use crate::{assert, mat::*, *};
use reborrow::*;
#[track_caller]
pub fn kron<E: ComplexField>(dst: MatMut<E>, lhs: MatRef<E>, rhs: MatRef<E>) {
let mut dst = dst;
let mut lhs = lhs;
let mut rhs = rhs;
if dst.col_stride().unsigned_abs() < dst.row_stride().unsigned_abs() {
dst = dst.transpose_mut();
lhs = lhs.transpose();
rhs = rhs.transpose();
}
assert!(Some(dst.nrows()) == lhs.nrows().checked_mul(rhs.nrows()));
assert!(Some(dst.ncols()) == lhs.ncols().checked_mul(rhs.ncols()));
for lhs_j in 0..lhs.ncols() {
for lhs_i in 0..lhs.nrows() {
let lhs_val = lhs.read(lhs_i, lhs_j);
let mut dst = dst.rb_mut().submatrix_mut(
lhs_i * rhs.nrows(),
lhs_j * rhs.ncols(),
rhs.nrows(),
rhs.ncols(),
);
for rhs_j in 0..rhs.ncols() {
for rhs_i in 0..rhs.nrows() {
unsafe {
let rhs_val = rhs.read_unchecked(rhs_i, rhs_j);
dst.write_unchecked(rhs_i, rhs_j, lhs_val.faer_mul(rhs_val));
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use crate::{assert, prelude::*};
#[test]
fn test_kron_ones() {
for (m, n, p, q) in [(2, 3, 4, 5), (3, 2, 5, 4), (1, 1, 1, 1)] {
let a = Mat::from_fn(m, n, |_, _| 1 as f64);
let b = Mat::from_fn(p, q, |_, _| 1 as f64);
let expected = Mat::from_fn(m * p, n * q, |_, _| 1 as f64);
assert!(a.kron(&b) == expected);
}
for (m, n, p) in [(2, 3, 4), (3, 2, 5), (1, 1, 1)] {
let a = Mat::from_fn(m, n, |_, _| 1 as f64);
let b = Col::from_fn(p, |_| 1 as f64);
let expected = Mat::from_fn(m * p, n, |_, _| 1 as f64);
assert!(a.kron(&b) == expected);
assert!(b.kron(&a) == expected);
let a = Mat::from_fn(m, n, |_, _| 1 as f64);
let b = Row::from_fn(p, |_| 1 as f64);
let expected = Mat::from_fn(m, n * p, |_, _| 1 as f64);
assert!(a.kron(&b) == expected);
assert!(b.kron(&a) == expected);
}
for (m, n) in [(2, 3), (3, 2), (1, 1)] {
let a = Row::from_fn(m, |_| 1 as f64);
let b = Col::from_fn(n, |_| 1 as f64);
let expected = Mat::from_fn(n, m, |_, _| 1 as f64);
assert!(a.kron(&b) == expected);
assert!(b.kron(&a) == expected);
let c = Row::from_fn(n, |_| 1 as f64);
let expected = Mat::from_fn(1, m * n, |_, _| 1 as f64);
assert!(a.kron(&c) == expected);
let d = Col::from_fn(m, |_| 1 as f64);
let expected = Mat::from_fn(m * n, 1, |_, _| 1 as f64);
assert!(d.kron(&b) == expected);
}
}
}