Skip to main content

memlink_protocol/
shandle.rs

1//! Stream handle for large payload transfers.
2//!
3//! Defines StreamHandle struct (80 bytes) with stream_id, total_size,
4//! expires_ns, and checksum for referencing out-of-band data streams.
5
6use core::mem::size_of;
7
8use crate::error::Result;
9use crate::magic::MEMLINK_MAGIC;
10use crate::sproto::{STREAM_ID_SIZE, STREAM_HANDLE_SIZE};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13#[repr(C)]
14pub struct StreamHandle {
15    pub stream_id: [u8; STREAM_ID_SIZE],
16    pub total_size: u64,
17    pub expires_ns: u64,
18    magic: u32,
19    _reserved1: [u8; 4],
20    pub checksum: u128,
21}
22
23const _: () = {
24    assert!(
25        size_of::<StreamHandle>() == STREAM_HANDLE_SIZE,
26        "StreamHandle must be exactly 80 bytes"
27    );
28};
29
30impl StreamHandle {
31    pub fn new(stream_id: [u8; STREAM_ID_SIZE], total_size: u64, expires_ns: u64) -> Self {
32        Self {
33            stream_id,
34            total_size,
35            expires_ns,
36            magic: MEMLINK_MAGIC,
37            _reserved1: [0; 4],
38            checksum: 0,
39        }
40    }
41
42    pub fn generate(total_size: u64) -> Self {
43        let stream_id = generate_random_id();
44        let expires_ns = 0;
45        Self::new(stream_id, total_size, expires_ns)
46    }
47
48    pub fn with_timeout(total_size: u64, timeout_ns: u64) -> Self {
49        let stream_id = generate_random_id();
50        let expires_ns = get_current_time_ns().saturating_add(timeout_ns);
51        Self::new(stream_id, total_size, expires_ns)
52    }
53
54    pub fn stream_id(&self) -> &[u8; STREAM_ID_SIZE] {
55        &self.stream_id
56    }
57
58    pub fn total_size(&self) -> u64 {
59        self.total_size
60    }
61
62    pub fn checksum(&self) -> u128 {
63        self.checksum
64    }
65
66    pub fn expires_ns(&self) -> u64 {
67        self.expires_ns
68    }
69
70    pub fn is_expired(&self) -> bool {
71        if self.expires_ns == 0 {
72            return false;
73        }
74
75        let now = get_current_time_ns();
76        now > self.expires_ns
77    }
78
79    pub fn validate(&self) -> Result<(), StreamError> {
80        if self.magic != MEMLINK_MAGIC {
81            return Err(StreamError::InvalidMagic);
82        }
83
84        if self.stream_id.iter().all(|&b| b == 0) {
85            return Err(StreamError::InvalidStreamId);
86        }
87
88        if self.is_expired() {
89            return Err(StreamError::StreamExpired);
90        }
91
92        Ok(())
93    }
94
95    pub fn as_bytes(&self) -> [u8; STREAM_HANDLE_SIZE] {
96        let mut bytes = [0u8; STREAM_HANDLE_SIZE];
97
98        bytes[0..32].copy_from_slice(&self.stream_id);
99        bytes[32..40].copy_from_slice(&self.total_size.to_le_bytes());
100        bytes[40..48].copy_from_slice(&self.expires_ns.to_le_bytes());
101        bytes[48..52].copy_from_slice(&self.magic.to_le_bytes());
102        bytes[52..56].copy_from_slice(&self._reserved1);
103        bytes[56..64].copy_from_slice(&[0u8; 8]);
104        bytes[64..80].copy_from_slice(&self.checksum.to_le_bytes());
105
106        bytes
107    }
108
109    pub fn from_bytes(bytes: &[u8; STREAM_HANDLE_SIZE]) -> Result<Self, StreamError> {
110        let mut stream_id = [0u8; STREAM_ID_SIZE];
111        stream_id.copy_from_slice(&bytes[0..32]);
112
113        let total_size = u64::from_le_bytes([
114            bytes[32], bytes[33], bytes[34], bytes[35],
115            bytes[36], bytes[37], bytes[38], bytes[39],
116        ]);
117
118        let expires_ns = u64::from_le_bytes([
119            bytes[40], bytes[41], bytes[42], bytes[43],
120            bytes[44], bytes[45], bytes[46], bytes[47],
121        ]);
122
123        let magic = u32::from_le_bytes([bytes[48], bytes[49], bytes[50], bytes[51]]);
124        let _reserved1 = [bytes[52], bytes[53], bytes[54], bytes[55]];
125
126        let checksum = u128::from_le_bytes([
127            bytes[64], bytes[65], bytes[66], bytes[67],
128            bytes[68], bytes[69], bytes[70], bytes[71],
129            bytes[72], bytes[73], bytes[74], bytes[75],
130            bytes[76], bytes[77], bytes[78], bytes[79],
131        ]);
132
133        let handle = Self {
134            stream_id,
135            total_size,
136            expires_ns,
137            magic,
138            _reserved1,
139            checksum,
140        };
141
142        handle.validate()?;
143
144        Ok(handle)
145    }
146
147    pub fn set_checksum(&mut self, checksum: u128) {
148        self.checksum = checksum;
149    }
150}
151
152#[derive(Debug, Clone, Copy, PartialEq, Eq)]
153pub enum StreamError {
154    InvalidMagic,
155    InvalidStreamId,
156    StreamExpired,
157    InvalidLength,
158}
159
160impl core::fmt::Display for StreamError {
161    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
162        match self {
163            StreamError::InvalidMagic => write!(f, "invalid magic number"),
164            StreamError::InvalidStreamId => write!(f, "invalid stream ID"),
165            StreamError::StreamExpired => write!(f, "stream handle has expired"),
166            StreamError::InvalidLength => write!(f, "invalid byte array length"),
167        }
168    }
169}
170
171fn generate_random_id() -> [u8; STREAM_ID_SIZE] {
172    let mut id = [0u8; STREAM_ID_SIZE];
173    let seed = get_current_time_ns() ^ (MEMLINK_MAGIC as u64);
174
175    let mut state = seed;
176    for chunk in id.chunks_mut(8) {
177        state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
178        chunk.copy_from_slice(&state.to_le_bytes());
179    }
180
181    id
182}
183
184#[inline]
185fn get_current_time_ns() -> u64 {
186    #[cfg(feature = "std")]
187    {
188        extern crate std;
189        use std::time::{SystemTime, UNIX_EPOCH};
190        SystemTime::now()
191            .duration_since(UNIX_EPOCH)
192            .unwrap_or_default()
193            .as_nanos() as u64
194    }
195    #[cfg(not(feature = "std"))]
196    {
197        0
198    }
199}