Skip to main content

ggplot_rs/stat/
smooth.rs

1use crate::aes::Aesthetic;
2use crate::data::{DataFrame, Value};
3use crate::scale::ScaleSet;
4
5use super::Stat;
6
7/// Smoothing method selection.
8#[derive(Clone, Debug, Default)]
9pub enum SmoothMethod {
10    /// Linear regression (y = mx + b).
11    #[default]
12    Lm,
13    /// LOESS with configurable span.
14    Loess { span: f64 },
15    /// Generalized linear model via anofox-regression (Gaussian or Poisson).
16    #[cfg(feature = "regression")]
17    Glm { family: SmoothFamily },
18    /// Robust linear regression (Huber M-estimator) via anofox-regression.
19    #[cfg(feature = "regression")]
20    Rlm,
21}
22
23/// GLM family for regression-backed smoothing (`SmoothMethod::Glm`).
24#[cfg(feature = "regression")]
25#[derive(Clone, Copy, Debug, Default, PartialEq)]
26pub enum SmoothFamily {
27    /// Ordinary least squares (identity link).
28    #[default]
29    Gaussian,
30    /// Poisson regression (log link) for count responses.
31    Poisson,
32}
33
34/// Smoothing statistic — supports both linear regression and LOESS.
35pub struct StatSmooth {
36    /// Number of points to generate for the fitted line.
37    pub n_points: usize,
38    /// Whether to compute confidence interval.
39    pub se: bool,
40    /// Smoothing method.
41    pub method: SmoothMethod,
42}
43
44impl Default for StatSmooth {
45    fn default() -> Self {
46        StatSmooth {
47            n_points: 80,
48            se: true,
49            method: SmoothMethod::Lm,
50        }
51    }
52}
53
54impl Stat for StatSmooth {
55    fn compute_group(&self, data: &DataFrame, scales: &ScaleSet) -> DataFrame {
56        match &self.method {
57            SmoothMethod::Lm => self.compute_lm(data),
58            SmoothMethod::Loess { span } => {
59                let loess = super::loess::StatLoess {
60                    span: *span,
61                    n_points: self.n_points,
62                    se: self.se,
63                };
64                loess.compute_group(data, scales)
65            }
66            #[cfg(feature = "regression")]
67            SmoothMethod::Glm { family } => self.compute_glm(data, Some(*family)),
68            #[cfg(feature = "regression")]
69            SmoothMethod::Rlm => self.compute_glm(data, None),
70        }
71    }
72
73    fn required_aes(&self) -> Vec<Aesthetic> {
74        vec![Aesthetic::X, Aesthetic::Y]
75    }
76
77    fn name(&self) -> &str {
78        "smooth"
79    }
80}
81
82impl StatSmooth {
83    fn compute_lm(&self, data: &DataFrame) -> DataFrame {
84        let x_col = match data.column("x") {
85            Some(c) => c,
86            None => return DataFrame::new(),
87        };
88        let y_col = match data.column("y") {
89            Some(c) => c,
90            None => return DataFrame::new(),
91        };
92
93        let pairs: Vec<(f64, f64)> = x_col
94            .iter()
95            .zip(y_col.iter())
96            .filter_map(|(x, y)| Some((x.as_f64()?, y.as_f64()?)))
97            .collect();
98
99        if pairs.len() < 2 {
100            return DataFrame::new();
101        }
102
103        let n = pairs.len() as f64;
104        let sum_x: f64 = pairs.iter().map(|(x, _)| x).sum();
105        let sum_y: f64 = pairs.iter().map(|(_, y)| y).sum();
106        let sum_xy: f64 = pairs.iter().map(|(x, y)| x * y).sum();
107        let sum_xx: f64 = pairs.iter().map(|(x, _)| x * x).sum();
108
109        let mean_x = sum_x / n;
110        let mean_y = sum_y / n;
111
112        let denom = sum_xx - sum_x * sum_x / n;
113        let (slope, intercept) = if denom.abs() < f64::EPSILON {
114            (0.0, mean_y)
115        } else {
116            let m = (sum_xy - sum_x * sum_y / n) / denom;
117            let b = mean_y - m * mean_x;
118            (m, b)
119        };
120
121        // Generate fitted values across x range
122        let x_min = pairs.iter().map(|(x, _)| *x).fold(f64::INFINITY, f64::min);
123        let x_max = pairs
124            .iter()
125            .map(|(x, _)| *x)
126            .fold(f64::NEG_INFINITY, f64::max);
127
128        let step = (x_max - x_min) / (self.n_points - 1).max(1) as f64;
129
130        // Compute standard error of prediction if requested
131        let se_values = if self.se && pairs.len() > 2 {
132            let residuals: Vec<f64> = pairs
133                .iter()
134                .map(|(x, y)| y - (slope * x + intercept))
135                .collect();
136            let sse: f64 = residuals.iter().map(|r| r * r).sum();
137            let mse = sse / (n - 2.0);
138            Some((mse, sum_xx, mean_x, n))
139        } else {
140            None
141        };
142
143        let mut x_vals = Vec::with_capacity(self.n_points);
144        let mut y_vals = Vec::with_capacity(self.n_points);
145        let mut ymin_vals = Vec::with_capacity(self.n_points);
146        let mut ymax_vals = Vec::with_capacity(self.n_points);
147
148        for i in 0..self.n_points {
149            let x = x_min + i as f64 * step;
150            let y = slope * x + intercept;
151            x_vals.push(Value::Float(x));
152            y_vals.push(Value::Float(y));
153
154            if let Some((mse, sum_xx, mean_x, n)) = se_values {
155                let se_pred = (mse
156                    * (1.0 / n + (x - mean_x).powi(2) / (sum_xx - n * mean_x * mean_x)))
157                    .sqrt();
158                // ~95% CI: t ≈ 1.96 for large n
159                let t_val = 1.96;
160                ymin_vals.push(Value::Float(y - t_val * se_pred));
161                ymax_vals.push(Value::Float(y + t_val * se_pred));
162            }
163        }
164
165        let mut result = DataFrame::new();
166        result.add_column("x".to_string(), x_vals);
167        result.add_column("y".to_string(), y_vals);
168        if !ymin_vals.is_empty() {
169            result.add_column("ymin".to_string(), ymin_vals);
170            result.add_column("ymax".to_string(), ymax_vals);
171        }
172        result
173    }
174
175    /// GLM / robust-linear smoothing backed by anofox-regression. `family = None`
176    /// selects the robust (Huber) fit; `Some(..)` selects a GLM family. A
177    /// confidence-interval ribbon (ymin/ymax) is emitted when `self.se` is set.
178    #[cfg(feature = "regression")]
179    fn compute_glm(&self, data: &DataFrame, family: Option<SmoothFamily>) -> DataFrame {
180        use anofox_regression::solvers::{
181            FittedRegressor, HuberRegressor, OlsRegressor, PoissonRegressor, Regressor,
182        };
183        use anofox_regression::{IntervalType, PoissonFamily, RegressionOptions};
184        use faer::{Col, Mat};
185
186        let (x_col, y_col) = match (data.column("x"), data.column("y")) {
187            (Some(x), Some(y)) => (x, y),
188            _ => return DataFrame::new(),
189        };
190        let pairs: Vec<(f64, f64)> = x_col
191            .iter()
192            .zip(y_col.iter())
193            .filter_map(|(x, y)| Some((x.as_f64()?, y.as_f64()?)))
194            .collect();
195        if pairs.len() < 2 {
196            return DataFrame::new();
197        }
198
199        let n = pairs.len();
200        let x = Mat::from_fn(n, 1, |i, _| pairs[i].0);
201        let y = Col::from_fn(n, |i| pairs[i].1);
202        let x_min = pairs.iter().map(|p| p.0).fold(f64::INFINITY, f64::min);
203        let x_max = pairs.iter().map(|p| p.0).fold(f64::NEG_INFINITY, f64::max);
204        let steps = self.n_points.max(2);
205        let grid = Mat::from_fn(steps, 1, |k, _| {
206            x_min + (x_max - x_min) * k as f64 / (steps - 1) as f64
207        });
208        let interval = if self.se {
209            Some(IntervalType::Confidence)
210        } else {
211            None
212        };
213
214        // Fit the requested model and predict (with interval) over the grid.
215        let pred = match family {
216            None => match HuberRegressor::new().fit(&x, &y) {
217                Ok(f) => f.predict_with_interval(&grid, interval, 0.95),
218                Err(_) => return DataFrame::new(),
219            },
220            Some(SmoothFamily::Gaussian) => {
221                match OlsRegressor::new(RegressionOptions::default()).fit(&x, &y) {
222                    Ok(f) => f.predict_with_interval(&grid, interval, 0.95),
223                    Err(_) => return DataFrame::new(),
224                }
225            }
226            Some(SmoothFamily::Poisson) => {
227                let reg =
228                    PoissonRegressor::new(RegressionOptions::default(), PoissonFamily::default());
229                match reg.fit(&x, &y) {
230                    Ok(f) => f.predict_with_interval(&grid, interval, 0.95),
231                    Err(_) => return DataFrame::new(),
232                }
233            }
234        };
235
236        let mut x_vals = Vec::with_capacity(steps);
237        let mut y_vals = Vec::with_capacity(steps);
238        let mut ymin_vals = Vec::with_capacity(steps);
239        let mut ymax_vals = Vec::with_capacity(steps);
240        for k in 0..steps {
241            x_vals.push(Value::Float(grid[(k, 0)]));
242            y_vals.push(Value::Float(pred.fit[k]));
243            if self.se {
244                ymin_vals.push(Value::Float(pred.lower[k]));
245                ymax_vals.push(Value::Float(pred.upper[k]));
246            }
247        }
248
249        let mut result = DataFrame::new();
250        result.add_column("x".to_string(), x_vals);
251        result.add_column("y".to_string(), y_vals);
252        if self.se {
253            result.add_column("ymin".to_string(), ymin_vals);
254            result.add_column("ymax".to_string(), ymax_vals);
255        }
256        for col_name in &["color", "fill", "group"] {
257            if let Some(col) = data.column(col_name) {
258                if let Some(first) = col.first() {
259                    result.add_column(col_name.to_string(), vec![first.clone(); steps]);
260                }
261            }
262        }
263        result
264    }
265}