1use std::collections::HashMap;
42
43use super::nl_tape::{Tape, TapeOp};
44
45#[derive(Debug, Clone, Copy)]
49pub enum HOp {
50 FwdLoadVar {
52 dst: u32,
53 x_idx: u32,
54 },
55 FwdLoadConst {
56 dst: u32,
57 c_idx: u32,
58 },
59 FwdAdd {
60 dst: u32,
61 a: u32,
62 b: u32,
63 },
64 FwdSub {
65 dst: u32,
66 a: u32,
67 b: u32,
68 },
69 FwdMul {
70 dst: u32,
71 a: u32,
72 b: u32,
73 },
74 FwdDiv {
75 dst: u32,
76 a: u32,
77 b: u32,
78 },
79 FwdPow {
80 dst: u32,
81 a: u32,
82 b: u32,
83 },
84 FwdNeg {
85 dst: u32,
86 a: u32,
87 },
88 FwdAbs {
89 dst: u32,
90 a: u32,
91 },
92 FwdSqrt {
93 dst: u32,
94 a: u32,
95 },
96 FwdExp {
97 dst: u32,
98 a: u32,
99 },
100 FwdLog {
101 dst: u32,
102 a: u32,
103 },
104 FwdLog10 {
105 dst: u32,
106 a: u32,
107 },
108 FwdSin {
109 dst: u32,
110 a: u32,
111 },
112 FwdCos {
113 dst: u32,
114 a: u32,
115 },
116
117 SetZero {
119 dst: u32,
120 },
121 SetOne {
122 dst: u32,
123 },
124
125 ZeroRange {
127 start: u32,
128 len: u32,
129 },
130
131 DotAdd {
133 dst: u32,
134 a: u32,
135 b: u32,
136 },
137 DotSub {
138 dst: u32,
139 a: u32,
140 b: u32,
141 },
142 DotMul {
144 dst: u32,
145 dot_a: u32,
146 vb: u32,
147 va: u32,
148 dot_b: u32,
149 },
150 DotDiv {
152 dst: u32,
153 dot_a: u32,
154 vb: u32,
155 va: u32,
156 dot_b: u32,
157 },
158 DotSqrt {
160 dst: u32,
161 dot_a: u32,
162 vd: u32,
163 },
164 DotExp {
166 dst: u32,
167 dot_a: u32,
168 vd: u32,
169 },
170 DotLog {
171 dst: u32,
172 dot_a: u32,
173 va: u32,
174 },
175 DotLog10 {
176 dst: u32,
177 dot_a: u32,
178 va: u32,
179 },
180 DotSin {
181 dst: u32,
182 dot_a: u32,
183 va: u32,
184 },
185 DotCos {
186 dst: u32,
187 dot_a: u32,
188 va: u32,
189 },
190 DotNeg {
191 dst: u32,
192 dot_a: u32,
193 },
194 DotAbs {
195 dst: u32,
196 dot_a: u32,
197 va: u32,
198 },
199 DotPow {
202 dst: u32,
203 va: u32,
204 vb: u32,
205 vd: u32,
206 dot_a: u32,
207 dot_b: u32,
208 },
209
210 RevAdd {
215 adj_a: u32,
216 adj_b: u32,
217 adj_dot_a: u32,
218 adj_dot_b: u32,
219 w: u32,
220 wd: u32,
221 },
222 RevSub {
223 adj_a: u32,
224 adj_b: u32,
225 adj_dot_a: u32,
226 adj_dot_b: u32,
227 w: u32,
228 wd: u32,
229 },
230 RevMul {
231 adj_a: u32,
232 adj_b: u32,
233 adj_dot_a: u32,
234 adj_dot_b: u32,
235 w: u32,
236 wd: u32,
237 va: u32,
238 vb: u32,
239 dot_a: u32,
240 dot_b: u32,
241 },
242 RevDiv {
243 adj_a: u32,
244 adj_b: u32,
245 adj_dot_a: u32,
246 adj_dot_b: u32,
247 w: u32,
248 wd: u32,
249 va: u32,
250 vb: u32,
251 dot_a: u32,
252 dot_b: u32,
253 },
254 RevPow {
255 adj_a: u32,
256 adj_b: u32,
257 adj_dot_a: u32,
258 adj_dot_b: u32,
259 w: u32,
260 wd: u32,
261 va: u32,
262 vb: u32,
263 vd: u32,
264 dot_a: u32,
265 dot_b: u32,
266 },
267 RevNeg {
268 adj_a: u32,
269 adj_dot_a: u32,
270 w: u32,
271 wd: u32,
272 },
273 RevAbs {
274 adj_a: u32,
275 adj_dot_a: u32,
276 w: u32,
277 wd: u32,
278 va: u32,
279 },
280 RevSqrt {
281 adj_a: u32,
282 adj_dot_a: u32,
283 w: u32,
284 wd: u32,
285 va: u32,
286 vd: u32,
287 dot_a: u32,
288 },
289 RevExp {
290 adj_a: u32,
291 adj_dot_a: u32,
292 w: u32,
293 wd: u32,
294 vd: u32,
295 dot_a: u32,
296 },
297 RevLog {
298 adj_a: u32,
299 adj_dot_a: u32,
300 w: u32,
301 wd: u32,
302 va: u32,
303 dot_a: u32,
304 },
305 RevLog10 {
306 adj_a: u32,
307 adj_dot_a: u32,
308 w: u32,
309 wd: u32,
310 va: u32,
311 dot_a: u32,
312 },
313 RevSin {
314 adj_a: u32,
315 adj_dot_a: u32,
316 w: u32,
317 wd: u32,
318 va: u32,
319 dot_a: u32,
320 },
321 RevCos {
322 adj_a: u32,
323 adj_dot_a: u32,
324 w: u32,
325 wd: u32,
326 va: u32,
327 dot_a: u32,
328 },
329
330 HessEmit {
333 hess_ptr: u32,
334 adj_dot_slot: u32,
335 },
336}
337
338#[derive(Debug, Clone)]
341pub struct HessianProgram {
342 ops: Vec<HOp>,
343 consts: Vec<f64>,
344 n_slots: u32,
345}
346
347impl HessianProgram {
348 pub fn compile(tape: &Tape, hess_map: &HashMap<(usize, usize), usize>) -> Self {
352 let n = tape.ops.len() as u32;
353 let v_base = 0u32;
354 let dot_base = n;
355 let adj_base = 2 * n;
356 let adj_dot_base = 3 * n;
357 let n_slots = 4 * n;
358
359 let v_slot = |i: u32| v_base + i;
360 let dot_slot = |i: u32| dot_base + i;
361 let adj_slot = |i: u32| adj_base + i;
362 let adj_dot_slot = |i: u32| adj_dot_base + i;
363
364 let reachable = reachable_to_output(tape);
365 let var_indices = tape.variables();
366 let depends_on: Vec<Vec<bool>> = (0..var_indices.len())
368 .map(|k_idx| depends_on_var(tape, var_indices[k_idx]))
369 .collect();
370
371 let mut consts: Vec<f64> = Vec::new();
372 let mut const_intern: HashMap<u64, u32> = HashMap::new();
373 let mut intern_const = |c: f64, consts: &mut Vec<f64>| -> u32 {
374 let bits = c.to_bits();
375 if let Some(&idx) = const_intern.get(&bits) {
376 return idx;
377 }
378 let idx = consts.len() as u32;
379 consts.push(c);
380 const_intern.insert(bits, idx);
381 idx
382 };
383
384 let mut ops: Vec<HOp> = Vec::new();
385
386 for (i, tape_op) in tape.ops.iter().enumerate() {
388 let i = i as u32;
389 let dst = v_slot(i);
390 let op = match *tape_op {
391 TapeOp::Const(c) => HOp::FwdLoadConst {
392 dst,
393 c_idx: intern_const(c, &mut consts),
394 },
395 TapeOp::Var(x_idx) => HOp::FwdLoadVar {
396 dst,
397 x_idx: x_idx as u32,
398 },
399 TapeOp::Add(a, b) => HOp::FwdAdd {
400 dst,
401 a: v_slot(a as u32),
402 b: v_slot(b as u32),
403 },
404 TapeOp::Sub(a, b) => HOp::FwdSub {
405 dst,
406 a: v_slot(a as u32),
407 b: v_slot(b as u32),
408 },
409 TapeOp::Mul(a, b) => HOp::FwdMul {
410 dst,
411 a: v_slot(a as u32),
412 b: v_slot(b as u32),
413 },
414 TapeOp::Div(a, b) => HOp::FwdDiv {
415 dst,
416 a: v_slot(a as u32),
417 b: v_slot(b as u32),
418 },
419 TapeOp::Pow(a, b) => HOp::FwdPow {
420 dst,
421 a: v_slot(a as u32),
422 b: v_slot(b as u32),
423 },
424 TapeOp::Neg(a) => HOp::FwdNeg {
425 dst,
426 a: v_slot(a as u32),
427 },
428 TapeOp::Abs(a) => HOp::FwdAbs {
429 dst,
430 a: v_slot(a as u32),
431 },
432 TapeOp::Sqrt(a) => HOp::FwdSqrt {
433 dst,
434 a: v_slot(a as u32),
435 },
436 TapeOp::Exp(a) => HOp::FwdExp {
437 dst,
438 a: v_slot(a as u32),
439 },
440 TapeOp::Log(a) => HOp::FwdLog {
441 dst,
442 a: v_slot(a as u32),
443 },
444 TapeOp::Log10(a) => HOp::FwdLog10 {
445 dst,
446 a: v_slot(a as u32),
447 },
448 TapeOp::Sin(a) => HOp::FwdSin {
449 dst,
450 a: v_slot(a as u32),
451 },
452 TapeOp::Cos(a) => HOp::FwdCos {
453 dst,
454 a: v_slot(a as u32),
455 },
456 TapeOp::Funcall(_) => panic!(
457 "HessianProgram path does not support AMPL external functions; \
458 use the Tape (build_with_externals) path instead."
459 ),
460 TapeOp::Tan(_)
461 | TapeOp::Atan(_)
462 | TapeOp::Acos(_)
463 | TapeOp::Sinh(_)
464 | TapeOp::Cosh(_)
465 | TapeOp::Tanh(_)
466 | TapeOp::Asin(_)
467 | TapeOp::Acosh(_)
468 | TapeOp::Asinh(_)
469 | TapeOp::Atanh(_)
470 | TapeOp::Atan2(_, _)
471 | TapeOp::Cmp(_, _, _)
472 | TapeOp::And(_, _)
473 | TapeOp::Or(_, _)
474 | TapeOp::Not(_)
475 | TapeOp::Select(_, _, _)
476 | TapeOp::Min(_, _)
477 | TapeOp::Max(_, _) => panic!(
478 "HessianProgram path does not yet support tan/atan/acos, the \
479 other transcendental opcodes, atan2, min/max, or \
480 conditional / logical opcodes; use the Tape \
481 (build_with_externals) interpreter path instead."
482 ),
483 };
484 ops.push(op);
485 }
486
487 if n == 0 || var_indices.is_empty() {
488 return HessianProgram {
489 ops,
490 consts,
491 n_slots,
492 };
493 }
494
495 for (k_idx, &j) in var_indices.iter().enumerate() {
497 ops.push(HOp::ZeroRange {
499 start: dot_base,
500 len: 3 * n,
501 });
502 ops.push(HOp::SetOne {
503 dst: adj_slot(n - 1),
504 });
505
506 for (i, tape_op) in tape.ops.iter().enumerate() {
510 let i_u = i as u32;
511 if !depends_on[k_idx][i] {
512 continue;
513 }
514 let dst = dot_slot(i_u);
515 let dot_op = match *tape_op {
516 TapeOp::Const(_) => continue,
519 TapeOp::Var(_) => HOp::SetOne { dst },
523 TapeOp::Add(a, b) => HOp::DotAdd {
524 dst,
525 a: dot_slot(a as u32),
526 b: dot_slot(b as u32),
527 },
528 TapeOp::Sub(a, b) => HOp::DotSub {
529 dst,
530 a: dot_slot(a as u32),
531 b: dot_slot(b as u32),
532 },
533 TapeOp::Mul(a, b) => HOp::DotMul {
534 dst,
535 dot_a: dot_slot(a as u32),
536 vb: v_slot(b as u32),
537 va: v_slot(a as u32),
538 dot_b: dot_slot(b as u32),
539 },
540 TapeOp::Div(a, b) => HOp::DotDiv {
541 dst,
542 dot_a: dot_slot(a as u32),
543 vb: v_slot(b as u32),
544 va: v_slot(a as u32),
545 dot_b: dot_slot(b as u32),
546 },
547 TapeOp::Pow(a, b) => HOp::DotPow {
548 dst,
549 va: v_slot(a as u32),
550 vb: v_slot(b as u32),
551 vd: v_slot(i_u),
552 dot_a: dot_slot(a as u32),
553 dot_b: dot_slot(b as u32),
554 },
555 TapeOp::Neg(a) => HOp::DotNeg {
556 dst,
557 dot_a: dot_slot(a as u32),
558 },
559 TapeOp::Abs(a) => HOp::DotAbs {
560 dst,
561 dot_a: dot_slot(a as u32),
562 va: v_slot(a as u32),
563 },
564 TapeOp::Sqrt(a) => HOp::DotSqrt {
565 dst,
566 dot_a: dot_slot(a as u32),
567 vd: v_slot(i_u),
568 },
569 TapeOp::Exp(a) => HOp::DotExp {
570 dst,
571 dot_a: dot_slot(a as u32),
572 vd: v_slot(i_u),
573 },
574 TapeOp::Log(a) => HOp::DotLog {
575 dst,
576 dot_a: dot_slot(a as u32),
577 va: v_slot(a as u32),
578 },
579 TapeOp::Log10(a) => HOp::DotLog10 {
580 dst,
581 dot_a: dot_slot(a as u32),
582 va: v_slot(a as u32),
583 },
584 TapeOp::Sin(a) => HOp::DotSin {
585 dst,
586 dot_a: dot_slot(a as u32),
587 va: v_slot(a as u32),
588 },
589 TapeOp::Cos(a) => HOp::DotCos {
590 dst,
591 dot_a: dot_slot(a as u32),
592 va: v_slot(a as u32),
593 },
594 TapeOp::Funcall(_) => panic!(
595 "HessianProgram path does not support AMPL external functions; \
596 use the Tape (build_with_externals) path instead."
597 ),
598 TapeOp::Tan(_)
599 | TapeOp::Atan(_)
600 | TapeOp::Acos(_)
601 | TapeOp::Sinh(_)
602 | TapeOp::Cosh(_)
603 | TapeOp::Tanh(_)
604 | TapeOp::Asin(_)
605 | TapeOp::Acosh(_)
606 | TapeOp::Asinh(_)
607 | TapeOp::Atanh(_)
608 | TapeOp::Atan2(_, _)
609 | TapeOp::Cmp(_, _, _)
610 | TapeOp::And(_, _)
611 | TapeOp::Or(_, _)
612 | TapeOp::Not(_)
613 | TapeOp::Select(_, _, _)
614 | TapeOp::Min(_, _)
615 | TapeOp::Max(_, _) => panic!(
616 "HessianProgram path does not yet support tan/atan/acos, the \
617 other transcendental opcodes, atan2, min/max, or \
618 conditional / logical opcodes; use the Tape \
619 (build_with_externals) interpreter path instead."
620 ),
621 };
622 ops.push(dot_op);
623 }
624
625 for i in (0..n as usize).rev() {
628 if !reachable[i] {
629 continue;
630 }
631 let i_u = i as u32;
632 let w = adj_slot(i_u);
633 let wd = adj_dot_slot(i_u);
634 let tape_op = &tape.ops[i];
635 let rev_op = match *tape_op {
636 TapeOp::Const(_) => continue,
637 TapeOp::Var(k) => {
638 if k >= j {
642 if let Some(&ptr) = hess_map.get(&(k, j)) {
643 ops.push(HOp::HessEmit {
644 hess_ptr: ptr as u32,
645 adj_dot_slot: wd,
646 });
647 }
648 }
649 continue;
650 }
651 TapeOp::Add(a, b) => HOp::RevAdd {
652 adj_a: adj_slot(a as u32),
653 adj_b: adj_slot(b as u32),
654 adj_dot_a: adj_dot_slot(a as u32),
655 adj_dot_b: adj_dot_slot(b as u32),
656 w,
657 wd,
658 },
659 TapeOp::Sub(a, b) => HOp::RevSub {
660 adj_a: adj_slot(a as u32),
661 adj_b: adj_slot(b as u32),
662 adj_dot_a: adj_dot_slot(a as u32),
663 adj_dot_b: adj_dot_slot(b as u32),
664 w,
665 wd,
666 },
667 TapeOp::Mul(a, b) => HOp::RevMul {
668 adj_a: adj_slot(a as u32),
669 adj_b: adj_slot(b as u32),
670 adj_dot_a: adj_dot_slot(a as u32),
671 adj_dot_b: adj_dot_slot(b as u32),
672 w,
673 wd,
674 va: v_slot(a as u32),
675 vb: v_slot(b as u32),
676 dot_a: dot_slot(a as u32),
677 dot_b: dot_slot(b as u32),
678 },
679 TapeOp::Div(a, b) => HOp::RevDiv {
680 adj_a: adj_slot(a as u32),
681 adj_b: adj_slot(b as u32),
682 adj_dot_a: adj_dot_slot(a as u32),
683 adj_dot_b: adj_dot_slot(b as u32),
684 w,
685 wd,
686 va: v_slot(a as u32),
687 vb: v_slot(b as u32),
688 dot_a: dot_slot(a as u32),
689 dot_b: dot_slot(b as u32),
690 },
691 TapeOp::Pow(a, b) => HOp::RevPow {
692 adj_a: adj_slot(a as u32),
693 adj_b: adj_slot(b as u32),
694 adj_dot_a: adj_dot_slot(a as u32),
695 adj_dot_b: adj_dot_slot(b as u32),
696 w,
697 wd,
698 va: v_slot(a as u32),
699 vb: v_slot(b as u32),
700 vd: v_slot(i_u),
701 dot_a: dot_slot(a as u32),
702 dot_b: dot_slot(b as u32),
703 },
704 TapeOp::Neg(a) => HOp::RevNeg {
705 adj_a: adj_slot(a as u32),
706 adj_dot_a: adj_dot_slot(a as u32),
707 w,
708 wd,
709 },
710 TapeOp::Abs(a) => HOp::RevAbs {
711 adj_a: adj_slot(a as u32),
712 adj_dot_a: adj_dot_slot(a as u32),
713 w,
714 wd,
715 va: v_slot(a as u32),
716 },
717 TapeOp::Sqrt(a) => HOp::RevSqrt {
718 adj_a: adj_slot(a as u32),
719 adj_dot_a: adj_dot_slot(a as u32),
720 w,
721 wd,
722 va: v_slot(a as u32),
723 vd: v_slot(i_u),
724 dot_a: dot_slot(a as u32),
725 },
726 TapeOp::Exp(a) => HOp::RevExp {
727 adj_a: adj_slot(a as u32),
728 adj_dot_a: adj_dot_slot(a as u32),
729 w,
730 wd,
731 vd: v_slot(i_u),
732 dot_a: dot_slot(a as u32),
733 },
734 TapeOp::Log(a) => HOp::RevLog {
735 adj_a: adj_slot(a as u32),
736 adj_dot_a: adj_dot_slot(a as u32),
737 w,
738 wd,
739 va: v_slot(a as u32),
740 dot_a: dot_slot(a as u32),
741 },
742 TapeOp::Log10(a) => HOp::RevLog10 {
743 adj_a: adj_slot(a as u32),
744 adj_dot_a: adj_dot_slot(a as u32),
745 w,
746 wd,
747 va: v_slot(a as u32),
748 dot_a: dot_slot(a as u32),
749 },
750 TapeOp::Sin(a) => HOp::RevSin {
751 adj_a: adj_slot(a as u32),
752 adj_dot_a: adj_dot_slot(a as u32),
753 w,
754 wd,
755 va: v_slot(a as u32),
756 dot_a: dot_slot(a as u32),
757 },
758 TapeOp::Cos(a) => HOp::RevCos {
759 adj_a: adj_slot(a as u32),
760 adj_dot_a: adj_dot_slot(a as u32),
761 w,
762 wd,
763 va: v_slot(a as u32),
764 dot_a: dot_slot(a as u32),
765 },
766 TapeOp::Funcall(_) => panic!(
767 "HessianProgram path does not support AMPL external functions; \
768 use the Tape (build_with_externals) path instead."
769 ),
770 TapeOp::Tan(_)
771 | TapeOp::Atan(_)
772 | TapeOp::Acos(_)
773 | TapeOp::Sinh(_)
774 | TapeOp::Cosh(_)
775 | TapeOp::Tanh(_)
776 | TapeOp::Asin(_)
777 | TapeOp::Acosh(_)
778 | TapeOp::Asinh(_)
779 | TapeOp::Atanh(_)
780 | TapeOp::Atan2(_, _)
781 | TapeOp::Cmp(_, _, _)
782 | TapeOp::And(_, _)
783 | TapeOp::Or(_, _)
784 | TapeOp::Not(_)
785 | TapeOp::Select(_, _, _)
786 | TapeOp::Min(_, _)
787 | TapeOp::Max(_, _) => panic!(
788 "HessianProgram path does not yet support tan/atan/acos, the \
789 other transcendental opcodes, atan2, min/max, or \
790 conditional / logical opcodes; use the Tape \
791 (build_with_externals) interpreter path instead."
792 ),
793 };
794 ops.push(rev_op);
795 }
796 }
797
798 HessianProgram {
799 ops,
800 consts,
801 n_slots,
802 }
803 }
804
805 pub fn n_slots(&self) -> usize {
806 self.n_slots as usize
807 }
808
809 pub fn n_ops(&self) -> usize {
810 self.ops.len()
811 }
812
813 pub fn execute(&self, x: &[f64], weight: f64, scratch: &mut [f64], values: &mut [f64]) {
819 debug_assert!(scratch.len() >= self.n_slots as usize);
820 if self.ops.is_empty() || weight == 0.0 {
821 return;
822 }
823 let consts = &self.consts[..];
824 for &op in &self.ops {
825 match op {
826 HOp::FwdLoadVar { dst, x_idx } => {
827 scratch[dst as usize] = x[x_idx as usize];
828 }
829 HOp::FwdLoadConst { dst, c_idx } => {
830 scratch[dst as usize] = consts[c_idx as usize];
831 }
832 HOp::FwdAdd { dst, a, b } => {
833 scratch[dst as usize] = scratch[a as usize] + scratch[b as usize];
834 }
835 HOp::FwdSub { dst, a, b } => {
836 scratch[dst as usize] = scratch[a as usize] - scratch[b as usize];
837 }
838 HOp::FwdMul { dst, a, b } => {
839 scratch[dst as usize] = scratch[a as usize] * scratch[b as usize];
840 }
841 HOp::FwdDiv { dst, a, b } => {
842 scratch[dst as usize] = scratch[a as usize] / scratch[b as usize];
843 }
844 HOp::FwdPow { dst, a, b } => {
845 scratch[dst as usize] = scratch[a as usize].powf(scratch[b as usize]);
846 }
847 HOp::FwdNeg { dst, a } => {
848 scratch[dst as usize] = -scratch[a as usize];
849 }
850 HOp::FwdAbs { dst, a } => {
851 scratch[dst as usize] = scratch[a as usize].abs();
852 }
853 HOp::FwdSqrt { dst, a } => {
854 scratch[dst as usize] = scratch[a as usize].sqrt();
855 }
856 HOp::FwdExp { dst, a } => {
857 scratch[dst as usize] = scratch[a as usize].exp();
858 }
859 HOp::FwdLog { dst, a } => {
860 scratch[dst as usize] = scratch[a as usize].ln();
861 }
862 HOp::FwdLog10 { dst, a } => {
863 scratch[dst as usize] = scratch[a as usize].log10();
864 }
865 HOp::FwdSin { dst, a } => {
866 scratch[dst as usize] = scratch[a as usize].sin();
867 }
868 HOp::FwdCos { dst, a } => {
869 scratch[dst as usize] = scratch[a as usize].cos();
870 }
871
872 HOp::SetZero { dst } => {
873 scratch[dst as usize] = 0.0;
874 }
875 HOp::SetOne { dst } => {
876 scratch[dst as usize] = 1.0;
877 }
878 HOp::ZeroRange { start, len } => {
879 let s = start as usize;
880 let e = s + len as usize;
881 scratch[s..e].fill(0.0);
882 }
883
884 HOp::DotAdd { dst, a, b } => {
885 scratch[dst as usize] = scratch[a as usize] + scratch[b as usize];
886 }
887 HOp::DotSub { dst, a, b } => {
888 scratch[dst as usize] = scratch[a as usize] - scratch[b as usize];
889 }
890 HOp::DotMul {
891 dst,
892 dot_a,
893 vb,
894 va,
895 dot_b,
896 } => {
897 scratch[dst as usize] = scratch[dot_a as usize] * scratch[vb as usize]
898 + scratch[va as usize] * scratch[dot_b as usize];
899 }
900 HOp::DotDiv {
901 dst,
902 dot_a,
903 vb,
904 va,
905 dot_b,
906 } => {
907 let v_b = scratch[vb as usize];
908 scratch[dst as usize] = (scratch[dot_a as usize] * v_b
909 - scratch[va as usize] * scratch[dot_b as usize])
910 / (v_b * v_b);
911 }
912 HOp::DotSqrt { dst, dot_a, vd } => {
913 let svd = scratch[vd as usize];
914 scratch[dst as usize] = if svd > 0.0 {
915 scratch[dot_a as usize] * 0.5 / svd
916 } else {
917 0.0
918 };
919 }
920 HOp::DotExp { dst, dot_a, vd } => {
921 scratch[dst as usize] = scratch[dot_a as usize] * scratch[vd as usize];
922 }
923 HOp::DotLog { dst, dot_a, va } => {
924 scratch[dst as usize] = scratch[dot_a as usize] / scratch[va as usize];
925 }
926 HOp::DotLog10 { dst, dot_a, va } => {
927 scratch[dst as usize] =
928 scratch[dot_a as usize] / (scratch[va as usize] * std::f64::consts::LN_10);
929 }
930 HOp::DotSin { dst, dot_a, va } => {
931 scratch[dst as usize] = scratch[dot_a as usize] * scratch[va as usize].cos();
932 }
933 HOp::DotCos { dst, dot_a, va } => {
934 scratch[dst as usize] = -scratch[dot_a as usize] * scratch[va as usize].sin();
935 }
936 HOp::DotNeg { dst, dot_a } => {
937 scratch[dst as usize] = -scratch[dot_a as usize];
938 }
939 HOp::DotAbs { dst, dot_a, va } => {
940 scratch[dst as usize] = if scratch[va as usize] >= 0.0 {
941 scratch[dot_a as usize]
942 } else {
943 -scratch[dot_a as usize]
944 };
945 }
946 HOp::DotPow {
947 dst,
948 va,
949 vb,
950 vd,
951 dot_a,
952 dot_b,
953 } => {
954 let u = scratch[va as usize];
955 let r = scratch[vb as usize];
956 let du = scratch[dot_a as usize];
957 let dr = scratch[dot_b as usize];
958 let mut result = 0.0;
959 if r != 0.0 && u != 0.0 {
960 result += r * u.powf(r - 1.0) * du;
961 }
962 if u > 0.0 {
963 result += scratch[vd as usize] * u.ln() * dr;
964 }
965 scratch[dst as usize] = result;
966 }
967
968 HOp::RevAdd {
969 adj_a,
970 adj_b,
971 adj_dot_a,
972 adj_dot_b,
973 w,
974 wd,
975 } => {
976 let w_v = scratch[w as usize];
977 let wd_v = scratch[wd as usize];
978 scratch[adj_a as usize] += w_v;
979 scratch[adj_b as usize] += w_v;
980 scratch[adj_dot_a as usize] += wd_v;
981 scratch[adj_dot_b as usize] += wd_v;
982 }
983 HOp::RevSub {
984 adj_a,
985 adj_b,
986 adj_dot_a,
987 adj_dot_b,
988 w,
989 wd,
990 } => {
991 let w_v = scratch[w as usize];
992 let wd_v = scratch[wd as usize];
993 scratch[adj_a as usize] += w_v;
994 scratch[adj_b as usize] -= w_v;
995 scratch[adj_dot_a as usize] += wd_v;
996 scratch[adj_dot_b as usize] -= wd_v;
997 }
998 HOp::RevMul {
999 adj_a,
1000 adj_b,
1001 adj_dot_a,
1002 adj_dot_b,
1003 w,
1004 wd,
1005 va,
1006 vb,
1007 dot_a,
1008 dot_b,
1009 } => {
1010 let w_v = scratch[w as usize];
1011 let wd_v = scratch[wd as usize];
1012 let va_v = scratch[va as usize];
1013 let vb_v = scratch[vb as usize];
1014 let da_v = scratch[dot_a as usize];
1015 let db_v = scratch[dot_b as usize];
1016 scratch[adj_a as usize] += w_v * vb_v;
1017 scratch[adj_b as usize] += w_v * va_v;
1018 scratch[adj_dot_a as usize] += wd_v * vb_v + w_v * db_v;
1019 scratch[adj_dot_b as usize] += wd_v * va_v + w_v * da_v;
1020 }
1021 HOp::RevDiv {
1022 adj_a,
1023 adj_b,
1024 adj_dot_a,
1025 adj_dot_b,
1026 w,
1027 wd,
1028 va,
1029 vb,
1030 dot_a,
1031 dot_b,
1032 } => {
1033 let w_v = scratch[w as usize];
1034 let wd_v = scratch[wd as usize];
1035 let va_v = scratch[va as usize];
1036 let vb_v = scratch[vb as usize];
1037 let vb2 = vb_v * vb_v;
1038 let vb3 = vb2 * vb_v;
1039 let da_v = scratch[dot_a as usize];
1040 let db_v = scratch[dot_b as usize];
1041 scratch[adj_a as usize] += w_v / vb_v;
1042 scratch[adj_dot_a as usize] += wd_v / vb_v + w_v * (-db_v / vb2);
1043 scratch[adj_b as usize] += w_v * (-va_v / vb2);
1044 scratch[adj_dot_b as usize] +=
1045 wd_v * (-va_v / vb2) + w_v * (-da_v / vb2 + 2.0 * va_v * db_v / vb3);
1046 }
1047 HOp::RevPow {
1048 adj_a,
1049 adj_b,
1050 adj_dot_a,
1051 adj_dot_b,
1052 w,
1053 wd,
1054 va,
1055 vb,
1056 vd,
1057 dot_a,
1058 dot_b,
1059 } => {
1060 let w_v = scratch[w as usize];
1061 let wd_v = scratch[wd as usize];
1062 let u = scratch[va as usize];
1063 let r = scratch[vb as usize];
1064 let du = scratch[dot_a as usize];
1065 let dr = scratch[dot_b as usize];
1066 if r != 0.0 {
1067 if u != 0.0 {
1068 let p_a = r * u.powf(r - 1.0);
1069 scratch[adj_a as usize] += w_v * p_a;
1070 let mut dp_a = dr * u.powf(r - 1.0);
1071 if u > 0.0 {
1072 dp_a += r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
1073 } else {
1074 dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
1075 }
1076 scratch[adj_dot_a as usize] += wd_v * p_a + w_v * dp_a;
1077 } else if r >= 2.0 {
1078 let p_a = 0.0;
1079 scratch[adj_a as usize] += w_v * p_a;
1080 let dp_a = if r == 2.0 {
1081 2.0 * du
1082 } else {
1083 r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
1084 };
1085 scratch[adj_dot_a as usize] += wd_v * p_a + w_v * dp_a;
1086 }
1087 }
1088 if u > 0.0 {
1089 let ln_u = u.ln();
1090 let p_b = scratch[vd as usize] * ln_u;
1091 scratch[adj_b as usize] += w_v * p_b;
1092 let dur = scratch[vd as usize] * (r * du / u + dr * ln_u);
1093 let dp_b = dur * ln_u + scratch[vd as usize] * du / u;
1094 scratch[adj_dot_b as usize] += wd_v * p_b + w_v * dp_b;
1095 }
1096 }
1097 HOp::RevNeg {
1098 adj_a,
1099 adj_dot_a,
1100 w,
1101 wd,
1102 } => {
1103 scratch[adj_a as usize] -= scratch[w as usize];
1104 scratch[adj_dot_a as usize] -= scratch[wd as usize];
1105 }
1106 HOp::RevAbs {
1107 adj_a,
1108 adj_dot_a,
1109 w,
1110 wd,
1111 va,
1112 } => {
1113 let s = if scratch[va as usize] >= 0.0 {
1114 1.0
1115 } else {
1116 -1.0
1117 };
1118 scratch[adj_a as usize] += scratch[w as usize] * s;
1119 scratch[adj_dot_a as usize] += scratch[wd as usize] * s;
1120 }
1121 HOp::RevSqrt {
1122 adj_a,
1123 adj_dot_a,
1124 w,
1125 wd,
1126 va: _,
1127 vd,
1128 dot_a,
1129 } => {
1130 let sv = scratch[vd as usize];
1131 if sv > 0.0 {
1132 let fp = 0.5 / sv;
1133 let fpp = -0.25 / (sv * sv * sv);
1134 let w_v = scratch[w as usize];
1135 let wd_v = scratch[wd as usize];
1136 scratch[adj_a as usize] += w_v * fp;
1137 scratch[adj_dot_a as usize] +=
1138 wd_v * fp + w_v * fpp * scratch[dot_a as usize];
1139 }
1140 }
1141 HOp::RevExp {
1142 adj_a,
1143 adj_dot_a,
1144 w,
1145 wd,
1146 vd,
1147 dot_a,
1148 } => {
1149 let ev = scratch[vd as usize];
1150 let w_v = scratch[w as usize];
1151 let wd_v = scratch[wd as usize];
1152 scratch[adj_a as usize] += w_v * ev;
1153 scratch[adj_dot_a as usize] += wd_v * ev + w_v * ev * scratch[dot_a as usize];
1154 }
1155 HOp::RevLog {
1156 adj_a,
1157 adj_dot_a,
1158 w,
1159 wd,
1160 va,
1161 dot_a,
1162 } => {
1163 let u = scratch[va as usize];
1164 let w_v = scratch[w as usize];
1165 let wd_v = scratch[wd as usize];
1166 scratch[adj_a as usize] += w_v / u;
1167 scratch[adj_dot_a as usize] +=
1168 wd_v / u + w_v * (-1.0 / (u * u)) * scratch[dot_a as usize];
1169 }
1170 HOp::RevLog10 {
1171 adj_a,
1172 adj_dot_a,
1173 w,
1174 wd,
1175 va,
1176 dot_a,
1177 } => {
1178 let u = scratch[va as usize];
1179 let c = std::f64::consts::LN_10;
1180 let w_v = scratch[w as usize];
1181 let wd_v = scratch[wd as usize];
1182 scratch[adj_a as usize] += w_v / (u * c);
1183 scratch[adj_dot_a as usize] +=
1184 wd_v / (u * c) + w_v * (-1.0 / (u * u * c)) * scratch[dot_a as usize];
1185 }
1186 HOp::RevSin {
1187 adj_a,
1188 adj_dot_a,
1189 w,
1190 wd,
1191 va,
1192 dot_a,
1193 } => {
1194 let u = scratch[va as usize];
1195 let cu = u.cos();
1196 let w_v = scratch[w as usize];
1197 let wd_v = scratch[wd as usize];
1198 scratch[adj_a as usize] += w_v * cu;
1199 scratch[adj_dot_a as usize] +=
1200 wd_v * cu + w_v * (-u.sin()) * scratch[dot_a as usize];
1201 }
1202 HOp::RevCos {
1203 adj_a,
1204 adj_dot_a,
1205 w,
1206 wd,
1207 va,
1208 dot_a,
1209 } => {
1210 let u = scratch[va as usize];
1211 let su = u.sin();
1212 let w_v = scratch[w as usize];
1213 let wd_v = scratch[wd as usize];
1214 scratch[adj_a as usize] -= w_v * su;
1215 scratch[adj_dot_a as usize] +=
1216 wd_v * (-su) + w_v * (-u.cos()) * scratch[dot_a as usize];
1217 }
1218
1219 HOp::HessEmit {
1220 hess_ptr,
1221 adj_dot_slot,
1222 } => {
1223 values[hess_ptr as usize] += weight * scratch[adj_dot_slot as usize];
1224 }
1225 }
1226 }
1227 }
1228}
1229
1230fn reachable_to_output(tape: &Tape) -> Vec<bool> {
1234 let n = tape.ops.len();
1235 let mut r = vec![false; n];
1236 if n == 0 {
1237 return r;
1238 }
1239 r[n - 1] = true;
1240 for i in (0..n).rev() {
1241 if !r[i] {
1242 continue;
1243 }
1244 match tape.ops[i] {
1245 TapeOp::Const(_) | TapeOp::Var(_) => {}
1246 TapeOp::Add(a, b)
1247 | TapeOp::Sub(a, b)
1248 | TapeOp::Mul(a, b)
1249 | TapeOp::Div(a, b)
1250 | TapeOp::Pow(a, b)
1251 | TapeOp::Atan2(a, b) => {
1252 r[a] = true;
1253 r[b] = true;
1254 }
1255 TapeOp::Neg(a)
1256 | TapeOp::Abs(a)
1257 | TapeOp::Sqrt(a)
1258 | TapeOp::Exp(a)
1259 | TapeOp::Log(a)
1260 | TapeOp::Log10(a)
1261 | TapeOp::Sin(a)
1262 | TapeOp::Cos(a)
1263 | TapeOp::Tan(a)
1264 | TapeOp::Atan(a)
1265 | TapeOp::Acos(a)
1266 | TapeOp::Sinh(a)
1267 | TapeOp::Cosh(a)
1268 | TapeOp::Tanh(a)
1269 | TapeOp::Asin(a)
1270 | TapeOp::Acosh(a)
1271 | TapeOp::Asinh(a)
1272 | TapeOp::Atanh(a) => {
1273 r[a] = true;
1274 }
1275 TapeOp::Funcall(_) => panic!(
1276 "HessianProgram path does not support AMPL external functions; \
1277 use the Tape (build_with_externals) path instead."
1278 ),
1279 TapeOp::Cmp(_, _, _)
1280 | TapeOp::And(_, _)
1281 | TapeOp::Or(_, _)
1282 | TapeOp::Not(_)
1283 | TapeOp::Select(_, _, _)
1284 | TapeOp::Min(_, _)
1285 | TapeOp::Max(_, _) => panic!(
1286 "HessianProgram path does not support conditional / logical / min-max \
1287 opcodes; use the Tape (build_with_externals) path instead."
1288 ),
1289 }
1290 }
1291 r
1292}
1293
1294fn depends_on_var(tape: &Tape, j: usize) -> Vec<bool> {
1299 let n = tape.ops.len();
1300 let mut d = vec![false; n];
1301 for (i, op) in tape.ops.iter().enumerate() {
1302 d[i] = match *op {
1303 TapeOp::Const(_) => false,
1304 TapeOp::Var(k) => k == j,
1305 TapeOp::Add(a, b)
1306 | TapeOp::Sub(a, b)
1307 | TapeOp::Mul(a, b)
1308 | TapeOp::Div(a, b)
1309 | TapeOp::Pow(a, b)
1310 | TapeOp::Atan2(a, b) => d[a] || d[b],
1311 TapeOp::Neg(a)
1312 | TapeOp::Abs(a)
1313 | TapeOp::Sqrt(a)
1314 | TapeOp::Exp(a)
1315 | TapeOp::Log(a)
1316 | TapeOp::Log10(a)
1317 | TapeOp::Sin(a)
1318 | TapeOp::Cos(a)
1319 | TapeOp::Tan(a)
1320 | TapeOp::Atan(a)
1321 | TapeOp::Acos(a)
1322 | TapeOp::Sinh(a)
1323 | TapeOp::Cosh(a)
1324 | TapeOp::Tanh(a)
1325 | TapeOp::Asin(a)
1326 | TapeOp::Acosh(a)
1327 | TapeOp::Asinh(a)
1328 | TapeOp::Atanh(a) => d[a],
1329 TapeOp::Funcall(_) => panic!(
1330 "HessianProgram path does not support AMPL external functions; \
1331 use the Tape (build_with_externals) path instead."
1332 ),
1333 TapeOp::Cmp(_, _, _)
1334 | TapeOp::And(_, _)
1335 | TapeOp::Or(_, _)
1336 | TapeOp::Not(_)
1337 | TapeOp::Select(_, _, _)
1338 | TapeOp::Min(_, _)
1339 | TapeOp::Max(_, _) => panic!(
1340 "HessianProgram path does not support conditional / logical / min-max \
1341 opcodes; use the Tape (build_with_externals) path instead."
1342 ),
1343 };
1344 }
1345 d
1346}
1347
1348#[cfg(test)]
1349mod tests {
1350 use super::*;
1351 use crate::nl_reader::{BinOp, Expr, UnaryOp};
1352 use std::collections::BTreeSet;
1353 use std::rc::Rc;
1354
1355 fn cnst(c: f64) -> Expr {
1356 Expr::Const(c)
1357 }
1358 fn var(i: usize) -> Expr {
1359 Expr::Var(i)
1360 }
1361 fn add(a: Expr, b: Expr) -> Expr {
1362 Expr::Binary(BinOp::Add, Box::new(a), Box::new(b))
1363 }
1364 fn mul(a: Expr, b: Expr) -> Expr {
1365 Expr::Binary(BinOp::Mul, Box::new(a), Box::new(b))
1366 }
1367 fn pow(a: Expr, b: Expr) -> Expr {
1368 Expr::Binary(BinOp::Pow, Box::new(a), Box::new(b))
1369 }
1370 fn div(a: Expr, b: Expr) -> Expr {
1371 Expr::Binary(BinOp::Div, Box::new(a), Box::new(b))
1372 }
1373 fn sub(a: Expr, b: Expr) -> Expr {
1374 Expr::Binary(BinOp::Sub, Box::new(a), Box::new(b))
1375 }
1376 fn unary(op: UnaryOp, a: Expr) -> Expr {
1377 Expr::Unary(op, Box::new(a))
1378 }
1379
1380 fn build_hess_map(tape: &Tape) -> (HashMap<(usize, usize), usize>, Vec<(usize, usize)>) {
1383 let vars = tape.variables();
1384 let mut pairs: Vec<(usize, usize)> = Vec::new();
1385 let mut map: HashMap<(usize, usize), usize> = HashMap::new();
1386 for (ai, &vi) in vars.iter().enumerate() {
1387 for &vj in &vars[..=ai] {
1388 let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
1389 map.entry((r, c)).or_insert_with(|| {
1390 let p = pairs.len();
1391 pairs.push((r, c));
1392 p
1393 });
1394 }
1395 }
1396 (map, pairs)
1397 }
1398
1399 fn assert_program_matches_tape(tape: &Tape, x: &[f64], weight: f64) {
1402 let (hess_map, pairs) = build_hess_map(tape);
1403 let nnz = pairs.len();
1404
1405 let mut tape_vals = vec![0.0; nnz];
1406 tape.hessian_accumulate(x, weight, &hess_map, &mut tape_vals);
1407
1408 let program = HessianProgram::compile(tape, &hess_map);
1409 let mut scratch = vec![0.0; program.n_slots()];
1410 let mut prog_vals = vec![0.0; nnz];
1411 program.execute(x, weight, &mut scratch, &mut prog_vals);
1412
1413 for (k, &(r, c)) in pairs.iter().enumerate() {
1414 let tol = tape_vals[k].abs().max(1.0) * 1e-12;
1415 assert!(
1416 (tape_vals[k] - prog_vals[k]).abs() < tol,
1417 "H[{},{}]: tape={:.6e} prog={:.6e}",
1418 r,
1419 c,
1420 tape_vals[k],
1421 prog_vals[k]
1422 );
1423 }
1424 }
1425
1426 #[test]
1427 fn matches_quadratic() {
1428 let e = add(
1429 add(
1430 mul(cnst(3.0), pow(var(0), cnst(2.0))),
1431 mul(cnst(2.0), mul(var(0), var(1))),
1432 ),
1433 pow(var(1), cnst(2.0)),
1434 );
1435 let tape = Tape::build(&e);
1436 assert_program_matches_tape(&tape, &[2.0, 3.0], 1.0);
1437 assert_program_matches_tape(&tape, &[-1.5, 0.7], 2.5);
1438 }
1439
1440 #[test]
1441 fn matches_transcendental() {
1442 let e = Expr::Sum(vec![
1443 unary(UnaryOp::Exp, var(0)),
1444 unary(UnaryOp::Sin, var(1)),
1445 unary(UnaryOp::Log, var(0)),
1446 unary(UnaryOp::Sqrt, var(1)),
1447 mul(var(0), var(1)),
1448 unary(UnaryOp::Cos, add(var(0), var(1))),
1449 ]);
1450 let tape = Tape::build(&e);
1451 assert_program_matches_tape(&tape, &[1.5, 2.0], 1.0);
1452 assert_program_matches_tape(&tape, &[0.3, 4.1], -0.4);
1453 }
1454
1455 #[test]
1456 fn matches_division() {
1457 let e = add(div(var(0), var(1)), unary(UnaryOp::Cos, var(0)));
1458 let tape = Tape::build(&e);
1459 assert_program_matches_tape(&tape, &[0.5, 1.2], 1.0);
1460 }
1461
1462 #[test]
1463 fn matches_through_cse() {
1464 let body = Rc::new(add(var(0), var(1)));
1465 let e = add(
1466 pow(Expr::Cse(body.clone()), cnst(2.0)),
1467 Expr::Cse(body.clone()),
1468 );
1469 let tape = Tape::build(&e);
1470 assert_program_matches_tape(&tape, &[1.0, 2.0], 1.0);
1471 assert_program_matches_tape(&tape, &[-0.5, 3.3], 0.7);
1472 }
1473
1474 #[test]
1475 fn matches_pow_chain() {
1476 let e = add(pow(var(0), cnst(3.0)), pow(var(1), cnst(-2.0)));
1479 let tape = Tape::build(&e);
1480 assert_program_matches_tape(&tape, &[1.7, 0.8], 1.0);
1481 }
1482
1483 #[test]
1484 fn matches_residual_pow_with_var_exponent() {
1485 let e = pow(var(0), var(1));
1489 let tape = Tape::build(&e);
1490 assert_program_matches_tape(&tape, &[2.5, 1.4], 1.0);
1491 assert_program_matches_tape(&tape, &[0.6, 2.1], -1.0);
1492 }
1493
1494 #[test]
1495 fn matches_sub_neg_abs() {
1496 let e = sub(
1497 unary(UnaryOp::Neg, var(0)),
1498 unary(UnaryOp::Abs, sub(var(1), var(0))),
1499 );
1500 let tape = Tape::build(&e);
1501 assert_program_matches_tape(&tape, &[1.0, -2.0], 1.0);
1502 assert_program_matches_tape(&tape, &[-3.5, 4.0], 0.9);
1503 }
1504
1505 #[test]
1506 fn slots_layout_matches_design() {
1507 let e = mul(var(0), var(1));
1508 let tape = Tape::build(&e);
1509 let (hess_map, _) = build_hess_map(&tape);
1510 let prog = HessianProgram::compile(&tape, &hess_map);
1511 assert_eq!(prog.n_slots(), 4 * tape.ops.len());
1512 }
1513
1514 #[test]
1517 fn dependence_matches_hessian_sparsity_for_simple_case() {
1518 let e = add(unary(UnaryOp::Sin, var(0)), mul(var(1), var(2)));
1519 let tape = Tape::build(&e);
1520 let s: BTreeSet<(usize, usize)> = tape.hessian_sparsity();
1521 assert!(s.contains(&(0, 0)));
1524 assert!(s.contains(&(2, 1)));
1525 assert_program_matches_tape(&tape, &[0.7, 1.1, 2.2], 1.0);
1526 }
1527}