Skip to main content

oxigdal_gpu/
memory.rs

1//! Advanced GPU memory management for OxiGDAL.
2//!
3//! This module provides sophisticated memory management strategies including
4//! memory pooling, staging buffer management, defragmentation, and VRAM budget tracking.
5
6use crate::context::GpuContext;
7use crate::error::{GpuError, GpuResult};
8use std::collections::{HashMap, VecDeque};
9use std::sync::{Arc, Mutex};
10use tracing::{debug, trace};
11use wgpu::{Buffer, BufferDescriptor, BufferUsages};
12
13/// Memory pool configuration.
14#[derive(Debug, Clone)]
15pub struct MemoryPoolConfig {
16    /// Initial pool size in bytes.
17    pub initial_size: u64,
18    /// Maximum pool size in bytes.
19    pub max_size: u64,
20    /// Growth factor when expanding pool.
21    pub growth_factor: f64,
22    /// Enable automatic defragmentation.
23    pub auto_defrag: bool,
24    /// Defragmentation threshold (fragmentation ratio).
25    pub defrag_threshold: f64,
26}
27
28impl Default for MemoryPoolConfig {
29    fn default() -> Self {
30        Self {
31            initial_size: 64 * 1024 * 1024,   // 64 MB
32            max_size: 2 * 1024 * 1024 * 1024, // 2 GB
33            growth_factor: 1.5,
34            auto_defrag: true,
35            defrag_threshold: 0.3,
36        }
37    }
38}
39
40/// Memory allocation statistics.
41#[derive(Debug, Clone, Default)]
42pub struct MemoryStats {
43    /// Total allocated bytes.
44    pub total_allocated: u64,
45    /// Total bytes in use.
46    pub bytes_in_use: u64,
47    /// Total bytes available in pool.
48    pub bytes_available: u64,
49    /// Number of active allocations.
50    pub num_allocations: usize,
51    /// Number of pool expansions.
52    pub num_expansions: usize,
53    /// Number of defragmentations performed.
54    pub num_defrags: usize,
55    /// Fragmentation ratio (0.0 = no fragmentation, 1.0 = fully fragmented).
56    pub fragmentation_ratio: f64,
57}
58
59impl MemoryStats {
60    /// Calculate memory utilization percentage.
61    pub fn utilization(&self) -> f64 {
62        if self.total_allocated == 0 {
63            return 0.0;
64        }
65        (self.bytes_in_use as f64 / self.total_allocated as f64) * 100.0
66    }
67
68    /// Check if defragmentation is recommended.
69    pub fn needs_defrag(&self, threshold: f64) -> bool {
70        self.fragmentation_ratio >= threshold
71    }
72}
73
74/// Memory block in the pool.
75#[derive(Debug)]
76struct MemoryBlock {
77    /// Starting offset in the pool.
78    offset: u64,
79    /// Size of the block.
80    size: u64,
81    /// Whether the block is in use.
82    in_use: bool,
83    /// Block ID for tracking.
84    id: u64,
85}
86
87impl MemoryBlock {
88    fn new(offset: u64, size: u64, id: u64) -> Self {
89        Self {
90            offset,
91            size,
92            in_use: false,
93            id,
94        }
95    }
96
97    fn can_fit(&self, size: u64) -> bool {
98        !self.in_use && self.size >= size
99    }
100}
101
102/// GPU memory pool for efficient buffer reuse.
103///
104/// This pool manages a large buffer and suballocates from it to avoid
105/// frequent GPU allocations.
106pub struct MemoryPool {
107    context: GpuContext,
108    config: MemoryPoolConfig,
109    buffer: Arc<Buffer>,
110    blocks: Vec<MemoryBlock>,
111    stats: MemoryStats,
112    next_block_id: u64,
113}
114
115impl MemoryPool {
116    /// Create a new memory pool.
117    ///
118    /// # Errors
119    ///
120    /// Returns an error if initial buffer creation fails.
121    pub fn new(context: &GpuContext, config: MemoryPoolConfig) -> GpuResult<Self> {
122        let buffer = Arc::new(context.device().create_buffer(&BufferDescriptor {
123            label: Some("Memory Pool"),
124            size: config.initial_size,
125            usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
126            mapped_at_creation: false,
127        }));
128
129        let mut blocks = Vec::new();
130        blocks.push(MemoryBlock::new(0, config.initial_size, 0));
131
132        let stats = MemoryStats {
133            total_allocated: config.initial_size,
134            bytes_available: config.initial_size,
135            ..Default::default()
136        };
137
138        debug!(
139            "Created memory pool: {} MB",
140            config.initial_size / (1024 * 1024)
141        );
142
143        Ok(Self {
144            context: context.clone(),
145            config,
146            buffer,
147            blocks,
148            stats,
149            next_block_id: 1,
150        })
151    }
152
153    /// Allocate a block from the pool.
154    ///
155    /// # Errors
156    ///
157    /// Returns an error if allocation fails or pool is exhausted.
158    pub fn allocate(&mut self, size: u64, alignment: u64) -> GpuResult<MemoryAllocation> {
159        let aligned_size = Self::align_size(size, alignment);
160
161        // Find suitable block using first-fit strategy
162        if let Some(block_idx) = self.find_free_block(aligned_size) {
163            return self.allocate_from_block(block_idx, aligned_size);
164        }
165
166        // Try defragmentation if enabled
167        if self.config.auto_defrag && self.stats.needs_defrag(self.config.defrag_threshold) {
168            self.defragment()?;
169
170            // Try again after defragmentation
171            if let Some(block_idx) = self.find_free_block(aligned_size) {
172                return self.allocate_from_block(block_idx, aligned_size);
173            }
174        }
175
176        // Expand pool if possible
177        if self.stats.total_allocated < self.config.max_size {
178            self.expand_pool(aligned_size)?;
179
180            if let Some(block_idx) = self.find_free_block(aligned_size) {
181                return self.allocate_from_block(block_idx, aligned_size);
182            }
183        }
184
185        Err(GpuError::internal(format!(
186            "Failed to allocate {} bytes from pool",
187            aligned_size
188        )))
189    }
190
191    /// Free a memory allocation.
192    ///
193    /// # Errors
194    ///
195    /// Returns an error if the block ID is invalid.
196    pub fn free(&mut self, allocation: MemoryAllocation) -> GpuResult<()> {
197        let block = self
198            .blocks
199            .iter_mut()
200            .find(|b| b.id == allocation.block_id)
201            .ok_or_else(|| GpuError::invalid_buffer("Invalid block ID"))?;
202
203        if !block.in_use {
204            return Err(GpuError::invalid_buffer("Block already freed"));
205        }
206
207        block.in_use = false;
208        self.stats.bytes_in_use = self.stats.bytes_in_use.saturating_sub(block.size);
209        self.stats.bytes_available += block.size;
210        self.stats.num_allocations = self.stats.num_allocations.saturating_sub(1);
211
212        trace!("Freed {} bytes from pool", block.size);
213
214        // Try to merge adjacent free blocks
215        self.merge_adjacent_blocks();
216
217        Ok(())
218    }
219
220    /// Get current memory statistics.
221    pub fn stats(&self) -> &MemoryStats {
222        &self.stats
223    }
224
225    /// Manually trigger defragmentation.
226    ///
227    /// # Errors
228    ///
229    /// Returns an error if defragmentation fails.
230    pub fn defragment(&mut self) -> GpuResult<()> {
231        debug!("Starting memory pool defragmentation");
232
233        // Sort blocks by offset
234        self.blocks.sort_by_key(|b| b.offset);
235
236        // Merge all adjacent free blocks
237        let mut i = 0;
238        while i < self.blocks.len().saturating_sub(1) {
239            if !self.blocks[i].in_use && !self.blocks[i + 1].in_use {
240                let next_size = self.blocks[i + 1].size;
241                self.blocks[i].size += next_size;
242                self.blocks.remove(i + 1);
243            } else {
244                i += 1;
245            }
246        }
247
248        self.stats.num_defrags += 1;
249        self.update_fragmentation_ratio();
250
251        debug!(
252            "Defragmentation complete: {} blocks remaining",
253            self.blocks.len()
254        );
255
256        Ok(())
257    }
258
259    /// Reset the entire pool.
260    pub fn reset(&mut self) {
261        for block in &mut self.blocks {
262            block.in_use = false;
263        }
264
265        // Merge all blocks into one
266        self.blocks.clear();
267        self.blocks.push(MemoryBlock::new(
268            0,
269            self.stats.total_allocated,
270            self.next_block_id,
271        ));
272        self.next_block_id += 1;
273
274        self.stats.bytes_in_use = 0;
275        self.stats.bytes_available = self.stats.total_allocated;
276        self.stats.num_allocations = 0;
277        self.stats.fragmentation_ratio = 0.0;
278
279        debug!("Memory pool reset");
280    }
281
282    fn align_size(size: u64, alignment: u64) -> u64 {
283        ((size + alignment - 1) / alignment) * alignment
284    }
285
286    fn find_free_block(&self, size: u64) -> Option<usize> {
287        self.blocks.iter().position(|block| block.can_fit(size))
288    }
289
290    fn allocate_from_block(&mut self, block_idx: usize, size: u64) -> GpuResult<MemoryAllocation> {
291        let offset = self.blocks[block_idx].offset;
292        let block_id = self.blocks[block_idx].id;
293        let block_size = self.blocks[block_idx].size;
294
295        // Split block if there's leftover space
296        if block_size > size {
297            let remaining_size = block_size - size;
298            let new_offset = offset + size;
299
300            let new_block = MemoryBlock::new(new_offset, remaining_size, self.next_block_id);
301            self.next_block_id += 1;
302
303            self.blocks[block_idx].size = size;
304            self.blocks.insert(block_idx + 1, new_block);
305        }
306
307        self.blocks[block_idx].in_use = true;
308
309        self.stats.bytes_in_use += size;
310        self.stats.bytes_available = self.stats.bytes_available.saturating_sub(size);
311        self.stats.num_allocations += 1;
312
313        self.update_fragmentation_ratio();
314
315        trace!("Allocated {} bytes at offset {}", size, offset);
316
317        Ok(MemoryAllocation {
318            buffer: Arc::clone(&self.buffer),
319            offset,
320            size,
321            block_id,
322        })
323    }
324
325    fn expand_pool(&mut self, min_additional_size: u64) -> GpuResult<()> {
326        let current_size = self.stats.total_allocated;
327        let growth = (current_size as f64 * self.config.growth_factor) as u64;
328        let new_size = (current_size + growth.max(min_additional_size)).min(self.config.max_size);
329
330        if new_size <= current_size {
331            return Err(GpuError::internal("Cannot expand pool beyond maximum size"));
332        }
333
334        let additional_size = new_size - current_size;
335
336        // Create new larger buffer
337        let new_buffer = Arc::new(self.context.device().create_buffer(&BufferDescriptor {
338            label: Some("Expanded Memory Pool"),
339            size: new_size,
340            usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
341            mapped_at_creation: false,
342        }));
343
344        // Copy existing data to new buffer
345        let mut encoder =
346            self.context
347                .device()
348                .create_command_encoder(&wgpu::CommandEncoderDescriptor {
349                    label: Some("Pool Expansion Copy"),
350                });
351
352        encoder.copy_buffer_to_buffer(&self.buffer, 0, &new_buffer, 0, current_size);
353
354        self.context.queue().submit(Some(encoder.finish()));
355
356        // Update buffer and add new free block
357        self.buffer = new_buffer;
358        self.blocks.push(MemoryBlock::new(
359            current_size,
360            additional_size,
361            self.next_block_id,
362        ));
363        self.next_block_id += 1;
364
365        self.stats.total_allocated = new_size;
366        self.stats.bytes_available += additional_size;
367        self.stats.num_expansions += 1;
368
369        debug!(
370            "Expanded memory pool: {} MB -> {} MB",
371            current_size / (1024 * 1024),
372            new_size / (1024 * 1024)
373        );
374
375        Ok(())
376    }
377
378    fn merge_adjacent_blocks(&mut self) {
379        self.blocks.sort_by_key(|b| b.offset);
380
381        let mut i = 0;
382        while i < self.blocks.len().saturating_sub(1) {
383            if !self.blocks[i].in_use
384                && !self.blocks[i + 1].in_use
385                && self.blocks[i].offset + self.blocks[i].size == self.blocks[i + 1].offset
386            {
387                let next_size = self.blocks[i + 1].size;
388                self.blocks[i].size += next_size;
389                self.blocks.remove(i + 1);
390            } else {
391                i += 1;
392            }
393        }
394    }
395
396    fn update_fragmentation_ratio(&mut self) {
397        let free_blocks = self.blocks.iter().filter(|b| !b.in_use).count();
398        let total_blocks = self.blocks.len();
399
400        if total_blocks == 0 {
401            self.stats.fragmentation_ratio = 0.0;
402        } else {
403            self.stats.fragmentation_ratio = free_blocks as f64 / total_blocks as f64;
404        }
405    }
406}
407
408/// A suballocation from the memory pool.
409#[derive(Debug, Clone)]
410pub struct MemoryAllocation {
411    /// The underlying buffer.
412    pub buffer: Arc<Buffer>,
413    /// Offset into the buffer.
414    pub offset: u64,
415    /// Size of the allocation.
416    pub size: u64,
417    /// Block ID for freeing.
418    block_id: u64,
419}
420
421impl MemoryAllocation {
422    /// Get a slice of the buffer for this allocation.
423    pub fn slice(&self) -> wgpu::BufferSlice<'_> {
424        self.buffer.slice(self.offset..self.offset + self.size)
425    }
426}
427
428/// Staging buffer manager for efficient CPU-GPU transfers.
429///
430/// Manages a pool of staging buffers to optimize data transfers between
431/// CPU and GPU memory.
432pub struct StagingBufferManager {
433    context: GpuContext,
434    upload_buffers: VecDeque<Arc<Buffer>>,
435    download_buffers: VecDeque<Arc<Buffer>>,
436    buffer_size: u64,
437    max_buffers: usize,
438    stats: Arc<Mutex<StagingStats>>,
439}
440
441#[derive(Debug, Clone, Default)]
442struct StagingStats {
443    total_uploads: usize,
444    total_downloads: usize,
445    upload_bytes: u64,
446    download_bytes: u64,
447    buffer_reuses: usize,
448}
449
450impl StagingBufferManager {
451    /// Create a new staging buffer manager.
452    pub fn new(context: &GpuContext, buffer_size: u64, max_buffers: usize) -> Self {
453        Self {
454            context: context.clone(),
455            upload_buffers: VecDeque::new(),
456            download_buffers: VecDeque::new(),
457            buffer_size,
458            max_buffers,
459            stats: Arc::new(Mutex::new(StagingStats::default())),
460        }
461    }
462
463    /// Get or create an upload buffer.
464    ///
465    /// # Errors
466    ///
467    /// Returns an error if buffer creation fails.
468    pub fn get_upload_buffer(&mut self) -> GpuResult<Arc<Buffer>> {
469        if let Some(buffer) = self.upload_buffers.pop_front() {
470            if let Ok(mut stats) = self.stats.lock() {
471                stats.buffer_reuses += 1;
472            }
473            Ok(buffer)
474        } else {
475            self.create_upload_buffer()
476        }
477    }
478
479    /// Get or create a download buffer.
480    ///
481    /// # Errors
482    ///
483    /// Returns an error if buffer creation fails.
484    pub fn get_download_buffer(&mut self) -> GpuResult<Arc<Buffer>> {
485        if let Some(buffer) = self.download_buffers.pop_front() {
486            if let Ok(mut stats) = self.stats.lock() {
487                stats.buffer_reuses += 1;
488            }
489            Ok(buffer)
490        } else {
491            self.create_download_buffer()
492        }
493    }
494
495    /// Return an upload buffer to the pool.
496    pub fn return_upload_buffer(&mut self, buffer: Arc<Buffer>) {
497        if self.upload_buffers.len() < self.max_buffers {
498            self.upload_buffers.push_back(buffer);
499        }
500    }
501
502    /// Return a download buffer to the pool.
503    pub fn return_download_buffer(&mut self, buffer: Arc<Buffer>) {
504        if self.download_buffers.len() < self.max_buffers {
505            self.download_buffers.push_back(buffer);
506        }
507    }
508
509    /// Record an upload operation.
510    pub fn record_upload(&self, bytes: u64) {
511        if let Ok(mut stats) = self.stats.lock() {
512            stats.total_uploads += 1;
513            stats.upload_bytes += bytes;
514        }
515    }
516
517    /// Record a download operation.
518    pub fn record_download(&self, bytes: u64) {
519        if let Ok(mut stats) = self.stats.lock() {
520            stats.total_downloads += 1;
521            stats.download_bytes += bytes;
522        }
523    }
524
525    /// Get staging statistics.
526    pub fn stats(&self) -> StagingStats {
527        self.stats.lock().map(|s| s.clone()).unwrap_or_default()
528    }
529
530    /// Clear all cached buffers.
531    pub fn clear(&mut self) {
532        self.upload_buffers.clear();
533        self.download_buffers.clear();
534    }
535
536    fn create_upload_buffer(&self) -> GpuResult<Arc<Buffer>> {
537        let buffer = self.context.device().create_buffer(&BufferDescriptor {
538            label: Some("Staging Upload Buffer"),
539            size: self.buffer_size,
540            usage: BufferUsages::MAP_WRITE | BufferUsages::COPY_SRC,
541            mapped_at_creation: false,
542        });
543
544        Ok(Arc::new(buffer))
545    }
546
547    fn create_download_buffer(&self) -> GpuResult<Arc<Buffer>> {
548        let buffer = self.context.device().create_buffer(&BufferDescriptor {
549            label: Some("Staging Download Buffer"),
550            size: self.buffer_size,
551            usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
552            mapped_at_creation: false,
553        });
554
555        Ok(Arc::new(buffer))
556    }
557}
558
559/// VRAM budget manager to prevent out-of-memory errors.
560pub struct VramBudgetManager {
561    /// Total VRAM budget in bytes.
562    total_budget: u64,
563    /// Currently allocated VRAM in bytes.
564    allocated: Arc<Mutex<u64>>,
565    /// Allocation tracking.
566    allocations: Arc<Mutex<HashMap<u64, u64>>>,
567    next_id: Arc<Mutex<u64>>,
568}
569
570impl VramBudgetManager {
571    /// Create a new VRAM budget manager.
572    pub fn new(total_budget: u64) -> Self {
573        Self {
574            total_budget,
575            allocated: Arc::new(Mutex::new(0)),
576            allocations: Arc::new(Mutex::new(HashMap::new())),
577            next_id: Arc::new(Mutex::new(0)),
578        }
579    }
580
581    /// Try to allocate VRAM budget.
582    ///
583    /// # Errors
584    ///
585    /// Returns an error if budget is exceeded.
586    pub fn allocate(&self, size: u64) -> GpuResult<u64> {
587        let mut allocated = self
588            .allocated
589            .lock()
590            .map_err(|_| GpuError::internal("Lock poisoned"))?;
591
592        if *allocated + size > self.total_budget {
593            return Err(GpuError::internal(format!(
594                "VRAM budget exceeded: {} + {} > {}",
595                *allocated, size, self.total_budget
596            )));
597        }
598
599        let mut id = self
600            .next_id
601            .lock()
602            .map_err(|_| GpuError::internal("Lock poisoned"))?;
603        let allocation_id = *id;
604        *id += 1;
605
606        *allocated += size;
607
608        let mut allocations = self
609            .allocations
610            .lock()
611            .map_err(|_| GpuError::internal("Lock poisoned"))?;
612        allocations.insert(allocation_id, size);
613
614        trace!("VRAM allocated: {} bytes (total: {})", size, *allocated);
615
616        Ok(allocation_id)
617    }
618
619    /// Free VRAM budget.
620    ///
621    /// # Errors
622    ///
623    /// Returns an error if allocation ID is invalid.
624    pub fn free(&self, allocation_id: u64) -> GpuResult<()> {
625        let mut allocations = self
626            .allocations
627            .lock()
628            .map_err(|_| GpuError::internal("Lock poisoned"))?;
629
630        let size = allocations
631            .remove(&allocation_id)
632            .ok_or_else(|| GpuError::invalid_buffer("Invalid allocation ID"))?;
633
634        let mut allocated = self
635            .allocated
636            .lock()
637            .map_err(|_| GpuError::internal("Lock poisoned"))?;
638
639        *allocated = allocated.saturating_sub(size);
640
641        trace!("VRAM freed: {} bytes (total: {})", size, *allocated);
642
643        Ok(())
644    }
645
646    /// Get current allocated amount.
647    pub fn allocated(&self) -> u64 {
648        self.allocated.lock().map(|a| *a).unwrap_or(0)
649    }
650
651    /// Get total budget.
652    pub fn budget(&self) -> u64 {
653        self.total_budget
654    }
655
656    /// Get available budget.
657    pub fn available(&self) -> u64 {
658        self.total_budget.saturating_sub(self.allocated())
659    }
660
661    /// Get utilization percentage.
662    pub fn utilization(&self) -> f64 {
663        if self.total_budget == 0 {
664            return 0.0;
665        }
666        (self.allocated() as f64 / self.total_budget as f64) * 100.0
667    }
668
669    /// Check if allocation would fit in budget.
670    pub fn can_allocate(&self, size: u64) -> bool {
671        self.allocated() + size <= self.total_budget
672    }
673}
674
675#[cfg(test)]
676#[allow(clippy::panic)]
677mod tests {
678    use super::*;
679
680    #[tokio::test]
681    async fn test_memory_stats() {
682        let stats = MemoryStats {
683            total_allocated: 1024,
684            bytes_in_use: 512,
685            ..Default::default()
686        };
687
688        assert_eq!(stats.utilization(), 50.0);
689    }
690
691    #[tokio::test]
692    async fn test_vram_budget_manager() {
693        let manager = VramBudgetManager::new(1024);
694
695        let id1 = manager.allocate(512).unwrap_or_else(|e| panic!("{}", e));
696        assert_eq!(manager.allocated(), 512);
697        assert_eq!(manager.utilization(), 50.0);
698
699        let id2 = manager.allocate(256).unwrap_or_else(|e| panic!("{}", e));
700        assert_eq!(manager.allocated(), 768);
701
702        // Should fail - exceeds budget
703        assert!(manager.allocate(512).is_err());
704
705        manager.free(id1).unwrap_or_else(|e| panic!("{}", e));
706        assert_eq!(manager.allocated(), 256);
707
708        manager.free(id2).unwrap_or_else(|e| panic!("{}", e));
709        assert_eq!(manager.allocated(), 0);
710    }
711
712    #[test]
713    fn test_memory_pool_config() {
714        let config = MemoryPoolConfig::default();
715        assert_eq!(config.initial_size, 64 * 1024 * 1024);
716        assert_eq!(config.max_size, 2 * 1024 * 1024 * 1024);
717    }
718}