inference_gateway_adk/server/
task_manager.rs1use super::storage::Storage;
22use super::task_handler::TaskHandler;
23use crate::a2a_types::{TaskState, TaskStatus, Timestamp};
24use std::sync::Arc;
25use std::time::Duration;
26use tokio::task::JoinSet;
27use tokio_util::sync::CancellationToken;
28use tracing::{debug, warn};
29
30pub struct DefaultTaskManager {
35 storage: Arc<dyn Storage>,
36 handler: Arc<dyn TaskHandler>,
37 worker_count: usize,
38}
39
40impl std::fmt::Debug for DefaultTaskManager {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 f.debug_struct("DefaultTaskManager")
43 .field("worker_count", &self.worker_count)
44 .finish_non_exhaustive()
45 }
46}
47
48impl DefaultTaskManager {
49 pub fn new(
50 storage: Arc<dyn Storage>,
51 handler: Arc<dyn TaskHandler>,
52 worker_count: usize,
53 ) -> Self {
54 let worker_count = worker_count.max(1);
55 Self {
56 storage,
57 handler,
58 worker_count,
59 }
60 }
61
62 pub fn start(&self) -> TaskManagerRunner {
67 let token = CancellationToken::new();
68 let mut join_set: JoinSet<()> = JoinSet::new();
69 for worker_id in 0..self.worker_count {
70 let storage = Arc::clone(&self.storage);
71 let handler = Arc::clone(&self.handler);
72 let token = token.clone();
73 join_set.spawn(async move {
74 run_worker(worker_id, storage, handler, token).await;
75 });
76 }
77 debug!("task manager started with {} worker(s)", self.worker_count);
78 TaskManagerRunner {
79 shutdown: token,
80 join_set,
81 }
82 }
83}
84
85#[derive(Debug)]
89pub struct TaskManagerRunner {
90 shutdown: CancellationToken,
91 join_set: JoinSet<()>,
92}
93
94impl TaskManagerRunner {
95 pub async fn shutdown(mut self) {
99 self.shutdown.cancel();
100 while self.join_set.join_next().await.is_some() {}
101 debug!("task manager shutdown complete");
102 }
103
104 pub fn cancel(&self) {
107 self.shutdown.cancel();
108 }
109}
110
111async fn run_worker(
112 worker_id: usize,
113 storage: Arc<dyn Storage>,
114 handler: Arc<dyn TaskHandler>,
115 shutdown: CancellationToken,
116) {
117 debug!(worker_id, "task manager worker started");
118 loop {
119 let queued = tokio::select! {
120 biased;
121 _ = shutdown.cancelled() => {
122 debug!(worker_id, "task manager worker exiting on cancellation");
123 return;
124 }
125 res = storage.dequeue_task() => match res {
126 Ok(q) => q,
127 Err(e) => {
128 warn!(worker_id, error = %e, "dequeue_task failed; backing off");
129 tokio::select! {
130 _ = shutdown.cancelled() => return,
131 _ = tokio::time::sleep(Duration::from_secs(1)) => continue,
132 }
133 }
134 }
135 };
136
137 let task = queued.task;
138 let task_id = task.id.clone();
139
140 if let Err(e) = storage.create_active_task(&task).await {
141 debug!(worker_id, task_id = %task_id, error = %e, "create_active_task: continuing");
142 }
143
144 let last_message = task.history.last().cloned();
145 match handler.handle_task(task.clone(), last_message).await {
146 Ok(result) => route_terminal_or_active(&storage, worker_id, result).await,
147 Err(e) => {
148 warn!(worker_id, task_id = %task_id, error = %e, "task handler failed");
149 let mut failed = task;
150 failed.status = TaskStatus {
151 message: failed.status.message.clone(),
152 state: TaskState::TaskStateFailed,
153 timestamp: Some(Timestamp(chrono::Utc::now())),
154 };
155 if let Err(store_err) = storage.store_dead_letter_task(&failed).await {
156 warn!(worker_id, task_id = %task_id, error = %store_err,
157 "store_dead_letter_task failed after handler error");
158 }
159 }
160 }
161 }
162}
163
164async fn route_terminal_or_active(
165 storage: &Arc<dyn Storage>,
166 worker_id: usize,
167 result: crate::a2a_types::Task,
168) {
169 let terminal = matches!(
170 result.status.state,
171 TaskState::TaskStateCompleted
172 | TaskState::TaskStateFailed
173 | TaskState::TaskStateCancelled
174 | TaskState::TaskStateRejected
175 );
176 if terminal {
177 if let Err(e) = storage.store_dead_letter_task(&result).await {
178 warn!(worker_id, task_id = %result.id, error = %e, "store_dead_letter_task failed");
179 }
180 } else if let Err(e) = storage.update_active_task(&result).await {
181 warn!(worker_id, task_id = %result.id, error = %e, "update_active_task failed");
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use crate::a2a_types::{
189 Message as A2AMessage, Part, Role, Task, TaskState, TaskStatus, Timestamp,
190 };
191 use crate::server::storage::InMemoryStorage;
192 use crate::server::task_handler::TaskHandler;
193 use anyhow::Result;
194 use async_trait::async_trait;
195 use std::sync::Mutex;
196
197 fn make_task(id: &str) -> Task {
198 Task {
199 artifacts: vec![],
200 context_id: "ctx".to_string(),
201 history: vec![A2AMessage {
202 context_id: Some("ctx".to_string()),
203 extensions: vec![],
204 message_id: format!("msg-{id}"),
205 metadata: None,
206 parts: vec![Part {
207 data: None,
208 file: None,
209 metadata: None,
210 text: Some("hello".to_string()),
211 }],
212 reference_task_ids: vec![],
213 role: Role::RoleUser,
214 task_id: Some(id.to_string()),
215 }],
216 id: id.to_string(),
217 metadata: None,
218 status: TaskStatus {
219 message: None,
220 state: TaskState::TaskStateSubmitted,
221 timestamp: Some(Timestamp(chrono::Utc::now())),
222 },
223 }
224 }
225
226 #[derive(Debug)]
229 struct RecordingHandler {
230 seen: Arc<Mutex<Vec<String>>>,
231 terminal_state: TaskState,
232 }
233
234 #[async_trait]
235 impl TaskHandler for RecordingHandler {
236 async fn handle_task(&self, mut task: Task, _message: Option<A2AMessage>) -> Result<Task> {
237 self.seen
238 .lock()
239 .expect("mutex poisoned")
240 .push(task.id.clone());
241 task.status = TaskStatus {
242 message: None,
243 state: self.terminal_state,
244 timestamp: Some(Timestamp(chrono::Utc::now())),
245 };
246 Ok(task)
247 }
248 }
249
250 #[derive(Debug)]
253 struct FailingHandler;
254
255 #[async_trait]
256 impl TaskHandler for FailingHandler {
257 async fn handle_task(&self, _task: Task, _message: Option<A2AMessage>) -> Result<Task> {
258 Err(anyhow::anyhow!("handler always fails"))
259 }
260 }
261
262 async fn wait_for_terminal(storage: &Arc<InMemoryStorage>, task_id: &str) -> Task {
263 for _ in 0..50 {
264 if let Some(task) = storage.get_task(task_id).await
265 && matches!(
266 task.status.state,
267 TaskState::TaskStateCompleted
268 | TaskState::TaskStateFailed
269 | TaskState::TaskStateCancelled
270 | TaskState::TaskStateRejected
271 )
272 {
273 return task;
274 }
275 tokio::time::sleep(Duration::from_millis(20)).await;
276 }
277 panic!("task {task_id} never reached terminal state");
278 }
279
280 #[tokio::test]
281 async fn worker_dequeues_and_routes_completed_to_dead_letter() {
282 let storage: Arc<InMemoryStorage> = Arc::new(InMemoryStorage::new());
283 let seen = Arc::new(Mutex::new(Vec::new()));
284 let handler = Arc::new(RecordingHandler {
285 seen: Arc::clone(&seen),
286 terminal_state: TaskState::TaskStateCompleted,
287 });
288
289 let manager = DefaultTaskManager::new(
290 storage.clone() as Arc<dyn Storage>,
291 handler as Arc<dyn TaskHandler>,
292 1,
293 );
294 let runner = manager.start();
295
296 storage
297 .enqueue_task(make_task("t1"), serde_json::Value::Null)
298 .await
299 .expect("enqueue");
300
301 let terminal = wait_for_terminal(&storage, "t1").await;
302 assert_eq!(terminal.status.state, TaskState::TaskStateCompleted);
303 assert!(
304 storage.get_active_task("t1").await.expect("ok").is_none(),
305 "completed tasks must be evicted from active store",
306 );
307 let stats = storage.get_stats().await;
308 assert_eq!(stats.dead_letter_tasks, 1);
309 assert_eq!(stats.active_tasks, 0);
310 assert_eq!(seen.lock().expect("mutex poisoned").as_slice(), &["t1"]);
311
312 runner.shutdown().await;
313 }
314
315 #[tokio::test]
316 async fn worker_routes_input_required_to_active_store() {
317 let storage: Arc<InMemoryStorage> = Arc::new(InMemoryStorage::new());
318 let handler = Arc::new(RecordingHandler {
319 seen: Arc::new(Mutex::new(Vec::new())),
320 terminal_state: TaskState::TaskStateInputRequired,
321 });
322
323 let manager = DefaultTaskManager::new(
324 storage.clone() as Arc<dyn Storage>,
325 handler as Arc<dyn TaskHandler>,
326 1,
327 );
328 let runner = manager.start();
329
330 storage
331 .enqueue_task(make_task("t2"), serde_json::Value::Null)
332 .await
333 .expect("enqueue");
334
335 for _ in 0..50 {
336 let active = storage.get_active_task("t2").await.expect("ok");
337 if matches!(
338 active.as_ref().map(|t| t.status.state),
339 Some(TaskState::TaskStateInputRequired)
340 ) {
341 break;
342 }
343 tokio::time::sleep(Duration::from_millis(20)).await;
344 }
345 let active = storage
346 .get_active_task("t2")
347 .await
348 .expect("ok")
349 .expect("task should remain in active store");
350 assert_eq!(active.status.state, TaskState::TaskStateInputRequired);
351 assert_eq!(storage.get_stats().await.dead_letter_tasks, 0);
352
353 runner.shutdown().await;
354 }
355
356 #[tokio::test]
357 async fn handler_failure_routes_to_dead_letter_as_failed() {
358 let storage: Arc<InMemoryStorage> = Arc::new(InMemoryStorage::new());
359 let manager = DefaultTaskManager::new(
360 storage.clone() as Arc<dyn Storage>,
361 Arc::new(FailingHandler) as Arc<dyn TaskHandler>,
362 1,
363 );
364 let runner = manager.start();
365
366 storage
367 .enqueue_task(make_task("t3"), serde_json::Value::Null)
368 .await
369 .expect("enqueue");
370
371 let terminal = wait_for_terminal(&storage, "t3").await;
372 assert_eq!(terminal.status.state, TaskState::TaskStateFailed);
373
374 runner.shutdown().await;
375 }
376
377 #[tokio::test]
378 async fn shutdown_exits_workers_even_with_empty_queue() {
379 let storage: Arc<InMemoryStorage> = Arc::new(InMemoryStorage::new());
380 let handler = Arc::new(RecordingHandler {
381 seen: Arc::new(Mutex::new(Vec::new())),
382 terminal_state: TaskState::TaskStateCompleted,
383 });
384 let manager = DefaultTaskManager::new(
385 storage.clone() as Arc<dyn Storage>,
386 handler as Arc<dyn TaskHandler>,
387 2,
388 );
389 let runner = manager.start();
390
391 runner.shutdown().await;
392 assert_eq!(storage.get_stats().await.dead_letter_tasks, 0);
393 }
394}