1use crate::error::ChartError;
4use crate::{
5 DEFAULT_COLOR, DEFAULT_HEIGHT, DEFAULT_PADDING_FRACTION, DEFAULT_TITLE_FONT_SIZE,
6 DEFAULT_WIDTH, ScaleType, TITLE_AREA_HEIGHT, extent_padded, validate_data_array,
7 validate_data_length, validate_dimensions, validate_positive,
8};
9use d3rs::axis::{AxisConfig, DefaultAxisTheme, render_axis};
10use d3rs::color::D3Color;
11use d3rs::grid::{GridConfig, render_grid};
12use d3rs::scale::{LinearScale, LogScale};
13use d3rs::shape::{ScatterConfig, ScatterPoint, render_scatter};
14use d3rs::text::{VectorFontConfig, render_vector_text};
15use gpui::prelude::*;
16use gpui::*;
17
18#[derive(Debug, Clone)]
20pub struct ScatterChart {
21 x: Vec<f64>,
22 y: Vec<f64>,
23 title: Option<String>,
24 color: u32,
25 point_radius: f32,
26 opacity: f32,
27 width: f32,
28 height: f32,
29 x_scale_type: ScaleType,
30 y_scale_type: ScaleType,
31}
32
33impl ScatterChart {
34 pub fn title(mut self, title: impl Into<String>) -> Self {
36 self.title = Some(title.into());
37 self
38 }
39
40 pub fn color(mut self, hex: u32) -> Self {
50 self.color = hex;
51 self
52 }
53
54 pub fn point_radius(mut self, radius: f32) -> Self {
56 self.point_radius = radius;
57 self
58 }
59
60 pub fn opacity(mut self, opacity: f32) -> Self {
62 self.opacity = opacity.clamp(0.0, 1.0);
63 self
64 }
65
66 pub fn size(mut self, width: f32, height: f32) -> Self {
68 self.width = width;
69 self.height = height;
70 self
71 }
72
73 pub fn x_scale(mut self, scale: ScaleType) -> Self {
83 self.x_scale_type = scale;
84 self
85 }
86
87 pub fn y_scale(mut self, scale: ScaleType) -> Self {
89 self.y_scale_type = scale;
90 self
91 }
92
93 pub fn build(self) -> Result<impl IntoElement, ChartError> {
95 validate_data_array(&self.x, "x")?;
97 validate_data_array(&self.y, "y")?;
98 validate_data_length(self.x.len(), self.y.len(), "x", "y")?;
99 validate_dimensions(self.width, self.height)?;
100
101 if self.x_scale_type == ScaleType::Log {
103 validate_positive(&self.x, "x")?;
104 }
105 if self.y_scale_type == ScaleType::Log {
106 validate_positive(&self.y, "y")?;
107 }
108
109 let margin_left = 50.0;
111 let margin_bottom = 30.0;
112 let margin_top = 10.0;
113 let margin_right = 20.0;
114
115 let title_height = if self.title.is_some() {
117 TITLE_AREA_HEIGHT
118 } else {
119 0.0
120 };
121
122 let plot_width = (self.width as f64 - margin_left - margin_right).max(0.0);
123 let plot_height =
124 (self.height as f64 - title_height as f64 - margin_top - margin_bottom).max(0.0);
125
126 let (x_min, x_max) = extent_padded(&self.x, DEFAULT_PADDING_FRACTION);
128 let (y_min, y_max) = extent_padded(&self.y, DEFAULT_PADDING_FRACTION);
129
130 let data: Vec<ScatterPoint> = self
132 .x
133 .iter()
134 .zip(self.y.iter())
135 .map(|(&x, &y)| ScatterPoint::new(x, y))
136 .collect();
137
138 let config = ScatterConfig::new()
140 .fill_color(D3Color::from_hex(self.color))
141 .point_radius(self.point_radius)
142 .opacity(self.opacity);
143
144 let theme = DefaultAxisTheme;
145
146 let chart_content: AnyElement = match (self.x_scale_type, self.y_scale_type) {
148 (ScaleType::Linear, ScaleType::Linear) => {
149 let x_scale = LinearScale::new()
150 .domain(x_min, x_max)
151 .range(0.0, plot_width);
152 let y_scale = LinearScale::new()
153 .domain(y_min, y_max)
154 .range(plot_height, 0.0);
155
156 div()
157 .flex()
158 .child(render_axis(
159 &y_scale,
160 &AxisConfig::left(),
161 plot_height as f32,
162 &theme,
163 ))
164 .child(
165 div()
166 .flex()
167 .flex_col()
168 .child(
169 div()
170 .w(px(plot_width as f32))
171 .h(px(plot_height as f32))
172 .relative()
173 .bg(rgb(0xf8f8f8)) .child(render_grid(
175 &x_scale,
176 &y_scale,
177 &GridConfig::default(),
178 plot_width as f32,
179 plot_height as f32,
180 &theme,
181 ))
182 .child(render_scatter(&x_scale, &y_scale, &data, &config)),
183 )
184 .child(render_axis(
185 &x_scale,
186 &AxisConfig::bottom(),
187 plot_width as f32,
188 &theme,
189 )),
190 )
191 .into_any_element()
192 }
193 (ScaleType::Log, ScaleType::Linear) => {
194 let x_scale = LogScale::new()
195 .domain(x_min.max(1e-10), x_max)
196 .range(0.0, plot_width);
197 let y_scale = LinearScale::new()
198 .domain(y_min, y_max)
199 .range(plot_height, 0.0);
200
201 div()
202 .flex()
203 .child(render_axis(
204 &y_scale,
205 &AxisConfig::left(),
206 plot_height as f32,
207 &theme,
208 ))
209 .child(
210 div()
211 .flex()
212 .flex_col()
213 .child(
214 div()
215 .w(px(plot_width as f32))
216 .h(px(plot_height as f32))
217 .relative()
218 .bg(rgb(0xf8f8f8))
219 .child(render_grid(
220 &x_scale,
221 &y_scale,
222 &GridConfig::default(),
223 plot_width as f32,
224 plot_height as f32,
225 &theme,
226 ))
227 .child(render_scatter(&x_scale, &y_scale, &data, &config)),
228 )
229 .child(render_axis(
230 &x_scale,
231 &AxisConfig::bottom(),
232 plot_width as f32,
233 &theme,
234 )),
235 )
236 .into_any_element()
237 }
238 (ScaleType::Linear, ScaleType::Log) => {
239 let x_scale = LinearScale::new()
240 .domain(x_min, x_max)
241 .range(0.0, plot_width);
242 let y_scale = LogScale::new()
243 .domain(y_min.max(1e-10), y_max)
244 .range(plot_height, 0.0);
245
246 div()
247 .flex()
248 .child(render_axis(
249 &y_scale,
250 &AxisConfig::left(),
251 plot_height as f32,
252 &theme,
253 ))
254 .child(
255 div()
256 .flex()
257 .flex_col()
258 .child(
259 div()
260 .w(px(plot_width as f32))
261 .h(px(plot_height as f32))
262 .relative()
263 .bg(rgb(0xf8f8f8))
264 .child(render_grid(
265 &x_scale,
266 &y_scale,
267 &GridConfig::default(),
268 plot_width as f32,
269 plot_height as f32,
270 &theme,
271 ))
272 .child(render_scatter(&x_scale, &y_scale, &data, &config)),
273 )
274 .child(render_axis(
275 &x_scale,
276 &AxisConfig::bottom(),
277 plot_width as f32,
278 &theme,
279 )),
280 )
281 .into_any_element()
282 }
283 (ScaleType::Log, ScaleType::Log) => {
284 let x_scale = LogScale::new()
285 .domain(x_min.max(1e-10), x_max)
286 .range(0.0, plot_width);
287 let y_scale = LogScale::new()
288 .domain(y_min.max(1e-10), y_max)
289 .range(plot_height, 0.0);
290
291 div()
292 .flex()
293 .child(render_axis(
294 &y_scale,
295 &AxisConfig::left(),
296 plot_height as f32,
297 &theme,
298 ))
299 .child(
300 div()
301 .flex()
302 .flex_col()
303 .child(
304 div()
305 .w(px(plot_width as f32))
306 .h(px(plot_height as f32))
307 .relative()
308 .bg(rgb(0xf8f8f8))
309 .child(render_grid(
310 &x_scale,
311 &y_scale,
312 &GridConfig::default(),
313 plot_width as f32,
314 plot_height as f32,
315 &theme,
316 ))
317 .child(render_scatter(&x_scale, &y_scale, &data, &config)),
318 )
319 .child(render_axis(
320 &x_scale,
321 &AxisConfig::bottom(),
322 plot_width as f32,
323 &theme,
324 )),
325 )
326 .into_any_element()
327 }
328 };
329
330 let mut container = div()
332 .w(px(self.width))
333 .h(px(self.height))
334 .relative()
335 .flex()
336 .flex_col();
337
338 if let Some(title) = &self.title {
340 let font_config =
341 VectorFontConfig::horizontal(DEFAULT_TITLE_FONT_SIZE, hsla(0.0, 0.0, 0.2, 1.0));
342 container = container.child(
343 div()
344 .w_full()
345 .h(px(title_height))
346 .flex()
347 .justify_center()
348 .items_center()
349 .child(render_vector_text(title, &font_config)),
350 );
351 }
352
353 container = container.child(div().relative().child(chart_content));
355
356 Ok(container)
357 }
358}
359
360pub fn scatter(x: &[f64], y: &[f64]) -> ScatterChart {
377 ScatterChart {
378 x: x.to_vec(),
379 y: y.to_vec(),
380 title: None,
381 color: DEFAULT_COLOR,
382 point_radius: 5.0,
383 opacity: 0.7,
384 width: DEFAULT_WIDTH,
385 height: DEFAULT_HEIGHT,
386 x_scale_type: ScaleType::Linear,
387 y_scale_type: ScaleType::Linear,
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394
395 #[test]
396 fn test_scatter_empty_x_data() {
397 let result = scatter(&[], &[1.0, 2.0, 3.0]).build();
398 assert!(matches!(result, Err(ChartError::EmptyData { field: "x" })));
399 }
400
401 #[test]
402 fn test_scatter_empty_y_data() {
403 let result = scatter(&[1.0, 2.0, 3.0], &[]).build();
404 assert!(matches!(result, Err(ChartError::EmptyData { field: "y" })));
405 }
406
407 #[test]
408 fn test_scatter_data_length_mismatch() {
409 let result = scatter(&[1.0, 2.0], &[1.0, 2.0, 3.0]).build();
410 assert!(matches!(
411 result,
412 Err(ChartError::DataLengthMismatch {
413 x_field: "x",
414 y_field: "y",
415 x_len: 2,
416 y_len: 3,
417 })
418 ));
419 }
420
421 #[test]
422 fn test_scatter_nan_in_x() {
423 let result = scatter(&[1.0, f64::NAN, 3.0], &[1.0, 2.0, 3.0]).build();
424 assert!(matches!(
425 result,
426 Err(ChartError::InvalidData {
427 field: "x",
428 reason: "contains NaN or Infinity"
429 })
430 ));
431 }
432
433 #[test]
434 fn test_scatter_infinity_in_y() {
435 let result = scatter(&[1.0, 2.0, 3.0], &[1.0, f64::INFINITY, 3.0]).build();
436 assert!(matches!(
437 result,
438 Err(ChartError::InvalidData {
439 field: "y",
440 reason: "contains NaN or Infinity"
441 })
442 ));
443 }
444
445 #[test]
446 fn test_scatter_zero_width() {
447 let result = scatter(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0])
448 .size(0.0, 400.0)
449 .build();
450 assert!(matches!(
451 result,
452 Err(ChartError::InvalidDimension {
453 field: "width",
454 value: 0.0
455 })
456 ));
457 }
458
459 #[test]
460 fn test_scatter_negative_height() {
461 let result = scatter(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0])
462 .size(600.0, -100.0)
463 .build();
464 assert!(matches!(
465 result,
466 Err(ChartError::InvalidDimension {
467 field: "height",
468 value: -100.0
469 })
470 ));
471 }
472
473 #[test]
474 fn test_scatter_successful_build() {
475 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
476 let y = vec![2.0, 4.0, 3.0, 5.0, 4.5];
477 let result = scatter(&x, &y).title("Test Chart").color(0x1f77b4).build();
478 assert!(result.is_ok());
479 }
480
481 #[test]
482 fn test_scatter_builder_chain() {
483 let result = scatter(&[1.0, 2.0], &[3.0, 4.0])
484 .title("My Plot")
485 .color(0xff0000)
486 .point_radius(10.0)
487 .opacity(0.5)
488 .size(800.0, 600.0)
489 .build();
490 assert!(result.is_ok());
491 }
492
493 #[test]
494 fn test_scatter_log_x_scale() {
495 let x = vec![10.0, 100.0, 1000.0, 10000.0];
496 let y = vec![1.0, 2.0, 3.0, 4.0];
497 let result = scatter(&x, &y).x_scale(ScaleType::Log).build();
498 assert!(result.is_ok());
499 }
500
501 #[test]
502 fn test_scatter_log_y_scale() {
503 let x = vec![1.0, 2.0, 3.0, 4.0];
504 let y = vec![10.0, 100.0, 1000.0, 10000.0];
505 let result = scatter(&x, &y).y_scale(ScaleType::Log).build();
506 assert!(result.is_ok());
507 }
508
509 #[test]
510 fn test_scatter_log_xy_scale() {
511 let x = vec![10.0, 100.0, 1000.0];
512 let y = vec![20.0, 200.0, 2000.0];
513 let result = scatter(&x, &y)
514 .x_scale(ScaleType::Log)
515 .y_scale(ScaleType::Log)
516 .build();
517 assert!(result.is_ok());
518 }
519
520 #[test]
521 fn test_scatter_log_x_negative_values() {
522 let x = vec![-10.0, -5.0, 5.0, 10.0];
523 let y = vec![1.0, 2.0, 3.0, 4.0];
524 let result = scatter(&x, &y).x_scale(ScaleType::Log).build();
525 assert!(matches!(
526 result,
527 Err(ChartError::InvalidData {
528 field: "x",
529 reason: "contains non-positive values for log scale"
530 })
531 ));
532 }
533
534 #[test]
535 fn test_scatter_log_y_zero_value() {
536 let x = vec![1.0, 2.0, 3.0, 4.0];
537 let y = vec![0.0, 1.0, 2.0, 3.0];
538 let result = scatter(&x, &y).y_scale(ScaleType::Log).build();
539 assert!(matches!(
540 result,
541 Err(ChartError::InvalidData {
542 field: "y",
543 reason: "contains non-positive values for log scale"
544 })
545 ));
546 }
547
548 #[test]
549 fn test_scatter_log_scale_with_title() {
550 let x = vec![10.0, 100.0, 1000.0];
551 let y = vec![1.0, 2.0, 3.0];
552 let result = scatter(&x, &y)
553 .title("Log Scale Plot")
554 .x_scale(ScaleType::Log)
555 .color(0x1f77b4)
556 .build();
557 assert!(result.is_ok());
558 }
559}