use bon::bon;
use crate::{
components::{Axis, ColorBar, FacetConfig, FacetScales, Palette, Text},
ir::data::ColumnData,
ir::layout::LayoutIR,
ir::trace::{HeatMapIR, TraceIR},
};
use polars::frame::DataFrame;
#[derive(Clone)]
#[allow(dead_code)]
pub struct HeatMap {
traces: Vec<TraceIR>,
layout: LayoutIR,
}
#[bon]
impl HeatMap {
#[builder(on(String, into), on(Text, into))]
pub fn new(
data: &DataFrame,
x: &str,
y: &str,
z: &str,
facet: Option<&str>,
facet_config: Option<&FacetConfig>,
auto_color_scale: Option<bool>,
color_bar: Option<&ColorBar>,
color_scale: Option<Palette>,
reverse_scale: Option<bool>,
show_scale: Option<bool>,
plot_title: Option<Text>,
x_title: Option<Text>,
y_title: Option<Text>,
x_axis: Option<&Axis>,
y_axis: Option<&Axis>,
) -> Self {
let grid = facet.map(|facet_column| {
let config = facet_config.cloned().unwrap_or_default();
let facet_categories =
crate::data::get_unique_groups(data, facet_column, config.sorter);
let n_facets = facet_categories.len();
let (ncols, nrows) =
crate::faceting::calculate_grid_dimensions(n_facets, config.cols, config.rows);
crate::ir::facet::GridSpec {
kind: crate::ir::facet::FacetKind::Axis,
rows: nrows,
cols: ncols,
h_gap: config.h_gap,
v_gap: config.v_gap,
scales: config.scales.clone(),
n_facets,
facet_categories,
title_style: config.title_style.clone(),
x_title: x_title.clone(),
y_title: y_title.clone(),
x_axis: x_axis.cloned(),
y_axis: y_axis.cloned(),
legend_title: None,
legend: None,
}
});
let layout = LayoutIR {
title: plot_title.clone(),
x_title: if grid.is_some() {
None
} else {
x_title.clone()
},
y_title: if grid.is_some() {
None
} else {
y_title.clone()
},
y2_title: None,
z_title: None,
legend_title: None,
legend: None,
dimensions: None,
bar_mode: None,
box_mode: None,
box_gap: None,
margin_bottom: None,
axes_2d: if grid.is_some() {
None
} else {
Some(crate::ir::layout::Axes2dIR {
x_axis: x_axis.cloned(),
y_axis: y_axis.cloned(),
y2_axis: None,
})
},
scene_3d: None,
polar: None,
mapbox: None,
grid,
annotations: vec![],
};
let traces = match facet {
Some(facet_column) => {
let config = facet_config.cloned().unwrap_or_default();
Self::create_ir_traces_faceted(
data,
x,
y,
z,
facet_column,
&config,
auto_color_scale,
color_bar,
color_scale,
reverse_scale,
show_scale,
)
}
None => Self::create_ir_traces(
data,
x,
y,
z,
auto_color_scale,
color_bar,
color_scale,
reverse_scale,
show_scale,
),
};
Self { traces, layout }
}
}
#[bon]
impl HeatMap {
#[builder(
start_fn = try_builder,
finish_fn = try_build,
builder_type = HeatMapTryBuilder,
on(String, into),
on(Text, into),
)]
pub fn try_new(
data: &DataFrame,
x: &str,
y: &str,
z: &str,
facet: Option<&str>,
facet_config: Option<&FacetConfig>,
auto_color_scale: Option<bool>,
color_bar: Option<&ColorBar>,
color_scale: Option<Palette>,
reverse_scale: Option<bool>,
show_scale: Option<bool>,
plot_title: Option<Text>,
x_title: Option<Text>,
y_title: Option<Text>,
x_axis: Option<&Axis>,
y_axis: Option<&Axis>,
) -> Result<Self, crate::io::PlotlarsError> {
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
Self::__orig_new(
data,
x,
y,
z,
facet,
facet_config,
auto_color_scale,
color_bar,
color_scale,
reverse_scale,
show_scale,
plot_title,
x_title,
y_title,
x_axis,
y_axis,
)
}))
.map_err(|panic| {
let msg = panic
.downcast_ref::<String>()
.cloned()
.or_else(|| panic.downcast_ref::<&str>().map(|s| s.to_string()))
.unwrap_or_else(|| "unknown error".to_string());
crate::io::PlotlarsError::PlotBuild { message: msg }
})
}
}
impl HeatMap {
#[allow(clippy::too_many_arguments)]
fn create_ir_traces(
data: &DataFrame,
x: &str,
y: &str,
z: &str,
auto_color_scale: Option<bool>,
color_bar: Option<&ColorBar>,
color_scale: Option<Palette>,
reverse_scale: Option<bool>,
show_scale: Option<bool>,
) -> Vec<TraceIR> {
vec![TraceIR::HeatMap(HeatMapIR {
x: ColumnData::String(crate::data::get_string_column(data, x)),
y: ColumnData::String(crate::data::get_string_column(data, y)),
z: ColumnData::Numeric(crate::data::get_numeric_column(data, z)),
color_scale,
color_bar: color_bar.cloned(),
auto_color_scale,
reverse_scale,
show_scale,
z_min: None,
z_max: None,
subplot_ref: None,
})]
}
#[allow(clippy::too_many_arguments)]
fn create_ir_traces_faceted(
data: &DataFrame,
x: &str,
y: &str,
z: &str,
facet_column: &str,
config: &FacetConfig,
auto_color_scale: Option<bool>,
color_bar: Option<&ColorBar>,
color_scale: Option<Palette>,
reverse_scale: Option<bool>,
show_scale: Option<bool>,
) -> Vec<TraceIR> {
const MAX_FACETS: usize = 8;
let facet_categories = crate::data::get_unique_groups(data, facet_column, config.sorter);
if facet_categories.len() > MAX_FACETS {
panic!(
"Facet column '{}' has {} unique values, but plotly.rs supports maximum {} subplots",
facet_column,
facet_categories.len(),
MAX_FACETS
);
}
let use_global_z = !matches!(config.scales, FacetScales::Free);
let global_z_range = if use_global_z {
Some(Self::calculate_global_z_range(data, z))
} else {
None
};
let mut traces = Vec::new();
for (facet_idx, facet_value) in facet_categories.iter().enumerate() {
let facet_data = crate::data::filter_data_by_group(data, facet_column, facet_value);
let subplot_ref = format!(
"{}{}",
crate::faceting::get_axis_reference(facet_idx, "x"),
crate::faceting::get_axis_reference(facet_idx, "y")
);
let show_scale_for_trace = if facet_idx == 0 {
show_scale
} else {
Some(false)
};
let (z_min, z_max) = match global_z_range {
Some((zmin, zmax)) => (Some(zmin as f64), Some(zmax as f64)),
None => (None, None),
};
traces.push(TraceIR::HeatMap(HeatMapIR {
x: ColumnData::String(crate::data::get_string_column(&facet_data, x)),
y: ColumnData::String(crate::data::get_string_column(&facet_data, y)),
z: ColumnData::Numeric(crate::data::get_numeric_column(&facet_data, z)),
color_scale,
color_bar: color_bar.cloned(),
auto_color_scale,
reverse_scale,
show_scale: show_scale_for_trace,
z_min,
z_max,
subplot_ref: Some(subplot_ref),
}));
}
traces
}
fn calculate_global_z_range(data: &DataFrame, z: &str) -> (f32, f32) {
let z_data = crate::data::get_numeric_column(data, z);
let values: Vec<f32> = z_data.iter().filter_map(|v| *v).collect();
if values.is_empty() {
return (0.0, 1.0);
}
let min = values.iter().cloned().fold(f32::INFINITY, f32::min);
let max = values.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
(min, max)
}
}
impl crate::Plot for HeatMap {
fn ir_traces(&self) -> &[TraceIR] {
&self.traces
}
fn ir_layout(&self) -> &LayoutIR {
&self.layout
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Plot;
use polars::prelude::*;
#[test]
fn test_basic_one_trace() {
let df = df![
"x" => ["a", "b", "c"],
"y" => ["d", "e", "f"],
"z" => [1.0, 2.0, 3.0]
]
.unwrap();
let plot = HeatMap::builder().data(&df).x("x").y("y").z("z").build();
assert_eq!(plot.ir_traces().len(), 1);
assert!(matches!(plot.ir_traces()[0], TraceIR::HeatMap(_)));
}
#[test]
fn test_layout_has_axes() {
let df = df![
"x" => ["a", "b"],
"y" => ["c", "d"],
"z" => [1.0, 2.0]
]
.unwrap();
let plot = HeatMap::builder().data(&df).x("x").y("y").z("z").build();
assert!(plot.ir_layout().axes_2d.is_some());
}
#[test]
fn test_layout_title() {
let df = df![
"x" => ["a"],
"y" => ["b"],
"z" => [1.0]
]
.unwrap();
let plot = HeatMap::builder()
.data(&df)
.x("x")
.y("y")
.z("z")
.plot_title("Heat")
.build();
assert!(plot.ir_layout().title.is_some());
}
}