1use crate::tape::{Tape, TapeRef};
38use std::ops::{Add, Div, Mul, Neg, Sub};
39use std::rc::Rc;
40
41#[derive(Clone, Debug)]
46pub struct Var {
47 pub(crate) index: usize,
49 pub value: f64,
51 pub(crate) tape: TapeRef,
53}
54
55impl Var {
56 pub fn cst(&self, value: f64) -> Var {
58 let (index, value) = {
59 let mut t = self.tape.borrow_mut();
60 let idx = t.nodes.len();
61 t.nodes.push(crate::tape::Node {
62 value,
63 parent1: None,
64 parent2: None,
65 });
66 (idx, value)
67 };
68 Var {
69 index,
70 value,
71 tape: Rc::clone(&self.tape),
72 }
73 }
74
75 pub fn sin(&self) -> Var {
77 let val = self.value.sin();
78 let deriv = self.value.cos(); let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
80 Var {
81 index,
82 value,
83 tape: Rc::clone(&self.tape),
84 }
85 }
86
87 pub fn cos(&self) -> Var {
89 let val = self.value.cos();
90 let deriv = -self.value.sin(); let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
92 Var {
93 index,
94 value,
95 tape: Rc::clone(&self.tape),
96 }
97 }
98
99 pub fn tan(&self) -> Var {
101 let val = self.value.tan();
102 let c = self.value.cos();
103 let deriv = 1.0 / (c * c); let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
105 Var {
106 index,
107 value,
108 tape: Rc::clone(&self.tape),
109 }
110 }
111
112 pub fn exp(&self) -> Var {
114 let val = self.value.exp();
115 let deriv = val; let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
117 Var {
118 index,
119 value,
120 tape: Rc::clone(&self.tape),
121 }
122 }
123
124 pub fn ln(&self) -> Var {
126 let val = self.value.ln();
127 let deriv = 1.0 / self.value; let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
129 Var {
130 index,
131 value,
132 tape: Rc::clone(&self.tape),
133 }
134 }
135
136 pub fn sqrt(&self) -> Var {
138 let val = self.value.sqrt();
139 let deriv = 0.5 / val; let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
141 Var {
142 index,
143 value,
144 tape: Rc::clone(&self.tape),
145 }
146 }
147
148 pub fn abs(&self) -> Var {
150 let val = self.value.abs();
151 let deriv = if self.value >= 0.0 { 1.0 } else { -1.0 };
152 let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
153 Var {
154 index,
155 value,
156 tape: Rc::clone(&self.tape),
157 }
158 }
159
160 pub fn tanh(&self) -> Var {
162 let val = self.value.tanh();
163 let deriv = 1.0 - val * val; let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
165 Var {
166 index,
167 value,
168 tape: Rc::clone(&self.tape),
169 }
170 }
171
172 pub fn sinh(&self) -> Var {
174 let val = self.value.sinh();
175 let deriv = self.value.cosh();
176 let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
177 Var {
178 index,
179 value,
180 tape: Rc::clone(&self.tape),
181 }
182 }
183
184 pub fn cosh(&self) -> Var {
186 let val = self.value.cosh();
187 let deriv = self.value.sinh();
188 let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
189 Var {
190 index,
191 value,
192 tape: Rc::clone(&self.tape),
193 }
194 }
195
196 pub fn pow(&self, n: &Var) -> Var {
198 let val = self.value.powf(n.value);
199 let d_self = n.value * self.value.powf(n.value - 1.0);
201 let d_n = val * self.value.ln();
203 let (index, value) = Tape::push_binary(&self.tape, val, self.index, d_self, n.index, d_n);
204 Var {
205 index,
206 value,
207 tape: Rc::clone(&self.tape),
208 }
209 }
210
211 pub fn powf(&self, n: f64) -> Var {
213 let val = self.value.powf(n);
214 let deriv = n * self.value.powf(n - 1.0);
215 let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
216 Var {
217 index,
218 value,
219 tape: Rc::clone(&self.tape),
220 }
221 }
222
223 pub fn powi(&self, n: i32) -> Var {
225 self.powf(n as f64)
226 }
227
228 pub fn asin(&self) -> Var {
230 let val = self.value.asin();
231 let deriv = 1.0 / (1.0 - self.value * self.value).sqrt();
232 let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
233 Var {
234 index,
235 value,
236 tape: Rc::clone(&self.tape),
237 }
238 }
239
240 pub fn acos(&self) -> Var {
242 let val = self.value.acos();
243 let deriv = -1.0 / (1.0 - self.value * self.value).sqrt();
244 let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
245 Var {
246 index,
247 value,
248 tape: Rc::clone(&self.tape),
249 }
250 }
251
252 pub fn atan(&self) -> Var {
254 let val = self.value.atan();
255 let deriv = 1.0 / (1.0 + self.value * self.value);
256 let (index, value) = Tape::push_unary(&self.tape, val, self.index, deriv);
257 Var {
258 index,
259 value,
260 tape: Rc::clone(&self.tape),
261 }
262 }
263}
264
265impl Add for Var {
269 type Output = Var;
270 fn add(self, rhs: Var) -> Var {
271 let val = self.value + rhs.value;
272 let (index, value) = Tape::push_binary(&self.tape, val, self.index, 1.0, rhs.index, 1.0);
273 Var {
274 index,
275 value,
276 tape: self.tape,
277 }
278 }
279}
280
281impl Add for &Var {
283 type Output = Var;
284 fn add(self, rhs: &Var) -> Var {
285 let val = self.value + rhs.value;
286 let (index, value) = Tape::push_binary(&self.tape, val, self.index, 1.0, rhs.index, 1.0);
287 Var {
288 index,
289 value,
290 tape: Rc::clone(&self.tape),
291 }
292 }
293}
294
295impl Add<f64> for Var {
297 type Output = Var;
298 fn add(self, rhs: f64) -> Var {
299 let val = self.value + rhs;
300 let (index, value) = Tape::push_unary(&self.tape, val, self.index, 1.0);
301 Var {
302 index,
303 value,
304 tape: self.tape,
305 }
306 }
307}
308
309impl Add<Var> for f64 {
311 type Output = Var;
312 fn add(self, rhs: Var) -> Var {
313 let val = self + rhs.value;
314 let (index, value) = Tape::push_unary(&rhs.tape, val, rhs.index, 1.0);
315 Var {
316 index,
317 value,
318 tape: rhs.tape,
319 }
320 }
321}
322
323impl Sub for Var {
325 type Output = Var;
326 fn sub(self, rhs: Var) -> Var {
327 let val = self.value - rhs.value;
328 let (index, value) = Tape::push_binary(&self.tape, val, self.index, 1.0, rhs.index, -1.0);
329 Var {
330 index,
331 value,
332 tape: self.tape,
333 }
334 }
335}
336
337impl Sub for &Var {
339 type Output = Var;
340 fn sub(self, rhs: &Var) -> Var {
341 let val = self.value - rhs.value;
342 let (index, value) = Tape::push_binary(&self.tape, val, self.index, 1.0, rhs.index, -1.0);
343 Var {
344 index,
345 value,
346 tape: Rc::clone(&self.tape),
347 }
348 }
349}
350
351impl Sub<f64> for Var {
353 type Output = Var;
354 fn sub(self, rhs: f64) -> Var {
355 let val = self.value - rhs;
356 let (index, value) = Tape::push_unary(&self.tape, val, self.index, 1.0);
357 Var {
358 index,
359 value,
360 tape: self.tape,
361 }
362 }
363}
364
365impl Sub<Var> for f64 {
367 type Output = Var;
368 fn sub(self, rhs: Var) -> Var {
369 let val = self - rhs.value;
370 let (index, value) = Tape::push_unary(&rhs.tape, val, rhs.index, -1.0);
371 Var {
372 index,
373 value,
374 tape: rhs.tape,
375 }
376 }
377}
378
379impl Mul for Var {
381 type Output = Var;
382 fn mul(self, rhs: Var) -> Var {
383 let val = self.value * rhs.value;
384 let (index, value) = Tape::push_binary(
385 &self.tape, val, self.index, rhs.value, rhs.index, self.value, );
388 Var {
389 index,
390 value,
391 tape: self.tape,
392 }
393 }
394}
395
396impl Mul for &Var {
398 type Output = Var;
399 fn mul(self, rhs: &Var) -> Var {
400 let val = self.value * rhs.value;
401 let (index, value) = Tape::push_binary(
402 &self.tape, val, self.index, rhs.value, rhs.index, self.value,
403 );
404 Var {
405 index,
406 value,
407 tape: Rc::clone(&self.tape),
408 }
409 }
410}
411
412impl Mul<f64> for Var {
414 type Output = Var;
415 fn mul(self, rhs: f64) -> Var {
416 let val = self.value * rhs;
417 let (index, value) = Tape::push_unary(&self.tape, val, self.index, rhs);
418 Var {
419 index,
420 value,
421 tape: self.tape,
422 }
423 }
424}
425
426impl Mul<Var> for f64 {
428 type Output = Var;
429 fn mul(self, rhs: Var) -> Var {
430 let val = self * rhs.value;
431 let (index, value) = Tape::push_unary(&rhs.tape, val, rhs.index, self);
432 Var {
433 index,
434 value,
435 tape: rhs.tape,
436 }
437 }
438}
439
440impl Mul<&Var> for Var {
442 type Output = Var;
443 fn mul(self, rhs: &Var) -> Var {
444 let val = self.value * rhs.value;
445 let (index, value) = Tape::push_binary(
446 &self.tape, val, self.index, rhs.value, rhs.index, self.value,
447 );
448 Var {
449 index,
450 value,
451 tape: self.tape,
452 }
453 }
454}
455
456impl Mul<Var> for &Var {
458 type Output = Var;
459 fn mul(self, rhs: Var) -> Var {
460 let val = self.value * rhs.value;
461 let (index, value) = Tape::push_binary(
462 &self.tape, val, self.index, rhs.value, rhs.index, self.value,
463 );
464 Var {
465 index,
466 value,
467 tape: rhs.tape,
468 }
469 }
470}
471
472impl Div for Var {
474 type Output = Var;
475 fn div(self, rhs: Var) -> Var {
476 let val = self.value / rhs.value;
477 let (index, value) = Tape::push_binary(
478 &self.tape,
479 val,
480 self.index,
481 1.0 / rhs.value, rhs.index,
483 -self.value / (rhs.value * rhs.value), );
485 Var {
486 index,
487 value,
488 tape: self.tape,
489 }
490 }
491}
492
493impl Div for &Var {
495 type Output = Var;
496 fn div(self, rhs: &Var) -> Var {
497 let val = self.value / rhs.value;
498 let (index, value) = Tape::push_binary(
499 &self.tape,
500 val,
501 self.index,
502 1.0 / rhs.value,
503 rhs.index,
504 -self.value / (rhs.value * rhs.value),
505 );
506 Var {
507 index,
508 value,
509 tape: Rc::clone(&self.tape),
510 }
511 }
512}
513
514impl Div<f64> for Var {
516 type Output = Var;
517 fn div(self, rhs: f64) -> Var {
518 let val = self.value / rhs;
519 let (index, value) = Tape::push_unary(&self.tape, val, self.index, 1.0 / rhs);
520 Var {
521 index,
522 value,
523 tape: self.tape,
524 }
525 }
526}
527
528impl Neg for Var {
530 type Output = Var;
531 fn neg(self) -> Var {
532 let val = -self.value;
533 let (index, value) = Tape::push_unary(&self.tape, val, self.index, -1.0);
534 Var {
535 index,
536 value,
537 tape: self.tape,
538 }
539 }
540}
541
542impl Neg for &Var {
544 type Output = Var;
545 fn neg(self) -> Var {
546 let val = -self.value;
547 let (index, value) = Tape::push_unary(&self.tape, val, self.index, -1.0);
548 Var {
549 index,
550 value,
551 tape: Rc::clone(&self.tape),
552 }
553 }
554}
555
556pub fn grad<F>(f: F, x: &[f64]) -> Vec<f64>
573where
574 F: Fn(&[Var]) -> Var,
575{
576 let tape = Tape::new();
577 let vars: Vec<Var> = x.iter().map(|&xi| Tape::var(&tape, xi)).collect();
578 let output = f(&vars);
579 Tape::gradient(&tape, &output)
580}
581
582pub fn jacobian_reverse<F>(f: F, x: &[f64]) -> Vec<Vec<f64>>
600where
601 F: Fn(&[Var]) -> Vec<Var>,
602{
603 let tape = Tape::new();
604 let vars: Vec<Var> = x.iter().map(|&xi| Tape::var(&tape, xi)).collect();
605 let outputs = f(&vars);
606 Tape::jacobian(&tape, &outputs)
607}
608
609pub fn hessian<F>(f: F, x: &[f64]) -> Vec<Vec<f64>>
631where
632 F: Fn(&[Var]) -> Var,
633{
634 let n = x.len();
635 let eps = 1e-7;
636 let mut h = vec![vec![0.0; n]; n];
637
638 let g0 = grad(&f, x);
639
640 for j in 0..n {
641 let mut x_pert = x.to_vec();
642 x_pert[j] += eps;
643 let g_pert = grad(&f, &x_pert);
644
645 for i in 0..n {
646 h[i][j] = (g_pert[i] - g0[i]) / eps;
647 }
648 }
649
650 h
651}
652
653#[cfg(test)]
654mod tests {
655 use super::*;
656
657 #[test]
658 fn test_basic_arithmetic() {
659 let tape = Tape::new();
660 let x = Tape::var(&tape, 2.0);
661 let y = Tape::var(&tape, 3.0);
662
663 let z = x.clone() + y.clone();
665 assert!((z.value - 5.0).abs() < 1e-14);
666 let g = Tape::gradient(&tape, &z);
667 assert!((g[0] - 1.0).abs() < 1e-14);
668 assert!((g[1] - 1.0).abs() < 1e-14);
669 }
670
671 #[test]
672 fn test_multiplication() {
673 let tape = Tape::new();
674 let x = Tape::var(&tape, 2.0);
675 let y = Tape::var(&tape, 3.0);
676
677 let z = x * y;
679 let g = Tape::gradient(&tape, &z);
680 assert!((g[0] - 3.0).abs() < 1e-14); assert!((g[1] - 2.0).abs() < 1e-14); }
683
684 #[test]
685 fn test_chain_rule() {
686 let tape = Tape::new();
687 let x = Tape::var(&tape, 2.0);
688
689 let z = x.clone() * x.clone() * x;
691 let g = Tape::gradient(&tape, &z);
692 assert!((g[0] - 12.0).abs() < 1e-12);
693 }
694
695 #[test]
696 fn test_subtraction() {
697 let tape = Tape::new();
698 let x = Tape::var(&tape, 5.0);
699 let y = Tape::var(&tape, 3.0);
700
701 let z = x - y;
702 assert!((z.value - 2.0).abs() < 1e-14);
703 let g = Tape::gradient(&tape, &z);
704 assert!((g[0] - 1.0).abs() < 1e-14);
705 assert!((g[1] - (-1.0)).abs() < 1e-14);
706 }
707
708 #[test]
709 fn test_division() {
710 let tape = Tape::new();
711 let x = Tape::var(&tape, 6.0);
712 let y = Tape::var(&tape, 3.0);
713
714 let z = x / y; let g = Tape::gradient(&tape, &z);
716 assert!((g[0] - 1.0 / 3.0).abs() < 1e-14); assert!((g[1] - (-6.0 / 9.0)).abs() < 1e-14); }
719
720 #[test]
721 fn test_negation() {
722 let tape = Tape::new();
723 let x = Tape::var(&tape, 3.0);
724 let z = -x;
725 assert!((z.value - (-3.0)).abs() < 1e-14);
726 let g = Tape::gradient(&tape, &z);
727 assert!((g[0] - (-1.0)).abs() < 1e-14);
728 }
729
730 #[test]
731 fn test_sin_cos() {
732 let tape = Tape::new();
733 let x = Tape::var(&tape, 1.0);
734
735 let s = x.sin();
736 let g = Tape::gradient(&tape, &s);
737 assert!((g[0] - 1.0_f64.cos()).abs() < 1e-14);
738
739 let tape2 = Tape::new();
741 let x2 = Tape::var(&tape2, 1.0);
742 let c = x2.cos();
743 let g2 = Tape::gradient(&tape2, &c);
744 assert!((g2[0] - (-1.0_f64.sin())).abs() < 1e-14);
745 }
746
747 #[test]
748 fn test_exp_ln() {
749 let tape = Tape::new();
750 let x = Tape::var(&tape, 2.0);
751 let e = x.exp();
752 let g = Tape::gradient(&tape, &e);
753 assert!((g[0] - 2.0_f64.exp()).abs() < 1e-12);
754
755 let tape2 = Tape::new();
756 let x2 = Tape::var(&tape2, 3.0);
757 let l = x2.ln();
758 let g2 = Tape::gradient(&tape2, &l);
759 assert!((g2[0] - 1.0 / 3.0).abs() < 1e-14);
760 }
761
762 #[test]
763 fn test_sqrt() {
764 let tape = Tape::new();
765 let x = Tape::var(&tape, 4.0);
766 let s = x.sqrt();
767 assert!((s.value - 2.0).abs() < 1e-14);
768 let g = Tape::gradient(&tape, &s);
769 assert!((g[0] - 0.25).abs() < 1e-14); }
771
772 #[test]
773 fn test_tanh() {
774 let tape = Tape::new();
775 let x = Tape::var(&tape, 1.0);
776 let t = x.tanh();
777 let g = Tape::gradient(&tape, &t);
778 let expected = 1.0 - 1.0_f64.tanh().powi(2);
779 assert!((g[0] - expected).abs() < 1e-14);
780 }
781
782 #[test]
783 fn test_powf() {
784 let tape = Tape::new();
785 let x = Tape::var(&tape, 3.0);
786 let p = x.powf(2.0); assert!((p.value - 9.0).abs() < 1e-14);
788 let g = Tape::gradient(&tape, &p);
789 assert!((g[0] - 6.0).abs() < 1e-12); }
791
792 #[test]
793 fn test_pow_var() {
794 let tape = Tape::new();
795 let x = Tape::var(&tape, 2.0);
796 let y = Tape::var(&tape, 3.0);
797 let p = x.pow(&y); assert!((p.value - 8.0).abs() < 1e-12);
799 let g = Tape::gradient(&tape, &p);
800 assert!((g[0] - 12.0).abs() < 1e-10);
802 assert!((g[1] - 8.0 * 2.0_f64.ln()).abs() < 1e-10);
804 }
805
806 #[test]
807 fn test_scalar_ops() {
808 let tape = Tape::new();
809 let x = Tape::var(&tape, 3.0);
810
811 let z1 = x.clone() + 2.0;
813 assert!((z1.value - 5.0).abs() < 1e-14);
814
815 let z2 = 2.0 + x.clone();
817 assert!((z2.value - 5.0).abs() < 1e-14);
818
819 let z3 = x.clone() * 3.0;
821 assert!((z3.value - 9.0).abs() < 1e-14);
822
823 let z4 = 3.0 * x.clone();
825 assert!((z4.value - 9.0).abs() < 1e-14);
826
827 let z5 = x.clone() - 1.0;
829 assert!((z5.value - 2.0).abs() < 1e-14);
830
831 let z6 = 10.0 - x.clone();
833 assert!((z6.value - 7.0).abs() < 1e-14);
834
835 let z7 = x / 2.0;
837 assert!((z7.value - 1.5).abs() < 1e-14);
838 }
839
840 #[test]
841 fn test_reference_ops() {
842 let tape = Tape::new();
843 let x = Tape::var(&tape, 2.0);
844 let y = Tape::var(&tape, 3.0);
845
846 let z = &x + &y;
847 assert!((z.value - 5.0).abs() < 1e-14);
848
849 let z2 = &x * &y;
850 assert!((z2.value - 6.0).abs() < 1e-14);
851
852 let z3 = &x - &y;
853 assert!((z3.value - (-1.0)).abs() < 1e-14);
854
855 let z4 = &x / &y;
856 assert!((z4.value - 2.0 / 3.0).abs() < 1e-14);
857 }
858
859 #[test]
860 fn test_grad_rosenbrock() {
861 let g = grad(
863 |x| {
864 let a = x[0].cst(1.0) - x[0].clone(); let b = x[1].clone() - x[0].clone() * x[0].clone(); a.clone() * a + x[0].cst(100.0) * b.clone() * b
867 },
868 &[1.0, 1.0],
869 );
870 assert!(g[0].abs() < 1e-10);
872 assert!(g[1].abs() < 1e-10);
873 }
874
875 #[test]
876 fn test_grad_rosenbrock_nonzero() {
877 let g = grad(
879 |x| {
880 let a = x[0].cst(1.0) - x[0].clone();
881 let b = x[1].clone() - x[0].clone() * x[0].clone();
882 a.clone() * a + x[0].cst(100.0) * b.clone() * b
883 },
884 &[0.0, 0.0],
885 );
886 assert!((g[0] - (-2.0)).abs() < 1e-10);
887 assert!(g[1].abs() < 1e-10);
888 }
889
890 #[test]
891 fn test_jacobian_reverse_fn() {
892 let jac = jacobian_reverse(|x| vec![&x[0] + &x[1], &x[0] * &x[1]], &[2.0, 3.0]);
894 assert_eq!(jac.len(), 2);
895 assert!((jac[0][0] - 1.0).abs() < 1e-14);
896 assert!((jac[0][1] - 1.0).abs() < 1e-14);
897 assert!((jac[1][0] - 3.0).abs() < 1e-14);
898 assert!((jac[1][1] - 2.0).abs() < 1e-14);
899 }
900
901 #[test]
902 fn test_jacobian_rotation() {
903 let theta: f64 = 0.5;
906 let jac = jacobian_reverse(
907 |v| {
908 let x = &v[0];
909 let y = &v[1];
910 let ct = x.cst(theta.cos());
911 let st = x.cst(theta.sin());
912 vec![&(x * &ct) - &(y * &st), &(x * &st) + &(y * &ct)]
913 },
914 &[1.0, 0.0],
915 );
916 assert!((jac[0][0] - theta.cos()).abs() < 1e-12);
917 assert!((jac[0][1] - (-theta.sin())).abs() < 1e-12);
918 assert!((jac[1][0] - theta.sin()).abs() < 1e-12);
919 assert!((jac[1][1] - theta.cos()).abs() < 1e-12);
920 }
921
922 #[test]
923 fn test_hessian_quadratic() {
924 let h = hessian(
927 |x| &x[0] * &x[0] + x[0].cst(2.0) * &x[0] * &x[1] + x[0].cst(3.0) * &x[1] * &x[1],
928 &[1.0, 1.0],
929 );
930 assert!((h[0][0] - 2.0).abs() < 1e-5);
931 assert!((h[0][1] - 2.0).abs() < 1e-5);
932 assert!((h[1][0] - 2.0).abs() < 1e-5);
933 assert!((h[1][1] - 6.0).abs() < 1e-5);
934 }
935
936 #[test]
937 fn test_hessian_rosenbrock() {
938 let h = hessian(
940 |x| {
941 let a = x[0].cst(1.0) - x[0].clone();
942 let b = x[1].clone() - x[0].clone() * x[0].clone();
943 a.clone() * a + x[0].cst(100.0) * b.clone() * b
944 },
945 &[1.0, 1.0],
946 );
947 assert!((h[0][0] - 802.0).abs() < 1e-3);
948 assert!((h[0][1] - (-400.0)).abs() < 1e-3);
949 assert!((h[1][0] - (-400.0)).abs() < 1e-3);
950 assert!((h[1][1] - 200.0).abs() < 1e-3);
951 }
952
953 #[test]
954 fn test_inverse_trig() {
955 let tape = Tape::new();
957 let x = Tape::var(&tape, 0.5);
958 let a = x.asin();
959 let g = Tape::gradient(&tape, &a);
960 assert!((g[0] - 1.0 / (1.0 - 0.25_f64).sqrt()).abs() < 1e-12);
961
962 let tape2 = Tape::new();
964 let x2 = Tape::var(&tape2, 1.0);
965 let a2 = x2.atan();
966 let g2 = Tape::gradient(&tape2, &a2);
967 assert!((g2[0] - 0.5).abs() < 1e-14); }
969
970 #[test]
971 fn test_grad_matches_fd() {
972 let f = |x: &[Var]| x[0].sin() * x[1].exp() + x[2].clone() * x[2].clone();
974 let x0 = [1.0, 2.0, 3.0];
975 let g = grad(f, &x0);
976
977 let eps = 1e-7;
979 let f_val = |x: &[f64]| x[0].sin() * x[1].exp() + x[2] * x[2];
980 let f0 = f_val(&x0);
981 for i in 0..3 {
982 let mut xp = x0;
983 xp[i] += eps;
984 let fd = (f_val(&xp) - f0) / eps;
985 assert!(
986 (g[i] - fd).abs() < 1e-5,
987 "component {} mismatch: {} vs {}",
988 i,
989 g[i],
990 fd
991 );
992 }
993 }
994
995 #[test]
996 fn test_constant() {
997 let tape = Tape::new();
998 let x = Tape::var(&tape, 2.0);
999 let c = x.cst(5.0);
1000 let z = x * c; let g = Tape::gradient(&tape, &z);
1002 assert!((g[0] - 5.0).abs() < 1e-14);
1003 }
1004
1005 #[test]
1006 fn test_complex_composition() {
1007 let g = grad(
1009 |x| {
1010 let x2 = &x[0] * &x[0];
1011 let ex = x[0].exp();
1012 let inner = x2 + ex;
1013 inner.sin()
1014 },
1015 &[1.0],
1016 );
1017 let x = 1.0_f64;
1019 let inner = x * x + x.exp();
1020 let expected = inner.cos() * (2.0 * x + x.exp());
1021 assert!((g[0] - expected).abs() < 1e-10);
1022 }
1023}