use crate::shape::{coord_to_flat, flat_to_coord};
use crate::{MattenError, Tensor};
impl Tensor {
#[must_use]
pub fn sum(&self) -> f64 {
self.data.iter().sum()
}
#[must_use]
pub fn mean(&self) -> f64 {
self.sum() / self.data.len() as f64
}
#[must_use]
pub fn min(&self) -> f64 {
nan_reduce(&self.data, f64::INFINITY, |acc, v| acc.min(v))
}
#[must_use]
pub fn max(&self) -> f64 {
nan_reduce(&self.data, f64::NEG_INFINITY, |acc, v| acc.max(v))
}
}
fn nan_reduce(data: &[f64], init: f64, f: impl Fn(f64, f64) -> f64) -> f64 {
let mut acc = init;
for &v in data {
if v.is_nan() {
return f64::NAN;
}
acc = f(acc, v);
}
acc
}
impl Tensor {
#[must_use]
pub fn sum_axis(&self, axis: usize) -> Tensor {
axis_reduce(self, axis, "sum_axis", |acc, v| acc + v, 0.0)
}
#[must_use]
pub fn mean_axis(&self, axis: usize) -> Tensor {
let n = self.shape()[axis] as f64;
let sums = axis_reduce(self, axis, "mean_axis", |acc, v| acc + v, 0.0);
&sums / n
}
}
fn axis_reduce(
t: &Tensor,
axis: usize,
operation: &'static str,
f: impl Fn(f64, f64) -> f64,
identity: f64,
) -> Tensor {
if axis >= t.ndim() {
panic!(
"matten shape error in {operation}: axis {axis} is out of range \
for rank-{} tensor",
t.ndim()
);
}
let src_shape = t.shape();
let out_shape: Vec<usize> = src_shape
.iter()
.enumerate()
.filter(|&(i, _)| i != axis)
.map(|(_, &d)| d)
.collect();
let out_len: usize = if out_shape.is_empty() {
1
} else {
out_shape.iter().product()
};
let mut out_data = vec![identity; out_len];
for (src_flat, &val) in t.data.iter().enumerate() {
let src_coord = flat_to_coord(src_flat, src_shape);
let out_coord: Vec<usize> = src_coord
.iter()
.enumerate()
.filter(|&(i, _)| i != axis)
.map(|(_, &c)| c)
.collect();
let dst_flat = if out_shape.is_empty() {
0
} else {
coord_to_flat(&out_coord, &out_shape).expect("valid by construction")
};
out_data[dst_flat] = f(out_data[dst_flat], val);
}
Tensor {
data: out_data,
shape: out_shape,
#[cfg(feature = "dynamic")]
dynamic: None,
}
}
impl Tensor {
#[must_use]
pub fn dot(&self, rhs: &Tensor) -> Tensor {
matmul_dispatch(self, rhs, "dot")
}
#[must_use]
pub fn matmul(&self, rhs: &Tensor) -> Tensor {
matmul_dispatch(self, rhs, "matmul")
}
}
fn matmul_dispatch(lhs: &Tensor, rhs: &Tensor, op: &'static str) -> Tensor {
match (lhs.ndim(), rhs.ndim()) {
(1, 1) => vv_dot(lhs, rhs, op),
(2, 1) => mv_mul(lhs, rhs, op),
(1, 2) => vm_mul(lhs, rhs, op),
(2, 2) => mm_mul(lhs, rhs, op),
_ => panic!(
"matten shape error in {op}: unsupported rank combination \
(left rank {}, right rank {}); supported: [n]×[n], [m,n]×[n], \
[n]×[n,p], [m,n]×[n,p]",
lhs.ndim(),
rhs.ndim()
),
}
}
fn vv_dot(a: &Tensor, b: &Tensor, op: &'static str) -> Tensor {
let n = a.len();
if b.len() != n {
panic!(
"matten shape error in {op}: vector lengths must match \
(left {n}, right {})",
b.len()
);
}
let v: f64 = a.data.iter().zip(&b.data).map(|(x, y)| x * y).sum();
Tensor::scalar(v)
}
fn mv_mul(a: &Tensor, b: &Tensor, op: &'static str) -> Tensor {
let [m, n] = shape2(a, op);
dim_check(n, b.len(), "left columns", "right length", op);
let mut out = vec![0.0f64; m];
for (i, o) in out.iter_mut().enumerate() {
for k in 0..n {
*o += a.data[i * n + k] * b.data[k];
}
}
Tensor {
data: out,
shape: vec![m],
#[cfg(feature = "dynamic")]
dynamic: None,
}
}
fn vm_mul(a: &Tensor, b: &Tensor, op: &'static str) -> Tensor {
let [n, p] = shape2(b, op);
dim_check(a.len(), n, "left length", "right rows", op);
let mut out = vec![0.0f64; p];
for k in 0..n {
for (j, slot) in out.iter_mut().enumerate() {
*slot += a.data[k] * b.data[k * p + j];
}
}
Tensor {
data: out,
shape: vec![p],
#[cfg(feature = "dynamic")]
dynamic: None,
}
}
fn mm_mul(a: &Tensor, b: &Tensor, op: &'static str) -> Tensor {
let [m, n] = shape2(a, op);
let [nb, p] = shape2(b, op);
dim_check(n, nb, "left columns", "right rows", op);
let mut out = vec![0.0f64; m * p];
for (i, row) in out.chunks_mut(p).enumerate() {
for (j, slot) in row.iter_mut().enumerate() {
let mut acc = 0.0f64;
for k in 0..n {
acc += a.data[i * n + k] * b.data[k * p + j];
}
*slot = acc;
}
}
Tensor {
data: out,
shape: vec![m, p],
#[cfg(feature = "dynamic")]
dynamic: None,
}
}
fn shape2(t: &Tensor, op: &'static str) -> [usize; 2] {
match t.shape() {
[a, b] => [*a, *b],
s => panic!("matten shape error in {op}: expected rank-2 tensor, got shape {s:?}"),
}
}
fn dim_check(left: usize, right: usize, left_name: &str, right_name: &str, op: &'static str) {
if left != right {
panic!(
"matten shape error in {op}: {left_name} ({left}) \
must equal {right_name} ({right})"
);
}
}
#[allow(dead_code)]
pub(crate) fn try_matmul(
lhs: &Tensor,
rhs: &Tensor,
op: &'static str,
) -> Result<Tensor, MattenError> {
match (lhs.ndim(), rhs.ndim()) {
(1, 1) | (2, 1) | (1, 2) | (2, 2) => Ok(matmul_dispatch(lhs, rhs, op)),
_ => Err(MattenError::Shape {
operation: op,
message: format!(
"unsupported rank combination (left rank {}, right rank {})",
lhs.ndim(),
rhs.ndim()
),
}),
}
}