Skip to main content

ariadnetor_tensor/tensor/
dense_ops.rs

1//! Dense-specific inherent methods on `Tensor<DenseStorage<S>, DenseLayout>`.
2//!
3//! Covers element access, in-place fills / scales, Frobenius-norm-based
4//! normalization, conjugation, zero-copy reshape, and reorder. These
5//! operations are storage-local: they do not need a backend for dispatch.
6
7use std::ops::{Mul, MulAssign};
8
9use ariadnetor_core::Scalar;
10use num_traits::{One, Zero};
11
12use super::Tensor;
13use crate::{DenseLayout, DenseStorage, DenseTensorData, TensorData};
14
15// ============================================================================
16// Dense-specific data access (all backends)
17// ============================================================================
18
19impl<S> Tensor<DenseStorage<S>, DenseLayout> {
20    /// Get a reference to the underlying contiguous data buffer.
21    pub fn data_slice(&self) -> &[S] {
22        self.data.storage().data()
23    }
24
25    /// Get a mutable reference to the underlying data buffer
26    /// (CoW-aware).
27    pub fn data_slice_mut(&mut self) -> &mut [S]
28    where
29        S: Clone,
30    {
31        self.data.storage_mut().data_mut()
32    }
33
34    /// Reshape to `new_shape` (zero-copy). Preserves the layout's memory
35    /// order. The flat data buffer is `Arc`-shared via
36    /// `DenseStorage::Clone`, so the result aliases the same allocation
37    /// as `self`.
38    ///
39    /// Under non-adjacent axis fusion the logical mapping differs
40    /// between row-major and column-major; callers fusing such axes
41    /// must reorder the flat buffer to the appropriate order first.
42    ///
43    /// # Panics
44    ///
45    /// Panics if `new_shape.iter().product() != self.len()`, via
46    /// [`TensorData::new`]'s `storage.flat_len() == layout.storage_extent()`
47    /// assert.
48    pub fn reshape(&self, new_shape: Vec<usize>) -> Self {
49        let new_layout = DenseLayout::new(new_shape, self.data.layout().order());
50        let new_storage = self.data.storage().clone();
51        Self::from_data(TensorData::new(new_storage, new_layout))
52    }
53}
54
55impl<S: Scalar> Tensor<DenseStorage<S>, DenseLayout> {
56    /// Memory order this tensor's flat data is laid out in.
57    pub fn order(&self) -> ariadnetor_core::backend::MemoryOrder {
58        self.data.layout().order()
59    }
60
61    /// Get element at the given indices.
62    ///
63    /// `indices` accepts any `AsRef<[usize]>`, so an array literal can be
64    /// passed without a borrow: `t.get([i, j])` as well as `t.get(&coords)`.
65    ///
66    /// # Panics
67    ///
68    /// Panics if `indices.len() != rank` or any index exceeds the
69    /// corresponding axis dimension.
70    pub fn get(&self, indices: impl AsRef<[usize]>) -> S {
71        let indices = indices.as_ref();
72        let shape = self.shape();
73        assert_eq!(
74            indices.len(),
75            shape.len(),
76            "Tensor::get: indices length {} doesn't match rank {}",
77            indices.len(),
78            shape.len(),
79        );
80        for (axis, (&idx, &dim)) in indices.iter().zip(shape).enumerate() {
81            assert!(
82                idx < dim,
83                "Tensor::get: index {idx} out of bounds for axis {axis} with size {dim}",
84            );
85        }
86        let order = self.order();
87        let flat = crate::flat_index(indices, shape, order);
88        self.data.storage().data()[flat]
89    }
90
91    /// Set element at the given indices.
92    ///
93    /// `indices` accepts any `AsRef<[usize]>`, so an array literal can be
94    /// passed without a borrow: `t.set([i, j], v)` as well as `t.set(&coords, v)`.
95    ///
96    /// # Panics
97    ///
98    /// Panics if `indices.len() != rank` or any index exceeds the
99    /// corresponding axis dimension.
100    pub fn set(&mut self, indices: impl AsRef<[usize]>, value: S) {
101        let indices = indices.as_ref();
102        // Resolve the flat offset under an immutable borrow that ends before
103        // the mutable storage borrow below, so no owned-shape copy is needed.
104        let flat = {
105            let shape = self.shape();
106            assert_eq!(
107                indices.len(),
108                shape.len(),
109                "Tensor::set: indices length {} doesn't match rank {}",
110                indices.len(),
111                shape.len(),
112            );
113            for (axis, (&idx, &dim)) in indices.iter().zip(shape).enumerate() {
114                assert!(
115                    idx < dim,
116                    "Tensor::set: index {idx} out of bounds for axis {axis} with size {dim}",
117                );
118            }
119            crate::flat_index(indices, shape, self.order())
120        };
121        self.data.storage_mut().data_mut()[flat] = value;
122    }
123
124    /// Fill the tensor with a constant value.
125    pub fn fill(&mut self, value: S) {
126        for slot in self.data.storage_mut().data_mut().iter_mut() {
127            *slot = value;
128        }
129    }
130}
131
132// ============================================================================
133// Dense-specific arithmetic operations (all backends)
134// ============================================================================
135
136impl<S: Clone> Tensor<DenseStorage<S>, DenseLayout> {
137    /// Scale every element by a factor (in-place).
138    pub fn scale<F>(&mut self, factor: F)
139    where
140        S: Mul<F, Output = S>,
141        F: Clone,
142    {
143        for slot in self.data.storage_mut().data_mut().iter_mut() {
144            *slot = slot.clone() * factor.clone();
145        }
146    }
147
148    /// Scale every element by a factor (out-of-place).
149    pub fn scaled<F>(&self, factor: F) -> Self
150    where
151        S: Mul<F, Output = S>,
152        F: Clone,
153    {
154        let new_data: Vec<S> = self
155            .data
156            .storage()
157            .data()
158            .iter()
159            .map(|x| x.clone() * factor.clone())
160            .collect();
161        let shape = self.shape().to_vec();
162        let order = self.data.layout().order();
163        let td = DenseTensorData::from_raw_parts(new_data, shape, order);
164        Self::from_data(td)
165    }
166}
167
168// ============================================================================
169// Scalar-multiplication operators on the joined DenseTensor surface
170// ============================================================================
171//
172// Convenience aliases for `scale` / `scaled`, restricted to a same-type
173// factor (`S` matches the element type). Cross-type factors keep using
174// the named methods, since a single `Mul` impl cannot cover them without
175// conflicting coherence.
176
177impl<S> Mul<S> for Tensor<DenseStorage<S>, DenseLayout>
178where
179    S: Clone + Mul<Output = S>,
180{
181    type Output = Tensor<DenseStorage<S>, DenseLayout>;
182
183    /// Scale by `rhs`, consuming `self`. Reuses the owned buffer in
184    /// place (no extra allocation when the storage is uniquely owned;
185    /// a buffer still shared via copy-on-write is cloned first).
186    fn mul(mut self, rhs: S) -> Self::Output {
187        self.scale(rhs);
188        self
189    }
190}
191
192impl<S> Mul<S> for &Tensor<DenseStorage<S>, DenseLayout>
193where
194    S: Clone + Mul<Output = S>,
195{
196    type Output = Tensor<DenseStorage<S>, DenseLayout>;
197
198    /// Scale by `rhs`, leaving `self` untouched (out-of-place).
199    fn mul(self, rhs: S) -> Self::Output {
200        self.scaled(rhs)
201    }
202}
203
204impl<S> MulAssign<S> for Tensor<DenseStorage<S>, DenseLayout>
205where
206    S: Clone + Mul<Output = S>,
207{
208    /// Scale every element by `rhs` in place.
209    fn mul_assign(&mut self, rhs: S) {
210        self.scale(rhs);
211    }
212}
213
214// ============================================================================
215// Dense-specific norm / normalization (all backends)
216// ============================================================================
217
218impl<S> Tensor<DenseStorage<S>, DenseLayout>
219where
220    S: Scalar,
221{
222    /// Frobenius norm.
223    pub fn norm(&self) -> S::Real {
224        let mut sq = S::Real::zero();
225        for &x in self.data.storage().data() {
226            let a = x.abs();
227            sq = sq + a * a;
228        }
229        <S::Real as num_traits::Float>::sqrt(sq)
230    }
231
232    /// Normalize to unit norm (in-place). Returns the original norm.
233    ///
234    /// # Panics
235    ///
236    /// Panics if the tensor has zero norm.
237    pub fn normalize(&mut self) -> S::Real {
238        let norm = self.norm();
239        assert!(norm != S::Real::zero(), "Cannot normalize zero tensor");
240        let inv_norm = S::Real::one() / norm;
241        for slot in self.data.storage_mut().data_mut().iter_mut() {
242            *slot = slot.scale_real(inv_norm);
243        }
244        norm
245    }
246
247    /// Normalize and return a new tensor (out-of-place).
248    pub fn normalized(&self) -> (Self, S::Real) {
249        let mut clone = self.clone();
250        let n = clone.normalize();
251        (clone, n)
252    }
253
254    /// Element-wise complex conjugate. Symmetric with
255    /// [`BlockSparseTensor::conj`].
256    pub fn conj(&self) -> Self {
257        Self {
258            data: self.data.conj(),
259        }
260    }
261
262    /// Return a tensor with flat data reordered to `to`. When
263    /// `self.data().order() == to`, the underlying buffer is shared via
264    /// `Arc` rather than copied.
265    ///
266    /// This is a **workspace-internal escape hatch**, not a user entry
267    /// point. The public `Tensor` surface hides memory layout: constructors
268    /// take no order, and the linalg / algorithm layers normalize to the
269    /// backend's preferred order internally. The only in-tree callers are
270    /// that internal plumbing (and the order-mismatch rejection tests).
271    /// End users should never need to choose a `MemoryOrder`; as an inherent
272    /// method on a re-exported type it cannot be hidden from umbrella users,
273    /// hence this note.
274    pub fn reordered(&self, to: ariadnetor_core::backend::MemoryOrder) -> Self {
275        let reordered = crate::reorder::reorder_data(&self.data, to);
276        Self { data: reordered }
277    }
278
279    /// General logical (C-order) reshape to an arbitrary target shape,
280    /// preserving the tensor's memory order. The buffer is routed
281    /// through row-major so the logical axis grouping is independent of
282    /// the physical layout, then restored to the original order; for a
283    /// row-major tensor each step is a zero-copy `Arc` share, for a
284    /// column-major tensor it costs one round-trip transpose.
285    ///
286    /// This is the low-level escape hatch for multi-leg regroupings that
287    /// [`fuse_legs`] / [`split_leg`] cannot express in a single
288    /// operation — e.g. fusing two disjoint leg groups at once. Prefer
289    /// [`fuse_legs`] / [`split_leg`] for single-leg fuse / split: they
290    /// constrain which axis changes and read as intent. Like
291    /// [`reshape`], only the total element count is validated.
292    ///
293    /// [`fuse_legs`]: Self::fuse_legs
294    /// [`split_leg`]: Self::split_leg
295    /// [`reshape`]: Self::reshape
296    ///
297    /// # Panics
298    ///
299    /// Panics if `new_shape`'s total element count differs from the
300    /// tensor's, via [`reshape`].
301    pub fn reshape_logical(&self, new_shape: Vec<usize>) -> Self {
302        let orig_order = self.order();
303        self.reordered(ariadnetor_core::backend::MemoryOrder::RowMajor)
304            .reshape(new_shape)
305            .reordered(orig_order)
306    }
307
308    /// Fuse a contiguous range of axes into a single leg, grouping
309    /// them in row-major (C-order) logical order regardless of the
310    /// tensor's physical memory order.
311    ///
312    /// The fused leg's extent is the product of the fused axes'
313    /// extents and its logical index runs fastest over the last fused
314    /// axis. The result keeps `self`'s memory order. Use [`reshape`]
315    /// instead when a raw, order-preserving buffer reinterpretation is
316    /// wanted. Inverse of [`split_leg`] over the same range. Convenience
317    /// over [`reshape_logical`] for the single-leg case; for multi-group
318    /// regroupings call [`reshape_logical`] directly.
319    ///
320    /// [`reshape`]: Self::reshape
321    /// [`split_leg`]: Self::split_leg
322    /// [`reshape_logical`]: Self::reshape_logical
323    ///
324    /// # Panics
325    ///
326    /// Panics unless `range.start < range.end <= rank`.
327    pub fn fuse_legs(&self, range: std::ops::Range<usize>) -> Self {
328        let shape = self.shape();
329        let rank = shape.len();
330        assert!(
331            range.start < range.end && range.end <= rank,
332            "fuse_legs: range {range:?} out of bounds for rank {rank}",
333        );
334        let fused: usize = shape[range.clone()].iter().product();
335        let mut new_shape = shape[..range.start].to_vec();
336        new_shape.push(fused);
337        new_shape.extend_from_slice(&shape[range.end..]);
338        self.reshape_logical(new_shape)
339    }
340
341    /// Split one axis into multiple axes, distributing the extent in
342    /// row-major (C-order) logical order regardless of the tensor's
343    /// physical memory order.
344    ///
345    /// `into` lists the resulting extents from slowest- to
346    /// fastest-varying. The result keeps `self`'s memory order.
347    /// Inverse of [`fuse_legs`] for a contiguous range. Convenience over
348    /// [`reshape_logical`] for the single-leg case; for multi-group
349    /// regroupings call [`reshape_logical`] directly.
350    ///
351    /// [`fuse_legs`]: Self::fuse_legs
352    /// [`reshape_logical`]: Self::reshape_logical
353    ///
354    /// # Panics
355    ///
356    /// Panics unless `axis < rank`, `into` is non-empty, and
357    /// `into.iter().product() == shape[axis]`.
358    pub fn split_leg(&self, axis: usize, into: &[usize]) -> Self {
359        let shape = self.shape();
360        let rank = shape.len();
361        assert!(
362            axis < rank,
363            "split_leg: axis {axis} out of bounds for rank {rank}",
364        );
365        assert!(!into.is_empty(), "split_leg: `into` must be non-empty");
366        let prod: usize = into.iter().product();
367        assert_eq!(
368            prod, shape[axis],
369            "split_leg: product of {into:?} != axis {axis} extent {}",
370            shape[axis],
371        );
372        let mut new_shape = shape[..axis].to_vec();
373        new_shape.extend_from_slice(into);
374        new_shape.extend_from_slice(&shape[axis + 1..]);
375        self.reshape_logical(new_shape)
376    }
377}