use crate::ArgminDot;
use faer::{mat::AsMatRef, Mat, MatRef};
use faer_traits::ComplexField;
use std::ops::Mul;
mod matrix_matrix_multiplication {
use super::*;
impl<'a, E: ComplexField> ArgminDot<MatRef<'a, E>, Mat<E>> for MatRef<'_, E> {
#[inline]
fn dot(&self, other: &MatRef<'a, E>) -> Mat<E> {
self * other
}
}
impl<E: ComplexField> ArgminDot<Mat<E>, Mat<E>> for MatRef<'_, E> {
#[inline]
fn dot(&self, other: &Mat<E>) -> Mat<E> {
<_ as ArgminDot<_, _>>::dot(self, &other.as_mat_ref())
}
}
impl<'a, E: ComplexField> ArgminDot<MatRef<'a, E>, Mat<E>> for Mat<E> {
#[inline]
fn dot(&self, other: &MatRef<'a, E>) -> Mat<E> {
<_ as ArgminDot<_, _>>::dot(&self.as_mat_ref(), other)
}
}
impl<E: ComplexField> ArgminDot<Mat<E>, Mat<E>> for Mat<E> {
#[inline]
fn dot(&self, other: &Mat<E>) -> Mat<E> {
<_ as ArgminDot<_, _>>::dot(&self.as_mat_ref(), &other.as_mat_ref())
}
}
}
mod scalar_product {
use faer_traits::Conjugate;
use super::*;
impl<'a, E: ComplexField + Conjugate<Conj = E>> ArgminDot<MatRef<'a, E>, E> for MatRef<'_, E> {
#[inline]
fn dot(&self, other: &MatRef<'a, E>) -> E {
assert!(
(self.nrows() == 1 || self.ncols() == 1)
&& (other.nrows() == 1 || other.ncols() == 1),
"arguments for dot product must be vectors"
);
let count = std::cmp::max(self.nrows(), self.ncols());
let count_rhs = std::cmp::max(other.nrows(), other.ncols());
assert_eq!(
count, count_rhs,
"vectors for dot product must have same number of elements"
);
let value: Mat<E> = <_ as ArgminDot<_, _>>::dot(
&self.as_shape(count, 1).conjugate().transpose(),
&other.as_shape(count, 1),
);
debug_assert_eq!(value.nrows(), 1);
debug_assert_eq!(value.ncols(), 1);
value[(0, 0)].clone()
}
}
impl<E: ComplexField + Conjugate<Conj = E>> ArgminDot<Mat<E>, E> for MatRef<'_, E> {
#[inline]
fn dot(&self, other: &Mat<E>) -> E {
<_ as ArgminDot<_, _>>::dot(self, &other.as_mat_ref())
}
}
impl<'a, E: ComplexField + Conjugate<Conj = E>> ArgminDot<MatRef<'a, E>, E> for Mat<E> {
#[inline]
fn dot(&self, other: &MatRef<'a, E>) -> E {
<_ as ArgminDot<_, _>>::dot(&self.as_mat_ref(), other)
}
}
impl<E: ComplexField + Conjugate<Conj = E>> ArgminDot<Mat<E>, E> for Mat<E> {
#[inline]
fn dot(&self, other: &Mat<E>) -> E {
<_ as ArgminDot<_, _>>::dot(&self.as_mat_ref(), &other.as_mat_ref())
}
}
}
mod multiply_matrix_with_scalar {
use super::*;
use crate::ArgminMul;
use faer_traits::ComplexField;
use std::ops::Mul;
impl<E: ComplexField> ArgminDot<E, Mat<E>> for MatRef<'_, E> {
#[inline]
fn dot(&self, other: &E) -> Mat<E> {
<Self as ArgminMul<E, _>>::mul(self, other)
}
}
impl<E: ComplexField> ArgminDot<E, Mat<E>> for Mat<E> {
#[inline]
fn dot(&self, other: &E) -> Mat<E> {
<_ as ArgminDot<E, _>>::dot(&self.as_mat_ref(), other)
}
}
impl<'a, E: ComplexField> ArgminDot<MatRef<'a, E>, Mat<E>> for E {
#[inline]
fn dot(&self, other: &MatRef<'a, E>) -> Mat<E> {
<E as ArgminMul<MatRef<'a, E>, _>>::mul(self, other)
}
}
impl<E: ComplexField> ArgminDot<Mat<E>, Mat<E>> for E {
#[inline]
fn dot(&self, other: &Mat<E>) -> Mat<E> {
<E as ArgminDot<_, _>>::dot(self, &other.as_mat_ref())
}
}
}