Skip to main content

ggplot_rs/stat/
summary2d.rs

1use crate::aes::Aesthetic;
2use crate::data::{DataFrame, Value};
3use crate::scale::ScaleSet;
4
5use super::summary::SummaryFun;
6use super::Stat;
7
8/// 2-D binned summary (analogous to R's `stat_summary_2d`): bins x/y into a grid
9/// and applies a summary function to the `z` values in each cell, emitting
10/// `xmin/xmax/ymin/ymax/fill` (fill = the per-cell summary) for `geom_bin2d`.
11pub struct StatSummary2d {
12    pub bins_x: usize,
13    pub bins_y: usize,
14    pub fun: SummaryFun,
15}
16
17impl Default for StatSummary2d {
18    fn default() -> Self {
19        StatSummary2d {
20            bins_x: 30,
21            bins_y: 30,
22            fun: SummaryFun::Mean,
23        }
24    }
25}
26
27impl StatSummary2d {
28    pub fn new(fun: SummaryFun) -> Self {
29        StatSummary2d {
30            fun,
31            ..Default::default()
32        }
33    }
34
35    pub fn with_bins(mut self, bins_x: usize, bins_y: usize) -> Self {
36        self.bins_x = bins_x.max(1);
37        self.bins_y = bins_y.max(1);
38        self
39    }
40}
41
42impl Stat for StatSummary2d {
43    fn compute_group(&self, data: &DataFrame, _scales: &ScaleSet) -> DataFrame {
44        let (x_col, y_col, z_col) = match (data.column("x"), data.column("y"), data.column("z")) {
45            (Some(x), Some(y), Some(z)) => (x, y, z),
46            _ => return DataFrame::new(),
47        };
48        let rows: Vec<(f64, f64, f64)> = x_col
49            .iter()
50            .zip(y_col.iter())
51            .zip(z_col.iter())
52            .filter_map(|((x, y), z)| Some((x.as_f64()?, y.as_f64()?, z.as_f64()?)))
53            .collect();
54        if rows.is_empty() {
55            return DataFrame::new();
56        }
57
58        let x_min = rows.iter().map(|r| r.0).fold(f64::INFINITY, f64::min);
59        let x_max = rows.iter().map(|r| r.0).fold(f64::NEG_INFINITY, f64::max);
60        let y_min = rows.iter().map(|r| r.1).fold(f64::INFINITY, f64::min);
61        let y_max = rows.iter().map(|r| r.1).fold(f64::NEG_INFINITY, f64::max);
62        let (x_min, x_max) = if (x_max - x_min).abs() < f64::EPSILON {
63            (x_min - 0.5, x_max + 0.5)
64        } else {
65            (x_min, x_max)
66        };
67        let (y_min, y_max) = if (y_max - y_min).abs() < f64::EPSILON {
68            (y_min - 0.5, y_max + 0.5)
69        } else {
70            (y_min, y_max)
71        };
72        let bw_x = (x_max - x_min) / self.bins_x as f64;
73        let bw_y = (y_max - y_min) / self.bins_y as f64;
74
75        let mut cells: Vec<Vec<Vec<f64>>> = vec![vec![Vec::new(); self.bins_y]; self.bins_x];
76        for &(x, y, z) in &rows {
77            let bx = (((x - x_min) / bw_x).floor() as usize).min(self.bins_x - 1);
78            let by = (((y - y_min) / bw_y).floor() as usize).min(self.bins_y - 1);
79            cells[bx][by].push(z);
80        }
81
82        let mut xmin_vals = Vec::new();
83        let mut xmax_vals = Vec::new();
84        let mut ymin_vals = Vec::new();
85        let mut ymax_vals = Vec::new();
86        let mut fill_vals = Vec::new();
87        for (bx, col) in cells.iter().enumerate() {
88            for (by, zs) in col.iter().enumerate() {
89                if zs.is_empty() {
90                    continue;
91                }
92                let cell_xmin = x_min + bx as f64 * bw_x;
93                let cell_ymin = y_min + by as f64 * bw_y;
94                xmin_vals.push(Value::Float(cell_xmin));
95                xmax_vals.push(Value::Float(cell_xmin + bw_x));
96                ymin_vals.push(Value::Float(cell_ymin));
97                ymax_vals.push(Value::Float(cell_ymin + bw_y));
98                fill_vals.push(Value::Float(self.fun.apply(zs)));
99            }
100        }
101
102        let mut result = DataFrame::new();
103        result.add_column("xmin".to_string(), xmin_vals);
104        result.add_column("xmax".to_string(), xmax_vals);
105        result.add_column("ymin".to_string(), ymin_vals);
106        result.add_column("ymax".to_string(), ymax_vals);
107        result.add_column("fill".to_string(), fill_vals);
108        result
109    }
110
111    fn required_aes(&self) -> Vec<Aesthetic> {
112        vec![Aesthetic::X, Aesthetic::Y]
113    }
114
115    fn name(&self) -> &str {
116        "summary_2d"
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    #[test]
125    fn summarises_z_per_cell() {
126        let mut df = DataFrame::new();
127        // Two cells: left cluster z≈10, right cluster z≈20.
128        let xs = [0.0, 0.1, 0.2, 9.0, 9.1, 9.2];
129        let ys = [0.0, 0.1, 0.0, 9.0, 9.1, 9.0];
130        let zs = [10.0, 10.0, 10.0, 20.0, 20.0, 20.0];
131        df.add_column("x".into(), xs.iter().map(|v| Value::Float(*v)).collect());
132        df.add_column("y".into(), ys.iter().map(|v| Value::Float(*v)).collect());
133        df.add_column("z".into(), zs.iter().map(|v| Value::Float(*v)).collect());
134
135        let out = StatSummary2d::new(SummaryFun::Mean)
136            .with_bins(2, 2)
137            .compute_group(&df, &ScaleSet::new());
138        let fills: Vec<f64> = out
139            .column("fill")
140            .unwrap()
141            .iter()
142            .filter_map(|v| v.as_f64())
143            .collect();
144        assert_eq!(fills.len(), 2);
145        assert!(fills.contains(&10.0) && fills.contains(&20.0), "{fills:?}");
146    }
147
148    #[test]
149    fn missing_z_returns_empty() {
150        let mut df = DataFrame::new();
151        df.add_column("x".into(), vec![Value::Float(1.0)]);
152        df.add_column("y".into(), vec![Value::Float(1.0)]);
153        let out = StatSummary2d::default().compute_group(&df, &ScaleSet::new());
154        assert_eq!(out.nrows(), 0);
155    }
156}