Skip to main content

entrenar/lora/
paged_optim.rs

1//! Paged optimizer states for QLoRA (ENT-LoRA-010)
2//!
3//! Pages AdamW m/v (momentum/variance) states to CPU RAM when GPU VRAM pressure
4//! is detected. This enables training larger models on smaller GPUs.
5//!
6//! Architecture:
7//! - All optimizer states (m, v) live on CPU by default
8//! - Before optimizer step, relevant states are paged into GPU buffers
9//! - After step, states are paged back to CPU
10//! - VRAM budget tracks pressure and triggers paging
11//!
12//! Memory savings for 7B model at rank=16:
13//! - Full m/v on GPU: 2 × 7B × 4 bytes = 56 GB (impossible on consumer GPU)
14//! - LoRA m/v only: 2 × ~5.9M × 4 bytes = ~47 MB (always fits)
15//! - Paged: m/v on CPU, paged in per-layer = constant ~200KB GPU overhead
16
17use ndarray::Array1;
18
19/// VRAM budget tracker for optimizer state paging decisions
20#[derive(Debug, Clone)]
21pub struct VramBudget {
22    /// Total VRAM in bytes
23    total_bytes: u64,
24    /// Reserved for model weights and activations (bytes)
25    reserved_bytes: u64,
26    /// Target utilization (0.0 - 1.0)
27    target_utilization: f64,
28}
29
30impl VramBudget {
31    /// Create a new VRAM budget
32    pub fn new(total_vram_gb: f64) -> Self {
33        Self {
34            total_bytes: (total_vram_gb * 1e9) as u64,
35            reserved_bytes: 0,
36            target_utilization: 0.85,
37        }
38    }
39
40    /// Set reserved bytes (model weights + activations)
41    pub fn with_reserved(mut self, reserved_gb: f64) -> Self {
42        self.reserved_bytes = (reserved_gb * 1e9) as u64;
43        self
44    }
45
46    /// Set target utilization
47    pub fn with_target(mut self, target: f64) -> Self {
48        self.target_utilization = target.clamp(0.5, 0.95);
49        self
50    }
51
52    /// Available bytes for optimizer states
53    pub fn available_bytes(&self) -> u64 {
54        let budget = (self.total_bytes as f64 * self.target_utilization) as u64;
55        budget.saturating_sub(self.reserved_bytes)
56    }
57
58    /// Check if a given number of bytes would fit on GPU
59    pub fn fits(&self, bytes: u64) -> bool {
60        bytes <= self.available_bytes()
61    }
62}
63
64/// Paging strategy for optimizer states
65#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66pub enum PagingStrategy {
67    /// All states on CPU, page in per-layer during step (safest, slowest)
68    FullyPaged,
69    /// Keep states on GPU if they fit, fall back to paging (adaptive)
70    Adaptive,
71    /// Never page — fail with OOM if states don't fit (default, fastest)
72    None,
73}
74
75/// CPU-resident optimizer state for one parameter group
76#[derive(Debug, Clone)]
77pub struct PagedState {
78    /// First moment (m) stored on CPU
79    pub m: Option<Array1<f32>>,
80    /// Second moment (v) stored on CPU
81    pub v: Option<Array1<f32>>,
82    /// Number of elements
83    pub len: usize,
84    /// Whether this state is currently paged in to GPU
85    pub on_gpu: bool,
86}
87
88impl PagedState {
89    /// Create empty state for a parameter of given length
90    pub fn new(len: usize) -> Self {
91        Self { m: None, v: None, len, on_gpu: false }
92    }
93
94    /// Initialize states (lazy, called on first use)
95    pub fn ensure_initialized(&mut self) {
96        if self.m.is_none() {
97            self.m = Some(Array1::zeros(self.len));
98            self.v = Some(Array1::zeros(self.len));
99        }
100    }
101
102    /// Memory usage in bytes on CPU
103    pub fn cpu_bytes(&self) -> usize {
104        // m + v, each Array1<f32> = len * 4 bytes
105        if self.m.is_some() {
106            self.len * 8
107        } else {
108            0
109        }
110    }
111
112    /// Memory that would be needed on GPU
113    pub fn gpu_bytes(&self) -> usize {
114        self.len * 8 // m + v
115    }
116}
117
118/// Paged optimizer state manager
119///
120/// Wraps optimizer state storage with CPU↔GPU paging capability.
121/// On CPU-only systems, this is essentially a no-op wrapper.
122pub struct PagedOptimStates {
123    /// Per-parameter optimizer states (CPU-resident)
124    states: Vec<PagedState>,
125    /// VRAM budget for paging decisions
126    budget: VramBudget,
127    /// Paging strategy
128    strategy: PagingStrategy,
129    /// Number of page-in events (for monitoring)
130    page_in_count: u64,
131    /// Number of page-out events
132    page_out_count: u64,
133}
134
135impl PagedOptimStates {
136    /// Create a new paged optimizer state manager
137    pub fn new(budget: VramBudget, strategy: PagingStrategy) -> Self {
138        Self { states: Vec::new(), budget, strategy, page_in_count: 0, page_out_count: 0 }
139    }
140
141    /// Register a parameter group
142    pub fn register(&mut self, param_len: usize) -> usize {
143        let idx = self.states.len();
144        self.states.push(PagedState::new(param_len));
145        idx
146    }
147
148    /// Get mutable state for a parameter, paging in if necessary
149    pub fn get_state_mut(&mut self, idx: usize) -> &mut PagedState {
150        self.states[idx].ensure_initialized();
151
152        if self.strategy == PagingStrategy::FullyPaged && self.states[idx].on_gpu {
153            // State is on GPU, need to page out others first
154            self.page_out_count += 1;
155        }
156
157        if !self.states[idx].on_gpu && self.strategy != PagingStrategy::None {
158            self.page_in_count += 1;
159        }
160
161        &mut self.states[idx]
162    }
163
164    /// Get immutable state
165    pub fn get_state(&self, idx: usize) -> &PagedState {
166        &self.states[idx]
167    }
168
169    /// Total CPU memory used by all states (bytes)
170    pub fn total_cpu_bytes(&self) -> usize {
171        self.states.iter().map(PagedState::cpu_bytes).sum()
172    }
173
174    /// Number of registered parameter groups
175    pub fn num_states(&self) -> usize {
176        self.states.len()
177    }
178
179    /// Would all states fit on GPU simultaneously?
180    pub fn all_fit_on_gpu(&self) -> bool {
181        let total: u64 = self.states.iter().map(|s| s.gpu_bytes() as u64).sum();
182        self.budget.fits(total)
183    }
184
185    /// Get paging statistics
186    pub fn stats(&self) -> PagingStats {
187        PagingStats {
188            page_in_count: self.page_in_count,
189            page_out_count: self.page_out_count,
190            total_cpu_bytes: self.total_cpu_bytes(),
191            num_states: self.states.len(),
192            strategy: self.strategy,
193        }
194    }
195}
196
197/// Paging statistics for monitoring
198#[derive(Debug, Clone)]
199pub struct PagingStats {
200    /// Number of page-in events
201    pub page_in_count: u64,
202    /// Number of page-out events
203    pub page_out_count: u64,
204    /// Total CPU memory used (bytes)
205    pub total_cpu_bytes: usize,
206    /// Number of parameter groups
207    pub num_states: usize,
208    /// Active paging strategy
209    pub strategy: PagingStrategy,
210}
211
212impl PagingStats {
213    /// Format as human-readable string
214    pub fn summary(&self) -> String {
215        format!(
216            "Paged optimizer: {} states, {:.1} MB CPU, {} page-ins, {} page-outs, strategy={:?}",
217            self.num_states,
218            self.total_cpu_bytes as f64 / 1e6,
219            self.page_in_count,
220            self.page_out_count,
221            self.strategy,
222        )
223    }
224}
225
226#[cfg(test)]
227#[allow(clippy::unwrap_used)]
228mod tests {
229    use super::*;
230    use proptest::prelude::*;
231
232    #[test]
233    fn test_ent_lora_010_vram_budget_basic() {
234        let budget = VramBudget::new(16.0).with_reserved(10.0);
235        // 16 * 0.85 - 10 = 3.6 GB available
236        let avail = budget.available_bytes();
237        assert!(avail > 3_000_000_000);
238        assert!(avail < 4_000_000_000);
239    }
240
241    #[test]
242    fn test_ent_lora_010_vram_budget_fits() {
243        let budget = VramBudget::new(16.0).with_reserved(10.0);
244        assert!(budget.fits(1_000_000_000)); // 1GB fits
245        assert!(!budget.fits(10_000_000_000)); // 10GB doesn't
246    }
247
248    #[test]
249    fn test_ent_lora_010_paged_state_lifecycle() {
250        let mut state = PagedState::new(1024);
251        assert_eq!(state.cpu_bytes(), 0);
252
253        state.ensure_initialized();
254        assert_eq!(state.cpu_bytes(), 1024 * 8); // m + v
255        assert_eq!(state.gpu_bytes(), 1024 * 8);
256        assert!(state.m.is_some());
257        assert!(state.v.is_some());
258    }
259
260    #[test]
261    fn test_ent_lora_010_paged_optim_register() {
262        let budget = VramBudget::new(16.0);
263        let mut paged = PagedOptimStates::new(budget, PagingStrategy::Adaptive);
264
265        let idx0 = paged.register(512);
266        let idx1 = paged.register(1024);
267
268        assert_eq!(idx0, 0);
269        assert_eq!(idx1, 1);
270        assert_eq!(paged.num_states(), 2);
271    }
272
273    #[test]
274    fn test_ent_lora_010_paged_optim_get_state() {
275        let budget = VramBudget::new(16.0);
276        let mut paged = PagedOptimStates::new(budget, PagingStrategy::FullyPaged);
277        paged.register(256);
278
279        let state = paged.get_state_mut(0);
280        assert!(state.m.is_some()); // Lazily initialized
281        assert_eq!(state.m.as_ref().unwrap().len(), 256);
282    }
283
284    #[test]
285    fn test_ent_lora_010_paged_optim_stats() {
286        let budget = VramBudget::new(16.0);
287        let mut paged = PagedOptimStates::new(budget, PagingStrategy::FullyPaged);
288        paged.register(1024);
289        let _ = paged.get_state_mut(0); // Triggers page-in
290
291        let stats = paged.stats();
292        assert_eq!(stats.num_states, 1);
293        assert!(stats.total_cpu_bytes > 0);
294        assert!(stats.page_in_count > 0);
295        assert!(stats.summary().contains("Paged optimizer"));
296    }
297
298    #[test]
299    fn test_ent_lora_010_all_fit_on_gpu() {
300        let budget = VramBudget::new(16.0).with_reserved(0.0);
301        let mut paged = PagedOptimStates::new(budget, PagingStrategy::Adaptive);
302        // 1M params × 8 bytes = 8MB — easily fits in 16GB
303        paged.register(1_000_000);
304        assert!(paged.all_fit_on_gpu());
305    }
306
307    #[test]
308    fn test_ent_lora_010_does_not_fit_on_gpu() {
309        let budget = VramBudget::new(0.001); // 1MB VRAM
310        let mut paged = PagedOptimStates::new(budget, PagingStrategy::Adaptive);
311        // 100M params × 8 bytes = 800MB — doesn't fit in 1MB
312        paged.register(100_000_000);
313        assert!(!paged.all_fit_on_gpu());
314    }
315
316    #[test]
317    fn test_ent_lora_010_strategy_none() {
318        let budget = VramBudget::new(16.0);
319        let mut paged = PagedOptimStates::new(budget, PagingStrategy::None);
320        paged.register(512);
321        let _ = paged.get_state_mut(0);
322
323        let stats = paged.stats();
324        assert_eq!(stats.page_in_count, 0); // None strategy doesn't track pages
325    }
326
327    #[test]
328    fn test_ent_lora_010_vram_budget_target_clamping() {
329        let budget = VramBudget::new(16.0).with_target(0.1);
330        assert!(budget.target_utilization >= 0.5);
331
332        let budget = VramBudget::new(16.0).with_target(1.5);
333        assert!(budget.target_utilization <= 0.95);
334    }
335
336    proptest! {
337        #![proptest_config(proptest::test_runner::Config::with_cases(50))]
338
339        #[test]
340        fn prop_paged_state_bytes_consistent(len in 1usize..10000) {
341            let mut state = PagedState::new(len);
342            prop_assert_eq!(state.cpu_bytes(), 0);
343
344            state.ensure_initialized();
345            prop_assert_eq!(state.cpu_bytes(), len * 8);
346            prop_assert_eq!(state.gpu_bytes(), len * 8);
347        }
348
349        #[test]
350        fn prop_budget_available_nonnegative(
351            total_gb in 1.0f64..100.0,
352            reserved_gb in 0.0f64..50.0,
353        ) {
354            let budget = VramBudget::new(total_gb).with_reserved(reserved_gb);
355            // available_bytes uses saturating_sub, so always >= 0
356            let _ = budget.available_bytes(); // Just verify no panic
357        }
358    }
359}