gpui_px/
scatter.rs

1//! Scatter chart - Plotly Express style API.
2
3use crate::error::ChartError;
4use crate::{
5    DEFAULT_COLOR, DEFAULT_HEIGHT, DEFAULT_PADDING_FRACTION, DEFAULT_TITLE_FONT_SIZE,
6    DEFAULT_WIDTH, ScaleType, TITLE_AREA_HEIGHT, extent_padded, validate_data_array,
7    validate_data_length, validate_dimensions, validate_positive,
8};
9use d3rs::axis::{AxisConfig, DefaultAxisTheme, render_axis};
10use d3rs::color::D3Color;
11use d3rs::grid::{GridConfig, render_grid};
12use d3rs::scale::{LinearScale, LogScale};
13use d3rs::shape::{ScatterConfig, ScatterPoint, render_scatter};
14use d3rs::text::{VectorFontConfig, render_vector_text};
15use gpui::prelude::*;
16use gpui::*;
17
18/// Scatter chart builder.
19#[derive(Debug, Clone)]
20pub struct ScatterChart {
21    x: Vec<f64>,
22    y: Vec<f64>,
23    title: Option<String>,
24    color: u32,
25    point_radius: f32,
26    opacity: f32,
27    width: f32,
28    height: f32,
29    x_scale_type: ScaleType,
30    y_scale_type: ScaleType,
31}
32
33impl ScatterChart {
34    /// Set chart title (rendered at top of chart).
35    pub fn title(mut self, title: impl Into<String>) -> Self {
36        self.title = Some(title.into());
37        self
38    }
39
40    /// Set point color as 24-bit RGB hex value (format: 0xRRGGBB).
41    ///
42    /// # Example
43    /// ```rust,no_run
44    /// use gpui_px::scatter;
45    /// let chart = scatter(&[1.0], &[1.0])
46    ///     .color(0x1f77b4)  // Plotly blue
47    ///     .build();
48    /// ```
49    pub fn color(mut self, hex: u32) -> Self {
50        self.color = hex;
51        self
52    }
53
54    /// Set point radius in pixels.
55    pub fn point_radius(mut self, radius: f32) -> Self {
56        self.point_radius = radius;
57        self
58    }
59
60    /// Set point opacity (0.0 - 1.0).
61    pub fn opacity(mut self, opacity: f32) -> Self {
62        self.opacity = opacity.clamp(0.0, 1.0);
63        self
64    }
65
66    /// Set chart dimensions.
67    pub fn size(mut self, width: f32, height: f32) -> Self {
68        self.width = width;
69        self.height = height;
70        self
71    }
72
73    /// Set X-axis scale type (linear or log).
74    ///
75    /// # Example
76    /// ```rust,no_run
77    /// use gpui_px::{scatter, ScaleType};
78    /// let chart = scatter(&[10.0, 100.0, 1000.0], &[1.0, 2.0, 3.0])
79    ///     .x_scale(ScaleType::Log)
80    ///     .build();
81    /// ```
82    pub fn x_scale(mut self, scale: ScaleType) -> Self {
83        self.x_scale_type = scale;
84        self
85    }
86
87    /// Set Y-axis scale type (linear or log).
88    pub fn y_scale(mut self, scale: ScaleType) -> Self {
89        self.y_scale_type = scale;
90        self
91    }
92
93    /// Build and validate the chart, returning renderable element.
94    pub fn build(self) -> Result<impl IntoElement, ChartError> {
95        // Validate inputs
96        validate_data_array(&self.x, "x")?;
97        validate_data_array(&self.y, "y")?;
98        validate_data_length(self.x.len(), self.y.len(), "x", "y")?;
99        validate_dimensions(self.width, self.height)?;
100
101        // Validate positive values for log scales
102        if self.x_scale_type == ScaleType::Log {
103            validate_positive(&self.x, "x")?;
104        }
105        if self.y_scale_type == ScaleType::Log {
106            validate_positive(&self.y, "y")?;
107        }
108
109        // Define margins (TODO: Make configurable?)
110        let margin_left = 50.0;
111        let margin_bottom = 30.0;
112        let margin_top = 10.0;
113        let margin_right = 20.0;
114
115        // Calculate plot area (reserve space for title if present)
116        let title_height = if self.title.is_some() {
117            TITLE_AREA_HEIGHT
118        } else {
119            0.0
120        };
121
122        let plot_width = (self.width as f64 - margin_left - margin_right).max(0.0);
123        let plot_height =
124            (self.height as f64 - title_height as f64 - margin_top - margin_bottom).max(0.0);
125
126        // Calculate domains with padding
127        let (x_min, x_max) = extent_padded(&self.x, DEFAULT_PADDING_FRACTION);
128        let (y_min, y_max) = extent_padded(&self.y, DEFAULT_PADDING_FRACTION);
129
130        // Create data points
131        let data: Vec<ScatterPoint> = self
132            .x
133            .iter()
134            .zip(self.y.iter())
135            .map(|(&x, &y)| ScatterPoint::new(x, y))
136            .collect();
137
138        // Create config
139        let config = ScatterConfig::new()
140            .fill_color(D3Color::from_hex(self.color))
141            .point_radius(self.point_radius)
142            .opacity(self.opacity);
143
144        let theme = DefaultAxisTheme;
145
146        // Build the element based on scale types
147        let chart_content: AnyElement = match (self.x_scale_type, self.y_scale_type) {
148            (ScaleType::Linear, ScaleType::Linear) => {
149                let x_scale = LinearScale::new()
150                    .domain(x_min, x_max)
151                    .range(0.0, plot_width);
152                let y_scale = LinearScale::new()
153                    .domain(y_min, y_max)
154                    .range(plot_height, 0.0);
155
156                div()
157                    .flex()
158                    .child(render_axis(
159                        &y_scale,
160                        &AxisConfig::left(),
161                        plot_height as f32,
162                        &theme,
163                    ))
164                    .child(
165                        div()
166                            .flex()
167                            .flex_col()
168                            .child(
169                                div()
170                                    .w(px(plot_width as f32))
171                                    .h(px(plot_height as f32))
172                                    .relative()
173                                    .bg(rgb(0xf8f8f8)) // Light gray background
174                                    .child(render_grid(
175                                        &x_scale,
176                                        &y_scale,
177                                        &GridConfig::default(),
178                                        plot_width as f32,
179                                        plot_height as f32,
180                                        &theme,
181                                    ))
182                                    .child(render_scatter(&x_scale, &y_scale, &data, &config)),
183                            )
184                            .child(render_axis(
185                                &x_scale,
186                                &AxisConfig::bottom(),
187                                plot_width as f32,
188                                &theme,
189                            )),
190                    )
191                    .into_any_element()
192            }
193            (ScaleType::Log, ScaleType::Linear) => {
194                let x_scale = LogScale::new()
195                    .domain(x_min.max(1e-10), x_max)
196                    .range(0.0, plot_width);
197                let y_scale = LinearScale::new()
198                    .domain(y_min, y_max)
199                    .range(plot_height, 0.0);
200
201                div()
202                    .flex()
203                    .child(render_axis(
204                        &y_scale,
205                        &AxisConfig::left(),
206                        plot_height as f32,
207                        &theme,
208                    ))
209                    .child(
210                        div()
211                            .flex()
212                            .flex_col()
213                            .child(
214                                div()
215                                    .w(px(plot_width as f32))
216                                    .h(px(plot_height as f32))
217                                    .relative()
218                                    .bg(rgb(0xf8f8f8))
219                                    .child(render_grid(
220                                        &x_scale,
221                                        &y_scale,
222                                        &GridConfig::default(),
223                                        plot_width as f32,
224                                        plot_height as f32,
225                                        &theme,
226                                    ))
227                                    .child(render_scatter(&x_scale, &y_scale, &data, &config)),
228                            )
229                            .child(render_axis(
230                                &x_scale,
231                                &AxisConfig::bottom(),
232                                plot_width as f32,
233                                &theme,
234                            )),
235                    )
236                    .into_any_element()
237            }
238            (ScaleType::Linear, ScaleType::Log) => {
239                let x_scale = LinearScale::new()
240                    .domain(x_min, x_max)
241                    .range(0.0, plot_width);
242                let y_scale = LogScale::new()
243                    .domain(y_min.max(1e-10), y_max)
244                    .range(plot_height, 0.0);
245
246                div()
247                    .flex()
248                    .child(render_axis(
249                        &y_scale,
250                        &AxisConfig::left(),
251                        plot_height as f32,
252                        &theme,
253                    ))
254                    .child(
255                        div()
256                            .flex()
257                            .flex_col()
258                            .child(
259                                div()
260                                    .w(px(plot_width as f32))
261                                    .h(px(plot_height as f32))
262                                    .relative()
263                                    .bg(rgb(0xf8f8f8))
264                                    .child(render_grid(
265                                        &x_scale,
266                                        &y_scale,
267                                        &GridConfig::default(),
268                                        plot_width as f32,
269                                        plot_height as f32,
270                                        &theme,
271                                    ))
272                                    .child(render_scatter(&x_scale, &y_scale, &data, &config)),
273                            )
274                            .child(render_axis(
275                                &x_scale,
276                                &AxisConfig::bottom(),
277                                plot_width as f32,
278                                &theme,
279                            )),
280                    )
281                    .into_any_element()
282            }
283            (ScaleType::Log, ScaleType::Log) => {
284                let x_scale = LogScale::new()
285                    .domain(x_min.max(1e-10), x_max)
286                    .range(0.0, plot_width);
287                let y_scale = LogScale::new()
288                    .domain(y_min.max(1e-10), y_max)
289                    .range(plot_height, 0.0);
290
291                div()
292                    .flex()
293                    .child(render_axis(
294                        &y_scale,
295                        &AxisConfig::left(),
296                        plot_height as f32,
297                        &theme,
298                    ))
299                    .child(
300                        div()
301                            .flex()
302                            .flex_col()
303                            .child(
304                                div()
305                                    .w(px(plot_width as f32))
306                                    .h(px(plot_height as f32))
307                                    .relative()
308                                    .bg(rgb(0xf8f8f8))
309                                    .child(render_grid(
310                                        &x_scale,
311                                        &y_scale,
312                                        &GridConfig::default(),
313                                        plot_width as f32,
314                                        plot_height as f32,
315                                        &theme,
316                                    ))
317                                    .child(render_scatter(&x_scale, &y_scale, &data, &config)),
318                            )
319                            .child(render_axis(
320                                &x_scale,
321                                &AxisConfig::bottom(),
322                                plot_width as f32,
323                                &theme,
324                            )),
325                    )
326                    .into_any_element()
327            }
328        };
329
330        // Build container with optional title
331        let mut container = div()
332            .w(px(self.width))
333            .h(px(self.height))
334            .relative()
335            .flex()
336            .flex_col();
337
338        // Add title if present
339        if let Some(title) = &self.title {
340            let font_config =
341                VectorFontConfig::horizontal(DEFAULT_TITLE_FONT_SIZE, hsla(0.0, 0.0, 0.2, 1.0));
342            container = container.child(
343                div()
344                    .w_full()
345                    .h(px(title_height))
346                    .flex()
347                    .justify_center()
348                    .items_center()
349                    .child(render_vector_text(title, &font_config)),
350            );
351        }
352
353        // Add chart content
354        container = container.child(div().relative().child(chart_content));
355
356        Ok(container)
357    }
358}
359
360/// Create a scatter chart from x and y data.
361///
362/// # Example
363///
364/// ```rust,no_run
365/// use gpui_px::scatter;
366///
367/// let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
368/// let y = vec![2.0, 4.0, 3.0, 5.0, 4.5];
369///
370/// let chart = scatter(&x, &y)
371///     .title("My Scatter Plot")
372///     .color(0x1f77b4)
373///     .build()?;
374/// # Ok::<(), gpui_px::ChartError>(())
375/// ```
376pub fn scatter(x: &[f64], y: &[f64]) -> ScatterChart {
377    ScatterChart {
378        x: x.to_vec(),
379        y: y.to_vec(),
380        title: None,
381        color: DEFAULT_COLOR,
382        point_radius: 5.0,
383        opacity: 0.7,
384        width: DEFAULT_WIDTH,
385        height: DEFAULT_HEIGHT,
386        x_scale_type: ScaleType::Linear,
387        y_scale_type: ScaleType::Linear,
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    #[test]
396    fn test_scatter_empty_x_data() {
397        let result = scatter(&[], &[1.0, 2.0, 3.0]).build();
398        assert!(matches!(result, Err(ChartError::EmptyData { field: "x" })));
399    }
400
401    #[test]
402    fn test_scatter_empty_y_data() {
403        let result = scatter(&[1.0, 2.0, 3.0], &[]).build();
404        assert!(matches!(result, Err(ChartError::EmptyData { field: "y" })));
405    }
406
407    #[test]
408    fn test_scatter_data_length_mismatch() {
409        let result = scatter(&[1.0, 2.0], &[1.0, 2.0, 3.0]).build();
410        assert!(matches!(
411            result,
412            Err(ChartError::DataLengthMismatch {
413                x_field: "x",
414                y_field: "y",
415                x_len: 2,
416                y_len: 3,
417            })
418        ));
419    }
420
421    #[test]
422    fn test_scatter_nan_in_x() {
423        let result = scatter(&[1.0, f64::NAN, 3.0], &[1.0, 2.0, 3.0]).build();
424        assert!(matches!(
425            result,
426            Err(ChartError::InvalidData {
427                field: "x",
428                reason: "contains NaN or Infinity"
429            })
430        ));
431    }
432
433    #[test]
434    fn test_scatter_infinity_in_y() {
435        let result = scatter(&[1.0, 2.0, 3.0], &[1.0, f64::INFINITY, 3.0]).build();
436        assert!(matches!(
437            result,
438            Err(ChartError::InvalidData {
439                field: "y",
440                reason: "contains NaN or Infinity"
441            })
442        ));
443    }
444
445    #[test]
446    fn test_scatter_zero_width() {
447        let result = scatter(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0])
448            .size(0.0, 400.0)
449            .build();
450        assert!(matches!(
451            result,
452            Err(ChartError::InvalidDimension {
453                field: "width",
454                value: 0.0
455            })
456        ));
457    }
458
459    #[test]
460    fn test_scatter_negative_height() {
461        let result = scatter(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0])
462            .size(600.0, -100.0)
463            .build();
464        assert!(matches!(
465            result,
466            Err(ChartError::InvalidDimension {
467                field: "height",
468                value: -100.0
469            })
470        ));
471    }
472
473    #[test]
474    fn test_scatter_successful_build() {
475        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
476        let y = vec![2.0, 4.0, 3.0, 5.0, 4.5];
477        let result = scatter(&x, &y).title("Test Chart").color(0x1f77b4).build();
478        assert!(result.is_ok());
479    }
480
481    #[test]
482    fn test_scatter_builder_chain() {
483        let result = scatter(&[1.0, 2.0], &[3.0, 4.0])
484            .title("My Plot")
485            .color(0xff0000)
486            .point_radius(10.0)
487            .opacity(0.5)
488            .size(800.0, 600.0)
489            .build();
490        assert!(result.is_ok());
491    }
492
493    #[test]
494    fn test_scatter_log_x_scale() {
495        let x = vec![10.0, 100.0, 1000.0, 10000.0];
496        let y = vec![1.0, 2.0, 3.0, 4.0];
497        let result = scatter(&x, &y).x_scale(ScaleType::Log).build();
498        assert!(result.is_ok());
499    }
500
501    #[test]
502    fn test_scatter_log_y_scale() {
503        let x = vec![1.0, 2.0, 3.0, 4.0];
504        let y = vec![10.0, 100.0, 1000.0, 10000.0];
505        let result = scatter(&x, &y).y_scale(ScaleType::Log).build();
506        assert!(result.is_ok());
507    }
508
509    #[test]
510    fn test_scatter_log_xy_scale() {
511        let x = vec![10.0, 100.0, 1000.0];
512        let y = vec![20.0, 200.0, 2000.0];
513        let result = scatter(&x, &y)
514            .x_scale(ScaleType::Log)
515            .y_scale(ScaleType::Log)
516            .build();
517        assert!(result.is_ok());
518    }
519
520    #[test]
521    fn test_scatter_log_x_negative_values() {
522        let x = vec![-10.0, -5.0, 5.0, 10.0];
523        let y = vec![1.0, 2.0, 3.0, 4.0];
524        let result = scatter(&x, &y).x_scale(ScaleType::Log).build();
525        assert!(matches!(
526            result,
527            Err(ChartError::InvalidData {
528                field: "x",
529                reason: "contains non-positive values for log scale"
530            })
531        ));
532    }
533
534    #[test]
535    fn test_scatter_log_y_zero_value() {
536        let x = vec![1.0, 2.0, 3.0, 4.0];
537        let y = vec![0.0, 1.0, 2.0, 3.0];
538        let result = scatter(&x, &y).y_scale(ScaleType::Log).build();
539        assert!(matches!(
540            result,
541            Err(ChartError::InvalidData {
542                field: "y",
543                reason: "contains non-positive values for log scale"
544            })
545        ));
546    }
547
548    #[test]
549    fn test_scatter_log_scale_with_title() {
550        let x = vec![10.0, 100.0, 1000.0];
551        let y = vec![1.0, 2.0, 3.0];
552        let result = scatter(&x, &y)
553            .title("Log Scale Plot")
554            .x_scale(ScaleType::Log)
555            .color(0x1f77b4)
556            .build();
557        assert!(result.is_ok());
558    }
559}