use crate::core::Result as PlotResult;
use crate::core::style_utils::StyleResolver;
use crate::plots::traits::{PlotArea, PlotConfig, PlotData, PlotRender};
use crate::render::skia::SkiaRenderer;
use crate::render::{Color, ColorMap, Theme};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Interpolation {
#[default]
Nearest,
Bilinear,
}
#[derive(Debug, Clone)]
pub struct HeatmapConfig {
pub colormap: ColorMap,
pub vmin: Option<f64>,
pub vmax: Option<f64>,
pub colorbar: bool,
pub colorbar_label: Option<String>,
pub colorbar_tick_font_size: f32,
pub colorbar_label_font_size: f32,
pub xticklabels: Option<Vec<String>>,
pub yticklabels: Option<Vec<String>>,
pub interpolation: Interpolation,
pub annotate: bool,
pub annotation_format: String,
pub aspect: Option<f64>,
pub alpha: f32,
}
impl Default for HeatmapConfig {
fn default() -> Self {
Self {
colormap: ColorMap::viridis(),
vmin: None,
vmax: None,
colorbar: true,
colorbar_label: None,
colorbar_tick_font_size: 12.0, colorbar_label_font_size: 14.0, xticklabels: None,
yticklabels: None,
interpolation: Interpolation::Nearest,
annotate: false,
annotation_format: "{:.2}".to_string(),
aspect: None,
alpha: 1.0,
}
}
}
impl HeatmapConfig {
pub fn new() -> Self {
Self::default()
}
pub fn colormap(mut self, colormap: ColorMap) -> Self {
self.colormap = colormap;
self
}
pub fn vmin(mut self, vmin: f64) -> Self {
self.vmin = Some(vmin);
self
}
pub fn vmax(mut self, vmax: f64) -> Self {
self.vmax = Some(vmax);
self
}
pub fn colorbar(mut self, show: bool) -> Self {
self.colorbar = show;
self
}
pub fn colorbar_label<S: Into<String>>(mut self, label: S) -> Self {
self.colorbar_label = Some(label.into());
self
}
pub fn colorbar_tick_font_size(mut self, size: f32) -> Self {
self.colorbar_tick_font_size = size.max(1.0);
self
}
pub fn colorbar_label_font_size(mut self, size: f32) -> Self {
self.colorbar_label_font_size = size.max(1.0);
self
}
pub fn xticklabels(mut self, labels: Vec<String>) -> Self {
self.xticklabels = Some(labels);
self
}
pub fn yticklabels(mut self, labels: Vec<String>) -> Self {
self.yticklabels = Some(labels);
self
}
pub fn interpolation(mut self, method: Interpolation) -> Self {
self.interpolation = method;
self
}
pub fn annotate(mut self, show: bool) -> Self {
self.annotate = show;
self
}
pub fn annotation_format<S: Into<String>>(mut self, format: S) -> Self {
self.annotation_format = format.into();
self
}
pub fn aspect(mut self, ratio: f64) -> Self {
self.aspect = Some(ratio);
self
}
pub fn alpha(mut self, alpha: f32) -> Self {
self.alpha = alpha.clamp(0.0, 1.0);
self
}
}
impl PlotConfig for HeatmapConfig {}
#[derive(Debug, Clone)]
pub struct HeatmapData {
pub values: Vec<Vec<f64>>,
pub n_rows: usize,
pub n_cols: usize,
pub data_min: f64,
pub data_max: f64,
pub vmin: f64,
pub vmax: f64,
pub config: HeatmapConfig,
}
impl HeatmapData {
pub fn get_color(&self, value: f64) -> Color {
let range = self.vmax - self.vmin;
if range.abs() < f64::EPSILON {
return self.config.colormap.sample(0.5);
}
let normalized = ((value - self.vmin) / range).clamp(0.0, 1.0);
self.config.colormap.sample(normalized)
}
pub fn get_text_color(&self, background: Color) -> Color {
let luminance = 0.299 * (background.r as f64)
+ 0.587 * (background.g as f64)
+ 0.114 * (background.b as f64);
if luminance > 128.0 {
Color::BLACK
} else {
Color::WHITE
}
}
}
pub fn process_heatmap(data: &[Vec<f64>], config: HeatmapConfig) -> Result<HeatmapData, String> {
if data.is_empty() {
return Err("Heatmap data is empty".to_string());
}
let n_rows = data.len();
let n_cols = data[0].len();
for (i, row) in data.iter().enumerate() {
if row.len() != n_cols {
return Err(format!(
"Row {} has {} columns, expected {}",
i,
row.len(),
n_cols
));
}
}
let mut data_min = f64::INFINITY;
let mut data_max = f64::NEG_INFINITY;
for row in data {
for &value in row {
if value.is_finite() {
data_min = data_min.min(value);
data_max = data_max.max(value);
}
}
}
if !data_min.is_finite() || !data_max.is_finite() {
return Err("Heatmap data contains only non-finite values".to_string());
}
let vmin = config.vmin.unwrap_or(data_min);
let vmax = config.vmax.unwrap_or(data_max);
Ok(HeatmapData {
values: data.to_vec(),
n_rows,
n_cols,
data_min,
data_max,
vmin,
vmax,
config,
})
}
pub fn process_heatmap_flat(
data: &[f64],
n_rows: usize,
n_cols: usize,
config: HeatmapConfig,
) -> Result<HeatmapData, String> {
if data.len() != n_rows * n_cols {
return Err(format!(
"Data length {} does not match dimensions {}x{}",
data.len(),
n_rows,
n_cols
));
}
let values: Vec<Vec<f64>> = (0..n_rows)
.map(|r| data[r * n_cols..(r + 1) * n_cols].to_vec())
.collect();
process_heatmap(&values, config)
}
impl PlotData for HeatmapData {
fn data_bounds(&self) -> ((f64, f64), (f64, f64)) {
((0.0, self.n_cols as f64), (0.0, self.n_rows as f64))
}
fn is_empty(&self) -> bool {
self.values.is_empty() || self.values[0].is_empty()
}
}
impl PlotRender for HeatmapData {
fn render(
&self,
renderer: &mut SkiaRenderer,
area: &PlotArea,
_theme: &Theme,
_color: Color, ) -> PlotResult<()> {
if self.is_empty() {
return Ok(());
}
let config = &self.config;
let alpha = config.alpha;
for row in 0..self.n_rows {
for col in 0..self.n_cols {
let value = self.values[row][col];
if !value.is_finite() {
continue;
}
let cell_color = self.get_color(value).with_alpha(alpha);
let x1 = col as f64;
let x2 = (col + 1) as f64;
let y1 = (self.n_rows - row - 1) as f64;
let y2 = (self.n_rows - row) as f64;
let (sx1, sy1) = area.data_to_screen(x1, y2);
let (sx2, sy2) = area.data_to_screen(x2, y1);
renderer.draw_rectangle(sx1, sy1, sx2 - sx1, sy2 - sy1, cell_color, true)?;
}
}
Ok(())
}
fn render_styled(
&self,
renderer: &mut SkiaRenderer,
area: &PlotArea,
theme: &Theme,
_color: Color,
alpha: f32,
_line_width: Option<f32>,
) -> PlotResult<()> {
if self.is_empty() {
return Ok(());
}
let config = &self.config;
let _resolver = StyleResolver::new(theme);
let effective_alpha = if alpha != 1.0 { alpha } else { config.alpha };
for row in 0..self.n_rows {
for col in 0..self.n_cols {
let value = self.values[row][col];
if !value.is_finite() {
continue;
}
let cell_color = self.get_color(value).with_alpha(effective_alpha);
let x1 = col as f64;
let x2 = (col + 1) as f64;
let y1 = (self.n_rows - row - 1) as f64;
let y2 = (self.n_rows - row) as f64;
let (sx1, sy1) = area.data_to_screen(x1, y2);
let (sx2, sy2) = area.data_to_screen(x2, y1);
renderer.draw_rectangle(sx1, sy1, sx2 - sx1, sy2 - sy1, cell_color, true)?;
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_heatmap_config_defaults() {
let config = HeatmapConfig::default();
assert!(config.colorbar);
assert!(!config.annotate);
assert_eq!(config.interpolation, Interpolation::Nearest);
assert!(config.vmin.is_none());
assert!(config.vmax.is_none());
}
#[test]
fn test_heatmap_config_builder() {
let config = HeatmapConfig::new()
.colormap(ColorMap::plasma())
.vmin(0.0)
.vmax(100.0)
.colorbar(true)
.colorbar_label("Temperature")
.annotate(true);
assert_eq!(config.vmin, Some(0.0));
assert_eq!(config.vmax, Some(100.0));
assert!(config.colorbar);
assert_eq!(config.colorbar_label, Some("Temperature".to_string()));
assert!(config.annotate);
}
#[test]
fn test_process_heatmap() {
let data = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
let config = HeatmapConfig::default();
let result = process_heatmap(&data, config).unwrap();
assert_eq!(result.n_rows, 2);
assert_eq!(result.n_cols, 3);
assert!((result.data_min - 1.0).abs() < f64::EPSILON);
assert!((result.data_max - 6.0).abs() < f64::EPSILON);
}
#[test]
fn test_process_heatmap_with_vmin_vmax() {
let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let config = HeatmapConfig::new().vmin(0.0).vmax(10.0);
let result = process_heatmap(&data, config).unwrap();
assert!((result.vmin - 0.0).abs() < f64::EPSILON);
assert!((result.vmax - 10.0).abs() < f64::EPSILON);
}
#[test]
fn test_process_heatmap_empty() {
let data: Vec<Vec<f64>> = vec![];
let config = HeatmapConfig::default();
assert!(process_heatmap(&data, config).is_err());
}
#[test]
fn test_process_heatmap_jagged() {
let data = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0]]; let config = HeatmapConfig::default();
assert!(process_heatmap(&data, config).is_err());
}
#[test]
fn test_heatmap_get_color() {
let data = vec![vec![0.0, 1.0]];
let config = HeatmapConfig::new().vmin(0.0).vmax(1.0);
let heatmap = process_heatmap(&data, config).unwrap();
let color_min = heatmap.get_color(0.0);
let color_max = heatmap.get_color(1.0);
assert!(color_min != color_max);
}
#[test]
fn test_get_text_color() {
let data = vec![vec![0.0, 1.0]];
let config = HeatmapConfig::default();
let heatmap = process_heatmap(&data, config).unwrap();
let white_text = heatmap.get_text_color(Color::BLACK);
assert_eq!(white_text, Color::WHITE);
let black_text = heatmap.get_text_color(Color::WHITE);
assert_eq!(black_text, Color::BLACK);
}
#[test]
fn test_process_heatmap_flat() {
let flat_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let config = HeatmapConfig::default();
let result = process_heatmap_flat(&flat_data, 2, 3, config).unwrap();
assert_eq!(result.n_rows, 2);
assert_eq!(result.n_cols, 3);
assert_eq!(result.values[0], vec![1.0, 2.0, 3.0]);
assert_eq!(result.values[1], vec![4.0, 5.0, 6.0]);
}
#[test]
fn test_interpolation_enum() {
assert_eq!(Interpolation::default(), Interpolation::Nearest);
}
#[test]
fn test_heatmap_config_implements_plot_config() {
fn assert_plot_config<T: PlotConfig>() {}
assert_plot_config::<HeatmapConfig>();
}
#[test]
fn test_heatmap_plot_data_trait() {
let data = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
let config = HeatmapConfig::default();
let heatmap = process_heatmap(&data, config).unwrap();
let ((x_min, x_max), (y_min, y_max)) = heatmap.data_bounds();
assert!((x_min - 0.0).abs() < 0.001);
assert!((x_max - 3.0).abs() < 0.001);
assert!((y_min - 0.0).abs() < 0.001);
assert!((y_max - 2.0).abs() < 0.001);
assert!(!heatmap.is_empty());
}
}