1use chrono::Utc;
7use dashmap::DashMap;
8use std::time::Duration;
9use tokio::sync::{mpsc, oneshot};
10use tracing::warn;
11use uuid::Uuid;
12
13use punch_types::{
14 AgentMessage, AgentMessageType, FighterId, MessageChannel, MessagePriority, PunchError,
15 PunchResult,
16};
17
18const DEFAULT_MAILBOX_CAPACITY: usize = 256;
20
21const MAX_DEAD_LETTERS: usize = 1000;
23
24pub struct MessageRouter {
26 mailboxes: DashMap<FighterId, mpsc::Sender<AgentMessage>>,
28 pending_receivers: DashMap<FighterId, mpsc::Receiver<AgentMessage>>,
31 dead_letters: DashMap<u64, AgentMessage>,
33 dead_letter_counter: std::sync::atomic::AtomicU64,
35 pending_requests: DashMap<Uuid, oneshot::Sender<AgentMessage>>,
37}
38
39impl MessageRouter {
40 pub fn new() -> Self {
42 Self {
43 mailboxes: DashMap::new(),
44 pending_receivers: DashMap::new(),
45 dead_letters: DashMap::new(),
46 dead_letter_counter: std::sync::atomic::AtomicU64::new(0),
47 pending_requests: DashMap::new(),
48 }
49 }
50
51 pub fn register(&self, fighter_id: FighterId) -> mpsc::Receiver<AgentMessage> {
54 let (tx, rx) = mpsc::channel(DEFAULT_MAILBOX_CAPACITY);
55 self.mailboxes.insert(fighter_id, tx);
56 rx
57 }
58
59 pub fn unregister(&self, fighter_id: &FighterId) {
61 self.mailboxes.remove(fighter_id);
62 self.pending_receivers.remove(fighter_id);
63 }
64
65 pub fn is_registered(&self, fighter_id: &FighterId) -> bool {
67 self.mailboxes.contains_key(fighter_id)
68 }
69
70 pub async fn send_direct(
72 &self,
73 from: FighterId,
74 to: FighterId,
75 content: AgentMessageType,
76 priority: MessagePriority,
77 ) -> PunchResult<Uuid> {
78 let msg = AgentMessage {
79 id: Uuid::new_v4(),
80 from,
81 to,
82 channel: MessageChannel::Direct,
83 content,
84 priority,
85 timestamp: Utc::now(),
86 delivered: false,
87 };
88
89 self.deliver(msg).await
90 }
91
92 pub async fn broadcast(
94 &self,
95 from: FighterId,
96 content: AgentMessageType,
97 priority: MessagePriority,
98 ) -> PunchResult<Vec<Uuid>> {
99 let targets: Vec<FighterId> = self
100 .mailboxes
101 .iter()
102 .map(|entry| *entry.key())
103 .filter(|id| *id != from)
104 .collect();
105
106 let mut ids = Vec::new();
107 for target in targets {
108 let msg = AgentMessage {
109 id: Uuid::new_v4(),
110 from,
111 to: target,
112 channel: MessageChannel::Broadcast,
113 content: content.clone(),
114 priority,
115 timestamp: Utc::now(),
116 delivered: false,
117 };
118 match self.deliver(msg).await {
119 Ok(id) => ids.push(id),
120 Err(e) => warn!(target = %target, error = %e, "broadcast delivery failed"),
121 }
122 }
123
124 Ok(ids)
125 }
126
127 pub async fn multicast(
129 &self,
130 from: FighterId,
131 targets: Vec<FighterId>,
132 content: AgentMessageType,
133 priority: MessagePriority,
134 ) -> PunchResult<Vec<Uuid>> {
135 let mut ids = Vec::new();
136 for target in &targets {
137 let msg = AgentMessage {
138 id: Uuid::new_v4(),
139 from,
140 to: *target,
141 channel: MessageChannel::Multicast(targets.clone()),
142 content: content.clone(),
143 priority,
144 timestamp: Utc::now(),
145 delivered: false,
146 };
147 match self.deliver(msg).await {
148 Ok(id) => ids.push(id),
149 Err(e) => warn!(target = %target, error = %e, "multicast delivery failed"),
150 }
151 }
152
153 Ok(ids)
154 }
155
156 pub async fn request(
160 &self,
161 from: FighterId,
162 to: FighterId,
163 content: AgentMessageType,
164 timeout: Duration,
165 ) -> PunchResult<AgentMessage> {
166 let msg_id = Uuid::new_v4();
167 let (resp_tx, resp_rx) = oneshot::channel();
168
169 self.pending_requests.insert(msg_id, resp_tx);
170
171 let msg = AgentMessage {
172 id: msg_id,
173 from,
174 to,
175 channel: MessageChannel::Request {
176 timeout_ms: timeout.as_millis() as u64,
177 },
178 content,
179 priority: MessagePriority::High,
180 timestamp: Utc::now(),
181 delivered: false,
182 };
183
184 self.deliver(msg).await?;
185
186 match tokio::time::timeout(timeout, resp_rx).await {
187 Ok(Ok(response)) => Ok(response),
188 Ok(Err(_)) => {
189 self.pending_requests.remove(&msg_id);
190 Err(PunchError::Internal(
191 "request channel closed before response".to_string(),
192 ))
193 }
194 Err(_) => {
195 self.pending_requests.remove(&msg_id);
196 Err(PunchError::Internal(format!(
197 "request timed out after {}ms",
198 timeout.as_millis()
199 )))
200 }
201 }
202 }
203
204 pub fn respond(
206 &self,
207 original_msg_id: &Uuid,
208 response: AgentMessage,
209 ) -> PunchResult<()> {
210 let (_, tx) = self
211 .pending_requests
212 .remove(original_msg_id)
213 .ok_or_else(|| {
214 PunchError::Internal(format!(
215 "no pending request for message {}",
216 original_msg_id
217 ))
218 })?;
219
220 tx.send(response).map_err(|_| {
221 PunchError::Internal("failed to send response: requester dropped".to_string())
222 })
223 }
224
225 async fn deliver(&self, msg: AgentMessage) -> PunchResult<Uuid> {
227 let msg_id = msg.id;
228 let target = msg.to;
229
230 if let Some(tx) = self.mailboxes.get(&target) {
231 match tx.try_send(msg) {
232 Ok(()) => Ok(msg_id),
233 Err(mpsc::error::TrySendError::Full(returned_msg)) => {
234 warn!(to = %target, "mailbox full, message queued as dead letter");
235 self.add_dead_letter(returned_msg);
236 Err(PunchError::Internal(format!(
237 "mailbox full for fighter {}",
238 target
239 )))
240 }
241 Err(mpsc::error::TrySendError::Closed(returned_msg)) => {
242 warn!(to = %target, "mailbox closed, message queued as dead letter");
243 self.add_dead_letter(returned_msg);
244 Err(PunchError::Internal(format!(
245 "mailbox closed for fighter {}",
246 target
247 )))
248 }
249 }
250 } else {
251 self.add_dead_letter(msg);
252 Err(PunchError::Internal(format!(
253 "no mailbox registered for fighter {}",
254 target
255 )))
256 }
257 }
258
259 fn add_dead_letter(&self, msg: AgentMessage) {
261 let key = self
262 .dead_letter_counter
263 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
264 self.dead_letters.insert(key, msg);
265
266 while self.dead_letters.len() > MAX_DEAD_LETTERS {
268 if let Some(oldest) = self.dead_letters.iter().map(|e| *e.key()).min() {
270 self.dead_letters.remove(&oldest);
271 } else {
272 break;
273 }
274 }
275 }
276
277 pub fn dead_letter_count(&self) -> usize {
279 self.dead_letters.len()
280 }
281
282 pub fn drain_dead_letters(&self) -> Vec<AgentMessage> {
284 let keys: Vec<u64> = self.dead_letters.iter().map(|e| *e.key()).collect();
285 let mut messages = Vec::new();
286 for key in keys {
287 if let Some((_, msg)) = self.dead_letters.remove(&key) {
288 messages.push(msg);
289 }
290 }
291 messages
292 }
293
294 pub fn registered_count(&self) -> usize {
296 self.mailboxes.len()
297 }
298}
299
300impl Default for MessageRouter {
301 fn default() -> Self {
302 Self::new()
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[tokio::test]
311 async fn test_register_and_receive() {
312 let router = MessageRouter::new();
313 let f1 = FighterId::new();
314 let f2 = FighterId::new();
315 let mut rx1 = router.register(f1);
316 let _rx2 = router.register(f2);
317
318 let msg_id = router
319 .send_direct(
320 f2,
321 f1,
322 AgentMessageType::StatusUpdate {
323 progress: 1.0,
324 detail: "done".to_string(),
325 },
326 MessagePriority::Normal,
327 )
328 .await
329 .expect("should deliver");
330
331 let received = rx1.recv().await.expect("should receive");
332 assert_eq!(received.id, msg_id);
333 assert_eq!(received.from, f2);
334 }
335
336 #[tokio::test]
337 async fn test_broadcast() {
338 let router = MessageRouter::new();
339 let sender = FighterId::new();
340 let r1 = FighterId::new();
341 let r2 = FighterId::new();
342 let _sender_rx = router.register(sender);
343 let mut rx1 = router.register(r1);
344 let mut rx2 = router.register(r2);
345
346 let ids = router
347 .broadcast(
348 sender,
349 AgentMessageType::StatusUpdate {
350 progress: 0.5,
351 detail: "update".to_string(),
352 },
353 MessagePriority::Normal,
354 )
355 .await
356 .expect("should broadcast");
357
358 assert_eq!(ids.len(), 2);
359
360 let m1 = rx1.recv().await.expect("should receive");
361 let m2 = rx2.recv().await.expect("should receive");
362 assert_eq!(m1.from, sender);
363 assert_eq!(m2.from, sender);
364 }
365
366 #[tokio::test]
367 async fn test_multicast() {
368 let router = MessageRouter::new();
369 let sender = FighterId::new();
370 let t1 = FighterId::new();
371 let t2 = FighterId::new();
372 let t3 = FighterId::new();
373 let _sr = router.register(sender);
374 let mut rx1 = router.register(t1);
375 let mut rx2 = router.register(t2);
376 let _rx3 = router.register(t3);
377
378 let ids = router
379 .multicast(
380 sender,
381 vec![t1, t2],
382 AgentMessageType::TaskAssignment {
383 task: "work".to_string(),
384 },
385 MessagePriority::High,
386 )
387 .await
388 .expect("should multicast");
389
390 assert_eq!(ids.len(), 2);
391
392 let m1 = rx1.recv().await.expect("r1 should receive");
393 let m2 = rx2.recv().await.expect("r2 should receive");
394 assert_eq!(m1.from, sender);
395 assert_eq!(m2.from, sender);
396 }
397
398 #[tokio::test]
399 async fn test_request_response() {
400 let router = std::sync::Arc::new(MessageRouter::new());
401 let requester = FighterId::new();
402 let responder = FighterId::new();
403 let _req_rx = router.register(requester);
404 let mut resp_rx = router.register(responder);
405
406 let router_clone = router.clone();
407 let requester_clone = requester;
408 let responder_clone = responder;
409
410 tokio::spawn(async move {
412 if let Some(msg) = resp_rx.recv().await {
413 let response = AgentMessage {
414 id: Uuid::new_v4(),
415 from: responder_clone,
416 to: requester_clone,
417 channel: MessageChannel::Direct,
418 content: AgentMessageType::TaskResult {
419 result: "42".to_string(),
420 success: true,
421 },
422 priority: MessagePriority::Normal,
423 timestamp: Utc::now(),
424 delivered: false,
425 };
426 let _ = router_clone.respond(&msg.id, response);
427 }
428 });
429
430 let result = router
431 .request(
432 requester,
433 responder,
434 AgentMessageType::TaskAssignment {
435 task: "compute".to_string(),
436 },
437 Duration::from_secs(5),
438 )
439 .await
440 .expect("should get response");
441
442 match &result.content {
443 AgentMessageType::TaskResult { result, success } => {
444 assert_eq!(result, "42");
445 assert!(success);
446 }
447 _ => panic!("wrong response type"),
448 }
449 }
450
451 #[tokio::test]
452 async fn test_request_timeout() {
453 let router = MessageRouter::new();
454 let requester = FighterId::new();
455 let responder = FighterId::new();
456 let _req_rx = router.register(requester);
457 let _resp_rx = router.register(responder);
458
459 let result = router
461 .request(
462 requester,
463 responder,
464 AgentMessageType::TaskAssignment {
465 task: "compute".to_string(),
466 },
467 Duration::from_millis(50),
468 )
469 .await;
470
471 assert!(result.is_err());
472 let err = result.unwrap_err().to_string();
473 assert!(err.contains("timed out"));
474 }
475
476 #[tokio::test]
477 async fn test_dead_letter_on_unregistered() {
478 let router = MessageRouter::new();
479 let f1 = FighterId::new();
480 let f2 = FighterId::new();
481 let _rx = router.register(f1);
482
483 let result = router
485 .send_direct(
486 f1,
487 f2,
488 AgentMessageType::StatusUpdate {
489 progress: 0.0,
490 detail: "test".to_string(),
491 },
492 MessagePriority::Low,
493 )
494 .await;
495
496 assert!(result.is_err());
497 assert_eq!(router.dead_letter_count(), 1);
498 }
499
500 #[tokio::test]
501 async fn test_drain_dead_letters() {
502 let router = MessageRouter::new();
503 let f1 = FighterId::new();
504 let f2 = FighterId::new();
505 let _rx = router.register(f1);
506
507 let _ = router
508 .send_direct(
509 f1,
510 f2,
511 AgentMessageType::StatusUpdate {
512 progress: 0.0,
513 detail: "dead".to_string(),
514 },
515 MessagePriority::Low,
516 )
517 .await;
518
519 let letters = router.drain_dead_letters();
520 assert_eq!(letters.len(), 1);
521 assert_eq!(router.dead_letter_count(), 0);
522 }
523
524 #[test]
525 fn test_unregister() {
526 let router = MessageRouter::new();
527 let f = FighterId::new();
528 let _rx = router.register(f);
529 assert!(router.is_registered(&f));
530 router.unregister(&f);
531 assert!(!router.is_registered(&f));
532 }
533
534 #[test]
535 fn test_registered_count() {
536 let router = MessageRouter::new();
537 assert_eq!(router.registered_count(), 0);
538 let f1 = FighterId::new();
539 let f2 = FighterId::new();
540 let _rx1 = router.register(f1);
541 let _rx2 = router.register(f2);
542 assert_eq!(router.registered_count(), 2);
543 }
544
545 #[tokio::test]
546 async fn test_broadcast_excludes_sender() {
547 let router = MessageRouter::new();
548 let sender = FighterId::new();
549 let mut sender_rx = router.register(sender);
550
551 let ids = router
552 .broadcast(
553 sender,
554 AgentMessageType::StatusUpdate {
555 progress: 1.0,
556 detail: "done".to_string(),
557 },
558 MessagePriority::Normal,
559 )
560 .await
561 .expect("should broadcast");
562
563 assert!(ids.is_empty());
565
566 let result = tokio::time::timeout(Duration::from_millis(50), sender_rx.recv()).await;
568 assert!(result.is_err()); }
570
571 #[test]
572 fn test_default_impl() {
573 let router = MessageRouter::default();
574 assert_eq!(router.registered_count(), 0);
575 }
576
577 #[tokio::test]
578 async fn test_message_priority_preserved() {
579 let router = MessageRouter::new();
580 let f1 = FighterId::new();
581 let f2 = FighterId::new();
582 let mut rx = router.register(f1);
583 let _rx2 = router.register(f2);
584
585 router
586 .send_direct(
587 f2,
588 f1,
589 AgentMessageType::Escalation {
590 reason: "urgent".to_string(),
591 original_task: "task".to_string(),
592 },
593 MessagePriority::Critical,
594 )
595 .await
596 .expect("should deliver");
597
598 let msg = rx.recv().await.expect("should receive");
599 assert_eq!(msg.priority, MessagePriority::Critical);
600 }
601
602 #[tokio::test]
603 async fn test_respond_to_nonexistent_request() {
604 let router = MessageRouter::new();
605 let response = AgentMessage {
606 id: Uuid::new_v4(),
607 from: FighterId::new(),
608 to: FighterId::new(),
609 channel: MessageChannel::Direct,
610 content: AgentMessageType::TaskResult {
611 result: "nope".to_string(),
612 success: false,
613 },
614 priority: MessagePriority::Normal,
615 timestamp: Utc::now(),
616 delivered: false,
617 };
618
619 let result = router.respond(&Uuid::new_v4(), response);
620 assert!(result.is_err());
621 }
622}