1use std::collections::{BTreeSet, HashMap, HashSet};
26use std::rc::Rc;
27use std::sync::Arc;
28
29use super::nl_external::{EvalResult, ExternalArg, ExternalLibrary, ExternalResolver};
30use super::nl_reader::{BinOp, CmpOp, Expr, FuncallArg, UnaryOp};
31
32#[derive(Debug, Clone)]
36pub enum TapeOp {
37 Const(f64),
38 Var(usize),
39 Add(usize, usize),
40 Sub(usize, usize),
41 Mul(usize, usize),
42 Div(usize, usize),
43 Pow(usize, usize),
44 Neg(usize),
45 Abs(usize),
46 Sqrt(usize),
47 Exp(usize),
48 Log(usize),
49 Log10(usize),
50 Sin(usize),
51 Cos(usize),
52 Tan(usize),
53 Atan(usize),
54 Acos(usize),
55 Sinh(usize),
56 Cosh(usize),
57 Tanh(usize),
58 Asin(usize),
59 Acosh(usize),
60 Asinh(usize),
61 Atanh(usize),
62 Atan2(usize, usize),
65 Min(usize, usize),
70 Max(usize, usize),
73 Cmp(CmpOp, usize, usize),
77 And(usize, usize),
79 Or(usize, usize),
81 Not(usize),
83 Select(usize, usize, usize),
89 Funcall(Box<FuncallData>),
97}
98
99#[derive(Debug, Clone)]
104pub struct FuncallData {
105 pub lib: Arc<ExternalLibrary>,
106 pub name: String,
107 pub args: Vec<TapeFuncallArg>,
108}
109
110#[derive(Debug, Clone)]
114pub enum TapeFuncallArg {
115 Tape(usize),
116 Str(String),
117}
118
119#[inline]
122fn cmp_holds(op: CmpOp, a: f64, b: f64) -> bool {
123 match op {
124 CmpOp::Lt => a < b,
125 CmpOp::Le => a <= b,
126 CmpOp::Eq => a == b,
127 CmpOp::Ge => a >= b,
128 CmpOp::Gt => a > b,
129 CmpOp::Ne => a != b,
130 }
131}
132
133fn funcall_to_ext_args<'a>(args: &'a [TapeFuncallArg], vals: &[f64]) -> Vec<ExternalArg<'a>> {
134 args.iter()
135 .map(|a| match a {
136 TapeFuncallArg::Tape(idx) => ExternalArg::Real(vals[*idx]),
137 TapeFuncallArg::Str(s) => ExternalArg::Str(s.as_str()),
138 })
139 .collect()
140}
141
142fn ext_eval_or_nan(
155 lib: &ExternalLibrary,
156 name: &str,
157 call_args: &[ExternalArg<'_>],
158 n_args: usize,
159 want_derivs: bool,
160 want_hes: bool,
161) -> EvalResult {
162 lib.eval(name, call_args, want_derivs, want_hes)
163 .unwrap_or_else(|_| EvalResult {
164 value: f64::NAN,
165 derivs: want_derivs.then(|| vec![f64::NAN; n_args]),
166 hessian: want_hes.then(|| vec![f64::NAN; n_args * (n_args + 1) / 2]),
167 })
168}
169
170#[derive(Debug, Clone)]
173pub struct Tape {
174 pub ops: Vec<TapeOp>,
175}
176
177impl Tape {
178 pub fn build(expr: &Expr) -> Self {
182 Self::build_with_externals(expr, &ExternalResolver::default())
183 }
184
185 pub fn build_with_externals(expr: &Expr, resolver: &ExternalResolver) -> Self {
190 let mut ops = Vec::new();
191 let mut cache: HashMap<*const Expr, usize> = HashMap::new();
192 build_recursive(expr, &mut ops, &mut cache, resolver);
193 Tape { ops }
194 }
195
196 pub fn forward(&self, x: &[f64]) -> Vec<f64> {
199 let mut vals: Vec<f64> = Vec::with_capacity(self.ops.len());
200 for op in &self.ops {
201 let v = match op {
202 TapeOp::Const(c) => *c,
203 TapeOp::Var(i) => x[*i],
204 TapeOp::Add(a, b) => vals[*a] + vals[*b],
205 TapeOp::Sub(a, b) => vals[*a] - vals[*b],
206 TapeOp::Mul(a, b) => vals[*a] * vals[*b],
207 TapeOp::Div(a, b) => vals[*a] / vals[*b],
208 TapeOp::Pow(a, b) => vals[*a].powf(vals[*b]),
209 TapeOp::Neg(a) => -vals[*a],
210 TapeOp::Abs(a) => vals[*a].abs(),
211 TapeOp::Sqrt(a) => vals[*a].sqrt(),
212 TapeOp::Exp(a) => vals[*a].exp(),
213 TapeOp::Log(a) => vals[*a].ln(),
214 TapeOp::Log10(a) => vals[*a].log10(),
215 TapeOp::Sin(a) => vals[*a].sin(),
216 TapeOp::Cos(a) => vals[*a].cos(),
217 TapeOp::Tan(a) => vals[*a].tan(),
218 TapeOp::Atan(a) => vals[*a].atan(),
219 TapeOp::Acos(a) => vals[*a].acos(),
220 TapeOp::Sinh(a) => vals[*a].sinh(),
221 TapeOp::Cosh(a) => vals[*a].cosh(),
222 TapeOp::Tanh(a) => vals[*a].tanh(),
223 TapeOp::Asin(a) => vals[*a].asin(),
224 TapeOp::Acosh(a) => vals[*a].acosh(),
225 TapeOp::Asinh(a) => vals[*a].asinh(),
226 TapeOp::Atanh(a) => vals[*a].atanh(),
227 TapeOp::Atan2(a, b) => vals[*a].atan2(vals[*b]),
228 TapeOp::Min(a, b) => vals[*a].min(vals[*b]),
229 TapeOp::Max(a, b) => vals[*a].max(vals[*b]),
230 TapeOp::Cmp(op, a, b) => f64::from(cmp_holds(*op, vals[*a], vals[*b])),
231 TapeOp::And(a, b) => f64::from(vals[*a] != 0.0 && vals[*b] != 0.0),
232 TapeOp::Or(a, b) => f64::from(vals[*a] != 0.0 || vals[*b] != 0.0),
233 TapeOp::Not(a) => f64::from(vals[*a] == 0.0),
234 TapeOp::Select(c, t, e) => {
235 if vals[*c] != 0.0 {
236 vals[*t]
237 } else {
238 vals[*e]
239 }
240 }
241 TapeOp::Funcall(fc) => {
242 let FuncallData { lib, name, args } = fc.as_ref();
243 let call_args = funcall_to_ext_args(args, &vals);
244 let res = ext_eval_or_nan(lib, name, &call_args, args.len(), false, false);
245 res.value
246 }
247 };
248 vals.push(v);
249 }
250 vals
251 }
252
253 pub fn eval(&self, x: &[f64]) -> f64 {
254 let vals = self.forward(x);
255 *vals.last().unwrap_or(&0.0)
256 }
257
258 pub fn gradient_seed(&self, x: &[f64], seed: f64, grad: &mut [f64]) {
263 if seed == 0.0 || self.ops.is_empty() {
264 return;
265 }
266 let vals = self.forward(x);
267 self.reverse(&vals, seed, grad);
268 }
269
270 fn reverse(&self, vals: &[f64], seed: f64, grad: &mut [f64]) {
271 let n = self.ops.len();
272 let mut adj = vec![0.0f64; n];
273 adj[n - 1] = seed;
274
275 for i in (0..n).rev() {
276 let a = adj[i];
277 if a == 0.0 {
278 continue;
279 }
280 match &self.ops[i] {
281 TapeOp::Const(_) => {}
282 TapeOp::Var(j) => {
283 grad[*j] += a;
284 }
285 TapeOp::Add(l, r) => {
286 adj[*l] += a;
287 adj[*r] += a;
288 }
289 TapeOp::Sub(l, r) => {
290 adj[*l] += a;
291 adj[*r] -= a;
292 }
293 TapeOp::Mul(l, r) => {
294 adj[*l] += a * vals[*r];
295 adj[*r] += a * vals[*l];
296 }
297 TapeOp::Div(l, r) => {
298 let rv = vals[*r];
299 adj[*l] += a / rv;
300 adj[*r] -= a * vals[*l] / (rv * rv);
301 }
302 TapeOp::Pow(l, r) => {
303 let lv = vals[*l];
304 let rv = vals[*r];
305 if rv != 0.0 {
306 adj[*l] += a * rv * lv.powf(rv - 1.0);
307 }
308 if lv > 0.0 {
309 adj[*r] += a * vals[i] * lv.ln();
310 }
311 }
312 TapeOp::Neg(j) => {
313 adj[*j] -= a;
314 }
315 TapeOp::Abs(j) => {
316 if vals[*j] >= 0.0 {
317 adj[*j] += a;
318 } else {
319 adj[*j] -= a;
320 }
321 }
322 TapeOp::Sqrt(j) => {
323 let sv = vals[i];
324 if sv > 0.0 {
325 adj[*j] += a * 0.5 / sv;
326 }
327 }
328 TapeOp::Exp(j) => {
329 adj[*j] += a * vals[i];
330 }
331 TapeOp::Log(j) => {
332 adj[*j] += a / vals[*j];
333 }
334 TapeOp::Log10(j) => {
335 adj[*j] += a / (vals[*j] * std::f64::consts::LN_10);
336 }
337 TapeOp::Sin(j) => {
338 adj[*j] += a * vals[*j].cos();
339 }
340 TapeOp::Cos(j) => {
341 adj[*j] -= a * vals[*j].sin();
342 }
343 TapeOp::Tan(j) => {
344 let t = vals[i];
345 adj[*j] += a * (1.0 + t * t);
346 }
347 TapeOp::Atan(j) => {
348 let u = vals[*j];
349 adj[*j] += a / (1.0 + u * u);
350 }
351 TapeOp::Acos(j) => {
352 let u = vals[*j];
353 adj[*j] -= a / (1.0 - u * u).sqrt();
354 }
355 TapeOp::Sinh(j) => {
356 adj[*j] += a * vals[*j].cosh();
357 }
358 TapeOp::Cosh(j) => {
359 adj[*j] += a * vals[*j].sinh();
360 }
361 TapeOp::Tanh(j) => {
362 let t = vals[i];
363 adj[*j] += a * (1.0 - t * t);
364 }
365 TapeOp::Asin(j) => {
366 let u = vals[*j];
367 adj[*j] += a / (1.0 - u * u).sqrt();
368 }
369 TapeOp::Acosh(j) => {
370 let u = vals[*j];
371 adj[*j] += a / (u * u - 1.0).sqrt();
372 }
373 TapeOp::Asinh(j) => {
374 let u = vals[*j];
375 adj[*j] += a / (u * u + 1.0).sqrt();
376 }
377 TapeOp::Atanh(j) => {
378 let u = vals[*j];
379 adj[*j] += a / (1.0 - u * u);
380 }
381 TapeOp::Atan2(l, r) => {
382 let y = vals[*l];
383 let x = vals[*r];
384 let d = y * y + x * x;
385 adj[*l] += a * (x / d);
386 adj[*r] += a * (-y / d);
387 }
388 TapeOp::Min(l, r) => {
392 if vals[*l] <= vals[*r] {
393 adj[*l] += a;
394 } else {
395 adj[*r] += a;
396 }
397 }
398 TapeOp::Max(l, r) => {
399 if vals[*l] >= vals[*r] {
400 adj[*l] += a;
401 } else {
402 adj[*r] += a;
403 }
404 }
405 TapeOp::Cmp(_, _, _) | TapeOp::And(_, _) | TapeOp::Or(_, _) | TapeOp::Not(_) => {}
408 TapeOp::Select(c, t, e) => {
411 if vals[*c] != 0.0 {
412 adj[*t] += a;
413 } else {
414 adj[*e] += a;
415 }
416 }
417 TapeOp::Funcall(fc) => {
418 let FuncallData { lib, name, args } = fc.as_ref();
419 let call_args = funcall_to_ext_args(args, vals);
420 let res = ext_eval_or_nan(lib, name, &call_args, args.len(), true, false);
421 let derivs = res.derivs.expect("want_derivs=true returns derivs");
422 let mut k = 0usize;
423 for arg in args {
424 if let TapeFuncallArg::Tape(idx) = arg {
425 adj[*idx] += a * derivs[k];
426 k += 1;
427 }
428 }
429 }
430 }
431 }
432 }
433
434 pub fn variables(&self) -> Vec<usize> {
436 let mut s: BTreeSet<usize> = BTreeSet::new();
437 for op in &self.ops {
438 if let TapeOp::Var(j) = op {
439 s.insert(*j);
440 }
441 }
442 s.into_iter().collect()
443 }
444
445 fn forward_tangent(&self, vals: &[f64], seed_var: usize, dot: &mut [f64]) {
450 let n = self.ops.len();
451 debug_assert_eq!(dot.len(), n);
452 for i in 0..n {
453 dot[i] = match &self.ops[i] {
454 TapeOp::Const(_) => 0.0,
455 TapeOp::Var(k) => {
456 if *k == seed_var {
457 1.0
458 } else {
459 0.0
460 }
461 }
462 TapeOp::Add(a, b) => dot[*a] + dot[*b],
463 TapeOp::Sub(a, b) => dot[*a] - dot[*b],
464 TapeOp::Mul(a, b) => dot[*a] * vals[*b] + vals[*a] * dot[*b],
465 TapeOp::Div(a, b) => {
466 let vb = vals[*b];
467 (dot[*a] * vb - vals[*a] * dot[*b]) / (vb * vb)
468 }
469 TapeOp::Pow(a, b) => {
470 let u = vals[*a];
471 let r = vals[*b];
472 let du = dot[*a];
473 let dr = dot[*b];
474 let mut result = 0.0;
475 if r != 0.0 && u != 0.0 {
476 result += r * u.powf(r - 1.0) * du;
477 }
478 if u > 0.0 {
479 result += vals[i] * u.ln() * dr;
480 }
481 result
482 }
483 TapeOp::Neg(a) => -dot[*a],
484 TapeOp::Abs(a) => {
485 if vals[*a] >= 0.0 {
486 dot[*a]
487 } else {
488 -dot[*a]
489 }
490 }
491 TapeOp::Sqrt(a) => {
492 let sv = vals[i];
493 if sv > 0.0 {
494 dot[*a] * 0.5 / sv
495 } else {
496 0.0
497 }
498 }
499 TapeOp::Exp(a) => dot[*a] * vals[i],
500 TapeOp::Log(a) => dot[*a] / vals[*a],
501 TapeOp::Log10(a) => dot[*a] / (vals[*a] * std::f64::consts::LN_10),
502 TapeOp::Sin(a) => dot[*a] * vals[*a].cos(),
503 TapeOp::Cos(a) => -dot[*a] * vals[*a].sin(),
504 TapeOp::Tan(a) => {
505 let t = vals[i];
506 dot[*a] * (1.0 + t * t)
507 }
508 TapeOp::Atan(a) => {
509 let u = vals[*a];
510 dot[*a] / (1.0 + u * u)
511 }
512 TapeOp::Acos(a) => {
513 let u = vals[*a];
514 -dot[*a] / (1.0 - u * u).sqrt()
515 }
516 TapeOp::Sinh(a) => dot[*a] * vals[*a].cosh(),
517 TapeOp::Cosh(a) => dot[*a] * vals[*a].sinh(),
518 TapeOp::Tanh(a) => {
519 let t = vals[i];
520 dot[*a] * (1.0 - t * t)
521 }
522 TapeOp::Asin(a) => {
523 let u = vals[*a];
524 dot[*a] / (1.0 - u * u).sqrt()
525 }
526 TapeOp::Acosh(a) => {
527 let u = vals[*a];
528 dot[*a] / (u * u - 1.0).sqrt()
529 }
530 TapeOp::Asinh(a) => {
531 let u = vals[*a];
532 dot[*a] / (u * u + 1.0).sqrt()
533 }
534 TapeOp::Atanh(a) => {
535 let u = vals[*a];
536 dot[*a] / (1.0 - u * u)
537 }
538 TapeOp::Atan2(a, b) => {
539 let y = vals[*a];
540 let x = vals[*b];
541 let d = y * y + x * x;
542 (x * dot[*a] - y * dot[*b]) / d
543 }
544 TapeOp::Min(a, b) => {
546 if vals[*a] <= vals[*b] {
547 dot[*a]
548 } else {
549 dot[*b]
550 }
551 }
552 TapeOp::Max(a, b) => {
553 if vals[*a] >= vals[*b] {
554 dot[*a]
555 } else {
556 dot[*b]
557 }
558 }
559 TapeOp::Cmp(_, _, _) | TapeOp::And(_, _) | TapeOp::Or(_, _) | TapeOp::Not(_) => 0.0,
560 TapeOp::Select(c, t, e) => {
561 if vals[*c] != 0.0 {
562 dot[*t]
563 } else {
564 dot[*e]
565 }
566 }
567 TapeOp::Funcall(fc) => {
568 let FuncallData { lib, name, args } = fc.as_ref();
569 let call_args = funcall_to_ext_args(args, vals);
570 let res = ext_eval_or_nan(lib, name, &call_args, args.len(), true, false);
571 let derivs = res.derivs.expect("want_derivs=true returns derivs");
572 let mut acc = 0.0;
573 let mut k = 0usize;
574 for arg in args {
575 if let TapeFuncallArg::Tape(idx) = arg {
576 acc += derivs[k] * dot[*idx];
577 k += 1;
578 }
579 }
580 acc
581 }
582 };
583 }
584 }
585
586 pub fn forward_into(&self, x: &[f64], vals: &mut [f64]) {
590 let n = self.ops.len();
591 debug_assert!(vals.len() >= n);
592 for i in 0..n {
593 vals[i] = match &self.ops[i] {
594 TapeOp::Const(c) => *c,
595 TapeOp::Var(j) => x[*j],
596 TapeOp::Add(a, b) => vals[*a] + vals[*b],
597 TapeOp::Sub(a, b) => vals[*a] - vals[*b],
598 TapeOp::Mul(a, b) => vals[*a] * vals[*b],
599 TapeOp::Div(a, b) => vals[*a] / vals[*b],
600 TapeOp::Pow(a, b) => vals[*a].powf(vals[*b]),
601 TapeOp::Neg(a) => -vals[*a],
602 TapeOp::Abs(a) => vals[*a].abs(),
603 TapeOp::Sqrt(a) => vals[*a].sqrt(),
604 TapeOp::Exp(a) => vals[*a].exp(),
605 TapeOp::Log(a) => vals[*a].ln(),
606 TapeOp::Log10(a) => vals[*a].log10(),
607 TapeOp::Sin(a) => vals[*a].sin(),
608 TapeOp::Cos(a) => vals[*a].cos(),
609 TapeOp::Tan(a) => vals[*a].tan(),
610 TapeOp::Atan(a) => vals[*a].atan(),
611 TapeOp::Acos(a) => vals[*a].acos(),
612 TapeOp::Sinh(a) => vals[*a].sinh(),
613 TapeOp::Cosh(a) => vals[*a].cosh(),
614 TapeOp::Tanh(a) => vals[*a].tanh(),
615 TapeOp::Asin(a) => vals[*a].asin(),
616 TapeOp::Acosh(a) => vals[*a].acosh(),
617 TapeOp::Asinh(a) => vals[*a].asinh(),
618 TapeOp::Atanh(a) => vals[*a].atanh(),
619 TapeOp::Atan2(a, b) => vals[*a].atan2(vals[*b]),
620 TapeOp::Min(a, b) => vals[*a].min(vals[*b]),
621 TapeOp::Max(a, b) => vals[*a].max(vals[*b]),
622 TapeOp::Cmp(op, a, b) => f64::from(cmp_holds(*op, vals[*a], vals[*b])),
623 TapeOp::And(a, b) => f64::from(vals[*a] != 0.0 && vals[*b] != 0.0),
624 TapeOp::Or(a, b) => f64::from(vals[*a] != 0.0 || vals[*b] != 0.0),
625 TapeOp::Not(a) => f64::from(vals[*a] == 0.0),
626 TapeOp::Select(c, t, e) => {
627 if vals[*c] != 0.0 {
628 vals[*t]
629 } else {
630 vals[*e]
631 }
632 }
633 TapeOp::Funcall(fc) => {
634 let FuncallData { lib, name, args } = fc.as_ref();
635 let call_args = funcall_to_ext_args(args, &*vals);
636 let res = ext_eval_or_nan(lib, name, &call_args, args.len(), false, false);
637 res.value
638 }
639 };
640 }
641 }
642
643 pub fn hessian_directional(
660 &self,
661 vals: &[f64],
662 seed: &[f64],
663 weight: f64,
664 out: &mut [f64],
665 dot: &mut [f64],
666 adj: &mut [f64],
667 adj_dot: &mut [f64],
668 ) {
669 let n = self.ops.len();
670 if n == 0 || weight == 0.0 {
671 return;
672 }
673 debug_assert!(vals.len() >= n);
674 debug_assert!(dot.len() >= n);
675 debug_assert!(adj.len() >= n);
676 debug_assert!(adj_dot.len() >= n);
677
678 for i in 0..n {
682 dot[i] = match &self.ops[i] {
683 TapeOp::Const(_) => 0.0,
684 TapeOp::Var(k) => seed[*k],
685 TapeOp::Add(a, b) => dot[*a] + dot[*b],
686 TapeOp::Sub(a, b) => dot[*a] - dot[*b],
687 TapeOp::Mul(a, b) => dot[*a] * vals[*b] + vals[*a] * dot[*b],
688 TapeOp::Div(a, b) => {
689 let vb = vals[*b];
690 (dot[*a] * vb - vals[*a] * dot[*b]) / (vb * vb)
691 }
692 TapeOp::Pow(a, b) => {
693 let u = vals[*a];
694 let r = vals[*b];
695 let du = dot[*a];
696 let dr = dot[*b];
697 let mut result = 0.0;
698 if r != 0.0 && u != 0.0 {
699 result += r * u.powf(r - 1.0) * du;
700 }
701 if u > 0.0 {
702 result += vals[i] * u.ln() * dr;
703 }
704 result
705 }
706 TapeOp::Neg(a) => -dot[*a],
707 TapeOp::Abs(a) => {
708 if vals[*a] >= 0.0 {
709 dot[*a]
710 } else {
711 -dot[*a]
712 }
713 }
714 TapeOp::Sqrt(a) => {
715 let sv = vals[i];
716 if sv > 0.0 {
717 dot[*a] * 0.5 / sv
718 } else {
719 0.0
720 }
721 }
722 TapeOp::Exp(a) => vals[i] * dot[*a],
723 TapeOp::Log(a) => dot[*a] / vals[*a],
724 TapeOp::Log10(a) => dot[*a] / (vals[*a] * std::f64::consts::LN_10),
725 TapeOp::Sin(a) => vals[*a].cos() * dot[*a],
726 TapeOp::Cos(a) => -vals[*a].sin() * dot[*a],
727 TapeOp::Tan(a) => {
728 let t = vals[i];
729 (1.0 + t * t) * dot[*a]
730 }
731 TapeOp::Atan(a) => {
732 let u = vals[*a];
733 dot[*a] / (1.0 + u * u)
734 }
735 TapeOp::Acos(a) => {
736 let u = vals[*a];
737 -dot[*a] / (1.0 - u * u).sqrt()
738 }
739 TapeOp::Sinh(a) => dot[*a] * vals[*a].cosh(),
740 TapeOp::Cosh(a) => dot[*a] * vals[*a].sinh(),
741 TapeOp::Tanh(a) => {
742 let t = vals[i];
743 (1.0 - t * t) * dot[*a]
744 }
745 TapeOp::Asin(a) => {
746 let u = vals[*a];
747 dot[*a] / (1.0 - u * u).sqrt()
748 }
749 TapeOp::Acosh(a) => {
750 let u = vals[*a];
751 dot[*a] / (u * u - 1.0).sqrt()
752 }
753 TapeOp::Asinh(a) => {
754 let u = vals[*a];
755 dot[*a] / (u * u + 1.0).sqrt()
756 }
757 TapeOp::Atanh(a) => {
758 let u = vals[*a];
759 dot[*a] / (1.0 - u * u)
760 }
761 TapeOp::Atan2(a, b) => {
762 let y = vals[*a];
763 let x = vals[*b];
764 let d = y * y + x * x;
765 (x * dot[*a] - y * dot[*b]) / d
766 }
767 TapeOp::Min(a, b) => {
769 if vals[*a] <= vals[*b] {
770 dot[*a]
771 } else {
772 dot[*b]
773 }
774 }
775 TapeOp::Max(a, b) => {
776 if vals[*a] >= vals[*b] {
777 dot[*a]
778 } else {
779 dot[*b]
780 }
781 }
782 TapeOp::Cmp(_, _, _) | TapeOp::And(_, _) | TapeOp::Or(_, _) | TapeOp::Not(_) => 0.0,
783 TapeOp::Select(c, t, e) => {
784 if vals[*c] != 0.0 {
785 dot[*t]
786 } else {
787 dot[*e]
788 }
789 }
790 TapeOp::Funcall(fc) => {
791 let FuncallData { lib, name, args } = fc.as_ref();
792 let call_args = funcall_to_ext_args(args, vals);
793 let res = ext_eval_or_nan(lib, name, &call_args, args.len(), true, false);
794 let derivs = res.derivs.expect("want_derivs=true returns derivs");
795 let mut acc = 0.0;
796 let mut k = 0usize;
797 for arg in args {
798 if let TapeFuncallArg::Tape(idx) = arg {
799 acc += derivs[k] * dot[*idx];
800 k += 1;
801 }
802 }
803 acc
804 }
805 };
806 }
807
808 for slot in adj.iter_mut().take(n) {
812 *slot = 0.0;
813 }
814 for slot in adj_dot.iter_mut().take(n) {
815 *slot = 0.0;
816 }
817 adj[n - 1] = 1.0;
818
819 for i in (0..n).rev() {
820 let w = adj[i];
821 let wd = adj_dot[i];
822 if w == 0.0 && wd == 0.0 {
823 continue;
824 }
825 match &self.ops[i] {
826 TapeOp::Const(_) => {}
827 TapeOp::Var(k) => {
828 if wd != 0.0 {
829 out[*k] += weight * wd;
830 }
831 }
832 TapeOp::Add(a, b) => {
833 adj[*a] += w;
834 adj[*b] += w;
835 adj_dot[*a] += wd;
836 adj_dot[*b] += wd;
837 }
838 TapeOp::Sub(a, b) => {
839 adj[*a] += w;
840 adj[*b] -= w;
841 adj_dot[*a] += wd;
842 adj_dot[*b] -= wd;
843 }
844 TapeOp::Mul(a, b) => {
845 adj[*a] += w * vals[*b];
846 adj[*b] += w * vals[*a];
847 adj_dot[*a] += wd * vals[*b] + w * dot[*b];
848 adj_dot[*b] += wd * vals[*a] + w * dot[*a];
849 }
850 TapeOp::Div(a, b) => {
851 let vb = vals[*b];
852 let vb2 = vb * vb;
853 let vb3 = vb2 * vb;
854 adj[*a] += w / vb;
855 adj_dot[*a] += wd / vb + w * (-dot[*b] / vb2);
856 adj[*b] += w * (-vals[*a] / vb2);
857 adj_dot[*b] += wd * (-vals[*a] / vb2)
858 + w * (-dot[*a] / vb2 + 2.0 * vals[*a] * dot[*b] / vb3);
859 }
860 TapeOp::Pow(a, b) => {
861 let u = vals[*a];
862 let r = vals[*b];
863 let du = dot[*a];
864 let dr = dot[*b];
865 if r != 0.0 {
866 if u != 0.0 {
867 let p_a = r * u.powf(r - 1.0);
868 adj[*a] += w * p_a;
869 let mut dp_a = dr * u.powf(r - 1.0);
870 if u > 0.0 {
871 dp_a += r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
872 } else {
873 dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
874 }
875 adj_dot[*a] += wd * p_a + w * dp_a;
876 } else if r >= 2.0 {
877 let p_a = 0.0;
878 adj[*a] += w * p_a;
879 let dp_a = if r == 2.0 {
880 2.0 * du
881 } else {
882 r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
883 };
884 adj_dot[*a] += wd * p_a + w * dp_a;
885 }
886 }
887 if u > 0.0 {
888 let ln_u = u.ln();
889 let p_b = vals[i] * ln_u;
890 adj[*b] += w * p_b;
891 let dur = vals[i] * (r * du / u + dr * ln_u);
892 let dp_b = dur * ln_u + vals[i] * du / u;
893 adj_dot[*b] += wd * p_b + w * dp_b;
894 }
895 }
896 TapeOp::Neg(a) => {
897 adj[*a] -= w;
898 adj_dot[*a] -= wd;
899 }
900 TapeOp::Abs(a) => {
901 let s = if vals[*a] >= 0.0 { 1.0 } else { -1.0 };
902 adj[*a] += w * s;
903 adj_dot[*a] += wd * s;
904 }
905 TapeOp::Sqrt(a) => {
906 let sv = vals[i];
907 if sv > 0.0 {
908 let fp = 0.5 / sv;
909 let fpp = -0.25 / (vals[*a] * sv);
910 adj[*a] += w * fp;
911 adj_dot[*a] += wd * fp + w * fpp * dot[*a];
912 }
913 }
914 TapeOp::Exp(a) => {
915 let ev = vals[i];
916 adj[*a] += w * ev;
917 adj_dot[*a] += wd * ev + w * ev * dot[*a];
918 }
919 TapeOp::Log(a) => {
920 let u = vals[*a];
921 adj[*a] += w / u;
922 adj_dot[*a] += wd / u + w * (-1.0 / (u * u)) * dot[*a];
923 }
924 TapeOp::Log10(a) => {
925 let u = vals[*a];
926 let c = std::f64::consts::LN_10;
927 adj[*a] += w / (u * c);
928 adj_dot[*a] += wd / (u * c) + w * (-1.0 / (u * u * c)) * dot[*a];
929 }
930 TapeOp::Sin(a) => {
931 let u = vals[*a];
932 let cu = u.cos();
933 adj[*a] += w * cu;
934 adj_dot[*a] += wd * cu + w * (-u.sin()) * dot[*a];
935 }
936 TapeOp::Cos(a) => {
937 let u = vals[*a];
938 let su = u.sin();
939 adj[*a] -= w * su;
940 adj_dot[*a] += wd * (-su) + w * (-u.cos()) * dot[*a];
941 }
942 TapeOp::Tan(a) => {
943 let t = vals[i];
944 let gp = 1.0 + t * t;
945 let gpp = 2.0 * t * gp;
946 adj[*a] += w * gp;
947 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
948 }
949 TapeOp::Atan(a) => {
950 let u = vals[*a];
951 let d = 1.0 + u * u;
952 let gp = 1.0 / d;
953 let gpp = -2.0 * u / (d * d);
954 adj[*a] += w * gp;
955 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
956 }
957 TapeOp::Acos(a) => {
958 let u = vals[*a];
959 let s = 1.0 - u * u;
960 let r = s.sqrt();
961 let gp = -1.0 / r;
962 let gpp = -u / (s * r);
963 adj[*a] += w * gp;
964 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
965 }
966 TapeOp::Sinh(a) => {
967 let u = vals[*a];
968 let gp = u.cosh();
969 let gpp = u.sinh();
970 adj[*a] += w * gp;
971 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
972 }
973 TapeOp::Cosh(a) => {
974 let u = vals[*a];
975 let gp = u.sinh();
976 let gpp = u.cosh();
977 adj[*a] += w * gp;
978 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
979 }
980 TapeOp::Tanh(a) => {
981 let t = vals[i];
982 let gp = 1.0 - t * t;
983 let gpp = -2.0 * t * gp;
984 adj[*a] += w * gp;
985 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
986 }
987 TapeOp::Asin(a) => {
988 let u = vals[*a];
989 let s = 1.0 - u * u;
990 let r = s.sqrt();
991 let gp = 1.0 / r;
992 let gpp = u / (s * r);
993 adj[*a] += w * gp;
994 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
995 }
996 TapeOp::Acosh(a) => {
997 let u = vals[*a];
998 let s = u * u - 1.0;
999 let r = s.sqrt();
1000 let gp = 1.0 / r;
1001 let gpp = -u / (s * r);
1002 adj[*a] += w * gp;
1003 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1004 }
1005 TapeOp::Asinh(a) => {
1006 let u = vals[*a];
1007 let s = u * u + 1.0;
1008 let r = s.sqrt();
1009 let gp = 1.0 / r;
1010 let gpp = -u / (s * r);
1011 adj[*a] += w * gp;
1012 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1013 }
1014 TapeOp::Atanh(a) => {
1015 let u = vals[*a];
1016 let d = 1.0 - u * u;
1017 let gp = 1.0 / d;
1018 let gpp = 2.0 * u / (d * d);
1019 adj[*a] += w * gp;
1020 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1021 }
1022 TapeOp::Atan2(a, b) => {
1023 let y = vals[*a];
1024 let x = vals[*b];
1025 let d = y * y + x * x;
1026 let d2 = d * d;
1027 let fa = x / d;
1028 let fb = -y / d;
1029 let faa = -2.0 * y * x / d2;
1030 let fab = (y * y - x * x) / d2;
1031 let fbb = 2.0 * y * x / d2;
1032 adj[*a] += w * fa;
1033 adj[*b] += w * fb;
1034 adj_dot[*a] += wd * fa + w * (faa * dot[*a] + fab * dot[*b]);
1035 adj_dot[*b] += wd * fb + w * (fab * dot[*a] + fbb * dot[*b]);
1036 }
1037 TapeOp::Min(a, b) => {
1041 let br = if vals[*a] <= vals[*b] { *a } else { *b };
1042 adj[br] += w;
1043 adj_dot[br] += wd;
1044 }
1045 TapeOp::Max(a, b) => {
1046 let br = if vals[*a] >= vals[*b] { *a } else { *b };
1047 adj[br] += w;
1048 adj_dot[br] += wd;
1049 }
1050 TapeOp::Cmp(_, _, _) | TapeOp::And(_, _) | TapeOp::Or(_, _) | TapeOp::Not(_) => {}
1052 TapeOp::Select(c, t, e) => {
1055 let br = if vals[*c] != 0.0 { *t } else { *e };
1056 adj[br] += w;
1057 adj_dot[br] += wd;
1058 }
1059 TapeOp::Funcall(fc) => {
1060 let FuncallData { lib, name, args } = fc.as_ref();
1061 let call_args = funcall_to_ext_args(args, vals);
1062 let res = ext_eval_or_nan(lib, name, &call_args, args.len(), true, true);
1063 let derivs = res.derivs.expect("want_derivs=true returns derivs");
1064 let hes = res.hessian.expect("want_hes=true returns hessian");
1065 let real_tape: Vec<usize> = args
1066 .iter()
1067 .filter_map(|a| match a {
1068 TapeFuncallArg::Tape(t) => Some(*t),
1069 TapeFuncallArg::Str(_) => None,
1070 })
1071 .collect();
1072 for (k, &tk) in real_tape.iter().enumerate() {
1073 adj[tk] += w * derivs[k];
1074 let mut second_term = 0.0;
1075 for (l, &tl) in real_tape.iter().enumerate() {
1076 let (lo, hi) = if k <= l { (k, l) } else { (l, k) };
1077 let h_kl = hes[lo + hi * (hi + 1) / 2];
1078 second_term += h_kl * dot[tl];
1079 }
1080 adj_dot[tk] += wd * derivs[k] + w * second_term;
1081 }
1082 }
1083 }
1084 }
1085 }
1086
1087 pub fn hessian_accumulate(
1094 &self,
1095 x: &[f64],
1096 weight: f64,
1097 hess_map: &HashMap<(usize, usize), usize>,
1098 values: &mut [f64],
1099 ) {
1100 let n = self.ops.len();
1101 if n == 0 || weight == 0.0 {
1102 return;
1103 }
1104 let v = self.forward(x);
1105 let var_indices = self.variables();
1106
1107 let mut dot = vec![0.0f64; n];
1114 let mut adj = vec![0.0f64; n];
1115 let mut adj_dot = vec![0.0f64; n];
1116 for &j in &var_indices {
1117 self.forward_tangent(&v, j, &mut dot);
1118
1119 adj.fill(0.0);
1122 adj_dot.fill(0.0);
1123 adj[n - 1] = 1.0;
1124
1125 for i in (0..n).rev() {
1126 let w = adj[i];
1127 let wd = adj_dot[i];
1128 if w == 0.0 && wd == 0.0 {
1129 continue;
1130 }
1131 match &self.ops[i] {
1132 TapeOp::Const(_) => {}
1133 TapeOp::Var(k) => {
1134 if wd != 0.0 && *k >= j {
1137 if let Some(&pos) = hess_map.get(&(*k, j)) {
1138 values[pos] += weight * wd;
1139 }
1140 }
1141 }
1142 TapeOp::Add(a, b) => {
1143 adj[*a] += w;
1144 adj[*b] += w;
1145 adj_dot[*a] += wd;
1146 adj_dot[*b] += wd;
1147 }
1148 TapeOp::Sub(a, b) => {
1149 adj[*a] += w;
1150 adj[*b] -= w;
1151 adj_dot[*a] += wd;
1152 adj_dot[*b] -= wd;
1153 }
1154 TapeOp::Mul(a, b) => {
1155 adj[*a] += w * v[*b];
1156 adj[*b] += w * v[*a];
1157 adj_dot[*a] += wd * v[*b] + w * dot[*b];
1158 adj_dot[*b] += wd * v[*a] + w * dot[*a];
1159 }
1160 TapeOp::Div(a, b) => {
1161 let vb = v[*b];
1162 let vb2 = vb * vb;
1163 let vb3 = vb2 * vb;
1164 adj[*a] += w / vb;
1165 adj_dot[*a] += wd / vb + w * (-dot[*b] / vb2);
1166 adj[*b] += w * (-v[*a] / vb2);
1167 adj_dot[*b] += wd * (-v[*a] / vb2)
1168 + w * (-dot[*a] / vb2 + 2.0 * v[*a] * dot[*b] / vb3);
1169 }
1170 TapeOp::Pow(a, b) => {
1171 let u = v[*a];
1172 let r = v[*b];
1173 let du = dot[*a];
1174 let dr = dot[*b];
1175 if r != 0.0 {
1176 if u != 0.0 {
1177 let p_a = r * u.powf(r - 1.0);
1178 adj[*a] += w * p_a;
1179 let mut dp_a = dr * u.powf(r - 1.0);
1180 if u > 0.0 {
1181 dp_a +=
1182 r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
1183 } else {
1184 dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
1185 }
1186 adj_dot[*a] += wd * p_a + w * dp_a;
1187 } else if r >= 2.0 {
1188 let p_a = 0.0;
1189 adj[*a] += w * p_a;
1190 let dp_a = if r == 2.0 {
1191 2.0 * du
1192 } else {
1193 r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
1194 };
1195 adj_dot[*a] += wd * p_a + w * dp_a;
1196 }
1197 }
1198 if u > 0.0 {
1199 let ln_u = u.ln();
1200 let p_b = v[i] * ln_u;
1201 adj[*b] += w * p_b;
1202 let dur = v[i] * (r * du / u + dr * ln_u);
1203 let dp_b = dur * ln_u + v[i] * du / u;
1204 adj_dot[*b] += wd * p_b + w * dp_b;
1205 }
1206 }
1207 TapeOp::Neg(a) => {
1208 adj[*a] -= w;
1209 adj_dot[*a] -= wd;
1210 }
1211 TapeOp::Abs(a) => {
1212 let s = if v[*a] >= 0.0 { 1.0 } else { -1.0 };
1213 adj[*a] += w * s;
1214 adj_dot[*a] += wd * s;
1215 }
1216 TapeOp::Sqrt(a) => {
1217 let sv = v[i];
1218 if sv > 0.0 {
1219 let fp = 0.5 / sv;
1220 let fpp = -0.25 / (v[*a] * sv);
1221 adj[*a] += w * fp;
1222 adj_dot[*a] += wd * fp + w * fpp * dot[*a];
1223 }
1224 }
1225 TapeOp::Exp(a) => {
1226 let ev = v[i];
1227 adj[*a] += w * ev;
1228 adj_dot[*a] += wd * ev + w * ev * dot[*a];
1229 }
1230 TapeOp::Log(a) => {
1231 let u = v[*a];
1232 adj[*a] += w / u;
1233 adj_dot[*a] += wd / u + w * (-1.0 / (u * u)) * dot[*a];
1234 }
1235 TapeOp::Log10(a) => {
1236 let u = v[*a];
1237 let c = std::f64::consts::LN_10;
1238 adj[*a] += w / (u * c);
1239 adj_dot[*a] += wd / (u * c) + w * (-1.0 / (u * u * c)) * dot[*a];
1240 }
1241 TapeOp::Sin(a) => {
1242 let u = v[*a];
1243 let cu = u.cos();
1244 adj[*a] += w * cu;
1245 adj_dot[*a] += wd * cu + w * (-u.sin()) * dot[*a];
1246 }
1247 TapeOp::Cos(a) => {
1248 let u = v[*a];
1249 let su = u.sin();
1250 adj[*a] -= w * su;
1251 adj_dot[*a] += wd * (-su) + w * (-u.cos()) * dot[*a];
1252 }
1253 TapeOp::Tan(a) => {
1254 let t = v[i];
1255 let gp = 1.0 + t * t;
1256 let gpp = 2.0 * t * gp;
1257 adj[*a] += w * gp;
1258 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1259 }
1260 TapeOp::Atan(a) => {
1261 let u = v[*a];
1262 let d = 1.0 + u * u;
1263 let gp = 1.0 / d;
1264 let gpp = -2.0 * u / (d * d);
1265 adj[*a] += w * gp;
1266 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1267 }
1268 TapeOp::Acos(a) => {
1269 let u = v[*a];
1270 let s = 1.0 - u * u;
1271 let r = s.sqrt();
1272 let gp = -1.0 / r;
1273 let gpp = -u / (s * r);
1274 adj[*a] += w * gp;
1275 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1276 }
1277 TapeOp::Sinh(a) => {
1278 let u = v[*a];
1279 let gp = u.cosh();
1280 let gpp = u.sinh();
1281 adj[*a] += w * gp;
1282 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1283 }
1284 TapeOp::Cosh(a) => {
1285 let u = v[*a];
1286 let gp = u.sinh();
1287 let gpp = u.cosh();
1288 adj[*a] += w * gp;
1289 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1290 }
1291 TapeOp::Tanh(a) => {
1292 let t = v[i];
1293 let gp = 1.0 - t * t;
1294 let gpp = -2.0 * t * gp;
1295 adj[*a] += w * gp;
1296 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1297 }
1298 TapeOp::Asin(a) => {
1299 let u = v[*a];
1300 let s = 1.0 - u * u;
1301 let r = s.sqrt();
1302 let gp = 1.0 / r;
1303 let gpp = u / (s * r);
1304 adj[*a] += w * gp;
1305 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1306 }
1307 TapeOp::Acosh(a) => {
1308 let u = v[*a];
1309 let s = u * u - 1.0;
1310 let r = s.sqrt();
1311 let gp = 1.0 / r;
1312 let gpp = -u / (s * r);
1313 adj[*a] += w * gp;
1314 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1315 }
1316 TapeOp::Asinh(a) => {
1317 let u = v[*a];
1318 let s = u * u + 1.0;
1319 let r = s.sqrt();
1320 let gp = 1.0 / r;
1321 let gpp = -u / (s * r);
1322 adj[*a] += w * gp;
1323 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1324 }
1325 TapeOp::Atanh(a) => {
1326 let u = v[*a];
1327 let d = 1.0 - u * u;
1328 let gp = 1.0 / d;
1329 let gpp = 2.0 * u / (d * d);
1330 adj[*a] += w * gp;
1331 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
1332 }
1333 TapeOp::Atan2(a, b) => {
1334 let y = v[*a];
1335 let x = v[*b];
1336 let d = y * y + x * x;
1337 let d2 = d * d;
1338 let fa = x / d;
1339 let fb = -y / d;
1340 let faa = -2.0 * y * x / d2;
1341 let fab = (y * y - x * x) / d2;
1342 let fbb = 2.0 * y * x / d2;
1343 adj[*a] += w * fa;
1344 adj[*b] += w * fb;
1345 adj_dot[*a] += wd * fa + w * (faa * dot[*a] + fab * dot[*b]);
1346 adj_dot[*b] += wd * fb + w * (fab * dot[*a] + fbb * dot[*b]);
1347 }
1348 TapeOp::Min(a, b) => {
1352 let br = if v[*a] <= v[*b] { *a } else { *b };
1353 adj[br] += w;
1354 adj_dot[br] += wd;
1355 }
1356 TapeOp::Max(a, b) => {
1357 let br = if v[*a] >= v[*b] { *a } else { *b };
1358 adj[br] += w;
1359 adj_dot[br] += wd;
1360 }
1361 TapeOp::Cmp(_, _, _)
1363 | TapeOp::And(_, _)
1364 | TapeOp::Or(_, _)
1365 | TapeOp::Not(_) => {}
1366 TapeOp::Select(c, t, e) => {
1369 let br = if v[*c] != 0.0 { *t } else { *e };
1370 adj[br] += w;
1371 adj_dot[br] += wd;
1372 }
1373 TapeOp::Funcall(fc) => {
1374 let FuncallData { lib, name, args } = fc.as_ref();
1375 let call_args = funcall_to_ext_args(args, &v);
1376 let res = ext_eval_or_nan(lib, name, &call_args, args.len(), true, true);
1377 let derivs = res.derivs.expect("want_derivs=true returns derivs");
1378 let hes = res.hessian.expect("want_hes=true returns hessian");
1379 let real_tape: Vec<usize> = args
1380 .iter()
1381 .filter_map(|a| match a {
1382 TapeFuncallArg::Tape(t) => Some(*t),
1383 TapeFuncallArg::Str(_) => None,
1384 })
1385 .collect();
1386 for (k, &tk) in real_tape.iter().enumerate() {
1387 adj[tk] += w * derivs[k];
1388 let mut second_term = 0.0;
1389 for (l, &tl) in real_tape.iter().enumerate() {
1390 let (lo, hi) = if k <= l { (k, l) } else { (l, k) };
1391 let h_kl = hes[lo + hi * (hi + 1) / 2];
1392 second_term += h_kl * dot[tl];
1393 }
1394 adj_dot[tk] += wd * derivs[k] + w * second_term;
1395 }
1396 }
1397 }
1398 }
1399 }
1400 }
1401
1402 pub fn hessian_sparsity(&self) -> BTreeSet<(usize, usize)> {
1407 let n = self.ops.len();
1408 let mut var_sets: Vec<BTreeSet<usize>> = Vec::with_capacity(n);
1409 let mut pairs: BTreeSet<(usize, usize)> = BTreeSet::new();
1410
1411 let emit_cross =
1412 |s1: &BTreeSet<usize>, s2: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
1413 for &v1 in s1 {
1414 for &v2 in s2 {
1415 let (r, c) = if v1 >= v2 { (v1, v2) } else { (v2, v1) };
1416 pairs.insert((r, c));
1417 }
1418 }
1419 };
1420 let emit_self = |s: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
1421 let vars: Vec<usize> = s.iter().copied().collect();
1422 for (ai, &vi) in vars.iter().enumerate() {
1423 for &vj in &vars[..=ai] {
1424 let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
1425 pairs.insert((r, c));
1426 }
1427 }
1428 };
1429
1430 for op in &self.ops {
1431 let vset = match op {
1432 TapeOp::Const(_) => BTreeSet::new(),
1433 TapeOp::Var(j) => {
1434 let mut s = BTreeSet::new();
1435 s.insert(*j);
1436 s
1437 }
1438 TapeOp::Add(a, b) | TapeOp::Sub(a, b) => {
1439 var_sets[*a].union(&var_sets[*b]).copied().collect()
1440 }
1441 TapeOp::Neg(a) | TapeOp::Abs(a) => var_sets[*a].clone(),
1442 TapeOp::Mul(a, b) => {
1443 emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
1444 var_sets[*a].union(&var_sets[*b]).copied().collect()
1445 }
1446 TapeOp::Div(a, b) => {
1447 emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
1448 emit_self(&var_sets[*b], &mut pairs);
1449 var_sets[*a].union(&var_sets[*b]).copied().collect()
1450 }
1451 TapeOp::Pow(a, b) => {
1452 let combined: BTreeSet<usize> =
1453 var_sets[*a].union(&var_sets[*b]).copied().collect();
1454 emit_self(&combined, &mut pairs);
1455 combined
1456 }
1457 TapeOp::Sqrt(a)
1458 | TapeOp::Exp(a)
1459 | TapeOp::Log(a)
1460 | TapeOp::Log10(a)
1461 | TapeOp::Sin(a)
1462 | TapeOp::Cos(a)
1463 | TapeOp::Tan(a)
1464 | TapeOp::Atan(a)
1465 | TapeOp::Acos(a)
1466 | TapeOp::Sinh(a)
1467 | TapeOp::Cosh(a)
1468 | TapeOp::Tanh(a)
1469 | TapeOp::Asin(a)
1470 | TapeOp::Acosh(a)
1471 | TapeOp::Asinh(a)
1472 | TapeOp::Atanh(a) => {
1473 emit_self(&var_sets[*a], &mut pairs);
1474 var_sets[*a].clone()
1475 }
1476 TapeOp::Atan2(a, b) => {
1480 let combined: BTreeSet<usize> =
1481 var_sets[*a].union(&var_sets[*b]).copied().collect();
1482 emit_self(&combined, &mut pairs);
1483 combined
1484 }
1485 TapeOp::Cmp(_, _, _) | TapeOp::And(_, _) | TapeOp::Or(_, _) | TapeOp::Not(_) => {
1491 BTreeSet::new()
1492 }
1493 TapeOp::Select(_c, t, e) => var_sets[*t].union(&var_sets[*e]).copied().collect(),
1500 TapeOp::Min(a, b) | TapeOp::Max(a, b) => {
1507 var_sets[*a].union(&var_sets[*b]).copied().collect()
1508 }
1509 TapeOp::Funcall(fc) => {
1510 let args = &fc.args;
1511 let mut combined: BTreeSet<usize> = BTreeSet::new();
1512 for arg in args {
1513 if let TapeFuncallArg::Tape(t) = arg {
1514 for &vv in &var_sets[*t] {
1515 combined.insert(vv);
1516 }
1517 }
1518 }
1519 emit_self(&combined, &mut pairs);
1520 combined
1521 }
1522 };
1523 var_sets.push(vset);
1524 }
1525 pairs
1526 }
1527}
1528
1529fn build_recursive(
1530 expr: &Expr,
1531 ops: &mut Vec<TapeOp>,
1532 cache: &mut HashMap<*const Expr, usize>,
1533 resolver: &ExternalResolver,
1534) -> usize {
1535 match expr {
1536 Expr::Const(c) => {
1537 let idx = ops.len();
1538 ops.push(TapeOp::Const(*c));
1539 idx
1540 }
1541 Expr::Var(i) => {
1542 let idx = ops.len();
1543 ops.push(TapeOp::Var(*i));
1544 idx
1545 }
1546 Expr::Binary(op, a, b) => {
1547 if let BinOp::Pow = op {
1555 if let Some(c) = peek_const(b) {
1556 if let Some(idx) = try_emit_const_pow(a, c, ops, cache, resolver) {
1557 return idx;
1558 }
1559 }
1560 }
1561 let l = build_recursive(a, ops, cache, resolver);
1562 let r = build_recursive(b, ops, cache, resolver);
1563 let idx = ops.len();
1564 ops.push(match op {
1565 BinOp::Add => TapeOp::Add(l, r),
1566 BinOp::Sub => TapeOp::Sub(l, r),
1567 BinOp::Mul => TapeOp::Mul(l, r),
1568 BinOp::Div => TapeOp::Div(l, r),
1569 BinOp::Pow => TapeOp::Pow(l, r),
1570 BinOp::Atan2 => TapeOp::Atan2(l, r),
1571 });
1572 idx
1573 }
1574 Expr::Unary(op, a) => {
1575 let v = build_recursive(a, ops, cache, resolver);
1576 let idx = ops.len();
1577 ops.push(match op {
1578 UnaryOp::Neg => TapeOp::Neg(v),
1579 UnaryOp::Sqrt => TapeOp::Sqrt(v),
1580 UnaryOp::Log => TapeOp::Log(v),
1581 UnaryOp::Log10 => TapeOp::Log10(v),
1582 UnaryOp::Exp => TapeOp::Exp(v),
1583 UnaryOp::Abs => TapeOp::Abs(v),
1584 UnaryOp::Sin => TapeOp::Sin(v),
1585 UnaryOp::Cos => TapeOp::Cos(v),
1586 UnaryOp::Tan => TapeOp::Tan(v),
1587 UnaryOp::Atan => TapeOp::Atan(v),
1588 UnaryOp::Acos => TapeOp::Acos(v),
1589 UnaryOp::Sinh => TapeOp::Sinh(v),
1590 UnaryOp::Cosh => TapeOp::Cosh(v),
1591 UnaryOp::Tanh => TapeOp::Tanh(v),
1592 UnaryOp::Asin => TapeOp::Asin(v),
1593 UnaryOp::Acosh => TapeOp::Acosh(v),
1594 UnaryOp::Asinh => TapeOp::Asinh(v),
1595 UnaryOp::Atanh => TapeOp::Atanh(v),
1596 });
1597 idx
1598 }
1599 Expr::Sum(args) => {
1600 if args.is_empty() {
1601 let idx = ops.len();
1602 ops.push(TapeOp::Const(0.0));
1603 return idx;
1604 }
1605 let mut acc = build_recursive(&args[0], ops, cache, resolver);
1606 for a in &args[1..] {
1607 let next = build_recursive(a, ops, cache, resolver);
1608 let idx = ops.len();
1609 ops.push(TapeOp::Add(acc, next));
1610 acc = idx;
1611 }
1612 acc
1613 }
1614 Expr::MinList(args) | Expr::MaxList(args) => {
1622 let is_min = matches!(expr, Expr::MinList(_));
1623 if args.is_empty() {
1624 let idx = ops.len();
1625 ops.push(TapeOp::Const(0.0));
1626 return idx;
1627 }
1628 let mut acc = build_recursive(&args[0], ops, cache, resolver);
1629 for a in &args[1..] {
1630 let next = build_recursive(a, ops, cache, resolver);
1631 let idx = ops.len();
1632 ops.push(if is_min {
1633 TapeOp::Min(acc, next)
1634 } else {
1635 TapeOp::Max(acc, next)
1636 });
1637 acc = idx;
1638 }
1639 acc
1640 }
1641 Expr::Cse(body) => {
1642 let key = Rc::as_ptr(body) as *const Expr;
1649 if let Some(&idx) = cache.get(&key) {
1650 idx
1651 } else {
1652 let idx = build_recursive(body, ops, cache, resolver);
1653 cache.insert(key, idx);
1654 idx
1655 }
1656 }
1657 Expr::Compare(op, a, b) => {
1658 let l = build_recursive(a, ops, cache, resolver);
1659 let r = build_recursive(b, ops, cache, resolver);
1660 let idx = ops.len();
1661 ops.push(TapeOp::Cmp(*op, l, r));
1662 idx
1663 }
1664 Expr::And(a, b) => {
1665 let l = build_recursive(a, ops, cache, resolver);
1666 let r = build_recursive(b, ops, cache, resolver);
1667 let idx = ops.len();
1668 ops.push(TapeOp::And(l, r));
1669 idx
1670 }
1671 Expr::Or(a, b) => {
1672 let l = build_recursive(a, ops, cache, resolver);
1673 let r = build_recursive(b, ops, cache, resolver);
1674 let idx = ops.len();
1675 ops.push(TapeOp::Or(l, r));
1676 idx
1677 }
1678 Expr::Not(a) => {
1679 let v = build_recursive(a, ops, cache, resolver);
1680 let idx = ops.len();
1681 ops.push(TapeOp::Not(v));
1682 idx
1683 }
1684 Expr::Cond { cond, then_, else_ } => {
1685 let c = build_recursive(cond, ops, cache, resolver);
1686 let t = build_recursive(then_, ops, cache, resolver);
1687 let e = build_recursive(else_, ops, cache, resolver);
1688 let idx = ops.len();
1689 ops.push(TapeOp::Select(c, t, e));
1690 idx
1691 }
1692 Expr::Funcall { id, args } => {
1693 let (lib, name) = resolver
1694 .funcs_by_id
1695 .get(id)
1696 .unwrap_or_else(|| panic!("unresolved AMPL funcall id {id}"));
1697 let tape_args: Vec<TapeFuncallArg> = args
1698 .iter()
1699 .map(|a| match a {
1700 FuncallArg::Real(e) => {
1701 TapeFuncallArg::Tape(build_recursive(e, ops, cache, resolver))
1702 }
1703 FuncallArg::Str(s) => TapeFuncallArg::Str(s.clone()),
1704 })
1705 .collect();
1706 let idx = ops.len();
1707 ops.push(TapeOp::Funcall(Box::new(FuncallData {
1708 lib: Arc::clone(lib),
1709 name: name.clone(),
1710 args: tape_args,
1711 })));
1712 idx
1713 }
1714 }
1715}
1716
1717fn peek_const(e: &Expr) -> Option<f64> {
1721 match e {
1722 Expr::Const(c) => Some(*c),
1723 Expr::Cse(body) => peek_const(body),
1724 _ => None,
1725 }
1726}
1727
1728fn try_emit_const_pow(
1736 base_expr: &Expr,
1737 c: f64,
1738 ops: &mut Vec<TapeOp>,
1739 cache: &mut HashMap<*const Expr, usize>,
1740 resolver: &ExternalResolver,
1741) -> Option<usize> {
1742 if c == 0.0 {
1743 let idx = ops.len();
1744 ops.push(TapeOp::Const(1.0));
1745 return Some(idx);
1746 }
1747 if c == 1.0 {
1748 return Some(build_recursive(base_expr, ops, cache, resolver));
1749 }
1750 if c == 0.5 {
1751 let b = build_recursive(base_expr, ops, cache, resolver);
1752 let idx = ops.len();
1753 ops.push(TapeOp::Sqrt(b));
1754 return Some(idx);
1755 }
1756 if c.is_finite() && c.fract() == 0.0 && c.abs() <= 8.0 {
1761 let n = c.abs() as u32;
1762 if n == 0 {
1763 let idx = ops.len();
1765 ops.push(TapeOp::Const(1.0));
1766 return Some(idx);
1767 }
1768 let b = build_recursive(base_expr, ops, cache, resolver);
1769 let pos = emit_int_pow(b, n, ops);
1770 if c < 0.0 {
1771 let one_idx = ops.len();
1774 ops.push(TapeOp::Const(1.0));
1775 let idx = ops.len();
1776 ops.push(TapeOp::Div(one_idx, pos));
1777 return Some(idx);
1778 }
1779 return Some(pos);
1780 }
1781 None
1782}
1783
1784fn emit_int_pow(base: usize, n: u32, ops: &mut Vec<TapeOp>) -> usize {
1788 debug_assert!(n >= 1);
1789 if n == 1 {
1790 return base;
1791 }
1792 let half = emit_int_pow(base, n / 2, ops);
1793 let squared = ops.len();
1794 ops.push(TapeOp::Mul(half, half));
1795 if n % 2 == 1 {
1796 let idx = ops.len();
1797 ops.push(TapeOp::Mul(squared, base));
1798 idx
1799 } else {
1800 squared
1801 }
1802}
1803
1804#[derive(Debug, Clone)]
1832pub enum SummandOp {
1833 Local(TapeOp),
1836 Shared(usize),
1840}
1841
1842#[derive(Debug, Clone)]
1843pub struct Summand {
1844 pub ops: Vec<SummandOp>,
1845 pub root_slot: usize,
1847 pub local_reach: Vec<usize>,
1849 pub prelude_reach: Vec<usize>,
1852 pub local_vars: Vec<usize>,
1854 pub prelude_vars: Vec<usize>,
1856 pub all_vars: Vec<usize>,
1858}
1859
1860#[derive(Debug)]
1861pub struct HybridTape {
1862 pub prelude: Vec<TapeOp>,
1867 pub summands: Vec<Summand>,
1868}
1869
1870impl HybridTape {
1871 pub fn build_multi(exprs: &[Expr]) -> Self {
1876 let mut cse_count: HashMap<*const Expr, usize> = HashMap::new();
1880 for e in exprs {
1881 let mut seen_in_root: HashSet<*const Expr> = HashSet::new();
1882 count_cse_appearances(e, &mut seen_in_root, &mut cse_count);
1883 }
1884
1885 let mut prelude: Vec<TapeOp> = Vec::new();
1890 let mut prelude_map: HashMap<*const Expr, usize> = HashMap::new();
1891 let mut summands: Vec<Summand> = Vec::with_capacity(exprs.len());
1892 for e in exprs {
1893 let mut local: Vec<SummandOp> = Vec::new();
1894 let mut local_cache: HashMap<*const Expr, usize> = HashMap::new();
1895 let root_slot = build_into_summand(
1896 e,
1897 &mut local,
1898 &mut local_cache,
1899 &mut prelude,
1900 &mut prelude_map,
1901 &cse_count,
1902 );
1903 summands.push(Summand {
1904 ops: local,
1905 root_slot,
1906 local_reach: Vec::new(),
1907 prelude_reach: Vec::new(),
1908 local_vars: Vec::new(),
1909 prelude_vars: Vec::new(),
1910 all_vars: Vec::new(),
1911 });
1912 }
1913
1914 let mut p_visited: Vec<u32> = vec![0; prelude.len()];
1918 let mut p_epoch: u32 = 0;
1919 let mut p_stack: Vec<usize> = Vec::new();
1920 for s in &mut summands {
1921 let (local_reach, shared_refs) = compute_local_reach(&s.ops, s.root_slot);
1922 s.local_reach = local_reach;
1923
1924 let mut lv: BTreeSet<usize> = BTreeSet::new();
1925 for &i in &s.local_reach {
1926 if let SummandOp::Local(TapeOp::Var(j)) = &s.ops[i] {
1927 lv.insert(*j);
1928 }
1929 }
1930 s.local_vars = lv.iter().copied().collect();
1931
1932 if !shared_refs.is_empty() {
1933 p_epoch += 1;
1934 let mut preach: Vec<usize> = Vec::new();
1935 for &start in &shared_refs {
1936 bfs_prelude(
1937 &prelude,
1938 start,
1939 &mut p_visited,
1940 p_epoch,
1941 &mut p_stack,
1942 &mut preach,
1943 );
1944 }
1945 preach.sort_unstable();
1946 s.prelude_vars = vars_in(&prelude, &preach);
1947 s.prelude_reach = preach;
1948 }
1949
1950 let mut av: BTreeSet<usize> = lv;
1951 for &v in &s.prelude_vars {
1952 av.insert(v);
1953 }
1954 s.all_vars = av.into_iter().collect();
1955 }
1956
1957 HybridTape { prelude, summands }
1958 }
1959
1960 pub fn n_prelude_ops(&self) -> usize {
1961 self.prelude.len()
1962 }
1963 pub fn n_summands(&self) -> usize {
1964 self.summands.len()
1965 }
1966 pub fn max_summand_ops(&self) -> usize {
1967 self.summands.iter().map(|s| s.ops.len()).max().unwrap_or(0)
1968 }
1969 pub fn total_local_ops(&self) -> usize {
1970 self.summands.iter().map(|s| s.ops.len()).sum()
1971 }
1972
1973 pub fn forward_prelude(&self, x: &[f64], prelude_vals: &mut [f64]) {
1976 debug_assert_eq!(prelude_vals.len(), self.prelude.len());
1977 for i in 0..self.prelude.len() {
1978 prelude_vals[i] = fwd_step(&self.prelude[i], x, prelude_vals);
1979 }
1980 }
1981
1982 pub fn forward_summand(
1985 &self,
1986 s: &Summand,
1987 x: &[f64],
1988 prelude_vals: &[f64],
1989 local_vals: &mut [f64],
1990 ) {
1991 debug_assert!(local_vals.len() >= s.ops.len());
1992 for i in 0..s.ops.len() {
1993 local_vals[i] = match &s.ops[i] {
1994 SummandOp::Local(op) => fwd_step(op, x, local_vals),
1995 SummandOp::Shared(k) => prelude_vals[*k],
1996 };
1997 }
1998 }
1999
2000 #[inline]
2002 pub fn root_value(&self, s: &Summand, local_vals: &[f64]) -> f64 {
2003 local_vals[s.root_slot]
2004 }
2005
2006 #[allow(clippy::too_many_arguments)]
2013 pub fn gradient_summand(
2014 &self,
2015 s: &Summand,
2016 prelude_vals: &[f64],
2017 local_vals: &[f64],
2018 seed: f64,
2019 grad: &mut [f64],
2020 local_adj: &mut [f64],
2021 prelude_adj: &mut [f64],
2022 ) {
2023 if seed == 0.0 || s.local_reach.is_empty() {
2024 return;
2025 }
2026 for &i in &s.local_reach {
2027 local_adj[i] = 0.0;
2028 }
2029 for &i in &s.prelude_reach {
2030 prelude_adj[i] = 0.0;
2031 }
2032 local_adj[s.root_slot] = seed;
2033 for &i in s.local_reach.iter().rev() {
2034 let a = local_adj[i];
2035 if a == 0.0 {
2036 continue;
2037 }
2038 match &s.ops[i] {
2039 SummandOp::Local(op) => rev_step(op, i, local_vals, local_adj, a, grad),
2040 SummandOp::Shared(k) => {
2041 prelude_adj[*k] += a;
2042 }
2043 }
2044 }
2045 for &i in s.prelude_reach.iter().rev() {
2046 let a = prelude_adj[i];
2047 if a == 0.0 {
2048 continue;
2049 }
2050 rev_step(&self.prelude[i], i, prelude_vals, prelude_adj, a, grad);
2051 }
2052 }
2053
2054 #[allow(clippy::too_many_arguments)]
2062 pub fn hessian_summand(
2063 &self,
2064 s: &Summand,
2065 prelude_vals: &[f64],
2066 local_vals: &[f64],
2067 weight: f64,
2068 hess_map: &HashMap<(usize, usize), usize>,
2069 values: &mut [f64],
2070 local_dot: &mut [f64],
2071 local_adj: &mut [f64],
2072 local_adj_dot: &mut [f64],
2073 prelude_dot: &mut [f64],
2074 prelude_adj: &mut [f64],
2075 prelude_adj_dot: &mut [f64],
2076 ) {
2077 if weight == 0.0 || s.local_reach.is_empty() {
2078 return;
2079 }
2080 for &j in &s.all_vars {
2081 for &i in &s.local_reach {
2082 local_dot[i] = 0.0;
2083 local_adj[i] = 0.0;
2084 local_adj_dot[i] = 0.0;
2085 }
2086 for &i in &s.prelude_reach {
2087 prelude_dot[i] = 0.0;
2088 prelude_adj[i] = 0.0;
2089 prelude_adj_dot[i] = 0.0;
2090 }
2091 for &i in &s.prelude_reach {
2092 prelude_dot[i] = fwd_tan_step(&self.prelude[i], j, prelude_vals, prelude_dot, i);
2093 }
2094 for &i in &s.local_reach {
2095 local_dot[i] = match &s.ops[i] {
2096 SummandOp::Local(op) => fwd_tan_step(op, j, local_vals, local_dot, i),
2097 SummandOp::Shared(k) => prelude_dot[*k],
2098 };
2099 }
2100 local_adj[s.root_slot] = 1.0;
2101 for &i in s.local_reach.iter().rev() {
2102 let w = local_adj[i];
2103 let wd = local_adj_dot[i];
2104 if w == 0.0 && wd == 0.0 {
2105 continue;
2106 }
2107 match &s.ops[i] {
2108 SummandOp::Local(op) => {
2109 ror_step(
2110 op,
2111 i,
2112 j,
2113 local_vals,
2114 local_dot,
2115 local_adj,
2116 local_adj_dot,
2117 w,
2118 wd,
2119 weight,
2120 hess_map,
2121 values,
2122 );
2123 }
2124 SummandOp::Shared(k) => {
2125 prelude_adj[*k] += w;
2126 prelude_adj_dot[*k] += wd;
2127 }
2128 }
2129 }
2130 for &i in s.prelude_reach.iter().rev() {
2131 let w = prelude_adj[i];
2132 let wd = prelude_adj_dot[i];
2133 if w == 0.0 && wd == 0.0 {
2134 continue;
2135 }
2136 ror_step(
2137 &self.prelude[i],
2138 i,
2139 j,
2140 prelude_vals,
2141 prelude_dot,
2142 prelude_adj,
2143 prelude_adj_dot,
2144 w,
2145 wd,
2146 weight,
2147 hess_map,
2148 values,
2149 );
2150 }
2151 }
2152 }
2153
2154 pub fn hessian_sparsity_all(&self) -> BTreeSet<(usize, usize)> {
2157 let mut pairs = hessian_sparsity_impl(&self.prelude);
2158
2159 let prelude_var_sets = compute_var_sets(&self.prelude);
2162
2163 for s in &self.summands {
2164 summand_sparsity(&s.ops, &prelude_var_sets, &mut pairs);
2165 }
2166 pairs
2167 }
2168}
2169
2170fn count_cse_appearances(
2175 e: &Expr,
2176 seen_in_root: &mut HashSet<*const Expr>,
2177 counts: &mut HashMap<*const Expr, usize>,
2178) {
2179 match e {
2180 Expr::Const(_) | Expr::Var(_) => {}
2181 Expr::Binary(_, a, b) => {
2182 count_cse_appearances(a, seen_in_root, counts);
2183 count_cse_appearances(b, seen_in_root, counts);
2184 }
2185 Expr::Unary(_, a) => count_cse_appearances(a, seen_in_root, counts),
2186 Expr::Sum(args) | Expr::MinList(args) | Expr::MaxList(args) => {
2187 for a in args {
2188 count_cse_appearances(a, seen_in_root, counts);
2189 }
2190 }
2191 Expr::Compare(_, a, b) | Expr::And(a, b) | Expr::Or(a, b) => {
2192 count_cse_appearances(a, seen_in_root, counts);
2193 count_cse_appearances(b, seen_in_root, counts);
2194 }
2195 Expr::Not(a) => count_cse_appearances(a, seen_in_root, counts),
2196 Expr::Cond { cond, then_, else_ } => {
2197 count_cse_appearances(cond, seen_in_root, counts);
2198 count_cse_appearances(then_, seen_in_root, counts);
2199 count_cse_appearances(else_, seen_in_root, counts);
2200 }
2201 Expr::Cse(body) => {
2202 let key = Rc::as_ptr(body) as *const Expr;
2203 if seen_in_root.insert(key) {
2204 *counts.entry(key).or_insert(0) += 1;
2205 count_cse_appearances(body, seen_in_root, counts);
2206 }
2207 }
2208 Expr::Funcall { args, .. } => {
2209 for arg in args {
2210 if let FuncallArg::Real(e) = arg {
2211 count_cse_appearances(e, seen_in_root, counts);
2212 }
2213 }
2214 }
2215 }
2216}
2217
2218fn build_into_summand(
2224 expr: &Expr,
2225 local: &mut Vec<SummandOp>,
2226 local_cache: &mut HashMap<*const Expr, usize>,
2227 prelude: &mut Vec<TapeOp>,
2228 prelude_map: &mut HashMap<*const Expr, usize>,
2229 cse_count: &HashMap<*const Expr, usize>,
2230) -> usize {
2231 match expr {
2232 Expr::Const(c) => {
2233 let i = local.len();
2234 local.push(SummandOp::Local(TapeOp::Const(*c)));
2235 i
2236 }
2237 Expr::Var(j) => {
2238 let i = local.len();
2239 local.push(SummandOp::Local(TapeOp::Var(*j)));
2240 i
2241 }
2242 Expr::Binary(op, a, b) => {
2243 if let BinOp::Pow = op {
2244 if let Some(c) = peek_const(b) {
2245 if let Some(i) = try_emit_const_pow_summand(
2246 a,
2247 c,
2248 local,
2249 local_cache,
2250 prelude,
2251 prelude_map,
2252 cse_count,
2253 ) {
2254 return i;
2255 }
2256 }
2257 }
2258 let l = build_into_summand(a, local, local_cache, prelude, prelude_map, cse_count);
2259 let r = build_into_summand(b, local, local_cache, prelude, prelude_map, cse_count);
2260 let i = local.len();
2261 local.push(SummandOp::Local(match op {
2262 BinOp::Add => TapeOp::Add(l, r),
2263 BinOp::Sub => TapeOp::Sub(l, r),
2264 BinOp::Mul => TapeOp::Mul(l, r),
2265 BinOp::Div => TapeOp::Div(l, r),
2266 BinOp::Pow => TapeOp::Pow(l, r),
2267 BinOp::Atan2 => TapeOp::Atan2(l, r),
2268 }));
2269 i
2270 }
2271 Expr::Unary(op, a) => {
2272 let v = build_into_summand(a, local, local_cache, prelude, prelude_map, cse_count);
2273 let i = local.len();
2274 local.push(SummandOp::Local(match op {
2275 UnaryOp::Neg => TapeOp::Neg(v),
2276 UnaryOp::Sqrt => TapeOp::Sqrt(v),
2277 UnaryOp::Log => TapeOp::Log(v),
2278 UnaryOp::Log10 => TapeOp::Log10(v),
2279 UnaryOp::Exp => TapeOp::Exp(v),
2280 UnaryOp::Abs => TapeOp::Abs(v),
2281 UnaryOp::Sin => TapeOp::Sin(v),
2282 UnaryOp::Cos => TapeOp::Cos(v),
2283 UnaryOp::Tan => TapeOp::Tan(v),
2284 UnaryOp::Atan => TapeOp::Atan(v),
2285 UnaryOp::Acos => TapeOp::Acos(v),
2286 UnaryOp::Sinh => TapeOp::Sinh(v),
2287 UnaryOp::Cosh => TapeOp::Cosh(v),
2288 UnaryOp::Tanh => TapeOp::Tanh(v),
2289 UnaryOp::Asin => TapeOp::Asin(v),
2290 UnaryOp::Acosh => TapeOp::Acosh(v),
2291 UnaryOp::Asinh => TapeOp::Asinh(v),
2292 UnaryOp::Atanh => TapeOp::Atanh(v),
2293 }));
2294 i
2295 }
2296 Expr::Sum(args) => {
2297 if args.is_empty() {
2298 let i = local.len();
2299 local.push(SummandOp::Local(TapeOp::Const(0.0)));
2300 return i;
2301 }
2302 let mut acc = build_into_summand(
2303 &args[0],
2304 local,
2305 local_cache,
2306 prelude,
2307 prelude_map,
2308 cse_count,
2309 );
2310 for a in &args[1..] {
2311 let nxt =
2312 build_into_summand(a, local, local_cache, prelude, prelude_map, cse_count);
2313 let i = local.len();
2314 local.push(SummandOp::Local(TapeOp::Add(acc, nxt)));
2315 acc = i;
2316 }
2317 acc
2318 }
2319 Expr::Cse(body) => {
2320 let key = Rc::as_ptr(body) as *const Expr;
2321 if let Some(&li) = local_cache.get(&key) {
2322 return li;
2323 }
2324 let promoted = cse_count.get(&key).copied().unwrap_or(0) >= 2;
2325 if promoted {
2326 let pslot =
2331 build_recursive(expr, prelude, prelude_map, &ExternalResolver::default());
2332 let li = local.len();
2333 local.push(SummandOp::Shared(pslot));
2334 local_cache.insert(key, li);
2335 li
2336 } else {
2337 let li =
2338 build_into_summand(body, local, local_cache, prelude, prelude_map, cse_count);
2339 local_cache.insert(key, li);
2340 li
2341 }
2342 }
2343 Expr::Compare(_, _, _)
2344 | Expr::And(_, _)
2345 | Expr::Or(_, _)
2346 | Expr::Not(_)
2347 | Expr::Cond { .. }
2348 | Expr::MinList(_)
2349 | Expr::MaxList(_) => {
2350 panic!(
2351 "HybridTape: conditional / logical / min-max opcodes (comparisons, \
2352 AND/OR/NOT, if-then-else, min/max lists) are not supported on the \
2353 hybrid (partial-separability) tape path. Build with \
2354 Tape::build_with_externals instead."
2355 );
2356 }
2357 Expr::Funcall { .. } => {
2358 panic!(
2359 "HybridTape: AMPL external function calls are not supported on the \
2360 hybrid (partial-separability) tape path. Build with Tape::build_with_externals \
2361 instead."
2362 );
2363 }
2364 }
2365}
2366
2367fn try_emit_const_pow_summand(
2370 base_expr: &Expr,
2371 c: f64,
2372 local: &mut Vec<SummandOp>,
2373 local_cache: &mut HashMap<*const Expr, usize>,
2374 prelude: &mut Vec<TapeOp>,
2375 prelude_map: &mut HashMap<*const Expr, usize>,
2376 cse_count: &HashMap<*const Expr, usize>,
2377) -> Option<usize> {
2378 if c == 0.0 {
2379 let i = local.len();
2380 local.push(SummandOp::Local(TapeOp::Const(1.0)));
2381 return Some(i);
2382 }
2383 if c == 1.0 {
2384 return Some(build_into_summand(
2385 base_expr,
2386 local,
2387 local_cache,
2388 prelude,
2389 prelude_map,
2390 cse_count,
2391 ));
2392 }
2393 if c == 0.5 {
2394 let b = build_into_summand(
2395 base_expr,
2396 local,
2397 local_cache,
2398 prelude,
2399 prelude_map,
2400 cse_count,
2401 );
2402 let i = local.len();
2403 local.push(SummandOp::Local(TapeOp::Sqrt(b)));
2404 return Some(i);
2405 }
2406 if c.is_finite() && c.fract() == 0.0 && c.abs() <= 8.0 {
2407 let n = c.abs() as u32;
2408 if n == 0 {
2409 let i = local.len();
2410 local.push(SummandOp::Local(TapeOp::Const(1.0)));
2411 return Some(i);
2412 }
2413 let b = build_into_summand(
2414 base_expr,
2415 local,
2416 local_cache,
2417 prelude,
2418 prelude_map,
2419 cse_count,
2420 );
2421 let pos = emit_int_pow_summand(b, n, local);
2422 if c < 0.0 {
2423 let one_idx = local.len();
2424 local.push(SummandOp::Local(TapeOp::Const(1.0)));
2425 let i = local.len();
2426 local.push(SummandOp::Local(TapeOp::Div(one_idx, pos)));
2427 return Some(i);
2428 }
2429 return Some(pos);
2430 }
2431 None
2432}
2433
2434fn emit_int_pow_summand(base: usize, n: u32, local: &mut Vec<SummandOp>) -> usize {
2435 debug_assert!(n >= 1);
2436 if n == 1 {
2437 return base;
2438 }
2439 let half = emit_int_pow_summand(base, n / 2, local);
2440 let squared = local.len();
2441 local.push(SummandOp::Local(TapeOp::Mul(half, half)));
2442 if n % 2 == 1 {
2443 let i = local.len();
2444 local.push(SummandOp::Local(TapeOp::Mul(squared, base)));
2445 i
2446 } else {
2447 squared
2448 }
2449}
2450
2451fn compute_local_reach(ops: &[SummandOp], root: usize) -> (Vec<usize>, Vec<usize>) {
2455 let mut visited = vec![false; ops.len()];
2456 let mut reach: Vec<usize> = Vec::new();
2457 let mut shared: BTreeSet<usize> = BTreeSet::new();
2458 let mut stack: Vec<usize> = Vec::with_capacity(16);
2459 visited[root] = true;
2460 reach.push(root);
2461 stack.push(root);
2462 while let Some(s) = stack.pop() {
2463 match &ops[s] {
2464 SummandOp::Local(op) => {
2465 let (a, b) = op_operands(op);
2466 if let Some(a) = a {
2467 if !visited[a] {
2468 visited[a] = true;
2469 reach.push(a);
2470 stack.push(a);
2471 }
2472 }
2473 if let Some(b) = b {
2474 if !visited[b] {
2475 visited[b] = true;
2476 reach.push(b);
2477 stack.push(b);
2478 }
2479 }
2480 }
2481 SummandOp::Shared(k) => {
2482 shared.insert(*k);
2483 }
2484 }
2485 }
2486 reach.sort_unstable();
2487 (reach, shared.into_iter().collect())
2488}
2489
2490fn bfs_prelude(
2494 prelude: &[TapeOp],
2495 start: usize,
2496 visited: &mut [u32],
2497 cur: u32,
2498 stack: &mut Vec<usize>,
2499 out: &mut Vec<usize>,
2500) {
2501 if visited[start] == cur {
2502 return;
2503 }
2504 visited[start] = cur;
2505 out.push(start);
2506 stack.push(start);
2507 while let Some(s) = stack.pop() {
2508 let (a, b) = op_operands(&prelude[s]);
2509 if let Some(a) = a {
2510 if visited[a] != cur {
2511 visited[a] = cur;
2512 out.push(a);
2513 stack.push(a);
2514 }
2515 }
2516 if let Some(b) = b {
2517 if visited[b] != cur {
2518 visited[b] = cur;
2519 out.push(b);
2520 stack.push(b);
2521 }
2522 }
2523 }
2524}
2525
2526fn compute_var_sets(ops: &[TapeOp]) -> Vec<BTreeSet<usize>> {
2530 let mut out: Vec<BTreeSet<usize>> = Vec::with_capacity(ops.len());
2531 for op in ops {
2532 let vs: BTreeSet<usize> = match op {
2533 TapeOp::Const(_) => BTreeSet::new(),
2534 TapeOp::Var(j) => {
2535 let mut s = BTreeSet::new();
2536 s.insert(*j);
2537 s
2538 }
2539 TapeOp::Add(a, b)
2540 | TapeOp::Sub(a, b)
2541 | TapeOp::Mul(a, b)
2542 | TapeOp::Div(a, b)
2543 | TapeOp::Pow(a, b)
2544 | TapeOp::Atan2(a, b) => out[*a].union(&out[*b]).copied().collect(),
2545 TapeOp::Neg(a)
2546 | TapeOp::Abs(a)
2547 | TapeOp::Sqrt(a)
2548 | TapeOp::Exp(a)
2549 | TapeOp::Log(a)
2550 | TapeOp::Log10(a)
2551 | TapeOp::Sin(a)
2552 | TapeOp::Cos(a)
2553 | TapeOp::Tan(a)
2554 | TapeOp::Atan(a)
2555 | TapeOp::Acos(a)
2556 | TapeOp::Sinh(a)
2557 | TapeOp::Cosh(a)
2558 | TapeOp::Tanh(a)
2559 | TapeOp::Asin(a)
2560 | TapeOp::Acosh(a)
2561 | TapeOp::Asinh(a)
2562 | TapeOp::Atanh(a) => out[*a].clone(),
2563 TapeOp::Cmp(_, _, _)
2564 | TapeOp::And(_, _)
2565 | TapeOp::Or(_, _)
2566 | TapeOp::Not(_)
2567 | TapeOp::Select(_, _, _)
2568 | TapeOp::Min(_, _)
2569 | TapeOp::Max(_, _) => unreachable!(
2570 "HybridTape prelude cannot contain conditional / logical / min-max \
2571 TapeOps; build_into_summand panics on those Expr variants."
2572 ),
2573 TapeOp::Funcall(_) => unreachable!(
2574 "HybridTape prelude cannot contain TapeOp::Funcall; \
2575 build_into_summand panics on Expr::Funcall."
2576 ),
2577 };
2578 out.push(vs);
2579 }
2580 out
2581}
2582
2583fn summand_sparsity(
2588 ops: &[SummandOp],
2589 prelude_var_sets: &[BTreeSet<usize>],
2590 pairs: &mut BTreeSet<(usize, usize)>,
2591) {
2592 let mut var_sets: Vec<BTreeSet<usize>> = Vec::with_capacity(ops.len());
2593 let emit_cross =
2594 |s1: &BTreeSet<usize>, s2: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
2595 for &v1 in s1 {
2596 for &v2 in s2 {
2597 let (r, c) = if v1 >= v2 { (v1, v2) } else { (v2, v1) };
2598 pairs.insert((r, c));
2599 }
2600 }
2601 };
2602 let emit_self = |s: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
2603 let vars: Vec<usize> = s.iter().copied().collect();
2604 for (ai, &vi) in vars.iter().enumerate() {
2605 for &vj in &vars[..=ai] {
2606 let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
2607 pairs.insert((r, c));
2608 }
2609 }
2610 };
2611 for so in ops {
2612 let vset: BTreeSet<usize> = match so {
2613 SummandOp::Shared(k) => prelude_var_sets[*k].clone(),
2614 SummandOp::Local(op) => match op {
2615 TapeOp::Const(_) => BTreeSet::new(),
2616 TapeOp::Var(j) => {
2617 let mut s = BTreeSet::new();
2618 s.insert(*j);
2619 s
2620 }
2621 TapeOp::Add(a, b) | TapeOp::Sub(a, b) => {
2622 var_sets[*a].union(&var_sets[*b]).copied().collect()
2623 }
2624 TapeOp::Neg(a) | TapeOp::Abs(a) => var_sets[*a].clone(),
2625 TapeOp::Mul(a, b) => {
2626 emit_cross(&var_sets[*a], &var_sets[*b], pairs);
2627 var_sets[*a].union(&var_sets[*b]).copied().collect()
2628 }
2629 TapeOp::Div(a, b) => {
2630 emit_cross(&var_sets[*a], &var_sets[*b], pairs);
2631 emit_self(&var_sets[*b], pairs);
2632 var_sets[*a].union(&var_sets[*b]).copied().collect()
2633 }
2634 TapeOp::Pow(a, b) | TapeOp::Atan2(a, b) => {
2635 let combined: BTreeSet<usize> =
2636 var_sets[*a].union(&var_sets[*b]).copied().collect();
2637 emit_self(&combined, pairs);
2638 combined
2639 }
2640 TapeOp::Sqrt(a)
2641 | TapeOp::Exp(a)
2642 | TapeOp::Log(a)
2643 | TapeOp::Log10(a)
2644 | TapeOp::Sin(a)
2645 | TapeOp::Cos(a)
2646 | TapeOp::Tan(a)
2647 | TapeOp::Atan(a)
2648 | TapeOp::Acos(a)
2649 | TapeOp::Sinh(a)
2650 | TapeOp::Cosh(a)
2651 | TapeOp::Tanh(a)
2652 | TapeOp::Asin(a)
2653 | TapeOp::Acosh(a)
2654 | TapeOp::Asinh(a)
2655 | TapeOp::Atanh(a) => {
2656 emit_self(&var_sets[*a], pairs);
2657 var_sets[*a].clone()
2658 }
2659 TapeOp::Cmp(_, _, _)
2660 | TapeOp::And(_, _)
2661 | TapeOp::Or(_, _)
2662 | TapeOp::Not(_)
2663 | TapeOp::Select(_, _, _)
2664 | TapeOp::Min(_, _)
2665 | TapeOp::Max(_, _) => unreachable!(
2666 "HybridTape summand cannot contain conditional / logical / min-max \
2667 TapeOps; build_into_summand panics on those Expr variants."
2668 ),
2669 TapeOp::Funcall(_) => unreachable!(
2670 "HybridTape summand cannot contain TapeOp::Funcall; \
2671 build_into_summand panics on Expr::Funcall."
2672 ),
2673 },
2674 };
2675 var_sets.push(vset);
2676 }
2677}
2678
2679#[inline]
2682fn op_operands(op: &TapeOp) -> (Option<usize>, Option<usize>) {
2683 match op {
2684 TapeOp::Const(_) | TapeOp::Var(_) => (None, None),
2685 TapeOp::Add(a, b)
2686 | TapeOp::Sub(a, b)
2687 | TapeOp::Mul(a, b)
2688 | TapeOp::Div(a, b)
2689 | TapeOp::Pow(a, b)
2690 | TapeOp::Atan2(a, b) => (Some(*a), Some(*b)),
2691 TapeOp::Neg(a)
2692 | TapeOp::Abs(a)
2693 | TapeOp::Sqrt(a)
2694 | TapeOp::Exp(a)
2695 | TapeOp::Log(a)
2696 | TapeOp::Log10(a)
2697 | TapeOp::Sin(a)
2698 | TapeOp::Cos(a)
2699 | TapeOp::Tan(a)
2700 | TapeOp::Atan(a)
2701 | TapeOp::Acos(a)
2702 | TapeOp::Sinh(a)
2703 | TapeOp::Cosh(a)
2704 | TapeOp::Tanh(a)
2705 | TapeOp::Asin(a)
2706 | TapeOp::Acosh(a)
2707 | TapeOp::Asinh(a)
2708 | TapeOp::Atanh(a) => (Some(*a), None),
2709 TapeOp::Cmp(_, a, b) | TapeOp::And(a, b) | TapeOp::Or(a, b) => (Some(*a), Some(*b)),
2715 TapeOp::Not(a) => (Some(*a), None),
2716 TapeOp::Select(_, _, _) => unreachable!(
2717 "op_operands: TapeOp::Select has three operands and is unsupported on \
2718 the HybridTape path"
2719 ),
2720 TapeOp::Min(_, _) | TapeOp::Max(_, _) => unreachable!(
2721 "op_operands: TapeOp::Min/Max are unsupported on the HybridTape path \
2722 (build_into_summand rejects min/max lists)"
2723 ),
2724 TapeOp::Funcall(_) => (None, None),
2725 }
2726}
2727
2728fn vars_in(ops: &[TapeOp], reach: &[usize]) -> Vec<usize> {
2729 let mut s: BTreeSet<usize> = BTreeSet::new();
2730 for &i in reach {
2731 if let TapeOp::Var(j) = &ops[i] {
2732 s.insert(*j);
2733 }
2734 }
2735 s.into_iter().collect()
2736}
2737
2738#[inline]
2741fn fwd_step(op: &TapeOp, x: &[f64], vals: &[f64]) -> f64 {
2742 match op {
2743 TapeOp::Const(c) => *c,
2744 TapeOp::Var(i) => x[*i],
2745 TapeOp::Add(a, b) => vals[*a] + vals[*b],
2746 TapeOp::Sub(a, b) => vals[*a] - vals[*b],
2747 TapeOp::Mul(a, b) => vals[*a] * vals[*b],
2748 TapeOp::Div(a, b) => vals[*a] / vals[*b],
2749 TapeOp::Pow(a, b) => vals[*a].powf(vals[*b]),
2750 TapeOp::Neg(a) => -vals[*a],
2751 TapeOp::Abs(a) => vals[*a].abs(),
2752 TapeOp::Sqrt(a) => vals[*a].sqrt(),
2753 TapeOp::Exp(a) => vals[*a].exp(),
2754 TapeOp::Log(a) => vals[*a].ln(),
2755 TapeOp::Log10(a) => vals[*a].log10(),
2756 TapeOp::Sin(a) => vals[*a].sin(),
2757 TapeOp::Cos(a) => vals[*a].cos(),
2758 TapeOp::Tan(a) => vals[*a].tan(),
2759 TapeOp::Atan(a) => vals[*a].atan(),
2760 TapeOp::Acos(a) => vals[*a].acos(),
2761 TapeOp::Sinh(a) => vals[*a].sinh(),
2762 TapeOp::Cosh(a) => vals[*a].cosh(),
2763 TapeOp::Tanh(a) => vals[*a].tanh(),
2764 TapeOp::Asin(a) => vals[*a].asin(),
2765 TapeOp::Acosh(a) => vals[*a].acosh(),
2766 TapeOp::Asinh(a) => vals[*a].asinh(),
2767 TapeOp::Atanh(a) => vals[*a].atanh(),
2768 TapeOp::Atan2(a, b) => vals[*a].atan2(vals[*b]),
2769 TapeOp::Cmp(_, _, _)
2770 | TapeOp::And(_, _)
2771 | TapeOp::Or(_, _)
2772 | TapeOp::Not(_)
2773 | TapeOp::Select(_, _, _)
2774 | TapeOp::Min(_, _)
2775 | TapeOp::Max(_, _) => panic!(
2776 "GlobalTape free-function kernels do not implement conditional / logical \
2777 / min-max TapeOps; use the Tape (build_with_externals) interpreter path \
2778 instead."
2779 ),
2780 TapeOp::Funcall(fc) => {
2781 let FuncallData { lib, name, args } = fc.as_ref();
2782 let call_args = funcall_to_ext_args(args, vals);
2783 let res = lib
2784 .eval(name, &call_args, false, false)
2785 .unwrap_or_else(|e| panic!("external function '{name}' eval failed: {e}"));
2786 res.value
2787 }
2788 }
2789}
2790
2791#[inline]
2792fn rev_step(op: &TapeOp, i: usize, vals: &[f64], adj: &mut [f64], a: f64, grad: &mut [f64]) {
2793 match op {
2794 TapeOp::Const(_) => {}
2795 TapeOp::Var(j) => {
2796 grad[*j] += a;
2797 }
2798 TapeOp::Add(l, r) => {
2799 adj[*l] += a;
2800 adj[*r] += a;
2801 }
2802 TapeOp::Sub(l, r) => {
2803 adj[*l] += a;
2804 adj[*r] -= a;
2805 }
2806 TapeOp::Mul(l, r) => {
2807 adj[*l] += a * vals[*r];
2808 adj[*r] += a * vals[*l];
2809 }
2810 TapeOp::Div(l, r) => {
2811 let rv = vals[*r];
2812 adj[*l] += a / rv;
2813 adj[*r] -= a * vals[*l] / (rv * rv);
2814 }
2815 TapeOp::Pow(l, r) => {
2816 let lv = vals[*l];
2817 let rv = vals[*r];
2818 if rv != 0.0 {
2819 adj[*l] += a * rv * lv.powf(rv - 1.0);
2820 }
2821 if lv > 0.0 {
2822 adj[*r] += a * vals[i] * lv.ln();
2823 }
2824 }
2825 TapeOp::Neg(j) => {
2826 adj[*j] -= a;
2827 }
2828 TapeOp::Abs(j) => {
2829 if vals[*j] >= 0.0 {
2830 adj[*j] += a;
2831 } else {
2832 adj[*j] -= a;
2833 }
2834 }
2835 TapeOp::Sqrt(j) => {
2836 let sv = vals[i];
2837 if sv > 0.0 {
2838 adj[*j] += a * 0.5 / sv;
2839 }
2840 }
2841 TapeOp::Exp(j) => {
2842 adj[*j] += a * vals[i];
2843 }
2844 TapeOp::Log(j) => {
2845 adj[*j] += a / vals[*j];
2846 }
2847 TapeOp::Log10(j) => {
2848 adj[*j] += a / (vals[*j] * std::f64::consts::LN_10);
2849 }
2850 TapeOp::Sin(j) => {
2851 adj[*j] += a * vals[*j].cos();
2852 }
2853 TapeOp::Cos(j) => {
2854 adj[*j] -= a * vals[*j].sin();
2855 }
2856 TapeOp::Tan(j) => {
2857 let t = vals[i];
2858 adj[*j] += a * (1.0 + t * t);
2859 }
2860 TapeOp::Atan(j) => {
2861 let u = vals[*j];
2862 adj[*j] += a / (1.0 + u * u);
2863 }
2864 TapeOp::Acos(j) => {
2865 let u = vals[*j];
2866 adj[*j] -= a / (1.0 - u * u).sqrt();
2867 }
2868 TapeOp::Sinh(j) => {
2869 adj[*j] += a * vals[*j].cosh();
2870 }
2871 TapeOp::Cosh(j) => {
2872 adj[*j] += a * vals[*j].sinh();
2873 }
2874 TapeOp::Tanh(j) => {
2875 let t = vals[i];
2876 adj[*j] += a * (1.0 - t * t);
2877 }
2878 TapeOp::Asin(j) => {
2879 let u = vals[*j];
2880 adj[*j] += a / (1.0 - u * u).sqrt();
2881 }
2882 TapeOp::Acosh(j) => {
2883 let u = vals[*j];
2884 adj[*j] += a / (u * u - 1.0).sqrt();
2885 }
2886 TapeOp::Asinh(j) => {
2887 let u = vals[*j];
2888 adj[*j] += a / (u * u + 1.0).sqrt();
2889 }
2890 TapeOp::Atanh(j) => {
2891 let u = vals[*j];
2892 adj[*j] += a / (1.0 - u * u);
2893 }
2894 TapeOp::Atan2(l, r) => {
2895 let y = vals[*l];
2896 let x = vals[*r];
2897 let d = y * y + x * x;
2898 adj[*l] += a * (x / d);
2899 adj[*r] += a * (-y / d);
2900 }
2901 TapeOp::Cmp(_, _, _)
2902 | TapeOp::And(_, _)
2903 | TapeOp::Or(_, _)
2904 | TapeOp::Not(_)
2905 | TapeOp::Select(_, _, _)
2906 | TapeOp::Min(_, _)
2907 | TapeOp::Max(_, _) => panic!(
2908 "GlobalTape free-function kernels do not implement conditional / logical \
2909 / min-max TapeOps; use the Tape (build_with_externals) interpreter path \
2910 instead."
2911 ),
2912 TapeOp::Funcall(fc) => {
2913 let FuncallData { lib, name, args } = fc.as_ref();
2914 let call_args = funcall_to_ext_args(args, vals);
2915 let res = lib
2916 .eval(name, &call_args, true, false)
2917 .unwrap_or_else(|e| panic!("external function '{name}' reverse eval failed: {e}"));
2918 let derivs = res.derivs.expect("want_derivs=true returns derivs");
2919 let mut k = 0usize;
2920 for arg in args {
2921 if let TapeFuncallArg::Tape(idx) = arg {
2922 adj[*idx] += a * derivs[k];
2923 k += 1;
2924 }
2925 }
2926 let _ = i;
2927 let _ = grad;
2928 }
2929 }
2930}
2931
2932#[inline]
2933fn fwd_tan_step(op: &TapeOp, seed_var: usize, vals: &[f64], dot: &[f64], i: usize) -> f64 {
2934 match op {
2935 TapeOp::Const(_) => 0.0,
2936 TapeOp::Var(k) => {
2937 if *k == seed_var {
2938 1.0
2939 } else {
2940 0.0
2941 }
2942 }
2943 TapeOp::Add(a, b) => dot[*a] + dot[*b],
2944 TapeOp::Sub(a, b) => dot[*a] - dot[*b],
2945 TapeOp::Mul(a, b) => dot[*a] * vals[*b] + vals[*a] * dot[*b],
2946 TapeOp::Div(a, b) => {
2947 let vb = vals[*b];
2948 (dot[*a] * vb - vals[*a] * dot[*b]) / (vb * vb)
2949 }
2950 TapeOp::Pow(a, b) => {
2951 let u = vals[*a];
2952 let r = vals[*b];
2953 let du = dot[*a];
2954 let dr = dot[*b];
2955 let mut result = 0.0;
2956 if r != 0.0 && u != 0.0 {
2957 result += r * u.powf(r - 1.0) * du;
2958 }
2959 if u > 0.0 {
2960 result += vals[i] * u.ln() * dr;
2961 }
2962 result
2963 }
2964 TapeOp::Neg(a) => -dot[*a],
2965 TapeOp::Abs(a) => {
2966 if vals[*a] >= 0.0 {
2967 dot[*a]
2968 } else {
2969 -dot[*a]
2970 }
2971 }
2972 TapeOp::Sqrt(a) => {
2973 let sv = vals[i];
2974 if sv > 0.0 {
2975 dot[*a] * 0.5 / sv
2976 } else {
2977 0.0
2978 }
2979 }
2980 TapeOp::Exp(a) => dot[*a] * vals[i],
2981 TapeOp::Log(a) => dot[*a] / vals[*a],
2982 TapeOp::Log10(a) => dot[*a] / (vals[*a] * std::f64::consts::LN_10),
2983 TapeOp::Sin(a) => dot[*a] * vals[*a].cos(),
2984 TapeOp::Cos(a) => -dot[*a] * vals[*a].sin(),
2985 TapeOp::Tan(a) => {
2986 let t = vals[i];
2987 dot[*a] * (1.0 + t * t)
2988 }
2989 TapeOp::Atan(a) => {
2990 let u = vals[*a];
2991 dot[*a] / (1.0 + u * u)
2992 }
2993 TapeOp::Acos(a) => {
2994 let u = vals[*a];
2995 -dot[*a] / (1.0 - u * u).sqrt()
2996 }
2997 TapeOp::Sinh(a) => dot[*a] * vals[*a].cosh(),
2998 TapeOp::Cosh(a) => dot[*a] * vals[*a].sinh(),
2999 TapeOp::Tanh(a) => {
3000 let t = vals[i];
3001 dot[*a] * (1.0 - t * t)
3002 }
3003 TapeOp::Asin(a) => {
3004 let u = vals[*a];
3005 dot[*a] / (1.0 - u * u).sqrt()
3006 }
3007 TapeOp::Acosh(a) => {
3008 let u = vals[*a];
3009 dot[*a] / (u * u - 1.0).sqrt()
3010 }
3011 TapeOp::Asinh(a) => {
3012 let u = vals[*a];
3013 dot[*a] / (u * u + 1.0).sqrt()
3014 }
3015 TapeOp::Atanh(a) => {
3016 let u = vals[*a];
3017 dot[*a] / (1.0 - u * u)
3018 }
3019 TapeOp::Atan2(a, b) => {
3020 let y = vals[*a];
3021 let x = vals[*b];
3022 let d = y * y + x * x;
3023 (x * dot[*a] - y * dot[*b]) / d
3024 }
3025 TapeOp::Cmp(_, _, _)
3026 | TapeOp::And(_, _)
3027 | TapeOp::Or(_, _)
3028 | TapeOp::Not(_)
3029 | TapeOp::Select(_, _, _)
3030 | TapeOp::Min(_, _)
3031 | TapeOp::Max(_, _) => panic!(
3032 "GlobalTape free-function kernels do not implement conditional / logical \
3033 / min-max TapeOps; use the Tape (build_with_externals) interpreter path \
3034 instead."
3035 ),
3036 TapeOp::Funcall(fc) => {
3037 let FuncallData { lib, name, args } = fc.as_ref();
3038 let call_args = funcall_to_ext_args(args, vals);
3039 let res = lib
3040 .eval(name, &call_args, true, false)
3041 .unwrap_or_else(|e| panic!("external function '{name}' tangent eval failed: {e}"));
3042 let derivs = res.derivs.expect("want_derivs=true returns derivs");
3043 let mut acc = 0.0;
3044 let mut k = 0usize;
3045 for arg in args {
3046 if let TapeFuncallArg::Tape(idx) = arg {
3047 acc += derivs[k] * dot[*idx];
3048 k += 1;
3049 }
3050 }
3051 let _ = seed_var;
3052 acc
3053 }
3054 }
3055}
3056
3057#[allow(clippy::too_many_arguments)]
3058#[inline]
3059fn ror_step(
3060 op: &TapeOp,
3061 i: usize,
3062 seed_var: usize,
3063 vals: &[f64],
3064 dot: &[f64],
3065 adj: &mut [f64],
3066 adj_dot: &mut [f64],
3067 w: f64,
3068 wd: f64,
3069 weight: f64,
3070 hess_map: &HashMap<(usize, usize), usize>,
3071 values: &mut [f64],
3072) {
3073 match op {
3074 TapeOp::Const(_) => {}
3075 TapeOp::Var(k) => {
3076 if wd != 0.0 && *k >= seed_var {
3077 if let Some(&pos) = hess_map.get(&(*k, seed_var)) {
3078 values[pos] += weight * wd;
3079 }
3080 }
3081 }
3082 TapeOp::Add(a, b) => {
3083 adj[*a] += w;
3084 adj[*b] += w;
3085 adj_dot[*a] += wd;
3086 adj_dot[*b] += wd;
3087 }
3088 TapeOp::Sub(a, b) => {
3089 adj[*a] += w;
3090 adj[*b] -= w;
3091 adj_dot[*a] += wd;
3092 adj_dot[*b] -= wd;
3093 }
3094 TapeOp::Mul(a, b) => {
3095 adj[*a] += w * vals[*b];
3096 adj[*b] += w * vals[*a];
3097 adj_dot[*a] += wd * vals[*b] + w * dot[*b];
3098 adj_dot[*b] += wd * vals[*a] + w * dot[*a];
3099 }
3100 TapeOp::Div(a, b) => {
3101 let vb = vals[*b];
3102 let vb2 = vb * vb;
3103 let vb3 = vb2 * vb;
3104 adj[*a] += w / vb;
3105 adj_dot[*a] += wd / vb + w * (-dot[*b] / vb2);
3106 adj[*b] += w * (-vals[*a] / vb2);
3107 adj_dot[*b] +=
3108 wd * (-vals[*a] / vb2) + w * (-dot[*a] / vb2 + 2.0 * vals[*a] * dot[*b] / vb3);
3109 }
3110 TapeOp::Pow(a, b) => {
3111 let u = vals[*a];
3112 let r = vals[*b];
3113 let du = dot[*a];
3114 let dr = dot[*b];
3115 if r != 0.0 {
3116 if u != 0.0 {
3117 let p_a = r * u.powf(r - 1.0);
3118 adj[*a] += w * p_a;
3119 let mut dp_a = dr * u.powf(r - 1.0);
3120 if u > 0.0 {
3121 dp_a += r * u.powf(r - 1.0) * ((r - 1.0) * du / u + dr * u.ln());
3122 } else {
3123 dp_a += r * (r - 1.0) * u.powf(r - 2.0) * du;
3124 }
3125 adj_dot[*a] += wd * p_a + w * dp_a;
3126 } else if r >= 2.0 {
3127 let p_a = 0.0;
3128 adj[*a] += w * p_a;
3129 let dp_a = if r == 2.0 {
3130 2.0 * du
3131 } else {
3132 r * (r - 1.0) * (0.0_f64).powf(r - 2.0) * du
3133 };
3134 adj_dot[*a] += wd * p_a + w * dp_a;
3135 }
3136 }
3137 if u > 0.0 {
3138 let ln_u = u.ln();
3139 let p_b = vals[i] * ln_u;
3140 adj[*b] += w * p_b;
3141 let dur = vals[i] * (r * du / u + dr * ln_u);
3142 let dp_b = dur * ln_u + vals[i] * du / u;
3143 adj_dot[*b] += wd * p_b + w * dp_b;
3144 }
3145 }
3146 TapeOp::Neg(a) => {
3147 adj[*a] -= w;
3148 adj_dot[*a] -= wd;
3149 }
3150 TapeOp::Abs(a) => {
3151 let s = if vals[*a] >= 0.0 { 1.0 } else { -1.0 };
3152 adj[*a] += w * s;
3153 adj_dot[*a] += wd * s;
3154 }
3155 TapeOp::Sqrt(a) => {
3156 let sv = vals[i];
3157 if sv > 0.0 {
3158 let fp = 0.5 / sv;
3159 let fpp = -0.25 / (vals[*a] * sv);
3160 adj[*a] += w * fp;
3161 adj_dot[*a] += wd * fp + w * fpp * dot[*a];
3162 }
3163 }
3164 TapeOp::Exp(a) => {
3165 let ev = vals[i];
3166 adj[*a] += w * ev;
3167 adj_dot[*a] += wd * ev + w * ev * dot[*a];
3168 }
3169 TapeOp::Log(a) => {
3170 let u = vals[*a];
3171 adj[*a] += w / u;
3172 adj_dot[*a] += wd / u + w * (-1.0 / (u * u)) * dot[*a];
3173 }
3174 TapeOp::Log10(a) => {
3175 let u = vals[*a];
3176 let c = std::f64::consts::LN_10;
3177 adj[*a] += w / (u * c);
3178 adj_dot[*a] += wd / (u * c) + w * (-1.0 / (u * u * c)) * dot[*a];
3179 }
3180 TapeOp::Sin(a) => {
3181 let u = vals[*a];
3182 let cu = u.cos();
3183 adj[*a] += w * cu;
3184 adj_dot[*a] += wd * cu + w * (-u.sin()) * dot[*a];
3185 }
3186 TapeOp::Cos(a) => {
3187 let u = vals[*a];
3188 let su = u.sin();
3189 adj[*a] -= w * su;
3190 adj_dot[*a] += wd * (-su) + w * (-u.cos()) * dot[*a];
3191 }
3192 TapeOp::Tan(a) => {
3193 let t = vals[i];
3194 let gp = 1.0 + t * t;
3195 let gpp = 2.0 * t * gp;
3196 adj[*a] += w * gp;
3197 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3198 }
3199 TapeOp::Atan(a) => {
3200 let u = vals[*a];
3201 let d = 1.0 + u * u;
3202 let gp = 1.0 / d;
3203 let gpp = -2.0 * u / (d * d);
3204 adj[*a] += w * gp;
3205 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3206 }
3207 TapeOp::Acos(a) => {
3208 let u = vals[*a];
3209 let s = 1.0 - u * u;
3210 let r = s.sqrt();
3211 let gp = -1.0 / r;
3212 let gpp = -u / (s * r);
3213 adj[*a] += w * gp;
3214 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3215 }
3216 TapeOp::Sinh(a) => {
3217 let u = vals[*a];
3218 let gp = u.cosh();
3219 let gpp = vals[i]; adj[*a] += w * gp;
3221 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3222 }
3223 TapeOp::Cosh(a) => {
3224 let u = vals[*a];
3225 let gp = u.sinh();
3226 let gpp = vals[i]; adj[*a] += w * gp;
3228 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3229 }
3230 TapeOp::Tanh(a) => {
3231 let t = vals[i];
3232 let gp = 1.0 - t * t;
3233 let gpp = -2.0 * t * gp;
3234 adj[*a] += w * gp;
3235 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3236 }
3237 TapeOp::Asin(a) => {
3238 let u = vals[*a];
3239 let s = 1.0 - u * u;
3240 let r = s.sqrt();
3241 let gp = 1.0 / r;
3242 let gpp = u / (s * r);
3243 adj[*a] += w * gp;
3244 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3245 }
3246 TapeOp::Acosh(a) => {
3247 let u = vals[*a];
3248 let s = u * u - 1.0;
3249 let r = s.sqrt();
3250 let gp = 1.0 / r;
3251 let gpp = -u / (s * r);
3252 adj[*a] += w * gp;
3253 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3254 }
3255 TapeOp::Asinh(a) => {
3256 let u = vals[*a];
3257 let s = u * u + 1.0;
3258 let r = s.sqrt();
3259 let gp = 1.0 / r;
3260 let gpp = -u / (s * r);
3261 adj[*a] += w * gp;
3262 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3263 }
3264 TapeOp::Atanh(a) => {
3265 let u = vals[*a];
3266 let d = 1.0 - u * u;
3267 let gp = 1.0 / d;
3268 let gpp = 2.0 * u / (d * d);
3269 adj[*a] += w * gp;
3270 adj_dot[*a] += wd * gp + w * gpp * dot[*a];
3271 }
3272 TapeOp::Atan2(a, b) => {
3273 let y = vals[*a];
3274 let x = vals[*b];
3275 let d = y * y + x * x;
3276 let d2 = d * d;
3277 let fa = x / d;
3278 let fb = -y / d;
3279 let faa = -2.0 * x * y / d2;
3280 let fab = (y * y - x * x) / d2;
3281 let fbb = 2.0 * x * y / d2;
3282 adj[*a] += w * fa;
3283 adj[*b] += w * fb;
3284 adj_dot[*a] += wd * fa + w * (faa * dot[*a] + fab * dot[*b]);
3285 adj_dot[*b] += wd * fb + w * (fab * dot[*a] + fbb * dot[*b]);
3286 }
3287 TapeOp::Cmp(_, _, _)
3288 | TapeOp::And(_, _)
3289 | TapeOp::Or(_, _)
3290 | TapeOp::Not(_)
3291 | TapeOp::Select(_, _, _)
3292 | TapeOp::Min(_, _)
3293 | TapeOp::Max(_, _) => panic!(
3294 "GlobalTape free-function kernels do not implement conditional / logical \
3295 / min-max TapeOps; use the Tape (build_with_externals) interpreter path \
3296 instead."
3297 ),
3298 TapeOp::Funcall(fc) => {
3299 let FuncallData { lib, name, args } = fc.as_ref();
3300 let call_args = funcall_to_ext_args(args, vals);
3301 let res = lib.eval(name, &call_args, true, true).unwrap_or_else(|e| {
3302 panic!("external function '{name}' 2nd-order eval failed: {e}")
3303 });
3304 let derivs = res.derivs.expect("want_derivs=true returns derivs");
3305 let hes = res.hessian.expect("want_hes=true returns hessian");
3306 let real_tape: Vec<usize> = args
3307 .iter()
3308 .filter_map(|a| match a {
3309 TapeFuncallArg::Tape(t) => Some(*t),
3310 TapeFuncallArg::Str(_) => None,
3311 })
3312 .collect();
3313 for (k, &tk) in real_tape.iter().enumerate() {
3314 adj[tk] += w * derivs[k];
3315 let mut second_term = 0.0;
3316 for (l, &tl) in real_tape.iter().enumerate() {
3317 let (lo, hi) = if k <= l { (k, l) } else { (l, k) };
3318 let h_kl = hes[lo + hi * (hi + 1) / 2];
3319 second_term += h_kl * dot[tl];
3320 }
3321 adj_dot[tk] += wd * derivs[k] + w * second_term;
3322 }
3323 let _ = seed_var;
3324 let _ = hess_map;
3325 let _ = values;
3326 let _ = weight;
3327 let _ = i;
3328 }
3329 }
3330}
3331
3332fn hessian_sparsity_impl(ops: &[TapeOp]) -> BTreeSet<(usize, usize)> {
3336 let n = ops.len();
3337 let mut var_sets: Vec<BTreeSet<usize>> = Vec::with_capacity(n);
3338 let mut pairs: BTreeSet<(usize, usize)> = BTreeSet::new();
3339
3340 let emit_cross =
3341 |s1: &BTreeSet<usize>, s2: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
3342 for &v1 in s1 {
3343 for &v2 in s2 {
3344 let (r, c) = if v1 >= v2 { (v1, v2) } else { (v2, v1) };
3345 pairs.insert((r, c));
3346 }
3347 }
3348 };
3349 let emit_self = |s: &BTreeSet<usize>, pairs: &mut BTreeSet<(usize, usize)>| {
3350 let vars: Vec<usize> = s.iter().copied().collect();
3351 for (ai, &vi) in vars.iter().enumerate() {
3352 for &vj in &vars[..=ai] {
3353 let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
3354 pairs.insert((r, c));
3355 }
3356 }
3357 };
3358
3359 for op in ops {
3360 let vset = match op {
3361 TapeOp::Const(_) => BTreeSet::new(),
3362 TapeOp::Var(j) => {
3363 let mut s = BTreeSet::new();
3364 s.insert(*j);
3365 s
3366 }
3367 TapeOp::Add(a, b) | TapeOp::Sub(a, b) => {
3368 var_sets[*a].union(&var_sets[*b]).copied().collect()
3369 }
3370 TapeOp::Neg(a) | TapeOp::Abs(a) => var_sets[*a].clone(),
3371 TapeOp::Mul(a, b) => {
3372 emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
3373 var_sets[*a].union(&var_sets[*b]).copied().collect()
3374 }
3375 TapeOp::Div(a, b) => {
3376 emit_cross(&var_sets[*a], &var_sets[*b], &mut pairs);
3377 emit_self(&var_sets[*b], &mut pairs);
3378 var_sets[*a].union(&var_sets[*b]).copied().collect()
3379 }
3380 TapeOp::Pow(a, b) | TapeOp::Atan2(a, b) => {
3381 let combined: BTreeSet<usize> =
3382 var_sets[*a].union(&var_sets[*b]).copied().collect();
3383 emit_self(&combined, &mut pairs);
3384 combined
3385 }
3386 TapeOp::Sqrt(a)
3387 | TapeOp::Exp(a)
3388 | TapeOp::Log(a)
3389 | TapeOp::Log10(a)
3390 | TapeOp::Sin(a)
3391 | TapeOp::Cos(a)
3392 | TapeOp::Tan(a)
3393 | TapeOp::Atan(a)
3394 | TapeOp::Acos(a)
3395 | TapeOp::Sinh(a)
3396 | TapeOp::Cosh(a)
3397 | TapeOp::Tanh(a)
3398 | TapeOp::Asin(a)
3399 | TapeOp::Acosh(a)
3400 | TapeOp::Asinh(a)
3401 | TapeOp::Atanh(a) => {
3402 emit_self(&var_sets[*a], &mut pairs);
3403 var_sets[*a].clone()
3404 }
3405 TapeOp::Funcall(fc) => {
3406 let args = &fc.args;
3407 let mut combined: BTreeSet<usize> = BTreeSet::new();
3408 for arg in args {
3409 if let TapeFuncallArg::Tape(t) = arg {
3410 for &vv in &var_sets[*t] {
3411 combined.insert(vv);
3412 }
3413 }
3414 }
3415 emit_self(&combined, &mut pairs);
3416 combined
3417 }
3418 TapeOp::Cmp(_, _, _) | TapeOp::And(_, _) | TapeOp::Or(_, _) | TapeOp::Not(_) => {
3419 BTreeSet::new()
3422 }
3423 TapeOp::Select(_, t, e) => {
3424 var_sets[*t].union(&var_sets[*e]).copied().collect()
3427 }
3428 TapeOp::Min(a, b) | TapeOp::Max(a, b) => {
3429 var_sets[*a].union(&var_sets[*b]).copied().collect()
3432 }
3433 };
3434 var_sets.push(vset);
3435 }
3436 pairs
3437}
3438
3439#[cfg(test)]
3440mod tests {
3441 use super::*;
3442
3443 fn cnst(c: f64) -> Expr {
3444 Expr::Const(c)
3445 }
3446 fn var(i: usize) -> Expr {
3447 Expr::Var(i)
3448 }
3449 fn add(a: Expr, b: Expr) -> Expr {
3450 Expr::Binary(BinOp::Add, Box::new(a), Box::new(b))
3451 }
3452 fn mul(a: Expr, b: Expr) -> Expr {
3453 Expr::Binary(BinOp::Mul, Box::new(a), Box::new(b))
3454 }
3455 fn pow(a: Expr, b: Expr) -> Expr {
3456 Expr::Binary(BinOp::Pow, Box::new(a), Box::new(b))
3457 }
3458 fn div(a: Expr, b: Expr) -> Expr {
3459 Expr::Binary(BinOp::Div, Box::new(a), Box::new(b))
3460 }
3461 fn unary(op: UnaryOp, a: Expr) -> Expr {
3462 Expr::Unary(op, Box::new(a))
3463 }
3464 fn cmp(op: CmpOp, a: Expr, b: Expr) -> Expr {
3465 Expr::Compare(op, Box::new(a), Box::new(b))
3466 }
3467 fn cond(c: Expr, t: Expr, e: Expr) -> Expr {
3468 Expr::Cond {
3469 cond: Box::new(c),
3470 then_: Box::new(t),
3471 else_: Box::new(e),
3472 }
3473 }
3474
3475 #[test]
3476 fn polynomial_eval_and_grad() {
3477 let e = add(
3479 mul(cnst(3.0), pow(var(0), cnst(2.0))),
3480 mul(cnst(2.0), var(1)),
3481 );
3482 let t = Tape::build(&e);
3483 assert!((t.eval(&[2.0, 3.0]) - 18.0).abs() < 1e-12);
3484 let mut g = vec![0.0; 2];
3485 t.gradient_seed(&[2.0, 3.0], 1.0, &mut g);
3486 assert!((g[0] - 12.0).abs() < 1e-12);
3488 assert!((g[1] - 2.0).abs() < 1e-12);
3489 }
3490
3491 #[test]
3492 fn cse_shared_body_evaluated_once() {
3493 let body = Rc::new(add(var(0), var(1)));
3495 let e = add(
3496 pow(Expr::Cse(body.clone()), cnst(2.0)),
3497 Expr::Cse(body.clone()),
3498 );
3499 let t = Tape::build(&e);
3500 let n_body_adds = t
3502 .ops
3503 .iter()
3504 .filter(|op| {
3505 matches!(op, TapeOp::Add(a, b) if {
3506 matches!(t.ops[*a], TapeOp::Var(0)) && matches!(t.ops[*b], TapeOp::Var(1))
3507 })
3508 })
3509 .count();
3510 assert_eq!(n_body_adds, 1, "CSE body should be emitted exactly once");
3511
3512 assert!((t.eval(&[1.0, 2.0]) - 12.0).abs() < 1e-12);
3514 let mut g = vec![0.0; 2];
3515 t.gradient_seed(&[1.0, 2.0], 1.0, &mut g);
3516 assert!((g[0] - 7.0).abs() < 1e-12);
3518 assert!((g[1] - 7.0).abs() < 1e-12);
3519 }
3520
3521 fn fd_check(tape: &Tape, x: &[f64], n: usize, tol: f64) {
3522 let vars = tape.variables();
3523 let mut hess_map: HashMap<(usize, usize), usize> = HashMap::new();
3524 let mut pairs = Vec::new();
3525 for (ai, &vi) in vars.iter().enumerate() {
3526 for &vj in &vars[..=ai] {
3527 let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
3528 hess_map.entry((r, c)).or_insert_with(|| {
3529 let p = pairs.len();
3530 pairs.push((r, c));
3531 p
3532 });
3533 }
3534 }
3535 let nnz = pairs.len();
3536 let mut ad = vec![0.0; nnz];
3537 tape.hessian_accumulate(x, 1.0, &hess_map, &mut ad);
3538
3539 let mut fd = vec![0.0; nnz];
3540 let mut xp = x.to_vec();
3541 let mut gp = vec![0.0; n];
3542 let mut gm = vec![0.0; n];
3543 for &j in &vars {
3544 let h = (1e-7_f64).max(x[j].abs() * 1e-7);
3545 xp[j] = x[j] + h;
3546 gp.iter_mut().for_each(|v| *v = 0.0);
3547 tape.gradient_seed(&xp, 1.0, &mut gp);
3548 xp[j] = x[j] - h;
3549 gm.iter_mut().for_each(|v| *v = 0.0);
3550 tape.gradient_seed(&xp, 1.0, &mut gm);
3551 xp[j] = x[j];
3552 for &i in &vars {
3553 if i >= j {
3554 if let Some(&pos) = hess_map.get(&(i, j)) {
3555 fd[pos] = (gp[i] - gm[i]) / (2.0 * h);
3556 }
3557 }
3558 }
3559 }
3560 for (k, &(r, c)) in pairs.iter().enumerate() {
3561 let scale = fd[k].abs().max(1.0);
3562 assert!(
3563 (ad[k] - fd[k]).abs() / scale < tol,
3564 "H[{},{}]: AD={:.6e} FD={:.6e}",
3565 r,
3566 c,
3567 ad[k],
3568 fd[k]
3569 );
3570 }
3571 }
3572
3573 #[test]
3574 fn hessian_quadratic_matches_fd() {
3575 let e = add(
3577 add(
3578 mul(cnst(3.0), pow(var(0), cnst(2.0))),
3579 mul(cnst(2.0), mul(var(0), var(1))),
3580 ),
3581 pow(var(1), cnst(2.0)),
3582 );
3583 let t = Tape::build(&e);
3584 fd_check(&t, &[2.0, 3.0], 2, 1e-5);
3585 }
3586
3587 #[test]
3588 fn hessian_transcendental_matches_fd() {
3589 let e = Expr::Sum(vec![
3591 unary(UnaryOp::Exp, var(0)),
3592 unary(UnaryOp::Sin, var(1)),
3593 unary(UnaryOp::Log, var(0)),
3594 unary(UnaryOp::Sqrt, var(1)),
3595 mul(var(0), var(1)),
3596 ]);
3597 let t = Tape::build(&e);
3598 fd_check(&t, &[1.5, 2.0], 2, 1e-5);
3599 }
3600
3601 #[test]
3602 fn inverse_trig_grad_and_hessian_match_fd() {
3603 let e = Expr::Sum(vec![
3607 unary(UnaryOp::Tan, var(0)),
3608 unary(UnaryOp::Atan, var(1)),
3609 unary(UnaryOp::Acos, var(2)),
3610 mul(var(0), var(1)),
3611 ]);
3612 let t = Tape::build(&e);
3613 let x = [0.5, 1.3, 0.3];
3614
3615 let mut g = vec![0.0; 3];
3619 t.gradient_seed(&x, 1.0, &mut g);
3620 for j in 0..3 {
3621 let h = (1e-7_f64).max(x[j].abs() * 1e-7);
3622 let mut xp = x;
3623 let mut xm = x;
3624 xp[j] += h;
3625 xm[j] -= h;
3626 let fd = (t.eval(&xp) - t.eval(&xm)) / (2.0 * h);
3627 let scale = fd.abs().max(1.0);
3628 assert!(
3629 (g[j] - fd).abs() / scale < 1e-5,
3630 "grad[{j}]: AD={:.6e} FD={:.6e}",
3631 g[j],
3632 fd
3633 );
3634 }
3635
3636 fd_check(&t, &x, 3, 1e-5);
3638 }
3639
3640 fn grad_and_hess_match_fd(e: &Expr, x: &[f64], tol: f64) {
3643 let n = x.len();
3644 let t = Tape::build(e);
3645 let mut g = vec![0.0; n];
3646 t.gradient_seed(x, 1.0, &mut g);
3647 for j in 0..n {
3648 let h = (1e-7_f64).max(x[j].abs() * 1e-7);
3649 let mut xp = x.to_vec();
3650 let mut xm = x.to_vec();
3651 xp[j] += h;
3652 xm[j] -= h;
3653 let fd = (t.eval(&xp) - t.eval(&xm)) / (2.0 * h);
3654 let scale = fd.abs().max(1.0);
3655 assert!(
3656 (g[j] - fd).abs() / scale < tol,
3657 "grad[{j}]: AD={:.6e} FD={:.6e}",
3658 g[j],
3659 fd
3660 );
3661 }
3662 fd_check(&t, x, n, tol);
3663 }
3664
3665 #[test]
3666 fn hyperbolic_grad_and_hessian_match_fd() {
3667 let e = Expr::Sum(vec![
3670 unary(UnaryOp::Sinh, var(0)),
3671 unary(UnaryOp::Cosh, var(1)),
3672 unary(UnaryOp::Tanh, var(2)),
3673 unary(UnaryOp::Asinh, var(3)),
3674 mul(var(0), var(1)),
3675 mul(var(2), var(3)),
3676 ]);
3677 grad_and_hess_match_fd(&e, &[0.5, 0.7, 0.3, 1.1], 1e-5);
3678 }
3679
3680 #[test]
3681 fn restricted_inverse_grad_and_hessian_match_fd() {
3682 let e = Expr::Sum(vec![
3686 unary(UnaryOp::Asin, var(0)),
3687 unary(UnaryOp::Acosh, var(1)),
3688 unary(UnaryOp::Atanh, var(2)),
3689 mul(var(0), var(2)),
3690 ]);
3691 grad_and_hess_match_fd(&e, &[0.4, 1.8, 0.3], 1e-5);
3692 }
3693
3694 #[test]
3695 fn atan2_grad_and_hessian_match_fd() {
3696 let atan2 = |a: Expr, b: Expr| Expr::Binary(BinOp::Atan2, Box::new(a), Box::new(b));
3698 let e = Expr::Sum(vec![atan2(var(0), var(1)), mul(var(0), var(1))]);
3699 grad_and_hess_match_fd(&e, &[1.2, 0.7], 1e-5);
3700 }
3701
3702 #[test]
3703 fn minmax_grad_and_hessian_match_fd() {
3704 let e = Expr::Sum(vec![
3711 Expr::MinList(vec![var(0), var(1), var(2)]),
3712 Expr::MaxList(vec![var(1), var(2)]),
3713 mul(var(0), var(2)),
3714 ]);
3715 grad_and_hess_match_fd(&e, &[0.5, 3.0, 2.0], 1e-5);
3716 }
3717
3718 #[test]
3719 fn minmax_value_and_active_operand() {
3720 let e = Expr::Sum(vec![
3723 Expr::MinList(vec![var(0), var(1)]),
3724 Expr::MaxList(vec![var(0), var(1)]),
3725 ]);
3726 let t = Tape::build(&e);
3727 let x = [1.3, -0.4];
3729 assert!((t.eval(&x) - (x[0] + x[1])).abs() < 1e-12);
3730 let mut g = vec![0.0; 2];
3731 t.gradient_seed(&x, 1.0, &mut g);
3732 assert!((g[0] - 1.0).abs() < 1e-12, "g0={}", g[0]);
3735 assert!((g[1] - 1.0).abs() < 1e-12, "g1={}", g[1]);
3736 }
3737
3738 #[test]
3739 fn hessian_division_matches_fd() {
3740 let e = add(div(var(0), var(1)), unary(UnaryOp::Cos, var(0)));
3742 let t = Tape::build(&e);
3743 fd_check(&t, &[0.5, 1.2], 2, 1e-5);
3744 }
3745
3746 #[test]
3747 fn conditional_value_grad_hessian_active_branch() {
3748 let e = cond(
3752 cmp(CmpOp::Ge, var(0), cnst(1.0)),
3753 mul(var(0), var(1)),
3754 pow(var(1), cnst(2.0)),
3755 );
3756 let t = Tape::build(&e);
3757
3758 let x = [2.0, 5.0];
3760 assert!((t.eval(&x) - 10.0).abs() < 1e-12);
3761 let mut g = vec![0.0; 2];
3762 t.gradient_seed(&x, 1.0, &mut g);
3763 assert!((g[0] - 5.0).abs() < 1e-10);
3765 assert!((g[1] - 2.0).abs() < 1e-10);
3766 fd_check(&t, &x, 2, 1e-5);
3768
3769 let x2 = [0.0, 5.0];
3771 assert!((t.eval(&x2) - 25.0).abs() < 1e-12);
3772 let mut g2 = vec![0.0; 2];
3773 t.gradient_seed(&x2, 1.0, &mut g2);
3774 assert!(g2[0].abs() < 1e-10);
3775 assert!((g2[1] - 10.0).abs() < 1e-10);
3776 fd_check(&t, &x2, 2, 1e-5);
3777 }
3778
3779 #[test]
3780 fn comparison_and_logical_have_zero_derivative() {
3781 let lt = cmp(CmpOp::Lt, var(0), var(1));
3785 let and = Expr::And(
3786 Box::new(cmp(CmpOp::Gt, var(0), cnst(0.0))),
3787 Box::new(cmp(CmpOp::Gt, var(1), cnst(0.0))),
3788 );
3789 let notc = Expr::Not(Box::new(cmp(CmpOp::Eq, var(0), var(1))));
3790 let e = add(add(lt, and), notc);
3791 let t = Tape::build(&e);
3792
3793 let x = [1.0, 2.0];
3794 assert!((t.eval(&x) - 3.0).abs() < 1e-12);
3796 let mut g = vec![0.0; 2];
3797 t.gradient_seed(&x, 1.0, &mut g);
3798 assert!(g[0].abs() < 1e-12, "d/dx0 should be 0, got {}", g[0]);
3799 assert!(g[1].abs() < 1e-12, "d/dx1 should be 0, got {}", g[1]);
3800 }
3801
3802 #[test]
3803 fn logical_or_value() {
3804 let e = Expr::Or(
3806 Box::new(cmp(CmpOp::Gt, var(0), cnst(0.0))),
3807 Box::new(cmp(CmpOp::Gt, var(1), cnst(0.0))),
3808 );
3809 let t = Tape::build(&e);
3810 assert!((t.eval(&[-1.0, 3.0]) - 1.0).abs() < 1e-12);
3811 assert!((t.eval(&[-1.0, -3.0]) - 0.0).abs() < 1e-12);
3812 }
3813
3814 fn directional_matches_accumulate(tape: &Tape, x: &[f64], n: usize) {
3819 let vars = tape.variables();
3820 let mut hess_map: HashMap<(usize, usize), usize> = HashMap::new();
3821 let mut pairs = Vec::new();
3822 for (ai, &vi) in vars.iter().enumerate() {
3823 for &vj in &vars[..=ai] {
3824 let (r, c) = if vi >= vj { (vi, vj) } else { (vj, vi) };
3825 hess_map.entry((r, c)).or_insert_with(|| {
3826 let p = pairs.len();
3827 pairs.push((r, c));
3828 p
3829 });
3830 }
3831 }
3832 let nnz = pairs.len();
3833 let mut ad = vec![0.0; nnz];
3834 tape.hessian_accumulate(x, 1.0, &hess_map, &mut ad);
3835
3836 let nops = tape.ops.len();
3837 let mut vals = vec![0.0; nops];
3838 tape.forward_into(x, &mut vals);
3839 let mut dot = vec![0.0; nops];
3840 let mut adj = vec![0.0; nops];
3841 let mut adj_dot = vec![0.0; nops];
3842
3843 for &j in &vars {
3844 let mut seed = vec![0.0; n];
3845 seed[j] = 1.0;
3846 let mut col = vec![0.0; n];
3847 tape.hessian_directional(
3848 &vals,
3849 &seed,
3850 1.0,
3851 &mut col,
3852 &mut dot,
3853 &mut adj,
3854 &mut adj_dot,
3855 );
3856 for &i in &vars {
3857 let (r, c) = if i >= j { (i, j) } else { (j, i) };
3858 let expect = ad[hess_map[&(r, c)]];
3859 assert!(
3860 (col[i] - expect).abs() < 1e-10,
3861 "directional H[{i},{j}] = {} vs accumulate {}",
3862 col[i],
3863 expect
3864 );
3865 }
3866 }
3867 }
3868
3869 #[test]
3870 fn directional_quadratic_matches_accumulate() {
3871 let e = add(
3873 add(
3874 mul(cnst(3.0), pow(var(0), cnst(2.0))),
3875 mul(mul(cnst(2.0), var(0)), var(1)),
3876 ),
3877 pow(var(1), cnst(2.0)),
3878 );
3879 let t = Tape::build(&e);
3880 directional_matches_accumulate(&t, &[0.5, -0.3], 2);
3881 }
3882
3883 #[test]
3884 fn directional_transcendental_matches_accumulate() {
3885 let e = Expr::Sum(vec![
3886 unary(UnaryOp::Exp, var(0)),
3887 unary(UnaryOp::Sin, var(1)),
3888 unary(UnaryOp::Log, var(0)),
3889 unary(UnaryOp::Sqrt, var(1)),
3890 mul(var(0), var(1)),
3891 ]);
3892 let t = Tape::build(&e);
3893 directional_matches_accumulate(&t, &[1.5, 2.0], 2);
3894 }
3895
3896 #[test]
3897 fn directional_with_division_matches_accumulate() {
3898 let e = add(div(var(0), var(1)), unary(UnaryOp::Cos, var(0)));
3899 let t = Tape::build(&e);
3900 directional_matches_accumulate(&t, &[0.5, 1.2], 2);
3901 }
3902
3903 #[test]
3904 fn hessian_sparsity_separable() {
3905 let e = add(unary(UnaryOp::Sin, var(0)), mul(var(1), var(2)));
3907 let t = Tape::build(&e);
3908 let s = t.hessian_sparsity();
3909 assert!(s.contains(&(0, 0)));
3910 assert!(s.contains(&(2, 1)));
3911 assert!(!s.contains(&(1, 0)));
3912 assert!(!s.contains(&(2, 0)));
3913 }
3914
3915 fn count_op<F: Fn(&TapeOp) -> bool>(t: &Tape, pred: F) -> usize {
3916 t.ops.iter().filter(|o| pred(o)).count()
3917 }
3918
3919 #[test]
3920 fn pow_zero_const_folds_to_one() {
3921 let e = pow(var(0), cnst(0.0));
3923 let t = Tape::build(&e);
3924 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
3925 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Var(_))), 0);
3926 assert!((t.eval(&[7.0]) - 1.0).abs() < 1e-12);
3927 }
3928
3929 #[test]
3930 fn pow_one_passes_through() {
3931 let e = pow(var(0), cnst(1.0));
3933 let t = Tape::build(&e);
3934 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
3935 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Const(_))), 0);
3936 assert!((t.eval(&[3.5]) - 3.5).abs() < 1e-12);
3937 }
3938
3939 #[test]
3940 fn pow_half_lowers_to_sqrt() {
3941 let e = pow(var(0), cnst(0.5));
3942 let t = Tape::build(&e);
3943 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
3944 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Sqrt(_))), 1);
3945 assert!((t.eval(&[16.0]) - 4.0).abs() < 1e-12);
3946 }
3947
3948 #[test]
3949 fn pow_two_lowers_to_single_mul() {
3950 let e = pow(var(0), cnst(2.0));
3951 let t = Tape::build(&e);
3952 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
3953 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 1);
3954 assert!((t.eval(&[3.0]) - 9.0).abs() < 1e-12);
3955 }
3956
3957 #[test]
3958 fn pow_three_lowers_to_two_muls() {
3959 let e = pow(var(0), cnst(3.0));
3960 let t = Tape::build(&e);
3961 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
3962 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 2);
3963 assert!((t.eval(&[2.0]) - 8.0).abs() < 1e-12);
3964 }
3965
3966 #[test]
3967 fn pow_eight_lowers_to_three_muls() {
3968 let e = pow(var(0), cnst(8.0));
3970 let t = Tape::build(&e);
3971 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
3972 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 3);
3973 assert!((t.eval(&[2.0]) - 256.0).abs() < 1e-12);
3974 }
3975
3976 #[test]
3977 fn pow_negative_two_lowers_to_div() {
3978 let e = pow(var(0), cnst(-2.0));
3980 let t = Tape::build(&e);
3981 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
3982 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Div(..))), 1);
3983 assert!((t.eval(&[4.0]) - (1.0 / 16.0)).abs() < 1e-12);
3984 }
3985
3986 #[test]
3987 fn pow_large_const_stays_generic() {
3988 let e = pow(var(0), cnst(9.0));
3990 let t = Tape::build(&e);
3991 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 1);
3992 }
3993
3994 #[test]
3995 fn pow_non_integer_const_stays_generic() {
3996 let e = pow(var(0), cnst(1.5));
3998 let t = Tape::build(&e);
3999 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 1);
4000 }
4001
4002 #[test]
4003 fn pow_const_through_cse_const() {
4004 let two = Rc::new(cnst(2.0));
4006 let e = Expr::Binary(BinOp::Pow, Box::new(var(0)), Box::new(Expr::Cse(two)));
4007 let t = Tape::build(&e);
4008 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Pow(..))), 0);
4009 assert_eq!(count_op(&t, |o| matches!(o, TapeOp::Mul(..))), 1);
4010 }
4011
4012 #[test]
4013 fn hessian_pow_three_matches_fd() {
4014 let e = add(mul(cnst(5.0), pow(var(0), cnst(3.0))), mul(var(0), var(1)));
4016 let t = Tape::build(&e);
4017 fd_check(&t, &[1.7, 0.8], 2, 1e-5);
4018 }
4019
4020 #[test]
4021 fn hessian_pow_negative_matches_fd() {
4022 let e = add(pow(var(0), cnst(-2.0)), pow(var(1), cnst(2.0)));
4024 let t = Tape::build(&e);
4025 fd_check(&t, &[1.3, 2.4], 2, 1e-5);
4026 }
4027
4028 #[test]
4029 fn hessian_pow_half_matches_fd() {
4030 let e = add(pow(var(0), cnst(0.5)), mul(var(0), var(1)));
4032 let t = Tape::build(&e);
4033 fd_check(&t, &[2.5, 1.1], 2, 1e-5);
4034 }
4035
4036 #[test]
4037 fn hessian_sparsity_through_cse() {
4038 let body = Rc::new(add(var(0), var(1)));
4041 let e = add(
4042 pow(Expr::Cse(body.clone()), cnst(2.0)),
4043 Expr::Cse(body.clone()),
4044 );
4045 let t = Tape::build(&e);
4046 let s = t.hessian_sparsity();
4047 assert!(s.contains(&(0, 0)));
4048 assert!(s.contains(&(1, 0)));
4049 assert!(s.contains(&(1, 1)));
4050 assert_eq!(s.len(), 3);
4051 }
4052}