Skip to main content

oxibonsai_model/
gradient_checkpoint.rs

1//! Gradient checkpointing: trade compute for memory in training.
2//!
3//! Instead of storing all intermediate activations, checkpointing only
4//! stores "checkpoint" tensors at layer boundaries and recomputes
5//! intermediate values during backward pass.
6//!
7//! # Memory Trade-off
8//!
9//! For a network with N segments each producing an activation of size A:
10//! - Without checkpointing: stores N activations → N * A bytes
11//! - With checkpointing: stores N inputs (same as activations for matching dims)
12//!   but for networks where output > input, the savings can be significant.
13//!
14//! The savings fraction = 1 - (sum of input sizes) / (sum of output sizes).
15
16use thiserror::Error;
17
18// ─── Error types ─────────────────────────────────────────────────────────────
19
20/// Errors that can arise during checkpointed computation.
21#[derive(Debug, Error)]
22pub enum CheckpointError {
23    /// The memory budget was exceeded when trying to allocate.
24    #[error("memory budget exceeded: need {need}, available {available}")]
25    BudgetExceeded { need: usize, available: usize },
26    /// An empty segment list was provided where at least one is required.
27    #[error("empty segment list")]
28    EmptySegments,
29    /// The input vector length does not match the expected dimension.
30    #[error("dimension mismatch: input has {got} elements, expected {expected}")]
31    DimMismatch { expected: usize, got: usize },
32    /// The pipeline has no segments.
33    #[error("empty pipeline")]
34    EmptyPipeline,
35}
36
37// ─── Recomputable trait ──────────────────────────────────────────────────────
38
39/// A computation segment that can be re-run on demand.
40///
41/// Implementors represent a pure function from `Input` to `Output` that can be
42/// called an arbitrary number of times — once during the forward pass, and once
43/// more per segment during the backward pass to recover intermediate activations.
44///
45/// # Thread safety
46///
47/// Both `Self`, `Input`, and `Output` must be `Send + Sync` so that checkpointed
48/// networks can be used across threads (e.g., in data-parallel training).
49pub trait Recomputable: Send + Sync {
50    /// The input type accepted by this segment.
51    type Input: Clone + Send + Sync;
52    /// The output type produced by this segment.
53    type Output: Clone + Send + Sync;
54
55    /// Compute the forward pass of this segment.
56    ///
57    /// This will be called at least twice: once during the main forward pass and
58    /// once during the backward pass when the output is needed for gradient
59    /// computation.  Implementations must be **deterministic** — identical inputs
60    /// must always produce identical outputs.
61    fn forward(&self, input: &Self::Input) -> Self::Output;
62
63    /// Estimate the memory footprint of one input value in bytes.
64    ///
65    /// Used by [`CheckpointBudget`] to track how much memory the checkpointed
66    /// inputs collectively occupy.
67    fn input_memory_bytes(input: &Self::Input) -> usize;
68}
69
70// ─── Checkpoint ──────────────────────────────────────────────────────────────
71
72/// A single checkpointed computation segment.
73///
74/// Stores the segment implementation and its saved input so that the output can
75/// be recomputed at any time.  The output itself is **not** stored; calling
76/// [`recompute`] always re-runs the forward pass.
77///
78/// [`recompute`]: Checkpoint::recompute
79pub struct Checkpoint<R: Recomputable> {
80    recomputable: R,
81    saved_input: R::Input,
82}
83
84impl<R: Recomputable> Checkpoint<R> {
85    /// Create a new checkpoint, saving `input` for later recomputation.
86    ///
87    /// The `recomputable` segment is stored alongside the input so that
88    /// [`recompute`] can call `recomputable.forward(&saved_input)`.
89    ///
90    /// [`recompute`]: Self::recompute
91    pub fn new(recomputable: R, input: R::Input) -> Self {
92        Self {
93            recomputable,
94            saved_input: input,
95        }
96    }
97
98    /// Recompute and return the output from the saved input.
99    ///
100    /// This is the key operation: instead of loading a cached output tensor,
101    /// we replay the forward pass from the checkpointed input.
102    pub fn recompute(&self) -> R::Output {
103        self.recomputable.forward(&self.saved_input)
104    }
105
106    /// Bytes consumed by the saved checkpoint (input only, not the output).
107    pub fn memory_bytes(&self) -> usize {
108        R::input_memory_bytes(&self.saved_input)
109    }
110}
111
112// ─── LinearSegment ───────────────────────────────────────────────────────────
113
114/// A simple fully-connected (linear/dense) layer operating on flat `f32`
115/// vectors.
116///
117/// Implements the linear transformation `y = W x` (no bias) where `W` is a
118/// `[out_dim × in_dim]` matrix stored in row-major order.
119///
120/// This is primarily intended for testing gradient-checkpointing mechanics
121/// without pulling in heavy tensor libraries.
122#[derive(Clone)]
123pub struct LinearSegment {
124    /// Weight matrix stored row-major: `weights[i * in_dim + j]` = W[i, j].
125    pub weights: Vec<f32>,
126    /// Number of input features.
127    pub in_dim: usize,
128    /// Number of output features.
129    pub out_dim: usize,
130}
131
132impl LinearSegment {
133    /// Create a `LinearSegment` from explicit weights.
134    ///
135    /// # Panics (debug only)
136    ///
137    /// Panics if `weights.len() != in_dim * out_dim`.
138    pub fn new(weights: Vec<f32>, in_dim: usize, out_dim: usize) -> Self {
139        debug_assert_eq!(
140            weights.len(),
141            in_dim * out_dim,
142            "weights.len() must equal in_dim * out_dim"
143        );
144        Self {
145            weights,
146            in_dim,
147            out_dim,
148        }
149    }
150
151    /// Initialise weights pseudo-randomly using a simple 64-bit LCG.
152    ///
153    /// The LCG parameters (multiplier / increment) are the same as used by
154    /// Knuth and Numerical Recipes.  Weights are scaled to `[-1, 1]` using
155    /// Xavier-style normalisation: `w ∈ [-sqrt(6/(in+out)), sqrt(6/(in+out))]`.
156    ///
157    /// No external `rand` crate is required.
158    pub fn random_init(in_dim: usize, out_dim: usize, seed: u64) -> Self {
159        let n = in_dim * out_dim;
160        let mut state = seed;
161        let xavier_limit = (6.0_f64 / (in_dim + out_dim) as f64).sqrt() as f32;
162
163        let weights: Vec<f32> = (0..n)
164            .map(|_| {
165                // LCG step: x_{n+1} = (a * x_n + c) mod 2^64
166                state = state
167                    .wrapping_mul(6_364_136_223_846_793_005)
168                    .wrapping_add(1_442_695_040_888_963_407);
169                // Map upper 32 bits to [0, 1) then to [-limit, limit]
170                let uniform = (state >> 32) as f32 / u32::MAX as f32; // [0, 1]
171                uniform * 2.0 * xavier_limit - xavier_limit
172            })
173            .collect();
174
175        Self {
176            weights,
177            in_dim,
178            out_dim,
179        }
180    }
181}
182
183impl Recomputable for LinearSegment {
184    type Input = Vec<f32>;
185    type Output = Vec<f32>;
186
187    /// Compute `y = W x` where `x` has length `in_dim` and `y` has length
188    /// `out_dim`.
189    fn forward(&self, input: &Vec<f32>) -> Vec<f32> {
190        let mut output = vec![0.0f32; self.out_dim];
191        // Each output neuron i: y[i] = sum_j W[i,j] * x[j]
192        for (i, out_val) in output.iter_mut().enumerate() {
193            let row_start = i * self.in_dim;
194            let row = &self.weights[row_start..row_start + self.in_dim];
195            let mut acc = 0.0f32;
196            for (w, x) in row.iter().zip(input.iter()) {
197                acc += w * x;
198            }
199            *out_val = acc;
200        }
201        output
202    }
203
204    /// Each `f32` is 4 bytes.
205    fn input_memory_bytes(input: &Vec<f32>) -> usize {
206        input.len() * 4
207    }
208}
209
210// ─── CheckpointedNetwork ─────────────────────────────────────────────────────
211
212/// A sequential network where every layer boundary is checkpointed.
213///
214/// During the forward pass, each segment's output is fed as input to the next
215/// segment, but only the inputs are retained — outputs are discarded.  On
216/// demand (e.g., during the backward pass) any segment's output can be
217/// recovered by calling `checkpoint.recompute()`.
218///
219/// This type is generic over any `Recomputable` whose `Input` and `Output` are
220/// both `Vec<f32>`, making it suitable for chain-of-linear-segments networks.
221pub struct CheckpointedNetwork<R: Recomputable<Input = Vec<f32>, Output = Vec<f32>>> {
222    segments: Vec<Checkpoint<R>>,
223}
224
225impl<R: Recomputable<Input = Vec<f32>, Output = Vec<f32>>> CheckpointedNetwork<R> {
226    /// Construct the network from pre-built checkpoints.
227    pub fn new(segments: Vec<Checkpoint<R>>) -> Self {
228        Self { segments }
229    }
230
231    /// Execute the full forward pass, returning the output of the final segment.
232    ///
233    /// Internally each segment is run in order; the checkpointed inputs are
234    /// already stored so we merely recompute each in sequence.
235    ///
236    /// Returns an error if `segments` is empty.
237    pub fn forward(&self, _input: &[f32]) -> Vec<f32> {
238        if self.segments.is_empty() {
239            return Vec::new();
240        }
241        // Each segment already has its saved input; run them in order.
242        let mut output = self.segments[0].recompute();
243        for seg in self.segments.iter().skip(1) {
244            // Re-run the segment from its checkpointed input (ignores `output`
245            // from the previous iteration since inputs are already stored).
246            output = seg.recompute();
247        }
248        output
249    }
250
251    /// Total bytes used by all checkpointed inputs.
252    pub fn memory_bytes(&self) -> usize {
253        self.segments.iter().map(|s| s.memory_bytes()).sum()
254    }
255
256    /// Hypothetical memory if we stored every segment's **output** instead of
257    /// its input.
258    ///
259    /// For a typical expanding network (out_dim > in_dim) this will be larger
260    /// than [`memory_bytes`], demonstrating the checkpointing advantage.
261    ///
262    /// [`memory_bytes`]: Self::memory_bytes
263    pub fn full_memory_bytes(&self) -> usize {
264        self.segments
265            .iter()
266            .map(|s| {
267                // Recompute output and measure its size.
268                let out = s.recompute();
269                out.len() * 4
270            })
271            .sum()
272    }
273
274    /// Fraction of memory saved relative to storing all outputs.
275    ///
276    /// Returns a value in `[0, 1)`.  A result of `0.0` means no savings (input
277    /// == output size everywhere); values near `1.0` mean the full-storage cost
278    /// would be much higher.
279    pub fn memory_savings(&self) -> f32 {
280        let full = self.full_memory_bytes() as f32;
281        if full <= 0.0 {
282            return 0.0;
283        }
284        let ckpt = self.memory_bytes() as f32;
285        ((full - ckpt) / full).max(0.0)
286    }
287}
288
289// ─── CheckpointBudget ────────────────────────────────────────────────────────
290
291/// Memory budget tracker for gradient checkpointing.
292///
293/// Maintains a running total of bytes allocated to checkpointed inputs and
294/// enforces an upper bound.  Call [`allocate`] when a new checkpoint is
295/// created and [`free`] when it is discarded (e.g., after the backward pass
296/// for that segment completes).
297///
298/// [`allocate`]: CheckpointBudget::allocate
299/// [`free`]: CheckpointBudget::free
300#[derive(Debug, Clone)]
301pub struct CheckpointBudget {
302    /// Maximum permitted allocation in bytes.
303    pub max_bytes: usize,
304    /// Currently allocated bytes.
305    pub used_bytes: usize,
306}
307
308impl CheckpointBudget {
309    /// Create a fresh budget with `max_bytes` capacity and zero usage.
310    pub fn new(max_bytes: usize) -> Self {
311        Self {
312            max_bytes,
313            used_bytes: 0,
314        }
315    }
316
317    /// Bytes still available for allocation.
318    pub fn remaining(&self) -> usize {
319        self.max_bytes.saturating_sub(self.used_bytes)
320    }
321
322    /// Fraction of the budget that has been consumed (`used / max`).
323    ///
324    /// Returns `0.0` when `max_bytes == 0` to avoid division by zero.
325    pub fn utilization(&self) -> f32 {
326        if self.max_bytes == 0 {
327            return 0.0;
328        }
329        self.used_bytes as f32 / self.max_bytes as f32
330    }
331
332    /// Whether `bytes` can be allocated without exceeding the budget.
333    pub fn can_allocate(&self, bytes: usize) -> bool {
334        self.used_bytes.saturating_add(bytes) <= self.max_bytes
335    }
336
337    /// Attempt to allocate `bytes`.
338    ///
339    /// On success, `used_bytes` increases by `bytes`.
340    /// On failure, returns [`CheckpointError::BudgetExceeded`] and leaves
341    /// `used_bytes` unchanged.
342    pub fn allocate(&mut self, bytes: usize) -> Result<(), CheckpointError> {
343        if !self.can_allocate(bytes) {
344            return Err(CheckpointError::BudgetExceeded {
345                need: bytes,
346                available: self.remaining(),
347            });
348        }
349        self.used_bytes += bytes;
350        Ok(())
351    }
352
353    /// Release `bytes` back to the budget.
354    ///
355    /// Uses saturating subtraction to avoid underflow if `bytes` exceeds
356    /// `used_bytes` (which would indicate a programming error, but should not
357    /// panic in production).
358    pub fn free(&mut self, bytes: usize) {
359        self.used_bytes = self.used_bytes.saturating_sub(bytes);
360    }
361}
362
363// ─── CheckpointSegment (concrete, non-generic) ─────────────────────────────
364
365/// A recomputable segment: stores weights and dimensions so the forward
366/// pass (matrix-vector product `y = W * x`) can be re-executed cheaply.
367///
368/// Unlike the generic [`LinearSegment`] + [`Recomputable`] approach, this is
369/// a self-contained struct that carries everything needed for recomputation.
370pub struct CheckpointSegment {
371    /// Human-readable name for this segment (e.g. `"layer_3"`).
372    pub name: String,
373    /// Row-major weight matrix of shape `[out_dim, in_dim]`.
374    pub weights: Vec<f32>,
375    /// Input dimension.
376    pub in_dim: usize,
377    /// Output dimension.
378    pub out_dim: usize,
379}
380
381impl CheckpointSegment {
382    /// Create a segment with explicitly provided weights.
383    pub fn new(name: impl Into<String>, weights: Vec<f32>, in_dim: usize, out_dim: usize) -> Self {
384        Self {
385            name: name.into(),
386            weights,
387            in_dim,
388            out_dim,
389        }
390    }
391
392    /// Create a segment with LCG-initialised weights (no `rand` crate).
393    ///
394    /// Uses the Knuth MMIX LCG constants. Weights are mapped to `[-1, 1]`.
395    pub fn init_lcg(name: impl Into<String>, in_dim: usize, out_dim: usize, seed: u64) -> Self {
396        let count = in_dim * out_dim;
397        let mut state = seed;
398        let mut weights = Vec::with_capacity(count);
399        for _ in 0..count {
400            state = state
401                .wrapping_mul(6_364_136_223_846_793_005)
402                .wrapping_add(1_442_695_040_888_963_407);
403            let bits = (state >> 33) as i32;
404            weights.push(bits as f32 / (1u64 << 31) as f32);
405        }
406        Self {
407            name: name.into(),
408            weights,
409            in_dim,
410            out_dim,
411        }
412    }
413
414    /// Forward pass: compute `y = W * x` (matrix-vector product).
415    ///
416    /// `input` must have exactly `in_dim` elements.
417    /// Returns a vector of `out_dim` elements.
418    pub fn forward(&self, input: &[f32]) -> Result<Vec<f32>, CheckpointError> {
419        if input.len() != self.in_dim {
420            return Err(CheckpointError::DimMismatch {
421                expected: self.in_dim,
422                got: input.len(),
423            });
424        }
425        let mut output = vec![0.0f32; self.out_dim];
426        for (row, out_val) in output.iter_mut().enumerate() {
427            let row_offset = row * self.in_dim;
428            let mut acc = 0.0f32;
429            for (col, inp_val) in input.iter().enumerate() {
430                acc += self.weights[row_offset + col] * inp_val;
431            }
432            *out_val = acc;
433        }
434        Ok(output)
435    }
436
437    /// Memory (in bytes) required to store one full activation (output).
438    pub fn activation_memory(&self) -> usize {
439        self.out_dim * std::mem::size_of::<f32>()
440    }
441}
442
443// ─── CheckpointedActivation ─────────────────────────────────────────────────
444
445/// A checkpointed activation: stores only the input and recomputes the
446/// output on demand via the associated [`CheckpointSegment`].
447pub struct CheckpointedActivation {
448    segment: CheckpointSegment,
449    saved_input: Vec<f32>,
450}
451
452impl CheckpointedActivation {
453    /// Create a new checkpointed activation.
454    pub fn new(segment: CheckpointSegment, input: Vec<f32>) -> Self {
455        Self {
456            segment,
457            saved_input: input,
458        }
459    }
460
461    /// Recompute the output from the saved input.
462    pub fn recompute(&self) -> Result<Vec<f32>, CheckpointError> {
463        self.segment.forward(&self.saved_input)
464    }
465
466    /// Memory actually consumed by this checkpoint (input only, in bytes).
467    pub fn memory_bytes(&self) -> usize {
468        self.saved_input.len() * std::mem::size_of::<f32>()
469    }
470
471    /// Memory that would be consumed if both input and output were stored.
472    pub fn full_memory_bytes(&self) -> usize {
473        self.memory_bytes() + self.segment.activation_memory()
474    }
475
476    /// Fraction of memory saved compared to the full (non-checkpointed) case.
477    ///
478    /// Returns a value in `[0.0, 1.0]`.
479    pub fn memory_savings(&self) -> f32 {
480        let full = self.full_memory_bytes();
481        if full == 0 {
482            return 0.0;
483        }
484        1.0 - (self.memory_bytes() as f32 / full as f32)
485    }
486}
487
488// ─── CheckpointedPipeline ───────────────────────────────────────────────────
489
490/// A sequence of checkpointed layers that can be run end-to-end.
491///
492/// Stores only the per-layer inputs (not outputs) and recomputes
493/// activations as needed.
494pub struct CheckpointedPipeline {
495    segments: Vec<CheckpointSegment>,
496}
497
498impl CheckpointedPipeline {
499    /// Build a pipeline from a list of segments.
500    pub fn new(segments: Vec<CheckpointSegment>) -> Self {
501        Self { segments }
502    }
503
504    /// Run the full forward pass through all segments.
505    pub fn forward(&self, input: &[f32]) -> Result<Vec<f32>, CheckpointError> {
506        if self.segments.is_empty() {
507            return Err(CheckpointError::EmptyPipeline);
508        }
509        let mut current = input.to_vec();
510        for seg in &self.segments {
511            current = seg.forward(&current)?;
512        }
513        Ok(current)
514    }
515
516    /// Number of segments in the pipeline.
517    pub fn num_segments(&self) -> usize {
518        self.segments.len()
519    }
520
521    /// Total checkpoint memory for a given input size.
522    ///
523    /// First layer saves the original input; subsequent layers save their
524    /// predecessor's output (= predecessor's `out_dim`).
525    pub fn total_checkpoint_memory(&self, input_size: usize) -> usize {
526        if self.segments.is_empty() {
527            return 0;
528        }
529        let f32_size = std::mem::size_of::<f32>();
530        let mut total = input_size * f32_size;
531        for i in 0..self.segments.len() - 1 {
532            total += self.segments[i].out_dim * f32_size;
533        }
534        total
535    }
536
537    /// Total memory if all activations (inputs **and** outputs) were stored.
538    pub fn total_full_memory(&self) -> usize {
539        let f32_size = std::mem::size_of::<f32>();
540        self.segments
541            .iter()
542            .map(|s| (s.in_dim + s.out_dim) * f32_size)
543            .sum()
544    }
545
546    /// Overall memory savings fraction for the whole pipeline.
547    pub fn overall_savings(&self, input_size: usize) -> f32 {
548        let full = self.total_full_memory();
549        if full == 0 {
550            return 0.0;
551        }
552        let ckpt = self.total_checkpoint_memory(input_size);
553        1.0 - (ckpt as f32 / full as f32)
554    }
555}
556
557// ─── CheckpointStrategy ────────────────────────────────────────────────────
558
559/// Strategy for selecting which layers to checkpoint.
560#[derive(Debug, Clone, Copy, PartialEq)]
561pub enum CheckpointStrategy {
562    /// Checkpoint every layer.
563    Every,
564    /// Checkpoint every N-th layer (layers 0, N, 2N, ...).
565    EveryNth(usize),
566    /// Checkpoint approximately sqrt(N) layers, evenly spaced.
567    Sqrt,
568    /// No checkpointing at all.
569    None,
570}
571
572impl CheckpointStrategy {
573    /// Given `total_layers`, return sorted indices of layers to checkpoint.
574    pub fn select_layers(&self, total_layers: usize) -> Vec<usize> {
575        match self {
576            CheckpointStrategy::Every => (0..total_layers).collect(),
577            CheckpointStrategy::EveryNth(n) => {
578                let step = if *n == 0 { 1 } else { *n };
579                (0..total_layers).filter(|i| i % step == 0).collect()
580            }
581            CheckpointStrategy::Sqrt => {
582                if total_layers == 0 {
583                    return Vec::new();
584                }
585                let count = isqrt(total_layers).max(1);
586                if count >= total_layers {
587                    return (0..total_layers).collect();
588                }
589                let step = total_layers / count;
590                let mut layers = Vec::with_capacity(count);
591                let mut idx = 0;
592                while idx < total_layers && layers.len() < count {
593                    layers.push(idx);
594                    idx += step;
595                }
596                layers
597            }
598            CheckpointStrategy::None => Vec::new(),
599        }
600    }
601}
602
603/// Integer square root (floor) via Newton's method.
604fn isqrt(n: usize) -> usize {
605    if n < 2 {
606        return n;
607    }
608    let mut x = n;
609    let mut y = x.div_ceil(2);
610    while y < x {
611        x = y;
612        y = (x + n / x) / 2;
613    }
614    x
615}
616
617// ─── Tests ───────────────────────────────────────────────────────────────────
618
619#[cfg(test)]
620mod tests {
621    use super::*;
622
623    #[test]
624    fn linear_segment_forward_shape() {
625        let seg = LinearSegment::random_init(4, 8, 42);
626        let input = vec![1.0f32; 4];
627        let out = seg.forward(&input);
628        assert_eq!(out.len(), 8, "output should have out_dim elements");
629    }
630
631    #[test]
632    fn linear_segment_forward_deterministic() {
633        let seg = LinearSegment::random_init(4, 8, 99);
634        let input = vec![0.5f32, -0.5, 1.0, -1.0];
635        let out1 = seg.forward(&input);
636        let out2 = seg.forward(&input);
637        assert_eq!(out1, out2, "forward must be deterministic");
638    }
639
640    #[test]
641    fn checkpoint_recompute_equals_forward() {
642        let seg = LinearSegment::random_init(3, 6, 7);
643        let input = vec![1.0f32, 2.0, 3.0];
644        let expected = seg.forward(&input);
645        let ckpt = Checkpoint::new(seg, input);
646        let got = ckpt.recompute();
647        assert_eq!(got, expected, "recompute must equal original forward");
648    }
649
650    #[test]
651    fn checkpoint_memory_input_only() {
652        let seg = LinearSegment::random_init(5, 10, 0);
653        let input = vec![0.0f32; 5];
654        let ckpt = Checkpoint::new(seg, input);
655        assert_eq!(ckpt.memory_bytes(), 5 * 4, "checkpoint stores input only");
656    }
657
658    #[test]
659    fn network_forward_runs() {
660        let seg1 = LinearSegment::random_init(4, 8, 1);
661        let seg2 = LinearSegment::random_init(8, 4, 2);
662        let input1 = vec![1.0f32; 4];
663        let mid = seg1.forward(&input1);
664        let input2 = mid.clone();
665        let c1 = Checkpoint::new(seg1, input1);
666        let c2 = Checkpoint::new(seg2, input2);
667        let net = CheckpointedNetwork::new(vec![c1, c2]);
668        let out = net.forward(&[1.0f32; 4]);
669        assert_eq!(
670            out.len(),
671            4,
672            "output should not panic and have correct length"
673        );
674    }
675
676    #[test]
677    fn network_memory_savings_positive() {
678        // Expanding network: 4→16, 16→64 — outputs are much larger than inputs.
679        let seg1 = LinearSegment::random_init(4, 16, 10);
680        let seg2 = LinearSegment::random_init(16, 64, 11);
681        let input1 = vec![1.0f32; 4];
682        let mid = seg1.forward(&input1);
683        let c1 = Checkpoint::new(seg1, input1);
684        let c2 = Checkpoint::new(seg2, mid);
685        let net = CheckpointedNetwork::new(vec![c1, c2]);
686        let savings = net.memory_savings();
687        assert!(
688            savings > 0.0,
689            "expanding network should save memory, got {savings}"
690        );
691    }
692
693    #[test]
694    fn network_full_memory_greater() {
695        let seg1 = LinearSegment::random_init(4, 16, 20);
696        let seg2 = LinearSegment::random_init(16, 64, 21);
697        let input1 = vec![0.5f32; 4];
698        let mid = seg1.forward(&input1);
699        let c1 = Checkpoint::new(seg1, input1);
700        let c2 = Checkpoint::new(seg2, mid);
701        let net = CheckpointedNetwork::new(vec![c1, c2]);
702        assert!(
703            net.full_memory_bytes() > net.memory_bytes(),
704            "full storage must use more memory than checkpointed storage"
705        );
706    }
707
708    #[test]
709    fn budget_new() {
710        let b = CheckpointBudget::new(1024);
711        assert_eq!(b.used_bytes, 0, "fresh budget should have used_bytes = 0");
712        assert_eq!(b.max_bytes, 1024);
713    }
714
715    #[test]
716    fn budget_allocate_within() {
717        let mut b = CheckpointBudget::new(1024);
718        let result = b.allocate(256);
719        assert!(result.is_ok(), "allocation within budget must succeed");
720        assert_eq!(b.used_bytes, 256);
721    }
722
723    #[test]
724    fn budget_allocate_exceed() {
725        let mut b = CheckpointBudget::new(100);
726        let result = b.allocate(200);
727        assert!(
728            matches!(result, Err(CheckpointError::BudgetExceeded { .. })),
729            "allocation exceeding budget must return BudgetExceeded"
730        );
731        assert_eq!(
732            b.used_bytes, 0,
733            "failed allocation must not change used_bytes"
734        );
735    }
736
737    #[test]
738    fn budget_free() {
739        let mut b = CheckpointBudget::new(1024);
740        b.allocate(512).expect("allocation should succeed");
741        b.free(256);
742        assert_eq!(b.used_bytes, 256);
743    }
744
745    #[test]
746    fn budget_utilization() {
747        let mut b = CheckpointBudget::new(1000);
748        b.allocate(250).expect("allocation should succeed");
749        let util = b.utilization();
750        assert!(
751            (util - 0.25).abs() < 1e-6,
752            "utilization should be 0.25, got {util}"
753        );
754    }
755
756    #[test]
757    fn network_single_segment() {
758        let seg = LinearSegment::random_init(3, 3, 55);
759        let input = vec![1.0f32, 0.0, -1.0];
760        let c = Checkpoint::new(seg, input);
761        let net = CheckpointedNetwork::new(vec![c]);
762        let out = net.forward(&[1.0f32, 0.0, -1.0]);
763        assert_eq!(out.len(), 3, "single-segment network should produce output");
764    }
765}