Skip to main content

a2a_ao/
client.rs

1//! A2A Client — high-level client for interacting with A2A-compatible agents.
2//!
3//! The client handles agent discovery, message sending, task tracking,
4//! streaming, and push notification management.
5
6use futures::StreamExt;
7use reqwest::Client;
8use serde::Serialize;
9use url::Url;
10
11use crate::agent_card::AgentCard;
12use crate::error::{A2AError, A2AResult};
13use crate::message::Message;
14use crate::notification::PushNotificationConfig;
15use crate::task::{Task, TaskQueryParams};
16use crate::transport::jsonrpc::{self, JsonRpcRequest, JsonRpcResponse, A2A_MEDIA_TYPE};
17use crate::transport::sse::TaskEventStream;
18
19/// High-level A2A client for communicating with remote agents.
20#[derive(Debug, Clone)]
21pub struct A2AClient {
22    /// Base URL of the remote agent.
23    base_url: Url,
24
25    /// The discovered agent card (populated after discover()).
26    agent_card: Option<AgentCard>,
27
28    /// HTTP client.
29    http: Client,
30
31    /// Optional bearer token for authentication.
32    auth_token: Option<String>,
33}
34
35impl A2AClient {
36    /// Create a new A2A client for a remote agent.
37    pub fn new(base_url: &str) -> Self {
38        Self {
39            base_url: Url::parse(base_url).expect("Invalid base URL"),
40            agent_card: None,
41            http: Client::new(),
42            auth_token: None,
43        }
44    }
45
46    /// Create a client with a custom HTTP client.
47    pub fn with_http_client(base_url: &str, http: Client) -> Self {
48        Self {
49            base_url: Url::parse(base_url).expect("Invalid base URL"),
50            agent_card: None,
51            http,
52            auth_token: None,
53        }
54    }
55
56    /// Set authentication token.
57    pub fn with_auth(mut self, token: impl Into<String>) -> Self {
58        self.auth_token = Some(token.into());
59        self
60    }
61
62    /// Discover the remote agent's capabilities by fetching its Agent Card.
63    pub async fn discover(&mut self) -> A2AResult<&AgentCard> {
64        let card = AgentCard::discover(self.base_url.as_str()).await?;
65        self.agent_card = Some(card);
66        Ok(self.agent_card.as_ref().unwrap())
67    }
68
69    /// Get the cached agent card (call discover() first).
70    pub fn agent_card(&self) -> Option<&AgentCard> {
71        self.agent_card.as_ref()
72    }
73
74    // ── Core Operations ──────────────────────────────────────
75
76    /// Send a message to the remote agent, creating or continuing a task.
77    pub async fn send_message(&self, request: SendMessageRequest) -> A2AResult<Task> {
78        let params = serde_json::to_value(&request).map_err(A2AError::Serialization)?;
79
80        let rpc_request = JsonRpcRequest::send_message(params);
81        let response = self.send_rpc(rpc_request).await?;
82        let result = response.into_result().map_err(|e| A2AError::JsonRpc {
83            code: e.code,
84            message: e.message,
85            data: e.data,
86        })?;
87
88        let task: Task = serde_json::from_value(result)?;
89        Ok(task)
90    }
91
92    /// Convenience: send a simple text message.
93    pub async fn send_message_text(&self, text: &str) -> A2AResult<Task> {
94        self.send_message(SendMessageRequest {
95            message: Message::user_text(text),
96            task_id: None,
97            context_id: None,
98            metadata: None,
99        })
100        .await
101    }
102
103    /// Continue an existing task with additional input.
104    pub async fn continue_task(&self, task_id: &str, text: &str) -> A2AResult<Task> {
105        self.send_message(SendMessageRequest {
106            message: Message::user_text(text),
107            task_id: Some(task_id.to_string()),
108            context_id: None,
109            metadata: None,
110        })
111        .await
112    }
113
114    /// Get a task by its ID.
115    pub async fn get_task(&self, task_id: &str) -> A2AResult<Task> {
116        let rpc_request = JsonRpcRequest::get_task(task_id);
117        let response = self.send_rpc(rpc_request).await?;
118        let result = response.into_result().map_err(|e| A2AError::JsonRpc {
119            code: e.code,
120            message: e.message,
121            data: e.data,
122        })?;
123
124        let task: Task = serde_json::from_value(result)?;
125        Ok(task)
126    }
127
128    /// List tasks matching the given query parameters.
129    pub async fn list_tasks(&self, params: TaskQueryParams) -> A2AResult<Vec<Task>> {
130        let rpc_params = serde_json::to_value(&params)?;
131        let rpc_request = JsonRpcRequest::list_tasks(rpc_params);
132        let response = self.send_rpc(rpc_request).await?;
133        let result = response.into_result().map_err(|e| A2AError::JsonRpc {
134            code: e.code,
135            message: e.message,
136            data: e.data,
137        })?;
138
139        let tasks: Vec<Task> = serde_json::from_value(result)?;
140        Ok(tasks)
141    }
142
143    /// Cancel a running task.
144    pub async fn cancel_task(&self, task_id: &str) -> A2AResult<Task> {
145        let rpc_request = JsonRpcRequest::cancel_task(task_id);
146        let response = self.send_rpc(rpc_request).await?;
147        let result = response.into_result().map_err(|e| A2AError::JsonRpc {
148            code: e.code,
149            message: e.message,
150            data: e.data,
151        })?;
152
153        let task: Task = serde_json::from_value(result)?;
154        Ok(task)
155    }
156
157    // ── Streaming Operations ─────────────────────────────────
158
159    /// Send a streaming message — returns an SSE stream of task events.
160    ///
161    /// The remote agent will stream back `TaskEvent`s as it processes
162    /// the message (state changes, partial artifacts, etc.).
163    pub async fn send_streaming_message(
164        &self,
165        request: SendMessageRequest,
166    ) -> A2AResult<TaskEventStream> {
167        let params = serde_json::to_value(&request).map_err(A2AError::Serialization)?;
168
169        let rpc_request = JsonRpcRequest::send_streaming_message(params);
170
171        let mut http_request = self
172            .http
173            .post(self.base_url.as_str())
174            .header("Content-Type", A2A_MEDIA_TYPE)
175            .header("Accept", "text/event-stream")
176            .json(&rpc_request);
177
178        if let Some(ref token) = self.auth_token {
179            http_request = http_request.bearer_auth(token);
180        }
181
182        tracing::debug!(url = %self.base_url, "Sending streaming A2A request");
183
184        let response = http_request.send().await?;
185
186        if !response.status().is_success() {
187            return Err(A2AError::Transport(
188                response.error_for_status().unwrap_err(),
189            ));
190        }
191
192        let byte_stream = response.bytes_stream();
193        let event_stream = Box::pin(
194            byte_stream
195                .map(|chunk| match chunk {
196                    Ok(bytes) => {
197                        let text = String::from_utf8_lossy(&bytes);
198                        // Parse SSE data lines
199                        let mut events = Vec::new();
200                        for line in text.lines() {
201                            if let Some(data) = line.strip_prefix("data: ") {
202                                if data == "[DONE]" {
203                                    break;
204                                }
205                                match crate::transport::sse::parse_sse_event(data) {
206                                    Ok(event) => events.push(Ok(event)),
207                                    Err(e) => events.push(Err(e)),
208                                }
209                            }
210                        }
211                        futures::stream::iter(events)
212                    }
213                    Err(e) => futures::stream::iter(vec![Err(A2AError::StreamingError(format!(
214                        "Stream read error: {e}"
215                    )))]),
216                })
217                .flatten(),
218        );
219
220        Ok(TaskEventStream::new(event_stream))
221    }
222
223    /// Convenience: send a streaming text message.
224    pub async fn send_streaming_text(&self, text: &str) -> A2AResult<TaskEventStream> {
225        self.send_streaming_message(SendMessageRequest {
226            message: Message::user_text(text),
227            task_id: None,
228            context_id: None,
229            metadata: None,
230        })
231        .await
232    }
233
234    /// Subscribe to an existing task's updates via SSE.
235    pub async fn subscribe_task(&self, task_id: &str) -> A2AResult<TaskEventStream> {
236        let rpc_request = JsonRpcRequest::new(
237            jsonrpc::methods::SUBSCRIBE_TASK,
238            Some(serde_json::json!({ "taskId": task_id })),
239        );
240
241        let mut http_request = self
242            .http
243            .post(self.base_url.as_str())
244            .header("Content-Type", A2A_MEDIA_TYPE)
245            .header("Accept", "text/event-stream")
246            .json(&rpc_request);
247
248        if let Some(ref token) = self.auth_token {
249            http_request = http_request.bearer_auth(token);
250        }
251
252        let response = http_request.send().await?;
253
254        if !response.status().is_success() {
255            return Err(A2AError::Transport(
256                response.error_for_status().unwrap_err(),
257            ));
258        }
259
260        let byte_stream = response.bytes_stream();
261        let event_stream = Box::pin(
262            byte_stream
263                .map(|chunk| match chunk {
264                    Ok(bytes) => {
265                        let text = String::from_utf8_lossy(&bytes);
266                        let mut events = Vec::new();
267                        for line in text.lines() {
268                            if let Some(data) = line.strip_prefix("data: ") {
269                                if data == "[DONE]" {
270                                    break;
271                                }
272                                match crate::transport::sse::parse_sse_event(data) {
273                                    Ok(event) => events.push(Ok(event)),
274                                    Err(e) => events.push(Err(e)),
275                                }
276                            }
277                        }
278                        futures::stream::iter(events)
279                    }
280                    Err(e) => futures::stream::iter(vec![Err(A2AError::StreamingError(format!(
281                        "Stream read error: {e}"
282                    )))]),
283                })
284                .flatten(),
285        );
286
287        Ok(TaskEventStream::new(event_stream))
288    }
289
290    // ── Push Notification Operations ─────────────────────────
291
292    /// Create a push notification configuration for a task.
293    pub async fn create_push_notification(
294        &self,
295        config: &PushNotificationConfig,
296    ) -> A2AResult<PushNotificationConfig> {
297        let params = serde_json::to_value(config)?;
298        let rpc_request =
299            JsonRpcRequest::new(jsonrpc::methods::CREATE_PUSH_NOTIFICATION, Some(params));
300        let response = self.send_rpc(rpc_request).await?;
301        let result = response.into_result().map_err(|e| A2AError::JsonRpc {
302            code: e.code,
303            message: e.message,
304            data: e.data,
305        })?;
306        Ok(serde_json::from_value(result)?)
307    }
308
309    /// Get a push notification configuration by ID.
310    pub async fn get_push_notification(
311        &self,
312        config_id: &str,
313        task_id: &str,
314    ) -> A2AResult<PushNotificationConfig> {
315        let rpc_request = JsonRpcRequest::new(
316            jsonrpc::methods::GET_PUSH_NOTIFICATION,
317            Some(serde_json::json!({ "configId": config_id, "taskId": task_id })),
318        );
319        let response = self.send_rpc(rpc_request).await?;
320        let result = response.into_result().map_err(|e| A2AError::JsonRpc {
321            code: e.code,
322            message: e.message,
323            data: e.data,
324        })?;
325        Ok(serde_json::from_value(result)?)
326    }
327
328    /// List push notification configurations for a task.
329    pub async fn list_push_notifications(
330        &self,
331        task_id: &str,
332    ) -> A2AResult<Vec<PushNotificationConfig>> {
333        let rpc_request = JsonRpcRequest::new(
334            jsonrpc::methods::LIST_PUSH_NOTIFICATIONS,
335            Some(serde_json::json!({ "taskId": task_id })),
336        );
337        let response = self.send_rpc(rpc_request).await?;
338        let result = response.into_result().map_err(|e| A2AError::JsonRpc {
339            code: e.code,
340            message: e.message,
341            data: e.data,
342        })?;
343        Ok(serde_json::from_value(result)?)
344    }
345
346    /// Delete a push notification configuration.
347    pub async fn delete_push_notification(&self, config_id: &str, task_id: &str) -> A2AResult<()> {
348        let rpc_request = JsonRpcRequest::new(
349            jsonrpc::methods::DELETE_PUSH_NOTIFICATION,
350            Some(serde_json::json!({ "configId": config_id, "taskId": task_id })),
351        );
352        let response = self.send_rpc(rpc_request).await?;
353        response.into_result().map_err(|e| A2AError::JsonRpc {
354            code: e.code,
355            message: e.message,
356            data: e.data,
357        })?;
358        Ok(())
359    }
360
361    /// Get the extended agent card (post-authentication).
362    pub async fn get_extended_agent_card(&self) -> A2AResult<AgentCard> {
363        let rpc_request = JsonRpcRequest::new(jsonrpc::methods::GET_EXTENDED_AGENT_CARD, None);
364        let response = self.send_rpc(rpc_request).await?;
365        let result = response.into_result().map_err(|e| A2AError::JsonRpc {
366            code: e.code,
367            message: e.message,
368            data: e.data,
369        })?;
370        Ok(serde_json::from_value(result)?)
371    }
372
373    // ── Internal Transport ───────────────────────────────────
374
375    /// Send a JSON-RPC request to the remote agent.
376    async fn send_rpc(&self, request: JsonRpcRequest) -> A2AResult<JsonRpcResponse> {
377        let mut http_request = self
378            .http
379            .post(self.base_url.as_str())
380            .header("Content-Type", A2A_MEDIA_TYPE)
381            .header("Accept", A2A_MEDIA_TYPE)
382            .json(&request);
383
384        if let Some(ref token) = self.auth_token {
385            http_request = http_request.bearer_auth(token);
386        }
387
388        tracing::debug!(
389            method = %request.method,
390            url = %self.base_url,
391            "Sending A2A request"
392        );
393
394        let response = http_request.send().await?;
395
396        if !response.status().is_success() {
397            return Err(A2AError::Transport(
398                response.error_for_status().unwrap_err(),
399            ));
400        }
401
402        let rpc_response: JsonRpcResponse = response.json().await?;
403        Ok(rpc_response)
404    }
405}
406
407// ── Request Types ────────────────────────────────────────────
408
409/// Request to send a message to the remote agent.
410#[derive(Debug, Clone, Serialize)]
411#[serde(rename_all = "camelCase")]
412pub struct SendMessageRequest {
413    /// The message to send.
414    pub message: Message,
415
416    /// Existing task ID to continue (optional — omit to create a new task).
417    #[serde(skip_serializing_if = "Option::is_none")]
418    pub task_id: Option<String>,
419
420    /// Context ID to group related tasks.
421    #[serde(skip_serializing_if = "Option::is_none")]
422    pub context_id: Option<String>,
423
424    /// Optional metadata.
425    #[serde(skip_serializing_if = "Option::is_none")]
426    pub metadata: Option<serde_json::Value>,
427}
428
429impl Default for SendMessageRequest {
430    fn default() -> Self {
431        Self {
432            message: Message::user(vec![]),
433            task_id: None,
434            context_id: None,
435            metadata: None,
436        }
437    }
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443
444    #[test]
445    fn test_send_message_request_serialization() {
446        let req = SendMessageRequest {
447            message: Message::user_text("Hello"),
448            task_id: None,
449            context_id: Some("session-1".into()),
450            metadata: None,
451        };
452
453        let json = serde_json::to_string_pretty(&req).unwrap();
454        assert!(json.contains("session-1"));
455        assert!(json.contains("Hello"));
456    }
457}