use std::collections::HashMap;
use super::Graph;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct Interval {
pub first_write_op: u32,
pub last_read_op: u32,
}
#[derive(Debug, Default)]
pub struct Lifetimes {
pub intervals: HashMap<u32, Interval>,
}
pub fn analyze_lifetimes(graph: &Graph) -> Lifetimes {
let mut first_write: HashMap<u32, u32> = HashMap::new();
let mut last_read: HashMap<u32, u32> = HashMap::new();
for (op_idx, op) in graph.ops.iter().enumerate() {
let op_idx = op_idx as u32;
for buf in op.writes_raw() {
first_write.entry(buf).or_insert(op_idx);
}
for buf in op.reads_raw() {
if first_write.contains_key(&buf) {
last_read.insert(buf, op_idx);
}
}
}
let mut intervals = HashMap::with_capacity(first_write.len());
for (buf, first) in first_write {
let last = last_read.get(&buf).copied().unwrap_or(first);
intervals.insert(
buf,
Interval {
first_write_op: first,
last_read_op: last,
},
);
}
Lifetimes { intervals }
}
pub type ColorId = u32;
#[derive(Debug, Default)]
pub struct ColoringMap {
pub bufid_to_color: HashMap<u32, ColorId>,
pub color_count: u32,
}
pub fn greedy_color(lifetimes: &Lifetimes) -> ColoringMap {
let mut intervals: Vec<(u32, Interval)> = lifetimes
.intervals
.iter()
.map(|(b, i)| (*b, *i))
.collect();
intervals.sort_by_key(|(b, i)| (i.first_write_op, *b));
let mut active: Vec<(ColorId, u32)> = Vec::new();
let mut next_color: ColorId = 0;
let mut bufid_to_color: HashMap<u32, ColorId> =
HashMap::with_capacity(intervals.len());
for (buf, interval) in &intervals {
active.retain(|(_, last_read)| *last_read >= interval.first_write_op);
let mut active_colors: Vec<ColorId> =
active.iter().map(|(c, _)| *c).collect();
active_colors.sort();
let mut expected: ColorId = 0;
let mut chosen: Option<ColorId> = None;
for &c in &active_colors {
if c != expected {
chosen = Some(expected);
break;
}
expected += 1;
}
let color = match chosen {
Some(c) => c,
None => {
if expected < next_color {
expected
} else {
let c = next_color;
next_color += 1;
c
}
}
};
active.push((color, interval.last_read_op));
bufid_to_color.insert(*buf, color);
}
ColoringMap {
bufid_to_color,
color_count: next_color,
}
}
#[cfg(test)]
mod tests {
use super::super::buftype::{
Buf, BufId, ConvOutBuf, HiddenBuf, OProjOutBuf, ResidualBuf,
};
use super::super::Op;
use super::*;
fn buf<B: Buf>(n: u32) -> BufId<B> {
BufId::from_raw(n)
}
fn resid(a: u32, b: u32, out: u32) -> Op {
Op::ResidualAddNTokens {
label: "test",
a: buf::<OProjOutBuf>(a),
b: buf::<HiddenBuf>(b).into(),
out: buf::<ResidualBuf>(out),
n_tokens: 1,
dim: 1,
}
}
#[test]
fn empty_graph_has_no_intervals() {
let g = Graph::new();
let lt = analyze_lifetimes(&g);
assert!(lt.intervals.is_empty());
}
#[test]
fn pure_input_bufids_absent_from_intervals() {
let mut g = Graph::new();
g.push(resid(0, 1, 2));
let lt = analyze_lifetimes(&g);
assert_eq!(lt.intervals.len(), 1);
assert!(lt.intervals.contains_key(&2));
assert!(!lt.intervals.contains_key(&0));
assert!(!lt.intervals.contains_key(&1));
assert_eq!(
lt.intervals[&2],
Interval { first_write_op: 0, last_read_op: 0 }
);
}
#[test]
fn chain_of_residuals_has_one_interval_per_intermediate() {
let mut g = Graph::new();
g.push(resid(0, 1, 5)); g.push(resid(5, 2, 6)); g.push(resid(6, 3, 7)); g.push(resid(7, 4, 8)); let lt = analyze_lifetimes(&g);
assert_eq!(lt.intervals.len(), 4);
assert_eq!(
lt.intervals[&5],
Interval { first_write_op: 0, last_read_op: 1 }
);
assert_eq!(
lt.intervals[&6],
Interval { first_write_op: 1, last_read_op: 2 }
);
assert_eq!(
lt.intervals[&7],
Interval { first_write_op: 2, last_read_op: 3 }
);
assert_eq!(
lt.intervals[&8],
Interval { first_write_op: 3, last_read_op: 3 }
);
}
#[test]
fn rmw_bufid_has_single_point_interval() {
let mut g = Graph::new();
g.push(Op::RmsNormQkNTokens {
label: "qk",
x: buf::<ConvOutBuf>(0),
num_k_heads: 4,
key_dim: 128,
key_offset_per_token: 512,
per_token_total: 1024,
n_tokens: 1,
});
let lt = analyze_lifetimes(&g);
assert_eq!(
lt.intervals[&0],
Interval { first_write_op: 0, last_read_op: 0 }
);
}
#[test]
fn coloring_empty_lifetimes() {
let cm = greedy_color(&Lifetimes::default());
assert_eq!(cm.color_count, 0);
assert!(cm.bufid_to_color.is_empty());
}
#[test]
fn coloring_disjoint_intervals_reuses_color() {
let mut lt = Lifetimes::default();
lt.intervals.insert(
0,
Interval { first_write_op: 0, last_read_op: 1 },
);
lt.intervals.insert(
1,
Interval { first_write_op: 2, last_read_op: 3 },
);
let cm = greedy_color(<);
assert_eq!(cm.color_count, 1);
assert_eq!(cm.bufid_to_color[&0], 0);
assert_eq!(cm.bufid_to_color[&1], 0);
}
#[test]
fn coloring_overlapping_intervals_use_two_colors() {
let mut lt = Lifetimes::default();
lt.intervals.insert(
0,
Interval { first_write_op: 0, last_read_op: 2 },
);
lt.intervals.insert(
1,
Interval { first_write_op: 1, last_read_op: 3 },
);
let cm = greedy_color(<);
assert_eq!(cm.color_count, 2);
assert_ne!(cm.bufid_to_color[&0], cm.bufid_to_color[&1]);
}
#[test]
fn coloring_ping_pong_chain_uses_two_colors() {
let mut g = Graph::new();
for i in 0..10 {
let a = if i == 0 { 0 } else { 100 + i - 1 };
let out = 100 + i;
g.push(resid(a, i + 1, out));
}
let lt = analyze_lifetimes(&g);
let cm = greedy_color(<);
assert!(
cm.color_count <= 2,
"expected ≤ 2 colors, got {}",
cm.color_count
);
}
#[test]
fn coloring_residual_chain_compresses_intermediates() {
let mut g = Graph::new();
g.push(resid(0, 1, 5));
g.push(resid(5, 2, 6));
g.push(resid(6, 3, 7));
g.push(resid(7, 4, 8));
let lt = analyze_lifetimes(&g);
let cm = greedy_color(<);
assert_eq!(cm.color_count, 2);
}
#[test]
fn coloring_is_deterministic_across_runs() {
let mk = |order: &[(u32, u32, u32)]| -> Lifetimes {
let mut lt = Lifetimes::default();
for &(b, fw, lr) in order {
lt.intervals.insert(
b,
Interval { first_write_op: fw, last_read_op: lr },
);
}
lt
};
let lt_a = mk(&[(10, 0, 1), (20, 1, 2), (30, 2, 3)]);
let lt_b = mk(&[(30, 2, 3), (10, 0, 1), (20, 1, 2)]);
let cm_a = greedy_color(<_a);
let cm_b = greedy_color(<_b);
assert_eq!(cm_a.color_count, cm_b.color_count);
assert_eq!(cm_a.bufid_to_color, cm_b.bufid_to_color);
}
}