use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
use tokio::net::unix::{OwnedReadHalf, OwnedWriteHalf};
use tokio::sync::{Mutex, RwLock};
use tokio::time::timeout;
use tracing::{debug, error, info, warn};
use atm_claude_adapter::RawHookEvent;
use atm_core::SessionId;
use atm_pi_adapter::RawPiEvent;
use atm_protocol::{ClientMessage, DaemonMessage, MessageType, ProtocolVersion};
use crate::discovery::{DiscoveryResult, DiscoveryService};
use crate::registry::{RegistryHandle, SessionEvent};
pub type SubscriberWriter = Arc<Mutex<BufWriter<OwnedWriteHalf>>>;
pub struct Subscriber {
pub writer: SubscriberWriter,
pub filter: Option<SessionId>,
}
pub type SubscribersMap = Arc<RwLock<HashMap<String, Subscriber>>>;
const MAX_TUI_CLIENTS: usize = 10;
const MAX_MESSAGE_SIZE: usize = 1_048_576;
const READ_TIMEOUT: Duration = Duration::from_secs(300);
const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
type ClientId = String;
pub struct ConnectionHandler {
reader: BufReader<OwnedReadHalf>,
writer: SubscriberWriter,
registry: RegistryHandle,
subscribers: SubscribersMap,
client_id: Option<ClientId>,
subscribed: bool,
subscription_filter: Option<SessionId>,
connection_number: u64,
}
impl ConnectionHandler {
pub fn new(
reader: OwnedReadHalf,
writer: OwnedWriteHalf,
registry: RegistryHandle,
subscribers: SubscribersMap,
connection_number: u64,
) -> Self {
Self {
reader: BufReader::new(reader),
writer: Arc::new(Mutex::new(BufWriter::new(writer))),
registry,
subscribers,
client_id: None,
subscribed: false,
subscription_filter: None,
connection_number,
}
}
pub fn writer_handle(&self) -> SubscriberWriter {
Arc::clone(&self.writer)
}
pub async fn run(mut self) -> Option<ClientId> {
debug!(connection = self.connection_number, "New client connected");
match self.handle_handshake().await {
Ok(()) => {
debug!(
client_id = ?self.client_id,
"Client handshake completed"
);
}
Err(e) => {
warn!(
connection = self.connection_number,
error = %e,
"Handshake failed"
);
return None;
}
}
let client_id = self.client_id.clone();
if let Err(e) = self.process_messages().await {
debug!(
client_id = ?self.client_id,
error = %e,
"Connection closed"
);
}
debug!(client_id = ?self.client_id, "Client disconnected");
client_id
}
async fn handle_handshake(&mut self) -> Result<(), ConnectionError> {
let msg = self.read_message().await?;
let client_version = msg.protocol_version;
if !client_version.is_compatible_with(&ProtocolVersion::CURRENT) {
warn!(
client_version = %client_version,
server_version = %ProtocolVersion::CURRENT,
"Protocol version mismatch"
);
self.send_message(DaemonMessage::rejected(&format!(
"Protocol version {} not compatible with server version {}",
client_version,
ProtocolVersion::CURRENT
)))
.await?;
return Err(ConnectionError::VersionMismatch {
client: client_version,
server: ProtocolVersion::CURRENT,
});
}
match msg.message {
MessageType::Connect { client_id } => {
let assigned_id =
client_id.unwrap_or_else(|| format!("client-{}", self.connection_number));
self.client_id = Some(assigned_id.clone());
self.send_message(DaemonMessage::connected(assigned_id))
.await?;
Ok(())
}
other => {
self.send_message(DaemonMessage::error(
"Expected Connect message for handshake",
))
.await?;
Err(ConnectionError::UnexpectedMessage(format!("{other:?}")))
}
}
}
async fn process_messages(&mut self) -> Result<(), ConnectionError> {
loop {
let msg = match timeout(READ_TIMEOUT, self.read_message()).await {
Ok(Ok(msg)) => msg,
Ok(Err(ConnectionError::Eof)) => {
debug!(client_id = ?self.client_id, "Client sent EOF");
return Ok(());
}
Ok(Err(e)) => return Err(e),
Err(_) => {
debug!(client_id = ?self.client_id, "Connection timed out");
return Err(ConnectionError::Timeout);
}
};
if let Err(e) = self.handle_message(msg).await {
error!(
client_id = ?self.client_id,
error = %e,
"Error handling message"
);
let _ = self
.send_message(DaemonMessage::error(&e.to_string()))
.await;
}
}
}
async fn handle_message(&mut self, msg: ClientMessage) -> Result<(), ConnectionError> {
match msg.message {
MessageType::Connect { .. } => {
self.send_message(DaemonMessage::error("Already connected"))
.await?;
}
MessageType::StatusUpdate { data } => {
self.handle_status_update(data).await?;
}
MessageType::HookEvent { data } => {
self.handle_hook_event(data).await?;
}
MessageType::PiEvent { data } => {
self.handle_pi_event(data).await?;
}
MessageType::ListSessions => {
let sessions = self.registry.get_all_sessions().await;
self.send_message(DaemonMessage::session_list(sessions))
.await?;
}
MessageType::Subscribe { session_id } => {
let client_id = match &self.client_id {
Some(id) => id.clone(),
None => {
self.send_message(DaemonMessage::error("Must connect before subscribing"))
.await?;
return Ok(());
}
};
{
let mut subs = self.subscribers.write().await;
if subs.len() >= MAX_TUI_CLIENTS && !subs.contains_key(&client_id) {
self.send_message(DaemonMessage::error(&format!(
"Too many subscribers (max: {MAX_TUI_CLIENTS})"
)))
.await?;
return Ok(());
}
subs.insert(
client_id.clone(),
Subscriber {
writer: Arc::clone(&self.writer),
filter: session_id.clone(),
},
);
}
self.subscribed = true;
self.subscription_filter = session_id;
debug!(
client_id = %client_id,
filter = ?self.subscription_filter,
"Client subscribed to updates"
);
let sessions = self.registry.get_all_sessions().await;
self.send_message(DaemonMessage::session_list(sessions))
.await?;
}
MessageType::Unsubscribe => {
if let Some(ref client_id) = self.client_id {
let mut subs = self.subscribers.write().await;
subs.remove(client_id);
}
self.subscribed = false;
self.subscription_filter = None;
debug!(
client_id = ?self.client_id,
"Client unsubscribed from updates"
);
}
MessageType::Ping { seq } => {
self.send_message(DaemonMessage::pong(seq)).await?;
}
MessageType::Discover => {
debug!(client_id = ?self.client_id, "Client requested discovery");
let result = self.handle_discover().await;
self.send_message(DaemonMessage::discovery_complete(
result.discovered,
result.failed,
))
.await?;
}
MessageType::Disconnect => {
debug!(client_id = ?self.client_id, "Client requested disconnect");
return Err(ConnectionError::Eof);
}
}
Ok(())
}
async fn handle_status_update(
&mut self,
data: serde_json::Value,
) -> Result<(), ConnectionError> {
let session_id = data
.get("session_id")
.and_then(|v| v.as_str())
.map(SessionId::new)
.ok_or_else(|| ConnectionError::ParseError("Missing session_id".to_string()))?;
self.registry
.update_from_status_line(session_id, data)
.await
.map_err(|e| ConnectionError::RegistryError(e.to_string()))?;
Ok(())
}
async fn handle_hook_event(&mut self, data: serde_json::Value) -> Result<(), ConnectionError> {
debug!(client_id = ?self.client_id, "Received hook event data");
let raw_event: RawHookEvent =
serde_json::from_value(data).map_err(|e| ConnectionError::ParseError(e.to_string()))?;
debug!(
session_id = %raw_event.session_id(),
event_type = ?raw_event.event_type(),
pid = ?raw_event.pid,
tmux_pane = ?raw_event.tmux_pane,
"Processing hook event"
);
let lifecycle = match raw_event.to_lifecycle_event() {
Some(le) => le,
None => {
debug!(
hook_event_name = %raw_event.hook_event_name,
event_type = ?raw_event.event_type(),
tool_name = ?raw_event.tool_name,
"hook event suppressed by adapter"
);
return Ok(());
}
};
let session_id = raw_event.session_id();
let pid = raw_event.pid;
let tmux_pane = raw_event.tmux_pane.clone();
self.registry
.apply_lifecycle_event(
session_id,
lifecycle,
atm_core::Harness::ClaudeCode,
pid,
tmux_pane,
)
.await
.map_err(|e| ConnectionError::RegistryError(e.to_string()))?;
Ok(())
}
async fn handle_pi_event(&mut self, data: serde_json::Value) -> Result<(), ConnectionError> {
debug!(client_id = ?self.client_id, "Received pi event data");
let raw_event: RawPiEvent =
serde_json::from_value(data).map_err(|e| ConnectionError::ParseError(e.to_string()))?;
debug!(
session_id = ?raw_event.session_id,
event = %raw_event.event,
pid = ?raw_event.pid,
tmux_pane = ?raw_event.tmux_pane,
"Processing pi event"
);
let lifecycle = match raw_event.to_lifecycle_event() {
Some(le) => le,
None => {
debug!(event = %raw_event.event, "pi event suppressed by adapter");
return Ok(());
}
};
let session_id = match raw_event.session_id.as_deref() {
Some(s) => atm_core::SessionId::new(s),
None => raw_event
.pid
.map(atm_core::SessionId::pending_from_pid)
.ok_or_else(|| {
ConnectionError::ParseError(
"pi event missing both session_id and pid; cannot attribute".to_string(),
)
})?,
};
self.registry
.apply_lifecycle_event(
session_id,
lifecycle,
atm_core::Harness::Pi,
raw_event.pid,
raw_event.tmux_pane,
)
.await
.map_err(|e| ConnectionError::RegistryError(e.to_string()))?;
Ok(())
}
async fn handle_discover(&mut self) -> DiscoveryResult {
info!(client_id = ?self.client_id, "Processing discovery request");
let discovery = DiscoveryService::new(self.registry.clone());
discovery.discover().await
}
async fn read_message(&mut self) -> Result<ClientMessage, ConnectionError> {
let mut line = String::new();
let bytes_read = self
.reader
.read_line(&mut line)
.await
.map_err(|e| ConnectionError::Io(e.to_string()))?;
if bytes_read == 0 {
return Err(ConnectionError::Eof);
}
if line.len() > MAX_MESSAGE_SIZE {
return Err(ConnectionError::MessageTooLarge {
size: line.len(),
max: MAX_MESSAGE_SIZE,
});
}
let msg: ClientMessage =
serde_json::from_str(&line).map_err(|e| ConnectionError::ParseError(e.to_string()))?;
debug!(
client_id = ?self.client_id,
message_type = ?std::mem::discriminant(&msg.message),
"Received message"
);
Ok(msg)
}
async fn send_message(&self, msg: DaemonMessage) -> Result<(), ConnectionError> {
let json =
serde_json::to_string(&msg).map_err(|e| ConnectionError::ParseError(e.to_string()))?;
let mut writer = self.writer.lock().await;
match timeout(WRITE_TIMEOUT, async {
writer.write_all(json.as_bytes()).await?;
writer.write_all(b"\n").await?;
writer.flush().await?;
Ok::<(), std::io::Error>(())
})
.await
{
Ok(Ok(())) => Ok(()),
Ok(Err(e)) => Err(ConnectionError::Io(e.to_string())),
Err(_) => Err(ConnectionError::WriteTimeout),
}
}
pub fn is_subscribed(&self) -> bool {
self.subscribed
}
pub fn should_receive_event(&self, session_id: &SessionId) -> bool {
if !self.subscribed {
return false;
}
match &self.subscription_filter {
Some(filter) => filter == session_id,
None => true, }
}
pub fn client_id(&self) -> Option<&str> {
self.client_id.as_deref()
}
}
#[derive(Debug, thiserror::Error)]
pub enum ConnectionError {
#[error("Protocol version mismatch: client {client}, server {server}")]
VersionMismatch {
client: ProtocolVersion,
server: ProtocolVersion,
},
#[error("Unexpected message: {0}")]
UnexpectedMessage(String),
#[error("Parse error: {0}")]
ParseError(String),
#[error("I/O error: {0}")]
Io(String),
#[error("Connection closed")]
Eof,
#[error("Read timeout")]
Timeout,
#[error("Write timeout")]
WriteTimeout,
#[error("Message too large: {size} bytes (max: {max})")]
MessageTooLarge { size: usize, max: usize },
#[error("Registry error: {0}")]
RegistryError(String),
}
#[allow(dead_code)]
pub async fn send_event(
writer: &Arc<Mutex<BufWriter<OwnedWriteHalf>>>,
event: &SessionEvent,
) -> Result<(), ConnectionError> {
let msg = match event {
SessionEvent::Registered { .. } => {
return Ok(());
}
SessionEvent::Updated { session } => DaemonMessage::session_updated((**session).clone()),
SessionEvent::Removed { session_id, .. } => {
DaemonMessage::session_removed(session_id.clone())
}
};
let json =
serde_json::to_string(&msg).map_err(|e| ConnectionError::ParseError(e.to_string()))?;
let mut writer = writer.lock().await;
match timeout(WRITE_TIMEOUT, async {
writer.write_all(json.as_bytes()).await?;
writer.write_all(b"\n").await?;
writer.flush().await?;
Ok::<(), std::io::Error>(())
})
.await
{
Ok(Ok(())) => Ok(()),
Ok(Err(e)) => Err(ConnectionError::Io(e.to_string())),
Err(_) => Err(ConnectionError::WriteTimeout),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_connection_error_display() {
let err = ConnectionError::VersionMismatch {
client: ProtocolVersion::new(2, 0),
server: ProtocolVersion::new(1, 0),
};
assert!(err.to_string().contains("2.0"));
assert!(err.to_string().contains("1.0"));
}
#[test]
fn test_message_size_error() {
let err = ConnectionError::MessageTooLarge {
size: 2_000_000,
max: MAX_MESSAGE_SIZE,
};
assert!(err.to_string().contains("2000000"));
}
}