num_dual/datatypes/
derivative.rs

1use crate::DualNum;
2use nalgebra::allocator::Allocator;
3use nalgebra::constraint::{SameNumberOfRows, ShapeConstraint};
4use nalgebra::*;
5use num_traits::Zero;
6use std::fmt;
7use std::marker::PhantomData;
8use std::mem::MaybeUninit;
9use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
10
11/// Wrapper struct for a derivative vector or matrix.
12#[derive(PartialEq, Eq, Clone, Debug)]
13pub struct Derivative<T: DualNum<F>, F, R: Dim, C: Dim>(
14    pub(crate) Option<OMatrix<T, R, C>>,
15    PhantomData<F>,
16)
17where
18    DefaultAllocator: Allocator<R, C>;
19
20impl<T: DualNum<F> + Copy, F: Copy, const R: usize, const C: usize> Copy
21    for Derivative<T, F, Const<R>, Const<C>>
22{
23}
24
25impl<T: DualNum<F>, F, R: Dim, C: Dim> Derivative<T, F, R, C>
26where
27    DefaultAllocator: Allocator<R, C>,
28{
29    pub fn new(derivative: Option<OMatrix<T, R, C>>) -> Self {
30        Self(derivative, PhantomData)
31    }
32
33    pub fn some(derivative: OMatrix<T, R, C>) -> Self {
34        Self::new(Some(derivative))
35    }
36
37    pub fn none() -> Self {
38        Self::new(None)
39    }
40
41    pub(crate) fn map<T2, F2>(&self, f: impl FnMut(T) -> T2) -> Derivative<T2, F2, R, C>
42    where
43        T2: DualNum<F2>,
44        DefaultAllocator: Allocator<R, C>,
45    {
46        let opt = self.0.as_ref().map(|eps| eps.map(f));
47        Derivative::new(opt)
48    }
49
50    // A version of map that doesn't clone values before mapping. Useful for the SimdValue impl,
51    // which would be redundantly cloning all the lanes of each epsilon value before extracting
52    // just one of them.
53    //
54    // To implement, we inline a copy of Matrix::map, which implicitly clones values, and remove
55    // the cloning.
56    pub(crate) fn map_borrowed<T2, F2>(
57        &self,
58        mut f: impl FnMut(&T) -> T2,
59    ) -> Derivative<T2, F2, R, C>
60    where
61        T2: DualNum<F2>,
62        DefaultAllocator: Allocator<R, C>,
63    {
64        let opt = self.0.as_ref().map(move |eps| {
65            let (nrows, ncols) = eps.shape_generic();
66            let mut res: Matrix<MaybeUninit<T2>, R, C, _> = Matrix::uninit(nrows, ncols);
67
68            for j in 0..ncols.value() {
69                for i in 0..nrows.value() {
70                    // Safety: all indices are in range.
71                    unsafe {
72                        let a = eps.data.get_unchecked(i, j);
73                        *res.data.get_unchecked_mut(i, j) = MaybeUninit::new(f(a));
74                    }
75                }
76            }
77
78            // Safety: res is now fully initialized.
79            unsafe { res.assume_init() }
80        });
81        Derivative::new(opt)
82    }
83
84    /// Same but bails out if the closure returns None
85    pub(crate) fn try_map_borrowed<T2, F2>(
86        &self,
87        mut f: impl FnMut(&T) -> Option<T2>,
88    ) -> Option<Derivative<T2, F2, R, C>>
89    where
90        T2: DualNum<F2>,
91        DefaultAllocator: Allocator<R, C>,
92    {
93        self.0
94            .as_ref()
95            .and_then(move |eps| {
96                let (nrows, ncols) = eps.shape_generic();
97                let mut res: Matrix<MaybeUninit<T2>, R, C, _> = Matrix::uninit(nrows, ncols);
98
99                for j in 0..ncols.value() {
100                    for i in 0..nrows.value() {
101                        // Safety: all indices are in range.
102                        unsafe {
103                            let a = eps.data.get_unchecked(i, j);
104                            *res.data.get_unchecked_mut(i, j) = MaybeUninit::new(f(a)?);
105                        }
106                    }
107                }
108
109                // Safety: res is now fully initialized.
110                Some(unsafe { res.assume_init() })
111            })
112            .map(Derivative::some)
113    }
114
115    pub fn derivative_generic(r: R, c: C, i: usize) -> Self {
116        let mut m = OMatrix::zeros_generic(r, c);
117        m[i] = T::one();
118        Self::some(m)
119    }
120
121    pub fn unwrap_generic(self, r: R, c: C) -> OMatrix<T, R, C> {
122        self.0.unwrap_or_else(|| OMatrix::zeros_generic(r, c))
123    }
124
125    pub fn fmt(&self, f: &mut fmt::Formatter, symbol: &str) -> fmt::Result {
126        if let Some(m) = self.0.as_ref() {
127            write!(f, " + ")?;
128            match m.shape() {
129                (1, 1) => write!(f, "{}", m[0])?,
130                (1, _) | (_, 1) => {
131                    let x: Vec<_> = m.iter().map(T::to_string).collect();
132                    write!(f, "[{}]", x.join(", "))?
133                }
134                (_, _) => write!(f, "{m}")?,
135            };
136            write!(f, "{symbol}")?;
137        }
138        write!(f, "")
139    }
140}
141
142impl<T: DualNum<F>, F> Derivative<T, F, U1, U1> {
143    #[expect(clippy::self_named_constructors)]
144    pub fn derivative() -> Self {
145        Self::some(SVector::identity())
146    }
147
148    pub fn unwrap(self) -> T {
149        self.0.map_or_else(
150            || T::zero(),
151            |s| {
152                let [[r]] = s.data.0;
153                r
154            },
155        )
156    }
157}
158
159impl<T: DualNum<F>, F, R: Dim, C: Dim> Mul<T> for Derivative<T, F, R, C>
160where
161    DefaultAllocator: Allocator<R, C>,
162{
163    type Output = Self;
164
165    fn mul(self, rhs: T) -> Self::Output {
166        Derivative::new(self.0.map(|x| x * rhs))
167    }
168}
169
170impl<T: DualNum<F>, F, R: Dim, C: Dim> Mul<T> for &Derivative<T, F, R, C>
171where
172    DefaultAllocator: Allocator<R, C>,
173{
174    type Output = Derivative<T, F, R, C>;
175
176    fn mul(self, rhs: T) -> Self::Output {
177        Derivative::new(self.0.as_ref().map(|x| x * rhs))
178    }
179}
180
181impl<T: DualNum<F>, F, R: Dim, C: Dim, R2: Dim, C2: Dim> Mul<&Derivative<T, F, R2, C2>>
182    for &Derivative<T, F, R, C>
183where
184    DefaultAllocator: Allocator<R, C> + Allocator<R2, C2> + Allocator<R, C2>,
185    ShapeConstraint: SameNumberOfRows<C, R2>,
186{
187    type Output = Derivative<T, F, R, C2>;
188
189    fn mul(self, rhs: &Derivative<T, F, R2, C2>) -> Derivative<T, F, R, C2> {
190        Derivative::new(self.0.as_ref().zip(rhs.0.as_ref()).map(|(s, r)| s * r))
191    }
192}
193
194impl<T: DualNum<F>, F, R: Dim, C: Dim> Div<T> for Derivative<T, F, R, C>
195where
196    DefaultAllocator: Allocator<R, C>,
197{
198    type Output = Self;
199
200    fn div(self, rhs: T) -> Self::Output {
201        Derivative::new(self.0.map(|x| x / rhs))
202    }
203}
204
205impl<T: DualNum<F>, F, R: Dim, C: Dim> Div<T> for &Derivative<T, F, R, C>
206where
207    DefaultAllocator: Allocator<R, C>,
208{
209    type Output = Derivative<T, F, R, C>;
210
211    fn div(self, rhs: T) -> Self::Output {
212        Derivative::new(self.0.as_ref().map(|x| x / rhs))
213    }
214}
215
216impl<T: DualNum<F>, F, R: Dim, C: Dim> Derivative<T, F, R, C>
217where
218    DefaultAllocator: Allocator<R, C>,
219{
220    pub fn tr_mul<R2: Dim, C2: Dim>(
221        &self,
222        rhs: &Derivative<T, F, R2, C2>,
223    ) -> Derivative<T, F, C, C2>
224    where
225        DefaultAllocator: Allocator<R2, C2> + Allocator<C, C2>,
226        ShapeConstraint: SameNumberOfRows<R, R2>,
227    {
228        Derivative::new(
229            self.0
230                .as_ref()
231                .zip(rhs.0.as_ref())
232                .map(|(s, r)| s.tr_mul(r)),
233        )
234    }
235}
236
237impl<T: DualNum<F>, F, R: Dim, C: Dim> Add for Derivative<T, F, R, C>
238where
239    DefaultAllocator: Allocator<R, C>,
240{
241    type Output = Self;
242
243    fn add(self, rhs: Self) -> Self::Output {
244        Self::new(match (self.0, rhs.0) {
245            (Some(s), Some(r)) => Some(s + r),
246            (Some(s), None) => Some(s),
247            (None, Some(r)) => Some(r),
248            (None, None) => None,
249        })
250    }
251}
252
253impl<T: DualNum<F>, F, R: Dim, C: Dim> Add<&Derivative<T, F, R, C>> for Derivative<T, F, R, C>
254where
255    DefaultAllocator: Allocator<R, C>,
256{
257    type Output = Derivative<T, F, R, C>;
258
259    fn add(self, rhs: &Derivative<T, F, R, C>) -> Self::Output {
260        Derivative::new(match (&self.0, &rhs.0) {
261            (Some(s), Some(r)) => Some(s + r),
262            (Some(s), None) => Some(s.clone()),
263            (None, Some(r)) => Some(r.clone()),
264            (None, None) => None,
265        })
266    }
267}
268
269impl<T: DualNum<F>, F, R: Dim, C: Dim> Add for &Derivative<T, F, R, C>
270where
271    DefaultAllocator: Allocator<R, C>,
272{
273    type Output = Derivative<T, F, R, C>;
274
275    fn add(self, rhs: Self) -> Self::Output {
276        Derivative::new(match (&self.0, &rhs.0) {
277            (Some(s), Some(r)) => Some(s + r),
278            (Some(s), None) => Some(s.clone()),
279            (None, Some(r)) => Some(r.clone()),
280            (None, None) => None,
281        })
282    }
283}
284
285impl<T: DualNum<F>, F, R: Dim, C: Dim> Sub for Derivative<T, F, R, C>
286where
287    DefaultAllocator: Allocator<R, C>,
288{
289    type Output = Self;
290
291    fn sub(self, rhs: Self) -> Self::Output {
292        Self::new(match (self.0, rhs.0) {
293            (Some(s), Some(r)) => Some(s - r),
294            (Some(s), None) => Some(s),
295            (None, Some(r)) => Some(-r),
296            (None, None) => None,
297        })
298    }
299}
300
301impl<T: DualNum<F>, F, R: Dim, C: Dim> Sub<&Derivative<T, F, R, C>> for Derivative<T, F, R, C>
302where
303    DefaultAllocator: Allocator<R, C>,
304{
305    type Output = Derivative<T, F, R, C>;
306
307    fn sub(self, rhs: &Derivative<T, F, R, C>) -> Self::Output {
308        Derivative::new(match (&self.0, &rhs.0) {
309            (Some(s), Some(r)) => Some(s - r),
310            (Some(s), None) => Some(s.clone()),
311            (None, Some(r)) => Some(-r.clone()),
312            (None, None) => None,
313        })
314    }
315}
316
317impl<T: DualNum<F>, F, R: Dim, C: Dim> Sub for &Derivative<T, F, R, C>
318where
319    DefaultAllocator: Allocator<R, C>,
320{
321    type Output = Derivative<T, F, R, C>;
322
323    fn sub(self, rhs: Self) -> Self::Output {
324        Derivative::new(match (&self.0, &rhs.0) {
325            (Some(s), Some(r)) => Some(s - r),
326            (Some(s), None) => Some(s.clone()),
327            (None, Some(r)) => Some(-r),
328            (None, None) => None,
329        })
330    }
331}
332
333impl<T: DualNum<F>, F, R: Dim, C: Dim> Neg for &Derivative<T, F, R, C>
334where
335    DefaultAllocator: Allocator<R, C>,
336{
337    type Output = Derivative<T, F, R, C>;
338
339    fn neg(self) -> Self::Output {
340        Derivative::new(self.0.as_ref().map(|x| -x))
341    }
342}
343
344impl<T: DualNum<F>, F, R: Dim, C: Dim> Neg for Derivative<T, F, R, C>
345where
346    DefaultAllocator: Allocator<R, C>,
347{
348    type Output = Self;
349
350    fn neg(self) -> Self::Output {
351        Derivative::new(self.0.map(|x| -x))
352    }
353}
354
355impl<T: DualNum<F>, F, R: Dim, C: Dim> AddAssign for Derivative<T, F, R, C>
356where
357    DefaultAllocator: Allocator<R, C>,
358{
359    fn add_assign(&mut self, rhs: Self) {
360        match (&mut self.0, rhs.0) {
361            (Some(s), Some(r)) => *s += &r,
362            (None, Some(r)) => self.0 = Some(r),
363            (_, None) => (),
364        };
365    }
366}
367
368impl<T: DualNum<F>, F, R: Dim, C: Dim> SubAssign for Derivative<T, F, R, C>
369where
370    DefaultAllocator: Allocator<R, C>,
371{
372    fn sub_assign(&mut self, rhs: Self) {
373        match (&mut self.0, rhs.0) {
374            (Some(s), Some(r)) => *s -= &r,
375            (None, Some(r)) => self.0 = Some(-&r),
376            (_, None) => (),
377        };
378    }
379}
380
381impl<T: DualNum<F>, F, R: Dim, C: Dim> MulAssign<T> for Derivative<T, F, R, C>
382where
383    DefaultAllocator: Allocator<R, C>,
384{
385    fn mul_assign(&mut self, rhs: T) {
386        if let Some(s) = &mut self.0 {
387            *s *= rhs
388        }
389    }
390}
391
392impl<T: DualNum<F>, F, R: Dim, C: Dim> DivAssign<T> for Derivative<T, F, R, C>
393where
394    DefaultAllocator: Allocator<R, C>,
395{
396    fn div_assign(&mut self, rhs: T) {
397        if let Some(s) = &mut self.0 {
398            *s /= rhs
399        }
400    }
401}
402
403impl<T, R: Dim, C: Dim> nalgebra::SimdValue for Derivative<T, T::Element, R, C>
404where
405    DefaultAllocator: Allocator<R, C>,
406    T: DualNum<T::Element> + SimdValue + Scalar,
407    T::Element: DualNum<T::Element> + Scalar + Zero,
408{
409    type Element = Derivative<T::Element, T::Element, R, C>;
410
411    type SimdBool = T::SimdBool;
412
413    const LANES: usize = T::LANES;
414
415    #[inline]
416    fn splat(val: Self::Element) -> Self {
417        val.map(|e| T::splat(e))
418    }
419
420    #[inline]
421    fn extract(&self, i: usize) -> Self::Element {
422        self.map_borrowed(|e| T::extract(e, i))
423    }
424
425    #[inline]
426    unsafe fn extract_unchecked(&self, i: usize) -> Self::Element {
427        let opt = self
428            .map_borrowed(|e| unsafe { T::extract_unchecked(e, i) })
429            .0
430            // Now check it's all zeros.
431            // Unfortunately there is no way to use the vectorized version of `is_zero`, which is
432            // only for matrices with statically known dimensions. Specialization would be
433            // required.
434            .filter(|x| Iterator::any(&mut x.iter(), |e| !e.is_zero()));
435        Derivative::new(opt)
436    }
437
438    // SIMD code will expect to be able to replace one lane with another Self::Element,
439    // even with a None Derivative, e.g.
440    //
441    // let single = Derivative::none();
442    // let mut x4 = Derivative::splat(single);
443    // let one = Derivative::some(...);
444    // x4.replace(1, one);
445    //
446    // So the implementation of `replace` will need to auto-upgrade to Some(zeros) in
447    // order to satisfy requests like that.
448    fn replace(&mut self, i: usize, val: Self::Element) {
449        match (&mut self.0, val.0) {
450            (Some(ours), Some(theirs)) => {
451                ours.zip_apply(&theirs, |e, replacement| e.replace(i, replacement));
452            }
453            (ours @ None, Some(theirs)) => {
454                let (r, c) = theirs.shape_generic();
455                let mut init: OMatrix<T, R, C> = OMatrix::zeros_generic(r, c);
456                init.zip_apply(&theirs, |e, replacement| e.replace(i, replacement));
457                *ours = Some(init);
458            }
459            (Some(ours), None) => {
460                ours.apply(|e| e.replace(i, T::Element::zero()));
461            }
462            _ => {}
463        }
464    }
465
466    unsafe fn replace_unchecked(&mut self, i: usize, val: Self::Element) {
467        match (&mut self.0, val.0) {
468            (Some(ours), Some(theirs)) => {
469                ours.zip_apply(&theirs, |e, replacement| unsafe {
470                    e.replace_unchecked(i, replacement)
471                });
472            }
473            (ours @ None, Some(theirs)) => {
474                let (r, c) = theirs.shape_generic();
475                let mut init: OMatrix<T, R, C> = OMatrix::zeros_generic(r, c);
476                init.zip_apply(&theirs, |e, replacement| unsafe {
477                    e.replace_unchecked(i, replacement)
478                });
479                *ours = Some(init);
480            }
481            (Some(ours), None) => {
482                ours.apply(|e| unsafe { e.replace_unchecked(i, T::Element::zero()) });
483            }
484            _ => {}
485        }
486    }
487
488    fn select(mut self, cond: Self::SimdBool, other: Self) -> Self {
489        // If cond is mixed, then we may need to generate big zero matrices to do the
490        // component-wise select on. So check if cond is all-true or all-first to avoid that.
491        if cond.all() {
492            self
493        } else if cond.none() {
494            other
495        } else {
496            match (&mut self.0, other.0) {
497                (Some(ours), Some(theirs)) => {
498                    ours.zip_apply(&theirs, |e, other_e| {
499                        // this will probably get optimized out
500                        let e_ = std::mem::replace(e, T::zero());
501                        *e = e_.select(cond, other_e)
502                    });
503                    self
504                }
505                (Some(ours), None) => {
506                    ours.apply(|e| {
507                        // this will probably get optimized out
508                        let e_ = std::mem::replace(e, T::zero());
509                        *e = e_.select(cond, T::zero());
510                    });
511                    self
512                }
513                (ours @ None, Some(mut theirs)) => {
514                    use std::ops::Not;
515                    let inverted: T::SimdBool = cond.not();
516                    theirs.apply(|e| {
517                        // this will probably get optimized out
518                        let e_ = std::mem::replace(e, T::zero());
519                        *e = e_.select(inverted, T::zero());
520                    });
521                    *ours = Some(theirs);
522                    self
523                }
524                _ => self,
525            }
526        }
527    }
528}
529
530use simba::scalar::{SubsetOf, SupersetOf};
531
532impl<TSuper, FSuper, T, F, R: Dim, C: Dim> SubsetOf<Derivative<TSuper, FSuper, R, C>>
533    for Derivative<T, F, R, C>
534where
535    TSuper: DualNum<FSuper> + SupersetOf<T>,
536    T: DualNum<F>,
537    DefaultAllocator: Allocator<R, C>,
538    // DefaultAllocator: Allocator<D>
539    //     + Allocator<U1, D>
540    //     + Allocator<D, U1>
541    //     + Allocator<D, D>,
542{
543    #[inline(always)]
544    fn to_superset(&self) -> Derivative<TSuper, FSuper, R, C> {
545        self.map_borrowed(|elem| TSuper::from_subset(elem))
546    }
547    #[inline(always)]
548    fn from_superset(element: &Derivative<TSuper, FSuper, R, C>) -> Option<Self> {
549        element.try_map_borrowed(|elem| TSuper::to_subset(elem))
550    }
551    #[inline(always)]
552    fn from_superset_unchecked(element: &Derivative<TSuper, FSuper, R, C>) -> Self {
553        element.map_borrowed(|elem| TSuper::to_subset_unchecked(elem))
554    }
555    #[inline(always)]
556    fn is_in_subset(element: &Derivative<TSuper, FSuper, R, C>) -> bool {
557        element
558            .0
559            .as_ref()
560            .is_none_or(|matrix| matrix.iter().all(|elem| TSuper::is_in_subset(elem)))
561    }
562}