Skip to main content

inference_gateway_adk/server/
task_handler.rs

1use super::agent::Agent;
2use super::storage::Storage;
3use crate::a2a_types::{
4    Artifact, Message as A2AMessage, Part, Role, StreamResponse, Task, TaskArtifactUpdateEvent,
5    TaskState, TaskStatus, TaskStatusUpdateEvent, Timestamp,
6};
7use anyhow::{Result, anyhow};
8use futures_util::stream::StreamExt;
9use inference_gateway_sdk::{Message, MessageContent, MessageRole};
10use serde_json::Value;
11use std::sync::Arc;
12use tokio::sync::mpsc;
13use tracing::{debug, warn};
14
15/// Handler invoked by the server for `message/send` requests.
16///
17/// Implementations receive a freshly-built task (already in
18/// `TaskStateSubmitted`) plus the incoming user message, run the business
19/// logic, and return the final task - typically with `state == Completed`
20/// and an agent reply attached to `status.message`.
21#[async_trait::async_trait]
22pub trait TaskHandler: Send + Sync + std::fmt::Debug {
23    async fn handle_task(&self, task: Task, message: Option<A2AMessage>) -> Result<Task>;
24}
25
26/// Handler invoked by the server for `message/stream` requests.
27///
28/// The server is responsible for parsing the request, persisting the initial
29/// `Submitted` task, and emitting the first event (the `Task` wrapper). The
30/// handler then drives the task to a terminal state by emitting
31/// `StreamResponse` events via [`StreamEmitter`]. The last emitted event
32/// **must** carry a `TaskStatusUpdateEvent` with `final: true`; otherwise
33/// callers will treat the stream as unterminated.
34#[async_trait::async_trait]
35pub trait StreamableTaskHandler: Send + Sync + std::fmt::Debug {
36    /// Drive a `message/stream` interaction.
37    ///
38    /// `task` is the freshly-built task already persisted in storage at
39    /// `TaskStateSubmitted`. The handler should emit subsequent events
40    /// (typically `Working` → optional artifact(s) → `Completed`).
41    async fn handle_streaming_task(
42        &self,
43        task: Task,
44        message: Option<A2AMessage>,
45        emitter: StreamEmitter,
46    ) -> Result<()>;
47}
48
49/// Emits `StreamResponse` events into an active `message/stream` response and
50/// keeps the stored task in sync with the latest status.
51#[derive(Clone)]
52pub struct StreamEmitter {
53    tx: mpsc::Sender<StreamResponse>,
54    storage: Arc<dyn Storage>,
55}
56
57impl std::fmt::Debug for StreamEmitter {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        f.debug_struct("StreamEmitter").finish_non_exhaustive()
60    }
61}
62
63impl StreamEmitter {
64    pub(super) fn new(tx: mpsc::Sender<StreamResponse>, storage: Arc<dyn Storage>) -> Self {
65        Self { tx, storage }
66    }
67
68    /// Send a raw `StreamResponse` to the connected client.
69    pub async fn emit(&self, response: StreamResponse) -> Result<()> {
70        self.tx
71            .send(response)
72            .await
73            .map_err(|_| anyhow!("stream receiver dropped before handler finished"))
74    }
75
76    /// Convenience helper that updates the stored task to `state` (attaching
77    /// `message` to the task status if provided), then emits a
78    /// `TaskStatusUpdateEvent` describing the new state.
79    pub async fn emit_status(
80        &self,
81        task_id: &str,
82        context_id: &str,
83        state: TaskState,
84        message: Option<A2AMessage>,
85        final_: bool,
86    ) -> Result<()> {
87        let now = Timestamp(chrono::Utc::now());
88        let new_status = TaskStatus {
89            message: message.clone(),
90            state,
91            timestamp: Some(now),
92        };
93
94        if let Some(mut task) = self.storage.get_task(task_id).await {
95            task.status = new_status.clone();
96            if let Some(ref msg) = message {
97                task.history.push(msg.clone());
98            }
99            self.storage.put_task(task).await;
100        }
101
102        let event = TaskStatusUpdateEvent {
103            context_id: context_id.to_string(),
104            final_,
105            metadata: None,
106            status: new_status,
107            task_id: task_id.to_string(),
108        };
109
110        self.emit(StreamResponse {
111            artifact_update: None,
112            message: None,
113            status_update: Some(event),
114            task: None,
115        })
116        .await
117    }
118
119    /// Convenience helper that appends a text artifact to the stored task and
120    /// emits a `TaskArtifactUpdateEvent` describing it.
121    pub async fn emit_text_artifact(
122        &self,
123        task_id: &str,
124        context_id: &str,
125        text: impl Into<String>,
126        last_chunk: bool,
127    ) -> Result<()> {
128        let artifact_id = uuid::Uuid::new_v4().to_string();
129        let text = text.into();
130        let artifact = Artifact {
131            artifact_id: artifact_id.clone(),
132            description: None,
133            extensions: vec![],
134            metadata: None,
135            name: None,
136            parts: vec![Part {
137                data: None,
138                file: None,
139                metadata: None,
140                text: Some(text),
141            }],
142        };
143
144        if let Some(mut task) = self.storage.get_task(task_id).await {
145            task.artifacts.push(artifact.clone());
146            self.storage.put_task(task).await;
147        }
148
149        let event = TaskArtifactUpdateEvent {
150            append: None,
151            artifact,
152            context_id: context_id.to_string(),
153            last_chunk: Some(last_chunk),
154            metadata: None,
155            task_id: task_id.to_string(),
156        };
157
158        self.emit(StreamResponse {
159            artifact_update: Some(event),
160            message: None,
161            status_update: None,
162            task: None,
163        })
164        .await
165    }
166}
167
168pub(super) fn build_agent_text_message(task: &Task, text: &str) -> A2AMessage {
169    A2AMessage {
170        context_id: Some(task.context_id.clone()),
171        extensions: vec![],
172        message_id: uuid::Uuid::new_v4().to_string(),
173        metadata: None,
174        parts: vec![Part {
175            data: None,
176            file: None,
177            metadata: None,
178            text: Some(text.to_string()),
179        }],
180        reference_task_ids: vec![],
181        role: Role::RoleAgent,
182        task_id: Some(task.id.clone()),
183    }
184}
185
186fn message_content_to_string(content: &MessageContent) -> String {
187    match content {
188        MessageContent::String(s) => s.clone(),
189        MessageContent::Array(parts) => serde_json::to_string(parts).unwrap_or_default(),
190    }
191}
192
193/// Translate the task history into the SDK message format expected by the
194/// agent's [`LLMClient`]. Optionally prepends the agent's system prompt.
195///
196/// [`LLMClient`]: super::agent_llm_client::LLMClient
197fn build_sdk_messages(agent: &Agent, task: &Task) -> Vec<Message> {
198    let mut messages: Vec<Message> = Vec::new();
199    if let Some(prompt) = agent.system_prompt.clone() {
200        messages.push(Message {
201            role: MessageRole::System,
202            content: MessageContent::String(prompt),
203            reasoning: None,
204            reasoning_content: None,
205            tool_call_id: None,
206            tool_calls: Vec::new(),
207        });
208    }
209    for a2a_msg in &task.history {
210        let text = a2a_msg
211            .parts
212            .iter()
213            .filter_map(|p| p.text.clone())
214            .collect::<Vec<_>>()
215            .join("");
216        if text.is_empty() {
217            continue;
218        }
219        let role = match a2a_msg.role {
220            Role::RoleAgent => MessageRole::Assistant,
221            _ => MessageRole::User,
222        };
223        messages.push(Message {
224            role,
225            content: MessageContent::String(text),
226            reasoning: None,
227            reasoning_content: None,
228            tool_call_id: None,
229            tool_calls: Vec::new(),
230        });
231    }
232    messages
233}
234
235/// Static message returned by the default handlers when no agent is
236/// configured.
237const NO_AGENT_REPLY: &str = "I received your message. I'm a default task handler without AI capabilities. \
238     To enable AI responses, configure an OpenAI-compatible agent via \
239     `A2AServerBuilder::with_agent(...)`.";
240
241/// Outcome of [`run_tool_loop`]. Carries the conversation buffer (with all
242/// assistant tool-call messages + tool result messages appended in order)
243/// plus the final assistant text the model returned once it stopped
244/// invoking tools, and a flag indicating whether the loop hit the iteration
245/// cap.
246struct ToolLoopOutcome {
247    messages: Vec<Message>,
248    final_text: String,
249    exhausted: bool,
250}
251
252/// Drive a non-streaming "model call → execute tool_calls → feed results
253/// back" loop up to `agent.max_chat_completion()` iterations. The default
254/// task handlers use this to bridge the gap between the inference gateway
255/// (which only emits raw OpenAI-style tool_calls) and the registered
256/// [`ToolHandler`] implementations on the agent.
257///
258/// Tool activity is silent at the wire level - which
259/// only debug-logs tool lifecycle events from inside its
260/// `DefaultBackgroundTaskHandler` instead of forwarding them as A2A
261/// `TaskStatusUpdate` events (the A2A spec has no tool-event variant).
262async fn run_tool_loop(agent: &Agent, mut messages: Vec<Message>) -> Result<ToolLoopOutcome> {
263    let llm = agent.llm_client();
264    let tools = agent.toolbox.clone();
265    let max_iterations = agent.max_chat_completion().max(1) as usize;
266
267    for _ in 0..max_iterations {
268        let response = llm
269            .create_chat_completion(messages.clone(), tools.clone())
270            .await
271            .map_err(|e| anyhow!("llm call failed: {e}"))?;
272
273        let Some(choice) = response.choices.into_iter().next() else {
274            return Ok(ToolLoopOutcome {
275                messages,
276                final_text: String::new(),
277                exhausted: false,
278            });
279        };
280
281        let assistant_text = message_content_to_string(&choice.message.content);
282        let tool_calls = choice.message.tool_calls.clone();
283        let reasoning = choice.message.reasoning.clone();
284        let reasoning_content = choice.message.reasoning_content.clone();
285
286        messages.push(Message {
287            role: MessageRole::Assistant,
288            content: MessageContent::String(assistant_text.clone()),
289            reasoning,
290            reasoning_content,
291            tool_call_id: None,
292            tool_calls: tool_calls.clone(),
293        });
294
295        if tool_calls.is_empty() {
296            return Ok(ToolLoopOutcome {
297                messages,
298                final_text: assistant_text,
299                exhausted: false,
300            });
301        }
302
303        for tool_call in tool_calls {
304            let tool_name = tool_call.function.name.clone();
305            let args: Value = serde_json::from_str(&tool_call.function.arguments)
306                .unwrap_or_else(|_| Value::String(tool_call.function.arguments.clone()));
307
308            debug!("tool dispatch: {tool_name}");
309
310            let tool_result = match agent.tool_handler(&tool_name) {
311                Some(handler) => match handler.handle(args).await {
312                    Ok(value) => value,
313                    Err(e) => format!("tool `{tool_name}` failed: {e}"),
314                },
315                None => format!("no handler registered for tool `{tool_name}`"),
316            };
317
318            messages.push(Message {
319                role: MessageRole::Tool,
320                content: MessageContent::String(tool_result),
321                reasoning: None,
322                reasoning_content: None,
323                tool_call_id: Some(tool_call.id.clone()),
324                tool_calls: Vec::new(),
325            });
326        }
327    }
328
329    Ok(ToolLoopOutcome {
330        messages,
331        final_text: String::new(),
332        exhausted: true,
333    })
334}
335
336/// Opt-in default `message/send` handler wired up by
337/// [`A2AServerBuilder::with_default_background_task_handler`] /
338/// [`A2AServerBuilder::with_default_task_handlers`].
339///
340/// When an [`Agent`] is configured, delegates to the inference gateway via a
341/// single non-streaming `generate_content` call and returns the resulting
342/// task with `state == Completed` and the reply attached. Without an agent,
343/// returns the static [`NO_AGENT_REPLY`] message - `processWithoutAgentBackground`.
344#[derive(Debug)]
345pub struct DefaultBackgroundTaskHandler {
346    agent: Option<Arc<Agent>>,
347}
348
349impl DefaultBackgroundTaskHandler {
350    pub fn new(agent: Option<Arc<Agent>>) -> Self {
351        Self { agent }
352    }
353}
354
355#[async_trait::async_trait]
356impl TaskHandler for DefaultBackgroundTaskHandler {
357    async fn handle_task(&self, mut task: Task, _message: Option<A2AMessage>) -> Result<Task> {
358        let (reply_text, terminal_state) = match self.agent.as_ref() {
359            Some(agent) => {
360                let messages = build_sdk_messages(agent, &task);
361                match run_tool_loop(agent, messages).await {
362                    Ok(outcome) if outcome.exhausted => {
363                        warn!(
364                            "default background handler: tool loop exhausted \
365                             after {} iterations without a final answer",
366                            agent.max_chat_completion()
367                        );
368                        (
369                            "Tool loop exhausted before the model produced a \
370                             final answer."
371                                .to_string(),
372                            TaskState::TaskStateFailed,
373                        )
374                    }
375                    Ok(outcome) => {
376                        let text = if outcome.final_text.is_empty() {
377                            "Task completed".to_string()
378                        } else {
379                            outcome.final_text
380                        };
381                        (text, TaskState::TaskStateCompleted)
382                    }
383                    Err(e) => {
384                        warn!("default background handler: agent call failed: {e}");
385                        (
386                            format!("Agent call failed: {e}"),
387                            TaskState::TaskStateFailed,
388                        )
389                    }
390                }
391            }
392            None => (NO_AGENT_REPLY.to_string(), TaskState::TaskStateCompleted),
393        };
394
395        let reply = build_agent_text_message(&task, &reply_text);
396        task.history.push(reply.clone());
397        task.status = TaskStatus {
398            message: Some(reply),
399            state: terminal_state,
400            timestamp: Some(Timestamp(chrono::Utc::now())),
401        };
402        Ok(task)
403    }
404}
405
406/// Opt-in default `message/stream` handler wired up by
407/// [`A2AServerBuilder::with_default_streaming_task_handler`] /
408/// [`A2AServerBuilder::with_default_task_handlers`].
409///
410/// When an [`Agent`] is configured, the handler iterates `generate_content_stream`
411/// from the inference gateway, parses each OpenAI-style delta, and emits a
412/// [`TaskArtifactUpdateEvent`] per non-empty content chunk (`append: true`,
413/// shared `artifact_id`) - clients see the reply build up in real time. The
414/// stream terminates with a final `last_chunk: true` artifact + a
415/// `Completed` status update.
416///
417/// Without an agent, emits a single instructional artifact + `Completed`
418/// so the bundled defaults remain usable for examples and tests.
419#[derive(Debug)]
420pub struct DefaultStreamingTaskHandler {
421    agent: Option<Arc<Agent>>,
422}
423
424impl DefaultStreamingTaskHandler {
425    pub fn new(agent: Option<Arc<Agent>>) -> Self {
426        Self { agent }
427    }
428}
429
430#[async_trait::async_trait]
431impl StreamableTaskHandler for DefaultStreamingTaskHandler {
432    async fn handle_streaming_task(
433        &self,
434        task: Task,
435        _message: Option<A2AMessage>,
436        emitter: StreamEmitter,
437    ) -> Result<()> {
438        emitter
439            .emit_status(
440                &task.id,
441                &task.context_id,
442                TaskState::TaskStateWorking,
443                None,
444                false,
445            )
446            .await?;
447
448        let final_text = match self.agent.as_ref() {
449            Some(agent) => stream_agent_deltas(agent, &task, &emitter).await?,
450            None => {
451                emitter
452                    .emit_text_artifact(&task.id, &task.context_id, NO_AGENT_REPLY, true)
453                    .await?;
454                NO_AGENT_REPLY.to_string()
455            }
456        };
457
458        let reply_message = build_agent_text_message(&task, &final_text);
459        emitter
460            .emit_status(
461                &task.id,
462                &task.context_id,
463                TaskState::TaskStateCompleted,
464                Some(reply_message),
465                true,
466            )
467            .await
468    }
469}
470
471/// Drive `generate_content_stream` and forward each delta chunk to
472/// `emitter` as an incremental [`TaskArtifactUpdateEvent`] sharing a single
473/// `artifact_id`. Returns the accumulated reply text on success. On gateway
474/// failure, the helper falls back to a one-shot error artifact so the
475/// stream still terminates cleanly.
476///
477/// When the agent advertises tools, this helper first runs a non-streaming
478/// [`run_tool_loop`] preflight so any `tool_calls` the model emits get
479/// dispatched to registered [`ToolHandler`] implementations. Once the model
480/// stops requesting tools (or the iteration cap is hit), the final answer
481/// is fetched via `generate_content_stream` and delivered as deltas.
482async fn stream_agent_deltas(
483    agent: &Agent,
484    task: &Task,
485    emitter: &StreamEmitter,
486) -> Result<String> {
487    let base_messages = build_sdk_messages(agent, task);
488
489    let messages = if agent.toolbox().is_some() {
490        match run_tool_loop(agent, base_messages).await {
491            Ok(outcome) if outcome.exhausted => {
492                let msg = "Tool loop exhausted before the model produced a \
493                           final answer."
494                    .to_string();
495                emitter
496                    .emit_text_artifact(&task.id, &task.context_id, &msg, true)
497                    .await?;
498                return Ok(msg);
499            }
500            Ok(outcome) => {
501                if !outcome.final_text.is_empty()
502                    && outcome
503                        .messages
504                        .last()
505                        .map(|m| m.tool_calls.is_empty())
506                        .unwrap_or(true)
507                {
508                    emitter
509                        .emit_text_artifact(&task.id, &task.context_id, &outcome.final_text, true)
510                        .await?;
511                    return Ok(outcome.final_text);
512                }
513                outcome.messages
514            }
515            Err(e) => {
516                warn!("default streaming handler: tool loop failed: {e}");
517                let msg = format!("Agent stream failed: {e}");
518                emitter
519                    .emit_text_artifact(&task.id, &task.context_id, &msg, true)
520                    .await?;
521                return Ok(msg);
522            }
523        }
524    } else {
525        base_messages
526    };
527
528    let llm = agent.llm_client();
529    let tools = agent.toolbox.clone();
530    let mut stream = llm.create_streaming_chat_completion(messages, tools);
531
532    let artifact_id = uuid::Uuid::new_v4().to_string();
533    let mut buffer = String::new();
534
535    while let Some(item) = stream.next().await {
536        let event = match item {
537            Ok(e) => e,
538            Err(e) => {
539                warn!("default streaming handler: gateway error: {e}");
540                let msg = format!("Agent stream failed: {e}");
541                emitter
542                    .emit_text_artifact(&task.id, &task.context_id, &msg, true)
543                    .await?;
544                return Ok(msg);
545            }
546        };
547
548        let data = event.data.trim();
549        if data.is_empty() || data == "[DONE]" {
550            if data == "[DONE]" {
551                break;
552            }
553            continue;
554        }
555
556        let parsed: serde_json::Value = match serde_json::from_str(data) {
557            Ok(v) => v,
558            Err(_) => continue,
559        };
560        let Some(text) = parsed
561            .get("choices")
562            .and_then(|c| c.as_array())
563            .and_then(|arr| arr.first())
564            .and_then(|c| c.get("delta"))
565            .and_then(|d| d.get("content"))
566            .and_then(|t| t.as_str())
567        else {
568            continue;
569        };
570        if text.is_empty() {
571            continue;
572        }
573        buffer.push_str(text);
574
575        let chunk_event = TaskArtifactUpdateEvent {
576            append: Some(true),
577            artifact: Artifact {
578                artifact_id: artifact_id.clone(),
579                description: None,
580                extensions: vec![],
581                metadata: None,
582                name: None,
583                parts: vec![Part {
584                    data: None,
585                    file: None,
586                    metadata: None,
587                    text: Some(text.to_string()),
588                }],
589            },
590            context_id: task.context_id.clone(),
591            last_chunk: Some(false),
592            metadata: None,
593            task_id: task.id.clone(),
594        };
595        emitter
596            .emit(StreamResponse {
597                artifact_update: Some(chunk_event),
598                message: None,
599                status_update: None,
600                task: None,
601            })
602            .await?;
603    }
604
605    let final_event = TaskArtifactUpdateEvent {
606        append: Some(true),
607        artifact: Artifact {
608            artifact_id,
609            description: None,
610            extensions: vec![],
611            metadata: None,
612            name: None,
613            parts: vec![],
614        },
615        context_id: task.context_id.clone(),
616        last_chunk: Some(true),
617        metadata: None,
618        task_id: task.id.clone(),
619    };
620    emitter
621        .emit(StreamResponse {
622            artifact_update: Some(final_event),
623            message: None,
624            status_update: None,
625            task: None,
626        })
627        .await?;
628
629    Ok(buffer)
630}
631
632#[cfg(test)]
633mod tests {
634    use super::*;
635    use crate::a2a_types::{AgentCard, Role, SendMessageRequest};
636    use crate::server::agent_builder::AgentBuilder;
637    use crate::server::protocol::{AppState, a2a_handler};
638    use crate::server::server_builder::A2AServerBuilder;
639    use axum::Router;
640    use axum::extract::State;
641    use axum::response::Json;
642    use axum::routing::post;
643    use inference_gateway_sdk::{
644        ChatCompletionTool, ChatCompletionToolType, FunctionObject, FunctionParameters,
645    };
646    use tokio::net::TcpListener;
647
648    fn agent_card_with_streaming(streaming: bool) -> AgentCard {
649        serde_json::from_value(serde_json::json!({
650            "name": "Validation Agent",
651            "description": "Builder validation tests",
652            "version": "0.0.0",
653            "protocolVersion": "0.2.6",
654            "url": "http://localhost/a2a",
655            "preferredTransport": "JSONRPC",
656            "capabilities": {
657                "streaming": streaming,
658                "pushNotifications": false,
659                "stateTransitionHistory": false
660            },
661            "defaultInputModes": ["text/plain"],
662            "defaultOutputModes": ["text/plain"],
663            "skills": [
664                {"id": "x", "name": "x", "description": "x", "tags": ["x"]}
665            ]
666        }))
667        .unwrap()
668    }
669
670    /// Drive the `DefaultStreamingTaskHandler` against a mock OpenAI-compatible
671    /// gateway and verify the handler iterates the delta stream, emitting an
672    /// incremental artifact event per non-empty content chunk (all sharing a
673    /// single artifact_id with `append: true`), terminating with `last_chunk:
674    /// true` and a `Completed` status whose message carries the accumulated
675    /// reply.
676    #[tokio::test]
677    async fn default_streaming_handler_iterates_gateway_deltas() {
678        use crate::A2AClient;
679        use crate::a2a_types::Message as A2AMessage;
680        use crate::config::AgentConfig;
681        use axum::response::sse::{Event as SseEvent, KeepAlive as SseKeepAlive, Sse as SseResp};
682        use futures_util::StreamExt as _;
683
684        // ----- Mock OpenAI-compatible gateway --------------------------------
685        async fn chat_completions() -> SseResp<
686            impl futures_util::Stream<Item = std::result::Result<SseEvent, std::convert::Infallible>>,
687        > {
688            let deltas = [
689                serde_json::json!({"choices":[{"delta":{"content":"Hel"}}]}).to_string(),
690                serde_json::json!({"choices":[{"delta":{"content":"lo "}}]}).to_string(),
691                serde_json::json!({"choices":[{"delta":{"content":"world"}}]}).to_string(),
692                "[DONE]".to_string(),
693            ];
694            let stream = futures_util::stream::iter(
695                deltas
696                    .into_iter()
697                    .map(|d| Ok::<_, std::convert::Infallible>(SseEvent::default().data(d))),
698            );
699            SseResp::new(stream).keep_alive(SseKeepAlive::default())
700        }
701
702        let gateway_listener = TcpListener::bind("127.0.0.1:0")
703            .await
704            .expect("bind gateway");
705        let gateway_addr = gateway_listener.local_addr().expect("addr");
706        let gateway_app = Router::new().route("/chat/completions", post(chat_completions));
707        tokio::spawn(async move {
708            axum::serve(gateway_listener, gateway_app).await.ok();
709        });
710
711        // ----- A2A server using DefaultStreamingTaskHandler ------------------
712        let agent_card = agent_card_with_streaming(true);
713        let agent_config = AgentConfig {
714            provider: "openai".to_string(),
715            model: "test-model".to_string(),
716            base_url: Some(format!("http://{gateway_addr}")),
717            ..AgentConfig::default()
718        };
719        let agent = AgentBuilder::new()
720            .with_config(&agent_config)
721            .build()
722            .await
723            .expect("agent builds");
724
725        let server = A2AServerBuilder::new()
726            .with_agent_card(agent_card)
727            .with_agent(agent)
728            .with_default_task_handlers()
729            .build()
730            .await
731            .expect("server builds");
732
733        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind a2a");
734        let addr = listener.local_addr().expect("addr");
735        let app = Router::new()
736            .route("/a2a", post(a2a_handler))
737            .with_state(Arc::new(AppState::new(server)));
738        tokio::spawn(async move {
739            axum::serve(listener, app).await.ok();
740        });
741
742        let client = A2AClient::new(format!("http://{addr}")).expect("client");
743
744        let request = SendMessageRequest {
745            configuration: None,
746            message: Some(A2AMessage {
747                context_id: None,
748                extensions: vec![],
749                message_id: "msg-default-stream".to_string(),
750                metadata: None,
751                parts: vec![Part {
752                    data: None,
753                    file: None,
754                    metadata: None,
755                    text: Some("hi".to_string()),
756                }],
757                reference_task_ids: vec![],
758                role: Role::RoleUser,
759                task_id: None,
760            }),
761            metadata: None,
762            tenant: "tests".to_string(),
763        };
764
765        let mut stream = Box::pin(client.stream_message(request).await.expect("stream"));
766        let mut events: Vec<StreamResponse> = Vec::new();
767        while let Some(item) = stream.next().await {
768            events.push(item.expect("event"));
769        }
770
771        assert_eq!(
772            events.len(),
773            7,
774            "unexpected event count {}: {:?}",
775            events.len(),
776            events
777        );
778
779        assert!(events[0].task.is_some(), "first event carries task");
780        let working = events[1]
781            .status_update
782            .as_ref()
783            .expect("second event is status update");
784        assert_eq!(working.status.state, TaskState::TaskStateWorking);
785        assert!(!working.final_);
786
787        let mut artifact_ids = std::collections::HashSet::new();
788        let chunks: Vec<String> = (2..=4)
789            .map(|i| {
790                let upd = events[i]
791                    .artifact_update
792                    .as_ref()
793                    .unwrap_or_else(|| panic!("event[{i}] should be an artifact update"));
794                assert_eq!(upd.append, Some(true), "deltas must have append=true");
795                assert_eq!(upd.last_chunk, Some(false));
796                artifact_ids.insert(upd.artifact.artifact_id.clone());
797                upd.artifact
798                    .parts
799                    .iter()
800                    .filter_map(|p| p.text.clone())
801                    .collect::<String>()
802            })
803            .collect();
804        assert_eq!(chunks, vec!["Hel", "lo ", "world"]);
805        assert_eq!(
806            artifact_ids.len(),
807            1,
808            "all deltas must share a single artifact_id"
809        );
810
811        let terminal_artifact = events[5]
812            .artifact_update
813            .as_ref()
814            .expect("event[5] should be the terminal artifact chunk");
815        assert_eq!(terminal_artifact.last_chunk, Some(true));
816        assert!(
817            terminal_artifact.artifact.parts.is_empty(),
818            "terminal chunk should have empty parts"
819        );
820        assert_eq!(
821            artifact_ids.iter().next().unwrap(),
822            &terminal_artifact.artifact.artifact_id,
823            "terminal chunk must share artifact_id with deltas"
824        );
825
826        let completed = events[6]
827            .status_update
828            .as_ref()
829            .expect("event[6] should be the Completed status");
830        assert_eq!(completed.status.state, TaskState::TaskStateCompleted);
831        assert!(completed.final_);
832        let assembled = completed
833            .status
834            .message
835            .as_ref()
836            .expect("completed status carries the final message")
837            .parts
838            .iter()
839            .filter_map(|p| p.text.clone())
840            .collect::<String>();
841        assert_eq!(assembled, "Hello world");
842    }
843
844    // ----- tool-dispatch coverage -------------------------------------------
845
846    #[derive(Clone, Default)]
847    struct ToolMockState {
848        non_streaming_calls: std::sync::Arc<std::sync::atomic::AtomicUsize>,
849        captured_tool_results: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
850    }
851
852    fn tool_call_response_json() -> serde_json::Value {
853        serde_json::json!({
854            "id": "chatcmpl-tool",
855            "object": "chat.completion",
856            "created": 0,
857            "model": "test-model",
858            "choices": [{
859                "index": 0,
860                "finish_reason": "tool_calls",
861                "message": {
862                    "role": "assistant",
863                    "content": "",
864                    "tool_calls": [{
865                        "id": "call_1",
866                        "type": "function",
867                        "function": {
868                            "name": "echo_arg",
869                            "arguments": "{\"text\":\"hi\"}",
870                        }
871                    }],
872                },
873            }],
874        })
875    }
876
877    fn final_answer_response_json(text: &str) -> serde_json::Value {
878        serde_json::json!({
879            "id": "chatcmpl-final",
880            "object": "chat.completion",
881            "created": 0,
882            "model": "test-model",
883            "choices": [{
884                "index": 0,
885                "finish_reason": "stop",
886                "message": {
887                    "role": "assistant",
888                    "content": text,
889                    "tool_calls": [],
890                },
891            }],
892        })
893    }
894
895    async fn mock_non_streaming(
896        State(state): State<std::sync::Arc<ToolMockState>>,
897        body: Value,
898    ) -> Json<Value> {
899        if let Some(msgs) = body.get("messages").and_then(|v| v.as_array()) {
900            for m in msgs {
901                if m.get("role").and_then(|v| v.as_str()) == Some("tool")
902                    && let Some(text) = m.get("content").and_then(|v| v.as_str())
903                {
904                    state
905                        .captured_tool_results
906                        .lock()
907                        .expect("mutex poisoned")
908                        .push(text.to_string());
909                }
910            }
911        }
912        let call_index = state
913            .non_streaming_calls
914            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
915        if call_index == 0 {
916            Json(tool_call_response_json())
917        } else {
918            Json(final_answer_response_json("12 is the tool result"))
919        }
920    }
921
922    /// Single dispatcher: the tool loop is non-streaming end-to-end (since
923    /// the inference gateway's OpenAI surface only exposes
924    /// `CreateChatCompletion`), so this mock only needs to serve the two
925    /// non-streaming responses in order.
926    async fn mock_chat_completions(
927        State(state): State<std::sync::Arc<ToolMockState>>,
928        body: axum::body::Bytes,
929    ) -> Json<Value> {
930        let parsed: Value = serde_json::from_slice(&body).expect("valid JSON");
931        mock_non_streaming(State(state), parsed).await
932    }
933
934    async fn build_echo_agent_with_recorder(
935        gateway_url: String,
936    ) -> (Agent, std::sync::Arc<std::sync::Mutex<Vec<String>>>) {
937        use crate::config::AgentConfig;
938
939        let recorded = std::sync::Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
940        let recorded_clone = std::sync::Arc::clone(&recorded);
941
942        let echo_tool = ChatCompletionTool {
943            type_: ChatCompletionToolType::Function,
944            function: FunctionObject {
945                name: "echo_arg".to_string(),
946                description: Some("echo back the text arg".to_string()),
947                parameters: Some(FunctionParameters(
948                    serde_json::json!({
949                        "type": "object",
950                        "properties": {"text": {"type": "string"}},
951                        "required": ["text"],
952                    })
953                    .as_object()
954                    .unwrap()
955                    .clone(),
956                )),
957                strict: false,
958            },
959        };
960
961        let agent_cfg = AgentConfig {
962            provider: "openai".to_string(),
963            model: "test-model".to_string(),
964            base_url: Some(gateway_url),
965            ..AgentConfig::default()
966        };
967
968        let agent = AgentBuilder::new()
969            .with_config(&agent_cfg)
970            .with_toolbox(vec![echo_tool])
971            .with_async_function_tool("echo_arg".to_string(), move |args: Value| {
972                let recorded = std::sync::Arc::clone(&recorded_clone);
973                async move {
974                    let text = args
975                        .get("text")
976                        .and_then(|v| v.as_str())
977                        .unwrap_or("")
978                        .to_string();
979                    recorded.lock().expect("mutex poisoned").push(text.clone());
980                    Ok(format!("echoed: {text}"))
981                }
982            })
983            .build()
984            .await
985            .expect("agent builds");
986        (agent, recorded)
987    }
988
989    #[tokio::test]
990    async fn default_background_handler_dispatches_tool_calls() {
991        use crate::A2AClient;
992        use crate::a2a_types::Message as A2AMessage;
993
994        let mock_state = std::sync::Arc::new(ToolMockState::default());
995        let gateway_listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
996        let gateway_addr = gateway_listener.local_addr().expect("addr");
997        let gateway_app = Router::new()
998            .route("/chat/completions", post(mock_chat_completions))
999            .with_state(std::sync::Arc::clone(&mock_state));
1000        tokio::spawn(async move {
1001            axum::serve(gateway_listener, gateway_app).await.ok();
1002        });
1003
1004        let (agent, recorded) =
1005            build_echo_agent_with_recorder(format!("http://{gateway_addr}")).await;
1006        let card = agent_card_with_streaming(false);
1007
1008        let mut server = A2AServerBuilder::new()
1009            .with_agent_card(card)
1010            .with_agent(agent)
1011            .with_default_background_task_handler()
1012            .build()
1013            .await
1014            .expect("server builds");
1015
1016        let runner = server
1017            .task_manager
1018            .take()
1019            .expect("task manager configured for background handler")
1020            .start();
1021
1022        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind a2a");
1023        let addr = listener.local_addr().expect("a2a addr");
1024        let app = Router::new()
1025            .route("/a2a", post(a2a_handler))
1026            .with_state(Arc::new(AppState::new(server)));
1027        tokio::spawn(async move {
1028            axum::serve(listener, app).await.ok();
1029        });
1030
1031        let client = A2AClient::new(format!("http://{addr}")).expect("client");
1032        let response = client
1033            .send_message(SendMessageRequest {
1034                configuration: None,
1035                message: Some(A2AMessage {
1036                    context_id: None,
1037                    extensions: vec![],
1038                    message_id: "msg-bg-tool".to_string(),
1039                    metadata: None,
1040                    parts: vec![Part {
1041                        data: None,
1042                        file: None,
1043                        metadata: None,
1044                        text: Some("ask".to_string()),
1045                    }],
1046                    reference_task_ids: vec![],
1047                    role: Role::RoleUser,
1048                    task_id: None,
1049                }),
1050                metadata: None,
1051                tenant: "tests".to_string(),
1052            })
1053            .await
1054            .expect("message/send");
1055
1056        let submitted = response.task.expect("task in response");
1057        assert_eq!(submitted.status.state, TaskState::TaskStateSubmitted);
1058
1059        let final_task = poll_until_terminal(&client, &submitted.id).await;
1060        assert_eq!(final_task.status.state, TaskState::TaskStateCompleted);
1061        let final_text = final_task
1062            .status
1063            .message
1064            .expect("final agent message")
1065            .parts
1066            .iter()
1067            .filter_map(|p| p.text.clone())
1068            .collect::<String>();
1069        assert_eq!(final_text, "12 is the tool result");
1070
1071        assert_eq!(
1072            recorded.lock().expect("mutex poisoned").clone(),
1073            vec!["hi".to_string()],
1074            "echo_arg should fire exactly once with the model-supplied argument",
1075        );
1076        assert_eq!(
1077            mock_state
1078                .captured_tool_results
1079                .lock()
1080                .expect("mutex poisoned")
1081                .clone(),
1082            vec!["echoed: hi".to_string()],
1083            "second gateway call should include the tool result as a Tool-role message",
1084        );
1085
1086        runner.shutdown().await;
1087    }
1088
1089    /// Poll `tasks/get` until the task reaches a terminal state, with a
1090    /// per-test timeout. Used by the queue-driven `message/send` tests
1091    /// that need to wait for the background worker to complete.
1092    async fn poll_until_terminal(client: &crate::A2AClient, task_id: &str) -> Task {
1093        for _ in 0..100 {
1094            let fetched = client
1095                .get_task(crate::a2a_types::GetTaskRequest {
1096                    history_length: None,
1097                    name: format!("tasks/{task_id}"),
1098                    tenant: Some("tests".to_string()),
1099                })
1100                .await
1101                .expect("tasks/get");
1102            if matches!(
1103                fetched.status.state,
1104                TaskState::TaskStateCompleted
1105                    | TaskState::TaskStateFailed
1106                    | TaskState::TaskStateCancelled
1107                    | TaskState::TaskStateRejected
1108            ) {
1109                return fetched;
1110            }
1111            tokio::time::sleep(std::time::Duration::from_millis(20)).await;
1112        }
1113        panic!("task {task_id} never reached terminal state within 2s");
1114    }
1115
1116    #[tokio::test]
1117    async fn default_streaming_handler_dispatches_tool_calls() {
1118        use crate::A2AClient;
1119        use crate::a2a_types::Message as A2AMessage;
1120        use futures_util::StreamExt;
1121
1122        let mock_state = std::sync::Arc::new(ToolMockState::default());
1123        let gateway_listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
1124        let gateway_addr = gateway_listener.local_addr().expect("addr");
1125        let gateway_app = Router::new()
1126            .route("/chat/completions", post(mock_chat_completions))
1127            .with_state(std::sync::Arc::clone(&mock_state));
1128        tokio::spawn(async move {
1129            axum::serve(gateway_listener, gateway_app).await.ok();
1130        });
1131
1132        let (agent, recorded) =
1133            build_echo_agent_with_recorder(format!("http://{gateway_addr}")).await;
1134        let card = agent_card_with_streaming(true);
1135
1136        let server = A2AServerBuilder::new()
1137            .with_agent_card(card)
1138            .with_agent(agent)
1139            .with_default_streaming_task_handler()
1140            .build()
1141            .await
1142            .expect("server builds");
1143
1144        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind a2a");
1145        let addr = listener.local_addr().expect("a2a addr");
1146        let app = Router::new()
1147            .route("/a2a", post(a2a_handler))
1148            .with_state(Arc::new(AppState::new(server)));
1149        tokio::spawn(async move {
1150            axum::serve(listener, app).await.ok();
1151        });
1152
1153        let client = A2AClient::new(format!("http://{addr}")).expect("client");
1154        let request = SendMessageRequest {
1155            configuration: None,
1156            message: Some(A2AMessage {
1157                context_id: None,
1158                extensions: vec![],
1159                message_id: "msg-stream-tool".to_string(),
1160                metadata: None,
1161                parts: vec![Part {
1162                    data: None,
1163                    file: None,
1164                    metadata: None,
1165                    text: Some("ask".to_string()),
1166                }],
1167                reference_task_ids: vec![],
1168                role: Role::RoleUser,
1169                task_id: None,
1170            }),
1171            metadata: None,
1172            tenant: "tests".to_string(),
1173        };
1174
1175        let mut stream = Box::pin(client.stream_message(request).await.expect("stream"));
1176        let mut events: Vec<StreamResponse> = Vec::new();
1177        while let Some(item) = stream.next().await {
1178            events.push(item.expect("event"));
1179        }
1180
1181        assert_eq!(
1182            recorded.lock().expect("mutex poisoned").clone(),
1183            vec!["hi".to_string()],
1184            "echo_arg should fire once during the tool-loop preflight"
1185        );
1186
1187        let saw_tool_status = events.iter().any(|e| {
1188            e.status_update
1189                .as_ref()
1190                .and_then(|u| u.status.message.as_ref())
1191                .map(|m| {
1192                    m.parts
1193                        .iter()
1194                        .filter_map(|p| p.text.clone())
1195                        .any(|t| t.contains("calling tool"))
1196                })
1197                .unwrap_or(false)
1198        });
1199        assert!(
1200            !saw_tool_status,
1201            "stream should NOT carry tool-lifecycle status updates",
1202        );
1203
1204        let accumulated: String = events
1205            .iter()
1206            .filter_map(|e| e.artifact_update.as_ref())
1207            .flat_map(|a| {
1208                a.artifact
1209                    .parts
1210                    .iter()
1211                    .filter_map(|p| p.text.clone())
1212                    .collect::<Vec<_>>()
1213            })
1214            .collect::<String>();
1215        assert_eq!(accumulated, "12 is the tool result");
1216
1217        let last = events.last().expect("at least one event");
1218        let last_status = last
1219            .status_update
1220            .as_ref()
1221            .expect("last event is a status update");
1222        assert_eq!(last_status.status.state, TaskState::TaskStateCompleted);
1223        assert!(last_status.final_);
1224    }
1225}