echidna 0.9.0

A high-performance automatic differentiation library for Rust
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
// Forward-over-reverse (HVP) shader for sparse Hessian computation.
//
// Each thread performs:
// 1. Forward tangent pass: compute (primals, tangents) for all tape entries
// 2. Reverse adjoint sweep with Dual adjoints: adj_re and adj_eps
//    adj_re → gradient, adj_eps → Hessian-vector product
//
// For adjoint accumulation with Dual partials:
//   adj_re[a] += da_re * adj_re[i]
//   adj_eps[a] += da_re * adj_eps[i] + da_eps * adj_re[i]
// where da = Dual(da_re, da_eps) is the tangent of the reverse partial.

const OP_INPUT:  u32 = 0u;
const OP_CONST:  u32 = 1u;
const OP_ADD:    u32 = 2u;
const OP_SUB:    u32 = 3u;
const OP_MUL:    u32 = 4u;
const OP_DIV:    u32 = 5u;
const OP_REM:    u32 = 6u;
const OP_POWF:   u32 = 7u;
const OP_ATAN2:  u32 = 8u;
const OP_HYPOT:  u32 = 9u;
const OP_MAX:    u32 = 10u;
const OP_MIN:    u32 = 11u;
const OP_NEG:    u32 = 12u;
const OP_RECIP:  u32 = 13u;
const OP_SQRT:   u32 = 14u;
const OP_CBRT:   u32 = 15u;
const OP_POWI:   u32 = 16u;
const OP_EXP:    u32 = 17u;
const OP_EXP2:   u32 = 18u;
const OP_EXPM1:  u32 = 19u;
const OP_LN:     u32 = 20u;
const OP_LOG2:   u32 = 21u;
const OP_LOG10:  u32 = 22u;
const OP_LN1P:   u32 = 23u;
const OP_SIN:    u32 = 24u;
const OP_COS:    u32 = 25u;
const OP_TAN:    u32 = 26u;
const OP_ASIN:   u32 = 27u;
const OP_ACOS:   u32 = 28u;
const OP_ATAN:   u32 = 29u;
const OP_SINH:   u32 = 30u;
const OP_COSH:   u32 = 31u;
const OP_TANH:   u32 = 32u;
const OP_ASINH:  u32 = 33u;
const OP_ACOSH:  u32 = 34u;
const OP_ATANH:  u32 = 35u;
const OP_ABS:    u32 = 36u;
const OP_SIGNUM: u32 = 37u;
const OP_FLOOR:  u32 = 38u;
const OP_CEIL:   u32 = 39u;
const OP_ROUND:  u32 = 40u;
const OP_TRUNC:  u32 = 41u;
const OP_FRACT:  u32 = 42u;

const UNUSED: u32 = 0xFFFFFFFFu;

struct TapeMeta {
    num_ops: u32,
    num_inputs: u32,
    num_variables: u32,
    num_outputs: u32,
    batch_size: u32,
    _pad0: u32,
    _pad1: u32,
    _pad2: u32,
}

@group(0) @binding(0) var<storage, read> opcodes: array<u32>;
@group(0) @binding(1) var<storage, read> arg0: array<u32>;
@group(0) @binding(2) var<storage, read> arg1: array<u32>;
@group(0) @binding(3) var<storage, read> constants: array<f32>;
@group(0) @binding(4) var<uniform> tape_meta: TapeMeta;
@group(0) @binding(5) var<storage, read> output_indices: array<u32>;

// I/O: bind group 1
// 0: primal_inputs [B * num_inputs]
@group(1) @binding(0) var<storage, read> primal_inputs: array<f32>;
// 1: tangent_seeds [B * num_inputs]
@group(1) @binding(1) var<storage, read> tangent_seeds: array<f32>;
// 2: primals [B * num_variables]
@group(1) @binding(2) var<storage, read_write> primals: array<f32>;
// 3: tangents [B * num_variables]
@group(1) @binding(3) var<storage, read_write> tans: array<f32>;
// 4: adj_re [B * num_variables]
@group(1) @binding(4) var<storage, read_write> adj_re: array<f32>;
// 5: adj_eps [B * num_variables]
@group(1) @binding(5) var<storage, read_write> adj_eps: array<f32>;
// 6: grad_out [B * num_inputs]
@group(1) @binding(6) var<storage, read_write> grad_out: array<f32>;
// 7: hvp_out [B * num_inputs]
@group(1) @binding(7) var<storage, read_write> hvp_out: array<f32>;

fn sinh_f(x: f32) -> f32 { return (exp(x) - exp(-x)) * 0.5; }
fn cosh_f(x: f32) -> f32 { return (exp(x) + exp(-x)) * 0.5; }

// Precision-preserving EXPM1 / LN1P primals for small |x|, matching
// forward.wgsl helpers. `exp(x) - 1` and `log(1 + x)` cancel
// catastrophically as x → 0; the Taylor-series shortcut avoids that.
fn expm1_f32(x: f32) -> f32 {
    if abs(x) < 1e-4 { return x + 0.5 * x * x; }
    return exp(x) - 1.0;
}
fn ln1p_f32(x: f32) -> f32 {
    if abs(x) < 1e-4 { return x - 0.5 * x * x; }
    return log(1.0 + x);
}

@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let bid = gid.x;
    if bid >= tape_meta.batch_size {
        return;
    }

    let nv = tape_meta.num_variables;
    let ni = tape_meta.num_inputs;
    let num_ops = tape_meta.num_ops;
    let base = bid * nv;

    // ──── Phase 1: Forward tangent pass ────
    for (var i = 0u; i < nv; i = i + 1u) {
        primals[base + i] = constants[i];
        tans[base + i] = 0.0;
    }
    let in_base = bid * ni;
    for (var i = 0u; i < ni; i = i + 1u) {
        primals[base + i] = primal_inputs[in_base + i];
        tans[base + i] = tangent_seeds[in_base + i];
    }

    for (var i = ni; i < num_ops; i = i + 1u) {
        let op = opcodes[i];
        if op == OP_CONST { continue; }
        let ai = arg0[i];
        let bi = arg1[i];
        let a = primals[base + ai];
        let at = tans[base + ai];
        var r = 0.0f;
        var rt = 0.0f;

        switch op {
            case 2u: { let b = primals[base+bi]; let bt = tans[base+bi]; r=a+b; rt=at+bt; }
            case 3u: { let b = primals[base+bi]; let bt = tans[base+bi]; r=a-b; rt=at-bt; }
            case 4u: { let b = primals[base+bi]; let bt = tans[base+bi]; r=a*b; rt=b*at+a*bt; }
            case 5u: { let b = primals[base+bi]; let bt = tans[base+bi]; r=a/b; let inv=1.0/b; rt=inv*at-a*inv*inv*bt; }
            case 6u: { let b=primals[base+bi]; let bt=tans[base+bi]; r=a-trunc(a/b)*b; rt=at-trunc(a/b)*bt; }
            case 7u: { let b=primals[base+bi]; let bt=tans[base+bi]; r=pow(a,b); let dx=select(b*r/a*at, b*pow(a,b-1.0)*at, a==0.0); let dy=select(r*log(a)*bt, 0.0, r==0.0); rt=dx+dy; }
            case 8u: { let b=primals[base+bi]; let bt=tans[base+bi]; r=atan2(a,b); let mx=max(abs(a),abs(b)); if mx==0.0 {rt=0.0;} else {let au=a/mx; let bu=b/mx; let d=mx*(au*au+bu*bu); rt=(bu*at-au*bt)/d;} }
            case 9u: { let b=primals[base+bi]; let bt=tans[base+bi]; r=sqrt(a*a+b*b); if r==0.0 {rt=0.0;} else {rt=(a*at+b*bt)/r;} }
            case 10u: { let b=primals[base+bi]; let bt=tans[base+bi]; let bb=bitcast<u32>(b); let bn=((bb>>23u)&0xffu)==0xffu && (bb&0x7fffffu)!=0u; if a>=b || bn {r=a;rt=at;} else {r=b;rt=bt;} }
            case 11u: { let b=primals[base+bi]; let bt=tans[base+bi]; let bb=bitcast<u32>(b); let bn=((bb>>23u)&0xffu)==0xffu && (bb&0x7fffffu)!=0u; if a<=b || bn {r=a;rt=at;} else {r=b;rt=bt;} }
            case 12u: { r=-a; rt=-at; }
            case 13u: { r=1.0/a; rt=-at/(a*a); }
            case 14u: { r=sqrt(a); rt=at/(2.0*r); }
            case 15u: { let s=sign(a); r=s*pow(abs(a),1.0/3.0); rt=at/(3.0*r*r); }
            case 16u: { let e=bitcast<i32>(bi); let n=f32(e); r=pow(a,n); rt=select(n*pow(a,n-1.0)*at, 0.0, e==0); }
            case 17u: { r=exp(a); rt=r*at; }
            case 18u: { r=exp2(a); rt=r*log(2.0)*at; }
            case 19u: { r=expm1_f32(a); rt=(r+1.0)*at; }
            case 20u: { r=log(a); rt=at/a; }
            case 21u: { r=log2(a); rt=at/(a*log(2.0)); }
            case 22u: { r=log(a)/log(10.0); rt=at/(a*log(10.0)); }
            case 23u: { r=ln1p_f32(a); rt=at/(1.0+a); }
            case 24u: { r=sin(a); rt=cos(a)*at; }
            case 25u: { r=cos(a); rt=-sin(a)*at; }
            case 26u: { r=tan(a); let c=cos(a); rt=at/(c*c); }
            case 27u: { r=asin(a); rt=at/sqrt((1.0-a)*(1.0+a)); }
            case 28u: { r=acos(a); rt=-at/sqrt((1.0-a)*(1.0+a)); }
            case 29u: {
                let aa = abs(a); r = atan(a);
                if aa > 1e8 { let inv = 1.0 / a; rt = at * inv * inv / (1.0 + inv * inv); }
                else        { rt = at / (1.0 + a * a); }
            }
            case 30u: { r=sinh_f(a); rt=cosh_f(a)*at; }
            case 31u: { r=cosh_f(a); rt=sinh_f(a)*at; }
            case 32u: { r=tanh(a); let c=cosh_f(a); rt=at/(c*c); }
            case 33u: { let ax=abs(a); r=select(-log(ax+sqrt(ax*ax+1.0)), log(ax+sqrt(ax*ax+1.0)), a>=0.0); if ax>1e8 {let inv=1.0/a; rt=at*abs(inv)/sqrt(1.0+inv*inv);} else {rt=at/sqrt(a*a+1.0);} }
            case 34u: { r=log(a+sqrt(a*a-1.0)); if abs(a)>1e8 {let inv=1.0/a; rt=at*abs(inv)/sqrt(1.0-inv*inv);} else {rt=at/sqrt(a*a-1.0);} }
            case 35u: { r=0.5*log((1.0+a)/(1.0-a)); rt=at/((1.0-a)*(1.0+a)); }
            case 36u: { r=abs(a); if a!=a {rt=0.0;} else {let bits=bitcast<u32>(a); let s=select(1.0, -1.0, (bits&0x80000000u)!=0u); rt=s*at;} }
            case 37u: { if a!=a {r=a;} else if a>=0.0 {r=1.0;} else {r=-1.0;} rt=0.0; }
            case 38u: { r=floor(a); rt=0.0; }
            case 39u: { r=ceil(a); rt=0.0; }
            case 40u: { r=round(a); rt=0.0; }
            case 41u: { r=trunc(a); rt=0.0; }
            case 42u: { r=a-trunc(a); rt=at; }
            default: {}
        }
        primals[base + i] = r;
        tans[base + i] = rt;
    }

    // ──── Phase 2: Reverse sweep with Dual adjoints ────
    for (var i = 0u; i < nv; i = i + 1u) {
        adj_re[base + i] = 0.0;
        adj_eps[base + i] = 0.0;
    }
    // Seed output adjoint
    let seed_idx = output_indices[0];
    adj_re[base + seed_idx] = 1.0;

    for (var ii = 0u; ii < num_ops; ii = ii + 1u) {
        let i = num_ops - 1u - ii;
        let ar = adj_re[base + i];
        let ae = adj_eps[base + i];
        if ar == 0.0 && ae == 0.0 { continue; }

        let op = opcodes[i];
        if op == OP_INPUT || op == OP_CONST { continue; }

        adj_re[base + i] = 0.0;
        adj_eps[base + i] = 0.0;

        let ai = arg0[i];
        let bi = arg1[i];
        let a = primals[base + ai];
        let at = tans[base + ai];
        let r = primals[base + i];

        // Compute Dual reverse partials: (da_re, da_eps, db_re, db_eps)
        var da_re = 0.0f;
        var da_eps = 0.0f;
        var db_re = 0.0f;
        var db_eps = 0.0f;

        switch op {
            case 2u /* ADD */: { da_re=1.0; db_re=1.0; }
            case 3u /* SUB */: { da_re=1.0; db_re=-1.0; }
            case 4u /* MUL */: {
                let b=primals[base+bi]; let bt=tans[base+bi];
                da_re=b; da_eps=bt; db_re=a; db_eps=at;
            }
            case 5u /* DIV */: {
                let b=primals[base+bi]; let bt=tans[base+bi];
                let inv=1.0/b;
                // Factor through `r = a/b` to drop one `inv` from each
                // higher-order term: `-a*inv²` → `-r*inv`, and
                // `2*a*inv³` → `2*r*inv²`. One factor of `inv*inv` still
                // remains in the eps-eps terms (unavoidable second
                // derivative), but `inv³` is eliminated.
                da_re=inv; da_eps=-bt*inv*inv;
                db_re=-r*inv; db_eps=-at*inv*inv+2.0*r*bt*inv*inv;
            }
            case 6u /* REM */: {
                let b=primals[base+bi];
                da_re=1.0;
                db_re=-trunc(a/b);
                // db_eps = 0 since trunc has zero derivative a.e.
            }
            case 7u /* POWF */: {
                let b=primals[base+bi]; let bt=tans[base+bi];
                let ab1 = pow(a, b-1.0);
                da_re = b * ab1;
                // For a <= 0, `log(a)` is NaN. Finite `r` at a < 0 implies b
                // was integer; the classical derivative w.r.t. b is undefined
                // there, convention is 0 — mirrors CPU OpCode::Powf safety net.
                if a <= 0.0 {
                    da_eps = 0.0;
                } else {
                    da_eps = bt*ab1 + b*ab1*((b-1.0)/a*at + log(a)*bt);
                }
                let rr = primals[base+i]; // r = a^b from forward pass
                if rr == 0.0 || a <= 0.0 {
                    db_re = 0.0;
                    db_eps = 0.0;
                } else {
                    let la = log(a);
                    let rt = tans[base+i];
                    db_re = rr * la;
                    db_eps = rt*la + rr*at/a;
                }
            }
            case 8u /* ATAN2 */: {
                let b=primals[base+bi]; let bt=tans[base+bi];
                // Normalize by max(|a|,|b|) so a²+b² doesn't overflow in f32
                // even when a*a alone would. Let au = a/mx, bu = b/mx, s =
                // au² + bu² ∈ [1, 2], ms = mx·s. Then:
                //   a² + b² = mx · ms
                //   da_re  = b/(a²+b²) = bu/ms
                //   da_eps = bt/(a²+b²) − bu·dd_over_mx/ms²
                //          = (bt/mx − bu·dd_over_mx/ms) / ms
                // where dd_over_mx = 2·(au·at + bu·bt) is bounded for
                // fixed tangent magnitudes. Expressing in this form avoids
                // the explicit mx² that would overflow for |mx| > sqrt(f32::MAX).
                let mx = max(abs(a), abs(b));
                if mx == 0.0 {
                    da_re = 0.0; da_eps = 0.0; db_re = 0.0; db_eps = 0.0;
                } else {
                    let au = a / mx;
                    let bu = b / mx;
                    let s = au * au + bu * bu;
                    let ms = mx * s;
                    let dd_over_mx = 2.0 * (au * at + bu * bt);
                    da_re = bu / ms;
                    db_re = -au / ms;
                    da_eps = (bt / mx - bu * dd_over_mx / ms) / ms;
                    db_eps = (-at / mx + au * dd_over_mx / ms) / ms;
                }
            }
            case 9u /* HYPOT */: {
                let b=primals[base+bi]; let bt=tans[base+bi];
                if r == 0.0 {
                    da_re=0.0; da_eps=0.0; db_re=0.0; db_eps=0.0;
                } else {
                    let r2=r*r; let rt2=tans[base+i];
                    da_re=a/r; da_eps=(at*r-a*rt2)/(r2);
                    db_re=b/r; db_eps=(bt*r-b*rt2)/(r2);
                }
            }
            case 10u /* MAX */: {
                let b=primals[base+bi];
                if a>=b || b!=b { da_re=1.0; } else { db_re=1.0; }
            }
            case 11u /* MIN */: {
                let b=primals[base+bi];
                if a<=b || b!=b { da_re=1.0; } else { db_re=1.0; }
            }

            // Unary ops: da_re = f'(a), da_eps = f''(a)*at
            case 12u /* NEG */: { da_re=-1.0; }
            case 13u /* RECIP */: { let inv=1.0/a; da_re=-inv*inv; da_eps=2.0*at*inv*inv*inv; }
            case 14u /* SQRT */: { da_re=0.5/r; da_eps=-0.25*at/(a*r); }
            // f''(a) = -2/(9*r^5) where r = cbrt(a)
            case 15u /* CBRT */: { let rr=r*r; da_re=1.0/(3.0*rr); da_eps=-2.0*at/(9.0*rr*rr*r); }
            case 16u /* POWI */: {
                let e=bitcast<i32>(bi);
                if e == 0 { da_re=0.0; da_eps=0.0; }
                else if e == 1 {
                    // f(a)=a, f'=1, f''=0. The general formula evaluates
                    // `pow(a, -1) → Inf` at a=0, giving `0 * Inf * at = NaN`.
                    // Short-circuit to the mathematically exact zero second
                    // derivative. Mirrors the CUDA fix in tape_eval.cu.
                    da_re=1.0; da_eps=0.0;
                } else {
                    let n=f32(e); da_re=n*pow(a,n-1.0); da_eps=n*(n-1.0)*pow(a,n-2.0)*at;
                }
            }
            case 17u /* EXP */: { da_re=r; da_eps=r*at; }
            case 18u /* EXP2 */: { let l2=log(2.0); da_re=r*l2; da_eps=r*l2*l2*at; }
            case 19u /* EXPM1 */: { da_re=r+1.0; da_eps=(r+1.0)*at; }
            case 20u /* LN */: { da_re=1.0/a; da_eps=-at/(a*a); }
            case 21u /* LOG2 */: { let l2=log(2.0); da_re=1.0/(a*l2); da_eps=-at/(a*a*l2); }
            case 22u /* LOG10 */: { let l10=log(10.0); da_re=1.0/(a*l10); da_eps=-at/(a*a*l10); }
            case 23u /* LN1P */: { let t=1.0+a; da_re=1.0/t; da_eps=-at/(t*t); }
            case 24u /* SIN */: { da_re=cos(a); da_eps=-sin(a)*at; }
            case 25u /* COS */: { da_re=-sin(a); da_eps=-cos(a)*at; }
            case 26u /* TAN */: { let c=cos(a); let s=1.0/(c*c); da_re=s; da_eps=2.0*tan(a)*s*at; }
            case 27u /* ASIN */: { let t=sqrt((1.0-a)*(1.0+a)); da_re=1.0/t; da_eps=a*at/(t*t*t); }
            case 28u /* ACOS */: { let t=sqrt((1.0-a)*(1.0+a)); da_re=-1.0/t; da_eps=-a*at/(t*t*t); }
            case 29u /* ATAN */: {
                let aa = abs(a);
                if aa > 1e8 {
                    let inv = 1.0 / a;
                    let h = 1.0 + inv * inv;
                    da_re = inv * inv / h;
                    da_eps = -2.0 * inv * inv * inv * at / (h * h);
                } else {
                    let t = 1.0 + a * a;
                    da_re = 1.0 / t;
                    da_eps = -2.0 * a * at / (t * t);
                }
            }
            case 30u /* SINH */: { da_re=cosh_f(a); da_eps=sinh_f(a)*at; }
            case 31u /* COSH */: { da_re=sinh_f(a); da_eps=cosh_f(a)*at; }
            case 32u /* TANH */: { let c=cosh_f(a); let s=1.0/(c*c); da_re=s; da_eps=-2.0*tanh(a)*s*at; }
            case 33u /* ASINH */: {
                // For |a| > 1e8 use inv-based formula to avoid a*a+1 overflow.
                if abs(a) > 1e8 {
                    let inv = 1.0 / a;
                    // d/dx asinh = 1/sqrt(1+x²) = |1/x|/sqrt(1+1/x²) for large |x|
                    let denom = sqrt(1.0 + inv * inv);
                    da_re = abs(inv) / denom;
                    // d²/dx² asinh = -x/(1+x²)^(3/2). Rewrite via inv = 1/x:
                    //   = -sign(x)·|inv|³/denom³ = -sign(x)·inv²·|inv|/denom³
                    let denom3 = denom * denom * denom;
                    da_eps = -a * at * inv * inv * abs(inv) / denom3;
                } else {
                    let t = sqrt(a * a + 1.0);
                    da_re = 1.0 / t;
                    da_eps = -a * at / (t * t * t);
                }
            }
            case 34u /* ACOSH */: {
                if abs(a) > 1e8 {
                    let inv = 1.0 / a;
                    let denom = sqrt(1.0 - inv * inv);
                    da_re = abs(inv) / denom;
                    let denom3 = denom * denom * denom;
                    da_eps = -a * at * inv * inv * abs(inv) / denom3;
                } else {
                    // Factored form (a-1)(a+1) avoids cancellation
                    // near a=1; matches kernels::acosh_deriv. Both
                    // first-order (1/t) and second-order (-a*at/t³)
                    // benefit from the precision improvement.
                    let t = sqrt((a - 1.0) * (a + 1.0));
                    da_re = 1.0 / t;
                    da_eps = -a * at / (t * t * t);
                }
            }
            case 35u /* ATANH */: { let t=(1.0-a)*(1.0+a); da_re=1.0/t; da_eps=2.0*a*at/(t*t); }
            case 36u /* ABS */: {
                if a != a {
                    da_re = 0.0;
                } else {
                    let bits = bitcast<u32>(a);
                    da_re = select(1.0, -1.0, (bits & 0x80000000u) != 0u);
                }
            }
            case 37u, 38u, 39u, 40u, 41u: { /* zero derivative */ }
            case 42u /* FRACT */: { da_re=1.0; }
            default: {}
        }

        // Dual accumulation: adj[arg] += Dual(da_re, da_eps) * Dual(ar, ae)
        adj_re[base + ai] += da_re * ar;
        adj_eps[base + ai] += da_re * ae + da_eps * ar;

        if bi != UNUSED && op != OP_POWI {
            adj_re[base + bi] += db_re * ar;
            adj_eps[base + bi] += db_re * ae + db_eps * ar;
        }
    }

    // Write gradient and HVP outputs
    let g_base = bid * ni;
    for (var i = 0u; i < ni; i = i + 1u) {
        grad_out[g_base + i] = adj_re[base + i];
        hvp_out[g_base + i] = adj_eps[base + i];
    }
}