Skip to main content

ariadnetor_tensor/dense/
operations.rs

1//! Reshape, element-wise, and arithmetic operations for `DenseTensorData<T>`.
2
3use num_traits::Zero;
4use std::ops::{Add, Mul, MulAssign};
5
6use crate::{DenseLayout, DenseTensorData, TensorData, TensorError};
7use ariadnetor_core::MemoryOrder;
8
9impl<T> DenseTensorData<T>
10where
11    T: Clone,
12{
13    /// Reshape the tensor to a new shape (zero-copy: shares the
14    /// underlying storage Arc).
15    ///
16    /// The flat data is not rearranged — only the layout's shape
17    /// changes. The output preserves `self.order()`. Reshape semantics
18    /// depend on the order: adjacent-axis fusion is zero-copy under
19    /// both row-major and column-major for contiguous tensors, but
20    /// non-adjacent fusion produces a different logical mapping under
21    /// each order.
22    ///
23    /// # Panics
24    ///
25    /// Panics if the new shape has a different total number of
26    /// elements.
27    pub fn reshape(&self, new_shape: Vec<usize>) -> Self {
28        let new_total: usize = new_shape.iter().product();
29        assert_eq!(
30            self.len(),
31            new_total,
32            "reshape: total elements must match ({} vs {new_total})",
33            self.len()
34        );
35        let storage = self.storage().clone();
36        let layout = DenseLayout::new(new_shape, self.order());
37        TensorData::new(storage, layout)
38    }
39
40    /// Apply a function to each element.
41    ///
42    /// Iterates flat data directly. The result preserves
43    /// `self.order()`.
44    pub fn map<U, F>(&self, f: F) -> DenseTensorData<U>
45    where
46        F: Fn(&T) -> U,
47        U: Clone + 'static,
48    {
49        let result: Vec<U> = self.storage().data().iter().map(f).collect();
50        DenseTensorData::<U>::from_raw_parts(result, self.shape().to_vec(), self.order())
51    }
52
53    /// Apply a function with multi-dimensional coordinates to each
54    /// element.
55    ///
56    /// Iterates coordinates in `self.order()` while reading storage
57    /// linearly, so the coordinate-to-value mapping always matches
58    /// the storage's layout. The output preserves `self.order()`.
59    pub fn map_with_index<U, F>(&self, f: F) -> DenseTensorData<U>
60    where
61        F: Fn(&[usize], &T) -> U,
62        U: Clone + 'static,
63    {
64        let order = self.order();
65        let shape = self.shape();
66        let rank = shape.len();
67        let total = self.len();
68        let raw = self.storage().data();
69        let mut coords = vec![0usize; rank];
70        let mut result = Vec::with_capacity(total);
71
72        let axis_order: Vec<usize> = match order {
73            MemoryOrder::RowMajor => (0..rank).collect(),
74            MemoryOrder::ColumnMajor => (0..rank).rev().collect(),
75        };
76
77        for val in raw.iter().take(total) {
78            result.push(f(&coords, val));
79            for &d in axis_order.iter().rev() {
80                coords[d] += 1;
81                if coords[d] < shape[d] {
82                    break;
83                }
84                coords[d] = 0;
85            }
86        }
87
88        DenseTensorData::<U>::from_raw_parts(result, shape.to_vec(), order)
89    }
90
91    /// Scale all elements and return a new tensor (out-of-place).
92    pub fn scaled<S>(&self, factor: S) -> Self
93    where
94        T: Mul<S, Output = T>,
95        S: Clone,
96    {
97        let mut result = self.clone();
98        result.storage_mut().scale(factor);
99        result
100    }
101}
102
103// ============================================================================
104// Scalar-multiplication operators on DenseTensorData<T>
105// ============================================================================
106//
107// Convenience aliases for `scale` / `scaled`, restricted to a same-type
108// factor (`S = T`). Cross-type factors (e.g. scaling a complex tensor by
109// a real) cannot be expressed through a single `Mul` impl without
110// conflicting coherence, so those callers keep using the named methods.
111
112impl<T> Mul<T> for DenseTensorData<T>
113where
114    T: Clone + Mul<Output = T>,
115{
116    type Output = DenseTensorData<T>;
117
118    /// Scale by `rhs`, consuming `self`. Reuses the owned buffer in
119    /// place (no extra allocation when the storage is uniquely owned;
120    /// a buffer still shared via copy-on-write is cloned first).
121    fn mul(mut self, rhs: T) -> Self::Output {
122        self.scale(rhs);
123        self
124    }
125}
126
127impl<T> Mul<T> for &DenseTensorData<T>
128where
129    T: Clone + Mul<Output = T>,
130{
131    type Output = DenseTensorData<T>;
132
133    /// Scale by `rhs`, leaving `self` untouched (out-of-place).
134    fn mul(self, rhs: T) -> Self::Output {
135        self.scaled(rhs)
136    }
137}
138
139impl<T> MulAssign<T> for DenseTensorData<T>
140where
141    T: Clone + Mul<Output = T>,
142{
143    /// Scale every element by `rhs` in place.
144    fn mul_assign(&mut self, rhs: T) {
145        self.scale(rhs);
146    }
147}
148
149// ============================================================================
150// Multi-tensor arithmetic on DenseTensorData<T>
151// ============================================================================
152
153impl<T> DenseTensorData<T>
154where
155    T: Clone,
156{
157    /// Add all tensors (coefficients all = 1).
158    pub fn add_all(tensors: &[&DenseTensorData<T>]) -> Result<DenseTensorData<T>, TensorError>
159    where
160        T: Zero + num_traits::One + Add<Output = T> + Mul<Output = T>,
161    {
162        let coefs = vec![T::one(); tensors.len()];
163        Self::linear_combine(tensors, &coefs)
164    }
165
166    /// Linear combination: Σ coefs\[i\] * tensors\[i\].
167    ///
168    /// All input tensors must share the same `order()`; the result
169    /// preserves that order.
170    ///
171    /// # Errors
172    ///
173    /// Returns an error if tensors have different shapes, different
174    /// orders, the list is empty, or tensors and coefficients have
175    /// different lengths.
176    pub fn linear_combine(
177        tensors: &[&DenseTensorData<T>],
178        coefs: &[T],
179    ) -> Result<DenseTensorData<T>, TensorError>
180    where
181        T: Zero + Add<Output = T> + Mul<Output = T>,
182    {
183        if tensors.is_empty() {
184            return Err(TensorError::InvalidArgument(
185                "Cannot combine empty tensor list".to_string(),
186            ));
187        }
188        if tensors.len() != coefs.len() {
189            return Err(TensorError::InvalidArgument(format!(
190                "Mismatched lengths: {} tensors vs {} coefficients",
191                tensors.len(),
192                coefs.len()
193            )));
194        }
195        let shape = tensors[0].shape();
196        let order = tensors[0].order();
197        for t in &tensors[1..] {
198            if t.shape() != shape {
199                return Err(TensorError::InvalidArgument(
200                    "All tensors must have the same shape".to_string(),
201                ));
202            }
203            if t.order() != order {
204                return Err(TensorError::InvalidArgument(format!(
205                    "All tensors must have the same memory order; got {:?} and {:?}",
206                    order,
207                    t.order()
208                )));
209            }
210        }
211        let len = tensors[0].len();
212        let mut result = vec![T::zero(); len];
213        for (tensor, coef) in tensors.iter().zip(coefs) {
214            for (r, val) in result.iter_mut().zip(tensor.storage().data()) {
215                *r = r.clone() + coef.clone() * val.clone();
216            }
217        }
218        Ok(DenseTensorData::from_raw_parts(
219            result,
220            shape.to_vec(),
221            order,
222        ))
223    }
224}