plotpy/
auxiliary.rs

1/// Implements the sign function
2///
3/// ```text
4///           │ -1   if x < 0
5/// sign(x) = ┤  0   if x = 0
6///           │  1   if x > 0
7///
8///           |x|    x
9/// sign(x) = ——— = ———
10///            x    |x|
11///
12/// sign(x) = 2 · heaviside(x) - 1
13/// ```
14///
15/// Reference: <https://en.wikipedia.org/wiki/Sign_function>
16pub fn sign(x: f64) -> f64 {
17    if x < 0.0 {
18        -1.0
19    } else if x > 0.0 {
20        1.0
21    } else {
22        0.0
23    }
24}
25
26/// Implements the superquadric function involving sin(x)
27///
28/// ```text
29/// suq_sin(x;k) = sign(sin(x)) · |sin(x)|ᵏ
30/// ```
31///
32/// `suq_sin(x;k)` is the `f(ω;m)` function from <https://en.wikipedia.org/wiki/Superquadrics>
33pub fn suq_sin(x: f64, k: f64) -> f64 {
34    sign(f64::sin(x)) * f64::powf(f64::abs(f64::sin(x)), k)
35}
36
37/// Implements the superquadric auxiliary involving cos(x)
38///
39/// ```text
40/// suq_cos(x;k) = sign(cos(x)) · |cos(x)|ᵏ
41/// ```
42///
43/// `suq_cos(x;k)` is the `g(ω;m)` function from <https://en.wikipedia.org/wiki/Superquadrics>
44pub fn suq_cos(x: f64, k: f64) -> f64 {
45    sign(f64::cos(x)) * f64::powf(f64::abs(f64::cos(x)), k)
46}
47
48/// Returns evenly spaced numbers over a specified closed interval
49pub fn linspace(start: f64, stop: f64, count: usize) -> Vec<f64> {
50    if count == 0 {
51        return Vec::new();
52    }
53    let mut res = vec![0.0; count];
54    res[0] = start;
55    if count == 1 {
56        return res;
57    }
58    res[count - 1] = stop;
59    if count == 2 {
60        return res;
61    }
62    let den = (count - 1) as f64;
63    let step = (stop - start) / den;
64    for i in 1..count {
65        let p = i as f64;
66        res[i] = start + p * step;
67    }
68    res
69}
70
71/// Generates 2d points (meshgrid)
72///
73/// # Input
74///
75/// * `xmin`, `xmax` -- range along x
76/// * `ymin`, `ymax` -- range along y
77/// * `nx` -- is the number of points along x (must be `>= 2`)
78/// * `ny` -- is the number of points along y (must be `>= 2`)
79///
80/// # Output
81///
82/// * `x`, `y` -- (`ny` by `nx`) 2D arrays
83pub fn generate2d(xmin: f64, xmax: f64, ymin: f64, ymax: f64, nx: usize, ny: usize) -> (Vec<Vec<f64>>, Vec<Vec<f64>>) {
84    let mut x = vec![vec![0.0; nx]; ny];
85    let mut y = vec![vec![0.0; nx]; ny];
86    if nx == 0 || ny == 0 {
87        return (x, y);
88    }
89    let dx = if nx == 1 {
90        xmin
91    } else {
92        (xmax - xmin) / ((nx - 1) as f64)
93    };
94    let dy = if ny == 1 {
95        ymin
96    } else {
97        (ymax - ymin) / ((ny - 1) as f64)
98    };
99    for i in 0..ny {
100        let v = ymin + (i as f64) * dy;
101        for j in 0..nx {
102            let u = xmin + (j as f64) * dx;
103            x[i][j] = u;
104            y[i][j] = v;
105        }
106    }
107    (x, y)
108}
109
110/// Generates 3d points (function over meshgrid)
111///
112/// # Input
113///
114/// * `xmin`, `xmax` -- range along x
115/// * `ymin`, `ymax` -- range along y
116/// * `nx` -- is the number of points along x (must be `>= 2`)
117/// * `ny` -- is the number of points along y (must be `>= 2`)
118/// * `calc_z` -- is a function of (xij, yij) that calculates zij
119///
120/// # Output
121///
122/// * `x`, `y`, `z` -- (`ny` by `nx`) 2D arrays
123pub fn generate3d<F>(
124    xmin: f64,
125    xmax: f64,
126    ymin: f64,
127    ymax: f64,
128    nx: usize,
129    ny: usize,
130    mut calc_z: F,
131) -> (Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<Vec<f64>>)
132where
133    F: FnMut(f64, f64) -> f64,
134{
135    let mut x = vec![vec![0.0; nx]; ny];
136    let mut y = vec![vec![0.0; nx]; ny];
137    let mut z = vec![vec![0.0; nx]; ny];
138    if nx == 0 || ny == 0 {
139        return (x, y, z);
140    }
141    let dx = if nx == 1 {
142        xmin
143    } else {
144        (xmax - xmin) / ((nx - 1) as f64)
145    };
146    let dy = if ny == 1 {
147        ymin
148    } else {
149        (ymax - ymin) / ((ny - 1) as f64)
150    };
151    for i in 0..ny {
152        let v = ymin + (i as f64) * dy;
153        for j in 0..nx {
154            let u = xmin + (j as f64) * dx;
155            x[i][j] = u;
156            y[i][j] = v;
157            z[i][j] = calc_z(u, v);
158        }
159    }
160    (x, y, z)
161}
162
163////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
164
165#[cfg(test)]
166mod tests {
167    use super::{generate2d, generate3d, linspace, sign, suq_cos, suq_sin};
168
169    fn approx_eq(a: f64, b: f64, tol: f64) {
170        let diff = f64::abs(a - b);
171        if diff > tol {
172            panic!("numbers are not approximately equal. diff = {:?}", diff);
173        }
174    }
175
176    #[test]
177    #[should_panic(expected = "numbers are not approximately equal. diff = 1.0")]
178    fn approx_eq_captures_errors() {
179        approx_eq(1.0, 2.0, 1e-15);
180    }
181
182    #[test]
183    fn sign_works() {
184        let xx = [-2.0, -1.6, -1.2, -0.8, -0.4, 0.0, 0.4, 0.8, 1.2, 1.6, 2.0];
185        for x in xx {
186            let s = sign(x);
187            if x == 0.0 {
188                assert_eq!(s, 0.0);
189            } else {
190                assert_eq!(s, f64::abs(x) / x);
191            }
192        }
193    }
194
195    #[test]
196    fn suq_sin_and_cos_work() {
197        const PI: f64 = std::f64::consts::PI;
198        approx_eq(suq_sin(0.0, 1.0), 0.0, 1e-14);
199        approx_eq(suq_sin(PI, 1.0), 0.0, 1e-14);
200        approx_eq(suq_sin(PI / 2.0, 0.0), 1.0, 1e-14);
201        approx_eq(suq_sin(PI / 2.0, 1.0), 1.0, 1e-14);
202        approx_eq(suq_sin(PI / 2.0, 2.0), 1.0, 1e-14);
203        approx_eq(suq_sin(PI / 4.0, 2.0), 0.5, 1e-14);
204        approx_eq(suq_sin(-PI / 4.0, 2.0), -0.5, 1e-14);
205
206        approx_eq(suq_cos(0.0, 1.0), 1.0, 1e-14);
207        approx_eq(suq_cos(PI, 1.0), -1.0, 1e-14);
208        approx_eq(suq_cos(PI / 2.0, 0.0), 1.0, 1e-14); // because sign(cos(pi/2))=1
209        approx_eq(suq_cos(PI / 2.0, 1.0), 0.0, 1e-14);
210        approx_eq(suq_cos(PI / 2.0, 2.0), 0.0, 1e-14);
211        approx_eq(suq_cos(PI / 4.0, 2.0), 0.5, 1e-14);
212        approx_eq(suq_cos(-PI / 4.0, 2.0), 0.5, 1e-14);
213    }
214
215    #[test]
216    fn linspace_works() {
217        let x = linspace(0.0, 1.0, 11);
218        let correct = &[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0];
219        let mut k = 0;
220        for v in &x {
221            approx_eq(*v, correct[k], 1e-15);
222            k += 1;
223        }
224
225        let x = linspace(2.0, 3.0, 0);
226        assert_eq!(x.len(), 0);
227
228        let x = linspace(2.0, 3.0, 1);
229        assert_eq!(x.len(), 1);
230        assert_eq!(x[0], 2.0);
231
232        let x = linspace(2.0, 3.0, 2);
233        assert_eq!(x.len(), 2);
234        assert_eq!(x[0], 2.0);
235        assert_eq!(x[1], 3.0);
236
237        let x = linspace(0.0, 10.0, 0);
238        assert_eq!(x.len(), 0);
239
240        let x = linspace(0.0, 10.0, 1);
241        assert_eq!(x, &[0.0]);
242
243        let x = linspace(0.0, 10.0, 2);
244        assert_eq!(x, [0.0, 10.0]);
245
246        let x = linspace(0.0, 10.0, 3);
247        assert_eq!(x, [0.0, 5.0, 10.0]);
248    }
249
250    #[test]
251    fn generate2d_edge_cases_work() {
252        let (x, y) = generate2d(-1.0, 1.0, -3.0, 3.0, 0, 0);
253        assert_eq!(x.len(), 0);
254        assert_eq!(y.len(), 0);
255
256        let (x, y) = generate2d(-1.0, 1.0, -3.0, 3.0, 1, 1);
257        assert_eq!(x.len(), 1);
258        assert_eq!(y.len(), 1);
259        assert_eq!(x[0], &[-1.0]);
260        assert_eq!(y[0], &[-3.0]);
261    }
262
263    #[test]
264    fn generate2d_works() {
265        let (x, y) = generate2d(-1.0, 1.0, -3.0, 3.0, 0, 2);
266        assert_eq!(x.len(), 2);
267        assert_eq!(y.len(), 2);
268        assert_eq!(x[0].len(), 0);
269        assert_eq!(x[1].len(), 0);
270        assert_eq!(y[0].len(), 0);
271        assert_eq!(y[1].len(), 0);
272
273        let (x, y) = generate2d(-1.0, 1.0, -3.0, 3.0, 2, 0);
274        assert_eq!(x.len(), 0);
275        assert_eq!(y.len(), 0);
276
277        let (x, y) = generate2d(-1.0, 1.0, -3.0, 3.0, 1, 2);
278        assert_eq!(x, &[[-1.0], [-1.0]]);
279        assert_eq!(y, &[[-3.0], [3.0]]);
280
281        let (x, y) = generate2d(-1.0, 1.0, -3.0, 3.0, 2, 1);
282        assert_eq!(x, &[[-1.0, 1.0]]);
283        assert_eq!(y, &[[-3.0, -3.0]]);
284
285        let (x, y) = generate2d(-1.0, 1.0, -3.0, 3.0, 2, 3);
286        // -1.0, 1.0,
287        // -1.0, 1.0,
288        // -1.0, 1.0,
289        assert_eq!(x, &[[-1.0, 1.0], [-1.0, 1.0], [-1.0, 1.0]]);
290        // -3.0, -3.0,
291        //  0.0,  0.0,
292        //  3.0,  3.0,
293        assert_eq!(y, &[[-3.0, -3.0], [0.0, 0.0], [3.0, 3.0]]);
294    }
295
296    fn calc_z(x: f64, y: f64) -> f64 {
297        x + y
298    }
299
300    #[test]
301    fn generate3d_edge_cases_work() {
302        let (x, y, z) = generate3d(-1.0, 1.0, -3.0, 3.0, 0, 0, calc_z);
303        assert_eq!(x.len(), 0);
304        assert_eq!(y.len(), 0);
305        assert_eq!(z.len(), 0);
306
307        let (x, y, z) = generate3d(-1.0, 1.0, -3.0, 3.0, 1, 1, calc_z);
308        assert_eq!(x.len(), 1);
309        assert_eq!(y.len(), 1);
310        assert_eq!(z.len(), 1);
311        assert_eq!(x[0], &[-1.0]);
312        assert_eq!(y[0], &[-3.0]);
313        assert_eq!(z[0], &[-4.0]);
314    }
315
316    #[test]
317    fn generate3d_works() {
318        let (x, y, z) = generate3d(-1.0, 1.0, -3.0, 3.0, 0, 2, calc_z);
319        assert_eq!(x.len(), 2);
320        assert_eq!(y.len(), 2);
321        assert_eq!(z.len(), 2);
322        assert_eq!(x[0].len(), 0);
323        assert_eq!(x[1].len(), 0);
324        assert_eq!(y[0].len(), 0);
325        assert_eq!(y[1].len(), 0);
326        assert_eq!(z[0].len(), 0);
327        assert_eq!(z[1].len(), 0);
328
329        let (x, y, z) = generate3d(-1.0, 1.0, -3.0, 3.0, 2, 0, calc_z);
330        assert_eq!(x.len(), 0);
331        assert_eq!(y.len(), 0);
332        assert_eq!(z.len(), 0);
333
334        let (x, y, z) = generate3d(-1.0, 1.0, -3.0, 3.0, 1, 2, calc_z);
335        assert_eq!(x.len(), 2);
336        assert_eq!(y.len(), 2);
337        assert_eq!(z.len(), 2);
338        assert_eq!(x, &[[-1.0], [-1.0]]);
339        assert_eq!(y, &[[-3.0], [3.0]]);
340        assert_eq!(z, &[[-4.0], [2.0]]);
341
342        let (x, y, z) = generate3d(-1.0, 1.0, -3.0, 3.0, 2, 1, calc_z);
343        assert_eq!(x.len(), 1);
344        assert_eq!(y.len(), 1);
345        assert_eq!(z.len(), 1);
346        assert_eq!(x, &[[-1.0, 1.0]]);
347        assert_eq!(y, &[[-3.0, -3.0]]);
348        assert_eq!(z, &[[-4.0, -2.0]]);
349
350        let (x, y, z) = generate3d(-1.0, 1.0, -3.0, 3.0, 2, 3, calc_z);
351        // -1.0, 1.0,
352        // -1.0, 1.0,
353        // -1.0, 1.0,
354        assert_eq!(x, &[[-1.0, 1.0], [-1.0, 1.0], [-1.0, 1.0]]);
355        // -3.0, -3.0,
356        //  0.0,  0.0,
357        //  3.0,  3.0,
358        assert_eq!(y, &[[-3.0, -3.0], [0.0, 0.0], [3.0, 3.0]]);
359        // -4.0, -2.0,
360        // -1.0,  1.0,
361        //  2.0,  4.0,
362        assert_eq!(z, &[[-4.0, -2.0], [-1.0, 1.0], [2.0, 4.0]]);
363    }
364}