Skip to main content

zeph_a2a/
client.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::pin::Pin;
5
6use eventsource_stream::Eventsource;
7use futures_core::Stream;
8use serde::{Deserialize, Serialize, de::DeserializeOwned};
9use tokio_stream::StreamExt;
10use zeph_common::net::is_private_ip;
11
12use crate::error::A2aError;
13use crate::jsonrpc::{
14    JsonRpcRequest, JsonRpcResponse, METHOD_CANCEL_TASK, METHOD_GET_TASK, METHOD_SEND_MESSAGE,
15    METHOD_SEND_STREAMING_MESSAGE, SendMessageParams, TaskIdParams,
16};
17use crate::types::{Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent};
18
19pub type TaskEventStream = Pin<Box<dyn Stream<Item = Result<TaskEvent, A2aError>> + Send>>;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22#[serde(untagged)]
23pub enum TaskEvent {
24    StatusUpdate(TaskStatusUpdateEvent),
25    ArtifactUpdate(TaskArtifactUpdateEvent),
26}
27
28pub struct A2aClient {
29    client: reqwest::Client,
30    require_tls: bool,
31    ssrf_protection: bool,
32}
33
34impl A2aClient {
35    #[must_use]
36    pub fn new(client: reqwest::Client) -> Self {
37        Self {
38            client,
39            require_tls: false,
40            ssrf_protection: false,
41        }
42    }
43
44    #[must_use]
45    pub fn with_security(mut self, require_tls: bool, ssrf_protection: bool) -> Self {
46        self.require_tls = require_tls;
47        self.ssrf_protection = ssrf_protection;
48        self
49    }
50
51    /// # Errors
52    /// Returns `A2aError` on network, JSON, or JSON-RPC errors.
53    pub async fn send_message(
54        &self,
55        endpoint: &str,
56        params: SendMessageParams,
57        token: Option<&str>,
58    ) -> Result<Task, A2aError> {
59        self.rpc_call(endpoint, METHOD_SEND_MESSAGE, params, token)
60            .await
61    }
62
63    /// # Errors
64    /// Returns `A2aError` on network failure or if the SSE connection cannot be established.
65    pub async fn stream_message(
66        &self,
67        endpoint: &str,
68        params: SendMessageParams,
69        token: Option<&str>,
70    ) -> Result<TaskEventStream, A2aError> {
71        self.validate_endpoint(endpoint).await?;
72        let request = JsonRpcRequest::new(METHOD_SEND_STREAMING_MESSAGE, params);
73        let mut req = self.client.post(endpoint).json(&request);
74        if let Some(t) = token {
75            req = req.bearer_auth(t);
76        }
77        let resp = req.send().await?;
78
79        if !resp.status().is_success() {
80            let status = resp.status();
81            let body = resp.text().await.unwrap_or_default();
82            // Truncate body to avoid leaking large upstream error responses.
83            let truncated = if body.len() > 256 {
84                format!("{}…", &body[..256])
85            } else {
86                body
87            };
88            return Err(A2aError::Stream(format!("HTTP {status}: {truncated}")));
89        }
90
91        let event_stream = resp.bytes_stream().eventsource();
92        let mapped = event_stream.filter_map(|event| match event {
93            Ok(event) => {
94                if event.data.is_empty() || event.data == "[DONE]" {
95                    return None;
96                }
97                match serde_json::from_str::<JsonRpcResponse<TaskEvent>>(&event.data) {
98                    Ok(rpc_resp) => match rpc_resp.into_result() {
99                        Ok(task_event) => Some(Ok(task_event)),
100                        Err(rpc_err) => Some(Err(A2aError::from(rpc_err))),
101                    },
102                    Err(e) => Some(Err(A2aError::Stream(format!(
103                        "failed to parse SSE event: {e}"
104                    )))),
105                }
106            }
107            Err(e) => Some(Err(A2aError::Stream(format!("SSE stream error: {e}")))),
108        });
109
110        Ok(Box::pin(mapped))
111    }
112
113    /// # Errors
114    /// Returns `A2aError` on network, JSON, or JSON-RPC errors.
115    pub async fn get_task(
116        &self,
117        endpoint: &str,
118        params: TaskIdParams,
119        token: Option<&str>,
120    ) -> Result<Task, A2aError> {
121        self.rpc_call(endpoint, METHOD_GET_TASK, params, token)
122            .await
123    }
124
125    /// # Errors
126    /// Returns `A2aError` on network, JSON, or JSON-RPC errors.
127    pub async fn cancel_task(
128        &self,
129        endpoint: &str,
130        params: TaskIdParams,
131        token: Option<&str>,
132    ) -> Result<Task, A2aError> {
133        self.rpc_call(endpoint, METHOD_CANCEL_TASK, params, token)
134            .await
135    }
136
137    async fn validate_endpoint(&self, endpoint: &str) -> Result<(), A2aError> {
138        if self.require_tls && !endpoint.starts_with("https://") {
139            return Err(A2aError::Security(format!(
140                "TLS required but endpoint uses HTTP: {endpoint}"
141            )));
142        }
143
144        if self.ssrf_protection {
145            let url: url::Url = endpoint
146                .parse()
147                .map_err(|e| A2aError::Security(format!("invalid URL: {e}")))?;
148
149            if let Some(host) = url.host_str() {
150                let addrs = tokio::net::lookup_host(format!(
151                    "{}:{}",
152                    host,
153                    url.port_or_known_default().unwrap_or(443)
154                ))
155                .await
156                .map_err(|e| A2aError::Security(format!("DNS resolution failed: {e}")))?;
157
158                for addr in addrs {
159                    if is_private_ip(addr.ip()) {
160                        return Err(A2aError::Security(format!(
161                            "SSRF protection: private IP {} for host {host}",
162                            addr.ip()
163                        )));
164                    }
165                }
166            }
167        }
168
169        Ok(())
170    }
171
172    async fn rpc_call<P: Serialize, R: DeserializeOwned>(
173        &self,
174        endpoint: &str,
175        method: &str,
176        params: P,
177        token: Option<&str>,
178    ) -> Result<R, A2aError> {
179        self.validate_endpoint(endpoint).await?;
180        let request = JsonRpcRequest::new(method, params);
181        let mut req = self.client.post(endpoint).json(&request);
182        if let Some(t) = token {
183            req = req.bearer_auth(t);
184        }
185        let resp = req.send().await?;
186        let rpc_response: JsonRpcResponse<R> = resp.json().await?;
187        rpc_response.into_result().map_err(A2aError::from)
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use std::net::IpAddr;
194
195    use super::*;
196    use crate::jsonrpc::{JsonRpcError, JsonRpcResponse};
197    use crate::types::{
198        Artifact, Message, Part, Task, TaskArtifactUpdateEvent, TaskState, TaskStatus,
199        TaskStatusUpdateEvent,
200    };
201
202    #[test]
203    fn task_event_deserialize_status_update() {
204        let event = TaskStatusUpdateEvent {
205            kind: "status-update".into(),
206            task_id: "t-1".into(),
207            context_id: None,
208            status: TaskStatus {
209                state: TaskState::Working,
210                timestamp: "ts".into(),
211                message: Some(Message::user_text("thinking...")),
212            },
213            is_final: false,
214        };
215        let json = serde_json::to_string(&event).unwrap();
216        let parsed: TaskEvent = serde_json::from_str(&json).unwrap();
217        assert!(matches!(parsed, TaskEvent::StatusUpdate(_)));
218    }
219
220    #[test]
221    fn task_event_deserialize_artifact_update() {
222        let event = TaskArtifactUpdateEvent {
223            kind: "artifact-update".into(),
224            task_id: "t-1".into(),
225            context_id: None,
226            artifact: Artifact {
227                artifact_id: "a-1".into(),
228                name: None,
229                parts: vec![Part::text("result")],
230                metadata: None,
231            },
232            is_final: true,
233        };
234        let json = serde_json::to_string(&event).unwrap();
235        let parsed: TaskEvent = serde_json::from_str(&json).unwrap();
236        assert!(matches!(parsed, TaskEvent::ArtifactUpdate(_)));
237    }
238
239    #[test]
240    fn rpc_response_with_task_result() {
241        let task = Task {
242            id: "t-1".into(),
243            context_id: None,
244            status: TaskStatus {
245                state: TaskState::Completed,
246                timestamp: "ts".into(),
247                message: None,
248            },
249            artifacts: vec![],
250            history: vec![],
251            metadata: None,
252        };
253        let resp = JsonRpcResponse {
254            jsonrpc: "2.0".into(),
255            id: serde_json::Value::String("req-1".into()),
256            result: Some(task),
257            error: None,
258        };
259        let json = serde_json::to_string(&resp).unwrap();
260        let back: JsonRpcResponse<Task> = serde_json::from_str(&json).unwrap();
261        let task = back.into_result().unwrap();
262        assert_eq!(task.id, "t-1");
263        assert_eq!(task.status.state, TaskState::Completed);
264    }
265
266    #[test]
267    fn rpc_response_with_error() {
268        let resp: JsonRpcResponse<Task> = JsonRpcResponse {
269            jsonrpc: "2.0".into(),
270            id: serde_json::Value::String("req-1".into()),
271            result: None,
272            error: Some(JsonRpcError {
273                code: -32001,
274                message: "task not found".into(),
275                data: None,
276            }),
277        };
278        let json = serde_json::to_string(&resp).unwrap();
279        let back: JsonRpcResponse<Task> = serde_json::from_str(&json).unwrap();
280        let err = back.into_result().unwrap_err();
281        assert_eq!(err.code, -32001);
282    }
283
284    #[test]
285    fn a2a_client_construction() {
286        let client = A2aClient::new(reqwest::Client::new());
287        drop(client);
288    }
289
290    #[test]
291    fn is_private_ip_loopback() {
292        assert!(is_private_ip(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)));
293        assert!(is_private_ip(IpAddr::V6(std::net::Ipv6Addr::LOCALHOST)));
294    }
295
296    #[test]
297    fn is_private_ip_private_ranges() {
298        assert!(is_private_ip("10.0.0.1".parse().unwrap()));
299        assert!(is_private_ip("172.16.0.1".parse().unwrap()));
300        assert!(is_private_ip("192.168.1.1".parse().unwrap()));
301    }
302
303    #[test]
304    fn is_private_ip_link_local() {
305        assert!(is_private_ip("169.254.0.1".parse().unwrap()));
306    }
307
308    #[test]
309    fn is_private_ip_unspecified() {
310        assert!(is_private_ip("0.0.0.0".parse().unwrap()));
311        assert!(is_private_ip("::".parse().unwrap()));
312    }
313
314    #[test]
315    fn is_private_ip_public() {
316        assert!(!is_private_ip("8.8.8.8".parse().unwrap()));
317        assert!(!is_private_ip("1.1.1.1".parse().unwrap()));
318    }
319
320    #[tokio::test]
321    async fn tls_enforcement_rejects_http() {
322        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
323        let result = client.validate_endpoint("http://example.com/rpc").await;
324        assert!(result.is_err());
325        let err = result.unwrap_err();
326        assert!(matches!(err, A2aError::Security(_)));
327        assert!(err.to_string().contains("TLS required"));
328    }
329
330    #[tokio::test]
331    async fn tls_enforcement_allows_https() {
332        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
333        let result = client.validate_endpoint("https://example.com/rpc").await;
334        assert!(result.is_ok());
335    }
336
337    #[tokio::test]
338    async fn ssrf_protection_rejects_localhost() {
339        let client = A2aClient::new(reqwest::Client::new()).with_security(false, true);
340        let result = client.validate_endpoint("http://127.0.0.1:8080/rpc").await;
341        assert!(result.is_err());
342        assert!(result.unwrap_err().to_string().contains("SSRF"));
343    }
344
345    #[tokio::test]
346    async fn no_security_allows_http_localhost() {
347        let client = A2aClient::new(reqwest::Client::new());
348        let result = client.validate_endpoint("http://127.0.0.1:8080/rpc").await;
349        assert!(result.is_ok());
350    }
351
352    #[test]
353    fn jsonrpc_request_serialization_for_send_message() {
354        let params = SendMessageParams {
355            message: Message::user_text("hello"),
356            configuration: None,
357        };
358        let req = JsonRpcRequest::new(METHOD_SEND_MESSAGE, params);
359        let json = serde_json::to_string(&req).unwrap();
360        assert!(json.contains("\"method\":\"message/send\""));
361        assert!(json.contains("\"jsonrpc\":\"2.0\""));
362        assert!(json.contains("\"hello\""));
363    }
364
365    #[test]
366    fn jsonrpc_request_serialization_for_get_task() {
367        let params = TaskIdParams {
368            id: "task-123".into(),
369            history_length: Some(5),
370        };
371        let req = JsonRpcRequest::new(METHOD_GET_TASK, params);
372        let json = serde_json::to_string(&req).unwrap();
373        assert!(json.contains("\"method\":\"tasks/get\""));
374        assert!(json.contains("\"task-123\""));
375        assert!(json.contains("\"historyLength\":5"));
376    }
377
378    #[test]
379    fn jsonrpc_request_serialization_for_cancel_task() {
380        let params = TaskIdParams {
381            id: "task-456".into(),
382            history_length: None,
383        };
384        let req = JsonRpcRequest::new(METHOD_CANCEL_TASK, params);
385        let json = serde_json::to_string(&req).unwrap();
386        assert!(json.contains("\"method\":\"tasks/cancel\""));
387        assert!(!json.contains("historyLength"));
388    }
389
390    #[test]
391    fn jsonrpc_request_serialization_for_stream() {
392        let params = SendMessageParams {
393            message: Message::user_text("stream me"),
394            configuration: None,
395        };
396        let req = JsonRpcRequest::new(METHOD_SEND_STREAMING_MESSAGE, params);
397        let json = serde_json::to_string(&req).unwrap();
398        assert!(json.contains("\"method\":\"message/stream\""));
399    }
400
401    #[tokio::test]
402    async fn send_message_connection_error() {
403        let client = A2aClient::new(reqwest::Client::new());
404        let params = SendMessageParams {
405            message: Message::user_text("hello"),
406            configuration: None,
407        };
408        let result = client
409            .send_message("http://127.0.0.1:1/rpc", params, None)
410            .await;
411        assert!(result.is_err());
412        assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
413    }
414
415    #[tokio::test]
416    async fn get_task_connection_error() {
417        let client = A2aClient::new(reqwest::Client::new());
418        let params = TaskIdParams {
419            id: "t-1".into(),
420            history_length: None,
421        };
422        let result = client
423            .get_task("http://127.0.0.1:1/rpc", params, None)
424            .await;
425        assert!(result.is_err());
426        assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
427    }
428
429    #[tokio::test]
430    async fn cancel_task_connection_error() {
431        let client = A2aClient::new(reqwest::Client::new());
432        let params = TaskIdParams {
433            id: "t-1".into(),
434            history_length: None,
435        };
436        let result = client
437            .cancel_task("http://127.0.0.1:1/rpc", params, None)
438            .await;
439        assert!(result.is_err());
440        assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
441    }
442
443    #[tokio::test]
444    async fn stream_message_connection_error() {
445        let client = A2aClient::new(reqwest::Client::new());
446        let params = SendMessageParams {
447            message: Message::user_text("stream me"),
448            configuration: None,
449        };
450        let result = client
451            .stream_message("http://127.0.0.1:1/rpc", params, None)
452            .await;
453        assert!(result.is_err());
454    }
455
456    #[tokio::test]
457    async fn stream_message_tls_required_rejects_http() {
458        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
459        let params = SendMessageParams {
460            message: Message::user_text("hello"),
461            configuration: None,
462        };
463        let result = client
464            .stream_message("http://example.com/rpc", params, None)
465            .await;
466        match result {
467            Err(A2aError::Security(msg)) => assert!(msg.contains("TLS required")),
468            _ => panic!("expected Security error"),
469        }
470    }
471
472    #[tokio::test]
473    async fn send_message_tls_required_rejects_http() {
474        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
475        let params = SendMessageParams {
476            message: Message::user_text("hello"),
477            configuration: None,
478        };
479        let result = client
480            .send_message("http://example.com/rpc", params, None)
481            .await;
482        assert!(result.is_err());
483        assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
484    }
485
486    #[tokio::test]
487    async fn get_task_tls_required_rejects_http() {
488        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
489        let params = TaskIdParams {
490            id: "t-1".into(),
491            history_length: None,
492        };
493        let result = client
494            .get_task("http://example.com/rpc", params, None)
495            .await;
496        assert!(result.is_err());
497        assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
498    }
499
500    #[tokio::test]
501    async fn cancel_task_tls_required_rejects_http() {
502        let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
503        let params = TaskIdParams {
504            id: "t-1".into(),
505            history_length: None,
506        };
507        let result = client
508            .cancel_task("http://example.com/rpc", params, None)
509            .await;
510        assert!(result.is_err());
511        assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
512    }
513
514    #[tokio::test]
515    async fn validate_endpoint_invalid_url_with_ssrf() {
516        let client = A2aClient::new(reqwest::Client::new()).with_security(false, true);
517        let result = client.validate_endpoint("not-a-url").await;
518        assert!(result.is_err());
519        assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
520    }
521
522    #[test]
523    fn with_security_returns_configured_client() {
524        let client = A2aClient::new(reqwest::Client::new()).with_security(true, true);
525        assert!(client.require_tls);
526        assert!(client.ssrf_protection);
527    }
528
529    #[test]
530    fn default_client_no_security() {
531        let client = A2aClient::new(reqwest::Client::new());
532        assert!(!client.require_tls);
533        assert!(!client.ssrf_protection);
534    }
535
536    #[test]
537    fn task_event_clone() {
538        let event = TaskEvent::StatusUpdate(TaskStatusUpdateEvent {
539            kind: "status-update".into(),
540            task_id: "t-1".into(),
541            context_id: None,
542            status: TaskStatus {
543                state: TaskState::Working,
544                timestamp: "ts".into(),
545                message: None,
546            },
547            is_final: false,
548        });
549        let cloned = event.clone();
550        let json1 = serde_json::to_string(&event).unwrap();
551        let json2 = serde_json::to_string(&cloned).unwrap();
552        assert_eq!(json1, json2);
553    }
554
555    #[test]
556    fn task_event_debug() {
557        let event = TaskEvent::ArtifactUpdate(TaskArtifactUpdateEvent {
558            kind: "artifact-update".into(),
559            task_id: "t-1".into(),
560            context_id: None,
561            artifact: Artifact {
562                artifact_id: "a-1".into(),
563                name: None,
564                parts: vec![Part::text("data")],
565                metadata: None,
566            },
567            is_final: true,
568        });
569        let dbg = format!("{event:?}");
570        assert!(dbg.contains("ArtifactUpdate"));
571    }
572
573    #[test]
574    fn is_private_ip_ipv4_non_private() {
575        assert!(!is_private_ip("93.184.216.34".parse().unwrap()));
576    }
577
578    #[test]
579    fn is_private_ip_ipv6_non_private() {
580        assert!(!is_private_ip("2001:db8::1".parse().unwrap()));
581    }
582
583    #[test]
584    fn rpc_response_error_takes_priority_over_result() {
585        let resp = JsonRpcResponse {
586            jsonrpc: "2.0".into(),
587            id: serde_json::Value::String("1".into()),
588            result: Some(Task {
589                id: "t-1".into(),
590                context_id: None,
591                status: TaskStatus {
592                    state: TaskState::Completed,
593                    timestamp: "ts".into(),
594                    message: None,
595                },
596                artifacts: vec![],
597                history: vec![],
598                metadata: None,
599            }),
600            error: Some(JsonRpcError {
601                code: -32001,
602                message: "error".into(),
603                data: None,
604            }),
605        };
606        let err = resp.into_result().unwrap_err();
607        assert_eq!(err.code, -32001);
608    }
609
610    #[test]
611    fn rpc_response_neither_result_nor_error() {
612        let resp: JsonRpcResponse<Task> = JsonRpcResponse {
613            jsonrpc: "2.0".into(),
614            id: serde_json::Value::String("1".into()),
615            result: None,
616            error: None,
617        };
618        let err = resp.into_result().unwrap_err();
619        assert_eq!(err.code, -32603);
620    }
621
622    #[test]
623    fn task_event_serialize_round_trip() {
624        let event = TaskEvent::StatusUpdate(TaskStatusUpdateEvent {
625            kind: "status-update".into(),
626            task_id: "t-1".into(),
627            context_id: Some("ctx-1".into()),
628            status: TaskStatus {
629                state: TaskState::Completed,
630                timestamp: "2025-01-01T00:00:00Z".into(),
631                message: Some(Message::user_text("done")),
632            },
633            is_final: true,
634        });
635        let json = serde_json::to_string(&event).unwrap();
636        let back: TaskEvent = serde_json::from_str(&json).unwrap();
637        assert!(matches!(back, TaskEvent::StatusUpdate(_)));
638    }
639}
640
641#[cfg(test)]
642mod wiremock_tests {
643    use tokio_stream::StreamExt;
644    use wiremock::matchers::{header, method, path};
645    use wiremock::{Mock, MockServer, ResponseTemplate};
646
647    use crate::client::A2aClient;
648    use crate::jsonrpc::{SendMessageParams, TaskIdParams};
649    use crate::testing::*;
650    use crate::types::Message;
651
652    #[tokio::test]
653    async fn send_message_success() {
654        let server = MockServer::start().await;
655        Mock::given(method("POST"))
656            .and(path("/rpc"))
657            .respond_with(task_rpc_response("task-1", "submitted"))
658            .mount(&server)
659            .await;
660
661        let client = A2aClient::new(reqwest::Client::new());
662        let params = SendMessageParams {
663            message: Message::user_text("hello"),
664            configuration: None,
665        };
666        let task = client
667            .send_message(&format!("{}/rpc", server.uri()), params, None)
668            .await
669            .unwrap();
670        assert_eq!(task.id, "task-1");
671    }
672
673    #[tokio::test]
674    async fn send_message_rpc_error() {
675        let server = MockServer::start().await;
676        Mock::given(method("POST"))
677            .and(path("/rpc"))
678            .respond_with(task_rpc_error_response(-32001, "task not found"))
679            .mount(&server)
680            .await;
681
682        let client = A2aClient::new(reqwest::Client::new());
683        let params = SendMessageParams {
684            message: Message::user_text("hi"),
685            configuration: None,
686        };
687        let result = client
688            .send_message(&format!("{}/rpc", server.uri()), params, None)
689            .await;
690        assert!(result.is_err());
691        let err = result.unwrap_err();
692        assert!(matches!(
693            err,
694            crate::error::A2aError::JsonRpc { code: -32001, .. }
695        ));
696    }
697
698    #[tokio::test]
699    async fn send_message_with_bearer_auth() {
700        let server = MockServer::start().await;
701        Mock::given(method("POST"))
702            .and(path("/rpc"))
703            .and(header("authorization", "Bearer secret-token"))
704            .respond_with(task_rpc_response("task-auth", "submitted"))
705            .mount(&server)
706            .await;
707
708        let client = A2aClient::new(reqwest::Client::new());
709        let params = SendMessageParams {
710            message: Message::user_text("secure"),
711            configuration: None,
712        };
713        let task = client
714            .send_message(
715                &format!("{}/rpc", server.uri()),
716                params,
717                Some("secret-token"),
718            )
719            .await
720            .unwrap();
721        assert_eq!(task.id, "task-auth");
722    }
723
724    #[tokio::test]
725    async fn get_task_success() {
726        let server = MockServer::start().await;
727        Mock::given(method("POST"))
728            .and(path("/rpc"))
729            .respond_with(task_rpc_response("task-get", "completed"))
730            .mount(&server)
731            .await;
732
733        let client = A2aClient::new(reqwest::Client::new());
734        let params = TaskIdParams {
735            id: "task-get".into(),
736            history_length: None,
737        };
738        let task = client
739            .get_task(&format!("{}/rpc", server.uri()), params, None)
740            .await
741            .unwrap();
742        assert_eq!(task.id, "task-get");
743    }
744
745    #[tokio::test]
746    async fn cancel_task_success() {
747        let server = MockServer::start().await;
748        Mock::given(method("POST"))
749            .and(path("/rpc"))
750            .respond_with(task_rpc_response("task-cancel", "canceled"))
751            .mount(&server)
752            .await;
753
754        let client = A2aClient::new(reqwest::Client::new());
755        let params = TaskIdParams {
756            id: "task-cancel".into(),
757            history_length: None,
758        };
759        let task = client
760            .cancel_task(&format!("{}/rpc", server.uri()), params, None)
761            .await
762            .unwrap();
763        assert_eq!(task.id, "task-cancel");
764    }
765
766    #[tokio::test]
767    async fn stream_message_success() {
768        let server = MockServer::start().await;
769        Mock::given(method("POST"))
770            .and(path("/rpc"))
771            .respond_with(sse_task_events_response("task-stream", "result content"))
772            .mount(&server)
773            .await;
774
775        let client = A2aClient::new(reqwest::Client::new());
776        let params = SendMessageParams {
777            message: Message::user_text("stream"),
778            configuration: None,
779        };
780        let stream = client
781            .stream_message(&format!("{}/rpc", server.uri()), params, None)
782            .await
783            .unwrap();
784        let events: Vec<_> = stream.collect().await;
785        assert!(!events.is_empty());
786    }
787
788    #[tokio::test]
789    async fn stream_message_http_error() {
790        let server = MockServer::start().await;
791        Mock::given(method("POST"))
792            .and(path("/rpc"))
793            .respond_with(ResponseTemplate::new(500).set_body_string("Internal Server Error"))
794            .mount(&server)
795            .await;
796
797        let client = A2aClient::new(reqwest::Client::new());
798        let params = SendMessageParams {
799            message: Message::user_text("fail"),
800            configuration: None,
801        };
802        let result = client
803            .stream_message(&format!("{}/rpc", server.uri()), params, None)
804            .await;
805        let err = result.err().expect("expected error");
806        assert!(matches!(err, crate::error::A2aError::Stream(_)));
807    }
808}