ariadnetor_tensor/tensor_data.rs
1//! `TensorData<St, L>`: the storage + layout bundle.
2//!
3//! Joins a [`Storage`] half with a paired [`TensorLayout`] half. The
4//! `new` constructor enforces the storage-layout boundary
5//! (length-equality check); layout-internal invariants are validated
6//! by the layout's own constructor.
7//!
8//! Flavor-specific aliases [`DenseTensorData<T>`](crate::DenseTensorData)
9//! and [`BlockSparseTensorData<T, S>`](crate::BlockSparseTensorData)
10//! carry the convenience constructors and joined accessors that need
11//! to touch both halves simultaneously (e.g. block-data slicing for
12//! block-sparse tensors).
13
14use crate::{Storage, StorageFor, TensorLayout};
15
16/// Joined storage + layout bundle.
17///
18/// Construction goes through [`new`](Self::new), which asserts
19/// `storage.flat_len() == layout.storage_extent()`. The bound
20/// `St: StorageFor<L>` enforces flavor compatibility at the type
21/// level (only `DenseStorage` ⇔ `DenseLayout`,
22/// `BlockSparseStorage` ⇔ `BlockSparseLayout`).
23pub struct TensorData<St, L>
24where
25 St: Storage + StorageFor<L>,
26 L: TensorLayout,
27{
28 storage: St,
29 layout: L,
30}
31
32impl<St, L> TensorData<St, L>
33where
34 St: Storage + StorageFor<L>,
35 L: TensorLayout,
36{
37 /// Construct from a `Storage` half and a paired `TensorLayout`
38 /// half. Asserts the storage-layout boundary: the storage's flat
39 /// length must match the layout's expected storage extent.
40 pub fn new(storage: St, layout: L) -> Self {
41 assert_eq!(
42 storage.flat_len(),
43 layout.storage_extent(),
44 "TensorData::new: storage.flat_len() = {} but layout.storage_extent() = {}",
45 storage.flat_len(),
46 layout.storage_extent(),
47 );
48 Self { storage, layout }
49 }
50
51 /// Reference to the storage half.
52 pub fn storage(&self) -> &St {
53 &self.storage
54 }
55
56 /// Mutable reference to the storage half.
57 ///
58 /// Crate-internal: wholesale replacement (`*td.storage_mut() = ...`)
59 /// would let a caller break the storage-layout boundary invariant
60 /// (`storage.flat_len() == layout.storage_extent()`) re-checked only
61 /// at [`new`](Self::new). Internal callers use this for length-preserving
62 /// element-wise mutation (via the storage's own `data_mut` etc.).
63 pub(crate) fn storage_mut(&mut self) -> &mut St {
64 &mut self.storage
65 }
66
67 /// Reference to the layout half.
68 pub fn layout(&self) -> &L {
69 &self.layout
70 }
71
72 /// Consume and return both halves.
73 pub fn into_parts(self) -> (St, L) {
74 (self.storage, self.layout)
75 }
76}
77
78impl<St, L> Clone for TensorData<St, L>
79where
80 St: Storage + StorageFor<L> + Clone,
81 L: TensorLayout + Clone,
82{
83 fn clone(&self) -> Self {
84 Self {
85 storage: self.storage.clone(),
86 layout: self.layout.clone(),
87 }
88 }
89}
90
91#[cfg(test)]
92mod tests {
93 use ariadnetor_core::backend::MemoryOrder;
94
95 use crate::{DenseLayout, DenseStorage, TensorData};
96
97 #[test]
98 #[should_panic(expected = "storage.flat_len() = 5 but layout.storage_extent() = 6")]
99 fn new_panics_on_storage_layout_length_mismatch() {
100 // 2 x 3 dense layout expects storage_extent = 6, but the
101 // storage carries only 5 elements. `TensorData::new` must
102 // reject the pair so downstream kernels never see a buffer
103 // that can index out of range under the layout's strides.
104 let storage = DenseStorage::<f64>::new(vec![0.0; 5]);
105 let layout = DenseLayout::new(vec![2, 3], MemoryOrder::RowMajor);
106 let _ = TensorData::new(storage, layout);
107 }
108
109 #[test]
110 fn new_accepts_matching_lengths() {
111 let storage = DenseStorage::<f64>::new(vec![0.0; 6]);
112 let layout = DenseLayout::new(vec![2, 3], MemoryOrder::RowMajor);
113 let td = TensorData::new(storage, layout);
114 assert_eq!(td.storage().data().len(), 6);
115 assert_eq!(td.layout().shape(), &[2, 3]);
116 }
117}