1use crate::{
4 client::config::ClientConfig,
5 prelude::A2AError,
6 protocol::{A2AOperation, AgentCard, Message, Task, TaskStatus},
7 service::{A2ARequest, A2AResponse, RequestContext},
8};
9use tower_service::Service;
10
11pub struct AgentClient<S> {
34 service: S,
35 config: ClientConfig,
36}
37
38impl<S> AgentClient<S>
39where
40 S: Service<A2ARequest, Response = A2AResponse, Error = A2AError>,
41{
42 pub fn new(service: S, config: ClientConfig) -> Self {
49 Self { service, config }
50 }
51
52 pub fn config(&self) -> &ClientConfig {
54 &self.config
55 }
56
57 fn build_context(&self) -> RequestContext {
59 RequestContext {
60 agent_url: self.config.agent_url.clone(),
61 auth: None, timeout: Some(self.config.timeout),
63 metadata: Default::default(),
64 }
65 }
66
67 pub async fn send_message(&mut self, message: Message) -> Result<Task, A2AError> {
81 let operation = A2AOperation::SendMessage {
82 message,
83 stream: false,
84 context_id: None,
85 task_id: None,
86 };
87
88 let request = A2ARequest::new(operation, self.build_context());
89 let response = self.service.call(request).await?;
90
91 match response {
92 A2AResponse::Task(task) => Ok(*task),
93 _ => Err(A2AError::Protocol(
94 "Expected task response from send_message".into(),
95 )),
96 }
97 }
98
99 pub async fn send_message_streaming(&mut self, message: Message) -> Result<Task, A2AError> {
103 let operation = A2AOperation::SendMessage {
104 message,
105 stream: true,
106 context_id: None,
107 task_id: None,
108 };
109
110 let request = A2ARequest::new(operation, self.build_context());
111 let response = self.service.call(request).await?;
112
113 match response {
114 A2AResponse::Task(task) => Ok(*task),
115 _ => Err(A2AError::Protocol(
116 "Expected task response from send_message_streaming".into(),
117 )),
118 }
119 }
120
121 pub async fn send_message_in_context(
128 &mut self,
129 message: Message,
130 context_id: String,
131 ) -> Result<Task, A2AError> {
132 let operation = A2AOperation::SendMessage {
133 message,
134 stream: false,
135 context_id: Some(context_id),
136 task_id: None,
137 };
138
139 let request = A2ARequest::new(operation, self.build_context());
140 let response = self.service.call(request).await?;
141
142 match response {
143 A2AResponse::Task(task) => Ok(*task),
144 _ => Err(A2AError::Protocol(
145 "Expected task response from send_message_in_context".into(),
146 )),
147 }
148 }
149
150 pub async fn get_task(&mut self, task_id: String) -> Result<Task, A2AError> {
164 let operation = A2AOperation::GetTask { task_id };
165
166 let request = A2ARequest::new(operation, self.build_context());
167 let response = self.service.call(request).await?;
168
169 match response {
170 A2AResponse::Task(task) => Ok(*task),
171 _ => Err(A2AError::Protocol(
172 "Expected task response from get_task".into(),
173 )),
174 }
175 }
176
177 pub async fn list_tasks(
188 &mut self,
189 status: Option<TaskStatus>,
190 limit: Option<u32>,
191 ) -> Result<Vec<Task>, A2AError> {
192 let operation = A2AOperation::ListTasks {
193 status,
194 limit,
195 offset: None,
196 next_token: None,
197 };
198
199 let request = A2ARequest::new(operation, self.build_context());
200 let response = self.service.call(request).await?;
201
202 match response {
203 A2AResponse::TaskList { tasks, .. } => Ok(tasks),
204 _ => Err(A2AError::Protocol(
205 "Expected task list response from list_tasks".into(),
206 )),
207 }
208 }
209
210 pub async fn list_all_tasks(&mut self) -> Result<Vec<Task>, A2AError> {
212 self.list_tasks(None, None).await
213 }
214
215 pub async fn list_tasks_by_status(
217 &mut self,
218 status: TaskStatus,
219 ) -> Result<Vec<Task>, A2AError> {
220 self.list_tasks(Some(status), None).await
221 }
222
223 pub async fn cancel_task(&mut self, task_id: String) -> Result<Task, A2AError> {
233 let operation = A2AOperation::CancelTask { task_id };
234
235 let request = A2ARequest::new(operation, self.build_context());
236 let response = self.service.call(request).await?;
237
238 match response {
239 A2AResponse::Task(task) => Ok(*task),
240 _ => Err(A2AError::Protocol(
241 "Expected task response from cancel_task".into(),
242 )),
243 }
244 }
245
246 pub async fn discover(&mut self) -> Result<AgentCard, A2AError> {
254 let operation = A2AOperation::DiscoverAgent;
255
256 let request = A2ARequest::new(operation, self.build_context());
257 let response = self.service.call(request).await?;
258
259 match response {
260 A2AResponse::AgentCard(card) => Ok(*card),
261 _ => Err(A2AError::Protocol(
262 "Expected agent card response from discover".into(),
263 )),
264 }
265 }
266
267 pub async fn poll_until_complete(
282 &mut self,
283 task_id: String,
284 poll_interval_ms: u64,
285 max_attempts: usize,
286 ) -> Result<Task, A2AError> {
287 let mut attempts = 0;
288
289 loop {
290 let task = self.get_task(task_id.clone()).await?;
291
292 if task.is_terminal() {
293 return Ok(task);
294 }
295
296 attempts += 1;
297 if max_attempts > 0 && attempts >= max_attempts {
298 return Err(A2AError::Timeout);
299 }
300
301 tokio::time::sleep(tokio::time::Duration::from_millis(poll_interval_ms)).await;
302 }
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use std::sync::Arc;
309
310 use crate::{
311 codec::JsonCodec,
312 protocol::message::Message,
313 service::A2AProtocolService,
314 transport::{mock::MockTransport, TransportResponse},
315 };
316 use bytes::Bytes;
317
318 use super::*;
319
320 #[tokio::test]
321 async fn test_send_message() {
322 let transport = MockTransport::new(|_req| {
323 let task = Task::new("task-123", Message::user("Test"));
324 let json = serde_json::to_vec(&task).unwrap();
325 TransportResponse::new(200).body(Bytes::from(json))
326 });
327
328 let codec = Arc::new(JsonCodec::new());
329 let service = A2AProtocolService::new(transport, codec);
330 let config = ClientConfig::new("https://example.com");
331 let mut client = AgentClient::new(service, config);
332
333 let message = Message::user("Hello");
334 let task = client.send_message(message).await.unwrap();
335
336 assert_eq!(task.id, "task-123");
337 }
338
339 #[tokio::test]
340 async fn test_get_task() {
341 let transport = MockTransport::new(|_req| {
342 let task = Task::new("task-456", Message::user("Test"));
343 let json = serde_json::to_vec(&task).unwrap();
344 TransportResponse::new(200).body(Bytes::from(json))
345 });
346
347 let codec = Arc::new(JsonCodec::new());
348 let service = A2AProtocolService::new(transport, codec);
349 let config = ClientConfig::new("https://example.com");
350 let mut client = AgentClient::new(service, config);
351
352 let task = client.get_task("task-456".to_string()).await.unwrap();
353
354 assert_eq!(task.id, "task-456");
355 }
356
357 #[tokio::test]
358 async fn test_discover() {
359 use crate::protocol::agent::{AgentCapabilities, AgentCard};
360
361 let transport = MockTransport::new(|_req| {
362 let card = AgentCard::new("Test Agent", "A test agent", AgentCapabilities::default());
363 let json = serde_json::to_vec(&card).unwrap();
364 TransportResponse::new(200).body(Bytes::from(json))
365 });
366
367 let codec = Arc::new(JsonCodec::new());
368 let service = A2AProtocolService::new(transport, codec);
369 let config = ClientConfig::new("https://example.com");
370 let mut client = AgentClient::new(service, config);
371
372 let card = client.discover().await.unwrap();
373
374 assert_eq!(card.name, "Test Agent");
375 }
376}