1use std::pin::Pin;
5
6use eventsource_stream::Eventsource;
7use futures_core::Stream;
8use serde::{Deserialize, Serialize, de::DeserializeOwned};
9use tokio_stream::StreamExt;
10use zeph_common::net::is_private_ip;
11
12use crate::error::A2aError;
13use crate::jsonrpc::{
14 JsonRpcRequest, JsonRpcResponse, METHOD_CANCEL_TASK, METHOD_GET_TASK, METHOD_SEND_MESSAGE,
15 METHOD_SEND_STREAMING_MESSAGE, SendMessageParams, TaskIdParams,
16};
17use crate::types::{Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent};
18
19pub type TaskEventStream = Pin<Box<dyn Stream<Item = Result<TaskEvent, A2aError>> + Send>>;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22#[serde(untagged)]
23pub enum TaskEvent {
24 StatusUpdate(TaskStatusUpdateEvent),
25 ArtifactUpdate(TaskArtifactUpdateEvent),
26}
27
28pub struct A2aClient {
29 client: reqwest::Client,
30 require_tls: bool,
31 ssrf_protection: bool,
32}
33
34impl A2aClient {
35 #[must_use]
36 pub fn new(client: reqwest::Client) -> Self {
37 Self {
38 client,
39 require_tls: false,
40 ssrf_protection: false,
41 }
42 }
43
44 #[must_use]
45 pub fn with_security(mut self, require_tls: bool, ssrf_protection: bool) -> Self {
46 self.require_tls = require_tls;
47 self.ssrf_protection = ssrf_protection;
48 self
49 }
50
51 pub async fn send_message(
54 &self,
55 endpoint: &str,
56 params: SendMessageParams,
57 token: Option<&str>,
58 ) -> Result<Task, A2aError> {
59 self.rpc_call(endpoint, METHOD_SEND_MESSAGE, params, token)
60 .await
61 }
62
63 pub async fn stream_message(
66 &self,
67 endpoint: &str,
68 params: SendMessageParams,
69 token: Option<&str>,
70 ) -> Result<TaskEventStream, A2aError> {
71 self.validate_endpoint(endpoint).await?;
72 let request = JsonRpcRequest::new(METHOD_SEND_STREAMING_MESSAGE, params);
73 let mut req = self.client.post(endpoint).json(&request);
74 if let Some(t) = token {
75 req = req.bearer_auth(t);
76 }
77 let resp = req.send().await?;
78
79 if !resp.status().is_success() {
80 let status = resp.status();
81 let body = resp.text().await.unwrap_or_default();
82 let truncated = if body.len() > 256 {
84 format!("{}…", &body[..256])
85 } else {
86 body
87 };
88 return Err(A2aError::Stream(format!("HTTP {status}: {truncated}")));
89 }
90
91 let event_stream = resp.bytes_stream().eventsource();
92 let mapped = event_stream.filter_map(|event| match event {
93 Ok(event) => {
94 if event.data.is_empty() || event.data == "[DONE]" {
95 return None;
96 }
97 match serde_json::from_str::<JsonRpcResponse<TaskEvent>>(&event.data) {
98 Ok(rpc_resp) => match rpc_resp.into_result() {
99 Ok(task_event) => Some(Ok(task_event)),
100 Err(rpc_err) => Some(Err(A2aError::from(rpc_err))),
101 },
102 Err(e) => Some(Err(A2aError::Stream(format!(
103 "failed to parse SSE event: {e}"
104 )))),
105 }
106 }
107 Err(e) => Some(Err(A2aError::Stream(format!("SSE stream error: {e}")))),
108 });
109
110 Ok(Box::pin(mapped))
111 }
112
113 pub async fn get_task(
116 &self,
117 endpoint: &str,
118 params: TaskIdParams,
119 token: Option<&str>,
120 ) -> Result<Task, A2aError> {
121 self.rpc_call(endpoint, METHOD_GET_TASK, params, token)
122 .await
123 }
124
125 pub async fn cancel_task(
128 &self,
129 endpoint: &str,
130 params: TaskIdParams,
131 token: Option<&str>,
132 ) -> Result<Task, A2aError> {
133 self.rpc_call(endpoint, METHOD_CANCEL_TASK, params, token)
134 .await
135 }
136
137 async fn validate_endpoint(&self, endpoint: &str) -> Result<(), A2aError> {
138 if self.require_tls && !endpoint.starts_with("https://") {
139 return Err(A2aError::Security(format!(
140 "TLS required but endpoint uses HTTP: {endpoint}"
141 )));
142 }
143
144 if self.ssrf_protection {
145 let url: url::Url = endpoint
146 .parse()
147 .map_err(|e| A2aError::Security(format!("invalid URL: {e}")))?;
148
149 if let Some(host) = url.host_str() {
150 let addrs = tokio::net::lookup_host(format!(
151 "{}:{}",
152 host,
153 url.port_or_known_default().unwrap_or(443)
154 ))
155 .await
156 .map_err(|e| A2aError::Security(format!("DNS resolution failed: {e}")))?;
157
158 for addr in addrs {
159 if is_private_ip(addr.ip()) {
160 return Err(A2aError::Security(format!(
161 "SSRF protection: private IP {} for host {host}",
162 addr.ip()
163 )));
164 }
165 }
166 }
167 }
168
169 Ok(())
170 }
171
172 async fn rpc_call<P: Serialize, R: DeserializeOwned>(
173 &self,
174 endpoint: &str,
175 method: &str,
176 params: P,
177 token: Option<&str>,
178 ) -> Result<R, A2aError> {
179 self.validate_endpoint(endpoint).await?;
180 let request = JsonRpcRequest::new(method, params);
181 let mut req = self.client.post(endpoint).json(&request);
182 if let Some(t) = token {
183 req = req.bearer_auth(t);
184 }
185 let resp = req.send().await?;
186 let rpc_response: JsonRpcResponse<R> = resp.json().await?;
187 rpc_response.into_result().map_err(A2aError::from)
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use std::net::IpAddr;
194
195 use super::*;
196 use crate::jsonrpc::{JsonRpcError, JsonRpcResponse};
197 use crate::types::{
198 Artifact, Message, Part, Task, TaskArtifactUpdateEvent, TaskState, TaskStatus,
199 TaskStatusUpdateEvent,
200 };
201
202 #[test]
203 fn task_event_deserialize_status_update() {
204 let event = TaskStatusUpdateEvent {
205 kind: "status-update".into(),
206 task_id: "t-1".into(),
207 context_id: None,
208 status: TaskStatus {
209 state: TaskState::Working,
210 timestamp: "ts".into(),
211 message: Some(Message::user_text("thinking...")),
212 },
213 is_final: false,
214 };
215 let json = serde_json::to_string(&event).unwrap();
216 let parsed: TaskEvent = serde_json::from_str(&json).unwrap();
217 assert!(matches!(parsed, TaskEvent::StatusUpdate(_)));
218 }
219
220 #[test]
221 fn task_event_deserialize_artifact_update() {
222 let event = TaskArtifactUpdateEvent {
223 kind: "artifact-update".into(),
224 task_id: "t-1".into(),
225 context_id: None,
226 artifact: Artifact {
227 artifact_id: "a-1".into(),
228 name: None,
229 parts: vec![Part::text("result")],
230 metadata: None,
231 },
232 is_final: true,
233 };
234 let json = serde_json::to_string(&event).unwrap();
235 let parsed: TaskEvent = serde_json::from_str(&json).unwrap();
236 assert!(matches!(parsed, TaskEvent::ArtifactUpdate(_)));
237 }
238
239 #[test]
240 fn rpc_response_with_task_result() {
241 let task = Task {
242 id: "t-1".into(),
243 context_id: None,
244 status: TaskStatus {
245 state: TaskState::Completed,
246 timestamp: "ts".into(),
247 message: None,
248 },
249 artifacts: vec![],
250 history: vec![],
251 metadata: None,
252 };
253 let resp = JsonRpcResponse {
254 jsonrpc: "2.0".into(),
255 id: serde_json::Value::String("req-1".into()),
256 result: Some(task),
257 error: None,
258 };
259 let json = serde_json::to_string(&resp).unwrap();
260 let back: JsonRpcResponse<Task> = serde_json::from_str(&json).unwrap();
261 let task = back.into_result().unwrap();
262 assert_eq!(task.id, "t-1");
263 assert_eq!(task.status.state, TaskState::Completed);
264 }
265
266 #[test]
267 fn rpc_response_with_error() {
268 let resp: JsonRpcResponse<Task> = JsonRpcResponse {
269 jsonrpc: "2.0".into(),
270 id: serde_json::Value::String("req-1".into()),
271 result: None,
272 error: Some(JsonRpcError {
273 code: -32001,
274 message: "task not found".into(),
275 data: None,
276 }),
277 };
278 let json = serde_json::to_string(&resp).unwrap();
279 let back: JsonRpcResponse<Task> = serde_json::from_str(&json).unwrap();
280 let err = back.into_result().unwrap_err();
281 assert_eq!(err.code, -32001);
282 }
283
284 #[test]
285 fn a2a_client_construction() {
286 let client = A2aClient::new(reqwest::Client::new());
287 drop(client);
288 }
289
290 #[test]
291 fn is_private_ip_loopback() {
292 assert!(is_private_ip(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)));
293 assert!(is_private_ip(IpAddr::V6(std::net::Ipv6Addr::LOCALHOST)));
294 }
295
296 #[test]
297 fn is_private_ip_private_ranges() {
298 assert!(is_private_ip("10.0.0.1".parse().unwrap()));
299 assert!(is_private_ip("172.16.0.1".parse().unwrap()));
300 assert!(is_private_ip("192.168.1.1".parse().unwrap()));
301 }
302
303 #[test]
304 fn is_private_ip_link_local() {
305 assert!(is_private_ip("169.254.0.1".parse().unwrap()));
306 }
307
308 #[test]
309 fn is_private_ip_unspecified() {
310 assert!(is_private_ip("0.0.0.0".parse().unwrap()));
311 assert!(is_private_ip("::".parse().unwrap()));
312 }
313
314 #[test]
315 fn is_private_ip_public() {
316 assert!(!is_private_ip("8.8.8.8".parse().unwrap()));
317 assert!(!is_private_ip("1.1.1.1".parse().unwrap()));
318 }
319
320 #[tokio::test]
321 async fn tls_enforcement_rejects_http() {
322 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
323 let result = client.validate_endpoint("http://example.com/rpc").await;
324 assert!(result.is_err());
325 let err = result.unwrap_err();
326 assert!(matches!(err, A2aError::Security(_)));
327 assert!(err.to_string().contains("TLS required"));
328 }
329
330 #[tokio::test]
331 async fn tls_enforcement_allows_https() {
332 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
333 let result = client.validate_endpoint("https://example.com/rpc").await;
334 assert!(result.is_ok());
335 }
336
337 #[tokio::test]
338 async fn ssrf_protection_rejects_localhost() {
339 let client = A2aClient::new(reqwest::Client::new()).with_security(false, true);
340 let result = client.validate_endpoint("http://127.0.0.1:8080/rpc").await;
341 assert!(result.is_err());
342 assert!(result.unwrap_err().to_string().contains("SSRF"));
343 }
344
345 #[tokio::test]
346 async fn no_security_allows_http_localhost() {
347 let client = A2aClient::new(reqwest::Client::new());
348 let result = client.validate_endpoint("http://127.0.0.1:8080/rpc").await;
349 assert!(result.is_ok());
350 }
351
352 #[test]
353 fn jsonrpc_request_serialization_for_send_message() {
354 let params = SendMessageParams {
355 message: Message::user_text("hello"),
356 configuration: None,
357 };
358 let req = JsonRpcRequest::new(METHOD_SEND_MESSAGE, params);
359 let json = serde_json::to_string(&req).unwrap();
360 assert!(json.contains("\"method\":\"message/send\""));
361 assert!(json.contains("\"jsonrpc\":\"2.0\""));
362 assert!(json.contains("\"hello\""));
363 }
364
365 #[test]
366 fn jsonrpc_request_serialization_for_get_task() {
367 let params = TaskIdParams {
368 id: "task-123".into(),
369 history_length: Some(5),
370 };
371 let req = JsonRpcRequest::new(METHOD_GET_TASK, params);
372 let json = serde_json::to_string(&req).unwrap();
373 assert!(json.contains("\"method\":\"tasks/get\""));
374 assert!(json.contains("\"task-123\""));
375 assert!(json.contains("\"historyLength\":5"));
376 }
377
378 #[test]
379 fn jsonrpc_request_serialization_for_cancel_task() {
380 let params = TaskIdParams {
381 id: "task-456".into(),
382 history_length: None,
383 };
384 let req = JsonRpcRequest::new(METHOD_CANCEL_TASK, params);
385 let json = serde_json::to_string(&req).unwrap();
386 assert!(json.contains("\"method\":\"tasks/cancel\""));
387 assert!(!json.contains("historyLength"));
388 }
389
390 #[test]
391 fn jsonrpc_request_serialization_for_stream() {
392 let params = SendMessageParams {
393 message: Message::user_text("stream me"),
394 configuration: None,
395 };
396 let req = JsonRpcRequest::new(METHOD_SEND_STREAMING_MESSAGE, params);
397 let json = serde_json::to_string(&req).unwrap();
398 assert!(json.contains("\"method\":\"message/stream\""));
399 }
400
401 #[tokio::test]
402 async fn send_message_connection_error() {
403 let client = A2aClient::new(reqwest::Client::new());
404 let params = SendMessageParams {
405 message: Message::user_text("hello"),
406 configuration: None,
407 };
408 let result = client
409 .send_message("http://127.0.0.1:1/rpc", params, None)
410 .await;
411 assert!(result.is_err());
412 assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
413 }
414
415 #[tokio::test]
416 async fn get_task_connection_error() {
417 let client = A2aClient::new(reqwest::Client::new());
418 let params = TaskIdParams {
419 id: "t-1".into(),
420 history_length: None,
421 };
422 let result = client
423 .get_task("http://127.0.0.1:1/rpc", params, None)
424 .await;
425 assert!(result.is_err());
426 assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
427 }
428
429 #[tokio::test]
430 async fn cancel_task_connection_error() {
431 let client = A2aClient::new(reqwest::Client::new());
432 let params = TaskIdParams {
433 id: "t-1".into(),
434 history_length: None,
435 };
436 let result = client
437 .cancel_task("http://127.0.0.1:1/rpc", params, None)
438 .await;
439 assert!(result.is_err());
440 assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
441 }
442
443 #[tokio::test]
444 async fn stream_message_connection_error() {
445 let client = A2aClient::new(reqwest::Client::new());
446 let params = SendMessageParams {
447 message: Message::user_text("stream me"),
448 configuration: None,
449 };
450 let result = client
451 .stream_message("http://127.0.0.1:1/rpc", params, None)
452 .await;
453 assert!(result.is_err());
454 }
455
456 #[tokio::test]
457 async fn stream_message_tls_required_rejects_http() {
458 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
459 let params = SendMessageParams {
460 message: Message::user_text("hello"),
461 configuration: None,
462 };
463 let result = client
464 .stream_message("http://example.com/rpc", params, None)
465 .await;
466 match result {
467 Err(A2aError::Security(msg)) => assert!(msg.contains("TLS required")),
468 _ => panic!("expected Security error"),
469 }
470 }
471
472 #[tokio::test]
473 async fn send_message_tls_required_rejects_http() {
474 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
475 let params = SendMessageParams {
476 message: Message::user_text("hello"),
477 configuration: None,
478 };
479 let result = client
480 .send_message("http://example.com/rpc", params, None)
481 .await;
482 assert!(result.is_err());
483 assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
484 }
485
486 #[tokio::test]
487 async fn get_task_tls_required_rejects_http() {
488 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
489 let params = TaskIdParams {
490 id: "t-1".into(),
491 history_length: None,
492 };
493 let result = client
494 .get_task("http://example.com/rpc", params, None)
495 .await;
496 assert!(result.is_err());
497 assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
498 }
499
500 #[tokio::test]
501 async fn cancel_task_tls_required_rejects_http() {
502 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
503 let params = TaskIdParams {
504 id: "t-1".into(),
505 history_length: None,
506 };
507 let result = client
508 .cancel_task("http://example.com/rpc", params, None)
509 .await;
510 assert!(result.is_err());
511 assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
512 }
513
514 #[tokio::test]
515 async fn validate_endpoint_invalid_url_with_ssrf() {
516 let client = A2aClient::new(reqwest::Client::new()).with_security(false, true);
517 let result = client.validate_endpoint("not-a-url").await;
518 assert!(result.is_err());
519 assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
520 }
521
522 #[test]
523 fn with_security_returns_configured_client() {
524 let client = A2aClient::new(reqwest::Client::new()).with_security(true, true);
525 assert!(client.require_tls);
526 assert!(client.ssrf_protection);
527 }
528
529 #[test]
530 fn default_client_no_security() {
531 let client = A2aClient::new(reqwest::Client::new());
532 assert!(!client.require_tls);
533 assert!(!client.ssrf_protection);
534 }
535
536 #[test]
537 fn task_event_clone() {
538 let event = TaskEvent::StatusUpdate(TaskStatusUpdateEvent {
539 kind: "status-update".into(),
540 task_id: "t-1".into(),
541 context_id: None,
542 status: TaskStatus {
543 state: TaskState::Working,
544 timestamp: "ts".into(),
545 message: None,
546 },
547 is_final: false,
548 });
549 let cloned = event.clone();
550 let json1 = serde_json::to_string(&event).unwrap();
551 let json2 = serde_json::to_string(&cloned).unwrap();
552 assert_eq!(json1, json2);
553 }
554
555 #[test]
556 fn task_event_debug() {
557 let event = TaskEvent::ArtifactUpdate(TaskArtifactUpdateEvent {
558 kind: "artifact-update".into(),
559 task_id: "t-1".into(),
560 context_id: None,
561 artifact: Artifact {
562 artifact_id: "a-1".into(),
563 name: None,
564 parts: vec![Part::text("data")],
565 metadata: None,
566 },
567 is_final: true,
568 });
569 let dbg = format!("{event:?}");
570 assert!(dbg.contains("ArtifactUpdate"));
571 }
572
573 #[test]
574 fn is_private_ip_ipv4_non_private() {
575 assert!(!is_private_ip("93.184.216.34".parse().unwrap()));
576 }
577
578 #[test]
579 fn is_private_ip_ipv6_non_private() {
580 assert!(!is_private_ip("2001:db8::1".parse().unwrap()));
581 }
582
583 #[test]
584 fn rpc_response_error_takes_priority_over_result() {
585 let resp = JsonRpcResponse {
586 jsonrpc: "2.0".into(),
587 id: serde_json::Value::String("1".into()),
588 result: Some(Task {
589 id: "t-1".into(),
590 context_id: None,
591 status: TaskStatus {
592 state: TaskState::Completed,
593 timestamp: "ts".into(),
594 message: None,
595 },
596 artifacts: vec![],
597 history: vec![],
598 metadata: None,
599 }),
600 error: Some(JsonRpcError {
601 code: -32001,
602 message: "error".into(),
603 data: None,
604 }),
605 };
606 let err = resp.into_result().unwrap_err();
607 assert_eq!(err.code, -32001);
608 }
609
610 #[test]
611 fn rpc_response_neither_result_nor_error() {
612 let resp: JsonRpcResponse<Task> = JsonRpcResponse {
613 jsonrpc: "2.0".into(),
614 id: serde_json::Value::String("1".into()),
615 result: None,
616 error: None,
617 };
618 let err = resp.into_result().unwrap_err();
619 assert_eq!(err.code, -32603);
620 }
621
622 #[test]
623 fn task_event_serialize_round_trip() {
624 let event = TaskEvent::StatusUpdate(TaskStatusUpdateEvent {
625 kind: "status-update".into(),
626 task_id: "t-1".into(),
627 context_id: Some("ctx-1".into()),
628 status: TaskStatus {
629 state: TaskState::Completed,
630 timestamp: "2025-01-01T00:00:00Z".into(),
631 message: Some(Message::user_text("done")),
632 },
633 is_final: true,
634 });
635 let json = serde_json::to_string(&event).unwrap();
636 let back: TaskEvent = serde_json::from_str(&json).unwrap();
637 assert!(matches!(back, TaskEvent::StatusUpdate(_)));
638 }
639}
640
641#[cfg(test)]
642mod wiremock_tests {
643 use tokio_stream::StreamExt;
644 use wiremock::matchers::{header, method, path};
645 use wiremock::{Mock, MockServer, ResponseTemplate};
646
647 use crate::client::A2aClient;
648 use crate::jsonrpc::{SendMessageParams, TaskIdParams};
649 use crate::testing::*;
650 use crate::types::Message;
651
652 #[tokio::test]
653 async fn send_message_success() {
654 let server = MockServer::start().await;
655 Mock::given(method("POST"))
656 .and(path("/rpc"))
657 .respond_with(task_rpc_response("task-1", "submitted"))
658 .mount(&server)
659 .await;
660
661 let client = A2aClient::new(reqwest::Client::new());
662 let params = SendMessageParams {
663 message: Message::user_text("hello"),
664 configuration: None,
665 };
666 let task = client
667 .send_message(&format!("{}/rpc", server.uri()), params, None)
668 .await
669 .unwrap();
670 assert_eq!(task.id, "task-1");
671 }
672
673 #[tokio::test]
674 async fn send_message_rpc_error() {
675 let server = MockServer::start().await;
676 Mock::given(method("POST"))
677 .and(path("/rpc"))
678 .respond_with(task_rpc_error_response(-32001, "task not found"))
679 .mount(&server)
680 .await;
681
682 let client = A2aClient::new(reqwest::Client::new());
683 let params = SendMessageParams {
684 message: Message::user_text("hi"),
685 configuration: None,
686 };
687 let result = client
688 .send_message(&format!("{}/rpc", server.uri()), params, None)
689 .await;
690 assert!(result.is_err());
691 let err = result.unwrap_err();
692 assert!(matches!(
693 err,
694 crate::error::A2aError::JsonRpc { code: -32001, .. }
695 ));
696 }
697
698 #[tokio::test]
699 async fn send_message_with_bearer_auth() {
700 let server = MockServer::start().await;
701 Mock::given(method("POST"))
702 .and(path("/rpc"))
703 .and(header("authorization", "Bearer secret-token"))
704 .respond_with(task_rpc_response("task-auth", "submitted"))
705 .mount(&server)
706 .await;
707
708 let client = A2aClient::new(reqwest::Client::new());
709 let params = SendMessageParams {
710 message: Message::user_text("secure"),
711 configuration: None,
712 };
713 let task = client
714 .send_message(
715 &format!("{}/rpc", server.uri()),
716 params,
717 Some("secret-token"),
718 )
719 .await
720 .unwrap();
721 assert_eq!(task.id, "task-auth");
722 }
723
724 #[tokio::test]
725 async fn get_task_success() {
726 let server = MockServer::start().await;
727 Mock::given(method("POST"))
728 .and(path("/rpc"))
729 .respond_with(task_rpc_response("task-get", "completed"))
730 .mount(&server)
731 .await;
732
733 let client = A2aClient::new(reqwest::Client::new());
734 let params = TaskIdParams {
735 id: "task-get".into(),
736 history_length: None,
737 };
738 let task = client
739 .get_task(&format!("{}/rpc", server.uri()), params, None)
740 .await
741 .unwrap();
742 assert_eq!(task.id, "task-get");
743 }
744
745 #[tokio::test]
746 async fn cancel_task_success() {
747 let server = MockServer::start().await;
748 Mock::given(method("POST"))
749 .and(path("/rpc"))
750 .respond_with(task_rpc_response("task-cancel", "canceled"))
751 .mount(&server)
752 .await;
753
754 let client = A2aClient::new(reqwest::Client::new());
755 let params = TaskIdParams {
756 id: "task-cancel".into(),
757 history_length: None,
758 };
759 let task = client
760 .cancel_task(&format!("{}/rpc", server.uri()), params, None)
761 .await
762 .unwrap();
763 assert_eq!(task.id, "task-cancel");
764 }
765
766 #[tokio::test]
767 async fn stream_message_success() {
768 let server = MockServer::start().await;
769 Mock::given(method("POST"))
770 .and(path("/rpc"))
771 .respond_with(sse_task_events_response("task-stream", "result content"))
772 .mount(&server)
773 .await;
774
775 let client = A2aClient::new(reqwest::Client::new());
776 let params = SendMessageParams {
777 message: Message::user_text("stream"),
778 configuration: None,
779 };
780 let stream = client
781 .stream_message(&format!("{}/rpc", server.uri()), params, None)
782 .await
783 .unwrap();
784 let events: Vec<_> = stream.collect().await;
785 assert!(!events.is_empty());
786 }
787
788 #[tokio::test]
789 async fn stream_message_http_error() {
790 let server = MockServer::start().await;
791 Mock::given(method("POST"))
792 .and(path("/rpc"))
793 .respond_with(ResponseTemplate::new(500).set_body_string("Internal Server Error"))
794 .mount(&server)
795 .await;
796
797 let client = A2aClient::new(reqwest::Client::new());
798 let params = SendMessageParams {
799 message: Message::user_text("fail"),
800 configuration: None,
801 };
802 let result = client
803 .stream_message(&format!("{}/rpc", server.uri()), params, None)
804 .await;
805 let err = result.err().expect("expected error");
806 assert!(matches!(err, crate::error::A2aError::Stream(_)));
807 }
808}