1use crate::error::ChartError;
4use crate::{
5 DEFAULT_COLOR, DEFAULT_HEIGHT, DEFAULT_PADDING_FRACTION, DEFAULT_TITLE_FONT_SIZE,
6 DEFAULT_WIDTH, TITLE_AREA_HEIGHT, ScaleType, extent_padded, validate_data_array,
7 validate_data_length, validate_dimensions, validate_positive,
8};
9use d3rs::color::D3Color;
10use d3rs::scale::{LinearScale, LogScale};
11use d3rs::shape::{ScatterConfig, ScatterPoint, render_scatter};
12use d3rs::text::{VectorFontConfig, render_vector_text};
13use gpui::prelude::*;
14use gpui::*;
15
16#[derive(Debug, Clone)]
18pub struct ScatterChart {
19 x: Vec<f64>,
20 y: Vec<f64>,
21 title: Option<String>,
22 color: u32,
23 point_radius: f32,
24 opacity: f32,
25 width: f32,
26 height: f32,
27 x_scale_type: ScaleType,
28 y_scale_type: ScaleType,
29}
30
31impl ScatterChart {
32 pub fn title(mut self, title: impl Into<String>) -> Self {
34 self.title = Some(title.into());
35 self
36 }
37
38 pub fn color(mut self, hex: u32) -> Self {
48 self.color = hex;
49 self
50 }
51
52 pub fn point_radius(mut self, radius: f32) -> Self {
54 self.point_radius = radius;
55 self
56 }
57
58 pub fn opacity(mut self, opacity: f32) -> Self {
60 self.opacity = opacity.clamp(0.0, 1.0);
61 self
62 }
63
64 pub fn size(mut self, width: f32, height: f32) -> Self {
66 self.width = width;
67 self.height = height;
68 self
69 }
70
71 pub fn x_scale(mut self, scale: ScaleType) -> Self {
81 self.x_scale_type = scale;
82 self
83 }
84
85 pub fn y_scale(mut self, scale: ScaleType) -> Self {
87 self.y_scale_type = scale;
88 self
89 }
90
91 pub fn build(self) -> Result<impl IntoElement, ChartError> {
93 validate_data_array(&self.x, "x")?;
95 validate_data_array(&self.y, "y")?;
96 validate_data_length(self.x.len(), self.y.len(), "x", "y")?;
97 validate_dimensions(self.width, self.height)?;
98
99 if self.x_scale_type == ScaleType::Log {
101 validate_positive(&self.x, "x")?;
102 }
103 if self.y_scale_type == ScaleType::Log {
104 validate_positive(&self.y, "y")?;
105 }
106
107 let title_height = if self.title.is_some() {
109 TITLE_AREA_HEIGHT
110 } else {
111 0.0
112 };
113 let plot_height = self.height - title_height;
114
115 let (x_min, x_max) = extent_padded(&self.x, DEFAULT_PADDING_FRACTION);
117 let (y_min, y_max) = extent_padded(&self.y, DEFAULT_PADDING_FRACTION);
118
119 let data: Vec<ScatterPoint> = self
121 .x
122 .iter()
123 .zip(self.y.iter())
124 .map(|(&x, &y)| ScatterPoint::new(x, y))
125 .collect();
126
127 let config = ScatterConfig::new()
129 .fill_color(D3Color::from_hex(self.color))
130 .point_radius(self.point_radius)
131 .opacity(self.opacity);
132
133 let scatter_element: AnyElement = match (self.x_scale_type, self.y_scale_type) {
135 (ScaleType::Linear, ScaleType::Linear) => {
136 let x_scale = LinearScale::new()
137 .domain(x_min, x_max)
138 .range(0.0, self.width as f64);
139 let y_scale = LinearScale::new()
140 .domain(y_min, y_max)
141 .range(plot_height as f64, 0.0);
142 render_scatter(&x_scale, &y_scale, &data, &config).into_any_element()
143 }
144 (ScaleType::Log, ScaleType::Linear) => {
145 let x_scale = LogScale::new()
146 .domain(x_min.max(1e-10), x_max)
147 .range(0.0, self.width as f64);
148 let y_scale = LinearScale::new()
149 .domain(y_min, y_max)
150 .range(plot_height as f64, 0.0);
151 render_scatter(&x_scale, &y_scale, &data, &config).into_any_element()
152 }
153 (ScaleType::Linear, ScaleType::Log) => {
154 let x_scale = LinearScale::new()
155 .domain(x_min, x_max)
156 .range(0.0, self.width as f64);
157 let y_scale = LogScale::new()
158 .domain(y_min.max(1e-10), y_max)
159 .range(plot_height as f64, 0.0);
160 render_scatter(&x_scale, &y_scale, &data, &config).into_any_element()
161 }
162 (ScaleType::Log, ScaleType::Log) => {
163 let x_scale = LogScale::new()
164 .domain(x_min.max(1e-10), x_max)
165 .range(0.0, self.width as f64);
166 let y_scale = LogScale::new()
167 .domain(y_min.max(1e-10), y_max)
168 .range(plot_height as f64, 0.0);
169 render_scatter(&x_scale, &y_scale, &data, &config).into_any_element()
170 }
171 };
172
173 let mut container = div()
175 .w(px(self.width))
176 .h(px(self.height))
177 .relative()
178 .flex()
179 .flex_col();
180
181 if let Some(title) = &self.title {
183 let font_config =
184 VectorFontConfig::horizontal(DEFAULT_TITLE_FONT_SIZE, hsla(0.0, 0.0, 0.2, 1.0));
185 container = container.child(
186 div()
187 .w_full()
188 .h(px(title_height))
189 .flex()
190 .justify_center()
191 .items_center()
192 .child(render_vector_text(title, &font_config)),
193 );
194 }
195
196 container = container.child(
198 div()
199 .w(px(self.width))
200 .h(px(plot_height))
201 .relative()
202 .child(scatter_element),
203 );
204
205 Ok(container)
206 }
207}
208
209pub fn scatter(x: &[f64], y: &[f64]) -> ScatterChart {
226 ScatterChart {
227 x: x.to_vec(),
228 y: y.to_vec(),
229 title: None,
230 color: DEFAULT_COLOR,
231 point_radius: 5.0,
232 opacity: 0.7,
233 width: DEFAULT_WIDTH,
234 height: DEFAULT_HEIGHT,
235 x_scale_type: ScaleType::Linear,
236 y_scale_type: ScaleType::Linear,
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[test]
245 fn test_scatter_empty_x_data() {
246 let result = scatter(&[], &[1.0, 2.0, 3.0]).build();
247 assert!(matches!(result, Err(ChartError::EmptyData { field: "x" })));
248 }
249
250 #[test]
251 fn test_scatter_empty_y_data() {
252 let result = scatter(&[1.0, 2.0, 3.0], &[]).build();
253 assert!(matches!(result, Err(ChartError::EmptyData { field: "y" })));
254 }
255
256 #[test]
257 fn test_scatter_data_length_mismatch() {
258 let result = scatter(&[1.0, 2.0], &[1.0, 2.0, 3.0]).build();
259 assert!(matches!(
260 result,
261 Err(ChartError::DataLengthMismatch {
262 x_field: "x",
263 y_field: "y",
264 x_len: 2,
265 y_len: 3,
266 })
267 ));
268 }
269
270 #[test]
271 fn test_scatter_nan_in_x() {
272 let result = scatter(&[1.0, f64::NAN, 3.0], &[1.0, 2.0, 3.0]).build();
273 assert!(matches!(
274 result,
275 Err(ChartError::InvalidData {
276 field: "x",
277 reason: "contains NaN or Infinity"
278 })
279 ));
280 }
281
282 #[test]
283 fn test_scatter_infinity_in_y() {
284 let result = scatter(&[1.0, 2.0, 3.0], &[1.0, f64::INFINITY, 3.0]).build();
285 assert!(matches!(
286 result,
287 Err(ChartError::InvalidData {
288 field: "y",
289 reason: "contains NaN or Infinity"
290 })
291 ));
292 }
293
294 #[test]
295 fn test_scatter_zero_width() {
296 let result = scatter(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0])
297 .size(0.0, 400.0)
298 .build();
299 assert!(matches!(
300 result,
301 Err(ChartError::InvalidDimension {
302 field: "width",
303 value: 0.0
304 })
305 ));
306 }
307
308 #[test]
309 fn test_scatter_negative_height() {
310 let result = scatter(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0])
311 .size(600.0, -100.0)
312 .build();
313 assert!(matches!(
314 result,
315 Err(ChartError::InvalidDimension {
316 field: "height",
317 value: -100.0
318 })
319 ));
320 }
321
322 #[test]
323 fn test_scatter_successful_build() {
324 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
325 let y = vec![2.0, 4.0, 3.0, 5.0, 4.5];
326 let result = scatter(&x, &y).title("Test Chart").color(0x1f77b4).build();
327 assert!(result.is_ok());
328 }
329
330 #[test]
331 fn test_scatter_builder_chain() {
332 let result = scatter(&[1.0, 2.0], &[3.0, 4.0])
333 .title("My Plot")
334 .color(0xff0000)
335 .point_radius(10.0)
336 .opacity(0.5)
337 .size(800.0, 600.0)
338 .build();
339 assert!(result.is_ok());
340 }
341
342 #[test]
343 fn test_scatter_log_x_scale() {
344 let x = vec![10.0, 100.0, 1000.0, 10000.0];
345 let y = vec![1.0, 2.0, 3.0, 4.0];
346 let result = scatter(&x, &y)
347 .x_scale(ScaleType::Log)
348 .build();
349 assert!(result.is_ok());
350 }
351
352 #[test]
353 fn test_scatter_log_y_scale() {
354 let x = vec![1.0, 2.0, 3.0, 4.0];
355 let y = vec![10.0, 100.0, 1000.0, 10000.0];
356 let result = scatter(&x, &y)
357 .y_scale(ScaleType::Log)
358 .build();
359 assert!(result.is_ok());
360 }
361
362 #[test]
363 fn test_scatter_log_xy_scale() {
364 let x = vec![10.0, 100.0, 1000.0];
365 let y = vec![20.0, 200.0, 2000.0];
366 let result = scatter(&x, &y)
367 .x_scale(ScaleType::Log)
368 .y_scale(ScaleType::Log)
369 .build();
370 assert!(result.is_ok());
371 }
372
373 #[test]
374 fn test_scatter_log_x_negative_values() {
375 let x = vec![-10.0, -5.0, 5.0, 10.0];
376 let y = vec![1.0, 2.0, 3.0, 4.0];
377 let result = scatter(&x, &y)
378 .x_scale(ScaleType::Log)
379 .build();
380 assert!(matches!(
381 result,
382 Err(ChartError::InvalidData {
383 field: "x",
384 reason: "contains non-positive values for log scale"
385 })
386 ));
387 }
388
389 #[test]
390 fn test_scatter_log_y_zero_value() {
391 let x = vec![1.0, 2.0, 3.0, 4.0];
392 let y = vec![0.0, 1.0, 2.0, 3.0];
393 let result = scatter(&x, &y)
394 .y_scale(ScaleType::Log)
395 .build();
396 assert!(matches!(
397 result,
398 Err(ChartError::InvalidData {
399 field: "y",
400 reason: "contains non-positive values for log scale"
401 })
402 ));
403 }
404
405 #[test]
406 fn test_scatter_log_scale_with_title() {
407 let x = vec![10.0, 100.0, 1000.0];
408 let y = vec![1.0, 2.0, 3.0];
409 let result = scatter(&x, &y)
410 .title("Log Scale Plot")
411 .x_scale(ScaleType::Log)
412 .color(0x1f77b4)
413 .build();
414 assert!(result.is_ok());
415 }
416}