amari_dual/
multivector.rs

1//! Dual number multivectors for automatic differentiation in geometric algebra
2
3use crate::DualNumber;
4use alloc::vec::Vec;
5use amari_core::Multivector;
6use num_traits::{Float, Zero};
7
8/// Multivector with dual number coefficients for automatic differentiation
9#[derive(Clone, Debug)]
10pub struct DualMultivector<T: Float, const P: usize, const Q: usize, const R: usize> {
11    coefficients: Vec<DualNumber<T>>,
12}
13
14impl<T: Float, const P: usize, const Q: usize, const R: usize> DualMultivector<T, P, Q, R> {
15    const DIM: usize = P + Q + R;
16    const BASIS_COUNT: usize = 1 << Self::DIM;
17
18    /// Create zero dual multivector
19    pub fn zero() -> Self {
20        let mut coeffs = Vec::with_capacity(Self::BASIS_COUNT);
21        for _ in 0..Self::BASIS_COUNT {
22            coeffs.push(DualNumber::constant(T::zero()));
23        }
24        Self {
25            coefficients: coeffs,
26        }
27    }
28
29    /// Create from dual number coefficients
30    pub fn from_dual_coefficients(coeffs: Vec<DualNumber<T>>) -> Self {
31        assert_eq!(coeffs.len(), Self::BASIS_COUNT);
32        Self {
33            coefficients: coeffs,
34        }
35    }
36
37    /// Create dual multivector where each coefficient is a variable
38    pub fn new_variables(values: &[T]) -> Self {
39        let mut coeffs = Vec::with_capacity(Self::BASIS_COUNT);
40        for i in 0..Self::BASIS_COUNT {
41            if i < values.len() {
42                coeffs.push(DualNumber::variable(values[i]));
43            } else {
44                coeffs.push(DualNumber::constant(T::zero()));
45            }
46        }
47        Self {
48            coefficients: coeffs,
49        }
50    }
51
52    /// Create dual multivector where each coefficient is a variable - alias
53    pub fn new_variable(values: &[T]) -> Self {
54        Self::new_variables(values)
55    }
56
57    /// Create constant dual multivector from Multivector
58    pub fn constant_mv(mv: &Multivector<P, Q, R>) -> Self {
59        let mut coeffs = Vec::with_capacity(Self::BASIS_COUNT);
60        for i in 0..Self::BASIS_COUNT {
61            coeffs.push(DualNumber::constant(
62                T::from(mv.get(i)).unwrap_or(T::zero()),
63            ));
64        }
65        Self {
66            coefficients: coeffs,
67        }
68    }
69
70    /// Create constant dual multivector from slice
71    pub fn constant(values: &[T]) -> Self {
72        let mut coeffs = Vec::with_capacity(Self::BASIS_COUNT);
73        for i in 0..Self::BASIS_COUNT {
74            if i < values.len() {
75                coeffs.push(DualNumber::constant(values[i]));
76            } else {
77                coeffs.push(DualNumber::constant(T::zero()));
78            }
79        }
80        Self {
81            coefficients: coeffs,
82        }
83    }
84
85    /// Get coefficient at index
86    pub fn get(&self, index: usize) -> DualNumber<T> {
87        self.coefficients
88            .get(index)
89            .copied()
90            .unwrap_or(DualNumber::constant(T::zero()))
91    }
92
93    /// Set coefficient at index
94    pub fn set(&mut self, index: usize, value: DualNumber<T>) {
95        if index < self.coefficients.len() {
96            self.coefficients[index] = value;
97        }
98    }
99
100    /// Get the value part (without derivatives)
101    pub fn value(&self) -> Multivector<P, Q, R> {
102        let values: Vec<f64> = self
103            .coefficients
104            .iter()
105            .map(|coeff| coeff.real.to_f64().unwrap_or(0.0))
106            .collect();
107        Multivector::from_coefficients(values)
108    }
109
110    /// Get the derivative part
111    pub fn derivative(&self) -> Multivector<P, Q, R> {
112        let derivatives: Vec<f64> = self
113            .coefficients
114            .iter()
115            .map(|coeff| coeff.dual.to_f64().unwrap_or(0.0))
116            .collect();
117        Multivector::from_coefficients(derivatives)
118    }
119
120    /// Get the real part (value) as a Multivector
121    pub fn real_part(&self) -> Multivector<P, Q, R> {
122        self.value()
123    }
124
125    /// Compute magnitude of the real part
126    pub fn magnitude(&self) -> T {
127        T::from(self.value().magnitude()).unwrap()
128    }
129
130    /// Create from real multivector (alias for constant_mv)
131    pub fn from_real_multivector(mv: Multivector<P, Q, R>) -> Self {
132        Self::constant_mv(&mv)
133    }
134
135    /// Create scalar dual multivector from DualNumber
136    pub fn scalar(value: DualNumber<T>) -> Self {
137        let mut result = Self::zero();
138        result.coefficients[0] = value;
139        result
140    }
141
142    /// Create a basis vector (unit vector in given direction)
143    pub fn basis_vector(index: usize) -> Result<Self, &'static str> {
144        if index >= Self::DIM {
145            return Err("Index out of bounds for basis vector");
146        }
147
148        let mut result = Self::zero();
149        let basis_index = 1 << index; // 2^index for single basis blade
150        if basis_index < Self::BASIS_COUNT {
151            result.coefficients[basis_index] = DualNumber::constant(T::one());
152            Ok(result)
153        } else {
154            Err("Invalid basis vector index")
155        }
156    }
157
158    /// Get the grade of the multivector (highest grade component with non-zero coefficient)
159    pub fn grade(&self) -> usize {
160        let mut highest_grade = 0;
161        for (i, coeff) in self.coefficients.iter().enumerate() {
162            if coeff.real != T::zero() || coeff.dual != T::zero() {
163                let grade = i.count_ones() as usize;
164                if grade > highest_grade {
165                    highest_grade = grade;
166                }
167            }
168        }
169        highest_grade
170    }
171
172    /// Geometric product with automatic differentiation
173    pub fn geometric_product(&self, other: &Self) -> Self {
174        // Use the same Cayley table structure as regular geometric product
175        // but apply dual number multiplication rules
176
177        let mut result = Self::zero();
178
179        for i in 0..Self::BASIS_COUNT {
180            for j in 0..Self::BASIS_COUNT {
181                let index = i ^ j; // Same blade combination rule
182
183                // Dual number multiplication with chain rule
184                let product = self.coefficients[i] * other.coefficients[j];
185
186                // Apply the appropriate sign from Cayley table
187                let sign = self.compute_cayley_sign(i, j);
188                let signed_product = if sign > 0.0 { product } else { -product };
189
190                result.coefficients[index] = result.coefficients[index] + signed_product;
191            }
192        }
193
194        result
195    }
196
197    /// Simplified Cayley table sign computation (should use proper table)
198    fn compute_cayley_sign(&self, i: usize, j: usize) -> f64 {
199        // Simplified - in practice would use the Cayley table from amari-core
200        let mut swaps = 0;
201        let mut b = j;
202
203        while b != 0 {
204            let lowest_b = b & (!b + 1);
205            b ^= lowest_b;
206            let mask = lowest_b - 1;
207            let count = (i & mask).count_ones();
208            swaps += count;
209        }
210
211        if swaps % 2 == 0 {
212            1.0
213        } else {
214            -1.0
215        }
216    }
217
218    /// Reverse with automatic differentiation
219    pub fn reverse(&self) -> Self {
220        let mut result = Self::zero();
221
222        for i in 0..Self::BASIS_COUNT {
223            let grade = i.count_ones() as usize;
224            let sign = if grade == 0 || (grade * (grade - 1) / 2).is_multiple_of(2) {
225                1.0
226            } else {
227                -1.0
228            };
229
230            if sign > 0.0 {
231                result.coefficients[i] = self.coefficients[i];
232            } else {
233                result.coefficients[i] = -self.coefficients[i];
234            }
235        }
236
237        result
238    }
239
240    /// Grade projection with automatic differentiation
241    pub fn grade_projection(&self, grade: usize) -> Self {
242        let mut result = Self::zero();
243
244        for i in 0..Self::BASIS_COUNT {
245            if i.count_ones() as usize == grade {
246                result.coefficients[i] = self.coefficients[i];
247            }
248        }
249
250        result
251    }
252
253    /// Dual number norm (with automatic differentiation)
254    pub fn norm_squared(&self) -> DualNumber<T> {
255        let reversed = self.reverse();
256        let product = self.geometric_product(&reversed);
257        product.coefficients[0] // Scalar part
258    }
259
260    /// Dual number norm
261    pub fn norm(&self) -> DualNumber<T> {
262        self.norm_squared().sqrt()
263    }
264
265    /// Normalize with automatic differentiation
266    pub fn normalize(&self) -> Self {
267        let norm = self.norm();
268        let mut result = Self::zero();
269
270        for i in 0..Self::BASIS_COUNT {
271            result.coefficients[i] = self.coefficients[i] / norm;
272        }
273
274        result
275    }
276
277    /// Exponential map with automatic differentiation
278    pub fn exp(&self) -> Self {
279        // For bivectors, use closed form with dual numbers
280        let grade2 = self.grade_projection(2);
281        let remainder = self.clone() - grade2.clone();
282
283        if remainder.norm().value() < T::from(1e-10).unwrap() {
284            // Pure bivector case
285            let b_squared = grade2.geometric_product(&grade2).coefficients[0];
286
287            if b_squared.real > T::from(-1e-14).unwrap() {
288                // Hyperbolic case
289                let norm = b_squared.sqrt();
290                let cosh_norm = norm.apply_with_derivative(|x| x.cosh(), |x| x.sinh());
291                let sinh_norm = norm.apply_with_derivative(|x| x.sinh(), |x| x.cosh());
292
293                let mut result = Self::zero();
294                result.coefficients[0] = cosh_norm;
295
296                if !norm.is_zero() {
297                    let factor = sinh_norm / norm;
298                    for i in 0..Self::BASIS_COUNT {
299                        if i.count_ones() == 2 {
300                            // Bivector components
301                            result.coefficients[i] = grade2.coefficients[i] * factor;
302                        }
303                    }
304                }
305
306                result
307            } else {
308                // Circular case
309                let norm = (-b_squared).sqrt();
310                let cos_norm = norm.apply_with_derivative(|x| x.cos(), |x| -x.sin());
311                let sin_norm = norm.apply_with_derivative(|x| x.sin(), |x| x.cos());
312
313                let mut result = Self::zero();
314                result.coefficients[0] = cos_norm;
315
316                let factor = sin_norm / norm;
317                for i in 0..Self::BASIS_COUNT {
318                    if i.count_ones() == 2 {
319                        result.coefficients[i] = grade2.coefficients[i] * factor;
320                    }
321                }
322
323                result
324            }
325        } else {
326            // General case - use series expansion
327            self.exp_series()
328        }
329    }
330
331    /// Series expansion for exponential
332    fn exp_series(&self) -> Self {
333        let mut result = Self::zero();
334        result.coefficients[0] = DualNumber::constant(T::one());
335
336        let mut term = result.clone();
337
338        for n in 1..20 {
339            term = term.geometric_product(self);
340            let factorial = T::from((1..=n).product::<usize>()).unwrap();
341            let scaled_term = term.clone() * DualNumber::constant(T::one() / factorial);
342
343            result = result + scaled_term;
344
345            // Check convergence (simplified)
346            if term.norm().value() < T::from(1e-14).unwrap() {
347                break;
348            }
349        }
350
351        result
352    }
353
354    /// Apply a function element-wise with automatic differentiation
355    pub fn map<F, G>(&self, f: F, df: G) -> Self
356    where
357        F: Fn(T) -> T,
358        G: Fn(T) -> T,
359    {
360        let mut result = Self::zero();
361        for i in 0..Self::BASIS_COUNT {
362            result.coefficients[i] = self.coefficients[i].apply_with_derivative(&f, &df);
363        }
364        result
365    }
366
367    /// Forward-mode automatic differentiation
368    ///
369    /// Computes both the value and gradient of a function using dual numbers.
370    /// Forward-mode AD propagates derivatives alongside function evaluation,
371    /// allowing exact derivative computation without finite differences.
372    ///
373    /// # Parameters
374    /// - `f`: A function that takes a `DualMultivector` and returns a `DualMultivector`
375    ///
376    /// # Returns
377    /// A tuple containing:
378    /// - The scalar value of the function at the input point
379    /// - The full `DualMultivector` result containing all partial derivatives
380    ///
381    /// # Example
382    /// ```
383    /// use amari_dual::{DualMultivector, DualNumber};
384    /// let input = DualMultivector::<f64, 3, 0, 0>::scalar(DualNumber::constant(2.0));
385    /// let (value, gradient) = input.forward_mode_ad(|x| {
386    ///     x.geometric_product(&x)  // x^2 in geometric algebra
387    /// });
388    /// // Computes geometric square and its derivative
389    /// ```
390    ///
391    /// # How it works
392    /// The dual number coefficients automatically track derivatives through
393    /// all arithmetic operations via operator overloading. The 'dual' part
394    /// of each number carries the derivative information, which is propagated
395    /// according to the chain rule during computation.
396    pub fn forward_mode_ad<F>(&self, f: F) -> (T, Self)
397    where
398        F: Fn(Self) -> Self,
399    {
400        let result = f(self.clone());
401        let value = result.coefficients[0].real; // Extract scalar part value
402        let gradient = result; // The result contains gradients
403        (value, gradient)
404    }
405}
406
407// Arithmetic operations
408impl<T: Float, const P: usize, const Q: usize, const R: usize> core::ops::Add
409    for DualMultivector<T, P, Q, R>
410{
411    type Output = Self;
412
413    fn add(mut self, other: Self) -> Self {
414        for i in 0..Self::BASIS_COUNT {
415            self.coefficients[i] = self.coefficients[i] + other.coefficients[i];
416        }
417        self
418    }
419}
420
421impl<T: Float, const P: usize, const Q: usize, const R: usize> core::ops::Sub
422    for DualMultivector<T, P, Q, R>
423{
424    type Output = Self;
425
426    fn sub(mut self, other: Self) -> Self {
427        for i in 0..Self::BASIS_COUNT {
428            self.coefficients[i] = self.coefficients[i] - other.coefficients[i];
429        }
430        self
431    }
432}
433
434impl<T: Float, const P: usize, const Q: usize, const R: usize> core::ops::Mul<DualNumber<T>>
435    for DualMultivector<T, P, Q, R>
436{
437    type Output = Self;
438
439    fn mul(mut self, scalar: DualNumber<T>) -> Self {
440        for i in 0..Self::BASIS_COUNT {
441            self.coefficients[i] = self.coefficients[i] * scalar;
442        }
443        self
444    }
445}
446
447/// Multi-variable dual multivector for computing full Jacobians
448#[derive(Clone, Debug)]
449pub struct MultiDualMultivector<T: Float> {
450    /// Function values for each basis component
451    pub values: Vec<T>,
452    /// Jacobian matrix: [basis_component][variable]
453    pub jacobian: Vec<Vec<T>>,
454    pub n_vars: usize,
455    pub basis_count: usize,
456}
457
458impl<T: Float> MultiDualMultivector<T> {
459    /// Create new multi-dual multivector
460    pub fn new(values: Vec<T>, n_vars: usize) -> Self {
461        let basis_count = values.len();
462        let mut jacobian = Vec::with_capacity(basis_count);
463        for _ in 0..basis_count {
464            let mut row = Vec::with_capacity(n_vars);
465            for _ in 0..n_vars {
466                row.push(T::zero());
467            }
468            jacobian.push(row);
469        }
470
471        Self {
472            values,
473            jacobian,
474            n_vars,
475            basis_count,
476        }
477    }
478
479    /// Create variable multivector (one variable per coefficient)
480    pub fn variables(values: Vec<T>) -> Self {
481        let n_vars = values.len();
482        let basis_count = values.len();
483        let mut jacobian = Vec::with_capacity(basis_count);
484        for _ in 0..basis_count {
485            let mut row = Vec::with_capacity(n_vars);
486            for _ in 0..n_vars {
487                row.push(T::zero());
488            }
489            jacobian.push(row);
490        }
491
492        // Set up identity jacobian
493        for (i, row) in jacobian.iter_mut().enumerate().take(basis_count) {
494            if i < n_vars {
495                row[i] = T::one();
496            }
497        }
498
499        Self {
500            values,
501            jacobian,
502            n_vars,
503            basis_count,
504        }
505    }
506
507    /// Get partial derivative of coefficient i with respect to variable j
508    pub fn partial(&self, coeff_index: usize, var_index: usize) -> T {
509        self.jacobian
510            .get(coeff_index)
511            .and_then(|row| row.get(var_index))
512            .copied()
513            .unwrap_or(T::zero())
514    }
515
516    /// Get full gradient of coefficient i
517    pub fn gradient(&self, coeff_index: usize) -> Vec<T> {
518        self.jacobian.get(coeff_index).cloned().unwrap_or_else(|| {
519            let mut grad = Vec::with_capacity(self.n_vars);
520            for _ in 0..self.n_vars {
521                grad.push(T::zero());
522            }
523            grad
524        })
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531    use alloc::vec;
532    use approx::assert_relative_eq;
533
534    #[test]
535    fn test_dual_multivector_creation() {
536        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
537        let dmv = DualMultivector::<f64, 3, 0, 0>::new_variables(&values);
538
539        assert_eq!(dmv.coefficients.len(), 8);
540        assert_eq!(dmv.get(0).real, 1.0);
541        assert_eq!(dmv.get(0).dual, 1.0); // Variable has derivative 1
542    }
543
544    #[test]
545    fn test_dual_geometric_product() {
546        let values1 = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
547        let values2 = vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
548
549        let dmv1 = DualMultivector::<f64, 3, 0, 0>::new_variables(&values1);
550        let dmv2 = DualMultivector::<f64, 3, 0, 0>::new_variables(&values2);
551
552        let product = dmv1.geometric_product(&dmv2);
553
554        // Should have non-zero derivative
555        assert!(!product.get(1).is_zero());
556    }
557
558    #[test]
559    fn test_dual_norm() {
560        let values = vec![3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
561        let dmv = DualMultivector::<f64, 3, 0, 0>::new_variables(&values);
562
563        let norm = dmv.norm();
564
565        // Norm should be 5.0
566        assert_relative_eq!(norm.real, 5.0, epsilon = 1e-10);
567
568        // Derivative should be non-zero
569        assert!(norm.dual.abs() > 1e-10);
570    }
571
572    #[test]
573    fn test_dual_exp() {
574        // Create a small bivector
575        let mut values = vec![0.0; 8];
576        values[3] = 0.1; // Small bivector component
577
578        let dmv = DualMultivector::<f64, 3, 0, 0>::new_variables(&values);
579        let exp_result = dmv.exp();
580
581        // Should have computed exponential
582        assert!(exp_result.get(0).real > 0.9); // Close to 1 for small bivector
583        assert!(exp_result.get(0).dual.abs() > 0.0); // Should have derivative
584    }
585
586    #[test]
587    fn test_multi_dual_multivector() {
588        let values = vec![1.0, 2.0, 3.0, 4.0];
589        let mdmv = MultiDualMultivector::variables(values.clone());
590
591        assert_eq!(mdmv.values, values);
592        assert_eq!(mdmv.n_vars, 4);
593
594        // Check identity jacobian
595        for i in 0..4 {
596            assert_eq!(mdmv.partial(i, i), 1.0);
597            for j in 0..4 {
598                if i != j {
599                    assert_eq!(mdmv.partial(i, j), 0.0);
600                }
601            }
602        }
603    }
604}