Skip to main content

ggplot_rs/stat/
smooth.rs

1use crate::aes::Aesthetic;
2use crate::data::{DataFrame, Value};
3use crate::scale::ScaleSet;
4
5use super::Stat;
6
7/// Smoothing method selection.
8#[derive(Clone, Debug, Default)]
9pub enum SmoothMethod {
10    /// Linear regression (y = mx + b).
11    #[default]
12    Lm,
13    /// LOESS with configurable span.
14    Loess { span: f64 },
15}
16
17/// Smoothing statistic — supports both linear regression and LOESS.
18pub struct StatSmooth {
19    /// Number of points to generate for the fitted line.
20    pub n_points: usize,
21    /// Whether to compute confidence interval.
22    pub se: bool,
23    /// Smoothing method.
24    pub method: SmoothMethod,
25}
26
27impl Default for StatSmooth {
28    fn default() -> Self {
29        StatSmooth {
30            n_points: 80,
31            se: true,
32            method: SmoothMethod::Lm,
33        }
34    }
35}
36
37impl Stat for StatSmooth {
38    fn compute_group(&self, data: &DataFrame, scales: &ScaleSet) -> DataFrame {
39        match &self.method {
40            SmoothMethod::Lm => self.compute_lm(data),
41            SmoothMethod::Loess { span } => {
42                let loess = super::loess::StatLoess {
43                    span: *span,
44                    n_points: self.n_points,
45                    se: self.se,
46                };
47                loess.compute_group(data, scales)
48            }
49        }
50    }
51
52    fn required_aes(&self) -> Vec<Aesthetic> {
53        vec![Aesthetic::X, Aesthetic::Y]
54    }
55
56    fn name(&self) -> &str {
57        "smooth"
58    }
59}
60
61impl StatSmooth {
62    fn compute_lm(&self, data: &DataFrame) -> DataFrame {
63        let x_col = match data.column("x") {
64            Some(c) => c,
65            None => return DataFrame::new(),
66        };
67        let y_col = match data.column("y") {
68            Some(c) => c,
69            None => return DataFrame::new(),
70        };
71
72        let pairs: Vec<(f64, f64)> = x_col
73            .iter()
74            .zip(y_col.iter())
75            .filter_map(|(x, y)| Some((x.as_f64()?, y.as_f64()?)))
76            .collect();
77
78        if pairs.len() < 2 {
79            return DataFrame::new();
80        }
81
82        let n = pairs.len() as f64;
83        let sum_x: f64 = pairs.iter().map(|(x, _)| x).sum();
84        let sum_y: f64 = pairs.iter().map(|(_, y)| y).sum();
85        let sum_xy: f64 = pairs.iter().map(|(x, y)| x * y).sum();
86        let sum_xx: f64 = pairs.iter().map(|(x, _)| x * x).sum();
87
88        let mean_x = sum_x / n;
89        let mean_y = sum_y / n;
90
91        let denom = sum_xx - sum_x * sum_x / n;
92        let (slope, intercept) = if denom.abs() < f64::EPSILON {
93            (0.0, mean_y)
94        } else {
95            let m = (sum_xy - sum_x * sum_y / n) / denom;
96            let b = mean_y - m * mean_x;
97            (m, b)
98        };
99
100        // Generate fitted values across x range
101        let x_min = pairs.iter().map(|(x, _)| *x).fold(f64::INFINITY, f64::min);
102        let x_max = pairs
103            .iter()
104            .map(|(x, _)| *x)
105            .fold(f64::NEG_INFINITY, f64::max);
106
107        let step = (x_max - x_min) / (self.n_points - 1).max(1) as f64;
108
109        // Compute standard error of prediction if requested
110        let se_values = if self.se && pairs.len() > 2 {
111            let residuals: Vec<f64> = pairs
112                .iter()
113                .map(|(x, y)| y - (slope * x + intercept))
114                .collect();
115            let sse: f64 = residuals.iter().map(|r| r * r).sum();
116            let mse = sse / (n - 2.0);
117            Some((mse, sum_xx, mean_x, n))
118        } else {
119            None
120        };
121
122        let mut x_vals = Vec::with_capacity(self.n_points);
123        let mut y_vals = Vec::with_capacity(self.n_points);
124        let mut ymin_vals = Vec::with_capacity(self.n_points);
125        let mut ymax_vals = Vec::with_capacity(self.n_points);
126
127        for i in 0..self.n_points {
128            let x = x_min + i as f64 * step;
129            let y = slope * x + intercept;
130            x_vals.push(Value::Float(x));
131            y_vals.push(Value::Float(y));
132
133            if let Some((mse, sum_xx, mean_x, n)) = se_values {
134                let se_pred = (mse
135                    * (1.0 / n + (x - mean_x).powi(2) / (sum_xx - n * mean_x * mean_x)))
136                    .sqrt();
137                // ~95% CI: t ≈ 1.96 for large n
138                let t_val = 1.96;
139                ymin_vals.push(Value::Float(y - t_val * se_pred));
140                ymax_vals.push(Value::Float(y + t_val * se_pred));
141            }
142        }
143
144        let mut result = DataFrame::new();
145        result.add_column("x".to_string(), x_vals);
146        result.add_column("y".to_string(), y_vals);
147        if !ymin_vals.is_empty() {
148            result.add_column("ymin".to_string(), ymin_vals);
149            result.add_column("ymax".to_string(), ymax_vals);
150        }
151        result
152    }
153}