Skip to main content

punch_kernel/
a2a_executor.rs

1//! A2A Task Executor — picks up pending A2A tasks and runs them through fighters.
2//!
3//! The [`A2ATaskExecutor`] polls a shared [`DashMap`] of tasks for any in
4//! [`Pending`](punch_types::a2a::A2ATaskStatus::Pending) status, spawns a
5//! temporary fighter for each, executes the task, and writes the result back.
6
7use std::sync::Arc;
8use std::time::Duration;
9
10use chrono::Utc;
11use dashmap::DashMap;
12use tokio::sync::watch;
13use tokio::task::JoinHandle;
14use tracing::{error, info, instrument};
15
16use punch_types::a2a::{A2ATask, A2ATaskInput, A2ATaskOutput, A2ATaskStatus};
17use punch_types::{FighterManifest, WeightClass};
18
19use crate::ring::Ring;
20
21/// Default polling interval for the executor (500ms).
22const DEFAULT_POLL_INTERVAL: Duration = Duration::from_millis(500);
23
24/// The A2A task executor: polls for pending tasks and executes them via fighters.
25pub struct A2ATaskExecutor {
26    /// Reference to the Ring for spawning fighters and sending messages.
27    ring: Arc<Ring>,
28    /// Shared task map (same instance as the HTTP handlers use).
29    tasks: Arc<DashMap<String, A2ATask>>,
30    /// Polling interval.
31    poll_interval: Duration,
32    /// Shutdown signal sender.
33    shutdown_tx: watch::Sender<bool>,
34    /// Shutdown signal receiver (cloned for the polling task).
35    shutdown_rx: watch::Receiver<bool>,
36    /// Handle to the background polling task.
37    handle: Option<JoinHandle<()>>,
38}
39
40impl A2ATaskExecutor {
41    /// Create a new executor that will poll the given task map and use the Ring
42    /// to spawn fighters for execution.
43    pub fn new(ring: Arc<Ring>, tasks: Arc<DashMap<String, A2ATask>>) -> Self {
44        let (shutdown_tx, shutdown_rx) = watch::channel(false);
45        Self {
46            ring,
47            tasks,
48            poll_interval: DEFAULT_POLL_INTERVAL,
49            shutdown_tx,
50            shutdown_rx,
51            handle: None,
52        }
53    }
54
55    /// Create a new executor with a custom polling interval.
56    pub fn with_poll_interval(
57        ring: Arc<Ring>,
58        tasks: Arc<DashMap<String, A2ATask>>,
59        poll_interval: Duration,
60    ) -> Self {
61        let (shutdown_tx, shutdown_rx) = watch::channel(false);
62        Self {
63            ring,
64            tasks,
65            poll_interval,
66            shutdown_tx,
67            shutdown_rx,
68            handle: None,
69        }
70    }
71
72    /// Start the background polling loop.
73    ///
74    /// Spawns a tokio task that polls the DashMap every `poll_interval` for
75    /// [`Pending`](A2ATaskStatus::Pending) tasks. Each pending task is picked
76    /// up and executed in its own spawned task.
77    pub fn start(&mut self) {
78        let ring = Arc::clone(&self.ring);
79        let tasks = Arc::clone(&self.tasks);
80        let interval = self.poll_interval;
81        let mut shutdown_rx = self.shutdown_rx.clone();
82
83        let handle = tokio::spawn(async move {
84            info!(
85                poll_interval_ms = interval.as_millis(),
86                "A2A task executor started"
87            );
88
89            loop {
90                tokio::select! {
91                    _ = tokio::time::sleep(interval) => {}
92                    _ = shutdown_rx.changed() => {
93                        if *shutdown_rx.borrow() {
94                            info!("A2A task executor received shutdown signal");
95                            break;
96                        }
97                    }
98                }
99
100                if *shutdown_rx.borrow() {
101                    break;
102                }
103
104                // Collect pending task IDs (avoid holding DashMap guards across await).
105                let pending_ids: Vec<String> = tasks
106                    .iter()
107                    .filter(|entry| entry.value().status == A2ATaskStatus::Pending)
108                    .map(|entry| entry.key().clone())
109                    .collect();
110
111                for task_id in pending_ids {
112                    // Transition to Running (atomic check-and-set).
113                    let task_input = {
114                        let mut entry = match tasks.get_mut(&task_id) {
115                            Some(e) => e,
116                            None => continue,
117                        };
118                        // Double-check it's still Pending (another poll may have grabbed it).
119                        if entry.status != A2ATaskStatus::Pending {
120                            continue;
121                        }
122                        entry.status = A2ATaskStatus::Running;
123                        entry.updated_at = Utc::now();
124                        entry.input.clone()
125                    };
126
127                    // Spawn execution in a separate task so we don't block polling.
128                    let ring = Arc::clone(&ring);
129                    let tasks = Arc::clone(&tasks);
130                    let id = task_id.clone();
131
132                    tokio::spawn(async move {
133                        execute_task(ring, tasks, id, task_input).await;
134                    });
135                }
136            }
137
138            info!("A2A task executor stopped");
139        });
140
141        self.handle = Some(handle);
142    }
143
144    /// Stop the polling loop.
145    pub fn stop(&mut self) {
146        let _ = self.shutdown_tx.send(true);
147        if let Some(handle) = self.handle.take() {
148            handle.abort();
149        }
150        info!("A2A task executor stop requested");
151    }
152
153    /// Returns `true` if the executor is currently running.
154    pub fn is_running(&self) -> bool {
155        self.handle.as_ref().is_some_and(|h| !h.is_finished())
156    }
157}
158
159impl Drop for A2ATaskExecutor {
160    fn drop(&mut self) {
161        // Best-effort shutdown on drop.
162        let _ = self.shutdown_tx.send(true);
163        if let Some(handle) = self.handle.take() {
164            handle.abort();
165        }
166    }
167}
168
169/// Execute a single A2A task: spawn a fighter, send the prompt, collect the
170/// result, and update the DashMap.
171#[instrument(skip(ring, tasks, task_input), fields(task_id = %task_id))]
172async fn execute_task(
173    ring: Arc<Ring>,
174    tasks: Arc<DashMap<String, A2ATask>>,
175    task_id: String,
176    task_input: serde_json::Value,
177) {
178    // Extract the prompt from the input.
179    let prompt = extract_prompt(&task_input);
180
181    // Build a temporary fighter manifest for this task.
182    let manifest = FighterManifest {
183        name: format!("a2a-task-{}", &task_id[..8.min(task_id.len())]),
184        description: format!("Temporary fighter for A2A task {task_id}"),
185        model: ring.config().default_model.clone(),
186        system_prompt: build_task_system_prompt(&task_input),
187        capabilities: Vec::new(),
188        weight_class: WeightClass::Middleweight,
189        tenant_id: None,
190    };
191
192    // Spawn the fighter.
193    let fighter_id = ring.spawn_fighter(manifest).await;
194
195    // Send the prompt and collect the result.
196    let result = ring.send_message(&fighter_id, prompt).await;
197
198    // Update the task based on the result.
199    match result {
200        Ok(loop_result) => {
201            if let Some(mut entry) = tasks.get_mut(&task_id) {
202                // Don't overwrite a Cancelled task.
203                if entry.status == A2ATaskStatus::Cancelled {
204                    info!(task_id = %task_id, "task was cancelled during execution, skipping update");
205                } else {
206                    let output = A2ATaskOutput {
207                        content: loop_result.response.clone(),
208                        data: Some(serde_json::json!({
209                            "tokens_used": loop_result.usage.total(),
210                            "iterations": loop_result.iterations,
211                            "tool_calls": loop_result.tool_calls_made,
212                        })),
213                        mode: "text".to_string(),
214                    };
215                    entry.status = A2ATaskStatus::Completed;
216                    entry.output =
217                        Some(serde_json::to_value(output).unwrap_or(serde_json::json!({})));
218                    entry.updated_at = Utc::now();
219                    info!(task_id = %task_id, "A2A task completed successfully");
220                }
221            }
222        }
223        Err(e) => {
224            error!(task_id = %task_id, error = %e, "A2A task execution failed");
225            if let Some(mut entry) = tasks.get_mut(&task_id)
226                && entry.status != A2ATaskStatus::Cancelled
227            {
228                entry.status = A2ATaskStatus::Failed(e.to_string());
229                entry.updated_at = Utc::now();
230            }
231        }
232    }
233
234    // Kill the temporary fighter.
235    ring.kill_fighter(&fighter_id);
236}
237
238/// Extract the prompt text from a task input JSON value.
239///
240/// Tries to parse as [`A2ATaskInput`] first, then falls back to looking for a
241/// "prompt" field, and finally uses the JSON as a string.
242fn extract_prompt(input: &serde_json::Value) -> String {
243    // Try structured A2ATaskInput.
244    if let Ok(structured) = serde_json::from_value::<A2ATaskInput>(input.clone()) {
245        return structured.prompt;
246    }
247
248    // Try a "prompt" field directly.
249    if let Some(prompt) = input.get("prompt").and_then(|v| v.as_str()) {
250        return prompt.to_string();
251    }
252
253    // Try a "message" field.
254    if let Some(msg) = input.get("message").and_then(|v| v.as_str()) {
255        return msg.to_string();
256    }
257
258    // Fall back to the JSON as a string.
259    if let Some(s) = input.as_str() {
260        return s.to_string();
261    }
262
263    input.to_string()
264}
265
266/// Build a system prompt for the task fighter, incorporating any context from
267/// the task input.
268fn build_task_system_prompt(input: &serde_json::Value) -> String {
269    let mut prompt = "You are an AI agent executing a task received via the A2A protocol. \
270                      Complete the task thoroughly and return a clear, actionable response."
271        .to_string();
272
273    // If the input has a context object, include it.
274    if let Some(context) = input.get("context")
275        && let Some(obj) = context.as_object()
276        && !obj.is_empty()
277    {
278        prompt.push_str("\n\n## Task Context\n");
279        for (key, value) in obj {
280            prompt.push_str(&format!("- **{key}**: {value}\n"));
281        }
282    }
283
284    prompt
285}
286
287// ---------------------------------------------------------------------------
288// Tests
289// ---------------------------------------------------------------------------
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use punch_types::a2a::A2ATaskStatus;
295
296    fn make_task(id: &str, status: A2ATaskStatus) -> A2ATask {
297        let now = Utc::now();
298        A2ATask {
299            id: id.to_string(),
300            status,
301            input: serde_json::json!({"prompt": "hello world"}),
302            output: None,
303            created_at: now,
304            updated_at: now,
305        }
306    }
307
308    #[test]
309    fn test_extract_prompt_structured() {
310        let input = serde_json::json!({
311            "prompt": "Summarize this code",
312            "context": {},
313            "mode": "text"
314        });
315        assert_eq!(extract_prompt(&input), "Summarize this code");
316    }
317
318    #[test]
319    fn test_extract_prompt_simple_prompt_field() {
320        let input = serde_json::json!({"prompt": "Do the thing"});
321        assert_eq!(extract_prompt(&input), "Do the thing");
322    }
323
324    #[test]
325    fn test_extract_prompt_message_field() {
326        let input = serde_json::json!({"message": "Hello agent"});
327        assert_eq!(extract_prompt(&input), "Hello agent");
328    }
329
330    #[test]
331    fn test_extract_prompt_string_value() {
332        let input = serde_json::json!("Just a string prompt");
333        assert_eq!(extract_prompt(&input), "Just a string prompt");
334    }
335
336    #[test]
337    fn test_extract_prompt_fallback_json() {
338        let input = serde_json::json!({"arbitrary": "data", "count": 42});
339        let result = extract_prompt(&input);
340        assert!(result.contains("arbitrary"));
341    }
342
343    #[test]
344    fn test_build_task_system_prompt_no_context() {
345        let input = serde_json::json!({"prompt": "hello"});
346        let prompt = build_task_system_prompt(&input);
347        assert!(prompt.contains("A2A protocol"));
348        assert!(!prompt.contains("Task Context"));
349    }
350
351    #[test]
352    fn test_build_task_system_prompt_with_context() {
353        let input = serde_json::json!({
354            "prompt": "hello",
355            "context": {
356                "language": "rust",
357                "project": "punch"
358            }
359        });
360        let prompt = build_task_system_prompt(&input);
361        assert!(prompt.contains("Task Context"));
362        assert!(prompt.contains("language"));
363        assert!(prompt.contains("rust"));
364    }
365
366    #[test]
367    fn test_build_task_system_prompt_empty_context() {
368        let input = serde_json::json!({
369            "prompt": "hello",
370            "context": {}
371        });
372        let prompt = build_task_system_prompt(&input);
373        assert!(!prompt.contains("Task Context"));
374    }
375
376    #[test]
377    fn test_executor_creation() {
378        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
379        // We can't easily create a real Ring in unit tests, so we test the
380        // components that don't require one.
381        assert_eq!(tasks.len(), 0);
382    }
383
384    #[test]
385    fn test_task_pending_to_running_transition() {
386        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
387        let task = make_task("task-001", A2ATaskStatus::Pending);
388        tasks.insert("task-001".to_string(), task);
389
390        // Simulate the executor picking up the task.
391        {
392            let mut entry = tasks.get_mut("task-001").unwrap();
393            assert_eq!(entry.status, A2ATaskStatus::Pending);
394            entry.status = A2ATaskStatus::Running;
395            entry.updated_at = Utc::now();
396        }
397
398        let entry = tasks.get("task-001").unwrap();
399        assert_eq!(entry.status, A2ATaskStatus::Running);
400    }
401
402    #[test]
403    fn test_task_running_to_completed_transition() {
404        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
405        let task = make_task("task-002", A2ATaskStatus::Running);
406        tasks.insert("task-002".to_string(), task);
407
408        // Simulate successful completion.
409        {
410            let mut entry = tasks.get_mut("task-002").unwrap();
411            let output = A2ATaskOutput {
412                content: "Task result here".to_string(),
413                data: None,
414                mode: "text".to_string(),
415            };
416            entry.status = A2ATaskStatus::Completed;
417            entry.output = Some(serde_json::to_value(output).unwrap());
418            entry.updated_at = Utc::now();
419        }
420
421        let entry = tasks.get("task-002").unwrap();
422        assert_eq!(entry.status, A2ATaskStatus::Completed);
423        assert!(entry.output.is_some());
424    }
425
426    #[test]
427    fn test_task_running_to_failed_transition() {
428        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
429        let task = make_task("task-003", A2ATaskStatus::Running);
430        tasks.insert("task-003".to_string(), task);
431
432        // Simulate failure.
433        {
434            let mut entry = tasks.get_mut("task-003").unwrap();
435            entry.status = A2ATaskStatus::Failed("LLM provider error".to_string());
436            entry.updated_at = Utc::now();
437        }
438
439        let entry = tasks.get("task-003").unwrap();
440        assert!(
441            matches!(entry.status, A2ATaskStatus::Failed(ref msg) if msg.contains("LLM provider"))
442        );
443    }
444
445    #[test]
446    fn test_multiple_concurrent_tasks() {
447        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
448
449        // Insert multiple pending tasks.
450        for i in 0..5 {
451            let task = make_task(&format!("concurrent-{i}"), A2ATaskStatus::Pending);
452            tasks.insert(format!("concurrent-{i}"), task);
453        }
454
455        // Collect pending IDs (simulating executor poll).
456        let pending_ids: Vec<String> = tasks
457            .iter()
458            .filter(|e| e.status == A2ATaskStatus::Pending)
459            .map(|e| e.key().clone())
460            .collect();
461
462        assert_eq!(pending_ids.len(), 5);
463
464        // Transition all to Running.
465        for id in &pending_ids {
466            let mut entry = tasks.get_mut(id).unwrap();
467            entry.status = A2ATaskStatus::Running;
468        }
469
470        // Verify all are Running.
471        let running_count = tasks
472            .iter()
473            .filter(|e| e.status == A2ATaskStatus::Running)
474            .count();
475        assert_eq!(running_count, 5);
476
477        // Complete some, fail others.
478        for (i, id) in pending_ids.iter().enumerate() {
479            let mut entry = tasks.get_mut(id).unwrap();
480            if i % 2 == 0 {
481                entry.status = A2ATaskStatus::Completed;
482                entry.output = Some(serde_json::json!({"result": "ok"}));
483            } else {
484                entry.status = A2ATaskStatus::Failed("test error".to_string());
485            }
486        }
487
488        let completed = tasks
489            .iter()
490            .filter(|e| e.status == A2ATaskStatus::Completed)
491            .count();
492        let failed = tasks
493            .iter()
494            .filter(|e| matches!(e.status, A2ATaskStatus::Failed(_)))
495            .count();
496        assert_eq!(completed, 3);
497        assert_eq!(failed, 2);
498    }
499
500    #[test]
501    fn test_cancelled_task_not_overwritten() {
502        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
503        let task = make_task("task-cancel", A2ATaskStatus::Running);
504        tasks.insert("task-cancel".to_string(), task);
505
506        // Cancel the task while it's "running".
507        {
508            let mut entry = tasks.get_mut("task-cancel").unwrap();
509            entry.status = A2ATaskStatus::Cancelled;
510            entry.updated_at = Utc::now();
511        }
512
513        // Simulate executor trying to write a result after cancellation.
514        {
515            let mut entry = tasks.get_mut("task-cancel").unwrap();
516            if entry.status != A2ATaskStatus::Cancelled {
517                entry.status = A2ATaskStatus::Completed;
518                entry.output = Some(serde_json::json!({"result": "should not appear"}));
519            }
520        }
521
522        let entry = tasks.get("task-cancel").unwrap();
523        assert_eq!(entry.status, A2ATaskStatus::Cancelled);
524        assert!(entry.output.is_none());
525    }
526
527    #[test]
528    fn test_completed_task_has_output() {
529        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
530        let task = make_task("task-output", A2ATaskStatus::Running);
531        tasks.insert("task-output".to_string(), task);
532
533        let output = A2ATaskOutput {
534            content: "The answer is 42".to_string(),
535            data: Some(serde_json::json!({"tokens_used": 100})),
536            mode: "text".to_string(),
537        };
538
539        {
540            let mut entry = tasks.get_mut("task-output").unwrap();
541            entry.status = A2ATaskStatus::Completed;
542            entry.output = Some(serde_json::to_value(&output).unwrap());
543            entry.updated_at = Utc::now();
544        }
545
546        let entry = tasks.get("task-output").unwrap();
547        assert_eq!(entry.status, A2ATaskStatus::Completed);
548        let stored_output = entry.output.as_ref().unwrap();
549        assert_eq!(stored_output["content"], "The answer is 42");
550        assert_eq!(stored_output["mode"], "text");
551    }
552
553    #[test]
554    fn test_failed_task_has_error_message() {
555        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
556        let task = make_task("task-err", A2ATaskStatus::Running);
557        tasks.insert("task-err".to_string(), task);
558
559        {
560            let mut entry = tasks.get_mut("task-err").unwrap();
561            entry.status = A2ATaskStatus::Failed("connection timeout to provider".to_string());
562            entry.updated_at = Utc::now();
563        }
564
565        let entry = tasks.get("task-err").unwrap();
566        match &entry.status {
567            A2ATaskStatus::Failed(msg) => {
568                assert!(msg.contains("connection timeout"));
569            }
570            _ => panic!("expected Failed status"),
571        }
572    }
573
574    #[tokio::test]
575    async fn test_stop_cancellation() {
576        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
577        let (shutdown_tx, shutdown_rx) = watch::channel(false);
578
579        // Simulate executor components without a real Ring.
580        let mut shutdown_rx_clone = shutdown_rx.clone();
581
582        let handle = tokio::spawn(async move {
583            loop {
584                tokio::select! {
585                    _ = tokio::time::sleep(Duration::from_millis(50)) => {}
586                    _ = shutdown_rx_clone.changed() => {
587                        if *shutdown_rx_clone.borrow() {
588                            break;
589                        }
590                    }
591                }
592                if *shutdown_rx_clone.borrow() {
593                    break;
594                }
595            }
596        });
597
598        // Let it run briefly.
599        tokio::time::sleep(Duration::from_millis(100)).await;
600        assert!(!handle.is_finished());
601
602        // Send shutdown.
603        let _ = shutdown_tx.send(true);
604        tokio::time::sleep(Duration::from_millis(100)).await;
605        assert!(handle.is_finished());
606
607        // Tasks map should still be accessible.
608        assert_eq!(tasks.len(), 0);
609    }
610
611    #[test]
612    fn test_pending_task_skipped_if_already_claimed() {
613        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
614        let task = make_task("task-race", A2ATaskStatus::Pending);
615        tasks.insert("task-race".to_string(), task);
616
617        // First "poll" claims it.
618        {
619            let mut entry = tasks.get_mut("task-race").unwrap();
620            if entry.status == A2ATaskStatus::Pending {
621                entry.status = A2ATaskStatus::Running;
622            }
623        }
624
625        // Second "poll" should see Running, not Pending.
626        {
627            let entry = tasks.get("task-race").unwrap();
628            assert_eq!(entry.status, A2ATaskStatus::Running);
629        }
630
631        // Collect pending — should be empty now.
632        let pending: Vec<String> = tasks
633            .iter()
634            .filter(|e| e.status == A2ATaskStatus::Pending)
635            .map(|e| e.key().clone())
636            .collect();
637        assert!(pending.is_empty());
638    }
639
640    #[test]
641    fn test_extract_prompt_with_context_and_prompt() {
642        let input = serde_json::json!({
643            "prompt": "Analyze this code",
644            "context": {
645                "language": "rust"
646            },
647            "mode": "text"
648        });
649        assert_eq!(extract_prompt(&input), "Analyze this code");
650    }
651
652    #[test]
653    fn test_extract_prompt_numeric_value() {
654        let input = serde_json::json!(42);
655        let result = extract_prompt(&input);
656        assert_eq!(result, "42");
657    }
658
659    #[test]
660    fn test_extract_prompt_null_value() {
661        let input = serde_json::json!(null);
662        let result = extract_prompt(&input);
663        assert_eq!(result, "null");
664    }
665
666    #[test]
667    fn test_extract_prompt_array_value() {
668        let input = serde_json::json!(["a", "b"]);
669        let result = extract_prompt(&input);
670        assert!(result.contains('a'));
671    }
672
673    #[test]
674    fn test_extract_prompt_empty_object() {
675        let input = serde_json::json!({});
676        let result = extract_prompt(&input);
677        assert!(!result.is_empty());
678    }
679
680    #[test]
681    fn test_extract_prompt_prefers_structured_over_prompt_field() {
682        // A2ATaskInput has prompt, context, mode fields.
683        let input = serde_json::json!({
684            "prompt": "structured prompt",
685            "context": {},
686            "mode": "text"
687        });
688        assert_eq!(extract_prompt(&input), "structured prompt");
689    }
690
691    #[test]
692    fn test_extract_prompt_message_over_json_fallback() {
693        let input = serde_json::json!({
694            "message": "msg field",
695            "other": "data"
696        });
697        assert_eq!(extract_prompt(&input), "msg field");
698    }
699
700    #[test]
701    fn test_build_task_system_prompt_with_multiple_context_keys() {
702        let input = serde_json::json!({
703            "prompt": "do stuff",
704            "context": {
705                "a": "1",
706                "b": "2",
707                "c": "3"
708            }
709        });
710        let prompt = build_task_system_prompt(&input);
711        assert!(prompt.contains("Task Context"));
712        assert!(prompt.contains("**a**"));
713        assert!(prompt.contains("**b**"));
714        assert!(prompt.contains("**c**"));
715    }
716
717    #[test]
718    fn test_build_task_system_prompt_null_context() {
719        let input = serde_json::json!({
720            "prompt": "hello",
721            "context": null
722        });
723        let prompt = build_task_system_prompt(&input);
724        assert!(!prompt.contains("Task Context"));
725    }
726
727    #[test]
728    fn test_build_task_system_prompt_context_is_string() {
729        let input = serde_json::json!({
730            "prompt": "hello",
731            "context": "not an object"
732        });
733        let prompt = build_task_system_prompt(&input);
734        assert!(!prompt.contains("Task Context"));
735    }
736
737    #[test]
738    fn test_task_lifecycle_pending_running_completed() {
739        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
740        let task = make_task("lifecycle", A2ATaskStatus::Pending);
741        tasks.insert("lifecycle".to_string(), task);
742
743        // Step 1: Pending -> Running.
744        {
745            let mut entry = tasks.get_mut("lifecycle").unwrap();
746            assert_eq!(entry.status, A2ATaskStatus::Pending);
747            entry.status = A2ATaskStatus::Running;
748        }
749
750        // Step 2: Running -> Completed.
751        {
752            let mut entry = tasks.get_mut("lifecycle").unwrap();
753            assert_eq!(entry.status, A2ATaskStatus::Running);
754            entry.status = A2ATaskStatus::Completed;
755            entry.output = Some(serde_json::json!({"result": "done"}));
756        }
757
758        let entry = tasks.get("lifecycle").unwrap();
759        assert_eq!(entry.status, A2ATaskStatus::Completed);
760        assert!(entry.output.is_some());
761    }
762
763    #[test]
764    fn test_task_lifecycle_pending_running_failed() {
765        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
766        let task = make_task("fail-life", A2ATaskStatus::Pending);
767        tasks.insert("fail-life".to_string(), task);
768
769        {
770            let mut entry = tasks.get_mut("fail-life").unwrap();
771            entry.status = A2ATaskStatus::Running;
772        }
773        {
774            let mut entry = tasks.get_mut("fail-life").unwrap();
775            entry.status = A2ATaskStatus::Failed("some error".to_string());
776        }
777
778        let entry = tasks.get("fail-life").unwrap();
779        assert!(matches!(entry.status, A2ATaskStatus::Failed(_)));
780    }
781
782    #[test]
783    fn test_failed_task_preserves_error_detail() {
784        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
785        let task = make_task("err-detail", A2ATaskStatus::Running);
786        tasks.insert("err-detail".to_string(), task);
787
788        let error_msg = "rate limit exceeded: retry after 60s".to_string();
789        {
790            let mut entry = tasks.get_mut("err-detail").unwrap();
791            entry.status = A2ATaskStatus::Failed(error_msg.clone());
792        }
793
794        let entry = tasks.get("err-detail").unwrap();
795        match &entry.status {
796            A2ATaskStatus::Failed(msg) => assert_eq!(msg, &error_msg),
797            _ => panic!("expected Failed"),
798        }
799    }
800
801    #[test]
802    fn test_concurrent_task_isolation() {
803        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
804
805        // Create independent tasks.
806        tasks.insert("t1".to_string(), make_task("t1", A2ATaskStatus::Pending));
807        tasks.insert("t2".to_string(), make_task("t2", A2ATaskStatus::Running));
808        tasks.insert("t3".to_string(), make_task("t3", A2ATaskStatus::Completed));
809
810        // Modifying one doesn't affect others.
811        {
812            let mut entry = tasks.get_mut("t1").unwrap();
813            entry.status = A2ATaskStatus::Running;
814        }
815
816        assert_eq!(tasks.get("t1").unwrap().status, A2ATaskStatus::Running);
817        assert_eq!(tasks.get("t2").unwrap().status, A2ATaskStatus::Running);
818        assert_eq!(tasks.get("t3").unwrap().status, A2ATaskStatus::Completed);
819    }
820
821    #[test]
822    fn test_task_output_with_structured_data() {
823        let output = A2ATaskOutput {
824            content: "Result text".to_string(),
825            data: Some(serde_json::json!({
826                "tokens_used": 500,
827                "iterations": 3,
828                "tool_calls": 2,
829            })),
830            mode: "text".to_string(),
831        };
832        let json = serde_json::to_value(&output).unwrap();
833        assert_eq!(json["content"], "Result text");
834        assert_eq!(json["data"]["tokens_used"], 500);
835        assert_eq!(json["data"]["iterations"], 3);
836    }
837
838    #[test]
839    fn test_task_removal() {
840        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
841        tasks.insert(
842            "rm-task".to_string(),
843            make_task("rm-task", A2ATaskStatus::Completed),
844        );
845
846        assert!(tasks.contains_key("rm-task"));
847        tasks.remove("rm-task");
848        assert!(!tasks.contains_key("rm-task"));
849    }
850
851    #[test]
852    fn test_task_updated_at_changes() {
853        let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
854        let task = make_task("time-task", A2ATaskStatus::Pending);
855        let original_time = task.updated_at;
856        tasks.insert("time-task".to_string(), task);
857
858        // Small sleep to ensure time difference.
859        std::thread::sleep(std::time::Duration::from_millis(10));
860
861        {
862            let mut entry = tasks.get_mut("time-task").unwrap();
863            entry.status = A2ATaskStatus::Running;
864            entry.updated_at = Utc::now();
865        }
866
867        let entry = tasks.get("time-task").unwrap();
868        assert!(entry.updated_at >= original_time);
869    }
870}