use crate::error::AutogradError;
use crate::forward_mode::DualNumber;
use num::Float as NumFloat;
use scirs2_core::ndarray::{Array1, Array2, Axis};
use std::fmt;
use std::sync::Arc;
type SegmentFn<F> = Arc<dyn Fn(&Array2<F>) -> Array2<F> + Send + Sync>;
pub fn vmap<F, Func>(func: Func, inputs: &Array2<F>) -> Result<Array2<F>, AutogradError>
where
F: NumFloat + Copy + Send + Sync + fmt::Debug + 'static,
Func: Fn(&Array1<F>) -> Array1<F> + Send + Sync,
{
let batch_size = inputs.nrows();
if batch_size == 0 {
return Err(AutogradError::OperationError(
"vmap: input batch is empty".to_string(),
));
}
let mut results: Vec<Array1<F>> = Vec::with_capacity(batch_size);
for i in 0..batch_size {
let row = inputs.row(i).to_owned();
results.push(func(&row));
}
let out_dim = results[0].len();
for (i, r) in results.iter().enumerate() {
if r.len() != out_dim {
return Err(AutogradError::ShapeMismatch(format!(
"vmap: output dimension mismatch at row {}: expected {}, got {}",
i,
out_dim,
r.len()
)));
}
}
let mut output = Array2::<F>::zeros((batch_size, out_dim));
for (i, r) in results.iter().enumerate() {
for j in 0..out_dim {
output[[i, j]] = r[j];
}
}
Ok(output)
}
pub fn pmap<F, Func>(func: Func, inputs: &Array2<F>) -> Result<Array2<F>, AutogradError>
where
F: NumFloat + Copy + Send + Sync + fmt::Debug + 'static,
Func: Fn(&Array1<F>) -> Array1<F> + Send + Sync,
{
let batch_size = inputs.nrows();
if batch_size == 0 {
return Err(AutogradError::OperationError(
"pmap: input batch is empty".to_string(),
));
}
let rows: Vec<Array1<F>> = (0..batch_size).map(|i| inputs.row(i).to_owned()).collect();
let results: Vec<Array1<F>> = std::thread::scope(|scope| {
let handles: Vec<_> = rows
.iter()
.map(|row| {
let func_ref = &func;
scope.spawn(move || func_ref(row))
})
.collect();
handles
.into_iter()
.map(|h| h.join().unwrap_or_else(|_| Array1::zeros(0)))
.collect()
});
for (i, r) in results.iter().enumerate() {
if r.is_empty() && inputs.ncols() > 0 {
return Err(AutogradError::OperationError(format!(
"pmap: thread for row {} failed",
i,
)));
}
}
let out_dim = results[0].len();
for (i, r) in results.iter().enumerate() {
if r.len() != out_dim {
return Err(AutogradError::ShapeMismatch(format!(
"pmap: output dimension mismatch at row {}: expected {}, got {}",
i,
out_dim,
r.len()
)));
}
}
let mut output = Array2::<F>::zeros((batch_size, out_dim));
for (i, r) in results.iter().enumerate() {
for j in 0..out_dim {
output[[i, j]] = r[j];
}
}
Ok(output)
}
pub fn grad<F, Func>(func: Func) -> impl Fn(&Array1<F>) -> Array1<F>
where
F: NumFloat + Copy + fmt::Debug + Send + Sync + 'static,
Func: Fn(&[DualNumber<F>]) -> DualNumber<F> + Clone + 'static,
{
move |x: &Array1<F>| {
let n = x.len();
let mut gradient = Array1::<F>::zeros(n);
for i in 0..n {
let duals: Vec<DualNumber<F>> = x
.iter()
.enumerate()
.map(|(k, &xk)| {
if k == i {
DualNumber::new(xk, F::one())
} else {
DualNumber::new(xk, F::zero())
}
})
.collect();
gradient[i] = func.clone()(&duals).tangent();
}
gradient
}
}
pub fn grad_grad<F, Func>(func: Func) -> impl Fn(&Array1<F>) -> Array2<F>
where
F: NumFloat + Copy + fmt::Debug + Send + Sync + 'static,
Func: Fn(&[DualNumber<F>]) -> DualNumber<F> + Clone + 'static,
{
move |x: &Array1<F>| crate::forward_mode::hessian(func.clone(), x)
}
pub fn value_and_grad<F, Func>(func: Func) -> impl Fn(&Array1<F>) -> (F, Array1<F>)
where
F: NumFloat + Copy + fmt::Debug + Send + Sync + 'static,
Func: Fn(&[DualNumber<F>]) -> DualNumber<F> + Clone + 'static,
{
move |x: &Array1<F>| {
let n = x.len();
let mut gradient = Array1::<F>::zeros(n);
let primal_duals: Vec<DualNumber<F>> =
x.iter().map(|&xk| DualNumber::constant(xk)).collect();
let value = func.clone()(&primal_duals).value();
for i in 0..n {
let duals: Vec<DualNumber<F>> = x
.iter()
.enumerate()
.map(|(k, &xk)| {
if k == i {
DualNumber::new(xk, F::one())
} else {
DualNumber::new(xk, F::zero())
}
})
.collect();
gradient[i] = func.clone()(&duals).tangent();
}
(value, gradient)
}
}
pub fn jacobian<F, Func>(func: Func) -> impl Fn(&Array1<F>) -> Array2<F>
where
F: NumFloat + Copy + fmt::Debug + Send + Sync + 'static,
Func: Fn(&[DualNumber<F>]) -> Vec<DualNumber<F>> + Clone + 'static,
{
move |x: &Array1<F>| crate::forward_mode::jacobian_forward(func.clone(), x)
}
pub fn stop_gradient<F: NumFloat + Copy>(tensor: &Array2<F>) -> Array2<F> {
tensor.clone()
}
pub fn stop_gradient_1d<F: NumFloat + Copy>(tensor: &Array1<F>) -> Array1<F> {
tensor.clone()
}
pub fn stop_gradient_dual<F: NumFloat + Copy + fmt::Debug>(
duals: &[DualNumber<F>],
) -> Vec<DualNumber<F>> {
duals
.iter()
.map(|d| DualNumber::constant(d.value()))
.collect()
}
pub struct Checkpoint<F: NumFloat + Copy> {
segments: Vec<SegmentFn<F>>,
}
impl<F: NumFloat + Copy + fmt::Debug> Checkpoint<F> {
pub fn new() -> Self {
Self {
segments: Vec::new(),
}
}
pub fn add_segment(
&mut self,
f: impl Fn(&Array2<F>) -> Array2<F> + Send + Sync + 'static,
) -> &mut Self {
self.segments.push(Arc::new(f));
self
}
pub fn forward(&self, input: &Array2<F>) -> Array2<F> {
let mut current = input.clone();
for seg in &self.segments {
current = seg(¤t);
}
current
}
pub fn forward_with_checkpoints(&self, input: &Array2<F>) -> (Array2<F>, Vec<Array2<F>>) {
let mut checkpoints = Vec::with_capacity(self.segments.len() + 1);
let mut current = input.clone();
for seg in &self.segments {
checkpoints.push(current.clone());
current = seg(¤t);
}
(current, checkpoints)
}
pub fn recompute_segment(
&self,
segment_idx: usize,
checkpoint: &Array2<F>,
) -> Option<Array2<F>> {
self.segments.get(segment_idx).map(|seg| seg(checkpoint))
}
pub fn memory_savings_ratio(&self) -> f64 {
let n = self.segments.len();
if n <= 1 {
return 1.0;
}
(n as f64).sqrt() / n as f64
}
pub fn num_segments(&self) -> usize {
self.segments.len()
}
}
impl<F: NumFloat + Copy + fmt::Debug> Default for Checkpoint<F> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct JitHint {
constant_folding: bool,
fusion: bool,
cse: bool,
static_args: Vec<usize>,
dead_code_elimination: bool,
max_fusion_depth: usize,
}
impl JitHint {
pub fn new() -> Self {
Self {
constant_folding: false,
fusion: false,
cse: false,
static_args: Vec::new(),
dead_code_elimination: false,
max_fusion_depth: 8,
}
}
pub fn enable_constant_folding(mut self, enable: bool) -> Self {
self.constant_folding = enable;
self
}
pub fn enable_fusion(mut self, enable: bool) -> Self {
self.fusion = enable;
self
}
pub fn enable_cse(mut self, enable: bool) -> Self {
self.cse = enable;
self
}
pub fn enable_dead_code_elimination(mut self, enable: bool) -> Self {
self.dead_code_elimination = enable;
self
}
pub fn set_max_fusion_depth(mut self, depth: usize) -> Self {
self.max_fusion_depth = depth;
self
}
pub fn with_static_argnums(mut self, indices: &[usize]) -> Self {
self.static_args = indices.to_vec();
self
}
pub fn constant_folding(&self) -> bool {
self.constant_folding
}
pub fn fusion(&self) -> bool {
self.fusion
}
pub fn cse(&self) -> bool {
self.cse
}
pub fn dead_code_elimination(&self) -> bool {
self.dead_code_elimination
}
pub fn max_fusion_depth(&self) -> usize {
self.max_fusion_depth
}
pub fn static_argnums(&self) -> &[usize] {
&self.static_args
}
pub fn all_optimizations() -> Self {
Self {
constant_folding: true,
fusion: true,
cse: true,
static_args: Vec::new(),
dead_code_elimination: true,
max_fusion_depth: 16,
}
}
}
impl Default for JitHint {
fn default() -> Self {
Self::new()
}
}
pub fn batched_value_and_grad<F, Func>(
func: Func,
inputs: &Array2<F>,
) -> Result<(Array1<F>, Array2<F>), AutogradError>
where
F: NumFloat + Copy + fmt::Debug + Send + Sync + 'static,
Func: Fn(&[DualNumber<F>]) -> DualNumber<F> + Clone + Send + Sync + 'static,
{
let batch_size = inputs.nrows();
let n = inputs.ncols();
if batch_size == 0 {
return Err(AutogradError::OperationError(
"batched_value_and_grad: empty batch".to_string(),
));
}
let mut values = Array1::<F>::zeros(batch_size);
let mut gradients = Array2::<F>::zeros((batch_size, n));
let vg = value_and_grad(func);
for b in 0..batch_size {
let row = inputs.row(b).to_owned();
let (val, g) = vg(&row);
values[b] = val;
for j in 0..n {
gradients[[b, j]] = g[j];
}
}
Ok((values, gradients))
}
pub fn scan<F, Func>(transforms: &[Func], input: &Array1<F>) -> (Array1<F>, Vec<Array1<F>>)
where
F: NumFloat + Copy + fmt::Debug,
Func: Fn(&Array1<F>) -> Array1<F>,
{
let mut intermediates = Vec::with_capacity(transforms.len());
let mut current = input.clone();
for t in transforms {
current = t(¤t);
intermediates.push(current.clone());
}
(current, intermediates)
}
pub fn check_grad<F, FuncDual, FuncScalar>(
func_dual: FuncDual,
func_scalar: FuncScalar,
x: &Array1<F>,
epsilon: F,
) -> F
where
F: NumFloat + Copy + fmt::Debug,
FuncDual: Fn(&[DualNumber<F>]) -> DualNumber<F> + Clone,
FuncScalar: Fn(&Array1<F>) -> F,
{
let two = F::one() + F::one();
let analytical = crate::forward_mode::gradient_forward(func_dual, x);
let n = x.len();
let mut max_err = F::zero();
for i in 0..n {
let mut x_fwd = x.clone();
let mut x_bwd = x.clone();
x_fwd[i] = x[i] + epsilon;
x_bwd[i] = x[i] - epsilon;
let numerical_i = (func_scalar(&x_fwd) - func_scalar(&x_bwd)) / (two * epsilon);
let err = (analytical[i] - numerical_i).abs();
if err > max_err {
max_err = err;
}
}
max_err
}
pub fn compose<F, G, A, B, C>(f: F, g: G) -> impl Fn(A) -> C
where
F: Fn(A) -> B,
G: Fn(B) -> C,
{
move |x| g(f(x))
}
pub fn iterate<F, Func>(func: &Func, x: &Array1<F>, n: usize) -> Array1<F>
where
F: NumFloat + Copy + fmt::Debug,
Func: Fn(&Array1<F>) -> Array1<F>,
{
let mut current = x.clone();
for _ in 0..n {
current = func(¤t);
}
current
}
pub fn numerical_jacobian<F, Func>(func: &Func, x: &Array1<F>, epsilon: F) -> Array2<F>
where
F: NumFloat + Copy + fmt::Debug,
Func: Fn(&Array1<F>) -> Array1<F>,
{
let two = F::one() + F::one();
let n = x.len();
let y0 = func(x);
let m = y0.len();
let mut jac = Array2::<F>::zeros((m, n));
for j in 0..n {
let mut x_fwd = x.clone();
let mut x_bwd = x.clone();
x_fwd[j] = x[j] + epsilon;
x_bwd[j] = x[j] - epsilon;
let y_fwd = func(&x_fwd);
let y_bwd = func(&x_bwd);
for i in 0..m {
jac[[i, j]] = (y_fwd[i] - y_bwd[i]) / (two * epsilon);
}
}
jac
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{Array1, Array2};
#[test]
fn test_vmap_double() {
let batch = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.expect("valid shape");
let result = vmap(|x: &Array1<f64>| x.mapv(|v| v * 2.0), &batch).expect("vmap succeeds");
assert_eq!(result.shape(), &[3, 2]);
assert!((result[[0, 0]] - 2.0).abs() < 1e-12);
assert!((result[[1, 0]] - 6.0).abs() < 1e-12);
assert!((result[[2, 1]] - 12.0).abs() < 1e-12);
}
#[test]
fn test_vmap_nonlinear() {
let batch = Array2::from_shape_vec((2, 3), vec![1.0, 4.0, 9.0, 16.0, 25.0, 36.0])
.expect("valid shape");
let result = vmap(|x: &Array1<f64>| x.mapv(|v| v.sqrt()), &batch).expect("vmap succeeds");
assert!((result[[0, 0]] - 1.0).abs() < 1e-12);
assert!((result[[0, 1]] - 2.0).abs() < 1e-12);
assert!((result[[1, 2]] - 6.0).abs() < 1e-12);
}
#[test]
fn test_vmap_empty_batch() {
let batch = Array2::<f64>::zeros((0, 3));
let result = vmap(|x: &Array1<f64>| x.clone(), &batch);
assert!(result.is_err());
}
#[test]
fn test_vmap_single_element() {
let batch = Array2::from_shape_vec((1, 4), vec![1.0, 2.0, 3.0, 4.0]).expect("valid shape");
let result = vmap(|x: &Array1<f64>| x.mapv(|v| v + 10.0), &batch).expect("vmap succeeds");
assert_eq!(result.shape(), &[1, 4]);
assert!((result[[0, 0]] - 11.0).abs() < 1e-12);
assert!((result[[0, 3]] - 14.0).abs() < 1e-12);
}
#[test]
fn test_vmap_dimension_change() {
let batch = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.expect("valid shape");
let result = vmap(
|x: &Array1<f64>| Array1::from(vec![x[0] + x[1], x[2]]),
&batch,
)
.expect("vmap succeeds");
assert_eq!(result.shape(), &[2, 2]);
assert!((result[[0, 0]] - 3.0).abs() < 1e-12); assert!((result[[0, 1]] - 3.0).abs() < 1e-12); assert!((result[[1, 0]] - 9.0).abs() < 1e-12); assert!((result[[1, 1]] - 6.0).abs() < 1e-12); }
#[test]
fn test_grad_quadratic() {
let grad_f = grad(|xs: &[DualNumber<f64>]| xs[0] * xs[0] + xs[1] * xs[1]);
let x = Array1::from(vec![3.0, 4.0]);
let g = grad_f(&x);
assert!((g[0] - 6.0).abs() < 1e-12);
assert!((g[1] - 8.0).abs() < 1e-12);
}
#[test]
fn test_grad_linear() {
let grad_f = grad(|xs: &[DualNumber<f64>]| {
let three = DualNumber::constant(3.0);
let seven = DualNumber::constant(7.0);
three * xs[0] + seven * xs[1]
});
let x = Array1::from(vec![100.0, 200.0]);
let g = grad_f(&x);
assert!((g[0] - 3.0).abs() < 1e-12);
assert!((g[1] - 7.0).abs() < 1e-12);
}
#[test]
fn test_grad_transcendental() {
let grad_f = grad(|xs: &[DualNumber<f64>]| xs[0].sin() * xs[1].exp());
let x = Array1::from(vec![0.0, 0.0]);
let g = grad_f(&x);
assert!((g[0] - 1.0).abs() < 1e-12);
assert!(g[1].abs() < 1e-12);
}
#[test]
fn test_grad_grad_quadratic() {
let hessian_f = grad_grad(|xs: &[DualNumber<f64>]| {
let two = DualNumber::constant(2.0);
let three = DualNumber::constant(3.0);
xs[0] * xs[0] + three * xs[0] * xs[1] + two * xs[1] * xs[1]
});
let x = Array1::from(vec![1.0, 1.0]);
let h = hessian_f(&x);
assert!((h[[0, 0]] - 2.0).abs() < 1e-4);
assert!((h[[0, 1]] - 3.0).abs() < 1e-4);
assert!((h[[1, 0]] - 3.0).abs() < 1e-4);
assert!((h[[1, 1]] - 4.0).abs() < 1e-4);
}
#[test]
fn test_grad_grad_diagonal() {
let hessian_f = grad_grad(|xs: &[DualNumber<f64>]| {
let five = DualNumber::constant(5.0);
xs[0] * xs[0] + five * xs[1] * xs[1]
});
let x = Array1::from(vec![0.0, 0.0]);
let h = hessian_f(&x);
assert!((h[[0, 0]] - 2.0).abs() < 1e-4);
assert!(h[[0, 1]].abs() < 1e-4);
assert!(h[[1, 0]].abs() < 1e-4);
assert!((h[[1, 1]] - 10.0).abs() < 1e-4);
}
#[test]
fn test_value_and_grad_basic() {
let vg = value_and_grad(|xs: &[DualNumber<f64>]| xs[0] * xs[0] + xs[1] * xs[1]);
let x = Array1::from(vec![3.0, 4.0]);
let (val, g) = vg(&x);
assert!((val - 25.0).abs() < 1e-12);
assert!((g[0] - 6.0).abs() < 1e-12);
assert!((g[1] - 8.0).abs() < 1e-12);
}
#[test]
fn test_value_and_grad_rosenbrock() {
let vg = value_and_grad(|xs: &[DualNumber<f64>]| {
let one = DualNumber::constant(1.0);
let hundred = DualNumber::constant(100.0);
let a = one - xs[0];
let b = xs[1] - xs[0] * xs[0];
a * a + hundred * b * b
});
let x = Array1::from(vec![1.0, 1.0]);
let (val, g) = vg(&x);
assert!(val.abs() < 1e-12);
assert!(g[0].abs() < 1e-12);
assert!(g[1].abs() < 1e-12);
}
#[test]
fn test_stop_gradient_2d() {
let t = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).expect("valid shape");
let stopped = stop_gradient(&t);
assert_eq!(t, stopped);
let mut stopped_mut = stopped;
stopped_mut[[0, 0]] = 999.0;
assert!((t[[0, 0]] - 1.0).abs() < 1e-12);
}
#[test]
fn test_stop_gradient_1d() {
let t = Array1::from(vec![1.0, 2.0, 3.0]);
let stopped = stop_gradient_1d(&t);
assert_eq!(t, stopped);
}
#[test]
fn test_stop_gradient_dual() {
let duals = vec![DualNumber::new(1.0_f64, 0.5), DualNumber::new(2.0, 0.7)];
let stopped = stop_gradient_dual(&duals);
assert!((stopped[0].value() - 1.0).abs() < 1e-12);
assert!(stopped[0].tangent().abs() < 1e-12); assert!((stopped[1].value() - 2.0).abs() < 1e-12);
assert!(stopped[1].tangent().abs() < 1e-12);
}
#[test]
fn test_checkpoint_forward() {
let mut ckpt = Checkpoint::<f64>::new();
ckpt.add_segment(|x: &Array2<f64>| x.mapv(|v| v * 2.0))
.add_segment(|x: &Array2<f64>| x.mapv(|v| v + 1.0))
.add_segment(|x: &Array2<f64>| x.mapv(|v| v * v));
let input = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).expect("valid shape");
let output = ckpt.forward(&input);
assert!((output[[0, 0]] - 9.0).abs() < 1e-12);
assert!((output[[0, 1]] - 25.0).abs() < 1e-12);
assert!((output[[1, 0]] - 49.0).abs() < 1e-12);
assert!((output[[1, 1]] - 81.0).abs() < 1e-12);
}
#[test]
fn test_checkpoint_with_intermediates() {
let mut ckpt = Checkpoint::<f64>::new();
ckpt.add_segment(|x: &Array2<f64>| x.mapv(|v| v + 10.0))
.add_segment(|x: &Array2<f64>| x.mapv(|v| v * 3.0));
let input = Array2::from_shape_vec((1, 2), vec![1.0, 2.0]).expect("valid shape");
let (output, checkpoints) = ckpt.forward_with_checkpoints(&input);
assert_eq!(checkpoints.len(), 2);
assert!((checkpoints[0][[0, 0]] - 1.0).abs() < 1e-12);
assert!((checkpoints[1][[0, 0]] - 11.0).abs() < 1e-12);
assert!((output[[0, 0]] - 33.0).abs() < 1e-12);
assert!((output[[0, 1]] - 36.0).abs() < 1e-12);
}
#[test]
fn test_checkpoint_recompute() {
let mut ckpt = Checkpoint::<f64>::new();
ckpt.add_segment(|x: &Array2<f64>| x.mapv(|v| v * 5.0));
let input = Array2::from_shape_vec((1, 1), vec![3.0]).expect("valid shape");
let recomputed = ckpt.recompute_segment(0, &input);
assert!(recomputed.is_some());
let r = recomputed.expect("segment exists");
assert!((r[[0, 0]] - 15.0).abs() < 1e-12);
assert!(ckpt.recompute_segment(5, &input).is_none());
}
#[test]
fn test_checkpoint_memory_savings() {
let mut ckpt = Checkpoint::<f64>::new();
for _ in 0..100 {
ckpt.add_segment(|x: &Array2<f64>| x.mapv(|v| v + 1.0));
}
let ratio = ckpt.memory_savings_ratio();
assert!((ratio - 0.1).abs() < 1e-12);
assert_eq!(ckpt.num_segments(), 100);
}
#[test]
fn test_checkpoint_empty() {
let ckpt = Checkpoint::<f64>::new();
let input = Array2::from_shape_vec((1, 2), vec![1.0, 2.0]).expect("valid shape");
let output = ckpt.forward(&input);
assert_eq!(output, input);
assert!((ckpt.memory_savings_ratio() - 1.0).abs() < 1e-12);
}
#[test]
fn test_jit_hint_builder() {
let hint = JitHint::new()
.enable_constant_folding(true)
.enable_fusion(true)
.enable_cse(false)
.enable_dead_code_elimination(true)
.with_static_argnums(&[0, 2])
.set_max_fusion_depth(4);
assert!(hint.constant_folding());
assert!(hint.fusion());
assert!(!hint.cse());
assert!(hint.dead_code_elimination());
assert_eq!(hint.static_argnums(), &[0, 2]);
assert_eq!(hint.max_fusion_depth(), 4);
}
#[test]
fn test_jit_hint_all_optimizations() {
let hint = JitHint::all_optimizations();
assert!(hint.constant_folding());
assert!(hint.fusion());
assert!(hint.cse());
assert!(hint.dead_code_elimination());
}
#[test]
fn test_pmap_parallel() {
let batch = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
.expect("valid shape");
let result = pmap(|x: &Array1<f64>| x.mapv(|v| v * v), &batch).expect("pmap succeeds");
assert_eq!(result.shape(), &[4, 2]);
assert!((result[[0, 0]] - 1.0).abs() < 1e-12);
assert!((result[[1, 1]] - 16.0).abs() < 1e-12);
assert!((result[[3, 0]] - 49.0).abs() < 1e-12);
}
#[test]
fn test_pmap_empty_batch() {
let batch = Array2::<f64>::zeros((0, 3));
let result = pmap(|x: &Array1<f64>| x.clone(), &batch);
assert!(result.is_err());
}
#[test]
fn test_batched_value_and_grad() {
let batch = Array2::from_shape_vec((2, 2), vec![3.0, 4.0, 1.0, 0.0]).expect("valid shape");
let (vals, grads) = batched_value_and_grad(
|xs: &[DualNumber<f64>]| xs[0] * xs[0] + xs[1] * xs[1],
&batch,
)
.expect("succeeds");
assert!((vals[0] - 25.0).abs() < 1e-12);
assert!((grads[[0, 0]] - 6.0).abs() < 1e-12);
assert!((grads[[0, 1]] - 8.0).abs() < 1e-12);
assert!((vals[1] - 1.0).abs() < 1e-12);
assert!((grads[[1, 0]] - 2.0).abs() < 1e-12);
assert!(grads[[1, 1]].abs() < 1e-12);
}
#[test]
fn test_scan_transforms() {
type TransformFn = Box<dyn Fn(&Array1<f64>) -> Array1<f64>>;
let transforms: Vec<TransformFn> = vec![
Box::new(|x: &Array1<f64>| x.mapv(|v| v + 1.0)),
Box::new(|x: &Array1<f64>| x.mapv(|v| v * 2.0)),
Box::new(|x: &Array1<f64>| x.mapv(|v| v - 3.0)),
];
let input = Array1::from(vec![10.0]);
let (final_out, intermediates) = scan(&transforms, &input);
assert_eq!(intermediates.len(), 3);
assert!((intermediates[0][0] - 11.0).abs() < 1e-12); assert!((intermediates[1][0] - 22.0).abs() < 1e-12); assert!((intermediates[2][0] - 19.0).abs() < 1e-12); assert!((final_out[0] - 19.0).abs() < 1e-12);
}
#[test]
fn test_check_grad_accurate() {
let err = check_grad(
|xs: &[DualNumber<f64>]| xs[0] * xs[0] + xs[1] * xs[1],
|xs: &Array1<f64>| xs[0] * xs[0] + xs[1] * xs[1],
&Array1::from(vec![3.0, 4.0]),
1e-6,
);
assert!(err < 1e-5, "Gradient check error too large: {}", err);
}
#[test]
fn test_compose_functions() {
let f = |x: f64| x * 2.0;
let g = |x: f64| x + 10.0;
let h = compose(f, g);
assert!((h(5.0) - 20.0).abs() < 1e-12); }
#[test]
fn test_iterate_function() {
let f = |x: &Array1<f64>| x.mapv(|v| v + 1.0);
let x = Array1::from(vec![0.0]);
let result = iterate(&f, &x, 5);
assert!((result[0] - 5.0).abs() < 1e-12);
}
#[test]
fn test_numerical_jacobian() {
let f = |x: &Array1<f64>| Array1::from(vec![x[0] * x[0], x[0] * x[1]]);
let x = Array1::from(vec![2.0, 3.0]);
let jac = numerical_jacobian(&f, &x, 1e-6);
assert!((jac[[0, 0]] - 4.0).abs() < 1e-4);
assert!(jac[[0, 1]].abs() < 1e-4);
assert!((jac[[1, 0]] - 3.0).abs() < 1e-4);
assert!((jac[[1, 1]] - 2.0).abs() < 1e-4);
}
#[test]
fn test_jacobian_transform() {
let jac_f = jacobian(|xs: &[DualNumber<f64>]| vec![xs[0] * xs[0], xs[0] * xs[1]]);
let x = Array1::from(vec![2.0_f64, 3.0]);
let jac = jac_f(&x);
assert!((jac[[0, 0]] - 4.0).abs() < 1e-12);
assert!(jac[[0, 1]].abs() < 1e-12);
assert!((jac[[1, 0]] - 3.0).abs() < 1e-12);
assert!((jac[[1, 1]] - 2.0).abs() < 1e-12);
}
#[test]
fn test_grad_then_vmap() {
let grad_f = grad(|xs: &[DualNumber<f64>]| xs[0] * xs[0]);
let points = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).expect("valid shape");
let gradients =
vmap(move |x: &Array1<f64>| grad_f(x), &points).expect("vmap(grad) succeeds");
assert!((gradients[[0, 0]] - 2.0).abs() < 1e-12);
assert!((gradients[[1, 0]] - 4.0).abs() < 1e-12);
assert!((gradients[[2, 0]] - 6.0).abs() < 1e-12);
}
#[test]
fn test_checkpoint_with_grad() {
let mut ckpt = Checkpoint::<f64>::new();
ckpt.add_segment(|x: &Array2<f64>| x.mapv(|v| v.exp()))
.add_segment(|x: &Array2<f64>| {
let sums: Vec<f64> = x
.rows()
.into_iter()
.map(|row| row.iter().copied().fold(0.0, |a, b| a + b))
.collect();
let ncols = x.ncols();
let mut out = Array2::zeros(x.raw_dim());
for (i, &s) in sums.iter().enumerate() {
for j in 0..ncols {
out[[i, j]] = x[[i, j]] / s;
}
}
out
});
let input = Array2::from_shape_vec((1, 3), vec![1.0, 2.0, 3.0]).expect("valid shape");
let output = ckpt.forward(&input);
let sum_exp = 1.0_f64.exp() + 2.0_f64.exp() + 3.0_f64.exp();
assert!((output[[0, 0]] - 1.0_f64.exp() / sum_exp).abs() < 1e-10);
assert!((output[[0, 1]] - 2.0_f64.exp() / sum_exp).abs() < 1e-10);
}
#[test]
fn test_value_and_grad_consistency() {
let func = |xs: &[DualNumber<f64>]| {
let three = DualNumber::constant(3.0);
xs[0].powi(3) + three * xs[1].exp()
};
let x = Array1::from(vec![2.0, 1.0]);
let vg = value_and_grad(func);
let (val, g) = vg(&x);
let grad_f = grad(func);
let g2 = grad_f(&x);
let primal: Vec<DualNumber<f64>> = x.iter().map(|&xi| DualNumber::constant(xi)).collect();
let val2 = func(&primal).value();
assert!((val - val2).abs() < 1e-12, "Values should match");
assert!((g[0] - g2[0]).abs() < 1e-12, "Gradient x0 should match");
assert!((g[1] - g2[1]).abs() < 1e-12, "Gradient x1 should match");
}
}