use crate::error::OpError;
use crate::op::{self, ComputeContext, GradientContext};
use crate::tensor::Tensor;
use crate::{Context, Float, NdArray};
use scirs2_core::ndarray::{ArrayD, ArrayViewD};
use std::sync::Arc;
pub trait CustomGradientOp<F: Float>: Send + Sync {
fn forward(&self, inputs: &[ArrayViewD<F>]) -> Result<ArrayD<F>, OpError>;
fn backward<'g>(
&self,
output_grad: &Tensor<'g, F>,
saved_tensors: &[Tensor<'g, F>],
ctx: &'g Context<'g, F>,
) -> Vec<Option<Tensor<'g, F>>>;
fn num_inputs(&self) -> usize;
fn name(&self) -> &'static str {
"CustomGradientOp"
}
fn saves_inputs(&self) -> bool {
true
}
fn saves_output(&self) -> bool {
true
}
}
struct CustomGradientWrapper<F: Float> {
inner: Arc<dyn CustomGradientOp<F>>,
}
impl<F: Float> op::Op<F> for CustomGradientWrapper<F> {
fn name(&self) -> &'static str {
self.inner.name()
}
fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
let input_views: Vec<ArrayViewD<F>> = ctx.inputs();
let output = self.inner.forward(&input_views)?;
ctx.append_output(output);
Ok(())
}
fn grad<'a>(&self, ctx: &mut GradientContext<'a, 'a, F>) {
let output_grad = ctx.output_grad();
let graph = ctx.graph();
let mut saved = Vec::new();
if self.inner.saves_inputs() {
for i in 0..ctx.num_inputs() {
saved.push(*ctx.input(i));
}
}
if self.inner.saves_output() {
saved.push(*ctx.output());
}
let input_grads = self.inner.backward(output_grad, &saved, graph);
for (i, grad) in input_grads.into_iter().enumerate() {
ctx.append_input_grad(i, grad);
}
}
}
pub fn custom_op<'g, F: Float>(
op: Arc<dyn CustomGradientOp<F>>,
inputs: &[Tensor<'g, F>],
ctx: &'g Context<'g, F>,
) -> Tensor<'g, F> {
let wrapper = CustomGradientWrapper { inner: op };
let mut builder = Tensor::builder(ctx);
for input in inputs {
builder = builder.append_input(input, false);
}
builder.build(wrapper)
}
pub fn custom_unary_op<'g, F, FwdFn, BwdFn>(
name: &'static str,
forward_fn: FwdFn,
backward_fn: BwdFn,
input: Tensor<'g, F>,
ctx: &'g Context<'g, F>,
) -> Tensor<'g, F>
where
F: Float,
FwdFn: Fn(&ArrayViewD<F>) -> ArrayD<F> + Send + Sync + 'static,
BwdFn: Fn(&Tensor<'g, F>, &Tensor<'g, F>, &Tensor<'g, F>) -> Option<Tensor<'g, F>>
+ Send
+ Sync
+ 'static,
{
struct ClosureOp<F: Float, Fwd, Bwd> {
name: &'static str,
forward: Fwd,
backward: Bwd,
_phantom: std::marker::PhantomData<F>,
}
unsafe impl<F: Float, Fwd: Send, Bwd: Send> Send for ClosureOp<F, Fwd, Bwd> {}
unsafe impl<F: Float, Fwd: Sync, Bwd: Sync> Sync for ClosureOp<F, Fwd, Bwd> {}
impl<F: Float, Fwd, Bwd> op::Op<F> for ClosureOp<F, Fwd, Bwd>
where
Fwd: Fn(&ArrayViewD<F>) -> ArrayD<F> + Send + Sync,
Bwd: Send + Sync,
{
fn name(&self) -> &'static str {
self.name
}
fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
let input = ctx.input(0);
let output = (self.forward)(&input);
ctx.append_output(output);
Ok(())
}
fn grad<'a>(&self, ctx: &mut GradientContext<'a, 'a, F>) {
let gy = ctx.output_grad();
ctx.append_input_grad(0, Some(*gy));
}
}
let op = ClosureOp {
name,
forward: forward_fn,
backward: backward_fn,
_phantom: std::marker::PhantomData,
};
Tensor::builder(ctx).append_input(input, false).build(op)
}
pub struct SelectiveStopGradient {
mask: Vec<bool>,
}
impl SelectiveStopGradient {
pub fn new(mask: Vec<bool>) -> Self {
Self { mask }
}
pub fn block_indices(size: usize, blocked_indices: &[usize]) -> Self {
let mut mask = vec![true; size];
for &idx in blocked_indices {
if idx < size {
mask[idx] = false;
}
}
Self { mask }
}
pub fn allow_indices(size: usize, allowed_indices: &[usize]) -> Self {
let mut mask = vec![false; size];
for &idx in allowed_indices {
if idx < size {
mask[idx] = true;
}
}
Self { mask }
}
}
impl<F: Float> op::Op<F> for SelectiveStopGradient {
fn name(&self) -> &'static str {
"SelectiveStopGradient"
}
fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
let input = ctx.input(0);
ctx.append_output(input.to_owned());
Ok(())
}
fn grad(&self, ctx: &mut GradientContext<'_, '_, F>) {
let gy = ctx.output_grad();
let mask_vals: Vec<F> = self
.mask
.iter()
.map(|&m| if m { F::one() } else { F::zero() })
.collect();
let mask_arr = scirs2_core::ndarray::Array1::from(mask_vals).into_dyn();
let mask_tensor = crate::tensor_ops::convert_to_tensor(mask_arr, ctx.graph());
let masked_grad = *gy * mask_tensor;
ctx.append_input_grad(0, Some(masked_grad));
}
}
pub fn selective_stop_gradient<'g, F: Float>(
input: Tensor<'g, F>,
mask: Vec<bool>,
ctx: &'g Context<'g, F>,
) -> Tensor<'g, F> {
let op = SelectiveStopGradient::new(mask);
Tensor::builder(ctx).append_input(input, false).build(op)
}
pub struct ScaleGradient<F: Float> {
scale: F,
}
impl<F: Float> ScaleGradient<F> {
pub fn new(scale: F) -> Self {
Self { scale }
}
}
impl<F: Float> op::Op<F> for ScaleGradient<F> {
fn name(&self) -> &'static str {
"ScaleGradient"
}
fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
let input = ctx.input(0);
ctx.append_output(input.to_owned());
Ok(())
}
fn grad(&self, ctx: &mut GradientContext<'_, '_, F>) {
let gy = ctx.output_grad();
let scaled = *gy * self.scale;
ctx.append_input_grad(0, Some(scaled));
}
}
pub fn scale_gradient<'g, F: Float>(
input: Tensor<'g, F>,
scale: F,
ctx: &'g Context<'g, F>,
) -> Tensor<'g, F> {
let op = ScaleGradient::new(scale);
Tensor::builder(ctx).append_input(input, false).build(op)
}
pub fn gradient_reversal<'g, F: Float>(
input: Tensor<'g, F>,
ctx: &'g Context<'g, F>,
) -> Tensor<'g, F> {
let neg_one = F::from(-1.0).unwrap_or_else(|| F::zero() - F::one());
scale_gradient(input, neg_one, ctx)
}
struct DetachOp;
impl<F: Float> op::Op<F> for DetachOp {
fn name(&self) -> &'static str {
"Detach"
}
fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
let input = ctx.input(0);
ctx.append_output(input.to_owned());
Ok(())
}
fn grad(&self, ctx: &mut GradientContext<'_, '_, F>) {
ctx.append_input_grad(0, None);
}
}
pub fn detach<'g, F: Float>(input: Tensor<'g, F>, ctx: &'g Context<'g, F>) -> Tensor<'g, F> {
Tensor::builder(ctx)
.append_input(input, false)
.build(DetachOp)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor_ops;
use crate::tensor_ops::*;
use std::sync::Arc;
struct DoubledGradOp;
impl CustomGradientOp<f64> for DoubledGradOp {
fn forward(&self, inputs: &[ArrayViewD<f64>]) -> Result<ArrayD<f64>, OpError> {
Ok(inputs[0].to_owned())
}
fn backward<'g>(
&self,
output_grad: &Tensor<'g, f64>,
_saved: &[Tensor<'g, f64>],
_ctx: &'g Context<'g, f64>,
) -> Vec<Option<Tensor<'g, f64>>> {
vec![Some(*output_grad * 2.0)]
}
fn num_inputs(&self) -> usize {
1
}
fn name(&self) -> &'static str {
"DoubledGrad"
}
}
#[test]
fn test_custom_op_forward() {
crate::run(|ctx: &mut Context<f64>| {
let x = convert_to_tensor(scirs2_core::ndarray::arr1(&[1.0, 2.0, 3.0]).into_dyn(), ctx);
let op = Arc::new(DoubledGradOp);
let y = custom_op(op, &[x], ctx);
let result = y.eval(ctx);
match result {
Ok(arr) => {
let vals = arr.as_slice().unwrap_or(&[]);
assert!((vals[0] - 1.0).abs() < 1e-10);
assert!((vals[1] - 2.0).abs() < 1e-10);
assert!((vals[2] - 3.0).abs() < 1e-10);
}
Err(e) => panic!("Forward eval failed: {e:?}"),
}
});
}
#[test]
fn test_custom_op_backward_doubled() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[3]);
let op = Arc::new(DoubledGradOp);
let y = custom_op(op, &[x], ctx);
let loss = crate::tensor_ops::reduction::sum_all(y);
let grads = crate::tensor_ops::grad(&[loss], &[x]);
let x_val = scirs2_core::ndarray::arr1(&[1.0, 2.0, 3.0]);
let result = ctx
.evaluator()
.push(&grads[0])
.feed(x, x_val.view().into_dyn())
.run();
let grad_arr = result[0].as_ref().expect("Should evaluate gradient");
let grad_vals = grad_arr.as_slice().unwrap_or(&[]);
for val in grad_vals {
assert!(val.is_finite(), "Gradient should be finite");
assert!(*val > 0.0, "Gradient should be positive");
}
});
}
struct StraightThroughEstimator;
impl CustomGradientOp<f64> for StraightThroughEstimator {
fn forward(&self, inputs: &[ArrayViewD<f64>]) -> Result<ArrayD<f64>, OpError> {
Ok(inputs[0].mapv(|v| v.round()))
}
fn backward<'g>(
&self,
output_grad: &Tensor<'g, f64>,
_saved: &[Tensor<'g, f64>],
_ctx: &'g Context<'g, f64>,
) -> Vec<Option<Tensor<'g, f64>>> {
vec![Some(*output_grad)]
}
fn num_inputs(&self) -> usize {
1
}
fn name(&self) -> &'static str {
"STE"
}
}
#[test]
fn test_straight_through_estimator() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[4]);
let op = Arc::new(StraightThroughEstimator);
let y = custom_op(op, &[x], ctx);
let x_val = scirs2_core::ndarray::arr1(&[0.3, 1.7, -0.5, 2.9]);
let fwd_result = ctx
.evaluator()
.push(&y)
.feed(x, x_val.view().into_dyn())
.run();
let fwd_arr = fwd_result[0].as_ref().expect("Forward should work");
let fwd_vals = fwd_arr.as_slice().unwrap_or(&[]);
assert!((fwd_vals[0] - 0.0).abs() < 1e-10);
assert!((fwd_vals[1] - 2.0).abs() < 1e-10);
let loss = crate::tensor_ops::reduction::sum_all(y);
let grads = crate::tensor_ops::grad(&[loss], &[x]);
let grad_result = ctx
.evaluator()
.push(&grads[0])
.feed(x, x_val.view().into_dyn())
.run();
let grad_arr = grad_result[0].as_ref().expect("Gradient should work");
let grad_vals = grad_arr.as_slice().unwrap_or(&[]);
for val in grad_vals {
assert!(val.is_finite(), "STE gradient should be finite");
}
});
}
#[test]
fn test_selective_stop_gradient() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[4]);
let mask = vec![true, false, true, false];
let y = selective_stop_gradient(x, mask, ctx);
let loss = crate::tensor_ops::reduction::sum_all(y);
let grads = crate::tensor_ops::grad(&[loss], &[x]);
let x_val = scirs2_core::ndarray::arr1(&[1.0, 2.0, 3.0, 4.0]);
let result = ctx
.evaluator()
.push(&grads[0])
.feed(x, x_val.view().into_dyn())
.run();
let grad_arr = result[0].as_ref().expect("Should evaluate");
let grad_vals = grad_arr.as_slice().unwrap_or(&[]);
for val in grad_vals {
assert!(val.is_finite(), "Gradient should be finite");
}
assert!(grad_vals.len() == 4, "Should have 4 gradient elements");
});
}
#[test]
fn test_scale_gradient() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[3]);
let y = scale_gradient(x, 0.5, ctx);
let loss = crate::tensor_ops::reduction::sum_all(y);
let grads = crate::tensor_ops::grad(&[loss], &[x]);
let x_val = scirs2_core::ndarray::arr1(&[1.0, 2.0, 3.0]);
let result = ctx
.evaluator()
.push(&grads[0])
.feed(x, x_val.view().into_dyn())
.run();
let grad_arr = result[0].as_ref().expect("Should evaluate");
let grad_vals = grad_arr.as_slice().unwrap_or(&[]);
for val in grad_vals {
assert!(val.is_finite(), "Gradient should be finite");
}
});
}
#[test]
fn test_gradient_reversal() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[2]);
let y = gradient_reversal(x, ctx);
let loss = crate::tensor_ops::reduction::sum_all(y);
let grads = crate::tensor_ops::grad(&[loss], &[x]);
let x_val = scirs2_core::ndarray::arr1(&[1.0, 2.0]);
let result = ctx
.evaluator()
.push(&grads[0])
.feed(x, x_val.view().into_dyn())
.run();
let grad_arr = result[0].as_ref().expect("Should evaluate");
let grad_vals = grad_arr.as_slice().unwrap_or(&[]);
for val in grad_vals {
assert!(val.is_finite(), "Gradient should be finite");
}
let sum: f64 = grad_vals.iter().copied().sum();
assert!(sum.abs() > 1e-15, "Gradient sum should be nonzero");
});
}
#[test]
fn test_detach() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[3]);
let y = x * 2.0;
let z = super::detach(y, ctx);
let loss = crate::tensor_ops::reduction::sum_all(z + x);
let grads = crate::tensor_ops::grad(&[loss], &[x]);
let x_val = scirs2_core::ndarray::arr1(&[1.0, 2.0, 3.0]);
let result = ctx
.evaluator()
.push(&grads[0])
.feed(x, x_val.view().into_dyn())
.run();
let grad_arr = result[0].as_ref().expect("Should evaluate");
let grad_vals = grad_arr.as_slice().unwrap_or(&[]);
for val in grad_vals {
assert!(val.is_finite(), "Gradient should be finite");
}
});
}
#[test]
fn test_block_indices() {
let ssg = SelectiveStopGradient::block_indices(5, &[1, 3]);
assert!(ssg.mask[0]);
assert!(!ssg.mask[1]);
assert!(ssg.mask[2]);
assert!(!ssg.mask[3]);
assert!(ssg.mask[4]);
}
#[test]
fn test_allow_indices() {
let ssg = SelectiveStopGradient::allow_indices(5, &[0, 2, 4]);
assert!(ssg.mask[0]);
assert!(!ssg.mask[1]);
assert!(ssg.mask[2]);
assert!(!ssg.mask[3]);
assert!(ssg.mask[4]);
}
#[test]
fn test_custom_op_name() {
let op = DoubledGradOp;
assert_eq!(op.name(), "DoubledGrad");
assert!(op.saves_inputs());
assert!(op.saves_output());
assert_eq!(op.num_inputs(), 1);
}
}