use crate::{Storage, StorageFor, TensorLayout};
pub struct TensorData<St, L>
where
St: Storage + StorageFor<L>,
L: TensorLayout,
{
storage: St,
layout: L,
}
impl<St, L> TensorData<St, L>
where
St: Storage + StorageFor<L>,
L: TensorLayout,
{
pub fn new(storage: St, layout: L) -> Self {
assert_eq!(
storage.flat_len(),
layout.storage_extent(),
"TensorData::new: storage.flat_len() = {} but layout.storage_extent() = {}",
storage.flat_len(),
layout.storage_extent(),
);
Self { storage, layout }
}
pub fn storage(&self) -> &St {
&self.storage
}
pub(crate) fn storage_mut(&mut self) -> &mut St {
&mut self.storage
}
pub fn layout(&self) -> &L {
&self.layout
}
pub fn into_parts(self) -> (St, L) {
(self.storage, self.layout)
}
}
impl<St, L> Clone for TensorData<St, L>
where
St: Storage + StorageFor<L> + Clone,
L: TensorLayout + Clone,
{
fn clone(&self) -> Self {
Self {
storage: self.storage.clone(),
layout: self.layout.clone(),
}
}
}
#[cfg(test)]
mod tests {
use ariadnetor_core::backend::MemoryOrder;
use crate::{DenseLayout, DenseStorage, TensorData};
#[test]
#[should_panic(expected = "storage.flat_len() = 5 but layout.storage_extent() = 6")]
fn new_panics_on_storage_layout_length_mismatch() {
let storage = DenseStorage::<f64>::new(vec![0.0; 5]);
let layout = DenseLayout::new(vec![2, 3], MemoryOrder::RowMajor);
let _ = TensorData::new(storage, layout);
}
#[test]
fn new_accepts_matching_lengths() {
let storage = DenseStorage::<f64>::new(vec![0.0; 6]);
let layout = DenseLayout::new(vec![2, 3], MemoryOrder::RowMajor);
let td = TensorData::new(storage, layout);
assert_eq!(td.storage().data().len(), 6);
assert_eq!(td.layout().shape(), &[2, 3]);
}
}