use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::Float;
use std::fmt::Debug;
pub fn einsum<T: Float + Clone + Debug + std::ops::AddAssign + 'static>(
subscripts: &str,
operands: &[&Array<T>],
) -> Result<Array<T>> {
let parts: Vec<&str> = subscripts.split("->").collect();
if parts.len() != 2 {
return Err(NumRs2Error::InvalidOperation(
"einsum subscripts must contain exactly one '->'".to_string(),
));
}
let input_spec = parts[0];
let output_spec = parts[1];
let operand_specs: Vec<&str> = input_spec.split(',').collect();
if operand_specs.len() != operands.len() {
return Err(NumRs2Error::InvalidOperation(format!(
"Number of operand specs ({}) doesn't match number of operands ({})",
operand_specs.len(),
operands.len()
)));
}
if operand_specs.len() == 2
&& operand_specs[0] == "ij"
&& operand_specs[1] == "jk"
&& output_spec == "ik"
{
return operands[0].matmul(operands[1]);
}
if operand_specs.len() == 2
&& operand_specs[0] == "i"
&& operand_specs[1] == "i"
&& output_spec.is_empty()
{
use crate::linalg::vector_ops::vdot;
let result = vdot(operands[0], operands[1])?;
return Ok(Array::from_vec(vec![result]));
}
if operand_specs.len() == 1 && operand_specs[0] == "ii" && output_spec.is_empty() {
use crate::linalg::vector_ops::trace;
let result = trace(operands[0])?;
return Ok(Array::from_vec(vec![result]));
}
if operand_specs.len() == 1 && operand_specs[0] == "ij" && output_spec == "ji" {
return Ok(operands[0].transpose());
}
if operand_specs.len() == 1 && operand_specs[0] == "ii" && output_spec == "i" {
use crate::array_ops::diagonal::diag;
return diag(operands[0], None);
}
if operand_specs.len() == 2
&& operand_specs[0] == "i"
&& operand_specs[1] == "j"
&& output_spec == "ij"
{
use crate::linalg::vector_ops::outer;
return outer(operands[0], operands[1]);
}
if operand_specs.len() == 2
&& operand_specs[0] == operand_specs[1]
&& operand_specs[0] == output_spec
{
let a_data = operands[0].to_vec();
let b_data = operands[1].to_vec();
let result_data: Vec<T> = a_data
.iter()
.zip(b_data.iter())
.map(|(a, b)| *a * *b)
.collect();
return Ok(Array::from_vec(result_data).reshape(&operands[0].shape()));
}
if operand_specs.len() == 1 && operand_specs[0].len() == 2 && output_spec.len() == 1 {
let input_chars: Vec<char> = operand_specs[0].chars().collect();
let output_char = output_spec.chars().next().unwrap_or_default();
if input_chars.contains(&output_char) {
let sum_axis = if input_chars[0] == output_char { 1 } else { 0 };
return operands[0].sum_axis(sum_axis);
}
}
einsum_general(subscripts, operands)
}
fn einsum_general<T: Float + Clone + Debug + std::ops::AddAssign>(
subscripts: &str,
operands: &[&Array<T>],
) -> Result<Array<T>> {
let parts: Vec<&str> = subscripts.split("->").collect();
let input_spec = parts[0];
let output_spec = parts[1];
let operand_specs: Vec<&str> = input_spec.split(',').collect();
let mut all_indices = std::collections::HashSet::new();
for spec in &operand_specs {
for ch in spec.chars() {
if ch.is_alphabetic() {
all_indices.insert(ch);
}
}
}
let output_indices: Vec<char> = output_spec.chars().filter(|c| c.is_alphabetic()).collect();
let summation_indices: Vec<char> = all_indices
.iter()
.filter(|&&idx| !output_indices.contains(&idx))
.copied()
.collect();
let mut index_sizes = std::collections::HashMap::new();
for (op_idx, &operand) in operands.iter().enumerate() {
let spec = operand_specs[op_idx];
let shape = operand.shape();
for (dim_idx, idx_char) in spec.chars().enumerate() {
if idx_char.is_alphabetic() {
let size = shape[dim_idx];
if let Some(&existing_size) = index_sizes.get(&idx_char) {
if existing_size != size {
return Err(NumRs2Error::DimensionMismatch(format!(
"Index '{}' has inconsistent sizes: {} and {}",
idx_char, existing_size, size
)));
}
} else {
index_sizes.insert(idx_char, size);
}
}
}
}
let output_shape: Vec<usize> = output_indices
.iter()
.map(|&idx| index_sizes[&idx])
.collect();
let output_shape = if output_shape.is_empty() {
vec![1]
} else {
output_shape
};
let mut result = Array::zeros(&output_shape);
let total_output_size: usize = output_shape.iter().product();
for output_idx in 0..total_output_size {
let mut output_multi_idx = vec![0; output_shape.len()];
let mut temp = output_idx;
for i in (0..output_shape.len()).rev() {
output_multi_idx[i] = temp % output_shape[i];
temp /= output_shape[i];
}
let mut index_values = std::collections::HashMap::new();
for (i, &idx_char) in output_indices.iter().enumerate() {
if !output_shape.is_empty() && output_shape[0] != 1 {
index_values.insert(idx_char, output_multi_idx[i]);
}
}
let mut sum = T::zero();
let summation_ranges: Vec<usize> = summation_indices
.iter()
.map(|&idx| index_sizes[&idx])
.collect();
if summation_ranges.is_empty() {
let mut product = T::one();
for (op_idx, &operand) in operands.iter().enumerate() {
let spec = operand_specs[op_idx];
let op_shape = operand.shape();
let mut op_indices = vec![0; op_shape.len()];
for (dim_idx, idx_char) in spec.chars().enumerate() {
if idx_char.is_alphabetic() {
op_indices[dim_idx] = index_values[&idx_char];
}
}
product = product * operand.get(&op_indices)?;
}
sum += product;
} else {
let total_summation_size: usize = summation_ranges.iter().product();
for sum_idx in 0..total_summation_size {
let mut sum_multi_idx = vec![0; summation_ranges.len()];
let mut temp = sum_idx;
for i in (0..summation_ranges.len()).rev() {
sum_multi_idx[i] = temp % summation_ranges[i];
temp /= summation_ranges[i];
}
for (i, &idx_char) in summation_indices.iter().enumerate() {
index_values.insert(idx_char, sum_multi_idx[i]);
}
let mut product = T::one();
for (op_idx, &operand) in operands.iter().enumerate() {
let spec = operand_specs[op_idx];
let op_shape = operand.shape();
let mut op_indices = vec![0; op_shape.len()];
for (dim_idx, idx_char) in spec.chars().enumerate() {
if idx_char.is_alphabetic() {
op_indices[dim_idx] = index_values[&idx_char];
}
}
product = product * operand.get(&op_indices)?;
}
sum += product;
}
}
if output_shape[0] == 1 && output_shape.len() == 1 {
result.set(&[0], sum)?;
} else {
result.set(&output_multi_idx, sum)?;
}
}
Ok(result)
}
pub fn kron<T: Float + Clone + Debug>(a: &Array<T>, b: &Array<T>) -> Result<Array<T>> {
if a.ndim() != 2 || b.ndim() != 2 {
return Err(NumRs2Error::DimensionMismatch(
"kron requires two 2D arrays".to_string(),
));
}
let a_shape = a.shape();
let b_shape = b.shape();
let out_shape = [a_shape[0] * b_shape[0], a_shape[1] * b_shape[1]];
let mut result = Array::zeros(&out_shape);
let a_data = a.to_vec();
let b_data = b.to_vec();
let result_data = result.array_mut().as_slice_mut().ok_or_else(|| {
NumRs2Error::ComputationError("array should have contiguous memory layout".to_string())
})?;
for i in 0..a_shape[0] {
for j in 0..a_shape[1] {
let a_idx = i * a_shape[1] + j;
let a_val = a_data[a_idx];
for k in 0..b_shape[0] {
for l in 0..b_shape[1] {
let b_idx = k * b_shape[1] + l;
let b_val = b_data[b_idx];
let row = i * b_shape[0] + k;
let col = j * b_shape[1] + l;
let result_idx = row * out_shape[1] + col;
result_data[result_idx] = a_val * b_val;
}
}
}
}
Ok(result)
}
pub fn tensordot<T: Float + Clone + Debug>(
a: &Array<T>,
b: &Array<T>,
axes: &[usize],
) -> Result<Array<T>> {
if axes.len() != 2 {
return Err(NumRs2Error::InvalidOperation(
"This implementation of tensordot only supports 2 axes".to_string(),
));
}
let a_shape = a.shape();
let b_shape = b.shape();
let a_axis = axes[0];
let b_axis = axes[1];
if a_axis >= a_shape.len() || b_axis >= b_shape.len() {
return Err(NumRs2Error::DimensionMismatch(
"Axis out of bounds".to_string(),
));
}
if a_shape[a_axis] != b_shape[b_axis] {
return Err(NumRs2Error::ShapeMismatch {
expected: vec![a_shape[a_axis]],
actual: vec![b_shape[b_axis]],
});
}
if a_shape.len() != 2 || b_shape.len() != 2 {
return Err(NumRs2Error::DimensionMismatch(
"This implementation of tensordot only supports 2D arrays".to_string(),
));
}
if a_axis == 1 && b_axis == 0 {
return a.matmul(b);
}
if a_axis == 0 && b_axis == 1 {
let b_trans = b.transpose();
let result = a.transpose().matmul(&b_trans)?;
return Ok(result.transpose());
}
Err(NumRs2Error::InvalidOperation(
"This axis combination is not implemented in this version".to_string(),
))
}