use thiserror::Error;
#[derive(Debug, Error)]
pub enum CheckpointError {
#[error("memory budget exceeded: need {need}, available {available}")]
BudgetExceeded { need: usize, available: usize },
#[error("empty segment list")]
EmptySegments,
#[error("dimension mismatch: input has {got} elements, expected {expected}")]
DimMismatch { expected: usize, got: usize },
#[error("empty pipeline")]
EmptyPipeline,
}
pub trait Recomputable: Send + Sync {
type Input: Clone + Send + Sync;
type Output: Clone + Send + Sync;
fn forward(&self, input: &Self::Input) -> Self::Output;
fn input_memory_bytes(input: &Self::Input) -> usize;
}
pub struct Checkpoint<R: Recomputable> {
recomputable: R,
saved_input: R::Input,
}
impl<R: Recomputable> Checkpoint<R> {
pub fn new(recomputable: R, input: R::Input) -> Self {
Self {
recomputable,
saved_input: input,
}
}
pub fn recompute(&self) -> R::Output {
self.recomputable.forward(&self.saved_input)
}
pub fn memory_bytes(&self) -> usize {
R::input_memory_bytes(&self.saved_input)
}
}
#[derive(Clone)]
pub struct LinearSegment {
pub weights: Vec<f32>,
pub in_dim: usize,
pub out_dim: usize,
}
impl LinearSegment {
pub fn new(weights: Vec<f32>, in_dim: usize, out_dim: usize) -> Self {
debug_assert_eq!(
weights.len(),
in_dim * out_dim,
"weights.len() must equal in_dim * out_dim"
);
Self {
weights,
in_dim,
out_dim,
}
}
pub fn random_init(in_dim: usize, out_dim: usize, seed: u64) -> Self {
let n = in_dim * out_dim;
let mut state = seed;
let xavier_limit = (6.0_f64 / (in_dim + out_dim) as f64).sqrt() as f32;
let weights: Vec<f32> = (0..n)
.map(|_| {
state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let uniform = (state >> 32) as f32 / u32::MAX as f32; uniform * 2.0 * xavier_limit - xavier_limit
})
.collect();
Self {
weights,
in_dim,
out_dim,
}
}
}
impl Recomputable for LinearSegment {
type Input = Vec<f32>;
type Output = Vec<f32>;
fn forward(&self, input: &Vec<f32>) -> Vec<f32> {
let mut output = vec![0.0f32; self.out_dim];
for (i, out_val) in output.iter_mut().enumerate() {
let row_start = i * self.in_dim;
let row = &self.weights[row_start..row_start + self.in_dim];
let mut acc = 0.0f32;
for (w, x) in row.iter().zip(input.iter()) {
acc += w * x;
}
*out_val = acc;
}
output
}
fn input_memory_bytes(input: &Vec<f32>) -> usize {
input.len() * 4
}
}
pub struct CheckpointedNetwork<R: Recomputable<Input = Vec<f32>, Output = Vec<f32>>> {
segments: Vec<Checkpoint<R>>,
}
impl<R: Recomputable<Input = Vec<f32>, Output = Vec<f32>>> CheckpointedNetwork<R> {
pub fn new(segments: Vec<Checkpoint<R>>) -> Self {
Self { segments }
}
pub fn forward(&self, _input: &[f32]) -> Vec<f32> {
if self.segments.is_empty() {
return Vec::new();
}
let mut output = self.segments[0].recompute();
for seg in self.segments.iter().skip(1) {
output = seg.recompute();
}
output
}
pub fn memory_bytes(&self) -> usize {
self.segments.iter().map(|s| s.memory_bytes()).sum()
}
pub fn full_memory_bytes(&self) -> usize {
self.segments
.iter()
.map(|s| {
let out = s.recompute();
out.len() * 4
})
.sum()
}
pub fn memory_savings(&self) -> f32 {
let full = self.full_memory_bytes() as f32;
if full <= 0.0 {
return 0.0;
}
let ckpt = self.memory_bytes() as f32;
((full - ckpt) / full).max(0.0)
}
}
#[derive(Debug, Clone)]
pub struct CheckpointBudget {
pub max_bytes: usize,
pub used_bytes: usize,
}
impl CheckpointBudget {
pub fn new(max_bytes: usize) -> Self {
Self {
max_bytes,
used_bytes: 0,
}
}
pub fn remaining(&self) -> usize {
self.max_bytes.saturating_sub(self.used_bytes)
}
pub fn utilization(&self) -> f32 {
if self.max_bytes == 0 {
return 0.0;
}
self.used_bytes as f32 / self.max_bytes as f32
}
pub fn can_allocate(&self, bytes: usize) -> bool {
self.used_bytes.saturating_add(bytes) <= self.max_bytes
}
pub fn allocate(&mut self, bytes: usize) -> Result<(), CheckpointError> {
if !self.can_allocate(bytes) {
return Err(CheckpointError::BudgetExceeded {
need: bytes,
available: self.remaining(),
});
}
self.used_bytes += bytes;
Ok(())
}
pub fn free(&mut self, bytes: usize) {
self.used_bytes = self.used_bytes.saturating_sub(bytes);
}
}
pub struct CheckpointSegment {
pub name: String,
pub weights: Vec<f32>,
pub in_dim: usize,
pub out_dim: usize,
}
impl CheckpointSegment {
pub fn new(name: impl Into<String>, weights: Vec<f32>, in_dim: usize, out_dim: usize) -> Self {
Self {
name: name.into(),
weights,
in_dim,
out_dim,
}
}
pub fn init_lcg(name: impl Into<String>, in_dim: usize, out_dim: usize, seed: u64) -> Self {
let count = in_dim * out_dim;
let mut state = seed;
let mut weights = Vec::with_capacity(count);
for _ in 0..count {
state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let bits = (state >> 33) as i32;
weights.push(bits as f32 / (1u64 << 31) as f32);
}
Self {
name: name.into(),
weights,
in_dim,
out_dim,
}
}
pub fn forward(&self, input: &[f32]) -> Result<Vec<f32>, CheckpointError> {
if input.len() != self.in_dim {
return Err(CheckpointError::DimMismatch {
expected: self.in_dim,
got: input.len(),
});
}
let mut output = vec![0.0f32; self.out_dim];
for (row, out_val) in output.iter_mut().enumerate() {
let row_offset = row * self.in_dim;
let mut acc = 0.0f32;
for (col, inp_val) in input.iter().enumerate() {
acc += self.weights[row_offset + col] * inp_val;
}
*out_val = acc;
}
Ok(output)
}
pub fn activation_memory(&self) -> usize {
self.out_dim * std::mem::size_of::<f32>()
}
}
pub struct CheckpointedActivation {
segment: CheckpointSegment,
saved_input: Vec<f32>,
}
impl CheckpointedActivation {
pub fn new(segment: CheckpointSegment, input: Vec<f32>) -> Self {
Self {
segment,
saved_input: input,
}
}
pub fn recompute(&self) -> Result<Vec<f32>, CheckpointError> {
self.segment.forward(&self.saved_input)
}
pub fn memory_bytes(&self) -> usize {
self.saved_input.len() * std::mem::size_of::<f32>()
}
pub fn full_memory_bytes(&self) -> usize {
self.memory_bytes() + self.segment.activation_memory()
}
pub fn memory_savings(&self) -> f32 {
let full = self.full_memory_bytes();
if full == 0 {
return 0.0;
}
1.0 - (self.memory_bytes() as f32 / full as f32)
}
}
pub struct CheckpointedPipeline {
segments: Vec<CheckpointSegment>,
}
impl CheckpointedPipeline {
pub fn new(segments: Vec<CheckpointSegment>) -> Self {
Self { segments }
}
pub fn forward(&self, input: &[f32]) -> Result<Vec<f32>, CheckpointError> {
if self.segments.is_empty() {
return Err(CheckpointError::EmptyPipeline);
}
let mut current = input.to_vec();
for seg in &self.segments {
current = seg.forward(¤t)?;
}
Ok(current)
}
pub fn num_segments(&self) -> usize {
self.segments.len()
}
pub fn total_checkpoint_memory(&self, input_size: usize) -> usize {
if self.segments.is_empty() {
return 0;
}
let f32_size = std::mem::size_of::<f32>();
let mut total = input_size * f32_size;
for i in 0..self.segments.len() - 1 {
total += self.segments[i].out_dim * f32_size;
}
total
}
pub fn total_full_memory(&self) -> usize {
let f32_size = std::mem::size_of::<f32>();
self.segments
.iter()
.map(|s| (s.in_dim + s.out_dim) * f32_size)
.sum()
}
pub fn overall_savings(&self, input_size: usize) -> f32 {
let full = self.total_full_memory();
if full == 0 {
return 0.0;
}
let ckpt = self.total_checkpoint_memory(input_size);
1.0 - (ckpt as f32 / full as f32)
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum CheckpointStrategy {
Every,
EveryNth(usize),
Sqrt,
None,
}
impl CheckpointStrategy {
pub fn select_layers(&self, total_layers: usize) -> Vec<usize> {
match self {
CheckpointStrategy::Every => (0..total_layers).collect(),
CheckpointStrategy::EveryNth(n) => {
let step = if *n == 0 { 1 } else { *n };
(0..total_layers).filter(|i| i % step == 0).collect()
}
CheckpointStrategy::Sqrt => {
if total_layers == 0 {
return Vec::new();
}
let count = isqrt(total_layers).max(1);
if count >= total_layers {
return (0..total_layers).collect();
}
let step = total_layers / count;
let mut layers = Vec::with_capacity(count);
let mut idx = 0;
while idx < total_layers && layers.len() < count {
layers.push(idx);
idx += step;
}
layers
}
CheckpointStrategy::None => Vec::new(),
}
}
}
fn isqrt(n: usize) -> usize {
if n < 2 {
return n;
}
let mut x = n;
let mut y = x.div_ceil(2);
while y < x {
x = y;
y = (x + n / x) / 2;
}
x
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn linear_segment_forward_shape() {
let seg = LinearSegment::random_init(4, 8, 42);
let input = vec![1.0f32; 4];
let out = seg.forward(&input);
assert_eq!(out.len(), 8, "output should have out_dim elements");
}
#[test]
fn linear_segment_forward_deterministic() {
let seg = LinearSegment::random_init(4, 8, 99);
let input = vec![0.5f32, -0.5, 1.0, -1.0];
let out1 = seg.forward(&input);
let out2 = seg.forward(&input);
assert_eq!(out1, out2, "forward must be deterministic");
}
#[test]
fn checkpoint_recompute_equals_forward() {
let seg = LinearSegment::random_init(3, 6, 7);
let input = vec![1.0f32, 2.0, 3.0];
let expected = seg.forward(&input);
let ckpt = Checkpoint::new(seg, input);
let got = ckpt.recompute();
assert_eq!(got, expected, "recompute must equal original forward");
}
#[test]
fn checkpoint_memory_input_only() {
let seg = LinearSegment::random_init(5, 10, 0);
let input = vec![0.0f32; 5];
let ckpt = Checkpoint::new(seg, input);
assert_eq!(ckpt.memory_bytes(), 5 * 4, "checkpoint stores input only");
}
#[test]
fn network_forward_runs() {
let seg1 = LinearSegment::random_init(4, 8, 1);
let seg2 = LinearSegment::random_init(8, 4, 2);
let input1 = vec![1.0f32; 4];
let mid = seg1.forward(&input1);
let input2 = mid.clone();
let c1 = Checkpoint::new(seg1, input1);
let c2 = Checkpoint::new(seg2, input2);
let net = CheckpointedNetwork::new(vec![c1, c2]);
let out = net.forward(&[1.0f32; 4]);
assert_eq!(
out.len(),
4,
"output should not panic and have correct length"
);
}
#[test]
fn network_memory_savings_positive() {
let seg1 = LinearSegment::random_init(4, 16, 10);
let seg2 = LinearSegment::random_init(16, 64, 11);
let input1 = vec![1.0f32; 4];
let mid = seg1.forward(&input1);
let c1 = Checkpoint::new(seg1, input1);
let c2 = Checkpoint::new(seg2, mid);
let net = CheckpointedNetwork::new(vec![c1, c2]);
let savings = net.memory_savings();
assert!(
savings > 0.0,
"expanding network should save memory, got {savings}"
);
}
#[test]
fn network_full_memory_greater() {
let seg1 = LinearSegment::random_init(4, 16, 20);
let seg2 = LinearSegment::random_init(16, 64, 21);
let input1 = vec![0.5f32; 4];
let mid = seg1.forward(&input1);
let c1 = Checkpoint::new(seg1, input1);
let c2 = Checkpoint::new(seg2, mid);
let net = CheckpointedNetwork::new(vec![c1, c2]);
assert!(
net.full_memory_bytes() > net.memory_bytes(),
"full storage must use more memory than checkpointed storage"
);
}
#[test]
fn budget_new() {
let b = CheckpointBudget::new(1024);
assert_eq!(b.used_bytes, 0, "fresh budget should have used_bytes = 0");
assert_eq!(b.max_bytes, 1024);
}
#[test]
fn budget_allocate_within() {
let mut b = CheckpointBudget::new(1024);
let result = b.allocate(256);
assert!(result.is_ok(), "allocation within budget must succeed");
assert_eq!(b.used_bytes, 256);
}
#[test]
fn budget_allocate_exceed() {
let mut b = CheckpointBudget::new(100);
let result = b.allocate(200);
assert!(
matches!(result, Err(CheckpointError::BudgetExceeded { .. })),
"allocation exceeding budget must return BudgetExceeded"
);
assert_eq!(
b.used_bytes, 0,
"failed allocation must not change used_bytes"
);
}
#[test]
fn budget_free() {
let mut b = CheckpointBudget::new(1024);
b.allocate(512).expect("allocation should succeed");
b.free(256);
assert_eq!(b.used_bytes, 256);
}
#[test]
fn budget_utilization() {
let mut b = CheckpointBudget::new(1000);
b.allocate(250).expect("allocation should succeed");
let util = b.utilization();
assert!(
(util - 0.25).abs() < 1e-6,
"utilization should be 0.25, got {util}"
);
}
#[test]
fn network_single_segment() {
let seg = LinearSegment::random_init(3, 3, 55);
let input = vec![1.0f32, 0.0, -1.0];
let c = Checkpoint::new(seg, input);
let net = CheckpointedNetwork::new(vec![c]);
let out = net.forward(&[1.0f32, 0.0, -1.0]);
assert_eq!(out.len(), 3, "single-segment network should produce output");
}
}