1#![allow(unused)]
2
3use crate::Ad;
4use approx::{AbsDiffEq, RelativeEq, UlpsEq};
5use na::{ComplexField, Field, RealField, SimdValue};
6use num_traits::FromPrimitive;
7use simba::scalar::SubsetOf;
8use std::f64::consts::LN_2;
9
10impl<const N: usize> AbsDiffEq for Ad<N> {
15 type Epsilon = Self;
16
17 fn default_epsilon() -> Self::Epsilon {
18 todo!()
19 }
20
21 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
22 todo!()
23 }
24}
25
26impl<const N: usize> UlpsEq for Ad<N> {
27 fn default_max_ulps() -> u32 {
28 todo!()
29 }
30
31 fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool {
32 todo!()
33 }
34}
35
36impl<const N: usize> RelativeEq for Ad<N> {
37 fn default_max_relative() -> Self::Epsilon {
38 todo!()
39 }
40
41 fn relative_eq(
42 &self,
43 other: &Self,
44 epsilon: Self::Epsilon,
45 max_relative: Self::Epsilon,
46 ) -> bool {
47 todo!()
48 }
49}
50
51impl<const N: usize> Field for Ad<N> {}
52
53impl<const N: usize> SimdValue for Ad<N> {
54 const LANES: usize = 1;
55
56 type Element = Self;
57
58 type SimdBool = bool;
59
60 fn splat(val: Self::Element) -> Self {
61 todo!()
62 }
63
64 fn extract(&self, i: usize) -> Self::Element {
65 todo!()
66 }
67
68 unsafe fn extract_unchecked(&self, i: usize) -> Self::Element {
69 todo!()
70 }
71
72 fn replace(&mut self, i: usize, val: Self::Element) {
73 todo!()
74 }
75
76 unsafe fn replace_unchecked(&mut self, i: usize, val: Self::Element) {
77 todo!()
78 }
79
80 fn select(self, cond: Self::SimdBool, other: Self) -> Self {
81 todo!()
82 }
83}
84
85impl<const N: usize> FromPrimitive for Ad<N> {
86 fn from_i64(n: i64) -> Option<Self> {
87 todo!()
88 }
89
90 fn from_u64(n: u64) -> Option<Self> {
91 todo!()
92 }
93}
94
95impl<const N: usize> SubsetOf<Ad<N>> for Ad<N> {
96 fn to_superset(&self) -> Ad<N> {
97 todo!()
98 }
99
100 fn from_superset_unchecked(element: &Ad<N>) -> Self {
101 todo!()
102 }
103
104 fn is_in_subset(element: &Ad<N>) -> bool {
105 todo!()
106 }
107}
108
109impl<const N: usize> SubsetOf<Ad<N>> for f64 {
110 fn to_superset(&self) -> Ad<N> {
111 todo!()
112 }
113
114 fn from_superset_unchecked(element: &Ad<N>) -> Self {
115 todo!()
116 }
117
118 fn is_in_subset(element: &Ad<N>) -> bool {
119 todo!()
120 }
121}
122impl<const N: usize> SubsetOf<Ad<N>> for f32 {
123 fn to_superset(&self) -> Ad<N> {
124 todo!()
125 }
126
127 fn from_superset_unchecked(element: &Ad<N>) -> Self {
128 todo!()
129 }
130
131 fn is_in_subset(element: &Ad<N>) -> bool {
132 todo!()
133 }
134}
135
136impl<const N: usize> RealField for Ad<N> {
137 fn is_sign_positive(&self) -> bool {
138 todo!()
139 }
140
141 fn is_sign_negative(&self) -> bool {
142 todo!()
143 }
144
145 fn copysign(self, sign: Self) -> Self {
146 todo!()
147 }
148
149 fn max(self, other: Self) -> Self {
150 todo!()
151 }
152
153 fn min(self, other: Self) -> Self {
154 todo!()
155 }
156
157 fn clamp(self, min: Self, max: Self) -> Self {
158 todo!()
159 }
160
161 fn atan2(self, other: Self) -> Self {
162 todo!()
163 }
164
165 fn min_value() -> Option<Self> {
166 todo!()
167 }
168
169 fn max_value() -> Option<Self> {
170 todo!()
171 }
172
173 fn pi() -> Self {
174 todo!()
175 }
176
177 fn two_pi() -> Self {
178 todo!()
179 }
180
181 fn frac_pi_2() -> Self {
182 todo!()
183 }
184
185 fn frac_pi_3() -> Self {
186 todo!()
187 }
188
189 fn frac_pi_4() -> Self {
190 todo!()
191 }
192
193 fn frac_pi_6() -> Self {
194 todo!()
195 }
196
197 fn frac_pi_8() -> Self {
198 todo!()
199 }
200
201 fn frac_1_pi() -> Self {
202 todo!()
203 }
204
205 fn frac_2_pi() -> Self {
206 todo!()
207 }
208
209 fn frac_2_sqrt_pi() -> Self {
210 todo!()
211 }
212
213 fn e() -> Self {
214 todo!()
215 }
216
217 fn log2_e() -> Self {
218 todo!()
219 }
220
221 fn log10_e() -> Self {
222 todo!()
223 }
224
225 fn ln_2() -> Self {
226 todo!()
227 }
228
229 fn ln_10() -> Self {
230 todo!()
231 }
232}
233
234impl<const N: usize> ComplexField for Ad<N> {
239 type RealField = Ad<N>;
240
241 #[doc = r" Builds a pure-real complex number from the given value."]
242 fn from_real(re: Self::RealField) -> Self {
243 re
244 }
245
246 #[doc = r" The real part of this complex number."]
247 fn real(self) -> Self::RealField {
248 self
249 }
250
251 #[doc = r" The imaginary part of this complex number."]
252 fn imaginary(self) -> Self::RealField {
253 unimplemented!("This is a real type");
254 }
255
256 #[doc = r" The modulus of this complex number."]
257 fn modulus(self) -> Self::RealField {
258 self.abs()
259 }
260
261 #[doc = r" The squared modulus of this complex number."]
262 fn modulus_squared(self) -> Self::RealField {
263 self.square()
264 }
265
266 #[doc = r" The argument of this complex number."]
267 fn argument(self) -> Self::RealField {
269 unimplemented!("This should not be used");
270 }
271
272 #[doc = r" The sum of the absolute value of this complex number's real and imaginary part."]
273 fn norm1(self) -> Self::RealField {
274 self.abs()
275 }
276
277 #[doc = r" Multiplies this complex number by `factor`."]
278 fn scale(self, factor: Self::RealField) -> Self {
279 factor * self
280 }
281
282 #[doc = r" Divides this complex number by `factor`."]
283 fn unscale(self, factor: Self::RealField) -> Self {
284 self / factor
285 }
286
287 fn floor(self) -> Self {
288 unimplemented!("Floor is not differentiable!");
289 }
290
291 fn ceil(self) -> Self {
292 unimplemented!("Ceil is not differentiable!");
293 }
294
295 fn round(self) -> Self {
296 unimplemented!("Round is not differentiable!");
297 }
298
299 fn trunc(self) -> Self {
300 unimplemented!("Trunc is not differentiable!");
301 }
302
303 fn fract(self) -> Self {
304 unimplemented!("Fract is not differentiable!");
305 }
306
307 fn mul_add(self, a: Self, b: Self) -> Self {
308 a * self + b
309 }
310
311 #[doc = r" The absolute value of this complex number: `self / self.signum()`."]
312 #[doc = r""]
313 #[doc = r" This is equivalent to `self.modulus()`."]
314 fn abs(self) -> Self::RealField {
315 let mut res = Self::_zeroed();
316 res.value = self.value.abs();
317 let sign = if self.value >= 0.0 { 1.0 } else { -1.0 };
318 res.grad = sign * self.grad;
319 res.hess = sign * self.hess;
320
321 res
322 }
323
324 #[doc = r" Computes (self.conjugate() * self + other.conjugate() * other).sqrt()"]
325 fn hypot(self, other: Self) -> Self::RealField {
326 (&self * &self + &other * &other).sqrt()
327 }
328
329 fn recip(self) -> Self {
330 Ad::inactive_scalar(1.0) / self
331 }
332
333 fn conjugate(self) -> Self {
335 self
336 }
337
338 fn sin(self) -> Self {
339 let sin_val = self.value.sin();
340 let cos_val = self.value.cos();
341
342 Self::chain(sin_val, cos_val, -sin_val, &self)
343 }
344
345 fn cos(self) -> Self {
346 let cos_val = self.value.cos();
347 let sin_val = self.value.sin();
348
349 Self::chain(cos_val, -sin_val, -cos_val, &self)
350 }
351
352 fn sin_cos(self) -> (Self, Self) {
353 todo!()
355 }
356
357 fn tan(self) -> Self {
358 let cos_val = self.value.cos();
359 let cos_sq = cos_val * cos_val;
360
361 Self::chain(
362 self.value.tan(),
363 1.0 / cos_sq,
364 2.0 * self.value.sin() / (cos_sq * cos_val),
365 &self,
366 )
367 }
368
369 fn asin(self) -> Self {
370 if self.value < -1.0 || self.value > 1.0 {
371 panic!("Asin out of domain!");
372 }
373 let s = 1.0 - self.value * self.value;
374 let s_sqrt = s.sqrt();
375
376 Self::chain(
377 self.value.asin(),
378 1.0 / s_sqrt,
379 self.value / (s * s_sqrt),
380 &self,
381 )
382 }
383
384 fn acos(self) -> Self {
385 if self.value < -1.0 || self.value > 1.0 {
386 panic!("Acos out of domain!");
387 }
388 let s = 1.0 - self.value * self.value;
389 let s_sqrt = s.sqrt();
390
391 Self::chain(
392 self.value.acos(),
393 -1.0 / s_sqrt,
394 -self.value / (s * s_sqrt),
395 &self,
396 )
397 }
398
399 fn atan(self) -> Self {
400 let s = self.value * self.value + 1.0;
401
402 Self::chain(
403 self.value.atan(),
404 1.0 / s,
405 -2.0 * self.value / (s * s),
406 &self,
407 )
408 }
409
410 fn sinh(self) -> Self {
411 let sinh_val = self.value.sinh();
412 let cosh_val = self.value.cosh();
413
414 Self::chain(sinh_val, cosh_val, sinh_val, &self)
415 }
416
417 fn cosh(self) -> Self {
418 let sinh_val = self.value.sinh();
419 let cosh_val = self.value.cosh();
420
421 Self::chain(cosh_val, sinh_val, cosh_val, &self)
422 }
423
424 fn tanh(self) -> Self {
425 let cosh_val = self.value.cosh();
426 let cosh_sq = cosh_val * cosh_val;
427
428 Self::chain(
429 self.value.tanh(),
430 1.0 / cosh_sq,
431 -2.0 * self.value.sinh() / (cosh_sq * cosh_val),
432 &self,
433 )
434 }
435
436 fn asinh(self) -> Self {
437 let s = self.value * self.value + 1.0;
438 let s_sqrt = s.sqrt();
439
440 Self::chain(
441 self.value.asinh(),
442 1.0 / s_sqrt,
443 -self.value / (s * s_sqrt),
444 &self,
445 )
446 }
447
448 fn acosh(self) -> Self {
449 if self.value < 1.0 {
450 panic!("Acosh out of domain!");
451 }
452 let sm = self.value - 1.0;
453 let sp = self.value + 1.0;
454 let prod = (sm * sp).sqrt();
455
456 Self::chain(
457 self.value.acosh(),
458 1.0 / prod,
459 -self.value / (prod * sm * sp),
460 &self,
461 )
462 }
463
464 fn atanh(self) -> Self {
465 if self.value <= -1.0 || self.value >= 1.0 {
466 panic!("Atanh out of domain!");
467 }
468 let s = 1.0 - self.value * self.value;
469
470 Self::chain(
471 self.value.atanh(),
472 1.0 / s,
473 2.0 * self.value / (s * s),
474 &self,
475 )
476 }
477
478 fn log(self, base: Self::RealField) -> Self {
479 unimplemented!("Differentiation w.r.t. base is not implemented...")
480 }
481
482 fn log2(self) -> Self {
483 if self.value <= 0.0 {
484 panic!("Log2 on non-positive value!");
485 }
486 let inv = 1.0 / self.value / std::f64::consts::LN_2;
487
488 Self::chain(self.value.log2(), inv, -inv / self.value, &self)
489 }
490
491 fn log10(self) -> Self {
492 if self.value <= 0.0 {
493 panic!("Log10 on non-positive value!");
494 }
495 let inv = 1.0 / self.value / std::f64::consts::LN_10;
496
497 Self::chain(self.value.log10(), inv, -inv / self.value, &self)
498 }
499
500 fn ln(self) -> Self {
501 if self.value <= 0.0 {
502 panic!("Ln on non-positive value!");
503 }
504 let inv = 1.0 / self.value;
505
506 Self::chain(self.value.ln(), inv, -inv * inv, &self)
507 }
508
509 fn ln_1p(self) -> Self {
510 (self + Self::inactive_scalar(1.0)).ln()
511 }
512
513 fn sqrt(self) -> Self {
514 if self.value < -0.0 {
515 panic!("Sqrt on negative value!");
517 }
518 let f = self.value.sqrt();
519
520 Self::chain(f, 0.5 / f, -0.25 / (f * self.value), &self)
521 }
522
523 fn exp(self) -> Self {
524 let exp_val = self.value.exp();
525
526 Self::chain(exp_val, exp_val, exp_val, &self)
527 }
528
529 fn exp2(self) -> Self {
530 let exp_val = self.value.exp2();
531
532 Self::chain(exp_val, exp_val * LN_2, exp_val * LN_2 * LN_2, &self)
533 }
534
535 fn exp_m1(self) -> Self {
536 (self - Self::inactive_scalar(1.0)).exp()
537 }
538
539 fn powi(self, exponent: i32) -> Self {
540 if self.value.abs() == 0.0 && exponent == 0 {
541 panic!("0.pow(0) is undefined!");
543 }
544
545 let f2 = self.value.powi(exponent - 2);
546 let f1 = f2 * self.value;
547 let f = f1 * self.value;
548
549 let ef = exponent as f64;
551
552 Self::chain(f, ef * f1, ef * (ef - 1.0) * f2, &self)
553 }
554
555 fn powf(self, n: Self::RealField) -> Self {
556 unimplemented!("Differentiation w.r.t. power it not supported");
557 }
558
559 fn powc(self, n: Self) -> Self {
560 unimplemented!("Differentiation w.r.t. complex power it not supported");
561 }
562
563 fn cbrt(self) -> Self {
564 let f = self.value.cbrt();
565
566 let d = 1.0 / (3.0 * f * f);
567 let dd = -2.0 / (9.0 * f * f * f * self.value);
568
569 Self::chain(f, d, dd, &self)
570 }
571
572 fn is_finite(&self) -> bool {
573 self.value.is_finite()
574 && self.grad.as_slice().into_iter().all(|x| x.is_finite())
575 && self.hess.as_slice().into_iter().all(|x| x.is_finite())
576 }
577
578 fn try_sqrt(self) -> Option<Self> {
579 if self.value < -0.0 {
580 None
581 } else {
582 Some(self.sqrt())
583 }
584 }
585}
586
587#[cfg(test)]
592mod test_field_impl {
593 use crate::{
594 make::{self, var},
595 misc::symbolic_1::grad_det3,
596 types::advec,
597 Ad, GetValue,
598 };
599 use approx::assert_abs_diff_eq;
600 use na::U3;
601 use rand::{thread_rng, Rng};
602
603 const EPS: f64 = 1e-12;
604
605 #[test]
606 fn test_det() {
607 const N: usize = 3;
608 const NVEC: usize = N * N;
609 let mut rng = thread_rng();
610 let vals: Vec<_> = (0..NVEC).map(|_| rng.gen_range(-3.0..3.0)).collect();
611 let matvec: advec<9, 9> = var::vector_from_slice(&vals);
612
613 let mat: na::SMatrix<Ad<NVEC>, 3, 3> = matvec.reshape_generic(U3, U3).transpose();
615
616 let mat_val = mat.value();
617
618 let det = mat.determinant();
619 let gt_det = mat_val.determinant();
620
621 let det_grad = det.grad();
622 let gt_det_grad = grad_det3(
623 mat_val[(0, 0)],
624 mat_val[(0, 1)],
625 mat_val[(0, 2)],
626 mat_val[(1, 0)],
627 mat_val[(1, 1)],
628 mat_val[(1, 2)],
629 mat_val[(2, 0)],
630 mat_val[(2, 1)],
631 mat_val[(2, 2)],
632 );
633
634 assert_eq!(det.value(), gt_det);
635
636 let grad_diff = (det_grad - gt_det_grad).norm_squared();
637 assert_abs_diff_eq!(grad_diff, 0.0, epsilon = EPS);
638 }
639}