use parking_lot::Mutex;
use std::{mem, slice};
use torsh_core::error::{Result, TorshError};
pub struct ZeroCopyTensor<T> {
data_ptr: *const T,
shape: Vec<usize>,
stride: Vec<usize>,
capacity: usize,
owned: bool,
}
impl<T> ZeroCopyTensor<T> {
pub unsafe fn from_raw_parts(
data_ptr: *const T,
shape: Vec<usize>,
stride: Vec<usize>,
) -> Self {
let capacity = shape.iter().product();
Self {
data_ptr,
shape,
stride,
capacity,
owned: false,
}
}
pub fn from_slice(data: &[T], shape: Vec<usize>) -> Self {
let capacity = shape.iter().product();
assert_eq!(
data.len(),
capacity,
"Data length must match tensor capacity"
);
let stride = Self::compute_stride(&shape);
Self {
data_ptr: data.as_ptr(),
shape,
stride,
capacity,
owned: false,
}
}
pub fn from_vec(data: Vec<T>, shape: Vec<usize>) -> Self {
let capacity = shape.iter().product();
assert_eq!(
data.len(),
capacity,
"Data length must match tensor capacity"
);
let stride = Self::compute_stride(&shape);
let data_ptr = data.as_ptr();
mem::forget(data);
Self {
data_ptr,
shape,
stride,
capacity,
owned: true,
}
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn stride(&self) -> &[usize] {
&self.stride
}
pub fn len(&self) -> usize {
self.capacity
}
pub fn is_empty(&self) -> bool {
self.capacity == 0
}
pub fn as_slice(&self) -> &[T] {
unsafe { slice::from_raw_parts(self.data_ptr, self.capacity) }
}
fn compute_stride(shape: &[usize]) -> Vec<usize> {
let mut stride = vec![1; shape.len()];
for i in (0..shape.len().saturating_sub(1)).rev() {
stride[i] = stride[i + 1] * shape[i + 1];
}
stride
}
pub fn slice_view(&self, ranges: &[(usize, usize)]) -> Result<ZeroCopyTensor<T>> {
if ranges.len() != self.shape.len() {
return Err(TorshError::InvalidArgument(
"Number of slice ranges must match tensor dimensions".to_string(),
));
}
let mut new_shape = Vec::new();
let mut offset = 0;
for (i, &(start, end)) in ranges.iter().enumerate() {
if start >= end || end > self.shape[i] {
return Err(TorshError::InvalidArgument(
"Invalid slice range".to_string(),
));
}
new_shape.push(end - start);
offset += start * self.stride[i];
}
let new_stride = self.stride.clone();
let new_data_ptr = unsafe { self.data_ptr.add(offset) };
let capacity = new_shape.iter().product();
Ok(ZeroCopyTensor {
data_ptr: new_data_ptr,
shape: new_shape,
stride: new_stride,
capacity,
owned: false,
})
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn is_owned(&self) -> bool {
self.owned
}
}
unsafe impl<T: Send> Send for ZeroCopyTensor<T> {}
unsafe impl<T: Sync> Sync for ZeroCopyTensor<T> {}
impl<T> Drop for ZeroCopyTensor<T> {
fn drop(&mut self) {
if self.owned {
unsafe {
let _vec =
Vec::from_raw_parts(self.data_ptr as *mut T, self.capacity, self.capacity);
}
}
}
}
pub struct TensorPool<T> {
pool: Mutex<Vec<Vec<T>>>,
max_size: usize,
}
impl<T: Clone + Default> TensorPool<T> {
pub fn new(max_size: usize) -> Self {
Self {
pool: Mutex::new(Vec::new()),
max_size,
}
}
pub fn get(&self, capacity: usize) -> Vec<T> {
let mut pool = self.pool.lock();
for i in 0..pool.len() {
if pool[i].capacity() >= capacity {
let mut tensor = pool.swap_remove(i);
tensor.clear();
tensor.resize(capacity, T::default());
return tensor;
}
}
vec![T::default(); capacity]
}
pub fn return_tensor(&self, tensor: Vec<T>) {
let mut pool = self.pool.lock();
if pool.len() < self.max_size {
pool.push(tensor);
}
}
pub fn pool_size(&self) -> usize {
self.pool.lock().len()
}
pub fn clear(&self) {
self.pool.lock().clear();
}
}
pub struct MemoryMappedLoader {
file_path: std::path::PathBuf,
}
impl MemoryMappedLoader {
pub fn new<P: AsRef<std::path::Path>>(file_path: P) -> Result<Self> {
let file_path = file_path.as_ref().to_path_buf();
if !file_path.exists() {
return Err(TorshError::InvalidArgument(format!(
"File does not exist: {}",
file_path.display()
)));
}
Ok(Self { file_path })
}
pub fn file_path(&self) -> &std::path::Path {
&self.file_path
}
pub fn file_size(&self) -> Result<u64> {
std::fs::metadata(&self.file_path)
.map(|metadata| metadata.len())
.map_err(|e| TorshError::InvalidArgument(format!("Failed to get file size: {}", e)))
}
pub fn load_slice(&self, _offset: usize, _length: usize) -> Result<&[u8]> {
Err(TorshError::UnsupportedOperation {
op: "memory mapping".to_string(),
dtype: "MemoryMappedLoader".to_string(),
})
}
pub fn can_map(&self) -> bool {
false
}
}
pub struct BufferManager<T> {
available_buffers: Mutex<Vec<Vec<T>>>,
max_buffers: usize,
buffer_size: usize,
}
impl<T: Clone + Default> BufferManager<T> {
pub fn new(max_buffers: usize, buffer_size: usize) -> Self {
let mut available_buffers = Vec::with_capacity(max_buffers);
for _ in 0..max_buffers {
available_buffers.push(vec![T::default(); buffer_size]);
}
Self {
available_buffers: Mutex::new(available_buffers),
max_buffers,
buffer_size,
}
}
pub fn acquire_buffer(&self) -> Option<Vec<T>> {
let mut available = self.available_buffers.lock();
available.pop()
}
pub fn release_buffer(&self, buffer: Vec<T>) {
let mut available = self.available_buffers.lock();
if available.len() < self.max_buffers {
available.push(buffer);
}
}
pub fn available_count(&self) -> usize {
self.available_buffers.lock().len()
}
pub fn in_use_count(&self) -> usize {
self.max_buffers - self.available_count()
}
pub fn buffer_size(&self) -> usize {
self.buffer_size
}
pub fn max_buffers(&self) -> usize {
self.max_buffers
}
pub fn reset(&self) {
let mut available = self.available_buffers.lock();
available.clear();
for _ in 0..self.max_buffers {
available.push(vec![T::default(); self.buffer_size]);
}
}
}
pub fn zero_copy_from_vec<T>(data: Vec<T>, shape: Vec<usize>) -> ZeroCopyTensor<T> {
ZeroCopyTensor::from_vec(data, shape)
}
pub fn zero_copy_from_slice<T>(data: &[T], shape: Vec<usize>) -> ZeroCopyTensor<T> {
ZeroCopyTensor::from_slice(data, shape)
}
pub fn create_tensor_pool<T: Clone + Default>(max_size: usize) -> TensorPool<T> {
TensorPool::new(max_size)
}
pub fn create_buffer_manager<T: Clone + Default>(
max_buffers: usize,
buffer_size: usize,
) -> BufferManager<T> {
BufferManager::new(max_buffers, buffer_size)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_zero_copy_tensor_from_vec() {
let data = vec![1, 2, 3, 4, 5, 6];
let shape = vec![2, 3];
let tensor = ZeroCopyTensor::from_vec(data, shape.clone());
assert_eq!(tensor.shape(), &[2, 3]);
assert_eq!(tensor.len(), 6);
assert!(!tensor.is_empty());
assert!(tensor.is_owned());
assert_eq!(tensor.ndim(), 2);
}
#[test]
fn test_zero_copy_tensor_from_slice() {
let data = vec![1, 2, 3, 4];
let shape = vec![2, 2];
let tensor = ZeroCopyTensor::from_slice(&data, shape.clone());
assert_eq!(tensor.shape(), &[2, 2]);
assert_eq!(tensor.len(), 4);
assert!(!tensor.is_owned());
assert_eq!(tensor.as_slice(), &[1, 2, 3, 4]);
}
#[test]
fn test_zero_copy_tensor_slice_view() {
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
let shape = vec![3, 3];
let tensor = ZeroCopyTensor::from_vec(data, shape);
let ranges = vec![(1, 3), (1, 3)];
let slice_view = tensor.slice_view(&ranges).unwrap();
assert_eq!(slice_view.shape(), &[2, 2]);
assert_eq!(slice_view.len(), 4);
assert!(!slice_view.is_owned());
}
#[test]
fn test_tensor_pool() {
let pool = TensorPool::<f32>::new(3);
assert_eq!(pool.pool_size(), 0);
let tensor1 = pool.get(10);
assert_eq!(tensor1.len(), 10);
pool.return_tensor(tensor1);
assert_eq!(pool.pool_size(), 1);
let tensor2 = pool.get(10);
assert_eq!(tensor2.len(), 10);
assert_eq!(pool.pool_size(), 0);
pool.return_tensor(tensor2);
pool.clear();
assert_eq!(pool.pool_size(), 0);
}
#[test]
fn test_buffer_manager() {
let manager = BufferManager::<u8>::new(2, 100);
assert_eq!(manager.available_count(), 2);
assert_eq!(manager.in_use_count(), 0);
assert_eq!(manager.buffer_size(), 100);
assert_eq!(manager.max_buffers(), 2);
let buffer1 = manager.acquire_buffer().unwrap();
assert_eq!(buffer1.len(), 100);
assert_eq!(manager.available_count(), 1);
let buffer2 = manager.acquire_buffer().unwrap();
assert_eq!(manager.available_count(), 0);
assert!(manager.acquire_buffer().is_none());
manager.release_buffer(buffer1);
assert_eq!(manager.available_count(), 1);
manager.release_buffer(buffer2);
assert_eq!(manager.available_count(), 2);
manager.reset();
assert_eq!(manager.available_count(), 2);
}
#[test]
fn test_memory_mapped_loader() {
let result = MemoryMappedLoader::new("/non/existent/file");
assert!(result.is_err());
if let Ok(loader) = MemoryMappedLoader::new("/dev/null") {
let result = loader.load_slice(0, 10);
assert!(result.is_err());
assert!(!loader.can_map());
}
}
#[test]
fn test_stride_computation() {
let stride = ZeroCopyTensor::<f32>::compute_stride(&[3, 4]);
assert_eq!(stride, vec![4, 1]);
let stride = ZeroCopyTensor::<f32>::compute_stride(&[2, 3, 4]);
assert_eq!(stride, vec![12, 4, 1]);
let stride = ZeroCopyTensor::<f32>::compute_stride(&[5]);
assert_eq!(stride, vec![1]);
}
#[test]
fn test_convenience_functions() {
let data = vec![1, 2, 3, 4];
let shape = vec![2, 2];
let _tensor_from_vec = zero_copy_from_vec(data.clone(), shape.clone());
let _tensor_from_slice = zero_copy_from_slice(&data, shape);
let _pool = create_tensor_pool::<f32>(10);
let _manager = create_buffer_manager::<u8>(5, 100);
}
#[test]
#[should_panic(expected = "Data length must match tensor capacity")]
fn test_shape_mismatch() {
let data = vec![1, 2, 3];
let shape = vec![2, 2]; ZeroCopyTensor::from_vec(data, shape);
}
}