1use std::collections::HashMap;
42
43use super::nl_tape::{Tape, TapeOp};
44
45#[derive(Debug, Clone, Copy)]
49pub enum HOp {
50 FwdLoadVar {
52 dst: u32,
53 x_idx: u32,
54 },
55 FwdLoadConst {
56 dst: u32,
57 c_idx: u32,
58 },
59 FwdAdd {
60 dst: u32,
61 a: u32,
62 b: u32,
63 },
64 FwdSub {
65 dst: u32,
66 a: u32,
67 b: u32,
68 },
69 FwdMul {
70 dst: u32,
71 a: u32,
72 b: u32,
73 },
74 FwdDiv {
75 dst: u32,
76 a: u32,
77 b: u32,
78 },
79 FwdPow {
80 dst: u32,
81 a: u32,
82 b: u32,
83 },
84 FwdNeg {
85 dst: u32,
86 a: u32,
87 },
88 FwdAbs {
89 dst: u32,
90 a: u32,
91 },
92 FwdSqrt {
93 dst: u32,
94 a: u32,
95 },
96 FwdExp {
97 dst: u32,
98 a: u32,
99 },
100 FwdLog {
101 dst: u32,
102 a: u32,
103 },
104 FwdLog10 {
105 dst: u32,
106 a: u32,
107 },
108 FwdSin {
109 dst: u32,
110 a: u32,
111 },
112 FwdCos {
113 dst: u32,
114 a: u32,
115 },
116
117 SetZero {
119 dst: u32,
120 },
121 SetOne {
122 dst: u32,
123 },
124
125 ZeroRange {
127 start: u32,
128 len: u32,
129 },
130
131 DotAdd {
133 dst: u32,
134 a: u32,
135 b: u32,
136 },
137 DotSub {
138 dst: u32,
139 a: u32,
140 b: u32,
141 },
142 DotMul {
144 dst: u32,
145 dot_a: u32,
146 vb: u32,
147 va: u32,
148 dot_b: u32,
149 },
150 DotDiv {
152 dst: u32,
153 dot_a: u32,
154 vb: u32,
155 va: u32,
156 dot_b: u32,
157 },
158 DotSqrt {
160 dst: u32,
161 dot_a: u32,
162 vd: u32,
163 },
164 DotExp {
166 dst: u32,
167 dot_a: u32,
168 vd: u32,
169 },
170 DotLog {
171 dst: u32,
172 dot_a: u32,
173 va: u32,
174 },
175 DotLog10 {
176 dst: u32,
177 dot_a: u32,
178 va: u32,
179 },
180 DotSin {
181 dst: u32,
182 dot_a: u32,
183 va: u32,
184 },
185 DotCos {
186 dst: u32,
187 dot_a: u32,
188 va: u32,
189 },
190 DotNeg {
191 dst: u32,
192 dot_a: u32,
193 },
194 DotAbs {
195 dst: u32,
196 dot_a: u32,
197 va: u32,
198 },
199 DotPow {
202 dst: u32,
203 va: u32,
204 vb: u32,
205 vd: u32,
206 dot_a: u32,
207 dot_b: u32,
208 },
209
210 RevAdd {
215 adj_a: u32,
216 adj_b: u32,
217 adj_dot_a: u32,
218 adj_dot_b: u32,
219 w: u32,
220 wd: u32,
221 },
222 RevSub {
223 adj_a: u32,
224 adj_b: u32,
225 adj_dot_a: u32,
226 adj_dot_b: u32,
227 w: u32,
228 wd: u32,
229 },
230 RevMul {
231 adj_a: u32,
232 adj_b: u32,
233 adj_dot_a: u32,
234 adj_dot_b: u32,
235 w: u32,
236 wd: u32,
237 va: u32,
238 vb: u32,
239 dot_a: u32,
240 dot_b: u32,
241 },
242 RevDiv {
243 adj_a: u32,
244 adj_b: u32,
245 adj_dot_a: u32,
246 adj_dot_b: u32,
247 w: u32,
248 wd: u32,
249 va: u32,
250 vb: u32,
251 dot_a: u32,
252 dot_b: u32,
253 },
254 RevPow {
255 adj_a: u32,
256 adj_b: u32,
257 adj_dot_a: u32,
258 adj_dot_b: u32,
259 w: u32,
260 wd: u32,
261 va: u32,
262 vb: u32,
263 vd: u32,
264 dot_a: u32,
265 dot_b: u32,
266 },
267 RevNeg {
268 adj_a: u32,
269 adj_dot_a: u32,
270 w: u32,
271 wd: u32,
272 },
273 RevAbs {
274 adj_a: u32,
275 adj_dot_a: u32,
276 w: u32,
277 wd: u32,
278 va: u32,
279 },
280 RevSqrt {
281 adj_a: u32,
282 adj_dot_a: u32,
283 w: u32,
284 wd: u32,
285 va: u32,
286 vd: u32,
287 dot_a: u32,
288 },
289 RevExp {
290 adj_a: u32,
291 adj_dot_a: u32,
292 w: u32,
293 wd: u32,
294 vd: u32,
295 dot_a: u32,
296 },
297 RevLog {
298 adj_a: u32,
299 adj_dot_a: u32,
300 w: u32,
301 wd: u32,
302 va: u32,
303 dot_a: u32,
304 },
305 RevLog10 {
306 adj_a: u32,
307 adj_dot_a: u32,
308 w: u32,
309 wd: u32,
310 va: u32,
311 dot_a: u32,
312 },
313 RevSin {
314 adj_a: u32,
315 adj_dot_a: u32,
316 w: u32,
317 wd: u32,
318 va: u32,
319 dot_a: u32,
320 },
321 RevCos {
322 adj_a: u32,
323 adj_dot_a: u32,
324 w: u32,
325 wd: u32,
326 va: u32,
327 dot_a: u32,
328 },
329
330 HessEmit {
333 hess_ptr: u32,
334 adj_dot_slot: u32,
335 },
336}
337
338#[derive(Debug, Clone)]
341pub struct HessianProgram {
342 ops: Vec<HOp>,
343 consts: Vec<f64>,
344 n_slots: u32,
345}
346
347impl HessianProgram {
348 pub fn compile(tape: &Tape, hess_map: &HashMap<(usize, usize), usize>) -> Self {
352 let n = tape.ops.len() as u32;
353 let v_base = 0u32;
354 let dot_base = n;
355 let adj_base = 2 * n;
356 let adj_dot_base = 3 * n;
357 let n_slots = 4 * n;
358
359 let v_slot = |i: u32| v_base + i;
360 let dot_slot = |i: u32| dot_base + i;
361 let adj_slot = |i: u32| adj_base + i;
362 let adj_dot_slot = |i: u32| adj_dot_base + i;
363
364 let reachable = reachable_to_output(tape);
365 let var_indices = tape.variables();
366 let depends_on: Vec<Vec<bool>> = (0..var_indices.len())
368 .map(|k_idx| depends_on_var(tape, var_indices[k_idx]))
369 .collect();
370
371 let mut consts: Vec<f64> = Vec::new();
372 let mut const_intern: HashMap<u64, u32> = HashMap::new();
373 let mut intern_const = |c: f64, consts: &mut Vec<f64>| -> u32 {
374 let bits = c.to_bits();
375 if let Some(&idx) = const_intern.get(&bits) {
376 return idx;
377 }
378 let idx = consts.len() as u32;
379 consts.push(c);
380 const_intern.insert(bits, idx);
381 idx
382 };
383
384 let mut ops: Vec<HOp> = Vec::new();
385
386 for (i, tape_op) in tape.ops.iter().enumerate() {
388 let i = i as u32;
389 let dst = v_slot(i);
390 let op = match *tape_op {
391 TapeOp::Const(c) => HOp::FwdLoadConst {
392 dst,
393 c_idx: intern_const(c, &mut consts),
394 },
395 TapeOp::Var(x_idx) => HOp::FwdLoadVar {
396 dst,
397 x_idx: x_idx as u32,
398 },
399 TapeOp::Add(a, b) => HOp::FwdAdd {
400 dst,
401 a: v_slot(a as u32),
402 b: v_slot(b as u32),
403 },
404 TapeOp::Sub(a, b) => HOp::FwdSub {
405 dst,
406 a: v_slot(a as u32),
407 b: v_slot(b as u32),
408 },
409 TapeOp::Mul(a, b) => HOp::FwdMul {
410 dst,
411 a: v_slot(a as u32),
412 b: v_slot(b as u32),
413 },
414 TapeOp::Div(a, b) => HOp::FwdDiv {
415 dst,
416 a: v_slot(a as u32),
417 b: v_slot(b as u32),
418 },
419 TapeOp::Pow(a, b) => HOp::FwdPow {
420 dst,
421 a: v_slot(a as u32),
422 b: v_slot(b as u32),
423 },
424 TapeOp::Neg(a) => HOp::FwdNeg {
425 dst,
426 a: v_slot(a as u32),
427 },
428 TapeOp::Abs(a) => HOp::FwdAbs {
429 dst,
430 a: v_slot(a as u32),
431 },
432 TapeOp::Sqrt(a) => HOp::FwdSqrt {
433 dst,
434 a: v_slot(a as u32),
435 },
436 TapeOp::Exp(a) => HOp::FwdExp {
437 dst,
438 a: v_slot(a as u32),
439 },
440 TapeOp::Log(a) => HOp::FwdLog {
441 dst,
442 a: v_slot(a as u32),
443 },
444 TapeOp::Log10(a) => HOp::FwdLog10 {
445 dst,
446 a: v_slot(a as u32),
447 },
448 TapeOp::Sin(a) => HOp::FwdSin {
449 dst,
450 a: v_slot(a as u32),
451 },
452 TapeOp::Cos(a) => HOp::FwdCos {
453 dst,
454 a: v_slot(a as u32),
455 },
456 TapeOp::Funcall { .. } => panic!(
457 "HessianProgram path does not support AMPL external functions; \
458 use the Tape (build_with_externals) path instead."
459 ),
460 };
461 ops.push(op);
462 }
463
464 if n == 0 || var_indices.is_empty() {
465 return HessianProgram {
466 ops,
467 consts,
468 n_slots,
469 };
470 }
471
472 for (k_idx, &j) in var_indices.iter().enumerate() {
474 ops.push(HOp::ZeroRange {
476 start: dot_base,
477 len: 3 * n,
478 });
479 ops.push(HOp::SetOne {
480 dst: adj_slot(n - 1),
481 });
482
483 for (i, tape_op) in tape.ops.iter().enumerate() {
487 let i_u = i as u32;
488 if !depends_on[k_idx][i] {
489 continue;
490 }
491 let dst = dot_slot(i_u);
492 let dot_op = match *tape_op {
493 TapeOp::Const(_) => continue,
496 TapeOp::Var(_) => HOp::SetOne { dst },
500 TapeOp::Add(a, b) => HOp::DotAdd {
501 dst,
502 a: dot_slot(a as u32),
503 b: dot_slot(b as u32),
504 },
505 TapeOp::Sub(a, b) => HOp::DotSub {
506 dst,
507 a: dot_slot(a as u32),
508 b: dot_slot(b as u32),
509 },
510 TapeOp::Mul(a, b) => HOp::DotMul {
511 dst,
512 dot_a: dot_slot(a as u32),
513 vb: v_slot(b as u32),
514 va: v_slot(a as u32),
515 dot_b: dot_slot(b as u32),
516 },
517 TapeOp::Div(a, b) => HOp::DotDiv {
518 dst,
519 dot_a: dot_slot(a as u32),
520 vb: v_slot(b as u32),
521 va: v_slot(a as u32),
522 dot_b: dot_slot(b as u32),
523 },
524 TapeOp::Pow(a, b) => HOp::DotPow {
525 dst,
526 va: v_slot(a as u32),
527 vb: v_slot(b as u32),
528 vd: v_slot(i_u),
529 dot_a: dot_slot(a as u32),
530 dot_b: dot_slot(b as u32),
531 },
532 TapeOp::Neg(a) => HOp::DotNeg {
533 dst,
534 dot_a: dot_slot(a as u32),
535 },
536 TapeOp::Abs(a) => HOp::DotAbs {
537 dst,
538 dot_a: dot_slot(a as u32),
539 va: v_slot(a as u32),
540 },
541 TapeOp::Sqrt(a) => HOp::DotSqrt {
542 dst,
543 dot_a: dot_slot(a as u32),
544 vd: v_slot(i_u),
545 },
546 TapeOp::Exp(a) => HOp::DotExp {
547 dst,
548 dot_a: dot_slot(a as u32),
549 vd: v_slot(i_u),
550 },
551 TapeOp::Log(a) => HOp::DotLog {
552 dst,
553 dot_a: dot_slot(a as u32),
554 va: v_slot(a as u32),
555 },
556 TapeOp::Log10(a) => HOp::DotLog10 {
557 dst,
558 dot_a: dot_slot(a as u32),
559 va: v_slot(a as u32),
560 },
561 TapeOp::Sin(a) => HOp::DotSin {
562 dst,
563 dot_a: dot_slot(a as u32),
564 va: v_slot(a as u32),
565 },
566 TapeOp::Cos(a) => HOp::DotCos {
567 dst,
568 dot_a: dot_slot(a as u32),
569 va: v_slot(a as u32),
570 },
571 TapeOp::Funcall { .. } => panic!(
572 "HessianProgram path does not support AMPL external functions; \
573 use the Tape (build_with_externals) path instead."
574 ),
575 };
576 ops.push(dot_op);
577 }
578
579 for i in (0..n as usize).rev() {
582 if !reachable[i] {
583 continue;
584 }
585 let i_u = i as u32;
586 let w = adj_slot(i_u);
587 let wd = adj_dot_slot(i_u);
588 let tape_op = &tape.ops[i];
589 let rev_op = match *tape_op {
590 TapeOp::Const(_) => continue,
591 TapeOp::Var(k) => {
592 if k >= j {
596 if let Some(&ptr) = hess_map.get(&(k, j)) {
597 ops.push(HOp::HessEmit {
598 hess_ptr: ptr as u32,
599 adj_dot_slot: wd,
600 });
601 }
602 }
603 continue;
604 }
605 TapeOp::Add(a, b) => HOp::RevAdd {
606 adj_a: adj_slot(a as u32),
607 adj_b: adj_slot(b as u32),
608 adj_dot_a: adj_dot_slot(a as u32),
609 adj_dot_b: adj_dot_slot(b as u32),
610 w,
611 wd,
612 },
613 TapeOp::Sub(a, b) => HOp::RevSub {
614 adj_a: adj_slot(a as u32),
615 adj_b: adj_slot(b as u32),
616 adj_dot_a: adj_dot_slot(a as u32),
617 adj_dot_b: adj_dot_slot(b as u32),
618 w,
619 wd,
620 },
621 TapeOp::Mul(a, b) => HOp::RevMul {
622 adj_a: adj_slot(a as u32),
623 adj_b: adj_slot(b as u32),
624 adj_dot_a: adj_dot_slot(a as u32),
625 adj_dot_b: adj_dot_slot(b as u32),
626 w,
627 wd,
628 va: v_slot(a as u32),
629 vb: v_slot(b as u32),
630 dot_a: dot_slot(a as u32),
631 dot_b: dot_slot(b as u32),
632 },
633 TapeOp::Div(a, b) => HOp::RevDiv {
634 adj_a: adj_slot(a as u32),
635 adj_b: adj_slot(b as u32),
636 adj_dot_a: adj_dot_slot(a as u32),
637 adj_dot_b: adj_dot_slot(b as u32),
638 w,
639 wd,
640 va: v_slot(a as u32),
641 vb: v_slot(b as u32),
642 dot_a: dot_slot(a as u32),
643 dot_b: dot_slot(b as u32),
644 },
645 TapeOp::Pow(a, b) => HOp::RevPow {
646 adj_a: adj_slot(a as u32),
647 adj_b: adj_slot(b as u32),
648 adj_dot_a: adj_dot_slot(a as u32),
649 adj_dot_b: adj_dot_slot(b as u32),
650 w,
651 wd,
652 va: v_slot(a as u32),
653 vb: v_slot(b as u32),
654 vd: v_slot(i_u),
655 dot_a: dot_slot(a as u32),
656 dot_b: dot_slot(b as u32),
657 },
658 TapeOp::Neg(a) => HOp::RevNeg {
659 adj_a: adj_slot(a as u32),
660 adj_dot_a: adj_dot_slot(a as u32),
661 w,
662 wd,
663 },
664 TapeOp::Abs(a) => HOp::RevAbs {
665 adj_a: adj_slot(a as u32),
666 adj_dot_a: adj_dot_slot(a as u32),
667 w,
668 wd,
669 va: v_slot(a as u32),
670 },
671 TapeOp::Sqrt(a) => HOp::RevSqrt {
672 adj_a: adj_slot(a as u32),
673 adj_dot_a: adj_dot_slot(a as u32),
674 w,
675 wd,
676 va: v_slot(a as u32),
677 vd: v_slot(i_u),
678 dot_a: dot_slot(a as u32),
679 },
680 TapeOp::Exp(a) => HOp::RevExp {
681 adj_a: adj_slot(a as u32),
682 adj_dot_a: adj_dot_slot(a as u32),
683 w,
684 wd,
685 vd: v_slot(i_u),
686 dot_a: dot_slot(a as u32),
687 },
688 TapeOp::Log(a) => HOp::RevLog {
689 adj_a: adj_slot(a as u32),
690 adj_dot_a: adj_dot_slot(a as u32),
691 w,
692 wd,
693 va: v_slot(a as u32),
694 dot_a: dot_slot(a as u32),
695 },
696 TapeOp::Log10(a) => HOp::RevLog10 {
697 adj_a: adj_slot(a as u32),
698 adj_dot_a: adj_dot_slot(a as u32),
699 w,
700 wd,
701 va: v_slot(a as u32),
702 dot_a: dot_slot(a as u32),
703 },
704 TapeOp::Sin(a) => HOp::RevSin {
705 adj_a: adj_slot(a as u32),
706 adj_dot_a: adj_dot_slot(a as u32),
707 w,
708 wd,
709 va: v_slot(a as u32),
710 dot_a: dot_slot(a as u32),
711 },
712 TapeOp::Cos(a) => HOp::RevCos {
713 adj_a: adj_slot(a as u32),
714 adj_dot_a: adj_dot_slot(a as u32),
715 w,
716 wd,
717 va: v_slot(a as u32),
718 dot_a: dot_slot(a as u32),
719 },
720 TapeOp::Funcall { .. } => panic!(
721 "HessianProgram path does not support AMPL external functions; \
722 use the Tape (build_with_externals) path instead."
723 ),
724 };
725 ops.push(rev_op);
726 }
727 }
728
729 HessianProgram {
730 ops,
731 consts,
732 n_slots,
733 }
734 }
735
736 pub fn n_slots(&self) -> usize {
737 self.n_slots as usize
738 }
739
740 pub fn n_ops(&self) -> usize {
741 self.ops.len()
742 }
743
744 pub fn execute(&self, x: &[f64], weight: f64, scratch: &mut [f64], values: &mut [f64]) {
750 debug_assert!(scratch.len() >= self.n_slots as usize);
751 if self.ops.is_empty() || weight == 0.0 {
752 return;
753 }
754 let consts = &self.consts[..];
755 for &op in &self.ops {
756 match op {
757 HOp::FwdLoadVar { dst, x_idx } => {
758 scratch[dst as usize] = x[x_idx as usize];
759 }
760 HOp::FwdLoadConst { dst, c_idx } => {
761 scratch[dst as usize] = consts[c_idx as usize];
762 }
763 HOp::FwdAdd { dst, a, b } => {
764 scratch[dst as usize] = scratch[a as usize] + scratch[b as usize];
765 }
766 HOp::FwdSub { dst, a, b } => {
767 scratch[dst as usize] = scratch[a as usize] - scratch[b as usize];
768 }
769 HOp::FwdMul { dst, a, b } => {
770 scratch[dst as usize] = scratch[a as usize] * scratch[b as usize];
771 }
772 HOp::FwdDiv { dst, a, b } => {
773 scratch[dst as usize] = scratch[a as usize] / scratch[b as usize];
774 }
775 HOp::FwdPow { dst, a, b } => {
776 scratch[dst as usize] = scratch[a as usize].powf(scratch[b as usize]);
777 }
778 HOp::FwdNeg { dst, a } => {
779 scratch[dst as usize] = -scratch[a as usize];
780 }
781 HOp::FwdAbs { dst, a } => {
782 scratch[dst as usize] = scratch[a as usize].abs();
783 }
784 HOp::FwdSqrt { dst, a } => {
785 scratch[dst as usize] = scratch[a as usize].sqrt();
786 }
787 HOp::FwdExp { dst, a } => {
788 scratch[dst as usize] = scratch[a as usize].exp();
789 }
790 HOp::FwdLog { dst, a } => {
791 scratch[dst as usize] = scratch[a as usize].ln();
792 }
793 HOp::FwdLog10 { dst, a } => {
794 scratch[dst as usize] = scratch[a as usize].log10();
795 }
796 HOp::FwdSin { dst, a } => {
797 scratch[dst as usize] = scratch[a as usize].sin();
798 }
799 HOp::FwdCos { dst, a } => {
800 scratch[dst as usize] = scratch[a as usize].cos();
801 }
802
803 HOp::SetZero { dst } => {
804 scratch[dst as usize] = 0.0;
805 }
806 HOp::SetOne { dst } => {
807 scratch[dst as usize] = 1.0;
808 }
809 HOp::ZeroRange { start, len } => {
810 let s = start as usize;
811 let e = s + len as usize;
812 scratch[s..e].fill(0.0);
813 }
814
815 HOp::DotAdd { dst, a, b } => {
816 scratch[dst as usize] = scratch[a as usize] + scratch[b as usize];
817 }
818 HOp::DotSub { dst, a, b } => {
819 scratch[dst as usize] = scratch[a as usize] - scratch[b as usize];
820 }
821 HOp::DotMul {
822 dst,
823 dot_a,
824 vb,
825 va,
826 dot_b,
827 } => {
828 scratch[dst as usize] = scratch[dot_a as usize] * scratch[vb as usize]
829 + scratch[va as usize] * scratch[dot_b as usize];
830 }
831 HOp::DotDiv {
832 dst,
833 dot_a,
834 vb,
835 va,
836 dot_b,
837 } => {
838 let v_b = scratch[vb as usize];
839 scratch[dst as usize] = (scratch[dot_a as usize] * v_b
840 - scratch[va as usize] * scratch[dot_b as usize])
841 / (v_b * v_b);
842 }
843 HOp::DotSqrt { dst, dot_a, vd } => {
844 let svd = scratch[vd as usize];
845 scratch[dst as usize] = if svd > 0.0 {
846 scratch[dot_a as usize] * 0.5 / svd
847 } else {
848 0.0
849 };
850 }
851 HOp::DotExp { dst, dot_a, vd } => {
852 scratch[dst as usize] = scratch[dot_a as usize] * scratch[vd as usize];
853 }
854 HOp::DotLog { dst, dot_a, va } => {
855 scratch[dst as usize] = scratch[dot_a as usize] / scratch[va as usize];
856 }
857 HOp::DotLog10 { dst, dot_a, va } => {
858 scratch[dst as usize] =
859 scratch[dot_a as usize] / (scratch[va as usize] * std::f64::consts::LN_10);
860 }
861 HOp::DotSin { dst, dot_a, va } => {
862 scratch[dst as usize] = scratch[dot_a as usize] * scratch[va as usize].cos();
863 }
864 HOp::DotCos { dst, dot_a, va } => {
865 scratch[dst as usize] = -scratch[dot_a as usize] * scratch[va as usize].sin();
866 }
867 HOp::DotNeg { dst, dot_a } => {
868 scratch[dst as usize] = -scratch[dot_a as usize];
869 }
870 HOp::DotAbs { dst, dot_a, va } => {
871 scratch[dst as usize] = if scratch[va as usize] >= 0.0 {
872 scratch[dot_a as usize]
873 } else {
874 -scratch[dot_a as usize]
875 };
876 }
877 HOp::DotPow {
878 dst,
879 va,
880 vb,
881 vd,
882 dot_a,
883 dot_b,
884 } => {
885 let u = scratch[va as usize];
886 let r = scratch[vb as usize];
887 let du = scratch[dot_a as usize];
888 let dr = scratch[dot_b as usize];
889 let mut result = 0.0;
890 if r != 0.0 && u != 0.0 {
891 result += r * u.powf(r - 1.0) * du;
892 }
893 if u > 0.0 {
894 result += scratch[vd as usize] * u.ln() * dr;
895 }
896 scratch[dst as usize] = result;
897 }
898
899 HOp::RevAdd {
900 adj_a,
901 adj_b,
902 adj_dot_a,
903 adj_dot_b,
904 w,
905 wd,
906 } => {
907 let w_v = scratch[w as usize];
908 let wd_v = scratch[wd as usize];
909 scratch[adj_a as usize] += w_v;
910 scratch[adj_b as usize] += w_v;
911 scratch[adj_dot_a as usize] += wd_v;
912 scratch[adj_dot_b as usize] += wd_v;
913 }
914 HOp::RevSub {
915 adj_a,
916 adj_b,
917 adj_dot_a,
918 adj_dot_b,
919 w,
920 wd,
921 } => {
922 let w_v = scratch[w as usize];
923 let wd_v = scratch[wd as usize];
924 scratch[adj_a as usize] += w_v;
925 scratch[adj_b as usize] -= w_v;
926 scratch[adj_dot_a as usize] += wd_v;
927 scratch[adj_dot_b as usize] -= wd_v;
928 }
929 HOp::RevMul {
930 adj_a,
931 adj_b,
932 adj_dot_a,
933 adj_dot_b,
934 w,
935 wd,
936 va,
937 vb,
938 dot_a,
939 dot_b,
940 } => {
941 let w_v = scratch[w as usize];
942 let wd_v = scratch[wd as usize];
943 let va_v = scratch[va as usize];
944 let vb_v = scratch[vb as usize];
945 let da_v = scratch[dot_a as usize];
946 let db_v = scratch[dot_b as usize];
947 scratch[adj_a as usize] += w_v * vb_v;
948 scratch[adj_b as usize] += w_v * va_v;
949 scratch[adj_dot_a as usize] += wd_v * vb_v + w_v * db_v;
950 scratch[adj_dot_b as usize] += wd_v * va_v + w_v * da_v;
951 }
952 HOp::RevDiv {
953 adj_a,
954 adj_b,
955 adj_dot_a,
956 adj_dot_b,
957 w,
958 wd,
959 va,
960 vb,
961 dot_a,
962 dot_b,
963 } => {
964 let w_v = scratch[w as usize];
965 let wd_v = scratch[wd as usize];
966 let va_v = scratch[va as usize];
967 let vb_v = scratch[vb as usize];
968 let vb2 = vb_v * vb_v;
969 let vb3 = vb2 * vb_v;
970 let da_v = scratch[dot_a as usize];
971 let db_v = scratch[dot_b as usize];
972 scratch[adj_a as usize] += w_v / vb_v;
973 scratch[adj_dot_a as usize] += wd_v / vb_v + w_v * (-db_v / vb2);
974 scratch[adj_b as usize] += w_v * (-va_v / vb2);
975 scratch[adj_dot_b as usize] +=
976 wd_v * (-va_v / vb2) + w_v * (-da_v / vb2 + 2.0 * va_v * db_v / vb3);
977 }
978 HOp::RevPow {
979 adj_a,
980 adj_b,
981 adj_dot_a,
982 adj_dot_b,
983 w,
984 wd,
985 va,
986 vb,
987 vd,
988 dot_a,
989 dot_b,
990 } => {
991 let w_v = scratch[w as usize];
992 let wd_v = scratch[wd as usize];
993 let u = scratch[va as usize];
994 let r = scratch[vb as usize];
995 let du = scratch[dot_a as usize];
996 let dr = scratch[dot_b as usize];
997 if r != 0.0 {
998 if u != 0.0 {
999 let p_a = r * u.powf(r - 1.0);
1000 scratch[adj_a as usize] += w_v * p_a;
1001 let mut dp_a = dr * u.powf(r - 1.0);
1002 if u > 0.0 {
1003 dp_a += r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
1004 } else {
1005 dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
1006 }
1007 scratch[adj_dot_a as usize] += wd_v * p_a + w_v * dp_a;
1008 } else if r >= 2.0 {
1009 let p_a = 0.0;
1010 scratch[adj_a as usize] += w_v * p_a;
1011 let dp_a = if r == 2.0 {
1012 2.0 * du
1013 } else {
1014 r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
1015 };
1016 scratch[adj_dot_a as usize] += wd_v * p_a + w_v * dp_a;
1017 }
1018 }
1019 if u > 0.0 {
1020 let ln_u = u.ln();
1021 let p_b = scratch[vd as usize] * ln_u;
1022 scratch[adj_b as usize] += w_v * p_b;
1023 let dur = scratch[vd as usize] * (r * du / u + dr * ln_u);
1024 let dp_b = dur * ln_u + scratch[vd as usize] * du / u;
1025 scratch[adj_dot_b as usize] += wd_v * p_b + w_v * dp_b;
1026 }
1027 }
1028 HOp::RevNeg {
1029 adj_a,
1030 adj_dot_a,
1031 w,
1032 wd,
1033 } => {
1034 scratch[adj_a as usize] -= scratch[w as usize];
1035 scratch[adj_dot_a as usize] -= scratch[wd as usize];
1036 }
1037 HOp::RevAbs {
1038 adj_a,
1039 adj_dot_a,
1040 w,
1041 wd,
1042 va,
1043 } => {
1044 let s = if scratch[va as usize] >= 0.0 {
1045 1.0
1046 } else {
1047 -1.0
1048 };
1049 scratch[adj_a as usize] += scratch[w as usize] * s;
1050 scratch[adj_dot_a as usize] += scratch[wd as usize] * s;
1051 }
1052 HOp::RevSqrt {
1053 adj_a,
1054 adj_dot_a,
1055 w,
1056 wd,
1057 va: _,
1058 vd,
1059 dot_a,
1060 } => {
1061 let sv = scratch[vd as usize];
1062 if sv > 0.0 {
1063 let fp = 0.5 / sv;
1064 let fpp = -0.25 / (sv * sv * sv);
1065 let w_v = scratch[w as usize];
1066 let wd_v = scratch[wd as usize];
1067 scratch[adj_a as usize] += w_v * fp;
1068 scratch[adj_dot_a as usize] +=
1069 wd_v * fp + w_v * fpp * scratch[dot_a as usize];
1070 }
1071 }
1072 HOp::RevExp {
1073 adj_a,
1074 adj_dot_a,
1075 w,
1076 wd,
1077 vd,
1078 dot_a,
1079 } => {
1080 let ev = scratch[vd as usize];
1081 let w_v = scratch[w as usize];
1082 let wd_v = scratch[wd as usize];
1083 scratch[adj_a as usize] += w_v * ev;
1084 scratch[adj_dot_a as usize] += wd_v * ev + w_v * ev * scratch[dot_a as usize];
1085 }
1086 HOp::RevLog {
1087 adj_a,
1088 adj_dot_a,
1089 w,
1090 wd,
1091 va,
1092 dot_a,
1093 } => {
1094 let u = scratch[va as usize];
1095 let w_v = scratch[w as usize];
1096 let wd_v = scratch[wd as usize];
1097 scratch[adj_a as usize] += w_v / u;
1098 scratch[adj_dot_a as usize] +=
1099 wd_v / u + w_v * (-1.0 / (u * u)) * scratch[dot_a as usize];
1100 }
1101 HOp::RevLog10 {
1102 adj_a,
1103 adj_dot_a,
1104 w,
1105 wd,
1106 va,
1107 dot_a,
1108 } => {
1109 let u = scratch[va as usize];
1110 let c = std::f64::consts::LN_10;
1111 let w_v = scratch[w as usize];
1112 let wd_v = scratch[wd as usize];
1113 scratch[adj_a as usize] += w_v / (u * c);
1114 scratch[adj_dot_a as usize] +=
1115 wd_v / (u * c) + w_v * (-1.0 / (u * u * c)) * scratch[dot_a as usize];
1116 }
1117 HOp::RevSin {
1118 adj_a,
1119 adj_dot_a,
1120 w,
1121 wd,
1122 va,
1123 dot_a,
1124 } => {
1125 let u = scratch[va as usize];
1126 let cu = u.cos();
1127 let w_v = scratch[w as usize];
1128 let wd_v = scratch[wd as usize];
1129 scratch[adj_a as usize] += w_v * cu;
1130 scratch[adj_dot_a as usize] +=
1131 wd_v * cu + w_v * (-u.sin()) * scratch[dot_a as usize];
1132 }
1133 HOp::RevCos {
1134 adj_a,
1135 adj_dot_a,
1136 w,
1137 wd,
1138 va,
1139 dot_a,
1140 } => {
1141 let u = scratch[va as usize];
1142 let su = u.sin();
1143 let w_v = scratch[w as usize];
1144 let wd_v = scratch[wd as usize];
1145 scratch[adj_a as usize] -= w_v * su;
1146 scratch[adj_dot_a as usize] +=
1147 wd_v * (-su) + w_v * (-u.cos()) * scratch[dot_a as usize];
1148 }
1149
1150 HOp::HessEmit {
1151 hess_ptr,
1152 adj_dot_slot,
1153 } => {
1154 values[hess_ptr as usize] += weight * scratch[adj_dot_slot as usize];
1155 }
1156 }
1157 }
1158 }
1159}
1160
1161fn reachable_to_output(tape: &Tape) -> Vec<bool> {
1165 let n = tape.ops.len();
1166 let mut r = vec![false; n];
1167 if n == 0 {
1168 return r;
1169 }
1170 r[n - 1] = true;
1171 for i in (0..n).rev() {
1172 if !r[i] {
1173 continue;
1174 }
1175 match tape.ops[i] {
1176 TapeOp::Const(_) | TapeOp::Var(_) => {}
1177 TapeOp::Add(a, b)
1178 | TapeOp::Sub(a, b)
1179 | TapeOp::Mul(a, b)
1180 | TapeOp::Div(a, b)
1181 | TapeOp::Pow(a, b) => {
1182 r[a] = true;
1183 r[b] = true;
1184 }
1185 TapeOp::Neg(a)
1186 | TapeOp::Abs(a)
1187 | TapeOp::Sqrt(a)
1188 | TapeOp::Exp(a)
1189 | TapeOp::Log(a)
1190 | TapeOp::Log10(a)
1191 | TapeOp::Sin(a)
1192 | TapeOp::Cos(a) => {
1193 r[a] = true;
1194 }
1195 TapeOp::Funcall { .. } => panic!(
1196 "HessianProgram path does not support AMPL external functions; \
1197 use the Tape (build_with_externals) path instead."
1198 ),
1199 }
1200 }
1201 r
1202}
1203
1204fn depends_on_var(tape: &Tape, j: usize) -> Vec<bool> {
1209 let n = tape.ops.len();
1210 let mut d = vec![false; n];
1211 for (i, op) in tape.ops.iter().enumerate() {
1212 d[i] = match *op {
1213 TapeOp::Const(_) => false,
1214 TapeOp::Var(k) => k == j,
1215 TapeOp::Add(a, b)
1216 | TapeOp::Sub(a, b)
1217 | TapeOp::Mul(a, b)
1218 | TapeOp::Div(a, b)
1219 | TapeOp::Pow(a, b) => d[a] || d[b],
1220 TapeOp::Neg(a)
1221 | TapeOp::Abs(a)
1222 | TapeOp::Sqrt(a)
1223 | TapeOp::Exp(a)
1224 | TapeOp::Log(a)
1225 | TapeOp::Log10(a)
1226 | TapeOp::Sin(a)
1227 | TapeOp::Cos(a) => d[a],
1228 TapeOp::Funcall { .. } => panic!(
1229 "HessianProgram path does not support AMPL external functions; \
1230 use the Tape (build_with_externals) path instead."
1231 ),
1232 };
1233 }
1234 d
1235}
1236
1237#[cfg(test)]
1238mod tests {
1239 use super::*;
1240 use crate::nl_reader::{BinOp, Expr, UnaryOp};
1241 use std::collections::BTreeSet;
1242 use std::rc::Rc;
1243
1244 fn cnst(c: f64) -> Expr {
1245 Expr::Const(c)
1246 }
1247 fn var(i: usize) -> Expr {
1248 Expr::Var(i)
1249 }
1250 fn add(a: Expr, b: Expr) -> Expr {
1251 Expr::Binary(BinOp::Add, Box::new(a), Box::new(b))
1252 }
1253 fn mul(a: Expr, b: Expr) -> Expr {
1254 Expr::Binary(BinOp::Mul, Box::new(a), Box::new(b))
1255 }
1256 fn pow(a: Expr, b: Expr) -> Expr {
1257 Expr::Binary(BinOp::Pow, Box::new(a), Box::new(b))
1258 }
1259 fn div(a: Expr, b: Expr) -> Expr {
1260 Expr::Binary(BinOp::Div, Box::new(a), Box::new(b))
1261 }
1262 fn sub(a: Expr, b: Expr) -> Expr {
1263 Expr::Binary(BinOp::Sub, Box::new(a), Box::new(b))
1264 }
1265 fn unary(op: UnaryOp, a: Expr) -> Expr {
1266 Expr::Unary(op, Box::new(a))
1267 }
1268
1269 fn build_hess_map(tape: &Tape) -> (HashMap<(usize, usize), usize>, Vec<(usize, usize)>) {
1272 let vars = tape.variables();
1273 let mut pairs: Vec<(usize, usize)> = Vec::new();
1274 let mut map: HashMap<(usize, usize), usize> = HashMap::new();
1275 for (ai, &vi) in vars.iter().enumerate() {
1276 for &vj in &vars[..=ai] {
1277 let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
1278 map.entry((r, c)).or_insert_with(|| {
1279 let p = pairs.len();
1280 pairs.push((r, c));
1281 p
1282 });
1283 }
1284 }
1285 (map, pairs)
1286 }
1287
1288 fn assert_program_matches_tape(tape: &Tape, x: &[f64], weight: f64) {
1291 let (hess_map, pairs) = build_hess_map(tape);
1292 let nnz = pairs.len();
1293
1294 let mut tape_vals = vec![0.0; nnz];
1295 tape.hessian_accumulate(x, weight, &hess_map, &mut tape_vals);
1296
1297 let program = HessianProgram::compile(tape, &hess_map);
1298 let mut scratch = vec![0.0; program.n_slots()];
1299 let mut prog_vals = vec![0.0; nnz];
1300 program.execute(x, weight, &mut scratch, &mut prog_vals);
1301
1302 for (k, &(r, c)) in pairs.iter().enumerate() {
1303 let tol = tape_vals[k].abs().max(1.0) * 1e-12;
1304 assert!(
1305 (tape_vals[k] - prog_vals[k]).abs() < tol,
1306 "H[{},{}]: tape={:.6e} prog={:.6e}",
1307 r,
1308 c,
1309 tape_vals[k],
1310 prog_vals[k]
1311 );
1312 }
1313 }
1314
1315 #[test]
1316 fn matches_quadratic() {
1317 let e = add(
1318 add(
1319 mul(cnst(3.0), pow(var(0), cnst(2.0))),
1320 mul(cnst(2.0), mul(var(0), var(1))),
1321 ),
1322 pow(var(1), cnst(2.0)),
1323 );
1324 let tape = Tape::build(&e);
1325 assert_program_matches_tape(&tape, &[2.0, 3.0], 1.0);
1326 assert_program_matches_tape(&tape, &[-1.5, 0.7], 2.5);
1327 }
1328
1329 #[test]
1330 fn matches_transcendental() {
1331 let e = Expr::Sum(vec![
1332 unary(UnaryOp::Exp, var(0)),
1333 unary(UnaryOp::Sin, var(1)),
1334 unary(UnaryOp::Log, var(0)),
1335 unary(UnaryOp::Sqrt, var(1)),
1336 mul(var(0), var(1)),
1337 unary(UnaryOp::Cos, add(var(0), var(1))),
1338 ]);
1339 let tape = Tape::build(&e);
1340 assert_program_matches_tape(&tape, &[1.5, 2.0], 1.0);
1341 assert_program_matches_tape(&tape, &[0.3, 4.1], -0.4);
1342 }
1343
1344 #[test]
1345 fn matches_division() {
1346 let e = add(div(var(0), var(1)), unary(UnaryOp::Cos, var(0)));
1347 let tape = Tape::build(&e);
1348 assert_program_matches_tape(&tape, &[0.5, 1.2], 1.0);
1349 }
1350
1351 #[test]
1352 fn matches_through_cse() {
1353 let body = Rc::new(add(var(0), var(1)));
1354 let e = add(
1355 pow(Expr::Cse(body.clone()), cnst(2.0)),
1356 Expr::Cse(body.clone()),
1357 );
1358 let tape = Tape::build(&e);
1359 assert_program_matches_tape(&tape, &[1.0, 2.0], 1.0);
1360 assert_program_matches_tape(&tape, &[-0.5, 3.3], 0.7);
1361 }
1362
1363 #[test]
1364 fn matches_pow_chain() {
1365 let e = add(pow(var(0), cnst(3.0)), pow(var(1), cnst(-2.0)));
1368 let tape = Tape::build(&e);
1369 assert_program_matches_tape(&tape, &[1.7, 0.8], 1.0);
1370 }
1371
1372 #[test]
1373 fn matches_residual_pow_with_var_exponent() {
1374 let e = pow(var(0), var(1));
1378 let tape = Tape::build(&e);
1379 assert_program_matches_tape(&tape, &[2.5, 1.4], 1.0);
1380 assert_program_matches_tape(&tape, &[0.6, 2.1], -1.0);
1381 }
1382
1383 #[test]
1384 fn matches_sub_neg_abs() {
1385 let e = sub(
1386 unary(UnaryOp::Neg, var(0)),
1387 unary(UnaryOp::Abs, sub(var(1), var(0))),
1388 );
1389 let tape = Tape::build(&e);
1390 assert_program_matches_tape(&tape, &[1.0, -2.0], 1.0);
1391 assert_program_matches_tape(&tape, &[-3.5, 4.0], 0.9);
1392 }
1393
1394 #[test]
1395 fn slots_layout_matches_design() {
1396 let e = mul(var(0), var(1));
1397 let tape = Tape::build(&e);
1398 let (hess_map, _) = build_hess_map(&tape);
1399 let prog = HessianProgram::compile(&tape, &hess_map);
1400 assert_eq!(prog.n_slots(), 4 * tape.ops.len());
1401 }
1402
1403 #[test]
1406 fn dependence_matches_hessian_sparsity_for_simple_case() {
1407 let e = add(unary(UnaryOp::Sin, var(0)), mul(var(1), var(2)));
1408 let tape = Tape::build(&e);
1409 let s: BTreeSet<(usize, usize)> = tape.hessian_sparsity();
1410 assert!(s.contains(&(0, 0)));
1413 assert!(s.contains(&(2, 1)));
1414 assert_program_matches_tape(&tape, &[0.7, 1.1, 2.2], 1.0);
1415 }
1416}