use std::marker::PhantomData;
use std::ops::Range;
use std::sync::Arc;
use rten_base::byte_cast::Pod;
use rten_tensor::{DynLayout, Storage, TensorBase};
#[cfg(feature = "mmap")]
use memmap2::Mmap;
fn slice_address_range<T>(slice: &[T]) -> Range<usize> {
let addr = slice.as_ptr() as usize;
addr..(addr + std::mem::size_of_val(slice))
}
#[derive(Debug)]
pub enum ConstantStorage {
#[cfg(feature = "mmap")]
Mmap(Mmap),
Buffer(Vec<u8>),
StaticSlice(&'static [u8]),
}
impl ConstantStorage {
pub fn data(&self) -> &[u8] {
match &self {
ConstantStorage::Buffer(data) => data,
ConstantStorage::StaticSlice(data) => data,
#[cfg(feature = "mmap")]
ConstantStorage::Mmap(mmap) => mmap,
}
}
fn byte_range_of<T>(&self, data: &[T]) -> Option<Range<usize>> {
if std::mem::size_of::<T>() == 0 {
return None;
}
let self_range = slice_address_range(self.data());
let data_range = slice_address_range(data);
if !self_range.contains(&data_range.start) || self_range.end < data_range.end {
return None;
}
let start = data_range.start - self_range.start;
let end = data_range.end - self_range.start;
Some(start..end)
}
}
#[derive(Debug)]
pub struct ArcSlice<T> {
storage: Arc<ConstantStorage>,
byte_offset: usize,
len: usize,
phantom: PhantomData<T>,
}
impl<T> Clone for ArcSlice<T> {
fn clone(&self) -> ArcSlice<T> {
ArcSlice {
storage: self.storage.clone(),
byte_offset: self.byte_offset,
len: self.len,
phantom: PhantomData,
}
}
}
impl<T> ArcSlice<T> {
pub fn new(storage: Arc<ConstantStorage>, data: &[T]) -> Option<ArcSlice<T>> {
let byte_range = storage.byte_range_of(data)?;
Some(ArcSlice::<T> {
storage,
byte_offset: byte_range.start,
len: data.len(),
phantom: PhantomData,
})
}
pub fn from_bytes(buf: Vec<u8>) -> Option<ArcSlice<T>>
where
T: Pod,
{
let buf = if buf.capacity() == 0 {
Vec::with_capacity(1)
} else {
buf
};
if !(buf.as_ptr() as usize).is_multiple_of(align_of::<T>())
|| !buf.len().is_multiple_of(size_of::<T>())
{
return None;
}
let len = buf.len().checked_div(size_of::<T>())?;
let storage = Arc::new(ConstantStorage::Buffer(buf));
Some(ArcSlice::<T> {
storage,
byte_offset: 0,
len,
phantom: PhantomData,
})
}
}
unsafe impl<T> Storage for ArcSlice<T> {
type Elem = T;
const MUTABLE: bool = false;
fn len(&self) -> usize {
self.len
}
fn as_ptr(&self) -> *const Self::Elem {
unsafe {
let ptr = self.storage.data().as_ptr().add(self.byte_offset);
std::mem::transmute(ptr)
}
}
}
pub type ArcTensorView<T> = TensorBase<ArcSlice<T>, DynLayout>;
#[cfg(test)]
mod tests {
use std::sync::Arc;
use rten_base::byte_cast::cast_pod_slice;
use rten_tensor::prelude::*;
use super::{ArcSlice, ArcTensorView, ConstantStorage};
fn vec_to_ne_bytes(vec: Vec<i32>) -> Vec<u8> {
vec.into_iter()
.flat_map(|x| x.to_ne_bytes().into_iter())
.collect()
}
#[test]
fn test_constant_storage() {
let data: Vec<i32> = (0..16).collect();
let bytes = vec_to_ne_bytes(data);
let storage = Arc::new(ConstantStorage::Buffer(bytes));
let slice_one = cast_pod_slice::<u8, i32>(&storage.data()[0..32]).unwrap();
assert_eq!(slice_one, [0, 1, 2, 3, 4, 5, 6, 7]);
let slice_two = cast_pod_slice::<u8, i32>(&storage.data()[32..64]).unwrap();
assert_eq!(slice_two, [8, 9, 10, 11, 12, 13, 14, 15]);
let arc_slice_one = ArcSlice::new(storage.clone(), slice_one).unwrap();
let arc_slice_two = ArcSlice::new(storage.clone(), slice_two).unwrap();
let view_one = ArcTensorView::from_data(&[2, 4], arc_slice_one);
let view_two = ArcTensorView::from_data(&[4, 2], arc_slice_two);
assert_eq!(view_one.shape(), &[2, 4]);
assert_eq!(view_one.data().unwrap(), slice_one);
assert_eq!(view_two.shape(), &[4, 2]);
assert_eq!(view_two.data().unwrap(), slice_two);
let slice_outside = &[1, 2, 3];
assert!(ArcSlice::new(storage.clone(), slice_outside).is_none());
let zst_slice = &[(), ()];
assert!(ArcSlice::new(storage.clone(), zst_slice).is_none());
}
#[test]
fn test_arc_slice_from_bytes() {
let data: Vec<i32> = (0..16).collect();
let bytes = vec_to_ne_bytes(data.clone());
let slice = ArcSlice::<i32>::from_bytes(bytes).unwrap();
let tensor = ArcTensorView::from_data(&[16], slice);
assert_eq!(tensor.data().unwrap(), data);
}
#[test]
fn test_arc_slice_from_empty_bytes() {
let data: Vec<i32> = Vec::new();
let bytes = vec_to_ne_bytes(data.clone());
let slice = ArcSlice::<i32>::from_bytes(bytes).unwrap();
let tensor = ArcTensorView::from_data(&[0], slice);
assert_eq!(tensor.data().unwrap(), data);
}
#[test]
fn test_arc_slice_from_bytes_invalid() {
let data: Vec<i32> = (0..16).collect();
let mut bytes = vec_to_ne_bytes(data.clone());
bytes.push(0u8); assert!(ArcSlice::<i32>::from_bytes(bytes).is_none());
}
}