use num_traits::Zero;
use std::ops::{Add, Mul};
use mdarray::{DSlice, DTensor, Layout, Shape, Slice};
use num_complex::ComplexFloat;
use crate::matmul::{Triangle, Type};
use crate::matvec::{Argmax, MatVec, MatVecBuilder, VecOps};
use crate::utils::unravel_index;
use crate::Naive;
struct BlasMatVecBuilder<'a, T, La, Lx>
where
La: Layout,
Lx: Layout,
{
alpha: T,
a: &'a DSlice<T, 2, La>,
x: &'a DSlice<T, 1, Lx>,
}
impl<'a, T, La, Lx> MatVecBuilder<'a, T, La, Lx> for BlasMatVecBuilder<'a, T, La, Lx>
where
La: Layout,
Lx: Layout,
T: ComplexFloat,
i8: Into<T::Real>,
T::Real: Into<T>,
{
fn parallelize(self) -> Self {
self
}
fn scale(mut self, alpha: T) -> Self {
self.alpha = alpha * self.alpha;
self
}
fn eval(self) -> DTensor<T, 1> {
let mut _y = DTensor::<T, 1>::from_elem(self.x.len(), 0.into().into());
todo!()
}
fn overwrite<Ly: Layout>(self, _y: &mut DSlice<T, 1, Ly>) {
todo!()
}
fn add_to<Ly: Layout>(self, _y: &mut DSlice<T, 1, Ly>) {
todo!()
}
fn add_to_scaled<Ly: Layout>(self, _y: &mut DSlice<T, 1, Ly>, _beta: T) {
todo!()
}
fn add_outer<Ly: Layout>(self, y: &DSlice<T, 1, Ly>, beta: T) -> DTensor<T, 2> {
let mut a_copy = DTensor::<T, 2>::from_elem(*self.a.shape(), 0.into().into());
a_copy.assign(self.a);
let (m, n) = *a_copy.shape();
for i in 0..m {
for j in 0..n {
a_copy[[i, j]] = self.alpha * a_copy[[i, j]] + beta * self.x[[i]] * y[[j]];
}
}
a_copy
}
fn add_outer_special(self, _beta: T, _ty: Type, _tr: Triangle) -> DTensor<T, 2> {
let mut a_copy = DTensor::<T, 2>::from_elem(*self.a.shape(), 0.into().into());
a_copy.assign(self.a);
todo!()
}
}
impl<T> MatVec<T> for Naive
where
T: ComplexFloat,
i8: Into<T::Real>,
T::Real: Into<T>,
{
fn matvec<'a, La, Lx>(
&self,
a: &'a DSlice<T, 2, La>,
x: &'a DSlice<T, 1, Lx>,
) -> impl MatVecBuilder<'a, T, La, Lx>
where
La: Layout,
Lx: Layout,
{
BlasMatVecBuilder {
alpha: 1.into().into(),
a,
x,
}
}
}
impl<T: ComplexFloat + 'static + Add<Output = T> + Mul<Output = T> + Zero + Copy> VecOps<T>
for Naive
{
fn add_to_scaled<Lx: Layout, Ly: Layout>(
&self,
_alpha: T,
_x: &DSlice<T, 1, Lx>,
_y: &mut DSlice<T, 1, Ly>,
) {
todo!()
}
fn dot<Lx: Layout, Ly: Layout>(&self, x: &DSlice<T, 1, Lx>, y: &DSlice<T, 1, Ly>) -> T {
let mut result = T::zero();
for (elem_x, elem_y) in std::iter::zip(x.into_iter(), y.into_iter()) {
result = result + *elem_x * (*elem_y);
}
result
}
fn dotc<Lx: Layout, Ly: Layout>(&self, _x: &DSlice<T, 1, Lx>, _y: &DSlice<T, 1, Ly>) -> T {
todo!()
}
fn norm2<Lx: Layout>(&self, _x: &DSlice<T, 1, Lx>) -> T::Real {
todo!()
}
fn norm1<Lx: Layout>(&self, _x: &DSlice<T, 1, Lx>) -> T::Real
where
T: ComplexFloat,
{
todo!()
}
fn copy<Lx: Layout, Ly: Layout>(&self, _x: &DSlice<T, 1, Lx>, _y: &mut DSlice<T, 1, Ly>) {
todo!()
}
fn scal<Lx: Layout>(&self, _alpha: T, _x: &mut DSlice<T, 1, Lx>) {
todo!()
}
fn swap<Lx: Layout, Ly: Layout>(&self, _x: &mut DSlice<T, 1, Lx>, _y: &mut DSlice<T, 1, Ly>) {
todo!()
}
fn rot<Lx: Layout, Ly: Layout>(
&self,
_x: &mut DSlice<T, 1, Lx>,
_y: &mut DSlice<T, 1, Ly>,
_c: T::Real,
_s: T,
) where
T: ComplexFloat,
{
todo!()
}
}
impl<
T: ComplexFloat<Real = T> + 'static + PartialOrd + Add<Output = T> + Mul<Output = T> + Zero + Copy,
> Argmax<T> for Naive
{
fn argmax_overwrite<Lx: Layout, S: Shape>(
&self,
x: &Slice<T, S, Lx>,
output: &mut Vec<usize>,
) -> bool {
output.clear();
if x.is_empty() {
return false;
}
if x.rank() == 0 {
return true;
}
let mut max_flat_idx = 0;
let mut max_val = x.iter().next().unwrap();
for (flat_idx, val) in x.iter().enumerate().skip(1) {
if val > max_val {
max_val = val;
max_flat_idx = flat_idx;
}
}
let indices = unravel_index(x, max_flat_idx);
output.extend_from_slice(&indices);
true
}
fn argmax<Lx: Layout, S: Shape>(&self, x: &Slice<T, S, Lx>) -> Option<Vec<usize>> {
let mut result = Vec::new();
if self.argmax_overwrite(x, &mut result) {
Some(result)
} else {
None
}
}
fn argmax_abs_overwrite<Lx: Layout, S: Shape>(
&self,
x: &Slice<T, S, Lx>,
output: &mut Vec<usize>,
) -> bool {
output.clear();
if x.is_empty() {
return false;
}
if x.rank() == 0 {
return true;
}
let mut max_flat_idx = 0;
let mut max_val = x.iter().next().unwrap().abs();
for (flat_idx, val) in x.iter().enumerate().skip(1) {
if val.abs() > max_val {
max_val = val.abs();
max_flat_idx = flat_idx;
}
}
let indices = unravel_index(x, max_flat_idx);
output.extend_from_slice(&indices);
true
}
fn argmax_abs<Lx: Layout, S: Shape>(&self, x: &Slice<T, S, Lx>) -> Option<Vec<usize>> {
let mut result = Vec::new();
if self.argmax_abs_overwrite(x, &mut result) {
Some(result)
} else {
None
}
}
}