1use parking_lot::RwLock;
54use std::collections::HashMap;
55use std::sync::Arc;
56
57use crate::error::{Result, RingKernelError};
58use crate::hlc::HlcTimestamp;
59use crate::k2k::{DeliveryReceipt, DeliveryStatus, K2KBroker, K2KConfig};
60use crate::message::MessageEnvelope;
61use crate::persistent_message::{DispatchTable, PersistentMessage};
62use crate::runtime::KernelId;
63
64#[derive(Debug, Clone)]
66pub struct DispatcherConfig {
67 pub enable_logging: bool,
69 pub enable_metrics: bool,
71 pub default_priority: u8,
73}
74
75impl Default for DispatcherConfig {
76 fn default() -> Self {
77 Self {
78 enable_logging: false,
79 enable_metrics: true,
80 default_priority: 0,
81 }
82 }
83}
84
85#[derive(Debug, Default)]
87pub struct DispatcherMetrics {
88 pub messages_dispatched: u64,
90 pub messages_delivered: u64,
92 pub unknown_type_errors: u64,
94 pub delivery_errors: u64,
96}
97
98pub struct KernelDispatcher {
105 routes: RwLock<HashMap<u64, KernelId>>,
107 handler_tables: RwLock<HashMap<KernelId, DispatchTable>>,
109 broker: Arc<K2KBroker>,
111 config: DispatcherConfig,
113 metrics: RwLock<DispatcherMetrics>,
115}
116
117impl KernelDispatcher {
118 pub fn builder() -> DispatcherBuilder {
120 DispatcherBuilder::new()
121 }
122
123 pub fn new(broker: Arc<K2KBroker>) -> Self {
125 Self::with_config(broker, DispatcherConfig::default())
126 }
127
128 pub fn with_config(broker: Arc<K2KBroker>, config: DispatcherConfig) -> Self {
130 Self {
131 routes: RwLock::new(HashMap::new()),
132 handler_tables: RwLock::new(HashMap::new()),
133 broker,
134 config,
135 metrics: RwLock::new(DispatcherMetrics::default()),
136 }
137 }
138
139 pub fn register<M: PersistentMessage>(&self, kernel_id: KernelId) -> crate::error::Result<()> {
149 self.register_with_name::<M>(kernel_id, std::any::type_name::<M>())
150 }
151
152 pub fn register_with_name<M: PersistentMessage>(
154 &self,
155 kernel_id: KernelId,
156 handler_name: &str,
157 ) -> crate::error::Result<()> {
158 let type_id = M::message_type();
159
160 self.routes.write().insert(type_id, kernel_id.clone());
162
163 let mut handler_tables = self.handler_tables.write();
165 let table = handler_tables.entry(kernel_id).or_default();
166 table.register_message::<M>(handler_name)
167 }
168
169 pub fn register_route(&self, type_id: u64, kernel_id: KernelId) {
171 self.routes.write().insert(type_id, kernel_id);
172 }
173
174 pub fn unregister(&self, type_id: u64) {
176 self.routes.write().remove(&type_id);
177 }
178
179 pub fn get_route(&self, type_id: u64) -> Option<KernelId> {
181 self.routes.read().get(&type_id).cloned()
182 }
183
184 pub fn has_route(&self, type_id: u64) -> bool {
186 self.routes.read().contains_key(&type_id)
187 }
188
189 pub fn routes(&self) -> Vec<(u64, KernelId)> {
191 self.routes
192 .read()
193 .iter()
194 .map(|(k, v)| (*k, v.clone()))
195 .collect()
196 }
197
198 pub fn get_dispatch_table(&self, kernel_id: &KernelId) -> Option<DispatchTable> {
200 self.handler_tables.read().get(kernel_id).cloned()
201 }
202
203 pub async fn dispatch(&self, envelope: MessageEnvelope) -> Result<DeliveryReceipt> {
213 self.dispatch_from(KernelId::new("host"), envelope).await
215 }
216
217 pub async fn dispatch_from(
219 &self,
220 source: KernelId,
221 envelope: MessageEnvelope,
222 ) -> Result<DeliveryReceipt> {
223 let type_id = envelope.header.message_type;
224
225 let kernel_id = {
227 let routes = self.routes.read();
228 routes.get(&type_id).cloned()
229 };
230
231 let kernel_id = match kernel_id {
232 Some(id) => id,
233 None => {
234 {
236 let mut metrics = self.metrics.write();
237 metrics.messages_dispatched += 1;
238 metrics.unknown_type_errors += 1;
239 }
240 return Err(RingKernelError::K2KError(format!(
241 "No route for message type_id: {}",
242 type_id
243 )));
244 }
245 };
246
247 let receipt = self
249 .broker
250 .send_priority(source, kernel_id, envelope, self.config.default_priority)
251 .await?;
252
253 {
255 let mut metrics = self.metrics.write();
256 metrics.messages_dispatched += 1;
257 match receipt.status {
258 DeliveryStatus::Delivered => metrics.messages_delivered += 1,
259 DeliveryStatus::Pending => {} _ => metrics.delivery_errors += 1,
261 }
262 }
263
264 Ok(receipt)
265 }
266
267 pub async fn dispatch_message<M: PersistentMessage>(
271 &self,
272 message: &M,
273 timestamp: HlcTimestamp,
274 ) -> Result<DeliveryReceipt> {
275 let envelope = MessageEnvelope::new(message, 0, 0, timestamp);
277 self.dispatch(envelope).await
278 }
279
280 pub fn metrics(&self) -> DispatcherMetrics {
282 let metrics = self.metrics.read();
283 DispatcherMetrics {
284 messages_dispatched: metrics.messages_dispatched,
285 messages_delivered: metrics.messages_delivered,
286 unknown_type_errors: metrics.unknown_type_errors,
287 delivery_errors: metrics.delivery_errors,
288 }
289 }
290
291 pub fn reset_metrics(&self) {
293 *self.metrics.write() = DispatcherMetrics::default();
294 }
295
296 pub fn broker(&self) -> &Arc<K2KBroker> {
298 &self.broker
299 }
300}
301
302pub struct DispatcherBuilder {
304 routes: Vec<Route>,
306 config: DispatcherConfig,
308 k2k_config: K2KConfig,
310}
311
312struct Route {
314 type_id: u64,
316 kernel_id: KernelId,
318 handler_name: String,
320 handler_id: Option<u32>,
322 requires_response: bool,
324}
325
326impl DispatcherBuilder {
327 pub fn new() -> Self {
329 Self {
330 routes: Vec::new(),
331 config: DispatcherConfig::default(),
332 k2k_config: K2KConfig::default(),
333 }
334 }
335
336 pub fn route<M: PersistentMessage>(mut self, kernel_id: KernelId) -> Self {
338 self.routes.push(Route {
339 type_id: M::message_type(),
340 kernel_id,
341 handler_name: std::any::type_name::<M>().to_string(),
342 handler_id: Some(M::handler_id()),
343 requires_response: M::requires_response(),
344 });
345 self
346 }
347
348 pub fn route_named<M: PersistentMessage>(
350 mut self,
351 kernel_id: KernelId,
352 handler_name: &str,
353 ) -> Self {
354 self.routes.push(Route {
355 type_id: M::message_type(),
356 kernel_id,
357 handler_name: handler_name.to_string(),
358 handler_id: Some(M::handler_id()),
359 requires_response: M::requires_response(),
360 });
361 self
362 }
363
364 pub fn route_raw(mut self, type_id: u64, kernel_id: KernelId) -> Self {
366 self.routes.push(Route {
367 type_id,
368 kernel_id,
369 handler_name: format!("handler_{}", type_id),
370 handler_id: None,
371 requires_response: false,
372 });
373 self
374 }
375
376 pub fn with_config(mut self, config: DispatcherConfig) -> Self {
378 self.config = config;
379 self
380 }
381
382 pub fn with_k2k_config(mut self, config: K2KConfig) -> Self {
384 self.k2k_config = config;
385 self
386 }
387
388 pub fn with_logging(mut self) -> Self {
390 self.config.enable_logging = true;
391 self
392 }
393
394 pub fn with_priority(mut self, priority: u8) -> Self {
396 self.config.default_priority = priority;
397 self
398 }
399
400 pub fn build(self) -> KernelDispatcher {
402 let broker = K2KBroker::new(self.k2k_config.clone());
403 self.build_with_broker(broker)
404 }
405
406 pub fn build_with_broker(self, broker: Arc<K2KBroker>) -> KernelDispatcher {
408 let dispatcher = KernelDispatcher::with_config(broker, self.config);
409
410 for route in self.routes {
412 dispatcher
413 .routes
414 .write()
415 .insert(route.type_id, route.kernel_id.clone());
416
417 if let Some(handler_id) = route.handler_id {
419 use crate::persistent_message::HandlerRegistration;
420
421 let mut handler_tables = dispatcher.handler_tables.write();
422 let table = handler_tables.entry(route.kernel_id).or_default();
423
424 let mut registration =
425 HandlerRegistration::new(handler_id, &route.handler_name, route.type_id);
426
427 if route.requires_response {
428 registration = registration.with_response(0);
431 }
432
433 if let Err(e) = table.register(registration) {
434 tracing::warn!("Failed to register handler in dispatcher build: {}", e);
435 }
436 }
437 }
438
439 dispatcher
440 }
441}
442
443impl Default for DispatcherBuilder {
444 fn default() -> Self {
445 Self::new()
446 }
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452 use crate::hlc::HlcClock;
453 use crate::message::{MessageHeader, RingMessage};
454
455 #[derive(Clone, Copy, Debug)]
457 #[repr(C)]
458 struct TestRequest {
459 value: u64,
460 }
461
462 impl RingMessage for TestRequest {
463 fn message_type() -> u64 {
464 5001
465 }
466
467 fn message_id(&self) -> crate::message::MessageId {
468 crate::message::MessageId::new(0)
469 }
470
471 fn correlation_id(&self) -> crate::message::CorrelationId {
472 crate::message::CorrelationId::none()
473 }
474
475 fn priority(&self) -> crate::message::Priority {
476 crate::message::Priority::Normal
477 }
478
479 fn serialize(&self) -> Vec<u8> {
480 self.value.to_le_bytes().to_vec()
481 }
482
483 fn deserialize(bytes: &[u8]) -> Result<Self> {
484 if bytes.len() < 8 {
485 return Err(RingKernelError::DeserializationError(
486 "Too small".to_string(),
487 ));
488 }
489 let value = u64::from_le_bytes(bytes[..8].try_into().unwrap());
490 Ok(Self { value })
491 }
492
493 fn size_hint(&self) -> usize {
494 8
495 }
496 }
497
498 impl PersistentMessage for TestRequest {
499 fn handler_id() -> u32 {
500 1
501 }
502
503 fn requires_response() -> bool {
504 true
505 }
506
507 fn payload_size() -> usize {
508 8
509 }
510
511 fn to_inline_payload(
512 &self,
513 ) -> Option<[u8; crate::persistent_message::MAX_INLINE_PAYLOAD_SIZE]> {
514 let mut payload = [0u8; 32];
515 payload[..8].copy_from_slice(&self.value.to_le_bytes());
516 Some(payload)
517 }
518
519 fn from_inline_payload(payload: &[u8]) -> Result<Self> {
520 if payload.len() < 8 {
521 return Err(RingKernelError::DeserializationError(
522 "Too small".to_string(),
523 ));
524 }
525 let value = u64::from_le_bytes(payload[..8].try_into().unwrap());
526 Ok(Self { value })
527 }
528 }
529
530 #[test]
531 fn test_dispatcher_builder() {
532 let kernel_id = KernelId::new("test_kernel");
533
534 let dispatcher = DispatcherBuilder::new()
535 .route::<TestRequest>(kernel_id.clone())
536 .build();
537
538 assert!(dispatcher.has_route(5001));
539 assert_eq!(dispatcher.get_route(5001), Some(kernel_id));
540 }
541
542 #[test]
543 fn test_dispatcher_registration() {
544 let dispatcher = DispatcherBuilder::new().build();
545
546 let kernel_id = KernelId::new("processor");
547 dispatcher
548 .register::<TestRequest>(kernel_id.clone())
549 .unwrap();
550
551 assert!(dispatcher.has_route(5001));
552 assert_eq!(dispatcher.get_route(5001), Some(kernel_id));
553 }
554
555 #[test]
556 fn test_dispatcher_unregister() {
557 let dispatcher = DispatcherBuilder::new()
558 .route::<TestRequest>(KernelId::new("processor"))
559 .build();
560
561 assert!(dispatcher.has_route(5001));
562 dispatcher.unregister(5001);
563 assert!(!dispatcher.has_route(5001));
564 }
565
566 #[test]
567 fn test_dispatcher_routes() {
568 let kernel_a = KernelId::new("kernel_a");
569 let kernel_b = KernelId::new("kernel_b");
570
571 let dispatcher = DispatcherBuilder::new()
572 .route::<TestRequest>(kernel_a.clone())
573 .route_raw(9999, kernel_b.clone())
574 .build();
575
576 let routes = dispatcher.routes();
577 assert_eq!(routes.len(), 2);
578 assert!(routes.contains(&(5001, kernel_a)));
579 assert!(routes.contains(&(9999, kernel_b)));
580 }
581
582 #[test]
583 fn test_dispatch_table_generation() {
584 let kernel_id = KernelId::new("test_kernel");
585
586 let dispatcher = DispatcherBuilder::new()
587 .route::<TestRequest>(kernel_id.clone())
588 .build();
589
590 let table = dispatcher.get_dispatch_table(&kernel_id);
591 assert!(table.is_some());
592
593 let table = table.unwrap();
594 assert_eq!(table.len(), 1);
595
596 let handler = table.get(1).unwrap();
597 assert_eq!(handler.handler_id, 1);
598 assert_eq!(handler.message_type_id, 5001);
599 }
600
601 #[tokio::test]
602 async fn test_dispatch_unknown_type() {
603 let dispatcher = DispatcherBuilder::new().build();
604
605 let clock = HlcClock::new(1);
606 let header = MessageHeader::new(9999, 0, 0, 0, clock.now());
607 let envelope = MessageEnvelope {
608 header,
609 payload: vec![],
610 };
611
612 let result = dispatcher.dispatch(envelope).await;
613 assert!(result.is_err());
614
615 let metrics = dispatcher.metrics();
616 assert_eq!(metrics.messages_dispatched, 1);
617 assert_eq!(metrics.unknown_type_errors, 1);
618 }
619
620 #[tokio::test]
621 async fn test_dispatch_to_registered_kernel() {
622 let kernel_id = KernelId::new("test_kernel");
623
624 let broker = K2KBroker::new(K2KConfig::default());
625 let _endpoint = broker.register(kernel_id.clone());
626
627 let dispatcher = DispatcherBuilder::new()
628 .route::<TestRequest>(kernel_id)
629 .build_with_broker(broker);
630
631 let clock = HlcClock::new(1);
632 let msg = TestRequest { value: 42 };
633 let envelope = MessageEnvelope::new(&msg, 0, 0, clock.now());
634
635 let receipt = dispatcher.dispatch(envelope).await.unwrap();
636 assert_eq!(receipt.status, DeliveryStatus::Delivered);
637
638 let metrics = dispatcher.metrics();
639 assert_eq!(metrics.messages_dispatched, 1);
640 assert_eq!(metrics.messages_delivered, 1);
641 }
642
643 #[test]
644 fn test_metrics_reset() {
645 let dispatcher = DispatcherBuilder::new().build();
646
647 {
648 let mut metrics = dispatcher.metrics.write();
649 metrics.messages_dispatched = 100;
650 metrics.messages_delivered = 50;
651 }
652
653 let metrics = dispatcher.metrics();
654 assert_eq!(metrics.messages_dispatched, 100);
655
656 dispatcher.reset_metrics();
657
658 let metrics = dispatcher.metrics();
659 assert_eq!(metrics.messages_dispatched, 0);
660 assert_eq!(metrics.messages_delivered, 0);
661 }
662}