use std::time::Duration;
use crate::support::integration_test_lock;
use embers_core::{ErrorCode, MuxError, RequestId, SessionId, new_request_id};
use embers_protocol::{
ClientMessage, FrameType, NodeRequest, PingRequest, RawFrame, ServerEnvelope, ServerEvent,
ServerResponse, SessionRequest, decode_server_envelope, encode_client_message, read_frame,
write_frame,
};
use embers_test_support::{TestConnection, TestServer};
use tokio::io::AsyncWriteExt;
use tokio::net::UnixStream;
use tokio::time::sleep;
async fn create_session(
connection: &mut TestConnection,
name: &str,
) -> embers_protocol::SessionSnapshotResponse {
let response = connection
.request(&ClientMessage::Session(SessionRequest::Create {
request_id: new_request_id(),
name: name.to_owned(),
}))
.await
.expect("create session request succeeds");
match response {
ServerResponse::SessionSnapshot(snapshot) => snapshot,
other => panic!("expected session snapshot response, got {other:?}"),
}
}
fn expect_error(
response: ServerResponse,
request_id: Option<RequestId>,
code: ErrorCode,
) -> String {
match response {
ServerResponse::Error(error) => {
assert_eq!(error.request_id, request_id);
assert_eq!(error.error.code, code);
error.error.message
}
other => panic!("expected error response, got {other:?}"),
}
}
fn encode_frame_bytes(frame: &RawFrame) -> Vec<u8> {
let mut bytes = Vec::with_capacity(13 + frame.payload.len());
bytes.extend_from_slice(&(frame.payload.len() as u32).to_le_bytes());
bytes.push(frame.frame_type as u8);
bytes.extend_from_slice(&u64::from(frame.request_id).to_le_bytes());
bytes.extend_from_slice(&frame.payload);
bytes
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn subscriptions_fan_out_to_multiple_clients_with_session_filters() {
let _guard = integration_test_lock().lock().await;
let server = TestServer::start().await.expect("start server");
let mut actor = TestConnection::connect(server.socket_path())
.await
.expect("connect actor");
let mut global = TestConnection::connect(server.socket_path())
.await
.expect("connect global subscriber");
let mut scoped = TestConnection::connect(server.socket_path())
.await
.expect("connect scoped subscriber");
let mut other_scope = TestConnection::connect(server.socket_path())
.await
.expect("connect other scoped subscriber");
let main_session = create_session(&mut actor, "main").await;
let other_session = create_session(&mut actor, "other").await;
let main_session_id = main_session.snapshot.session.id;
let other_session_id = other_session.snapshot.session.id;
global.subscribe(None).await.expect("subscribe globally");
scoped
.subscribe(Some(main_session_id))
.await
.expect("subscribe to main session");
other_scope
.subscribe(Some(other_session_id))
.await
.expect("subscribe to other session");
let close_request_id = RequestId(41);
let close_response = actor
.request(&ClientMessage::Session(SessionRequest::Close {
request_id: close_request_id,
session_id: main_session_id,
force: false,
}))
.await
.expect("close session request succeeds");
assert!(matches!(close_response, ServerResponse::Ok(_)));
let global_event = global
.wait_for_event(Duration::from_secs(1), |event| {
matches!(
event,
ServerEvent::SessionClosed(closed) if closed.session_id == main_session_id
)
})
.await
.expect("global subscriber receives session close");
assert!(matches!(
global_event,
ServerEvent::SessionClosed(closed) if closed.session_id == main_session_id
));
let scoped_event = scoped
.wait_for_event(Duration::from_secs(1), |event| {
matches!(
event,
ServerEvent::SessionClosed(closed) if closed.session_id == main_session_id
)
})
.await
.expect("scoped subscriber receives matching session close");
assert!(matches!(
scoped_event,
ServerEvent::SessionClosed(closed) if closed.session_id == main_session_id
));
let other_scope_error = other_scope
.wait_for_event(Duration::from_millis(200), |event| {
matches!(
event,
ServerEvent::SessionClosed(closed) if closed.session_id == main_session_id
)
})
.await
.expect_err("non-matching scoped subscriber should not receive the event");
assert!(matches!(other_scope_error, MuxError::Timeout(_)));
server.shutdown().await.expect("shutdown server");
}
#[tokio::test]
async fn fragmented_request_frames_round_trip_and_preserve_correlation_id() {
let _guard = integration_test_lock().lock().await;
let server = TestServer::start().await.expect("start server");
let mut stream = UnixStream::connect(server.socket_path())
.await
.expect("connect raw client");
let request_id = RequestId(52);
let payload = encode_client_message(&ClientMessage::Ping(PingRequest {
request_id,
payload: "fragmented".to_owned(),
}))
.expect("encode ping request");
let frame = RawFrame::new(FrameType::Request, request_id, payload);
for chunk in encode_frame_bytes(&frame).chunks(3) {
stream.write_all(chunk).await.expect("write request chunk");
tokio::task::yield_now().await;
}
let response_frame = read_frame(&mut stream)
.await
.expect("read response frame")
.expect("response frame");
assert_eq!(response_frame.frame_type, FrameType::Response);
assert_eq!(response_frame.request_id, request_id);
match decode_server_envelope(&response_frame.payload).expect("decode response payload") {
ServerEnvelope::Response(ServerResponse::Pong(pong)) => {
assert_eq!(pong.request_id, request_id);
assert_eq!(pong.payload, "fragmented");
}
other => panic!("expected pong response, got {other:?}"),
}
server.shutdown().await.expect("shutdown server");
}
#[tokio::test]
async fn malformed_payloads_return_protocol_violation_errors() {
let _guard = integration_test_lock().lock().await;
let server = TestServer::start().await.expect("start server");
let mut stream = UnixStream::connect(server.socket_path())
.await
.expect("connect raw client");
let request_id = RequestId(61);
let malformed = RawFrame::new(FrameType::Request, request_id, vec![0, 1, 2, 3, 4]);
write_frame(&mut stream, &malformed)
.await
.expect("write malformed request");
let response_frame = read_frame(&mut stream)
.await
.expect("read response frame")
.expect("response frame");
assert_eq!(response_frame.frame_type, FrameType::Response);
assert_eq!(response_frame.request_id, request_id);
match decode_server_envelope(&response_frame.payload).expect("decode response payload") {
ServerEnvelope::Response(ServerResponse::Error(error)) => {
assert_eq!(error.request_id, Some(request_id));
assert_eq!(error.error.code, ErrorCode::ProtocolViolation);
}
other => panic!("expected protocol violation response, got {other:?}"),
}
server.shutdown().await.expect("shutdown server");
}
#[tokio::test]
async fn typed_errors_cover_invalid_ids_and_impossible_mutations() {
let _guard = integration_test_lock().lock().await;
let server = TestServer::start().await.expect("start server");
let mut connection = TestConnection::connect(server.socket_path())
.await
.expect("connect client");
let missing_request_id = RequestId(71);
let missing_response = connection
.request(&ClientMessage::Session(SessionRequest::Get {
request_id: missing_request_id,
session_id: SessionId(999),
}))
.await
.expect("missing session request returns response");
expect_error(
missing_response,
Some(missing_request_id),
ErrorCode::NotFound,
);
let session = create_session(&mut connection, "empty").await;
let invalid_focus_request_id = RequestId(72);
let invalid_focus_response = connection
.request(&ClientMessage::Node(NodeRequest::Focus {
request_id: invalid_focus_request_id,
session_id: session.snapshot.session.id,
node_id: session.snapshot.session.root_node_id,
}))
.await
.expect("invalid focus request returns response");
let message = expect_error(
invalid_focus_response,
Some(invalid_focus_request_id),
ErrorCode::InvalidRequest,
);
assert!(message.contains("no focusable leaf"));
let invalid_move_request_id = RequestId(73);
let invalid_move_response = connection
.request(&ClientMessage::Node(NodeRequest::MoveBufferToNode {
request_id: invalid_move_request_id,
buffer_id: embers_core::BufferId(1),
target_leaf_node_id: session.snapshot.session.root_node_id,
}))
.await
.expect("invalid move request returns response");
expect_error(
invalid_move_response,
Some(invalid_move_request_id),
ErrorCode::InvalidRequest,
);
server.shutdown().await.expect("shutdown server");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn disconnected_subscribers_are_cleaned_up_without_breaking_remaining_clients() {
let _guard = integration_test_lock().lock().await;
let server = TestServer::start().await.expect("start server");
let mut actor = TestConnection::connect(server.socket_path())
.await
.expect("connect actor");
let mut surviving_subscriber = TestConnection::connect(server.socket_path())
.await
.expect("connect surviving subscriber");
let mut disconnected_subscriber = TestConnection::connect(server.socket_path())
.await
.expect("connect subscriber to disconnect");
let session = create_session(&mut actor, "cleanup").await;
let session_id = session.snapshot.session.id;
surviving_subscriber
.subscribe(Some(session_id))
.await
.expect("subscribe surviving client");
disconnected_subscriber
.subscribe(Some(session_id))
.await
.expect("subscribe client to disconnect");
drop(disconnected_subscriber);
sleep(Duration::from_millis(50)).await;
let close_response = actor
.request(&ClientMessage::Session(SessionRequest::Close {
request_id: RequestId(81),
session_id,
force: false,
}))
.await
.expect("close session succeeds");
assert!(matches!(close_response, ServerResponse::Ok(_)));
let event = surviving_subscriber
.wait_for_event(Duration::from_secs(1), |server_event| {
matches!(
server_event,
ServerEvent::SessionClosed(closed) if closed.session_id == session_id
)
})
.await
.expect("surviving subscriber receives session close");
assert!(matches!(
event,
ServerEvent::SessionClosed(closed) if closed.session_id == session_id
));
let ping = actor
.ping("still-alive")
.await
.expect("server stays usable");
assert_eq!(ping, "still-alive");
server.shutdown().await.expect("shutdown server");
}