use scirs2_core::ndarray::{Array, Axis, IxDyn};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use crate::error::{AutogradError, Result};
#[allow(dead_code)]
pub fn matmul_forward<F: Float + Debug + Send + Sync + 'static>(
a: &Array<F, IxDyn>,
b: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
if a.ndim() < 2 || b.ndim() < 2 {
return Err(AutogradError::ShapeMismatch(
"Matrix multiplication requires at least 2D tensors".to_string(),
));
}
let ashape = a.shape();
let bshape = b.shape();
if ashape[ashape.len() - 1] != bshape[bshape.len() - 2] {
return Err(AutogradError::ShapeMismatch(format!(
"Matrix multiplication dimension mismatch: {:?} and {:?}",
ashape, bshape
)));
}
if a.ndim() == 2 && b.ndim() == 2 {
let a_rows = a.shape()[0];
let a_cols = a.shape()[1];
let b_cols = b.shape()[1];
let mut result = Array::<F>::zeros((a_rows, b_cols));
for i in 0..a_rows {
for j in 0..b_cols {
let mut sum = F::zero();
for k in 0..a_cols {
sum = sum + a[[i, k]] * b[[k, j]];
}
result[[i, j]] = sum;
}
}
Ok(result.into_dyn())
} else {
Err(AutogradError::OperationError(
"Batched matrix multiplication not implemented yet".to_string(),
))
}
}
#[allow(dead_code)]
pub fn matmul_backward<F: Float + Debug + Send + Sync + 'static>(
grad: &Array<F, IxDyn>,
a: &Array<F, IxDyn>,
b: &Array<F, IxDyn>,
) -> Result<(Array<F, IxDyn>, Array<F, IxDyn>)> {
if a.ndim() != 2 || b.ndim() != 2 || grad.ndim() != 2 {
return Err(AutogradError::OperationError(
"Matrix multiplication gradient currently only implemented for 2D tensors".to_string(),
));
}
let ashape = a.shape();
let bshape = b.shape();
let gradshape = grad.shape();
if gradshape[0] != ashape[0] || gradshape[1] != bshape[1] {
return Err(AutogradError::ShapeMismatch(format!(
"Gradient shape mismatch: {:?} for matmul of {:?} and {:?}",
gradshape, ashape, bshape
)));
}
let grad_rows = grad.shape()[0];
let grad_cols = grad.shape()[1];
let b_rows = b.shape()[0];
let b_cols = b.shape()[1];
let mut b_t = Array::<F>::zeros((b_cols, b_rows));
for i in 0..b_rows {
for j in 0..b_cols {
b_t[[j, i]] = b[[i, j]];
}
}
let mut grad_a = Array::<F>::zeros((grad_rows, b_rows));
for i in 0..grad_rows {
for j in 0..b_rows {
let mut sum = F::zero();
for k in 0..grad_cols {
sum = sum + grad[[i, k]] * b_t[[k, j]];
}
grad_a[[i, j]] = sum;
}
}
let grad_a = grad_a.into_dyn();
let a_rows = a.shape()[0];
let a_cols = a.shape()[1];
let mut a_t = Array::<F>::zeros((a_cols, a_rows));
for i in 0..a_rows {
for j in 0..a_cols {
a_t[[j, i]] = a[[i, j]];
}
}
let mut grad_b = Array::<F>::zeros((a_cols, grad_cols));
for i in 0..a_cols {
for j in 0..grad_cols {
let mut sum = F::zero();
for k in 0..a_rows {
sum = sum + a_t[[i, k]] * grad[[k, j]];
}
grad_b[[i, j]] = sum;
}
}
let grad_b = grad_b.into_dyn();
Ok((grad_a, grad_b))
}
#[allow(dead_code)]
pub fn relu_forward<F: Float + Debug + Send + Sync + 'static>(
input: &Array<F, IxDyn>,
) -> Array<F, IxDyn> {
input.mapv(|x| if x > F::zero() { x } else { F::zero() })
}
#[allow(dead_code)]
pub fn relu_backward<F: Float + Debug + Send + Sync + 'static>(
grad: &Array<F, IxDyn>,
input: &Array<F, IxDyn>,
) -> Array<F, IxDyn> {
let mut result = grad.clone();
for (r, &i) in result.iter_mut().zip(input.iter()) {
if i <= F::zero() {
*r = F::zero();
}
}
result
}
#[allow(dead_code)]
pub fn sigmoid_forward<F: Float + Debug + Send + Sync + 'static>(
input: &Array<F, IxDyn>,
) -> Array<F, IxDyn> {
input.mapv(|x| F::one() / (F::one() + (-x).exp()))
}
#[allow(dead_code)]
pub fn sigmoid_backward<F: Float + Debug + Send + Sync + 'static>(
grad: &Array<F, IxDyn>,
output: &Array<F, IxDyn>,
) -> Array<F, IxDyn> {
let sigmoid_grad = output.mapv(|y| y * (F::one() - y));
grad * &sigmoid_grad
}
#[allow(dead_code)]
pub fn tanh_forward<F: Float + Debug + Send + Sync + 'static>(
input: &Array<F, IxDyn>,
) -> Array<F, IxDyn> {
input.mapv(|x| x.tanh())
}
#[allow(dead_code)]
pub fn tanh_backward<F: Float + Debug + Send + Sync + 'static>(
grad: &Array<F, IxDyn>,
output: &Array<F, IxDyn>,
) -> Array<F, IxDyn> {
let tanh_grad = output.mapv(|y| F::one() - y * y);
grad * &tanh_grad
}
#[allow(dead_code)]
pub fn softmax_forward<F: Float + Debug + Send + Sync + 'static>(
input: &Array<F, IxDyn>,
dim: usize,
) -> Result<Array<F, IxDyn>> {
if dim >= input.ndim() {
return Err(AutogradError::ShapeMismatch(format!(
"Softmax dimension {} out of bounds for tensor with {} dimensions",
dim,
input.ndim()
)));
}
let max_vals = input.map_axis(Axis(dim), |view| {
view.fold(F::neg_infinity(), |a, &b| if a > b { a } else { b })
});
let mut exp_vals = input.clone();
for (mut row, &max) in exp_vals
.lanes_mut(Axis(dim))
.into_iter()
.zip(max_vals.iter())
{
row.mapv_inplace(|v| (v - max).exp());
}
let sum_vals = exp_vals.map_axis(Axis(dim), |view| view.sum());
let mut result = exp_vals;
for (mut row, &sum) in result.lanes_mut(Axis(dim)).into_iter().zip(sum_vals.iter()) {
row.mapv_inplace(|v| v / sum);
}
Ok(result)
}
#[allow(dead_code)]
pub fn softmax_backward<F: Float + Debug + Send + Sync + 'static>(
grad: &Array<F, IxDyn>,
output: &Array<F, IxDyn>,
dim: usize,
) -> Result<Array<F, IxDyn>> {
if dim >= output.ndim() {
return Err(AutogradError::ShapeMismatch(format!(
"Softmax dimension {} out of bounds for tensor with {} dimensions",
dim,
output.ndim()
)));
}
let result = output.clone() * grad;
Ok(result)
}