1use std::collections::{BTreeSet, HashMap, HashSet};
26use std::rc::Rc;
27
28use super::nl_reader::{BinOp, Expr, UnaryOp};
29
30#[derive(Debug, Clone)]
34pub enum TapeOp {
35 Const(f64),
36 Var(usize),
37 Add(usize, usize),
38 Sub(usize, usize),
39 Mul(usize, usize),
40 Div(usize, usize),
41 Pow(usize, usize),
42 Neg(usize),
43 Abs(usize),
44 Sqrt(usize),
45 Exp(usize),
46 Log(usize),
47 Log10(usize),
48 Sin(usize),
49 Cos(usize),
50}
51
52#[derive(Debug, Clone)]
55pub struct Tape {
56 pub ops: Vec<TapeOp>,
57}
58
59impl Tape {
60 pub fn build(expr: &Expr) -> Self {
64 let mut ops = Vec::new();
65 let mut cache: HashMap<*const Expr, usize> = HashMap::new();
66 build_recursive(expr, &mut ops, &mut cache);
67 Tape { ops }
68 }
69
70 pub fn forward(&self, x: &[f64]) -> Vec<f64> {
73 let mut vals: Vec<f64> = Vec::with_capacity(self.ops.len());
74 for op in &self.ops {
75 let v = match op {
76 TapeOp::Const(c) => *c,
77 TapeOp::Var(i) => x[*i],
78 TapeOp::Add(a, b) => vals[*a] + vals[*b],
79 TapeOp::Sub(a, b) => vals[*a] - vals[*b],
80 TapeOp::Mul(a, b) => vals[*a] * vals[*b],
81 TapeOp::Div(a, b) => vals[*a] / vals[*b],
82 TapeOp::Pow(a, b) => vals[*a].powf(vals[*b]),
83 TapeOp::Neg(a) => -vals[*a],
84 TapeOp::Abs(a) => vals[*a].abs(),
85 TapeOp::Sqrt(a) => vals[*a].sqrt(),
86 TapeOp::Exp(a) => vals[*a].exp(),
87 TapeOp::Log(a) => vals[*a].ln(),
88 TapeOp::Log10(a) => vals[*a].log10(),
89 TapeOp::Sin(a) => vals[*a].sin(),
90 TapeOp::Cos(a) => vals[*a].cos(),
91 };
92 vals.push(v);
93 }
94 vals
95 }
96
97 pub fn eval(&self, x: &[f64]) -> f64 {
98 let vals = self.forward(x);
99 *vals.last().unwrap_or(&0.0)
100 }
101
102 pub fn gradient_seed(&self, x: &[f64], seed: f64, grad: &mut [f64]) {
107 if seed == 0.0 || self.ops.is_empty() {
108 return;
109 }
110 let vals = self.forward(x);
111 self.reverse(&vals, seed, grad);
112 }
113
114 fn reverse(&self, vals: &[f64], seed: f64, grad: &mut [f64]) {
115 let n = self.ops.len();
116 let mut adj = vec![0.0f64; n];
117 adj[n - 1] = seed;
118
119 for i in (0..n).rev() {
120 let a = adj[i];
121 if a == 0.0 {
122 continue;
123 }
124 match &self.ops[i] {
125 TapeOp::Const(_) => {}
126 TapeOp::Var(j) => {
127 grad[*j] += a;
128 }
129 TapeOp::Add(l, r) => {
130 adj[*l] += a;
131 adj[*r] += a;
132 }
133 TapeOp::Sub(l, r) => {
134 adj[*l] += a;
135 adj[*r] -= a;
136 }
137 TapeOp::Mul(l, r) => {
138 adj[*l] += a * vals[*r];
139 adj[*r] += a * vals[*l];
140 }
141 TapeOp::Div(l, r) => {
142 let rv = vals[*r];
143 adj[*l] += a / rv;
144 adj[*r] -= a * vals[*l] / (rv * rv);
145 }
146 TapeOp::Pow(l, r) => {
147 let lv = vals[*l];
148 let rv = vals[*r];
149 if rv != 0.0 {
150 adj[*l] += a * rv * lv.powf(rv - 1.0);
151 }
152 if lv > 0.0 {
153 adj[*r] += a * vals[i] * lv.ln();
154 }
155 }
156 TapeOp::Neg(j) => {
157 adj[*j] -= a;
158 }
159 TapeOp::Abs(j) => {
160 if vals[*j] >= 0.0 {
161 adj[*j] += a;
162 } else {
163 adj[*j] -= a;
164 }
165 }
166 TapeOp::Sqrt(j) => {
167 let sv = vals[i];
168 if sv > 0.0 {
169 adj[*j] += a * 0.5 / sv;
170 }
171 }
172 TapeOp::Exp(j) => {
173 adj[*j] += a * vals[i];
174 }
175 TapeOp::Log(j) => {
176 adj[*j] += a / vals[*j];
177 }
178 TapeOp::Log10(j) => {
179 adj[*j] += a / (vals[*j] * std::f64::consts::LN_10);
180 }
181 TapeOp::Sin(j) => {
182 adj[*j] += a * vals[*j].cos();
183 }
184 TapeOp::Cos(j) => {
185 adj[*j] -= a * vals[*j].sin();
186 }
187 }
188 }
189 }
190
191 pub fn variables(&self) -> Vec<usize> {
193 let mut s: BTreeSet<usize> = BTreeSet::new();
194 for op in &self.ops {
195 if let TapeOp::Var(j) = op {
196 s.insert(*j);
197 }
198 }
199 s.into_iter().collect()
200 }
201
202 fn forward_tangent(&self, vals: &[f64], seed_var: usize, dot: &mut [f64]) {
207 let n = self.ops.len();
208 debug_assert_eq!(dot.len(), n);
209 for i in 0..n {
210 dot[i] = match &self.ops[i] {
211 TapeOp::Const(_) => 0.0,
212 TapeOp::Var(k) => {
213 if *k == seed_var {
214 1.0
215 } else {
216 0.0
217 }
218 }
219 TapeOp::Add(a, b) => dot[*a] + dot[*b],
220 TapeOp::Sub(a, b) => dot[*a] - dot[*b],
221 TapeOp::Mul(a, b) => dot[*a] * vals[*b] + vals[*a] * dot[*b],
222 TapeOp::Div(a, b) => {
223 let vb = vals[*b];
224 (dot[*a] * vb - vals[*a] * dot[*b]) / (vb * vb)
225 }
226 TapeOp::Pow(a, b) => {
227 let u = vals[*a];
228 let r = vals[*b];
229 let du = dot[*a];
230 let dr = dot[*b];
231 let mut result = 0.0;
232 if r != 0.0 && u != 0.0 {
233 result += r * u.powf(r - 1.0) * du;
234 }
235 if u > 0.0 {
236 result += vals[i] * u.ln() * dr;
237 }
238 result
239 }
240 TapeOp::Neg(a) => -dot[*a],
241 TapeOp::Abs(a) => {
242 if vals[*a] >= 0.0 {
243 dot[*a]
244 } else {
245 -dot[*a]
246 }
247 }
248 TapeOp::Sqrt(a) => {
249 let sv = vals[i];
250 if sv > 0.0 {
251 dot[*a] * 0.5 / sv
252 } else {
253 0.0
254 }
255 }
256 TapeOp::Exp(a) => dot[*a] * vals[i],
257 TapeOp::Log(a) => dot[*a] / vals[*a],
258 TapeOp::Log10(a) => dot[*a] / (vals[*a] * std::f64::consts::LN_10),
259 TapeOp::Sin(a) => dot[*a] * vals[*a].cos(),
260 TapeOp::Cos(a) => -dot[*a] * vals[*a].sin(),
261 };
262 }
263 }
264
265 pub fn forward_into(&self, x: &[f64], vals: &mut [f64]) {
269 let n = self.ops.len();
270 debug_assert!(vals.len() >= n);
271 for i in 0..n {
272 vals[i] = match &self.ops[i] {
273 TapeOp::Const(c) => *c,
274 TapeOp::Var(j) => x[*j],
275 TapeOp::Add(a, b) => vals[*a] + vals[*b],
276 TapeOp::Sub(a, b) => vals[*a] - vals[*b],
277 TapeOp::Mul(a, b) => vals[*a] * vals[*b],
278 TapeOp::Div(a, b) => vals[*a] / vals[*b],
279 TapeOp::Pow(a, b) => vals[*a].powf(vals[*b]),
280 TapeOp::Neg(a) => -vals[*a],
281 TapeOp::Abs(a) => vals[*a].abs(),
282 TapeOp::Sqrt(a) => vals[*a].sqrt(),
283 TapeOp::Exp(a) => vals[*a].exp(),
284 TapeOp::Log(a) => vals[*a].ln(),
285 TapeOp::Log10(a) => vals[*a].log10(),
286 TapeOp::Sin(a) => vals[*a].sin(),
287 TapeOp::Cos(a) => vals[*a].cos(),
288 };
289 }
290 }
291
292 pub fn hessian_directional(
309 &self,
310 vals: &[f64],
311 seed: &[f64],
312 weight: f64,
313 out: &mut [f64],
314 dot: &mut [f64],
315 adj: &mut [f64],
316 adj_dot: &mut [f64],
317 ) {
318 let n = self.ops.len();
319 if n == 0 || weight == 0.0 {
320 return;
321 }
322 debug_assert!(vals.len() >= n);
323 debug_assert!(dot.len() >= n);
324 debug_assert!(adj.len() >= n);
325 debug_assert!(adj_dot.len() >= n);
326
327 for i in 0..n {
331 dot[i] = match &self.ops[i] {
332 TapeOp::Const(_) => 0.0,
333 TapeOp::Var(k) => seed[*k],
334 TapeOp::Add(a, b) => dot[*a] + dot[*b],
335 TapeOp::Sub(a, b) => dot[*a] - dot[*b],
336 TapeOp::Mul(a, b) => dot[*a] * vals[*b] + vals[*a] * dot[*b],
337 TapeOp::Div(a, b) => {
338 let vb = vals[*b];
339 (dot[*a] * vb - vals[*a] * dot[*b]) / (vb * vb)
340 }
341 TapeOp::Pow(a, b) => {
342 let u = vals[*a];
343 let r = vals[*b];
344 let du = dot[*a];
345 let dr = dot[*b];
346 let mut result = 0.0;
347 if r != 0.0 && u != 0.0 {
348 result += r * u.powf(r - 1.0) * du;
349 }
350 if u > 0.0 {
351 result += vals[i] * u.ln() * dr;
352 }
353 result
354 }
355 TapeOp::Neg(a) => -dot[*a],
356 TapeOp::Abs(a) => {
357 if vals[*a] >= 0.0 {
358 dot[*a]
359 } else {
360 -dot[*a]
361 }
362 }
363 TapeOp::Sqrt(a) => {
364 let sv = vals[i];
365 if sv > 0.0 {
366 dot[*a] * 0.5 / sv
367 } else {
368 0.0
369 }
370 }
371 TapeOp::Exp(a) => vals[i] * dot[*a],
372 TapeOp::Log(a) => dot[*a] / vals[*a],
373 TapeOp::Log10(a) => dot[*a] / (vals[*a] * std::f64::consts::LN_10),
374 TapeOp::Sin(a) => vals[*a].cos() * dot[*a],
375 TapeOp::Cos(a) => -vals[*a].sin() * dot[*a],
376 };
377 }
378
379 for slot in adj.iter_mut().take(n) {
383 *slot = 0.0;
384 }
385 for slot in adj_dot.iter_mut().take(n) {
386 *slot = 0.0;
387 }
388 adj[n - 1] = 1.0;
389
390 for i in (0..n).rev() {
391 let w = adj[i];
392 let wd = adj_dot[i];
393 if w == 0.0 && wd == 0.0 {
394 continue;
395 }
396 match &self.ops[i] {
397 TapeOp::Const(_) => {}
398 TapeOp::Var(k) => {
399 if wd != 0.0 {
400 out[*k] += weight * wd;
401 }
402 }
403 TapeOp::Add(a, b) => {
404 adj[*a] += w;
405 adj[*b] += w;
406 adj_dot[*a] += wd;
407 adj_dot[*b] += wd;
408 }
409 TapeOp::Sub(a, b) => {
410 adj[*a] += w;
411 adj[*b] -= w;
412 adj_dot[*a] += wd;
413 adj_dot[*b] -= wd;
414 }
415 TapeOp::Mul(a, b) => {
416 adj[*a] += w * vals[*b];
417 adj[*b] += w * vals[*a];
418 adj_dot[*a] += wd * vals[*b] + w * dot[*b];
419 adj_dot[*b] += wd * vals[*a] + w * dot[*a];
420 }
421 TapeOp::Div(a, b) => {
422 let vb = vals[*b];
423 let vb2 = vb * vb;
424 let vb3 = vb2 * vb;
425 adj[*a] += w / vb;
426 adj_dot[*a] += wd / vb + w * (-dot[*b] / vb2);
427 adj[*b] += w * (-vals[*a] / vb2);
428 adj_dot[*b] += wd * (-vals[*a] / vb2)
429 + w * (-dot[*a] / vb2 + 2.0 * vals[*a] * dot[*b] / vb3);
430 }
431 TapeOp::Pow(a, b) => {
432 let u = vals[*a];
433 let r = vals[*b];
434 let du = dot[*a];
435 let dr = dot[*b];
436 if r != 0.0 {
437 if u != 0.0 {
438 let p_a = r * u.powf(r - 1.0);
439 adj[*a] += w * p_a;
440 let mut dp_a = dr * u.powf(r - 1.0);
441 if u > 0.0 {
442 dp_a += r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
443 } else {
444 dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
445 }
446 adj_dot[*a] += wd * p_a + w * dp_a;
447 } else if r >= 2.0 {
448 let p_a = 0.0;
449 adj[*a] += w * p_a;
450 let dp_a = if r == 2.0 {
451 2.0 * du
452 } else {
453 r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
454 };
455 adj_dot[*a] += wd * p_a + w * dp_a;
456 }
457 }
458 if u > 0.0 {
459 let ln_u = u.ln();
460 let p_b = vals[i] * ln_u;
461 adj[*b] += w * p_b;
462 let dur = vals[i] * (r * du / u + dr * ln_u);
463 let dp_b = dur * ln_u + vals[i] * du / u;
464 adj_dot[*b] += wd * p_b + w * dp_b;
465 }
466 }
467 TapeOp::Neg(a) => {
468 adj[*a] -= w;
469 adj_dot[*a] -= wd;
470 }
471 TapeOp::Abs(a) => {
472 let s = if vals[*a] >= 0.0 { 1.0 } else { -1.0 };
473 adj[*a] += w * s;
474 adj_dot[*a] += wd * s;
475 }
476 TapeOp::Sqrt(a) => {
477 let sv = vals[i];
478 if sv > 0.0 {
479 let fp = 0.5 / sv;
480 let fpp = -0.25 / (vals[*a] * sv);
481 adj[*a] += w * fp;
482 adj_dot[*a] += wd * fp + w * fpp * dot[*a];
483 }
484 }
485 TapeOp::Exp(a) => {
486 let ev = vals[i];
487 adj[*a] += w * ev;
488 adj_dot[*a] += wd * ev + w * ev * dot[*a];
489 }
490 TapeOp::Log(a) => {
491 let u = vals[*a];
492 adj[*a] += w / u;
493 adj_dot[*a] += wd / u + w * (-1.0 / (u * u)) * dot[*a];
494 }
495 TapeOp::Log10(a) => {
496 let u = vals[*a];
497 let c = std::f64::consts::LN_10;
498 adj[*a] += w / (u * c);
499 adj_dot[*a] += wd / (u * c) + w * (-1.0 / (u * u * c)) * dot[*a];
500 }
501 TapeOp::Sin(a) => {
502 let u = vals[*a];
503 let cu = u.cos();
504 adj[*a] += w * cu;
505 adj_dot[*a] += wd * cu + w * (-u.sin()) * dot[*a];
506 }
507 TapeOp::Cos(a) => {
508 let u = vals[*a];
509 let su = u.sin();
510 adj[*a] -= w * su;
511 adj_dot[*a] += wd * (-su) + w * (-u.cos()) * dot[*a];
512 }
513 }
514 }
515 }
516
517 pub fn hessian_accumulate(
524 &self,
525 x: &[f64],
526 weight: f64,
527 hess_map: &HashMap<(usize, usize), usize>,
528 values: &mut [f64],
529 ) {
530 let n = self.ops.len();
531 if n == 0 || weight == 0.0 {
532 return;
533 }
534 let v = self.forward(x);
535 let var_indices = self.variables();
536
537 let mut dot = vec![0.0f64; n];
544 let mut adj = vec![0.0f64; n];
545 let mut adj_dot = vec![0.0f64; n];
546 for &j in &var_indices {
547 self.forward_tangent(&v, j, &mut dot);
548
549 adj.fill(0.0);
552 adj_dot.fill(0.0);
553 adj[n - 1] = 1.0;
554
555 for i in (0..n).rev() {
556 let w = adj[i];
557 let wd = adj_dot[i];
558 if w == 0.0 && wd == 0.0 {
559 continue;
560 }
561 match &self.ops[i] {
562 TapeOp::Const(_) => {}
563 TapeOp::Var(k) => {
564 if wd != 0.0 && *k >= j {
567 if let Some(&pos) = hess_map.get(&(*k, j)) {
568 values[pos] += weight * wd;
569 }
570 }
571 }
572 TapeOp::Add(a, b) => {
573 adj[*a] += w;
574 adj[*b] += w;
575 adj_dot[*a] += wd;
576 adj_dot[*b] += wd;
577 }
578 TapeOp::Sub(a, b) => {
579 adj[*a] += w;
580 adj[*b] -= w;
581 adj_dot[*a] += wd;
582 adj_dot[*b] -= wd;
583 }
584 TapeOp::Mul(a, b) => {
585 adj[*a] += w * v[*b];
586 adj[*b] += w * v[*a];
587 adj_dot[*a] += wd * v[*b] + w * dot[*b];
588 adj_dot[*b] += wd * v[*a] + w * dot[*a];
589 }
590 TapeOp::Div(a, b) => {
591 let vb = v[*b];
592 let vb2 = vb * vb;
593 let vb3 = vb2 * vb;
594 adj[*a] += w / vb;
595 adj_dot[*a] += wd / vb + w * (-dot[*b] / vb2);
596 adj[*b] += w * (-v[*a] / vb2);
597 adj_dot[*b] += wd * (-v[*a] / vb2)
598 + w * (-dot[*a] / vb2 + 2.0 * v[*a] * dot[*b] / vb3);
599 }
600 TapeOp::Pow(a, b) => {
601 let u = v[*a];
602 let r = v[*b];
603 let du = dot[*a];
604 let dr = dot[*b];
605 if r != 0.0 {
606 if u != 0.0 {
607 let p_a = r * u.powf(r - 1.0);
608 adj[*a] += w * p_a;
609 let mut dp_a = dr * u.powf(r - 1.0);
610 if u > 0.0 {
611 dp_a +=
612 r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
613 } else {
614 dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
615 }
616 adj_dot[*a] += wd * p_a + w * dp_a;
617 } else if r >= 2.0 {
618 let p_a = 0.0;
619 adj[*a] += w * p_a;
620 let dp_a = if r == 2.0 {
621 2.0 * du
622 } else {
623 r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
624 };
625 adj_dot[*a] += wd * p_a + w * dp_a;
626 }
627 }
628 if u > 0.0 {
629 let ln_u = u.ln();
630 let p_b = v[i] * ln_u;
631 adj[*b] += w * p_b;
632 let dur = v[i] * (r * du / u + dr * ln_u);
633 let dp_b = dur * ln_u + v[i] * du / u;
634 adj_dot[*b] += wd * p_b + w * dp_b;
635 }
636 }
637 TapeOp::Neg(a) => {
638 adj[*a] -= w;
639 adj_dot[*a] -= wd;
640 }
641 TapeOp::Abs(a) => {
642 let s = if v[*a] >= 0.0 { 1.0 } else { -1.0 };
643 adj[*a] += w * s;
644 adj_dot[*a] += wd * s;
645 }
646 TapeOp::Sqrt(a) => {
647 let sv = v[i];
648 if sv > 0.0 {
649 let fp = 0.5 / sv;
650 let fpp = -0.25 / (v[*a] * sv);
651 adj[*a] += w * fp;
652 adj_dot[*a] += wd * fp + w * fpp * dot[*a];
653 }
654 }
655 TapeOp::Exp(a) => {
656 let ev = v[i];
657 adj[*a] += w * ev;
658 adj_dot[*a] += wd * ev + w * ev * dot[*a];
659 }
660 TapeOp::Log(a) => {
661 let u = v[*a];
662 adj[*a] += w / u;
663 adj_dot[*a] += wd / u + w * (-1.0 / (u * u)) * dot[*a];
664 }
665 TapeOp::Log10(a) => {
666 let u = v[*a];
667 let c = std::f64::consts::LN_10;
668 adj[*a] += w / (u * c);
669 adj_dot[*a] += wd / (u * c) + w * (-1.0 / (u * u * c)) * dot[*a];
670 }
671 TapeOp::Sin(a) => {
672 let u = v[*a];
673 let cu = u.cos();
674 adj[*a] += w * cu;
675 adj_dot[*a] += wd * cu + w * (-u.sin()) * dot[*a];
676 }
677 TapeOp::Cos(a) => {
678 let u = v[*a];
679 let su = u.sin();
680 adj[*a] -= w * su;
681 adj_dot[*a] += wd * (-su) + w * (-u.cos()) * dot[*a];
682 }
683 }
684 }
685 }
686 }
687
688 pub fn hessian_sparsity(&self) -> BTreeSet<(usize, usize)> {
693 let n = self.ops.len();
694 let mut var_sets: Vec<BTreeSet<usize>> = Vec::with_capacity(n);
695 let mut pairs: BTreeSet<(usize, usize)> = BTreeSet::new();
696
697 let emit_cross =
698 |s1: &BTreeSet<usize>, s2: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
699 for &v1 in s1 {
700 for &v2 in s2 {
701 let (r, c) = if v1 >= v2 { (v1, v2) } else { (v2, v1) };
702 pairs.insert((r, c));
703 }
704 }
705 };
706 let emit_self = |s: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
707 let vars: Vec<usize> = s.iter().copied().collect();
708 for (ai, &vi) in vars.iter().enumerate() {
709 for &vj in &vars[..=ai] {
710 let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
711 pairs.insert((r, c));
712 }
713 }
714 };
715
716 for op in &self.ops {
717 let vset = match op {
718 TapeOp::Const(_) => BTreeSet::new(),
719 TapeOp::Var(j) => {
720 let mut s = BTreeSet::new();
721 s.insert(*j);
722 s
723 }
724 TapeOp::Add(a, b) | TapeOp::Sub(a, b) => {
725 var_sets[*a].union(&var_sets[*b]).copied().collect()
726 }
727 TapeOp::Neg(a) | TapeOp::Abs(a) => var_sets[*a].clone(),
728 TapeOp::Mul(a, b) => {
729 emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
730 var_sets[*a].union(&var_sets[*b]).copied().collect()
731 }
732 TapeOp::Div(a, b) => {
733 emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
734 emit_self(&var_sets[*b], &mut pairs);
735 var_sets[*a].union(&var_sets[*b]).copied().collect()
736 }
737 TapeOp::Pow(a, b) => {
738 let combined: BTreeSet<usize> =
739 var_sets[*a].union(&var_sets[*b]).copied().collect();
740 emit_self(&combined, &mut pairs);
741 combined
742 }
743 TapeOp::Sqrt(a)
744 | TapeOp::Exp(a)
745 | TapeOp::Log(a)
746 | TapeOp::Log10(a)
747 | TapeOp::Sin(a)
748 | TapeOp::Cos(a) => {
749 emit_self(&var_sets[*a], &mut pairs);
750 var_sets[*a].clone()
751 }
752 };
753 var_sets.push(vset);
754 }
755 pairs
756 }
757}
758
759fn build_recursive(
760 expr: &Expr,
761 ops: &mut Vec<TapeOp>,
762 cache: &mut HashMap<*const Expr, usize>,
763) -> usize {
764 match expr {
765 Expr::Const(c) => {
766 let idx = ops.len();
767 ops.push(TapeOp::Const(*c));
768 idx
769 }
770 Expr::Var(i) => {
771 let idx = ops.len();
772 ops.push(TapeOp::Var(*i));
773 idx
774 }
775 Expr::Binary(op, a, b) => {
776 if let BinOp::Pow = op {
784 if let Some(c) = peek_const(b) {
785 if let Some(idx) = try_emit_const_pow(a, c, ops, cache) {
786 return idx;
787 }
788 }
789 }
790 let l = build_recursive(a, ops, cache);
791 let r = build_recursive(b, ops, cache);
792 let idx = ops.len();
793 ops.push(match op {
794 BinOp::Add => TapeOp::Add(l, r),
795 BinOp::Sub => TapeOp::Sub(l, r),
796 BinOp::Mul => TapeOp::Mul(l, r),
797 BinOp::Div => TapeOp::Div(l, r),
798 BinOp::Pow => TapeOp::Pow(l, r),
799 });
800 idx
801 }
802 Expr::Unary(op, a) => {
803 let v = build_recursive(a, ops, cache);
804 let idx = ops.len();
805 ops.push(match op {
806 UnaryOp::Neg => TapeOp::Neg(v),
807 UnaryOp::Sqrt => TapeOp::Sqrt(v),
808 UnaryOp::Log => TapeOp::Log(v),
809 UnaryOp::Log10 => TapeOp::Log10(v),
810 UnaryOp::Exp => TapeOp::Exp(v),
811 UnaryOp::Abs => TapeOp::Abs(v),
812 UnaryOp::Sin => TapeOp::Sin(v),
813 UnaryOp::Cos => TapeOp::Cos(v),
814 });
815 idx
816 }
817 Expr::Sum(args) => {
818 if args.is_empty() {
819 let idx = ops.len();
820 ops.push(TapeOp::Const(0.0));
821 return idx;
822 }
823 let mut acc = build_recursive(&args[0], ops, cache);
824 for a in &args[1..] {
825 let next = build_recursive(a, ops, cache);
826 let idx = ops.len();
827 ops.push(TapeOp::Add(acc, next));
828 acc = idx;
829 }
830 acc
831 }
832 Expr::Cse(body) => {
833 let key = Rc::as_ptr(body) as *const Expr;
840 if let Some(&idx) = cache.get(&key) {
841 idx
842 } else {
843 let idx = build_recursive(body, ops, cache);
844 cache.insert(key, idx);
845 idx
846 }
847 }
848 }
849}
850
851fn peek_const(e: &Expr) -> Option<f64> {
855 match e {
856 Expr::Const(c) => Some(*c),
857 Expr::Cse(body) => peek_const(body),
858 _ => None,
859 }
860}
861
862fn try_emit_const_pow(
870 base_expr: &Expr,
871 c: f64,
872 ops: &mut Vec<TapeOp>,
873 cache: &mut HashMap<*const Expr, usize>,
874) -> Option<usize> {
875 if c == 0.0 {
876 let idx = ops.len();
877 ops.push(TapeOp::Const(1.0));
878 return Some(idx);
879 }
880 if c == 1.0 {
881 return Some(build_recursive(base_expr, ops, cache));
882 }
883 if c == 0.5 {
884 let b = build_recursive(base_expr, ops, cache);
885 let idx = ops.len();
886 ops.push(TapeOp::Sqrt(b));
887 return Some(idx);
888 }
889 if c.is_finite() && c.fract() == 0.0 && c.abs() <= 8.0 {
894 let n = c.abs() as u32;
895 if n == 0 {
896 let idx = ops.len();
898 ops.push(TapeOp::Const(1.0));
899 return Some(idx);
900 }
901 let b = build_recursive(base_expr, ops, cache);
902 let pos = emit_int_pow(b, n, ops);
903 if c < 0.0 {
904 let one_idx = ops.len();
907 ops.push(TapeOp::Const(1.0));
908 let idx = ops.len();
909 ops.push(TapeOp::Div(one_idx, pos));
910 return Some(idx);
911 }
912 return Some(pos);
913 }
914 None
915}
916
917fn emit_int_pow(base: usize, n: u32, ops: &mut Vec<TapeOp>) -> usize {
921 debug_assert!(n >= 1);
922 if n == 1 {
923 return base;
924 }
925 let half = emit_int_pow(base, n / 2, ops);
926 let squared = ops.len();
927 ops.push(TapeOp::Mul(half, half));
928 if n % 2 == 1 {
929 let idx = ops.len();
930 ops.push(TapeOp::Mul(squared, base));
931 idx
932 } else {
933 squared
934 }
935}
936
937#[derive(Debug, Clone)]
965pub enum SummandOp {
966 Local(TapeOp),
969 Shared(usize),
973}
974
975#[derive(Debug, Clone)]
976pub struct Summand {
977 pub ops: Vec<SummandOp>,
978 pub root_slot: usize,
980 pub local_reach: Vec<usize>,
982 pub prelude_reach: Vec<usize>,
985 pub local_vars: Vec<usize>,
987 pub prelude_vars: Vec<usize>,
989 pub all_vars: Vec<usize>,
991}
992
993#[derive(Debug)]
994pub struct HybridTape {
995 pub prelude: Vec<TapeOp>,
1000 pub summands: Vec<Summand>,
1001}
1002
1003impl HybridTape {
1004 pub fn build_multi(exprs: &[Expr]) -> Self {
1009 let mut cse_count: HashMap<*const Expr, usize> = HashMap::new();
1013 for e in exprs {
1014 let mut seen_in_root: HashSet<*const Expr> = HashSet::new();
1015 count_cse_appearances(e, &mut seen_in_root, &mut cse_count);
1016 }
1017
1018 let mut prelude: Vec<TapeOp> = Vec::new();
1023 let mut prelude_map: HashMap<*const Expr, usize> = HashMap::new();
1024 let mut summands: Vec<Summand> = Vec::with_capacity(exprs.len());
1025 for e in exprs {
1026 let mut local: Vec<SummandOp> = Vec::new();
1027 let mut local_cache: HashMap<*const Expr, usize> = HashMap::new();
1028 let root_slot = build_into_summand(
1029 e,
1030 &mut local,
1031 &mut local_cache,
1032 &mut prelude,
1033 &mut prelude_map,
1034 &cse_count,
1035 );
1036 summands.push(Summand {
1037 ops: local,
1038 root_slot,
1039 local_reach: Vec::new(),
1040 prelude_reach: Vec::new(),
1041 local_vars: Vec::new(),
1042 prelude_vars: Vec::new(),
1043 all_vars: Vec::new(),
1044 });
1045 }
1046
1047 let mut p_visited: Vec<u32> = vec![0; prelude.len()];
1051 let mut p_epoch: u32 = 0;
1052 let mut p_stack: Vec<usize> = Vec::new();
1053 for s in &mut summands {
1054 let (local_reach, shared_refs) = compute_local_reach(&s.ops, s.root_slot);
1055 s.local_reach = local_reach;
1056
1057 let mut lv: BTreeSet<usize> = BTreeSet::new();
1058 for &i in &s.local_reach {
1059 if let SummandOp::Local(TapeOp::Var(j)) = &s.ops[i] {
1060 lv.insert(*j);
1061 }
1062 }
1063 s.local_vars = lv.iter().copied().collect();
1064
1065 if !shared_refs.is_empty() {
1066 p_epoch += 1;
1067 let mut preach: Vec<usize> = Vec::new();
1068 for &start in &shared_refs {
1069 bfs_prelude(
1070 &prelude,
1071 start,
1072 &mut p_visited,
1073 p_epoch,
1074 &mut p_stack,
1075 &mut preach,
1076 );
1077 }
1078 preach.sort_unstable();
1079 s.prelude_vars = vars_in(&prelude, &preach);
1080 s.prelude_reach = preach;
1081 }
1082
1083 let mut av: BTreeSet<usize> = lv;
1084 for &v in &s.prelude_vars {
1085 av.insert(v);
1086 }
1087 s.all_vars = av.into_iter().collect();
1088 }
1089
1090 HybridTape { prelude, summands }
1091 }
1092
1093 pub fn n_prelude_ops(&self) -> usize {
1094 self.prelude.len()
1095 }
1096 pub fn n_summands(&self) -> usize {
1097 self.summands.len()
1098 }
1099 pub fn max_summand_ops(&self) -> usize {
1100 self.summands.iter().map(|s| s.ops.len()).max().unwrap_or(0)
1101 }
1102 pub fn total_local_ops(&self) -> usize {
1103 self.summands.iter().map(|s| s.ops.len()).sum()
1104 }
1105
1106 pub fn forward_prelude(&self, x: &[f64], prelude_vals: &mut [f64]) {
1109 debug_assert_eq!(prelude_vals.len(), self.prelude.len());
1110 for i in 0..self.prelude.len() {
1111 prelude_vals[i] = fwd_step(&self.prelude[i], x, prelude_vals);
1112 }
1113 }
1114
1115 pub fn forward_summand(
1118 &self,
1119 s: &Summand,
1120 x: &[f64],
1121 prelude_vals: &[f64],
1122 local_vals: &mut [f64],
1123 ) {
1124 debug_assert!(local_vals.len() >= s.ops.len());
1125 for i in 0..s.ops.len() {
1126 local_vals[i] = match &s.ops[i] {
1127 SummandOp::Local(op) => fwd_step(op, x, local_vals),
1128 SummandOp::Shared(k) => prelude_vals[*k],
1129 };
1130 }
1131 }
1132
1133 #[inline]
1135 pub fn root_value(&self, s: &Summand, local_vals: &[f64]) -> f64 {
1136 local_vals[s.root_slot]
1137 }
1138
1139 #[allow(clippy::too_many_arguments)]
1146 pub fn gradient_summand(
1147 &self,
1148 s: &Summand,
1149 prelude_vals: &[f64],
1150 local_vals: &[f64],
1151 seed: f64,
1152 grad: &mut [f64],
1153 local_adj: &mut [f64],
1154 prelude_adj: &mut [f64],
1155 ) {
1156 if seed == 0.0 || s.local_reach.is_empty() {
1157 return;
1158 }
1159 for &i in &s.local_reach {
1160 local_adj[i] = 0.0;
1161 }
1162 for &i in &s.prelude_reach {
1163 prelude_adj[i] = 0.0;
1164 }
1165 local_adj[s.root_slot] = seed;
1166 for &i in s.local_reach.iter().rev() {
1167 let a = local_adj[i];
1168 if a == 0.0 {
1169 continue;
1170 }
1171 match &s.ops[i] {
1172 SummandOp::Local(op) => rev_step(op, i, local_vals, local_adj, a, grad),
1173 SummandOp::Shared(k) => {
1174 prelude_adj[*k] += a;
1175 }
1176 }
1177 }
1178 for &i in s.prelude_reach.iter().rev() {
1179 let a = prelude_adj[i];
1180 if a == 0.0 {
1181 continue;
1182 }
1183 rev_step(&self.prelude[i], i, prelude_vals, prelude_adj, a, grad);
1184 }
1185 }
1186
1187 #[allow(clippy::too_many_arguments)]
1195 pub fn hessian_summand(
1196 &self,
1197 s: &Summand,
1198 prelude_vals: &[f64],
1199 local_vals: &[f64],
1200 weight: f64,
1201 hess_map: &HashMap<(usize, usize), usize>,
1202 values: &mut [f64],
1203 local_dot: &mut [f64],
1204 local_adj: &mut [f64],
1205 local_adj_dot: &mut [f64],
1206 prelude_dot: &mut [f64],
1207 prelude_adj: &mut [f64],
1208 prelude_adj_dot: &mut [f64],
1209 ) {
1210 if weight == 0.0 || s.local_reach.is_empty() {
1211 return;
1212 }
1213 for &j in &s.all_vars {
1214 for &i in &s.local_reach {
1215 local_dot[i] = 0.0;
1216 local_adj[i] = 0.0;
1217 local_adj_dot[i] = 0.0;
1218 }
1219 for &i in &s.prelude_reach {
1220 prelude_dot[i] = 0.0;
1221 prelude_adj[i] = 0.0;
1222 prelude_adj_dot[i] = 0.0;
1223 }
1224 for &i in &s.prelude_reach {
1225 prelude_dot[i] = fwd_tan_step(&self.prelude[i], j, prelude_vals, prelude_dot, i);
1226 }
1227 for &i in &s.local_reach {
1228 local_dot[i] = match &s.ops[i] {
1229 SummandOp::Local(op) => fwd_tan_step(op, j, local_vals, local_dot, i),
1230 SummandOp::Shared(k) => prelude_dot[*k],
1231 };
1232 }
1233 local_adj[s.root_slot] = 1.0;
1234 for &i in s.local_reach.iter().rev() {
1235 let w = local_adj[i];
1236 let wd = local_adj_dot[i];
1237 if w == 0.0 && wd == 0.0 {
1238 continue;
1239 }
1240 match &s.ops[i] {
1241 SummandOp::Local(op) => {
1242 ror_step(
1243 op,
1244 i,
1245 j,
1246 local_vals,
1247 local_dot,
1248 local_adj,
1249 local_adj_dot,
1250 w,
1251 wd,
1252 weight,
1253 hess_map,
1254 values,
1255 );
1256 }
1257 SummandOp::Shared(k) => {
1258 prelude_adj[*k] += w;
1259 prelude_adj_dot[*k] += wd;
1260 }
1261 }
1262 }
1263 for &i in s.prelude_reach.iter().rev() {
1264 let w = prelude_adj[i];
1265 let wd = prelude_adj_dot[i];
1266 if w == 0.0 && wd == 0.0 {
1267 continue;
1268 }
1269 ror_step(
1270 &self.prelude[i],
1271 i,
1272 j,
1273 prelude_vals,
1274 prelude_dot,
1275 prelude_adj,
1276 prelude_adj_dot,
1277 w,
1278 wd,
1279 weight,
1280 hess_map,
1281 values,
1282 );
1283 }
1284 }
1285 }
1286
1287 pub fn hessian_sparsity_all(&self) -> BTreeSet<(usize, usize)> {
1290 let mut pairs = hessian_sparsity_impl(&self.prelude);
1291
1292 let prelude_var_sets = compute_var_sets(&self.prelude);
1295
1296 for s in &self.summands {
1297 summand_sparsity(&s.ops, &prelude_var_sets, &mut pairs);
1298 }
1299 pairs
1300 }
1301}
1302
1303fn count_cse_appearances(
1308 e: &Expr,
1309 seen_in_root: &mut HashSet<*const Expr>,
1310 counts: &mut HashMap<*const Expr, usize>,
1311) {
1312 match e {
1313 Expr::Const(_) | Expr::Var(_) => {}
1314 Expr::Binary(_, a, b) => {
1315 count_cse_appearances(a, seen_in_root, counts);
1316 count_cse_appearances(b, seen_in_root, counts);
1317 }
1318 Expr::Unary(_, a) => count_cse_appearances(a, seen_in_root, counts),
1319 Expr::Sum(args) => {
1320 for a in args {
1321 count_cse_appearances(a, seen_in_root, counts);
1322 }
1323 }
1324 Expr::Cse(body) => {
1325 let key = Rc::as_ptr(body) as *const Expr;
1326 if seen_in_root.insert(key) {
1327 *counts.entry(key).or_insert(0) += 1;
1328 count_cse_appearances(body, seen_in_root, counts);
1329 }
1330 }
1331 }
1332}
1333
1334fn build_into_summand(
1340 expr: &Expr,
1341 local: &mut Vec<SummandOp>,
1342 local_cache: &mut HashMap<*const Expr, usize>,
1343 prelude: &mut Vec<TapeOp>,
1344 prelude_map: &mut HashMap<*const Expr, usize>,
1345 cse_count: &HashMap<*const Expr, usize>,
1346) -> usize {
1347 match expr {
1348 Expr::Const(c) => {
1349 let i = local.len();
1350 local.push(SummandOp::Local(TapeOp::Const(*c)));
1351 i
1352 }
1353 Expr::Var(j) => {
1354 let i = local.len();
1355 local.push(SummandOp::Local(TapeOp::Var(*j)));
1356 i
1357 }
1358 Expr::Binary(op, a, b) => {
1359 if let BinOp::Pow = op {
1360 if let Some(c) = peek_const(b) {
1361 if let Some(i) = try_emit_const_pow_summand(
1362 a,
1363 c,
1364 local,
1365 local_cache,
1366 prelude,
1367 prelude_map,
1368 cse_count,
1369 ) {
1370 return i;
1371 }
1372 }
1373 }
1374 let l = build_into_summand(a, local, local_cache, prelude, prelude_map, cse_count);
1375 let r = build_into_summand(b, local, local_cache, prelude, prelude_map, cse_count);
1376 let i = local.len();
1377 local.push(SummandOp::Local(match op {
1378 BinOp::Add => TapeOp::Add(l, r),
1379 BinOp::Sub => TapeOp::Sub(l, r),
1380 BinOp::Mul => TapeOp::Mul(l, r),
1381 BinOp::Div => TapeOp::Div(l, r),
1382 BinOp::Pow => TapeOp::Pow(l, r),
1383 }));
1384 i
1385 }
1386 Expr::Unary(op, a) => {
1387 let v = build_into_summand(a, local, local_cache, prelude, prelude_map, cse_count);
1388 let i = local.len();
1389 local.push(SummandOp::Local(match op {
1390 UnaryOp::Neg => TapeOp::Neg(v),
1391 UnaryOp::Sqrt => TapeOp::Sqrt(v),
1392 UnaryOp::Log => TapeOp::Log(v),
1393 UnaryOp::Log10 => TapeOp::Log10(v),
1394 UnaryOp::Exp => TapeOp::Exp(v),
1395 UnaryOp::Abs => TapeOp::Abs(v),
1396 UnaryOp::Sin => TapeOp::Sin(v),
1397 UnaryOp::Cos => TapeOp::Cos(v),
1398 }));
1399 i
1400 }
1401 Expr::Sum(args) => {
1402 if args.is_empty() {
1403 let i = local.len();
1404 local.push(SummandOp::Local(TapeOp::Const(0.0)));
1405 return i;
1406 }
1407 let mut acc = build_into_summand(
1408 &args[0],
1409 local,
1410 local_cache,
1411 prelude,
1412 prelude_map,
1413 cse_count,
1414 );
1415 for a in &args[1..] {
1416 let nxt =
1417 build_into_summand(a, local, local_cache, prelude, prelude_map, cse_count);
1418 let i = local.len();
1419 local.push(SummandOp::Local(TapeOp::Add(acc, nxt)));
1420 acc = i;
1421 }
1422 acc
1423 }
1424 Expr::Cse(body) => {
1425 let key = Rc::as_ptr(body) as *const Expr;
1426 if let Some(&li) = local_cache.get(&key) {
1427 return li;
1428 }
1429 let promoted = cse_count.get(&key).copied().unwrap_or(0) >= 2;
1430 if promoted {
1431 let pslot = build_recursive(expr, prelude, prelude_map);
1436 let li = local.len();
1437 local.push(SummandOp::Shared(pslot));
1438 local_cache.insert(key, li);
1439 li
1440 } else {
1441 let li =
1442 build_into_summand(body, local, local_cache, prelude, prelude_map, cse_count);
1443 local_cache.insert(key, li);
1444 li
1445 }
1446 }
1447 }
1448}
1449
1450fn try_emit_const_pow_summand(
1453 base_expr: &Expr,
1454 c: f64,
1455 local: &mut Vec<SummandOp>,
1456 local_cache: &mut HashMap<*const Expr, usize>,
1457 prelude: &mut Vec<TapeOp>,
1458 prelude_map: &mut HashMap<*const Expr, usize>,
1459 cse_count: &HashMap<*const Expr, usize>,
1460) -> Option<usize> {
1461 if c == 0.0 {
1462 let i = local.len();
1463 local.push(SummandOp::Local(TapeOp::Const(1.0)));
1464 return Some(i);
1465 }
1466 if c == 1.0 {
1467 return Some(build_into_summand(
1468 base_expr,
1469 local,
1470 local_cache,
1471 prelude,
1472 prelude_map,
1473 cse_count,
1474 ));
1475 }
1476 if c == 0.5 {
1477 let b = build_into_summand(
1478 base_expr,
1479 local,
1480 local_cache,
1481 prelude,
1482 prelude_map,
1483 cse_count,
1484 );
1485 let i = local.len();
1486 local.push(SummandOp::Local(TapeOp::Sqrt(b)));
1487 return Some(i);
1488 }
1489 if c.is_finite() && c.fract() == 0.0 && c.abs() <= 8.0 {
1490 let n = c.abs() as u32;
1491 if n == 0 {
1492 let i = local.len();
1493 local.push(SummandOp::Local(TapeOp::Const(1.0)));
1494 return Some(i);
1495 }
1496 let b = build_into_summand(
1497 base_expr,
1498 local,
1499 local_cache,
1500 prelude,
1501 prelude_map,
1502 cse_count,
1503 );
1504 let pos = emit_int_pow_summand(b, n, local);
1505 if c < 0.0 {
1506 let one_idx = local.len();
1507 local.push(SummandOp::Local(TapeOp::Const(1.0)));
1508 let i = local.len();
1509 local.push(SummandOp::Local(TapeOp::Div(one_idx, pos)));
1510 return Some(i);
1511 }
1512 return Some(pos);
1513 }
1514 None
1515}
1516
1517fn emit_int_pow_summand(base: usize, n: u32, local: &mut Vec<SummandOp>) -> usize {
1518 debug_assert!(n >= 1);
1519 if n == 1 {
1520 return base;
1521 }
1522 let half = emit_int_pow_summand(base, n / 2, local);
1523 let squared = local.len();
1524 local.push(SummandOp::Local(TapeOp::Mul(half, half)));
1525 if n % 2 == 1 {
1526 let i = local.len();
1527 local.push(SummandOp::Local(TapeOp::Mul(squared, base)));
1528 i
1529 } else {
1530 squared
1531 }
1532}
1533
1534fn compute_local_reach(ops: &[SummandOp], root: usize) -> (Vec<usize>, Vec<usize>) {
1538 let mut visited = vec![false; ops.len()];
1539 let mut reach: Vec<usize> = Vec::new();
1540 let mut shared: BTreeSet<usize> = BTreeSet::new();
1541 let mut stack: Vec<usize> = Vec::with_capacity(16);
1542 visited[root] = true;
1543 reach.push(root);
1544 stack.push(root);
1545 while let Some(s) = stack.pop() {
1546 match &ops[s] {
1547 SummandOp::Local(op) => {
1548 let (a, b) = op_operands(op);
1549 if let Some(a) = a {
1550 if !visited[a] {
1551 visited[a] = true;
1552 reach.push(a);
1553 stack.push(a);
1554 }
1555 }
1556 if let Some(b) = b {
1557 if !visited[b] {
1558 visited[b] = true;
1559 reach.push(b);
1560 stack.push(b);
1561 }
1562 }
1563 }
1564 SummandOp::Shared(k) => {
1565 shared.insert(*k);
1566 }
1567 }
1568 }
1569 reach.sort_unstable();
1570 (reach, shared.into_iter().collect())
1571}
1572
1573fn bfs_prelude(
1577 prelude: &[TapeOp],
1578 start: usize,
1579 visited: &mut [u32],
1580 cur: u32,
1581 stack: &mut Vec<usize>,
1582 out: &mut Vec<usize>,
1583) {
1584 if visited[start] == cur {
1585 return;
1586 }
1587 visited[start] = cur;
1588 out.push(start);
1589 stack.push(start);
1590 while let Some(s) = stack.pop() {
1591 let (a, b) = op_operands(&prelude[s]);
1592 if let Some(a) = a {
1593 if visited[a] != cur {
1594 visited[a] = cur;
1595 out.push(a);
1596 stack.push(a);
1597 }
1598 }
1599 if let Some(b) = b {
1600 if visited[b] != cur {
1601 visited[b] = cur;
1602 out.push(b);
1603 stack.push(b);
1604 }
1605 }
1606 }
1607}
1608
1609fn compute_var_sets(ops: &[TapeOp]) -> Vec<BTreeSet<usize>> {
1613 let mut out: Vec<BTreeSet<usize>> = Vec::with_capacity(ops.len());
1614 for op in ops {
1615 let vs: BTreeSet<usize> = match op {
1616 TapeOp::Const(_) => BTreeSet::new(),
1617 TapeOp::Var(j) => {
1618 let mut s = BTreeSet::new();
1619 s.insert(*j);
1620 s
1621 }
1622 TapeOp::Add(a, b)
1623 | TapeOp::Sub(a, b)
1624 | TapeOp::Mul(a, b)
1625 | TapeOp::Div(a, b)
1626 | TapeOp::Pow(a, b) => out[*a].union(&out[*b]).copied().collect(),
1627 TapeOp::Neg(a)
1628 | TapeOp::Abs(a)
1629 | TapeOp::Sqrt(a)
1630 | TapeOp::Exp(a)
1631 | TapeOp::Log(a)
1632 | TapeOp::Log10(a)
1633 | TapeOp::Sin(a)
1634 | TapeOp::Cos(a) => out[*a].clone(),
1635 };
1636 out.push(vs);
1637 }
1638 out
1639}
1640
1641fn summand_sparsity(
1646 ops: &[SummandOp],
1647 prelude_var_sets: &[BTreeSet<usize>],
1648 pairs: &mut BTreeSet<(usize, usize)>,
1649) {
1650 let mut var_sets: Vec<BTreeSet<usize>> = Vec::with_capacity(ops.len());
1651 let emit_cross =
1652 |s1: &BTreeSet<usize>, s2: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
1653 for &v1 in s1 {
1654 for &v2 in s2 {
1655 let (r, c) = if v1 >= v2 { (v1, v2) } else { (v2, v1) };
1656 pairs.insert((r, c));
1657 }
1658 }
1659 };
1660 let emit_self = |s: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
1661 let vars: Vec<usize> = s.iter().copied().collect();
1662 for (ai, &vi) in vars.iter().enumerate() {
1663 for &vj in &vars[..=ai] {
1664 let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
1665 pairs.insert((r, c));
1666 }
1667 }
1668 };
1669 for so in ops {
1670 let vset: BTreeSet<usize> = match so {
1671 SummandOp::Shared(k) => prelude_var_sets[*k].clone(),
1672 SummandOp::Local(op) => match op {
1673 TapeOp::Const(_) => BTreeSet::new(),
1674 TapeOp::Var(j) => {
1675 let mut s = BTreeSet::new();
1676 s.insert(*j);
1677 s
1678 }
1679 TapeOp::Add(a, b) | TapeOp::Sub(a, b) => {
1680 var_sets[*a].union(&var_sets[*b]).copied().collect()
1681 }
1682 TapeOp::Neg(a) | TapeOp::Abs(a) => var_sets[*a].clone(),
1683 TapeOp::Mul(a, b) => {
1684 emit_cross(&var_sets[*a], &var_sets[*b], pairs);
1685 var_sets[*a].union(&var_sets[*b]).copied().collect()
1686 }
1687 TapeOp::Div(a, b) => {
1688 emit_cross(&var_sets[*a], &var_sets[*b], pairs);
1689 emit_self(&var_sets[*b], pairs);
1690 var_sets[*a].union(&var_sets[*b]).copied().collect()
1691 }
1692 TapeOp::Pow(a, b) => {
1693 let combined: BTreeSet<usize> =
1694 var_sets[*a].union(&var_sets[*b]).copied().collect();
1695 emit_self(&combined, pairs);
1696 combined
1697 }
1698 TapeOp::Sqrt(a)
1699 | TapeOp::Exp(a)
1700 | TapeOp::Log(a)
1701 | TapeOp::Log10(a)
1702 | TapeOp::Sin(a)
1703 | TapeOp::Cos(a) => {
1704 emit_self(&var_sets[*a], pairs);
1705 var_sets[*a].clone()
1706 }
1707 },
1708 };
1709 var_sets.push(vset);
1710 }
1711}
1712
1713#[inline]
1716fn op_operands(op: &TapeOp) -> (Option<usize>, Option<usize>) {
1717 match op {
1718 TapeOp::Const(_) | TapeOp::Var(_) => (None, None),
1719 TapeOp::Add(a, b)
1720 | TapeOp::Sub(a, b)
1721 | TapeOp::Mul(a, b)
1722 | TapeOp::Div(a, b)
1723 | TapeOp::Pow(a, b) => (Some(*a), Some(*b)),
1724 TapeOp::Neg(a)
1725 | TapeOp::Abs(a)
1726 | TapeOp::Sqrt(a)
1727 | TapeOp::Exp(a)
1728 | TapeOp::Log(a)
1729 | TapeOp::Log10(a)
1730 | TapeOp::Sin(a)
1731 | TapeOp::Cos(a) => (Some(*a), None),
1732 }
1733}
1734
1735fn vars_in(ops: &[TapeOp], reach: &[usize]) -> Vec<usize> {
1736 let mut s: BTreeSet<usize> = BTreeSet::new();
1737 for &i in reach {
1738 if let TapeOp::Var(j) = &ops[i] {
1739 s.insert(*j);
1740 }
1741 }
1742 s.into_iter().collect()
1743}
1744
1745#[inline]
1748fn fwd_step(op: &TapeOp, x: &[f64], vals: &[f64]) -> f64 {
1749 match op {
1750 TapeOp::Const(c) => *c,
1751 TapeOp::Var(i) => x[*i],
1752 TapeOp::Add(a, b) => vals[*a] + vals[*b],
1753 TapeOp::Sub(a, b) => vals[*a] - vals[*b],
1754 TapeOp::Mul(a, b) => vals[*a] * vals[*b],
1755 TapeOp::Div(a, b) => vals[*a] / vals[*b],
1756 TapeOp::Pow(a, b) => vals[*a].powf(vals[*b]),
1757 TapeOp::Neg(a) => -vals[*a],
1758 TapeOp::Abs(a) => vals[*a].abs(),
1759 TapeOp::Sqrt(a) => vals[*a].sqrt(),
1760 TapeOp::Exp(a) => vals[*a].exp(),
1761 TapeOp::Log(a) => vals[*a].ln(),
1762 TapeOp::Log10(a) => vals[*a].log10(),
1763 TapeOp::Sin(a) => vals[*a].sin(),
1764 TapeOp::Cos(a) => vals[*a].cos(),
1765 }
1766}
1767
1768#[inline]
1769fn rev_step(op: &TapeOp, i: usize, vals: &[f64], adj: &mut [f64], a: f64, grad: &mut [f64]) {
1770 match op {
1771 TapeOp::Const(_) => {}
1772 TapeOp::Var(j) => {
1773 grad[*j] += a;
1774 }
1775 TapeOp::Add(l, r) => {
1776 adj[*l] += a;
1777 adj[*r] += a;
1778 }
1779 TapeOp::Sub(l, r) => {
1780 adj[*l] += a;
1781 adj[*r] -= a;
1782 }
1783 TapeOp::Mul(l, r) => {
1784 adj[*l] += a * vals[*r];
1785 adj[*r] += a * vals[*l];
1786 }
1787 TapeOp::Div(l, r) => {
1788 let rv = vals[*r];
1789 adj[*l] += a / rv;
1790 adj[*r] -= a * vals[*l] / (rv * rv);
1791 }
1792 TapeOp::Pow(l, r) => {
1793 let lv = vals[*l];
1794 let rv = vals[*r];
1795 if rv != 0.0 {
1796 adj[*l] += a * rv * lv.powf(rv - 1.0);
1797 }
1798 if lv > 0.0 {
1799 adj[*r] += a * vals[i] * lv.ln();
1800 }
1801 }
1802 TapeOp::Neg(j) => {
1803 adj[*j] -= a;
1804 }
1805 TapeOp::Abs(j) => {
1806 if vals[*j] >= 0.0 {
1807 adj[*j] += a;
1808 } else {
1809 adj[*j] -= a;
1810 }
1811 }
1812 TapeOp::Sqrt(j) => {
1813 let sv = vals[i];
1814 if sv > 0.0 {
1815 adj[*j] += a * 0.5 / sv;
1816 }
1817 }
1818 TapeOp::Exp(j) => {
1819 adj[*j] += a * vals[i];
1820 }
1821 TapeOp::Log(j) => {
1822 adj[*j] += a / vals[*j];
1823 }
1824 TapeOp::Log10(j) => {
1825 adj[*j] += a / (vals[*j] * std::f64::consts::LN_10);
1826 }
1827 TapeOp::Sin(j) => {
1828 adj[*j] += a * vals[*j].cos();
1829 }
1830 TapeOp::Cos(j) => {
1831 adj[*j] -= a * vals[*j].sin();
1832 }
1833 }
1834}
1835
1836#[inline]
1837fn fwd_tan_step(op: &TapeOp, seed_var: usize, vals: &[f64], dot: &[f64], i: usize) -> f64 {
1838 match op {
1839 TapeOp::Const(_) => 0.0,
1840 TapeOp::Var(k) => {
1841 if *k == seed_var {
1842 1.0
1843 } else {
1844 0.0
1845 }
1846 }
1847 TapeOp::Add(a, b) => dot[*a] + dot[*b],
1848 TapeOp::Sub(a, b) => dot[*a] - dot[*b],
1849 TapeOp::Mul(a, b) => dot[*a] * vals[*b] + vals[*a] * dot[*b],
1850 TapeOp::Div(a, b) => {
1851 let vb = vals[*b];
1852 (dot[*a] * vb - vals[*a] * dot[*b]) / (vb * vb)
1853 }
1854 TapeOp::Pow(a, b) => {
1855 let u = vals[*a];
1856 let r = vals[*b];
1857 let du = dot[*a];
1858 let dr = dot[*b];
1859 let mut result = 0.0;
1860 if r != 0.0 && u != 0.0 {
1861 result += r * u.powf(r - 1.0) * du;
1862 }
1863 if u > 0.0 {
1864 result += vals[i] * u.ln() * dr;
1865 }
1866 result
1867 }
1868 TapeOp::Neg(a) => -dot[*a],
1869 TapeOp::Abs(a) => {
1870 if vals[*a] >= 0.0 {
1871 dot[*a]
1872 } else {
1873 -dot[*a]
1874 }
1875 }
1876 TapeOp::Sqrt(a) => {
1877 let sv = vals[i];
1878 if sv > 0.0 {
1879 dot[*a] * 0.5 / sv
1880 } else {
1881 0.0
1882 }
1883 }
1884 TapeOp::Exp(a) => dot[*a] * vals[i],
1885 TapeOp::Log(a) => dot[*a] / vals[*a],
1886 TapeOp::Log10(a) => dot[*a] / (vals[*a] * std::f64::consts::LN_10),
1887 TapeOp::Sin(a) => dot[*a] * vals[*a].cos(),
1888 TapeOp::Cos(a) => -dot[*a] * vals[*a].sin(),
1889 }
1890}
1891
1892#[allow(clippy::too_many_arguments)]
1893#[inline]
1894fn ror_step(
1895 op: &TapeOp,
1896 i: usize,
1897 seed_var: usize,
1898 vals: &[f64],
1899 dot: &[f64],
1900 adj: &mut [f64],
1901 adj_dot: &mut [f64],
1902 w: f64,
1903 wd: f64,
1904 weight: f64,
1905 hess_map: &HashMap<(usize, usize), usize>,
1906 values: &mut [f64],
1907) {
1908 match op {
1909 TapeOp::Const(_) => {}
1910 TapeOp::Var(k) => {
1911 if wd != 0.0 && *k >= seed_var {
1912 if let Some(&pos) = hess_map.get(&(*k, seed_var)) {
1913 values[pos] += weight * wd;
1914 }
1915 }
1916 }
1917 TapeOp::Add(a, b) => {
1918 adj[*a] += w;
1919 adj[*b] += w;
1920 adj_dot[*a] += wd;
1921 adj_dot[*b] += wd;
1922 }
1923 TapeOp::Sub(a, b) => {
1924 adj[*a] += w;
1925 adj[*b] -= w;
1926 adj_dot[*a] += wd;
1927 adj_dot[*b] -= wd;
1928 }
1929 TapeOp::Mul(a, b) => {
1930 adj[*a] += w * vals[*b];
1931 adj[*b] += w * vals[*a];
1932 adj_dot[*a] += wd * vals[*b] + w * dot[*b];
1933 adj_dot[*b] += wd * vals[*a] + w * dot[*a];
1934 }
1935 TapeOp::Div(a, b) => {
1936 let vb = vals[*b];
1937 let vb2 = vb * vb;
1938 let vb3 = vb2 * vb;
1939 adj[*a] += w / vb;
1940 adj_dot[*a] += wd / vb + w * (-dot[*b] / vb2);
1941 adj[*b] += w * (-vals[*a] / vb2);
1942 adj_dot[*b] +=
1943 wd * (-vals[*a] / vb2) + w * (-dot[*a] / vb2 + 2.0 * vals[*a] * dot[*b] / vb3);
1944 }
1945 TapeOp::Pow(a, b) => {
1946 let u = vals[*a];
1947 let r = vals[*b];
1948 let du = dot[*a];
1949 let dr = dot[*b];
1950 if r != 0.0 {
1951 if u != 0.0 {
1952 let p_a = r * u.powf(r - 1.0);
1953 adj[*a] += w * p_a;
1954 let mut dp_a = dr * u.powf(r - 1.0);
1955 if u > 0.0 {
1956 dp_a += r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
1957 } else {
1958 dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
1959 }
1960 adj_dot[*a] += wd * p_a + w * dp_a;
1961 } else if r >= 2.0 {
1962 let p_a = 0.0;
1963 adj[*a] += w * p_a;
1964 let dp_a = if r == 2.0 {
1965 2.0 * du
1966 } else {
1967 r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
1968 };
1969 adj_dot[*a] += wd * p_a + w * dp_a;
1970 }
1971 }
1972 if u > 0.0 {
1973 let ln_u = u.ln();
1974 let p_b = vals[i] * ln_u;
1975 adj[*b] += w * p_b;
1976 let dur = vals[i] * (r * du / u + dr * ln_u);
1977 let dp_b = dur * ln_u + vals[i] * du / u;
1978 adj_dot[*b] += wd * p_b + w * dp_b;
1979 }
1980 }
1981 TapeOp::Neg(a) => {
1982 adj[*a] -= w;
1983 adj_dot[*a] -= wd;
1984 }
1985 TapeOp::Abs(a) => {
1986 let s = if vals[*a] >= 0.0 { 1.0 } else { -1.0 };
1987 adj[*a] += w * s;
1988 adj_dot[*a] += wd * s;
1989 }
1990 TapeOp::Sqrt(a) => {
1991 let sv = vals[i];
1992 if sv > 0.0 {
1993 let fp = 0.5 / sv;
1994 let fpp = -0.25 / (vals[*a] * sv);
1995 adj[*a] += w * fp;
1996 adj_dot[*a] += wd * fp + w * fpp * dot[*a];
1997 }
1998 }
1999 TapeOp::Exp(a) => {
2000 let ev = vals[i];
2001 adj[*a] += w * ev;
2002 adj_dot[*a] += wd * ev + w * ev * dot[*a];
2003 }
2004 TapeOp::Log(a) => {
2005 let u = vals[*a];
2006 adj[*a] += w / u;
2007 adj_dot[*a] += wd / u + w * (-1.0 / (u * u)) * dot[*a];
2008 }
2009 TapeOp::Log10(a) => {
2010 let u = vals[*a];
2011 let c = std::f64::consts::LN_10;
2012 adj[*a] += w / (u * c);
2013 adj_dot[*a] += wd / (u * c) + w * (-1.0 / (u * u * c)) * dot[*a];
2014 }
2015 TapeOp::Sin(a) => {
2016 let u = vals[*a];
2017 let cu = u.cos();
2018 adj[*a] += w * cu;
2019 adj_dot[*a] += wd * cu + w * (-u.sin()) * dot[*a];
2020 }
2021 TapeOp::Cos(a) => {
2022 let u = vals[*a];
2023 let su = u.sin();
2024 adj[*a] -= w * su;
2025 adj_dot[*a] += wd * (-su) + w * (-u.cos()) * dot[*a];
2026 }
2027 }
2028}
2029
2030fn hessian_sparsity_impl(ops: &[TapeOp]) -> BTreeSet<(usize, usize)> {
2034 let n = ops.len();
2035 let mut var_sets: Vec<BTreeSet<usize>> = Vec::with_capacity(n);
2036 let mut pairs: BTreeSet<(usize, usize)> = BTreeSet::new();
2037
2038 let emit_cross =
2039 |s1: &BTreeSet<usize>, s2: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
2040 for &v1 in s1 {
2041 for &v2 in s2 {
2042 let (r, c) = if v1 >= v2 { (v1, v2) } else { (v2, v1) };
2043 pairs.insert((r, c));
2044 }
2045 }
2046 };
2047 let emit_self = |s: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
2048 let vars: Vec<usize> = s.iter().copied().collect();
2049 for (ai, &vi) in vars.iter().enumerate() {
2050 for &vj in &vars[..=ai] {
2051 let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
2052 pairs.insert((r, c));
2053 }
2054 }
2055 };
2056
2057 for op in ops {
2058 let vset = match op {
2059 TapeOp::Const(_) => BTreeSet::new(),
2060 TapeOp::Var(j) => {
2061 let mut s = BTreeSet::new();
2062 s.insert(*j);
2063 s
2064 }
2065 TapeOp::Add(a, b) | TapeOp::Sub(a, b) => {
2066 var_sets[*a].union(&var_sets[*b]).copied().collect()
2067 }
2068 TapeOp::Neg(a) | TapeOp::Abs(a) => var_sets[*a].clone(),
2069 TapeOp::Mul(a, b) => {
2070 emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
2071 var_sets[*a].union(&var_sets[*b]).copied().collect()
2072 }
2073 TapeOp::Div(a, b) => {
2074 emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
2075 emit_self(&var_sets[*b], &mut pairs);
2076 var_sets[*a].union(&var_sets[*b]).copied().collect()
2077 }
2078 TapeOp::Pow(a, b) => {
2079 let combined: BTreeSet<usize> =
2080 var_sets[*a].union(&var_sets[*b]).copied().collect();
2081 emit_self(&combined, &mut pairs);
2082 combined
2083 }
2084 TapeOp::Sqrt(a)
2085 | TapeOp::Exp(a)
2086 | TapeOp::Log(a)
2087 | TapeOp::Log10(a)
2088 | TapeOp::Sin(a)
2089 | TapeOp::Cos(a) => {
2090 emit_self(&var_sets[*a], &mut pairs);
2091 var_sets[*a].clone()
2092 }
2093 };
2094 var_sets.push(vset);
2095 }
2096 pairs
2097}
2098
2099#[cfg(test)]
2100mod tests {
2101 use super::*;
2102
2103 fn cnst(c: f64) -> Expr {
2104 Expr::Const(c)
2105 }
2106 fn var(i: usize) -> Expr {
2107 Expr::Var(i)
2108 }
2109 fn add(a: Expr, b: Expr) -> Expr {
2110 Expr::Binary(BinOp::Add, Box::new(a), Box::new(b))
2111 }
2112 fn mul(a: Expr, b: Expr) -> Expr {
2113 Expr::Binary(BinOp::Mul, Box::new(a), Box::new(b))
2114 }
2115 fn pow(a: Expr, b: Expr) -> Expr {
2116 Expr::Binary(BinOp::Pow, Box::new(a), Box::new(b))
2117 }
2118 fn div(a: Expr, b: Expr) -> Expr {
2119 Expr::Binary(BinOp::Div, Box::new(a), Box::new(b))
2120 }
2121 fn unary(op: UnaryOp, a: Expr) -> Expr {
2122 Expr::Unary(op, Box::new(a))
2123 }
2124
2125 #[test]
2126 fn polynomial_eval_and_grad() {
2127 let e = add(
2129 mul(cnst(3.0), pow(var(0), cnst(2.0))),
2130 mul(cnst(2.0), var(1)),
2131 );
2132 let t = Tape::build(&e);
2133 assert!((t.eval(&[2.0, 3.0]) - 18.0).abs() < 1e-12);
2134 let mut g = vec![0.0; 2];
2135 t.gradient_seed(&[2.0, 3.0], 1.0, &mut g);
2136 assert!((g[0] - 12.0).abs() < 1e-12);
2138 assert!((g[1] - 2.0).abs() < 1e-12);
2139 }
2140
2141 #[test]
2142 fn cse_shared_body_evaluated_once() {
2143 let body = Rc::new(add(var(0), var(1)));
2145 let e = add(
2146 pow(Expr::Cse(body.clone()), cnst(2.0)),
2147 Expr::Cse(body.clone()),
2148 );
2149 let t = Tape::build(&e);
2150 let n_body_adds = t
2152 .ops
2153 .iter()
2154 .filter(|op| {
2155 matches!(op, TapeOp::Add(a, b) if {
2156 matches!(t.ops[*a], TapeOp::Var(0)) && matches!(t.ops[*b], TapeOp::Var(1))
2157 })
2158 })
2159 .count();
2160 assert_eq!(n_body_adds, 1, "CSE body should be emitted exactly once");
2161
2162 assert!((t.eval(&[1.0, 2.0]) - 12.0).abs() < 1e-12);
2164 let mut g = vec![0.0; 2];
2165 t.gradient_seed(&[1.0, 2.0], 1.0, &mut g);
2166 assert!((g[0] - 7.0).abs() < 1e-12);
2168 assert!((g[1] - 7.0).abs() < 1e-12);
2169 }
2170
2171 fn fd_check(tape: &Tape, x: &[f64], n: usize, tol: f64) {
2172 let vars = tape.variables();
2173 let mut hess_map: HashMap<(usize, usize), usize> = HashMap::new();
2174 let mut pairs = Vec::new();
2175 for (ai, &vi) in vars.iter().enumerate() {
2176 for &vj in &vars[..=ai] {
2177 let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
2178 hess_map.entry((r, c)).or_insert_with(|| {
2179 let p = pairs.len();
2180 pairs.push((r, c));
2181 p
2182 });
2183 }
2184 }
2185 let nnz = pairs.len();
2186 let mut ad = vec![0.0; nnz];
2187 tape.hessian_accumulate(x, 1.0, &hess_map, &mut ad);
2188
2189 let mut fd = vec![0.0; nnz];
2190 let mut xp = x.to_vec();
2191 let mut gp = vec![0.0; n];
2192 let mut gm = vec![0.0; n];
2193 for &j in &vars {
2194 let h = (1e-7_f64).max(x[j].abs() * 1e-7);
2195 xp[j] = x[j] + h;
2196 gp.iter_mut().for_each(|v| *v = 0.0);
2197 tape.gradient_seed(&xp, 1.0, &mut gp);
2198 xp[j] = x[j] - h;
2199 gm.iter_mut().for_each(|v| *v = 0.0);
2200 tape.gradient_seed(&xp, 1.0, &mut gm);
2201 xp[j] = x[j];
2202 for &i in &vars {
2203 if i >= j {
2204 if let Some(&pos) = hess_map.get(&(i, j)) {
2205 fd[pos] = (gp[i] - gm[i]) / (2.0 * h);
2206 }
2207 }
2208 }
2209 }
2210 for (k, &(r, c)) in pairs.iter().enumerate() {
2211 let scale = fd[k].abs().max(1.0);
2212 assert!(
2213 (ad[k] - fd[k]).abs() / scale < tol,
2214 "H[{},{}]: AD={:.6e} FD={:.6e}",
2215 r,
2216 c,
2217 ad[k],
2218 fd[k]
2219 );
2220 }
2221 }
2222
2223 #[test]
2224 fn hessian_quadratic_matches_fd() {
2225 let e = add(
2227 add(
2228 mul(cnst(3.0), pow(var(0), cnst(2.0))),
2229 mul(cnst(2.0), mul(var(0), var(1))),
2230 ),
2231 pow(var(1), cnst(2.0)),
2232 );
2233 let t = Tape::build(&e);
2234 fd_check(&t, &[2.0, 3.0], 2, 1e-5);
2235 }
2236
2237 #[test]
2238 fn hessian_transcendental_matches_fd() {
2239 let e = Expr::Sum(vec![
2241 unary(UnaryOp::Exp, var(0)),
2242 unary(UnaryOp::Sin, var(1)),
2243 unary(UnaryOp::Log, var(0)),
2244 unary(UnaryOp::Sqrt, var(1)),
2245 mul(var(0), var(1)),
2246 ]);
2247 let t = Tape::build(&e);
2248 fd_check(&t, &[1.5, 2.0], 2, 1e-5);
2249 }
2250
2251 #[test]
2252 fn hessian_division_matches_fd() {
2253 let e = add(div(var(0), var(1)), unary(UnaryOp::Cos, var(0)));
2255 let t = Tape::build(&e);
2256 fd_check(&t, &[0.5, 1.2], 2, 1e-5);
2257 }
2258
2259 fn directional_matches_accumulate(tape: &Tape, x: &[f64], n: usize) {
2264 let vars = tape.variables();
2265 let mut hess_map: HashMap<(usize, usize), usize> = HashMap::new();
2266 let mut pairs = Vec::new();
2267 for (ai, &vi) in vars.iter().enumerate() {
2268 for &vj in &vars[..=ai] {
2269 let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
2270 hess_map.entry((r, c)).or_insert_with(|| {
2271 let p = pairs.len();
2272 pairs.push((r, c));
2273 p
2274 });
2275 }
2276 }
2277 let nnz = pairs.len();
2278 let mut ad = vec![0.0; nnz];
2279 tape.hessian_accumulate(x, 1.0, &hess_map, &mut ad);
2280
2281 let nops = tape.ops.len();
2282 let mut vals = vec![0.0; nops];
2283 tape.forward_into(x, &mut vals);
2284 let mut dot = vec![0.0; nops];
2285 let mut adj = vec![0.0; nops];
2286 let mut adj_dot = vec![0.0; nops];
2287
2288 for &j in &vars {
2289 let mut seed = vec![0.0; n];
2290 seed[j] = 1.0;
2291 let mut col = vec![0.0; n];
2292 tape.hessian_directional(
2293 &vals,
2294 &seed,
2295 1.0,
2296 &mut col,
2297 &mut dot,
2298 &mut adj,
2299 &mut adj_dot,
2300 );
2301 for &i in &vars {
2302 let (r, c) = if i >= j { (i, j) } else { (j, i) };
2303 let expect = ad[hess_map[&(r, c)]];
2304 assert!(
2305 (col[i] - expect).abs() < 1e-10,
2306 "directional H[{i},{j}] = {} vs accumulate {}",
2307 col[i],
2308 expect
2309 );
2310 }
2311 }
2312 }
2313
2314 #[test]
2315 fn directional_quadratic_matches_accumulate() {
2316 let e = add(
2318 add(
2319 mul(cnst(3.0), pow(var(0), cnst(2.0))),
2320 mul(mul(cnst(2.0), var(0)), var(1)),
2321 ),
2322 pow(var(1), cnst(2.0)),
2323 );
2324 let t = Tape::build(&e);
2325 directional_matches_accumulate(&t, &[0.5, -0.3], 2);
2326 }
2327
2328 #[test]
2329 fn directional_transcendental_matches_accumulate() {
2330 let e = Expr::Sum(vec![
2331 unary(UnaryOp::Exp, var(0)),
2332 unary(UnaryOp::Sin, var(1)),
2333 unary(UnaryOp::Log, var(0)),
2334 unary(UnaryOp::Sqrt, var(1)),
2335 mul(var(0), var(1)),
2336 ]);
2337 let t = Tape::build(&e);
2338 directional_matches_accumulate(&t, &[1.5, 2.0], 2);
2339 }
2340
2341 #[test]
2342 fn directional_with_division_matches_accumulate() {
2343 let e = add(div(var(0), var(1)), unary(UnaryOp::Cos, var(0)));
2344 let t = Tape::build(&e);
2345 directional_matches_accumulate(&t, &[0.5, 1.2], 2);
2346 }
2347
2348 #[test]
2349 fn hessian_sparsity_separable() {
2350 let e = add(unary(UnaryOp::Sin, var(0)), mul(var(1), var(2)));
2352 let t = Tape::build(&e);
2353 let s = t.hessian_sparsity();
2354 assert!(s.contains(&(0, 0)));
2355 assert!(s.contains(&(2, 1)));
2356 assert!(!s.contains(&(1, 0)));
2357 assert!(!s.contains(&(2, 0)));
2358 }
2359
2360 fn count_op<F: Fn(&TapeOp) -> bool>(t: &Tape, pred: F) -> usize {
2361 t.ops.iter().filter(|o| pred(o)).count()
2362 }
2363
2364 #[test]
2365 fn pow_zero_const_folds_to_one() {
2366 let e = pow(var(0), cnst(0.0));
2368 let t = Tape::build(&e);
2369 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2370 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Var(_))), 0);
2371 assert!((t.eval(&[7.0]) - 1.0).abs() < 1e-12);
2372 }
2373
2374 #[test]
2375 fn pow_one_passes_through() {
2376 let e = pow(var(0), cnst(1.0));
2378 let t = Tape::build(&e);
2379 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2380 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Const(_))), 0);
2381 assert!((t.eval(&[3.5]) - 3.5).abs() < 1e-12);
2382 }
2383
2384 #[test]
2385 fn pow_half_lowers_to_sqrt() {
2386 let e = pow(var(0), cnst(0.5));
2387 let t = Tape::build(&e);
2388 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2389 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Sqrt(_))), 1);
2390 assert!((t.eval(&[16.0]) - 4.0).abs() < 1e-12);
2391 }
2392
2393 #[test]
2394 fn pow_two_lowers_to_single_mul() {
2395 let e = pow(var(0), cnst(2.0));
2396 let t = Tape::build(&e);
2397 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2398 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 1);
2399 assert!((t.eval(&[3.0]) - 9.0).abs() < 1e-12);
2400 }
2401
2402 #[test]
2403 fn pow_three_lowers_to_two_muls() {
2404 let e = pow(var(0), cnst(3.0));
2405 let t = Tape::build(&e);
2406 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2407 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 2);
2408 assert!((t.eval(&[2.0]) - 8.0).abs() < 1e-12);
2409 }
2410
2411 #[test]
2412 fn pow_eight_lowers_to_three_muls() {
2413 let e = pow(var(0), cnst(8.0));
2415 let t = Tape::build(&e);
2416 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2417 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 3);
2418 assert!((t.eval(&[2.0]) - 256.0).abs() < 1e-12);
2419 }
2420
2421 #[test]
2422 fn pow_negative_two_lowers_to_div() {
2423 let e = pow(var(0), cnst(-2.0));
2425 let t = Tape::build(&e);
2426 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2427 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Div(..))), 1);
2428 assert!((t.eval(&[4.0]) - (1.0 / 16.0)).abs() < 1e-12);
2429 }
2430
2431 #[test]
2432 fn pow_large_const_stays_generic() {
2433 let e = pow(var(0), cnst(9.0));
2435 let t = Tape::build(&e);
2436 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 1);
2437 }
2438
2439 #[test]
2440 fn pow_non_integer_const_stays_generic() {
2441 let e = pow(var(0), cnst(1.5));
2443 let t = Tape::build(&e);
2444 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 1);
2445 }
2446
2447 #[test]
2448 fn pow_const_through_cse_const() {
2449 let two = Rc::new(cnst(2.0));
2451 let e = Expr::Binary(BinOp::Pow, Box::new(var(0)), Box::new(Expr::Cse(two)));
2452 let t = Tape::build(&e);
2453 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2454 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 1);
2455 }
2456
2457 #[test]
2458 fn hessian_pow_three_matches_fd() {
2459 let e = add(mul(cnst(5.0), pow(var(0), cnst(3.0))), mul(var(0), var(1)));
2461 let t = Tape::build(&e);
2462 fd_check(&t, &[1.7, 0.8], 2, 1e-5);
2463 }
2464
2465 #[test]
2466 fn hessian_pow_negative_matches_fd() {
2467 let e = add(pow(var(0), cnst(-2.0)), pow(var(1), cnst(2.0)));
2469 let t = Tape::build(&e);
2470 fd_check(&t, &[1.3, 2.4], 2, 1e-5);
2471 }
2472
2473 #[test]
2474 fn hessian_pow_half_matches_fd() {
2475 let e = add(pow(var(0), cnst(0.5)), mul(var(0), var(1)));
2477 let t = Tape::build(&e);
2478 fd_check(&t, &[2.5, 1.1], 2, 1e-5);
2479 }
2480
2481 #[test]
2482 fn hessian_sparsity_through_cse() {
2483 let body = Rc::new(add(var(0), var(1)));
2486 let e = add(
2487 pow(Expr::Cse(body.clone()), cnst(2.0)),
2488 Expr::Cse(body.clone()),
2489 );
2490 let t = Tape::build(&e);
2491 let s = t.hessian_sparsity();
2492 assert!(s.contains(&(0, 0)));
2493 assert!(s.contains(&(1, 0)));
2494 assert!(s.contains(&(1, 1)));
2495 assert_eq!(s.len(), 3);
2496 }
2497}