1use futures::StreamExt;
7use reqwest::Client;
8use serde::Serialize;
9use url::Url;
10
11use crate::agent_card::AgentCard;
12use crate::error::{A2AError, A2AResult};
13use crate::message::Message;
14use crate::notification::PushNotificationConfig;
15use crate::task::{Task, TaskQueryParams};
16use crate::transport::jsonrpc::{self, JsonRpcRequest, JsonRpcResponse, A2A_MEDIA_TYPE};
17use crate::transport::sse::TaskEventStream;
18
19#[derive(Debug, Clone)]
21pub struct A2AClient {
22 base_url: Url,
24
25 agent_card: Option<AgentCard>,
27
28 http: Client,
30
31 auth_token: Option<String>,
33}
34
35impl A2AClient {
36 pub fn new(base_url: &str) -> Self {
38 Self {
39 base_url: Url::parse(base_url).expect("Invalid base URL"),
40 agent_card: None,
41 http: Client::new(),
42 auth_token: None,
43 }
44 }
45
46 pub fn with_http_client(base_url: &str, http: Client) -> Self {
48 Self {
49 base_url: Url::parse(base_url).expect("Invalid base URL"),
50 agent_card: None,
51 http,
52 auth_token: None,
53 }
54 }
55
56 pub fn with_auth(mut self, token: impl Into<String>) -> Self {
58 self.auth_token = Some(token.into());
59 self
60 }
61
62 pub async fn discover(&mut self) -> A2AResult<&AgentCard> {
64 let card = AgentCard::discover(self.base_url.as_str()).await?;
65 self.agent_card = Some(card);
66 Ok(self.agent_card.as_ref().unwrap())
67 }
68
69 pub fn agent_card(&self) -> Option<&AgentCard> {
71 self.agent_card.as_ref()
72 }
73
74 pub async fn send_message(&self, request: SendMessageRequest) -> A2AResult<Task> {
78 let params = serde_json::to_value(&request).map_err(A2AError::Serialization)?;
79
80 let rpc_request = JsonRpcRequest::send_message(params);
81 let response = self.send_rpc(rpc_request).await?;
82 let result = response.into_result().map_err(|e| A2AError::JsonRpc {
83 code: e.code,
84 message: e.message,
85 data: e.data,
86 })?;
87
88 let task: Task = serde_json::from_value(result)?;
89 Ok(task)
90 }
91
92 pub async fn send_message_text(&self, text: &str) -> A2AResult<Task> {
94 self.send_message(SendMessageRequest {
95 message: Message::user_text(text),
96 task_id: None,
97 context_id: None,
98 metadata: None,
99 })
100 .await
101 }
102
103 pub async fn continue_task(&self, task_id: &str, text: &str) -> A2AResult<Task> {
105 self.send_message(SendMessageRequest {
106 message: Message::user_text(text),
107 task_id: Some(task_id.to_string()),
108 context_id: None,
109 metadata: None,
110 })
111 .await
112 }
113
114 pub async fn get_task(&self, task_id: &str) -> A2AResult<Task> {
116 let rpc_request = JsonRpcRequest::get_task(task_id);
117 let response = self.send_rpc(rpc_request).await?;
118 let result = response.into_result().map_err(|e| A2AError::JsonRpc {
119 code: e.code,
120 message: e.message,
121 data: e.data,
122 })?;
123
124 let task: Task = serde_json::from_value(result)?;
125 Ok(task)
126 }
127
128 pub async fn list_tasks(&self, params: TaskQueryParams) -> A2AResult<Vec<Task>> {
130 let rpc_params = serde_json::to_value(¶ms)?;
131 let rpc_request = JsonRpcRequest::list_tasks(rpc_params);
132 let response = self.send_rpc(rpc_request).await?;
133 let result = response.into_result().map_err(|e| A2AError::JsonRpc {
134 code: e.code,
135 message: e.message,
136 data: e.data,
137 })?;
138
139 let tasks: Vec<Task> = serde_json::from_value(result)?;
140 Ok(tasks)
141 }
142
143 pub async fn cancel_task(&self, task_id: &str) -> A2AResult<Task> {
145 let rpc_request = JsonRpcRequest::cancel_task(task_id);
146 let response = self.send_rpc(rpc_request).await?;
147 let result = response.into_result().map_err(|e| A2AError::JsonRpc {
148 code: e.code,
149 message: e.message,
150 data: e.data,
151 })?;
152
153 let task: Task = serde_json::from_value(result)?;
154 Ok(task)
155 }
156
157 pub async fn send_streaming_message(
164 &self,
165 request: SendMessageRequest,
166 ) -> A2AResult<TaskEventStream> {
167 let params = serde_json::to_value(&request).map_err(A2AError::Serialization)?;
168
169 let rpc_request = JsonRpcRequest::send_streaming_message(params);
170
171 let mut http_request = self
172 .http
173 .post(self.base_url.as_str())
174 .header("Content-Type", A2A_MEDIA_TYPE)
175 .header("Accept", "text/event-stream")
176 .json(&rpc_request);
177
178 if let Some(ref token) = self.auth_token {
179 http_request = http_request.bearer_auth(token);
180 }
181
182 tracing::debug!(url = %self.base_url, "Sending streaming A2A request");
183
184 let response = http_request.send().await?;
185
186 if !response.status().is_success() {
187 return Err(A2AError::Transport(
188 response.error_for_status().unwrap_err(),
189 ));
190 }
191
192 let byte_stream = response.bytes_stream();
193 let event_stream = Box::pin(
194 byte_stream
195 .map(|chunk| match chunk {
196 Ok(bytes) => {
197 let text = String::from_utf8_lossy(&bytes);
198 let mut events = Vec::new();
200 for line in text.lines() {
201 if let Some(data) = line.strip_prefix("data: ") {
202 if data == "[DONE]" {
203 break;
204 }
205 match crate::transport::sse::parse_sse_event(data) {
206 Ok(event) => events.push(Ok(event)),
207 Err(e) => events.push(Err(e)),
208 }
209 }
210 }
211 futures::stream::iter(events)
212 }
213 Err(e) => futures::stream::iter(vec![Err(A2AError::StreamingError(format!(
214 "Stream read error: {e}"
215 )))]),
216 })
217 .flatten(),
218 );
219
220 Ok(TaskEventStream::new(event_stream))
221 }
222
223 pub async fn send_streaming_text(&self, text: &str) -> A2AResult<TaskEventStream> {
225 self.send_streaming_message(SendMessageRequest {
226 message: Message::user_text(text),
227 task_id: None,
228 context_id: None,
229 metadata: None,
230 })
231 .await
232 }
233
234 pub async fn subscribe_task(&self, task_id: &str) -> A2AResult<TaskEventStream> {
236 let rpc_request = JsonRpcRequest::new(
237 jsonrpc::methods::SUBSCRIBE_TASK,
238 Some(serde_json::json!({ "taskId": task_id })),
239 );
240
241 let mut http_request = self
242 .http
243 .post(self.base_url.as_str())
244 .header("Content-Type", A2A_MEDIA_TYPE)
245 .header("Accept", "text/event-stream")
246 .json(&rpc_request);
247
248 if let Some(ref token) = self.auth_token {
249 http_request = http_request.bearer_auth(token);
250 }
251
252 let response = http_request.send().await?;
253
254 if !response.status().is_success() {
255 return Err(A2AError::Transport(
256 response.error_for_status().unwrap_err(),
257 ));
258 }
259
260 let byte_stream = response.bytes_stream();
261 let event_stream = Box::pin(
262 byte_stream
263 .map(|chunk| match chunk {
264 Ok(bytes) => {
265 let text = String::from_utf8_lossy(&bytes);
266 let mut events = Vec::new();
267 for line in text.lines() {
268 if let Some(data) = line.strip_prefix("data: ") {
269 if data == "[DONE]" {
270 break;
271 }
272 match crate::transport::sse::parse_sse_event(data) {
273 Ok(event) => events.push(Ok(event)),
274 Err(e) => events.push(Err(e)),
275 }
276 }
277 }
278 futures::stream::iter(events)
279 }
280 Err(e) => futures::stream::iter(vec![Err(A2AError::StreamingError(format!(
281 "Stream read error: {e}"
282 )))]),
283 })
284 .flatten(),
285 );
286
287 Ok(TaskEventStream::new(event_stream))
288 }
289
290 pub async fn create_push_notification(
294 &self,
295 config: &PushNotificationConfig,
296 ) -> A2AResult<PushNotificationConfig> {
297 let params = serde_json::to_value(config)?;
298 let rpc_request =
299 JsonRpcRequest::new(jsonrpc::methods::CREATE_PUSH_NOTIFICATION, Some(params));
300 let response = self.send_rpc(rpc_request).await?;
301 let result = response.into_result().map_err(|e| A2AError::JsonRpc {
302 code: e.code,
303 message: e.message,
304 data: e.data,
305 })?;
306 Ok(serde_json::from_value(result)?)
307 }
308
309 pub async fn get_push_notification(
311 &self,
312 config_id: &str,
313 task_id: &str,
314 ) -> A2AResult<PushNotificationConfig> {
315 let rpc_request = JsonRpcRequest::new(
316 jsonrpc::methods::GET_PUSH_NOTIFICATION,
317 Some(serde_json::json!({ "configId": config_id, "taskId": task_id })),
318 );
319 let response = self.send_rpc(rpc_request).await?;
320 let result = response.into_result().map_err(|e| A2AError::JsonRpc {
321 code: e.code,
322 message: e.message,
323 data: e.data,
324 })?;
325 Ok(serde_json::from_value(result)?)
326 }
327
328 pub async fn list_push_notifications(
330 &self,
331 task_id: &str,
332 ) -> A2AResult<Vec<PushNotificationConfig>> {
333 let rpc_request = JsonRpcRequest::new(
334 jsonrpc::methods::LIST_PUSH_NOTIFICATIONS,
335 Some(serde_json::json!({ "taskId": task_id })),
336 );
337 let response = self.send_rpc(rpc_request).await?;
338 let result = response.into_result().map_err(|e| A2AError::JsonRpc {
339 code: e.code,
340 message: e.message,
341 data: e.data,
342 })?;
343 Ok(serde_json::from_value(result)?)
344 }
345
346 pub async fn delete_push_notification(&self, config_id: &str, task_id: &str) -> A2AResult<()> {
348 let rpc_request = JsonRpcRequest::new(
349 jsonrpc::methods::DELETE_PUSH_NOTIFICATION,
350 Some(serde_json::json!({ "configId": config_id, "taskId": task_id })),
351 );
352 let response = self.send_rpc(rpc_request).await?;
353 response.into_result().map_err(|e| A2AError::JsonRpc {
354 code: e.code,
355 message: e.message,
356 data: e.data,
357 })?;
358 Ok(())
359 }
360
361 pub async fn get_extended_agent_card(&self) -> A2AResult<AgentCard> {
363 let rpc_request = JsonRpcRequest::new(jsonrpc::methods::GET_EXTENDED_AGENT_CARD, None);
364 let response = self.send_rpc(rpc_request).await?;
365 let result = response.into_result().map_err(|e| A2AError::JsonRpc {
366 code: e.code,
367 message: e.message,
368 data: e.data,
369 })?;
370 Ok(serde_json::from_value(result)?)
371 }
372
373 async fn send_rpc(&self, request: JsonRpcRequest) -> A2AResult<JsonRpcResponse> {
377 let mut http_request = self
378 .http
379 .post(self.base_url.as_str())
380 .header("Content-Type", A2A_MEDIA_TYPE)
381 .header("Accept", A2A_MEDIA_TYPE)
382 .json(&request);
383
384 if let Some(ref token) = self.auth_token {
385 http_request = http_request.bearer_auth(token);
386 }
387
388 tracing::debug!(
389 method = %request.method,
390 url = %self.base_url,
391 "Sending A2A request"
392 );
393
394 let response = http_request.send().await?;
395
396 if !response.status().is_success() {
397 return Err(A2AError::Transport(
398 response.error_for_status().unwrap_err(),
399 ));
400 }
401
402 let rpc_response: JsonRpcResponse = response.json().await?;
403 Ok(rpc_response)
404 }
405}
406
407#[derive(Debug, Clone, Serialize)]
411#[serde(rename_all = "camelCase")]
412pub struct SendMessageRequest {
413 pub message: Message,
415
416 #[serde(skip_serializing_if = "Option::is_none")]
418 pub task_id: Option<String>,
419
420 #[serde(skip_serializing_if = "Option::is_none")]
422 pub context_id: Option<String>,
423
424 #[serde(skip_serializing_if = "Option::is_none")]
426 pub metadata: Option<serde_json::Value>,
427}
428
429impl Default for SendMessageRequest {
430 fn default() -> Self {
431 Self {
432 message: Message::user(vec![]),
433 task_id: None,
434 context_id: None,
435 metadata: None,
436 }
437 }
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443
444 #[test]
445 fn test_send_message_request_serialization() {
446 let req = SendMessageRequest {
447 message: Message::user_text("Hello"),
448 task_id: None,
449 context_id: Some("session-1".into()),
450 metadata: None,
451 };
452
453 let json = serde_json::to_string_pretty(&req).unwrap();
454 assert!(json.contains("session-1"));
455 assert!(json.contains("Hello"));
456 }
457}