netem_trace/model/solve_truncate.rs
1//! This module provides the `solve` function to solve the following problem:
2//!
3//! Given a lowerbound (default 0) and an upperbound (defaunt +inf) and the std_dev of a normal distribution,
4//! find out a center of the distribution, such that the mathmetical expectation of the distribution (after
5//! truncated by the given lower and upper bound) equals an expected value that is between the lower and upper
6//! bound.
7//!
8//! Enable `truncated-normal` feature to use this module.
9//!
10//! Use example can be found in the doc of `model::bw::NormalizedBwConfig::build_truncated`.
11//!
12//!
13//! Notation:
14//! ```txt
15//! cdf(t, avg, sigma) is the cumulative distribution function of a normal distribution,
16//! whose center is avg and standard derivation is sigma, with respect to t
17//!
18//! pdf(t, avg, sigma) is the probability density function of a normal distribution,
19//! whose center is avg and standard derivation is sigma, with respect to t
20//! ```
21use statrs::function::erf::erf;
22use std::f64::consts::PI;
23
24/// Calculates the mathmetical expectation of the following distribution of t,
25/// whose Cumulative Distribution Function (CDF(t)) is:
26/// ```text
27/// CDF(t) = 0, if t < lower
28/// CDF(t) = 1, if t > upper
29/// CDF(t) = cdf(t, avg, sigma)
30/// ```
31///
32/// if lower or upper are given as `None`, default values of 0.0 and +inf respectively are used.
33///
34/// The calculation is separated into three addicative parts.
35/// 1. $$\int_{\text{lower}}^{\text{upper}} t \times \text{pdf}(t, \text{avg}, \text{sigma} ) \ \text dt $$
36///
37/// The indefinite integral of above is calculated in `integral`
38///
39/// 2. $$upper \times (1 - cdf(upper, avg, sigma))$$
40///
41/// 3. $$lower \times cdf(lower, avg, sigma)$$
42///
43///
44fn truncated_bandwidth(avg: f64, sigma: f64, lower: Option<f64>, upper: Option<f64>) -> f64 {
45 //upper_integral - lower_integral is the integral described in the doc, which is part 1 of the calculation.
46 let upper_integral = if let Some(upper) = upper {
47 integral(avg, upper, sigma)
48 } else {
49 //default upper as +inf
50 avg * 0.5f64
51 };
52
53 let lower_integral = if let Some(lower) = lower {
54 integral(avg, lower, sigma)
55 } else {
56 integral(avg, 0f64, sigma)
57 };
58
59 // part 2 of the calculation as described in the doc.
60 let upper_truncate = if let Some(upper) = upper {
61 upper * (1f64 - cdf(upper, avg, sigma))
62 } else {
63 0.0f64
64 };
65
66 // part 3 of the calculation as described in the doc.
67 let lower_truncate = if let Some(lower) = lower {
68 lower * cdf(lower, avg, sigma)
69 } else {
70 0.0f64
71 };
72
73 upper_integral - lower_integral + lower_truncate + upper_truncate
74}
75
76/// An indefinite integral:
77/// $$\int t \times \text{pdf}(t, \text{avg}, \text{sigma} ) \ \text dt $$
78///
79/// Used in `truncated_bandwidth`.
80///
81fn integral(avg: f64, t: f64, sigma: f64) -> f64 {
82 let part1 = avg * 0.5f64 * erf((t - avg) / sigma / 2.0_f64.sqrt());
83 let part2 =
84 -sigma / (2.0f64 * PI).sqrt() * (-(t - avg) * (t - avg) * 0.5f64 / sigma / sigma).exp();
85
86 part1 + part2
87}
88
89/// The cumulative distribution function of a normal distribution,
90/// whose center is avg and standard derivation is sigma, with respect to t
91///
92/// Used in `truncated_bandwidth`.
93///
94///
95fn cdf(t: f64, avg: f64, sigma: f64) -> f64 {
96 0.5f64 * (1f64 + erf((t - avg) / sigma / 2f64.sqrt()))
97}
98
99/// The derivative of `truncated_bandwidth` with respect to `avg`.
100/// As `truncated_bandwidth` is calculated in addicative parts, here calculates the derivative of it in
101/// addicative parts, part by part.
102///
103fn deri_truncated_bandwidth(avg: f64, sigma: f64, lower: Option<f64>, upper: Option<f64>) -> f64 {
104 let upper_integral = if let Some(upper) = upper {
105 deri_integral(avg, upper, sigma)
106 } else {
107 //default upper as +inf
108 0.5f64
109 };
110
111 let lower_integral = if let Some(lower) = lower {
112 deri_integral(avg, lower, sigma)
113 } else {
114 deri_integral(avg, 0f64, sigma)
115 };
116
117 let upper_truncate = if let Some(upper) = upper {
118 upper * (-deri_cdf(upper, avg, sigma))
119 } else {
120 0.0f64
121 };
122
123 let lower_truncate = if let Some(lower) = lower {
124 lower * deri_cdf(lower, avg, sigma)
125 } else {
126 0.0f64
127 };
128
129 upper_integral - lower_integral + lower_truncate + upper_truncate
130}
131
132/// Partial derivative of the following respect to `avg`.
133/// $$\int t \times \text{pdf}(t, \text{avg}, \text{sigma} ) \ \text dt $$
134///
135/// Used in `derivative_truncated_bandwidth`.
136///
137fn deri_integral(avg: f64, t: f64, sigma: f64) -> f64 {
138 let part1 = 0.5f64 * erf((t - avg) / sigma / 2.0_f64.sqrt());
139 let part2 = (-(t - avg) * (t - avg) * 0.5f64 / sigma / sigma).exp() * (-t)
140 / (2.0f64 * PI).sqrt()
141 / sigma;
142 part1 + part2
143}
144
145/// Partial derivative of the following respect to `avg`.
146/// cdf(t, avg, sigma)
147///
148/// Used in `derivative_truncated_bandwidth`.
149///
150fn deri_cdf(t: f64, avg: f64, sigma: f64) -> f64 {
151 -(-(t - avg) * (t - avg) / 2.0f64 / sigma / sigma).exp() / sigma / (2.0f64 * PI).sqrt()
152}
153
154/// Solve the problem descirbed at the head of this file with Newtown's method, which requires f(x) and f'(x) to
155/// solve f(x) = 0. Here f(x) is `truncated_bandwidth` and f'(x) us `derivative_truncated_bandwidth`
156///
157///
158/// Parameters:
159/// x : target mathematical expectation of the truncated normal distribution
160/// sigma: the standard deviation of the normal distribution before truncation
161/// lower: the lower bound of the truncation, default 0 if None is provided
162/// upper: the upper bound of the truncation, default +inf if None is provided
163///
164/// Return value:
165/// if a solution is found for the problem, returns the cernter of the normal distribution before truncation
166/// else (aka the sanity check of the parameters failed), returns None.
167///
168/// The units of the parameters above should be consistent, which is the unit of the return value.
169///
170/// ## Examples
171///
172/// ```
173/// use netem_trace::model::solve_truncate::solve;
174/// let a = solve(8.0, 2.0, Some(4.0), Some(12.0)).unwrap();
175/// assert!((a-8.0).abs() < 0.000001);
176///
177/// let a = solve(10.0, 4.0, Some(4.0), Some(12.0)).unwrap();
178/// assert_eq!(a, 11.145871035156846);
179///
180/// let a = solve(10.0, 20.0, None, None).unwrap();
181/// assert_eq!(a, 3.7609851997619734);
182///
183/// let a = solve(5.0, 18.0, None, None).unwrap();
184/// assert_eq!(a, -4.888296757781897);
185///
186/// let a = solve(10.0, 20.0, Some(7.0), Some(15.0)).unwrap();
187/// assert_eq!(a, 4.584705225916618);
188///
189/// let a = solve(10.0, 0.01, Some(7.0), Some(15.0)).unwrap();
190/// assert_eq!(a, 10.0);
191///
192/// let a = solve(10.0, 0.01, None, Some(15.0)).unwrap();
193/// assert_eq!(a, 10.0);
194///
195/// let a = solve(10.0, 0.01, None, None).unwrap();
196/// assert_eq!(a, 10.0);
197///
198/// let a = solve(10.0, 0.01, Some(3.0), None).unwrap();
199/// assert_eq!(a, 10.0);
200/// ```
201///
202pub fn solve(x: f64, sigma: f64, mut lower: Option<f64>, upper: Option<f64>) -> Option<f64> {
203 if sigma.abs() <= f64::EPSILON {
204 return Some(x);
205 }
206 //sanity check
207 if lower.is_some_and(|lower| lower >= x * (1.0 + f64::EPSILON)) {
208 return lower;
209 }
210
211 if lower.is_none() && x <= f64::EPSILON {
212 return 0f64.into();
213 }
214
215 if upper.is_some_and(|upper| upper * (1.0 + f64::EPSILON) <= x) {
216 return upper;
217 }
218
219 let mut result = x;
220
221 if lower.is_some_and(|l| l < 0.0) || lower.is_none() {
222 lower = Some(0.0f64);
223 }
224
225 let mut last_diff = f64::MAX;
226 let mut run_cnt = 10;
227
228 while run_cnt > 0 {
229 let f_x = truncated_bandwidth(result, sigma, lower, upper);
230
231 let diff = (f_x - x).abs();
232 if diff < last_diff {
233 last_diff = diff;
234 run_cnt = 100;
235 } else {
236 run_cnt -= 1;
237 }
238
239 result = result - (f_x - x) / deri_truncated_bandwidth(result, sigma, lower, upper);
240 }
241
242 Some(result)
243}
244
245#[cfg(test)]
246mod tests {
247
248 use super::*;
249 use rand::rngs::StdRng;
250 use rand::{Rng, SeedableRng};
251
252 fn test_deri<F, G>(func: F, deri: G, low: f64, high: f64)
253 where
254 F: Fn(f64) -> f64,
255 G: Fn(f64) -> f64,
256 {
257 let mut rng = StdRng::seed_from_u64(42);
258
259 for _ in 0..1000 {
260 let x = rng.random_range(low..high);
261 let eps = 5E-8 * (low + high);
262 let delta1 = func(x + eps) - func(x);
263 let delta2 = eps * deri(x + eps * 0.5);
264 dbg!(delta1, delta2);
265 if delta1 * delta2 > 0.0 {
266 assert!(delta1 / delta2 < 1.0000001);
267 assert!(delta2 / delta1 < 1.0000001);
268 } else {
269 assert!(delta1.abs() < f32::EPSILON.into());
270 assert!(delta2.abs() < f32::EPSILON.into());
271 }
272 }
273 }
274
275 #[test]
276 fn test_truncated_bandwidth() {
277 assert_eq!(
278 truncated_bandwidth(10.0, 5.0, None, None),
279 10.042453513094314
280 );
281
282 test_deri(
283 |x| truncated_bandwidth(x, 3.0, None, None),
284 |x| deri_truncated_bandwidth(x, 3.0, None, None),
285 0.0,
286 10.0,
287 );
288
289 test_deri(
290 |x| truncated_bandwidth(x, 3.0, Some(3.0), None),
291 |x| deri_truncated_bandwidth(x, 3.0, Some(3.0), None),
292 0.0,
293 10.0,
294 );
295
296 test_deri(
297 |x| truncated_bandwidth(x, 3.0, Some(3.0), Some(20.0)),
298 |x| deri_truncated_bandwidth(x, 3.0, Some(3.0), Some(20.0)),
299 0.0,
300 10.0,
301 );
302 }
303
304 #[test]
305 fn test_integral() {
306 assert_eq!(integral(10.0, 5.0, 2.0), -4.972959947732017);
307 assert_eq!(integral(10.0, 1E9, 2.0), 5.0);
308 assert_eq!(integral(10.0, -1E9, 2.0), -5.0);
309
310 test_deri(
311 |x| integral(x, 8.0, 5.0),
312 |x| deri_integral(x, 8.0, 5.0),
313 6.0,
314 10.0,
315 );
316
317 test_deri(
318 |x| integral(x, 8.0, 15.0),
319 |x| deri_integral(x, 8.0, 15.0),
320 6.0,
321 10.0,
322 );
323 }
324
325 #[test]
326 fn test_cdf() {
327 assert_eq!(
328 cdf(12.0, 12.0, 4.0) * 2.0 - 1.0, // 0 sigma
329 0.0
330 );
331 assert_eq!(
332 cdf(16.0, 12.0, 4.0) - cdf(8.0, 12.0, 4.0), // 1 sigma
333 0.6826894921098856
334 );
335 assert_eq!(
336 cdf(20.0, 12.0, 4.0) - cdf(4.0, 12.0, 4.0), // 2 sigma
337 0.9544997361056748
338 );
339 assert_eq!(
340 cdf(24.0, 12.0, 4.0) - cdf(0.0, 12.0, 4.0), // 2 sigma
341 0.997300203936851
342 );
343
344 test_deri(|x| cdf(8.0, x, 5.0), |x| deri_cdf(8.0, x, 5.0), 6.0, 10.0);
345
346 test_deri(
347 |x| cdf(123.0, x, 15.0),
348 |x| deri_cdf(123.0, x, 15.0),
349 110.0,
350 136.0,
351 );
352 }
353}