1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
//! Per-opcode math kernels (single source of truth for formulas).
//!
//! Historical context: before this module, each opcode's numerical
//! formula was duplicated across several CPU AD types (`Dual`,
//! `DualVec`, `Reverse`, `BReverse`, `Laurent`) and the bytecode-tape
//! opcode dispatcher (`opcode.rs`). The GPU side — WGSL shaders and the
//! CUDA NVRTC kernel — carry yet another copy in their respective
//! source-level languages. A change to one copy could drift silently
//! from the others; Phase 7 of Cycle 6 found multiple such drifts
//! (atan large-|a|, div small-|b|, hypot Inf handling).
//!
//! Going forward, any numerical formula shared between the CPU AD
//! types **and** the opcode dispatcher should live here as a single
//! generic function over `num_traits::Float`. Each AD type then
//! delegates to the helper, so CPU drift becomes impossible. GPU
//! drift is caught by `tests/gpu_cpu_parity.rs`.
//!
//! Scope note: This module is intentionally minimal at the start.
//! Only the formulas most prone to recent drift (hypot partials,
//! atan large-|a|) are extracted; the rest remain inline in the AD
//! types and opcode dispatcher. Every extraction here must come with
//! a call-site refactor so we don't add abstraction for its own sake.
use Float;
/// Partial derivatives of `hypot(a, b) = sqrt(a² + b²)`.
///
/// `r` is the primal value `hypot(a, b)`. Returns `(∂r/∂a, ∂r/∂b)`.
///
/// At the origin (`r == 0`) the gradient is mathematically undefined
/// but we return `(0, 0)` to match the JAX / PyTorch convention and
/// to avoid emitting NaN into downstream adjoint chains.
/// Partial derivatives of `atan2(a, b)`.
///
/// Formula: `∂/∂a = b/(a²+b²)`, `∂/∂b = -a/(a²+b²)`. Factored as
/// `(b/h)/h` and `-(a/h)/h` where `h = hypot(a, b)` so the
/// intermediate never forms `a²+b²` directly — that would overflow
/// for `|a|, |b| > sqrt(MAX)` and underflow for values below
/// `sqrt(MIN_POSITIVE)`.
///
/// At the origin (`h == 0`) the gradient is mathematically undefined
/// and we return `(0, 0)`.
/// Derivative of `atan(a)` with an overflow-safe large-|a| path.
///
/// For `|a| ≤ 1e8`, returns `1/(1+a²)`. For `|a| > 1e8`, reformulates
/// via `u = 1/a` so `1/(1+a²) = u²/(1+u²)`, keeping every intermediate
/// in-range even at `|a| ≈ 1e19` where `a² overflows in f32.
/// Derivative of `asinh(a) = ln(a + sqrt(1+a²))` with a large-|a|
/// overflow-safe path.
///
/// For `|a| ≤ 1e8`, returns `1/sqrt(1+a²)`. For `|a| > 1e8`, uses
/// `u = 1/a` and `|u|/sqrt(1+u²)` so `1+a²` can't overflow.
/// Derivative of `acosh(a) = ln(a + sqrt(a²-1))` with a large-|a|
/// overflow-safe path.
///
/// For `|a| ≤ 1e8`, returns `1/sqrt((a-1)·(a+1))`. The factored form
/// (vs naive `a*a - 1`) avoids catastrophic cancellation near `a = 1`:
/// at `a = 1 + ε`, `a*a` rounds to `1 + 2ε` and `a*a - 1 = 2ε` loses
/// the `ε²` contribution, while `(a-1)·(a+1) = ε·(2 + ε)` retains it.
/// For `|a| > 1e8`, uses `u = 1/a` and `|u|/sqrt(1-u²)`.
///
/// The WGSL shaders (`reverse.wgsl`, `tangent_forward.wgsl`,
/// `tangent_reverse.wgsl`, plus the `acosh_f32` primal helper in
/// `forward.wgsl`), the CUDA kernel (`tape_eval.cu` at three derivative
/// sites), and the Taylor jet codegen (`taylor_codegen.rs` for both
/// WGSL and CUDA emitters, including the `acosh_f` primal helper) all
/// use the same factored form so CPU and GPU stay in lockstep. The
/// regression test `acosh_deriv_factored_form_keeps_precision_near_one`
/// below pins the f64 behaviour — any swap back to `a*a - 1` will
/// trip it.