bio/stats/probs/
adaptive_integration.rs1use std::cmp;
7use std::collections::HashMap;
8use std::convert::{Into, TryFrom};
9use std::hash::Hash;
10use std::{
11 fmt::Debug,
12 ops::{Add, Div, Mul, Sub},
13};
14
15use crate::stats::probs::LogProb;
16use itertools::Itertools;
17use itertools_num::linspace;
18use ordered_float::NotNan;
19
20pub fn ln_integrate_exp<T, F, E>(
47 mut density: F,
48 min_point: T,
49 max_point: T,
50 max_resolution: T,
51) -> LogProb
52where
53 T: Copy
54 + Add<Output = T>
55 + Sub<Output = T>
56 + Div<Output = T>
57 + Div<NotNan<f64>, Output = T>
58 + Mul<Output = T>
59 + Into<f64>
60 + TryFrom<f64, Error = E>
61 + Ord
62 + Debug
63 + Hash,
64 E: Debug,
65 F: FnMut(T) -> LogProb,
66 f64: From<T>,
67{
68 let mut probs = HashMap::new();
69
70 let mut grid_point = |point, probs: &mut HashMap<_, _>| {
71 probs.insert(point, density(point));
72 point
73 };
74 let middle_grid_point = |left: T, right: T| (right + left) / NotNan::new(2.0).unwrap();
75 let mut left = grid_point(min_point, &mut probs);
79 let mut right = grid_point(max_point, &mut probs);
80 let mut first_middle = None;
81 let mut middle = None;
82
83 while (((right - left) >= max_resolution) && left < right) || middle.is_none() {
84 middle = Some(grid_point(middle_grid_point(left, right), &mut probs));
85
86 if first_middle.is_none() {
87 first_middle = middle;
88 }
89
90 let left_prob = probs.get(&left).unwrap();
91 let right_prob = probs.get(&right).unwrap();
92
93 if left_prob > right_prob {
94 right = middle.unwrap();
96 } else {
97 left = middle.unwrap();
99 }
100 }
101 let middle = middle.unwrap();
103 let first_middle = first_middle.unwrap();
104 if middle < first_middle {
106 grid_point(middle_grid_point(first_middle, max_point), &mut probs);
107 } else {
108 grid_point(middle_grid_point(min_point, first_middle), &mut probs);
109 }
110 for point in linspace(
112 cmp::max(
113 T::try_from(middle.into() - max_resolution.into() * 3.0).unwrap(),
114 min_point,
115 )
116 .into(),
117 middle.into(),
118 4,
119 )
120 .take(3)
121 .chain(
122 linspace(
123 middle.into(),
124 cmp::min(
125 T::try_from(middle.into() + max_resolution.into() * 3.0).unwrap(),
126 max_point,
127 )
128 .into(),
129 4,
130 )
131 .skip(1),
132 ) {
133 grid_point(T::try_from(point).unwrap(), &mut probs);
134 }
135
136 let sorted_grid_points: Vec<f64> = probs.keys().sorted().map(|point| (*point).into()).collect();
137
138 LogProb::ln_trapezoidal_integrate_grid_exp::<f64, _>(
141 |_, g| *probs.get(&T::try_from(g).unwrap()).unwrap(),
142 &sorted_grid_points,
143 )
144}