1use crate::common::IntegrateFloat;
7use crate::error::{IntegrateError, IntegrateResult};
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
9use std::cell::RefCell;
10use std::collections::HashMap;
11use std::rc::Rc;
12
13#[derive(Debug, Clone)]
15pub enum Operation<F: IntegrateFloat> {
16 Variable(usize),
18 Constant(F),
20 Add(usize, usize),
22 Sub(usize, usize),
24 Mul(usize, usize),
26 Div(usize, usize),
28 Neg(usize),
30 Pow(usize, F),
32 PowGeneral(usize, usize),
34 Sin(usize),
36 Cos(usize),
38 Tan(usize),
40 Exp(usize),
42 Ln(usize),
44 Sqrt(usize),
46 Tanh(usize),
48 Sinh(usize),
50 Cosh(usize),
52 Atan2(usize, usize),
54 Abs(usize),
56 Max(usize, usize),
58 Min(usize, usize),
60}
61
62pub struct TapeNode<F: IntegrateFloat> {
64 pub value: F,
66 pub operation: Operation<F>,
68 pub gradient: RefCell<F>,
70}
71
72impl<F: IntegrateFloat> TapeNode<F> {
73 pub fn new(value: F, operation: Operation<F>) -> Self {
75 TapeNode {
76 value,
77 operation,
78 gradient: RefCell::new(F::zero()),
79 }
80 }
81}
82
83pub struct Tape<F: IntegrateFloat> {
85 nodes: Vec<Rc<TapeNode<F>>>,
87 var_map: HashMap<usize, usize>,
89}
90
91impl<F: IntegrateFloat> Tape<F> {
92 pub fn new() -> Self {
94 Tape {
95 nodes: Vec::new(),
96 var_map: HashMap::new(),
97 }
98 }
99
100 pub fn variable(&mut self, idx: usize, value: F) -> usize {
102 let nodeidx = self.nodes.len();
103 self.nodes
104 .push(Rc::new(TapeNode::new(value, Operation::Variable(idx))));
105 self.var_map.insert(idx, nodeidx);
106 nodeidx
107 }
108
109 pub fn constant(&mut self, value: F) -> usize {
111 let nodeidx = self.nodes.len();
112 self.nodes
113 .push(Rc::new(TapeNode::new(value, Operation::Constant(value))));
114 nodeidx
115 }
116
117 pub fn add(&mut self, a: usize, b: usize) -> usize {
119 let value = self.nodes[a].value + self.nodes[b].value;
120 let nodeidx = self.nodes.len();
121 self.nodes
122 .push(Rc::new(TapeNode::new(value, Operation::Add(a, b))));
123 nodeidx
124 }
125
126 pub fn sub(&mut self, a: usize, b: usize) -> usize {
128 let value = self.nodes[a].value - self.nodes[b].value;
129 let nodeidx = self.nodes.len();
130 self.nodes
131 .push(Rc::new(TapeNode::new(value, Operation::Sub(a, b))));
132 nodeidx
133 }
134
135 pub fn mul(&mut self, a: usize, b: usize) -> usize {
137 let value = self.nodes[a].value * self.nodes[b].value;
138 let nodeidx = self.nodes.len();
139 self.nodes
140 .push(Rc::new(TapeNode::new(value, Operation::Mul(a, b))));
141 nodeidx
142 }
143
144 pub fn div(&mut self, a: usize, b: usize) -> usize {
146 let value = self.nodes[a].value / self.nodes[b].value;
147 let nodeidx = self.nodes.len();
148 self.nodes
149 .push(Rc::new(TapeNode::new(value, Operation::Div(a, b))));
150 nodeidx
151 }
152
153 pub fn neg(&mut self, a: usize) -> usize {
155 let value = -self.nodes[a].value;
156 let nodeidx = self.nodes.len();
157 self.nodes
158 .push(Rc::new(TapeNode::new(value, Operation::Neg(a))));
159 nodeidx
160 }
161
162 pub fn pow(&mut self, a: usize, n: F) -> usize {
164 let value = self.nodes[a].value.powf(n);
165 let nodeidx = self.nodes.len();
166 self.nodes
167 .push(Rc::new(TapeNode::new(value, Operation::Pow(a, n))));
168 nodeidx
169 }
170
171 pub fn sin(&mut self, a: usize) -> usize {
173 let value = self.nodes[a].value.sin();
174 let nodeidx = self.nodes.len();
175 self.nodes
176 .push(Rc::new(TapeNode::new(value, Operation::Sin(a))));
177 nodeidx
178 }
179
180 pub fn cos(&mut self, a: usize) -> usize {
182 let value = self.nodes[a].value.cos();
183 let nodeidx = self.nodes.len();
184 self.nodes
185 .push(Rc::new(TapeNode::new(value, Operation::Cos(a))));
186 nodeidx
187 }
188
189 pub fn exp(&mut self, a: usize) -> usize {
191 let value = self.nodes[a].value.exp();
192 let nodeidx = self.nodes.len();
193 self.nodes
194 .push(Rc::new(TapeNode::new(value, Operation::Exp(a))));
195 nodeidx
196 }
197
198 pub fn ln(&mut self, a: usize) -> usize {
200 let value = self.nodes[a].value.ln();
201 let nodeidx = self.nodes.len();
202 self.nodes
203 .push(Rc::new(TapeNode::new(value, Operation::Ln(a))));
204 nodeidx
205 }
206
207 pub fn sqrt(&mut self, a: usize) -> usize {
209 let value = self.nodes[a].value.sqrt();
210 let nodeidx = self.nodes.len();
211 self.nodes
212 .push(Rc::new(TapeNode::new(value, Operation::Sqrt(a))));
213 nodeidx
214 }
215
216 pub fn pow_general(&mut self, a: usize, b: usize) -> usize {
218 let value = self.nodes[a].value.powf(self.nodes[b].value);
219 let nodeidx = self.nodes.len();
220 self.nodes
221 .push(Rc::new(TapeNode::new(value, Operation::PowGeneral(a, b))));
222 nodeidx
223 }
224
225 pub fn tan(&mut self, a: usize) -> usize {
227 let value = self.nodes[a].value.tan();
228 let nodeidx = self.nodes.len();
229 self.nodes
230 .push(Rc::new(TapeNode::new(value, Operation::Tan(a))));
231 nodeidx
232 }
233
234 pub fn tanh(&mut self, a: usize) -> usize {
236 let value = self.nodes[a].value.tanh();
237 let nodeidx = self.nodes.len();
238 self.nodes
239 .push(Rc::new(TapeNode::new(value, Operation::Tanh(a))));
240 nodeidx
241 }
242
243 pub fn sinh(&mut self, a: usize) -> usize {
245 let value = self.nodes[a].value.sinh();
246 let nodeidx = self.nodes.len();
247 self.nodes
248 .push(Rc::new(TapeNode::new(value, Operation::Sinh(a))));
249 nodeidx
250 }
251
252 pub fn cosh(&mut self, a: usize) -> usize {
254 let value = self.nodes[a].value.cosh();
255 let nodeidx = self.nodes.len();
256 self.nodes
257 .push(Rc::new(TapeNode::new(value, Operation::Cosh(a))));
258 nodeidx
259 }
260
261 pub fn atan2(&mut self, y: usize, x: usize) -> usize {
263 let value = self.nodes[y].value.atan2(self.nodes[x].value);
264 let nodeidx = self.nodes.len();
265 self.nodes
266 .push(Rc::new(TapeNode::new(value, Operation::Atan2(y, x))));
267 nodeidx
268 }
269
270 pub fn abs(&mut self, a: usize) -> usize {
272 let value = self.nodes[a].value.abs();
273 let nodeidx = self.nodes.len();
274 self.nodes
275 .push(Rc::new(TapeNode::new(value, Operation::Abs(a))));
276 nodeidx
277 }
278
279 pub fn max(&mut self, a: usize, b: usize) -> usize {
281 let value = self.nodes[a].value.max(self.nodes[b].value);
282 let nodeidx = self.nodes.len();
283 self.nodes
284 .push(Rc::new(TapeNode::new(value, Operation::Max(a, b))));
285 nodeidx
286 }
287
288 pub fn min(&mut self, a: usize, b: usize) -> usize {
290 let value = self.nodes[a].value.min(self.nodes[b].value);
291 let nodeidx = self.nodes.len();
292 self.nodes
293 .push(Rc::new(TapeNode::new(value, Operation::Min(a, b))));
294 nodeidx
295 }
296
297 pub fn value(&self, idx: usize) -> F {
299 self.nodes[idx].value
300 }
301
302 pub fn backward(&mut self, outputidx: usize, nvars: usize) -> Array1<F> {
304 for node in &self.nodes {
306 *node.gradient.borrow_mut() = F::zero();
307 }
308
309 *self.nodes[outputidx].gradient.borrow_mut() = F::one();
311
312 for i in (0..=outputidx).rev() {
314 let node = &self.nodes[i];
315 let grad = *node.gradient.borrow();
316
317 if grad.abs() < F::epsilon() {
318 continue;
319 }
320
321 match &node.operation {
322 Operation::Variable(_) | Operation::Constant(_) => {}
323 Operation::Add(a, b) => {
324 *self.nodes[*a].gradient.borrow_mut() += grad;
325 *self.nodes[*b].gradient.borrow_mut() += grad;
326 }
327 Operation::Sub(a, b) => {
328 *self.nodes[*a].gradient.borrow_mut() += grad;
329 *self.nodes[*b].gradient.borrow_mut() -= grad;
330 }
331 Operation::Mul(a, b) => {
332 *self.nodes[*a].gradient.borrow_mut() += grad * self.nodes[*b].value;
333 *self.nodes[*b].gradient.borrow_mut() += grad * self.nodes[*a].value;
334 }
335 Operation::Div(a, b) => {
336 let b_val = self.nodes[*b].value;
337 *self.nodes[*a].gradient.borrow_mut() += grad / b_val;
338 *self.nodes[*b].gradient.borrow_mut() -=
339 grad * self.nodes[*a].value / (b_val * b_val);
340 }
341 Operation::Neg(a) => {
342 *self.nodes[*a].gradient.borrow_mut() -= grad;
343 }
344 Operation::Pow(a, n) => {
345 *self.nodes[*a].gradient.borrow_mut() +=
346 grad * *n * self.nodes[*a].value.powf(*n - F::one());
347 }
348 Operation::Sin(a) => {
349 *self.nodes[*a].gradient.borrow_mut() += grad * self.nodes[*a].value.cos();
350 }
351 Operation::Cos(a) => {
352 *self.nodes[*a].gradient.borrow_mut() -= grad * self.nodes[*a].value.sin();
353 }
354 Operation::Exp(a) => {
355 *self.nodes[*a].gradient.borrow_mut() += grad * node.value;
356 }
357 Operation::Ln(a) => {
358 *self.nodes[*a].gradient.borrow_mut() += grad / self.nodes[*a].value;
359 }
360 Operation::Sqrt(a) => {
361 *self.nodes[*a].gradient.borrow_mut() +=
362 grad / (F::from(2.0).unwrap() * node.value);
363 }
364 Operation::PowGeneral(a, b) => {
365 let a_val = self.nodes[*a].value;
368 let b_val = self.nodes[*b].value;
369 *self.nodes[*a].gradient.borrow_mut() +=
370 grad * b_val * a_val.powf(b_val - F::one());
371 *self.nodes[*b].gradient.borrow_mut() += grad * node.value * a_val.ln();
372 }
373 Operation::Tan(a) => {
374 let cos_val = self.nodes[*a].value.cos();
376 *self.nodes[*a].gradient.borrow_mut() += grad / (cos_val * cos_val);
377 }
378 Operation::Tanh(a) => {
379 let tanh_val = node.value;
381 *self.nodes[*a].gradient.borrow_mut() +=
382 grad * (F::one() - tanh_val * tanh_val);
383 }
384 Operation::Sinh(a) => {
385 *self.nodes[*a].gradient.borrow_mut() += grad * self.nodes[*a].value.cosh();
387 }
388 Operation::Cosh(a) => {
389 *self.nodes[*a].gradient.borrow_mut() += grad * self.nodes[*a].value.sinh();
391 }
392 Operation::Atan2(y, x) => {
393 let x_val = self.nodes[*x].value;
396 let y_val = self.nodes[*y].value;
397 let denom = x_val * x_val + y_val * y_val;
398 *self.nodes[*y].gradient.borrow_mut() += grad * x_val / denom;
399 *self.nodes[*x].gradient.borrow_mut() -= grad * y_val / denom;
400 }
401 Operation::Abs(a) => {
402 let sign = if self.nodes[*a].value >= F::zero() {
404 F::one()
405 } else {
406 -F::one()
407 };
408 *self.nodes[*a].gradient.borrow_mut() += grad * sign;
409 }
410 Operation::Max(a, b) => {
411 if self.nodes[*a].value >= self.nodes[*b].value {
413 *self.nodes[*a].gradient.borrow_mut() += grad;
414 } else {
415 *self.nodes[*b].gradient.borrow_mut() += grad;
416 }
417 }
418 Operation::Min(a, b) => {
419 if self.nodes[*a].value <= self.nodes[*b].value {
421 *self.nodes[*a].gradient.borrow_mut() += grad;
422 } else {
423 *self.nodes[*b].gradient.borrow_mut() += grad;
424 }
425 }
426 }
427 }
428
429 let mut gradients = Array1::zeros(nvars);
431 for (varidx, &nodeidx) in &self.var_map {
432 if *varidx < nvars {
433 gradients[*varidx] = *self.nodes[nodeidx].gradient.borrow();
434 }
435 }
436
437 gradients
438 }
439}
440
441impl<F: IntegrateFloat> Default for Tape<F> {
442 fn default() -> Self {
443 Self::new()
444 }
445}
446
447#[derive(Debug, Clone, Copy)]
449pub enum CheckpointStrategy {
450 None,
452 FixedInterval(usize),
454 Logarithmic,
456 MemoryBased { max_nodes: usize },
458}
459
460pub struct ReverseAD<F: IntegrateFloat> {
462 nvars: usize,
464 checkpoint_strategy: CheckpointStrategy,
466 _phantom: std::marker::PhantomData<F>,
467}
468
469impl<F: IntegrateFloat> ReverseAD<F> {
470 pub fn new(nvars: usize) -> Self {
472 ReverseAD {
473 nvars,
474 checkpoint_strategy: CheckpointStrategy::None,
475 _phantom: std::marker::PhantomData,
476 }
477 }
478
479 pub fn with_checkpoint_strategy(mut self, strategy: CheckpointStrategy) -> Self {
481 self.checkpoint_strategy = strategy;
482 self
483 }
484
485 pub fn gradient<Func>(&mut self, f: Func, x: ArrayView1<F>) -> IntegrateResult<Array1<F>>
487 where
488 Func: Fn(&mut Tape<F>, &[usize]) -> usize,
489 {
490 if x.len() != self.nvars {
491 return Err(IntegrateError::DimensionMismatch(format!(
492 "Expected {} variables, got {}",
493 self.nvars,
494 x.len()
495 )));
496 }
497
498 let mut tape = Tape::new();
499 let mut var_indices = Vec::new();
500
501 for (i, &val) in x.iter().enumerate() {
503 let idx = tape.variable(i, val);
504 var_indices.push(idx);
505 }
506
507 let outputidx = f(&mut tape, &var_indices);
509
510 Ok(tape.backward(outputidx, self.nvars))
512 }
513
514 pub fn jacobian<Func>(&mut self, f: Func, x: ArrayView1<F>) -> IntegrateResult<Array2<F>>
516 where
517 Func: Fn(&mut Tape<F>, &[usize]) -> Vec<usize>,
518 {
519 if x.len() != self.nvars {
520 return Err(IntegrateError::DimensionMismatch(format!(
521 "Expected {} variables, got {}",
522 self.nvars,
523 x.len()
524 )));
525 }
526
527 let mut tape = Tape::new();
528 let mut var_indices = Vec::new();
529
530 for (i, &val) in x.iter().enumerate() {
532 let idx = tape.variable(i, val);
533 var_indices.push(idx);
534 }
535
536 let output_indices = f(&mut tape, &var_indices);
538 let m = output_indices.len();
539
540 let mut jacobian = Array2::zeros((m, self.nvars));
541
542 for (i, &outputidx) in output_indices.iter().enumerate() {
544 let grad = tape.backward(outputidx, self.nvars);
545 jacobian.row_mut(i).assign(&grad);
546 }
547
548 Ok(jacobian)
549 }
550
551 pub fn hessian<Func>(&mut self, f: Func, x: ArrayView1<F>) -> IntegrateResult<Array2<F>>
553 where
554 Func: Fn(&mut Tape<F>, &[usize]) -> usize + Clone,
555 {
556 if x.len() != self.nvars {
557 return Err(IntegrateError::DimensionMismatch(format!(
558 "Expected {} variables, got {}",
559 self.nvars,
560 x.len()
561 )));
562 }
563
564 let mut hessian = Array2::zeros((self.nvars, self.nvars));
565 let eps = F::from(1e-8).unwrap();
566
567 for j in 0..self.nvars {
569 let mut x_plus = x.to_owned();
571 x_plus[j] += eps;
572
573 let grad_plus = self.gradient(f.clone(), x_plus.view())?;
574 let grad_base = self.gradient(f.clone(), x)?;
575
576 for i in 0..self.nvars {
578 hessian[[i, j]] = (grad_plus[i] - grad_base[i]) / eps;
579 }
580 }
581
582 for i in 0..self.nvars {
584 for j in (i + 1)..self.nvars {
585 let avg = (hessian[[i, j]] + hessian[[j, i]]) / F::from(2.0).unwrap();
586 hessian[[i, j]] = avg;
587 hessian[[j, i]] = avg;
588 }
589 }
590
591 Ok(hessian)
592 }
593
594 pub fn batch_gradient<Func>(
596 &mut self,
597 f: Func,
598 x_batch: &[Array1<F>],
599 ) -> IntegrateResult<Vec<Array1<F>>>
600 where
601 Func: Fn(&mut Tape<F>, &[usize]) -> usize + Clone,
602 {
603 let mut gradients = Vec::with_capacity(x_batch.len());
604
605 for x in x_batch {
606 gradients.push(self.gradient(f.clone(), x.view())?);
607 }
608
609 Ok(gradients)
610 }
611
612 pub fn jvp<Func>(
614 &mut self,
615 f: Func,
616 x: ArrayView1<F>,
617 v: ArrayView1<F>,
618 ) -> IntegrateResult<Array1<F>>
619 where
620 Func: Fn(&mut Tape<F>, &[usize]) -> Vec<usize>,
621 {
622 if x.len() != self.nvars || v.len() != self.nvars {
623 return Err(IntegrateError::DimensionMismatch(format!(
624 "Expected {} variables for both x and v",
625 self.nvars
626 )));
627 }
628
629 let eps = F::from(1e-8).unwrap();
631 let x_perturbed = &x + &(v.to_owned() * eps);
632
633 let mut tape = Tape::new();
634 let mut var_indices = Vec::new();
635 let mut var_indices_perturbed = Vec::new();
636
637 for (i, &val) in x.iter().enumerate() {
639 let idx = tape.variable(i, val);
640 var_indices.push(idx);
641 }
642
643 let output_base = f(&mut tape, &var_indices);
644
645 tape = Tape::new();
646 for (i, &val) in x_perturbed.iter().enumerate() {
647 let idx = tape.variable(i, val);
648 var_indices_perturbed.push(idx);
649 }
650
651 let output_perturbed = f(&mut tape, &var_indices_perturbed);
652
653 let mut jvp = Array1::zeros(output_base.len());
655 for (i, (&idx_base, &idx_pert)) in
656 output_base.iter().zip(output_perturbed.iter()).enumerate()
657 {
658 jvp[i] = (tape.value(idx_pert) - tape.value(idx_base)) / eps;
659 }
660
661 Ok(jvp)
662 }
663
664 pub fn vjp<Func>(
666 &mut self,
667 f: Func,
668 x: ArrayView1<F>,
669 v: ArrayView1<F>,
670 ) -> IntegrateResult<Array1<F>>
671 where
672 Func: Fn(&mut Tape<F>, &[usize]) -> Vec<usize>,
673 {
674 if x.len() != self.nvars {
675 return Err(IntegrateError::DimensionMismatch(format!(
676 "Expected {} variables",
677 self.nvars
678 )));
679 }
680
681 let mut tape = Tape::new();
682 let mut var_indices = Vec::new();
683
684 for (i, &val) in x.iter().enumerate() {
686 let idx = tape.variable(i, val);
687 var_indices.push(idx);
688 }
689
690 let output_indices = f(&mut tape, &var_indices);
692
693 if v.len() != output_indices.len() {
694 return Err(IntegrateError::DimensionMismatch(format!(
695 "Vector v length {} doesn't match output dimension {}",
696 v.len(),
697 output_indices.len()
698 )));
699 }
700
701 let mut weighted_sum = tape.constant(F::zero());
703 for (i, &outputidx) in output_indices.iter().enumerate() {
704 let v_i = tape.constant(v[i]);
705 let term = tape.mul(v_i, outputidx);
706 weighted_sum = tape.add(weighted_sum, term);
707 }
708
709 Ok(tape.backward(weighted_sum, self.nvars))
711 }
712}
713
714#[allow(dead_code)]
716pub fn reverse_gradient<F, Func>(f: Func, x: ArrayView1<F>) -> IntegrateResult<Array1<F>>
717where
718 F: IntegrateFloat,
719 Func: Fn(&mut Tape<F>, &[usize]) -> usize,
720{
721 let mut ad = ReverseAD::new(x.len());
722 ad.gradient(f, x)
723}
724
725#[allow(dead_code)]
727pub fn reverse_jacobian<F, Func>(f: Func, x: ArrayView1<F>) -> IntegrateResult<Array2<F>>
728where
729 F: IntegrateFloat,
730 Func: Fn(&mut Tape<F>, &[usize]) -> Vec<usize>,
731{
732 let mut ad = ReverseAD::new(x.len());
733 ad.jacobian(f, x)
734}
735
736#[cfg(test)]
737mod tests {
738 use super::*;
739
740 #[test]
741 fn test_reverse_gradient() {
742 let f = |tape: &mut Tape<f64>, vars: &[usize]| {
744 let x_sq = tape.mul(vars[0], vars[0]);
745 let y_sq = tape.mul(vars[1], vars[1]);
746 tape.add(x_sq, y_sq)
747 };
748
749 let x = Array1::from_vec(vec![3.0, 4.0]);
750 let grad = reverse_gradient(f, x.view()).unwrap();
751
752 assert!((grad[0] - 6.0).abs() < 1e-10);
754 assert!((grad[1] - 8.0).abs() < 1e-10);
755 }
756
757 #[test]
758 fn test_reverse_jacobian() {
759 let f = |tape: &mut Tape<f64>, vars: &[usize]| {
761 let x_sq = tape.mul(vars[0], vars[0]);
762 let xy = tape.mul(vars[0], vars[1]);
763 let y_sq = tape.mul(vars[1], vars[1]);
764 vec![x_sq, xy, y_sq]
765 };
766
767 let x = Array1::from_vec(vec![2.0, 3.0]);
768 let jac = reverse_jacobian(f, x.view()).unwrap();
769
770 assert!((jac[[0, 0]] - 4.0).abs() < 1e-10); assert!((jac[[0, 1]] - 0.0).abs() < 1e-10);
776 assert!((jac[[1, 0]] - 3.0).abs() < 1e-10); assert!((jac[[1, 1]] - 2.0).abs() < 1e-10); assert!((jac[[2, 0]] - 0.0).abs() < 1e-10);
779 assert!((jac[[2, 1]] - 6.0).abs() < 1e-10); }
781}