use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tokio::sync::oneshot;
#[derive(Debug)]
pub struct BgAgentResult {
pub agent_name: String,
pub prompt: String,
pub output: String,
pub success: bool,
}
struct BgAgentEntry {
agent_name: String,
prompt: String,
rx: oneshot::Receiver<Result<String, String>>,
}
pub struct BgAgentRegistry {
pending: Mutex<HashMap<u32, BgAgentEntry>>,
next_id: Mutex<u32>,
}
impl BgAgentRegistry {
pub fn new() -> Self {
Self {
pending: Mutex::new(HashMap::new()),
next_id: Mutex::new(1),
}
}
pub fn register(
&self,
agent_name: &str,
prompt: &str,
) -> (u32, oneshot::Sender<Result<String, String>>) {
let (tx, rx) = oneshot::channel();
let mut id = self.next_id.lock().unwrap();
let task_id = *id;
*id += 1;
drop(id);
self.pending.lock().unwrap().insert(
task_id,
BgAgentEntry {
agent_name: agent_name.to_string(),
prompt: prompt.to_string(),
rx,
},
);
(task_id, tx)
}
pub fn drain_completed(&self) -> Vec<BgAgentResult> {
let mut guard = self.pending.lock().unwrap();
let mut completed = Vec::new();
let mut done_ids = Vec::new();
for (id, entry) in guard.iter_mut() {
match entry.rx.try_recv() {
Ok(Ok(output)) => {
done_ids.push(*id);
completed.push(BgAgentResult {
agent_name: entry.agent_name.clone(),
prompt: entry.prompt.clone(),
output,
success: true,
});
}
Ok(Err(err)) => {
done_ids.push(*id);
completed.push(BgAgentResult {
agent_name: entry.agent_name.clone(),
prompt: entry.prompt.clone(),
output: err,
success: false,
});
}
Err(oneshot::error::TryRecvError::Empty) => {
}
Err(oneshot::error::TryRecvError::Closed) => {
done_ids.push(*id);
completed.push(BgAgentResult {
agent_name: entry.agent_name.clone(),
prompt: entry.prompt.clone(),
output: "[background agent task was cancelled]".to_string(),
success: false,
});
}
}
}
for id in done_ids {
guard.remove(&id);
}
completed
}
pub fn pending_count(&self) -> usize {
self.pending.lock().unwrap().len()
}
}
impl Default for BgAgentRegistry {
fn default() -> Self {
Self::new()
}
}
pub fn new_shared() -> Arc<BgAgentRegistry> {
Arc::new(BgAgentRegistry::new())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn register_and_complete() {
let reg = BgAgentRegistry::new();
let (task_id, tx) = reg.register("explore", "find all tests");
assert_eq!(task_id, 1);
assert_eq!(reg.pending_count(), 1);
assert!(reg.drain_completed().is_empty());
tx.send(Ok("found 42 tests".to_string())).unwrap();
let results = reg.drain_completed();
assert_eq!(results.len(), 1);
assert_eq!(results[0].agent_name, "explore");
assert_eq!(results[0].output, "found 42 tests");
assert!(results[0].success);
assert_eq!(reg.pending_count(), 0);
}
#[test]
fn drain_only_completed() {
let reg = BgAgentRegistry::new();
let (_id1, tx1) = reg.register("task", "build");
let (_id2, _tx2) = reg.register("explore", "search");
tx1.send(Ok("done".to_string())).unwrap();
let results = reg.drain_completed();
assert_eq!(results.len(), 1);
assert_eq!(results[0].agent_name, "task");
assert_eq!(reg.pending_count(), 1); }
#[test]
fn dropped_sender_reports_cancelled() {
let reg = BgAgentRegistry::new();
let (_id, tx) = reg.register("task", "build");
drop(tx);
let results = reg.drain_completed();
assert_eq!(results.len(), 1);
assert!(!results[0].success);
assert!(results[0].output.contains("cancelled"));
}
#[test]
fn error_result() {
let reg = BgAgentRegistry::new();
let (_id, tx) = reg.register("verify", "check");
tx.send(Err("test failures".to_string())).unwrap();
let results = reg.drain_completed();
assert_eq!(results.len(), 1);
assert!(!results[0].success);
assert_eq!(results[0].output, "test failures");
}
}