use crate::matrix::Matrix;
use crate::vector::Vector;
use crate::{to_i32, StrError, CBLAS_COL_MAJOR, CBLAS_NO_TRANS};
extern "C" {
fn cblas_dgemv(
layout: i32,
transa: i32,
m: i32,
n: i32,
alpha: f64,
a: *const f64,
lda: i32,
x: *const f64,
incx: i32,
beta: f64,
y: *mut f64,
incy: i32,
);
}
pub fn mat_vec_mul_update(v: &mut Vector, alpha: f64, a: &Matrix, u: &Vector, beta: f64) -> Result<(), StrError> {
let m = v.dim();
let n = u.dim();
if m != a.nrow() || n != a.ncol() {
return Err("matrix and vectors are incompatible");
}
if m == 0 {
return Ok(());
}
if n == 0 {
v.fill(0.0);
return Ok(());
}
let m_i32: i32 = to_i32(m);
let n_i32: i32 = to_i32(n);
let incx = 1;
let incy = 1;
unsafe {
cblas_dgemv(
CBLAS_COL_MAJOR,
CBLAS_NO_TRANS,
m_i32,
n_i32,
alpha,
a.as_data().as_ptr(),
m_i32,
u.as_data().as_ptr(),
incx,
beta,
v.as_mut_data().as_mut_ptr(),
incy,
);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::{mat_vec_mul_update, Matrix, Vector};
use crate::vec_approx_eq;
#[test]
fn mat_vec_mul_update_fails_on_wrong_dims() {
let u = Vector::new(2);
let a_1x2 = Matrix::new(1, 2);
let a_3x1 = Matrix::new(3, 1);
let mut v = Vector::new(3);
assert_eq!(
mat_vec_mul_update(&mut v, 1.0, &a_1x2, &u, 0.0),
Err("matrix and vectors are incompatible")
);
assert_eq!(
mat_vec_mul_update(&mut v, 1.0, &a_3x1, &u, 0.0),
Err("matrix and vectors are incompatible")
);
}
#[test]
fn mat_vec_mul_update_works() {
#[rustfmt::skip]
let a = Matrix::from(&[
[ 5.0, -2.0, 0.0, 1.0],
[10.0, -4.0, 0.0, 2.0],
[15.0, -6.0, 0.0, 3.0],
]);
let u = Vector::from(&[1.0, 3.0, 8.0, 5.0]);
let mut v = Vector::new(a.nrow());
mat_vec_mul_update(&mut v, 1.0, &a, &u, 0.0).unwrap();
let correct = &[4.0, 8.0, 12.0];
vec_approx_eq(&v, correct, 1e-15);
let mut v = Vector::from(&[100.0, 200.0, 300.0]);
mat_vec_mul_update(&mut v, 1.0, &a, &u, 0.0).unwrap();
let correct = &[4.0, 8.0, 12.0];
vec_approx_eq(&v, correct, 1e-15);
let mut v = Vector::from(&[100.0, 200.0, 300.0]);
mat_vec_mul_update(&mut v, 1.0, &a, &u, 1.0).unwrap();
let correct = &[104.0, 208.0, 312.0];
vec_approx_eq(&v, correct, 1e-15);
}
#[test]
fn mat_vec_mul_update_zero_works() {
let a_0x0 = Matrix::new(0, 0);
let a_0x1 = Matrix::new(0, 1);
let a_1x0 = Matrix::new(1, 0);
let u0 = Vector::new(0);
let u1 = Vector::new(1);
let mut v0 = Vector::new(0);
let mut v1 = Vector::new(1);
mat_vec_mul_update(&mut v0, 1.0, &a_0x0, &u0, 0.0).unwrap();
assert_eq!(v0.as_data(), &[] as &[f64]);
mat_vec_mul_update(&mut v0, 1.0, &a_0x1, &u1, 1.0).unwrap();
assert_eq!(v0.as_data(), &[] as &[f64]);
mat_vec_mul_update(&mut v1, 1.0, &a_1x0, &u0, 2.0).unwrap();
assert_eq!(v1.as_data(), &[0.0]);
}
}