use yscv_tensor::Tensor;
use super::error::AutogradError;
use super::graph::Graph;
use super::node::NodeId;
pub(crate) fn relu_backward(
graph: &mut Graph,
upstream: Tensor,
_index: usize,
input: NodeId,
) -> Result<(), AutogradError> {
let input_grad = {
let iv = &graph.nodes[input.0].value;
if iv.shape() != upstream.shape() {
return Err(AutogradError::InvalidGradientShape {
node: input.0,
expected: iv.shape().to_vec(),
got: upstream.shape().to_vec(),
});
}
let mut result = upstream;
relu_backward_slice(result.data_mut(), iv.data());
result
};
graph.accumulate_grad(input, input_grad)?;
Ok(())
}
pub(crate) fn exp_backward(
graph: &mut Graph,
upstream: Tensor,
index: usize,
input: NodeId,
) -> Result<(), AutogradError> {
let input_grad = {
let ov = &graph.nodes[index].value;
graph.dispatch_mul(&upstream, ov)?
};
graph.accumulate_grad(input, input_grad)?;
Ok(())
}
pub(crate) fn log_backward(
graph: &mut Graph,
upstream: Tensor,
input: NodeId,
) -> Result<(), AutogradError> {
let input_grad = {
let iv = &graph.nodes[input.0].value;
let inv_x = iv.reciprocal();
graph.dispatch_mul(&upstream, &inv_x)?
};
graph.accumulate_grad(input, input_grad)?;
Ok(())
}
pub(crate) fn sqrt_backward(
graph: &mut Graph,
upstream: Tensor,
index: usize,
input: NodeId,
) -> Result<(), AutogradError> {
let input_grad = {
let ov = &graph.nodes[index].value;
let half_inv = ov.reciprocal().scale(0.5);
graph.dispatch_mul(&upstream, &half_inv)?
};
graph.accumulate_grad(input, input_grad)?;
Ok(())
}
pub(crate) fn sigmoid_backward(
graph: &mut Graph,
upstream: Tensor,
index: usize,
input: NodeId,
) -> Result<(), AutogradError> {
let input_grad = {
let ov = &graph.nodes[index].value;
let one_minus = graph.dispatch_neg(ov).add_scalar(1.0);
let local = graph.dispatch_mul(ov, &one_minus)?;
graph.dispatch_mul(&upstream, &local)?
};
graph.accumulate_grad(input, input_grad)?;
Ok(())
}
pub(crate) fn tanh_backward(
graph: &mut Graph,
upstream: Tensor,
index: usize,
input: NodeId,
) -> Result<(), AutogradError> {
let input_grad = {
let ov = &graph.nodes[index].value;
let sq = graph.dispatch_mul(ov, ov)?;
let one_minus_sq = graph.dispatch_neg(&sq).add_scalar(1.0);
graph.dispatch_mul(&upstream, &one_minus_sq)?
};
graph.accumulate_grad(input, input_grad)?;
Ok(())
}
pub(crate) fn gelu_backward(
graph: &mut Graph,
upstream: Tensor,
input: NodeId,
) -> Result<(), AutogradError> {
let input_grad = {
let iv = &graph.nodes[input.0].value;
let mut result = upstream;
gelu_backward_slice(result.data_mut(), iv.data());
result
};
graph.accumulate_grad(input, input_grad)?;
Ok(())
}
pub(crate) fn silu_backward(
graph: &mut Graph,
upstream: Tensor,
input: NodeId,
) -> Result<(), AutogradError> {
let input_grad = {
let iv = &graph.nodes[input.0].value;
let mut result = upstream;
silu_backward_slice(result.data_mut(), iv.data());
result
};
graph.accumulate_grad(input, input_grad)?;
Ok(())
}
pub(crate) fn mish_backward(
graph: &mut Graph,
upstream: Tensor,
input: NodeId,
) -> Result<(), AutogradError> {
let input_grad = {
let iv = &graph.nodes[input.0].value;
let mut result = upstream;
mish_backward_slice(result.data_mut(), iv.data());
result
};
graph.accumulate_grad(input, input_grad)?;
Ok(())
}
pub(crate) fn leaky_relu_backward(
graph: &mut Graph,
upstream: Tensor,
input: NodeId,
negative_slope: u32,
) -> Result<(), AutogradError> {
let slope = f32::from_bits(negative_slope);
let input_grad = {
let iv = &graph.nodes[input.0].value;
let mut result = upstream;
leaky_relu_backward_slice(result.data_mut(), iv.data(), slope);
result
};
graph.accumulate_grad(input, input_grad)?;
Ok(())
}
pub(crate) fn softmax_backward(
graph: &mut Graph,
upstream: Tensor,
index: usize,
input: NodeId,
) -> Result<(), AutogradError> {
let input_grad = {
let ov = &graph.nodes[index].value;
let shape = ov.shape();
let last = *shape.last().unwrap_or(&1);
let outer = ov.len() / last;
let sm = ov.data();
let up = upstream.data();
let mut grad = vec![0.0f32; ov.len()];
for o in 0..outer {
let base = o * last;
let dot: f32 = (0..last).map(|i| up[base + i] * sm[base + i]).sum();
for i in 0..last {
grad[base + i] = sm[base + i] * (up[base + i] - dot);
}
}
Tensor::from_vec(shape.to_vec(), grad)?
};
graph.accumulate_grad(input, input_grad)?;
Ok(())
}
pub(crate) fn log_softmax_backward(
graph: &mut Graph,
upstream: Tensor,
index: usize,
input: NodeId,
) -> Result<(), AutogradError> {
let input_grad = {
let ov = &graph.nodes[index].value;
let shape = ov.shape();
let last = *shape.last().unwrap_or(&1);
let outer = ov.len() / last;
let ov_data = ov.data();
let up = upstream.data();
let mut grad = vec![0.0f32; ov.len()];
for o in 0..outer {
let base = o * last;
let sum_up: f32 = (0..last).map(|i| up[base + i]).sum();
for i in 0..last {
grad[base + i] = up[base + i] - ov_data[base + i].exp() * sum_up;
}
}
Tensor::from_vec(shape.to_vec(), grad)?
};
graph.accumulate_grad(input, input_grad)?;
Ok(())
}
#[inline(always)]
fn relu_backward_slice(grad: &mut [f32], input: &[f32]) {
debug_assert_eq!(grad.len(), input.len());
for i in 0..grad.len() {
if input[i] <= 0.0 {
grad[i] = 0.0;
}
}
}
#[inline(always)]
fn gelu_backward_slice(grad: &mut [f32], input: &[f32]) {
debug_assert_eq!(grad.len(), input.len());
for i in 0..grad.len() {
let x = input[i];
let a = 1.702 * x;
let ea = (-a).exp();
let s = 1.0 / (1.0 + ea);
grad[i] *= s + x * 1.702 * s * (1.0 - s);
}
}
#[inline(always)]
fn silu_backward_slice(grad: &mut [f32], input: &[f32]) {
debug_assert_eq!(grad.len(), input.len());
for i in 0..grad.len() {
let x = input[i];
let s = 1.0 / (1.0 + (-x).exp());
grad[i] *= s + x * s * (1.0 - s);
}
}
#[inline(always)]
fn mish_backward_slice(grad: &mut [f32], input: &[f32]) {
debug_assert_eq!(grad.len(), input.len());
for i in 0..grad.len() {
let x = input[i];
let sp = (1.0 + x.exp()).ln();
let tanh_sp = sp.tanh();
let sech2_sp = 1.0 - tanh_sp * tanh_sp;
let sig = 1.0 / (1.0 + (-x).exp());
grad[i] *= tanh_sp + x * sech2_sp * sig;
}
}
#[inline(always)]
fn leaky_relu_backward_slice(grad: &mut [f32], input: &[f32], slope: f32) {
debug_assert_eq!(grad.len(), input.len());
for i in 0..grad.len() {
if input[i] < 0.0 {
grad[i] *= slope;
}
}
}