use crate::vector_op as vector;
use crate::{Float, Num};
pub fn zero<V: Num, const RC: usize>() -> [V; RC] {
[V::zero(); RC]
}
pub fn set_zero<V: Num>(m: &mut [V]) {
vector::set_zero(m)
}
pub fn is_zero<V: Num>(m: &[V]) -> bool {
vector::is_zero(m)
}
pub fn scale<V: Num, const RC: usize>(m: [V; RC], s: V) -> [V; RC] {
vector::scale(m, s)
}
pub fn reduce<V: Num, const RC: usize>(m: [V; RC], s: V) -> [V; RC] {
vector::reduce(m, s)
}
pub fn add<V: Num, const RC: usize>(m: [V; RC], other: &[V; RC], scale: V) -> [V; RC] {
vector::add(m, other, scale)
}
pub fn sub<V: Num, const RC: usize>(m: [V; RC], other: &[V; RC], scale: V) -> [V; RC] {
vector::sub(m, other, scale)
}
pub fn absmax<V: Float>(m: &[V]) -> V {
m.iter().fold(V::zero(), |acc, c| V::max(acc, V::abs(*c)))
}
pub fn normalize<V: Float, const RC: usize, const C: usize>(mut m: [V; RC]) -> [V; RC] {
let l = absmax::<V>(&m);
if l < V::epsilon() {
set_zero::<V>(&mut m);
m
} else {
reduce::<V, RC>(m, l)
}
}
pub fn transpose<V: Num, const RC: usize, const R: usize, const C: usize>(m: [V; RC]) -> [V; RC] {
assert_eq!(RC, R * C);
let mut v = zero::<V, RC>();
for r in 0..R {
for c in 0..C {
v[r + c * R] = m[c + r * C];
}
}
v
}
#[track_caller]
pub fn multiply<
V: Num,
const RX: usize, const XC: usize, const RC: usize, const R: usize, const X: usize, const C: usize, >(
a: &[V; RX],
b: &[V; XC],
) -> [V; RC] {
assert_eq!(RX, R * X, "RX must equal R*X");
assert_eq!(RC, R * C, "RC must equal R*C");
assert_eq!(XC, X * C, "XC must equal X*C");
let mut m = [V::zero(); RC];
for r in 0..R {
for c in 0..C {
let mut v = V::zero();
for x in 0..X {
v = v + a[r * X + x] * b[x * C + c];
}
m[r * C + c] = v;
}
}
m
}
pub fn transpose_dyn<V: Copy>(
r: usize,
c: usize,
m: &[V], m_t: &mut [V], ) {
assert!(m.len() == r * c);
assert!(m_t.len() == r * c);
for ri in 0..r {
for ci in 0..c {
m_t[ci * r + ri] = m[ri * c + ci];
}
}
}
pub fn multiply_dyn<V: Num>(
r: usize,
x: usize,
c: usize,
a: &[V], b: &[V], m: &mut [V], ) {
assert!(
a.len() >= r * x,
"Expected a (length {}) to be r*x = {} * {}",
a.len(),
r,
x
);
assert!(
b.len() >= x * c,
"Expected a (length {}) to be r*x = {} * {}",
a.len(),
r,
x
);
assert!(
m.len() >= r * c,
"Expected a (length {}) to be r*x = {} * {}",
a.len(),
r,
x
);
for ri in 0..r {
for ci in 0..c {
let mut v = V::zero();
for xi in 0..x {
v = v + a[ri * x + xi] * b[xi * c + ci];
}
m[ri * c + ci] = v;
}
}
}
pub fn transform_vec<V: Num, const RD: usize, const R: usize, const D: usize>(
m: &[V; RD],
v: &[V; D],
) -> [V; R] {
multiply::<V, RD, D, R, R, D, 1>(m, v)
}
pub fn fmt<V: Num, const C: usize>(f: &mut std::fmt::Formatter, v: &[V]) -> std::fmt::Result {
let mut c = 0;
for (i, value) in v.iter().enumerate() {
if i == 0 {
write!(f, "[{value}")?;
} else if c == 0 {
write!(f, " {value}")?;
} else {
write!(f, ",{value}")?;
}
c += 1;
if c == C {
c = 0;
}
}
write!(f, "]")
}
#[derive(Debug)]
pub struct MatrixType<'a, V: Num, const RC: usize, const C: usize> {
m: &'a [V; RC],
}
impl<'a, V: Num, const RC: usize, const C: usize> MatrixType<'a, V, RC, C> {
pub fn new(m: &'a [V; RC]) -> Self {
Self { m }
}
}
impl<'a, V: Num, const RC: usize, const C: usize> std::fmt::Display for MatrixType<'a, V, RC, C> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt::<V, C>(f, self.m)
}
}