Skip to main content

ggplot_rs/stat/
loess.rs

1use crate::aes::Aesthetic;
2use crate::data::{DataFrame, Value};
3use crate::scale::ScaleSet;
4
5use super::Stat;
6
7/// LOESS (locally estimated scatterplot smoothing) via local weighted polynomial regression.
8pub struct StatLoess {
9    /// Span parameter controlling smoothness (0, 1]. Smaller = more flexible.
10    pub span: f64,
11    /// Number of points to generate for the fitted curve.
12    pub n_points: usize,
13    /// Whether to compute confidence interval.
14    pub se: bool,
15}
16
17impl Default for StatLoess {
18    fn default() -> Self {
19        StatLoess {
20            span: 0.75,
21            n_points: 80,
22            se: true,
23        }
24    }
25}
26
27impl Stat for StatLoess {
28    fn compute_group(&self, data: &DataFrame, _scales: &ScaleSet) -> DataFrame {
29        let x_col = match data.column("x") {
30            Some(c) => c,
31            None => return DataFrame::new(),
32        };
33        let y_col = match data.column("y") {
34            Some(c) => c,
35            None => return DataFrame::new(),
36        };
37
38        let pairs: Vec<(f64, f64)> = x_col
39            .iter()
40            .zip(y_col.iter())
41            .filter_map(|(x, y)| Some((x.as_f64()?, y.as_f64()?)))
42            .collect();
43
44        if pairs.len() < 3 {
45            return DataFrame::new();
46        }
47
48        let n = pairs.len();
49        let x_min = pairs.iter().map(|(x, _)| *x).fold(f64::INFINITY, f64::min);
50        let x_max = pairs
51            .iter()
52            .map(|(x, _)| *x)
53            .fold(f64::NEG_INFINITY, f64::max);
54        let step = (x_max - x_min) / (self.n_points - 1).max(1) as f64;
55
56        // Number of neighbors to use
57        let k = ((self.span * n as f64).ceil() as usize).max(3).min(n);
58
59        let mut x_vals = Vec::with_capacity(self.n_points);
60        let mut y_vals = Vec::with_capacity(self.n_points);
61        let mut ymin_vals = Vec::with_capacity(self.n_points);
62        let mut ymax_vals = Vec::with_capacity(self.n_points);
63
64        // Compute residual variance for SE estimation
65        let residual_var = if self.se {
66            let mut sse = 0.0;
67            for &(xi, yi) in &pairs {
68                let y_hat = local_regression(&pairs, xi, k);
69                sse += (yi - y_hat).powi(2);
70            }
71            Some(sse / (n as f64 - 2.0).max(1.0))
72        } else {
73            None
74        };
75
76        for i in 0..self.n_points {
77            let x = x_min + i as f64 * step;
78            let y = local_regression(&pairs, x, k);
79            x_vals.push(Value::Float(x));
80            y_vals.push(Value::Float(y));
81
82            if let Some(var) = residual_var {
83                // Approximate SE using residual variance and effective degrees of freedom
84                let se = var.sqrt() * (1.0 / k as f64 + 1.0 / n as f64).sqrt() * 1.5;
85                let t_val = 1.96;
86                ymin_vals.push(Value::Float(y - t_val * se));
87                ymax_vals.push(Value::Float(y + t_val * se));
88            }
89        }
90
91        let mut result = DataFrame::new();
92        result.add_column("x".to_string(), x_vals);
93        result.add_column("y".to_string(), y_vals);
94        if !ymin_vals.is_empty() {
95            result.add_column("ymin".to_string(), ymin_vals);
96            result.add_column("ymax".to_string(), ymax_vals);
97        }
98        result
99    }
100
101    fn required_aes(&self) -> Vec<Aesthetic> {
102        vec![Aesthetic::X, Aesthetic::Y]
103    }
104
105    fn name(&self) -> &str {
106        "loess"
107    }
108}
109
110/// Fit a tricube-weighted local quadratic (degree 2, like R's loess) over the
111/// `k` nearest neighbors and return the prediction at `x0`.
112fn local_regression(pairs: &[(f64, f64)], x0: f64, k: usize) -> f64 {
113    // Sort by distance to x0 and take k nearest
114    let mut dists: Vec<(usize, f64)> = pairs
115        .iter()
116        .enumerate()
117        .map(|(i, (x, _))| (i, (x - x0).abs()))
118        .collect();
119    dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
120
121    let max_dist = dists[k - 1].1;
122    let max_dist = if max_dist < f64::EPSILON {
123        1.0
124    } else {
125        max_dist
126    };
127
128    // Tricube weight function
129    let weights: Vec<(f64, f64, f64)> = dists[..k]
130        .iter()
131        .map(|(i, d)| {
132            let u = d / max_dist;
133            let u = u.min(1.0);
134            let w = (1.0 - u * u * u).powi(3);
135            (pairs[*i].0, pairs[*i].1, w)
136        })
137        .collect();
138
139    let sum_w: f64 = weights.iter().map(|(_, _, w)| w).sum();
140    if sum_w < f64::EPSILON {
141        return pairs.iter().map(|(_, y)| y).sum::<f64>() / pairs.len() as f64;
142    }
143    let mean_y = weights.iter().map(|(_, y, w)| w * y).sum::<f64>() / sum_w;
144
145    // Weighted local quadratic regression (R's loess default degree = 2),
146    // centered at x0 so the prediction is just the intercept. Solve the 3×3
147    // normal equations for [a, b, c] with t = x - x0; the fit at t=0 is `a`.
148    let (mut s1, mut s2, mut s3, mut s4) = (0.0, 0.0, 0.0, 0.0);
149    let (mut ty0, mut ty1, mut ty2) = (0.0, 0.0, 0.0);
150    for &(x, y, w) in &weights {
151        let t = x - x0;
152        let (t2, t3, t4) = (t * t, t * t * t, t * t * t * t);
153        s1 += w * t;
154        s2 += w * t2;
155        s3 += w * t3;
156        s4 += w * t4;
157        ty0 += w * y;
158        ty1 += w * t * y;
159        ty2 += w * t2 * y;
160    }
161    // Matrix M = [[s0,s1,s2],[s1,s2,s3],[s2,s3,s4]], RHS = [ty0,ty1,ty2].
162    let s0 = sum_w;
163    let det = s0 * (s2 * s4 - s3 * s3) - s1 * (s1 * s4 - s3 * s2) + s2 * (s1 * s3 - s2 * s2);
164    if det.abs() < 1e-12 {
165        // Singular (e.g. too few distinct x): fall back to weighted linear.
166        let denom = s0 * s2 - s1 * s1;
167        if denom.abs() < 1e-12 {
168            return mean_y;
169        }
170        let b = (s0 * ty1 - s1 * ty0) / denom;
171        let a = (ty0 - b * s1) / s0;
172        return a;
173    }
174    // Cramer's rule for a (column 0 replaced by RHS).
175    let det_a = ty0 * (s2 * s4 - s3 * s3) - s1 * (ty1 * s4 - s3 * ty2) + s2 * (ty1 * s3 - s2 * ty2);
176    det_a / det
177}