1use arrayvec::ArrayString;
2use bincode::de::Decoder;
3use bincode::enc::Encoder;
4use bincode::error::{DecodeError, EncodeError};
5use bincode::{Decode, Encode};
6use cu29_traits::CuResult;
7use object_pool::{Pool, ReusableOwned};
8use smallvec::SmallVec;
9use std::alloc::{alloc, dealloc, Layout};
10use std::collections::HashMap;
11use std::fmt::Debug;
12use std::ops::{Deref, DerefMut};
13use std::sync::{Arc, Mutex, OnceLock};
14
15type PoolID = ArrayString<64>;
16
17pub trait PoolMonitor: Send + Sync {
19 fn id(&self) -> PoolID;
21
22 fn space_left(&self) -> usize;
24
25 fn total_size(&self) -> usize;
27
28 fn buffer_size(&self) -> usize;
30}
31
32static POOL_REGISTRY: OnceLock<Mutex<HashMap<String, Arc<dyn PoolMonitor>>>> = OnceLock::new();
33const MAX_POOLS: usize = 16;
34
35fn register_pool(pool: Arc<dyn PoolMonitor>) {
37 POOL_REGISTRY
38 .get_or_init(|| Mutex::new(HashMap::new()))
39 .lock()
40 .unwrap()
41 .insert(pool.id().to_string(), pool);
42}
43
44type PoolStats = (PoolID, usize, usize, usize);
45
46pub fn pools_statistics() -> SmallVec<[PoolStats; MAX_POOLS]> {
49 let registry = POOL_REGISTRY.get().unwrap().lock().unwrap();
50 let mut result = SmallVec::with_capacity(MAX_POOLS);
51 for pool in registry.values() {
52 result.push((
53 pool.id(),
54 pool.space_left(),
55 pool.total_size(),
56 pool.buffer_size(),
57 ));
58 }
59 result
60}
61
62pub trait ElementType:
64 Default + Sized + Copy + Encode + Decode + Debug + Unpin + Send + Sync
65{
66}
67
68impl<T> ElementType for T where
70 T: Default + Sized + Copy + Encode + Decode + Debug + Unpin + Send + Sync
71{
72}
73
74pub trait ArrayLike: Deref<Target = [Self::Element]> + DerefMut + Debug + Sync + Send {
75 type Element: ElementType;
76}
77
78pub enum CuHandleInner<T: Debug> {
82 Pooled(ReusableOwned<T>),
83 Detached(T), }
85
86impl<T> Debug for CuHandleInner<T>
87where
88 T: Debug,
89{
90 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91 match self {
92 CuHandleInner::Pooled(r) => {
93 write!(f, "Pooled: {:?}", r.deref())
94 }
95 CuHandleInner::Detached(r) => write!(f, "Detached: {:?}", r),
96 }
97 }
98}
99
100impl<T: ArrayLike> Deref for CuHandleInner<T> {
101 type Target = [T::Element];
102
103 fn deref(&self) -> &Self::Target {
104 match self {
105 CuHandleInner::Pooled(pooled) => pooled,
106 CuHandleInner::Detached(detached) => detached,
107 }
108 }
109}
110
111impl<T: ArrayLike> DerefMut for CuHandleInner<T> {
112 fn deref_mut(&mut self) -> &mut Self::Target {
113 match self {
114 CuHandleInner::Pooled(pooled) => pooled.deref_mut(),
115 CuHandleInner::Detached(detached) => detached,
116 }
117 }
118}
119
120#[derive(Clone, Debug)]
122pub struct CuHandle<T: ArrayLike>(Arc<Mutex<CuHandleInner<T>>>);
123
124impl<T: ArrayLike> Deref for CuHandle<T> {
125 type Target = Arc<Mutex<CuHandleInner<T>>>;
126
127 fn deref(&self) -> &Self::Target {
128 &self.0
129 }
130}
131
132impl<T: ArrayLike> CuHandle<T> {
133 pub fn new_detached(inner: T) -> Self {
135 CuHandle(Arc::new(Mutex::new(CuHandleInner::Detached(inner))))
136 }
137
138 pub fn with_inner<R>(&self, f: impl FnOnce(&CuHandleInner<T>) -> R) -> R {
140 let lock = self.lock().unwrap();
141 f(&*lock)
142 }
143
144 pub fn with_inner_mut<R>(&self, f: impl FnOnce(&mut CuHandleInner<T>) -> R) -> R {
146 let mut lock = self.lock().unwrap();
147 f(&mut *lock)
148 }
149}
150
151impl<T: ArrayLike> Encode for CuHandle<T>
152where
153 <T as ArrayLike>::Element: 'static,
154{
155 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
156 let inner = self.lock().unwrap();
157 match inner.deref() {
158 CuHandleInner::Pooled(pooled) => pooled.encode(encoder),
159 CuHandleInner::Detached(detached) => detached.encode(encoder),
160 }
161 }
162}
163
164impl<T: ArrayLike> Default for CuHandle<T> {
165 fn default() -> Self {
166 panic!("Cannot create a default CuHandle")
167 }
168}
169
170impl<U: ElementType + 'static> Decode for CuHandle<Vec<U>> {
171 fn decode<D: Decoder>(decoder: &mut D) -> Result<Self, DecodeError> {
172 let vec: Vec<U> = Vec::decode(decoder)?;
173 Ok(CuHandle(Arc::new(Mutex::new(CuHandleInner::Detached(vec)))))
174 }
175}
176
177pub trait CuPool<T: ArrayLike>: PoolMonitor {
180 fn acquire(&self) -> Option<CuHandle<T>>;
182
183 fn copy_from<O>(&self, from: &mut CuHandle<O>) -> CuHandle<T>
185 where
186 O: ArrayLike<Element = T::Element>;
187}
188
189pub trait DeviceCuPool<T: ArrayLike>: CuPool<T> {
191 fn copy_to_host_pool<O>(
194 &self,
195 from_device_handle: &CuHandle<T>,
196 to_host_handle: &mut CuHandle<O>,
197 ) -> CuResult<()>
198 where
199 O: ArrayLike<Element = T::Element>;
200}
201
202pub struct CuHostMemoryPool<T> {
204 id: PoolID,
207 pool: Arc<Pool<T>>,
208 size: usize,
209 buffer_size: usize,
210}
211
212impl<T: ArrayLike + 'static> CuHostMemoryPool<T> {
213 pub fn new<F>(id: &str, size: usize, buffer_initializer: F) -> CuResult<Arc<Self>>
214 where
215 F: Fn() -> T,
216 {
217 let pool = Arc::new(Pool::new(size, buffer_initializer));
218 let buffer_size = pool.try_pull().unwrap().len() * size_of::<T::Element>();
219
220 let og = Self {
221 id: PoolID::from(id).map_err(|_| "Failed to create PoolID")?,
222 pool,
223 size,
224 buffer_size,
225 };
226 let og = Arc::new(og);
227 register_pool(og.clone());
228 Ok(og)
229 }
230}
231
232impl<T: ArrayLike> PoolMonitor for CuHostMemoryPool<T> {
233 fn id(&self) -> PoolID {
234 self.id
235 }
236
237 fn space_left(&self) -> usize {
238 self.pool.len()
239 }
240
241 fn total_size(&self) -> usize {
242 self.size
243 }
244
245 fn buffer_size(&self) -> usize {
246 self.buffer_size
247 }
248}
249
250impl<T: ArrayLike> CuPool<T> for CuHostMemoryPool<T> {
251 fn acquire(&self) -> Option<CuHandle<T>> {
252 let owned_object = self.pool.try_pull_owned(); owned_object.map(|reusable| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(reusable)))))
255 }
256
257 fn copy_from<O: ArrayLike<Element = T::Element>>(&self, from: &mut CuHandle<O>) -> CuHandle<T> {
258 let to_handle = self.acquire().expect("No available buffers in the pool");
259
260 match from.lock().unwrap().deref() {
261 CuHandleInner::Detached(source) => match to_handle.lock().unwrap().deref_mut() {
262 CuHandleInner::Detached(destination) => {
263 destination.copy_from_slice(source);
264 }
265 CuHandleInner::Pooled(destination) => {
266 destination.copy_from_slice(source);
267 }
268 },
269 CuHandleInner::Pooled(source) => match to_handle.lock().unwrap().deref_mut() {
270 CuHandleInner::Detached(destination) => {
271 destination.copy_from_slice(source);
272 }
273 CuHandleInner::Pooled(destination) => {
274 destination.copy_from_slice(source);
275 }
276 },
277 }
278 to_handle
279 }
280}
281
282impl<E: ElementType + 'static> ArrayLike for Vec<E> {
283 type Element = E;
284}
285
286#[cfg(all(feature = "cuda", not(target_os = "macos")))]
287mod cuda {
288 use super::*;
289 use cu29_traits::CuError;
290 use cudarc::driver::{CudaDevice, CudaSlice, DeviceRepr, ValidAsZeroBits};
291 use std::sync::Arc;
292
293 #[derive(Debug)]
294 pub struct CudaSliceWrapper<E>(CudaSlice<E>);
295
296 impl<E> Deref for CudaSliceWrapper<E>
297 where
298 E: ElementType,
299 {
300 type Target = [E];
301
302 fn deref(&self) -> &Self::Target {
303 panic!("You need to copy data to host memory pool before accessing it.");
305 }
306 }
307
308 impl<E> DerefMut for CudaSliceWrapper<E>
309 where
310 E: ElementType,
311 {
312 fn deref_mut(&mut self) -> &mut Self::Target {
313 panic!("You need to copy data to host memory pool before accessing it.");
314 }
315 }
316
317 impl<E: ElementType> ArrayLike for CudaSliceWrapper<E> {
318 type Element = E;
319 }
320
321 impl<E> CudaSliceWrapper<E> {
322 pub fn as_cuda_slice(&self) -> &CudaSlice<E> {
323 &self.0
324 }
325
326 pub fn as_cuda_slice_mut(&mut self) -> &mut CudaSlice<E> {
327 &mut self.0
328 }
329 }
330
331 pub struct CuCudaPool<E>
333 where
334 E: ElementType + ValidAsZeroBits + DeviceRepr + Unpin,
335 {
336 id: PoolID,
337 device: Arc<CudaDevice>,
338 pool: Arc<Pool<CudaSliceWrapper<E>>>,
339 nb_buffers: usize,
340 nb_element_per_buffer: usize,
341 }
342
343 impl<E: ElementType + ValidAsZeroBits + DeviceRepr> CuCudaPool<E> {
344 #[allow(dead_code)]
345 pub fn new(
346 id: &'static str,
347 device: Arc<CudaDevice>,
348 nb_buffers: usize,
349 nb_element_per_buffer: usize,
350 ) -> CuResult<Self> {
351 let pool = (0..nb_buffers)
352 .map(|_| {
353 device
354 .alloc_zeros(nb_element_per_buffer)
355 .map(CudaSliceWrapper)
356 .map_err(|_| "Failed to allocate device memory")
357 })
358 .collect::<Result<Vec<_>, _>>()?;
359
360 Ok(Self {
361 id: PoolID::from(id).map_err(|_| "Failed to create PoolID")?,
362 device: device.clone(),
363 pool: Arc::new(Pool::from_vec(pool)),
364 nb_buffers,
365 nb_element_per_buffer,
366 })
367 }
368 }
369
370 impl<E> PoolMonitor for CuCudaPool<E>
371 where
372 E: DeviceRepr + ElementType + ValidAsZeroBits,
373 {
374 fn id(&self) -> PoolID {
375 self.id
376 }
377
378 fn space_left(&self) -> usize {
379 self.pool.len()
380 }
381
382 fn total_size(&self) -> usize {
383 self.nb_buffers
384 }
385
386 fn buffer_size(&self) -> usize {
387 self.nb_element_per_buffer * size_of::<E>()
388 }
389 }
390
391 impl<E> CuPool<CudaSliceWrapper<E>> for CuCudaPool<E>
392 where
393 E: DeviceRepr + ElementType + ValidAsZeroBits,
394 {
395 fn acquire(&self) -> Option<CuHandle<CudaSliceWrapper<E>>> {
396 self.pool
397 .try_pull_owned()
398 .map(|x| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(x)))))
399 }
400
401 fn copy_from<O>(&self, from_handle: &mut CuHandle<O>) -> CuHandle<CudaSliceWrapper<E>>
402 where
403 O: ArrayLike<Element = E>,
404 {
405 let to_handle = self.acquire().expect("No available buffers in the pool");
406
407 match from_handle.lock().unwrap().deref() {
408 CuHandleInner::Detached(from) => match to_handle.lock().unwrap().deref_mut() {
409 CuHandleInner::Detached(CudaSliceWrapper(to)) => {
410 self.device
411 .htod_sync_copy_into(from, to)
412 .expect("Failed to copy data to device");
413 }
414 CuHandleInner::Pooled(to) => {
415 self.device
416 .htod_sync_copy_into(from, to.as_cuda_slice_mut())
417 .expect("Failed to copy data to device");
418 }
419 },
420 CuHandleInner::Pooled(from) => match to_handle.lock().unwrap().deref_mut() {
421 CuHandleInner::Detached(CudaSliceWrapper(to)) => {
422 self.device
423 .htod_sync_copy_into(from, to)
424 .expect("Failed to copy data to device");
425 }
426 CuHandleInner::Pooled(to) => {
427 self.device
428 .htod_sync_copy_into(from, to.as_cuda_slice_mut())
429 .expect("Failed to copy data to device");
430 }
431 },
432 }
433 to_handle
434 }
435 }
436
437 impl<E> DeviceCuPool<CudaSliceWrapper<E>> for CuCudaPool<E>
438 where
439 E: ElementType + ValidAsZeroBits + DeviceRepr,
440 {
441 fn copy_to_host_pool<O>(
443 &self,
444 device_handle: &CuHandle<CudaSliceWrapper<E>>,
445 host_handle: &mut CuHandle<O>,
446 ) -> Result<(), CuError>
447 where
448 O: ArrayLike<Element = E>,
449 {
450 match device_handle.lock().unwrap().deref() {
451 CuHandleInner::Pooled(source) => match host_handle.lock().unwrap().deref_mut() {
452 CuHandleInner::Pooled(ref mut destination) => {
453 self.device
454 .dtoh_sync_copy_into(source.as_cuda_slice(), destination)
455 .expect("Failed to copy data to device");
456 }
457 CuHandleInner::Detached(ref mut destination) => {
458 self.device
459 .dtoh_sync_copy_into(source.as_cuda_slice(), destination)
460 .expect("Failed to copy data to device");
461 }
462 },
463 CuHandleInner::Detached(source) => match host_handle.lock().unwrap().deref_mut() {
464 CuHandleInner::Pooled(ref mut destination) => {
465 self.device
466 .dtoh_sync_copy_into(source.as_cuda_slice(), destination)
467 .expect("Failed to copy data to device");
468 }
469 CuHandleInner::Detached(ref mut destination) => {
470 self.device
471 .dtoh_sync_copy_into(source.as_cuda_slice(), destination)
472 .expect("Failed to copy data to device");
473 }
474 },
475 }
476 Ok(())
477 }
478 }
479}
480
481#[derive(Debug)]
482pub struct AlignedBuffer<E: ElementType> {
484 ptr: *mut E,
485 size: usize,
486 layout: Layout,
487}
488
489impl<E: ElementType> AlignedBuffer<E> {
490 pub fn new(num_elements: usize, alignment: usize) -> Self {
491 let layout = Layout::from_size_align(num_elements * size_of::<E>(), alignment).unwrap();
492 let ptr = unsafe { alloc(layout) as *mut E };
493 if ptr.is_null() {
494 panic!("Failed to allocate memory");
495 }
496 Self {
497 ptr,
498 size: num_elements,
499 layout,
500 }
501 }
502}
503
504impl<E: ElementType> Deref for AlignedBuffer<E> {
505 type Target = [E];
506
507 fn deref(&self) -> &Self::Target {
508 unsafe { std::slice::from_raw_parts(self.ptr, self.size) }
509 }
510}
511
512impl<E: ElementType> DerefMut for AlignedBuffer<E> {
513 fn deref_mut(&mut self) -> &mut Self::Target {
514 unsafe { std::slice::from_raw_parts_mut(self.ptr, self.size) }
515 }
516}
517
518impl<E: ElementType> Drop for AlignedBuffer<E> {
519 fn drop(&mut self) {
520 if !self.ptr.is_null() {
521 unsafe {
522 dealloc(self.ptr as *mut u8, self.layout);
523 }
524 }
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531 #[cfg(all(feature = "cuda", not(target_os = "macos")))]
532 use crate::pool::cuda::CuCudaPool;
533 use std::cell::RefCell;
534
535 #[test]
536 fn test_pool() {
537 let objs = RefCell::new(vec![vec![1], vec![2], vec![3]]);
538 let holding = objs.borrow().clone();
539 let objs_as_slices = holding.iter().map(|x| x.as_slice()).collect::<Vec<_>>();
540 let pool = CuHostMemoryPool::new("mytestcudapool", 3, || objs.borrow_mut().pop().unwrap())
541 .unwrap();
542
543 let obj1 = pool.acquire().unwrap();
544 {
545 let obj2 = pool.acquire().unwrap();
546 assert!(objs_as_slices.contains(&obj1.lock().unwrap().deref().deref()));
547 assert!(objs_as_slices.contains(&obj2.lock().unwrap().deref().deref()));
548 assert_eq!(pool.space_left(), 1);
549 }
550 assert_eq!(pool.space_left(), 2);
551
552 let obj3 = pool.acquire().unwrap();
553 assert!(objs_as_slices.contains(&obj3.lock().unwrap().deref().deref()));
554
555 assert_eq!(pool.space_left(), 1);
556
557 let _obj4 = pool.acquire().unwrap();
558 assert_eq!(pool.space_left(), 0);
559
560 let obj5 = pool.acquire();
561 assert!(obj5.is_none());
562 }
563
564 #[cfg(all(feature = "cuda", not(target_os = "macos")))]
565 #[test]
566 #[ignore] fn test_cuda_pool() {
568 use cudarc::driver::CudaDevice;
569 let device = CudaDevice::new(0).unwrap();
570 let pool = CuCudaPool::<f32>::new("mytestcudapool", device, 3, 1).unwrap();
571
572 let _obj1 = pool.acquire().unwrap();
573
574 {
575 let _obj2 = pool.acquire().unwrap();
576 assert_eq!(pool.space_left(), 1);
577 }
578 assert_eq!(pool.space_left(), 2);
579
580 let _obj3 = pool.acquire().unwrap();
581
582 assert_eq!(pool.space_left(), 1);
583
584 let _obj4 = pool.acquire().unwrap();
585 assert_eq!(pool.space_left(), 0);
586
587 let obj5 = pool.acquire();
588 assert!(obj5.is_none());
589 }
590
591 #[cfg(all(feature = "cuda", not(target_os = "macos")))]
592 #[test]
593 #[ignore] fn test_copy_roundtrip() {
595 use cudarc::driver::CudaDevice;
596 let device = CudaDevice::new(0).unwrap();
597 let host_pool = CuHostMemoryPool::new("mytesthostpool", 3, || vec![0.0; 1]).unwrap();
598 let cuda_pool = CuCudaPool::<f32>::new("mytestcudapool", device, 3, 1).unwrap();
599
600 let cuda_handle = {
601 let mut initial_handle = host_pool.acquire().unwrap();
602 {
603 let mut inner_initial_handle = initial_handle.lock().unwrap();
604 if let CuHandleInner::Pooled(ref mut pooled) = *inner_initial_handle {
605 pooled[0] = 42.0;
606 } else {
607 panic!();
608 }
609 }
610
611 cuda_pool.copy_from(&mut initial_handle)
613 };
614
615 let mut final_handle = host_pool.acquire().unwrap();
617 cuda_pool
618 .copy_to_host_pool(&cuda_handle, &mut final_handle)
619 .unwrap();
620
621 let value = final_handle.lock().unwrap().deref().deref()[0];
622 assert_eq!(value, 42.0);
623 }
624}