carton_runner_interface/do_not_modify/
storage.rs1use std::{fmt::Debug, marker::PhantomData};
18
19use ndarray::{ShapeBuilder, StrideShape};
20use serde::{Deserialize, Serialize};
21
22use super::alloc::AsPtr;
23
24#[derive(Debug, Serialize, Deserialize)]
25pub struct TensorStorage<T, Storage> {
26 pub(crate) data: Storage,
27 pub(crate) shape: Vec<u64>,
28 pub(crate) strides: Option<Vec<u64>>,
29 pub(crate) pd: PhantomData<T>,
30}
31
32impl<T, Storage> TensorStorage<T, Storage>
33where
34 Storage: AsPtr<T>,
35{
36 fn get_shape(&self) -> StrideShape<ndarray::IxDyn> {
37 match &self.strides {
38 None => self
39 .shape
40 .iter()
41 .map(|v| *v as usize)
42 .collect::<Vec<_>>()
43 .into(),
44 Some(strides) => self
45 .shape
46 .iter()
47 .map(|v| *v as usize)
48 .collect::<Vec<_>>()
49 .strides(strides.iter().map(|v| (*v).try_into().unwrap()).collect())
50 .into(),
51 }
52 }
53
54 pub fn view(&self) -> ndarray::ArrayViewD<T> {
55 let data = self.data.as_ptr();
56 unsafe { ndarray::ArrayView::from_shape_ptr(self.get_shape(), data) }
57 }
58
59 pub fn view_mut(&mut self) -> ndarray::ArrayViewMutD<T> {
60 let data = self.data.as_mut_ptr();
61 unsafe { ndarray::ArrayViewMut::from_shape_ptr(self.get_shape(), data) }
62 }
63}