use num_complex::ComplexFloat;
use num_traits::{One, Zero};
use mdarray::{DSlice, DTensor, DynRank, Layout, Slice, Tensor};
pub enum Side {
Left,
Right,
}
pub enum Type {
Sym,
Her,
Tri,
}
pub enum Triangle {
Upper,
Lower,
}
pub trait MatMul<T: One> {
fn matmul<'a, La, Lb>(
&self,
a: &'a DSlice<T, 2, La>,
b: &'a DSlice<T, 2, Lb>,
) -> impl MatMulBuilder<'a, T, La, Lb>
where
T: One,
La: Layout,
Lb: Layout;
fn contract_all<'a, La, Lb>(
&self,
a: &'a Slice<T, DynRank, La>,
b: &'a Slice<T, DynRank, Lb>,
) -> impl ContractBuilder<'a, T, La, Lb>
where
T: 'a,
La: Layout,
Lb: Layout;
fn contract_n<'a, La, Lb>(
&self,
a: &'a Slice<T, DynRank, La>,
b: &'a Slice<T, DynRank, Lb>,
n: usize,
) -> impl ContractBuilder<'a, T, La, Lb>
where
T: 'a,
La: Layout,
Lb: Layout;
fn contract<'a, La, Lb>(
&self,
a: &'a Slice<T, DynRank, La>,
b: &'a Slice<T, DynRank, Lb>,
axes_a: impl Into<Box<[usize]>>,
axes_b: impl Into<Box<[usize]>>,
) -> impl ContractBuilder<'a, T, La, Lb>
where
T: 'a,
La: Layout,
Lb: Layout;
}
pub trait MatMulBuilder<'a, T, La, Lb>
where
La: Layout,
Lb: Layout,
T: 'a,
La: 'a,
Lb: 'a,
{
fn parallelize(self) -> Self;
fn scale(self, factor: T) -> Self;
fn eval(self) -> DTensor<T, 2>;
fn overwrite<Lc: Layout>(self, c: &mut DSlice<T, 2, Lc>);
fn add_to<Lc: Layout>(self, c: &mut DSlice<T, 2, Lc>);
fn add_to_scaled<Lc: Layout>(self, c: &mut DSlice<T, 2, Lc>, beta: T);
fn special(self, lr: Side, type_of_matrix: Type, tr: Triangle) -> DTensor<T, 2>;
}
pub trait ContractBuilder<'a, T, La, Lb>
where
T: 'a,
La: Layout,
Lb: Layout,
{
fn scale(self, factor: T) -> Self;
fn eval(self) -> Tensor<T, DynRank>;
fn overwrite(self, c: &mut Slice<T>);
}
pub enum Axes {
All,
LastFirst { k: usize },
Specific(Box<[usize]>, Box<[usize]>),
}
pub fn _contract<T: Zero + ComplexFloat, La: Layout, Lb: Layout>(
bd: impl MatMul<T>,
a: &Slice<T, DynRank, La>,
b: &Slice<T, DynRank, Lb>,
axes: Axes,
alpha: T,
) -> Tensor<T, DynRank> {
let rank_a = a.rank();
let rank_b = b.rank();
let extract_shape = |s: &DynRank| match s {
DynRank::Dyn(arr) => arr.clone(),
DynRank::One(n) => Box::new([*n]),
};
let shape_a = extract_shape(a.shape());
let shape_b = extract_shape(b.shape());
let (axes_a, axes_b) = match axes {
Axes::All => ((0..rank_a).collect(), (0..rank_b).collect()),
Axes::LastFirst { k } => (((rank_a - k)..rank_a).collect(), (0..k).collect()),
Axes::Specific(ax_a, ax_b) => (ax_a, ax_b),
};
assert_eq!(
axes_a.len(),
axes_b.len(),
"Axis count mismatch: {} (tensor A) vs {} (tensor B)",
axes_a.len(),
axes_b.len()
);
axes_a.iter().zip(&axes_b).for_each(|(a_ax, b_ax)| {
assert_eq!(
shape_a[*a_ax], shape_b[*b_ax],
"Dimension mismatch at contraction: A[axis {}] = {} ≠ B[axis {}] = {}",
*a_ax, shape_a[*a_ax], *b_ax, shape_b[*b_ax]
);
});
let compute_keep_axes = |rank: usize, axes: &[usize]| -> Vec<usize> {
(0..rank).filter(|k| !axes.contains(k)).collect()
};
let keep_axes_a = compute_keep_axes(rank_a, &axes_a);
let keep_axes_b = compute_keep_axes(rank_b, &axes_b);
let compute_keep_shape = |axes: &[usize], shape: &[usize]| -> Vec<usize> {
axes.iter().map(|&ax| shape[ax]).collect()
};
let mut keep_shape_a = compute_keep_shape(&keep_axes_a, &shape_a);
let keep_shape_b = compute_keep_shape(&keep_axes_b, &shape_b);
let compute_size =
|axes: &[usize], shape: &[usize]| -> usize { axes.iter().map(|&k| shape[k]).product() };
let contract_size_a = compute_size(&axes_a, &shape_a);
let contract_size_b = compute_size(&axes_b, &shape_b);
let keep_size_a = compute_size(&keep_axes_a, &shape_a);
let keep_size_b = compute_size(&keep_axes_b, &shape_b);
let order_a: Vec<usize> = keep_axes_a.iter().chain(axes_a.iter()).copied().collect();
let order_b: Vec<usize> = axes_b.iter().chain(keep_axes_b.iter()).copied().collect();
let trans_a = a.permute(order_a).to_tensor();
let trans_b = b.permute(order_b).to_tensor();
let a_resh = trans_a.reshape([keep_size_a, contract_size_a]);
let b_resh = trans_b.reshape([contract_size_b, keep_size_b]);
let ab_resh = bd.matmul(&a_resh, &b_resh).scale(alpha).eval();
if keep_shape_a.is_empty() && keep_shape_b.is_empty() {
ab_resh.to_owned().into_dyn()
} else if keep_shape_a.is_empty() {
ab_resh
.view(0, ..)
.reshape(keep_shape_a)
.to_owned()
.into_dyn()
.into()
} else if keep_shape_b.is_empty() {
ab_resh
.view(.., 0)
.reshape(keep_shape_b)
.to_owned()
.into_dyn()
.into()
} else {
keep_shape_a.extend(keep_shape_b);
ab_resh.reshape(keep_shape_a).to_owned().into_dyn().into()
}
}