1use crate::context::GpuContext;
7use crate::error::{GpuError, GpuResult};
8use std::collections::{HashMap, VecDeque};
9use std::sync::{Arc, Mutex};
10use tracing::{debug, trace};
11use wgpu::{Buffer, BufferDescriptor, BufferUsages};
12
13#[derive(Debug, Clone)]
15pub struct MemoryPoolConfig {
16 pub initial_size: u64,
18 pub max_size: u64,
20 pub growth_factor: f64,
22 pub auto_defrag: bool,
24 pub defrag_threshold: f64,
26}
27
28impl Default for MemoryPoolConfig {
29 fn default() -> Self {
30 Self {
31 initial_size: 64 * 1024 * 1024, max_size: 2 * 1024 * 1024 * 1024, growth_factor: 1.5,
34 auto_defrag: true,
35 defrag_threshold: 0.3,
36 }
37 }
38}
39
40#[derive(Debug, Clone, Default)]
42pub struct MemoryStats {
43 pub total_allocated: u64,
45 pub bytes_in_use: u64,
47 pub bytes_available: u64,
49 pub num_allocations: usize,
51 pub num_expansions: usize,
53 pub num_defrags: usize,
55 pub fragmentation_ratio: f64,
57}
58
59impl MemoryStats {
60 pub fn utilization(&self) -> f64 {
62 if self.total_allocated == 0 {
63 return 0.0;
64 }
65 (self.bytes_in_use as f64 / self.total_allocated as f64) * 100.0
66 }
67
68 pub fn needs_defrag(&self, threshold: f64) -> bool {
70 self.fragmentation_ratio >= threshold
71 }
72}
73
74#[derive(Debug)]
76struct MemoryBlock {
77 offset: u64,
79 size: u64,
81 in_use: bool,
83 id: u64,
85}
86
87impl MemoryBlock {
88 fn new(offset: u64, size: u64, id: u64) -> Self {
89 Self {
90 offset,
91 size,
92 in_use: false,
93 id,
94 }
95 }
96
97 fn can_fit(&self, size: u64) -> bool {
98 !self.in_use && self.size >= size
99 }
100}
101
102pub struct MemoryPool {
107 context: GpuContext,
108 config: MemoryPoolConfig,
109 buffer: Arc<Buffer>,
110 blocks: Vec<MemoryBlock>,
111 stats: MemoryStats,
112 next_block_id: u64,
113}
114
115impl MemoryPool {
116 pub fn new(context: &GpuContext, config: MemoryPoolConfig) -> GpuResult<Self> {
122 let buffer = Arc::new(context.device().create_buffer(&BufferDescriptor {
123 label: Some("Memory Pool"),
124 size: config.initial_size,
125 usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
126 mapped_at_creation: false,
127 }));
128
129 let mut blocks = Vec::new();
130 blocks.push(MemoryBlock::new(0, config.initial_size, 0));
131
132 let stats = MemoryStats {
133 total_allocated: config.initial_size,
134 bytes_available: config.initial_size,
135 ..Default::default()
136 };
137
138 debug!(
139 "Created memory pool: {} MB",
140 config.initial_size / (1024 * 1024)
141 );
142
143 Ok(Self {
144 context: context.clone(),
145 config,
146 buffer,
147 blocks,
148 stats,
149 next_block_id: 1,
150 })
151 }
152
153 pub fn allocate(&mut self, size: u64, alignment: u64) -> GpuResult<MemoryAllocation> {
159 let aligned_size = Self::align_size(size, alignment);
160
161 if let Some(block_idx) = self.find_free_block(aligned_size) {
163 return self.allocate_from_block(block_idx, aligned_size);
164 }
165
166 if self.config.auto_defrag && self.stats.needs_defrag(self.config.defrag_threshold) {
168 self.defragment()?;
169
170 if let Some(block_idx) = self.find_free_block(aligned_size) {
172 return self.allocate_from_block(block_idx, aligned_size);
173 }
174 }
175
176 if self.stats.total_allocated < self.config.max_size {
178 self.expand_pool(aligned_size)?;
179
180 if let Some(block_idx) = self.find_free_block(aligned_size) {
181 return self.allocate_from_block(block_idx, aligned_size);
182 }
183 }
184
185 Err(GpuError::internal(format!(
186 "Failed to allocate {} bytes from pool",
187 aligned_size
188 )))
189 }
190
191 pub fn free(&mut self, allocation: MemoryAllocation) -> GpuResult<()> {
197 let block = self
198 .blocks
199 .iter_mut()
200 .find(|b| b.id == allocation.block_id)
201 .ok_or_else(|| GpuError::invalid_buffer("Invalid block ID"))?;
202
203 if !block.in_use {
204 return Err(GpuError::invalid_buffer("Block already freed"));
205 }
206
207 block.in_use = false;
208 self.stats.bytes_in_use = self.stats.bytes_in_use.saturating_sub(block.size);
209 self.stats.bytes_available += block.size;
210 self.stats.num_allocations = self.stats.num_allocations.saturating_sub(1);
211
212 trace!("Freed {} bytes from pool", block.size);
213
214 self.merge_adjacent_blocks();
216
217 Ok(())
218 }
219
220 pub fn stats(&self) -> &MemoryStats {
222 &self.stats
223 }
224
225 pub fn defragment(&mut self) -> GpuResult<()> {
231 debug!("Starting memory pool defragmentation");
232
233 self.blocks.sort_by_key(|b| b.offset);
235
236 let mut i = 0;
238 while i < self.blocks.len().saturating_sub(1) {
239 if !self.blocks[i].in_use && !self.blocks[i + 1].in_use {
240 let next_size = self.blocks[i + 1].size;
241 self.blocks[i].size += next_size;
242 self.blocks.remove(i + 1);
243 } else {
244 i += 1;
245 }
246 }
247
248 self.stats.num_defrags += 1;
249 self.update_fragmentation_ratio();
250
251 debug!(
252 "Defragmentation complete: {} blocks remaining",
253 self.blocks.len()
254 );
255
256 Ok(())
257 }
258
259 pub fn reset(&mut self) {
261 for block in &mut self.blocks {
262 block.in_use = false;
263 }
264
265 self.blocks.clear();
267 self.blocks.push(MemoryBlock::new(
268 0,
269 self.stats.total_allocated,
270 self.next_block_id,
271 ));
272 self.next_block_id += 1;
273
274 self.stats.bytes_in_use = 0;
275 self.stats.bytes_available = self.stats.total_allocated;
276 self.stats.num_allocations = 0;
277 self.stats.fragmentation_ratio = 0.0;
278
279 debug!("Memory pool reset");
280 }
281
282 fn align_size(size: u64, alignment: u64) -> u64 {
283 ((size + alignment - 1) / alignment) * alignment
284 }
285
286 fn find_free_block(&self, size: u64) -> Option<usize> {
287 self.blocks.iter().position(|block| block.can_fit(size))
288 }
289
290 fn allocate_from_block(&mut self, block_idx: usize, size: u64) -> GpuResult<MemoryAllocation> {
291 let offset = self.blocks[block_idx].offset;
292 let block_id = self.blocks[block_idx].id;
293 let block_size = self.blocks[block_idx].size;
294
295 if block_size > size {
297 let remaining_size = block_size - size;
298 let new_offset = offset + size;
299
300 let new_block = MemoryBlock::new(new_offset, remaining_size, self.next_block_id);
301 self.next_block_id += 1;
302
303 self.blocks[block_idx].size = size;
304 self.blocks.insert(block_idx + 1, new_block);
305 }
306
307 self.blocks[block_idx].in_use = true;
308
309 self.stats.bytes_in_use += size;
310 self.stats.bytes_available = self.stats.bytes_available.saturating_sub(size);
311 self.stats.num_allocations += 1;
312
313 self.update_fragmentation_ratio();
314
315 trace!("Allocated {} bytes at offset {}", size, offset);
316
317 Ok(MemoryAllocation {
318 buffer: Arc::clone(&self.buffer),
319 offset,
320 size,
321 block_id,
322 })
323 }
324
325 fn expand_pool(&mut self, min_additional_size: u64) -> GpuResult<()> {
326 let current_size = self.stats.total_allocated;
327 let growth = (current_size as f64 * self.config.growth_factor) as u64;
328 let new_size = (current_size + growth.max(min_additional_size)).min(self.config.max_size);
329
330 if new_size <= current_size {
331 return Err(GpuError::internal("Cannot expand pool beyond maximum size"));
332 }
333
334 let additional_size = new_size - current_size;
335
336 let new_buffer = Arc::new(self.context.device().create_buffer(&BufferDescriptor {
338 label: Some("Expanded Memory Pool"),
339 size: new_size,
340 usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
341 mapped_at_creation: false,
342 }));
343
344 let mut encoder =
346 self.context
347 .device()
348 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
349 label: Some("Pool Expansion Copy"),
350 });
351
352 encoder.copy_buffer_to_buffer(&self.buffer, 0, &new_buffer, 0, current_size);
353
354 self.context.queue().submit(Some(encoder.finish()));
355
356 self.buffer = new_buffer;
358 self.blocks.push(MemoryBlock::new(
359 current_size,
360 additional_size,
361 self.next_block_id,
362 ));
363 self.next_block_id += 1;
364
365 self.stats.total_allocated = new_size;
366 self.stats.bytes_available += additional_size;
367 self.stats.num_expansions += 1;
368
369 debug!(
370 "Expanded memory pool: {} MB -> {} MB",
371 current_size / (1024 * 1024),
372 new_size / (1024 * 1024)
373 );
374
375 Ok(())
376 }
377
378 fn merge_adjacent_blocks(&mut self) {
379 self.blocks.sort_by_key(|b| b.offset);
380
381 let mut i = 0;
382 while i < self.blocks.len().saturating_sub(1) {
383 if !self.blocks[i].in_use
384 && !self.blocks[i + 1].in_use
385 && self.blocks[i].offset + self.blocks[i].size == self.blocks[i + 1].offset
386 {
387 let next_size = self.blocks[i + 1].size;
388 self.blocks[i].size += next_size;
389 self.blocks.remove(i + 1);
390 } else {
391 i += 1;
392 }
393 }
394 }
395
396 fn update_fragmentation_ratio(&mut self) {
397 let free_blocks = self.blocks.iter().filter(|b| !b.in_use).count();
398 let total_blocks = self.blocks.len();
399
400 if total_blocks == 0 {
401 self.stats.fragmentation_ratio = 0.0;
402 } else {
403 self.stats.fragmentation_ratio = free_blocks as f64 / total_blocks as f64;
404 }
405 }
406}
407
408#[derive(Debug, Clone)]
410pub struct MemoryAllocation {
411 pub buffer: Arc<Buffer>,
413 pub offset: u64,
415 pub size: u64,
417 block_id: u64,
419}
420
421impl MemoryAllocation {
422 pub fn slice(&self) -> wgpu::BufferSlice<'_> {
424 self.buffer.slice(self.offset..self.offset + self.size)
425 }
426}
427
428pub struct StagingBufferManager {
433 context: GpuContext,
434 upload_buffers: VecDeque<Arc<Buffer>>,
435 download_buffers: VecDeque<Arc<Buffer>>,
436 buffer_size: u64,
437 max_buffers: usize,
438 stats: Arc<Mutex<StagingStats>>,
439}
440
441#[derive(Debug, Clone, Default)]
442struct StagingStats {
443 total_uploads: usize,
444 total_downloads: usize,
445 upload_bytes: u64,
446 download_bytes: u64,
447 buffer_reuses: usize,
448}
449
450impl StagingBufferManager {
451 pub fn new(context: &GpuContext, buffer_size: u64, max_buffers: usize) -> Self {
453 Self {
454 context: context.clone(),
455 upload_buffers: VecDeque::new(),
456 download_buffers: VecDeque::new(),
457 buffer_size,
458 max_buffers,
459 stats: Arc::new(Mutex::new(StagingStats::default())),
460 }
461 }
462
463 pub fn get_upload_buffer(&mut self) -> GpuResult<Arc<Buffer>> {
469 if let Some(buffer) = self.upload_buffers.pop_front() {
470 if let Ok(mut stats) = self.stats.lock() {
471 stats.buffer_reuses += 1;
472 }
473 Ok(buffer)
474 } else {
475 self.create_upload_buffer()
476 }
477 }
478
479 pub fn get_download_buffer(&mut self) -> GpuResult<Arc<Buffer>> {
485 if let Some(buffer) = self.download_buffers.pop_front() {
486 if let Ok(mut stats) = self.stats.lock() {
487 stats.buffer_reuses += 1;
488 }
489 Ok(buffer)
490 } else {
491 self.create_download_buffer()
492 }
493 }
494
495 pub fn return_upload_buffer(&mut self, buffer: Arc<Buffer>) {
497 if self.upload_buffers.len() < self.max_buffers {
498 self.upload_buffers.push_back(buffer);
499 }
500 }
501
502 pub fn return_download_buffer(&mut self, buffer: Arc<Buffer>) {
504 if self.download_buffers.len() < self.max_buffers {
505 self.download_buffers.push_back(buffer);
506 }
507 }
508
509 pub fn record_upload(&self, bytes: u64) {
511 if let Ok(mut stats) = self.stats.lock() {
512 stats.total_uploads += 1;
513 stats.upload_bytes += bytes;
514 }
515 }
516
517 pub fn record_download(&self, bytes: u64) {
519 if let Ok(mut stats) = self.stats.lock() {
520 stats.total_downloads += 1;
521 stats.download_bytes += bytes;
522 }
523 }
524
525 pub fn stats(&self) -> StagingStats {
527 self.stats.lock().map(|s| s.clone()).unwrap_or_default()
528 }
529
530 pub fn clear(&mut self) {
532 self.upload_buffers.clear();
533 self.download_buffers.clear();
534 }
535
536 fn create_upload_buffer(&self) -> GpuResult<Arc<Buffer>> {
537 let buffer = self.context.device().create_buffer(&BufferDescriptor {
538 label: Some("Staging Upload Buffer"),
539 size: self.buffer_size,
540 usage: BufferUsages::MAP_WRITE | BufferUsages::COPY_SRC,
541 mapped_at_creation: false,
542 });
543
544 Ok(Arc::new(buffer))
545 }
546
547 fn create_download_buffer(&self) -> GpuResult<Arc<Buffer>> {
548 let buffer = self.context.device().create_buffer(&BufferDescriptor {
549 label: Some("Staging Download Buffer"),
550 size: self.buffer_size,
551 usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
552 mapped_at_creation: false,
553 });
554
555 Ok(Arc::new(buffer))
556 }
557}
558
559pub struct VramBudgetManager {
561 total_budget: u64,
563 allocated: Arc<Mutex<u64>>,
565 allocations: Arc<Mutex<HashMap<u64, u64>>>,
567 next_id: Arc<Mutex<u64>>,
568}
569
570impl VramBudgetManager {
571 pub fn new(total_budget: u64) -> Self {
573 Self {
574 total_budget,
575 allocated: Arc::new(Mutex::new(0)),
576 allocations: Arc::new(Mutex::new(HashMap::new())),
577 next_id: Arc::new(Mutex::new(0)),
578 }
579 }
580
581 pub fn allocate(&self, size: u64) -> GpuResult<u64> {
587 let mut allocated = self
588 .allocated
589 .lock()
590 .map_err(|_| GpuError::internal("Lock poisoned"))?;
591
592 if *allocated + size > self.total_budget {
593 return Err(GpuError::internal(format!(
594 "VRAM budget exceeded: {} + {} > {}",
595 *allocated, size, self.total_budget
596 )));
597 }
598
599 let mut id = self
600 .next_id
601 .lock()
602 .map_err(|_| GpuError::internal("Lock poisoned"))?;
603 let allocation_id = *id;
604 *id += 1;
605
606 *allocated += size;
607
608 let mut allocations = self
609 .allocations
610 .lock()
611 .map_err(|_| GpuError::internal("Lock poisoned"))?;
612 allocations.insert(allocation_id, size);
613
614 trace!("VRAM allocated: {} bytes (total: {})", size, *allocated);
615
616 Ok(allocation_id)
617 }
618
619 pub fn free(&self, allocation_id: u64) -> GpuResult<()> {
625 let mut allocations = self
626 .allocations
627 .lock()
628 .map_err(|_| GpuError::internal("Lock poisoned"))?;
629
630 let size = allocations
631 .remove(&allocation_id)
632 .ok_or_else(|| GpuError::invalid_buffer("Invalid allocation ID"))?;
633
634 let mut allocated = self
635 .allocated
636 .lock()
637 .map_err(|_| GpuError::internal("Lock poisoned"))?;
638
639 *allocated = allocated.saturating_sub(size);
640
641 trace!("VRAM freed: {} bytes (total: {})", size, *allocated);
642
643 Ok(())
644 }
645
646 pub fn allocated(&self) -> u64 {
648 self.allocated.lock().map(|a| *a).unwrap_or(0)
649 }
650
651 pub fn budget(&self) -> u64 {
653 self.total_budget
654 }
655
656 pub fn available(&self) -> u64 {
658 self.total_budget.saturating_sub(self.allocated())
659 }
660
661 pub fn utilization(&self) -> f64 {
663 if self.total_budget == 0 {
664 return 0.0;
665 }
666 (self.allocated() as f64 / self.total_budget as f64) * 100.0
667 }
668
669 pub fn can_allocate(&self, size: u64) -> bool {
671 self.allocated() + size <= self.total_budget
672 }
673}
674
675#[cfg(test)]
676#[allow(clippy::panic)]
677mod tests {
678 use super::*;
679
680 #[tokio::test]
681 async fn test_memory_stats() {
682 let stats = MemoryStats {
683 total_allocated: 1024,
684 bytes_in_use: 512,
685 ..Default::default()
686 };
687
688 assert_eq!(stats.utilization(), 50.0);
689 }
690
691 #[tokio::test]
692 async fn test_vram_budget_manager() {
693 let manager = VramBudgetManager::new(1024);
694
695 let id1 = manager.allocate(512).unwrap_or_else(|e| panic!("{}", e));
696 assert_eq!(manager.allocated(), 512);
697 assert_eq!(manager.utilization(), 50.0);
698
699 let id2 = manager.allocate(256).unwrap_or_else(|e| panic!("{}", e));
700 assert_eq!(manager.allocated(), 768);
701
702 assert!(manager.allocate(512).is_err());
704
705 manager.free(id1).unwrap_or_else(|e| panic!("{}", e));
706 assert_eq!(manager.allocated(), 256);
707
708 manager.free(id2).unwrap_or_else(|e| panic!("{}", e));
709 assert_eq!(manager.allocated(), 0);
710 }
711
712 #[test]
713 fn test_memory_pool_config() {
714 let config = MemoryPoolConfig::default();
715 assert_eq!(config.initial_size, 64 * 1024 * 1024);
716 assert_eq!(config.max_size, 2 * 1024 * 1024 * 1024);
717 }
718}