use crate::{Tensor, TensorElement};
use torsh_core::error::Result;
impl<T: TensorElement> Tensor<T> {
pub fn element_wise_op<F>(&self, other: &Self, op: F) -> Result<Self>
where
F: Fn(T, T) -> T + Send + Sync,
T: Send + Sync,
{
if self.shape() != other.shape() {
return Err(torsh_core::error::TorshError::ShapeMismatch {
expected: self.shape().to_vec(),
got: other.shape().to_vec(),
});
}
let self_data = self.data();
let other_data = other.data();
let result_data: Vec<T> = self_data
.iter()
.zip(other_data.iter())
.map(|(&a, &b)| op(a, b))
.collect();
Self::from_vec_and_shape(result_data, self.shape().to_vec())
}
pub fn element_wise_op_inplace<F>(&mut self, other: &Self, op: F) -> Result<()>
where
F: Fn(T, T) -> T + Send + Sync,
T: Send + Sync,
{
if self.shape() != other.shape() {
return Err(torsh_core::error::TorshError::ShapeMismatch {
expected: self.shape().to_vec(),
got: other.shape().to_vec(),
});
}
let other_data = other.data();
let self_data = self.data_mut();
for (a, &b) in self_data.iter_mut().zip(other_data.iter()) {
*a = op(*a, b);
}
Ok(())
}
pub fn broadcast_binary_op_inplace<F>(&mut self, other: &Self, op: F) -> Result<()>
where
F: Fn(T, T) -> T + Send + Sync + Clone,
T: Send + Sync,
{
let self_shape = self.shape();
let other_shape = other.shape();
if !self.can_broadcast_with(other) {
return Err(torsh_core::error::TorshError::BroadcastError {
shape1: self_shape.to_vec(),
shape2: other_shape.to_vec(),
});
}
if self_shape == other_shape {
return self.element_wise_op_inplace(other, op);
}
let self_strides = self.strides();
let other_strides = other.strides();
let other_data = other.data();
let self_data = self.data_mut();
for i in 0..self.numel() {
let self_idx = self.linear_to_multi_index(i);
let other_idx = self.broadcast_indices(&self_idx, other_shape);
let other_linear_idx = Self::multi_to_linear_index(&other_idx, other_strides);
self_data[i] = op(self_data[i], other_data[other_linear_idx]);
}
Ok(())
}
pub fn can_broadcast_with(&self, other: &Self) -> bool {
let self_shape = self.shape();
let other_shape = other.shape();
let max_dims = self_shape.len().max(other_shape.len());
for i in 0..max_dims {
let self_dim = if i < self_shape.len() {
self_shape[self_shape.len() - 1 - i]
} else {
1
};
let other_dim = if i < other_shape.len() {
other_shape[other_shape.len() - 1 - i]
} else {
1
};
if self_dim != 1 && other_dim != 1 && self_dim != other_dim {
return false;
}
}
true
}
fn broadcast_indices(&self, index: &[usize], target_shape: &[usize]) -> Vec<usize> {
let mut result = vec![0; target_shape.len()];
let self_shape = self.shape();
for i in 0..target_shape.len() {
let self_dim_idx = if i < self_shape.len() {
self_shape.len() - 1 - (target_shape.len() - 1 - i)
} else {
continue;
};
if self_dim_idx < index.len() {
result[i] = if self_shape[self_dim_idx] == 1 {
0
} else {
index[self_dim_idx]
};
}
}
result
}
fn linear_to_multi_index(&self, linear_idx: usize) -> Vec<usize> {
let shape = self.shape();
let mut index = vec![0; shape.len()];
let mut remaining = linear_idx;
for i in (0..shape.len()).rev() {
let stride = shape[i+1..].iter().product::<usize>();
index[i] = remaining / stride;
remaining %= stride;
}
index
}
fn multi_to_linear_index(index: &[usize], strides: &[usize]) -> usize {
index.iter().zip(strides.iter()).map(|(&i, &s)| i * s).sum()
}
}