Skip to main content

oximedia_gpu/
memory.rs

1//! GPU memory management and allocation tracking
2//!
3//! This module provides memory allocation tracking, usage statistics,
4//! and memory pool management for GPU buffers.
5
6use crate::{GpuBuffer, GpuDevice, Result};
7use parking_lot::RwLock;
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11
12/// Memory allocation statistics
13#[derive(Debug, Clone, Copy, Default)]
14pub struct MemoryStats {
15    /// Total bytes allocated
16    pub total_allocated: u64,
17    /// Total bytes freed
18    pub total_freed: u64,
19    /// Current bytes in use
20    pub current_usage: u64,
21    /// Peak memory usage
22    pub peak_usage: u64,
23    /// Number of active allocations
24    pub allocation_count: u64,
25}
26
27impl MemoryStats {
28    /// Get the current memory usage in bytes
29    #[must_use]
30    pub fn current_bytes(&self) -> u64 {
31        self.current_usage
32    }
33
34    /// Get the current memory usage in megabytes
35    #[must_use]
36    pub fn current_mb(&self) -> f64 {
37        self.current_usage as f64 / (1024.0 * 1024.0)
38    }
39
40    /// Get the peak memory usage in bytes
41    #[must_use]
42    pub fn peak_bytes(&self) -> u64 {
43        self.peak_usage
44    }
45
46    /// Get the peak memory usage in megabytes
47    #[must_use]
48    pub fn peak_mb(&self) -> f64 {
49        self.peak_usage as f64 / (1024.0 * 1024.0)
50    }
51}
52
53/// Memory allocator for GPU buffers
54pub struct MemoryAllocator {
55    device: Arc<wgpu::Device>,
56    total_allocated: AtomicU64,
57    total_freed: AtomicU64,
58    current_usage: AtomicU64,
59    peak_usage: AtomicU64,
60    allocation_count: AtomicU64,
61}
62
63impl MemoryAllocator {
64    /// Create a new memory allocator
65    #[must_use]
66    pub fn new(device: &GpuDevice) -> Self {
67        Self {
68            device: Arc::clone(device.device()),
69            total_allocated: AtomicU64::new(0),
70            total_freed: AtomicU64::new(0),
71            current_usage: AtomicU64::new(0),
72            peak_usage: AtomicU64::new(0),
73            allocation_count: AtomicU64::new(0),
74        }
75    }
76
77    /// Track a memory allocation
78    pub fn track_allocation(&self, size: u64) {
79        self.total_allocated.fetch_add(size, Ordering::Relaxed);
80        let current = self.current_usage.fetch_add(size, Ordering::Relaxed) + size;
81        self.allocation_count.fetch_add(1, Ordering::Relaxed);
82
83        // Update peak usage
84        let mut peak = self.peak_usage.load(Ordering::Relaxed);
85        while current > peak {
86            match self.peak_usage.compare_exchange_weak(
87                peak,
88                current,
89                Ordering::Relaxed,
90                Ordering::Relaxed,
91            ) {
92                Ok(_) => break,
93                Err(x) => peak = x,
94            }
95        }
96    }
97
98    /// Track a memory deallocation
99    pub fn track_deallocation(&self, size: u64) {
100        self.total_freed.fetch_add(size, Ordering::Relaxed);
101        self.current_usage.fetch_sub(size, Ordering::Relaxed);
102        self.allocation_count.fetch_sub(1, Ordering::Relaxed);
103    }
104
105    /// Get current memory statistics
106    pub fn stats(&self) -> MemoryStats {
107        MemoryStats {
108            total_allocated: self.total_allocated.load(Ordering::Relaxed),
109            total_freed: self.total_freed.load(Ordering::Relaxed),
110            current_usage: self.current_usage.load(Ordering::Relaxed),
111            peak_usage: self.peak_usage.load(Ordering::Relaxed),
112            allocation_count: self.allocation_count.load(Ordering::Relaxed),
113        }
114    }
115
116    /// Reset statistics
117    pub fn reset_stats(&self) {
118        self.total_allocated.store(0, Ordering::Relaxed);
119        self.total_freed.store(0, Ordering::Relaxed);
120        self.current_usage.store(0, Ordering::Relaxed);
121        self.peak_usage.store(0, Ordering::Relaxed);
122        self.allocation_count.store(0, Ordering::Relaxed);
123    }
124
125    /// Get the device reference
126    pub fn device(&self) -> &Arc<wgpu::Device> {
127        &self.device
128    }
129}
130
131/// Memory pool for reusing GPU buffers
132pub struct MemoryPool {
133    #[allow(dead_code)]
134    device: Arc<wgpu::Device>,
135    allocator: Arc<MemoryAllocator>,
136    pools: RwLock<HashMap<u64, Vec<GpuBuffer>>>,
137}
138
139impl MemoryPool {
140    /// Create a new memory pool
141    #[must_use]
142    pub fn new(device: &GpuDevice) -> Self {
143        Self {
144            device: Arc::clone(device.device()),
145            allocator: Arc::new(MemoryAllocator::new(device)),
146            pools: RwLock::new(HashMap::new()),
147        }
148    }
149
150    /// Allocate a buffer from the pool
151    ///
152    /// If a buffer of the requested size is available in the pool, it will be reused.
153    /// Otherwise, a new buffer will be allocated.
154    ///
155    /// # Arguments
156    ///
157    /// * `device` - GPU device
158    /// * `size` - Buffer size in bytes
159    /// * `buffer_type` - Type of buffer to allocate
160    ///
161    /// # Errors
162    ///
163    /// Returns an error if buffer allocation fails.
164    pub fn allocate(
165        &self,
166        device: &GpuDevice,
167        size: u64,
168        buffer_type: crate::buffer::BufferType,
169    ) -> Result<GpuBuffer> {
170        // Try to reuse a buffer from the pool
171        {
172            let mut pools = self.pools.write();
173            if let Some(pool) = pools.get_mut(&size) {
174                if let Some(buffer) = pool.pop() {
175                    return Ok(buffer);
176                }
177            }
178        }
179
180        // Allocate a new buffer
181        let buffer = GpuBuffer::new(device, size, buffer_type)?;
182        self.allocator.track_allocation(size);
183
184        Ok(buffer)
185    }
186
187    /// Return a buffer to the pool for reuse
188    ///
189    /// # Arguments
190    ///
191    /// * `buffer` - Buffer to return to the pool
192    pub fn deallocate(&self, buffer: GpuBuffer) {
193        let size = buffer.size();
194        let mut pools = self.pools.write();
195        pools.entry(size).or_default().push(buffer);
196    }
197
198    /// Clear the memory pool
199    pub fn clear(&self) {
200        let mut pools = self.pools.write();
201        for (size, buffers) in pools.drain() {
202            let total_size = size * buffers.len() as u64;
203            self.allocator.track_deallocation(total_size);
204        }
205    }
206
207    /// Get the number of buffers in the pool
208    pub fn pool_size(&self) -> usize {
209        let pools = self.pools.read();
210        pools.values().map(std::vec::Vec::len).sum()
211    }
212
213    /// Get memory statistics
214    pub fn stats(&self) -> MemoryStats {
215        self.allocator.stats()
216    }
217
218    /// Get the allocator
219    pub fn allocator(&self) -> &Arc<MemoryAllocator> {
220        &self.allocator
221    }
222}
223
224/// RAII wrapper for automatic buffer deallocation
225pub struct ManagedBuffer {
226    buffer: Option<GpuBuffer>,
227    pool: Arc<MemoryPool>,
228}
229
230impl ManagedBuffer {
231    /// Create a new managed buffer
232    pub fn new(buffer: GpuBuffer, pool: Arc<MemoryPool>) -> Self {
233        Self {
234            buffer: Some(buffer),
235            pool,
236        }
237    }
238
239    /// Get a reference to the buffer
240    #[must_use]
241    pub fn buffer(&self) -> &GpuBuffer {
242        self.buffer.as_ref().expect("Buffer already released")
243    }
244
245    /// Take ownership of the buffer, preventing automatic deallocation
246    #[must_use]
247    pub fn take(mut self) -> GpuBuffer {
248        self.buffer.take().expect("Buffer already released")
249    }
250}
251
252impl Drop for ManagedBuffer {
253    fn drop(&mut self) {
254        if let Some(buffer) = self.buffer.take() {
255            self.pool.deallocate(buffer);
256        }
257    }
258}
259
260impl std::ops::Deref for ManagedBuffer {
261    type Target = GpuBuffer;
262
263    fn deref(&self) -> &Self::Target {
264        self.buffer()
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271
272    #[test]
273    fn test_memory_stats() {
274        let stats = MemoryStats {
275            total_allocated: 1024 * 1024 * 100, // 100 MB
276            total_freed: 1024 * 1024 * 20,      // 20 MB
277            current_usage: 1024 * 1024 * 80,    // 80 MB
278            peak_usage: 1024 * 1024 * 90,       // 90 MB
279            allocation_count: 10,
280        };
281
282        assert_eq!(stats.current_bytes(), 1024 * 1024 * 80);
283        assert!((stats.current_mb() - 80.0).abs() < 0.01);
284        assert_eq!(stats.peak_bytes(), 1024 * 1024 * 90);
285        assert!((stats.peak_mb() - 90.0).abs() < 0.01);
286    }
287
288    #[test]
289    #[ignore] // Requires GPU hardware; run with --ignored
290    fn test_memory_allocator_tracking() {
291        let Ok(gpu_device) = crate::device::GpuDevice::new(None) else {
292            return;
293        };
294        let allocator = MemoryAllocator::new(&gpu_device);
295
296        allocator.track_allocation(1024);
297        allocator.track_allocation(2048);
298
299        let stats = allocator.stats();
300        assert_eq!(stats.total_allocated, 3072);
301        assert_eq!(stats.current_usage, 3072);
302        assert_eq!(stats.allocation_count, 2);
303
304        allocator.track_deallocation(1024);
305
306        let stats = allocator.stats();
307        assert_eq!(stats.total_freed, 1024);
308        assert_eq!(stats.current_usage, 2048);
309        assert_eq!(stats.allocation_count, 1);
310    }
311
312    #[test]
313    fn test_memory_allocator_tracking_no_gpu() {
314        // Test tracking logic without GPU initialization using atomic counters directly.
315        let total_allocated = AtomicU64::new(0);
316        let total_freed = AtomicU64::new(0);
317        let current_usage = AtomicU64::new(0);
318        let allocation_count = AtomicU64::new(0);
319
320        // Simulate track_allocation(1024)
321        total_allocated.fetch_add(1024, Ordering::Relaxed);
322        current_usage.fetch_add(1024, Ordering::Relaxed);
323        allocation_count.fetch_add(1, Ordering::Relaxed);
324
325        // Simulate track_allocation(2048)
326        total_allocated.fetch_add(2048, Ordering::Relaxed);
327        current_usage.fetch_add(2048, Ordering::Relaxed);
328        allocation_count.fetch_add(1, Ordering::Relaxed);
329
330        assert_eq!(total_allocated.load(Ordering::Relaxed), 3072);
331        assert_eq!(current_usage.load(Ordering::Relaxed), 3072);
332        assert_eq!(allocation_count.load(Ordering::Relaxed), 2);
333
334        // Simulate track_deallocation(1024)
335        total_freed.fetch_add(1024, Ordering::Relaxed);
336        current_usage.fetch_sub(1024, Ordering::Relaxed);
337        allocation_count.fetch_sub(1, Ordering::Relaxed);
338
339        assert_eq!(total_freed.load(Ordering::Relaxed), 1024);
340        assert_eq!(current_usage.load(Ordering::Relaxed), 2048);
341        assert_eq!(allocation_count.load(Ordering::Relaxed), 1);
342    }
343}