use super::utils::flat_to_multi_index;
use crate::tensor::TensorStorage;
use crate::{Result, Tensor, TensorError};
use scirs2_core::numeric::{One, Zero};
pub fn execute_contraction_path<T>(
operands: &[&Tensor<T>],
path: &[(usize, usize)],
) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ One
+ std::ops::Add<Output = T>
+ std::ops::Mul<Output = T>
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
let mut intermediate_tensors = operands.iter().map(|&t| t.clone()).collect::<Vec<_>>();
for &(left_idx, right_idx) in path {
if left_idx >= intermediate_tensors.len() || right_idx >= intermediate_tensors.len() {
return Err(TensorError::invalid_argument(
"Invalid contraction path indices".to_string(),
));
}
let left = &intermediate_tensors[left_idx];
let right = &intermediate_tensors[right_idx];
let contracted = cache_optimized_contraction(left, right)?;
if right_idx > left_idx {
intermediate_tensors.remove(right_idx);
intermediate_tensors.remove(left_idx);
} else {
intermediate_tensors.remove(left_idx);
intermediate_tensors.remove(right_idx);
}
intermediate_tensors.push(contracted);
}
if intermediate_tensors.len() != 1 {
return Err(TensorError::invalid_argument(
"Invalid contraction path: should result in single tensor".to_string(),
));
}
Ok(intermediate_tensors
.into_iter()
.next()
.expect("intermediate_tensors guaranteed to have exactly 1 element"))
}
pub fn cache_optimized_contraction<T>(left: &Tensor<T>, right: &Tensor<T>) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ One
+ std::ops::Add<Output = T>
+ std::ops::Mul<Output = T>
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
let _left_shape = left.shape().dims();
let _right_shape = right.shape().dims();
match (&left.storage, &right.storage) {
(TensorStorage::Cpu(_), TensorStorage::Cpu(_)) => {
cache_friendly_cpu_contraction(left, right)
}
#[cfg(feature = "gpu")]
_ => {
left.mul(right)
}
}
}
pub(super) fn cache_friendly_cpu_contraction<T>(
left: &Tensor<T>,
right: &Tensor<T>,
) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ One
+ std::ops::Add<Output = T>
+ std::ops::Mul<Output = T>
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
let left_shape = left.shape().dims();
let right_shape = right.shape().dims();
if left_shape == right_shape {
return cache_friendly_elementwise_mul(left, right);
}
left.mul(right) }
pub(super) fn cache_friendly_elementwise_mul<T>(
left: &Tensor<T>,
right: &Tensor<T>,
) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ One
+ std::ops::Add<Output = T>
+ std::ops::Mul<Output = T>
+ Send
+ Sync
+ 'static
+ bytemuck::Pod
+ bytemuck::Zeroable,
{
match (&left.storage, &right.storage) {
(TensorStorage::Cpu(left_arr), TensorStorage::Cpu(right_arr)) => {
let shape = left.shape().dims();
if shape.iter().product::<usize>() > 16384 {
cache_friendly_blocked_multiply(left_arr, right_arr, shape)
} else {
left.mul(right)
}
}
#[cfg(feature = "gpu")]
_ => left.mul(right),
}
}
pub(super) fn cache_friendly_blocked_multiply<T>(
left_arr: &scirs2_core::ndarray::ArrayD<T>,
right_arr: &scirs2_core::ndarray::ArrayD<T>,
shape: &[usize],
) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ One
+ std::ops::Add<Output = T>
+ std::ops::Mul<Output = T>
+ Send
+ Sync
+ 'static,
{
let total_elements = shape.iter().product::<usize>();
let block_size = 1024.min(total_elements);
if total_elements > 65536 {
parallel_blocked_multiply(left_arr, right_arr, shape, block_size)
} else {
sequential_blocked_multiply(left_arr, right_arr, shape, block_size)
}
}
pub(super) fn sequential_blocked_multiply<T>(
left_arr: &scirs2_core::ndarray::ArrayD<T>,
right_arr: &scirs2_core::ndarray::ArrayD<T>,
shape: &[usize],
block_size: usize,
) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ One
+ std::ops::Add<Output = T>
+ std::ops::Mul<Output = T>
+ Send
+ Sync
+ 'static,
{
let total_elements = shape.iter().product::<usize>();
let mut result_data = Vec::with_capacity(total_elements);
for block_start in (0..total_elements).step_by(block_size) {
let block_end = (block_start + block_size).min(total_elements);
for flat_idx in block_start..block_end {
let multi_idx = flat_to_multi_index(flat_idx, shape);
let left_val = left_arr
.get(multi_idx.as_slice())
.unwrap_or(&T::zero())
.clone();
let right_val = right_arr
.get(multi_idx.as_slice())
.unwrap_or(&T::zero())
.clone();
result_data.push(left_val * right_val);
}
}
Tensor::from_vec(result_data, shape)
}
pub(super) fn parallel_blocked_multiply<T>(
left_arr: &scirs2_core::ndarray::ArrayD<T>,
right_arr: &scirs2_core::ndarray::ArrayD<T>,
shape: &[usize],
block_size: usize,
) -> Result<Tensor<T>>
where
T: Clone
+ Default
+ Zero
+ One
+ std::ops::Add<Output = T>
+ std::ops::Mul<Output = T>
+ Send
+ Sync
+ 'static,
{
let total_elements = shape.iter().product::<usize>();
let num_threads = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1)
.min(8);
if num_threads <= 1 {
return sequential_blocked_multiply(left_arr, right_arr, shape, block_size);
}
let elements_per_thread = (total_elements + num_threads - 1) / num_threads;
let mut result_data = vec![T::zero(); total_elements];
std::thread::scope(|s| {
let mut handles = Vec::new();
for thread_id in 0..num_threads {
let start_idx = thread_id * elements_per_thread;
let end_idx = (start_idx + elements_per_thread).min(total_elements);
if start_idx >= total_elements {
break;
}
let handle = s.spawn(move || {
let mut chunk_results = Vec::new();
for flat_idx in start_idx..end_idx {
let multi_idx = flat_to_multi_index(flat_idx, shape);
let left_val = left_arr
.get(multi_idx.as_slice())
.unwrap_or(&T::zero())
.clone();
let right_val = right_arr
.get(multi_idx.as_slice())
.unwrap_or(&T::zero())
.clone();
chunk_results.push((flat_idx, left_val * right_val));
}
chunk_results
});
handles.push(handle);
}
for handle in handles {
let chunk_results = handle.join().expect("thread join should succeed");
for (idx, value) in chunk_results {
result_data[idx] = value;
}
}
});
Tensor::from_vec(result_data, shape)
}