Skip to main content

punch_types/
a2a.rs

1//! # Agent-to-Agent (A2A) Protocol
2//!
3//! Inter-system communication protocol for agent discovery and task delegation.
4//! Think of it like a fight card — each agent publishes its card so others know
5//! its weight class, special moves, and how to reach it in the ring.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use dashmap::DashMap;
10use serde::{Deserialize, Serialize};
11use std::time::Duration;
12
13use crate::error::{PunchError, PunchResult};
14
15// ---------------------------------------------------------------------------
16// Authentication
17// ---------------------------------------------------------------------------
18
19/// Authentication method for reaching a remote agent.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21#[serde(tag = "type", content = "value")]
22pub enum A2AAuth {
23    /// Bearer token authentication.
24    Bearer(String),
25    /// API key authentication.
26    ApiKey(String),
27    /// No authentication required.
28    None,
29}
30
31// ---------------------------------------------------------------------------
32// Agent Card (fight card)
33// ---------------------------------------------------------------------------
34
35/// An agent's public identity card — its fight card.
36///
37/// Published so other agents can discover capabilities, supported I/O modes,
38/// and how to send tasks to this fighter.
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct AgentCard {
41    /// Human-readable name of the agent.
42    pub name: String,
43    /// Description of what this agent does.
44    pub description: String,
45    /// The URL where this agent can be reached.
46    pub url: String,
47    /// Semantic version of the agent.
48    pub version: String,
49    /// List of capability identifiers (e.g. "code_review", "web_search").
50    pub capabilities: Vec<String>,
51    /// Supported input modes (e.g. "text", "json", "image").
52    pub input_modes: Vec<String>,
53    /// Supported output modes (e.g. "text", "json", "markdown").
54    pub output_modes: Vec<String>,
55    /// Optional authentication details for reaching this agent.
56    pub authentication: Option<A2AAuth>,
57}
58
59// ---------------------------------------------------------------------------
60// Task types
61// ---------------------------------------------------------------------------
62
63/// Status of an A2A task as it moves through the pipeline.
64#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
65pub enum A2ATaskStatus {
66    /// Task is queued but not yet started.
67    Pending,
68    /// Task is actively being processed.
69    Running,
70    /// Task finished successfully.
71    Completed,
72    /// Task failed with an error message.
73    Failed(String),
74    /// Task was cancelled before completion.
75    Cancelled,
76}
77
78/// A task sent from one agent to another.
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct A2ATask {
81    /// Unique identifier for this task.
82    pub id: String,
83    /// Current status.
84    pub status: A2ATaskStatus,
85    /// Input payload (JSON).
86    pub input: serde_json::Value,
87    /// Output payload, populated on completion.
88    pub output: Option<serde_json::Value>,
89    /// When the task was created.
90    pub created_at: DateTime<Utc>,
91    /// When the task was last updated.
92    pub updated_at: DateTime<Utc>,
93}
94
95/// Structured input payload for an A2A task.
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct A2ATaskInput {
98    /// The prompt or instruction to execute.
99    pub prompt: String,
100    /// Optional additional context as key-value pairs.
101    #[serde(default)]
102    pub context: serde_json::Map<String, serde_json::Value>,
103    /// Input mode (e.g. "text", "json").
104    #[serde(default = "default_mode")]
105    pub mode: String,
106}
107
108/// Structured output payload from an A2A task.
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct A2ATaskOutput {
111    /// The result content.
112    pub content: String,
113    /// Optional structured data.
114    #[serde(default)]
115    pub data: Option<serde_json::Value>,
116    /// Output mode (e.g. "text", "json").
117    #[serde(default = "default_mode")]
118    pub mode: String,
119}
120
121fn default_mode() -> String {
122    "text".to_string()
123}
124
125// ---------------------------------------------------------------------------
126// Messages
127// ---------------------------------------------------------------------------
128
129/// A message exchanged during an A2A task.
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct A2AMessage {
132    /// The task this message belongs to.
133    pub task_id: String,
134    /// Role of the sender (e.g. "user", "agent").
135    pub role: String,
136    /// The message content.
137    pub content: String,
138    /// When the message was sent.
139    pub timestamp: DateTime<Utc>,
140}
141
142// ---------------------------------------------------------------------------
143// Client trait
144// ---------------------------------------------------------------------------
145
146/// Client interface for A2A protocol operations.
147///
148/// Implementations handle the transport layer (HTTP, gRPC, etc.) to communicate
149/// with remote agents.
150#[async_trait]
151pub trait A2AClient: Send + Sync {
152    /// Discover a remote agent by fetching its card from a URL.
153    async fn discover(&self, url: &str) -> PunchResult<AgentCard>;
154
155    /// Send a task to a remote agent for execution.
156    async fn send_task(&self, agent: &AgentCard, task: A2ATask) -> PunchResult<A2ATask>;
157
158    /// Poll the status of a previously submitted task.
159    async fn get_task_status(&self, agent: &AgentCard, task_id: &str)
160    -> PunchResult<A2ATaskStatus>;
161
162    /// Cancel a running task on a remote agent.
163    async fn cancel_task(&self, agent: &AgentCard, task_id: &str) -> PunchResult<()>;
164}
165
166// ---------------------------------------------------------------------------
167// HTTP Client Implementation
168// ---------------------------------------------------------------------------
169
170/// Default timeout for HTTP A2A requests (30 seconds).
171const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
172
173/// HTTP-based implementation of the A2A client protocol.
174///
175/// Uses reqwest to make real HTTP calls to remote agents' A2A endpoints.
176pub struct HttpA2AClient {
177    client: reqwest::Client,
178}
179
180impl HttpA2AClient {
181    /// Create a new HTTP A2A client with the default 30-second timeout.
182    pub fn new() -> PunchResult<Self> {
183        let client = reqwest::Client::builder()
184            .timeout(DEFAULT_TIMEOUT)
185            .build()
186            .map_err(|e| PunchError::Internal(format!("failed to build HTTP client: {e}")))?;
187        Ok(Self { client })
188    }
189
190    /// Create a new HTTP A2A client with a custom timeout.
191    pub fn with_timeout(timeout: Duration) -> PunchResult<Self> {
192        let client = reqwest::Client::builder()
193            .timeout(timeout)
194            .build()
195            .map_err(|e| PunchError::Internal(format!("failed to build HTTP client: {e}")))?;
196        Ok(Self { client })
197    }
198
199    /// Build the full URL for an A2A endpoint on a remote agent.
200    pub fn build_url(base_url: &str, path: &str) -> String {
201        let base = base_url.trim_end_matches('/');
202        format!("{base}{path}")
203    }
204}
205
206impl Default for HttpA2AClient {
207    fn default() -> Self {
208        Self::new().expect("failed to create default HttpA2AClient")
209    }
210}
211
212#[async_trait]
213impl A2AClient for HttpA2AClient {
214    /// Fetch a remote agent's card from its well-known URL.
215    async fn discover(&self, url: &str) -> PunchResult<AgentCard> {
216        let card_url = Self::build_url(url, "/.well-known/agent.json");
217        let resp = self
218            .client
219            .get(&card_url)
220            .send()
221            .await
222            .map_err(|e| PunchError::Internal(format!("A2A discover failed for {url}: {e}")))?;
223
224        if !resp.status().is_success() {
225            return Err(PunchError::Internal(format!(
226                "A2A discover returned {} for {card_url}",
227                resp.status()
228            )));
229        }
230
231        resp.json::<AgentCard>()
232            .await
233            .map_err(|e| PunchError::Internal(format!("A2A discover parse error: {e}")))
234    }
235
236    /// Send a task to a remote agent via HTTP POST.
237    async fn send_task(&self, agent: &AgentCard, task: A2ATask) -> PunchResult<A2ATask> {
238        let url = Self::build_url(&agent.url, "/a2a/tasks/send");
239        let resp = self
240            .client
241            .post(&url)
242            .json(&task)
243            .send()
244            .await
245            .map_err(|e| {
246                PunchError::Internal(format!("A2A send_task failed for {}: {e}", agent.name))
247            })?;
248
249        if !resp.status().is_success() {
250            return Err(PunchError::Internal(format!(
251                "A2A send_task returned {} for {}",
252                resp.status(),
253                agent.name
254            )));
255        }
256
257        resp.json::<A2ATask>()
258            .await
259            .map_err(|e| PunchError::Internal(format!("A2A send_task parse error: {e}")))
260    }
261
262    /// Get the status of a task on a remote agent via HTTP GET.
263    async fn get_task_status(
264        &self,
265        agent: &AgentCard,
266        task_id: &str,
267    ) -> PunchResult<A2ATaskStatus> {
268        let url = Self::build_url(&agent.url, &format!("/a2a/tasks/{task_id}"));
269        let resp = self.client.get(&url).send().await.map_err(|e| {
270            PunchError::Internal(format!(
271                "A2A get_task_status failed for {}: {e}",
272                agent.name
273            ))
274        })?;
275
276        if !resp.status().is_success() {
277            return Err(PunchError::Internal(format!(
278                "A2A get_task_status returned {} for {}",
279                resp.status(),
280                agent.name
281            )));
282        }
283
284        let task = resp
285            .json::<A2ATask>()
286            .await
287            .map_err(|e| PunchError::Internal(format!("A2A get_task_status parse error: {e}")))?;
288
289        Ok(task.status)
290    }
291
292    /// Cancel a task on a remote agent via HTTP POST.
293    async fn cancel_task(&self, agent: &AgentCard, task_id: &str) -> PunchResult<()> {
294        let url = Self::build_url(&agent.url, &format!("/a2a/tasks/{task_id}/cancel"));
295        let resp = self.client.post(&url).send().await.map_err(|e| {
296            PunchError::Internal(format!("A2A cancel_task failed for {}: {e}", agent.name))
297        })?;
298
299        if !resp.status().is_success() {
300            return Err(PunchError::Internal(format!(
301                "A2A cancel_task returned {} for {}",
302                resp.status(),
303                agent.name
304            )));
305        }
306
307        Ok(())
308    }
309}
310
311// ---------------------------------------------------------------------------
312// Registry
313// ---------------------------------------------------------------------------
314
315/// Thread-safe registry of known agent cards.
316///
317/// Acts as the fight roster — keeps track of all agents that have checked in
318/// so we can discover and delegate to them.
319pub struct A2ARegistry {
320    agents: DashMap<String, AgentCard>,
321}
322
323impl A2ARegistry {
324    /// Create a new empty registry.
325    pub fn new() -> Self {
326        Self {
327            agents: DashMap::new(),
328        }
329    }
330
331    /// Register an agent card (or overwrite an existing one with the same name).
332    pub fn register(&self, card: AgentCard) {
333        self.agents.insert(card.name.clone(), card);
334    }
335
336    /// Discover an agent by name.
337    pub fn discover(&self, name: &str) -> Option<AgentCard> {
338        self.agents.get(name).map(|entry| entry.value().clone())
339    }
340
341    /// List all registered agent cards.
342    pub fn list(&self) -> Vec<AgentCard> {
343        self.agents
344            .iter()
345            .map(|entry| entry.value().clone())
346            .collect()
347    }
348
349    /// Remove an agent from the registry. Returns `true` if the agent was found.
350    pub fn remove(&self, name: &str) -> bool {
351        self.agents.remove(name).is_some()
352    }
353
354    /// Generate our own agent card — the fight card we publish to the world.
355    pub fn our_card(name: &str, url: &str, capabilities: Vec<String>) -> AgentCard {
356        AgentCard {
357            name: name.to_string(),
358            description: format!("Punch Agent: {name}"),
359            url: url.to_string(),
360            version: env!("CARGO_PKG_VERSION").to_string(),
361            capabilities,
362            input_modes: vec!["text".to_string(), "json".to_string()],
363            output_modes: vec!["text".to_string(), "json".to_string()],
364            authentication: Some(A2AAuth::None),
365        }
366    }
367}
368
369impl Default for A2ARegistry {
370    fn default() -> Self {
371        Self::new()
372    }
373}
374
375// ---------------------------------------------------------------------------
376// Tests
377// ---------------------------------------------------------------------------
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382
383    fn sample_card(name: &str) -> AgentCard {
384        AgentCard {
385            name: name.to_string(),
386            description: format!("Test agent {name}"),
387            url: format!("http://localhost:8080/{name}"),
388            version: "0.1.0".to_string(),
389            capabilities: vec!["code".to_string(), "search".to_string()],
390            input_modes: vec!["text".to_string()],
391            output_modes: vec!["text".to_string()],
392            authentication: None,
393        }
394    }
395
396    #[test]
397    fn test_agent_card_creation() {
398        let card = sample_card("alpha");
399        assert_eq!(card.name, "alpha");
400        assert_eq!(card.capabilities.len(), 2);
401        assert!(card.authentication.is_none());
402    }
403
404    #[test]
405    fn test_registry_register_and_discover() {
406        let reg = A2ARegistry::new();
407        reg.register(sample_card("boxer"));
408        let found = reg.discover("boxer");
409        assert!(found.is_some());
410        assert_eq!(found.as_ref().map(|c| c.name.as_str()), Some("boxer"));
411    }
412
413    #[test]
414    fn test_registry_list() {
415        let reg = A2ARegistry::new();
416        reg.register(sample_card("a"));
417        reg.register(sample_card("b"));
418        let list = reg.list();
419        assert_eq!(list.len(), 2);
420    }
421
422    #[test]
423    fn test_registry_remove() {
424        let reg = A2ARegistry::new();
425        reg.register(sample_card("temp"));
426        assert!(reg.remove("temp"));
427        assert!(reg.discover("temp").is_none());
428    }
429
430    #[test]
431    fn test_registry_remove_nonexistent() {
432        let reg = A2ARegistry::new();
433        assert!(!reg.remove("ghost"));
434    }
435
436    #[test]
437    fn test_task_status_transitions() {
438        let now = Utc::now();
439        let mut task = A2ATask {
440            id: "task-1".to_string(),
441            status: A2ATaskStatus::Pending,
442            input: serde_json::json!({"prompt": "hello"}),
443            output: None,
444            created_at: now,
445            updated_at: now,
446        };
447        assert_eq!(task.status, A2ATaskStatus::Pending);
448
449        task.status = A2ATaskStatus::Running;
450        assert_eq!(task.status, A2ATaskStatus::Running);
451
452        task.status = A2ATaskStatus::Completed;
453        task.output = Some(serde_json::json!({"result": "done"}));
454        assert_eq!(task.status, A2ATaskStatus::Completed);
455        assert!(task.output.is_some());
456    }
457
458    #[test]
459    fn test_our_card_generation() {
460        let card = A2ARegistry::our_card(
461            "punch-main",
462            "http://localhost:3000",
463            vec!["coordination".to_string()],
464        );
465        assert_eq!(card.name, "punch-main");
466        assert_eq!(card.url, "http://localhost:3000");
467        assert!(card.input_modes.contains(&"text".to_string()));
468        assert!(card.output_modes.contains(&"json".to_string()));
469    }
470
471    #[test]
472    fn test_registry_count() {
473        let reg = A2ARegistry::new();
474        assert_eq!(reg.list().len(), 0);
475        reg.register(sample_card("one"));
476        reg.register(sample_card("two"));
477        reg.register(sample_card("three"));
478        assert_eq!(reg.list().len(), 3);
479    }
480
481    #[test]
482    fn test_serialization_roundtrip() {
483        let card = sample_card("serial");
484        let json = serde_json::to_string(&card).expect("serialize");
485        let deserialized: AgentCard = serde_json::from_str(&json).expect("deserialize");
486        assert_eq!(deserialized.name, "serial");
487        assert_eq!(deserialized.capabilities, card.capabilities);
488    }
489
490    #[test]
491    fn test_unknown_agent_returns_none() {
492        let reg = A2ARegistry::new();
493        reg.register(sample_card("known"));
494        assert!(reg.discover("unknown").is_none());
495    }
496
497    #[test]
498    fn test_duplicate_registration_overwrites() {
499        let reg = A2ARegistry::new();
500        let mut card1 = sample_card("dup");
501        card1.description = "first".to_string();
502        reg.register(card1);
503
504        let mut card2 = sample_card("dup");
505        card2.description = "second".to_string();
506        reg.register(card2);
507
508        let found = reg.discover("dup").expect("should exist");
509        assert_eq!(found.description, "second");
510        assert_eq!(reg.list().len(), 1);
511    }
512
513    #[test]
514    fn test_task_input_serialization() {
515        let input = A2ATaskInput {
516            prompt: "Summarize this code".to_string(),
517            context: serde_json::Map::new(),
518            mode: "text".to_string(),
519        };
520        let json = serde_json::to_string(&input).expect("serialize");
521        let parsed: A2ATaskInput = serde_json::from_str(&json).expect("deserialize");
522        assert_eq!(parsed.prompt, "Summarize this code");
523        assert_eq!(parsed.mode, "text");
524    }
525
526    #[test]
527    fn test_task_output_serialization() {
528        let output = A2ATaskOutput {
529            content: "Here is the summary".to_string(),
530            data: Some(serde_json::json!({"tokens": 42})),
531            mode: "text".to_string(),
532        };
533        let json = serde_json::to_string(&output).expect("serialize");
534        let parsed: A2ATaskOutput = serde_json::from_str(&json).expect("deserialize");
535        assert_eq!(parsed.content, "Here is the summary");
536        assert!(parsed.data.is_some());
537    }
538
539    #[test]
540    fn test_task_input_default_mode() {
541        let json = r#"{"prompt": "hello", "context": {}}"#;
542        let input: A2ATaskInput = serde_json::from_str(json).expect("deserialize");
543        assert_eq!(input.mode, "text");
544    }
545
546    #[test]
547    fn test_task_output_optional_data() {
548        let output = A2ATaskOutput {
549            content: "done".to_string(),
550            data: None,
551            mode: "json".to_string(),
552        };
553        let json = serde_json::to_string(&output).expect("serialize");
554        assert!(json.contains("\"data\":null"));
555    }
556
557    #[test]
558    fn test_http_client_url_construction() {
559        assert_eq!(
560            HttpA2AClient::build_url("http://localhost:3000", "/.well-known/agent.json"),
561            "http://localhost:3000/.well-known/agent.json"
562        );
563        assert_eq!(
564            HttpA2AClient::build_url("http://localhost:3000/", "/a2a/tasks/send"),
565            "http://localhost:3000/a2a/tasks/send"
566        );
567        assert_eq!(
568            HttpA2AClient::build_url("https://agent.example.com", "/a2a/tasks/abc-123"),
569            "https://agent.example.com/a2a/tasks/abc-123"
570        );
571    }
572
573    #[test]
574    fn test_http_client_creation() {
575        let client = HttpA2AClient::new();
576        assert!(client.is_ok());
577    }
578
579    #[test]
580    fn test_http_client_custom_timeout() {
581        let client = HttpA2AClient::with_timeout(Duration::from_secs(5));
582        assert!(client.is_ok());
583    }
584
585    #[test]
586    fn test_task_status_serialization_roundtrip() {
587        let statuses = vec![
588            A2ATaskStatus::Pending,
589            A2ATaskStatus::Running,
590            A2ATaskStatus::Completed,
591            A2ATaskStatus::Failed("boom".to_string()),
592            A2ATaskStatus::Cancelled,
593        ];
594        for status in statuses {
595            let json = serde_json::to_string(&status).expect("serialize");
596            let parsed: A2ATaskStatus = serde_json::from_str(&json).expect("deserialize");
597            assert_eq!(parsed, status);
598        }
599    }
600
601    #[test]
602    fn test_a2a_message_serialization() {
603        let msg = A2AMessage {
604            task_id: "t1".to_string(),
605            role: "agent".to_string(),
606            content: "Working on it".to_string(),
607            timestamp: Utc::now(),
608        };
609        let json = serde_json::to_string(&msg).expect("serialize");
610        let parsed: A2AMessage = serde_json::from_str(&json).expect("deserialize");
611        assert_eq!(parsed.task_id, "t1");
612        assert_eq!(parsed.role, "agent");
613    }
614
615    #[test]
616    fn test_agent_card_with_auth() {
617        let card = AgentCard {
618            name: "secure-agent".to_string(),
619            description: "An authenticated agent".to_string(),
620            url: "https://secure.example.com".to_string(),
621            version: "1.0.0".to_string(),
622            capabilities: vec!["code".to_string()],
623            input_modes: vec!["text".to_string()],
624            output_modes: vec!["text".to_string()],
625            authentication: Some(A2AAuth::Bearer("token123".to_string())),
626        };
627        let json = serde_json::to_string(&card).expect("serialize");
628        let parsed: AgentCard = serde_json::from_str(&json).expect("deserialize");
629        assert!(matches!(parsed.authentication, Some(A2AAuth::Bearer(ref t)) if t == "token123"));
630    }
631}