Skip to main content

ariadnetor_tensor/dense/
access.rs

1//! Coordinate-based element access for `DenseTensorData<T>`.
2//!
3//! `get` / `set` resolve flat indices through the layout's memory
4//! order, so a `RowMajor`-tagged and a `ColumnMajor`-tagged tensor
5//! holding the same logical matrix return the same value at the same
6//! `[i, j, ...]`.
7
8use crate::DenseTensorData;
9use crate::reorder::flat_index;
10
11impl<T> DenseTensorData<T>
12where
13    T: Clone,
14{
15    /// Get element at multi-dimensional indices.
16    ///
17    /// The flat index is computed using `self.order()`.
18    ///
19    /// # Panics
20    ///
21    /// Panics if indices are out of bounds.
22    pub fn get(&self, indices: &[usize]) -> T {
23        let shape = self.shape();
24        assert_eq!(indices.len(), shape.len());
25        for (axis, (&idx, &dim)) in indices.iter().zip(shape).enumerate() {
26            assert!(
27                idx < dim,
28                "index {idx} out of bounds for axis {axis} with size {dim}"
29            );
30        }
31        let order = self.order();
32        let idx = flat_index(indices, shape, order);
33        self.storage().data()[idx].clone()
34    }
35
36    /// Set element at multi-dimensional indices (triggers CoW on the
37    /// storage half if shared).
38    ///
39    /// The flat index is computed using `self.order()`.
40    ///
41    /// # Panics
42    ///
43    /// Panics if indices are out of bounds.
44    pub fn set(&mut self, indices: &[usize], value: T) {
45        let idx = {
46            let shape = self.shape();
47            assert_eq!(indices.len(), shape.len());
48            for (axis, (&i, &dim)) in indices.iter().zip(shape).enumerate() {
49                assert!(
50                    i < dim,
51                    "index {i} out of bounds for axis {axis} with size {dim}"
52                );
53            }
54            flat_index(indices, shape, self.order())
55        };
56        self.storage_mut().data_mut()[idx] = value;
57    }
58
59    /// Mutable reference to the underlying contiguous data buffer
60    /// (triggers CoW on the storage half if shared).
61    pub fn data_mut(&mut self) -> &mut [T] {
62        self.storage_mut().data_mut()
63    }
64
65    /// Iterate over the flat storage in flat (storage) order.
66    pub fn iter(&self) -> std::slice::Iter<'_, T> {
67        self.storage().data().iter()
68    }
69
70    /// Fill every element with a constant value (triggers CoW if
71    /// shared). Forwards to the storage half.
72    pub fn fill(&mut self, value: T) {
73        self.storage_mut().fill(value);
74    }
75
76    /// Scale every element by a scalar factor in place (triggers CoW
77    /// if shared).
78    pub fn scale<S>(&mut self, factor: S)
79    where
80        T: std::ops::Mul<S, Output = T>,
81        S: Clone,
82    {
83        self.storage_mut().scale(factor);
84    }
85
86    /// Apply a function to each element in place (triggers CoW if
87    /// shared).
88    pub fn map_mut<F>(&mut self, f: F)
89    where
90        F: Fn(&T) -> T,
91    {
92        self.storage_mut().map_mut(f);
93    }
94}