use crate::Tensor;
use crate::shape::{coord_to_flat, flat_to_coord};
impl Tensor {
#[cfg(feature = "dynamic")]
#[inline(always)]
fn require_numeric(&self, operation: &'static str) {
if self.is_dynamic() {
panic!(
"matten unsupported error in {operation}: this reduction is not supported on dynamic tensors; call try_numeric() first to convert"
);
}
}
#[must_use]
pub fn sum(&self) -> f64 {
#[cfg(feature = "dynamic")]
self.require_numeric("sum");
self.data.iter().sum()
}
#[must_use]
pub fn mean(&self) -> f64 {
#[cfg(feature = "dynamic")]
self.require_numeric("mean");
self.sum() / self.data.len() as f64
}
#[must_use]
pub fn min(&self) -> f64 {
#[cfg(feature = "dynamic")]
self.require_numeric("min");
nan_reduce(&self.data, f64::INFINITY, |acc, v| acc.min(v))
}
#[must_use]
pub fn max(&self) -> f64 {
#[cfg(feature = "dynamic")]
self.require_numeric("max");
nan_reduce(&self.data, f64::NEG_INFINITY, |acc, v| acc.max(v))
}
}
fn nan_reduce(data: &[f64], init: f64, f: impl Fn(f64, f64) -> f64) -> f64 {
let mut acc = init;
for &v in data {
if v.is_nan() {
return f64::NAN;
}
acc = f(acc, v);
}
acc
}
impl Tensor {
#[must_use]
pub fn sum_axis(&self, axis: usize) -> Tensor {
#[cfg(feature = "dynamic")]
self.require_numeric("sum_axis");
axis_reduce(self, axis, "sum_axis", |acc, v| acc + v, 0.0)
}
#[must_use]
pub fn mean_axis(&self, axis: usize) -> Tensor {
#[cfg(feature = "dynamic")]
self.require_numeric("mean_axis");
if axis >= self.ndim() {
panic!(
"matten shape error in mean_axis: axis {axis} is out of range for rank-{} tensor",
self.ndim()
);
}
let n = self.shape()[axis] as f64;
let sums = axis_reduce(self, axis, "mean_axis", |acc, v| acc + v, 0.0);
&sums / n
}
}
impl Tensor {
#[must_use]
pub fn min_axis(&self, axis: usize) -> Tensor {
#[cfg(feature = "dynamic")]
self.require_numeric("min_axis");
nan_axis_reduce(self, axis, "min_axis", f64::INFINITY, |a, b| a.min(b))
}
#[must_use]
pub fn max_axis(&self, axis: usize) -> Tensor {
#[cfg(feature = "dynamic")]
self.require_numeric("max_axis");
nan_axis_reduce(self, axis, "max_axis", f64::NEG_INFINITY, |a, b| a.max(b))
}
}
fn nan_axis_reduce(
t: &Tensor,
axis: usize,
operation: &'static str,
identity: f64,
f: impl Fn(f64, f64) -> f64,
) -> Tensor {
if axis >= t.ndim() {
panic!(
"matten shape error in {operation}: axis {axis} is out of range for rank-{} tensor",
t.ndim()
);
}
let src_shape = t.shape();
let out_shape: Vec<usize> = src_shape
.iter()
.enumerate()
.filter(|&(i, _)| i != axis)
.map(|(_, &d)| d)
.collect();
let out_len: usize = if out_shape.is_empty() {
1
} else {
out_shape.iter().product()
};
let mut out_data = vec![identity; out_len];
let mut has_nan = vec![false; out_len];
for (src_flat, &val) in t.data.iter().enumerate() {
let src_coord = flat_to_coord(src_flat, src_shape);
let out_coord: Vec<usize> = src_coord
.iter()
.enumerate()
.filter(|&(i, _)| i != axis)
.map(|(_, &c)| c)
.collect();
let dst_flat = if out_shape.is_empty() {
0
} else {
coord_to_flat(&out_coord, &out_shape).expect("valid by construction")
};
if val.is_nan() {
has_nan[dst_flat] = true;
} else {
out_data[dst_flat] = f(out_data[dst_flat], val);
}
}
for (i, &nan) in has_nan.iter().enumerate() {
if nan {
out_data[i] = f64::NAN;
}
}
Tensor {
data: out_data,
shape: out_shape,
#[cfg(feature = "dynamic")]
dynamic: None,
}
}
fn axis_reduce(
t: &Tensor,
axis: usize,
operation: &'static str,
f: impl Fn(f64, f64) -> f64,
identity: f64,
) -> Tensor {
if axis >= t.ndim() {
panic!(
"matten shape error in {operation}: axis {axis} is out of range \
for rank-{} tensor",
t.ndim()
);
}
let src_shape = t.shape();
let out_shape: Vec<usize> = src_shape
.iter()
.enumerate()
.filter(|&(i, _)| i != axis)
.map(|(_, &d)| d)
.collect();
let out_len: usize = if out_shape.is_empty() {
1
} else {
out_shape.iter().product()
};
let mut out_data = vec![identity; out_len];
for (src_flat, &val) in t.data.iter().enumerate() {
let src_coord = flat_to_coord(src_flat, src_shape);
let out_coord: Vec<usize> = src_coord
.iter()
.enumerate()
.filter(|&(i, _)| i != axis)
.map(|(_, &c)| c)
.collect();
let dst_flat = if out_shape.is_empty() {
0
} else {
coord_to_flat(&out_coord, &out_shape).expect("valid by construction")
};
out_data[dst_flat] = f(out_data[dst_flat], val);
}
Tensor {
data: out_data,
shape: out_shape,
#[cfg(feature = "dynamic")]
dynamic: None,
}
}
impl Tensor {
#[must_use]
pub fn dot(&self, rhs: &Tensor) -> Tensor {
#[cfg(feature = "dynamic")]
if self.is_dynamic() || rhs.is_dynamic() {
panic!(
"matten unsupported error in dot/matmul: not supported on dynamic \
tensors; call try_numeric() on each operand first"
);
}
matmul_dispatch(self, rhs, "dot")
}
#[must_use]
pub fn matmul(&self, rhs: &Tensor) -> Tensor {
matmul_dispatch(self, rhs, "matmul")
}
}
fn matmul_dispatch(lhs: &Tensor, rhs: &Tensor, op: &'static str) -> Tensor {
match (lhs.ndim(), rhs.ndim()) {
(1, 1) => vv_dot(lhs, rhs, op),
(2, 1) => mv_mul(lhs, rhs, op),
(1, 2) => vm_mul(lhs, rhs, op),
(2, 2) => mm_mul(lhs, rhs, op),
_ => panic!(
"matten shape error in {op}: unsupported rank combination \
(left rank {}, right rank {}); supported: [n]×[n], [m,n]×[n], \
[n]×[n,p], [m,n]×[n,p]",
lhs.ndim(),
rhs.ndim()
),
}
}
fn vv_dot(a: &Tensor, b: &Tensor, op: &'static str) -> Tensor {
let n = a.len();
if b.len() != n {
panic!(
"matten shape error in {op}: vector lengths must match \
(left {n}, right {})",
b.len()
);
}
let v: f64 = a.data.iter().zip(&b.data).map(|(x, y)| x * y).sum();
Tensor::scalar(v)
}
fn mv_mul(a: &Tensor, b: &Tensor, op: &'static str) -> Tensor {
let [m, n] = shape2(a, op);
dim_check(n, b.len(), "left columns", "right length", op);
let mut out = vec![0.0f64; m];
for (i, o) in out.iter_mut().enumerate() {
for k in 0..n {
*o += a.data[i * n + k] * b.data[k];
}
}
Tensor {
data: out,
shape: vec![m],
#[cfg(feature = "dynamic")]
dynamic: None,
}
}
fn vm_mul(a: &Tensor, b: &Tensor, op: &'static str) -> Tensor {
let [n, p] = shape2(b, op);
dim_check(a.len(), n, "left length", "right rows", op);
let mut out = vec![0.0f64; p];
for k in 0..n {
for (j, slot) in out.iter_mut().enumerate() {
*slot += a.data[k] * b.data[k * p + j];
}
}
Tensor {
data: out,
shape: vec![p],
#[cfg(feature = "dynamic")]
dynamic: None,
}
}
fn mm_mul(a: &Tensor, b: &Tensor, op: &'static str) -> Tensor {
let [m, n] = shape2(a, op);
let [nb, p] = shape2(b, op);
dim_check(n, nb, "left columns", "right rows", op);
let mut out = vec![0.0f64; m * p];
for (i, row) in out.chunks_mut(p).enumerate() {
for (j, slot) in row.iter_mut().enumerate() {
let mut acc = 0.0f64;
for k in 0..n {
acc += a.data[i * n + k] * b.data[k * p + j];
}
*slot = acc;
}
}
Tensor {
data: out,
shape: vec![m, p],
#[cfg(feature = "dynamic")]
dynamic: None,
}
}
fn shape2(t: &Tensor, op: &'static str) -> [usize; 2] {
match t.shape() {
[a, b] => [*a, *b],
s => panic!("matten shape error in {op}: expected rank-2 tensor, got shape {s:?}"),
}
}
fn dim_check(left: usize, right: usize, left_name: &str, right_name: &str, op: &'static str) {
if left != right {
panic!(
"matten shape error in {op}: {left_name} ({left}) \
must equal {right_name} ({right})"
);
}
}