1use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Instant;
11
12use a2a_protocol_types::events::{StreamResponse, TaskStatusUpdateEvent};
13use a2a_protocol_types::params::MessageSendParams;
14use a2a_protocol_types::responses::SendMessageResponse;
15use a2a_protocol_types::task::{ContextId, Task, TaskId, TaskState, TaskStatus};
16
17use crate::error::{ServerError, ServerResult};
18use crate::request_context::RequestContext;
19use crate::streaming::EventQueueWriter;
20
21use super::helpers::{build_call_context, validate_id};
22use super::{CancellationEntry, RequestHandler, SendMessageResult};
23
24fn json_byte_len(value: &serde_json::Value) -> serde_json::Result<usize> {
26 struct CountWriter(usize);
27 impl std::io::Write for CountWriter {
28 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
29 self.0 += buf.len();
30 Ok(buf.len())
31 }
32 fn flush(&mut self) -> std::io::Result<()> {
33 Ok(())
34 }
35 }
36 let mut w = CountWriter(0);
37 serde_json::to_writer(&mut w, value)?;
38 Ok(w.0)
39}
40
41impl RequestHandler {
42 pub async fn on_send_message(
51 &self,
52 params: MessageSendParams,
53 streaming: bool,
54 headers: Option<&HashMap<String, String>>,
55 ) -> ServerResult<SendMessageResult> {
56 let method_name = if streaming {
57 "SendStreamingMessage"
58 } else {
59 "SendMessage"
60 };
61 let start = Instant::now();
62 trace_info!(method = method_name, streaming, "handling send message");
63 self.metrics.on_request(method_name);
64
65 let tenant = params.tenant.clone().unwrap_or_default();
66 let result = crate::store::tenant::TenantContext::scope(tenant, async {
67 self.send_message_inner(params, streaming, method_name, headers)
68 .await
69 })
70 .await;
71 let elapsed = start.elapsed();
72 match &result {
73 Ok(_) => {
74 self.metrics.on_response(method_name);
75 self.metrics.on_latency(method_name, elapsed);
76 }
77 Err(e) => {
78 self.metrics.on_error(method_name, &e.to_string());
79 self.metrics.on_latency(method_name, elapsed);
80 }
81 }
82 result
83 }
84
85 #[allow(clippy::too_many_lines)]
88 async fn send_message_inner(
89 &self,
90 params: MessageSendParams,
91 streaming: bool,
92 method_name: &str,
93 headers: Option<&HashMap<String, String>>,
94 ) -> ServerResult<SendMessageResult> {
95 let call_ctx = build_call_context(method_name, headers);
96 self.interceptors.run_before(&call_ctx).await?;
97
98 if let Some(ref ctx_id) = params.message.context_id {
100 validate_id(&ctx_id.0, "context_id", self.limits.max_id_length)?;
101 }
102 if let Some(ref task_id) = params.message.task_id {
103 validate_id(&task_id.0, "task_id", self.limits.max_id_length)?;
104 }
105
106 if params.message.parts.is_empty() {
108 return Err(ServerError::InvalidParams(
109 "message must contain at least one part".into(),
110 ));
111 }
112
113 let max_meta = self.limits.max_metadata_size;
116 if let Some(ref meta) = params.message.metadata {
117 let meta_size = json_byte_len(meta).map_err(|_| {
118 ServerError::InvalidParams("message metadata is not serializable".into())
119 })?;
120 if meta_size > max_meta {
121 return Err(ServerError::InvalidParams(format!(
122 "message metadata exceeds maximum size ({meta_size} bytes, max {max_meta})"
123 )));
124 }
125 }
126 if let Some(ref meta) = params.metadata {
127 let meta_size = json_byte_len(meta).map_err(|_| {
128 ServerError::InvalidParams("request metadata is not serializable".into())
129 })?;
130 if meta_size > max_meta {
131 return Err(ServerError::InvalidParams(format!(
132 "request metadata exceeds maximum size ({meta_size} bytes, max {max_meta})"
133 )));
134 }
135 }
136
137 let context_id = params
140 .context_id
141 .as_deref()
142 .or_else(|| params.message.context_id.as_ref().map(|c| c.0.as_str()))
143 .map_or_else(|| uuid::Uuid::new_v4().to_string(), ToString::to_string);
144
145 let context_lock = {
149 let mut locks = self.context_locks.write().await;
150 locks.entry(context_id.clone()).or_default().clone()
151 };
152 let context_guard = context_lock.lock().await;
153
154 let stored_task = self.find_task_by_context(&context_id).await?;
156
157 let task_id = if let Some(ref msg_task_id) = params.message.task_id {
161 if let Some(ref stored) = stored_task {
162 if msg_task_id != &stored.id {
163 return Err(ServerError::InvalidParams(
164 "message task_id does not match task found for context".into(),
165 ));
166 }
167 } else {
169 let placeholder = Task {
172 id: msg_task_id.clone(),
173 context_id: ContextId::new(&context_id),
174 status: TaskStatus::with_timestamp(TaskState::Submitted),
175 history: None,
176 artifacts: None,
177 metadata: None,
178 };
179 if !self.task_store.insert_if_absent(placeholder).await? {
180 return Err(ServerError::InvalidParams(
181 "task_id already exists; cannot create duplicate".into(),
182 ));
183 }
184 }
185 msg_task_id.clone()
186 } else {
187 TaskId::new(uuid::Uuid::new_v4().to_string())
188 };
189
190 let return_immediately = params
192 .configuration
193 .as_ref()
194 .and_then(|c| c.return_immediately)
195 .unwrap_or(false);
196
197 trace_debug!(
199 task_id = %task_id,
200 context_id = %context_id,
201 "creating task"
202 );
203 let task = Task {
204 id: task_id.clone(),
205 context_id: ContextId::new(&context_id),
206 status: TaskStatus::with_timestamp(TaskState::Submitted),
207 history: None,
208 artifacts: None,
209 metadata: None,
210 };
211
212 let mut ctx = RequestContext::new(params.message, task_id.clone(), context_id);
215 if let Some(stored) = stored_task {
216 ctx = ctx.with_stored_task(stored);
217 }
218 if let Some(meta) = params.metadata {
219 ctx = ctx.with_metadata(meta);
220 }
221
222 {
227 let stale_ids: Vec<TaskId> = {
231 let tokens = self.cancellation_tokens.read().await;
232 if tokens.len() >= self.limits.max_cancellation_tokens {
233 let now = Instant::now();
234 tokens
235 .iter()
236 .filter(|(_, entry)| {
237 entry.token.is_cancelled()
238 || now.duration_since(entry.created_at) >= self.limits.max_token_age
239 })
240 .map(|(id, _)| id.clone())
241 .collect()
242 } else {
243 Vec::new()
244 }
245 };
246
247 if !stale_ids.is_empty() {
249 let mut tokens = self.cancellation_tokens.write().await;
250 for id in &stale_ids {
251 tokens.remove(id);
252 }
253 }
254
255 let mut tokens = self.cancellation_tokens.write().await;
257 tokens.insert(
258 task_id.clone(),
259 CancellationEntry {
260 token: ctx.cancellation_token.clone(),
261 created_at: Instant::now(),
262 },
263 );
264 }
265
266 self.task_store.save(task.clone()).await?;
267
268 drop(context_guard);
271
272 let (writer, reader, persistence_rx) = if streaming {
276 let (w, r, p) = self
277 .event_queue_manager
278 .get_or_create_with_persistence(&task_id)
279 .await;
280 let r = match r {
281 Some(r) => r,
282 None => {
283 self.event_queue_manager
285 .subscribe(&task_id)
286 .await
287 .ok_or_else(|| {
288 ServerError::Internal("event queue disappeared during subscribe".into())
289 })?
290 }
291 };
292 (w, r, p)
293 } else {
294 let (w, r) = self.event_queue_manager.get_or_create(&task_id).await;
295 let r = match r {
296 Some(r) => r,
297 None => {
298 self.event_queue_manager
300 .subscribe(&task_id)
301 .await
302 .ok_or_else(|| {
303 ServerError::Internal("event queue disappeared during subscribe".into())
304 })?
305 }
306 };
307 (w, r, None)
308 };
309
310 let executor = Arc::clone(&self.executor);
314 let task_id_for_cleanup = task_id.clone();
315 let event_queue_mgr = self.event_queue_manager.clone();
316 let cancel_tokens = Arc::clone(&self.cancellation_tokens);
317 let executor_timeout = self.executor_timeout;
318 let executor_handle = tokio::spawn(async move {
319 trace_debug!(task_id = %ctx.task_id, "executor started");
320
321 #[allow(clippy::items_after_statements)]
326 struct CleanupGuard {
327 task_id: Option<TaskId>,
328 queue_mgr: crate::streaming::EventQueueManager,
329 tokens: std::sync::Arc<tokio::sync::RwLock<HashMap<TaskId, CancellationEntry>>>,
330 }
331 #[allow(clippy::items_after_statements)]
332 impl Drop for CleanupGuard {
333 fn drop(&mut self) {
334 if let Some(tid) = self.task_id.take() {
335 let qmgr = self.queue_mgr.clone();
336 let tokens = std::sync::Arc::clone(&self.tokens);
337 tokio::task::spawn(async move {
338 qmgr.destroy(&tid).await;
339 tokens.write().await.remove(&tid);
340 });
341 }
342 }
343 }
344 let mut cleanup_guard = CleanupGuard {
345 task_id: Some(task_id_for_cleanup.clone()),
346 queue_mgr: event_queue_mgr.clone(),
347 tokens: Arc::clone(&cancel_tokens),
348 };
349
350 let result = {
352 let exec_future = if let Some(timeout) = executor_timeout {
353 tokio::time::timeout(timeout, executor.execute(&ctx, writer.as_ref()))
354 .await
355 .unwrap_or_else(|_| {
356 Err(a2a_protocol_types::error::A2aError::internal(format!(
357 "executor timed out after {}s",
358 timeout.as_secs()
359 )))
360 })
361 } else {
362 executor.execute(&ctx, writer.as_ref()).await
363 };
364 exec_future
365 };
366
367 if let Err(ref e) = result {
368 trace_error!(task_id = %ctx.task_id, error = %e, "executor failed");
369 let fail_event = StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
371 task_id: ctx.task_id.clone(),
372 context_id: ContextId::new(ctx.context_id.clone()),
373 status: TaskStatus::with_timestamp(TaskState::Failed),
374 metadata: Some(serde_json::json!({ "error": e.to_string() })),
375 });
376 if let Err(_write_err) = writer.write(fail_event).await {
377 trace_error!(
378 task_id = %ctx.task_id,
379 error = %_write_err,
380 "failed to write failure event to queue"
381 );
382 }
383 }
384 drop(writer);
386 event_queue_mgr.destroy(&task_id_for_cleanup).await;
389 cancel_tokens.write().await.remove(&task_id_for_cleanup);
390 cleanup_guard.task_id = None;
391 });
392
393 self.interceptors.run_after(&call_ctx).await?;
394
395 if streaming {
396 self.spawn_background_event_processor(task_id.clone(), executor_handle, persistence_rx);
406 Ok(SendMessageResult::Stream(reader))
407 } else if return_immediately {
408 Ok(SendMessageResult::Response(SendMessageResponse::Task(task)))
410 } else {
411 let final_task = self
414 .collect_events(reader, task_id.clone(), executor_handle)
415 .await?;
416 Ok(SendMessageResult::Response(SendMessageResponse::Task(
417 final_task,
418 )))
419 }
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426 use a2a_protocol_types::message::{Message, MessageId, MessageRole, Part};
427 use a2a_protocol_types::params::{MessageSendParams, SendMessageConfiguration};
428 use a2a_protocol_types::task::ContextId;
429
430 use crate::agent_executor;
431 use crate::builder::RequestHandlerBuilder;
432
433 struct DummyExecutor;
434 agent_executor!(DummyExecutor, |_ctx, _queue| async { Ok(()) });
435
436 fn make_handler() -> RequestHandler {
437 RequestHandlerBuilder::new(DummyExecutor)
438 .build()
439 .expect("default build should succeed")
440 }
441
442 fn make_params(context_id: Option<&str>) -> MessageSendParams {
443 MessageSendParams {
444 context_id: None,
445 message: Message {
446 id: MessageId::new("msg-1"),
447 role: MessageRole::User,
448 parts: vec![Part::text("hello")],
449 context_id: context_id.map(ContextId::new),
450 task_id: None,
451 reference_task_ids: None,
452 extensions: None,
453 metadata: None,
454 },
455 configuration: None,
456 metadata: None,
457 tenant: None,
458 }
459 }
460
461 #[tokio::test]
462 async fn empty_message_parts_returns_invalid_params() {
463 let handler = make_handler();
464 let mut params = make_params(None);
465 params.message.parts = vec![];
466
467 let result = handler.on_send_message(params, false, None).await;
468
469 assert!(
470 matches!(result, Err(ServerError::InvalidParams(_))),
471 "expected InvalidParams for empty parts"
472 );
473 }
474
475 #[tokio::test]
476 async fn oversized_message_metadata_returns_invalid_params() {
477 let handler = make_handler();
478 let mut params = make_params(None);
479 let big_value = "x".repeat(1_100_000);
481 params.message.metadata = Some(serde_json::json!(big_value));
482
483 let result = handler.on_send_message(params, false, None).await;
484
485 assert!(
486 matches!(result, Err(ServerError::InvalidParams(_))),
487 "expected InvalidParams for oversized message metadata"
488 );
489 }
490
491 #[tokio::test]
492 async fn oversized_request_metadata_returns_invalid_params() {
493 let handler = make_handler();
494 let mut params = make_params(None);
495 let big_value = "x".repeat(1_100_000);
497 params.metadata = Some(serde_json::json!(big_value));
498
499 let result = handler.on_send_message(params, false, None).await;
500
501 assert!(
502 matches!(result, Err(ServerError::InvalidParams(_))),
503 "expected InvalidParams for oversized request metadata"
504 );
505 }
506
507 #[tokio::test]
508 async fn valid_message_returns_ok() {
509 let handler = make_handler();
510 let params = make_params(None);
511
512 let result = handler.on_send_message(params, false, None).await;
513
514 let send_result = result.expect("expected Ok for valid message");
515 assert!(
516 matches!(
517 send_result,
518 SendMessageResult::Response(SendMessageResponse::Task(_))
519 ),
520 "expected Response(Task) for non-streaming send"
521 );
522 }
523
524 #[tokio::test]
525 async fn return_immediately_returns_task() {
526 let handler = make_handler();
527 let mut params = make_params(None);
528 params.configuration = Some(SendMessageConfiguration {
529 accepted_output_modes: vec!["text/plain".into()],
530 task_push_notification_config: None,
531 history_length: None,
532 return_immediately: Some(true),
533 });
534
535 let result = handler.on_send_message(params, false, None).await;
536
537 assert!(
538 matches!(
539 result,
540 Ok(SendMessageResult::Response(SendMessageResponse::Task(_)))
541 ),
542 "expected Response(Task) for return_immediately=true"
543 );
544 }
545
546 #[tokio::test]
547 async fn empty_context_id_returns_invalid_params() {
548 let handler = make_handler();
549 let params = make_params(Some(""));
550
551 let result = handler.on_send_message(params, false, None).await;
552
553 assert!(
554 matches!(result, Err(ServerError::InvalidParams(_))),
555 "expected InvalidParams for empty context_id"
556 );
557 }
558
559 #[tokio::test]
560 async fn too_long_context_id_returns_invalid_params() {
561 use crate::handler::limits::HandlerLimits;
563
564 let handler = RequestHandlerBuilder::new(DummyExecutor)
565 .with_handler_limits(HandlerLimits::default().with_max_id_length(10))
566 .build()
567 .unwrap();
568 let long_ctx = "x".repeat(20);
569 let params = make_params(Some(&long_ctx));
570
571 let result = handler.on_send_message(params, false, None).await;
572 assert!(
573 matches!(result, Err(ServerError::InvalidParams(ref msg)) if msg.contains("maximum length")),
574 "expected InvalidParams for too-long context_id"
575 );
576 }
577
578 #[tokio::test]
579 async fn too_long_task_id_returns_invalid_params() {
580 use crate::handler::limits::HandlerLimits;
582 use a2a_protocol_types::task::TaskId;
583
584 let handler = RequestHandlerBuilder::new(DummyExecutor)
585 .with_handler_limits(HandlerLimits::default().with_max_id_length(10))
586 .build()
587 .unwrap();
588 let mut params = make_params(None);
589 params.message.task_id = Some(TaskId::new("a".repeat(20)));
590
591 let result = handler.on_send_message(params, false, None).await;
592 assert!(
593 matches!(result, Err(ServerError::InvalidParams(ref msg)) if msg.contains("maximum length")),
594 "expected InvalidParams for too-long task_id"
595 );
596 }
597
598 #[tokio::test]
599 async fn empty_task_id_returns_invalid_params() {
600 use a2a_protocol_types::task::TaskId;
602
603 let handler = make_handler();
604 let mut params = make_params(None);
605 params.message.task_id = Some(TaskId::new(""));
606
607 let result = handler.on_send_message(params, false, None).await;
608 assert!(
609 matches!(result, Err(ServerError::InvalidParams(ref msg)) if msg.contains("empty")),
610 "expected InvalidParams for empty task_id"
611 );
612 }
613
614 #[tokio::test]
615 async fn task_id_mismatch_returns_invalid_params() {
616 use a2a_protocol_types::task::{Task, TaskId, TaskState, TaskStatus};
618
619 let handler = make_handler();
620
621 let task = Task {
623 id: TaskId::new("stored-task-id"),
624 context_id: ContextId::new("ctx-existing"),
625 status: TaskStatus::new(TaskState::Completed),
626 history: None,
627 artifacts: None,
628 metadata: None,
629 };
630 handler.task_store.save(task).await.unwrap();
631
632 let mut params = make_params(Some("ctx-existing"));
634 params.message.task_id = Some(TaskId::new("different-task-id"));
635
636 let result = handler.on_send_message(params, false, None).await;
637 assert!(
638 matches!(result, Err(ServerError::InvalidParams(ref msg)) if msg.contains("does not match")),
639 "expected InvalidParams for task_id mismatch"
640 );
641 }
642
643 #[tokio::test]
644 async fn send_message_with_request_metadata() {
645 let handler = make_handler();
647 let mut params = make_params(None);
648 params.metadata = Some(serde_json::json!({"key": "value"}));
649
650 let result = handler.on_send_message(params, false, None).await;
651 assert!(
652 result.is_ok(),
653 "send_message with request metadata should succeed"
654 );
655 }
656
657 #[tokio::test]
658 async fn send_message_error_path_records_metrics() {
659 use crate::call_context::CallContext;
661 use crate::interceptor::ServerInterceptor;
662 use std::future::Future;
663 use std::pin::Pin;
664
665 struct FailInterceptor;
666 impl ServerInterceptor for FailInterceptor {
667 fn before<'a>(
668 &'a self,
669 _ctx: &'a CallContext,
670 ) -> Pin<Box<dyn Future<Output = a2a_protocol_types::error::A2aResult<()>> + Send + 'a>>
671 {
672 Box::pin(async {
673 Err(a2a_protocol_types::error::A2aError::internal(
674 "forced failure",
675 ))
676 })
677 }
678 fn after<'a>(
679 &'a self,
680 _ctx: &'a CallContext,
681 ) -> Pin<Box<dyn Future<Output = a2a_protocol_types::error::A2aResult<()>> + Send + 'a>>
682 {
683 Box::pin(async { Ok(()) })
684 }
685 }
686
687 let handler = RequestHandlerBuilder::new(DummyExecutor)
688 .with_interceptor(FailInterceptor)
689 .build()
690 .unwrap();
691
692 let params = make_params(None);
693 let result = handler.on_send_message(params, false, None).await;
694 assert!(
695 result.is_err(),
696 "send_message should fail when interceptor rejects, exercising error metrics path"
697 );
698 }
699
700 #[tokio::test]
701 async fn send_streaming_message_error_path_records_metrics() {
702 use crate::call_context::CallContext;
704 use crate::interceptor::ServerInterceptor;
705 use std::future::Future;
706 use std::pin::Pin;
707
708 struct FailInterceptor;
709 impl ServerInterceptor for FailInterceptor {
710 fn before<'a>(
711 &'a self,
712 _ctx: &'a CallContext,
713 ) -> Pin<Box<dyn Future<Output = a2a_protocol_types::error::A2aResult<()>> + Send + 'a>>
714 {
715 Box::pin(async {
716 Err(a2a_protocol_types::error::A2aError::internal(
717 "forced failure",
718 ))
719 })
720 }
721 fn after<'a>(
722 &'a self,
723 _ctx: &'a CallContext,
724 ) -> Pin<Box<dyn Future<Output = a2a_protocol_types::error::A2aResult<()>> + Send + 'a>>
725 {
726 Box::pin(async { Ok(()) })
727 }
728 }
729
730 let handler = RequestHandlerBuilder::new(DummyExecutor)
731 .with_interceptor(FailInterceptor)
732 .build()
733 .unwrap();
734
735 let params = make_params(None);
736 let result = handler.on_send_message(params, true, None).await;
737 assert!(
738 result.is_err(),
739 "streaming send_message should fail when interceptor rejects"
740 );
741 }
742
743 #[tokio::test]
744 async fn streaming_mode_returns_stream_result() {
745 let handler = make_handler();
747 let params = make_params(None);
748
749 let result = handler.on_send_message(params, true, None).await;
750 assert!(
751 matches!(result, Ok(SendMessageResult::Stream(_))),
752 "expected Stream result in streaming mode"
753 );
754 }
755
756 #[tokio::test]
757 async fn send_message_with_stored_task_continuation() {
758 use a2a_protocol_types::task::{Task, TaskState, TaskStatus};
761
762 let handler = make_handler();
763
764 let task = Task {
766 id: TaskId::new("existing-task"),
767 context_id: ContextId::new("continue-ctx"),
768 status: TaskStatus::new(TaskState::Completed),
769 history: None,
770 artifacts: None,
771 metadata: None,
772 };
773 handler.task_store.save(task).await.unwrap();
774
775 let params = make_params(Some("continue-ctx"));
777 let result = handler.on_send_message(params, false, None).await;
778 assert!(
779 result.is_ok(),
780 "send_message with existing context should succeed"
781 );
782 }
783
784 #[tokio::test]
785 async fn send_message_with_headers() {
786 let handler = make_handler();
788 let params = make_params(None);
789 let mut headers = HashMap::new();
790 headers.insert("authorization".to_string(), "Bearer test-token".to_string());
791
792 let result = handler.on_send_message(params, false, Some(&headers)).await;
793 let send_result = result.expect("send_message with headers should succeed");
794 assert!(
795 matches!(
796 send_result,
797 SendMessageResult::Response(SendMessageResponse::Task(_))
798 ),
799 "expected Response(Task) for send with headers"
800 );
801 }
802
803 #[tokio::test]
804 async fn duplicate_task_id_without_context_match_returns_error() {
805 use a2a_protocol_types::task::{Task, TaskId as TId, TaskState, TaskStatus};
807
808 let handler = make_handler();
809
810 let task = Task {
812 id: TId::new("dup-task"),
813 context_id: ContextId::new("other-ctx"),
814 status: TaskStatus::new(TaskState::Completed),
815 history: None,
816 artifacts: None,
817 metadata: None,
818 };
819 handler.task_store.save(task).await.unwrap();
820
821 let mut params = make_params(Some("brand-new-ctx"));
823 params.message.task_id = Some(TId::new("dup-task"));
824
825 let result = handler.on_send_message(params, false, None).await;
826 assert!(
827 matches!(result, Err(ServerError::InvalidParams(ref msg)) if msg.contains("already exists")),
828 "expected InvalidParams for duplicate task_id"
829 );
830 }
831
832 #[tokio::test]
833 async fn send_message_with_tenant() {
834 let handler = make_handler();
836 let mut params = make_params(None);
837 params.tenant = Some("test-tenant".to_string());
838
839 let result = handler.on_send_message(params, false, None).await;
840 let send_result = result.expect("send_message with tenant should succeed");
841 assert!(
842 matches!(
843 send_result,
844 SendMessageResult::Response(SendMessageResponse::Task(_))
845 ),
846 "expected Response(Task) for send with tenant"
847 );
848 }
849
850 #[tokio::test]
851 async fn executor_timeout_returns_failed_task() {
852 use a2a_protocol_types::error::A2aResult;
854 use std::time::Duration;
855
856 struct SlowExecutor;
857 impl crate::executor::AgentExecutor for SlowExecutor {
858 fn execute<'a>(
859 &'a self,
860 _ctx: &'a crate::request_context::RequestContext,
861 _queue: &'a dyn crate::streaming::EventQueueWriter,
862 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = A2aResult<()>> + Send + 'a>>
863 {
864 Box::pin(async {
865 tokio::time::sleep(Duration::from_secs(60)).await;
866 Ok(())
867 })
868 }
869 }
870
871 let handler = RequestHandlerBuilder::new(SlowExecutor)
872 .with_executor_timeout(Duration::from_millis(50))
873 .build()
874 .unwrap();
875
876 let params = make_params(None);
877 let result = handler.on_send_message(params, false, None).await;
879 assert!(
881 result.is_ok(),
882 "executor timeout should still return a task result"
883 );
884 }
885
886 #[tokio::test]
887 async fn executor_failure_writes_failed_event() {
888 use a2a_protocol_types::error::{A2aError, A2aResult};
890
891 struct FailExecutor;
892 impl crate::executor::AgentExecutor for FailExecutor {
893 fn execute<'a>(
894 &'a self,
895 _ctx: &'a crate::request_context::RequestContext,
896 _queue: &'a dyn crate::streaming::EventQueueWriter,
897 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = A2aResult<()>> + Send + 'a>>
898 {
899 Box::pin(async { Err(A2aError::internal("executor exploded")) })
900 }
901 }
902
903 let handler = RequestHandlerBuilder::new(FailExecutor).build().unwrap();
904 let params = make_params(None);
905
906 let result = handler.on_send_message(params, false, None).await;
907 assert!(
909 result.is_ok(),
910 "executor failure should produce a task result"
911 );
912 }
913
914 #[tokio::test]
915 async fn cancellation_token_sweep_runs_when_map_is_full() {
916 use crate::handler::limits::HandlerLimits;
919
920 struct SlowExec;
922 impl crate::executor::AgentExecutor for SlowExec {
923 fn execute<'a>(
924 &'a self,
925 _ctx: &'a crate::request_context::RequestContext,
926 _queue: &'a dyn crate::streaming::EventQueueWriter,
927 ) -> std::pin::Pin<
928 Box<
929 dyn std::future::Future<Output = a2a_protocol_types::error::A2aResult<()>>
930 + Send
931 + 'a,
932 >,
933 > {
934 Box::pin(async {
935 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
937 Ok(())
938 })
939 }
940 }
941
942 let handler = RequestHandlerBuilder::new(SlowExec)
943 .with_handler_limits(HandlerLimits::default().with_max_cancellation_tokens(2))
944 .build()
945 .unwrap();
946
947 for _ in 0..3 {
950 let params = make_params(None);
951 let _ = handler.on_send_message(params, true, None).await;
952 }
953 handler.shutdown().await;
956 }
957
958 #[tokio::test]
959 async fn stale_cancellation_tokens_cleaned_up() {
960 use crate::handler::limits::HandlerLimits;
962 use std::time::Duration;
963
964 struct SlowExec2;
966 impl crate::executor::AgentExecutor for SlowExec2 {
967 fn execute<'a>(
968 &'a self,
969 _ctx: &'a crate::request_context::RequestContext,
970 _queue: &'a dyn crate::streaming::EventQueueWriter,
971 ) -> std::pin::Pin<
972 Box<
973 dyn std::future::Future<Output = a2a_protocol_types::error::A2aResult<()>>
974 + Send
975 + 'a,
976 >,
977 > {
978 Box::pin(async {
979 tokio::time::sleep(Duration::from_secs(10)).await;
980 Ok(())
981 })
982 }
983 }
984
985 let handler = RequestHandlerBuilder::new(SlowExec2)
986 .with_handler_limits(
987 HandlerLimits::default()
988 .with_max_cancellation_tokens(2)
989 .with_max_token_age(Duration::from_millis(1)),
991 )
992 .build()
993 .unwrap();
994
995 for _ in 0..2 {
997 let params = make_params(None);
998 let _ = handler.on_send_message(params, true, None).await;
999 }
1000
1001 tokio::time::sleep(Duration::from_millis(50)).await;
1003
1004 let params = make_params(None);
1008 let _ = handler.on_send_message(params, true, None).await;
1009
1010 handler.shutdown().await;
1012 }
1013
1014 #[tokio::test]
1015 async fn streaming_executor_failure_writes_error_event() {
1016 use a2a_protocol_types::error::{A2aError, A2aResult};
1018
1019 struct FailExecutor;
1020 impl crate::executor::AgentExecutor for FailExecutor {
1021 fn execute<'a>(
1022 &'a self,
1023 _ctx: &'a crate::request_context::RequestContext,
1024 _queue: &'a dyn crate::streaming::EventQueueWriter,
1025 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = A2aResult<()>> + Send + 'a>>
1026 {
1027 Box::pin(async { Err(A2aError::internal("streaming fail")) })
1028 }
1029 }
1030
1031 let handler = RequestHandlerBuilder::new(FailExecutor).build().unwrap();
1032 let params = make_params(None);
1033
1034 let result = handler.on_send_message(params, true, None).await;
1035 assert!(
1036 matches!(result, Ok(SendMessageResult::Stream(_))),
1037 "streaming executor failure should still return stream"
1038 );
1039 }
1040
1041 #[tokio::test]
1042 async fn input_required_continuation_reuses_task_id() {
1043 use a2a_protocol_types::task::{Task, TaskId, TaskState, TaskStatus};
1047
1048 let handler = make_handler();
1049
1050 let existing_task_id = TaskId::new("input-required-task");
1052 let task = Task {
1053 id: existing_task_id.clone(),
1054 context_id: ContextId::new("ctx-input"),
1055 status: TaskStatus::new(TaskState::InputRequired),
1056 history: None,
1057 artifacts: None,
1058 metadata: None,
1059 };
1060 handler.task_store.save(task).await.unwrap();
1061
1062 let mut params = make_params(Some("ctx-input"));
1064 params.message.task_id = Some(existing_task_id.clone());
1065
1066 let result = handler.on_send_message(params, false, None).await;
1067 let send_result = result.expect("continuation should succeed");
1068 match send_result {
1069 SendMessageResult::Response(SendMessageResponse::Task(t)) => {
1070 assert_eq!(
1071 t.id, existing_task_id,
1072 "task_id should be reused for input-required continuation"
1073 );
1074 }
1075 _ => panic!("expected Response(Task)"),
1076 }
1077 }
1078}