ariadnetor_tensor/dense/access.rs
1//! Coordinate-based element access for `DenseTensorData<T>`.
2//!
3//! `get` / `set` resolve flat indices through the layout's memory
4//! order, so a `RowMajor`-tagged and a `ColumnMajor`-tagged tensor
5//! holding the same logical matrix return the same value at the same
6//! `[i, j, ...]`.
7
8use crate::DenseTensorData;
9use crate::reorder::flat_index;
10
11impl<T> DenseTensorData<T>
12where
13 T: Clone,
14{
15 /// Get element at multi-dimensional indices.
16 ///
17 /// The flat index is computed using `self.order()`.
18 ///
19 /// # Panics
20 ///
21 /// Panics if indices are out of bounds.
22 pub fn get(&self, indices: &[usize]) -> T {
23 let shape = self.shape();
24 assert_eq!(indices.len(), shape.len());
25 for (axis, (&idx, &dim)) in indices.iter().zip(shape).enumerate() {
26 assert!(
27 idx < dim,
28 "index {idx} out of bounds for axis {axis} with size {dim}"
29 );
30 }
31 let order = self.order();
32 let idx = flat_index(indices, shape, order);
33 self.storage().data()[idx].clone()
34 }
35
36 /// Set element at multi-dimensional indices (triggers CoW on the
37 /// storage half if shared).
38 ///
39 /// The flat index is computed using `self.order()`.
40 ///
41 /// # Panics
42 ///
43 /// Panics if indices are out of bounds.
44 pub fn set(&mut self, indices: &[usize], value: T) {
45 let idx = {
46 let shape = self.shape();
47 assert_eq!(indices.len(), shape.len());
48 for (axis, (&i, &dim)) in indices.iter().zip(shape).enumerate() {
49 assert!(
50 i < dim,
51 "index {i} out of bounds for axis {axis} with size {dim}"
52 );
53 }
54 flat_index(indices, shape, self.order())
55 };
56 self.storage_mut().data_mut()[idx] = value;
57 }
58
59 /// Mutable reference to the underlying contiguous data buffer
60 /// (triggers CoW on the storage half if shared).
61 pub fn data_mut(&mut self) -> &mut [T] {
62 self.storage_mut().data_mut()
63 }
64
65 /// Iterate over the flat storage in flat (storage) order.
66 pub fn iter(&self) -> std::slice::Iter<'_, T> {
67 self.storage().data().iter()
68 }
69
70 /// Fill every element with a constant value (triggers CoW if
71 /// shared). Forwards to the storage half.
72 pub fn fill(&mut self, value: T) {
73 self.storage_mut().fill(value);
74 }
75
76 /// Scale every element by a scalar factor in place (triggers CoW
77 /// if shared).
78 pub fn scale<S>(&mut self, factor: S)
79 where
80 T: std::ops::Mul<S, Output = T>,
81 S: Clone,
82 {
83 self.storage_mut().scale(factor);
84 }
85
86 /// Apply a function to each element in place (triggers CoW if
87 /// shared).
88 pub fn map_mut<F>(&mut self, f: F)
89 where
90 F: Fn(&T) -> T,
91 {
92 self.storage_mut().map_mut(f);
93 }
94}