Skip to main content

a2a_rs_client/
client.rs

1//! A2A Client implementation
2//!
3//! Provides a reusable client for A2A RC 1.0 agent communication.
4
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use a2a_rs_core::{
9    AgentCard, GetTaskRequest, JsonRpcRequest, JsonRpcResponse, Message, SendMessageRequest,
10    SendMessageResponse, Task,
11};
12use anyhow::{anyhow, Context, Result};
13use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
14use rand::Rng;
15use reqwest::{Client, Url};
16use serde::{Deserialize, Serialize};
17use sha2::{Digest, Sha256};
18use tokio::sync::RwLock;
19use tokio::time::sleep;
20use tracing::{info, warn};
21
22/// Duration to cache the agent card (5 minutes)
23const AGENT_CARD_CACHE_TTL: Duration = Duration::from_secs(300);
24
25/// Client configuration
26#[derive(Debug, Clone)]
27pub struct ClientConfig {
28    /// Base URL of the A2A server
29    pub server_url: String,
30    /// Maximum number of poll attempts for task completion
31    pub max_polls: u32,
32    /// Milliseconds between poll attempts
33    pub poll_interval_ms: u64,
34    /// OAuth configuration (if using OAuth authentication)
35    pub oauth: Option<OAuthConfig>,
36}
37
38impl Default for ClientConfig {
39    fn default() -> Self {
40        Self {
41            server_url: "http://127.0.0.1:8080".to_string(),
42            max_polls: 30,
43            poll_interval_ms: 2000,
44            oauth: None,
45        }
46    }
47}
48
49/// OAuth configuration for client authentication
50#[derive(Debug, Clone)]
51pub struct OAuthConfig {
52    /// Client ID for OAuth
53    pub client_id: String,
54    /// Redirect URI for OAuth callback
55    pub redirect_uri: String,
56    /// OAuth scopes to request
57    pub scopes: Vec<String>,
58    /// Pre-existing session token (skip OAuth flow if provided)
59    pub session_token: Option<String>,
60}
61
62impl Default for OAuthConfig {
63    fn default() -> Self {
64        Self {
65            client_id: "a2a-client".to_string(),
66            redirect_uri: "http://localhost:3000/callback".to_string(),
67            scopes: vec![
68                "User.Read".to_string(),
69                "Sites.Read.All".to_string(),
70                "Mail.Read".to_string(),
71                "offline_access".to_string(),
72            ],
73            session_token: None,
74        }
75    }
76}
77
78/// Cached agent card with expiration
79struct CachedCard {
80    card: AgentCard,
81    fetched_at: Instant,
82}
83
84impl CachedCard {
85    fn is_valid(&self) -> bool {
86        self.fetched_at.elapsed() < AGENT_CARD_CACHE_TTL
87    }
88}
89
90/// A2A Client for communicating with A2A-compliant agent servers
91#[derive(Clone)]
92pub struct A2aClient {
93    config: ClientConfig,
94    http: Client,
95    base_url: Url,
96    /// Cached agent card to avoid repeated fetches
97    card_cache: Arc<RwLock<Option<CachedCard>>>,
98    /// Cached RPC endpoint URL for fast lookups (derived from agent card)
99    endpoint_cache: Arc<RwLock<Option<String>>>,
100}
101
102impl std::fmt::Debug for A2aClient {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        f.debug_struct("A2aClient")
105            .field("config", &self.config)
106            .field("base_url", &self.base_url)
107            .finish_non_exhaustive()
108    }
109}
110
111#[derive(Debug, Serialize)]
112struct OAuthAuthorizeRequest {
113    response_type: String,
114    client_id: String,
115    redirect_uri: String,
116    scope: String,
117    state: String,
118    code_challenge: String,
119    code_challenge_method: String,
120}
121
122#[derive(Debug, Deserialize)]
123struct OAuthAuthorizeResponse {
124    authorization_url: String,
125    #[allow(dead_code)]
126    state: String,
127}
128
129impl A2aClient {
130    /// Create a new A2A client with the given configuration
131    pub fn new(config: ClientConfig) -> Result<Self> {
132        let base_url = Url::parse(&config.server_url)
133            .with_context(|| format!("Invalid server URL: {}", config.server_url))?;
134
135        Ok(Self {
136            config,
137            http: Client::new(),
138            base_url,
139            card_cache: Arc::new(RwLock::new(None)),
140            endpoint_cache: Arc::new(RwLock::new(None)),
141        })
142    }
143
144    /// Create a client with default configuration for a given server URL
145    pub fn with_server(server_url: &str) -> Result<Self> {
146        Self::new(ClientConfig {
147            server_url: server_url.to_string(),
148            ..Default::default()
149        })
150    }
151
152    /// Get the server base URL
153    pub fn server_url(&self) -> &str {
154        &self.config.server_url
155    }
156
157    /// Fetch the agent card, using cache if available and valid
158    pub async fn fetch_agent_card(&self) -> Result<AgentCard> {
159        // Check cache first
160        {
161            let cache = self.card_cache.read().await;
162            if let Some(cached) = cache.as_ref() {
163                if cached.is_valid() {
164                    return Ok(cached.card.clone());
165                }
166            }
167        }
168
169        // Fetch fresh card
170        let url = self.base_url.join("/.well-known/agent-card.json")?;
171        let card: AgentCard = self
172            .http
173            .get(url)
174            .send()
175            .await?
176            .error_for_status()?
177            .json()
178            .await?;
179
180        // Update cache
181        {
182            let mut cache = self.card_cache.write().await;
183            *cache = Some(CachedCard {
184                card: card.clone(),
185                fetched_at: Instant::now(),
186            });
187        }
188
189        Ok(card)
190    }
191
192    /// Invalidate the cached agent card and endpoint
193    pub async fn invalidate_card_cache(&self) {
194        let mut cache = self.card_cache.write().await;
195        *cache = None;
196        let mut endpoint = self.endpoint_cache.write().await;
197        *endpoint = None;
198    }
199
200    /// Get the cached RPC endpoint URL, fetching from agent card if needed
201    async fn get_cached_endpoint(&self) -> Result<String> {
202        // Check endpoint cache first
203        {
204            let cache = self.endpoint_cache.read().await;
205            if let Some(endpoint) = cache.as_ref() {
206                return Ok(endpoint.clone());
207            }
208        }
209
210        // Fetch agent card and cache the endpoint
211        let card = self.fetch_agent_card().await?;
212        let endpoint = card
213            .endpoint()
214            .ok_or_else(|| anyhow!("Agent card has no JSONRPC endpoint"))?
215            .to_string();
216
217        {
218            let mut cache = self.endpoint_cache.write().await;
219            *cache = Some(endpoint.clone());
220        }
221
222        Ok(endpoint)
223    }
224
225    /// Get the JSON-RPC endpoint URL from the agent card
226    #[inline]
227    pub fn get_rpc_url(card: &AgentCard) -> Option<&str> {
228        card.endpoint()
229    }
230
231    /// Send a JSON-RPC request and parse the response
232    async fn json_rpc_call<P: Serialize, R: for<'de> Deserialize<'de>>(
233        &self,
234        method: &str,
235        params: P,
236        session_token: Option<&str>,
237    ) -> Result<R> {
238        let rpc_url = self.get_cached_endpoint().await?;
239
240        let request = JsonRpcRequest {
241            jsonrpc: "2.0".into(),
242            method: method.into(),
243            params: Some(serde_json::to_value(params)?),
244            id: serde_json::json!(1),
245        };
246
247        let mut req_builder = self.http.post(rpc_url).json(&request);
248        if let Some(token) = session_token {
249            req_builder = req_builder.header("Authorization", format!("Bearer {token}"));
250        }
251
252        let mut resp: JsonRpcResponse = req_builder
253            .send()
254            .await?
255            .error_for_status()?
256            .json()
257            .await?;
258
259        if let Some(err) = resp.error.take() {
260            anyhow::bail!("Server error {}: {}", err.code, err.message);
261        }
262
263        resp.result
264            .as_ref()
265            .map(|v| serde_json::from_value(v.clone()))
266            .transpose()?
267            .ok_or_else(|| anyhow!("Server returned no result"))
268    }
269
270    /// Send a message to the agent and receive a response (Task or Message)
271    pub async fn send_message(
272        &self,
273        message: Message,
274        session_token: Option<&str>,
275    ) -> Result<SendMessageResponse> {
276        let params = SendMessageRequest {
277            tenant: None,
278            message,
279            configuration: None,
280            metadata: None,
281        };
282        self.json_rpc_call("message/send", params, session_token).await
283    }
284
285    /// Poll a task by ID
286    pub async fn poll_task(&self, task_id: &str, session_token: Option<&str>) -> Result<Task> {
287        let params = GetTaskRequest {
288            id: task_id.to_string(),
289            history_length: None,
290            tenant: None,
291        };
292        self.json_rpc_call("tasks/get", params, session_token).await
293    }
294
295    /// Poll a task until it reaches a terminal state or max polls exceeded
296    pub async fn poll_until_complete(
297        &self,
298        task_id: &str,
299        session_token: Option<&str>,
300    ) -> Result<Task> {
301        let mut task = self.poll_task(task_id, session_token).await?;
302
303        for i in 0..self.config.max_polls {
304            if task.status.state.is_terminal() {
305                return Ok(task);
306            }
307
308            sleep(Duration::from_millis(self.config.poll_interval_ms)).await;
309
310            match self.poll_task(task_id, session_token).await {
311                Ok(updated_task) => {
312                    info!(
313                        "Poll {}/{}: state={:?}",
314                        i + 1,
315                        self.config.max_polls,
316                        updated_task.status.state
317                    );
318                    task = updated_task;
319                }
320                Err(e) => {
321                    warn!("Poll {}/{} failed: {}", i + 1, self.config.max_polls, e);
322                    // Continue polling, the task might still complete
323                }
324            }
325        }
326
327        Ok(task)
328    }
329
330    /// Perform interactive OAuth flow (prompts user to visit URL and paste callback)
331    pub async fn perform_oauth_interactive(&self) -> Result<String> {
332        let oauth_config = self
333            .config
334            .oauth
335            .as_ref()
336            .ok_or_else(|| anyhow!("OAuth not configured"))?;
337
338        // If we already have a session token, return it
339        if let Some(token) = &oauth_config.session_token {
340            return Ok(token.clone());
341        }
342
343        // Generate PKCE code verifier and challenge
344        let code_verifier = generate_code_verifier();
345        let code_challenge = generate_code_challenge(&code_verifier);
346        let client_state = generate_random_string(32);
347
348        let authorize_req = OAuthAuthorizeRequest {
349            response_type: "code".to_string(),
350            client_id: oauth_config.client_id.clone(),
351            redirect_uri: oauth_config.redirect_uri.clone(),
352            scope: oauth_config.scopes.join(" "),
353            state: client_state.clone(),
354            code_challenge,
355            code_challenge_method: "S256".to_string(),
356        };
357
358        // Call /oauth/authorize endpoint
359        let oauth_url = self.base_url.join("/oauth/authorize")?;
360        let auth_response: OAuthAuthorizeResponse = self
361            .http
362            .post(oauth_url)
363            .json(&authorize_req)
364            .send()
365            .await?
366            .error_for_status()?
367            .json()
368            .await?;
369
370        // Display authorization URL to user
371        println!("\nšŸ” OAuth Authentication Required");
372        println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
373        println!("Please visit this URL to authenticate:\n");
374        println!("{}\n", auth_response.authorization_url);
375        println!("After authentication, you'll be redirected to:");
376        println!("{}?session_token=...", oauth_config.redirect_uri);
377        println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n");
378
379        // Prompt user to paste the session token
380        println!("Paste the full redirect URL here:");
381        let mut input = String::new();
382        std::io::stdin().read_line(&mut input)?;
383        let input = input.trim();
384
385        // Extract session_token from the redirect URL
386        let parsed_url = Url::parse(input).or_else(|_| {
387            if input.starts_with("session_token=") || input.contains("session_token=") {
388                Ok(Url::parse(&format!(
389                    "{}?{}",
390                    oauth_config.redirect_uri, input
391                ))?)
392            } else {
393                Err(anyhow!("Invalid URL or token format"))
394            }
395        })?;
396
397        let session_token = parsed_url
398            .query_pairs()
399            .find(|(key, _)| key == "session_token")
400            .map(|(_, value)| value.to_string())
401            .ok_or_else(|| anyhow!("No session_token found in URL"))?;
402
403        Ok(session_token)
404    }
405
406    /// Start OAuth flow and return authorization URL (for programmatic use)
407    pub async fn start_oauth_flow(&self) -> Result<(String, String)> {
408        let oauth_config = self
409            .config
410            .oauth
411            .as_ref()
412            .ok_or_else(|| anyhow!("OAuth not configured"))?;
413
414        let code_verifier = generate_code_verifier();
415        let code_challenge = generate_code_challenge(&code_verifier);
416        let client_state = generate_random_string(32);
417
418        let authorize_req = OAuthAuthorizeRequest {
419            response_type: "code".to_string(),
420            client_id: oauth_config.client_id.clone(),
421            redirect_uri: oauth_config.redirect_uri.clone(),
422            scope: oauth_config.scopes.join(" "),
423            state: client_state.clone(),
424            code_challenge,
425            code_challenge_method: "S256".to_string(),
426        };
427
428        let oauth_url = self.base_url.join("/oauth/authorize")?;
429        let auth_response: OAuthAuthorizeResponse = self
430            .http
431            .post(oauth_url)
432            .json(&authorize_req)
433            .send()
434            .await?
435            .error_for_status()?
436            .json()
437            .await?;
438
439        Ok((auth_response.authorization_url, code_verifier))
440    }
441}
442
443/// Generate a PKCE code verifier (43-128 character random string)
444pub fn generate_code_verifier() -> String {
445    let mut rng = rand::thread_rng();
446    let random_bytes: Vec<u8> = (0..32).map(|_| rng.gen()).collect();
447    URL_SAFE_NO_PAD.encode(&random_bytes)
448}
449
450/// Generate a PKCE code challenge from a code verifier using S256 method
451pub fn generate_code_challenge(verifier: &str) -> String {
452    let mut hasher = Sha256::new();
453    hasher.update(verifier.as_bytes());
454    let hash = hasher.finalize();
455    URL_SAFE_NO_PAD.encode(hash)
456}
457
458/// Generate a random string for state parameter
459pub fn generate_random_string(length: usize) -> String {
460    let mut rng = rand::thread_rng();
461    let random_bytes: Vec<u8> = (0..length).map(|_| rng.gen()).collect();
462    URL_SAFE_NO_PAD.encode(&random_bytes)
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468
469    #[test]
470    fn test_default_config() {
471        let config = ClientConfig::default();
472        assert_eq!(config.server_url, "http://127.0.0.1:8080");
473        assert_eq!(config.max_polls, 30);
474        assert_eq!(config.poll_interval_ms, 2000);
475        assert!(config.oauth.is_none());
476    }
477
478    #[test]
479    fn test_code_challenge() {
480        let verifier = generate_code_verifier();
481        let challenge = generate_code_challenge(&verifier);
482
483        // Verifier should be URL-safe base64
484        assert!(verifier.chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
485        // Challenge should also be URL-safe base64
486        assert!(challenge.chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
487    }
488
489    #[test]
490    fn test_client_creation() {
491        let client = A2aClient::with_server("http://localhost:8080").unwrap();
492        assert_eq!(client.server_url(), "http://localhost:8080");
493    }
494}