use crate::error::AutogradError;
use std::collections::BTreeMap;
#[derive(Debug, Clone)]
pub struct NodeMeta {
pub index: usize,
pub activation_size: usize,
pub forward_cost: f64,
pub name: Option<String>,
}
impl NodeMeta {
pub fn new(index: usize, activation_size: usize, forward_cost: f64) -> Self {
Self {
index,
activation_size,
forward_cost,
name: None,
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
}
#[derive(Debug, Clone, Default)]
pub struct CheckpointedGraph {
nodes: Vec<NodeMeta>,
checkpointed: std::collections::BTreeSet<usize>,
}
impl CheckpointedGraph {
pub fn new() -> Self {
Self::default()
}
pub fn add_node(&mut self, meta: NodeMeta) {
self.nodes.push(meta);
}
pub fn checkpoint_node(&mut self, index: usize) {
self.checkpointed.insert(index);
}
pub fn remove_checkpoint(&mut self, index: usize) {
self.checkpointed.remove(&index);
}
pub fn num_nodes(&self) -> usize {
self.nodes.len()
}
pub fn checkpoint_count(&self) -> usize {
self.checkpointed.len()
}
pub fn is_checkpointed(&self, index: usize) -> bool {
self.checkpointed.contains(&index)
}
pub fn nodes(&self) -> impl Iterator<Item = &NodeMeta> {
self.nodes.iter()
}
pub fn checkpointed_indices(&self) -> &std::collections::BTreeSet<usize> {
&self.checkpointed
}
pub fn peak_memory(&self) -> usize {
if self.nodes.is_empty() {
return 0;
}
let cp: Vec<usize> = self.checkpointed.iter().copied().collect();
if cp.is_empty() {
return self.nodes.iter().map(|n| n.activation_size).sum();
}
let checkpoint_mem: usize = cp
.iter()
.filter_map(|&i| self.nodes.get(i))
.map(|n| n.activation_size)
.sum();
let n = self.nodes.len();
let mut max_segment_mem = 0usize;
let mut boundaries: Vec<(usize, usize)> = Vec::new();
if cp[0] > 0 {
boundaries.push((0, cp[0]));
}
for w in cp.windows(2) {
if w[1] > w[0] + 1 {
boundaries.push((w[0] + 1, w[1]));
}
}
if let Some(&last_cp) = cp.last() {
if last_cp + 1 < n {
boundaries.push((last_cp + 1, n));
}
}
for (start, end) in boundaries {
let seg_mem: usize = self.nodes[start..end].iter().map(|nd| nd.activation_size).sum();
if seg_mem > max_segment_mem {
max_segment_mem = seg_mem;
}
}
checkpoint_mem + max_segment_mem
}
pub fn recomputation_cost(&self) -> f64 {
if self.nodes.is_empty() {
return 0.0;
}
let mut total = 0.0f64;
for node in &self.nodes {
if !self.checkpointed.contains(&node.index) {
total += node.forward_cost;
}
}
total
}
}
pub fn checkpoint(checkpoint_inputs: Vec<usize>, f: impl FnOnce() -> usize) -> CheckpointRecord {
let output_index = f();
CheckpointRecord {
stored_inputs: checkpoint_inputs,
output_index,
}
}
#[derive(Debug, Clone)]
pub struct CheckpointRecord {
pub stored_inputs: Vec<usize>,
pub output_index: usize,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum MemoryBudget {
Absolute(usize),
Fraction(f64),
Slots(usize),
}
#[derive(Debug, Clone)]
pub struct CheckpointSchedule {
pub num_nodes: usize,
pub checkpoints: Vec<usize>,
pub estimated_peak_memory: usize,
pub estimated_recomputation_cost: f64,
}
#[derive(Debug, Clone)]
pub struct TradeoffPoint {
pub slots: usize,
pub peak_memory: usize,
pub recomputation_overhead: f64,
}
#[derive(Debug, Clone)]
pub struct MemoryComputeTradeoff {
pub pareto_points: Vec<TradeoffPoint>,
pub baseline_memory: usize,
pub baseline_forward_cost: f64,
}
pub fn sqrt_schedule(graph: &CheckpointedGraph) -> CheckpointSchedule {
let n = graph.num_nodes();
if n == 0 {
return CheckpointSchedule {
num_nodes: 0,
checkpoints: Vec::new(),
estimated_peak_memory: 0,
estimated_recomputation_cost: 0.0,
};
}
let k = ((n as f64).sqrt().ceil() as usize).max(1);
let mut checkpoints = Vec::new();
let mut i = 0;
while i < n {
checkpoints.push(i);
i += k;
}
if checkpoints.last().copied() != Some(n - 1) {
checkpoints.push(n - 1);
}
build_schedule(graph, checkpoints, n)
}
pub fn uniform_schedule(
graph: &CheckpointedGraph,
interval: usize,
) -> Result<CheckpointSchedule, AutogradError> {
if interval == 0 {
return Err(AutogradError::OperationError(
"uniform_schedule: interval must be > 0".to_string(),
));
}
let n = graph.num_nodes();
if n == 0 {
return Ok(CheckpointSchedule {
num_nodes: 0,
checkpoints: Vec::new(),
estimated_peak_memory: 0,
estimated_recomputation_cost: 0.0,
});
}
let mut checkpoints = Vec::new();
let mut i = 0;
while i < n {
checkpoints.push(i);
i += interval;
}
if checkpoints.last().copied() != Some(n - 1) {
checkpoints.push(n - 1);
}
Ok(build_schedule(graph, checkpoints, n))
}
pub fn binomial_schedule(
graph: &CheckpointedGraph,
budget: MemoryBudget,
) -> Result<CheckpointSchedule, AutogradError> {
let n = graph.num_nodes();
if n == 0 {
return Ok(CheckpointSchedule {
num_nodes: 0,
checkpoints: Vec::new(),
estimated_peak_memory: 0,
estimated_recomputation_cost: 0.0,
});
}
let total_mem: usize = graph.nodes().map(|nd| nd.activation_size).sum();
let total_cost: f64 = graph.nodes().map(|nd| nd.forward_cost).sum();
let slots = match budget {
MemoryBudget::Slots(s) => {
if s == 0 {
return Err(AutogradError::OperationError(
"binomial_schedule: slot budget must be > 0".to_string(),
));
}
s
}
MemoryBudget::Absolute(mem) => {
if mem == 0 {
return Err(AutogradError::OperationError(
"binomial_schedule: memory budget must be > 0".to_string(),
));
}
let avg_activation = total_mem / n;
if avg_activation == 0 {
n
} else {
(mem / avg_activation).max(1)
}
}
MemoryBudget::Fraction(f) => {
if !(0.0..=1.0).contains(&f) {
return Err(AutogradError::OperationError(format!(
"binomial_schedule: fraction {} must be in [0.0, 1.0]",
f
)));
}
let target = (total_mem as f64 * f) as usize;
let avg_activation = total_mem / n;
if avg_activation == 0 {
n
} else {
(target / avg_activation).max(1)
}
}
};
let checkpoints = if slots >= n {
(0..n).collect::<Vec<_>>()
} else {
let costs: Vec<f64> = graph.nodes().map(|nd| nd.forward_cost).collect();
let total_cost_pos: f64 = costs.iter().sum();
let target_interval = total_cost_pos / (slots as f64 + 1.0);
let mut cps = vec![0usize]; let mut cumulative = 0.0f64;
let mut next_target = target_interval;
for (i, &c) in costs.iter().enumerate() {
cumulative += c;
if cumulative >= next_target && cps.last().copied() != Some(i) {
cps.push(i);
next_target += target_interval;
if cps.len() >= slots {
break;
}
}
}
if cps.last().copied() != Some(n - 1) {
cps.push(n - 1);
}
cps
};
let _ = total_cost; Ok(build_schedule(graph, checkpoints, n))
}
pub fn optimal_checkpointing_schedule(
graph: &CheckpointedGraph,
memory_budget: MemoryBudget,
) -> Result<CheckpointSchedule, AutogradError> {
let n = graph.num_nodes();
if n == 0 {
return Ok(CheckpointSchedule {
num_nodes: 0,
checkpoints: Vec::new(),
estimated_peak_memory: 0,
estimated_recomputation_cost: 0.0,
});
}
let total_mem: usize = graph.nodes().map(|nd| nd.activation_size).sum();
let max_slots = match memory_budget {
MemoryBudget::Slots(s) => {
if s == 0 {
return Err(AutogradError::OperationError(
"optimal_checkpointing_schedule: slot budget must be > 0".to_string(),
));
}
s.min(n)
}
MemoryBudget::Absolute(mem) => {
if mem == 0 {
return Err(AutogradError::OperationError(
"optimal_checkpointing_schedule: memory budget must be > 0".to_string(),
));
}
let avg = if n > 0 { total_mem / n } else { 1 };
let avg = avg.max(1);
(mem / avg).max(1).min(n)
}
MemoryBudget::Fraction(f) => {
if !(0.0..=1.0).contains(&f) {
return Err(AutogradError::OperationError(format!(
"optimal_checkpointing_schedule: fraction {} must be in [0.0, 1.0]",
f
)));
}
let target = (total_mem as f64 * f) as usize;
let avg = if n > 0 { total_mem / n } else { 1 };
let avg = avg.max(1);
(target / avg).max(1).min(n)
}
};
let costs: Vec<f64> = graph.nodes().map(|nd| nd.forward_cost).collect();
let mut prefix = vec![0.0f64; n + 1];
for i in 0..n {
prefix[i + 1] = prefix[i] + costs[i];
}
let seg_cost = |a: usize, b: usize| -> f64 { prefix[b] - prefix[a] };
let inf = f64::INFINITY;
let mut dp: Vec<Vec<f64>> = vec![vec![inf; max_slots + 1]; n];
let mut prev: Vec<Vec<Option<usize>>> = vec![vec![None; max_slots + 1]; n];
for k in 1..=max_slots {
dp[0][k] = 0.0; }
for i in 1..n {
for k in 1..=max_slots {
for j in 0..i {
if dp[j][k - 1] < inf {
let cost = dp[j][k - 1] + seg_cost(j + 1, i);
if cost < dp[i][k] {
dp[i][k] = cost;
prev[i][k] = Some(j);
}
}
}
}
}
let last = n - 1;
let (best_k, best_cost) = (1..=max_slots)
.map(|k| (k, dp[last][k]))
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or((1, inf));
if best_cost == inf {
return Ok(sqrt_schedule(graph));
}
let mut checkpoints = Vec::new();
let mut pos = last;
let mut k_rem = best_k;
loop {
checkpoints.push(pos);
if k_rem <= 1 || pos == 0 {
break;
}
match prev[pos][k_rem] {
Some(j) => {
pos = j;
k_rem -= 1;
}
None => break,
}
}
if !checkpoints.contains(&0) {
checkpoints.push(0);
}
checkpoints.sort_unstable();
checkpoints.dedup();
Ok(build_schedule(graph, checkpoints, n))
}
pub fn analyse_tradeoff(graph: &CheckpointedGraph) -> MemoryComputeTradeoff {
let n = graph.num_nodes();
let baseline_memory: usize = graph.nodes().map(|nd| nd.activation_size).sum();
let baseline_forward_cost: f64 = graph.nodes().map(|nd| nd.forward_cost).sum();
if n == 0 {
return MemoryComputeTradeoff {
pareto_points: Vec::new(),
baseline_memory,
baseline_forward_cost,
};
}
let mut raw_points: Vec<(usize, f64, usize)> = Vec::new();
for slots in 1..=n {
let sched = optimal_checkpointing_schedule(graph, MemoryBudget::Slots(slots));
if let Ok(s) = sched {
raw_points.push((s.estimated_peak_memory, s.estimated_recomputation_cost, slots));
}
}
raw_points.sort_by_key(|&(mem, _, _)| mem);
let mut pareto: Vec<TradeoffPoint> = Vec::new();
let mut min_cost = f64::INFINITY;
for (mem, cost, slots) in &raw_points {
if *cost < min_cost {
min_cost = *cost;
pareto.push(TradeoffPoint {
slots: *slots,
peak_memory: *mem,
recomputation_overhead: if baseline_forward_cost > 0.0 {
*cost / baseline_forward_cost
} else {
0.0
},
});
}
}
MemoryComputeTradeoff {
pareto_points: pareto,
baseline_memory,
baseline_forward_cost,
}
}
fn build_schedule(
graph: &CheckpointedGraph,
checkpoints: Vec<usize>,
n: usize,
) -> CheckpointSchedule {
let mut g = graph.clone();
for i in 0..n {
g.remove_checkpoint(i);
}
for &cp in &checkpoints {
g.checkpoint_node(cp);
}
CheckpointSchedule {
num_nodes: n,
checkpoints,
estimated_peak_memory: g.peak_memory(),
estimated_recomputation_cost: g.recomputation_cost(),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_uniform_graph(n: usize, act_size: usize, fwd_cost: f64) -> CheckpointedGraph {
let mut g = CheckpointedGraph::new();
for i in 0..n {
g.add_node(NodeMeta::new(i, act_size, fwd_cost));
}
g
}
#[test]
fn test_graph_checkpoint_tracking() {
let mut g = make_uniform_graph(5, 10, 1.0);
g.checkpoint_node(0);
g.checkpoint_node(2);
g.checkpoint_node(4);
assert_eq!(g.checkpoint_count(), 3);
assert!(g.is_checkpointed(0));
assert!(!g.is_checkpointed(1));
g.remove_checkpoint(2);
assert_eq!(g.checkpoint_count(), 2);
}
#[test]
fn test_peak_memory_all_checkpointed() {
let mut g = make_uniform_graph(4, 10, 1.0);
for i in 0..4 {
g.checkpoint_node(i);
}
assert_eq!(g.peak_memory(), 40);
}
#[test]
fn test_peak_memory_no_checkpoints() {
let g = make_uniform_graph(4, 10, 1.0);
assert_eq!(g.peak_memory(), 40);
}
#[test]
fn test_recomputation_cost() {
let mut g = make_uniform_graph(4, 10, 2.0);
g.checkpoint_node(0);
g.checkpoint_node(3);
assert!((g.recomputation_cost() - 4.0).abs() < 1e-9);
}
#[test]
fn test_checkpoint_record() {
let record = checkpoint(vec![0, 5, 10], || 15);
assert_eq!(record.stored_inputs, vec![0, 5, 10]);
assert_eq!(record.output_index, 15);
}
#[test]
fn test_sqrt_schedule_9_nodes() {
let g = make_uniform_graph(9, 10, 1.0);
let sched = sqrt_schedule(&g);
assert!(sched.checkpoints.contains(&0), "should have 0");
assert!(sched.checkpoints.contains(&3), "should have 3");
assert!(sched.checkpoints.contains(&8), "should have last");
}
#[test]
fn test_sqrt_schedule_empty() {
let g = CheckpointedGraph::new();
let sched = sqrt_schedule(&g);
assert_eq!(sched.num_nodes, 0);
assert!(sched.checkpoints.is_empty());
}
#[test]
fn test_uniform_schedule_interval_2() {
let g = make_uniform_graph(6, 10, 1.0);
let sched = uniform_schedule(&g, 2).expect("uniform sched");
assert!(sched.checkpoints.contains(&0));
assert!(sched.checkpoints.contains(&2));
assert!(sched.checkpoints.contains(&4));
assert!(sched.checkpoints.contains(&5));
}
#[test]
fn test_uniform_schedule_zero_interval_error() {
let g = make_uniform_graph(4, 10, 1.0);
let r = uniform_schedule(&g, 0);
assert!(r.is_err());
}
#[test]
fn test_binomial_schedule_slots() {
let g = make_uniform_graph(16, 1, 1.0);
let sched = binomial_schedule(&g, MemoryBudget::Slots(4)).expect("binomial sched");
assert!(sched.checkpoints.len() <= 5, "too many cps: {:?}", sched.checkpoints);
}
#[test]
fn test_binomial_schedule_fraction() {
let g = make_uniform_graph(10, 10, 1.0);
let sched = binomial_schedule(&g, MemoryBudget::Fraction(0.5)).expect("binomial frac");
assert!(!sched.checkpoints.is_empty());
}
#[test]
fn test_binomial_invalid_fraction_error() {
let g = make_uniform_graph(4, 10, 1.0);
let r = binomial_schedule(&g, MemoryBudget::Fraction(1.5));
assert!(r.is_err());
}
#[test]
fn test_optimal_schedule_small() {
let g = make_uniform_graph(8, 10, 1.0);
let sched = optimal_checkpointing_schedule(&g, MemoryBudget::Slots(3))
.expect("optimal sched");
assert!(!sched.checkpoints.is_empty());
assert!(sched.checkpoints.len() <= 4);
}
#[test]
fn test_optimal_schedule_all_slots() {
let g = make_uniform_graph(5, 10, 1.0);
let sched = optimal_checkpointing_schedule(&g, MemoryBudget::Slots(5))
.expect("optimal all slots");
assert_eq!(sched.checkpoints.len(), 5);
assert!((sched.estimated_recomputation_cost - 0.0).abs() < 1e-9);
}
#[test]
fn test_optimal_schedule_zero_slots_error() {
let g = make_uniform_graph(4, 10, 1.0);
let r = optimal_checkpointing_schedule(&g, MemoryBudget::Slots(0));
assert!(r.is_err());
}
#[test]
fn test_analyse_tradeoff_non_empty() {
let g = make_uniform_graph(8, 10, 1.0);
let tradeoff = analyse_tradeoff(&g);
assert!(!tradeoff.pareto_points.is_empty());
assert_eq!(tradeoff.baseline_memory, 80);
assert!((tradeoff.baseline_forward_cost - 8.0).abs() < 1e-9);
}
#[test]
fn test_analyse_tradeoff_monotone_memory() {
let g = make_uniform_graph(8, 10, 1.0);
let tradeoff = analyse_tradeoff(&g);
let mems: Vec<usize> = tradeoff.pareto_points.iter().map(|p| p.peak_memory).collect();
for w in mems.windows(2) {
assert!(w[0] <= w[1], "memory not monotone: {:?}", mems);
}
}
#[test]
fn test_analyse_tradeoff_empty_graph() {
let g = CheckpointedGraph::new();
let tradeoff = analyse_tradeoff(&g);
assert!(tradeoff.pareto_points.is_empty());
}
}