Skip to main content

mur_core/a2a/
client.rs

1//! A2A Client — discover and communicate with remote agents.
2
3use super::protocol::{
4    AgentCard, JsonRpcRequest, JsonRpcResponse, TaskRequest, TaskResponse, TaskStatusUpdate,
5    methods,
6};
7use anyhow::{Context, Result};
8use std::collections::HashMap;
9use tokio::sync::RwLock;
10
11/// Client for communicating with A2A-compatible agents.
12pub struct A2aClient {
13    http: reqwest::Client,
14    /// Cache of discovered agent cards keyed by base URL.
15    agent_cache: RwLock<HashMap<String, AgentCard>>,
16    /// Per-agent bearer tokens keyed by base URL.
17    agent_tokens: RwLock<HashMap<String, String>>,
18    /// Optional global fallback bearer token used when no per-agent token is set.
19    auth_token: Option<String>,
20}
21
22/// Options for sending a task.
23pub struct SendTaskOptions {
24    /// Timeout in seconds.
25    pub timeout_secs: Option<u64>,
26    /// Whether to subscribe to streaming updates.
27    pub subscribe: bool,
28}
29
30impl Default for SendTaskOptions {
31    fn default() -> Self {
32        Self {
33            timeout_secs: Some(30),
34            subscribe: false,
35        }
36    }
37}
38
39impl A2aClient {
40    /// Create a new A2A client with an optional global fallback auth token.
41    pub fn new(auth_token: Option<String>) -> Self {
42        Self {
43            http: reqwest::Client::builder()
44                .connect_timeout(std::time::Duration::from_secs(10))
45                .timeout(std::time::Duration::from_secs(30))
46                .build()
47                .unwrap_or_else(|_| reqwest::Client::new()),
48            agent_cache: RwLock::new(HashMap::new()),
49            agent_tokens: RwLock::new(HashMap::new()),
50            auth_token,
51        }
52    }
53
54    /// Set an authentication token for a specific agent base URL.
55    pub async fn set_agent_token(&self, base_url: &str, token: String) {
56        let mut tokens = self.agent_tokens.write().await;
57        tokens.insert(base_url.to_string(), token);
58    }
59
60    /// Look up the bearer token for a given base URL.
61    /// Returns the per-agent token if set, otherwise falls back to the global token.
62    async fn resolve_token(&self, base_url: &str) -> Option<String> {
63        let tokens = self.agent_tokens.read().await;
64        if let Some(token) = tokens.get(base_url) {
65            return Some(token.clone());
66        }
67        self.auth_token.clone()
68    }
69
70    /// Validate a URL to prevent SSRF attacks.
71    fn validate_url(url_str: &str) -> Result<()> {
72        if !url_str.starts_with("http://") && !url_str.starts_with("https://") {
73            anyhow::bail!("Only http/https URLs allowed, got: {}", url_str);
74        }
75        // Extract host portion
76        let host = url_str
77            .split("://").nth(1).unwrap_or("")
78            .split('/').next().unwrap_or("")
79            .split(':').next().unwrap_or("");
80        let blocked = [
81            "localhost", "127.0.0.1", "::1", "0.0.0.0",
82            "metadata.google.internal",
83        ];
84        if blocked.contains(&host)
85            || host.starts_with("169.254.")
86            || host.starts_with("10.")
87            || host.starts_with("192.168.")
88            || host.starts_with("172.16.") || host.starts_with("172.17.")
89        {
90            anyhow::bail!("SSRF blocked: URL points to private/loopback address: {}", host);
91        }
92        Ok(())
93    }
94
95    /// Discover an agent by fetching its card from `/.well-known/agent.json`.
96    pub async fn discover(&self, base_url: &str) -> Result<AgentCard> {
97        let url = format!(
98            "{}/.well-known/agent.json",
99            base_url.trim_end_matches('/')
100        );
101        Self::validate_url(&url)?;
102
103        let mut request = self.http.get(&url);
104        if let Some(token) = self.resolve_token(base_url).await {
105            request = request.header("Authorization", format!("Bearer {}", token));
106        }
107
108        let response = request
109            .send()
110            .await
111            .with_context(|| format!("Fetching agent card from {}", url))?;
112
113        if !response.status().is_success() {
114            anyhow::bail!(
115                "Agent discovery failed for {}: HTTP {}",
116                url,
117                response.status()
118            );
119        }
120
121        let card: AgentCard = response
122            .json()
123            .await
124            .context("Parsing agent card JSON")?;
125
126        // Cache the card
127        {
128            let mut cache = self.agent_cache.write().await;
129            cache.insert(base_url.to_string(), card.clone());
130        }
131
132        Ok(card)
133    }
134
135    /// Get a cached agent card, or discover if not cached.
136    pub async fn get_agent(&self, base_url: &str) -> Result<AgentCard> {
137        {
138            let cache = self.agent_cache.read().await;
139            if let Some(card) = cache.get(base_url) {
140                return Ok(card.clone());
141            }
142        }
143        self.discover(base_url).await
144    }
145
146    /// Send a task to an agent, optionally applying a timeout from `SendTaskOptions`.
147    pub async fn send_task(
148        &self,
149        agent_url: &str,
150        task: &TaskRequest,
151        options: Option<SendTaskOptions>,
152    ) -> Result<TaskResponse> {
153        let rpc_request = JsonRpcRequest::new(
154            methods::TASKS_SEND,
155            Some(serde_json::to_value(task)?),
156            serde_json::json!(uuid::Uuid::new_v4().to_string()),
157        );
158
159        let timeout_secs = options.and_then(|o| o.timeout_secs);
160        let rpc_future = self.send_rpc(agent_url, &rpc_request);
161
162        let response = if let Some(secs) = timeout_secs {
163            tokio::time::timeout(std::time::Duration::from_secs(secs), rpc_future)
164                .await
165                .map_err(|_| anyhow::anyhow!("Task send timed out after {}s", secs))??
166        } else {
167            rpc_future.await?
168        };
169
170        if let Some(error) = response.error {
171            anyhow::bail!("A2A error {}: {}", error.code, error.message);
172        }
173
174        let result = response
175            .result
176            .ok_or_else(|| anyhow::anyhow!("Empty result from agent"))?;
177
178        serde_json::from_value(result).context("Parsing task response")
179    }
180
181    /// Get the status of a task.
182    pub async fn get_task(
183        &self,
184        agent_url: &str,
185        task_id: &str,
186    ) -> Result<TaskResponse> {
187        let rpc_request = JsonRpcRequest::new(
188            methods::TASKS_GET,
189            Some(serde_json::json!({"id": task_id})),
190            serde_json::json!(uuid::Uuid::new_v4().to_string()),
191        );
192
193        let response = self.send_rpc(agent_url, &rpc_request).await?;
194
195        if let Some(error) = response.error {
196            anyhow::bail!("A2A error {}: {}", error.code, error.message);
197        }
198
199        let result = response
200            .result
201            .ok_or_else(|| anyhow::anyhow!("Empty result from agent"))?;
202
203        serde_json::from_value(result).context("Parsing task response")
204    }
205
206    /// Cancel a running task.
207    pub async fn cancel_task(
208        &self,
209        agent_url: &str,
210        task_id: &str,
211    ) -> Result<TaskStatusUpdate> {
212        let rpc_request = JsonRpcRequest::new(
213            methods::TASKS_CANCEL,
214            Some(serde_json::json!({"id": task_id})),
215            serde_json::json!(uuid::Uuid::new_v4().to_string()),
216        );
217
218        let response = self.send_rpc(agent_url, &rpc_request).await?;
219
220        if let Some(error) = response.error {
221            anyhow::bail!("A2A error {}: {}", error.code, error.message);
222        }
223
224        let result = response
225            .result
226            .ok_or_else(|| anyhow::anyhow!("Empty result from agent"))?;
227
228        serde_json::from_value(result).context("Parsing cancel response")
229    }
230
231    /// Send a raw JSON-RPC request to an agent.
232    async fn send_rpc(
233        &self,
234        agent_url: &str,
235        request: &JsonRpcRequest,
236    ) -> Result<JsonRpcResponse> {
237        let url = format!("{}/a2a", agent_url.trim_end_matches('/'));
238
239        let mut http_request = self
240            .http
241            .post(&url)
242            .header("Content-Type", "application/json");
243
244        if let Some(token) = self.resolve_token(agent_url).await {
245            http_request =
246                http_request.header("Authorization", format!("Bearer {}", token));
247        }
248
249        let response = http_request
250            .json(request)
251            .send()
252            .await
253            .with_context(|| format!("Sending A2A request to {}", url))?;
254
255        if !response.status().is_success() {
256            anyhow::bail!("A2A request failed: HTTP {}", response.status());
257        }
258
259        response.json().await.context("Parsing JSON-RPC response")
260    }
261
262    /// List all cached agents.
263    pub async fn cached_agents(&self) -> Vec<AgentCard> {
264        let cache = self.agent_cache.read().await;
265        cache.values().cloned().collect()
266    }
267
268    /// Clear the agent cache.
269    pub async fn clear_cache(&self) {
270        let mut cache = self.agent_cache.write().await;
271        cache.clear();
272    }
273}
274
275/// Response from a multi-agent broadcast.
276#[derive(Debug)]
277pub struct BroadcastResult {
278    pub agent_url: String,
279    pub result: Result<TaskResponse>,
280}
281
282/// Broadcast a task to multiple agents sequentially.
283///
284/// Since `client` is behind `&` (not `Arc`), we dispatch one at a time.
285/// Each agent receives its own send_task call.
286pub async fn broadcast_task(
287    client: &A2aClient,
288    agent_urls: &[String],
289    task: &TaskRequest,
290) -> Vec<BroadcastResult> {
291    let mut results = Vec::with_capacity(agent_urls.len());
292
293    for url in agent_urls {
294        let result = client.send_task(url, task, None).await;
295        results.push(BroadcastResult {
296            agent_url: url.clone(),
297            result,
298        });
299    }
300
301    results
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    #[tokio::test]
309    async fn test_client_creation() {
310        let client = A2aClient::new(None);
311        assert!(client.cached_agents().await.is_empty());
312
313        let client_auth = A2aClient::new(Some("test-token".into()));
314        assert!(client_auth.auth_token.is_some());
315    }
316
317    #[test]
318    fn test_send_task_options_default() {
319        let opts = SendTaskOptions::default();
320        assert_eq!(opts.timeout_secs, Some(30));
321        assert!(!opts.subscribe);
322    }
323
324    #[tokio::test]
325    async fn test_cache_operations() {
326        let client = A2aClient::new(None);
327        assert!(client.cached_agents().await.is_empty());
328        client.clear_cache().await;
329        assert!(client.cached_agents().await.is_empty());
330    }
331
332    #[tokio::test]
333    async fn test_per_agent_token() {
334        let client = A2aClient::new(Some("global-token".into()));
335
336        // No per-agent token yet — should fall back to global
337        let token = client.resolve_token("https://agent-a.example.com").await;
338        assert_eq!(token, Some("global-token".into()));
339
340        // Set a per-agent token
341        client.set_agent_token("https://agent-a.example.com", "agent-a-token".into()).await;
342
343        // Per-agent token takes precedence
344        let token = client.resolve_token("https://agent-a.example.com").await;
345        assert_eq!(token, Some("agent-a-token".into()));
346
347        // Other agents still use the global fallback
348        let token = client.resolve_token("https://agent-b.example.com").await;
349        assert_eq!(token, Some("global-token".into()));
350    }
351
352    #[tokio::test]
353    async fn test_no_token() {
354        let client = A2aClient::new(None);
355        let token = client.resolve_token("https://agent.example.com").await;
356        assert!(token.is_none());
357    }
358}