use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::Tensor;
pub fn einsum_optimized(equation: &str, operands: &[&Tensor]) -> TorshResult<Tensor> {
if operands.is_empty() {
return Err(TorshError::invalid_argument_with_context(
"einsum requires at least one operand",
"einsum_optimized",
));
}
let (inputs, output) = parse_einsum_equation(equation)?;
if inputs.len() != operands.len() {
return Err(TorshError::invalid_argument_with_context(
&format!(
"einsum equation expects {} operands, got {}",
inputs.len(),
operands.len()
),
"einsum_optimized",
));
}
let optimal_path = optimize_contraction_path(&inputs, &output)?;
execute_contraction_path(operands, &optimal_path, &output)
}
fn parse_einsum_equation(equation: &str) -> TorshResult<(Vec<String>, String)> {
let parts: Vec<&str> = equation.split("->").collect();
if parts.len() > 2 {
return Err(TorshError::invalid_argument_with_context(
"einsum equation can have at most one '->' separator",
"parse_einsum_equation",
));
}
let input_str = parts[0];
let inputs: Vec<String> = input_str.split(',').map(|s| s.trim().to_string()).collect();
let output = if parts.len() == 2 {
parts[1].trim().to_string()
} else {
infer_output_indices(&inputs)
};
Ok((inputs, output))
}
fn infer_output_indices(inputs: &[String]) -> String {
use std::collections::HashMap;
let mut index_counts = HashMap::new();
for input in inputs {
for ch in input.chars() {
if ch.is_alphabetic() {
*index_counts.entry(ch).or_insert(0) += 1;
}
}
}
let mut output_chars: Vec<char> = index_counts
.iter()
.filter(|(_, &count)| count == 1)
.map(|(&ch, _)| ch)
.collect();
output_chars.sort_unstable();
output_chars.into_iter().collect()
}
fn optimize_contraction_path(
inputs: &[String],
_output: &str,
) -> TorshResult<Vec<ContractionStep>> {
let mut steps = Vec::new();
let mut remaining = inputs.to_vec();
while remaining.len() > 1 {
let (idx1, idx2) = find_best_contraction_pair(&remaining)?;
let indices1 = &remaining[idx1];
let indices2 = &remaining[idx2];
let result_indices = compute_contraction_result(indices1, indices2);
steps.push(ContractionStep {
_operand1: idx1,
_operand2: idx2,
_result_indices: result_indices.clone(),
});
remaining.remove(idx2.max(idx1));
remaining.remove(idx1.min(idx2));
remaining.push(result_indices);
}
Ok(steps)
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct ContractionStep {
_operand1: usize,
_operand2: usize,
_result_indices: String,
}
fn find_best_contraction_pair(remaining: &[String]) -> TorshResult<(usize, usize)> {
if remaining.len() < 2 {
return Err(TorshError::invalid_argument_with_context(
"need at least 2 tensors to find contraction pair",
"find_best_contraction_pair",
));
}
Ok((0, 1))
}
fn compute_contraction_result(indices1: &str, indices2: &str) -> String {
use std::collections::HashSet;
let set1: HashSet<char> = indices1.chars().collect();
let set2: HashSet<char> = indices2.chars().collect();
let contracted: HashSet<char> = set1.intersection(&set2).copied().collect();
let mut result_chars: Vec<char> = indices1
.chars()
.chain(indices2.chars())
.filter(|&ch| !contracted.contains(&ch))
.collect();
let mut seen = HashSet::new();
result_chars.retain(|&ch| seen.insert(ch));
result_chars.into_iter().collect()
}
fn execute_contraction_path(
operands: &[&Tensor],
_path: &[ContractionStep],
_output: &str,
) -> TorshResult<Tensor> {
if operands.len() == 2 {
let operand_vec: Vec<Tensor> = operands.iter().map(|&t| t.clone()).collect();
return crate::math::einsum("ij,jk->ik", &operand_vec);
}
Err(TorshError::InvalidOperation(
"general einsum contraction path execution not yet implemented (execute_contraction_path)"
.to_string(),
))
}
pub fn tensor_contract(
a: &Tensor,
b: &Tensor,
axes_a: &[usize],
axes_b: &[usize],
) -> TorshResult<Tensor> {
if axes_a.len() != axes_b.len() {
return Err(TorshError::invalid_argument_with_context(
"number of contraction axes must match",
"tensor_contract",
));
}
let a_shape_obj = a.shape();
let shape_a = a_shape_obj.dims();
let b_shape_obj = b.shape();
let shape_b = b_shape_obj.dims();
for &axis in axes_a {
if axis >= shape_a.len() {
return Err(TorshError::invalid_argument_with_context(
&format!(
"axis {} out of range for tensor with {} dimensions",
axis,
shape_a.len()
),
"tensor_contract",
));
}
}
for &axis in axes_b {
if axis >= shape_b.len() {
return Err(TorshError::invalid_argument_with_context(
&format!(
"axis {} out of range for tensor with {} dimensions",
axis,
shape_b.len()
),
"tensor_contract",
));
}
}
for (&axis_a, &axis_b) in axes_a.iter().zip(axes_b.iter()) {
if shape_a[axis_a] != shape_b[axis_b] {
return Err(TorshError::invalid_argument_with_context(
&format!(
"contracted dimensions must match: {} != {}",
shape_a[axis_a], shape_b[axis_b]
),
"tensor_contract",
));
}
}
crate::manipulation::tensordot(
a,
b,
crate::manipulation::TensorDotAxes::Arrays(axes_a.to_vec(), axes_b.to_vec()),
)
}
pub fn tensor_map<F>(input: &Tensor<f32>, f: F) -> TorshResult<Tensor<f32>>
where
F: Fn(f32) -> f32 + Send + Sync,
{
let data = input.data()?;
let shape = input.shape().dims().to_vec();
let device = input.device();
let result_data: Vec<f32> = if data.len() > 10000 {
use scirs2_core::parallel_ops::*;
data.iter()
.copied()
.collect::<Vec<_>>()
.into_par_iter()
.map(f)
.collect()
} else {
data.iter().map(|&x| f(x)).collect()
};
Tensor::from_data(result_data, shape, device)
}
pub fn tensor_reduce<F>(
input: &Tensor<f32>,
axis: Option<usize>,
f: F,
init: f32,
) -> TorshResult<Tensor<f32>>
where
F: Fn(f32, f32) -> f32 + Send + Sync,
{
let input_shape = input.shape();
let shape = input_shape.dims();
if let Some(ax) = axis {
if ax >= shape.len() {
return Err(TorshError::invalid_argument_with_context(
&format!(
"axis {} out of range for tensor with {} dimensions",
ax,
shape.len()
),
"tensor_reduce",
));
}
let data = input.data()?;
let mut output_shape = shape.to_vec();
output_shape.remove(ax);
if output_shape.is_empty() {
let result = data.iter().fold(init, |acc, &x| f(acc, x));
return Tensor::from_data(vec![result], vec![1], input.device());
}
let mut strides = vec![1; shape.len()];
for i in (0..shape.len() - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
let output_size: usize = output_shape.iter().product();
let axis_size = shape[ax];
let mut result_data = vec![init; output_size];
for (out_idx, result_val) in result_data.iter_mut().enumerate() {
for axis_idx in 0..axis_size {
let mut in_idx = 0;
let mut remaining = out_idx;
let mut out_dim_idx = 0;
for dim_idx in 0..shape.len() {
if dim_idx == ax {
in_idx += axis_idx * strides[dim_idx];
} else {
let size = output_shape[out_dim_idx];
let coord = remaining % size;
remaining /= size;
in_idx += coord * strides[dim_idx];
out_dim_idx += 1;
}
}
if in_idx < data.len() {
*result_val = f(*result_val, data[in_idx]);
}
}
}
Tensor::from_data(result_data, output_shape, input.device())
} else {
let data = input.data()?;
let result = data.iter().fold(init, |acc, &x| f(acc, x));
Tensor::from_data(vec![result], vec![1], input.device())
}
}
pub fn tensor_scan<F>(input: &Tensor<f32>, axis: usize, f: F, init: f32) -> TorshResult<Tensor<f32>>
where
F: Fn(f32, f32) -> f32,
{
let input_shape = input.shape();
let shape = input_shape.dims();
if axis >= shape.len() {
return Err(TorshError::invalid_argument_with_context(
&format!(
"axis {} out of range for tensor with {} dimensions",
axis,
shape.len()
),
"tensor_scan",
));
}
let data = input.data()?;
let mut result_data = data.to_vec();
let mut strides = vec![1; shape.len()];
for i in (0..shape.len() - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
let axis_size = shape[axis];
let axis_stride = strides[axis];
let other_size: usize = shape
.iter()
.enumerate()
.filter(|(i, _)| *i != axis)
.map(|(_, &s)| s)
.product();
for other_idx in 0..other_size {
let mut base_idx = 0;
let mut remaining = other_idx;
for (dim_idx, &size) in shape.iter().enumerate() {
if dim_idx != axis {
let coord = remaining % size;
remaining /= size;
base_idx += coord * strides[dim_idx];
}
}
let mut acc = init;
for axis_idx in 0..axis_size {
let idx = base_idx + axis_idx * axis_stride;
if idx < result_data.len() {
acc = f(acc, result_data[idx]);
result_data[idx] = acc;
}
}
}
Tensor::from_data(result_data, shape.to_vec(), input.device())
}
pub fn tensor_fold<F>(input: &Tensor<f32>, f: F, init: f32) -> TorshResult<f32>
where
F: Fn(f32, f32) -> f32,
{
let data = input.data()?;
Ok(data.iter().fold(init, |acc, &x| f(acc, x)))
}
pub fn tensor_outer(a: &Tensor<f32>, b: &Tensor<f32>) -> TorshResult<Tensor<f32>> {
let a_shape_obj = a.shape();
let shape_a = a_shape_obj.dims();
let b_shape_obj = b.shape();
let shape_b = b_shape_obj.dims();
let mut new_shape_a = shape_a.to_vec();
new_shape_a.extend(vec![1; shape_b.len()]);
let mut new_shape_b = vec![1; shape_a.len()];
new_shape_b.extend(shape_b);
let a_reshaped = a.view(&new_shape_a.iter().map(|&x| x as i32).collect::<Vec<_>>())?;
let b_reshaped = b.view(&new_shape_b.iter().map(|&x| x as i32).collect::<Vec<_>>())?;
a_reshaped.mul(&b_reshaped)
}
pub fn tensor_zip<F>(a: &Tensor<f32>, b: &Tensor<f32>, f: F) -> TorshResult<Tensor<f32>>
where
F: Fn(f32, f32) -> f32 + Send + Sync,
{
if a.shape().dims() != b.shape().dims() {
return Err(TorshError::invalid_argument_with_context(
&format!(
"tensor shapes must match for zip: {:?} vs {:?}",
a.shape().dims(),
b.shape().dims()
),
"tensor_zip",
));
}
let data_a = a.data()?;
let data_b = b.data()?;
let shape = a.shape().dims().to_vec();
let device = a.device();
let result_data: Vec<f32> = if data_a.len() > 10000 {
use scirs2_core::parallel_ops::*;
let pairs: Vec<(f32, f32)> = data_a.iter().copied().zip(data_b.iter().copied()).collect();
pairs.into_par_iter().map(|(x, y)| f(x, y)).collect()
} else {
data_a
.iter()
.zip(data_b.iter())
.map(|(&x, &y)| f(x, y))
.collect()
};
Tensor::from_data(result_data, shape, device)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_tensor_map() {
let input = Tensor::from_data(
vec![1.0, 2.0, 3.0, 4.0],
vec![2, 2],
torsh_core::device::DeviceType::Cpu,
)
.expect("failed to create tensor");
let output = tensor_map(&input, |x| x * 2.0).expect("map failed");
let output_data = output.data().expect("failed to get data");
assert_relative_eq!(output_data[0], 2.0, epsilon = 1e-6);
assert_relative_eq!(output_data[1], 4.0, epsilon = 1e-6);
assert_relative_eq!(output_data[2], 6.0, epsilon = 1e-6);
assert_relative_eq!(output_data[3], 8.0, epsilon = 1e-6);
}
#[test]
fn test_tensor_reduce() {
let input = Tensor::from_data(
vec![1.0, 2.0, 3.0, 4.0],
vec![4],
torsh_core::device::DeviceType::Cpu,
)
.expect("failed to create tensor");
let output = tensor_reduce(&input, None, |a, b| a + b, 0.0).expect("reduce failed");
let output_data = output.data().expect("failed to get data");
assert_relative_eq!(output_data[0], 10.0, epsilon = 1e-6);
}
#[test]
fn test_tensor_fold() {
let input = Tensor::from_data(
vec![1.0, 2.0, 3.0, 4.0],
vec![4],
torsh_core::device::DeviceType::Cpu,
)
.expect("failed to create tensor");
let result = tensor_fold(&input, |acc, x| acc + x, 0.0).expect("fold failed");
assert_relative_eq!(result, 10.0, epsilon = 1e-6);
}
#[test]
fn test_tensor_scan() {
let input = Tensor::from_data(
vec![1.0, 2.0, 3.0, 4.0],
vec![4],
torsh_core::device::DeviceType::Cpu,
)
.expect("failed to create tensor");
let output = tensor_scan(&input, 0, |a, b| a + b, 0.0).expect("scan failed");
let output_data = output.data().expect("failed to get data");
assert_relative_eq!(output_data[0], 1.0, epsilon = 1e-6);
assert_relative_eq!(output_data[1], 3.0, epsilon = 1e-6);
assert_relative_eq!(output_data[2], 6.0, epsilon = 1e-6);
assert_relative_eq!(output_data[3], 10.0, epsilon = 1e-6);
}
#[test]
fn test_tensor_outer() {
let a = Tensor::from_data(
vec![1.0, 2.0, 3.0],
vec![3],
torsh_core::device::DeviceType::Cpu,
)
.expect("failed to create tensor");
let b = Tensor::from_data(vec![4.0, 5.0], vec![2], torsh_core::device::DeviceType::Cpu)
.expect("failed to create tensor");
let c = tensor_outer(&a, &b).expect("outer product failed");
assert_eq!(c.shape().dims(), &[3, 2]);
let c_data = c.data().expect("failed to get data");
assert_relative_eq!(c_data[0], 4.0, epsilon = 1e-6); assert_relative_eq!(c_data[1], 5.0, epsilon = 1e-6); assert_relative_eq!(c_data[2], 8.0, epsilon = 1e-6); assert_relative_eq!(c_data[3], 10.0, epsilon = 1e-6); }
#[test]
fn test_tensor_zip() {
let a = Tensor::from_data(
vec![1.0, 2.0, 3.0, 4.0],
vec![4],
torsh_core::device::DeviceType::Cpu,
)
.expect("failed to create tensor");
let b = Tensor::from_data(
vec![5.0, 6.0, 7.0, 8.0],
vec![4],
torsh_core::device::DeviceType::Cpu,
)
.expect("failed to create tensor");
let c = tensor_zip(&a, &b, |x, y| x + y).expect("zip failed");
let c_data = c.data().expect("failed to get data");
assert_relative_eq!(c_data[0], 6.0, epsilon = 1e-6);
assert_relative_eq!(c_data[1], 8.0, epsilon = 1e-6);
assert_relative_eq!(c_data[2], 10.0, epsilon = 1e-6);
assert_relative_eq!(c_data[3], 12.0, epsilon = 1e-6);
}
#[test]
fn test_parse_einsum_equation() {
let (inputs, output) = parse_einsum_equation("ij,jk->ik").expect("parse failed");
assert_eq!(inputs, vec!["ij", "jk"]);
assert_eq!(output, "ik");
let (inputs, output) = parse_einsum_equation("ii->").expect("parse failed");
assert_eq!(inputs, vec!["ii"]);
assert_eq!(output, "");
}
#[test]
fn test_tensor_reduce_axis() {
let input = Tensor::from_data(
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
vec![2, 3],
torsh_core::device::DeviceType::Cpu,
)
.expect("failed to create tensor");
let output = tensor_reduce(&input, Some(0), |a, b| a + b, 0.0).expect("reduce failed");
assert_eq!(output.shape().dims(), &[3]);
let output_data = output.data().expect("failed to get data");
assert_relative_eq!(output_data[0], 5.0, epsilon = 1e-6); assert_relative_eq!(output_data[1], 7.0, epsilon = 1e-6); assert_relative_eq!(output_data[2], 9.0, epsilon = 1e-6); }
}