1use std::collections::HashMap;
8use std::sync::Arc;
9
10use anyhow::Result;
11use chrono::{DateTime, Utc};
12use serde::{Deserialize, Serialize};
13use tokio::sync::RwLock;
14use uuid::Uuid;
15
16use crate::event_bus::{EventBus, KernelEvent};
17use crate::types::{AgentId, AgentStatus};
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21#[serde(tag = "type", rename_all = "snake_case")]
22pub enum A2AMessage {
23 TaskDelegation {
25 task_id: Uuid,
27 description: String,
29 payload: serde_json::Value,
31 priority: TaskPriority,
33 },
34 StatusUpdate {
36 task_id: Uuid,
38 progress: u8,
40 message: String,
42 },
43 ResultSharing {
45 task_id: Uuid,
47 result: serde_json::Value,
49 summary: String,
51 },
52 CapabilityQuery {
54 query: String,
56 required_capabilities: Vec<String>,
58 },
59 Handshake {
61 agent_id: AgentId,
63 name: String,
65 capabilities: Vec<String>,
67 },
68}
69
70impl A2AMessage {
71 pub fn type_name(&self) -> &'static str {
73 match self {
74 A2AMessage::TaskDelegation { .. } => "task_delegation",
75 A2AMessage::StatusUpdate { .. } => "status_update",
76 A2AMessage::ResultSharing { .. } => "result_sharing",
77 A2AMessage::CapabilityQuery { .. } => "capability_query",
78 A2AMessage::Handshake { .. } => "handshake",
79 }
80 }
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
85pub enum TaskPriority {
86 Low,
88 #[default]
90 Normal,
91 High,
93 Critical,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct TaskSpec {
100 pub task_id: Uuid,
102 pub description: String,
104 pub payload: serde_json::Value,
106 pub priority: TaskPriority,
108 pub deadline: Option<DateTime<Utc>>,
110}
111
112impl TaskSpec {
113 pub fn new(description: impl Into<String>, payload: serde_json::Value) -> Self {
115 Self {
116 task_id: Uuid::new_v4(),
117 description: description.into(),
118 payload,
119 priority: TaskPriority::default(),
120 deadline: None,
121 }
122 }
123
124 pub fn with_priority(mut self, priority: TaskPriority) -> Self {
126 self.priority = priority;
127 self
128 }
129
130 pub fn with_deadline(mut self, deadline: DateTime<Utc>) -> Self {
132 self.deadline = Some(deadline);
133 self
134 }
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct A2ARequest {
140 pub request_id: Uuid,
142 pub from: AgentId,
144 pub to: AgentId,
146 pub message: A2AMessage,
148 pub timestamp: DateTime<Utc>,
150}
151
152impl A2ARequest {
153 pub fn new(from: AgentId, to: AgentId, message: A2AMessage) -> Self {
155 Self {
156 request_id: Uuid::new_v4(),
157 from,
158 to,
159 message,
160 timestamp: Utc::now(),
161 }
162 }
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct A2AResponse {
168 pub response_id: Uuid,
170 pub request_id: Uuid,
172 pub from: AgentId,
174 pub to: AgentId,
176 pub accepted: bool,
178 pub payload: serde_json::Value,
180 pub timestamp: DateTime<Utc>,
182}
183
184impl A2AResponse {
185 pub fn success(
187 request_id: Uuid,
188 from: AgentId,
189 to: AgentId,
190 payload: serde_json::Value,
191 ) -> Self {
192 Self {
193 response_id: Uuid::new_v4(),
194 request_id,
195 from,
196 to,
197 accepted: true,
198 payload,
199 timestamp: Utc::now(),
200 }
201 }
202
203 pub fn error(request_id: Uuid, from: AgentId, to: AgentId, error: impl Into<String>) -> Self {
205 Self {
206 response_id: Uuid::new_v4(),
207 request_id,
208 from,
209 to,
210 accepted: false,
211 payload: serde_json::json!({ "error": error.into() }),
212 timestamp: Utc::now(),
213 }
214 }
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct PendingMessage {
220 pub request: A2ARequest,
222 pub queued_at: DateTime<Utc>,
224}
225
226impl PendingMessage {
227 fn new(request: A2ARequest) -> Self {
228 Self {
229 request,
230 queued_at: Utc::now(),
231 }
232 }
233}
234
235#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct AgentCard {
241 pub agent_id: AgentId,
243 pub name: String,
245 pub description: String,
247 pub capabilities: Vec<String>,
249 pub skills: Vec<String>,
251 pub endpoint: String,
253 pub status: AgentStatus,
255}
256
257impl AgentCard {
258 pub fn new(agent_id: AgentId, name: impl Into<String>, description: impl Into<String>) -> Self {
260 Self {
261 agent_id,
262 name: name.into(),
263 description: description.into(),
264 capabilities: Vec::new(),
265 skills: Vec::new(),
266 endpoint: "local".into(),
267 status: AgentStatus::Starting,
268 }
269 }
270
271 pub fn with_capability(mut self, capability: impl Into<String>) -> Self {
273 self.capabilities.push(capability.into());
274 self
275 }
276
277 pub fn with_skill(mut self, skill: impl Into<String>) -> Self {
279 self.skills.push(skill.into());
280 self
281 }
282
283 pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
285 self.endpoint = endpoint.into();
286 self
287 }
288
289 pub fn with_status(mut self, status: AgentStatus) -> Self {
291 self.status = status;
292 self
293 }
294
295 pub fn has_capability(&self, capability: &str) -> bool {
297 self.capabilities.iter().any(|c| c == capability)
298 }
299
300 pub fn has_skill(&self, skill: &str) -> bool {
302 self.skills.iter().any(|s| s == skill)
303 }
304}
305
306#[derive(Clone)]
311pub struct AgentCardRegistry {
312 cards: Arc<RwLock<HashMap<AgentId, AgentCard>>>,
314 event_bus: EventBus,
316}
317
318impl AgentCardRegistry {
319 pub fn new(event_bus: EventBus) -> Self {
321 Self {
322 cards: Arc::new(RwLock::new(HashMap::new())),
323 event_bus,
324 }
325 }
326
327 pub async fn register_agent(&self, card: AgentCard) -> Result<()> {
329 let agent_id = card.agent_id;
330 let mut cards = self.cards.write().await;
331 cards.insert(agent_id, card.clone());
332 drop(cards);
333
334 self.event_bus.publish(KernelEvent::AgentCreated {
335 id: agent_id,
336 name: card.name.clone(),
337 })?;
338
339 tracing::info!(agent_id = %agent_id, name = %card.name, "Agent registered in A2A registry");
340 Ok(())
341 }
342
343 pub async fn unregister_agent(&self, agent_id: AgentId) -> Result<()> {
345 let mut cards = self.cards.write().await;
346 if let Some(card) = cards.remove(&agent_id) {
347 tracing::info!(agent_id = %agent_id, name = %card.name, "Agent unregistered from A2A registry");
348 drop(cards);
349
350 self.event_bus
351 .publish(KernelEvent::AgentStopped { id: agent_id })?;
352 }
353 Ok(())
354 }
355
356 pub async fn find_agents_by_capability(&self, capability: &str) -> Result<Vec<AgentCard>> {
358 let cards = self.cards.read().await;
359 let matches: Vec<AgentCard> = cards
360 .values()
361 .filter(|card| card.has_capability(capability))
362 .cloned()
363 .collect();
364 Ok(matches)
365 }
366
367 pub async fn find_agents_by_skill(&self, skill: &str) -> Result<Vec<AgentCard>> {
369 let cards = self.cards.read().await;
370 let matches: Vec<AgentCard> = cards
371 .values()
372 .filter(|card| card.has_skill(skill))
373 .cloned()
374 .collect();
375 Ok(matches)
376 }
377
378 pub async fn get_agent(&self, agent_id: AgentId) -> Option<AgentCard> {
380 let cards = self.cards.read().await;
381 cards.get(&agent_id).cloned()
382 }
383
384 pub async fn list_agents(&self) -> Vec<AgentCard> {
386 let cards = self.cards.read().await;
387 cards.values().cloned().collect()
388 }
389
390 pub async fn agent_count(&self) -> usize {
392 let cards = self.cards.read().await;
393 cards.len()
394 }
395
396 pub async fn update_status(&self, agent_id: AgentId, status: AgentStatus) -> Result<()> {
398 let mut cards = self.cards.write().await;
399 if let Some(card) = cards.get_mut(&agent_id) {
400 card.status = status;
401 }
402 Ok(())
403 }
404}
405
406impl std::fmt::Debug for AgentCardRegistry {
407 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
408 f.debug_struct("AgentCardRegistry").finish()
409 }
410}
411
412struct AgentQueue {
417 messages: parking_lot::Mutex<Vec<PendingMessage>>,
419 notify: tokio::sync::Notify,
421}
422
423impl AgentQueue {
424 fn new() -> Self {
425 Self {
426 messages: parking_lot::Mutex::new(Vec::new()),
427 notify: tokio::sync::Notify::new(),
428 }
429 }
430}
431
432pub type DelegationHandler = Arc<
437 dyn Fn(
438 AgentId,
439 AgentId,
440 TaskSpec,
441 )
442 -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<serde_json::Value>> + Send>>
443 + Send
444 + Sync,
445>;
446
447#[derive(Clone)]
449pub struct A2AProtocol {
450 registry: AgentCardRegistry,
452 queues: Arc<RwLock<HashMap<AgentId, Arc<AgentQueue>>>>,
454 event_bus: EventBus,
456 delegation_handler: Arc<RwLock<Option<DelegationHandler>>>,
458}
459
460impl A2AProtocol {
461 pub fn new(event_bus: EventBus) -> Self {
463 let registry = AgentCardRegistry::new(event_bus.clone());
464 Self {
465 registry,
466 queues: Arc::new(RwLock::new(HashMap::new())),
467 event_bus,
468 delegation_handler: Arc::new(RwLock::new(None)),
469 }
470 }
471
472 pub async fn set_delegation_handler(&self, handler: DelegationHandler) {
478 let mut h = self.delegation_handler.write().await;
479 *h = Some(handler);
480 }
481
482 async fn get_or_create_queue(&self, agent_id: AgentId) -> Arc<AgentQueue> {
484 let mut queues = self.queues.write().await;
485 queues
486 .entry(agent_id)
487 .or_insert_with(|| Arc::new(AgentQueue::new()))
488 .clone()
489 }
490
491 pub fn registry(&self) -> &AgentCardRegistry {
493 &self.registry
494 }
495
496 pub async fn execute_delegation(
503 &self,
504 from: AgentId,
505 to: AgentId,
506 task: TaskSpec,
507 ) -> Option<Result<serde_json::Value>> {
508 let handler = self.delegation_handler.read().await;
509 let handler_ref = handler.as_ref()?;
510
511 let _ = self.event_bus.publish(KernelEvent::MessageReceived {
513 from,
514 content: format!("[task_delegation] {:?}", task.task_id),
515 });
516
517 tracing::info!(
518 from = %from,
519 to = %to,
520 task_id = %task.task_id,
521 "A2A execute_delegation: starting"
522 );
523
524 let result = handler_ref(from, to, task).await;
525
526 tracing::info!(
527 from = %from,
528 to = %to,
529 success = result.is_ok(),
530 "A2A execute_delegation: completed"
531 );
532
533 Some(result)
534 }
535
536 pub async fn send_message(
538 &self,
539 from: AgentId,
540 to: AgentId,
541 message: A2AMessage,
542 ) -> Result<Uuid> {
543 let msg_type = message.type_name();
544 let request = A2ARequest::new(from, to, message);
545 let request_id = request.request_id;
546
547 let queue = self.get_or_create_queue(to).await;
549 queue
550 .messages
551 .lock()
552 .push(PendingMessage::new(request.clone()));
553 queue.notify.notify_one();
554
555 self.event_bus.publish(KernelEvent::MessageReceived {
556 from,
557 content: format!("[{}] {:?}", msg_type, request_id),
558 })?;
559
560 tracing::debug!(
561 from = %from,
562 to = %to,
563 request_id = %request_id,
564 msg_type,
565 "A2A message sent"
566 );
567
568 Ok(request_id)
569 }
570
571 pub async fn delegate_task(&self, from: AgentId, to: AgentId, task: TaskSpec) -> Result<Uuid> {
573 let message = A2AMessage::TaskDelegation {
574 task_id: task.task_id,
575 description: task.description.clone(),
576 payload: task.payload.clone(),
577 priority: task.priority,
578 };
579
580 self.send_message(from, to, message).await
581 }
582
583 pub async fn send_status_update(
585 &self,
586 from: AgentId,
587 to: AgentId,
588 task_id: Uuid,
589 progress: u8,
590 message: String,
591 ) -> Result<Uuid> {
592 let message = A2AMessage::StatusUpdate {
593 task_id,
594 progress,
595 message,
596 };
597
598 self.send_message(from, to, message).await
599 }
600
601 pub async fn share_result(
603 &self,
604 from: AgentId,
605 to: AgentId,
606 task_id: Uuid,
607 result: serde_json::Value,
608 summary: String,
609 ) -> Result<Uuid> {
610 let message = A2AMessage::ResultSharing {
611 task_id,
612 result,
613 summary,
614 };
615
616 self.send_message(from, to, message).await
617 }
618
619 pub async fn query_capabilities(&self, capability: &str) -> Result<Vec<AgentCard>> {
621 self.registry.find_agents_by_capability(capability).await
622 }
623
624 pub async fn send_handshake(&self, from: AgentId, to: AgentId) -> Result<Uuid> {
626 let card = self.registry.get_agent(from).await;
627
628 let (name, capabilities) = if let Some(card) = card {
629 (card.name, card.capabilities.clone())
630 } else {
631 ("unknown".into(), Vec::new())
632 };
633
634 let message = A2AMessage::Handshake {
635 agent_id: from,
636 name,
637 capabilities,
638 };
639
640 self.send_message(from, to, message).await
641 }
642
643 pub async fn receive_messages(&self, agent_id: AgentId) -> Vec<A2ARequest> {
645 let queues = self.queues.read().await;
646 if let Some(queue) = queues.get(&agent_id) {
647 let drained: Vec<PendingMessage> = queue.messages.lock().drain(..).collect();
648 drained.into_iter().map(|m| m.request).collect()
649 } else {
650 Vec::new()
651 }
652 }
653
654 pub async fn pending_count(&self, agent_id: AgentId) -> usize {
656 let queues = self.queues.read().await;
657 queues
658 .get(&agent_id)
659 .map(|q| q.messages.lock().len())
660 .unwrap_or(0)
661 }
662
663 pub async fn has_messages(&self, agent_id: AgentId) -> bool {
665 self.pending_count(agent_id).await > 0
666 }
667
668 pub async fn deliver_pending_messages(&self, agent_id: AgentId) -> Result<Vec<A2ARequest>> {
674 Ok(self.receive_messages(agent_id).await)
675 }
676
677 pub async fn send_and_wait(
685 &self,
686 from: AgentId,
687 to: AgentId,
688 message: A2AMessage,
689 timeout: std::time::Duration,
690 ) -> Result<A2AResponse> {
691 let wait_task_id = match &message {
693 A2AMessage::TaskDelegation { task_id, .. } => Some(*task_id),
694 _ => None,
695 };
696
697 let request_id = self.send_message(from, to, message).await?;
698 let queue = self.get_or_create_queue(from).await;
699 let deadline = tokio::time::Instant::now() + timeout;
700
701 loop {
702 {
704 let mut msgs = queue.messages.lock();
705 let match_idx = msgs.iter().position(|p| {
706 match (&p.request.message, wait_task_id) {
707 (A2AMessage::ResultSharing { task_id, .. }, Some(wait_id)) => {
709 *task_id == wait_id
710 }
711 (A2AMessage::ResultSharing { result, .. }, None) => {
713 result.get("request_id").and_then(|v| v.as_str())
714 == Some(&request_id.to_string())
715 }
716 _ => false,
717 }
718 });
719 if let Some(idx) = match_idx {
720 let matched = msgs.remove(idx);
721 if let A2AMessage::ResultSharing { result, .. } = matched.request.message {
722 return Ok(A2AResponse::success(request_id, to, from, result));
723 }
724 }
725 }
726
727 let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
729 if remaining.is_zero() {
730 anyhow::bail!("A2A response timeout after {:?}", timeout);
731 }
732
733 tokio::select! {
734 _ = queue.notify.notified() => {
735 }
737 _ = tokio::time::sleep(remaining) => {
738 anyhow::bail!("A2A response timeout after {:?}", timeout);
739 }
740 }
741 }
742 }
743}
744
745impl std::fmt::Debug for A2AProtocol {
746 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
747 f.debug_struct("A2AProtocol")
748 .field("registry", &self.registry)
749 .finish()
750 }
751}
752
753#[cfg(test)]
754mod tests {
755 use super::*;
756
757 fn create_test_event_bus() -> EventBus {
758 EventBus::new(256)
759 }
760
761 fn create_test_agent_id() -> AgentId {
762 Uuid::new_v4()
763 }
764
765 #[tokio::test]
766 async fn test_agent_card_creation() {
767 let agent_id = create_test_agent_id();
768 let card = AgentCard::new(agent_id, "test-agent", "A test agent")
769 .with_capability("code-review")
770 .with_capability("lint")
771 .with_skill("rust")
772 .with_endpoint("local");
773
774 assert_eq!(card.agent_id, agent_id);
775 assert_eq!(card.name, "test-agent");
776 assert!(card.has_capability("code-review"));
777 assert!(card.has_capability("lint"));
778 assert!(!card.has_capability("refactor"));
779 assert!(card.has_skill("rust"));
780 assert!(!card.has_skill("python"));
781 }
782
783 #[tokio::test]
784 async fn test_registry_register_unregister() {
785 let bus = create_test_event_bus();
786 let registry = AgentCardRegistry::new(bus);
787
788 let agent_id = create_test_agent_id();
789 let card = AgentCard::new(agent_id, "register-test", "Test agent").with_capability("test");
790
791 registry.register_agent(card.clone()).await.unwrap();
792 assert_eq!(registry.agent_count().await, 1);
793
794 let found = registry.get_agent(agent_id).await;
795 assert!(found.is_some());
796 assert_eq!(found.unwrap().name, "register-test");
797
798 registry.unregister_agent(agent_id).await.unwrap();
799 assert_eq!(registry.agent_count().await, 0);
800
801 let found = registry.get_agent(agent_id).await;
802 assert!(found.is_none());
803 }
804
805 #[tokio::test]
806 async fn test_registry_find_by_capability() {
807 let bus = create_test_event_bus();
808 let registry = AgentCardRegistry::new(bus);
809
810 let id1 = Uuid::new_v4();
811 let id2 = Uuid::new_v4();
812
813 registry
814 .register_agent(
815 AgentCard::new(id1, "agent-1", "First agent").with_capability("code-review"),
816 )
817 .await
818 .unwrap();
819
820 registry
821 .register_agent(
822 AgentCard::new(id2, "agent-2", "Second agent")
823 .with_capability("code-review")
824 .with_capability("refactor"),
825 )
826 .await
827 .unwrap();
828
829 let reviewers = registry
830 .find_agents_by_capability("code-review")
831 .await
832 .unwrap();
833 assert_eq!(reviewers.len(), 2);
834 }
835
836 #[tokio::test]
837 async fn test_a2a_protocol_send_receive() {
838 let bus = create_test_event_bus();
839 let a2a = A2AProtocol::new(bus);
840
841 let from = create_test_agent_id();
842 let to = create_test_agent_id();
843
844 let message = A2AMessage::Handshake {
845 agent_id: from,
846 name: "sender".into(),
847 capabilities: vec!["test".into()],
848 };
849
850 a2a.send_message(from, to, message).await.unwrap();
851 assert_eq!(a2a.pending_count(to).await, 1);
852
853 let messages = a2a.receive_messages(to).await;
854 assert_eq!(messages.len(), 1);
855 assert_eq!(messages[0].from, from);
856 assert_eq!(messages[0].to, to);
857 assert_eq!(a2a.pending_count(to).await, 0);
858 }
859
860 #[tokio::test]
861 async fn test_delegate_task() {
862 let bus = create_test_event_bus();
863 let a2a = A2AProtocol::new(bus);
864
865 let from = create_test_agent_id();
866 let to = create_test_agent_id();
867
868 let task = TaskSpec::new("Review PR", serde_json::json!({ "pr": 42 }));
869
870 let request_id = a2a.delegate_task(from, to, task).await.unwrap();
871 assert!(request_id != Uuid::nil());
872
873 let messages = a2a.receive_messages(to).await;
874 assert_eq!(messages.len(), 1);
875 }
876}