use crate::custom_gradient::CustomGradientOp;
use crate::error::{AutogradError, OpError};
use crate::op::{self, ComputeContext, GradientContext};
use crate::tensor::Tensor;
use crate::{Context, Float, Result};
use scirs2_core::ndarray::{ArrayD, ArrayViewD};
use std::sync::Arc;
type FwdBoxed<F> = Arc<dyn Fn(&[ArrayViewD<F>]) -> ArrayD<F> + Send + Sync + 'static>;
type BwdBoxed<F> = Arc<
dyn for<'g> Fn(
&Tensor<'g, F>,
&[Tensor<'g, F>],
&'g Context<'g, F>,
) -> Vec<Option<Tensor<'g, F>>>
+ Send
+ Sync
+ 'static,
>;
struct ClosureCustomOp<F: Float> {
name: &'static str,
forward: FwdBoxed<F>,
backward: BwdBoxed<F>,
num_inputs: usize,
}
unsafe impl<F: Float> Send for ClosureCustomOp<F> {}
unsafe impl<F: Float> Sync for ClosureCustomOp<F> {}
impl<F: Float> CustomGradientOp<F> for ClosureCustomOp<F> {
fn forward(&self, inputs: &[ArrayViewD<F>]) -> std::result::Result<ArrayD<F>, OpError> {
Ok((self.forward)(inputs))
}
fn backward<'g>(
&self,
output_grad: &Tensor<'g, F>,
saved_tensors: &[Tensor<'g, F>],
ctx: &'g Context<'g, F>,
) -> Vec<Option<Tensor<'g, F>>> {
(self.backward)(output_grad, saved_tensors, ctx)
}
fn num_inputs(&self) -> usize {
self.num_inputs
}
fn name(&self) -> &'static str {
self.name
}
fn saves_inputs(&self) -> bool {
true
}
fn saves_output(&self) -> bool {
false
}
}
pub fn register_custom_gradient<'g, F, FwdFn, BwdFn>(
forward: FwdFn,
backward: BwdFn,
inputs: &[Tensor<'g, F>],
ctx: &'g Context<'g, F>,
) -> Tensor<'g, F>
where
F: Float,
FwdFn: Fn(&[ArrayViewD<F>]) -> ArrayD<F> + Send + Sync + 'static,
BwdFn: for<'a> Fn(
&Tensor<'a, F>,
&[Tensor<'a, F>],
&'a Context<'a, F>,
) -> Vec<Option<Tensor<'a, F>>>
+ Send
+ Sync
+ 'static,
{
let num_in = inputs.len().max(1);
let op = Arc::new(ClosureCustomOp {
name: "CustomGradOp",
forward: Arc::new(forward),
backward: Arc::new(backward),
num_inputs: num_in,
});
crate::custom_gradient::custom_op(op, inputs, ctx)
}
pub fn stop_gradient_tensor<'g, F: Float>(
x: Tensor<'g, F>,
_ctx: &'g Context<'g, F>,
) -> Tensor<'g, F> {
crate::tensor_ops::stop_gradient(x)
}
pub fn checkpoint_fn<'g, F, Func>(
f: Func,
x: Tensor<'g, F>,
ctx: &'g Context<'g, F>,
) -> Result<Tensor<'g, F>>
where
F: Float,
Func: for<'a> Fn(&Tensor<'a, F>) -> Tensor<'a, F> + Clone + Send + Sync + 'static,
{
let y = f(&x);
let f_bwd = f;
let backward_closure: BwdBoxed<F> = Arc::new(
move |gy: &Tensor<'_, F>,
saved: &[Tensor<'_, F>],
_ctx: &Context<'_, F>|
-> Vec<Option<Tensor<'_, F>>> {
if saved.len() < 2 {
return vec![None, None];
}
let x_saved = saved[1];
let y_recomputed = f_bwd(&x_saved);
let y_flat = crate::tensor_ops::flatten(y_recomputed);
let gy_flat = crate::tensor_ops::flatten(*gy);
let dot = crate::tensor_ops::reduction::sum_all(gy_flat * y_flat);
let g_x = crate::tensor_ops::grad(&[dot], &[x_saved])[0];
vec![None, Some(g_x)]
},
);
let op = Arc::new(ClosureCustomOp {
name: "CheckpointOp",
forward: Arc::new(|inputs: &[ArrayViewD<F>]| inputs[0].to_owned()),
backward: backward_closure,
num_inputs: 2,
});
let checkpointed = crate::custom_gradient::custom_op(op, &[y, x], ctx);
Ok(checkpointed)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor_ops::*;
#[test]
fn test_custom_gradient_identity_forward() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[3]);
let y = register_custom_gradient(
|inputs| inputs[0].to_owned(),
|gy, _saved, _ctx| vec![Some(*gy)],
&[x],
ctx,
);
let x_val = scirs2_core::ndarray::arr1(&[1.0f64, 2.0, 3.0]);
let out = ctx
.evaluator()
.push(&y)
.feed(x, x_val.view().into_dyn())
.run();
let arr = out[0].as_ref().expect("should eval");
let s = arr.as_slice().expect("slice");
assert!((s[0] - 1.0).abs() < 1e-9);
assert!((s[1] - 2.0).abs() < 1e-9);
assert!((s[2] - 3.0).abs() < 1e-9);
});
}
#[test]
fn test_custom_gradient_doubled_forward() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[2]);
let y = register_custom_gradient(
|inputs| inputs[0].mapv(|v: f64| v * 2.0),
|gy, _saved, _ctx| vec![Some(*gy * 2.0_f64)],
&[x],
ctx,
);
let x_val = scirs2_core::ndarray::arr1(&[3.0f64, 4.0]);
let out = ctx
.evaluator()
.push(&y)
.feed(x, x_val.view().into_dyn())
.run();
let arr = out[0].as_ref().expect("eval");
let s = arr.as_slice().expect("slice");
assert!((s[0] - 6.0).abs() < 1e-9, "forward 3*2=6, got {}", s[0]);
assert!((s[1] - 8.0).abs() < 1e-9, "forward 4*2=8, got {}", s[1]);
});
}
#[test]
fn test_custom_gradient_abs_forward() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[4]);
let y = register_custom_gradient(
|inputs| inputs[0].mapv(|v: f64| v.abs()),
|gy, saved, _ctx| {
let sign = crate::tensor_ops::map(saved[0], |arr| arr.mapv(|v: f64| if v >= 0.0 { 1.0 } else { -1.0 }));
vec![Some(*gy * sign)]
},
&[x],
ctx,
);
let x_val = scirs2_core::ndarray::arr1(&[-2.0f64, -1.0, 1.0, 2.0]);
let out = ctx
.evaluator()
.push(&y)
.feed(x, x_val.view().into_dyn())
.run();
let arr = out[0].as_ref().expect("eval");
let s = arr.as_slice().expect("slice");
assert!((s[0] - 2.0).abs() < 1e-9, "|-2|=2, got {}", s[0]);
assert!((s[1] - 1.0).abs() < 1e-9, "|-1|=1, got {}", s[1]);
assert!((s[2] - 1.0).abs() < 1e-9, "|1|=1, got {}", s[2]);
assert!((s[3] - 2.0).abs() < 1e-9, "|2|=2, got {}", s[3]);
});
}
#[test]
fn test_stop_gradient_tensor_forward_unchanged() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[3]);
let stopped = stop_gradient_tensor(x, ctx);
let x_val = scirs2_core::ndarray::arr1(&[10.0f64, 20.0, 30.0]);
let out = ctx
.evaluator()
.push(&stopped)
.feed(x, x_val.view().into_dyn())
.run();
let arr = out[0].as_ref().expect("eval");
let s = arr.as_slice().expect("slice");
assert!((s[0] - 10.0).abs() < 1e-9);
assert!((s[1] - 20.0).abs() < 1e-9);
assert!((s[2] - 30.0).abs() < 1e-9);
});
}
#[test]
fn test_stop_gradient_tensor_blocks_gradient() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[2]);
let stopped = stop_gradient_tensor(x, ctx);
let loss = reduction::sum_all(stopped + x);
let grads = grad(&[loss], &[x]);
let x_val = scirs2_core::ndarray::arr1(&[5.0f64, 7.0]);
let out = ctx
.evaluator()
.push(&grads[0])
.feed(x, x_val.view().into_dyn())
.run();
let arr = out[0].as_ref().expect("eval");
let s = arr.as_slice().expect("slice");
assert!(s[0].is_finite(), "gradient should be finite, got {}", s[0]);
assert!(s[1].is_finite(), "gradient should be finite, got {}", s[1]);
});
}
#[test]
fn test_checkpoint_fn_forward_value() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[3]);
let y_direct = x * x;
let y_ckpt = checkpoint_fn(|t| *t * *t, x, ctx)
.expect("checkpoint should succeed");
let x_val = scirs2_core::ndarray::arr1(&[1.0f64, 2.0, 3.0]);
let outs = ctx
.evaluator()
.push(&y_direct)
.push(&y_ckpt)
.feed(x, x_val.view().into_dyn())
.run();
let direct = outs[0].as_ref().expect("direct eval").as_slice().expect("s");
let ckpt = outs[1].as_ref().expect("ckpt eval").as_slice().expect("s");
for (a, b) in direct.iter().zip(ckpt.iter()) {
assert!(
(a - b).abs() < 1e-9,
"checkpoint forward must equal direct: {} vs {}",
a,
b
);
}
});
}
#[test]
fn test_checkpoint_fn_gradient_flow() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[3]);
let y = checkpoint_fn(|t| *t * *t, x, ctx)
.expect("checkpoint should succeed");
let loss = reduction::sum_all(y);
let grads = grad(&[loss], &[x]);
let x_val = scirs2_core::ndarray::arr1(&[1.0f64, 2.0, 3.0]);
let out = ctx
.evaluator()
.push(&grads[0])
.feed(x, x_val.view().into_dyn())
.run();
let arr = out[0].as_ref().expect("eval");
let s = arr.as_slice().expect("slice");
assert!(s[0] > 0.0, "grad[0] should be positive, got {}", s[0]);
assert!(s[1] > 0.0, "grad[1] should be positive, got {}", s[1]);
assert!(s[2] > 0.0, "grad[2] should be positive, got {}", s[2]);
assert!(s[1] > s[0], "grad[1] > grad[0]");
assert!(s[2] > s[1], "grad[2] > grad[1]");
});
}
#[test]
fn test_checkpoint_fn_cubic_gradient() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[]);
let y = checkpoint_fn(|t| *t * *t * *t, x, ctx)
.expect("checkpoint cubic should succeed");
let grads = grad(&[y], &[x]);
let x_val = scirs2_core::ndarray::arr0(2.0f64);
let out = ctx
.evaluator()
.push(&grads[0])
.feed(x, x_val.view().into_dyn())
.run();
let val = out[0]
.as_ref()
.expect("eval")
.first()
.copied()
.expect("first");
assert!(
(val - 12.0).abs() < 2.0,
"checkpoint cubic grad at 2 = 12, got {}",
val
);
});
}
}