use super::{
super::{dot, matmul},
Matrix, Vector,
};
pub trait Dot<T, S> {
fn dot(&self, other: T) -> S;
fn dot_t(&self, other: T) -> S;
fn t_dot(&self, other: T) -> S;
fn t_dot_t(&self, other: T) -> S;
}
macro_rules! impl_macro_for_types {
($macro: ident, $t1: ty, $t2: ty) => {
$macro!($t1, $t2);
$macro!($t1, &$t2);
$macro!(&$t1, $t2);
$macro!(&$t1, &$t2);
};
}
macro_rules! impl_mat_mat_dot {
($selftype: ty, $othertype: ty) => {
impl Dot<$othertype, Matrix> for $selftype {
fn dot(&self, other: $othertype) -> Matrix {
assert_eq!(self.ncols, other.nrows, "matrix shapes not compatible");
let output = matmul(
&self.data(),
&other.data(),
self.nrows,
other.nrows,
false,
false,
);
Matrix::new(output, self.nrows as i32, other.ncols as i32)
}
fn t_dot(&self, other: $othertype) -> Matrix {
assert_eq!(self.nrows, other.nrows, "matrix shapes not compatible");
let output = matmul(
&self.data(),
&other.data(),
self.nrows,
other.nrows,
true,
false,
);
Matrix::new(output, self.ncols as i32, other.ncols as i32)
}
fn dot_t(&self, other: $othertype) -> Matrix {
assert_eq!(self.ncols, other.ncols, "matrix shapes not compatible");
let output = matmul(
&self.data(),
&other.data(),
self.nrows,
other.nrows,
false,
true,
);
Matrix::new(output, self.nrows as i32, other.nrows as i32)
}
fn t_dot_t(&self, other: $othertype) -> Matrix {
assert_eq!(self.nrows, other.ncols, "matrix shapes not compatible");
let output = matmul(
&self.data(),
&other.data(),
self.nrows,
other.nrows,
true,
true,
);
Matrix::new(output, self.ncols as i32, other.nrows as i32)
}
}
};
}
impl_macro_for_types!(impl_mat_mat_dot, Matrix, Matrix);
macro_rules! impl_dot_append_one {
($othertype: ty, $innerop: ident, $($op: ident),+) => {
$(
fn $op(&self, other: $othertype) -> Vector {
let mut o = other.clone().to_owned().to_matrix();
o.t_mut();
self.$innerop(o).to_vec()
}
)+
}
}
macro_rules! impl_mat_vec_dot {
($selftype: ty, $othertype: ty) => {
impl Dot<$othertype, Vector> for $selftype {
impl_dot_append_one!($othertype, dot, dot, dot_t);
impl_dot_append_one!($othertype, t_dot, t_dot, t_dot_t);
}
};
}
impl_macro_for_types!(impl_mat_vec_dot, Matrix, Vector);
macro_rules! impl_dot_prepend_one {
($othertype: ty, $innerop: ident, $($op: ident),+) => {
$(
fn $op(&self, other: $othertype) -> Vector {
self.clone().to_owned().to_matrix().$innerop(other).to_vec()
}
)+
}
}
macro_rules! impl_vec_mat_dot {
($selftype: ty, $othertype: ty) => {
impl Dot<$othertype, Vector> for $selftype {
impl_dot_prepend_one!($othertype, dot, dot, t_dot);
impl_dot_prepend_one!($othertype, dot_t, dot_t, t_dot_t);
}
};
}
impl_macro_for_types!(impl_vec_mat_dot, Vector, Matrix);
macro_rules! impl_dot_vec_vec {
($othertype: ty, $($op: ident),+) => {
$(
fn $op(&self, other: $othertype) -> f64 {
dot(&self.data(), &other.data())
}
)+
}
}
macro_rules! impl_vec_vec_dot {
($selftype: ty, $othertype: ty) => {
impl Dot<$othertype, f64> for $selftype {
impl_dot_vec_vec!($othertype, dot, t_dot, dot_t, t_dot_t);
}
};
}
impl_macro_for_types!(impl_vec_vec_dot, Vector, Vector);