use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{Mutex, Semaphore};
use crate::state::State;
use crate::text::TextAgent;
pub struct BackgroundAgentDispatcher {
budget: Arc<Semaphore>,
tasks: Arc<Mutex<HashMap<String, tokio::task::JoinHandle<()>>>>,
max_concurrent: usize,
}
impl BackgroundAgentDispatcher {
pub fn new(max_concurrent: usize) -> Self {
Self {
budget: Arc::new(Semaphore::new(max_concurrent)),
tasks: Arc::new(Mutex::new(HashMap::new())),
max_concurrent,
}
}
pub fn max_concurrent(&self) -> usize {
self.max_concurrent
}
pub fn available_permits(&self) -> usize {
self.budget.available_permits()
}
pub fn dispatch(&self, task_name: impl Into<String>, agent: Arc<dyn TextAgent>, state: State) {
let name = task_name.into();
let budget = self.budget.clone();
let tasks = self.tasks.clone();
let result_key = format!("{name}:result");
let error_key = format!("{name}:error");
let name_for_cleanup = name.clone();
let handle = tokio::spawn(async move {
let _permit = match budget.acquire().await {
Ok(p) => p,
Err(_) => return, };
match agent.run(&state).await {
Ok(result) => {
state.set(&result_key, &result);
}
Err(e) => {
state.set(&error_key, format!("{e}"));
}
}
tasks.lock().await.remove(&name_for_cleanup);
});
if let Ok(mut guard) = self.tasks.try_lock() {
guard.insert(name, handle);
}
}
pub async fn is_running(&self, name: &str) -> bool {
let guard = self.tasks.lock().await;
guard.get(name).map(|h| !h.is_finished()).unwrap_or(false)
}
pub async fn cancel_all(&self) {
let mut guard = self.tasks.lock().await;
for (_, handle) in guard.drain() {
handle.abort();
}
}
pub async fn cancel(&self, name: &str) {
let mut guard = self.tasks.lock().await;
if let Some(handle) = guard.remove(name) {
handle.abort();
}
}
pub async fn active_count(&self) -> usize {
let guard = self.tasks.lock().await;
guard.values().filter(|h| !h.is_finished()).count()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::AgentError;
use async_trait::async_trait;
struct QuickAgent {
output: String,
}
#[async_trait]
impl TextAgent for QuickAgent {
fn name(&self) -> &str {
"quick"
}
async fn run(&self, _state: &State) -> Result<String, AgentError> {
Ok(self.output.clone())
}
}
struct SlowAgent;
#[async_trait]
impl TextAgent for SlowAgent {
fn name(&self) -> &str {
"slow"
}
async fn run(&self, _state: &State) -> Result<String, AgentError> {
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
Ok("done".into())
}
}
struct FailAgent;
#[async_trait]
impl TextAgent for FailAgent {
fn name(&self) -> &str {
"fail"
}
async fn run(&self, _state: &State) -> Result<String, AgentError> {
Err(AgentError::Other("background failure".into()))
}
}
struct StateWriterAgent;
#[async_trait]
impl TextAgent for StateWriterAgent {
fn name(&self) -> &str {
"writer"
}
async fn run(&self, state: &State) -> Result<String, AgentError> {
state.set("bg_wrote", true);
Ok("wrote state".into())
}
}
#[tokio::test]
async fn dispatch_writes_result_to_state() {
let dispatcher = BackgroundAgentDispatcher::new(5);
let state = State::new();
let agent: Arc<dyn TextAgent> = Arc::new(QuickAgent {
output: "analysis complete".into(),
});
dispatcher.dispatch("analysis", agent, state.clone());
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert_eq!(
state.get::<String>("analysis:result"),
Some("analysis complete".into())
);
}
#[tokio::test]
async fn dispatch_writes_error_to_state() {
let dispatcher = BackgroundAgentDispatcher::new(5);
let state = State::new();
let agent: Arc<dyn TextAgent> = Arc::new(FailAgent);
dispatcher.dispatch("check", agent, state.clone());
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let error = state.get::<String>("check:error");
assert!(error.is_some());
assert!(error.unwrap().contains("background failure"));
}
#[tokio::test]
async fn budget_limits_concurrency() {
let dispatcher = BackgroundAgentDispatcher::new(2);
let state = State::new();
let agent: Arc<dyn TextAgent> = Arc::new(SlowAgent);
assert_eq!(dispatcher.available_permits(), 2);
dispatcher.dispatch("task1", agent.clone(), state.clone());
dispatcher.dispatch("task2", agent.clone(), state.clone());
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
assert_eq!(dispatcher.available_permits(), 0);
tokio::time::sleep(std::time::Duration::from_millis(300)).await;
assert_eq!(dispatcher.available_permits(), 2);
}
#[tokio::test]
async fn cancel_all_aborts_tasks() {
let dispatcher = BackgroundAgentDispatcher::new(5);
let state = State::new();
let agent: Arc<dyn TextAgent> = Arc::new(SlowAgent);
dispatcher.dispatch("long", agent, state.clone());
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
assert!(dispatcher.is_running("long").await);
dispatcher.cancel_all().await;
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert!(state.get::<String>("long:result").is_none());
}
#[tokio::test]
async fn state_mutations_visible_to_parent() {
let dispatcher = BackgroundAgentDispatcher::new(5);
let state = State::new();
let agent: Arc<dyn TextAgent> = Arc::new(StateWriterAgent);
dispatcher.dispatch("writer", agent, state.clone());
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert_eq!(state.get::<bool>("bg_wrote"), Some(true));
assert_eq!(
state.get::<String>("writer:result"),
Some("wrote state".into())
);
}
#[tokio::test]
async fn cancel_specific_task() {
let dispatcher = BackgroundAgentDispatcher::new(5);
let state = State::new();
let agent: Arc<dyn TextAgent> = Arc::new(SlowAgent);
dispatcher.dispatch("keep", agent.clone(), state.clone());
dispatcher.dispatch("abort", agent, state.clone());
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
dispatcher.cancel("abort").await;
tokio::time::sleep(std::time::Duration::from_millis(300)).await;
assert_eq!(state.get::<String>("keep:result"), Some("done".into()));
assert!(state.get::<String>("abort:result").is_none());
}
}