oxibonsai_model/gradient_checkpoint.rs
1//! Gradient checkpointing: trade compute for memory in training.
2//!
3//! Instead of storing all intermediate activations, checkpointing only
4//! stores "checkpoint" tensors at layer boundaries and recomputes
5//! intermediate values during backward pass.
6//!
7//! # Memory Trade-off
8//!
9//! For a network with N segments each producing an activation of size A:
10//! - Without checkpointing: stores N activations → N * A bytes
11//! - With checkpointing: stores N inputs (same as activations for matching dims)
12//! but for networks where output > input, the savings can be significant.
13//!
14//! The savings fraction = 1 - (sum of input sizes) / (sum of output sizes).
15
16use thiserror::Error;
17
18// ─── Error types ─────────────────────────────────────────────────────────────
19
20/// Errors that can arise during checkpointed computation.
21#[derive(Debug, Error)]
22pub enum CheckpointError {
23 /// The memory budget was exceeded when trying to allocate.
24 #[error("memory budget exceeded: need {need}, available {available}")]
25 BudgetExceeded { need: usize, available: usize },
26 /// An empty segment list was provided where at least one is required.
27 #[error("empty segment list")]
28 EmptySegments,
29 /// The input vector length does not match the expected dimension.
30 #[error("dimension mismatch: input has {got} elements, expected {expected}")]
31 DimMismatch { expected: usize, got: usize },
32 /// The pipeline has no segments.
33 #[error("empty pipeline")]
34 EmptyPipeline,
35}
36
37// ─── Recomputable trait ──────────────────────────────────────────────────────
38
39/// A computation segment that can be re-run on demand.
40///
41/// Implementors represent a pure function from `Input` to `Output` that can be
42/// called an arbitrary number of times — once during the forward pass, and once
43/// more per segment during the backward pass to recover intermediate activations.
44///
45/// # Thread safety
46///
47/// Both `Self`, `Input`, and `Output` must be `Send + Sync` so that checkpointed
48/// networks can be used across threads (e.g., in data-parallel training).
49pub trait Recomputable: Send + Sync {
50 /// The input type accepted by this segment.
51 type Input: Clone + Send + Sync;
52 /// The output type produced by this segment.
53 type Output: Clone + Send + Sync;
54
55 /// Compute the forward pass of this segment.
56 ///
57 /// This will be called at least twice: once during the main forward pass and
58 /// once during the backward pass when the output is needed for gradient
59 /// computation. Implementations must be **deterministic** — identical inputs
60 /// must always produce identical outputs.
61 fn forward(&self, input: &Self::Input) -> Self::Output;
62
63 /// Estimate the memory footprint of one input value in bytes.
64 ///
65 /// Used by [`CheckpointBudget`] to track how much memory the checkpointed
66 /// inputs collectively occupy.
67 fn input_memory_bytes(input: &Self::Input) -> usize;
68}
69
70// ─── Checkpoint ──────────────────────────────────────────────────────────────
71
72/// A single checkpointed computation segment.
73///
74/// Stores the segment implementation and its saved input so that the output can
75/// be recomputed at any time. The output itself is **not** stored; calling
76/// [`recompute`] always re-runs the forward pass.
77///
78/// [`recompute`]: Checkpoint::recompute
79pub struct Checkpoint<R: Recomputable> {
80 recomputable: R,
81 saved_input: R::Input,
82}
83
84impl<R: Recomputable> Checkpoint<R> {
85 /// Create a new checkpoint, saving `input` for later recomputation.
86 ///
87 /// The `recomputable` segment is stored alongside the input so that
88 /// [`recompute`] can call `recomputable.forward(&saved_input)`.
89 ///
90 /// [`recompute`]: Self::recompute
91 pub fn new(recomputable: R, input: R::Input) -> Self {
92 Self {
93 recomputable,
94 saved_input: input,
95 }
96 }
97
98 /// Recompute and return the output from the saved input.
99 ///
100 /// This is the key operation: instead of loading a cached output tensor,
101 /// we replay the forward pass from the checkpointed input.
102 pub fn recompute(&self) -> R::Output {
103 self.recomputable.forward(&self.saved_input)
104 }
105
106 /// Bytes consumed by the saved checkpoint (input only, not the output).
107 pub fn memory_bytes(&self) -> usize {
108 R::input_memory_bytes(&self.saved_input)
109 }
110}
111
112// ─── LinearSegment ───────────────────────────────────────────────────────────
113
114/// A simple fully-connected (linear/dense) layer operating on flat `f32`
115/// vectors.
116///
117/// Implements the linear transformation `y = W x` (no bias) where `W` is a
118/// `[out_dim × in_dim]` matrix stored in row-major order.
119///
120/// This is primarily intended for testing gradient-checkpointing mechanics
121/// without pulling in heavy tensor libraries.
122#[derive(Clone)]
123pub struct LinearSegment {
124 /// Weight matrix stored row-major: `weights[i * in_dim + j]` = W[i, j].
125 pub weights: Vec<f32>,
126 /// Number of input features.
127 pub in_dim: usize,
128 /// Number of output features.
129 pub out_dim: usize,
130}
131
132impl LinearSegment {
133 /// Create a `LinearSegment` from explicit weights.
134 ///
135 /// # Panics (debug only)
136 ///
137 /// Panics if `weights.len() != in_dim * out_dim`.
138 pub fn new(weights: Vec<f32>, in_dim: usize, out_dim: usize) -> Self {
139 debug_assert_eq!(
140 weights.len(),
141 in_dim * out_dim,
142 "weights.len() must equal in_dim * out_dim"
143 );
144 Self {
145 weights,
146 in_dim,
147 out_dim,
148 }
149 }
150
151 /// Initialise weights pseudo-randomly using a simple 64-bit LCG.
152 ///
153 /// The LCG parameters (multiplier / increment) are the same as used by
154 /// Knuth and Numerical Recipes. Weights are scaled to `[-1, 1]` using
155 /// Xavier-style normalisation: `w ∈ [-sqrt(6/(in+out)), sqrt(6/(in+out))]`.
156 ///
157 /// No external `rand` crate is required.
158 pub fn random_init(in_dim: usize, out_dim: usize, seed: u64) -> Self {
159 let n = in_dim * out_dim;
160 let mut state = seed;
161 let xavier_limit = (6.0_f64 / (in_dim + out_dim) as f64).sqrt() as f32;
162
163 let weights: Vec<f32> = (0..n)
164 .map(|_| {
165 // LCG step: x_{n+1} = (a * x_n + c) mod 2^64
166 state = state
167 .wrapping_mul(6_364_136_223_846_793_005)
168 .wrapping_add(1_442_695_040_888_963_407);
169 // Map upper 32 bits to [0, 1) then to [-limit, limit]
170 let uniform = (state >> 32) as f32 / u32::MAX as f32; // [0, 1]
171 uniform * 2.0 * xavier_limit - xavier_limit
172 })
173 .collect();
174
175 Self {
176 weights,
177 in_dim,
178 out_dim,
179 }
180 }
181}
182
183impl Recomputable for LinearSegment {
184 type Input = Vec<f32>;
185 type Output = Vec<f32>;
186
187 /// Compute `y = W x` where `x` has length `in_dim` and `y` has length
188 /// `out_dim`.
189 fn forward(&self, input: &Vec<f32>) -> Vec<f32> {
190 let mut output = vec![0.0f32; self.out_dim];
191 // Each output neuron i: y[i] = sum_j W[i,j] * x[j]
192 for (i, out_val) in output.iter_mut().enumerate() {
193 let row_start = i * self.in_dim;
194 let row = &self.weights[row_start..row_start + self.in_dim];
195 let mut acc = 0.0f32;
196 for (w, x) in row.iter().zip(input.iter()) {
197 acc += w * x;
198 }
199 *out_val = acc;
200 }
201 output
202 }
203
204 /// Each `f32` is 4 bytes.
205 fn input_memory_bytes(input: &Vec<f32>) -> usize {
206 input.len() * 4
207 }
208}
209
210// ─── CheckpointedNetwork ─────────────────────────────────────────────────────
211
212/// A sequential network where every layer boundary is checkpointed.
213///
214/// During the forward pass, each segment's output is fed as input to the next
215/// segment, but only the inputs are retained — outputs are discarded. On
216/// demand (e.g., during the backward pass) any segment's output can be
217/// recovered by calling `checkpoint.recompute()`.
218///
219/// This type is generic over any `Recomputable` whose `Input` and `Output` are
220/// both `Vec<f32>`, making it suitable for chain-of-linear-segments networks.
221pub struct CheckpointedNetwork<R: Recomputable<Input = Vec<f32>, Output = Vec<f32>>> {
222 segments: Vec<Checkpoint<R>>,
223}
224
225impl<R: Recomputable<Input = Vec<f32>, Output = Vec<f32>>> CheckpointedNetwork<R> {
226 /// Construct the network from pre-built checkpoints.
227 pub fn new(segments: Vec<Checkpoint<R>>) -> Self {
228 Self { segments }
229 }
230
231 /// Execute the full forward pass, returning the output of the final segment.
232 ///
233 /// Internally each segment is run in order; the checkpointed inputs are
234 /// already stored so we merely recompute each in sequence.
235 ///
236 /// Returns an error if `segments` is empty.
237 pub fn forward(&self, _input: &[f32]) -> Vec<f32> {
238 if self.segments.is_empty() {
239 return Vec::new();
240 }
241 // Each segment already has its saved input; run them in order.
242 let mut output = self.segments[0].recompute();
243 for seg in self.segments.iter().skip(1) {
244 // Re-run the segment from its checkpointed input (ignores `output`
245 // from the previous iteration since inputs are already stored).
246 output = seg.recompute();
247 }
248 output
249 }
250
251 /// Total bytes used by all checkpointed inputs.
252 pub fn memory_bytes(&self) -> usize {
253 self.segments.iter().map(|s| s.memory_bytes()).sum()
254 }
255
256 /// Hypothetical memory if we stored every segment's **output** instead of
257 /// its input.
258 ///
259 /// For a typical expanding network (out_dim > in_dim) this will be larger
260 /// than [`memory_bytes`], demonstrating the checkpointing advantage.
261 ///
262 /// [`memory_bytes`]: Self::memory_bytes
263 pub fn full_memory_bytes(&self) -> usize {
264 self.segments
265 .iter()
266 .map(|s| {
267 // Recompute output and measure its size.
268 let out = s.recompute();
269 out.len() * 4
270 })
271 .sum()
272 }
273
274 /// Fraction of memory saved relative to storing all outputs.
275 ///
276 /// Returns a value in `[0, 1)`. A result of `0.0` means no savings (input
277 /// == output size everywhere); values near `1.0` mean the full-storage cost
278 /// would be much higher.
279 pub fn memory_savings(&self) -> f32 {
280 let full = self.full_memory_bytes() as f32;
281 if full <= 0.0 {
282 return 0.0;
283 }
284 let ckpt = self.memory_bytes() as f32;
285 ((full - ckpt) / full).max(0.0)
286 }
287}
288
289// ─── CheckpointBudget ────────────────────────────────────────────────────────
290
291/// Memory budget tracker for gradient checkpointing.
292///
293/// Maintains a running total of bytes allocated to checkpointed inputs and
294/// enforces an upper bound. Call [`allocate`] when a new checkpoint is
295/// created and [`free`] when it is discarded (e.g., after the backward pass
296/// for that segment completes).
297///
298/// [`allocate`]: CheckpointBudget::allocate
299/// [`free`]: CheckpointBudget::free
300#[derive(Debug, Clone)]
301pub struct CheckpointBudget {
302 /// Maximum permitted allocation in bytes.
303 pub max_bytes: usize,
304 /// Currently allocated bytes.
305 pub used_bytes: usize,
306}
307
308impl CheckpointBudget {
309 /// Create a fresh budget with `max_bytes` capacity and zero usage.
310 pub fn new(max_bytes: usize) -> Self {
311 Self {
312 max_bytes,
313 used_bytes: 0,
314 }
315 }
316
317 /// Bytes still available for allocation.
318 pub fn remaining(&self) -> usize {
319 self.max_bytes.saturating_sub(self.used_bytes)
320 }
321
322 /// Fraction of the budget that has been consumed (`used / max`).
323 ///
324 /// Returns `0.0` when `max_bytes == 0` to avoid division by zero.
325 pub fn utilization(&self) -> f32 {
326 if self.max_bytes == 0 {
327 return 0.0;
328 }
329 self.used_bytes as f32 / self.max_bytes as f32
330 }
331
332 /// Whether `bytes` can be allocated without exceeding the budget.
333 pub fn can_allocate(&self, bytes: usize) -> bool {
334 self.used_bytes.saturating_add(bytes) <= self.max_bytes
335 }
336
337 /// Attempt to allocate `bytes`.
338 ///
339 /// On success, `used_bytes` increases by `bytes`.
340 /// On failure, returns [`CheckpointError::BudgetExceeded`] and leaves
341 /// `used_bytes` unchanged.
342 pub fn allocate(&mut self, bytes: usize) -> Result<(), CheckpointError> {
343 if !self.can_allocate(bytes) {
344 return Err(CheckpointError::BudgetExceeded {
345 need: bytes,
346 available: self.remaining(),
347 });
348 }
349 self.used_bytes += bytes;
350 Ok(())
351 }
352
353 /// Release `bytes` back to the budget.
354 ///
355 /// Uses saturating subtraction to avoid underflow if `bytes` exceeds
356 /// `used_bytes` (which would indicate a programming error, but should not
357 /// panic in production).
358 pub fn free(&mut self, bytes: usize) {
359 self.used_bytes = self.used_bytes.saturating_sub(bytes);
360 }
361}
362
363// ─── CheckpointSegment (concrete, non-generic) ─────────────────────────────
364
365/// A recomputable segment: stores weights and dimensions so the forward
366/// pass (matrix-vector product `y = W * x`) can be re-executed cheaply.
367///
368/// Unlike the generic [`LinearSegment`] + [`Recomputable`] approach, this is
369/// a self-contained struct that carries everything needed for recomputation.
370pub struct CheckpointSegment {
371 /// Human-readable name for this segment (e.g. `"layer_3"`).
372 pub name: String,
373 /// Row-major weight matrix of shape `[out_dim, in_dim]`.
374 pub weights: Vec<f32>,
375 /// Input dimension.
376 pub in_dim: usize,
377 /// Output dimension.
378 pub out_dim: usize,
379}
380
381impl CheckpointSegment {
382 /// Create a segment with explicitly provided weights.
383 pub fn new(name: impl Into<String>, weights: Vec<f32>, in_dim: usize, out_dim: usize) -> Self {
384 Self {
385 name: name.into(),
386 weights,
387 in_dim,
388 out_dim,
389 }
390 }
391
392 /// Create a segment with LCG-initialised weights (no `rand` crate).
393 ///
394 /// Uses the Knuth MMIX LCG constants. Weights are mapped to `[-1, 1]`.
395 pub fn init_lcg(name: impl Into<String>, in_dim: usize, out_dim: usize, seed: u64) -> Self {
396 let count = in_dim * out_dim;
397 let mut state = seed;
398 let mut weights = Vec::with_capacity(count);
399 for _ in 0..count {
400 state = state
401 .wrapping_mul(6_364_136_223_846_793_005)
402 .wrapping_add(1_442_695_040_888_963_407);
403 let bits = (state >> 33) as i32;
404 weights.push(bits as f32 / (1u64 << 31) as f32);
405 }
406 Self {
407 name: name.into(),
408 weights,
409 in_dim,
410 out_dim,
411 }
412 }
413
414 /// Forward pass: compute `y = W * x` (matrix-vector product).
415 ///
416 /// `input` must have exactly `in_dim` elements.
417 /// Returns a vector of `out_dim` elements.
418 pub fn forward(&self, input: &[f32]) -> Result<Vec<f32>, CheckpointError> {
419 if input.len() != self.in_dim {
420 return Err(CheckpointError::DimMismatch {
421 expected: self.in_dim,
422 got: input.len(),
423 });
424 }
425 let mut output = vec![0.0f32; self.out_dim];
426 for (row, out_val) in output.iter_mut().enumerate() {
427 let row_offset = row * self.in_dim;
428 let mut acc = 0.0f32;
429 for (col, inp_val) in input.iter().enumerate() {
430 acc += self.weights[row_offset + col] * inp_val;
431 }
432 *out_val = acc;
433 }
434 Ok(output)
435 }
436
437 /// Memory (in bytes) required to store one full activation (output).
438 pub fn activation_memory(&self) -> usize {
439 self.out_dim * std::mem::size_of::<f32>()
440 }
441}
442
443// ─── CheckpointedActivation ─────────────────────────────────────────────────
444
445/// A checkpointed activation: stores only the input and recomputes the
446/// output on demand via the associated [`CheckpointSegment`].
447pub struct CheckpointedActivation {
448 segment: CheckpointSegment,
449 saved_input: Vec<f32>,
450}
451
452impl CheckpointedActivation {
453 /// Create a new checkpointed activation.
454 pub fn new(segment: CheckpointSegment, input: Vec<f32>) -> Self {
455 Self {
456 segment,
457 saved_input: input,
458 }
459 }
460
461 /// Recompute the output from the saved input.
462 pub fn recompute(&self) -> Result<Vec<f32>, CheckpointError> {
463 self.segment.forward(&self.saved_input)
464 }
465
466 /// Memory actually consumed by this checkpoint (input only, in bytes).
467 pub fn memory_bytes(&self) -> usize {
468 self.saved_input.len() * std::mem::size_of::<f32>()
469 }
470
471 /// Memory that would be consumed if both input and output were stored.
472 pub fn full_memory_bytes(&self) -> usize {
473 self.memory_bytes() + self.segment.activation_memory()
474 }
475
476 /// Fraction of memory saved compared to the full (non-checkpointed) case.
477 ///
478 /// Returns a value in `[0.0, 1.0]`.
479 pub fn memory_savings(&self) -> f32 {
480 let full = self.full_memory_bytes();
481 if full == 0 {
482 return 0.0;
483 }
484 1.0 - (self.memory_bytes() as f32 / full as f32)
485 }
486}
487
488// ─── CheckpointedPipeline ───────────────────────────────────────────────────
489
490/// A sequence of checkpointed layers that can be run end-to-end.
491///
492/// Stores only the per-layer inputs (not outputs) and recomputes
493/// activations as needed.
494pub struct CheckpointedPipeline {
495 segments: Vec<CheckpointSegment>,
496}
497
498impl CheckpointedPipeline {
499 /// Build a pipeline from a list of segments.
500 pub fn new(segments: Vec<CheckpointSegment>) -> Self {
501 Self { segments }
502 }
503
504 /// Run the full forward pass through all segments.
505 pub fn forward(&self, input: &[f32]) -> Result<Vec<f32>, CheckpointError> {
506 if self.segments.is_empty() {
507 return Err(CheckpointError::EmptyPipeline);
508 }
509 let mut current = input.to_vec();
510 for seg in &self.segments {
511 current = seg.forward(¤t)?;
512 }
513 Ok(current)
514 }
515
516 /// Number of segments in the pipeline.
517 pub fn num_segments(&self) -> usize {
518 self.segments.len()
519 }
520
521 /// Total checkpoint memory for a given input size.
522 ///
523 /// First layer saves the original input; subsequent layers save their
524 /// predecessor's output (= predecessor's `out_dim`).
525 pub fn total_checkpoint_memory(&self, input_size: usize) -> usize {
526 if self.segments.is_empty() {
527 return 0;
528 }
529 let f32_size = std::mem::size_of::<f32>();
530 let mut total = input_size * f32_size;
531 for i in 0..self.segments.len() - 1 {
532 total += self.segments[i].out_dim * f32_size;
533 }
534 total
535 }
536
537 /// Total memory if all activations (inputs **and** outputs) were stored.
538 pub fn total_full_memory(&self) -> usize {
539 let f32_size = std::mem::size_of::<f32>();
540 self.segments
541 .iter()
542 .map(|s| (s.in_dim + s.out_dim) * f32_size)
543 .sum()
544 }
545
546 /// Overall memory savings fraction for the whole pipeline.
547 pub fn overall_savings(&self, input_size: usize) -> f32 {
548 let full = self.total_full_memory();
549 if full == 0 {
550 return 0.0;
551 }
552 let ckpt = self.total_checkpoint_memory(input_size);
553 1.0 - (ckpt as f32 / full as f32)
554 }
555}
556
557// ─── CheckpointStrategy ────────────────────────────────────────────────────
558
559/// Strategy for selecting which layers to checkpoint.
560#[derive(Debug, Clone, Copy, PartialEq)]
561pub enum CheckpointStrategy {
562 /// Checkpoint every layer.
563 Every,
564 /// Checkpoint every N-th layer (layers 0, N, 2N, ...).
565 EveryNth(usize),
566 /// Checkpoint approximately sqrt(N) layers, evenly spaced.
567 Sqrt,
568 /// No checkpointing at all.
569 None,
570}
571
572impl CheckpointStrategy {
573 /// Given `total_layers`, return sorted indices of layers to checkpoint.
574 pub fn select_layers(&self, total_layers: usize) -> Vec<usize> {
575 match self {
576 CheckpointStrategy::Every => (0..total_layers).collect(),
577 CheckpointStrategy::EveryNth(n) => {
578 let step = if *n == 0 { 1 } else { *n };
579 (0..total_layers).filter(|i| i % step == 0).collect()
580 }
581 CheckpointStrategy::Sqrt => {
582 if total_layers == 0 {
583 return Vec::new();
584 }
585 let count = isqrt(total_layers).max(1);
586 if count >= total_layers {
587 return (0..total_layers).collect();
588 }
589 let step = total_layers / count;
590 let mut layers = Vec::with_capacity(count);
591 let mut idx = 0;
592 while idx < total_layers && layers.len() < count {
593 layers.push(idx);
594 idx += step;
595 }
596 layers
597 }
598 CheckpointStrategy::None => Vec::new(),
599 }
600 }
601}
602
603/// Integer square root (floor) via Newton's method.
604fn isqrt(n: usize) -> usize {
605 if n < 2 {
606 return n;
607 }
608 let mut x = n;
609 let mut y = x.div_ceil(2);
610 while y < x {
611 x = y;
612 y = (x + n / x) / 2;
613 }
614 x
615}
616
617// ─── Tests ───────────────────────────────────────────────────────────────────
618
619#[cfg(test)]
620mod tests {
621 use super::*;
622
623 #[test]
624 fn linear_segment_forward_shape() {
625 let seg = LinearSegment::random_init(4, 8, 42);
626 let input = vec![1.0f32; 4];
627 let out = seg.forward(&input);
628 assert_eq!(out.len(), 8, "output should have out_dim elements");
629 }
630
631 #[test]
632 fn linear_segment_forward_deterministic() {
633 let seg = LinearSegment::random_init(4, 8, 99);
634 let input = vec![0.5f32, -0.5, 1.0, -1.0];
635 let out1 = seg.forward(&input);
636 let out2 = seg.forward(&input);
637 assert_eq!(out1, out2, "forward must be deterministic");
638 }
639
640 #[test]
641 fn checkpoint_recompute_equals_forward() {
642 let seg = LinearSegment::random_init(3, 6, 7);
643 let input = vec![1.0f32, 2.0, 3.0];
644 let expected = seg.forward(&input);
645 let ckpt = Checkpoint::new(seg, input);
646 let got = ckpt.recompute();
647 assert_eq!(got, expected, "recompute must equal original forward");
648 }
649
650 #[test]
651 fn checkpoint_memory_input_only() {
652 let seg = LinearSegment::random_init(5, 10, 0);
653 let input = vec![0.0f32; 5];
654 let ckpt = Checkpoint::new(seg, input);
655 assert_eq!(ckpt.memory_bytes(), 5 * 4, "checkpoint stores input only");
656 }
657
658 #[test]
659 fn network_forward_runs() {
660 let seg1 = LinearSegment::random_init(4, 8, 1);
661 let seg2 = LinearSegment::random_init(8, 4, 2);
662 let input1 = vec![1.0f32; 4];
663 let mid = seg1.forward(&input1);
664 let input2 = mid.clone();
665 let c1 = Checkpoint::new(seg1, input1);
666 let c2 = Checkpoint::new(seg2, input2);
667 let net = CheckpointedNetwork::new(vec![c1, c2]);
668 let out = net.forward(&[1.0f32; 4]);
669 assert_eq!(
670 out.len(),
671 4,
672 "output should not panic and have correct length"
673 );
674 }
675
676 #[test]
677 fn network_memory_savings_positive() {
678 // Expanding network: 4→16, 16→64 — outputs are much larger than inputs.
679 let seg1 = LinearSegment::random_init(4, 16, 10);
680 let seg2 = LinearSegment::random_init(16, 64, 11);
681 let input1 = vec![1.0f32; 4];
682 let mid = seg1.forward(&input1);
683 let c1 = Checkpoint::new(seg1, input1);
684 let c2 = Checkpoint::new(seg2, mid);
685 let net = CheckpointedNetwork::new(vec![c1, c2]);
686 let savings = net.memory_savings();
687 assert!(
688 savings > 0.0,
689 "expanding network should save memory, got {savings}"
690 );
691 }
692
693 #[test]
694 fn network_full_memory_greater() {
695 let seg1 = LinearSegment::random_init(4, 16, 20);
696 let seg2 = LinearSegment::random_init(16, 64, 21);
697 let input1 = vec![0.5f32; 4];
698 let mid = seg1.forward(&input1);
699 let c1 = Checkpoint::new(seg1, input1);
700 let c2 = Checkpoint::new(seg2, mid);
701 let net = CheckpointedNetwork::new(vec![c1, c2]);
702 assert!(
703 net.full_memory_bytes() > net.memory_bytes(),
704 "full storage must use more memory than checkpointed storage"
705 );
706 }
707
708 #[test]
709 fn budget_new() {
710 let b = CheckpointBudget::new(1024);
711 assert_eq!(b.used_bytes, 0, "fresh budget should have used_bytes = 0");
712 assert_eq!(b.max_bytes, 1024);
713 }
714
715 #[test]
716 fn budget_allocate_within() {
717 let mut b = CheckpointBudget::new(1024);
718 let result = b.allocate(256);
719 assert!(result.is_ok(), "allocation within budget must succeed");
720 assert_eq!(b.used_bytes, 256);
721 }
722
723 #[test]
724 fn budget_allocate_exceed() {
725 let mut b = CheckpointBudget::new(100);
726 let result = b.allocate(200);
727 assert!(
728 matches!(result, Err(CheckpointError::BudgetExceeded { .. })),
729 "allocation exceeding budget must return BudgetExceeded"
730 );
731 assert_eq!(
732 b.used_bytes, 0,
733 "failed allocation must not change used_bytes"
734 );
735 }
736
737 #[test]
738 fn budget_free() {
739 let mut b = CheckpointBudget::new(1024);
740 b.allocate(512).expect("allocation should succeed");
741 b.free(256);
742 assert_eq!(b.used_bytes, 256);
743 }
744
745 #[test]
746 fn budget_utilization() {
747 let mut b = CheckpointBudget::new(1000);
748 b.allocate(250).expect("allocation should succeed");
749 let util = b.utilization();
750 assert!(
751 (util - 0.25).abs() < 1e-6,
752 "utilization should be 0.25, got {util}"
753 );
754 }
755
756 #[test]
757 fn network_single_segment() {
758 let seg = LinearSegment::random_init(3, 3, 55);
759 let input = vec![1.0f32, 0.0, -1.0];
760 let c = Checkpoint::new(seg, input);
761 let net = CheckpointedNetwork::new(vec![c]);
762 let out = net.forward(&[1.0f32, 0.0, -1.0]);
763 assert_eq!(out.len(), 3, "single-segment network should produce output");
764 }
765}