1use crate::arrow::TensorDtype;
9use memmap2::{Mmap, MmapMut};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::fs::{File, OpenOptions};
13use std::path::{Path, PathBuf};
14use std::sync::atomic::{AtomicU64, Ordering};
15use std::sync::Arc;
16
17pub struct SharedTensorBuffer {
19 mmap: MmapMut,
21 header: SharedBufferHeader,
23 path: PathBuf,
25}
26
27#[repr(C)]
29#[derive(Debug, Clone, Copy)]
30pub struct SharedBufferHeader {
31 pub magic: u64,
33 pub version: u32,
35 pub flags: u32,
37 pub total_size: u64,
39 pub data_offset: u64,
41 pub num_tensors: u32,
43 pub checksum: u64,
45 pub ref_count: u64,
47}
48
49impl SharedBufferHeader {
50 const MAGIC: u64 = 0x4950_4652_5354_454E; pub fn new(total_size: u64, data_offset: u64, num_tensors: u32) -> Self {
54 Self {
55 magic: Self::MAGIC,
56 version: 1,
57 flags: 0,
58 total_size,
59 data_offset,
60 num_tensors,
61 checksum: 0,
62 ref_count: 1,
63 }
64 }
65
66 pub fn validate(&self) -> bool {
68 self.magic == Self::MAGIC && self.version == 1
69 }
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct SharedTensorInfo {
75 pub name: String,
77 pub dtype: TensorDtype,
79 pub shape: Vec<usize>,
81 pub offset: usize,
83 pub size: usize,
85}
86
87impl SharedTensorBuffer {
88 pub fn create<P: AsRef<Path>>(
90 path: P,
91 size: usize,
92 tensors: &[SharedTensorInfo],
93 ) -> Result<Self, SharedMemoryError> {
94 let path = path.as_ref().to_path_buf();
95
96 let metadata_json = serde_json::to_vec(tensors)?;
98 let metadata_size = metadata_json.len();
99
100 let header_size = std::mem::size_of::<SharedBufferHeader>();
102 let metadata_offset = header_size;
103 let data_offset = metadata_offset + metadata_size + 8; let total_size = data_offset + size;
105
106 let file = OpenOptions::new()
108 .read(true)
109 .write(true)
110 .create(true)
111 .truncate(true)
112 .open(&path)?;
113
114 file.set_len(total_size as u64)?;
115
116 let mut mmap = unsafe { MmapMut::map_mut(&file)? };
118
119 let header =
121 SharedBufferHeader::new(total_size as u64, data_offset as u64, tensors.len() as u32);
122 let header_bytes: &[u8] = unsafe {
123 std::slice::from_raw_parts(
124 &header as *const SharedBufferHeader as *const u8,
125 std::mem::size_of::<SharedBufferHeader>(),
126 )
127 };
128 mmap[..header_size].copy_from_slice(header_bytes);
129
130 let metadata_len_bytes = (metadata_size as u64).to_le_bytes();
132 mmap[metadata_offset..metadata_offset + 8].copy_from_slice(&metadata_len_bytes);
133 mmap[metadata_offset + 8..metadata_offset + 8 + metadata_size]
134 .copy_from_slice(&metadata_json);
135
136 mmap.flush()?;
137
138 Ok(Self { mmap, header, path })
139 }
140
141 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, SharedMemoryError> {
143 let path = path.as_ref().to_path_buf();
144
145 let file = OpenOptions::new().read(true).write(true).open(&path)?;
146
147 let mmap = unsafe { MmapMut::map_mut(&file)? };
148
149 let header_size = std::mem::size_of::<SharedBufferHeader>();
151 if mmap.len() < header_size {
152 return Err(SharedMemoryError::InvalidFormat("File too small".into()));
153 }
154
155 let header: SharedBufferHeader =
156 unsafe { std::ptr::read(mmap.as_ptr() as *const SharedBufferHeader) };
157
158 if !header.validate() {
159 return Err(SharedMemoryError::InvalidFormat(
160 "Invalid header magic or version".into(),
161 ));
162 }
163
164 Ok(Self { mmap, header, path })
165 }
166
167 pub fn open_readonly<P: AsRef<Path>>(
169 path: P,
170 ) -> Result<SharedTensorBufferReadOnly, SharedMemoryError> {
171 let path = path.as_ref().to_path_buf();
172 let file = File::open(&path)?;
173 let mmap = unsafe { Mmap::map(&file)? };
174
175 let header_size = std::mem::size_of::<SharedBufferHeader>();
176 if mmap.len() < header_size {
177 return Err(SharedMemoryError::InvalidFormat("File too small".into()));
178 }
179
180 let header: SharedBufferHeader =
181 unsafe { std::ptr::read(mmap.as_ptr() as *const SharedBufferHeader) };
182
183 if !header.validate() {
184 return Err(SharedMemoryError::InvalidFormat(
185 "Invalid header magic or version".into(),
186 ));
187 }
188
189 Ok(SharedTensorBufferReadOnly { mmap, header, path })
190 }
191
192 pub fn tensor_metadata(&self) -> Result<Vec<SharedTensorInfo>, SharedMemoryError> {
194 let header_size = std::mem::size_of::<SharedBufferHeader>();
195
196 let mut len_bytes = [0u8; 8];
198 len_bytes.copy_from_slice(&self.mmap[header_size..header_size + 8]);
199 let metadata_len = u64::from_le_bytes(len_bytes) as usize;
200
201 let metadata_bytes = &self.mmap[header_size + 8..header_size + 8 + metadata_len];
203 let tensors: Vec<SharedTensorInfo> = serde_json::from_slice(metadata_bytes)?;
204
205 Ok(tensors)
206 }
207
208 pub fn tensor_data_mut(&mut self, info: &SharedTensorInfo) -> &mut [u8] {
210 let start = self.header.data_offset as usize + info.offset;
211 let end = start + info.size;
212 &mut self.mmap[start..end]
213 }
214
215 pub fn tensor_data(&self, info: &SharedTensorInfo) -> &[u8] {
217 let start = self.header.data_offset as usize + info.offset;
218 let end = start + info.size;
219 &self.mmap[start..end]
220 }
221
222 pub fn write_tensor<T: Copy>(&mut self, info: &SharedTensorInfo, data: &[T]) {
224 let bytes = unsafe {
225 std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data))
226 };
227 self.tensor_data_mut(info).copy_from_slice(bytes);
228 }
229
230 pub fn read_tensor<T: Copy + Default>(&self, info: &SharedTensorInfo) -> Vec<T> {
232 let bytes = self.tensor_data(info);
233 let elem_size = std::mem::size_of::<T>();
234 let count = bytes.len() / elem_size;
235 let mut result = vec![T::default(); count];
236
237 let result_bytes = unsafe {
239 std::slice::from_raw_parts_mut(result.as_mut_ptr() as *mut u8, count * elem_size)
240 };
241 result_bytes.copy_from_slice(&bytes[..count * elem_size]);
242 result
243 }
244
245 pub fn update_checksum(&mut self) {
247 let data_start = self.header.data_offset as usize;
248 let data = &self.mmap[data_start..];
249
250 let checksum: u64 = data.iter().fold(0u64, |acc, &b| acc.wrapping_add(b as u64));
252
253 let header_bytes = &mut self.mmap[..std::mem::size_of::<SharedBufferHeader>()];
255 let offset = std::mem::offset_of!(SharedBufferHeader, checksum);
256 header_bytes[offset..offset + 8].copy_from_slice(&checksum.to_le_bytes());
257 }
258
259 pub fn flush(&self) -> Result<(), SharedMemoryError> {
261 self.mmap.flush()?;
262 Ok(())
263 }
264
265 pub fn path(&self) -> &Path {
267 &self.path
268 }
269
270 pub fn size(&self) -> usize {
272 self.header.total_size as usize
273 }
274}
275
276pub struct SharedTensorBufferReadOnly {
278 mmap: Mmap,
280 header: SharedBufferHeader,
282 path: PathBuf,
284}
285
286impl SharedTensorBufferReadOnly {
287 pub fn tensor_metadata(&self) -> Result<Vec<SharedTensorInfo>, SharedMemoryError> {
289 let header_size = std::mem::size_of::<SharedBufferHeader>();
290
291 let mut len_bytes = [0u8; 8];
292 len_bytes.copy_from_slice(&self.mmap[header_size..header_size + 8]);
293 let metadata_len = u64::from_le_bytes(len_bytes) as usize;
294
295 let metadata_bytes = &self.mmap[header_size + 8..header_size + 8 + metadata_len];
296 let tensors: Vec<SharedTensorInfo> = serde_json::from_slice(metadata_bytes)?;
297
298 Ok(tensors)
299 }
300
301 pub fn tensor_data(&self, info: &SharedTensorInfo) -> &[u8] {
303 let start = self.header.data_offset as usize + info.offset;
304 let end = start + info.size;
305 &self.mmap[start..end]
306 }
307
308 pub fn read_tensor<T: Copy + Default>(&self, info: &SharedTensorInfo) -> Vec<T> {
310 let bytes = self.tensor_data(info);
311 let elem_size = std::mem::size_of::<T>();
312 let count = bytes.len() / elem_size;
313 let mut result = vec![T::default(); count];
314
315 let result_bytes = unsafe {
317 std::slice::from_raw_parts_mut(result.as_mut_ptr() as *mut u8, count * elem_size)
318 };
319 result_bytes.copy_from_slice(&bytes[..count * elem_size]);
320 result
321 }
322
323 pub fn verify_checksum(&self) -> bool {
325 let data_start = self.header.data_offset as usize;
326 let data = &self.mmap[data_start..];
327
328 let computed: u64 = data.iter().fold(0u64, |acc, &b| acc.wrapping_add(b as u64));
329 computed == self.header.checksum
330 }
331
332 pub fn path(&self) -> &Path {
334 &self.path
335 }
336}
337
338#[allow(dead_code)]
340pub struct SharedMemoryPool {
341 base_dir: PathBuf,
343 buffers: HashMap<String, Arc<SharedTensorBufferReadOnly>>,
345 max_size: usize,
347 current_size: AtomicU64,
349}
350
351impl SharedMemoryPool {
352 pub fn new<P: AsRef<Path>>(base_dir: P, max_size: usize) -> Self {
354 std::fs::create_dir_all(base_dir.as_ref()).ok();
355
356 Self {
357 base_dir: base_dir.as_ref().to_path_buf(),
358 buffers: HashMap::new(),
359 max_size,
360 current_size: AtomicU64::new(0),
361 }
362 }
363
364 pub fn register(
366 &mut self,
367 name: &str,
368 buffer: SharedTensorBufferReadOnly,
369 ) -> Result<(), SharedMemoryError> {
370 let size = buffer.mmap.len();
371
372 let current = self.current_size.load(Ordering::Relaxed);
374 if current + size as u64 > self.max_size as u64 {
375 return Err(SharedMemoryError::PoolFull);
376 }
377
378 self.current_size.fetch_add(size as u64, Ordering::Relaxed);
379 self.buffers.insert(name.to_string(), Arc::new(buffer));
380
381 Ok(())
382 }
383
384 pub fn get(&self, name: &str) -> Option<Arc<SharedTensorBufferReadOnly>> {
386 self.buffers.get(name).cloned()
387 }
388
389 pub fn remove(&mut self, name: &str) -> Option<Arc<SharedTensorBufferReadOnly>> {
391 if let Some(buffer) = self.buffers.remove(name) {
392 let size = buffer.mmap.len() as u64;
393 self.current_size.fetch_sub(size, Ordering::Relaxed);
394 Some(buffer)
395 } else {
396 None
397 }
398 }
399
400 pub fn list(&self) -> Vec<&str> {
402 self.buffers.keys().map(|s| s.as_str()).collect()
403 }
404
405 pub fn memory_usage(&self) -> usize {
407 self.current_size.load(Ordering::Relaxed) as usize
408 }
409
410 pub fn available(&self) -> usize {
412 self.max_size.saturating_sub(self.memory_usage())
413 }
414}
415
416#[derive(Debug)]
418pub enum SharedMemoryError {
419 Io(std::io::Error),
421 InvalidFormat(String),
423 Json(serde_json::Error),
425 PoolFull,
427 NotFound(String),
429}
430
431impl From<std::io::Error> for SharedMemoryError {
432 fn from(err: std::io::Error) -> Self {
433 SharedMemoryError::Io(err)
434 }
435}
436
437impl From<serde_json::Error> for SharedMemoryError {
438 fn from(err: serde_json::Error) -> Self {
439 SharedMemoryError::Json(err)
440 }
441}
442
443impl std::fmt::Display for SharedMemoryError {
444 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
445 match self {
446 SharedMemoryError::Io(e) => write!(f, "IO error: {}", e),
447 SharedMemoryError::InvalidFormat(s) => write!(f, "Invalid format: {}", s),
448 SharedMemoryError::Json(e) => write!(f, "JSON error: {}", e),
449 SharedMemoryError::PoolFull => write!(f, "Shared memory pool is full"),
450 SharedMemoryError::NotFound(s) => write!(f, "Buffer not found: {}", s),
451 }
452 }
453}
454
455impl std::error::Error for SharedMemoryError {}
456
457#[cfg(test)]
458mod tests {
459 use super::*;
460 use tempfile::tempdir;
461
462 #[test]
463 fn test_shared_buffer_create_and_read() {
464 let dir = tempdir().unwrap();
465 let path = dir.path().join("test.shm");
466
467 let tensors = vec![
469 SharedTensorInfo {
470 name: "weights".to_string(),
471 dtype: TensorDtype::Float32,
472 shape: vec![2, 3],
473 offset: 0,
474 size: 24, },
476 SharedTensorInfo {
477 name: "bias".to_string(),
478 dtype: TensorDtype::Float32,
479 shape: vec![3],
480 offset: 24,
481 size: 12, },
483 ];
484
485 let mut buffer = SharedTensorBuffer::create(&path, 36, &tensors).unwrap();
487
488 let weights: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
490 let bias: Vec<f32> = vec![0.1, 0.2, 0.3];
491
492 buffer.write_tensor(&tensors[0], &weights);
493 buffer.write_tensor(&tensors[1], &bias);
494 buffer.update_checksum();
495 buffer.flush().unwrap();
496
497 let read_buffer = SharedTensorBuffer::open_readonly(&path).unwrap();
499 let metadata = read_buffer.tensor_metadata().unwrap();
500
501 assert_eq!(metadata.len(), 2);
502 assert_eq!(metadata[0].name, "weights");
503 assert_eq!(metadata[1].name, "bias");
504
505 let read_weights: Vec<f32> = read_buffer.read_tensor(&metadata[0]);
506 let read_bias: Vec<f32> = read_buffer.read_tensor(&metadata[1]);
507
508 assert_eq!(read_weights, weights);
509 assert_eq!(read_bias, bias);
510 }
511
512 #[test]
513 fn test_memory_pool() {
514 let dir = tempdir().unwrap();
515 let pool_dir = dir.path().join("pool");
516
517 let mut pool = SharedMemoryPool::new(&pool_dir, 1024 * 1024);
518
519 let path = pool_dir.join("test1.shm");
521 let tensors = vec![SharedTensorInfo {
522 name: "test".to_string(),
523 dtype: TensorDtype::Float32,
524 shape: vec![4],
525 offset: 0,
526 size: 16,
527 }];
528
529 SharedTensorBuffer::create(&path, 16, &tensors).unwrap();
530
531 let buffer = SharedTensorBuffer::open_readonly(&path).unwrap();
533 pool.register("test1", buffer).unwrap();
534
535 assert_eq!(pool.list().len(), 1);
536 assert!(pool.get("test1").is_some());
537
538 pool.remove("test1");
540 assert!(pool.get("test1").is_none());
541 }
542}