use ndarray::ArrayD;
use ndarray::IxDyn;
pub fn tensordot(
a: &ArrayD<f64>,
b: &ArrayD<f64>,
axes_a: &[usize],
axes_b: &[usize],
) -> Result<ArrayD<f64>, String> {
if axes_a.len() != axes_b.len() {
return Err("Contracted axes \
must have the \
same length."
.to_string());
}
for (&ax_a, &ax_b) in axes_a.iter().zip(axes_b.iter()) {
if a.shape()[ax_a] != b.shape()[ax_b] {
return Err(format!(
"Dimension mismatch \
on contracted axes: \
{} != {}",
a.shape()[ax_a],
b.shape()[ax_b]
));
}
}
let free_axes_a: Vec<_> = (0..a.ndim()).filter(|i| !axes_a.contains(i)).collect();
let free_axes_b: Vec<_> = (0..b.ndim()).filter(|i| !axes_b.contains(i)).collect();
let perm_a: Vec<_> = free_axes_a.iter().chain(axes_a.iter()).copied().collect();
let perm_b: Vec<_> = axes_b.iter().chain(free_axes_b.iter()).copied().collect();
let a_perm = a.clone().permuted_axes(perm_a);
let b_perm = b.clone().permuted_axes(perm_b);
let free_dim_a = free_axes_a.iter().map(|&i| a.shape()[i]).product::<usize>();
let free_dim_b = free_axes_b.iter().map(|&i| b.shape()[i]).product::<usize>();
let contracted_dim = axes_a.iter().map(|&i| a.shape()[i]).product::<usize>();
let a_mat = a_perm
.to_shape((free_dim_a, contracted_dim))
.map_err(|e| e.to_string())?
.to_owned();
let b_mat = b_perm
.to_shape((contracted_dim, free_dim_b))
.map_err(|e| e.to_string())?
.to_owned();
let result_mat = a_mat.dot(&b_mat);
let mut final_shape_dims = Vec::new();
final_shape_dims.extend(free_axes_a.iter().map(|&i| a.shape()[i]));
final_shape_dims.extend(free_axes_b.iter().map(|&i| b.shape()[i]));
Ok(result_mat
.to_shape(IxDyn(&final_shape_dims))
.map_err(|e| e.to_string())?
.to_owned())
}
pub fn outer_product(
a: &ArrayD<f64>,
b: &ArrayD<f64>,
) -> Result<ArrayD<f64>, String> {
let mut new_shape = a.shape().to_vec();
new_shape.extend_from_slice(b.shape());
let a_flat = a.as_slice().ok_or_else(|| {
"Input tensor 'a' is not \
contiguous"
.to_string()
})?;
let b_flat = b.as_slice().ok_or_else(|| {
"Input tensor 'b' is not \
contiguous"
.to_string()
})?;
let mut result_data = Vec::with_capacity(a.len() * b.len());
for val_a in a_flat {
for val_b in b_flat {
result_data.push(val_a * val_b);
}
}
ArrayD::from_shape_vec(IxDyn(&new_shape), result_data).map_err(|e| e.to_string())
}
pub fn tensor_vec_mul(
tensor: &ArrayD<f64>,
vector: &[f64],
) -> Result<ArrayD<f64>, String> {
if tensor.ndim() < 1 {
return Err("Tensor must \
have at least \
one dimension."
.to_string());
}
let last_dim = tensor.shape()[tensor.ndim() - 1];
if last_dim != vector.len() {
return Err(format!(
"Dimension mismatch: last \
tensor dim {} != vector \
length {}",
last_dim,
vector.len()
));
}
let vec_arr = ndarray::Array1::from_vec(vector.to_vec());
let res = tensordot(tensor, &vec_arr.into_dyn(), &[tensor.ndim() - 1], &[0])?;
Ok(res)
}
pub fn inner_product(
a: &ArrayD<f64>,
b: &ArrayD<f64>,
) -> Result<f64, String> {
if a.shape() != b.shape() {
return Err("Tensors must \
have the same \
shape for inner \
product."
.to_string());
}
let a_flat = a.as_slice().ok_or("Tensor 'a' is not contiguous")?;
let b_flat = b.as_slice().ok_or("Tensor 'b' is not contiguous")?;
Ok(a_flat.iter().zip(b_flat.iter()).map(|(x, y)| x * y).sum())
}
pub fn contract(
a: &ArrayD<f64>,
axis1: usize,
axis2: usize,
) -> Result<ArrayD<f64>, String> {
if axis1 == axis2 {
return Err("Axes must be \
different for \
contraction."
.to_string());
}
if a.shape()[axis1] != a.shape()[axis2] {
return Err("Dimensions \
along contraction \
axes must be \
equal."
.to_string());
}
let n = a.shape()[axis1];
#[warn(clippy::collection_is_never_read)]
let mut new_shape = Vec::new();
for i in 0..a.ndim() {
if i != axis1 && i != axis2 {
new_shape.push(a.shape()[i]);
}
}
if a.ndim() == 2 {
let mut sum = 0.0;
for i in 0..n {
sum += a[[i, i]];
}
return Ok(ndarray::Array0::from_elem((), sum).into_dyn());
}
Err("General tensor contraction \
(trace) for rank > 2 not yet \
implemented."
.to_string())
}
#[must_use]
pub fn norm(a: &ArrayD<f64>) -> f64 {
a.iter().map(|x| x * x).sum::<f64>().sqrt()
}
use serde::Deserialize;
use serde::Serialize;
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct TensorData {
pub shape: Vec<usize>,
pub data: Vec<f64>,
}
impl From<&ArrayD<f64>> for TensorData {
fn from(arr: &ArrayD<f64>) -> Self {
Self {
shape: arr.shape().to_vec(),
data: arr.clone().into_raw_vec_and_offset().0,
}
}
}
impl TensorData {
pub fn to_arrayd(&self) -> Result<ArrayD<f64>, String> {
ArrayD::from_shape_vec(IxDyn(&self.shape), self.data.clone()).map_err(|e| e.to_string())
}
}
#[cfg(test)]
mod tests {
use ndarray::array;
use super::*;
#[test]
fn test_tensordot() {
let a = array![[1.0, 2.0], [3.0, 4.0]].into_dyn();
let b = array![[5.0, 6.0], [7.0, 8.0]].into_dyn();
let res = tensordot(&a, &b, &[1], &[0]).unwrap();
assert_eq!(res.shape(), &[2, 2]);
assert_eq!(res[[0, 0]], 1.0 * 5.0 + 2.0 * 7.0);
}
#[test]
fn test_outer_product() {
let a = array![1.0, 2.0].into_dyn();
let b = array![3.0, 4.0].into_dyn();
let res = outer_product(&a, &b).unwrap();
assert_eq!(res.shape(), &[2, 2]);
assert_eq!(res[[0, 0]], 3.0);
assert_eq!(res[[1, 1]], 8.0);
}
}