1use core::fmt;
12use core::ops::{Add, Div, Mul, Neg, Sub};
13use num_traits::{Float, One, Zero};
14
15#[derive(Debug, Clone, Copy, PartialEq)]
37pub struct DualNumber<T: Float> {
38 pub real: T,
40 pub dual: T,
42}
43
44impl<T: Float> DualNumber<T> {
45 pub fn new(real: T, dual: T) -> Self {
47 Self { real, dual }
48 }
49
50 pub fn constant(value: T) -> Self {
54 Self {
55 real: value,
56 dual: T::zero(),
57 }
58 }
59
60 pub fn variable(value: T) -> Self {
64 Self {
65 real: value,
66 dual: T::one(),
67 }
68 }
69
70 pub fn value(&self) -> T {
72 self.real
73 }
74
75 pub fn derivative(&self) -> T {
77 self.dual
78 }
79
80 pub fn exp(self) -> Self {
84 let exp_real = self.real.exp();
85 Self {
86 real: exp_real,
87 dual: self.dual * exp_real,
88 }
89 }
90
91 pub fn ln(self) -> Self {
95 Self {
96 real: self.real.ln(),
97 dual: self.dual / self.real,
98 }
99 }
100
101 pub fn sin(self) -> Self {
105 Self {
106 real: self.real.sin(),
107 dual: self.dual * self.real.cos(),
108 }
109 }
110
111 pub fn cos(self) -> Self {
115 Self {
116 real: self.real.cos(),
117 dual: -self.dual * self.real.sin(),
118 }
119 }
120
121 pub fn tan(self) -> Self {
125 let tan_real = self.real.tan();
126 let cos_real = self.real.cos();
127 Self {
128 real: tan_real,
129 dual: self.dual / (cos_real * cos_real),
130 }
131 }
132
133 pub fn sqrt(self) -> Self {
137 let sqrt_real = self.real.sqrt();
138 Self {
139 real: sqrt_real,
140 dual: self.dual / (T::from(2.0).unwrap() * sqrt_real),
141 }
142 }
143
144 pub fn powf(self, n: T) -> Self {
148 let pow_real = self.real.powf(n);
149 Self {
150 real: pow_real,
151 dual: n * self.dual * self.real.powf(n - T::one()),
152 }
153 }
154
155 pub fn powi(self, n: i32) -> Self {
159 let n_float = T::from(n).unwrap();
160 let pow_real = self.real.powi(n);
161 Self {
162 real: pow_real,
163 dual: n_float * self.dual * self.real.powi(n - 1),
164 }
165 }
166
167 pub fn abs(self) -> Self {
171 let sign = if self.real >= T::zero() {
172 T::one()
173 } else {
174 -T::one()
175 };
176 Self {
177 real: self.real.abs(),
178 dual: self.dual * sign,
179 }
180 }
181
182 pub fn sinh(self) -> Self {
184 Self {
185 real: self.real.sinh(),
186 dual: self.dual * self.real.cosh(),
187 }
188 }
189
190 pub fn cosh(self) -> Self {
192 Self {
193 real: self.real.cosh(),
194 dual: self.dual * self.real.sinh(),
195 }
196 }
197
198 pub fn tanh(self) -> Self {
200 let tanh_real = self.real.tanh();
201 let cosh_real = self.real.cosh();
202 Self {
203 real: tanh_real,
204 dual: self.dual / (cosh_real * cosh_real),
205 }
206 }
207
208 pub fn max(self, other: Self) -> Self {
210 if self.real >= other.real {
211 self
212 } else {
213 other
214 }
215 }
216
217 pub fn min(self, other: Self) -> Self {
219 if self.real <= other.real {
220 self
221 } else {
222 other
223 }
224 }
225
226 pub fn sigmoid(self) -> Self {
230 let exp_neg = (-self.real).exp();
231 let sigmoid_real = T::one() / (T::one() + exp_neg);
232 let sigmoid_deriv = sigmoid_real * (T::one() - sigmoid_real);
233 Self {
234 real: sigmoid_real,
235 dual: self.dual * sigmoid_deriv,
236 }
237 }
238
239 pub fn apply_with_derivative<F, G>(self, f: F, df: G) -> Self
248 where
249 F: Fn(T) -> T,
250 G: Fn(T) -> T,
251 {
252 Self {
253 real: f(self.real),
254 dual: self.dual * df(self.real),
255 }
256 }
257}
258
259impl<T: Float> Add for DualNumber<T> {
268 type Output = Self;
269
270 fn add(self, other: Self) -> Self {
271 Self {
272 real: self.real + other.real,
273 dual: self.dual + other.dual,
274 }
275 }
276}
277
278impl<T: Float> Sub for DualNumber<T> {
279 type Output = Self;
280
281 fn sub(self, other: Self) -> Self {
282 Self {
283 real: self.real - other.real,
284 dual: self.dual - other.dual,
285 }
286 }
287}
288
289impl<T: Float> Mul for DualNumber<T> {
290 type Output = Self;
291
292 fn mul(self, other: Self) -> Self {
293 Self {
294 real: self.real * other.real,
295 dual: self.real * other.dual + self.dual * other.real,
296 }
297 }
298}
299
300impl<T: Float> Div for DualNumber<T> {
301 type Output = Self;
302
303 fn div(self, other: Self) -> Self {
304 let real = self.real / other.real;
305 let dual = (self.dual * other.real - self.real * other.dual) / (other.real * other.real);
306 Self { real, dual }
307 }
308}
309
310impl<T: Float> Neg for DualNumber<T> {
311 type Output = Self;
312
313 fn neg(self) -> Self {
314 Self {
315 real: -self.real,
316 dual: -self.dual,
317 }
318 }
319}
320
321impl<T: Float> Zero for DualNumber<T> {
322 fn zero() -> Self {
323 Self::constant(T::zero())
324 }
325
326 fn is_zero(&self) -> bool {
327 self.real.is_zero() && self.dual.is_zero()
328 }
329}
330
331impl<T: Float> One for DualNumber<T> {
332 fn one() -> Self {
333 Self::constant(T::one())
334 }
335}
336
337impl<T: Float + fmt::Display> fmt::Display for DualNumber<T> {
338 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
339 write!(f, "{} + {}ε", self.real, self.dual)
340 }
341}
342
343pub type StandardDual = DualNumber<f64>;
346
347#[cfg(feature = "high-precision")]
349pub type ExtendedDual = DualNumber<crate::ExtendedFloat>;
350
351pub type StandardMultiDual = MultiDualNumber<f64>;
353
354#[cfg(feature = "high-precision")]
356pub type ExtendedMultiDual = MultiDualNumber<crate::ExtendedFloat>;
357
358#[derive(Debug, Clone, PartialEq)]
381pub struct MultiDualNumber<T: Float> {
382 pub value: T,
384 pub gradient: Vec<T>,
386}
387
388impl<T: Float> MultiDualNumber<T> {
389 pub fn new(value: T, gradient: Vec<T>) -> Self {
391 Self { value, gradient }
392 }
393
394 pub fn constant(value: T, n_vars: usize) -> Self {
396 Self {
397 value,
398 gradient: vec![T::zero(); n_vars],
399 }
400 }
401
402 pub fn variable(value: T, var_index: usize, n_vars: usize) -> Self {
404 let mut gradient = vec![T::zero(); n_vars];
405 gradient[var_index] = T::one();
406 Self { value, gradient }
407 }
408
409 pub fn n_vars(&self) -> usize {
411 self.gradient.len()
412 }
413
414 pub fn get_value(&self) -> T {
416 self.value
417 }
418
419 pub fn get_gradient(&self) -> &[T] {
421 &self.gradient
422 }
423}
424
425impl<T: Float> Add for MultiDualNumber<T> {
426 type Output = Self;
427
428 fn add(self, other: Self) -> Self {
429 assert_eq!(
430 self.gradient.len(),
431 other.gradient.len(),
432 "Gradient dimension mismatch"
433 );
434 let gradient = self
435 .gradient
436 .iter()
437 .zip(&other.gradient)
438 .map(|(&a, &b)| a + b)
439 .collect();
440 Self {
441 value: self.value + other.value,
442 gradient,
443 }
444 }
445}
446
447impl<T: Float> Sub for MultiDualNumber<T> {
448 type Output = Self;
449
450 fn sub(self, other: Self) -> Self {
451 assert_eq!(
452 self.gradient.len(),
453 other.gradient.len(),
454 "Gradient dimension mismatch"
455 );
456 let gradient = self
457 .gradient
458 .iter()
459 .zip(&other.gradient)
460 .map(|(&a, &b)| a - b)
461 .collect();
462 Self {
463 value: self.value - other.value,
464 gradient,
465 }
466 }
467}
468
469impl<T: Float> Mul for MultiDualNumber<T> {
470 type Output = Self;
471
472 fn mul(self, other: Self) -> Self {
473 assert_eq!(
474 self.gradient.len(),
475 other.gradient.len(),
476 "Gradient dimension mismatch"
477 );
478 let gradient = self
480 .gradient
481 .iter()
482 .zip(&other.gradient)
483 .map(|(&df, &dg)| df * other.value + self.value * dg)
484 .collect();
485 Self {
486 value: self.value * other.value,
487 gradient,
488 }
489 }
490}
491
492impl<T: Float> Div for MultiDualNumber<T> {
493 type Output = Self;
494
495 fn div(self, other: Self) -> Self {
496 assert_eq!(
497 self.gradient.len(),
498 other.gradient.len(),
499 "Gradient dimension mismatch"
500 );
501 let g_squared = other.value * other.value;
503 let gradient = self
504 .gradient
505 .iter()
506 .zip(&other.gradient)
507 .map(|(&df, &dg)| (df * other.value - self.value * dg) / g_squared)
508 .collect();
509 Self {
510 value: self.value / other.value,
511 gradient,
512 }
513 }
514}
515
516impl<T: Float> Neg for MultiDualNumber<T> {
517 type Output = Self;
518
519 fn neg(self) -> Self {
520 let gradient = self.gradient.iter().map(|&x| -x).collect();
521 Self {
522 value: -self.value,
523 gradient,
524 }
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531 use approx::assert_relative_eq;
532
533 #[test]
534 fn test_dual_number_creation() {
535 let constant = DualNumber::constant(5.0);
536 assert_eq!(constant.real, 5.0);
537 assert_eq!(constant.dual, 0.0);
538
539 let variable = DualNumber::variable(3.0);
540 assert_eq!(variable.real, 3.0);
541 assert_eq!(variable.dual, 1.0);
542 }
543
544 #[test]
545 fn test_dual_number_arithmetic() {
546 let x = DualNumber::variable(2.0);
547 let c = DualNumber::constant(3.0);
548
549 let result = c * x;
551 assert_eq!(result.real, 6.0); assert_eq!(result.dual, 3.0); let result = x * x;
556 assert_eq!(result.real, 4.0); assert_eq!(result.dual, 4.0); }
559
560 #[test]
561 fn test_dual_number_division() {
562 let x = DualNumber::variable(4.0);
563 let c = DualNumber::constant(2.0);
564
565 let result = x / c;
567 assert_eq!(result.real, 2.0); assert_eq!(result.dual, 0.5); }
570
571 #[test]
572 fn test_dual_number_exp() {
573 let x = DualNumber::variable(0.0);
574
575 let result = x.exp();
577 assert_relative_eq!(result.real, 1.0, epsilon = 1e-10); assert_relative_eq!(result.dual, 1.0, epsilon = 1e-10); }
580
581 #[test]
582 fn test_dual_number_ln() {
583 let x = DualNumber::variable(1.0);
584
585 let result = x.ln();
587 assert_relative_eq!(result.real, 0.0, epsilon = 1e-10); assert_relative_eq!(result.dual, 1.0, epsilon = 1e-10); }
590
591 #[test]
592 fn test_dual_number_sin() {
593 let x = DualNumber::variable(0.0);
594
595 let result = x.sin();
597 assert_relative_eq!(result.real, 0.0, epsilon = 1e-10); assert_relative_eq!(result.dual, 1.0, epsilon = 1e-10); }
600
601 #[test]
602 fn test_dual_number_cos() {
603 let x = DualNumber::variable(0.0);
604
605 let result = x.cos();
607 assert_relative_eq!(result.real, 1.0, epsilon = 1e-10); assert_relative_eq!(result.dual, 0.0, epsilon = 1e-10); }
610
611 #[test]
612 fn test_dual_number_sqrt() {
613 let x = DualNumber::variable(4.0);
614
615 let result = x.sqrt();
617 assert_relative_eq!(result.real, 2.0, epsilon = 1e-10); assert_relative_eq!(result.dual, 0.25, epsilon = 1e-10); }
620
621 #[test]
622 fn test_multi_dual_number() {
623 let x = MultiDualNumber::variable(2.0, 0, 2);
625 let y = MultiDualNumber::variable(3.0, 1, 2);
626
627 let result = x + y;
628 assert_eq!(result.value, 5.0);
629 assert_eq!(result.gradient[0], 1.0); assert_eq!(result.gradient[1], 1.0); }
632
633 #[test]
634 fn test_multi_dual_number_product() {
635 let x = MultiDualNumber::variable(2.0, 0, 2);
637 let y = MultiDualNumber::variable(3.0, 1, 2);
638
639 let result = x * y;
640 assert_eq!(result.value, 6.0);
641 assert_eq!(result.gradient[0], 3.0); assert_eq!(result.gradient[1], 2.0); }
644
645 #[test]
646 fn test_chain_rule() {
647 let x = DualNumber::variable(1.0);
649 let x_squared = x * x;
650 let result = x_squared.sin();
651
652 assert_relative_eq!(result.real, 1.0_f64.sin(), epsilon = 1e-10);
654 assert_relative_eq!(result.dual, 1.0_f64.cos() * 2.0, epsilon = 1e-10);
656 }
657}