1use std::collections::HashMap;
42
43use super::nl_tape::{Tape, TapeOp};
44
45#[derive(Debug, Clone, Copy)]
49pub enum HOp {
50 FwdLoadVar {
52 dst: u32,
53 x_idx: u32,
54 },
55 FwdLoadConst {
56 dst: u32,
57 c_idx: u32,
58 },
59 FwdAdd {
60 dst: u32,
61 a: u32,
62 b: u32,
63 },
64 FwdSub {
65 dst: u32,
66 a: u32,
67 b: u32,
68 },
69 FwdMul {
70 dst: u32,
71 a: u32,
72 b: u32,
73 },
74 FwdDiv {
75 dst: u32,
76 a: u32,
77 b: u32,
78 },
79 FwdPow {
80 dst: u32,
81 a: u32,
82 b: u32,
83 },
84 FwdNeg {
85 dst: u32,
86 a: u32,
87 },
88 FwdAbs {
89 dst: u32,
90 a: u32,
91 },
92 FwdSqrt {
93 dst: u32,
94 a: u32,
95 },
96 FwdExp {
97 dst: u32,
98 a: u32,
99 },
100 FwdLog {
101 dst: u32,
102 a: u32,
103 },
104 FwdLog10 {
105 dst: u32,
106 a: u32,
107 },
108 FwdSin {
109 dst: u32,
110 a: u32,
111 },
112 FwdCos {
113 dst: u32,
114 a: u32,
115 },
116
117 SetZero {
119 dst: u32,
120 },
121 SetOne {
122 dst: u32,
123 },
124
125 ZeroRange {
127 start: u32,
128 len: u32,
129 },
130
131 DotAdd {
133 dst: u32,
134 a: u32,
135 b: u32,
136 },
137 DotSub {
138 dst: u32,
139 a: u32,
140 b: u32,
141 },
142 DotMul {
144 dst: u32,
145 dot_a: u32,
146 vb: u32,
147 va: u32,
148 dot_b: u32,
149 },
150 DotDiv {
152 dst: u32,
153 dot_a: u32,
154 vb: u32,
155 va: u32,
156 dot_b: u32,
157 },
158 DotSqrt {
160 dst: u32,
161 dot_a: u32,
162 vd: u32,
163 },
164 DotExp {
166 dst: u32,
167 dot_a: u32,
168 vd: u32,
169 },
170 DotLog {
171 dst: u32,
172 dot_a: u32,
173 va: u32,
174 },
175 DotLog10 {
176 dst: u32,
177 dot_a: u32,
178 va: u32,
179 },
180 DotSin {
181 dst: u32,
182 dot_a: u32,
183 va: u32,
184 },
185 DotCos {
186 dst: u32,
187 dot_a: u32,
188 va: u32,
189 },
190 DotNeg {
191 dst: u32,
192 dot_a: u32,
193 },
194 DotAbs {
195 dst: u32,
196 dot_a: u32,
197 va: u32,
198 },
199 DotPow {
202 dst: u32,
203 va: u32,
204 vb: u32,
205 vd: u32,
206 dot_a: u32,
207 dot_b: u32,
208 },
209
210 RevAdd {
215 adj_a: u32,
216 adj_b: u32,
217 adj_dot_a: u32,
218 adj_dot_b: u32,
219 w: u32,
220 wd: u32,
221 },
222 RevSub {
223 adj_a: u32,
224 adj_b: u32,
225 adj_dot_a: u32,
226 adj_dot_b: u32,
227 w: u32,
228 wd: u32,
229 },
230 RevMul {
231 adj_a: u32,
232 adj_b: u32,
233 adj_dot_a: u32,
234 adj_dot_b: u32,
235 w: u32,
236 wd: u32,
237 va: u32,
238 vb: u32,
239 dot_a: u32,
240 dot_b: u32,
241 },
242 RevDiv {
243 adj_a: u32,
244 adj_b: u32,
245 adj_dot_a: u32,
246 adj_dot_b: u32,
247 w: u32,
248 wd: u32,
249 va: u32,
250 vb: u32,
251 dot_a: u32,
252 dot_b: u32,
253 },
254 RevPow {
255 adj_a: u32,
256 adj_b: u32,
257 adj_dot_a: u32,
258 adj_dot_b: u32,
259 w: u32,
260 wd: u32,
261 va: u32,
262 vb: u32,
263 vd: u32,
264 dot_a: u32,
265 dot_b: u32,
266 },
267 RevNeg {
268 adj_a: u32,
269 adj_dot_a: u32,
270 w: u32,
271 wd: u32,
272 },
273 RevAbs {
274 adj_a: u32,
275 adj_dot_a: u32,
276 w: u32,
277 wd: u32,
278 va: u32,
279 },
280 RevSqrt {
281 adj_a: u32,
282 adj_dot_a: u32,
283 w: u32,
284 wd: u32,
285 va: u32,
286 vd: u32,
287 dot_a: u32,
288 },
289 RevExp {
290 adj_a: u32,
291 adj_dot_a: u32,
292 w: u32,
293 wd: u32,
294 vd: u32,
295 dot_a: u32,
296 },
297 RevLog {
298 adj_a: u32,
299 adj_dot_a: u32,
300 w: u32,
301 wd: u32,
302 va: u32,
303 dot_a: u32,
304 },
305 RevLog10 {
306 adj_a: u32,
307 adj_dot_a: u32,
308 w: u32,
309 wd: u32,
310 va: u32,
311 dot_a: u32,
312 },
313 RevSin {
314 adj_a: u32,
315 adj_dot_a: u32,
316 w: u32,
317 wd: u32,
318 va: u32,
319 dot_a: u32,
320 },
321 RevCos {
322 adj_a: u32,
323 adj_dot_a: u32,
324 w: u32,
325 wd: u32,
326 va: u32,
327 dot_a: u32,
328 },
329
330 HessEmit {
333 hess_ptr: u32,
334 adj_dot_slot: u32,
335 },
336}
337
338#[derive(Debug, Clone)]
341pub struct HessianProgram {
342 ops: Vec<HOp>,
343 consts: Vec<f64>,
344 n_slots: u32,
345}
346
347impl HessianProgram {
348 pub fn compile(tape: &Tape, hess_map: &HashMap<(usize, usize), usize>) -> Self {
352 let n = tape.ops.len() as u32;
353 let v_base = 0u32;
354 let dot_base = n;
355 let adj_base = 2 * n;
356 let adj_dot_base = 3 * n;
357 let n_slots = 4 * n;
358
359 let v_slot = |i: u32| v_base + i;
360 let dot_slot = |i: u32| dot_base + i;
361 let adj_slot = |i: u32| adj_base + i;
362 let adj_dot_slot = |i: u32| adj_dot_base + i;
363
364 let reachable = reachable_to_output(tape);
365 let var_indices = tape.variables();
366 let depends_on: Vec<Vec<bool>> = (0..var_indices.len())
368 .map(|k_idx| depends_on_var(tape, var_indices[k_idx]))
369 .collect();
370
371 let mut consts: Vec<f64> = Vec::new();
372 let mut const_intern: HashMap<u64, u32> = HashMap::new();
373 let mut intern_const = |c: f64, consts: &mut Vec<f64>| -> u32 {
374 let bits = c.to_bits();
375 if let Some(&idx) = const_intern.get(&bits) {
376 return idx;
377 }
378 let idx = consts.len() as u32;
379 consts.push(c);
380 const_intern.insert(bits, idx);
381 idx
382 };
383
384 let mut ops: Vec<HOp> = Vec::new();
385
386 for (i, tape_op) in tape.ops.iter().enumerate() {
388 let i = i as u32;
389 let dst = v_slot(i);
390 let op = match *tape_op {
391 TapeOp::Const(c) => HOp::FwdLoadConst {
392 dst,
393 c_idx: intern_const(c, &mut consts),
394 },
395 TapeOp::Var(x_idx) => HOp::FwdLoadVar {
396 dst,
397 x_idx: x_idx as u32,
398 },
399 TapeOp::Add(a, b) => HOp::FwdAdd {
400 dst,
401 a: v_slot(a as u32),
402 b: v_slot(b as u32),
403 },
404 TapeOp::Sub(a, b) => HOp::FwdSub {
405 dst,
406 a: v_slot(a as u32),
407 b: v_slot(b as u32),
408 },
409 TapeOp::Mul(a, b) => HOp::FwdMul {
410 dst,
411 a: v_slot(a as u32),
412 b: v_slot(b as u32),
413 },
414 TapeOp::Div(a, b) => HOp::FwdDiv {
415 dst,
416 a: v_slot(a as u32),
417 b: v_slot(b as u32),
418 },
419 TapeOp::Pow(a, b) => HOp::FwdPow {
420 dst,
421 a: v_slot(a as u32),
422 b: v_slot(b as u32),
423 },
424 TapeOp::Neg(a) => HOp::FwdNeg {
425 dst,
426 a: v_slot(a as u32),
427 },
428 TapeOp::Abs(a) => HOp::FwdAbs {
429 dst,
430 a: v_slot(a as u32),
431 },
432 TapeOp::Sqrt(a) => HOp::FwdSqrt {
433 dst,
434 a: v_slot(a as u32),
435 },
436 TapeOp::Exp(a) => HOp::FwdExp {
437 dst,
438 a: v_slot(a as u32),
439 },
440 TapeOp::Log(a) => HOp::FwdLog {
441 dst,
442 a: v_slot(a as u32),
443 },
444 TapeOp::Log10(a) => HOp::FwdLog10 {
445 dst,
446 a: v_slot(a as u32),
447 },
448 TapeOp::Sin(a) => HOp::FwdSin {
449 dst,
450 a: v_slot(a as u32),
451 },
452 TapeOp::Cos(a) => HOp::FwdCos {
453 dst,
454 a: v_slot(a as u32),
455 },
456 };
457 ops.push(op);
458 }
459
460 if n == 0 || var_indices.is_empty() {
461 return HessianProgram {
462 ops,
463 consts,
464 n_slots,
465 };
466 }
467
468 for (k_idx, &j) in var_indices.iter().enumerate() {
470 ops.push(HOp::ZeroRange {
472 start: dot_base,
473 len: 3 * n,
474 });
475 ops.push(HOp::SetOne {
476 dst: adj_slot(n - 1),
477 });
478
479 for (i, tape_op) in tape.ops.iter().enumerate() {
483 let i_u = i as u32;
484 if !depends_on[k_idx][i] {
485 continue;
486 }
487 let dst = dot_slot(i_u);
488 let dot_op = match *tape_op {
489 TapeOp::Const(_) => continue,
492 TapeOp::Var(_) => HOp::SetOne { dst },
496 TapeOp::Add(a, b) => HOp::DotAdd {
497 dst,
498 a: dot_slot(a as u32),
499 b: dot_slot(b as u32),
500 },
501 TapeOp::Sub(a, b) => HOp::DotSub {
502 dst,
503 a: dot_slot(a as u32),
504 b: dot_slot(b as u32),
505 },
506 TapeOp::Mul(a, b) => HOp::DotMul {
507 dst,
508 dot_a: dot_slot(a as u32),
509 vb: v_slot(b as u32),
510 va: v_slot(a as u32),
511 dot_b: dot_slot(b as u32),
512 },
513 TapeOp::Div(a, b) => HOp::DotDiv {
514 dst,
515 dot_a: dot_slot(a as u32),
516 vb: v_slot(b as u32),
517 va: v_slot(a as u32),
518 dot_b: dot_slot(b as u32),
519 },
520 TapeOp::Pow(a, b) => HOp::DotPow {
521 dst,
522 va: v_slot(a as u32),
523 vb: v_slot(b as u32),
524 vd: v_slot(i_u),
525 dot_a: dot_slot(a as u32),
526 dot_b: dot_slot(b as u32),
527 },
528 TapeOp::Neg(a) => HOp::DotNeg {
529 dst,
530 dot_a: dot_slot(a as u32),
531 },
532 TapeOp::Abs(a) => HOp::DotAbs {
533 dst,
534 dot_a: dot_slot(a as u32),
535 va: v_slot(a as u32),
536 },
537 TapeOp::Sqrt(a) => HOp::DotSqrt {
538 dst,
539 dot_a: dot_slot(a as u32),
540 vd: v_slot(i_u),
541 },
542 TapeOp::Exp(a) => HOp::DotExp {
543 dst,
544 dot_a: dot_slot(a as u32),
545 vd: v_slot(i_u),
546 },
547 TapeOp::Log(a) => HOp::DotLog {
548 dst,
549 dot_a: dot_slot(a as u32),
550 va: v_slot(a as u32),
551 },
552 TapeOp::Log10(a) => HOp::DotLog10 {
553 dst,
554 dot_a: dot_slot(a as u32),
555 va: v_slot(a as u32),
556 },
557 TapeOp::Sin(a) => HOp::DotSin {
558 dst,
559 dot_a: dot_slot(a as u32),
560 va: v_slot(a as u32),
561 },
562 TapeOp::Cos(a) => HOp::DotCos {
563 dst,
564 dot_a: dot_slot(a as u32),
565 va: v_slot(a as u32),
566 },
567 };
568 ops.push(dot_op);
569 }
570
571 for i in (0..n as usize).rev() {
574 if !reachable[i] {
575 continue;
576 }
577 let i_u = i as u32;
578 let w = adj_slot(i_u);
579 let wd = adj_dot_slot(i_u);
580 let tape_op = &tape.ops[i];
581 let rev_op = match *tape_op {
582 TapeOp::Const(_) => continue,
583 TapeOp::Var(k) => {
584 if k >= j {
588 if let Some(&ptr) = hess_map.get(&(k, j)) {
589 ops.push(HOp::HessEmit {
590 hess_ptr: ptr as u32,
591 adj_dot_slot: wd,
592 });
593 }
594 }
595 continue;
596 }
597 TapeOp::Add(a, b) => HOp::RevAdd {
598 adj_a: adj_slot(a as u32),
599 adj_b: adj_slot(b as u32),
600 adj_dot_a: adj_dot_slot(a as u32),
601 adj_dot_b: adj_dot_slot(b as u32),
602 w,
603 wd,
604 },
605 TapeOp::Sub(a, b) => HOp::RevSub {
606 adj_a: adj_slot(a as u32),
607 adj_b: adj_slot(b as u32),
608 adj_dot_a: adj_dot_slot(a as u32),
609 adj_dot_b: adj_dot_slot(b as u32),
610 w,
611 wd,
612 },
613 TapeOp::Mul(a, b) => HOp::RevMul {
614 adj_a: adj_slot(a as u32),
615 adj_b: adj_slot(b as u32),
616 adj_dot_a: adj_dot_slot(a as u32),
617 adj_dot_b: adj_dot_slot(b as u32),
618 w,
619 wd,
620 va: v_slot(a as u32),
621 vb: v_slot(b as u32),
622 dot_a: dot_slot(a as u32),
623 dot_b: dot_slot(b as u32),
624 },
625 TapeOp::Div(a, b) => HOp::RevDiv {
626 adj_a: adj_slot(a as u32),
627 adj_b: adj_slot(b as u32),
628 adj_dot_a: adj_dot_slot(a as u32),
629 adj_dot_b: adj_dot_slot(b as u32),
630 w,
631 wd,
632 va: v_slot(a as u32),
633 vb: v_slot(b as u32),
634 dot_a: dot_slot(a as u32),
635 dot_b: dot_slot(b as u32),
636 },
637 TapeOp::Pow(a, b) => HOp::RevPow {
638 adj_a: adj_slot(a as u32),
639 adj_b: adj_slot(b as u32),
640 adj_dot_a: adj_dot_slot(a as u32),
641 adj_dot_b: adj_dot_slot(b as u32),
642 w,
643 wd,
644 va: v_slot(a as u32),
645 vb: v_slot(b as u32),
646 vd: v_slot(i_u),
647 dot_a: dot_slot(a as u32),
648 dot_b: dot_slot(b as u32),
649 },
650 TapeOp::Neg(a) => HOp::RevNeg {
651 adj_a: adj_slot(a as u32),
652 adj_dot_a: adj_dot_slot(a as u32),
653 w,
654 wd,
655 },
656 TapeOp::Abs(a) => HOp::RevAbs {
657 adj_a: adj_slot(a as u32),
658 adj_dot_a: adj_dot_slot(a as u32),
659 w,
660 wd,
661 va: v_slot(a as u32),
662 },
663 TapeOp::Sqrt(a) => HOp::RevSqrt {
664 adj_a: adj_slot(a as u32),
665 adj_dot_a: adj_dot_slot(a as u32),
666 w,
667 wd,
668 va: v_slot(a as u32),
669 vd: v_slot(i_u),
670 dot_a: dot_slot(a as u32),
671 },
672 TapeOp::Exp(a) => HOp::RevExp {
673 adj_a: adj_slot(a as u32),
674 adj_dot_a: adj_dot_slot(a as u32),
675 w,
676 wd,
677 vd: v_slot(i_u),
678 dot_a: dot_slot(a as u32),
679 },
680 TapeOp::Log(a) => HOp::RevLog {
681 adj_a: adj_slot(a as u32),
682 adj_dot_a: adj_dot_slot(a as u32),
683 w,
684 wd,
685 va: v_slot(a as u32),
686 dot_a: dot_slot(a as u32),
687 },
688 TapeOp::Log10(a) => HOp::RevLog10 {
689 adj_a: adj_slot(a as u32),
690 adj_dot_a: adj_dot_slot(a as u32),
691 w,
692 wd,
693 va: v_slot(a as u32),
694 dot_a: dot_slot(a as u32),
695 },
696 TapeOp::Sin(a) => HOp::RevSin {
697 adj_a: adj_slot(a as u32),
698 adj_dot_a: adj_dot_slot(a as u32),
699 w,
700 wd,
701 va: v_slot(a as u32),
702 dot_a: dot_slot(a as u32),
703 },
704 TapeOp::Cos(a) => HOp::RevCos {
705 adj_a: adj_slot(a as u32),
706 adj_dot_a: adj_dot_slot(a as u32),
707 w,
708 wd,
709 va: v_slot(a as u32),
710 dot_a: dot_slot(a as u32),
711 },
712 };
713 ops.push(rev_op);
714 }
715 }
716
717 HessianProgram {
718 ops,
719 consts,
720 n_slots,
721 }
722 }
723
724 pub fn n_slots(&self) -> usize {
725 self.n_slots as usize
726 }
727
728 pub fn n_ops(&self) -> usize {
729 self.ops.len()
730 }
731
732 pub fn execute(&self, x: &[f64], weight: f64, scratch: &mut [f64], values: &mut [f64]) {
738 debug_assert!(scratch.len() >= self.n_slots as usize);
739 if self.ops.is_empty() || weight == 0.0 {
740 return;
741 }
742 let consts = &self.consts[..];
743 for &op in &self.ops {
744 match op {
745 HOp::FwdLoadVar { dst, x_idx } => {
746 scratch[dst as usize] = x[x_idx as usize];
747 }
748 HOp::FwdLoadConst { dst, c_idx } => {
749 scratch[dst as usize] = consts[c_idx as usize];
750 }
751 HOp::FwdAdd { dst, a, b } => {
752 scratch[dst as usize] = scratch[a as usize] + scratch[b as usize];
753 }
754 HOp::FwdSub { dst, a, b } => {
755 scratch[dst as usize] = scratch[a as usize] - scratch[b as usize];
756 }
757 HOp::FwdMul { dst, a, b } => {
758 scratch[dst as usize] = scratch[a as usize] * scratch[b as usize];
759 }
760 HOp::FwdDiv { dst, a, b } => {
761 scratch[dst as usize] = scratch[a as usize] / scratch[b as usize];
762 }
763 HOp::FwdPow { dst, a, b } => {
764 scratch[dst as usize] = scratch[a as usize].powf(scratch[b as usize]);
765 }
766 HOp::FwdNeg { dst, a } => {
767 scratch[dst as usize] = -scratch[a as usize];
768 }
769 HOp::FwdAbs { dst, a } => {
770 scratch[dst as usize] = scratch[a as usize].abs();
771 }
772 HOp::FwdSqrt { dst, a } => {
773 scratch[dst as usize] = scratch[a as usize].sqrt();
774 }
775 HOp::FwdExp { dst, a } => {
776 scratch[dst as usize] = scratch[a as usize].exp();
777 }
778 HOp::FwdLog { dst, a } => {
779 scratch[dst as usize] = scratch[a as usize].ln();
780 }
781 HOp::FwdLog10 { dst, a } => {
782 scratch[dst as usize] = scratch[a as usize].log10();
783 }
784 HOp::FwdSin { dst, a } => {
785 scratch[dst as usize] = scratch[a as usize].sin();
786 }
787 HOp::FwdCos { dst, a } => {
788 scratch[dst as usize] = scratch[a as usize].cos();
789 }
790
791 HOp::SetZero { dst } => {
792 scratch[dst as usize] = 0.0;
793 }
794 HOp::SetOne { dst } => {
795 scratch[dst as usize] = 1.0;
796 }
797 HOp::ZeroRange { start, len } => {
798 let s = start as usize;
799 let e = s + len as usize;
800 scratch[s..e].fill(0.0);
801 }
802
803 HOp::DotAdd { dst, a, b } => {
804 scratch[dst as usize] = scratch[a as usize] + scratch[b as usize];
805 }
806 HOp::DotSub { dst, a, b } => {
807 scratch[dst as usize] = scratch[a as usize] - scratch[b as usize];
808 }
809 HOp::DotMul {
810 dst,
811 dot_a,
812 vb,
813 va,
814 dot_b,
815 } => {
816 scratch[dst as usize] = scratch[dot_a as usize] * scratch[vb as usize]
817 + scratch[va as usize] * scratch[dot_b as usize];
818 }
819 HOp::DotDiv {
820 dst,
821 dot_a,
822 vb,
823 va,
824 dot_b,
825 } => {
826 let v_b = scratch[vb as usize];
827 scratch[dst as usize] = (scratch[dot_a as usize] * v_b
828 - scratch[va as usize] * scratch[dot_b as usize])
829 / (v_b * v_b);
830 }
831 HOp::DotSqrt { dst, dot_a, vd } => {
832 let svd = scratch[vd as usize];
833 scratch[dst as usize] = if svd > 0.0 {
834 scratch[dot_a as usize] * 0.5 / svd
835 } else {
836 0.0
837 };
838 }
839 HOp::DotExp { dst, dot_a, vd } => {
840 scratch[dst as usize] = scratch[dot_a as usize] * scratch[vd as usize];
841 }
842 HOp::DotLog { dst, dot_a, va } => {
843 scratch[dst as usize] = scratch[dot_a as usize] / scratch[va as usize];
844 }
845 HOp::DotLog10 { dst, dot_a, va } => {
846 scratch[dst as usize] =
847 scratch[dot_a as usize] / (scratch[va as usize] * std::f64::consts::LN_10);
848 }
849 HOp::DotSin { dst, dot_a, va } => {
850 scratch[dst as usize] = scratch[dot_a as usize] * scratch[va as usize].cos();
851 }
852 HOp::DotCos { dst, dot_a, va } => {
853 scratch[dst as usize] = -scratch[dot_a as usize] * scratch[va as usize].sin();
854 }
855 HOp::DotNeg { dst, dot_a } => {
856 scratch[dst as usize] = -scratch[dot_a as usize];
857 }
858 HOp::DotAbs { dst, dot_a, va } => {
859 scratch[dst as usize] = if scratch[va as usize] >= 0.0 {
860 scratch[dot_a as usize]
861 } else {
862 -scratch[dot_a as usize]
863 };
864 }
865 HOp::DotPow {
866 dst,
867 va,
868 vb,
869 vd,
870 dot_a,
871 dot_b,
872 } => {
873 let u = scratch[va as usize];
874 let r = scratch[vb as usize];
875 let du = scratch[dot_a as usize];
876 let dr = scratch[dot_b as usize];
877 let mut result = 0.0;
878 if r != 0.0 && u != 0.0 {
879 result += r * u.powf(r - 1.0) * du;
880 }
881 if u > 0.0 {
882 result += scratch[vd as usize] * u.ln() * dr;
883 }
884 scratch[dst as usize] = result;
885 }
886
887 HOp::RevAdd {
888 adj_a,
889 adj_b,
890 adj_dot_a,
891 adj_dot_b,
892 w,
893 wd,
894 } => {
895 let w_v = scratch[w as usize];
896 let wd_v = scratch[wd as usize];
897 scratch[adj_a as usize] += w_v;
898 scratch[adj_b as usize] += w_v;
899 scratch[adj_dot_a as usize] += wd_v;
900 scratch[adj_dot_b as usize] += wd_v;
901 }
902 HOp::RevSub {
903 adj_a,
904 adj_b,
905 adj_dot_a,
906 adj_dot_b,
907 w,
908 wd,
909 } => {
910 let w_v = scratch[w as usize];
911 let wd_v = scratch[wd as usize];
912 scratch[adj_a as usize] += w_v;
913 scratch[adj_b as usize] -= w_v;
914 scratch[adj_dot_a as usize] += wd_v;
915 scratch[adj_dot_b as usize] -= wd_v;
916 }
917 HOp::RevMul {
918 adj_a,
919 adj_b,
920 adj_dot_a,
921 adj_dot_b,
922 w,
923 wd,
924 va,
925 vb,
926 dot_a,
927 dot_b,
928 } => {
929 let w_v = scratch[w as usize];
930 let wd_v = scratch[wd as usize];
931 let va_v = scratch[va as usize];
932 let vb_v = scratch[vb as usize];
933 let da_v = scratch[dot_a as usize];
934 let db_v = scratch[dot_b as usize];
935 scratch[adj_a as usize] += w_v * vb_v;
936 scratch[adj_b as usize] += w_v * va_v;
937 scratch[adj_dot_a as usize] += wd_v * vb_v + w_v * db_v;
938 scratch[adj_dot_b as usize] += wd_v * va_v + w_v * da_v;
939 }
940 HOp::RevDiv {
941 adj_a,
942 adj_b,
943 adj_dot_a,
944 adj_dot_b,
945 w,
946 wd,
947 va,
948 vb,
949 dot_a,
950 dot_b,
951 } => {
952 let w_v = scratch[w as usize];
953 let wd_v = scratch[wd as usize];
954 let va_v = scratch[va as usize];
955 let vb_v = scratch[vb as usize];
956 let vb2 = vb_v * vb_v;
957 let vb3 = vb2 * vb_v;
958 let da_v = scratch[dot_a as usize];
959 let db_v = scratch[dot_b as usize];
960 scratch[adj_a as usize] += w_v / vb_v;
961 scratch[adj_dot_a as usize] += wd_v / vb_v + w_v * (-db_v / vb2);
962 scratch[adj_b as usize] += w_v * (-va_v / vb2);
963 scratch[adj_dot_b as usize] +=
964 wd_v * (-va_v / vb2) + w_v * (-da_v / vb2 + 2.0 * va_v * db_v / vb3);
965 }
966 HOp::RevPow {
967 adj_a,
968 adj_b,
969 adj_dot_a,
970 adj_dot_b,
971 w,
972 wd,
973 va,
974 vb,
975 vd,
976 dot_a,
977 dot_b,
978 } => {
979 let w_v = scratch[w as usize];
980 let wd_v = scratch[wd as usize];
981 let u = scratch[va as usize];
982 let r = scratch[vb as usize];
983 let du = scratch[dot_a as usize];
984 let dr = scratch[dot_b as usize];
985 if r != 0.0 {
986 if u != 0.0 {
987 let p_a = r * u.powf(r - 1.0);
988 scratch[adj_a as usize] += w_v * p_a;
989 let mut dp_a = dr * u.powf(r - 1.0);
990 if u > 0.0 {
991 dp_a += r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
992 } else {
993 dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
994 }
995 scratch[adj_dot_a as usize] += wd_v * p_a + w_v * dp_a;
996 } else if r >= 2.0 {
997 let p_a = 0.0;
998 scratch[adj_a as usize] += w_v * p_a;
999 let dp_a = if r == 2.0 {
1000 2.0 * du
1001 } else {
1002 r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
1003 };
1004 scratch[adj_dot_a as usize] += wd_v * p_a + w_v * dp_a;
1005 }
1006 }
1007 if u > 0.0 {
1008 let ln_u = u.ln();
1009 let p_b = scratch[vd as usize] * ln_u;
1010 scratch[adj_b as usize] += w_v * p_b;
1011 let dur = scratch[vd as usize] * (r * du / u + dr * ln_u);
1012 let dp_b = dur * ln_u + scratch[vd as usize] * du / u;
1013 scratch[adj_dot_b as usize] += wd_v * p_b + w_v * dp_b;
1014 }
1015 }
1016 HOp::RevNeg {
1017 adj_a,
1018 adj_dot_a,
1019 w,
1020 wd,
1021 } => {
1022 scratch[adj_a as usize] -= scratch[w as usize];
1023 scratch[adj_dot_a as usize] -= scratch[wd as usize];
1024 }
1025 HOp::RevAbs {
1026 adj_a,
1027 adj_dot_a,
1028 w,
1029 wd,
1030 va,
1031 } => {
1032 let s = if scratch[va as usize] >= 0.0 {
1033 1.0
1034 } else {
1035 -1.0
1036 };
1037 scratch[adj_a as usize] += scratch[w as usize] * s;
1038 scratch[adj_dot_a as usize] += scratch[wd as usize] * s;
1039 }
1040 HOp::RevSqrt {
1041 adj_a,
1042 adj_dot_a,
1043 w,
1044 wd,
1045 va: _,
1046 vd,
1047 dot_a,
1048 } => {
1049 let sv = scratch[vd as usize];
1050 if sv > 0.0 {
1051 let fp = 0.5 / sv;
1052 let fpp = -0.25 / (sv * sv * sv);
1053 let w_v = scratch[w as usize];
1054 let wd_v = scratch[wd as usize];
1055 scratch[adj_a as usize] += w_v * fp;
1056 scratch[adj_dot_a as usize] +=
1057 wd_v * fp + w_v * fpp * scratch[dot_a as usize];
1058 }
1059 }
1060 HOp::RevExp {
1061 adj_a,
1062 adj_dot_a,
1063 w,
1064 wd,
1065 vd,
1066 dot_a,
1067 } => {
1068 let ev = scratch[vd as usize];
1069 let w_v = scratch[w as usize];
1070 let wd_v = scratch[wd as usize];
1071 scratch[adj_a as usize] += w_v * ev;
1072 scratch[adj_dot_a as usize] += wd_v * ev + w_v * ev * scratch[dot_a as usize];
1073 }
1074 HOp::RevLog {
1075 adj_a,
1076 adj_dot_a,
1077 w,
1078 wd,
1079 va,
1080 dot_a,
1081 } => {
1082 let u = scratch[va as usize];
1083 let w_v = scratch[w as usize];
1084 let wd_v = scratch[wd as usize];
1085 scratch[adj_a as usize] += w_v / u;
1086 scratch[adj_dot_a as usize] +=
1087 wd_v / u + w_v * (-1.0 / (u * u)) * scratch[dot_a as usize];
1088 }
1089 HOp::RevLog10 {
1090 adj_a,
1091 adj_dot_a,
1092 w,
1093 wd,
1094 va,
1095 dot_a,
1096 } => {
1097 let u = scratch[va as usize];
1098 let c = std::f64::consts::LN_10;
1099 let w_v = scratch[w as usize];
1100 let wd_v = scratch[wd as usize];
1101 scratch[adj_a as usize] += w_v / (u * c);
1102 scratch[adj_dot_a as usize] +=
1103 wd_v / (u * c) + w_v * (-1.0 / (u * u * c)) * scratch[dot_a as usize];
1104 }
1105 HOp::RevSin {
1106 adj_a,
1107 adj_dot_a,
1108 w,
1109 wd,
1110 va,
1111 dot_a,
1112 } => {
1113 let u = scratch[va as usize];
1114 let cu = u.cos();
1115 let w_v = scratch[w as usize];
1116 let wd_v = scratch[wd as usize];
1117 scratch[adj_a as usize] += w_v * cu;
1118 scratch[adj_dot_a as usize] +=
1119 wd_v * cu + w_v * (-u.sin()) * scratch[dot_a as usize];
1120 }
1121 HOp::RevCos {
1122 adj_a,
1123 adj_dot_a,
1124 w,
1125 wd,
1126 va,
1127 dot_a,
1128 } => {
1129 let u = scratch[va as usize];
1130 let su = u.sin();
1131 let w_v = scratch[w as usize];
1132 let wd_v = scratch[wd as usize];
1133 scratch[adj_a as usize] -= w_v * su;
1134 scratch[adj_dot_a as usize] +=
1135 wd_v * (-su) + w_v * (-u.cos()) * scratch[dot_a as usize];
1136 }
1137
1138 HOp::HessEmit {
1139 hess_ptr,
1140 adj_dot_slot,
1141 } => {
1142 values[hess_ptr as usize] += weight * scratch[adj_dot_slot as usize];
1143 }
1144 }
1145 }
1146 }
1147}
1148
1149fn reachable_to_output(tape: &Tape) -> Vec<bool> {
1153 let n = tape.ops.len();
1154 let mut r = vec![false; n];
1155 if n == 0 {
1156 return r;
1157 }
1158 r[n - 1] = true;
1159 for i in (0..n).rev() {
1160 if !r[i] {
1161 continue;
1162 }
1163 match tape.ops[i] {
1164 TapeOp::Const(_) | TapeOp::Var(_) => {}
1165 TapeOp::Add(a, b)
1166 | TapeOp::Sub(a, b)
1167 | TapeOp::Mul(a, b)
1168 | TapeOp::Div(a, b)
1169 | TapeOp::Pow(a, b) => {
1170 r[a] = true;
1171 r[b] = true;
1172 }
1173 TapeOp::Neg(a)
1174 | TapeOp::Abs(a)
1175 | TapeOp::Sqrt(a)
1176 | TapeOp::Exp(a)
1177 | TapeOp::Log(a)
1178 | TapeOp::Log10(a)
1179 | TapeOp::Sin(a)
1180 | TapeOp::Cos(a) => {
1181 r[a] = true;
1182 }
1183 }
1184 }
1185 r
1186}
1187
1188fn depends_on_var(tape: &Tape, j: usize) -> Vec<bool> {
1193 let n = tape.ops.len();
1194 let mut d = vec![false; n];
1195 for (i, op) in tape.ops.iter().enumerate() {
1196 d[i] = match *op {
1197 TapeOp::Const(_) => false,
1198 TapeOp::Var(k) => k == j,
1199 TapeOp::Add(a, b)
1200 | TapeOp::Sub(a, b)
1201 | TapeOp::Mul(a, b)
1202 | TapeOp::Div(a, b)
1203 | TapeOp::Pow(a, b) => d[a] || d[b],
1204 TapeOp::Neg(a)
1205 | TapeOp::Abs(a)
1206 | TapeOp::Sqrt(a)
1207 | TapeOp::Exp(a)
1208 | TapeOp::Log(a)
1209 | TapeOp::Log10(a)
1210 | TapeOp::Sin(a)
1211 | TapeOp::Cos(a) => d[a],
1212 };
1213 }
1214 d
1215}
1216
1217#[cfg(test)]
1218mod tests {
1219 use super::*;
1220 use crate::nl_reader::{BinOp, Expr, UnaryOp};
1221 use std::collections::BTreeSet;
1222 use std::rc::Rc;
1223
1224 fn cnst(c: f64) -> Expr {
1225 Expr::Const(c)
1226 }
1227 fn var(i: usize) -> Expr {
1228 Expr::Var(i)
1229 }
1230 fn add(a: Expr, b: Expr) -> Expr {
1231 Expr::Binary(BinOp::Add, Box::new(a), Box::new(b))
1232 }
1233 fn mul(a: Expr, b: Expr) -> Expr {
1234 Expr::Binary(BinOp::Mul, Box::new(a), Box::new(b))
1235 }
1236 fn pow(a: Expr, b: Expr) -> Expr {
1237 Expr::Binary(BinOp::Pow, Box::new(a), Box::new(b))
1238 }
1239 fn div(a: Expr, b: Expr) -> Expr {
1240 Expr::Binary(BinOp::Div, Box::new(a), Box::new(b))
1241 }
1242 fn sub(a: Expr, b: Expr) -> Expr {
1243 Expr::Binary(BinOp::Sub, Box::new(a), Box::new(b))
1244 }
1245 fn unary(op: UnaryOp, a: Expr) -> Expr {
1246 Expr::Unary(op, Box::new(a))
1247 }
1248
1249 fn build_hess_map(tape: &Tape) -> (HashMap<(usize, usize), usize>, Vec<(usize, usize)>) {
1252 let vars = tape.variables();
1253 let mut pairs: Vec<(usize, usize)> = Vec::new();
1254 let mut map: HashMap<(usize, usize), usize> = HashMap::new();
1255 for (ai, &vi) in vars.iter().enumerate() {
1256 for &vj in &vars[..=ai] {
1257 let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
1258 map.entry((r, c)).or_insert_with(|| {
1259 let p = pairs.len();
1260 pairs.push((r, c));
1261 p
1262 });
1263 }
1264 }
1265 (map, pairs)
1266 }
1267
1268 fn assert_program_matches_tape(tape: &Tape, x: &[f64], weight: f64) {
1271 let (hess_map, pairs) = build_hess_map(tape);
1272 let nnz = pairs.len();
1273
1274 let mut tape_vals = vec![0.0; nnz];
1275 tape.hessian_accumulate(x, weight, &hess_map, &mut tape_vals);
1276
1277 let program = HessianProgram::compile(tape, &hess_map);
1278 let mut scratch = vec![0.0; program.n_slots()];
1279 let mut prog_vals = vec![0.0; nnz];
1280 program.execute(x, weight, &mut scratch, &mut prog_vals);
1281
1282 for (k, &(r, c)) in pairs.iter().enumerate() {
1283 let tol = tape_vals[k].abs().max(1.0) * 1e-12;
1284 assert!(
1285 (tape_vals[k] - prog_vals[k]).abs() < tol,
1286 "H[{},{}]: tape={:.6e} prog={:.6e}",
1287 r,
1288 c,
1289 tape_vals[k],
1290 prog_vals[k]
1291 );
1292 }
1293 }
1294
1295 #[test]
1296 fn matches_quadratic() {
1297 let e = add(
1298 add(
1299 mul(cnst(3.0), pow(var(0), cnst(2.0))),
1300 mul(cnst(2.0), mul(var(0), var(1))),
1301 ),
1302 pow(var(1), cnst(2.0)),
1303 );
1304 let tape = Tape::build(&e);
1305 assert_program_matches_tape(&tape, &[2.0, 3.0], 1.0);
1306 assert_program_matches_tape(&tape, &[-1.5, 0.7], 2.5);
1307 }
1308
1309 #[test]
1310 fn matches_transcendental() {
1311 let e = Expr::Sum(vec![
1312 unary(UnaryOp::Exp, var(0)),
1313 unary(UnaryOp::Sin, var(1)),
1314 unary(UnaryOp::Log, var(0)),
1315 unary(UnaryOp::Sqrt, var(1)),
1316 mul(var(0), var(1)),
1317 unary(UnaryOp::Cos, add(var(0), var(1))),
1318 ]);
1319 let tape = Tape::build(&e);
1320 assert_program_matches_tape(&tape, &[1.5, 2.0], 1.0);
1321 assert_program_matches_tape(&tape, &[0.3, 4.1], -0.4);
1322 }
1323
1324 #[test]
1325 fn matches_division() {
1326 let e = add(div(var(0), var(1)), unary(UnaryOp::Cos, var(0)));
1327 let tape = Tape::build(&e);
1328 assert_program_matches_tape(&tape, &[0.5, 1.2], 1.0);
1329 }
1330
1331 #[test]
1332 fn matches_through_cse() {
1333 let body = Rc::new(add(var(0), var(1)));
1334 let e = add(
1335 pow(Expr::Cse(body.clone()), cnst(2.0)),
1336 Expr::Cse(body.clone()),
1337 );
1338 let tape = Tape::build(&e);
1339 assert_program_matches_tape(&tape, &[1.0, 2.0], 1.0);
1340 assert_program_matches_tape(&tape, &[-0.5, 3.3], 0.7);
1341 }
1342
1343 #[test]
1344 fn matches_pow_chain() {
1345 let e = add(pow(var(0), cnst(3.0)), pow(var(1), cnst(-2.0)));
1348 let tape = Tape::build(&e);
1349 assert_program_matches_tape(&tape, &[1.7, 0.8], 1.0);
1350 }
1351
1352 #[test]
1353 fn matches_residual_pow_with_var_exponent() {
1354 let e = pow(var(0), var(1));
1358 let tape = Tape::build(&e);
1359 assert_program_matches_tape(&tape, &[2.5, 1.4], 1.0);
1360 assert_program_matches_tape(&tape, &[0.6, 2.1], -1.0);
1361 }
1362
1363 #[test]
1364 fn matches_sub_neg_abs() {
1365 let e = sub(
1366 unary(UnaryOp::Neg, var(0)),
1367 unary(UnaryOp::Abs, sub(var(1), var(0))),
1368 );
1369 let tape = Tape::build(&e);
1370 assert_program_matches_tape(&tape, &[1.0, -2.0], 1.0);
1371 assert_program_matches_tape(&tape, &[-3.5, 4.0], 0.9);
1372 }
1373
1374 #[test]
1375 fn slots_layout_matches_design() {
1376 let e = mul(var(0), var(1));
1377 let tape = Tape::build(&e);
1378 let (hess_map, _) = build_hess_map(&tape);
1379 let prog = HessianProgram::compile(&tape, &hess_map);
1380 assert_eq!(prog.n_slots(), 4 * tape.ops.len());
1381 }
1382
1383 #[test]
1386 fn dependence_matches_hessian_sparsity_for_simple_case() {
1387 let e = add(unary(UnaryOp::Sin, var(0)), mul(var(1), var(2)));
1388 let tape = Tape::build(&e);
1389 let s: BTreeSet<(usize, usize)> = tape.hessian_sparsity();
1390 assert!(s.contains(&(0, 0)));
1393 assert!(s.contains(&(2, 1)));
1394 assert_program_matches_tape(&tape, &[0.7, 1.1, 2.2], 1.0);
1395 }
1396}