#![allow(clippy::unnecessary_cast)]
#![allow(clippy::useless_conversion)]
#![allow(clippy::too_many_arguments)]
use crate::prelude_dev::*;
use core::mem::transmute;
use core::num::NonZeroUsize;
use faer::prelude::*;
use faer::traits::ComplexField;
use num::Num;
use rayon::prelude::*;
const PARALLEL_SWITCH: usize = 64;
pub fn gemm_faer<T>(
c: &mut [T],
lc: &Layout<Ix2>,
a: &[T],
la: &Layout<Ix2>,
b: &[T],
lb: &Layout<Ix2>,
alpha: T,
beta: T,
pool: Option<&ThreadPool>,
) -> Result<()>
where
T: ComplexField + Num + MulAssign<T>,
{
let nthreads = pool.map_or_else(|| 1, |pool| pool.current_num_threads());
let sc = lc.shape();
let sa = la.shape();
let sb = lb.shape();
rstsr_assert_eq!(sc[0], sa[0], InvalidLayout)?;
rstsr_assert_eq!(sa[1], sb[0], InvalidLayout)?;
rstsr_assert_eq!(sc[1], sb[1], InvalidLayout)?;
let faer_a = unsafe {
MatRef::from_raw_parts(
a.as_ptr().add(la.offset()) as *const T,
la.shape()[0],
la.shape()[1],
la.stride()[0],
la.stride()[1],
)
};
let faer_b = unsafe {
MatRef::from_raw_parts(
b.as_ptr().add(lb.offset()) as *const T,
lb.shape()[0],
lb.shape()[1],
lb.stride()[0],
lb.stride()[1],
)
};
let faer_c = unsafe {
MatMut::from_raw_parts_mut(
c.as_mut_ptr().add(lc.offset()) as *mut T,
lc.shape()[0],
lc.shape()[1],
lc.stride()[0],
lc.stride()[1],
)
};
if beta == T::zero() {
faer::linalg::matmul::matmul(
faer_c,
faer::Accum::Replace,
faer_a,
faer_b,
alpha,
faer::Par::Rayon(NonZeroUsize::new(nthreads).unwrap()),
);
} else {
if beta != T::one() {
let c = unsafe { transmute::<&mut [T], &mut [MaybeUninit<T>]>(c) };
op_muta_numb_func_cpu_rayon(
c,
lc,
beta,
&mut |vc, vb| unsafe { *vc.assume_init_mut() *= vb.clone() },
pool,
)?;
}
faer::linalg::matmul::matmul(
faer_c,
faer::Accum::Add,
faer_a,
faer_b,
alpha,
faer::Par::Rayon(NonZeroUsize::new(nthreads).unwrap()),
);
}
return Ok(());
}
pub fn syrk_faer<T>(
c: &mut [T],
lc: &Layout<Ix2>,
a: &[T],
la: &Layout<Ix2>,
uplo: FlagUpLo,
alpha: T,
beta: T,
pool: Option<&ThreadPool>,
) -> Result<()>
where
T: ComplexField + Num + MulAssign<T>,
{
let nthreads = pool.map_or_else(|| 1, |pool| pool.current_num_threads());
let sc = lc.shape();
let sa = la.shape();
rstsr_assert_eq!(sc[0], sc[1], InvalidLayout)?;
rstsr_assert_eq!(sc[0], sa[0], InvalidLayout)?;
let faer_a = unsafe {
MatRef::from_raw_parts(
a.as_ptr().add(la.offset()) as *const T,
la.shape()[0],
la.shape()[1],
la.stride()[0],
la.stride()[1],
)
};
let faer_at = unsafe {
MatRef::from_raw_parts(
a.as_ptr().add(la.offset()) as *const T,
la.shape()[1],
la.shape()[0],
la.stride()[1],
la.stride()[0],
)
};
let faer_c = unsafe {
MatMut::from_raw_parts_mut(
c.as_mut_ptr().add(lc.offset()) as *mut T,
lc.shape()[0],
lc.shape()[1],
lc.stride()[0],
lc.stride()[1],
)
};
use faer::linalg::matmul::triangular::BlockStructure;
let block_structure = match uplo {
FlagUpLo::U => BlockStructure::TriangularUpper,
FlagUpLo::L => BlockStructure::TriangularLower,
};
if beta == T::zero() {
faer::linalg::matmul::triangular::matmul(
faer_c,
block_structure,
faer::Accum::Replace,
faer_a,
BlockStructure::Rectangular,
faer_at,
BlockStructure::Rectangular,
alpha,
faer::Par::Rayon(NonZeroUsize::new(nthreads).unwrap()),
);
} else {
if beta != T::one() {
let c = unsafe { transmute::<&mut [T], &mut [MaybeUninit<T>]>(c) };
op_muta_numb_func_cpu_rayon(
c,
lc,
beta,
&mut |vc, vb| unsafe { *vc.assume_init_mut() *= vb.clone() },
pool,
)?;
}
faer::linalg::matmul::triangular::matmul(
faer_c,
block_structure,
faer::Accum::Add,
faer_a,
BlockStructure::Rectangular,
faer_at,
BlockStructure::Rectangular,
alpha,
faer::Par::Rayon(NonZeroUsize::new(nthreads).unwrap()),
);
}
return Ok(());
}
pub fn gemm_with_syrk_faer<T>(
c: &mut [T],
lc: &Layout<Ix2>,
a: &[T],
la: &Layout<Ix2>,
alpha: T,
beta: T,
pool: Option<&ThreadPool>,
) -> Result<()>
where
T: ComplexField + Num + MulAssign<T>,
{
let nthreads = pool.map_or_else(|| 1, |pool| pool.current_num_threads());
if beta != T::zero() {
gemm_faer(c, lc, a, la, a, &la.reverse_axes(), alpha, beta, pool)?;
} else {
syrk_faer(c, lc, a, la, FlagUpLo::L, alpha, beta, pool)?;
let n = lc.shape()[0];
if n < PARALLEL_SWITCH || nthreads == 1 {
for i in 0..n {
for j in 0..i {
let idx_ij = unsafe { lc.index_uncheck(&[i, j]) as usize };
let idx_ji = unsafe { lc.index_uncheck(&[j, i]) as usize };
c[idx_ji] = c[idx_ij].clone();
}
}
} else {
let pool = rayon::ThreadPoolBuilder::new().num_threads(nthreads).build().unwrap();
pool.install(|| {
(0..n).into_par_iter().for_each(|i| {
(0..i).for_each(|j| unsafe {
let idx_ij = lc.index_uncheck(&[i, j]) as usize;
let idx_ji = lc.index_uncheck(&[j, i]) as usize;
let c_ptr_ji = c.as_ptr().add(idx_ji) as *mut T;
*c_ptr_ji = c[idx_ij].clone();
});
});
});
}
}
return Ok(());
}
#[cfg(test)]
mod test {
use super::*;
use std::time::Instant;
#[test]
#[ignore]
fn playground_1() {
let m = 2048;
let n = 2049;
let k = 2050;
let a = (0..m * k).map(|x| x as f64).collect::<Vec<_>>();
let b = (0..k * n).map(|x| x as f64).collect::<Vec<_>>();
let mut c = vec![0.0; m * n];
let la = [m, k].c();
let lb = [k, n].c();
let lc = [m, n].c();
let pool = rayon::ThreadPoolBuilder::new().num_threads(16).build().unwrap();
let pool = Some(&pool);
let start = Instant::now();
gemm_faer(&mut c, &lc, &a, &la, &b, &lb, 1.0, 0.0, pool).unwrap();
println!("time: {:?}", start.elapsed());
let start = Instant::now();
gemm_faer(&mut c, &lc, &a, &la, &b, &lb, 1.0, 0.0, pool).unwrap();
println!("time: {:?}", start.elapsed());
let start = Instant::now();
gemm_faer(&mut c, &lc, &a, &la, &b, &lb, 1.0, 0.0, pool).unwrap();
println!("time: {:?}", start.elapsed());
}
#[test]
fn playground_2() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let mut c = vec![1.0; 4];
let la = [2, 2].c();
let lc = [2, 2].c();
let pool = rayon::ThreadPoolBuilder::new().num_threads(16).build().unwrap();
let pool = Some(&pool);
syrk_faer(&mut c, &lc, &a, &la, FlagUpLo::L, 2.0, 1.0, pool).unwrap();
println!("{c:?}");
}
#[test]
#[cfg(not(feature = "col_major"))]
fn test_minimal_correctness() {
#[allow(non_camel_case_types)]
type c32 = num::Complex<f32>;
let vec_a = vec![
c32::new(0., 1.),
c32::new(1., 2.),
c32::new(2., 3.),
c32::new(3., 4.),
c32::new(4., 5.),
c32::new(5., 6.),
];
let vec_b = vec![c32::new(0., 1.), c32::new(2., 3.), c32::new(4., 5.), c32::new(6., 7.)];
let device = DeviceFaer::default();
let a = asarray((vec_a, &device)).into_shape([3, 2]);
let b = asarray((vec_b, &device)).into_shape([2, 2]);
let c = a % b;
let sum_c = c.raw().iter().sum::<c32>();
assert!(sum_c.re - -78.0 < 1e-5);
assert!(sum_c.im - 270.0 < 1e-5);
}
}