1use super::protocol::{
4 AgentCard, JsonRpcRequest, JsonRpcResponse, TaskRequest, TaskResponse, TaskStatusUpdate,
5 methods,
6};
7use anyhow::{Context, Result};
8use std::collections::HashMap;
9use tokio::sync::RwLock;
10
11pub struct A2aClient {
13 http: reqwest::Client,
14 agent_cache: RwLock<HashMap<String, AgentCard>>,
16 agent_tokens: RwLock<HashMap<String, String>>,
18 auth_token: Option<String>,
20}
21
22pub struct SendTaskOptions {
24 pub timeout_secs: Option<u64>,
26 pub subscribe: bool,
28}
29
30impl Default for SendTaskOptions {
31 fn default() -> Self {
32 Self {
33 timeout_secs: Some(30),
34 subscribe: false,
35 }
36 }
37}
38
39impl A2aClient {
40 pub fn new(auth_token: Option<String>) -> Self {
42 Self {
43 http: reqwest::Client::builder()
44 .connect_timeout(std::time::Duration::from_secs(10))
45 .timeout(std::time::Duration::from_secs(30))
46 .build()
47 .unwrap_or_else(|_| reqwest::Client::new()),
48 agent_cache: RwLock::new(HashMap::new()),
49 agent_tokens: RwLock::new(HashMap::new()),
50 auth_token,
51 }
52 }
53
54 pub async fn set_agent_token(&self, base_url: &str, token: String) {
56 let mut tokens = self.agent_tokens.write().await;
57 tokens.insert(base_url.to_string(), token);
58 }
59
60 async fn resolve_token(&self, base_url: &str) -> Option<String> {
63 let tokens = self.agent_tokens.read().await;
64 if let Some(token) = tokens.get(base_url) {
65 return Some(token.clone());
66 }
67 self.auth_token.clone()
68 }
69
70 fn validate_url(url_str: &str) -> Result<()> {
72 if !url_str.starts_with("http://") && !url_str.starts_with("https://") {
73 anyhow::bail!("Only http/https URLs allowed, got: {}", url_str);
74 }
75 let host = url_str
77 .split("://").nth(1).unwrap_or("")
78 .split('/').next().unwrap_or("")
79 .split(':').next().unwrap_or("");
80 let blocked = [
81 "localhost", "127.0.0.1", "::1", "0.0.0.0",
82 "metadata.google.internal",
83 ];
84 if blocked.contains(&host)
85 || host.starts_with("169.254.")
86 || host.starts_with("10.")
87 || host.starts_with("192.168.")
88 || host.starts_with("172.16.") || host.starts_with("172.17.")
89 {
90 anyhow::bail!("SSRF blocked: URL points to private/loopback address: {}", host);
91 }
92 Ok(())
93 }
94
95 pub async fn discover(&self, base_url: &str) -> Result<AgentCard> {
97 let url = format!(
98 "{}/.well-known/agent.json",
99 base_url.trim_end_matches('/')
100 );
101 Self::validate_url(&url)?;
102
103 let mut request = self.http.get(&url);
104 if let Some(token) = self.resolve_token(base_url).await {
105 request = request.header("Authorization", format!("Bearer {}", token));
106 }
107
108 let response = request
109 .send()
110 .await
111 .with_context(|| format!("Fetching agent card from {}", url))?;
112
113 if !response.status().is_success() {
114 anyhow::bail!(
115 "Agent discovery failed for {}: HTTP {}",
116 url,
117 response.status()
118 );
119 }
120
121 let card: AgentCard = response
122 .json()
123 .await
124 .context("Parsing agent card JSON")?;
125
126 {
128 let mut cache = self.agent_cache.write().await;
129 cache.insert(base_url.to_string(), card.clone());
130 }
131
132 Ok(card)
133 }
134
135 pub async fn get_agent(&self, base_url: &str) -> Result<AgentCard> {
137 {
138 let cache = self.agent_cache.read().await;
139 if let Some(card) = cache.get(base_url) {
140 return Ok(card.clone());
141 }
142 }
143 self.discover(base_url).await
144 }
145
146 pub async fn send_task(
148 &self,
149 agent_url: &str,
150 task: &TaskRequest,
151 options: Option<SendTaskOptions>,
152 ) -> Result<TaskResponse> {
153 let rpc_request = JsonRpcRequest::new(
154 methods::TASKS_SEND,
155 Some(serde_json::to_value(task)?),
156 serde_json::json!(uuid::Uuid::new_v4().to_string()),
157 );
158
159 let timeout_secs = options.and_then(|o| o.timeout_secs);
160 let rpc_future = self.send_rpc(agent_url, &rpc_request);
161
162 let response = if let Some(secs) = timeout_secs {
163 tokio::time::timeout(std::time::Duration::from_secs(secs), rpc_future)
164 .await
165 .map_err(|_| anyhow::anyhow!("Task send timed out after {}s", secs))??
166 } else {
167 rpc_future.await?
168 };
169
170 if let Some(error) = response.error {
171 anyhow::bail!("A2A error {}: {}", error.code, error.message);
172 }
173
174 let result = response
175 .result
176 .ok_or_else(|| anyhow::anyhow!("Empty result from agent"))?;
177
178 serde_json::from_value(result).context("Parsing task response")
179 }
180
181 pub async fn get_task(
183 &self,
184 agent_url: &str,
185 task_id: &str,
186 ) -> Result<TaskResponse> {
187 let rpc_request = JsonRpcRequest::new(
188 methods::TASKS_GET,
189 Some(serde_json::json!({"id": task_id})),
190 serde_json::json!(uuid::Uuid::new_v4().to_string()),
191 );
192
193 let response = self.send_rpc(agent_url, &rpc_request).await?;
194
195 if let Some(error) = response.error {
196 anyhow::bail!("A2A error {}: {}", error.code, error.message);
197 }
198
199 let result = response
200 .result
201 .ok_or_else(|| anyhow::anyhow!("Empty result from agent"))?;
202
203 serde_json::from_value(result).context("Parsing task response")
204 }
205
206 pub async fn cancel_task(
208 &self,
209 agent_url: &str,
210 task_id: &str,
211 ) -> Result<TaskStatusUpdate> {
212 let rpc_request = JsonRpcRequest::new(
213 methods::TASKS_CANCEL,
214 Some(serde_json::json!({"id": task_id})),
215 serde_json::json!(uuid::Uuid::new_v4().to_string()),
216 );
217
218 let response = self.send_rpc(agent_url, &rpc_request).await?;
219
220 if let Some(error) = response.error {
221 anyhow::bail!("A2A error {}: {}", error.code, error.message);
222 }
223
224 let result = response
225 .result
226 .ok_or_else(|| anyhow::anyhow!("Empty result from agent"))?;
227
228 serde_json::from_value(result).context("Parsing cancel response")
229 }
230
231 async fn send_rpc(
233 &self,
234 agent_url: &str,
235 request: &JsonRpcRequest,
236 ) -> Result<JsonRpcResponse> {
237 let url = format!("{}/a2a", agent_url.trim_end_matches('/'));
238
239 let mut http_request = self
240 .http
241 .post(&url)
242 .header("Content-Type", "application/json");
243
244 if let Some(token) = self.resolve_token(agent_url).await {
245 http_request =
246 http_request.header("Authorization", format!("Bearer {}", token));
247 }
248
249 let response = http_request
250 .json(request)
251 .send()
252 .await
253 .with_context(|| format!("Sending A2A request to {}", url))?;
254
255 if !response.status().is_success() {
256 anyhow::bail!("A2A request failed: HTTP {}", response.status());
257 }
258
259 response.json().await.context("Parsing JSON-RPC response")
260 }
261
262 pub async fn cached_agents(&self) -> Vec<AgentCard> {
264 let cache = self.agent_cache.read().await;
265 cache.values().cloned().collect()
266 }
267
268 pub async fn clear_cache(&self) {
270 let mut cache = self.agent_cache.write().await;
271 cache.clear();
272 }
273}
274
275#[derive(Debug)]
277pub struct BroadcastResult {
278 pub agent_url: String,
279 pub result: Result<TaskResponse>,
280}
281
282pub async fn broadcast_task(
287 client: &A2aClient,
288 agent_urls: &[String],
289 task: &TaskRequest,
290) -> Vec<BroadcastResult> {
291 let mut results = Vec::with_capacity(agent_urls.len());
292
293 for url in agent_urls {
294 let result = client.send_task(url, task, None).await;
295 results.push(BroadcastResult {
296 agent_url: url.clone(),
297 result,
298 });
299 }
300
301 results
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[tokio::test]
309 async fn test_client_creation() {
310 let client = A2aClient::new(None);
311 assert!(client.cached_agents().await.is_empty());
312
313 let client_auth = A2aClient::new(Some("test-token".into()));
314 assert!(client_auth.auth_token.is_some());
315 }
316
317 #[test]
318 fn test_send_task_options_default() {
319 let opts = SendTaskOptions::default();
320 assert_eq!(opts.timeout_secs, Some(30));
321 assert!(!opts.subscribe);
322 }
323
324 #[tokio::test]
325 async fn test_cache_operations() {
326 let client = A2aClient::new(None);
327 assert!(client.cached_agents().await.is_empty());
328 client.clear_cache().await;
329 assert!(client.cached_agents().await.is_empty());
330 }
331
332 #[tokio::test]
333 async fn test_per_agent_token() {
334 let client = A2aClient::new(Some("global-token".into()));
335
336 let token = client.resolve_token("https://agent-a.example.com").await;
338 assert_eq!(token, Some("global-token".into()));
339
340 client.set_agent_token("https://agent-a.example.com", "agent-a-token".into()).await;
342
343 let token = client.resolve_token("https://agent-a.example.com").await;
345 assert_eq!(token, Some("agent-a-token".into()));
346
347 let token = client.resolve_token("https://agent-b.example.com").await;
349 assert_eq!(token, Some("global-token".into()));
350 }
351
352 #[tokio::test]
353 async fn test_no_token() {
354 let client = A2aClient::new(None);
355 let token = client.resolve_token("https://agent.example.com").await;
356 assert!(token.is_none());
357 }
358}