1use crate::DualNum;
2use nalgebra::allocator::Allocator;
3use nalgebra::constraint::{SameNumberOfRows, ShapeConstraint};
4use nalgebra::*;
5use num_traits::Zero;
6use std::fmt;
7use std::marker::PhantomData;
8use std::mem::MaybeUninit;
9use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
10
11#[derive(PartialEq, Eq, Clone, Debug)]
13pub struct Derivative<T: DualNum<F>, F, R: Dim, C: Dim>(
14 pub(crate) Option<OMatrix<T, R, C>>,
15 PhantomData<F>,
16)
17where
18 DefaultAllocator: Allocator<R, C>;
19
20impl<T: DualNum<F> + Copy, F: Copy, const R: usize, const C: usize> Copy
21 for Derivative<T, F, Const<R>, Const<C>>
22{
23}
24
25impl<T: DualNum<F>, F, R: Dim, C: Dim> Derivative<T, F, R, C>
26where
27 DefaultAllocator: Allocator<R, C>,
28{
29 pub fn new(derivative: Option<OMatrix<T, R, C>>) -> Self {
30 Self(derivative, PhantomData)
31 }
32
33 pub fn some(derivative: OMatrix<T, R, C>) -> Self {
34 Self::new(Some(derivative))
35 }
36
37 pub fn none() -> Self {
38 Self::new(None)
39 }
40
41 pub(crate) fn map<T2, F2>(&self, f: impl FnMut(T) -> T2) -> Derivative<T2, F2, R, C>
42 where
43 T2: DualNum<F2>,
44 DefaultAllocator: Allocator<R, C>,
45 {
46 let opt = self.0.as_ref().map(|eps| eps.map(f));
47 Derivative::new(opt)
48 }
49
50 pub(crate) fn map_borrowed<T2, F2>(
57 &self,
58 mut f: impl FnMut(&T) -> T2,
59 ) -> Derivative<T2, F2, R, C>
60 where
61 T2: DualNum<F2>,
62 DefaultAllocator: Allocator<R, C>,
63 {
64 let opt = self.0.as_ref().map(move |eps| {
65 let (nrows, ncols) = eps.shape_generic();
66 let mut res: Matrix<MaybeUninit<T2>, R, C, _> = Matrix::uninit(nrows, ncols);
67
68 for j in 0..ncols.value() {
69 for i in 0..nrows.value() {
70 unsafe {
72 let a = eps.data.get_unchecked(i, j);
73 *res.data.get_unchecked_mut(i, j) = MaybeUninit::new(f(a));
74 }
75 }
76 }
77
78 unsafe { res.assume_init() }
80 });
81 Derivative::new(opt)
82 }
83
84 pub(crate) fn try_map_borrowed<T2, F2>(
86 &self,
87 mut f: impl FnMut(&T) -> Option<T2>,
88 ) -> Option<Derivative<T2, F2, R, C>>
89 where
90 T2: DualNum<F2>,
91 DefaultAllocator: Allocator<R, C>,
92 {
93 self.0
94 .as_ref()
95 .and_then(move |eps| {
96 let (nrows, ncols) = eps.shape_generic();
97 let mut res: Matrix<MaybeUninit<T2>, R, C, _> = Matrix::uninit(nrows, ncols);
98
99 for j in 0..ncols.value() {
100 for i in 0..nrows.value() {
101 unsafe {
103 let a = eps.data.get_unchecked(i, j);
104 *res.data.get_unchecked_mut(i, j) = MaybeUninit::new(f(a)?);
105 }
106 }
107 }
108
109 Some(unsafe { res.assume_init() })
111 })
112 .map(Derivative::some)
113 }
114
115 pub fn derivative_generic(r: R, c: C, i: usize) -> Self {
116 let mut m = OMatrix::zeros_generic(r, c);
117 m[i] = T::one();
118 Self::some(m)
119 }
120
121 pub fn unwrap_generic(self, r: R, c: C) -> OMatrix<T, R, C> {
122 self.0.unwrap_or_else(|| OMatrix::zeros_generic(r, c))
123 }
124
125 pub fn fmt(&self, f: &mut fmt::Formatter, symbol: &str) -> fmt::Result {
126 if let Some(m) = self.0.as_ref() {
127 write!(f, " + ")?;
128 match m.shape() {
129 (1, 1) => write!(f, "{}", m[0])?,
130 (1, _) | (_, 1) => {
131 let x: Vec<_> = m.iter().map(T::to_string).collect();
132 write!(f, "[{}]", x.join(", "))?
133 }
134 (_, _) => write!(f, "{m}")?,
135 };
136 write!(f, "{symbol}")?;
137 }
138 write!(f, "")
139 }
140}
141
142impl<T: DualNum<F>, F> Derivative<T, F, U1, U1> {
143 #[expect(clippy::self_named_constructors)]
144 pub fn derivative() -> Self {
145 Self::some(SVector::identity())
146 }
147
148 pub fn unwrap(self) -> T {
149 self.0.map_or_else(
150 || T::zero(),
151 |s| {
152 let [[r]] = s.data.0;
153 r
154 },
155 )
156 }
157}
158
159impl<T: DualNum<F>, F, R: Dim, C: Dim> Mul<T> for Derivative<T, F, R, C>
160where
161 DefaultAllocator: Allocator<R, C>,
162{
163 type Output = Self;
164
165 fn mul(self, rhs: T) -> Self::Output {
166 Derivative::new(self.0.map(|x| x * rhs))
167 }
168}
169
170impl<T: DualNum<F>, F, R: Dim, C: Dim> Mul<T> for &Derivative<T, F, R, C>
171where
172 DefaultAllocator: Allocator<R, C>,
173{
174 type Output = Derivative<T, F, R, C>;
175
176 fn mul(self, rhs: T) -> Self::Output {
177 Derivative::new(self.0.as_ref().map(|x| x * rhs))
178 }
179}
180
181impl<T: DualNum<F>, F, R: Dim, C: Dim, R2: Dim, C2: Dim> Mul<&Derivative<T, F, R2, C2>>
182 for &Derivative<T, F, R, C>
183where
184 DefaultAllocator: Allocator<R, C> + Allocator<R2, C2> + Allocator<R, C2>,
185 ShapeConstraint: SameNumberOfRows<C, R2>,
186{
187 type Output = Derivative<T, F, R, C2>;
188
189 fn mul(self, rhs: &Derivative<T, F, R2, C2>) -> Derivative<T, F, R, C2> {
190 Derivative::new(self.0.as_ref().zip(rhs.0.as_ref()).map(|(s, r)| s * r))
191 }
192}
193
194impl<T: DualNum<F>, F, R: Dim, C: Dim> Div<T> for Derivative<T, F, R, C>
195where
196 DefaultAllocator: Allocator<R, C>,
197{
198 type Output = Self;
199
200 fn div(self, rhs: T) -> Self::Output {
201 Derivative::new(self.0.map(|x| x / rhs))
202 }
203}
204
205impl<T: DualNum<F>, F, R: Dim, C: Dim> Div<T> for &Derivative<T, F, R, C>
206where
207 DefaultAllocator: Allocator<R, C>,
208{
209 type Output = Derivative<T, F, R, C>;
210
211 fn div(self, rhs: T) -> Self::Output {
212 Derivative::new(self.0.as_ref().map(|x| x / rhs))
213 }
214}
215
216impl<T: DualNum<F>, F, R: Dim, C: Dim> Derivative<T, F, R, C>
217where
218 DefaultAllocator: Allocator<R, C>,
219{
220 pub fn tr_mul<R2: Dim, C2: Dim>(
221 &self,
222 rhs: &Derivative<T, F, R2, C2>,
223 ) -> Derivative<T, F, C, C2>
224 where
225 DefaultAllocator: Allocator<R2, C2> + Allocator<C, C2>,
226 ShapeConstraint: SameNumberOfRows<R, R2>,
227 {
228 Derivative::new(
229 self.0
230 .as_ref()
231 .zip(rhs.0.as_ref())
232 .map(|(s, r)| s.tr_mul(r)),
233 )
234 }
235}
236
237impl<T: DualNum<F>, F, R: Dim, C: Dim> Add for Derivative<T, F, R, C>
238where
239 DefaultAllocator: Allocator<R, C>,
240{
241 type Output = Self;
242
243 fn add(self, rhs: Self) -> Self::Output {
244 Self::new(match (self.0, rhs.0) {
245 (Some(s), Some(r)) => Some(s + r),
246 (Some(s), None) => Some(s),
247 (None, Some(r)) => Some(r),
248 (None, None) => None,
249 })
250 }
251}
252
253impl<T: DualNum<F>, F, R: Dim, C: Dim> Add<&Derivative<T, F, R, C>> for Derivative<T, F, R, C>
254where
255 DefaultAllocator: Allocator<R, C>,
256{
257 type Output = Derivative<T, F, R, C>;
258
259 fn add(self, rhs: &Derivative<T, F, R, C>) -> Self::Output {
260 Derivative::new(match (&self.0, &rhs.0) {
261 (Some(s), Some(r)) => Some(s + r),
262 (Some(s), None) => Some(s.clone()),
263 (None, Some(r)) => Some(r.clone()),
264 (None, None) => None,
265 })
266 }
267}
268
269impl<T: DualNum<F>, F, R: Dim, C: Dim> Add for &Derivative<T, F, R, C>
270where
271 DefaultAllocator: Allocator<R, C>,
272{
273 type Output = Derivative<T, F, R, C>;
274
275 fn add(self, rhs: Self) -> Self::Output {
276 Derivative::new(match (&self.0, &rhs.0) {
277 (Some(s), Some(r)) => Some(s + r),
278 (Some(s), None) => Some(s.clone()),
279 (None, Some(r)) => Some(r.clone()),
280 (None, None) => None,
281 })
282 }
283}
284
285impl<T: DualNum<F>, F, R: Dim, C: Dim> Sub for Derivative<T, F, R, C>
286where
287 DefaultAllocator: Allocator<R, C>,
288{
289 type Output = Self;
290
291 fn sub(self, rhs: Self) -> Self::Output {
292 Self::new(match (self.0, rhs.0) {
293 (Some(s), Some(r)) => Some(s - r),
294 (Some(s), None) => Some(s),
295 (None, Some(r)) => Some(-r),
296 (None, None) => None,
297 })
298 }
299}
300
301impl<T: DualNum<F>, F, R: Dim, C: Dim> Sub<&Derivative<T, F, R, C>> for Derivative<T, F, R, C>
302where
303 DefaultAllocator: Allocator<R, C>,
304{
305 type Output = Derivative<T, F, R, C>;
306
307 fn sub(self, rhs: &Derivative<T, F, R, C>) -> Self::Output {
308 Derivative::new(match (&self.0, &rhs.0) {
309 (Some(s), Some(r)) => Some(s - r),
310 (Some(s), None) => Some(s.clone()),
311 (None, Some(r)) => Some(-r.clone()),
312 (None, None) => None,
313 })
314 }
315}
316
317impl<T: DualNum<F>, F, R: Dim, C: Dim> Sub for &Derivative<T, F, R, C>
318where
319 DefaultAllocator: Allocator<R, C>,
320{
321 type Output = Derivative<T, F, R, C>;
322
323 fn sub(self, rhs: Self) -> Self::Output {
324 Derivative::new(match (&self.0, &rhs.0) {
325 (Some(s), Some(r)) => Some(s - r),
326 (Some(s), None) => Some(s.clone()),
327 (None, Some(r)) => Some(-r),
328 (None, None) => None,
329 })
330 }
331}
332
333impl<T: DualNum<F>, F, R: Dim, C: Dim> Neg for &Derivative<T, F, R, C>
334where
335 DefaultAllocator: Allocator<R, C>,
336{
337 type Output = Derivative<T, F, R, C>;
338
339 fn neg(self) -> Self::Output {
340 Derivative::new(self.0.as_ref().map(|x| -x))
341 }
342}
343
344impl<T: DualNum<F>, F, R: Dim, C: Dim> Neg for Derivative<T, F, R, C>
345where
346 DefaultAllocator: Allocator<R, C>,
347{
348 type Output = Self;
349
350 fn neg(self) -> Self::Output {
351 Derivative::new(self.0.map(|x| -x))
352 }
353}
354
355impl<T: DualNum<F>, F, R: Dim, C: Dim> AddAssign for Derivative<T, F, R, C>
356where
357 DefaultAllocator: Allocator<R, C>,
358{
359 fn add_assign(&mut self, rhs: Self) {
360 match (&mut self.0, rhs.0) {
361 (Some(s), Some(r)) => *s += &r,
362 (None, Some(r)) => self.0 = Some(r),
363 (_, None) => (),
364 };
365 }
366}
367
368impl<T: DualNum<F>, F, R: Dim, C: Dim> SubAssign for Derivative<T, F, R, C>
369where
370 DefaultAllocator: Allocator<R, C>,
371{
372 fn sub_assign(&mut self, rhs: Self) {
373 match (&mut self.0, rhs.0) {
374 (Some(s), Some(r)) => *s -= &r,
375 (None, Some(r)) => self.0 = Some(-&r),
376 (_, None) => (),
377 };
378 }
379}
380
381impl<T: DualNum<F>, F, R: Dim, C: Dim> MulAssign<T> for Derivative<T, F, R, C>
382where
383 DefaultAllocator: Allocator<R, C>,
384{
385 fn mul_assign(&mut self, rhs: T) {
386 if let Some(s) = &mut self.0 {
387 *s *= rhs
388 }
389 }
390}
391
392impl<T: DualNum<F>, F, R: Dim, C: Dim> DivAssign<T> for Derivative<T, F, R, C>
393where
394 DefaultAllocator: Allocator<R, C>,
395{
396 fn div_assign(&mut self, rhs: T) {
397 if let Some(s) = &mut self.0 {
398 *s /= rhs
399 }
400 }
401}
402
403impl<T, R: Dim, C: Dim> nalgebra::SimdValue for Derivative<T, T::Element, R, C>
404where
405 DefaultAllocator: Allocator<R, C>,
406 T: DualNum<T::Element> + SimdValue + Scalar,
407 T::Element: DualNum<T::Element> + Scalar + Zero,
408{
409 type Element = Derivative<T::Element, T::Element, R, C>;
410
411 type SimdBool = T::SimdBool;
412
413 const LANES: usize = T::LANES;
414
415 #[inline]
416 fn splat(val: Self::Element) -> Self {
417 val.map(|e| T::splat(e))
418 }
419
420 #[inline]
421 fn extract(&self, i: usize) -> Self::Element {
422 self.map_borrowed(|e| T::extract(e, i))
423 }
424
425 #[inline]
426 unsafe fn extract_unchecked(&self, i: usize) -> Self::Element {
427 let opt = self
428 .map_borrowed(|e| unsafe { T::extract_unchecked(e, i) })
429 .0
430 .filter(|x| Iterator::any(&mut x.iter(), |e| !e.is_zero()));
435 Derivative::new(opt)
436 }
437
438 fn replace(&mut self, i: usize, val: Self::Element) {
449 match (&mut self.0, val.0) {
450 (Some(ours), Some(theirs)) => {
451 ours.zip_apply(&theirs, |e, replacement| e.replace(i, replacement));
452 }
453 (ours @ None, Some(theirs)) => {
454 let (r, c) = theirs.shape_generic();
455 let mut init: OMatrix<T, R, C> = OMatrix::zeros_generic(r, c);
456 init.zip_apply(&theirs, |e, replacement| e.replace(i, replacement));
457 *ours = Some(init);
458 }
459 (Some(ours), None) => {
460 ours.apply(|e| e.replace(i, T::Element::zero()));
461 }
462 _ => {}
463 }
464 }
465
466 unsafe fn replace_unchecked(&mut self, i: usize, val: Self::Element) {
467 match (&mut self.0, val.0) {
468 (Some(ours), Some(theirs)) => {
469 ours.zip_apply(&theirs, |e, replacement| unsafe {
470 e.replace_unchecked(i, replacement)
471 });
472 }
473 (ours @ None, Some(theirs)) => {
474 let (r, c) = theirs.shape_generic();
475 let mut init: OMatrix<T, R, C> = OMatrix::zeros_generic(r, c);
476 init.zip_apply(&theirs, |e, replacement| unsafe {
477 e.replace_unchecked(i, replacement)
478 });
479 *ours = Some(init);
480 }
481 (Some(ours), None) => {
482 ours.apply(|e| unsafe { e.replace_unchecked(i, T::Element::zero()) });
483 }
484 _ => {}
485 }
486 }
487
488 fn select(mut self, cond: Self::SimdBool, other: Self) -> Self {
489 if cond.all() {
492 self
493 } else if cond.none() {
494 other
495 } else {
496 match (&mut self.0, other.0) {
497 (Some(ours), Some(theirs)) => {
498 ours.zip_apply(&theirs, |e, other_e| {
499 let e_ = std::mem::replace(e, T::zero());
501 *e = e_.select(cond, other_e)
502 });
503 self
504 }
505 (Some(ours), None) => {
506 ours.apply(|e| {
507 let e_ = std::mem::replace(e, T::zero());
509 *e = e_.select(cond, T::zero());
510 });
511 self
512 }
513 (ours @ None, Some(mut theirs)) => {
514 use std::ops::Not;
515 let inverted: T::SimdBool = cond.not();
516 theirs.apply(|e| {
517 let e_ = std::mem::replace(e, T::zero());
519 *e = e_.select(inverted, T::zero());
520 });
521 *ours = Some(theirs);
522 self
523 }
524 _ => self,
525 }
526 }
527 }
528}
529
530use simba::scalar::{SubsetOf, SupersetOf};
531
532impl<TSuper, FSuper, T, F, R: Dim, C: Dim> SubsetOf<Derivative<TSuper, FSuper, R, C>>
533 for Derivative<T, F, R, C>
534where
535 TSuper: DualNum<FSuper> + SupersetOf<T>,
536 T: DualNum<F>,
537 DefaultAllocator: Allocator<R, C>,
538 {
543 #[inline(always)]
544 fn to_superset(&self) -> Derivative<TSuper, FSuper, R, C> {
545 self.map_borrowed(|elem| TSuper::from_subset(elem))
546 }
547 #[inline(always)]
548 fn from_superset(element: &Derivative<TSuper, FSuper, R, C>) -> Option<Self> {
549 element.try_map_borrowed(|elem| TSuper::to_subset(elem))
550 }
551 #[inline(always)]
552 fn from_superset_unchecked(element: &Derivative<TSuper, FSuper, R, C>) -> Self {
553 element.map_borrowed(|elem| TSuper::to_subset_unchecked(elem))
554 }
555 #[inline(always)]
556 fn is_in_subset(element: &Derivative<TSuper, FSuper, R, C>) -> bool {
557 element
558 .0
559 .as_ref()
560 .is_none_or(|matrix| matrix.iter().all(|elem| TSuper::is_in_subset(elem)))
561 }
562}