Skip to main content

laddu_extensions/optimize/
ganesh.rs

1use ganesh::traits::{CostFunction, Gradient, LogDensity};
2use laddu_core::{LadduError, LadduResult};
3use nalgebra::DVector;
4
5use crate::{
6    likelihood::{LikelihoodTerm, StochasticNLL},
7    optimize::MaybeThreadPool,
8    LikelihoodExpression, NLL,
9};
10
11impl CostFunction<MaybeThreadPool, LadduError> for NLL {
12    fn evaluate(&self, parameters: &DVector<f64>, args: &MaybeThreadPool) -> LadduResult<f64> {
13        args.install(|| LikelihoodTerm::evaluate(self, parameters.into()))
14    }
15}
16impl Gradient<MaybeThreadPool, LadduError> for NLL {
17    fn gradient(
18        &self,
19        parameters: &DVector<f64>,
20        args: &MaybeThreadPool,
21    ) -> LadduResult<DVector<f64>> {
22        args.install(|| LikelihoodTerm::evaluate_gradient(self, parameters.into()))
23    }
24}
25impl LogDensity<MaybeThreadPool, LadduError> for NLL {
26    fn log_density(&self, parameters: &DVector<f64>, args: &MaybeThreadPool) -> LadduResult<f64> {
27        Ok(-args.install(|| LikelihoodTerm::evaluate(self, parameters.into()))?)
28    }
29}
30
31impl CostFunction<MaybeThreadPool, LadduError> for StochasticNLL {
32    fn evaluate(&self, parameters: &DVector<f64>, args: &MaybeThreadPool) -> LadduResult<f64> {
33        args.install(|| LikelihoodTerm::evaluate(self, parameters.into()))
34    }
35}
36impl Gradient<MaybeThreadPool, LadduError> for StochasticNLL {
37    fn gradient(
38        &self,
39        parameters: &DVector<f64>,
40        args: &MaybeThreadPool,
41    ) -> LadduResult<DVector<f64>> {
42        args.install(|| LikelihoodTerm::evaluate_gradient(self, parameters.into()))
43    }
44}
45impl LogDensity<MaybeThreadPool, LadduError> for StochasticNLL {
46    fn log_density(&self, parameters: &DVector<f64>, args: &MaybeThreadPool) -> LadduResult<f64> {
47        Ok(-args.install(|| LikelihoodTerm::evaluate(self, parameters.into()))?)
48    }
49}
50
51impl CostFunction<MaybeThreadPool, LadduError> for LikelihoodExpression {
52    fn evaluate(&self, parameters: &DVector<f64>, args: &MaybeThreadPool) -> LadduResult<f64> {
53        args.install(|| LikelihoodTerm::evaluate(self, parameters.into()))
54    }
55}
56impl Gradient<MaybeThreadPool, LadduError> for LikelihoodExpression {
57    fn gradient(
58        &self,
59        parameters: &DVector<f64>,
60        args: &MaybeThreadPool,
61    ) -> LadduResult<DVector<f64>> {
62        args.install(|| LikelihoodTerm::evaluate_gradient(self, parameters.into()))
63    }
64}
65impl LogDensity<MaybeThreadPool, LadduError> for LikelihoodExpression {
66    fn log_density(&self, parameters: &DVector<f64>, args: &MaybeThreadPool) -> LadduResult<f64> {
67        Ok(-args.install(|| LikelihoodTerm::evaluate(self, parameters.into()))?)
68    }
69}
70
71#[cfg(test)]
72mod tests {
73    use super::MaybeThreadPool;
74
75    #[test]
76    fn maybe_thread_pool_handles_repeated_short_installs() {
77        let pool = MaybeThreadPool::new(2);
78        let total = (0usize..64)
79            .map(|index| {
80                pool.install(|| Ok(index + 1))
81                    .expect("repeated install should succeed")
82            })
83            .sum::<usize>();
84        assert_eq!(total, (1usize..=64).sum::<usize>());
85    }
86}