1use std::sync::Arc;
8use std::time::Duration;
9
10use chrono::Utc;
11use dashmap::DashMap;
12use tokio::sync::watch;
13use tokio::task::JoinHandle;
14use tracing::{error, info, instrument};
15
16use punch_types::a2a::{A2ATask, A2ATaskInput, A2ATaskOutput, A2ATaskStatus};
17use punch_types::{FighterManifest, WeightClass};
18
19use crate::ring::Ring;
20
21const DEFAULT_POLL_INTERVAL: Duration = Duration::from_millis(500);
23
24pub struct A2ATaskExecutor {
26 ring: Arc<Ring>,
28 tasks: Arc<DashMap<String, A2ATask>>,
30 poll_interval: Duration,
32 shutdown_tx: watch::Sender<bool>,
34 shutdown_rx: watch::Receiver<bool>,
36 handle: Option<JoinHandle<()>>,
38}
39
40impl A2ATaskExecutor {
41 pub fn new(ring: Arc<Ring>, tasks: Arc<DashMap<String, A2ATask>>) -> Self {
44 let (shutdown_tx, shutdown_rx) = watch::channel(false);
45 Self {
46 ring,
47 tasks,
48 poll_interval: DEFAULT_POLL_INTERVAL,
49 shutdown_tx,
50 shutdown_rx,
51 handle: None,
52 }
53 }
54
55 pub fn with_poll_interval(
57 ring: Arc<Ring>,
58 tasks: Arc<DashMap<String, A2ATask>>,
59 poll_interval: Duration,
60 ) -> Self {
61 let (shutdown_tx, shutdown_rx) = watch::channel(false);
62 Self {
63 ring,
64 tasks,
65 poll_interval,
66 shutdown_tx,
67 shutdown_rx,
68 handle: None,
69 }
70 }
71
72 pub fn start(&mut self) {
78 let ring = Arc::clone(&self.ring);
79 let tasks = Arc::clone(&self.tasks);
80 let interval = self.poll_interval;
81 let mut shutdown_rx = self.shutdown_rx.clone();
82
83 let handle = tokio::spawn(async move {
84 info!(
85 poll_interval_ms = interval.as_millis(),
86 "A2A task executor started"
87 );
88
89 loop {
90 tokio::select! {
91 _ = tokio::time::sleep(interval) => {}
92 _ = shutdown_rx.changed() => {
93 if *shutdown_rx.borrow() {
94 info!("A2A task executor received shutdown signal");
95 break;
96 }
97 }
98 }
99
100 if *shutdown_rx.borrow() {
101 break;
102 }
103
104 let pending_ids: Vec<String> = tasks
106 .iter()
107 .filter(|entry| entry.value().status == A2ATaskStatus::Pending)
108 .map(|entry| entry.key().clone())
109 .collect();
110
111 for task_id in pending_ids {
112 let task_input = {
114 let mut entry = match tasks.get_mut(&task_id) {
115 Some(e) => e,
116 None => continue,
117 };
118 if entry.status != A2ATaskStatus::Pending {
120 continue;
121 }
122 entry.status = A2ATaskStatus::Running;
123 entry.updated_at = Utc::now();
124 entry.input.clone()
125 };
126
127 let ring = Arc::clone(&ring);
129 let tasks = Arc::clone(&tasks);
130 let id = task_id.clone();
131
132 tokio::spawn(async move {
133 execute_task(ring, tasks, id, task_input).await;
134 });
135 }
136 }
137
138 info!("A2A task executor stopped");
139 });
140
141 self.handle = Some(handle);
142 }
143
144 pub fn stop(&mut self) {
146 let _ = self.shutdown_tx.send(true);
147 if let Some(handle) = self.handle.take() {
148 handle.abort();
149 }
150 info!("A2A task executor stop requested");
151 }
152
153 pub fn is_running(&self) -> bool {
155 self.handle.as_ref().is_some_and(|h| !h.is_finished())
156 }
157}
158
159impl Drop for A2ATaskExecutor {
160 fn drop(&mut self) {
161 let _ = self.shutdown_tx.send(true);
163 if let Some(handle) = self.handle.take() {
164 handle.abort();
165 }
166 }
167}
168
169#[instrument(skip(ring, tasks, task_input), fields(task_id = %task_id))]
172async fn execute_task(
173 ring: Arc<Ring>,
174 tasks: Arc<DashMap<String, A2ATask>>,
175 task_id: String,
176 task_input: serde_json::Value,
177) {
178 let prompt = extract_prompt(&task_input);
180
181 let manifest = FighterManifest {
183 name: format!("a2a-task-{}", &task_id[..8.min(task_id.len())]),
184 description: format!("Temporary fighter for A2A task {task_id}"),
185 model: ring.config().default_model.clone(),
186 system_prompt: build_task_system_prompt(&task_input),
187 capabilities: Vec::new(),
188 weight_class: WeightClass::Middleweight,
189 tenant_id: None,
190 };
191
192 let fighter_id = ring.spawn_fighter(manifest).await;
194
195 let result = ring.send_message(&fighter_id, prompt).await;
197
198 match result {
200 Ok(loop_result) => {
201 if let Some(mut entry) = tasks.get_mut(&task_id) {
202 if entry.status == A2ATaskStatus::Cancelled {
204 info!(task_id = %task_id, "task was cancelled during execution, skipping update");
205 } else {
206 let output = A2ATaskOutput {
207 content: loop_result.response.clone(),
208 data: Some(serde_json::json!({
209 "tokens_used": loop_result.usage.total(),
210 "iterations": loop_result.iterations,
211 "tool_calls": loop_result.tool_calls_made,
212 })),
213 mode: "text".to_string(),
214 };
215 entry.status = A2ATaskStatus::Completed;
216 entry.output =
217 Some(serde_json::to_value(output).unwrap_or(serde_json::json!({})));
218 entry.updated_at = Utc::now();
219 info!(task_id = %task_id, "A2A task completed successfully");
220 }
221 }
222 }
223 Err(e) => {
224 error!(task_id = %task_id, error = %e, "A2A task execution failed");
225 if let Some(mut entry) = tasks.get_mut(&task_id)
226 && entry.status != A2ATaskStatus::Cancelled
227 {
228 entry.status = A2ATaskStatus::Failed(e.to_string());
229 entry.updated_at = Utc::now();
230 }
231 }
232 }
233
234 ring.kill_fighter(&fighter_id);
236}
237
238fn extract_prompt(input: &serde_json::Value) -> String {
243 if let Ok(structured) = serde_json::from_value::<A2ATaskInput>(input.clone()) {
245 return structured.prompt;
246 }
247
248 if let Some(prompt) = input.get("prompt").and_then(|v| v.as_str()) {
250 return prompt.to_string();
251 }
252
253 if let Some(msg) = input.get("message").and_then(|v| v.as_str()) {
255 return msg.to_string();
256 }
257
258 if let Some(s) = input.as_str() {
260 return s.to_string();
261 }
262
263 input.to_string()
264}
265
266fn build_task_system_prompt(input: &serde_json::Value) -> String {
269 let mut prompt = "You are an AI agent executing a task received via the A2A protocol. \
270 Complete the task thoroughly and return a clear, actionable response."
271 .to_string();
272
273 if let Some(context) = input.get("context")
275 && let Some(obj) = context.as_object()
276 && !obj.is_empty()
277 {
278 prompt.push_str("\n\n## Task Context\n");
279 for (key, value) in obj {
280 prompt.push_str(&format!("- **{key}**: {value}\n"));
281 }
282 }
283
284 prompt
285}
286
287#[cfg(test)]
292mod tests {
293 use super::*;
294 use punch_types::a2a::A2ATaskStatus;
295
296 fn make_task(id: &str, status: A2ATaskStatus) -> A2ATask {
297 let now = Utc::now();
298 A2ATask {
299 id: id.to_string(),
300 status,
301 input: serde_json::json!({"prompt": "hello world"}),
302 output: None,
303 created_at: now,
304 updated_at: now,
305 }
306 }
307
308 #[test]
309 fn test_extract_prompt_structured() {
310 let input = serde_json::json!({
311 "prompt": "Summarize this code",
312 "context": {},
313 "mode": "text"
314 });
315 assert_eq!(extract_prompt(&input), "Summarize this code");
316 }
317
318 #[test]
319 fn test_extract_prompt_simple_prompt_field() {
320 let input = serde_json::json!({"prompt": "Do the thing"});
321 assert_eq!(extract_prompt(&input), "Do the thing");
322 }
323
324 #[test]
325 fn test_extract_prompt_message_field() {
326 let input = serde_json::json!({"message": "Hello agent"});
327 assert_eq!(extract_prompt(&input), "Hello agent");
328 }
329
330 #[test]
331 fn test_extract_prompt_string_value() {
332 let input = serde_json::json!("Just a string prompt");
333 assert_eq!(extract_prompt(&input), "Just a string prompt");
334 }
335
336 #[test]
337 fn test_extract_prompt_fallback_json() {
338 let input = serde_json::json!({"arbitrary": "data", "count": 42});
339 let result = extract_prompt(&input);
340 assert!(result.contains("arbitrary"));
341 }
342
343 #[test]
344 fn test_build_task_system_prompt_no_context() {
345 let input = serde_json::json!({"prompt": "hello"});
346 let prompt = build_task_system_prompt(&input);
347 assert!(prompt.contains("A2A protocol"));
348 assert!(!prompt.contains("Task Context"));
349 }
350
351 #[test]
352 fn test_build_task_system_prompt_with_context() {
353 let input = serde_json::json!({
354 "prompt": "hello",
355 "context": {
356 "language": "rust",
357 "project": "punch"
358 }
359 });
360 let prompt = build_task_system_prompt(&input);
361 assert!(prompt.contains("Task Context"));
362 assert!(prompt.contains("language"));
363 assert!(prompt.contains("rust"));
364 }
365
366 #[test]
367 fn test_build_task_system_prompt_empty_context() {
368 let input = serde_json::json!({
369 "prompt": "hello",
370 "context": {}
371 });
372 let prompt = build_task_system_prompt(&input);
373 assert!(!prompt.contains("Task Context"));
374 }
375
376 #[test]
377 fn test_executor_creation() {
378 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
379 assert_eq!(tasks.len(), 0);
382 }
383
384 #[test]
385 fn test_task_pending_to_running_transition() {
386 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
387 let task = make_task("task-001", A2ATaskStatus::Pending);
388 tasks.insert("task-001".to_string(), task);
389
390 {
392 let mut entry = tasks.get_mut("task-001").unwrap();
393 assert_eq!(entry.status, A2ATaskStatus::Pending);
394 entry.status = A2ATaskStatus::Running;
395 entry.updated_at = Utc::now();
396 }
397
398 let entry = tasks.get("task-001").unwrap();
399 assert_eq!(entry.status, A2ATaskStatus::Running);
400 }
401
402 #[test]
403 fn test_task_running_to_completed_transition() {
404 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
405 let task = make_task("task-002", A2ATaskStatus::Running);
406 tasks.insert("task-002".to_string(), task);
407
408 {
410 let mut entry = tasks.get_mut("task-002").unwrap();
411 let output = A2ATaskOutput {
412 content: "Task result here".to_string(),
413 data: None,
414 mode: "text".to_string(),
415 };
416 entry.status = A2ATaskStatus::Completed;
417 entry.output = Some(serde_json::to_value(output).unwrap());
418 entry.updated_at = Utc::now();
419 }
420
421 let entry = tasks.get("task-002").unwrap();
422 assert_eq!(entry.status, A2ATaskStatus::Completed);
423 assert!(entry.output.is_some());
424 }
425
426 #[test]
427 fn test_task_running_to_failed_transition() {
428 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
429 let task = make_task("task-003", A2ATaskStatus::Running);
430 tasks.insert("task-003".to_string(), task);
431
432 {
434 let mut entry = tasks.get_mut("task-003").unwrap();
435 entry.status = A2ATaskStatus::Failed("LLM provider error".to_string());
436 entry.updated_at = Utc::now();
437 }
438
439 let entry = tasks.get("task-003").unwrap();
440 assert!(
441 matches!(entry.status, A2ATaskStatus::Failed(ref msg) if msg.contains("LLM provider"))
442 );
443 }
444
445 #[test]
446 fn test_multiple_concurrent_tasks() {
447 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
448
449 for i in 0..5 {
451 let task = make_task(&format!("concurrent-{i}"), A2ATaskStatus::Pending);
452 tasks.insert(format!("concurrent-{i}"), task);
453 }
454
455 let pending_ids: Vec<String> = tasks
457 .iter()
458 .filter(|e| e.status == A2ATaskStatus::Pending)
459 .map(|e| e.key().clone())
460 .collect();
461
462 assert_eq!(pending_ids.len(), 5);
463
464 for id in &pending_ids {
466 let mut entry = tasks.get_mut(id).unwrap();
467 entry.status = A2ATaskStatus::Running;
468 }
469
470 let running_count = tasks
472 .iter()
473 .filter(|e| e.status == A2ATaskStatus::Running)
474 .count();
475 assert_eq!(running_count, 5);
476
477 for (i, id) in pending_ids.iter().enumerate() {
479 let mut entry = tasks.get_mut(id).unwrap();
480 if i % 2 == 0 {
481 entry.status = A2ATaskStatus::Completed;
482 entry.output = Some(serde_json::json!({"result": "ok"}));
483 } else {
484 entry.status = A2ATaskStatus::Failed("test error".to_string());
485 }
486 }
487
488 let completed = tasks
489 .iter()
490 .filter(|e| e.status == A2ATaskStatus::Completed)
491 .count();
492 let failed = tasks
493 .iter()
494 .filter(|e| matches!(e.status, A2ATaskStatus::Failed(_)))
495 .count();
496 assert_eq!(completed, 3);
497 assert_eq!(failed, 2);
498 }
499
500 #[test]
501 fn test_cancelled_task_not_overwritten() {
502 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
503 let task = make_task("task-cancel", A2ATaskStatus::Running);
504 tasks.insert("task-cancel".to_string(), task);
505
506 {
508 let mut entry = tasks.get_mut("task-cancel").unwrap();
509 entry.status = A2ATaskStatus::Cancelled;
510 entry.updated_at = Utc::now();
511 }
512
513 {
515 let mut entry = tasks.get_mut("task-cancel").unwrap();
516 if entry.status != A2ATaskStatus::Cancelled {
517 entry.status = A2ATaskStatus::Completed;
518 entry.output = Some(serde_json::json!({"result": "should not appear"}));
519 }
520 }
521
522 let entry = tasks.get("task-cancel").unwrap();
523 assert_eq!(entry.status, A2ATaskStatus::Cancelled);
524 assert!(entry.output.is_none());
525 }
526
527 #[test]
528 fn test_completed_task_has_output() {
529 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
530 let task = make_task("task-output", A2ATaskStatus::Running);
531 tasks.insert("task-output".to_string(), task);
532
533 let output = A2ATaskOutput {
534 content: "The answer is 42".to_string(),
535 data: Some(serde_json::json!({"tokens_used": 100})),
536 mode: "text".to_string(),
537 };
538
539 {
540 let mut entry = tasks.get_mut("task-output").unwrap();
541 entry.status = A2ATaskStatus::Completed;
542 entry.output = Some(serde_json::to_value(&output).unwrap());
543 entry.updated_at = Utc::now();
544 }
545
546 let entry = tasks.get("task-output").unwrap();
547 assert_eq!(entry.status, A2ATaskStatus::Completed);
548 let stored_output = entry.output.as_ref().unwrap();
549 assert_eq!(stored_output["content"], "The answer is 42");
550 assert_eq!(stored_output["mode"], "text");
551 }
552
553 #[test]
554 fn test_failed_task_has_error_message() {
555 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
556 let task = make_task("task-err", A2ATaskStatus::Running);
557 tasks.insert("task-err".to_string(), task);
558
559 {
560 let mut entry = tasks.get_mut("task-err").unwrap();
561 entry.status = A2ATaskStatus::Failed("connection timeout to provider".to_string());
562 entry.updated_at = Utc::now();
563 }
564
565 let entry = tasks.get("task-err").unwrap();
566 match &entry.status {
567 A2ATaskStatus::Failed(msg) => {
568 assert!(msg.contains("connection timeout"));
569 }
570 _ => panic!("expected Failed status"),
571 }
572 }
573
574 #[tokio::test]
575 async fn test_stop_cancellation() {
576 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
577 let (shutdown_tx, shutdown_rx) = watch::channel(false);
578
579 let mut shutdown_rx_clone = shutdown_rx.clone();
581
582 let handle = tokio::spawn(async move {
583 loop {
584 tokio::select! {
585 _ = tokio::time::sleep(Duration::from_millis(50)) => {}
586 _ = shutdown_rx_clone.changed() => {
587 if *shutdown_rx_clone.borrow() {
588 break;
589 }
590 }
591 }
592 if *shutdown_rx_clone.borrow() {
593 break;
594 }
595 }
596 });
597
598 tokio::time::sleep(Duration::from_millis(100)).await;
600 assert!(!handle.is_finished());
601
602 let _ = shutdown_tx.send(true);
604 tokio::time::sleep(Duration::from_millis(100)).await;
605 assert!(handle.is_finished());
606
607 assert_eq!(tasks.len(), 0);
609 }
610
611 #[test]
612 fn test_pending_task_skipped_if_already_claimed() {
613 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
614 let task = make_task("task-race", A2ATaskStatus::Pending);
615 tasks.insert("task-race".to_string(), task);
616
617 {
619 let mut entry = tasks.get_mut("task-race").unwrap();
620 if entry.status == A2ATaskStatus::Pending {
621 entry.status = A2ATaskStatus::Running;
622 }
623 }
624
625 {
627 let entry = tasks.get("task-race").unwrap();
628 assert_eq!(entry.status, A2ATaskStatus::Running);
629 }
630
631 let pending: Vec<String> = tasks
633 .iter()
634 .filter(|e| e.status == A2ATaskStatus::Pending)
635 .map(|e| e.key().clone())
636 .collect();
637 assert!(pending.is_empty());
638 }
639
640 #[test]
641 fn test_extract_prompt_with_context_and_prompt() {
642 let input = serde_json::json!({
643 "prompt": "Analyze this code",
644 "context": {
645 "language": "rust"
646 },
647 "mode": "text"
648 });
649 assert_eq!(extract_prompt(&input), "Analyze this code");
650 }
651
652 #[test]
653 fn test_extract_prompt_numeric_value() {
654 let input = serde_json::json!(42);
655 let result = extract_prompt(&input);
656 assert_eq!(result, "42");
657 }
658
659 #[test]
660 fn test_extract_prompt_null_value() {
661 let input = serde_json::json!(null);
662 let result = extract_prompt(&input);
663 assert_eq!(result, "null");
664 }
665
666 #[test]
667 fn test_extract_prompt_array_value() {
668 let input = serde_json::json!(["a", "b"]);
669 let result = extract_prompt(&input);
670 assert!(result.contains('a'));
671 }
672
673 #[test]
674 fn test_extract_prompt_empty_object() {
675 let input = serde_json::json!({});
676 let result = extract_prompt(&input);
677 assert!(!result.is_empty());
678 }
679
680 #[test]
681 fn test_extract_prompt_prefers_structured_over_prompt_field() {
682 let input = serde_json::json!({
684 "prompt": "structured prompt",
685 "context": {},
686 "mode": "text"
687 });
688 assert_eq!(extract_prompt(&input), "structured prompt");
689 }
690
691 #[test]
692 fn test_extract_prompt_message_over_json_fallback() {
693 let input = serde_json::json!({
694 "message": "msg field",
695 "other": "data"
696 });
697 assert_eq!(extract_prompt(&input), "msg field");
698 }
699
700 #[test]
701 fn test_build_task_system_prompt_with_multiple_context_keys() {
702 let input = serde_json::json!({
703 "prompt": "do stuff",
704 "context": {
705 "a": "1",
706 "b": "2",
707 "c": "3"
708 }
709 });
710 let prompt = build_task_system_prompt(&input);
711 assert!(prompt.contains("Task Context"));
712 assert!(prompt.contains("**a**"));
713 assert!(prompt.contains("**b**"));
714 assert!(prompt.contains("**c**"));
715 }
716
717 #[test]
718 fn test_build_task_system_prompt_null_context() {
719 let input = serde_json::json!({
720 "prompt": "hello",
721 "context": null
722 });
723 let prompt = build_task_system_prompt(&input);
724 assert!(!prompt.contains("Task Context"));
725 }
726
727 #[test]
728 fn test_build_task_system_prompt_context_is_string() {
729 let input = serde_json::json!({
730 "prompt": "hello",
731 "context": "not an object"
732 });
733 let prompt = build_task_system_prompt(&input);
734 assert!(!prompt.contains("Task Context"));
735 }
736
737 #[test]
738 fn test_task_lifecycle_pending_running_completed() {
739 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
740 let task = make_task("lifecycle", A2ATaskStatus::Pending);
741 tasks.insert("lifecycle".to_string(), task);
742
743 {
745 let mut entry = tasks.get_mut("lifecycle").unwrap();
746 assert_eq!(entry.status, A2ATaskStatus::Pending);
747 entry.status = A2ATaskStatus::Running;
748 }
749
750 {
752 let mut entry = tasks.get_mut("lifecycle").unwrap();
753 assert_eq!(entry.status, A2ATaskStatus::Running);
754 entry.status = A2ATaskStatus::Completed;
755 entry.output = Some(serde_json::json!({"result": "done"}));
756 }
757
758 let entry = tasks.get("lifecycle").unwrap();
759 assert_eq!(entry.status, A2ATaskStatus::Completed);
760 assert!(entry.output.is_some());
761 }
762
763 #[test]
764 fn test_task_lifecycle_pending_running_failed() {
765 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
766 let task = make_task("fail-life", A2ATaskStatus::Pending);
767 tasks.insert("fail-life".to_string(), task);
768
769 {
770 let mut entry = tasks.get_mut("fail-life").unwrap();
771 entry.status = A2ATaskStatus::Running;
772 }
773 {
774 let mut entry = tasks.get_mut("fail-life").unwrap();
775 entry.status = A2ATaskStatus::Failed("some error".to_string());
776 }
777
778 let entry = tasks.get("fail-life").unwrap();
779 assert!(matches!(entry.status, A2ATaskStatus::Failed(_)));
780 }
781
782 #[test]
783 fn test_failed_task_preserves_error_detail() {
784 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
785 let task = make_task("err-detail", A2ATaskStatus::Running);
786 tasks.insert("err-detail".to_string(), task);
787
788 let error_msg = "rate limit exceeded: retry after 60s".to_string();
789 {
790 let mut entry = tasks.get_mut("err-detail").unwrap();
791 entry.status = A2ATaskStatus::Failed(error_msg.clone());
792 }
793
794 let entry = tasks.get("err-detail").unwrap();
795 match &entry.status {
796 A2ATaskStatus::Failed(msg) => assert_eq!(msg, &error_msg),
797 _ => panic!("expected Failed"),
798 }
799 }
800
801 #[test]
802 fn test_concurrent_task_isolation() {
803 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
804
805 tasks.insert("t1".to_string(), make_task("t1", A2ATaskStatus::Pending));
807 tasks.insert("t2".to_string(), make_task("t2", A2ATaskStatus::Running));
808 tasks.insert("t3".to_string(), make_task("t3", A2ATaskStatus::Completed));
809
810 {
812 let mut entry = tasks.get_mut("t1").unwrap();
813 entry.status = A2ATaskStatus::Running;
814 }
815
816 assert_eq!(tasks.get("t1").unwrap().status, A2ATaskStatus::Running);
817 assert_eq!(tasks.get("t2").unwrap().status, A2ATaskStatus::Running);
818 assert_eq!(tasks.get("t3").unwrap().status, A2ATaskStatus::Completed);
819 }
820
821 #[test]
822 fn test_task_output_with_structured_data() {
823 let output = A2ATaskOutput {
824 content: "Result text".to_string(),
825 data: Some(serde_json::json!({
826 "tokens_used": 500,
827 "iterations": 3,
828 "tool_calls": 2,
829 })),
830 mode: "text".to_string(),
831 };
832 let json = serde_json::to_value(&output).unwrap();
833 assert_eq!(json["content"], "Result text");
834 assert_eq!(json["data"]["tokens_used"], 500);
835 assert_eq!(json["data"]["iterations"], 3);
836 }
837
838 #[test]
839 fn test_task_removal() {
840 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
841 tasks.insert(
842 "rm-task".to_string(),
843 make_task("rm-task", A2ATaskStatus::Completed),
844 );
845
846 assert!(tasks.contains_key("rm-task"));
847 tasks.remove("rm-task");
848 assert!(!tasks.contains_key("rm-task"));
849 }
850
851 #[test]
852 fn test_task_updated_at_changes() {
853 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
854 let task = make_task("time-task", A2ATaskStatus::Pending);
855 let original_time = task.updated_at;
856 tasks.insert("time-task".to_string(), task);
857
858 std::thread::sleep(std::time::Duration::from_millis(10));
860
861 {
862 let mut entry = tasks.get_mut("time-task").unwrap();
863 entry.status = A2ATaskStatus::Running;
864 entry.updated_at = Utc::now();
865 }
866
867 let entry = tasks.get("time-task").unwrap();
868 assert!(entry.updated_at >= original_time);
869 }
870}