use std::collections::HashMap;
use std::sync::Arc;
use car_ir::{ToolHandle, ToolStatus, ToolStreamChunk, ToolStreamEvent};
use serde::{Deserialize, Serialize};
use tokio::sync::{broadcast, Mutex};
use tokio_util::sync::CancellationToken;
const EVENT_CHANNEL_CAP: usize = 256;
const MAX_BUFFERED_CHUNKS: usize = 1024;
const TERMINAL_TTL: std::time::Duration = std::time::Duration::from_secs(15 * 60);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolPollResult {
pub handle: String,
pub tool: String,
pub action_id: String,
pub status: ToolStatus,
pub chunks: Vec<ToolStreamChunk>,
#[serde(default, skip_serializing_if = "is_zero")]
pub dropped_chunks: u64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub result: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
fn is_zero(n: &u64) -> bool {
*n == 0
}
struct HandleEntry {
tool: String,
action_id: String,
status: ToolStatus,
buffered: Vec<ToolStreamChunk>,
dropped_chunks: u64,
result: Option<serde_json::Value>,
error: Option<String>,
cancel: CancellationToken,
drained_terminal: bool,
sealed_at: Option<std::time::Instant>,
}
pub struct ToolHandleRegistry {
entries: Mutex<HashMap<String, HandleEntry>>,
events: broadcast::Sender<ToolStreamEvent>,
}
impl Default for ToolHandleRegistry {
fn default() -> Self {
Self::new()
}
}
impl ToolHandleRegistry {
pub fn new() -> Self {
let (events, _) = broadcast::channel(EVENT_CHANNEL_CAP);
Self {
entries: Mutex::new(HashMap::new()),
events,
}
}
pub fn subscribe(&self) -> broadcast::Receiver<ToolStreamEvent> {
self.events.subscribe()
}
pub async fn register(&self, tool: &str, action_id: &str) -> (ToolHandle, CancellationToken) {
let id = uuid::Uuid::new_v4().simple().to_string();
let cancel = CancellationToken::new();
let entry = HandleEntry {
tool: tool.to_string(),
action_id: action_id.to_string(),
status: ToolStatus::Running,
buffered: Vec::new(),
dropped_chunks: 0,
result: None,
error: None,
cancel: cancel.clone(),
drained_terminal: false,
sealed_at: None,
};
let mut entries = self.entries.lock().await;
Self::reap_expired(&mut entries);
entries.insert(id.clone(), entry);
drop(entries);
(ToolHandle::new(id), cancel)
}
fn reap_expired(entries: &mut HashMap<String, HandleEntry>) {
entries.retain(|_, e| match e.sealed_at {
Some(at) => at.elapsed() < TERMINAL_TTL,
None => true,
});
}
pub async fn push_chunk(&self, handle_id: &str, chunk: ToolStreamChunk) {
let mut entries = self.entries.lock().await;
let Some(entry) = entries.get_mut(handle_id) else {
return;
};
if entry.status.is_terminal() {
return;
}
match &chunk {
ToolStreamChunk::Done { result } => {
entry.status = ToolStatus::Succeeded;
entry.result = result.clone();
entry.sealed_at = Some(std::time::Instant::now());
}
ToolStreamChunk::Error { message } => {
entry.status = ToolStatus::Failed;
entry.error = Some(message.clone());
entry.sealed_at = Some(std::time::Instant::now());
}
_ => {}
}
if entry.buffered.len() >= MAX_BUFFERED_CHUNKS {
entry.buffered.remove(0);
entry.dropped_chunks += 1;
}
entry.buffered.push(chunk.clone());
let _ = self.events.send(ToolStreamEvent {
handle: ToolHandle::new(handle_id.to_string()),
chunk,
});
}
pub async fn mark_stream_closed(&self, handle_id: &str) {
let mut entries = self.entries.lock().await;
if let Some(entry) = entries.get_mut(handle_id) {
if !entry.status.is_terminal() {
entry.status = ToolStatus::Failed;
entry.error = Some("tool stream closed without a terminal chunk".to_string());
entry.sealed_at = Some(std::time::Instant::now());
}
}
}
pub async fn cancel(&self, handle_id: &str) -> bool {
let mut entries = self.entries.lock().await;
Self::reap_expired(&mut entries);
let Some(entry) = entries.get_mut(handle_id) else {
return false;
};
entry.cancel.cancel();
if !entry.status.is_terminal() {
entry.status = ToolStatus::Cancelled;
entry.sealed_at = Some(std::time::Instant::now());
}
true
}
pub async fn cancel_all(&self) -> usize {
let mut entries = self.entries.lock().await;
let mut n = 0;
for e in entries.values_mut() {
if !e.status.is_terminal() {
e.cancel.cancel();
e.status = ToolStatus::Cancelled;
e.sealed_at = Some(std::time::Instant::now());
n += 1;
}
}
n
}
pub async fn poll(&self, handle_id: &str) -> Option<ToolPollResult> {
let mut entries = self.entries.lock().await;
Self::reap_expired(&mut entries);
let entry = entries.get_mut(handle_id)?;
let chunks = std::mem::take(&mut entry.buffered);
let dropped = std::mem::take(&mut entry.dropped_chunks);
let res = ToolPollResult {
handle: handle_id.to_string(),
tool: entry.tool.clone(),
action_id: entry.action_id.clone(),
status: entry.status,
chunks,
dropped_chunks: dropped,
result: entry.result.clone(),
error: entry.error.clone(),
};
if entry.status.is_terminal() {
if entry.drained_terminal {
entries.remove(handle_id);
} else {
entry.drained_terminal = true;
}
}
Some(res)
}
pub async fn status(&self, handle_id: &str) -> Option<ToolStatus> {
self.entries.lock().await.get(handle_id).map(|e| e.status)
}
pub async fn live_count(&self) -> usize {
self.entries
.lock()
.await
.values()
.filter(|e| !e.status.is_terminal())
.count()
}
}
pub fn spawn_drain(
registry: Arc<ToolHandleRegistry>,
handle_id: String,
mut rx: tokio::sync::mpsc::Receiver<ToolStreamChunk>,
cancel: CancellationToken,
) {
tokio::spawn(async move {
loop {
tokio::select! {
_ = cancel.cancelled() => {
break;
}
chunk = rx.recv() => match chunk {
Some(c) => {
let terminal = c.is_terminal();
registry.push_chunk(&handle_id, c).await;
if terminal {
break;
}
}
None => {
registry.mark_stream_closed(&handle_id).await;
break;
}
}
}
}
});
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn chunks_buffer_and_drain_in_order() {
let reg = ToolHandleRegistry::new();
let (h, _tok) = reg.register("tail_log", "a1").await;
reg.push_chunk(&h.id, ToolStreamChunk::Text { text: "one".into() })
.await;
reg.push_chunk(&h.id, ToolStreamChunk::Text { text: "two".into() })
.await;
let poll = reg.poll(&h.id).await.expect("known handle");
assert_eq!(poll.status, ToolStatus::Running);
assert_eq!(poll.chunks.len(), 2);
let poll2 = reg.poll(&h.id).await.expect("still live");
assert!(poll2.chunks.is_empty());
}
#[tokio::test]
async fn done_chunk_seals_success_with_result() {
let reg = ToolHandleRegistry::new();
let (h, _tok) = reg.register("build", "a2").await;
reg.push_chunk(
&h.id,
ToolStreamChunk::Done {
result: Some(serde_json::json!({"exit": 0})),
},
)
.await;
let poll = reg.poll(&h.id).await.unwrap();
assert_eq!(poll.status, ToolStatus::Succeeded);
assert_eq!(poll.result, Some(serde_json::json!({"exit": 0})));
reg.push_chunk(&h.id, ToolStreamChunk::Text { text: "late".into() })
.await;
let poll2 = reg.poll(&h.id).await.unwrap();
assert!(poll2.chunks.is_empty());
assert!(reg.poll(&h.id).await.is_none());
}
#[tokio::test]
async fn cancel_seals_cancelled_and_fires_token() {
let reg = ToolHandleRegistry::new();
let (h, tok) = reg.register("watch", "a3").await;
assert!(reg.cancel(&h.id).await);
assert!(tok.is_cancelled());
assert_eq!(reg.status(&h.id).await, Some(ToolStatus::Cancelled));
reg.push_chunk(&h.id, ToolStreamChunk::Done { result: None }).await;
assert_eq!(reg.status(&h.id).await, Some(ToolStatus::Cancelled));
assert!(!reg.cancel("nope").await);
}
#[tokio::test]
async fn drain_task_forwards_until_terminal() {
let reg = Arc::new(ToolHandleRegistry::new());
let (h, tok) = reg.register("gen", "a4").await;
let (tx, rx) = tokio::sync::mpsc::channel(8);
spawn_drain(reg.clone(), h.id.clone(), rx, tok);
tx.send(ToolStreamChunk::Progress {
fraction: 0.5,
message: Some("half".into()),
})
.await
.unwrap();
tx.send(ToolStreamChunk::Done { result: None }).await.unwrap();
for _ in 0..100 {
if reg.status(&h.id).await == Some(ToolStatus::Succeeded) {
break;
}
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
}
let poll = reg.poll(&h.id).await.unwrap();
assert_eq!(poll.status, ToolStatus::Succeeded);
assert_eq!(poll.chunks.len(), 2);
}
#[tokio::test]
async fn dropped_stream_without_terminal_is_failure() {
let reg = Arc::new(ToolHandleRegistry::new());
let (h, tok) = reg.register("flaky", "a5").await;
let (tx, rx) = tokio::sync::mpsc::channel(8);
spawn_drain(reg.clone(), h.id.clone(), rx, tok);
tx.send(ToolStreamChunk::Text { text: "partial".into() })
.await
.unwrap();
drop(tx);
for _ in 0..100 {
if reg.status(&h.id).await == Some(ToolStatus::Failed) {
break;
}
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
}
let poll = reg.poll(&h.id).await.unwrap();
assert_eq!(poll.status, ToolStatus::Failed);
assert!(poll.error.unwrap().contains("without a terminal chunk"));
}
#[tokio::test]
async fn events_broadcast_to_subscribers() {
let reg = ToolHandleRegistry::new();
let mut sub = reg.subscribe();
let (h, _tok) = reg.register("emit", "a6").await;
reg.push_chunk(&h.id, ToolStreamChunk::Text { text: "x".into() })
.await;
let ev = sub.recv().await.expect("event delivered");
assert_eq!(ev.handle.id, h.id);
assert!(matches!(ev.chunk, ToolStreamChunk::Text { .. }));
}
}