Skip to main content

ggplot_rs/stat/
ellipse.rs

1use crate::aes::Aesthetic;
2use crate::data::{DataFrame, Value};
3use crate::scale::ScaleSet;
4
5use super::Stat;
6
7/// Confidence ellipse for a 2-D point cloud (analogous to R's `stat_ellipse`).
8///
9/// Assumes a bivariate normal distribution: the ellipse is the covariance
10/// eigen-decomposition scaled by the chi-square quantile for `level` (df = 2).
11/// Emits `segments + 1` boundary points forming a closed path per group.
12pub struct StatEllipse {
13    /// Confidence level in (0, 1). Default 0.95.
14    pub level: f64,
15    /// Number of segments used to draw the ellipse. Default 51.
16    pub segments: usize,
17}
18
19impl Default for StatEllipse {
20    fn default() -> Self {
21        StatEllipse {
22            level: 0.95,
23            segments: 51,
24        }
25    }
26}
27
28impl StatEllipse {
29    pub fn new(level: f64) -> Self {
30        StatEllipse {
31            level,
32            ..Default::default()
33        }
34    }
35}
36
37impl Stat for StatEllipse {
38    fn compute_group(&self, data: &DataFrame, _scales: &ScaleSet) -> DataFrame {
39        let (xs, ys) = match (data.column("x"), data.column("y")) {
40            (Some(x), Some(y)) => (x, y),
41            _ => return DataFrame::new(),
42        };
43        let pts: Vec<(f64, f64)> = xs
44            .iter()
45            .zip(ys.iter())
46            .filter_map(|(a, b)| Some((a.as_f64()?, b.as_f64()?)))
47            .collect();
48        if pts.len() < 3 {
49            return DataFrame::new();
50        }
51
52        let n = pts.len() as f64;
53        let mx = pts.iter().map(|p| p.0).sum::<f64>() / n;
54        let my = pts.iter().map(|p| p.1).sum::<f64>() / n;
55
56        // Sample covariance (n - 1 denominator).
57        let mut sxx = 0.0;
58        let mut syy = 0.0;
59        let mut sxy = 0.0;
60        for &(x, y) in &pts {
61            sxx += (x - mx) * (x - mx);
62            syy += (y - my) * (y - my);
63            sxy += (x - mx) * (y - my);
64        }
65        let d = n - 1.0;
66        let (sxx, syy, sxy) = (sxx / d, syy / d, sxy / d);
67
68        // Eigen-decomposition of the symmetric 2x2 [[sxx, sxy], [sxy, syy]].
69        let trace = sxx + syy;
70        let det = sxx * syy - sxy * sxy;
71        let disc = ((trace * 0.5).powi(2) - det).max(0.0).sqrt();
72        let l1 = (trace * 0.5 + disc).max(0.0);
73        let l2 = (trace * 0.5 - disc).max(0.0);
74        let (v1x, v1y) = if sxy.abs() > 1e-12 {
75            let vx = l1 - syy;
76            let vy = sxy;
77            let norm = (vx * vx + vy * vy).sqrt();
78            (vx / norm, vy / norm)
79        } else if sxx >= syy {
80            (1.0, 0.0)
81        } else {
82            (0.0, 1.0)
83        };
84        // Second axis is perpendicular to the first.
85        let (v2x, v2y) = (-v1y, v1x);
86
87        // Chi-square quantile with 2 dof has the closed form -2 ln(1 - level).
88        let radius = (-2.0 * (1.0 - self.level).ln()).sqrt();
89        let a = radius * l1.sqrt();
90        let b = radius * l2.sqrt();
91
92        let steps = self.segments.max(3);
93        let mut x_vals = Vec::with_capacity(steps + 1);
94        let mut y_vals = Vec::with_capacity(steps + 1);
95        for i in 0..=steps {
96            let theta = 2.0 * std::f64::consts::PI * (i as f64) / (steps as f64);
97            let (c, s) = (theta.cos(), theta.sin());
98            let px = mx + a * c * v1x + b * s * v2x;
99            let py = my + a * c * v1y + b * s * v2y;
100            x_vals.push(Value::Float(px));
101            y_vals.push(Value::Float(py));
102        }
103
104        let nrows = x_vals.len();
105        let mut result = DataFrame::new();
106        result.add_column("x".to_string(), x_vals);
107        result.add_column("y".to_string(), y_vals);
108        for col_name in &["color", "fill", "group"] {
109            if let Some(col) = data.column(col_name) {
110                if let Some(first) = col.first() {
111                    result.add_column(col_name.to_string(), vec![first.clone(); nrows]);
112                }
113            }
114        }
115        result
116    }
117
118    fn required_aes(&self) -> Vec<Aesthetic> {
119        vec![Aesthetic::X, Aesthetic::Y]
120    }
121
122    fn name(&self) -> &str {
123        "ellipse"
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130
131    fn frame(pts: &[(f64, f64)]) -> DataFrame {
132        let mut df = DataFrame::new();
133        df.add_column("x".into(), pts.iter().map(|p| Value::Float(p.0)).collect());
134        df.add_column("y".into(), pts.iter().map(|p| Value::Float(p.1)).collect());
135        df
136    }
137
138    #[test]
139    fn ellipse_of_circular_cloud_is_centered() {
140        // A symmetric ring of points → ellipse centred at the mean.
141        let pts: Vec<(f64, f64)> = (0..40)
142            .map(|i| {
143                let t = 2.0 * std::f64::consts::PI * i as f64 / 40.0;
144                (5.0 + t.cos(), 3.0 + t.sin())
145            })
146            .collect();
147        let out = StatEllipse::default().compute_group(&frame(&pts), &ScaleSet::new());
148        assert_eq!(out.nrows(), StatEllipse::default().segments + 1);
149        let xs: Vec<f64> = out
150            .column("x")
151            .unwrap()
152            .iter()
153            .filter_map(|v| v.as_f64())
154            .collect();
155        let ys: Vec<f64> = out
156            .column("y")
157            .unwrap()
158            .iter()
159            .filter_map(|v| v.as_f64())
160            .collect();
161        let cx = xs.iter().sum::<f64>() / xs.len() as f64;
162        let cy = ys.iter().sum::<f64>() / ys.len() as f64;
163        assert!((cx - 5.0).abs() < 0.2, "center x {cx}");
164        assert!((cy - 3.0).abs() < 0.2, "center y {cy}");
165        // Closed path: first point equals last.
166        assert!((xs[0] - xs[xs.len() - 1]).abs() < 1e-9);
167    }
168
169    #[test]
170    fn too_few_points_returns_empty() {
171        let out = StatEllipse::default()
172            .compute_group(&frame(&[(0.0, 0.0), (1.0, 1.0)]), &ScaleSet::new());
173        assert_eq!(out.nrows(), 0);
174    }
175
176    #[test]
177    fn higher_level_makes_larger_ellipse() {
178        let pts: Vec<(f64, f64)> = (0..30)
179            .map(|i| (i as f64, (i as f64 * 0.7).sin() * 3.0))
180            .collect();
181        let small = StatEllipse::new(0.5).compute_group(&frame(&pts), &ScaleSet::new());
182        let big = StatEllipse::new(0.99).compute_group(&frame(&pts), &ScaleSet::new());
183        let span = |df: &DataFrame| {
184            let xs: Vec<f64> = df
185                .column("x")
186                .unwrap()
187                .iter()
188                .filter_map(|v| v.as_f64())
189                .collect();
190            xs.iter().cloned().fold(f64::MIN, f64::max)
191                - xs.iter().cloned().fold(f64::MAX, f64::min)
192        };
193        assert!(span(&big) > span(&small));
194    }
195}