use kuva::backend::svg::SvgBackend;
use kuva::plot::{ColorMap, Heatmap, PhyloTree};
use kuva::render::figure::Figure;
use kuva::render::layout::Layout;
use kuva::render::plots::Plot;
use kuva::render::render::render_multiple;
#[test]
fn test_heatmap_colorbar_values() {
let data = vec![
vec![10.0, 20.0, 30.0],
vec![4.0, 50.0, 6.0],
vec![7.0, 8.0, 90.0],
];
let heatmap = Heatmap::new()
.with_data(data)
.with_values()
.with_color_map(ColorMap::Viridis);
let plots = vec![Plot::Heatmap(heatmap)];
let layout = Layout::auto_from_plots(&plots).with_title("Heatmap");
let scene = render_multiple(plots, layout);
let svg = SvgBackend.render_scene(&scene);
std::fs::write("test_outputs/heatmap_values.svg", svg.clone()).unwrap();
assert!(svg.contains("<svg"));
}
#[test]
fn test_heatmap_colorbar() {
let data = vec![
vec![10.0, 20.0, 30.0],
vec![4.0, 50.0, 6.0],
vec![7.0, 8.0, 90.0],
];
let heatmap = Heatmap::new()
.with_data(data)
.with_color_map(ColorMap::Viridis);
let plots = vec![Plot::Heatmap(heatmap)];
let layout = Layout::auto_from_plots(&plots).with_title("Heatmap with Colorbar");
let scene = render_multiple(plots, layout);
let svg = SvgBackend.render_scene(&scene);
std::fs::write("test_outputs/heatmap_colorbar.svg", svg.clone()).unwrap();
assert!(svg.contains("<svg"));
assert!(svg.contains("<rect")); }
#[test]
fn test_heatmap_with_y_categories_reorders_data() {
let data = vec![
vec![99.0, 1.0, 2.0], vec![3.0, 4.0, 5.0], vec![6.0, 7.0, 8.0], ];
let row_labels: Vec<String> = ["A", "B", "C"].iter().map(|s| s.to_string()).collect();
let col_labels: Vec<String> = ["x", "y", "z"].iter().map(|s| s.to_string()).collect();
let desired_top_to_bottom: Vec<String> =
["C", "B", "A"].iter().map(|s| s.to_string()).collect();
let heatmap = Heatmap::new()
.with_data(data)
.with_labels(row_labels, col_labels)
.with_y_categories(desired_top_to_bottom);
assert_eq!(heatmap.data[0][0], 99.0, "data row 0 (bottom) should be A");
assert_eq!(heatmap.data[1][0], 3.0, "data row 1 should be B");
assert_eq!(heatmap.data[2][0], 6.0, "data row 2 (top) should be C");
let expected_row_labels: &[String] = &["A", "B", "C"]
.iter()
.map(|s| s.to_string())
.collect::<Vec<_>>();
assert_eq!(heatmap.row_labels.as_deref(), Some(expected_row_labels));
let layout_cats = heatmap.row_labels.clone().unwrap();
let plots = vec![Plot::Heatmap(heatmap)];
let layout = Layout::auto_from_plots(&plots)
.with_title("Heatmap — C top, B mid, A bottom")
.with_y_categories(layout_cats);
let svg = SvgBackend.render_scene(&render_multiple(plots, layout));
std::fs::write("test_outputs/heatmap_y_categories.svg", svg.clone()).unwrap();
assert!(svg.contains("<svg"));
}
#[test]
fn test_heatmap_with_x_categories_reorders_data() {
let data = vec![
vec![1.0, 2.0, 10.0],
vec![3.0, 4.0, 20.0],
vec![5.0, 6.0, 30.0],
];
let row_labels: Vec<String> = ["A", "B", "C"].iter().map(|s| s.to_string()).collect();
let col_labels: Vec<String> = ["x", "y", "z"].iter().map(|s| s.to_string()).collect();
let desired: Vec<String> = ["z", "x", "y"].iter().map(|s| s.to_string()).collect();
let heatmap = Heatmap::new()
.with_data(data)
.with_labels(row_labels, col_labels)
.with_x_categories(desired.clone());
assert_eq!(
heatmap.data[0][0], 10.0,
"col 0 row 0 should be z-value for A"
);
assert_eq!(
heatmap.data[1][0], 20.0,
"col 0 row 1 should be z-value for B"
);
assert_eq!(
heatmap.data[2][0], 30.0,
"col 0 row 2 should be z-value for C"
);
assert_eq!(heatmap.col_labels.as_deref(), Some(desired.as_slice()));
let plots = vec![Plot::Heatmap(heatmap)];
let layout = Layout::auto_from_plots(&plots)
.with_title("Heatmap — cols reordered z, x, y")
.with_x_categories(desired);
let svg = SvgBackend.render_scene(&render_multiple(plots, layout));
std::fs::write("test_outputs/heatmap_x_categories.svg", svg.clone()).unwrap();
assert!(svg.contains("<svg"));
}
#[test]
fn test_phylo_heatmap_alignment() {
let labels_str = ["Wolf", "Cat", "Whale", "Human"];
let labels: Vec<String> = labels_str.iter().map(|s| s.to_string()).collect();
let dist = vec![
vec![0.0, 0.5, 0.9, 0.8],
vec![0.5, 0.0, 0.9, 0.8],
vec![0.9, 0.9, 0.0, 0.7],
vec![0.8, 0.8, 0.7, 0.0],
];
let tree = PhyloTree::from_distance_matrix(&labels_str, &dist).with_phylogram();
let leaf_order = tree.leaf_labels_top_to_bottom();
let heatmap = Heatmap::new()
.with_data(dist)
.with_labels(labels, vec![])
.with_y_categories(leaf_order.clone());
let last_leaf = leaf_order.last().unwrap().as_str();
let last_leaf_idx_in_original = labels_str.iter().position(|&s| s == last_leaf).unwrap();
assert_eq!(
heatmap.data[0][last_leaf_idx_in_original], 0.0,
"diagonal must be 0.0: bottom-of-tree leaf should be in data row 0"
);
let layout_cats = heatmap.row_labels.clone().unwrap();
let tree_plots = vec![Plot::PhyloTree(tree)];
let heatmap_plots = vec![Plot::Heatmap(heatmap)];
let tree_layout = Layout::auto_from_plots(&tree_plots).with_title("UPGMA Tree");
let heatmap_layout = Layout::auto_from_plots(&heatmap_plots)
.with_title("Distance Matrix")
.with_y_categories(layout_cats);
let figure = Figure::new(1, 2)
.with_plots(vec![tree_plots, heatmap_plots])
.with_layouts(vec![tree_layout, heatmap_layout])
.with_title("Phylo + Heatmap — aligned leaf order");
let svg = SvgBackend.render_scene(&figure.render());
std::fs::write("test_outputs/heatmap_phylo_alignment.svg", svg.clone()).unwrap();
assert!(svg.contains("<svg"));
}
#[test]
fn test_heatmap_x_range() {
let data = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
let hm = Heatmap::new().with_data(data).with_x_range(-10.0, 10.0);
let plots = vec![Plot::Heatmap(hm)];
let layout = Layout::auto_from_plots(&plots).with_x_label("X");
let svg = SvgBackend.render_scene(&render_multiple(plots, layout));
assert!(svg.contains("<svg"));
assert!(svg.contains("-10") || svg.contains("10"));
}
#[test]
fn test_heatmap_y_range() {
let data = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
let hm = Heatmap::new().with_data(data).with_y_range(-4.0, 4.0);
let plots = vec![Plot::Heatmap(hm)];
let layout = Layout::auto_from_plots(&plots).with_y_label("Y");
let svg = SvgBackend.render_scene(&render_multiple(plots, layout));
assert!(svg.contains("<svg"));
assert!(svg.contains("-4") || svg.contains("4"));
}
#[test]
fn test_heatmap_xy_range() {
let data = vec![
vec![10.0, 20.0, 30.0, 40.0],
vec![50.0, 60.0, 70.0, 80.0],
vec![90.0, 80.0, 70.0, 60.0],
vec![50.0, 40.0, 30.0, 20.0],
];
let hm = Heatmap::new()
.with_data(data)
.with_x_range(-10.0, 10.0)
.with_y_range(-4.0, 4.0);
let plots = vec![Plot::Heatmap(hm)];
let layout = Layout::auto_from_plots(&plots)
.with_title("Scalar Field")
.with_x_label("X (m)")
.with_y_label("Y (m)");
let svg = SvgBackend.render_scene(&render_multiple(plots, layout));
std::fs::create_dir_all("test_outputs").unwrap();
std::fs::write("test_outputs/heatmap_xy_range.svg", &svg).unwrap();
assert!(svg.contains("<svg"));
}
#[test]
fn test_heatmap_default_range_unchanged() {
let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let hm = Heatmap::new().with_data(data);
let plots = vec![Plot::Heatmap(hm)];
let b = plots[0].bounds().unwrap();
assert_eq!(b, ((0.5, 2.5), (0.5, 2.5)));
}
#[test]
fn test_heatmap_custom_range_bounds() {
let data = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
let hm = Heatmap::new()
.with_data(data)
.with_x_range(-10.0, 10.0)
.with_y_range(-4.0, 4.0);
let plots = vec![Plot::Heatmap(hm)];
let b = plots[0].bounds().unwrap();
assert_eq!(b, ((-10.0, 10.0), (-4.0, 4.0)));
}
#[test]
fn test_heatmap_cell_size_default() {
let hm = Heatmap::new().with_data(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
assert!(
(hm.cell_size - 0.99).abs() < 1e-9,
"default cell_size should be 0.99"
);
}
fn parse_rect_xw(svg: &str) -> Vec<(f64, f64)> {
let mut out = Vec::new();
for chunk in svg.split("<rect ") {
let x = chunk
.split("x=\"")
.nth(1)
.and_then(|s| s.split('"').next())
.and_then(|s| s.parse::<f64>().ok());
let w = chunk
.split("width=\"")
.nth(1)
.and_then(|s| s.split('"').next())
.and_then(|s| s.parse::<f64>().ok());
if let (Some(x), Some(w)) = (x, w) {
out.push((x, w));
}
}
out
}
fn extract_cell_rects_n(rects: &[(f64, f64)], n: usize) -> Vec<(f64, f64)> {
let mut i = 0;
while i + n <= rects.len() {
let w0 = rects[i].1;
let run: Vec<_> = rects[i..]
.iter()
.take_while(|&&(_, w)| (w - w0).abs() < 1.0)
.copied()
.collect();
if run.len() == n {
return run;
}
i += run.len().max(1);
}
vec![]
}
#[test]
fn test_heatmap_cell_size_gap_default() {
let data = vec![vec![1.0, 2.0, 3.0, 4.0]];
let hm = Heatmap::new().with_data(data);
let plots = vec![Plot::Heatmap(hm)];
let layout = Layout::auto_from_plots(&plots).with_width(500.0);
let svg = SvgBackend.render_scene(&render_multiple(plots, layout));
std::fs::create_dir_all("test_outputs").unwrap();
std::fs::write("test_outputs/heatmap_cell_size_gap.svg", &svg).unwrap();
let all_rects = parse_rect_xw(&svg);
let cells = extract_cell_rects_n(&all_rects, 4);
assert_eq!(cells.len(), 4, "expected 4 cell rects");
let mut sorted = cells.clone();
sorted.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
for w in sorted.windows(2) {
let right = w[0].0 + w[0].1;
let next_left = w[1].0;
assert!(
right < next_left + 1e-3,
"default cell_size=0.99: right edge {right:.3} should be < next left {next_left:.3}"
);
}
}
#[test]
fn test_heatmap_cell_size_flush() {
let data = vec![vec![1.0, 2.0, 3.0, 4.0]];
let hm = Heatmap::new().with_data(data).with_cell_size(1.0);
assert!((hm.cell_size - 1.0).abs() < 1e-9);
let plots = vec![Plot::Heatmap(hm)];
let layout = Layout::auto_from_plots(&plots).with_width(500.0);
let svg = SvgBackend.render_scene(&render_multiple(plots, layout));
std::fs::write("test_outputs/heatmap_flush.svg", &svg).unwrap();
let all_rects = parse_rect_xw(&svg);
let cells = extract_cell_rects_n(&all_rects, 4);
assert_eq!(cells.len(), 4, "expected 4 cell rects");
let mut sorted = cells.clone();
sorted.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
for w in sorted.windows(2) {
let right = w[0].0 + w[0].1;
let next_left = w[1].0;
assert!(
right > next_left,
"cell_size=1.0: right edge {right:.3} should overlap next left {next_left:.3}"
);
}
}
#[test]
fn test_heatmap_cell_size_clamp() {
let hm = Heatmap::new()
.with_data(vec![vec![1.0]])
.with_cell_size(2.0);
assert!(
(hm.cell_size - 1.0).abs() < 1e-9,
"cell_size should be clamped to 1.0"
);
let hm2 = Heatmap::new()
.with_data(vec![vec![1.0]])
.with_cell_size(0.1);
assert!(
(hm2.cell_size - 0.5).abs() < 1e-9,
"cell_size should be clamped to 0.5"
);
}