Skip to main content

punch_kernel/
agent_messaging.rs

1//! # Inter-Agent Messaging
2//!
3//! Rich messaging between fighters using tokio channels.
4//! Supports direct, broadcast, multicast, request-response, and streaming patterns.
5
6use chrono::Utc;
7use dashmap::DashMap;
8use std::time::Duration;
9use tokio::sync::{mpsc, oneshot};
10use tracing::warn;
11use uuid::Uuid;
12
13use punch_types::{
14    AgentMessage, AgentMessageType, FighterId, MessageChannel, MessagePriority, PunchError,
15    PunchResult,
16};
17
18/// Default mailbox capacity per fighter.
19const DEFAULT_MAILBOX_CAPACITY: usize = 256;
20
21/// Maximum dead letters to retain before oldest are dropped.
22const MAX_DEAD_LETTERS: usize = 1000;
23
24/// The messaging router handles delivery of inter-agent messages.
25pub struct MessageRouter {
26    /// Active mailboxes keyed by fighter ID.
27    mailboxes: DashMap<FighterId, mpsc::Sender<AgentMessage>>,
28    /// Receivers waiting to be claimed (fighter_id -> receiver).
29    /// Using a DashMap with Option to allow one-time take.
30    pending_receivers: DashMap<FighterId, mpsc::Receiver<AgentMessage>>,
31    /// Dead letter queue for undeliverable messages.
32    dead_letters: DashMap<u64, AgentMessage>,
33    /// Counter for dead letter keys.
34    dead_letter_counter: std::sync::atomic::AtomicU64,
35    /// Pending request-response callbacks.
36    pending_requests: DashMap<Uuid, oneshot::Sender<AgentMessage>>,
37}
38
39impl MessageRouter {
40    /// Create a new message router.
41    pub fn new() -> Self {
42        Self {
43            mailboxes: DashMap::new(),
44            pending_receivers: DashMap::new(),
45            dead_letters: DashMap::new(),
46            dead_letter_counter: std::sync::atomic::AtomicU64::new(0),
47            pending_requests: DashMap::new(),
48        }
49    }
50
51    /// Register a fighter's mailbox. Returns a receiver for the fighter to
52    /// consume messages from.
53    pub fn register(&self, fighter_id: FighterId) -> mpsc::Receiver<AgentMessage> {
54        let (tx, rx) = mpsc::channel(DEFAULT_MAILBOX_CAPACITY);
55        self.mailboxes.insert(fighter_id, tx);
56        rx
57    }
58
59    /// Unregister a fighter's mailbox.
60    pub fn unregister(&self, fighter_id: &FighterId) {
61        self.mailboxes.remove(fighter_id);
62        self.pending_receivers.remove(fighter_id);
63    }
64
65    /// Check if a fighter has a registered mailbox.
66    pub fn is_registered(&self, fighter_id: &FighterId) -> bool {
67        self.mailboxes.contains_key(fighter_id)
68    }
69
70    /// Send a direct message from one fighter to another.
71    pub async fn send_direct(
72        &self,
73        from: FighterId,
74        to: FighterId,
75        content: AgentMessageType,
76        priority: MessagePriority,
77    ) -> PunchResult<Uuid> {
78        let msg = AgentMessage {
79            id: Uuid::new_v4(),
80            from,
81            to,
82            channel: MessageChannel::Direct,
83            content,
84            priority,
85            timestamp: Utc::now(),
86            delivered: false,
87        };
88
89        self.deliver(msg).await
90    }
91
92    /// Broadcast a message to all registered fighters (except the sender).
93    pub async fn broadcast(
94        &self,
95        from: FighterId,
96        content: AgentMessageType,
97        priority: MessagePriority,
98    ) -> PunchResult<Vec<Uuid>> {
99        let targets: Vec<FighterId> = self
100            .mailboxes
101            .iter()
102            .map(|entry| *entry.key())
103            .filter(|id| *id != from)
104            .collect();
105
106        let mut ids = Vec::new();
107        for target in targets {
108            let msg = AgentMessage {
109                id: Uuid::new_v4(),
110                from,
111                to: target,
112                channel: MessageChannel::Broadcast,
113                content: content.clone(),
114                priority,
115                timestamp: Utc::now(),
116                delivered: false,
117            };
118            match self.deliver(msg).await {
119                Ok(id) => ids.push(id),
120                Err(e) => warn!(target = %target, error = %e, "broadcast delivery failed"),
121            }
122        }
123
124        Ok(ids)
125    }
126
127    /// Multicast a message to a specific set of fighters.
128    pub async fn multicast(
129        &self,
130        from: FighterId,
131        targets: Vec<FighterId>,
132        content: AgentMessageType,
133        priority: MessagePriority,
134    ) -> PunchResult<Vec<Uuid>> {
135        let mut ids = Vec::new();
136        for target in &targets {
137            let msg = AgentMessage {
138                id: Uuid::new_v4(),
139                from,
140                to: *target,
141                channel: MessageChannel::Multicast(targets.clone()),
142                content: content.clone(),
143                priority,
144                timestamp: Utc::now(),
145                delivered: false,
146            };
147            match self.deliver(msg).await {
148                Ok(id) => ids.push(id),
149                Err(e) => warn!(target = %target, error = %e, "multicast delivery failed"),
150            }
151        }
152
153        Ok(ids)
154    }
155
156    /// Send a request and wait for a response with timeout.
157    ///
158    /// Returns the response message on success, or a timeout error.
159    pub async fn request(
160        &self,
161        from: FighterId,
162        to: FighterId,
163        content: AgentMessageType,
164        timeout: Duration,
165    ) -> PunchResult<AgentMessage> {
166        let msg_id = Uuid::new_v4();
167        let (resp_tx, resp_rx) = oneshot::channel();
168
169        self.pending_requests.insert(msg_id, resp_tx);
170
171        let msg = AgentMessage {
172            id: msg_id,
173            from,
174            to,
175            channel: MessageChannel::Request {
176                timeout_ms: timeout.as_millis() as u64,
177            },
178            content,
179            priority: MessagePriority::High,
180            timestamp: Utc::now(),
181            delivered: false,
182        };
183
184        self.deliver(msg).await?;
185
186        match tokio::time::timeout(timeout, resp_rx).await {
187            Ok(Ok(response)) => Ok(response),
188            Ok(Err(_)) => {
189                self.pending_requests.remove(&msg_id);
190                Err(PunchError::Internal(
191                    "request channel closed before response".to_string(),
192                ))
193            }
194            Err(_) => {
195                self.pending_requests.remove(&msg_id);
196                Err(PunchError::Internal(format!(
197                    "request timed out after {}ms",
198                    timeout.as_millis()
199                )))
200            }
201        }
202    }
203
204    /// Respond to a request message.
205    pub fn respond(
206        &self,
207        original_msg_id: &Uuid,
208        response: AgentMessage,
209    ) -> PunchResult<()> {
210        let (_, tx) = self
211            .pending_requests
212            .remove(original_msg_id)
213            .ok_or_else(|| {
214                PunchError::Internal(format!(
215                    "no pending request for message {}",
216                    original_msg_id
217                ))
218            })?;
219
220        tx.send(response).map_err(|_| {
221            PunchError::Internal("failed to send response: requester dropped".to_string())
222        })
223    }
224
225    /// Internal delivery to a fighter's mailbox.
226    async fn deliver(&self, msg: AgentMessage) -> PunchResult<Uuid> {
227        let msg_id = msg.id;
228        let target = msg.to;
229
230        if let Some(tx) = self.mailboxes.get(&target) {
231            match tx.try_send(msg) {
232                Ok(()) => Ok(msg_id),
233                Err(mpsc::error::TrySendError::Full(returned_msg)) => {
234                    warn!(to = %target, "mailbox full, message queued as dead letter");
235                    self.add_dead_letter(returned_msg);
236                    Err(PunchError::Internal(format!(
237                        "mailbox full for fighter {}",
238                        target
239                    )))
240                }
241                Err(mpsc::error::TrySendError::Closed(returned_msg)) => {
242                    warn!(to = %target, "mailbox closed, message queued as dead letter");
243                    self.add_dead_letter(returned_msg);
244                    Err(PunchError::Internal(format!(
245                        "mailbox closed for fighter {}",
246                        target
247                    )))
248                }
249            }
250        } else {
251            self.add_dead_letter(msg);
252            Err(PunchError::Internal(format!(
253                "no mailbox registered for fighter {}",
254                target
255            )))
256        }
257    }
258
259    /// Add a message to the dead letter queue.
260    fn add_dead_letter(&self, msg: AgentMessage) {
261        let key = self
262            .dead_letter_counter
263            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
264        self.dead_letters.insert(key, msg);
265
266        // Prune oldest if over limit.
267        while self.dead_letters.len() > MAX_DEAD_LETTERS {
268            // Remove the smallest key (oldest).
269            if let Some(oldest) = self.dead_letters.iter().map(|e| *e.key()).min() {
270                self.dead_letters.remove(&oldest);
271            } else {
272                break;
273            }
274        }
275    }
276
277    /// Get the count of dead letters.
278    pub fn dead_letter_count(&self) -> usize {
279        self.dead_letters.len()
280    }
281
282    /// Drain all dead letters.
283    pub fn drain_dead_letters(&self) -> Vec<AgentMessage> {
284        let keys: Vec<u64> = self.dead_letters.iter().map(|e| *e.key()).collect();
285        let mut messages = Vec::new();
286        for key in keys {
287            if let Some((_, msg)) = self.dead_letters.remove(&key) {
288                messages.push(msg);
289            }
290        }
291        messages
292    }
293
294    /// Get the number of registered mailboxes.
295    pub fn registered_count(&self) -> usize {
296        self.mailboxes.len()
297    }
298}
299
300impl Default for MessageRouter {
301    fn default() -> Self {
302        Self::new()
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    #[tokio::test]
311    async fn test_register_and_receive() {
312        let router = MessageRouter::new();
313        let f1 = FighterId::new();
314        let f2 = FighterId::new();
315        let mut rx1 = router.register(f1);
316        let _rx2 = router.register(f2);
317
318        let msg_id = router
319            .send_direct(
320                f2,
321                f1,
322                AgentMessageType::StatusUpdate {
323                    progress: 1.0,
324                    detail: "done".to_string(),
325                },
326                MessagePriority::Normal,
327            )
328            .await
329            .expect("should deliver");
330
331        let received = rx1.recv().await.expect("should receive");
332        assert_eq!(received.id, msg_id);
333        assert_eq!(received.from, f2);
334    }
335
336    #[tokio::test]
337    async fn test_broadcast() {
338        let router = MessageRouter::new();
339        let sender = FighterId::new();
340        let r1 = FighterId::new();
341        let r2 = FighterId::new();
342        let _sender_rx = router.register(sender);
343        let mut rx1 = router.register(r1);
344        let mut rx2 = router.register(r2);
345
346        let ids = router
347            .broadcast(
348                sender,
349                AgentMessageType::StatusUpdate {
350                    progress: 0.5,
351                    detail: "update".to_string(),
352                },
353                MessagePriority::Normal,
354            )
355            .await
356            .expect("should broadcast");
357
358        assert_eq!(ids.len(), 2);
359
360        let m1 = rx1.recv().await.expect("should receive");
361        let m2 = rx2.recv().await.expect("should receive");
362        assert_eq!(m1.from, sender);
363        assert_eq!(m2.from, sender);
364    }
365
366    #[tokio::test]
367    async fn test_multicast() {
368        let router = MessageRouter::new();
369        let sender = FighterId::new();
370        let t1 = FighterId::new();
371        let t2 = FighterId::new();
372        let t3 = FighterId::new();
373        let _sr = router.register(sender);
374        let mut rx1 = router.register(t1);
375        let mut rx2 = router.register(t2);
376        let _rx3 = router.register(t3);
377
378        let ids = router
379            .multicast(
380                sender,
381                vec![t1, t2],
382                AgentMessageType::TaskAssignment {
383                    task: "work".to_string(),
384                },
385                MessagePriority::High,
386            )
387            .await
388            .expect("should multicast");
389
390        assert_eq!(ids.len(), 2);
391
392        let m1 = rx1.recv().await.expect("r1 should receive");
393        let m2 = rx2.recv().await.expect("r2 should receive");
394        assert_eq!(m1.from, sender);
395        assert_eq!(m2.from, sender);
396    }
397
398    #[tokio::test]
399    async fn test_request_response() {
400        let router = std::sync::Arc::new(MessageRouter::new());
401        let requester = FighterId::new();
402        let responder = FighterId::new();
403        let _req_rx = router.register(requester);
404        let mut resp_rx = router.register(responder);
405
406        let router_clone = router.clone();
407        let requester_clone = requester;
408        let responder_clone = responder;
409
410        // Spawn responder task.
411        tokio::spawn(async move {
412            if let Some(msg) = resp_rx.recv().await {
413                let response = AgentMessage {
414                    id: Uuid::new_v4(),
415                    from: responder_clone,
416                    to: requester_clone,
417                    channel: MessageChannel::Direct,
418                    content: AgentMessageType::TaskResult {
419                        result: "42".to_string(),
420                        success: true,
421                    },
422                    priority: MessagePriority::Normal,
423                    timestamp: Utc::now(),
424                    delivered: false,
425                };
426                let _ = router_clone.respond(&msg.id, response);
427            }
428        });
429
430        let result = router
431            .request(
432                requester,
433                responder,
434                AgentMessageType::TaskAssignment {
435                    task: "compute".to_string(),
436                },
437                Duration::from_secs(5),
438            )
439            .await
440            .expect("should get response");
441
442        match &result.content {
443            AgentMessageType::TaskResult { result, success } => {
444                assert_eq!(result, "42");
445                assert!(success);
446            }
447            _ => panic!("wrong response type"),
448        }
449    }
450
451    #[tokio::test]
452    async fn test_request_timeout() {
453        let router = MessageRouter::new();
454        let requester = FighterId::new();
455        let responder = FighterId::new();
456        let _req_rx = router.register(requester);
457        let _resp_rx = router.register(responder);
458
459        // Don't spawn a responder, so this will timeout.
460        let result = router
461            .request(
462                requester,
463                responder,
464                AgentMessageType::TaskAssignment {
465                    task: "compute".to_string(),
466                },
467                Duration::from_millis(50),
468            )
469            .await;
470
471        assert!(result.is_err());
472        let err = result.unwrap_err().to_string();
473        assert!(err.contains("timed out"));
474    }
475
476    #[tokio::test]
477    async fn test_dead_letter_on_unregistered() {
478        let router = MessageRouter::new();
479        let f1 = FighterId::new();
480        let f2 = FighterId::new();
481        let _rx = router.register(f1);
482
483        // f2 is not registered; message should become dead letter.
484        let result = router
485            .send_direct(
486                f1,
487                f2,
488                AgentMessageType::StatusUpdate {
489                    progress: 0.0,
490                    detail: "test".to_string(),
491                },
492                MessagePriority::Low,
493            )
494            .await;
495
496        assert!(result.is_err());
497        assert_eq!(router.dead_letter_count(), 1);
498    }
499
500    #[tokio::test]
501    async fn test_drain_dead_letters() {
502        let router = MessageRouter::new();
503        let f1 = FighterId::new();
504        let f2 = FighterId::new();
505        let _rx = router.register(f1);
506
507        let _ = router
508            .send_direct(
509                f1,
510                f2,
511                AgentMessageType::StatusUpdate {
512                    progress: 0.0,
513                    detail: "dead".to_string(),
514                },
515                MessagePriority::Low,
516            )
517            .await;
518
519        let letters = router.drain_dead_letters();
520        assert_eq!(letters.len(), 1);
521        assert_eq!(router.dead_letter_count(), 0);
522    }
523
524    #[test]
525    fn test_unregister() {
526        let router = MessageRouter::new();
527        let f = FighterId::new();
528        let _rx = router.register(f);
529        assert!(router.is_registered(&f));
530        router.unregister(&f);
531        assert!(!router.is_registered(&f));
532    }
533
534    #[test]
535    fn test_registered_count() {
536        let router = MessageRouter::new();
537        assert_eq!(router.registered_count(), 0);
538        let f1 = FighterId::new();
539        let f2 = FighterId::new();
540        let _rx1 = router.register(f1);
541        let _rx2 = router.register(f2);
542        assert_eq!(router.registered_count(), 2);
543    }
544
545    #[tokio::test]
546    async fn test_broadcast_excludes_sender() {
547        let router = MessageRouter::new();
548        let sender = FighterId::new();
549        let mut sender_rx = router.register(sender);
550
551        let ids = router
552            .broadcast(
553                sender,
554                AgentMessageType::StatusUpdate {
555                    progress: 1.0,
556                    detail: "done".to_string(),
557                },
558                MessagePriority::Normal,
559            )
560            .await
561            .expect("should broadcast");
562
563        // No recipients besides sender, who is excluded.
564        assert!(ids.is_empty());
565
566        // Sender should NOT receive their own broadcast.
567        let result = tokio::time::timeout(Duration::from_millis(50), sender_rx.recv()).await;
568        assert!(result.is_err()); // Timeout means nothing received.
569    }
570
571    #[test]
572    fn test_default_impl() {
573        let router = MessageRouter::default();
574        assert_eq!(router.registered_count(), 0);
575    }
576
577    #[tokio::test]
578    async fn test_message_priority_preserved() {
579        let router = MessageRouter::new();
580        let f1 = FighterId::new();
581        let f2 = FighterId::new();
582        let mut rx = router.register(f1);
583        let _rx2 = router.register(f2);
584
585        router
586            .send_direct(
587                f2,
588                f1,
589                AgentMessageType::Escalation {
590                    reason: "urgent".to_string(),
591                    original_task: "task".to_string(),
592                },
593                MessagePriority::Critical,
594            )
595            .await
596            .expect("should deliver");
597
598        let msg = rx.recv().await.expect("should receive");
599        assert_eq!(msg.priority, MessagePriority::Critical);
600    }
601
602    #[tokio::test]
603    async fn test_respond_to_nonexistent_request() {
604        let router = MessageRouter::new();
605        let response = AgentMessage {
606            id: Uuid::new_v4(),
607            from: FighterId::new(),
608            to: FighterId::new(),
609            channel: MessageChannel::Direct,
610            content: AgentMessageType::TaskResult {
611                result: "nope".to_string(),
612                success: false,
613            },
614            priority: MessagePriority::Normal,
615            timestamp: Utc::now(),
616            delivered: false,
617        };
618
619        let result = router.respond(&Uuid::new_v4(), response);
620        assert!(result.is_err());
621    }
622}