1#[cfg(feature = "cuda")]
35use std::collections::HashMap;
36#[cfg(feature = "cuda")]
37use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
38#[cfg(feature = "cuda")]
39use std::sync::{Arc, RwLock};
40
41#[cfg(feature = "cuda")]
42use cudarc::driver::{CudaSlice, CudaStream};
43
44#[cfg(feature = "cuda")]
45use crate::gpu::{GpuError, GpuMemoryConfig};
46
47#[cfg(not(feature = "cuda"))]
49#[derive(Debug, Clone)]
50pub enum GpuError {
51 NotAvailable,
52}
53
54#[cfg(not(feature = "cuda"))]
55#[derive(Clone, Debug, Default)]
56pub struct GpuMemoryConfig {
57 pub safe_limit: usize,
58}
59
60#[derive(Clone, Debug)]
62pub struct VramPoolConfig {
63 pub max_usage_ratio: f64,
65 pub enable_eviction: bool,
67 pub min_free_bytes: usize,
69 pub enable_async: bool,
71}
72
73impl Default for VramPoolConfig {
74 fn default() -> Self {
75 Self {
76 max_usage_ratio: 0.80, enable_eviction: true,
78 min_free_bytes: 256 * 1024 * 1024, enable_async: true,
80 }
81 }
82}
83
84#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
86pub struct VramHandle {
87 pub id: u64,
89 pub size: usize,
91}
92
93#[cfg(feature = "cuda")]
94impl VramHandle {
95 fn new(id: u64, size: usize) -> Self {
97 Self { id, size }
98 }
99}
100
101#[cfg(feature = "cuda")]
103#[derive(Debug)]
104struct VramAllocation {
105 handle: VramHandle,
107 last_access: std::time::Instant,
109 dirty: bool,
111 pinned: bool,
113}
114
115#[cfg(feature = "cuda")]
116impl VramAllocation {
117 fn new(handle: VramHandle) -> Self {
118 Self {
119 handle,
120 last_access: std::time::Instant::now(),
121 dirty: false,
122 pinned: false,
123 }
124 }
125
126 fn touch(&mut self) {
127 self.last_access = std::time::Instant::now();
128 }
129}
130
131#[cfg(feature = "cuda")]
136pub struct VramPool {
137 stream: Arc<CudaStream>,
139 config: VramPoolConfig,
141 memory_config: GpuMemoryConfig,
143 allocations: RwLock<HashMap<u64, CudaSlice<u8>>>,
145 metadata: RwLock<HashMap<u64, VramAllocation>>,
147 next_id: AtomicU64,
149 allocated_bytes: AtomicUsize,
151}
152
153#[cfg(feature = "cuda")]
154impl VramPool {
155 pub fn new(
157 stream: Arc<CudaStream>,
158 memory_config: GpuMemoryConfig,
159 config: VramPoolConfig,
160 ) -> Self {
161 Self {
162 stream,
163 config,
164 memory_config,
165 allocations: RwLock::new(HashMap::new()),
166 metadata: RwLock::new(HashMap::new()),
167 next_id: AtomicU64::new(1),
168 allocated_bytes: AtomicUsize::new(0),
169 }
170 }
171
172 pub fn max_usable_bytes(&self) -> usize {
174 let safe_limit = self.memory_config.safe_limit;
175 let from_ratio = (safe_limit as f64 * self.config.max_usage_ratio) as usize;
176 from_ratio.saturating_sub(self.config.min_free_bytes)
177 }
178
179 pub fn allocated_bytes(&self) -> usize {
181 self.allocated_bytes.load(Ordering::Relaxed)
182 }
183
184 pub fn available_bytes(&self) -> usize {
186 self.max_usable_bytes()
187 .saturating_sub(self.allocated_bytes())
188 }
189
190 pub fn can_allocate(&self, size: usize) -> bool {
192 size <= self.available_bytes()
193 }
194
195 pub fn allocate(&self, size: usize) -> Result<VramHandle, GpuError> {
200 if !self.can_allocate(size) {
202 if self.config.enable_eviction {
203 self.evict_until_available(size)?;
204 }
205
206 if !self.can_allocate(size) {
207 return Err(GpuError::MemoryAlloc(format!(
208 "VRAM pool exhausted: need {} bytes, available {} bytes",
209 size,
210 self.available_bytes()
211 )));
212 }
213 }
214
215 let device_buffer: CudaSlice<u8> = self
217 .stream
218 .alloc_zeros(size)
219 .map_err(|e| GpuError::MemoryAlloc(e.to_string()))?;
220
221 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
223 let handle = VramHandle::new(id, size);
224 let allocation = VramAllocation::new(handle);
225
226 {
228 let mut allocs = self.allocations.write().unwrap();
229 let mut meta = self.metadata.write().unwrap();
230 allocs.insert(id, device_buffer);
231 meta.insert(id, allocation);
232 }
233
234 self.allocated_bytes.fetch_add(size, Ordering::SeqCst);
235
236 Ok(handle)
237 }
238
239 pub fn free(&self, handle: VramHandle) -> Result<(), GpuError> {
241 let mut allocs = self.allocations.write().unwrap();
242 let mut meta = self.metadata.write().unwrap();
243
244 if allocs.remove(&handle.id).is_some() {
245 meta.remove(&handle.id);
246 self.allocated_bytes
247 .fetch_sub(handle.size, Ordering::SeqCst);
248 Ok(())
249 } else {
250 Err(GpuError::InvalidValue(format!(
251 "VRAM handle {} not found",
252 handle.id
253 )))
254 }
255 }
256
257 pub fn upload(&self, handle: &VramHandle, data: &[u8]) -> Result<(), GpuError> {
259 if data.len() != handle.size {
260 return Err(GpuError::InvalidValue(format!(
261 "Data size {} doesn't match allocation size {}",
262 data.len(),
263 handle.size
264 )));
265 }
266
267 {
269 let mut meta = self.metadata.write().unwrap();
270 if let Some(alloc) = meta.get_mut(&handle.id) {
271 alloc.touch();
272 }
273 }
274
275 let mut allocs = self.allocations.write().unwrap();
277 let device_buf = allocs.get_mut(&handle.id).ok_or_else(|| {
278 GpuError::InvalidValue(format!("VRAM handle {} not found", handle.id))
279 })?;
280
281 self.stream
283 .memcpy_htod(data, device_buf)
284 .map_err(|e| GpuError::MemoryCopy(e.to_string()))?;
285
286 Ok(())
287 }
288
289 pub fn download(&self, handle: &VramHandle) -> Result<Vec<u8>, GpuError> {
291 {
293 let mut meta = self.metadata.write().unwrap();
294 if let Some(alloc) = meta.get_mut(&handle.id) {
295 alloc.touch();
296 }
297 }
298
299 let allocs = self.allocations.read().unwrap();
300 let device_buf = allocs.get(&handle.id).ok_or_else(|| {
301 GpuError::InvalidValue(format!("VRAM handle {} not found", handle.id))
302 })?;
303
304 let data = self
305 .stream
306 .clone_dtoh(device_buf)
307 .map_err(|e| GpuError::MemoryCopy(e.to_string()))?;
308
309 Ok(data)
310 }
311
312 pub fn pin(&self, handle: &VramHandle) -> Result<(), GpuError> {
314 let mut meta = self.metadata.write().unwrap();
315 let alloc = meta.get_mut(&handle.id).ok_or_else(|| {
316 GpuError::InvalidValue(format!("VRAM handle {} not found", handle.id))
317 })?;
318 alloc.pinned = true;
319 Ok(())
320 }
321
322 pub fn unpin(&self, handle: &VramHandle) -> Result<(), GpuError> {
324 let mut meta = self.metadata.write().unwrap();
325 let alloc = meta.get_mut(&handle.id).ok_or_else(|| {
326 GpuError::InvalidValue(format!("VRAM handle {} not found", handle.id))
327 })?;
328 alloc.pinned = false;
329 Ok(())
330 }
331
332 pub fn mark_dirty(&self, handle: &VramHandle) -> Result<(), GpuError> {
334 let mut meta = self.metadata.write().unwrap();
335 let alloc = meta.get_mut(&handle.id).ok_or_else(|| {
336 GpuError::InvalidValue(format!("VRAM handle {} not found", handle.id))
337 })?;
338 alloc.dirty = true;
339 Ok(())
340 }
341
342 pub fn is_dirty(&self, handle: &VramHandle) -> Result<bool, GpuError> {
344 let meta = self.metadata.read().unwrap();
345 let alloc = meta.get(&handle.id).ok_or_else(|| {
346 GpuError::InvalidValue(format!("VRAM handle {} not found", handle.id))
347 })?;
348 Ok(alloc.dirty)
349 }
350
351 fn evict_until_available(&self, needed: usize) -> Result<(), GpuError> {
353 while self.available_bytes() < needed {
354 let to_evict = {
356 let meta = self.metadata.read().unwrap();
357 meta.values()
358 .filter(|a| !a.pinned)
359 .min_by_key(|a| a.last_access)
360 .map(|a| a.handle)
361 };
362
363 match to_evict {
364 Some(handle) => {
365 self.free(handle)?;
366 }
367 None => {
368 return Err(GpuError::MemoryAlloc(
370 "Cannot evict: all allocations are pinned".to_string(),
371 ));
372 }
373 }
374 }
375 Ok(())
376 }
377
378 pub fn stats(&self) -> VramPoolStats {
380 let meta = self.metadata.read().unwrap();
381 let num_allocations = meta.len();
382 let num_pinned = meta.values().filter(|a| a.pinned).count();
383 let num_dirty = meta.values().filter(|a| a.dirty).count();
384
385 VramPoolStats {
386 total_capacity: self.max_usable_bytes(),
387 allocated_bytes: self.allocated_bytes(),
388 available_bytes: self.available_bytes(),
389 num_allocations,
390 num_pinned,
391 num_dirty,
392 }
393 }
394}
395
396#[derive(Clone, Debug)]
398pub struct VramPoolStats {
399 pub total_capacity: usize,
401 pub allocated_bytes: usize,
403 pub available_bytes: usize,
405 pub num_allocations: usize,
407 pub num_pinned: usize,
409 pub num_dirty: usize,
411}
412
413#[cfg(not(feature = "cuda"))]
415pub struct VramPool {
416 _private: (),
417}
418
419#[cfg(not(feature = "cuda"))]
420impl VramPool {
421 pub fn new(_stream: (), _memory_config: GpuMemoryConfig, _config: VramPoolConfig) -> Self {
422 Self { _private: () }
423 }
424
425 pub fn allocate(&self, _size: usize) -> Result<VramHandle, GpuError> {
426 Err(GpuError::NotAvailable)
427 }
428
429 pub fn free(&self, _handle: VramHandle) -> Result<(), GpuError> {
430 Err(GpuError::NotAvailable)
431 }
432
433 pub fn upload(&self, _handle: &VramHandle, _data: &[u8]) -> Result<(), GpuError> {
434 Err(GpuError::NotAvailable)
435 }
436
437 pub fn download(&self, _handle: &VramHandle) -> Result<Vec<u8>, GpuError> {
438 Err(GpuError::NotAvailable)
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445
446 #[test]
447 fn test_vram_handle() {
448 let h1 = VramHandle { id: 1, size: 1024 };
449 let h2 = VramHandle { id: 2, size: 2048 };
450
451 assert_eq!(h1.id, 1);
452 assert_eq!(h1.size, 1024);
453 assert_ne!(h1, h2);
454 }
455
456 #[test]
457 fn test_vram_pool_config_default() {
458 let config = VramPoolConfig::default();
459 assert!((config.max_usage_ratio - 0.80).abs() < 0.001);
460 assert!(config.enable_eviction);
461 assert_eq!(config.min_free_bytes, 256 * 1024 * 1024);
462 }
463
464 #[cfg(feature = "cuda")]
465 #[test]
466 fn test_vram_allocation_lru() {
467 let handle = VramHandle { id: 1, size: 100 };
468 let mut alloc = VramAllocation::new(handle);
469
470 let t1 = alloc.last_access;
471 std::thread::sleep(std::time::Duration::from_millis(10));
472 alloc.touch();
473 let t2 = alloc.last_access;
474
475 assert!(t2 > t1);
476 }
477
478 #[test]
479 fn test_vram_pool_stats() {
480 let stats = VramPoolStats {
481 total_capacity: 1024 * 1024 * 1024,
482 allocated_bytes: 512 * 1024 * 1024,
483 available_bytes: 512 * 1024 * 1024,
484 num_allocations: 10,
485 num_pinned: 2,
486 num_dirty: 1,
487 };
488
489 assert_eq!(stats.num_allocations, 10);
490 assert_eq!(stats.num_pinned, 2);
491 }
492}