use std::collections::HashSet;
use crate::breverse::BReverse;
use crate::bytecode_tape::{BtapeGuard, BtapeThreadLocal, BytecodeTape};
use crate::float::Float;
pub fn grad_checkpointed<F: Float + BtapeThreadLocal>(
step: impl Fn(&[BReverse<F>]) -> Vec<BReverse<F>>,
loss: impl FnOnce(&[BReverse<F>]) -> BReverse<F>,
x0: &[F],
num_steps: usize,
num_checkpoints: usize,
) -> Vec<F> {
let dim = x0.len();
if num_steps == 0 {
let (mut tape, _) = crate::api::record(loss, x0);
return tape.gradient(x0);
}
let num_checkpoints = num_checkpoints.max(1).min(num_steps);
let mut all_positions = revolve_schedule(num_steps, num_checkpoints);
all_positions.truncate(num_checkpoints);
let checkpoint_positions: HashSet<usize> = all_positions.into_iter().collect();
let mut checkpoints: Vec<(usize, Vec<F>)> = Vec::with_capacity(num_checkpoints + 1);
checkpoints.push((0, x0.to_vec()));
let mut current_state = x0.to_vec();
for s in 0..num_steps {
current_state = step_forward_primal(&step, ¤t_state);
assert_eq!(
current_state.len(),
dim,
"step must preserve dimension: expected {}, got {}",
dim,
current_state.len()
);
let next_step = s + 1;
if next_step < num_steps && checkpoint_positions.contains(&next_step) {
checkpoints.push((next_step, current_state.clone()));
}
}
let final_state = current_state;
backward_from_checkpoints(&step, loss, &final_state, &checkpoints, num_steps)
}
pub fn grad_checkpointed_online<F: Float + BtapeThreadLocal>(
step: impl Fn(&[BReverse<F>]) -> Vec<BReverse<F>>,
stop: impl Fn(&[F], usize) -> bool,
loss: impl FnOnce(&[BReverse<F>]) -> BReverse<F>,
x0: &[F],
num_checkpoints: usize,
) -> Vec<F> {
assert!(
num_checkpoints >= 2,
"online checkpointing requires at least 2 checkpoint slots, got {}",
num_checkpoints,
);
let dim = x0.len();
if stop(x0, 0) {
let (mut tape, _) = crate::api::record(loss, x0);
return tape.gradient(x0);
}
let mut buffer: Vec<(usize, Vec<F>)> = Vec::with_capacity(num_checkpoints);
buffer.push((0, x0.to_vec()));
let mut spacing = 1usize;
let mut current_state = x0.to_vec();
let mut step_index = 0usize;
loop {
current_state = step_forward_primal(&step, ¤t_state);
step_index += 1;
assert_eq!(
current_state.len(),
dim,
"step must preserve dimension: expected {}, got {}",
dim,
current_state.len()
);
if step_index.is_multiple_of(spacing) {
buffer.push((step_index, current_state.clone()));
}
if stop(¤t_state, step_index) {
break;
}
if buffer.len() >= num_checkpoints {
let tail: Vec<(usize, Vec<F>)> =
buffer[1..].iter().skip(1).step_by(2).cloned().collect();
buffer.truncate(1);
buffer.extend(tail);
spacing *= 2;
}
}
let num_steps = step_index;
let final_state = current_state;
backward_from_checkpoints(&step, loss, &final_state, &buffer, num_steps)
}
pub fn grad_checkpointed_with_hints<F: Float + BtapeThreadLocal>(
step: impl Fn(&[BReverse<F>]) -> Vec<BReverse<F>>,
loss: impl FnOnce(&[BReverse<F>]) -> BReverse<F>,
x0: &[F],
num_steps: usize,
num_checkpoints: usize,
required_positions: &[usize],
) -> Vec<F> {
let dim = x0.len();
if num_steps == 0 {
let (mut tape, _) = crate::api::record(loss, x0);
return tape.gradient(x0);
}
let num_checkpoints = num_checkpoints.max(1).min(num_steps);
let mut required: Vec<usize> = required_positions
.iter()
.copied()
.filter(|&p| p >= 1 && p < num_steps)
.collect();
required.sort_unstable();
required.dedup();
assert!(
required.len() <= num_checkpoints,
"required positions ({}) exceed available checkpoint slots ({})",
required.len(),
num_checkpoints,
);
let free = num_checkpoints.saturating_sub(required.len());
let mut boundaries = Vec::with_capacity(required.len() + 2);
boundaries.push(0);
boundaries.extend_from_slice(&required);
boundaries.push(num_steps);
boundaries.dedup();
let intervals: Vec<(usize, usize)> = boundaries.windows(2).map(|w| (w[0], w[1])).collect();
let interval_lengths: Vec<usize> = intervals.iter().map(|(s, e)| e - s).collect();
let total_len: usize = interval_lengths.iter().sum();
let slot_alloc = largest_remainder_alloc(free, &interval_lengths, total_len);
let mut all_positions: HashSet<usize> = required.iter().copied().collect();
for (i, &(start, end)) in intervals.iter().enumerate() {
let sub_steps = end - start;
let sub_slots = slot_alloc[i];
if sub_steps > 1 && sub_slots > 0 {
let mut sub_positions = revolve_schedule(sub_steps, sub_slots);
sub_positions.truncate(sub_slots);
all_positions.extend(sub_positions.iter().map(|&p| p + start));
}
}
let mut checkpoints: Vec<(usize, Vec<F>)> = Vec::with_capacity(all_positions.len() + 1);
checkpoints.push((0, x0.to_vec()));
let mut current_state = x0.to_vec();
for s in 0..num_steps {
current_state = step_forward_primal(&step, ¤t_state);
assert_eq!(
current_state.len(),
dim,
"step must preserve dimension: expected {}, got {}",
dim,
current_state.len()
);
let next_step = s + 1;
if next_step < num_steps && all_positions.contains(&next_step) {
checkpoints.push((next_step, current_state.clone()));
}
}
let final_state = current_state;
backward_from_checkpoints(&step, loss, &final_state, &checkpoints, num_steps)
}
fn largest_remainder_alloc(total: usize, weights: &[usize], weight_sum: usize) -> Vec<usize> {
if weight_sum == 0 || weights.is_empty() {
return vec![0; weights.len()];
}
let mut alloc: Vec<usize> = weights.iter().map(|&w| (w * total) / weight_sum).collect();
let allocated: usize = alloc.iter().sum();
let mut remaining = total - allocated;
if remaining > 0 {
let mut remainders: Vec<(usize, f64)> = weights
.iter()
.enumerate()
.map(|(i, &w)| {
let exact = (w as f64 * total as f64) / weight_sum as f64;
(i, exact - alloc[i] as f64)
})
.collect();
remainders.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
for (idx, _) in remainders {
if remaining == 0 {
break;
}
alloc[idx] += 1;
remaining -= 1;
}
}
alloc
}
pub fn grad_checkpointed_disk<F: Float + BtapeThreadLocal>(
step: impl Fn(&[BReverse<F>]) -> Vec<BReverse<F>>,
loss: impl FnOnce(&[BReverse<F>]) -> BReverse<F>,
x0: &[F],
num_steps: usize,
num_checkpoints: usize,
dir: &std::path::Path,
) -> Vec<F> {
assert!(
dir.is_dir(),
"checkpoint directory does not exist: {}",
dir.display()
);
let dim = x0.len();
if num_steps == 0 {
let (mut tape, _) = crate::api::record(loss, x0);
return tape.gradient(x0);
}
let num_checkpoints = num_checkpoints.max(1).min(num_steps);
let mut all_positions = revolve_schedule(num_steps, num_checkpoints);
all_positions.truncate(num_checkpoints);
let checkpoint_positions: HashSet<usize> = all_positions.into_iter().collect();
let mut guard = DiskCheckpointGuard { files: Vec::new() };
let path_0 = dir.join("ckpt_0.bin");
write_checkpoint(x0, &path_0);
guard.files.push(path_0);
let mut current_state = x0.to_vec();
for s in 0..num_steps {
current_state = step_forward_primal(&step, ¤t_state);
assert_eq!(
current_state.len(),
dim,
"step must preserve dimension: expected {}, got {}",
dim,
current_state.len()
);
let next_step = s + 1;
if next_step < num_steps && checkpoint_positions.contains(&next_step) {
let path = dir.join(format!("ckpt_{}.bin", next_step));
write_checkpoint(¤t_state, &path);
guard.files.push(path);
}
}
let final_state = current_state;
let mut ckpt_steps: Vec<usize> = vec![0];
ckpt_steps.extend(checkpoint_positions.iter().filter(|&p| *p < num_steps));
ckpt_steps.sort_unstable();
ckpt_steps.dedup();
let mut adjoint = {
let (mut tape, _) = crate::api::record(loss, &final_state);
tape.gradient(&final_state)
};
let num_segments = ckpt_steps.len();
for seg in (0..num_segments).rev() {
let ckpt_step = ckpt_steps[seg];
let seg_end = if seg + 1 < num_segments {
ckpt_steps[seg + 1]
} else {
num_steps
};
let seg_len = seg_end - ckpt_step;
let path = dir.join(format!("ckpt_{}.bin", ckpt_step));
let ckpt_state = read_checkpoint::<F>(&path, dim);
let mut states: Vec<Vec<F>> = Vec::with_capacity(seg_len + 1);
states.push(ckpt_state);
let mut s = states[0].clone();
for _ in 0..seg_len {
s = step_forward_primal(&step, &s);
states.push(s.clone());
}
for i in (0..seg_len).rev() {
adjoint = vjp_step(&step, &states[i], &adjoint);
}
}
guard.cleanup();
adjoint
}
fn write_checkpoint<F: Float>(state: &[F], path: &std::path::Path) {
let bytes: &[u8] = unsafe {
std::slice::from_raw_parts(state.as_ptr().cast::<u8>(), std::mem::size_of_val(state))
};
std::fs::write(path, bytes).expect("checkpoint write failed");
}
fn read_checkpoint<F: Float>(path: &std::path::Path, dim: usize) -> Vec<F> {
let bytes = std::fs::read(path).expect("checkpoint read failed");
assert_eq!(
bytes.len(),
dim * std::mem::size_of::<F>(),
"checkpoint file size mismatch: expected {}, got {}",
dim * std::mem::size_of::<F>(),
bytes.len()
);
let mut state = vec![F::zero(); dim];
unsafe {
std::ptr::copy_nonoverlapping(bytes.as_ptr(), state.as_mut_ptr().cast::<u8>(), bytes.len());
}
state
}
struct DiskCheckpointGuard {
files: Vec<std::path::PathBuf>,
}
impl DiskCheckpointGuard {
fn cleanup(&mut self) {
for f in self.files.drain(..) {
let _ = std::fs::remove_file(f);
}
}
}
impl Drop for DiskCheckpointGuard {
fn drop(&mut self) {
self.cleanup();
}
}
fn backward_from_checkpoints<F: Float + BtapeThreadLocal>(
step: &impl Fn(&[BReverse<F>]) -> Vec<BReverse<F>>,
loss: impl FnOnce(&[BReverse<F>]) -> BReverse<F>,
final_state: &[F],
checkpoints: &[(usize, Vec<F>)],
num_steps: usize,
) -> Vec<F> {
let mut adjoint = {
let (mut tape, _) = crate::api::record(loss, final_state);
tape.gradient(final_state)
};
let num_segments = checkpoints.len();
for seg in (0..num_segments).rev() {
let (ckpt_step, ref ckpt_state) = checkpoints[seg];
let seg_end = if seg + 1 < num_segments {
checkpoints[seg + 1].0
} else {
num_steps
};
let seg_len = seg_end - ckpt_step;
let mut states: Vec<Vec<F>> = Vec::with_capacity(seg_len + 1);
states.push(ckpt_state.clone());
let mut s = ckpt_state.clone();
for _ in 0..seg_len {
s = step_forward_primal(step, &s);
states.push(s.clone());
}
for i in (0..seg_len).rev() {
adjoint = vjp_step(step, &states[i], &adjoint);
}
}
adjoint
}
fn revolve_schedule(num_steps: usize, num_checkpoints: usize) -> Vec<usize> {
if num_checkpoints >= num_steps {
return (1..num_steps).collect();
}
let mut positions = Vec::new();
schedule_recursive(0, num_steps, num_checkpoints, &mut positions);
positions.sort_unstable();
positions.dedup();
positions
}
fn schedule_recursive(start: usize, end: usize, checkpoints: usize, positions: &mut Vec<usize>) {
let steps = end - start;
if steps <= 1 || checkpoints == 0 {
return;
}
let advance = optimal_advance(steps, checkpoints);
let split = start + advance;
if split < end && split > start {
positions.push(split);
schedule_recursive(start, split, checkpoints - 1, positions);
schedule_recursive(split, end, checkpoints, positions);
}
}
fn optimal_advance(steps: usize, c: usize) -> usize {
if c == 0 || steps <= 1 {
return steps;
}
let mut t = 1usize;
while beta(t, c) < steps {
t += 1;
}
if t > 0 && c > 0 {
beta(t - 1, c - 1).max(1).min(steps - 1)
} else {
1
}
}
fn beta(s: usize, c: usize) -> usize {
if c == 0 {
return s + 1;
}
if s == 0 {
return 1;
}
let mut result = 1usize;
for i in 0..c {
let factor = s + c - i;
let divisor = i + 1;
match result.checked_mul(factor) {
Some(v) => result = v / divisor,
None => return usize::MAX, }
}
result
}
fn step_forward_primal<F: Float + BtapeThreadLocal>(
step: &impl Fn(&[BReverse<F>]) -> Vec<BReverse<F>>,
state: &[F],
) -> Vec<F> {
let mut tape = BytecodeTape::with_capacity(state.len() * 10);
let inputs: Vec<BReverse<F>> = state
.iter()
.map(|&val| {
let idx = tape.new_input(val);
BReverse::from_tape(val, idx)
})
.collect();
{
let _guard = BtapeGuard::new(&mut tape);
let outputs = step(&inputs);
outputs.iter().map(|r| r.value).collect()
}
}
fn vjp_step<F: Float + BtapeThreadLocal>(
step: &impl Fn(&[BReverse<F>]) -> Vec<BReverse<F>>,
state: &[F],
w: &[F],
) -> Vec<F> {
let dim = state.len();
let mut tape = BytecodeTape::with_capacity(dim * 10);
let inputs: Vec<BReverse<F>> = state
.iter()
.map(|&val| {
let idx = tape.new_input(val);
BReverse::from_tape(val, idx)
})
.collect();
let scalar_index = {
let _guard = BtapeGuard::new(&mut tape);
let outputs = step(&inputs);
assert_eq!(
outputs.len(),
dim,
"step must preserve dimension: expected {}, got {}",
dim,
outputs.len()
);
let mut scalar = BReverse::constant(F::zero());
for i in 0..dim {
scalar += BReverse::constant(w[i]) * outputs[i];
}
scalar.index
};
tape.set_output(scalar_index);
tape.gradient(state)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn beta_base_cases() {
assert_eq!(beta(0, 0), 1);
assert_eq!(beta(1, 0), 2);
assert_eq!(beta(5, 0), 6);
assert_eq!(beta(0, 1), 1);
assert_eq!(beta(0, 5), 1);
assert_eq!(beta(1, 1), 2);
assert_eq!(beta(2, 2), 6);
assert_eq!(beta(3, 2), 10);
}
#[test]
fn revolve_schedule_store_all() {
let positions = revolve_schedule(5, 5);
assert_eq!(positions, vec![1, 2, 3, 4]);
}
#[test]
fn revolve_schedule_one_checkpoint() {
let positions = revolve_schedule(4, 1);
assert!(!positions.is_empty());
for &p in &positions {
assert!(p > 0 && p < 4);
}
}
#[test]
fn revolve_schedule_two_checkpoints() {
let positions = revolve_schedule(10, 2);
assert!(!positions.is_empty());
for &p in &positions {
assert!(p > 0 && p < 10);
}
}
}