1use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Instant;
11
12use a2a_protocol_types::events::{StreamResponse, TaskStatusUpdateEvent};
13use a2a_protocol_types::params::MessageSendParams;
14use a2a_protocol_types::responses::SendMessageResponse;
15use a2a_protocol_types::task::{ContextId, Task, TaskId, TaskState, TaskStatus};
16
17use crate::error::{ServerError, ServerResult};
18use crate::request_context::RequestContext;
19use crate::streaming::EventQueueWriter;
20
21use super::helpers::{build_call_context, validate_id};
22use super::{CancellationEntry, RequestHandler, SendMessageResult};
23
24fn json_byte_len(value: &serde_json::Value) -> serde_json::Result<usize> {
26 struct CountWriter(usize);
27 impl std::io::Write for CountWriter {
28 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
29 self.0 += buf.len();
30 Ok(buf.len())
31 }
32 fn flush(&mut self) -> std::io::Result<()> {
33 Ok(())
34 }
35 }
36 let mut w = CountWriter(0);
37 serde_json::to_writer(&mut w, value)?;
38 Ok(w.0)
39}
40
41impl RequestHandler {
42 pub async fn on_send_message(
51 &self,
52 params: MessageSendParams,
53 streaming: bool,
54 headers: Option<&HashMap<String, String>>,
55 ) -> ServerResult<SendMessageResult> {
56 let method_name = if streaming {
57 "SendStreamingMessage"
58 } else {
59 "SendMessage"
60 };
61 let start = Instant::now();
62 trace_info!(method = method_name, streaming, "handling send message");
63 self.metrics.on_request(method_name);
64
65 let tenant = params.tenant.clone().unwrap_or_default();
66 let result = crate::store::tenant::TenantContext::scope(tenant, async {
67 self.send_message_inner(params, streaming, method_name, headers)
68 .await
69 })
70 .await;
71 let elapsed = start.elapsed();
72 match &result {
73 Ok(_) => {
74 self.metrics.on_response(method_name);
75 self.metrics.on_latency(method_name, elapsed);
76 }
77 Err(e) => {
78 self.metrics.on_error(method_name, &e.to_string());
79 self.metrics.on_latency(method_name, elapsed);
80 }
81 }
82 result
83 }
84
85 #[allow(clippy::too_many_lines)]
88 async fn send_message_inner(
89 &self,
90 params: MessageSendParams,
91 streaming: bool,
92 method_name: &str,
93 headers: Option<&HashMap<String, String>>,
94 ) -> ServerResult<SendMessageResult> {
95 let call_ctx = build_call_context(method_name, headers);
96 self.interceptors.run_before(&call_ctx).await?;
97
98 if let Some(ref ctx_id) = params.context_id {
100 validate_id(ctx_id, "context_id", self.limits.max_id_length)?;
101 }
102 if let Some(ref ctx_id) = params.message.context_id {
103 validate_id(&ctx_id.0, "context_id", self.limits.max_id_length)?;
104 }
105 if let Some(ref task_id) = params.message.task_id {
106 validate_id(&task_id.0, "task_id", self.limits.max_id_length)?;
107 }
108
109 if params.message.parts.is_empty() {
111 return Err(ServerError::InvalidParams(
112 "message must contain at least one part".into(),
113 ));
114 }
115
116 let max_meta = self.limits.max_metadata_size;
119 if let Some(ref meta) = params.message.metadata {
120 let meta_size = json_byte_len(meta).map_err(|_| {
121 ServerError::InvalidParams("message metadata is not serializable".into())
122 })?;
123 if meta_size > max_meta {
124 return Err(ServerError::InvalidParams(format!(
125 "message metadata exceeds maximum size ({meta_size} bytes, max {max_meta})"
126 )));
127 }
128 }
129 if let Some(ref meta) = params.metadata {
130 let meta_size = json_byte_len(meta).map_err(|_| {
131 ServerError::InvalidParams("request metadata is not serializable".into())
132 })?;
133 if meta_size > max_meta {
134 return Err(ServerError::InvalidParams(format!(
135 "request metadata exceeds maximum size ({meta_size} bytes, max {max_meta})"
136 )));
137 }
138 }
139
140 let context_id = params
143 .context_id
144 .as_deref()
145 .or_else(|| params.message.context_id.as_ref().map(|c| c.0.as_str()))
146 .map_or_else(|| uuid::Uuid::new_v4().to_string(), ToString::to_string);
147
148 let context_lock = {
152 let mut locks = self.context_locks.write().await;
153 if locks.len() >= self.limits.max_context_locks {
157 locks.retain(|_, v| Arc::strong_count(v) > 1);
158 }
159 locks.entry(context_id.clone()).or_default().clone()
160 };
161 let context_guard = context_lock.lock().await;
162
163 let stored_task = self.find_task_by_context(&context_id).await?;
165
166 let task_id = if let Some(ref msg_task_id) = params.message.task_id {
170 if let Some(ref stored) = stored_task {
171 if msg_task_id != &stored.id {
172 return Err(ServerError::InvalidParams(
173 "message task_id does not match task found for context".into(),
174 ));
175 }
176 } else {
178 let placeholder = Task {
181 id: msg_task_id.clone(),
182 context_id: ContextId::new(&context_id),
183 status: TaskStatus::with_timestamp(TaskState::Submitted),
184 history: None,
185 artifacts: None,
186 metadata: None,
187 };
188 if !self.task_store.insert_if_absent(placeholder).await? {
189 return Err(ServerError::InvalidParams(
190 "task_id already exists; cannot create duplicate".into(),
191 ));
192 }
193 }
194 msg_task_id.clone()
195 } else {
196 TaskId::new(uuid::Uuid::new_v4().to_string())
197 };
198
199 let return_immediately = params
201 .configuration
202 .as_ref()
203 .and_then(|c| c.return_immediately)
204 .unwrap_or(false);
205
206 trace_debug!(
208 task_id = %task_id,
209 context_id = %context_id,
210 "creating task"
211 );
212 let task = Task {
213 id: task_id.clone(),
214 context_id: ContextId::new(&context_id),
215 status: TaskStatus::with_timestamp(TaskState::Submitted),
216 history: None,
217 artifacts: None,
218 metadata: None,
219 };
220
221 let mut ctx = RequestContext::new(params.message, task_id.clone(), context_id);
224 if let Some(stored) = stored_task {
225 ctx = ctx.with_stored_task(stored);
226 }
227 if let Some(meta) = params.metadata {
228 ctx = ctx.with_metadata(meta);
229 }
230
231 {
236 let stale_ids: Vec<TaskId> = {
240 let tokens = self.cancellation_tokens.read().await;
241 if tokens.len() >= self.limits.max_cancellation_tokens {
242 let now = Instant::now();
243 tokens
244 .iter()
245 .filter(|(_, entry)| {
246 entry.token.is_cancelled()
247 || now.duration_since(entry.created_at) >= self.limits.max_token_age
248 })
249 .map(|(id, _)| id.clone())
250 .collect()
251 } else {
252 Vec::new()
253 }
254 };
255
256 if !stale_ids.is_empty() {
258 let mut tokens = self.cancellation_tokens.write().await;
259 for id in &stale_ids {
260 tokens.remove(id);
261 }
262 }
263
264 let mut tokens = self.cancellation_tokens.write().await;
266 tokens.insert(
267 task_id.clone(),
268 CancellationEntry {
269 token: ctx.cancellation_token.clone(),
270 created_at: Instant::now(),
271 },
272 );
273 }
274
275 self.task_store.save(task.clone()).await?;
276
277 drop(context_guard);
280
281 let (writer, reader, persistence_rx) = if streaming {
285 let (w, r, p) = self
286 .event_queue_manager
287 .get_or_create_with_persistence(&task_id)
288 .await;
289 let r = match r {
290 Some(r) => r,
291 None => {
292 self.event_queue_manager
294 .subscribe(&task_id)
295 .await
296 .ok_or_else(|| {
297 ServerError::Internal("event queue disappeared during subscribe".into())
298 })?
299 }
300 };
301 (w, r, p)
302 } else {
303 let (w, r) = self.event_queue_manager.get_or_create(&task_id).await;
304 let r = match r {
305 Some(r) => r,
306 None => {
307 self.event_queue_manager
309 .subscribe(&task_id)
310 .await
311 .ok_or_else(|| {
312 ServerError::Internal("event queue disappeared during subscribe".into())
313 })?
314 }
315 };
316 (w, r, None)
317 };
318
319 let executor = Arc::clone(&self.executor);
323 let task_id_for_cleanup = task_id.clone();
324 let event_queue_mgr = self.event_queue_manager.clone();
325 let cancel_tokens = Arc::clone(&self.cancellation_tokens);
326 let executor_timeout = self.executor_timeout;
327 let executor_handle = tokio::spawn(async move {
328 trace_debug!(task_id = %ctx.task_id, "executor started");
329
330 #[allow(clippy::items_after_statements)]
335 struct CleanupGuard {
336 task_id: Option<TaskId>,
337 queue_mgr: crate::streaming::EventQueueManager,
338 tokens: std::sync::Arc<tokio::sync::RwLock<HashMap<TaskId, CancellationEntry>>>,
339 }
340 #[allow(clippy::items_after_statements)]
341 impl Drop for CleanupGuard {
342 fn drop(&mut self) {
343 if let Some(tid) = self.task_id.take() {
344 let qmgr = self.queue_mgr.clone();
345 let tokens = std::sync::Arc::clone(&self.tokens);
346 tokio::task::spawn(async move {
347 qmgr.destroy(&tid).await;
348 tokens.write().await.remove(&tid);
349 });
350 }
351 }
352 }
353 let mut cleanup_guard = CleanupGuard {
354 task_id: Some(task_id_for_cleanup.clone()),
355 queue_mgr: event_queue_mgr.clone(),
356 tokens: Arc::clone(&cancel_tokens),
357 };
358
359 let result = {
361 let exec_future = if let Some(timeout) = executor_timeout {
362 tokio::time::timeout(timeout, executor.execute(&ctx, writer.as_ref()))
363 .await
364 .unwrap_or_else(|_| {
365 Err(a2a_protocol_types::error::A2aError::internal(format!(
366 "executor timed out after {}s",
367 timeout.as_secs()
368 )))
369 })
370 } else {
371 executor.execute(&ctx, writer.as_ref()).await
372 };
373 exec_future
374 };
375
376 if let Err(ref e) = result {
377 trace_error!(task_id = %ctx.task_id, error = %e, "executor failed");
378 let fail_event = StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
380 task_id: ctx.task_id.clone(),
381 context_id: ContextId::new(ctx.context_id.clone()),
382 status: TaskStatus::with_timestamp(TaskState::Failed),
383 metadata: Some(serde_json::json!({ "error": e.to_string() })),
384 });
385 if let Err(_write_err) = writer.write(fail_event).await {
386 trace_error!(
387 task_id = %ctx.task_id,
388 error = %_write_err,
389 "failed to write failure event to queue"
390 );
391 }
392 }
393 drop(writer);
395 event_queue_mgr.destroy(&task_id_for_cleanup).await;
398 cancel_tokens.write().await.remove(&task_id_for_cleanup);
399 cleanup_guard.task_id = None;
400 });
401
402 self.interceptors.run_after(&call_ctx).await?;
403
404 if streaming {
405 self.spawn_background_event_processor(task_id.clone(), executor_handle, persistence_rx);
415 Ok(SendMessageResult::Stream(reader))
416 } else if return_immediately {
417 Ok(SendMessageResult::Response(SendMessageResponse::Task(task)))
419 } else {
420 let final_task = self
423 .collect_events(reader, task_id.clone(), executor_handle)
424 .await?;
425 Ok(SendMessageResult::Response(SendMessageResponse::Task(
426 final_task,
427 )))
428 }
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435 use a2a_protocol_types::message::{Message, MessageId, MessageRole, Part};
436 use a2a_protocol_types::params::{MessageSendParams, SendMessageConfiguration};
437 use a2a_protocol_types::task::ContextId;
438
439 use crate::agent_executor;
440 use crate::builder::RequestHandlerBuilder;
441
442 struct DummyExecutor;
443 agent_executor!(DummyExecutor, |_ctx, _queue| async { Ok(()) });
444
445 fn make_handler() -> RequestHandler {
446 RequestHandlerBuilder::new(DummyExecutor)
447 .build()
448 .expect("default build should succeed")
449 }
450
451 fn make_params(context_id: Option<&str>) -> MessageSendParams {
452 MessageSendParams {
453 context_id: None,
454 message: Message {
455 id: MessageId::new("msg-1"),
456 role: MessageRole::User,
457 parts: vec![Part::text("hello")],
458 context_id: context_id.map(ContextId::new),
459 task_id: None,
460 reference_task_ids: None,
461 extensions: None,
462 metadata: None,
463 },
464 configuration: None,
465 metadata: None,
466 tenant: None,
467 }
468 }
469
470 #[tokio::test]
471 async fn empty_message_parts_returns_invalid_params() {
472 let handler = make_handler();
473 let mut params = make_params(None);
474 params.message.parts = vec![];
475
476 let result = handler.on_send_message(params, false, None).await;
477
478 assert!(
479 matches!(result, Err(ServerError::InvalidParams(_))),
480 "expected InvalidParams for empty parts"
481 );
482 }
483
484 #[tokio::test]
485 async fn oversized_message_metadata_returns_invalid_params() {
486 let handler = make_handler();
487 let mut params = make_params(None);
488 let big_value = "x".repeat(1_100_000);
490 params.message.metadata = Some(serde_json::json!(big_value));
491
492 let result = handler.on_send_message(params, false, None).await;
493
494 assert!(
495 matches!(result, Err(ServerError::InvalidParams(_))),
496 "expected InvalidParams for oversized message metadata"
497 );
498 }
499
500 #[tokio::test]
501 async fn oversized_request_metadata_returns_invalid_params() {
502 let handler = make_handler();
503 let mut params = make_params(None);
504 let big_value = "x".repeat(1_100_000);
506 params.metadata = Some(serde_json::json!(big_value));
507
508 let result = handler.on_send_message(params, false, None).await;
509
510 assert!(
511 matches!(result, Err(ServerError::InvalidParams(_))),
512 "expected InvalidParams for oversized request metadata"
513 );
514 }
515
516 #[tokio::test]
517 async fn valid_message_returns_ok() {
518 let handler = make_handler();
519 let params = make_params(None);
520
521 let result = handler.on_send_message(params, false, None).await;
522
523 let send_result = result.expect("expected Ok for valid message");
524 assert!(
525 matches!(
526 send_result,
527 SendMessageResult::Response(SendMessageResponse::Task(_))
528 ),
529 "expected Response(Task) for non-streaming send"
530 );
531 }
532
533 #[tokio::test]
534 async fn return_immediately_returns_task() {
535 let handler = make_handler();
536 let mut params = make_params(None);
537 params.configuration = Some(SendMessageConfiguration {
538 accepted_output_modes: vec!["text/plain".into()],
539 task_push_notification_config: None,
540 history_length: None,
541 return_immediately: Some(true),
542 });
543
544 let result = handler.on_send_message(params, false, None).await;
545
546 assert!(
547 matches!(
548 result,
549 Ok(SendMessageResult::Response(SendMessageResponse::Task(_)))
550 ),
551 "expected Response(Task) for return_immediately=true"
552 );
553 }
554
555 #[tokio::test]
556 async fn empty_context_id_returns_invalid_params() {
557 let handler = make_handler();
558 let params = make_params(Some(""));
559
560 let result = handler.on_send_message(params, false, None).await;
561
562 assert!(
563 matches!(result, Err(ServerError::InvalidParams(_))),
564 "expected InvalidParams for empty context_id"
565 );
566 }
567
568 #[tokio::test]
569 async fn too_long_context_id_returns_invalid_params() {
570 use crate::handler::limits::HandlerLimits;
572
573 let handler = RequestHandlerBuilder::new(DummyExecutor)
574 .with_handler_limits(HandlerLimits::default().with_max_id_length(10))
575 .build()
576 .unwrap();
577 let long_ctx = "x".repeat(20);
578 let params = make_params(Some(&long_ctx));
579
580 let result = handler.on_send_message(params, false, None).await;
581 assert!(
582 matches!(result, Err(ServerError::InvalidParams(ref msg)) if msg.contains("maximum length")),
583 "expected InvalidParams for too-long context_id"
584 );
585 }
586
587 #[tokio::test]
588 async fn too_long_task_id_returns_invalid_params() {
589 use crate::handler::limits::HandlerLimits;
591 use a2a_protocol_types::task::TaskId;
592
593 let handler = RequestHandlerBuilder::new(DummyExecutor)
594 .with_handler_limits(HandlerLimits::default().with_max_id_length(10))
595 .build()
596 .unwrap();
597 let mut params = make_params(None);
598 params.message.task_id = Some(TaskId::new("a".repeat(20)));
599
600 let result = handler.on_send_message(params, false, None).await;
601 assert!(
602 matches!(result, Err(ServerError::InvalidParams(ref msg)) if msg.contains("maximum length")),
603 "expected InvalidParams for too-long task_id"
604 );
605 }
606
607 #[tokio::test]
608 async fn empty_task_id_returns_invalid_params() {
609 use a2a_protocol_types::task::TaskId;
611
612 let handler = make_handler();
613 let mut params = make_params(None);
614 params.message.task_id = Some(TaskId::new(""));
615
616 let result = handler.on_send_message(params, false, None).await;
617 assert!(
618 matches!(result, Err(ServerError::InvalidParams(ref msg)) if msg.contains("empty")),
619 "expected InvalidParams for empty task_id"
620 );
621 }
622
623 #[tokio::test]
624 async fn task_id_mismatch_returns_invalid_params() {
625 use a2a_protocol_types::task::{Task, TaskId, TaskState, TaskStatus};
627
628 let handler = make_handler();
629
630 let task = Task {
632 id: TaskId::new("stored-task-id"),
633 context_id: ContextId::new("ctx-existing"),
634 status: TaskStatus::new(TaskState::Completed),
635 history: None,
636 artifacts: None,
637 metadata: None,
638 };
639 handler.task_store.save(task).await.unwrap();
640
641 let mut params = make_params(Some("ctx-existing"));
643 params.message.task_id = Some(TaskId::new("different-task-id"));
644
645 let result = handler.on_send_message(params, false, None).await;
646 assert!(
647 matches!(result, Err(ServerError::InvalidParams(ref msg)) if msg.contains("does not match")),
648 "expected InvalidParams for task_id mismatch"
649 );
650 }
651
652 #[tokio::test]
653 async fn send_message_with_request_metadata() {
654 let handler = make_handler();
656 let mut params = make_params(None);
657 params.metadata = Some(serde_json::json!({"key": "value"}));
658
659 let result = handler.on_send_message(params, false, None).await;
660 assert!(
661 result.is_ok(),
662 "send_message with request metadata should succeed"
663 );
664 }
665
666 #[tokio::test]
667 async fn send_message_error_path_records_metrics() {
668 use crate::call_context::CallContext;
670 use crate::interceptor::ServerInterceptor;
671 use std::future::Future;
672 use std::pin::Pin;
673
674 struct FailInterceptor;
675 impl ServerInterceptor for FailInterceptor {
676 fn before<'a>(
677 &'a self,
678 _ctx: &'a CallContext,
679 ) -> Pin<Box<dyn Future<Output = a2a_protocol_types::error::A2aResult<()>> + Send + 'a>>
680 {
681 Box::pin(async {
682 Err(a2a_protocol_types::error::A2aError::internal(
683 "forced failure",
684 ))
685 })
686 }
687 fn after<'a>(
688 &'a self,
689 _ctx: &'a CallContext,
690 ) -> Pin<Box<dyn Future<Output = a2a_protocol_types::error::A2aResult<()>> + Send + 'a>>
691 {
692 Box::pin(async { Ok(()) })
693 }
694 }
695
696 let handler = RequestHandlerBuilder::new(DummyExecutor)
697 .with_interceptor(FailInterceptor)
698 .build()
699 .unwrap();
700
701 let params = make_params(None);
702 let result = handler.on_send_message(params, false, None).await;
703 assert!(
704 result.is_err(),
705 "send_message should fail when interceptor rejects, exercising error metrics path"
706 );
707 }
708
709 #[tokio::test]
710 async fn send_streaming_message_error_path_records_metrics() {
711 use crate::call_context::CallContext;
713 use crate::interceptor::ServerInterceptor;
714 use std::future::Future;
715 use std::pin::Pin;
716
717 struct FailInterceptor;
718 impl ServerInterceptor for FailInterceptor {
719 fn before<'a>(
720 &'a self,
721 _ctx: &'a CallContext,
722 ) -> Pin<Box<dyn Future<Output = a2a_protocol_types::error::A2aResult<()>> + Send + 'a>>
723 {
724 Box::pin(async {
725 Err(a2a_protocol_types::error::A2aError::internal(
726 "forced failure",
727 ))
728 })
729 }
730 fn after<'a>(
731 &'a self,
732 _ctx: &'a CallContext,
733 ) -> Pin<Box<dyn Future<Output = a2a_protocol_types::error::A2aResult<()>> + Send + 'a>>
734 {
735 Box::pin(async { Ok(()) })
736 }
737 }
738
739 let handler = RequestHandlerBuilder::new(DummyExecutor)
740 .with_interceptor(FailInterceptor)
741 .build()
742 .unwrap();
743
744 let params = make_params(None);
745 let result = handler.on_send_message(params, true, None).await;
746 assert!(
747 result.is_err(),
748 "streaming send_message should fail when interceptor rejects"
749 );
750 }
751
752 #[tokio::test]
753 async fn streaming_mode_returns_stream_result() {
754 let handler = make_handler();
756 let params = make_params(None);
757
758 let result = handler.on_send_message(params, true, None).await;
759 assert!(
760 matches!(result, Ok(SendMessageResult::Stream(_))),
761 "expected Stream result in streaming mode"
762 );
763 }
764
765 #[tokio::test]
766 async fn send_message_with_stored_task_continuation() {
767 use a2a_protocol_types::task::{Task, TaskState, TaskStatus};
770
771 let handler = make_handler();
772
773 let task = Task {
775 id: TaskId::new("existing-task"),
776 context_id: ContextId::new("continue-ctx"),
777 status: TaskStatus::new(TaskState::Completed),
778 history: None,
779 artifacts: None,
780 metadata: None,
781 };
782 handler.task_store.save(task).await.unwrap();
783
784 let params = make_params(Some("continue-ctx"));
786 let result = handler.on_send_message(params, false, None).await;
787 assert!(
788 result.is_ok(),
789 "send_message with existing context should succeed"
790 );
791 }
792
793 #[tokio::test]
794 async fn send_message_with_headers() {
795 let handler = make_handler();
797 let params = make_params(None);
798 let mut headers = HashMap::new();
799 headers.insert("authorization".to_string(), "Bearer test-token".to_string());
800
801 let result = handler.on_send_message(params, false, Some(&headers)).await;
802 let send_result = result.expect("send_message with headers should succeed");
803 assert!(
804 matches!(
805 send_result,
806 SendMessageResult::Response(SendMessageResponse::Task(_))
807 ),
808 "expected Response(Task) for send with headers"
809 );
810 }
811
812 #[tokio::test]
813 async fn duplicate_task_id_without_context_match_returns_error() {
814 use a2a_protocol_types::task::{Task, TaskId as TId, TaskState, TaskStatus};
816
817 let handler = make_handler();
818
819 let task = Task {
821 id: TId::new("dup-task"),
822 context_id: ContextId::new("other-ctx"),
823 status: TaskStatus::new(TaskState::Completed),
824 history: None,
825 artifacts: None,
826 metadata: None,
827 };
828 handler.task_store.save(task).await.unwrap();
829
830 let mut params = make_params(Some("brand-new-ctx"));
832 params.message.task_id = Some(TId::new("dup-task"));
833
834 let result = handler.on_send_message(params, false, None).await;
835 assert!(
836 matches!(result, Err(ServerError::InvalidParams(ref msg)) if msg.contains("already exists")),
837 "expected InvalidParams for duplicate task_id"
838 );
839 }
840
841 #[tokio::test]
842 async fn send_message_with_tenant() {
843 let handler = make_handler();
845 let mut params = make_params(None);
846 params.tenant = Some("test-tenant".to_string());
847
848 let result = handler.on_send_message(params, false, None).await;
849 let send_result = result.expect("send_message with tenant should succeed");
850 assert!(
851 matches!(
852 send_result,
853 SendMessageResult::Response(SendMessageResponse::Task(_))
854 ),
855 "expected Response(Task) for send with tenant"
856 );
857 }
858
859 #[tokio::test]
860 async fn executor_timeout_returns_failed_task() {
861 use a2a_protocol_types::error::A2aResult;
863 use std::time::Duration;
864
865 struct SlowExecutor;
866 impl crate::executor::AgentExecutor for SlowExecutor {
867 fn execute<'a>(
868 &'a self,
869 _ctx: &'a crate::request_context::RequestContext,
870 _queue: &'a dyn crate::streaming::EventQueueWriter,
871 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = A2aResult<()>> + Send + 'a>>
872 {
873 Box::pin(async {
874 tokio::time::sleep(Duration::from_secs(60)).await;
875 Ok(())
876 })
877 }
878 }
879
880 let handler = RequestHandlerBuilder::new(SlowExecutor)
881 .with_executor_timeout(Duration::from_millis(50))
882 .build()
883 .unwrap();
884
885 let params = make_params(None);
886 let result = handler.on_send_message(params, false, None).await;
888 assert!(
890 result.is_ok(),
891 "executor timeout should still return a task result"
892 );
893 }
894
895 #[tokio::test]
896 async fn executor_failure_writes_failed_event() {
897 use a2a_protocol_types::error::{A2aError, A2aResult};
899
900 struct FailExecutor;
901 impl crate::executor::AgentExecutor for FailExecutor {
902 fn execute<'a>(
903 &'a self,
904 _ctx: &'a crate::request_context::RequestContext,
905 _queue: &'a dyn crate::streaming::EventQueueWriter,
906 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = A2aResult<()>> + Send + 'a>>
907 {
908 Box::pin(async { Err(A2aError::internal("executor exploded")) })
909 }
910 }
911
912 let handler = RequestHandlerBuilder::new(FailExecutor).build().unwrap();
913 let params = make_params(None);
914
915 let result = handler.on_send_message(params, false, None).await;
916 assert!(
918 result.is_ok(),
919 "executor failure should produce a task result"
920 );
921 }
922
923 #[tokio::test]
924 async fn cancellation_token_sweep_runs_when_map_is_full() {
925 use crate::handler::limits::HandlerLimits;
928
929 struct SlowExec;
931 impl crate::executor::AgentExecutor for SlowExec {
932 fn execute<'a>(
933 &'a self,
934 _ctx: &'a crate::request_context::RequestContext,
935 _queue: &'a dyn crate::streaming::EventQueueWriter,
936 ) -> std::pin::Pin<
937 Box<
938 dyn std::future::Future<Output = a2a_protocol_types::error::A2aResult<()>>
939 + Send
940 + 'a,
941 >,
942 > {
943 Box::pin(async {
944 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
946 Ok(())
947 })
948 }
949 }
950
951 let handler = RequestHandlerBuilder::new(SlowExec)
952 .with_handler_limits(HandlerLimits::default().with_max_cancellation_tokens(2))
953 .build()
954 .unwrap();
955
956 for _ in 0..3 {
959 let params = make_params(None);
960 let _ = handler.on_send_message(params, true, None).await;
961 }
962 handler.shutdown().await;
965 }
966
967 #[tokio::test]
968 async fn stale_cancellation_tokens_cleaned_up() {
969 use crate::handler::limits::HandlerLimits;
971 use std::time::Duration;
972
973 struct SlowExec2;
975 impl crate::executor::AgentExecutor for SlowExec2 {
976 fn execute<'a>(
977 &'a self,
978 _ctx: &'a crate::request_context::RequestContext,
979 _queue: &'a dyn crate::streaming::EventQueueWriter,
980 ) -> std::pin::Pin<
981 Box<
982 dyn std::future::Future<Output = a2a_protocol_types::error::A2aResult<()>>
983 + Send
984 + 'a,
985 >,
986 > {
987 Box::pin(async {
988 tokio::time::sleep(Duration::from_secs(10)).await;
989 Ok(())
990 })
991 }
992 }
993
994 let handler = RequestHandlerBuilder::new(SlowExec2)
995 .with_handler_limits(
996 HandlerLimits::default()
997 .with_max_cancellation_tokens(2)
998 .with_max_token_age(Duration::from_millis(1)),
1000 )
1001 .build()
1002 .unwrap();
1003
1004 for _ in 0..2 {
1006 let params = make_params(None);
1007 let _ = handler.on_send_message(params, true, None).await;
1008 }
1009
1010 tokio::time::sleep(Duration::from_millis(50)).await;
1012
1013 let params = make_params(None);
1017 let _ = handler.on_send_message(params, true, None).await;
1018
1019 handler.shutdown().await;
1021 }
1022
1023 #[tokio::test]
1024 async fn streaming_executor_failure_writes_error_event() {
1025 use a2a_protocol_types::error::{A2aError, A2aResult};
1027
1028 struct FailExecutor;
1029 impl crate::executor::AgentExecutor for FailExecutor {
1030 fn execute<'a>(
1031 &'a self,
1032 _ctx: &'a crate::request_context::RequestContext,
1033 _queue: &'a dyn crate::streaming::EventQueueWriter,
1034 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = A2aResult<()>> + Send + 'a>>
1035 {
1036 Box::pin(async { Err(A2aError::internal("streaming fail")) })
1037 }
1038 }
1039
1040 let handler = RequestHandlerBuilder::new(FailExecutor).build().unwrap();
1041 let params = make_params(None);
1042
1043 let result = handler.on_send_message(params, true, None).await;
1044 assert!(
1045 matches!(result, Ok(SendMessageResult::Stream(_))),
1046 "streaming executor failure should still return stream"
1047 );
1048 }
1049
1050 #[tokio::test]
1051 async fn input_required_continuation_reuses_task_id() {
1052 use a2a_protocol_types::task::{Task, TaskId, TaskState, TaskStatus};
1056
1057 let handler = make_handler();
1058
1059 let existing_task_id = TaskId::new("input-required-task");
1061 let task = Task {
1062 id: existing_task_id.clone(),
1063 context_id: ContextId::new("ctx-input"),
1064 status: TaskStatus::new(TaskState::InputRequired),
1065 history: None,
1066 artifacts: None,
1067 metadata: None,
1068 };
1069 handler.task_store.save(task).await.unwrap();
1070
1071 let mut params = make_params(Some("ctx-input"));
1073 params.message.task_id = Some(existing_task_id.clone());
1074
1075 let result = handler.on_send_message(params, false, None).await;
1076 let send_result = result.expect("continuation should succeed");
1077 match send_result {
1078 SendMessageResult::Response(SendMessageResponse::Task(t)) => {
1079 assert_eq!(
1080 t.id, existing_task_id,
1081 "task_id should be reused for input-required continuation"
1082 );
1083 }
1084 _ => panic!("expected Response(Task)"),
1085 }
1086 }
1087}