use crate::{Result, Tensor, TensorError};
use scirs2_core::numeric::{One, Zero};
use std::collections::HashSet;
pub(super) fn compute_optimal_path(
input_subscripts: &[String],
_output_subscript: &str,
) -> Result<Vec<(usize, usize)>> {
let n = input_subscripts.len();
if n <= 2 {
let mut path = Vec::new();
for i in 1..n {
path.push((0, i));
}
return Ok(path);
}
let mut remaining_indices: Vec<usize> = (0..n).collect();
let mut path = Vec::new();
while remaining_indices.len() > 1 {
let mut best_pair = (0, 1);
let mut best_score = f64::INFINITY;
for i in 0..remaining_indices.len() {
for j in (i + 1)..remaining_indices.len() {
let idx1 = remaining_indices[i];
let idx2 = remaining_indices[j];
let cost =
estimate_contraction_cost(&input_subscripts[idx1], &input_subscripts[idx2]);
if cost < best_score {
best_score = cost;
best_pair = (i, j);
}
}
}
let left_idx = remaining_indices[best_pair.0];
let right_idx = remaining_indices[best_pair.1];
path.push((left_idx, right_idx));
remaining_indices.remove(best_pair.1); remaining_indices.remove(best_pair.0);
remaining_indices.push(n + path.len() - 1); }
Ok(path)
}
pub(super) fn estimate_contraction_cost(sub1: &str, sub2: &str) -> f64 {
let total_indices = sub1.len() + sub2.len();
let unique_indices = sub1
.chars()
.chain(sub2.chars())
.collect::<HashSet<_>>()
.len();
let eliminated_indices = total_indices - unique_indices;
total_indices as f64 * 10.0 - eliminated_indices as f64 * 5.0
}
pub(super) fn flat_to_multi_index(flat_idx: usize, shape: &[usize]) -> Vec<usize> {
let mut multi_idx = Vec::with_capacity(shape.len());
let mut remaining = flat_idx;
for &dim_size in shape.iter().rev() {
multi_idx.push(remaining % dim_size);
remaining /= dim_size;
}
multi_idx.reverse();
multi_idx
}
#[allow(dead_code)]
pub(super) fn optimal_loop_ordering(shapes: &[&[usize]]) -> Vec<usize> {
let max_dims = shapes.iter().map(|s| s.len()).max().unwrap_or(0);
let mut ordering = Vec::with_capacity(max_dims);
for i in (0..max_dims).rev() {
ordering.push(i);
}
ordering
}
pub fn cache_friendly_trace<T>(tensor: &Tensor<T>) -> Result<Tensor<T>>
where
T: Clone + Default + Zero + std::ops::Add<Output = T> + Send + Sync + 'static,
{
use crate::tensor::TensorStorage;
match &tensor.storage {
TensorStorage::Cpu(arr) => {
let shape = tensor.shape().dims();
if shape.len() != 2 || shape[0] != shape[1] {
return Err(TensorError::invalid_argument(
"Trace requires square 2D matrix".to_string(),
));
}
let n = shape[0];
let mut trace = T::zero();
let stride = n + 1; for i in 0..n {
let idx = i * stride;
if let Some(val) = arr.as_slice().and_then(|s| s.get(idx)) {
trace = trace + val.clone();
}
}
Tensor::from_vec(vec![trace], &[])
}
#[cfg(feature = "gpu")]
_ => {
extract_trace(tensor)
}
}
}
#[allow(dead_code)]
pub(super) fn extract_trace<T>(tensor: &Tensor<T>) -> Result<Tensor<T>>
where
T: Clone + Default + Zero + std::ops::Add<Output = T> + Send + Sync + 'static,
{
let shape = tensor.shape().dims();
if shape.len() != 2 || shape[0] != shape[1] {
return Err(TensorError::invalid_argument(
"Trace requires square 2D matrix".to_string(),
));
}
let n = shape[0];
let mut trace = T::zero();
for i in 0..n {
if let Some(val) = tensor.get(&[i, i]) {
trace = trace + val;
}
}
Tensor::from_vec(vec![trace], &[])
}
pub fn compute_outer_product<T>(a: &Tensor<T>, b: &Tensor<T>) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ One
+ std::ops::Mul<Output = T>
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
let a_shape = a.shape().dims();
let b_shape = b.shape().dims();
if a_shape.len() != 1 || b_shape.len() != 1 {
return Err(TensorError::invalid_argument(
"Outer product requires 1D tensors".to_string(),
));
}
let m = a_shape[0];
let n = b_shape[0];
let a_vec = a.to_vec().map_err(|e| {
TensorError::invalid_argument(format!("Failed to access tensor A data: {e}"))
})?;
let b_vec = b.to_vec().map_err(|e| {
TensorError::invalid_argument(format!("Failed to access tensor B data: {e}"))
})?;
let mut result_data = Vec::with_capacity(m * n);
for a_val in a_vec.iter().take(m) {
for b_val in b_vec.iter().take(n) {
result_data.push(*a_val * *b_val);
}
}
Tensor::from_vec(result_data, &[m, n])
}
pub fn batch_transpose<T>(tensor: &Tensor<T>) -> Result<Tensor<T>>
where
T: Clone + Default + Zero + Send + Sync + 'static,
{
let shape = tensor.shape().dims();
if shape.len() != 3 {
return Err(TensorError::invalid_argument(
"Batch transpose requires 3D tensor".to_string(),
));
}
let permutation = vec![0, 2, 1]; crate::ops::manipulation::transpose_axes(tensor, Some(&permutation))
}
#[allow(dead_code)]
pub(super) fn get_intermediate_subscript(left: &str, right: &str) -> Result<String> {
let mut chars = Vec::new();
let left_chars: Vec<char> = left.chars().collect();
let right_chars: Vec<char> = right.chars().collect();
for &c in &left_chars {
if !chars.contains(&c) {
chars.push(c);
}
}
for &c in &right_chars {
if !chars.contains(&c) {
chars.push(c);
}
}
chars.sort();
Ok(chars.into_iter().collect())
}
#[allow(dead_code)]
pub(super) fn get_result_subscript<T>(tensor: &Tensor<T>) -> Result<String> {
let rank = tensor.shape().rank();
let chars: String = (0..rank).map(|i| char::from(b'a' + (i as u8))).collect();
Ok(chars)
}