1use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Instant;
11
12use a2a_protocol_types::events::{StreamResponse, TaskStatusUpdateEvent};
13use a2a_protocol_types::params::MessageSendParams;
14use a2a_protocol_types::responses::SendMessageResponse;
15use a2a_protocol_types::task::{ContextId, Task, TaskId, TaskState, TaskStatus};
16
17use crate::error::{ServerError, ServerResult};
18use crate::request_context::RequestContext;
19use crate::streaming::EventQueueWriter;
20
21use super::helpers::{build_call_context, validate_id};
22use super::{CancellationEntry, RequestHandler, SendMessageResult};
23
24fn json_byte_len(value: &serde_json::Value) -> serde_json::Result<usize> {
26 struct CountWriter(usize);
27 impl std::io::Write for CountWriter {
28 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
29 self.0 += buf.len();
30 Ok(buf.len())
31 }
32 fn flush(&mut self) -> std::io::Result<()> {
33 Ok(())
34 }
35 }
36 let mut w = CountWriter(0);
37 serde_json::to_writer(&mut w, value)?;
38 Ok(w.0)
39}
40
41impl RequestHandler {
42 pub async fn on_send_message(
51 &self,
52 params: MessageSendParams,
53 streaming: bool,
54 headers: Option<&HashMap<String, String>>,
55 ) -> ServerResult<SendMessageResult> {
56 let method_name = if streaming {
57 "SendStreamingMessage"
58 } else {
59 "SendMessage"
60 };
61 let start = Instant::now();
62 trace_info!(method = method_name, streaming, "handling send message");
63 self.metrics.on_request(method_name);
64
65 let tenant = params.tenant.clone().unwrap_or_default();
66 let result = crate::store::tenant::TenantContext::scope(tenant, async {
67 self.send_message_inner(params, streaming, method_name, headers)
68 .await
69 })
70 .await;
71 let elapsed = start.elapsed();
72 match &result {
73 Ok(_) => {
74 self.metrics.on_response(method_name);
75 self.metrics.on_latency(method_name, elapsed);
76 }
77 Err(e) => {
78 self.metrics.on_error(method_name, &e.to_string());
79 self.metrics.on_latency(method_name, elapsed);
80 }
81 }
82 result
83 }
84
85 #[allow(clippy::too_many_lines)]
88 async fn send_message_inner(
89 &self,
90 params: MessageSendParams,
91 streaming: bool,
92 method_name: &str,
93 headers: Option<&HashMap<String, String>>,
94 ) -> ServerResult<SendMessageResult> {
95 let call_ctx = build_call_context(method_name, headers);
96 self.interceptors.run_before(&call_ctx).await?;
97
98 if let Some(ref ctx_id) = params.message.context_id {
100 validate_id(&ctx_id.0, "context_id", self.limits.max_id_length)?;
101 }
102 if let Some(ref task_id) = params.message.task_id {
103 validate_id(&task_id.0, "task_id", self.limits.max_id_length)?;
104 }
105
106 if params.message.parts.is_empty() {
108 return Err(ServerError::InvalidParams(
109 "message must contain at least one part".into(),
110 ));
111 }
112
113 let max_meta = self.limits.max_metadata_size;
116 if let Some(ref meta) = params.message.metadata {
117 let meta_size = json_byte_len(meta).map_err(|_| {
118 ServerError::InvalidParams("message metadata is not serializable".into())
119 })?;
120 if meta_size > max_meta {
121 return Err(ServerError::InvalidParams(format!(
122 "message metadata exceeds maximum size ({meta_size} bytes, max {max_meta})"
123 )));
124 }
125 }
126 if let Some(ref meta) = params.metadata {
127 let meta_size = json_byte_len(meta).map_err(|_| {
128 ServerError::InvalidParams("request metadata is not serializable".into())
129 })?;
130 if meta_size > max_meta {
131 return Err(ServerError::InvalidParams(format!(
132 "request metadata exceeds maximum size ({meta_size} bytes, max {max_meta})"
133 )));
134 }
135 }
136
137 let context_id = params
140 .message
141 .context_id
142 .as_ref()
143 .map_or_else(|| uuid::Uuid::new_v4().to_string(), |c| c.0.clone());
144
145 let context_lock = {
149 let mut locks = self.context_locks.write().await;
150 if locks.len() >= self.limits.max_context_locks {
154 locks.retain(|_, v| Arc::strong_count(v) > 1);
155 }
156 locks.entry(context_id.clone()).or_default().clone()
157 };
158 let context_guard = context_lock.lock().await;
159
160 let stored_task = self.find_task_by_context(&context_id).await?;
162
163 let task_id = if let Some(ref msg_task_id) = params.message.task_id {
167 if let Some(ref stored) = stored_task {
168 if msg_task_id != &stored.id {
169 return Err(ServerError::InvalidParams(
170 "message task_id does not match task found for context".into(),
171 ));
172 }
173 if stored.status.state.is_terminal() {
177 return Err(ServerError::UnsupportedOperation(format!(
178 "task {} is in terminal state '{}' and cannot accept new messages",
179 stored.id, stored.status.state
180 )));
181 }
182 } else {
184 let exists = self.task_store.get(msg_task_id).await?.is_some();
188 if !exists {
189 return Err(ServerError::TaskNotFound(msg_task_id.clone()));
190 }
191 return Err(ServerError::InvalidParams(
193 "task_id exists but belongs to a different context".into(),
194 ));
195 }
196 msg_task_id.clone()
197 } else {
198 TaskId::new(uuid::Uuid::new_v4().to_string())
202 };
203
204 let return_immediately = params
206 .configuration
207 .as_ref()
208 .and_then(|c| c.return_immediately)
209 .unwrap_or(false);
210
211 trace_debug!(
213 task_id = %task_id,
214 context_id = %context_id,
215 "creating task"
216 );
217 let task = Task {
218 id: task_id.clone(),
219 context_id: ContextId::new(&context_id),
220 status: TaskStatus::with_timestamp(TaskState::Submitted),
221 history: None,
222 artifacts: None,
223 metadata: None,
224 };
225
226 let mut ctx = RequestContext::new(params.message, task_id.clone(), context_id);
229 if let Some(stored) = stored_task {
230 ctx = ctx.with_stored_task(stored);
231 }
232 if let Some(meta) = params.metadata {
233 ctx = ctx.with_metadata(meta);
234 }
235
236 {
241 let stale_ids: Vec<TaskId> = {
245 let tokens = self.cancellation_tokens.read().await;
246 if tokens.len() >= self.limits.max_cancellation_tokens {
247 let now = Instant::now();
248 tokens
249 .iter()
250 .filter(|(_, entry)| {
251 entry.token.is_cancelled()
252 || now.duration_since(entry.created_at) >= self.limits.max_token_age
253 })
254 .map(|(id, _)| id.clone())
255 .collect()
256 } else {
257 Vec::new()
258 }
259 };
260
261 if !stale_ids.is_empty() {
263 let mut tokens = self.cancellation_tokens.write().await;
264 for id in &stale_ids {
265 tokens.remove(id);
266 }
267 }
268
269 let mut tokens = self.cancellation_tokens.write().await;
271 tokens.insert(
272 task_id.clone(),
273 CancellationEntry {
274 token: ctx.cancellation_token.clone(),
275 created_at: Instant::now(),
276 },
277 );
278 }
279
280 self.task_store.save(&task).await?;
281
282 drop(context_guard);
285
286 let (writer, reader, persistence_rx) = if streaming {
290 let (w, r, p) = self
291 .event_queue_manager
292 .get_or_create_with_persistence(&task_id)
293 .await;
294 let r = match r {
295 Some(r) => r,
296 None => {
297 self.event_queue_manager
299 .subscribe(&task_id)
300 .await
301 .ok_or_else(|| {
302 ServerError::Internal("event queue disappeared during subscribe".into())
303 })?
304 }
305 };
306 (w, r, p)
307 } else {
308 let (w, r) = self.event_queue_manager.get_or_create(&task_id).await;
309 let r = match r {
310 Some(r) => r,
311 None => {
312 self.event_queue_manager
314 .subscribe(&task_id)
315 .await
316 .ok_or_else(|| {
317 ServerError::Internal("event queue disappeared during subscribe".into())
318 })?
319 }
320 };
321 (w, r, None)
322 };
323
324 let executor = Arc::clone(&self.executor);
328 let task_id_for_cleanup = task_id.clone();
329 let event_queue_mgr = self.event_queue_manager.clone();
330 let cancel_tokens = Arc::clone(&self.cancellation_tokens);
331 let executor_timeout = self.executor_timeout;
332 let executor_handle = tokio::spawn(async move {
333 trace_debug!(task_id = %ctx.task_id, "executor started");
334
335 #[allow(clippy::items_after_statements)]
340 struct CleanupGuard {
341 task_id: Option<TaskId>,
342 queue_mgr: crate::streaming::EventQueueManager,
343 tokens: std::sync::Arc<tokio::sync::RwLock<HashMap<TaskId, CancellationEntry>>>,
344 }
345 #[allow(clippy::items_after_statements)]
346 impl Drop for CleanupGuard {
347 fn drop(&mut self) {
348 if let Some(tid) = self.task_id.take() {
349 let qmgr = self.queue_mgr.clone();
350 let tokens = std::sync::Arc::clone(&self.tokens);
351 tokio::task::spawn(async move {
352 qmgr.destroy(&tid).await;
353 tokens.write().await.remove(&tid);
354 });
355 }
356 }
357 }
358 let mut cleanup_guard = CleanupGuard {
359 task_id: Some(task_id_for_cleanup.clone()),
360 queue_mgr: event_queue_mgr.clone(),
361 tokens: Arc::clone(&cancel_tokens),
362 };
363
364 let result = {
366 let exec_future = if let Some(timeout) = executor_timeout {
367 tokio::time::timeout(timeout, executor.execute(&ctx, writer.as_ref()))
368 .await
369 .unwrap_or_else(|_| {
370 Err(a2a_protocol_types::error::A2aError::internal(format!(
371 "executor timed out after {}s",
372 timeout.as_secs()
373 )))
374 })
375 } else {
376 executor.execute(&ctx, writer.as_ref()).await
377 };
378 exec_future
379 };
380
381 if let Err(ref e) = result {
382 trace_error!(task_id = %ctx.task_id, error = %e, "executor failed");
383 let fail_event = StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
385 task_id: ctx.task_id.clone(),
386 context_id: ContextId::new(ctx.context_id.clone()),
387 status: TaskStatus::with_timestamp(TaskState::Failed),
388 metadata: Some(serde_json::json!({ "error": e.to_string() })),
389 });
390 if let Err(_write_err) = writer.write(fail_event).await {
391 trace_error!(
392 task_id = %ctx.task_id,
393 error = %_write_err,
394 "failed to write failure event to queue"
395 );
396 }
397 }
398 drop(writer);
400 event_queue_mgr.destroy(&task_id_for_cleanup).await;
403 cancel_tokens.write().await.remove(&task_id_for_cleanup);
404 cleanup_guard.task_id = None;
405 });
406
407 self.interceptors.run_after(&call_ctx).await?;
408
409 if streaming {
410 self.spawn_background_event_processor(task_id.clone(), executor_handle, persistence_rx);
420 let mut reader = reader;
423 reader.set_first_event(StreamResponse::Task(task.clone()));
424 Ok(SendMessageResult::Stream(reader))
425 } else if return_immediately {
426 Ok(SendMessageResult::Response(SendMessageResponse::Task(task)))
428 } else {
429 let final_task = self
432 .collect_events(reader, task_id.clone(), executor_handle)
433 .await?;
434 Ok(SendMessageResult::Response(SendMessageResponse::Task(
435 final_task,
436 )))
437 }
438 }
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444 use a2a_protocol_types::message::{Message, MessageId, MessageRole, Part};
445 use a2a_protocol_types::params::{MessageSendParams, SendMessageConfiguration};
446 use a2a_protocol_types::task::ContextId;
447
448 use crate::agent_executor;
449 use crate::builder::RequestHandlerBuilder;
450
451 struct DummyExecutor;
452 agent_executor!(DummyExecutor, |_ctx, _queue| async { Ok(()) });
453
454 fn make_handler() -> RequestHandler {
455 RequestHandlerBuilder::new(DummyExecutor)
456 .build()
457 .expect("default build should succeed")
458 }
459
460 fn make_params(context_id: Option<&str>) -> MessageSendParams {
461 MessageSendParams {
462 message: Message {
463 id: MessageId::new("msg-1"),
464 role: MessageRole::User,
465 parts: vec![Part::text("hello")],
466 context_id: context_id.map(ContextId::new),
467 task_id: None,
468 reference_task_ids: None,
469 extensions: None,
470 metadata: None,
471 },
472 configuration: None,
473 metadata: None,
474 tenant: None,
475 }
476 }
477
478 #[tokio::test]
479 async fn empty_message_parts_returns_invalid_params() {
480 let handler = make_handler();
481 let mut params = make_params(None);
482 params.message.parts = vec![];
483
484 let result = handler.on_send_message(params, false, None).await;
485
486 assert!(
487 matches!(result, Err(ServerError::InvalidParams(_))),
488 "expected InvalidParams for empty parts"
489 );
490 }
491
492 #[tokio::test]
493 async fn oversized_message_metadata_returns_invalid_params() {
494 let handler = make_handler();
495 let mut params = make_params(None);
496 let big_value = "x".repeat(1_100_000);
498 params.message.metadata = Some(serde_json::json!(big_value));
499
500 let result = handler.on_send_message(params, false, None).await;
501
502 assert!(
503 matches!(result, Err(ServerError::InvalidParams(_))),
504 "expected InvalidParams for oversized message metadata"
505 );
506 }
507
508 #[tokio::test]
509 async fn oversized_request_metadata_returns_invalid_params() {
510 let handler = make_handler();
511 let mut params = make_params(None);
512 let big_value = "x".repeat(1_100_000);
514 params.metadata = Some(serde_json::json!(big_value));
515
516 let result = handler.on_send_message(params, false, None).await;
517
518 assert!(
519 matches!(result, Err(ServerError::InvalidParams(_))),
520 "expected InvalidParams for oversized request metadata"
521 );
522 }
523
524 #[tokio::test]
525 async fn valid_message_returns_ok() {
526 let handler = make_handler();
527 let params = make_params(None);
528
529 let result = handler.on_send_message(params, false, None).await;
530
531 let send_result = result.expect("expected Ok for valid message");
532 assert!(
533 matches!(
534 send_result,
535 SendMessageResult::Response(SendMessageResponse::Task(_))
536 ),
537 "expected Response(Task) for non-streaming send"
538 );
539 }
540
541 #[tokio::test]
542 async fn return_immediately_returns_task() {
543 let handler = make_handler();
544 let mut params = make_params(None);
545 params.configuration = Some(SendMessageConfiguration {
546 accepted_output_modes: vec!["text/plain".into()],
547 task_push_notification_config: None,
548 history_length: None,
549 return_immediately: Some(true),
550 });
551
552 let result = handler.on_send_message(params, false, None).await;
553
554 assert!(
555 matches!(
556 result,
557 Ok(SendMessageResult::Response(SendMessageResponse::Task(_)))
558 ),
559 "expected Response(Task) for return_immediately=true"
560 );
561 }
562
563 #[tokio::test]
564 async fn empty_context_id_returns_invalid_params() {
565 let handler = make_handler();
566 let params = make_params(Some(""));
567
568 let result = handler.on_send_message(params, false, None).await;
569
570 assert!(
571 matches!(result, Err(ServerError::InvalidParams(_))),
572 "expected InvalidParams for empty context_id"
573 );
574 }
575
576 #[tokio::test]
577 async fn too_long_context_id_returns_invalid_params() {
578 use crate::handler::limits::HandlerLimits;
580
581 let handler = RequestHandlerBuilder::new(DummyExecutor)
582 .with_handler_limits(HandlerLimits::default().with_max_id_length(10))
583 .build()
584 .unwrap();
585 let long_ctx = "x".repeat(20);
586 let params = make_params(Some(&long_ctx));
587
588 let result = handler.on_send_message(params, false, None).await;
589 assert!(
590 matches!(result, Err(ServerError::InvalidParams(ref msg)) if msg.contains("maximum length")),
591 "expected InvalidParams for too-long context_id"
592 );
593 }
594
595 #[tokio::test]
596 async fn too_long_task_id_returns_invalid_params() {
597 use crate::handler::limits::HandlerLimits;
599 use a2a_protocol_types::task::TaskId;
600
601 let handler = RequestHandlerBuilder::new(DummyExecutor)
602 .with_handler_limits(HandlerLimits::default().with_max_id_length(10))
603 .build()
604 .unwrap();
605 let mut params = make_params(None);
606 params.message.task_id = Some(TaskId::new("a".repeat(20)));
607
608 let result = handler.on_send_message(params, false, None).await;
609 assert!(
610 matches!(result, Err(ServerError::InvalidParams(ref msg)) if msg.contains("maximum length")),
611 "expected InvalidParams for too-long task_id"
612 );
613 }
614
615 #[tokio::test]
616 async fn empty_task_id_returns_invalid_params() {
617 use a2a_protocol_types::task::TaskId;
619
620 let handler = make_handler();
621 let mut params = make_params(None);
622 params.message.task_id = Some(TaskId::new(""));
623
624 let result = handler.on_send_message(params, false, None).await;
625 assert!(
626 matches!(result, Err(ServerError::InvalidParams(ref msg)) if msg.contains("empty")),
627 "expected InvalidParams for empty task_id"
628 );
629 }
630
631 #[tokio::test]
632 async fn task_id_mismatch_returns_invalid_params() {
633 use a2a_protocol_types::task::{Task, TaskId, TaskState, TaskStatus};
635
636 let handler = make_handler();
637
638 let task = Task {
640 id: TaskId::new("stored-task-id"),
641 context_id: ContextId::new("ctx-existing"),
642 status: TaskStatus::new(TaskState::InputRequired),
643 history: None,
644 artifacts: None,
645 metadata: None,
646 };
647 handler.task_store.save(&task).await.unwrap();
648
649 let mut params = make_params(Some("ctx-existing"));
651 params.message.task_id = Some(TaskId::new("different-task-id"));
652
653 let result = handler.on_send_message(params, false, None).await;
654 assert!(
655 matches!(result, Err(ServerError::InvalidParams(ref msg)) if msg.contains("does not match")),
656 "expected InvalidParams for task_id mismatch, got: {result:?}"
657 );
658 }
659
660 #[tokio::test]
661 async fn send_message_with_request_metadata() {
662 let handler = make_handler();
664 let mut params = make_params(None);
665 params.metadata = Some(serde_json::json!({"key": "value"}));
666
667 let result = handler.on_send_message(params, false, None).await;
668 assert!(
669 result.is_ok(),
670 "send_message with request metadata should succeed"
671 );
672 }
673
674 #[tokio::test]
675 async fn send_message_error_path_records_metrics() {
676 use crate::call_context::CallContext;
678 use crate::interceptor::ServerInterceptor;
679 use std::future::Future;
680 use std::pin::Pin;
681
682 struct FailInterceptor;
683 impl ServerInterceptor for FailInterceptor {
684 fn before<'a>(
685 &'a self,
686 _ctx: &'a CallContext,
687 ) -> Pin<Box<dyn Future<Output = a2a_protocol_types::error::A2aResult<()>> + Send + 'a>>
688 {
689 Box::pin(async {
690 Err(a2a_protocol_types::error::A2aError::internal(
691 "forced failure",
692 ))
693 })
694 }
695 fn after<'a>(
696 &'a self,
697 _ctx: &'a CallContext,
698 ) -> Pin<Box<dyn Future<Output = a2a_protocol_types::error::A2aResult<()>> + Send + 'a>>
699 {
700 Box::pin(async { Ok(()) })
701 }
702 }
703
704 let handler = RequestHandlerBuilder::new(DummyExecutor)
705 .with_interceptor(FailInterceptor)
706 .build()
707 .unwrap();
708
709 let params = make_params(None);
710 let result = handler.on_send_message(params, false, None).await;
711 assert!(
712 result.is_err(),
713 "send_message should fail when interceptor rejects, exercising error metrics path"
714 );
715 }
716
717 #[tokio::test]
718 async fn send_streaming_message_error_path_records_metrics() {
719 use crate::call_context::CallContext;
721 use crate::interceptor::ServerInterceptor;
722 use std::future::Future;
723 use std::pin::Pin;
724
725 struct FailInterceptor;
726 impl ServerInterceptor for FailInterceptor {
727 fn before<'a>(
728 &'a self,
729 _ctx: &'a CallContext,
730 ) -> Pin<Box<dyn Future<Output = a2a_protocol_types::error::A2aResult<()>> + Send + 'a>>
731 {
732 Box::pin(async {
733 Err(a2a_protocol_types::error::A2aError::internal(
734 "forced failure",
735 ))
736 })
737 }
738 fn after<'a>(
739 &'a self,
740 _ctx: &'a CallContext,
741 ) -> Pin<Box<dyn Future<Output = a2a_protocol_types::error::A2aResult<()>> + Send + 'a>>
742 {
743 Box::pin(async { Ok(()) })
744 }
745 }
746
747 let handler = RequestHandlerBuilder::new(DummyExecutor)
748 .with_interceptor(FailInterceptor)
749 .build()
750 .unwrap();
751
752 let params = make_params(None);
753 let result = handler.on_send_message(params, true, None).await;
754 assert!(
755 result.is_err(),
756 "streaming send_message should fail when interceptor rejects"
757 );
758 }
759
760 #[tokio::test]
761 async fn streaming_mode_returns_stream_result() {
762 let handler = make_handler();
764 let params = make_params(None);
765
766 let result = handler.on_send_message(params, true, None).await;
767 assert!(
768 matches!(result, Ok(SendMessageResult::Stream(_))),
769 "expected Stream result in streaming mode"
770 );
771 }
772
773 #[tokio::test]
774 async fn send_message_with_stored_task_continuation() {
775 use a2a_protocol_types::task::{Task, TaskState, TaskStatus};
778
779 let handler = make_handler();
780
781 let task = Task {
783 id: TaskId::new("existing-task"),
784 context_id: ContextId::new("continue-ctx"),
785 status: TaskStatus::new(TaskState::InputRequired),
786 history: None,
787 artifacts: None,
788 metadata: None,
789 };
790 handler.task_store.save(&task).await.unwrap();
791
792 let params = make_params(Some("continue-ctx"));
794 let result = handler.on_send_message(params, false, None).await;
795 assert!(
796 result.is_ok(),
797 "send_message with existing non-terminal context should succeed"
798 );
799 }
800
801 #[tokio::test]
802 async fn send_message_to_terminal_task_returns_unsupported_operation() {
803 use a2a_protocol_types::task::{Task, TaskState, TaskStatus};
806
807 let handler = make_handler();
808
809 let task = Task {
811 id: TaskId::new("done-task"),
812 context_id: ContextId::new("done-ctx"),
813 status: TaskStatus::new(TaskState::Completed),
814 history: None,
815 artifacts: None,
816 metadata: None,
817 };
818 handler.task_store.save(&task).await.unwrap();
819
820 let mut params = make_params(Some("done-ctx"));
822 params.message.task_id = Some(TaskId::new("done-task"));
823 let result = handler.on_send_message(params, false, None).await;
824 assert!(
825 matches!(result, Err(ServerError::UnsupportedOperation(ref msg)) if msg.contains("terminal")),
826 "expected UnsupportedOperation for terminal task, got: {result:?}"
827 );
828 }
829
830 #[tokio::test]
831 async fn send_message_to_terminal_context_without_task_id_creates_new_task() {
832 use a2a_protocol_types::task::{Task, TaskState, TaskStatus};
835
836 let handler = make_handler();
837
838 let task = Task {
840 id: TaskId::new("old-task"),
841 context_id: ContextId::new("reuse-ctx"),
842 status: TaskStatus::new(TaskState::Completed),
843 history: None,
844 artifacts: None,
845 metadata: None,
846 };
847 handler.task_store.save(&task).await.unwrap();
848
849 let params = make_params(Some("reuse-ctx"));
851 let result = handler.on_send_message(params, false, None).await;
852 assert!(
853 result.is_ok(),
854 "should create new task on terminal context, got: {result:?}"
855 );
856 }
857
858 #[tokio::test]
859 async fn send_message_with_headers() {
860 let handler = make_handler();
862 let params = make_params(None);
863 let mut headers = HashMap::new();
864 headers.insert("authorization".to_string(), "Bearer test-token".to_string());
865
866 let result = handler.on_send_message(params, false, Some(&headers)).await;
867 let send_result = result.expect("send_message with headers should succeed");
868 assert!(
869 matches!(
870 send_result,
871 SendMessageResult::Response(SendMessageResponse::Task(_))
872 ),
873 "expected Response(Task) for send with headers"
874 );
875 }
876
877 #[tokio::test]
878 async fn duplicate_task_id_without_context_match_returns_error() {
879 use a2a_protocol_types::task::{Task, TaskId as TId, TaskState, TaskStatus};
881
882 let handler = make_handler();
883
884 let task = Task {
886 id: TId::new("dup-task"),
887 context_id: ContextId::new("other-ctx"),
888 status: TaskStatus::new(TaskState::Completed),
889 history: None,
890 artifacts: None,
891 metadata: None,
892 };
893 handler.task_store.save(&task).await.unwrap();
894
895 let mut params = make_params(Some("brand-new-ctx"));
897 params.message.task_id = Some(TId::new("dup-task"));
898
899 let result = handler.on_send_message(params, false, None).await;
900 assert!(
901 matches!(result, Err(ServerError::InvalidParams(ref msg)) if msg.contains("different context")),
902 "expected InvalidParams for task_id in different context, got: {result:?}"
903 );
904 }
905
906 #[tokio::test]
907 async fn unknown_task_id_returns_task_not_found() {
908 use a2a_protocol_types::task::TaskId as TId;
910
911 let handler = make_handler();
912
913 let mut params = make_params(Some("fresh-ctx"));
915 params.message.task_id = Some(TId::new("nonexistent-task"));
916
917 let result = handler.on_send_message(params, false, None).await;
918 assert!(
919 matches!(result, Err(ServerError::TaskNotFound(_))),
920 "expected TaskNotFound for unknown task_id, got: {result:?}"
921 );
922 }
923
924 #[tokio::test]
925 async fn send_message_with_tenant() {
926 let handler = make_handler();
928 let mut params = make_params(None);
929 params.tenant = Some("test-tenant".to_string());
930
931 let result = handler.on_send_message(params, false, None).await;
932 let send_result = result.expect("send_message with tenant should succeed");
933 assert!(
934 matches!(
935 send_result,
936 SendMessageResult::Response(SendMessageResponse::Task(_))
937 ),
938 "expected Response(Task) for send with tenant"
939 );
940 }
941
942 #[tokio::test]
943 async fn executor_timeout_returns_failed_task() {
944 use a2a_protocol_types::error::A2aResult;
946 use std::time::Duration;
947
948 struct SlowExecutor;
949 impl crate::executor::AgentExecutor for SlowExecutor {
950 fn execute<'a>(
951 &'a self,
952 _ctx: &'a crate::request_context::RequestContext,
953 _queue: &'a dyn crate::streaming::EventQueueWriter,
954 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = A2aResult<()>> + Send + 'a>>
955 {
956 Box::pin(async {
957 tokio::time::sleep(Duration::from_secs(60)).await;
958 Ok(())
959 })
960 }
961 }
962
963 let handler = RequestHandlerBuilder::new(SlowExecutor)
964 .with_executor_timeout(Duration::from_millis(50))
965 .build()
966 .unwrap();
967
968 let params = make_params(None);
969 let result = handler.on_send_message(params, false, None).await;
971 assert!(
973 result.is_ok(),
974 "executor timeout should still return a task result"
975 );
976 }
977
978 #[tokio::test]
979 async fn executor_failure_writes_failed_event() {
980 use a2a_protocol_types::error::{A2aError, A2aResult};
982
983 struct FailExecutor;
984 impl crate::executor::AgentExecutor for FailExecutor {
985 fn execute<'a>(
986 &'a self,
987 _ctx: &'a crate::request_context::RequestContext,
988 _queue: &'a dyn crate::streaming::EventQueueWriter,
989 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = A2aResult<()>> + Send + 'a>>
990 {
991 Box::pin(async { Err(A2aError::internal("executor exploded")) })
992 }
993 }
994
995 let handler = RequestHandlerBuilder::new(FailExecutor).build().unwrap();
996 let params = make_params(None);
997
998 let result = handler.on_send_message(params, false, None).await;
999 assert!(
1001 result.is_ok(),
1002 "executor failure should produce a task result"
1003 );
1004 }
1005
1006 #[tokio::test]
1007 async fn cancellation_token_sweep_runs_when_map_is_full() {
1008 use crate::handler::limits::HandlerLimits;
1011
1012 struct SlowExec;
1014 impl crate::executor::AgentExecutor for SlowExec {
1015 fn execute<'a>(
1016 &'a self,
1017 _ctx: &'a crate::request_context::RequestContext,
1018 _queue: &'a dyn crate::streaming::EventQueueWriter,
1019 ) -> std::pin::Pin<
1020 Box<
1021 dyn std::future::Future<Output = a2a_protocol_types::error::A2aResult<()>>
1022 + Send
1023 + 'a,
1024 >,
1025 > {
1026 Box::pin(async {
1027 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
1029 Ok(())
1030 })
1031 }
1032 }
1033
1034 let handler = RequestHandlerBuilder::new(SlowExec)
1035 .with_handler_limits(HandlerLimits::default().with_max_cancellation_tokens(2))
1036 .build()
1037 .unwrap();
1038
1039 for _ in 0..3 {
1042 let params = make_params(None);
1043 let _ = handler.on_send_message(params, true, None).await;
1044 }
1045 handler.shutdown().await;
1048 }
1049
1050 #[tokio::test]
1051 async fn stale_cancellation_tokens_cleaned_up() {
1052 use crate::handler::limits::HandlerLimits;
1054 use std::time::Duration;
1055
1056 struct SlowExec2;
1058 impl crate::executor::AgentExecutor for SlowExec2 {
1059 fn execute<'a>(
1060 &'a self,
1061 _ctx: &'a crate::request_context::RequestContext,
1062 _queue: &'a dyn crate::streaming::EventQueueWriter,
1063 ) -> std::pin::Pin<
1064 Box<
1065 dyn std::future::Future<Output = a2a_protocol_types::error::A2aResult<()>>
1066 + Send
1067 + 'a,
1068 >,
1069 > {
1070 Box::pin(async {
1071 tokio::time::sleep(Duration::from_secs(10)).await;
1072 Ok(())
1073 })
1074 }
1075 }
1076
1077 let handler = RequestHandlerBuilder::new(SlowExec2)
1078 .with_handler_limits(
1079 HandlerLimits::default()
1080 .with_max_cancellation_tokens(2)
1081 .with_max_token_age(Duration::from_millis(1)),
1083 )
1084 .build()
1085 .unwrap();
1086
1087 for _ in 0..2 {
1089 let params = make_params(None);
1090 let _ = handler.on_send_message(params, true, None).await;
1091 }
1092
1093 tokio::time::sleep(Duration::from_millis(50)).await;
1095
1096 let params = make_params(None);
1100 let _ = handler.on_send_message(params, true, None).await;
1101
1102 handler.shutdown().await;
1104 }
1105
1106 #[tokio::test]
1107 async fn streaming_executor_failure_writes_error_event() {
1108 use a2a_protocol_types::error::{A2aError, A2aResult};
1110
1111 struct FailExecutor;
1112 impl crate::executor::AgentExecutor for FailExecutor {
1113 fn execute<'a>(
1114 &'a self,
1115 _ctx: &'a crate::request_context::RequestContext,
1116 _queue: &'a dyn crate::streaming::EventQueueWriter,
1117 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = A2aResult<()>> + Send + 'a>>
1118 {
1119 Box::pin(async { Err(A2aError::internal("streaming fail")) })
1120 }
1121 }
1122
1123 let handler = RequestHandlerBuilder::new(FailExecutor).build().unwrap();
1124 let params = make_params(None);
1125
1126 let result = handler.on_send_message(params, true, None).await;
1127 assert!(
1128 matches!(result, Ok(SendMessageResult::Stream(_))),
1129 "streaming executor failure should still return stream"
1130 );
1131 }
1132
1133 #[tokio::test]
1134 async fn input_required_continuation_reuses_task_id() {
1135 use a2a_protocol_types::task::{Task, TaskId, TaskState, TaskStatus};
1139
1140 let handler = make_handler();
1141
1142 let existing_task_id = TaskId::new("input-required-task");
1144 let task = Task {
1145 id: existing_task_id.clone(),
1146 context_id: ContextId::new("ctx-input"),
1147 status: TaskStatus::new(TaskState::InputRequired),
1148 history: None,
1149 artifacts: None,
1150 metadata: None,
1151 };
1152 handler.task_store.save(&task).await.unwrap();
1153
1154 let mut params = make_params(Some("ctx-input"));
1156 params.message.task_id = Some(existing_task_id.clone());
1157
1158 let result = handler.on_send_message(params, false, None).await;
1159 let send_result = result.expect("continuation should succeed");
1160 match send_result {
1161 SendMessageResult::Response(SendMessageResponse::Task(t)) => {
1162 assert_eq!(
1163 t.id, existing_task_id,
1164 "task_id should be reused for input-required continuation"
1165 );
1166 }
1167 _ => panic!("expected Response(Task)"),
1168 }
1169 }
1170}