bio/stats/probs/
adaptive_integration.rs

1// Copyright 2021-2022 Johannes Köster.
2// Licensed under the MIT license (http://opensource.org/licenses/MIT)
3// This file may not be copied, modified, or distributed
4// except according to those terms.
5
6use std::cmp;
7use std::collections::HashMap;
8use std::convert::{Into, TryFrom};
9use std::hash::Hash;
10use std::{
11    fmt::Debug,
12    ops::{Add, Div, Mul, Sub},
13};
14
15use crate::stats::probs::LogProb;
16use itertools::Itertools;
17use itertools_num::linspace;
18use ordered_float::NotNan;
19
20/// Integrate over an interval of type T with a given density function while trying to minimize
21/// the number of grid points evaluated and still hit the maximum likelihood point.
22/// This is achieved via a binary search over the grid points.
23/// The assumption is that the density is unimodal. If that is not the case,
24/// the binary search will not find the maximum and the integral can miss features.
25///
26/// # Example
27///
28/// ```rust
29/// use approx::abs_diff_eq;
30/// use bio::stats::probs::adaptive_integration::ln_integrate_exp;
31/// use bio::stats::probs::{LogProb, Prob};
32/// use ordered_float::NotNan;
33/// use statrs::distribution::{Continuous, Normal};
34/// use statrs::statistics::Distribution;
35///
36/// let ndist = Normal::new(0.0, 1.0).unwrap();
37///
38/// let integral = ln_integrate_exp(
39///     |x| LogProb::from(Prob(ndist.pdf(*x))),
40///     NotNan::new(-1.0).unwrap(),
41///     NotNan::new(1.0).unwrap(),
42///     NotNan::new(0.01).unwrap(),
43/// );
44/// abs_diff_eq!(integral.exp(), 0.682, epsilon = 0.01);
45/// ```
46pub fn ln_integrate_exp<T, F, E>(
47    mut density: F,
48    min_point: T,
49    max_point: T,
50    max_resolution: T,
51) -> LogProb
52where
53    T: Copy
54        + Add<Output = T>
55        + Sub<Output = T>
56        + Div<Output = T>
57        + Div<NotNan<f64>, Output = T>
58        + Mul<Output = T>
59        + Into<f64>
60        + TryFrom<f64, Error = E>
61        + Ord
62        + Debug
63        + Hash,
64    E: Debug,
65    F: FnMut(T) -> LogProb,
66    f64: From<T>,
67{
68    let mut probs = HashMap::new();
69
70    let mut grid_point = |point, probs: &mut HashMap<_, _>| {
71        probs.insert(point, density(point));
72        point
73    };
74    let middle_grid_point = |left: T, right: T| (right + left) / NotNan::new(2.0).unwrap();
75    // METHOD:
76    // Step 1: perform binary search for maximum likelihood point
77    // Remember all points.
78    let mut left = grid_point(min_point, &mut probs);
79    let mut right = grid_point(max_point, &mut probs);
80    let mut first_middle = None;
81    let mut middle = None;
82
83    while (((right - left) >= max_resolution) && left < right) || middle.is_none() {
84        middle = Some(grid_point(middle_grid_point(left, right), &mut probs));
85
86        if first_middle.is_none() {
87            first_middle = middle;
88        }
89
90        let left_prob = probs.get(&left).unwrap();
91        let right_prob = probs.get(&right).unwrap();
92
93        if left_prob > right_prob {
94            // investigate left window more closely
95            right = middle.unwrap();
96        } else {
97            // investigate right window more closely
98            left = middle.unwrap();
99        }
100    }
101    // After that loop, we are guaranteed that middle.is_some().
102    let middle = middle.unwrap();
103    let first_middle = first_middle.unwrap();
104    // METHOD: add additional grid point in the initially abandoned arm
105    if middle < first_middle {
106        grid_point(middle_grid_point(first_middle, max_point), &mut probs);
107    } else {
108        grid_point(middle_grid_point(min_point, first_middle), &mut probs);
109    }
110    // METHOD additionally investigate small interval around the optimum
111    for point in linspace(
112        cmp::max(
113            T::try_from(middle.into() - max_resolution.into() * 3.0).unwrap(),
114            min_point,
115        )
116        .into(),
117        middle.into(),
118        4,
119    )
120    .take(3)
121    .chain(
122        linspace(
123            middle.into(),
124            cmp::min(
125                T::try_from(middle.into() + max_resolution.into() * 3.0).unwrap(),
126                max_point,
127            )
128            .into(),
129            4,
130        )
131        .skip(1),
132    ) {
133        grid_point(T::try_from(point).unwrap(), &mut probs);
134    }
135
136    let sorted_grid_points: Vec<f64> = probs.keys().sorted().map(|point| (*point).into()).collect();
137
138    // METHOD:
139    // Step 2: integrate over grid points visited during the binary search.
140    LogProb::ln_trapezoidal_integrate_grid_exp::<f64, _>(
141        |_, g| *probs.get(&T::try_from(g).unwrap()).unwrap(),
142        &sorted_grid_points,
143    )
144}