use crate::MattenError;
use crate::Tensor;
use crate::shape::{coord_to_flat, flat_to_coord};
impl Tensor {
#[must_use]
pub fn sum(&self) -> f64 {
self.try_sum().unwrap_or_else(|e| panic!("{e}"))
}
pub fn try_sum(&self) -> Result<f64, MattenError> {
reject_dynamic(self, "sum")?;
Ok(self.data.iter().sum())
}
#[must_use]
pub fn mean(&self) -> f64 {
self.try_mean().unwrap_or_else(|e| panic!("{e}"))
}
pub fn try_mean(&self) -> Result<f64, MattenError> {
reject_dynamic(self, "mean")?;
Ok(self.data.iter().sum::<f64>() / self.data.len() as f64)
}
#[must_use]
pub fn min(&self) -> f64 {
self.try_min().unwrap_or_else(|e| panic!("{e}"))
}
pub fn try_min(&self) -> Result<f64, MattenError> {
reject_dynamic(self, "min")?;
Ok(nan_reduce(&self.data, f64::INFINITY, |acc, v| acc.min(v)))
}
#[must_use]
pub fn max(&self) -> f64 {
self.try_max().unwrap_or_else(|e| panic!("{e}"))
}
pub fn try_max(&self) -> Result<f64, MattenError> {
reject_dynamic(self, "max")?;
Ok(nan_reduce(&self.data, f64::NEG_INFINITY, |acc, v| {
acc.max(v)
}))
}
}
pub(crate) fn reject_dynamic(t: &Tensor, operation: &'static str) -> Result<(), MattenError> {
#[cfg(feature = "dynamic")]
if t.is_dynamic() {
return Err(MattenError::Unsupported {
operation,
message: format!(
"{operation} is not supported on dynamic tensors; call try_numeric() first"
),
});
}
#[cfg(not(feature = "dynamic"))]
let _ = (t, operation);
Ok(())
}
pub(crate) fn check_axis(
t: &Tensor,
axis: usize,
operation: &'static str,
) -> Result<(), MattenError> {
let rank = t.shape().len();
if axis >= rank {
return Err(MattenError::Shape {
operation,
message: format!("axis {axis} is out of range for a rank-{rank} tensor"),
});
}
Ok(())
}
fn nan_reduce(data: &[f64], init: f64, f: impl Fn(f64, f64) -> f64) -> f64 {
let mut acc = init;
for &v in data {
if v.is_nan() {
return f64::NAN;
}
acc = f(acc, v);
}
acc
}
impl Tensor {
#[must_use]
pub fn sum_axis(&self, axis: usize) -> Tensor {
self.try_sum_axis(axis).unwrap_or_else(|e| panic!("{e}"))
}
pub fn try_sum_axis(&self, axis: usize) -> Result<Tensor, MattenError> {
reject_dynamic(self, "sum_axis")?;
check_axis(self, axis, "sum_axis")?;
Ok(axis_reduce(self, axis, "sum_axis", |acc, v| acc + v, 0.0))
}
#[must_use]
pub fn mean_axis(&self, axis: usize) -> Tensor {
self.try_mean_axis(axis).unwrap_or_else(|e| panic!("{e}"))
}
pub fn try_mean_axis(&self, axis: usize) -> Result<Tensor, MattenError> {
reject_dynamic(self, "mean_axis")?;
check_axis(self, axis, "mean_axis")?;
let n = self.shape()[axis] as f64;
let sums = axis_reduce(self, axis, "mean_axis", |acc, v| acc + v, 0.0);
Ok(&sums / n)
}
}
impl Tensor {
#[must_use]
pub fn min_axis(&self, axis: usize) -> Tensor {
self.try_min_axis(axis).unwrap_or_else(|e| panic!("{e}"))
}
pub fn try_min_axis(&self, axis: usize) -> Result<Tensor, MattenError> {
reject_dynamic(self, "min_axis")?;
check_axis(self, axis, "min_axis")?;
Ok(nan_axis_reduce(
self,
axis,
"min_axis",
f64::INFINITY,
|a, b| a.min(b),
))
}
#[must_use]
pub fn max_axis(&self, axis: usize) -> Tensor {
self.try_max_axis(axis).unwrap_or_else(|e| panic!("{e}"))
}
pub fn try_max_axis(&self, axis: usize) -> Result<Tensor, MattenError> {
reject_dynamic(self, "max_axis")?;
check_axis(self, axis, "max_axis")?;
Ok(nan_axis_reduce(
self,
axis,
"max_axis",
f64::NEG_INFINITY,
|a, b| a.max(b),
))
}
}
fn nan_axis_reduce(
t: &Tensor,
axis: usize,
operation: &'static str,
identity: f64,
f: impl Fn(f64, f64) -> f64,
) -> Tensor {
if axis >= t.ndim() {
panic!(
"matten shape error in {operation}: axis {axis} is out of range for rank-{} tensor",
t.ndim()
);
}
let src_shape = t.shape();
let out_shape: Vec<usize> = src_shape
.iter()
.enumerate()
.filter(|&(i, _)| i != axis)
.map(|(_, &d)| d)
.collect();
let out_len: usize = if out_shape.is_empty() {
1
} else {
out_shape.iter().product()
};
let mut out_data = vec![identity; out_len];
let mut has_nan = vec![false; out_len];
for (src_flat, &val) in t.data.iter().enumerate() {
let src_coord = flat_to_coord(src_flat, src_shape);
let out_coord: Vec<usize> = src_coord
.iter()
.enumerate()
.filter(|&(i, _)| i != axis)
.map(|(_, &c)| c)
.collect();
let dst_flat = if out_shape.is_empty() {
0
} else {
coord_to_flat(&out_coord, &out_shape).expect("valid by construction")
};
if val.is_nan() {
has_nan[dst_flat] = true;
} else {
out_data[dst_flat] = f(out_data[dst_flat], val);
}
}
for (i, &nan) in has_nan.iter().enumerate() {
if nan {
out_data[i] = f64::NAN;
}
}
Tensor {
data: out_data,
shape: out_shape,
#[cfg(feature = "dynamic")]
dynamic: None,
}
}
fn axis_reduce(
t: &Tensor,
axis: usize,
operation: &'static str,
f: impl Fn(f64, f64) -> f64,
identity: f64,
) -> Tensor {
if axis >= t.ndim() {
panic!(
"matten shape error in {operation}: axis {axis} is out of range for rank-{} tensor",
t.ndim()
);
}
let src_shape = t.shape();
let out_shape: Vec<usize> = src_shape
.iter()
.enumerate()
.filter(|&(i, _)| i != axis)
.map(|(_, &d)| d)
.collect();
let out_len: usize = if out_shape.is_empty() {
1
} else {
out_shape.iter().product()
};
let mut out_data = vec![identity; out_len];
for (src_flat, &val) in t.data.iter().enumerate() {
let src_coord = flat_to_coord(src_flat, src_shape);
let out_coord: Vec<usize> = src_coord
.iter()
.enumerate()
.filter(|&(i, _)| i != axis)
.map(|(_, &c)| c)
.collect();
let dst_flat = if out_shape.is_empty() {
0
} else {
coord_to_flat(&out_coord, &out_shape).expect("valid by construction")
};
out_data[dst_flat] = f(out_data[dst_flat], val);
}
Tensor {
data: out_data,
shape: out_shape,
#[cfg(feature = "dynamic")]
dynamic: None,
}
}
impl Tensor {
#[must_use]
pub fn dot(&self, rhs: &Tensor) -> Tensor {
#[cfg(feature = "dynamic")]
if self.is_dynamic() || rhs.is_dynamic() {
panic!(
"matten unsupported error in dot/matmul: not supported on dynamic tensors; call try_numeric() on each operand first"
);
}
matmul_dispatch(self, rhs, "dot")
}
#[must_use]
pub fn matmul(&self, rhs: &Tensor) -> Tensor {
self.dot(rhs)
}
}
fn matmul_dispatch(lhs: &Tensor, rhs: &Tensor, op: &'static str) -> Tensor {
match (lhs.ndim(), rhs.ndim()) {
(1, 1) => vv_dot(lhs, rhs, op),
(2, 1) => mv_mul(lhs, rhs, op),
(1, 2) => vm_mul(lhs, rhs, op),
(2, 2) => mm_mul(lhs, rhs, op),
_ => panic!(
"matten shape error in {op}: unsupported rank combination \
(left rank {}, right rank {}); supported: [n]×[n], [m,n]×[n], [n]×[n,p], [m,n]×[n,p]",
lhs.ndim(),
rhs.ndim()
),
}
}
fn vv_dot(a: &Tensor, b: &Tensor, op: &'static str) -> Tensor {
let n = a.len();
if b.len() != n {
panic!(
"matten shape error in {op}: vector lengths must match (left {n}, right {})",
b.len()
);
}
let v: f64 = a.data.iter().zip(&b.data).map(|(x, y)| x * y).sum();
Tensor::scalar(v)
}
fn mv_mul(a: &Tensor, b: &Tensor, op: &'static str) -> Tensor {
let [m, n] = shape2(a, op);
dim_check(n, b.len(), "left columns", "right length", op);
let mut out = vec![0.0f64; m];
for (i, o) in out.iter_mut().enumerate() {
for k in 0..n {
*o += a.data[i * n + k] * b.data[k];
}
}
Tensor {
data: out,
shape: vec![m],
#[cfg(feature = "dynamic")]
dynamic: None,
}
}
fn vm_mul(a: &Tensor, b: &Tensor, op: &'static str) -> Tensor {
let [n, p] = shape2(b, op);
dim_check(a.len(), n, "left length", "right rows", op);
let mut out = vec![0.0f64; p];
for k in 0..n {
for (j, slot) in out.iter_mut().enumerate() {
*slot += a.data[k] * b.data[k * p + j];
}
}
Tensor {
data: out,
shape: vec![p],
#[cfg(feature = "dynamic")]
dynamic: None,
}
}
fn mm_mul(a: &Tensor, b: &Tensor, op: &'static str) -> Tensor {
let [m, n] = shape2(a, op);
let [nb, p] = shape2(b, op);
dim_check(n, nb, "left columns", "right rows", op);
let mut out = vec![0.0f64; m * p];
for (i, row) in out.chunks_mut(p).enumerate() {
for (j, slot) in row.iter_mut().enumerate() {
let mut acc = 0.0f64;
for k in 0..n {
acc += a.data[i * n + k] * b.data[k * p + j];
}
*slot = acc;
}
}
Tensor {
data: out,
shape: vec![m, p],
#[cfg(feature = "dynamic")]
dynamic: None,
}
}
fn shape2(t: &Tensor, op: &'static str) -> [usize; 2] {
match t.shape() {
[a, b] => [*a, *b],
s => panic!("matten shape error in {op}: expected rank-2 tensor, got shape {s:?}"),
}
}
fn dim_check(left: usize, right: usize, left_name: &str, right_name: &str, op: &'static str) {
if left != right {
panic!(
"matten shape error in {op}: {left_name} ({left}) must equal {right_name} ({right})"
);
}
}
#[cfg(test)]
mod tests;