Skip to main content

ggplot_rs/stat/
ydensity.rs

1use crate::aes::Aesthetic;
2use crate::data::{DataFrame, Value};
3use crate::scale::ScaleSet;
4
5use super::Stat;
6
7/// Kernel density estimation on Y per group (for violin plots).
8/// Outputs: x (group value), y (eval points), violinwidth (density normalized to
9/// [0, 1]). The geom mirrors `violinwidth` around the group's x slot.
10pub struct StatYDensity {
11    pub n_points: usize,
12}
13
14impl Default for StatYDensity {
15    fn default() -> Self {
16        StatYDensity { n_points: 128 }
17    }
18}
19
20impl Stat for StatYDensity {
21    fn compute_group(&self, data: &DataFrame, _scales: &ScaleSet) -> DataFrame {
22        let x_col = data.column("x");
23        let y_col = match data.column("y") {
24            Some(c) => c,
25            None => return DataFrame::new(),
26        };
27
28        let values: Vec<f64> = y_col.iter().filter_map(|v| v.as_f64()).collect();
29        if values.len() < 2 {
30            return DataFrame::new();
31        }
32
33        // Keep the group's x *value* as-is (e.g. the discrete label "A"). The geom
34        // maps it through the X scale, exactly like boxplot — converting to f64 here
35        // would collapse every discrete group to 0.0.
36        let group_x = x_col
37            .and_then(|c| c.first())
38            .cloned()
39            .unwrap_or(Value::Float(0.0));
40
41        let n = values.len() as f64;
42        let mean = values.iter().sum::<f64>() / n;
43        let var = values.iter().map(|y| (y - mean).powi(2)).sum::<f64>() / (n - 1.0);
44        let sd = var.sqrt();
45
46        // Silverman's rule of thumb
47        let iqr_val = iqr(&values);
48        let bandwidth = 0.9 * sd.min(iqr_val / 1.34) * n.powf(-0.2);
49        let bandwidth = if bandwidth > 0.0 { bandwidth } else { sd * 0.5 };
50
51        let y_min = values.iter().cloned().fold(f64::INFINITY, f64::min) - 3.0 * bandwidth;
52        let y_max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max) + 3.0 * bandwidth;
53        let step = (y_max - y_min) / (self.n_points - 1) as f64;
54
55        let mut x_vals = Vec::with_capacity(self.n_points);
56        let mut y_vals = Vec::with_capacity(self.n_points);
57
58        // Compute density at each evaluation point
59        let mut densities = Vec::with_capacity(self.n_points);
60        let mut max_density: f64 = 0.0;
61        for i in 0..self.n_points {
62            let y = y_min + i as f64 * step;
63            let density: f64 = values
64                .iter()
65                .map(|yi| gaussian_kernel((y - yi) / bandwidth))
66                .sum::<f64>()
67                / (n * bandwidth);
68            densities.push((y, density));
69            if density > max_density {
70                max_density = density;
71            }
72        }
73
74        // Normalize density to [0, 1] (peak = 1). The geom scales this by the
75        // per-group slot half-width, so the widest point fills the group's slot.
76        let scale = if max_density > 0.0 {
77            1.0 / max_density
78        } else {
79            1.0
80        };
81
82        let mut width_vals = Vec::with_capacity(self.n_points);
83        for (y, density) in &densities {
84            x_vals.push(group_x.clone());
85            y_vals.push(Value::Float(*y));
86            width_vals.push(Value::Float(density * scale));
87        }
88
89        let mut result = DataFrame::new();
90        result.add_column("x".to_string(), x_vals);
91        result.add_column("y".to_string(), y_vals);
92        result.add_column("violinwidth".to_string(), width_vals);
93
94        // Carry over grouping columns
95        for col_name in &["color", "fill", "group"] {
96            if let Some(col) = data.column(col_name) {
97                if let Some(first) = col.first() {
98                    result.add_column(col_name.to_string(), vec![first.clone(); self.n_points]);
99                }
100            }
101        }
102
103        result
104    }
105
106    fn required_aes(&self) -> Vec<Aesthetic> {
107        vec![Aesthetic::X, Aesthetic::Y]
108    }
109
110    fn name(&self) -> &str {
111        "ydensity"
112    }
113}
114
115fn gaussian_kernel(x: f64) -> f64 {
116    (-(x * x) / 2.0).exp() / (2.0 * std::f64::consts::PI).sqrt()
117}
118
119fn iqr(values: &[f64]) -> f64 {
120    let mut sorted = values.to_vec();
121    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
122    quantile_type7(&sorted, 0.75) - quantile_type7(&sorted, 0.25)
123}
124
125/// R-compatible type-7 quantile interpolation (R's default `quantile()` method).
126fn quantile_type7(sorted: &[f64], p: f64) -> f64 {
127    let n = sorted.len();
128    if n == 0 {
129        return 0.0;
130    }
131    if n == 1 {
132        return sorted[0];
133    }
134    let h = (n - 1) as f64 * p;
135    let lo = h.floor() as usize;
136    let hi = (lo + 1).min(n - 1);
137    let frac = h - lo as f64;
138    sorted[lo] + frac * (sorted[hi] - sorted[lo])
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    #[test]
146    fn test_ydensity_basic() {
147        let mut data = DataFrame::new();
148        data.add_column("x".to_string(), vec![Value::Float(1.0); 50]);
149        let y_vals: Vec<Value> = (0..50).map(|i| Value::Float(i as f64)).collect();
150        data.add_column("y".to_string(), y_vals);
151
152        let stat = StatYDensity::default();
153        let scales = ScaleSet::new();
154        let result = stat.compute_group(&data, &scales);
155
156        assert!(result.nrows() > 0);
157        assert!(result.column("x").is_some());
158        assert!(result.column("y").is_some());
159        assert!(result.column("violinwidth").is_some());
160        // Normalized width peaks at 1.0.
161        let max_w = result
162            .column("violinwidth")
163            .unwrap()
164            .iter()
165            .filter_map(|v| v.as_f64())
166            .fold(0.0_f64, f64::max);
167        assert!(
168            (max_w - 1.0).abs() < 1e-9,
169            "peak width should be 1.0, got {max_w}"
170        );
171    }
172}