use crate::matrix::Matrix;
use crate::vector::Vector;
use crate::{to_i32, StrError, CBLAS_COL_MAJOR, CBLAS_NO_TRANS};
extern "C" {
fn cblas_dgemv(
layout: i32,
transa: i32,
m: i32,
n: i32,
alpha: f64,
a: *const f64,
lda: i32,
x: *const f64,
incx: i32,
beta: f64,
y: *mut f64,
incy: i32,
);
}
pub fn mat_vec_mul(v: &mut Vector, alpha: f64, a: &Matrix, u: &Vector) -> Result<(), StrError> {
let m = v.dim();
let n = u.dim();
if m != a.nrow() || n != a.ncol() {
return Err("matrix and vectors are incompatible");
}
if m == 0 {
return Ok(());
}
if n == 0 {
v.fill(0.0);
return Ok(());
}
let m_i32: i32 = to_i32(m);
let n_i32: i32 = to_i32(n);
let incx = 1;
let incy = 1;
let beta = 0.0;
unsafe {
cblas_dgemv(
CBLAS_COL_MAJOR,
CBLAS_NO_TRANS,
m_i32,
n_i32,
alpha,
a.as_data().as_ptr(),
m_i32,
u.as_data().as_ptr(),
incx,
beta,
v.as_mut_data().as_mut_ptr(),
incy,
);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::{mat_vec_mul, Matrix, Vector};
use crate::{vec_approx_eq, vec_norm, Norm};
#[test]
fn mat_vec_mul_fails_on_wrong_dims() {
let u = Vector::new(2);
let a_1x2 = Matrix::new(1, 2);
let a_3x1 = Matrix::new(3, 1);
let mut v = Vector::new(3);
assert_eq!(
mat_vec_mul(&mut v, 1.0, &a_1x2, &u),
Err("matrix and vectors are incompatible")
);
assert_eq!(
mat_vec_mul(&mut v, 1.0, &a_3x1, &u),
Err("matrix and vectors are incompatible")
);
}
#[test]
fn mat_vec_mul_works() {
#[rustfmt::skip]
let a = Matrix::from(&[
[ 5.0, -2.0, 0.0, 1.0],
[10.0, -4.0, 0.0, 2.0],
[15.0, -6.0, 0.0, 3.0],
]);
let u = Vector::from(&[1.0, 3.0, 8.0, 5.0]);
let mut v = Vector::new(a.nrow());
mat_vec_mul(&mut v, 1.0, &a, &u).unwrap();
let correct = &[4.0, 8.0, 12.0];
vec_approx_eq(&v, correct, 1e-15);
}
#[test]
fn mat_vec_mul_zero_works() {
let a_0x0 = Matrix::new(0, 0);
let a_0x1 = Matrix::new(0, 1);
let a_1x0 = Matrix::new(1, 0);
let u0 = Vector::new(0);
let u1 = Vector::new(1);
let mut v0 = Vector::new(0);
let mut v1 = Vector::new(1);
mat_vec_mul(&mut v0, 1.0, &a_0x0, &u0).unwrap();
assert_eq!(v0.as_data(), &[] as &[f64]);
mat_vec_mul(&mut v0, 1.0, &a_0x1, &u1).unwrap();
assert_eq!(v0.as_data(), &[] as &[f64]);
mat_vec_mul(&mut v1, 1.0, &a_1x0, &u0).unwrap();
assert_eq!(v1.as_data(), &[0.0]);
}
#[test]
fn mat_vec_mul_works_range() {
for m in [0, 7, 15_usize] {
for n in [0, 4, 8_usize] {
let a = Matrix::filled(m, n, 1.0);
let u = Vector::filled(n, 1.0);
let mut v = Vector::new(m);
mat_vec_mul(&mut v, 1.0, &a, &u).unwrap();
if m == 0 {
assert_eq!(vec_norm(&v, Norm::Max), 0.0);
} else {
assert_eq!(vec_norm(&v, Norm::Max), n as f64);
}
}
}
}
}