raddy/scalar/
field_impl.rs

1#![allow(unused)]
2
3use crate::Ad;
4use approx::{AbsDiffEq, RelativeEq, UlpsEq};
5use na::{ComplexField, Field, RealField, SimdValue};
6use num_traits::FromPrimitive;
7use simba::scalar::SubsetOf;
8use std::f64::consts::LN_2;
9
10// ################################################
11// ################## Unexamined ##################
12// ################################################
13
14impl<const N: usize> AbsDiffEq for Ad<N> {
15    type Epsilon = Self;
16
17    fn default_epsilon() -> Self::Epsilon {
18        todo!()
19    }
20
21    fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
22        todo!()
23    }
24}
25
26impl<const N: usize> UlpsEq for Ad<N> {
27    fn default_max_ulps() -> u32 {
28        todo!()
29    }
30
31    fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool {
32        todo!()
33    }
34}
35
36impl<const N: usize> RelativeEq for Ad<N> {
37    fn default_max_relative() -> Self::Epsilon {
38        todo!()
39    }
40
41    fn relative_eq(
42        &self,
43        other: &Self,
44        epsilon: Self::Epsilon,
45        max_relative: Self::Epsilon,
46    ) -> bool {
47        todo!()
48    }
49}
50
51impl<const N: usize> Field for Ad<N> {}
52
53impl<const N: usize> SimdValue for Ad<N> {
54    const LANES: usize = 1;
55
56    type Element = Self;
57
58    type SimdBool = bool;
59
60    fn splat(val: Self::Element) -> Self {
61        todo!()
62    }
63
64    fn extract(&self, i: usize) -> Self::Element {
65        todo!()
66    }
67
68    unsafe fn extract_unchecked(&self, i: usize) -> Self::Element {
69        todo!()
70    }
71
72    fn replace(&mut self, i: usize, val: Self::Element) {
73        todo!()
74    }
75
76    unsafe fn replace_unchecked(&mut self, i: usize, val: Self::Element) {
77        todo!()
78    }
79
80    fn select(self, cond: Self::SimdBool, other: Self) -> Self {
81        todo!()
82    }
83}
84
85impl<const N: usize> FromPrimitive for Ad<N> {
86    fn from_i64(n: i64) -> Option<Self> {
87        todo!()
88    }
89
90    fn from_u64(n: u64) -> Option<Self> {
91        todo!()
92    }
93}
94
95impl<const N: usize> SubsetOf<Ad<N>> for Ad<N> {
96    fn to_superset(&self) -> Ad<N> {
97        todo!()
98    }
99
100    fn from_superset_unchecked(element: &Ad<N>) -> Self {
101        todo!()
102    }
103
104    fn is_in_subset(element: &Ad<N>) -> bool {
105        todo!()
106    }
107}
108
109impl<const N: usize> SubsetOf<Ad<N>> for f64 {
110    fn to_superset(&self) -> Ad<N> {
111        todo!()
112    }
113
114    fn from_superset_unchecked(element: &Ad<N>) -> Self {
115        todo!()
116    }
117
118    fn is_in_subset(element: &Ad<N>) -> bool {
119        todo!()
120    }
121}
122impl<const N: usize> SubsetOf<Ad<N>> for f32 {
123    fn to_superset(&self) -> Ad<N> {
124        todo!()
125    }
126
127    fn from_superset_unchecked(element: &Ad<N>) -> Self {
128        todo!()
129    }
130
131    fn is_in_subset(element: &Ad<N>) -> bool {
132        todo!()
133    }
134}
135
136impl<const N: usize> RealField for Ad<N> {
137    fn is_sign_positive(&self) -> bool {
138        todo!()
139    }
140
141    fn is_sign_negative(&self) -> bool {
142        todo!()
143    }
144
145    fn copysign(self, sign: Self) -> Self {
146        todo!()
147    }
148
149    fn max(self, other: Self) -> Self {
150        todo!()
151    }
152
153    fn min(self, other: Self) -> Self {
154        todo!()
155    }
156
157    fn clamp(self, min: Self, max: Self) -> Self {
158        todo!()
159    }
160
161    fn atan2(self, other: Self) -> Self {
162        todo!()
163    }
164
165    fn min_value() -> Option<Self> {
166        todo!()
167    }
168
169    fn max_value() -> Option<Self> {
170        todo!()
171    }
172
173    fn pi() -> Self {
174        todo!()
175    }
176
177    fn two_pi() -> Self {
178        todo!()
179    }
180
181    fn frac_pi_2() -> Self {
182        todo!()
183    }
184
185    fn frac_pi_3() -> Self {
186        todo!()
187    }
188
189    fn frac_pi_4() -> Self {
190        todo!()
191    }
192
193    fn frac_pi_6() -> Self {
194        todo!()
195    }
196
197    fn frac_pi_8() -> Self {
198        todo!()
199    }
200
201    fn frac_1_pi() -> Self {
202        todo!()
203    }
204
205    fn frac_2_pi() -> Self {
206        todo!()
207    }
208
209    fn frac_2_sqrt_pi() -> Self {
210        todo!()
211    }
212
213    fn e() -> Self {
214        todo!()
215    }
216
217    fn log2_e() -> Self {
218        todo!()
219    }
220
221    fn log10_e() -> Self {
222        todo!()
223    }
224
225    fn ln_2() -> Self {
226        todo!()
227    }
228
229    fn ln_10() -> Self {
230        todo!()
231    }
232}
233
234// ################################################
235// ################### Examined ###################
236// ################################################
237
238impl<const N: usize> ComplexField for Ad<N> {
239    type RealField = Ad<N>;
240
241    #[doc = r" Builds a pure-real complex number from the given value."]
242    fn from_real(re: Self::RealField) -> Self {
243        re
244    }
245
246    #[doc = r" The real part of this complex number."]
247    fn real(self) -> Self::RealField {
248        self
249    }
250
251    #[doc = r" The imaginary part of this complex number."]
252    fn imaginary(self) -> Self::RealField {
253        unimplemented!("This is a real type");
254    }
255
256    #[doc = r" The modulus of this complex number."]
257    fn modulus(self) -> Self::RealField {
258        self.abs()
259    }
260
261    #[doc = r" The squared modulus of this complex number."]
262    fn modulus_squared(self) -> Self::RealField {
263        self.square()
264    }
265
266    #[doc = r" The argument of this complex number."]
267    /// This should be zero with no grad w.r.t. self, but the use of this method is itself a bug.
268    fn argument(self) -> Self::RealField {
269        unimplemented!("This should not be used");
270    }
271
272    #[doc = r" The sum of the absolute value of this complex number's real and imaginary part."]
273    fn norm1(self) -> Self::RealField {
274        self.abs()
275    }
276
277    #[doc = r" Multiplies this complex number by `factor`."]
278    fn scale(self, factor: Self::RealField) -> Self {
279        factor * self
280    }
281
282    #[doc = r" Divides this complex number by `factor`."]
283    fn unscale(self, factor: Self::RealField) -> Self {
284        self / factor
285    }
286
287    fn floor(self) -> Self {
288        unimplemented!("Floor is not differentiable!");
289    }
290
291    fn ceil(self) -> Self {
292        unimplemented!("Ceil is not differentiable!");
293    }
294
295    fn round(self) -> Self {
296        unimplemented!("Round is not differentiable!");
297    }
298
299    fn trunc(self) -> Self {
300        unimplemented!("Trunc is not differentiable!");
301    }
302
303    fn fract(self) -> Self {
304        unimplemented!("Fract is not differentiable!");
305    }
306
307    fn mul_add(self, a: Self, b: Self) -> Self {
308        a * self + b
309    }
310
311    #[doc = r" The absolute value of this complex number: `self / self.signum()`."]
312    #[doc = r""]
313    #[doc = r" This is equivalent to `self.modulus()`."]
314    fn abs(self) -> Self::RealField {
315        let mut res = Self::_zeroed();
316        res.value = self.value.abs();
317        let sign = if self.value >= 0.0 { 1.0 } else { -1.0 };
318        res.grad = sign * self.grad;
319        res.hess = sign * self.hess;
320
321        res
322    }
323
324    #[doc = r" Computes (self.conjugate() * self + other.conjugate() * other).sqrt()"]
325    fn hypot(self, other: Self) -> Self::RealField {
326        (&self * &self + &other * &other).sqrt()
327    }
328
329    fn recip(self) -> Self {
330        Ad::inactive_scalar(1.0) / self
331    }
332
333    /// Real number has itself as conjugate
334    fn conjugate(self) -> Self {
335        self
336    }
337
338    fn sin(self) -> Self {
339        let sin_val = self.value.sin();
340        let cos_val = self.value.cos();
341
342        Self::chain(sin_val, cos_val, -sin_val, &self)
343    }
344
345    fn cos(self) -> Self {
346        let cos_val = self.value.cos();
347        let sin_val = self.value.sin();
348
349        Self::chain(cos_val, -sin_val, -cos_val, &self)
350    }
351
352    fn sin_cos(self) -> (Self, Self) {
353        // (self.sin(), self.cos())
354        todo!()
355    }
356
357    fn tan(self) -> Self {
358        let cos_val = self.value.cos();
359        let cos_sq = cos_val * cos_val;
360
361        Self::chain(
362            self.value.tan(),
363            1.0 / cos_sq,
364            2.0 * self.value.sin() / (cos_sq * cos_val),
365            &self,
366        )
367    }
368
369    fn asin(self) -> Self {
370        if self.value < -1.0 || self.value > 1.0 {
371            panic!("Asin out of domain!");
372        }
373        let s = 1.0 - self.value * self.value;
374        let s_sqrt = s.sqrt();
375
376        Self::chain(
377            self.value.asin(),
378            1.0 / s_sqrt,
379            self.value / (s * s_sqrt),
380            &self,
381        )
382    }
383
384    fn acos(self) -> Self {
385        if self.value < -1.0 || self.value > 1.0 {
386            panic!("Acos out of domain!");
387        }
388        let s = 1.0 - self.value * self.value;
389        let s_sqrt = s.sqrt();
390
391        Self::chain(
392            self.value.acos(),
393            -1.0 / s_sqrt,
394            -self.value / (s * s_sqrt),
395            &self,
396        )
397    }
398
399    fn atan(self) -> Self {
400        let s = self.value * self.value + 1.0;
401
402        Self::chain(
403            self.value.atan(),
404            1.0 / s,
405            -2.0 * self.value / (s * s),
406            &self,
407        )
408    }
409
410    fn sinh(self) -> Self {
411        let sinh_val = self.value.sinh();
412        let cosh_val = self.value.cosh();
413
414        Self::chain(sinh_val, cosh_val, sinh_val, &self)
415    }
416
417    fn cosh(self) -> Self {
418        let sinh_val = self.value.sinh();
419        let cosh_val = self.value.cosh();
420
421        Self::chain(cosh_val, sinh_val, cosh_val, &self)
422    }
423
424    fn tanh(self) -> Self {
425        let cosh_val = self.value.cosh();
426        let cosh_sq = cosh_val * cosh_val;
427
428        Self::chain(
429            self.value.tanh(),
430            1.0 / cosh_sq,
431            -2.0 * self.value.sinh() / (cosh_sq * cosh_val),
432            &self,
433        )
434    }
435
436    fn asinh(self) -> Self {
437        let s = self.value * self.value + 1.0;
438        let s_sqrt = s.sqrt();
439
440        Self::chain(
441            self.value.asinh(),
442            1.0 / s_sqrt,
443            -self.value / (s * s_sqrt),
444            &self,
445        )
446    }
447
448    fn acosh(self) -> Self {
449        if self.value < 1.0 {
450            panic!("Acosh out of domain!");
451        }
452        let sm = self.value - 1.0;
453        let sp = self.value + 1.0;
454        let prod = (sm * sp).sqrt();
455
456        Self::chain(
457            self.value.acosh(),
458            1.0 / prod,
459            -self.value / (prod * sm * sp),
460            &self,
461        )
462    }
463
464    fn atanh(self) -> Self {
465        if self.value <= -1.0 || self.value >= 1.0 {
466            panic!("Atanh out of domain!");
467        }
468        let s = 1.0 - self.value * self.value;
469
470        Self::chain(
471            self.value.atanh(),
472            1.0 / s,
473            2.0 * self.value / (s * s),
474            &self,
475        )
476    }
477
478    fn log(self, base: Self::RealField) -> Self {
479        unimplemented!("Differentiation w.r.t. base is not implemented...")
480    }
481
482    fn log2(self) -> Self {
483        if self.value <= 0.0 {
484            panic!("Log2 on non-positive value!");
485        }
486        let inv = 1.0 / self.value / std::f64::consts::LN_2;
487
488        Self::chain(self.value.log2(), inv, -inv / self.value, &self)
489    }
490
491    fn log10(self) -> Self {
492        if self.value <= 0.0 {
493            panic!("Log10 on non-positive value!");
494        }
495        let inv = 1.0 / self.value / std::f64::consts::LN_10;
496
497        Self::chain(self.value.log10(), inv, -inv / self.value, &self)
498    }
499
500    fn ln(self) -> Self {
501        if self.value <= 0.0 {
502            panic!("Ln on non-positive value!");
503        }
504        let inv = 1.0 / self.value;
505
506        Self::chain(self.value.ln(), inv, -inv * inv, &self)
507    }
508
509    fn ln_1p(self) -> Self {
510        (self + Self::inactive_scalar(1.0)).ln()
511    }
512
513    fn sqrt(self) -> Self {
514        if self.value < -0.0 {
515            // We don't want to mute this behavior or get NaN as this is fucking undebuggable.
516            panic!("Sqrt on negative value!");
517        }
518        let f = self.value.sqrt();
519
520        Self::chain(f, 0.5 / f, -0.25 / (f * self.value), &self)
521    }
522
523    fn exp(self) -> Self {
524        let exp_val = self.value.exp();
525
526        Self::chain(exp_val, exp_val, exp_val, &self)
527    }
528
529    fn exp2(self) -> Self {
530        let exp_val = self.value.exp2();
531
532        Self::chain(exp_val, exp_val * LN_2, exp_val * LN_2 * LN_2, &self)
533    }
534
535    fn exp_m1(self) -> Self {
536        (self - Self::inactive_scalar(1.0)).exp()
537    }
538
539    fn powi(self, exponent: i32) -> Self {
540        if self.value.abs() == 0.0 && exponent == 0 {
541            // We don't want to mute this behavior or get NaN as this is fucking undebuggable.
542            panic!("0.pow(0) is undefined!");
543        }
544
545        let f2 = self.value.powi(exponent - 2);
546        let f1 = f2 * self.value;
547        let f = f1 * self.value;
548
549        // exponent in float
550        let ef = exponent as f64;
551
552        Self::chain(f, ef * f1, ef * (ef - 1.0) * f2, &self)
553    }
554
555    fn powf(self, n: Self::RealField) -> Self {
556        unimplemented!("Differentiation w.r.t. power it not supported");
557    }
558
559    fn powc(self, n: Self) -> Self {
560        unimplemented!("Differentiation w.r.t. complex power it not supported");
561    }
562
563    fn cbrt(self) -> Self {
564        let f = self.value.cbrt();
565
566        let d = 1.0 / (3.0 * f * f);
567        let dd = -2.0 / (9.0 * f * f * f * self.value);
568
569        Self::chain(f, d, dd, &self)
570    }
571
572    fn is_finite(&self) -> bool {
573        self.value.is_finite()
574            && self.grad.as_slice().into_iter().all(|x| x.is_finite())
575            && self.hess.as_slice().into_iter().all(|x| x.is_finite())
576    }
577
578    fn try_sqrt(self) -> Option<Self> {
579        if self.value < -0.0 {
580            None
581        } else {
582            Some(self.sqrt())
583        }
584    }
585}
586
587// ################################################
588// #################### Tests #####################
589// ################################################
590
591#[cfg(test)]
592mod test_field_impl {
593    use crate::{
594        make::{self, var},
595        misc::symbolic_1::grad_det3,
596        types::advec,
597        Ad, GetValue,
598    };
599    use approx::assert_abs_diff_eq;
600    use na::U3;
601    use rand::{thread_rng, Rng};
602
603    const EPS: f64 = 1e-12;
604
605    #[test]
606    fn test_det() {
607        const N: usize = 3;
608        const NVEC: usize = N * N;
609        let mut rng = thread_rng();
610        let vals: Vec<_> = (0..NVEC).map(|_| rng.gen_range(-3.0..3.0)).collect();
611        let matvec: advec<9, 9> = var::vector_from_slice(&vals);
612
613        // Note that this reshape is ROW MAJOR, we shall transpose it
614        let mat: na::SMatrix<Ad<NVEC>, 3, 3> = matvec.reshape_generic(U3, U3).transpose();
615
616        let mat_val = mat.value();
617
618        let det = mat.determinant();
619        let gt_det = mat_val.determinant();
620
621        let det_grad = det.grad();
622        let gt_det_grad = grad_det3(
623            mat_val[(0, 0)],
624            mat_val[(0, 1)],
625            mat_val[(0, 2)],
626            mat_val[(1, 0)],
627            mat_val[(1, 1)],
628            mat_val[(1, 2)],
629            mat_val[(2, 0)],
630            mat_val[(2, 1)],
631            mat_val[(2, 2)],
632        );
633
634        assert_eq!(det.value(), gt_det);
635
636        let grad_diff = (det_grad - gt_det_grad).norm_squared();
637        assert_abs_diff_eq!(grad_diff, 0.0, epsilon = EPS);
638    }
639}