1use super::agent::Agent;
2use super::storage::Storage;
3use crate::a2a_types::{
4 Artifact, Message as A2AMessage, Part, Role, StreamResponse, Task, TaskArtifactUpdateEvent,
5 TaskState, TaskStatus, TaskStatusUpdateEvent, Timestamp,
6};
7use anyhow::{Result, anyhow};
8use futures_util::stream::StreamExt;
9use inference_gateway_sdk::{Message, MessageContent, MessageRole};
10use serde_json::Value;
11use std::sync::Arc;
12use tokio::sync::mpsc;
13use tracing::{debug, warn};
14
15#[async_trait::async_trait]
22pub trait TaskHandler: Send + Sync + std::fmt::Debug {
23 async fn handle_task(&self, task: Task, message: Option<A2AMessage>) -> Result<Task>;
24}
25
26#[async_trait::async_trait]
35pub trait StreamableTaskHandler: Send + Sync + std::fmt::Debug {
36 async fn handle_streaming_task(
42 &self,
43 task: Task,
44 message: Option<A2AMessage>,
45 emitter: StreamEmitter,
46 ) -> Result<()>;
47}
48
49#[derive(Clone)]
52pub struct StreamEmitter {
53 tx: mpsc::Sender<StreamResponse>,
54 storage: Arc<dyn Storage>,
55}
56
57impl std::fmt::Debug for StreamEmitter {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 f.debug_struct("StreamEmitter").finish_non_exhaustive()
60 }
61}
62
63impl StreamEmitter {
64 pub(super) fn new(tx: mpsc::Sender<StreamResponse>, storage: Arc<dyn Storage>) -> Self {
65 Self { tx, storage }
66 }
67
68 pub async fn emit(&self, response: StreamResponse) -> Result<()> {
70 self.tx
71 .send(response)
72 .await
73 .map_err(|_| anyhow!("stream receiver dropped before handler finished"))
74 }
75
76 pub async fn emit_status(
80 &self,
81 task_id: &str,
82 context_id: &str,
83 state: TaskState,
84 message: Option<A2AMessage>,
85 final_: bool,
86 ) -> Result<()> {
87 let now = Timestamp(chrono::Utc::now());
88 let new_status = TaskStatus {
89 message: message.clone(),
90 state,
91 timestamp: Some(now),
92 };
93
94 if let Some(mut task) = self.storage.get_task(task_id).await {
95 task.status = new_status.clone();
96 if let Some(ref msg) = message {
97 task.history.push(msg.clone());
98 }
99 self.storage.put_task(task).await;
100 }
101
102 let event = TaskStatusUpdateEvent {
103 context_id: context_id.to_string(),
104 final_,
105 metadata: None,
106 status: new_status,
107 task_id: task_id.to_string(),
108 };
109
110 self.emit(StreamResponse {
111 artifact_update: None,
112 message: None,
113 status_update: Some(event),
114 task: None,
115 })
116 .await
117 }
118
119 pub async fn emit_text_artifact(
122 &self,
123 task_id: &str,
124 context_id: &str,
125 text: impl Into<String>,
126 last_chunk: bool,
127 ) -> Result<()> {
128 let artifact_id = uuid::Uuid::new_v4().to_string();
129 let text = text.into();
130 let artifact = Artifact {
131 artifact_id: artifact_id.clone(),
132 description: None,
133 extensions: vec![],
134 metadata: None,
135 name: None,
136 parts: vec![Part {
137 data: None,
138 file: None,
139 metadata: None,
140 text: Some(text),
141 }],
142 };
143
144 if let Some(mut task) = self.storage.get_task(task_id).await {
145 task.artifacts.push(artifact.clone());
146 self.storage.put_task(task).await;
147 }
148
149 let event = TaskArtifactUpdateEvent {
150 append: None,
151 artifact,
152 context_id: context_id.to_string(),
153 last_chunk: Some(last_chunk),
154 metadata: None,
155 task_id: task_id.to_string(),
156 };
157
158 self.emit(StreamResponse {
159 artifact_update: Some(event),
160 message: None,
161 status_update: None,
162 task: None,
163 })
164 .await
165 }
166}
167
168pub(super) fn build_agent_text_message(task: &Task, text: &str) -> A2AMessage {
169 A2AMessage {
170 context_id: Some(task.context_id.clone()),
171 extensions: vec![],
172 message_id: uuid::Uuid::new_v4().to_string(),
173 metadata: None,
174 parts: vec![Part {
175 data: None,
176 file: None,
177 metadata: None,
178 text: Some(text.to_string()),
179 }],
180 reference_task_ids: vec![],
181 role: Role::RoleAgent,
182 task_id: Some(task.id.clone()),
183 }
184}
185
186fn message_content_to_string(content: &MessageContent) -> String {
187 match content {
188 MessageContent::String(s) => s.clone(),
189 MessageContent::Array(parts) => serde_json::to_string(parts).unwrap_or_default(),
190 }
191}
192
193fn build_sdk_messages(agent: &Agent, task: &Task) -> Vec<Message> {
198 let mut messages: Vec<Message> = Vec::new();
199 if let Some(prompt) = agent.system_prompt.clone() {
200 messages.push(Message {
201 role: MessageRole::System,
202 content: MessageContent::String(prompt),
203 reasoning: None,
204 reasoning_content: None,
205 tool_call_id: None,
206 tool_calls: Vec::new(),
207 });
208 }
209 for a2a_msg in &task.history {
210 let text = a2a_msg
211 .parts
212 .iter()
213 .filter_map(|p| p.text.clone())
214 .collect::<Vec<_>>()
215 .join("");
216 if text.is_empty() {
217 continue;
218 }
219 let role = match a2a_msg.role {
220 Role::RoleAgent => MessageRole::Assistant,
221 _ => MessageRole::User,
222 };
223 messages.push(Message {
224 role,
225 content: MessageContent::String(text),
226 reasoning: None,
227 reasoning_content: None,
228 tool_call_id: None,
229 tool_calls: Vec::new(),
230 });
231 }
232 messages
233}
234
235const NO_AGENT_REPLY: &str = "I received your message. I'm a default task handler without AI capabilities. \
238 To enable AI responses, configure an OpenAI-compatible agent via \
239 `A2AServerBuilder::with_agent(...)`.";
240
241struct ToolLoopOutcome {
247 messages: Vec<Message>,
248 final_text: String,
249 exhausted: bool,
250}
251
252async fn run_tool_loop(agent: &Agent, mut messages: Vec<Message>) -> Result<ToolLoopOutcome> {
263 let llm = agent.llm_client();
264 let tools = agent.toolbox.clone();
265 let max_iterations = agent.max_chat_completion().max(1) as usize;
266
267 for _ in 0..max_iterations {
268 let response = llm
269 .create_chat_completion(messages.clone(), tools.clone())
270 .await
271 .map_err(|e| anyhow!("llm call failed: {e}"))?;
272
273 let Some(choice) = response.choices.into_iter().next() else {
274 return Ok(ToolLoopOutcome {
275 messages,
276 final_text: String::new(),
277 exhausted: false,
278 });
279 };
280
281 let assistant_text = message_content_to_string(&choice.message.content);
282 let tool_calls = choice.message.tool_calls.clone();
283 let reasoning = choice.message.reasoning.clone();
284 let reasoning_content = choice.message.reasoning_content.clone();
285
286 messages.push(Message {
287 role: MessageRole::Assistant,
288 content: MessageContent::String(assistant_text.clone()),
289 reasoning,
290 reasoning_content,
291 tool_call_id: None,
292 tool_calls: tool_calls.clone(),
293 });
294
295 if tool_calls.is_empty() {
296 return Ok(ToolLoopOutcome {
297 messages,
298 final_text: assistant_text,
299 exhausted: false,
300 });
301 }
302
303 for tool_call in tool_calls {
304 let tool_name = tool_call.function.name.clone();
305 let args: Value = serde_json::from_str(&tool_call.function.arguments)
306 .unwrap_or_else(|_| Value::String(tool_call.function.arguments.clone()));
307
308 debug!("tool dispatch: {tool_name}");
309
310 let tool_result = match agent.tool_handler(&tool_name) {
311 Some(handler) => match handler.handle(args).await {
312 Ok(value) => value,
313 Err(e) => format!("tool `{tool_name}` failed: {e}"),
314 },
315 None => format!("no handler registered for tool `{tool_name}`"),
316 };
317
318 messages.push(Message {
319 role: MessageRole::Tool,
320 content: MessageContent::String(tool_result),
321 reasoning: None,
322 reasoning_content: None,
323 tool_call_id: Some(tool_call.id.clone()),
324 tool_calls: Vec::new(),
325 });
326 }
327 }
328
329 Ok(ToolLoopOutcome {
330 messages,
331 final_text: String::new(),
332 exhausted: true,
333 })
334}
335
336#[derive(Debug)]
345pub struct DefaultBackgroundTaskHandler {
346 agent: Option<Arc<Agent>>,
347}
348
349impl DefaultBackgroundTaskHandler {
350 pub fn new(agent: Option<Arc<Agent>>) -> Self {
351 Self { agent }
352 }
353}
354
355#[async_trait::async_trait]
356impl TaskHandler for DefaultBackgroundTaskHandler {
357 async fn handle_task(&self, mut task: Task, _message: Option<A2AMessage>) -> Result<Task> {
358 let (reply_text, terminal_state) = match self.agent.as_ref() {
359 Some(agent) => {
360 let messages = build_sdk_messages(agent, &task);
361 match run_tool_loop(agent, messages).await {
362 Ok(outcome) if outcome.exhausted => {
363 warn!(
364 "default background handler: tool loop exhausted \
365 after {} iterations without a final answer",
366 agent.max_chat_completion()
367 );
368 (
369 "Tool loop exhausted before the model produced a \
370 final answer."
371 .to_string(),
372 TaskState::TaskStateFailed,
373 )
374 }
375 Ok(outcome) => {
376 let text = if outcome.final_text.is_empty() {
377 "Task completed".to_string()
378 } else {
379 outcome.final_text
380 };
381 (text, TaskState::TaskStateCompleted)
382 }
383 Err(e) => {
384 warn!("default background handler: agent call failed: {e}");
385 (
386 format!("Agent call failed: {e}"),
387 TaskState::TaskStateFailed,
388 )
389 }
390 }
391 }
392 None => (NO_AGENT_REPLY.to_string(), TaskState::TaskStateCompleted),
393 };
394
395 let reply = build_agent_text_message(&task, &reply_text);
396 task.history.push(reply.clone());
397 task.status = TaskStatus {
398 message: Some(reply),
399 state: terminal_state,
400 timestamp: Some(Timestamp(chrono::Utc::now())),
401 };
402 Ok(task)
403 }
404}
405
406#[derive(Debug)]
420pub struct DefaultStreamingTaskHandler {
421 agent: Option<Arc<Agent>>,
422}
423
424impl DefaultStreamingTaskHandler {
425 pub fn new(agent: Option<Arc<Agent>>) -> Self {
426 Self { agent }
427 }
428}
429
430#[async_trait::async_trait]
431impl StreamableTaskHandler for DefaultStreamingTaskHandler {
432 async fn handle_streaming_task(
433 &self,
434 task: Task,
435 _message: Option<A2AMessage>,
436 emitter: StreamEmitter,
437 ) -> Result<()> {
438 emitter
439 .emit_status(
440 &task.id,
441 &task.context_id,
442 TaskState::TaskStateWorking,
443 None,
444 false,
445 )
446 .await?;
447
448 let final_text = match self.agent.as_ref() {
449 Some(agent) => stream_agent_deltas(agent, &task, &emitter).await?,
450 None => {
451 emitter
452 .emit_text_artifact(&task.id, &task.context_id, NO_AGENT_REPLY, true)
453 .await?;
454 NO_AGENT_REPLY.to_string()
455 }
456 };
457
458 let reply_message = build_agent_text_message(&task, &final_text);
459 emitter
460 .emit_status(
461 &task.id,
462 &task.context_id,
463 TaskState::TaskStateCompleted,
464 Some(reply_message),
465 true,
466 )
467 .await
468 }
469}
470
471async fn stream_agent_deltas(
483 agent: &Agent,
484 task: &Task,
485 emitter: &StreamEmitter,
486) -> Result<String> {
487 let base_messages = build_sdk_messages(agent, task);
488
489 let messages = if agent.toolbox().is_some() {
490 match run_tool_loop(agent, base_messages).await {
491 Ok(outcome) if outcome.exhausted => {
492 let msg = "Tool loop exhausted before the model produced a \
493 final answer."
494 .to_string();
495 emitter
496 .emit_text_artifact(&task.id, &task.context_id, &msg, true)
497 .await?;
498 return Ok(msg);
499 }
500 Ok(outcome) => {
501 if !outcome.final_text.is_empty()
502 && outcome
503 .messages
504 .last()
505 .map(|m| m.tool_calls.is_empty())
506 .unwrap_or(true)
507 {
508 emitter
509 .emit_text_artifact(&task.id, &task.context_id, &outcome.final_text, true)
510 .await?;
511 return Ok(outcome.final_text);
512 }
513 outcome.messages
514 }
515 Err(e) => {
516 warn!("default streaming handler: tool loop failed: {e}");
517 let msg = format!("Agent stream failed: {e}");
518 emitter
519 .emit_text_artifact(&task.id, &task.context_id, &msg, true)
520 .await?;
521 return Ok(msg);
522 }
523 }
524 } else {
525 base_messages
526 };
527
528 let llm = agent.llm_client();
529 let tools = agent.toolbox.clone();
530 let mut stream = llm.create_streaming_chat_completion(messages, tools);
531
532 let artifact_id = uuid::Uuid::new_v4().to_string();
533 let mut buffer = String::new();
534
535 while let Some(item) = stream.next().await {
536 let event = match item {
537 Ok(e) => e,
538 Err(e) => {
539 warn!("default streaming handler: gateway error: {e}");
540 let msg = format!("Agent stream failed: {e}");
541 emitter
542 .emit_text_artifact(&task.id, &task.context_id, &msg, true)
543 .await?;
544 return Ok(msg);
545 }
546 };
547
548 let data = event.data.trim();
549 if data.is_empty() || data == "[DONE]" {
550 if data == "[DONE]" {
551 break;
552 }
553 continue;
554 }
555
556 let parsed: serde_json::Value = match serde_json::from_str(data) {
557 Ok(v) => v,
558 Err(_) => continue,
559 };
560 let Some(text) = parsed
561 .get("choices")
562 .and_then(|c| c.as_array())
563 .and_then(|arr| arr.first())
564 .and_then(|c| c.get("delta"))
565 .and_then(|d| d.get("content"))
566 .and_then(|t| t.as_str())
567 else {
568 continue;
569 };
570 if text.is_empty() {
571 continue;
572 }
573 buffer.push_str(text);
574
575 let chunk_event = TaskArtifactUpdateEvent {
576 append: Some(true),
577 artifact: Artifact {
578 artifact_id: artifact_id.clone(),
579 description: None,
580 extensions: vec![],
581 metadata: None,
582 name: None,
583 parts: vec![Part {
584 data: None,
585 file: None,
586 metadata: None,
587 text: Some(text.to_string()),
588 }],
589 },
590 context_id: task.context_id.clone(),
591 last_chunk: Some(false),
592 metadata: None,
593 task_id: task.id.clone(),
594 };
595 emitter
596 .emit(StreamResponse {
597 artifact_update: Some(chunk_event),
598 message: None,
599 status_update: None,
600 task: None,
601 })
602 .await?;
603 }
604
605 let final_event = TaskArtifactUpdateEvent {
606 append: Some(true),
607 artifact: Artifact {
608 artifact_id,
609 description: None,
610 extensions: vec![],
611 metadata: None,
612 name: None,
613 parts: vec![],
614 },
615 context_id: task.context_id.clone(),
616 last_chunk: Some(true),
617 metadata: None,
618 task_id: task.id.clone(),
619 };
620 emitter
621 .emit(StreamResponse {
622 artifact_update: Some(final_event),
623 message: None,
624 status_update: None,
625 task: None,
626 })
627 .await?;
628
629 Ok(buffer)
630}
631
632#[cfg(test)]
633mod tests {
634 use super::*;
635 use crate::a2a_types::{AgentCard, Role, SendMessageRequest};
636 use crate::server::agent_builder::AgentBuilder;
637 use crate::server::protocol::{AppState, a2a_handler};
638 use crate::server::server_builder::A2AServerBuilder;
639 use axum::Router;
640 use axum::extract::State;
641 use axum::response::Json;
642 use axum::routing::post;
643 use inference_gateway_sdk::{
644 ChatCompletionTool, ChatCompletionToolType, FunctionObject, FunctionParameters,
645 };
646 use tokio::net::TcpListener;
647
648 fn agent_card_with_streaming(streaming: bool) -> AgentCard {
649 serde_json::from_value(serde_json::json!({
650 "name": "Validation Agent",
651 "description": "Builder validation tests",
652 "version": "0.0.0",
653 "protocolVersion": "0.2.6",
654 "url": "http://localhost/a2a",
655 "preferredTransport": "JSONRPC",
656 "capabilities": {
657 "streaming": streaming,
658 "pushNotifications": false,
659 "stateTransitionHistory": false
660 },
661 "defaultInputModes": ["text/plain"],
662 "defaultOutputModes": ["text/plain"],
663 "skills": [
664 {"id": "x", "name": "x", "description": "x", "tags": ["x"]}
665 ]
666 }))
667 .unwrap()
668 }
669
670 #[tokio::test]
677 async fn default_streaming_handler_iterates_gateway_deltas() {
678 use crate::A2AClient;
679 use crate::a2a_types::Message as A2AMessage;
680 use crate::config::AgentConfig;
681 use axum::response::sse::{Event as SseEvent, KeepAlive as SseKeepAlive, Sse as SseResp};
682 use futures_util::StreamExt as _;
683
684 async fn chat_completions() -> SseResp<
686 impl futures_util::Stream<Item = std::result::Result<SseEvent, std::convert::Infallible>>,
687 > {
688 let deltas = [
689 serde_json::json!({"choices":[{"delta":{"content":"Hel"}}]}).to_string(),
690 serde_json::json!({"choices":[{"delta":{"content":"lo "}}]}).to_string(),
691 serde_json::json!({"choices":[{"delta":{"content":"world"}}]}).to_string(),
692 "[DONE]".to_string(),
693 ];
694 let stream = futures_util::stream::iter(
695 deltas
696 .into_iter()
697 .map(|d| Ok::<_, std::convert::Infallible>(SseEvent::default().data(d))),
698 );
699 SseResp::new(stream).keep_alive(SseKeepAlive::default())
700 }
701
702 let gateway_listener = TcpListener::bind("127.0.0.1:0")
703 .await
704 .expect("bind gateway");
705 let gateway_addr = gateway_listener.local_addr().expect("addr");
706 let gateway_app = Router::new().route("/chat/completions", post(chat_completions));
707 tokio::spawn(async move {
708 axum::serve(gateway_listener, gateway_app).await.ok();
709 });
710
711 let agent_card = agent_card_with_streaming(true);
713 let agent_config = AgentConfig {
714 provider: "openai".to_string(),
715 model: "test-model".to_string(),
716 base_url: Some(format!("http://{gateway_addr}")),
717 ..AgentConfig::default()
718 };
719 let agent = AgentBuilder::new()
720 .with_config(&agent_config)
721 .build()
722 .await
723 .expect("agent builds");
724
725 let server = A2AServerBuilder::new()
726 .with_agent_card(agent_card)
727 .with_agent(agent)
728 .with_default_task_handlers()
729 .build()
730 .await
731 .expect("server builds");
732
733 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind a2a");
734 let addr = listener.local_addr().expect("addr");
735 let app = Router::new()
736 .route("/a2a", post(a2a_handler))
737 .with_state(Arc::new(AppState::new(server)));
738 tokio::spawn(async move {
739 axum::serve(listener, app).await.ok();
740 });
741
742 let client = A2AClient::new(format!("http://{addr}")).expect("client");
743
744 let request = SendMessageRequest {
745 configuration: None,
746 message: Some(A2AMessage {
747 context_id: None,
748 extensions: vec![],
749 message_id: "msg-default-stream".to_string(),
750 metadata: None,
751 parts: vec![Part {
752 data: None,
753 file: None,
754 metadata: None,
755 text: Some("hi".to_string()),
756 }],
757 reference_task_ids: vec![],
758 role: Role::RoleUser,
759 task_id: None,
760 }),
761 metadata: None,
762 tenant: "tests".to_string(),
763 };
764
765 let mut stream = Box::pin(client.stream_message(request).await.expect("stream"));
766 let mut events: Vec<StreamResponse> = Vec::new();
767 while let Some(item) = stream.next().await {
768 events.push(item.expect("event"));
769 }
770
771 assert_eq!(
772 events.len(),
773 7,
774 "unexpected event count {}: {:?}",
775 events.len(),
776 events
777 );
778
779 assert!(events[0].task.is_some(), "first event carries task");
780 let working = events[1]
781 .status_update
782 .as_ref()
783 .expect("second event is status update");
784 assert_eq!(working.status.state, TaskState::TaskStateWorking);
785 assert!(!working.final_);
786
787 let mut artifact_ids = std::collections::HashSet::new();
788 let chunks: Vec<String> = (2..=4)
789 .map(|i| {
790 let upd = events[i]
791 .artifact_update
792 .as_ref()
793 .unwrap_or_else(|| panic!("event[{i}] should be an artifact update"));
794 assert_eq!(upd.append, Some(true), "deltas must have append=true");
795 assert_eq!(upd.last_chunk, Some(false));
796 artifact_ids.insert(upd.artifact.artifact_id.clone());
797 upd.artifact
798 .parts
799 .iter()
800 .filter_map(|p| p.text.clone())
801 .collect::<String>()
802 })
803 .collect();
804 assert_eq!(chunks, vec!["Hel", "lo ", "world"]);
805 assert_eq!(
806 artifact_ids.len(),
807 1,
808 "all deltas must share a single artifact_id"
809 );
810
811 let terminal_artifact = events[5]
812 .artifact_update
813 .as_ref()
814 .expect("event[5] should be the terminal artifact chunk");
815 assert_eq!(terminal_artifact.last_chunk, Some(true));
816 assert!(
817 terminal_artifact.artifact.parts.is_empty(),
818 "terminal chunk should have empty parts"
819 );
820 assert_eq!(
821 artifact_ids.iter().next().unwrap(),
822 &terminal_artifact.artifact.artifact_id,
823 "terminal chunk must share artifact_id with deltas"
824 );
825
826 let completed = events[6]
827 .status_update
828 .as_ref()
829 .expect("event[6] should be the Completed status");
830 assert_eq!(completed.status.state, TaskState::TaskStateCompleted);
831 assert!(completed.final_);
832 let assembled = completed
833 .status
834 .message
835 .as_ref()
836 .expect("completed status carries the final message")
837 .parts
838 .iter()
839 .filter_map(|p| p.text.clone())
840 .collect::<String>();
841 assert_eq!(assembled, "Hello world");
842 }
843
844 #[derive(Clone, Default)]
847 struct ToolMockState {
848 non_streaming_calls: std::sync::Arc<std::sync::atomic::AtomicUsize>,
849 captured_tool_results: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
850 }
851
852 fn tool_call_response_json() -> serde_json::Value {
853 serde_json::json!({
854 "id": "chatcmpl-tool",
855 "object": "chat.completion",
856 "created": 0,
857 "model": "test-model",
858 "choices": [{
859 "index": 0,
860 "finish_reason": "tool_calls",
861 "message": {
862 "role": "assistant",
863 "content": "",
864 "tool_calls": [{
865 "id": "call_1",
866 "type": "function",
867 "function": {
868 "name": "echo_arg",
869 "arguments": "{\"text\":\"hi\"}",
870 }
871 }],
872 },
873 }],
874 })
875 }
876
877 fn final_answer_response_json(text: &str) -> serde_json::Value {
878 serde_json::json!({
879 "id": "chatcmpl-final",
880 "object": "chat.completion",
881 "created": 0,
882 "model": "test-model",
883 "choices": [{
884 "index": 0,
885 "finish_reason": "stop",
886 "message": {
887 "role": "assistant",
888 "content": text,
889 "tool_calls": [],
890 },
891 }],
892 })
893 }
894
895 async fn mock_non_streaming(
896 State(state): State<std::sync::Arc<ToolMockState>>,
897 body: Value,
898 ) -> Json<Value> {
899 if let Some(msgs) = body.get("messages").and_then(|v| v.as_array()) {
900 for m in msgs {
901 if m.get("role").and_then(|v| v.as_str()) == Some("tool")
902 && let Some(text) = m.get("content").and_then(|v| v.as_str())
903 {
904 state
905 .captured_tool_results
906 .lock()
907 .expect("mutex poisoned")
908 .push(text.to_string());
909 }
910 }
911 }
912 let call_index = state
913 .non_streaming_calls
914 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
915 if call_index == 0 {
916 Json(tool_call_response_json())
917 } else {
918 Json(final_answer_response_json("12 is the tool result"))
919 }
920 }
921
922 async fn mock_chat_completions(
927 State(state): State<std::sync::Arc<ToolMockState>>,
928 body: axum::body::Bytes,
929 ) -> Json<Value> {
930 let parsed: Value = serde_json::from_slice(&body).expect("valid JSON");
931 mock_non_streaming(State(state), parsed).await
932 }
933
934 async fn build_echo_agent_with_recorder(
935 gateway_url: String,
936 ) -> (Agent, std::sync::Arc<std::sync::Mutex<Vec<String>>>) {
937 use crate::config::AgentConfig;
938
939 let recorded = std::sync::Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
940 let recorded_clone = std::sync::Arc::clone(&recorded);
941
942 let echo_tool = ChatCompletionTool {
943 type_: ChatCompletionToolType::Function,
944 function: FunctionObject {
945 name: "echo_arg".to_string(),
946 description: Some("echo back the text arg".to_string()),
947 parameters: Some(FunctionParameters(
948 serde_json::json!({
949 "type": "object",
950 "properties": {"text": {"type": "string"}},
951 "required": ["text"],
952 })
953 .as_object()
954 .unwrap()
955 .clone(),
956 )),
957 strict: false,
958 },
959 };
960
961 let agent_cfg = AgentConfig {
962 provider: "openai".to_string(),
963 model: "test-model".to_string(),
964 base_url: Some(gateway_url),
965 ..AgentConfig::default()
966 };
967
968 let agent = AgentBuilder::new()
969 .with_config(&agent_cfg)
970 .with_toolbox(vec![echo_tool])
971 .with_async_function_tool("echo_arg".to_string(), move |args: Value| {
972 let recorded = std::sync::Arc::clone(&recorded_clone);
973 async move {
974 let text = args
975 .get("text")
976 .and_then(|v| v.as_str())
977 .unwrap_or("")
978 .to_string();
979 recorded.lock().expect("mutex poisoned").push(text.clone());
980 Ok(format!("echoed: {text}"))
981 }
982 })
983 .build()
984 .await
985 .expect("agent builds");
986 (agent, recorded)
987 }
988
989 #[tokio::test]
990 async fn default_background_handler_dispatches_tool_calls() {
991 use crate::A2AClient;
992 use crate::a2a_types::Message as A2AMessage;
993
994 let mock_state = std::sync::Arc::new(ToolMockState::default());
995 let gateway_listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
996 let gateway_addr = gateway_listener.local_addr().expect("addr");
997 let gateway_app = Router::new()
998 .route("/chat/completions", post(mock_chat_completions))
999 .with_state(std::sync::Arc::clone(&mock_state));
1000 tokio::spawn(async move {
1001 axum::serve(gateway_listener, gateway_app).await.ok();
1002 });
1003
1004 let (agent, recorded) =
1005 build_echo_agent_with_recorder(format!("http://{gateway_addr}")).await;
1006 let card = agent_card_with_streaming(false);
1007
1008 let mut server = A2AServerBuilder::new()
1009 .with_agent_card(card)
1010 .with_agent(agent)
1011 .with_default_background_task_handler()
1012 .build()
1013 .await
1014 .expect("server builds");
1015
1016 let runner = server
1017 .task_manager
1018 .take()
1019 .expect("task manager configured for background handler")
1020 .start();
1021
1022 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind a2a");
1023 let addr = listener.local_addr().expect("a2a addr");
1024 let app = Router::new()
1025 .route("/a2a", post(a2a_handler))
1026 .with_state(Arc::new(AppState::new(server)));
1027 tokio::spawn(async move {
1028 axum::serve(listener, app).await.ok();
1029 });
1030
1031 let client = A2AClient::new(format!("http://{addr}")).expect("client");
1032 let response = client
1033 .send_message(SendMessageRequest {
1034 configuration: None,
1035 message: Some(A2AMessage {
1036 context_id: None,
1037 extensions: vec![],
1038 message_id: "msg-bg-tool".to_string(),
1039 metadata: None,
1040 parts: vec![Part {
1041 data: None,
1042 file: None,
1043 metadata: None,
1044 text: Some("ask".to_string()),
1045 }],
1046 reference_task_ids: vec![],
1047 role: Role::RoleUser,
1048 task_id: None,
1049 }),
1050 metadata: None,
1051 tenant: "tests".to_string(),
1052 })
1053 .await
1054 .expect("message/send");
1055
1056 let submitted = response.task.expect("task in response");
1057 assert_eq!(submitted.status.state, TaskState::TaskStateSubmitted);
1058
1059 let final_task = poll_until_terminal(&client, &submitted.id).await;
1060 assert_eq!(final_task.status.state, TaskState::TaskStateCompleted);
1061 let final_text = final_task
1062 .status
1063 .message
1064 .expect("final agent message")
1065 .parts
1066 .iter()
1067 .filter_map(|p| p.text.clone())
1068 .collect::<String>();
1069 assert_eq!(final_text, "12 is the tool result");
1070
1071 assert_eq!(
1072 recorded.lock().expect("mutex poisoned").clone(),
1073 vec!["hi".to_string()],
1074 "echo_arg should fire exactly once with the model-supplied argument",
1075 );
1076 assert_eq!(
1077 mock_state
1078 .captured_tool_results
1079 .lock()
1080 .expect("mutex poisoned")
1081 .clone(),
1082 vec!["echoed: hi".to_string()],
1083 "second gateway call should include the tool result as a Tool-role message",
1084 );
1085
1086 runner.shutdown().await;
1087 }
1088
1089 async fn poll_until_terminal(client: &crate::A2AClient, task_id: &str) -> Task {
1093 for _ in 0..100 {
1094 let fetched = client
1095 .get_task(crate::a2a_types::GetTaskRequest {
1096 history_length: None,
1097 name: format!("tasks/{task_id}"),
1098 tenant: Some("tests".to_string()),
1099 })
1100 .await
1101 .expect("tasks/get");
1102 if matches!(
1103 fetched.status.state,
1104 TaskState::TaskStateCompleted
1105 | TaskState::TaskStateFailed
1106 | TaskState::TaskStateCancelled
1107 | TaskState::TaskStateRejected
1108 ) {
1109 return fetched;
1110 }
1111 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
1112 }
1113 panic!("task {task_id} never reached terminal state within 2s");
1114 }
1115
1116 #[tokio::test]
1117 async fn default_streaming_handler_dispatches_tool_calls() {
1118 use crate::A2AClient;
1119 use crate::a2a_types::Message as A2AMessage;
1120 use futures_util::StreamExt;
1121
1122 let mock_state = std::sync::Arc::new(ToolMockState::default());
1123 let gateway_listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
1124 let gateway_addr = gateway_listener.local_addr().expect("addr");
1125 let gateway_app = Router::new()
1126 .route("/chat/completions", post(mock_chat_completions))
1127 .with_state(std::sync::Arc::clone(&mock_state));
1128 tokio::spawn(async move {
1129 axum::serve(gateway_listener, gateway_app).await.ok();
1130 });
1131
1132 let (agent, recorded) =
1133 build_echo_agent_with_recorder(format!("http://{gateway_addr}")).await;
1134 let card = agent_card_with_streaming(true);
1135
1136 let server = A2AServerBuilder::new()
1137 .with_agent_card(card)
1138 .with_agent(agent)
1139 .with_default_streaming_task_handler()
1140 .build()
1141 .await
1142 .expect("server builds");
1143
1144 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind a2a");
1145 let addr = listener.local_addr().expect("a2a addr");
1146 let app = Router::new()
1147 .route("/a2a", post(a2a_handler))
1148 .with_state(Arc::new(AppState::new(server)));
1149 tokio::spawn(async move {
1150 axum::serve(listener, app).await.ok();
1151 });
1152
1153 let client = A2AClient::new(format!("http://{addr}")).expect("client");
1154 let request = SendMessageRequest {
1155 configuration: None,
1156 message: Some(A2AMessage {
1157 context_id: None,
1158 extensions: vec![],
1159 message_id: "msg-stream-tool".to_string(),
1160 metadata: None,
1161 parts: vec![Part {
1162 data: None,
1163 file: None,
1164 metadata: None,
1165 text: Some("ask".to_string()),
1166 }],
1167 reference_task_ids: vec![],
1168 role: Role::RoleUser,
1169 task_id: None,
1170 }),
1171 metadata: None,
1172 tenant: "tests".to_string(),
1173 };
1174
1175 let mut stream = Box::pin(client.stream_message(request).await.expect("stream"));
1176 let mut events: Vec<StreamResponse> = Vec::new();
1177 while let Some(item) = stream.next().await {
1178 events.push(item.expect("event"));
1179 }
1180
1181 assert_eq!(
1182 recorded.lock().expect("mutex poisoned").clone(),
1183 vec!["hi".to_string()],
1184 "echo_arg should fire once during the tool-loop preflight"
1185 );
1186
1187 let saw_tool_status = events.iter().any(|e| {
1188 e.status_update
1189 .as_ref()
1190 .and_then(|u| u.status.message.as_ref())
1191 .map(|m| {
1192 m.parts
1193 .iter()
1194 .filter_map(|p| p.text.clone())
1195 .any(|t| t.contains("calling tool"))
1196 })
1197 .unwrap_or(false)
1198 });
1199 assert!(
1200 !saw_tool_status,
1201 "stream should NOT carry tool-lifecycle status updates",
1202 );
1203
1204 let accumulated: String = events
1205 .iter()
1206 .filter_map(|e| e.artifact_update.as_ref())
1207 .flat_map(|a| {
1208 a.artifact
1209 .parts
1210 .iter()
1211 .filter_map(|p| p.text.clone())
1212 .collect::<Vec<_>>()
1213 })
1214 .collect::<String>();
1215 assert_eq!(accumulated, "12 is the tool result");
1216
1217 let last = events.last().expect("at least one event");
1218 let last_status = last
1219 .status_update
1220 .as_ref()
1221 .expect("last event is a status update");
1222 assert_eq!(last_status.status.state, TaskState::TaskStateCompleted);
1223 assert!(last_status.final_);
1224 }
1225}