ringkernel_core/persistent_message.rs
1//! Persistent Message Traits for Type-Based Kernel Dispatch
2//!
3//! This module provides traits and types for user-defined message dispatch within
4//! persistent GPU kernels. It enables multiple analytics types (fraud detection,
5//! aggregations, pattern detection) to run within a single persistent kernel with
6//! type-based routing to specialized handlers.
7//!
8//! # Architecture
9//!
10//! ```text
11//! Host GPU (Persistent Kernel)
12//! ┌──────────────┐ ┌─────────────────────────────────────┐
13//! │ send_message │──────▶│ H2K Queue │
14//! │ <FraudCheck> │ │ ↓ │
15//! │ │ │ Type Dispatcher (switch on type_id) │
16//! │ │ │ ├─▶ handle_fraud_check() │
17//! │ │ │ ├─▶ handle_aggregate() │
18//! │ │ │ └─▶ handle_pattern_detect() │
19//! │ │ │ ↓ │
20//! │ poll_typed │◀──────│ K2H Queue │
21//! │ <FraudResult>│ └─────────────────────────────────────┘
22//! └──────────────┘
23//! ```
24//!
25//! # Example
26//!
27//! ```ignore
28//! use ringkernel_core::persistent_message::{PersistentMessage, DispatchTable};
29//! use ringkernel_derive::{RingMessage, PersistentMessage};
30//!
31//! #[derive(RingMessage, PersistentMessage)]
32//! #[message(type_id = 1001)]
33//! #[persistent_message(handler_id = 1, requires_response = true)]
34//! pub struct FraudCheckRequest {
35//! pub transaction_id: u64,
36//! pub amount: f32,
37//! pub account_id: u32,
38//! }
39//!
40//! // Runtime usage
41//! sim.send_message(FraudCheckRequest { ... })?; // ~0.03µs
42//! let results: Vec<FraudCheckResult> = sim.poll_typed();
43//! ```
44
45use crate::message::RingMessage;
46
47/// Maximum size for inline payload in extended messages.
48/// Messages larger than this must use external buffer references.
49pub const MAX_INLINE_PAYLOAD_SIZE: usize = 32;
50
51/// Flags for extended H2K messages.
52pub mod message_flags {
53 /// Flag indicating this is an extended message format.
54 pub const FLAG_EXTENDED: u32 = 0x01;
55 /// Flag indicating this message has high priority.
56 pub const FLAG_HIGH_PRIORITY: u32 = 0x02;
57 /// Flag indicating message uses external buffer.
58 pub const FLAG_EXTERNAL_BUFFER: u32 = 0x04;
59 /// Flag indicating this message requires a response.
60 pub const FLAG_REQUIRES_RESPONSE: u32 = 0x08;
61}
62
63/// Trait for messages that can be dispatched within a persistent GPU kernel.
64///
65/// This trait extends `RingMessage` with additional metadata needed for
66/// type-based dispatch within a unified kernel. Each message type is
67/// associated with a handler ID that maps to a CUDA device function.
68///
69/// # Implementation
70///
71/// Use the `#[derive(PersistentMessage)]` macro for automatic implementation:
72///
73/// ```ignore
74/// #[derive(RingMessage, PersistentMessage)]
75/// #[message(type_id = 1001)]
76/// #[persistent_message(handler_id = 1, requires_response = true)]
77/// pub struct FraudCheckRequest {
78/// pub transaction_id: u64,
79/// pub amount: f32,
80/// pub account_id: u32,
81/// }
82/// ```
83pub trait PersistentMessage: RingMessage + Sized {
84 /// Handler ID for CUDA dispatch (0-255).
85 ///
86 /// This maps to a case in the generated switch statement:
87 /// ```cuda
88 /// switch (msg->handler_id) {
89 /// case 1: handle_fraud_check(msg, state, response); break;
90 /// // ...
91 /// }
92 /// ```
93 fn handler_id() -> u32;
94
95 /// Whether this message type expects a response.
96 ///
97 /// When true, the kernel will generate a response message after
98 /// processing. The caller should use `poll_typed::<ResponseType>()`
99 /// to retrieve responses.
100 fn requires_response() -> bool {
101 false
102 }
103
104 /// Convert message to inline payload bytes.
105 ///
106 /// Returns `Some([u8; 32])` if the message fits in 32 bytes,
107 /// `None` if the message requires external buffer allocation.
108 fn to_inline_payload(&self) -> Option<[u8; MAX_INLINE_PAYLOAD_SIZE]>;
109
110 /// Reconstruct message from inline payload bytes.
111 ///
112 /// # Errors
113 ///
114 /// Returns error if the payload is invalid or incomplete.
115 fn from_inline_payload(payload: &[u8]) -> crate::error::Result<Self>;
116
117 /// Get the serialized payload size in bytes.
118 fn payload_size() -> usize;
119
120 /// Check if this message type can be inlined (fits in 32 bytes).
121 fn can_inline() -> bool {
122 Self::payload_size() <= MAX_INLINE_PAYLOAD_SIZE
123 }
124}
125
126/// Handler registration entry for the dispatch table.
127#[derive(Debug, Clone)]
128pub struct HandlerRegistration {
129 /// Handler ID (0-255).
130 pub handler_id: u32,
131 /// Name of the handler function.
132 pub name: String,
133 /// Message type ID (from RingMessage::message_type()).
134 pub message_type_id: u64,
135 /// Whether this handler produces responses.
136 pub produces_response: bool,
137 /// Response type ID (if produces_response is true).
138 pub response_type_id: Option<u64>,
139 /// CUDA function body (for code generation).
140 pub cuda_body: Option<String>,
141}
142
143impl HandlerRegistration {
144 /// Create a new handler registration.
145 pub fn new(handler_id: u32, name: impl Into<String>, message_type_id: u64) -> Self {
146 Self {
147 handler_id,
148 name: name.into(),
149 message_type_id,
150 produces_response: false,
151 response_type_id: None,
152 cuda_body: None,
153 }
154 }
155
156 /// Set whether this handler produces responses.
157 pub fn with_response(mut self, response_type_id: u64) -> Self {
158 self.produces_response = true;
159 self.response_type_id = Some(response_type_id);
160 self
161 }
162
163 /// Set the CUDA function body for code generation.
164 pub fn with_cuda_body(mut self, body: impl Into<String>) -> Self {
165 self.cuda_body = Some(body.into());
166 self
167 }
168}
169
170/// Dispatch table mapping handler IDs to functions.
171///
172/// Used during code generation to build the CUDA switch statement.
173#[derive(Debug, Clone, Default)]
174pub struct DispatchTable {
175 /// Registered handlers.
176 handlers: Vec<HandlerRegistration>,
177 /// Maximum handler ID seen.
178 max_handler_id: u32,
179}
180
181impl DispatchTable {
182 /// Create a new empty dispatch table.
183 pub fn new() -> Self {
184 Self::default()
185 }
186
187 /// Register a handler.
188 ///
189 /// # Panics
190 ///
191 /// Panics if a handler with the same ID is already registered.
192 pub fn register(&mut self, registration: HandlerRegistration) {
193 // Check for duplicate handler ID
194 if self
195 .handlers
196 .iter()
197 .any(|h| h.handler_id == registration.handler_id)
198 {
199 panic!(
200 "Duplicate handler ID: {} ({})",
201 registration.handler_id, registration.name
202 );
203 }
204
205 self.max_handler_id = self.max_handler_id.max(registration.handler_id);
206 self.handlers.push(registration);
207 }
208
209 /// Register a handler from a PersistentMessage type.
210 pub fn register_message<M: PersistentMessage>(&mut self, name: impl Into<String>) {
211 let registration = HandlerRegistration::new(M::handler_id(), name, M::message_type());
212
213 let registration = if M::requires_response() {
214 // Note: Response type ID would need to be provided separately
215 registration
216 } else {
217 registration
218 };
219
220 self.register(registration);
221 }
222
223 /// Get all registered handlers.
224 pub fn handlers(&self) -> &[HandlerRegistration] {
225 &self.handlers
226 }
227
228 /// Get a handler by ID.
229 pub fn get(&self, handler_id: u32) -> Option<&HandlerRegistration> {
230 self.handlers.iter().find(|h| h.handler_id == handler_id)
231 }
232
233 /// Get the maximum handler ID.
234 pub fn max_handler_id(&self) -> u32 {
235 self.max_handler_id
236 }
237
238 /// Get the number of registered handlers.
239 pub fn len(&self) -> usize {
240 self.handlers.len()
241 }
242
243 /// Check if the table is empty.
244 pub fn is_empty(&self) -> bool {
245 self.handlers.is_empty()
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn test_dispatch_table_registration() {
255 let mut table = DispatchTable::new();
256
257 table.register(HandlerRegistration::new(1, "fraud_check", 1001));
258 table.register(HandlerRegistration::new(2, "aggregate", 1002));
259 table.register(HandlerRegistration::new(3, "pattern_detect", 1003).with_response(2003));
260
261 assert_eq!(table.len(), 3);
262 assert_eq!(table.max_handler_id(), 3);
263
264 let handler = table.get(2).unwrap();
265 assert_eq!(handler.name, "aggregate");
266 assert_eq!(handler.message_type_id, 1002);
267 assert!(!handler.produces_response);
268
269 let handler = table.get(3).unwrap();
270 assert!(handler.produces_response);
271 assert_eq!(handler.response_type_id, Some(2003));
272 }
273
274 #[test]
275 #[should_panic(expected = "Duplicate handler ID")]
276 fn test_duplicate_handler_panics() {
277 let mut table = DispatchTable::new();
278 table.register(HandlerRegistration::new(1, "first", 1001));
279 table.register(HandlerRegistration::new(1, "second", 1002)); // Should panic
280 }
281
282 #[test]
283 fn test_message_flags() {
284 assert_eq!(message_flags::FLAG_EXTENDED, 0x01);
285 assert_eq!(message_flags::FLAG_HIGH_PRIORITY, 0x02);
286 assert_eq!(message_flags::FLAG_EXTERNAL_BUFFER, 0x04);
287 assert_eq!(message_flags::FLAG_REQUIRES_RESPONSE, 0x08);
288 }
289
290 #[test]
291 fn test_max_inline_payload_size() {
292 assert_eq!(MAX_INLINE_PAYLOAD_SIZE, 32);
293 }
294}