Skip to main content

llama_rs/tensor/
storage.rs

1//! Tensor data storage
2
3use std::sync::Arc;
4
5/// Owned or borrowed tensor data
6#[derive(Debug, Clone)]
7pub enum TensorStorage {
8    /// Owned data on CPU
9    Owned(Arc<Vec<u8>>),
10    /// View into external data (e.g., memory-mapped file)
11    View { data: *const u8, len: usize },
12}
13
14// SAFETY: View data comes from memory-mapped files which are thread-safe for reads
15unsafe impl Send for TensorStorage {}
16unsafe impl Sync for TensorStorage {}
17
18impl TensorStorage {
19    pub fn owned(data: Vec<u8>) -> Self {
20        Self::Owned(Arc::new(data))
21    }
22
23    /// # Safety
24    /// The data pointer must be valid for the lifetime of this storage.
25    pub unsafe fn view(data: *const u8, len: usize) -> Self {
26        Self::View { data, len }
27    }
28
29    pub fn as_bytes(&self) -> &[u8] {
30        match self {
31            Self::Owned(data) => data.as_slice(),
32            Self::View { data, len } => unsafe { std::slice::from_raw_parts(*data, *len) },
33        }
34    }
35
36    pub fn as_bytes_mut(&mut self) -> Option<&mut [u8]> {
37        match self {
38            Self::Owned(data) => Arc::get_mut(data).map(|v| v.as_mut_slice()),
39            Self::View { .. } => None,
40        }
41    }
42
43    pub fn len(&self) -> usize {
44        match self {
45            Self::Owned(data) => data.len(),
46            Self::View { len, .. } => *len,
47        }
48    }
49
50    pub fn is_empty(&self) -> bool {
51        self.len() == 0
52    }
53
54    pub fn to_owned(&self) -> Self {
55        Self::Owned(Arc::new(self.as_bytes().to_vec()))
56    }
57}