Skip to main content

ggplot_rs/stat/
loess.rs

1use crate::aes::Aesthetic;
2use crate::data::{DataFrame, Value};
3use crate::scale::ScaleSet;
4
5use super::Stat;
6
7/// LOESS (locally estimated scatterplot smoothing) via local weighted polynomial regression.
8pub struct StatLoess {
9    /// Span parameter controlling smoothness (0, 1]. Smaller = more flexible.
10    pub span: f64,
11    /// Number of points to generate for the fitted curve.
12    pub n_points: usize,
13    /// Whether to compute confidence interval.
14    pub se: bool,
15}
16
17impl Default for StatLoess {
18    fn default() -> Self {
19        StatLoess {
20            span: 0.75,
21            n_points: 80,
22            se: true,
23        }
24    }
25}
26
27impl Stat for StatLoess {
28    fn compute_group(&self, data: &DataFrame, _scales: &ScaleSet) -> DataFrame {
29        let x_col = match data.column("x") {
30            Some(c) => c,
31            None => return DataFrame::new(),
32        };
33        let y_col = match data.column("y") {
34            Some(c) => c,
35            None => return DataFrame::new(),
36        };
37
38        let pairs: Vec<(f64, f64)> = x_col
39            .iter()
40            .zip(y_col.iter())
41            .filter_map(|(x, y)| Some((x.as_f64()?, y.as_f64()?)))
42            .collect();
43
44        if pairs.len() < 3 {
45            return DataFrame::new();
46        }
47
48        let n = pairs.len();
49        let x_min = pairs.iter().map(|(x, _)| *x).fold(f64::INFINITY, f64::min);
50        let x_max = pairs
51            .iter()
52            .map(|(x, _)| *x)
53            .fold(f64::NEG_INFINITY, f64::max);
54        let step = (x_max - x_min) / (self.n_points - 1).max(1) as f64;
55
56        // Number of neighbors to use
57        let k = ((self.span * n as f64).ceil() as usize).max(3).min(n);
58
59        let mut x_vals = Vec::with_capacity(self.n_points);
60        let mut y_vals = Vec::with_capacity(self.n_points);
61        let mut ymin_vals = Vec::with_capacity(self.n_points);
62        let mut ymax_vals = Vec::with_capacity(self.n_points);
63
64        // Compute residual variance for SE estimation
65        let residual_var = if self.se {
66            let mut sse = 0.0;
67            for &(xi, yi) in &pairs {
68                let y_hat = local_regression(&pairs, xi, k);
69                sse += (yi - y_hat).powi(2);
70            }
71            Some(sse / (n as f64 - 2.0).max(1.0))
72        } else {
73            None
74        };
75
76        for i in 0..self.n_points {
77            let x = x_min + i as f64 * step;
78            let y = local_regression(&pairs, x, k);
79            x_vals.push(Value::Float(x));
80            y_vals.push(Value::Float(y));
81
82            if let Some(var) = residual_var {
83                // Approximate SE using residual variance and effective degrees of freedom
84                let se = var.sqrt() * (1.0 / k as f64 + 1.0 / n as f64).sqrt() * 1.5;
85                let t_val = 1.96;
86                ymin_vals.push(Value::Float(y - t_val * se));
87                ymax_vals.push(Value::Float(y + t_val * se));
88            }
89        }
90
91        let mut result = DataFrame::new();
92        result.add_column("x".to_string(), x_vals);
93        result.add_column("y".to_string(), y_vals);
94        if !ymin_vals.is_empty() {
95            result.add_column("ymin".to_string(), ymin_vals);
96            result.add_column("ymax".to_string(), ymax_vals);
97        }
98        result
99    }
100
101    fn required_aes(&self) -> Vec<Aesthetic> {
102        vec![Aesthetic::X, Aesthetic::Y]
103    }
104
105    fn name(&self) -> &str {
106        "loess"
107    }
108}
109
110/// Perform local weighted linear regression at point x0.
111fn local_regression(pairs: &[(f64, f64)], x0: f64, k: usize) -> f64 {
112    // Sort by distance to x0 and take k nearest
113    let mut dists: Vec<(usize, f64)> = pairs
114        .iter()
115        .enumerate()
116        .map(|(i, (x, _))| (i, (x - x0).abs()))
117        .collect();
118    dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
119
120    let max_dist = dists[k - 1].1;
121    let max_dist = if max_dist < f64::EPSILON {
122        1.0
123    } else {
124        max_dist
125    };
126
127    // Tricube weight function
128    let weights: Vec<(f64, f64, f64)> = dists[..k]
129        .iter()
130        .map(|(i, d)| {
131            let u = d / max_dist;
132            let u = u.min(1.0);
133            let w = (1.0 - u * u * u).powi(3);
134            (pairs[*i].0, pairs[*i].1, w)
135        })
136        .collect();
137
138    // Weighted linear regression: y = a + b*x
139    let sum_w: f64 = weights.iter().map(|(_, _, w)| w).sum();
140    if sum_w < f64::EPSILON {
141        return pairs.iter().map(|(_, y)| y).sum::<f64>() / pairs.len() as f64;
142    }
143
144    let sum_wx: f64 = weights.iter().map(|(x, _, w)| w * x).sum();
145    let sum_wy: f64 = weights.iter().map(|(_, y, w)| w * y).sum();
146    let sum_wxx: f64 = weights.iter().map(|(x, _, w)| w * x * x).sum();
147    let sum_wxy: f64 = weights.iter().map(|(x, y, w)| w * x * y).sum();
148
149    let mean_x = sum_wx / sum_w;
150    let mean_y = sum_wy / sum_w;
151
152    let denom = sum_wxx - sum_wx * sum_wx / sum_w;
153    if denom.abs() < f64::EPSILON {
154        mean_y
155    } else {
156        let b = (sum_wxy - sum_wx * sum_wy / sum_w) / denom;
157        let a = mean_y - b * mean_x;
158        a + b * x0
159    }
160}