Skip to main content

ringkernel_core/
persistent_message.rs

1//! Persistent Message Traits for Type-Based Kernel Dispatch
2//!
3//! This module provides traits and types for user-defined message dispatch within
4//! persistent GPU kernels. It enables multiple analytics types (fraud detection,
5//! aggregations, pattern detection) to run within a single persistent kernel with
6//! type-based routing to specialized handlers.
7//!
8//! # Architecture
9//!
10//! ```text
11//! Host                    GPU (Persistent Kernel)
12//! ┌──────────────┐       ┌─────────────────────────────────────┐
13//! │ send_message │──────▶│ H2K Queue                           │
14//! │ <FraudCheck> │       │   ↓                                 │
15//! │              │       │ Type Dispatcher (switch on type_id) │
16//! │              │       │   ├─▶ handle_fraud_check()          │
17//! │              │       │   ├─▶ handle_aggregate()            │
18//! │              │       │   └─▶ handle_pattern_detect()       │
19//! │              │       │         ↓                           │
20//! │ poll_typed   │◀──────│ K2H Queue                           │
21//! │ <FraudResult>│       └─────────────────────────────────────┘
22//! └──────────────┘
23//! ```
24//!
25//! # Example
26//!
27//! ```ignore
28//! use ringkernel_core::persistent_message::{PersistentMessage, DispatchTable};
29//! use ringkernel_derive::{RingMessage, PersistentMessage};
30//!
31//! #[derive(RingMessage, PersistentMessage)]
32//! #[message(type_id = 1001)]
33//! #[persistent_message(handler_id = 1, requires_response = true)]
34//! pub struct FraudCheckRequest {
35//!     pub transaction_id: u64,
36//!     pub amount: f32,
37//!     pub account_id: u32,
38//! }
39//!
40//! // Runtime usage
41//! sim.send_message(FraudCheckRequest { ... })?;  // ~0.03µs
42//! let results: Vec<FraudCheckResult> = sim.poll_typed();
43//! ```
44
45use crate::message::RingMessage;
46
47/// Maximum size for inline payload in extended messages.
48/// Messages larger than this must use external buffer references.
49pub const MAX_INLINE_PAYLOAD_SIZE: usize = 32;
50
51/// Flags for extended H2K messages.
52pub mod message_flags {
53    /// Flag indicating this is an extended message format.
54    pub const FLAG_EXTENDED: u32 = 0x01;
55    /// Flag indicating this message has high priority.
56    pub const FLAG_HIGH_PRIORITY: u32 = 0x02;
57    /// Flag indicating message uses external buffer.
58    pub const FLAG_EXTERNAL_BUFFER: u32 = 0x04;
59    /// Flag indicating this message requires a response.
60    pub const FLAG_REQUIRES_RESPONSE: u32 = 0x08;
61}
62
63/// Trait for messages that can be dispatched within a persistent GPU kernel.
64///
65/// This trait extends `RingMessage` with additional metadata needed for
66/// type-based dispatch within a unified kernel. Each message type is
67/// associated with a handler ID that maps to a CUDA device function.
68///
69/// # Implementation
70///
71/// Use the `#[derive(PersistentMessage)]` macro for automatic implementation:
72///
73/// ```ignore
74/// #[derive(RingMessage, PersistentMessage)]
75/// #[message(type_id = 1001)]
76/// #[persistent_message(handler_id = 1, requires_response = true)]
77/// pub struct FraudCheckRequest {
78///     pub transaction_id: u64,
79///     pub amount: f32,
80///     pub account_id: u32,
81/// }
82/// ```
83pub trait PersistentMessage: RingMessage + Sized {
84    /// Handler ID for CUDA dispatch (0-255).
85    ///
86    /// This maps to a case in the generated switch statement:
87    /// ```cuda
88    /// switch (msg->handler_id) {
89    ///     case 1: handle_fraud_check(msg, state, response); break;
90    ///     // ...
91    /// }
92    /// ```
93    fn handler_id() -> u32;
94
95    /// Whether this message type expects a response.
96    ///
97    /// When true, the kernel will generate a response message after
98    /// processing. The caller should use `poll_typed::<ResponseType>()`
99    /// to retrieve responses.
100    fn requires_response() -> bool {
101        false
102    }
103
104    /// Convert message to inline payload bytes.
105    ///
106    /// Returns `Some([u8; 32])` if the message fits in 32 bytes,
107    /// `None` if the message requires external buffer allocation.
108    fn to_inline_payload(&self) -> Option<[u8; MAX_INLINE_PAYLOAD_SIZE]>;
109
110    /// Reconstruct message from inline payload bytes.
111    ///
112    /// # Errors
113    ///
114    /// Returns error if the payload is invalid or incomplete.
115    fn from_inline_payload(payload: &[u8]) -> crate::error::Result<Self>;
116
117    /// Get the serialized payload size in bytes.
118    fn payload_size() -> usize;
119
120    /// Check if this message type can be inlined (fits in 32 bytes).
121    fn can_inline() -> bool {
122        Self::payload_size() <= MAX_INLINE_PAYLOAD_SIZE
123    }
124}
125
126/// Handler registration entry for the dispatch table.
127#[derive(Debug, Clone)]
128pub struct HandlerRegistration {
129    /// Handler ID (0-255).
130    pub handler_id: u32,
131    /// Name of the handler function.
132    pub name: String,
133    /// Message type ID (from RingMessage::message_type()).
134    pub message_type_id: u64,
135    /// Whether this handler produces responses.
136    pub produces_response: bool,
137    /// Response type ID (if produces_response is true).
138    pub response_type_id: Option<u64>,
139    /// CUDA function body (for code generation).
140    pub cuda_body: Option<String>,
141}
142
143impl HandlerRegistration {
144    /// Create a new handler registration.
145    pub fn new(handler_id: u32, name: impl Into<String>, message_type_id: u64) -> Self {
146        Self {
147            handler_id,
148            name: name.into(),
149            message_type_id,
150            produces_response: false,
151            response_type_id: None,
152            cuda_body: None,
153        }
154    }
155
156    /// Set whether this handler produces responses.
157    pub fn with_response(mut self, response_type_id: u64) -> Self {
158        self.produces_response = true;
159        self.response_type_id = Some(response_type_id);
160        self
161    }
162
163    /// Set the CUDA function body for code generation.
164    pub fn with_cuda_body(mut self, body: impl Into<String>) -> Self {
165        self.cuda_body = Some(body.into());
166        self
167    }
168}
169
170/// Dispatch table mapping handler IDs to functions.
171///
172/// Used during code generation to build the CUDA switch statement.
173#[derive(Debug, Clone, Default)]
174pub struct DispatchTable {
175    /// Registered handlers.
176    handlers: Vec<HandlerRegistration>,
177    /// Maximum handler ID seen.
178    max_handler_id: u32,
179}
180
181impl DispatchTable {
182    /// Create a new empty dispatch table.
183    pub fn new() -> Self {
184        Self::default()
185    }
186
187    /// Register a handler.
188    ///
189    /// # Panics
190    ///
191    /// Panics if a handler with the same ID is already registered.
192    pub fn register(&mut self, registration: HandlerRegistration) {
193        // Check for duplicate handler ID
194        if self
195            .handlers
196            .iter()
197            .any(|h| h.handler_id == registration.handler_id)
198        {
199            panic!(
200                "Duplicate handler ID: {} ({})",
201                registration.handler_id, registration.name
202            );
203        }
204
205        self.max_handler_id = self.max_handler_id.max(registration.handler_id);
206        self.handlers.push(registration);
207    }
208
209    /// Register a handler from a PersistentMessage type.
210    pub fn register_message<M: PersistentMessage>(&mut self, name: impl Into<String>) {
211        let registration = HandlerRegistration::new(M::handler_id(), name, M::message_type());
212
213        let registration = if M::requires_response() {
214            // Note: Response type ID would need to be provided separately
215            registration
216        } else {
217            registration
218        };
219
220        self.register(registration);
221    }
222
223    /// Get all registered handlers.
224    pub fn handlers(&self) -> &[HandlerRegistration] {
225        &self.handlers
226    }
227
228    /// Get a handler by ID.
229    pub fn get(&self, handler_id: u32) -> Option<&HandlerRegistration> {
230        self.handlers.iter().find(|h| h.handler_id == handler_id)
231    }
232
233    /// Get the maximum handler ID.
234    pub fn max_handler_id(&self) -> u32 {
235        self.max_handler_id
236    }
237
238    /// Get the number of registered handlers.
239    pub fn len(&self) -> usize {
240        self.handlers.len()
241    }
242
243    /// Check if the table is empty.
244    pub fn is_empty(&self) -> bool {
245        self.handlers.is_empty()
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[test]
254    fn test_dispatch_table_registration() {
255        let mut table = DispatchTable::new();
256
257        table.register(HandlerRegistration::new(1, "fraud_check", 1001));
258        table.register(HandlerRegistration::new(2, "aggregate", 1002));
259        table.register(HandlerRegistration::new(3, "pattern_detect", 1003).with_response(2003));
260
261        assert_eq!(table.len(), 3);
262        assert_eq!(table.max_handler_id(), 3);
263
264        let handler = table.get(2).unwrap();
265        assert_eq!(handler.name, "aggregate");
266        assert_eq!(handler.message_type_id, 1002);
267        assert!(!handler.produces_response);
268
269        let handler = table.get(3).unwrap();
270        assert!(handler.produces_response);
271        assert_eq!(handler.response_type_id, Some(2003));
272    }
273
274    #[test]
275    #[should_panic(expected = "Duplicate handler ID")]
276    fn test_duplicate_handler_panics() {
277        let mut table = DispatchTable::new();
278        table.register(HandlerRegistration::new(1, "first", 1001));
279        table.register(HandlerRegistration::new(1, "second", 1002)); // Should panic
280    }
281
282    #[test]
283    fn test_message_flags() {
284        assert_eq!(message_flags::FLAG_EXTENDED, 0x01);
285        assert_eq!(message_flags::FLAG_HIGH_PRIORITY, 0x02);
286        assert_eq!(message_flags::FLAG_EXTERNAL_BUFFER, 0x04);
287        assert_eq!(message_flags::FLAG_REQUIRES_RESPONSE, 0x08);
288    }
289
290    #[test]
291    fn test_max_inline_payload_size() {
292        assert_eq!(MAX_INLINE_PAYLOAD_SIZE, 32);
293    }
294}