a2a_client/
client.rs

1//! A2A Client for calling remote A2A agents
2//!
3//! This module provides a client for making A2A protocol calls to remote agents.
4//! It supports both streaming and non-streaming interactions.
5
6use self::sse_parser::SseParser;
7use crate::constants::{AGENT_CARD_PATH, JSONRPC_VERSION};
8use crate::error::{A2AError, A2AResult};
9use a2a_types::{
10    AgentCard, DeleteTaskPushNotificationConfigParams, JSONRPCErrorResponse, JSONRPCId,
11    ListTaskPushNotificationConfigParams, MessageSendParams, SendMessageResponse,
12    SendStreamingMessageResult, Task, TaskIdParams, TaskPushNotificationConfig, TaskQueryParams,
13};
14use futures_core::Stream;
15use reqwest::Client;
16use serde::{Deserialize, Serialize};
17use std::pin::Pin;
18use std::sync::Arc;
19use std::sync::atomic::{AtomicU64, Ordering};
20
21#[cfg(not(target_arch = "wasm32"))]
22type SseStream = Pin<Box<dyn Stream<Item = A2AResult<SendStreamingMessageResult>> + Send>>;
23#[cfg(target_arch = "wasm32")]
24type SseStream = Pin<Box<dyn Stream<Item = A2AResult<SendStreamingMessageResult>>>>;
25
26/// A2A client for communicating with remote agents
27#[derive(Clone)]
28pub struct A2AClient {
29    /// HTTP client for making requests
30    client: Client,
31    /// Service endpoint URL from agent card
32    service_endpoint_url: String,
33    /// Optional authentication token
34    auth_token: Option<String>,
35    /// Request ID counter for JSON-RPC requests
36    request_id_counter: Arc<AtomicU64>,
37    /// Cached agent card
38    agent_card: Arc<AgentCard>,
39}
40
41/// JSON-RPC 2.0 request structure
42#[derive(Debug, Serialize)]
43struct JsonRpcRequest<T> {
44    jsonrpc: String,
45    id: JSONRPCId,
46    method: String,
47    params: T,
48}
49
50/// JSON-RPC 2.0 response structure
51#[derive(Debug, Serialize, Deserialize)]
52#[serde(untagged)]
53enum JsonRpcResponse<T> {
54    Success { id: Option<JSONRPCId>, result: T },
55    Error(JSONRPCErrorResponse),
56}
57
58/// Handles parsing of Server-Sent Events (SSE) streams, accommodating both WASM and native targets.
59mod sse_parser {
60    use super::{A2AError, A2AResult, JsonRpcResponse};
61    use a2a_types::SendStreamingMessageResult;
62    use futures_core::Stream;
63    use std::pin::Pin;
64    use std::task::{Context, Poll};
65
66    // Define a trait that abstracts over the `Send` bound, which is required for non-WASM targets.
67    #[cfg(not(target_arch = "wasm32"))]
68    pub trait ByteStreamTrait: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send {}
69    #[cfg(not(target_arch = "wasm32"))]
70    impl<T: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send> ByteStreamTrait for T {}
71
72    #[cfg(target_arch = "wasm32")]
73    pub trait ByteStreamTrait: Stream<Item = Result<bytes::Bytes, reqwest::Error>> {}
74    #[cfg(target_arch = "wasm32")]
75    impl<T: Stream<Item = Result<bytes::Bytes, reqwest::Error>>> ByteStreamTrait for T {}
76
77    // Define a type alias for the pinned byte stream to avoid repetition.
78    #[cfg(not(target_arch = "wasm32"))]
79    type PinnedByteStream =
80        Pin<Box<dyn Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send>>;
81    #[cfg(target_arch = "wasm32")]
82    type PinnedByteStream = Pin<Box<dyn Stream<Item = Result<bytes::Bytes, reqwest::Error>>>>;
83
84    /// A parser for Server-Sent Events (SSE) streams.
85    pub struct SseParser {
86        inner: PinnedByteStream,
87        buffer: String,
88        event_data_buffer: String,
89        pending_results: Vec<A2AResult<SendStreamingMessageResult>>,
90    }
91
92    impl SseParser {
93        /// Creates a new SSE parser from a byte stream.
94        pub fn new(inner: impl ByteStreamTrait + 'static) -> Self {
95            Self {
96                inner: Box::pin(inner),
97                buffer: String::new(),
98                event_data_buffer: String::new(),
99                pending_results: Vec::new(),
100            }
101        }
102
103        /// Processes a chunk of bytes from the stream, parsing full SSE events.
104        fn process_chunk(
105            &mut self,
106            chunk: bytes::Bytes,
107        ) -> Vec<A2AResult<SendStreamingMessageResult>> {
108            self.buffer.push_str(&String::from_utf8_lossy(&chunk));
109            let mut results = Vec::new();
110
111            // Process buffer line by line.
112            while let Some(newline_pos) = self.buffer.find('\n') {
113                let line = self.buffer[..newline_pos]
114                    .trim_end_matches('\r')
115                    .to_string();
116                self.buffer = self.buffer[newline_pos + 1..].to_string();
117
118                if line.is_empty() {
119                    // An empty line signifies the end of an event.
120                    if !self.event_data_buffer.is_empty() {
121                        match process_sse_event(&self.event_data_buffer) {
122                            Ok(result) => results.push(Ok(result)),
123                            Err(e) => results.push(Err(e)),
124                        }
125                        self.event_data_buffer.clear();
126                    }
127                } else if let Some(data) = line.strip_prefix("data:") {
128                    // Accumulate data lines for a single event.
129                    if !self.event_data_buffer.is_empty() {
130                        self.event_data_buffer.push('\n');
131                    }
132                    self.event_data_buffer.push_str(data.trim_start());
133                } else if line.starts_with(':') {
134                    // Ignore comment lines.
135                }
136            }
137            results
138        }
139    }
140
141    impl Stream for SseParser {
142        type Item = A2AResult<SendStreamingMessageResult>;
143
144        fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
145            // Drain any pending results from the last chunk processing.
146            if let Some(result) = self.pending_results.pop() {
147                return Poll::Ready(Some(result));
148            }
149
150            // Poll the underlying stream for more data.
151            match self.inner.as_mut().poll_next(cx) {
152                Poll::Ready(Some(Ok(chunk))) => {
153                    let mut results = self.process_chunk(chunk);
154                    if results.is_empty() {
155                        // If no full events were parsed, wait for more data.
156                        cx.waker().wake_by_ref();
157                        Poll::Pending
158                    } else {
159                        // Reverse results to return them in the correct order.
160                        results.reverse();
161                        self.pending_results = results;
162                        Poll::Ready(self.pending_results.pop())
163                    }
164                }
165                Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(A2AError::NetworkError {
166                    message: format!("Stream error: {}", e),
167                }))),
168                Poll::Ready(None) => Poll::Ready(None),
169                Poll::Pending => Poll::Pending,
170            }
171        }
172    }
173
174    /// Processes the data part of a single SSE event.
175    fn process_sse_event(json_data: &str) -> A2AResult<SendStreamingMessageResult> {
176        if json_data.trim().is_empty() {
177            return Err(A2AError::SerializationError {
178                message: "Empty SSE event data".to_string(),
179            });
180        }
181
182        // The data is expected to be a JSON-RPC response.
183        let json_response: JsonRpcResponse<SendStreamingMessageResult> =
184            serde_json::from_str(json_data).map_err(|e| A2AError::SerializationError {
185                message: format!("Failed to parse SSE event data: {}", e),
186            })?;
187
188        match json_response {
189            JsonRpcResponse::Success { result, .. } => Ok(result),
190            JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
191                message: format!("SSE event contained an error: {}", err.error.message),
192                code: Some(err.error.code),
193            }),
194        }
195    }
196
197    #[cfg(test)]
198    mod tests {
199        use super::*;
200        use a2a_types::{
201            JSONRPCError, JSONRPCErrorResponse, JSONRPCId, Message, MessageRole, Part,
202        };
203        use bytes::Bytes;
204        use futures_util::{StreamExt, stream};
205
206        fn sample_message(text: &str) -> Message {
207            Message {
208                kind: "message".to_string(),
209                message_id: format!("msg-{text}"),
210                role: MessageRole::Agent,
211                parts: vec![Part::Text {
212                    text: text.to_string(),
213                    metadata: None,
214                }],
215                context_id: Some("ctx-1".into()),
216                task_id: Some("task-1".into()),
217                reference_task_ids: Vec::new(),
218                extensions: Vec::new(),
219                metadata: None,
220            }
221        }
222
223        #[tokio::test]
224        async fn sse_parser_emits_multiple_events_in_order() {
225            let first = JsonRpcResponse::Success {
226                id: Some(JSONRPCId::Integer(1)),
227                result: SendStreamingMessageResult::Message(sample_message("one")),
228            };
229            let second = JsonRpcResponse::Success {
230                id: Some(JSONRPCId::Integer(2)),
231                result: SendStreamingMessageResult::Message(sample_message("two")),
232            };
233            let payload = format!(
234                "data: {}\n\ndata: {}\n\n",
235                serde_json::to_string(&first).expect("json"),
236                serde_json::to_string(&second).expect("json")
237            );
238            let byte_stream = stream::iter(vec![Ok::<Bytes, reqwest::Error>(Bytes::from(payload))]);
239
240            let mut parser = SseParser::new(byte_stream);
241            let first_item = parser.next().await.expect("first event").expect("ok");
242            let second_item = parser.next().await.expect("second event").expect("ok");
243
244            match first_item {
245                SendStreamingMessageResult::Message(msg) => {
246                    assert!(msg.parts.iter().any(|part| part.as_data().is_none()));
247                }
248                other => panic!("expected message, got {other:?}"),
249            }
250
251            match second_item {
252                SendStreamingMessageResult::Message(msg) => {
253                    assert!(msg.message_id.contains("two"));
254                }
255                other => panic!("expected message, got {other:?}"),
256            }
257        }
258
259        #[test]
260        fn process_sse_event_returns_error_for_remote_failure() {
261            let error =
262                JsonRpcResponse::<SendStreamingMessageResult>::Error(JSONRPCErrorResponse {
263                    jsonrpc: "2.0".into(),
264                    error: JSONRPCError {
265                        code: -1,
266                        message: "boom".into(),
267                        data: None,
268                    },
269                    id: Some(JSONRPCId::Integer(1)),
270                });
271            let json = serde_json::to_string(&error).expect("json");
272            let result = process_sse_event(&json);
273            assert!(matches!(result, Err(A2AError::RemoteAgentError { .. })));
274        }
275    }
276}
277
278impl A2AClient {
279    /// Create a new A2A client from an agent card URL
280    ///
281    /// This will fetch the agent card from the specified URL and use the
282    /// service endpoint URL from the card for all subsequent requests.
283    ///
284    /// Uses a default `reqwest::Client` for HTTP requests. For custom HTTP
285    /// configuration, use `from_card_url_with_client()`.
286    ///
287    /// # Example
288    ///
289    /// ```no_run
290    /// use a2a_client::A2AClient;
291    ///
292    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
293    /// let client = A2AClient::from_card_url("https://agent.example.com").await?;
294    /// # Ok(())
295    /// # }
296    /// ```
297    pub async fn from_card_url(base_url: impl AsRef<str>) -> A2AResult<Self> {
298        Self::from_card_url_with_client(base_url, Client::new()).await
299    }
300
301    /// Create a new A2A client from an agent card URL with a custom HTTP client
302    ///
303    /// This allows you to provide a pre-configured `reqwest::Client` with
304    /// custom settings like timeouts, proxies, TLS config, default headers, etc.
305    ///
306    /// # Example
307    ///
308    /// ```no_run
309    /// # #[cfg(not(target_family = "wasm"))]
310    /// # {
311    /// use a2a_client::A2AClient;
312    /// use reqwest::Client;
313    /// use std::time::Duration;
314    ///
315    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
316    /// let http_client = Client::builder()
317    ///     .timeout(Duration::from_secs(30))
318    ///     .build()?;
319    ///
320    /// let client = A2AClient::from_card_url_with_client(
321    ///     "https://agent.example.com",
322    ///     http_client
323    /// ).await?;
324    /// # Ok(())
325    /// # }
326    /// # }
327    /// ```
328    pub async fn from_card_url_with_client(
329        base_url: impl AsRef<str>,
330        http_client: Client,
331    ) -> A2AResult<Self> {
332        let base_url = base_url.as_ref().trim_end_matches('/');
333        let card_url = format!("{}/{}", base_url, AGENT_CARD_PATH);
334
335        let response = http_client
336            .get(&card_url)
337            .header("Accept", "application/json")
338            .send()
339            .await
340            .map_err(|e| A2AError::NetworkError {
341                message: format!("Failed to fetch agent card from {}: {}", card_url, e),
342            })?;
343
344        if !response.status().is_success() {
345            return Err(A2AError::NetworkError {
346                message: format!("Failed to fetch agent card: HTTP {}", response.status()),
347            });
348        }
349
350        let agent_card: AgentCard =
351            response
352                .json()
353                .await
354                .map_err(|e| A2AError::SerializationError {
355                    message: format!("Failed to parse agent card: {}", e),
356                })?;
357
358        if agent_card.url.is_empty() {
359            return Err(A2AError::InvalidParameter {
360                message: "Agent card does not contain a valid 'url' for the service endpoint"
361                    .to_string(),
362            });
363        }
364
365        Ok(Self {
366            client: http_client,
367            service_endpoint_url: agent_card.url.clone(),
368            auth_token: None,
369            request_id_counter: Arc::new(AtomicU64::new(1)),
370            agent_card: Arc::new(agent_card),
371        })
372    }
373
374    /// Create a new A2A client directly from an agent card
375    ///
376    /// This is useful when you already have an agent card and don't need to fetch it.
377    /// Uses a default `reqwest::Client`. For custom HTTP configuration, use `from_card_with_client()`.
378    ///
379    /// # Example
380    ///
381    /// ```no_run
382    /// use a2a_client::A2AClient;
383    /// use a2a_types::AgentCard;
384    ///
385    /// # fn example(agent_card: AgentCard) -> Result<(), Box<dyn std::error::Error>> {
386    /// let client = A2AClient::from_card(agent_card)?;
387    /// # Ok(())
388    /// # }
389    /// ```
390    pub fn from_card(agent_card: AgentCard) -> A2AResult<Self> {
391        Self::from_card_with_client(agent_card, Client::new())
392    }
393
394    /// Create a new A2A client from an agent card with a custom HTTP client
395    ///
396    /// This allows you to provide a pre-configured `reqwest::Client` with
397    /// custom settings like timeouts, proxies, TLS config, default headers, etc.
398    ///
399    /// # Example
400    ///
401    /// ```no_run
402    /// # #[cfg(not(target_family = "wasm"))]
403    /// # {
404    /// use a2a_client::A2AClient;
405    /// use a2a_types::AgentCard;
406    /// use reqwest::Client;
407    /// use std::time::Duration;
408    ///
409    /// # fn example(agent_card: AgentCard) -> Result<(), Box<dyn std::error::Error>> {
410    /// let http_client = Client::builder()
411    ///     .timeout(Duration::from_secs(30))
412    ///     .default_headers({
413    ///         let mut headers = reqwest::header::HeaderMap::new();
414    ///         headers.insert("X-Custom-Header", "value".parse()?);
415    ///         headers
416    ///     })
417    ///     .build()?;
418    ///
419    /// let client = A2AClient::from_card_with_client(agent_card, http_client)?;
420    /// # Ok(())
421    /// # }
422    /// # }
423    /// ```
424    pub fn from_card_with_client(agent_card: AgentCard, http_client: Client) -> A2AResult<Self> {
425        if agent_card.url.is_empty() {
426            return Err(A2AError::InvalidParameter {
427                message: "Agent card does not contain a valid 'url' for the service endpoint"
428                    .to_string(),
429            });
430        }
431
432        Ok(Self {
433            client: http_client,
434            service_endpoint_url: agent_card.url.clone(),
435            auth_token: None,
436            request_id_counter: Arc::new(AtomicU64::new(1)),
437            agent_card: Arc::new(agent_card),
438        })
439    }
440
441    /// Create a new A2A client from an agent card with custom headers
442    ///
443    /// This is a convenience method that builds a reqwest::Client with the provided
444    /// headers and uses it to create the A2AClient.
445    ///
446    /// # Example
447    ///
448    /// ```no_run
449    /// use a2a_client::A2AClient;
450    /// use a2a_types::AgentCard;
451    /// use std::collections::HashMap;
452    ///
453    /// # fn example(agent_card: AgentCard) -> Result<(), Box<dyn std::error::Error>> {
454    /// let mut headers = HashMap::new();
455    /// headers.insert("Authorization".to_string(), "Bearer token123".to_string());
456    /// headers.insert("X-API-Key".to_string(), "my-api-key".to_string());
457    ///
458    /// let client = A2AClient::from_card_with_headers(agent_card, headers)?;
459    /// # Ok(())
460    /// # }
461    /// ```
462    pub fn from_card_with_headers(
463        agent_card: AgentCard,
464        headers: std::collections::HashMap<String, String>,
465    ) -> A2AResult<Self> {
466        use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
467        use std::str::FromStr;
468
469        let mut header_map = HeaderMap::new();
470        for (key, value) in headers {
471            let header_name =
472                HeaderName::from_str(&key).map_err(|e| A2AError::InvalidParameter {
473                    message: format!("Invalid header name '{}': {}", key, e),
474                })?;
475            let header_value =
476                HeaderValue::from_str(&value).map_err(|e| A2AError::InvalidParameter {
477                    message: format!("Invalid header value for '{}': {}", key, e),
478                })?;
479            header_map.insert(header_name, header_value);
480        }
481
482        let http_client = Client::builder()
483            .default_headers(header_map)
484            .build()
485            .map_err(|e| A2AError::NetworkError {
486                message: format!("Failed to build HTTP client with headers: {}", e),
487            })?;
488
489        Self::from_card_with_client(agent_card, http_client)
490    }
491
492    /// Set authentication token (builder pattern)
493    pub fn with_auth_token(mut self, token: impl Into<String>) -> Self {
494        self.auth_token = Some(token.into());
495        self
496    }
497
498    /// Get the cached agent card
499    pub fn agent_card(&self) -> &AgentCard {
500        &self.agent_card
501    }
502
503    /// Fetch a fresh agent card from the base URL
504    pub async fn fetch_agent_card(&self, base_url: impl AsRef<str>) -> A2AResult<AgentCard> {
505        let base_url = base_url.as_ref().trim_end_matches('/');
506        let card_url = format!("{}/{}", base_url, AGENT_CARD_PATH);
507
508        let mut req = self
509            .client
510            .get(&card_url)
511            .header("Accept", "application/json");
512
513        if let Some(token) = &self.auth_token {
514            req = req.bearer_auth(token);
515        }
516
517        let response = req.send().await.map_err(|e| A2AError::NetworkError {
518            message: format!("Failed to fetch agent card from {}: {}", card_url, e),
519        })?;
520
521        if !response.status().is_success() {
522            return Err(A2AError::NetworkError {
523                message: format!("Failed to fetch agent card: HTTP {}", response.status()),
524            });
525        }
526
527        response
528            .json()
529            .await
530            .map_err(|e| A2AError::SerializationError {
531                message: format!("Failed to parse agent card: {}", e),
532            })
533    }
534
535    /// Get the next request ID
536    fn next_request_id(&self) -> JSONRPCId {
537        let id = self.request_id_counter.fetch_add(1, Ordering::SeqCst);
538        JSONRPCId::Integer(id as i64)
539    }
540
541    /// Inject W3C Trace Context into HTTP headers for distributed tracing
542    ///
543    /// Extracts the OpenTelemetry context from the current tracing span and
544    /// injects it into a carrier (HashMap) that can be used as HTTP headers.
545    /// This enables trace propagation across service boundaries.
546    fn inject_trace_context() -> std::collections::HashMap<String, String> {
547        use opentelemetry::global;
548        use tracing_opentelemetry::OpenTelemetrySpanExt;
549
550        let mut carrier = std::collections::HashMap::new();
551
552        // Get the OpenTelemetry context from the current tracing span
553        let context = tracing::Span::current().context();
554
555        // Inject the context into the carrier (adds traceparent, tracestate headers)
556        // OpenTelemetry 0.31+ uses a closure-based API
557        global::get_text_map_propagator(|propagator| {
558            propagator.inject_context(&context, &mut carrier);
559        });
560
561        carrier
562    }
563
564    /// Helper method to make a generic JSON-RPC POST request
565    async fn post_rpc_request<TParams, TResponse>(
566        &self,
567        method: &str,
568        params: TParams,
569    ) -> A2AResult<JsonRpcResponse<TResponse>>
570    where
571        TParams: Serialize,
572        TResponse: for<'de> Deserialize<'de>,
573    {
574        let request_id = self.next_request_id();
575        let rpc_request = JsonRpcRequest {
576            jsonrpc: JSONRPC_VERSION.to_string(),
577            method: method.to_string(),
578            params,
579            id: request_id.clone(),
580        };
581
582        let mut req = self
583            .client
584            .post(&self.service_endpoint_url)
585            .header("Content-Type", "application/json")
586            .header("Accept", "application/json")
587            .json(&rpc_request);
588
589        // Inject distributed tracing headers (W3C Trace Context)
590        for (key, value) in Self::inject_trace_context() {
591            req = req.header(key, value);
592        }
593
594        if let Some(token) = &self.auth_token {
595            req = req.bearer_auth(token);
596        }
597
598        let response = req.send().await.map_err(|e| A2AError::NetworkError {
599            message: format!("Failed to send {} request: {}", method, e),
600        })?;
601
602        if !response.status().is_success() {
603            // Try to parse error response
604            let status = response.status();
605            let error_text = response.text().await.unwrap_or_default();
606            if let Ok(error_json) = serde_json::from_str::<JSONRPCErrorResponse>(&error_text) {
607                return Ok(JsonRpcResponse::Error(error_json));
608            }
609            return Err(A2AError::NetworkError {
610                message: format!("HTTP error {}: {}", status, error_text),
611            });
612        }
613
614        let json_response: JsonRpcResponse<TResponse> =
615            response
616                .json()
617                .await
618                .map_err(|e| A2AError::SerializationError {
619                    message: format!("Failed to parse {} response: {}", method, e),
620                })?;
621
622        // Validate response ID matches request ID
623        if let JsonRpcResponse::Success {
624            id: Some(resp_id), ..
625        } = &json_response
626            && resp_id != &request_id
627        {
628            eprintln!(
629                "WARNING: RPC response ID mismatch for method {}. Expected {:?}, got {:?}",
630                method, request_id, resp_id
631            );
632        }
633
634        Ok(json_response)
635    }
636
637    /// Send a message to the remote agent (non-streaming)
638    pub async fn send_message(&self, params: MessageSendParams) -> A2AResult<SendMessageResponse> {
639        match self.post_rpc_request("message/send", params).await? {
640            JsonRpcResponse::Success { result, .. } => Ok(result),
641            JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
642                message: format!("Remote agent error: {}", err.error.message),
643                code: Some(err.error.code),
644            }),
645        }
646    }
647
648    /// Send a streaming message to the remote agent
649    ///
650    /// Returns a stream of events (Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent)
651    pub async fn send_streaming_message(&self, params: MessageSendParams) -> A2AResult<SseStream> {
652        // Check if agent supports streaming
653        if !self.agent_card.capabilities.streaming.unwrap_or(false) {
654            return Err(A2AError::InvalidParameter {
655                message: "Agent does not support streaming (capabilities.streaming is not true)"
656                    .to_string(),
657            });
658        }
659
660        let request_id = self.next_request_id();
661        let rpc_request = JsonRpcRequest {
662            jsonrpc: JSONRPC_VERSION.to_string(),
663            method: "message/stream".to_string(),
664            params,
665            id: request_id.clone(),
666        };
667
668        let mut req = self
669            .client
670            .post(&self.service_endpoint_url)
671            .header("Content-Type", "application/json")
672            .header("Accept", "text/event-stream")
673            .json(&rpc_request);
674
675        // Inject distributed tracing headers (W3C Trace Context)
676        for (key, value) in Self::inject_trace_context() {
677            req = req.header(key, value);
678        }
679
680        if let Some(token) = &self.auth_token {
681            req = req.bearer_auth(token);
682        }
683
684        let response = req.send().await.map_err(|e| A2AError::NetworkError {
685            message: format!("Failed to send streaming message request: {}", e),
686        })?;
687
688        if !response.status().is_success() {
689            let status = response.status();
690            let error_text = response.text().await.unwrap_or_default();
691            return Err(A2AError::NetworkError {
692                message: format!("HTTP error {}: {}", status, error_text),
693            });
694        }
695
696        // Verify content type
697        let content_type = response
698            .headers()
699            .get("Content-Type")
700            .and_then(|v| v.to_str().ok())
701            .unwrap_or("");
702
703        if !content_type.starts_with("text/event-stream") {
704            return Err(A2AError::NetworkError {
705                message: format!(
706                    "Invalid response Content-Type for SSE stream. Expected 'text/event-stream', got '{}'",
707                    content_type
708                ),
709            });
710        }
711
712        // Parse SSE stream
713        Ok(Box::pin(SseParser::new(response.bytes_stream())))
714    }
715
716    /// Get a specific task from the remote agent
717    pub async fn get_task(&self, params: TaskQueryParams) -> A2AResult<Task> {
718        match self.post_rpc_request("tasks/get", params).await? {
719            JsonRpcResponse::Success { result, .. } => Ok(result),
720            JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
721                message: format!("Remote agent error: {}", err.error.message),
722                code: Some(err.error.code),
723            }),
724        }
725    }
726
727    /// Cancel a task by its ID
728    pub async fn cancel_task(&self, params: TaskIdParams) -> A2AResult<Task> {
729        match self.post_rpc_request("tasks/cancel", params).await? {
730            JsonRpcResponse::Success { result, .. } => Ok(result),
731            JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
732                message: format!("Remote agent error: {}", err.error.message),
733                code: Some(err.error.code),
734            }),
735        }
736    }
737
738    /// Resubscribe to a task's event stream
739    ///
740    /// This is used if a previous SSE connection for an active task was broken.
741    pub async fn resubscribe_task(&self, params: TaskIdParams) -> A2AResult<SseStream> {
742        // Check if agent supports streaming
743        if !self.agent_card.capabilities.streaming.unwrap_or(false) {
744            return Err(A2AError::InvalidParameter {
745                message: "Agent does not support streaming (required for tasks/resubscribe)"
746                    .to_string(),
747            });
748        }
749
750        let request_id = self.next_request_id();
751        let rpc_request = JsonRpcRequest {
752            jsonrpc: JSONRPC_VERSION.to_string(),
753            method: "tasks/resubscribe".to_string(),
754            params,
755            id: request_id.clone(),
756        };
757
758        let mut req = self
759            .client
760            .post(&self.service_endpoint_url)
761            .header("Content-Type", "application/json")
762            .header("Accept", "text/event-stream")
763            .json(&rpc_request);
764
765        // Inject distributed tracing headers (W3C Trace Context)
766        for (key, value) in Self::inject_trace_context() {
767            req = req.header(key, value);
768        }
769
770        if let Some(token) = &self.auth_token {
771            req = req.bearer_auth(token);
772        }
773
774        let response = req.send().await.map_err(|e| A2AError::NetworkError {
775            message: format!("Failed to send resubscribe request: {}", e),
776        })?;
777
778        if !response.status().is_success() {
779            let status = response.status();
780            let error_text = response.text().await.unwrap_or_default();
781            return Err(A2AError::NetworkError {
782                message: format!("HTTP error {}: {}", status, error_text),
783            });
784        }
785
786        // Verify content type
787        let content_type = response
788            .headers()
789            .get("Content-Type")
790            .and_then(|v| v.to_str().ok())
791            .unwrap_or("");
792
793        if !content_type.starts_with("text/event-stream") {
794            return Err(A2AError::NetworkError {
795                message: format!(
796                    "Invalid response Content-Type for SSE stream on resubscribe. Expected 'text/event-stream', got '{}'",
797                    content_type
798                ),
799            });
800        }
801
802        Ok(Box::pin(SseParser::new(response.bytes_stream())))
803    }
804
805    /// Set or update the push notification configuration for a given task
806    pub async fn set_task_push_notification_config(
807        &self,
808        params: TaskPushNotificationConfig,
809    ) -> A2AResult<TaskPushNotificationConfig> {
810        // Check if agent supports push notifications
811        if !self
812            .agent_card
813            .capabilities
814            .push_notifications
815            .unwrap_or(false)
816        {
817            return Err(A2AError::InvalidParameter {
818                message: "Agent does not support push notifications (capabilities.pushNotifications is not true)"
819                    .to_string(),
820            });
821        }
822
823        match self
824            .post_rpc_request("tasks/pushNotificationConfig/set", params)
825            .await?
826        {
827            JsonRpcResponse::Success { result, .. } => Ok(result),
828            JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
829                message: format!("Remote agent error: {}", err.error.message),
830                code: Some(err.error.code),
831            }),
832        }
833    }
834
835    /// Get the push notification configuration for a given task
836    pub async fn get_task_push_notification_config(
837        &self,
838        params: TaskIdParams,
839    ) -> A2AResult<TaskPushNotificationConfig> {
840        match self
841            .post_rpc_request("tasks/pushNotificationConfig/get", params)
842            .await?
843        {
844            JsonRpcResponse::Success { result, .. } => Ok(result),
845            JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
846                message: format!("Remote agent error: {}", err.error.message),
847                code: Some(err.error.code),
848            }),
849        }
850    }
851
852    /// List push notification configurations for a given task
853    pub async fn list_task_push_notification_config(
854        &self,
855        params: ListTaskPushNotificationConfigParams,
856    ) -> A2AResult<Vec<TaskPushNotificationConfig>> {
857        match self
858            .post_rpc_request("tasks/pushNotificationConfig/list", params)
859            .await?
860        {
861            JsonRpcResponse::Success { result, .. } => Ok(result),
862            JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
863                message: format!("Remote agent error: {}", err.error.message),
864                code: Some(err.error.code),
865            }),
866        }
867    }
868
869    /// Delete a push notification configuration for a given task
870    pub async fn delete_task_push_notification_config(
871        &self,
872        params: DeleteTaskPushNotificationConfigParams,
873    ) -> A2AResult<()> {
874        match self
875            .post_rpc_request::<_, serde_json::Value>("tasks/pushNotificationConfig/delete", params)
876            .await?
877        {
878            JsonRpcResponse::Success { .. } => Ok(()),
879            JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
880                message: format!("Remote agent error: {}", err.error.message),
881                code: Some(err.error.code),
882            }),
883        }
884    }
885
886    /// Call a custom extension method
887    ///
888    /// This allows calling custom JSON-RPC methods defined by agent extensions.
889    pub async fn call_extension_method<TParams, TResponse>(
890        &self,
891        method: &str,
892        params: TParams,
893    ) -> A2AResult<TResponse>
894    where
895        TParams: Serialize,
896        TResponse: for<'de> Deserialize<'de>,
897    {
898        match self.post_rpc_request(method, params).await? {
899            JsonRpcResponse::Success { result, .. } => Ok(result),
900            JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
901                message: format!("Remote agent error: {}", err.error.message),
902                code: Some(err.error.code),
903            }),
904        }
905    }
906
907    /// List tasks from the remote agent
908    ///
909    /// Note: This method is not part of the official A2A spec but is commonly implemented.
910    pub async fn list_tasks(&self, context_id: Option<String>) -> A2AResult<Vec<Task>> {
911        #[derive(Serialize)]
912        struct ListTasksParams {
913            #[serde(skip_serializing_if = "Option::is_none")]
914            context_id: Option<String>,
915        }
916
917        match self
918            .post_rpc_request("tasks/list", ListTasksParams { context_id })
919            .await?
920        {
921            JsonRpcResponse::Success { result, .. } => Ok(result),
922            JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
923                message: format!("Remote agent error: {}", err.error.message),
924                code: Some(err.error.code),
925            }),
926        }
927    }
928}
929
930#[cfg(test)]
931mod tests {
932    use super::*;
933
934    #[test]
935    fn test_client_requires_valid_card_url() {
936        let card_without_url = AgentCard {
937            name: "Test".to_string(),
938            description: "Test".to_string(),
939            version: "1.0.0".to_string(),
940            protocol_version: "0.3.0".to_string(),
941            url: "".to_string(), // Empty URL
942            preferred_transport: a2a_types::TransportProtocol::JsonRpc,
943            capabilities: a2a_types::AgentCapabilities::default(),
944            default_input_modes: vec![],
945            default_output_modes: vec![],
946            skills: vec![],
947            provider: None,
948            additional_interfaces: vec![],
949            documentation_url: None,
950            icon_url: None,
951            security: vec![],
952            security_schemes: None,
953            signatures: vec![],
954            supports_authenticated_extended_card: None,
955        };
956
957        assert!(A2AClient::from_card(card_without_url).is_err());
958    }
959
960    #[test]
961    fn test_from_card_with_headers() {
962        let mut headers = std::collections::HashMap::new();
963        headers.insert("Authorization".to_string(), "Bearer token123".to_string());
964        headers.insert("X-API-Key".to_string(), "my-api-key".to_string());
965
966        let card = AgentCard {
967            name: "Test".to_string(),
968            description: "Test agent".to_string(),
969            version: "1.0.0".to_string(),
970            protocol_version: "0.3.0".to_string(),
971            url: "https://example.com".to_string(),
972            preferred_transport: a2a_types::TransportProtocol::JsonRpc,
973            capabilities: a2a_types::AgentCapabilities::default(),
974            default_input_modes: vec![],
975            default_output_modes: vec![],
976            skills: vec![],
977            provider: None,
978            additional_interfaces: vec![],
979            documentation_url: None,
980            icon_url: None,
981            security: vec![],
982            security_schemes: None,
983            signatures: vec![],
984            supports_authenticated_extended_card: None,
985        };
986
987        let result = A2AClient::from_card_with_headers(card, headers);
988        assert!(result.is_ok());
989
990        let client = result.unwrap();
991        assert_eq!(client.service_endpoint_url, "https://example.com");
992    }
993
994    #[test]
995    fn test_from_card_with_invalid_header_name() {
996        let mut headers = std::collections::HashMap::new();
997        headers.insert("Invalid Header Name!".to_string(), "value".to_string());
998
999        let card = AgentCard {
1000            name: "Test".to_string(),
1001            description: "Test agent".to_string(),
1002            version: "1.0.0".to_string(),
1003            protocol_version: "0.3.0".to_string(),
1004            url: "https://example.com".to_string(),
1005            preferred_transport: a2a_types::TransportProtocol::JsonRpc,
1006            capabilities: a2a_types::AgentCapabilities::default(),
1007            default_input_modes: vec![],
1008            default_output_modes: vec![],
1009            skills: vec![],
1010            provider: None,
1011            additional_interfaces: vec![],
1012            documentation_url: None,
1013            icon_url: None,
1014            security: vec![],
1015            security_schemes: None,
1016            signatures: vec![],
1017            supports_authenticated_extended_card: None,
1018        };
1019
1020        let result = A2AClient::from_card_with_headers(card, headers);
1021        assert!(result.is_err());
1022        if let Err(err) = result {
1023            assert!(matches!(err, A2AError::InvalidParameter { .. }));
1024        }
1025    }
1026
1027    #[test]
1028    fn next_request_id_is_monotonic() {
1029        let client = A2AClient::from_card(AgentCard::new(
1030            "Test",
1031            "desc",
1032            "1.0.0",
1033            "https://example.com",
1034        ))
1035        .expect("valid card");
1036
1037        let first = match client.next_request_id() {
1038            JSONRPCId::Integer(value) => value,
1039            other => panic!("unexpected id variant: {other:?}"),
1040        };
1041        let second = match client.next_request_id() {
1042            JSONRPCId::Integer(value) => value,
1043            other => panic!("unexpected id variant: {other:?}"),
1044        };
1045
1046        assert_eq!(first, 1);
1047        assert_eq!(second, 2);
1048    }
1049}