Skip to main content

punch_kernel/
agent_messaging.rs

1//! # Inter-Agent Messaging
2//!
3//! Rich messaging between fighters using tokio channels.
4//! Supports direct, broadcast, multicast, request-response, and streaming patterns.
5
6use chrono::Utc;
7use dashmap::DashMap;
8use std::time::Duration;
9use tokio::sync::{mpsc, oneshot};
10use tracing::warn;
11use uuid::Uuid;
12
13use punch_types::{
14    AgentMessage, AgentMessageType, FighterId, MessageChannel, MessagePriority, PunchError,
15    PunchResult,
16};
17
18/// Default mailbox capacity per fighter.
19const DEFAULT_MAILBOX_CAPACITY: usize = 256;
20
21/// Maximum dead letters to retain before oldest are dropped.
22const MAX_DEAD_LETTERS: usize = 1000;
23
24/// The messaging router handles delivery of inter-agent messages.
25pub struct MessageRouter {
26    /// Active mailboxes keyed by fighter ID.
27    mailboxes: DashMap<FighterId, mpsc::Sender<AgentMessage>>,
28    /// Receivers waiting to be claimed (fighter_id -> receiver).
29    /// Using a DashMap with Option to allow one-time take.
30    pending_receivers: DashMap<FighterId, mpsc::Receiver<AgentMessage>>,
31    /// Dead letter queue for undeliverable messages.
32    dead_letters: DashMap<u64, AgentMessage>,
33    /// Counter for dead letter keys.
34    dead_letter_counter: std::sync::atomic::AtomicU64,
35    /// Pending request-response callbacks.
36    pending_requests: DashMap<Uuid, oneshot::Sender<AgentMessage>>,
37}
38
39impl MessageRouter {
40    /// Create a new message router.
41    pub fn new() -> Self {
42        Self {
43            mailboxes: DashMap::new(),
44            pending_receivers: DashMap::new(),
45            dead_letters: DashMap::new(),
46            dead_letter_counter: std::sync::atomic::AtomicU64::new(0),
47            pending_requests: DashMap::new(),
48        }
49    }
50
51    /// Register a fighter's mailbox. Returns a receiver for the fighter to
52    /// consume messages from.
53    pub fn register(&self, fighter_id: FighterId) -> mpsc::Receiver<AgentMessage> {
54        let (tx, rx) = mpsc::channel(DEFAULT_MAILBOX_CAPACITY);
55        self.mailboxes.insert(fighter_id, tx);
56        rx
57    }
58
59    /// Unregister a fighter's mailbox.
60    pub fn unregister(&self, fighter_id: &FighterId) {
61        self.mailboxes.remove(fighter_id);
62        self.pending_receivers.remove(fighter_id);
63    }
64
65    /// Check if a fighter has a registered mailbox.
66    pub fn is_registered(&self, fighter_id: &FighterId) -> bool {
67        self.mailboxes.contains_key(fighter_id)
68    }
69
70    /// Send a direct message from one fighter to another.
71    pub async fn send_direct(
72        &self,
73        from: FighterId,
74        to: FighterId,
75        content: AgentMessageType,
76        priority: MessagePriority,
77    ) -> PunchResult<Uuid> {
78        let msg = AgentMessage {
79            id: Uuid::new_v4(),
80            from,
81            to,
82            channel: MessageChannel::Direct,
83            content,
84            priority,
85            timestamp: Utc::now(),
86            delivered: false,
87        };
88
89        self.deliver(msg).await
90    }
91
92    /// Broadcast a message to all registered fighters (except the sender).
93    pub async fn broadcast(
94        &self,
95        from: FighterId,
96        content: AgentMessageType,
97        priority: MessagePriority,
98    ) -> PunchResult<Vec<Uuid>> {
99        let targets: Vec<FighterId> = self
100            .mailboxes
101            .iter()
102            .map(|entry| *entry.key())
103            .filter(|id| *id != from)
104            .collect();
105
106        let mut ids = Vec::new();
107        for target in targets {
108            let msg = AgentMessage {
109                id: Uuid::new_v4(),
110                from,
111                to: target,
112                channel: MessageChannel::Broadcast,
113                content: content.clone(),
114                priority,
115                timestamp: Utc::now(),
116                delivered: false,
117            };
118            match self.deliver(msg).await {
119                Ok(id) => ids.push(id),
120                Err(e) => warn!(target = %target, error = %e, "broadcast delivery failed"),
121            }
122        }
123
124        Ok(ids)
125    }
126
127    /// Multicast a message to a specific set of fighters.
128    pub async fn multicast(
129        &self,
130        from: FighterId,
131        targets: Vec<FighterId>,
132        content: AgentMessageType,
133        priority: MessagePriority,
134    ) -> PunchResult<Vec<Uuid>> {
135        let mut ids = Vec::new();
136        for target in &targets {
137            let msg = AgentMessage {
138                id: Uuid::new_v4(),
139                from,
140                to: *target,
141                channel: MessageChannel::Multicast(targets.clone()),
142                content: content.clone(),
143                priority,
144                timestamp: Utc::now(),
145                delivered: false,
146            };
147            match self.deliver(msg).await {
148                Ok(id) => ids.push(id),
149                Err(e) => warn!(target = %target, error = %e, "multicast delivery failed"),
150            }
151        }
152
153        Ok(ids)
154    }
155
156    /// Send a request and wait for a response with timeout.
157    ///
158    /// Returns the response message on success, or a timeout error.
159    pub async fn request(
160        &self,
161        from: FighterId,
162        to: FighterId,
163        content: AgentMessageType,
164        timeout: Duration,
165    ) -> PunchResult<AgentMessage> {
166        let msg_id = Uuid::new_v4();
167        let (resp_tx, resp_rx) = oneshot::channel();
168
169        self.pending_requests.insert(msg_id, resp_tx);
170
171        let msg = AgentMessage {
172            id: msg_id,
173            from,
174            to,
175            channel: MessageChannel::Request {
176                timeout_ms: timeout.as_millis() as u64,
177            },
178            content,
179            priority: MessagePriority::High,
180            timestamp: Utc::now(),
181            delivered: false,
182        };
183
184        self.deliver(msg).await?;
185
186        match tokio::time::timeout(timeout, resp_rx).await {
187            Ok(Ok(response)) => Ok(response),
188            Ok(Err(_)) => {
189                self.pending_requests.remove(&msg_id);
190                Err(PunchError::Internal(
191                    "request channel closed before response".to_string(),
192                ))
193            }
194            Err(_) => {
195                self.pending_requests.remove(&msg_id);
196                Err(PunchError::Internal(format!(
197                    "request timed out after {}ms",
198                    timeout.as_millis()
199                )))
200            }
201        }
202    }
203
204    /// Respond to a request message.
205    pub fn respond(&self, original_msg_id: &Uuid, response: AgentMessage) -> PunchResult<()> {
206        let (_, tx) = self
207            .pending_requests
208            .remove(original_msg_id)
209            .ok_or_else(|| {
210                PunchError::Internal(format!(
211                    "no pending request for message {}",
212                    original_msg_id
213                ))
214            })?;
215
216        tx.send(response).map_err(|_| {
217            PunchError::Internal("failed to send response: requester dropped".to_string())
218        })
219    }
220
221    /// Internal delivery to a fighter's mailbox.
222    async fn deliver(&self, msg: AgentMessage) -> PunchResult<Uuid> {
223        let msg_id = msg.id;
224        let target = msg.to;
225
226        if let Some(tx) = self.mailboxes.get(&target) {
227            match tx.try_send(msg) {
228                Ok(()) => Ok(msg_id),
229                Err(mpsc::error::TrySendError::Full(returned_msg)) => {
230                    warn!(to = %target, "mailbox full, message queued as dead letter");
231                    self.add_dead_letter(returned_msg);
232                    Err(PunchError::Internal(format!(
233                        "mailbox full for fighter {}",
234                        target
235                    )))
236                }
237                Err(mpsc::error::TrySendError::Closed(returned_msg)) => {
238                    warn!(to = %target, "mailbox closed, message queued as dead letter");
239                    self.add_dead_letter(returned_msg);
240                    Err(PunchError::Internal(format!(
241                        "mailbox closed for fighter {}",
242                        target
243                    )))
244                }
245            }
246        } else {
247            self.add_dead_letter(msg);
248            Err(PunchError::Internal(format!(
249                "no mailbox registered for fighter {}",
250                target
251            )))
252        }
253    }
254
255    /// Add a message to the dead letter queue.
256    fn add_dead_letter(&self, msg: AgentMessage) {
257        let key = self
258            .dead_letter_counter
259            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
260        self.dead_letters.insert(key, msg);
261
262        // Prune oldest if over limit.
263        while self.dead_letters.len() > MAX_DEAD_LETTERS {
264            // Remove the smallest key (oldest).
265            if let Some(oldest) = self.dead_letters.iter().map(|e| *e.key()).min() {
266                self.dead_letters.remove(&oldest);
267            } else {
268                break;
269            }
270        }
271    }
272
273    /// Get the count of dead letters.
274    pub fn dead_letter_count(&self) -> usize {
275        self.dead_letters.len()
276    }
277
278    /// Drain all dead letters.
279    pub fn drain_dead_letters(&self) -> Vec<AgentMessage> {
280        let keys: Vec<u64> = self.dead_letters.iter().map(|e| *e.key()).collect();
281        let mut messages = Vec::new();
282        for key in keys {
283            if let Some((_, msg)) = self.dead_letters.remove(&key) {
284                messages.push(msg);
285            }
286        }
287        messages
288    }
289
290    /// Get the number of registered mailboxes.
291    pub fn registered_count(&self) -> usize {
292        self.mailboxes.len()
293    }
294}
295
296impl Default for MessageRouter {
297    fn default() -> Self {
298        Self::new()
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    #[tokio::test]
307    async fn test_register_and_receive() {
308        let router = MessageRouter::new();
309        let f1 = FighterId::new();
310        let f2 = FighterId::new();
311        let mut rx1 = router.register(f1);
312        let _rx2 = router.register(f2);
313
314        let msg_id = router
315            .send_direct(
316                f2,
317                f1,
318                AgentMessageType::StatusUpdate {
319                    progress: 1.0,
320                    detail: "done".to_string(),
321                },
322                MessagePriority::Normal,
323            )
324            .await
325            .expect("should deliver");
326
327        let received = rx1.recv().await.expect("should receive");
328        assert_eq!(received.id, msg_id);
329        assert_eq!(received.from, f2);
330    }
331
332    #[tokio::test]
333    async fn test_broadcast() {
334        let router = MessageRouter::new();
335        let sender = FighterId::new();
336        let r1 = FighterId::new();
337        let r2 = FighterId::new();
338        let _sender_rx = router.register(sender);
339        let mut rx1 = router.register(r1);
340        let mut rx2 = router.register(r2);
341
342        let ids = router
343            .broadcast(
344                sender,
345                AgentMessageType::StatusUpdate {
346                    progress: 0.5,
347                    detail: "update".to_string(),
348                },
349                MessagePriority::Normal,
350            )
351            .await
352            .expect("should broadcast");
353
354        assert_eq!(ids.len(), 2);
355
356        let m1 = rx1.recv().await.expect("should receive");
357        let m2 = rx2.recv().await.expect("should receive");
358        assert_eq!(m1.from, sender);
359        assert_eq!(m2.from, sender);
360    }
361
362    #[tokio::test]
363    async fn test_multicast() {
364        let router = MessageRouter::new();
365        let sender = FighterId::new();
366        let t1 = FighterId::new();
367        let t2 = FighterId::new();
368        let t3 = FighterId::new();
369        let _sr = router.register(sender);
370        let mut rx1 = router.register(t1);
371        let mut rx2 = router.register(t2);
372        let _rx3 = router.register(t3);
373
374        let ids = router
375            .multicast(
376                sender,
377                vec![t1, t2],
378                AgentMessageType::TaskAssignment {
379                    task: "work".to_string(),
380                },
381                MessagePriority::High,
382            )
383            .await
384            .expect("should multicast");
385
386        assert_eq!(ids.len(), 2);
387
388        let m1 = rx1.recv().await.expect("r1 should receive");
389        let m2 = rx2.recv().await.expect("r2 should receive");
390        assert_eq!(m1.from, sender);
391        assert_eq!(m2.from, sender);
392    }
393
394    #[tokio::test]
395    async fn test_request_response() {
396        let router = std::sync::Arc::new(MessageRouter::new());
397        let requester = FighterId::new();
398        let responder = FighterId::new();
399        let _req_rx = router.register(requester);
400        let mut resp_rx = router.register(responder);
401
402        let router_clone = router.clone();
403        let requester_clone = requester;
404        let responder_clone = responder;
405
406        // Spawn responder task.
407        tokio::spawn(async move {
408            if let Some(msg) = resp_rx.recv().await {
409                let response = AgentMessage {
410                    id: Uuid::new_v4(),
411                    from: responder_clone,
412                    to: requester_clone,
413                    channel: MessageChannel::Direct,
414                    content: AgentMessageType::TaskResult {
415                        result: "42".to_string(),
416                        success: true,
417                    },
418                    priority: MessagePriority::Normal,
419                    timestamp: Utc::now(),
420                    delivered: false,
421                };
422                let _ = router_clone.respond(&msg.id, response);
423            }
424        });
425
426        let result = router
427            .request(
428                requester,
429                responder,
430                AgentMessageType::TaskAssignment {
431                    task: "compute".to_string(),
432                },
433                Duration::from_secs(5),
434            )
435            .await
436            .expect("should get response");
437
438        match &result.content {
439            AgentMessageType::TaskResult { result, success } => {
440                assert_eq!(result, "42");
441                assert!(success);
442            }
443            _ => panic!("wrong response type"),
444        }
445    }
446
447    #[tokio::test]
448    async fn test_request_timeout() {
449        let router = MessageRouter::new();
450        let requester = FighterId::new();
451        let responder = FighterId::new();
452        let _req_rx = router.register(requester);
453        let _resp_rx = router.register(responder);
454
455        // Don't spawn a responder, so this will timeout.
456        let result = router
457            .request(
458                requester,
459                responder,
460                AgentMessageType::TaskAssignment {
461                    task: "compute".to_string(),
462                },
463                Duration::from_millis(50),
464            )
465            .await;
466
467        assert!(result.is_err());
468        let err = result.unwrap_err().to_string();
469        assert!(err.contains("timed out"));
470    }
471
472    #[tokio::test]
473    async fn test_dead_letter_on_unregistered() {
474        let router = MessageRouter::new();
475        let f1 = FighterId::new();
476        let f2 = FighterId::new();
477        let _rx = router.register(f1);
478
479        // f2 is not registered; message should become dead letter.
480        let result = router
481            .send_direct(
482                f1,
483                f2,
484                AgentMessageType::StatusUpdate {
485                    progress: 0.0,
486                    detail: "test".to_string(),
487                },
488                MessagePriority::Low,
489            )
490            .await;
491
492        assert!(result.is_err());
493        assert_eq!(router.dead_letter_count(), 1);
494    }
495
496    #[tokio::test]
497    async fn test_drain_dead_letters() {
498        let router = MessageRouter::new();
499        let f1 = FighterId::new();
500        let f2 = FighterId::new();
501        let _rx = router.register(f1);
502
503        let _ = router
504            .send_direct(
505                f1,
506                f2,
507                AgentMessageType::StatusUpdate {
508                    progress: 0.0,
509                    detail: "dead".to_string(),
510                },
511                MessagePriority::Low,
512            )
513            .await;
514
515        let letters = router.drain_dead_letters();
516        assert_eq!(letters.len(), 1);
517        assert_eq!(router.dead_letter_count(), 0);
518    }
519
520    #[test]
521    fn test_unregister() {
522        let router = MessageRouter::new();
523        let f = FighterId::new();
524        let _rx = router.register(f);
525        assert!(router.is_registered(&f));
526        router.unregister(&f);
527        assert!(!router.is_registered(&f));
528    }
529
530    #[test]
531    fn test_registered_count() {
532        let router = MessageRouter::new();
533        assert_eq!(router.registered_count(), 0);
534        let f1 = FighterId::new();
535        let f2 = FighterId::new();
536        let _rx1 = router.register(f1);
537        let _rx2 = router.register(f2);
538        assert_eq!(router.registered_count(), 2);
539    }
540
541    #[tokio::test]
542    async fn test_broadcast_excludes_sender() {
543        let router = MessageRouter::new();
544        let sender = FighterId::new();
545        let mut sender_rx = router.register(sender);
546
547        let ids = router
548            .broadcast(
549                sender,
550                AgentMessageType::StatusUpdate {
551                    progress: 1.0,
552                    detail: "done".to_string(),
553                },
554                MessagePriority::Normal,
555            )
556            .await
557            .expect("should broadcast");
558
559        // No recipients besides sender, who is excluded.
560        assert!(ids.is_empty());
561
562        // Sender should NOT receive their own broadcast.
563        let result = tokio::time::timeout(Duration::from_millis(50), sender_rx.recv()).await;
564        assert!(result.is_err()); // Timeout means nothing received.
565    }
566
567    #[test]
568    fn test_default_impl() {
569        let router = MessageRouter::default();
570        assert_eq!(router.registered_count(), 0);
571    }
572
573    #[tokio::test]
574    async fn test_message_priority_preserved() {
575        let router = MessageRouter::new();
576        let f1 = FighterId::new();
577        let f2 = FighterId::new();
578        let mut rx = router.register(f1);
579        let _rx2 = router.register(f2);
580
581        router
582            .send_direct(
583                f2,
584                f1,
585                AgentMessageType::Escalation {
586                    reason: "urgent".to_string(),
587                    original_task: "task".to_string(),
588                },
589                MessagePriority::Critical,
590            )
591            .await
592            .expect("should deliver");
593
594        let msg = rx.recv().await.expect("should receive");
595        assert_eq!(msg.priority, MessagePriority::Critical);
596    }
597
598    #[tokio::test]
599    async fn test_respond_to_nonexistent_request() {
600        let router = MessageRouter::new();
601        let response = AgentMessage {
602            id: Uuid::new_v4(),
603            from: FighterId::new(),
604            to: FighterId::new(),
605            channel: MessageChannel::Direct,
606            content: AgentMessageType::TaskResult {
607                result: "nope".to_string(),
608                success: false,
609            },
610            priority: MessagePriority::Normal,
611            timestamp: Utc::now(),
612            delivered: false,
613        };
614
615        let result = router.respond(&Uuid::new_v4(), response);
616        assert!(result.is_err());
617    }
618}