1use crate::{GpuBuffer, GpuDevice, GpuError, Result};
7use parking_lot::RwLock;
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11
12#[derive(Debug, Clone, Copy, Default)]
14pub struct MemoryStats {
15 pub total_allocated: u64,
17 pub total_freed: u64,
19 pub current_usage: u64,
21 pub peak_usage: u64,
23 pub allocation_count: u64,
25}
26
27impl MemoryStats {
28 #[must_use]
30 pub fn current_bytes(&self) -> u64 {
31 self.current_usage
32 }
33
34 #[must_use]
36 pub fn current_mb(&self) -> f64 {
37 self.current_usage as f64 / (1024.0 * 1024.0)
38 }
39
40 #[must_use]
42 pub fn peak_bytes(&self) -> u64 {
43 self.peak_usage
44 }
45
46 #[must_use]
48 pub fn peak_mb(&self) -> f64 {
49 self.peak_usage as f64 / (1024.0 * 1024.0)
50 }
51}
52
53pub struct MemoryAllocator {
55 device: Arc<wgpu::Device>,
56 total_allocated: AtomicU64,
57 total_freed: AtomicU64,
58 current_usage: AtomicU64,
59 peak_usage: AtomicU64,
60 allocation_count: AtomicU64,
61}
62
63impl MemoryAllocator {
64 #[must_use]
66 pub fn new(device: &GpuDevice) -> Self {
67 Self {
68 device: Arc::clone(device.device()),
69 total_allocated: AtomicU64::new(0),
70 total_freed: AtomicU64::new(0),
71 current_usage: AtomicU64::new(0),
72 peak_usage: AtomicU64::new(0),
73 allocation_count: AtomicU64::new(0),
74 }
75 }
76
77 pub fn track_allocation(&self, size: u64) {
79 self.total_allocated.fetch_add(size, Ordering::Relaxed);
80 let current = self.current_usage.fetch_add(size, Ordering::Relaxed) + size;
81 self.allocation_count.fetch_add(1, Ordering::Relaxed);
82
83 let mut peak = self.peak_usage.load(Ordering::Relaxed);
85 while current > peak {
86 match self.peak_usage.compare_exchange_weak(
87 peak,
88 current,
89 Ordering::Relaxed,
90 Ordering::Relaxed,
91 ) {
92 Ok(_) => break,
93 Err(x) => peak = x,
94 }
95 }
96 }
97
98 pub fn track_deallocation(&self, size: u64) {
100 self.total_freed.fetch_add(size, Ordering::Relaxed);
101 self.current_usage.fetch_sub(size, Ordering::Relaxed);
102 self.allocation_count.fetch_sub(1, Ordering::Relaxed);
103 }
104
105 pub fn stats(&self) -> MemoryStats {
107 MemoryStats {
108 total_allocated: self.total_allocated.load(Ordering::Relaxed),
109 total_freed: self.total_freed.load(Ordering::Relaxed),
110 current_usage: self.current_usage.load(Ordering::Relaxed),
111 peak_usage: self.peak_usage.load(Ordering::Relaxed),
112 allocation_count: self.allocation_count.load(Ordering::Relaxed),
113 }
114 }
115
116 pub fn reset_stats(&self) {
118 self.total_allocated.store(0, Ordering::Relaxed);
119 self.total_freed.store(0, Ordering::Relaxed);
120 self.current_usage.store(0, Ordering::Relaxed);
121 self.peak_usage.store(0, Ordering::Relaxed);
122 self.allocation_count.store(0, Ordering::Relaxed);
123 }
124
125 pub fn device(&self) -> &Arc<wgpu::Device> {
127 &self.device
128 }
129}
130
131pub struct MemoryPool {
133 #[allow(dead_code)]
134 device: Arc<wgpu::Device>,
135 allocator: Arc<MemoryAllocator>,
136 pools: RwLock<HashMap<u64, Vec<GpuBuffer>>>,
137}
138
139impl MemoryPool {
140 #[must_use]
142 pub fn new(device: &GpuDevice) -> Self {
143 Self {
144 device: Arc::clone(device.device()),
145 allocator: Arc::new(MemoryAllocator::new(device)),
146 pools: RwLock::new(HashMap::new()),
147 }
148 }
149
150 pub fn allocate(
165 &self,
166 device: &GpuDevice,
167 size: u64,
168 buffer_type: crate::buffer::BufferType,
169 ) -> Result<GpuBuffer> {
170 {
172 let mut pools = self.pools.write();
173 if let Some(pool) = pools.get_mut(&size) {
174 if let Some(buffer) = pool.pop() {
175 return Ok(buffer);
176 }
177 }
178 }
179
180 let buffer = GpuBuffer::new(device, size, buffer_type)?;
182 self.allocator.track_allocation(size);
183
184 Ok(buffer)
185 }
186
187 pub fn deallocate(&self, buffer: GpuBuffer) {
193 let size = buffer.size();
194 let mut pools = self.pools.write();
195 pools.entry(size).or_default().push(buffer);
196 }
197
198 pub fn clear(&self) {
200 let mut pools = self.pools.write();
201 for (size, buffers) in pools.drain() {
202 let total_size = size * buffers.len() as u64;
203 self.allocator.track_deallocation(total_size);
204 }
205 }
206
207 pub fn pool_size(&self) -> usize {
209 let pools = self.pools.read();
210 pools.values().map(std::vec::Vec::len).sum()
211 }
212
213 pub fn stats(&self) -> MemoryStats {
215 self.allocator.stats()
216 }
217
218 pub fn allocator(&self) -> &Arc<MemoryAllocator> {
220 &self.allocator
221 }
222}
223
224pub struct ManagedBuffer {
226 buffer: Option<GpuBuffer>,
227 pool: Arc<MemoryPool>,
228}
229
230impl ManagedBuffer {
231 pub fn new(buffer: GpuBuffer, pool: Arc<MemoryPool>) -> Self {
233 Self {
234 buffer: Some(buffer),
235 pool,
236 }
237 }
238
239 #[must_use]
246 pub fn buffer(&self) -> &GpuBuffer {
247 self.buffer
248 .as_ref()
249 .unwrap_or_else(|| unreachable!("ManagedBuffer accessed after buffer was released"))
250 }
251
252 pub fn try_buffer(&self) -> Result<&GpuBuffer> {
258 self.buffer
259 .as_ref()
260 .ok_or_else(|| GpuError::Internal("Buffer already released".to_string()))
261 }
262
263 pub fn take(mut self) -> Result<GpuBuffer> {
269 self.buffer
270 .take()
271 .ok_or_else(|| GpuError::Internal("Buffer already released".to_string()))
272 }
273}
274
275impl Drop for ManagedBuffer {
276 fn drop(&mut self) {
277 if let Some(buffer) = self.buffer.take() {
278 self.pool.deallocate(buffer);
279 }
280 }
281}
282
283impl std::ops::Deref for ManagedBuffer {
284 type Target = GpuBuffer;
285
286 fn deref(&self) -> &Self::Target {
287 self.buffer()
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294
295 #[test]
296 fn test_memory_stats() {
297 let stats = MemoryStats {
298 total_allocated: 1024 * 1024 * 100, total_freed: 1024 * 1024 * 20, current_usage: 1024 * 1024 * 80, peak_usage: 1024 * 1024 * 90, allocation_count: 10,
303 };
304
305 assert_eq!(stats.current_bytes(), 1024 * 1024 * 80);
306 assert!((stats.current_mb() - 80.0).abs() < 0.01);
307 assert_eq!(stats.peak_bytes(), 1024 * 1024 * 90);
308 assert!((stats.peak_mb() - 90.0).abs() < 0.01);
309 }
310
311 #[test]
312 #[ignore] fn test_memory_allocator_tracking() {
314 let Ok(gpu_device) = crate::device::GpuDevice::new(None) else {
315 return;
316 };
317 let allocator = MemoryAllocator::new(&gpu_device);
318
319 allocator.track_allocation(1024);
320 allocator.track_allocation(2048);
321
322 let stats = allocator.stats();
323 assert_eq!(stats.total_allocated, 3072);
324 assert_eq!(stats.current_usage, 3072);
325 assert_eq!(stats.allocation_count, 2);
326
327 allocator.track_deallocation(1024);
328
329 let stats = allocator.stats();
330 assert_eq!(stats.total_freed, 1024);
331 assert_eq!(stats.current_usage, 2048);
332 assert_eq!(stats.allocation_count, 1);
333 }
334
335 #[test]
336 fn test_memory_allocator_tracking_no_gpu() {
337 let total_allocated = AtomicU64::new(0);
339 let total_freed = AtomicU64::new(0);
340 let current_usage = AtomicU64::new(0);
341 let allocation_count = AtomicU64::new(0);
342
343 total_allocated.fetch_add(1024, Ordering::Relaxed);
345 current_usage.fetch_add(1024, Ordering::Relaxed);
346 allocation_count.fetch_add(1, Ordering::Relaxed);
347
348 total_allocated.fetch_add(2048, Ordering::Relaxed);
350 current_usage.fetch_add(2048, Ordering::Relaxed);
351 allocation_count.fetch_add(1, Ordering::Relaxed);
352
353 assert_eq!(total_allocated.load(Ordering::Relaxed), 3072);
354 assert_eq!(current_usage.load(Ordering::Relaxed), 3072);
355 assert_eq!(allocation_count.load(Ordering::Relaxed), 2);
356
357 total_freed.fetch_add(1024, Ordering::Relaxed);
359 current_usage.fetch_sub(1024, Ordering::Relaxed);
360 allocation_count.fetch_sub(1, Ordering::Relaxed);
361
362 assert_eq!(total_freed.load(Ordering::Relaxed), 1024);
363 assert_eq!(current_usage.load(Ordering::Relaxed), 2048);
364 assert_eq!(allocation_count.load(Ordering::Relaxed), 1);
365 }
366}