use super::{GradFn, GradStore, Var, VarGradStore, var_add};
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::TensorOps;
use crate::runtime::{Runtime, RuntimeClient};
use crate::tensor::{Tensor, TensorId};
use std::collections::HashSet;
use std::sync::Arc;
pub trait BackwardHook<R: Runtime>: Send {
fn on_leaf_grad_ready(&mut self, id: TensorId, grad: &Tensor<R>);
}
pub struct NoOpHook;
impl<R: Runtime> BackwardHook<R> for NoOpHook {
fn on_leaf_grad_ready(&mut self, _id: TensorId, _grad: &Tensor<R>) {}
}
#[inline]
fn validate_loss<R: Runtime>(loss: &Var<R>, fn_name: &str) -> Result<()> {
if loss.numel() != 1 {
return Err(Error::ShapeMismatch {
expected: vec![1],
got: loss.shape().to_vec(),
});
}
if !loss.requires_grad() {
return Err(Error::Internal(format!(
"{}() called on tensor that doesn't require grad",
fn_name
)));
}
Ok(())
}
#[inline]
fn create_loss_gradient<R: Runtime<DType = DType>>(loss: &Var<R>) -> Tensor<R> {
Tensor::<R>::ones(loss.shape(), loss.tensor().dtype(), loss.tensor().device())
}
pub fn backward<R, C>(loss: &Var<R>, client: &C) -> Result<GradStore<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + TensorOps<R>,
{
backward_with_hooks(loss, client, &mut NoOpHook)
}
pub fn backward_with_hooks<R, C, H>(
loss: &Var<R>,
client: &C,
hooks: &mut H,
) -> Result<GradStore<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + TensorOps<R>,
H: BackwardHook<R>,
{
validate_loss(loss, "backward_with_hooks")?;
let mut grad_store = GradStore::new();
grad_store.insert(loss.id(), create_loss_gradient(loss));
let topo_order = topological_sort(loss);
for var_entry in topo_order.into_iter().rev() {
let (var_id, grad_fn_opt, input_ids) = var_entry;
let grad_output = match grad_store.get(var_id) {
Some(g) => g.clone(),
None => continue, };
if let Some(grad_fn) = grad_fn_opt {
let input_grads = grad_fn.backward(&grad_output)?;
for (input_id, input_grad_opt) in input_ids.iter().zip(input_grads.into_iter()) {
if let Some(input_grad) = input_grad_opt {
grad_store.try_accumulate(*input_id, input_grad, |existing, new| {
client.add(&existing, &new)
})?;
}
}
} else {
hooks.on_leaf_grad_ready(var_id, &grad_output);
}
}
Ok(grad_store)
}
pub fn backward_with_graph<R, C>(loss: &Var<R>, client: &C) -> Result<VarGradStore<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + TensorOps<R>,
R::Client: TensorOps<R>,
{
validate_loss(loss, "backward_with_graph")?;
let mut var_grad_store = VarGradStore::new();
var_grad_store.insert(loss.id(), Var::new(create_loss_gradient(loss), true));
let topo_order = topological_sort(loss);
for var_entry in topo_order.into_iter().rev() {
let (var_id, grad_fn_opt, input_ids) = var_entry;
let grad_output = match var_grad_store.get_var(var_id) {
Some(g) => g.clone(),
None => continue, };
if let Some(grad_fn) = grad_fn_opt {
let input_grads = grad_fn.backward_var(&grad_output)?;
for (input_id, input_grad_opt) in input_ids.iter().zip(input_grads.into_iter()) {
if let Some(input_grad) = input_grad_opt {
var_grad_store.try_accumulate(*input_id, input_grad, |existing, new| {
var_add(&existing, &new, client)
})?;
}
}
}
}
Ok(var_grad_store)
}
type TopoEntry<R> = (TensorId, Option<Arc<dyn GradFn<R>>>, Vec<TensorId>);
fn topological_sort<R: Runtime>(loss: &Var<R>) -> Vec<TopoEntry<R>> {
let mut result = Vec::new();
let mut visited = HashSet::new();
fn dfs<R: Runtime>(
id: TensorId,
grad_fn: Option<Arc<dyn GradFn<R>>>,
visited: &mut HashSet<TensorId>,
result: &mut Vec<TopoEntry<R>>,
) {
if visited.contains(&id) {
return;
}
visited.insert(id);
let input_ids: Vec<TensorId> = grad_fn
.as_ref()
.map(|gf| gf.inputs().to_vec())
.unwrap_or_default();
if let Some(gf) = &grad_fn {
for (input_id, input_grad_fn) in input_ids.iter().zip(gf.input_grad_fns()) {
dfs(*input_id, input_grad_fn, visited, result);
}
}
result.push((id, grad_fn, input_ids));
}
dfs(
loss.id(),
loss.grad_fn().cloned(),
&mut visited,
&mut result,
);
result
}
#[cfg(test)]
mod tests {
use super::*;
use crate::autograd::{var_mul, var_sum};
use crate::runtime::cpu::{CpuDevice, CpuRuntime};
use std::cell::RefCell;
use std::rc::Rc;
struct RecordingHook {
leaf_ids: Rc<RefCell<Vec<TensorId>>>,
}
impl RecordingHook {
fn new() -> (Self, Rc<RefCell<Vec<TensorId>>>) {
let ids = Rc::new(RefCell::new(Vec::new()));
(
Self {
leaf_ids: ids.clone(),
},
ids,
)
}
}
unsafe impl Send for RecordingHook {}
impl BackwardHook<CpuRuntime> for RecordingHook {
fn on_leaf_grad_ready(&mut self, id: TensorId, _grad: &Tensor<CpuRuntime>) {
self.leaf_ids.borrow_mut().push(id);
}
}
#[test]
fn test_backward_with_hooks_matches_backward() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[2.0f32], &[1], &device),
true,
);
let y = Var::new(
Tensor::<CpuRuntime>::from_slice(&[3.0f32], &[1], &device),
true,
);
let z1 = var_mul(&x, &y, &client).unwrap();
let z2 = var_mul(&x, &y, &client).unwrap();
let grads1 = backward(&z1, &client).unwrap();
let (mut hook, leaf_ids) = RecordingHook::new();
let grads2 = backward_with_hooks(&z2, &client, &mut hook).unwrap();
let gx1: Vec<f32> = grads1.get(x.id()).unwrap().to_vec();
let gx2: Vec<f32> = grads2.get(x.id()).unwrap().to_vec();
assert!((gx1[0] - gx2[0]).abs() < 1e-6);
let gy1: Vec<f32> = grads1.get(y.id()).unwrap().to_vec();
let gy2: Vec<f32> = grads2.get(y.id()).unwrap().to_vec();
assert!((gy1[0] - gy2[0]).abs() < 1e-6);
let ids = leaf_ids.borrow();
assert_eq!(ids.len(), 2);
assert!(ids.contains(&x.id()));
assert!(ids.contains(&y.id()));
}
#[test]
fn test_backward_with_hooks_no_hook_for_non_leaf() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[2.0f32, 3.0], &[2], &device),
true,
);
let x_sq = var_mul(&x, &x, &client).unwrap();
let loss = var_sum(&x_sq, &[0], false, &client).unwrap();
let (mut hook, leaf_ids) = RecordingHook::new();
let _grads = backward_with_hooks(&loss, &client, &mut hook).unwrap();
let ids = leaf_ids.borrow();
assert_eq!(ids.len(), 1);
assert!(ids.contains(&x.id()));
}
#[test]
fn test_backward_requires_scalar() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let tensor = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &device);
let var = Var::new(tensor, true);
let result = backward(&var, &client);
assert!(result.is_err());
}
#[test]
fn test_backward_leaf_variable() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let tensor = Tensor::<CpuRuntime>::from_slice(&[5.0f32], &[1], &device);
let var = Var::new(tensor, true);
let grads = backward(&var, &client).unwrap();
let grad = grads.get(var.id()).unwrap();
let grad_data: Vec<f32> = grad.to_vec();
assert_eq!(grad_data, vec![1.0f32]);
}
#[test]
fn test_backward_with_graph_requires_scalar() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let tensor = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &device);
let var = Var::new(tensor, true);
let result = backward_with_graph(&var, &client);
assert!(result.is_err());
}
#[test]
fn test_backward_with_graph_leaf_variable() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let tensor = Tensor::<CpuRuntime>::from_slice(&[5.0f32], &[1], &device);
let var = Var::new(tensor, true);
let grads = backward_with_graph(&var, &client).unwrap();
let grad_var = grads.get_var(var.id()).unwrap();
let grad_data: Vec<f32> = grad_var.tensor().to_vec();
assert_eq!(grad_data, vec![1.0f32]);
assert!(grad_var.requires_grad());
}
#[test]
fn test_backward_with_graph_simple_mul() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[3.0f32], &[1], &device),
true,
);
let y = var_mul(&x, &x, &client).unwrap();
let grads = backward_with_graph(&y, &client).unwrap();
let grad_x = grads.get_var(x.id()).unwrap();
let grad_data: Vec<f32> = grad_x.tensor().to_vec();
assert!((grad_data[0] - 6.0).abs() < 1e-6);
assert!(grad_x.requires_grad());
}
#[test]
fn test_backward_with_graph_matches_backward() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[2.0f32], &[1], &device),
true,
);
let y = Var::new(
Tensor::<CpuRuntime>::from_slice(&[3.0f32], &[1], &device),
true,
);
let z1 = var_mul(&x, &y, &client).unwrap();
let z2 = var_mul(&x, &y, &client).unwrap();
let grads1 = backward(&z1, &client).unwrap();
let grads2 = backward_with_graph(&z2, &client).unwrap();
let grad_x1: Vec<f32> = grads1.get(x.id()).unwrap().to_vec();
let grad_x2: Vec<f32> = grads2.get(x.id()).unwrap().to_vec();
assert!((grad_x1[0] - grad_x2[0]).abs() < 1e-6);
let grad_y1: Vec<f32> = grads1.get(y.id()).unwrap().to_vec();
let grad_y2: Vec<f32> = grads2.get(y.id()).unwrap().to_vec();
assert!((grad_y1[0] - grad_y2[0]).abs() < 1e-6);
}
#[test]
fn test_backward_with_graph_to_grad_store() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[2.0f32], &[1], &device),
true,
);
let y = var_mul(&x, &x, &client).unwrap();
let var_grads = backward_with_graph(&y, &client).unwrap();
let grad_store = var_grads.to_grad_store();
let grad_x: Vec<f32> = grad_store.get(x.id()).unwrap().to_vec();
assert!((grad_x[0] - 4.0).abs() < 1e-6); }
#[test]
fn test_second_order_derivative_x_squared() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[3.0f32], &[1], &device),
true,
);
let y = var_mul(&x, &x, &client).unwrap();
let grads = backward_with_graph(&y, &client).unwrap();
let grad_x = grads.get_var(x.id()).unwrap();
let first_deriv: Vec<f32> = grad_x.tensor().to_vec();
assert!((first_deriv[0] - 6.0).abs() < 1e-6);
let grad_x_sum = var_sum(grad_x, &[], false, &client).unwrap();
let second_grads = backward(&grad_x_sum, &client).unwrap();
let second_deriv: Vec<f32> = second_grads.get(x.id()).unwrap().to_vec();
assert!(
(second_deriv[0] - 2.0).abs() < 1e-5,
"Expected 2.0, got {}",
second_deriv[0]
);
}
#[test]
fn test_hessian_vector_product() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[3.0f32], &[1], &device),
true,
);
let y = var_mul(&x, &x, &client).unwrap();
let grads = backward_with_graph(&y, &client).unwrap();
let grad_x = grads.get_var(x.id()).unwrap();
let v = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32], &[1], &device),
false, );
let grad_v = var_mul(grad_x, &v, &client).unwrap();
let grad_v_sum = var_sum(&grad_v, &[], false, &client).unwrap();
let hvp_grads = backward(&grad_v_sum, &client).unwrap();
let hvp: Vec<f32> = hvp_grads.get(x.id()).unwrap().to_vec();
assert!(
(hvp[0] - 2.0).abs() < 1e-5,
"Expected HVP = 2.0, got {}",
hvp[0]
);
}
#[test]
fn test_second_order_add() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[3.0f32], &[1], &device),
true,
);
let y = Var::new(
Tensor::<CpuRuntime>::from_slice(&[2.0f32], &[1], &device),
true,
);
let z = crate::autograd::var_add(&x, &y, &client).unwrap();
let grads = backward_with_graph(&z, &client).unwrap();
let grad_x: Vec<f32> = grads.get(x.id()).unwrap().to_vec();
let grad_y: Vec<f32> = grads.get(y.id()).unwrap().to_vec();
assert!((grad_x[0] - 1.0).abs() < 1e-6);
assert!((grad_y[0] - 1.0).abs() < 1e-6);
let grad_x_var = grads.get_var(x.id()).unwrap();
let grad_x_sum = var_sum(grad_x_var, &[], false, &client).unwrap();
let second_grads = backward(&grad_x_sum, &client).unwrap();
assert!(
second_grads.get(x.id()).is_none(),
"Expected no second-order gradient for add"
);
}
#[test]
fn test_second_order_sub() {
use crate::autograd::var_sub;
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[3.0f32], &[1], &device),
true,
);
let y = Var::new(
Tensor::<CpuRuntime>::from_slice(&[2.0f32], &[1], &device),
true,
);
let z = var_sub(&x, &y, &client).unwrap();
let grads = backward_with_graph(&z, &client).unwrap();
let grad_x: Vec<f32> = grads.get(x.id()).unwrap().to_vec();
let grad_y: Vec<f32> = grads.get(y.id()).unwrap().to_vec();
assert!((grad_x[0] - 1.0).abs() < 1e-6);
assert!((grad_y[0] - (-1.0)).abs() < 1e-6);
}
#[test]
fn test_second_order_div() {
use crate::autograd::var_div;
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let one = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32], &[1], &device),
false,
);
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[2.0f32], &[1], &device),
true,
);
let y = var_div(&one, &x, &client).unwrap();
let grads = backward_with_graph(&y, &client).unwrap();
let grad_x: Vec<f32> = grads.get(x.id()).unwrap().to_vec();
assert!(
(grad_x[0] - (-0.25)).abs() < 1e-5,
"Expected -0.25, got {}",
grad_x[0]
);
let grad_x_var = grads.get_var(x.id()).unwrap();
let grad_x_sum = var_sum(grad_x_var, &[], false, &client).unwrap();
let second_grads = backward(&grad_x_sum, &client).unwrap();
let second_deriv: Vec<f32> = second_grads.get(x.id()).unwrap().to_vec();
assert!(
(second_deriv[0] - 0.25).abs() < 1e-4,
"Expected 0.25, got {}",
second_deriv[0]
);
}
#[test]
fn test_second_order_through_sum() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[3.0f32, 4.0], &[2], &device),
true,
);
let x_squared = var_mul(&x, &x, &client).unwrap();
let y = var_sum(&x_squared, &[0], false, &client).unwrap();
let grads = backward_with_graph(&y, &client).unwrap();
let grad_x: Vec<f32> = grads.get(x.id()).unwrap().to_vec();
assert!(
(grad_x[0] - 6.0).abs() < 1e-5,
"Expected 6.0, got {}",
grad_x[0]
);
assert!(
(grad_x[1] - 8.0).abs() < 1e-5,
"Expected 8.0, got {}",
grad_x[1]
);
let grad_x_var = grads.get_var(x.id()).unwrap();
let grad_x_sum = var_sum(grad_x_var, &[0], false, &client).unwrap(); let second_grads = backward(&grad_x_sum, &client).unwrap();
let second_deriv: Vec<f32> = second_grads.get(x.id()).unwrap().to_vec();
assert!(
(second_deriv[0] - 2.0).abs() < 1e-4,
"Expected 2.0, got {}",
second_deriv[0]
);
assert!(
(second_deriv[1] - 2.0).abs() < 1e-4,
"Expected 2.0, got {}",
second_deriv[1]
);
}
#[test]
fn test_second_order_through_mean() {
use crate::autograd::var_mean;
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[3.0f32, 4.0], &[2], &device),
true,
);
let x_squared = var_mul(&x, &x, &client).unwrap();
let y = var_mean(&x_squared, &[0], false, &client).unwrap();
let grads = backward_with_graph(&y, &client).unwrap();
let grad_x: Vec<f32> = grads.get(x.id()).unwrap().to_vec();
assert!(
(grad_x[0] - 3.0).abs() < 1e-5,
"Expected 3.0, got {}",
grad_x[0]
);
assert!(
(grad_x[1] - 4.0).abs() < 1e-5,
"Expected 4.0, got {}",
grad_x[1]
);
let grad_x_var = grads.get_var(x.id()).unwrap();
let grad_x_sum = var_sum(grad_x_var, &[0], false, &client).unwrap(); let second_grads = backward(&grad_x_sum, &client).unwrap();
let second_deriv: Vec<f32> = second_grads.get(x.id()).unwrap().to_vec();
assert!(
(second_deriv[0] - 1.0).abs() < 1e-4,
"Expected 1.0, got {}",
second_deriv[0]
);
assert!(
(second_deriv[1] - 1.0).abs() < 1e-4,
"Expected 1.0, got {}",
second_deriv[1]
);
}
#[test]
fn test_second_order_through_mul_scalar() {
use crate::autograd::var_mul_scalar;
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[2.0f32], &[1], &device),
true,
);
let x_squared = var_mul(&x, &x, &client).unwrap();
let y = var_mul_scalar(&x_squared, 3.0, &client).unwrap();
let grads = backward_with_graph(&y, &client).unwrap();
let grad_x: Vec<f32> = grads.get(x.id()).unwrap().to_vec();
assert!(
(grad_x[0] - 12.0).abs() < 1e-5,
"Expected 12.0, got {}",
grad_x[0]
);
let grad_x_var = grads.get_var(x.id()).unwrap();
let grad_x_sum = var_sum(grad_x_var, &[0], false, &client).unwrap();
let second_grads = backward(&grad_x_sum, &client).unwrap();
let second_deriv: Vec<f32> = second_grads.get(x.id()).unwrap().to_vec();
assert!(
(second_deriv[0] - 6.0).abs() < 1e-4,
"Expected 6.0, got {}",
second_deriv[0]
);
}
#[test]
fn test_second_order_through_pow_scalar() {
use crate::autograd::var_pow_scalar;
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[2.0f32], &[1], &device),
true,
);
let y = var_pow_scalar(&x, 3.0, &client).unwrap();
let grads = backward_with_graph(&y, &client).unwrap();
let grad_x: Vec<f32> = grads.get(x.id()).unwrap().to_vec();
assert!(
(grad_x[0] - 12.0).abs() < 1e-5,
"Expected 12.0, got {}",
grad_x[0]
);
let grad_x_var = grads.get_var(x.id()).unwrap();
let grad_x_sum = var_sum(grad_x_var, &[0], false, &client).unwrap();
let second_grads = backward(&grad_x_sum, &client).unwrap();
let second_deriv: Vec<f32> = second_grads.get(x.id()).unwrap().to_vec();
assert!(
(second_deriv[0] - 12.0).abs() < 1e-4,
"Expected 12.0, got {}",
second_deriv[0]
);
}
#[test]
fn test_second_order_through_broadcast() {
use crate::autograd::var_add;
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &device),
true,
);
let b = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.1f32, 0.2], &[2], &device),
true,
);
let x_plus_b = var_add(&x, &b, &client).unwrap();
let squared = var_mul(&x_plus_b, &x_plus_b, &client).unwrap();
let loss = var_sum(&squared, &[0], false, &client).unwrap();
let grads = backward_with_graph(&loss, &client).unwrap();
assert!(grads.get(x.id()).is_some(), "Should have gradient for x");
assert!(grads.get(b.id()).is_some(), "Should have gradient for b");
let grad_x: Vec<f32> = grads.get(x.id()).unwrap().to_vec();
assert!(
(grad_x[0] - 2.2).abs() < 1e-5,
"Expected 2.2, got {}",
grad_x[0]
);
let grad_x_var = grads.get_var(x.id()).unwrap();
let grad_x_sum = var_sum(grad_x_var, &[0], false, &client).unwrap();
let second_grads = backward(&grad_x_sum, &client).unwrap();
let second_deriv_x: Vec<f32> = second_grads.get(x.id()).unwrap().to_vec();
for (i, &val) in second_deriv_x.iter().enumerate() {
assert!(
(val - 2.0).abs() < 1e-4,
"Expected d²L/dx²[{}] = 2.0, got {}",
i,
val
);
}
}
#[test]
fn test_second_order_through_broadcast_shapes() {
use crate::autograd::var_add;
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], &device),
true,
);
let b = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.1f32, 0.2, 0.3], &[3], &device),
true,
);
let x_plus_b = var_add(&x, &b, &client).unwrap();
let squared = var_mul(&x_plus_b, &x_plus_b, &client).unwrap();
let loss = var_sum(&squared, &[0, 1], false, &client).unwrap();
let grads = backward_with_graph(&loss, &client).unwrap();
assert!(grads.get(x.id()).is_some(), "Should have gradient for x");
assert!(grads.get(b.id()).is_some(), "Should have gradient for b");
assert_eq!(grads.get(x.id()).unwrap().shape(), &[2, 3]);
assert_eq!(grads.get(b.id()).unwrap().shape(), &[3]);
let grad_x: Vec<f32> = grads.get(x.id()).unwrap().to_vec();
assert!(
(grad_x[0] - 2.2).abs() < 1e-5,
"Expected 2.2, got {}",
grad_x[0]
);
let grad_b: Vec<f32> = grads.get(b.id()).unwrap().to_vec();
assert!(
(grad_b[0] - 10.4).abs() < 1e-4,
"Expected 10.4, got {}",
grad_b[0]
);
if let Some(grad_x_var) = grads.get_var(x.id()) {
let grad_x_sum = var_sum(grad_x_var, &[0, 1], false, &client).unwrap();
let second_grads = backward(&grad_x_sum, &client).unwrap();
if let Some(second_deriv_x) = second_grads.get(x.id()) {
let second_deriv_x: Vec<f32> = second_deriv_x.to_vec();
for (i, &val) in second_deriv_x.iter().enumerate() {
assert!(
(val - 2.0).abs() < 1e-4,
"Expected d²L/dx²[{}] = 2.0, got {}",
i,
val
);
}
}
}
}
}