1use std::pin::Pin;
7use std::time::Duration;
8
9use eventsource_stream::Eventsource;
10use futures_core::Stream;
11use serde::{Deserialize, Serialize, de::DeserializeOwned};
12use tokio_stream::StreamExt;
13use zeph_common::net::is_private_ip;
14
15use crate::error::A2aError;
16use crate::jsonrpc::{
17 JsonRpcRequest, JsonRpcResponse, METHOD_CANCEL_TASK, METHOD_GET_TASK, METHOD_SEND_MESSAGE,
18 METHOD_SEND_STREAMING_MESSAGE, SendMessageParams, TaskIdParams,
19};
20use crate::types::{Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent};
21
22pub type TaskEventStream = Pin<Box<dyn Stream<Item = Result<TaskEvent, A2aError>> + Send>>;
27
28#[non_exhaustive]
34#[derive(Debug, Clone, Serialize, Deserialize)]
35#[serde(untagged)]
36pub enum TaskEvent {
37 StatusUpdate(TaskStatusUpdateEvent),
39 ArtifactUpdate(TaskArtifactUpdateEvent),
41}
42
43pub struct A2aClient {
76 client: reqwest::Client,
77 require_tls: bool,
78 ssrf_protection: bool,
79 request_timeout: Duration,
87}
88
89impl A2aClient {
90 #[must_use]
95 pub fn new(client: reqwest::Client) -> Self {
96 Self {
97 client,
98 require_tls: false,
99 ssrf_protection: false,
100 request_timeout: Duration::from_secs(30),
101 }
102 }
103
104 #[must_use]
121 pub fn with_security(mut self, require_tls: bool, ssrf_protection: bool) -> Self {
122 self.require_tls = require_tls;
123 self.ssrf_protection = ssrf_protection;
124 self
125 }
126
127 #[must_use]
133 pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
134 self.request_timeout = timeout;
135 self
136 }
137
138 pub async fn send_message(
142 &self,
143 endpoint: &str,
144 params: SendMessageParams,
145 token: Option<&str>,
146 ) -> Result<Task, A2aError> {
147 self.rpc_call(endpoint, METHOD_SEND_MESSAGE, params, token)
148 .await
149 }
150
151 pub async fn stream_message(
154 &self,
155 endpoint: &str,
156 params: SendMessageParams,
157 token: Option<&str>,
158 ) -> Result<TaskEventStream, A2aError> {
159 self.validate_endpoint(endpoint).await?;
160 let request = JsonRpcRequest::new(METHOD_SEND_STREAMING_MESSAGE, params);
161 let mut req = self.client.post(endpoint).json(&request);
162 if let Some(t) = token {
163 req = req.bearer_auth(t);
164 }
165 let resp = tokio::time::timeout(self.request_timeout, req.send())
166 .await
167 .map_err(|_| A2aError::Timeout(self.request_timeout))?
168 .map_err(A2aError::Http)?;
169
170 if !resp.status().is_success() {
171 let status = resp.status();
172 let body = tokio::time::timeout(Duration::from_secs(5), resp.text())
173 .await
174 .unwrap_or(Ok(String::new()))
175 .unwrap_or_default();
176 let truncated = if body.len() > 256 {
178 format!("{}…", &body[..256])
179 } else {
180 body
181 };
182 return Err(A2aError::Stream(format!("HTTP {status}: {truncated}")));
183 }
184
185 let event_stream = resp.bytes_stream().eventsource();
186 let mapped = event_stream.filter_map(|event| match event {
187 Ok(event) => {
188 if event.data.is_empty() || event.data == "[DONE]" {
189 return None;
190 }
191 match serde_json::from_str::<JsonRpcResponse<TaskEvent>>(&event.data) {
192 Ok(rpc_resp) => match rpc_resp.into_result() {
193 Ok(task_event) => Some(Ok(task_event)),
194 Err(rpc_err) => Some(Err(A2aError::from(rpc_err))),
195 },
196 Err(e) => Some(Err(A2aError::Stream(format!(
197 "failed to parse SSE event: {e}"
198 )))),
199 }
200 }
201 Err(e) => Some(Err(A2aError::Stream(format!("SSE stream error: {e}")))),
202 });
203
204 Ok(Box::pin(mapped))
205 }
206
207 pub async fn get_task(
211 &self,
212 endpoint: &str,
213 params: TaskIdParams,
214 token: Option<&str>,
215 ) -> Result<Task, A2aError> {
216 self.rpc_call(endpoint, METHOD_GET_TASK, params, token)
217 .await
218 }
219
220 pub async fn cancel_task(
224 &self,
225 endpoint: &str,
226 params: TaskIdParams,
227 token: Option<&str>,
228 ) -> Result<Task, A2aError> {
229 self.rpc_call(endpoint, METHOD_CANCEL_TASK, params, token)
230 .await
231 }
232
233 async fn validate_endpoint(&self, endpoint: &str) -> Result<(), A2aError> {
234 if self.require_tls && !endpoint.starts_with("https://") {
235 return Err(A2aError::Security(format!(
236 "TLS required but endpoint uses HTTP: {endpoint}"
237 )));
238 }
239
240 if self.ssrf_protection {
241 let url: url::Url = endpoint
242 .parse()
243 .map_err(|e| A2aError::Security(format!("invalid URL: {e}")))?;
244
245 if let Some(host) = url.host_str() {
246 let addrs = tokio::net::lookup_host(format!(
247 "{}:{}",
248 host,
249 url.port_or_known_default().unwrap_or(443)
250 ))
251 .await
252 .map_err(|e| A2aError::Security(format!("DNS resolution failed: {e}")))?;
253
254 for addr in addrs {
255 if is_private_ip(addr.ip()) {
256 return Err(A2aError::Security(format!(
257 "SSRF protection: private IP {} for host {host}",
258 addr.ip()
259 )));
260 }
261 }
262 }
263 }
264
265 Ok(())
266 }
267
268 async fn rpc_call<P: Serialize, R: DeserializeOwned>(
269 &self,
270 endpoint: &str,
271 method: &str,
272 params: P,
273 token: Option<&str>,
274 ) -> Result<R, A2aError> {
275 self.validate_endpoint(endpoint).await?;
276 let request = JsonRpcRequest::new(method, params);
277 let mut req = self.client.post(endpoint).json(&request);
278 if let Some(t) = token {
279 req = req.bearer_auth(t);
280 }
281 let rpc_response: JsonRpcResponse<R> = tokio::time::timeout(self.request_timeout, async {
282 let resp = req.send().await?;
283 resp.json().await
284 })
285 .await
286 .map_err(|_| A2aError::Timeout(self.request_timeout))?
287 .map_err(A2aError::Http)?;
288 rpc_response.into_result().map_err(A2aError::from)
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use std::net::IpAddr;
295
296 use super::*;
297 use crate::jsonrpc::{JsonRpcError, JsonRpcResponse};
298 use crate::types::{
299 Artifact, Message, Part, Task, TaskArtifactUpdateEvent, TaskState, TaskStatus,
300 TaskStatusUpdateEvent,
301 };
302
303 #[test]
304 fn task_event_deserialize_status_update() {
305 let event = TaskStatusUpdateEvent {
306 kind: "status-update".into(),
307 task_id: "t-1".into(),
308 context_id: None,
309 status: TaskStatus {
310 state: TaskState::Working,
311 timestamp: "ts".into(),
312 message: Some(Message::user_text("thinking...")),
313 },
314 is_final: false,
315 };
316 let json = serde_json::to_string(&event).unwrap();
317 let parsed: TaskEvent = serde_json::from_str(&json).unwrap();
318 assert!(matches!(parsed, TaskEvent::StatusUpdate(_)));
319 }
320
321 #[test]
322 fn task_event_deserialize_artifact_update() {
323 let event = TaskArtifactUpdateEvent {
324 kind: "artifact-update".into(),
325 task_id: "t-1".into(),
326 context_id: None,
327 artifact: Artifact {
328 artifact_id: "a-1".into(),
329 name: None,
330 parts: vec![Part::text("result")],
331 metadata: None,
332 },
333 is_final: true,
334 };
335 let json = serde_json::to_string(&event).unwrap();
336 let parsed: TaskEvent = serde_json::from_str(&json).unwrap();
337 assert!(matches!(parsed, TaskEvent::ArtifactUpdate(_)));
338 }
339
340 #[test]
341 fn rpc_response_with_task_result() {
342 let task = Task {
343 id: "t-1".into(),
344 context_id: None,
345 status: TaskStatus {
346 state: TaskState::Completed,
347 timestamp: "ts".into(),
348 message: None,
349 },
350 artifacts: vec![],
351 history: vec![],
352 metadata: None,
353 };
354 let resp = JsonRpcResponse {
355 jsonrpc: "2.0".into(),
356 id: serde_json::Value::String("req-1".into()),
357 result: Some(task),
358 error: None,
359 };
360 let json = serde_json::to_string(&resp).unwrap();
361 let back: JsonRpcResponse<Task> = serde_json::from_str(&json).unwrap();
362 let task = back.into_result().unwrap();
363 assert_eq!(task.id, "t-1");
364 assert_eq!(task.status.state, TaskState::Completed);
365 }
366
367 #[test]
368 fn rpc_response_with_error() {
369 let resp: JsonRpcResponse<Task> = JsonRpcResponse {
370 jsonrpc: "2.0".into(),
371 id: serde_json::Value::String("req-1".into()),
372 result: None,
373 error: Some(JsonRpcError {
374 code: -32001,
375 message: "task not found".into(),
376 data: None,
377 }),
378 };
379 let json = serde_json::to_string(&resp).unwrap();
380 let back: JsonRpcResponse<Task> = serde_json::from_str(&json).unwrap();
381 let err = back.into_result().unwrap_err();
382 assert_eq!(err.code, -32001);
383 }
384
385 #[test]
386 fn a2a_client_construction() {
387 let client = A2aClient::new(reqwest::Client::new());
388 drop(client);
389 }
390
391 #[test]
392 fn is_private_ip_loopback() {
393 assert!(is_private_ip(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)));
394 assert!(is_private_ip(IpAddr::V6(std::net::Ipv6Addr::LOCALHOST)));
395 }
396
397 #[test]
398 fn is_private_ip_private_ranges() {
399 assert!(is_private_ip("10.0.0.1".parse().unwrap()));
400 assert!(is_private_ip("172.16.0.1".parse().unwrap()));
401 assert!(is_private_ip("192.168.1.1".parse().unwrap()));
402 }
403
404 #[test]
405 fn is_private_ip_link_local() {
406 assert!(is_private_ip("169.254.0.1".parse().unwrap()));
407 }
408
409 #[test]
410 fn is_private_ip_unspecified() {
411 assert!(is_private_ip("0.0.0.0".parse().unwrap()));
412 assert!(is_private_ip("::".parse().unwrap()));
413 }
414
415 #[test]
416 fn is_private_ip_public() {
417 assert!(!is_private_ip("8.8.8.8".parse().unwrap()));
418 assert!(!is_private_ip("1.1.1.1".parse().unwrap()));
419 }
420
421 #[tokio::test]
422 async fn tls_enforcement_rejects_http() {
423 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
424 let result = client.validate_endpoint("http://example.com/rpc").await;
425 assert!(result.is_err());
426 let err = result.unwrap_err();
427 assert!(matches!(err, A2aError::Security(_)));
428 assert!(err.to_string().contains("TLS required"));
429 }
430
431 #[tokio::test]
432 async fn tls_enforcement_allows_https() {
433 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
434 let result = client.validate_endpoint("https://example.com/rpc").await;
435 assert!(result.is_ok());
436 }
437
438 #[tokio::test]
439 async fn ssrf_protection_rejects_localhost() {
440 let client = A2aClient::new(reqwest::Client::new()).with_security(false, true);
441 let result = client.validate_endpoint("http://127.0.0.1:8080/rpc").await;
442 assert!(result.is_err());
443 assert!(result.unwrap_err().to_string().contains("SSRF"));
444 }
445
446 #[tokio::test]
447 async fn no_security_allows_http_localhost() {
448 let client = A2aClient::new(reqwest::Client::new());
449 let result = client.validate_endpoint("http://127.0.0.1:8080/rpc").await;
450 assert!(result.is_ok());
451 }
452
453 #[test]
454 fn jsonrpc_request_serialization_for_send_message() {
455 let params = SendMessageParams {
456 message: Message::user_text("hello"),
457 configuration: None,
458 };
459 let req = JsonRpcRequest::new(METHOD_SEND_MESSAGE, params);
460 let json = serde_json::to_string(&req).unwrap();
461 assert!(json.contains("\"method\":\"message/send\""));
462 assert!(json.contains("\"jsonrpc\":\"2.0\""));
463 assert!(json.contains("\"hello\""));
464 }
465
466 #[test]
467 fn jsonrpc_request_serialization_for_get_task() {
468 let params = TaskIdParams {
469 id: "task-123".into(),
470 history_length: Some(5),
471 };
472 let req = JsonRpcRequest::new(METHOD_GET_TASK, params);
473 let json = serde_json::to_string(&req).unwrap();
474 assert!(json.contains("\"method\":\"tasks/get\""));
475 assert!(json.contains("\"task-123\""));
476 assert!(json.contains("\"historyLength\":5"));
477 }
478
479 #[test]
480 fn jsonrpc_request_serialization_for_cancel_task() {
481 let params = TaskIdParams {
482 id: "task-456".into(),
483 history_length: None,
484 };
485 let req = JsonRpcRequest::new(METHOD_CANCEL_TASK, params);
486 let json = serde_json::to_string(&req).unwrap();
487 assert!(json.contains("\"method\":\"tasks/cancel\""));
488 assert!(!json.contains("historyLength"));
489 }
490
491 #[test]
492 fn jsonrpc_request_serialization_for_stream() {
493 let params = SendMessageParams {
494 message: Message::user_text("stream me"),
495 configuration: None,
496 };
497 let req = JsonRpcRequest::new(METHOD_SEND_STREAMING_MESSAGE, params);
498 let json = serde_json::to_string(&req).unwrap();
499 assert!(json.contains("\"method\":\"message/stream\""));
500 }
501
502 #[tokio::test]
503 async fn send_message_connection_error() {
504 let client = A2aClient::new(reqwest::Client::new());
505 let params = SendMessageParams {
506 message: Message::user_text("hello"),
507 configuration: None,
508 };
509 let result = client
510 .send_message("http://127.0.0.1:1/rpc", params, None)
511 .await;
512 assert!(result.is_err());
513 assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
514 }
515
516 #[tokio::test]
517 async fn get_task_connection_error() {
518 let client = A2aClient::new(reqwest::Client::new());
519 let params = TaskIdParams {
520 id: "t-1".into(),
521 history_length: None,
522 };
523 let result = client
524 .get_task("http://127.0.0.1:1/rpc", params, None)
525 .await;
526 assert!(result.is_err());
527 assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
528 }
529
530 #[tokio::test]
531 async fn cancel_task_connection_error() {
532 let client = A2aClient::new(reqwest::Client::new());
533 let params = TaskIdParams {
534 id: "t-1".into(),
535 history_length: None,
536 };
537 let result = client
538 .cancel_task("http://127.0.0.1:1/rpc", params, None)
539 .await;
540 assert!(result.is_err());
541 assert!(matches!(result.unwrap_err(), A2aError::Http(_)));
542 }
543
544 #[tokio::test]
545 async fn stream_message_connection_error() {
546 let client = A2aClient::new(reqwest::Client::new());
547 let params = SendMessageParams {
548 message: Message::user_text("stream me"),
549 configuration: None,
550 };
551 let result = client
552 .stream_message("http://127.0.0.1:1/rpc", params, None)
553 .await;
554 assert!(result.is_err());
555 }
556
557 #[tokio::test]
558 async fn stream_message_tls_required_rejects_http() {
559 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
560 let params = SendMessageParams {
561 message: Message::user_text("hello"),
562 configuration: None,
563 };
564 let result = client
565 .stream_message("http://example.com/rpc", params, None)
566 .await;
567 match result {
568 Err(A2aError::Security(msg)) => assert!(msg.contains("TLS required")),
569 _ => panic!("expected Security error"),
570 }
571 }
572
573 #[tokio::test]
574 async fn send_message_tls_required_rejects_http() {
575 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
576 let params = SendMessageParams {
577 message: Message::user_text("hello"),
578 configuration: None,
579 };
580 let result = client
581 .send_message("http://example.com/rpc", params, None)
582 .await;
583 assert!(result.is_err());
584 assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
585 }
586
587 #[tokio::test]
588 async fn get_task_tls_required_rejects_http() {
589 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
590 let params = TaskIdParams {
591 id: "t-1".into(),
592 history_length: None,
593 };
594 let result = client
595 .get_task("http://example.com/rpc", params, None)
596 .await;
597 assert!(result.is_err());
598 assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
599 }
600
601 #[tokio::test]
602 async fn cancel_task_tls_required_rejects_http() {
603 let client = A2aClient::new(reqwest::Client::new()).with_security(true, false);
604 let params = TaskIdParams {
605 id: "t-1".into(),
606 history_length: None,
607 };
608 let result = client
609 .cancel_task("http://example.com/rpc", params, None)
610 .await;
611 assert!(result.is_err());
612 assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
613 }
614
615 #[tokio::test]
616 async fn validate_endpoint_invalid_url_with_ssrf() {
617 let client = A2aClient::new(reqwest::Client::new()).with_security(false, true);
618 let result = client.validate_endpoint("not-a-url").await;
619 assert!(result.is_err());
620 assert!(matches!(result.unwrap_err(), A2aError::Security(_)));
621 }
622
623 #[test]
624 fn with_security_returns_configured_client() {
625 let client = A2aClient::new(reqwest::Client::new()).with_security(true, true);
626 assert!(client.require_tls);
627 assert!(client.ssrf_protection);
628 }
629
630 #[test]
631 fn default_client_no_security() {
632 let client = A2aClient::new(reqwest::Client::new());
633 assert!(!client.require_tls);
634 assert!(!client.ssrf_protection);
635 }
636
637 #[test]
638 fn task_event_clone() {
639 let event = TaskEvent::StatusUpdate(TaskStatusUpdateEvent {
640 kind: "status-update".into(),
641 task_id: "t-1".into(),
642 context_id: None,
643 status: TaskStatus {
644 state: TaskState::Working,
645 timestamp: "ts".into(),
646 message: None,
647 },
648 is_final: false,
649 });
650 let cloned = event.clone();
651 let json1 = serde_json::to_string(&event).unwrap();
652 let json2 = serde_json::to_string(&cloned).unwrap();
653 assert_eq!(json1, json2);
654 }
655
656 #[test]
657 fn task_event_debug() {
658 let event = TaskEvent::ArtifactUpdate(TaskArtifactUpdateEvent {
659 kind: "artifact-update".into(),
660 task_id: "t-1".into(),
661 context_id: None,
662 artifact: Artifact {
663 artifact_id: "a-1".into(),
664 name: None,
665 parts: vec![Part::text("data")],
666 metadata: None,
667 },
668 is_final: true,
669 });
670 let dbg = format!("{event:?}");
671 assert!(dbg.contains("ArtifactUpdate"));
672 }
673
674 #[test]
675 fn is_private_ip_ipv4_non_private() {
676 assert!(!is_private_ip("93.184.216.34".parse().unwrap()));
677 }
678
679 #[test]
680 fn is_private_ip_ipv6_non_private() {
681 assert!(!is_private_ip("2001:db8::1".parse().unwrap()));
682 }
683
684 #[test]
685 fn rpc_response_error_takes_priority_over_result() {
686 let resp = JsonRpcResponse {
687 jsonrpc: "2.0".into(),
688 id: serde_json::Value::String("1".into()),
689 result: Some(Task {
690 id: "t-1".into(),
691 context_id: None,
692 status: TaskStatus {
693 state: TaskState::Completed,
694 timestamp: "ts".into(),
695 message: None,
696 },
697 artifacts: vec![],
698 history: vec![],
699 metadata: None,
700 }),
701 error: Some(JsonRpcError {
702 code: -32001,
703 message: "error".into(),
704 data: None,
705 }),
706 };
707 let err = resp.into_result().unwrap_err();
708 assert_eq!(err.code, -32001);
709 }
710
711 #[test]
712 fn rpc_response_neither_result_nor_error() {
713 let resp: JsonRpcResponse<Task> = JsonRpcResponse {
714 jsonrpc: "2.0".into(),
715 id: serde_json::Value::String("1".into()),
716 result: None,
717 error: None,
718 };
719 let err = resp.into_result().unwrap_err();
720 assert_eq!(err.code, -32603);
721 }
722
723 #[test]
724 fn task_event_serialize_round_trip() {
725 let event = TaskEvent::StatusUpdate(TaskStatusUpdateEvent {
726 kind: "status-update".into(),
727 task_id: "t-1".into(),
728 context_id: Some("ctx-1".into()),
729 status: TaskStatus {
730 state: TaskState::Completed,
731 timestamp: "2025-01-01T00:00:00Z".into(),
732 message: Some(Message::user_text("done")),
733 },
734 is_final: true,
735 });
736 let json = serde_json::to_string(&event).unwrap();
737 let back: TaskEvent = serde_json::from_str(&json).unwrap();
738 assert!(matches!(back, TaskEvent::StatusUpdate(_)));
739 }
740}
741
742#[cfg(test)]
743mod wiremock_tests {
744 use tokio_stream::StreamExt;
745 use wiremock::matchers::{header, method, path};
746 use wiremock::{Mock, MockServer, ResponseTemplate};
747
748 use crate::client::A2aClient;
749 use crate::jsonrpc::{SendMessageParams, TaskIdParams};
750 use crate::testing::*;
751 use crate::types::Message;
752
753 #[tokio::test]
754 async fn send_message_success() {
755 let server = MockServer::start().await;
756 Mock::given(method("POST"))
757 .and(path("/rpc"))
758 .respond_with(task_rpc_response("task-1", "submitted"))
759 .mount(&server)
760 .await;
761
762 let client = A2aClient::new(reqwest::Client::new());
763 let params = SendMessageParams {
764 message: Message::user_text("hello"),
765 configuration: None,
766 };
767 let task = client
768 .send_message(&format!("{}/rpc", server.uri()), params, None)
769 .await
770 .unwrap();
771 assert_eq!(task.id, "task-1");
772 }
773
774 #[tokio::test]
775 async fn send_message_rpc_error() {
776 let server = MockServer::start().await;
777 Mock::given(method("POST"))
778 .and(path("/rpc"))
779 .respond_with(task_rpc_error_response(-32001, "task not found"))
780 .mount(&server)
781 .await;
782
783 let client = A2aClient::new(reqwest::Client::new());
784 let params = SendMessageParams {
785 message: Message::user_text("hi"),
786 configuration: None,
787 };
788 let result = client
789 .send_message(&format!("{}/rpc", server.uri()), params, None)
790 .await;
791 assert!(result.is_err());
792 let err = result.unwrap_err();
793 assert!(matches!(
794 err,
795 crate::error::A2aError::JsonRpc { code: -32001, .. }
796 ));
797 }
798
799 #[tokio::test]
800 async fn send_message_with_bearer_auth() {
801 let server = MockServer::start().await;
802 Mock::given(method("POST"))
803 .and(path("/rpc"))
804 .and(header("authorization", "Bearer secret-token"))
805 .respond_with(task_rpc_response("task-auth", "submitted"))
806 .mount(&server)
807 .await;
808
809 let client = A2aClient::new(reqwest::Client::new());
810 let params = SendMessageParams {
811 message: Message::user_text("secure"),
812 configuration: None,
813 };
814 let task = client
815 .send_message(
816 &format!("{}/rpc", server.uri()),
817 params,
818 Some("secret-token"),
819 )
820 .await
821 .unwrap();
822 assert_eq!(task.id, "task-auth");
823 }
824
825 #[tokio::test]
826 async fn get_task_success() {
827 let server = MockServer::start().await;
828 Mock::given(method("POST"))
829 .and(path("/rpc"))
830 .respond_with(task_rpc_response("task-get", "completed"))
831 .mount(&server)
832 .await;
833
834 let client = A2aClient::new(reqwest::Client::new());
835 let params = TaskIdParams {
836 id: "task-get".into(),
837 history_length: None,
838 };
839 let task = client
840 .get_task(&format!("{}/rpc", server.uri()), params, None)
841 .await
842 .unwrap();
843 assert_eq!(task.id, "task-get");
844 }
845
846 #[tokio::test]
847 async fn cancel_task_success() {
848 let server = MockServer::start().await;
849 Mock::given(method("POST"))
850 .and(path("/rpc"))
851 .respond_with(task_rpc_response("task-cancel", "canceled"))
852 .mount(&server)
853 .await;
854
855 let client = A2aClient::new(reqwest::Client::new());
856 let params = TaskIdParams {
857 id: "task-cancel".into(),
858 history_length: None,
859 };
860 let task = client
861 .cancel_task(&format!("{}/rpc", server.uri()), params, None)
862 .await
863 .unwrap();
864 assert_eq!(task.id, "task-cancel");
865 }
866
867 #[tokio::test]
868 async fn stream_message_success() {
869 let server = MockServer::start().await;
870 Mock::given(method("POST"))
871 .and(path("/rpc"))
872 .respond_with(sse_task_events_response("task-stream", "result content"))
873 .mount(&server)
874 .await;
875
876 let client = A2aClient::new(reqwest::Client::new());
877 let params = SendMessageParams {
878 message: Message::user_text("stream"),
879 configuration: None,
880 };
881 let stream = client
882 .stream_message(&format!("{}/rpc", server.uri()), params, None)
883 .await
884 .unwrap();
885 let events: Vec<_> = stream.collect().await;
886 assert!(!events.is_empty());
887 }
888
889 #[tokio::test]
890 async fn stream_message_http_error() {
891 let server = MockServer::start().await;
892 Mock::given(method("POST"))
893 .and(path("/rpc"))
894 .respond_with(ResponseTemplate::new(500).set_body_string("Internal Server Error"))
895 .mount(&server)
896 .await;
897
898 let client = A2aClient::new(reqwest::Client::new());
899 let params = SendMessageParams {
900 message: Message::user_text("fail"),
901 configuration: None,
902 };
903 let result = client
904 .stream_message(&format!("{}/rpc", server.uri()), params, None)
905 .await;
906 let err = result.err().expect("expected error");
907 assert!(matches!(err, crate::error::A2aError::Stream(_)));
908 }
909
910 #[tokio::test]
911 async fn rpc_call_times_out() {
912 let server = MockServer::start().await;
913 Mock::given(method("POST"))
914 .and(path("/rpc"))
915 .respond_with(
916 ResponseTemplate::new(200)
917 .set_delay(std::time::Duration::from_secs(5))
918 .set_body_json(serde_json::json!({
919 "jsonrpc": "2.0",
920 "id": "req-1",
921 "result": {
922 "id": "t-1",
923 "status": {"state": "completed", "timestamp": "2026-01-01T00:00:00Z"}
924 }
925 })),
926 )
927 .mount(&server)
928 .await;
929
930 let client = A2aClient::new(reqwest::Client::new())
931 .with_request_timeout(std::time::Duration::from_millis(100));
932 let params = SendMessageParams {
933 message: Message::user_text("hello"),
934 configuration: None,
935 };
936 let result = client
937 .send_message(&format!("{}/rpc", server.uri()), params, None)
938 .await;
939 assert!(result.is_err());
940 assert!(
941 matches!(result.unwrap_err(), crate::error::A2aError::Timeout(_)),
942 "expected Timeout error"
943 );
944 }
945}