Skip to main content

embeddenator_vsa/
vram_pool.rs

1//! GPU VRAM Memory Pool for Persistent Engrams
2//!
3//! This module provides arena-style GPU memory management for persistent engrams,
4//! enabling unified memory access across CPU RAM, GPU VRAM, and storage.
5//!
6//! # Design Goals
7//!
8//! 1. **Arena-style allocation**: Pre-allocate VRAM chunks to reduce allocation overhead
9//! 2. **Eviction policy**: Automatically evict least-recently-used engrams under VRAM pressure
10//! 3. **Safe limits**: Respect GPU memory constraints from `GpuMemoryConfig`
11//! 4. **Integration**: Wire into existing `GpuBackend` infrastructure
12//!
13//! # Example
14//!
15//! ```rust,ignore
16//! use embeddenator_vsa::{VramPool, VramPoolConfig, GpuBackend, GpuConfig};
17//!
18//! let gpu = GpuBackend::new(GpuConfig::default())?;
19//! let pool = VramPool::new(&gpu, VramPoolConfig::default())?;
20//!
21//! // Allocate space for an engram
22//! let handle = pool.allocate(1024 * 1024)?; // 1MB
23//!
24//! // Upload data
25//! pool.upload(&handle, &my_data)?;
26//!
27//! // Download data
28//! let data = pool.download(&handle)?;
29//!
30//! // Free when done
31//! pool.free(handle)?;
32//! ```
33
34#[cfg(feature = "cuda")]
35use std::collections::HashMap;
36#[cfg(feature = "cuda")]
37use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
38#[cfg(feature = "cuda")]
39use std::sync::{Arc, RwLock};
40
41#[cfg(feature = "cuda")]
42use cudarc::driver::{CudaSlice, CudaStream};
43
44#[cfg(feature = "cuda")]
45use crate::gpu::{GpuError, GpuMemoryConfig};
46
47// Stub types for non-CUDA builds
48#[cfg(not(feature = "cuda"))]
49#[derive(Debug, Clone)]
50pub enum GpuError {
51    NotAvailable,
52}
53
54#[cfg(not(feature = "cuda"))]
55#[derive(Clone, Debug, Default)]
56pub struct GpuMemoryConfig {
57    pub safe_limit: usize,
58}
59
60/// Configuration for VRAM pool
61#[derive(Clone, Debug)]
62pub struct VramPoolConfig {
63    /// Maximum percentage of safe VRAM to use (0.0 - 1.0)
64    pub max_usage_ratio: f64,
65    /// Enable LRU eviction when pool is full
66    pub enable_eviction: bool,
67    /// Minimum free space to maintain (bytes)
68    pub min_free_bytes: usize,
69    /// Enable async transfers
70    pub enable_async: bool,
71}
72
73impl Default for VramPoolConfig {
74    fn default() -> Self {
75        Self {
76            max_usage_ratio: 0.80, // Use up to 80% of safe VRAM
77            enable_eviction: true,
78            min_free_bytes: 256 * 1024 * 1024, // 256MB minimum free
79            enable_async: true,
80        }
81    }
82}
83
84/// Handle to a VRAM allocation
85#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
86pub struct VramHandle {
87    /// Unique allocation ID
88    pub id: u64,
89    /// Size in bytes
90    pub size: usize,
91}
92
93#[cfg(feature = "cuda")]
94impl VramHandle {
95    /// Create a new handle
96    fn new(id: u64, size: usize) -> Self {
97        Self { id, size }
98    }
99}
100
101/// Metadata for a VRAM allocation
102#[cfg(feature = "cuda")]
103#[derive(Debug)]
104struct VramAllocation {
105    /// Handle to this allocation
106    handle: VramHandle,
107    /// Last access time for LRU eviction
108    last_access: std::time::Instant,
109    /// Whether the data is dirty (modified on device)
110    dirty: bool,
111    /// Whether the allocation is pinned (cannot be evicted)
112    pinned: bool,
113}
114
115#[cfg(feature = "cuda")]
116impl VramAllocation {
117    fn new(handle: VramHandle) -> Self {
118        Self {
119            handle,
120            last_access: std::time::Instant::now(),
121            dirty: false,
122            pinned: false,
123        }
124    }
125
126    fn touch(&mut self) {
127        self.last_access = std::time::Instant::now();
128    }
129}
130
131/// VRAM memory pool for persistent GPU allocations
132///
133/// Provides arena-style memory management with LRU eviction
134/// for GPU VRAM, enabling persistent engram storage.
135#[cfg(feature = "cuda")]
136pub struct VramPool {
137    /// CUDA stream for transfers
138    stream: Arc<CudaStream>,
139    /// Pool configuration
140    config: VramPoolConfig,
141    /// Memory limits from GPU
142    memory_config: GpuMemoryConfig,
143    /// Active allocations (handle ID -> device buffer)
144    allocations: RwLock<HashMap<u64, CudaSlice<u8>>>,
145    /// Allocation metadata for LRU tracking
146    metadata: RwLock<HashMap<u64, VramAllocation>>,
147    /// Next allocation ID
148    next_id: AtomicU64,
149    /// Current total allocated bytes
150    allocated_bytes: AtomicUsize,
151}
152
153#[cfg(feature = "cuda")]
154impl VramPool {
155    /// Create a new VRAM pool with the given GPU stream
156    pub fn new(
157        stream: Arc<CudaStream>,
158        memory_config: GpuMemoryConfig,
159        config: VramPoolConfig,
160    ) -> Self {
161        Self {
162            stream,
163            config,
164            memory_config,
165            allocations: RwLock::new(HashMap::new()),
166            metadata: RwLock::new(HashMap::new()),
167            next_id: AtomicU64::new(1),
168            allocated_bytes: AtomicUsize::new(0),
169        }
170    }
171
172    /// Get the maximum usable VRAM based on config
173    pub fn max_usable_bytes(&self) -> usize {
174        let safe_limit = self.memory_config.safe_limit;
175        let from_ratio = (safe_limit as f64 * self.config.max_usage_ratio) as usize;
176        from_ratio.saturating_sub(self.config.min_free_bytes)
177    }
178
179    /// Get current allocated bytes
180    pub fn allocated_bytes(&self) -> usize {
181        self.allocated_bytes.load(Ordering::Relaxed)
182    }
183
184    /// Get available bytes for allocation
185    pub fn available_bytes(&self) -> usize {
186        self.max_usable_bytes()
187            .saturating_sub(self.allocated_bytes())
188    }
189
190    /// Check if an allocation of the given size would fit
191    pub fn can_allocate(&self, size: usize) -> bool {
192        size <= self.available_bytes()
193    }
194
195    /// Allocate VRAM for the given size
196    ///
197    /// Returns a handle to the allocation. If eviction is enabled and
198    /// there isn't enough space, LRU allocations will be evicted first.
199    pub fn allocate(&self, size: usize) -> Result<VramHandle, GpuError> {
200        // Check if we have space (potentially after eviction)
201        if !self.can_allocate(size) {
202            if self.config.enable_eviction {
203                self.evict_until_available(size)?;
204            }
205
206            if !self.can_allocate(size) {
207                return Err(GpuError::MemoryAlloc(format!(
208                    "VRAM pool exhausted: need {} bytes, available {} bytes",
209                    size,
210                    self.available_bytes()
211                )));
212            }
213        }
214
215        // Allocate on device
216        let device_buffer: CudaSlice<u8> = self
217            .stream
218            .alloc_zeros(size)
219            .map_err(|e| GpuError::MemoryAlloc(e.to_string()))?;
220
221        // Create handle and track allocation
222        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
223        let handle = VramHandle::new(id, size);
224        let allocation = VramAllocation::new(handle);
225
226        // Store allocation
227        {
228            let mut allocs = self.allocations.write().unwrap();
229            let mut meta = self.metadata.write().unwrap();
230            allocs.insert(id, device_buffer);
231            meta.insert(id, allocation);
232        }
233
234        self.allocated_bytes.fetch_add(size, Ordering::SeqCst);
235
236        Ok(handle)
237    }
238
239    /// Free a VRAM allocation
240    pub fn free(&self, handle: VramHandle) -> Result<(), GpuError> {
241        let mut allocs = self.allocations.write().unwrap();
242        let mut meta = self.metadata.write().unwrap();
243
244        if allocs.remove(&handle.id).is_some() {
245            meta.remove(&handle.id);
246            self.allocated_bytes
247                .fetch_sub(handle.size, Ordering::SeqCst);
248            Ok(())
249        } else {
250            Err(GpuError::InvalidValue(format!(
251                "VRAM handle {} not found",
252                handle.id
253            )))
254        }
255    }
256
257    /// Upload data from host to a VRAM allocation
258    pub fn upload(&self, handle: &VramHandle, data: &[u8]) -> Result<(), GpuError> {
259        if data.len() != handle.size {
260            return Err(GpuError::InvalidValue(format!(
261                "Data size {} doesn't match allocation size {}",
262                data.len(),
263                handle.size
264            )));
265        }
266
267        // Touch metadata for LRU
268        {
269            let mut meta = self.metadata.write().unwrap();
270            if let Some(alloc) = meta.get_mut(&handle.id) {
271                alloc.touch();
272            }
273        }
274
275        // Get the device buffer and copy data
276        let mut allocs = self.allocations.write().unwrap();
277        let device_buf = allocs.get_mut(&handle.id).ok_or_else(|| {
278            GpuError::InvalidValue(format!("VRAM handle {} not found", handle.id))
279        })?;
280
281        // Copy host data to device buffer
282        self.stream
283            .memcpy_htod(data, device_buf)
284            .map_err(|e| GpuError::MemoryCopy(e.to_string()))?;
285
286        Ok(())
287    }
288
289    /// Download data from a VRAM allocation to host
290    pub fn download(&self, handle: &VramHandle) -> Result<Vec<u8>, GpuError> {
291        // Touch metadata for LRU
292        {
293            let mut meta = self.metadata.write().unwrap();
294            if let Some(alloc) = meta.get_mut(&handle.id) {
295                alloc.touch();
296            }
297        }
298
299        let allocs = self.allocations.read().unwrap();
300        let device_buf = allocs.get(&handle.id).ok_or_else(|| {
301            GpuError::InvalidValue(format!("VRAM handle {} not found", handle.id))
302        })?;
303
304        let data = self
305            .stream
306            .clone_dtoh(device_buf)
307            .map_err(|e| GpuError::MemoryCopy(e.to_string()))?;
308
309        Ok(data)
310    }
311
312    /// Pin an allocation (prevent eviction)
313    pub fn pin(&self, handle: &VramHandle) -> Result<(), GpuError> {
314        let mut meta = self.metadata.write().unwrap();
315        let alloc = meta.get_mut(&handle.id).ok_or_else(|| {
316            GpuError::InvalidValue(format!("VRAM handle {} not found", handle.id))
317        })?;
318        alloc.pinned = true;
319        Ok(())
320    }
321
322    /// Unpin an allocation (allow eviction)
323    pub fn unpin(&self, handle: &VramHandle) -> Result<(), GpuError> {
324        let mut meta = self.metadata.write().unwrap();
325        let alloc = meta.get_mut(&handle.id).ok_or_else(|| {
326            GpuError::InvalidValue(format!("VRAM handle {} not found", handle.id))
327        })?;
328        alloc.pinned = false;
329        Ok(())
330    }
331
332    /// Mark an allocation as dirty (modified on device)
333    pub fn mark_dirty(&self, handle: &VramHandle) -> Result<(), GpuError> {
334        let mut meta = self.metadata.write().unwrap();
335        let alloc = meta.get_mut(&handle.id).ok_or_else(|| {
336            GpuError::InvalidValue(format!("VRAM handle {} not found", handle.id))
337        })?;
338        alloc.dirty = true;
339        Ok(())
340    }
341
342    /// Check if an allocation is dirty
343    pub fn is_dirty(&self, handle: &VramHandle) -> Result<bool, GpuError> {
344        let meta = self.metadata.read().unwrap();
345        let alloc = meta.get(&handle.id).ok_or_else(|| {
346            GpuError::InvalidValue(format!("VRAM handle {} not found", handle.id))
347        })?;
348        Ok(alloc.dirty)
349    }
350
351    /// Evict allocations until we have at least `needed` bytes available
352    fn evict_until_available(&self, needed: usize) -> Result<(), GpuError> {
353        while self.available_bytes() < needed {
354            // Find LRU non-pinned allocation
355            let to_evict = {
356                let meta = self.metadata.read().unwrap();
357                meta.values()
358                    .filter(|a| !a.pinned)
359                    .min_by_key(|a| a.last_access)
360                    .map(|a| a.handle)
361            };
362
363            match to_evict {
364                Some(handle) => {
365                    self.free(handle)?;
366                }
367                None => {
368                    // No evictable allocations
369                    return Err(GpuError::MemoryAlloc(
370                        "Cannot evict: all allocations are pinned".to_string(),
371                    ));
372                }
373            }
374        }
375        Ok(())
376    }
377
378    /// Get statistics about the pool
379    pub fn stats(&self) -> VramPoolStats {
380        let meta = self.metadata.read().unwrap();
381        let num_allocations = meta.len();
382        let num_pinned = meta.values().filter(|a| a.pinned).count();
383        let num_dirty = meta.values().filter(|a| a.dirty).count();
384
385        VramPoolStats {
386            total_capacity: self.max_usable_bytes(),
387            allocated_bytes: self.allocated_bytes(),
388            available_bytes: self.available_bytes(),
389            num_allocations,
390            num_pinned,
391            num_dirty,
392        }
393    }
394}
395
396/// Statistics about VRAM pool usage
397#[derive(Clone, Debug)]
398pub struct VramPoolStats {
399    /// Total usable capacity in bytes
400    pub total_capacity: usize,
401    /// Currently allocated bytes
402    pub allocated_bytes: usize,
403    /// Available bytes for new allocations
404    pub available_bytes: usize,
405    /// Number of active allocations
406    pub num_allocations: usize,
407    /// Number of pinned allocations
408    pub num_pinned: usize,
409    /// Number of dirty allocations
410    pub num_dirty: usize,
411}
412
413// Stub for non-CUDA builds
414#[cfg(not(feature = "cuda"))]
415pub struct VramPool {
416    _private: (),
417}
418
419#[cfg(not(feature = "cuda"))]
420impl VramPool {
421    pub fn new(_stream: (), _memory_config: GpuMemoryConfig, _config: VramPoolConfig) -> Self {
422        Self { _private: () }
423    }
424
425    pub fn allocate(&self, _size: usize) -> Result<VramHandle, GpuError> {
426        Err(GpuError::NotAvailable)
427    }
428
429    pub fn free(&self, _handle: VramHandle) -> Result<(), GpuError> {
430        Err(GpuError::NotAvailable)
431    }
432
433    pub fn upload(&self, _handle: &VramHandle, _data: &[u8]) -> Result<(), GpuError> {
434        Err(GpuError::NotAvailable)
435    }
436
437    pub fn download(&self, _handle: &VramHandle) -> Result<Vec<u8>, GpuError> {
438        Err(GpuError::NotAvailable)
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445
446    #[test]
447    fn test_vram_handle() {
448        let h1 = VramHandle { id: 1, size: 1024 };
449        let h2 = VramHandle { id: 2, size: 2048 };
450
451        assert_eq!(h1.id, 1);
452        assert_eq!(h1.size, 1024);
453        assert_ne!(h1, h2);
454    }
455
456    #[test]
457    fn test_vram_pool_config_default() {
458        let config = VramPoolConfig::default();
459        assert!((config.max_usage_ratio - 0.80).abs() < 0.001);
460        assert!(config.enable_eviction);
461        assert_eq!(config.min_free_bytes, 256 * 1024 * 1024);
462    }
463
464    #[cfg(feature = "cuda")]
465    #[test]
466    fn test_vram_allocation_lru() {
467        let handle = VramHandle { id: 1, size: 100 };
468        let mut alloc = VramAllocation::new(handle);
469
470        let t1 = alloc.last_access;
471        std::thread::sleep(std::time::Duration::from_millis(10));
472        alloc.touch();
473        let t2 = alloc.last_access;
474
475        assert!(t2 > t1);
476    }
477
478    #[test]
479    fn test_vram_pool_stats() {
480        let stats = VramPoolStats {
481            total_capacity: 1024 * 1024 * 1024,
482            allocated_bytes: 512 * 1024 * 1024,
483            available_bytes: 512 * 1024 * 1024,
484            num_allocations: 10,
485            num_pinned: 2,
486            num_dirty: 1,
487        };
488
489        assert_eq!(stats.num_allocations, 10);
490        assert_eq!(stats.num_pinned, 2);
491    }
492}