1use std::sync::Arc;
8use std::time::Duration;
9
10use chrono::Utc;
11use dashmap::DashMap;
12use tokio::sync::watch;
13use tokio::task::JoinHandle;
14use tracing::{error, info, instrument};
15
16use punch_types::a2a::{A2ATask, A2ATaskInput, A2ATaskOutput, A2ATaskStatus};
17use punch_types::{FighterManifest, WeightClass};
18
19use crate::ring::Ring;
20
21const DEFAULT_POLL_INTERVAL: Duration = Duration::from_millis(500);
23
24pub struct A2ATaskExecutor {
26 ring: Arc<Ring>,
28 tasks: Arc<DashMap<String, A2ATask>>,
30 poll_interval: Duration,
32 shutdown_tx: watch::Sender<bool>,
34 shutdown_rx: watch::Receiver<bool>,
36 handle: Option<JoinHandle<()>>,
38}
39
40impl A2ATaskExecutor {
41 pub fn new(ring: Arc<Ring>, tasks: Arc<DashMap<String, A2ATask>>) -> Self {
44 let (shutdown_tx, shutdown_rx) = watch::channel(false);
45 Self {
46 ring,
47 tasks,
48 poll_interval: DEFAULT_POLL_INTERVAL,
49 shutdown_tx,
50 shutdown_rx,
51 handle: None,
52 }
53 }
54
55 pub fn with_poll_interval(
57 ring: Arc<Ring>,
58 tasks: Arc<DashMap<String, A2ATask>>,
59 poll_interval: Duration,
60 ) -> Self {
61 let (shutdown_tx, shutdown_rx) = watch::channel(false);
62 Self {
63 ring,
64 tasks,
65 poll_interval,
66 shutdown_tx,
67 shutdown_rx,
68 handle: None,
69 }
70 }
71
72 pub fn start(&mut self) {
78 let ring = Arc::clone(&self.ring);
79 let tasks = Arc::clone(&self.tasks);
80 let interval = self.poll_interval;
81 let mut shutdown_rx = self.shutdown_rx.clone();
82
83 let handle = tokio::spawn(async move {
84 info!(
85 poll_interval_ms = interval.as_millis(),
86 "A2A task executor started"
87 );
88
89 loop {
90 tokio::select! {
91 _ = tokio::time::sleep(interval) => {}
92 _ = shutdown_rx.changed() => {
93 if *shutdown_rx.borrow() {
94 info!("A2A task executor received shutdown signal");
95 break;
96 }
97 }
98 }
99
100 if *shutdown_rx.borrow() {
101 break;
102 }
103
104 let pending_ids: Vec<String> = tasks
106 .iter()
107 .filter(|entry| entry.value().status == A2ATaskStatus::Pending)
108 .map(|entry| entry.key().clone())
109 .collect();
110
111 for task_id in pending_ids {
112 let task_input = {
114 let mut entry = match tasks.get_mut(&task_id) {
115 Some(e) => e,
116 None => continue,
117 };
118 if entry.status != A2ATaskStatus::Pending {
120 continue;
121 }
122 entry.status = A2ATaskStatus::Running;
123 entry.updated_at = Utc::now();
124 entry.input.clone()
125 };
126
127 let ring = Arc::clone(&ring);
129 let tasks = Arc::clone(&tasks);
130 let id = task_id.clone();
131
132 tokio::spawn(async move {
133 execute_task(ring, tasks, id, task_input).await;
134 });
135 }
136 }
137
138 info!("A2A task executor stopped");
139 });
140
141 self.handle = Some(handle);
142 }
143
144 pub fn stop(&mut self) {
146 let _ = self.shutdown_tx.send(true);
147 if let Some(handle) = self.handle.take() {
148 handle.abort();
149 }
150 info!("A2A task executor stop requested");
151 }
152
153 pub fn is_running(&self) -> bool {
155 self.handle
156 .as_ref()
157 .is_some_and(|h| !h.is_finished())
158 }
159}
160
161impl Drop for A2ATaskExecutor {
162 fn drop(&mut self) {
163 let _ = self.shutdown_tx.send(true);
165 if let Some(handle) = self.handle.take() {
166 handle.abort();
167 }
168 }
169}
170
171#[instrument(skip(ring, tasks, task_input), fields(task_id = %task_id))]
174async fn execute_task(
175 ring: Arc<Ring>,
176 tasks: Arc<DashMap<String, A2ATask>>,
177 task_id: String,
178 task_input: serde_json::Value,
179) {
180 let prompt = extract_prompt(&task_input);
182
183 let manifest = FighterManifest {
185 name: format!("a2a-task-{}", &task_id[..8.min(task_id.len())]),
186 description: format!("Temporary fighter for A2A task {task_id}"),
187 model: ring.config().default_model.clone(),
188 system_prompt: build_task_system_prompt(&task_input),
189 capabilities: Vec::new(),
190 weight_class: WeightClass::Middleweight,
191 tenant_id: None,
192 };
193
194 let fighter_id = ring.spawn_fighter(manifest).await;
196
197 let result = ring.send_message(&fighter_id, prompt).await;
199
200 match result {
202 Ok(loop_result) => {
203 if let Some(mut entry) = tasks.get_mut(&task_id) {
204 if entry.status == A2ATaskStatus::Cancelled {
206 info!(task_id = %task_id, "task was cancelled during execution, skipping update");
207 } else {
208 let output = A2ATaskOutput {
209 content: loop_result.response.clone(),
210 data: Some(serde_json::json!({
211 "tokens_used": loop_result.usage.total(),
212 "iterations": loop_result.iterations,
213 "tool_calls": loop_result.tool_calls_made,
214 })),
215 mode: "text".to_string(),
216 };
217 entry.status = A2ATaskStatus::Completed;
218 entry.output =
219 Some(serde_json::to_value(output).unwrap_or(serde_json::json!({})));
220 entry.updated_at = Utc::now();
221 info!(task_id = %task_id, "A2A task completed successfully");
222 }
223 }
224 }
225 Err(e) => {
226 error!(task_id = %task_id, error = %e, "A2A task execution failed");
227 if let Some(mut entry) = tasks.get_mut(&task_id)
228 && entry.status != A2ATaskStatus::Cancelled
229 {
230 entry.status = A2ATaskStatus::Failed(e.to_string());
231 entry.updated_at = Utc::now();
232 }
233 }
234 }
235
236 ring.kill_fighter(&fighter_id);
238}
239
240fn extract_prompt(input: &serde_json::Value) -> String {
245 if let Ok(structured) = serde_json::from_value::<A2ATaskInput>(input.clone()) {
247 return structured.prompt;
248 }
249
250 if let Some(prompt) = input.get("prompt").and_then(|v| v.as_str()) {
252 return prompt.to_string();
253 }
254
255 if let Some(msg) = input.get("message").and_then(|v| v.as_str()) {
257 return msg.to_string();
258 }
259
260 if let Some(s) = input.as_str() {
262 return s.to_string();
263 }
264
265 input.to_string()
266}
267
268fn build_task_system_prompt(input: &serde_json::Value) -> String {
271 let mut prompt = "You are an AI agent executing a task received via the A2A protocol. \
272 Complete the task thoroughly and return a clear, actionable response."
273 .to_string();
274
275 if let Some(context) = input.get("context")
277 && let Some(obj) = context.as_object()
278 && !obj.is_empty()
279 {
280 prompt.push_str("\n\n## Task Context\n");
281 for (key, value) in obj {
282 prompt.push_str(&format!("- **{key}**: {value}\n"));
283 }
284 }
285
286 prompt
287}
288
289#[cfg(test)]
294mod tests {
295 use super::*;
296 use punch_types::a2a::A2ATaskStatus;
297
298 fn make_task(id: &str, status: A2ATaskStatus) -> A2ATask {
299 let now = Utc::now();
300 A2ATask {
301 id: id.to_string(),
302 status,
303 input: serde_json::json!({"prompt": "hello world"}),
304 output: None,
305 created_at: now,
306 updated_at: now,
307 }
308 }
309
310 #[test]
311 fn test_extract_prompt_structured() {
312 let input = serde_json::json!({
313 "prompt": "Summarize this code",
314 "context": {},
315 "mode": "text"
316 });
317 assert_eq!(extract_prompt(&input), "Summarize this code");
318 }
319
320 #[test]
321 fn test_extract_prompt_simple_prompt_field() {
322 let input = serde_json::json!({"prompt": "Do the thing"});
323 assert_eq!(extract_prompt(&input), "Do the thing");
324 }
325
326 #[test]
327 fn test_extract_prompt_message_field() {
328 let input = serde_json::json!({"message": "Hello agent"});
329 assert_eq!(extract_prompt(&input), "Hello agent");
330 }
331
332 #[test]
333 fn test_extract_prompt_string_value() {
334 let input = serde_json::json!("Just a string prompt");
335 assert_eq!(extract_prompt(&input), "Just a string prompt");
336 }
337
338 #[test]
339 fn test_extract_prompt_fallback_json() {
340 let input = serde_json::json!({"arbitrary": "data", "count": 42});
341 let result = extract_prompt(&input);
342 assert!(result.contains("arbitrary"));
343 }
344
345 #[test]
346 fn test_build_task_system_prompt_no_context() {
347 let input = serde_json::json!({"prompt": "hello"});
348 let prompt = build_task_system_prompt(&input);
349 assert!(prompt.contains("A2A protocol"));
350 assert!(!prompt.contains("Task Context"));
351 }
352
353 #[test]
354 fn test_build_task_system_prompt_with_context() {
355 let input = serde_json::json!({
356 "prompt": "hello",
357 "context": {
358 "language": "rust",
359 "project": "punch"
360 }
361 });
362 let prompt = build_task_system_prompt(&input);
363 assert!(prompt.contains("Task Context"));
364 assert!(prompt.contains("language"));
365 assert!(prompt.contains("rust"));
366 }
367
368 #[test]
369 fn test_build_task_system_prompt_empty_context() {
370 let input = serde_json::json!({
371 "prompt": "hello",
372 "context": {}
373 });
374 let prompt = build_task_system_prompt(&input);
375 assert!(!prompt.contains("Task Context"));
376 }
377
378 #[test]
379 fn test_executor_creation() {
380 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
381 assert_eq!(tasks.len(), 0);
384 }
385
386 #[test]
387 fn test_task_pending_to_running_transition() {
388 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
389 let task = make_task("task-001", A2ATaskStatus::Pending);
390 tasks.insert("task-001".to_string(), task);
391
392 {
394 let mut entry = tasks.get_mut("task-001").unwrap();
395 assert_eq!(entry.status, A2ATaskStatus::Pending);
396 entry.status = A2ATaskStatus::Running;
397 entry.updated_at = Utc::now();
398 }
399
400 let entry = tasks.get("task-001").unwrap();
401 assert_eq!(entry.status, A2ATaskStatus::Running);
402 }
403
404 #[test]
405 fn test_task_running_to_completed_transition() {
406 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
407 let task = make_task("task-002", A2ATaskStatus::Running);
408 tasks.insert("task-002".to_string(), task);
409
410 {
412 let mut entry = tasks.get_mut("task-002").unwrap();
413 let output = A2ATaskOutput {
414 content: "Task result here".to_string(),
415 data: None,
416 mode: "text".to_string(),
417 };
418 entry.status = A2ATaskStatus::Completed;
419 entry.output = Some(serde_json::to_value(output).unwrap());
420 entry.updated_at = Utc::now();
421 }
422
423 let entry = tasks.get("task-002").unwrap();
424 assert_eq!(entry.status, A2ATaskStatus::Completed);
425 assert!(entry.output.is_some());
426 }
427
428 #[test]
429 fn test_task_running_to_failed_transition() {
430 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
431 let task = make_task("task-003", A2ATaskStatus::Running);
432 tasks.insert("task-003".to_string(), task);
433
434 {
436 let mut entry = tasks.get_mut("task-003").unwrap();
437 entry.status = A2ATaskStatus::Failed("LLM provider error".to_string());
438 entry.updated_at = Utc::now();
439 }
440
441 let entry = tasks.get("task-003").unwrap();
442 assert!(matches!(entry.status, A2ATaskStatus::Failed(ref msg) if msg.contains("LLM provider")));
443 }
444
445 #[test]
446 fn test_multiple_concurrent_tasks() {
447 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
448
449 for i in 0..5 {
451 let task = make_task(&format!("concurrent-{i}"), A2ATaskStatus::Pending);
452 tasks.insert(format!("concurrent-{i}"), task);
453 }
454
455 let pending_ids: Vec<String> = tasks
457 .iter()
458 .filter(|e| e.status == A2ATaskStatus::Pending)
459 .map(|e| e.key().clone())
460 .collect();
461
462 assert_eq!(pending_ids.len(), 5);
463
464 for id in &pending_ids {
466 let mut entry = tasks.get_mut(id).unwrap();
467 entry.status = A2ATaskStatus::Running;
468 }
469
470 let running_count = tasks
472 .iter()
473 .filter(|e| e.status == A2ATaskStatus::Running)
474 .count();
475 assert_eq!(running_count, 5);
476
477 for (i, id) in pending_ids.iter().enumerate() {
479 let mut entry = tasks.get_mut(id).unwrap();
480 if i % 2 == 0 {
481 entry.status = A2ATaskStatus::Completed;
482 entry.output = Some(serde_json::json!({"result": "ok"}));
483 } else {
484 entry.status = A2ATaskStatus::Failed("test error".to_string());
485 }
486 }
487
488 let completed = tasks
489 .iter()
490 .filter(|e| e.status == A2ATaskStatus::Completed)
491 .count();
492 let failed = tasks
493 .iter()
494 .filter(|e| matches!(e.status, A2ATaskStatus::Failed(_)))
495 .count();
496 assert_eq!(completed, 3);
497 assert_eq!(failed, 2);
498 }
499
500 #[test]
501 fn test_cancelled_task_not_overwritten() {
502 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
503 let task = make_task("task-cancel", A2ATaskStatus::Running);
504 tasks.insert("task-cancel".to_string(), task);
505
506 {
508 let mut entry = tasks.get_mut("task-cancel").unwrap();
509 entry.status = A2ATaskStatus::Cancelled;
510 entry.updated_at = Utc::now();
511 }
512
513 {
515 let mut entry = tasks.get_mut("task-cancel").unwrap();
516 if entry.status != A2ATaskStatus::Cancelled {
517 entry.status = A2ATaskStatus::Completed;
518 entry.output = Some(serde_json::json!({"result": "should not appear"}));
519 }
520 }
521
522 let entry = tasks.get("task-cancel").unwrap();
523 assert_eq!(entry.status, A2ATaskStatus::Cancelled);
524 assert!(entry.output.is_none());
525 }
526
527 #[test]
528 fn test_completed_task_has_output() {
529 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
530 let task = make_task("task-output", A2ATaskStatus::Running);
531 tasks.insert("task-output".to_string(), task);
532
533 let output = A2ATaskOutput {
534 content: "The answer is 42".to_string(),
535 data: Some(serde_json::json!({"tokens_used": 100})),
536 mode: "text".to_string(),
537 };
538
539 {
540 let mut entry = tasks.get_mut("task-output").unwrap();
541 entry.status = A2ATaskStatus::Completed;
542 entry.output = Some(serde_json::to_value(&output).unwrap());
543 entry.updated_at = Utc::now();
544 }
545
546 let entry = tasks.get("task-output").unwrap();
547 assert_eq!(entry.status, A2ATaskStatus::Completed);
548 let stored_output = entry.output.as_ref().unwrap();
549 assert_eq!(stored_output["content"], "The answer is 42");
550 assert_eq!(stored_output["mode"], "text");
551 }
552
553 #[test]
554 fn test_failed_task_has_error_message() {
555 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
556 let task = make_task("task-err", A2ATaskStatus::Running);
557 tasks.insert("task-err".to_string(), task);
558
559 {
560 let mut entry = tasks.get_mut("task-err").unwrap();
561 entry.status = A2ATaskStatus::Failed("connection timeout to provider".to_string());
562 entry.updated_at = Utc::now();
563 }
564
565 let entry = tasks.get("task-err").unwrap();
566 match &entry.status {
567 A2ATaskStatus::Failed(msg) => {
568 assert!(msg.contains("connection timeout"));
569 }
570 _ => panic!("expected Failed status"),
571 }
572 }
573
574 #[tokio::test]
575 async fn test_stop_cancellation() {
576 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
577 let (shutdown_tx, shutdown_rx) = watch::channel(false);
578
579 let mut shutdown_rx_clone = shutdown_rx.clone();
581
582 let handle = tokio::spawn(async move {
583 loop {
584 tokio::select! {
585 _ = tokio::time::sleep(Duration::from_millis(50)) => {}
586 _ = shutdown_rx_clone.changed() => {
587 if *shutdown_rx_clone.borrow() {
588 break;
589 }
590 }
591 }
592 if *shutdown_rx_clone.borrow() {
593 break;
594 }
595 }
596 });
597
598 tokio::time::sleep(Duration::from_millis(100)).await;
600 assert!(!handle.is_finished());
601
602 let _ = shutdown_tx.send(true);
604 tokio::time::sleep(Duration::from_millis(100)).await;
605 assert!(handle.is_finished());
606
607 assert_eq!(tasks.len(), 0);
609 }
610
611 #[test]
612 fn test_pending_task_skipped_if_already_claimed() {
613 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
614 let task = make_task("task-race", A2ATaskStatus::Pending);
615 tasks.insert("task-race".to_string(), task);
616
617 {
619 let mut entry = tasks.get_mut("task-race").unwrap();
620 if entry.status == A2ATaskStatus::Pending {
621 entry.status = A2ATaskStatus::Running;
622 }
623 }
624
625 {
627 let entry = tasks.get("task-race").unwrap();
628 assert_eq!(entry.status, A2ATaskStatus::Running);
629 }
630
631 let pending: Vec<String> = tasks
633 .iter()
634 .filter(|e| e.status == A2ATaskStatus::Pending)
635 .map(|e| e.key().clone())
636 .collect();
637 assert!(pending.is_empty());
638 }
639
640 #[test]
641 fn test_extract_prompt_with_context_and_prompt() {
642 let input = serde_json::json!({
643 "prompt": "Analyze this code",
644 "context": {
645 "language": "rust"
646 },
647 "mode": "text"
648 });
649 assert_eq!(extract_prompt(&input), "Analyze this code");
650 }
651
652 #[test]
653 fn test_extract_prompt_numeric_value() {
654 let input = serde_json::json!(42);
655 let result = extract_prompt(&input);
656 assert_eq!(result, "42");
657 }
658
659 #[test]
660 fn test_extract_prompt_null_value() {
661 let input = serde_json::json!(null);
662 let result = extract_prompt(&input);
663 assert_eq!(result, "null");
664 }
665
666 #[test]
667 fn test_extract_prompt_array_value() {
668 let input = serde_json::json!(["a", "b"]);
669 let result = extract_prompt(&input);
670 assert!(result.contains('a'));
671 }
672
673 #[test]
674 fn test_extract_prompt_empty_object() {
675 let input = serde_json::json!({});
676 let result = extract_prompt(&input);
677 assert!(!result.is_empty());
678 }
679
680 #[test]
681 fn test_extract_prompt_prefers_structured_over_prompt_field() {
682 let input = serde_json::json!({
684 "prompt": "structured prompt",
685 "context": {},
686 "mode": "text"
687 });
688 assert_eq!(extract_prompt(&input), "structured prompt");
689 }
690
691 #[test]
692 fn test_extract_prompt_message_over_json_fallback() {
693 let input = serde_json::json!({
694 "message": "msg field",
695 "other": "data"
696 });
697 assert_eq!(extract_prompt(&input), "msg field");
698 }
699
700 #[test]
701 fn test_build_task_system_prompt_with_multiple_context_keys() {
702 let input = serde_json::json!({
703 "prompt": "do stuff",
704 "context": {
705 "a": "1",
706 "b": "2",
707 "c": "3"
708 }
709 });
710 let prompt = build_task_system_prompt(&input);
711 assert!(prompt.contains("Task Context"));
712 assert!(prompt.contains("**a**"));
713 assert!(prompt.contains("**b**"));
714 assert!(prompt.contains("**c**"));
715 }
716
717 #[test]
718 fn test_build_task_system_prompt_null_context() {
719 let input = serde_json::json!({
720 "prompt": "hello",
721 "context": null
722 });
723 let prompt = build_task_system_prompt(&input);
724 assert!(!prompt.contains("Task Context"));
725 }
726
727 #[test]
728 fn test_build_task_system_prompt_context_is_string() {
729 let input = serde_json::json!({
730 "prompt": "hello",
731 "context": "not an object"
732 });
733 let prompt = build_task_system_prompt(&input);
734 assert!(!prompt.contains("Task Context"));
735 }
736
737 #[test]
738 fn test_task_lifecycle_pending_running_completed() {
739 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
740 let task = make_task("lifecycle", A2ATaskStatus::Pending);
741 tasks.insert("lifecycle".to_string(), task);
742
743 {
745 let mut entry = tasks.get_mut("lifecycle").unwrap();
746 assert_eq!(entry.status, A2ATaskStatus::Pending);
747 entry.status = A2ATaskStatus::Running;
748 }
749
750 {
752 let mut entry = tasks.get_mut("lifecycle").unwrap();
753 assert_eq!(entry.status, A2ATaskStatus::Running);
754 entry.status = A2ATaskStatus::Completed;
755 entry.output = Some(serde_json::json!({"result": "done"}));
756 }
757
758 let entry = tasks.get("lifecycle").unwrap();
759 assert_eq!(entry.status, A2ATaskStatus::Completed);
760 assert!(entry.output.is_some());
761 }
762
763 #[test]
764 fn test_task_lifecycle_pending_running_failed() {
765 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
766 let task = make_task("fail-life", A2ATaskStatus::Pending);
767 tasks.insert("fail-life".to_string(), task);
768
769 {
770 let mut entry = tasks.get_mut("fail-life").unwrap();
771 entry.status = A2ATaskStatus::Running;
772 }
773 {
774 let mut entry = tasks.get_mut("fail-life").unwrap();
775 entry.status = A2ATaskStatus::Failed("some error".to_string());
776 }
777
778 let entry = tasks.get("fail-life").unwrap();
779 assert!(matches!(entry.status, A2ATaskStatus::Failed(_)));
780 }
781
782 #[test]
783 fn test_failed_task_preserves_error_detail() {
784 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
785 let task = make_task("err-detail", A2ATaskStatus::Running);
786 tasks.insert("err-detail".to_string(), task);
787
788 let error_msg = "rate limit exceeded: retry after 60s".to_string();
789 {
790 let mut entry = tasks.get_mut("err-detail").unwrap();
791 entry.status = A2ATaskStatus::Failed(error_msg.clone());
792 }
793
794 let entry = tasks.get("err-detail").unwrap();
795 match &entry.status {
796 A2ATaskStatus::Failed(msg) => assert_eq!(msg, &error_msg),
797 _ => panic!("expected Failed"),
798 }
799 }
800
801 #[test]
802 fn test_concurrent_task_isolation() {
803 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
804
805 tasks.insert("t1".to_string(), make_task("t1", A2ATaskStatus::Pending));
807 tasks.insert("t2".to_string(), make_task("t2", A2ATaskStatus::Running));
808 tasks.insert("t3".to_string(), make_task("t3", A2ATaskStatus::Completed));
809
810 {
812 let mut entry = tasks.get_mut("t1").unwrap();
813 entry.status = A2ATaskStatus::Running;
814 }
815
816 assert_eq!(tasks.get("t1").unwrap().status, A2ATaskStatus::Running);
817 assert_eq!(tasks.get("t2").unwrap().status, A2ATaskStatus::Running);
818 assert_eq!(tasks.get("t3").unwrap().status, A2ATaskStatus::Completed);
819 }
820
821 #[test]
822 fn test_task_output_with_structured_data() {
823 let output = A2ATaskOutput {
824 content: "Result text".to_string(),
825 data: Some(serde_json::json!({
826 "tokens_used": 500,
827 "iterations": 3,
828 "tool_calls": 2,
829 })),
830 mode: "text".to_string(),
831 };
832 let json = serde_json::to_value(&output).unwrap();
833 assert_eq!(json["content"], "Result text");
834 assert_eq!(json["data"]["tokens_used"], 500);
835 assert_eq!(json["data"]["iterations"], 3);
836 }
837
838 #[test]
839 fn test_task_removal() {
840 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
841 tasks.insert("rm-task".to_string(), make_task("rm-task", A2ATaskStatus::Completed));
842
843 assert!(tasks.contains_key("rm-task"));
844 tasks.remove("rm-task");
845 assert!(!tasks.contains_key("rm-task"));
846 }
847
848 #[test]
849 fn test_task_updated_at_changes() {
850 let tasks: Arc<DashMap<String, A2ATask>> = Arc::new(DashMap::new());
851 let task = make_task("time-task", A2ATaskStatus::Pending);
852 let original_time = task.updated_at;
853 tasks.insert("time-task".to_string(), task);
854
855 std::thread::sleep(std::time::Duration::from_millis(10));
857
858 {
859 let mut entry = tasks.get_mut("time-task").unwrap();
860 entry.status = A2ATaskStatus::Running;
861 entry.updated_at = Utc::now();
862 }
863
864 let entry = tasks.get("time-task").unwrap();
865 assert!(entry.updated_at >= original_time);
866 }
867}