1use core::mem::size_of;
7
8use crate::error::Result;
9use crate::magic::MEMLINK_MAGIC;
10use crate::sproto::{STREAM_ID_SIZE, STREAM_HANDLE_SIZE};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13#[repr(C)]
14pub struct StreamHandle {
15 pub stream_id: [u8; STREAM_ID_SIZE],
16 pub total_size: u64,
17 pub expires_ns: u64,
18 magic: u32,
19 _reserved1: [u8; 4],
20 pub checksum: u128,
21}
22
23const _: () = {
24 assert!(
25 size_of::<StreamHandle>() == STREAM_HANDLE_SIZE,
26 "StreamHandle must be exactly 80 bytes"
27 );
28};
29
30impl StreamHandle {
31 pub fn new(stream_id: [u8; STREAM_ID_SIZE], total_size: u64, expires_ns: u64) -> Self {
32 Self {
33 stream_id,
34 total_size,
35 expires_ns,
36 magic: MEMLINK_MAGIC,
37 _reserved1: [0; 4],
38 checksum: 0,
39 }
40 }
41
42 pub fn generate(total_size: u64) -> Self {
43 let stream_id = generate_random_id();
44 let expires_ns = 0;
45 Self::new(stream_id, total_size, expires_ns)
46 }
47
48 pub fn with_timeout(total_size: u64, timeout_ns: u64) -> Self {
49 let stream_id = generate_random_id();
50 let expires_ns = get_current_time_ns().saturating_add(timeout_ns);
51 Self::new(stream_id, total_size, expires_ns)
52 }
53
54 pub fn stream_id(&self) -> &[u8; STREAM_ID_SIZE] {
55 &self.stream_id
56 }
57
58 pub fn total_size(&self) -> u64 {
59 self.total_size
60 }
61
62 pub fn checksum(&self) -> u128 {
63 self.checksum
64 }
65
66 pub fn expires_ns(&self) -> u64 {
67 self.expires_ns
68 }
69
70 pub fn is_expired(&self) -> bool {
71 if self.expires_ns == 0 {
72 return false;
73 }
74
75 let now = get_current_time_ns();
76 now > self.expires_ns
77 }
78
79 pub fn validate(&self) -> Result<(), StreamError> {
80 if self.magic != MEMLINK_MAGIC {
81 return Err(StreamError::InvalidMagic);
82 }
83
84 if self.stream_id.iter().all(|&b| b == 0) {
85 return Err(StreamError::InvalidStreamId);
86 }
87
88 if self.is_expired() {
89 return Err(StreamError::StreamExpired);
90 }
91
92 Ok(())
93 }
94
95 pub fn as_bytes(&self) -> [u8; STREAM_HANDLE_SIZE] {
96 let mut bytes = [0u8; STREAM_HANDLE_SIZE];
97
98 bytes[0..32].copy_from_slice(&self.stream_id);
99 bytes[32..40].copy_from_slice(&self.total_size.to_le_bytes());
100 bytes[40..48].copy_from_slice(&self.expires_ns.to_le_bytes());
101 bytes[48..52].copy_from_slice(&self.magic.to_le_bytes());
102 bytes[52..56].copy_from_slice(&self._reserved1);
103 bytes[56..64].copy_from_slice(&[0u8; 8]);
104 bytes[64..80].copy_from_slice(&self.checksum.to_le_bytes());
105
106 bytes
107 }
108
109 pub fn from_bytes(bytes: &[u8; STREAM_HANDLE_SIZE]) -> Result<Self, StreamError> {
110 let mut stream_id = [0u8; STREAM_ID_SIZE];
111 stream_id.copy_from_slice(&bytes[0..32]);
112
113 let total_size = u64::from_le_bytes([
114 bytes[32], bytes[33], bytes[34], bytes[35],
115 bytes[36], bytes[37], bytes[38], bytes[39],
116 ]);
117
118 let expires_ns = u64::from_le_bytes([
119 bytes[40], bytes[41], bytes[42], bytes[43],
120 bytes[44], bytes[45], bytes[46], bytes[47],
121 ]);
122
123 let magic = u32::from_le_bytes([bytes[48], bytes[49], bytes[50], bytes[51]]);
124 let _reserved1 = [bytes[52], bytes[53], bytes[54], bytes[55]];
125
126 let checksum = u128::from_le_bytes([
127 bytes[64], bytes[65], bytes[66], bytes[67],
128 bytes[68], bytes[69], bytes[70], bytes[71],
129 bytes[72], bytes[73], bytes[74], bytes[75],
130 bytes[76], bytes[77], bytes[78], bytes[79],
131 ]);
132
133 let handle = Self {
134 stream_id,
135 total_size,
136 expires_ns,
137 magic,
138 _reserved1,
139 checksum,
140 };
141
142 handle.validate()?;
143
144 Ok(handle)
145 }
146
147 pub fn set_checksum(&mut self, checksum: u128) {
148 self.checksum = checksum;
149 }
150}
151
152#[derive(Debug, Clone, Copy, PartialEq, Eq)]
153pub enum StreamError {
154 InvalidMagic,
155 InvalidStreamId,
156 StreamExpired,
157 InvalidLength,
158}
159
160impl core::fmt::Display for StreamError {
161 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
162 match self {
163 StreamError::InvalidMagic => write!(f, "invalid magic number"),
164 StreamError::InvalidStreamId => write!(f, "invalid stream ID"),
165 StreamError::StreamExpired => write!(f, "stream handle has expired"),
166 StreamError::InvalidLength => write!(f, "invalid byte array length"),
167 }
168 }
169}
170
171fn generate_random_id() -> [u8; STREAM_ID_SIZE] {
172 let mut id = [0u8; STREAM_ID_SIZE];
173 let seed = get_current_time_ns() ^ (MEMLINK_MAGIC as u64);
174
175 let mut state = seed;
176 for chunk in id.chunks_mut(8) {
177 state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
178 chunk.copy_from_slice(&state.to_le_bytes());
179 }
180
181 id
182}
183
184#[inline]
185fn get_current_time_ns() -> u64 {
186 #[cfg(feature = "std")]
187 {
188 extern crate std;
189 use std::time::{SystemTime, UNIX_EPOCH};
190 SystemTime::now()
191 .duration_since(UNIX_EPOCH)
192 .unwrap_or_default()
193 .as_nanos() as u64
194 }
195 #[cfg(not(feature = "std"))]
196 {
197 0
198 }
199}