use ndarray::Array1;
#[derive(Debug, Clone)]
pub struct VramBudget {
total_bytes: u64,
reserved_bytes: u64,
target_utilization: f64,
}
impl VramBudget {
pub fn new(total_vram_gb: f64) -> Self {
Self {
total_bytes: (total_vram_gb * 1e9) as u64,
reserved_bytes: 0,
target_utilization: 0.85,
}
}
pub fn with_reserved(mut self, reserved_gb: f64) -> Self {
self.reserved_bytes = (reserved_gb * 1e9) as u64;
self
}
pub fn with_target(mut self, target: f64) -> Self {
self.target_utilization = target.clamp(0.5, 0.95);
self
}
pub fn available_bytes(&self) -> u64 {
let budget = (self.total_bytes as f64 * self.target_utilization) as u64;
budget.saturating_sub(self.reserved_bytes)
}
pub fn fits(&self, bytes: u64) -> bool {
bytes <= self.available_bytes()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PagingStrategy {
FullyPaged,
Adaptive,
None,
}
#[derive(Debug, Clone)]
pub struct PagedState {
pub m: Option<Array1<f32>>,
pub v: Option<Array1<f32>>,
pub len: usize,
pub on_gpu: bool,
}
impl PagedState {
pub fn new(len: usize) -> Self {
Self { m: None, v: None, len, on_gpu: false }
}
pub fn ensure_initialized(&mut self) {
if self.m.is_none() {
self.m = Some(Array1::zeros(self.len));
self.v = Some(Array1::zeros(self.len));
}
}
pub fn cpu_bytes(&self) -> usize {
if self.m.is_some() {
self.len * 8
} else {
0
}
}
pub fn gpu_bytes(&self) -> usize {
self.len * 8 }
}
pub struct PagedOptimStates {
states: Vec<PagedState>,
budget: VramBudget,
strategy: PagingStrategy,
page_in_count: u64,
page_out_count: u64,
}
impl PagedOptimStates {
pub fn new(budget: VramBudget, strategy: PagingStrategy) -> Self {
Self { states: Vec::new(), budget, strategy, page_in_count: 0, page_out_count: 0 }
}
pub fn register(&mut self, param_len: usize) -> usize {
let idx = self.states.len();
self.states.push(PagedState::new(param_len));
idx
}
pub fn get_state_mut(&mut self, idx: usize) -> &mut PagedState {
self.states[idx].ensure_initialized();
if self.strategy == PagingStrategy::FullyPaged && self.states[idx].on_gpu {
self.page_out_count += 1;
}
if !self.states[idx].on_gpu && self.strategy != PagingStrategy::None {
self.page_in_count += 1;
}
&mut self.states[idx]
}
pub fn get_state(&self, idx: usize) -> &PagedState {
&self.states[idx]
}
pub fn total_cpu_bytes(&self) -> usize {
self.states.iter().map(PagedState::cpu_bytes).sum()
}
pub fn num_states(&self) -> usize {
self.states.len()
}
pub fn all_fit_on_gpu(&self) -> bool {
let total: u64 = self.states.iter().map(|s| s.gpu_bytes() as u64).sum();
self.budget.fits(total)
}
pub fn stats(&self) -> PagingStats {
PagingStats {
page_in_count: self.page_in_count,
page_out_count: self.page_out_count,
total_cpu_bytes: self.total_cpu_bytes(),
num_states: self.states.len(),
strategy: self.strategy,
}
}
}
#[derive(Debug, Clone)]
pub struct PagingStats {
pub page_in_count: u64,
pub page_out_count: u64,
pub total_cpu_bytes: usize,
pub num_states: usize,
pub strategy: PagingStrategy,
}
impl PagingStats {
pub fn summary(&self) -> String {
format!(
"Paged optimizer: {} states, {:.1} MB CPU, {} page-ins, {} page-outs, strategy={:?}",
self.num_states,
self.total_cpu_bytes as f64 / 1e6,
self.page_in_count,
self.page_out_count,
self.strategy,
)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use proptest::prelude::*;
#[test]
fn test_ent_lora_010_vram_budget_basic() {
let budget = VramBudget::new(16.0).with_reserved(10.0);
let avail = budget.available_bytes();
assert!(avail > 3_000_000_000);
assert!(avail < 4_000_000_000);
}
#[test]
fn test_ent_lora_010_vram_budget_fits() {
let budget = VramBudget::new(16.0).with_reserved(10.0);
assert!(budget.fits(1_000_000_000)); assert!(!budget.fits(10_000_000_000)); }
#[test]
fn test_ent_lora_010_paged_state_lifecycle() {
let mut state = PagedState::new(1024);
assert_eq!(state.cpu_bytes(), 0);
state.ensure_initialized();
assert_eq!(state.cpu_bytes(), 1024 * 8); assert_eq!(state.gpu_bytes(), 1024 * 8);
assert!(state.m.is_some());
assert!(state.v.is_some());
}
#[test]
fn test_ent_lora_010_paged_optim_register() {
let budget = VramBudget::new(16.0);
let mut paged = PagedOptimStates::new(budget, PagingStrategy::Adaptive);
let idx0 = paged.register(512);
let idx1 = paged.register(1024);
assert_eq!(idx0, 0);
assert_eq!(idx1, 1);
assert_eq!(paged.num_states(), 2);
}
#[test]
fn test_ent_lora_010_paged_optim_get_state() {
let budget = VramBudget::new(16.0);
let mut paged = PagedOptimStates::new(budget, PagingStrategy::FullyPaged);
paged.register(256);
let state = paged.get_state_mut(0);
assert!(state.m.is_some()); assert_eq!(state.m.as_ref().unwrap().len(), 256);
}
#[test]
fn test_ent_lora_010_paged_optim_stats() {
let budget = VramBudget::new(16.0);
let mut paged = PagedOptimStates::new(budget, PagingStrategy::FullyPaged);
paged.register(1024);
let _ = paged.get_state_mut(0);
let stats = paged.stats();
assert_eq!(stats.num_states, 1);
assert!(stats.total_cpu_bytes > 0);
assert!(stats.page_in_count > 0);
assert!(stats.summary().contains("Paged optimizer"));
}
#[test]
fn test_ent_lora_010_all_fit_on_gpu() {
let budget = VramBudget::new(16.0).with_reserved(0.0);
let mut paged = PagedOptimStates::new(budget, PagingStrategy::Adaptive);
paged.register(1_000_000);
assert!(paged.all_fit_on_gpu());
}
#[test]
fn test_ent_lora_010_does_not_fit_on_gpu() {
let budget = VramBudget::new(0.001); let mut paged = PagedOptimStates::new(budget, PagingStrategy::Adaptive);
paged.register(100_000_000);
assert!(!paged.all_fit_on_gpu());
}
#[test]
fn test_ent_lora_010_strategy_none() {
let budget = VramBudget::new(16.0);
let mut paged = PagedOptimStates::new(budget, PagingStrategy::None);
paged.register(512);
let _ = paged.get_state_mut(0);
let stats = paged.stats();
assert_eq!(stats.page_in_count, 0); }
#[test]
fn test_ent_lora_010_vram_budget_target_clamping() {
let budget = VramBudget::new(16.0).with_target(0.1);
assert!(budget.target_utilization >= 0.5);
let budget = VramBudget::new(16.0).with_target(1.5);
assert!(budget.target_utilization <= 0.95);
}
proptest! {
#![proptest_config(proptest::test_runner::Config::with_cases(50))]
#[test]
fn prop_paged_state_bytes_consistent(len in 1usize..10000) {
let mut state = PagedState::new(len);
prop_assert_eq!(state.cpu_bytes(), 0);
state.ensure_initialized();
prop_assert_eq!(state.cpu_bytes(), len * 8);
prop_assert_eq!(state.gpu_bytes(), len * 8);
}
#[test]
fn prop_budget_available_nonnegative(
total_gb in 1.0f64..100.0,
reserved_gb in 0.0f64..50.0,
) {
let budget = VramBudget::new(total_gb).with_reserved(reserved_gb);
let _ = budget.available_bytes(); }
}
}