1use std::collections::{BTreeSet, HashMap, HashSet};
26use std::sync::Arc;
27
28use super::nl_external::{EvalResult, ExternalArg, ExternalLibrary, ExternalResolver};
29use super::nl_reader::{BinOp, CmpOp, Expr, FuncallArg, UnaryOp};
30
31#[derive(Debug, Clone)]
35pub enum TapeOp {
36 Const(f64),
37 Var(usize),
38 Add(usize, usize),
39 Sub(usize, usize),
40 Mul(usize, usize),
41 Div(usize, usize),
42 Pow(usize, usize),
43 Neg(usize),
44 Abs(usize),
45 Sqrt(usize),
46 Exp(usize),
47 Log(usize),
48 Log10(usize),
49 Sin(usize),
50 Cos(usize),
51 Tan(usize),
52 Atan(usize),
53 Acos(usize),
54 Sinh(usize),
55 Cosh(usize),
56 Tanh(usize),
57 Asin(usize),
58 Acosh(usize),
59 Asinh(usize),
60 Atanh(usize),
61 Atan2(usize, usize),
64 Min(usize, usize),
69 Max(usize, usize),
72 Cmp(CmpOp, usize, usize),
76 And(usize, usize),
78 Or(usize, usize),
80 Not(usize),
82 Select(usize, usize, usize),
88 Funcall(Box<FuncallData>),
96}
97
98#[derive(Debug, Clone)]
103pub struct FuncallData {
104 pub lib: Arc<ExternalLibrary>,
105 pub name: String,
106 pub args: Vec<TapeFuncallArg>,
107}
108
109#[derive(Debug, Clone)]
113pub enum TapeFuncallArg {
114 Tape(usize),
115 Str(String),
116}
117
118#[inline]
121fn cmp_holds(op: CmpOp, a: f64, b: f64) -> bool {
122 match op {
123 CmpOp::Lt => a < b,
124 CmpOp::Le => a <= b,
125 CmpOp::Eq => a == b,
126 CmpOp::Ge => a >= b,
127 CmpOp::Gt => a > b,
128 CmpOp::Ne => a != b,
129 }
130}
131
132fn funcall_to_ext_args<'a>(args: &'a [TapeFuncallArg], vals: &[f64]) -> Vec<ExternalArg<'a>> {
133 args.iter()
134 .map(|a| match a {
135 TapeFuncallArg::Tape(idx) => ExternalArg::Real(vals[*idx]),
136 TapeFuncallArg::Str(s) => ExternalArg::Str(s.as_str()),
137 })
138 .collect()
139}
140
141fn ext_eval_or_nan(
154 lib: &ExternalLibrary,
155 name: &str,
156 call_args: &[ExternalArg<'_>],
157 n_args: usize,
158 want_derivs: bool,
159 want_hes: bool,
160) -> EvalResult {
161 lib.eval(name, call_args, want_derivs, want_hes)
162 .unwrap_or_else(|_| EvalResult {
163 value: f64::NAN,
164 derivs: want_derivs.then(|| vec![f64::NAN; n_args]),
165 hessian: want_hes.then(|| vec![f64::NAN; n_args * (n_args + 1) / 2]),
166 })
167}
168
169#[derive(Debug, Clone)]
172pub struct Tape {
173 pub ops: Vec<TapeOp>,
174}
175
176impl Tape {
177 pub fn build(expr: &Expr) -> Self {
181 Self::build_with_externals(expr, &ExternalResolver::default())
182 }
183
184 pub fn build_with_externals(expr: &Expr, resolver: &ExternalResolver) -> Self {
189 let mut ops = Vec::new();
190 let mut cache: HashMap<*const Expr, usize> = HashMap::new();
191 build_recursive(expr, &mut ops, &mut cache, resolver);
192 Tape { ops }
193 }
194
195 pub fn forward(&self, x: &[f64]) -> Vec<f64> {
198 let mut vals: Vec<f64> = Vec::with_capacity(self.ops.len());
199 for op in &self.ops {
200 let v = match op {
201 TapeOp::Const(c) => *c,
202 TapeOp::Var(i) => x[*i],
203 TapeOp::Add(a, b) => vals[*a] + vals[*b],
204 TapeOp::Sub(a, b) => vals[*a] - vals[*b],
205 TapeOp::Mul(a, b) => vals[*a] * vals[*b],
206 TapeOp::Div(a, b) => vals[*a] / vals[*b],
207 TapeOp::Pow(a, b) => vals[*a].powf(vals[*b]),
208 TapeOp::Neg(a) => -vals[*a],
209 TapeOp::Abs(a) => vals[*a].abs(),
210 TapeOp::Sqrt(a) => vals[*a].sqrt(),
211 TapeOp::Exp(a) => vals[*a].exp(),
212 TapeOp::Log(a) => vals[*a].ln(),
213 TapeOp::Log10(a) => vals[*a].log10(),
214 TapeOp::Sin(a) => vals[*a].sin(),
215 TapeOp::Cos(a) => vals[*a].cos(),
216 TapeOp::Tan(a) => vals[*a].tan(),
217 TapeOp::Atan(a) => vals[*a].atan(),
218 TapeOp::Acos(a) => vals[*a].acos(),
219 TapeOp::Sinh(a) => vals[*a].sinh(),
220 TapeOp::Cosh(a) => vals[*a].cosh(),
221 TapeOp::Tanh(a) => vals[*a].tanh(),
222 TapeOp::Asin(a) => vals[*a].asin(),
223 TapeOp::Acosh(a) => vals[*a].acosh(),
224 TapeOp::Asinh(a) => vals[*a].asinh(),
225 TapeOp::Atanh(a) => vals[*a].atanh(),
226 TapeOp::Atan2(a, b) => vals[*a].atan2(vals[*b]),
227 TapeOp::Min(a, b) => vals[*a].min(vals[*b]),
228 TapeOp::Max(a, b) => vals[*a].max(vals[*b]),
229 TapeOp::Cmp(op, a, b) => f64::from(cmp_holds(*op, vals[*a], vals[*b])),
230 TapeOp::And(a, b) => f64::from(vals[*a] != 0.0 && vals[*b] != 0.0),
231 TapeOp::Or(a, b) => f64::from(vals[*a] != 0.0 || vals[*b] != 0.0),
232 TapeOp::Not(a) => f64::from(vals[*a] == 0.0),
233 TapeOp::Select(c, t, e) => {
234 if vals[*c] != 0.0 {
235 vals[*t]
236 } else {
237 vals[*e]
238 }
239 }
240 TapeOp::Funcall(fc) => {
241 let FuncallData { lib, name, args } = fc.as_ref();
242 let call_args = funcall_to_ext_args(args, &vals);
243 let res = ext_eval_or_nan(lib, name, &call_args, args.len(), false, false);
244 res.value
245 }
246 };
247 vals.push(v);
248 }
249 vals
250 }
251
252 pub fn eval(&self, x: &[f64]) -> f64 {
253 let vals = self.forward(x);
254 *vals.last().unwrap_or(&0.0)
255 }
256
257 pub fn gradient_seed(&self, x: &[f64], seed: f64, grad: &mut [f64]) {
262 if seed == 0.0 || self.ops.is_empty() {
263 return;
264 }
265 let vals = self.forward(x);
266 self.reverse(&vals, seed, grad);
267 }
268
269 pub fn gradient_seed_into(
281 &self,
282 x: &[f64],
283 seed: f64,
284 grad: &mut [f64],
285 vals: &mut [f64],
286 adj: &mut [f64],
287 ) {
288 if seed == 0.0 || self.ops.is_empty() {
289 return;
290 }
291 debug_assert!(vals.len() >= self.ops.len());
292 self.forward_into(x, vals);
293 self.reverse_into(vals, seed, grad, adj);
294 }
295
296 fn reverse(&self, vals: &[f64], seed: f64, grad: &mut [f64]) {
297 let n = self.ops.len();
298 let mut adj = vec![0.0f64; n];
299 self.reverse_into(vals, seed, grad, &mut adj);
300 }
301
302 fn reverse_into(&self, vals: &[f64], seed: f64, grad: &mut [f64], adj: &mut [f64]) {
307 let n = self.ops.len();
308 debug_assert!(adj.len() >= n);
309 adj[..n].fill(0.0);
310 adj[n - 1] = seed;
311
312 for i in (0..n).rev() {
313 let a = adj[i];
314 if a == 0.0 {
315 continue;
316 }
317 match &self.ops[i] {
318 TapeOp::Const(_) => {}
319 TapeOp::Var(j) => {
320 grad[*j] += a;
321 }
322 TapeOp::Add(l, r) => {
323 adj[*l] += a;
324 adj[*r] += a;
325 }
326 TapeOp::Sub(l, r) => {
327 adj[*l] += a;
328 adj[*r] -= a;
329 }
330 TapeOp::Mul(l, r) => {
331 adj[*l] += a * vals[*r];
332 adj[*r] += a * vals[*l];
333 }
334 TapeOp::Div(l, r) => {
335 let rv = vals[*r];
336 adj[*l] += a / rv;
337 adj[*r] -= a * vals[*l] / (rv * rv);
338 }
339 TapeOp::Pow(l, r) => {
340 let lv = vals[*l];
341 let rv = vals[*r];
342 if rv != 0.0 {
343 adj[*l] += a * rv * lv.powf(rv - 1.0);
344 }
345 if lv > 0.0 {
346 adj[*r] += a * vals[i] * lv.ln();
347 }
348 }
349 TapeOp::Neg(j) => {
350 adj[*j] -= a;
351 }
352 TapeOp::Abs(j) => {
353 if vals[*j] >= 0.0 {
354 adj[*j] += a;
355 } else {
356 adj[*j] -= a;
357 }
358 }
359 TapeOp::Sqrt(j) => {
360 let sv = vals[i];
361 if sv > 0.0 {
362 adj[*j] += a * 0.5 / sv;
363 }
364 }
365 TapeOp::Exp(j) => {
366 adj[*j] += a * vals[i];
367 }
368 TapeOp::Log(j) => {
369 adj[*j] += a / vals[*j];
370 }
371 TapeOp::Log10(j) => {
372 adj[*j] += a / (vals[*j] * std::f64::consts::LN_10);
373 }
374 TapeOp::Sin(j) => {
375 adj[*j] += a * vals[*j].cos();
376 }
377 TapeOp::Cos(j) => {
378 adj[*j] -= a * vals[*j].sin();
379 }
380 TapeOp::Tan(j) => {
381 let t = vals[i];
382 adj[*j] += a * (1.0 + t * t);
383 }
384 TapeOp::Atan(j) => {
385 let u = vals[*j];
386 adj[*j] += a / (1.0 + u * u);
387 }
388 TapeOp::Acos(j) => {
389 let u = vals[*j];
390 adj[*j] -= a / (1.0 - u * u).sqrt();
391 }
392 TapeOp::Sinh(j) => {
393 adj[*j] += a * vals[*j].cosh();
394 }
395 TapeOp::Cosh(j) => {
396 adj[*j] += a * vals[*j].sinh();
397 }
398 TapeOp::Tanh(j) => {
399 let t = vals[i];
400 adj[*j] += a * (1.0 - t * t);
401 }
402 TapeOp::Asin(j) => {
403 let u = vals[*j];
404 adj[*j] += a / (1.0 - u * u).sqrt();
405 }
406 TapeOp::Acosh(j) => {
407 let u = vals[*j];
408 adj[*j] += a / (u * u - 1.0).sqrt();
409 }
410 TapeOp::Asinh(j) => {
411 let u = vals[*j];
412 adj[*j] += a / (u * u + 1.0).sqrt();
413 }
414 TapeOp::Atanh(j) => {
415 let u = vals[*j];
416 adj[*j] += a / (1.0 - u * u);
417 }
418 TapeOp::Atan2(l, r) => {
419 let y = vals[*l];
420 let x = vals[*r];
421 let d = y * y + x * x;
422 adj[*l] += a * (x / d);
423 adj[*r] += a * (-y / d);
424 }
425 TapeOp::Min(l, r) => {
429 if vals[*l] <= vals[*r] {
430 adj[*l] += a;
431 } else {
432 adj[*r] += a;
433 }
434 }
435 TapeOp::Max(l, r) => {
436 if vals[*l] >= vals[*r] {
437 adj[*l] += a;
438 } else {
439 adj[*r] += a;
440 }
441 }
442 TapeOp::Cmp(_, _, _) | TapeOp::And(_, _) | TapeOp::Or(_, _) | TapeOp::Not(_) => {}
445 TapeOp::Select(c, t, e) => {
448 if vals[*c] != 0.0 {
449 adj[*t] += a;
450 } else {
451 adj[*e] += a;
452 }
453 }
454 TapeOp::Funcall(fc) => {
455 let FuncallData { lib, name, args } = fc.as_ref();
456 let call_args = funcall_to_ext_args(args, vals);
457 let res = ext_eval_or_nan(lib, name, &call_args, args.len(), true, false);
458 let derivs = res.derivs.expect("want_derivs=true returns derivs");
459 let mut k = 0usize;
460 for arg in args {
461 if let TapeFuncallArg::Tape(idx) = arg {
462 adj[*idx] += a * derivs[k];
463 k += 1;
464 }
465 }
466 }
467 }
468 }
469 }
470
471 pub fn variables(&self) -> Vec<usize> {
473 let mut s: BTreeSet<usize> = BTreeSet::new();
474 for op in &self.ops {
475 if let TapeOp::Var(j) = op {
476 s.insert(*j);
477 }
478 }
479 s.into_iter().collect()
480 }
481
482 fn forward_tangent(&self, vals: &[f64], seed_var: usize, dot: &mut [f64]) {
487 let n = self.ops.len();
488 debug_assert_eq!(dot.len(), n);
489 for i in 0..n {
490 dot[i] = match &self.ops[i] {
491 TapeOp::Const(_) => 0.0,
492 TapeOp::Var(k) => {
493 if *k == seed_var {
494 1.0
495 } else {
496 0.0
497 }
498 }
499 TapeOp::Add(a, b) => dot[*a] + dot[*b],
500 TapeOp::Sub(a, b) => dot[*a] - dot[*b],
501 TapeOp::Mul(a, b) => dot[*a] * vals[*b] + vals[*a] * dot[*b],
502 TapeOp::Div(a, b) => {
503 let vb = vals[*b];
504 (dot[*a] * vb - vals[*a] * dot[*b]) / (vb * vb)
505 }
506 TapeOp::Pow(a, b) => {
507 let u = vals[*a];
508 let r = vals[*b];
509 let du = dot[*a];
510 let dr = dot[*b];
511 let mut result = 0.0;
512 if r != 0.0 {
517 result += r * u.powf(r - 1.0) * du;
518 }
519 if u > 0.0 {
520 result += vals[i] * u.ln() * dr;
521 }
522 result
523 }
524 TapeOp::Neg(a) => -dot[*a],
525 TapeOp::Abs(a) => {
526 if vals[*a] >= 0.0 {
527 dot[*a]
528 } else {
529 -dot[*a]
530 }
531 }
532 TapeOp::Sqrt(a) => {
533 let sv = vals[i];
534 if sv > 0.0 {
535 dot[*a] * 0.5 / sv
536 } else {
537 0.0
538 }
539 }
540 TapeOp::Exp(a) => dot[*a] * vals[i],
541 TapeOp::Log(a) => dot[*a] / vals[*a],
542 TapeOp::Log10(a) => dot[*a] / (vals[*a] * std::f64::consts::LN_10),
543 TapeOp::Sin(a) => dot[*a] * vals[*a].cos(),
544 TapeOp::Cos(a) => -dot[*a] * vals[*a].sin(),
545 TapeOp::Tan(a) => {
546 let t = vals[i];
547 dot[*a] * (1.0 + t * t)
548 }
549 TapeOp::Atan(a) => {
550 let u = vals[*a];
551 dot[*a] / (1.0 + u * u)
552 }
553 TapeOp::Acos(a) => {
554 let u = vals[*a];
555 -dot[*a] / (1.0 - u * u).sqrt()
556 }
557 TapeOp::Sinh(a) => dot[*a] * vals[*a].cosh(),
558 TapeOp::Cosh(a) => dot[*a] * vals[*a].sinh(),
559 TapeOp::Tanh(a) => {
560 let t = vals[i];
561 dot[*a] * (1.0 - t * t)
562 }
563 TapeOp::Asin(a) => {
564 let u = vals[*a];
565 dot[*a] / (1.0 - u * u).sqrt()
566 }
567 TapeOp::Acosh(a) => {
568 let u = vals[*a];
569 dot[*a] / (u * u - 1.0).sqrt()
570 }
571 TapeOp::Asinh(a) => {
572 let u = vals[*a];
573 dot[*a] / (u * u + 1.0).sqrt()
574 }
575 TapeOp::Atanh(a) => {
576 let u = vals[*a];
577 dot[*a] / (1.0 - u * u)
578 }
579 TapeOp::Atan2(a, b) => {
580 let y = vals[*a];
581 let x = vals[*b];
582 let d = y * y + x * x;
583 (x * dot[*a] - y * dot[*b]) / d
584 }
585 TapeOp::Min(a, b) => {
587 if vals[*a] <= vals[*b] {
588 dot[*a]
589 } else {
590 dot[*b]
591 }
592 }
593 TapeOp::Max(a, b) => {
594 if vals[*a] >= vals[*b] {
595 dot[*a]
596 } else {
597 dot[*b]
598 }
599 }
600 TapeOp::Cmp(_, _, _) | TapeOp::And(_, _) | TapeOp::Or(_, _) | TapeOp::Not(_) => 0.0,
601 TapeOp::Select(c, t, e) => {
602 if vals[*c] != 0.0 {
603 dot[*t]
604 } else {
605 dot[*e]
606 }
607 }
608 TapeOp::Funcall(fc) => {
609 let FuncallData { lib, name, args } = fc.as_ref();
610 let call_args = funcall_to_ext_args(args, vals);
611 let res = ext_eval_or_nan(lib, name, &call_args, args.len(), true, false);
612 let derivs = res.derivs.expect("want_derivs=true returns derivs");
613 let mut acc = 0.0;
614 let mut k = 0usize;
615 for arg in args {
616 if let TapeFuncallArg::Tape(idx) = arg {
617 acc += derivs[k] * dot[*idx];
618 k += 1;
619 }
620 }
621 acc
622 }
623 };
624 }
625 }
626
627 pub fn forward_into(&self, x: &[f64], vals: &mut [f64]) {
631 let n = self.ops.len();
632 debug_assert!(vals.len() >= n);
633 for i in 0..n {
634 vals[i] = match &self.ops[i] {
635 TapeOp::Const(c) => *c,
636 TapeOp::Var(j) => x[*j],
637 TapeOp::Add(a, b) => vals[*a] + vals[*b],
638 TapeOp::Sub(a, b) => vals[*a] - vals[*b],
639 TapeOp::Mul(a, b) => vals[*a] * vals[*b],
640 TapeOp::Div(a, b) => vals[*a] / vals[*b],
641 TapeOp::Pow(a, b) => vals[*a].powf(vals[*b]),
642 TapeOp::Neg(a) => -vals[*a],
643 TapeOp::Abs(a) => vals[*a].abs(),
644 TapeOp::Sqrt(a) => vals[*a].sqrt(),
645 TapeOp::Exp(a) => vals[*a].exp(),
646 TapeOp::Log(a) => vals[*a].ln(),
647 TapeOp::Log10(a) => vals[*a].log10(),
648 TapeOp::Sin(a) => vals[*a].sin(),
649 TapeOp::Cos(a) => vals[*a].cos(),
650 TapeOp::Tan(a) => vals[*a].tan(),
651 TapeOp::Atan(a) => vals[*a].atan(),
652 TapeOp::Acos(a) => vals[*a].acos(),
653 TapeOp::Sinh(a) => vals[*a].sinh(),
654 TapeOp::Cosh(a) => vals[*a].cosh(),
655 TapeOp::Tanh(a) => vals[*a].tanh(),
656 TapeOp::Asin(a) => vals[*a].asin(),
657 TapeOp::Acosh(a) => vals[*a].acosh(),
658 TapeOp::Asinh(a) => vals[*a].asinh(),
659 TapeOp::Atanh(a) => vals[*a].atanh(),
660 TapeOp::Atan2(a, b) => vals[*a].atan2(vals[*b]),
661 TapeOp::Min(a, b) => vals[*a].min(vals[*b]),
662 TapeOp::Max(a, b) => vals[*a].max(vals[*b]),
663 TapeOp::Cmp(op, a, b) => f64::from(cmp_holds(*op, vals[*a], vals[*b])),
664 TapeOp::And(a, b) => f64::from(vals[*a] != 0.0 && vals[*b] != 0.0),
665 TapeOp::Or(a, b) => f64::from(vals[*a] != 0.0 || vals[*b] != 0.0),
666 TapeOp::Not(a) => f64::from(vals[*a] == 0.0),
667 TapeOp::Select(c, t, e) => {
668 if vals[*c] != 0.0 {
669 vals[*t]
670 } else {
671 vals[*e]
672 }
673 }
674 TapeOp::Funcall(fc) => {
675 let FuncallData { lib, name, args } = fc.as_ref();
676 let call_args = funcall_to_ext_args(args, &*vals);
677 let res = ext_eval_or_nan(lib, name, &call_args, args.len(), false, false);
678 res.value
679 }
680 };
681 }
682 }
683
684 pub fn hessian_directional(
701 &self,
702 vals: &[f64],
703 seed: &[f64],
704 weight: f64,
705 out: &mut [f64],
706 dot: &mut [f64],
707 adj: &mut [f64],
708 adj_dot: &mut [f64],
709 ) {
710 let n = self.ops.len();
711 if n == 0 || weight == 0.0 {
712 return;
713 }
714 debug_assert!(vals.len() >= n);
715 debug_assert!(dot.len() >= n);
716 debug_assert!(adj.len() >= n);
717 debug_assert!(adj_dot.len() >= n);
718
719 for i in 0..n {
723 dot[i] = match &self.ops[i] {
724 TapeOp::Const(_) => 0.0,
725 TapeOp::Var(k) => seed[*k],
726 TapeOp::Add(a, b) => dot[*a] + dot[*b],
727 TapeOp::Sub(a, b) => dot[*a] - dot[*b],
728 TapeOp::Mul(a, b) => dot[*a] * vals[*b] + vals[*a] * dot[*b],
729 TapeOp::Div(a, b) => {
730 let vb = vals[*b];
731 (dot[*a] * vb - vals[*a] * dot[*b]) / (vb * vb)
732 }
733 TapeOp::Pow(a, b) => {
734 let u = vals[*a];
735 let r = vals[*b];
736 let du = dot[*a];
737 let dr = dot[*b];
738 let mut result = 0.0;
739 if r != 0.0 {
744 result += r * u.powf(r - 1.0) * du;
745 }
746 if u > 0.0 {
747 result += vals[i] * u.ln() * dr;
748 }
749 result
750 }
751 TapeOp::Neg(a) => -dot[*a],
752 TapeOp::Abs(a) => {
753 if vals[*a] >= 0.0 {
754 dot[*a]
755 } else {
756 -dot[*a]
757 }
758 }
759 TapeOp::Sqrt(a) => {
760 let sv = vals[i];
761 if sv > 0.0 {
762 dot[*a] * 0.5 / sv
763 } else {
764 0.0
765 }
766 }
767 TapeOp::Exp(a) => vals[i] * dot[*a],
768 TapeOp::Log(a) => dot[*a] / vals[*a],
769 TapeOp::Log10(a) => dot[*a] / (vals[*a] * std::f64::consts::LN_10),
770 TapeOp::Sin(a) => vals[*a].cos() * dot[*a],
771 TapeOp::Cos(a) => -vals[*a].sin() * dot[*a],
772 TapeOp::Tan(a) => {
773 let t = vals[i];
774 (1.0 + t * t) * dot[*a]
775 }
776 TapeOp::Atan(a) => {
777 let u = vals[*a];
778 dot[*a] / (1.0 + u * u)
779 }
780 TapeOp::Acos(a) => {
781 let u = vals[*a];
782 -dot[*a] / (1.0 - u * u).sqrt()
783 }
784 TapeOp::Sinh(a) => dot[*a] * vals[*a].cosh(),
785 TapeOp::Cosh(a) => dot[*a] * vals[*a].sinh(),
786 TapeOp::Tanh(a) => {
787 let t = vals[i];
788 (1.0 - t * t) * dot[*a]
789 }
790 TapeOp::Asin(a) => {
791 let u = vals[*a];
792 dot[*a] / (1.0 - u * u).sqrt()
793 }
794 TapeOp::Acosh(a) => {
795 let u = vals[*a];
796 dot[*a] / (u * u - 1.0).sqrt()
797 }
798 TapeOp::Asinh(a) => {
799 let u = vals[*a];
800 dot[*a] / (u * u + 1.0).sqrt()
801 }
802 TapeOp::Atanh(a) => {
803 let u = vals[*a];
804 dot[*a] / (1.0 - u * u)
805 }
806 TapeOp::Atan2(a, b) => {
807 let y = vals[*a];
808 let x = vals[*b];
809 let d = y * y + x * x;
810 (x * dot[*a] - y * dot[*b]) / d
811 }
812 TapeOp::Min(a, b) => {
814 if vals[*a] <= vals[*b] {
815 dot[*a]
816 } else {
817 dot[*b]
818 }
819 }
820 TapeOp::Max(a, b) => {
821 if vals[*a] >= vals[*b] {
822 dot[*a]
823 } else {
824 dot[*b]
825 }
826 }
827 TapeOp::Cmp(_, _, _) | TapeOp::And(_, _) | TapeOp::Or(_, _) | TapeOp::Not(_) => 0.0,
828 TapeOp::Select(c, t, e) => {
829 if vals[*c] != 0.0 {
830 dot[*t]
831 } else {
832 dot[*e]
833 }
834 }
835 TapeOp::Funcall(fc) => {
836 let FuncallData { lib, name, args } = fc.as_ref();
837 let call_args = funcall_to_ext_args(args, vals);
838 let res = ext_eval_or_nan(lib, name, &call_args, args.len(), true, false);
839 let derivs = res.derivs.expect("want_derivs=true returns derivs");
840 let mut acc = 0.0;
841 let mut k = 0usize;
842 for arg in args {
843 if let TapeFuncallArg::Tape(idx) = arg {
844 acc += derivs[k] * dot[*idx];
845 k += 1;
846 }
847 }
848 acc
849 }
850 };
851 }
852
853 for slot in adj.iter_mut().take(n) {
857 *slot = 0.0;
858 }
859 for slot in adj_dot.iter_mut().take(n) {
860 *slot = 0.0;
861 }
862 adj[n - 1] = 1.0;
863
864 for i in (0..n).rev() {
865 let w = adj[i];
866 let wd = adj_dot[i];
867 if w == 0.0 && wd == 0.0 {
868 continue;
869 }
870 match &self.ops[i] {
871 TapeOp::Const(_) => {}
872 TapeOp::Var(k) => {
873 if wd != 0.0 {
874 out[*k] += weight * wd;
875 }
876 }
877 TapeOp::Add(a, b) => {
878 adj[*a] += w;
879 adj[*b] += w;
880 adj_dot[*a] += wd;
881 adj_dot[*b] += wd;
882 }
883 TapeOp::Sub(a, b) => {
884 adj[*a] += w;
885 adj[*b] -= w;
886 adj_dot[*a] += wd;
887 adj_dot[*b] -= wd;
888 }
889 TapeOp::Mul(a, b) => {
890 adj[*a] += w * vals[*b];
891 adj[*b] += w * vals[*a];
892 adj_dot[*a] += wd * vals[*b] + w * dot[*b];
893 adj_dot[*b] += wd * vals[*a] + w * dot[*a];
894 }
895 TapeOp::Div(a, b) => {
896 let vb = vals[*b];
897 let vb2 = vb * vb;
898 let vb3 = vb2 * vb;
899 adj[*a] += w / vb;
900 adj_dot[*a] += wd / vb + w * (-dot[*b] / vb2);
901 adj[*b] += w * (-vals[*a] / vb2);
902 adj_dot[*b] += wd * (-vals[*a] / vb2)
903 + w * (-dot[*a] / vb2 + 2.0 * vals[*a] * dot[*b] / vb3);
904 }
905 TapeOp::Pow(a, b) => {
906 let u = vals[*a];
907 let r = vals[*b];
908 let du = dot[*a];
909 let dr = dot[*b];
910 if r != 0.0 {
911 if u != 0.0 {
912 let p_a = r * u.powf(r - 1.0);
913 adj[*a] += w * p_a;
914 let mut dp_a = dr * u.powf(r - 1.0);
915 if u > 0.0 {
916 dp_a += r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
917 } else {
918 dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
919 }
920 adj_dot[*a] += wd * p_a + w * dp_a;
921 } else if r >= 2.0 {
922 let p_a = 0.0;
923 adj[*a] += w * p_a;
924 let dp_a = if r == 2.0 {
925 2.0 * du
926 } else {
927 r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
928 };
929 adj_dot[*a] += wd * p_a + w * dp_a;
930 }
931 }
932 if u > 0.0 {
933 let ln_u = u.ln();
934 let p_b = vals[i] * ln_u;
935 adj[*b] += w * p_b;
936 let dur = vals[i] * (r * du / u + dr * ln_u);
937 let dp_b = dur * ln_u + vals[i] * du / u;
938 adj_dot[*b] += wd * p_b + w * dp_b;
939 }
940 }
941 TapeOp::Neg(a) => {
942 adj[*a] -= w;
943 adj_dot[*a] -= wd;
944 }
945 TapeOp::Abs(a) => {
946 let s = if vals[*a] >= 0.0 { 1.0 } else { -1.0 };
947 adj[*a] += w * s;
948 adj_dot[*a] += wd * s;
949 }
950 TapeOp::Sqrt(a) => {
951 let sv = vals[i];
952 if sv > 0.0 {
953 let fp = 0.5 / sv;
954 let fpp = -0.25 / (vals[*a] * sv);
955 adj[*a] += w * fp;
956 adj_dot[*a] += wd * fp + w * fpp * dot[*a];
957 }
958 }
959 TapeOp::Exp(a) => {
960 let ev = vals[i];
961 adj[*a] += w * ev;
962 adj_dot[*a] += wd * ev + w * ev * dot[*a];
963 }
964 TapeOp::Log(a) => {
965 let u = vals[*a];
966 adj[*a] += w / u;
967 adj_dot[*a] += wd / u + w * (-1.0 / (u * u)) * dot[*a];
968 }
969 TapeOp::Log10(a) => {
970 let u = vals[*a];
971 let c = std::f64::consts::LN_10;
972 adj[*a] += w / (u * c);
973 adj_dot[*a] += wd / (u * c) + w * (-1.0 / (u * u * c)) * dot[*a];
974 }
975 TapeOp::Sin(a) => {
976 let u = vals[*a];
977 let cu = u.cos();
978 adj[*a] += w * cu;
979 adj_dot[*a] += wd * cu + w * (-u.sin()) * dot[*a];
980 }
981 TapeOp::Cos(a) => {
982 let u = vals[*a];
983 let su = u.sin();
984 adj[*a] -= w * su;
985 adj_dot[*a] += wd * (-su) + w * (-u.cos()) * dot[*a];
986 }
987 TapeOp::Tan(a) => {
988 let t = vals[i];
989 let gp = 1.0 + t * t;
990 let gpp = 2.0 * t * gp;
991 adj[*a] += w * gp;
992 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
993 }
994 TapeOp::Atan(a) => {
995 let u = vals[*a];
996 let d = 1.0 + u * u;
997 let gp = 1.0 / d;
998 let gpp = -2.0 * u / (d * d);
999 adj[*a] += w * gp;
1000 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1001 }
1002 TapeOp::Acos(a) => {
1003 let u = vals[*a];
1004 let s = 1.0 - u * u;
1005 let r = s.sqrt();
1006 let gp = -1.0 / r;
1007 let gpp = -u / (s * r);
1008 adj[*a] += w * gp;
1009 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1010 }
1011 TapeOp::Sinh(a) => {
1012 let u = vals[*a];
1013 let gp = u.cosh();
1014 let gpp = u.sinh();
1015 adj[*a] += w * gp;
1016 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1017 }
1018 TapeOp::Cosh(a) => {
1019 let u = vals[*a];
1020 let gp = u.sinh();
1021 let gpp = u.cosh();
1022 adj[*a] += w * gp;
1023 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1024 }
1025 TapeOp::Tanh(a) => {
1026 let t = vals[i];
1027 let gp = 1.0 - t * t;
1028 let gpp = -2.0 * t * gp;
1029 adj[*a] += w * gp;
1030 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1031 }
1032 TapeOp::Asin(a) => {
1033 let u = vals[*a];
1034 let s = 1.0 - u * u;
1035 let r = s.sqrt();
1036 let gp = 1.0 / r;
1037 let gpp = u / (s * r);
1038 adj[*a] += w * gp;
1039 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1040 }
1041 TapeOp::Acosh(a) => {
1042 let u = vals[*a];
1043 let s = u * u - 1.0;
1044 let r = s.sqrt();
1045 let gp = 1.0 / r;
1046 let gpp = -u / (s * r);
1047 adj[*a] += w * gp;
1048 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1049 }
1050 TapeOp::Asinh(a) => {
1051 let u = vals[*a];
1052 let s = u * u + 1.0;
1053 let r = s.sqrt();
1054 let gp = 1.0 / r;
1055 let gpp = -u / (s * r);
1056 adj[*a] += w * gp;
1057 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1058 }
1059 TapeOp::Atanh(a) => {
1060 let u = vals[*a];
1061 let d = 1.0 - u * u;
1062 let gp = 1.0 / d;
1063 let gpp = 2.0 * u / (d * d);
1064 adj[*a] += w * gp;
1065 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1066 }
1067 TapeOp::Atan2(a, b) => {
1068 let y = vals[*a];
1069 let x = vals[*b];
1070 let d = y * y + x * x;
1071 let d2 = d * d;
1072 let fa = x / d;
1073 let fb = -y / d;
1074 let faa = -2.0 * y * x / d2;
1075 let fab = (y * y - x * x) / d2;
1076 let fbb = 2.0 * y * x / d2;
1077 adj[*a] += w * fa;
1078 adj[*b] += w * fb;
1079 adj_dot[*a] += wd * fa + w * (faa * dot[*a] + fab * dot[*b]);
1080 adj_dot[*b] += wd * fb + w * (fab * dot[*a] + fbb * dot[*b]);
1081 }
1082 TapeOp::Min(a, b) => {
1086 let br = if vals[*a] <= vals[*b] { *a } else { *b };
1087 adj[br] += w;
1088 adj_dot[br] += wd;
1089 }
1090 TapeOp::Max(a, b) => {
1091 let br = if vals[*a] >= vals[*b] { *a } else { *b };
1092 adj[br] += w;
1093 adj_dot[br] += wd;
1094 }
1095 TapeOp::Cmp(_, _, _) | TapeOp::And(_, _) | TapeOp::Or(_, _) | TapeOp::Not(_) => {}
1097 TapeOp::Select(c, t, e) => {
1100 let br = if vals[*c] != 0.0 { *t } else { *e };
1101 adj[br] += w;
1102 adj_dot[br] += wd;
1103 }
1104 TapeOp::Funcall(fc) => {
1105 let FuncallData { lib, name, args } = fc.as_ref();
1106 let call_args = funcall_to_ext_args(args, vals);
1107 let res = ext_eval_or_nan(lib, name, &call_args, args.len(), true, true);
1108 let derivs = res.derivs.expect("want_derivs=true returns derivs");
1109 let hes = res.hessian.expect("want_hes=true returns hessian");
1110 let real_tape: Vec<usize> = args
1111 .iter()
1112 .filter_map(|a| match a {
1113 TapeFuncallArg::Tape(t) => Some(*t),
1114 TapeFuncallArg::Str(_) => None,
1115 })
1116 .collect();
1117 for (k, &tk) in real_tape.iter().enumerate() {
1118 adj[tk] += w * derivs[k];
1119 let mut second_term = 0.0;
1120 for (l, &tl) in real_tape.iter().enumerate() {
1121 let (lo, hi) = if k <= l { (k, l) } else { (l, k) };
1122 let h_kl = hes[lo + hi * (hi + 1) / 2];
1123 second_term += h_kl * dot[tl];
1124 }
1125 adj_dot[tk] += wd * derivs[k] + w * second_term;
1126 }
1127 }
1128 }
1129 }
1130 }
1131
1132 pub fn hessian_accumulate(
1139 &self,
1140 x: &[f64],
1141 weight: f64,
1142 hess_map: &HashMap<(usize, usize), usize>,
1143 values: &mut [f64],
1144 ) {
1145 let n = self.ops.len();
1146 if n == 0 || weight == 0.0 {
1147 return;
1148 }
1149 let v = self.forward(x);
1150 let var_indices = self.variables();
1151
1152 let mut dot = vec![0.0f64; n];
1159 let mut adj = vec![0.0f64; n];
1160 let mut adj_dot = vec![0.0f64; n];
1161 for &j in &var_indices {
1162 self.forward_tangent(&v, j, &mut dot);
1163
1164 adj.fill(0.0);
1167 adj_dot.fill(0.0);
1168 adj[n - 1] = 1.0;
1169
1170 for i in (0..n).rev() {
1171 let w = adj[i];
1172 let wd = adj_dot[i];
1173 if w == 0.0 && wd == 0.0 {
1174 continue;
1175 }
1176 match &self.ops[i] {
1177 TapeOp::Const(_) => {}
1178 TapeOp::Var(k) => {
1179 if wd != 0.0 && *k >= j {
1182 if let Some(&pos) = hess_map.get(&(*k, j)) {
1183 values[pos] += weight * wd;
1184 }
1185 }
1186 }
1187 TapeOp::Add(a, b) => {
1188 adj[*a] += w;
1189 adj[*b] += w;
1190 adj_dot[*a] += wd;
1191 adj_dot[*b] += wd;
1192 }
1193 TapeOp::Sub(a, b) => {
1194 adj[*a] += w;
1195 adj[*b] -= w;
1196 adj_dot[*a] += wd;
1197 adj_dot[*b] -= wd;
1198 }
1199 TapeOp::Mul(a, b) => {
1200 adj[*a] += w * v[*b];
1201 adj[*b] += w * v[*a];
1202 adj_dot[*a] += wd * v[*b] + w * dot[*b];
1203 adj_dot[*b] += wd * v[*a] + w * dot[*a];
1204 }
1205 TapeOp::Div(a, b) => {
1206 let vb = v[*b];
1207 let vb2 = vb * vb;
1208 let vb3 = vb2 * vb;
1209 adj[*a] += w / vb;
1210 adj_dot[*a] += wd / vb + w * (-dot[*b] / vb2);
1211 adj[*b] += w * (-v[*a] / vb2);
1212 adj_dot[*b] += wd * (-v[*a] / vb2)
1213 + w * (-dot[*a] / vb2 + 2.0 * v[*a] * dot[*b] / vb3);
1214 }
1215 TapeOp::Pow(a, b) => {
1216 let u = v[*a];
1217 let r = v[*b];
1218 let du = dot[*a];
1219 let dr = dot[*b];
1220 if r != 0.0 {
1221 if u != 0.0 {
1222 let p_a = r * u.powf(r - 1.0);
1223 adj[*a] += w * p_a;
1224 let mut dp_a = dr * u.powf(r - 1.0);
1225 if u > 0.0 {
1226 dp_a +=
1227 r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
1228 } else {
1229 dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
1230 }
1231 adj_dot[*a] += wd * p_a + w * dp_a;
1232 } else if r >= 2.0 {
1233 let p_a = 0.0;
1234 adj[*a] += w * p_a;
1235 let dp_a = if r == 2.0 {
1236 2.0 * du
1237 } else {
1238 r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
1239 };
1240 adj_dot[*a] += wd * p_a + w * dp_a;
1241 }
1242 }
1243 if u > 0.0 {
1244 let ln_u = u.ln();
1245 let p_b = v[i] * ln_u;
1246 adj[*b] += w * p_b;
1247 let dur = v[i] * (r * du / u + dr * ln_u);
1248 let dp_b = dur * ln_u + v[i] * du / u;
1249 adj_dot[*b] += wd * p_b + w * dp_b;
1250 }
1251 }
1252 TapeOp::Neg(a) => {
1253 adj[*a] -= w;
1254 adj_dot[*a] -= wd;
1255 }
1256 TapeOp::Abs(a) => {
1257 let s = if v[*a] >= 0.0 { 1.0 } else { -1.0 };
1258 adj[*a] += w * s;
1259 adj_dot[*a] += wd * s;
1260 }
1261 TapeOp::Sqrt(a) => {
1262 let sv = v[i];
1263 if sv > 0.0 {
1264 let fp = 0.5 / sv;
1265 let fpp = -0.25 / (v[*a] * sv);
1266 adj[*a] += w * fp;
1267 adj_dot[*a] += wd * fp + w * fpp * dot[*a];
1268 }
1269 }
1270 TapeOp::Exp(a) => {
1271 let ev = v[i];
1272 adj[*a] += w * ev;
1273 adj_dot[*a] += wd * ev + w * ev * dot[*a];
1274 }
1275 TapeOp::Log(a) => {
1276 let u = v[*a];
1277 adj[*a] += w / u;
1278 adj_dot[*a] += wd / u + w * (-1.0 / (u * u)) * dot[*a];
1279 }
1280 TapeOp::Log10(a) => {
1281 let u = v[*a];
1282 let c = std::f64::consts::LN_10;
1283 adj[*a] += w / (u * c);
1284 adj_dot[*a] += wd / (u * c) + w * (-1.0 / (u * u * c)) * dot[*a];
1285 }
1286 TapeOp::Sin(a) => {
1287 let u = v[*a];
1288 let cu = u.cos();
1289 adj[*a] += w * cu;
1290 adj_dot[*a] += wd * cu + w * (-u.sin()) * dot[*a];
1291 }
1292 TapeOp::Cos(a) => {
1293 let u = v[*a];
1294 let su = u.sin();
1295 adj[*a] -= w * su;
1296 adj_dot[*a] += wd * (-su) + w * (-u.cos()) * dot[*a];
1297 }
1298 TapeOp::Tan(a) => {
1299 let t = v[i];
1300 let gp = 1.0 + t * t;
1301 let gpp = 2.0 * t * gp;
1302 adj[*a] += w * gp;
1303 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1304 }
1305 TapeOp::Atan(a) => {
1306 let u = v[*a];
1307 let d = 1.0 + u * u;
1308 let gp = 1.0 / d;
1309 let gpp = -2.0 * u / (d * d);
1310 adj[*a] += w * gp;
1311 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1312 }
1313 TapeOp::Acos(a) => {
1314 let u = v[*a];
1315 let s = 1.0 - u * u;
1316 let r = s.sqrt();
1317 let gp = -1.0 / r;
1318 let gpp = -u / (s * r);
1319 adj[*a] += w * gp;
1320 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1321 }
1322 TapeOp::Sinh(a) => {
1323 let u = v[*a];
1324 let gp = u.cosh();
1325 let gpp = u.sinh();
1326 adj[*a] += w * gp;
1327 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1328 }
1329 TapeOp::Cosh(a) => {
1330 let u = v[*a];
1331 let gp = u.sinh();
1332 let gpp = u.cosh();
1333 adj[*a] += w * gp;
1334 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1335 }
1336 TapeOp::Tanh(a) => {
1337 let t = v[i];
1338 let gp = 1.0 - t * t;
1339 let gpp = -2.0 * t * gp;
1340 adj[*a] += w * gp;
1341 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1342 }
1343 TapeOp::Asin(a) => {
1344 let u = v[*a];
1345 let s = 1.0 - u * u;
1346 let r = s.sqrt();
1347 let gp = 1.0 / r;
1348 let gpp = u / (s * r);
1349 adj[*a] += w * gp;
1350 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1351 }
1352 TapeOp::Acosh(a) => {
1353 let u = v[*a];
1354 let s = u * u - 1.0;
1355 let r = s.sqrt();
1356 let gp = 1.0 / r;
1357 let gpp = -u / (s * r);
1358 adj[*a] += w * gp;
1359 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1360 }
1361 TapeOp::Asinh(a) => {
1362 let u = v[*a];
1363 let s = u * u + 1.0;
1364 let r = s.sqrt();
1365 let gp = 1.0 / r;
1366 let gpp = -u / (s * r);
1367 adj[*a] += w * gp;
1368 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1369 }
1370 TapeOp::Atanh(a) => {
1371 let u = v[*a];
1372 let d = 1.0 - u * u;
1373 let gp = 1.0 / d;
1374 let gpp = 2.0 * u / (d * d);
1375 adj[*a] += w * gp;
1376 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1377 }
1378 TapeOp::Atan2(a, b) => {
1379 let y = v[*a];
1380 let x = v[*b];
1381 let d = y * y + x * x;
1382 let d2 = d * d;
1383 let fa = x / d;
1384 let fb = -y / d;
1385 let faa = -2.0 * y * x / d2;
1386 let fab = (y * y - x * x) / d2;
1387 let fbb = 2.0 * y * x / d2;
1388 adj[*a] += w * fa;
1389 adj[*b] += w * fb;
1390 adj_dot[*a] += wd * fa + w * (faa * dot[*a] + fab * dot[*b]);
1391 adj_dot[*b] += wd * fb + w * (fab * dot[*a] + fbb * dot[*b]);
1392 }
1393 TapeOp::Min(a, b) => {
1397 let br = if v[*a] <= v[*b] { *a } else { *b };
1398 adj[br] += w;
1399 adj_dot[br] += wd;
1400 }
1401 TapeOp::Max(a, b) => {
1402 let br = if v[*a] >= v[*b] { *a } else { *b };
1403 adj[br] += w;
1404 adj_dot[br] += wd;
1405 }
1406 TapeOp::Cmp(_, _, _)
1408 | TapeOp::And(_, _)
1409 | TapeOp::Or(_, _)
1410 | TapeOp::Not(_) => {}
1411 TapeOp::Select(c, t, e) => {
1414 let br = if v[*c] != 0.0 { *t } else { *e };
1415 adj[br] += w;
1416 adj_dot[br] += wd;
1417 }
1418 TapeOp::Funcall(fc) => {
1419 let FuncallData { lib, name, args } = fc.as_ref();
1420 let call_args = funcall_to_ext_args(args, &v);
1421 let res = ext_eval_or_nan(lib, name, &call_args, args.len(), true, true);
1422 let derivs = res.derivs.expect("want_derivs=true returns derivs");
1423 let hes = res.hessian.expect("want_hes=true returns hessian");
1424 let real_tape: Vec<usize> = args
1425 .iter()
1426 .filter_map(|a| match a {
1427 TapeFuncallArg::Tape(t) => Some(*t),
1428 TapeFuncallArg::Str(_) => None,
1429 })
1430 .collect();
1431 for (k, &tk) in real_tape.iter().enumerate() {
1432 adj[tk] += w * derivs[k];
1433 let mut second_term = 0.0;
1434 for (l, &tl) in real_tape.iter().enumerate() {
1435 let (lo, hi) = if k <= l { (k, l) } else { (l, k) };
1436 let h_kl = hes[lo + hi * (hi + 1) / 2];
1437 second_term += h_kl * dot[tl];
1438 }
1439 adj_dot[tk] += wd * derivs[k] + w * second_term;
1440 }
1441 }
1442 }
1443 }
1444 }
1445 }
1446
1447 pub fn hessian_sparsity(&self) -> BTreeSet<(usize, usize)> {
1452 let n = self.ops.len();
1453 let mut var_sets: Vec<BTreeSet<usize>> = Vec::with_capacity(n);
1454 let mut pairs: BTreeSet<(usize, usize)> = BTreeSet::new();
1455
1456 let emit_cross =
1457 |s1: &BTreeSet<usize>, s2: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
1458 for &v1 in s1 {
1459 for &v2 in s2 {
1460 let (r, c) = if v1 >= v2 { (v1, v2) } else { (v2, v1) };
1461 pairs.insert((r, c));
1462 }
1463 }
1464 };
1465 let emit_self = |s: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
1466 let vars: Vec<usize> = s.iter().copied().collect();
1467 for (ai, &vi) in vars.iter().enumerate() {
1468 for &vj in &vars[..=ai] {
1469 let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
1470 pairs.insert((r, c));
1471 }
1472 }
1473 };
1474
1475 for op in &self.ops {
1476 let vset = match op {
1477 TapeOp::Const(_) => BTreeSet::new(),
1478 TapeOp::Var(j) => {
1479 let mut s = BTreeSet::new();
1480 s.insert(*j);
1481 s
1482 }
1483 TapeOp::Add(a, b) | TapeOp::Sub(a, b) => {
1484 var_sets[*a].union(&var_sets[*b]).copied().collect()
1485 }
1486 TapeOp::Neg(a) | TapeOp::Abs(a) => var_sets[*a].clone(),
1487 TapeOp::Mul(a, b) => {
1488 emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
1489 var_sets[*a].union(&var_sets[*b]).copied().collect()
1490 }
1491 TapeOp::Div(a, b) => {
1492 emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
1493 emit_self(&var_sets[*b], &mut pairs);
1494 var_sets[*a].union(&var_sets[*b]).copied().collect()
1495 }
1496 TapeOp::Pow(a, b) => {
1497 let combined: BTreeSet<usize> =
1498 var_sets[*a].union(&var_sets[*b]).copied().collect();
1499 emit_self(&combined, &mut pairs);
1500 combined
1501 }
1502 TapeOp::Sqrt(a)
1503 | TapeOp::Exp(a)
1504 | TapeOp::Log(a)
1505 | TapeOp::Log10(a)
1506 | TapeOp::Sin(a)
1507 | TapeOp::Cos(a)
1508 | TapeOp::Tan(a)
1509 | TapeOp::Atan(a)
1510 | TapeOp::Acos(a)
1511 | TapeOp::Sinh(a)
1512 | TapeOp::Cosh(a)
1513 | TapeOp::Tanh(a)
1514 | TapeOp::Asin(a)
1515 | TapeOp::Acosh(a)
1516 | TapeOp::Asinh(a)
1517 | TapeOp::Atanh(a) => {
1518 emit_self(&var_sets[*a], &mut pairs);
1519 var_sets[*a].clone()
1520 }
1521 TapeOp::Atan2(a, b) => {
1525 let combined: BTreeSet<usize> =
1526 var_sets[*a].union(&var_sets[*b]).copied().collect();
1527 emit_self(&combined, &mut pairs);
1528 combined
1529 }
1530 TapeOp::Cmp(_, _, _) | TapeOp::And(_, _) | TapeOp::Or(_, _) | TapeOp::Not(_) => {
1536 BTreeSet::new()
1537 }
1538 TapeOp::Select(_c, t, e) => var_sets[*t].union(&var_sets[*e]).copied().collect(),
1545 TapeOp::Min(a, b) | TapeOp::Max(a, b) => {
1552 var_sets[*a].union(&var_sets[*b]).copied().collect()
1553 }
1554 TapeOp::Funcall(fc) => {
1555 let args = &fc.args;
1556 let mut combined: BTreeSet<usize> = BTreeSet::new();
1557 for arg in args {
1558 if let TapeFuncallArg::Tape(t) = arg {
1559 for &vv in &var_sets[*t] {
1560 combined.insert(vv);
1561 }
1562 }
1563 }
1564 emit_self(&combined, &mut pairs);
1565 combined
1566 }
1567 };
1568 var_sets.push(vset);
1569 }
1570 pairs
1571 }
1572}
1573
1574fn build_recursive(
1575 expr: &Expr,
1576 ops: &mut Vec<TapeOp>,
1577 cache: &mut HashMap<*const Expr, usize>,
1578 resolver: &ExternalResolver,
1579) -> usize {
1580 match expr {
1581 Expr::Const(c) => {
1582 let idx = ops.len();
1583 ops.push(TapeOp::Const(*c));
1584 idx
1585 }
1586 Expr::Var(i) => {
1587 let idx = ops.len();
1588 ops.push(TapeOp::Var(*i));
1589 idx
1590 }
1591 Expr::Binary(op, a, b) => {
1592 if let BinOp::Pow = op {
1600 if let Some(c) = peek_const(b) {
1601 if let Some(idx) = try_emit_const_pow(a, c, ops, cache, resolver) {
1602 return idx;
1603 }
1604 }
1605 }
1606 let l = build_recursive(a, ops, cache, resolver);
1607 let r = build_recursive(b, ops, cache, resolver);
1608 let idx = ops.len();
1609 ops.push(match op {
1610 BinOp::Add => TapeOp::Add(l, r),
1611 BinOp::Sub => TapeOp::Sub(l, r),
1612 BinOp::Mul => TapeOp::Mul(l, r),
1613 BinOp::Div => TapeOp::Div(l, r),
1614 BinOp::Pow => TapeOp::Pow(l, r),
1615 BinOp::Atan2 => TapeOp::Atan2(l, r),
1616 });
1617 idx
1618 }
1619 Expr::Unary(op, a) => {
1620 let v = build_recursive(a, ops, cache, resolver);
1621 let idx = ops.len();
1622 ops.push(match op {
1623 UnaryOp::Neg => TapeOp::Neg(v),
1624 UnaryOp::Sqrt => TapeOp::Sqrt(v),
1625 UnaryOp::Log => TapeOp::Log(v),
1626 UnaryOp::Log10 => TapeOp::Log10(v),
1627 UnaryOp::Exp => TapeOp::Exp(v),
1628 UnaryOp::Abs => TapeOp::Abs(v),
1629 UnaryOp::Sin => TapeOp::Sin(v),
1630 UnaryOp::Cos => TapeOp::Cos(v),
1631 UnaryOp::Tan => TapeOp::Tan(v),
1632 UnaryOp::Atan => TapeOp::Atan(v),
1633 UnaryOp::Acos => TapeOp::Acos(v),
1634 UnaryOp::Sinh => TapeOp::Sinh(v),
1635 UnaryOp::Cosh => TapeOp::Cosh(v),
1636 UnaryOp::Tanh => TapeOp::Tanh(v),
1637 UnaryOp::Asin => TapeOp::Asin(v),
1638 UnaryOp::Acosh => TapeOp::Acosh(v),
1639 UnaryOp::Asinh => TapeOp::Asinh(v),
1640 UnaryOp::Atanh => TapeOp::Atanh(v),
1641 });
1642 idx
1643 }
1644 Expr::Sum(args) => {
1645 if args.is_empty() {
1646 let idx = ops.len();
1647 ops.push(TapeOp::Const(0.0));
1648 return idx;
1649 }
1650 let mut acc = build_recursive(&args[0], ops, cache, resolver);
1651 for a in &args[1..] {
1652 let next = build_recursive(a, ops, cache, resolver);
1653 let idx = ops.len();
1654 ops.push(TapeOp::Add(acc, next));
1655 acc = idx;
1656 }
1657 acc
1658 }
1659 Expr::MinList(args) | Expr::MaxList(args) => {
1667 let is_min = matches!(expr, Expr::MinList(_));
1668 if args.is_empty() {
1669 let idx = ops.len();
1670 ops.push(TapeOp::Const(0.0));
1671 return idx;
1672 }
1673 let mut acc = build_recursive(&args[0], ops, cache, resolver);
1674 for a in &args[1..] {
1675 let next = build_recursive(a, ops, cache, resolver);
1676 let idx = ops.len();
1677 ops.push(if is_min {
1678 TapeOp::Min(acc, next)
1679 } else {
1680 TapeOp::Max(acc, next)
1681 });
1682 acc = idx;
1683 }
1684 acc
1685 }
1686 Expr::Cse(body) => {
1687 let key = Arc::as_ptr(body) as *const Expr;
1694 if let Some(&idx) = cache.get(&key) {
1695 idx
1696 } else {
1697 let idx = build_recursive(body, ops, cache, resolver);
1698 cache.insert(key, idx);
1699 idx
1700 }
1701 }
1702 Expr::Compare(op, a, b) => {
1703 let l = build_recursive(a, ops, cache, resolver);
1704 let r = build_recursive(b, ops, cache, resolver);
1705 let idx = ops.len();
1706 ops.push(TapeOp::Cmp(*op, l, r));
1707 idx
1708 }
1709 Expr::And(a, b) => {
1710 let l = build_recursive(a, ops, cache, resolver);
1711 let r = build_recursive(b, ops, cache, resolver);
1712 let idx = ops.len();
1713 ops.push(TapeOp::And(l, r));
1714 idx
1715 }
1716 Expr::Or(a, b) => {
1717 let l = build_recursive(a, ops, cache, resolver);
1718 let r = build_recursive(b, ops, cache, resolver);
1719 let idx = ops.len();
1720 ops.push(TapeOp::Or(l, r));
1721 idx
1722 }
1723 Expr::Not(a) => {
1724 let v = build_recursive(a, ops, cache, resolver);
1725 let idx = ops.len();
1726 ops.push(TapeOp::Not(v));
1727 idx
1728 }
1729 Expr::Cond { cond, then_, else_ } => {
1730 let c = build_recursive(cond, ops, cache, resolver);
1731 let t = build_recursive(then_, ops, cache, resolver);
1732 let e = build_recursive(else_, ops, cache, resolver);
1733 let idx = ops.len();
1734 ops.push(TapeOp::Select(c, t, e));
1735 idx
1736 }
1737 Expr::Funcall { id, args } => {
1738 let (lib, name) = resolver
1739 .funcs_by_id
1740 .get(id)
1741 .unwrap_or_else(|| panic!("unresolved AMPL funcall id {id}"));
1742 let tape_args: Vec<TapeFuncallArg> = args
1743 .iter()
1744 .map(|a| match a {
1745 FuncallArg::Real(e) => {
1746 TapeFuncallArg::Tape(build_recursive(e, ops, cache, resolver))
1747 }
1748 FuncallArg::Str(s) => TapeFuncallArg::Str(s.clone()),
1749 })
1750 .collect();
1751 let idx = ops.len();
1752 ops.push(TapeOp::Funcall(Box::new(FuncallData {
1753 lib: Arc::clone(lib),
1754 name: name.clone(),
1755 args: tape_args,
1756 })));
1757 idx
1758 }
1759 }
1760}
1761
1762fn peek_const(e: &Expr) -> Option<f64> {
1766 match e {
1767 Expr::Const(c) => Some(*c),
1768 Expr::Cse(body) => peek_const(body),
1769 _ => None,
1770 }
1771}
1772
1773fn try_emit_const_pow(
1781 base_expr: &Expr,
1782 c: f64,
1783 ops: &mut Vec<TapeOp>,
1784 cache: &mut HashMap<*const Expr, usize>,
1785 resolver: &ExternalResolver,
1786) -> Option<usize> {
1787 if c == 0.0 {
1788 let idx = ops.len();
1789 ops.push(TapeOp::Const(1.0));
1790 return Some(idx);
1791 }
1792 if c == 1.0 {
1793 return Some(build_recursive(base_expr, ops, cache, resolver));
1794 }
1795 if c == 0.5 {
1796 let b = build_recursive(base_expr, ops, cache, resolver);
1797 let idx = ops.len();
1798 ops.push(TapeOp::Sqrt(b));
1799 return Some(idx);
1800 }
1801 if c.is_finite() && c.fract() == 0.0 && c.abs() <= 8.0 {
1806 let n = c.abs() as u32;
1807 if n == 0 {
1808 let idx = ops.len();
1810 ops.push(TapeOp::Const(1.0));
1811 return Some(idx);
1812 }
1813 let b = build_recursive(base_expr, ops, cache, resolver);
1814 let pos = emit_int_pow(b, n, ops);
1815 if c < 0.0 {
1816 let one_idx = ops.len();
1819 ops.push(TapeOp::Const(1.0));
1820 let idx = ops.len();
1821 ops.push(TapeOp::Div(one_idx, pos));
1822 return Some(idx);
1823 }
1824 return Some(pos);
1825 }
1826 None
1827}
1828
1829fn emit_int_pow(base: usize, n: u32, ops: &mut Vec<TapeOp>) -> usize {
1833 debug_assert!(n >= 1);
1834 if n == 1 {
1835 return base;
1836 }
1837 let half = emit_int_pow(base, n / 2, ops);
1838 let squared = ops.len();
1839 ops.push(TapeOp::Mul(half, half));
1840 if n % 2 == 1 {
1841 let idx = ops.len();
1842 ops.push(TapeOp::Mul(squared, base));
1843 idx
1844 } else {
1845 squared
1846 }
1847}
1848
1849#[derive(Debug, Clone)]
1877pub enum SummandOp {
1878 Local(TapeOp),
1881 Shared(usize),
1885}
1886
1887#[derive(Debug, Clone)]
1888pub struct Summand {
1889 pub ops: Vec<SummandOp>,
1890 pub root_slot: usize,
1892 pub local_reach: Vec<usize>,
1894 pub prelude_reach: Vec<usize>,
1897 pub local_vars: Vec<usize>,
1899 pub prelude_vars: Vec<usize>,
1901 pub all_vars: Vec<usize>,
1903}
1904
1905#[derive(Debug)]
1906pub struct HybridTape {
1907 pub prelude: Vec<TapeOp>,
1912 pub summands: Vec<Summand>,
1913}
1914
1915impl HybridTape {
1916 pub fn build_multi(exprs: &[Expr]) -> Self {
1921 let mut cse_count: HashMap<*const Expr, usize> = HashMap::new();
1925 for e in exprs {
1926 let mut seen_in_root: HashSet<*const Expr> = HashSet::new();
1927 count_cse_appearances(e, &mut seen_in_root, &mut cse_count);
1928 }
1929
1930 let mut prelude: Vec<TapeOp> = Vec::new();
1935 let mut prelude_map: HashMap<*const Expr, usize> = HashMap::new();
1936 let mut summands: Vec<Summand> = Vec::with_capacity(exprs.len());
1937 for e in exprs {
1938 let mut local: Vec<SummandOp> = Vec::new();
1939 let mut local_cache: HashMap<*const Expr, usize> = HashMap::new();
1940 let root_slot = build_into_summand(
1941 e,
1942 &mut local,
1943 &mut local_cache,
1944 &mut prelude,
1945 &mut prelude_map,
1946 &cse_count,
1947 );
1948 summands.push(Summand {
1949 ops: local,
1950 root_slot,
1951 local_reach: Vec::new(),
1952 prelude_reach: Vec::new(),
1953 local_vars: Vec::new(),
1954 prelude_vars: Vec::new(),
1955 all_vars: Vec::new(),
1956 });
1957 }
1958
1959 let mut p_visited: Vec<u32> = vec![0; prelude.len()];
1963 let mut p_epoch: u32 = 0;
1964 let mut p_stack: Vec<usize> = Vec::new();
1965 for s in &mut summands {
1966 let (local_reach, shared_refs) = compute_local_reach(&s.ops, s.root_slot);
1967 s.local_reach = local_reach;
1968
1969 let mut lv: BTreeSet<usize> = BTreeSet::new();
1970 for &i in &s.local_reach {
1971 if let SummandOp::Local(TapeOp::Var(j)) = &s.ops[i] {
1972 lv.insert(*j);
1973 }
1974 }
1975 s.local_vars = lv.iter().copied().collect();
1976
1977 if !shared_refs.is_empty() {
1978 p_epoch += 1;
1979 let mut preach: Vec<usize> = Vec::new();
1980 for &start in &shared_refs {
1981 bfs_prelude(
1982 &prelude,
1983 start,
1984 &mut p_visited,
1985 p_epoch,
1986 &mut p_stack,
1987 &mut preach,
1988 );
1989 }
1990 preach.sort_unstable();
1991 s.prelude_vars = vars_in(&prelude, &preach);
1992 s.prelude_reach = preach;
1993 }
1994
1995 let mut av: BTreeSet<usize> = lv;
1996 for &v in &s.prelude_vars {
1997 av.insert(v);
1998 }
1999 s.all_vars = av.into_iter().collect();
2000 }
2001
2002 HybridTape { prelude, summands }
2003 }
2004
2005 pub fn n_prelude_ops(&self) -> usize {
2006 self.prelude.len()
2007 }
2008 pub fn n_summands(&self) -> usize {
2009 self.summands.len()
2010 }
2011 pub fn max_summand_ops(&self) -> usize {
2012 self.summands.iter().map(|s| s.ops.len()).max().unwrap_or(0)
2013 }
2014 pub fn total_local_ops(&self) -> usize {
2015 self.summands.iter().map(|s| s.ops.len()).sum()
2016 }
2017
2018 pub fn forward_prelude(&self, x: &[f64], prelude_vals: &mut [f64]) {
2021 debug_assert_eq!(prelude_vals.len(), self.prelude.len());
2022 for i in 0..self.prelude.len() {
2023 prelude_vals[i] = fwd_step(&self.prelude[i], x, prelude_vals);
2024 }
2025 }
2026
2027 pub fn forward_summand(
2030 &self,
2031 s: &Summand,
2032 x: &[f64],
2033 prelude_vals: &[f64],
2034 local_vals: &mut [f64],
2035 ) {
2036 debug_assert!(local_vals.len() >= s.ops.len());
2037 for i in 0..s.ops.len() {
2038 local_vals[i] = match &s.ops[i] {
2039 SummandOp::Local(op) => fwd_step(op, x, local_vals),
2040 SummandOp::Shared(k) => prelude_vals[*k],
2041 };
2042 }
2043 }
2044
2045 #[inline]
2047 pub fn root_value(&self, s: &Summand, local_vals: &[f64]) -> f64 {
2048 local_vals[s.root_slot]
2049 }
2050
2051 #[allow(clippy::too_many_arguments)]
2058 pub fn gradient_summand(
2059 &self,
2060 s: &Summand,
2061 prelude_vals: &[f64],
2062 local_vals: &[f64],
2063 seed: f64,
2064 grad: &mut [f64],
2065 local_adj: &mut [f64],
2066 prelude_adj: &mut [f64],
2067 ) {
2068 if seed == 0.0 || s.local_reach.is_empty() {
2069 return;
2070 }
2071 for &i in &s.local_reach {
2072 local_adj[i] = 0.0;
2073 }
2074 for &i in &s.prelude_reach {
2075 prelude_adj[i] = 0.0;
2076 }
2077 local_adj[s.root_slot] = seed;
2078 for &i in s.local_reach.iter().rev() {
2079 let a = local_adj[i];
2080 if a == 0.0 {
2081 continue;
2082 }
2083 match &s.ops[i] {
2084 SummandOp::Local(op) => rev_step(op, i, local_vals, local_adj, a, grad),
2085 SummandOp::Shared(k) => {
2086 prelude_adj[*k] += a;
2087 }
2088 }
2089 }
2090 for &i in s.prelude_reach.iter().rev() {
2091 let a = prelude_adj[i];
2092 if a == 0.0 {
2093 continue;
2094 }
2095 rev_step(&self.prelude[i], i, prelude_vals, prelude_adj, a, grad);
2096 }
2097 }
2098
2099 #[allow(clippy::too_many_arguments)]
2107 pub fn hessian_summand(
2108 &self,
2109 s: &Summand,
2110 prelude_vals: &[f64],
2111 local_vals: &[f64],
2112 weight: f64,
2113 hess_map: &HashMap<(usize, usize), usize>,
2114 values: &mut [f64],
2115 local_dot: &mut [f64],
2116 local_adj: &mut [f64],
2117 local_adj_dot: &mut [f64],
2118 prelude_dot: &mut [f64],
2119 prelude_adj: &mut [f64],
2120 prelude_adj_dot: &mut [f64],
2121 ) {
2122 if weight == 0.0 || s.local_reach.is_empty() {
2123 return;
2124 }
2125 for &j in &s.all_vars {
2126 for &i in &s.local_reach {
2127 local_dot[i] = 0.0;
2128 local_adj[i] = 0.0;
2129 local_adj_dot[i] = 0.0;
2130 }
2131 for &i in &s.prelude_reach {
2132 prelude_dot[i] = 0.0;
2133 prelude_adj[i] = 0.0;
2134 prelude_adj_dot[i] = 0.0;
2135 }
2136 for &i in &s.prelude_reach {
2137 prelude_dot[i] = fwd_tan_step(&self.prelude[i], j, prelude_vals, prelude_dot, i);
2138 }
2139 for &i in &s.local_reach {
2140 local_dot[i] = match &s.ops[i] {
2141 SummandOp::Local(op) => fwd_tan_step(op, j, local_vals, local_dot, i),
2142 SummandOp::Shared(k) => prelude_dot[*k],
2143 };
2144 }
2145 local_adj[s.root_slot] = 1.0;
2146 for &i in s.local_reach.iter().rev() {
2147 let w = local_adj[i];
2148 let wd = local_adj_dot[i];
2149 if w == 0.0 && wd == 0.0 {
2150 continue;
2151 }
2152 match &s.ops[i] {
2153 SummandOp::Local(op) => {
2154 ror_step(
2155 op,
2156 i,
2157 j,
2158 local_vals,
2159 local_dot,
2160 local_adj,
2161 local_adj_dot,
2162 w,
2163 wd,
2164 weight,
2165 hess_map,
2166 values,
2167 );
2168 }
2169 SummandOp::Shared(k) => {
2170 prelude_adj[*k] += w;
2171 prelude_adj_dot[*k] += wd;
2172 }
2173 }
2174 }
2175 for &i in s.prelude_reach.iter().rev() {
2176 let w = prelude_adj[i];
2177 let wd = prelude_adj_dot[i];
2178 if w == 0.0 && wd == 0.0 {
2179 continue;
2180 }
2181 ror_step(
2182 &self.prelude[i],
2183 i,
2184 j,
2185 prelude_vals,
2186 prelude_dot,
2187 prelude_adj,
2188 prelude_adj_dot,
2189 w,
2190 wd,
2191 weight,
2192 hess_map,
2193 values,
2194 );
2195 }
2196 }
2197 }
2198
2199 pub fn hessian_sparsity_all(&self) -> BTreeSet<(usize, usize)> {
2202 let mut pairs = hessian_sparsity_impl(&self.prelude);
2203
2204 let prelude_var_sets = compute_var_sets(&self.prelude);
2207
2208 for s in &self.summands {
2209 summand_sparsity(&s.ops, &prelude_var_sets, &mut pairs);
2210 }
2211 pairs
2212 }
2213}
2214
2215fn cse_contains_funcall(expr: &Expr) -> bool {
2231 match expr {
2232 Expr::Funcall { .. } => true,
2233 Expr::Const(_) | Expr::Var(_) => false,
2234 Expr::Binary(_, a, b) => cse_contains_funcall(a) || cse_contains_funcall(b),
2235 Expr::Unary(_, a) => cse_contains_funcall(a),
2236 Expr::Sum(args) | Expr::MinList(args) | Expr::MaxList(args) => {
2237 args.iter().any(cse_contains_funcall)
2238 }
2239 Expr::Compare(_, a, b) | Expr::And(a, b) | Expr::Or(a, b) => {
2240 cse_contains_funcall(a) || cse_contains_funcall(b)
2241 }
2242 Expr::Not(a) => cse_contains_funcall(a),
2243 Expr::Cond { cond, then_, else_ } => {
2244 cse_contains_funcall(cond) || cse_contains_funcall(then_) || cse_contains_funcall(else_)
2245 }
2246 Expr::Cse(body) => cse_contains_funcall(body),
2247 }
2248}
2249
2250fn count_cse_appearances(
2251 e: &Expr,
2252 seen_in_root: &mut HashSet<*const Expr>,
2253 counts: &mut HashMap<*const Expr, usize>,
2254) {
2255 match e {
2256 Expr::Const(_) | Expr::Var(_) => {}
2257 Expr::Binary(_, a, b) => {
2258 count_cse_appearances(a, seen_in_root, counts);
2259 count_cse_appearances(b, seen_in_root, counts);
2260 }
2261 Expr::Unary(_, a) => count_cse_appearances(a, seen_in_root, counts),
2262 Expr::Sum(args) | Expr::MinList(args) | Expr::MaxList(args) => {
2263 for a in args {
2264 count_cse_appearances(a, seen_in_root, counts);
2265 }
2266 }
2267 Expr::Compare(_, a, b) | Expr::And(a, b) | Expr::Or(a, b) => {
2268 count_cse_appearances(a, seen_in_root, counts);
2269 count_cse_appearances(b, seen_in_root, counts);
2270 }
2271 Expr::Not(a) => count_cse_appearances(a, seen_in_root, counts),
2272 Expr::Cond { cond, then_, else_ } => {
2273 count_cse_appearances(cond, seen_in_root, counts);
2274 count_cse_appearances(then_, seen_in_root, counts);
2275 count_cse_appearances(else_, seen_in_root, counts);
2276 }
2277 Expr::Cse(body) => {
2278 let key = Arc::as_ptr(body) as *const Expr;
2279 if seen_in_root.insert(key) {
2280 *counts.entry(key).or_insert(0) += 1;
2281 count_cse_appearances(body, seen_in_root, counts);
2282 }
2283 }
2284 Expr::Funcall { args, .. } => {
2285 for arg in args {
2286 if let FuncallArg::Real(e) = arg {
2287 count_cse_appearances(e, seen_in_root, counts);
2288 }
2289 }
2290 }
2291 }
2292}
2293
2294fn build_into_summand(
2300 expr: &Expr,
2301 local: &mut Vec<SummandOp>,
2302 local_cache: &mut HashMap<*const Expr, usize>,
2303 prelude: &mut Vec<TapeOp>,
2304 prelude_map: &mut HashMap<*const Expr, usize>,
2305 cse_count: &HashMap<*const Expr, usize>,
2306) -> usize {
2307 match expr {
2308 Expr::Const(c) => {
2309 let i = local.len();
2310 local.push(SummandOp::Local(TapeOp::Const(*c)));
2311 i
2312 }
2313 Expr::Var(j) => {
2314 let i = local.len();
2315 local.push(SummandOp::Local(TapeOp::Var(*j)));
2316 i
2317 }
2318 Expr::Binary(op, a, b) => {
2319 if let BinOp::Pow = op {
2320 if let Some(c) = peek_const(b) {
2321 if let Some(i) = try_emit_const_pow_summand(
2322 a,
2323 c,
2324 local,
2325 local_cache,
2326 prelude,
2327 prelude_map,
2328 cse_count,
2329 ) {
2330 return i;
2331 }
2332 }
2333 }
2334 let l = build_into_summand(a, local, local_cache, prelude, prelude_map, cse_count);
2335 let r = build_into_summand(b, local, local_cache, prelude, prelude_map, cse_count);
2336 let i = local.len();
2337 local.push(SummandOp::Local(match op {
2338 BinOp::Add => TapeOp::Add(l, r),
2339 BinOp::Sub => TapeOp::Sub(l, r),
2340 BinOp::Mul => TapeOp::Mul(l, r),
2341 BinOp::Div => TapeOp::Div(l, r),
2342 BinOp::Pow => TapeOp::Pow(l, r),
2343 BinOp::Atan2 => TapeOp::Atan2(l, r),
2344 }));
2345 i
2346 }
2347 Expr::Unary(op, a) => {
2348 let v = build_into_summand(a, local, local_cache, prelude, prelude_map, cse_count);
2349 let i = local.len();
2350 local.push(SummandOp::Local(match op {
2351 UnaryOp::Neg => TapeOp::Neg(v),
2352 UnaryOp::Sqrt => TapeOp::Sqrt(v),
2353 UnaryOp::Log => TapeOp::Log(v),
2354 UnaryOp::Log10 => TapeOp::Log10(v),
2355 UnaryOp::Exp => TapeOp::Exp(v),
2356 UnaryOp::Abs => TapeOp::Abs(v),
2357 UnaryOp::Sin => TapeOp::Sin(v),
2358 UnaryOp::Cos => TapeOp::Cos(v),
2359 UnaryOp::Tan => TapeOp::Tan(v),
2360 UnaryOp::Atan => TapeOp::Atan(v),
2361 UnaryOp::Acos => TapeOp::Acos(v),
2362 UnaryOp::Sinh => TapeOp::Sinh(v),
2363 UnaryOp::Cosh => TapeOp::Cosh(v),
2364 UnaryOp::Tanh => TapeOp::Tanh(v),
2365 UnaryOp::Asin => TapeOp::Asin(v),
2366 UnaryOp::Acosh => TapeOp::Acosh(v),
2367 UnaryOp::Asinh => TapeOp::Asinh(v),
2368 UnaryOp::Atanh => TapeOp::Atanh(v),
2369 }));
2370 i
2371 }
2372 Expr::Sum(args) => {
2373 if args.is_empty() {
2374 let i = local.len();
2375 local.push(SummandOp::Local(TapeOp::Const(0.0)));
2376 return i;
2377 }
2378 let mut acc = build_into_summand(
2379 &args[0],
2380 local,
2381 local_cache,
2382 prelude,
2383 prelude_map,
2384 cse_count,
2385 );
2386 for a in &args[1..] {
2387 let nxt =
2388 build_into_summand(a, local, local_cache, prelude, prelude_map, cse_count);
2389 let i = local.len();
2390 local.push(SummandOp::Local(TapeOp::Add(acc, nxt)));
2391 acc = i;
2392 }
2393 acc
2394 }
2395 Expr::Cse(body) => {
2396 let key = Arc::as_ptr(body) as *const Expr;
2397 if let Some(&li) = local_cache.get(&key) {
2398 return li;
2399 }
2400 let promoted = cse_count.get(&key).copied().unwrap_or(0) >= 2;
2401 if promoted {
2402 if cse_contains_funcall(body) {
2409 panic!(
2410 "HybridTape: AMPL external function calls are not supported on the \
2411 hybrid (partial-separability) tape path. Build with \
2412 Tape::build_with_externals instead."
2413 );
2414 }
2415 let pslot =
2420 build_recursive(expr, prelude, prelude_map, &ExternalResolver::default());
2421 let li = local.len();
2422 local.push(SummandOp::Shared(pslot));
2423 local_cache.insert(key, li);
2424 li
2425 } else {
2426 let li =
2427 build_into_summand(body, local, local_cache, prelude, prelude_map, cse_count);
2428 local_cache.insert(key, li);
2429 li
2430 }
2431 }
2432 Expr::Compare(_, _, _)
2433 | Expr::And(_, _)
2434 | Expr::Or(_, _)
2435 | Expr::Not(_)
2436 | Expr::Cond { .. }
2437 | Expr::MinList(_)
2438 | Expr::MaxList(_) => {
2439 panic!(
2440 "HybridTape: conditional / logical / min-max opcodes (comparisons, \
2441 AND/OR/NOT, if-then-else, min/max lists) are not supported on the \
2442 hybrid (partial-separability) tape path. Build with \
2443 Tape::build_with_externals instead."
2444 );
2445 }
2446 Expr::Funcall { .. } => {
2447 panic!(
2448 "HybridTape: AMPL external function calls are not supported on the \
2449 hybrid (partial-separability) tape path. Build with Tape::build_with_externals \
2450 instead."
2451 );
2452 }
2453 }
2454}
2455
2456fn try_emit_const_pow_summand(
2459 base_expr: &Expr,
2460 c: f64,
2461 local: &mut Vec<SummandOp>,
2462 local_cache: &mut HashMap<*const Expr, usize>,
2463 prelude: &mut Vec<TapeOp>,
2464 prelude_map: &mut HashMap<*const Expr, usize>,
2465 cse_count: &HashMap<*const Expr, usize>,
2466) -> Option<usize> {
2467 if c == 0.0 {
2468 let i = local.len();
2469 local.push(SummandOp::Local(TapeOp::Const(1.0)));
2470 return Some(i);
2471 }
2472 if c == 1.0 {
2473 return Some(build_into_summand(
2474 base_expr,
2475 local,
2476 local_cache,
2477 prelude,
2478 prelude_map,
2479 cse_count,
2480 ));
2481 }
2482 if c == 0.5 {
2483 let b = build_into_summand(
2484 base_expr,
2485 local,
2486 local_cache,
2487 prelude,
2488 prelude_map,
2489 cse_count,
2490 );
2491 let i = local.len();
2492 local.push(SummandOp::Local(TapeOp::Sqrt(b)));
2493 return Some(i);
2494 }
2495 if c.is_finite() && c.fract() == 0.0 && c.abs() <= 8.0 {
2496 let n = c.abs() as u32;
2497 if n == 0 {
2498 let i = local.len();
2499 local.push(SummandOp::Local(TapeOp::Const(1.0)));
2500 return Some(i);
2501 }
2502 let b = build_into_summand(
2503 base_expr,
2504 local,
2505 local_cache,
2506 prelude,
2507 prelude_map,
2508 cse_count,
2509 );
2510 let pos = emit_int_pow_summand(b, n, local);
2511 if c < 0.0 {
2512 let one_idx = local.len();
2513 local.push(SummandOp::Local(TapeOp::Const(1.0)));
2514 let i = local.len();
2515 local.push(SummandOp::Local(TapeOp::Div(one_idx, pos)));
2516 return Some(i);
2517 }
2518 return Some(pos);
2519 }
2520 None
2521}
2522
2523fn emit_int_pow_summand(base: usize, n: u32, local: &mut Vec<SummandOp>) -> usize {
2524 debug_assert!(n >= 1);
2525 if n == 1 {
2526 return base;
2527 }
2528 let half = emit_int_pow_summand(base, n / 2, local);
2529 let squared = local.len();
2530 local.push(SummandOp::Local(TapeOp::Mul(half, half)));
2531 if n % 2 == 1 {
2532 let i = local.len();
2533 local.push(SummandOp::Local(TapeOp::Mul(squared, base)));
2534 i
2535 } else {
2536 squared
2537 }
2538}
2539
2540fn compute_local_reach(ops: &[SummandOp], root: usize) -> (Vec<usize>, Vec<usize>) {
2544 let mut visited = vec![false; ops.len()];
2545 let mut reach: Vec<usize> = Vec::new();
2546 let mut shared: BTreeSet<usize> = BTreeSet::new();
2547 let mut stack: Vec<usize> = Vec::with_capacity(16);
2548 visited[root] = true;
2549 reach.push(root);
2550 stack.push(root);
2551 while let Some(s) = stack.pop() {
2552 match &ops[s] {
2553 SummandOp::Local(op) => {
2554 let (a, b) = op_operands(op);
2555 if let Some(a) = a {
2556 if !visited[a] {
2557 visited[a] = true;
2558 reach.push(a);
2559 stack.push(a);
2560 }
2561 }
2562 if let Some(b) = b {
2563 if !visited[b] {
2564 visited[b] = true;
2565 reach.push(b);
2566 stack.push(b);
2567 }
2568 }
2569 }
2570 SummandOp::Shared(k) => {
2571 shared.insert(*k);
2572 }
2573 }
2574 }
2575 reach.sort_unstable();
2576 (reach, shared.into_iter().collect())
2577}
2578
2579fn bfs_prelude(
2583 prelude: &[TapeOp],
2584 start: usize,
2585 visited: &mut [u32],
2586 cur: u32,
2587 stack: &mut Vec<usize>,
2588 out: &mut Vec<usize>,
2589) {
2590 if visited[start] == cur {
2591 return;
2592 }
2593 visited[start] = cur;
2594 out.push(start);
2595 stack.push(start);
2596 while let Some(s) = stack.pop() {
2597 let (a, b) = op_operands(&prelude[s]);
2598 if let Some(a) = a {
2599 if visited[a] != cur {
2600 visited[a] = cur;
2601 out.push(a);
2602 stack.push(a);
2603 }
2604 }
2605 if let Some(b) = b {
2606 if visited[b] != cur {
2607 visited[b] = cur;
2608 out.push(b);
2609 stack.push(b);
2610 }
2611 }
2612 }
2613}
2614
2615fn compute_var_sets(ops: &[TapeOp]) -> Vec<BTreeSet<usize>> {
2619 let mut out: Vec<BTreeSet<usize>> = Vec::with_capacity(ops.len());
2620 for op in ops {
2621 let vs: BTreeSet<usize> = match op {
2622 TapeOp::Const(_) => BTreeSet::new(),
2623 TapeOp::Var(j) => {
2624 let mut s = BTreeSet::new();
2625 s.insert(*j);
2626 s
2627 }
2628 TapeOp::Add(a, b)
2629 | TapeOp::Sub(a, b)
2630 | TapeOp::Mul(a, b)
2631 | TapeOp::Div(a, b)
2632 | TapeOp::Pow(a, b)
2633 | TapeOp::Atan2(a, b) => out[*a].union(&out[*b]).copied().collect(),
2634 TapeOp::Neg(a)
2635 | TapeOp::Abs(a)
2636 | TapeOp::Sqrt(a)
2637 | TapeOp::Exp(a)
2638 | TapeOp::Log(a)
2639 | TapeOp::Log10(a)
2640 | TapeOp::Sin(a)
2641 | TapeOp::Cos(a)
2642 | TapeOp::Tan(a)
2643 | TapeOp::Atan(a)
2644 | TapeOp::Acos(a)
2645 | TapeOp::Sinh(a)
2646 | TapeOp::Cosh(a)
2647 | TapeOp::Tanh(a)
2648 | TapeOp::Asin(a)
2649 | TapeOp::Acosh(a)
2650 | TapeOp::Asinh(a)
2651 | TapeOp::Atanh(a) => out[*a].clone(),
2652 TapeOp::Cmp(_, _, _)
2653 | TapeOp::And(_, _)
2654 | TapeOp::Or(_, _)
2655 | TapeOp::Not(_)
2656 | TapeOp::Select(_, _, _)
2657 | TapeOp::Min(_, _)
2658 | TapeOp::Max(_, _) => unreachable!(
2659 "HybridTape prelude cannot contain conditional / logical / min-max \
2660 TapeOps; build_into_summand panics on those Expr variants."
2661 ),
2662 TapeOp::Funcall(_) => unreachable!(
2663 "HybridTape prelude cannot contain TapeOp::Funcall; \
2664 build_into_summand panics on Expr::Funcall."
2665 ),
2666 };
2667 out.push(vs);
2668 }
2669 out
2670}
2671
2672fn summand_sparsity(
2677 ops: &[SummandOp],
2678 prelude_var_sets: &[BTreeSet<usize>],
2679 pairs: &mut BTreeSet<(usize, usize)>,
2680) {
2681 let mut var_sets: Vec<BTreeSet<usize>> = Vec::with_capacity(ops.len());
2682 let emit_cross =
2683 |s1: &BTreeSet<usize>, s2: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
2684 for &v1 in s1 {
2685 for &v2 in s2 {
2686 let (r, c) = if v1 >= v2 { (v1, v2) } else { (v2, v1) };
2687 pairs.insert((r, c));
2688 }
2689 }
2690 };
2691 let emit_self = |s: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
2692 let vars: Vec<usize> = s.iter().copied().collect();
2693 for (ai, &vi) in vars.iter().enumerate() {
2694 for &vj in &vars[..=ai] {
2695 let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
2696 pairs.insert((r, c));
2697 }
2698 }
2699 };
2700 for so in ops {
2701 let vset: BTreeSet<usize> = match so {
2702 SummandOp::Shared(k) => prelude_var_sets[*k].clone(),
2703 SummandOp::Local(op) => match op {
2704 TapeOp::Const(_) => BTreeSet::new(),
2705 TapeOp::Var(j) => {
2706 let mut s = BTreeSet::new();
2707 s.insert(*j);
2708 s
2709 }
2710 TapeOp::Add(a, b) | TapeOp::Sub(a, b) => {
2711 var_sets[*a].union(&var_sets[*b]).copied().collect()
2712 }
2713 TapeOp::Neg(a) | TapeOp::Abs(a) => var_sets[*a].clone(),
2714 TapeOp::Mul(a, b) => {
2715 emit_cross(&var_sets[*a], &var_sets[*b], pairs);
2716 var_sets[*a].union(&var_sets[*b]).copied().collect()
2717 }
2718 TapeOp::Div(a, b) => {
2719 emit_cross(&var_sets[*a], &var_sets[*b], pairs);
2720 emit_self(&var_sets[*b], pairs);
2721 var_sets[*a].union(&var_sets[*b]).copied().collect()
2722 }
2723 TapeOp::Pow(a, b) | TapeOp::Atan2(a, b) => {
2724 let combined: BTreeSet<usize> =
2725 var_sets[*a].union(&var_sets[*b]).copied().collect();
2726 emit_self(&combined, pairs);
2727 combined
2728 }
2729 TapeOp::Sqrt(a)
2730 | TapeOp::Exp(a)
2731 | TapeOp::Log(a)
2732 | TapeOp::Log10(a)
2733 | TapeOp::Sin(a)
2734 | TapeOp::Cos(a)
2735 | TapeOp::Tan(a)
2736 | TapeOp::Atan(a)
2737 | TapeOp::Acos(a)
2738 | TapeOp::Sinh(a)
2739 | TapeOp::Cosh(a)
2740 | TapeOp::Tanh(a)
2741 | TapeOp::Asin(a)
2742 | TapeOp::Acosh(a)
2743 | TapeOp::Asinh(a)
2744 | TapeOp::Atanh(a) => {
2745 emit_self(&var_sets[*a], pairs);
2746 var_sets[*a].clone()
2747 }
2748 TapeOp::Cmp(_, _, _)
2749 | TapeOp::And(_, _)
2750 | TapeOp::Or(_, _)
2751 | TapeOp::Not(_)
2752 | TapeOp::Select(_, _, _)
2753 | TapeOp::Min(_, _)
2754 | TapeOp::Max(_, _) => unreachable!(
2755 "HybridTape summand cannot contain conditional / logical / min-max \
2756 TapeOps; build_into_summand panics on those Expr variants."
2757 ),
2758 TapeOp::Funcall(_) => unreachable!(
2759 "HybridTape summand cannot contain TapeOp::Funcall; \
2760 build_into_summand panics on Expr::Funcall."
2761 ),
2762 },
2763 };
2764 var_sets.push(vset);
2765 }
2766}
2767
2768#[inline]
2771fn op_operands(op: &TapeOp) -> (Option<usize>, Option<usize>) {
2772 match op {
2773 TapeOp::Const(_) | TapeOp::Var(_) => (None, None),
2774 TapeOp::Add(a, b)
2775 | TapeOp::Sub(a, b)
2776 | TapeOp::Mul(a, b)
2777 | TapeOp::Div(a, b)
2778 | TapeOp::Pow(a, b)
2779 | TapeOp::Atan2(a, b) => (Some(*a), Some(*b)),
2780 TapeOp::Neg(a)
2781 | TapeOp::Abs(a)
2782 | TapeOp::Sqrt(a)
2783 | TapeOp::Exp(a)
2784 | TapeOp::Log(a)
2785 | TapeOp::Log10(a)
2786 | TapeOp::Sin(a)
2787 | TapeOp::Cos(a)
2788 | TapeOp::Tan(a)
2789 | TapeOp::Atan(a)
2790 | TapeOp::Acos(a)
2791 | TapeOp::Sinh(a)
2792 | TapeOp::Cosh(a)
2793 | TapeOp::Tanh(a)
2794 | TapeOp::Asin(a)
2795 | TapeOp::Acosh(a)
2796 | TapeOp::Asinh(a)
2797 | TapeOp::Atanh(a) => (Some(*a), None),
2798 TapeOp::Cmp(_, a, b) | TapeOp::And(a, b) | TapeOp::Or(a, b) => (Some(*a), Some(*b)),
2804 TapeOp::Not(a) => (Some(*a), None),
2805 TapeOp::Select(_, _, _) => unreachable!(
2806 "op_operands: TapeOp::Select has three operands and is unsupported on \
2807 the HybridTape path"
2808 ),
2809 TapeOp::Min(_, _) | TapeOp::Max(_, _) => unreachable!(
2810 "op_operands: TapeOp::Min/Max are unsupported on the HybridTape path \
2811 (build_into_summand rejects min/max lists)"
2812 ),
2813 TapeOp::Funcall(_) => (None, None),
2814 }
2815}
2816
2817fn vars_in(ops: &[TapeOp], reach: &[usize]) -> Vec<usize> {
2818 let mut s: BTreeSet<usize> = BTreeSet::new();
2819 for &i in reach {
2820 if let TapeOp::Var(j) = &ops[i] {
2821 s.insert(*j);
2822 }
2823 }
2824 s.into_iter().collect()
2825}
2826
2827#[inline]
2830fn fwd_step(op: &TapeOp, x: &[f64], vals: &[f64]) -> f64 {
2831 match op {
2832 TapeOp::Const(c) => *c,
2833 TapeOp::Var(i) => x[*i],
2834 TapeOp::Add(a, b) => vals[*a] + vals[*b],
2835 TapeOp::Sub(a, b) => vals[*a] - vals[*b],
2836 TapeOp::Mul(a, b) => vals[*a] * vals[*b],
2837 TapeOp::Div(a, b) => vals[*a] / vals[*b],
2838 TapeOp::Pow(a, b) => vals[*a].powf(vals[*b]),
2839 TapeOp::Neg(a) => -vals[*a],
2840 TapeOp::Abs(a) => vals[*a].abs(),
2841 TapeOp::Sqrt(a) => vals[*a].sqrt(),
2842 TapeOp::Exp(a) => vals[*a].exp(),
2843 TapeOp::Log(a) => vals[*a].ln(),
2844 TapeOp::Log10(a) => vals[*a].log10(),
2845 TapeOp::Sin(a) => vals[*a].sin(),
2846 TapeOp::Cos(a) => vals[*a].cos(),
2847 TapeOp::Tan(a) => vals[*a].tan(),
2848 TapeOp::Atan(a) => vals[*a].atan(),
2849 TapeOp::Acos(a) => vals[*a].acos(),
2850 TapeOp::Sinh(a) => vals[*a].sinh(),
2851 TapeOp::Cosh(a) => vals[*a].cosh(),
2852 TapeOp::Tanh(a) => vals[*a].tanh(),
2853 TapeOp::Asin(a) => vals[*a].asin(),
2854 TapeOp::Acosh(a) => vals[*a].acosh(),
2855 TapeOp::Asinh(a) => vals[*a].asinh(),
2856 TapeOp::Atanh(a) => vals[*a].atanh(),
2857 TapeOp::Atan2(a, b) => vals[*a].atan2(vals[*b]),
2858 TapeOp::Cmp(_, _, _)
2859 | TapeOp::And(_, _)
2860 | TapeOp::Or(_, _)
2861 | TapeOp::Not(_)
2862 | TapeOp::Select(_, _, _)
2863 | TapeOp::Min(_, _)
2864 | TapeOp::Max(_, _) => panic!(
2865 "GlobalTape free-function kernels do not implement conditional / logical \
2866 / min-max TapeOps; use the Tape (build_with_externals) interpreter path \
2867 instead."
2868 ),
2869 TapeOp::Funcall(fc) => {
2870 let FuncallData { lib, name, args } = fc.as_ref();
2871 let call_args = funcall_to_ext_args(args, vals);
2872 let res = lib
2873 .eval(name, &call_args, false, false)
2874 .unwrap_or_else(|e| panic!("external function '{name}' eval failed: {e}"));
2875 res.value
2876 }
2877 }
2878}
2879
2880#[inline]
2881fn rev_step(op: &TapeOp, i: usize, vals: &[f64], adj: &mut [f64], a: f64, grad: &mut [f64]) {
2882 match op {
2883 TapeOp::Const(_) => {}
2884 TapeOp::Var(j) => {
2885 grad[*j] += a;
2886 }
2887 TapeOp::Add(l, r) => {
2888 adj[*l] += a;
2889 adj[*r] += a;
2890 }
2891 TapeOp::Sub(l, r) => {
2892 adj[*l] += a;
2893 adj[*r] -= a;
2894 }
2895 TapeOp::Mul(l, r) => {
2896 adj[*l] += a * vals[*r];
2897 adj[*r] += a * vals[*l];
2898 }
2899 TapeOp::Div(l, r) => {
2900 let rv = vals[*r];
2901 adj[*l] += a / rv;
2902 adj[*r] -= a * vals[*l] / (rv * rv);
2903 }
2904 TapeOp::Pow(l, r) => {
2905 let lv = vals[*l];
2906 let rv = vals[*r];
2907 if rv != 0.0 {
2908 adj[*l] += a * rv * lv.powf(rv - 1.0);
2909 }
2910 if lv > 0.0 {
2911 adj[*r] += a * vals[i] * lv.ln();
2912 }
2913 }
2914 TapeOp::Neg(j) => {
2915 adj[*j] -= a;
2916 }
2917 TapeOp::Abs(j) => {
2918 if vals[*j] >= 0.0 {
2919 adj[*j] += a;
2920 } else {
2921 adj[*j] -= a;
2922 }
2923 }
2924 TapeOp::Sqrt(j) => {
2925 let sv = vals[i];
2926 if sv > 0.0 {
2927 adj[*j] += a * 0.5 / sv;
2928 }
2929 }
2930 TapeOp::Exp(j) => {
2931 adj[*j] += a * vals[i];
2932 }
2933 TapeOp::Log(j) => {
2934 adj[*j] += a / vals[*j];
2935 }
2936 TapeOp::Log10(j) => {
2937 adj[*j] += a / (vals[*j] * std::f64::consts::LN_10);
2938 }
2939 TapeOp::Sin(j) => {
2940 adj[*j] += a * vals[*j].cos();
2941 }
2942 TapeOp::Cos(j) => {
2943 adj[*j] -= a * vals[*j].sin();
2944 }
2945 TapeOp::Tan(j) => {
2946 let t = vals[i];
2947 adj[*j] += a * (1.0 + t * t);
2948 }
2949 TapeOp::Atan(j) => {
2950 let u = vals[*j];
2951 adj[*j] += a / (1.0 + u * u);
2952 }
2953 TapeOp::Acos(j) => {
2954 let u = vals[*j];
2955 adj[*j] -= a / (1.0 - u * u).sqrt();
2956 }
2957 TapeOp::Sinh(j) => {
2958 adj[*j] += a * vals[*j].cosh();
2959 }
2960 TapeOp::Cosh(j) => {
2961 adj[*j] += a * vals[*j].sinh();
2962 }
2963 TapeOp::Tanh(j) => {
2964 let t = vals[i];
2965 adj[*j] += a * (1.0 - t * t);
2966 }
2967 TapeOp::Asin(j) => {
2968 let u = vals[*j];
2969 adj[*j] += a / (1.0 - u * u).sqrt();
2970 }
2971 TapeOp::Acosh(j) => {
2972 let u = vals[*j];
2973 adj[*j] += a / (u * u - 1.0).sqrt();
2974 }
2975 TapeOp::Asinh(j) => {
2976 let u = vals[*j];
2977 adj[*j] += a / (u * u + 1.0).sqrt();
2978 }
2979 TapeOp::Atanh(j) => {
2980 let u = vals[*j];
2981 adj[*j] += a / (1.0 - u * u);
2982 }
2983 TapeOp::Atan2(l, r) => {
2984 let y = vals[*l];
2985 let x = vals[*r];
2986 let d = y * y + x * x;
2987 adj[*l] += a * (x / d);
2988 adj[*r] += a * (-y / d);
2989 }
2990 TapeOp::Cmp(_, _, _)
2991 | TapeOp::And(_, _)
2992 | TapeOp::Or(_, _)
2993 | TapeOp::Not(_)
2994 | TapeOp::Select(_, _, _)
2995 | TapeOp::Min(_, _)
2996 | TapeOp::Max(_, _) => panic!(
2997 "GlobalTape free-function kernels do not implement conditional / logical \
2998 / min-max TapeOps; use the Tape (build_with_externals) interpreter path \
2999 instead."
3000 ),
3001 TapeOp::Funcall(fc) => {
3002 let FuncallData { lib, name, args } = fc.as_ref();
3003 let call_args = funcall_to_ext_args(args, vals);
3004 let res = lib
3005 .eval(name, &call_args, true, false)
3006 .unwrap_or_else(|e| panic!("external function '{name}' reverse eval failed: {e}"));
3007 let derivs = res.derivs.expect("want_derivs=true returns derivs");
3008 let mut k = 0usize;
3009 for arg in args {
3010 if let TapeFuncallArg::Tape(idx) = arg {
3011 adj[*idx] += a * derivs[k];
3012 k += 1;
3013 }
3014 }
3015 let _ = i;
3016 let _ = grad;
3017 }
3018 }
3019}
3020
3021#[inline]
3022fn fwd_tan_step(op: &TapeOp, seed_var: usize, vals: &[f64], dot: &[f64], i: usize) -> f64 {
3023 match op {
3024 TapeOp::Const(_) => 0.0,
3025 TapeOp::Var(k) => {
3026 if *k == seed_var {
3027 1.0
3028 } else {
3029 0.0
3030 }
3031 }
3032 TapeOp::Add(a, b) => dot[*a] + dot[*b],
3033 TapeOp::Sub(a, b) => dot[*a] - dot[*b],
3034 TapeOp::Mul(a, b) => dot[*a] * vals[*b] + vals[*a] * dot[*b],
3035 TapeOp::Div(a, b) => {
3036 let vb = vals[*b];
3037 (dot[*a] * vb - vals[*a] * dot[*b]) / (vb * vb)
3038 }
3039 TapeOp::Pow(a, b) => {
3040 let u = vals[*a];
3041 let r = vals[*b];
3042 let du = dot[*a];
3043 let dr = dot[*b];
3044 let mut result = 0.0;
3045 if r != 0.0 {
3050 result += r * u.powf(r - 1.0) * du;
3051 }
3052 if u > 0.0 {
3053 result += vals[i] * u.ln() * dr;
3054 }
3055 result
3056 }
3057 TapeOp::Neg(a) => -dot[*a],
3058 TapeOp::Abs(a) => {
3059 if vals[*a] >= 0.0 {
3060 dot[*a]
3061 } else {
3062 -dot[*a]
3063 }
3064 }
3065 TapeOp::Sqrt(a) => {
3066 let sv = vals[i];
3067 if sv > 0.0 {
3068 dot[*a] * 0.5 / sv
3069 } else {
3070 0.0
3071 }
3072 }
3073 TapeOp::Exp(a) => dot[*a] * vals[i],
3074 TapeOp::Log(a) => dot[*a] / vals[*a],
3075 TapeOp::Log10(a) => dot[*a] / (vals[*a] * std::f64::consts::LN_10),
3076 TapeOp::Sin(a) => dot[*a] * vals[*a].cos(),
3077 TapeOp::Cos(a) => -dot[*a] * vals[*a].sin(),
3078 TapeOp::Tan(a) => {
3079 let t = vals[i];
3080 dot[*a] * (1.0 + t * t)
3081 }
3082 TapeOp::Atan(a) => {
3083 let u = vals[*a];
3084 dot[*a] / (1.0 + u * u)
3085 }
3086 TapeOp::Acos(a) => {
3087 let u = vals[*a];
3088 -dot[*a] / (1.0 - u * u).sqrt()
3089 }
3090 TapeOp::Sinh(a) => dot[*a] * vals[*a].cosh(),
3091 TapeOp::Cosh(a) => dot[*a] * vals[*a].sinh(),
3092 TapeOp::Tanh(a) => {
3093 let t = vals[i];
3094 dot[*a] * (1.0 - t * t)
3095 }
3096 TapeOp::Asin(a) => {
3097 let u = vals[*a];
3098 dot[*a] / (1.0 - u * u).sqrt()
3099 }
3100 TapeOp::Acosh(a) => {
3101 let u = vals[*a];
3102 dot[*a] / (u * u - 1.0).sqrt()
3103 }
3104 TapeOp::Asinh(a) => {
3105 let u = vals[*a];
3106 dot[*a] / (u * u + 1.0).sqrt()
3107 }
3108 TapeOp::Atanh(a) => {
3109 let u = vals[*a];
3110 dot[*a] / (1.0 - u * u)
3111 }
3112 TapeOp::Atan2(a, b) => {
3113 let y = vals[*a];
3114 let x = vals[*b];
3115 let d = y * y + x * x;
3116 (x * dot[*a] - y * dot[*b]) / d
3117 }
3118 TapeOp::Cmp(_, _, _)
3119 | TapeOp::And(_, _)
3120 | TapeOp::Or(_, _)
3121 | TapeOp::Not(_)
3122 | TapeOp::Select(_, _, _)
3123 | TapeOp::Min(_, _)
3124 | TapeOp::Max(_, _) => panic!(
3125 "GlobalTape free-function kernels do not implement conditional / logical \
3126 / min-max TapeOps; use the Tape (build_with_externals) interpreter path \
3127 instead."
3128 ),
3129 TapeOp::Funcall(fc) => {
3130 let FuncallData { lib, name, args } = fc.as_ref();
3131 let call_args = funcall_to_ext_args(args, vals);
3132 let res = lib
3133 .eval(name, &call_args, true, false)
3134 .unwrap_or_else(|e| panic!("external function '{name}' tangent eval failed: {e}"));
3135 let derivs = res.derivs.expect("want_derivs=true returns derivs");
3136 let mut acc = 0.0;
3137 let mut k = 0usize;
3138 for arg in args {
3139 if let TapeFuncallArg::Tape(idx) = arg {
3140 acc += derivs[k] * dot[*idx];
3141 k += 1;
3142 }
3143 }
3144 let _ = seed_var;
3145 acc
3146 }
3147 }
3148}
3149
3150#[allow(clippy::too_many_arguments)]
3151#[inline]
3152fn ror_step(
3153 op: &TapeOp,
3154 i: usize,
3155 seed_var: usize,
3156 vals: &[f64],
3157 dot: &[f64],
3158 adj: &mut [f64],
3159 adj_dot: &mut [f64],
3160 w: f64,
3161 wd: f64,
3162 weight: f64,
3163 hess_map: &HashMap<(usize, usize), usize>,
3164 values: &mut [f64],
3165) {
3166 match op {
3167 TapeOp::Const(_) => {}
3168 TapeOp::Var(k) => {
3169 if wd != 0.0 && *k >= seed_var {
3170 if let Some(&pos) = hess_map.get(&(*k, seed_var)) {
3171 values[pos] += weight * wd;
3172 }
3173 }
3174 }
3175 TapeOp::Add(a, b) => {
3176 adj[*a] += w;
3177 adj[*b] += w;
3178 adj_dot[*a] += wd;
3179 adj_dot[*b] += wd;
3180 }
3181 TapeOp::Sub(a, b) => {
3182 adj[*a] += w;
3183 adj[*b] -= w;
3184 adj_dot[*a] += wd;
3185 adj_dot[*b] -= wd;
3186 }
3187 TapeOp::Mul(a, b) => {
3188 adj[*a] += w * vals[*b];
3189 adj[*b] += w * vals[*a];
3190 adj_dot[*a] += wd * vals[*b] + w * dot[*b];
3191 adj_dot[*b] += wd * vals[*a] + w * dot[*a];
3192 }
3193 TapeOp::Div(a, b) => {
3194 let vb = vals[*b];
3195 let vb2 = vb * vb;
3196 let vb3 = vb2 * vb;
3197 adj[*a] += w / vb;
3198 adj_dot[*a] += wd / vb + w * (-dot[*b] / vb2);
3199 adj[*b] += w * (-vals[*a] / vb2);
3200 adj_dot[*b] +=
3201 wd * (-vals[*a] / vb2) + w * (-dot[*a] / vb2 + 2.0 * vals[*a] * dot[*b] / vb3);
3202 }
3203 TapeOp::Pow(a, b) => {
3204 let u = vals[*a];
3205 let r = vals[*b];
3206 let du = dot[*a];
3207 let dr = dot[*b];
3208 if r != 0.0 {
3209 if u != 0.0 {
3210 let p_a = r * u.powf(r - 1.0);
3211 adj[*a] += w * p_a;
3212 let mut dp_a = dr * u.powf(r - 1.0);
3213 if u > 0.0 {
3214 dp_a += r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
3215 } else {
3216 dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
3217 }
3218 adj_dot[*a] += wd * p_a + w * dp_a;
3219 } else if r >= 2.0 {
3220 let p_a = 0.0;
3221 adj[*a] += w * p_a;
3222 let dp_a = if r == 2.0 {
3223 2.0 * du
3224 } else {
3225 r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
3226 };
3227 adj_dot[*a] += wd * p_a + w * dp_a;
3228 }
3229 }
3230 if u > 0.0 {
3231 let ln_u = u.ln();
3232 let p_b = vals[i] * ln_u;
3233 adj[*b] += w * p_b;
3234 let dur = vals[i] * (r * du / u + dr * ln_u);
3235 let dp_b = dur * ln_u + vals[i] * du / u;
3236 adj_dot[*b] += wd * p_b + w * dp_b;
3237 }
3238 }
3239 TapeOp::Neg(a) => {
3240 adj[*a] -= w;
3241 adj_dot[*a] -= wd;
3242 }
3243 TapeOp::Abs(a) => {
3244 let s = if vals[*a] >= 0.0 { 1.0 } else { -1.0 };
3245 adj[*a] += w * s;
3246 adj_dot[*a] += wd * s;
3247 }
3248 TapeOp::Sqrt(a) => {
3249 let sv = vals[i];
3250 if sv > 0.0 {
3251 let fp = 0.5 / sv;
3252 let fpp = -0.25 / (vals[*a] * sv);
3253 adj[*a] += w * fp;
3254 adj_dot[*a] += wd * fp + w * fpp * dot[*a];
3255 }
3256 }
3257 TapeOp::Exp(a) => {
3258 let ev = vals[i];
3259 adj[*a] += w * ev;
3260 adj_dot[*a] += wd * ev + w * ev * dot[*a];
3261 }
3262 TapeOp::Log(a) => {
3263 let u = vals[*a];
3264 adj[*a] += w / u;
3265 adj_dot[*a] += wd / u + w * (-1.0 / (u * u)) * dot[*a];
3266 }
3267 TapeOp::Log10(a) => {
3268 let u = vals[*a];
3269 let c = std::f64::consts::LN_10;
3270 adj[*a] += w / (u * c);
3271 adj_dot[*a] += wd / (u * c) + w * (-1.0 / (u * u * c)) * dot[*a];
3272 }
3273 TapeOp::Sin(a) => {
3274 let u = vals[*a];
3275 let cu = u.cos();
3276 adj[*a] += w * cu;
3277 adj_dot[*a] += wd * cu + w * (-u.sin()) * dot[*a];
3278 }
3279 TapeOp::Cos(a) => {
3280 let u = vals[*a];
3281 let su = u.sin();
3282 adj[*a] -= w * su;
3283 adj_dot[*a] += wd * (-su) + w * (-u.cos()) * dot[*a];
3284 }
3285 TapeOp::Tan(a) => {
3286 let t = vals[i];
3287 let gp = 1.0 + t * t;
3288 let gpp = 2.0 * t * gp;
3289 adj[*a] += w * gp;
3290 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3291 }
3292 TapeOp::Atan(a) => {
3293 let u = vals[*a];
3294 let d = 1.0 + u * u;
3295 let gp = 1.0 / d;
3296 let gpp = -2.0 * u / (d * d);
3297 adj[*a] += w * gp;
3298 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3299 }
3300 TapeOp::Acos(a) => {
3301 let u = vals[*a];
3302 let s = 1.0 - u * u;
3303 let r = s.sqrt();
3304 let gp = -1.0 / r;
3305 let gpp = -u / (s * r);
3306 adj[*a] += w * gp;
3307 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3308 }
3309 TapeOp::Sinh(a) => {
3310 let u = vals[*a];
3311 let gp = u.cosh();
3312 let gpp = vals[i]; adj[*a] += w * gp;
3314 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3315 }
3316 TapeOp::Cosh(a) => {
3317 let u = vals[*a];
3318 let gp = u.sinh();
3319 let gpp = vals[i]; adj[*a] += w * gp;
3321 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3322 }
3323 TapeOp::Tanh(a) => {
3324 let t = vals[i];
3325 let gp = 1.0 - t * t;
3326 let gpp = -2.0 * t * gp;
3327 adj[*a] += w * gp;
3328 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3329 }
3330 TapeOp::Asin(a) => {
3331 let u = vals[*a];
3332 let s = 1.0 - u * u;
3333 let r = s.sqrt();
3334 let gp = 1.0 / r;
3335 let gpp = u / (s * r);
3336 adj[*a] += w * gp;
3337 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3338 }
3339 TapeOp::Acosh(a) => {
3340 let u = vals[*a];
3341 let s = u * u - 1.0;
3342 let r = s.sqrt();
3343 let gp = 1.0 / r;
3344 let gpp = -u / (s * r);
3345 adj[*a] += w * gp;
3346 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3347 }
3348 TapeOp::Asinh(a) => {
3349 let u = vals[*a];
3350 let s = u * u + 1.0;
3351 let r = s.sqrt();
3352 let gp = 1.0 / r;
3353 let gpp = -u / (s * r);
3354 adj[*a] += w * gp;
3355 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3356 }
3357 TapeOp::Atanh(a) => {
3358 let u = vals[*a];
3359 let d = 1.0 - u * u;
3360 let gp = 1.0 / d;
3361 let gpp = 2.0 * u / (d * d);
3362 adj[*a] += w * gp;
3363 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3364 }
3365 TapeOp::Atan2(a, b) => {
3366 let y = vals[*a];
3367 let x = vals[*b];
3368 let d = y * y + x * x;
3369 let d2 = d * d;
3370 let fa = x / d;
3371 let fb = -y / d;
3372 let faa = -2.0 * x * y / d2;
3373 let fab = (y * y - x * x) / d2;
3374 let fbb = 2.0 * x * y / d2;
3375 adj[*a] += w * fa;
3376 adj[*b] += w * fb;
3377 adj_dot[*a] += wd * fa + w * (faa * dot[*a] + fab * dot[*b]);
3378 adj_dot[*b] += wd * fb + w * (fab * dot[*a] + fbb * dot[*b]);
3379 }
3380 TapeOp::Cmp(_, _, _)
3381 | TapeOp::And(_, _)
3382 | TapeOp::Or(_, _)
3383 | TapeOp::Not(_)
3384 | TapeOp::Select(_, _, _)
3385 | TapeOp::Min(_, _)
3386 | TapeOp::Max(_, _) => panic!(
3387 "GlobalTape free-function kernels do not implement conditional / logical \
3388 / min-max TapeOps; use the Tape (build_with_externals) interpreter path \
3389 instead."
3390 ),
3391 TapeOp::Funcall(fc) => {
3392 let FuncallData { lib, name, args } = fc.as_ref();
3393 let call_args = funcall_to_ext_args(args, vals);
3394 let res = lib.eval(name, &call_args, true, true).unwrap_or_else(|e| {
3395 panic!("external function '{name}' 2nd-order eval failed: {e}")
3396 });
3397 let derivs = res.derivs.expect("want_derivs=true returns derivs");
3398 let hes = res.hessian.expect("want_hes=true returns hessian");
3399 let real_tape: Vec<usize> = args
3400 .iter()
3401 .filter_map(|a| match a {
3402 TapeFuncallArg::Tape(t) => Some(*t),
3403 TapeFuncallArg::Str(_) => None,
3404 })
3405 .collect();
3406 for (k, &tk) in real_tape.iter().enumerate() {
3407 adj[tk] += w * derivs[k];
3408 let mut second_term = 0.0;
3409 for (l, &tl) in real_tape.iter().enumerate() {
3410 let (lo, hi) = if k <= l { (k, l) } else { (l, k) };
3411 let h_kl = hes[lo + hi * (hi + 1) / 2];
3412 second_term += h_kl * dot[tl];
3413 }
3414 adj_dot[tk] += wd * derivs[k] + w * second_term;
3415 }
3416 let _ = seed_var;
3417 let _ = hess_map;
3418 let _ = values;
3419 let _ = weight;
3420 let _ = i;
3421 }
3422 }
3423}
3424
3425fn hessian_sparsity_impl(ops: &[TapeOp]) -> BTreeSet<(usize, usize)> {
3429 let n = ops.len();
3430 let mut var_sets: Vec<BTreeSet<usize>> = Vec::with_capacity(n);
3431 let mut pairs: BTreeSet<(usize, usize)> = BTreeSet::new();
3432
3433 let emit_cross =
3434 |s1: &BTreeSet<usize>, s2: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
3435 for &v1 in s1 {
3436 for &v2 in s2 {
3437 let (r, c) = if v1 >= v2 { (v1, v2) } else { (v2, v1) };
3438 pairs.insert((r, c));
3439 }
3440 }
3441 };
3442 let emit_self = |s: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
3443 let vars: Vec<usize> = s.iter().copied().collect();
3444 for (ai, &vi) in vars.iter().enumerate() {
3445 for &vj in &vars[..=ai] {
3446 let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
3447 pairs.insert((r, c));
3448 }
3449 }
3450 };
3451
3452 for op in ops {
3453 let vset = match op {
3454 TapeOp::Const(_) => BTreeSet::new(),
3455 TapeOp::Var(j) => {
3456 let mut s = BTreeSet::new();
3457 s.insert(*j);
3458 s
3459 }
3460 TapeOp::Add(a, b) | TapeOp::Sub(a, b) => {
3461 var_sets[*a].union(&var_sets[*b]).copied().collect()
3462 }
3463 TapeOp::Neg(a) | TapeOp::Abs(a) => var_sets[*a].clone(),
3464 TapeOp::Mul(a, b) => {
3465 emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
3466 var_sets[*a].union(&var_sets[*b]).copied().collect()
3467 }
3468 TapeOp::Div(a, b) => {
3469 emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
3470 emit_self(&var_sets[*b], &mut pairs);
3471 var_sets[*a].union(&var_sets[*b]).copied().collect()
3472 }
3473 TapeOp::Pow(a, b) | TapeOp::Atan2(a, b) => {
3474 let combined: BTreeSet<usize> =
3475 var_sets[*a].union(&var_sets[*b]).copied().collect();
3476 emit_self(&combined, &mut pairs);
3477 combined
3478 }
3479 TapeOp::Sqrt(a)
3480 | TapeOp::Exp(a)
3481 | TapeOp::Log(a)
3482 | TapeOp::Log10(a)
3483 | TapeOp::Sin(a)
3484 | TapeOp::Cos(a)
3485 | TapeOp::Tan(a)
3486 | TapeOp::Atan(a)
3487 | TapeOp::Acos(a)
3488 | TapeOp::Sinh(a)
3489 | TapeOp::Cosh(a)
3490 | TapeOp::Tanh(a)
3491 | TapeOp::Asin(a)
3492 | TapeOp::Acosh(a)
3493 | TapeOp::Asinh(a)
3494 | TapeOp::Atanh(a) => {
3495 emit_self(&var_sets[*a], &mut pairs);
3496 var_sets[*a].clone()
3497 }
3498 TapeOp::Funcall(fc) => {
3499 let args = &fc.args;
3500 let mut combined: BTreeSet<usize> = BTreeSet::new();
3501 for arg in args {
3502 if let TapeFuncallArg::Tape(t) = arg {
3503 for &vv in &var_sets[*t] {
3504 combined.insert(vv);
3505 }
3506 }
3507 }
3508 emit_self(&combined, &mut pairs);
3509 combined
3510 }
3511 TapeOp::Cmp(_, _, _) | TapeOp::And(_, _) | TapeOp::Or(_, _) | TapeOp::Not(_) => {
3512 BTreeSet::new()
3515 }
3516 TapeOp::Select(_, t, e) => {
3517 var_sets[*t].union(&var_sets[*e]).copied().collect()
3520 }
3521 TapeOp::Min(a, b) | TapeOp::Max(a, b) => {
3522 var_sets[*a].union(&var_sets[*b]).copied().collect()
3525 }
3526 };
3527 var_sets.push(vset);
3528 }
3529 pairs
3530}
3531
3532#[cfg(test)]
3533mod tests {
3534 use super::*;
3535
3536 fn cnst(c: f64) -> Expr {
3537 Expr::Const(c)
3538 }
3539 fn var(i: usize) -> Expr {
3540 Expr::Var(i)
3541 }
3542 fn add(a: Expr, b: Expr) -> Expr {
3543 Expr::Binary(BinOp::Add, Box::new(a), Box::new(b))
3544 }
3545 fn mul(a: Expr, b: Expr) -> Expr {
3546 Expr::Binary(BinOp::Mul, Box::new(a), Box::new(b))
3547 }
3548 fn pow(a: Expr, b: Expr) -> Expr {
3549 Expr::Binary(BinOp::Pow, Box::new(a), Box::new(b))
3550 }
3551 fn div(a: Expr, b: Expr) -> Expr {
3552 Expr::Binary(BinOp::Div, Box::new(a), Box::new(b))
3553 }
3554 fn unary(op: UnaryOp, a: Expr) -> Expr {
3555 Expr::Unary(op, Box::new(a))
3556 }
3557 fn cmp(op: CmpOp, a: Expr, b: Expr) -> Expr {
3558 Expr::Compare(op, Box::new(a), Box::new(b))
3559 }
3560 fn cond(c: Expr, t: Expr, e: Expr) -> Expr {
3561 Expr::Cond {
3562 cond: Box::new(c),
3563 then_: Box::new(t),
3564 else_: Box::new(e),
3565 }
3566 }
3567
3568 #[test]
3569 fn polynomial_eval_and_grad() {
3570 let e = add(
3572 mul(cnst(3.0), pow(var(0), cnst(2.0))),
3573 mul(cnst(2.0), var(1)),
3574 );
3575 let t = Tape::build(&e);
3576 assert!((t.eval(&[2.0, 3.0]) - 18.0).abs() < 1e-12);
3577 let mut g = vec![0.0; 2];
3578 t.gradient_seed(&[2.0, 3.0], 1.0, &mut g);
3579 assert!((g[0] - 12.0).abs() < 1e-12);
3581 assert!((g[1] - 2.0).abs() < 1e-12);
3582 }
3583
3584 #[test]
3585 fn cse_shared_body_evaluated_once() {
3586 let body = Arc::new(add(var(0), var(1)));
3588 let e = add(
3589 pow(Expr::Cse(body.clone()), cnst(2.0)),
3590 Expr::Cse(body.clone()),
3591 );
3592 let t = Tape::build(&e);
3593 let n_body_adds = t
3595 .ops
3596 .iter()
3597 .filter(|op| {
3598 matches!(op, TapeOp::Add(a, b) if {
3599 matches!(t.ops[*a], TapeOp::Var(0)) && matches!(t.ops[*b], TapeOp::Var(1))
3600 })
3601 })
3602 .count();
3603 assert_eq!(n_body_adds, 1, "CSE body should be emitted exactly once");
3604
3605 assert!((t.eval(&[1.0, 2.0]) - 12.0).abs() < 1e-12);
3607 let mut g = vec![0.0; 2];
3608 t.gradient_seed(&[1.0, 2.0], 1.0, &mut g);
3609 assert!((g[0] - 7.0).abs() < 1e-12);
3611 assert!((g[1] - 7.0).abs() < 1e-12);
3612 }
3613
3614 fn fd_check(tape: &Tape, x: &[f64], n: usize, tol: f64) {
3615 let vars = tape.variables();
3616 let mut hess_map: HashMap<(usize, usize), usize> = HashMap::new();
3617 let mut pairs = Vec::new();
3618 for (ai, &vi) in vars.iter().enumerate() {
3619 for &vj in &vars[..=ai] {
3620 let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
3621 hess_map.entry((r, c)).or_insert_with(|| {
3622 let p = pairs.len();
3623 pairs.push((r, c));
3624 p
3625 });
3626 }
3627 }
3628 let nnz = pairs.len();
3629 let mut ad = vec![0.0; nnz];
3630 tape.hessian_accumulate(x, 1.0, &hess_map, &mut ad);
3631
3632 let mut fd = vec![0.0; nnz];
3633 let mut xp = x.to_vec();
3634 let mut gp = vec![0.0; n];
3635 let mut gm = vec![0.0; n];
3636 for &j in &vars {
3637 let h = (1e-7_f64).max(x[j].abs() * 1e-7);
3638 xp[j] = x[j] + h;
3639 gp.iter_mut().for_each(|v| *v = 0.0);
3640 tape.gradient_seed(&xp, 1.0, &mut gp);
3641 xp[j] = x[j] - h;
3642 gm.iter_mut().for_each(|v| *v = 0.0);
3643 tape.gradient_seed(&xp, 1.0, &mut gm);
3644 xp[j] = x[j];
3645 for &i in &vars {
3646 if i >= j {
3647 if let Some(&pos) = hess_map.get(&(i, j)) {
3648 fd[pos] = (gp[i] - gm[i]) / (2.0 * h);
3649 }
3650 }
3651 }
3652 }
3653 for (k, &(r, c)) in pairs.iter().enumerate() {
3654 let scale = fd[k].abs().max(1.0);
3655 assert!(
3656 (ad[k] - fd[k]).abs() / scale < tol,
3657 "H[{},{}]: AD={:.6e} FD={:.6e}",
3658 r,
3659 c,
3660 ad[k],
3661 fd[k]
3662 );
3663 }
3664 }
3665
3666 #[test]
3667 fn hessian_quadratic_matches_fd() {
3668 let e = add(
3670 add(
3671 mul(cnst(3.0), pow(var(0), cnst(2.0))),
3672 mul(cnst(2.0), mul(var(0), var(1))),
3673 ),
3674 pow(var(1), cnst(2.0)),
3675 );
3676 let t = Tape::build(&e);
3677 fd_check(&t, &[2.0, 3.0], 2, 1e-5);
3678 }
3679
3680 #[test]
3681 fn hessian_transcendental_matches_fd() {
3682 let e = Expr::Sum(vec![
3684 unary(UnaryOp::Exp, var(0)),
3685 unary(UnaryOp::Sin, var(1)),
3686 unary(UnaryOp::Log, var(0)),
3687 unary(UnaryOp::Sqrt, var(1)),
3688 mul(var(0), var(1)),
3689 ]);
3690 let t = Tape::build(&e);
3691 fd_check(&t, &[1.5, 2.0], 2, 1e-5);
3692 }
3693
3694 #[test]
3695 fn inverse_trig_grad_and_hessian_match_fd() {
3696 let e = Expr::Sum(vec![
3700 unary(UnaryOp::Tan, var(0)),
3701 unary(UnaryOp::Atan, var(1)),
3702 unary(UnaryOp::Acos, var(2)),
3703 mul(var(0), var(1)),
3704 ]);
3705 let t = Tape::build(&e);
3706 let x = [0.5, 1.3, 0.3];
3707
3708 let mut g = vec![0.0; 3];
3712 t.gradient_seed(&x, 1.0, &mut g);
3713 for j in 0..3 {
3714 let h = (1e-7_f64).max(x[j].abs() * 1e-7);
3715 let mut xp = x;
3716 let mut xm = x;
3717 xp[j] += h;
3718 xm[j] -= h;
3719 let fd = (t.eval(&xp) - t.eval(&xm)) / (2.0 * h);
3720 let scale = fd.abs().max(1.0);
3721 assert!(
3722 (g[j] - fd).abs() / scale < 1e-5,
3723 "grad[{j}]: AD={:.6e} FD={:.6e}",
3724 g[j],
3725 fd
3726 );
3727 }
3728
3729 fd_check(&t, &x, 3, 1e-5);
3731 }
3732
3733 fn grad_and_hess_match_fd(e: &Expr, x: &[f64], tol: f64) {
3736 let n = x.len();
3737 let t = Tape::build(e);
3738 let mut g = vec![0.0; n];
3739 t.gradient_seed(x, 1.0, &mut g);
3740 for j in 0..n {
3741 let h = (1e-7_f64).max(x[j].abs() * 1e-7);
3742 let mut xp = x.to_vec();
3743 let mut xm = x.to_vec();
3744 xp[j] += h;
3745 xm[j] -= h;
3746 let fd = (t.eval(&xp) - t.eval(&xm)) / (2.0 * h);
3747 let scale = fd.abs().max(1.0);
3748 assert!(
3749 (g[j] - fd).abs() / scale < tol,
3750 "grad[{j}]: AD={:.6e} FD={:.6e}",
3751 g[j],
3752 fd
3753 );
3754 }
3755 fd_check(&t, x, n, tol);
3756 }
3757
3758 #[test]
3759 fn hyperbolic_grad_and_hessian_match_fd() {
3760 let e = Expr::Sum(vec![
3763 unary(UnaryOp::Sinh, var(0)),
3764 unary(UnaryOp::Cosh, var(1)),
3765 unary(UnaryOp::Tanh, var(2)),
3766 unary(UnaryOp::Asinh, var(3)),
3767 mul(var(0), var(1)),
3768 mul(var(2), var(3)),
3769 ]);
3770 grad_and_hess_match_fd(&e, &[0.5, 0.7, 0.3, 1.1], 1e-5);
3771 }
3772
3773 #[test]
3774 fn restricted_inverse_grad_and_hessian_match_fd() {
3775 let e = Expr::Sum(vec![
3779 unary(UnaryOp::Asin, var(0)),
3780 unary(UnaryOp::Acosh, var(1)),
3781 unary(UnaryOp::Atanh, var(2)),
3782 mul(var(0), var(2)),
3783 ]);
3784 grad_and_hess_match_fd(&e, &[0.4, 1.8, 0.3], 1e-5);
3785 }
3786
3787 #[test]
3788 fn atan2_grad_and_hessian_match_fd() {
3789 let atan2 = |a: Expr, b: Expr| Expr::Binary(BinOp::Atan2, Box::new(a), Box::new(b));
3791 let e = Expr::Sum(vec![atan2(var(0), var(1)), mul(var(0), var(1))]);
3792 grad_and_hess_match_fd(&e, &[1.2, 0.7], 1e-5);
3793 }
3794
3795 #[test]
3796 fn minmax_grad_and_hessian_match_fd() {
3797 let e = Expr::Sum(vec![
3804 Expr::MinList(vec![var(0), var(1), var(2)]),
3805 Expr::MaxList(vec![var(1), var(2)]),
3806 mul(var(0), var(2)),
3807 ]);
3808 grad_and_hess_match_fd(&e, &[0.5, 3.0, 2.0], 1e-5);
3809 }
3810
3811 #[test]
3812 fn minmax_value_and_active_operand() {
3813 let e = Expr::Sum(vec![
3816 Expr::MinList(vec![var(0), var(1)]),
3817 Expr::MaxList(vec![var(0), var(1)]),
3818 ]);
3819 let t = Tape::build(&e);
3820 let x = [1.3, -0.4];
3822 assert!((t.eval(&x) - (x[0] + x[1])).abs() < 1e-12);
3823 let mut g = vec![0.0; 2];
3824 t.gradient_seed(&x, 1.0, &mut g);
3825 assert!((g[0] - 1.0).abs() < 1e-12, "g0={}", g[0]);
3828 assert!((g[1] - 1.0).abs() < 1e-12, "g1={}", g[1]);
3829 }
3830
3831 #[test]
3832 fn hessian_division_matches_fd() {
3833 let e = add(div(var(0), var(1)), unary(UnaryOp::Cos, var(0)));
3835 let t = Tape::build(&e);
3836 fd_check(&t, &[0.5, 1.2], 2, 1e-5);
3837 }
3838
3839 #[test]
3840 fn conditional_value_grad_hessian_active_branch() {
3841 let e = cond(
3845 cmp(CmpOp::Ge, var(0), cnst(1.0)),
3846 mul(var(0), var(1)),
3847 pow(var(1), cnst(2.0)),
3848 );
3849 let t = Tape::build(&e);
3850
3851 let x = [2.0, 5.0];
3853 assert!((t.eval(&x) - 10.0).abs() < 1e-12);
3854 let mut g = vec![0.0; 2];
3855 t.gradient_seed(&x, 1.0, &mut g);
3856 assert!((g[0] - 5.0).abs() < 1e-10);
3858 assert!((g[1] - 2.0).abs() < 1e-10);
3859 fd_check(&t, &x, 2, 1e-5);
3861
3862 let x2 = [0.0, 5.0];
3864 assert!((t.eval(&x2) - 25.0).abs() < 1e-12);
3865 let mut g2 = vec![0.0; 2];
3866 t.gradient_seed(&x2, 1.0, &mut g2);
3867 assert!(g2[0].abs() < 1e-10);
3868 assert!((g2[1] - 10.0).abs() < 1e-10);
3869 fd_check(&t, &x2, 2, 1e-5);
3870 }
3871
3872 #[test]
3873 fn comparison_and_logical_have_zero_derivative() {
3874 let lt = cmp(CmpOp::Lt, var(0), var(1));
3878 let and = Expr::And(
3879 Box::new(cmp(CmpOp::Gt, var(0), cnst(0.0))),
3880 Box::new(cmp(CmpOp::Gt, var(1), cnst(0.0))),
3881 );
3882 let notc = Expr::Not(Box::new(cmp(CmpOp::Eq, var(0), var(1))));
3883 let e = add(add(lt, and), notc);
3884 let t = Tape::build(&e);
3885
3886 let x = [1.0, 2.0];
3887 assert!((t.eval(&x) - 3.0).abs() < 1e-12);
3889 let mut g = vec![0.0; 2];
3890 t.gradient_seed(&x, 1.0, &mut g);
3891 assert!(g[0].abs() < 1e-12, "d/dx0 should be 0, got {}", g[0]);
3892 assert!(g[1].abs() < 1e-12, "d/dx1 should be 0, got {}", g[1]);
3893 }
3894
3895 #[test]
3896 fn logical_or_value() {
3897 let e = Expr::Or(
3899 Box::new(cmp(CmpOp::Gt, var(0), cnst(0.0))),
3900 Box::new(cmp(CmpOp::Gt, var(1), cnst(0.0))),
3901 );
3902 let t = Tape::build(&e);
3903 assert!((t.eval(&[-1.0, 3.0]) - 1.0).abs() < 1e-12);
3904 assert!((t.eval(&[-1.0, -3.0]) - 0.0).abs() < 1e-12);
3905 }
3906
3907 fn directional_matches_accumulate(tape: &Tape, x: &[f64], n: usize) {
3912 let vars = tape.variables();
3913 let mut hess_map: HashMap<(usize, usize), usize> = HashMap::new();
3914 let mut pairs = Vec::new();
3915 for (ai, &vi) in vars.iter().enumerate() {
3916 for &vj in &vars[..=ai] {
3917 let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
3918 hess_map.entry((r, c)).or_insert_with(|| {
3919 let p = pairs.len();
3920 pairs.push((r, c));
3921 p
3922 });
3923 }
3924 }
3925 let nnz = pairs.len();
3926 let mut ad = vec![0.0; nnz];
3927 tape.hessian_accumulate(x, 1.0, &hess_map, &mut ad);
3928
3929 let nops = tape.ops.len();
3930 let mut vals = vec![0.0; nops];
3931 tape.forward_into(x, &mut vals);
3932 let mut dot = vec![0.0; nops];
3933 let mut adj = vec![0.0; nops];
3934 let mut adj_dot = vec![0.0; nops];
3935
3936 for &j in &vars {
3937 let mut seed = vec![0.0; n];
3938 seed[j] = 1.0;
3939 let mut col = vec![0.0; n];
3940 tape.hessian_directional(
3941 &vals,
3942 &seed,
3943 1.0,
3944 &mut col,
3945 &mut dot,
3946 &mut adj,
3947 &mut adj_dot,
3948 );
3949 for &i in &vars {
3950 let (r, c) = if i >= j { (i, j) } else { (j, i) };
3951 let expect = ad[hess_map[&(r, c)]];
3952 assert!(
3953 (col[i] - expect).abs() < 1e-10,
3954 "directional H[{i},{j}] = {} vs accumulate {}",
3955 col[i],
3956 expect
3957 );
3958 }
3959 }
3960 }
3961
3962 #[test]
3963 fn directional_quadratic_matches_accumulate() {
3964 let e = add(
3966 add(
3967 mul(cnst(3.0), pow(var(0), cnst(2.0))),
3968 mul(mul(cnst(2.0), var(0)), var(1)),
3969 ),
3970 pow(var(1), cnst(2.0)),
3971 );
3972 let t = Tape::build(&e);
3973 directional_matches_accumulate(&t, &[0.5, -0.3], 2);
3974 }
3975
3976 #[test]
3977 fn directional_transcendental_matches_accumulate() {
3978 let e = Expr::Sum(vec![
3979 unary(UnaryOp::Exp, var(0)),
3980 unary(UnaryOp::Sin, var(1)),
3981 unary(UnaryOp::Log, var(0)),
3982 unary(UnaryOp::Sqrt, var(1)),
3983 mul(var(0), var(1)),
3984 ]);
3985 let t = Tape::build(&e);
3986 directional_matches_accumulate(&t, &[1.5, 2.0], 2);
3987 }
3988
3989 #[test]
3990 fn directional_with_division_matches_accumulate() {
3991 let e = add(div(var(0), var(1)), unary(UnaryOp::Cos, var(0)));
3992 let t = Tape::build(&e);
3993 directional_matches_accumulate(&t, &[0.5, 1.2], 2);
3994 }
3995
3996 #[test]
3997 fn hessian_sparsity_separable() {
3998 let e = add(unary(UnaryOp::Sin, var(0)), mul(var(1), var(2)));
4000 let t = Tape::build(&e);
4001 let s = t.hessian_sparsity();
4002 assert!(s.contains(&(0, 0)));
4003 assert!(s.contains(&(2, 1)));
4004 assert!(!s.contains(&(1, 0)));
4005 assert!(!s.contains(&(2, 0)));
4006 }
4007
4008 fn count_op<F: Fn(&TapeOp) -> bool>(t: &Tape, pred: F) -> usize {
4009 t.ops.iter().filter(|o| pred(o)).count()
4010 }
4011
4012 #[test]
4013 fn pow_zero_const_folds_to_one() {
4014 let e = pow(var(0), cnst(0.0));
4016 let t = Tape::build(&e);
4017 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
4018 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Var(_))), 0);
4019 assert!((t.eval(&[7.0]) - 1.0).abs() < 1e-12);
4020 }
4021
4022 #[test]
4023 fn pow_one_passes_through() {
4024 let e = pow(var(0), cnst(1.0));
4026 let t = Tape::build(&e);
4027 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
4028 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Const(_))), 0);
4029 assert!((t.eval(&[3.5]) - 3.5).abs() < 1e-12);
4030 }
4031
4032 #[test]
4033 fn pow_half_lowers_to_sqrt() {
4034 let e = pow(var(0), cnst(0.5));
4035 let t = Tape::build(&e);
4036 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
4037 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Sqrt(_))), 1);
4038 assert!((t.eval(&[16.0]) - 4.0).abs() < 1e-12);
4039 }
4040
4041 #[test]
4042 fn pow_two_lowers_to_single_mul() {
4043 let e = pow(var(0), cnst(2.0));
4044 let t = Tape::build(&e);
4045 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
4046 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 1);
4047 assert!((t.eval(&[3.0]) - 9.0).abs() < 1e-12);
4048 }
4049
4050 #[test]
4051 fn pow_three_lowers_to_two_muls() {
4052 let e = pow(var(0), cnst(3.0));
4053 let t = Tape::build(&e);
4054 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
4055 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 2);
4056 assert!((t.eval(&[2.0]) - 8.0).abs() < 1e-12);
4057 }
4058
4059 #[test]
4060 fn pow_eight_lowers_to_three_muls() {
4061 let e = pow(var(0), cnst(8.0));
4063 let t = Tape::build(&e);
4064 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
4065 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 3);
4066 assert!((t.eval(&[2.0]) - 256.0).abs() < 1e-12);
4067 }
4068
4069 #[test]
4070 fn pow_negative_two_lowers_to_div() {
4071 let e = pow(var(0), cnst(-2.0));
4073 let t = Tape::build(&e);
4074 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
4075 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Div(..))), 1);
4076 assert!((t.eval(&[4.0]) - (1.0 / 16.0)).abs() < 1e-12);
4077 }
4078
4079 #[test]
4080 fn pow_large_const_stays_generic() {
4081 let e = pow(var(0), cnst(9.0));
4083 let t = Tape::build(&e);
4084 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 1);
4085 }
4086
4087 #[test]
4088 fn pow_non_integer_const_stays_generic() {
4089 let e = pow(var(0), cnst(1.5));
4091 let t = Tape::build(&e);
4092 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 1);
4093 }
4094
4095 #[test]
4096 fn pow_const_through_cse_const() {
4097 let two = Arc::new(cnst(2.0));
4099 let e = Expr::Binary(BinOp::Pow, Box::new(var(0)), Box::new(Expr::Cse(two)));
4100 let t = Tape::build(&e);
4101 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
4102 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 1);
4103 }
4104
4105 #[test]
4106 fn hessian_pow_three_matches_fd() {
4107 let e = add(mul(cnst(5.0), pow(var(0), cnst(3.0))), mul(var(0), var(1)));
4109 let t = Tape::build(&e);
4110 fd_check(&t, &[1.7, 0.8], 2, 1e-5);
4111 }
4112
4113 #[test]
4114 fn hessian_pow_negative_matches_fd() {
4115 let e = add(pow(var(0), cnst(-2.0)), pow(var(1), cnst(2.0)));
4117 let t = Tape::build(&e);
4118 fd_check(&t, &[1.3, 2.4], 2, 1e-5);
4119 }
4120
4121 #[test]
4122 fn hessian_pow_half_matches_fd() {
4123 let e = add(pow(var(0), cnst(0.5)), mul(var(0), var(1)));
4125 let t = Tape::build(&e);
4126 fd_check(&t, &[2.5, 1.1], 2, 1e-5);
4127 }
4128
4129 #[test]
4130 fn hessian_sparsity_through_cse() {
4131 let body = Arc::new(add(var(0), var(1)));
4134 let e = add(
4135 pow(Expr::Cse(body.clone()), cnst(2.0)),
4136 Expr::Cse(body.clone()),
4137 );
4138 let t = Tape::build(&e);
4139 let s = t.hessian_sparsity();
4140 assert!(s.contains(&(0, 0)));
4141 assert!(s.contains(&(1, 0)));
4142 assert!(s.contains(&(1, 1)));
4143 assert_eq!(s.len(), 3);
4144 }
4145
4146 #[test]
4147 fn pow_forward_tangent_matches_reverse_gradient_at_base_zero() {
4148 let e = pow(var(0), var(1));
4157 let t = Tape::build(&e);
4158 assert!(
4161 t.ops.iter().any(|op| matches!(op, TapeOp::Pow(_, _))),
4162 "expected a Pow op in the tape; got {:?}",
4163 t.ops
4164 );
4165 let x = [0.0, 1.0];
4166 let n = t.ops.len();
4167
4168 let mut grad = vec![0.0; 2];
4170 t.gradient_seed(&x, 1.0, &mut grad);
4171
4172 let vals = t.forward(&x);
4174 let mut dot = vec![0.0; n];
4175 t.forward_tangent(&vals, 0, &mut dot);
4176 let fwd_dfx0 = dot[n - 1];
4177
4178 assert!(
4179 (grad[0] - 1.0).abs() < 1e-12,
4180 "reverse gradient df/dx0 at base 0 should be 1, got {}",
4181 grad[0]
4182 );
4183 assert!(
4184 (fwd_dfx0 - grad[0]).abs() < 1e-12,
4185 "forward tangent df/dx0 = {fwd_dfx0} must match reverse gradient {} at base 0",
4186 grad[0]
4187 );
4188 }
4189
4190 #[test]
4191 #[should_panic(expected = "external function calls are not supported on the")]
4192 fn hybrid_promoted_cse_with_funcall_reports_clear_message() {
4193 let body = Arc::new(Expr::Funcall {
4201 id: 0,
4202 args: vec![FuncallArg::Real(var(0))],
4203 });
4204 let exprs = vec![
4205 add(Expr::Cse(body.clone()), cnst(1.0)),
4206 add(Expr::Cse(body.clone()), cnst(2.0)),
4207 ];
4208 HybridTape::build_multi(&exprs);
4209 }
4210}