1use crate::aes::Aesthetic;
2use crate::data::{DataFrame, Value};
3use crate::scale::ScaleSet;
4
5use super::Stat;
6
7#[derive(Clone, Debug, Default)]
9pub enum SmoothMethod {
10 #[default]
12 Lm,
13 Loess { span: f64 },
15}
16
17pub struct StatSmooth {
19 pub n_points: usize,
21 pub se: bool,
23 pub method: SmoothMethod,
25}
26
27impl Default for StatSmooth {
28 fn default() -> Self {
29 StatSmooth {
30 n_points: 80,
31 se: true,
32 method: SmoothMethod::Lm,
33 }
34 }
35}
36
37impl Stat for StatSmooth {
38 fn compute_group(&self, data: &DataFrame, scales: &ScaleSet) -> DataFrame {
39 match &self.method {
40 SmoothMethod::Lm => self.compute_lm(data),
41 SmoothMethod::Loess { span } => {
42 let loess = super::loess::StatLoess {
43 span: *span,
44 n_points: self.n_points,
45 se: self.se,
46 };
47 loess.compute_group(data, scales)
48 }
49 }
50 }
51
52 fn required_aes(&self) -> Vec<Aesthetic> {
53 vec![Aesthetic::X, Aesthetic::Y]
54 }
55
56 fn name(&self) -> &str {
57 "smooth"
58 }
59}
60
61impl StatSmooth {
62 fn compute_lm(&self, data: &DataFrame) -> DataFrame {
63 let x_col = match data.column("x") {
64 Some(c) => c,
65 None => return DataFrame::new(),
66 };
67 let y_col = match data.column("y") {
68 Some(c) => c,
69 None => return DataFrame::new(),
70 };
71
72 let pairs: Vec<(f64, f64)> = x_col
73 .iter()
74 .zip(y_col.iter())
75 .filter_map(|(x, y)| Some((x.as_f64()?, y.as_f64()?)))
76 .collect();
77
78 if pairs.len() < 2 {
79 return DataFrame::new();
80 }
81
82 let n = pairs.len() as f64;
83 let sum_x: f64 = pairs.iter().map(|(x, _)| x).sum();
84 let sum_y: f64 = pairs.iter().map(|(_, y)| y).sum();
85 let sum_xy: f64 = pairs.iter().map(|(x, y)| x * y).sum();
86 let sum_xx: f64 = pairs.iter().map(|(x, _)| x * x).sum();
87
88 let mean_x = sum_x / n;
89 let mean_y = sum_y / n;
90
91 let denom = sum_xx - sum_x * sum_x / n;
92 let (slope, intercept) = if denom.abs() < f64::EPSILON {
93 (0.0, mean_y)
94 } else {
95 let m = (sum_xy - sum_x * sum_y / n) / denom;
96 let b = mean_y - m * mean_x;
97 (m, b)
98 };
99
100 let x_min = pairs.iter().map(|(x, _)| *x).fold(f64::INFINITY, f64::min);
102 let x_max = pairs
103 .iter()
104 .map(|(x, _)| *x)
105 .fold(f64::NEG_INFINITY, f64::max);
106
107 let step = (x_max - x_min) / (self.n_points - 1).max(1) as f64;
108
109 let se_values = if self.se && pairs.len() > 2 {
111 let residuals: Vec<f64> = pairs
112 .iter()
113 .map(|(x, y)| y - (slope * x + intercept))
114 .collect();
115 let sse: f64 = residuals.iter().map(|r| r * r).sum();
116 let mse = sse / (n - 2.0);
117 Some((mse, sum_xx, mean_x, n))
118 } else {
119 None
120 };
121
122 let mut x_vals = Vec::with_capacity(self.n_points);
123 let mut y_vals = Vec::with_capacity(self.n_points);
124 let mut ymin_vals = Vec::with_capacity(self.n_points);
125 let mut ymax_vals = Vec::with_capacity(self.n_points);
126
127 for i in 0..self.n_points {
128 let x = x_min + i as f64 * step;
129 let y = slope * x + intercept;
130 x_vals.push(Value::Float(x));
131 y_vals.push(Value::Float(y));
132
133 if let Some((mse, sum_xx, mean_x, n)) = se_values {
134 let se_pred = (mse
135 * (1.0 / n + (x - mean_x).powi(2) / (sum_xx - n * mean_x * mean_x)))
136 .sqrt();
137 let t_val = 1.96;
139 ymin_vals.push(Value::Float(y - t_val * se_pred));
140 ymax_vals.push(Value::Float(y + t_val * se_pred));
141 }
142 }
143
144 let mut result = DataFrame::new();
145 result.add_column("x".to_string(), x_vals);
146 result.add_column("y".to_string(), y_vals);
147 if !ymin_vals.is_empty() {
148 result.add_column("ymin".to_string(), ymin_vals);
149 result.add_column("ymax".to_string(), ymax_vals);
150 }
151 result
152 }
153}