Skip to main content

ariadnetor_tensor/tensor/
mod.rs

1//! Tensor type combining storage and layout.
2//!
3//! `Tensor<St, L>` is a thin wrapper over a
4//! [`TensorData<St, L>`](crate::TensorData) bundle. Concrete user-facing
5//! aliases:
6//!
7//! - [`DenseTensor<T>`] = `Tensor<DenseStorage<T>, DenseLayout>`
8//! - [`BlockSparseTensor<T, S>`] =
9//!   `Tensor<BlockSparseStorage<T>, BlockSparseLayout<S>>`
10//!
11//! The tensor carries no compute backend: operations take the backend
12//! explicitly at the call site (see `ariadnetor-linalg`). Convenience
13//! constructors that need a memory order read it from the host substrate
14//! ([`Host`](crate::Host)) without binding the tensor to any backend.
15
16use std::fmt;
17
18use ariadnetor_core::Scalar;
19use ariadnetor_core::backend::{ComputeBackend, MemoryOrder};
20use num_traits::Zero;
21use rand::RngExt;
22
23use crate::capability::Host;
24use crate::{
25    BlockCoord, BlockSparseLayout, BlockSparseStorage, BlockSparseTensorData, DenseLayout,
26    DenseStorage, DenseTensorData, QNIndex, Sector, Storage, StorageFor, TensorData, TensorLayout,
27};
28
29mod dense_ops;
30
31mod block_sparse_ops;
32
33#[cfg(test)]
34mod tests;
35
36/// Memory order for host-resident convenience constructors.
37///
38/// Read through the [`Host`](crate::Host) substrate alias rather than a
39/// `NativeBackend` literal so the host order has a single source and the
40/// substrate can be repointed in one place.
41fn host_order() -> MemoryOrder {
42    Host::shared().preferred_order()
43}
44
45/// Tensor wrapping a [`TensorData`] bundle.
46///
47/// # Type Parameters
48///
49/// * `St` - Storage half ([`DenseStorage<T>`] or [`BlockSparseStorage<T>`])
50/// * `L`  - Layout half ([`DenseLayout`] or [`BlockSparseLayout<S>`])
51///
52/// # Examples
53///
54/// ```
55/// use ariadnetor_tensor::DenseTensor;
56///
57/// let a = DenseTensor::<f64>::zeros(vec![2, 2]);
58/// assert_eq!(a.shape(), &[2, 2]);
59/// ```
60pub struct Tensor<St, L>
61where
62    St: Storage + StorageFor<L>,
63    L: TensorLayout,
64{
65    data: TensorData<St, L>,
66}
67
68/// Dense tensor alias.
69pub type DenseTensor<T = f64> = Tensor<DenseStorage<T>, DenseLayout>;
70
71/// BlockSparse tensor alias.
72pub type BlockSparseTensor<T, S> = Tensor<BlockSparseStorage<T>, BlockSparseLayout<S>>;
73
74// ============================================================================
75// Manual Clone / Debug
76//
77// `Tensor` is generic over `St` and `L`; deriving requires bounds on
78// both that are not always present. The manual impls add the bounds
79// only where needed.
80// ============================================================================
81
82impl<St, L> Clone for Tensor<St, L>
83where
84    St: Storage + StorageFor<L> + Clone,
85    L: TensorLayout + Clone,
86{
87    fn clone(&self) -> Self {
88        Self {
89            data: self.data.clone(),
90        }
91    }
92}
93
94impl<St, L> fmt::Debug for Tensor<St, L>
95where
96    St: Storage + StorageFor<L>,
97    L: TensorLayout,
98{
99    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100        f.debug_struct("Tensor")
101            .field("shape", &self.data.layout().shape())
102            .finish()
103    }
104}
105
106// ============================================================================
107// Generic methods (all storage / layout combinations)
108// ============================================================================
109
110impl<St, L> Tensor<St, L>
111where
112    St: Storage + StorageFor<L>,
113    L: TensorLayout,
114{
115    /// Build a tensor from a pre-bundled [`TensorData`].
116    pub fn from_data(data: TensorData<St, L>) -> Self {
117        Self { data }
118    }
119
120    /// Internal escape hatch: reference to the joined [`TensorData`]
121    /// bundle.
122    ///
123    /// Intended for cross-crate kernel-access paths inside `ariadnetor-linalg`
124    /// and `ariadnetor-mps`; user code should reach for the inherent methods
125    /// on [`DenseTensor`] / [`BlockSparseTensor`] instead.
126    pub fn data(&self) -> &TensorData<St, L> {
127        &self.data
128    }
129
130    /// Internal escape hatch: mutable reference to the joined
131    /// [`TensorData`] bundle.
132    ///
133    /// Same audience as [`Tensor::data`] — cross-crate kernel paths
134    /// that need to mutate raw storage / layout state.
135    pub fn data_mut(&mut self) -> &mut TensorData<St, L> {
136        &mut self.data
137    }
138
139    /// Logical shape (delegates to the layout).
140    pub fn shape(&self) -> &[usize] {
141        self.data.layout().shape()
142    }
143
144    /// Rank (number of dimensions).
145    pub fn rank(&self) -> usize {
146        self.shape().len()
147    }
148
149    /// Total number of logical elements (`product(shape)`).
150    pub fn len(&self) -> usize {
151        self.shape().iter().product()
152    }
153
154    /// Whether the tensor has zero logical elements.
155    pub fn is_empty(&self) -> bool {
156        self.len() == 0
157    }
158}
159
160// ============================================================================
161// Dense-specific host constructors
162//
163// The memory order is taken from the host substrate's preferred order so
164// dispatch paths that require preferred-order alignment find it satisfied
165// at construction. These build host-resident data and bind the tensor to
166// no backend.
167// ============================================================================
168
169impl<S: Scalar> Tensor<DenseStorage<S>, DenseLayout> {
170    /// Create a Dense tensor filled with zeros.
171    pub fn zeros(shape: Vec<usize>) -> Self {
172        Self::dense_filled(shape, S::zero())
173    }
174
175    /// Create a Dense tensor filled with ones.
176    pub fn ones(shape: Vec<usize>) -> Self {
177        Self::dense_filled(shape, S::one())
178    }
179
180    /// Create a Dense tensor filled with `value`.
181    pub fn filled(shape: Vec<usize>, value: S) -> Self {
182        Self::dense_filled(shape, value)
183    }
184
185    /// Create an n×n identity matrix.
186    pub fn eye(n: usize) -> Self {
187        let order = host_order();
188        let mut data = vec![S::zero(); n * n];
189        // The identity matrix is symmetric, so the flat data is the
190        // same regardless of memory order; only the layout's `order()`
191        // field differs.
192        for i in 0..n {
193            data[i * n + i] = S::one();
194        }
195        let td = DenseTensorData::from_raw_parts(data, vec![n, n], order);
196        Self::from_data(td)
197    }
198
199    /// Create a Dense tensor filled with values drawn from the
200    /// standard distribution via the supplied RNG.
201    pub fn random<R: rand::Rng>(shape: Vec<usize>, rng: &mut R) -> Self
202    where
203        rand::distr::StandardUniform: rand::distr::Distribution<S>,
204    {
205        let order = host_order();
206        let total: usize = shape.iter().product();
207        let data: Vec<S> = (0..total).map(|_| rng.random()).collect();
208        let td = DenseTensorData::from_raw_parts(data, shape, order);
209        Self::from_data(td)
210    }
211
212    fn dense_filled(shape: Vec<usize>, value: S) -> Self {
213        let order = host_order();
214        let len: usize = shape.iter().product();
215        let data = DenseTensorData::from_raw_parts(vec![value; len], shape, order);
216        Self::from_data(data)
217    }
218}
219
220// ============================================================================
221// BlockSparse-specific host constructors
222//
223// As with Dense, the memory order is taken from the host substrate's
224// preferred order; users needing arbitrary order must go through the
225// joined-path `TensorData::new` route.
226// ============================================================================
227
228impl<T, S> Tensor<BlockSparseStorage<T>, BlockSparseLayout<S>>
229where
230    T: Clone + Zero,
231    S: Sector,
232{
233    /// Create a zero-filled `BlockSparseTensor` enumerating every
234    /// flux-allowed block of the supplied `QNIndex` legs.
235    pub fn zeros(indices: Vec<QNIndex<S>>, flux: S) -> Self {
236        let order = host_order();
237        let td = BlockSparseTensorData::zeros(indices, flux, order);
238        Self::from_data(td)
239    }
240}
241
242impl<T, S> Tensor<BlockSparseStorage<T>, BlockSparseLayout<S>>
243where
244    T: Clone,
245    S: Sector,
246    rand::distr::StandardUniform: rand::distr::Distribution<T>,
247{
248    /// Create a `BlockSparseTensor` whose flux-allowed blocks are
249    /// filled with values drawn from the standard distribution via the
250    /// supplied RNG.
251    pub fn random<R: rand::Rng>(indices: Vec<QNIndex<S>>, flux: S, rng: &mut R) -> Self {
252        let order = host_order();
253        let td = BlockSparseTensorData::random(indices, flux, order, rng);
254        Self::from_data(td)
255    }
256}
257
258impl<T, S> Tensor<BlockSparseStorage<T>, BlockSparseLayout<S>>
259where
260    T: Clone + Zero,
261    S: Sector,
262{
263    /// Construct a `BlockSparseTensor` by populating each flux-allowed
264    /// block from a closure receiving the block coordinate and its
265    /// dense block shape.
266    pub fn from_block_fn<F>(indices: Vec<QNIndex<S>>, flux: S, f: F) -> Self
267    where
268        F: FnMut(&BlockCoord, &[usize]) -> Vec<T>,
269    {
270        let order = host_order();
271        let td = BlockSparseTensorData::from_block_fn(indices, flux, order, f);
272        Self::from_data(td)
273    }
274}