netem_trace/model/
solve_truncate.rs

1//! This module provides the `solve` function to solve the following problem:
2//!
3//! Given a lowerbound (default 0) and an upperbound (defaunt +inf) and the std_dev of a normal distribution,
4//! find out a center of the distribution, such that the mathmetical expectation of the distribution (after
5//! truncated by the given lower and upper bound) equals an expected value that is between the lower and upper
6//! bound.
7//!
8//! Enable `truncated-normal` feature to use this module.
9//!
10//! Use example can be found in the doc of `model::bw::NormalizedBwConfig::build_truncated`.
11//!
12//!
13//! Notation:
14//! ```txt
15//!     cdf(t, avg, sigma) is the cumulative distribution function of a normal distribution,
16//!         whose center is avg and standard derivation is sigma, with respect to t
17//!
18//!     pdf(t, avg, sigma) is the probability density function of a normal distribution,
19//!         whose center is avg and standard derivation is sigma, with respect to t
20//! ```
21use statrs::function::erf::erf;
22use std::f64::consts::PI;
23
24/// Calculates the mathmetical expectation of the following distribution of t,
25/// whose Cumulative Distribution Function (CDF(t)) is:
26/// ```text
27///     CDF(t) = 0, if t < lower
28///     CDF(t) = 1, if t > upper
29///     CDF(t) = cdf(t, avg, sigma)
30/// ```
31///
32/// if lower or upper are given as `None`, default values of 0.0 and +inf respectively are used.
33///
34/// The calculation is separated into three addicative parts.
35/// 1. $$\int_{\text{lower}}^{\text{upper}} t \times \text{pdf}(t, \text{avg}, \text{sigma} ) \ \text dt $$
36///
37///     The indefinite integral of above is calculated in `integral`
38///
39/// 2. $$upper \times (1 - cdf(upper, avg, sigma))$$
40///
41/// 3. $$lower \times cdf(lower, avg, sigma)$$
42///
43///
44fn truncated_bandwidth(avg: f64, sigma: f64, lower: Option<f64>, upper: Option<f64>) -> f64 {
45    //upper_integral - lower_integral is the integral described in the doc, which is part 1 of the calculation.
46    let upper_integral = if let Some(upper) = upper {
47        integral(avg, upper, sigma)
48    } else {
49        //default upper as +inf
50        avg * 0.5f64
51    };
52
53    let lower_integral = if let Some(lower) = lower {
54        integral(avg, lower, sigma)
55    } else {
56        integral(avg, 0f64, sigma)
57    };
58
59    // part 2 of the calculation as described in the doc.
60    let upper_truncate = if let Some(upper) = upper {
61        upper * (1f64 - cdf(upper, avg, sigma))
62    } else {
63        0.0f64
64    };
65
66    // part 3 of the calculation as described in the doc.
67    let lower_truncate = if let Some(lower) = lower {
68        lower * cdf(lower, avg, sigma)
69    } else {
70        0.0f64
71    };
72
73    upper_integral - lower_integral + lower_truncate + upper_truncate
74}
75
76/// An indefinite integral:
77///     $$\int t \times \text{pdf}(t, \text{avg}, \text{sigma} ) \ \text dt $$
78///
79/// Used in `truncated_bandwidth`.
80///
81fn integral(avg: f64, t: f64, sigma: f64) -> f64 {
82    let part1 = avg * 0.5f64 * erf((t - avg) / sigma / 2.0_f64.sqrt());
83    let part2 =
84        -sigma / (2.0f64 * PI).sqrt() * (-(t - avg) * (t - avg) * 0.5f64 / sigma / sigma).exp();
85
86    part1 + part2
87}
88
89/// The cumulative distribution function of a normal distribution,
90///     whose center is avg and standard derivation is sigma, with respect to t
91///
92/// Used in `truncated_bandwidth`.
93///
94///
95fn cdf(t: f64, avg: f64, sigma: f64) -> f64 {
96    0.5f64 * (1f64 + erf((t - avg) / sigma / 2f64.sqrt()))
97}
98
99/// The derivative of `truncated_bandwidth` with respect to `avg`.
100/// As `truncated_bandwidth` is calculated in addicative parts, here calculates the derivative of it in
101/// addicative parts, part by part.
102///
103fn deri_truncated_bandwidth(avg: f64, sigma: f64, lower: Option<f64>, upper: Option<f64>) -> f64 {
104    let upper_integral = if let Some(upper) = upper {
105        deri_integral(avg, upper, sigma)
106    } else {
107        //default upper as +inf
108        0.5f64
109    };
110
111    let lower_integral = if let Some(lower) = lower {
112        deri_integral(avg, lower, sigma)
113    } else {
114        deri_integral(avg, 0f64, sigma)
115    };
116
117    let upper_truncate = if let Some(upper) = upper {
118        upper * (-deri_cdf(upper, avg, sigma))
119    } else {
120        0.0f64
121    };
122
123    let lower_truncate = if let Some(lower) = lower {
124        lower * deri_cdf(lower, avg, sigma)
125    } else {
126        0.0f64
127    };
128
129    upper_integral - lower_integral + lower_truncate + upper_truncate
130}
131
132/// Partial derivative of the following respect to `avg`.
133///     $$\int t \times \text{pdf}(t, \text{avg}, \text{sigma} ) \ \text dt $$
134///
135/// Used in `derivative_truncated_bandwidth`.
136///
137fn deri_integral(avg: f64, t: f64, sigma: f64) -> f64 {
138    let part1 = 0.5f64 * erf((t - avg) / sigma / 2.0_f64.sqrt());
139    let part2 = (-(t - avg) * (t - avg) * 0.5f64 / sigma / sigma).exp() * (-t)
140        / (2.0f64 * PI).sqrt()
141        / sigma;
142    part1 + part2
143}
144
145/// Partial derivative of the following respect to `avg`.
146///     cdf(t, avg, sigma)
147///
148/// Used in `derivative_truncated_bandwidth`.
149///
150fn deri_cdf(t: f64, avg: f64, sigma: f64) -> f64 {
151    -(-(t - avg) * (t - avg) / 2.0f64 / sigma / sigma).exp() / sigma / (2.0f64 * PI).sqrt()
152}
153
154/// Solve the problem descirbed at the head of this file with Newtown's method, which requires f(x) and f'(x) to
155/// solve f(x) = 0. Here f(x) is `truncated_bandwidth` and f'(x) us `derivative_truncated_bandwidth`
156///
157///
158/// Parameters:
159///     x : target mathematical expectation of the truncated normal distribution
160///     sigma: the standard deviation of the normal distribution before truncation
161///     lower: the lower bound of the truncation, default 0 if None is provided
162///     upper: the upper bound of the truncation, default +inf if None is provided
163///
164/// Return value:
165///     if a solution is found for the problem, returns the cernter of the normal distribution before truncation
166///     else (aka the sanity check of the parameters failed), returns None.
167///
168/// The units of the parameters above should be consistent, which is the unit of the return value.
169///
170/// ## Examples
171///
172/// ```
173/// use netem_trace::model::solve_truncate::solve;
174/// let a = solve(8.0, 2.0, Some(4.0), Some(12.0)).unwrap();
175/// assert!((a-8.0).abs() < 0.000001);
176///
177/// let a = solve(10.0, 4.0, Some(4.0), Some(12.0)).unwrap();
178/// assert_eq!(a, 11.145871035156846);
179///
180/// let a = solve(10.0, 20.0, None, None).unwrap();
181/// assert_eq!(a, 3.7609851997619734);
182///
183/// let a = solve(5.0, 18.0, None, None).unwrap();
184/// assert_eq!(a, -4.888296757781897);
185///
186/// let a = solve(10.0, 20.0, Some(7.0), Some(15.0)).unwrap();
187/// assert_eq!(a, 4.584705225916618);
188///
189/// let a = solve(10.0, 0.01, Some(7.0), Some(15.0)).unwrap();
190/// assert_eq!(a, 10.0);
191///
192/// let a = solve(10.0, 0.01, None, Some(15.0)).unwrap();
193/// assert_eq!(a, 10.0);
194///
195/// let a = solve(10.0, 0.01, None, None).unwrap();
196/// assert_eq!(a, 10.0);
197///
198/// let a = solve(10.0, 0.01, Some(3.0), None).unwrap();
199/// assert_eq!(a, 10.0);
200/// ```
201///
202pub fn solve(x: f64, sigma: f64, mut lower: Option<f64>, upper: Option<f64>) -> Option<f64> {
203    if sigma.abs() <= f64::EPSILON {
204        return Some(x);
205    }
206    //sanity check
207    if lower.is_some_and(|lower| lower >= x * (1.0 + f64::EPSILON)) {
208        return lower;
209    }
210
211    if lower.is_none() && x <= f64::EPSILON {
212        return 0f64.into();
213    }
214
215    if upper.is_some_and(|upper| upper * (1.0 + f64::EPSILON) <= x) {
216        return upper;
217    }
218
219    let mut result = x;
220
221    if lower.is_some_and(|l| l < 0.0) || lower.is_none() {
222        lower = Some(0.0f64);
223    }
224
225    let mut last_diff = f64::MAX;
226    let mut run_cnt = 10;
227
228    while run_cnt > 0 {
229        let f_x = truncated_bandwidth(result, sigma, lower, upper);
230
231        let diff = (f_x - x).abs();
232        if diff < last_diff {
233            last_diff = diff;
234            run_cnt = 100;
235        } else {
236            run_cnt -= 1;
237        }
238
239        result = result - (f_x - x) / deri_truncated_bandwidth(result, sigma, lower, upper);
240    }
241
242    Some(result)
243}
244
245#[cfg(test)]
246mod tests {
247
248    use super::*;
249    use rand::rngs::StdRng;
250    use rand::{Rng, SeedableRng};
251
252    fn test_deri<F, G>(func: F, deri: G, low: f64, high: f64)
253    where
254        F: Fn(f64) -> f64,
255        G: Fn(f64) -> f64,
256    {
257        let mut rng = StdRng::seed_from_u64(42);
258
259        for _ in 0..1000 {
260            let x = rng.random_range(low..high);
261            let eps = 5E-8 * (low + high);
262            let delta1 = func(x + eps) - func(x);
263            let delta2 = eps * deri(x + eps * 0.5);
264            dbg!(delta1, delta2);
265            if delta1 * delta2 > 0.0 {
266                assert!(delta1 / delta2 < 1.0000001);
267                assert!(delta2 / delta1 < 1.0000001);
268            } else {
269                assert!(delta1.abs() < f32::EPSILON.into());
270                assert!(delta2.abs() < f32::EPSILON.into());
271            }
272        }
273    }
274
275    #[test]
276    fn test_truncated_bandwidth() {
277        assert_eq!(
278            truncated_bandwidth(10.0, 5.0, None, None),
279            10.042453513094314
280        );
281
282        test_deri(
283            |x| truncated_bandwidth(x, 3.0, None, None),
284            |x| deri_truncated_bandwidth(x, 3.0, None, None),
285            0.0,
286            10.0,
287        );
288
289        test_deri(
290            |x| truncated_bandwidth(x, 3.0, Some(3.0), None),
291            |x| deri_truncated_bandwidth(x, 3.0, Some(3.0), None),
292            0.0,
293            10.0,
294        );
295
296        test_deri(
297            |x| truncated_bandwidth(x, 3.0, Some(3.0), Some(20.0)),
298            |x| deri_truncated_bandwidth(x, 3.0, Some(3.0), Some(20.0)),
299            0.0,
300            10.0,
301        );
302    }
303
304    #[test]
305    fn test_integral() {
306        assert_eq!(integral(10.0, 5.0, 2.0), -4.972959947732017);
307        assert_eq!(integral(10.0, 1E9, 2.0), 5.0);
308        assert_eq!(integral(10.0, -1E9, 2.0), -5.0);
309
310        test_deri(
311            |x| integral(x, 8.0, 5.0),
312            |x| deri_integral(x, 8.0, 5.0),
313            6.0,
314            10.0,
315        );
316
317        test_deri(
318            |x| integral(x, 8.0, 15.0),
319            |x| deri_integral(x, 8.0, 15.0),
320            6.0,
321            10.0,
322        );
323    }
324
325    #[test]
326    fn test_cdf() {
327        assert_eq!(
328            cdf(12.0, 12.0, 4.0) * 2.0 - 1.0, // 0 sigma
329            0.0
330        );
331        assert_eq!(
332            cdf(16.0, 12.0, 4.0) - cdf(8.0, 12.0, 4.0), // 1 sigma
333            0.6826894921098856
334        );
335        assert_eq!(
336            cdf(20.0, 12.0, 4.0) - cdf(4.0, 12.0, 4.0), // 2 sigma
337            0.9544997361056748
338        );
339        assert_eq!(
340            cdf(24.0, 12.0, 4.0) - cdf(0.0, 12.0, 4.0), // 2 sigma
341            0.997300203936851
342        );
343
344        test_deri(|x| cdf(8.0, x, 5.0), |x| deri_cdf(8.0, x, 5.0), 6.0, 10.0);
345
346        test_deri(
347            |x| cdf(123.0, x, 15.0),
348            |x| deri_cdf(123.0, x, 15.0),
349            110.0,
350            136.0,
351        );
352    }
353}