gpui_px/
heatmap.rs

1//! Heatmap chart - Plotly Express style API.
2
3use crate::color_scale::ColorScale;
4use crate::error::ChartError;
5use crate::{
6    DEFAULT_HEIGHT, DEFAULT_TITLE_FONT_SIZE, DEFAULT_WIDTH, ScaleType, TITLE_AREA_HEIGHT,
7    extent_padded, validate_data_array, validate_dimensions, validate_grid_dimensions,
8    validate_monotonic, validate_positive,
9};
10use d3rs::axis::{AxisConfig, DefaultAxisTheme, render_axis};
11use d3rs::grid::{GridConfig, render_grid};
12use d3rs::scale::{LinearScale, LogScale};
13use d3rs::shape::{ContourConfig, HeatmapData, render_heatmap};
14use d3rs::text::{VectorFontConfig, render_vector_text};
15use gpui::prelude::*;
16use gpui::*;
17
18/// Heatmap chart builder.
19#[derive(Clone)]
20pub struct HeatmapChart {
21    z: Vec<f64>,
22    grid_width: usize,
23    grid_height: usize,
24    x_values: Option<Vec<f64>>,
25    y_values: Option<Vec<f64>>,
26    x_scale_type: ScaleType,
27    y_scale_type: ScaleType,
28    color_scale: ColorScale,
29    title: Option<String>,
30    opacity: f32,
31    width: f32,
32    height: f32,
33}
34
35impl std::fmt::Debug for HeatmapChart {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        f.debug_struct("HeatmapChart")
38            .field("grid_width", &self.grid_width)
39            .field("grid_height", &self.grid_height)
40            .field("x_scale_type", &self.x_scale_type)
41            .field("y_scale_type", &self.y_scale_type)
42            .field("color_scale", &self.color_scale)
43            .field("title", &self.title)
44            .field("opacity", &self.opacity)
45            .field("width", &self.width)
46            .field("height", &self.height)
47            .finish()
48    }
49}
50
51impl HeatmapChart {
52    /// Set custom x axis values.
53    ///
54    /// Values must be strictly monotonically increasing.
55    /// Length must match grid_width.
56    pub fn x(mut self, values: &[f64]) -> Self {
57        self.x_values = Some(values.to_vec());
58        self
59    }
60
61    /// Set custom y axis values.
62    ///
63    /// Values must be strictly monotonically increasing.
64    /// Length must match grid_height.
65    pub fn y(mut self, values: &[f64]) -> Self {
66        self.y_values = Some(values.to_vec());
67        self
68    }
69
70    /// Set x-axis scale type.
71    pub fn x_scale(mut self, scale: ScaleType) -> Self {
72        self.x_scale_type = scale;
73        self
74    }
75
76    /// Set y-axis scale type.
77    pub fn y_scale(mut self, scale: ScaleType) -> Self {
78        self.y_scale_type = scale;
79        self
80    }
81
82    /// Set color scale.
83    pub fn color_scale(mut self, scale: ColorScale) -> Self {
84        self.color_scale = scale;
85        self
86    }
87
88    /// Set chart title (rendered at top of chart).
89    pub fn title(mut self, title: impl Into<String>) -> Self {
90        self.title = Some(title.into());
91        self
92    }
93
94    /// Set fill opacity (0.0 - 1.0).
95    pub fn opacity(mut self, opacity: f32) -> Self {
96        self.opacity = opacity.clamp(0.0, 1.0);
97        self
98    }
99
100    /// Set chart dimensions.
101    pub fn size(mut self, width: f32, height: f32) -> Self {
102        self.width = width;
103        self.height = height;
104        self
105    }
106
107    /// Build and validate the chart, returning renderable element.
108    pub fn build(self) -> Result<impl IntoElement, ChartError> {
109        // Validate inputs
110        validate_data_array(&self.z, "z")?;
111        validate_grid_dimensions(&self.z, self.grid_width, self.grid_height)?;
112        validate_dimensions(self.width, self.height)?;
113
114        // Generate or validate x values
115        let x_values = match self.x_values {
116            Some(ref v) => {
117                if v.len() != self.grid_width {
118                    return Err(ChartError::DataLengthMismatch {
119                        x_field: "x",
120                        y_field: "grid_width",
121                        x_len: v.len(),
122                        y_len: self.grid_width,
123                    });
124                }
125                validate_data_array(v, "x")?;
126                validate_monotonic(v, "x")?;
127                if self.x_scale_type == ScaleType::Log {
128                    validate_positive(v, "x")?;
129                }
130                v.clone()
131            }
132            None => (0..self.grid_width).map(|i| i as f64).collect(),
133        };
134
135        // Generate or validate y values
136        let y_values = match self.y_values {
137            Some(ref v) => {
138                if v.len() != self.grid_height {
139                    return Err(ChartError::DataLengthMismatch {
140                        x_field: "y",
141                        y_field: "grid_height",
142                        x_len: v.len(),
143                        y_len: self.grid_height,
144                    });
145                }
146                validate_data_array(v, "y")?;
147                validate_monotonic(v, "y")?;
148                if self.y_scale_type == ScaleType::Log {
149                    validate_positive(v, "y")?;
150                }
151                v.clone()
152            }
153            None => (0..self.grid_height).map(|i| i as f64).collect(),
154        };
155
156        // Define margins
157        let margin_left = 50.0;
158        let margin_bottom = 30.0;
159        let margin_top = 10.0;
160        let margin_right = 20.0;
161
162        // Calculate plot area (reserve space for title if present)
163        let title_height = if self.title.is_some() {
164            TITLE_AREA_HEIGHT
165        } else {
166            0.0
167        };
168
169        let plot_width = (self.width as f64 - margin_left - margin_right).max(0.0);
170        let plot_height =
171            (self.height as f64 - title_height as f64 - margin_top - margin_bottom).max(0.0);
172
173        // Calculate domains with padding
174        let (x_min, x_max) = extent_padded(&x_values, 0.0);
175        let (y_min, y_max) = extent_padded(&y_values, 0.0);
176
177        // Create HeatmapData
178        let heatmap_data = HeatmapData::new(x_values, y_values, self.z.clone());
179
180        // Build config with color scale
181        let color_fn = self.color_scale.to_fn();
182        let config = ContourConfig::new()
183            .fill(true)
184            .fill_opacity(self.opacity)
185            .color_scale(color_fn);
186
187        let theme = DefaultAxisTheme;
188
189        // Build the element based on scale types
190        let chart_content: AnyElement = match (self.x_scale_type, self.y_scale_type) {
191            (ScaleType::Linear, ScaleType::Linear) => {
192                let x_scale = LinearScale::new()
193                    .domain(x_min, x_max)
194                    .range(0.0, plot_width);
195                let y_scale = LinearScale::new()
196                    .domain(y_min, y_max)
197                    .range(plot_height, 0.0);
198
199                div()
200                    .flex()
201                    .child(render_axis(
202                        &y_scale,
203                        &AxisConfig::left(),
204                        plot_height as f32,
205                        &theme,
206                    ))
207                    .child(
208                        div()
209                            .flex()
210                            .flex_col()
211                            .child(
212                                div()
213                                    .w(px(plot_width as f32))
214                                    .h(px(plot_height as f32))
215                                    .relative()
216                                    .bg(rgb(0xf8f8f8))
217                                    .child(render_grid(
218                                        &x_scale,
219                                        &y_scale,
220                                        &GridConfig::default(),
221                                        plot_width as f32,
222                                        plot_height as f32,
223                                        &theme,
224                                    ))
225                                    .child(div().absolute().inset_0().child(render_heatmap(
226                                        heatmap_data,
227                                        &x_scale,
228                                        &y_scale,
229                                        &config,
230                                    ))),
231                            )
232                            .child(render_axis(
233                                &x_scale,
234                                &AxisConfig::bottom(),
235                                plot_width as f32,
236                                &theme,
237                            )),
238                    )
239                    .into_any_element()
240            }
241            (ScaleType::Log, ScaleType::Linear) => {
242                let x_scale = LogScale::new()
243                    .domain(x_min.max(1e-10), x_max)
244                    .range(0.0, plot_width);
245                let y_scale = LinearScale::new()
246                    .domain(y_min, y_max)
247                    .range(plot_height, 0.0);
248
249                div()
250                    .flex()
251                    .child(render_axis(
252                        &y_scale,
253                        &AxisConfig::left(),
254                        plot_height as f32,
255                        &theme,
256                    ))
257                    .child(
258                        div()
259                            .flex()
260                            .flex_col()
261                            .child(
262                                div()
263                                    .w(px(plot_width as f32))
264                                    .h(px(plot_height as f32))
265                                    .relative()
266                                    .bg(rgb(0xf8f8f8))
267                                    .child(render_grid(
268                                        &x_scale,
269                                        &y_scale,
270                                        &GridConfig::default(),
271                                        plot_width as f32,
272                                        plot_height as f32,
273                                        &theme,
274                                    ))
275                                    .child(div().absolute().inset_0().child(render_heatmap(
276                                        heatmap_data,
277                                        &x_scale,
278                                        &y_scale,
279                                        &config,
280                                    ))),
281                            )
282                            .child(render_axis(
283                                &x_scale,
284                                &AxisConfig::bottom(),
285                                plot_width as f32,
286                                &theme,
287                            )),
288                    )
289                    .into_any_element()
290            }
291            (ScaleType::Linear, ScaleType::Log) => {
292                let x_scale = LinearScale::new()
293                    .domain(x_min, x_max)
294                    .range(0.0, plot_width);
295                let y_scale = LogScale::new()
296                    .domain(y_min.max(1e-10), y_max)
297                    .range(plot_height, 0.0);
298
299                div()
300                    .flex()
301                    .child(render_axis(
302                        &y_scale,
303                        &AxisConfig::left(),
304                        plot_height as f32,
305                        &theme,
306                    ))
307                    .child(
308                        div()
309                            .flex()
310                            .flex_col()
311                            .child(
312                                div()
313                                    .w(px(plot_width as f32))
314                                    .h(px(plot_height as f32))
315                                    .relative()
316                                    .bg(rgb(0xf8f8f8))
317                                    .child(render_grid(
318                                        &x_scale,
319                                        &y_scale,
320                                        &GridConfig::default(),
321                                        plot_width as f32,
322                                        plot_height as f32,
323                                        &theme,
324                                    ))
325                                    .child(div().absolute().inset_0().child(render_heatmap(
326                                        heatmap_data,
327                                        &x_scale,
328                                        &y_scale,
329                                        &config,
330                                    ))),
331                            )
332                            .child(render_axis(
333                                &x_scale,
334                                &AxisConfig::bottom(),
335                                plot_width as f32,
336                                &theme,
337                            )),
338                    )
339                    .into_any_element()
340            }
341            (ScaleType::Log, ScaleType::Log) => {
342                let x_scale = LogScale::new()
343                    .domain(x_min.max(1e-10), x_max)
344                    .range(0.0, plot_width);
345                let y_scale = LogScale::new()
346                    .domain(y_min.max(1e-10), y_max)
347                    .range(plot_height, 0.0);
348
349                div()
350                    .flex()
351                    .child(render_axis(
352                        &y_scale,
353                        &AxisConfig::left(),
354                        plot_height as f32,
355                        &theme,
356                    ))
357                    .child(
358                        div()
359                            .flex()
360                            .flex_col()
361                            .child(
362                                div()
363                                    .w(px(plot_width as f32))
364                                    .h(px(plot_height as f32))
365                                    .relative()
366                                    .bg(rgb(0xf8f8f8))
367                                    .child(render_grid(
368                                        &x_scale,
369                                        &y_scale,
370                                        &GridConfig::default(),
371                                        plot_width as f32,
372                                        plot_height as f32,
373                                        &theme,
374                                    ))
375                                    .child(div().absolute().inset_0().child(render_heatmap(
376                                        heatmap_data,
377                                        &x_scale,
378                                        &y_scale,
379                                        &config,
380                                    ))),
381                            )
382                            .child(render_axis(
383                                &x_scale,
384                                &AxisConfig::bottom(),
385                                plot_width as f32,
386                                &theme,
387                            )),
388                    )
389                    .into_any_element()
390            }
391        };
392
393        // Build container with optional title
394        let mut container = div()
395            .w(px(self.width))
396            .h(px(self.height))
397            .relative()
398            .flex()
399            .flex_col();
400
401        // Add title if present
402        if let Some(title) = &self.title {
403            let font_config =
404                VectorFontConfig::horizontal(DEFAULT_TITLE_FONT_SIZE, hsla(0.0, 0.0, 0.2, 1.0));
405            container = container.child(
406                div()
407                    .w_full()
408                    .h(px(title_height))
409                    .flex()
410                    .justify_center()
411                    .items_center()
412                    .child(render_vector_text(title, &font_config)),
413            );
414        }
415
416        // Add chart content
417        container = container.child(div().relative().child(chart_content));
418
419        Ok(container)
420    }
421}
422
423/// Create a heatmap chart from z data with grid dimensions.
424///
425/// Data is in row-major order: `z[row * width + col]` where row 0 is at the bottom.
426///
427/// # Example
428///
429/// ```rust,no_run
430/// use gpui_px::{heatmap, ColorScale, ScaleType};
431///
432/// // 3x3 grid
433/// let z = vec![
434///     1.0, 2.0, 3.0,  // row 0 (bottom)
435///     4.0, 5.0, 6.0,  // row 1
436///     7.0, 8.0, 9.0,  // row 2 (top)
437/// ];
438///
439/// let chart = heatmap(&z, 3, 3)
440///     .title("My Heatmap")
441///     .color_scale(ColorScale::Inferno)
442///     .build()?;
443/// # Ok::<(), gpui_px::ChartError>(())
444/// ```
445///
446/// # With custom axes
447///
448/// ```rust,no_run
449/// use gpui_px::{heatmap, ColorScale, ScaleType};
450///
451/// let freq_bins = vec![20.0, 100.0, 1000.0, 10000.0, 20000.0];
452/// let time_bins = vec![0.0, 1.0, 2.0, 3.0];
453/// let z = vec![0.0; 20]; // 5x4 grid
454///
455/// let chart = heatmap(&z, 5, 4)
456///     .x(&freq_bins)
457///     .y(&time_bins)
458///     .x_scale(ScaleType::Log)
459///     .color_scale(ColorScale::Viridis)
460///     .build()?;
461/// # Ok::<(), gpui_px::ChartError>(())
462/// ```
463pub fn heatmap(z: &[f64], grid_width: usize, grid_height: usize) -> HeatmapChart {
464    HeatmapChart {
465        z: z.to_vec(),
466        grid_width,
467        grid_height,
468        x_values: None,
469        y_values: None,
470        x_scale_type: ScaleType::Linear,
471        y_scale_type: ScaleType::Linear,
472        color_scale: ColorScale::default(),
473        title: None,
474        opacity: 1.0,
475        width: DEFAULT_WIDTH,
476        height: DEFAULT_HEIGHT,
477    }
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483
484    #[test]
485    fn test_heatmap_empty_z() {
486        let result = heatmap(&[], 0, 0).build();
487        assert!(matches!(result, Err(ChartError::EmptyData { field: "z" })));
488    }
489
490    #[test]
491    fn test_heatmap_grid_mismatch() {
492        let z = vec![1.0, 2.0, 3.0, 4.0, 5.0]; // 5 values
493        let result = heatmap(&z, 2, 3).build(); // expects 6
494        assert!(matches!(
495            result,
496            Err(ChartError::GridDimensionMismatch {
497                z_len: 5,
498                width: 2,
499                height: 3,
500                expected: 6,
501            })
502        ));
503    }
504
505    #[test]
506    fn test_heatmap_x_length_mismatch() {
507        let z = vec![1.0; 6]; // 2x3 grid
508        let x = vec![0.0, 1.0, 2.0]; // 3 values, expects 2
509        let result = heatmap(&z, 2, 3).x(&x).build();
510        assert!(matches!(
511            result,
512            Err(ChartError::DataLengthMismatch {
513                x_field: "x",
514                y_field: "grid_width",
515                x_len: 3,
516                y_len: 2,
517            })
518        ));
519    }
520
521    #[test]
522    fn test_heatmap_y_length_mismatch() {
523        let z = vec![1.0; 6]; // 2x3 grid
524        let y = vec![0.0, 1.0]; // 2 values, expects 3
525        let result = heatmap(&z, 2, 3).y(&y).build();
526        assert!(matches!(
527            result,
528            Err(ChartError::DataLengthMismatch {
529                x_field: "y",
530                y_field: "grid_height",
531                x_len: 2,
532                y_len: 3,
533            })
534        ));
535    }
536
537    #[test]
538    fn test_heatmap_non_monotonic_x() {
539        let z = vec![1.0; 4]; // 2x2 grid
540        let x = vec![1.0, 0.0]; // not monotonic
541        let result = heatmap(&z, 2, 2).x(&x).build();
542        assert!(matches!(
543            result,
544            Err(ChartError::InvalidData {
545                field: "x",
546                reason: "must be strictly monotonically increasing"
547            })
548        ));
549    }
550
551    #[test]
552    fn test_heatmap_log_scale_negative() {
553        let z = vec![1.0; 4]; // 2x2 grid
554        let x = vec![-1.0, 1.0]; // negative values
555        let result = heatmap(&z, 2, 2).x(&x).x_scale(ScaleType::Log).build();
556        assert!(matches!(
557            result,
558            Err(ChartError::InvalidData {
559                field: "x",
560                reason: "log scale requires positive values"
561            })
562        ));
563    }
564
565    #[test]
566    fn test_heatmap_successful_build() {
567        let z = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3 grid
568        let result = heatmap(&z, 2, 3)
569            .title("Test Heatmap")
570            .color_scale(ColorScale::Viridis)
571            .build();
572        assert!(result.is_ok());
573    }
574
575    #[test]
576    fn test_heatmap_with_custom_axes() {
577        let z = vec![1.0; 6]; // 2x3 grid
578        let x = vec![10.0, 100.0];
579        let y = vec![0.0, 1.0, 2.0];
580        let result = heatmap(&z, 2, 3).x(&x).y(&y).build();
581        assert!(result.is_ok());
582    }
583
584    #[test]
585    fn test_heatmap_log_scale() {
586        let z = vec![1.0; 4]; // 2x2 grid
587        let x = vec![10.0, 100.0];
588        let y = vec![1.0, 10.0];
589        let result = heatmap(&z, 2, 2)
590            .x(&x)
591            .y(&y)
592            .x_scale(ScaleType::Log)
593            .y_scale(ScaleType::Log)
594            .build();
595        assert!(result.is_ok());
596    }
597
598    #[test]
599    fn test_heatmap_builder_chain() {
600        let z = vec![1.0; 9]; // 3x3 grid
601        let result = heatmap(&z, 3, 3)
602            .title("My Heatmap")
603            .color_scale(ColorScale::Plasma)
604            .opacity(0.8)
605            .size(800.0, 600.0)
606            .build();
607        assert!(result.is_ok());
608    }
609}