ghostflow_cuda/
memory.rs

1//! GPU memory management
2//!
3//! Efficient memory allocation and pooling for GPU tensors
4
5use std::collections::HashMap;
6use std::sync::{Arc, Mutex};
7
8/// GPU memory pool for efficient allocation
9#[derive(Debug)]
10pub struct GpuMemoryPool {
11    /// Free memory blocks by size
12    free_blocks: Arc<Mutex<HashMap<usize, Vec<*mut u8>>>>,
13    /// Allocated blocks
14    allocated: Arc<Mutex<HashMap<*mut u8, usize>>>,
15    /// Total allocated memory
16    total_allocated: Arc<Mutex<usize>>,
17    /// Peak memory usage
18    peak_usage: Arc<Mutex<usize>>,
19}
20
21impl GpuMemoryPool {
22    /// Create a new GPU memory pool
23    pub fn new() -> Self {
24        Self {
25            free_blocks: Arc::new(Mutex::new(HashMap::new())),
26            allocated: Arc::new(Mutex::new(HashMap::new())),
27            total_allocated: Arc::new(Mutex::new(0)),
28            peak_usage: Arc::new(Mutex::new(0)),
29        }
30    }
31
32    /// Allocate GPU memory (with pooling)
33    pub fn allocate(&self, size: usize) -> Result<*mut u8, String> {
34        // Round up to nearest power of 2 for better pooling
35        let rounded_size = size.next_power_of_two();
36        
37        // Try to reuse from pool
38        let mut free_blocks = self.free_blocks.lock().unwrap();
39        if let Some(blocks) = free_blocks.get_mut(&rounded_size) {
40            if let Some(ptr) = blocks.pop() {
41                // Reuse existing block
42                let mut allocated = self.allocated.lock().unwrap();
43                allocated.insert(ptr, rounded_size);
44                return Ok(ptr);
45            }
46        }
47        drop(free_blocks);
48        
49        // Allocate new block
50        #[cfg(feature = "cuda")]
51        let ptr = unsafe {
52            let mut ptr: *mut u8 = std::ptr::null_mut();
53            use crate::ffi;
54            let result = ffi::cudaMalloc(
55                &mut ptr as *mut *mut u8 as *mut *mut std::ffi::c_void,
56                rounded_size,
57            );
58            if result != 0 {
59                return Err(format!("CUDA malloc failed with error code: {}", result));
60            }
61            ptr
62        };
63        
64        #[cfg(not(feature = "cuda"))]
65        let ptr = {
66            // CPU fallback - just allocate regular memory
67            vec![0u8; rounded_size].leak().as_mut_ptr()
68        };
69        
70        // Track allocation
71        let mut allocated = self.allocated.lock().unwrap();
72        allocated.insert(ptr, rounded_size);
73        
74        let mut total = self.total_allocated.lock().unwrap();
75        *total += rounded_size;
76        
77        let mut peak = self.peak_usage.lock().unwrap();
78        if *total > *peak {
79            *peak = *total;
80        }
81        
82        Ok(ptr)
83    }
84
85    /// Free GPU memory (return to pool)
86    pub fn free(&self, ptr: *mut u8) -> Result<(), String> {
87        let mut allocated = self.allocated.lock().unwrap();
88        
89        if let Some(size) = allocated.remove(&ptr) {
90            // Return to pool instead of freeing
91            let mut free_blocks = self.free_blocks.lock().unwrap();
92            free_blocks.entry(size).or_insert_with(Vec::new).push(ptr);
93            
94            let mut total = self.total_allocated.lock().unwrap();
95            *total -= size;
96            
97            Ok(())
98        } else {
99            Err("Attempted to free unallocated pointer".to_string())
100        }
101    }
102
103    /// Clear the memory pool (actually free memory)
104    pub fn clear(&self) -> Result<(), String> {
105        let mut free_blocks = self.free_blocks.lock().unwrap();
106        
107        for (_, blocks) in free_blocks.iter() {
108            for &ptr in blocks {
109                #[cfg(feature = "cuda")]
110                unsafe {
111                    use crate::ffi;
112                    ffi::cudaFree(ptr as *mut std::ffi::c_void);
113                }
114                
115                #[cfg(not(feature = "cuda"))]
116                unsafe {
117                    // Free CPU memory
118                    let _ = Vec::from_raw_parts(ptr, 0, 0);
119                }
120            }
121        }
122        
123        free_blocks.clear();
124        Ok(())
125    }
126
127    /// Get current memory usage
128    pub fn current_usage(&self) -> usize {
129        *self.total_allocated.lock().unwrap()
130    }
131
132    /// Get peak memory usage
133    pub fn peak_usage(&self) -> usize {
134        *self.peak_usage.lock().unwrap()
135    }
136}
137
138impl Drop for GpuMemoryPool {
139    fn drop(&mut self) {
140        #[cfg(feature = "cuda")]
141        {
142            let _ = self.clear();
143        }
144    }
145}
146
147/// GPU tensor wrapper
148pub struct GpuTensor {
149    /// Pointer to GPU memory
150    ptr: *mut f32,
151    /// Shape of the tensor
152    shape: Vec<usize>,
153    /// Memory pool reference
154    #[allow(dead_code)]
155    pool: Arc<GpuMemoryPool>,
156}
157
158impl GpuTensor {
159    /// Create a new GPU tensor
160    #[cfg(feature = "cuda")]
161    pub fn new(shape: Vec<usize>, pool: Arc<GpuMemoryPool>) -> Result<Self, String> {
162        let numel: usize = shape.iter().product();
163        let size_bytes = numel * std::mem::size_of::<f32>();
164        
165        let ptr = pool.allocate(size_bytes)? as *mut f32;
166        
167        Ok(Self { ptr, shape, pool })
168    }
169
170    /// Copy data from CPU to GPU
171    pub fn copy_from_cpu(&mut self, data: &[f32]) -> Result<(), String> {
172        let numel: usize = self.shape.iter().product();
173        if data.len() != numel {
174            return Err("Data size mismatch".to_string());
175        }
176        
177        #[cfg(feature = "cuda")]
178        unsafe {
179            use crate::ffi;
180            let result = ffi::cudaMemcpy(
181                self.ptr as *mut std::ffi::c_void,
182                data.as_ptr() as *const std::ffi::c_void,
183                numel * std::mem::size_of::<f32>(),
184                ffi::cudaMemcpyKind::cudaMemcpyHostToDevice,
185            );
186            if result != 0 {
187                return Err(format!("CUDA memcpy H2D failed with error code: {}", result));
188            }
189        }
190        
191        #[cfg(not(feature = "cuda"))]
192        unsafe {
193            // CPU fallback - just copy memory
194            std::ptr::copy_nonoverlapping(data.as_ptr(), self.ptr, numel);
195        }
196        
197        Ok(())
198    }
199
200    /// Copy data from GPU to CPU
201    pub fn copy_to_cpu(&self) -> Result<Vec<f32>, String> {
202        let numel: usize = self.shape.iter().product();
203        let mut data = vec![0.0f32; numel];
204        
205        #[cfg(feature = "cuda")]
206        unsafe {
207            use crate::ffi;
208            let result = ffi::cudaMemcpy(
209                data.as_mut_ptr() as *mut std::ffi::c_void,
210                self.ptr as *const std::ffi::c_void,
211                numel * std::mem::size_of::<f32>(),
212                ffi::cudaMemcpyKind::cudaMemcpyDeviceToHost,
213            );
214            if result != 0 {
215                return Err(format!("CUDA memcpy D2H failed with error code: {}", result));
216            }
217        }
218        
219        #[cfg(not(feature = "cuda"))]
220        unsafe {
221            // CPU fallback - just copy memory
222            std::ptr::copy_nonoverlapping(self.ptr, data.as_mut_ptr(), numel);
223        }
224        
225        Ok(data)
226    }
227
228    /// Get raw pointer
229    pub fn as_ptr(&self) -> *const f32 {
230        self.ptr
231    }
232
233    /// Get mutable raw pointer
234    pub fn as_mut_ptr(&mut self) -> *mut f32 {
235        self.ptr
236    }
237
238    /// Get shape
239    pub fn shape(&self) -> &[usize] {
240        &self.shape
241    }
242}
243
244impl Drop for GpuTensor {
245    fn drop(&mut self) {
246        #[cfg(feature = "cuda")]
247        {
248            let _ = self.pool.free(self.ptr as *mut u8);
249        }
250    }
251}
252
253/// Global GPU memory pool
254static mut GLOBAL_GPU_POOL: Option<Arc<GpuMemoryPool>> = None;
255
256/// Get or create the global GPU memory pool
257#[allow(static_mut_refs)]
258pub fn get_global_gpu_pool() -> Arc<GpuMemoryPool> {
259    unsafe {
260        if GLOBAL_GPU_POOL.is_none() {
261            GLOBAL_GPU_POOL = Some(Arc::new(GpuMemoryPool::new()));
262        }
263        GLOBAL_GPU_POOL.as_ref().unwrap().clone()
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[test]
272    fn test_memory_pool() {
273        let pool = GpuMemoryPool::new();
274        assert_eq!(pool.current_usage(), 0);
275    }
276}