use scirs2_core::ndarray::{Array, IxDyn};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use crate::error::{AutogradError, Result};
use crate::graph::Node;
use crate::tensor::Tensor;
#[allow(dead_code)]
pub fn add<F: Float + Debug + Send + Sync + 'static>(
a: &Tensor<F>,
b: &Tensor<F>,
) -> Result<Tensor<F>> {
if a.shape() != b.shape() {
return Err(AutogradError::ShapeMismatch(format!(
"Addition requires tensors of the same shape: {:?} vs {:?}",
a.shape(),
b.shape()
)));
}
Ok(a + b)
}
#[allow(dead_code)]
pub fn sub<F: Float + Debug + Send + Sync + 'static>(
a: &Tensor<F>,
b: &Tensor<F>,
) -> Result<Tensor<F>> {
if a.shape() != b.shape() {
return Err(AutogradError::ShapeMismatch(format!(
"Subtraction requires tensors of the same shape: {:?} vs {:?}",
a.shape(),
b.shape()
)));
}
Ok(a - b)
}
#[allow(dead_code)]
pub fn mul<F: Float + Debug + Send + Sync + 'static>(
a: &Tensor<F>,
b: &Tensor<F>,
) -> Result<Tensor<F>> {
if a.shape() != b.shape() {
return Err(AutogradError::ShapeMismatch(format!(
"Multiplication requires tensors of the same shape: {:?} vs {:?}",
a.shape(),
b.shape()
)));
}
Ok(a * b)
}
#[allow(dead_code)]
pub fn div<F: Float + Debug + Send + Sync + 'static>(
a: &Tensor<F>,
b: &Tensor<F>,
) -> Result<Tensor<F>> {
if a.shape() != b.shape() {
return Err(AutogradError::ShapeMismatch(format!(
"Division requires tensors of the same shape: {:?} vs {:?}",
a.shape(),
b.shape()
)));
}
Ok(a / b)
}
#[allow(dead_code)]
pub fn matmul<F: Float + Debug + Send + Sync + 'static>(
a: &Tensor<F>,
b: &Tensor<F>,
) -> Result<Tensor<F>> {
if a.data.ndim() < 2 || b.data.ndim() < 2 {
return Err(AutogradError::ShapeMismatch(
"Matrix multiplication requires at least 2D tensors".to_string(),
));
}
let ashape = a.shape();
let bshape = b.shape();
if ashape[ashape.len() - 1] != bshape[bshape.len() - 2] {
return Err(AutogradError::ShapeMismatch(format!(
"Matrix multiplication dimension mismatch: {:?} and {:?}",
ashape, bshape
)));
}
let a_rows = a.data.shape()[0];
let a_cols = a.data.shape()[1];
let b_cols = b.data.shape()[1];
let mut result_data_2d = Array::<F>::zeros((a_rows, b_cols));
for i in 0..a_rows {
for j in 0..b_cols {
let mut sum = F::zero();
for k in 0..a_cols {
sum = sum + a.data[[i, k]] * b.data[[k, j]];
}
result_data_2d[[i, j]] = sum;
}
}
let result_data = result_data_2d.into_dyn();
let requires_grad = a.requires_grad || b.requires_grad;
if requires_grad {
let node = Node::matmul(a, b)?;
Ok(Tensor::from_operation(result_data, node, requires_grad))
} else {
Ok(Tensor::new(result_data, false))
}
}
#[allow(dead_code)]
pub fn relu<F: Float + Debug + Send + Sync + 'static>(x: &Tensor<F>) -> Result<Tensor<F>> {
let result_data = x.data.mapv(|v| if v > F::zero() { v } else { F::zero() });
if x.requires_grad {
let node = Node::relu(x);
Ok(Tensor::from_operation(result_data, node, true))
} else {
Ok(Tensor::new(result_data, false))
}
}
#[allow(dead_code)]
pub fn sigmoid<F: Float + Debug + Send + Sync + 'static>(x: &Tensor<F>) -> Result<Tensor<F>> {
let result_data = x.data.mapv(|v| F::one() / (F::one() + (-v).exp()));
if x.requires_grad {
let node = Node::sigmoid(x);
Ok(Tensor::from_operation(result_data, node, true))
} else {
Ok(Tensor::new(result_data, false))
}
}
#[allow(dead_code)]
pub fn tanh<F: Float + Debug + Send + Sync + 'static>(x: &Tensor<F>) -> Result<Tensor<F>> {
let result_data = x.data.mapv(|v| v.tanh());
if x.requires_grad {
let node = Node::tanh(x);
Ok(Tensor::from_operation(result_data, node, true))
} else {
Ok(Tensor::new(result_data, false))
}
}
#[allow(dead_code)]
pub fn sum<F: Float + Debug + Send + Sync + 'static>(
x: &Tensor<F>,
axis: Option<usize>,
) -> Result<Tensor<F>> {
let result_data = if let Some(axis) = axis {
if axis >= x.data.ndim() {
return Err(AutogradError::ShapeMismatch(format!(
"Sum axis {} out of bounds for tensor with {} dimensions",
axis,
x.data.ndim()
)));
}
x.data.sum_axis(scirs2_core::ndarray::Axis(axis)).into_dyn()
} else {
let sum_val = x.data.sum();
Array::from_elem(IxDyn(&[1]), sum_val)
};
if x.requires_grad {
let node = Node::sum(x, axis);
Ok(Tensor::from_operation(result_data, node, true))
} else {
Ok(Tensor::new(result_data, false))
}
}
#[allow(dead_code)]
pub fn mean<F: Float + Debug + Send + Sync + 'static>(
x: &Tensor<F>,
axis: Option<usize>,
) -> Result<Tensor<F>> {
let result_data = if let Some(axis) = axis {
if axis >= x.data.ndim() {
return Err(AutogradError::ShapeMismatch(format!(
"Mean axis {} out of bounds for tensor with {} dimensions",
axis,
x.data.ndim()
)));
}
let sum = x.data.sum_axis(scirs2_core::ndarray::Axis(axis));
let count = F::from(x.shape()[axis]).expect("Operation failed");
let mean = sum.mapv(|v| v / count);
mean.into_dyn()
} else {
let sum_val = x.data.sum();
let count = F::from(x.size()).expect("Operation failed");
let mean_val = sum_val / count;
Array::from_elem(IxDyn(&[1]), mean_val)
};
if x.requires_grad {
let node = Node::mean(x, axis);
Ok(Tensor::from_operation(result_data, node, true))
} else {
Ok(Tensor::new(result_data, false))
}
}
#[allow(dead_code)]
pub fn reshape<F: Float + Debug + Send + Sync + 'static>(
x: &Tensor<F>,
shape: &[usize],
) -> Result<Tensor<F>> {
let old_size = x.size();
let new_size: usize = shape.iter().product();
if old_size != new_size {
return Err(AutogradError::ShapeMismatch(format!(
"Cannot reshape tensor of size {} to shape {:?} with size {}",
old_size, shape, new_size
)));
}
let result_data = match x.data.clone().into_shape_with_order(shape) {
Ok(reshaped) => reshaped.into_dyn(),
Err(e) => {
return Err(AutogradError::OperationError(format!(
"Reshape error: {}",
e
)))
}
};
if x.requires_grad {
let originalshape = x.shape().to_vec();
let backward = Box::new(move |grad: Array<F, IxDyn>| -> Result<Array<F, IxDyn>> {
match grad.clone().into_shape_with_order(originalshape.clone()) {
Ok(reshaped_grad) => Ok(reshaped_grad),
Err(e) => Err(AutogradError::OperationError(format!(
"Gradient reshape error: {}",
e
))),
}
})
as Box<dyn Fn(Array<F, IxDyn>) -> Result<Array<F, IxDyn>> + Send + Sync>;
let node = Node::new(super::graph::OpType::Reshape, vec![x], vec![Some(backward)]);
Ok(Tensor::from_operation(result_data, node, true))
} else {
Ok(Tensor::new(result_data, false))
}
}
#[allow(dead_code)]
pub fn transpose<F: Float + Debug + Send + Sync + 'static>(
x: &Tensor<F>,
dim0: usize,
dim1: usize,
) -> Result<Tensor<F>> {
let ndim = x.data.ndim();
if dim0 >= ndim || dim1 >= ndim {
return Err(AutogradError::ShapeMismatch(format!(
"Transpose dimensions {}, {} out of bounds for tensor with {} dimensions",
dim0, dim1, ndim
)));
}
let result_data = x.data.clone().permuted_axes(
(0..ndim)
.map(|i| {
if i == dim0 {
dim1
} else if i == dim1 {
dim0
} else {
i
}
})
.collect::<Vec<_>>(),
);
if x.requires_grad {
let backward = Box::new(move |grad: Array<F, IxDyn>| -> Result<Array<F, IxDyn>> {
Ok(grad.permuted_axes(
(0..ndim)
.map(|i| {
if i == dim0 {
dim1
} else if i == dim1 {
dim0
} else {
i
}
})
.collect::<Vec<_>>(),
))
})
as Box<dyn Fn(Array<F, IxDyn>) -> Result<Array<F, IxDyn>> + Send + Sync>;
let node = Node::new(
super::graph::OpType::Transpose,
vec![x],
vec![Some(backward)],
);
Ok(Tensor::from_operation(result_data, node, true))
} else {
Ok(Tensor::new(result_data, false))
}
}
#[allow(dead_code)]
pub fn log<F: Float + Debug + Send + Sync + 'static>(x: &Tensor<F>) -> Result<Tensor<F>> {
let result_data = x.data.mapv(|v| v.ln());
if x.requires_grad {
let x_data = x.data.clone();
let backward = Box::new(move |grad: Array<F, IxDyn>| -> Result<Array<F, IxDyn>> {
Ok(&grad / &x_data)
})
as Box<dyn Fn(Array<F, IxDyn>) -> Result<Array<F, IxDyn>> + Send + Sync>;
let node = Node::new(super::graph::OpType::Log, vec![x], vec![Some(backward)]);
Ok(Tensor::from_operation(result_data, node, true))
} else {
Ok(Tensor::new(result_data, false))
}
}
#[allow(dead_code)]
pub fn exp<F: Float + Debug + Send + Sync + 'static>(x: &Tensor<F>) -> Result<Tensor<F>> {
let result_data = x.data.mapv(|v| v.exp());
if x.requires_grad {
let result_data_clone = result_data.clone();
let backward = Box::new(move |grad: Array<F, IxDyn>| -> Result<Array<F, IxDyn>> {
Ok(&grad * &result_data_clone)
})
as Box<dyn Fn(Array<F, IxDyn>) -> Result<Array<F, IxDyn>> + Send + Sync>;
let node = Node::new(super::graph::OpType::Exp, vec![x], vec![Some(backward)]);
Ok(Tensor::from_operation(result_data, node, true))
} else {
Ok(Tensor::new(result_data, false))
}
}
#[allow(dead_code)]
pub fn softmax<F: Float + Debug + Send + Sync + 'static>(
x: &Tensor<F>,
dim: usize,
) -> Result<Tensor<F>> {
let ndim = x.data.ndim();
if dim >= ndim {
return Err(AutogradError::ShapeMismatch(format!(
"Softmax dimension {} out of bounds for tensor with {} dimensions",
dim, ndim
)));
}
let max_vals = x.data.map_axis(scirs2_core::ndarray::Axis(dim), |view| {
view.fold(F::neg_infinity(), |a, &b| if a > b { a } else { b })
});
let mut exp_vals = x.data.clone();
for (mut row, &max) in exp_vals
.lanes_mut(scirs2_core::ndarray::Axis(dim))
.into_iter()
.zip(max_vals.iter())
{
row.mapv_inplace(|v| (v - max).exp());
}
let sum_vals = exp_vals.map_axis(scirs2_core::ndarray::Axis(dim), |view| view.sum());
let mut result_data = exp_vals.clone();
for (mut row, &sum) in result_data
.lanes_mut(scirs2_core::ndarray::Axis(dim))
.into_iter()
.zip(sum_vals.iter())
{
row.mapv_inplace(|v| v / sum);
}
if x.requires_grad {
let result_data_clone = result_data.clone();
let _dim_clone = dim;
let backward = Box::new(move |grad: Array<F, IxDyn>| -> Result<Array<F, IxDyn>> {
let mut dx = grad.clone();
let s_times_gy = &dx * &result_data_clone;
if result_data_clone.ndim() == 1 {
let dot_s_gy = result_data_clone.dot(&dx);
dx = &s_times_gy - &(&result_data_clone * dot_s_gy);
} else {
let last_dim = result_data_clone.ndim() - 1;
let mut result = s_times_gy.clone();
for idx in scirs2_core::ndarray::indices(result_data_clone.shape()[..last_dim].iter().cloned()) {
let s_batch = result_data_clone.index_axis(scirs2_core::ndarray::Axis(last_dim-1), idx[last_dim-1]);
let gy_batch = dx.index_axis(scirs2_core::ndarray::Axis(last_dim-1), idx[last_dim-1]);
let dot_s_gy = (&s_batch * &gy_batch).sum();
let mut view = result.index_axis_mut(scirs2_core::ndarray::Axis(last_dim-1), idx[last_dim-1]);
view -= &(&s_batch * dot_s_gy);
}
dx = result;
}
Ok(dx)
})
as Box<dyn Fn(Array<F, IxDyn>) -> Result<Array<F, IxDyn>> + Send + Sync>;
let node = Node::new(
super::graph::OpType::Activation("softmax".to_string()),
vec![x],
vec![Some(backward)],
);
Ok(Tensor::from_operation(result_data, node, true))
} else {
Ok(Tensor::new(result_data, false))
}
}
#[allow(dead_code)]
pub fn cat<F: Float + Debug + Send + Sync + 'static>(
tensors: &[&Tensor<F>],
dim: usize,
) -> Result<Tensor<F>> {
if tensors.is_empty() {
return Err(AutogradError::OperationError(
"Cannot concatenate empty list of tensors".to_string(),
));
}
let refshape = tensors[0].shape();
let ndim = refshape.len();
if dim >= ndim {
return Err(AutogradError::ShapeMismatch(format!(
"Concatenation dimension {} out of bounds for tensor with {} dimensions",
dim, ndim
)));
}
for (i, tensor) in tensors.iter().enumerate().skip(1) {
let shape = tensor.shape();
if shape.len() != ndim {
return Err(AutogradError::ShapeMismatch(format!(
"All tensors must have the same number of dimensions, but tensor 0 has {} and tensor {} has {}",
ndim, i, shape.len()
)));
}
for (j, (&s1, &s2)) in refshape.iter().zip(shape.iter()).enumerate() {
if j != dim && s1 != s2 {
return Err(AutogradError::ShapeMismatch(format!(
"Incompatible shapes for concatenation: {:?} and {:?} at dimension {}",
refshape, shape, j
)));
}
}
}
let mut resultshape = refshape.to_vec();
resultshape[dim] = tensors.iter().map(|t| t.shape()[dim]).sum();
let mut result_data = Array::<F, IxDyn>::zeros(resultshape.clone());
let mut offset = 0;
for tensor in tensors {
let size_along_dim = tensor.shape()[dim];
let mut indices: Vec<scirs2_core::ndarray::SliceInfo<_, scirs2_core::ndarray::SliceArg>> =
(0..ndim).map(|d| if d == dim {
scirs2_core::ndarray::s![offset..offset + size_along_dim]
} else {
scirs2_core::ndarray::s![..]
}).collect();
let mut slice = result_data.slice_each_axis_mut(|ax| indices[ax.axis.index()]);
slice.assign(&tensor.data);
offset += size_along_dim;
}
let requires_grad = tensors.iter().any(|t| t.requires_grad);
if requires_grad {
let mut backward_fns = Vec::with_capacity(tensors.len());
let mut offsets = Vec::with_capacity(tensors.len());
let mut total_dim_size = 0;
for tensor in tensors {
let offset = total_dim_size;
offsets.push(offset);
total_dim_size += tensor.shape()[dim];
if tensor.requires_grad {
let _tensorshape = tensor.shape();
let _dim_clone = dim;
backward_fns.push(Some(Box::new(
move |grad: Array<F, IxDyn>| -> Result<Array<F, IxDyn>> {
let tensorshape = tensorshape.clone();
let dim = dim_clone;
let offset = offset;
let size_along_dim = tensorshape[dim];
let mut indices: Vec<scirs2_core::ndarray::SliceInfo<_, scirs2_core::ndarray::SliceArg>> =
(0..grad.ndim()).map(|d| if d == dim {
scirs2_core::ndarray::s![offset..offset + size_along_dim]
} else {
scirs2_core::ndarray::s![..]
}).collect();
let grad_slice = grad.slice_each_axis(|ax| indices[ax.axis.index()]).to_owned();
Ok(grad_slice)
},
)
as Box<dyn Fn(Array<F, IxDyn>) -> Result<Array<F, IxDyn>> + Send + Sync>));
} else {
backward_fns.push(None);
}
}
let node = Node::new(super::graph::OpType::Concat, tensors.to_vec(), backward_fns);
Ok(Tensor::from_operation(result_data, node, requires_grad))
} else {
Ok(Tensor::new(result_data, false))
}
}