ringkernel_core/
k2k.rs

1//! Kernel-to-Kernel (K2K) direct messaging.
2//!
3//! This module provides infrastructure for direct communication between
4//! GPU kernels without host-side mediation.
5
6use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::Arc;
10use tokio::sync::mpsc;
11
12use crate::error::{Result, RingKernelError};
13use crate::hlc::HlcTimestamp;
14use crate::message::{MessageEnvelope, MessageId};
15use crate::runtime::KernelId;
16
17/// Configuration for K2K messaging.
18#[derive(Debug, Clone)]
19pub struct K2KConfig {
20    /// Maximum pending messages per kernel pair.
21    pub max_pending_messages: usize,
22    /// Timeout for delivery in milliseconds.
23    pub delivery_timeout_ms: u64,
24    /// Enable message tracing.
25    pub enable_tracing: bool,
26    /// Maximum hop count for routed messages.
27    pub max_hops: u8,
28}
29
30impl Default for K2KConfig {
31    fn default() -> Self {
32        Self {
33            max_pending_messages: 1024,
34            delivery_timeout_ms: 5000,
35            enable_tracing: false,
36            max_hops: 8,
37        }
38    }
39}
40
41/// A K2K message with routing information.
42#[derive(Debug, Clone)]
43pub struct K2KMessage {
44    /// Unique message ID.
45    pub id: MessageId,
46    /// Source kernel.
47    pub source: KernelId,
48    /// Destination kernel.
49    pub destination: KernelId,
50    /// The message envelope.
51    pub envelope: MessageEnvelope,
52    /// Hop count (for detecting routing loops).
53    pub hops: u8,
54    /// Timestamp when message was sent.
55    pub sent_at: HlcTimestamp,
56    /// Priority (higher = more urgent).
57    pub priority: u8,
58}
59
60impl K2KMessage {
61    /// Create a new K2K message.
62    pub fn new(
63        source: KernelId,
64        destination: KernelId,
65        envelope: MessageEnvelope,
66        timestamp: HlcTimestamp,
67    ) -> Self {
68        Self {
69            id: MessageId::generate(),
70            source,
71            destination,
72            envelope,
73            hops: 0,
74            sent_at: timestamp,
75            priority: 0,
76        }
77    }
78
79    /// Create with priority.
80    pub fn with_priority(mut self, priority: u8) -> Self {
81        self.priority = priority;
82        self
83    }
84
85    /// Increment hop count.
86    pub fn increment_hops(&mut self) -> Result<()> {
87        self.hops += 1;
88        if self.hops > 16 {
89            return Err(RingKernelError::K2KError(
90                "Maximum hop count exceeded".to_string(),
91            ));
92        }
93        Ok(())
94    }
95}
96
97/// Receipt for a K2K message delivery.
98#[derive(Debug, Clone)]
99pub struct DeliveryReceipt {
100    /// Message ID.
101    pub message_id: MessageId,
102    /// Source kernel.
103    pub source: KernelId,
104    /// Destination kernel.
105    pub destination: KernelId,
106    /// Delivery status.
107    pub status: DeliveryStatus,
108    /// Timestamp of delivery/failure.
109    pub timestamp: HlcTimestamp,
110}
111
112/// Status of message delivery.
113#[derive(Debug, Clone, Copy, PartialEq, Eq)]
114pub enum DeliveryStatus {
115    /// Message delivered successfully.
116    Delivered,
117    /// Message pending delivery.
118    Pending,
119    /// Destination kernel not found.
120    NotFound,
121    /// Destination queue full.
122    QueueFull,
123    /// Delivery timed out.
124    Timeout,
125    /// Maximum hops exceeded.
126    MaxHopsExceeded,
127}
128
129/// K2K endpoint for a single kernel.
130pub struct K2KEndpoint {
131    /// Kernel ID.
132    kernel_id: KernelId,
133    /// Incoming message channel.
134    receiver: mpsc::Receiver<K2KMessage>,
135    /// Reference to the broker.
136    broker: Arc<K2KBroker>,
137}
138
139impl K2KEndpoint {
140    /// Receive a K2K message (blocking).
141    pub async fn receive(&mut self) -> Option<K2KMessage> {
142        self.receiver.recv().await
143    }
144
145    /// Try to receive a K2K message (non-blocking).
146    pub fn try_receive(&mut self) -> Option<K2KMessage> {
147        self.receiver.try_recv().ok()
148    }
149
150    /// Send a message to another kernel.
151    pub async fn send(
152        &self,
153        destination: KernelId,
154        envelope: MessageEnvelope,
155    ) -> Result<DeliveryReceipt> {
156        self.broker
157            .send(self.kernel_id.clone(), destination, envelope)
158            .await
159    }
160
161    /// Send a high-priority message.
162    pub async fn send_priority(
163        &self,
164        destination: KernelId,
165        envelope: MessageEnvelope,
166        priority: u8,
167    ) -> Result<DeliveryReceipt> {
168        self.broker
169            .send_priority(self.kernel_id.clone(), destination, envelope, priority)
170            .await
171    }
172
173    /// Get pending message count.
174    pub fn pending_count(&self) -> usize {
175        // Note: This is an estimate since the channel may be modified concurrently
176        0 // mpsc doesn't provide len() directly
177    }
178}
179
180/// K2K message broker for routing messages between kernels.
181pub struct K2KBroker {
182    /// Configuration.
183    config: K2KConfig,
184    /// Registered endpoints (kernel_id -> sender).
185    endpoints: RwLock<HashMap<KernelId, mpsc::Sender<K2KMessage>>>,
186    /// Message counter.
187    message_counter: AtomicU64,
188    /// Delivery receipts (for acknowledgment).
189    receipts: RwLock<HashMap<MessageId, DeliveryReceipt>>,
190    /// Routing table for indirect delivery.
191    routing_table: RwLock<HashMap<KernelId, KernelId>>,
192}
193
194impl K2KBroker {
195    /// Create a new K2K broker.
196    pub fn new(config: K2KConfig) -> Arc<Self> {
197        Arc::new(Self {
198            config,
199            endpoints: RwLock::new(HashMap::new()),
200            message_counter: AtomicU64::new(0),
201            receipts: RwLock::new(HashMap::new()),
202            routing_table: RwLock::new(HashMap::new()),
203        })
204    }
205
206    /// Register a kernel endpoint.
207    pub fn register(self: &Arc<Self>, kernel_id: KernelId) -> K2KEndpoint {
208        let (sender, receiver) = mpsc::channel(self.config.max_pending_messages);
209
210        self.endpoints.write().insert(kernel_id.clone(), sender);
211
212        K2KEndpoint {
213            kernel_id,
214            receiver,
215            broker: Arc::clone(self),
216        }
217    }
218
219    /// Unregister a kernel endpoint.
220    pub fn unregister(&self, kernel_id: &KernelId) {
221        self.endpoints.write().remove(kernel_id);
222        self.routing_table.write().remove(kernel_id);
223    }
224
225    /// Check if a kernel is registered.
226    pub fn is_registered(&self, kernel_id: &KernelId) -> bool {
227        self.endpoints.read().contains_key(kernel_id)
228    }
229
230    /// Get all registered kernels.
231    pub fn registered_kernels(&self) -> Vec<KernelId> {
232        self.endpoints.read().keys().cloned().collect()
233    }
234
235    /// Send a message from one kernel to another.
236    pub async fn send(
237        &self,
238        source: KernelId,
239        destination: KernelId,
240        envelope: MessageEnvelope,
241    ) -> Result<DeliveryReceipt> {
242        self.send_priority(source, destination, envelope, 0).await
243    }
244
245    /// Send a priority message.
246    pub async fn send_priority(
247        &self,
248        source: KernelId,
249        destination: KernelId,
250        envelope: MessageEnvelope,
251        priority: u8,
252    ) -> Result<DeliveryReceipt> {
253        let timestamp = envelope.header.timestamp;
254        let mut message = K2KMessage::new(source.clone(), destination.clone(), envelope, timestamp);
255        message.priority = priority;
256
257        self.deliver(message).await
258    }
259
260    /// Deliver a message to its destination.
261    async fn deliver(&self, message: K2KMessage) -> Result<DeliveryReceipt> {
262        let message_id = message.id;
263        let source = message.source.clone();
264        let destination = message.destination.clone();
265        let timestamp = message.sent_at;
266
267        // Try direct delivery first
268        let endpoints = self.endpoints.read();
269        if let Some(sender) = endpoints.get(&destination) {
270            match sender.try_send(message) {
271                Ok(()) => {
272                    self.message_counter.fetch_add(1, Ordering::Relaxed);
273                    let receipt = DeliveryReceipt {
274                        message_id,
275                        source,
276                        destination,
277                        status: DeliveryStatus::Delivered,
278                        timestamp,
279                    };
280                    self.receipts.write().insert(message_id, receipt.clone());
281                    return Ok(receipt);
282                }
283                Err(mpsc::error::TrySendError::Full(_)) => {
284                    return Ok(DeliveryReceipt {
285                        message_id,
286                        source,
287                        destination,
288                        status: DeliveryStatus::QueueFull,
289                        timestamp,
290                    });
291                }
292                Err(mpsc::error::TrySendError::Closed(_)) => {
293                    return Ok(DeliveryReceipt {
294                        message_id,
295                        source,
296                        destination,
297                        status: DeliveryStatus::NotFound,
298                        timestamp,
299                    });
300                }
301            }
302        }
303        drop(endpoints);
304
305        // Try routing table
306        let next_hop = {
307            let routing = self.routing_table.read();
308            routing.get(&destination).cloned()
309        };
310
311        if let Some(next_hop) = next_hop {
312            let routed_message = K2KMessage {
313                id: message_id,
314                source,
315                destination: destination.clone(),
316                envelope: message.envelope,
317                hops: message.hops + 1,
318                sent_at: message.sent_at,
319                priority: message.priority,
320            };
321
322            if routed_message.hops > self.config.max_hops {
323                return Ok(DeliveryReceipt {
324                    message_id,
325                    source: routed_message.source,
326                    destination,
327                    status: DeliveryStatus::MaxHopsExceeded,
328                    timestamp,
329                });
330            }
331
332            // Try to deliver to next hop
333            let endpoints = self.endpoints.read();
334            if let Some(sender) = endpoints.get(&next_hop) {
335                if sender.try_send(routed_message).is_ok() {
336                    self.message_counter.fetch_add(1, Ordering::Relaxed);
337                    return Ok(DeliveryReceipt {
338                        message_id,
339                        source: message.source,
340                        destination,
341                        status: DeliveryStatus::Pending,
342                        timestamp,
343                    });
344                }
345            }
346        }
347
348        // Destination not found
349        Ok(DeliveryReceipt {
350            message_id,
351            source: message.source,
352            destination,
353            status: DeliveryStatus::NotFound,
354            timestamp,
355        })
356    }
357
358    /// Add a route to the routing table.
359    pub fn add_route(&self, destination: KernelId, next_hop: KernelId) {
360        self.routing_table.write().insert(destination, next_hop);
361    }
362
363    /// Remove a route from the routing table.
364    pub fn remove_route(&self, destination: &KernelId) {
365        self.routing_table.write().remove(destination);
366    }
367
368    /// Get statistics.
369    pub fn stats(&self) -> K2KStats {
370        K2KStats {
371            registered_endpoints: self.endpoints.read().len(),
372            messages_delivered: self.message_counter.load(Ordering::Relaxed),
373            routes_configured: self.routing_table.read().len(),
374        }
375    }
376
377    /// Get delivery receipt for a message.
378    pub fn get_receipt(&self, message_id: &MessageId) -> Option<DeliveryReceipt> {
379        self.receipts.read().get(message_id).cloned()
380    }
381}
382
383/// K2K messaging statistics.
384#[derive(Debug, Clone, Default)]
385pub struct K2KStats {
386    /// Number of registered endpoints.
387    pub registered_endpoints: usize,
388    /// Total messages delivered.
389    pub messages_delivered: u64,
390    /// Number of routes configured.
391    pub routes_configured: usize,
392}
393
394/// Builder for creating K2K infrastructure.
395pub struct K2KBuilder {
396    config: K2KConfig,
397}
398
399impl K2KBuilder {
400    /// Create a new builder.
401    pub fn new() -> Self {
402        Self {
403            config: K2KConfig::default(),
404        }
405    }
406
407    /// Set maximum pending messages.
408    pub fn max_pending_messages(mut self, count: usize) -> Self {
409        self.config.max_pending_messages = count;
410        self
411    }
412
413    /// Set delivery timeout.
414    pub fn delivery_timeout_ms(mut self, timeout: u64) -> Self {
415        self.config.delivery_timeout_ms = timeout;
416        self
417    }
418
419    /// Enable message tracing.
420    pub fn enable_tracing(mut self, enable: bool) -> Self {
421        self.config.enable_tracing = enable;
422        self
423    }
424
425    /// Set maximum hop count.
426    pub fn max_hops(mut self, hops: u8) -> Self {
427        self.config.max_hops = hops;
428        self
429    }
430
431    /// Build the K2K broker.
432    pub fn build(self) -> Arc<K2KBroker> {
433        K2KBroker::new(self.config)
434    }
435}
436
437impl Default for K2KBuilder {
438    fn default() -> Self {
439        Self::new()
440    }
441}
442
443// ============================================================================
444// K2K Message Type Registry (FR-3)
445// ============================================================================
446
447/// Registration information for a K2K-routable message type.
448///
449/// This struct is automatically generated by the `#[derive(RingMessage)]` macro
450/// when `k2k_routable = true` is specified. Registrations are collected at
451/// compile time using the `inventory` crate.
452///
453/// # Example
454///
455/// ```ignore
456/// #[derive(RingMessage)]
457/// #[ring_message(type_id = 1, domain = "OrderMatching", k2k_routable = true)]
458/// pub struct SubmitOrderInput { ... }
459///
460/// // Runtime discovery
461/// let registry = K2KTypeRegistry::discover();
462/// assert!(registry.is_routable(501)); // domain base (500) + type_id (1)
463/// ```
464#[derive(Debug, Clone)]
465pub struct K2KMessageRegistration {
466    /// Message type ID (from RingMessage::message_type()).
467    pub type_id: u64,
468    /// Full type name for debugging/logging.
469    pub type_name: &'static str,
470    /// Whether this message type is routable via K2K.
471    pub k2k_routable: bool,
472    /// Optional routing category for grouped routing.
473    pub category: Option<&'static str>,
474}
475
476// Collect all K2K message registrations at compile time
477inventory::collect!(K2KMessageRegistration);
478
479/// Registry for discovering K2K-routable message types at runtime.
480///
481/// The registry is built by scanning all `K2KMessageRegistration` entries
482/// submitted via the `inventory` crate. This enables runtime discovery of
483/// message types for routing, validation, and monitoring.
484///
485/// # Example
486///
487/// ```ignore
488/// let registry = K2KTypeRegistry::discover();
489///
490/// // Check if a type is routable
491/// if registry.is_routable(501) {
492///     // Allow K2K routing
493/// }
494///
495/// // Get all types in a category
496/// let order_types = registry.get_category("orders");
497/// for type_id in order_types {
498///     println!("Order message type: {}", type_id);
499/// }
500/// ```
501pub struct K2KTypeRegistry {
502    /// Type ID to registration mapping.
503    by_type_id: HashMap<u64, &'static K2KMessageRegistration>,
504    /// Type name to registration mapping.
505    by_type_name: HashMap<&'static str, &'static K2KMessageRegistration>,
506    /// Category to type IDs mapping.
507    by_category: HashMap<&'static str, Vec<u64>>,
508}
509
510impl K2KTypeRegistry {
511    /// Discover all registered K2K message types at runtime.
512    ///
513    /// This scans all `K2KMessageRegistration` entries that were submitted
514    /// via `inventory::submit!` during compilation.
515    pub fn discover() -> Self {
516        let mut registry = Self {
517            by_type_id: HashMap::new(),
518            by_type_name: HashMap::new(),
519            by_category: HashMap::new(),
520        };
521
522        for reg in inventory::iter::<K2KMessageRegistration>() {
523            registry.by_type_id.insert(reg.type_id, reg);
524            registry.by_type_name.insert(reg.type_name, reg);
525            if let Some(cat) = reg.category {
526                registry
527                    .by_category
528                    .entry(cat)
529                    .or_default()
530                    .push(reg.type_id);
531            }
532        }
533
534        registry
535    }
536
537    /// Check if a message type ID is K2K routable.
538    pub fn is_routable(&self, type_id: u64) -> bool {
539        self.by_type_id
540            .get(&type_id)
541            .map(|r| r.k2k_routable)
542            .unwrap_or(false)
543    }
544
545    /// Get registration by type ID.
546    pub fn get(&self, type_id: u64) -> Option<&'static K2KMessageRegistration> {
547        self.by_type_id.get(&type_id).copied()
548    }
549
550    /// Get registration by type name.
551    pub fn get_by_name(&self, type_name: &str) -> Option<&'static K2KMessageRegistration> {
552        self.by_type_name.get(type_name).copied()
553    }
554
555    /// Get all type IDs in a category.
556    pub fn get_category(&self, category: &str) -> &[u64] {
557        self.by_category
558            .get(category)
559            .map(|v| v.as_slice())
560            .unwrap_or(&[])
561    }
562
563    /// Get all registered categories.
564    pub fn categories(&self) -> impl Iterator<Item = &'static str> + '_ {
565        self.by_category.keys().copied()
566    }
567
568    /// Iterate all registered message types.
569    pub fn iter(&self) -> impl Iterator<Item = &'static K2KMessageRegistration> + '_ {
570        self.by_type_id.values().copied()
571    }
572
573    /// Get all routable type IDs.
574    pub fn routable_types(&self) -> Vec<u64> {
575        self.by_type_id
576            .iter()
577            .filter(|(_, r)| r.k2k_routable)
578            .map(|(id, _)| *id)
579            .collect()
580    }
581
582    /// Get total number of registered message types.
583    pub fn len(&self) -> usize {
584        self.by_type_id.len()
585    }
586
587    /// Check if the registry is empty.
588    pub fn is_empty(&self) -> bool {
589        self.by_type_id.is_empty()
590    }
591}
592
593impl Default for K2KTypeRegistry {
594    fn default() -> Self {
595        Self::discover()
596    }
597}
598
599impl std::fmt::Debug for K2KTypeRegistry {
600    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
601        f.debug_struct("K2KTypeRegistry")
602            .field("registered_types", &self.by_type_id.len())
603            .field("categories", &self.by_category.keys().collect::<Vec<_>>())
604            .finish()
605    }
606}
607
608#[cfg(test)]
609mod tests {
610    use super::*;
611
612    #[tokio::test]
613    async fn test_k2k_broker_registration() {
614        let broker = K2KBuilder::new().build();
615
616        let kernel1 = KernelId::new("kernel1");
617        let kernel2 = KernelId::new("kernel2");
618
619        let _endpoint1 = broker.register(kernel1.clone());
620        let _endpoint2 = broker.register(kernel2.clone());
621
622        assert!(broker.is_registered(&kernel1));
623        assert!(broker.is_registered(&kernel2));
624        assert_eq!(broker.registered_kernels().len(), 2);
625    }
626
627    #[tokio::test]
628    async fn test_k2k_message_delivery() {
629        let broker = K2KBuilder::new().build();
630
631        let kernel1 = KernelId::new("kernel1");
632        let kernel2 = KernelId::new("kernel2");
633
634        let endpoint1 = broker.register(kernel1.clone());
635        let mut endpoint2 = broker.register(kernel2.clone());
636
637        // Create a test envelope
638        let envelope = MessageEnvelope::empty(1, 2, HlcTimestamp::now(1));
639
640        // Send from kernel1 to kernel2
641        let receipt = endpoint1.send(kernel2.clone(), envelope).await.unwrap();
642        assert_eq!(receipt.status, DeliveryStatus::Delivered);
643
644        // Receive on kernel2
645        let message = endpoint2.try_receive();
646        assert!(message.is_some());
647        assert_eq!(message.unwrap().source, kernel1);
648    }
649
650    #[test]
651    fn test_k2k_config_default() {
652        let config = K2KConfig::default();
653        assert_eq!(config.max_pending_messages, 1024);
654        assert_eq!(config.delivery_timeout_ms, 5000);
655    }
656}