use crate::errors::VizError;
use crate::viz::report::PairwiseMatrix;
use image::{DynamicImage, ImageBuffer, Rgb, RgbImage};
use ndarray::Array2;
use std::path::Path;
fn normalize(data: &Array2<f32>) -> Array2<f32> {
let min = data.iter().cloned().fold(f32::INFINITY, f32::min);
let max = data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let range = (max - min).max(1e-8);
data.mapv(|v| (v - min) / range)
}
fn heatmap_color(v: f32) -> [u8; 3] {
let v = v.clamp(0.0, 1.0);
let r = (v * 2.0 - 1.0).max(0.0) * 255.0;
let g = (1.0 - (v * 2.0 - 1.0).abs()) * 255.0;
let b = (1.0 - v * 2.0).max(0.0) * 255.0;
[r as u8, g as u8, b as u8]
}
pub fn save_attention_overlay(
original: &DynamicImage,
attention_map: &Array2<f32>,
output_path: &Path,
alpha: f32,
) -> Result<(), VizError> {
let (ow, oh) = (original.width(), original.height());
let rgb = original.to_rgb8();
let (ah, aw) = (attention_map.shape()[0], attention_map.shape()[1]);
let norm = normalize(attention_map);
let mut out: RgbImage = ImageBuffer::new(ow, oh);
for py in 0..oh {
for px in 0..ow {
let ax = (px as f32 / ow as f32 * aw as f32) as usize;
let ay = (py as f32 / oh as f32 * ah as f32) as usize;
let attn_val = norm[[ay.min(ah - 1), ax.min(aw - 1)]];
let heat = heatmap_color(attn_val);
let orig = rgb.get_pixel(px, py);
let r = (orig[0] as f32 * (1.0 - alpha) + heat[0] as f32 * alpha) as u8;
let g = (orig[1] as f32 * (1.0 - alpha) + heat[1] as f32 * alpha) as u8;
let b = (orig[2] as f32 * (1.0 - alpha) + heat[2] as f32 * alpha) as u8;
out.put_pixel(px, py, Rgb([r, g, b]));
}
}
out.save(output_path)
.map_err(|e| VizError::Png(format!("Failed to save {}: {e}", output_path.display())))?;
Ok(())
}
const PCA_MIN_OUTPUT_SIZE: u32 = 448;
pub fn save_pca_rgb(
projections: &Array2<f32>,
grid_size: usize,
output_path: &Path,
) -> Result<(), VizError> {
let n = projections.shape()[0];
if n < grid_size * grid_size {
return Err(VizError::Png(format!(
"Expected {}x{}={} patches, got {}",
grid_size,
grid_size,
grid_size * grid_size,
n
)));
}
let make_channel = |c: usize| -> Vec<u8> {
let vals: Vec<f32> = (0..grid_size * grid_size)
.map(|i| projections[[i, c]])
.collect();
let min = vals.iter().cloned().fold(f32::INFINITY, f32::min);
let max = vals.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let range = (max - min).max(1e-8);
vals.iter()
.map(|&v| ((v - min) / range * 255.0) as u8)
.collect()
};
let r_ch = make_channel(0);
let g_ch = make_channel(1);
let b_ch = make_channel(2.min(projections.shape()[1] - 1));
let scale = (PCA_MIN_OUTPUT_SIZE as usize / grid_size).max(1);
let out_size = (grid_size * scale) as u32;
let mut img: RgbImage = ImageBuffer::new(out_size, out_size);
for py in 0..out_size {
for px in 0..out_size {
let gx = (px as usize / scale).min(grid_size - 1);
let gy = (py as usize / scale).min(grid_size - 1);
let idx = gy * grid_size + gx;
img.put_pixel(px, py, Rgb([r_ch[idx], g_ch[idx], b_ch[idx]]));
}
}
img.save(output_path)
.map_err(|e| VizError::Png(format!("Failed to save {}: {e}", output_path.display())))?;
Ok(())
}
pub fn save_similarity_heatmap(matrix: &Array2<f32>, output_path: &Path) -> Result<(), VizError> {
let n = matrix.shape()[0];
let norm = normalize(matrix);
let mut img: RgbImage = ImageBuffer::new(n as u32, n as u32);
for (i, pixel) in img.pixels_mut().enumerate() {
let row = i / n;
let col = i % n;
let c = heatmap_color(norm[[row, col]]);
*pixel = Rgb(c);
}
img.save(output_path)
.map_err(|e| VizError::Png(format!("Failed to save {}: {e}", output_path.display())))?;
Ok(())
}
pub fn save_pairwise_heatmap(matrix: &PairwiseMatrix, output_path: &Path) -> Result<(), VizError> {
if matrix.is_empty() {
return Err(VizError::Png(
"Cannot render an empty pairwise heatmap".to_string(),
));
}
let present_values = matrix
.rows
.iter()
.flat_map(|row| row.iter().flatten().copied())
.collect::<Vec<_>>();
let min = present_values
.iter()
.copied()
.fold(f32::INFINITY, f32::min)
.min(1.0);
let max = present_values
.iter()
.copied()
.fold(f32::NEG_INFINITY, f32::max)
.max(1.0);
let range = (max - min).max(1e-8);
let cell_size = 48u32;
let gap = 2u32;
let n = matrix.len() as u32;
let width = n * cell_size + n.saturating_sub(1) * gap;
let height = width;
let mut img = RgbImage::from_pixel(width, height, Rgb([18, 23, 33]));
for (row_idx, row) in matrix.rows.iter().enumerate() {
for (col_idx, value) in row.iter().enumerate() {
let color = value
.map(|value| heatmap_color((value - min) / range))
.unwrap_or([72, 78, 88]);
let x0 = col_idx as u32 * (cell_size + gap);
let y0 = row_idx as u32 * (cell_size + gap);
for y in y0..(y0 + cell_size) {
for x in x0..(x0 + cell_size) {
img.put_pixel(x, y, Rgb(color));
}
}
}
}
img.save(output_path)
.map_err(|e| VizError::Png(format!("Failed to save {}: {e}", output_path.display())))?;
Ok(())
}
pub fn save_series_chart(values: &[f32], output_path: &Path) -> Result<(), VizError> {
let values = if values.is_empty() {
vec![0.0]
} else {
values.to_vec()
};
let min = values.iter().copied().fold(f32::INFINITY, f32::min);
let max = values.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let range = (max - min).max(1e-8);
let bar_width = 22u32;
let gap = 8u32;
let padding = 16u32;
let chart_height = 180u32;
let width =
padding * 2 + values.len() as u32 * bar_width + values.len().saturating_sub(1) as u32 * gap;
let height = chart_height + padding * 2;
let mut img = RgbImage::from_pixel(width, height, Rgb([13, 17, 23]));
for (index, value) in values.iter().enumerate() {
let normalized = ((*value - min) / range).clamp(0.0, 1.0);
let bar_height = (normalized * chart_height as f32).round() as u32;
let x0 = padding + index as u32 * (bar_width + gap);
let y0 = padding + chart_height.saturating_sub(bar_height);
let color = heatmap_color((index as f32 + 1.0) / values.len() as f32);
for y in y0..(padding + chart_height) {
for x in x0..(x0 + bar_width) {
img.put_pixel(x, y, Rgb(color));
}
}
}
img.save(output_path)
.map_err(|e| VizError::Png(format!("Failed to save {}: {e}", output_path.display())))?;
Ok(())
}
pub fn save_variance_spectrum_chart(ratios: &[f32], output_path: &Path) -> Result<(), VizError> {
if ratios.is_empty() {
return Err(VizError::Png(
"Cannot render a variance spectrum without ratios".to_string(),
));
}
let bar_width = 18u32;
let gap = 6u32;
let padding = 16u32;
let chart_height = 180u32;
let width =
padding * 2 + ratios.len() as u32 * bar_width + ratios.len().saturating_sub(1) as u32 * gap;
let height = chart_height + padding * 2;
let mut img = RgbImage::from_pixel(width, height, Rgb([13, 17, 23]));
for (index, ratio) in ratios.iter().enumerate() {
let bar_height = (ratio.clamp(0.0, 1.0) * chart_height as f32).round() as u32;
let x0 = padding + index as u32 * (bar_width + gap);
let y0 = padding + chart_height.saturating_sub(bar_height);
let color = heatmap_color((index as f32 + 1.0) / ratios.len() as f32);
for y in y0..(padding + chart_height) {
for x in x0..(x0 + bar_width) {
img.put_pixel(x, y, Rgb(color));
}
}
}
img.save(output_path)
.map_err(|e| VizError::Png(format!("Failed to save {}: {e}", output_path.display())))?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_heatmap_color_clamps_to_expected_extremes() {
assert_eq!(heatmap_color(-1.0), [0, 0, 255]);
assert_eq!(heatmap_color(0.5), [0, 255, 0]);
assert_eq!(heatmap_color(2.0), [255, 0, 0]);
}
#[test]
fn test_save_similarity_heatmap() {
let dir = tempdir().unwrap();
let path = dir.path().join("sim.png");
let matrix = Array2::from_shape_fn((8, 8), |(i, j)| if i == j { 1.0 } else { 0.0 });
save_similarity_heatmap(&matrix, &path).unwrap();
assert!(path.exists());
}
#[test]
fn test_save_pairwise_heatmap() {
let dir = tempdir().unwrap();
let path = dir.path().join("pairwise.png");
let matrix = PairwiseMatrix {
labels: vec!["a".into(), "b".into()],
rows: vec![vec![Some(1.0), Some(0.5)], vec![Some(0.5), Some(1.0)]],
};
save_pairwise_heatmap(&matrix, &path).unwrap();
assert!(path.exists());
}
#[test]
fn test_save_variance_spectrum_chart() {
let dir = tempdir().unwrap();
let path = dir.path().join("variance.png");
save_variance_spectrum_chart(&[0.4, 0.25, 0.2, 0.1], &path).unwrap();
assert!(path.exists());
}
#[test]
fn test_save_series_chart() {
let dir = tempdir().unwrap();
let path = dir.path().join("series.png");
save_series_chart(&[0.2, 0.6, 0.4], &path).unwrap();
assert!(path.exists());
}
}