cu29_runtime/
pool.rs

1use arrayvec::ArrayString;
2use bincode::de::Decoder;
3use bincode::enc::Encoder;
4use bincode::error::{DecodeError, EncodeError};
5use bincode::{Decode, Encode};
6use cu29_traits::CuResult;
7use object_pool::{Pool, ReusableOwned};
8use smallvec::SmallVec;
9use std::alloc::{alloc, dealloc, Layout};
10use std::collections::HashMap;
11use std::fmt::Debug;
12use std::ops::{Deref, DerefMut};
13use std::sync::{Arc, Mutex, OnceLock};
14
15type PoolID = ArrayString<64>;
16
17/// Trait for a Pool to exposed to be monitored by the monitoring API.
18pub trait PoolMonitor: Send + Sync {
19    /// A unique and descriptive identifier for the pool.
20    fn id(&self) -> PoolID;
21
22    /// Number of buffer slots left in the pool.
23    fn space_left(&self) -> usize;
24
25    /// Total size of the pool in number of buffers.
26    fn total_size(&self) -> usize;
27
28    /// Size of one buffer
29    fn buffer_size(&self) -> usize;
30}
31
32static POOL_REGISTRY: OnceLock<Mutex<HashMap<String, Arc<dyn PoolMonitor>>>> = OnceLock::new();
33const MAX_POOLS: usize = 16;
34
35// Register a pool to the global registry.
36fn register_pool(pool: Arc<dyn PoolMonitor>) {
37    POOL_REGISTRY
38        .get_or_init(|| Mutex::new(HashMap::new()))
39        .lock()
40        .unwrap()
41        .insert(pool.id().to_string(), pool);
42}
43
44type PoolStats = (PoolID, usize, usize, usize);
45
46/// Get the list of pools and their statistics.
47/// We use SmallVec here to avoid heap allocations while the stack is running.
48pub fn pools_statistics() -> SmallVec<[PoolStats; MAX_POOLS]> {
49    let registry = POOL_REGISTRY.get().unwrap().lock().unwrap();
50    let mut result = SmallVec::with_capacity(MAX_POOLS);
51    for pool in registry.values() {
52        result.push((
53            pool.id(),
54            pool.space_left(),
55            pool.total_size(),
56            pool.buffer_size(),
57        ));
58    }
59    result
60}
61
62/// Basic Type that can be used in a buffer in a CuPool.
63pub trait ElementType: Default + Sized + Copy + Debug + Unpin + Send + Sync {
64    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError>;
65    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError>;
66}
67
68/// Blanket implementation for all types that are Sized, Copy, Encode, Decode and Debug.
69impl<T> ElementType for T
70where
71    T: Default + Sized + Copy + Debug + Unpin + Send + Sync,
72    T: Encode,
73    T: Decode<()>,
74{
75    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
76        self.encode(encoder)
77    }
78
79    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
80        Self::decode(decoder)
81    }
82}
83
84pub trait ArrayLike: Deref<Target = [Self::Element]> + DerefMut + Debug + Sync + Send {
85    type Element: ElementType;
86}
87
88/// A Handle to a Buffer.
89/// For onboard usages, the buffer should be Pooled (ie, coming from a preallocated pool).
90/// The Detached version is for offline usages where we don't really need a pool to deserialize them.
91pub enum CuHandleInner<T: Debug> {
92    Pooled(ReusableOwned<T>),
93    Detached(T), // Should only be used in offline cases (e.g. deserialization)
94}
95
96impl<T> Debug for CuHandleInner<T>
97where
98    T: Debug,
99{
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        match self {
102            CuHandleInner::Pooled(r) => {
103                write!(f, "Pooled: {:?}", r.deref())
104            }
105            CuHandleInner::Detached(r) => write!(f, "Detached: {:?}", r),
106        }
107    }
108}
109
110impl<T: ArrayLike> Deref for CuHandleInner<T> {
111    type Target = [T::Element];
112
113    fn deref(&self) -> &Self::Target {
114        match self {
115            CuHandleInner::Pooled(pooled) => pooled,
116            CuHandleInner::Detached(detached) => detached,
117        }
118    }
119}
120
121impl<T: ArrayLike> DerefMut for CuHandleInner<T> {
122    fn deref_mut(&mut self) -> &mut Self::Target {
123        match self {
124            CuHandleInner::Pooled(pooled) => pooled.deref_mut(),
125            CuHandleInner::Detached(detached) => detached,
126        }
127    }
128}
129
130/// A shareable handle to an Array coming from a pool (either host or device).
131#[derive(Clone, Debug)]
132pub struct CuHandle<T: ArrayLike>(Arc<Mutex<CuHandleInner<T>>>);
133
134impl<T: ArrayLike> Deref for CuHandle<T> {
135    type Target = Arc<Mutex<CuHandleInner<T>>>;
136
137    fn deref(&self) -> &Self::Target {
138        &self.0
139    }
140}
141
142impl<T: ArrayLike> CuHandle<T> {
143    /// Create a new CuHandle not part of a Pool (not for onboard usages, use pools instead)
144    pub fn new_detached(inner: T) -> Self {
145        CuHandle(Arc::new(Mutex::new(CuHandleInner::Detached(inner))))
146    }
147
148    /// Safely access the inner value, applying a closure to it.
149    pub fn with_inner<R>(&self, f: impl FnOnce(&CuHandleInner<T>) -> R) -> R {
150        let lock = self.lock().unwrap();
151        f(&*lock)
152    }
153
154    /// Mutably access the inner value, applying a closure to it.
155    pub fn with_inner_mut<R>(&self, f: impl FnOnce(&mut CuHandleInner<T>) -> R) -> R {
156        let mut lock = self.lock().unwrap();
157        f(&mut *lock)
158    }
159}
160
161impl<T: ArrayLike + Encode> Encode for CuHandle<T>
162where
163    <T as ArrayLike>::Element: 'static,
164{
165    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
166        let inner = self.lock().unwrap();
167        match inner.deref() {
168            CuHandleInner::Pooled(pooled) => pooled.deref().encode(encoder),
169            CuHandleInner::Detached(detached) => detached.encode(encoder),
170        }
171    }
172}
173
174impl<T: ArrayLike> Default for CuHandle<T> {
175    fn default() -> Self {
176        panic!("Cannot create a default CuHandle")
177    }
178}
179
180impl<U: ElementType + Decode<()> + 'static> Decode<()> for CuHandle<Vec<U>> {
181    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
182        let vec: Vec<U> = Vec::decode(decoder)?;
183        Ok(CuHandle(Arc::new(Mutex::new(CuHandleInner::Detached(vec)))))
184    }
185}
186
187/// A CuPool is a pool of buffers that can be shared between different parts of the code.
188/// Handles can be stored locally in the tasks and shared between them.
189pub trait CuPool<T: ArrayLike>: PoolMonitor {
190    /// Acquire a buffer from the pool.
191    fn acquire(&self) -> Option<CuHandle<T>>;
192
193    /// Copy data from a handle to a new handle from the pool.
194    fn copy_from<O>(&self, from: &mut CuHandle<O>) -> CuHandle<T>
195    where
196        O: ArrayLike<Element = T::Element>;
197}
198
199/// A device memory pool can copy data from a device to a host memory pool on top.
200pub trait DeviceCuPool<T: ArrayLike>: CuPool<T> {
201    /// Takes a handle to a device buffer and copies it into a host buffer pool.
202    /// It returns a new handle from the host pool with the data from the device handle given.
203    fn copy_to_host_pool<O>(
204        &self,
205        from_device_handle: &CuHandle<T>,
206        to_host_handle: &mut CuHandle<O>,
207    ) -> CuResult<()>
208    where
209        O: ArrayLike<Element = T::Element>;
210}
211
212/// A pool of host memory buffers.
213pub struct CuHostMemoryPool<T> {
214    /// Underlying pool of host buffers.
215    // Being an Arc is a requirement of try_pull_owned() so buffers can refer back to the pool.
216    id: PoolID,
217    pool: Arc<Pool<T>>,
218    size: usize,
219    buffer_size: usize,
220}
221
222impl<T: ArrayLike + 'static> CuHostMemoryPool<T> {
223    pub fn new<F>(id: &str, size: usize, buffer_initializer: F) -> CuResult<Arc<Self>>
224    where
225        F: Fn() -> T,
226    {
227        let pool = Arc::new(Pool::new(size, buffer_initializer));
228        let buffer_size = pool.try_pull().unwrap().len() * size_of::<T::Element>();
229
230        let og = Self {
231            id: PoolID::from(id).map_err(|_| "Failed to create PoolID")?,
232            pool,
233            size,
234            buffer_size,
235        };
236        let og = Arc::new(og);
237        register_pool(og.clone());
238        Ok(og)
239    }
240}
241
242impl<T: ArrayLike> PoolMonitor for CuHostMemoryPool<T> {
243    fn id(&self) -> PoolID {
244        self.id
245    }
246
247    fn space_left(&self) -> usize {
248        self.pool.len()
249    }
250
251    fn total_size(&self) -> usize {
252        self.size
253    }
254
255    fn buffer_size(&self) -> usize {
256        self.buffer_size
257    }
258}
259
260impl<T: ArrayLike> CuPool<T> for CuHostMemoryPool<T> {
261    fn acquire(&self) -> Option<CuHandle<T>> {
262        let owned_object = self.pool.try_pull_owned(); // Use the owned version
263
264        owned_object.map(|reusable| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(reusable)))))
265    }
266
267    fn copy_from<O: ArrayLike<Element = T::Element>>(&self, from: &mut CuHandle<O>) -> CuHandle<T> {
268        let to_handle = self.acquire().expect("No available buffers in the pool");
269
270        match from.lock().unwrap().deref() {
271            CuHandleInner::Detached(source) => match to_handle.lock().unwrap().deref_mut() {
272                CuHandleInner::Detached(destination) => {
273                    destination.copy_from_slice(source);
274                }
275                CuHandleInner::Pooled(destination) => {
276                    destination.copy_from_slice(source);
277                }
278            },
279            CuHandleInner::Pooled(source) => match to_handle.lock().unwrap().deref_mut() {
280                CuHandleInner::Detached(destination) => {
281                    destination.copy_from_slice(source);
282                }
283                CuHandleInner::Pooled(destination) => {
284                    destination.copy_from_slice(source);
285                }
286            },
287        }
288        to_handle
289    }
290}
291
292impl<E: ElementType + 'static> ArrayLike for Vec<E> {
293    type Element = E;
294}
295
296#[cfg(all(feature = "cuda", not(target_os = "macos")))]
297mod cuda {
298    use super::*;
299    use cu29_traits::CuError;
300    use cudarc::driver::{CudaDevice, CudaSlice, DeviceRepr, ValidAsZeroBits};
301    use std::sync::Arc;
302
303    #[derive(Debug)]
304    pub struct CudaSliceWrapper<E>(CudaSlice<E>);
305
306    impl<E> Deref for CudaSliceWrapper<E>
307    where
308        E: ElementType,
309    {
310        type Target = [E];
311
312        fn deref(&self) -> &Self::Target {
313            // Implement logic to return a slice
314            panic!("You need to copy data to host memory pool before accessing it.");
315        }
316    }
317
318    impl<E> DerefMut for CudaSliceWrapper<E>
319    where
320        E: ElementType,
321    {
322        fn deref_mut(&mut self) -> &mut Self::Target {
323            panic!("You need to copy data to host memory pool before accessing it.");
324        }
325    }
326
327    impl<E: ElementType> ArrayLike for CudaSliceWrapper<E> {
328        type Element = E;
329    }
330
331    impl<E> CudaSliceWrapper<E> {
332        pub fn as_cuda_slice(&self) -> &CudaSlice<E> {
333            &self.0
334        }
335
336        pub fn as_cuda_slice_mut(&mut self) -> &mut CudaSlice<E> {
337            &mut self.0
338        }
339    }
340
341    /// A pool of CUDA memory buffers.
342    pub struct CuCudaPool<E>
343    where
344        E: ElementType + ValidAsZeroBits + DeviceRepr + Unpin,
345    {
346        id: PoolID,
347        device: Arc<CudaDevice>,
348        pool: Arc<Pool<CudaSliceWrapper<E>>>,
349        nb_buffers: usize,
350        nb_element_per_buffer: usize,
351    }
352
353    impl<E: ElementType + ValidAsZeroBits + DeviceRepr> CuCudaPool<E> {
354        #[allow(dead_code)]
355        pub fn new(
356            id: &'static str,
357            device: Arc<CudaDevice>,
358            nb_buffers: usize,
359            nb_element_per_buffer: usize,
360        ) -> CuResult<Self> {
361            let pool = (0..nb_buffers)
362                .map(|_| {
363                    device
364                        .alloc_zeros(nb_element_per_buffer)
365                        .map(CudaSliceWrapper)
366                        .map_err(|_| "Failed to allocate device memory")
367                })
368                .collect::<Result<Vec<_>, _>>()?;
369
370            Ok(Self {
371                id: PoolID::from(id).map_err(|_| "Failed to create PoolID")?,
372                device: device.clone(),
373                pool: Arc::new(Pool::from_vec(pool)),
374                nb_buffers,
375                nb_element_per_buffer,
376            })
377        }
378    }
379
380    impl<E> PoolMonitor for CuCudaPool<E>
381    where
382        E: DeviceRepr + ElementType + ValidAsZeroBits,
383    {
384        fn id(&self) -> PoolID {
385            self.id
386        }
387
388        fn space_left(&self) -> usize {
389            self.pool.len()
390        }
391
392        fn total_size(&self) -> usize {
393            self.nb_buffers
394        }
395
396        fn buffer_size(&self) -> usize {
397            self.nb_element_per_buffer * size_of::<E>()
398        }
399    }
400
401    impl<E> CuPool<CudaSliceWrapper<E>> for CuCudaPool<E>
402    where
403        E: DeviceRepr + ElementType + ValidAsZeroBits,
404    {
405        fn acquire(&self) -> Option<CuHandle<CudaSliceWrapper<E>>> {
406            self.pool
407                .try_pull_owned()
408                .map(|x| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(x)))))
409        }
410
411        fn copy_from<O>(&self, from_handle: &mut CuHandle<O>) -> CuHandle<CudaSliceWrapper<E>>
412        where
413            O: ArrayLike<Element = E>,
414        {
415            let to_handle = self.acquire().expect("No available buffers in the pool");
416
417            match from_handle.lock().unwrap().deref() {
418                CuHandleInner::Detached(from) => match to_handle.lock().unwrap().deref_mut() {
419                    CuHandleInner::Detached(CudaSliceWrapper(to)) => {
420                        self.device
421                            .htod_sync_copy_into(from, to)
422                            .expect("Failed to copy data to device");
423                    }
424                    CuHandleInner::Pooled(to) => {
425                        self.device
426                            .htod_sync_copy_into(from, to.as_cuda_slice_mut())
427                            .expect("Failed to copy data to device");
428                    }
429                },
430                CuHandleInner::Pooled(from) => match to_handle.lock().unwrap().deref_mut() {
431                    CuHandleInner::Detached(CudaSliceWrapper(to)) => {
432                        self.device
433                            .htod_sync_copy_into(from, to)
434                            .expect("Failed to copy data to device");
435                    }
436                    CuHandleInner::Pooled(to) => {
437                        self.device
438                            .htod_sync_copy_into(from, to.as_cuda_slice_mut())
439                            .expect("Failed to copy data to device");
440                    }
441                },
442            }
443            to_handle
444        }
445    }
446
447    impl<E> DeviceCuPool<CudaSliceWrapper<E>> for CuCudaPool<E>
448    where
449        E: ElementType + ValidAsZeroBits + DeviceRepr,
450    {
451        /// Copy from device to host
452        fn copy_to_host_pool<O>(
453            &self,
454            device_handle: &CuHandle<CudaSliceWrapper<E>>,
455            host_handle: &mut CuHandle<O>,
456        ) -> Result<(), CuError>
457        where
458            O: ArrayLike<Element = E>,
459        {
460            match device_handle.lock().unwrap().deref() {
461                CuHandleInner::Pooled(source) => match host_handle.lock().unwrap().deref_mut() {
462                    CuHandleInner::Pooled(ref mut destination) => {
463                        self.device
464                            .dtoh_sync_copy_into(source.as_cuda_slice(), destination)
465                            .expect("Failed to copy data to device");
466                    }
467                    CuHandleInner::Detached(ref mut destination) => {
468                        self.device
469                            .dtoh_sync_copy_into(source.as_cuda_slice(), destination)
470                            .expect("Failed to copy data to device");
471                    }
472                },
473                CuHandleInner::Detached(source) => match host_handle.lock().unwrap().deref_mut() {
474                    CuHandleInner::Pooled(ref mut destination) => {
475                        self.device
476                            .dtoh_sync_copy_into(source.as_cuda_slice(), destination)
477                            .expect("Failed to copy data to device");
478                    }
479                    CuHandleInner::Detached(ref mut destination) => {
480                        self.device
481                            .dtoh_sync_copy_into(source.as_cuda_slice(), destination)
482                            .expect("Failed to copy data to device");
483                    }
484                },
485            }
486            Ok(())
487        }
488    }
489}
490
491#[derive(Debug)]
492/// A buffer that is aligned to a specific size with the Element of type E.
493pub struct AlignedBuffer<E: ElementType> {
494    ptr: *mut E,
495    size: usize,
496    layout: Layout,
497}
498
499impl<E: ElementType> AlignedBuffer<E> {
500    pub fn new(num_elements: usize, alignment: usize) -> Self {
501        let layout = Layout::from_size_align(num_elements * size_of::<E>(), alignment).unwrap();
502        let ptr = unsafe { alloc(layout) as *mut E };
503        if ptr.is_null() {
504            panic!("Failed to allocate memory");
505        }
506        Self {
507            ptr,
508            size: num_elements,
509            layout,
510        }
511    }
512}
513
514impl<E: ElementType> Deref for AlignedBuffer<E> {
515    type Target = [E];
516
517    fn deref(&self) -> &Self::Target {
518        unsafe { std::slice::from_raw_parts(self.ptr, self.size) }
519    }
520}
521
522impl<E: ElementType> DerefMut for AlignedBuffer<E> {
523    fn deref_mut(&mut self) -> &mut Self::Target {
524        unsafe { std::slice::from_raw_parts_mut(self.ptr, self.size) }
525    }
526}
527
528impl<E: ElementType> Drop for AlignedBuffer<E> {
529    fn drop(&mut self) {
530        if !self.ptr.is_null() {
531            unsafe {
532                dealloc(self.ptr as *mut u8, self.layout);
533            }
534        }
535    }
536}
537
538#[cfg(test)]
539mod tests {
540    use super::*;
541    #[cfg(all(feature = "cuda", not(target_os = "macos")))]
542    use crate::pool::cuda::CuCudaPool;
543    use std::cell::RefCell;
544
545    #[test]
546    fn test_pool() {
547        let objs = RefCell::new(vec![vec![1], vec![2], vec![3]]);
548        let holding = objs.borrow().clone();
549        let objs_as_slices = holding.iter().map(|x| x.as_slice()).collect::<Vec<_>>();
550        let pool = CuHostMemoryPool::new("mytestcudapool", 3, || objs.borrow_mut().pop().unwrap())
551            .unwrap();
552
553        let obj1 = pool.acquire().unwrap();
554        {
555            let obj2 = pool.acquire().unwrap();
556            assert!(objs_as_slices.contains(&obj1.lock().unwrap().deref().deref()));
557            assert!(objs_as_slices.contains(&obj2.lock().unwrap().deref().deref()));
558            assert_eq!(pool.space_left(), 1);
559        }
560        assert_eq!(pool.space_left(), 2);
561
562        let obj3 = pool.acquire().unwrap();
563        assert!(objs_as_slices.contains(&obj3.lock().unwrap().deref().deref()));
564
565        assert_eq!(pool.space_left(), 1);
566
567        let _obj4 = pool.acquire().unwrap();
568        assert_eq!(pool.space_left(), 0);
569
570        let obj5 = pool.acquire();
571        assert!(obj5.is_none());
572    }
573
574    #[cfg(all(feature = "cuda", not(target_os = "macos")))]
575    #[test]
576    #[ignore] // Can only be executed if a real CUDA device is present
577    fn test_cuda_pool() {
578        use cudarc::driver::CudaDevice;
579        let device = CudaDevice::new(0).unwrap();
580        let pool = CuCudaPool::<f32>::new("mytestcudapool", device, 3, 1).unwrap();
581
582        let _obj1 = pool.acquire().unwrap();
583
584        {
585            let _obj2 = pool.acquire().unwrap();
586            assert_eq!(pool.space_left(), 1);
587        }
588        assert_eq!(pool.space_left(), 2);
589
590        let _obj3 = pool.acquire().unwrap();
591
592        assert_eq!(pool.space_left(), 1);
593
594        let _obj4 = pool.acquire().unwrap();
595        assert_eq!(pool.space_left(), 0);
596
597        let obj5 = pool.acquire();
598        assert!(obj5.is_none());
599    }
600
601    #[cfg(all(feature = "cuda", not(target_os = "macos")))]
602    #[test]
603    #[ignore] // Can only be executed if a real CUDA device is present
604    fn test_copy_roundtrip() {
605        use cudarc::driver::CudaDevice;
606        let device = CudaDevice::new(0).unwrap();
607        let host_pool = CuHostMemoryPool::new("mytesthostpool", 3, || vec![0.0; 1]).unwrap();
608        let cuda_pool = CuCudaPool::<f32>::new("mytestcudapool", device, 3, 1).unwrap();
609
610        let cuda_handle = {
611            let mut initial_handle = host_pool.acquire().unwrap();
612            {
613                let mut inner_initial_handle = initial_handle.lock().unwrap();
614                if let CuHandleInner::Pooled(ref mut pooled) = *inner_initial_handle {
615                    pooled[0] = 42.0;
616                } else {
617                    panic!();
618                }
619            }
620
621            // send that to the GPU
622            cuda_pool.copy_from(&mut initial_handle)
623        };
624
625        // get it back to the host
626        let mut final_handle = host_pool.acquire().unwrap();
627        cuda_pool
628            .copy_to_host_pool(&cuda_handle, &mut final_handle)
629            .unwrap();
630
631        let value = final_handle.lock().unwrap().deref().deref()[0];
632        assert_eq!(value, 42.0);
633    }
634}