Skip to main content

inference_gateway_adk/server/
task_manager.rs

1//! Background task manager that drains the [`Storage`] queue and
2//! dispatches each dequeued task to the configured [`TaskHandler`].
3//!
4//! one or more workers call [`Storage::dequeue_task`] (blocking), move
5//! the task into the active store, drive the handler to a terminal
6//! state, then route the result to the active store (intermediate
7//! state) or the dead-letter store (terminal state) based on the
8//! handler's returned `status.state`.
9//!
10//! Construction is decoupled from spawning so the server builder can
11//! own the manager configuration and call [`DefaultTaskManager::start`]
12//! at serve time:
13//!
14//! ```text
15//!     let manager = DefaultTaskManager::new(storage, handler, workers);
16//!     let runner = manager.start();
17//!     // ... serve until SIGINT ...
18//!     runner.shutdown().await;
19//! ```
20
21use super::storage::Storage;
22use super::task_handler::TaskHandler;
23use crate::a2a_types::{TaskState, TaskStatus, Timestamp};
24use std::sync::Arc;
25use std::time::Duration;
26use tokio::task::JoinSet;
27use tokio_util::sync::CancellationToken;
28use tracing::{debug, warn};
29
30/// Drains the storage queue and dispatches each dequeued task to the
31/// configured background [`TaskHandler`]. Construct via
32/// [`DefaultTaskManager::new`], then call [`start`](Self::start) when
33/// the server is ready to begin processing.
34pub struct DefaultTaskManager {
35    storage: Arc<dyn Storage>,
36    handler: Arc<dyn TaskHandler>,
37    worker_count: usize,
38}
39
40impl std::fmt::Debug for DefaultTaskManager {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        f.debug_struct("DefaultTaskManager")
43            .field("worker_count", &self.worker_count)
44            .finish_non_exhaustive()
45    }
46}
47
48impl DefaultTaskManager {
49    pub fn new(
50        storage: Arc<dyn Storage>,
51        handler: Arc<dyn TaskHandler>,
52        worker_count: usize,
53    ) -> Self {
54        let worker_count = worker_count.max(1);
55        Self {
56            storage,
57            handler,
58            worker_count,
59        }
60    }
61
62    /// Spawn `worker_count` workers. Returns a [`TaskManagerRunner`]
63    /// that holds the join set and cancellation token; call
64    /// [`TaskManagerRunner::shutdown`] to drain workers gracefully on
65    /// server shutdown.
66    pub fn start(&self) -> TaskManagerRunner {
67        let token = CancellationToken::new();
68        let mut join_set: JoinSet<()> = JoinSet::new();
69        for worker_id in 0..self.worker_count {
70            let storage = Arc::clone(&self.storage);
71            let handler = Arc::clone(&self.handler);
72            let token = token.clone();
73            join_set.spawn(async move {
74                run_worker(worker_id, storage, handler, token).await;
75            });
76        }
77        debug!("task manager started with {} worker(s)", self.worker_count);
78        TaskManagerRunner {
79            shutdown: token,
80            join_set,
81        }
82    }
83}
84
85/// Handle returned by [`DefaultTaskManager::start`]. Drop to detach the
86/// workers (they keep running) or call [`shutdown`](Self::shutdown) to
87/// cooperatively cancel + await them.
88#[derive(Debug)]
89pub struct TaskManagerRunner {
90    shutdown: CancellationToken,
91    join_set: JoinSet<()>,
92}
93
94impl TaskManagerRunner {
95    /// Cancel the workers' cancellation token and wait for every worker
96    /// to exit its loop. Each worker stops at the next `select!` point
97    /// (between dequeues; an in-flight handler call is allowed to finish).
98    pub async fn shutdown(mut self) {
99        self.shutdown.cancel();
100        while self.join_set.join_next().await.is_some() {}
101        debug!("task manager shutdown complete");
102    }
103
104    /// Trigger cancellation without waiting. Useful when the caller
105    /// needs to interleave shutdown with other concurrent work.
106    pub fn cancel(&self) {
107        self.shutdown.cancel();
108    }
109}
110
111async fn run_worker(
112    worker_id: usize,
113    storage: Arc<dyn Storage>,
114    handler: Arc<dyn TaskHandler>,
115    shutdown: CancellationToken,
116) {
117    debug!(worker_id, "task manager worker started");
118    loop {
119        let queued = tokio::select! {
120            biased;
121            _ = shutdown.cancelled() => {
122                debug!(worker_id, "task manager worker exiting on cancellation");
123                return;
124            }
125            res = storage.dequeue_task() => match res {
126                Ok(q) => q,
127                Err(e) => {
128                    warn!(worker_id, error = %e, "dequeue_task failed; backing off");
129                    tokio::select! {
130                        _ = shutdown.cancelled() => return,
131                        _ = tokio::time::sleep(Duration::from_secs(1)) => continue,
132                    }
133                }
134            }
135        };
136
137        let task = queued.task;
138        let task_id = task.id.clone();
139
140        if let Err(e) = storage.create_active_task(&task).await {
141            debug!(worker_id, task_id = %task_id, error = %e, "create_active_task: continuing");
142        }
143
144        let last_message = task.history.last().cloned();
145        match handler.handle_task(task.clone(), last_message).await {
146            Ok(result) => route_terminal_or_active(&storage, worker_id, result).await,
147            Err(e) => {
148                warn!(worker_id, task_id = %task_id, error = %e, "task handler failed");
149                let mut failed = task;
150                failed.status = TaskStatus {
151                    message: failed.status.message.clone(),
152                    state: TaskState::TaskStateFailed,
153                    timestamp: Some(Timestamp(chrono::Utc::now())),
154                };
155                if let Err(store_err) = storage.store_dead_letter_task(&failed).await {
156                    warn!(worker_id, task_id = %task_id, error = %store_err,
157                        "store_dead_letter_task failed after handler error");
158                }
159            }
160        }
161    }
162}
163
164async fn route_terminal_or_active(
165    storage: &Arc<dyn Storage>,
166    worker_id: usize,
167    result: crate::a2a_types::Task,
168) {
169    let terminal = matches!(
170        result.status.state,
171        TaskState::TaskStateCompleted
172            | TaskState::TaskStateFailed
173            | TaskState::TaskStateCancelled
174            | TaskState::TaskStateRejected
175    );
176    if terminal {
177        if let Err(e) = storage.store_dead_letter_task(&result).await {
178            warn!(worker_id, task_id = %result.id, error = %e, "store_dead_letter_task failed");
179        }
180    } else if let Err(e) = storage.update_active_task(&result).await {
181        warn!(worker_id, task_id = %result.id, error = %e, "update_active_task failed");
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use crate::a2a_types::{
189        Message as A2AMessage, Part, Role, Task, TaskState, TaskStatus, Timestamp,
190    };
191    use crate::server::storage::InMemoryStorage;
192    use crate::server::task_handler::TaskHandler;
193    use anyhow::Result;
194    use async_trait::async_trait;
195    use std::sync::Mutex;
196
197    fn make_task(id: &str) -> Task {
198        Task {
199            artifacts: vec![],
200            context_id: "ctx".to_string(),
201            history: vec![A2AMessage {
202                context_id: Some("ctx".to_string()),
203                extensions: vec![],
204                message_id: format!("msg-{id}"),
205                metadata: None,
206                parts: vec![Part {
207                    data: None,
208                    file: None,
209                    metadata: None,
210                    text: Some("hello".to_string()),
211                }],
212                reference_task_ids: vec![],
213                role: Role::RoleUser,
214                task_id: Some(id.to_string()),
215            }],
216            id: id.to_string(),
217            metadata: None,
218            status: TaskStatus {
219                message: None,
220                state: TaskState::TaskStateSubmitted,
221                timestamp: Some(Timestamp(chrono::Utc::now())),
222            },
223        }
224    }
225
226    /// Records every task it sees, then returns it with a configurable
227    /// terminal state so we can exercise the active/dead-letter routing.
228    #[derive(Debug)]
229    struct RecordingHandler {
230        seen: Arc<Mutex<Vec<String>>>,
231        terminal_state: TaskState,
232    }
233
234    #[async_trait]
235    impl TaskHandler for RecordingHandler {
236        async fn handle_task(&self, mut task: Task, _message: Option<A2AMessage>) -> Result<Task> {
237            self.seen
238                .lock()
239                .expect("mutex poisoned")
240                .push(task.id.clone());
241            task.status = TaskStatus {
242                message: None,
243                state: self.terminal_state,
244                timestamp: Some(Timestamp(chrono::Utc::now())),
245            };
246            Ok(task)
247        }
248    }
249
250    /// Always errors, so we can verify failures route to dead-letter
251    /// with state == Failed.
252    #[derive(Debug)]
253    struct FailingHandler;
254
255    #[async_trait]
256    impl TaskHandler for FailingHandler {
257        async fn handle_task(&self, _task: Task, _message: Option<A2AMessage>) -> Result<Task> {
258            Err(anyhow::anyhow!("handler always fails"))
259        }
260    }
261
262    async fn wait_for_terminal(storage: &Arc<InMemoryStorage>, task_id: &str) -> Task {
263        for _ in 0..50 {
264            if let Some(task) = storage.get_task(task_id).await
265                && matches!(
266                    task.status.state,
267                    TaskState::TaskStateCompleted
268                        | TaskState::TaskStateFailed
269                        | TaskState::TaskStateCancelled
270                        | TaskState::TaskStateRejected
271                )
272            {
273                return task;
274            }
275            tokio::time::sleep(Duration::from_millis(20)).await;
276        }
277        panic!("task {task_id} never reached terminal state");
278    }
279
280    #[tokio::test]
281    async fn worker_dequeues_and_routes_completed_to_dead_letter() {
282        let storage: Arc<InMemoryStorage> = Arc::new(InMemoryStorage::new());
283        let seen = Arc::new(Mutex::new(Vec::new()));
284        let handler = Arc::new(RecordingHandler {
285            seen: Arc::clone(&seen),
286            terminal_state: TaskState::TaskStateCompleted,
287        });
288
289        let manager = DefaultTaskManager::new(
290            storage.clone() as Arc<dyn Storage>,
291            handler as Arc<dyn TaskHandler>,
292            1,
293        );
294        let runner = manager.start();
295
296        storage
297            .enqueue_task(make_task("t1"), serde_json::Value::Null)
298            .await
299            .expect("enqueue");
300
301        let terminal = wait_for_terminal(&storage, "t1").await;
302        assert_eq!(terminal.status.state, TaskState::TaskStateCompleted);
303        assert!(
304            storage.get_active_task("t1").await.expect("ok").is_none(),
305            "completed tasks must be evicted from active store",
306        );
307        let stats = storage.get_stats().await;
308        assert_eq!(stats.dead_letter_tasks, 1);
309        assert_eq!(stats.active_tasks, 0);
310        assert_eq!(seen.lock().expect("mutex poisoned").as_slice(), &["t1"]);
311
312        runner.shutdown().await;
313    }
314
315    #[tokio::test]
316    async fn worker_routes_input_required_to_active_store() {
317        let storage: Arc<InMemoryStorage> = Arc::new(InMemoryStorage::new());
318        let handler = Arc::new(RecordingHandler {
319            seen: Arc::new(Mutex::new(Vec::new())),
320            terminal_state: TaskState::TaskStateInputRequired,
321        });
322
323        let manager = DefaultTaskManager::new(
324            storage.clone() as Arc<dyn Storage>,
325            handler as Arc<dyn TaskHandler>,
326            1,
327        );
328        let runner = manager.start();
329
330        storage
331            .enqueue_task(make_task("t2"), serde_json::Value::Null)
332            .await
333            .expect("enqueue");
334
335        for _ in 0..50 {
336            let active = storage.get_active_task("t2").await.expect("ok");
337            if matches!(
338                active.as_ref().map(|t| t.status.state),
339                Some(TaskState::TaskStateInputRequired)
340            ) {
341                break;
342            }
343            tokio::time::sleep(Duration::from_millis(20)).await;
344        }
345        let active = storage
346            .get_active_task("t2")
347            .await
348            .expect("ok")
349            .expect("task should remain in active store");
350        assert_eq!(active.status.state, TaskState::TaskStateInputRequired);
351        assert_eq!(storage.get_stats().await.dead_letter_tasks, 0);
352
353        runner.shutdown().await;
354    }
355
356    #[tokio::test]
357    async fn handler_failure_routes_to_dead_letter_as_failed() {
358        let storage: Arc<InMemoryStorage> = Arc::new(InMemoryStorage::new());
359        let manager = DefaultTaskManager::new(
360            storage.clone() as Arc<dyn Storage>,
361            Arc::new(FailingHandler) as Arc<dyn TaskHandler>,
362            1,
363        );
364        let runner = manager.start();
365
366        storage
367            .enqueue_task(make_task("t3"), serde_json::Value::Null)
368            .await
369            .expect("enqueue");
370
371        let terminal = wait_for_terminal(&storage, "t3").await;
372        assert_eq!(terminal.status.state, TaskState::TaskStateFailed);
373
374        runner.shutdown().await;
375    }
376
377    #[tokio::test]
378    async fn shutdown_exits_workers_even_with_empty_queue() {
379        let storage: Arc<InMemoryStorage> = Arc::new(InMemoryStorage::new());
380        let handler = Arc::new(RecordingHandler {
381            seen: Arc::new(Mutex::new(Vec::new())),
382            terminal_state: TaskState::TaskStateCompleted,
383        });
384        let manager = DefaultTaskManager::new(
385            storage.clone() as Arc<dyn Storage>,
386            handler as Arc<dyn TaskHandler>,
387            2,
388        );
389        let runner = manager.start();
390
391        runner.shutdown().await;
392        assert_eq!(storage.get_stats().await.dead_letter_tasks, 0);
393    }
394}