ariadnetor_tensor/tensor/dense_ops.rs
1//! Dense-specific inherent methods on `Tensor<DenseStorage<S>, DenseLayout>`.
2//!
3//! Covers element access, in-place fills / scales, Frobenius-norm-based
4//! normalization, conjugation, zero-copy reshape, and reorder. These
5//! operations are storage-local: they do not need a backend for dispatch.
6
7use std::ops::{Mul, MulAssign};
8
9use ariadnetor_core::Scalar;
10use num_traits::{One, Zero};
11
12use super::Tensor;
13use crate::{DenseLayout, DenseStorage, DenseTensorData, TensorData};
14
15// ============================================================================
16// Dense-specific data access (all backends)
17// ============================================================================
18
19impl<S> Tensor<DenseStorage<S>, DenseLayout> {
20 /// Get a reference to the underlying contiguous data buffer.
21 pub fn data_slice(&self) -> &[S] {
22 self.data.storage().data()
23 }
24
25 /// Get a mutable reference to the underlying data buffer
26 /// (CoW-aware).
27 pub fn data_slice_mut(&mut self) -> &mut [S]
28 where
29 S: Clone,
30 {
31 self.data.storage_mut().data_mut()
32 }
33
34 /// Reshape to `new_shape` (zero-copy). Preserves the layout's memory
35 /// order. The flat data buffer is `Arc`-shared via
36 /// `DenseStorage::Clone`, so the result aliases the same allocation
37 /// as `self`.
38 ///
39 /// Under non-adjacent axis fusion the logical mapping differs
40 /// between row-major and column-major; callers fusing such axes
41 /// must reorder the flat buffer to the appropriate order first.
42 ///
43 /// # Panics
44 ///
45 /// Panics if `new_shape.iter().product() != self.len()`, via
46 /// [`TensorData::new`]'s `storage.flat_len() == layout.storage_extent()`
47 /// assert.
48 pub fn reshape(&self, new_shape: Vec<usize>) -> Self {
49 let new_layout = DenseLayout::new(new_shape, self.data.layout().order());
50 let new_storage = self.data.storage().clone();
51 Self::from_data(TensorData::new(new_storage, new_layout))
52 }
53}
54
55impl<S: Scalar> Tensor<DenseStorage<S>, DenseLayout> {
56 /// Memory order this tensor's flat data is laid out in.
57 pub fn order(&self) -> ariadnetor_core::backend::MemoryOrder {
58 self.data.layout().order()
59 }
60
61 /// Get element at the given indices.
62 ///
63 /// `indices` accepts any `AsRef<[usize]>`, so an array literal can be
64 /// passed without a borrow: `t.get([i, j])` as well as `t.get(&coords)`.
65 ///
66 /// # Panics
67 ///
68 /// Panics if `indices.len() != rank` or any index exceeds the
69 /// corresponding axis dimension.
70 pub fn get(&self, indices: impl AsRef<[usize]>) -> S {
71 let indices = indices.as_ref();
72 let shape = self.shape();
73 assert_eq!(
74 indices.len(),
75 shape.len(),
76 "Tensor::get: indices length {} doesn't match rank {}",
77 indices.len(),
78 shape.len(),
79 );
80 for (axis, (&idx, &dim)) in indices.iter().zip(shape).enumerate() {
81 assert!(
82 idx < dim,
83 "Tensor::get: index {idx} out of bounds for axis {axis} with size {dim}",
84 );
85 }
86 let order = self.order();
87 let flat = crate::flat_index(indices, shape, order);
88 self.data.storage().data()[flat]
89 }
90
91 /// Set element at the given indices.
92 ///
93 /// `indices` accepts any `AsRef<[usize]>`, so an array literal can be
94 /// passed without a borrow: `t.set([i, j], v)` as well as `t.set(&coords, v)`.
95 ///
96 /// # Panics
97 ///
98 /// Panics if `indices.len() != rank` or any index exceeds the
99 /// corresponding axis dimension.
100 pub fn set(&mut self, indices: impl AsRef<[usize]>, value: S) {
101 let indices = indices.as_ref();
102 // Resolve the flat offset under an immutable borrow that ends before
103 // the mutable storage borrow below, so no owned-shape copy is needed.
104 let flat = {
105 let shape = self.shape();
106 assert_eq!(
107 indices.len(),
108 shape.len(),
109 "Tensor::set: indices length {} doesn't match rank {}",
110 indices.len(),
111 shape.len(),
112 );
113 for (axis, (&idx, &dim)) in indices.iter().zip(shape).enumerate() {
114 assert!(
115 idx < dim,
116 "Tensor::set: index {idx} out of bounds for axis {axis} with size {dim}",
117 );
118 }
119 crate::flat_index(indices, shape, self.order())
120 };
121 self.data.storage_mut().data_mut()[flat] = value;
122 }
123
124 /// Fill the tensor with a constant value.
125 pub fn fill(&mut self, value: S) {
126 for slot in self.data.storage_mut().data_mut().iter_mut() {
127 *slot = value;
128 }
129 }
130}
131
132// ============================================================================
133// Dense-specific arithmetic operations (all backends)
134// ============================================================================
135
136impl<S: Clone> Tensor<DenseStorage<S>, DenseLayout> {
137 /// Scale every element by a factor (in-place).
138 pub fn scale<F>(&mut self, factor: F)
139 where
140 S: Mul<F, Output = S>,
141 F: Clone,
142 {
143 for slot in self.data.storage_mut().data_mut().iter_mut() {
144 *slot = slot.clone() * factor.clone();
145 }
146 }
147
148 /// Scale every element by a factor (out-of-place).
149 pub fn scaled<F>(&self, factor: F) -> Self
150 where
151 S: Mul<F, Output = S>,
152 F: Clone,
153 {
154 let new_data: Vec<S> = self
155 .data
156 .storage()
157 .data()
158 .iter()
159 .map(|x| x.clone() * factor.clone())
160 .collect();
161 let shape = self.shape().to_vec();
162 let order = self.data.layout().order();
163 let td = DenseTensorData::from_raw_parts(new_data, shape, order);
164 Self::from_data(td)
165 }
166}
167
168// ============================================================================
169// Scalar-multiplication operators on the joined DenseTensor surface
170// ============================================================================
171//
172// Convenience aliases for `scale` / `scaled`, restricted to a same-type
173// factor (`S` matches the element type). Cross-type factors keep using
174// the named methods, since a single `Mul` impl cannot cover them without
175// conflicting coherence.
176
177impl<S> Mul<S> for Tensor<DenseStorage<S>, DenseLayout>
178where
179 S: Clone + Mul<Output = S>,
180{
181 type Output = Tensor<DenseStorage<S>, DenseLayout>;
182
183 /// Scale by `rhs`, consuming `self`. Reuses the owned buffer in
184 /// place (no extra allocation when the storage is uniquely owned;
185 /// a buffer still shared via copy-on-write is cloned first).
186 fn mul(mut self, rhs: S) -> Self::Output {
187 self.scale(rhs);
188 self
189 }
190}
191
192impl<S> Mul<S> for &Tensor<DenseStorage<S>, DenseLayout>
193where
194 S: Clone + Mul<Output = S>,
195{
196 type Output = Tensor<DenseStorage<S>, DenseLayout>;
197
198 /// Scale by `rhs`, leaving `self` untouched (out-of-place).
199 fn mul(self, rhs: S) -> Self::Output {
200 self.scaled(rhs)
201 }
202}
203
204impl<S> MulAssign<S> for Tensor<DenseStorage<S>, DenseLayout>
205where
206 S: Clone + Mul<Output = S>,
207{
208 /// Scale every element by `rhs` in place.
209 fn mul_assign(&mut self, rhs: S) {
210 self.scale(rhs);
211 }
212}
213
214// ============================================================================
215// Dense-specific norm / normalization (all backends)
216// ============================================================================
217
218impl<S> Tensor<DenseStorage<S>, DenseLayout>
219where
220 S: Scalar,
221{
222 /// Frobenius norm.
223 pub fn norm(&self) -> S::Real {
224 let mut sq = S::Real::zero();
225 for &x in self.data.storage().data() {
226 let a = x.abs();
227 sq = sq + a * a;
228 }
229 <S::Real as num_traits::Float>::sqrt(sq)
230 }
231
232 /// Normalize to unit norm (in-place). Returns the original norm.
233 ///
234 /// # Panics
235 ///
236 /// Panics if the tensor has zero norm.
237 pub fn normalize(&mut self) -> S::Real {
238 let norm = self.norm();
239 assert!(norm != S::Real::zero(), "Cannot normalize zero tensor");
240 let inv_norm = S::Real::one() / norm;
241 for slot in self.data.storage_mut().data_mut().iter_mut() {
242 *slot = slot.scale_real(inv_norm);
243 }
244 norm
245 }
246
247 /// Normalize and return a new tensor (out-of-place).
248 pub fn normalized(&self) -> (Self, S::Real) {
249 let mut clone = self.clone();
250 let n = clone.normalize();
251 (clone, n)
252 }
253
254 /// Element-wise complex conjugate. Symmetric with
255 /// [`BlockSparseTensor::conj`].
256 pub fn conj(&self) -> Self {
257 Self {
258 data: self.data.conj(),
259 }
260 }
261
262 /// Return a tensor with flat data reordered to `to`. When
263 /// `self.data().order() == to`, the underlying buffer is shared via
264 /// `Arc` rather than copied.
265 ///
266 /// This is a **workspace-internal escape hatch**, not a user entry
267 /// point. The public `Tensor` surface hides memory layout: constructors
268 /// take no order, and the linalg / algorithm layers normalize to the
269 /// backend's preferred order internally. The only in-tree callers are
270 /// that internal plumbing (and the order-mismatch rejection tests).
271 /// End users should never need to choose a `MemoryOrder`; as an inherent
272 /// method on a re-exported type it cannot be hidden from umbrella users,
273 /// hence this note.
274 pub fn reordered(&self, to: ariadnetor_core::backend::MemoryOrder) -> Self {
275 let reordered = crate::reorder::reorder_data(&self.data, to);
276 Self { data: reordered }
277 }
278
279 /// General logical (C-order) reshape to an arbitrary target shape,
280 /// preserving the tensor's memory order. The buffer is routed
281 /// through row-major so the logical axis grouping is independent of
282 /// the physical layout, then restored to the original order; for a
283 /// row-major tensor each step is a zero-copy `Arc` share, for a
284 /// column-major tensor it costs one round-trip transpose.
285 ///
286 /// This is the low-level escape hatch for multi-leg regroupings that
287 /// [`fuse_legs`] / [`split_leg`] cannot express in a single
288 /// operation — e.g. fusing two disjoint leg groups at once. Prefer
289 /// [`fuse_legs`] / [`split_leg`] for single-leg fuse / split: they
290 /// constrain which axis changes and read as intent. Like
291 /// [`reshape`], only the total element count is validated.
292 ///
293 /// [`fuse_legs`]: Self::fuse_legs
294 /// [`split_leg`]: Self::split_leg
295 /// [`reshape`]: Self::reshape
296 ///
297 /// # Panics
298 ///
299 /// Panics if `new_shape`'s total element count differs from the
300 /// tensor's, via [`reshape`].
301 pub fn reshape_logical(&self, new_shape: Vec<usize>) -> Self {
302 let orig_order = self.order();
303 self.reordered(ariadnetor_core::backend::MemoryOrder::RowMajor)
304 .reshape(new_shape)
305 .reordered(orig_order)
306 }
307
308 /// Fuse a contiguous range of axes into a single leg, grouping
309 /// them in row-major (C-order) logical order regardless of the
310 /// tensor's physical memory order.
311 ///
312 /// The fused leg's extent is the product of the fused axes'
313 /// extents and its logical index runs fastest over the last fused
314 /// axis. The result keeps `self`'s memory order. Use [`reshape`]
315 /// instead when a raw, order-preserving buffer reinterpretation is
316 /// wanted. Inverse of [`split_leg`] over the same range. Convenience
317 /// over [`reshape_logical`] for the single-leg case; for multi-group
318 /// regroupings call [`reshape_logical`] directly.
319 ///
320 /// [`reshape`]: Self::reshape
321 /// [`split_leg`]: Self::split_leg
322 /// [`reshape_logical`]: Self::reshape_logical
323 ///
324 /// # Panics
325 ///
326 /// Panics unless `range.start < range.end <= rank`.
327 pub fn fuse_legs(&self, range: std::ops::Range<usize>) -> Self {
328 let shape = self.shape();
329 let rank = shape.len();
330 assert!(
331 range.start < range.end && range.end <= rank,
332 "fuse_legs: range {range:?} out of bounds for rank {rank}",
333 );
334 let fused: usize = shape[range.clone()].iter().product();
335 let mut new_shape = shape[..range.start].to_vec();
336 new_shape.push(fused);
337 new_shape.extend_from_slice(&shape[range.end..]);
338 self.reshape_logical(new_shape)
339 }
340
341 /// Split one axis into multiple axes, distributing the extent in
342 /// row-major (C-order) logical order regardless of the tensor's
343 /// physical memory order.
344 ///
345 /// `into` lists the resulting extents from slowest- to
346 /// fastest-varying. The result keeps `self`'s memory order.
347 /// Inverse of [`fuse_legs`] for a contiguous range. Convenience over
348 /// [`reshape_logical`] for the single-leg case; for multi-group
349 /// regroupings call [`reshape_logical`] directly.
350 ///
351 /// [`fuse_legs`]: Self::fuse_legs
352 /// [`reshape_logical`]: Self::reshape_logical
353 ///
354 /// # Panics
355 ///
356 /// Panics unless `axis < rank`, `into` is non-empty, and
357 /// `into.iter().product() == shape[axis]`.
358 pub fn split_leg(&self, axis: usize, into: &[usize]) -> Self {
359 let shape = self.shape();
360 let rank = shape.len();
361 assert!(
362 axis < rank,
363 "split_leg: axis {axis} out of bounds for rank {rank}",
364 );
365 assert!(!into.is_empty(), "split_leg: `into` must be non-empty");
366 let prod: usize = into.iter().product();
367 assert_eq!(
368 prod, shape[axis],
369 "split_leg: product of {into:?} != axis {axis} extent {}",
370 shape[axis],
371 );
372 let mut new_shape = shape[..axis].to_vec();
373 new_shape.extend_from_slice(into);
374 new_shape.extend_from_slice(&shape[axis + 1..]);
375 self.reshape_logical(new_shape)
376 }
377}