1use crate::constants::{AGENT_CARD_PATH, JSONRPC_VERSION};
7use crate::error::{A2AError, A2AResult};
8use a2a_types::{
9 AgentCard, DeleteTaskPushNotificationConfigParams, JSONRPCErrorResponse, JSONRPCId,
10 ListTaskPushNotificationConfigParams, MessageSendParams, SendMessageResponse,
11 SendStreamingMessageResult, Task, TaskIdParams, TaskPushNotificationConfig, TaskQueryParams,
12};
13use futures_core::Stream;
14use reqwest::Client;
15use serde::{Deserialize, Serialize};
16use std::pin::Pin;
17use std::sync::Arc;
18use std::sync::atomic::{AtomicU64, Ordering};
19
20#[derive(Clone)]
22pub struct A2AClient {
23 client: Client,
25 service_endpoint_url: String,
27 auth_token: Option<String>,
29 request_id_counter: Arc<AtomicU64>,
31 agent_card: Arc<AgentCard>,
33}
34
35#[derive(Debug, Serialize)]
37struct JsonRpcRequest<T> {
38 jsonrpc: String,
39 id: JSONRPCId,
40 method: String,
41 params: T,
42}
43
44#[derive(Debug, Deserialize)]
46#[serde(untagged)]
47enum JsonRpcResponse<T> {
48 Success {
49 jsonrpc: String,
50 id: Option<JSONRPCId>,
51 result: T,
52 },
53 Error(JSONRPCErrorResponse),
54}
55
56impl A2AClient {
57 pub async fn from_card_url(base_url: impl AsRef<str>) -> A2AResult<Self> {
76 Self::from_card_url_with_client(base_url, Client::new()).await
77 }
78
79 pub async fn from_card_url_with_client(
104 base_url: impl AsRef<str>,
105 http_client: Client,
106 ) -> A2AResult<Self> {
107 let base_url = base_url.as_ref().trim_end_matches('/');
108 let card_url = format!("{}/{}", base_url, AGENT_CARD_PATH);
109
110 let response = http_client
111 .get(&card_url)
112 .header("Accept", "application/json")
113 .send()
114 .await
115 .map_err(|e| A2AError::NetworkError {
116 message: format!("Failed to fetch agent card from {}: {}", card_url, e),
117 })?;
118
119 if !response.status().is_success() {
120 return Err(A2AError::NetworkError {
121 message: format!("Failed to fetch agent card: HTTP {}", response.status()),
122 });
123 }
124
125 let agent_card: AgentCard =
126 response
127 .json()
128 .await
129 .map_err(|e| A2AError::SerializationError {
130 message: format!("Failed to parse agent card: {}", e),
131 })?;
132
133 if agent_card.url.is_empty() {
134 return Err(A2AError::InvalidParameter {
135 message: "Agent card does not contain a valid 'url' for the service endpoint"
136 .to_string(),
137 });
138 }
139
140 Ok(Self {
141 client: http_client,
142 service_endpoint_url: agent_card.url.clone(),
143 auth_token: None,
144 request_id_counter: Arc::new(AtomicU64::new(1)),
145 agent_card: Arc::new(agent_card),
146 })
147 }
148
149 pub fn from_card(agent_card: AgentCard) -> A2AResult<Self> {
166 Self::from_card_with_client(agent_card, Client::new())
167 }
168
169 pub fn from_card_with_client(agent_card: AgentCard, http_client: Client) -> A2AResult<Self> {
197 if agent_card.url.is_empty() {
198 return Err(A2AError::InvalidParameter {
199 message: "Agent card does not contain a valid 'url' for the service endpoint"
200 .to_string(),
201 });
202 }
203
204 Ok(Self {
205 client: http_client,
206 service_endpoint_url: agent_card.url.clone(),
207 auth_token: None,
208 request_id_counter: Arc::new(AtomicU64::new(1)),
209 agent_card: Arc::new(agent_card),
210 })
211 }
212
213 pub fn from_card_with_headers(
235 agent_card: AgentCard,
236 headers: std::collections::HashMap<String, String>,
237 ) -> A2AResult<Self> {
238 use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
239 use std::str::FromStr;
240
241 let mut header_map = HeaderMap::new();
242 for (key, value) in headers {
243 let header_name =
244 HeaderName::from_str(&key).map_err(|e| A2AError::InvalidParameter {
245 message: format!("Invalid header name '{}': {}", key, e),
246 })?;
247 let header_value =
248 HeaderValue::from_str(&value).map_err(|e| A2AError::InvalidParameter {
249 message: format!("Invalid header value for '{}': {}", key, e),
250 })?;
251 header_map.insert(header_name, header_value);
252 }
253
254 let http_client = Client::builder()
255 .default_headers(header_map)
256 .build()
257 .map_err(|e| A2AError::NetworkError {
258 message: format!("Failed to build HTTP client with headers: {}", e),
259 })?;
260
261 Self::from_card_with_client(agent_card, http_client)
262 }
263
264 pub fn with_auth_token(mut self, token: impl Into<String>) -> Self {
266 self.auth_token = Some(token.into());
267 self
268 }
269
270 pub fn agent_card(&self) -> &AgentCard {
272 &self.agent_card
273 }
274
275 pub async fn fetch_agent_card(&self, base_url: impl AsRef<str>) -> A2AResult<AgentCard> {
277 let base_url = base_url.as_ref().trim_end_matches('/');
278 let card_url = format!("{}/{}", base_url, AGENT_CARD_PATH);
279
280 let mut req = self
281 .client
282 .get(&card_url)
283 .header("Accept", "application/json");
284
285 if let Some(token) = &self.auth_token {
286 req = req.bearer_auth(token);
287 }
288
289 let response = req.send().await.map_err(|e| A2AError::NetworkError {
290 message: format!("Failed to fetch agent card from {}: {}", card_url, e),
291 })?;
292
293 if !response.status().is_success() {
294 return Err(A2AError::NetworkError {
295 message: format!("Failed to fetch agent card: HTTP {}", response.status()),
296 });
297 }
298
299 response
300 .json()
301 .await
302 .map_err(|e| A2AError::SerializationError {
303 message: format!("Failed to parse agent card: {}", e),
304 })
305 }
306
307 fn next_request_id(&self) -> JSONRPCId {
309 let id = self.request_id_counter.fetch_add(1, Ordering::SeqCst);
310 JSONRPCId::Integer(id as i64)
311 }
312
313 async fn post_rpc_request<TParams, TResponse>(
315 &self,
316 method: &str,
317 params: TParams,
318 ) -> A2AResult<JsonRpcResponse<TResponse>>
319 where
320 TParams: Serialize,
321 TResponse: for<'de> Deserialize<'de>,
322 {
323 let request_id = self.next_request_id();
324 let rpc_request = JsonRpcRequest {
325 jsonrpc: JSONRPC_VERSION.to_string(),
326 method: method.to_string(),
327 params,
328 id: request_id.clone(),
329 };
330
331 let mut req = self
332 .client
333 .post(&self.service_endpoint_url)
334 .header("Content-Type", "application/json")
335 .header("Accept", "application/json")
336 .json(&rpc_request);
337
338 if let Some(token) = &self.auth_token {
339 req = req.bearer_auth(token);
340 }
341
342 let response = req.send().await.map_err(|e| A2AError::NetworkError {
343 message: format!("Failed to send {} request: {}", method, e),
344 })?;
345
346 if !response.status().is_success() {
347 let status = response.status();
349 let error_text = response.text().await.unwrap_or_default();
350 if let Ok(error_json) = serde_json::from_str::<JSONRPCErrorResponse>(&error_text) {
351 return Ok(JsonRpcResponse::Error(error_json));
352 }
353 return Err(A2AError::NetworkError {
354 message: format!("HTTP error {}: {}", status, error_text),
355 });
356 }
357
358 let json_response: JsonRpcResponse<TResponse> =
359 response
360 .json()
361 .await
362 .map_err(|e| A2AError::SerializationError {
363 message: format!("Failed to parse {} response: {}", method, e),
364 })?;
365
366 if let JsonRpcResponse::Success {
368 id: Some(resp_id), ..
369 } = &json_response
370 {
371 if resp_id != &request_id {
372 eprintln!(
373 "WARNING: RPC response ID mismatch for method {}. Expected {:?}, got {:?}",
374 method, request_id, resp_id
375 );
376 }
377 }
378
379 Ok(json_response)
380 }
381
382 pub async fn send_message(&self, params: MessageSendParams) -> A2AResult<SendMessageResponse> {
384 match self.post_rpc_request("message/send", params).await? {
385 JsonRpcResponse::Success { result, .. } => Ok(result),
386 JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
387 message: format!("Remote agent error: {}", err.error.message),
388 code: Some(err.error.code),
389 }),
390 }
391 }
392
393 pub async fn send_streaming_message(
397 &self,
398 params: MessageSendParams,
399 ) -> A2AResult<Pin<Box<dyn Stream<Item = A2AResult<SendStreamingMessageResult>> + Send>>> {
400 if !self.agent_card.capabilities.streaming.unwrap_or(false) {
402 return Err(A2AError::InvalidParameter {
403 message: "Agent does not support streaming (capabilities.streaming is not true)"
404 .to_string(),
405 });
406 }
407
408 let request_id = self.next_request_id();
409 let rpc_request = JsonRpcRequest {
410 jsonrpc: JSONRPC_VERSION.to_string(),
411 method: "message/stream".to_string(),
412 params,
413 id: request_id.clone(),
414 };
415
416 let mut req = self
417 .client
418 .post(&self.service_endpoint_url)
419 .header("Content-Type", "application/json")
420 .header("Accept", "text/event-stream")
421 .json(&rpc_request);
422
423 if let Some(token) = &self.auth_token {
424 req = req.bearer_auth(token);
425 }
426
427 let response = req.send().await.map_err(|e| A2AError::NetworkError {
428 message: format!("Failed to send streaming message request: {}", e),
429 })?;
430
431 if !response.status().is_success() {
432 let status = response.status();
433 let error_text = response.text().await.unwrap_or_default();
434 return Err(A2AError::NetworkError {
435 message: format!("HTTP error {}: {}", status, error_text),
436 });
437 }
438
439 let content_type = response
441 .headers()
442 .get("Content-Type")
443 .and_then(|v| v.to_str().ok())
444 .unwrap_or("");
445
446 if !content_type.starts_with("text/event-stream") {
447 return Err(A2AError::NetworkError {
448 message: format!(
449 "Invalid response Content-Type for SSE stream. Expected 'text/event-stream', got '{}'",
450 content_type
451 ),
452 });
453 }
454
455 Ok(Box::pin(Self::parse_sse_stream(
457 response.bytes_stream(),
458 request_id,
459 )))
460 }
461
462 fn parse_sse_stream(
464 byte_stream: impl Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send + 'static,
465 _original_request_id: JSONRPCId,
466 ) -> impl Stream<Item = A2AResult<SendStreamingMessageResult>> + Send {
467 use futures_core::stream::Stream;
468 use std::pin::Pin;
469 use std::task::{Context, Poll};
470
471 struct SseParser {
472 inner: Pin<Box<dyn Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send>>,
473 buffer: String,
474 event_data_buffer: String,
475 pending_results: Vec<A2AResult<SendStreamingMessageResult>>,
476 }
477
478 impl SseParser {
479 fn new(
480 inner: impl Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send + 'static,
481 ) -> Self {
482 Self {
483 inner: Box::pin(inner),
484 buffer: String::new(),
485 event_data_buffer: String::new(),
486 pending_results: Vec::new(),
487 }
488 }
489
490 fn process_chunk(
491 &mut self,
492 chunk: bytes::Bytes,
493 ) -> Vec<A2AResult<SendStreamingMessageResult>> {
494 self.buffer.push_str(&String::from_utf8_lossy(&chunk));
496
497 let mut results = Vec::new();
498
499 while let Some(newline_pos) = self.buffer.find('\n') {
501 let line = self.buffer[..newline_pos]
502 .trim_end_matches('\r')
503 .to_string();
504 self.buffer = self.buffer[newline_pos + 1..].to_string();
505
506 if line.is_empty() {
507 if !self.event_data_buffer.is_empty() {
509 match A2AClient::process_sse_event(&self.event_data_buffer) {
510 Ok(result) => results.push(Ok(result)),
511 Err(e) => results.push(Err(e)),
512 }
513 self.event_data_buffer.clear();
514 }
515 } else if let Some(data) = line.strip_prefix("data:") {
516 if !self.event_data_buffer.is_empty() {
518 self.event_data_buffer.push('\n');
519 }
520 self.event_data_buffer.push_str(data.trim_start());
521 } else if line.starts_with(':') {
522 }
524 }
526
527 results
528 }
529 }
530
531 impl Stream for SseParser {
532 type Item = A2AResult<SendStreamingMessageResult>;
533
534 fn poll_next(
535 mut self: Pin<&mut Self>,
536 cx: &mut Context<'_>,
537 ) -> Poll<Option<Self::Item>> {
538 if let Some(result) = self.pending_results.pop() {
540 return Poll::Ready(Some(result));
541 }
542
543 match self.inner.as_mut().poll_next(cx) {
545 Poll::Ready(Some(Ok(chunk))) => {
546 let mut results = self.process_chunk(chunk);
548
549 if results.is_empty() {
550 cx.waker().wake_by_ref();
552 Poll::Pending
553 } else {
554 results.reverse();
556 self.pending_results = results;
557
558 Poll::Ready(self.pending_results.pop())
560 }
561 }
562 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(A2AError::NetworkError {
563 message: format!("Stream error: {}", e),
564 }))),
565 Poll::Ready(None) => Poll::Ready(None),
566 Poll::Pending => Poll::Pending,
567 }
568 }
569 }
570
571 SseParser::new(byte_stream)
572 }
573
574 fn process_sse_event(json_data: &str) -> A2AResult<SendStreamingMessageResult> {
576 if json_data.trim().is_empty() {
577 return Err(A2AError::SerializationError {
578 message: "Empty SSE event data".to_string(),
579 });
580 }
581
582 let json_response: JsonRpcResponse<SendStreamingMessageResult> =
584 serde_json::from_str(json_data).map_err(|e| A2AError::SerializationError {
585 message: format!("Failed to parse SSE event data: {}", e),
586 })?;
587
588 match json_response {
589 JsonRpcResponse::Success { result, .. } => Ok(result),
590 JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
591 message: format!("SSE event contained an error: {}", err.error.message),
592 code: Some(err.error.code),
593 }),
594 }
595 }
596
597 pub async fn get_task(&self, params: TaskQueryParams) -> A2AResult<Task> {
599 match self.post_rpc_request("tasks/get", params).await? {
600 JsonRpcResponse::Success { result, .. } => Ok(result),
601 JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
602 message: format!("Remote agent error: {}", err.error.message),
603 code: Some(err.error.code),
604 }),
605 }
606 }
607
608 pub async fn cancel_task(&self, params: TaskIdParams) -> A2AResult<Task> {
610 match self.post_rpc_request("tasks/cancel", params).await? {
611 JsonRpcResponse::Success { result, .. } => Ok(result),
612 JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
613 message: format!("Remote agent error: {}", err.error.message),
614 code: Some(err.error.code),
615 }),
616 }
617 }
618
619 pub async fn resubscribe_task(
623 &self,
624 params: TaskIdParams,
625 ) -> A2AResult<Pin<Box<dyn Stream<Item = A2AResult<SendStreamingMessageResult>> + Send>>> {
626 if !self.agent_card.capabilities.streaming.unwrap_or(false) {
628 return Err(A2AError::InvalidParameter {
629 message: "Agent does not support streaming (required for tasks/resubscribe)"
630 .to_string(),
631 });
632 }
633
634 let request_id = self.next_request_id();
635 let rpc_request = JsonRpcRequest {
636 jsonrpc: JSONRPC_VERSION.to_string(),
637 method: "tasks/resubscribe".to_string(),
638 params,
639 id: request_id.clone(),
640 };
641
642 let mut req = self
643 .client
644 .post(&self.service_endpoint_url)
645 .header("Content-Type", "application/json")
646 .header("Accept", "text/event-stream")
647 .json(&rpc_request);
648
649 if let Some(token) = &self.auth_token {
650 req = req.bearer_auth(token);
651 }
652
653 let response = req.send().await.map_err(|e| A2AError::NetworkError {
654 message: format!("Failed to send resubscribe request: {}", e),
655 })?;
656
657 if !response.status().is_success() {
658 let status = response.status();
659 let error_text = response.text().await.unwrap_or_default();
660 return Err(A2AError::NetworkError {
661 message: format!("HTTP error {}: {}", status, error_text),
662 });
663 }
664
665 let content_type = response
667 .headers()
668 .get("Content-Type")
669 .and_then(|v| v.to_str().ok())
670 .unwrap_or("");
671
672 if !content_type.starts_with("text/event-stream") {
673 return Err(A2AError::NetworkError {
674 message: format!(
675 "Invalid response Content-Type for SSE stream on resubscribe. Expected 'text/event-stream', got '{}'",
676 content_type
677 ),
678 });
679 }
680
681 Ok(Box::pin(Self::parse_sse_stream(
682 response.bytes_stream(),
683 request_id,
684 )))
685 }
686
687 pub async fn set_task_push_notification_config(
689 &self,
690 params: TaskPushNotificationConfig,
691 ) -> A2AResult<TaskPushNotificationConfig> {
692 if !self
694 .agent_card
695 .capabilities
696 .push_notifications
697 .unwrap_or(false)
698 {
699 return Err(A2AError::InvalidParameter {
700 message: "Agent does not support push notifications (capabilities.pushNotifications is not true)"
701 .to_string(),
702 });
703 }
704
705 match self
706 .post_rpc_request("tasks/pushNotificationConfig/set", params)
707 .await?
708 {
709 JsonRpcResponse::Success { result, .. } => Ok(result),
710 JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
711 message: format!("Remote agent error: {}", err.error.message),
712 code: Some(err.error.code),
713 }),
714 }
715 }
716
717 pub async fn get_task_push_notification_config(
719 &self,
720 params: TaskIdParams,
721 ) -> A2AResult<TaskPushNotificationConfig> {
722 match self
723 .post_rpc_request("tasks/pushNotificationConfig/get", params)
724 .await?
725 {
726 JsonRpcResponse::Success { result, .. } => Ok(result),
727 JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
728 message: format!("Remote agent error: {}", err.error.message),
729 code: Some(err.error.code),
730 }),
731 }
732 }
733
734 pub async fn list_task_push_notification_config(
736 &self,
737 params: ListTaskPushNotificationConfigParams,
738 ) -> A2AResult<Vec<TaskPushNotificationConfig>> {
739 match self
740 .post_rpc_request("tasks/pushNotificationConfig/list", params)
741 .await?
742 {
743 JsonRpcResponse::Success { result, .. } => Ok(result),
744 JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
745 message: format!("Remote agent error: {}", err.error.message),
746 code: Some(err.error.code),
747 }),
748 }
749 }
750
751 pub async fn delete_task_push_notification_config(
753 &self,
754 params: DeleteTaskPushNotificationConfigParams,
755 ) -> A2AResult<()> {
756 match self
757 .post_rpc_request::<_, serde_json::Value>("tasks/pushNotificationConfig/delete", params)
758 .await?
759 {
760 JsonRpcResponse::Success { .. } => Ok(()),
761 JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
762 message: format!("Remote agent error: {}", err.error.message),
763 code: Some(err.error.code),
764 }),
765 }
766 }
767
768 pub async fn call_extension_method<TParams, TResponse>(
772 &self,
773 method: &str,
774 params: TParams,
775 ) -> A2AResult<TResponse>
776 where
777 TParams: Serialize,
778 TResponse: for<'de> Deserialize<'de>,
779 {
780 match self.post_rpc_request(method, params).await? {
781 JsonRpcResponse::Success { result, .. } => Ok(result),
782 JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
783 message: format!("Remote agent error: {}", err.error.message),
784 code: Some(err.error.code),
785 }),
786 }
787 }
788
789 pub async fn list_tasks(&self, context_id: Option<String>) -> A2AResult<Vec<Task>> {
793 #[derive(Serialize)]
794 struct ListTasksParams {
795 #[serde(skip_serializing_if = "Option::is_none")]
796 context_id: Option<String>,
797 }
798
799 match self
800 .post_rpc_request("tasks/list", ListTasksParams { context_id })
801 .await?
802 {
803 JsonRpcResponse::Success { result, .. } => Ok(result),
804 JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
805 message: format!("Remote agent error: {}", err.error.message),
806 code: Some(err.error.code),
807 }),
808 }
809 }
810}
811
812#[cfg(test)]
813mod tests {
814 use super::*;
815
816 #[test]
817 fn test_client_requires_valid_card_url() {
818 let card_without_url = AgentCard {
819 name: "Test".to_string(),
820 description: "Test".to_string(),
821 version: "1.0.0".to_string(),
822 protocol_version: "0.3.0".to_string(),
823 url: "".to_string(), preferred_transport: a2a_types::TransportProtocol::JsonRpc,
825 capabilities: a2a_types::AgentCapabilities::default(),
826 default_input_modes: vec![],
827 default_output_modes: vec![],
828 skills: vec![],
829 provider: None,
830 additional_interfaces: vec![],
831 documentation_url: None,
832 icon_url: None,
833 security: vec![],
834 security_schemes: None,
835 signatures: vec![],
836 supports_authenticated_extended_card: None,
837 };
838
839 assert!(A2AClient::from_card(card_without_url).is_err());
840 }
841
842 #[test]
843 fn test_from_card_with_headers() {
844 let mut headers = std::collections::HashMap::new();
845 headers.insert("Authorization".to_string(), "Bearer token123".to_string());
846 headers.insert("X-API-Key".to_string(), "my-api-key".to_string());
847
848 let card = AgentCard {
849 name: "Test".to_string(),
850 description: "Test agent".to_string(),
851 version: "1.0.0".to_string(),
852 protocol_version: "0.3.0".to_string(),
853 url: "https://example.com".to_string(),
854 preferred_transport: a2a_types::TransportProtocol::JsonRpc,
855 capabilities: a2a_types::AgentCapabilities::default(),
856 default_input_modes: vec![],
857 default_output_modes: vec![],
858 skills: vec![],
859 provider: None,
860 additional_interfaces: vec![],
861 documentation_url: None,
862 icon_url: None,
863 security: vec![],
864 security_schemes: None,
865 signatures: vec![],
866 supports_authenticated_extended_card: None,
867 };
868
869 let result = A2AClient::from_card_with_headers(card, headers);
870 assert!(result.is_ok());
871
872 let client = result.unwrap();
873 assert_eq!(client.service_endpoint_url, "https://example.com");
874 }
875
876 #[test]
877 fn test_from_card_with_invalid_header_name() {
878 let mut headers = std::collections::HashMap::new();
879 headers.insert("Invalid Header Name!".to_string(), "value".to_string());
880
881 let card = AgentCard {
882 name: "Test".to_string(),
883 description: "Test agent".to_string(),
884 version: "1.0.0".to_string(),
885 protocol_version: "0.3.0".to_string(),
886 url: "https://example.com".to_string(),
887 preferred_transport: a2a_types::TransportProtocol::JsonRpc,
888 capabilities: a2a_types::AgentCapabilities::default(),
889 default_input_modes: vec![],
890 default_output_modes: vec![],
891 skills: vec![],
892 provider: None,
893 additional_interfaces: vec![],
894 documentation_url: None,
895 icon_url: None,
896 security: vec![],
897 security_schemes: None,
898 signatures: vec![],
899 supports_authenticated_extended_card: None,
900 };
901
902 let result = A2AClient::from_card_with_headers(card, headers);
903 assert!(result.is_err());
904 if let Err(err) = result {
905 assert!(matches!(err, A2AError::InvalidParameter { .. }));
906 }
907 }
908}