Skip to main content

ggplot_rs/geom/
smooth.rs

1use crate::aes::Aesthetic;
2use crate::coord::Coord;
3use crate::data::DataFrame;
4use crate::position::identity::PositionIdentity;
5use crate::position::Position;
6use crate::render::backend::{DrawBackend, LineStyle, Linetype, RectStyle};
7use crate::render::RenderError;
8use crate::scale::ScaleSet;
9use crate::stat::smooth::{SmoothMethod, StatSmooth};
10use crate::stat::Stat;
11use crate::theme::Theme;
12
13use super::{Geom, GeomParams};
14
15/// Smooth line with optional confidence ribbon.
16pub struct GeomSmooth {
17    pub color: (u8, u8, u8),
18    pub fill: (u8, u8, u8),
19    pub line_width: f64,
20    pub alpha: f64,
21    pub se: bool,
22    pub n_points: usize,
23    pub method: SmoothMethod,
24}
25
26impl Default for GeomSmooth {
27    fn default() -> Self {
28        GeomSmooth {
29            color: (51, 102, 204),
30            fill: (51, 102, 204),
31            line_width: 1.5,
32            alpha: 0.2,
33            se: true,
34            n_points: 80,
35            method: SmoothMethod::Lm,
36        }
37    }
38}
39
40impl GeomSmooth {
41    /// Use LOESS smoothing with the given span.
42    pub fn loess(mut self, span: f64) -> Self {
43        self.method = SmoothMethod::Loess { span };
44        self
45    }
46}
47
48impl Geom for GeomSmooth {
49    fn draw(
50        &self,
51        data: &DataFrame,
52        coord: &dyn Coord,
53        scales: &ScaleSet,
54        _theme: &Theme,
55        backend: &mut dyn DrawBackend,
56    ) -> Result<(), RenderError> {
57        let x_col = data
58            .column("x")
59            .ok_or(RenderError::MissingAesthetic("x".into()))?;
60        let y_col = data
61            .column("y")
62            .ok_or(RenderError::MissingAesthetic("y".into()))?;
63        let ymin_col = data.column("ymin");
64        let ymax_col = data.column("ymax");
65        let color_col = data.column("color");
66        let fill_col = data.column("fill");
67
68        let plot_area = backend.plot_area();
69        let x_scale = scales.get(&Aesthetic::X);
70        let y_scale = scales.get(&Aesthetic::Y);
71
72        // If there's a color/fill aesthetic, draw separate smooths per group
73        if let Some(cc) = color_col.or(fill_col) {
74            let mut groups: Vec<(String, Vec<usize>)> = Vec::new();
75            for (i, v) in cc.iter().enumerate() {
76                let key = v.to_group_key();
77                if let Some(entry) = groups.iter_mut().find(|(k, _)| k == &key) {
78                    entry.1.push(i);
79                } else {
80                    groups.push((key, vec![i]));
81                }
82            }
83
84            for (_, indices) in &groups {
85                let first_idx = indices[0];
86
87                // Determine colors from mapped aesthetics
88                let line_color = color_col
89                    .and_then(|c| scales.map_color(&Aesthetic::Color, &c[first_idx]))
90                    .unwrap_or(self.color);
91                let ribbon_fill = fill_col
92                    .and_then(|f| scales.map_color(&Aesthetic::Fill, &f[first_idx]))
93                    .or_else(|| {
94                        color_col.and_then(|c| scales.map_color(&Aesthetic::Color, &c[first_idx]))
95                    })
96                    .unwrap_or(self.fill);
97
98                // Draw confidence ribbon
99                if self.se {
100                    if let (Some(ymin), Some(ymax)) = (ymin_col, ymax_col) {
101                        let mut upper_points: Vec<(f64, f64)> = Vec::new();
102                        let mut lower_points: Vec<(f64, f64)> = Vec::new();
103
104                        for &i in indices {
105                            let nx = x_scale.map(|s| s.map(&x_col[i])).unwrap_or(0.0);
106                            let ny_max = y_scale.map(|s| s.map(&ymax[i])).unwrap_or(0.0);
107                            let ny_min = y_scale.map(|s| s.map(&ymin[i])).unwrap_or(0.0);
108
109                            upper_points.push(coord.transform((nx, ny_max), &plot_area));
110                            lower_points.push(coord.transform((nx, ny_min), &plot_area));
111                        }
112
113                        let mut polygon = upper_points;
114                        lower_points.reverse();
115                        polygon.extend(lower_points);
116
117                        if polygon.len() >= 3 {
118                            backend.draw_polygon(
119                                &polygon,
120                                &RectStyle {
121                                    fill: Some(ribbon_fill),
122                                    stroke: None,
123                                    stroke_width: 0.0,
124                                    alpha: self.alpha,
125                                    clip: true,
126                                },
127                            )?;
128                        }
129                    }
130                }
131
132                // Draw fitted line
133                let points: Vec<(f64, f64)> = indices
134                    .iter()
135                    .map(|&i| {
136                        let nx = x_scale.map(|s| s.map(&x_col[i])).unwrap_or(0.0);
137                        let ny = y_scale.map(|s| s.map(&y_col[i])).unwrap_or(0.0);
138                        coord.transform((nx, ny), &plot_area)
139                    })
140                    .collect();
141
142                if points.len() >= 2 {
143                    backend.draw_line(
144                        &points,
145                        &LineStyle {
146                            color: line_color,
147                            alpha: 1.0,
148                            width: self.line_width,
149                            linetype: Linetype::Solid,
150                        },
151                    )?;
152                }
153            }
154        } else {
155            // No grouping — original behavior with fixed colors
156
157            // Draw confidence ribbon first (behind line)
158            if self.se {
159                if let (Some(ymin), Some(ymax)) = (ymin_col, ymax_col) {
160                    let mut upper_points: Vec<(f64, f64)> = Vec::new();
161                    let mut lower_points: Vec<(f64, f64)> = Vec::new();
162
163                    for i in 0..data.nrows() {
164                        let nx = x_scale.map(|s| s.map(&x_col[i])).unwrap_or(0.0);
165                        let ny_max = y_scale.map(|s| s.map(&ymax[i])).unwrap_or(0.0);
166                        let ny_min = y_scale.map(|s| s.map(&ymin[i])).unwrap_or(0.0);
167
168                        upper_points.push(coord.transform((nx, ny_max), &plot_area));
169                        lower_points.push(coord.transform((nx, ny_min), &plot_area));
170                    }
171
172                    // Build polygon: upper left-to-right, then lower right-to-left
173                    let mut polygon = upper_points;
174                    lower_points.reverse();
175                    polygon.extend(lower_points);
176
177                    if polygon.len() >= 3 {
178                        backend.draw_polygon(
179                            &polygon,
180                            &RectStyle {
181                                fill: Some(self.fill),
182                                stroke: None,
183                                stroke_width: 0.0,
184                                alpha: self.alpha,
185                                clip: true,
186                            },
187                        )?;
188                    }
189                }
190            }
191
192            // Draw fitted line
193            let points: Vec<(f64, f64)> = (0..data.nrows())
194                .map(|i| {
195                    let nx = x_scale.map(|s| s.map(&x_col[i])).unwrap_or(0.0);
196                    let ny = y_scale.map(|s| s.map(&y_col[i])).unwrap_or(0.0);
197                    coord.transform((nx, ny), &plot_area)
198                })
199                .collect();
200
201            if points.len() >= 2 {
202                backend.draw_line(
203                    &points,
204                    &LineStyle {
205                        color: self.color,
206                        alpha: 1.0,
207                        width: self.line_width,
208                        linetype: Linetype::Solid,
209                    },
210                )?;
211            }
212        }
213
214        Ok(())
215    }
216
217    fn required_aes(&self) -> Vec<Aesthetic> {
218        vec![Aesthetic::X, Aesthetic::Y]
219    }
220
221    fn default_stat(&self) -> Box<dyn Stat> {
222        Box::new(StatSmooth {
223            n_points: self.n_points,
224            se: self.se,
225            method: self.method.clone(),
226        })
227    }
228
229    fn default_position(&self) -> Box<dyn Position> {
230        Box::new(PositionIdentity)
231    }
232
233    fn default_params(&self) -> GeomParams {
234        GeomParams::default()
235    }
236
237    fn name(&self) -> &str {
238        "smooth"
239    }
240
241    fn set_series_color(&mut self, color: (u8, u8, u8)) {
242        self.color = color;
243        self.fill = color;
244    }
245}