1use serde::{Deserialize, Serialize};
2use std::sync::Mutex;
3
4use crate::{GpuError, GpuMemoryLocation, GpuUsage, handles::GpuBufferHandle};
5
6pub trait BufferPool: Send + Sync {
8 fn alloc(
9 &self,
10 size_bytes: u64,
11 usage: GpuUsage,
12 location: GpuMemoryLocation,
13 ) -> Result<GpuBufferHandle, GpuError>;
14 fn free(&self, handle: GpuBufferHandle);
15}
16
17pub struct SimpleBufferPool {
19 buckets: Mutex<Vec<Vec<GpuBufferHandle>>>,
20}
21
22impl SimpleBufferPool {
23 pub fn new() -> Self {
24 Self {
25 buckets: Mutex::new(vec![Vec::new(); 16]), }
27 }
28}
29
30impl Default for SimpleBufferPool {
31 fn default() -> Self {
32 Self::new()
33 }
34}
35
36impl BufferPool for SimpleBufferPool {
37 fn alloc(
38 &self,
39 size_bytes: u64,
40 usage: GpuUsage,
41 location: GpuMemoryLocation,
42 ) -> Result<GpuBufferHandle, GpuError> {
43 if usage.is_empty() {
44 return Err(GpuError::Unsupported);
45 }
46 let bucket_idx = bucket_for(size_bytes);
47 let mut buckets = self.buckets.lock().expect("buffer pool lock poisoned");
48 if let Some(bucket) = buckets.get_mut(bucket_idx)
49 && let Some((idx, _)) = bucket
50 .iter()
51 .enumerate()
52 .find(|(_, h)| h.usage == usage && h.location == location)
53 {
54 let handle = bucket.swap_remove(idx);
55 return Ok(handle);
56 }
57 Ok(GpuBufferHandle::new(size_bytes, location, usage))
58 }
59
60 fn free(&self, handle: GpuBufferHandle) {
61 let idx = bucket_for(handle.size_bytes);
62 let mut buckets = self.buckets.lock().expect("buffer pool lock poisoned");
63 if let Some(bucket) = buckets.get_mut(idx) {
64 bucket.push(handle);
65 }
66 }
67}
68
69fn bucket_for(size: u64) -> usize {
70 let mut pow = 0;
71 let mut s = 1u64;
72 while s < size && pow < 63 {
73 s <<= 1;
74 pow += 1;
75 }
76 usize::min(pow as usize, 15)
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
81pub struct TransferStats {
82 pub bytes_uploaded: u64,
83 pub bytes_downloaded: u64,
84}
85
86impl TransferStats {
87 pub fn record_upload(&mut self, bytes: u64) {
88 self.bytes_uploaded = self.bytes_uploaded.saturating_add(bytes);
89 }
90
91 pub fn record_download(&mut self, bytes: u64) {
92 self.bytes_downloaded = self.bytes_downloaded.saturating_add(bytes);
93 }
94
95 pub fn take(&mut self) -> TransferStats {
96 std::mem::take(self)
97 }
98}