entrenar 0.7.10

Training & Optimization library with autograd, LoRA, quantization, and model merging
Documentation
//! Paged optimizer states for QLoRA (ENT-LoRA-010)
//!
//! Pages AdamW m/v (momentum/variance) states to CPU RAM when GPU VRAM pressure
//! is detected. This enables training larger models on smaller GPUs.
//!
//! Architecture:
//! - All optimizer states (m, v) live on CPU by default
//! - Before optimizer step, relevant states are paged into GPU buffers
//! - After step, states are paged back to CPU
//! - VRAM budget tracks pressure and triggers paging
//!
//! Memory savings for 7B model at rank=16:
//! - Full m/v on GPU: 2 × 7B × 4 bytes = 56 GB (impossible on consumer GPU)
//! - LoRA m/v only: 2 × ~5.9M × 4 bytes = ~47 MB (always fits)
//! - Paged: m/v on CPU, paged in per-layer = constant ~200KB GPU overhead

use ndarray::Array1;

/// VRAM budget tracker for optimizer state paging decisions
#[derive(Debug, Clone)]
pub struct VramBudget {
    /// Total VRAM in bytes
    total_bytes: u64,
    /// Reserved for model weights and activations (bytes)
    reserved_bytes: u64,
    /// Target utilization (0.0 - 1.0)
    target_utilization: f64,
}

impl VramBudget {
    /// Create a new VRAM budget
    pub fn new(total_vram_gb: f64) -> Self {
        Self {
            total_bytes: (total_vram_gb * 1e9) as u64,
            reserved_bytes: 0,
            target_utilization: 0.85,
        }
    }

    /// Set reserved bytes (model weights + activations)
    pub fn with_reserved(mut self, reserved_gb: f64) -> Self {
        self.reserved_bytes = (reserved_gb * 1e9) as u64;
        self
    }

    /// Set target utilization
    pub fn with_target(mut self, target: f64) -> Self {
        self.target_utilization = target.clamp(0.5, 0.95);
        self
    }

    /// Available bytes for optimizer states
    pub fn available_bytes(&self) -> u64 {
        let budget = (self.total_bytes as f64 * self.target_utilization) as u64;
        budget.saturating_sub(self.reserved_bytes)
    }

    /// Check if a given number of bytes would fit on GPU
    pub fn fits(&self, bytes: u64) -> bool {
        bytes <= self.available_bytes()
    }
}

/// Paging strategy for optimizer states
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PagingStrategy {
    /// All states on CPU, page in per-layer during step (safest, slowest)
    FullyPaged,
    /// Keep states on GPU if they fit, fall back to paging (adaptive)
    Adaptive,
    /// Never page — fail with OOM if states don't fit (default, fastest)
    None,
}

/// CPU-resident optimizer state for one parameter group
#[derive(Debug, Clone)]
pub struct PagedState {
    /// First moment (m) stored on CPU
    pub m: Option<Array1<f32>>,
    /// Second moment (v) stored on CPU
    pub v: Option<Array1<f32>>,
    /// Number of elements
    pub len: usize,
    /// Whether this state is currently paged in to GPU
    pub on_gpu: bool,
}

impl PagedState {
    /// Create empty state for a parameter of given length
    pub fn new(len: usize) -> Self {
        Self { m: None, v: None, len, on_gpu: false }
    }

    /// Initialize states (lazy, called on first use)
    pub fn ensure_initialized(&mut self) {
        if self.m.is_none() {
            self.m = Some(Array1::zeros(self.len));
            self.v = Some(Array1::zeros(self.len));
        }
    }

    /// Memory usage in bytes on CPU
    pub fn cpu_bytes(&self) -> usize {
        // m + v, each Array1<f32> = len * 4 bytes
        if self.m.is_some() {
            self.len * 8
        } else {
            0
        }
    }

    /// Memory that would be needed on GPU
    pub fn gpu_bytes(&self) -> usize {
        self.len * 8 // m + v
    }
}

/// Paged optimizer state manager
///
/// Wraps optimizer state storage with CPU↔GPU paging capability.
/// On CPU-only systems, this is essentially a no-op wrapper.
pub struct PagedOptimStates {
    /// Per-parameter optimizer states (CPU-resident)
    states: Vec<PagedState>,
    /// VRAM budget for paging decisions
    budget: VramBudget,
    /// Paging strategy
    strategy: PagingStrategy,
    /// Number of page-in events (for monitoring)
    page_in_count: u64,
    /// Number of page-out events
    page_out_count: u64,
}

impl PagedOptimStates {
    /// Create a new paged optimizer state manager
    pub fn new(budget: VramBudget, strategy: PagingStrategy) -> Self {
        Self { states: Vec::new(), budget, strategy, page_in_count: 0, page_out_count: 0 }
    }

    /// Register a parameter group
    pub fn register(&mut self, param_len: usize) -> usize {
        let idx = self.states.len();
        self.states.push(PagedState::new(param_len));
        idx
    }

    /// Get mutable state for a parameter, paging in if necessary
    pub fn get_state_mut(&mut self, idx: usize) -> &mut PagedState {
        self.states[idx].ensure_initialized();

        if self.strategy == PagingStrategy::FullyPaged && self.states[idx].on_gpu {
            // State is on GPU, need to page out others first
            self.page_out_count += 1;
        }

        if !self.states[idx].on_gpu && self.strategy != PagingStrategy::None {
            self.page_in_count += 1;
        }

        &mut self.states[idx]
    }

    /// Get immutable state
    pub fn get_state(&self, idx: usize) -> &PagedState {
        &self.states[idx]
    }

    /// Total CPU memory used by all states (bytes)
    pub fn total_cpu_bytes(&self) -> usize {
        self.states.iter().map(PagedState::cpu_bytes).sum()
    }

    /// Number of registered parameter groups
    pub fn num_states(&self) -> usize {
        self.states.len()
    }

    /// Would all states fit on GPU simultaneously?
    pub fn all_fit_on_gpu(&self) -> bool {
        let total: u64 = self.states.iter().map(|s| s.gpu_bytes() as u64).sum();
        self.budget.fits(total)
    }

    /// Get paging statistics
    pub fn stats(&self) -> PagingStats {
        PagingStats {
            page_in_count: self.page_in_count,
            page_out_count: self.page_out_count,
            total_cpu_bytes: self.total_cpu_bytes(),
            num_states: self.states.len(),
            strategy: self.strategy,
        }
    }
}

/// Paging statistics for monitoring
#[derive(Debug, Clone)]
pub struct PagingStats {
    /// Number of page-in events
    pub page_in_count: u64,
    /// Number of page-out events
    pub page_out_count: u64,
    /// Total CPU memory used (bytes)
    pub total_cpu_bytes: usize,
    /// Number of parameter groups
    pub num_states: usize,
    /// Active paging strategy
    pub strategy: PagingStrategy,
}

impl PagingStats {
    /// Format as human-readable string
    pub fn summary(&self) -> String {
        format!(
            "Paged optimizer: {} states, {:.1} MB CPU, {} page-ins, {} page-outs, strategy={:?}",
            self.num_states,
            self.total_cpu_bytes as f64 / 1e6,
            self.page_in_count,
            self.page_out_count,
            self.strategy,
        )
    }
}

#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
    use super::*;
    use proptest::prelude::*;

    #[test]
    fn test_ent_lora_010_vram_budget_basic() {
        let budget = VramBudget::new(16.0).with_reserved(10.0);
        // 16 * 0.85 - 10 = 3.6 GB available
        let avail = budget.available_bytes();
        assert!(avail > 3_000_000_000);
        assert!(avail < 4_000_000_000);
    }

    #[test]
    fn test_ent_lora_010_vram_budget_fits() {
        let budget = VramBudget::new(16.0).with_reserved(10.0);
        assert!(budget.fits(1_000_000_000)); // 1GB fits
        assert!(!budget.fits(10_000_000_000)); // 10GB doesn't
    }

    #[test]
    fn test_ent_lora_010_paged_state_lifecycle() {
        let mut state = PagedState::new(1024);
        assert_eq!(state.cpu_bytes(), 0);

        state.ensure_initialized();
        assert_eq!(state.cpu_bytes(), 1024 * 8); // m + v
        assert_eq!(state.gpu_bytes(), 1024 * 8);
        assert!(state.m.is_some());
        assert!(state.v.is_some());
    }

    #[test]
    fn test_ent_lora_010_paged_optim_register() {
        let budget = VramBudget::new(16.0);
        let mut paged = PagedOptimStates::new(budget, PagingStrategy::Adaptive);

        let idx0 = paged.register(512);
        let idx1 = paged.register(1024);

        assert_eq!(idx0, 0);
        assert_eq!(idx1, 1);
        assert_eq!(paged.num_states(), 2);
    }

    #[test]
    fn test_ent_lora_010_paged_optim_get_state() {
        let budget = VramBudget::new(16.0);
        let mut paged = PagedOptimStates::new(budget, PagingStrategy::FullyPaged);
        paged.register(256);

        let state = paged.get_state_mut(0);
        assert!(state.m.is_some()); // Lazily initialized
        assert_eq!(state.m.as_ref().unwrap().len(), 256);
    }

    #[test]
    fn test_ent_lora_010_paged_optim_stats() {
        let budget = VramBudget::new(16.0);
        let mut paged = PagedOptimStates::new(budget, PagingStrategy::FullyPaged);
        paged.register(1024);
        let _ = paged.get_state_mut(0); // Triggers page-in

        let stats = paged.stats();
        assert_eq!(stats.num_states, 1);
        assert!(stats.total_cpu_bytes > 0);
        assert!(stats.page_in_count > 0);
        assert!(stats.summary().contains("Paged optimizer"));
    }

    #[test]
    fn test_ent_lora_010_all_fit_on_gpu() {
        let budget = VramBudget::new(16.0).with_reserved(0.0);
        let mut paged = PagedOptimStates::new(budget, PagingStrategy::Adaptive);
        // 1M params × 8 bytes = 8MB — easily fits in 16GB
        paged.register(1_000_000);
        assert!(paged.all_fit_on_gpu());
    }

    #[test]
    fn test_ent_lora_010_does_not_fit_on_gpu() {
        let budget = VramBudget::new(0.001); // 1MB VRAM
        let mut paged = PagedOptimStates::new(budget, PagingStrategy::Adaptive);
        // 100M params × 8 bytes = 800MB — doesn't fit in 1MB
        paged.register(100_000_000);
        assert!(!paged.all_fit_on_gpu());
    }

    #[test]
    fn test_ent_lora_010_strategy_none() {
        let budget = VramBudget::new(16.0);
        let mut paged = PagedOptimStates::new(budget, PagingStrategy::None);
        paged.register(512);
        let _ = paged.get_state_mut(0);

        let stats = paged.stats();
        assert_eq!(stats.page_in_count, 0); // None strategy doesn't track pages
    }

    #[test]
    fn test_ent_lora_010_vram_budget_target_clamping() {
        let budget = VramBudget::new(16.0).with_target(0.1);
        assert!(budget.target_utilization >= 0.5);

        let budget = VramBudget::new(16.0).with_target(1.5);
        assert!(budget.target_utilization <= 0.95);
    }

    proptest! {
        #![proptest_config(proptest::test_runner::Config::with_cases(50))]

        #[test]
        fn prop_paged_state_bytes_consistent(len in 1usize..10000) {
            let mut state = PagedState::new(len);
            prop_assert_eq!(state.cpu_bytes(), 0);

            state.ensure_initialized();
            prop_assert_eq!(state.cpu_bytes(), len * 8);
            prop_assert_eq!(state.gpu_bytes(), len * 8);
        }

        #[test]
        fn prop_budget_available_nonnegative(
            total_gb in 1.0f64..100.0,
            reserved_gb in 0.0f64..50.0,
        ) {
            let budget = VramBudget::new(total_gb).with_reserved(reserved_gb);
            // available_bytes uses saturating_sub, so always >= 0
            let _ = budget.available_bytes(); // Just verify no panic
        }
    }
}