use crate::memory::MemoryPlan;
use rlx_ir::NodeId;
pub fn render(plan: &MemoryPlan) -> String {
let row_height = 24u32;
let pad = 8u32;
let bytes_per_pixel = (plan.arena_size.max(1) / 800).max(1); let width = (plan.arena_size as u32 / bytes_per_pixel as u32).max(200) + 2 * pad;
let n_buffers = plan.assignments.len() as u32;
let height = n_buffers * row_height + 2 * pad + row_height;
let mut rows: Vec<(usize, usize, NodeId)> = plan
.assignments
.iter()
.map(|(id, s)| (s.offset, s.size, *id))
.collect();
rows.sort();
let mut s = String::new();
s.push_str(&format!(
r##"<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}" font-family="monospace" font-size="11">"##
));
s.push_str(r##"<rect x="0" y="0" width="100%" height="100%" fill="#fafafa"/>"##);
s.push_str(&format!(
r##"<text x="{}" y="{}" fill="#333">arena_size={} ({} unshared, saved {})</text>"##,
pad,
pad + 12,
plan.arena_size,
plan.total_unshared_bytes(),
plan.bytes_saved(),
));
let track_y_start = pad + row_height;
for (i, &(offset, size, id)) in rows.iter().enumerate() {
let x = pad + (offset as u32 / bytes_per_pixel as u32);
let w = (size as u32 / bytes_per_pixel as u32).max(2);
let y = track_y_start + (i as u32 * row_height);
let color = color_for(id);
s.push_str(&format!(
r##"<rect x="{x}" y="{y}" width="{w}" height="{rh}" fill="{color}" fill-opacity="0.8" stroke="#333" stroke-width="0.5"/>"##,
rh = row_height - 2,
));
s.push_str(&format!(
r##"<text x="{}" y="{}" fill="#222">%{}: off={} sz={}</text>"##,
x + 4,
y + 14,
id.0,
offset,
size,
));
}
s.push_str("</svg>");
s
}
fn color_for(id: NodeId) -> &'static str {
const PALETTE: [&str; 8] = [
"#a5d8ff", "#b2f2bb", "#ffd8a8", "#fcc2d7", "#d0bfff", "#ffe066", "#ced4da", "#a5b4fc",
];
PALETTE[(id.0 as usize) % PALETTE.len()]
}
#[cfg(test)]
mod tests {
use super::*;
use crate::memory::plan_memory;
use rlx_ir::*;
#[test]
fn render_emits_svg() {
let mut g = Graph::new("svg-test");
let f = DType::F32;
let x = g.input("x", Shape::new(&[8, 8], f));
let w = g.param("w", Shape::new(&[8, 8], f));
let mm = g.matmul(x, w, Shape::new(&[8, 8], f));
g.set_outputs(vec![mm]);
let plan = plan_memory(&g);
let svg = render(&plan);
assert!(svg.starts_with("<svg"));
assert!(svg.ends_with("</svg>"));
assert!(svg.contains("arena_size="));
assert!(svg.contains("<rect"));
}
#[test]
fn color_is_stable_for_same_id() {
let a = color_for(NodeId(7));
let b = color_for(NodeId(7));
assert_eq!(a, b);
}
}