1#![cfg_attr(not(feature = "std"), no_std)]
8
9extern crate alloc;
10use alloc::vec::Vec;
11use core::ops::{Add, Div, Mul, Neg, Sub};
12use num_traits::{Float, One, Zero};
13
14#[cfg(feature = "high-precision")]
16pub use amari_core::HighPrecisionFloat;
17pub use amari_core::{ExtendedFloat, PrecisionFloat, StandardFloat};
18
19pub mod comprehensive_tests;
20pub mod error;
21pub mod functions;
22pub mod multivector;
23pub mod verified;
24pub mod verified_contracts;
25
26#[cfg(feature = "gpu")]
27pub mod gpu;
28
29pub use error::{DualError, DualResult};
31pub use multivector::{DualMultivector, MultiDualMultivector};
32
33#[cfg(feature = "gpu")]
35pub use gpu::{
36 DualGpuAccelerated, DualGpuContext, DualGpuError, DualGpuOps, DualGpuResult, DualOperation,
37 GpuDualNumber, GpuMultiDual, GpuOperationParams, GpuParameter, NeuralNetworkConfig,
38 ObjectiveFunction, VectorFunction,
39};
40
41pub type StandardDual = DualNumber<StandardFloat>;
44
45pub type ExtendedDual = DualNumber<ExtendedFloat>;
47
48pub type StandardMultiDual = MultiDual<StandardFloat>;
50
51pub type ExtendedMultiDual = MultiDual<ExtendedFloat>;
53
54#[derive(Clone, Debug, PartialEq)]
56pub struct MultiDualNumber<T: Float> {
57 pub real: T,
59 pub duals: Vec<T>,
61}
62
63impl<T: Float> MultiDualNumber<T> {
64 pub fn new(real: T, duals: Vec<T>) -> Self {
66 Self { real, duals }
67 }
68
69 pub fn variable(value: T, num_vars: usize, var_index: usize) -> Self {
71 let mut duals = vec![T::zero(); num_vars];
72 if var_index < num_vars {
73 duals[var_index] = T::one();
74 }
75 Self::new(value, duals)
76 }
77
78 pub fn constant(value: T, num_vars: usize) -> Self {
80 Self::new(value, vec![T::zero(); num_vars])
81 }
82
83 pub fn num_vars(&self) -> usize {
85 self.duals.len()
86 }
87
88 pub fn sqrt(&self) -> Self {
90 let sqrt_real = self.real.sqrt();
91 let sqrt_deriv = T::one() / (T::from(2.0).unwrap() * sqrt_real);
92
93 let mut new_duals = Vec::with_capacity(self.duals.len());
94 for &dual in &self.duals {
95 new_duals.push(dual * sqrt_deriv);
96 }
97
98 Self::new(sqrt_real, new_duals)
99 }
100}
101
102impl<T: Float> Add for &MultiDualNumber<T> {
103 type Output = MultiDualNumber<T>;
104
105 fn add(self, other: Self) -> Self::Output {
106 assert_eq!(self.duals.len(), other.duals.len());
107 let mut new_duals = Vec::with_capacity(self.duals.len());
108 for (a, b) in self.duals.iter().zip(other.duals.iter()) {
109 new_duals.push(*a + *b);
110 }
111 MultiDualNumber::new(self.real + other.real, new_duals)
112 }
113}
114
115impl<T: Float> Mul for &MultiDualNumber<T> {
116 type Output = MultiDualNumber<T>;
117
118 fn mul(self, other: Self) -> Self::Output {
119 assert_eq!(self.duals.len(), other.duals.len());
120 let mut new_duals = Vec::with_capacity(self.duals.len());
121 for (a, b) in self.duals.iter().zip(other.duals.iter()) {
122 new_duals.push(*a * other.real + self.real * *b);
123 }
124 MultiDualNumber::new(self.real * other.real, new_duals)
125 }
126}
127
128impl<T: Float> Add<&MultiDualNumber<T>> for MultiDualNumber<T> {
130 type Output = MultiDualNumber<T>;
131
132 fn add(self, other: &MultiDualNumber<T>) -> Self::Output {
133 &self + other
134 }
135}
136
137impl<T: Float> Mul<&MultiDualNumber<T>> for MultiDualNumber<T> {
138 type Output = MultiDualNumber<T>;
139
140 fn mul(self, other: &MultiDualNumber<T>) -> Self::Output {
141 &self * other
142 }
143}
144
145#[derive(Clone, Copy, Debug, PartialEq)]
149pub struct DualNumber<T: Float> {
150 pub real: T,
152 pub dual: T,
154}
155
156impl<T: Float> DualNumber<T> {
157 pub fn new(real: T, dual: T) -> Self {
159 Self { real, dual }
160 }
161
162 pub fn variable(value: T) -> Self {
164 Self {
165 real: value,
166 dual: T::one(),
167 }
168 }
169
170 pub fn new_variable(value: T) -> Self {
172 Self::variable(value)
173 }
174
175 pub fn constant(value: T) -> Self {
177 Self {
178 real: value,
179 dual: T::zero(),
180 }
181 }
182
183 pub fn value(&self) -> T {
185 self.real
186 }
187
188 pub fn derivative(&self) -> T {
190 self.dual
191 }
192
193 pub fn apply_with_derivative<F, G>(&self, f: F, df: G) -> Self
195 where
196 F: Fn(T) -> T,
197 G: Fn(T) -> T,
198 {
199 Self {
200 real: f(self.real),
201 dual: df(self.real) * self.dual,
202 }
203 }
204
205 pub fn sin(self) -> Self {
207 self.apply_with_derivative(|x| x.sin(), |x| x.cos())
208 }
209
210 pub fn cos(self) -> Self {
212 self.apply_with_derivative(|x| x.cos(), |x| -x.sin())
213 }
214
215 pub fn exp(self) -> Self {
217 let exp_val = self.real.exp();
218 Self {
219 real: exp_val,
220 dual: exp_val * self.dual,
221 }
222 }
223
224 pub fn ln(self) -> Self {
226 Self {
227 real: self.real.ln(),
228 dual: self.dual / self.real,
229 }
230 }
231
232 pub fn powf(self, n: T) -> Self {
234 Self {
235 real: self.real.powf(n),
236 dual: n * self.real.powf(n - T::one()) * self.dual,
237 }
238 }
239
240 pub fn sqrt(self) -> Self {
242 let sqrt_val = self.real.sqrt();
243 Self {
244 real: sqrt_val,
245 dual: self.dual / (T::from(2.0).unwrap() * sqrt_val),
246 }
247 }
248
249 pub fn tanh(self) -> Self {
251 let tanh_val = self.real.tanh();
252 Self {
253 real: tanh_val,
254 dual: self.dual * (T::one() - tanh_val * tanh_val),
255 }
256 }
257
258 pub fn relu(self) -> Self {
260 if self.real > T::zero() {
261 self
262 } else {
263 Self::constant(T::zero())
264 }
265 }
266
267 pub fn sigmoid(self) -> Self {
269 let exp_neg_x = (-self.real).exp();
270 let sigmoid_val = T::one() / (T::one() + exp_neg_x);
271 Self {
272 real: sigmoid_val,
273 dual: self.dual * sigmoid_val * (T::one() - sigmoid_val),
274 }
275 }
276
277 pub fn softplus(self) -> Self {
279 let exp_x = self.real.exp();
280 Self {
281 real: (T::one() + exp_x).ln(),
282 dual: self.dual * exp_x / (T::one() + exp_x),
283 }
284 }
285
286 pub fn max(self, other: Self) -> Self {
288 if self.real >= other.real {
289 self
290 } else {
291 other
292 }
293 }
294
295 pub fn min(self, other: Self) -> Self {
297 if self.real <= other.real {
298 self
299 } else {
300 other
301 }
302 }
303
304 pub fn tan(self) -> Self {
306 let tan_val = self.real.tan();
307 let sec_squared = T::one() + tan_val * tan_val;
308 Self {
309 real: tan_val,
310 dual: self.dual * sec_squared,
311 }
312 }
313
314 pub fn sinh(self) -> Self {
316 let sinh_val = self.real.sinh();
317 Self {
318 real: sinh_val,
319 dual: self.dual * self.real.cosh(),
320 }
321 }
322
323 pub fn cosh(self) -> Self {
325 let cosh_val = self.real.cosh();
326 Self {
327 real: cosh_val,
328 dual: self.dual * self.real.sinh(),
329 }
330 }
331
332 pub fn powi(self, n: i32) -> Self {
334 if n == 0 {
335 return Self::new(T::one(), T::zero());
336 }
337 let real_result = self.real.powi(n);
338 let n_float = T::from(n).unwrap();
339 let dual_result = self.dual * n_float * self.real.powi(n - 1);
340 Self {
341 real: real_result,
342 dual: dual_result,
343 }
344 }
345}
346
347impl DualNumber<f32> {
349 pub const ZERO: Self = Self {
351 real: 0.0,
352 dual: 0.0,
353 };
354
355 pub const ONE: Self = Self {
357 real: 1.0,
358 dual: 0.0,
359 };
360
361 pub const fn new_variable_const(value: f32) -> Self {
363 Self {
364 real: value,
365 dual: 1.0,
366 }
367 }
368
369 pub const fn new_constant_const(value: f32) -> Self {
371 Self {
372 real: value,
373 dual: 0.0,
374 }
375 }
376}
377
378impl<T: Float> Add for DualNumber<T> {
380 type Output = Self;
381
382 fn add(self, other: Self) -> Self {
383 Self {
384 real: self.real + other.real,
385 dual: self.dual + other.dual,
386 }
387 }
388}
389
390impl<T: Float> Sub for DualNumber<T> {
391 type Output = Self;
392
393 fn sub(self, other: Self) -> Self {
394 Self {
395 real: self.real - other.real,
396 dual: self.dual - other.dual,
397 }
398 }
399}
400
401impl<T: Float> Mul for DualNumber<T> {
402 type Output = Self;
403
404 fn mul(self, other: Self) -> Self {
405 Self {
406 real: self.real * other.real,
407 dual: self.real * other.dual + self.dual * other.real,
408 }
410 }
411}
412
413impl<T: Float> Div for DualNumber<T> {
414 type Output = Self;
415
416 fn div(self, other: Self) -> Self {
417 let real_result = self.real / other.real;
418 let dual_result =
419 (self.dual * other.real - self.real * other.dual) / (other.real * other.real);
420
421 Self {
422 real: real_result,
423 dual: dual_result,
424 }
425 }
426}
427
428impl<T: Float> Neg for DualNumber<T> {
429 type Output = Self;
430
431 fn neg(self) -> Self {
432 Self {
433 real: -self.real,
434 dual: -self.dual,
435 }
436 }
437}
438
439impl<T: Float> Add<T> for DualNumber<T> {
441 type Output = Self;
442
443 fn add(self, scalar: T) -> Self {
444 Self {
445 real: self.real + scalar,
446 dual: self.dual,
447 }
448 }
449}
450
451impl<T: Float> Sub<T> for DualNumber<T> {
452 type Output = Self;
453
454 fn sub(self, scalar: T) -> Self {
455 Self {
456 real: self.real - scalar,
457 dual: self.dual,
458 }
459 }
460}
461
462impl<T: Float> Mul<T> for DualNumber<T> {
463 type Output = Self;
464
465 fn mul(self, scalar: T) -> Self {
466 Self {
467 real: self.real * scalar,
468 dual: self.dual * scalar,
469 }
470 }
471}
472
473impl<T: Float> Div<T> for DualNumber<T> {
474 type Output = Self;
475
476 fn div(self, scalar: T) -> Self {
477 Self {
478 real: self.real / scalar,
479 dual: self.dual / scalar,
480 }
481 }
482}
483
484impl<T: Float> Zero for DualNumber<T> {
485 fn zero() -> Self {
486 Self::constant(T::zero())
487 }
488
489 fn is_zero(&self) -> bool {
490 self.real.is_zero() && self.dual.is_zero()
491 }
492}
493
494impl<T: Float> One for DualNumber<T> {
495 fn one() -> Self {
496 Self::constant(T::one())
497 }
498}
499
500#[derive(Clone, Debug)]
502pub struct MultiDual<T: Float> {
503 pub value: T,
505 pub gradient: Vec<T>,
507}
508
509impl<T: Float> MultiDual<T> {
510 pub fn new(value: T, gradient: Vec<T>) -> Self {
512 Self { value, gradient }
513 }
514
515 pub fn variable(value: T, index: usize, n_vars: usize) -> Self {
517 let mut gradient = Vec::with_capacity(n_vars);
518 for _ in 0..n_vars {
519 gradient.push(T::zero());
520 }
521 gradient[index] = T::one();
522 Self { value, gradient }
523 }
524
525 pub fn constant(value: T, n_vars: usize) -> Self {
527 Self {
528 value,
529 gradient: {
530 let mut g = Vec::with_capacity(n_vars);
531 for _ in 0..n_vars {
532 g.push(T::zero());
533 }
534 g
535 },
536 }
537 }
538
539 pub fn partial(&self, index: usize) -> T {
541 self.gradient.get(index).copied().unwrap_or(T::zero())
542 }
543
544 pub fn gradient_norm(&self) -> T {
546 self.gradient
547 .iter()
548 .map(|&x| x * x)
549 .fold(T::zero(), |acc, x| acc + x)
550 .sqrt()
551 }
552}
553
554impl<T: Float> Add for MultiDual<T> {
555 type Output = Self;
556
557 fn add(self, other: Self) -> Self {
558 let mut gradient = Vec::with_capacity(self.gradient.len().max(other.gradient.len()));
559 for i in 0..gradient.capacity() {
560 let a = self.gradient.get(i).copied().unwrap_or(T::zero());
561 let b = other.gradient.get(i).copied().unwrap_or(T::zero());
562 gradient.push(a + b);
563 }
564
565 Self {
566 value: self.value + other.value,
567 gradient,
568 }
569 }
570}
571
572impl<T: Float> Mul for MultiDual<T> {
573 type Output = Self;
574
575 fn mul(self, other: Self) -> Self {
576 let mut gradient = Vec::with_capacity(self.gradient.len().max(other.gradient.len()));
577 for i in 0..gradient.capacity() {
578 let a_grad = self.gradient.get(i).copied().unwrap_or(T::zero());
579 let b_grad = other.gradient.get(i).copied().unwrap_or(T::zero());
580 gradient.push(self.value * b_grad + a_grad * other.value);
581 }
582
583 Self {
584 value: self.value * other.value,
585 gradient,
586 }
587 }
588}
589
590pub struct AutoDiffContext<T: Float> {
592 variables: Vec<DualNumber<T>>,
593 n_vars: usize,
594}
595
596impl<T: Float> AutoDiffContext<T> {
597 pub fn new(n_vars: usize) -> Self {
599 Self {
600 variables: Vec::with_capacity(n_vars),
601 n_vars,
602 }
603 }
604
605 pub fn add_variable(&mut self, value: T) -> usize {
607 let index = self.variables.len();
608 self.variables.push(DualNumber::variable(value));
609 index
610 }
611
612 pub fn eval_gradient<F>(&self, f: F) -> (T, Vec<T>)
614 where
615 F: Fn(&[DualNumber<T>]) -> DualNumber<T>,
616 {
617 let mut gradient = Vec::with_capacity(self.n_vars);
618 let mut value = T::zero();
619
620 for (i, _var) in self.variables.iter().enumerate() {
621 let mut inputs = Vec::with_capacity(self.variables.len());
623 for _ in 0..self.variables.len() {
624 inputs.push(DualNumber::constant(T::zero()));
625 }
626 for (j, &v) in self.variables.iter().enumerate() {
627 inputs[j] = if i == j {
628 DualNumber::variable(v.real)
629 } else {
630 DualNumber::constant(v.real)
631 };
632 }
633
634 let result = f(&inputs);
635 if i == 0 {
636 value = result.real;
637 }
638 gradient.push(result.dual);
639 }
640
641 (value, gradient)
642 }
643}
644
645#[cfg(test)]
646mod tests {
647 use super::*;
648 use alloc::vec;
649 use approx::assert_relative_eq;
650
651 #[test]
652 fn test_dual_arithmetic() {
653 let x = DualNumber::variable(2.0);
654 let y = DualNumber::variable(3.0);
655
656 let sum = x + y;
658 assert_eq!(sum.real, 5.0);
659 let product = x * 3.0;
663 assert_eq!(product.real, 6.0);
664 assert_eq!(product.dual, 3.0);
665 }
666
667 #[test]
668 fn test_chain_rule() {
669 let x = DualNumber::variable(2.0);
670
671 let result = (x * x).sin();
673 let expected_derivative = 2.0 * 2.0 * (2.0 * 2.0).cos(); assert_relative_eq!(result.real, (2.0 * 2.0).sin(), epsilon = 1e-10);
676 assert_relative_eq!(result.dual, expected_derivative, epsilon = 1e-10);
677 }
678
679 #[test]
680 fn test_exp_and_ln() {
681 let x = DualNumber::variable(1.0);
682
683 let exp_result = x.exp();
685 assert_relative_eq!(exp_result.real, 1.0f64.exp(), epsilon = 1e-10);
686 assert_relative_eq!(exp_result.dual, 1.0f64.exp(), epsilon = 1e-10);
687
688 let ln_result = x.ln();
690 assert_relative_eq!(ln_result.real, 1.0f64.ln(), epsilon = 1e-10);
691 assert_relative_eq!(ln_result.dual, 1.0, epsilon = 1e-10);
692 }
693
694 #[test]
695 fn test_activation_functions() {
696 let x = DualNumber::variable(1.0);
697
698 let relu_result = x.relu();
700 assert_eq!(relu_result.real, 1.0);
701 assert_eq!(relu_result.dual, 1.0);
702
703 let x_neg = DualNumber::variable(-1.0);
704 let relu_neg = x_neg.relu();
705 assert_eq!(relu_neg.real, 0.0);
706 assert_eq!(relu_neg.dual, 0.0);
707
708 let sigmoid_result = x.sigmoid();
710 let expected_sigmoid = 1.0 / (1.0 + (-1.0f64).exp());
711 assert_relative_eq!(sigmoid_result.real, expected_sigmoid, epsilon = 1e-10);
712
713 let expected_derivative = expected_sigmoid * (1.0 - expected_sigmoid);
715 assert_relative_eq!(sigmoid_result.dual, expected_derivative, epsilon = 1e-10);
716 }
717
718 #[test]
719 fn test_multi_dual() {
720 let x = MultiDual::variable(2.0, 0, 2); let y = MultiDual::variable(3.0, 1, 2); let x_squared = MultiDual::new(x.value * x.value, vec![2.0 * x.value, 0.0]);
725 let xy = x.clone() * y.clone();
726 let result = xy + x_squared;
727
728 assert_eq!(result.value, 10.0);
730
731 assert_eq!(result.partial(0), 7.0);
733
734 assert_eq!(result.partial(1), 2.0);
736 }
737
738 #[test]
739 fn test_autodiff_context() {
740 let mut ctx = AutoDiffContext::new(2);
741 ctx.add_variable(2.0); ctx.add_variable(3.0); let (value, grad) = ctx.eval_gradient(|vars| {
746 let x = vars[0];
747 let y = vars[1];
748 x * y + x * x
749 });
750
751 assert_eq!(value, 10.0); assert_eq!(grad.len(), 2);
753 }
756}