Skip to main content

openinfer_simulator/tensor/
tensor.rs

1//! Tensor container and views.
2//!
3//! `Tensor<T>` owns flat storage with shape/stride metadata and provides view
4//! access for slicing and indexing.
5use anyhow::{anyhow, Result};
6use std::cell::UnsafeCell;
7use std::ops::Index;
8
9use super::shape::{is_contiguous, linear_to_indices, numel, offset_for, view_parts, compute_strides};
10
11/// Tensor construction options (shape/stride overrides).
12#[derive(Debug, Clone, Default)]
13pub struct TensorOptions {
14    /// Optional explicit shape.
15    pub shape: Option<Vec<usize>>,
16    /// Optional explicit strides.
17    pub strides: Option<Vec<usize>>,
18    /// Allow length mismatch when using packed types.
19    pub allow_len_mismatch: bool,
20}
21
22/// Borrowed view into tensor data with shape/stride metadata.
23#[derive(Debug, Clone)]
24pub struct TensorView<T> {
25    data: *const T,
26    shape: Vec<usize>,
27    strides: Vec<usize>,
28}
29
30impl<T> TensorView<T> {
31    fn new(data: *const T, shape: Vec<usize>, strides: Vec<usize>) -> Self {
32        Self {
33            data,
34            shape,
35            strides,
36        }
37    }
38
39    /// Return the shape of this view.
40    pub fn shape(&self) -> &[usize] {
41        &self.shape
42    }
43
44    /// Return the strides of this view.
45    pub fn strides(&self) -> &[usize] {
46        &self.strides
47    }
48
49    /// Return the logical element count.
50    pub fn len(&self) -> usize {
51        numel(&self.shape)
52    }
53
54    /// Access a value by multidimensional indices.
55    pub fn at(&self, indices: &[usize]) -> &T {
56        let offset = offset_for(&self.shape, &self.strides, indices)
57            .unwrap_or_else(|err| panic!("tensor view index error: {}", err));
58        unsafe { &*self.data.add(offset) }
59    }
60
61    /// Return a contiguous slice if the view is contiguous.
62    pub fn as_slice(&self) -> Option<&[T]> {
63        if !is_contiguous(&self.shape, &self.strides) {
64            return None;
65        }
66        let len = self.len();
67        if len == 0 {
68            return Some(&[]);
69        }
70        unsafe { Some(std::slice::from_raw_parts(self.data, len)) }
71    }
72
73    /// Collect the view into a contiguous vector.
74    pub fn to_vec(&self) -> Vec<T>
75    where
76        T: Clone,
77    {
78        if let Some(slice) = self.as_slice() {
79            return slice.to_vec();
80        }
81        let mut out = Vec::with_capacity(self.len());
82        for idx in 0..self.len() {
83            let coords = linear_to_indices(idx, &self.shape);
84            out.push(self.at(&coords).clone());
85        }
86        out
87    }
88}
89
90/// Owned tensor container with shape and stride metadata.
91#[derive(Debug)]
92pub struct Tensor<T> {
93    pub data: Vec<T>,
94    shape: Vec<usize>,
95    strides: Vec<usize>,
96    // Indexing caches a view; this is not thread-safe.
97    view_cache: UnsafeCell<TensorView<T>>,
98}
99
100// Tensor owns its backing storage; moving between threads is safe when it is
101// not accessed concurrently.
102unsafe impl<T: Send> Send for Tensor<T> {}
103
104impl<T: Clone> Clone for Tensor<T> {
105    fn clone(&self) -> Self {
106        let data = self.data.clone();
107        let shape = self.shape.clone();
108        let strides = self.strides.clone();
109        let data_ptr = data.as_ptr();
110        Self {
111            data,
112            shape: shape.clone(),
113            strides: strides.clone(),
114            view_cache: UnsafeCell::new(TensorView::new(data_ptr, shape, strides)),
115        }
116    }
117}
118
119impl<T> Tensor<T> {
120    /// Build a tensor from a flat data vector.
121    ///
122    /// # Example
123    /// ```no_run
124    /// # use openinfer::tensor::Tensor;
125    /// # fn main() -> anyhow::Result<()> {
126    /// let t = Tensor::from_vec(vec![1.0f32, 2.0, 3.0])?;
127    /// # Ok(()) }
128    /// ```
129    pub fn from_vec(data: Vec<T>) -> Result<Self> {
130        Self::from_vec_with_opts(data, TensorOptions::default())
131    }
132
133    /// Build a tensor with explicit options.
134    ///
135    /// # Example
136    /// ```no_run
137    /// # use openinfer::tensor::{Tensor, TensorOptions};
138    /// # fn main() -> anyhow::Result<()> {
139    /// let t = Tensor::from_vec_with_opts(
140    ///     vec![1.0f32, 2.0, 3.0, 4.0],
141    ///     TensorOptions { shape: Some(vec![2, 2]), ..TensorOptions::default() },
142    /// )?;
143    /// # Ok(()) }
144    /// ```
145    pub fn from_vec_with_opts(data: Vec<T>, opts: TensorOptions) -> Result<Self> {
146        let shape = match opts.shape {
147            Some(shape) => shape,
148            None => vec![data.len()],
149        };
150        let expected = numel(&shape);
151        if !opts.allow_len_mismatch && expected != data.len() {
152            return Err(anyhow!(
153                "tensor shape {:?} expects {} values, got {}",
154                shape,
155                expected,
156                data.len()
157            ));
158        }
159        if shape.is_empty() && data.len() != 1 {
160            return Err(anyhow!(
161                "scalar tensor expects 1 value, got {}",
162                data.len()
163            ));
164        }
165        let strides = match opts.strides {
166            Some(strides) => {
167                if strides.len() != shape.len() {
168                    return Err(anyhow!(
169                        "tensor strides length {} does not match shape length {}",
170                        strides.len(),
171                        shape.len()
172                    ));
173                }
174                strides
175            }
176            None => compute_strides(&shape),
177        };
178        let data_ptr = data.as_ptr();
179        Ok(Self {
180            data,
181            shape: shape.clone(),
182            strides: strides.clone(),
183            view_cache: UnsafeCell::new(TensorView::new(data_ptr, shape, strides)),
184        })
185    }
186
187    /// Create a scalar tensor from a single value.
188    ///
189    /// # Example
190    /// ```no_run
191    /// # use openinfer::tensor::Tensor;
192    /// let t = Tensor::from_scalar(3.14f32);
193    /// ```
194    pub fn from_scalar(value: T) -> Self {
195        let data = vec![value];
196        let data_ptr = data.as_ptr();
197        let shape = Vec::new();
198        let strides = Vec::new();
199        Self {
200            data,
201            shape: shape.clone(),
202            strides: strides.clone(),
203            view_cache: UnsafeCell::new(TensorView::new(data_ptr, shape, strides)),
204        }
205    }
206
207    /// Create a tensor, panicking on invalid shape.
208    pub fn new(data: Vec<T>) -> Self {
209        Tensor::from_vec(data)
210            .unwrap_or_else(|err| panic!("tensor creation failed: {}", err))
211    }
212
213    /// Return the raw data length.
214    pub fn len(&self) -> usize {
215        self.data.len()
216    }
217
218    /// Return the tensor shape.
219    pub fn shape(&self) -> &[usize] {
220        &self.shape
221    }
222
223    /// Return the tensor strides.
224    pub fn strides(&self) -> &[usize] {
225        &self.strides
226    }
227
228    /// Return the logical element count.
229    pub fn numel(&self) -> usize {
230        numel(&self.shape)
231    }
232
233    /// Access a value by multidimensional indices.
234    pub fn at(&self, indices: &[usize]) -> &T {
235        let offset = offset_for(&self.shape, &self.strides, indices)
236            .unwrap_or_else(|err| panic!("tensor index error: {}", err));
237        &self.data[offset]
238    }
239
240    /// Create a view starting at the provided indices.
241    pub fn view(&self, indices: &[usize]) -> TensorView<T> {
242        let (offset, shape, strides) =
243            view_parts(&self.shape, &self.strides, indices)
244                .unwrap_or_else(|err| panic!("tensor view error: {}", err));
245        TensorView::new(unsafe { self.data.as_ptr().add(offset) }, shape, strides)
246    }
247
248    /// Clone the tensor data into a vector.
249    pub fn to_vec(&self) -> Vec<T>
250    where
251        T: Clone,
252    {
253        self.data.clone()
254    }
255}
256
257impl<T, const N: usize> Index<[usize; N]> for Tensor<T> {
258    type Output = TensorView<T>;
259
260    fn index(&self, index: [usize; N]) -> &Self::Output {
261        let view = self.view(&index);
262        unsafe {
263            *self.view_cache.get() = view;
264            &*self.view_cache.get()
265        }
266    }
267}