1use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use a2a_rs_core::{
9 AgentCard, GetTaskRequest, JsonRpcRequest, JsonRpcResponse, Message, SendMessageRequest,
10 SendMessageResponse, Task,
11};
12use anyhow::{anyhow, Context, Result};
13use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
14use rand::Rng;
15use reqwest::{Client, Url};
16use serde::{Deserialize, Serialize};
17use sha2::{Digest, Sha256};
18use tokio::sync::RwLock;
19use tokio::time::sleep;
20use tracing::{info, warn};
21
22const AGENT_CARD_CACHE_TTL: Duration = Duration::from_secs(300);
24
25#[derive(Debug, Clone)]
27pub struct ClientConfig {
28 pub server_url: String,
30 pub max_polls: u32,
32 pub poll_interval_ms: u64,
34 pub oauth: Option<OAuthConfig>,
36}
37
38impl Default for ClientConfig {
39 fn default() -> Self {
40 Self {
41 server_url: "http://127.0.0.1:8080".to_string(),
42 max_polls: 30,
43 poll_interval_ms: 2000,
44 oauth: None,
45 }
46 }
47}
48
49#[derive(Debug, Clone)]
51pub struct OAuthConfig {
52 pub client_id: String,
54 pub redirect_uri: String,
56 pub scopes: Vec<String>,
58 pub session_token: Option<String>,
60}
61
62impl Default for OAuthConfig {
63 fn default() -> Self {
64 Self {
65 client_id: "a2a-client".to_string(),
66 redirect_uri: "http://localhost:3000/callback".to_string(),
67 scopes: vec![
68 "User.Read".to_string(),
69 "Sites.Read.All".to_string(),
70 "Mail.Read".to_string(),
71 "offline_access".to_string(),
72 ],
73 session_token: None,
74 }
75 }
76}
77
78struct CachedCard {
80 card: AgentCard,
81 fetched_at: Instant,
82}
83
84impl CachedCard {
85 fn is_valid(&self) -> bool {
86 self.fetched_at.elapsed() < AGENT_CARD_CACHE_TTL
87 }
88}
89
90#[derive(Clone)]
92pub struct A2aClient {
93 config: ClientConfig,
94 http: Client,
95 base_url: Url,
96 card_cache: Arc<RwLock<Option<CachedCard>>>,
98 endpoint_cache: Arc<RwLock<Option<String>>>,
100}
101
102impl std::fmt::Debug for A2aClient {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 f.debug_struct("A2aClient")
105 .field("config", &self.config)
106 .field("base_url", &self.base_url)
107 .finish_non_exhaustive()
108 }
109}
110
111#[derive(Debug, Serialize)]
112struct OAuthAuthorizeRequest {
113 response_type: String,
114 client_id: String,
115 redirect_uri: String,
116 scope: String,
117 state: String,
118 code_challenge: String,
119 code_challenge_method: String,
120}
121
122#[derive(Debug, Deserialize)]
123struct OAuthAuthorizeResponse {
124 authorization_url: String,
125 #[allow(dead_code)]
126 state: String,
127}
128
129impl A2aClient {
130 pub fn new(config: ClientConfig) -> Result<Self> {
132 let base_url = Url::parse(&config.server_url)
133 .with_context(|| format!("Invalid server URL: {}", config.server_url))?;
134
135 Ok(Self {
136 config,
137 http: Client::new(),
138 base_url,
139 card_cache: Arc::new(RwLock::new(None)),
140 endpoint_cache: Arc::new(RwLock::new(None)),
141 })
142 }
143
144 pub fn with_server(server_url: &str) -> Result<Self> {
146 Self::new(ClientConfig {
147 server_url: server_url.to_string(),
148 ..Default::default()
149 })
150 }
151
152 pub fn server_url(&self) -> &str {
154 &self.config.server_url
155 }
156
157 pub async fn fetch_agent_card(&self) -> Result<AgentCard> {
159 {
161 let cache = self.card_cache.read().await;
162 if let Some(cached) = cache.as_ref() {
163 if cached.is_valid() {
164 return Ok(cached.card.clone());
165 }
166 }
167 }
168
169 let url = self.base_url.join("/.well-known/agent-card.json")?;
171 let card: AgentCard = self
172 .http
173 .get(url)
174 .send()
175 .await?
176 .error_for_status()?
177 .json()
178 .await?;
179
180 {
182 let mut cache = self.card_cache.write().await;
183 *cache = Some(CachedCard {
184 card: card.clone(),
185 fetched_at: Instant::now(),
186 });
187 }
188
189 Ok(card)
190 }
191
192 pub async fn invalidate_card_cache(&self) {
194 let mut cache = self.card_cache.write().await;
195 *cache = None;
196 let mut endpoint = self.endpoint_cache.write().await;
197 *endpoint = None;
198 }
199
200 async fn get_cached_endpoint(&self) -> Result<String> {
202 {
204 let cache = self.endpoint_cache.read().await;
205 if let Some(endpoint) = cache.as_ref() {
206 return Ok(endpoint.clone());
207 }
208 }
209
210 let card = self.fetch_agent_card().await?;
212 let endpoint = card
213 .endpoint()
214 .ok_or_else(|| anyhow!("Agent card has no JSONRPC endpoint"))?
215 .to_string();
216
217 {
218 let mut cache = self.endpoint_cache.write().await;
219 *cache = Some(endpoint.clone());
220 }
221
222 Ok(endpoint)
223 }
224
225 #[inline]
227 pub fn get_rpc_url(card: &AgentCard) -> Option<&str> {
228 card.endpoint()
229 }
230
231 async fn json_rpc_call<P: Serialize, R: for<'de> Deserialize<'de>>(
233 &self,
234 method: &str,
235 params: P,
236 session_token: Option<&str>,
237 ) -> Result<R> {
238 let rpc_url = self.get_cached_endpoint().await?;
239
240 let request = JsonRpcRequest {
241 jsonrpc: "2.0".into(),
242 method: method.into(),
243 params: Some(serde_json::to_value(params)?),
244 id: serde_json::json!(1),
245 };
246
247 let mut req_builder = self.http.post(rpc_url).json(&request);
248 if let Some(token) = session_token {
249 req_builder = req_builder.header("Authorization", format!("Bearer {token}"));
250 }
251
252 let mut resp: JsonRpcResponse = req_builder
253 .send()
254 .await?
255 .error_for_status()?
256 .json()
257 .await?;
258
259 if let Some(err) = resp.error.take() {
260 anyhow::bail!("Server error {}: {}", err.code, err.message);
261 }
262
263 resp.result
264 .as_ref()
265 .map(|v| serde_json::from_value(v.clone()))
266 .transpose()?
267 .ok_or_else(|| anyhow!("Server returned no result"))
268 }
269
270 pub async fn send_message(
272 &self,
273 message: Message,
274 session_token: Option<&str>,
275 ) -> Result<SendMessageResponse> {
276 let params = SendMessageRequest {
277 tenant: None,
278 message,
279 configuration: None,
280 metadata: None,
281 };
282 self.json_rpc_call("message/send", params, session_token).await
283 }
284
285 pub async fn poll_task(&self, task_id: &str, session_token: Option<&str>) -> Result<Task> {
287 let params = GetTaskRequest {
288 id: task_id.to_string(),
289 history_length: None,
290 tenant: None,
291 };
292 self.json_rpc_call("tasks/get", params, session_token).await
293 }
294
295 pub async fn poll_until_complete(
297 &self,
298 task_id: &str,
299 session_token: Option<&str>,
300 ) -> Result<Task> {
301 let mut task = self.poll_task(task_id, session_token).await?;
302
303 for i in 0..self.config.max_polls {
304 if task.status.state.is_terminal() {
305 return Ok(task);
306 }
307
308 sleep(Duration::from_millis(self.config.poll_interval_ms)).await;
309
310 match self.poll_task(task_id, session_token).await {
311 Ok(updated_task) => {
312 info!(
313 "Poll {}/{}: state={:?}",
314 i + 1,
315 self.config.max_polls,
316 updated_task.status.state
317 );
318 task = updated_task;
319 }
320 Err(e) => {
321 warn!("Poll {}/{} failed: {}", i + 1, self.config.max_polls, e);
322 }
324 }
325 }
326
327 Ok(task)
328 }
329
330 pub async fn perform_oauth_interactive(&self) -> Result<String> {
332 let oauth_config = self
333 .config
334 .oauth
335 .as_ref()
336 .ok_or_else(|| anyhow!("OAuth not configured"))?;
337
338 if let Some(token) = &oauth_config.session_token {
340 return Ok(token.clone());
341 }
342
343 let code_verifier = generate_code_verifier();
345 let code_challenge = generate_code_challenge(&code_verifier);
346 let client_state = generate_random_string(32);
347
348 let authorize_req = OAuthAuthorizeRequest {
349 response_type: "code".to_string(),
350 client_id: oauth_config.client_id.clone(),
351 redirect_uri: oauth_config.redirect_uri.clone(),
352 scope: oauth_config.scopes.join(" "),
353 state: client_state.clone(),
354 code_challenge,
355 code_challenge_method: "S256".to_string(),
356 };
357
358 let oauth_url = self.base_url.join("/oauth/authorize")?;
360 let auth_response: OAuthAuthorizeResponse = self
361 .http
362 .post(oauth_url)
363 .json(&authorize_req)
364 .send()
365 .await?
366 .error_for_status()?
367 .json()
368 .await?;
369
370 println!("\nš OAuth Authentication Required");
372 println!("āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā");
373 println!("Please visit this URL to authenticate:\n");
374 println!("{}\n", auth_response.authorization_url);
375 println!("After authentication, you'll be redirected to:");
376 println!("{}?session_token=...", oauth_config.redirect_uri);
377 println!("āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā\n");
378
379 println!("Paste the full redirect URL here:");
381 let mut input = String::new();
382 std::io::stdin().read_line(&mut input)?;
383 let input = input.trim();
384
385 let parsed_url = Url::parse(input).or_else(|_| {
387 if input.starts_with("session_token=") || input.contains("session_token=") {
388 Ok(Url::parse(&format!(
389 "{}?{}",
390 oauth_config.redirect_uri, input
391 ))?)
392 } else {
393 Err(anyhow!("Invalid URL or token format"))
394 }
395 })?;
396
397 let session_token = parsed_url
398 .query_pairs()
399 .find(|(key, _)| key == "session_token")
400 .map(|(_, value)| value.to_string())
401 .ok_or_else(|| anyhow!("No session_token found in URL"))?;
402
403 Ok(session_token)
404 }
405
406 pub async fn start_oauth_flow(&self) -> Result<(String, String)> {
408 let oauth_config = self
409 .config
410 .oauth
411 .as_ref()
412 .ok_or_else(|| anyhow!("OAuth not configured"))?;
413
414 let code_verifier = generate_code_verifier();
415 let code_challenge = generate_code_challenge(&code_verifier);
416 let client_state = generate_random_string(32);
417
418 let authorize_req = OAuthAuthorizeRequest {
419 response_type: "code".to_string(),
420 client_id: oauth_config.client_id.clone(),
421 redirect_uri: oauth_config.redirect_uri.clone(),
422 scope: oauth_config.scopes.join(" "),
423 state: client_state.clone(),
424 code_challenge,
425 code_challenge_method: "S256".to_string(),
426 };
427
428 let oauth_url = self.base_url.join("/oauth/authorize")?;
429 let auth_response: OAuthAuthorizeResponse = self
430 .http
431 .post(oauth_url)
432 .json(&authorize_req)
433 .send()
434 .await?
435 .error_for_status()?
436 .json()
437 .await?;
438
439 Ok((auth_response.authorization_url, code_verifier))
440 }
441}
442
443pub fn generate_code_verifier() -> String {
445 let mut rng = rand::thread_rng();
446 let random_bytes: Vec<u8> = (0..32).map(|_| rng.gen()).collect();
447 URL_SAFE_NO_PAD.encode(&random_bytes)
448}
449
450pub fn generate_code_challenge(verifier: &str) -> String {
452 let mut hasher = Sha256::new();
453 hasher.update(verifier.as_bytes());
454 let hash = hasher.finalize();
455 URL_SAFE_NO_PAD.encode(hash)
456}
457
458pub fn generate_random_string(length: usize) -> String {
460 let mut rng = rand::thread_rng();
461 let random_bytes: Vec<u8> = (0..length).map(|_| rng.gen()).collect();
462 URL_SAFE_NO_PAD.encode(&random_bytes)
463}
464
465#[cfg(test)]
466mod tests {
467 use super::*;
468
469 #[test]
470 fn test_default_config() {
471 let config = ClientConfig::default();
472 assert_eq!(config.server_url, "http://127.0.0.1:8080");
473 assert_eq!(config.max_polls, 30);
474 assert_eq!(config.poll_interval_ms, 2000);
475 assert!(config.oauth.is_none());
476 }
477
478 #[test]
479 fn test_code_challenge() {
480 let verifier = generate_code_verifier();
481 let challenge = generate_code_challenge(&verifier);
482
483 assert!(verifier.chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
485 assert!(challenge.chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
487 }
488
489 #[test]
490 fn test_client_creation() {
491 let client = A2aClient::with_server("http://localhost:8080").unwrap();
492 assert_eq!(client.server_url(), "http://localhost:8080");
493 }
494}