1use super::{GpuBuffer, GpuConfig};
4use anyhow::{anyhow, Result};
5use std::collections::VecDeque;
6use std::sync::{Arc, Mutex};
7
8#[derive(Debug)]
10pub struct GpuMemoryPool {
11 device_id: i32,
12 available_buffers: Arc<Mutex<VecDeque<GpuBuffer>>>,
13 allocated_buffers: Arc<Mutex<Vec<GpuBuffer>>>,
14 total_memory: usize,
15 used_memory: usize,
16 buffer_size: usize,
17 max_buffers: usize,
18}
19
20impl GpuMemoryPool {
21 pub fn new(config: &GpuConfig, buffer_size: usize) -> Result<Self> {
23 let max_buffers = config.memory_pool_size / (buffer_size * std::mem::size_of::<f32>());
24
25 Ok(Self {
26 device_id: config.device_id,
27 available_buffers: Arc::new(Mutex::new(VecDeque::new())),
28 allocated_buffers: Arc::new(Mutex::new(Vec::new())),
29 total_memory: config.memory_pool_size,
30 used_memory: 0,
31 buffer_size,
32 max_buffers,
33 })
34 }
35
36 pub fn get_buffer(&mut self) -> Result<GpuBuffer> {
38 {
40 let mut available = self
41 .available_buffers
42 .lock()
43 .map_err(|e| anyhow!("Failed to lock available buffers: {}", e))?;
44
45 if let Some(buffer) = available.pop_front() {
46 return Ok(buffer);
48 }
49 }
50
51 if self.allocated_buffers.lock().unwrap().len() >= self.max_buffers {
53 return Err(anyhow!("Memory pool exhausted"));
54 }
55
56 let buffer = GpuBuffer::new(self.buffer_size, self.device_id)?;
58 self.used_memory += self.buffer_size * std::mem::size_of::<f32>();
59
60 Ok(buffer)
62 }
63
64 pub fn return_buffer(&mut self, buffer: GpuBuffer) -> Result<()> {
66 {
68 let mut allocated = self
69 .allocated_buffers
70 .lock()
71 .map_err(|e| anyhow!("Failed to lock allocated buffers: {}", e))?;
72
73 allocated.retain(|b| b.ptr() != buffer.ptr());
75 }
76
77 self.available_buffers
79 .lock()
80 .map_err(|e| anyhow!("Failed to lock available buffers: {}", e))?
81 .push_back(buffer);
82
83 Ok(())
84 }
85
86 pub fn stats(&self) -> MemoryPoolStats {
88 let allocated_count = self.allocated_buffers.lock().unwrap().len();
89 let available_count = self.available_buffers.lock().unwrap().len();
90
91 MemoryPoolStats {
92 total_buffers: allocated_count + available_count,
93 allocated_buffers: allocated_count,
94 available_buffers: available_count,
95 total_memory: self.total_memory,
96 used_memory: self.used_memory,
97 buffer_size: self.buffer_size,
98 utilization: if self.total_memory > 0 {
99 self.used_memory as f64 / self.total_memory as f64
100 } else {
101 0.0
102 },
103 }
104 }
105
106 pub fn preallocate(&mut self, count: usize) -> Result<()> {
108 let effective_count = count.min(self.max_buffers);
109
110 for _ in 0..effective_count {
111 let buffer = GpuBuffer::new(self.buffer_size, self.device_id)?;
112 self.used_memory += self.buffer_size * std::mem::size_of::<f32>();
113
114 self.available_buffers
115 .lock()
116 .map_err(|e| anyhow!("Failed to lock available buffers: {}", e))?
117 .push_back(buffer);
118 }
119
120 Ok(())
121 }
122
123 pub fn clear(&mut self) {
125 self.available_buffers.lock().unwrap().clear();
127 self.allocated_buffers.lock().unwrap().clear();
128 self.used_memory = 0;
129 }
130
131 pub fn has_capacity(&self) -> bool {
133 let total_buffers = self.available_buffers.lock().unwrap().len()
134 + self.allocated_buffers.lock().unwrap().len();
135 total_buffers < self.max_buffers
136 }
137
138 pub fn memory_usage(&self) -> usize {
140 self.used_memory
141 }
142
143 pub fn utilization(&self) -> f64 {
145 if self.total_memory > 0 {
146 self.used_memory as f64 / self.total_memory as f64
147 } else {
148 0.0
149 }
150 }
151
152 pub fn defragment(&mut self) -> Result<()> {
154 let mut available = self
157 .available_buffers
158 .lock()
159 .map_err(|e| anyhow!("Failed to lock available buffers: {}", e))?;
160
161 let mut buffers: Vec<GpuBuffer> = available.drain(..).collect();
163 buffers.sort_by_key(|b| b.ptr() as usize);
164
165 for buffer in buffers {
166 available.push_back(buffer);
167 }
168
169 Ok(())
170 }
171}
172
173#[derive(Debug, Clone)]
175pub struct MemoryPoolStats {
176 pub total_buffers: usize,
177 pub allocated_buffers: usize,
178 pub available_buffers: usize,
179 pub total_memory: usize,
180 pub used_memory: usize,
181 pub buffer_size: usize,
182 pub utilization: f64,
183}
184
185impl MemoryPoolStats {
186 pub fn is_under_pressure(&self) -> bool {
188 self.utilization > 0.8 || self.available_buffers < 2
189 }
190
191 pub fn remaining_capacity(&self) -> usize {
193 if self.total_memory > self.used_memory {
194 let remaining_memory = self.total_memory - self.used_memory;
195 remaining_memory / (self.buffer_size * std::mem::size_of::<f32>())
196 } else {
197 0
198 }
199 }
200
201 pub fn print(&self) {
203 println!("GPU Memory Pool Statistics:");
204 println!(" Total buffers: {}", self.total_buffers);
205 println!(
206 " Allocated: {}, Available: {}",
207 self.allocated_buffers, self.available_buffers
208 );
209 println!(
210 " Memory usage: {:.2} MB / {:.2} MB ({:.1}%)",
211 self.used_memory as f64 / 1024.0 / 1024.0,
212 self.total_memory as f64 / 1024.0 / 1024.0,
213 self.utilization * 100.0
214 );
215 println!(
216 " Buffer size: {:.2} KB",
217 self.buffer_size as f64 * 4.0 / 1024.0
218 );
219 println!(
220 " Remaining capacity: {} buffers",
221 self.remaining_capacity()
222 );
223
224 if self.is_under_pressure() {
225 println!(" ⚠️ Memory pool is under pressure!");
226 }
227 }
228}
229
230#[derive(Debug)]
232pub struct AdvancedGpuMemoryPool {
233 pools: Vec<GpuMemoryPool>,
234 buffer_sizes: Vec<usize>,
235 device_id: i32,
236}
237
238impl AdvancedGpuMemoryPool {
239 pub fn new(config: &GpuConfig, buffer_sizes: Vec<usize>) -> Result<Self> {
241 let mut pools = Vec::new();
242
243 for &size in &buffer_sizes {
244 let pool = GpuMemoryPool::new(config, size)?;
245 pools.push(pool);
246 }
247
248 Ok(Self {
249 pools,
250 buffer_sizes: buffer_sizes.clone(),
251 device_id: config.device_id,
252 })
253 }
254
255 pub fn get_buffer(&mut self, required_size: usize) -> Result<GpuBuffer> {
257 let pool_index = self
259 .buffer_sizes
260 .iter()
261 .position(|&size| size >= required_size)
262 .ok_or_else(|| anyhow!("No buffer size large enough for request"))?;
263
264 self.pools[pool_index].get_buffer()
265 }
266
267 pub fn return_buffer(&mut self, buffer: GpuBuffer) -> Result<()> {
269 let buffer_size = buffer.size();
270
271 let pool_index = self
273 .buffer_sizes
274 .iter()
275 .position(|&size| size == buffer_size)
276 .ok_or_else(|| anyhow!("Buffer size does not match any pool"))?;
277
278 self.pools[pool_index].return_buffer(buffer)
279 }
280
281 pub fn combined_stats(&self) -> AdvancedMemoryPoolStats {
283 let mut total_buffers = 0;
284 let mut total_allocated = 0;
285 let mut total_available = 0;
286 let mut total_memory = 0;
287 let mut total_used = 0;
288 let mut pool_stats = Vec::new();
289
290 for pool in &self.pools {
291 let stats = pool.stats();
292 total_buffers += stats.total_buffers;
293 total_allocated += stats.allocated_buffers;
294 total_available += stats.available_buffers;
295 total_memory += stats.total_memory;
296 total_used += stats.used_memory;
297 pool_stats.push(stats);
298 }
299
300 AdvancedMemoryPoolStats {
301 pool_stats,
302 total_buffers,
303 total_allocated,
304 total_available,
305 total_memory,
306 total_used,
307 utilization: if total_memory > 0 {
308 total_used as f64 / total_memory as f64
309 } else {
310 0.0
311 },
312 }
313 }
314
315 pub fn preallocate_all(&mut self, buffers_per_pool: usize) -> Result<()> {
317 for pool in &mut self.pools {
318 pool.preallocate(buffers_per_pool)?;
319 }
320 Ok(())
321 }
322}
323
324#[derive(Debug, Clone)]
326pub struct AdvancedMemoryPoolStats {
327 pub pool_stats: Vec<MemoryPoolStats>,
328 pub total_buffers: usize,
329 pub total_allocated: usize,
330 pub total_available: usize,
331 pub total_memory: usize,
332 pub total_used: usize,
333 pub utilization: f64,
334}
335
336impl AdvancedMemoryPoolStats {
337 pub fn print_detailed(&self) {
339 println!("Advanced GPU Memory Pool Statistics:");
340 println!(
341 " Overall: {} buffers, {:.1}% utilization",
342 self.total_buffers,
343 self.utilization * 100.0
344 );
345 println!(
346 " Total memory: {:.2} MB",
347 self.total_memory as f64 / 1024.0 / 1024.0
348 );
349
350 for (i, stats) in self.pool_stats.iter().enumerate() {
351 println!(
352 " Pool {}: {:.2} KB buffers, {} total, {:.1}% util",
353 i,
354 stats.buffer_size as f64 * 4.0 / 1024.0,
355 stats.total_buffers,
356 stats.utilization * 100.0
357 );
358 }
359 }
360}