ariadnetor_tensor/dense/
operations.rs1use num_traits::Zero;
4use std::ops::{Add, Mul, MulAssign};
5
6use crate::{DenseLayout, DenseTensorData, TensorData, TensorError};
7use ariadnetor_core::MemoryOrder;
8
9impl<T> DenseTensorData<T>
10where
11 T: Clone,
12{
13 pub fn reshape(&self, new_shape: Vec<usize>) -> Self {
28 let new_total: usize = new_shape.iter().product();
29 assert_eq!(
30 self.len(),
31 new_total,
32 "reshape: total elements must match ({} vs {new_total})",
33 self.len()
34 );
35 let storage = self.storage().clone();
36 let layout = DenseLayout::new(new_shape, self.order());
37 TensorData::new(storage, layout)
38 }
39
40 pub fn map<U, F>(&self, f: F) -> DenseTensorData<U>
45 where
46 F: Fn(&T) -> U,
47 U: Clone + 'static,
48 {
49 let result: Vec<U> = self.storage().data().iter().map(f).collect();
50 DenseTensorData::<U>::from_raw_parts(result, self.shape().to_vec(), self.order())
51 }
52
53 pub fn map_with_index<U, F>(&self, f: F) -> DenseTensorData<U>
60 where
61 F: Fn(&[usize], &T) -> U,
62 U: Clone + 'static,
63 {
64 let order = self.order();
65 let shape = self.shape();
66 let rank = shape.len();
67 let total = self.len();
68 let raw = self.storage().data();
69 let mut coords = vec![0usize; rank];
70 let mut result = Vec::with_capacity(total);
71
72 let axis_order: Vec<usize> = match order {
73 MemoryOrder::RowMajor => (0..rank).collect(),
74 MemoryOrder::ColumnMajor => (0..rank).rev().collect(),
75 };
76
77 for val in raw.iter().take(total) {
78 result.push(f(&coords, val));
79 for &d in axis_order.iter().rev() {
80 coords[d] += 1;
81 if coords[d] < shape[d] {
82 break;
83 }
84 coords[d] = 0;
85 }
86 }
87
88 DenseTensorData::<U>::from_raw_parts(result, shape.to_vec(), order)
89 }
90
91 pub fn scaled<S>(&self, factor: S) -> Self
93 where
94 T: Mul<S, Output = T>,
95 S: Clone,
96 {
97 let mut result = self.clone();
98 result.storage_mut().scale(factor);
99 result
100 }
101}
102
103impl<T> Mul<T> for DenseTensorData<T>
113where
114 T: Clone + Mul<Output = T>,
115{
116 type Output = DenseTensorData<T>;
117
118 fn mul(mut self, rhs: T) -> Self::Output {
122 self.scale(rhs);
123 self
124 }
125}
126
127impl<T> Mul<T> for &DenseTensorData<T>
128where
129 T: Clone + Mul<Output = T>,
130{
131 type Output = DenseTensorData<T>;
132
133 fn mul(self, rhs: T) -> Self::Output {
135 self.scaled(rhs)
136 }
137}
138
139impl<T> MulAssign<T> for DenseTensorData<T>
140where
141 T: Clone + Mul<Output = T>,
142{
143 fn mul_assign(&mut self, rhs: T) {
145 self.scale(rhs);
146 }
147}
148
149impl<T> DenseTensorData<T>
154where
155 T: Clone,
156{
157 pub fn add_all(tensors: &[&DenseTensorData<T>]) -> Result<DenseTensorData<T>, TensorError>
159 where
160 T: Zero + num_traits::One + Add<Output = T> + Mul<Output = T>,
161 {
162 let coefs = vec![T::one(); tensors.len()];
163 Self::linear_combine(tensors, &coefs)
164 }
165
166 pub fn linear_combine(
177 tensors: &[&DenseTensorData<T>],
178 coefs: &[T],
179 ) -> Result<DenseTensorData<T>, TensorError>
180 where
181 T: Zero + Add<Output = T> + Mul<Output = T>,
182 {
183 if tensors.is_empty() {
184 return Err(TensorError::InvalidArgument(
185 "Cannot combine empty tensor list".to_string(),
186 ));
187 }
188 if tensors.len() != coefs.len() {
189 return Err(TensorError::InvalidArgument(format!(
190 "Mismatched lengths: {} tensors vs {} coefficients",
191 tensors.len(),
192 coefs.len()
193 )));
194 }
195 let shape = tensors[0].shape();
196 let order = tensors[0].order();
197 for t in &tensors[1..] {
198 if t.shape() != shape {
199 return Err(TensorError::InvalidArgument(
200 "All tensors must have the same shape".to_string(),
201 ));
202 }
203 if t.order() != order {
204 return Err(TensorError::InvalidArgument(format!(
205 "All tensors must have the same memory order; got {:?} and {:?}",
206 order,
207 t.order()
208 )));
209 }
210 }
211 let len = tensors[0].len();
212 let mut result = vec![T::zero(); len];
213 for (tensor, coef) in tensors.iter().zip(coefs) {
214 for (r, val) in result.iter_mut().zip(tensor.storage().data()) {
215 *r = r.clone() + coef.clone() * val.clone();
216 }
217 }
218 Ok(DenseTensorData::from_raw_parts(
219 result,
220 shape.to_vec(),
221 order,
222 ))
223 }
224}