ringkernel_core/persistent_message.rs
1//! Persistent Message Traits for Type-Based Kernel Dispatch
2//!
3//! This module provides traits and types for user-defined message dispatch within
4//! persistent GPU kernels. It enables multiple analytics types (fraud detection,
5//! aggregations, pattern detection) to run within a single persistent kernel with
6//! type-based routing to specialized handlers.
7//!
8//! # Architecture
9//!
10//! ```text
11//! Host GPU (Persistent Kernel)
12//! ┌──────────────┐ ┌─────────────────────────────────────┐
13//! │ send_message │──────▶│ H2K Queue │
14//! │ <FraudCheck> │ │ ↓ │
15//! │ │ │ Type Dispatcher (switch on type_id) │
16//! │ │ │ ├─▶ handle_fraud_check() │
17//! │ │ │ ├─▶ handle_aggregate() │
18//! │ │ │ └─▶ handle_pattern_detect() │
19//! │ │ │ ↓ │
20//! │ poll_typed │◀──────│ K2H Queue │
21//! │ <FraudResult>│ └─────────────────────────────────────┘
22//! └──────────────┘
23//! ```
24//!
25//! # Example
26//!
27//! ```ignore
28//! use ringkernel_core::persistent_message::{PersistentMessage, DispatchTable};
29//! use ringkernel_derive::{RingMessage, PersistentMessage};
30//!
31//! #[derive(RingMessage, PersistentMessage)]
32//! #[message(type_id = 1001)]
33//! #[persistent_message(handler_id = 1, requires_response = true)]
34//! pub struct FraudCheckRequest {
35//! pub transaction_id: u64,
36//! pub amount: f32,
37//! pub account_id: u32,
38//! }
39//!
40//! // Runtime usage
41//! sim.send_message(FraudCheckRequest { ... })?; // ~0.03µs
42//! let results: Vec<FraudCheckResult> = sim.poll_typed();
43//! ```
44
45use crate::message::RingMessage;
46
47/// Maximum size for inline payload in extended messages.
48/// Messages larger than this must use external buffer references.
49pub const MAX_INLINE_PAYLOAD_SIZE: usize = 32;
50
51/// Flags for extended H2K messages.
52pub mod message_flags {
53 /// Flag indicating this is an extended message format.
54 pub const FLAG_EXTENDED: u32 = 0x01;
55 /// Flag indicating this message has high priority.
56 pub const FLAG_HIGH_PRIORITY: u32 = 0x02;
57 /// Flag indicating message uses external buffer.
58 pub const FLAG_EXTERNAL_BUFFER: u32 = 0x04;
59 /// Flag indicating this message requires a response.
60 pub const FLAG_REQUIRES_RESPONSE: u32 = 0x08;
61}
62
63/// Trait for messages that can be dispatched within a persistent GPU kernel.
64///
65/// This trait extends `RingMessage` with additional metadata needed for
66/// type-based dispatch within a unified kernel. Each message type is
67/// associated with a handler ID that maps to a CUDA device function.
68///
69/// # Implementation
70///
71/// Use the `#[derive(PersistentMessage)]` macro for automatic implementation:
72///
73/// ```ignore
74/// #[derive(RingMessage, PersistentMessage)]
75/// #[message(type_id = 1001)]
76/// #[persistent_message(handler_id = 1, requires_response = true)]
77/// pub struct FraudCheckRequest {
78/// pub transaction_id: u64,
79/// pub amount: f32,
80/// pub account_id: u32,
81/// }
82/// ```
83pub trait PersistentMessage: RingMessage + Sized {
84 /// Handler ID for CUDA dispatch (0-255).
85 ///
86 /// This maps to a case in the generated switch statement:
87 /// ```cuda
88 /// switch (msg->handler_id) {
89 /// case 1: handle_fraud_check(msg, state, response); break;
90 /// // ...
91 /// }
92 /// ```
93 fn handler_id() -> u32;
94
95 /// Whether this message type expects a response.
96 ///
97 /// When true, the kernel will generate a response message after
98 /// processing. The caller should use `poll_typed::<ResponseType>()`
99 /// to retrieve responses.
100 fn requires_response() -> bool {
101 false
102 }
103
104 /// Convert message to inline payload bytes.
105 ///
106 /// Returns `Some([u8; 32])` if the message fits in 32 bytes,
107 /// `None` if the message requires external buffer allocation.
108 fn to_inline_payload(&self) -> Option<[u8; MAX_INLINE_PAYLOAD_SIZE]>;
109
110 /// Reconstruct message from inline payload bytes.
111 ///
112 /// # Errors
113 ///
114 /// Returns error if the payload is invalid or incomplete.
115 fn from_inline_payload(payload: &[u8]) -> crate::error::Result<Self>;
116
117 /// Get the serialized payload size in bytes.
118 fn payload_size() -> usize;
119
120 /// Check if this message type can be inlined (fits in 32 bytes).
121 fn can_inline() -> bool {
122 Self::payload_size() <= MAX_INLINE_PAYLOAD_SIZE
123 }
124}
125
126/// Handler registration entry for the dispatch table.
127#[derive(Debug, Clone)]
128pub struct HandlerRegistration {
129 /// Handler ID (0-255).
130 pub handler_id: u32,
131 /// Name of the handler function.
132 pub name: String,
133 /// Message type ID (from RingMessage::message_type()).
134 pub message_type_id: u64,
135 /// Whether this handler produces responses.
136 pub produces_response: bool,
137 /// Response type ID (if produces_response is true).
138 pub response_type_id: Option<u64>,
139 /// CUDA function body (for code generation).
140 pub cuda_body: Option<String>,
141}
142
143impl HandlerRegistration {
144 /// Create a new handler registration.
145 pub fn new(handler_id: u32, name: impl Into<String>, message_type_id: u64) -> Self {
146 Self {
147 handler_id,
148 name: name.into(),
149 message_type_id,
150 produces_response: false,
151 response_type_id: None,
152 cuda_body: None,
153 }
154 }
155
156 /// Set whether this handler produces responses.
157 pub fn with_response(mut self, response_type_id: u64) -> Self {
158 self.produces_response = true;
159 self.response_type_id = Some(response_type_id);
160 self
161 }
162
163 /// Set the CUDA function body for code generation.
164 pub fn with_cuda_body(mut self, body: impl Into<String>) -> Self {
165 self.cuda_body = Some(body.into());
166 self
167 }
168}
169
170/// Dispatch table mapping handler IDs to functions.
171///
172/// Used during code generation to build the CUDA switch statement.
173#[derive(Debug, Clone, Default)]
174pub struct DispatchTable {
175 /// Registered handlers.
176 handlers: Vec<HandlerRegistration>,
177 /// Maximum handler ID seen.
178 max_handler_id: u32,
179}
180
181impl DispatchTable {
182 /// Create a new empty dispatch table.
183 pub fn new() -> Self {
184 Self::default()
185 }
186
187 /// Register a handler.
188 ///
189 /// Returns an error if a handler with the same ID is already registered.
190 pub fn register(&mut self, registration: HandlerRegistration) -> crate::error::Result<()> {
191 // Check for duplicate handler ID
192 if let Some(existing) = self
193 .handlers
194 .iter()
195 .find(|h| h.handler_id == registration.handler_id)
196 {
197 return Err(crate::error::RingKernelError::InvalidConfig(format!(
198 "Duplicate handler ID: {} (new: {}, existing: {})",
199 registration.handler_id, registration.name, existing.name
200 )));
201 }
202
203 self.max_handler_id = self.max_handler_id.max(registration.handler_id);
204 self.handlers.push(registration);
205 Ok(())
206 }
207
208 /// Register a handler from a PersistentMessage type.
209 pub fn register_message<M: PersistentMessage>(
210 &mut self,
211 name: impl Into<String>,
212 ) -> crate::error::Result<()> {
213 let registration = HandlerRegistration::new(M::handler_id(), name, M::message_type());
214
215 let registration = if M::requires_response() {
216 // Note: Response type ID would need to be provided separately
217 registration
218 } else {
219 registration
220 };
221
222 self.register(registration)
223 }
224
225 /// Get all registered handlers.
226 pub fn handlers(&self) -> &[HandlerRegistration] {
227 &self.handlers
228 }
229
230 /// Get a handler by ID.
231 pub fn get(&self, handler_id: u32) -> Option<&HandlerRegistration> {
232 self.handlers.iter().find(|h| h.handler_id == handler_id)
233 }
234
235 /// Get the maximum handler ID.
236 pub fn max_handler_id(&self) -> u32 {
237 self.max_handler_id
238 }
239
240 /// Get the number of registered handlers.
241 pub fn len(&self) -> usize {
242 self.handlers.len()
243 }
244
245 /// Check if the table is empty.
246 pub fn is_empty(&self) -> bool {
247 self.handlers.is_empty()
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn test_dispatch_table_registration() {
257 let mut table = DispatchTable::new();
258
259 table
260 .register(HandlerRegistration::new(1, "fraud_check", 1001))
261 .unwrap();
262 table
263 .register(HandlerRegistration::new(2, "aggregate", 1002))
264 .unwrap();
265 table
266 .register(HandlerRegistration::new(3, "pattern_detect", 1003).with_response(2003))
267 .unwrap();
268
269 assert_eq!(table.len(), 3);
270 assert_eq!(table.max_handler_id(), 3);
271
272 let handler = table.get(2).unwrap();
273 assert_eq!(handler.name, "aggregate");
274 assert_eq!(handler.message_type_id, 1002);
275 assert!(!handler.produces_response);
276
277 let handler = table.get(3).unwrap();
278 assert!(handler.produces_response);
279 assert_eq!(handler.response_type_id, Some(2003));
280 }
281
282 #[test]
283 fn test_duplicate_handler_returns_error() {
284 let mut table = DispatchTable::new();
285 table
286 .register(HandlerRegistration::new(1, "first", 1001))
287 .unwrap();
288 let result = table.register(HandlerRegistration::new(1, "second", 1002));
289 assert!(result.is_err());
290 let err = result.unwrap_err();
291 assert!(err.to_string().contains("Duplicate handler ID"));
292 }
293
294 #[test]
295 fn test_message_flags() {
296 assert_eq!(message_flags::FLAG_EXTENDED, 0x01);
297 assert_eq!(message_flags::FLAG_HIGH_PRIORITY, 0x02);
298 assert_eq!(message_flags::FLAG_EXTERNAL_BUFFER, 0x04);
299 assert_eq!(message_flags::FLAG_REQUIRES_RESPONSE, 0x08);
300 }
301
302 #[test]
303 fn test_max_inline_payload_size() {
304 assert_eq!(MAX_INLINE_PAYLOAD_SIZE, 32);
305 }
306}