use crate::math::scalar::Scalar;
use num_traits::Zero;
use rayon::prelude::*;
use super::errors;
use super::tensor_trait::TensorTrait;
#[inline]
pub fn size(shape: &[usize]) -> usize {
errors::checked_num_elements(shape).unwrap_or_else(|error| panic!("{error}"))
}
#[inline]
pub fn assert_same_shape<T, Lhs, Rhs>(lhs: &Lhs, rhs: &Rhs)
where
T: Scalar,
Lhs: TensorTrait<T>,
Rhs: TensorTrait<T>,
{
errors::ensure_same_shape(lhs.shape(), rhs.shape()).unwrap_or_else(|error| panic!("{error}"));
}
#[inline]
pub fn map<T, A, F>(tensor: &A, f: F) -> A
where
T: Scalar + Copy + Send + Sync,
A: TensorTrait<T>,
F: Fn(T) -> T + Sync + Send,
{
let mut out = tensor.clone();
out.par_map_in_place(f);
out
}
#[inline]
pub fn zip_with<T, Lhs, Rhs, F>(lhs: &Lhs, rhs: &Rhs, f: F) -> Lhs
where
T: Scalar + Copy + Send + Sync,
Lhs: TensorTrait<T>,
Rhs: TensorTrait<T>,
F: Fn(T, T) -> T + Sync + Send,
{
assert_same_shape::<T, _, _>(lhs, rhs);
let mut out = lhs.clone();
out.par_zip_with_inplace(rhs, f);
out
}
#[inline]
pub fn conj<T, A>(tensor: &A) -> A
where
T: Scalar + Copy + Send + Sync,
A: TensorTrait<T>,
{
map(tensor, Scalar::conj)
}
#[inline]
pub fn abs<T, A>(tensor: &A) -> A
where
T: Scalar + Copy + Send + Sync,
A: TensorTrait<T>,
{
map(tensor, Scalar::abs)
}
#[inline]
pub fn norm_sqr<T, A>(tensor: &A) -> A
where
T: Scalar + Copy + Send + Sync,
A: TensorTrait<T>,
{
map(tensor, Scalar::norm_sqr)
}
#[inline]
pub fn sqrt<T, A>(tensor: &A) -> A
where
T: Scalar + Copy + Send + Sync,
A: TensorTrait<T>,
{
map(tensor, Scalar::sqrt)
}
#[inline]
pub fn scalar_mul<T, A>(tensor: &A, scalar: T) -> A
where
T: Scalar + Copy + Send + Sync,
A: TensorTrait<T>,
{
map(tensor, move |x| x * scalar)
}
#[inline]
pub fn elem_mul<T, Lhs, Rhs>(lhs: &Lhs, rhs: &Rhs) -> Lhs
where
T: Scalar + Copy + Send + Sync,
Lhs: TensorTrait<T>,
Rhs: TensorTrait<T>,
{
zip_with(lhs, rhs, |a, b| a * b)
}
#[inline]
pub fn elem_div<T, Lhs, Rhs>(lhs: &Lhs, rhs: &Rhs) -> Lhs
where
T: Scalar + Copy + Send + Sync,
Lhs: TensorTrait<T>,
Rhs: TensorTrait<T>,
{
zip_with(lhs, rhs, |a, b| a / b)
}
pub fn transpose<T, A>(tensor: &A) -> A
where
T: Scalar + Copy + Send + Sync,
A: TensorTrait<T>,
{
assert_eq!(tensor.rank(), 2, "transpose requires a rank-2 tensor");
let rows = tensor.shape()[0];
let cols = tensor.shape()[1];
let entries: Vec<((isize, isize), T)> = (0..size(&[rows, cols]))
.into_par_iter()
.map(|k| {
let i = k / cols;
let j = k % cols;
(
(j as isize, i as isize),
tensor.get(&[i as isize, j as isize]),
)
})
.collect();
let mut out = A::empty(&[cols, rows]);
for (idx, value) in entries {
out.set(&[idx.0, idx.1], value);
}
out
}
pub fn hermitian_transpose<T, A>(tensor: &A) -> A
where
T: Scalar + Copy + Send + Sync,
A: TensorTrait<T>,
{
assert_eq!(
tensor.rank(),
2,
"hermitian_transpose requires a rank-2 tensor"
);
let rows = tensor.shape()[0];
let cols = tensor.shape()[1];
let entries: Vec<((isize, isize), T)> = (0..size(&[rows, cols]))
.into_par_iter()
.map(|k| {
let i = k / cols;
let j = k % cols;
(
(j as isize, i as isize),
tensor.get(&[i as isize, j as isize]).conj(),
)
})
.collect();
let mut out = A::empty(&[cols, rows]);
for (idx, value) in entries {
out.set(&[idx.0, idx.1], value);
}
out
}
pub fn dot<T, Lhs, Rhs>(lhs: &Lhs, rhs: &Rhs) -> T
where
T: Scalar + Copy + Send + Sync,
Lhs: TensorTrait<T>,
Rhs: TensorTrait<T>,
{
assert_same_shape::<T, _, _>(lhs, rhs);
(0..lhs.size())
.into_par_iter()
.map(|k| {
let idx = flat_to_index(lhs.shape(), k);
lhs.get(&idx) * rhs.get(&idx)
})
.reduce(|| T::zero(), |a, b| a + b)
}
pub fn hermitian_dot<T, Lhs, Rhs>(lhs: &Lhs, rhs: &Rhs) -> T
where
T: Scalar + Copy + Send + Sync,
Lhs: TensorTrait<T>,
Rhs: TensorTrait<T>,
{
assert_same_shape::<T, _, _>(lhs, rhs);
(0..lhs.size())
.into_par_iter()
.map(|k| {
let idx = flat_to_index(lhs.shape(), k);
lhs.get(&idx).conj() * rhs.get(&idx)
})
.reduce(|| T::zero(), |a, b| a + b)
}
pub fn norm_sqr_real<T, A>(tensor: &A) -> T::Real
where
T: Scalar + Copy + Send + Sync,
T::Real: Send + Sync,
A: TensorTrait<T>,
{
(0..tensor.size())
.into_par_iter()
.map(|k| {
let idx = flat_to_index(tensor.shape(), k);
tensor.get(&idx).norm_sqr_real()
})
.reduce(|| T::Real::zero(), |a, b| a + b)
}
pub fn norm<T, A>(tensor: &A) -> T
where
T: Scalar + Copy + Send + Sync,
T::Real: Send + Sync,
A: TensorTrait<T>,
{
T::from_re_im(norm_sqr_real::<T, A>(tensor).sqrt(), T::Real::zero())
}
pub fn cross<T, Lhs, Rhs>(lhs: &Lhs, rhs: &Rhs) -> Lhs
where
T: Scalar + Copy + Send + Sync,
Lhs: TensorTrait<T>,
Rhs: TensorTrait<T>,
{
assert_vector_len(lhs, 3, "cross");
assert_vector_len(rhs, 3, "cross");
let a0 = lhs.get(&[0]);
let a1 = lhs.get(&[1]);
let a2 = lhs.get(&[2]);
let b0 = rhs.get(&[0]);
let b1 = rhs.get(&[1]);
let b2 = rhs.get(&[2]);
let mut out = Lhs::empty(&[3]);
let values: Vec<(usize, T)> = (0..3)
.into_par_iter()
.map(|i| {
let value = match i {
0 => a1 * b2 - a2 * b1,
1 => a2 * b0 - a0 * b2,
_ => a0 * b1 - a1 * b0,
};
(i, value)
})
.collect();
for (i, value) in values {
out.set(&[i as isize], value);
}
out
}
pub fn wedge<T, Lhs, Rhs>(lhs: &Lhs, rhs: &Rhs) -> Lhs
where
T: Scalar + Copy + Send + Sync,
Lhs: TensorTrait<T>,
Rhs: TensorTrait<T>,
{
assert_vector(lhs, "wedge");
assert_vector(rhs, "wedge");
let n = lhs.shape()[0];
assert_eq!(rhs.shape()[0], n, "wedge vector length mismatch");
let entries: Vec<((isize, isize), T)> = (0..size(&[n, n]))
.into_par_iter()
.map(|k| {
let i = k / n;
let j = k % n;
let value = lhs.get(&[i as isize]) * rhs.get(&[j as isize])
- lhs.get(&[j as isize]) * rhs.get(&[i as isize]);
((i as isize, j as isize), value)
})
.collect();
let mut out = Lhs::empty(&[n, n]);
for (idx, value) in entries {
out.set(&[idx.0, idx.1], value);
}
out
}
pub fn matmul<T, Lhs, Rhs>(lhs: &Lhs, rhs: &Rhs) -> Lhs
where
T: Scalar + Copy + Send + Sync,
Lhs: TensorTrait<T>,
Rhs: TensorTrait<T>,
{
assert_eq!(lhs.rank(), 2, "matmul lhs must be rank 2");
assert_eq!(rhs.rank(), 2, "matmul rhs must be rank 2");
let rows = lhs.shape()[0];
let inner = lhs.shape()[1];
assert_eq!(rhs.shape()[0], inner, "matmul inner dimensions mismatch");
let cols = rhs.shape()[1];
let entries: Vec<((isize, isize), T)> = (0..size(&[rows, cols]))
.into_par_iter()
.map(|k| {
let i = k / cols;
let j = k % cols;
let value = (0..inner).fold(T::zero(), |acc, m| {
acc + lhs.get(&[i as isize, m as isize]) * rhs.get(&[m as isize, j as isize])
});
((i as isize, j as isize), value)
})
.collect();
let mut out = Lhs::empty(&[rows, cols]);
for (idx, value) in entries {
out.set(&[idx.0, idx.1], value);
}
out
}
fn flat_to_index(shape: &[usize], mut flat: usize) -> Vec<isize> {
let mut idx = vec![0isize; shape.len()];
for axis in (0..shape.len()).rev() {
let dim = shape[axis];
idx[axis] = (flat % dim) as isize;
flat /= dim;
}
idx
}
fn assert_vector<T, A>(tensor: &A, op: &str)
where
T: Scalar,
A: TensorTrait<T>,
{
assert_eq!(tensor.rank(), 1, "{op} requires rank-1 vectors");
}
fn assert_vector_len<T, A>(tensor: &A, len: usize, op: &str)
where
T: Scalar,
A: TensorTrait<T>,
{
assert_vector::<T, A>(tensor, op);
assert_eq!(tensor.shape()[0], len, "{op} requires length-{len} vectors");
}