Skip to main content

punch_kernel/
a2a_executor.rs

1//! A2A Task Executor — picks up pending A2A tasks and runs them through fighters.
2//!
3//! The [`A2ATaskExecutor`] polls a shared [`DashMap`] of tasks for any in
4//! [`Pending`](punch_types::a2a::A2ATaskStatus::Pending) status, spawns a
5//! temporary fighter for each, executes the task, and writes the result back.
6
7use std::sync::Arc;
8use std::time::Duration;
9
10use chrono::Utc;
11use dashmap::DashMap;
12use tokio::sync::watch;
13use tokio::task::JoinHandle;
14use tracing::{error, info, instrument};
15
16use punch_types::a2a::{A2ATask, A2ATaskInput, A2ATaskOutput, A2ATaskStatus};
17use punch_types::{FighterManifest, WeightClass};
18
19use crate::ring::Ring;
20
21/// Default polling interval for the executor (500ms).
22const DEFAULT_POLL_INTERVAL: Duration = Duration::from_millis(500);
23
24/// The A2A task executor: polls for pending tasks and executes them via fighters.
25pub struct A2ATaskExecutor {
26    /// Reference to the Ring for spawning fighters and sending messages.
27    ring: Arc<Ring>,
28    /// Shared task map (same instance as the HTTP handlers use).
29    tasks: Arc<DashMap<String, A2ATask>>,
30    /// Polling interval.
31    poll_interval: Duration,
32    /// Shutdown signal sender.
33    shutdown_tx: watch::Sender<bool>,
34    /// Shutdown signal receiver (cloned for the polling task).
35    shutdown_rx: watch::Receiver<bool>,
36    /// Handle to the background polling task.
37    handle: Option<JoinHandle<()>>,
38}
39
40impl A2ATaskExecutor {
41    /// Create a new executor that will poll the given task map and use the Ring
42    /// to spawn fighters for execution.
43    pub fn new(ring: Arc<Ring>, tasks: Arc<DashMap<String, A2ATask>>) -> Self {
44        let (shutdown_tx, shutdown_rx) = watch::channel(false);
45        Self {
46            ring,
47            tasks,
48            poll_interval: DEFAULT_POLL_INTERVAL,
49            shutdown_tx,
50            shutdown_rx,
51            handle: None,
52        }
53    }
54
55    /// Create a new executor with a custom polling interval.
56    pub fn with_poll_interval(
57        ring: Arc<Ring>,
58        tasks: Arc<DashMap<String, A2ATask>>,
59        poll_interval: Duration,
60    ) -> Self {
61        let (shutdown_tx, shutdown_rx) = watch::channel(false);
62        Self {
63            ring,
64            tasks,
65            poll_interval,
66            shutdown_tx,
67            shutdown_rx,
68            handle: None,
69        }
70    }
71
72    /// Start the background polling loop.
73    ///
74    /// Spawns a tokio task that polls the DashMap every `poll_interval` for
75    /// [`Pending`](A2ATaskStatus::Pending) tasks. Each pending task is picked
76    /// up and executed in its own spawned task.
77    pub fn start(&mut self) {
78        let ring = Arc::clone(&self.ring);
79        let tasks = Arc::clone(&self.tasks);
80        let interval = self.poll_interval;
81        let mut shutdown_rx = self.shutdown_rx.clone();
82
83        let handle = tokio::spawn(async move {
84            info!(
85                poll_interval_ms = interval.as_millis(),
86                "A2A task executor started"
87            );
88
89            loop {
90                tokio::select! {
91                    _ = tokio::time::sleep(interval) => {}
92                    _ = shutdown_rx.changed() => {
93                        if *shutdown_rx.borrow() {
94                            info!("A2A task executor received shutdown signal");
95                            break;
96                        }
97                    }
98                }
99
100                if *shutdown_rx.borrow() {
101                    break;
102                }
103
104                // Collect pending task IDs (avoid holding DashMap guards across await).
105                let pending_ids: Vec<String> = tasks
106                    .iter()
107                    .filter(|entry| entry.value().status == A2ATaskStatus::Pending)
108                    .map(|entry| entry.key().clone())
109                    .collect();
110
111                for task_id in pending_ids {
112                    // Transition to Running (atomic check-and-set).
113                    let task_input = {
114                        let mut entry = match tasks.get_mut(&task_id) {
115                            Some(e) => e,
116                            None => continue,
117                        };
118                        // Double-check it's still Pending (another poll may have grabbed it).
119                        if entry.status != A2ATaskStatus::Pending {
120                            continue;
121                        }
122                        entry.status = A2ATaskStatus::Running;
123                        entry.updated_at = Utc::now();
124                        entry.input.clone()
125                    };
126
127                    // Spawn execution in a separate task so we don't block polling.
128                    let ring = Arc::clone(&ring);
129                    let tasks = Arc::clone(&tasks);
130                    let id = task_id.clone();
131
132                    tokio::spawn(async move {
133                        execute_task(ring, tasks, id, task_input).await;
134                    });
135                }
136            }
137
138            info!("A2A task executor stopped");
139        });
140
141        self.handle = Some(handle);
142    }
143
144    /// Stop the polling loop.
145    pub fn stop(&mut self) {
146        let _ = self.shutdown_tx.send(true);
147        if let Some(handle) = self.handle.take() {
148            handle.abort();
149        }
150        info!("A2A task executor stop requested");
151    }
152
153    /// Returns `true` if the executor is currently running.
154    pub fn is_running(&self) -> bool {
155        self.handle
156            .as_ref()
157            .is_some_and(|h| !h.is_finished())
158    }
159}
160
161impl Drop for A2ATaskExecutor {
162    fn drop(&mut self) {
163        // Best-effort shutdown on drop.
164        let _ = self.shutdown_tx.send(true);
165        if let Some(handle) = self.handle.take() {
166            handle.abort();
167        }
168    }
169}
170
171/// Execute a single A2A task: spawn a fighter, send the prompt, collect the
172/// result, and update the DashMap.
173#[instrument(skip(ring, tasks, task_input), fields(task_id = %task_id))]
174async fn execute_task(
175    ring: Arc<Ring>,
176    tasks: Arc<DashMap<String, A2ATask>>,
177    task_id: String,
178    task_input: serde_json::Value,
179) {
180    // Extract the prompt from the input.
181    let prompt = extract_prompt(&task_input);
182
183    // Build a temporary fighter manifest for this task.
184    let manifest = FighterManifest {
185        name: format!("a2a-task-{}", &task_id[..8.min(task_id.len())]),
186        description: format!("Temporary fighter for A2A task {task_id}"),
187        model: ring.config().default_model.clone(),
188        system_prompt: build_task_system_prompt(&task_input),
189        capabilities: Vec::new(),
190        weight_class: WeightClass::Middleweight,
191        tenant_id: None,
192    };
193
194    // Spawn the fighter.
195    let fighter_id = ring.spawn_fighter(manifest).await;
196
197    // Send the prompt and collect the result.
198    let result = ring.send_message(&fighter_id, prompt).await;
199
200    // Update the task based on the result.
201    match result {
202        Ok(loop_result) => {
203            if let Some(mut entry) = tasks.get_mut(&task_id) {
204                // Don't overwrite a Cancelled task.
205                if entry.status == A2ATaskStatus::Cancelled {
206                    info!(task_id = %task_id, "task was cancelled during execution, skipping update");
207                } else {
208                    let output = A2ATaskOutput {
209                        content: loop_result.response.clone(),
210                        data: Some(serde_json::json!({
211                            "tokens_used": loop_result.usage.total(),
212                            "iterations": loop_result.iterations,
213                            "tool_calls": loop_result.tool_calls_made,
214                        })),
215                        mode: "text".to_string(),
216                    };
217                    entry.status = A2ATaskStatus::Completed;
218                    entry.output =
219                        Some(serde_json::to_value(output).unwrap_or(serde_json::json!({})));
220                    entry.updated_at = Utc::now();
221                    info!(task_id = %task_id, "A2A task completed successfully");
222                }
223            }
224        }
225        Err(e) => {
226            error!(task_id = %task_id, error = %e, "A2A task execution failed");
227            if let Some(mut entry) = tasks.get_mut(&task_id)
228                && entry.status != A2ATaskStatus::Cancelled
229            {
230                entry.status = A2ATaskStatus::Failed(e.to_string());
231                entry.updated_at = Utc::now();
232            }
233        }
234    }
235
236    // Kill the temporary fighter.
237    ring.kill_fighter(&fighter_id);
238}
239
240/// Extract the prompt text from a task input JSON value.
241///
242/// Tries to parse as [`A2ATaskInput`] first, then falls back to looking for a
243/// "prompt" field, and finally uses the JSON as a string.
244fn extract_prompt(input: &serde_json::Value) -> String {
245    // Try structured A2ATaskInput.
246    if let Ok(structured) = serde_json::from_value::<A2ATaskInput>(input.clone()) {
247        return structured.prompt;
248    }
249
250    // Try a "prompt" field directly.
251    if let Some(prompt) = input.get("prompt").and_then(|v| v.as_str()) {
252        return prompt.to_string();
253    }
254
255    // Try a "message" field.
256    if let Some(msg) = input.get("message").and_then(|v| v.as_str()) {
257        return msg.to_string();
258    }
259
260    // Fall back to the JSON as a string.
261    if let Some(s) = input.as_str() {
262        return s.to_string();
263    }
264
265    input.to_string()
266}
267
268/// Build a system prompt for the task fighter, incorporating any context from
269/// the task input.
270fn build_task_system_prompt(input: &serde_json::Value) -> String {
271    let mut prompt = "You are an AI agent executing a task received via the A2A protocol. \
272                      Complete the task thoroughly and return a clear, actionable response."
273        .to_string();
274
275    // If the input has a context object, include it.
276    if let Some(context) = input.get("context")
277        && let Some(obj) = context.as_object()
278        && !obj.is_empty()
279    {
280        prompt.push_str("\n\n## Task Context\n");
281        for (key, value) in obj {
282            prompt.push_str(&format!("- **{key}**: {value}\n"));
283        }
284    }
285
286    prompt
287}
288
289// ---------------------------------------------------------------------------
290// Tests
291// ---------------------------------------------------------------------------
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use punch_types::a2a::A2ATaskStatus;
297
298    fn make_task(id: &str, status: A2ATaskStatus) -> A2ATask {
299        let now = Utc::now();
300        A2ATask {
301            id: id.to_string(),
302            status,
303            input: serde_json::json!({"prompt": "hello world"}),
304            output: None,
305            created_at: now,
306            updated_at: now,
307        }
308    }
309
310    #[test]
311    fn test_extract_prompt_structured() {
312        let input = serde_json::json!({
313            "prompt": "Summarize this code",
314            "context": {},
315            "mode": "text"
316        });
317        assert_eq!(extract_prompt(&input), "Summarize this code");
318    }
319
320    #[test]
321    fn test_extract_prompt_simple_prompt_field() {
322        let input = serde_json::json!({"prompt": "Do the thing"});
323        assert_eq!(extract_prompt(&input), "Do the thing");
324    }
325
326    #[test]
327    fn test_extract_prompt_message_field() {
328        let input = serde_json::json!({"message": "Hello agent"});
329        assert_eq!(extract_prompt(&input), "Hello agent");
330    }
331
332    #[test]
333    fn test_extract_prompt_string_value() {
334        let input = serde_json::json!("Just a string prompt");
335        assert_eq!(extract_prompt(&input), "Just a string prompt");
336    }
337
338    #[test]
339    fn test_extract_prompt_fallback_json() {
340        let input = serde_json::json!({"arbitrary": "data", "count": 42});
341        let result = extract_prompt(&input);
342        assert!(result.contains("arbitrary"));
343    }
344
345    #[test]
346    fn test_build_task_system_prompt_no_context() {
347        let input = serde_json::json!({"prompt": "hello"});
348        let prompt = build_task_system_prompt(&input);
349        assert!(prompt.contains("A2A protocol"));
350        assert!(!prompt.contains("Task Context"));
351    }
352
353    #[test]
354    fn test_build_task_system_prompt_with_context() {
355        let input = serde_json::json!({
356            "prompt": "hello",
357            "context": {
358                "language": "rust",
359                "project": "punch"
360            }
361        });
362        let prompt = build_task_system_prompt(&input);
363        assert!(prompt.contains("Task Context"));
364        assert!(prompt.contains("language"));
365        assert!(prompt.contains("rust"));
366    }
367
368    #[test]
369    fn test_build_task_system_prompt_empty_context() {
370        let input = serde_json::json!({
371            "prompt": "hello",
372            "context": {}
373        });
374        let prompt = build_task_system_prompt(&input);
375        assert!(!prompt.contains("Task Context"));
376    }
377
378    #[test]
379    fn test_executor_creation() {
380        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
381        // We can't easily create a real Ring in unit tests, so we test the
382        // components that don't require one.
383        assert_eq!(tasks.len(), 0);
384    }
385
386    #[test]
387    fn test_task_pending_to_running_transition() {
388        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
389        let task = make_task("task-001", A2ATaskStatus::Pending);
390        tasks.insert("task-001".to_string(), task);
391
392        // Simulate the executor picking up the task.
393        {
394            let mut entry = tasks.get_mut("task-001").unwrap();
395            assert_eq!(entry.status, A2ATaskStatus::Pending);
396            entry.status = A2ATaskStatus::Running;
397            entry.updated_at = Utc::now();
398        }
399
400        let entry = tasks.get("task-001").unwrap();
401        assert_eq!(entry.status, A2ATaskStatus::Running);
402    }
403
404    #[test]
405    fn test_task_running_to_completed_transition() {
406        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
407        let task = make_task("task-002", A2ATaskStatus::Running);
408        tasks.insert("task-002".to_string(), task);
409
410        // Simulate successful completion.
411        {
412            let mut entry = tasks.get_mut("task-002").unwrap();
413            let output = A2ATaskOutput {
414                content: "Task result here".to_string(),
415                data: None,
416                mode: "text".to_string(),
417            };
418            entry.status = A2ATaskStatus::Completed;
419            entry.output = Some(serde_json::to_value(output).unwrap());
420            entry.updated_at = Utc::now();
421        }
422
423        let entry = tasks.get("task-002").unwrap();
424        assert_eq!(entry.status, A2ATaskStatus::Completed);
425        assert!(entry.output.is_some());
426    }
427
428    #[test]
429    fn test_task_running_to_failed_transition() {
430        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
431        let task = make_task("task-003", A2ATaskStatus::Running);
432        tasks.insert("task-003".to_string(), task);
433
434        // Simulate failure.
435        {
436            let mut entry = tasks.get_mut("task-003").unwrap();
437            entry.status = A2ATaskStatus::Failed("LLM provider error".to_string());
438            entry.updated_at = Utc::now();
439        }
440
441        let entry = tasks.get("task-003").unwrap();
442        assert!(matches!(entry.status, A2ATaskStatus::Failed(ref msg) if msg.contains("LLM provider")));
443    }
444
445    #[test]
446    fn test_multiple_concurrent_tasks() {
447        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
448
449        // Insert multiple pending tasks.
450        for i in 0..5 {
451            let task = make_task(&format!("concurrent-{i}"), A2ATaskStatus::Pending);
452            tasks.insert(format!("concurrent-{i}"), task);
453        }
454
455        // Collect pending IDs (simulating executor poll).
456        let pending_ids: Vec<String> = tasks
457            .iter()
458            .filter(|e| e.status == A2ATaskStatus::Pending)
459            .map(|e| e.key().clone())
460            .collect();
461
462        assert_eq!(pending_ids.len(), 5);
463
464        // Transition all to Running.
465        for id in &pending_ids {
466            let mut entry = tasks.get_mut(id).unwrap();
467            entry.status = A2ATaskStatus::Running;
468        }
469
470        // Verify all are Running.
471        let running_count = tasks
472            .iter()
473            .filter(|e| e.status == A2ATaskStatus::Running)
474            .count();
475        assert_eq!(running_count, 5);
476
477        // Complete some, fail others.
478        for (i, id) in pending_ids.iter().enumerate() {
479            let mut entry = tasks.get_mut(id).unwrap();
480            if i % 2 == 0 {
481                entry.status = A2ATaskStatus::Completed;
482                entry.output = Some(serde_json::json!({"result": "ok"}));
483            } else {
484                entry.status = A2ATaskStatus::Failed("test error".to_string());
485            }
486        }
487
488        let completed = tasks
489            .iter()
490            .filter(|e| e.status == A2ATaskStatus::Completed)
491            .count();
492        let failed = tasks
493            .iter()
494            .filter(|e| matches!(e.status, A2ATaskStatus::Failed(_)))
495            .count();
496        assert_eq!(completed, 3);
497        assert_eq!(failed, 2);
498    }
499
500    #[test]
501    fn test_cancelled_task_not_overwritten() {
502        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
503        let task = make_task("task-cancel", A2ATaskStatus::Running);
504        tasks.insert("task-cancel".to_string(), task);
505
506        // Cancel the task while it's "running".
507        {
508            let mut entry = tasks.get_mut("task-cancel").unwrap();
509            entry.status = A2ATaskStatus::Cancelled;
510            entry.updated_at = Utc::now();
511        }
512
513        // Simulate executor trying to write a result after cancellation.
514        {
515            let mut entry = tasks.get_mut("task-cancel").unwrap();
516            if entry.status != A2ATaskStatus::Cancelled {
517                entry.status = A2ATaskStatus::Completed;
518                entry.output = Some(serde_json::json!({"result": "should not appear"}));
519            }
520        }
521
522        let entry = tasks.get("task-cancel").unwrap();
523        assert_eq!(entry.status, A2ATaskStatus::Cancelled);
524        assert!(entry.output.is_none());
525    }
526
527    #[test]
528    fn test_completed_task_has_output() {
529        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
530        let task = make_task("task-output", A2ATaskStatus::Running);
531        tasks.insert("task-output".to_string(), task);
532
533        let output = A2ATaskOutput {
534            content: "The answer is 42".to_string(),
535            data: Some(serde_json::json!({"tokens_used": 100})),
536            mode: "text".to_string(),
537        };
538
539        {
540            let mut entry = tasks.get_mut("task-output").unwrap();
541            entry.status = A2ATaskStatus::Completed;
542            entry.output = Some(serde_json::to_value(&output).unwrap());
543            entry.updated_at = Utc::now();
544        }
545
546        let entry = tasks.get("task-output").unwrap();
547        assert_eq!(entry.status, A2ATaskStatus::Completed);
548        let stored_output = entry.output.as_ref().unwrap();
549        assert_eq!(stored_output["content"], "The answer is 42");
550        assert_eq!(stored_output["mode"], "text");
551    }
552
553    #[test]
554    fn test_failed_task_has_error_message() {
555        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
556        let task = make_task("task-err", A2ATaskStatus::Running);
557        tasks.insert("task-err".to_string(), task);
558
559        {
560            let mut entry = tasks.get_mut("task-err").unwrap();
561            entry.status = A2ATaskStatus::Failed("connection timeout to provider".to_string());
562            entry.updated_at = Utc::now();
563        }
564
565        let entry = tasks.get("task-err").unwrap();
566        match &entry.status {
567            A2ATaskStatus::Failed(msg) => {
568                assert!(msg.contains("connection timeout"));
569            }
570            _ => panic!("expected Failed status"),
571        }
572    }
573
574    #[tokio::test]
575    async fn test_stop_cancellation() {
576        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
577        let (shutdown_tx, shutdown_rx) = watch::channel(false);
578
579        // Simulate executor components without a real Ring.
580        let mut shutdown_rx_clone = shutdown_rx.clone();
581
582        let handle = tokio::spawn(async move {
583            loop {
584                tokio::select! {
585                    _ = tokio::time::sleep(Duration::from_millis(50)) => {}
586                    _ = shutdown_rx_clone.changed() => {
587                        if *shutdown_rx_clone.borrow() {
588                            break;
589                        }
590                    }
591                }
592                if *shutdown_rx_clone.borrow() {
593                    break;
594                }
595            }
596        });
597
598        // Let it run briefly.
599        tokio::time::sleep(Duration::from_millis(100)).await;
600        assert!(!handle.is_finished());
601
602        // Send shutdown.
603        let _ = shutdown_tx.send(true);
604        tokio::time::sleep(Duration::from_millis(100)).await;
605        assert!(handle.is_finished());
606
607        // Tasks map should still be accessible.
608        assert_eq!(tasks.len(), 0);
609    }
610
611    #[test]
612    fn test_pending_task_skipped_if_already_claimed() {
613        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
614        let task = make_task("task-race", A2ATaskStatus::Pending);
615        tasks.insert("task-race".to_string(), task);
616
617        // First "poll" claims it.
618        {
619            let mut entry = tasks.get_mut("task-race").unwrap();
620            if entry.status == A2ATaskStatus::Pending {
621                entry.status = A2ATaskStatus::Running;
622            }
623        }
624
625        // Second "poll" should see Running, not Pending.
626        {
627            let entry = tasks.get("task-race").unwrap();
628            assert_eq!(entry.status, A2ATaskStatus::Running);
629        }
630
631        // Collect pending — should be empty now.
632        let pending: Vec<String> = tasks
633            .iter()
634            .filter(|e| e.status == A2ATaskStatus::Pending)
635            .map(|e| e.key().clone())
636            .collect();
637        assert!(pending.is_empty());
638    }
639
640    #[test]
641    fn test_extract_prompt_with_context_and_prompt() {
642        let input = serde_json::json!({
643            "prompt": "Analyze this code",
644            "context": {
645                "language": "rust"
646            },
647            "mode": "text"
648        });
649        assert_eq!(extract_prompt(&input), "Analyze this code");
650    }
651
652    #[test]
653    fn test_extract_prompt_numeric_value() {
654        let input = serde_json::json!(42);
655        let result = extract_prompt(&input);
656        assert_eq!(result, "42");
657    }
658
659    #[test]
660    fn test_extract_prompt_null_value() {
661        let input = serde_json::json!(null);
662        let result = extract_prompt(&input);
663        assert_eq!(result, "null");
664    }
665
666    #[test]
667    fn test_extract_prompt_array_value() {
668        let input = serde_json::json!(["a", "b"]);
669        let result = extract_prompt(&input);
670        assert!(result.contains('a'));
671    }
672
673    #[test]
674    fn test_extract_prompt_empty_object() {
675        let input = serde_json::json!({});
676        let result = extract_prompt(&input);
677        assert!(!result.is_empty());
678    }
679
680    #[test]
681    fn test_extract_prompt_prefers_structured_over_prompt_field() {
682        // A2ATaskInput has prompt, context, mode fields.
683        let input = serde_json::json!({
684            "prompt": "structured prompt",
685            "context": {},
686            "mode": "text"
687        });
688        assert_eq!(extract_prompt(&input), "structured prompt");
689    }
690
691    #[test]
692    fn test_extract_prompt_message_over_json_fallback() {
693        let input = serde_json::json!({
694            "message": "msg field",
695            "other": "data"
696        });
697        assert_eq!(extract_prompt(&input), "msg field");
698    }
699
700    #[test]
701    fn test_build_task_system_prompt_with_multiple_context_keys() {
702        let input = serde_json::json!({
703            "prompt": "do stuff",
704            "context": {
705                "a": "1",
706                "b": "2",
707                "c": "3"
708            }
709        });
710        let prompt = build_task_system_prompt(&input);
711        assert!(prompt.contains("Task Context"));
712        assert!(prompt.contains("**a**"));
713        assert!(prompt.contains("**b**"));
714        assert!(prompt.contains("**c**"));
715    }
716
717    #[test]
718    fn test_build_task_system_prompt_null_context() {
719        let input = serde_json::json!({
720            "prompt": "hello",
721            "context": null
722        });
723        let prompt = build_task_system_prompt(&input);
724        assert!(!prompt.contains("Task Context"));
725    }
726
727    #[test]
728    fn test_build_task_system_prompt_context_is_string() {
729        let input = serde_json::json!({
730            "prompt": "hello",
731            "context": "not an object"
732        });
733        let prompt = build_task_system_prompt(&input);
734        assert!(!prompt.contains("Task Context"));
735    }
736
737    #[test]
738    fn test_task_lifecycle_pending_running_completed() {
739        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
740        let task = make_task("lifecycle", A2ATaskStatus::Pending);
741        tasks.insert("lifecycle".to_string(), task);
742
743        // Step 1: Pending -> Running.
744        {
745            let mut entry = tasks.get_mut("lifecycle").unwrap();
746            assert_eq!(entry.status, A2ATaskStatus::Pending);
747            entry.status = A2ATaskStatus::Running;
748        }
749
750        // Step 2: Running -> Completed.
751        {
752            let mut entry = tasks.get_mut("lifecycle").unwrap();
753            assert_eq!(entry.status, A2ATaskStatus::Running);
754            entry.status = A2ATaskStatus::Completed;
755            entry.output = Some(serde_json::json!({"result": "done"}));
756        }
757
758        let entry = tasks.get("lifecycle").unwrap();
759        assert_eq!(entry.status, A2ATaskStatus::Completed);
760        assert!(entry.output.is_some());
761    }
762
763    #[test]
764    fn test_task_lifecycle_pending_running_failed() {
765        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
766        let task = make_task("fail-life", A2ATaskStatus::Pending);
767        tasks.insert("fail-life".to_string(), task);
768
769        {
770            let mut entry = tasks.get_mut("fail-life").unwrap();
771            entry.status = A2ATaskStatus::Running;
772        }
773        {
774            let mut entry = tasks.get_mut("fail-life").unwrap();
775            entry.status = A2ATaskStatus::Failed("some error".to_string());
776        }
777
778        let entry = tasks.get("fail-life").unwrap();
779        assert!(matches!(entry.status, A2ATaskStatus::Failed(_)));
780    }
781
782    #[test]
783    fn test_failed_task_preserves_error_detail() {
784        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
785        let task = make_task("err-detail", A2ATaskStatus::Running);
786        tasks.insert("err-detail".to_string(), task);
787
788        let error_msg = "rate limit exceeded: retry after 60s".to_string();
789        {
790            let mut entry = tasks.get_mut("err-detail").unwrap();
791            entry.status = A2ATaskStatus::Failed(error_msg.clone());
792        }
793
794        let entry = tasks.get("err-detail").unwrap();
795        match &entry.status {
796            A2ATaskStatus::Failed(msg) => assert_eq!(msg, &error_msg),
797            _ => panic!("expected Failed"),
798        }
799    }
800
801    #[test]
802    fn test_concurrent_task_isolation() {
803        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
804
805        // Create independent tasks.
806        tasks.insert("t1".to_string(), make_task("t1", A2ATaskStatus::Pending));
807        tasks.insert("t2".to_string(), make_task("t2", A2ATaskStatus::Running));
808        tasks.insert("t3".to_string(), make_task("t3", A2ATaskStatus::Completed));
809
810        // Modifying one doesn't affect others.
811        {
812            let mut entry = tasks.get_mut("t1").unwrap();
813            entry.status = A2ATaskStatus::Running;
814        }
815
816        assert_eq!(tasks.get("t1").unwrap().status, A2ATaskStatus::Running);
817        assert_eq!(tasks.get("t2").unwrap().status, A2ATaskStatus::Running);
818        assert_eq!(tasks.get("t3").unwrap().status, A2ATaskStatus::Completed);
819    }
820
821    #[test]
822    fn test_task_output_with_structured_data() {
823        let output = A2ATaskOutput {
824            content: "Result text".to_string(),
825            data: Some(serde_json::json!({
826                "tokens_used": 500,
827                "iterations": 3,
828                "tool_calls": 2,
829            })),
830            mode: "text".to_string(),
831        };
832        let json = serde_json::to_value(&output).unwrap();
833        assert_eq!(json["content"], "Result text");
834        assert_eq!(json["data"]["tokens_used"], 500);
835        assert_eq!(json["data"]["iterations"], 3);
836    }
837
838    #[test]
839    fn test_task_removal() {
840        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
841        tasks.insert("rm-task".to_string(), make_task("rm-task", A2ATaskStatus::Completed));
842
843        assert!(tasks.contains_key("rm-task"));
844        tasks.remove("rm-task");
845        assert!(!tasks.contains_key("rm-task"));
846    }
847
848    #[test]
849    fn test_task_updated_at_changes() {
850        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
851        let task = make_task("time-task", A2ATaskStatus::Pending);
852        let original_time = task.updated_at;
853        tasks.insert("time-task".to_string(), task);
854
855        // Small sleep to ensure time difference.
856        std::thread::sleep(std::time::Duration::from_millis(10));
857
858        {
859            let mut entry = tasks.get_mut("time-task").unwrap();
860            entry.status = A2ATaskStatus::Running;
861            entry.updated_at = Utc::now();
862        }
863
864        let entry = tasks.get("time-task").unwrap();
865        assert!(entry.updated_at >= original_time);
866    }
867}