use crate::convert::flatten_rectangular;
use crate::error::MattenError;
use crate::shape;
use std::fmt;
const ARANGE_MAX_ELEMENTS: usize = 1 << 28;
#[derive(Clone, PartialEq)]
pub struct Tensor {
pub(crate) data: Vec<f64>,
pub(crate) shape: Vec<usize>,
}
#[allow(clippy::len_without_is_empty)]
impl Tensor {
#[must_use]
pub fn new(data: Vec<f64>, shape: &[usize]) -> Tensor {
Self::try_new(data, shape).unwrap_or_else(|e| panic!("{e}"))
}
pub fn try_new(data: Vec<f64>, shape: &[usize]) -> Result<Tensor, MattenError> {
let expected = shape::validate_shape(shape, "try_new")?;
if data.len() != expected {
return Err(MattenError::Shape {
operation: "try_new",
message: format!(
"data length {} does not match shape {shape:?}, which requires {expected} elements",
data.len()
),
});
}
Ok(Tensor {
data,
shape: shape.to_vec(),
})
}
#[must_use]
pub fn scalar(value: f64) -> Tensor {
Tensor {
data: vec![value],
shape: Vec::new(),
}
}
#[must_use]
pub fn zeros(shape: &[usize]) -> Tensor {
let len = shape::validate_shape(shape, "zeros").unwrap_or_else(|e| panic!("{e}"));
Tensor {
data: vec![0.0; len],
shape: shape.to_vec(),
}
}
#[must_use]
pub fn ones(shape: &[usize]) -> Tensor {
let len = shape::validate_shape(shape, "ones").unwrap_or_else(|e| panic!("{e}"));
Tensor {
data: vec![1.0; len],
shape: shape.to_vec(),
}
}
#[must_use]
pub fn full(shape: &[usize], value: f64) -> Tensor {
let len = shape::validate_shape(shape, "full").unwrap_or_else(|e| panic!("{e}"));
Tensor {
data: vec![value; len],
shape: shape.to_vec(),
}
}
#[must_use]
pub fn from_vec(data: Vec<f64>) -> Tensor {
let len = data.len();
Tensor::new(data, &[len])
}
#[must_use]
pub fn arange(start: f64, end: f64, step: f64) -> Tensor {
arange_impl(start, end, step, "arange").unwrap_or_else(|e| panic!("{e}"))
}
pub fn try_arange(start: f64, end: f64, step: f64) -> Result<Tensor, MattenError> {
arange_impl(start, end, step, "try_arange")
}
pub fn try_from_rows(rows: Vec<Vec<f64>>) -> Result<Tensor, MattenError> {
let (data, shape) = flatten_rectangular(rows, "try_from_rows")?;
Ok(Tensor { data, shape })
}
#[must_use]
pub fn shape(&self) -> &[usize] {
&self.shape
}
#[must_use]
pub fn ndim(&self) -> usize {
self.shape.len()
}
#[must_use]
pub fn len(&self) -> usize {
self.data.len()
}
#[must_use]
pub fn is_scalar(&self) -> bool {
self.ndim() == 0
}
#[must_use]
pub fn is_vector(&self) -> bool {
self.ndim() == 1
}
#[must_use]
pub fn is_matrix(&self) -> bool {
self.ndim() == 2
}
#[must_use]
pub fn as_slice(&self) -> &[f64] {
&self.data
}
#[must_use]
pub fn to_vec(&self) -> Vec<f64> {
self.data.clone()
}
#[must_use]
pub fn into_vec(self) -> Vec<f64> {
self.data
}
}
impl fmt::Debug for Tensor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
const MAX: usize = 8;
write!(f, "Tensor(shape={:?}, data=[", self.shape)?;
for (i, v) in self.data.iter().take(MAX).enumerate() {
if i > 0 {
f.write_str(", ")?;
}
write!(f, "{v:?}")?;
}
if self.data.len() > MAX {
write!(f, ", ... ({} more)", self.data.len() - MAX)?;
}
f.write_str("])")
}
}
fn arange_impl(
start: f64,
end: f64,
step: f64,
operation: &'static str,
) -> Result<Tensor, MattenError> {
if !start.is_finite() || !end.is_finite() {
return Err(MattenError::Shape {
operation,
message: format!("start and end must be finite (got start={start}, end={end})"),
});
}
if step == 0.0 || !step.is_finite() {
return Err(MattenError::Shape {
operation,
message: format!("step must be a non-zero finite value (got {step})"),
});
}
let raw_count = ((end - start) / step).ceil();
let count: usize = if raw_count <= 0.0 {
0
} else if raw_count > ARANGE_MAX_ELEMENTS as f64 {
return Err(MattenError::Allocation {
requested_elements: raw_count as usize,
message: format!(
"arange would produce ~{} elements, exceeding the limit of {}",
raw_count as usize, ARANGE_MAX_ELEMENTS
),
});
} else {
raw_count as usize
};
let mut data = Vec::with_capacity(count);
let mut i: usize = 0;
loop {
let v = start + step * i as f64;
if (step > 0.0 && v >= end) || (step < 0.0 && v <= end) {
break;
}
data.push(v);
i += 1;
if i > ARANGE_MAX_ELEMENTS {
return Err(MattenError::Allocation {
requested_elements: i,
message: format!("arange exceeded the element limit of {ARANGE_MAX_ELEMENTS}"),
});
}
}
let len = data.len();
if len == 0 {
return Err(MattenError::Shape {
operation,
message: format!("arange(start={start}, end={end}, step={step}) produces no elements"),
});
}
Ok(Tensor {
data,
shape: vec![len],
})
}
impl Tensor {
#[must_use]
pub fn reshape(&self, new_shape: &[usize]) -> Tensor {
crate::reshape::try_reshape_impl(self, new_shape).unwrap_or_else(|e| panic!("{e}"))
}
pub fn try_reshape(&self, new_shape: &[usize]) -> Result<Tensor, MattenError> {
crate::reshape::try_reshape_impl(self, new_shape)
}
#[must_use]
pub fn flatten(&self) -> Tensor {
let len = self.data.len();
Tensor {
data: self.data.clone(),
shape: vec![len],
}
}
#[must_use]
pub fn transpose(&self) -> Tensor {
let ndim = self.ndim();
if ndim == 0 {
panic!("matten shape error in transpose: cannot transpose a scalar (rank 0)");
}
let perm: Vec<usize> = (0..ndim).rev().collect();
crate::reshape::permute_axes(self, &perm)
}
#[must_use]
pub fn t(&self) -> Tensor {
self.transpose()
}
#[must_use]
pub fn swap_axes(&self, axis1: usize, axis2: usize) -> Tensor {
crate::reshape::validate_axes(axis1, axis2, self.ndim(), "swap_axes")
.unwrap_or_else(|e| panic!("{e}"));
let mut perm: Vec<usize> = (0..self.ndim()).collect();
perm.swap(axis1, axis2);
crate::reshape::permute_axes(self, &perm)
}
pub fn get(&self, coord: &[usize]) -> Option<f64> {
let flat = crate::shape::coord_to_flat(coord, &self.shape)?;
self.data.get(flat).copied()
}
pub fn slice(&self) -> crate::slice::SliceBuilder<'_> {
crate::slice::SliceBuilder::new(self)
}
pub fn slice_str(&self, spec: &str) -> Result<Tensor, MattenError> {
let specs = crate::slice::parse_slice_str(spec)?;
crate::slice::execute_slice(self, &specs, "slice_str")
}
}