use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::io::AsyncWrite;
use tokio::sync::Mutex;
use agent_team_mail_core::event_log::{EventFields, emit_event_best_effort};
pub fn queue_dir(team: &str, agent_id: &str) -> anyhow::Result<PathBuf> {
let home = agent_team_mail_core::home::get_home_dir()?;
Ok(home
.join(".config/atm/agent-sessions")
.join(team)
.join(agent_id)
.join("stdin_queue"))
}
pub async fn enqueue(team: &str, agent_id: &str, content: &str) -> anyhow::Result<()> {
let dir = queue_dir(team, agent_id)?;
tokio::fs::create_dir_all(&dir).await?;
let id = uuid::Uuid::new_v4();
let path = dir.join(format!("{id}.json"));
tokio::fs::write(&path, content.as_bytes()).await?;
emit_event_best_effort(EventFields {
level: "info",
source: "atm-agent-mcp",
action: "stdin_queue_enqueue",
team: Some(team.to_string()),
result: Some(format!("agent:{agent_id}")),
..Default::default()
});
Ok(())
}
pub async fn drain(
team: &str,
agent_id: &str,
stdin: &Arc<Mutex<Box<dyn AsyncWrite + Send + Unpin>>>,
ttl: Duration,
) -> anyhow::Result<usize> {
let dir = queue_dir(team, agent_id)?;
if !dir.exists() {
return Ok(0);
}
let _ = cleanup_ttl(&dir, ttl).await;
let mut entries = match tokio::fs::read_dir(&dir).await {
Ok(e) => e,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(0),
Err(e) => return Err(e.into()),
};
let mut drained = 0usize;
let mut json_files = Vec::new();
while let Ok(Some(entry)) = entries.next_entry().await {
let path = entry.path();
if path.extension().and_then(|e| e.to_str()) == Some("json") {
json_files.push(path);
}
}
for path in json_files {
let lock_path = path.with_extension("lock");
let claim_result = tokio::task::spawn_blocking({
let lock_path = lock_path.clone();
move || {
std::fs::OpenOptions::new()
.write(true)
.create_new(true)
.open(&lock_path)
}
})
.await;
match claim_result {
Ok(Ok(_file)) => {
}
Ok(Err(e)) if e.kind() == std::io::ErrorKind::AlreadyExists => {
continue;
}
Ok(Err(e)) => {
tracing::warn!(
path = %lock_path.display(),
error = %e,
"unexpected error creating stdin queue lock file; skipping"
);
continue;
}
Err(join_err) => {
tracing::warn!(
path = %lock_path.display(),
error = %join_err,
"spawn_blocking panicked creating stdin queue lock; skipping"
);
continue;
}
}
let content = match tokio::fs::read_to_string(&path).await {
Ok(c) => c,
Err(e) => {
tracing::warn!(
path = %path.display(),
error = %e,
"failed to read stdin queue file; releasing lock"
);
let _ = tokio::fs::remove_file(&lock_path).await;
continue;
}
};
let write_result = {
let mut guard = stdin.lock().await;
crate::framing::write_newline_delimited(&mut **guard, content.trim()).await
};
match write_result {
Ok(()) => {
let _ = tokio::fs::remove_file(&path).await;
let _ = tokio::fs::remove_file(&lock_path).await;
drained += 1;
}
Err(e) => {
tracing::warn!(
path = %path.display(),
error = %e,
"stdin queue write failed; releasing lock for retry"
);
let _ = tokio::fs::remove_file(&lock_path).await;
}
}
}
emit_event_best_effort(EventFields {
level: "info",
source: "atm-agent-mcp",
action: "stdin_queue_drain",
team: Some(team.to_string()),
result: Some(format!("drained:{drained}")),
..Default::default()
});
Ok(drained)
}
async fn cleanup_ttl(dir: &Path, ttl: Duration) -> anyhow::Result<usize> {
let cutoff = SystemTime::now()
.checked_sub(ttl)
.unwrap_or(SystemTime::UNIX_EPOCH);
let mut entries = match tokio::fs::read_dir(dir).await {
Ok(e) => e,
Err(_) => return Ok(0),
};
let mut removed = 0usize;
while let Ok(Some(entry)) = entries.next_entry().await {
let path = entry.path();
let ext = path.extension().and_then(|e| e.to_str());
if ext != Some("json") && ext != Some("lock") {
continue;
}
let metadata = match tokio::fs::metadata(&path).await {
Ok(m) => m,
Err(_) => continue,
};
let mtime = match metadata.modified() {
Ok(t) => t,
Err(_) => continue,
};
if mtime < cutoff && tokio::fs::remove_file(&path).await.is_ok() {
removed += 1;
}
}
Ok(removed)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use tokio::sync::Mutex;
fn setup_env(tmp: &tempfile::TempDir) -> (String, String) {
let team = "test-team";
let agent_id = "test-agent";
unsafe {
std::env::set_var("ATM_HOME", tmp.path());
}
(team.to_string(), agent_id.to_string())
}
#[tokio::test]
#[serial_test::serial]
async fn enqueue_creates_file() {
let tmp = tempfile::TempDir::new().unwrap();
let (team, agent_id) = setup_env(&tmp);
enqueue(&team, &agent_id, r#"{"hello":"world"}"#)
.await
.unwrap();
let dir = queue_dir(&team, &agent_id).unwrap();
let mut entries = tokio::fs::read_dir(&dir).await.unwrap();
let mut count = 0;
while let Ok(Some(entry)) = entries.next_entry().await {
assert_eq!(
entry.path().extension().and_then(|e| e.to_str()),
Some("json")
);
count += 1;
}
assert_eq!(count, 1);
}
struct SharedCapWriter(Arc<std::sync::Mutex<Vec<u8>>>);
impl SharedCapWriter {
fn new() -> (Self, Arc<std::sync::Mutex<Vec<u8>>>) {
let buf = Arc::new(std::sync::Mutex::new(Vec::new()));
(Self(Arc::clone(&buf)), buf)
}
}
impl AsyncWrite for SharedCapWriter {
fn poll_write(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
self.0.lock().unwrap().extend_from_slice(buf);
std::task::Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::task::Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::task::Poll::Ready(Ok(()))
}
}
#[tokio::test]
#[serial_test::serial]
async fn enqueue_drain_roundtrip() {
let tmp = tempfile::TempDir::new().unwrap();
let (team, agent_id) = setup_env(&tmp);
let msg = r#"{"type":"tool_result","data":"hello"}"#;
enqueue(&team, &agent_id, msg).await.unwrap();
enqueue(&team, &agent_id, msg).await.unwrap();
let (writer, captured) = SharedCapWriter::new();
let stdin: Arc<Mutex<Box<dyn AsyncWrite + Send + Unpin>>> =
Arc::new(Mutex::new(Box::new(writer)));
let count = drain(&team, &agent_id, &stdin, Duration::from_secs(600))
.await
.unwrap();
assert_eq!(count, 2);
let output = captured.lock().unwrap().clone();
let text = String::from_utf8_lossy(&output);
let lines: Vec<&str> = text.lines().collect();
assert_eq!(lines.len(), 2);
for line in &lines {
assert!(line.contains("tool_result"));
}
let dir = queue_dir(&team, &agent_id).unwrap();
let mut entries = tokio::fs::read_dir(&dir).await.unwrap();
let mut remaining = 0;
while let Ok(Some(_)) = entries.next_entry().await {
remaining += 1;
}
assert_eq!(remaining, 0, "all queue files should be removed after drain");
}
#[tokio::test]
#[serial_test::serial]
async fn drain_empty_queue_returns_zero() {
let tmp = tempfile::TempDir::new().unwrap();
let (team, agent_id) = setup_env(&tmp);
let (writer, _captured) = SharedCapWriter::new();
let stdin: Arc<Mutex<Box<dyn AsyncWrite + Send + Unpin>>> =
Arc::new(Mutex::new(Box::new(writer)));
let count = drain(&team, &agent_id, &stdin, Duration::from_secs(600))
.await
.unwrap();
assert_eq!(count, 0);
}
#[tokio::test]
#[serial_test::serial]
async fn concurrent_drain_no_double_delivery() {
let tmp = tempfile::TempDir::new().unwrap();
let (team, agent_id) = setup_env(&tmp);
for i in 0..5 {
enqueue(&team, &agent_id, &format!(r#"{{"msg":{i}}}"#))
.await
.unwrap();
}
let (writer1, cap1) = SharedCapWriter::new();
let stdin1: Arc<Mutex<Box<dyn AsyncWrite + Send + Unpin>>> =
Arc::new(Mutex::new(Box::new(writer1)));
let (writer2, cap2) = SharedCapWriter::new();
let stdin2: Arc<Mutex<Box<dyn AsyncWrite + Send + Unpin>>> =
Arc::new(Mutex::new(Box::new(writer2)));
let team_a = team.clone();
let agent_a = agent_id.clone();
let stdin_a = Arc::clone(&stdin1);
let team_b = team.clone();
let agent_b = agent_id.clone();
let stdin_b = Arc::clone(&stdin2);
let (count_a, count_b) = tokio::join!(
drain(&team_a, &agent_a, &stdin_a, Duration::from_secs(600)),
drain(&team_b, &agent_b, &stdin_b, Duration::from_secs(600)),
);
let total = count_a.unwrap() + count_b.unwrap();
assert_eq!(total, 5, "total drained should be exactly 5 (no double delivery)");
let out1 = cap1.lock().unwrap().clone();
let out2 = cap2.lock().unwrap().clone();
let text1 = String::from_utf8_lossy(&out1);
let text2 = String::from_utf8_lossy(&out2);
let lines1: Vec<&str> = text1.lines().filter(|l| !l.is_empty()).collect();
let lines2: Vec<&str> = text2.lines().filter(|l| !l.is_empty()).collect();
assert_eq!(
lines1.len() + lines2.len(),
5,
"exactly 5 messages should be delivered across both drains"
);
}
#[tokio::test]
#[serial_test::serial]
async fn ttl_cleanup_removes_old_files() {
let tmp = tempfile::TempDir::new().unwrap();
let (team, agent_id) = setup_env(&tmp);
enqueue(&team, &agent_id, r#"{"old":"message"}"#)
.await
.unwrap();
let dir = queue_dir(&team, &agent_id).unwrap();
let removed = cleanup_ttl(&dir, Duration::from_secs(0)).await.unwrap();
assert_eq!(removed, 1, "file should be removed with 0-second TTL");
}
#[tokio::test]
#[serial_test::serial]
async fn ttl_cleanup_removes_stale_lock_files() {
let tmp = tempfile::TempDir::new().unwrap();
let (team, agent_id) = setup_env(&tmp);
let dir = queue_dir(&team, &agent_id).unwrap();
tokio::fs::create_dir_all(&dir).await.unwrap();
let lock_path = dir.join("00000000-0000-0000-0000-000000000001.lock");
std::fs::write(&lock_path, b"").unwrap();
let removed = cleanup_ttl(&dir, Duration::from_secs(0)).await.unwrap();
assert_eq!(removed, 1, "stale lock file should be removed with 0-second TTL");
}
}