1use crate::aes::Aesthetic;
2use crate::data::{DataFrame, Value};
3use crate::scale::ScaleSet;
4
5use super::Stat;
6
7pub struct StatLoess {
9 pub span: f64,
11 pub n_points: usize,
13 pub se: bool,
15}
16
17impl Default for StatLoess {
18 fn default() -> Self {
19 StatLoess {
20 span: 0.75,
21 n_points: 80,
22 se: true,
23 }
24 }
25}
26
27impl Stat for StatLoess {
28 fn compute_group(&self, data: &DataFrame, _scales: &ScaleSet) -> DataFrame {
29 let x_col = match data.column("x") {
30 Some(c) => c,
31 None => return DataFrame::new(),
32 };
33 let y_col = match data.column("y") {
34 Some(c) => c,
35 None => return DataFrame::new(),
36 };
37
38 let pairs: Vec<(f64, f64)> = x_col
39 .iter()
40 .zip(y_col.iter())
41 .filter_map(|(x, y)| Some((x.as_f64()?, y.as_f64()?)))
42 .collect();
43
44 if pairs.len() < 3 {
45 return DataFrame::new();
46 }
47
48 let n = pairs.len();
49 let x_min = pairs.iter().map(|(x, _)| *x).fold(f64::INFINITY, f64::min);
50 let x_max = pairs
51 .iter()
52 .map(|(x, _)| *x)
53 .fold(f64::NEG_INFINITY, f64::max);
54 let step = (x_max - x_min) / (self.n_points - 1).max(1) as f64;
55
56 let k = ((self.span * n as f64).ceil() as usize).max(3).min(n);
58
59 let mut x_vals = Vec::with_capacity(self.n_points);
60 let mut y_vals = Vec::with_capacity(self.n_points);
61 let mut ymin_vals = Vec::with_capacity(self.n_points);
62 let mut ymax_vals = Vec::with_capacity(self.n_points);
63
64 let residual_var = if self.se {
66 let mut sse = 0.0;
67 for &(xi, yi) in &pairs {
68 let y_hat = local_regression(&pairs, xi, k);
69 sse += (yi - y_hat).powi(2);
70 }
71 Some(sse / (n as f64 - 2.0).max(1.0))
72 } else {
73 None
74 };
75
76 for i in 0..self.n_points {
77 let x = x_min + i as f64 * step;
78 let y = local_regression(&pairs, x, k);
79 x_vals.push(Value::Float(x));
80 y_vals.push(Value::Float(y));
81
82 if let Some(var) = residual_var {
83 let se = var.sqrt() * (1.0 / k as f64 + 1.0 / n as f64).sqrt() * 1.5;
85 let t_val = 1.96;
86 ymin_vals.push(Value::Float(y - t_val * se));
87 ymax_vals.push(Value::Float(y + t_val * se));
88 }
89 }
90
91 let mut result = DataFrame::new();
92 result.add_column("x".to_string(), x_vals);
93 result.add_column("y".to_string(), y_vals);
94 if !ymin_vals.is_empty() {
95 result.add_column("ymin".to_string(), ymin_vals);
96 result.add_column("ymax".to_string(), ymax_vals);
97 }
98 result
99 }
100
101 fn required_aes(&self) -> Vec<Aesthetic> {
102 vec![Aesthetic::X, Aesthetic::Y]
103 }
104
105 fn name(&self) -> &str {
106 "loess"
107 }
108}
109
110fn local_regression(pairs: &[(f64, f64)], x0: f64, k: usize) -> f64 {
113 let mut dists: Vec<(usize, f64)> = pairs
115 .iter()
116 .enumerate()
117 .map(|(i, (x, _))| (i, (x - x0).abs()))
118 .collect();
119 dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
120
121 let max_dist = dists[k - 1].1;
122 let max_dist = if max_dist < f64::EPSILON {
123 1.0
124 } else {
125 max_dist
126 };
127
128 let weights: Vec<(f64, f64, f64)> = dists[..k]
130 .iter()
131 .map(|(i, d)| {
132 let u = d / max_dist;
133 let u = u.min(1.0);
134 let w = (1.0 - u * u * u).powi(3);
135 (pairs[*i].0, pairs[*i].1, w)
136 })
137 .collect();
138
139 let sum_w: f64 = weights.iter().map(|(_, _, w)| w).sum();
140 if sum_w < f64::EPSILON {
141 return pairs.iter().map(|(_, y)| y).sum::<f64>() / pairs.len() as f64;
142 }
143 let mean_y = weights.iter().map(|(_, y, w)| w * y).sum::<f64>() / sum_w;
144
145 let (mut s1, mut s2, mut s3, mut s4) = (0.0, 0.0, 0.0, 0.0);
149 let (mut ty0, mut ty1, mut ty2) = (0.0, 0.0, 0.0);
150 for &(x, y, w) in &weights {
151 let t = x - x0;
152 let (t2, t3, t4) = (t * t, t * t * t, t * t * t * t);
153 s1 += w * t;
154 s2 += w * t2;
155 s3 += w * t3;
156 s4 += w * t4;
157 ty0 += w * y;
158 ty1 += w * t * y;
159 ty2 += w * t2 * y;
160 }
161 let s0 = sum_w;
163 let det = s0 * (s2 * s4 - s3 * s3) - s1 * (s1 * s4 - s3 * s2) + s2 * (s1 * s3 - s2 * s2);
164 if det.abs() < 1e-12 {
165 let denom = s0 * s2 - s1 * s1;
167 if denom.abs() < 1e-12 {
168 return mean_y;
169 }
170 let b = (s0 * ty1 - s1 * ty0) / denom;
171 let a = (ty0 - b * s1) / s0;
172 return a;
173 }
174 let det_a = ty0 * (s2 * s4 - s3 * s3) - s1 * (ty1 * s4 - s3 * ty2) + s2 * (ty1 * s3 - s2 * ty2);
176 det_a / det
177}