use crate::error::MattenError;
use crate::shape;
use std::fmt;
#[derive(Clone, PartialEq)]
pub struct Tensor {
pub(crate) data: Vec<f64>,
pub(crate) shape: Vec<usize>,
}
#[allow(clippy::len_without_is_empty)]
impl Tensor {
#[must_use]
pub fn new(data: Vec<f64>, shape: &[usize]) -> Tensor {
Self::try_new(data, shape).unwrap_or_else(|e| panic!("{e}"))
}
pub fn try_new(data: Vec<f64>, shape: &[usize]) -> Result<Tensor, MattenError> {
let expected = shape::validate_shape(shape, "try_new")?;
if data.len() != expected {
return Err(MattenError::Shape {
operation: "try_new",
message: format!(
"data length {} does not match shape {shape:?}, which requires {expected} elements",
data.len()
),
});
}
Ok(Tensor { data, shape: shape.to_vec() })
}
#[must_use]
pub fn scalar(value: f64) -> Tensor {
Tensor { data: vec![value], shape: Vec::new() }
}
#[must_use]
pub fn shape(&self) -> &[usize] { &self.shape }
#[must_use]
pub fn ndim(&self) -> usize { self.shape.len() }
#[must_use]
pub fn len(&self) -> usize { self.data.len() }
#[must_use]
pub fn is_scalar(&self) -> bool { self.ndim() == 0 }
#[must_use]
pub fn is_vector(&self) -> bool { self.ndim() == 1 }
#[must_use]
pub fn is_matrix(&self) -> bool { self.ndim() == 2 }
#[must_use]
pub fn as_slice(&self) -> &[f64] { &self.data }
#[must_use]
pub fn to_vec(&self) -> Vec<f64> { self.data.clone() }
}
impl fmt::Debug for Tensor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
const MAX: usize = 8;
write!(f, "Tensor(shape={:?}, data=[", self.shape)?;
for (i, v) in self.data.iter().take(MAX).enumerate() {
if i > 0 { f.write_str(", ")?; }
write!(f, "{v:?}")?;
}
if self.data.len() > MAX {
write!(f, ", ... ({} more)", self.data.len() - MAX)?;
}
f.write_str("])")
}
}