use crate::error::AutogradError;
use crate::tensor::Tensor;
use crate::{Context, Float, Result};
pub fn grad_fn<'g, F, Func>(
f: Func,
) -> impl Fn(&Tensor<'g, F>, &'g Context<'g, F>) -> Result<Tensor<'g, F>>
where
F: Float,
Func: Fn(&Tensor<'g, F>) -> Tensor<'g, F> + 'static,
{
move |x: &Tensor<'g, F>, _ctx: &'g Context<'g, F>| {
let y = f(x);
let g = crate::tensor_ops::grad(&[y], &[*x])[0];
Ok(g)
}
}
pub fn value_and_grad_fn<'g, F, Func>(
f: Func,
) -> impl Fn(&Tensor<'g, F>, &'g Context<'g, F>) -> Result<(Tensor<'g, F>, Tensor<'g, F>)>
where
F: Float,
Func: Fn(&Tensor<'g, F>) -> Tensor<'g, F> + 'static,
{
move |x: &Tensor<'g, F>, _ctx: &'g Context<'g, F>| {
let y = f(x);
let g = crate::tensor_ops::grad(&[y], &[*x])[0];
Ok((y, g))
}
}
pub fn vmap_graph<'g, F, Func>(
f: Func,
x: &Tensor<'g, F>,
_ctx: &'g Context<'g, F>,
batch_size: usize,
sample_dim: usize,
) -> Result<Tensor<'g, F>>
where
F: Float,
Func: Fn(&Tensor<'g, F>) -> Tensor<'g, F>,
{
if batch_size == 0 {
return Err(AutogradError::OperationError(
"vmap_graph: batch_size must be positive".to_string(),
));
}
if sample_dim == 0 {
return Err(AutogradError::OperationError(
"vmap_graph: sample_dim must be positive".to_string(),
));
}
let x_flat = crate::tensor_ops::flatten(*x);
let mut outputs = Vec::with_capacity(batch_size);
for b in 0..batch_size {
let start = (b * sample_dim) as isize;
let end = ((b + 1) * sample_dim) as isize;
let row = crate::tensor_ops::slice(x_flat, [start], [end]);
let out_b = f(&row);
let out_b_flat = crate::tensor_ops::flatten(out_b);
outputs.push(out_b_flat);
}
Ok(crate::tensor_ops::linear_algebra::concat(&outputs, 0))
}
pub fn jit_hint<'g, F, Func>(f: Func) -> impl Fn(Tensor<'g, F>) -> Tensor<'g, F>
where
F: Float,
Func: Fn(&Tensor<'g, F>) -> Tensor<'g, F> + 'static,
{
move |x: Tensor<'g, F>| f(&x)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor_ops::*;
#[test]
fn test_grad_fn_x_squared() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[]);
let gf = grad_fn(|t: &Tensor<'_, f64>| *t * *t);
let g = gf(&x, ctx).expect("grad_fn should succeed");
let x_val = scirs2_core::ndarray::arr0(5.0f64);
let out = ctx
.evaluator()
.push(&g)
.feed(x, x_val.view().into_dyn())
.run();
let val = out[0]
.as_ref()
.expect("should eval")
.first()
.copied()
.expect("first");
assert!(
(val - 10.0).abs() < 1e-9,
"d(x^2)/dx at x=5 should be 10, got {}",
val
);
});
}
#[test]
fn test_grad_fn_cubic() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[]);
let gf = grad_fn(|t: &Tensor<'_, f64>| *t * *t * *t);
let g = gf(&x, ctx).expect("grad_fn cubic should succeed");
let x_val = scirs2_core::ndarray::arr0(2.0f64);
let out = ctx
.evaluator()
.push(&g)
.feed(x, x_val.view().into_dyn())
.run();
let val = out[0]
.as_ref()
.expect("eval")
.first()
.copied()
.expect("first");
assert!(
(val - 12.0).abs() < 1e-9,
"d(x^3)/dx at x=2 should be 12, got {}",
val
);
});
}
#[test]
fn test_grad_fn_multivar_sum_of_squares() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[3]);
let gf = grad_fn(|t: &Tensor<'_, f64>| reduction::sum_all(*t * *t));
let g = gf(&x, ctx).expect("grad_fn multivar should succeed");
let x_val = scirs2_core::ndarray::arr1(&[1.0f64, 2.0, 3.0]);
let out = ctx
.evaluator()
.push(&g)
.feed(x, x_val.view().into_dyn())
.run();
let arr = out[0].as_ref().expect("eval");
let s = arr.as_slice().expect("slice");
assert!((s[0] - 2.0).abs() < 1e-9, "∇f[0] = 2, got {}", s[0]);
assert!((s[1] - 4.0).abs() < 1e-9, "∇f[1] = 4, got {}", s[1]);
assert!((s[2] - 6.0).abs() < 1e-9, "∇f[2] = 6, got {}", s[2]);
});
}
#[test]
fn test_grad_fn_element_wise() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[2]);
let gf = grad_fn(|t: &Tensor<'_, f64>| {
let axes = [0_isize];
crate::tensor_ops::reduce_sum(*t * *t, &axes, false)
});
let g = gf(&x, ctx).expect("grad_fn element-wise should succeed");
let x_val = scirs2_core::ndarray::arr1(&[2.0f64, 3.0]);
let out = ctx
.evaluator()
.push(&g)
.feed(x, x_val.view().into_dyn())
.run();
let arr = out[0].as_ref().expect("should eval");
let s = arr.as_slice().expect("slice");
assert!((s[0] - 4.0).abs() < 1e-9, "grad[0]=4, got {}", s[0]);
assert!((s[1] - 6.0).abs() < 1e-9, "grad[1]=6, got {}", s[1]);
});
}
#[test]
fn test_value_and_grad_fn_consistency() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[]);
let vg = value_and_grad_fn(|t: &Tensor<'_, f64>| *t * *t * *t);
let (val_t, grad_t) = vg(&x, ctx).expect("value_and_grad should succeed");
let x_val = scirs2_core::ndarray::arr0(2.0f64);
let outs = ctx
.evaluator()
.push(&val_t)
.push(&grad_t)
.feed(x, x_val.view().into_dyn())
.run();
let v = outs[0]
.as_ref()
.expect("val eval")
.first()
.copied()
.expect("v");
let g = outs[1]
.as_ref()
.expect("grad eval")
.first()
.copied()
.expect("g");
assert!((v - 8.0).abs() < 1e-9, "f(2)=8, got {}", v);
assert!((g - 12.0).abs() < 1e-9, "f'(2)=12, got {}", g);
});
}
#[test]
fn test_value_and_grad_fn_quadratic() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[3]);
let vg = value_and_grad_fn(|t: &Tensor<'_, f64>| {
reduction::sum_all(*t * *t) * 0.5_f64
});
let (val_t, grad_t) = vg(&x, ctx).expect("value_and_grad quad should succeed");
let x_val = scirs2_core::ndarray::arr1(&[1.0f64, 2.0, 3.0]);
let outs = ctx
.evaluator()
.push(&val_t)
.push(&grad_t)
.feed(x, x_val.view().into_dyn())
.run();
let v = outs[0]
.as_ref()
.expect("val eval")
.first()
.copied()
.expect("v");
assert!((v - 7.0).abs() < 1e-9, "f([1,2,3])=7, got {}", v);
let g_arr = outs[1].as_ref().expect("grad eval");
let g = g_arr.as_slice().expect("slice");
assert!((g[0] - 1.0).abs() < 1e-9, "∇f[0]=1, got {}", g[0]);
assert!((g[1] - 2.0).abs() < 1e-9, "∇f[1]=2, got {}", g[1]);
assert!((g[2] - 3.0).abs() < 1e-9, "∇f[2]=3, got {}", g[2]);
});
}
#[test]
fn test_vmap_graph_squared() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[2, 2]);
let out = vmap_graph(
|s| {
let s0 = slice(*s, [0isize], [1isize]);
let s1 = slice(*s, [1isize], [2isize]);
reduction::sum_all(s0 * s0 + s1 * s1)
},
&x,
ctx,
2,
2,
)
.expect("vmap_graph should succeed");
let x_val =
scirs2_core::ndarray::Array2::from_shape_vec((2, 2), vec![1.0f64, 2.0, 3.0, 4.0])
.expect("shape ok")
.into_dyn();
let outs = ctx
.evaluator()
.push(&out)
.feed(x, x_val.view())
.run();
let arr = outs[0].as_ref().expect("eval");
let s = arr.as_slice().expect("slice");
assert!((s[0] - 5.0).abs() < 1e-6, "batch[0] expected 5, got {}", s[0]);
assert!((s[1] - 25.0).abs() < 1e-6, "batch[1] expected 25, got {}", s[1]);
});
}
#[test]
fn test_vmap_graph_empty_batch_error() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[2, 2]);
let result = vmap_graph(|s| *s, &x, ctx, 0, 2);
assert!(result.is_err(), "empty batch should return error");
});
}
#[test]
fn test_jit_hint_passthrough() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[3]);
let jitted = jit_hint(|t: &Tensor<'_, f64>| *t * *t);
let y = jitted(x);
let y_direct = x * x;
let x_val = scirs2_core::ndarray::arr1(&[1.0f64, 2.0, 3.0]);
let outs = ctx
.evaluator()
.push(&y)
.push(&y_direct)
.feed(x, x_val.view().into_dyn())
.run();
let jit_s = outs[0].as_ref().expect("jit eval").as_slice().expect("s1");
let dir_s = outs[1].as_ref().expect("direct eval").as_slice().expect("s2");
for (a, b) in jit_s.iter().zip(dir_s.iter()) {
assert!((a - b).abs() < 1e-12, "jit_hint should be identity");
}
});
}
#[test]
fn test_canonical_grad_x_sq_eq_2x() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[]);
let gf = grad_fn(|t: &Tensor<'_, f64>| *t * *t);
for &xv in &[0.0f64, 1.0, 2.0, 3.0, -1.0, -2.5] {
let g = gf(&x, ctx).expect("grad should succeed");
let x_val = scirs2_core::ndarray::arr0(xv);
let out = ctx
.evaluator()
.push(&g)
.feed(x, x_val.view().into_dyn())
.run();
let val = out[0]
.as_ref()
.expect("eval")
.first()
.copied()
.expect("first");
let expected = 2.0 * xv;
assert!(
(val - expected).abs() < 1e-9,
"d(x^2)/dx at {} = {}, got {}",
xv,
expected,
val
);
}
});
}
}