use std::path::PathBuf;
use std::time::Duration;
use atm_core::{AgentType, Model, SessionDomain, SessionId};
use atm_protocol::{ClientMessage, DaemonMessage, MessageType, ProtocolVersion};
use atmd::registry::spawn_registry;
use atmd::server::DaemonServer;
use tempfile::TempDir;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::UnixStream;
use tokio::time::sleep;
use tokio_util::sync::CancellationToken;
const SOCKET_WAIT_TIMEOUT: Duration = Duration::from_millis(500);
const SOCKET_POLL_INTERVAL: Duration = Duration::from_millis(10);
const SHUTDOWN_GRACE_PERIOD: Duration = Duration::from_millis(100);
struct TestServer {
socket_path: PathBuf,
cancel_token: CancellationToken,
_temp_dir: TempDir, }
impl TestServer {
async fn spawn_internal() -> (Self, atmd::registry::RegistryHandle) {
let temp_dir = tempfile::tempdir().expect("create temp dir");
let socket_path = temp_dir.path().join("test.sock");
let registry = spawn_registry();
let registry_handle = registry.clone();
let cancel_token = CancellationToken::new();
let server = DaemonServer::new(socket_path.clone(), registry, cancel_token.clone());
tokio::spawn(async move {
let _ = server.run().await;
});
let start = tokio::time::Instant::now();
while start.elapsed() < SOCKET_WAIT_TIMEOUT {
if socket_path.exists() {
break;
}
sleep(SOCKET_POLL_INTERVAL).await;
}
assert!(
socket_path.exists(),
"Server socket did not appear within {SOCKET_WAIT_TIMEOUT:?}"
);
let test_server = TestServer {
socket_path,
cancel_token,
_temp_dir: temp_dir,
};
(test_server, registry_handle)
}
async fn spawn() -> Self {
Self::spawn_internal().await.0
}
async fn spawn_with_registry() -> (Self, atmd::registry::RegistryHandle) {
Self::spawn_internal().await
}
async fn connect(&self) -> TestClient {
let stream = UnixStream::connect(&self.socket_path)
.await
.expect("connect to server");
TestClient::new(stream)
}
async fn shutdown(self) {
self.cancel_token.cancel();
sleep(SHUTDOWN_GRACE_PERIOD).await;
}
}
struct TestClient {
reader: BufReader<tokio::net::unix::OwnedReadHalf>,
writer: tokio::net::unix::OwnedWriteHalf,
}
impl TestClient {
fn new(stream: UnixStream) -> Self {
let (reader, writer) = stream.into_split();
Self {
reader: BufReader::new(reader),
writer,
}
}
async fn send(&mut self, msg: ClientMessage) {
let json = serde_json::to_string(&msg).unwrap();
self.writer.write_all(json.as_bytes()).await.unwrap();
self.writer.write_all(b"\n").await.unwrap();
self.writer.flush().await.unwrap();
}
async fn recv(&mut self) -> DaemonMessage {
let mut line = String::new();
self.reader.read_line(&mut line).await.unwrap();
serde_json::from_str(&line).unwrap()
}
async fn handshake(&mut self, client_id: Option<String>) -> String {
self.send(ClientMessage::connect(client_id)).await;
match self.recv().await {
DaemonMessage::Connected { client_id, .. } => client_id,
other => panic!("Expected Connected, got {other:?}"),
}
}
async fn handshake_with_version(&mut self, version: ProtocolVersion) -> DaemonMessage {
let msg = ClientMessage {
protocol_version: version,
message: MessageType::Connect { client_id: None },
};
self.send(msg).await;
self.recv().await
}
}
fn create_test_session(id: &str) -> SessionDomain {
SessionDomain::new(
SessionId::new(id),
AgentType::GeneralPurpose,
Model::Sonnet4,
)
}
#[tokio::test]
async fn test_server_accepts_connection() {
let server = TestServer::spawn().await;
let _client = server.connect().await;
server.shutdown().await;
}
#[tokio::test]
async fn test_handshake_success() {
let server = TestServer::spawn().await;
let mut client = server.connect().await;
client
.send(ClientMessage::connect(Some("test-client".to_string())))
.await;
match client.recv().await {
DaemonMessage::Connected {
protocol_version,
client_id,
} => {
assert_eq!(protocol_version, ProtocolVersion::CURRENT);
assert_eq!(client_id, "test-client");
}
other => panic!("Expected Connected, got {other:?}"),
}
server.shutdown().await;
}
#[tokio::test]
async fn test_handshake_auto_assigns_client_id() {
let server = TestServer::spawn().await;
let mut client = server.connect().await;
client.send(ClientMessage::connect(None)).await;
match client.recv().await {
DaemonMessage::Connected { client_id, .. } => {
assert!(
client_id.starts_with("client-"),
"Expected auto-assigned ID starting with 'client-', got: {client_id}"
);
}
other => panic!("Expected Connected, got {other:?}"),
}
server.shutdown().await;
}
#[tokio::test]
async fn test_handshake_version_mismatch() {
let server = TestServer::spawn().await;
let mut client = server.connect().await;
let response = client
.handshake_with_version(ProtocolVersion::new(99, 0))
.await;
match response {
DaemonMessage::Rejected { reason, .. } => {
assert!(
reason.contains("not compatible"),
"Expected 'not compatible' in reason, got: {reason}"
);
}
other => panic!("Expected Rejected, got {other:?}"),
}
server.shutdown().await;
}
#[tokio::test]
async fn test_subscribe_unsubscribe_flow() {
let server = TestServer::spawn().await;
let mut client = server.connect().await;
client.handshake(Some("sub-client".to_string())).await;
client.send(ClientMessage::subscribe(None)).await;
match client.recv().await {
DaemonMessage::SessionList { sessions } => {
assert_eq!(sessions.len(), 0, "Initial session list should be empty");
}
other => panic!("Expected SessionList, got {other:?}"),
}
client
.send(ClientMessage::new(MessageType::Unsubscribe))
.await;
client.send(ClientMessage::list_sessions()).await;
match client.recv().await {
DaemonMessage::SessionList { .. } => {}
other => panic!("Expected SessionList after unsubscribe, got {other:?}"),
}
server.shutdown().await;
}
#[tokio::test]
async fn test_subscribe_with_session_filter() {
let server = TestServer::spawn().await;
let mut client = server.connect().await;
client.handshake(None).await;
let session_id = SessionId::new("specific-session");
client
.send(ClientMessage::subscribe(Some(session_id)))
.await;
match client.recv().await {
DaemonMessage::SessionList { .. } => {}
other => panic!("Expected SessionList, got {other:?}"),
}
server.shutdown().await;
}
#[tokio::test]
async fn test_broadcast_respects_filter() {
let (server, registry) = TestServer::spawn_with_registry().await;
let mut client1 = server.connect().await;
client1.handshake(Some("client-1".to_string())).await;
client1
.send(ClientMessage::subscribe(Some(SessionId::new("target"))))
.await;
match client1.recv().await {
DaemonMessage::SessionList { sessions } => {
assert_eq!(sessions.len(), 0, "Initial session list should be empty");
}
other => panic!("Expected SessionList, got {other:?}"),
}
let mut client2 = server.connect().await;
client2.handshake(Some("client-2".to_string())).await;
client2.send(ClientMessage::subscribe(None)).await;
match client2.recv().await {
DaemonMessage::SessionList { sessions } => {
assert_eq!(sessions.len(), 0, "Initial session list should be empty");
}
other => panic!("Expected SessionList, got {other:?}"),
}
let target_session = create_test_session("target");
registry
.register(target_session)
.await
.expect("register target session");
let other_session = create_test_session("other");
registry
.register(other_session)
.await
.expect("register other session");
client1.send(ClientMessage::list_sessions()).await;
match client1.recv().await {
DaemonMessage::SessionList { sessions } => {
assert_eq!(sessions.len(), 2, "Should have 2 sessions registered");
}
other => panic!("Expected SessionList with 2 sessions, got {other:?}"),
}
server.shutdown().await;
}
const MAX_TUI_CLIENTS: usize = 10;
#[tokio::test]
async fn test_max_clients_rejection() {
let server = TestServer::spawn().await;
let mut clients = Vec::new();
for i in 0..MAX_TUI_CLIENTS {
let mut client = server.connect().await;
client.handshake(Some(format!("client-{i}"))).await;
client.send(ClientMessage::subscribe(None)).await;
let _ = client.recv().await; clients.push(client);
}
for (i, client) in clients.iter_mut().enumerate() {
client.send(ClientMessage::ping(i as u64)).await;
match client.recv().await {
DaemonMessage::Pong { seq } => {
assert_eq!(seq, i as u64);
}
other => panic!("Expected Pong for client {i}, got {other:?}"),
}
}
server.shutdown().await;
}
#[tokio::test]
async fn test_graceful_shutdown() {
let server = TestServer::spawn().await;
let mut client = server.connect().await;
client.handshake(None).await;
let socket_path = server.socket_path.clone();
server.cancel_token.cancel();
sleep(SHUTDOWN_GRACE_PERIOD).await;
assert!(
!socket_path.exists(),
"Socket file should be removed after shutdown"
);
}
#[tokio::test]
async fn test_client_disconnect_message() {
let server = TestServer::spawn().await;
let mut client = server.connect().await;
client.handshake(None).await;
client.send(ClientMessage::disconnect()).await;
sleep(SHUTDOWN_GRACE_PERIOD).await;
server.shutdown().await;
}
#[tokio::test]
async fn test_ping_pong() {
let server = TestServer::spawn().await;
let mut client = server.connect().await;
client.handshake(None).await;
client.send(ClientMessage::ping(42)).await;
match client.recv().await {
DaemonMessage::Pong { seq } => {
assert_eq!(seq, 42, "Pong seq should match ping seq");
}
other => panic!("Expected Pong, got {other:?}"),
}
server.shutdown().await;
}
#[tokio::test]
async fn test_list_sessions_command() {
let server = TestServer::spawn().await;
let mut client = server.connect().await;
client.handshake(None).await;
client.send(ClientMessage::list_sessions()).await;
match client.recv().await {
DaemonMessage::SessionList { sessions } => {
assert_eq!(sessions.len(), 0);
}
other => panic!("Expected SessionList, got {other:?}"),
}
server.shutdown().await;
}
#[tokio::test]
async fn test_wrong_message_before_handshake() {
let server = TestServer::spawn().await;
let mut client = server.connect().await;
client.send(ClientMessage::list_sessions()).await;
match client.recv().await {
DaemonMessage::Error { message, .. } => {
assert!(
message.contains("Expected Connect"),
"Error should mention expected Connect message, got: {message}"
);
}
other => panic!("Expected Error, got {other:?}"),
}
server.shutdown().await;
}
#[tokio::test]
async fn test_duplicate_connect_rejected() {
let server = TestServer::spawn().await;
let mut client = server.connect().await;
client.handshake(None).await;
client.send(ClientMessage::connect(None)).await;
match client.recv().await {
DaemonMessage::Error { message, .. } => {
assert!(
message.contains("Already connected"),
"Error should mention 'Already connected', got: {message}"
);
}
other => panic!("Expected Error, got {other:?}"),
}
server.shutdown().await;
}
#[tokio::test]
async fn test_multiple_clients_concurrent() {
let server = TestServer::spawn().await;
let mut handles = Vec::new();
for i in 0..5 {
let socket_path = server.socket_path.clone();
let handle = tokio::spawn(async move {
let stream = UnixStream::connect(&socket_path).await.unwrap();
let mut client = TestClient::new(stream);
let id = client.handshake(Some(format!("concurrent-{i}"))).await;
assert_eq!(id, format!("concurrent-{i}"));
client.send(ClientMessage::list_sessions()).await;
let _ = client.recv().await;
});
handles.push(handle);
}
for handle in handles {
handle.await.expect("concurrent client task should succeed");
}
server.shutdown().await;
}
#[tokio::test]
async fn test_concurrent_ping_pong() {
let server = TestServer::spawn().await;
let mut clients = Vec::new();
for i in 0..3 {
let mut client = server.connect().await;
client.handshake(Some(format!("ping-client-{i}"))).await;
clients.push(client);
}
for (i, client) in clients.iter_mut().enumerate() {
client.send(ClientMessage::ping((i * 100) as u64)).await;
}
for (i, client) in clients.iter_mut().enumerate() {
match client.recv().await {
DaemonMessage::Pong { seq } => {
assert_eq!(seq, (i * 100) as u64);
}
other => panic!("Expected Pong for client {i}, got {other:?}"),
}
}
server.shutdown().await;
}