use crate::core::traits::{MatShape, MatTransVec, MatVec};
use std::marker::PhantomData;
type ShellFn<V> = dyn Fn(&V, &mut V) + Send + Sync;
pub struct ShellMat<V> {
pub dim: usize,
mult: Box<ShellFn<V>>,
mult_trans: Box<ShellFn<V>>,
_marker: PhantomData<fn(&V)>,
}
impl<V> ShellMat<V> {
pub fn new<F, G>(dim: usize, mult: F, mult_trans: G) -> Self
where
F: Fn(&V, &mut V) + Send + Sync + 'static,
G: Fn(&V, &mut V) + Send + Sync + 'static,
{
ShellMat {
dim,
mult: Box::new(mult),
mult_trans: Box::new(mult_trans),
_marker: PhantomData,
}
}
pub fn new_symmetric<F>(dim: usize, mult: F) -> Self
where
F: Fn(&V, &mut V) + Send + Sync + 'static + Clone,
{
ShellMat {
dim,
mult: Box::new(mult.clone()),
mult_trans: Box::new(mult),
_marker: PhantomData,
}
}
pub fn dimension(&self) -> usize {
self.dim
}
}
impl<V> MatVec<V> for ShellMat<V>
where
V: AsRef<[f64]> + AsMut<[f64]>,
{
fn matvec(&self, x: &V, y: &mut V) {
(self.mult)(x, y);
}
}
impl<V> MatTransVec<V> for ShellMat<V>
where
V: AsRef<[f64]> + AsMut<[f64]>,
{
fn mattransvec(&self, x: &V, y: &mut V) {
(self.mult_trans)(x, y);
}
}
impl<V> MatShape for ShellMat<V> {
fn nrows(&self) -> usize {
self.dim
}
fn ncols(&self) -> usize {
self.dim
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::traits::{MatTransVec, MatVec};
#[test]
fn test_shell_identity() {
let identity = ShellMat::new(
3,
|x: &Vec<f64>, y: &mut Vec<f64>| {
let x_ref: &[f64] = x.as_ref();
let y_mut: &mut [f64] = y.as_mut();
y_mut[..x_ref.len()].copy_from_slice(x_ref);
},
|x: &Vec<f64>, y: &mut Vec<f64>| {
let x_ref: &[f64] = x.as_ref();
let y_mut: &mut [f64] = y.as_mut();
y_mut[..x_ref.len()].copy_from_slice(x_ref);
},
);
let x = vec![1.0, 2.0, 3.0];
let mut y = vec![0.0; 3];
identity.matvec(&x, &mut y);
assert_eq!(y, vec![1.0, 2.0, 3.0]);
let mut y_trans = vec![0.0; 3];
identity.mattransvec(&x, &mut y_trans);
assert_eq!(y_trans, vec![1.0, 2.0, 3.0]);
assert_eq!(identity.nrows(), 3);
assert_eq!(identity.ncols(), 3);
}
#[test]
fn test_shell_diagonal() {
let diag_entries = vec![2.0, 3.0, 4.0];
let diag_clone = diag_entries.clone();
let diagonal = ShellMat::new(
3,
move |x: &Vec<f64>, y: &mut Vec<f64>| {
let x_ref: &[f64] = x.as_ref();
let y_mut: &mut [f64] = y.as_mut();
for i in 0..x_ref.len() {
y_mut[i] = diag_entries[i] * x_ref[i];
}
},
move |x: &Vec<f64>, y: &mut Vec<f64>| {
let x_ref: &[f64] = x.as_ref();
let y_mut: &mut [f64] = y.as_mut();
for i in 0..x_ref.len() {
y_mut[i] = diag_clone[i] * x_ref[i];
}
},
);
let x = vec![1.0, 1.0, 1.0];
let mut y = vec![0.0; 3];
diagonal.matvec(&x, &mut y);
assert_eq!(y, vec![2.0, 3.0, 4.0]);
let mut y_trans = vec![0.0; 3];
diagonal.mattransvec(&x, &mut y_trans);
assert_eq!(y_trans, vec![2.0, 3.0, 4.0]);
}
#[test]
fn test_shell_symmetric() {
let symmetric = ShellMat::new_symmetric(2, |x: &Vec<f64>, y: &mut Vec<f64>| {
let x_ref: &[f64] = x.as_ref();
let y_mut: &mut [f64] = y.as_mut();
y_mut[0] = 2.0 * x_ref[0] + 1.0 * x_ref[1];
y_mut[1] = 1.0 * x_ref[0] + 3.0 * x_ref[1];
});
let x = vec![1.0, 0.0];
let mut y = vec![0.0; 2];
symmetric.matvec(&x, &mut y);
assert_eq!(y, vec![2.0, 1.0]);
let mut y_trans = vec![0.0; 2];
symmetric.mattransvec(&x, &mut y_trans);
assert_eq!(y_trans, vec![2.0, 1.0]);
let x2 = vec![0.0, 1.0];
let mut y2 = vec![0.0; 2];
symmetric.matvec(&x2, &mut y2);
assert_eq!(y2, vec![1.0, 3.0]);
}
#[test]
fn test_shell_transpose() {
let matrix = ShellMat::new(
2,
|x: &Vec<f64>, y: &mut Vec<f64>| {
let x_ref: &[f64] = x.as_ref();
let y_mut: &mut [f64] = y.as_mut();
y_mut[0] = 1.0 * x_ref[0] + 2.0 * x_ref[1];
y_mut[1] = 3.0 * x_ref[0] + 4.0 * x_ref[1];
},
|x: &Vec<f64>, y: &mut Vec<f64>| {
let x_ref: &[f64] = x.as_ref();
let y_mut: &mut [f64] = y.as_mut();
y_mut[0] = 1.0 * x_ref[0] + 3.0 * x_ref[1];
y_mut[1] = 2.0 * x_ref[0] + 4.0 * x_ref[1];
},
);
let x = vec![1.0, 0.0];
let mut y = vec![0.0; 2];
matrix.matvec(&x, &mut y);
assert_eq!(y, vec![1.0, 3.0]);
let mut y_trans = vec![0.0; 2];
matrix.mattransvec(&x, &mut y_trans);
assert_eq!(y_trans, vec![1.0, 2.0]);
let x2 = vec![0.0, 1.0];
let mut y2 = vec![0.0; 2];
matrix.matvec(&x2, &mut y2);
assert_eq!(y2, vec![2.0, 4.0]);
let mut y2_trans = vec![0.0; 2];
matrix.mattransvec(&x2, &mut y2_trans);
assert_eq!(y2_trans, vec![3.0, 4.0]); }
}