1use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12
13use a2a_protocol_types::agent_card::AgentCard;
14use a2a_protocol_types::events::{StreamResponse, TaskStatusUpdateEvent};
15use a2a_protocol_types::params::{
16 CancelTaskParams, DeletePushConfigParams, GetPushConfigParams, ListTasksParams,
17 MessageSendParams, TaskIdParams, TaskQueryParams,
18};
19use a2a_protocol_types::push::TaskPushNotificationConfig;
20use a2a_protocol_types::responses::{SendMessageResponse, TaskListResponse};
21use a2a_protocol_types::task::{ContextId, Task, TaskId, TaskState, TaskStatus};
22
23use crate::call_context::CallContext;
24use crate::error::{ServerError, ServerResult};
25use crate::executor::AgentExecutor;
26use crate::interceptor::ServerInterceptorChain;
27use crate::metrics::Metrics;
28use crate::push::{PushConfigStore, PushSender};
29use crate::request_context::RequestContext;
30use crate::store::TaskStore;
31use crate::streaming::{
32 EventQueueManager, EventQueueReader, EventQueueWriter, InMemoryQueueReader,
33};
34
35const MAX_ID_LENGTH: usize = 1024;
37
38const MAX_METADATA_SIZE: usize = 1_048_576;
40
41const MAX_CANCELLATION_TOKENS: usize = 10_000;
44
45const MAX_TOKEN_AGE: Duration = Duration::from_secs(3600);
48
49fn validate_id(raw: &str, name: &str) -> ServerResult<()> {
51 let trimmed = raw.trim();
52 if trimmed.is_empty() {
53 return Err(ServerError::InvalidParams(format!(
54 "{name} must not be empty or whitespace-only"
55 )));
56 }
57 if trimmed.len() > MAX_ID_LENGTH {
58 return Err(ServerError::InvalidParams(format!(
59 "{name} exceeds maximum length (got {}, max {MAX_ID_LENGTH})",
60 trimmed.len()
61 )));
62 }
63 Ok(())
64}
65
66pub struct RequestHandler {
75 pub(crate) executor: Arc<dyn AgentExecutor>,
76 pub(crate) task_store: Box<dyn TaskStore>,
77 pub(crate) push_config_store: Box<dyn PushConfigStore>,
78 pub(crate) push_sender: Option<Box<dyn PushSender>>,
79 pub(crate) event_queue_manager: EventQueueManager,
80 pub(crate) interceptors: ServerInterceptorChain,
81 pub(crate) agent_card: Option<AgentCard>,
82 pub(crate) executor_timeout: Option<Duration>,
83 pub(crate) metrics: Box<dyn Metrics>,
84 pub(crate) cancellation_tokens: Arc<tokio::sync::RwLock<HashMap<TaskId, CancellationEntry>>>,
86}
87
88#[derive(Debug, Clone)]
90pub(crate) struct CancellationEntry {
91 pub(crate) token: tokio_util::sync::CancellationToken,
93 pub(crate) created_at: Instant,
95}
96
97impl RequestHandler {
98 #[allow(clippy::too_many_lines)]
104 pub async fn on_send_message(
105 &self,
106 params: MessageSendParams,
107 streaming: bool,
108 ) -> ServerResult<SendMessageResult> {
109 let method_name = if streaming {
110 "SendStreamingMessage"
111 } else {
112 "SendMessage"
113 };
114 trace_info!(method = method_name, streaming, "handling send message");
115 self.metrics.on_request(method_name);
116
117 let call_ctx = CallContext::new(method_name);
118 self.interceptors.run_before(&call_ctx).await?;
119
120 if let Some(ref ctx_id) = params.message.context_id {
122 validate_id(&ctx_id.0, "context_id")?;
123 }
124 if let Some(ref task_id) = params.message.task_id {
125 validate_id(&task_id.0, "task_id")?;
126 }
127
128 if params.message.parts.is_empty() {
130 return Err(ServerError::InvalidParams(
131 "message must contain at least one part".into(),
132 ));
133 }
134
135 if let Some(ref meta) = params.message.metadata {
137 let meta_size = serde_json::to_string(meta).map(|s| s.len()).unwrap_or(0);
138 if meta_size > MAX_METADATA_SIZE {
139 return Err(ServerError::InvalidParams(format!(
140 "message metadata exceeds maximum size ({meta_size} bytes, max {MAX_METADATA_SIZE})"
141 )));
142 }
143 }
144 if let Some(ref meta) = params.metadata {
145 let meta_size = serde_json::to_string(meta).map(|s| s.len()).unwrap_or(0);
146 if meta_size > MAX_METADATA_SIZE {
147 return Err(ServerError::InvalidParams(format!(
148 "request metadata exceeds maximum size ({meta_size} bytes, max {MAX_METADATA_SIZE})"
149 )));
150 }
151 }
152
153 let task_id = TaskId::new(uuid::Uuid::new_v4().to_string());
155 let context_id = params
156 .message
157 .context_id
158 .as_ref()
159 .map_or_else(|| uuid::Uuid::new_v4().to_string(), |c| c.0.clone());
160
161 let stored_task = self.find_task_by_context(&context_id).await;
163
164 if let Some(ref msg_task_id) = params.message.task_id {
167 if let Some(ref stored) = stored_task {
168 if msg_task_id != &stored.id {
169 return Err(ServerError::InvalidParams(
170 "message task_id does not match task found for context".into(),
171 ));
172 }
173 } else {
174 let placeholder = Task {
177 id: msg_task_id.clone(),
178 context_id: ContextId::new(&context_id),
179 status: TaskStatus::with_timestamp(TaskState::Submitted),
180 history: None,
181 artifacts: None,
182 metadata: None,
183 };
184 if !self.task_store.insert_if_absent(placeholder).await? {
185 return Err(ServerError::InvalidParams(
186 "task_id already exists; cannot create duplicate".into(),
187 ));
188 }
189 }
190 }
191
192 let return_immediately = params
194 .configuration
195 .as_ref()
196 .and_then(|c| c.return_immediately)
197 .unwrap_or(false);
198
199 trace_debug!(
201 task_id = %task_id,
202 context_id = %context_id,
203 "creating task"
204 );
205 let task = Task {
206 id: task_id.clone(),
207 context_id: ContextId::new(&context_id),
208 status: TaskStatus::with_timestamp(TaskState::Submitted),
209 history: None,
210 artifacts: None,
211 metadata: None,
212 };
213
214 self.task_store.save(task.clone()).await?;
215
216 let mut ctx = RequestContext::new(params.message, task_id.clone(), context_id);
218 if let Some(stored) = stored_task {
219 ctx = ctx.with_stored_task(stored);
220 }
221 if let Some(meta) = params.metadata {
222 ctx = ctx.with_metadata(meta);
223 }
224
225 {
227 let mut tokens = self.cancellation_tokens.write().await;
228 if tokens.len() >= MAX_CANCELLATION_TOKENS {
231 let now = Instant::now();
232 tokens.retain(|_, entry| {
233 !entry.token.is_cancelled()
234 && now.duration_since(entry.created_at) < MAX_TOKEN_AGE
235 });
236 }
237 tokens.insert(
238 task_id.clone(),
239 CancellationEntry {
240 token: ctx.cancellation_token.clone(),
241 created_at: Instant::now(),
242 },
243 );
244 }
245
246 let (writer, reader) = self.event_queue_manager.get_or_create(&task_id).await;
248 let reader = reader
249 .ok_or_else(|| ServerError::Internal("event queue already exists for task".into()))?;
250
251 let executor = Arc::clone(&self.executor);
255 let task_id_for_cleanup = task_id.clone();
256 let event_queue_mgr = self.event_queue_manager.clone();
257 let cancel_tokens = Arc::clone(&self.cancellation_tokens);
258 let executor_timeout = self.executor_timeout;
259 let executor_handle = tokio::spawn(async move {
260 trace_debug!(task_id = %ctx.task_id, "executor started");
261
262 let result = {
264 let exec_future = if let Some(timeout) = executor_timeout {
265 tokio::time::timeout(timeout, executor.execute(&ctx, writer.as_ref()))
266 .await
267 .unwrap_or_else(|_| {
268 Err(a2a_protocol_types::error::A2aError::internal(format!(
269 "executor timed out after {}s",
270 timeout.as_secs()
271 )))
272 })
273 } else {
274 executor.execute(&ctx, writer.as_ref()).await
275 };
276 exec_future
277 };
278
279 if let Err(ref e) = result {
280 trace_error!(task_id = %ctx.task_id, error = %e, "executor failed");
281 let fail_event = StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
283 task_id: ctx.task_id.clone(),
284 context_id: ContextId::new(ctx.context_id.clone()),
285 status: TaskStatus::with_timestamp(TaskState::Failed),
286 metadata: Some(serde_json::json!({ "error": e.to_string() })),
287 });
288 if let Err(_write_err) = writer.write(fail_event).await {
289 trace_error!(
290 task_id = %ctx.task_id,
291 error = %_write_err,
292 "failed to write failure event to queue"
293 );
294 }
295 }
296 drop(writer);
299 event_queue_mgr.destroy(&task_id_for_cleanup).await;
300 cancel_tokens.write().await.remove(&task_id_for_cleanup);
302 });
303
304 self.interceptors.run_after(&call_ctx).await?;
305
306 if streaming {
307 Ok(SendMessageResult::Stream(reader))
308 } else if return_immediately {
309 Ok(SendMessageResult::Response(SendMessageResponse::Task(task)))
311 } else {
312 let final_task = self
315 .collect_events(reader, task_id.clone(), executor_handle)
316 .await?;
317 Ok(SendMessageResult::Response(SendMessageResponse::Task(
318 final_task,
319 )))
320 }
321 }
322
323 pub async fn on_get_task(&self, params: TaskQueryParams) -> ServerResult<Task> {
329 trace_info!(method = "GetTask", task_id = %params.id, "handling get task");
330 let call_ctx = CallContext::new("GetTask");
331 self.interceptors.run_before(&call_ctx).await?;
332
333 let task_id = TaskId::new(¶ms.id);
334 let task = self
335 .task_store
336 .get(&task_id)
337 .await?
338 .ok_or_else(|| ServerError::TaskNotFound(task_id))?;
339
340 self.interceptors.run_after(&call_ctx).await?;
341 Ok(task)
342 }
343
344 pub async fn on_list_tasks(&self, params: ListTasksParams) -> ServerResult<TaskListResponse> {
350 trace_info!(method = "ListTasks", "handling list tasks");
351 let call_ctx = CallContext::new("ListTasks");
352 self.interceptors.run_before(&call_ctx).await?;
353
354 let result = self.task_store.list(¶ms).await?;
355
356 self.interceptors.run_after(&call_ctx).await?;
357 Ok(result)
358 }
359
360 pub async fn on_cancel_task(&self, params: CancelTaskParams) -> ServerResult<Task> {
366 trace_info!(method = "CancelTask", task_id = %params.id, "handling cancel task");
367 let call_ctx = CallContext::new("CancelTask");
368 self.interceptors.run_before(&call_ctx).await?;
369
370 let task_id = TaskId::new(¶ms.id);
371 let task = self
372 .task_store
373 .get(&task_id)
374 .await?
375 .ok_or_else(|| ServerError::TaskNotFound(task_id.clone()))?;
376
377 if task.status.state.is_terminal() {
378 return Err(ServerError::TaskNotCancelable(task_id));
379 }
380
381 {
383 let tokens = self.cancellation_tokens.read().await;
384 if let Some(entry) = tokens.get(&task_id) {
385 entry.token.cancel();
386 }
387 }
388
389 let ctx = RequestContext::new(
391 a2a_protocol_types::message::Message {
392 id: a2a_protocol_types::message::MessageId::new(uuid::Uuid::new_v4().to_string()),
393 role: a2a_protocol_types::message::MessageRole::User,
394 parts: vec![],
395 task_id: Some(task_id.clone()),
396 context_id: Some(task.context_id.clone()),
397 reference_task_ids: None,
398 extensions: None,
399 metadata: None,
400 },
401 task_id.clone(),
402 task.context_id.0.clone(),
403 );
404
405 let (writer, _reader) = self.event_queue_manager.get_or_create(&task_id).await;
406 self.executor.cancel(&ctx, writer.as_ref()).await?;
407
408 let mut updated = task;
410 updated.status = TaskStatus::with_timestamp(TaskState::Canceled);
411 self.task_store.save(updated.clone()).await?;
412
413 self.interceptors.run_after(&call_ctx).await?;
414 Ok(updated)
415 }
416
417 pub async fn on_resubscribe(&self, params: TaskIdParams) -> ServerResult<InMemoryQueueReader> {
423 trace_info!(method = "SubscribeToTask", task_id = %params.id, "handling resubscribe");
424 let call_ctx = CallContext::new("SubscribeToTask");
425 self.interceptors.run_before(&call_ctx).await?;
426
427 let task_id = TaskId::new(¶ms.id);
428
429 let _task = self
431 .task_store
432 .get(&task_id)
433 .await?
434 .ok_or_else(|| ServerError::TaskNotFound(task_id.clone()))?;
435
436 let (_writer, reader) = self.event_queue_manager.get_or_create(&task_id).await;
437 let reader = reader.ok_or_else(|| {
438 ServerError::Internal("no event queue available for resubscribe".into())
439 })?;
440
441 self.interceptors.run_after(&call_ctx).await?;
442 Ok(reader)
443 }
444
445 pub async fn on_set_push_config(
451 &self,
452 config: TaskPushNotificationConfig,
453 ) -> ServerResult<TaskPushNotificationConfig> {
454 if self.push_sender.is_none() {
455 return Err(ServerError::PushNotSupported);
456 }
457 let call_ctx = CallContext::new("CreateTaskPushNotificationConfig");
458 self.interceptors.run_before(&call_ctx).await?;
459
460 let result = self.push_config_store.set(config).await?;
461
462 self.interceptors.run_after(&call_ctx).await?;
463 Ok(result)
464 }
465
466 pub async fn on_get_push_config(
472 &self,
473 params: GetPushConfigParams,
474 ) -> ServerResult<TaskPushNotificationConfig> {
475 let call_ctx = CallContext::new("GetTaskPushNotificationConfig");
476 self.interceptors.run_before(&call_ctx).await?;
477
478 let config = self
479 .push_config_store
480 .get(¶ms.task_id, ¶ms.id)
481 .await?
482 .ok_or_else(|| {
483 ServerError::InvalidParams(format!(
484 "push config not found: task={}, id={}",
485 params.task_id, params.id
486 ))
487 })?;
488
489 self.interceptors.run_after(&call_ctx).await?;
490 Ok(config)
491 }
492
493 pub async fn on_list_push_configs(
499 &self,
500 task_id: &str,
501 ) -> ServerResult<Vec<TaskPushNotificationConfig>> {
502 let call_ctx = CallContext::new("ListTaskPushNotificationConfigs");
503 self.interceptors.run_before(&call_ctx).await?;
504
505 let configs = self.push_config_store.list(task_id).await?;
506
507 self.interceptors.run_after(&call_ctx).await?;
508 Ok(configs)
509 }
510
511 pub async fn on_delete_push_config(&self, params: DeletePushConfigParams) -> ServerResult<()> {
517 let call_ctx = CallContext::new("DeleteTaskPushNotificationConfig");
518 self.interceptors.run_before(&call_ctx).await?;
519
520 self.push_config_store
521 .delete(¶ms.task_id, ¶ms.id)
522 .await?;
523
524 self.interceptors.run_after(&call_ctx).await?;
525 Ok(())
526 }
527
528 pub async fn on_get_extended_agent_card(&self) -> ServerResult<AgentCard> {
534 let call_ctx = CallContext::new("GetExtendedAgentCard");
535 self.interceptors.run_before(&call_ctx).await?;
536
537 let card = self
538 .agent_card
539 .clone()
540 .ok_or_else(|| ServerError::Internal("no agent card configured".into()))?;
541
542 self.interceptors.run_after(&call_ctx).await?;
543 Ok(card)
544 }
545
546 async fn find_task_by_context(&self, context_id: &str) -> Option<Task> {
550 if context_id.len() > MAX_ID_LENGTH {
551 return None;
552 }
553 let params = ListTasksParams {
554 tenant: None,
555 context_id: Some(context_id.to_owned()),
556 status: None,
557 page_size: Some(1),
558 page_token: None,
559 status_timestamp_after: None,
560 include_artifacts: None,
561 history_length: None,
562 };
563 self.task_store
564 .list(¶ms)
565 .await
566 .ok()
567 .and_then(|resp| resp.tasks.into_iter().next())
568 }
569
570 async fn collect_events(
577 &self,
578 mut reader: InMemoryQueueReader,
579 task_id: TaskId,
580 executor_handle: tokio::task::JoinHandle<()>,
581 ) -> ServerResult<Task> {
582 let mut last_task = self
583 .task_store
584 .get(&task_id)
585 .await?
586 .ok_or_else(|| ServerError::TaskNotFound(task_id.clone()))?;
587
588 let mut executor_done = false;
592 let mut handle_fuse = executor_handle;
593
594 loop {
595 if executor_done {
596 match reader.read().await {
598 Some(event) => {
599 self.process_event(event, &task_id, &mut last_task).await?;
600 }
601 None => break,
602 }
603 } else {
604 tokio::select! {
605 biased;
606 event = reader.read() => {
607 match event {
608 Some(event) => {
609 self.process_event(event, &task_id, &mut last_task).await?;
610 }
611 None => break,
612 }
613 }
614 result = &mut handle_fuse => {
615 executor_done = true;
616 if result.is_err() {
617 trace_error!(
620 task_id = %task_id,
621 "executor task panicked"
622 );
623 if !last_task.status.state.is_terminal() {
624 last_task.status = TaskStatus::with_timestamp(TaskState::Failed);
625 self.task_store.save(last_task.clone()).await?;
626 }
627 }
628 }
630 }
631 }
632 }
633
634 Ok(last_task)
635 }
636
637 async fn process_event(
640 &self,
641 event: a2a_protocol_types::error::A2aResult<StreamResponse>,
642 task_id: &TaskId,
643 last_task: &mut Task,
644 ) -> ServerResult<()> {
645 match event {
646 Ok(ref stream_resp @ StreamResponse::StatusUpdate(ref update)) => {
647 let current = last_task.status.state;
648 let next = update.status.state;
649 if !current.can_transition_to(next) {
650 trace_warn!(
651 task_id = %task_id,
652 from = %current,
653 to = %next,
654 "invalid state transition rejected"
655 );
656 return Err(ServerError::InvalidStateTransition {
657 task_id: task_id.clone(),
658 from: current,
659 to: next,
660 });
661 }
662 last_task.status = TaskStatus {
663 state: next,
664 message: update.status.message.clone(),
665 timestamp: update.status.timestamp.clone(),
666 };
667 self.task_store.save(last_task.clone()).await?;
668 self.deliver_push(task_id, stream_resp).await;
669 }
670 Ok(ref stream_resp @ StreamResponse::ArtifactUpdate(ref update)) => {
671 let artifacts = last_task.artifacts.get_or_insert_with(Vec::new);
672 artifacts.push(update.artifact.clone());
673 self.task_store.save(last_task.clone()).await?;
674 self.deliver_push(task_id, stream_resp).await;
675 }
676 Ok(StreamResponse::Task(task)) => {
677 *last_task = task;
678 self.task_store.save(last_task.clone()).await?;
679 }
680 Ok(StreamResponse::Message(_) | _) => {
681 }
683 Err(e) => {
684 last_task.status = TaskStatus::with_timestamp(TaskState::Failed);
685 self.task_store.save(last_task.clone()).await?;
686 return Err(ServerError::Protocol(e));
687 }
688 }
689 Ok(())
690 }
691
692 async fn deliver_push(&self, task_id: &TaskId, event: &StreamResponse) {
698 let Some(ref sender) = self.push_sender else {
699 return;
700 };
701 let Ok(configs) = self.push_config_store.list(task_id.as_ref()).await else {
702 return;
703 };
704 for config in &configs {
705 let result = tokio::time::timeout(
707 Duration::from_secs(5),
708 sender.send(&config.url, event, config),
709 )
710 .await;
711 match result {
712 Ok(Err(_err)) => {
713 trace_warn!(
714 task_id = %task_id,
715 url = %config.url,
716 error = %_err,
717 "push notification delivery failed"
718 );
719 }
720 Err(_) => {
721 trace_warn!(
722 task_id = %task_id,
723 url = %config.url,
724 "push notification delivery timed out"
725 );
726 }
727 Ok(Ok(())) => {}
728 }
729 }
730 }
731}
732
733impl RequestHandler {
734 pub async fn shutdown(&self) {
744 {
746 let tokens = self.cancellation_tokens.read().await;
747 for entry in tokens.values() {
748 entry.token.cancel();
749 }
750 }
751
752 self.event_queue_manager.destroy_all().await;
754
755 {
757 let mut tokens = self.cancellation_tokens.write().await;
758 tokens.clear();
759 }
760
761 self.executor.on_shutdown().await;
763 }
764
765 pub async fn shutdown_with_timeout(&self, timeout: Duration) {
771 {
773 let tokens = self.cancellation_tokens.read().await;
774 for entry in tokens.values() {
775 entry.token.cancel();
776 }
777 }
778
779 let drain_start = Instant::now();
781 loop {
782 let active = self.event_queue_manager.active_count().await;
783 if active == 0 {
784 break;
785 }
786 if drain_start.elapsed() >= timeout {
787 trace_warn!(
788 active_queues = active,
789 "shutdown timeout reached, force-destroying remaining queues"
790 );
791 break;
792 }
793 tokio::time::sleep(Duration::from_millis(50)).await;
794 }
795
796 self.event_queue_manager.destroy_all().await;
798
799 {
801 let mut tokens = self.cancellation_tokens.write().await;
802 tokens.clear();
803 }
804
805 self.executor.on_shutdown().await;
807 }
808}
809
810impl std::fmt::Debug for RequestHandler {
811 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
812 f.debug_struct("RequestHandler")
813 .field("push_sender", &self.push_sender.is_some())
814 .field("event_queue_manager", &self.event_queue_manager)
815 .field("interceptors", &self.interceptors)
816 .field("agent_card", &self.agent_card.is_some())
817 .field("metrics", &"<dyn Metrics>")
818 .finish_non_exhaustive()
819 }
820}
821
822#[allow(clippy::large_enum_variant)]
824pub enum SendMessageResult {
825 Response(SendMessageResponse),
827 Stream(InMemoryQueueReader),
829}