carton_runner_interface/do_not_modify/
storage.rs

1// Copyright 2023 Vivek Panyam
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! TensorStorage that is stored inline
16
17use std::{fmt::Debug, marker::PhantomData};
18
19use ndarray::{ShapeBuilder, StrideShape};
20use serde::{Deserialize, Serialize};
21
22use super::alloc::AsPtr;
23
24#[derive(Debug, Serialize, Deserialize)]
25pub struct TensorStorage<T, Storage> {
26    pub(crate) data: Storage,
27    pub(crate) shape: Vec<u64>,
28    pub(crate) strides: Option<Vec<u64>>,
29    pub(crate) pd: PhantomData<T>,
30}
31
32impl<T, Storage> TensorStorage<T, Storage>
33where
34    Storage: AsPtr<T>,
35{
36    fn get_shape(&self) -> StrideShape<ndarray::IxDyn> {
37        match &self.strides {
38            None => self
39                .shape
40                .iter()
41                .map(|v| *v as usize)
42                .collect::<Vec<_>>()
43                .into(),
44            Some(strides) => self
45                .shape
46                .iter()
47                .map(|v| *v as usize)
48                .collect::<Vec<_>>()
49                .strides(strides.iter().map(|v| (*v).try_into().unwrap()).collect())
50                .into(),
51        }
52    }
53
54    pub fn view(&self) -> ndarray::ArrayViewD<T> {
55        let data = self.data.as_ptr();
56        unsafe { ndarray::ArrayView::from_shape_ptr(self.get_shape(), data) }
57    }
58
59    pub fn view_mut(&mut self) -> ndarray::ArrayViewMutD<T> {
60        let data = self.data.as_mut_ptr();
61        unsafe { ndarray::ArrayViewMut::from_shape_ptr(self.get_shape(), data) }
62    }
63}