ringkernel_core/
dispatcher.rs

1//! Multi-Kernel Message Dispatcher
2//!
3//! This module provides a `KernelDispatcher` that routes messages by type_id
4//! to appropriate handler kernels. It builds on the K2K broker infrastructure
5//! to enable type-based message routing across multiple GPU kernels.
6//!
7//! # Architecture
8//!
9//! ```text
10//! Host Application
11//!       │
12//!       ▼
13//! ┌─────────────────────────────────────────────────┐
14//! │            KernelDispatcher                      │
15//! │  ┌─────────────────────────────────────────┐   │
16//! │  │ Route Table (type_id → kernel_id)       │   │
17//! │  │  1001 → fraud_processor                 │   │
18//! │  │  1002 → aggregator                      │   │
19//! │  │  1003 → pattern_detector                │   │
20//! │  └─────────────────────────────────────────┘   │
21//! │                     │                           │
22//! │                     ▼                           │
23//! │  ┌─────────────────────────────────────────┐   │
24//! │  │            K2K Broker                   │   │
25//! │  └─────────────────────────────────────────┘   │
26//! └─────────────────────────────────────────────────┘
27//!                       │
28//!       ┌───────────────┼───────────────┐
29//!       ▼               ▼               ▼
30//! ┌──────────┐   ┌──────────┐   ┌──────────┐
31//! │ Kernel A │   │ Kernel B │   │ Kernel C │
32//! └──────────┘   └──────────┘   └──────────┘
33//! ```
34//!
35//! # Example
36//!
37//! ```ignore
38//! use ringkernel_core::dispatcher::{KernelDispatcher, DispatcherBuilder};
39//! use ringkernel_core::k2k::K2KBroker;
40//!
41//! // Create dispatcher with routes
42//! let broker = K2KBroker::new(K2KConfig::default());
43//! let dispatcher = DispatcherBuilder::new()
44//!     .route::<FraudCheckRequest>(KernelId::new("fraud_processor"))
45//!     .route::<AggregateRequest>(KernelId::new("aggregator"))
46//!     .build(broker);
47//!
48//! // Dispatch a message (routing determined by type_id)
49//! let envelope = MessageEnvelope::from_message(&fraud_check, clock.now());
50//! let receipt = dispatcher.dispatch(envelope).await?;
51//! ```
52
53use parking_lot::RwLock;
54use std::collections::HashMap;
55use std::sync::Arc;
56
57use crate::error::{Result, RingKernelError};
58use crate::hlc::HlcTimestamp;
59use crate::k2k::{DeliveryReceipt, DeliveryStatus, K2KBroker, K2KConfig};
60use crate::message::MessageEnvelope;
61use crate::persistent_message::{DispatchTable, PersistentMessage};
62use crate::runtime::KernelId;
63
64/// Configuration for the kernel dispatcher.
65#[derive(Debug, Clone)]
66pub struct DispatcherConfig {
67    /// Enable logging of dispatch operations.
68    pub enable_logging: bool,
69    /// Enable metrics collection.
70    pub enable_metrics: bool,
71    /// Default priority for dispatched messages.
72    pub default_priority: u8,
73}
74
75impl Default for DispatcherConfig {
76    fn default() -> Self {
77        Self {
78            enable_logging: false,
79            enable_metrics: true,
80            default_priority: 0,
81        }
82    }
83}
84
85/// Metrics for dispatcher operations.
86#[derive(Debug, Default)]
87pub struct DispatcherMetrics {
88    /// Total messages dispatched.
89    pub messages_dispatched: u64,
90    /// Messages successfully delivered.
91    pub messages_delivered: u64,
92    /// Messages that failed to route (unknown type).
93    pub unknown_type_errors: u64,
94    /// Messages that failed to deliver (queue full, etc.).
95    pub delivery_errors: u64,
96}
97
98/// Routes messages by type_id to registered handler kernels.
99///
100/// The dispatcher maintains a routing table mapping message type IDs to kernel IDs.
101/// When a message envelope is dispatched, the dispatcher looks up the type_id
102/// in the routing table and forwards the message to the appropriate kernel
103/// via the K2K broker.
104pub struct KernelDispatcher {
105    /// Routing table: type_id -> kernel_id
106    routes: RwLock<HashMap<u64, KernelId>>,
107    /// Handler dispatch tables per kernel (for CUDA codegen)
108    handler_tables: RwLock<HashMap<KernelId, DispatchTable>>,
109    /// K2K broker for message delivery
110    broker: Arc<K2KBroker>,
111    /// Configuration
112    config: DispatcherConfig,
113    /// Metrics
114    metrics: RwLock<DispatcherMetrics>,
115}
116
117impl KernelDispatcher {
118    /// Create a new dispatcher builder.
119    pub fn builder() -> DispatcherBuilder {
120        DispatcherBuilder::new()
121    }
122
123    /// Create a new dispatcher with the given broker.
124    pub fn new(broker: Arc<K2KBroker>) -> Self {
125        Self::with_config(broker, DispatcherConfig::default())
126    }
127
128    /// Create a new dispatcher with custom configuration.
129    pub fn with_config(broker: Arc<K2KBroker>, config: DispatcherConfig) -> Self {
130        Self {
131            routes: RwLock::new(HashMap::new()),
132            handler_tables: RwLock::new(HashMap::new()),
133            broker,
134            config,
135            metrics: RwLock::new(DispatcherMetrics::default()),
136        }
137    }
138
139    /// Register a message type to route to a specific kernel.
140    ///
141    /// # Type Parameters
142    ///
143    /// - `M`: A message type implementing `PersistentMessage`
144    ///
145    /// # Arguments
146    ///
147    /// - `kernel_id`: The kernel that will handle messages of this type
148    pub fn register<M: PersistentMessage>(&self, kernel_id: KernelId) {
149        self.register_with_name::<M>(kernel_id, std::any::type_name::<M>());
150    }
151
152    /// Register a message type with a custom handler name.
153    pub fn register_with_name<M: PersistentMessage>(
154        &self,
155        kernel_id: KernelId,
156        handler_name: &str,
157    ) {
158        let type_id = M::message_type();
159
160        // Add to routing table
161        self.routes.write().insert(type_id, kernel_id.clone());
162
163        // Add to handler table for the kernel
164        let mut handler_tables = self.handler_tables.write();
165        let table = handler_tables.entry(kernel_id).or_default();
166        table.register_message::<M>(handler_name);
167    }
168
169    /// Register a route with explicit type_id (for dynamic registration).
170    pub fn register_route(&self, type_id: u64, kernel_id: KernelId) {
171        self.routes.write().insert(type_id, kernel_id);
172    }
173
174    /// Unregister a message type.
175    pub fn unregister(&self, type_id: u64) {
176        self.routes.write().remove(&type_id);
177    }
178
179    /// Get the kernel ID for a message type.
180    pub fn get_route(&self, type_id: u64) -> Option<KernelId> {
181        self.routes.read().get(&type_id).cloned()
182    }
183
184    /// Check if a route exists for a type.
185    pub fn has_route(&self, type_id: u64) -> bool {
186        self.routes.read().contains_key(&type_id)
187    }
188
189    /// Get all registered routes.
190    pub fn routes(&self) -> Vec<(u64, KernelId)> {
191        self.routes
192            .read()
193            .iter()
194            .map(|(k, v)| (*k, v.clone()))
195            .collect()
196    }
197
198    /// Get the dispatch table for a kernel (for CUDA codegen).
199    pub fn get_dispatch_table(&self, kernel_id: &KernelId) -> Option<DispatchTable> {
200        self.handler_tables.read().get(kernel_id).cloned()
201    }
202
203    /// Dispatch a message envelope to the appropriate kernel.
204    ///
205    /// The type_id from the envelope header is used to look up the destination
206    /// kernel. If no route exists for the type_id, returns an error.
207    ///
208    /// # Returns
209    ///
210    /// - `Ok(DeliveryReceipt)` with delivery status
211    /// - `Err(RingKernelError::UnknownMessageType)` if no route exists
212    pub async fn dispatch(&self, envelope: MessageEnvelope) -> Result<DeliveryReceipt> {
213        // Use "host" as the default source for dispatched messages
214        self.dispatch_from(KernelId::new("host"), envelope).await
215    }
216
217    /// Dispatch a message from a specific source kernel.
218    pub async fn dispatch_from(
219        &self,
220        source: KernelId,
221        envelope: MessageEnvelope,
222    ) -> Result<DeliveryReceipt> {
223        let type_id = envelope.header.message_type;
224
225        // Look up the destination kernel
226        let kernel_id = {
227            let routes = self.routes.read();
228            routes.get(&type_id).cloned()
229        };
230
231        let kernel_id = match kernel_id {
232            Some(id) => id,
233            None => {
234                // Update metrics
235                {
236                    let mut metrics = self.metrics.write();
237                    metrics.messages_dispatched += 1;
238                    metrics.unknown_type_errors += 1;
239                }
240                return Err(RingKernelError::K2KError(format!(
241                    "No route for message type_id: {}",
242                    type_id
243                )));
244            }
245        };
246
247        // Dispatch via K2K broker
248        let receipt = self
249            .broker
250            .send_priority(source, kernel_id, envelope, self.config.default_priority)
251            .await?;
252
253        // Update metrics
254        {
255            let mut metrics = self.metrics.write();
256            metrics.messages_dispatched += 1;
257            match receipt.status {
258                DeliveryStatus::Delivered => metrics.messages_delivered += 1,
259                DeliveryStatus::Pending => {} // Still in flight
260                _ => metrics.delivery_errors += 1,
261            }
262        }
263
264        Ok(receipt)
265    }
266
267    /// Dispatch a typed message.
268    ///
269    /// Creates an envelope from the message and dispatches it.
270    pub async fn dispatch_message<M: PersistentMessage>(
271        &self,
272        message: &M,
273        timestamp: HlcTimestamp,
274    ) -> Result<DeliveryReceipt> {
275        // Use 0 for source/dest kernel IDs - the dispatcher will route based on type_id
276        let envelope = MessageEnvelope::new(message, 0, 0, timestamp);
277        self.dispatch(envelope).await
278    }
279
280    /// Get current metrics.
281    pub fn metrics(&self) -> DispatcherMetrics {
282        let metrics = self.metrics.read();
283        DispatcherMetrics {
284            messages_dispatched: metrics.messages_dispatched,
285            messages_delivered: metrics.messages_delivered,
286            unknown_type_errors: metrics.unknown_type_errors,
287            delivery_errors: metrics.delivery_errors,
288        }
289    }
290
291    /// Reset metrics.
292    pub fn reset_metrics(&self) {
293        *self.metrics.write() = DispatcherMetrics::default();
294    }
295
296    /// Get a reference to the underlying K2K broker.
297    pub fn broker(&self) -> &Arc<K2KBroker> {
298        &self.broker
299    }
300}
301
302/// Builder for creating a KernelDispatcher.
303pub struct DispatcherBuilder {
304    /// Pending routes to register
305    routes: Vec<Route>,
306    /// Configuration
307    config: DispatcherConfig,
308    /// K2K configuration
309    k2k_config: K2KConfig,
310}
311
312/// A route registration.
313struct Route {
314    /// Message type ID
315    type_id: u64,
316    /// Target kernel ID
317    kernel_id: KernelId,
318    /// Handler name
319    handler_name: String,
320    /// Handler ID (for PersistentMessage types)
321    handler_id: Option<u32>,
322    /// Whether response is required
323    requires_response: bool,
324}
325
326impl DispatcherBuilder {
327    /// Create a new builder.
328    pub fn new() -> Self {
329        Self {
330            routes: Vec::new(),
331            config: DispatcherConfig::default(),
332            k2k_config: K2KConfig::default(),
333        }
334    }
335
336    /// Add a route for a PersistentMessage type.
337    pub fn route<M: PersistentMessage>(mut self, kernel_id: KernelId) -> Self {
338        self.routes.push(Route {
339            type_id: M::message_type(),
340            kernel_id,
341            handler_name: std::any::type_name::<M>().to_string(),
342            handler_id: Some(M::handler_id()),
343            requires_response: M::requires_response(),
344        });
345        self
346    }
347
348    /// Add a route with custom handler name.
349    pub fn route_named<M: PersistentMessage>(
350        mut self,
351        kernel_id: KernelId,
352        handler_name: &str,
353    ) -> Self {
354        self.routes.push(Route {
355            type_id: M::message_type(),
356            kernel_id,
357            handler_name: handler_name.to_string(),
358            handler_id: Some(M::handler_id()),
359            requires_response: M::requires_response(),
360        });
361        self
362    }
363
364    /// Add a raw route (for dynamic type_ids).
365    pub fn route_raw(mut self, type_id: u64, kernel_id: KernelId) -> Self {
366        self.routes.push(Route {
367            type_id,
368            kernel_id,
369            handler_name: format!("handler_{}", type_id),
370            handler_id: None,
371            requires_response: false,
372        });
373        self
374    }
375
376    /// Set dispatcher configuration.
377    pub fn with_config(mut self, config: DispatcherConfig) -> Self {
378        self.config = config;
379        self
380    }
381
382    /// Set K2K configuration.
383    pub fn with_k2k_config(mut self, config: K2KConfig) -> Self {
384        self.k2k_config = config;
385        self
386    }
387
388    /// Enable logging.
389    pub fn with_logging(mut self) -> Self {
390        self.config.enable_logging = true;
391        self
392    }
393
394    /// Set default message priority.
395    pub fn with_priority(mut self, priority: u8) -> Self {
396        self.config.default_priority = priority;
397        self
398    }
399
400    /// Build the dispatcher with a new K2K broker.
401    pub fn build(self) -> KernelDispatcher {
402        let broker = K2KBroker::new(self.k2k_config.clone());
403        self.build_with_broker(broker)
404    }
405
406    /// Build the dispatcher with an existing K2K broker.
407    pub fn build_with_broker(self, broker: Arc<K2KBroker>) -> KernelDispatcher {
408        let dispatcher = KernelDispatcher::with_config(broker, self.config);
409
410        // Register all routes
411        for route in self.routes {
412            dispatcher
413                .routes
414                .write()
415                .insert(route.type_id, route.kernel_id.clone());
416
417            // Also register in handler tables if we have handler_id
418            if let Some(handler_id) = route.handler_id {
419                use crate::persistent_message::HandlerRegistration;
420
421                let mut handler_tables = dispatcher.handler_tables.write();
422                let table = handler_tables.entry(route.kernel_id).or_default();
423
424                let mut registration =
425                    HandlerRegistration::new(handler_id, &route.handler_name, route.type_id);
426
427                if route.requires_response {
428                    // Note: We don't have the response type_id here, so we use 0
429                    // In practice, the response type would be registered separately
430                    registration = registration.with_response(0);
431                }
432
433                table.register(registration);
434            }
435        }
436
437        dispatcher
438    }
439}
440
441impl Default for DispatcherBuilder {
442    fn default() -> Self {
443        Self::new()
444    }
445}
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450    use crate::hlc::HlcClock;
451    use crate::message::{MessageHeader, RingMessage};
452
453    // Test message type
454    #[derive(Clone, Copy, Debug)]
455    #[repr(C)]
456    struct TestRequest {
457        value: u64,
458    }
459
460    impl RingMessage for TestRequest {
461        fn message_type() -> u64 {
462            5001
463        }
464
465        fn message_id(&self) -> crate::message::MessageId {
466            crate::message::MessageId::new(0)
467        }
468
469        fn correlation_id(&self) -> crate::message::CorrelationId {
470            crate::message::CorrelationId::none()
471        }
472
473        fn priority(&self) -> crate::message::Priority {
474            crate::message::Priority::Normal
475        }
476
477        fn serialize(&self) -> Vec<u8> {
478            self.value.to_le_bytes().to_vec()
479        }
480
481        fn deserialize(bytes: &[u8]) -> Result<Self> {
482            if bytes.len() < 8 {
483                return Err(RingKernelError::DeserializationError(
484                    "Too small".to_string(),
485                ));
486            }
487            let value = u64::from_le_bytes(bytes[..8].try_into().unwrap());
488            Ok(Self { value })
489        }
490
491        fn size_hint(&self) -> usize {
492            8
493        }
494    }
495
496    impl PersistentMessage for TestRequest {
497        fn handler_id() -> u32 {
498            1
499        }
500
501        fn requires_response() -> bool {
502            true
503        }
504
505        fn payload_size() -> usize {
506            8
507        }
508
509        fn to_inline_payload(
510            &self,
511        ) -> Option<[u8; crate::persistent_message::MAX_INLINE_PAYLOAD_SIZE]> {
512            let mut payload = [0u8; 32];
513            payload[..8].copy_from_slice(&self.value.to_le_bytes());
514            Some(payload)
515        }
516
517        fn from_inline_payload(payload: &[u8]) -> Result<Self> {
518            if payload.len() < 8 {
519                return Err(RingKernelError::DeserializationError(
520                    "Too small".to_string(),
521                ));
522            }
523            let value = u64::from_le_bytes(payload[..8].try_into().unwrap());
524            Ok(Self { value })
525        }
526    }
527
528    #[test]
529    fn test_dispatcher_builder() {
530        let kernel_id = KernelId::new("test_kernel");
531
532        let dispatcher = DispatcherBuilder::new()
533            .route::<TestRequest>(kernel_id.clone())
534            .build();
535
536        assert!(dispatcher.has_route(5001));
537        assert_eq!(dispatcher.get_route(5001), Some(kernel_id));
538    }
539
540    #[test]
541    fn test_dispatcher_registration() {
542        let dispatcher = DispatcherBuilder::new().build();
543
544        let kernel_id = KernelId::new("processor");
545        dispatcher.register::<TestRequest>(kernel_id.clone());
546
547        assert!(dispatcher.has_route(5001));
548        assert_eq!(dispatcher.get_route(5001), Some(kernel_id));
549    }
550
551    #[test]
552    fn test_dispatcher_unregister() {
553        let dispatcher = DispatcherBuilder::new()
554            .route::<TestRequest>(KernelId::new("processor"))
555            .build();
556
557        assert!(dispatcher.has_route(5001));
558        dispatcher.unregister(5001);
559        assert!(!dispatcher.has_route(5001));
560    }
561
562    #[test]
563    fn test_dispatcher_routes() {
564        let kernel_a = KernelId::new("kernel_a");
565        let kernel_b = KernelId::new("kernel_b");
566
567        let dispatcher = DispatcherBuilder::new()
568            .route::<TestRequest>(kernel_a.clone())
569            .route_raw(9999, kernel_b.clone())
570            .build();
571
572        let routes = dispatcher.routes();
573        assert_eq!(routes.len(), 2);
574        assert!(routes.contains(&(5001, kernel_a)));
575        assert!(routes.contains(&(9999, kernel_b)));
576    }
577
578    #[test]
579    fn test_dispatch_table_generation() {
580        let kernel_id = KernelId::new("test_kernel");
581
582        let dispatcher = DispatcherBuilder::new()
583            .route::<TestRequest>(kernel_id.clone())
584            .build();
585
586        let table = dispatcher.get_dispatch_table(&kernel_id);
587        assert!(table.is_some());
588
589        let table = table.unwrap();
590        assert_eq!(table.len(), 1);
591
592        let handler = table.get(1).unwrap();
593        assert_eq!(handler.handler_id, 1);
594        assert_eq!(handler.message_type_id, 5001);
595    }
596
597    #[tokio::test]
598    async fn test_dispatch_unknown_type() {
599        let dispatcher = DispatcherBuilder::new().build();
600
601        let clock = HlcClock::new(1);
602        let header = MessageHeader::new(9999, 0, 0, 0, clock.now());
603        let envelope = MessageEnvelope {
604            header,
605            payload: vec![],
606        };
607
608        let result = dispatcher.dispatch(envelope).await;
609        assert!(result.is_err());
610
611        let metrics = dispatcher.metrics();
612        assert_eq!(metrics.messages_dispatched, 1);
613        assert_eq!(metrics.unknown_type_errors, 1);
614    }
615
616    #[tokio::test]
617    async fn test_dispatch_to_registered_kernel() {
618        let kernel_id = KernelId::new("test_kernel");
619
620        let broker = K2KBroker::new(K2KConfig::default());
621        let _endpoint = broker.register(kernel_id.clone());
622
623        let dispatcher = DispatcherBuilder::new()
624            .route::<TestRequest>(kernel_id)
625            .build_with_broker(broker);
626
627        let clock = HlcClock::new(1);
628        let msg = TestRequest { value: 42 };
629        let envelope = MessageEnvelope::new(&msg, 0, 0, clock.now());
630
631        let receipt = dispatcher.dispatch(envelope).await.unwrap();
632        assert_eq!(receipt.status, DeliveryStatus::Delivered);
633
634        let metrics = dispatcher.metrics();
635        assert_eq!(metrics.messages_dispatched, 1);
636        assert_eq!(metrics.messages_delivered, 1);
637    }
638
639    #[test]
640    fn test_metrics_reset() {
641        let dispatcher = DispatcherBuilder::new().build();
642
643        {
644            let mut metrics = dispatcher.metrics.write();
645            metrics.messages_dispatched = 100;
646            metrics.messages_delivered = 50;
647        }
648
649        let metrics = dispatcher.metrics();
650        assert_eq!(metrics.messages_dispatched, 100);
651
652        dispatcher.reset_metrics();
653
654        let metrics = dispatcher.metrics();
655        assert_eq!(metrics.messages_dispatched, 0);
656        assert_eq!(metrics.messages_delivered, 0);
657    }
658}