use crate::arrow::TensorDtype;
use memmap2::{Mmap, MmapMut};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs::{File, OpenOptions};
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
pub struct SharedTensorBuffer {
mmap: MmapMut,
header: SharedBufferHeader,
path: PathBuf,
}
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct SharedBufferHeader {
pub magic: u64,
pub version: u32,
pub flags: u32,
pub total_size: u64,
pub data_offset: u64,
pub num_tensors: u32,
pub checksum: u64,
pub ref_count: u64,
}
impl SharedBufferHeader {
const MAGIC: u64 = 0x4950_4652_5354_454E;
pub fn new(total_size: u64, data_offset: u64, num_tensors: u32) -> Self {
Self {
magic: Self::MAGIC,
version: 1,
flags: 0,
total_size,
data_offset,
num_tensors,
checksum: 0,
ref_count: 1,
}
}
pub fn validate(&self) -> bool {
self.magic == Self::MAGIC && self.version == 1
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SharedTensorInfo {
pub name: String,
pub dtype: TensorDtype,
pub shape: Vec<usize>,
pub offset: usize,
pub size: usize,
}
impl SharedTensorBuffer {
pub fn create<P: AsRef<Path>>(
path: P,
size: usize,
tensors: &[SharedTensorInfo],
) -> Result<Self, SharedMemoryError> {
let path = path.as_ref().to_path_buf();
let metadata_json = serde_json::to_vec(tensors)?;
let metadata_size = metadata_json.len();
let header_size = std::mem::size_of::<SharedBufferHeader>();
let metadata_offset = header_size;
let data_offset = metadata_offset + metadata_size + 8; let total_size = data_offset + size;
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(true)
.open(&path)?;
file.set_len(total_size as u64)?;
let mut mmap = unsafe { MmapMut::map_mut(&file)? };
let header =
SharedBufferHeader::new(total_size as u64, data_offset as u64, tensors.len() as u32);
let header_bytes: &[u8] = unsafe {
std::slice::from_raw_parts(
&header as *const SharedBufferHeader as *const u8,
std::mem::size_of::<SharedBufferHeader>(),
)
};
mmap[..header_size].copy_from_slice(header_bytes);
let metadata_len_bytes = (metadata_size as u64).to_le_bytes();
mmap[metadata_offset..metadata_offset + 8].copy_from_slice(&metadata_len_bytes);
mmap[metadata_offset + 8..metadata_offset + 8 + metadata_size]
.copy_from_slice(&metadata_json);
mmap.flush()?;
Ok(Self { mmap, header, path })
}
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, SharedMemoryError> {
let path = path.as_ref().to_path_buf();
let file = OpenOptions::new().read(true).write(true).open(&path)?;
let mmap = unsafe { MmapMut::map_mut(&file)? };
let header_size = std::mem::size_of::<SharedBufferHeader>();
if mmap.len() < header_size {
return Err(SharedMemoryError::InvalidFormat("File too small".into()));
}
let header: SharedBufferHeader =
unsafe { std::ptr::read(mmap.as_ptr() as *const SharedBufferHeader) };
if !header.validate() {
return Err(SharedMemoryError::InvalidFormat(
"Invalid header magic or version".into(),
));
}
Ok(Self { mmap, header, path })
}
pub fn open_readonly<P: AsRef<Path>>(
path: P,
) -> Result<SharedTensorBufferReadOnly, SharedMemoryError> {
let path = path.as_ref().to_path_buf();
let file = File::open(&path)?;
let mmap = unsafe { Mmap::map(&file)? };
let header_size = std::mem::size_of::<SharedBufferHeader>();
if mmap.len() < header_size {
return Err(SharedMemoryError::InvalidFormat("File too small".into()));
}
let header: SharedBufferHeader =
unsafe { std::ptr::read(mmap.as_ptr() as *const SharedBufferHeader) };
if !header.validate() {
return Err(SharedMemoryError::InvalidFormat(
"Invalid header magic or version".into(),
));
}
Ok(SharedTensorBufferReadOnly { mmap, header, path })
}
pub fn tensor_metadata(&self) -> Result<Vec<SharedTensorInfo>, SharedMemoryError> {
let header_size = std::mem::size_of::<SharedBufferHeader>();
let mut len_bytes = [0u8; 8];
len_bytes.copy_from_slice(&self.mmap[header_size..header_size + 8]);
let metadata_len = u64::from_le_bytes(len_bytes) as usize;
let metadata_bytes = &self.mmap[header_size + 8..header_size + 8 + metadata_len];
let tensors: Vec<SharedTensorInfo> = serde_json::from_slice(metadata_bytes)?;
Ok(tensors)
}
pub fn tensor_data_mut(&mut self, info: &SharedTensorInfo) -> &mut [u8] {
let start = self.header.data_offset as usize + info.offset;
let end = start + info.size;
&mut self.mmap[start..end]
}
pub fn tensor_data(&self, info: &SharedTensorInfo) -> &[u8] {
let start = self.header.data_offset as usize + info.offset;
let end = start + info.size;
&self.mmap[start..end]
}
pub fn write_tensor<T: Copy>(&mut self, info: &SharedTensorInfo, data: &[T]) {
let bytes = unsafe {
std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data))
};
self.tensor_data_mut(info).copy_from_slice(bytes);
}
pub fn read_tensor<T: Copy + Default>(&self, info: &SharedTensorInfo) -> Vec<T> {
let bytes = self.tensor_data(info);
let elem_size = std::mem::size_of::<T>();
let count = bytes.len() / elem_size;
let mut result = vec![T::default(); count];
let result_bytes = unsafe {
std::slice::from_raw_parts_mut(result.as_mut_ptr() as *mut u8, count * elem_size)
};
result_bytes.copy_from_slice(&bytes[..count * elem_size]);
result
}
pub fn update_checksum(&mut self) {
let data_start = self.header.data_offset as usize;
let data = &self.mmap[data_start..];
let checksum: u64 = data.iter().fold(0u64, |acc, &b| acc.wrapping_add(b as u64));
let header_bytes = &mut self.mmap[..std::mem::size_of::<SharedBufferHeader>()];
let offset = std::mem::offset_of!(SharedBufferHeader, checksum);
header_bytes[offset..offset + 8].copy_from_slice(&checksum.to_le_bytes());
}
pub fn flush(&self) -> Result<(), SharedMemoryError> {
self.mmap.flush()?;
Ok(())
}
pub fn path(&self) -> &Path {
&self.path
}
pub fn size(&self) -> usize {
self.header.total_size as usize
}
}
pub struct SharedTensorBufferReadOnly {
mmap: Mmap,
header: SharedBufferHeader,
path: PathBuf,
}
impl SharedTensorBufferReadOnly {
pub fn tensor_metadata(&self) -> Result<Vec<SharedTensorInfo>, SharedMemoryError> {
let header_size = std::mem::size_of::<SharedBufferHeader>();
let mut len_bytes = [0u8; 8];
len_bytes.copy_from_slice(&self.mmap[header_size..header_size + 8]);
let metadata_len = u64::from_le_bytes(len_bytes) as usize;
let metadata_bytes = &self.mmap[header_size + 8..header_size + 8 + metadata_len];
let tensors: Vec<SharedTensorInfo> = serde_json::from_slice(metadata_bytes)?;
Ok(tensors)
}
pub fn tensor_data(&self, info: &SharedTensorInfo) -> &[u8] {
let start = self.header.data_offset as usize + info.offset;
let end = start + info.size;
&self.mmap[start..end]
}
pub fn read_tensor<T: Copy + Default>(&self, info: &SharedTensorInfo) -> Vec<T> {
let bytes = self.tensor_data(info);
let elem_size = std::mem::size_of::<T>();
let count = bytes.len() / elem_size;
let mut result = vec![T::default(); count];
let result_bytes = unsafe {
std::slice::from_raw_parts_mut(result.as_mut_ptr() as *mut u8, count * elem_size)
};
result_bytes.copy_from_slice(&bytes[..count * elem_size]);
result
}
pub fn verify_checksum(&self) -> bool {
let data_start = self.header.data_offset as usize;
let data = &self.mmap[data_start..];
let computed: u64 = data.iter().fold(0u64, |acc, &b| acc.wrapping_add(b as u64));
computed == self.header.checksum
}
pub fn path(&self) -> &Path {
&self.path
}
}
#[allow(dead_code)]
pub struct SharedMemoryPool {
base_dir: PathBuf,
buffers: HashMap<String, Arc<SharedTensorBufferReadOnly>>,
max_size: usize,
current_size: AtomicU64,
}
impl SharedMemoryPool {
pub fn new<P: AsRef<Path>>(base_dir: P, max_size: usize) -> Self {
std::fs::create_dir_all(base_dir.as_ref()).ok();
Self {
base_dir: base_dir.as_ref().to_path_buf(),
buffers: HashMap::new(),
max_size,
current_size: AtomicU64::new(0),
}
}
pub fn register(
&mut self,
name: &str,
buffer: SharedTensorBufferReadOnly,
) -> Result<(), SharedMemoryError> {
let size = buffer.mmap.len();
let current = self.current_size.load(Ordering::Relaxed);
if current + size as u64 > self.max_size as u64 {
return Err(SharedMemoryError::PoolFull);
}
self.current_size.fetch_add(size as u64, Ordering::Relaxed);
self.buffers.insert(name.to_string(), Arc::new(buffer));
Ok(())
}
pub fn get(&self, name: &str) -> Option<Arc<SharedTensorBufferReadOnly>> {
self.buffers.get(name).cloned()
}
pub fn remove(&mut self, name: &str) -> Option<Arc<SharedTensorBufferReadOnly>> {
if let Some(buffer) = self.buffers.remove(name) {
let size = buffer.mmap.len() as u64;
self.current_size.fetch_sub(size, Ordering::Relaxed);
Some(buffer)
} else {
None
}
}
pub fn list(&self) -> Vec<&str> {
self.buffers.keys().map(|s| s.as_str()).collect()
}
pub fn memory_usage(&self) -> usize {
self.current_size.load(Ordering::Relaxed) as usize
}
pub fn available(&self) -> usize {
self.max_size.saturating_sub(self.memory_usage())
}
}
#[derive(Debug)]
pub enum SharedMemoryError {
Io(std::io::Error),
InvalidFormat(String),
Json(serde_json::Error),
PoolFull,
NotFound(String),
}
impl From<std::io::Error> for SharedMemoryError {
fn from(err: std::io::Error) -> Self {
SharedMemoryError::Io(err)
}
}
impl From<serde_json::Error> for SharedMemoryError {
fn from(err: serde_json::Error) -> Self {
SharedMemoryError::Json(err)
}
}
impl std::fmt::Display for SharedMemoryError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SharedMemoryError::Io(e) => write!(f, "IO error: {}", e),
SharedMemoryError::InvalidFormat(s) => write!(f, "Invalid format: {}", s),
SharedMemoryError::Json(e) => write!(f, "JSON error: {}", e),
SharedMemoryError::PoolFull => write!(f, "Shared memory pool is full"),
SharedMemoryError::NotFound(s) => write!(f, "Buffer not found: {}", s),
}
}
}
impl std::error::Error for SharedMemoryError {}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_shared_buffer_create_and_read() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.shm");
let tensors = vec![
SharedTensorInfo {
name: "weights".to_string(),
dtype: TensorDtype::Float32,
shape: vec![2, 3],
offset: 0,
size: 24, },
SharedTensorInfo {
name: "bias".to_string(),
dtype: TensorDtype::Float32,
shape: vec![3],
offset: 24,
size: 12, },
];
let mut buffer = SharedTensorBuffer::create(&path, 36, &tensors).unwrap();
let weights: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let bias: Vec<f32> = vec![0.1, 0.2, 0.3];
buffer.write_tensor(&tensors[0], &weights);
buffer.write_tensor(&tensors[1], &bias);
buffer.update_checksum();
buffer.flush().unwrap();
let read_buffer = SharedTensorBuffer::open_readonly(&path).unwrap();
let metadata = read_buffer.tensor_metadata().unwrap();
assert_eq!(metadata.len(), 2);
assert_eq!(metadata[0].name, "weights");
assert_eq!(metadata[1].name, "bias");
let read_weights: Vec<f32> = read_buffer.read_tensor(&metadata[0]);
let read_bias: Vec<f32> = read_buffer.read_tensor(&metadata[1]);
assert_eq!(read_weights, weights);
assert_eq!(read_bias, bias);
}
#[test]
fn test_memory_pool() {
let dir = tempdir().unwrap();
let pool_dir = dir.path().join("pool");
let mut pool = SharedMemoryPool::new(&pool_dir, 1024 * 1024);
let path = pool_dir.join("test1.shm");
let tensors = vec![SharedTensorInfo {
name: "test".to_string(),
dtype: TensorDtype::Float32,
shape: vec![4],
offset: 0,
size: 16,
}];
SharedTensorBuffer::create(&path, 16, &tensors).unwrap();
let buffer = SharedTensorBuffer::open_readonly(&path).unwrap();
pool.register("test1", buffer).unwrap();
assert_eq!(pool.list().len(), 1);
assert!(pool.get("test1").is_some());
pool.remove("test1");
assert!(pool.get("test1").is_none());
}
}