Skip to main content

ringkernel_core/
message.rs

1//! Message types and traits for kernel-to-kernel communication.
2//!
3//! This module defines the core message abstraction used for communication
4//! between GPU kernels and between host and device.
5
6use bytemuck::{Pod, Zeroable};
7use rkyv::{Archive, Deserialize, Serialize};
8use zerocopy::{AsBytes, FromBytes, FromZeroes};
9
10use crate::hlc::HlcTimestamp;
11use crate::k2k::audit_tag::AuditTag;
12use crate::k2k::tenant::TenantId;
13use crate::provenance::ProvenanceHeader;
14
15/// Unique message identifier.
16#[derive(
17    Debug,
18    Clone,
19    Copy,
20    PartialEq,
21    Eq,
22    Hash,
23    Default,
24    AsBytes,
25    FromBytes,
26    FromZeroes,
27    Pod,
28    Zeroable,
29    Archive,
30    Serialize,
31    Deserialize,
32)]
33#[repr(C)]
34pub struct MessageId(pub u64);
35
36impl MessageId {
37    /// Create a new message ID.
38    pub const fn new(id: u64) -> Self {
39        Self(id)
40    }
41
42    /// Generate a new unique message ID.
43    pub fn generate() -> Self {
44        use std::sync::atomic::{AtomicU64, Ordering};
45        static COUNTER: AtomicU64 = AtomicU64::new(1);
46        Self(COUNTER.fetch_add(1, Ordering::Relaxed))
47    }
48
49    /// Get the inner value.
50    pub const fn inner(&self) -> u64 {
51        self.0
52    }
53}
54
55impl std::fmt::Display for MessageId {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        write!(f, "msg:{:016x}", self.0)
58    }
59}
60
61/// Correlation ID for request-response patterns.
62#[derive(
63    Debug,
64    Clone,
65    Copy,
66    PartialEq,
67    Eq,
68    Hash,
69    Default,
70    AsBytes,
71    FromBytes,
72    FromZeroes,
73    Pod,
74    Zeroable,
75    Archive,
76    Serialize,
77    Deserialize,
78)]
79#[repr(C)]
80pub struct CorrelationId(pub u64);
81
82impl CorrelationId {
83    /// Create a new correlation ID.
84    pub const fn new(id: u64) -> Self {
85        Self(id)
86    }
87
88    /// Generate a new unique correlation ID.
89    pub fn generate() -> Self {
90        Self(MessageId::generate().0)
91    }
92
93    /// No correlation (for fire-and-forget messages).
94    pub const fn none() -> Self {
95        Self(0)
96    }
97
98    /// Check if this is a valid correlation ID.
99    pub const fn is_some(&self) -> bool {
100        self.0 != 0
101    }
102}
103
104/// Message priority levels.
105#[derive(
106    Debug,
107    Clone,
108    Copy,
109    PartialEq,
110    Eq,
111    Hash,
112    Default,
113    rkyv::Archive,
114    rkyv::Serialize,
115    rkyv::Deserialize,
116)]
117#[archive(compare(PartialEq))]
118#[repr(u8)]
119pub enum Priority {
120    /// Low priority (background tasks).
121    Low = 0,
122    /// Normal priority (default).
123    #[default]
124    Normal = 1,
125    /// High priority (important tasks).
126    High = 2,
127    /// Critical priority (system messages).
128    Critical = 3,
129}
130
131impl Priority {
132    /// Convert from u8.
133    pub const fn from_u8(value: u8) -> Self {
134        match value {
135            0 => Self::Low,
136            1 => Self::Normal,
137            2 => Self::High,
138            _ => Self::Critical,
139        }
140    }
141
142    /// Convert to u8.
143    pub const fn as_u8(self) -> u8 {
144        self as u8
145    }
146}
147
148/// Priority constants for convenient use.
149///
150/// # Example
151/// ```ignore
152/// use ringkernel::prelude::*;
153///
154/// let opts = LaunchOptions::default()
155///     .with_priority(priority::HIGH);
156/// ```
157pub mod priority {
158    /// Low priority (0) - background tasks.
159    pub const LOW: u8 = 0;
160    /// Normal priority (64) - default.
161    pub const NORMAL: u8 = 64;
162    /// High priority (128) - important tasks.
163    pub const HIGH: u8 = 128;
164    /// Critical priority (192) - system messages.
165    pub const CRITICAL: u8 = 192;
166}
167
168/// Fixed-size message header (256 bytes, cache-line aligned).
169///
170/// This header precedes the variable-length payload and provides
171/// all metadata needed for routing and processing.
172#[derive(Debug, Clone, Copy)]
173#[repr(C, align(64))]
174pub struct MessageHeader {
175    /// Magic number for validation (0xRINGKERN).
176    pub magic: u64,
177    /// Header version for compatibility.
178    pub version: u32,
179    /// Message flags.
180    pub flags: u32,
181    /// Unique message identifier.
182    pub message_id: MessageId,
183    /// Correlation ID for request-response.
184    pub correlation_id: CorrelationId,
185    /// Source kernel ID (0 for host).
186    pub source_kernel: u64,
187    /// Destination kernel ID (0 for host).
188    pub dest_kernel: u64,
189    /// Message type discriminator.
190    pub message_type: u64,
191    /// Priority level.
192    pub priority: u8,
193    /// Reserved for alignment.
194    pub _reserved1: [u8; 7],
195    /// Payload size in bytes.
196    pub payload_size: u64,
197    /// Checksum of payload (CRC32).
198    pub checksum: u32,
199    /// Reserved for alignment.
200    pub _reserved2: u32,
201    /// HLC timestamp when message was created.
202    pub timestamp: HlcTimestamp,
203    /// Deadline timestamp (0 = no deadline).
204    pub deadline: HlcTimestamp,
205    /// Reserved for future use (split for derive compatibility).
206    pub _reserved3a: [u8; 32],
207    /// Reserved for future use.
208    pub _reserved3b: [u8; 32],
209    /// Reserved for future use.
210    pub _reserved3c: [u8; 32],
211    /// Reserved for future use.
212    pub _reserved3d: [u8; 8],
213}
214
215impl MessageHeader {
216    /// Magic number for validation.
217    pub const MAGIC: u64 = 0x52494E474B45524E; // "RINGKERN"
218
219    /// Current header version.
220    pub const VERSION: u32 = 1;
221
222    /// Maximum payload size (64KB).
223    pub const MAX_PAYLOAD_SIZE: usize = 64 * 1024;
224
225    /// Convert header to bytes.
226    pub fn as_bytes(&self) -> &[u8] {
227        unsafe {
228            std::slice::from_raw_parts(
229                self as *const Self as *const u8,
230                std::mem::size_of::<Self>(),
231            )
232        }
233    }
234
235    /// Read header from bytes.
236    pub fn read_from(bytes: &[u8]) -> Option<Self> {
237        if bytes.len() < std::mem::size_of::<Self>() {
238            return None;
239        }
240        unsafe { Some(std::ptr::read_unaligned(bytes.as_ptr() as *const Self)) }
241    }
242
243    /// Create a new message header.
244    pub fn new(
245        message_type: u64,
246        source_kernel: u64,
247        dest_kernel: u64,
248        payload_size: usize,
249        timestamp: HlcTimestamp,
250    ) -> Self {
251        Self {
252            magic: Self::MAGIC,
253            version: Self::VERSION,
254            flags: 0,
255            message_id: MessageId::generate(),
256            correlation_id: CorrelationId::none(),
257            source_kernel,
258            dest_kernel,
259            message_type,
260            priority: Priority::Normal as u8,
261            _reserved1: [0; 7],
262            payload_size: payload_size as u64,
263            checksum: 0,
264            _reserved2: 0,
265            timestamp,
266            deadline: HlcTimestamp::zero(),
267            _reserved3a: [0; 32],
268            _reserved3b: [0; 32],
269            _reserved3c: [0; 32],
270            _reserved3d: [0; 8],
271        }
272    }
273
274    /// Validate the header.
275    pub fn validate(&self) -> bool {
276        self.magic == Self::MAGIC
277            && self.version <= Self::VERSION
278            && self.payload_size <= Self::MAX_PAYLOAD_SIZE as u64
279    }
280
281    /// Set correlation ID.
282    pub fn with_correlation(mut self, correlation_id: CorrelationId) -> Self {
283        self.correlation_id = correlation_id;
284        self
285    }
286
287    /// Set priority.
288    pub fn with_priority(mut self, priority: Priority) -> Self {
289        self.priority = priority as u8;
290        self
291    }
292
293    /// Set deadline.
294    pub fn with_deadline(mut self, deadline: HlcTimestamp) -> Self {
295        self.deadline = deadline;
296        self
297    }
298}
299
300impl Default for MessageHeader {
301    fn default() -> Self {
302        Self {
303            magic: Self::MAGIC,
304            version: Self::VERSION,
305            flags: 0,
306            message_id: MessageId::default(),
307            correlation_id: CorrelationId::none(),
308            source_kernel: 0,
309            dest_kernel: 0,
310            message_type: 0,
311            priority: Priority::Normal as u8,
312            _reserved1: [0; 7],
313            payload_size: 0,
314            checksum: 0,
315            _reserved2: 0,
316            timestamp: HlcTimestamp::zero(),
317            deadline: HlcTimestamp::zero(),
318            _reserved3a: [0; 32],
319            _reserved3b: [0; 32],
320            _reserved3c: [0; 32],
321            _reserved3d: [0; 8],
322        }
323    }
324}
325
326// Verify size at compile time
327const _: () = assert!(std::mem::size_of::<MessageHeader>() == 256);
328
329/// Trait for types that can be sent as kernel messages.
330///
331/// This trait is typically implemented via the `#[derive(RingMessage)]` macro.
332///
333/// # Example
334///
335/// ```ignore
336/// #[derive(RingMessage)]
337/// struct MyRequest {
338///     #[message(id)]
339///     id: MessageId,
340///     data: Vec<f32>,
341/// }
342/// ```
343pub trait RingMessage: Send + Sync + 'static {
344    /// Get the message type discriminator.
345    fn message_type() -> u64;
346
347    /// Get the message ID.
348    fn message_id(&self) -> MessageId;
349
350    /// Get the correlation ID (if any).
351    fn correlation_id(&self) -> CorrelationId {
352        CorrelationId::none()
353    }
354
355    /// Get the priority.
356    fn priority(&self) -> Priority {
357        Priority::Normal
358    }
359
360    /// Serialize the message to bytes.
361    fn serialize(&self) -> Vec<u8>;
362
363    /// Deserialize a message from bytes.
364    fn deserialize(bytes: &[u8]) -> crate::error::Result<Self>
365    where
366        Self: Sized;
367
368    /// Get the serialized size hint.
369    fn size_hint(&self) -> usize
370    where
371        Self: Sized,
372    {
373        std::mem::size_of::<Self>()
374    }
375}
376
377/// Envelope containing header and serialized payload.
378///
379/// The optional `provenance` slot carries PROV-O attribution metadata for
380/// NSAI reasoning chains (see [`crate::provenance`]). When `None`, the field
381/// is a single discriminant byte - zero cost for the common case. When
382/// populated, it adds a fixed-size [`ProvenanceHeader`] (see that type for
383/// size details).
384///
385/// The `tenant_id` field is the primary multi-tenant isolation key — it
386/// defaults to `0` (the unspecified tenant) which preserves single-tenant
387/// fast-path behavior and is free of HashMap lookups in the K2K broker.
388///
389/// The `audit_tag` field carries billable-unit attribution (org_id +
390/// engagement_id) and defaults to [`AuditTag::default`] (both fields zero).
391/// The K2K broker stamps the sender's tag into outgoing envelopes when the
392/// caller leaves this as the default.
393///
394/// Neither the `provenance`, `tenant_id`, nor `audit_tag` fields are included
395/// in the legacy [`MessageEnvelope::to_bytes`] / [`MessageEnvelope::from_bytes`]
396/// wire format, which is defined by `MessageHeader` + raw payload for
397/// backwards compatibility. These travel separately (e.g. as part of
398/// rkyv-encoded envelope transfer on GPU) or are reattached by the router.
399#[derive(Debug, Clone)]
400pub struct MessageEnvelope {
401    /// Message header.
402    pub header: MessageHeader,
403    /// Serialized payload.
404    pub payload: Vec<u8>,
405    /// Optional PROV-O attribution metadata. Defaults to `None`; only
406    /// populated when the message participates in an audited reasoning
407    /// chain (e.g. VynGraph NSAI pipelines).
408    pub provenance: Option<ProvenanceHeader>,
409    /// Multi-tenant isolation key. Defaults to `0` (unspecified tenant),
410    /// matching single-tenant deployments that never opt in to isolation.
411    pub tenant_id: TenantId,
412    /// Billable-unit attribution: `{org_id, engagement_id}`. Defaults to
413    /// `AuditTag::default()` (both fields zero). The K2K broker stamps the
414    /// sending kernel's tag into envelopes whose tag is still the default.
415    pub audit_tag: AuditTag,
416}
417
418impl MessageEnvelope {
419    /// Create a new envelope from a message.
420    pub fn new<M: RingMessage>(
421        message: &M,
422        source_kernel: u64,
423        dest_kernel: u64,
424        timestamp: HlcTimestamp,
425    ) -> Self {
426        let payload = message.serialize();
427        let header = MessageHeader::new(
428            M::message_type(),
429            source_kernel,
430            dest_kernel,
431            payload.len(),
432            timestamp,
433        )
434        .with_correlation(message.correlation_id())
435        .with_priority(message.priority());
436
437        Self {
438            header,
439            payload,
440            provenance: None,
441            tenant_id: 0,
442            audit_tag: AuditTag::unspecified(),
443        }
444    }
445
446    /// Get total size (header + payload).
447    pub fn total_size(&self) -> usize {
448        std::mem::size_of::<MessageHeader>() + self.payload.len()
449    }
450
451    /// Serialize to contiguous bytes.
452    ///
453    /// NOTE: the provenance metadata is intentionally *not* serialised here.
454    /// This method keeps the historical wire format unchanged; provenance is
455    /// transported out-of-band or via rkyv-encoded transfer.
456    pub fn to_bytes(&self) -> Vec<u8> {
457        let mut bytes = Vec::with_capacity(self.total_size());
458        bytes.extend_from_slice(self.header.as_bytes());
459        bytes.extend_from_slice(&self.payload);
460        bytes
461    }
462
463    /// Deserialize from bytes.
464    ///
465    /// Reconstructs an envelope with `provenance: None` and default
466    /// tenant/audit fields. Callers that need provenance must reattach it via
467    /// [`MessageEnvelope::with_provenance`]; callers that need tenant
468    /// attribution must stamp it via [`MessageEnvelope::with_tenant_id`] /
469    /// [`MessageEnvelope::with_audit_tag`].
470    pub fn from_bytes(bytes: &[u8]) -> crate::error::Result<Self> {
471        if bytes.len() < std::mem::size_of::<MessageHeader>() {
472            return Err(crate::error::RingKernelError::DeserializationError(
473                "buffer too small for header".to_string(),
474            ));
475        }
476
477        let header_bytes = &bytes[..std::mem::size_of::<MessageHeader>()];
478        let header = MessageHeader::read_from(header_bytes).ok_or_else(|| {
479            crate::error::RingKernelError::DeserializationError("invalid header".to_string())
480        })?;
481
482        if !header.validate() {
483            return Err(crate::error::RingKernelError::ValidationError(
484                "header validation failed".to_string(),
485            ));
486        }
487
488        let payload_start = std::mem::size_of::<MessageHeader>();
489        let payload_end = payload_start + header.payload_size as usize;
490
491        if bytes.len() < payload_end {
492            return Err(crate::error::RingKernelError::DeserializationError(
493                "buffer too small for payload".to_string(),
494            ));
495        }
496
497        let payload = bytes[payload_start..payload_end].to_vec();
498
499        Ok(Self {
500            header,
501            payload,
502            provenance: None,
503            tenant_id: 0,
504            audit_tag: AuditTag::unspecified(),
505        })
506    }
507
508    /// Create an empty envelope (for testing).
509    pub fn empty(source_kernel: u64, dest_kernel: u64, timestamp: HlcTimestamp) -> Self {
510        let header = MessageHeader::new(0, source_kernel, dest_kernel, 0, timestamp);
511        Self {
512            header,
513            payload: Vec::new(),
514            provenance: None,
515            tenant_id: 0,
516            audit_tag: AuditTag::unspecified(),
517        }
518    }
519
520    /// Attach a PROV-O provenance header (builder-style).
521    pub fn with_provenance(mut self, provenance: ProvenanceHeader) -> Self {
522        self.provenance = Some(provenance);
523        self
524    }
525
526    /// Strip provenance (builder-style). Useful when routing a message into
527    /// an untrusted tenant boundary where attribution must not leak.
528    pub fn without_provenance(mut self) -> Self {
529        self.provenance = None;
530        self
531    }
532
533    /// Stamp the envelope with a tenant ID (builder-style).
534    ///
535    /// In the two-tier tenancy model the tenant ID is the primary isolation
536    /// key; the K2K broker uses it to route the message into the correct
537    /// per-tenant sub-broker. Defaults to `0` (unspecified tenant).
538    #[inline]
539    pub fn with_tenant_id(mut self, tenant_id: TenantId) -> Self {
540        self.tenant_id = tenant_id;
541        self
542    }
543
544    /// Attach an audit tag (builder-style).
545    ///
546    /// The audit tag carries billable-unit attribution (`org_id +
547    /// engagement_id`). The K2K broker preserves this tag across delivery
548    /// so downstream cost accounting can attribute GPU-seconds back to the
549    /// specific engagement.
550    #[inline]
551    pub fn with_audit_tag(mut self, audit_tag: AuditTag) -> Self {
552        self.audit_tag = audit_tag;
553        self
554    }
555}
556
557impl Default for MessageEnvelope {
558    fn default() -> Self {
559        Self {
560            header: MessageHeader::default(),
561            payload: Vec::new(),
562            provenance: None,
563            tenant_id: 0,
564            audit_tag: AuditTag::unspecified(),
565        }
566    }
567}
568
569#[cfg(test)]
570mod tests {
571    use super::*;
572
573    #[test]
574    fn test_message_id_generation() {
575        let id1 = MessageId::generate();
576        let id2 = MessageId::generate();
577        assert_ne!(id1, id2);
578    }
579
580    #[test]
581    fn test_header_validation() {
582        let header = MessageHeader::new(1, 0, 1, 100, HlcTimestamp::zero());
583        assert!(header.validate());
584
585        let mut invalid = header;
586        invalid.magic = 0;
587        assert!(!invalid.validate());
588    }
589
590    #[test]
591    fn test_header_size() {
592        assert_eq!(std::mem::size_of::<MessageHeader>(), 256);
593    }
594
595    #[test]
596    fn test_priority_conversion() {
597        assert_eq!(Priority::from_u8(0), Priority::Low);
598        assert_eq!(Priority::from_u8(1), Priority::Normal);
599        assert_eq!(Priority::from_u8(2), Priority::High);
600        assert_eq!(Priority::from_u8(3), Priority::Critical);
601        assert_eq!(Priority::from_u8(255), Priority::Critical);
602    }
603
604    #[test]
605    fn test_envelope_roundtrip() {
606        let header = MessageHeader::new(42, 0, 1, 8, HlcTimestamp::now(1));
607        let envelope = MessageEnvelope {
608            header,
609            payload: vec![1, 2, 3, 4, 5, 6, 7, 8],
610            provenance: None,
611            tenant_id: 0,
612            audit_tag: AuditTag::unspecified(),
613        };
614
615        let bytes = envelope.to_bytes();
616        let restored = MessageEnvelope::from_bytes(&bytes).unwrap();
617
618        assert_eq!(envelope.header.message_type, restored.header.message_type);
619        assert_eq!(envelope.payload, restored.payload);
620        // Provenance/tenant fields are intentionally not round-tripped through
621        // the legacy wire format; restored envelopes carry defaults until
622        // reattached.
623        assert!(restored.provenance.is_none());
624        assert_eq!(restored.tenant_id, 0);
625        assert!(restored.audit_tag.is_unspecified());
626    }
627
628    #[test]
629    fn test_envelope_with_tenant_and_audit_tag() {
630        let envelope = MessageEnvelope::empty(0, 1, HlcTimestamp::zero());
631        assert_eq!(envelope.tenant_id, 0);
632        assert!(envelope.audit_tag.is_unspecified());
633
634        let tagged = envelope
635            .with_tenant_id(42)
636            .with_audit_tag(AuditTag::new(99, 7));
637        assert_eq!(tagged.tenant_id, 42);
638        assert_eq!(tagged.audit_tag, AuditTag::new(99, 7));
639    }
640
641    #[test]
642    fn test_envelope_with_provenance() {
643        use crate::provenance::{ProvNodeType, ProvRelationKind, ProvenanceBuilder};
644
645        let hdr = ProvenanceBuilder::new(ProvNodeType::Entity, 0x42)
646            .with_relation(ProvRelationKind::WasAttributedTo, 0x7)
647            .build()
648            .unwrap();
649
650        let envelope = MessageEnvelope::empty(0, 1, HlcTimestamp::now(1)).with_provenance(hdr);
651        assert!(envelope.provenance.is_some());
652        assert_eq!(envelope.provenance.unwrap().node_id, 0x42);
653
654        let stripped = envelope.without_provenance();
655        assert!(stripped.provenance.is_none());
656    }
657
658    #[test]
659    fn test_envelope_default_has_no_provenance() {
660        let envelope = MessageEnvelope::empty(0, 1, HlcTimestamp::zero());
661        assert!(envelope.provenance.is_none());
662    }
663}