1use crate::constants::{AGENT_CARD_PATH, JSONRPC_VERSION};
7use crate::error::{A2AError, A2AResult};
8use a2a_types::{
9 AgentCard, DeleteTaskPushNotificationConfigParams, JSONRPCErrorResponse, JSONRPCId,
10 ListTaskPushNotificationConfigParams, MessageSendParams, SendMessageResponse,
11 SendStreamingMessageResult, Task, TaskIdParams, TaskPushNotificationConfig, TaskQueryParams,
12};
13use futures_core::Stream;
14use reqwest::Client;
15use serde::{Deserialize, Serialize};
16use std::pin::Pin;
17use std::sync::Arc;
18use std::sync::atomic::{AtomicU64, Ordering};
19
20#[derive(Clone)]
22pub struct A2AClient {
23 client: Client,
25 service_endpoint_url: String,
27 auth_token: Option<String>,
29 request_id_counter: Arc<AtomicU64>,
31 agent_card: Arc<AgentCard>,
33}
34
35#[derive(Debug, Serialize)]
37struct JsonRpcRequest<T> {
38 jsonrpc: String,
39 id: JSONRPCId,
40 method: String,
41 params: T,
42}
43
44#[derive(Debug, Deserialize)]
46#[serde(untagged)]
47enum JsonRpcResponse<T> {
48 Success {
49 jsonrpc: String,
50 id: Option<JSONRPCId>,
51 result: T,
52 },
53 Error(JSONRPCErrorResponse),
54}
55
56impl A2AClient {
57 pub async fn from_card_url(base_url: impl AsRef<str>) -> A2AResult<Self> {
76 Self::from_card_url_with_client(base_url, Client::new()).await
77 }
78
79 pub async fn from_card_url_with_client(
104 base_url: impl AsRef<str>,
105 http_client: Client,
106 ) -> A2AResult<Self> {
107 let base_url = base_url.as_ref().trim_end_matches('/');
108 let card_url = format!("{}/{}", base_url, AGENT_CARD_PATH);
109
110 let response = http_client
111 .get(&card_url)
112 .header("Accept", "application/json")
113 .send()
114 .await
115 .map_err(|e| A2AError::NetworkError {
116 message: format!("Failed to fetch agent card from {}: {}", card_url, e),
117 })?;
118
119 if !response.status().is_success() {
120 return Err(A2AError::NetworkError {
121 message: format!("Failed to fetch agent card: HTTP {}", response.status()),
122 });
123 }
124
125 let agent_card: AgentCard =
126 response
127 .json()
128 .await
129 .map_err(|e| A2AError::SerializationError {
130 message: format!("Failed to parse agent card: {}", e),
131 })?;
132
133 if agent_card.url.is_empty() {
134 return Err(A2AError::InvalidParameter {
135 message: "Agent card does not contain a valid 'url' for the service endpoint"
136 .to_string(),
137 });
138 }
139
140 Ok(Self {
141 client: http_client,
142 service_endpoint_url: agent_card.url.clone(),
143 auth_token: None,
144 request_id_counter: Arc::new(AtomicU64::new(1)),
145 agent_card: Arc::new(agent_card),
146 })
147 }
148
149 pub fn from_card(agent_card: AgentCard) -> A2AResult<Self> {
166 Self::from_card_with_client(agent_card, Client::new())
167 }
168
169 pub fn from_card_with_client(agent_card: AgentCard, http_client: Client) -> A2AResult<Self> {
197 if agent_card.url.is_empty() {
198 return Err(A2AError::InvalidParameter {
199 message: "Agent card does not contain a valid 'url' for the service endpoint"
200 .to_string(),
201 });
202 }
203
204 Ok(Self {
205 client: http_client,
206 service_endpoint_url: agent_card.url.clone(),
207 auth_token: None,
208 request_id_counter: Arc::new(AtomicU64::new(1)),
209 agent_card: Arc::new(agent_card),
210 })
211 }
212
213 pub fn with_auth_token(mut self, token: impl Into<String>) -> Self {
215 self.auth_token = Some(token.into());
216 self
217 }
218
219 pub fn agent_card(&self) -> &AgentCard {
221 &self.agent_card
222 }
223
224 pub async fn fetch_agent_card(&self, base_url: impl AsRef<str>) -> A2AResult<AgentCard> {
226 let base_url = base_url.as_ref().trim_end_matches('/');
227 let card_url = format!("{}/{}", base_url, AGENT_CARD_PATH);
228
229 let mut req = self
230 .client
231 .get(&card_url)
232 .header("Accept", "application/json");
233
234 if let Some(token) = &self.auth_token {
235 req = req.bearer_auth(token);
236 }
237
238 let response = req.send().await.map_err(|e| A2AError::NetworkError {
239 message: format!("Failed to fetch agent card from {}: {}", card_url, e),
240 })?;
241
242 if !response.status().is_success() {
243 return Err(A2AError::NetworkError {
244 message: format!("Failed to fetch agent card: HTTP {}", response.status()),
245 });
246 }
247
248 response
249 .json()
250 .await
251 .map_err(|e| A2AError::SerializationError {
252 message: format!("Failed to parse agent card: {}", e),
253 })
254 }
255
256 fn next_request_id(&self) -> JSONRPCId {
258 let id = self.request_id_counter.fetch_add(1, Ordering::SeqCst);
259 JSONRPCId::Integer(id as i64)
260 }
261
262 async fn post_rpc_request<TParams, TResponse>(
264 &self,
265 method: &str,
266 params: TParams,
267 ) -> A2AResult<JsonRpcResponse<TResponse>>
268 where
269 TParams: Serialize,
270 TResponse: for<'de> Deserialize<'de>,
271 {
272 let request_id = self.next_request_id();
273 let rpc_request = JsonRpcRequest {
274 jsonrpc: JSONRPC_VERSION.to_string(),
275 method: method.to_string(),
276 params,
277 id: request_id.clone(),
278 };
279
280 let mut req = self
281 .client
282 .post(&self.service_endpoint_url)
283 .header("Content-Type", "application/json")
284 .header("Accept", "application/json")
285 .json(&rpc_request);
286
287 if let Some(token) = &self.auth_token {
288 req = req.bearer_auth(token);
289 }
290
291 let response = req.send().await.map_err(|e| A2AError::NetworkError {
292 message: format!("Failed to send {} request: {}", method, e),
293 })?;
294
295 if !response.status().is_success() {
296 let status = response.status();
298 let error_text = response.text().await.unwrap_or_default();
299 if let Ok(error_json) = serde_json::from_str::<JSONRPCErrorResponse>(&error_text) {
300 return Ok(JsonRpcResponse::Error(error_json));
301 }
302 return Err(A2AError::NetworkError {
303 message: format!("HTTP error {}: {}", status, error_text),
304 });
305 }
306
307 let json_response: JsonRpcResponse<TResponse> =
308 response
309 .json()
310 .await
311 .map_err(|e| A2AError::SerializationError {
312 message: format!("Failed to parse {} response: {}", method, e),
313 })?;
314
315 if let JsonRpcResponse::Success {
317 id: Some(resp_id), ..
318 } = &json_response
319 {
320 if resp_id != &request_id {
321 eprintln!(
322 "WARNING: RPC response ID mismatch for method {}. Expected {:?}, got {:?}",
323 method, request_id, resp_id
324 );
325 }
326 }
327
328 Ok(json_response)
329 }
330
331 pub async fn send_message(&self, params: MessageSendParams) -> A2AResult<SendMessageResponse> {
333 match self.post_rpc_request("message/send", params).await? {
334 JsonRpcResponse::Success { result, .. } => Ok(result),
335 JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
336 message: format!("Remote agent error: {}", err.error.message),
337 code: Some(err.error.code),
338 }),
339 }
340 }
341
342 pub async fn send_streaming_message(
346 &self,
347 params: MessageSendParams,
348 ) -> A2AResult<Pin<Box<dyn Stream<Item = A2AResult<SendStreamingMessageResult>> + Send>>> {
349 if !self.agent_card.capabilities.streaming.unwrap_or(false) {
351 return Err(A2AError::InvalidParameter {
352 message: "Agent does not support streaming (capabilities.streaming is not true)"
353 .to_string(),
354 });
355 }
356
357 let request_id = self.next_request_id();
358 let rpc_request = JsonRpcRequest {
359 jsonrpc: JSONRPC_VERSION.to_string(),
360 method: "message/stream".to_string(),
361 params,
362 id: request_id.clone(),
363 };
364
365 let mut req = self
366 .client
367 .post(&self.service_endpoint_url)
368 .header("Content-Type", "application/json")
369 .header("Accept", "text/event-stream")
370 .json(&rpc_request);
371
372 if let Some(token) = &self.auth_token {
373 req = req.bearer_auth(token);
374 }
375
376 let response = req.send().await.map_err(|e| A2AError::NetworkError {
377 message: format!("Failed to send streaming message request: {}", e),
378 })?;
379
380 if !response.status().is_success() {
381 let status = response.status();
382 let error_text = response.text().await.unwrap_or_default();
383 return Err(A2AError::NetworkError {
384 message: format!("HTTP error {}: {}", status, error_text),
385 });
386 }
387
388 let content_type = response
390 .headers()
391 .get("Content-Type")
392 .and_then(|v| v.to_str().ok())
393 .unwrap_or("");
394
395 if !content_type.starts_with("text/event-stream") {
396 return Err(A2AError::NetworkError {
397 message: format!(
398 "Invalid response Content-Type for SSE stream. Expected 'text/event-stream', got '{}'",
399 content_type
400 ),
401 });
402 }
403
404 Ok(Box::pin(Self::parse_sse_stream(
406 response.bytes_stream(),
407 request_id,
408 )))
409 }
410
411 fn parse_sse_stream(
413 byte_stream: impl Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send + 'static,
414 _original_request_id: JSONRPCId,
415 ) -> impl Stream<Item = A2AResult<SendStreamingMessageResult>> + Send {
416 use futures_core::stream::Stream;
417 use std::pin::Pin;
418 use std::task::{Context, Poll};
419
420 struct SseParser {
421 inner: Pin<Box<dyn Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send>>,
422 buffer: String,
423 event_data_buffer: String,
424 pending_results: Vec<A2AResult<SendStreamingMessageResult>>,
425 }
426
427 impl SseParser {
428 fn new(
429 inner: impl Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send + 'static,
430 ) -> Self {
431 Self {
432 inner: Box::pin(inner),
433 buffer: String::new(),
434 event_data_buffer: String::new(),
435 pending_results: Vec::new(),
436 }
437 }
438
439 fn process_chunk(
440 &mut self,
441 chunk: bytes::Bytes,
442 ) -> Vec<A2AResult<SendStreamingMessageResult>> {
443 self.buffer.push_str(&String::from_utf8_lossy(&chunk));
445
446 let mut results = Vec::new();
447
448 while let Some(newline_pos) = self.buffer.find('\n') {
450 let line = self.buffer[..newline_pos]
451 .trim_end_matches('\r')
452 .to_string();
453 self.buffer = self.buffer[newline_pos + 1..].to_string();
454
455 if line.is_empty() {
456 if !self.event_data_buffer.is_empty() {
458 match A2AClient::process_sse_event(&self.event_data_buffer) {
459 Ok(result) => results.push(Ok(result)),
460 Err(e) => results.push(Err(e)),
461 }
462 self.event_data_buffer.clear();
463 }
464 } else if let Some(data) = line.strip_prefix("data:") {
465 if !self.event_data_buffer.is_empty() {
467 self.event_data_buffer.push('\n');
468 }
469 self.event_data_buffer.push_str(data.trim_start());
470 } else if line.starts_with(':') {
471 }
473 }
475
476 results
477 }
478 }
479
480 impl Stream for SseParser {
481 type Item = A2AResult<SendStreamingMessageResult>;
482
483 fn poll_next(
484 mut self: Pin<&mut Self>,
485 cx: &mut Context<'_>,
486 ) -> Poll<Option<Self::Item>> {
487 if let Some(result) = self.pending_results.pop() {
489 return Poll::Ready(Some(result));
490 }
491
492 match self.inner.as_mut().poll_next(cx) {
494 Poll::Ready(Some(Ok(chunk))) => {
495 let mut results = self.process_chunk(chunk);
497
498 if results.is_empty() {
499 cx.waker().wake_by_ref();
501 Poll::Pending
502 } else {
503 results.reverse();
505 self.pending_results = results;
506
507 Poll::Ready(self.pending_results.pop())
509 }
510 }
511 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(A2AError::NetworkError {
512 message: format!("Stream error: {}", e),
513 }))),
514 Poll::Ready(None) => Poll::Ready(None),
515 Poll::Pending => Poll::Pending,
516 }
517 }
518 }
519
520 SseParser::new(byte_stream)
521 }
522
523 fn process_sse_event(json_data: &str) -> A2AResult<SendStreamingMessageResult> {
525 if json_data.trim().is_empty() {
526 return Err(A2AError::SerializationError {
527 message: "Empty SSE event data".to_string(),
528 });
529 }
530
531 let json_response: JsonRpcResponse<SendStreamingMessageResult> =
533 serde_json::from_str(json_data).map_err(|e| A2AError::SerializationError {
534 message: format!("Failed to parse SSE event data: {}", e),
535 })?;
536
537 match json_response {
538 JsonRpcResponse::Success { result, .. } => Ok(result),
539 JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
540 message: format!("SSE event contained an error: {}", err.error.message),
541 code: Some(err.error.code),
542 }),
543 }
544 }
545
546 pub async fn get_task(&self, params: TaskQueryParams) -> A2AResult<Task> {
548 match self.post_rpc_request("tasks/get", params).await? {
549 JsonRpcResponse::Success { result, .. } => Ok(result),
550 JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
551 message: format!("Remote agent error: {}", err.error.message),
552 code: Some(err.error.code),
553 }),
554 }
555 }
556
557 pub async fn cancel_task(&self, params: TaskIdParams) -> A2AResult<Task> {
559 match self.post_rpc_request("tasks/cancel", params).await? {
560 JsonRpcResponse::Success { result, .. } => Ok(result),
561 JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
562 message: format!("Remote agent error: {}", err.error.message),
563 code: Some(err.error.code),
564 }),
565 }
566 }
567
568 pub async fn resubscribe_task(
572 &self,
573 params: TaskIdParams,
574 ) -> A2AResult<Pin<Box<dyn Stream<Item = A2AResult<SendStreamingMessageResult>> + Send>>> {
575 if !self.agent_card.capabilities.streaming.unwrap_or(false) {
577 return Err(A2AError::InvalidParameter {
578 message: "Agent does not support streaming (required for tasks/resubscribe)"
579 .to_string(),
580 });
581 }
582
583 let request_id = self.next_request_id();
584 let rpc_request = JsonRpcRequest {
585 jsonrpc: JSONRPC_VERSION.to_string(),
586 method: "tasks/resubscribe".to_string(),
587 params,
588 id: request_id.clone(),
589 };
590
591 let mut req = self
592 .client
593 .post(&self.service_endpoint_url)
594 .header("Content-Type", "application/json")
595 .header("Accept", "text/event-stream")
596 .json(&rpc_request);
597
598 if let Some(token) = &self.auth_token {
599 req = req.bearer_auth(token);
600 }
601
602 let response = req.send().await.map_err(|e| A2AError::NetworkError {
603 message: format!("Failed to send resubscribe request: {}", e),
604 })?;
605
606 if !response.status().is_success() {
607 let status = response.status();
608 let error_text = response.text().await.unwrap_or_default();
609 return Err(A2AError::NetworkError {
610 message: format!("HTTP error {}: {}", status, error_text),
611 });
612 }
613
614 let content_type = response
616 .headers()
617 .get("Content-Type")
618 .and_then(|v| v.to_str().ok())
619 .unwrap_or("");
620
621 if !content_type.starts_with("text/event-stream") {
622 return Err(A2AError::NetworkError {
623 message: format!(
624 "Invalid response Content-Type for SSE stream on resubscribe. Expected 'text/event-stream', got '{}'",
625 content_type
626 ),
627 });
628 }
629
630 Ok(Box::pin(Self::parse_sse_stream(
631 response.bytes_stream(),
632 request_id,
633 )))
634 }
635
636 pub async fn set_task_push_notification_config(
638 &self,
639 params: TaskPushNotificationConfig,
640 ) -> A2AResult<TaskPushNotificationConfig> {
641 if !self
643 .agent_card
644 .capabilities
645 .push_notifications
646 .unwrap_or(false)
647 {
648 return Err(A2AError::InvalidParameter {
649 message: "Agent does not support push notifications (capabilities.pushNotifications is not true)"
650 .to_string(),
651 });
652 }
653
654 match self
655 .post_rpc_request("tasks/pushNotificationConfig/set", params)
656 .await?
657 {
658 JsonRpcResponse::Success { result, .. } => Ok(result),
659 JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
660 message: format!("Remote agent error: {}", err.error.message),
661 code: Some(err.error.code),
662 }),
663 }
664 }
665
666 pub async fn get_task_push_notification_config(
668 &self,
669 params: TaskIdParams,
670 ) -> A2AResult<TaskPushNotificationConfig> {
671 match self
672 .post_rpc_request("tasks/pushNotificationConfig/get", params)
673 .await?
674 {
675 JsonRpcResponse::Success { result, .. } => Ok(result),
676 JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
677 message: format!("Remote agent error: {}", err.error.message),
678 code: Some(err.error.code),
679 }),
680 }
681 }
682
683 pub async fn list_task_push_notification_config(
685 &self,
686 params: ListTaskPushNotificationConfigParams,
687 ) -> A2AResult<Vec<TaskPushNotificationConfig>> {
688 match self
689 .post_rpc_request("tasks/pushNotificationConfig/list", params)
690 .await?
691 {
692 JsonRpcResponse::Success { result, .. } => Ok(result),
693 JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
694 message: format!("Remote agent error: {}", err.error.message),
695 code: Some(err.error.code),
696 }),
697 }
698 }
699
700 pub async fn delete_task_push_notification_config(
702 &self,
703 params: DeleteTaskPushNotificationConfigParams,
704 ) -> A2AResult<()> {
705 match self
706 .post_rpc_request::<_, serde_json::Value>("tasks/pushNotificationConfig/delete", params)
707 .await?
708 {
709 JsonRpcResponse::Success { .. } => Ok(()),
710 JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
711 message: format!("Remote agent error: {}", err.error.message),
712 code: Some(err.error.code),
713 }),
714 }
715 }
716
717 pub async fn call_extension_method<TParams, TResponse>(
721 &self,
722 method: &str,
723 params: TParams,
724 ) -> A2AResult<TResponse>
725 where
726 TParams: Serialize,
727 TResponse: for<'de> Deserialize<'de>,
728 {
729 match self.post_rpc_request(method, params).await? {
730 JsonRpcResponse::Success { result, .. } => Ok(result),
731 JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
732 message: format!("Remote agent error: {}", err.error.message),
733 code: Some(err.error.code),
734 }),
735 }
736 }
737
738 pub async fn list_tasks(&self, context_id: Option<String>) -> A2AResult<Vec<Task>> {
742 #[derive(Serialize)]
743 struct ListTasksParams {
744 #[serde(skip_serializing_if = "Option::is_none")]
745 context_id: Option<String>,
746 }
747
748 match self
749 .post_rpc_request("tasks/list", ListTasksParams { context_id })
750 .await?
751 {
752 JsonRpcResponse::Success { result, .. } => Ok(result),
753 JsonRpcResponse::Error(err) => Err(A2AError::RemoteAgentError {
754 message: format!("Remote agent error: {}", err.error.message),
755 code: Some(err.error.code),
756 }),
757 }
758 }
759}
760
761#[cfg(test)]
762mod tests {
763 use super::*;
764
765 #[test]
766 fn test_client_requires_valid_card_url() {
767 let card_without_url = AgentCard {
768 name: "Test".to_string(),
769 description: "Test".to_string(),
770 version: "1.0.0".to_string(),
771 protocol_version: "0.3.0".to_string(),
772 url: "".to_string(), preferred_transport: a2a_types::TransportProtocol::JsonRpc,
774 capabilities: a2a_types::AgentCapabilities::default(),
775 default_input_modes: vec![],
776 default_output_modes: vec![],
777 skills: vec![],
778 provider: None,
779 additional_interfaces: vec![],
780 documentation_url: None,
781 icon_url: None,
782 security: vec![],
783 security_schemes: None,
784 signatures: vec![],
785 supports_authenticated_extended_card: None,
786 };
787
788 assert!(A2AClient::from_card(card_without_url).is_err());
789 }
790}