cu29_runtime/
pool.rs

1use arrayvec::ArrayString;
2use bincode::de::Decoder;
3use bincode::enc::Encoder;
4use bincode::error::{DecodeError, EncodeError};
5use bincode::{Decode, Encode};
6use cu29_traits::CuResult;
7use object_pool::{Pool, ReusableOwned};
8use smallvec::SmallVec;
9use std::alloc::{alloc, dealloc, Layout};
10use std::collections::HashMap;
11use std::fmt::Debug;
12use std::ops::{Deref, DerefMut};
13use std::sync::{Arc, Mutex, OnceLock};
14
15type PoolID = ArrayString<64>;
16
17/// Trait for a Pool to exposed to be monitored by the monitoring API.
18pub trait PoolMonitor: Send + Sync {
19    /// A unique and descriptive identifier for the pool.
20    fn id(&self) -> PoolID;
21
22    /// Number of buffer slots left in the pool.
23    fn space_left(&self) -> usize;
24
25    /// Total size of the pool in number of buffers.
26    fn total_size(&self) -> usize;
27
28    /// Size of one buffer
29    fn buffer_size(&self) -> usize;
30}
31
32static POOL_REGISTRY: OnceLock<Mutex<HashMap<String, Arc<dyn PoolMonitor>>>> = OnceLock::new();
33const MAX_POOLS: usize = 16;
34
35// Register a pool to the global registry.
36fn register_pool(pool: Arc<dyn PoolMonitor>) {
37    POOL_REGISTRY
38        .get_or_init(|| Mutex::new(HashMap::new()))
39        .lock()
40        .unwrap()
41        .insert(pool.id().to_string(), pool);
42}
43
44type PoolStats = (PoolID, usize, usize, usize);
45
46/// Get the list of pools and their statistics.
47/// We use SmallVec here to avoid heap allocations while the stack is running.
48pub fn pools_statistics() -> SmallVec<[PoolStats; MAX_POOLS]> {
49    let registry = POOL_REGISTRY.get().unwrap().lock().unwrap();
50    let mut result = SmallVec::with_capacity(MAX_POOLS);
51    for pool in registry.values() {
52        result.push((
53            pool.id(),
54            pool.space_left(),
55            pool.total_size(),
56            pool.buffer_size(),
57        ));
58    }
59    result
60}
61
62/// Basic Type that can be used in a buffer in a CuPool.
63pub trait ElementType:
64    Default + Sized + Copy + Encode + Decode + Debug + Unpin + Send + Sync
65{
66}
67
68/// Blanket implementation for all types that are Sized, Copy, Encode, Decode and Debug.
69impl<T> ElementType for T where
70    T: Default + Sized + Copy + Encode + Decode + Debug + Unpin + Send + Sync
71{
72}
73
74pub trait ArrayLike: Deref<Target = [Self::Element]> + DerefMut + Debug + Sync + Send {
75    type Element: ElementType;
76}
77
78/// A Handle to a Buffer.
79/// For onboard usages, the buffer should be Pooled (ie, coming from a preallocated pool).
80/// The Detached version is for offline usages where we don't really need a pool to deserialize them.
81pub enum CuHandleInner<T: Debug> {
82    Pooled(ReusableOwned<T>),
83    Detached(T), // Should only be used in offline cases (e.g. deserialization)
84}
85
86impl<T> Debug for CuHandleInner<T>
87where
88    T: Debug,
89{
90    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91        match self {
92            CuHandleInner::Pooled(r) => {
93                write!(f, "Pooled: {:?}", r.deref())
94            }
95            CuHandleInner::Detached(r) => write!(f, "Detached: {:?}", r),
96        }
97    }
98}
99
100impl<T: ArrayLike> Deref for CuHandleInner<T> {
101    type Target = [T::Element];
102
103    fn deref(&self) -> &Self::Target {
104        match self {
105            CuHandleInner::Pooled(pooled) => pooled,
106            CuHandleInner::Detached(detached) => detached,
107        }
108    }
109}
110
111impl<T: ArrayLike> DerefMut for CuHandleInner<T> {
112    fn deref_mut(&mut self) -> &mut Self::Target {
113        match self {
114            CuHandleInner::Pooled(pooled) => pooled.deref_mut(),
115            CuHandleInner::Detached(detached) => detached,
116        }
117    }
118}
119
120/// A shareable handle to an Array coming from a pool (either host or device).
121#[derive(Clone, Debug)]
122pub struct CuHandle<T: ArrayLike>(Arc<Mutex<CuHandleInner<T>>>);
123
124impl<T: ArrayLike> Deref for CuHandle<T> {
125    type Target = Arc<Mutex<CuHandleInner<T>>>;
126
127    fn deref(&self) -> &Self::Target {
128        &self.0
129    }
130}
131
132impl<T: ArrayLike> CuHandle<T> {
133    /// Create a new CuHandle not part of a Pool (not for onboard usages, use pools instead)
134    pub fn new_detached(inner: T) -> Self {
135        CuHandle(Arc::new(Mutex::new(CuHandleInner::Detached(inner))))
136    }
137
138    /// Safely access the inner value, applying a closure to it.
139    pub fn with_inner<R>(&self, f: impl FnOnce(&CuHandleInner<T>) -> R) -> R {
140        let lock = self.lock().unwrap();
141        f(&*lock)
142    }
143
144    /// Mutably access the inner value, applying a closure to it.
145    pub fn with_inner_mut<R>(&self, f: impl FnOnce(&mut CuHandleInner<T>) -> R) -> R {
146        let mut lock = self.lock().unwrap();
147        f(&mut *lock)
148    }
149}
150
151impl<T: ArrayLike> Encode for CuHandle<T>
152where
153    <T as ArrayLike>::Element: 'static,
154{
155    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
156        let inner = self.lock().unwrap();
157        match inner.deref() {
158            CuHandleInner::Pooled(pooled) => pooled.encode(encoder),
159            CuHandleInner::Detached(detached) => detached.encode(encoder),
160        }
161    }
162}
163
164impl<T: ArrayLike> Default for CuHandle<T> {
165    fn default() -> Self {
166        panic!("Cannot create a default CuHandle")
167    }
168}
169
170impl<U: ElementType + 'static> Decode for CuHandle<Vec<U>> {
171    fn decode<D: Decoder>(decoder: &mut D) -> Result<Self, DecodeError> {
172        let vec: Vec<U> = Vec::decode(decoder)?;
173        Ok(CuHandle(Arc::new(Mutex::new(CuHandleInner::Detached(vec)))))
174    }
175}
176
177/// A CuPool is a pool of buffers that can be shared between different parts of the code.
178/// Handles can be stored locally in the tasks and shared between them.
179pub trait CuPool<T: ArrayLike>: PoolMonitor {
180    /// Acquire a buffer from the pool.
181    fn acquire(&self) -> Option<CuHandle<T>>;
182
183    /// Copy data from a handle to a new handle from the pool.
184    fn copy_from<O>(&self, from: &mut CuHandle<O>) -> CuHandle<T>
185    where
186        O: ArrayLike<Element = T::Element>;
187}
188
189/// A device memory pool can copy data from a device to a host memory pool on top.
190pub trait DeviceCuPool<T: ArrayLike>: CuPool<T> {
191    /// Takes a handle to a device buffer and copies it into a host buffer pool.
192    /// It returns a new handle from the host pool with the data from the device handle given.
193    fn copy_to_host_pool<O>(
194        &self,
195        from_device_handle: &CuHandle<T>,
196        to_host_handle: &mut CuHandle<O>,
197    ) -> CuResult<()>
198    where
199        O: ArrayLike<Element = T::Element>;
200}
201
202/// A pool of host memory buffers.
203pub struct CuHostMemoryPool<T> {
204    /// Underlying pool of host buffers.
205    // Being an Arc is a requirement of try_pull_owned() so buffers can refer back to the pool.
206    id: PoolID,
207    pool: Arc<Pool<T>>,
208    size: usize,
209    buffer_size: usize,
210}
211
212impl<T: ArrayLike + 'static> CuHostMemoryPool<T> {
213    pub fn new<F>(id: &str, size: usize, buffer_initializer: F) -> CuResult<Arc<Self>>
214    where
215        F: Fn() -> T,
216    {
217        let pool = Arc::new(Pool::new(size, buffer_initializer));
218        let buffer_size = pool.try_pull().unwrap().len() * size_of::<T::Element>();
219
220        let og = Self {
221            id: PoolID::from(id).map_err(|_| "Failed to create PoolID")?,
222            pool,
223            size,
224            buffer_size,
225        };
226        let og = Arc::new(og);
227        register_pool(og.clone());
228        Ok(og)
229    }
230}
231
232impl<T: ArrayLike> PoolMonitor for CuHostMemoryPool<T> {
233    fn id(&self) -> PoolID {
234        self.id
235    }
236
237    fn space_left(&self) -> usize {
238        self.pool.len()
239    }
240
241    fn total_size(&self) -> usize {
242        self.size
243    }
244
245    fn buffer_size(&self) -> usize {
246        self.buffer_size
247    }
248}
249
250impl<T: ArrayLike> CuPool<T> for CuHostMemoryPool<T> {
251    fn acquire(&self) -> Option<CuHandle<T>> {
252        let owned_object = self.pool.try_pull_owned(); // Use the owned version
253
254        owned_object.map(|reusable| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(reusable)))))
255    }
256
257    fn copy_from<O: ArrayLike<Element = T::Element>>(&self, from: &mut CuHandle<O>) -> CuHandle<T> {
258        let to_handle = self.acquire().expect("No available buffers in the pool");
259
260        match from.lock().unwrap().deref() {
261            CuHandleInner::Detached(source) => match to_handle.lock().unwrap().deref_mut() {
262                CuHandleInner::Detached(destination) => {
263                    destination.copy_from_slice(source);
264                }
265                CuHandleInner::Pooled(destination) => {
266                    destination.copy_from_slice(source);
267                }
268            },
269            CuHandleInner::Pooled(source) => match to_handle.lock().unwrap().deref_mut() {
270                CuHandleInner::Detached(destination) => {
271                    destination.copy_from_slice(source);
272                }
273                CuHandleInner::Pooled(destination) => {
274                    destination.copy_from_slice(source);
275                }
276            },
277        }
278        to_handle
279    }
280}
281
282impl<E: ElementType + 'static> ArrayLike for Vec<E> {
283    type Element = E;
284}
285
286#[cfg(all(feature = "cuda", not(target_os = "macos")))]
287mod cuda {
288    use super::*;
289    use cu29_traits::CuError;
290    use cudarc::driver::{CudaDevice, CudaSlice, DeviceRepr, ValidAsZeroBits};
291    use std::sync::Arc;
292
293    #[derive(Debug)]
294    pub struct CudaSliceWrapper<E>(CudaSlice<E>);
295
296    impl<E> Deref for CudaSliceWrapper<E>
297    where
298        E: ElementType,
299    {
300        type Target = [E];
301
302        fn deref(&self) -> &Self::Target {
303            // Implement logic to return a slice
304            panic!("You need to copy data to host memory pool before accessing it.");
305        }
306    }
307
308    impl<E> DerefMut for CudaSliceWrapper<E>
309    where
310        E: ElementType,
311    {
312        fn deref_mut(&mut self) -> &mut Self::Target {
313            panic!("You need to copy data to host memory pool before accessing it.");
314        }
315    }
316
317    impl<E: ElementType> ArrayLike for CudaSliceWrapper<E> {
318        type Element = E;
319    }
320
321    impl<E> CudaSliceWrapper<E> {
322        pub fn as_cuda_slice(&self) -> &CudaSlice<E> {
323            &self.0
324        }
325
326        pub fn as_cuda_slice_mut(&mut self) -> &mut CudaSlice<E> {
327            &mut self.0
328        }
329    }
330
331    /// A pool of CUDA memory buffers.
332    pub struct CuCudaPool<E>
333    where
334        E: ElementType + ValidAsZeroBits + DeviceRepr + Unpin,
335    {
336        id: PoolID,
337        device: Arc<CudaDevice>,
338        pool: Arc<Pool<CudaSliceWrapper<E>>>,
339        nb_buffers: usize,
340        nb_element_per_buffer: usize,
341    }
342
343    impl<E: ElementType + ValidAsZeroBits + DeviceRepr> CuCudaPool<E> {
344        #[allow(dead_code)]
345        pub fn new(
346            id: &'static str,
347            device: Arc<CudaDevice>,
348            nb_buffers: usize,
349            nb_element_per_buffer: usize,
350        ) -> CuResult<Self> {
351            let pool = (0..nb_buffers)
352                .map(|_| {
353                    device
354                        .alloc_zeros(nb_element_per_buffer)
355                        .map(CudaSliceWrapper)
356                        .map_err(|_| "Failed to allocate device memory")
357                })
358                .collect::<Result<Vec<_>, _>>()?;
359
360            Ok(Self {
361                id: PoolID::from(id).map_err(|_| "Failed to create PoolID")?,
362                device: device.clone(),
363                pool: Arc::new(Pool::from_vec(pool)),
364                nb_buffers,
365                nb_element_per_buffer,
366            })
367        }
368    }
369
370    impl<E> PoolMonitor for CuCudaPool<E>
371    where
372        E: DeviceRepr + ElementType + ValidAsZeroBits,
373    {
374        fn id(&self) -> PoolID {
375            self.id
376        }
377
378        fn space_left(&self) -> usize {
379            self.pool.len()
380        }
381
382        fn total_size(&self) -> usize {
383            self.nb_buffers
384        }
385
386        fn buffer_size(&self) -> usize {
387            self.nb_element_per_buffer * size_of::<E>()
388        }
389    }
390
391    impl<E> CuPool<CudaSliceWrapper<E>> for CuCudaPool<E>
392    where
393        E: DeviceRepr + ElementType + ValidAsZeroBits,
394    {
395        fn acquire(&self) -> Option<CuHandle<CudaSliceWrapper<E>>> {
396            self.pool
397                .try_pull_owned()
398                .map(|x| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(x)))))
399        }
400
401        fn copy_from<O>(&self, from_handle: &mut CuHandle<O>) -> CuHandle<CudaSliceWrapper<E>>
402        where
403            O: ArrayLike<Element = E>,
404        {
405            let to_handle = self.acquire().expect("No available buffers in the pool");
406
407            match from_handle.lock().unwrap().deref() {
408                CuHandleInner::Detached(from) => match to_handle.lock().unwrap().deref_mut() {
409                    CuHandleInner::Detached(CudaSliceWrapper(to)) => {
410                        self.device
411                            .htod_sync_copy_into(from, to)
412                            .expect("Failed to copy data to device");
413                    }
414                    CuHandleInner::Pooled(to) => {
415                        self.device
416                            .htod_sync_copy_into(from, to.as_cuda_slice_mut())
417                            .expect("Failed to copy data to device");
418                    }
419                },
420                CuHandleInner::Pooled(from) => match to_handle.lock().unwrap().deref_mut() {
421                    CuHandleInner::Detached(CudaSliceWrapper(to)) => {
422                        self.device
423                            .htod_sync_copy_into(from, to)
424                            .expect("Failed to copy data to device");
425                    }
426                    CuHandleInner::Pooled(to) => {
427                        self.device
428                            .htod_sync_copy_into(from, to.as_cuda_slice_mut())
429                            .expect("Failed to copy data to device");
430                    }
431                },
432            }
433            to_handle
434        }
435    }
436
437    impl<E> DeviceCuPool<CudaSliceWrapper<E>> for CuCudaPool<E>
438    where
439        E: ElementType + ValidAsZeroBits + DeviceRepr,
440    {
441        /// Copy from device to host
442        fn copy_to_host_pool<O>(
443            &self,
444            device_handle: &CuHandle<CudaSliceWrapper<E>>,
445            host_handle: &mut CuHandle<O>,
446        ) -> Result<(), CuError>
447        where
448            O: ArrayLike<Element = E>,
449        {
450            match device_handle.lock().unwrap().deref() {
451                CuHandleInner::Pooled(source) => match host_handle.lock().unwrap().deref_mut() {
452                    CuHandleInner::Pooled(ref mut destination) => {
453                        self.device
454                            .dtoh_sync_copy_into(source.as_cuda_slice(), destination)
455                            .expect("Failed to copy data to device");
456                    }
457                    CuHandleInner::Detached(ref mut destination) => {
458                        self.device
459                            .dtoh_sync_copy_into(source.as_cuda_slice(), destination)
460                            .expect("Failed to copy data to device");
461                    }
462                },
463                CuHandleInner::Detached(source) => match host_handle.lock().unwrap().deref_mut() {
464                    CuHandleInner::Pooled(ref mut destination) => {
465                        self.device
466                            .dtoh_sync_copy_into(source.as_cuda_slice(), destination)
467                            .expect("Failed to copy data to device");
468                    }
469                    CuHandleInner::Detached(ref mut destination) => {
470                        self.device
471                            .dtoh_sync_copy_into(source.as_cuda_slice(), destination)
472                            .expect("Failed to copy data to device");
473                    }
474                },
475            }
476            Ok(())
477        }
478    }
479}
480
481#[derive(Debug)]
482/// A buffer that is aligned to a specific size with the Element of type E.
483pub struct AlignedBuffer<E: ElementType> {
484    ptr: *mut E,
485    size: usize,
486    layout: Layout,
487}
488
489impl<E: ElementType> AlignedBuffer<E> {
490    pub fn new(num_elements: usize, alignment: usize) -> Self {
491        let layout = Layout::from_size_align(num_elements * size_of::<E>(), alignment).unwrap();
492        let ptr = unsafe { alloc(layout) as *mut E };
493        if ptr.is_null() {
494            panic!("Failed to allocate memory");
495        }
496        Self {
497            ptr,
498            size: num_elements,
499            layout,
500        }
501    }
502}
503
504impl<E: ElementType> Deref for AlignedBuffer<E> {
505    type Target = [E];
506
507    fn deref(&self) -> &Self::Target {
508        unsafe { std::slice::from_raw_parts(self.ptr, self.size) }
509    }
510}
511
512impl<E: ElementType> DerefMut for AlignedBuffer<E> {
513    fn deref_mut(&mut self) -> &mut Self::Target {
514        unsafe { std::slice::from_raw_parts_mut(self.ptr, self.size) }
515    }
516}
517
518impl<E: ElementType> Drop for AlignedBuffer<E> {
519    fn drop(&mut self) {
520        if !self.ptr.is_null() {
521            unsafe {
522                dealloc(self.ptr as *mut u8, self.layout);
523            }
524        }
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531    #[cfg(all(feature = "cuda", not(target_os = "macos")))]
532    use crate::pool::cuda::CuCudaPool;
533    use std::cell::RefCell;
534
535    #[test]
536    fn test_pool() {
537        let objs = RefCell::new(vec![vec![1], vec![2], vec![3]]);
538        let holding = objs.borrow().clone();
539        let objs_as_slices = holding.iter().map(|x| x.as_slice()).collect::<Vec<_>>();
540        let pool = CuHostMemoryPool::new("mytestcudapool", 3, || objs.borrow_mut().pop().unwrap())
541            .unwrap();
542
543        let obj1 = pool.acquire().unwrap();
544        {
545            let obj2 = pool.acquire().unwrap();
546            assert!(objs_as_slices.contains(&obj1.lock().unwrap().deref().deref()));
547            assert!(objs_as_slices.contains(&obj2.lock().unwrap().deref().deref()));
548            assert_eq!(pool.space_left(), 1);
549        }
550        assert_eq!(pool.space_left(), 2);
551
552        let obj3 = pool.acquire().unwrap();
553        assert!(objs_as_slices.contains(&obj3.lock().unwrap().deref().deref()));
554
555        assert_eq!(pool.space_left(), 1);
556
557        let _obj4 = pool.acquire().unwrap();
558        assert_eq!(pool.space_left(), 0);
559
560        let obj5 = pool.acquire();
561        assert!(obj5.is_none());
562    }
563
564    #[cfg(all(feature = "cuda", not(target_os = "macos")))]
565    #[test]
566    #[ignore] // Can only be executed if a real CUDA device is present
567    fn test_cuda_pool() {
568        use cudarc::driver::CudaDevice;
569        let device = CudaDevice::new(0).unwrap();
570        let pool = CuCudaPool::<f32>::new("mytestcudapool", device, 3, 1).unwrap();
571
572        let _obj1 = pool.acquire().unwrap();
573
574        {
575            let _obj2 = pool.acquire().unwrap();
576            assert_eq!(pool.space_left(), 1);
577        }
578        assert_eq!(pool.space_left(), 2);
579
580        let _obj3 = pool.acquire().unwrap();
581
582        assert_eq!(pool.space_left(), 1);
583
584        let _obj4 = pool.acquire().unwrap();
585        assert_eq!(pool.space_left(), 0);
586
587        let obj5 = pool.acquire();
588        assert!(obj5.is_none());
589    }
590
591    #[cfg(all(feature = "cuda", not(target_os = "macos")))]
592    #[test]
593    #[ignore] // Can only be executed if a real CUDA device is present
594    fn test_copy_roundtrip() {
595        use cudarc::driver::CudaDevice;
596        let device = CudaDevice::new(0).unwrap();
597        let host_pool = CuHostMemoryPool::new("mytesthostpool", 3, || vec![0.0; 1]).unwrap();
598        let cuda_pool = CuCudaPool::<f32>::new("mytestcudapool", device, 3, 1).unwrap();
599
600        let cuda_handle = {
601            let mut initial_handle = host_pool.acquire().unwrap();
602            {
603                let mut inner_initial_handle = initial_handle.lock().unwrap();
604                if let CuHandleInner::Pooled(ref mut pooled) = *inner_initial_handle {
605                    pooled[0] = 42.0;
606                } else {
607                    panic!();
608                }
609            }
610
611            // send that to the GPU
612            cuda_pool.copy_from(&mut initial_handle)
613        };
614
615        // get it back to the host
616        let mut final_handle = host_pool.acquire().unwrap();
617        cuda_pool
618            .copy_to_host_pool(&cuda_handle, &mut final_handle)
619            .unwrap();
620
621        let value = final_handle.lock().unwrap().deref().deref()[0];
622        assert_eq!(value, 42.0);
623    }
624}