Skip to main content

plotpy/
auxiliary.rs

1/// Implements the sign function
2///
3/// ```text
4///           │ -1   if x < 0
5/// sign(x) = ┤  0   if x = 0
6///           │  1   if x > 0
7///
8///           |x|    x
9/// sign(x) = ——— = ———
10///            x    |x|
11///
12/// sign(x) = 2 · heaviside(x) - 1
13/// ```
14///
15/// Reference: <https://en.wikipedia.org/wiki/Sign_function>
16pub fn sign(x: f64) -> f64 {
17    if x < 0.0 {
18        -1.0
19    } else if x > 0.0 {
20        1.0
21    } else {
22        0.0
23    }
24}
25
26/// Implements the superquadric auxiliary function involving sin(x)
27///
28/// ```text
29/// suq_sin(x;k) = sign(sin(x)) · |sin(x)|ᵏ
30/// ```
31///
32/// This is the angular shaping function for [superquadrics](https://en.wikipedia.org/wiki/Superquadrics).
33/// When `k = 2` the result is the standard parametric form of a sphere/ellipsoid.
34/// Values `k < 1` produce a "pinched" shape; `k > 2` produces a "squared" shape.
35///
36/// `suq_sin(x;k)` corresponds to `f(ω;m)` in the superquadric literature.
37///
38/// See also: [`suq_cos`]
39pub fn suq_sin(x: f64, k: f64) -> f64 {
40    sign(f64::sin(x)) * f64::powf(f64::abs(f64::sin(x)), k)
41}
42
43/// Implements the superquadric auxiliary function involving cos(x)
44///
45/// ```text
46/// suq_cos(x;k) = sign(cos(x)) · |cos(x)|ᵏ
47/// ```
48///
49/// This is the angular shaping function for [superquadrics](https://en.wikipedia.org/wiki/Superquadrics).
50/// When `k = 2` the result is the standard parametric form of a sphere/ellipsoid.
51/// Values `k < 1` produce a "pinched" shape; `k > 2` produces a "squared" shape.
52///
53/// `suq_cos(x;k)` corresponds to `g(ω;m)` in the superquadric literature.
54///
55/// See also: [`suq_sin`]
56pub fn suq_cos(x: f64, k: f64) -> f64 {
57    sign(f64::cos(x)) * f64::powf(f64::abs(f64::cos(x)), k)
58}
59
60/// Returns evenly spaced numbers over a specified closed interval
61///
62/// Analogous to [numpy.linspace](https://numpy.org/doc/stable/reference/generated/numpy.linspace.html).
63/// Both `start` and `stop` are included in the output (closed interval).
64///
65/// # Examples
66///
67/// ```
68/// use plotpy::linspace;
69///
70/// assert_eq!(linspace(0.0, 1.0, 3), vec![0.0, 0.5, 1.0]);
71/// assert_eq!(linspace(0.0, 1.0, 5), vec![0.0, 0.25, 0.5, 0.75, 1.0]);
72/// ```
73///
74/// # Edge cases
75///
76/// - `count == 0` returns an empty vector
77/// - `count == 1` returns `[start]`
78/// - `count == 2` returns `[start, stop]`
79pub fn linspace(start: f64, stop: f64, count: usize) -> Vec<f64> {
80    if count == 0 {
81        return Vec::new();
82    }
83    let mut res = vec![0.0; count];
84    res[0] = start;
85    if count == 1 {
86        return res;
87    }
88    res[count - 1] = stop;
89    if count == 2 {
90        return res;
91    }
92    let den = (count - 1) as f64;
93    let step = (stop - start) / den;
94    for i in 1..count {
95        let p = i as f64;
96        res[i] = start + p * step;
97    }
98    res
99}
100
101/// Generates 2d meshgrid points
102///
103/// This is analogous to [numpy.meshgrid](https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html)
104/// with `indexing='ij'`. Produces two (`ny` × `nx`) matrices where each row has the same `y`
105/// value and each column has the same `x` value.
106///
107/// # Input
108///
109/// * `xmin`, `xmax` -- range along x
110/// * `ymin`, `ymax` -- range along y
111/// * `nx` -- number of points along x (must be `>= 2`)
112/// * `ny` -- number of points along y (must be `>= 2`)
113///
114/// # Output
115///
116/// * `x`, `y` -- (`ny` by `nx`) 2D arrays such that `x[i][j] = xmin + j·dx` and `y[i][j] = ymin + i·dy`
117///
118/// # Example
119///
120/// ```
121/// use plotpy::generate2d;
122///
123/// let (x, y) = generate2d(-1.0, 1.0, -3.0, 3.0, 2, 3);
124/// assert_eq!(x, vec![vec![-1.0, 1.0], vec![-1.0, 1.0], vec![-1.0, 1.0]]);
125/// assert_eq!(y, vec![vec![-3.0, -3.0], vec![0.0, 0.0], vec![3.0, 3.0]]);
126/// ```
127///
128/// See also: [`generate3d`]
129pub fn generate2d(xmin: f64, xmax: f64, ymin: f64, ymax: f64, nx: usize, ny: usize) -> (Vec<Vec<f64>>, Vec<Vec<f64>>) {
130    let mut x = vec![vec![0.0; nx]; ny];
131    let mut y = vec![vec![0.0; nx]; ny];
132    if nx == 0 || ny == 0 {
133        return (x, y);
134    }
135    let dx = if nx == 1 {
136        xmin
137    } else {
138        (xmax - xmin) / ((nx - 1) as f64)
139    };
140    let dy = if ny == 1 {
141        ymin
142    } else {
143        (ymax - ymin) / ((ny - 1) as f64)
144    };
145    for i in 0..ny {
146        let v = ymin + (i as f64) * dy;
147        for j in 0..nx {
148            let u = xmin + (j as f64) * dx;
149            x[i][j] = u;
150            y[i][j] = v;
151        }
152    }
153    (x, y)
154}
155
156/// Generates 3d points by evaluating a function over a 2d meshgrid
157///
158/// Creates the same `(x, y)` grid as [`generate2d`], then evaluates `calc_z(x, y)` at each
159/// grid point to produce the `z` matrix. This is the typical input for [`Surface::draw`](crate::Surface::draw)
160/// and [`Contour::draw`](crate::Contour::draw).
161///
162/// # Input
163///
164/// * `xmin`, `xmax` -- range along x
165/// * `ymin`, `ymax` -- range along y
166/// * `nx` -- number of points along x (must be `>= 2`)
167/// * `ny` -- number of points along y (must be `>= 2`)
168/// * `calc_z` -- function `f(x, y)` that returns `z` at each grid point
169///
170/// # Output
171///
172/// * `x`, `y`, `z` -- (`ny` by `nx`) 2D arrays
173///
174/// # Example
175///
176/// ```
177/// use plotpy::generate3d;
178///
179/// let (x, y, z) = generate3d(-1.0, 1.0, -1.0, 1.0, 3, 3, |x, y| x * x + y * y);
180/// assert_eq!(z, vec![
181///     vec![2.0, 1.0, 2.0],
182///     vec![1.0, 0.0, 1.0],
183///     vec![2.0, 1.0, 2.0],
184/// ]);
185/// ```
186///
187/// See also: [`generate2d`]
188pub fn generate3d<F>(
189    xmin: f64,
190    xmax: f64,
191    ymin: f64,
192    ymax: f64,
193    nx: usize,
194    ny: usize,
195    mut calc_z: F,
196) -> (Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<Vec<f64>>)
197where
198    F: FnMut(f64, f64) -> f64,
199{
200    let mut x = vec![vec![0.0; nx]; ny];
201    let mut y = vec![vec![0.0; nx]; ny];
202    let mut z = vec![vec![0.0; nx]; ny];
203    if nx == 0 || ny == 0 {
204        return (x, y, z);
205    }
206    let dx = if nx == 1 {
207        xmin
208    } else {
209        (xmax - xmin) / ((nx - 1) as f64)
210    };
211    let dy = if ny == 1 {
212        ymin
213    } else {
214        (ymax - ymin) / ((ny - 1) as f64)
215    };
216    for i in 0..ny {
217        let v = ymin + (i as f64) * dy;
218        for j in 0..nx {
219            let u = xmin + (j as f64) * dx;
220            x[i][j] = u;
221            y[i][j] = v;
222            z[i][j] = calc_z(u, v);
223        }
224    }
225    (x, y, z)
226}
227
228////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
229
230#[cfg(test)]
231mod tests {
232    use super::{generate2d, generate3d, linspace, sign, suq_cos, suq_sin};
233
234    fn approx_eq(a: f64, b: f64, tol: f64) {
235        let diff = f64::abs(a - b);
236        if diff > tol {
237            panic!("numbers are not approximately equal. diff = {:?}", diff);
238        }
239    }
240
241    #[test]
242    #[should_panic(expected = "numbers are not approximately equal. diff = 1.0")]
243    fn approx_eq_captures_errors() {
244        approx_eq(1.0, 2.0, 1e-15);
245    }
246
247    #[test]
248    fn sign_works() {
249        let xx = [-2.0, -1.6, -1.2, -0.8, -0.4, 0.0, 0.4, 0.8, 1.2, 1.6, 2.0];
250        for x in xx {
251            let s = sign(x);
252            if x == 0.0 {
253                assert_eq!(s, 0.0);
254            } else {
255                assert_eq!(s, f64::abs(x) / x);
256            }
257        }
258    }
259
260    #[test]
261    fn suq_sin_and_cos_work() {
262        const PI: f64 = std::f64::consts::PI;
263        approx_eq(suq_sin(0.0, 1.0), 0.0, 1e-14);
264        approx_eq(suq_sin(PI, 1.0), 0.0, 1e-14);
265        approx_eq(suq_sin(PI / 2.0, 0.0), 1.0, 1e-14);
266        approx_eq(suq_sin(PI / 2.0, 1.0), 1.0, 1e-14);
267        approx_eq(suq_sin(PI / 2.0, 2.0), 1.0, 1e-14);
268        approx_eq(suq_sin(PI / 4.0, 2.0), 0.5, 1e-14);
269        approx_eq(suq_sin(-PI / 4.0, 2.0), -0.5, 1e-14);
270
271        approx_eq(suq_cos(0.0, 1.0), 1.0, 1e-14);
272        approx_eq(suq_cos(PI, 1.0), -1.0, 1e-14);
273        approx_eq(suq_cos(PI / 2.0, 0.0), 1.0, 1e-14); // because sign(cos(pi/2))=1
274        approx_eq(suq_cos(PI / 2.0, 1.0), 0.0, 1e-14);
275        approx_eq(suq_cos(PI / 2.0, 2.0), 0.0, 1e-14);
276        approx_eq(suq_cos(PI / 4.0, 2.0), 0.5, 1e-14);
277        approx_eq(suq_cos(-PI / 4.0, 2.0), 0.5, 1e-14);
278    }
279
280    #[test]
281    fn linspace_works() {
282        let x = linspace(0.0, 1.0, 11);
283        let correct = &[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0];
284        let mut k = 0;
285        for v in &x {
286            approx_eq(*v, correct[k], 1e-15);
287            k += 1;
288        }
289
290        let x = linspace(2.0, 3.0, 0);
291        assert_eq!(x.len(), 0);
292
293        let x = linspace(2.0, 3.0, 1);
294        assert_eq!(x.len(), 1);
295        assert_eq!(x[0], 2.0);
296
297        let x = linspace(2.0, 3.0, 2);
298        assert_eq!(x.len(), 2);
299        assert_eq!(x[0], 2.0);
300        assert_eq!(x[1], 3.0);
301
302        let x = linspace(0.0, 10.0, 0);
303        assert_eq!(x.len(), 0);
304
305        let x = linspace(0.0, 10.0, 1);
306        assert_eq!(x, &[0.0]);
307
308        let x = linspace(0.0, 10.0, 2);
309        assert_eq!(x, [0.0, 10.0]);
310
311        let x = linspace(0.0, 10.0, 3);
312        assert_eq!(x, [0.0, 5.0, 10.0]);
313    }
314
315    #[test]
316    fn generate2d_edge_cases_work() {
317        let (x, y) = generate2d(-1.0, 1.0, -3.0, 3.0, 0, 0);
318        assert_eq!(x.len(), 0);
319        assert_eq!(y.len(), 0);
320
321        let (x, y) = generate2d(-1.0, 1.0, -3.0, 3.0, 1, 1);
322        assert_eq!(x.len(), 1);
323        assert_eq!(y.len(), 1);
324        assert_eq!(x[0], &[-1.0]);
325        assert_eq!(y[0], &[-3.0]);
326    }
327
328    #[test]
329    fn generate2d_works() {
330        let (x, y) = generate2d(-1.0, 1.0, -3.0, 3.0, 0, 2);
331        assert_eq!(x.len(), 2);
332        assert_eq!(y.len(), 2);
333        assert_eq!(x[0].len(), 0);
334        assert_eq!(x[1].len(), 0);
335        assert_eq!(y[0].len(), 0);
336        assert_eq!(y[1].len(), 0);
337
338        let (x, y) = generate2d(-1.0, 1.0, -3.0, 3.0, 2, 0);
339        assert_eq!(x.len(), 0);
340        assert_eq!(y.len(), 0);
341
342        let (x, y) = generate2d(-1.0, 1.0, -3.0, 3.0, 1, 2);
343        assert_eq!(x, &[[-1.0], [-1.0]]);
344        assert_eq!(y, &[[-3.0], [3.0]]);
345
346        let (x, y) = generate2d(-1.0, 1.0, -3.0, 3.0, 2, 1);
347        assert_eq!(x, &[[-1.0, 1.0]]);
348        assert_eq!(y, &[[-3.0, -3.0]]);
349
350        let (x, y) = generate2d(-1.0, 1.0, -3.0, 3.0, 2, 3);
351        // -1.0, 1.0,
352        // -1.0, 1.0,
353        // -1.0, 1.0,
354        assert_eq!(x, &[[-1.0, 1.0], [-1.0, 1.0], [-1.0, 1.0]]);
355        // -3.0, -3.0,
356        //  0.0,  0.0,
357        //  3.0,  3.0,
358        assert_eq!(y, &[[-3.0, -3.0], [0.0, 0.0], [3.0, 3.0]]);
359    }
360
361    fn calc_z(x: f64, y: f64) -> f64 {
362        x + y
363    }
364
365    #[test]
366    fn generate3d_edge_cases_work() {
367        let (x, y, z) = generate3d(-1.0, 1.0, -3.0, 3.0, 0, 0, calc_z);
368        assert_eq!(x.len(), 0);
369        assert_eq!(y.len(), 0);
370        assert_eq!(z.len(), 0);
371
372        let (x, y, z) = generate3d(-1.0, 1.0, -3.0, 3.0, 1, 1, calc_z);
373        assert_eq!(x.len(), 1);
374        assert_eq!(y.len(), 1);
375        assert_eq!(z.len(), 1);
376        assert_eq!(x[0], &[-1.0]);
377        assert_eq!(y[0], &[-3.0]);
378        assert_eq!(z[0], &[-4.0]);
379    }
380
381    #[test]
382    fn generate3d_works() {
383        let (x, y, z) = generate3d(-1.0, 1.0, -3.0, 3.0, 0, 2, calc_z);
384        assert_eq!(x.len(), 2);
385        assert_eq!(y.len(), 2);
386        assert_eq!(z.len(), 2);
387        assert_eq!(x[0].len(), 0);
388        assert_eq!(x[1].len(), 0);
389        assert_eq!(y[0].len(), 0);
390        assert_eq!(y[1].len(), 0);
391        assert_eq!(z[0].len(), 0);
392        assert_eq!(z[1].len(), 0);
393
394        let (x, y, z) = generate3d(-1.0, 1.0, -3.0, 3.0, 2, 0, calc_z);
395        assert_eq!(x.len(), 0);
396        assert_eq!(y.len(), 0);
397        assert_eq!(z.len(), 0);
398
399        let (x, y, z) = generate3d(-1.0, 1.0, -3.0, 3.0, 1, 2, calc_z);
400        assert_eq!(x.len(), 2);
401        assert_eq!(y.len(), 2);
402        assert_eq!(z.len(), 2);
403        assert_eq!(x, &[[-1.0], [-1.0]]);
404        assert_eq!(y, &[[-3.0], [3.0]]);
405        assert_eq!(z, &[[-4.0], [2.0]]);
406
407        let (x, y, z) = generate3d(-1.0, 1.0, -3.0, 3.0, 2, 1, calc_z);
408        assert_eq!(x.len(), 1);
409        assert_eq!(y.len(), 1);
410        assert_eq!(z.len(), 1);
411        assert_eq!(x, &[[-1.0, 1.0]]);
412        assert_eq!(y, &[[-3.0, -3.0]]);
413        assert_eq!(z, &[[-4.0, -2.0]]);
414
415        let (x, y, z) = generate3d(-1.0, 1.0, -3.0, 3.0, 2, 3, calc_z);
416        // -1.0, 1.0,
417        // -1.0, 1.0,
418        // -1.0, 1.0,
419        assert_eq!(x, &[[-1.0, 1.0], [-1.0, 1.0], [-1.0, 1.0]]);
420        // -3.0, -3.0,
421        //  0.0,  0.0,
422        //  3.0,  3.0,
423        assert_eq!(y, &[[-3.0, -3.0], [0.0, 0.0], [3.0, 3.0]]);
424        // -4.0, -2.0,
425        // -1.0,  1.0,
426        //  2.0,  4.0,
427        assert_eq!(z, &[[-4.0, -2.0], [-1.0, 1.0], [2.0, 4.0]]);
428    }
429}