Skip to main content

ggplot_rs/stat/
contour_filled.rs

1use crate::aes::Aesthetic;
2use crate::data::{DataFrame, Value};
3use crate::scale::ScaleSet;
4
5use super::contour::build_grid_from_xyz;
6use super::Stat;
7
8/// Filled contour bands from gridded (x, y, z) data (R's `stat_contour_filled` /
9/// `geom_contour_filled`).
10///
11/// The domain is resampled onto a grid; each cell is split into two triangles,
12/// and each triangle is clipped to every level band `[lo, hi]` (z is linear on a
13/// triangle, so the band region is convex). Emits filled polygons with one
14/// `group` per polygon and `fill` = the band midpoint — pair with a continuous
15/// fill scale via `geom_contour_filled()`.
16pub struct StatContourFilled {
17    pub bins: usize,
18    pub n_bands: usize,
19}
20
21impl Default for StatContourFilled {
22    fn default() -> Self {
23        StatContourFilled {
24            bins: 24,
25            n_bands: 8,
26        }
27    }
28}
29
30type Vtx = (f64, f64, f64);
31
32/// Sutherland–Hodgman clip of a polygon against the iso-level `thr`, keeping the
33/// side where `z >= thr` (or `z <= thr` when `keep_ge` is false). New vertices
34/// land exactly on the iso-level.
35fn clip(poly: &[Vtx], thr: f64, keep_ge: bool) -> Vec<Vtx> {
36    let n = poly.len();
37    if n == 0 {
38        return Vec::new();
39    }
40    let inside = |z: f64| if keep_ge { z >= thr } else { z <= thr };
41    let mut out = Vec::new();
42    for i in 0..n {
43        let cur = poly[i];
44        let nxt = poly[(i + 1) % n];
45        let ci = inside(cur.2);
46        let ni = inside(nxt.2);
47        if ci {
48            out.push(cur);
49        }
50        if ci != ni {
51            let denom = nxt.2 - cur.2;
52            let t = if denom.abs() < 1e-12 {
53                0.0
54            } else {
55                (thr - cur.2) / denom
56            };
57            out.push((
58                cur.0 + (nxt.0 - cur.0) * t,
59                cur.1 + (nxt.1 - cur.1) * t,
60                thr,
61            ));
62        }
63    }
64    out
65}
66
67impl Stat for StatContourFilled {
68    fn compute_group(&self, data: &DataFrame, _scales: &ScaleSet) -> DataFrame {
69        let (x_col, y_col, z_col) = match (data.column("x"), data.column("y"), data.column("z")) {
70            (Some(x), Some(y), Some(z)) => (x, y, z),
71            _ => return DataFrame::new(),
72        };
73        let points: Vec<Vtx> = x_col
74            .iter()
75            .zip(y_col.iter())
76            .zip(z_col.iter())
77            .filter_map(|((x, y), z)| {
78                let (xv, yv, zv) = (x.as_f64()?, y.as_f64()?, z.as_f64()?);
79                (xv.is_finite() && yv.is_finite() && zv.is_finite()).then_some((xv, yv, zv))
80            })
81            .collect();
82        if points.is_empty() {
83            return DataFrame::new();
84        }
85
86        let x_min = points.iter().map(|p| p.0).fold(f64::INFINITY, f64::min);
87        let x_max = points.iter().map(|p| p.0).fold(f64::NEG_INFINITY, f64::max);
88        let y_min = points.iter().map(|p| p.1).fold(f64::INFINITY, f64::min);
89        let y_max = points.iter().map(|p| p.1).fold(f64::NEG_INFINITY, f64::max);
90        let z_min = points.iter().map(|p| p.2).fold(f64::INFINITY, f64::min);
91        let z_max = points.iter().map(|p| p.2).fold(f64::NEG_INFINITY, f64::max);
92        if (x_max - x_min).abs() < f64::EPSILON
93            || (y_max - y_min).abs() < f64::EPSILON
94            || (z_max - z_min).abs() < f64::EPSILON
95        {
96            return DataFrame::new();
97        }
98
99        let nx = self.bins.max(1);
100        let ny = self.bins.max(1);
101        let dx = (x_max - x_min) / nx as f64;
102        let dy = (y_max - y_min) / ny as f64;
103        let grid = build_grid_from_xyz(&points, nx, ny, x_min, y_min, dx, dy, z_min, z_max);
104
105        let n_bands = self.n_bands.max(1);
106        let step = (z_max - z_min) / n_bands as f64;
107        let levels: Vec<f64> = (0..=n_bands).map(|k| z_min + k as f64 * step).collect();
108
109        let mut x_vals = Vec::new();
110        let mut y_vals = Vec::new();
111        let mut group_vals = Vec::new();
112        let mut fill_vals = Vec::new();
113        let mut gid: u64 = 0;
114        let at = |ix: usize, iy: usize| grid[iy * (nx + 1) + ix];
115
116        for iy in 0..ny {
117            for ix in 0..nx {
118                let x0 = x_min + ix as f64 * dx;
119                let x1 = x0 + dx;
120                let y0 = y_min + iy as f64 * dy;
121                let y1 = y0 + dy;
122                let p00 = (x0, y0, at(ix, iy));
123                let p10 = (x1, y0, at(ix + 1, iy));
124                let p01 = (x0, y1, at(ix, iy + 1));
125                let p11 = (x1, y1, at(ix + 1, iy + 1));
126
127                for tri in [[p00, p10, p11], [p00, p11, p01]] {
128                    for k in 0..n_bands {
129                        let (lo, hi) = (levels[k], levels[k + 1]);
130                        let band = clip(&clip(&tri, lo, true), hi, false);
131                        if band.len() >= 3 {
132                            let mid = 0.5 * (lo + hi);
133                            for v in &band {
134                                x_vals.push(Value::Float(v.0));
135                                y_vals.push(Value::Float(v.1));
136                                group_vals.push(Value::Str(format!("b{gid}")));
137                                fill_vals.push(Value::Float(mid));
138                            }
139                            gid += 1;
140                        }
141                    }
142                }
143            }
144        }
145
146        let mut result = DataFrame::new();
147        result.add_column("x".to_string(), x_vals);
148        result.add_column("y".to_string(), y_vals);
149        result.add_column("group".to_string(), group_vals);
150        result.add_column("fill".to_string(), fill_vals);
151        result
152    }
153
154    fn required_aes(&self) -> Vec<Aesthetic> {
155        vec![Aesthetic::X, Aesthetic::Y]
156    }
157
158    fn name(&self) -> &str {
159        "contour_filled"
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    fn cone_grid() -> DataFrame {
168        // A radial cone z = -sqrt(x^2+y^2) over a grid → nested filled bands.
169        let mut df = DataFrame::new();
170        let (mut xs, mut ys, mut zs) = (Vec::new(), Vec::new(), Vec::new());
171        for i in 0..20 {
172            for j in 0..20 {
173                let x = i as f64 - 10.0;
174                let y = j as f64 - 10.0;
175                xs.push(Value::Float(x));
176                ys.push(Value::Float(y));
177                zs.push(Value::Float(-(x * x + y * y).sqrt()));
178            }
179        }
180        df.add_column("x".into(), xs);
181        df.add_column("y".into(), ys);
182        df.add_column("z".into(), zs);
183        df
184    }
185
186    #[test]
187    fn produces_filled_bands() {
188        let out = StatContourFilled {
189            bins: 16,
190            n_bands: 5,
191        }
192        .compute_group(&cone_grid(), &ScaleSet::new());
193        assert!(out.nrows() > 0);
194        assert!(out.has_column("group") && out.has_column("fill"));
195        // Several distinct band fill levels should be present.
196        let fills: std::collections::HashSet<String> = out
197            .column("fill")
198            .unwrap()
199            .iter()
200            .map(|v| format!("{v:?}"))
201            .collect();
202        assert!(
203            fills.len() >= 3,
204            "expected multiple bands, got {}",
205            fills.len()
206        );
207    }
208
209    #[test]
210    fn degenerate_z_returns_empty() {
211        let mut df = DataFrame::new();
212        df.add_column("x".into(), vec![Value::Float(0.0), Value::Float(1.0)]);
213        df.add_column("y".into(), vec![Value::Float(0.0), Value::Float(1.0)]);
214        df.add_column("z".into(), vec![Value::Float(5.0), Value::Float(5.0)]);
215        let out = StatContourFilled::default().compute_group(&df, &ScaleSet::new());
216        assert_eq!(out.nrows(), 0);
217    }
218}