use mdarray::{DSlice, DTensor, Layout, Shape, Slice, tensor};
use num_complex::ComplexFloat;
use num_traits::{One, Zero};
pub fn pretty_print<T: ComplexFloat + std::fmt::Display>(mat: &DTensor<T, 2>)
where
<T as num_complex::ComplexFloat>::Real: std::fmt::Display,
{
let shape = mat.shape();
for i in 0..shape.0 {
for j in 0..shape.1 {
let v = mat[[i, j]];
print!("{:>10.4} {:+.4}i ", v.re(), v.im(),);
}
println!();
}
println!();
}
pub fn into_i32<T>(x: T) -> i32
where
T: TryInto<i32>,
<T as TryInto<i32>>::Error: std::fmt::Debug,
{
x.try_into().expect("dimension must fit into i32")
}
#[macro_export]
macro_rules! get_dims {
( $( $matrix:expr ),+ ) => {
(
$(
{
let shape = $matrix.shape();
(into_i32(shape.0), into_i32(shape.1))
}
),*
)
};
}
pub fn dims3(
a_shape: &(usize, usize),
b_shape: &(usize, usize),
c_shape: &(usize, usize),
) -> (i32, i32, i32) {
let (m, k) = *a_shape;
let (k2, n) = *b_shape;
let (m2, n2) = *c_shape;
assert!(m == m2, "a and c must agree in number of rows");
assert!(n == n2, "b and c must agree in number of columns");
assert!(
k == k2,
"a's number of columns must be equal to b's number of rows"
);
(into_i32(m), into_i32(n), into_i32(k))
}
pub fn dims2(a_shape: &(usize, usize), b_shape: &(usize, usize)) -> (i32, i32) {
let (m, k) = *a_shape;
let (k2, n) = *b_shape;
assert!(
k == k2,
"a's number of columns must be equal to b's number of rows"
);
(into_i32(m), into_i32(n))
}
#[macro_export]
macro_rules! trans_stride {
($x:expr, $same_order:expr, $other_order:expr) => {{
if $x.stride(1) == 1 {
($same_order, into_i32($x.stride(0)))
} else {
{
assert!($x.stride(0) == 1, stringify!($x must be contiguous in one dimension));
($other_order, into_i32($x.stride(1)))
}
}
}};
}
pub fn transpose_in_place<T, L>(c: &mut DSlice<T, 2, L>)
where
T: ComplexFloat + Default,
L: Layout,
{
let (m, n) = *c.shape();
if n == m {
for i in 0..m {
for j in (i + 1)..n {
c.swap(i * n + j, j * n + i);
}
}
} else {
let mut result = tensor![[T::default(); m]; n];
for j in 0..n {
for i in 0..m {
result[j * m + i] = c[i * n + j];
}
}
for j in 0..n {
for i in 0..m {
c[j * m + i] = result[j * m + i];
}
}
}
}
pub fn ipiv_to_perm_mat<T: ComplexFloat>(ipiv: &[i32], m: usize) -> DTensor<T, 2> {
let mut p = tensor![[T::zero(); m]; m];
for i in 0..m {
p[[i, i]] = T::one();
}
for i in 0..ipiv.len() {
let pivot_row = (ipiv[i] - 1) as usize; if pivot_row != i {
for j in 0..m {
let temp = p[[i, j]];
p[[i, j]] = p[[pivot_row, j]];
p[[pivot_row, j]] = temp;
}
}
}
p
}
pub fn to_col_major<T, L>(c: &DSlice<T, 2, L>) -> DTensor<T, 2>
where
T: ComplexFloat + Default + Clone,
L: Layout,
{
let (m, n) = *c.shape();
let mut result = DTensor::<T, 2>::zeros([n, m]);
for i in 0..m {
for j in 0..n {
result[[j, i]] = c[[i, j]];
}
}
result
}
pub fn trace<T, L>(a: &DSlice<T, 2, L>) -> T
where
T: ComplexFloat + std::ops::Add<Output = T> + Copy,
L: Layout,
{
let (m, n) = *a.shape();
assert_eq!(m, n, "trace is only defined for square matrices");
let mut tr = T::zero();
for i in 0..n {
tr = tr + a[[i, i]];
}
tr
}
pub fn identity<T: Zero + One>(n: usize) -> DTensor<T, 2> {
DTensor::<T, 2>::from_fn([n, n], |i| if i[0] == i[1] { T::one() } else { T::zero() })
}
pub fn identity_k<T: Zero + One>(n: usize, k: isize) -> DTensor<T, 2> {
DTensor::<T, 2>::from_fn([n, n], |i| {
if (i[1] as isize - i[0] as isize) == k {
T::one()
} else {
T::zero()
}
})
}
pub fn kron<T, La, Lb>(a: &DSlice<T, 2, La>, b: &DSlice<T, 2, Lb>) -> DTensor<T, 2>
where
T: ComplexFloat + std::ops::Mul<Output = T> + Copy,
La: Layout,
Lb: Layout,
{
let (ma, na) = *a.shape();
let (mb, nb) = *b.shape();
let out_shape = [ma * mb, na * nb];
DTensor::<T, 2>::from_fn(out_shape, |idx| {
let i = idx[0];
let j = idx[1];
let ai = i / mb;
let bi = i % mb;
let aj = j / nb;
let bj = j % nb;
a[[ai, aj]] * b[[bi, bj]]
})
}
pub fn unravel_index<T, S: Shape, L: Layout>(x: &Slice<T, S, L>, mut flat: usize) -> Vec<usize> {
let rank = x.rank();
assert!(
flat < x.len(),
"flat index out of bounds: {} >= {}",
flat,
x.len()
);
let mut coords = vec![0usize; rank];
for i in (0..rank).rev() {
let dim = x.shape().dim(i);
coords[i] = flat % dim;
flat /= dim;
}
coords
}