use num_traits::Zero;
use std::ops::{Add, Mul, MulAssign};
use crate::{DenseLayout, DenseTensorData, TensorData, TensorError};
use ariadnetor_core::MemoryOrder;
impl<T> DenseTensorData<T>
where
T: Clone,
{
pub fn reshape(&self, new_shape: Vec<usize>) -> Self {
let new_total: usize = new_shape.iter().product();
assert_eq!(
self.len(),
new_total,
"reshape: total elements must match ({} vs {new_total})",
self.len()
);
let storage = self.storage().clone();
let layout = DenseLayout::new(new_shape, self.order());
TensorData::new(storage, layout)
}
pub fn map<U, F>(&self, f: F) -> DenseTensorData<U>
where
F: Fn(&T) -> U,
U: Clone + 'static,
{
let result: Vec<U> = self.storage().data().iter().map(f).collect();
DenseTensorData::<U>::from_raw_parts(result, self.shape().to_vec(), self.order())
}
pub fn map_with_index<U, F>(&self, f: F) -> DenseTensorData<U>
where
F: Fn(&[usize], &T) -> U,
U: Clone + 'static,
{
let order = self.order();
let shape = self.shape();
let rank = shape.len();
let total = self.len();
let raw = self.storage().data();
let mut coords = vec![0usize; rank];
let mut result = Vec::with_capacity(total);
let axis_order: Vec<usize> = match order {
MemoryOrder::RowMajor => (0..rank).collect(),
MemoryOrder::ColumnMajor => (0..rank).rev().collect(),
};
for val in raw.iter().take(total) {
result.push(f(&coords, val));
for &d in axis_order.iter().rev() {
coords[d] += 1;
if coords[d] < shape[d] {
break;
}
coords[d] = 0;
}
}
DenseTensorData::<U>::from_raw_parts(result, shape.to_vec(), order)
}
pub fn scaled<S>(&self, factor: S) -> Self
where
T: Mul<S, Output = T>,
S: Clone,
{
let mut result = self.clone();
result.storage_mut().scale(factor);
result
}
}
impl<T> Mul<T> for DenseTensorData<T>
where
T: Clone + Mul<Output = T>,
{
type Output = DenseTensorData<T>;
fn mul(mut self, rhs: T) -> Self::Output {
self.scale(rhs);
self
}
}
impl<T> Mul<T> for &DenseTensorData<T>
where
T: Clone + Mul<Output = T>,
{
type Output = DenseTensorData<T>;
fn mul(self, rhs: T) -> Self::Output {
self.scaled(rhs)
}
}
impl<T> MulAssign<T> for DenseTensorData<T>
where
T: Clone + Mul<Output = T>,
{
fn mul_assign(&mut self, rhs: T) {
self.scale(rhs);
}
}
impl<T> DenseTensorData<T>
where
T: Clone,
{
pub fn add_all(tensors: &[&DenseTensorData<T>]) -> Result<DenseTensorData<T>, TensorError>
where
T: Zero + num_traits::One + Add<Output = T> + Mul<Output = T>,
{
let coefs = vec![T::one(); tensors.len()];
Self::linear_combine(tensors, &coefs)
}
pub fn linear_combine(
tensors: &[&DenseTensorData<T>],
coefs: &[T],
) -> Result<DenseTensorData<T>, TensorError>
where
T: Zero + Add<Output = T> + Mul<Output = T>,
{
if tensors.is_empty() {
return Err(TensorError::InvalidArgument(
"Cannot combine empty tensor list".to_string(),
));
}
if tensors.len() != coefs.len() {
return Err(TensorError::InvalidArgument(format!(
"Mismatched lengths: {} tensors vs {} coefficients",
tensors.len(),
coefs.len()
)));
}
let shape = tensors[0].shape();
let order = tensors[0].order();
for t in &tensors[1..] {
if t.shape() != shape {
return Err(TensorError::InvalidArgument(
"All tensors must have the same shape".to_string(),
));
}
if t.order() != order {
return Err(TensorError::InvalidArgument(format!(
"All tensors must have the same memory order; got {:?} and {:?}",
order,
t.order()
)));
}
}
let len = tensors[0].len();
let mut result = vec![T::zero(); len];
for (tensor, coef) in tensors.iter().zip(coefs) {
for (r, val) in result.iter_mut().zip(tensor.storage().data()) {
*r = r.clone() + coef.clone() * val.clone();
}
}
Ok(DenseTensorData::from_raw_parts(
result,
shape.to_vec(),
order,
))
}
}