candle_cuda_vmm/
virtual_memory.rs

1//! Virtual memory pool for elastic memory allocation.
2
3use crate::cuda_ffi::{self, AccessFlags};
4use crate::error::{Result, VmmError};
5use crate::mapping::{map_memory, set_memory_access, unmap_memory, VirtualAddressRange};
6use crate::physical_memory::PhysicalMemoryHandle;
7use candle_core::{Device, DeviceLocation};
8use std::collections::HashMap;
9
10/// Helper function to extract device ordinal from Candle Device
11fn get_device_ordinal(device: &Device) -> Result<i32> {
12    match device.location() {
13        DeviceLocation::Cuda { gpu_id } => Ok(gpu_id as i32),
14        _ => Err(VmmError::other("Device must be a CUDA device")),
15    }
16}
17
18/// Page state in the virtual memory pool.
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20enum PageState {
21    /// Page is not allocated (no physical memory).
22    Free,
23    /// Page is allocated and mapped to physical memory.
24    Allocated,
25}
26
27/// Elastic memory pool with virtual memory backing.
28///
29/// This pool reserves a large virtual address space but only allocates physical
30/// memory on-demand when `allocate()` is called. This enables:
31/// - Large virtual capacity (e.g., 128GB) with minimal initial physical usage
32/// - Dynamic allocation/deallocation based on workload
33/// - Reduced memory waste for bursty workloads
34///
35/// # Example
36/// ```no_run
37/// use candle_cuda_vmm::VirtualMemoryPool;
38/// use candle_core::Device;
39///
40/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
41/// let device = Device::new_cuda(0)?;
42/// let mut pool = VirtualMemoryPool::new(
43///     128 * 1024 * 1024 * 1024, // 128GB virtual capacity
44///     2 * 1024 * 1024,          // 2MB page size
45///     device,
46/// )?;
47///
48/// // Allocate 1GB of physical memory on-demand
49/// let addr = pool.allocate(0, 1024 * 1024 * 1024)?;
50/// println!("Physical usage: {} bytes", pool.physical_memory_usage());
51///
52/// // Deallocate when done
53/// pool.deallocate(0, 1024 * 1024 * 1024)?;
54/// # Ok(())
55/// # }
56/// ```
57pub struct VirtualMemoryPool {
58    /// Virtual address range reservation.
59    virtual_range: VirtualAddressRange,
60    /// Physical memory handles for each page (indexed by page number).
61    physical_pages: HashMap<usize, PhysicalMemoryHandle>,
62    /// Page state tracking.
63    page_states: Vec<PageState>,
64    /// Page size in bytes.
65    page_size: usize,
66    /// Total virtual capacity in bytes.
67    total_capacity: usize,
68    /// Currently mapped size in bytes.
69    mapped_size: usize,
70    /// Device ordinal.
71    device_ordinal: i32,
72}
73
74impl VirtualMemoryPool {
75    /// Create a new virtual memory pool.
76    ///
77    /// # Arguments
78    /// * `capacity` - Maximum virtual address space (e.g., 128GB).
79    /// * `page_size` - Page granularity (e.g., 2MB for large pages).
80    /// * `device` - CUDA device.
81    ///
82    /// # Returns
83    /// Pool with reserved virtual address space, no physical memory allocated.
84    ///
85    /// # Errors
86    /// Returns error if:
87    /// - Device is not a CUDA device
88    /// - Page size is invalid (not power of 2 or < 64KB)
89    /// - Virtual address reservation fails
90    pub fn new(capacity: usize, page_size: usize, device: Device) -> Result<Self> {
91        // Validate device
92        let device_ordinal = get_device_ordinal(&device)?;
93
94        // Validate page size
95        if !page_size.is_power_of_two() || page_size < 64 * 1024 {
96            return Err(VmmError::InvalidPageSize(page_size));
97        }
98
99        // Ensure capacity is multiple of page size
100        let capacity = (capacity + page_size - 1) / page_size * page_size;
101
102        // Reserve virtual address space
103        let virtual_range = VirtualAddressRange::new(capacity, page_size)?;
104
105        // Calculate number of pages
106        let num_pages = capacity / page_size;
107
108        Ok(Self {
109            virtual_range,
110            physical_pages: HashMap::new(),
111            page_states: vec![PageState::Free; num_pages],
112            page_size,
113            total_capacity: capacity,
114            mapped_size: 0,
115            device_ordinal,
116        })
117    }
118
119    /// Allocate and map physical pages on-demand.
120    ///
121    /// # Arguments
122    /// * `offset` - Offset in virtual address space (bytes).
123    /// * `size` - Number of bytes to allocate.
124    ///
125    /// # Returns
126    /// Base virtual address of allocated region.
127    ///
128    /// # Errors
129    /// Returns error if:
130    /// - Offset/size out of bounds
131    /// - Region already allocated
132    /// - Physical memory allocation fails
133    pub fn allocate(&mut self, offset: usize, size: usize) -> Result<usize> {
134        // Validate parameters
135        if offset + size > self.total_capacity {
136            return Err(VmmError::InvalidOffset {
137                offset,
138                size,
139                capacity: self.total_capacity,
140            });
141        }
142
143        // Align offset and size to page boundaries
144        let start_page = offset / self.page_size;
145        let end_page = (offset + size + self.page_size - 1) / self.page_size;
146
147        // Check if any pages are already allocated
148        for page_idx in start_page..end_page {
149            if self.page_states[page_idx] == PageState::Allocated {
150                return Err(VmmError::AlreadyMapped {
151                    offset: page_idx * self.page_size,
152                    size: self.page_size,
153                });
154            }
155        }
156
157        // Allocate and map each page
158        for page_idx in start_page..end_page {
159            // Allocate physical memory for this page
160            let device = Device::new_cuda(self.device_ordinal as usize)?;
161            let physical_handle = PhysicalMemoryHandle::new(self.page_size, &device)?;
162
163            // Map physical memory to virtual address
164            let page_offset = page_idx * self.page_size;
165            map_memory(
166                &self.virtual_range,
167                page_offset,
168                &physical_handle,
169                0,
170                self.page_size,
171            )?;
172
173            // Set memory access permissions
174            set_memory_access(
175                &self.virtual_range,
176                page_offset,
177                self.page_size,
178                self.device_ordinal,
179                AccessFlags::ReadWrite,
180            )?;
181
182            // Store physical handle and update state
183            self.physical_pages.insert(page_idx, physical_handle);
184            self.page_states[page_idx] = PageState::Allocated;
185            self.mapped_size += self.page_size;
186        }
187
188        Ok(self.virtual_range.base_address() + offset)
189    }
190
191    /// Unmap and free physical pages.
192    ///
193    /// # Arguments
194    /// * `offset` - Offset in virtual address space (bytes).
195    /// * `size` - Number of bytes to free.
196    ///
197    /// # Errors
198    /// Returns error if:
199    /// - Offset/size out of bounds
200    /// - Region not allocated
201    pub fn deallocate(&mut self, offset: usize, size: usize) -> Result<()> {
202        // Validate parameters
203        if offset + size > self.total_capacity {
204            return Err(VmmError::InvalidOffset {
205                offset,
206                size,
207                capacity: self.total_capacity,
208            });
209        }
210
211        // Align offset and size to page boundaries
212        let start_page = offset / self.page_size;
213        let end_page = (offset + size + self.page_size - 1) / self.page_size;
214
215        // Check if all pages are allocated
216        for page_idx in start_page..end_page {
217            if self.page_states[page_idx] == PageState::Free {
218                return Err(VmmError::NotMapped {
219                    offset: page_idx * self.page_size,
220                    size: self.page_size,
221                });
222            }
223        }
224
225        // Unmap and free each page
226        for page_idx in start_page..end_page {
227            let page_offset = page_idx * self.page_size;
228
229            // Unmap virtual memory
230            unmap_memory(&self.virtual_range, page_offset, self.page_size)?;
231
232            // Remove physical handle (automatically freed via Drop)
233            self.physical_pages.remove(&page_idx);
234            self.page_states[page_idx] = PageState::Free;
235            self.mapped_size -= self.page_size;
236        }
237
238        Ok(())
239    }
240
241    /// Get current physical memory usage in bytes.
242    pub fn physical_memory_usage(&self) -> usize {
243        self.mapped_size
244    }
245
246    /// Get virtual address space capacity in bytes.
247    pub fn capacity(&self) -> usize {
248        self.total_capacity
249    }
250
251    /// Get base virtual address.
252    pub fn base_address(&self) -> usize {
253        self.virtual_range.base_address()
254    }
255
256    /// Get page size in bytes.
257    pub fn page_size(&self) -> usize {
258        self.page_size
259    }
260
261    /// Check if a range is currently mapped.
262    ///
263    /// # Arguments
264    /// * `offset` - Offset in virtual address space (bytes).
265    /// * `size` - Size to check (bytes).
266    ///
267    /// # Returns
268    /// True if entire range is mapped, false otherwise.
269    pub fn is_mapped(&self, offset: usize, size: usize) -> bool {
270        if offset + size > self.total_capacity {
271            return false;
272        }
273
274        let start_page = offset / self.page_size;
275        let end_page = (offset + size + self.page_size - 1) / self.page_size;
276
277        for page_idx in start_page..end_page {
278            if self.page_states[page_idx] != PageState::Allocated {
279                return false;
280            }
281        }
282
283        true
284    }
285
286    /// Compact pool by coalescing free pages (no-op for now, future optimization).
287    pub fn compact(&mut self) -> Result<()> {
288        // Future: Implement compaction to reduce fragmentation
289        Ok(())
290    }
291
292    /// Get memory statistics.
293    pub fn stats(&self) -> MemoryStats {
294        let allocated_pages = self
295            .page_states
296            .iter()
297            .filter(|&&state| state == PageState::Allocated)
298            .count();
299
300        let total_pages = self.page_states.len();
301        let fragmentation_ratio = if total_pages > 0 {
302            1.0 - (allocated_pages as f32 / total_pages as f32)
303        } else {
304            0.0
305        };
306
307        MemoryStats {
308            virtual_capacity: self.total_capacity,
309            physical_usage: self.mapped_size,
310            mapped_pages: allocated_pages,
311            fragmentation_ratio,
312        }
313    }
314}
315
316/// Memory statistics for a pool.
317#[derive(Debug, Clone)]
318pub struct MemoryStats {
319    /// Virtual address space capacity in bytes.
320    pub virtual_capacity: usize,
321    /// Physical memory usage in bytes.
322    pub physical_usage: usize,
323    /// Number of mapped pages.
324    pub mapped_pages: usize,
325    /// Fragmentation ratio (0.0 = no fragmentation, 1.0 = completely fragmented).
326    pub fragmentation_ratio: f32,
327}
328
329/// Shared memory pool for multiple models.
330///
331/// Manages multiple virtual memory pools with a global physical memory limit.
332/// Enables memory sharing across models with per-model statistics.
333///
334/// # Example
335/// ```no_run
336/// use candle_cuda_vmm::SharedMemoryPool;
337/// use candle_core::Device;
338///
339/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
340/// let device = Device::new_cuda(0)?;
341/// let mut shared_pool = SharedMemoryPool::new(
342///     32 * 1024 * 1024 * 1024, // 32GB global physical limit
343///     device,
344/// )?;
345///
346/// // Register models
347/// shared_pool.register_model("llama-7b", 64 * 1024 * 1024 * 1024)?; // 64GB virtual
348/// shared_pool.register_model("gpt2", 32 * 1024 * 1024 * 1024)?;     // 32GB virtual
349///
350/// // Allocate for specific model
351/// let addr = shared_pool.allocate_for_model("llama-7b", 1024 * 1024 * 1024)?;
352/// # Ok(())
353/// # }
354/// ```
355pub struct SharedMemoryPool {
356    /// Per-model virtual memory pools.
357    pools: HashMap<String, VirtualMemoryPool>,
358    /// Global physical memory limit in bytes.
359    global_physical_limit: usize,
360    /// Current global physical usage in bytes.
361    current_physical_usage: usize,
362    /// Device ordinal.
363    device_ordinal: i32,
364    /// Default page size for new pools.
365    default_page_size: usize,
366}
367
368impl SharedMemoryPool {
369    /// Create shared pool with global physical memory limit.
370    ///
371    /// # Arguments
372    /// * `physical_limit` - Global physical memory limit (bytes).
373    /// * `device` - CUDA device.
374    ///
375    /// # Returns
376    /// Shared memory pool.
377    pub fn new(physical_limit: usize, device: Device) -> Result<Self> {
378        let device_ordinal = get_device_ordinal(&device)?;
379
380        // Get recommended page size
381        let default_page_size = cuda_ffi::get_recommended_granularity(device_ordinal)?;
382
383        Ok(Self {
384            pools: HashMap::new(),
385            global_physical_limit: physical_limit,
386            current_physical_usage: 0,
387            device_ordinal,
388            default_page_size,
389        })
390    }
391
392    /// Register a model with virtual address space reservation.
393    ///
394    /// # Arguments
395    /// * `model_id` - Unique model identifier.
396    /// * `virtual_capacity` - Virtual address space for this model (bytes).
397    ///
398    /// # Errors
399    /// Returns error if model already registered.
400    pub fn register_model(&mut self, model_id: &str, virtual_capacity: usize) -> Result<()> {
401        if self.pools.contains_key(model_id) {
402            return Err(VmmError::ModelAlreadyExists(model_id.to_string()));
403        }
404
405        let device = Device::new_cuda(self.device_ordinal as usize)?;
406        let pool = VirtualMemoryPool::new(virtual_capacity, self.default_page_size, device)?;
407
408        self.pools.insert(model_id.to_string(), pool);
409        Ok(())
410    }
411
412    /// Allocate from specific model's pool.
413    ///
414    /// # Arguments
415    /// * `model_id` - Model identifier.
416    /// * `size` - Size to allocate (bytes).
417    ///
418    /// # Returns
419    /// Virtual address of allocated region.
420    ///
421    /// # Errors
422    /// Returns error if:
423    /// - Model not found
424    /// - Global physical limit exceeded
425    /// - Allocation fails
426    pub fn allocate_for_model(&mut self, model_id: &str, size: usize) -> Result<usize> {
427        let pool = self
428            .pools
429            .get_mut(model_id)
430            .ok_or_else(|| VmmError::ModelNotFound(model_id.to_string()))?;
431
432        // Check global physical limit
433        let rounded_size =
434            (size + self.default_page_size - 1) / self.default_page_size * self.default_page_size;
435        if self.current_physical_usage + rounded_size > self.global_physical_limit {
436            return Err(VmmError::OutOfPhysicalMemory {
437                requested: rounded_size,
438                available: self.global_physical_limit - self.current_physical_usage,
439            });
440        }
441
442        // Allocate from model's pool
443        let addr = pool.allocate(0, size)?;
444        self.current_physical_usage += rounded_size;
445
446        Ok(addr)
447    }
448
449    /// Free from specific model's pool.
450    ///
451    /// # Arguments
452    /// * `model_id` - Model identifier.
453    /// * `offset` - Offset in model's virtual address space (bytes).
454    /// * `size` - Size to free (bytes).
455    pub fn deallocate_for_model(
456        &mut self,
457        model_id: &str,
458        offset: usize,
459        size: usize,
460    ) -> Result<()> {
461        let pool = self
462            .pools
463            .get_mut(model_id)
464            .ok_or_else(|| VmmError::ModelNotFound(model_id.to_string()))?;
465
466        let rounded_size =
467            (size + self.default_page_size - 1) / self.default_page_size * self.default_page_size;
468
469        pool.deallocate(offset, size)?;
470        self.current_physical_usage = self.current_physical_usage.saturating_sub(rounded_size);
471
472        Ok(())
473    }
474
475    /// Get per-model memory statistics.
476    pub fn get_model_stats(&self, model_id: &str) -> Option<MemoryStats> {
477        self.pools.get(model_id).map(|pool| pool.stats())
478    }
479
480    /// Global memory statistics.
481    pub fn global_stats(&self) -> GlobalMemoryStats {
482        GlobalMemoryStats {
483            physical_limit: self.global_physical_limit,
484            physical_usage: self.current_physical_usage,
485            num_models: self.pools.len(),
486        }
487    }
488
489    /// Unregister a model and free its resources.
490    pub fn unregister_model(&mut self, model_id: &str) -> Result<()> {
491        if let Some(pool) = self.pools.remove(model_id) {
492            let usage = pool.physical_memory_usage();
493            self.current_physical_usage = self.current_physical_usage.saturating_sub(usage);
494            Ok(())
495        } else {
496            Err(VmmError::ModelNotFound(model_id.to_string()))
497        }
498    }
499}
500
501/// Global memory statistics for shared pool.
502#[derive(Debug, Clone)]
503pub struct GlobalMemoryStats {
504    /// Global physical memory limit in bytes.
505    pub physical_limit: usize,
506    /// Current global physical usage in bytes.
507    pub physical_usage: usize,
508    /// Number of registered models.
509    pub num_models: usize,
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515
516    #[test]
517    fn test_memory_stats() {
518        let stats = MemoryStats {
519            virtual_capacity: 1024 * 1024,
520            physical_usage: 512 * 1024,
521            mapped_pages: 256,
522            fragmentation_ratio: 0.5,
523        };
524
525        assert_eq!(stats.virtual_capacity, 1024 * 1024);
526        assert_eq!(stats.physical_usage, 512 * 1024);
527    }
528
529    #[test]
530    fn test_global_memory_stats() {
531        let stats = GlobalMemoryStats {
532            physical_limit: 32 * 1024 * 1024 * 1024,
533            physical_usage: 16 * 1024 * 1024 * 1024,
534            num_models: 3,
535        };
536
537        assert_eq!(stats.num_models, 3);
538    }
539}