laddu_extensions/optimize/
ganesh.rs1use ganesh::traits::{CostFunction, Gradient, LogDensity};
2use laddu_core::{LadduError, LadduResult};
3use nalgebra::DVector;
4
5use crate::{
6 likelihood::{LikelihoodTerm, StochasticNLL},
7 optimize::MaybeThreadPool,
8 LikelihoodExpression, NLL,
9};
10
11impl CostFunction<MaybeThreadPool, LadduError> for NLL {
12 fn evaluate(&self, parameters: &DVector<f64>, args: &MaybeThreadPool) -> LadduResult<f64> {
13 args.install(|| LikelihoodTerm::evaluate(self, parameters.into()))
14 }
15}
16impl Gradient<MaybeThreadPool, LadduError> for NLL {
17 fn gradient(
18 &self,
19 parameters: &DVector<f64>,
20 args: &MaybeThreadPool,
21 ) -> LadduResult<DVector<f64>> {
22 args.install(|| LikelihoodTerm::evaluate_gradient(self, parameters.into()))
23 }
24}
25impl LogDensity<MaybeThreadPool, LadduError> for NLL {
26 fn log_density(&self, parameters: &DVector<f64>, args: &MaybeThreadPool) -> LadduResult<f64> {
27 Ok(-args.install(|| LikelihoodTerm::evaluate(self, parameters.into()))?)
28 }
29}
30
31impl CostFunction<MaybeThreadPool, LadduError> for StochasticNLL {
32 fn evaluate(&self, parameters: &DVector<f64>, args: &MaybeThreadPool) -> LadduResult<f64> {
33 args.install(|| LikelihoodTerm::evaluate(self, parameters.into()))
34 }
35}
36impl Gradient<MaybeThreadPool, LadduError> for StochasticNLL {
37 fn gradient(
38 &self,
39 parameters: &DVector<f64>,
40 args: &MaybeThreadPool,
41 ) -> LadduResult<DVector<f64>> {
42 args.install(|| LikelihoodTerm::evaluate_gradient(self, parameters.into()))
43 }
44}
45impl LogDensity<MaybeThreadPool, LadduError> for StochasticNLL {
46 fn log_density(&self, parameters: &DVector<f64>, args: &MaybeThreadPool) -> LadduResult<f64> {
47 Ok(-args.install(|| LikelihoodTerm::evaluate(self, parameters.into()))?)
48 }
49}
50
51impl CostFunction<MaybeThreadPool, LadduError> for LikelihoodExpression {
52 fn evaluate(&self, parameters: &DVector<f64>, args: &MaybeThreadPool) -> LadduResult<f64> {
53 args.install(|| LikelihoodTerm::evaluate(self, parameters.into()))
54 }
55}
56impl Gradient<MaybeThreadPool, LadduError> for LikelihoodExpression {
57 fn gradient(
58 &self,
59 parameters: &DVector<f64>,
60 args: &MaybeThreadPool,
61 ) -> LadduResult<DVector<f64>> {
62 args.install(|| LikelihoodTerm::evaluate_gradient(self, parameters.into()))
63 }
64}
65impl LogDensity<MaybeThreadPool, LadduError> for LikelihoodExpression {
66 fn log_density(&self, parameters: &DVector<f64>, args: &MaybeThreadPool) -> LadduResult<f64> {
67 Ok(-args.install(|| LikelihoodTerm::evaluate(self, parameters.into()))?)
68 }
69}
70
71#[cfg(test)]
72mod tests {
73 use super::MaybeThreadPool;
74
75 #[test]
76 fn maybe_thread_pool_handles_repeated_short_installs() {
77 let pool = MaybeThreadPool::new(2);
78 let total = (0usize..64)
79 .map(|index| {
80 pool.install(|| Ok(index + 1))
81 .expect("repeated install should succeed")
82 })
83 .sum::<usize>();
84 assert_eq!(total, (1usize..=64).sum::<usize>());
85 }
86}