echidna_optim/
objective.rs1use echidna::{BytecodeTape, Float};
2
3pub trait Objective<F: num_traits::Float> {
8 fn dim(&self) -> usize;
10
11 fn eval_grad(&mut self, x: &[F]) -> (F, Vec<F>);
15
16 fn eval_hessian(&mut self, x: &[F]) -> (F, Vec<F>, Vec<Vec<F>>) {
22 let _ = x;
23 unimplemented!("eval_hessian not implemented for this objective")
24 }
25
26 fn hvp(&mut self, x: &[F], v: &[F]) -> (Vec<F>, Vec<F>) {
32 let _ = (x, v);
33 unimplemented!("hvp not implemented for this objective")
34 }
35}
36
37pub struct TapeObjective<F: Float> {
39 tape: BytecodeTape<F>,
40 func_evals: usize,
41}
42
43impl<F: Float> TapeObjective<F> {
44 pub fn new(tape: BytecodeTape<F>) -> Self {
46 TapeObjective {
47 tape,
48 func_evals: 0,
49 }
50 }
51
52 pub fn func_evals(&self) -> usize {
54 self.func_evals
55 }
56
57 pub fn tape(&self) -> &BytecodeTape<F> {
59 &self.tape
60 }
61}
62
63impl<F: Float> Objective<F> for TapeObjective<F> {
64 fn dim(&self) -> usize {
65 self.tape.num_inputs()
66 }
67
68 fn eval_grad(&mut self, x: &[F]) -> (F, Vec<F>) {
69 self.func_evals += 1;
70 let grad = self.tape.gradient(x);
71 let value = self.tape.output_value();
72 (value, grad)
73 }
74
75 fn eval_hessian(&mut self, x: &[F]) -> (F, Vec<F>, Vec<Vec<F>>) {
76 self.func_evals += 1;
77 self.tape.hessian(x)
78 }
79
80 fn hvp(&mut self, x: &[F], v: &[F]) -> (Vec<F>, Vec<F>) {
81 self.func_evals += 1;
82 self.tape.hvp(x, v)
83 }
84}