Skip to main content

ariadnetor_tensor/
tensor_data.rs

1//! `TensorData<St, L>`: the storage + layout bundle.
2//!
3//! Joins a [`Storage`] half with a paired [`TensorLayout`] half. The
4//! `new` constructor enforces the storage-layout boundary
5//! (length-equality check); layout-internal invariants are validated
6//! by the layout's own constructor.
7//!
8//! Flavor-specific aliases [`DenseTensorData<T>`](crate::DenseTensorData)
9//! and [`BlockSparseTensorData<T, S>`](crate::BlockSparseTensorData)
10//! carry the convenience constructors and joined accessors that need
11//! to touch both halves simultaneously (e.g. block-data slicing for
12//! block-sparse tensors).
13
14use crate::{Storage, StorageFor, TensorLayout};
15
16/// Joined storage + layout bundle.
17///
18/// Construction goes through [`new`](Self::new), which asserts
19/// `storage.flat_len() == layout.storage_extent()`. The bound
20/// `St: StorageFor<L>` enforces flavor compatibility at the type
21/// level (only `DenseStorage` ⇔ `DenseLayout`,
22/// `BlockSparseStorage` ⇔ `BlockSparseLayout`).
23pub struct TensorData<St, L>
24where
25    St: Storage + StorageFor<L>,
26    L: TensorLayout,
27{
28    storage: St,
29    layout: L,
30}
31
32impl<St, L> TensorData<St, L>
33where
34    St: Storage + StorageFor<L>,
35    L: TensorLayout,
36{
37    /// Construct from a `Storage` half and a paired `TensorLayout`
38    /// half. Asserts the storage-layout boundary: the storage's flat
39    /// length must match the layout's expected storage extent.
40    pub fn new(storage: St, layout: L) -> Self {
41        assert_eq!(
42            storage.flat_len(),
43            layout.storage_extent(),
44            "TensorData::new: storage.flat_len() = {} but layout.storage_extent() = {}",
45            storage.flat_len(),
46            layout.storage_extent(),
47        );
48        Self { storage, layout }
49    }
50
51    /// Reference to the storage half.
52    pub fn storage(&self) -> &St {
53        &self.storage
54    }
55
56    /// Mutable reference to the storage half.
57    ///
58    /// Crate-internal: wholesale replacement (`*td.storage_mut() = ...`)
59    /// would let a caller break the storage-layout boundary invariant
60    /// (`storage.flat_len() == layout.storage_extent()`) re-checked only
61    /// at [`new`](Self::new). Internal callers use this for length-preserving
62    /// element-wise mutation (via the storage's own `data_mut` etc.).
63    pub(crate) fn storage_mut(&mut self) -> &mut St {
64        &mut self.storage
65    }
66
67    /// Reference to the layout half.
68    pub fn layout(&self) -> &L {
69        &self.layout
70    }
71
72    /// Consume and return both halves.
73    pub fn into_parts(self) -> (St, L) {
74        (self.storage, self.layout)
75    }
76}
77
78impl<St, L> Clone for TensorData<St, L>
79where
80    St: Storage + StorageFor<L> + Clone,
81    L: TensorLayout + Clone,
82{
83    fn clone(&self) -> Self {
84        Self {
85            storage: self.storage.clone(),
86            layout: self.layout.clone(),
87        }
88    }
89}
90
91#[cfg(test)]
92mod tests {
93    use ariadnetor_core::backend::MemoryOrder;
94
95    use crate::{DenseLayout, DenseStorage, TensorData};
96
97    #[test]
98    #[should_panic(expected = "storage.flat_len() = 5 but layout.storage_extent() = 6")]
99    fn new_panics_on_storage_layout_length_mismatch() {
100        // 2 x 3 dense layout expects storage_extent = 6, but the
101        // storage carries only 5 elements. `TensorData::new` must
102        // reject the pair so downstream kernels never see a buffer
103        // that can index out of range under the layout's strides.
104        let storage = DenseStorage::<f64>::new(vec![0.0; 5]);
105        let layout = DenseLayout::new(vec![2, 3], MemoryOrder::RowMajor);
106        let _ = TensorData::new(storage, layout);
107    }
108
109    #[test]
110    fn new_accepts_matching_lengths() {
111        let storage = DenseStorage::<f64>::new(vec![0.0; 6]);
112        let layout = DenseLayout::new(vec![2, 3], MemoryOrder::RowMajor);
113        let td = TensorData::new(storage, layout);
114        assert_eq!(td.storage().data().len(), 6);
115        assert_eq!(td.layout().shape(), &[2, 3]);
116    }
117}