1use crate::color_scale::ColorScale;
4use crate::error::ChartError;
5use crate::{
6 DEFAULT_HEIGHT, DEFAULT_TITLE_FONT_SIZE, DEFAULT_WIDTH, ScaleType, TITLE_AREA_HEIGHT,
7 extent_padded, validate_data_array, validate_dimensions, validate_grid_dimensions,
8 validate_monotonic, validate_positive,
9};
10use d3rs::axis::{AxisConfig, DefaultAxisTheme, render_axis};
11use d3rs::grid::{GridConfig, render_grid};
12use d3rs::scale::{LinearScale, LogScale};
13use d3rs::shape::{ContourConfig, HeatmapData, render_heatmap};
14use d3rs::text::{VectorFontConfig, render_vector_text};
15use gpui::prelude::*;
16use gpui::*;
17
18#[derive(Clone)]
20pub struct HeatmapChart {
21 z: Vec<f64>,
22 grid_width: usize,
23 grid_height: usize,
24 x_values: Option<Vec<f64>>,
25 y_values: Option<Vec<f64>>,
26 x_scale_type: ScaleType,
27 y_scale_type: ScaleType,
28 color_scale: ColorScale,
29 title: Option<String>,
30 opacity: f32,
31 width: f32,
32 height: f32,
33}
34
35impl std::fmt::Debug for HeatmapChart {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 f.debug_struct("HeatmapChart")
38 .field("grid_width", &self.grid_width)
39 .field("grid_height", &self.grid_height)
40 .field("x_scale_type", &self.x_scale_type)
41 .field("y_scale_type", &self.y_scale_type)
42 .field("color_scale", &self.color_scale)
43 .field("title", &self.title)
44 .field("opacity", &self.opacity)
45 .field("width", &self.width)
46 .field("height", &self.height)
47 .finish()
48 }
49}
50
51impl HeatmapChart {
52 pub fn x(mut self, values: &[f64]) -> Self {
57 self.x_values = Some(values.to_vec());
58 self
59 }
60
61 pub fn y(mut self, values: &[f64]) -> Self {
66 self.y_values = Some(values.to_vec());
67 self
68 }
69
70 pub fn x_scale(mut self, scale: ScaleType) -> Self {
72 self.x_scale_type = scale;
73 self
74 }
75
76 pub fn y_scale(mut self, scale: ScaleType) -> Self {
78 self.y_scale_type = scale;
79 self
80 }
81
82 pub fn color_scale(mut self, scale: ColorScale) -> Self {
84 self.color_scale = scale;
85 self
86 }
87
88 pub fn title(mut self, title: impl Into<String>) -> Self {
90 self.title = Some(title.into());
91 self
92 }
93
94 pub fn opacity(mut self, opacity: f32) -> Self {
96 self.opacity = opacity.clamp(0.0, 1.0);
97 self
98 }
99
100 pub fn size(mut self, width: f32, height: f32) -> Self {
102 self.width = width;
103 self.height = height;
104 self
105 }
106
107 pub fn build(self) -> Result<impl IntoElement, ChartError> {
109 validate_data_array(&self.z, "z")?;
111 validate_grid_dimensions(&self.z, self.grid_width, self.grid_height)?;
112 validate_dimensions(self.width, self.height)?;
113
114 let x_values = match self.x_values {
116 Some(ref v) => {
117 if v.len() != self.grid_width {
118 return Err(ChartError::DataLengthMismatch {
119 x_field: "x",
120 y_field: "grid_width",
121 x_len: v.len(),
122 y_len: self.grid_width,
123 });
124 }
125 validate_data_array(v, "x")?;
126 validate_monotonic(v, "x")?;
127 if self.x_scale_type == ScaleType::Log {
128 validate_positive(v, "x")?;
129 }
130 v.clone()
131 }
132 None => (0..self.grid_width).map(|i| i as f64).collect(),
133 };
134
135 let y_values = match self.y_values {
137 Some(ref v) => {
138 if v.len() != self.grid_height {
139 return Err(ChartError::DataLengthMismatch {
140 x_field: "y",
141 y_field: "grid_height",
142 x_len: v.len(),
143 y_len: self.grid_height,
144 });
145 }
146 validate_data_array(v, "y")?;
147 validate_monotonic(v, "y")?;
148 if self.y_scale_type == ScaleType::Log {
149 validate_positive(v, "y")?;
150 }
151 v.clone()
152 }
153 None => (0..self.grid_height).map(|i| i as f64).collect(),
154 };
155
156 let margin_left = 50.0;
158 let margin_bottom = 30.0;
159 let margin_top = 10.0;
160 let margin_right = 20.0;
161
162 let title_height = if self.title.is_some() {
164 TITLE_AREA_HEIGHT
165 } else {
166 0.0
167 };
168
169 let plot_width = (self.width as f64 - margin_left - margin_right).max(0.0);
170 let plot_height =
171 (self.height as f64 - title_height as f64 - margin_top - margin_bottom).max(0.0);
172
173 let (x_min, x_max) = extent_padded(&x_values, 0.0);
175 let (y_min, y_max) = extent_padded(&y_values, 0.0);
176
177 let heatmap_data = HeatmapData::new(x_values, y_values, self.z.clone());
179
180 let color_fn = self.color_scale.to_fn();
182 let config = ContourConfig::new()
183 .fill(true)
184 .fill_opacity(self.opacity)
185 .color_scale(color_fn);
186
187 let theme = DefaultAxisTheme;
188
189 let chart_content: AnyElement = match (self.x_scale_type, self.y_scale_type) {
191 (ScaleType::Linear, ScaleType::Linear) => {
192 let x_scale = LinearScale::new()
193 .domain(x_min, x_max)
194 .range(0.0, plot_width);
195 let y_scale = LinearScale::new()
196 .domain(y_min, y_max)
197 .range(plot_height, 0.0);
198
199 div()
200 .flex()
201 .child(render_axis(
202 &y_scale,
203 &AxisConfig::left(),
204 plot_height as f32,
205 &theme,
206 ))
207 .child(
208 div()
209 .flex()
210 .flex_col()
211 .child(
212 div()
213 .w(px(plot_width as f32))
214 .h(px(plot_height as f32))
215 .relative()
216 .bg(rgb(0xf8f8f8))
217 .child(render_grid(
218 &x_scale,
219 &y_scale,
220 &GridConfig::default(),
221 plot_width as f32,
222 plot_height as f32,
223 &theme,
224 ))
225 .child(div().absolute().inset_0().child(render_heatmap(
226 heatmap_data,
227 &x_scale,
228 &y_scale,
229 &config,
230 ))),
231 )
232 .child(render_axis(
233 &x_scale,
234 &AxisConfig::bottom(),
235 plot_width as f32,
236 &theme,
237 )),
238 )
239 .into_any_element()
240 }
241 (ScaleType::Log, ScaleType::Linear) => {
242 let x_scale = LogScale::new()
243 .domain(x_min.max(1e-10), x_max)
244 .range(0.0, plot_width);
245 let y_scale = LinearScale::new()
246 .domain(y_min, y_max)
247 .range(plot_height, 0.0);
248
249 div()
250 .flex()
251 .child(render_axis(
252 &y_scale,
253 &AxisConfig::left(),
254 plot_height as f32,
255 &theme,
256 ))
257 .child(
258 div()
259 .flex()
260 .flex_col()
261 .child(
262 div()
263 .w(px(plot_width as f32))
264 .h(px(plot_height as f32))
265 .relative()
266 .bg(rgb(0xf8f8f8))
267 .child(render_grid(
268 &x_scale,
269 &y_scale,
270 &GridConfig::default(),
271 plot_width as f32,
272 plot_height as f32,
273 &theme,
274 ))
275 .child(div().absolute().inset_0().child(render_heatmap(
276 heatmap_data,
277 &x_scale,
278 &y_scale,
279 &config,
280 ))),
281 )
282 .child(render_axis(
283 &x_scale,
284 &AxisConfig::bottom(),
285 plot_width as f32,
286 &theme,
287 )),
288 )
289 .into_any_element()
290 }
291 (ScaleType::Linear, ScaleType::Log) => {
292 let x_scale = LinearScale::new()
293 .domain(x_min, x_max)
294 .range(0.0, plot_width);
295 let y_scale = LogScale::new()
296 .domain(y_min.max(1e-10), y_max)
297 .range(plot_height, 0.0);
298
299 div()
300 .flex()
301 .child(render_axis(
302 &y_scale,
303 &AxisConfig::left(),
304 plot_height as f32,
305 &theme,
306 ))
307 .child(
308 div()
309 .flex()
310 .flex_col()
311 .child(
312 div()
313 .w(px(plot_width as f32))
314 .h(px(plot_height as f32))
315 .relative()
316 .bg(rgb(0xf8f8f8))
317 .child(render_grid(
318 &x_scale,
319 &y_scale,
320 &GridConfig::default(),
321 plot_width as f32,
322 plot_height as f32,
323 &theme,
324 ))
325 .child(div().absolute().inset_0().child(render_heatmap(
326 heatmap_data,
327 &x_scale,
328 &y_scale,
329 &config,
330 ))),
331 )
332 .child(render_axis(
333 &x_scale,
334 &AxisConfig::bottom(),
335 plot_width as f32,
336 &theme,
337 )),
338 )
339 .into_any_element()
340 }
341 (ScaleType::Log, ScaleType::Log) => {
342 let x_scale = LogScale::new()
343 .domain(x_min.max(1e-10), x_max)
344 .range(0.0, plot_width);
345 let y_scale = LogScale::new()
346 .domain(y_min.max(1e-10), y_max)
347 .range(plot_height, 0.0);
348
349 div()
350 .flex()
351 .child(render_axis(
352 &y_scale,
353 &AxisConfig::left(),
354 plot_height as f32,
355 &theme,
356 ))
357 .child(
358 div()
359 .flex()
360 .flex_col()
361 .child(
362 div()
363 .w(px(plot_width as f32))
364 .h(px(plot_height as f32))
365 .relative()
366 .bg(rgb(0xf8f8f8))
367 .child(render_grid(
368 &x_scale,
369 &y_scale,
370 &GridConfig::default(),
371 plot_width as f32,
372 plot_height as f32,
373 &theme,
374 ))
375 .child(div().absolute().inset_0().child(render_heatmap(
376 heatmap_data,
377 &x_scale,
378 &y_scale,
379 &config,
380 ))),
381 )
382 .child(render_axis(
383 &x_scale,
384 &AxisConfig::bottom(),
385 plot_width as f32,
386 &theme,
387 )),
388 )
389 .into_any_element()
390 }
391 };
392
393 let mut container = div()
395 .w(px(self.width))
396 .h(px(self.height))
397 .relative()
398 .flex()
399 .flex_col();
400
401 if let Some(title) = &self.title {
403 let font_config =
404 VectorFontConfig::horizontal(DEFAULT_TITLE_FONT_SIZE, hsla(0.0, 0.0, 0.2, 1.0));
405 container = container.child(
406 div()
407 .w_full()
408 .h(px(title_height))
409 .flex()
410 .justify_center()
411 .items_center()
412 .child(render_vector_text(title, &font_config)),
413 );
414 }
415
416 container = container.child(div().relative().child(chart_content));
418
419 Ok(container)
420 }
421}
422
423pub fn heatmap(z: &[f64], grid_width: usize, grid_height: usize) -> HeatmapChart {
464 HeatmapChart {
465 z: z.to_vec(),
466 grid_width,
467 grid_height,
468 x_values: None,
469 y_values: None,
470 x_scale_type: ScaleType::Linear,
471 y_scale_type: ScaleType::Linear,
472 color_scale: ColorScale::default(),
473 title: None,
474 opacity: 1.0,
475 width: DEFAULT_WIDTH,
476 height: DEFAULT_HEIGHT,
477 }
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483
484 #[test]
485 fn test_heatmap_empty_z() {
486 let result = heatmap(&[], 0, 0).build();
487 assert!(matches!(result, Err(ChartError::EmptyData { field: "z" })));
488 }
489
490 #[test]
491 fn test_heatmap_grid_mismatch() {
492 let z = vec![1.0, 2.0, 3.0, 4.0, 5.0]; let result = heatmap(&z, 2, 3).build(); assert!(matches!(
495 result,
496 Err(ChartError::GridDimensionMismatch {
497 z_len: 5,
498 width: 2,
499 height: 3,
500 expected: 6,
501 })
502 ));
503 }
504
505 #[test]
506 fn test_heatmap_x_length_mismatch() {
507 let z = vec![1.0; 6]; let x = vec![0.0, 1.0, 2.0]; let result = heatmap(&z, 2, 3).x(&x).build();
510 assert!(matches!(
511 result,
512 Err(ChartError::DataLengthMismatch {
513 x_field: "x",
514 y_field: "grid_width",
515 x_len: 3,
516 y_len: 2,
517 })
518 ));
519 }
520
521 #[test]
522 fn test_heatmap_y_length_mismatch() {
523 let z = vec![1.0; 6]; let y = vec![0.0, 1.0]; let result = heatmap(&z, 2, 3).y(&y).build();
526 assert!(matches!(
527 result,
528 Err(ChartError::DataLengthMismatch {
529 x_field: "y",
530 y_field: "grid_height",
531 x_len: 2,
532 y_len: 3,
533 })
534 ));
535 }
536
537 #[test]
538 fn test_heatmap_non_monotonic_x() {
539 let z = vec![1.0; 4]; let x = vec![1.0, 0.0]; let result = heatmap(&z, 2, 2).x(&x).build();
542 assert!(matches!(
543 result,
544 Err(ChartError::InvalidData {
545 field: "x",
546 reason: "must be strictly monotonically increasing"
547 })
548 ));
549 }
550
551 #[test]
552 fn test_heatmap_log_scale_negative() {
553 let z = vec![1.0; 4]; let x = vec![-1.0, 1.0]; let result = heatmap(&z, 2, 2).x(&x).x_scale(ScaleType::Log).build();
556 assert!(matches!(
557 result,
558 Err(ChartError::InvalidData {
559 field: "x",
560 reason: "log scale requires positive values"
561 })
562 ));
563 }
564
565 #[test]
566 fn test_heatmap_successful_build() {
567 let z = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let result = heatmap(&z, 2, 3)
569 .title("Test Heatmap")
570 .color_scale(ColorScale::Viridis)
571 .build();
572 assert!(result.is_ok());
573 }
574
575 #[test]
576 fn test_heatmap_with_custom_axes() {
577 let z = vec![1.0; 6]; let x = vec![10.0, 100.0];
579 let y = vec![0.0, 1.0, 2.0];
580 let result = heatmap(&z, 2, 3).x(&x).y(&y).build();
581 assert!(result.is_ok());
582 }
583
584 #[test]
585 fn test_heatmap_log_scale() {
586 let z = vec![1.0; 4]; let x = vec![10.0, 100.0];
588 let y = vec![1.0, 10.0];
589 let result = heatmap(&z, 2, 2)
590 .x(&x)
591 .y(&y)
592 .x_scale(ScaleType::Log)
593 .y_scale(ScaleType::Log)
594 .build();
595 assert!(result.is_ok());
596 }
597
598 #[test]
599 fn test_heatmap_builder_chain() {
600 let z = vec![1.0; 9]; let result = heatmap(&z, 3, 3)
602 .title("My Heatmap")
603 .color_scale(ColorScale::Plasma)
604 .opacity(0.8)
605 .size(800.0, 600.0)
606 .build();
607 assert!(result.is_ok());
608 }
609}