1use arrayvec::ArrayString;
2use bincode::de::Decoder;
3use bincode::enc::Encoder;
4use bincode::error::{DecodeError, EncodeError};
5use bincode::{Decode, Encode};
6use cu29_traits::CuResult;
7use object_pool::{Pool, ReusableOwned};
8use smallvec::SmallVec;
9use std::alloc::{alloc, dealloc, Layout};
10use std::collections::HashMap;
11use std::fmt::Debug;
12use std::ops::{Deref, DerefMut};
13use std::sync::{Arc, Mutex, OnceLock};
14
15type PoolID = ArrayString<64>;
16
17pub trait PoolMonitor: Send + Sync {
19 fn id(&self) -> PoolID;
21
22 fn space_left(&self) -> usize;
24
25 fn total_size(&self) -> usize;
27
28 fn buffer_size(&self) -> usize;
30}
31
32static POOL_REGISTRY: OnceLock<Mutex<HashMap<String, Arc<dyn PoolMonitor>>>> = OnceLock::new();
33const MAX_POOLS: usize = 16;
34
35fn register_pool(pool: Arc<dyn PoolMonitor>) {
37 POOL_REGISTRY
38 .get_or_init(|| Mutex::new(HashMap::new()))
39 .lock()
40 .unwrap()
41 .insert(pool.id().to_string(), pool);
42}
43
44type PoolStats = (PoolID, usize, usize, usize);
45
46pub fn pools_statistics() -> SmallVec<[PoolStats; MAX_POOLS]> {
49 let registry = POOL_REGISTRY.get().unwrap().lock().unwrap();
50 let mut result = SmallVec::with_capacity(MAX_POOLS);
51 for pool in registry.values() {
52 result.push((
53 pool.id(),
54 pool.space_left(),
55 pool.total_size(),
56 pool.buffer_size(),
57 ));
58 }
59 result
60}
61
62pub trait ElementType: Default + Sized + Copy + Debug + Unpin + Send + Sync {
64 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError>;
65 fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError>;
66}
67
68impl<T> ElementType for T
70where
71 T: Default + Sized + Copy + Debug + Unpin + Send + Sync,
72 T: Encode,
73 T: Decode<()>,
74{
75 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
76 self.encode(encoder)
77 }
78
79 fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
80 Self::decode(decoder)
81 }
82}
83
84pub trait ArrayLike: Deref<Target = [Self::Element]> + DerefMut + Debug + Sync + Send {
85 type Element: ElementType;
86}
87
88pub enum CuHandleInner<T: Debug> {
92 Pooled(ReusableOwned<T>),
93 Detached(T), }
95
96impl<T> Debug for CuHandleInner<T>
97where
98 T: Debug,
99{
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 match self {
102 CuHandleInner::Pooled(r) => {
103 write!(f, "Pooled: {:?}", r.deref())
104 }
105 CuHandleInner::Detached(r) => write!(f, "Detached: {:?}", r),
106 }
107 }
108}
109
110impl<T: ArrayLike> Deref for CuHandleInner<T> {
111 type Target = [T::Element];
112
113 fn deref(&self) -> &Self::Target {
114 match self {
115 CuHandleInner::Pooled(pooled) => pooled,
116 CuHandleInner::Detached(detached) => detached,
117 }
118 }
119}
120
121impl<T: ArrayLike> DerefMut for CuHandleInner<T> {
122 fn deref_mut(&mut self) -> &mut Self::Target {
123 match self {
124 CuHandleInner::Pooled(pooled) => pooled.deref_mut(),
125 CuHandleInner::Detached(detached) => detached,
126 }
127 }
128}
129
130#[derive(Clone, Debug)]
132pub struct CuHandle<T: ArrayLike>(Arc<Mutex<CuHandleInner<T>>>);
133
134impl<T: ArrayLike> Deref for CuHandle<T> {
135 type Target = Arc<Mutex<CuHandleInner<T>>>;
136
137 fn deref(&self) -> &Self::Target {
138 &self.0
139 }
140}
141
142impl<T: ArrayLike> CuHandle<T> {
143 pub fn new_detached(inner: T) -> Self {
145 CuHandle(Arc::new(Mutex::new(CuHandleInner::Detached(inner))))
146 }
147
148 pub fn with_inner<R>(&self, f: impl FnOnce(&CuHandleInner<T>) -> R) -> R {
150 let lock = self.lock().unwrap();
151 f(&*lock)
152 }
153
154 pub fn with_inner_mut<R>(&self, f: impl FnOnce(&mut CuHandleInner<T>) -> R) -> R {
156 let mut lock = self.lock().unwrap();
157 f(&mut *lock)
158 }
159}
160
161impl<T: ArrayLike + Encode> Encode for CuHandle<T>
162where
163 <T as ArrayLike>::Element: 'static,
164{
165 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
166 let inner = self.lock().unwrap();
167 match inner.deref() {
168 CuHandleInner::Pooled(pooled) => pooled.deref().encode(encoder),
169 CuHandleInner::Detached(detached) => detached.encode(encoder),
170 }
171 }
172}
173
174impl<T: ArrayLike> Default for CuHandle<T> {
175 fn default() -> Self {
176 panic!("Cannot create a default CuHandle")
177 }
178}
179
180impl<U: ElementType + Decode<()> + 'static> Decode<()> for CuHandle<Vec<U>> {
181 fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
182 let vec: Vec<U> = Vec::decode(decoder)?;
183 Ok(CuHandle(Arc::new(Mutex::new(CuHandleInner::Detached(vec)))))
184 }
185}
186
187pub trait CuPool<T: ArrayLike>: PoolMonitor {
190 fn acquire(&self) -> Option<CuHandle<T>>;
192
193 fn copy_from<O>(&self, from: &mut CuHandle<O>) -> CuHandle<T>
195 where
196 O: ArrayLike<Element = T::Element>;
197}
198
199pub trait DeviceCuPool<T: ArrayLike>: CuPool<T> {
201 fn copy_to_host_pool<O>(
204 &self,
205 from_device_handle: &CuHandle<T>,
206 to_host_handle: &mut CuHandle<O>,
207 ) -> CuResult<()>
208 where
209 O: ArrayLike<Element = T::Element>;
210}
211
212pub struct CuHostMemoryPool<T> {
214 id: PoolID,
217 pool: Arc<Pool<T>>,
218 size: usize,
219 buffer_size: usize,
220}
221
222impl<T: ArrayLike + 'static> CuHostMemoryPool<T> {
223 pub fn new<F>(id: &str, size: usize, buffer_initializer: F) -> CuResult<Arc<Self>>
224 where
225 F: Fn() -> T,
226 {
227 let pool = Arc::new(Pool::new(size, buffer_initializer));
228 let buffer_size = pool.try_pull().unwrap().len() * size_of::<T::Element>();
229
230 let og = Self {
231 id: PoolID::from(id).map_err(|_| "Failed to create PoolID")?,
232 pool,
233 size,
234 buffer_size,
235 };
236 let og = Arc::new(og);
237 register_pool(og.clone());
238 Ok(og)
239 }
240}
241
242impl<T: ArrayLike> PoolMonitor for CuHostMemoryPool<T> {
243 fn id(&self) -> PoolID {
244 self.id
245 }
246
247 fn space_left(&self) -> usize {
248 self.pool.len()
249 }
250
251 fn total_size(&self) -> usize {
252 self.size
253 }
254
255 fn buffer_size(&self) -> usize {
256 self.buffer_size
257 }
258}
259
260impl<T: ArrayLike> CuPool<T> for CuHostMemoryPool<T> {
261 fn acquire(&self) -> Option<CuHandle<T>> {
262 let owned_object = self.pool.try_pull_owned(); owned_object.map(|reusable| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(reusable)))))
265 }
266
267 fn copy_from<O: ArrayLike<Element = T::Element>>(&self, from: &mut CuHandle<O>) -> CuHandle<T> {
268 let to_handle = self.acquire().expect("No available buffers in the pool");
269
270 match from.lock().unwrap().deref() {
271 CuHandleInner::Detached(source) => match to_handle.lock().unwrap().deref_mut() {
272 CuHandleInner::Detached(destination) => {
273 destination.copy_from_slice(source);
274 }
275 CuHandleInner::Pooled(destination) => {
276 destination.copy_from_slice(source);
277 }
278 },
279 CuHandleInner::Pooled(source) => match to_handle.lock().unwrap().deref_mut() {
280 CuHandleInner::Detached(destination) => {
281 destination.copy_from_slice(source);
282 }
283 CuHandleInner::Pooled(destination) => {
284 destination.copy_from_slice(source);
285 }
286 },
287 }
288 to_handle
289 }
290}
291
292impl<E: ElementType + 'static> ArrayLike for Vec<E> {
293 type Element = E;
294}
295
296#[cfg(all(feature = "cuda", not(target_os = "macos")))]
297mod cuda {
298 use super::*;
299 use cu29_traits::CuError;
300 use cudarc::driver::{CudaDevice, CudaSlice, DeviceRepr, ValidAsZeroBits};
301 use std::sync::Arc;
302
303 #[derive(Debug)]
304 pub struct CudaSliceWrapper<E>(CudaSlice<E>);
305
306 impl<E> Deref for CudaSliceWrapper<E>
307 where
308 E: ElementType,
309 {
310 type Target = [E];
311
312 fn deref(&self) -> &Self::Target {
313 panic!("You need to copy data to host memory pool before accessing it.");
315 }
316 }
317
318 impl<E> DerefMut for CudaSliceWrapper<E>
319 where
320 E: ElementType,
321 {
322 fn deref_mut(&mut self) -> &mut Self::Target {
323 panic!("You need to copy data to host memory pool before accessing it.");
324 }
325 }
326
327 impl<E: ElementType> ArrayLike for CudaSliceWrapper<E> {
328 type Element = E;
329 }
330
331 impl<E> CudaSliceWrapper<E> {
332 pub fn as_cuda_slice(&self) -> &CudaSlice<E> {
333 &self.0
334 }
335
336 pub fn as_cuda_slice_mut(&mut self) -> &mut CudaSlice<E> {
337 &mut self.0
338 }
339 }
340
341 pub struct CuCudaPool<E>
343 where
344 E: ElementType + ValidAsZeroBits + DeviceRepr + Unpin,
345 {
346 id: PoolID,
347 device: Arc<CudaDevice>,
348 pool: Arc<Pool<CudaSliceWrapper<E>>>,
349 nb_buffers: usize,
350 nb_element_per_buffer: usize,
351 }
352
353 impl<E: ElementType + ValidAsZeroBits + DeviceRepr> CuCudaPool<E> {
354 #[allow(dead_code)]
355 pub fn new(
356 id: &'static str,
357 device: Arc<CudaDevice>,
358 nb_buffers: usize,
359 nb_element_per_buffer: usize,
360 ) -> CuResult<Self> {
361 let pool = (0..nb_buffers)
362 .map(|_| {
363 device
364 .alloc_zeros(nb_element_per_buffer)
365 .map(CudaSliceWrapper)
366 .map_err(|_| "Failed to allocate device memory")
367 })
368 .collect::<Result<Vec<_>, _>>()?;
369
370 Ok(Self {
371 id: PoolID::from(id).map_err(|_| "Failed to create PoolID")?,
372 device: device.clone(),
373 pool: Arc::new(Pool::from_vec(pool)),
374 nb_buffers,
375 nb_element_per_buffer,
376 })
377 }
378 }
379
380 impl<E> PoolMonitor for CuCudaPool<E>
381 where
382 E: DeviceRepr + ElementType + ValidAsZeroBits,
383 {
384 fn id(&self) -> PoolID {
385 self.id
386 }
387
388 fn space_left(&self) -> usize {
389 self.pool.len()
390 }
391
392 fn total_size(&self) -> usize {
393 self.nb_buffers
394 }
395
396 fn buffer_size(&self) -> usize {
397 self.nb_element_per_buffer * size_of::<E>()
398 }
399 }
400
401 impl<E> CuPool<CudaSliceWrapper<E>> for CuCudaPool<E>
402 where
403 E: DeviceRepr + ElementType + ValidAsZeroBits,
404 {
405 fn acquire(&self) -> Option<CuHandle<CudaSliceWrapper<E>>> {
406 self.pool
407 .try_pull_owned()
408 .map(|x| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(x)))))
409 }
410
411 fn copy_from<O>(&self, from_handle: &mut CuHandle<O>) -> CuHandle<CudaSliceWrapper<E>>
412 where
413 O: ArrayLike<Element = E>,
414 {
415 let to_handle = self.acquire().expect("No available buffers in the pool");
416
417 match from_handle.lock().unwrap().deref() {
418 CuHandleInner::Detached(from) => match to_handle.lock().unwrap().deref_mut() {
419 CuHandleInner::Detached(CudaSliceWrapper(to)) => {
420 self.device
421 .htod_sync_copy_into(from, to)
422 .expect("Failed to copy data to device");
423 }
424 CuHandleInner::Pooled(to) => {
425 self.device
426 .htod_sync_copy_into(from, to.as_cuda_slice_mut())
427 .expect("Failed to copy data to device");
428 }
429 },
430 CuHandleInner::Pooled(from) => match to_handle.lock().unwrap().deref_mut() {
431 CuHandleInner::Detached(CudaSliceWrapper(to)) => {
432 self.device
433 .htod_sync_copy_into(from, to)
434 .expect("Failed to copy data to device");
435 }
436 CuHandleInner::Pooled(to) => {
437 self.device
438 .htod_sync_copy_into(from, to.as_cuda_slice_mut())
439 .expect("Failed to copy data to device");
440 }
441 },
442 }
443 to_handle
444 }
445 }
446
447 impl<E> DeviceCuPool<CudaSliceWrapper<E>> for CuCudaPool<E>
448 where
449 E: ElementType + ValidAsZeroBits + DeviceRepr,
450 {
451 fn copy_to_host_pool<O>(
453 &self,
454 device_handle: &CuHandle<CudaSliceWrapper<E>>,
455 host_handle: &mut CuHandle<O>,
456 ) -> Result<(), CuError>
457 where
458 O: ArrayLike<Element = E>,
459 {
460 match device_handle.lock().unwrap().deref() {
461 CuHandleInner::Pooled(source) => match host_handle.lock().unwrap().deref_mut() {
462 CuHandleInner::Pooled(ref mut destination) => {
463 self.device
464 .dtoh_sync_copy_into(source.as_cuda_slice(), destination)
465 .expect("Failed to copy data to device");
466 }
467 CuHandleInner::Detached(ref mut destination) => {
468 self.device
469 .dtoh_sync_copy_into(source.as_cuda_slice(), destination)
470 .expect("Failed to copy data to device");
471 }
472 },
473 CuHandleInner::Detached(source) => match host_handle.lock().unwrap().deref_mut() {
474 CuHandleInner::Pooled(ref mut destination) => {
475 self.device
476 .dtoh_sync_copy_into(source.as_cuda_slice(), destination)
477 .expect("Failed to copy data to device");
478 }
479 CuHandleInner::Detached(ref mut destination) => {
480 self.device
481 .dtoh_sync_copy_into(source.as_cuda_slice(), destination)
482 .expect("Failed to copy data to device");
483 }
484 },
485 }
486 Ok(())
487 }
488 }
489}
490
491#[derive(Debug)]
492pub struct AlignedBuffer<E: ElementType> {
494 ptr: *mut E,
495 size: usize,
496 layout: Layout,
497}
498
499impl<E: ElementType> AlignedBuffer<E> {
500 pub fn new(num_elements: usize, alignment: usize) -> Self {
501 let layout = Layout::from_size_align(num_elements * size_of::<E>(), alignment).unwrap();
502 let ptr = unsafe { alloc(layout) as *mut E };
503 if ptr.is_null() {
504 panic!("Failed to allocate memory");
505 }
506 Self {
507 ptr,
508 size: num_elements,
509 layout,
510 }
511 }
512}
513
514impl<E: ElementType> Deref for AlignedBuffer<E> {
515 type Target = [E];
516
517 fn deref(&self) -> &Self::Target {
518 unsafe { std::slice::from_raw_parts(self.ptr, self.size) }
519 }
520}
521
522impl<E: ElementType> DerefMut for AlignedBuffer<E> {
523 fn deref_mut(&mut self) -> &mut Self::Target {
524 unsafe { std::slice::from_raw_parts_mut(self.ptr, self.size) }
525 }
526}
527
528impl<E: ElementType> Drop for AlignedBuffer<E> {
529 fn drop(&mut self) {
530 if !self.ptr.is_null() {
531 unsafe {
532 dealloc(self.ptr as *mut u8, self.layout);
533 }
534 }
535 }
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541 #[cfg(all(feature = "cuda", not(target_os = "macos")))]
542 use crate::pool::cuda::CuCudaPool;
543 use std::cell::RefCell;
544
545 #[test]
546 fn test_pool() {
547 let objs = RefCell::new(vec![vec![1], vec![2], vec![3]]);
548 let holding = objs.borrow().clone();
549 let objs_as_slices = holding.iter().map(|x| x.as_slice()).collect::<Vec<_>>();
550 let pool = CuHostMemoryPool::new("mytestcudapool", 3, || objs.borrow_mut().pop().unwrap())
551 .unwrap();
552
553 let obj1 = pool.acquire().unwrap();
554 {
555 let obj2 = pool.acquire().unwrap();
556 assert!(objs_as_slices.contains(&obj1.lock().unwrap().deref().deref()));
557 assert!(objs_as_slices.contains(&obj2.lock().unwrap().deref().deref()));
558 assert_eq!(pool.space_left(), 1);
559 }
560 assert_eq!(pool.space_left(), 2);
561
562 let obj3 = pool.acquire().unwrap();
563 assert!(objs_as_slices.contains(&obj3.lock().unwrap().deref().deref()));
564
565 assert_eq!(pool.space_left(), 1);
566
567 let _obj4 = pool.acquire().unwrap();
568 assert_eq!(pool.space_left(), 0);
569
570 let obj5 = pool.acquire();
571 assert!(obj5.is_none());
572 }
573
574 #[cfg(all(feature = "cuda", not(target_os = "macos")))]
575 #[test]
576 #[ignore] fn test_cuda_pool() {
578 use cudarc::driver::CudaDevice;
579 let device = CudaDevice::new(0).unwrap();
580 let pool = CuCudaPool::<f32>::new("mytestcudapool", device, 3, 1).unwrap();
581
582 let _obj1 = pool.acquire().unwrap();
583
584 {
585 let _obj2 = pool.acquire().unwrap();
586 assert_eq!(pool.space_left(), 1);
587 }
588 assert_eq!(pool.space_left(), 2);
589
590 let _obj3 = pool.acquire().unwrap();
591
592 assert_eq!(pool.space_left(), 1);
593
594 let _obj4 = pool.acquire().unwrap();
595 assert_eq!(pool.space_left(), 0);
596
597 let obj5 = pool.acquire();
598 assert!(obj5.is_none());
599 }
600
601 #[cfg(all(feature = "cuda", not(target_os = "macos")))]
602 #[test]
603 #[ignore] fn test_copy_roundtrip() {
605 use cudarc::driver::CudaDevice;
606 let device = CudaDevice::new(0).unwrap();
607 let host_pool = CuHostMemoryPool::new("mytesthostpool", 3, || vec![0.0; 1]).unwrap();
608 let cuda_pool = CuCudaPool::<f32>::new("mytestcudapool", device, 3, 1).unwrap();
609
610 let cuda_handle = {
611 let mut initial_handle = host_pool.acquire().unwrap();
612 {
613 let mut inner_initial_handle = initial_handle.lock().unwrap();
614 if let CuHandleInner::Pooled(ref mut pooled) = *inner_initial_handle {
615 pooled[0] = 42.0;
616 } else {
617 panic!();
618 }
619 }
620
621 cuda_pool.copy_from(&mut initial_handle)
623 };
624
625 let mut final_handle = host_pool.acquire().unwrap();
627 cuda_pool
628 .copy_to_host_pool(&cuda_handle, &mut final_handle)
629 .unwrap();
630
631 let value = final_handle.lock().unwrap().deref().deref()[0];
632 assert_eq!(value, 42.0);
633 }
634}