1use chrono::Utc;
7use dashmap::DashMap;
8use std::time::Duration;
9use tokio::sync::{mpsc, oneshot};
10use tracing::warn;
11use uuid::Uuid;
12
13use punch_types::{
14 AgentMessage, AgentMessageType, FighterId, MessageChannel, MessagePriority, PunchError,
15 PunchResult,
16};
17
18const DEFAULT_MAILBOX_CAPACITY: usize = 256;
20
21const MAX_DEAD_LETTERS: usize = 1000;
23
24pub struct MessageRouter {
26 mailboxes: DashMap<FighterId, mpsc::Sender<AgentMessage>>,
28 pending_receivers: DashMap<FighterId, mpsc::Receiver<AgentMessage>>,
31 dead_letters: DashMap<u64, AgentMessage>,
33 dead_letter_counter: std::sync::atomic::AtomicU64,
35 pending_requests: DashMap<Uuid, oneshot::Sender<AgentMessage>>,
37}
38
39impl MessageRouter {
40 pub fn new() -> Self {
42 Self {
43 mailboxes: DashMap::new(),
44 pending_receivers: DashMap::new(),
45 dead_letters: DashMap::new(),
46 dead_letter_counter: std::sync::atomic::AtomicU64::new(0),
47 pending_requests: DashMap::new(),
48 }
49 }
50
51 pub fn register(&self, fighter_id: FighterId) -> mpsc::Receiver<AgentMessage> {
54 let (tx, rx) = mpsc::channel(DEFAULT_MAILBOX_CAPACITY);
55 self.mailboxes.insert(fighter_id, tx);
56 rx
57 }
58
59 pub fn unregister(&self, fighter_id: &FighterId) {
61 self.mailboxes.remove(fighter_id);
62 self.pending_receivers.remove(fighter_id);
63 }
64
65 pub fn is_registered(&self, fighter_id: &FighterId) -> bool {
67 self.mailboxes.contains_key(fighter_id)
68 }
69
70 pub async fn send_direct(
72 &self,
73 from: FighterId,
74 to: FighterId,
75 content: AgentMessageType,
76 priority: MessagePriority,
77 ) -> PunchResult<Uuid> {
78 let msg = AgentMessage {
79 id: Uuid::new_v4(),
80 from,
81 to,
82 channel: MessageChannel::Direct,
83 content,
84 priority,
85 timestamp: Utc::now(),
86 delivered: false,
87 };
88
89 self.deliver(msg).await
90 }
91
92 pub async fn broadcast(
94 &self,
95 from: FighterId,
96 content: AgentMessageType,
97 priority: MessagePriority,
98 ) -> PunchResult<Vec<Uuid>> {
99 let targets: Vec<FighterId> = self
100 .mailboxes
101 .iter()
102 .map(|entry| *entry.key())
103 .filter(|id| *id != from)
104 .collect();
105
106 let mut ids = Vec::new();
107 for target in targets {
108 let msg = AgentMessage {
109 id: Uuid::new_v4(),
110 from,
111 to: target,
112 channel: MessageChannel::Broadcast,
113 content: content.clone(),
114 priority,
115 timestamp: Utc::now(),
116 delivered: false,
117 };
118 match self.deliver(msg).await {
119 Ok(id) => ids.push(id),
120 Err(e) => warn!(target = %target, error = %e, "broadcast delivery failed"),
121 }
122 }
123
124 Ok(ids)
125 }
126
127 pub async fn multicast(
129 &self,
130 from: FighterId,
131 targets: Vec<FighterId>,
132 content: AgentMessageType,
133 priority: MessagePriority,
134 ) -> PunchResult<Vec<Uuid>> {
135 let mut ids = Vec::new();
136 for target in &targets {
137 let msg = AgentMessage {
138 id: Uuid::new_v4(),
139 from,
140 to: *target,
141 channel: MessageChannel::Multicast(targets.clone()),
142 content: content.clone(),
143 priority,
144 timestamp: Utc::now(),
145 delivered: false,
146 };
147 match self.deliver(msg).await {
148 Ok(id) => ids.push(id),
149 Err(e) => warn!(target = %target, error = %e, "multicast delivery failed"),
150 }
151 }
152
153 Ok(ids)
154 }
155
156 pub async fn request(
160 &self,
161 from: FighterId,
162 to: FighterId,
163 content: AgentMessageType,
164 timeout: Duration,
165 ) -> PunchResult<AgentMessage> {
166 let msg_id = Uuid::new_v4();
167 let (resp_tx, resp_rx) = oneshot::channel();
168
169 self.pending_requests.insert(msg_id, resp_tx);
170
171 let msg = AgentMessage {
172 id: msg_id,
173 from,
174 to,
175 channel: MessageChannel::Request {
176 timeout_ms: timeout.as_millis() as u64,
177 },
178 content,
179 priority: MessagePriority::High,
180 timestamp: Utc::now(),
181 delivered: false,
182 };
183
184 self.deliver(msg).await?;
185
186 match tokio::time::timeout(timeout, resp_rx).await {
187 Ok(Ok(response)) => Ok(response),
188 Ok(Err(_)) => {
189 self.pending_requests.remove(&msg_id);
190 Err(PunchError::Internal(
191 "request channel closed before response".to_string(),
192 ))
193 }
194 Err(_) => {
195 self.pending_requests.remove(&msg_id);
196 Err(PunchError::Internal(format!(
197 "request timed out after {}ms",
198 timeout.as_millis()
199 )))
200 }
201 }
202 }
203
204 pub fn respond(&self, original_msg_id: &Uuid, response: AgentMessage) -> PunchResult<()> {
206 let (_, tx) = self
207 .pending_requests
208 .remove(original_msg_id)
209 .ok_or_else(|| {
210 PunchError::Internal(format!(
211 "no pending request for message {}",
212 original_msg_id
213 ))
214 })?;
215
216 tx.send(response).map_err(|_| {
217 PunchError::Internal("failed to send response: requester dropped".to_string())
218 })
219 }
220
221 async fn deliver(&self, msg: AgentMessage) -> PunchResult<Uuid> {
223 let msg_id = msg.id;
224 let target = msg.to;
225
226 if let Some(tx) = self.mailboxes.get(&target) {
227 match tx.try_send(msg) {
228 Ok(()) => Ok(msg_id),
229 Err(mpsc::error::TrySendError::Full(returned_msg)) => {
230 warn!(to = %target, "mailbox full, message queued as dead letter");
231 self.add_dead_letter(returned_msg);
232 Err(PunchError::Internal(format!(
233 "mailbox full for fighter {}",
234 target
235 )))
236 }
237 Err(mpsc::error::TrySendError::Closed(returned_msg)) => {
238 warn!(to = %target, "mailbox closed, message queued as dead letter");
239 self.add_dead_letter(returned_msg);
240 Err(PunchError::Internal(format!(
241 "mailbox closed for fighter {}",
242 target
243 )))
244 }
245 }
246 } else {
247 self.add_dead_letter(msg);
248 Err(PunchError::Internal(format!(
249 "no mailbox registered for fighter {}",
250 target
251 )))
252 }
253 }
254
255 fn add_dead_letter(&self, msg: AgentMessage) {
257 let key = self
258 .dead_letter_counter
259 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
260 self.dead_letters.insert(key, msg);
261
262 while self.dead_letters.len() > MAX_DEAD_LETTERS {
264 if let Some(oldest) = self.dead_letters.iter().map(|e| *e.key()).min() {
266 self.dead_letters.remove(&oldest);
267 } else {
268 break;
269 }
270 }
271 }
272
273 pub fn dead_letter_count(&self) -> usize {
275 self.dead_letters.len()
276 }
277
278 pub fn drain_dead_letters(&self) -> Vec<AgentMessage> {
280 let keys: Vec<u64> = self.dead_letters.iter().map(|e| *e.key()).collect();
281 let mut messages = Vec::new();
282 for key in keys {
283 if let Some((_, msg)) = self.dead_letters.remove(&key) {
284 messages.push(msg);
285 }
286 }
287 messages
288 }
289
290 pub fn registered_count(&self) -> usize {
292 self.mailboxes.len()
293 }
294}
295
296impl Default for MessageRouter {
297 fn default() -> Self {
298 Self::new()
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[tokio::test]
307 async fn test_register_and_receive() {
308 let router = MessageRouter::new();
309 let f1 = FighterId::new();
310 let f2 = FighterId::new();
311 let mut rx1 = router.register(f1);
312 let _rx2 = router.register(f2);
313
314 let msg_id = router
315 .send_direct(
316 f2,
317 f1,
318 AgentMessageType::StatusUpdate {
319 progress: 1.0,
320 detail: "done".to_string(),
321 },
322 MessagePriority::Normal,
323 )
324 .await
325 .expect("should deliver");
326
327 let received = rx1.recv().await.expect("should receive");
328 assert_eq!(received.id, msg_id);
329 assert_eq!(received.from, f2);
330 }
331
332 #[tokio::test]
333 async fn test_broadcast() {
334 let router = MessageRouter::new();
335 let sender = FighterId::new();
336 let r1 = FighterId::new();
337 let r2 = FighterId::new();
338 let _sender_rx = router.register(sender);
339 let mut rx1 = router.register(r1);
340 let mut rx2 = router.register(r2);
341
342 let ids = router
343 .broadcast(
344 sender,
345 AgentMessageType::StatusUpdate {
346 progress: 0.5,
347 detail: "update".to_string(),
348 },
349 MessagePriority::Normal,
350 )
351 .await
352 .expect("should broadcast");
353
354 assert_eq!(ids.len(), 2);
355
356 let m1 = rx1.recv().await.expect("should receive");
357 let m2 = rx2.recv().await.expect("should receive");
358 assert_eq!(m1.from, sender);
359 assert_eq!(m2.from, sender);
360 }
361
362 #[tokio::test]
363 async fn test_multicast() {
364 let router = MessageRouter::new();
365 let sender = FighterId::new();
366 let t1 = FighterId::new();
367 let t2 = FighterId::new();
368 let t3 = FighterId::new();
369 let _sr = router.register(sender);
370 let mut rx1 = router.register(t1);
371 let mut rx2 = router.register(t2);
372 let _rx3 = router.register(t3);
373
374 let ids = router
375 .multicast(
376 sender,
377 vec![t1, t2],
378 AgentMessageType::TaskAssignment {
379 task: "work".to_string(),
380 },
381 MessagePriority::High,
382 )
383 .await
384 .expect("should multicast");
385
386 assert_eq!(ids.len(), 2);
387
388 let m1 = rx1.recv().await.expect("r1 should receive");
389 let m2 = rx2.recv().await.expect("r2 should receive");
390 assert_eq!(m1.from, sender);
391 assert_eq!(m2.from, sender);
392 }
393
394 #[tokio::test]
395 async fn test_request_response() {
396 let router = std::sync::Arc::new(MessageRouter::new());
397 let requester = FighterId::new();
398 let responder = FighterId::new();
399 let _req_rx = router.register(requester);
400 let mut resp_rx = router.register(responder);
401
402 let router_clone = router.clone();
403 let requester_clone = requester;
404 let responder_clone = responder;
405
406 tokio::spawn(async move {
408 if let Some(msg) = resp_rx.recv().await {
409 let response = AgentMessage {
410 id: Uuid::new_v4(),
411 from: responder_clone,
412 to: requester_clone,
413 channel: MessageChannel::Direct,
414 content: AgentMessageType::TaskResult {
415 result: "42".to_string(),
416 success: true,
417 },
418 priority: MessagePriority::Normal,
419 timestamp: Utc::now(),
420 delivered: false,
421 };
422 let _ = router_clone.respond(&msg.id, response);
423 }
424 });
425
426 let result = router
427 .request(
428 requester,
429 responder,
430 AgentMessageType::TaskAssignment {
431 task: "compute".to_string(),
432 },
433 Duration::from_secs(5),
434 )
435 .await
436 .expect("should get response");
437
438 match &result.content {
439 AgentMessageType::TaskResult { result, success } => {
440 assert_eq!(result, "42");
441 assert!(success);
442 }
443 _ => panic!("wrong response type"),
444 }
445 }
446
447 #[tokio::test]
448 async fn test_request_timeout() {
449 let router = MessageRouter::new();
450 let requester = FighterId::new();
451 let responder = FighterId::new();
452 let _req_rx = router.register(requester);
453 let _resp_rx = router.register(responder);
454
455 let result = router
457 .request(
458 requester,
459 responder,
460 AgentMessageType::TaskAssignment {
461 task: "compute".to_string(),
462 },
463 Duration::from_millis(50),
464 )
465 .await;
466
467 assert!(result.is_err());
468 let err = result.unwrap_err().to_string();
469 assert!(err.contains("timed out"));
470 }
471
472 #[tokio::test]
473 async fn test_dead_letter_on_unregistered() {
474 let router = MessageRouter::new();
475 let f1 = FighterId::new();
476 let f2 = FighterId::new();
477 let _rx = router.register(f1);
478
479 let result = router
481 .send_direct(
482 f1,
483 f2,
484 AgentMessageType::StatusUpdate {
485 progress: 0.0,
486 detail: "test".to_string(),
487 },
488 MessagePriority::Low,
489 )
490 .await;
491
492 assert!(result.is_err());
493 assert_eq!(router.dead_letter_count(), 1);
494 }
495
496 #[tokio::test]
497 async fn test_drain_dead_letters() {
498 let router = MessageRouter::new();
499 let f1 = FighterId::new();
500 let f2 = FighterId::new();
501 let _rx = router.register(f1);
502
503 let _ = router
504 .send_direct(
505 f1,
506 f2,
507 AgentMessageType::StatusUpdate {
508 progress: 0.0,
509 detail: "dead".to_string(),
510 },
511 MessagePriority::Low,
512 )
513 .await;
514
515 let letters = router.drain_dead_letters();
516 assert_eq!(letters.len(), 1);
517 assert_eq!(router.dead_letter_count(), 0);
518 }
519
520 #[test]
521 fn test_unregister() {
522 let router = MessageRouter::new();
523 let f = FighterId::new();
524 let _rx = router.register(f);
525 assert!(router.is_registered(&f));
526 router.unregister(&f);
527 assert!(!router.is_registered(&f));
528 }
529
530 #[test]
531 fn test_registered_count() {
532 let router = MessageRouter::new();
533 assert_eq!(router.registered_count(), 0);
534 let f1 = FighterId::new();
535 let f2 = FighterId::new();
536 let _rx1 = router.register(f1);
537 let _rx2 = router.register(f2);
538 assert_eq!(router.registered_count(), 2);
539 }
540
541 #[tokio::test]
542 async fn test_broadcast_excludes_sender() {
543 let router = MessageRouter::new();
544 let sender = FighterId::new();
545 let mut sender_rx = router.register(sender);
546
547 let ids = router
548 .broadcast(
549 sender,
550 AgentMessageType::StatusUpdate {
551 progress: 1.0,
552 detail: "done".to_string(),
553 },
554 MessagePriority::Normal,
555 )
556 .await
557 .expect("should broadcast");
558
559 assert!(ids.is_empty());
561
562 let result = tokio::time::timeout(Duration::from_millis(50), sender_rx.recv()).await;
564 assert!(result.is_err()); }
566
567 #[test]
568 fn test_default_impl() {
569 let router = MessageRouter::default();
570 assert_eq!(router.registered_count(), 0);
571 }
572
573 #[tokio::test]
574 async fn test_message_priority_preserved() {
575 let router = MessageRouter::new();
576 let f1 = FighterId::new();
577 let f2 = FighterId::new();
578 let mut rx = router.register(f1);
579 let _rx2 = router.register(f2);
580
581 router
582 .send_direct(
583 f2,
584 f1,
585 AgentMessageType::Escalation {
586 reason: "urgent".to_string(),
587 original_task: "task".to_string(),
588 },
589 MessagePriority::Critical,
590 )
591 .await
592 .expect("should deliver");
593
594 let msg = rx.recv().await.expect("should receive");
595 assert_eq!(msg.priority, MessagePriority::Critical);
596 }
597
598 #[tokio::test]
599 async fn test_respond_to_nonexistent_request() {
600 let router = MessageRouter::new();
601 let response = AgentMessage {
602 id: Uuid::new_v4(),
603 from: FighterId::new(),
604 to: FighterId::new(),
605 channel: MessageChannel::Direct,
606 content: AgentMessageType::TaskResult {
607 result: "nope".to_string(),
608 success: false,
609 },
610 priority: MessagePriority::Normal,
611 timestamp: Utc::now(),
612 delivered: false,
613 };
614
615 let result = router.respond(&Uuid::new_v4(), response);
616 assert!(result.is_err());
617 }
618}