Skip to main content

plotlars_core/plots/
heatmap.rs

1use bon::bon;
2
3use crate::{
4    components::{Axis, ColorBar, FacetConfig, FacetScales, Palette, Text},
5    ir::data::ColumnData,
6    ir::layout::LayoutIR,
7    ir::trace::{HeatMapIR, TraceIR},
8};
9use polars::frame::DataFrame;
10
11/// A structure representing a heat map.
12///
13/// The `HeatMap` struct enables the creation of heat map visualizations with options for color scaling,
14/// axis customization, legend adjustments, and data value formatting. Users can customize the color
15/// scale, adjust the color bar, and set titles for the plot and axes, as well as format ticks and scales
16/// for improved data readability.
17///
18/// # Backend Support
19///
20/// | Backend | Supported |
21/// |---------|-----------|
22/// | Plotly  | Yes       |
23/// | Plotters| Yes       |
24///
25/// # Arguments
26///
27/// * `data` - A reference to the `DataFrame` containing the data to be plotted.
28/// * `x` - A string slice specifying the column name for x-axis values.
29/// * `y` - A string slice specifying the column name for y-axis values.
30/// * `z` - A string slice specifying the column name for z-axis values, which are represented by the color intensity.
31/// * `facet` - An optional string slice specifying the column name to be used for faceting (creating multiple subplots).
32/// * `facet_config` - An optional reference to a `FacetConfig` struct for customizing facet behavior (grid dimensions, scales, gaps, etc.).
33/// * `auto_color_scale` - An optional boolean for enabling automatic color scaling based on data.
34/// * `color_bar` - An optional reference to a `ColorBar` struct for customizing the color bar appearance.
35/// * `color_scale` - An optional `Palette` enum for specifying the color scale (e.g., Viridis).
36/// * `reverse_scale` - An optional boolean to reverse the color scale direction.
37/// * `show_scale` - An optional boolean to display the color scale on the plot.
38/// * `plot_title` - An optional `Text` struct for setting the title of the plot.
39/// * `x_title` - An optional `Text` struct for labeling the x-axis.
40/// * `y_title` - An optional `Text` struct for labeling the y-axis.
41/// * `x_axis` - An optional reference to an `Axis` struct for customizing x-axis appearance.
42/// * `y_axis` - An optional reference to an `Axis` struct for customizing y-axis appearance.
43///
44/// # Example
45///
46/// ```rust
47/// use plotlars::{ColorBar, HeatMap, Palette, Plot, Text, ValueExponent};
48/// use polars::prelude::*;
49///
50/// let dataset = LazyCsvReader::new(PlRefPath::new("data/heatmap.csv"))
51///     .finish()
52///     .unwrap()
53///     .collect()
54///     .unwrap();
55///
56/// HeatMap::builder()
57///     .data(&dataset)
58///     .x("x")
59///     .y("y")
60///     .z("z")
61///     .color_bar(
62///         &ColorBar::new()
63///             .length(0.7)
64///             .value_exponent(ValueExponent::None)
65///             .separate_thousands(true)
66///             .tick_length(5)
67///             .tick_step(2500.0)
68///     )
69///     .plot_title(
70///         Text::from("Heat Map")
71///             .font("Arial")
72///             .size(18)
73///     )
74///     .color_scale(Palette::Viridis)
75///     .build()
76///     .plot();
77/// ```
78///
79/// ![Example](https://imgur.com/5uFih4M.png)
80#[derive(Clone)]
81#[allow(dead_code)]
82pub struct HeatMap {
83    traces: Vec<TraceIR>,
84    layout: LayoutIR,
85}
86
87#[bon]
88impl HeatMap {
89    #[builder(on(String, into), on(Text, into))]
90    pub fn new(
91        data: &DataFrame,
92        x: &str,
93        y: &str,
94        z: &str,
95        facet: Option<&str>,
96        facet_config: Option<&FacetConfig>,
97        auto_color_scale: Option<bool>,
98        color_bar: Option<&ColorBar>,
99        color_scale: Option<Palette>,
100        reverse_scale: Option<bool>,
101        show_scale: Option<bool>,
102        plot_title: Option<Text>,
103        x_title: Option<Text>,
104        y_title: Option<Text>,
105        x_axis: Option<&Axis>,
106        y_axis: Option<&Axis>,
107    ) -> Self {
108        let grid = facet.map(|facet_column| {
109            let config = facet_config.cloned().unwrap_or_default();
110            let facet_categories =
111                crate::data::get_unique_groups(data, facet_column, config.sorter);
112            let n_facets = facet_categories.len();
113            let (ncols, nrows) =
114                crate::faceting::calculate_grid_dimensions(n_facets, config.cols, config.rows);
115            crate::ir::facet::GridSpec {
116                kind: crate::ir::facet::FacetKind::Axis,
117                rows: nrows,
118                cols: ncols,
119                h_gap: config.h_gap,
120                v_gap: config.v_gap,
121                scales: config.scales.clone(),
122                n_facets,
123                facet_categories,
124                title_style: config.title_style.clone(),
125                x_title: x_title.clone(),
126                y_title: y_title.clone(),
127                x_axis: x_axis.cloned(),
128                y_axis: y_axis.cloned(),
129                legend_title: None,
130                legend: None,
131            }
132        });
133
134        let layout = LayoutIR {
135            title: plot_title.clone(),
136            x_title: if grid.is_some() {
137                None
138            } else {
139                x_title.clone()
140            },
141            y_title: if grid.is_some() {
142                None
143            } else {
144                y_title.clone()
145            },
146            y2_title: None,
147            z_title: None,
148            legend_title: None,
149            legend: None,
150            dimensions: None,
151            bar_mode: None,
152            box_mode: None,
153            box_gap: None,
154            margin_bottom: None,
155            axes_2d: if grid.is_some() {
156                None
157            } else {
158                Some(crate::ir::layout::Axes2dIR {
159                    x_axis: x_axis.cloned(),
160                    y_axis: y_axis.cloned(),
161                    y2_axis: None,
162                })
163            },
164            scene_3d: None,
165            polar: None,
166            mapbox: None,
167            grid,
168            annotations: vec![],
169        };
170
171        let traces = match facet {
172            Some(facet_column) => {
173                let config = facet_config.cloned().unwrap_or_default();
174
175                Self::create_ir_traces_faceted(
176                    data,
177                    x,
178                    y,
179                    z,
180                    facet_column,
181                    &config,
182                    auto_color_scale,
183                    color_bar,
184                    color_scale,
185                    reverse_scale,
186                    show_scale,
187                )
188            }
189            None => Self::create_ir_traces(
190                data,
191                x,
192                y,
193                z,
194                auto_color_scale,
195                color_bar,
196                color_scale,
197                reverse_scale,
198                show_scale,
199            ),
200        };
201
202        Self { traces, layout }
203    }
204}
205
206#[bon]
207impl HeatMap {
208    #[builder(
209        start_fn = try_builder,
210        finish_fn = try_build,
211        builder_type = HeatMapTryBuilder,
212        on(String, into),
213        on(Text, into),
214    )]
215    pub fn try_new(
216        data: &DataFrame,
217        x: &str,
218        y: &str,
219        z: &str,
220        facet: Option<&str>,
221        facet_config: Option<&FacetConfig>,
222        auto_color_scale: Option<bool>,
223        color_bar: Option<&ColorBar>,
224        color_scale: Option<Palette>,
225        reverse_scale: Option<bool>,
226        show_scale: Option<bool>,
227        plot_title: Option<Text>,
228        x_title: Option<Text>,
229        y_title: Option<Text>,
230        x_axis: Option<&Axis>,
231        y_axis: Option<&Axis>,
232    ) -> Result<Self, crate::io::PlotlarsError> {
233        std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
234            Self::__orig_new(
235                data,
236                x,
237                y,
238                z,
239                facet,
240                facet_config,
241                auto_color_scale,
242                color_bar,
243                color_scale,
244                reverse_scale,
245                show_scale,
246                plot_title,
247                x_title,
248                y_title,
249                x_axis,
250                y_axis,
251            )
252        }))
253        .map_err(|panic| {
254            let msg = panic
255                .downcast_ref::<String>()
256                .cloned()
257                .or_else(|| panic.downcast_ref::<&str>().map(|s| s.to_string()))
258                .unwrap_or_else(|| "unknown error".to_string());
259            crate::io::PlotlarsError::PlotBuild { message: msg }
260        })
261    }
262}
263
264impl HeatMap {
265    #[allow(clippy::too_many_arguments)]
266    fn create_ir_traces(
267        data: &DataFrame,
268        x: &str,
269        y: &str,
270        z: &str,
271        auto_color_scale: Option<bool>,
272        color_bar: Option<&ColorBar>,
273        color_scale: Option<Palette>,
274        reverse_scale: Option<bool>,
275        show_scale: Option<bool>,
276    ) -> Vec<TraceIR> {
277        vec![TraceIR::HeatMap(HeatMapIR {
278            x: ColumnData::String(crate::data::get_string_column(data, x)),
279            y: ColumnData::String(crate::data::get_string_column(data, y)),
280            z: ColumnData::Numeric(crate::data::get_numeric_column(data, z)),
281            color_scale,
282            color_bar: color_bar.cloned(),
283            auto_color_scale,
284            reverse_scale,
285            show_scale,
286            z_min: None,
287            z_max: None,
288            subplot_ref: None,
289        })]
290    }
291
292    #[allow(clippy::too_many_arguments)]
293    fn create_ir_traces_faceted(
294        data: &DataFrame,
295        x: &str,
296        y: &str,
297        z: &str,
298        facet_column: &str,
299        config: &FacetConfig,
300        auto_color_scale: Option<bool>,
301        color_bar: Option<&ColorBar>,
302        color_scale: Option<Palette>,
303        reverse_scale: Option<bool>,
304        show_scale: Option<bool>,
305    ) -> Vec<TraceIR> {
306        const MAX_FACETS: usize = 8;
307
308        let facet_categories = crate::data::get_unique_groups(data, facet_column, config.sorter);
309
310        if facet_categories.len() > MAX_FACETS {
311            panic!(
312                "Facet column '{}' has {} unique values, but plotly.rs supports maximum {} subplots",
313                facet_column,
314                facet_categories.len(),
315                MAX_FACETS
316            );
317        }
318
319        let use_global_z = !matches!(config.scales, FacetScales::Free);
320        let global_z_range = if use_global_z {
321            Some(Self::calculate_global_z_range(data, z))
322        } else {
323            None
324        };
325
326        let mut traces = Vec::new();
327
328        for (facet_idx, facet_value) in facet_categories.iter().enumerate() {
329            let facet_data = crate::data::filter_data_by_group(data, facet_column, facet_value);
330
331            let subplot_ref = format!(
332                "{}{}",
333                crate::faceting::get_axis_reference(facet_idx, "x"),
334                crate::faceting::get_axis_reference(facet_idx, "y")
335            );
336
337            let show_scale_for_trace = if facet_idx == 0 {
338                show_scale
339            } else {
340                Some(false)
341            };
342
343            let (z_min, z_max) = match global_z_range {
344                Some((zmin, zmax)) => (Some(zmin as f64), Some(zmax as f64)),
345                None => (None, None),
346            };
347
348            traces.push(TraceIR::HeatMap(HeatMapIR {
349                x: ColumnData::String(crate::data::get_string_column(&facet_data, x)),
350                y: ColumnData::String(crate::data::get_string_column(&facet_data, y)),
351                z: ColumnData::Numeric(crate::data::get_numeric_column(&facet_data, z)),
352                color_scale,
353                color_bar: color_bar.cloned(),
354                auto_color_scale,
355                reverse_scale,
356                show_scale: show_scale_for_trace,
357                z_min,
358                z_max,
359                subplot_ref: Some(subplot_ref),
360            }));
361        }
362
363        traces
364    }
365
366    fn calculate_global_z_range(data: &DataFrame, z: &str) -> (f32, f32) {
367        let z_data = crate::data::get_numeric_column(data, z);
368
369        let values: Vec<f32> = z_data.iter().filter_map(|v| *v).collect();
370
371        if values.is_empty() {
372            return (0.0, 1.0);
373        }
374
375        let min = values.iter().cloned().fold(f32::INFINITY, f32::min);
376        let max = values.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
377
378        (min, max)
379    }
380}
381
382impl crate::Plot for HeatMap {
383    fn ir_traces(&self) -> &[TraceIR] {
384        &self.traces
385    }
386
387    fn ir_layout(&self) -> &LayoutIR {
388        &self.layout
389    }
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395    use crate::Plot;
396    use polars::prelude::*;
397
398    #[test]
399    fn test_basic_one_trace() {
400        let df = df![
401            "x" => ["a", "b", "c"],
402            "y" => ["d", "e", "f"],
403            "z" => [1.0, 2.0, 3.0]
404        ]
405        .unwrap();
406        let plot = HeatMap::builder().data(&df).x("x").y("y").z("z").build();
407        assert_eq!(plot.ir_traces().len(), 1);
408        assert!(matches!(plot.ir_traces()[0], TraceIR::HeatMap(_)));
409    }
410
411    #[test]
412    fn test_layout_has_axes() {
413        let df = df![
414            "x" => ["a", "b"],
415            "y" => ["c", "d"],
416            "z" => [1.0, 2.0]
417        ]
418        .unwrap();
419        let plot = HeatMap::builder().data(&df).x("x").y("y").z("z").build();
420        assert!(plot.ir_layout().axes_2d.is_some());
421    }
422
423    #[test]
424    fn test_layout_title() {
425        let df = df![
426            "x" => ["a"],
427            "y" => ["b"],
428            "z" => [1.0]
429        ]
430        .unwrap();
431        let plot = HeatMap::builder()
432            .data(&df)
433            .x("x")
434            .y("y")
435            .z("z")
436            .plot_title("Heat")
437            .build();
438        assert!(plot.ir_layout().title.is_some());
439    }
440}