1use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use dashmap::DashMap;
10use serde::{Deserialize, Serialize};
11use std::time::Duration;
12
13use crate::error::{PunchError, PunchResult};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
21#[serde(tag = "type", content = "value")]
22pub enum A2AAuth {
23 Bearer(String),
25 ApiKey(String),
27 None,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct AgentCard {
41 pub name: String,
43 pub description: String,
45 pub url: String,
47 pub version: String,
49 pub capabilities: Vec<String>,
51 pub input_modes: Vec<String>,
53 pub output_modes: Vec<String>,
55 pub authentication: Option<A2AAuth>,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
65pub enum A2ATaskStatus {
66 Pending,
68 Running,
70 Completed,
72 Failed(String),
74 Cancelled,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct A2ATask {
81 pub id: String,
83 pub status: A2ATaskStatus,
85 pub input: serde_json::Value,
87 pub output: Option<serde_json::Value>,
89 pub created_at: DateTime<Utc>,
91 pub updated_at: DateTime<Utc>,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct A2ATaskInput {
98 pub prompt: String,
100 #[serde(default)]
102 pub context: serde_json::Map<String, serde_json::Value>,
103 #[serde(default = "default_mode")]
105 pub mode: String,
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct A2ATaskOutput {
111 pub content: String,
113 #[serde(default)]
115 pub data: Option<serde_json::Value>,
116 #[serde(default = "default_mode")]
118 pub mode: String,
119}
120
121fn default_mode() -> String {
122 "text".to_string()
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct A2AMessage {
132 pub task_id: String,
134 pub role: String,
136 pub content: String,
138 pub timestamp: DateTime<Utc>,
140}
141
142#[async_trait]
151pub trait A2AClient: Send + Sync {
152 async fn discover(&self, url: &str) -> PunchResult<AgentCard>;
154
155 async fn send_task(&self, agent: &AgentCard, task: A2ATask) -> PunchResult<A2ATask>;
157
158 async fn get_task_status(&self, agent: &AgentCard, task_id: &str)
160 -> PunchResult<A2ATaskStatus>;
161
162 async fn cancel_task(&self, agent: &AgentCard, task_id: &str) -> PunchResult<()>;
164}
165
166const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
172
173pub struct HttpA2AClient {
177 client: reqwest::Client,
178}
179
180impl HttpA2AClient {
181 pub fn new() -> PunchResult<Self> {
183 let client = reqwest::Client::builder()
184 .timeout(DEFAULT_TIMEOUT)
185 .build()
186 .map_err(|e| PunchError::Internal(format!("failed to build HTTP client: {e}")))?;
187 Ok(Self { client })
188 }
189
190 pub fn with_timeout(timeout: Duration) -> PunchResult<Self> {
192 let client = reqwest::Client::builder()
193 .timeout(timeout)
194 .build()
195 .map_err(|e| PunchError::Internal(format!("failed to build HTTP client: {e}")))?;
196 Ok(Self { client })
197 }
198
199 pub fn build_url(base_url: &str, path: &str) -> String {
201 let base = base_url.trim_end_matches('/');
202 format!("{base}{path}")
203 }
204}
205
206impl Default for HttpA2AClient {
207 fn default() -> Self {
208 Self::new().expect("failed to create default HttpA2AClient")
209 }
210}
211
212#[async_trait]
213impl A2AClient for HttpA2AClient {
214 async fn discover(&self, url: &str) -> PunchResult<AgentCard> {
216 let card_url = Self::build_url(url, "/.well-known/agent.json");
217 let resp = self
218 .client
219 .get(&card_url)
220 .send()
221 .await
222 .map_err(|e| PunchError::Internal(format!("A2A discover failed for {url}: {e}")))?;
223
224 if !resp.status().is_success() {
225 return Err(PunchError::Internal(format!(
226 "A2A discover returned {} for {card_url}",
227 resp.status()
228 )));
229 }
230
231 resp.json::<AgentCard>()
232 .await
233 .map_err(|e| PunchError::Internal(format!("A2A discover parse error: {e}")))
234 }
235
236 async fn send_task(&self, agent: &AgentCard, task: A2ATask) -> PunchResult<A2ATask> {
238 let url = Self::build_url(&agent.url, "/a2a/tasks/send");
239 let resp = self
240 .client
241 .post(&url)
242 .json(&task)
243 .send()
244 .await
245 .map_err(|e| {
246 PunchError::Internal(format!("A2A send_task failed for {}: {e}", agent.name))
247 })?;
248
249 if !resp.status().is_success() {
250 return Err(PunchError::Internal(format!(
251 "A2A send_task returned {} for {}",
252 resp.status(),
253 agent.name
254 )));
255 }
256
257 resp.json::<A2ATask>()
258 .await
259 .map_err(|e| PunchError::Internal(format!("A2A send_task parse error: {e}")))
260 }
261
262 async fn get_task_status(
264 &self,
265 agent: &AgentCard,
266 task_id: &str,
267 ) -> PunchResult<A2ATaskStatus> {
268 let url = Self::build_url(&agent.url, &format!("/a2a/tasks/{task_id}"));
269 let resp = self.client.get(&url).send().await.map_err(|e| {
270 PunchError::Internal(format!(
271 "A2A get_task_status failed for {}: {e}",
272 agent.name
273 ))
274 })?;
275
276 if !resp.status().is_success() {
277 return Err(PunchError::Internal(format!(
278 "A2A get_task_status returned {} for {}",
279 resp.status(),
280 agent.name
281 )));
282 }
283
284 let task = resp
285 .json::<A2ATask>()
286 .await
287 .map_err(|e| PunchError::Internal(format!("A2A get_task_status parse error: {e}")))?;
288
289 Ok(task.status)
290 }
291
292 async fn cancel_task(&self, agent: &AgentCard, task_id: &str) -> PunchResult<()> {
294 let url = Self::build_url(&agent.url, &format!("/a2a/tasks/{task_id}/cancel"));
295 let resp = self.client.post(&url).send().await.map_err(|e| {
296 PunchError::Internal(format!("A2A cancel_task failed for {}: {e}", agent.name))
297 })?;
298
299 if !resp.status().is_success() {
300 return Err(PunchError::Internal(format!(
301 "A2A cancel_task returned {} for {}",
302 resp.status(),
303 agent.name
304 )));
305 }
306
307 Ok(())
308 }
309}
310
311pub struct A2ARegistry {
320 agents: DashMap<String, AgentCard>,
321}
322
323impl A2ARegistry {
324 pub fn new() -> Self {
326 Self {
327 agents: DashMap::new(),
328 }
329 }
330
331 pub fn register(&self, card: AgentCard) {
333 self.agents.insert(card.name.clone(), card);
334 }
335
336 pub fn discover(&self, name: &str) -> Option<AgentCard> {
338 self.agents.get(name).map(|entry| entry.value().clone())
339 }
340
341 pub fn list(&self) -> Vec<AgentCard> {
343 self.agents
344 .iter()
345 .map(|entry| entry.value().clone())
346 .collect()
347 }
348
349 pub fn remove(&self, name: &str) -> bool {
351 self.agents.remove(name).is_some()
352 }
353
354 pub fn our_card(name: &str, url: &str, capabilities: Vec<String>) -> AgentCard {
356 AgentCard {
357 name: name.to_string(),
358 description: format!("Punch Agent: {name}"),
359 url: url.to_string(),
360 version: env!("CARGO_PKG_VERSION").to_string(),
361 capabilities,
362 input_modes: vec!["text".to_string(), "json".to_string()],
363 output_modes: vec!["text".to_string(), "json".to_string()],
364 authentication: Some(A2AAuth::None),
365 }
366 }
367}
368
369impl Default for A2ARegistry {
370 fn default() -> Self {
371 Self::new()
372 }
373}
374
375#[cfg(test)]
380mod tests {
381 use super::*;
382
383 fn sample_card(name: &str) -> AgentCard {
384 AgentCard {
385 name: name.to_string(),
386 description: format!("Test agent {name}"),
387 url: format!("http://localhost:8080/{name}"),
388 version: "0.1.0".to_string(),
389 capabilities: vec!["code".to_string(), "search".to_string()],
390 input_modes: vec!["text".to_string()],
391 output_modes: vec!["text".to_string()],
392 authentication: None,
393 }
394 }
395
396 #[test]
397 fn test_agent_card_creation() {
398 let card = sample_card("alpha");
399 assert_eq!(card.name, "alpha");
400 assert_eq!(card.capabilities.len(), 2);
401 assert!(card.authentication.is_none());
402 }
403
404 #[test]
405 fn test_registry_register_and_discover() {
406 let reg = A2ARegistry::new();
407 reg.register(sample_card("boxer"));
408 let found = reg.discover("boxer");
409 assert!(found.is_some());
410 assert_eq!(found.as_ref().map(|c| c.name.as_str()), Some("boxer"));
411 }
412
413 #[test]
414 fn test_registry_list() {
415 let reg = A2ARegistry::new();
416 reg.register(sample_card("a"));
417 reg.register(sample_card("b"));
418 let list = reg.list();
419 assert_eq!(list.len(), 2);
420 }
421
422 #[test]
423 fn test_registry_remove() {
424 let reg = A2ARegistry::new();
425 reg.register(sample_card("temp"));
426 assert!(reg.remove("temp"));
427 assert!(reg.discover("temp").is_none());
428 }
429
430 #[test]
431 fn test_registry_remove_nonexistent() {
432 let reg = A2ARegistry::new();
433 assert!(!reg.remove("ghost"));
434 }
435
436 #[test]
437 fn test_task_status_transitions() {
438 let now = Utc::now();
439 let mut task = A2ATask {
440 id: "task-1".to_string(),
441 status: A2ATaskStatus::Pending,
442 input: serde_json::json!({"prompt": "hello"}),
443 output: None,
444 created_at: now,
445 updated_at: now,
446 };
447 assert_eq!(task.status, A2ATaskStatus::Pending);
448
449 task.status = A2ATaskStatus::Running;
450 assert_eq!(task.status, A2ATaskStatus::Running);
451
452 task.status = A2ATaskStatus::Completed;
453 task.output = Some(serde_json::json!({"result": "done"}));
454 assert_eq!(task.status, A2ATaskStatus::Completed);
455 assert!(task.output.is_some());
456 }
457
458 #[test]
459 fn test_our_card_generation() {
460 let card = A2ARegistry::our_card(
461 "punch-main",
462 "http://localhost:3000",
463 vec!["coordination".to_string()],
464 );
465 assert_eq!(card.name, "punch-main");
466 assert_eq!(card.url, "http://localhost:3000");
467 assert!(card.input_modes.contains(&"text".to_string()));
468 assert!(card.output_modes.contains(&"json".to_string()));
469 }
470
471 #[test]
472 fn test_registry_count() {
473 let reg = A2ARegistry::new();
474 assert_eq!(reg.list().len(), 0);
475 reg.register(sample_card("one"));
476 reg.register(sample_card("two"));
477 reg.register(sample_card("three"));
478 assert_eq!(reg.list().len(), 3);
479 }
480
481 #[test]
482 fn test_serialization_roundtrip() {
483 let card = sample_card("serial");
484 let json = serde_json::to_string(&card).expect("serialize");
485 let deserialized: AgentCard = serde_json::from_str(&json).expect("deserialize");
486 assert_eq!(deserialized.name, "serial");
487 assert_eq!(deserialized.capabilities, card.capabilities);
488 }
489
490 #[test]
491 fn test_unknown_agent_returns_none() {
492 let reg = A2ARegistry::new();
493 reg.register(sample_card("known"));
494 assert!(reg.discover("unknown").is_none());
495 }
496
497 #[test]
498 fn test_duplicate_registration_overwrites() {
499 let reg = A2ARegistry::new();
500 let mut card1 = sample_card("dup");
501 card1.description = "first".to_string();
502 reg.register(card1);
503
504 let mut card2 = sample_card("dup");
505 card2.description = "second".to_string();
506 reg.register(card2);
507
508 let found = reg.discover("dup").expect("should exist");
509 assert_eq!(found.description, "second");
510 assert_eq!(reg.list().len(), 1);
511 }
512
513 #[test]
514 fn test_task_input_serialization() {
515 let input = A2ATaskInput {
516 prompt: "Summarize this code".to_string(),
517 context: serde_json::Map::new(),
518 mode: "text".to_string(),
519 };
520 let json = serde_json::to_string(&input).expect("serialize");
521 let parsed: A2ATaskInput = serde_json::from_str(&json).expect("deserialize");
522 assert_eq!(parsed.prompt, "Summarize this code");
523 assert_eq!(parsed.mode, "text");
524 }
525
526 #[test]
527 fn test_task_output_serialization() {
528 let output = A2ATaskOutput {
529 content: "Here is the summary".to_string(),
530 data: Some(serde_json::json!({"tokens": 42})),
531 mode: "text".to_string(),
532 };
533 let json = serde_json::to_string(&output).expect("serialize");
534 let parsed: A2ATaskOutput = serde_json::from_str(&json).expect("deserialize");
535 assert_eq!(parsed.content, "Here is the summary");
536 assert!(parsed.data.is_some());
537 }
538
539 #[test]
540 fn test_task_input_default_mode() {
541 let json = r#"{"prompt": "hello", "context": {}}"#;
542 let input: A2ATaskInput = serde_json::from_str(json).expect("deserialize");
543 assert_eq!(input.mode, "text");
544 }
545
546 #[test]
547 fn test_task_output_optional_data() {
548 let output = A2ATaskOutput {
549 content: "done".to_string(),
550 data: None,
551 mode: "json".to_string(),
552 };
553 let json = serde_json::to_string(&output).expect("serialize");
554 assert!(json.contains("\"data\":null"));
555 }
556
557 #[test]
558 fn test_http_client_url_construction() {
559 assert_eq!(
560 HttpA2AClient::build_url("http://localhost:3000", "/.well-known/agent.json"),
561 "http://localhost:3000/.well-known/agent.json"
562 );
563 assert_eq!(
564 HttpA2AClient::build_url("http://localhost:3000/", "/a2a/tasks/send"),
565 "http://localhost:3000/a2a/tasks/send"
566 );
567 assert_eq!(
568 HttpA2AClient::build_url("https://agent.example.com", "/a2a/tasks/abc-123"),
569 "https://agent.example.com/a2a/tasks/abc-123"
570 );
571 }
572
573 #[test]
574 fn test_http_client_creation() {
575 let client = HttpA2AClient::new();
576 assert!(client.is_ok());
577 }
578
579 #[test]
580 fn test_http_client_custom_timeout() {
581 let client = HttpA2AClient::with_timeout(Duration::from_secs(5));
582 assert!(client.is_ok());
583 }
584
585 #[test]
586 fn test_task_status_serialization_roundtrip() {
587 let statuses = vec![
588 A2ATaskStatus::Pending,
589 A2ATaskStatus::Running,
590 A2ATaskStatus::Completed,
591 A2ATaskStatus::Failed("boom".to_string()),
592 A2ATaskStatus::Cancelled,
593 ];
594 for status in statuses {
595 let json = serde_json::to_string(&status).expect("serialize");
596 let parsed: A2ATaskStatus = serde_json::from_str(&json).expect("deserialize");
597 assert_eq!(parsed, status);
598 }
599 }
600
601 #[test]
602 fn test_a2a_message_serialization() {
603 let msg = A2AMessage {
604 task_id: "t1".to_string(),
605 role: "agent".to_string(),
606 content: "Working on it".to_string(),
607 timestamp: Utc::now(),
608 };
609 let json = serde_json::to_string(&msg).expect("serialize");
610 let parsed: A2AMessage = serde_json::from_str(&json).expect("deserialize");
611 assert_eq!(parsed.task_id, "t1");
612 assert_eq!(parsed.role, "agent");
613 }
614
615 #[test]
616 fn test_agent_card_with_auth() {
617 let card = AgentCard {
618 name: "secure-agent".to_string(),
619 description: "An authenticated agent".to_string(),
620 url: "https://secure.example.com".to_string(),
621 version: "1.0.0".to_string(),
622 capabilities: vec!["code".to_string()],
623 input_modes: vec!["text".to_string()],
624 output_modes: vec!["text".to_string()],
625 authentication: Some(A2AAuth::Bearer("token123".to_string())),
626 };
627 let json = serde_json::to_string(&card).expect("serialize");
628 let parsed: AgentCard = serde_json::from_str(&json).expect("deserialize");
629 assert!(matches!(parsed.authentication, Some(A2AAuth::Bearer(ref t)) if t == "token123"));
630 }
631}