1use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Instant;
11
12use a2a_protocol_types::events::{StreamResponse, TaskStatusUpdateEvent};
13use a2a_protocol_types::params::MessageSendParams;
14use a2a_protocol_types::responses::SendMessageResponse;
15use a2a_protocol_types::task::{ContextId, Task, TaskId, TaskState, TaskStatus};
16
17use crate::error::{ServerError, ServerResult};
18use crate::request_context::RequestContext;
19use crate::streaming::EventQueueWriter;
20
21use super::helpers::{build_call_context, validate_id};
22use super::{CancellationEntry, RequestHandler, SendMessageResult};
23
24pub const MAX_TASK_HISTORY_MESSAGES: usize = 1024;
30
31fn shape_response_history(task: &mut Task, history_length: Option<u32>) {
41 task.history = match (task.history.take(), history_length) {
42 (Some(msgs), Some(n)) if n > 0 => {
43 let n = n as usize;
44 if msgs.len() > n {
45 Some(msgs[msgs.len() - n..].to_vec())
46 } else {
47 Some(msgs)
48 }
49 }
50 _ => None,
51 };
52}
53
54fn json_byte_len(value: &serde_json::Value) -> serde_json::Result<usize> {
56 struct CountWriter(usize);
57 impl std::io::Write for CountWriter {
58 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
59 self.0 += buf.len();
60 Ok(buf.len())
61 }
62 fn flush(&mut self) -> std::io::Result<()> {
63 Ok(())
64 }
65 }
66 let mut w = CountWriter(0);
67 serde_json::to_writer(&mut w, value)?;
68 Ok(w.0)
69}
70
71impl RequestHandler {
72 pub async fn on_send_message(
81 &self,
82 params: MessageSendParams,
83 streaming: bool,
84 headers: Option<&HashMap<String, String>>,
85 ) -> ServerResult<SendMessageResult> {
86 let method_name = if streaming {
87 "SendStreamingMessage"
88 } else {
89 "SendMessage"
90 };
91 let start = Instant::now();
92 trace_info!(method = method_name, streaming, "handling send message");
93 self.metrics.on_request(method_name);
94
95 let tenant = params.tenant.clone().unwrap_or_default();
96 let result = crate::store::tenant::TenantContext::scope(tenant, async {
97 self.send_message_inner(params, streaming, method_name, headers)
98 .await
99 })
100 .await;
101 let elapsed = start.elapsed();
102 match &result {
103 Ok(_) => {
104 self.metrics.on_response(method_name);
105 self.metrics.on_latency(method_name, elapsed);
106 }
107 Err(e) => {
108 self.metrics.on_error(method_name, &e.to_string());
109 self.metrics.on_latency(method_name, elapsed);
110 }
111 }
112 result
113 }
114
115 #[allow(clippy::too_many_lines)]
118 async fn send_message_inner(
119 &self,
120 params: MessageSendParams,
121 streaming: bool,
122 method_name: &str,
123 headers: Option<&HashMap<String, String>>,
124 ) -> ServerResult<SendMessageResult> {
125 let call_ctx = build_call_context(method_name, headers);
126 self.interceptors.run_before(&call_ctx).await?;
127
128 if let Some(ref ctx_id) = params.message.context_id {
130 validate_id(&ctx_id.0, "context_id", self.limits.max_id_length)?;
131 }
132 if let Some(ref task_id) = params.message.task_id {
133 validate_id(&task_id.0, "task_id", self.limits.max_id_length)?;
134 }
135
136 if params.message.parts.is_empty() {
138 return Err(ServerError::InvalidParams(
139 "message must contain at least one part".into(),
140 ));
141 }
142
143 let max_meta = self.limits.max_metadata_size;
146 if let Some(ref meta) = params.message.metadata {
147 let meta_size = json_byte_len(meta).map_err(|_| {
148 ServerError::InvalidParams("message metadata is not serializable".into())
149 })?;
150 if meta_size > max_meta {
151 return Err(ServerError::InvalidParams(format!(
152 "message metadata exceeds maximum size ({meta_size} bytes, max {max_meta})"
153 )));
154 }
155 }
156 if let Some(ref meta) = params.metadata {
157 let meta_size = json_byte_len(meta).map_err(|_| {
158 ServerError::InvalidParams("request metadata is not serializable".into())
159 })?;
160 if meta_size > max_meta {
161 return Err(ServerError::InvalidParams(format!(
162 "request metadata exceeds maximum size ({meta_size} bytes, max {max_meta})"
163 )));
164 }
165 }
166
167 let context_id = params
170 .message
171 .context_id
172 .as_ref()
173 .map_or_else(|| uuid::Uuid::new_v4().to_string(), |c| c.0.clone());
174
175 let context_lock = {
179 let mut locks = self.context_locks.write().await;
180 if locks.len() >= self.limits.max_context_locks {
184 locks.retain(|_, v| Arc::strong_count(v) > 1);
185 }
186 locks.entry(context_id.clone()).or_default().clone()
187 };
188 let context_guard = context_lock.lock().await;
189
190 let stored_task = self.find_task_by_context(&context_id).await?;
192
193 let task_id = if let Some(ref msg_task_id) = params.message.task_id {
197 if let Some(ref stored) = stored_task {
198 if msg_task_id != &stored.id {
199 return Err(ServerError::InvalidParams(
200 "message task_id does not match task found for context".into(),
201 ));
202 }
203 if stored.status.state.is_terminal() {
207 return Err(ServerError::UnsupportedOperation(format!(
208 "task {} is in terminal state '{}' and cannot accept new messages",
209 stored.id, stored.status.state
210 )));
211 }
212 } else {
214 let exists = self.task_store.get(msg_task_id).await?.is_some();
218 if !exists {
219 return Err(ServerError::TaskNotFound(msg_task_id.clone()));
220 }
221 return Err(ServerError::InvalidParams(
223 "task_id exists but belongs to a different context".into(),
224 ));
225 }
226 msg_task_id.clone()
227 } else {
228 TaskId::new(uuid::Uuid::new_v4().to_string())
232 };
233
234 let return_immediately = params
236 .configuration
237 .as_ref()
238 .and_then(|c| c.return_immediately)
239 .unwrap_or(false);
240 let response_history_length = params.configuration.as_ref().and_then(|c| c.history_length);
241
242 trace_debug!(
244 task_id = %task_id,
245 context_id = %context_id,
246 "creating task"
247 );
248 let mut history = stored_task
255 .as_ref()
256 .and_then(|s| s.history.clone())
257 .unwrap_or_default();
258 history.push(params.message.clone());
259 if history.len() > MAX_TASK_HISTORY_MESSAGES {
260 let excess = history.len() - MAX_TASK_HISTORY_MESSAGES;
261 history.drain(..excess);
262 }
263 let task = Task {
264 id: task_id.clone(),
265 context_id: ContextId::new(&context_id),
266 status: TaskStatus::with_timestamp(TaskState::Submitted),
267 history: Some(history),
268 artifacts: stored_task.as_ref().and_then(|s| s.artifacts.clone()),
269 metadata: stored_task.as_ref().and_then(|s| s.metadata.clone()),
270 };
271
272 let mut ctx = RequestContext::new(params.message, task_id.clone(), context_id);
275 if let Some(stored) = stored_task {
276 ctx = ctx.with_stored_task(stored);
277 }
278 if let Some(meta) = params.metadata {
279 ctx = ctx.with_metadata(meta);
280 }
281
282 {
287 let stale_ids: Vec<TaskId> = {
291 let tokens = self.cancellation_tokens.read().await;
292 if tokens.len() >= self.limits.max_cancellation_tokens {
293 let now = Instant::now();
294 tokens
295 .iter()
296 .filter(|(_, entry)| {
297 entry.token.is_cancelled()
298 || now.duration_since(entry.created_at) >= self.limits.max_token_age
299 })
300 .map(|(id, _)| id.clone())
301 .collect()
302 } else {
303 Vec::new()
304 }
305 };
306
307 if !stale_ids.is_empty() {
309 let mut tokens = self.cancellation_tokens.write().await;
310 for id in &stale_ids {
311 tokens.remove(id);
312 }
313 }
314
315 let mut tokens = self.cancellation_tokens.write().await;
317 tokens.insert(
318 task_id.clone(),
319 CancellationEntry {
320 token: ctx.cancellation_token.clone(),
321 created_at: Instant::now(),
322 },
323 );
324 }
325
326 self.task_store.save(&task).await?;
327
328 drop(context_guard);
331
332 let (writer, reader, persistence_rx) = if streaming {
336 let (w, r, p) = self
337 .event_queue_manager
338 .get_or_create_with_persistence(&task_id)
339 .await;
340 let r = match r {
341 Some(r) => r,
342 None => {
343 self.event_queue_manager
345 .subscribe(&task_id)
346 .await
347 .ok_or_else(|| {
348 ServerError::Internal("event queue disappeared during subscribe".into())
349 })?
350 }
351 };
352 (w, r, p)
353 } else {
354 let (w, r) = self.event_queue_manager.get_or_create(&task_id).await;
355 let r = match r {
356 Some(r) => r,
357 None => {
358 self.event_queue_manager
360 .subscribe(&task_id)
361 .await
362 .ok_or_else(|| {
363 ServerError::Internal("event queue disappeared during subscribe".into())
364 })?
365 }
366 };
367 (w, r, None)
368 };
369
370 let executor = Arc::clone(&self.executor);
374 let task_id_for_cleanup = task_id.clone();
375 let event_queue_mgr = self.event_queue_manager.clone();
376 let cancel_tokens = Arc::clone(&self.cancellation_tokens);
377 let executor_timeout = self.executor_timeout;
378 let executor_handle = tokio::spawn(async move {
379 trace_debug!(task_id = %ctx.task_id, "executor started");
380
381 #[allow(clippy::items_after_statements)]
386 struct CleanupGuard {
387 task_id: Option<TaskId>,
388 queue_mgr: crate::streaming::EventQueueManager,
389 tokens: std::sync::Arc<tokio::sync::RwLock<HashMap<TaskId, CancellationEntry>>>,
390 }
391 #[allow(clippy::items_after_statements)]
392 impl Drop for CleanupGuard {
393 fn drop(&mut self) {
394 if let Some(tid) = self.task_id.take() {
395 let qmgr = self.queue_mgr.clone();
396 let tokens = std::sync::Arc::clone(&self.tokens);
397 tokio::task::spawn(async move {
398 qmgr.destroy(&tid).await;
399 tokens.write().await.remove(&tid);
400 });
401 }
402 }
403 }
404 let mut cleanup_guard = CleanupGuard {
405 task_id: Some(task_id_for_cleanup.clone()),
406 queue_mgr: event_queue_mgr.clone(),
407 tokens: Arc::clone(&cancel_tokens),
408 };
409
410 let result = {
412 let exec_future = if let Some(timeout) = executor_timeout {
413 tokio::time::timeout(timeout, executor.execute(&ctx, writer.as_ref()))
414 .await
415 .unwrap_or_else(|_| {
416 Err(a2a_protocol_types::error::A2aError::internal(format!(
417 "executor timed out after {}s",
418 timeout.as_secs()
419 )))
420 })
421 } else {
422 executor.execute(&ctx, writer.as_ref()).await
423 };
424 exec_future
425 };
426
427 if let Err(ref e) = result {
428 trace_error!(task_id = %ctx.task_id, error = %e, "executor failed");
429 let fail_event = StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
431 task_id: ctx.task_id.clone(),
432 context_id: ContextId::new(ctx.context_id.clone()),
433 status: TaskStatus::with_timestamp(TaskState::Failed),
434 metadata: Some(serde_json::json!({ "error": e.to_string() })),
435 });
436 if let Err(_write_err) = writer.write(fail_event).await {
437 trace_error!(
438 task_id = %ctx.task_id,
439 error = %_write_err,
440 "failed to write failure event to queue"
441 );
442 }
443 }
444 drop(writer);
446 event_queue_mgr.destroy(&task_id_for_cleanup).await;
449 cancel_tokens.write().await.remove(&task_id_for_cleanup);
450 cleanup_guard.task_id = None;
451 });
452
453 self.interceptors.run_after(&call_ctx).await?;
454
455 if streaming {
456 self.spawn_background_event_processor(task_id.clone(), executor_handle, persistence_rx);
466 let mut reader = reader;
469 let mut snapshot = task.clone();
470 shape_response_history(&mut snapshot, response_history_length);
471 reader.set_first_event(StreamResponse::Task(snapshot));
472 Ok(SendMessageResult::Stream(reader))
473 } else if return_immediately {
474 let mut task = task;
476 shape_response_history(&mut task, response_history_length);
477 Ok(SendMessageResult::Response(SendMessageResponse::Task(task)))
478 } else {
479 let mut final_task = self
482 .collect_events(reader, task_id.clone(), executor_handle)
483 .await?;
484 shape_response_history(&mut final_task, response_history_length);
485 Ok(SendMessageResult::Response(SendMessageResponse::Task(
486 final_task,
487 )))
488 }
489 }
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495 use a2a_protocol_types::message::{Message, MessageId, MessageRole, Part};
496 use a2a_protocol_types::params::{MessageSendParams, SendMessageConfiguration};
497 use a2a_protocol_types::task::ContextId;
498
499 use crate::agent_executor;
500 use crate::builder::RequestHandlerBuilder;
501
502 struct DummyExecutor;
503 agent_executor!(DummyExecutor, |_ctx, _queue| async { Ok(()) });
504
505 fn make_handler() -> RequestHandler {
506 RequestHandlerBuilder::new(DummyExecutor)
507 .build()
508 .expect("default build should succeed")
509 }
510
511 fn make_params(context_id: Option<&str>) -> MessageSendParams {
512 MessageSendParams {
513 message: Message {
514 id: MessageId::new("msg-1"),
515 role: MessageRole::User,
516 parts: vec![Part::text("hello")],
517 context_id: context_id.map(ContextId::new),
518 task_id: None,
519 reference_task_ids: None,
520 extensions: None,
521 metadata: None,
522 },
523 configuration: None,
524 metadata: None,
525 tenant: None,
526 }
527 }
528
529 #[tokio::test]
530 async fn empty_message_parts_returns_invalid_params() {
531 let handler = make_handler();
532 let mut params = make_params(None);
533 params.message.parts = vec![];
534
535 let result = handler.on_send_message(params, false, None).await;
536
537 assert!(
538 matches!(result, Err(ServerError::InvalidParams(_))),
539 "expected InvalidParams for empty parts"
540 );
541 }
542
543 #[tokio::test]
544 async fn oversized_message_metadata_returns_invalid_params() {
545 let handler = make_handler();
546 let mut params = make_params(None);
547 let big_value = "x".repeat(1_100_000);
549 params.message.metadata = Some(serde_json::json!(big_value));
550
551 let result = handler.on_send_message(params, false, None).await;
552
553 assert!(
554 matches!(result, Err(ServerError::InvalidParams(_))),
555 "expected InvalidParams for oversized message metadata"
556 );
557 }
558
559 #[tokio::test]
560 async fn oversized_request_metadata_returns_invalid_params() {
561 let handler = make_handler();
562 let mut params = make_params(None);
563 let big_value = "x".repeat(1_100_000);
565 params.metadata = Some(serde_json::json!(big_value));
566
567 let result = handler.on_send_message(params, false, None).await;
568
569 assert!(
570 matches!(result, Err(ServerError::InvalidParams(_))),
571 "expected InvalidParams for oversized request metadata"
572 );
573 }
574
575 #[tokio::test]
576 async fn valid_message_returns_ok() {
577 let handler = make_handler();
578 let params = make_params(None);
579
580 let result = handler.on_send_message(params, false, None).await;
581
582 let send_result = result.expect("expected Ok for valid message");
583 assert!(
584 matches!(
585 send_result,
586 SendMessageResult::Response(SendMessageResponse::Task(_))
587 ),
588 "expected Response(Task) for non-streaming send"
589 );
590 }
591
592 #[tokio::test]
593 async fn return_immediately_returns_task() {
594 let handler = make_handler();
595 let mut params = make_params(None);
596 params.configuration = Some(SendMessageConfiguration {
597 accepted_output_modes: vec!["text/plain".into()],
598 task_push_notification_config: None,
599 history_length: None,
600 return_immediately: Some(true),
601 });
602
603 let result = handler.on_send_message(params, false, None).await;
604
605 assert!(
606 matches!(
607 result,
608 Ok(SendMessageResult::Response(SendMessageResponse::Task(_)))
609 ),
610 "expected Response(Task) for return_immediately=true"
611 );
612 }
613
614 #[tokio::test]
615 async fn empty_context_id_returns_invalid_params() {
616 let handler = make_handler();
617 let params = make_params(Some(""));
618
619 let result = handler.on_send_message(params, false, None).await;
620
621 assert!(
622 matches!(result, Err(ServerError::InvalidParams(_))),
623 "expected InvalidParams for empty context_id"
624 );
625 }
626
627 #[tokio::test]
628 async fn too_long_context_id_returns_invalid_params() {
629 use crate::handler::limits::HandlerLimits;
631
632 let handler = RequestHandlerBuilder::new(DummyExecutor)
633 .with_handler_limits(HandlerLimits::default().with_max_id_length(10))
634 .build()
635 .unwrap();
636 let long_ctx = "x".repeat(20);
637 let params = make_params(Some(&long_ctx));
638
639 let result = handler.on_send_message(params, false, None).await;
640 assert!(
641 matches!(result, Err(ServerError::InvalidParams(ref msg)) if msg.contains("maximum length")),
642 "expected InvalidParams for too-long context_id"
643 );
644 }
645
646 #[tokio::test]
647 async fn too_long_task_id_returns_invalid_params() {
648 use crate::handler::limits::HandlerLimits;
650 use a2a_protocol_types::task::TaskId;
651
652 let handler = RequestHandlerBuilder::new(DummyExecutor)
653 .with_handler_limits(HandlerLimits::default().with_max_id_length(10))
654 .build()
655 .unwrap();
656 let mut params = make_params(None);
657 params.message.task_id = Some(TaskId::new("a".repeat(20)));
658
659 let result = handler.on_send_message(params, false, None).await;
660 assert!(
661 matches!(result, Err(ServerError::InvalidParams(ref msg)) if msg.contains("maximum length")),
662 "expected InvalidParams for too-long task_id"
663 );
664 }
665
666 #[tokio::test]
667 async fn empty_task_id_returns_invalid_params() {
668 use a2a_protocol_types::task::TaskId;
670
671 let handler = make_handler();
672 let mut params = make_params(None);
673 params.message.task_id = Some(TaskId::new(""));
674
675 let result = handler.on_send_message(params, false, None).await;
676 assert!(
677 matches!(result, Err(ServerError::InvalidParams(ref msg)) if msg.contains("empty")),
678 "expected InvalidParams for empty task_id"
679 );
680 }
681
682 #[tokio::test]
683 async fn task_id_mismatch_returns_invalid_params() {
684 use a2a_protocol_types::task::{Task, TaskId, TaskState, TaskStatus};
686
687 let handler = make_handler();
688
689 let task = Task {
691 id: TaskId::new("stored-task-id"),
692 context_id: ContextId::new("ctx-existing"),
693 status: TaskStatus::new(TaskState::InputRequired),
694 history: None,
695 artifacts: None,
696 metadata: None,
697 };
698 handler.task_store.save(&task).await.unwrap();
699
700 let mut params = make_params(Some("ctx-existing"));
702 params.message.task_id = Some(TaskId::new("different-task-id"));
703
704 let result = handler.on_send_message(params, false, None).await;
705 assert!(
706 matches!(result, Err(ServerError::InvalidParams(ref msg)) if msg.contains("does not match")),
707 "expected InvalidParams for task_id mismatch, got: {result:?}"
708 );
709 }
710
711 #[tokio::test]
712 async fn send_message_records_user_message_in_history() {
713 let handler = make_handler();
716 let result = handler
717 .on_send_message(make_params(None), false, None)
718 .await
719 .expect("send should succeed");
720 let task_id = match result {
721 SendMessageResult::Response(SendMessageResponse::Task(t)) => t.id,
722 other => panic!("expected task response, got {other:?}"),
723 };
724 let stored = handler
725 .task_store
726 .get(&task_id)
727 .await
728 .expect("get")
729 .expect("task stored");
730 let history = stored.history.expect("history populated on send");
731 assert_eq!(history.len(), 1, "exactly the incoming user message");
732 assert_eq!(history[0].role, MessageRole::User);
733 assert_eq!(
734 history[0].parts[0].text_content(),
735 Some("hello"),
736 "history records the message content"
737 );
738 }
739
740 #[tokio::test]
741 async fn continuation_appends_history_and_preserves_artifacts() {
742 use a2a_protocol_types::artifact::Artifact;
745 let handler = make_handler();
746 let prior = Task {
747 id: TaskId::new("cont-task"),
748 context_id: ContextId::new("ctx-cont"),
749 status: TaskStatus::new(TaskState::InputRequired),
750 history: Some(vec![Message {
751 id: MessageId::new("m-prior"),
752 role: MessageRole::User,
753 parts: vec![Part::text("first turn")],
754 context_id: None,
755 task_id: None,
756 reference_task_ids: None,
757 extensions: None,
758 metadata: None,
759 }]),
760 artifacts: Some(vec![Artifact::new("a1", vec![Part::text("turn-1 output")])]),
761 metadata: Some(serde_json::json!({"k": "v"})),
762 };
763 handler.task_store.save(&prior).await.unwrap();
764
765 let mut params = make_params(Some("ctx-cont"));
766 params.message.task_id = Some(TaskId::new("cont-task"));
767 handler
768 .on_send_message(params, false, None)
769 .await
770 .expect("continuation should succeed");
771
772 let stored = handler
773 .task_store
774 .get(&TaskId::new("cont-task"))
775 .await
776 .expect("get")
777 .expect("task stored");
778 let history = stored.history.expect("history preserved");
779 assert_eq!(history.len(), 2, "prior message + continuation message");
780 assert_eq!(history[0].parts[0].text_content(), Some("first turn"));
781 assert_eq!(history[1].parts[0].text_content(), Some("hello"));
782 assert!(
783 stored.artifacts.as_ref().is_some_and(|a| a.len() == 1),
784 "continuation must not wipe accumulated artifacts"
785 );
786 assert_eq!(
787 stored.metadata,
788 Some(serde_json::json!({"k": "v"})),
789 "continuation must not wipe task metadata"
790 );
791 }
792
793 #[tokio::test]
794 async fn history_is_capped_at_max_messages() {
795 let handler = make_handler();
797 let mut long_history: Vec<Message> = (0..MAX_TASK_HISTORY_MESSAGES)
798 .map(|i| Message {
799 id: MessageId::new(format!("m-{i}")),
800 role: MessageRole::User,
801 parts: vec![Part::text(format!("msg {i}"))],
802 context_id: None,
803 task_id: None,
804 reference_task_ids: None,
805 extensions: None,
806 metadata: None,
807 })
808 .collect();
809 long_history[0].parts = vec![Part::text("OLDEST")];
810 let prior = Task {
811 id: TaskId::new("cap-task"),
812 context_id: ContextId::new("ctx-cap"),
813 status: TaskStatus::new(TaskState::InputRequired),
814 history: Some(long_history),
815 artifacts: None,
816 metadata: None,
817 };
818 handler.task_store.save(&prior).await.unwrap();
819
820 let mut params = make_params(Some("ctx-cap"));
821 params.message.task_id = Some(TaskId::new("cap-task"));
822 handler
823 .on_send_message(params, false, None)
824 .await
825 .expect("continuation should succeed");
826
827 let stored = handler
828 .task_store
829 .get(&TaskId::new("cap-task"))
830 .await
831 .unwrap()
832 .unwrap();
833 let history = stored.history.unwrap();
834 assert_eq!(history.len(), MAX_TASK_HISTORY_MESSAGES, "capped");
835 assert_ne!(
836 history[0].parts[0].text_content(),
837 Some("OLDEST"),
838 "the oldest message is dropped first"
839 );
840 assert_eq!(
841 history[MAX_TASK_HISTORY_MESSAGES - 1].parts[0].text_content(),
842 Some("hello"),
843 "the newest message is retained"
844 );
845 }
846
847 #[tokio::test]
848 async fn send_response_omits_history_by_default_and_honors_history_length() {
849 use a2a_protocol_types::params::SendMessageConfiguration;
854 let handler = make_handler();
855
856 let result = handler
857 .on_send_message(make_params(Some("ctx-resp")), false, None)
858 .await
859 .expect("send should succeed");
860 let task = match result {
861 SendMessageResult::Response(SendMessageResponse::Task(t)) => t,
862 other => panic!("expected task response, got {other:?}"),
863 };
864 assert!(
865 task.history.is_none(),
866 "default send response must not echo history"
867 );
868 let stored = handler
869 .task_store
870 .get(&task.id)
871 .await
872 .unwrap()
873 .expect("task stored");
874 assert_eq!(
875 stored.history.as_ref().map(Vec::len),
876 Some(1),
877 "the store still keeps the full history"
878 );
879
880 let mut params = make_params(Some("ctx-resp"));
881 params.message.task_id = Some(task.id.clone());
882 params.configuration = Some(SendMessageConfiguration {
883 history_length: Some(10),
884 ..Default::default()
885 });
886 let result = handler
887 .on_send_message(params, false, None)
888 .await
889 .expect("continuation should succeed");
890 let task = match result {
891 SendMessageResult::Response(SendMessageResponse::Task(t)) => t,
892 other => panic!("expected task response, got {other:?}"),
893 };
894 assert_eq!(
895 task.history.as_ref().map(Vec::len),
896 Some(2),
897 "historyLength=10 returns the (2) stored messages"
898 );
899 }
900
901 #[tokio::test]
902 async fn send_message_with_request_metadata() {
903 let handler = make_handler();
905 let mut params = make_params(None);
906 params.metadata = Some(serde_json::json!({"key": "value"}));
907
908 let result = handler.on_send_message(params, false, None).await;
909 assert!(
910 result.is_ok(),
911 "send_message with request metadata should succeed"
912 );
913 }
914
915 #[tokio::test]
916 async fn send_message_error_path_records_metrics() {
917 use crate::call_context::CallContext;
919 use crate::interceptor::ServerInterceptor;
920 use std::future::Future;
921 use std::pin::Pin;
922
923 struct FailInterceptor;
924 impl ServerInterceptor for FailInterceptor {
925 fn before<'a>(
926 &'a self,
927 _ctx: &'a CallContext,
928 ) -> Pin<Box<dyn Future<Output = a2a_protocol_types::error::A2aResult<()>> + Send + 'a>>
929 {
930 Box::pin(async {
931 Err(a2a_protocol_types::error::A2aError::internal(
932 "forced failure",
933 ))
934 })
935 }
936 fn after<'a>(
937 &'a self,
938 _ctx: &'a CallContext,
939 ) -> Pin<Box<dyn Future<Output = a2a_protocol_types::error::A2aResult<()>> + Send + 'a>>
940 {
941 Box::pin(async { Ok(()) })
942 }
943 }
944
945 let handler = RequestHandlerBuilder::new(DummyExecutor)
946 .with_interceptor(FailInterceptor)
947 .build()
948 .unwrap();
949
950 let params = make_params(None);
951 let result = handler.on_send_message(params, false, None).await;
952 assert!(
953 result.is_err(),
954 "send_message should fail when interceptor rejects, exercising error metrics path"
955 );
956 }
957
958 #[tokio::test]
959 async fn send_streaming_message_error_path_records_metrics() {
960 use crate::call_context::CallContext;
962 use crate::interceptor::ServerInterceptor;
963 use std::future::Future;
964 use std::pin::Pin;
965
966 struct FailInterceptor;
967 impl ServerInterceptor for FailInterceptor {
968 fn before<'a>(
969 &'a self,
970 _ctx: &'a CallContext,
971 ) -> Pin<Box<dyn Future<Output = a2a_protocol_types::error::A2aResult<()>> + Send + 'a>>
972 {
973 Box::pin(async {
974 Err(a2a_protocol_types::error::A2aError::internal(
975 "forced failure",
976 ))
977 })
978 }
979 fn after<'a>(
980 &'a self,
981 _ctx: &'a CallContext,
982 ) -> Pin<Box<dyn Future<Output = a2a_protocol_types::error::A2aResult<()>> + Send + 'a>>
983 {
984 Box::pin(async { Ok(()) })
985 }
986 }
987
988 let handler = RequestHandlerBuilder::new(DummyExecutor)
989 .with_interceptor(FailInterceptor)
990 .build()
991 .unwrap();
992
993 let params = make_params(None);
994 let result = handler.on_send_message(params, true, None).await;
995 assert!(
996 result.is_err(),
997 "streaming send_message should fail when interceptor rejects"
998 );
999 }
1000
1001 #[tokio::test]
1002 async fn streaming_mode_returns_stream_result() {
1003 let handler = make_handler();
1005 let params = make_params(None);
1006
1007 let result = handler.on_send_message(params, true, None).await;
1008 assert!(
1009 matches!(result, Ok(SendMessageResult::Stream(_))),
1010 "expected Stream result in streaming mode"
1011 );
1012 }
1013
1014 #[tokio::test]
1015 async fn send_message_with_stored_task_continuation() {
1016 use a2a_protocol_types::task::{Task, TaskState, TaskStatus};
1019
1020 let handler = make_handler();
1021
1022 let task = Task {
1024 id: TaskId::new("existing-task"),
1025 context_id: ContextId::new("continue-ctx"),
1026 status: TaskStatus::new(TaskState::InputRequired),
1027 history: None,
1028 artifacts: None,
1029 metadata: None,
1030 };
1031 handler.task_store.save(&task).await.unwrap();
1032
1033 let params = make_params(Some("continue-ctx"));
1035 let result = handler.on_send_message(params, false, None).await;
1036 assert!(
1037 result.is_ok(),
1038 "send_message with existing non-terminal context should succeed"
1039 );
1040 }
1041
1042 #[tokio::test]
1043 async fn send_message_to_terminal_task_returns_unsupported_operation() {
1044 use a2a_protocol_types::task::{Task, TaskState, TaskStatus};
1047
1048 let handler = make_handler();
1049
1050 let task = Task {
1052 id: TaskId::new("done-task"),
1053 context_id: ContextId::new("done-ctx"),
1054 status: TaskStatus::new(TaskState::Completed),
1055 history: None,
1056 artifacts: None,
1057 metadata: None,
1058 };
1059 handler.task_store.save(&task).await.unwrap();
1060
1061 let mut params = make_params(Some("done-ctx"));
1063 params.message.task_id = Some(TaskId::new("done-task"));
1064 let result = handler.on_send_message(params, false, None).await;
1065 assert!(
1066 matches!(result, Err(ServerError::UnsupportedOperation(ref msg)) if msg.contains("terminal")),
1067 "expected UnsupportedOperation for terminal task, got: {result:?}"
1068 );
1069 }
1070
1071 #[tokio::test]
1072 async fn send_message_to_terminal_context_without_task_id_creates_new_task() {
1073 use a2a_protocol_types::task::{Task, TaskState, TaskStatus};
1076
1077 let handler = make_handler();
1078
1079 let task = Task {
1081 id: TaskId::new("old-task"),
1082 context_id: ContextId::new("reuse-ctx"),
1083 status: TaskStatus::new(TaskState::Completed),
1084 history: None,
1085 artifacts: None,
1086 metadata: None,
1087 };
1088 handler.task_store.save(&task).await.unwrap();
1089
1090 let params = make_params(Some("reuse-ctx"));
1092 let result = handler.on_send_message(params, false, None).await;
1093 assert!(
1094 result.is_ok(),
1095 "should create new task on terminal context, got: {result:?}"
1096 );
1097 }
1098
1099 #[tokio::test]
1100 async fn send_message_with_headers() {
1101 let handler = make_handler();
1103 let params = make_params(None);
1104 let mut headers = HashMap::new();
1105 headers.insert("authorization".to_string(), "Bearer test-token".to_string());
1106
1107 let result = handler.on_send_message(params, false, Some(&headers)).await;
1108 let send_result = result.expect("send_message with headers should succeed");
1109 assert!(
1110 matches!(
1111 send_result,
1112 SendMessageResult::Response(SendMessageResponse::Task(_))
1113 ),
1114 "expected Response(Task) for send with headers"
1115 );
1116 }
1117
1118 #[tokio::test]
1119 async fn duplicate_task_id_without_context_match_returns_error() {
1120 use a2a_protocol_types::task::{Task, TaskId as TId, TaskState, TaskStatus};
1122
1123 let handler = make_handler();
1124
1125 let task = Task {
1127 id: TId::new("dup-task"),
1128 context_id: ContextId::new("other-ctx"),
1129 status: TaskStatus::new(TaskState::Completed),
1130 history: None,
1131 artifacts: None,
1132 metadata: None,
1133 };
1134 handler.task_store.save(&task).await.unwrap();
1135
1136 let mut params = make_params(Some("brand-new-ctx"));
1138 params.message.task_id = Some(TId::new("dup-task"));
1139
1140 let result = handler.on_send_message(params, false, None).await;
1141 assert!(
1142 matches!(result, Err(ServerError::InvalidParams(ref msg)) if msg.contains("different context")),
1143 "expected InvalidParams for task_id in different context, got: {result:?}"
1144 );
1145 }
1146
1147 #[tokio::test]
1148 async fn unknown_task_id_returns_task_not_found() {
1149 use a2a_protocol_types::task::TaskId as TId;
1151
1152 let handler = make_handler();
1153
1154 let mut params = make_params(Some("fresh-ctx"));
1156 params.message.task_id = Some(TId::new("nonexistent-task"));
1157
1158 let result = handler.on_send_message(params, false, None).await;
1159 assert!(
1160 matches!(result, Err(ServerError::TaskNotFound(_))),
1161 "expected TaskNotFound for unknown task_id, got: {result:?}"
1162 );
1163 }
1164
1165 #[tokio::test]
1166 async fn send_message_with_tenant() {
1167 let handler = make_handler();
1169 let mut params = make_params(None);
1170 params.tenant = Some("test-tenant".to_string());
1171
1172 let result = handler.on_send_message(params, false, None).await;
1173 let send_result = result.expect("send_message with tenant should succeed");
1174 assert!(
1175 matches!(
1176 send_result,
1177 SendMessageResult::Response(SendMessageResponse::Task(_))
1178 ),
1179 "expected Response(Task) for send with tenant"
1180 );
1181 }
1182
1183 #[tokio::test]
1184 async fn executor_timeout_returns_failed_task() {
1185 use a2a_protocol_types::error::A2aResult;
1187 use std::time::Duration;
1188
1189 struct SlowExecutor;
1190 impl crate::executor::AgentExecutor for SlowExecutor {
1191 fn execute<'a>(
1192 &'a self,
1193 _ctx: &'a crate::request_context::RequestContext,
1194 _queue: &'a dyn crate::streaming::EventQueueWriter,
1195 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = A2aResult<()>> + Send + 'a>>
1196 {
1197 Box::pin(async {
1198 tokio::time::sleep(Duration::from_secs(60)).await;
1199 Ok(())
1200 })
1201 }
1202 }
1203
1204 let handler = RequestHandlerBuilder::new(SlowExecutor)
1205 .with_executor_timeout(Duration::from_millis(50))
1206 .build()
1207 .unwrap();
1208
1209 let params = make_params(None);
1210 let result = handler.on_send_message(params, false, None).await;
1212 assert!(
1214 result.is_ok(),
1215 "executor timeout should still return a task result"
1216 );
1217 }
1218
1219 #[tokio::test]
1220 async fn executor_failure_writes_failed_event() {
1221 use a2a_protocol_types::error::{A2aError, A2aResult};
1223
1224 struct FailExecutor;
1225 impl crate::executor::AgentExecutor for FailExecutor {
1226 fn execute<'a>(
1227 &'a self,
1228 _ctx: &'a crate::request_context::RequestContext,
1229 _queue: &'a dyn crate::streaming::EventQueueWriter,
1230 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = A2aResult<()>> + Send + 'a>>
1231 {
1232 Box::pin(async { Err(A2aError::internal("executor exploded")) })
1233 }
1234 }
1235
1236 let handler = RequestHandlerBuilder::new(FailExecutor).build().unwrap();
1237 let params = make_params(None);
1238
1239 let result = handler.on_send_message(params, false, None).await;
1240 assert!(
1242 result.is_ok(),
1243 "executor failure should produce a task result"
1244 );
1245 }
1246
1247 #[tokio::test]
1248 async fn cancellation_token_sweep_runs_when_map_is_full() {
1249 use crate::handler::limits::HandlerLimits;
1252
1253 struct SlowExec;
1255 impl crate::executor::AgentExecutor for SlowExec {
1256 fn execute<'a>(
1257 &'a self,
1258 _ctx: &'a crate::request_context::RequestContext,
1259 _queue: &'a dyn crate::streaming::EventQueueWriter,
1260 ) -> std::pin::Pin<
1261 Box<
1262 dyn std::future::Future<Output = a2a_protocol_types::error::A2aResult<()>>
1263 + Send
1264 + 'a,
1265 >,
1266 > {
1267 Box::pin(async {
1268 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
1270 Ok(())
1271 })
1272 }
1273 }
1274
1275 let handler = RequestHandlerBuilder::new(SlowExec)
1276 .with_handler_limits(HandlerLimits::default().with_max_cancellation_tokens(2))
1277 .build()
1278 .unwrap();
1279
1280 for _ in 0..3 {
1283 let params = make_params(None);
1284 let _ = handler.on_send_message(params, true, None).await;
1285 }
1286 handler.shutdown().await;
1289 }
1290
1291 #[tokio::test]
1292 async fn stale_cancellation_tokens_cleaned_up() {
1293 use crate::handler::limits::HandlerLimits;
1295 use std::time::Duration;
1296
1297 struct SlowExec2;
1299 impl crate::executor::AgentExecutor for SlowExec2 {
1300 fn execute<'a>(
1301 &'a self,
1302 _ctx: &'a crate::request_context::RequestContext,
1303 _queue: &'a dyn crate::streaming::EventQueueWriter,
1304 ) -> std::pin::Pin<
1305 Box<
1306 dyn std::future::Future<Output = a2a_protocol_types::error::A2aResult<()>>
1307 + Send
1308 + 'a,
1309 >,
1310 > {
1311 Box::pin(async {
1312 tokio::time::sleep(Duration::from_secs(10)).await;
1313 Ok(())
1314 })
1315 }
1316 }
1317
1318 let handler = RequestHandlerBuilder::new(SlowExec2)
1319 .with_handler_limits(
1320 HandlerLimits::default()
1321 .with_max_cancellation_tokens(2)
1322 .with_max_token_age(Duration::from_millis(1)),
1324 )
1325 .build()
1326 .unwrap();
1327
1328 for _ in 0..2 {
1330 let params = make_params(None);
1331 let _ = handler.on_send_message(params, true, None).await;
1332 }
1333
1334 tokio::time::sleep(Duration::from_millis(50)).await;
1336
1337 let params = make_params(None);
1341 let _ = handler.on_send_message(params, true, None).await;
1342
1343 handler.shutdown().await;
1345 }
1346
1347 #[tokio::test]
1348 async fn streaming_executor_failure_writes_error_event() {
1349 use a2a_protocol_types::error::{A2aError, A2aResult};
1351
1352 struct FailExecutor;
1353 impl crate::executor::AgentExecutor for FailExecutor {
1354 fn execute<'a>(
1355 &'a self,
1356 _ctx: &'a crate::request_context::RequestContext,
1357 _queue: &'a dyn crate::streaming::EventQueueWriter,
1358 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = A2aResult<()>> + Send + 'a>>
1359 {
1360 Box::pin(async { Err(A2aError::internal("streaming fail")) })
1361 }
1362 }
1363
1364 let handler = RequestHandlerBuilder::new(FailExecutor).build().unwrap();
1365 let params = make_params(None);
1366
1367 let result = handler.on_send_message(params, true, None).await;
1368 assert!(
1369 matches!(result, Ok(SendMessageResult::Stream(_))),
1370 "streaming executor failure should still return stream"
1371 );
1372 }
1373
1374 #[tokio::test]
1375 async fn input_required_continuation_reuses_task_id() {
1376 use a2a_protocol_types::task::{Task, TaskId, TaskState, TaskStatus};
1380
1381 let handler = make_handler();
1382
1383 let existing_task_id = TaskId::new("input-required-task");
1385 let task = Task {
1386 id: existing_task_id.clone(),
1387 context_id: ContextId::new("ctx-input"),
1388 status: TaskStatus::new(TaskState::InputRequired),
1389 history: None,
1390 artifacts: None,
1391 metadata: None,
1392 };
1393 handler.task_store.save(&task).await.unwrap();
1394
1395 let mut params = make_params(Some("ctx-input"));
1397 params.message.task_id = Some(existing_task_id.clone());
1398
1399 let result = handler.on_send_message(params, false, None).await;
1400 let send_result = result.expect("continuation should succeed");
1401 match send_result {
1402 SendMessageResult::Response(SendMessageResponse::Task(t)) => {
1403 assert_eq!(
1404 t.id, existing_task_id,
1405 "task_id should be reused for input-required continuation"
1406 );
1407 }
1408 _ => panic!("expected Response(Task)"),
1409 }
1410 }
1411}