use crate::cuda_ffi::{self, AccessFlags};
use crate::error::{Result, VmmError};
use crate::mapping::{map_memory, set_memory_access, unmap_memory, VirtualAddressRange};
use crate::physical_memory::PhysicalMemoryHandle;
use candle_core::{Device, DeviceLocation};
use std::collections::HashMap;
fn get_device_ordinal(device: &Device) -> Result<i32> {
match device.location() {
DeviceLocation::Cuda { gpu_id } => Ok(gpu_id as i32),
_ => Err(VmmError::other("Device must be a CUDA device")),
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum PageState {
Free,
Allocated,
}
pub struct VirtualMemoryPool {
virtual_range: VirtualAddressRange,
physical_pages: HashMap<usize, PhysicalMemoryHandle>,
page_states: Vec<PageState>,
page_size: usize,
total_capacity: usize,
mapped_size: usize,
device_ordinal: i32,
}
impl VirtualMemoryPool {
pub fn new(capacity: usize, page_size: usize, device: Device) -> Result<Self> {
let device_ordinal = get_device_ordinal(&device)?;
if !page_size.is_power_of_two() || page_size < 64 * 1024 {
return Err(VmmError::InvalidPageSize(page_size));
}
let capacity = (capacity + page_size - 1) / page_size * page_size;
let virtual_range = VirtualAddressRange::new(capacity, page_size)?;
let num_pages = capacity / page_size;
Ok(Self {
virtual_range,
physical_pages: HashMap::new(),
page_states: vec![PageState::Free; num_pages],
page_size,
total_capacity: capacity,
mapped_size: 0,
device_ordinal,
})
}
pub fn allocate(&mut self, offset: usize, size: usize) -> Result<usize> {
if offset + size > self.total_capacity {
return Err(VmmError::InvalidOffset {
offset,
size,
capacity: self.total_capacity,
});
}
let start_page = offset / self.page_size;
let end_page = (offset + size + self.page_size - 1) / self.page_size;
for page_idx in start_page..end_page {
if self.page_states[page_idx] == PageState::Allocated {
return Err(VmmError::AlreadyMapped {
offset: page_idx * self.page_size,
size: self.page_size,
});
}
}
for page_idx in start_page..end_page {
let device = Device::new_cuda(self.device_ordinal as usize)?;
let physical_handle = PhysicalMemoryHandle::new(self.page_size, &device)?;
let page_offset = page_idx * self.page_size;
map_memory(
&self.virtual_range,
page_offset,
&physical_handle,
0,
self.page_size,
)?;
set_memory_access(
&self.virtual_range,
page_offset,
self.page_size,
self.device_ordinal,
AccessFlags::ReadWrite,
)?;
self.physical_pages.insert(page_idx, physical_handle);
self.page_states[page_idx] = PageState::Allocated;
self.mapped_size += self.page_size;
}
Ok(self.virtual_range.base_address() + offset)
}
pub fn deallocate(&mut self, offset: usize, size: usize) -> Result<()> {
if offset + size > self.total_capacity {
return Err(VmmError::InvalidOffset {
offset,
size,
capacity: self.total_capacity,
});
}
let start_page = offset / self.page_size;
let end_page = (offset + size + self.page_size - 1) / self.page_size;
for page_idx in start_page..end_page {
if self.page_states[page_idx] == PageState::Free {
return Err(VmmError::NotMapped {
offset: page_idx * self.page_size,
size: self.page_size,
});
}
}
for page_idx in start_page..end_page {
let page_offset = page_idx * self.page_size;
unmap_memory(&self.virtual_range, page_offset, self.page_size)?;
self.physical_pages.remove(&page_idx);
self.page_states[page_idx] = PageState::Free;
self.mapped_size -= self.page_size;
}
Ok(())
}
pub fn physical_memory_usage(&self) -> usize {
self.mapped_size
}
pub fn capacity(&self) -> usize {
self.total_capacity
}
pub fn base_address(&self) -> usize {
self.virtual_range.base_address()
}
pub fn page_size(&self) -> usize {
self.page_size
}
pub fn is_mapped(&self, offset: usize, size: usize) -> bool {
if offset + size > self.total_capacity {
return false;
}
let start_page = offset / self.page_size;
let end_page = (offset + size + self.page_size - 1) / self.page_size;
for page_idx in start_page..end_page {
if self.page_states[page_idx] != PageState::Allocated {
return false;
}
}
true
}
pub fn compact(&mut self) -> Result<()> {
Ok(())
}
pub fn stats(&self) -> MemoryStats {
let allocated_pages = self
.page_states
.iter()
.filter(|&&state| state == PageState::Allocated)
.count();
let total_pages = self.page_states.len();
let fragmentation_ratio = if total_pages > 0 {
1.0 - (allocated_pages as f32 / total_pages as f32)
} else {
0.0
};
MemoryStats {
virtual_capacity: self.total_capacity,
physical_usage: self.mapped_size,
mapped_pages: allocated_pages,
fragmentation_ratio,
}
}
}
#[derive(Debug, Clone)]
pub struct MemoryStats {
pub virtual_capacity: usize,
pub physical_usage: usize,
pub mapped_pages: usize,
pub fragmentation_ratio: f32,
}
pub struct SharedMemoryPool {
pools: HashMap<String, VirtualMemoryPool>,
global_physical_limit: usize,
current_physical_usage: usize,
device_ordinal: i32,
default_page_size: usize,
}
impl SharedMemoryPool {
pub fn new(physical_limit: usize, device: Device) -> Result<Self> {
let device_ordinal = get_device_ordinal(&device)?;
let default_page_size = cuda_ffi::get_recommended_granularity(device_ordinal)?;
Ok(Self {
pools: HashMap::new(),
global_physical_limit: physical_limit,
current_physical_usage: 0,
device_ordinal,
default_page_size,
})
}
pub fn register_model(&mut self, model_id: &str, virtual_capacity: usize) -> Result<()> {
if self.pools.contains_key(model_id) {
return Err(VmmError::ModelAlreadyExists(model_id.to_string()));
}
let device = Device::new_cuda(self.device_ordinal as usize)?;
let pool = VirtualMemoryPool::new(virtual_capacity, self.default_page_size, device)?;
self.pools.insert(model_id.to_string(), pool);
Ok(())
}
pub fn allocate_for_model(&mut self, model_id: &str, size: usize) -> Result<usize> {
let pool = self
.pools
.get_mut(model_id)
.ok_or_else(|| VmmError::ModelNotFound(model_id.to_string()))?;
let rounded_size =
(size + self.default_page_size - 1) / self.default_page_size * self.default_page_size;
if self.current_physical_usage + rounded_size > self.global_physical_limit {
return Err(VmmError::OutOfPhysicalMemory {
requested: rounded_size,
available: self.global_physical_limit - self.current_physical_usage,
});
}
let addr = pool.allocate(0, size)?;
self.current_physical_usage += rounded_size;
Ok(addr)
}
pub fn deallocate_for_model(
&mut self,
model_id: &str,
offset: usize,
size: usize,
) -> Result<()> {
let pool = self
.pools
.get_mut(model_id)
.ok_or_else(|| VmmError::ModelNotFound(model_id.to_string()))?;
let rounded_size =
(size + self.default_page_size - 1) / self.default_page_size * self.default_page_size;
pool.deallocate(offset, size)?;
self.current_physical_usage = self.current_physical_usage.saturating_sub(rounded_size);
Ok(())
}
pub fn get_model_stats(&self, model_id: &str) -> Option<MemoryStats> {
self.pools.get(model_id).map(|pool| pool.stats())
}
pub fn global_stats(&self) -> GlobalMemoryStats {
GlobalMemoryStats {
physical_limit: self.global_physical_limit,
physical_usage: self.current_physical_usage,
num_models: self.pools.len(),
}
}
pub fn unregister_model(&mut self, model_id: &str) -> Result<()> {
if let Some(pool) = self.pools.remove(model_id) {
let usage = pool.physical_memory_usage();
self.current_physical_usage = self.current_physical_usage.saturating_sub(usage);
Ok(())
} else {
Err(VmmError::ModelNotFound(model_id.to_string()))
}
}
}
#[derive(Debug, Clone)]
pub struct GlobalMemoryStats {
pub physical_limit: usize,
pub physical_usage: usize,
pub num_models: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_stats() {
let stats = MemoryStats {
virtual_capacity: 1024 * 1024,
physical_usage: 512 * 1024,
mapped_pages: 256,
fragmentation_ratio: 0.5,
};
assert_eq!(stats.virtual_capacity, 1024 * 1024);
assert_eq!(stats.physical_usage, 512 * 1024);
}
#[test]
fn test_global_memory_stats() {
let stats = GlobalMemoryStats {
physical_limit: 32 * 1024 * 1024 * 1024,
physical_usage: 16 * 1024 * 1024 * 1024,
num_models: 3,
};
assert_eq!(stats.num_models, 3);
}
}