gpui_px/
heatmap.rs

1//! Heatmap chart - Plotly Express style API.
2
3use crate::color_scale::ColorScale;
4use crate::error::ChartError;
5use crate::{
6    DEFAULT_HEIGHT, DEFAULT_TITLE_FONT_SIZE, DEFAULT_WIDTH, ScaleType, TITLE_AREA_HEIGHT,
7    extent_padded, validate_data_array, validate_dimensions, validate_grid_dimensions,
8    validate_monotonic, validate_positive,
9};
10use d3rs::scale::{LinearScale, LogScale};
11use d3rs::shape::{ContourConfig, HeatmapData, render_heatmap};
12use d3rs::text::{VectorFontConfig, render_vector_text};
13use gpui::prelude::*;
14use gpui::*;
15
16/// Heatmap chart builder.
17#[derive(Clone)]
18pub struct HeatmapChart {
19    z: Vec<f64>,
20    grid_width: usize,
21    grid_height: usize,
22    x_values: Option<Vec<f64>>,
23    y_values: Option<Vec<f64>>,
24    x_scale_type: ScaleType,
25    y_scale_type: ScaleType,
26    color_scale: ColorScale,
27    title: Option<String>,
28    opacity: f32,
29    width: f32,
30    height: f32,
31}
32
33impl std::fmt::Debug for HeatmapChart {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("HeatmapChart")
36            .field("grid_width", &self.grid_width)
37            .field("grid_height", &self.grid_height)
38            .field("x_scale_type", &self.x_scale_type)
39            .field("y_scale_type", &self.y_scale_type)
40            .field("color_scale", &self.color_scale)
41            .field("title", &self.title)
42            .field("opacity", &self.opacity)
43            .field("width", &self.width)
44            .field("height", &self.height)
45            .finish()
46    }
47}
48
49impl HeatmapChart {
50    /// Set custom x axis values.
51    ///
52    /// Values must be strictly monotonically increasing.
53    /// Length must match grid_width.
54    pub fn x(mut self, values: &[f64]) -> Self {
55        self.x_values = Some(values.to_vec());
56        self
57    }
58
59    /// Set custom y axis values.
60    ///
61    /// Values must be strictly monotonically increasing.
62    /// Length must match grid_height.
63    pub fn y(mut self, values: &[f64]) -> Self {
64        self.y_values = Some(values.to_vec());
65        self
66    }
67
68    /// Set x-axis scale type.
69    pub fn x_scale(mut self, scale: ScaleType) -> Self {
70        self.x_scale_type = scale;
71        self
72    }
73
74    /// Set y-axis scale type.
75    pub fn y_scale(mut self, scale: ScaleType) -> Self {
76        self.y_scale_type = scale;
77        self
78    }
79
80    /// Set color scale.
81    pub fn color_scale(mut self, scale: ColorScale) -> Self {
82        self.color_scale = scale;
83        self
84    }
85
86    /// Set chart title (rendered at top of chart).
87    pub fn title(mut self, title: impl Into<String>) -> Self {
88        self.title = Some(title.into());
89        self
90    }
91
92    /// Set fill opacity (0.0 - 1.0).
93    pub fn opacity(mut self, opacity: f32) -> Self {
94        self.opacity = opacity.clamp(0.0, 1.0);
95        self
96    }
97
98    /// Set chart dimensions.
99    pub fn size(mut self, width: f32, height: f32) -> Self {
100        self.width = width;
101        self.height = height;
102        self
103    }
104
105    /// Build and validate the chart, returning renderable element.
106    pub fn build(self) -> Result<impl IntoElement, ChartError> {
107        // Validate inputs
108        validate_data_array(&self.z, "z")?;
109        validate_grid_dimensions(&self.z, self.grid_width, self.grid_height)?;
110        validate_dimensions(self.width, self.height)?;
111
112        // Generate or validate x values
113        let x_values = match self.x_values {
114            Some(ref v) => {
115                if v.len() != self.grid_width {
116                    return Err(ChartError::DataLengthMismatch {
117                        x_field: "x",
118                        y_field: "grid_width",
119                        x_len: v.len(),
120                        y_len: self.grid_width,
121                    });
122                }
123                validate_data_array(v, "x")?;
124                validate_monotonic(v, "x")?;
125                if self.x_scale_type == ScaleType::Log {
126                    validate_positive(v, "x")?;
127                }
128                v.clone()
129            }
130            None => (0..self.grid_width).map(|i| i as f64).collect(),
131        };
132
133        // Generate or validate y values
134        let y_values = match self.y_values {
135            Some(ref v) => {
136                if v.len() != self.grid_height {
137                    return Err(ChartError::DataLengthMismatch {
138                        x_field: "y",
139                        y_field: "grid_height",
140                        x_len: v.len(),
141                        y_len: self.grid_height,
142                    });
143                }
144                validate_data_array(v, "y")?;
145                validate_monotonic(v, "y")?;
146                if self.y_scale_type == ScaleType::Log {
147                    validate_positive(v, "y")?;
148                }
149                v.clone()
150            }
151            None => (0..self.grid_height).map(|i| i as f64).collect(),
152        };
153
154        // Calculate plot area (reserve space for title if present)
155        let title_height = if self.title.is_some() {
156            TITLE_AREA_HEIGHT
157        } else {
158            0.0
159        };
160        let plot_height = self.height - title_height;
161
162        // Calculate domains with padding
163        let (x_min, x_max) = extent_padded(&x_values, 0.0);
164        let (y_min, y_max) = extent_padded(&y_values, 0.0);
165
166        // Create HeatmapData
167        let heatmap_data = HeatmapData::new(x_values, y_values, self.z.clone());
168
169        // Build config with color scale
170        let color_fn = self.color_scale.to_fn();
171        let config = ContourConfig::new()
172            .fill(true)
173            .fill_opacity(self.opacity)
174            .color_scale(color_fn);
175
176        // Build the element based on scale types
177        let heatmap_element: AnyElement = match (self.x_scale_type, self.y_scale_type) {
178            (ScaleType::Linear, ScaleType::Linear) => {
179                let x_scale = LinearScale::new()
180                    .domain(x_min, x_max)
181                    .range(0.0, self.width as f64);
182                let y_scale = LinearScale::new()
183                    .domain(y_min, y_max)
184                    .range(plot_height as f64, 0.0);
185                render_heatmap(heatmap_data, &x_scale, &y_scale, &config).into_any_element()
186            }
187            (ScaleType::Log, ScaleType::Linear) => {
188                let x_scale = LogScale::new()
189                    .domain(x_min.max(1e-10), x_max)
190                    .range(0.0, self.width as f64);
191                let y_scale = LinearScale::new()
192                    .domain(y_min, y_max)
193                    .range(plot_height as f64, 0.0);
194                render_heatmap(heatmap_data, &x_scale, &y_scale, &config).into_any_element()
195            }
196            (ScaleType::Linear, ScaleType::Log) => {
197                let x_scale = LinearScale::new()
198                    .domain(x_min, x_max)
199                    .range(0.0, self.width as f64);
200                let y_scale = LogScale::new()
201                    .domain(y_min.max(1e-10), y_max)
202                    .range(plot_height as f64, 0.0);
203                render_heatmap(heatmap_data, &x_scale, &y_scale, &config).into_any_element()
204            }
205            (ScaleType::Log, ScaleType::Log) => {
206                let x_scale = LogScale::new()
207                    .domain(x_min.max(1e-10), x_max)
208                    .range(0.0, self.width as f64);
209                let y_scale = LogScale::new()
210                    .domain(y_min.max(1e-10), y_max)
211                    .range(plot_height as f64, 0.0);
212                render_heatmap(heatmap_data, &x_scale, &y_scale, &config).into_any_element()
213            }
214        };
215
216        // Build container with optional title
217        let mut container = div()
218            .w(px(self.width))
219            .h(px(self.height))
220            .relative()
221            .flex()
222            .flex_col();
223
224        // Add title if present
225        if let Some(title) = &self.title {
226            let font_config =
227                VectorFontConfig::horizontal(DEFAULT_TITLE_FONT_SIZE, hsla(0.0, 0.0, 0.2, 1.0));
228            container = container.child(
229                div()
230                    .w_full()
231                    .h(px(title_height))
232                    .flex()
233                    .justify_center()
234                    .items_center()
235                    .child(render_vector_text(title, &font_config)),
236            );
237        }
238
239        // Add plot area
240        container = container.child(
241            div()
242                .w(px(self.width))
243                .h(px(plot_height))
244                .relative()
245                .child(heatmap_element),
246        );
247
248        Ok(container)
249    }
250}
251
252/// Create a heatmap chart from z data with grid dimensions.
253///
254/// Data is in row-major order: `z[row * width + col]` where row 0 is at the bottom.
255///
256/// # Example
257///
258/// ```rust,no_run
259/// use gpui_px::{heatmap, ColorScale, ScaleType};
260///
261/// // 3x3 grid
262/// let z = vec![
263///     1.0, 2.0, 3.0,  // row 0 (bottom)
264///     4.0, 5.0, 6.0,  // row 1
265///     7.0, 8.0, 9.0,  // row 2 (top)
266/// ];
267///
268/// let chart = heatmap(&z, 3, 3)
269///     .title("My Heatmap")
270///     .color_scale(ColorScale::Inferno)
271///     .build()?;
272/// # Ok::<(), gpui_px::ChartError>(())
273/// ```
274///
275/// # With custom axes
276///
277/// ```rust,no_run
278/// use gpui_px::{heatmap, ColorScale, ScaleType};
279///
280/// let freq_bins = vec![20.0, 100.0, 1000.0, 10000.0, 20000.0];
281/// let time_bins = vec![0.0, 1.0, 2.0, 3.0];
282/// let z = vec![0.0; 20]; // 5x4 grid
283///
284/// let chart = heatmap(&z, 5, 4)
285///     .x(&freq_bins)
286///     .y(&time_bins)
287///     .x_scale(ScaleType::Log)
288///     .color_scale(ColorScale::Viridis)
289///     .build()?;
290/// # Ok::<(), gpui_px::ChartError>(())
291/// ```
292pub fn heatmap(z: &[f64], grid_width: usize, grid_height: usize) -> HeatmapChart {
293    HeatmapChart {
294        z: z.to_vec(),
295        grid_width,
296        grid_height,
297        x_values: None,
298        y_values: None,
299        x_scale_type: ScaleType::Linear,
300        y_scale_type: ScaleType::Linear,
301        color_scale: ColorScale::default(),
302        title: None,
303        opacity: 1.0,
304        width: DEFAULT_WIDTH,
305        height: DEFAULT_HEIGHT,
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    #[test]
314    fn test_heatmap_empty_z() {
315        let result = heatmap(&[], 0, 0).build();
316        assert!(matches!(result, Err(ChartError::EmptyData { field: "z" })));
317    }
318
319    #[test]
320    fn test_heatmap_grid_mismatch() {
321        let z = vec![1.0, 2.0, 3.0, 4.0, 5.0]; // 5 values
322        let result = heatmap(&z, 2, 3).build(); // expects 6
323        assert!(matches!(
324            result,
325            Err(ChartError::GridDimensionMismatch {
326                z_len: 5,
327                width: 2,
328                height: 3,
329                expected: 6,
330            })
331        ));
332    }
333
334    #[test]
335    fn test_heatmap_x_length_mismatch() {
336        let z = vec![1.0; 6]; // 2x3 grid
337        let x = vec![0.0, 1.0, 2.0]; // 3 values, expects 2
338        let result = heatmap(&z, 2, 3).x(&x).build();
339        assert!(matches!(
340            result,
341            Err(ChartError::DataLengthMismatch {
342                x_field: "x",
343                y_field: "grid_width",
344                x_len: 3,
345                y_len: 2,
346            })
347        ));
348    }
349
350    #[test]
351    fn test_heatmap_y_length_mismatch() {
352        let z = vec![1.0; 6]; // 2x3 grid
353        let y = vec![0.0, 1.0]; // 2 values, expects 3
354        let result = heatmap(&z, 2, 3).y(&y).build();
355        assert!(matches!(
356            result,
357            Err(ChartError::DataLengthMismatch {
358                x_field: "y",
359                y_field: "grid_height",
360                x_len: 2,
361                y_len: 3,
362            })
363        ));
364    }
365
366    #[test]
367    fn test_heatmap_non_monotonic_x() {
368        let z = vec![1.0; 4]; // 2x2 grid
369        let x = vec![1.0, 0.0]; // not monotonic
370        let result = heatmap(&z, 2, 2).x(&x).build();
371        assert!(matches!(
372            result,
373            Err(ChartError::InvalidData {
374                field: "x",
375                reason: "must be strictly monotonically increasing"
376            })
377        ));
378    }
379
380    #[test]
381    fn test_heatmap_log_scale_negative() {
382        let z = vec![1.0; 4]; // 2x2 grid
383        let x = vec![-1.0, 1.0]; // negative values
384        let result = heatmap(&z, 2, 2).x(&x).x_scale(ScaleType::Log).build();
385        assert!(matches!(
386            result,
387            Err(ChartError::InvalidData {
388                field: "x",
389                reason: "log scale requires positive values"
390            })
391        ));
392    }
393
394    #[test]
395    fn test_heatmap_successful_build() {
396        let z = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3 grid
397        let result = heatmap(&z, 2, 3)
398            .title("Test Heatmap")
399            .color_scale(ColorScale::Viridis)
400            .build();
401        assert!(result.is_ok());
402    }
403
404    #[test]
405    fn test_heatmap_with_custom_axes() {
406        let z = vec![1.0; 6]; // 2x3 grid
407        let x = vec![10.0, 100.0];
408        let y = vec![0.0, 1.0, 2.0];
409        let result = heatmap(&z, 2, 3).x(&x).y(&y).build();
410        assert!(result.is_ok());
411    }
412
413    #[test]
414    fn test_heatmap_log_scale() {
415        let z = vec![1.0; 4]; // 2x2 grid
416        let x = vec![10.0, 100.0];
417        let y = vec![1.0, 10.0];
418        let result = heatmap(&z, 2, 2)
419            .x(&x)
420            .y(&y)
421            .x_scale(ScaleType::Log)
422            .y_scale(ScaleType::Log)
423            .build();
424        assert!(result.is_ok());
425    }
426
427    #[test]
428    fn test_heatmap_builder_chain() {
429        let z = vec![1.0; 9]; // 3x3 grid
430        let result = heatmap(&z, 3, 3)
431            .title("My Heatmap")
432            .color_scale(ColorScale::Plasma)
433            .opacity(0.8)
434            .size(800.0, 600.0)
435            .build();
436        assert!(result.is_ok());
437    }
438}