amari_dual/
lib.rs

1//! Dual number automatic differentiation for efficient gradient computation
2//!
3//! Dual numbers extend real numbers with an infinitesimal unit ε where ε² = 0.
4//! This allows for exact computation of derivatives without numerical approximation
5//! or computational graphs, making it ideal for forward-mode automatic differentiation.
6
7#![cfg_attr(not(feature = "std"), no_std)]
8
9extern crate alloc;
10use alloc::vec::Vec;
11use core::ops::{Add, Div, Mul, Neg, Sub};
12use num_traits::{Float, One, Zero};
13
14// Import precision types from amari-core
15#[cfg(feature = "high-precision")]
16pub use amari_core::HighPrecisionFloat;
17pub use amari_core::{ExtendedFloat, PrecisionFloat, StandardFloat};
18
19pub mod comprehensive_tests;
20pub mod error;
21pub mod functions;
22pub mod multivector;
23pub mod verified;
24pub mod verified_contracts;
25
26#[cfg(feature = "gpu")]
27pub mod gpu;
28
29// Re-export commonly used types
30pub use error::{DualError, DualResult};
31pub use multivector::{DualMultivector, MultiDualMultivector};
32
33// GPU acceleration exports
34#[cfg(feature = "gpu")]
35pub use gpu::{
36    DualGpuAccelerated, DualGpuContext, DualGpuError, DualGpuOps, DualGpuResult, DualOperation,
37    GpuDualNumber, GpuMultiDual, GpuOperationParams, GpuParameter, NeuralNetworkConfig,
38    ObjectiveFunction, VectorFunction,
39};
40
41// Precision-aware type aliases for dual numbers
42/// Standard precision dual number using f64
43pub type StandardDual = DualNumber<StandardFloat>;
44
45/// Extended precision dual number - uses high precision when available
46pub type ExtendedDual = DualNumber<ExtendedFloat>;
47
48/// Standard precision multi-dual number using f64
49pub type StandardMultiDual = MultiDual<StandardFloat>;
50
51/// Extended precision multi-dual number - uses high precision when available
52pub type ExtendedMultiDual = MultiDual<ExtendedFloat>;
53
54/// Multi-variable dual number for computing gradients
55#[derive(Clone, Debug, PartialEq)]
56pub struct MultiDualNumber<T: Float> {
57    /// Function value
58    pub real: T,
59    /// Partial derivatives (gradient)
60    pub duals: Vec<T>,
61}
62
63impl<T: Float> MultiDualNumber<T> {
64    /// Create a new multi-dual number
65    pub fn new(real: T, duals: Vec<T>) -> Self {
66        Self { real, duals }
67    }
68
69    /// Create a variable with derivative 1 at the specified index
70    pub fn variable(value: T, num_vars: usize, var_index: usize) -> Self {
71        let mut duals = vec![T::zero(); num_vars];
72        if var_index < num_vars {
73            duals[var_index] = T::one();
74        }
75        Self::new(value, duals)
76    }
77
78    /// Create a constant (all derivatives are zero)
79    pub fn constant(value: T, num_vars: usize) -> Self {
80        Self::new(value, vec![T::zero(); num_vars])
81    }
82
83    /// Get the number of variables
84    pub fn num_vars(&self) -> usize {
85        self.duals.len()
86    }
87
88    /// Square root function
89    pub fn sqrt(&self) -> Self {
90        let sqrt_real = self.real.sqrt();
91        let sqrt_deriv = T::one() / (T::from(2.0).unwrap() * sqrt_real);
92
93        let mut new_duals = Vec::with_capacity(self.duals.len());
94        for &dual in &self.duals {
95            new_duals.push(dual * sqrt_deriv);
96        }
97
98        Self::new(sqrt_real, new_duals)
99    }
100}
101
102impl<T: Float> Add for &MultiDualNumber<T> {
103    type Output = MultiDualNumber<T>;
104
105    fn add(self, other: Self) -> Self::Output {
106        assert_eq!(self.duals.len(), other.duals.len());
107        let mut new_duals = Vec::with_capacity(self.duals.len());
108        for (a, b) in self.duals.iter().zip(other.duals.iter()) {
109            new_duals.push(*a + *b);
110        }
111        MultiDualNumber::new(self.real + other.real, new_duals)
112    }
113}
114
115impl<T: Float> Mul for &MultiDualNumber<T> {
116    type Output = MultiDualNumber<T>;
117
118    fn mul(self, other: Self) -> Self::Output {
119        assert_eq!(self.duals.len(), other.duals.len());
120        let mut new_duals = Vec::with_capacity(self.duals.len());
121        for (a, b) in self.duals.iter().zip(other.duals.iter()) {
122            new_duals.push(*a * other.real + self.real * *b);
123        }
124        MultiDualNumber::new(self.real * other.real, new_duals)
125    }
126}
127
128// Add missing combinations for owned + reference
129impl<T: Float> Add<&MultiDualNumber<T>> for MultiDualNumber<T> {
130    type Output = MultiDualNumber<T>;
131
132    fn add(self, other: &MultiDualNumber<T>) -> Self::Output {
133        &self + other
134    }
135}
136
137impl<T: Float> Mul<&MultiDualNumber<T>> for MultiDualNumber<T> {
138    type Output = MultiDualNumber<T>;
139
140    fn mul(self, other: &MultiDualNumber<T>) -> Self::Output {
141        &self * other
142    }
143}
144
145/// A dual number: a + bε where ε² = 0
146///
147/// The real part stores the function value, the dual part stores the derivative.
148#[derive(Clone, Copy, Debug, PartialEq)]
149pub struct DualNumber<T: Float> {
150    /// Real part (function value)
151    pub real: T,
152    /// Dual part (derivative with respect to input variable)
153    pub dual: T,
154}
155
156impl<T: Float> DualNumber<T> {
157    /// Create a new dual number
158    pub fn new(real: T, dual: T) -> Self {
159        Self { real, dual }
160    }
161
162    /// Create a variable (derivative = 1)
163    pub fn variable(value: T) -> Self {
164        Self {
165            real: value,
166            dual: T::one(),
167        }
168    }
169
170    /// Create a variable (derivative = 1) - alias for consistency
171    pub fn new_variable(value: T) -> Self {
172        Self::variable(value)
173    }
174
175    /// Create a constant (derivative = 0)
176    pub fn constant(value: T) -> Self {
177        Self {
178            real: value,
179            dual: T::zero(),
180        }
181    }
182
183    /// Get the value (real part)
184    pub fn value(&self) -> T {
185        self.real
186    }
187
188    /// Get the derivative (dual part)
189    pub fn derivative(&self) -> T {
190        self.dual
191    }
192
193    /// Apply a function with known derivative
194    pub fn apply_with_derivative<F, G>(&self, f: F, df: G) -> Self
195    where
196        F: Fn(T) -> T,
197        G: Fn(T) -> T,
198    {
199        Self {
200            real: f(self.real),
201            dual: df(self.real) * self.dual,
202        }
203    }
204
205    /// Sine function
206    pub fn sin(self) -> Self {
207        self.apply_with_derivative(|x| x.sin(), |x| x.cos())
208    }
209
210    /// Cosine function
211    pub fn cos(self) -> Self {
212        self.apply_with_derivative(|x| x.cos(), |x| -x.sin())
213    }
214
215    /// Exponential function
216    pub fn exp(self) -> Self {
217        let exp_val = self.real.exp();
218        Self {
219            real: exp_val,
220            dual: exp_val * self.dual,
221        }
222    }
223
224    /// Natural logarithm
225    pub fn ln(self) -> Self {
226        Self {
227            real: self.real.ln(),
228            dual: self.dual / self.real,
229        }
230    }
231
232    /// Power function
233    pub fn powf(self, n: T) -> Self {
234        Self {
235            real: self.real.powf(n),
236            dual: n * self.real.powf(n - T::one()) * self.dual,
237        }
238    }
239
240    /// Square root
241    pub fn sqrt(self) -> Self {
242        let sqrt_val = self.real.sqrt();
243        Self {
244            real: sqrt_val,
245            dual: self.dual / (T::from(2.0).unwrap() * sqrt_val),
246        }
247    }
248
249    /// Hyperbolic tangent
250    pub fn tanh(self) -> Self {
251        let tanh_val = self.real.tanh();
252        Self {
253            real: tanh_val,
254            dual: self.dual * (T::one() - tanh_val * tanh_val),
255        }
256    }
257
258    /// ReLU activation function
259    pub fn relu(self) -> Self {
260        if self.real > T::zero() {
261            self
262        } else {
263            Self::constant(T::zero())
264        }
265    }
266
267    /// Sigmoid activation function
268    pub fn sigmoid(self) -> Self {
269        let exp_neg_x = (-self.real).exp();
270        let sigmoid_val = T::one() / (T::one() + exp_neg_x);
271        Self {
272            real: sigmoid_val,
273            dual: self.dual * sigmoid_val * (T::one() - sigmoid_val),
274        }
275    }
276
277    /// Softplus activation function
278    pub fn softplus(self) -> Self {
279        let exp_x = self.real.exp();
280        Self {
281            real: (T::one() + exp_x).ln(),
282            dual: self.dual * exp_x / (T::one() + exp_x),
283        }
284    }
285
286    /// Maximum of two dual numbers
287    pub fn max(self, other: Self) -> Self {
288        if self.real >= other.real {
289            self
290        } else {
291            other
292        }
293    }
294
295    /// Minimum of two dual numbers
296    pub fn min(self, other: Self) -> Self {
297        if self.real <= other.real {
298            self
299        } else {
300            other
301        }
302    }
303
304    /// Tangent function
305    pub fn tan(self) -> Self {
306        let tan_val = self.real.tan();
307        let sec_squared = T::one() + tan_val * tan_val;
308        Self {
309            real: tan_val,
310            dual: self.dual * sec_squared,
311        }
312    }
313
314    /// Hyperbolic sine
315    pub fn sinh(self) -> Self {
316        let sinh_val = self.real.sinh();
317        Self {
318            real: sinh_val,
319            dual: self.dual * self.real.cosh(),
320        }
321    }
322
323    /// Hyperbolic cosine
324    pub fn cosh(self) -> Self {
325        let cosh_val = self.real.cosh();
326        Self {
327            real: cosh_val,
328            dual: self.dual * self.real.sinh(),
329        }
330    }
331
332    /// Integer power
333    pub fn powi(self, n: i32) -> Self {
334        if n == 0 {
335            return Self::new(T::one(), T::zero());
336        }
337        let real_result = self.real.powi(n);
338        let n_float = T::from(n).unwrap();
339        let dual_result = self.dual * n_float * self.real.powi(n - 1);
340        Self {
341            real: real_result,
342            dual: dual_result,
343        }
344    }
345}
346
347// GPU-compatible constants for f32 dual numbers
348impl DualNumber<f32> {
349    /// Zero dual number (0 + 0ε)
350    pub const ZERO: Self = Self {
351        real: 0.0,
352        dual: 0.0,
353    };
354
355    /// One dual number (1 + 0ε)
356    pub const ONE: Self = Self {
357        real: 1.0,
358        dual: 0.0,
359    };
360
361    /// Variable dual number (value + 1ε) - useful for GPU operations
362    pub const fn new_variable_const(value: f32) -> Self {
363        Self {
364            real: value,
365            dual: 1.0,
366        }
367    }
368
369    /// Constant dual number (value + 0ε) - useful for GPU operations
370    pub const fn new_constant_const(value: f32) -> Self {
371        Self {
372            real: value,
373            dual: 0.0,
374        }
375    }
376}
377
378// Arithmetic operations for dual numbers
379impl<T: Float> Add for DualNumber<T> {
380    type Output = Self;
381
382    fn add(self, other: Self) -> Self {
383        Self {
384            real: self.real + other.real,
385            dual: self.dual + other.dual,
386        }
387    }
388}
389
390impl<T: Float> Sub for DualNumber<T> {
391    type Output = Self;
392
393    fn sub(self, other: Self) -> Self {
394        Self {
395            real: self.real - other.real,
396            dual: self.dual - other.dual,
397        }
398    }
399}
400
401impl<T: Float> Mul for DualNumber<T> {
402    type Output = Self;
403
404    fn mul(self, other: Self) -> Self {
405        Self {
406            real: self.real * other.real,
407            dual: self.real * other.dual + self.dual * other.real,
408            // ε² = 0, so dual * dual term vanishes
409        }
410    }
411}
412
413impl<T: Float> Div for DualNumber<T> {
414    type Output = Self;
415
416    fn div(self, other: Self) -> Self {
417        let real_result = self.real / other.real;
418        let dual_result =
419            (self.dual * other.real - self.real * other.dual) / (other.real * other.real);
420
421        Self {
422            real: real_result,
423            dual: dual_result,
424        }
425    }
426}
427
428impl<T: Float> Neg for DualNumber<T> {
429    type Output = Self;
430
431    fn neg(self) -> Self {
432        Self {
433            real: -self.real,
434            dual: -self.dual,
435        }
436    }
437}
438
439// Scalar operations
440impl<T: Float> Add<T> for DualNumber<T> {
441    type Output = Self;
442
443    fn add(self, scalar: T) -> Self {
444        Self {
445            real: self.real + scalar,
446            dual: self.dual,
447        }
448    }
449}
450
451impl<T: Float> Sub<T> for DualNumber<T> {
452    type Output = Self;
453
454    fn sub(self, scalar: T) -> Self {
455        Self {
456            real: self.real - scalar,
457            dual: self.dual,
458        }
459    }
460}
461
462impl<T: Float> Mul<T> for DualNumber<T> {
463    type Output = Self;
464
465    fn mul(self, scalar: T) -> Self {
466        Self {
467            real: self.real * scalar,
468            dual: self.dual * scalar,
469        }
470    }
471}
472
473impl<T: Float> Div<T> for DualNumber<T> {
474    type Output = Self;
475
476    fn div(self, scalar: T) -> Self {
477        Self {
478            real: self.real / scalar,
479            dual: self.dual / scalar,
480        }
481    }
482}
483
484impl<T: Float> Zero for DualNumber<T> {
485    fn zero() -> Self {
486        Self::constant(T::zero())
487    }
488
489    fn is_zero(&self) -> bool {
490        self.real.is_zero() && self.dual.is_zero()
491    }
492}
493
494impl<T: Float> One for DualNumber<T> {
495    fn one() -> Self {
496        Self::constant(T::one())
497    }
498}
499
500/// Multi-variable dual number for partial derivatives
501#[derive(Clone, Debug)]
502pub struct MultiDual<T: Float> {
503    /// Function value
504    pub value: T,
505    /// Partial derivatives (gradient)
506    pub gradient: Vec<T>,
507}
508
509impl<T: Float> MultiDual<T> {
510    /// Create new multi-dual number
511    pub fn new(value: T, gradient: Vec<T>) -> Self {
512        Self { value, gradient }
513    }
514
515    /// Create variable at given index
516    pub fn variable(value: T, index: usize, n_vars: usize) -> Self {
517        let mut gradient = Vec::with_capacity(n_vars);
518        for _ in 0..n_vars {
519            gradient.push(T::zero());
520        }
521        gradient[index] = T::one();
522        Self { value, gradient }
523    }
524
525    /// Create constant
526    pub fn constant(value: T, n_vars: usize) -> Self {
527        Self {
528            value,
529            gradient: {
530                let mut g = Vec::with_capacity(n_vars);
531                for _ in 0..n_vars {
532                    g.push(T::zero());
533                }
534                g
535            },
536        }
537    }
538
539    /// Get partial derivative at index
540    pub fn partial(&self, index: usize) -> T {
541        self.gradient.get(index).copied().unwrap_or(T::zero())
542    }
543
544    /// Compute norm of gradient (for optimization)
545    pub fn gradient_norm(&self) -> T {
546        self.gradient
547            .iter()
548            .map(|&x| x * x)
549            .fold(T::zero(), |acc, x| acc + x)
550            .sqrt()
551    }
552}
553
554impl<T: Float> Add for MultiDual<T> {
555    type Output = Self;
556
557    fn add(self, other: Self) -> Self {
558        let mut gradient = Vec::with_capacity(self.gradient.len().max(other.gradient.len()));
559        for i in 0..gradient.capacity() {
560            let a = self.gradient.get(i).copied().unwrap_or(T::zero());
561            let b = other.gradient.get(i).copied().unwrap_or(T::zero());
562            gradient.push(a + b);
563        }
564
565        Self {
566            value: self.value + other.value,
567            gradient,
568        }
569    }
570}
571
572impl<T: Float> Mul for MultiDual<T> {
573    type Output = Self;
574
575    fn mul(self, other: Self) -> Self {
576        let mut gradient = Vec::with_capacity(self.gradient.len().max(other.gradient.len()));
577        for i in 0..gradient.capacity() {
578            let a_grad = self.gradient.get(i).copied().unwrap_or(T::zero());
579            let b_grad = other.gradient.get(i).copied().unwrap_or(T::zero());
580            gradient.push(self.value * b_grad + a_grad * other.value);
581        }
582
583        Self {
584            value: self.value * other.value,
585            gradient,
586        }
587    }
588}
589
590/// Automatic differentiation context
591pub struct AutoDiffContext<T: Float> {
592    variables: Vec<DualNumber<T>>,
593    n_vars: usize,
594}
595
596impl<T: Float> AutoDiffContext<T> {
597    /// Create new context with n variables
598    pub fn new(n_vars: usize) -> Self {
599        Self {
600            variables: Vec::with_capacity(n_vars),
601            n_vars,
602        }
603    }
604
605    /// Add variable to context
606    pub fn add_variable(&mut self, value: T) -> usize {
607        let index = self.variables.len();
608        self.variables.push(DualNumber::variable(value));
609        index
610    }
611
612    /// Evaluate function and get all partial derivatives
613    pub fn eval_gradient<F>(&self, f: F) -> (T, Vec<T>)
614    where
615        F: Fn(&[DualNumber<T>]) -> DualNumber<T>,
616    {
617        let mut gradient = Vec::with_capacity(self.n_vars);
618        let mut value = T::zero();
619
620        for (i, _var) in self.variables.iter().enumerate() {
621            // Set up dual number for i-th partial derivative
622            let mut inputs = Vec::with_capacity(self.variables.len());
623            for _ in 0..self.variables.len() {
624                inputs.push(DualNumber::constant(T::zero()));
625            }
626            for (j, &v) in self.variables.iter().enumerate() {
627                inputs[j] = if i == j {
628                    DualNumber::variable(v.real)
629                } else {
630                    DualNumber::constant(v.real)
631                };
632            }
633
634            let result = f(&inputs);
635            if i == 0 {
636                value = result.real;
637            }
638            gradient.push(result.dual);
639        }
640
641        (value, gradient)
642    }
643}
644
645#[cfg(test)]
646mod tests {
647    use super::*;
648    use alloc::vec;
649    use approx::assert_relative_eq;
650
651    #[test]
652    fn test_dual_arithmetic() {
653        let x = DualNumber::variable(2.0);
654        let y = DualNumber::variable(3.0);
655
656        // Test addition: d/dx(x + y) = 1, d/dy(x + y) = 1
657        let sum = x + y;
658        assert_eq!(sum.real, 5.0);
659        // For single variable, derivative is 1
660
661        // Test multiplication: d/dx(x * 3) = 3
662        let product = x * 3.0;
663        assert_eq!(product.real, 6.0);
664        assert_eq!(product.dual, 3.0);
665    }
666
667    #[test]
668    fn test_chain_rule() {
669        let x = DualNumber::variable(2.0);
670
671        // Test sin(x^2): derivative should be 2x*cos(x^2)
672        let result = (x * x).sin();
673        let expected_derivative = 2.0 * 2.0 * (2.0 * 2.0).cos(); // 2x * cos(x^2) at x=2
674
675        assert_relative_eq!(result.real, (2.0 * 2.0).sin(), epsilon = 1e-10);
676        assert_relative_eq!(result.dual, expected_derivative, epsilon = 1e-10);
677    }
678
679    #[test]
680    fn test_exp_and_ln() {
681        let x = DualNumber::variable(1.0);
682
683        // Test exp(x): derivative should be exp(x)
684        let exp_result = x.exp();
685        assert_relative_eq!(exp_result.real, 1.0f64.exp(), epsilon = 1e-10);
686        assert_relative_eq!(exp_result.dual, 1.0f64.exp(), epsilon = 1e-10);
687
688        // Test ln(x): derivative should be 1/x
689        let ln_result = x.ln();
690        assert_relative_eq!(ln_result.real, 1.0f64.ln(), epsilon = 1e-10);
691        assert_relative_eq!(ln_result.dual, 1.0, epsilon = 1e-10);
692    }
693
694    #[test]
695    fn test_activation_functions() {
696        let x = DualNumber::variable(1.0);
697
698        // Test ReLU
699        let relu_result = x.relu();
700        assert_eq!(relu_result.real, 1.0);
701        assert_eq!(relu_result.dual, 1.0);
702
703        let x_neg = DualNumber::variable(-1.0);
704        let relu_neg = x_neg.relu();
705        assert_eq!(relu_neg.real, 0.0);
706        assert_eq!(relu_neg.dual, 0.0);
707
708        // Test sigmoid
709        let sigmoid_result = x.sigmoid();
710        let expected_sigmoid = 1.0 / (1.0 + (-1.0f64).exp());
711        assert_relative_eq!(sigmoid_result.real, expected_sigmoid, epsilon = 1e-10);
712
713        // Sigmoid derivative: sigmoid(x) * (1 - sigmoid(x))
714        let expected_derivative = expected_sigmoid * (1.0 - expected_sigmoid);
715        assert_relative_eq!(sigmoid_result.dual, expected_derivative, epsilon = 1e-10);
716    }
717
718    #[test]
719    fn test_multi_dual() {
720        // Test f(x,y) = x*y + x^2
721        let x = MultiDual::variable(2.0, 0, 2); // Variable 0 of 2
722        let y = MultiDual::variable(3.0, 1, 2); // Variable 1 of 2
723
724        let x_squared = MultiDual::new(x.value * x.value, vec![2.0 * x.value, 0.0]);
725        let xy = x.clone() * y.clone();
726        let result = xy + x_squared;
727
728        // f(2,3) = 2*3 + 2^2 = 6 + 4 = 10
729        assert_eq!(result.value, 10.0);
730
731        // ∂f/∂x = y + 2x = 3 + 4 = 7
732        assert_eq!(result.partial(0), 7.0);
733
734        // ∂f/∂y = x = 2
735        assert_eq!(result.partial(1), 2.0);
736    }
737
738    #[test]
739    fn test_autodiff_context() {
740        let mut ctx = AutoDiffContext::new(2);
741        ctx.add_variable(2.0); // x = 2
742        ctx.add_variable(3.0); // y = 3
743
744        // Evaluate f(x,y) = x*y + x^2
745        let (value, grad) = ctx.eval_gradient(|vars| {
746            let x = vars[0];
747            let y = vars[1];
748            x * y + x * x
749        });
750
751        assert_eq!(value, 10.0); // f(2,3) = 6 + 4 = 10
752        assert_eq!(grad.len(), 2);
753        // The gradient computation in this simplified version
754        // focuses on demonstrating the API structure
755    }
756}