1use bon::bon;
2
3use crate::{
4 components::{Axis, ColorBar, FacetConfig, FacetScales, Palette, Text},
5 ir::data::ColumnData,
6 ir::layout::LayoutIR,
7 ir::trace::{HeatMapIR, TraceIR},
8};
9use polars::frame::DataFrame;
10
11#[derive(Clone)]
81#[allow(dead_code)]
82pub struct HeatMap {
83 traces: Vec<TraceIR>,
84 layout: LayoutIR,
85}
86
87#[bon]
88impl HeatMap {
89 #[builder(on(String, into), on(Text, into))]
90 pub fn new(
91 data: &DataFrame,
92 x: &str,
93 y: &str,
94 z: &str,
95 facet: Option<&str>,
96 facet_config: Option<&FacetConfig>,
97 auto_color_scale: Option<bool>,
98 color_bar: Option<&ColorBar>,
99 color_scale: Option<Palette>,
100 reverse_scale: Option<bool>,
101 show_scale: Option<bool>,
102 plot_title: Option<Text>,
103 x_title: Option<Text>,
104 y_title: Option<Text>,
105 x_axis: Option<&Axis>,
106 y_axis: Option<&Axis>,
107 ) -> Self {
108 let grid = facet.map(|facet_column| {
109 let config = facet_config.cloned().unwrap_or_default();
110 let facet_categories =
111 crate::data::get_unique_groups(data, facet_column, config.sorter);
112 let n_facets = facet_categories.len();
113 let (ncols, nrows) =
114 crate::faceting::calculate_grid_dimensions(n_facets, config.cols, config.rows);
115 crate::ir::facet::GridSpec {
116 kind: crate::ir::facet::FacetKind::Axis,
117 rows: nrows,
118 cols: ncols,
119 h_gap: config.h_gap,
120 v_gap: config.v_gap,
121 scales: config.scales.clone(),
122 n_facets,
123 facet_categories,
124 title_style: config.title_style.clone(),
125 x_title: x_title.clone(),
126 y_title: y_title.clone(),
127 x_axis: x_axis.cloned(),
128 y_axis: y_axis.cloned(),
129 legend_title: None,
130 legend: None,
131 }
132 });
133
134 let layout = LayoutIR {
135 title: plot_title.clone(),
136 x_title: if grid.is_some() {
137 None
138 } else {
139 x_title.clone()
140 },
141 y_title: if grid.is_some() {
142 None
143 } else {
144 y_title.clone()
145 },
146 y2_title: None,
147 z_title: None,
148 legend_title: None,
149 legend: None,
150 dimensions: None,
151 bar_mode: None,
152 box_mode: None,
153 box_gap: None,
154 margin_bottom: None,
155 axes_2d: if grid.is_some() {
156 None
157 } else {
158 Some(crate::ir::layout::Axes2dIR {
159 x_axis: x_axis.cloned(),
160 y_axis: y_axis.cloned(),
161 y2_axis: None,
162 })
163 },
164 scene_3d: None,
165 polar: None,
166 mapbox: None,
167 grid,
168 annotations: vec![],
169 };
170
171 let traces = match facet {
172 Some(facet_column) => {
173 let config = facet_config.cloned().unwrap_or_default();
174
175 Self::create_ir_traces_faceted(
176 data,
177 x,
178 y,
179 z,
180 facet_column,
181 &config,
182 auto_color_scale,
183 color_bar,
184 color_scale,
185 reverse_scale,
186 show_scale,
187 )
188 }
189 None => Self::create_ir_traces(
190 data,
191 x,
192 y,
193 z,
194 auto_color_scale,
195 color_bar,
196 color_scale,
197 reverse_scale,
198 show_scale,
199 ),
200 };
201
202 Self { traces, layout }
203 }
204}
205
206#[bon]
207impl HeatMap {
208 #[builder(
209 start_fn = try_builder,
210 finish_fn = try_build,
211 builder_type = HeatMapTryBuilder,
212 on(String, into),
213 on(Text, into),
214 )]
215 pub fn try_new(
216 data: &DataFrame,
217 x: &str,
218 y: &str,
219 z: &str,
220 facet: Option<&str>,
221 facet_config: Option<&FacetConfig>,
222 auto_color_scale: Option<bool>,
223 color_bar: Option<&ColorBar>,
224 color_scale: Option<Palette>,
225 reverse_scale: Option<bool>,
226 show_scale: Option<bool>,
227 plot_title: Option<Text>,
228 x_title: Option<Text>,
229 y_title: Option<Text>,
230 x_axis: Option<&Axis>,
231 y_axis: Option<&Axis>,
232 ) -> Result<Self, crate::io::PlotlarsError> {
233 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
234 Self::__orig_new(
235 data,
236 x,
237 y,
238 z,
239 facet,
240 facet_config,
241 auto_color_scale,
242 color_bar,
243 color_scale,
244 reverse_scale,
245 show_scale,
246 plot_title,
247 x_title,
248 y_title,
249 x_axis,
250 y_axis,
251 )
252 }))
253 .map_err(|panic| {
254 let msg = panic
255 .downcast_ref::<String>()
256 .cloned()
257 .or_else(|| panic.downcast_ref::<&str>().map(|s| s.to_string()))
258 .unwrap_or_else(|| "unknown error".to_string());
259 crate::io::PlotlarsError::PlotBuild { message: msg }
260 })
261 }
262}
263
264impl HeatMap {
265 #[allow(clippy::too_many_arguments)]
266 fn create_ir_traces(
267 data: &DataFrame,
268 x: &str,
269 y: &str,
270 z: &str,
271 auto_color_scale: Option<bool>,
272 color_bar: Option<&ColorBar>,
273 color_scale: Option<Palette>,
274 reverse_scale: Option<bool>,
275 show_scale: Option<bool>,
276 ) -> Vec<TraceIR> {
277 vec![TraceIR::HeatMap(HeatMapIR {
278 x: ColumnData::String(crate::data::get_string_column(data, x)),
279 y: ColumnData::String(crate::data::get_string_column(data, y)),
280 z: ColumnData::Numeric(crate::data::get_numeric_column(data, z)),
281 color_scale,
282 color_bar: color_bar.cloned(),
283 auto_color_scale,
284 reverse_scale,
285 show_scale,
286 z_min: None,
287 z_max: None,
288 subplot_ref: None,
289 })]
290 }
291
292 #[allow(clippy::too_many_arguments)]
293 fn create_ir_traces_faceted(
294 data: &DataFrame,
295 x: &str,
296 y: &str,
297 z: &str,
298 facet_column: &str,
299 config: &FacetConfig,
300 auto_color_scale: Option<bool>,
301 color_bar: Option<&ColorBar>,
302 color_scale: Option<Palette>,
303 reverse_scale: Option<bool>,
304 show_scale: Option<bool>,
305 ) -> Vec<TraceIR> {
306 const MAX_FACETS: usize = 8;
307
308 let facet_categories = crate::data::get_unique_groups(data, facet_column, config.sorter);
309
310 if facet_categories.len() > MAX_FACETS {
311 panic!(
312 "Facet column '{}' has {} unique values, but plotly.rs supports maximum {} subplots",
313 facet_column,
314 facet_categories.len(),
315 MAX_FACETS
316 );
317 }
318
319 let use_global_z = !matches!(config.scales, FacetScales::Free);
320 let global_z_range = if use_global_z {
321 Some(Self::calculate_global_z_range(data, z))
322 } else {
323 None
324 };
325
326 let mut traces = Vec::new();
327
328 for (facet_idx, facet_value) in facet_categories.iter().enumerate() {
329 let facet_data = crate::data::filter_data_by_group(data, facet_column, facet_value);
330
331 let subplot_ref = format!(
332 "{}{}",
333 crate::faceting::get_axis_reference(facet_idx, "x"),
334 crate::faceting::get_axis_reference(facet_idx, "y")
335 );
336
337 let show_scale_for_trace = if facet_idx == 0 {
338 show_scale
339 } else {
340 Some(false)
341 };
342
343 let (z_min, z_max) = match global_z_range {
344 Some((zmin, zmax)) => (Some(zmin as f64), Some(zmax as f64)),
345 None => (None, None),
346 };
347
348 traces.push(TraceIR::HeatMap(HeatMapIR {
349 x: ColumnData::String(crate::data::get_string_column(&facet_data, x)),
350 y: ColumnData::String(crate::data::get_string_column(&facet_data, y)),
351 z: ColumnData::Numeric(crate::data::get_numeric_column(&facet_data, z)),
352 color_scale,
353 color_bar: color_bar.cloned(),
354 auto_color_scale,
355 reverse_scale,
356 show_scale: show_scale_for_trace,
357 z_min,
358 z_max,
359 subplot_ref: Some(subplot_ref),
360 }));
361 }
362
363 traces
364 }
365
366 fn calculate_global_z_range(data: &DataFrame, z: &str) -> (f32, f32) {
367 let z_data = crate::data::get_numeric_column(data, z);
368
369 let values: Vec<f32> = z_data.iter().filter_map(|v| *v).collect();
370
371 if values.is_empty() {
372 return (0.0, 1.0);
373 }
374
375 let min = values.iter().cloned().fold(f32::INFINITY, f32::min);
376 let max = values.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
377
378 (min, max)
379 }
380}
381
382impl crate::Plot for HeatMap {
383 fn ir_traces(&self) -> &[TraceIR] {
384 &self.traces
385 }
386
387 fn ir_layout(&self) -> &LayoutIR {
388 &self.layout
389 }
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395 use crate::Plot;
396 use polars::prelude::*;
397
398 #[test]
399 fn test_basic_one_trace() {
400 let df = df![
401 "x" => ["a", "b", "c"],
402 "y" => ["d", "e", "f"],
403 "z" => [1.0, 2.0, 3.0]
404 ]
405 .unwrap();
406 let plot = HeatMap::builder().data(&df).x("x").y("y").z("z").build();
407 assert_eq!(plot.ir_traces().len(), 1);
408 assert!(matches!(plot.ir_traces()[0], TraceIR::HeatMap(_)));
409 }
410
411 #[test]
412 fn test_layout_has_axes() {
413 let df = df![
414 "x" => ["a", "b"],
415 "y" => ["c", "d"],
416 "z" => [1.0, 2.0]
417 ]
418 .unwrap();
419 let plot = HeatMap::builder().data(&df).x("x").y("y").z("z").build();
420 assert!(plot.ir_layout().axes_2d.is_some());
421 }
422
423 #[test]
424 fn test_layout_title() {
425 let df = df![
426 "x" => ["a"],
427 "y" => ["b"],
428 "z" => [1.0]
429 ]
430 .unwrap();
431 let plot = HeatMap::builder()
432 .data(&df)
433 .x("x")
434 .y("y")
435 .z("z")
436 .plot_title("Heat")
437 .build();
438 assert!(plot.ir_layout().title.is_some());
439 }
440}