gpui_px/
scatter.rs

1//! Scatter chart - Plotly Express style API.
2
3use crate::error::ChartError;
4use crate::{
5    DEFAULT_COLOR, DEFAULT_HEIGHT, DEFAULT_PADDING_FRACTION, DEFAULT_TITLE_FONT_SIZE,
6    DEFAULT_WIDTH, TITLE_AREA_HEIGHT, ScaleType, extent_padded, validate_data_array,
7    validate_data_length, validate_dimensions, validate_positive,
8};
9use d3rs::color::D3Color;
10use d3rs::scale::{LinearScale, LogScale};
11use d3rs::shape::{ScatterConfig, ScatterPoint, render_scatter};
12use d3rs::text::{VectorFontConfig, render_vector_text};
13use gpui::prelude::*;
14use gpui::*;
15
16/// Scatter chart builder.
17#[derive(Debug, Clone)]
18pub struct ScatterChart {
19    x: Vec<f64>,
20    y: Vec<f64>,
21    title: Option<String>,
22    color: u32,
23    point_radius: f32,
24    opacity: f32,
25    width: f32,
26    height: f32,
27    x_scale_type: ScaleType,
28    y_scale_type: ScaleType,
29}
30
31impl ScatterChart {
32    /// Set chart title (rendered at top of chart).
33    pub fn title(mut self, title: impl Into<String>) -> Self {
34        self.title = Some(title.into());
35        self
36    }
37
38    /// Set point color as 24-bit RGB hex value (format: 0xRRGGBB).
39    ///
40    /// # Example
41    /// ```rust,no_run
42    /// use gpui_px::scatter;
43    /// let chart = scatter(&[1.0], &[1.0])
44    ///     .color(0x1f77b4)  // Plotly blue
45    ///     .build();
46    /// ```
47    pub fn color(mut self, hex: u32) -> Self {
48        self.color = hex;
49        self
50    }
51
52    /// Set point radius in pixels.
53    pub fn point_radius(mut self, radius: f32) -> Self {
54        self.point_radius = radius;
55        self
56    }
57
58    /// Set point opacity (0.0 - 1.0).
59    pub fn opacity(mut self, opacity: f32) -> Self {
60        self.opacity = opacity.clamp(0.0, 1.0);
61        self
62    }
63
64    /// Set chart dimensions.
65    pub fn size(mut self, width: f32, height: f32) -> Self {
66        self.width = width;
67        self.height = height;
68        self
69    }
70
71    /// Set X-axis scale type (linear or log).
72    ///
73    /// # Example
74    /// ```rust,no_run
75    /// use gpui_px::{scatter, ScaleType};
76    /// let chart = scatter(&[10.0, 100.0, 1000.0], &[1.0, 2.0, 3.0])
77    ///     .x_scale(ScaleType::Log)
78    ///     .build();
79    /// ```
80    pub fn x_scale(mut self, scale: ScaleType) -> Self {
81        self.x_scale_type = scale;
82        self
83    }
84
85    /// Set Y-axis scale type (linear or log).
86    pub fn y_scale(mut self, scale: ScaleType) -> Self {
87        self.y_scale_type = scale;
88        self
89    }
90
91    /// Build and validate the chart, returning renderable element.
92    pub fn build(self) -> Result<impl IntoElement, ChartError> {
93        // Validate inputs
94        validate_data_array(&self.x, "x")?;
95        validate_data_array(&self.y, "y")?;
96        validate_data_length(self.x.len(), self.y.len(), "x", "y")?;
97        validate_dimensions(self.width, self.height)?;
98
99        // Validate positive values for log scales
100        if self.x_scale_type == ScaleType::Log {
101            validate_positive(&self.x, "x")?;
102        }
103        if self.y_scale_type == ScaleType::Log {
104            validate_positive(&self.y, "y")?;
105        }
106
107        // Calculate plot area (reserve space for title if present)
108        let title_height = if self.title.is_some() {
109            TITLE_AREA_HEIGHT
110        } else {
111            0.0
112        };
113        let plot_height = self.height - title_height;
114
115        // Calculate domains with padding
116        let (x_min, x_max) = extent_padded(&self.x, DEFAULT_PADDING_FRACTION);
117        let (y_min, y_max) = extent_padded(&self.y, DEFAULT_PADDING_FRACTION);
118
119        // Create data points
120        let data: Vec<ScatterPoint> = self
121            .x
122            .iter()
123            .zip(self.y.iter())
124            .map(|(&x, &y)| ScatterPoint::new(x, y))
125            .collect();
126
127        // Create config
128        let config = ScatterConfig::new()
129            .fill_color(D3Color::from_hex(self.color))
130            .point_radius(self.point_radius)
131            .opacity(self.opacity);
132
133        // Build the element based on scale types
134        let scatter_element: AnyElement = match (self.x_scale_type, self.y_scale_type) {
135            (ScaleType::Linear, ScaleType::Linear) => {
136                let x_scale = LinearScale::new()
137                    .domain(x_min, x_max)
138                    .range(0.0, self.width as f64);
139                let y_scale = LinearScale::new()
140                    .domain(y_min, y_max)
141                    .range(plot_height as f64, 0.0);
142                render_scatter(&x_scale, &y_scale, &data, &config).into_any_element()
143            }
144            (ScaleType::Log, ScaleType::Linear) => {
145                let x_scale = LogScale::new()
146                    .domain(x_min.max(1e-10), x_max)
147                    .range(0.0, self.width as f64);
148                let y_scale = LinearScale::new()
149                    .domain(y_min, y_max)
150                    .range(plot_height as f64, 0.0);
151                render_scatter(&x_scale, &y_scale, &data, &config).into_any_element()
152            }
153            (ScaleType::Linear, ScaleType::Log) => {
154                let x_scale = LinearScale::new()
155                    .domain(x_min, x_max)
156                    .range(0.0, self.width as f64);
157                let y_scale = LogScale::new()
158                    .domain(y_min.max(1e-10), y_max)
159                    .range(plot_height as f64, 0.0);
160                render_scatter(&x_scale, &y_scale, &data, &config).into_any_element()
161            }
162            (ScaleType::Log, ScaleType::Log) => {
163                let x_scale = LogScale::new()
164                    .domain(x_min.max(1e-10), x_max)
165                    .range(0.0, self.width as f64);
166                let y_scale = LogScale::new()
167                    .domain(y_min.max(1e-10), y_max)
168                    .range(plot_height as f64, 0.0);
169                render_scatter(&x_scale, &y_scale, &data, &config).into_any_element()
170            }
171        };
172
173        // Build container with optional title
174        let mut container = div()
175            .w(px(self.width))
176            .h(px(self.height))
177            .relative()
178            .flex()
179            .flex_col();
180
181        // Add title if present
182        if let Some(title) = &self.title {
183            let font_config =
184                VectorFontConfig::horizontal(DEFAULT_TITLE_FONT_SIZE, hsla(0.0, 0.0, 0.2, 1.0));
185            container = container.child(
186                div()
187                    .w_full()
188                    .h(px(title_height))
189                    .flex()
190                    .justify_center()
191                    .items_center()
192                    .child(render_vector_text(title, &font_config)),
193            );
194        }
195
196        // Add plot area
197        container = container.child(
198            div()
199                .w(px(self.width))
200                .h(px(plot_height))
201                .relative()
202                .child(scatter_element),
203        );
204
205        Ok(container)
206    }
207}
208
209/// Create a scatter chart from x and y data.
210///
211/// # Example
212///
213/// ```rust,no_run
214/// use gpui_px::scatter;
215///
216/// let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
217/// let y = vec![2.0, 4.0, 3.0, 5.0, 4.5];
218///
219/// let chart = scatter(&x, &y)
220///     .title("My Scatter Plot")
221///     .color(0x1f77b4)
222///     .build()?;
223/// # Ok::<(), gpui_px::ChartError>(())
224/// ```
225pub fn scatter(x: &[f64], y: &[f64]) -> ScatterChart {
226    ScatterChart {
227        x: x.to_vec(),
228        y: y.to_vec(),
229        title: None,
230        color: DEFAULT_COLOR,
231        point_radius: 5.0,
232        opacity: 0.7,
233        width: DEFAULT_WIDTH,
234        height: DEFAULT_HEIGHT,
235        x_scale_type: ScaleType::Linear,
236        y_scale_type: ScaleType::Linear,
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn test_scatter_empty_x_data() {
246        let result = scatter(&[], &[1.0, 2.0, 3.0]).build();
247        assert!(matches!(result, Err(ChartError::EmptyData { field: "x" })));
248    }
249
250    #[test]
251    fn test_scatter_empty_y_data() {
252        let result = scatter(&[1.0, 2.0, 3.0], &[]).build();
253        assert!(matches!(result, Err(ChartError::EmptyData { field: "y" })));
254    }
255
256    #[test]
257    fn test_scatter_data_length_mismatch() {
258        let result = scatter(&[1.0, 2.0], &[1.0, 2.0, 3.0]).build();
259        assert!(matches!(
260            result,
261            Err(ChartError::DataLengthMismatch {
262                x_field: "x",
263                y_field: "y",
264                x_len: 2,
265                y_len: 3,
266            })
267        ));
268    }
269
270    #[test]
271    fn test_scatter_nan_in_x() {
272        let result = scatter(&[1.0, f64::NAN, 3.0], &[1.0, 2.0, 3.0]).build();
273        assert!(matches!(
274            result,
275            Err(ChartError::InvalidData {
276                field: "x",
277                reason: "contains NaN or Infinity"
278            })
279        ));
280    }
281
282    #[test]
283    fn test_scatter_infinity_in_y() {
284        let result = scatter(&[1.0, 2.0, 3.0], &[1.0, f64::INFINITY, 3.0]).build();
285        assert!(matches!(
286            result,
287            Err(ChartError::InvalidData {
288                field: "y",
289                reason: "contains NaN or Infinity"
290            })
291        ));
292    }
293
294    #[test]
295    fn test_scatter_zero_width() {
296        let result = scatter(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0])
297            .size(0.0, 400.0)
298            .build();
299        assert!(matches!(
300            result,
301            Err(ChartError::InvalidDimension {
302                field: "width",
303                value: 0.0
304            })
305        ));
306    }
307
308    #[test]
309    fn test_scatter_negative_height() {
310        let result = scatter(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0])
311            .size(600.0, -100.0)
312            .build();
313        assert!(matches!(
314            result,
315            Err(ChartError::InvalidDimension {
316                field: "height",
317                value: -100.0
318            })
319        ));
320    }
321
322    #[test]
323    fn test_scatter_successful_build() {
324        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
325        let y = vec![2.0, 4.0, 3.0, 5.0, 4.5];
326        let result = scatter(&x, &y).title("Test Chart").color(0x1f77b4).build();
327        assert!(result.is_ok());
328    }
329
330    #[test]
331    fn test_scatter_builder_chain() {
332        let result = scatter(&[1.0, 2.0], &[3.0, 4.0])
333            .title("My Plot")
334            .color(0xff0000)
335            .point_radius(10.0)
336            .opacity(0.5)
337            .size(800.0, 600.0)
338            .build();
339        assert!(result.is_ok());
340    }
341
342    #[test]
343    fn test_scatter_log_x_scale() {
344        let x = vec![10.0, 100.0, 1000.0, 10000.0];
345        let y = vec![1.0, 2.0, 3.0, 4.0];
346        let result = scatter(&x, &y)
347            .x_scale(ScaleType::Log)
348            .build();
349        assert!(result.is_ok());
350    }
351
352    #[test]
353    fn test_scatter_log_y_scale() {
354        let x = vec![1.0, 2.0, 3.0, 4.0];
355        let y = vec![10.0, 100.0, 1000.0, 10000.0];
356        let result = scatter(&x, &y)
357            .y_scale(ScaleType::Log)
358            .build();
359        assert!(result.is_ok());
360    }
361
362    #[test]
363    fn test_scatter_log_xy_scale() {
364        let x = vec![10.0, 100.0, 1000.0];
365        let y = vec![20.0, 200.0, 2000.0];
366        let result = scatter(&x, &y)
367            .x_scale(ScaleType::Log)
368            .y_scale(ScaleType::Log)
369            .build();
370        assert!(result.is_ok());
371    }
372
373    #[test]
374    fn test_scatter_log_x_negative_values() {
375        let x = vec![-10.0, -5.0, 5.0, 10.0];
376        let y = vec![1.0, 2.0, 3.0, 4.0];
377        let result = scatter(&x, &y)
378            .x_scale(ScaleType::Log)
379            .build();
380        assert!(matches!(
381            result,
382            Err(ChartError::InvalidData {
383                field: "x",
384                reason: "contains non-positive values for log scale"
385            })
386        ));
387    }
388
389    #[test]
390    fn test_scatter_log_y_zero_value() {
391        let x = vec![1.0, 2.0, 3.0, 4.0];
392        let y = vec![0.0, 1.0, 2.0, 3.0];
393        let result = scatter(&x, &y)
394            .y_scale(ScaleType::Log)
395            .build();
396        assert!(matches!(
397            result,
398            Err(ChartError::InvalidData {
399                field: "y",
400                reason: "contains non-positive values for log scale"
401            })
402        ));
403    }
404
405    #[test]
406    fn test_scatter_log_scale_with_title() {
407        let x = vec![10.0, 100.0, 1000.0];
408        let y = vec![1.0, 2.0, 3.0];
409        let result = scatter(&x, &y)
410            .title("Log Scale Plot")
411            .x_scale(ScaleType::Log)
412            .color(0x1f77b4)
413            .build();
414        assert!(result.is_ok());
415    }
416}