tower_a2a/client/
agent.rs

1//! High-level A2A agent client
2
3use crate::{
4    client::config::ClientConfig,
5    prelude::A2AError,
6    protocol::{A2AOperation, AgentCard, Message, Task, TaskStatus},
7    service::{A2ARequest, A2AResponse, RequestContext},
8};
9use tower_service::Service;
10
11/// High-level A2A client for interacting with agents
12///
13/// This client wraps a Tower service and provides convenient methods for common A2A operations.
14/// The service is generic over any implementation that satisfies the Service trait bounds.
15///
16/// # Example
17///
18/// ```rust,no_run
19/// use tower_a2a::prelude::*;
20///
21/// # async fn example() -> Result<(), A2AError> {
22/// let url = "https://agent.example.com".parse().unwrap();
23/// let mut client = A2AClientBuilder::new(url)
24///     .with_http()
25///     .build()?;
26///
27/// let message = Message::user("Hello, agent!");
28/// let task = client.send_message(message).await?;
29/// println!("Task created: {}", task.id);
30/// # Ok(())
31/// # }
32/// ```
33pub struct AgentClient<S> {
34    service: S,
35    config: ClientConfig,
36}
37
38impl<S> AgentClient<S>
39where
40    S: Service<A2ARequest, Response = A2AResponse, Error = A2AError>,
41{
42    /// Create a new agent client
43    ///
44    /// # Arguments
45    ///
46    /// * `service` - The Tower service that handles requests
47    /// * `config` - Client configuration
48    pub fn new(service: S, config: ClientConfig) -> Self {
49        Self { service, config }
50    }
51
52    /// Get the client configuration
53    pub fn config(&self) -> &ClientConfig {
54        &self.config
55    }
56
57    /// Build a request context from the client configuration
58    fn build_context(&self) -> RequestContext {
59        RequestContext {
60            agent_url: self.config.agent_url.clone(),
61            auth: None, // Set by AuthLayer
62            timeout: Some(self.config.timeout),
63            metadata: Default::default(),
64        }
65    }
66
67    /// Send a message to the agent and get a task
68    ///
69    /// # Arguments
70    ///
71    /// * `message` - The message to send to the agent
72    ///
73    /// # Returns
74    ///
75    /// A task representing the agent's processing of the message
76    ///
77    /// # Errors
78    ///
79    /// Returns an error if the message fails to send or the response is invalid
80    pub async fn send_message(&mut self, message: Message) -> Result<Task, A2AError> {
81        let operation = A2AOperation::SendMessage {
82            message,
83            stream: false,
84            context_id: None,
85            task_id: None,
86        };
87
88        let request = A2ARequest::new(operation, self.build_context());
89        let response = self.service.call(request).await?;
90
91        match response {
92            A2AResponse::Task(task) => Ok(*task),
93            _ => Err(A2AError::Protocol(
94                "Expected task response from send_message".into(),
95            )),
96        }
97    }
98
99    /// Send a message with streaming enabled
100    ///
101    /// Note: Streaming is not yet fully implemented
102    pub async fn send_message_streaming(&mut self, message: Message) -> Result<Task, A2AError> {
103        let operation = A2AOperation::SendMessage {
104            message,
105            stream: true,
106            context_id: None,
107            task_id: None,
108        };
109
110        let request = A2ARequest::new(operation, self.build_context());
111        let response = self.service.call(request).await?;
112
113        match response {
114            A2AResponse::Task(task) => Ok(*task),
115            _ => Err(A2AError::Protocol(
116                "Expected task response from send_message_streaming".into(),
117            )),
118        }
119    }
120
121    /// Send a message in a specific context for multi-turn conversations
122    ///
123    /// # Arguments
124    ///
125    /// * `message` - The message to send
126    /// * `context_id` - The context ID for grouping related messages
127    pub async fn send_message_in_context(
128        &mut self,
129        message: Message,
130        context_id: String,
131    ) -> Result<Task, A2AError> {
132        let operation = A2AOperation::SendMessage {
133            message,
134            stream: false,
135            context_id: Some(context_id),
136            task_id: None,
137        };
138
139        let request = A2ARequest::new(operation, self.build_context());
140        let response = self.service.call(request).await?;
141
142        match response {
143            A2AResponse::Task(task) => Ok(*task),
144            _ => Err(A2AError::Protocol(
145                "Expected task response from send_message_in_context".into(),
146            )),
147        }
148    }
149
150    /// Get a task by ID
151    ///
152    /// # Arguments
153    ///
154    /// * `task_id` - The unique identifier of the task to retrieve
155    ///
156    /// # Returns
157    ///
158    /// The task with the specified ID
159    ///
160    /// # Errors
161    ///
162    /// Returns `A2AError::TaskNotFound` if the task doesn't exist
163    pub async fn get_task(&mut self, task_id: String) -> Result<Task, A2AError> {
164        let operation = A2AOperation::GetTask { task_id };
165
166        let request = A2ARequest::new(operation, self.build_context());
167        let response = self.service.call(request).await?;
168
169        match response {
170            A2AResponse::Task(task) => Ok(*task),
171            _ => Err(A2AError::Protocol(
172                "Expected task response from get_task".into(),
173            )),
174        }
175    }
176
177    /// List tasks with optional filtering
178    ///
179    /// # Arguments
180    ///
181    /// * `status` - Optional filter by task status
182    /// * `limit` - Maximum number of tasks to return (default: 100)
183    ///
184    /// # Returns
185    ///
186    /// A vector of tasks matching the query
187    pub async fn list_tasks(
188        &mut self,
189        status: Option<TaskStatus>,
190        limit: Option<u32>,
191    ) -> Result<Vec<Task>, A2AError> {
192        let operation = A2AOperation::ListTasks {
193            status,
194            limit,
195            offset: None,
196            next_token: None,
197        };
198
199        let request = A2ARequest::new(operation, self.build_context());
200        let response = self.service.call(request).await?;
201
202        match response {
203            A2AResponse::TaskList { tasks, .. } => Ok(tasks),
204            _ => Err(A2AError::Protocol(
205                "Expected task list response from list_tasks".into(),
206            )),
207        }
208    }
209
210    /// List all tasks without filtering
211    pub async fn list_all_tasks(&mut self) -> Result<Vec<Task>, A2AError> {
212        self.list_tasks(None, None).await
213    }
214
215    /// List tasks with a specific status
216    pub async fn list_tasks_by_status(
217        &mut self,
218        status: TaskStatus,
219    ) -> Result<Vec<Task>, A2AError> {
220        self.list_tasks(Some(status), None).await
221    }
222
223    /// Cancel a task by ID
224    ///
225    /// # Arguments
226    ///
227    /// * `task_id` - The unique identifier of the task to cancel
228    ///
229    /// # Returns
230    ///
231    /// The updated task with cancelled status
232    pub async fn cancel_task(&mut self, task_id: String) -> Result<Task, A2AError> {
233        let operation = A2AOperation::CancelTask { task_id };
234
235        let request = A2ARequest::new(operation, self.build_context());
236        let response = self.service.call(request).await?;
237
238        match response {
239            A2AResponse::Task(task) => Ok(*task),
240            _ => Err(A2AError::Protocol(
241                "Expected task response from cancel_task".into(),
242            )),
243        }
244    }
245
246    /// Discover agent capabilities by fetching the Agent Card
247    ///
248    /// This retrieves the agent's metadata from `/.well-known/agent-card.json`
249    ///
250    /// # Returns
251    ///
252    /// The agent's capability card
253    pub async fn discover(&mut self) -> Result<AgentCard, A2AError> {
254        let operation = A2AOperation::DiscoverAgent;
255
256        let request = A2ARequest::new(operation, self.build_context());
257        let response = self.service.call(request).await?;
258
259        match response {
260            A2AResponse::AgentCard(card) => Ok(*card),
261            _ => Err(A2AError::Protocol(
262                "Expected agent card response from discover".into(),
263            )),
264        }
265    }
266
267    /// Poll a task until it reaches a terminal state
268    ///
269    /// This is a convenience method that repeatedly calls get_task until
270    /// the task is completed, failed, cancelled, or rejected.
271    ///
272    /// # Arguments
273    ///
274    /// * `task_id` - The task ID to poll
275    /// * `poll_interval` - How often to poll (in milliseconds)
276    /// * `max_attempts` - Maximum number of polling attempts (0 = unlimited)
277    ///
278    /// # Returns
279    ///
280    /// The final task state
281    pub async fn poll_until_complete(
282        &mut self,
283        task_id: String,
284        poll_interval_ms: u64,
285        max_attempts: usize,
286    ) -> Result<Task, A2AError> {
287        let mut attempts = 0;
288
289        loop {
290            let task = self.get_task(task_id.clone()).await?;
291
292            if task.is_terminal() {
293                return Ok(task);
294            }
295
296            attempts += 1;
297            if max_attempts > 0 && attempts >= max_attempts {
298                return Err(A2AError::Timeout);
299            }
300
301            tokio::time::sleep(tokio::time::Duration::from_millis(poll_interval_ms)).await;
302        }
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use std::sync::Arc;
309
310    use crate::{
311        codec::JsonCodec,
312        protocol::message::Message,
313        service::A2AProtocolService,
314        transport::{mock::MockTransport, TransportResponse},
315    };
316    use bytes::Bytes;
317
318    use super::*;
319
320    #[tokio::test]
321    async fn test_send_message() {
322        let transport = MockTransport::new(|_req| {
323            let task = Task::new("task-123", Message::user("Test"));
324            let json = serde_json::to_vec(&task).unwrap();
325            TransportResponse::new(200).body(Bytes::from(json))
326        });
327
328        let codec = Arc::new(JsonCodec::new());
329        let service = A2AProtocolService::new(transport, codec);
330        let config = ClientConfig::new("https://example.com");
331        let mut client = AgentClient::new(service, config);
332
333        let message = Message::user("Hello");
334        let task = client.send_message(message).await.unwrap();
335
336        assert_eq!(task.id, "task-123");
337    }
338
339    #[tokio::test]
340    async fn test_get_task() {
341        let transport = MockTransport::new(|_req| {
342            let task = Task::new("task-456", Message::user("Test"));
343            let json = serde_json::to_vec(&task).unwrap();
344            TransportResponse::new(200).body(Bytes::from(json))
345        });
346
347        let codec = Arc::new(JsonCodec::new());
348        let service = A2AProtocolService::new(transport, codec);
349        let config = ClientConfig::new("https://example.com");
350        let mut client = AgentClient::new(service, config);
351
352        let task = client.get_task("task-456".to_string()).await.unwrap();
353
354        assert_eq!(task.id, "task-456");
355    }
356
357    #[tokio::test]
358    async fn test_discover() {
359        use crate::protocol::agent::{AgentCapabilities, AgentCard};
360
361        let transport = MockTransport::new(|_req| {
362            let card = AgentCard::new("Test Agent", "A test agent", AgentCapabilities::default());
363            let json = serde_json::to_vec(&card).unwrap();
364            TransportResponse::new(200).body(Bytes::from(json))
365        });
366
367        let codec = Arc::new(JsonCodec::new());
368        let service = A2AProtocolService::new(transport, codec);
369        let config = ClientConfig::new("https://example.com");
370        let mut client = AgentClient::new(service, config);
371
372        let card = client.discover().await.unwrap();
373
374        assert_eq!(card.name, "Test Agent");
375    }
376}