1use std::ops::{Add, Div, Mul, Neg, Sub};
21
22#[derive(Debug, Clone, Copy, PartialEq)]
31pub struct Dual {
32 pub v: f64,
34 pub dv: f64,
36}
37
38pub fn dual(x: f64, dx: f64) -> Dual {
43 Dual { v: x, dv: dx }
44}
45
46impl Dual {
47 #[allow(dead_code)]
49 pub fn constant(c: f64) -> Self {
50 Dual { v: c, dv: 0.0 }
51 }
52
53 #[allow(dead_code)]
55 pub fn variable(x: f64) -> Self {
56 Dual { v: x, dv: 1.0 }
57 }
58
59 pub fn sin(self) -> Self {
61 Dual {
62 v: self.v.sin(),
63 dv: self.v.cos() * self.dv,
64 }
65 }
66
67 pub fn cos(self) -> Self {
69 Dual {
70 v: self.v.cos(),
71 dv: -self.v.sin() * self.dv,
72 }
73 }
74
75 pub fn exp(self) -> Self {
77 let ev = self.v.exp();
78 Dual {
79 v: ev,
80 dv: ev * self.dv,
81 }
82 }
83
84 pub fn ln(self) -> Self {
86 Dual {
87 v: self.v.ln(),
88 dv: self.dv / self.v,
89 }
90 }
91
92 pub fn sqrt(self) -> Self {
94 let sv = self.v.sqrt();
95 Dual {
96 v: sv,
97 dv: self.dv / (2.0 * sv),
98 }
99 }
100
101 pub fn abs(self) -> Self {
103 Dual {
104 v: self.v.abs(),
105 dv: self.v.signum() * self.dv,
106 }
107 }
108
109 pub fn powi(self, n: i32) -> Self {
111 Dual {
112 v: self.v.powi(n),
113 dv: (n as f64) * self.v.powi(n - 1) * self.dv,
114 }
115 }
116
117 pub fn powf(self, p: f64) -> Self {
119 Dual {
120 v: self.v.powf(p),
121 dv: p * self.v.powf(p - 1.0) * self.dv,
122 }
123 }
124
125 #[allow(dead_code)]
127 pub fn tan(self) -> Self {
128 let c = self.v.cos();
129 Dual {
130 v: self.v.tan(),
131 dv: self.dv / (c * c),
132 }
133 }
134
135 #[allow(dead_code)]
137 pub fn sinh(self) -> Self {
138 Dual {
139 v: self.v.sinh(),
140 dv: self.v.cosh() * self.dv,
141 }
142 }
143
144 #[allow(dead_code)]
146 pub fn cosh(self) -> Self {
147 Dual {
148 v: self.v.cosh(),
149 dv: self.v.sinh() * self.dv,
150 }
151 }
152}
153
154impl Add for Dual {
159 type Output = Self;
160 fn add(self, rhs: Self) -> Self {
161 Dual {
162 v: self.v + rhs.v,
163 dv: self.dv + rhs.dv,
164 }
165 }
166}
167
168impl Add<f64> for Dual {
169 type Output = Self;
170 fn add(self, rhs: f64) -> Self {
171 Dual {
172 v: self.v + rhs,
173 dv: self.dv,
174 }
175 }
176}
177
178impl Add<Dual> for f64 {
179 type Output = Dual;
180 fn add(self, rhs: Dual) -> Dual {
181 Dual {
182 v: self + rhs.v,
183 dv: rhs.dv,
184 }
185 }
186}
187
188impl Sub for Dual {
189 type Output = Self;
190 fn sub(self, rhs: Self) -> Self {
191 Dual {
192 v: self.v - rhs.v,
193 dv: self.dv - rhs.dv,
194 }
195 }
196}
197
198impl Sub<f64> for Dual {
199 type Output = Self;
200 fn sub(self, rhs: f64) -> Self {
201 Dual {
202 v: self.v - rhs,
203 dv: self.dv,
204 }
205 }
206}
207
208impl Sub<Dual> for f64 {
209 type Output = Dual;
210 fn sub(self, rhs: Dual) -> Dual {
211 Dual {
212 v: self - rhs.v,
213 dv: -rhs.dv,
214 }
215 }
216}
217
218impl Mul for Dual {
219 type Output = Self;
220 fn mul(self, rhs: Self) -> Self {
221 Dual {
222 v: self.v * rhs.v,
223 dv: self.dv * rhs.v + self.v * rhs.dv,
224 }
225 }
226}
227
228impl Mul<f64> for Dual {
229 type Output = Self;
230 fn mul(self, rhs: f64) -> Self {
231 Dual {
232 v: self.v * rhs,
233 dv: self.dv * rhs,
234 }
235 }
236}
237
238impl Mul<Dual> for f64 {
239 type Output = Dual;
240 fn mul(self, rhs: Dual) -> Dual {
241 Dual {
242 v: self * rhs.v,
243 dv: self * rhs.dv,
244 }
245 }
246}
247
248impl Div for Dual {
249 type Output = Self;
250 fn div(self, rhs: Self) -> Self {
251 Dual {
252 v: self.v / rhs.v,
253 dv: (self.dv * rhs.v - self.v * rhs.dv) / (rhs.v * rhs.v),
254 }
255 }
256}
257
258impl Div<f64> for Dual {
259 type Output = Self;
260 fn div(self, rhs: f64) -> Self {
261 Dual {
262 v: self.v / rhs,
263 dv: self.dv / rhs,
264 }
265 }
266}
267
268impl Neg for Dual {
269 type Output = Self;
270 fn neg(self) -> Self {
271 Dual {
272 v: -self.v,
273 dv: -self.dv,
274 }
275 }
276}
277
278pub fn grad1(f: impl Fn(Dual) -> Dual, x: f64) -> f64 {
286 f(dual(x, 1.0)).dv
287}
288
289pub fn hessian_diag(f: impl Fn(Dual) -> Dual, xs: &[f64]) -> Vec<f64> {
301 let h = 1e-5;
302 xs.iter()
303 .map(|&xi| {
304 let fp = f(dual(xi + h, 1.0)).dv;
305 let fm = f(dual(xi - h, 1.0)).dv;
306 (fp - fm) / (2.0 * h)
307 })
308 .collect()
309}
310
311pub fn jacobian_row(f: impl Fn(&[Dual]) -> Dual, xs: &[f64]) -> Vec<f64> {
316 let n = xs.len();
317 let mut row = Vec::with_capacity(n);
318 for i in 0..n {
319 let duals: Vec<Dual> = xs
320 .iter()
321 .enumerate()
322 .map(|(j, &x)| dual(x, if j == i { 1.0 } else { 0.0 }))
323 .collect();
324 row.push(f(&duals).dv);
325 }
326 row
327}
328
329#[allow(dead_code)]
335#[derive(Debug, Clone)]
336pub struct DualVec {
337 pub components: Vec<Dual>,
339}
340
341impl DualVec {
342 #[allow(dead_code)]
344 pub fn from_pairs(pairs: &[(f64, f64)]) -> Self {
345 DualVec {
346 components: pairs.iter().map(|&(v, dv)| Dual { v, dv }).collect(),
347 }
348 }
349
350 #[allow(dead_code)]
352 pub fn variable(xs: &[f64], seed_idx: usize) -> Self {
353 DualVec {
354 components: xs
355 .iter()
356 .enumerate()
357 .map(|(i, &x)| dual(x, if i == seed_idx { 1.0 } else { 0.0 }))
358 .collect(),
359 }
360 }
361
362 #[allow(dead_code)]
364 pub fn len(&self) -> usize {
365 self.components.len()
366 }
367
368 #[allow(dead_code)]
370 pub fn is_empty(&self) -> bool {
371 self.components.is_empty()
372 }
373}
374
375pub fn gradient(f: impl Fn(&[Dual]) -> Dual, xs: &[f64]) -> Vec<f64> {
379 jacobian_row(f, xs)
380}
381
382pub fn newton_step(f: impl Fn(Dual) -> Dual, x: f64) -> f64 {
395 let d = f(dual(x, 1.0));
396 if d.dv.abs() < 1e-14 {
397 x
398 } else {
399 x - d.v / d.dv
400 }
401}
402
403#[allow(dead_code)]
410#[derive(Debug, Clone)]
411pub struct TaylorExpand {
412 pub center: f64,
414 pub f0: f64,
416 pub f1: f64,
418 pub f2: f64,
420 pub f3: f64,
422}
423
424impl TaylorExpand {
425 pub fn build<F>(f: F, x0: f64) -> Self
432 where
433 F: Fn(Dual) -> Dual + Copy,
434 {
435 let h = 1e-5;
436 let f0 = f(dual(x0, 0.0)).v;
437 let f1 = f(dual(x0, 1.0)).dv;
438 let f2 = (f(dual(x0 + h, 1.0)).dv - f(dual(x0 - h, 1.0)).dv) / (2.0 * h);
440 let f2p = |xv: f64| (f(dual(xv + h, 1.0)).dv - f(dual(xv - h, 1.0)).dv) / (2.0 * h);
442 let f3 = (f2p(x0 + h) - f2p(x0 - h)) / (2.0 * h);
443 TaylorExpand {
444 center: x0,
445 f0,
446 f1,
447 f2,
448 f3,
449 }
450 }
451
452 pub fn eval(&self, x: f64) -> f64 {
455 let dx = x - self.center;
456 self.f0 + self.f1 * dx + (self.f2 / 2.0) * dx * dx + (self.f3 / 6.0) * dx * dx * dx
457 }
458}
459
460#[cfg(test)]
465mod tests {
466 use super::*;
467 use std::f64::consts::{E, PI};
468
469 const EPS: f64 = 1e-9;
470 const LOOSE: f64 = 1e-5;
471
472 #[test]
475 fn test_dual_add() {
476 let a = dual(3.0, 1.0);
477 let b = dual(2.0, 4.0);
478 let c = a + b;
479 assert!((c.v - 5.0).abs() < EPS);
480 assert!((c.dv - 5.0).abs() < EPS);
481 }
482
483 #[test]
484 fn test_dual_sub() {
485 let a = dual(5.0, 2.0);
486 let b = dual(3.0, 1.0);
487 let c = a - b;
488 assert!((c.v - 2.0).abs() < EPS);
489 assert!((c.dv - 1.0).abs() < EPS);
490 }
491
492 #[test]
493 fn test_dual_mul_product_rule() {
494 let x = dual(3.0, 1.0);
496 let y = x * x;
497 assert!((y.v - 9.0).abs() < EPS);
498 assert!((y.dv - 6.0).abs() < EPS);
499 }
500
501 #[test]
502 fn test_dual_div_quotient_rule() {
503 let x = dual(2.0, 1.0);
505 let y = x / (x + 1.0);
506 assert!((y.v - 2.0 / 3.0).abs() < EPS);
507 assert!((y.dv - 1.0 / 9.0).abs() < EPS);
508 }
509
510 #[test]
511 fn test_dual_neg() {
512 let x = dual(4.0, 1.0);
513 let y = -x;
514 assert!((y.v + 4.0).abs() < EPS);
515 assert!((y.dv + 1.0).abs() < EPS);
516 }
517
518 #[test]
519 fn test_dual_add_f64() {
520 let x = dual(1.0, 1.0) + 5.0;
521 assert!((x.v - 6.0).abs() < EPS);
522 assert!((x.dv - 1.0).abs() < EPS);
523 }
524
525 #[test]
526 fn test_dual_mul_f64() {
527 let x = dual(3.0, 1.0) * 4.0;
528 assert!((x.v - 12.0).abs() < EPS);
529 assert!((x.dv - 4.0).abs() < EPS);
530 }
531
532 #[test]
533 fn test_dual_sub_f64_rhs() {
534 let x = 10.0_f64 - dual(3.0, 1.0);
535 assert!((x.v - 7.0).abs() < EPS);
536 assert!((x.dv + 1.0).abs() < EPS);
537 }
538
539 #[test]
542 fn test_dual_sin_derivative() {
543 let x = dual(PI / 4.0, 1.0);
545 let y = x.sin();
546 assert!((y.v - (PI / 4.0).sin()).abs() < EPS);
547 assert!((y.dv - (PI / 4.0).cos()).abs() < EPS);
548 }
549
550 #[test]
551 fn test_dual_cos_derivative() {
552 let x = dual(PI / 3.0, 1.0);
553 let y = x.cos();
554 assert!((y.v - (PI / 3.0).cos()).abs() < EPS);
555 assert!((y.dv + (PI / 3.0).sin()).abs() < EPS);
556 }
557
558 #[test]
559 fn test_dual_exp_derivative() {
560 let x = dual(2.0, 1.0);
562 let y = x.exp();
563 assert!((y.v - E * E).abs() < 1e-10);
564 assert!((y.dv - E * E).abs() < 1e-10);
565 }
566
567 #[test]
568 fn test_dual_ln_derivative() {
569 let x = dual(3.0, 1.0);
571 let y = x.ln();
572 assert!((y.v - 3.0_f64.ln()).abs() < EPS);
573 assert!((y.dv - 1.0 / 3.0).abs() < EPS);
574 }
575
576 #[test]
577 fn test_dual_sqrt_derivative() {
578 let x = dual(4.0, 1.0);
580 let y = x.sqrt();
581 assert!((y.v - 2.0).abs() < EPS);
582 assert!((y.dv - 0.25).abs() < EPS);
583 }
584
585 #[test]
586 fn test_dual_abs_positive() {
587 let x = dual(3.0, 1.0);
588 let y = x.abs();
589 assert!((y.v - 3.0).abs() < EPS);
590 assert!((y.dv - 1.0).abs() < EPS);
591 }
592
593 #[test]
594 fn test_dual_abs_negative() {
595 let x = dual(-3.0, 1.0);
596 let y = x.abs();
597 assert!((y.v - 3.0).abs() < EPS);
598 assert!((y.dv + 1.0).abs() < EPS);
599 }
600
601 #[test]
602 fn test_dual_powi() {
603 let x = dual(2.0, 1.0);
605 let y = x.powi(3);
606 assert!((y.v - 8.0).abs() < EPS);
607 assert!((y.dv - 12.0).abs() < EPS);
608 }
609
610 #[test]
611 fn test_dual_powf() {
612 let x = dual(4.0, 1.0);
614 let y = x.powf(2.5);
615 assert!((y.v - 4.0_f64.powf(2.5)).abs() < 1e-10);
616 assert!((y.dv - 2.5 * 4.0_f64.powf(1.5)).abs() < 1e-10);
617 }
618
619 #[test]
622 fn test_grad1_quadratic() {
623 let f = |x: Dual| x * x + dual(3.0, 0.0) * x - 5.0;
625 assert!((grad1(f, 4.0) - 11.0).abs() < EPS);
626 }
627
628 #[test]
629 fn test_grad1_sin() {
630 let f = |x: Dual| x.sin();
632 let expected = (PI / 6.0).cos();
633 assert!((grad1(f, PI / 6.0) - expected).abs() < EPS);
634 }
635
636 #[test]
637 fn test_grad1_chain_rule() {
638 let f = |x: Dual| (x * x).exp();
640 let expected = 2.0 * E;
641 assert!((grad1(f, 1.0) - expected).abs() < 1e-10);
642 }
643
644 #[test]
645 fn test_grad1_constant_function() {
646 let f = |_x: Dual| dual(42.0, 0.0);
647 assert!(grad1(f, 1.0).abs() < EPS);
648 }
649
650 #[test]
653 fn test_hessian_diag_quadratic() {
654 let f = |x: Dual| x * x;
656 let xs = [1.0, 2.0, -3.0];
657 let h = hessian_diag(f, &xs);
658 for hi in &h {
659 assert!((hi - 2.0).abs() < LOOSE, "expected 2, got {hi}");
660 }
661 }
662
663 #[test]
664 fn test_hessian_diag_sin() {
665 let f = |x: Dual| x.sin();
667 let xs = [PI / 4.0];
668 let h = hessian_diag(f, &xs);
669 let expected = -(PI / 4.0).sin();
670 assert!((h[0] - expected).abs() < LOOSE);
671 }
672
673 #[test]
676 fn test_jacobian_row_linear() {
677 let f = |xs: &[Dual]| dual(2.0, 0.0) * xs[0] + dual(3.0, 0.0) * xs[1];
679 let row = jacobian_row(f, &[1.0, 1.0]);
680 assert!((row[0] - 2.0).abs() < EPS);
681 assert!((row[1] - 3.0).abs() < EPS);
682 }
683
684 #[test]
685 fn test_gradient_quadratic_surface() {
686 let f = |xs: &[Dual]| xs[0] * xs[0] + xs[1] * xs[1];
688 let g = gradient(f, &[3.0, 4.0]);
689 assert!((g[0] - 6.0).abs() < EPS);
690 assert!((g[1] - 8.0).abs() < EPS);
691 }
692
693 #[test]
694 fn test_gradient_cross_term() {
695 let f = |xs: &[Dual]| xs[0] * xs[1] * xs[2];
697 let g = gradient(f, &[2.0, 3.0, 4.0]);
698 assert!((g[0] - 12.0).abs() < EPS); assert!((g[1] - 8.0).abs() < EPS); assert!((g[2] - 6.0).abs() < EPS); }
702
703 #[test]
706 fn test_newton_step_sqrt2() {
707 let f = |x: Dual| x * x - 2.0;
709 let mut x = 2.0_f64;
710 for _ in 0..20 {
711 x = newton_step(f, x);
712 }
713 assert!((x - 2.0_f64.sqrt()).abs() < 1e-12);
714 }
715
716 #[test]
717 fn test_newton_step_cube_root() {
718 let f = |x: Dual| x.powi(3) - 8.0;
720 let mut x = 3.0_f64;
721 for _ in 0..30 {
722 x = newton_step(f, x);
723 }
724 assert!((x - 2.0).abs() < 1e-10);
725 }
726
727 #[test]
728 fn test_newton_step_zero_derivative_safe() {
729 let f = |x: Dual| x * x;
731 let x_new = newton_step(f, 0.0);
732 assert!(x_new.is_finite());
733 }
734
735 #[test]
738 fn test_dualvec_len() {
739 let dv = DualVec::variable(&[1.0, 2.0, 3.0], 0);
740 assert_eq!(dv.len(), 3);
741 }
742
743 #[test]
744 fn test_dualvec_seed() {
745 let dv = DualVec::variable(&[1.0, 2.0, 3.0], 1);
746 assert!((dv.components[0].dv).abs() < EPS);
747 assert!((dv.components[1].dv - 1.0).abs() < EPS);
748 assert!((dv.components[2].dv).abs() < EPS);
749 }
750
751 #[test]
752 fn test_dualvec_from_pairs() {
753 let dv = DualVec::from_pairs(&[(1.0, 0.0), (2.0, 1.0)]);
754 assert_eq!(dv.len(), 2);
755 assert!((dv.components[1].dv - 1.0).abs() < EPS);
756 }
757
758 #[test]
761 fn test_taylor_expand_sin_at_zero() {
762 let t = TaylorExpand::build(|x| x.sin(), 0.0);
765 let x = 0.3;
766 let approx = t.eval(x);
767 let exact = x.sin();
768 assert!(
769 (approx - exact).abs() < 1e-4,
770 "approx={approx}, exact={exact}"
771 );
772 }
773
774 #[test]
775 fn test_taylor_expand_exp_at_zero() {
776 let t = TaylorExpand::build(|x| x.exp(), 0.0);
778 let x = 0.5;
779 let approx = t.eval(x);
780 let exact = x.exp();
781 assert!(
782 (approx - exact).abs() < 1e-2,
783 "approx={approx}, exact={exact}"
784 );
785 }
786
787 #[test]
788 fn test_taylor_expand_exact_at_center() {
789 let t = TaylorExpand::build(|x| x * x + x, 2.0);
790 let approx = t.eval(2.0);
792 assert!((approx - 6.0).abs() < 1e-9, "approx={approx}");
793 }
794
795 #[test]
796 fn test_taylor_expand_f1_matches_grad1() {
797 let f = |x: Dual| (x * x).sin();
798 let x0 = 1.0;
799 let t = TaylorExpand::build(f, x0);
800 let g = grad1(f, x0);
801 assert!((t.f1 - g).abs() < 1e-10);
802 }
803
804 #[test]
807 fn test_dual_constant_zero_derivative() {
808 let c = Dual::constant(7.0);
809 assert!((c.v - 7.0).abs() < EPS);
810 assert!(c.dv.abs() < EPS);
811 }
812
813 #[test]
814 fn test_dual_variable_unit_derivative() {
815 let v = Dual::variable(5.0);
816 assert!((v.v - 5.0).abs() < EPS);
817 assert!((v.dv - 1.0).abs() < EPS);
818 }
819
820 #[test]
823 fn test_dual_tan_derivative() {
824 let x = dual(PI / 4.0, 1.0);
826 let y = x.tan();
827 let expected_dv = 1.0 / (PI / 4.0).cos().powi(2);
828 assert!((y.dv - expected_dv).abs() < 1e-10);
829 }
830
831 #[test]
832 fn test_dual_sinh_derivative() {
833 let x = dual(1.0, 1.0);
835 let y = x.sinh();
836 assert!((y.dv - 1.0_f64.cosh()).abs() < EPS);
837 }
838
839 #[test]
840 fn test_dual_cosh_derivative() {
841 let x = dual(1.0, 1.0);
843 let y = x.cosh();
844 assert!((y.dv - 1.0_f64.sinh()).abs() < EPS);
845 }
846
847 #[test]
850 fn test_f64_add_dual() {
851 let d = 3.0_f64 + dual(2.0, 1.0);
852 assert!((d.v - 5.0).abs() < EPS);
853 assert!((d.dv - 1.0).abs() < EPS);
854 }
855
856 #[test]
857 fn test_f64_mul_dual() {
858 let d = 5.0_f64 * dual(3.0, 1.0);
859 assert!((d.v - 15.0).abs() < EPS);
860 assert!((d.dv - 5.0).abs() < EPS);
861 }
862
863 #[test]
866 fn test_composed_sin_exp() {
867 let f = |x: Dual| x.exp().sin();
869 let expected = 1.0_f64.cos();
870 assert!((grad1(f, 0.0) - expected).abs() < EPS);
871 }
872
873 #[test]
874 fn test_composed_sqrt_ln() {
875 let f = |x: Dual| x.ln().sqrt();
877 let expected = 1.0 / (2.0 * E);
878 assert!((grad1(f, E) - expected).abs() < 1e-10);
879 }
880
881 #[test]
882 fn test_div_f64() {
883 let d = dual(6.0, 3.0) / 2.0;
884 assert!((d.v - 3.0).abs() < EPS);
885 assert!((d.dv - 1.5).abs() < EPS);
886 }
887}