Skip to main content

rs_adk/text/
dispatch.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use async_trait::async_trait;
6
7use super::TextAgent;
8use crate::error::AgentError;
9use crate::state::State;
10
11/// Shared registry for dispatched background tasks.
12#[derive(Clone, Default)]
13pub struct TaskRegistry {
14    pub(crate) inner:
15        Arc<tokio::sync::Mutex<HashMap<String, tokio::task::JoinHandle<Result<String, String>>>>>,
16}
17
18impl TaskRegistry {
19    /// Create a new empty task registry.
20    pub fn new() -> Self {
21        Self::default()
22    }
23}
24
25/// Fire-and-forget background task launcher with global task budget.
26///
27/// Launches each child agent as a background `tokio::spawn` task,
28/// stores handles in a `TaskRegistry`, and returns immediately.
29pub struct DispatchTextAgent {
30    name: String,
31    children: Vec<(String, Arc<dyn TextAgent>)>,
32    registry: TaskRegistry,
33    budget: Arc<tokio::sync::Semaphore>,
34}
35
36impl DispatchTextAgent {
37    /// Create a new dispatch agent with named children and a concurrency budget.
38    pub fn new(
39        name: impl Into<String>,
40        children: Vec<(String, Arc<dyn TextAgent>)>,
41        registry: TaskRegistry,
42        budget: Arc<tokio::sync::Semaphore>,
43    ) -> Self {
44        Self {
45            name: name.into(),
46            children,
47            registry,
48            budget,
49        }
50    }
51}
52
53#[async_trait]
54impl TextAgent for DispatchTextAgent {
55    fn name(&self) -> &str {
56        &self.name
57    }
58
59    async fn run(&self, state: &State) -> Result<String, AgentError> {
60        let mut registry = self.registry.inner.lock().await;
61
62        for (task_name, agent) in &self.children {
63            let agent = agent.clone();
64            let state = state.clone();
65            let budget = self.budget.clone();
66            let task_name_owned = task_name.clone();
67
68            let handle = tokio::spawn(async move {
69                let _permit = budget
70                    .acquire()
71                    .await
72                    .map_err(|e| format!("Semaphore closed: {e}"))?;
73                agent
74                    .run(&state)
75                    .await
76                    .map_err(|e| format!("Task '{}' failed: {}", task_name_owned, e))
77            });
78
79            registry.insert(task_name.clone(), handle);
80        }
81
82        state.set(
83            "_dispatch_status",
84            self.children
85                .iter()
86                .map(|(name, _)| (name.clone(), "running".to_string()))
87                .collect::<HashMap<String, String>>(),
88        );
89
90        Ok(String::new())
91    }
92}
93
94// ── JoinTextAgent ─────────────────────────────────────────────────────────
95
96/// Waits for dispatched background tasks and collects their results.
97pub struct JoinTextAgent {
98    name: String,
99    registry: TaskRegistry,
100    target_names: Option<Vec<String>>,
101    timeout: Option<Duration>,
102}
103
104impl JoinTextAgent {
105    /// Create a new join agent that waits for dispatched tasks.
106    pub fn new(name: impl Into<String>, registry: TaskRegistry) -> Self {
107        Self {
108            name: name.into(),
109            registry,
110            target_names: None,
111            timeout: None,
112        }
113    }
114
115    /// Only wait for specific named tasks.
116    pub fn targets(mut self, names: Vec<String>) -> Self {
117        self.target_names = Some(names);
118        self
119    }
120
121    /// Set a timeout for waiting.
122    pub fn timeout(mut self, timeout: Duration) -> Self {
123        self.timeout = Some(timeout);
124        self
125    }
126}
127
128#[async_trait]
129impl TextAgent for JoinTextAgent {
130    fn name(&self) -> &str {
131        &self.name
132    }
133
134    async fn run(&self, state: &State) -> Result<String, AgentError> {
135        let mut registry = self.registry.inner.lock().await;
136
137        // Select tasks to wait for.
138        let tasks: HashMap<String, _> = if let Some(targets) = &self.target_names {
139            targets
140                .iter()
141                .filter_map(|name| registry.remove(name).map(|h| (name.clone(), h)))
142                .collect()
143        } else {
144            std::mem::take(&mut *registry)
145        };
146        drop(registry);
147
148        let mut results = Vec::new();
149
150        for (task_name, handle) in tasks {
151            let result = if let Some(timeout) = self.timeout {
152                match tokio::time::timeout(timeout, handle).await {
153                    Ok(Ok(Ok(text))) => {
154                        state.set(format!("_result_{}", task_name), &text);
155                        Ok(text)
156                    }
157                    Ok(Ok(Err(e))) => Err(AgentError::Other(e)),
158                    Ok(Err(e)) => Err(AgentError::Other(format!("Join error: {e}"))),
159                    Err(_) => Err(AgentError::Timeout),
160                }
161            } else {
162                match handle.await {
163                    Ok(Ok(text)) => {
164                        state.set(format!("_result_{}", task_name), &text);
165                        Ok(text)
166                    }
167                    Ok(Err(e)) => Err(AgentError::Other(e)),
168                    Err(e) => Err(AgentError::Other(format!("Join error: {e}"))),
169                }
170            };
171
172            results.push(result?);
173        }
174
175        let combined = results.join("\n");
176        state.set("output", &combined);
177        Ok(combined)
178    }
179}