ipfrs_tensorlogic/
shared_memory.rs

1//! Shared memory support for zero-copy IPC
2//!
3//! Provides mmap-based buffer sharing for:
4//! - Cross-process tensor sharing
5//! - Zero-copy IPC between processes
6//! - Memory-efficient model serving
7
8use crate::arrow::TensorDtype;
9use memmap2::{Mmap, MmapMut};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::fs::{File, OpenOptions};
13use std::path::{Path, PathBuf};
14use std::sync::atomic::{AtomicU64, Ordering};
15use std::sync::Arc;
16
17/// Shared memory region for tensor data
18pub struct SharedTensorBuffer {
19    /// Memory-mapped region
20    mmap: MmapMut,
21    /// Header with metadata
22    header: SharedBufferHeader,
23    /// Path to the backing file
24    path: PathBuf,
25}
26
27/// Header stored at the beginning of shared memory
28#[repr(C)]
29#[derive(Debug, Clone, Copy)]
30pub struct SharedBufferHeader {
31    /// Magic number for validation
32    pub magic: u64,
33    /// Version number
34    pub version: u32,
35    /// Flags
36    pub flags: u32,
37    /// Total size of the buffer
38    pub total_size: u64,
39    /// Data offset (after header and metadata)
40    pub data_offset: u64,
41    /// Number of tensors
42    pub num_tensors: u32,
43    /// Checksum of data
44    pub checksum: u64,
45    /// Reference count (for multi-process access)
46    pub ref_count: u64,
47}
48
49impl SharedBufferHeader {
50    const MAGIC: u64 = 0x4950_4652_5354_454E; // "IPFRSTN"
51
52    /// Create a new header
53    pub fn new(total_size: u64, data_offset: u64, num_tensors: u32) -> Self {
54        Self {
55            magic: Self::MAGIC,
56            version: 1,
57            flags: 0,
58            total_size,
59            data_offset,
60            num_tensors,
61            checksum: 0,
62            ref_count: 1,
63        }
64    }
65
66    /// Validate the header
67    pub fn validate(&self) -> bool {
68        self.magic == Self::MAGIC && self.version == 1
69    }
70}
71
72/// Metadata for tensors in shared memory
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct SharedTensorInfo {
75    /// Tensor name
76    pub name: String,
77    /// Data type
78    pub dtype: TensorDtype,
79    /// Shape
80    pub shape: Vec<usize>,
81    /// Offset from data start
82    pub offset: usize,
83    /// Size in bytes
84    pub size: usize,
85}
86
87impl SharedTensorBuffer {
88    /// Create a new shared tensor buffer
89    pub fn create<P: AsRef<Path>>(
90        path: P,
91        size: usize,
92        tensors: &[SharedTensorInfo],
93    ) -> Result<Self, SharedMemoryError> {
94        let path = path.as_ref().to_path_buf();
95
96        // Serialize tensor metadata
97        let metadata_json = serde_json::to_vec(tensors)?;
98        let metadata_size = metadata_json.len();
99
100        // Calculate total size needed
101        let header_size = std::mem::size_of::<SharedBufferHeader>();
102        let metadata_offset = header_size;
103        let data_offset = metadata_offset + metadata_size + 8; // 8 bytes for metadata length
104        let total_size = data_offset + size;
105
106        // Create and size the file
107        let file = OpenOptions::new()
108            .read(true)
109            .write(true)
110            .create(true)
111            .truncate(true)
112            .open(&path)?;
113
114        file.set_len(total_size as u64)?;
115
116        // Memory map the file
117        let mut mmap = unsafe { MmapMut::map_mut(&file)? };
118
119        // Write header
120        let header =
121            SharedBufferHeader::new(total_size as u64, data_offset as u64, tensors.len() as u32);
122        let header_bytes: &[u8] = unsafe {
123            std::slice::from_raw_parts(
124                &header as *const SharedBufferHeader as *const u8,
125                std::mem::size_of::<SharedBufferHeader>(),
126            )
127        };
128        mmap[..header_size].copy_from_slice(header_bytes);
129
130        // Write metadata length and metadata
131        let metadata_len_bytes = (metadata_size as u64).to_le_bytes();
132        mmap[metadata_offset..metadata_offset + 8].copy_from_slice(&metadata_len_bytes);
133        mmap[metadata_offset + 8..metadata_offset + 8 + metadata_size]
134            .copy_from_slice(&metadata_json);
135
136        mmap.flush()?;
137
138        Ok(Self { mmap, header, path })
139    }
140
141    /// Open an existing shared tensor buffer
142    pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, SharedMemoryError> {
143        let path = path.as_ref().to_path_buf();
144
145        let file = OpenOptions::new().read(true).write(true).open(&path)?;
146
147        let mmap = unsafe { MmapMut::map_mut(&file)? };
148
149        // Read and validate header
150        let header_size = std::mem::size_of::<SharedBufferHeader>();
151        if mmap.len() < header_size {
152            return Err(SharedMemoryError::InvalidFormat("File too small".into()));
153        }
154
155        let header: SharedBufferHeader =
156            unsafe { std::ptr::read(mmap.as_ptr() as *const SharedBufferHeader) };
157
158        if !header.validate() {
159            return Err(SharedMemoryError::InvalidFormat(
160                "Invalid header magic or version".into(),
161            ));
162        }
163
164        Ok(Self { mmap, header, path })
165    }
166
167    /// Open read-only
168    pub fn open_readonly<P: AsRef<Path>>(
169        path: P,
170    ) -> Result<SharedTensorBufferReadOnly, SharedMemoryError> {
171        let path = path.as_ref().to_path_buf();
172        let file = File::open(&path)?;
173        let mmap = unsafe { Mmap::map(&file)? };
174
175        let header_size = std::mem::size_of::<SharedBufferHeader>();
176        if mmap.len() < header_size {
177            return Err(SharedMemoryError::InvalidFormat("File too small".into()));
178        }
179
180        let header: SharedBufferHeader =
181            unsafe { std::ptr::read(mmap.as_ptr() as *const SharedBufferHeader) };
182
183        if !header.validate() {
184            return Err(SharedMemoryError::InvalidFormat(
185                "Invalid header magic or version".into(),
186            ));
187        }
188
189        Ok(SharedTensorBufferReadOnly { mmap, header, path })
190    }
191
192    /// Get tensor metadata
193    pub fn tensor_metadata(&self) -> Result<Vec<SharedTensorInfo>, SharedMemoryError> {
194        let header_size = std::mem::size_of::<SharedBufferHeader>();
195
196        // Read metadata length
197        let mut len_bytes = [0u8; 8];
198        len_bytes.copy_from_slice(&self.mmap[header_size..header_size + 8]);
199        let metadata_len = u64::from_le_bytes(len_bytes) as usize;
200
201        // Read metadata
202        let metadata_bytes = &self.mmap[header_size + 8..header_size + 8 + metadata_len];
203        let tensors: Vec<SharedTensorInfo> = serde_json::from_slice(metadata_bytes)?;
204
205        Ok(tensors)
206    }
207
208    /// Get mutable data slice for a tensor
209    pub fn tensor_data_mut(&mut self, info: &SharedTensorInfo) -> &mut [u8] {
210        let start = self.header.data_offset as usize + info.offset;
211        let end = start + info.size;
212        &mut self.mmap[start..end]
213    }
214
215    /// Get data slice for a tensor
216    pub fn tensor_data(&self, info: &SharedTensorInfo) -> &[u8] {
217        let start = self.header.data_offset as usize + info.offset;
218        let end = start + info.size;
219        &self.mmap[start..end]
220    }
221
222    /// Write tensor data
223    pub fn write_tensor<T: Copy>(&mut self, info: &SharedTensorInfo, data: &[T]) {
224        let bytes = unsafe {
225            std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data))
226        };
227        self.tensor_data_mut(info).copy_from_slice(bytes);
228    }
229
230    /// Read tensor data as typed Vec (safe copy)
231    pub fn read_tensor<T: Copy + Default>(&self, info: &SharedTensorInfo) -> Vec<T> {
232        let bytes = self.tensor_data(info);
233        let elem_size = std::mem::size_of::<T>();
234        let count = bytes.len() / elem_size;
235        let mut result = vec![T::default(); count];
236
237        // Safe copy using byte manipulation
238        let result_bytes = unsafe {
239            std::slice::from_raw_parts_mut(result.as_mut_ptr() as *mut u8, count * elem_size)
240        };
241        result_bytes.copy_from_slice(&bytes[..count * elem_size]);
242        result
243    }
244
245    /// Update checksum
246    pub fn update_checksum(&mut self) {
247        let data_start = self.header.data_offset as usize;
248        let data = &self.mmap[data_start..];
249
250        // Simple checksum (could use CRC32 or Blake3)
251        let checksum: u64 = data.iter().fold(0u64, |acc, &b| acc.wrapping_add(b as u64));
252
253        // Update header
254        let header_bytes = &mut self.mmap[..std::mem::size_of::<SharedBufferHeader>()];
255        let offset = std::mem::offset_of!(SharedBufferHeader, checksum);
256        header_bytes[offset..offset + 8].copy_from_slice(&checksum.to_le_bytes());
257    }
258
259    /// Flush changes to disk
260    pub fn flush(&self) -> Result<(), SharedMemoryError> {
261        self.mmap.flush()?;
262        Ok(())
263    }
264
265    /// Get the path to the backing file
266    pub fn path(&self) -> &Path {
267        &self.path
268    }
269
270    /// Get total size
271    pub fn size(&self) -> usize {
272        self.header.total_size as usize
273    }
274}
275
276/// Read-only shared tensor buffer
277pub struct SharedTensorBufferReadOnly {
278    /// Memory-mapped region
279    mmap: Mmap,
280    /// Header
281    header: SharedBufferHeader,
282    /// Path
283    path: PathBuf,
284}
285
286impl SharedTensorBufferReadOnly {
287    /// Get tensor metadata
288    pub fn tensor_metadata(&self) -> Result<Vec<SharedTensorInfo>, SharedMemoryError> {
289        let header_size = std::mem::size_of::<SharedBufferHeader>();
290
291        let mut len_bytes = [0u8; 8];
292        len_bytes.copy_from_slice(&self.mmap[header_size..header_size + 8]);
293        let metadata_len = u64::from_le_bytes(len_bytes) as usize;
294
295        let metadata_bytes = &self.mmap[header_size + 8..header_size + 8 + metadata_len];
296        let tensors: Vec<SharedTensorInfo> = serde_json::from_slice(metadata_bytes)?;
297
298        Ok(tensors)
299    }
300
301    /// Get data slice for a tensor
302    pub fn tensor_data(&self, info: &SharedTensorInfo) -> &[u8] {
303        let start = self.header.data_offset as usize + info.offset;
304        let end = start + info.size;
305        &self.mmap[start..end]
306    }
307
308    /// Read tensor data as typed Vec (safe copy)
309    pub fn read_tensor<T: Copy + Default>(&self, info: &SharedTensorInfo) -> Vec<T> {
310        let bytes = self.tensor_data(info);
311        let elem_size = std::mem::size_of::<T>();
312        let count = bytes.len() / elem_size;
313        let mut result = vec![T::default(); count];
314
315        // Safe copy using byte manipulation
316        let result_bytes = unsafe {
317            std::slice::from_raw_parts_mut(result.as_mut_ptr() as *mut u8, count * elem_size)
318        };
319        result_bytes.copy_from_slice(&bytes[..count * elem_size]);
320        result
321    }
322
323    /// Verify checksum
324    pub fn verify_checksum(&self) -> bool {
325        let data_start = self.header.data_offset as usize;
326        let data = &self.mmap[data_start..];
327
328        let computed: u64 = data.iter().fold(0u64, |acc, &b| acc.wrapping_add(b as u64));
329        computed == self.header.checksum
330    }
331
332    /// Get path
333    pub fn path(&self) -> &Path {
334        &self.path
335    }
336}
337
338/// Shared memory pool for managing multiple buffers
339#[allow(dead_code)]
340pub struct SharedMemoryPool {
341    /// Base directory for shared memory files
342    base_dir: PathBuf,
343    /// Active buffers
344    buffers: HashMap<String, Arc<SharedTensorBufferReadOnly>>,
345    /// Maximum total size
346    max_size: usize,
347    /// Current total size
348    current_size: AtomicU64,
349}
350
351impl SharedMemoryPool {
352    /// Create a new pool
353    pub fn new<P: AsRef<Path>>(base_dir: P, max_size: usize) -> Self {
354        std::fs::create_dir_all(base_dir.as_ref()).ok();
355
356        Self {
357            base_dir: base_dir.as_ref().to_path_buf(),
358            buffers: HashMap::new(),
359            max_size,
360            current_size: AtomicU64::new(0),
361        }
362    }
363
364    /// Register a buffer in the pool
365    pub fn register(
366        &mut self,
367        name: &str,
368        buffer: SharedTensorBufferReadOnly,
369    ) -> Result<(), SharedMemoryError> {
370        let size = buffer.mmap.len();
371
372        // Check size limit
373        let current = self.current_size.load(Ordering::Relaxed);
374        if current + size as u64 > self.max_size as u64 {
375            return Err(SharedMemoryError::PoolFull);
376        }
377
378        self.current_size.fetch_add(size as u64, Ordering::Relaxed);
379        self.buffers.insert(name.to_string(), Arc::new(buffer));
380
381        Ok(())
382    }
383
384    /// Get a buffer by name
385    pub fn get(&self, name: &str) -> Option<Arc<SharedTensorBufferReadOnly>> {
386        self.buffers.get(name).cloned()
387    }
388
389    /// Remove a buffer
390    pub fn remove(&mut self, name: &str) -> Option<Arc<SharedTensorBufferReadOnly>> {
391        if let Some(buffer) = self.buffers.remove(name) {
392            let size = buffer.mmap.len() as u64;
393            self.current_size.fetch_sub(size, Ordering::Relaxed);
394            Some(buffer)
395        } else {
396            None
397        }
398    }
399
400    /// List all buffer names
401    pub fn list(&self) -> Vec<&str> {
402        self.buffers.keys().map(|s| s.as_str()).collect()
403    }
404
405    /// Get current memory usage
406    pub fn memory_usage(&self) -> usize {
407        self.current_size.load(Ordering::Relaxed) as usize
408    }
409
410    /// Get available memory
411    pub fn available(&self) -> usize {
412        self.max_size.saturating_sub(self.memory_usage())
413    }
414}
415
416/// Error types for shared memory operations
417#[derive(Debug)]
418pub enum SharedMemoryError {
419    /// IO error
420    Io(std::io::Error),
421    /// Invalid format
422    InvalidFormat(String),
423    /// JSON serialization error
424    Json(serde_json::Error),
425    /// Pool is full
426    PoolFull,
427    /// Buffer not found
428    NotFound(String),
429}
430
431impl From<std::io::Error> for SharedMemoryError {
432    fn from(err: std::io::Error) -> Self {
433        SharedMemoryError::Io(err)
434    }
435}
436
437impl From<serde_json::Error> for SharedMemoryError {
438    fn from(err: serde_json::Error) -> Self {
439        SharedMemoryError::Json(err)
440    }
441}
442
443impl std::fmt::Display for SharedMemoryError {
444    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
445        match self {
446            SharedMemoryError::Io(e) => write!(f, "IO error: {}", e),
447            SharedMemoryError::InvalidFormat(s) => write!(f, "Invalid format: {}", s),
448            SharedMemoryError::Json(e) => write!(f, "JSON error: {}", e),
449            SharedMemoryError::PoolFull => write!(f, "Shared memory pool is full"),
450            SharedMemoryError::NotFound(s) => write!(f, "Buffer not found: {}", s),
451        }
452    }
453}
454
455impl std::error::Error for SharedMemoryError {}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460    use tempfile::tempdir;
461
462    #[test]
463    fn test_shared_buffer_create_and_read() {
464        let dir = tempdir().unwrap();
465        let path = dir.path().join("test.shm");
466
467        // Define tensors
468        let tensors = vec![
469            SharedTensorInfo {
470                name: "weights".to_string(),
471                dtype: TensorDtype::Float32,
472                shape: vec![2, 3],
473                offset: 0,
474                size: 24, // 6 * 4 bytes
475            },
476            SharedTensorInfo {
477                name: "bias".to_string(),
478                dtype: TensorDtype::Float32,
479                shape: vec![3],
480                offset: 24,
481                size: 12, // 3 * 4 bytes
482            },
483        ];
484
485        // Create buffer
486        let mut buffer = SharedTensorBuffer::create(&path, 36, &tensors).unwrap();
487
488        // Write data
489        let weights: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
490        let bias: Vec<f32> = vec![0.1, 0.2, 0.3];
491
492        buffer.write_tensor(&tensors[0], &weights);
493        buffer.write_tensor(&tensors[1], &bias);
494        buffer.update_checksum();
495        buffer.flush().unwrap();
496
497        // Read back
498        let read_buffer = SharedTensorBuffer::open_readonly(&path).unwrap();
499        let metadata = read_buffer.tensor_metadata().unwrap();
500
501        assert_eq!(metadata.len(), 2);
502        assert_eq!(metadata[0].name, "weights");
503        assert_eq!(metadata[1].name, "bias");
504
505        let read_weights: Vec<f32> = read_buffer.read_tensor(&metadata[0]);
506        let read_bias: Vec<f32> = read_buffer.read_tensor(&metadata[1]);
507
508        assert_eq!(read_weights, weights);
509        assert_eq!(read_bias, bias);
510    }
511
512    #[test]
513    fn test_memory_pool() {
514        let dir = tempdir().unwrap();
515        let pool_dir = dir.path().join("pool");
516
517        let mut pool = SharedMemoryPool::new(&pool_dir, 1024 * 1024);
518
519        // Create a buffer
520        let path = pool_dir.join("test1.shm");
521        let tensors = vec![SharedTensorInfo {
522            name: "test".to_string(),
523            dtype: TensorDtype::Float32,
524            shape: vec![4],
525            offset: 0,
526            size: 16,
527        }];
528
529        SharedTensorBuffer::create(&path, 16, &tensors).unwrap();
530
531        // Register in pool
532        let buffer = SharedTensorBuffer::open_readonly(&path).unwrap();
533        pool.register("test1", buffer).unwrap();
534
535        assert_eq!(pool.list().len(), 1);
536        assert!(pool.get("test1").is_some());
537
538        // Remove
539        pool.remove("test1");
540        assert!(pool.get("test1").is_none());
541    }
542}