gemlab/util/
num_divisions.rs

1use crate::StrError;
2
3/// Computes the number of divisions for GridSearch aiming at a square/cubic containers
4///
5/// Note: **long** means the longest direction whereas
6///       **other** corresponds to the `not-long` directions.
7///
8/// ```text
9/// ndiv_other = truncate((delta_other/delta_long) * ndiv_long)
10/// ndiv_other = max(ndiv_min, ndiv_other)
11/// ```
12///
13/// # Input
14///
15/// * `ndiv_min` -- minimum number of divisions for the other directions (≥ 1)
16/// * `ndiv_long` -- number of divisions for the longest direction (> ndiv_min)
17/// * `xmin` -- min coordinates (len = ndim)
18/// * `xmax` -- max coordinates (len = ndim); must be greater than min
19///
20/// # Output
21///
22/// * `ndiv` -- the number of divisions along each direction (ndim)
23///
24/// # Examples
25///
26/// ```
27/// use gemlab::util::num_divisions;
28/// use gemlab::StrError;
29///
30/// fn main() -> Result<(), StrError> {
31///     let xmin = &[0.0, 0.0];
32///     let xmax = &[1.0, 1.0];
33///     let ndiv = num_divisions(2, 5, xmin, xmax)?;
34///     assert_eq!(ndiv, &[5, 5]);
35///
36///     let xmax = &[1.0, 10.0];
37///     let ndiv = num_divisions(2, 5, xmin, xmax)?;
38///     assert_eq!(ndiv, &[2, 5]);
39///     Ok(())
40/// }
41/// ```
42pub fn num_divisions(ndiv_min: usize, ndiv_long: usize, xmin: &[f64], xmax: &[f64]) -> Result<Vec<usize>, StrError> {
43    if ndiv_min < 1 {
44        return Err("ndiv_min must be ≥ 1");
45    }
46    if ndiv_long <= ndiv_min {
47        return Err("ndiv_long must be > ndiv_min");
48    }
49    if xmin.len() != xmax.len() {
50        return Err("xmin.len() must be equal to xmax.len()");
51    }
52    let ndim = xmin.len();
53    if ndim < 2 {
54        return Err("ndim must be ≥ 2");
55    }
56    let delta: Vec<_> = xmin.iter().zip(xmax).map(|(a, b)| *b - *a).collect();
57    let mut index_long = 0;
58    let mut delta_long = delta[index_long];
59    for i in 1..ndim {
60        if delta[i] > delta_long {
61            index_long = i;
62            delta_long = delta[i];
63        }
64    }
65    let mut ndiv = vec![0; ndim];
66    for i in 0..ndim {
67        if delta[i] <= 0.0 {
68            return Err("xmax must be greater than xmin");
69        }
70        if i == index_long {
71            ndiv[i] = ndiv_long;
72        } else {
73            ndiv[i] = ((delta[i] / delta_long) * (ndiv_long as f64)) as usize;
74            ndiv[i] = usize::max(ndiv_min, ndiv[i]);
75        }
76    }
77    Ok(ndiv)
78}
79
80////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
81
82#[cfg(test)]
83mod tests {
84    use super::num_divisions;
85
86    #[test]
87    fn num_divisions_fails_on_errors() {
88        assert_eq!(
89            num_divisions(0, 20, &[0.0, 0.0], &[1.0, 1.0]).err(),
90            Some("ndiv_min must be ≥ 1")
91        );
92        assert_eq!(
93            num_divisions(2, 1, &[0.0, 0.0], &[1.0, 1.0]).err(),
94            Some("ndiv_long must be > ndiv_min")
95        );
96        assert_eq!(
97            num_divisions(2, 3, &[0.0, 0.0], &[1.0, 1.0, 1.0]).err(),
98            Some("xmin.len() must be equal to xmax.len()")
99        );
100        assert_eq!(num_divisions(2, 3, &[0.0], &[1.0]).err(), Some("ndim must be ≥ 2"));
101        assert_eq!(
102            num_divisions(2, 20, &[2.0, 2.0], &[0.0, 0.0]).err(),
103            Some("xmax must be greater than xmin")
104        );
105    }
106
107    #[test]
108    fn num_divisions_works() {
109        assert_eq!(num_divisions(1, 10, &[0.0, 0.0], &[1.0, 1.0]).unwrap(), &[10, 10]);
110        assert_eq!(num_divisions(2, 10, &[0.0, 0.0], &[1.0, 1.0]).unwrap(), &[10, 10]);
111        assert_eq!(num_divisions(3, 20, &[0.0, 0.0], &[100.0, 100.0]).unwrap(), &[20, 20]);
112        assert_eq!(num_divisions(3, 20, &[0.0, 0.0], &[1.0, 2.0]).unwrap(), &[10, 20]);
113        assert_eq!(num_divisions(3, 20, &[0.0, 0.0], &[2.0, 1.0]).unwrap(), &[20, 10]);
114        assert_eq!(num_divisions(3, 20, &[0.0, 0.0], &[1.0, 4.0]).unwrap(), &[5, 20]);
115        assert_eq!(num_divisions(3, 20, &[0.0, 0.0], &[1.0, 5.0]).unwrap(), &[4, 20]);
116        assert_eq!(num_divisions(2, 20, &[0.0, 0.0], &[1.0, 10.0]).unwrap(), &[2, 20]);
117        assert_eq!(num_divisions(1, 20, &[0.0, 0.0], &[1.0, 100.0]).unwrap(), &[1, 20]);
118        assert_eq!(num_divisions(2, 20, &[0.0, 0.0], &[1.0, 100.0]).unwrap(), &[2, 20]);
119        assert_eq!(num_divisions(3, 20, &[0.0, 0.0], &[10000.0, 1.0]).unwrap(), &[20, 3]);
120        assert_eq!(num_divisions(3, 20, &[0.0, 0.0], &[10.0, 3.0]).unwrap(), &[20, 6]);
121        assert_eq!(num_divisions(3, 20, &[0.0, 0.0], &[10.0, 3.5]).unwrap(), &[20, 7]);
122        assert_eq!(num_divisions(3, 20, &[-1.0, -2.0], &[1.0, 2.0]).unwrap(), &[10, 20]);
123    }
124}