use std::path::PathBuf;
use std::time::Duration;
use atm_core::{AgentType, Model, SessionDomain, SessionId};
use atm_protocol::{ClientMessage, DaemonMessage, MessageType, ProtocolVersion};
use atmd::registry::spawn_registry;
use atmd::server::DaemonServer;
use tempfile::TempDir;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::UnixStream;
use tokio::time::sleep;
use tokio_util::sync::CancellationToken;
const SOCKET_WAIT_TIMEOUT: Duration = Duration::from_millis(500);
const SOCKET_POLL_INTERVAL: Duration = Duration::from_millis(10);
const SHUTDOWN_GRACE_PERIOD: Duration = Duration::from_millis(100);
struct TestServer {
socket_path: PathBuf,
cancel_token: CancellationToken,
_temp_dir: TempDir, }
impl TestServer {
async fn spawn_internal() -> (Self, atmd::registry::RegistryHandle) {
let temp_dir = tempfile::tempdir().expect("create temp dir");
let socket_path = temp_dir.path().join("test.sock");
let registry = spawn_registry();
let registry_handle = registry.clone();
let cancel_token = CancellationToken::new();
let server = DaemonServer::new(socket_path.clone(), registry, cancel_token.clone());
tokio::spawn(async move {
let _ = server.run().await;
});
let start = tokio::time::Instant::now();
while start.elapsed() < SOCKET_WAIT_TIMEOUT {
if socket_path.exists() {
break;
}
sleep(SOCKET_POLL_INTERVAL).await;
}
assert!(
socket_path.exists(),
"Server socket did not appear within {SOCKET_WAIT_TIMEOUT:?}"
);
let test_server = TestServer {
socket_path,
cancel_token,
_temp_dir: temp_dir,
};
(test_server, registry_handle)
}
async fn spawn() -> Self {
Self::spawn_internal().await.0
}
async fn spawn_with_registry() -> (Self, atmd::registry::RegistryHandle) {
Self::spawn_internal().await
}
async fn connect(&self) -> TestClient {
let stream = UnixStream::connect(&self.socket_path)
.await
.expect("connect to server");
TestClient::new(stream)
}
async fn shutdown(self) {
self.cancel_token.cancel();
sleep(SHUTDOWN_GRACE_PERIOD).await;
}
}
struct TestClient {
reader: BufReader<tokio::net::unix::OwnedReadHalf>,
writer: tokio::net::unix::OwnedWriteHalf,
}
impl TestClient {
fn new(stream: UnixStream) -> Self {
let (reader, writer) = stream.into_split();
Self {
reader: BufReader::new(reader),
writer,
}
}
async fn send(&mut self, msg: ClientMessage) {
let json = serde_json::to_string(&msg).unwrap();
self.writer.write_all(json.as_bytes()).await.unwrap();
self.writer.write_all(b"\n").await.unwrap();
self.writer.flush().await.unwrap();
}
async fn recv(&mut self) -> DaemonMessage {
let mut line = String::new();
self.reader.read_line(&mut line).await.unwrap();
serde_json::from_str(&line).unwrap()
}
async fn handshake(&mut self, client_id: Option<String>) -> String {
self.send(ClientMessage::connect(client_id)).await;
match self.recv().await {
DaemonMessage::Connected { client_id, .. } => client_id,
other => panic!("Expected Connected, got {other:?}"),
}
}
async fn handshake_with_version(&mut self, version: ProtocolVersion) -> DaemonMessage {
let msg = ClientMessage {
protocol_version: version,
message: MessageType::Connect { client_id: None },
};
self.send(msg).await;
self.recv().await
}
}
fn create_test_session(id: &str) -> SessionDomain {
SessionDomain::new(
SessionId::new(id),
AgentType::GeneralPurpose,
Model::Sonnet4,
)
}
#[tokio::test]
async fn test_server_accepts_connection() {
let server = TestServer::spawn().await;
let _client = server.connect().await;
server.shutdown().await;
}
#[tokio::test]
async fn test_handshake_success() {
let server = TestServer::spawn().await;
let mut client = server.connect().await;
client
.send(ClientMessage::connect(Some("test-client".to_string())))
.await;
match client.recv().await {
DaemonMessage::Connected {
protocol_version,
client_id,
} => {
assert_eq!(protocol_version, ProtocolVersion::CURRENT);
assert_eq!(client_id, "test-client");
}
other => panic!("Expected Connected, got {other:?}"),
}
server.shutdown().await;
}
#[tokio::test]
async fn test_handshake_auto_assigns_client_id() {
let server = TestServer::spawn().await;
let mut client = server.connect().await;
client.send(ClientMessage::connect(None)).await;
match client.recv().await {
DaemonMessage::Connected { client_id, .. } => {
assert!(
client_id.starts_with("client-"),
"Expected auto-assigned ID starting with 'client-', got: {client_id}"
);
}
other => panic!("Expected Connected, got {other:?}"),
}
server.shutdown().await;
}
#[tokio::test]
async fn test_handshake_version_mismatch() {
let server = TestServer::spawn().await;
let mut client = server.connect().await;
let response = client
.handshake_with_version(ProtocolVersion::new(99, 0))
.await;
match response {
DaemonMessage::Rejected { reason, .. } => {
assert!(
reason.contains("not compatible"),
"Expected 'not compatible' in reason, got: {reason}"
);
}
other => panic!("Expected Rejected, got {other:?}"),
}
server.shutdown().await;
}
#[tokio::test]
async fn test_subscribe_unsubscribe_flow() {
let server = TestServer::spawn().await;
let mut client = server.connect().await;
client.handshake(Some("sub-client".to_string())).await;
client.send(ClientMessage::subscribe(None)).await;
match client.recv().await {
DaemonMessage::SessionList { sessions } => {
assert_eq!(sessions.len(), 0, "Initial session list should be empty");
}
other => panic!("Expected SessionList, got {other:?}"),
}
client
.send(ClientMessage::new(MessageType::Unsubscribe))
.await;
client.send(ClientMessage::list_sessions()).await;
match client.recv().await {
DaemonMessage::SessionList { .. } => {}
other => panic!("Expected SessionList after unsubscribe, got {other:?}"),
}
server.shutdown().await;
}
#[tokio::test]
async fn test_subscribe_with_session_filter() {
let server = TestServer::spawn().await;
let mut client = server.connect().await;
client.handshake(None).await;
let session_id = SessionId::new("specific-session");
client
.send(ClientMessage::subscribe(Some(session_id)))
.await;
match client.recv().await {
DaemonMessage::SessionList { .. } => {}
other => panic!("Expected SessionList, got {other:?}"),
}
server.shutdown().await;
}
#[tokio::test]
async fn test_broadcast_respects_filter() {
let (server, registry) = TestServer::spawn_with_registry().await;
let mut client1 = server.connect().await;
client1.handshake(Some("client-1".to_string())).await;
client1
.send(ClientMessage::subscribe(Some(SessionId::new("target"))))
.await;
match client1.recv().await {
DaemonMessage::SessionList { sessions } => {
assert_eq!(sessions.len(), 0, "Initial session list should be empty");
}
other => panic!("Expected SessionList, got {other:?}"),
}
let mut client2 = server.connect().await;
client2.handshake(Some("client-2".to_string())).await;
client2.send(ClientMessage::subscribe(None)).await;
match client2.recv().await {
DaemonMessage::SessionList { sessions } => {
assert_eq!(sessions.len(), 0, "Initial session list should be empty");
}
other => panic!("Expected SessionList, got {other:?}"),
}
let target_session = create_test_session("target");
registry
.register(target_session)
.await
.expect("register target session");
let other_session = create_test_session("other");
registry
.register(other_session)
.await
.expect("register other session");
client1.send(ClientMessage::list_sessions()).await;
match client1.recv().await {
DaemonMessage::SessionList { sessions } => {
assert_eq!(sessions.len(), 2, "Should have 2 sessions registered");
}
other => panic!("Expected SessionList with 2 sessions, got {other:?}"),
}
server.shutdown().await;
}
const MAX_TUI_CLIENTS: usize = 10;
#[tokio::test]
async fn test_max_clients_rejection() {
let server = TestServer::spawn().await;
let mut clients = Vec::new();
for i in 0..MAX_TUI_CLIENTS {
let mut client = server.connect().await;
client.handshake(Some(format!("client-{i}"))).await;
client.send(ClientMessage::subscribe(None)).await;
let _ = client.recv().await; clients.push(client);
}
for (i, client) in clients.iter_mut().enumerate() {
client.send(ClientMessage::ping(i as u64)).await;
match client.recv().await {
DaemonMessage::Pong { seq } => {
assert_eq!(seq, i as u64);
}
other => panic!("Expected Pong for client {i}, got {other:?}"),
}
}
server.shutdown().await;
}
#[tokio::test]
async fn test_graceful_shutdown() {
let server = TestServer::spawn().await;
let mut client = server.connect().await;
client.handshake(None).await;
let socket_path = server.socket_path.clone();
server.cancel_token.cancel();
sleep(SHUTDOWN_GRACE_PERIOD).await;
assert!(
!socket_path.exists(),
"Socket file should be removed after shutdown"
);
}
#[tokio::test]
async fn test_client_disconnect_message() {
let server = TestServer::spawn().await;
let mut client = server.connect().await;
client.handshake(None).await;
client.send(ClientMessage::disconnect()).await;
sleep(SHUTDOWN_GRACE_PERIOD).await;
server.shutdown().await;
}
#[tokio::test]
async fn test_ping_pong() {
let server = TestServer::spawn().await;
let mut client = server.connect().await;
client.handshake(None).await;
client.send(ClientMessage::ping(42)).await;
match client.recv().await {
DaemonMessage::Pong { seq } => {
assert_eq!(seq, 42, "Pong seq should match ping seq");
}
other => panic!("Expected Pong, got {other:?}"),
}
server.shutdown().await;
}
#[tokio::test]
async fn test_list_sessions_command() {
let server = TestServer::spawn().await;
let mut client = server.connect().await;
client.handshake(None).await;
client.send(ClientMessage::list_sessions()).await;
match client.recv().await {
DaemonMessage::SessionList { sessions } => {
assert_eq!(sessions.len(), 0);
}
other => panic!("Expected SessionList, got {other:?}"),
}
server.shutdown().await;
}
#[tokio::test]
async fn test_wrong_message_before_handshake() {
let server = TestServer::spawn().await;
let mut client = server.connect().await;
client.send(ClientMessage::list_sessions()).await;
match client.recv().await {
DaemonMessage::Error { message, .. } => {
assert!(
message.contains("Expected Connect"),
"Error should mention expected Connect message, got: {message}"
);
}
other => panic!("Expected Error, got {other:?}"),
}
server.shutdown().await;
}
#[tokio::test]
async fn test_duplicate_connect_rejected() {
let server = TestServer::spawn().await;
let mut client = server.connect().await;
client.handshake(None).await;
client.send(ClientMessage::connect(None)).await;
match client.recv().await {
DaemonMessage::Error { message, .. } => {
assert!(
message.contains("Already connected"),
"Error should mention 'Already connected', got: {message}"
);
}
other => panic!("Expected Error, got {other:?}"),
}
server.shutdown().await;
}
#[tokio::test]
async fn test_multiple_clients_concurrent() {
let server = TestServer::spawn().await;
let mut handles = Vec::new();
for i in 0..5 {
let socket_path = server.socket_path.clone();
let handle = tokio::spawn(async move {
let stream = UnixStream::connect(&socket_path).await.unwrap();
let mut client = TestClient::new(stream);
let id = client.handshake(Some(format!("concurrent-{i}"))).await;
assert_eq!(id, format!("concurrent-{i}"));
client.send(ClientMessage::list_sessions()).await;
let _ = client.recv().await;
});
handles.push(handle);
}
for handle in handles {
handle.await.expect("concurrent client task should succeed");
}
server.shutdown().await;
}
fn hook_event_json(
session_id: &str,
event_name: &str,
extras: serde_json::Value,
) -> serde_json::Value {
let mut obj = serde_json::json!({
"session_id": session_id,
"hook_event_name": event_name,
});
if let (Some(map), Some(extras)) = (obj.as_object_mut(), extras.as_object()) {
for (k, v) in extras {
map.insert(k.clone(), v.clone());
}
}
obj
}
#[tokio::test]
async fn test_e2e_pre_tool_use_translates_to_tool_call_start() {
let (server, registry) = TestServer::spawn_with_registry().await;
let mut client = server.connect().await;
client.handshake(None).await;
let session_id = SessionId::new("e2e-tool");
registry
.register(create_test_session(session_id.as_str()))
.await
.expect("register session");
client
.send(ClientMessage::hook_event(hook_event_json(
session_id.as_str(),
"PreToolUse",
serde_json::json!({"tool_name": "Bash"}),
)))
.await;
sleep(Duration::from_millis(50)).await;
let view = registry
.get_session(session_id.clone())
.await
.expect("session should still exist");
assert_eq!(
view.status_label, "working",
"PreToolUse(Bash) should translate to ToolCallStart -> Working"
);
assert_eq!(view.activity_detail, Some("Bash".into()));
server.shutdown().await;
}
#[tokio::test]
async fn test_e2e_pre_tool_use_interactive_translates_to_needs_input() {
let (server, registry) = TestServer::spawn_with_registry().await;
let mut client = server.connect().await;
client.handshake(None).await;
let session_id = SessionId::new("e2e-interactive");
registry
.register(create_test_session(session_id.as_str()))
.await
.expect("register session");
client
.send(ClientMessage::hook_event(hook_event_json(
session_id.as_str(),
"PreToolUse",
serde_json::json!({"tool_name": "AskUserQuestion"}),
)))
.await;
sleep(Duration::from_millis(50)).await;
let view = registry
.get_session(session_id.clone())
.await
.expect("session should still exist");
assert_eq!(
view.status_label, "needs input",
"interactive tool should translate to NeedsInput -> AttentionNeeded"
);
assert_eq!(view.activity_detail, Some("AskUserQuestion".into()));
server.shutdown().await;
}
#[tokio::test]
async fn test_e2e_stop_translates_to_idle() {
let (server, registry) = TestServer::spawn_with_registry().await;
let mut client = server.connect().await;
client.handshake(None).await;
let session_id = SessionId::new("e2e-stop");
registry
.register(create_test_session(session_id.as_str()))
.await
.expect("register session");
client
.send(ClientMessage::hook_event(hook_event_json(
session_id.as_str(),
"PreToolUse",
serde_json::json!({"tool_name": "Bash"}),
)))
.await;
sleep(Duration::from_millis(50)).await;
client
.send(ClientMessage::hook_event(hook_event_json(
session_id.as_str(),
"Stop",
serde_json::json!({}),
)))
.await;
sleep(Duration::from_millis(50)).await;
let view = registry
.get_session(session_id.clone())
.await
.expect("session should still exist after Stop");
assert_eq!(
view.status_label, "idle",
"Stop should translate to WorkingEnd -> Idle"
);
server.shutdown().await;
}
#[tokio::test]
async fn test_e2e_session_end_removes_session() {
let (server, registry) = TestServer::spawn_with_registry().await;
let mut client = server.connect().await;
client.handshake(None).await;
let session_id = SessionId::new("e2e-end");
registry
.register(create_test_session(session_id.as_str()))
.await
.expect("register session");
assert!(registry.get_session(session_id.clone()).await.is_some());
client
.send(ClientMessage::hook_event(hook_event_json(
session_id.as_str(),
"SessionEnd",
serde_json::json!({"reason": "clear"}),
)))
.await;
sleep(Duration::from_millis(50)).await;
assert!(
registry.get_session(session_id.clone()).await.is_none(),
"SessionEnd lifecycle event should remove the session"
);
server.shutdown().await;
}
#[tokio::test]
async fn test_e2e_notification_permission_prompt_becomes_needs_input() {
let (server, registry) = TestServer::spawn_with_registry().await;
let mut client = server.connect().await;
client.handshake(None).await;
let session_id = SessionId::new("e2e-permission");
registry
.register(create_test_session(session_id.as_str()))
.await
.expect("register session");
client
.send(ClientMessage::hook_event(hook_event_json(
session_id.as_str(),
"Notification",
serde_json::json!({"notification_type": "permission_prompt"}),
)))
.await;
sleep(Duration::from_millis(50)).await;
let view = registry
.get_session(session_id.clone())
.await
.expect("session should still exist");
assert_eq!(
view.status_label, "needs input",
"Notification(permission_prompt) should translate to NeedsInput"
);
server.shutdown().await;
}
fn pi_event_json(event: &str, payload: serde_json::Value) -> serde_json::Value {
serde_json::json!({
"event": event,
"payload": payload,
"session_id": "e2e-pi",
})
}
#[tokio::test]
async fn test_e2e_pi_tool_call_translates_to_tool_call_start() {
let (server, registry) = TestServer::spawn_with_registry().await;
let mut client = server.connect().await;
client.handshake(None).await;
let session_id = SessionId::new("e2e-pi");
registry
.register(create_test_session(session_id.as_str()))
.await
.expect("register session");
client
.send(ClientMessage::pi_event(pi_event_json(
"tool_call",
serde_json::json!({
"type": "tool_call",
"toolName": "Bash",
"toolCallId": "toolu_pi_demo",
"input": {"command": "ls"}
}),
)))
.await;
sleep(Duration::from_millis(50)).await;
let view = registry
.get_session(session_id.clone())
.await
.expect("session should exist");
assert_eq!(view.status_label, "working");
assert_eq!(view.activity_detail, Some("Bash".into()));
server.shutdown().await;
}
#[tokio::test]
async fn test_e2e_pi_permission_gate_becomes_needs_input() {
let (server, registry) = TestServer::spawn_with_registry().await;
let mut client = server.connect().await;
client.handshake(None).await;
let session_id = SessionId::new("e2e-pi");
registry
.register(create_test_session(session_id.as_str()))
.await
.expect("register session");
client
.send(ClientMessage::pi_event(pi_event_json(
"tool_call",
serde_json::json!({
"type": "tool_call",
"toolName": "Bash",
"toolCallId": "toolu_dangerous",
"needs_user_input": true
}),
)))
.await;
sleep(Duration::from_millis(50)).await;
let view = registry
.get_session(session_id.clone())
.await
.expect("session should exist");
assert_eq!(view.status_label, "needs input");
assert_eq!(view.activity_detail, Some("Bash".into()));
server.shutdown().await;
}
#[tokio::test]
async fn test_e2e_pi_agent_end_translates_to_idle() {
let (server, registry) = TestServer::spawn_with_registry().await;
let mut client = server.connect().await;
client.handshake(None).await;
let session_id = SessionId::new("e2e-pi");
registry
.register(create_test_session(session_id.as_str()))
.await
.expect("register session");
client
.send(ClientMessage::pi_event(pi_event_json(
"agent_start",
serde_json::json!({"type":"agent_start"}),
)))
.await;
sleep(Duration::from_millis(50)).await;
client
.send(ClientMessage::pi_event(pi_event_json(
"agent_end",
serde_json::json!({"type":"agent_end"}),
)))
.await;
sleep(Duration::from_millis(50)).await;
let view = registry
.get_session(session_id.clone())
.await
.expect("session should still exist");
assert_eq!(view.status_label, "idle");
server.shutdown().await;
}
#[tokio::test]
async fn test_e2e_pi_context_event_drives_cost_and_token_display() {
let (server, registry) = TestServer::spawn_with_registry().await;
let mut client = server.connect().await;
client.handshake(None).await;
let session_id = SessionId::new("e2e-pi");
registry
.register(create_test_session(session_id.as_str()))
.await
.expect("register session");
let view = registry
.get_session(session_id.clone())
.await
.expect("session exists");
assert!((view.cost_usd - 0.0).abs() < 1e-9, "cost starts at $0");
client
.send(ClientMessage::pi_event(pi_event_json(
"context",
serde_json::json!({
"type": "context",
"messages": [
{"role": "user", "content": "hi"},
{
"role": "assistant",
"usage": {
"input": 1088,
"output": 55,
"totalTokens": 1143,
"cost": {"input": 0.00544, "output": 0.00165, "total": 0.00709}
}
}
]
}),
)))
.await;
sleep(Duration::from_millis(50)).await;
let view = registry
.get_session(session_id.clone())
.await
.expect("session exists");
assert!(
(view.cost_usd - 0.00709).abs() < 1e-9,
"expected cost 0.00709, got {}",
view.cost_usd
);
server.shutdown().await;
}
#[tokio::test]
async fn test_e2e_pi_atm_needs_input_open_drives_attention_needed() {
let (server, registry) = TestServer::spawn_with_registry().await;
let mut client = server.connect().await;
client.handshake(None).await;
let session_id = SessionId::new("e2e-pi");
registry
.register(create_test_session(session_id.as_str()))
.await
.expect("register session");
client
.send(ClientMessage::pi_event(pi_event_json(
"atm_needs_input_open",
serde_json::json!({"title": "Allow `rm -rf /tmp/junk`?"}),
)))
.await;
sleep(Duration::from_millis(50)).await;
let view = registry
.get_session(session_id.clone())
.await
.expect("session exists");
assert_eq!(view.status_label, "needs input");
client
.send(ClientMessage::pi_event(pi_event_json(
"atm_needs_input_resolved",
serde_json::json!({}),
)))
.await;
sleep(Duration::from_millis(50)).await;
let view = registry
.get_session(session_id.clone())
.await
.expect("session exists");
assert_eq!(view.status_label, "working");
server.shutdown().await;
}
#[tokio::test]
async fn test_e2e_pi_session_shutdown_removes_session() {
let (server, registry) = TestServer::spawn_with_registry().await;
let mut client = server.connect().await;
client.handshake(None).await;
let session_id = SessionId::new("e2e-pi");
registry
.register(create_test_session(session_id.as_str()))
.await
.expect("register session");
client
.send(ClientMessage::pi_event(pi_event_json(
"session_shutdown",
serde_json::json!({"type":"session_shutdown","reason":"quit"}),
)))
.await;
sleep(Duration::from_millis(50)).await;
assert!(
registry.get_session(session_id.clone()).await.is_none(),
"session_shutdown should remove the session"
);
server.shutdown().await;
}
#[tokio::test]
async fn test_e2e_pi_suppressed_event_does_not_panic_or_disrupt() {
let (server, registry) = TestServer::spawn_with_registry().await;
let mut client = server.connect().await;
client.handshake(None).await;
let session_id = SessionId::new("e2e-pi");
registry
.register(create_test_session(session_id.as_str()))
.await
.expect("register session");
let before = registry
.get_session(session_id.clone())
.await
.expect("exists");
client
.send(ClientMessage::pi_event(pi_event_json(
"tool_execution_start",
serde_json::json!({
"type":"tool_execution_start",
"toolName":"ls",
"toolCallId":"call_x"
}),
)))
.await;
sleep(Duration::from_millis(50)).await;
let after = registry
.get_session(session_id.clone())
.await
.expect("still exists");
assert_eq!(before.status_label, after.status_label);
server.shutdown().await;
}
#[tokio::test]
async fn test_concurrent_ping_pong() {
let server = TestServer::spawn().await;
let mut clients = Vec::new();
for i in 0..3 {
let mut client = server.connect().await;
client.handshake(Some(format!("ping-client-{i}"))).await;
clients.push(client);
}
for (i, client) in clients.iter_mut().enumerate() {
client.send(ClientMessage::ping((i * 100) as u64)).await;
}
for (i, client) in clients.iter_mut().enumerate() {
match client.recv().await {
DaemonMessage::Pong { seq } => {
assert_eq!(seq, (i * 100) as u64);
}
other => panic!("Expected Pong for client {i}, got {other:?}"),
}
}
server.shutdown().await;
}