use super::{HeatmapColorScale, HeatmapState};
const DEFAULT_BIN_COUNT: usize = 50;
pub struct DistributionMap {
snapshots: Vec<(String, Vec<f64>)>,
bin_count: usize,
range: Option<(f64, f64)>,
}
impl DistributionMap {
pub fn new() -> Self {
Self {
snapshots: Vec::new(),
bin_count: DEFAULT_BIN_COUNT,
range: None,
}
}
pub fn add_snapshot(mut self, label: impl Into<String>, values: &[f64]) -> Self {
self.snapshots.push((label.into(), values.to_vec()));
self
}
pub fn with_bins(mut self, bins: usize) -> Self {
self.bin_count = bins.max(1);
self
}
pub fn with_range(mut self, min: f64, max: f64) -> Self {
self.range = Some((min, max));
self
}
pub fn to_heatmap(&self) -> HeatmapState {
if self.snapshots.is_empty() {
return HeatmapState::default();
}
let (global_min, global_max) = self.compute_range();
let bin_count = self.bin_count;
let mut grid = vec![vec![0.0; self.snapshots.len()]; bin_count];
for (col_idx, (_label, values)) in self.snapshots.iter().enumerate() {
let histogram = bin_values(values, bin_count, global_min, global_max);
for (bin_idx, &count) in histogram.iter().enumerate() {
let row_idx = bin_count - 1 - bin_idx;
grid[row_idx][col_idx] = count as f64;
}
}
let row_labels = build_bin_labels(bin_count, global_min, global_max);
let col_labels: Vec<String> = self
.snapshots
.iter()
.map(|(label, _)| label.clone())
.collect();
HeatmapState::with_data(grid)
.with_row_labels(row_labels)
.with_col_labels(col_labels)
.with_color_scale(HeatmapColorScale::Inferno)
.with_title("Distribution Map")
}
fn compute_range(&self) -> (f64, f64) {
if let Some((min, max)) = self.range {
return (min, max);
}
let mut global_min = f64::INFINITY;
let mut global_max = f64::NEG_INFINITY;
for (_label, values) in &self.snapshots {
for &v in values {
if v < global_min {
global_min = v;
}
if v > global_max {
global_max = v;
}
}
}
if global_min.is_infinite() {
(0.0, 0.0)
} else {
(global_min, global_max)
}
}
}
impl Default for DistributionMap {
fn default() -> Self {
Self::new()
}
}
fn bin_values(values: &[f64], bin_count: usize, min: f64, max: f64) -> Vec<usize> {
let mut bins = vec![0usize; bin_count];
let range = max - min;
if range.abs() < f64::EPSILON {
for &v in values {
if (v - min).abs() < f64::EPSILON || (range.abs() < f64::EPSILON) {
bins[0] += 1;
}
}
return bins;
}
for &v in values {
let normalized = (v - min) / range;
let bin_idx = (normalized * bin_count as f64).floor() as usize;
let bin_idx = bin_idx.min(bin_count - 1);
bins[bin_idx] += 1;
}
bins
}
fn build_bin_labels(bin_count: usize, min: f64, max: f64) -> Vec<String> {
let range = max - min;
let bin_width = if range.abs() < f64::EPSILON {
0.0
} else {
range / bin_count as f64
};
let mut labels = Vec::with_capacity(bin_count);
for i in 0..bin_count {
let lo = min + i as f64 * bin_width;
let hi = lo + bin_width;
labels.push(format_range(lo, hi));
}
labels.reverse();
labels
}
fn format_range(lo: f64, hi: f64) -> String {
let width = hi - lo;
let precision = if width >= 1.0 || width.abs() < f64::EPSILON {
1
} else if width >= 0.1 {
2
} else if width >= 0.01 {
3
} else {
4
};
format!("{lo:.precision$}..{hi:.precision$}")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_usage_with_three_snapshots() {
let state = DistributionMap::new()
.add_snapshot("Step 0", &[-1.0, -0.5, 0.0, 0.5, 1.0])
.add_snapshot("Step 1", &[-0.8, -0.3, 0.1, 0.4, 0.9])
.add_snapshot("Step 2", &[-0.5, -0.1, 0.2, 0.3, 0.6])
.with_bins(10)
.to_heatmap();
assert_eq!(state.rows(), 10);
assert_eq!(state.cols(), 3);
assert_eq!(state.col_labels().len(), 3);
assert_eq!(state.row_labels().len(), 10);
assert_eq!(state.col_labels()[0], "Step 0");
assert_eq!(state.col_labels()[1], "Step 1");
assert_eq!(state.col_labels()[2], "Step 2");
}
#[test]
fn test_correct_heatmap_dimensions() {
let state = DistributionMap::new()
.add_snapshot("A", &[1.0, 2.0, 3.0, 4.0, 5.0])
.add_snapshot("B", &[2.0, 3.0, 4.0, 5.0, 6.0])
.with_bins(20)
.to_heatmap();
assert_eq!(state.rows(), 20);
assert_eq!(state.cols(), 2);
}
#[test]
fn test_default_color_scale_is_inferno() {
let state = DistributionMap::new()
.add_snapshot("X", &[0.0, 1.0])
.with_bins(5)
.to_heatmap();
assert_eq!(state.color_scale(), &HeatmapColorScale::Inferno);
}
#[test]
fn test_default_title_is_distribution_map() {
let state = DistributionMap::new()
.add_snapshot("X", &[0.0, 1.0])
.with_bins(5)
.to_heatmap();
assert_eq!(state.title(), Some("Distribution Map"));
}
#[test]
fn test_custom_bin_count() {
let state = DistributionMap::new()
.add_snapshot("T=0", &[0.0, 0.25, 0.5, 0.75, 1.0])
.with_bins(4)
.to_heatmap();
assert_eq!(state.rows(), 4);
assert_eq!(state.cols(), 1);
}
#[test]
fn test_single_bin() {
let state = DistributionMap::new()
.add_snapshot("All", &[1.0, 2.0, 3.0])
.with_bins(1)
.to_heatmap();
assert_eq!(state.rows(), 1);
assert_eq!(state.get(0, 0), Some(3.0));
}
#[test]
fn test_bin_count_zero_clamped_to_one() {
let state = DistributionMap::new()
.add_snapshot("X", &[1.0, 2.0])
.with_bins(0)
.to_heatmap();
assert_eq!(state.rows(), 1);
}
#[test]
fn test_default_bin_count_is_fifty() {
let state = DistributionMap::new()
.add_snapshot("X", &[0.0, 100.0])
.to_heatmap();
assert_eq!(state.rows(), DEFAULT_BIN_COUNT);
}
#[test]
fn test_fixed_range() {
let state = DistributionMap::new()
.add_snapshot("A", &[2.0, 3.0])
.with_bins(10)
.with_range(0.0, 10.0)
.to_heatmap();
assert_eq!(state.rows(), 10);
assert_eq!(state.cols(), 1);
let labels = state.row_labels();
assert!(labels.last().unwrap().starts_with("0.0"));
assert!(labels.first().unwrap().contains("10.0"));
}
#[test]
fn test_fixed_range_values_outside_range_clamped() {
let state = DistributionMap::new()
.add_snapshot("A", &[-5.0, 15.0])
.with_bins(5)
.with_range(0.0, 10.0)
.to_heatmap();
assert_eq!(state.rows(), 5);
let total: f64 = (0..5).map(|r| state.get(r, 0).unwrap_or(0.0)).sum();
assert_eq!(total, 2.0);
}
#[test]
fn test_no_snapshots_returns_empty_heatmap() {
let state = DistributionMap::new().to_heatmap();
assert_eq!(state.rows(), 0);
assert_eq!(state.cols(), 0);
}
#[test]
fn test_empty_values_snapshot() {
let state = DistributionMap::new()
.add_snapshot("Empty", &[])
.with_bins(5)
.to_heatmap();
assert_eq!(state.rows(), 5);
assert_eq!(state.cols(), 1);
for r in 0..5 {
assert_eq!(state.get(r, 0), Some(0.0));
}
}
#[test]
fn test_mix_of_empty_and_nonempty_snapshots() {
let state = DistributionMap::new()
.add_snapshot("Has Data", &[1.0, 2.0, 3.0])
.add_snapshot("Empty", &[])
.with_bins(5)
.to_heatmap();
assert_eq!(state.rows(), 5);
assert_eq!(state.cols(), 2);
for r in 0..5 {
assert_eq!(state.get(r, 1), Some(0.0));
}
}
#[test]
fn test_values_distributed_across_bins() {
let state = DistributionMap::new()
.add_snapshot("T", &[0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
.with_bins(5)
.with_range(0.0, 5.0)
.to_heatmap();
let total: f64 = (0..5).map(|r| state.get(r, 0).unwrap_or(0.0)).sum();
assert_eq!(total, 6.0);
}
#[test]
fn test_all_same_values() {
let state = DistributionMap::new()
.add_snapshot("Same", &[5.0, 5.0, 5.0])
.with_bins(3)
.to_heatmap();
assert_eq!(state.rows(), 3);
let total: f64 = (0..3).map(|r| state.get(r, 0).unwrap_or(0.0)).sum();
assert_eq!(total, 3.0);
}
#[test]
fn test_single_value_snapshot() {
let state = DistributionMap::new()
.add_snapshot("One", &[42.0])
.with_bins(10)
.to_heatmap();
assert_eq!(state.rows(), 10);
let total: f64 = (0..10).map(|r| state.get(r, 0).unwrap_or(0.0)).sum();
assert_eq!(total, 1.0);
}
#[test]
fn test_row_labels_are_reversed_high_to_low() {
let state = DistributionMap::new()
.add_snapshot("X", &[0.0, 10.0])
.with_bins(5)
.with_range(0.0, 10.0)
.to_heatmap();
let labels = state.row_labels();
assert_eq!(labels.len(), 5);
assert!(labels[0].contains("10.0"));
assert!(labels[4].starts_with("0.0"));
}
#[test]
fn test_column_labels_match_snapshot_order() {
let state = DistributionMap::new()
.add_snapshot("First", &[1.0])
.add_snapshot("Second", &[2.0])
.add_snapshot("Third", &[3.0])
.with_bins(5)
.to_heatmap();
assert_eq!(state.col_labels(), &["First", "Second", "Third"]);
}
#[test]
fn test_default_creates_empty() {
let dm = DistributionMap::default();
let state = dm.to_heatmap();
assert_eq!(state.rows(), 0);
assert_eq!(state.cols(), 0);
}
#[test]
fn test_builder_chaining() {
let state = DistributionMap::new()
.with_bins(8)
.with_range(-1.0, 1.0)
.add_snapshot("A", &[0.0])
.add_snapshot("B", &[0.5])
.to_heatmap();
assert_eq!(state.rows(), 8);
assert_eq!(state.cols(), 2);
}
#[test]
fn test_bin_values_even_distribution() {
let bins = bin_values(&[0.0, 1.0, 2.0, 3.0], 4, 0.0, 4.0);
assert_eq!(bins, vec![1, 1, 1, 1]);
}
#[test]
fn test_bin_values_max_value_in_last_bin() {
let bins = bin_values(&[10.0], 5, 0.0, 10.0);
assert_eq!(bins[4], 1);
assert_eq!(bins.iter().sum::<usize>(), 1);
}
#[test]
fn test_bin_values_all_same() {
let bins = bin_values(&[5.0, 5.0, 5.0], 3, 5.0, 5.0);
assert_eq!(bins[0], 3);
}
#[test]
fn test_bin_values_empty_input() {
let bins = bin_values(&[], 5, 0.0, 10.0);
assert_eq!(bins, vec![0, 0, 0, 0, 0]);
}
#[test]
fn test_format_range_large_bins() {
let label = format_range(0.0, 2.0);
assert_eq!(label, "0.0..2.0");
}
#[test]
fn test_format_range_small_bins() {
let label = format_range(0.0, 0.5);
assert_eq!(label, "0.00..0.50");
}
#[test]
fn test_format_range_tiny_bins() {
let label = format_range(0.0, 0.05);
assert_eq!(label, "0.000..0.050");
}
#[test]
fn test_format_range_very_tiny_bins() {
let label = format_range(0.0, 0.005);
assert_eq!(label, "0.0000..0.0050");
}
#[test]
fn test_heatmap_has_selection_set() {
let state = DistributionMap::new()
.add_snapshot("X", &[0.0, 1.0])
.with_bins(5)
.to_heatmap();
assert_eq!(state.selected(), Some((0, 0)));
}
#[test]
fn test_gradient_distribution_use_case() {
let epoch_1: Vec<f64> = (-50..=50).map(|i| i as f64 * 0.02).collect();
let epoch_5: Vec<f64> = (-50..=50).map(|i| i as f64 * 0.01).collect();
let epoch_10: Vec<f64> = (-50..=50).map(|i| i as f64 * 0.005).collect();
let state = DistributionMap::new()
.add_snapshot("Epoch 1", &epoch_1)
.add_snapshot("Epoch 5", &epoch_5)
.add_snapshot("Epoch 10", &epoch_10)
.with_bins(25)
.with_range(-1.0, 1.0)
.to_heatmap();
assert_eq!(state.rows(), 25);
assert_eq!(state.cols(), 3);
assert_eq!(state.col_labels(), &["Epoch 1", "Epoch 5", "Epoch 10"]);
assert_eq!(state.color_scale(), &HeatmapColorScale::Inferno);
assert_eq!(state.title(), Some("Distribution Map"));
for col in 0..3 {
let total: f64 = (0..25).map(|r| state.get(r, col).unwrap_or(0.0)).sum();
assert_eq!(total, 101.0);
}
}
}