1use std::collections::{BTreeSet, HashMap, HashSet};
26use std::rc::Rc;
27use std::sync::Arc;
28
29use super::nl_external::{ExternalArg, ExternalLibrary, ExternalResolver};
30use super::nl_reader::{BinOp, Expr, FuncallArg, UnaryOp};
31
32#[derive(Debug, Clone)]
36pub enum TapeOp {
37 Const(f64),
38 Var(usize),
39 Add(usize, usize),
40 Sub(usize, usize),
41 Mul(usize, usize),
42 Div(usize, usize),
43 Pow(usize, usize),
44 Neg(usize),
45 Abs(usize),
46 Sqrt(usize),
47 Exp(usize),
48 Log(usize),
49 Log10(usize),
50 Sin(usize),
51 Cos(usize),
52 Funcall {
57 lib: Arc<ExternalLibrary>,
58 name: String,
59 args: Vec<TapeFuncallArg>,
60 },
61}
62
63#[derive(Debug, Clone)]
67pub enum TapeFuncallArg {
68 Tape(usize),
69 Str(String),
70}
71
72fn funcall_to_ext_args<'a>(args: &'a [TapeFuncallArg], vals: &[f64]) -> Vec<ExternalArg<'a>> {
73 args.iter()
74 .map(|a| match a {
75 TapeFuncallArg::Tape(idx) => ExternalArg::Real(vals[*idx]),
76 TapeFuncallArg::Str(s) => ExternalArg::Str(s.as_str()),
77 })
78 .collect()
79}
80
81#[derive(Debug, Clone)]
84pub struct Tape {
85 pub ops: Vec<TapeOp>,
86}
87
88impl Tape {
89 pub fn build(expr: &Expr) -> Self {
93 Self::build_with_externals(expr, &ExternalResolver::default())
94 }
95
96 pub fn build_with_externals(expr: &Expr, resolver: &ExternalResolver) -> Self {
101 let mut ops = Vec::new();
102 let mut cache: HashMap<*const Expr, usize> = HashMap::new();
103 build_recursive(expr, &mut ops, &mut cache, resolver);
104 Tape { ops }
105 }
106
107 pub fn forward(&self, x: &[f64]) -> Vec<f64> {
110 let mut vals: Vec<f64> = Vec::with_capacity(self.ops.len());
111 for op in &self.ops {
112 let v = match op {
113 TapeOp::Const(c) => *c,
114 TapeOp::Var(i) => x[*i],
115 TapeOp::Add(a, b) => vals[*a] + vals[*b],
116 TapeOp::Sub(a, b) => vals[*a] - vals[*b],
117 TapeOp::Mul(a, b) => vals[*a] * vals[*b],
118 TapeOp::Div(a, b) => vals[*a] / vals[*b],
119 TapeOp::Pow(a, b) => vals[*a].powf(vals[*b]),
120 TapeOp::Neg(a) => -vals[*a],
121 TapeOp::Abs(a) => vals[*a].abs(),
122 TapeOp::Sqrt(a) => vals[*a].sqrt(),
123 TapeOp::Exp(a) => vals[*a].exp(),
124 TapeOp::Log(a) => vals[*a].ln(),
125 TapeOp::Log10(a) => vals[*a].log10(),
126 TapeOp::Sin(a) => vals[*a].sin(),
127 TapeOp::Cos(a) => vals[*a].cos(),
128 TapeOp::Funcall { lib, name, args } => {
129 let call_args = funcall_to_ext_args(args, &vals);
130 let res = lib
131 .eval(name, &call_args, false, false)
132 .unwrap_or_else(|e| {
133 panic!("external function '{name}' forward eval failed: {e}")
134 });
135 res.value
136 }
137 };
138 vals.push(v);
139 }
140 vals
141 }
142
143 pub fn eval(&self, x: &[f64]) -> f64 {
144 let vals = self.forward(x);
145 *vals.last().unwrap_or(&0.0)
146 }
147
148 pub fn gradient_seed(&self, x: &[f64], seed: f64, grad: &mut [f64]) {
153 if seed == 0.0 || self.ops.is_empty() {
154 return;
155 }
156 let vals = self.forward(x);
157 self.reverse(&vals, seed, grad);
158 }
159
160 fn reverse(&self, vals: &[f64], seed: f64, grad: &mut [f64]) {
161 let n = self.ops.len();
162 let mut adj = vec![0.0f64; n];
163 adj[n - 1] = seed;
164
165 for i in (0..n).rev() {
166 let a = adj[i];
167 if a == 0.0 {
168 continue;
169 }
170 match &self.ops[i] {
171 TapeOp::Const(_) => {}
172 TapeOp::Var(j) => {
173 grad[*j] += a;
174 }
175 TapeOp::Add(l, r) => {
176 adj[*l] += a;
177 adj[*r] += a;
178 }
179 TapeOp::Sub(l, r) => {
180 adj[*l] += a;
181 adj[*r] -= a;
182 }
183 TapeOp::Mul(l, r) => {
184 adj[*l] += a * vals[*r];
185 adj[*r] += a * vals[*l];
186 }
187 TapeOp::Div(l, r) => {
188 let rv = vals[*r];
189 adj[*l] += a / rv;
190 adj[*r] -= a * vals[*l] / (rv * rv);
191 }
192 TapeOp::Pow(l, r) => {
193 let lv = vals[*l];
194 let rv = vals[*r];
195 if rv != 0.0 {
196 adj[*l] += a * rv * lv.powf(rv - 1.0);
197 }
198 if lv > 0.0 {
199 adj[*r] += a * vals[i] * lv.ln();
200 }
201 }
202 TapeOp::Neg(j) => {
203 adj[*j] -= a;
204 }
205 TapeOp::Abs(j) => {
206 if vals[*j] >= 0.0 {
207 adj[*j] += a;
208 } else {
209 adj[*j] -= a;
210 }
211 }
212 TapeOp::Sqrt(j) => {
213 let sv = vals[i];
214 if sv > 0.0 {
215 adj[*j] += a * 0.5 / sv;
216 }
217 }
218 TapeOp::Exp(j) => {
219 adj[*j] += a * vals[i];
220 }
221 TapeOp::Log(j) => {
222 adj[*j] += a / vals[*j];
223 }
224 TapeOp::Log10(j) => {
225 adj[*j] += a / (vals[*j] * std::f64::consts::LN_10);
226 }
227 TapeOp::Sin(j) => {
228 adj[*j] += a * vals[*j].cos();
229 }
230 TapeOp::Cos(j) => {
231 adj[*j] -= a * vals[*j].sin();
232 }
233 TapeOp::Funcall { lib, name, args } => {
234 let call_args = funcall_to_ext_args(args, vals);
235 let res = lib.eval(name, &call_args, true, false).unwrap_or_else(|e| {
236 panic!("external function '{name}' reverse eval failed: {e}")
237 });
238 let derivs = res.derivs.expect("want_derivs=true returns derivs");
239 let mut k = 0usize;
240 for arg in args {
241 if let TapeFuncallArg::Tape(idx) = arg {
242 adj[*idx] += a * derivs[k];
243 k += 1;
244 }
245 }
246 }
247 }
248 }
249 }
250
251 pub fn variables(&self) -> Vec<usize> {
253 let mut s: BTreeSet<usize> = BTreeSet::new();
254 for op in &self.ops {
255 if let TapeOp::Var(j) = op {
256 s.insert(*j);
257 }
258 }
259 s.into_iter().collect()
260 }
261
262 fn forward_tangent(&self, vals: &[f64], seed_var: usize, dot: &mut [f64]) {
267 let n = self.ops.len();
268 debug_assert_eq!(dot.len(), n);
269 for i in 0..n {
270 dot[i] = match &self.ops[i] {
271 TapeOp::Const(_) => 0.0,
272 TapeOp::Var(k) => {
273 if *k == seed_var {
274 1.0
275 } else {
276 0.0
277 }
278 }
279 TapeOp::Add(a, b) => dot[*a] + dot[*b],
280 TapeOp::Sub(a, b) => dot[*a] - dot[*b],
281 TapeOp::Mul(a, b) => dot[*a] * vals[*b] + vals[*a] * dot[*b],
282 TapeOp::Div(a, b) => {
283 let vb = vals[*b];
284 (dot[*a] * vb - vals[*a] * dot[*b]) / (vb * vb)
285 }
286 TapeOp::Pow(a, b) => {
287 let u = vals[*a];
288 let r = vals[*b];
289 let du = dot[*a];
290 let dr = dot[*b];
291 let mut result = 0.0;
292 if r != 0.0 && u != 0.0 {
293 result += r * u.powf(r - 1.0) * du;
294 }
295 if u > 0.0 {
296 result += vals[i] * u.ln() * dr;
297 }
298 result
299 }
300 TapeOp::Neg(a) => -dot[*a],
301 TapeOp::Abs(a) => {
302 if vals[*a] >= 0.0 {
303 dot[*a]
304 } else {
305 -dot[*a]
306 }
307 }
308 TapeOp::Sqrt(a) => {
309 let sv = vals[i];
310 if sv > 0.0 {
311 dot[*a] * 0.5 / sv
312 } else {
313 0.0
314 }
315 }
316 TapeOp::Exp(a) => dot[*a] * vals[i],
317 TapeOp::Log(a) => dot[*a] / vals[*a],
318 TapeOp::Log10(a) => dot[*a] / (vals[*a] * std::f64::consts::LN_10),
319 TapeOp::Sin(a) => dot[*a] * vals[*a].cos(),
320 TapeOp::Cos(a) => -dot[*a] * vals[*a].sin(),
321 TapeOp::Funcall { lib, name, args } => {
322 let call_args = funcall_to_ext_args(args, vals);
323 let res = lib.eval(name, &call_args, true, false).unwrap_or_else(|e| {
324 panic!("external function '{name}' tangent eval failed: {e}")
325 });
326 let derivs = res.derivs.expect("want_derivs=true returns derivs");
327 let mut acc = 0.0;
328 let mut k = 0usize;
329 for arg in args {
330 if let TapeFuncallArg::Tape(idx) = arg {
331 acc += derivs[k] * dot[*idx];
332 k += 1;
333 }
334 }
335 acc
336 }
337 };
338 }
339 }
340
341 pub fn forward_into(&self, x: &[f64], vals: &mut [f64]) {
345 let n = self.ops.len();
346 debug_assert!(vals.len() >= n);
347 for i in 0..n {
348 vals[i] = match &self.ops[i] {
349 TapeOp::Const(c) => *c,
350 TapeOp::Var(j) => x[*j],
351 TapeOp::Add(a, b) => vals[*a] + vals[*b],
352 TapeOp::Sub(a, b) => vals[*a] - vals[*b],
353 TapeOp::Mul(a, b) => vals[*a] * vals[*b],
354 TapeOp::Div(a, b) => vals[*a] / vals[*b],
355 TapeOp::Pow(a, b) => vals[*a].powf(vals[*b]),
356 TapeOp::Neg(a) => -vals[*a],
357 TapeOp::Abs(a) => vals[*a].abs(),
358 TapeOp::Sqrt(a) => vals[*a].sqrt(),
359 TapeOp::Exp(a) => vals[*a].exp(),
360 TapeOp::Log(a) => vals[*a].ln(),
361 TapeOp::Log10(a) => vals[*a].log10(),
362 TapeOp::Sin(a) => vals[*a].sin(),
363 TapeOp::Cos(a) => vals[*a].cos(),
364 TapeOp::Funcall { lib, name, args } => {
365 let call_args = funcall_to_ext_args(args, &*vals);
366 let res = lib
367 .eval(name, &call_args, false, false)
368 .unwrap_or_else(|e| {
369 panic!("external function '{name}' forward_into failed: {e}")
370 });
371 res.value
372 }
373 };
374 }
375 }
376
377 pub fn hessian_directional(
394 &self,
395 vals: &[f64],
396 seed: &[f64],
397 weight: f64,
398 out: &mut [f64],
399 dot: &mut [f64],
400 adj: &mut [f64],
401 adj_dot: &mut [f64],
402 ) {
403 let n = self.ops.len();
404 if n == 0 || weight == 0.0 {
405 return;
406 }
407 debug_assert!(vals.len() >= n);
408 debug_assert!(dot.len() >= n);
409 debug_assert!(adj.len() >= n);
410 debug_assert!(adj_dot.len() >= n);
411
412 for i in 0..n {
416 dot[i] = match &self.ops[i] {
417 TapeOp::Const(_) => 0.0,
418 TapeOp::Var(k) => seed[*k],
419 TapeOp::Add(a, b) => dot[*a] + dot[*b],
420 TapeOp::Sub(a, b) => dot[*a] - dot[*b],
421 TapeOp::Mul(a, b) => dot[*a] * vals[*b] + vals[*a] * dot[*b],
422 TapeOp::Div(a, b) => {
423 let vb = vals[*b];
424 (dot[*a] * vb - vals[*a] * dot[*b]) / (vb * vb)
425 }
426 TapeOp::Pow(a, b) => {
427 let u = vals[*a];
428 let r = vals[*b];
429 let du = dot[*a];
430 let dr = dot[*b];
431 let mut result = 0.0;
432 if r != 0.0 && u != 0.0 {
433 result += r * u.powf(r - 1.0) * du;
434 }
435 if u > 0.0 {
436 result += vals[i] * u.ln() * dr;
437 }
438 result
439 }
440 TapeOp::Neg(a) => -dot[*a],
441 TapeOp::Abs(a) => {
442 if vals[*a] >= 0.0 {
443 dot[*a]
444 } else {
445 -dot[*a]
446 }
447 }
448 TapeOp::Sqrt(a) => {
449 let sv = vals[i];
450 if sv > 0.0 {
451 dot[*a] * 0.5 / sv
452 } else {
453 0.0
454 }
455 }
456 TapeOp::Exp(a) => vals[i] * dot[*a],
457 TapeOp::Log(a) => dot[*a] / vals[*a],
458 TapeOp::Log10(a) => dot[*a] / (vals[*a] * std::f64::consts::LN_10),
459 TapeOp::Sin(a) => vals[*a].cos() * dot[*a],
460 TapeOp::Cos(a) => -vals[*a].sin() * dot[*a],
461 TapeOp::Funcall { lib, name, args } => {
462 let call_args = funcall_to_ext_args(args, vals);
463 let res = lib.eval(name, &call_args, true, false).unwrap_or_else(|e| {
464 panic!("external function '{name}' tangent eval failed: {e}")
465 });
466 let derivs = res.derivs.expect("want_derivs=true returns derivs");
467 let mut acc = 0.0;
468 let mut k = 0usize;
469 for arg in args {
470 if let TapeFuncallArg::Tape(idx) = arg {
471 acc += derivs[k] * dot[*idx];
472 k += 1;
473 }
474 }
475 acc
476 }
477 };
478 }
479
480 for slot in adj.iter_mut().take(n) {
484 *slot = 0.0;
485 }
486 for slot in adj_dot.iter_mut().take(n) {
487 *slot = 0.0;
488 }
489 adj[n - 1] = 1.0;
490
491 for i in (0..n).rev() {
492 let w = adj[i];
493 let wd = adj_dot[i];
494 if w == 0.0 && wd == 0.0 {
495 continue;
496 }
497 match &self.ops[i] {
498 TapeOp::Const(_) => {}
499 TapeOp::Var(k) => {
500 if wd != 0.0 {
501 out[*k] += weight * wd;
502 }
503 }
504 TapeOp::Add(a, b) => {
505 adj[*a] += w;
506 adj[*b] += w;
507 adj_dot[*a] += wd;
508 adj_dot[*b] += wd;
509 }
510 TapeOp::Sub(a, b) => {
511 adj[*a] += w;
512 adj[*b] -= w;
513 adj_dot[*a] += wd;
514 adj_dot[*b] -= wd;
515 }
516 TapeOp::Mul(a, b) => {
517 adj[*a] += w * vals[*b];
518 adj[*b] += w * vals[*a];
519 adj_dot[*a] += wd * vals[*b] + w * dot[*b];
520 adj_dot[*b] += wd * vals[*a] + w * dot[*a];
521 }
522 TapeOp::Div(a, b) => {
523 let vb = vals[*b];
524 let vb2 = vb * vb;
525 let vb3 = vb2 * vb;
526 adj[*a] += w / vb;
527 adj_dot[*a] += wd / vb + w * (-dot[*b] / vb2);
528 adj[*b] += w * (-vals[*a] / vb2);
529 adj_dot[*b] += wd * (-vals[*a] / vb2)
530 + w * (-dot[*a] / vb2 + 2.0 * vals[*a] * dot[*b] / vb3);
531 }
532 TapeOp::Pow(a, b) => {
533 let u = vals[*a];
534 let r = vals[*b];
535 let du = dot[*a];
536 let dr = dot[*b];
537 if r != 0.0 {
538 if u != 0.0 {
539 let p_a = r * u.powf(r - 1.0);
540 adj[*a] += w * p_a;
541 let mut dp_a = dr * u.powf(r - 1.0);
542 if u > 0.0 {
543 dp_a += r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
544 } else {
545 dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
546 }
547 adj_dot[*a] += wd * p_a + w * dp_a;
548 } else if r >= 2.0 {
549 let p_a = 0.0;
550 adj[*a] += w * p_a;
551 let dp_a = if r == 2.0 {
552 2.0 * du
553 } else {
554 r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
555 };
556 adj_dot[*a] += wd * p_a + w * dp_a;
557 }
558 }
559 if u > 0.0 {
560 let ln_u = u.ln();
561 let p_b = vals[i] * ln_u;
562 adj[*b] += w * p_b;
563 let dur = vals[i] * (r * du / u + dr * ln_u);
564 let dp_b = dur * ln_u + vals[i] * du / u;
565 adj_dot[*b] += wd * p_b + w * dp_b;
566 }
567 }
568 TapeOp::Neg(a) => {
569 adj[*a] -= w;
570 adj_dot[*a] -= wd;
571 }
572 TapeOp::Abs(a) => {
573 let s = if vals[*a] >= 0.0 { 1.0 } else { -1.0 };
574 adj[*a] += w * s;
575 adj_dot[*a] += wd * s;
576 }
577 TapeOp::Sqrt(a) => {
578 let sv = vals[i];
579 if sv > 0.0 {
580 let fp = 0.5 / sv;
581 let fpp = -0.25 / (vals[*a] * sv);
582 adj[*a] += w * fp;
583 adj_dot[*a] += wd * fp + w * fpp * dot[*a];
584 }
585 }
586 TapeOp::Exp(a) => {
587 let ev = vals[i];
588 adj[*a] += w * ev;
589 adj_dot[*a] += wd * ev + w * ev * dot[*a];
590 }
591 TapeOp::Log(a) => {
592 let u = vals[*a];
593 adj[*a] += w / u;
594 adj_dot[*a] += wd / u + w * (-1.0 / (u * u)) * dot[*a];
595 }
596 TapeOp::Log10(a) => {
597 let u = vals[*a];
598 let c = std::f64::consts::LN_10;
599 adj[*a] += w / (u * c);
600 adj_dot[*a] += wd / (u * c) + w * (-1.0 / (u * u * c)) * dot[*a];
601 }
602 TapeOp::Sin(a) => {
603 let u = vals[*a];
604 let cu = u.cos();
605 adj[*a] += w * cu;
606 adj_dot[*a] += wd * cu + w * (-u.sin()) * dot[*a];
607 }
608 TapeOp::Cos(a) => {
609 let u = vals[*a];
610 let su = u.sin();
611 adj[*a] -= w * su;
612 adj_dot[*a] += wd * (-su) + w * (-u.cos()) * dot[*a];
613 }
614 TapeOp::Funcall { lib, name, args } => {
615 let call_args = funcall_to_ext_args(args, vals);
616 let res = lib.eval(name, &call_args, true, true).unwrap_or_else(|e| {
617 panic!("external function '{name}' 2nd-order eval failed: {e}")
618 });
619 let derivs = res.derivs.expect("want_derivs=true returns derivs");
620 let hes = res.hessian.expect("want_hes=true returns hessian");
621 let real_tape: Vec<usize> = args
622 .iter()
623 .filter_map(|a| match a {
624 TapeFuncallArg::Tape(t) => Some(*t),
625 TapeFuncallArg::Str(_) => None,
626 })
627 .collect();
628 for (k, &tk) in real_tape.iter().enumerate() {
629 adj[tk] += w * derivs[k];
630 let mut second_term = 0.0;
631 for (l, &tl) in real_tape.iter().enumerate() {
632 let (lo, hi) = if k <= l { (k, l) } else { (l, k) };
633 let h_kl = hes[lo + hi * (hi + 1) / 2];
634 second_term += h_kl * dot[tl];
635 }
636 adj_dot[tk] += wd * derivs[k] + w * second_term;
637 }
638 }
639 }
640 }
641 }
642
643 pub fn hessian_accumulate(
650 &self,
651 x: &[f64],
652 weight: f64,
653 hess_map: &HashMap<(usize, usize), usize>,
654 values: &mut [f64],
655 ) {
656 let n = self.ops.len();
657 if n == 0 || weight == 0.0 {
658 return;
659 }
660 let v = self.forward(x);
661 let var_indices = self.variables();
662
663 let mut dot = vec![0.0f64; n];
670 let mut adj = vec![0.0f64; n];
671 let mut adj_dot = vec![0.0f64; n];
672 for &j in &var_indices {
673 self.forward_tangent(&v, j, &mut dot);
674
675 adj.fill(0.0);
678 adj_dot.fill(0.0);
679 adj[n - 1] = 1.0;
680
681 for i in (0..n).rev() {
682 let w = adj[i];
683 let wd = adj_dot[i];
684 if w == 0.0 && wd == 0.0 {
685 continue;
686 }
687 match &self.ops[i] {
688 TapeOp::Const(_) => {}
689 TapeOp::Var(k) => {
690 if wd != 0.0 && *k >= j {
693 if let Some(&pos) = hess_map.get(&(*k, j)) {
694 values[pos] += weight * wd;
695 }
696 }
697 }
698 TapeOp::Add(a, b) => {
699 adj[*a] += w;
700 adj[*b] += w;
701 adj_dot[*a] += wd;
702 adj_dot[*b] += wd;
703 }
704 TapeOp::Sub(a, b) => {
705 adj[*a] += w;
706 adj[*b] -= w;
707 adj_dot[*a] += wd;
708 adj_dot[*b] -= wd;
709 }
710 TapeOp::Mul(a, b) => {
711 adj[*a] += w * v[*b];
712 adj[*b] += w * v[*a];
713 adj_dot[*a] += wd * v[*b] + w * dot[*b];
714 adj_dot[*b] += wd * v[*a] + w * dot[*a];
715 }
716 TapeOp::Div(a, b) => {
717 let vb = v[*b];
718 let vb2 = vb * vb;
719 let vb3 = vb2 * vb;
720 adj[*a] += w / vb;
721 adj_dot[*a] += wd / vb + w * (-dot[*b] / vb2);
722 adj[*b] += w * (-v[*a] / vb2);
723 adj_dot[*b] += wd * (-v[*a] / vb2)
724 + w * (-dot[*a] / vb2 + 2.0 * v[*a] * dot[*b] / vb3);
725 }
726 TapeOp::Pow(a, b) => {
727 let u = v[*a];
728 let r = v[*b];
729 let du = dot[*a];
730 let dr = dot[*b];
731 if r != 0.0 {
732 if u != 0.0 {
733 let p_a = r * u.powf(r - 1.0);
734 adj[*a] += w * p_a;
735 let mut dp_a = dr * u.powf(r - 1.0);
736 if u > 0.0 {
737 dp_a +=
738 r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
739 } else {
740 dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
741 }
742 adj_dot[*a] += wd * p_a + w * dp_a;
743 } else if r >= 2.0 {
744 let p_a = 0.0;
745 adj[*a] += w * p_a;
746 let dp_a = if r == 2.0 {
747 2.0 * du
748 } else {
749 r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
750 };
751 adj_dot[*a] += wd * p_a + w * dp_a;
752 }
753 }
754 if u > 0.0 {
755 let ln_u = u.ln();
756 let p_b = v[i] * ln_u;
757 adj[*b] += w * p_b;
758 let dur = v[i] * (r * du / u + dr * ln_u);
759 let dp_b = dur * ln_u + v[i] * du / u;
760 adj_dot[*b] += wd * p_b + w * dp_b;
761 }
762 }
763 TapeOp::Neg(a) => {
764 adj[*a] -= w;
765 adj_dot[*a] -= wd;
766 }
767 TapeOp::Abs(a) => {
768 let s = if v[*a] >= 0.0 { 1.0 } else { -1.0 };
769 adj[*a] += w * s;
770 adj_dot[*a] += wd * s;
771 }
772 TapeOp::Sqrt(a) => {
773 let sv = v[i];
774 if sv > 0.0 {
775 let fp = 0.5 / sv;
776 let fpp = -0.25 / (v[*a] * sv);
777 adj[*a] += w * fp;
778 adj_dot[*a] += wd * fp + w * fpp * dot[*a];
779 }
780 }
781 TapeOp::Exp(a) => {
782 let ev = v[i];
783 adj[*a] += w * ev;
784 adj_dot[*a] += wd * ev + w * ev * dot[*a];
785 }
786 TapeOp::Log(a) => {
787 let u = v[*a];
788 adj[*a] += w / u;
789 adj_dot[*a] += wd / u + w * (-1.0 / (u * u)) * dot[*a];
790 }
791 TapeOp::Log10(a) => {
792 let u = v[*a];
793 let c = std::f64::consts::LN_10;
794 adj[*a] += w / (u * c);
795 adj_dot[*a] += wd / (u * c) + w * (-1.0 / (u * u * c)) * dot[*a];
796 }
797 TapeOp::Sin(a) => {
798 let u = v[*a];
799 let cu = u.cos();
800 adj[*a] += w * cu;
801 adj_dot[*a] += wd * cu + w * (-u.sin()) * dot[*a];
802 }
803 TapeOp::Cos(a) => {
804 let u = v[*a];
805 let su = u.sin();
806 adj[*a] -= w * su;
807 adj_dot[*a] += wd * (-su) + w * (-u.cos()) * dot[*a];
808 }
809 TapeOp::Funcall { lib, name, args } => {
810 let call_args = funcall_to_ext_args(args, &v);
811 let res = lib.eval(name, &call_args, true, true).unwrap_or_else(|e| {
812 panic!("external function '{name}' 2nd-order eval failed: {e}")
813 });
814 let derivs = res.derivs.expect("want_derivs=true returns derivs");
815 let hes = res.hessian.expect("want_hes=true returns hessian");
816 let real_tape: Vec<usize> = args
817 .iter()
818 .filter_map(|a| match a {
819 TapeFuncallArg::Tape(t) => Some(*t),
820 TapeFuncallArg::Str(_) => None,
821 })
822 .collect();
823 for (k, &tk) in real_tape.iter().enumerate() {
824 adj[tk] += w * derivs[k];
825 let mut second_term = 0.0;
826 for (l, &tl) in real_tape.iter().enumerate() {
827 let (lo, hi) = if k <= l { (k, l) } else { (l, k) };
828 let h_kl = hes[lo + hi * (hi + 1) / 2];
829 second_term += h_kl * dot[tl];
830 }
831 adj_dot[tk] += wd * derivs[k] + w * second_term;
832 }
833 }
834 }
835 }
836 }
837 }
838
839 pub fn hessian_sparsity(&self) -> BTreeSet<(usize, usize)> {
844 let n = self.ops.len();
845 let mut var_sets: Vec<BTreeSet<usize>> = Vec::with_capacity(n);
846 let mut pairs: BTreeSet<(usize, usize)> = BTreeSet::new();
847
848 let emit_cross =
849 |s1: &BTreeSet<usize>, s2: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
850 for &v1 in s1 {
851 for &v2 in s2 {
852 let (r, c) = if v1 >= v2 { (v1, v2) } else { (v2, v1) };
853 pairs.insert((r, c));
854 }
855 }
856 };
857 let emit_self = |s: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
858 let vars: Vec<usize> = s.iter().copied().collect();
859 for (ai, &vi) in vars.iter().enumerate() {
860 for &vj in &vars[..=ai] {
861 let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
862 pairs.insert((r, c));
863 }
864 }
865 };
866
867 for op in &self.ops {
868 let vset = match op {
869 TapeOp::Const(_) => BTreeSet::new(),
870 TapeOp::Var(j) => {
871 let mut s = BTreeSet::new();
872 s.insert(*j);
873 s
874 }
875 TapeOp::Add(a, b) | TapeOp::Sub(a, b) => {
876 var_sets[*a].union(&var_sets[*b]).copied().collect()
877 }
878 TapeOp::Neg(a) | TapeOp::Abs(a) => var_sets[*a].clone(),
879 TapeOp::Mul(a, b) => {
880 emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
881 var_sets[*a].union(&var_sets[*b]).copied().collect()
882 }
883 TapeOp::Div(a, b) => {
884 emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
885 emit_self(&var_sets[*b], &mut pairs);
886 var_sets[*a].union(&var_sets[*b]).copied().collect()
887 }
888 TapeOp::Pow(a, b) => {
889 let combined: BTreeSet<usize> =
890 var_sets[*a].union(&var_sets[*b]).copied().collect();
891 emit_self(&combined, &mut pairs);
892 combined
893 }
894 TapeOp::Sqrt(a)
895 | TapeOp::Exp(a)
896 | TapeOp::Log(a)
897 | TapeOp::Log10(a)
898 | TapeOp::Sin(a)
899 | TapeOp::Cos(a) => {
900 emit_self(&var_sets[*a], &mut pairs);
901 var_sets[*a].clone()
902 }
903 TapeOp::Funcall { args, .. } => {
904 let mut combined: BTreeSet<usize> = BTreeSet::new();
905 for arg in args {
906 if let TapeFuncallArg::Tape(t) = arg {
907 for &vv in &var_sets[*t] {
908 combined.insert(vv);
909 }
910 }
911 }
912 emit_self(&combined, &mut pairs);
913 combined
914 }
915 };
916 var_sets.push(vset);
917 }
918 pairs
919 }
920}
921
922fn build_recursive(
923 expr: &Expr,
924 ops: &mut Vec<TapeOp>,
925 cache: &mut HashMap<*const Expr, usize>,
926 resolver: &ExternalResolver,
927) -> usize {
928 match expr {
929 Expr::Const(c) => {
930 let idx = ops.len();
931 ops.push(TapeOp::Const(*c));
932 idx
933 }
934 Expr::Var(i) => {
935 let idx = ops.len();
936 ops.push(TapeOp::Var(*i));
937 idx
938 }
939 Expr::Binary(op, a, b) => {
940 if let BinOp::Pow = op {
948 if let Some(c) = peek_const(b) {
949 if let Some(idx) = try_emit_const_pow(a, c, ops, cache, resolver) {
950 return idx;
951 }
952 }
953 }
954 let l = build_recursive(a, ops, cache, resolver);
955 let r = build_recursive(b, ops, cache, resolver);
956 let idx = ops.len();
957 ops.push(match op {
958 BinOp::Add => TapeOp::Add(l, r),
959 BinOp::Sub => TapeOp::Sub(l, r),
960 BinOp::Mul => TapeOp::Mul(l, r),
961 BinOp::Div => TapeOp::Div(l, r),
962 BinOp::Pow => TapeOp::Pow(l, r),
963 });
964 idx
965 }
966 Expr::Unary(op, a) => {
967 let v = build_recursive(a, ops, cache, resolver);
968 let idx = ops.len();
969 ops.push(match op {
970 UnaryOp::Neg => TapeOp::Neg(v),
971 UnaryOp::Sqrt => TapeOp::Sqrt(v),
972 UnaryOp::Log => TapeOp::Log(v),
973 UnaryOp::Log10 => TapeOp::Log10(v),
974 UnaryOp::Exp => TapeOp::Exp(v),
975 UnaryOp::Abs => TapeOp::Abs(v),
976 UnaryOp::Sin => TapeOp::Sin(v),
977 UnaryOp::Cos => TapeOp::Cos(v),
978 });
979 idx
980 }
981 Expr::Sum(args) => {
982 if args.is_empty() {
983 let idx = ops.len();
984 ops.push(TapeOp::Const(0.0));
985 return idx;
986 }
987 let mut acc = build_recursive(&args[0], ops, cache, resolver);
988 for a in &args[1..] {
989 let next = build_recursive(a, ops, cache, resolver);
990 let idx = ops.len();
991 ops.push(TapeOp::Add(acc, next));
992 acc = idx;
993 }
994 acc
995 }
996 Expr::Cse(body) => {
997 let key = Rc::as_ptr(body) as *const Expr;
1004 if let Some(&idx) = cache.get(&key) {
1005 idx
1006 } else {
1007 let idx = build_recursive(body, ops, cache, resolver);
1008 cache.insert(key, idx);
1009 idx
1010 }
1011 }
1012 Expr::Funcall { id, args } => {
1013 let (lib, name) = resolver
1014 .funcs_by_id
1015 .get(id)
1016 .unwrap_or_else(|| panic!("unresolved AMPL funcall id {id}"));
1017 let tape_args: Vec<TapeFuncallArg> = args
1018 .iter()
1019 .map(|a| match a {
1020 FuncallArg::Real(e) => {
1021 TapeFuncallArg::Tape(build_recursive(e, ops, cache, resolver))
1022 }
1023 FuncallArg::Str(s) => TapeFuncallArg::Str(s.clone()),
1024 })
1025 .collect();
1026 let idx = ops.len();
1027 ops.push(TapeOp::Funcall {
1028 lib: Arc::clone(lib),
1029 name: name.clone(),
1030 args: tape_args,
1031 });
1032 idx
1033 }
1034 }
1035}
1036
1037fn peek_const(e: &Expr) -> Option<f64> {
1041 match e {
1042 Expr::Const(c) => Some(*c),
1043 Expr::Cse(body) => peek_const(body),
1044 _ => None,
1045 }
1046}
1047
1048fn try_emit_const_pow(
1056 base_expr: &Expr,
1057 c: f64,
1058 ops: &mut Vec<TapeOp>,
1059 cache: &mut HashMap<*const Expr, usize>,
1060 resolver: &ExternalResolver,
1061) -> Option<usize> {
1062 if c == 0.0 {
1063 let idx = ops.len();
1064 ops.push(TapeOp::Const(1.0));
1065 return Some(idx);
1066 }
1067 if c == 1.0 {
1068 return Some(build_recursive(base_expr, ops, cache, resolver));
1069 }
1070 if c == 0.5 {
1071 let b = build_recursive(base_expr, ops, cache, resolver);
1072 let idx = ops.len();
1073 ops.push(TapeOp::Sqrt(b));
1074 return Some(idx);
1075 }
1076 if c.is_finite() && c.fract() == 0.0 && c.abs() <= 8.0 {
1081 let n = c.abs() as u32;
1082 if n == 0 {
1083 let idx = ops.len();
1085 ops.push(TapeOp::Const(1.0));
1086 return Some(idx);
1087 }
1088 let b = build_recursive(base_expr, ops, cache, resolver);
1089 let pos = emit_int_pow(b, n, ops);
1090 if c < 0.0 {
1091 let one_idx = ops.len();
1094 ops.push(TapeOp::Const(1.0));
1095 let idx = ops.len();
1096 ops.push(TapeOp::Div(one_idx, pos));
1097 return Some(idx);
1098 }
1099 return Some(pos);
1100 }
1101 None
1102}
1103
1104fn emit_int_pow(base: usize, n: u32, ops: &mut Vec<TapeOp>) -> usize {
1108 debug_assert!(n >= 1);
1109 if n == 1 {
1110 return base;
1111 }
1112 let half = emit_int_pow(base, n / 2, ops);
1113 let squared = ops.len();
1114 ops.push(TapeOp::Mul(half, half));
1115 if n % 2 == 1 {
1116 let idx = ops.len();
1117 ops.push(TapeOp::Mul(squared, base));
1118 idx
1119 } else {
1120 squared
1121 }
1122}
1123
1124#[derive(Debug, Clone)]
1152pub enum SummandOp {
1153 Local(TapeOp),
1156 Shared(usize),
1160}
1161
1162#[derive(Debug, Clone)]
1163pub struct Summand {
1164 pub ops: Vec<SummandOp>,
1165 pub root_slot: usize,
1167 pub local_reach: Vec<usize>,
1169 pub prelude_reach: Vec<usize>,
1172 pub local_vars: Vec<usize>,
1174 pub prelude_vars: Vec<usize>,
1176 pub all_vars: Vec<usize>,
1178}
1179
1180#[derive(Debug)]
1181pub struct HybridTape {
1182 pub prelude: Vec<TapeOp>,
1187 pub summands: Vec<Summand>,
1188}
1189
1190impl HybridTape {
1191 pub fn build_multi(exprs: &[Expr]) -> Self {
1196 let mut cse_count: HashMap<*const Expr, usize> = HashMap::new();
1200 for e in exprs {
1201 let mut seen_in_root: HashSet<*const Expr> = HashSet::new();
1202 count_cse_appearances(e, &mut seen_in_root, &mut cse_count);
1203 }
1204
1205 let mut prelude: Vec<TapeOp> = Vec::new();
1210 let mut prelude_map: HashMap<*const Expr, usize> = HashMap::new();
1211 let mut summands: Vec<Summand> = Vec::with_capacity(exprs.len());
1212 for e in exprs {
1213 let mut local: Vec<SummandOp> = Vec::new();
1214 let mut local_cache: HashMap<*const Expr, usize> = HashMap::new();
1215 let root_slot = build_into_summand(
1216 e,
1217 &mut local,
1218 &mut local_cache,
1219 &mut prelude,
1220 &mut prelude_map,
1221 &cse_count,
1222 );
1223 summands.push(Summand {
1224 ops: local,
1225 root_slot,
1226 local_reach: Vec::new(),
1227 prelude_reach: Vec::new(),
1228 local_vars: Vec::new(),
1229 prelude_vars: Vec::new(),
1230 all_vars: Vec::new(),
1231 });
1232 }
1233
1234 let mut p_visited: Vec<u32> = vec![0; prelude.len()];
1238 let mut p_epoch: u32 = 0;
1239 let mut p_stack: Vec<usize> = Vec::new();
1240 for s in &mut summands {
1241 let (local_reach, shared_refs) = compute_local_reach(&s.ops, s.root_slot);
1242 s.local_reach = local_reach;
1243
1244 let mut lv: BTreeSet<usize> = BTreeSet::new();
1245 for &i in &s.local_reach {
1246 if let SummandOp::Local(TapeOp::Var(j)) = &s.ops[i] {
1247 lv.insert(*j);
1248 }
1249 }
1250 s.local_vars = lv.iter().copied().collect();
1251
1252 if !shared_refs.is_empty() {
1253 p_epoch += 1;
1254 let mut preach: Vec<usize> = Vec::new();
1255 for &start in &shared_refs {
1256 bfs_prelude(
1257 &prelude,
1258 start,
1259 &mut p_visited,
1260 p_epoch,
1261 &mut p_stack,
1262 &mut preach,
1263 );
1264 }
1265 preach.sort_unstable();
1266 s.prelude_vars = vars_in(&prelude, &preach);
1267 s.prelude_reach = preach;
1268 }
1269
1270 let mut av: BTreeSet<usize> = lv;
1271 for &v in &s.prelude_vars {
1272 av.insert(v);
1273 }
1274 s.all_vars = av.into_iter().collect();
1275 }
1276
1277 HybridTape { prelude, summands }
1278 }
1279
1280 pub fn n_prelude_ops(&self) -> usize {
1281 self.prelude.len()
1282 }
1283 pub fn n_summands(&self) -> usize {
1284 self.summands.len()
1285 }
1286 pub fn max_summand_ops(&self) -> usize {
1287 self.summands.iter().map(|s| s.ops.len()).max().unwrap_or(0)
1288 }
1289 pub fn total_local_ops(&self) -> usize {
1290 self.summands.iter().map(|s| s.ops.len()).sum()
1291 }
1292
1293 pub fn forward_prelude(&self, x: &[f64], prelude_vals: &mut [f64]) {
1296 debug_assert_eq!(prelude_vals.len(), self.prelude.len());
1297 for i in 0..self.prelude.len() {
1298 prelude_vals[i] = fwd_step(&self.prelude[i], x, prelude_vals);
1299 }
1300 }
1301
1302 pub fn forward_summand(
1305 &self,
1306 s: &Summand,
1307 x: &[f64],
1308 prelude_vals: &[f64],
1309 local_vals: &mut [f64],
1310 ) {
1311 debug_assert!(local_vals.len() >= s.ops.len());
1312 for i in 0..s.ops.len() {
1313 local_vals[i] = match &s.ops[i] {
1314 SummandOp::Local(op) => fwd_step(op, x, local_vals),
1315 SummandOp::Shared(k) => prelude_vals[*k],
1316 };
1317 }
1318 }
1319
1320 #[inline]
1322 pub fn root_value(&self, s: &Summand, local_vals: &[f64]) -> f64 {
1323 local_vals[s.root_slot]
1324 }
1325
1326 #[allow(clippy::too_many_arguments)]
1333 pub fn gradient_summand(
1334 &self,
1335 s: &Summand,
1336 prelude_vals: &[f64],
1337 local_vals: &[f64],
1338 seed: f64,
1339 grad: &mut [f64],
1340 local_adj: &mut [f64],
1341 prelude_adj: &mut [f64],
1342 ) {
1343 if seed == 0.0 || s.local_reach.is_empty() {
1344 return;
1345 }
1346 for &i in &s.local_reach {
1347 local_adj[i] = 0.0;
1348 }
1349 for &i in &s.prelude_reach {
1350 prelude_adj[i] = 0.0;
1351 }
1352 local_adj[s.root_slot] = seed;
1353 for &i in s.local_reach.iter().rev() {
1354 let a = local_adj[i];
1355 if a == 0.0 {
1356 continue;
1357 }
1358 match &s.ops[i] {
1359 SummandOp::Local(op) => rev_step(op, i, local_vals, local_adj, a, grad),
1360 SummandOp::Shared(k) => {
1361 prelude_adj[*k] += a;
1362 }
1363 }
1364 }
1365 for &i in s.prelude_reach.iter().rev() {
1366 let a = prelude_adj[i];
1367 if a == 0.0 {
1368 continue;
1369 }
1370 rev_step(&self.prelude[i], i, prelude_vals, prelude_adj, a, grad);
1371 }
1372 }
1373
1374 #[allow(clippy::too_many_arguments)]
1382 pub fn hessian_summand(
1383 &self,
1384 s: &Summand,
1385 prelude_vals: &[f64],
1386 local_vals: &[f64],
1387 weight: f64,
1388 hess_map: &HashMap<(usize, usize), usize>,
1389 values: &mut [f64],
1390 local_dot: &mut [f64],
1391 local_adj: &mut [f64],
1392 local_adj_dot: &mut [f64],
1393 prelude_dot: &mut [f64],
1394 prelude_adj: &mut [f64],
1395 prelude_adj_dot: &mut [f64],
1396 ) {
1397 if weight == 0.0 || s.local_reach.is_empty() {
1398 return;
1399 }
1400 for &j in &s.all_vars {
1401 for &i in &s.local_reach {
1402 local_dot[i] = 0.0;
1403 local_adj[i] = 0.0;
1404 local_adj_dot[i] = 0.0;
1405 }
1406 for &i in &s.prelude_reach {
1407 prelude_dot[i] = 0.0;
1408 prelude_adj[i] = 0.0;
1409 prelude_adj_dot[i] = 0.0;
1410 }
1411 for &i in &s.prelude_reach {
1412 prelude_dot[i] = fwd_tan_step(&self.prelude[i], j, prelude_vals, prelude_dot, i);
1413 }
1414 for &i in &s.local_reach {
1415 local_dot[i] = match &s.ops[i] {
1416 SummandOp::Local(op) => fwd_tan_step(op, j, local_vals, local_dot, i),
1417 SummandOp::Shared(k) => prelude_dot[*k],
1418 };
1419 }
1420 local_adj[s.root_slot] = 1.0;
1421 for &i in s.local_reach.iter().rev() {
1422 let w = local_adj[i];
1423 let wd = local_adj_dot[i];
1424 if w == 0.0 && wd == 0.0 {
1425 continue;
1426 }
1427 match &s.ops[i] {
1428 SummandOp::Local(op) => {
1429 ror_step(
1430 op,
1431 i,
1432 j,
1433 local_vals,
1434 local_dot,
1435 local_adj,
1436 local_adj_dot,
1437 w,
1438 wd,
1439 weight,
1440 hess_map,
1441 values,
1442 );
1443 }
1444 SummandOp::Shared(k) => {
1445 prelude_adj[*k] += w;
1446 prelude_adj_dot[*k] += wd;
1447 }
1448 }
1449 }
1450 for &i in s.prelude_reach.iter().rev() {
1451 let w = prelude_adj[i];
1452 let wd = prelude_adj_dot[i];
1453 if w == 0.0 && wd == 0.0 {
1454 continue;
1455 }
1456 ror_step(
1457 &self.prelude[i],
1458 i,
1459 j,
1460 prelude_vals,
1461 prelude_dot,
1462 prelude_adj,
1463 prelude_adj_dot,
1464 w,
1465 wd,
1466 weight,
1467 hess_map,
1468 values,
1469 );
1470 }
1471 }
1472 }
1473
1474 pub fn hessian_sparsity_all(&self) -> BTreeSet<(usize, usize)> {
1477 let mut pairs = hessian_sparsity_impl(&self.prelude);
1478
1479 let prelude_var_sets = compute_var_sets(&self.prelude);
1482
1483 for s in &self.summands {
1484 summand_sparsity(&s.ops, &prelude_var_sets, &mut pairs);
1485 }
1486 pairs
1487 }
1488}
1489
1490fn count_cse_appearances(
1495 e: &Expr,
1496 seen_in_root: &mut HashSet<*const Expr>,
1497 counts: &mut HashMap<*const Expr, usize>,
1498) {
1499 match e {
1500 Expr::Const(_) | Expr::Var(_) => {}
1501 Expr::Binary(_, a, b) => {
1502 count_cse_appearances(a, seen_in_root, counts);
1503 count_cse_appearances(b, seen_in_root, counts);
1504 }
1505 Expr::Unary(_, a) => count_cse_appearances(a, seen_in_root, counts),
1506 Expr::Sum(args) => {
1507 for a in args {
1508 count_cse_appearances(a, seen_in_root, counts);
1509 }
1510 }
1511 Expr::Cse(body) => {
1512 let key = Rc::as_ptr(body) as *const Expr;
1513 if seen_in_root.insert(key) {
1514 *counts.entry(key).or_insert(0) += 1;
1515 count_cse_appearances(body, seen_in_root, counts);
1516 }
1517 }
1518 Expr::Funcall { args, .. } => {
1519 for arg in args {
1520 if let FuncallArg::Real(e) = arg {
1521 count_cse_appearances(e, seen_in_root, counts);
1522 }
1523 }
1524 }
1525 }
1526}
1527
1528fn build_into_summand(
1534 expr: &Expr,
1535 local: &mut Vec<SummandOp>,
1536 local_cache: &mut HashMap<*const Expr, usize>,
1537 prelude: &mut Vec<TapeOp>,
1538 prelude_map: &mut HashMap<*const Expr, usize>,
1539 cse_count: &HashMap<*const Expr, usize>,
1540) -> usize {
1541 match expr {
1542 Expr::Const(c) => {
1543 let i = local.len();
1544 local.push(SummandOp::Local(TapeOp::Const(*c)));
1545 i
1546 }
1547 Expr::Var(j) => {
1548 let i = local.len();
1549 local.push(SummandOp::Local(TapeOp::Var(*j)));
1550 i
1551 }
1552 Expr::Binary(op, a, b) => {
1553 if let BinOp::Pow = op {
1554 if let Some(c) = peek_const(b) {
1555 if let Some(i) = try_emit_const_pow_summand(
1556 a,
1557 c,
1558 local,
1559 local_cache,
1560 prelude,
1561 prelude_map,
1562 cse_count,
1563 ) {
1564 return i;
1565 }
1566 }
1567 }
1568 let l = build_into_summand(a, local, local_cache, prelude, prelude_map, cse_count);
1569 let r = build_into_summand(b, local, local_cache, prelude, prelude_map, cse_count);
1570 let i = local.len();
1571 local.push(SummandOp::Local(match op {
1572 BinOp::Add => TapeOp::Add(l, r),
1573 BinOp::Sub => TapeOp::Sub(l, r),
1574 BinOp::Mul => TapeOp::Mul(l, r),
1575 BinOp::Div => TapeOp::Div(l, r),
1576 BinOp::Pow => TapeOp::Pow(l, r),
1577 }));
1578 i
1579 }
1580 Expr::Unary(op, a) => {
1581 let v = build_into_summand(a, local, local_cache, prelude, prelude_map, cse_count);
1582 let i = local.len();
1583 local.push(SummandOp::Local(match op {
1584 UnaryOp::Neg => TapeOp::Neg(v),
1585 UnaryOp::Sqrt => TapeOp::Sqrt(v),
1586 UnaryOp::Log => TapeOp::Log(v),
1587 UnaryOp::Log10 => TapeOp::Log10(v),
1588 UnaryOp::Exp => TapeOp::Exp(v),
1589 UnaryOp::Abs => TapeOp::Abs(v),
1590 UnaryOp::Sin => TapeOp::Sin(v),
1591 UnaryOp::Cos => TapeOp::Cos(v),
1592 }));
1593 i
1594 }
1595 Expr::Sum(args) => {
1596 if args.is_empty() {
1597 let i = local.len();
1598 local.push(SummandOp::Local(TapeOp::Const(0.0)));
1599 return i;
1600 }
1601 let mut acc = build_into_summand(
1602 &args[0],
1603 local,
1604 local_cache,
1605 prelude,
1606 prelude_map,
1607 cse_count,
1608 );
1609 for a in &args[1..] {
1610 let nxt =
1611 build_into_summand(a, local, local_cache, prelude, prelude_map, cse_count);
1612 let i = local.len();
1613 local.push(SummandOp::Local(TapeOp::Add(acc, nxt)));
1614 acc = i;
1615 }
1616 acc
1617 }
1618 Expr::Cse(body) => {
1619 let key = Rc::as_ptr(body) as *const Expr;
1620 if let Some(&li) = local_cache.get(&key) {
1621 return li;
1622 }
1623 let promoted = cse_count.get(&key).copied().unwrap_or(0) >= 2;
1624 if promoted {
1625 let pslot =
1630 build_recursive(expr, prelude, prelude_map, &ExternalResolver::default());
1631 let li = local.len();
1632 local.push(SummandOp::Shared(pslot));
1633 local_cache.insert(key, li);
1634 li
1635 } else {
1636 let li =
1637 build_into_summand(body, local, local_cache, prelude, prelude_map, cse_count);
1638 local_cache.insert(key, li);
1639 li
1640 }
1641 }
1642 Expr::Funcall { .. } => {
1643 panic!(
1644 "HybridTape: AMPL external function calls are not supported on the \
1645 hybrid (partial-separability) tape path. Build with Tape::build_with_externals \
1646 instead."
1647 );
1648 }
1649 }
1650}
1651
1652fn try_emit_const_pow_summand(
1655 base_expr: &Expr,
1656 c: f64,
1657 local: &mut Vec<SummandOp>,
1658 local_cache: &mut HashMap<*const Expr, usize>,
1659 prelude: &mut Vec<TapeOp>,
1660 prelude_map: &mut HashMap<*const Expr, usize>,
1661 cse_count: &HashMap<*const Expr, usize>,
1662) -> Option<usize> {
1663 if c == 0.0 {
1664 let i = local.len();
1665 local.push(SummandOp::Local(TapeOp::Const(1.0)));
1666 return Some(i);
1667 }
1668 if c == 1.0 {
1669 return Some(build_into_summand(
1670 base_expr,
1671 local,
1672 local_cache,
1673 prelude,
1674 prelude_map,
1675 cse_count,
1676 ));
1677 }
1678 if c == 0.5 {
1679 let b = build_into_summand(
1680 base_expr,
1681 local,
1682 local_cache,
1683 prelude,
1684 prelude_map,
1685 cse_count,
1686 );
1687 let i = local.len();
1688 local.push(SummandOp::Local(TapeOp::Sqrt(b)));
1689 return Some(i);
1690 }
1691 if c.is_finite() && c.fract() == 0.0 && c.abs() <= 8.0 {
1692 let n = c.abs() as u32;
1693 if n == 0 {
1694 let i = local.len();
1695 local.push(SummandOp::Local(TapeOp::Const(1.0)));
1696 return Some(i);
1697 }
1698 let b = build_into_summand(
1699 base_expr,
1700 local,
1701 local_cache,
1702 prelude,
1703 prelude_map,
1704 cse_count,
1705 );
1706 let pos = emit_int_pow_summand(b, n, local);
1707 if c < 0.0 {
1708 let one_idx = local.len();
1709 local.push(SummandOp::Local(TapeOp::Const(1.0)));
1710 let i = local.len();
1711 local.push(SummandOp::Local(TapeOp::Div(one_idx, pos)));
1712 return Some(i);
1713 }
1714 return Some(pos);
1715 }
1716 None
1717}
1718
1719fn emit_int_pow_summand(base: usize, n: u32, local: &mut Vec<SummandOp>) -> usize {
1720 debug_assert!(n >= 1);
1721 if n == 1 {
1722 return base;
1723 }
1724 let half = emit_int_pow_summand(base, n / 2, local);
1725 let squared = local.len();
1726 local.push(SummandOp::Local(TapeOp::Mul(half, half)));
1727 if n % 2 == 1 {
1728 let i = local.len();
1729 local.push(SummandOp::Local(TapeOp::Mul(squared, base)));
1730 i
1731 } else {
1732 squared
1733 }
1734}
1735
1736fn compute_local_reach(ops: &[SummandOp], root: usize) -> (Vec<usize>, Vec<usize>) {
1740 let mut visited = vec![false; ops.len()];
1741 let mut reach: Vec<usize> = Vec::new();
1742 let mut shared: BTreeSet<usize> = BTreeSet::new();
1743 let mut stack: Vec<usize> = Vec::with_capacity(16);
1744 visited[root] = true;
1745 reach.push(root);
1746 stack.push(root);
1747 while let Some(s) = stack.pop() {
1748 match &ops[s] {
1749 SummandOp::Local(op) => {
1750 let (a, b) = op_operands(op);
1751 if let Some(a) = a {
1752 if !visited[a] {
1753 visited[a] = true;
1754 reach.push(a);
1755 stack.push(a);
1756 }
1757 }
1758 if let Some(b) = b {
1759 if !visited[b] {
1760 visited[b] = true;
1761 reach.push(b);
1762 stack.push(b);
1763 }
1764 }
1765 }
1766 SummandOp::Shared(k) => {
1767 shared.insert(*k);
1768 }
1769 }
1770 }
1771 reach.sort_unstable();
1772 (reach, shared.into_iter().collect())
1773}
1774
1775fn bfs_prelude(
1779 prelude: &[TapeOp],
1780 start: usize,
1781 visited: &mut [u32],
1782 cur: u32,
1783 stack: &mut Vec<usize>,
1784 out: &mut Vec<usize>,
1785) {
1786 if visited[start] == cur {
1787 return;
1788 }
1789 visited[start] = cur;
1790 out.push(start);
1791 stack.push(start);
1792 while let Some(s) = stack.pop() {
1793 let (a, b) = op_operands(&prelude[s]);
1794 if let Some(a) = a {
1795 if visited[a] != cur {
1796 visited[a] = cur;
1797 out.push(a);
1798 stack.push(a);
1799 }
1800 }
1801 if let Some(b) = b {
1802 if visited[b] != cur {
1803 visited[b] = cur;
1804 out.push(b);
1805 stack.push(b);
1806 }
1807 }
1808 }
1809}
1810
1811fn compute_var_sets(ops: &[TapeOp]) -> Vec<BTreeSet<usize>> {
1815 let mut out: Vec<BTreeSet<usize>> = Vec::with_capacity(ops.len());
1816 for op in ops {
1817 let vs: BTreeSet<usize> = match op {
1818 TapeOp::Const(_) => BTreeSet::new(),
1819 TapeOp::Var(j) => {
1820 let mut s = BTreeSet::new();
1821 s.insert(*j);
1822 s
1823 }
1824 TapeOp::Add(a, b)
1825 | TapeOp::Sub(a, b)
1826 | TapeOp::Mul(a, b)
1827 | TapeOp::Div(a, b)
1828 | TapeOp::Pow(a, b) => out[*a].union(&out[*b]).copied().collect(),
1829 TapeOp::Neg(a)
1830 | TapeOp::Abs(a)
1831 | TapeOp::Sqrt(a)
1832 | TapeOp::Exp(a)
1833 | TapeOp::Log(a)
1834 | TapeOp::Log10(a)
1835 | TapeOp::Sin(a)
1836 | TapeOp::Cos(a) => out[*a].clone(),
1837 TapeOp::Funcall { .. } => unreachable!(
1838 "HybridTape prelude cannot contain TapeOp::Funcall; \
1839 build_into_summand panics on Expr::Funcall."
1840 ),
1841 };
1842 out.push(vs);
1843 }
1844 out
1845}
1846
1847fn summand_sparsity(
1852 ops: &[SummandOp],
1853 prelude_var_sets: &[BTreeSet<usize>],
1854 pairs: &mut BTreeSet<(usize, usize)>,
1855) {
1856 let mut var_sets: Vec<BTreeSet<usize>> = Vec::with_capacity(ops.len());
1857 let emit_cross =
1858 |s1: &BTreeSet<usize>, s2: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
1859 for &v1 in s1 {
1860 for &v2 in s2 {
1861 let (r, c) = if v1 >= v2 { (v1, v2) } else { (v2, v1) };
1862 pairs.insert((r, c));
1863 }
1864 }
1865 };
1866 let emit_self = |s: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
1867 let vars: Vec<usize> = s.iter().copied().collect();
1868 for (ai, &vi) in vars.iter().enumerate() {
1869 for &vj in &vars[..=ai] {
1870 let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
1871 pairs.insert((r, c));
1872 }
1873 }
1874 };
1875 for so in ops {
1876 let vset: BTreeSet<usize> = match so {
1877 SummandOp::Shared(k) => prelude_var_sets[*k].clone(),
1878 SummandOp::Local(op) => match op {
1879 TapeOp::Const(_) => BTreeSet::new(),
1880 TapeOp::Var(j) => {
1881 let mut s = BTreeSet::new();
1882 s.insert(*j);
1883 s
1884 }
1885 TapeOp::Add(a, b) | TapeOp::Sub(a, b) => {
1886 var_sets[*a].union(&var_sets[*b]).copied().collect()
1887 }
1888 TapeOp::Neg(a) | TapeOp::Abs(a) => var_sets[*a].clone(),
1889 TapeOp::Mul(a, b) => {
1890 emit_cross(&var_sets[*a], &var_sets[*b], pairs);
1891 var_sets[*a].union(&var_sets[*b]).copied().collect()
1892 }
1893 TapeOp::Div(a, b) => {
1894 emit_cross(&var_sets[*a], &var_sets[*b], pairs);
1895 emit_self(&var_sets[*b], pairs);
1896 var_sets[*a].union(&var_sets[*b]).copied().collect()
1897 }
1898 TapeOp::Pow(a, b) => {
1899 let combined: BTreeSet<usize> =
1900 var_sets[*a].union(&var_sets[*b]).copied().collect();
1901 emit_self(&combined, pairs);
1902 combined
1903 }
1904 TapeOp::Sqrt(a)
1905 | TapeOp::Exp(a)
1906 | TapeOp::Log(a)
1907 | TapeOp::Log10(a)
1908 | TapeOp::Sin(a)
1909 | TapeOp::Cos(a) => {
1910 emit_self(&var_sets[*a], pairs);
1911 var_sets[*a].clone()
1912 }
1913 TapeOp::Funcall { .. } => unreachable!(
1914 "HybridTape summand cannot contain TapeOp::Funcall; \
1915 build_into_summand panics on Expr::Funcall."
1916 ),
1917 },
1918 };
1919 var_sets.push(vset);
1920 }
1921}
1922
1923#[inline]
1926fn op_operands(op: &TapeOp) -> (Option<usize>, Option<usize>) {
1927 match op {
1928 TapeOp::Const(_) | TapeOp::Var(_) => (None, None),
1929 TapeOp::Add(a, b)
1930 | TapeOp::Sub(a, b)
1931 | TapeOp::Mul(a, b)
1932 | TapeOp::Div(a, b)
1933 | TapeOp::Pow(a, b) => (Some(*a), Some(*b)),
1934 TapeOp::Neg(a)
1935 | TapeOp::Abs(a)
1936 | TapeOp::Sqrt(a)
1937 | TapeOp::Exp(a)
1938 | TapeOp::Log(a)
1939 | TapeOp::Log10(a)
1940 | TapeOp::Sin(a)
1941 | TapeOp::Cos(a) => (Some(*a), None),
1942 TapeOp::Funcall { .. } => (None, None),
1943 }
1944}
1945
1946fn vars_in(ops: &[TapeOp], reach: &[usize]) -> Vec<usize> {
1947 let mut s: BTreeSet<usize> = BTreeSet::new();
1948 for &i in reach {
1949 if let TapeOp::Var(j) = &ops[i] {
1950 s.insert(*j);
1951 }
1952 }
1953 s.into_iter().collect()
1954}
1955
1956#[inline]
1959fn fwd_step(op: &TapeOp, x: &[f64], vals: &[f64]) -> f64 {
1960 match op {
1961 TapeOp::Const(c) => *c,
1962 TapeOp::Var(i) => x[*i],
1963 TapeOp::Add(a, b) => vals[*a] + vals[*b],
1964 TapeOp::Sub(a, b) => vals[*a] - vals[*b],
1965 TapeOp::Mul(a, b) => vals[*a] * vals[*b],
1966 TapeOp::Div(a, b) => vals[*a] / vals[*b],
1967 TapeOp::Pow(a, b) => vals[*a].powf(vals[*b]),
1968 TapeOp::Neg(a) => -vals[*a],
1969 TapeOp::Abs(a) => vals[*a].abs(),
1970 TapeOp::Sqrt(a) => vals[*a].sqrt(),
1971 TapeOp::Exp(a) => vals[*a].exp(),
1972 TapeOp::Log(a) => vals[*a].ln(),
1973 TapeOp::Log10(a) => vals[*a].log10(),
1974 TapeOp::Sin(a) => vals[*a].sin(),
1975 TapeOp::Cos(a) => vals[*a].cos(),
1976 TapeOp::Funcall { lib, name, args } => {
1977 let call_args = funcall_to_ext_args(args, vals);
1978 let res = lib
1979 .eval(name, &call_args, false, false)
1980 .unwrap_or_else(|e| panic!("external function '{name}' eval failed: {e}"));
1981 res.value
1982 }
1983 }
1984}
1985
1986#[inline]
1987fn rev_step(op: &TapeOp, i: usize, vals: &[f64], adj: &mut [f64], a: f64, grad: &mut [f64]) {
1988 match op {
1989 TapeOp::Const(_) => {}
1990 TapeOp::Var(j) => {
1991 grad[*j] += a;
1992 }
1993 TapeOp::Add(l, r) => {
1994 adj[*l] += a;
1995 adj[*r] += a;
1996 }
1997 TapeOp::Sub(l, r) => {
1998 adj[*l] += a;
1999 adj[*r] -= a;
2000 }
2001 TapeOp::Mul(l, r) => {
2002 adj[*l] += a * vals[*r];
2003 adj[*r] += a * vals[*l];
2004 }
2005 TapeOp::Div(l, r) => {
2006 let rv = vals[*r];
2007 adj[*l] += a / rv;
2008 adj[*r] -= a * vals[*l] / (rv * rv);
2009 }
2010 TapeOp::Pow(l, r) => {
2011 let lv = vals[*l];
2012 let rv = vals[*r];
2013 if rv != 0.0 {
2014 adj[*l] += a * rv * lv.powf(rv - 1.0);
2015 }
2016 if lv > 0.0 {
2017 adj[*r] += a * vals[i] * lv.ln();
2018 }
2019 }
2020 TapeOp::Neg(j) => {
2021 adj[*j] -= a;
2022 }
2023 TapeOp::Abs(j) => {
2024 if vals[*j] >= 0.0 {
2025 adj[*j] += a;
2026 } else {
2027 adj[*j] -= a;
2028 }
2029 }
2030 TapeOp::Sqrt(j) => {
2031 let sv = vals[i];
2032 if sv > 0.0 {
2033 adj[*j] += a * 0.5 / sv;
2034 }
2035 }
2036 TapeOp::Exp(j) => {
2037 adj[*j] += a * vals[i];
2038 }
2039 TapeOp::Log(j) => {
2040 adj[*j] += a / vals[*j];
2041 }
2042 TapeOp::Log10(j) => {
2043 adj[*j] += a / (vals[*j] * std::f64::consts::LN_10);
2044 }
2045 TapeOp::Sin(j) => {
2046 adj[*j] += a * vals[*j].cos();
2047 }
2048 TapeOp::Cos(j) => {
2049 adj[*j] -= a * vals[*j].sin();
2050 }
2051 TapeOp::Funcall { lib, name, args } => {
2052 let call_args = funcall_to_ext_args(args, vals);
2053 let res = lib
2054 .eval(name, &call_args, true, false)
2055 .unwrap_or_else(|e| panic!("external function '{name}' reverse eval failed: {e}"));
2056 let derivs = res.derivs.expect("want_derivs=true returns derivs");
2057 let mut k = 0usize;
2058 for arg in args {
2059 if let TapeFuncallArg::Tape(idx) = arg {
2060 adj[*idx] += a * derivs[k];
2061 k += 1;
2062 }
2063 }
2064 let _ = i;
2065 let _ = grad;
2066 }
2067 }
2068}
2069
2070#[inline]
2071fn fwd_tan_step(op: &TapeOp, seed_var: usize, vals: &[f64], dot: &[f64], i: usize) -> f64 {
2072 match op {
2073 TapeOp::Const(_) => 0.0,
2074 TapeOp::Var(k) => {
2075 if *k == seed_var {
2076 1.0
2077 } else {
2078 0.0
2079 }
2080 }
2081 TapeOp::Add(a, b) => dot[*a] + dot[*b],
2082 TapeOp::Sub(a, b) => dot[*a] - dot[*b],
2083 TapeOp::Mul(a, b) => dot[*a] * vals[*b] + vals[*a] * dot[*b],
2084 TapeOp::Div(a, b) => {
2085 let vb = vals[*b];
2086 (dot[*a] * vb - vals[*a] * dot[*b]) / (vb * vb)
2087 }
2088 TapeOp::Pow(a, b) => {
2089 let u = vals[*a];
2090 let r = vals[*b];
2091 let du = dot[*a];
2092 let dr = dot[*b];
2093 let mut result = 0.0;
2094 if r != 0.0 && u != 0.0 {
2095 result += r * u.powf(r - 1.0) * du;
2096 }
2097 if u > 0.0 {
2098 result += vals[i] * u.ln() * dr;
2099 }
2100 result
2101 }
2102 TapeOp::Neg(a) => -dot[*a],
2103 TapeOp::Abs(a) => {
2104 if vals[*a] >= 0.0 {
2105 dot[*a]
2106 } else {
2107 -dot[*a]
2108 }
2109 }
2110 TapeOp::Sqrt(a) => {
2111 let sv = vals[i];
2112 if sv > 0.0 {
2113 dot[*a] * 0.5 / sv
2114 } else {
2115 0.0
2116 }
2117 }
2118 TapeOp::Exp(a) => dot[*a] * vals[i],
2119 TapeOp::Log(a) => dot[*a] / vals[*a],
2120 TapeOp::Log10(a) => dot[*a] / (vals[*a] * std::f64::consts::LN_10),
2121 TapeOp::Sin(a) => dot[*a] * vals[*a].cos(),
2122 TapeOp::Cos(a) => -dot[*a] * vals[*a].sin(),
2123 TapeOp::Funcall { lib, name, args } => {
2124 let call_args = funcall_to_ext_args(args, vals);
2125 let res = lib
2126 .eval(name, &call_args, true, false)
2127 .unwrap_or_else(|e| panic!("external function '{name}' tangent eval failed: {e}"));
2128 let derivs = res.derivs.expect("want_derivs=true returns derivs");
2129 let mut acc = 0.0;
2130 let mut k = 0usize;
2131 for arg in args {
2132 if let TapeFuncallArg::Tape(idx) = arg {
2133 acc += derivs[k] * dot[*idx];
2134 k += 1;
2135 }
2136 }
2137 let _ = seed_var;
2138 acc
2139 }
2140 }
2141}
2142
2143#[allow(clippy::too_many_arguments)]
2144#[inline]
2145fn ror_step(
2146 op: &TapeOp,
2147 i: usize,
2148 seed_var: usize,
2149 vals: &[f64],
2150 dot: &[f64],
2151 adj: &mut [f64],
2152 adj_dot: &mut [f64],
2153 w: f64,
2154 wd: f64,
2155 weight: f64,
2156 hess_map: &HashMap<(usize, usize), usize>,
2157 values: &mut [f64],
2158) {
2159 match op {
2160 TapeOp::Const(_) => {}
2161 TapeOp::Var(k) => {
2162 if wd != 0.0 && *k >= seed_var {
2163 if let Some(&pos) = hess_map.get(&(*k, seed_var)) {
2164 values[pos] += weight * wd;
2165 }
2166 }
2167 }
2168 TapeOp::Add(a, b) => {
2169 adj[*a] += w;
2170 adj[*b] += w;
2171 adj_dot[*a] += wd;
2172 adj_dot[*b] += wd;
2173 }
2174 TapeOp::Sub(a, b) => {
2175 adj[*a] += w;
2176 adj[*b] -= w;
2177 adj_dot[*a] += wd;
2178 adj_dot[*b] -= wd;
2179 }
2180 TapeOp::Mul(a, b) => {
2181 adj[*a] += w * vals[*b];
2182 adj[*b] += w * vals[*a];
2183 adj_dot[*a] += wd * vals[*b] + w * dot[*b];
2184 adj_dot[*b] += wd * vals[*a] + w * dot[*a];
2185 }
2186 TapeOp::Div(a, b) => {
2187 let vb = vals[*b];
2188 let vb2 = vb * vb;
2189 let vb3 = vb2 * vb;
2190 adj[*a] += w / vb;
2191 adj_dot[*a] += wd / vb + w * (-dot[*b] / vb2);
2192 adj[*b] += w * (-vals[*a] / vb2);
2193 adj_dot[*b] +=
2194 wd * (-vals[*a] / vb2) + w * (-dot[*a] / vb2 + 2.0 * vals[*a] * dot[*b] / vb3);
2195 }
2196 TapeOp::Pow(a, b) => {
2197 let u = vals[*a];
2198 let r = vals[*b];
2199 let du = dot[*a];
2200 let dr = dot[*b];
2201 if r != 0.0 {
2202 if u != 0.0 {
2203 let p_a = r * u.powf(r - 1.0);
2204 adj[*a] += w * p_a;
2205 let mut dp_a = dr * u.powf(r - 1.0);
2206 if u > 0.0 {
2207 dp_a += r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
2208 } else {
2209 dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
2210 }
2211 adj_dot[*a] += wd * p_a + w * dp_a;
2212 } else if r >= 2.0 {
2213 let p_a = 0.0;
2214 adj[*a] += w * p_a;
2215 let dp_a = if r == 2.0 {
2216 2.0 * du
2217 } else {
2218 r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
2219 };
2220 adj_dot[*a] += wd * p_a + w * dp_a;
2221 }
2222 }
2223 if u > 0.0 {
2224 let ln_u = u.ln();
2225 let p_b = vals[i] * ln_u;
2226 adj[*b] += w * p_b;
2227 let dur = vals[i] * (r * du / u + dr * ln_u);
2228 let dp_b = dur * ln_u + vals[i] * du / u;
2229 adj_dot[*b] += wd * p_b + w * dp_b;
2230 }
2231 }
2232 TapeOp::Neg(a) => {
2233 adj[*a] -= w;
2234 adj_dot[*a] -= wd;
2235 }
2236 TapeOp::Abs(a) => {
2237 let s = if vals[*a] >= 0.0 { 1.0 } else { -1.0 };
2238 adj[*a] += w * s;
2239 adj_dot[*a] += wd * s;
2240 }
2241 TapeOp::Sqrt(a) => {
2242 let sv = vals[i];
2243 if sv > 0.0 {
2244 let fp = 0.5 / sv;
2245 let fpp = -0.25 / (vals[*a] * sv);
2246 adj[*a] += w * fp;
2247 adj_dot[*a] += wd * fp + w * fpp * dot[*a];
2248 }
2249 }
2250 TapeOp::Exp(a) => {
2251 let ev = vals[i];
2252 adj[*a] += w * ev;
2253 adj_dot[*a] += wd * ev + w * ev * dot[*a];
2254 }
2255 TapeOp::Log(a) => {
2256 let u = vals[*a];
2257 adj[*a] += w / u;
2258 adj_dot[*a] += wd / u + w * (-1.0 / (u * u)) * dot[*a];
2259 }
2260 TapeOp::Log10(a) => {
2261 let u = vals[*a];
2262 let c = std::f64::consts::LN_10;
2263 adj[*a] += w / (u * c);
2264 adj_dot[*a] += wd / (u * c) + w * (-1.0 / (u * u * c)) * dot[*a];
2265 }
2266 TapeOp::Sin(a) => {
2267 let u = vals[*a];
2268 let cu = u.cos();
2269 adj[*a] += w * cu;
2270 adj_dot[*a] += wd * cu + w * (-u.sin()) * dot[*a];
2271 }
2272 TapeOp::Cos(a) => {
2273 let u = vals[*a];
2274 let su = u.sin();
2275 adj[*a] -= w * su;
2276 adj_dot[*a] += wd * (-su) + w * (-u.cos()) * dot[*a];
2277 }
2278 TapeOp::Funcall { lib, name, args } => {
2279 let call_args = funcall_to_ext_args(args, vals);
2280 let res = lib.eval(name, &call_args, true, true).unwrap_or_else(|e| {
2281 panic!("external function '{name}' 2nd-order eval failed: {e}")
2282 });
2283 let derivs = res.derivs.expect("want_derivs=true returns derivs");
2284 let hes = res.hessian.expect("want_hes=true returns hessian");
2285 let real_tape: Vec<usize> = args
2286 .iter()
2287 .filter_map(|a| match a {
2288 TapeFuncallArg::Tape(t) => Some(*t),
2289 TapeFuncallArg::Str(_) => None,
2290 })
2291 .collect();
2292 for (k, &tk) in real_tape.iter().enumerate() {
2293 adj[tk] += w * derivs[k];
2294 let mut second_term = 0.0;
2295 for (l, &tl) in real_tape.iter().enumerate() {
2296 let (lo, hi) = if k <= l { (k, l) } else { (l, k) };
2297 let h_kl = hes[lo + hi * (hi + 1) / 2];
2298 second_term += h_kl * dot[tl];
2299 }
2300 adj_dot[tk] += wd * derivs[k] + w * second_term;
2301 }
2302 let _ = seed_var;
2303 let _ = hess_map;
2304 let _ = values;
2305 let _ = weight;
2306 let _ = i;
2307 }
2308 }
2309}
2310
2311fn hessian_sparsity_impl(ops: &[TapeOp]) -> BTreeSet<(usize, usize)> {
2315 let n = ops.len();
2316 let mut var_sets: Vec<BTreeSet<usize>> = Vec::with_capacity(n);
2317 let mut pairs: BTreeSet<(usize, usize)> = BTreeSet::new();
2318
2319 let emit_cross =
2320 |s1: &BTreeSet<usize>, s2: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
2321 for &v1 in s1 {
2322 for &v2 in s2 {
2323 let (r, c) = if v1 >= v2 { (v1, v2) } else { (v2, v1) };
2324 pairs.insert((r, c));
2325 }
2326 }
2327 };
2328 let emit_self = |s: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
2329 let vars: Vec<usize> = s.iter().copied().collect();
2330 for (ai, &vi) in vars.iter().enumerate() {
2331 for &vj in &vars[..=ai] {
2332 let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
2333 pairs.insert((r, c));
2334 }
2335 }
2336 };
2337
2338 for op in ops {
2339 let vset = match op {
2340 TapeOp::Const(_) => BTreeSet::new(),
2341 TapeOp::Var(j) => {
2342 let mut s = BTreeSet::new();
2343 s.insert(*j);
2344 s
2345 }
2346 TapeOp::Add(a, b) | TapeOp::Sub(a, b) => {
2347 var_sets[*a].union(&var_sets[*b]).copied().collect()
2348 }
2349 TapeOp::Neg(a) | TapeOp::Abs(a) => var_sets[*a].clone(),
2350 TapeOp::Mul(a, b) => {
2351 emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
2352 var_sets[*a].union(&var_sets[*b]).copied().collect()
2353 }
2354 TapeOp::Div(a, b) => {
2355 emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
2356 emit_self(&var_sets[*b], &mut pairs);
2357 var_sets[*a].union(&var_sets[*b]).copied().collect()
2358 }
2359 TapeOp::Pow(a, b) => {
2360 let combined: BTreeSet<usize> =
2361 var_sets[*a].union(&var_sets[*b]).copied().collect();
2362 emit_self(&combined, &mut pairs);
2363 combined
2364 }
2365 TapeOp::Sqrt(a)
2366 | TapeOp::Exp(a)
2367 | TapeOp::Log(a)
2368 | TapeOp::Log10(a)
2369 | TapeOp::Sin(a)
2370 | TapeOp::Cos(a) => {
2371 emit_self(&var_sets[*a], &mut pairs);
2372 var_sets[*a].clone()
2373 }
2374 TapeOp::Funcall { args, .. } => {
2375 let mut combined: BTreeSet<usize> = BTreeSet::new();
2376 for arg in args {
2377 if let TapeFuncallArg::Tape(t) = arg {
2378 for &vv in &var_sets[*t] {
2379 combined.insert(vv);
2380 }
2381 }
2382 }
2383 emit_self(&combined, &mut pairs);
2384 combined
2385 }
2386 };
2387 var_sets.push(vset);
2388 }
2389 pairs
2390}
2391
2392#[cfg(test)]
2393mod tests {
2394 use super::*;
2395
2396 fn cnst(c: f64) -> Expr {
2397 Expr::Const(c)
2398 }
2399 fn var(i: usize) -> Expr {
2400 Expr::Var(i)
2401 }
2402 fn add(a: Expr, b: Expr) -> Expr {
2403 Expr::Binary(BinOp::Add, Box::new(a), Box::new(b))
2404 }
2405 fn mul(a: Expr, b: Expr) -> Expr {
2406 Expr::Binary(BinOp::Mul, Box::new(a), Box::new(b))
2407 }
2408 fn pow(a: Expr, b: Expr) -> Expr {
2409 Expr::Binary(BinOp::Pow, Box::new(a), Box::new(b))
2410 }
2411 fn div(a: Expr, b: Expr) -> Expr {
2412 Expr::Binary(BinOp::Div, Box::new(a), Box::new(b))
2413 }
2414 fn unary(op: UnaryOp, a: Expr) -> Expr {
2415 Expr::Unary(op, Box::new(a))
2416 }
2417
2418 #[test]
2419 fn polynomial_eval_and_grad() {
2420 let e = add(
2422 mul(cnst(3.0), pow(var(0), cnst(2.0))),
2423 mul(cnst(2.0), var(1)),
2424 );
2425 let t = Tape::build(&e);
2426 assert!((t.eval(&[2.0, 3.0]) - 18.0).abs() < 1e-12);
2427 let mut g = vec![0.0; 2];
2428 t.gradient_seed(&[2.0, 3.0], 1.0, &mut g);
2429 assert!((g[0] - 12.0).abs() < 1e-12);
2431 assert!((g[1] - 2.0).abs() < 1e-12);
2432 }
2433
2434 #[test]
2435 fn cse_shared_body_evaluated_once() {
2436 let body = Rc::new(add(var(0), var(1)));
2438 let e = add(
2439 pow(Expr::Cse(body.clone()), cnst(2.0)),
2440 Expr::Cse(body.clone()),
2441 );
2442 let t = Tape::build(&e);
2443 let n_body_adds = t
2445 .ops
2446 .iter()
2447 .filter(|op| {
2448 matches!(op, TapeOp::Add(a, b) if {
2449 matches!(t.ops[*a], TapeOp::Var(0)) && matches!(t.ops[*b], TapeOp::Var(1))
2450 })
2451 })
2452 .count();
2453 assert_eq!(n_body_adds, 1, "CSE body should be emitted exactly once");
2454
2455 assert!((t.eval(&[1.0, 2.0]) - 12.0).abs() < 1e-12);
2457 let mut g = vec![0.0; 2];
2458 t.gradient_seed(&[1.0, 2.0], 1.0, &mut g);
2459 assert!((g[0] - 7.0).abs() < 1e-12);
2461 assert!((g[1] - 7.0).abs() < 1e-12);
2462 }
2463
2464 fn fd_check(tape: &Tape, x: &[f64], n: usize, tol: f64) {
2465 let vars = tape.variables();
2466 let mut hess_map: HashMap<(usize, usize), usize> = HashMap::new();
2467 let mut pairs = Vec::new();
2468 for (ai, &vi) in vars.iter().enumerate() {
2469 for &vj in &vars[..=ai] {
2470 let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
2471 hess_map.entry((r, c)).or_insert_with(|| {
2472 let p = pairs.len();
2473 pairs.push((r, c));
2474 p
2475 });
2476 }
2477 }
2478 let nnz = pairs.len();
2479 let mut ad = vec![0.0; nnz];
2480 tape.hessian_accumulate(x, 1.0, &hess_map, &mut ad);
2481
2482 let mut fd = vec![0.0; nnz];
2483 let mut xp = x.to_vec();
2484 let mut gp = vec![0.0; n];
2485 let mut gm = vec![0.0; n];
2486 for &j in &vars {
2487 let h = (1e-7_f64).max(x[j].abs() * 1e-7);
2488 xp[j] = x[j] + h;
2489 gp.iter_mut().for_each(|v| *v = 0.0);
2490 tape.gradient_seed(&xp, 1.0, &mut gp);
2491 xp[j] = x[j] - h;
2492 gm.iter_mut().for_each(|v| *v = 0.0);
2493 tape.gradient_seed(&xp, 1.0, &mut gm);
2494 xp[j] = x[j];
2495 for &i in &vars {
2496 if i >= j {
2497 if let Some(&pos) = hess_map.get(&(i, j)) {
2498 fd[pos] = (gp[i] - gm[i]) / (2.0 * h);
2499 }
2500 }
2501 }
2502 }
2503 for (k, &(r, c)) in pairs.iter().enumerate() {
2504 let scale = fd[k].abs().max(1.0);
2505 assert!(
2506 (ad[k] - fd[k]).abs() / scale < tol,
2507 "H[{},{}]: AD={:.6e} FD={:.6e}",
2508 r,
2509 c,
2510 ad[k],
2511 fd[k]
2512 );
2513 }
2514 }
2515
2516 #[test]
2517 fn hessian_quadratic_matches_fd() {
2518 let e = add(
2520 add(
2521 mul(cnst(3.0), pow(var(0), cnst(2.0))),
2522 mul(cnst(2.0), mul(var(0), var(1))),
2523 ),
2524 pow(var(1), cnst(2.0)),
2525 );
2526 let t = Tape::build(&e);
2527 fd_check(&t, &[2.0, 3.0], 2, 1e-5);
2528 }
2529
2530 #[test]
2531 fn hessian_transcendental_matches_fd() {
2532 let e = Expr::Sum(vec![
2534 unary(UnaryOp::Exp, var(0)),
2535 unary(UnaryOp::Sin, var(1)),
2536 unary(UnaryOp::Log, var(0)),
2537 unary(UnaryOp::Sqrt, var(1)),
2538 mul(var(0), var(1)),
2539 ]);
2540 let t = Tape::build(&e);
2541 fd_check(&t, &[1.5, 2.0], 2, 1e-5);
2542 }
2543
2544 #[test]
2545 fn hessian_division_matches_fd() {
2546 let e = add(div(var(0), var(1)), unary(UnaryOp::Cos, var(0)));
2548 let t = Tape::build(&e);
2549 fd_check(&t, &[0.5, 1.2], 2, 1e-5);
2550 }
2551
2552 fn directional_matches_accumulate(tape: &Tape, x: &[f64], n: usize) {
2557 let vars = tape.variables();
2558 let mut hess_map: HashMap<(usize, usize), usize> = HashMap::new();
2559 let mut pairs = Vec::new();
2560 for (ai, &vi) in vars.iter().enumerate() {
2561 for &vj in &vars[..=ai] {
2562 let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
2563 hess_map.entry((r, c)).or_insert_with(|| {
2564 let p = pairs.len();
2565 pairs.push((r, c));
2566 p
2567 });
2568 }
2569 }
2570 let nnz = pairs.len();
2571 let mut ad = vec![0.0; nnz];
2572 tape.hessian_accumulate(x, 1.0, &hess_map, &mut ad);
2573
2574 let nops = tape.ops.len();
2575 let mut vals = vec![0.0; nops];
2576 tape.forward_into(x, &mut vals);
2577 let mut dot = vec![0.0; nops];
2578 let mut adj = vec![0.0; nops];
2579 let mut adj_dot = vec![0.0; nops];
2580
2581 for &j in &vars {
2582 let mut seed = vec![0.0; n];
2583 seed[j] = 1.0;
2584 let mut col = vec![0.0; n];
2585 tape.hessian_directional(
2586 &vals,
2587 &seed,
2588 1.0,
2589 &mut col,
2590 &mut dot,
2591 &mut adj,
2592 &mut adj_dot,
2593 );
2594 for &i in &vars {
2595 let (r, c) = if i >= j { (i, j) } else { (j, i) };
2596 let expect = ad[hess_map[&(r, c)]];
2597 assert!(
2598 (col[i] - expect).abs() < 1e-10,
2599 "directional H[{i},{j}] = {} vs accumulate {}",
2600 col[i],
2601 expect
2602 );
2603 }
2604 }
2605 }
2606
2607 #[test]
2608 fn directional_quadratic_matches_accumulate() {
2609 let e = add(
2611 add(
2612 mul(cnst(3.0), pow(var(0), cnst(2.0))),
2613 mul(mul(cnst(2.0), var(0)), var(1)),
2614 ),
2615 pow(var(1), cnst(2.0)),
2616 );
2617 let t = Tape::build(&e);
2618 directional_matches_accumulate(&t, &[0.5, -0.3], 2);
2619 }
2620
2621 #[test]
2622 fn directional_transcendental_matches_accumulate() {
2623 let e = Expr::Sum(vec![
2624 unary(UnaryOp::Exp, var(0)),
2625 unary(UnaryOp::Sin, var(1)),
2626 unary(UnaryOp::Log, var(0)),
2627 unary(UnaryOp::Sqrt, var(1)),
2628 mul(var(0), var(1)),
2629 ]);
2630 let t = Tape::build(&e);
2631 directional_matches_accumulate(&t, &[1.5, 2.0], 2);
2632 }
2633
2634 #[test]
2635 fn directional_with_division_matches_accumulate() {
2636 let e = add(div(var(0), var(1)), unary(UnaryOp::Cos, var(0)));
2637 let t = Tape::build(&e);
2638 directional_matches_accumulate(&t, &[0.5, 1.2], 2);
2639 }
2640
2641 #[test]
2642 fn hessian_sparsity_separable() {
2643 let e = add(unary(UnaryOp::Sin, var(0)), mul(var(1), var(2)));
2645 let t = Tape::build(&e);
2646 let s = t.hessian_sparsity();
2647 assert!(s.contains(&(0, 0)));
2648 assert!(s.contains(&(2, 1)));
2649 assert!(!s.contains(&(1, 0)));
2650 assert!(!s.contains(&(2, 0)));
2651 }
2652
2653 fn count_op<F: Fn(&TapeOp) -> bool>(t: &Tape, pred: F) -> usize {
2654 t.ops.iter().filter(|o| pred(o)).count()
2655 }
2656
2657 #[test]
2658 fn pow_zero_const_folds_to_one() {
2659 let e = pow(var(0), cnst(0.0));
2661 let t = Tape::build(&e);
2662 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2663 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Var(_))), 0);
2664 assert!((t.eval(&[7.0]) - 1.0).abs() < 1e-12);
2665 }
2666
2667 #[test]
2668 fn pow_one_passes_through() {
2669 let e = pow(var(0), cnst(1.0));
2671 let t = Tape::build(&e);
2672 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2673 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Const(_))), 0);
2674 assert!((t.eval(&[3.5]) - 3.5).abs() < 1e-12);
2675 }
2676
2677 #[test]
2678 fn pow_half_lowers_to_sqrt() {
2679 let e = pow(var(0), cnst(0.5));
2680 let t = Tape::build(&e);
2681 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2682 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Sqrt(_))), 1);
2683 assert!((t.eval(&[16.0]) - 4.0).abs() < 1e-12);
2684 }
2685
2686 #[test]
2687 fn pow_two_lowers_to_single_mul() {
2688 let e = pow(var(0), cnst(2.0));
2689 let t = Tape::build(&e);
2690 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2691 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 1);
2692 assert!((t.eval(&[3.0]) - 9.0).abs() < 1e-12);
2693 }
2694
2695 #[test]
2696 fn pow_three_lowers_to_two_muls() {
2697 let e = pow(var(0), cnst(3.0));
2698 let t = Tape::build(&e);
2699 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2700 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 2);
2701 assert!((t.eval(&[2.0]) - 8.0).abs() < 1e-12);
2702 }
2703
2704 #[test]
2705 fn pow_eight_lowers_to_three_muls() {
2706 let e = pow(var(0), cnst(8.0));
2708 let t = Tape::build(&e);
2709 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2710 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 3);
2711 assert!((t.eval(&[2.0]) - 256.0).abs() < 1e-12);
2712 }
2713
2714 #[test]
2715 fn pow_negative_two_lowers_to_div() {
2716 let e = pow(var(0), cnst(-2.0));
2718 let t = Tape::build(&e);
2719 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2720 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Div(..))), 1);
2721 assert!((t.eval(&[4.0]) - (1.0 / 16.0)).abs() < 1e-12);
2722 }
2723
2724 #[test]
2725 fn pow_large_const_stays_generic() {
2726 let e = pow(var(0), cnst(9.0));
2728 let t = Tape::build(&e);
2729 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 1);
2730 }
2731
2732 #[test]
2733 fn pow_non_integer_const_stays_generic() {
2734 let e = pow(var(0), cnst(1.5));
2736 let t = Tape::build(&e);
2737 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 1);
2738 }
2739
2740 #[test]
2741 fn pow_const_through_cse_const() {
2742 let two = Rc::new(cnst(2.0));
2744 let e = Expr::Binary(BinOp::Pow, Box::new(var(0)), Box::new(Expr::Cse(two)));
2745 let t = Tape::build(&e);
2746 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
2747 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 1);
2748 }
2749
2750 #[test]
2751 fn hessian_pow_three_matches_fd() {
2752 let e = add(mul(cnst(5.0), pow(var(0), cnst(3.0))), mul(var(0), var(1)));
2754 let t = Tape::build(&e);
2755 fd_check(&t, &[1.7, 0.8], 2, 1e-5);
2756 }
2757
2758 #[test]
2759 fn hessian_pow_negative_matches_fd() {
2760 let e = add(pow(var(0), cnst(-2.0)), pow(var(1), cnst(2.0)));
2762 let t = Tape::build(&e);
2763 fd_check(&t, &[1.3, 2.4], 2, 1e-5);
2764 }
2765
2766 #[test]
2767 fn hessian_pow_half_matches_fd() {
2768 let e = add(pow(var(0), cnst(0.5)), mul(var(0), var(1)));
2770 let t = Tape::build(&e);
2771 fd_check(&t, &[2.5, 1.1], 2, 1e-5);
2772 }
2773
2774 #[test]
2775 fn hessian_sparsity_through_cse() {
2776 let body = Rc::new(add(var(0), var(1)));
2779 let e = add(
2780 pow(Expr::Cse(body.clone()), cnst(2.0)),
2781 Expr::Cse(body.clone()),
2782 );
2783 let t = Tape::build(&e);
2784 let s = t.hessian_sparsity();
2785 assert!(s.contains(&(0, 0)));
2786 assert!(s.contains(&(1, 0)));
2787 assert!(s.contains(&(1, 1)));
2788 assert_eq!(s.len(), 3);
2789 }
2790}