openinfer_simulator/tensor/
tensor.rs1use anyhow::{anyhow, Result};
6use std::cell::UnsafeCell;
7use std::ops::Index;
8
9use super::shape::{is_contiguous, linear_to_indices, numel, offset_for, view_parts, compute_strides};
10
11#[derive(Debug, Clone, Default)]
13pub struct TensorOptions {
14 pub shape: Option<Vec<usize>>,
16 pub strides: Option<Vec<usize>>,
18 pub allow_len_mismatch: bool,
20}
21
22#[derive(Debug, Clone)]
24pub struct TensorView<T> {
25 data: *const T,
26 shape: Vec<usize>,
27 strides: Vec<usize>,
28}
29
30impl<T> TensorView<T> {
31 fn new(data: *const T, shape: Vec<usize>, strides: Vec<usize>) -> Self {
32 Self {
33 data,
34 shape,
35 strides,
36 }
37 }
38
39 pub fn shape(&self) -> &[usize] {
41 &self.shape
42 }
43
44 pub fn strides(&self) -> &[usize] {
46 &self.strides
47 }
48
49 pub fn len(&self) -> usize {
51 numel(&self.shape)
52 }
53
54 pub fn at(&self, indices: &[usize]) -> &T {
56 let offset = offset_for(&self.shape, &self.strides, indices)
57 .unwrap_or_else(|err| panic!("tensor view index error: {}", err));
58 unsafe { &*self.data.add(offset) }
59 }
60
61 pub fn as_slice(&self) -> Option<&[T]> {
63 if !is_contiguous(&self.shape, &self.strides) {
64 return None;
65 }
66 let len = self.len();
67 if len == 0 {
68 return Some(&[]);
69 }
70 unsafe { Some(std::slice::from_raw_parts(self.data, len)) }
71 }
72
73 pub fn to_vec(&self) -> Vec<T>
75 where
76 T: Clone,
77 {
78 if let Some(slice) = self.as_slice() {
79 return slice.to_vec();
80 }
81 let mut out = Vec::with_capacity(self.len());
82 for idx in 0..self.len() {
83 let coords = linear_to_indices(idx, &self.shape);
84 out.push(self.at(&coords).clone());
85 }
86 out
87 }
88}
89
90#[derive(Debug)]
92pub struct Tensor<T> {
93 pub data: Vec<T>,
94 shape: Vec<usize>,
95 strides: Vec<usize>,
96 view_cache: UnsafeCell<TensorView<T>>,
98}
99
100unsafe impl<T: Send> Send for Tensor<T> {}
103
104impl<T: Clone> Clone for Tensor<T> {
105 fn clone(&self) -> Self {
106 let data = self.data.clone();
107 let shape = self.shape.clone();
108 let strides = self.strides.clone();
109 let data_ptr = data.as_ptr();
110 Self {
111 data,
112 shape: shape.clone(),
113 strides: strides.clone(),
114 view_cache: UnsafeCell::new(TensorView::new(data_ptr, shape, strides)),
115 }
116 }
117}
118
119impl<T> Tensor<T> {
120 pub fn from_vec(data: Vec<T>) -> Result<Self> {
130 Self::from_vec_with_opts(data, TensorOptions::default())
131 }
132
133 pub fn from_vec_with_opts(data: Vec<T>, opts: TensorOptions) -> Result<Self> {
146 let shape = match opts.shape {
147 Some(shape) => shape,
148 None => vec![data.len()],
149 };
150 let expected = numel(&shape);
151 if !opts.allow_len_mismatch && expected != data.len() {
152 return Err(anyhow!(
153 "tensor shape {:?} expects {} values, got {}",
154 shape,
155 expected,
156 data.len()
157 ));
158 }
159 if shape.is_empty() && data.len() != 1 {
160 return Err(anyhow!(
161 "scalar tensor expects 1 value, got {}",
162 data.len()
163 ));
164 }
165 let strides = match opts.strides {
166 Some(strides) => {
167 if strides.len() != shape.len() {
168 return Err(anyhow!(
169 "tensor strides length {} does not match shape length {}",
170 strides.len(),
171 shape.len()
172 ));
173 }
174 strides
175 }
176 None => compute_strides(&shape),
177 };
178 let data_ptr = data.as_ptr();
179 Ok(Self {
180 data,
181 shape: shape.clone(),
182 strides: strides.clone(),
183 view_cache: UnsafeCell::new(TensorView::new(data_ptr, shape, strides)),
184 })
185 }
186
187 pub fn from_scalar(value: T) -> Self {
195 let data = vec![value];
196 let data_ptr = data.as_ptr();
197 let shape = Vec::new();
198 let strides = Vec::new();
199 Self {
200 data,
201 shape: shape.clone(),
202 strides: strides.clone(),
203 view_cache: UnsafeCell::new(TensorView::new(data_ptr, shape, strides)),
204 }
205 }
206
207 pub fn new(data: Vec<T>) -> Self {
209 Tensor::from_vec(data)
210 .unwrap_or_else(|err| panic!("tensor creation failed: {}", err))
211 }
212
213 pub fn len(&self) -> usize {
215 self.data.len()
216 }
217
218 pub fn shape(&self) -> &[usize] {
220 &self.shape
221 }
222
223 pub fn strides(&self) -> &[usize] {
225 &self.strides
226 }
227
228 pub fn numel(&self) -> usize {
230 numel(&self.shape)
231 }
232
233 pub fn at(&self, indices: &[usize]) -> &T {
235 let offset = offset_for(&self.shape, &self.strides, indices)
236 .unwrap_or_else(|err| panic!("tensor index error: {}", err));
237 &self.data[offset]
238 }
239
240 pub fn view(&self, indices: &[usize]) -> TensorView<T> {
242 let (offset, shape, strides) =
243 view_parts(&self.shape, &self.strides, indices)
244 .unwrap_or_else(|err| panic!("tensor view error: {}", err));
245 TensorView::new(unsafe { self.data.as_ptr().add(offset) }, shape, strides)
246 }
247
248 pub fn to_vec(&self) -> Vec<T>
250 where
251 T: Clone,
252 {
253 self.data.clone()
254 }
255}
256
257impl<T, const N: usize> Index<[usize; N]> for Tensor<T> {
258 type Output = TensorView<T>;
259
260 fn index(&self, index: [usize; N]) -> &Self::Output {
261 let view = self.view(&index);
262 unsafe {
263 *self.view_cache.get() = view;
264 &*self.view_cache.get()
265 }
266 }
267}