entrenar/lora/
paged_optim.rs1use ndarray::Array1;
18
19#[derive(Debug, Clone)]
21pub struct VramBudget {
22 total_bytes: u64,
24 reserved_bytes: u64,
26 target_utilization: f64,
28}
29
30impl VramBudget {
31 pub fn new(total_vram_gb: f64) -> Self {
33 Self {
34 total_bytes: (total_vram_gb * 1e9) as u64,
35 reserved_bytes: 0,
36 target_utilization: 0.85,
37 }
38 }
39
40 pub fn with_reserved(mut self, reserved_gb: f64) -> Self {
42 self.reserved_bytes = (reserved_gb * 1e9) as u64;
43 self
44 }
45
46 pub fn with_target(mut self, target: f64) -> Self {
48 self.target_utilization = target.clamp(0.5, 0.95);
49 self
50 }
51
52 pub fn available_bytes(&self) -> u64 {
54 let budget = (self.total_bytes as f64 * self.target_utilization) as u64;
55 budget.saturating_sub(self.reserved_bytes)
56 }
57
58 pub fn fits(&self, bytes: u64) -> bool {
60 bytes <= self.available_bytes()
61 }
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66pub enum PagingStrategy {
67 FullyPaged,
69 Adaptive,
71 None,
73}
74
75#[derive(Debug, Clone)]
77pub struct PagedState {
78 pub m: Option<Array1<f32>>,
80 pub v: Option<Array1<f32>>,
82 pub len: usize,
84 pub on_gpu: bool,
86}
87
88impl PagedState {
89 pub fn new(len: usize) -> Self {
91 Self { m: None, v: None, len, on_gpu: false }
92 }
93
94 pub fn ensure_initialized(&mut self) {
96 if self.m.is_none() {
97 self.m = Some(Array1::zeros(self.len));
98 self.v = Some(Array1::zeros(self.len));
99 }
100 }
101
102 pub fn cpu_bytes(&self) -> usize {
104 if self.m.is_some() {
106 self.len * 8
107 } else {
108 0
109 }
110 }
111
112 pub fn gpu_bytes(&self) -> usize {
114 self.len * 8 }
116}
117
118pub struct PagedOptimStates {
123 states: Vec<PagedState>,
125 budget: VramBudget,
127 strategy: PagingStrategy,
129 page_in_count: u64,
131 page_out_count: u64,
133}
134
135impl PagedOptimStates {
136 pub fn new(budget: VramBudget, strategy: PagingStrategy) -> Self {
138 Self { states: Vec::new(), budget, strategy, page_in_count: 0, page_out_count: 0 }
139 }
140
141 pub fn register(&mut self, param_len: usize) -> usize {
143 let idx = self.states.len();
144 self.states.push(PagedState::new(param_len));
145 idx
146 }
147
148 pub fn get_state_mut(&mut self, idx: usize) -> &mut PagedState {
150 self.states[idx].ensure_initialized();
151
152 if self.strategy == PagingStrategy::FullyPaged && self.states[idx].on_gpu {
153 self.page_out_count += 1;
155 }
156
157 if !self.states[idx].on_gpu && self.strategy != PagingStrategy::None {
158 self.page_in_count += 1;
159 }
160
161 &mut self.states[idx]
162 }
163
164 pub fn get_state(&self, idx: usize) -> &PagedState {
166 &self.states[idx]
167 }
168
169 pub fn total_cpu_bytes(&self) -> usize {
171 self.states.iter().map(PagedState::cpu_bytes).sum()
172 }
173
174 pub fn num_states(&self) -> usize {
176 self.states.len()
177 }
178
179 pub fn all_fit_on_gpu(&self) -> bool {
181 let total: u64 = self.states.iter().map(|s| s.gpu_bytes() as u64).sum();
182 self.budget.fits(total)
183 }
184
185 pub fn stats(&self) -> PagingStats {
187 PagingStats {
188 page_in_count: self.page_in_count,
189 page_out_count: self.page_out_count,
190 total_cpu_bytes: self.total_cpu_bytes(),
191 num_states: self.states.len(),
192 strategy: self.strategy,
193 }
194 }
195}
196
197#[derive(Debug, Clone)]
199pub struct PagingStats {
200 pub page_in_count: u64,
202 pub page_out_count: u64,
204 pub total_cpu_bytes: usize,
206 pub num_states: usize,
208 pub strategy: PagingStrategy,
210}
211
212impl PagingStats {
213 pub fn summary(&self) -> String {
215 format!(
216 "Paged optimizer: {} states, {:.1} MB CPU, {} page-ins, {} page-outs, strategy={:?}",
217 self.num_states,
218 self.total_cpu_bytes as f64 / 1e6,
219 self.page_in_count,
220 self.page_out_count,
221 self.strategy,
222 )
223 }
224}
225
226#[cfg(test)]
227#[allow(clippy::unwrap_used)]
228mod tests {
229 use super::*;
230 use proptest::prelude::*;
231
232 #[test]
233 fn test_ent_lora_010_vram_budget_basic() {
234 let budget = VramBudget::new(16.0).with_reserved(10.0);
235 let avail = budget.available_bytes();
237 assert!(avail > 3_000_000_000);
238 assert!(avail < 4_000_000_000);
239 }
240
241 #[test]
242 fn test_ent_lora_010_vram_budget_fits() {
243 let budget = VramBudget::new(16.0).with_reserved(10.0);
244 assert!(budget.fits(1_000_000_000)); assert!(!budget.fits(10_000_000_000)); }
247
248 #[test]
249 fn test_ent_lora_010_paged_state_lifecycle() {
250 let mut state = PagedState::new(1024);
251 assert_eq!(state.cpu_bytes(), 0);
252
253 state.ensure_initialized();
254 assert_eq!(state.cpu_bytes(), 1024 * 8); assert_eq!(state.gpu_bytes(), 1024 * 8);
256 assert!(state.m.is_some());
257 assert!(state.v.is_some());
258 }
259
260 #[test]
261 fn test_ent_lora_010_paged_optim_register() {
262 let budget = VramBudget::new(16.0);
263 let mut paged = PagedOptimStates::new(budget, PagingStrategy::Adaptive);
264
265 let idx0 = paged.register(512);
266 let idx1 = paged.register(1024);
267
268 assert_eq!(idx0, 0);
269 assert_eq!(idx1, 1);
270 assert_eq!(paged.num_states(), 2);
271 }
272
273 #[test]
274 fn test_ent_lora_010_paged_optim_get_state() {
275 let budget = VramBudget::new(16.0);
276 let mut paged = PagedOptimStates::new(budget, PagingStrategy::FullyPaged);
277 paged.register(256);
278
279 let state = paged.get_state_mut(0);
280 assert!(state.m.is_some()); assert_eq!(state.m.as_ref().unwrap().len(), 256);
282 }
283
284 #[test]
285 fn test_ent_lora_010_paged_optim_stats() {
286 let budget = VramBudget::new(16.0);
287 let mut paged = PagedOptimStates::new(budget, PagingStrategy::FullyPaged);
288 paged.register(1024);
289 let _ = paged.get_state_mut(0); let stats = paged.stats();
292 assert_eq!(stats.num_states, 1);
293 assert!(stats.total_cpu_bytes > 0);
294 assert!(stats.page_in_count > 0);
295 assert!(stats.summary().contains("Paged optimizer"));
296 }
297
298 #[test]
299 fn test_ent_lora_010_all_fit_on_gpu() {
300 let budget = VramBudget::new(16.0).with_reserved(0.0);
301 let mut paged = PagedOptimStates::new(budget, PagingStrategy::Adaptive);
302 paged.register(1_000_000);
304 assert!(paged.all_fit_on_gpu());
305 }
306
307 #[test]
308 fn test_ent_lora_010_does_not_fit_on_gpu() {
309 let budget = VramBudget::new(0.001); let mut paged = PagedOptimStates::new(budget, PagingStrategy::Adaptive);
311 paged.register(100_000_000);
313 assert!(!paged.all_fit_on_gpu());
314 }
315
316 #[test]
317 fn test_ent_lora_010_strategy_none() {
318 let budget = VramBudget::new(16.0);
319 let mut paged = PagedOptimStates::new(budget, PagingStrategy::None);
320 paged.register(512);
321 let _ = paged.get_state_mut(0);
322
323 let stats = paged.stats();
324 assert_eq!(stats.page_in_count, 0); }
326
327 #[test]
328 fn test_ent_lora_010_vram_budget_target_clamping() {
329 let budget = VramBudget::new(16.0).with_target(0.1);
330 assert!(budget.target_utilization >= 0.5);
331
332 let budget = VramBudget::new(16.0).with_target(1.5);
333 assert!(budget.target_utilization <= 0.95);
334 }
335
336 proptest! {
337 #![proptest_config(proptest::test_runner::Config::with_cases(50))]
338
339 #[test]
340 fn prop_paged_state_bytes_consistent(len in 1usize..10000) {
341 let mut state = PagedState::new(len);
342 prop_assert_eq!(state.cpu_bytes(), 0);
343
344 state.ensure_initialized();
345 prop_assert_eq!(state.cpu_bytes(), len * 8);
346 prop_assert_eq!(state.gpu_bytes(), len * 8);
347 }
348
349 #[test]
350 fn prop_budget_available_nonnegative(
351 total_gb in 1.0f64..100.0,
352 reserved_gb in 0.0f64..50.0,
353 ) {
354 let budget = VramBudget::new(total_gb).with_reserved(reserved_gb);
355 let _ = budget.available_bytes(); }
358 }
359}