Skip to main content

cu29_runtime/
pool.rs

1use arrayvec::ArrayString;
2use bincode::de::Decoder;
3use bincode::enc::Encoder;
4use bincode::error::{DecodeError, EncodeError};
5use bincode::{Decode, Encode};
6use cu29_traits::CuResult;
7use hashbrown::HashMap;
8use object_pool::{Pool, ReusableOwned};
9use serde::de::{self, MapAccess, SeqAccess, Visitor};
10use serde::{Deserialize, Deserializer, Serialize, Serializer};
11use smallvec::SmallVec;
12use std::alloc::{Layout, alloc, dealloc};
13use std::cell::Cell;
14use std::cell::UnsafeCell;
15use std::fmt::Debug;
16use std::fs::OpenOptions;
17use std::marker::PhantomData;
18use std::mem::{align_of, size_of};
19use std::ops::{Deref, DerefMut};
20use std::path::{Path, PathBuf};
21use std::sync::{Arc, Mutex, MutexGuard, OnceLock};
22
23use memmap2::{MmapMut, MmapOptions};
24use tempfile::NamedTempFile;
25
26type PoolID = ArrayString<64>;
27
28/// Trait for a Pool to exposed to be monitored by the monitoring API.
29pub trait PoolMonitor: Send + Sync {
30    /// A unique and descriptive identifier for the pool.
31    fn id(&self) -> PoolID;
32
33    /// Number of buffer slots left in the pool.
34    fn space_left(&self) -> usize;
35
36    /// Total size of the pool in number of buffers.
37    fn total_size(&self) -> usize;
38
39    /// Size of one buffer
40    fn buffer_size(&self) -> usize;
41}
42
43static POOL_REGISTRY: OnceLock<Mutex<HashMap<String, Arc<dyn PoolMonitor>>>> = OnceLock::new();
44const MAX_POOLS: usize = 16;
45
46fn lock_unpoison<T>(mutex: &Mutex<T>) -> MutexGuard<'_, T> {
47    match mutex.lock() {
48        Ok(guard) => guard,
49        Err(poison) => poison.into_inner(),
50    }
51}
52
53// Register a pool to the global registry.
54fn register_pool(pool: Arc<dyn PoolMonitor>) {
55    POOL_REGISTRY
56        .get_or_init(|| Mutex::new(HashMap::new()))
57        .lock()
58        .unwrap_or_else(|poison| poison.into_inner())
59        .insert(pool.id().to_string(), pool);
60}
61
62type PoolStats = (PoolID, usize, usize, usize);
63
64/// Get the list of pools and their statistics.
65/// We use SmallVec here to avoid heap allocations while the stack is running.
66pub fn pools_statistics() -> SmallVec<[PoolStats; MAX_POOLS]> {
67    // Safely get the registry, returning empty stats if not initialized.
68    let registry_lock = match POOL_REGISTRY.get() {
69        Some(lock) => lock_unpoison(lock),
70        None => return SmallVec::new(), // Return empty if registry is not initialized
71    };
72    let mut result = SmallVec::with_capacity(MAX_POOLS);
73    for pool in registry_lock.values() {
74        result.push((
75            pool.id(),
76            pool.space_left(),
77            pool.total_size(),
78            pool.buffer_size(),
79        ));
80    }
81    result
82}
83
84/// Basic Type that can be used in a buffer in a CuPool.
85pub trait ElementType: Default + Sized + Copy + Debug + Unpin + Send + Sync {
86    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError>;
87    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError>;
88}
89
90/// Blanket implementation for all types that are Sized, Copy, Encode, Decode and Debug.
91impl<T> ElementType for T
92where
93    T: Default + Sized + Copy + Debug + Unpin + Send + Sync,
94    T: Encode,
95    T: Decode<()>,
96{
97    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
98        self.encode(encoder)
99    }
100
101    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
102        Self::decode(decoder)
103    }
104}
105
106pub trait ArrayLike: Deref<Target = [Self::Element]> + DerefMut + Debug + Sync + Send {
107    type Element: ElementType;
108}
109
110thread_local! {
111    static SHARED_HANDLE_SERIALIZATION_ENABLED: Cell<bool> = const { Cell::new(false) };
112}
113
114pub struct SharedHandleSerializationGuard {
115    previous: bool,
116}
117
118impl Drop for SharedHandleSerializationGuard {
119    fn drop(&mut self) {
120        SHARED_HANDLE_SERIALIZATION_ENABLED.with(|enabled| enabled.set(self.previous));
121    }
122}
123
124pub fn enable_shared_handle_serialization() -> SharedHandleSerializationGuard {
125    let previous = SHARED_HANDLE_SERIALIZATION_ENABLED.with(|enabled| {
126        let previous = enabled.get();
127        enabled.set(true);
128        previous
129    });
130    SharedHandleSerializationGuard { previous }
131}
132
133fn shared_handle_serialization_enabled() -> bool {
134    SHARED_HANDLE_SERIALIZATION_ENABLED.with(Cell::get)
135}
136
137#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
138#[serde(rename_all = "snake_case")]
139pub enum CuSharedMemoryElementType {
140    U8,
141    U16,
142    U32,
143    U64,
144    I8,
145    I16,
146    I32,
147    I64,
148    F32,
149    F64,
150}
151
152impl CuSharedMemoryElementType {
153    pub fn of<E: ElementType + 'static>() -> Option<Self> {
154        let type_id = core::any::TypeId::of::<E>();
155        if type_id == core::any::TypeId::of::<u8>() {
156            Some(Self::U8)
157        } else if type_id == core::any::TypeId::of::<u16>() {
158            Some(Self::U16)
159        } else if type_id == core::any::TypeId::of::<u32>() {
160            Some(Self::U32)
161        } else if type_id == core::any::TypeId::of::<u64>() {
162            Some(Self::U64)
163        } else if type_id == core::any::TypeId::of::<i8>() {
164            Some(Self::I8)
165        } else if type_id == core::any::TypeId::of::<i16>() {
166            Some(Self::I16)
167        } else if type_id == core::any::TypeId::of::<i32>() {
168            Some(Self::I32)
169        } else if type_id == core::any::TypeId::of::<i64>() {
170            Some(Self::I64)
171        } else if type_id == core::any::TypeId::of::<f32>() {
172            Some(Self::F32)
173        } else if type_id == core::any::TypeId::of::<f64>() {
174            Some(Self::F64)
175        } else {
176            None
177        }
178    }
179}
180
181#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
182pub struct CuSharedMemoryHandleDescriptor {
183    #[serde(rename = "__cu_shm_handle__")]
184    pub marker: bool,
185    pub path: String,
186    pub offset_bytes: usize,
187    pub len_elements: usize,
188    pub element_type: CuSharedMemoryElementType,
189}
190
191impl CuSharedMemoryHandleDescriptor {
192    fn new(
193        path: String,
194        offset_bytes: usize,
195        len_elements: usize,
196        element_type: CuSharedMemoryElementType,
197    ) -> Self {
198        Self {
199            marker: true,
200            path,
201            offset_bytes,
202            len_elements,
203            element_type,
204        }
205    }
206}
207
208struct CuSharedMemoryRegion {
209    path: PathBuf,
210    mmap: UnsafeCell<MmapMut>,
211    _backing_file: Option<NamedTempFile>,
212}
213
214impl CuSharedMemoryRegion {
215    fn create(byte_len: usize) -> CuResult<Arc<Self>> {
216        let file = NamedTempFile::new()
217            .map_err(|e| cu29_traits::CuError::new_with_cause("create shared memory file", e))?;
218        file.as_file()
219            .set_len(byte_len as u64)
220            .map_err(|e| cu29_traits::CuError::new_with_cause("size shared memory file", e))?;
221        let mmap = unsafe {
222            MmapOptions::new()
223                .len(byte_len)
224                .map_mut(file.as_file())
225                .map_err(|e| cu29_traits::CuError::new_with_cause("map shared memory file", e))?
226        };
227        let region = Arc::new(Self {
228            path: file.path().to_path_buf(),
229            mmap: UnsafeCell::new(mmap),
230            _backing_file: Some(file),
231        });
232        cache_shared_region(region.clone());
233        Ok(region)
234    }
235
236    fn open(path: &Path) -> CuResult<Arc<Self>> {
237        if let Some(region) = cached_shared_region(path) {
238            return Ok(region);
239        }
240
241        let file = OpenOptions::new()
242            .read(true)
243            .write(true)
244            .open(path)
245            .map_err(|e| cu29_traits::CuError::new_with_cause("open shared memory file", e))?;
246        let len = file
247            .metadata()
248            .map_err(|e| cu29_traits::CuError::new_with_cause("stat shared memory file", e))?
249            .len() as usize;
250        let mmap = unsafe {
251            MmapOptions::new()
252                .len(len)
253                .map_mut(&file)
254                .map_err(|e| cu29_traits::CuError::new_with_cause("map shared memory file", e))?
255        };
256        let region = Arc::new(Self {
257            path: path.to_path_buf(),
258            mmap: UnsafeCell::new(mmap),
259            _backing_file: None,
260        });
261        cache_shared_region(region.clone());
262        Ok(region)
263    }
264}
265
266impl Debug for CuSharedMemoryRegion {
267    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
268        f.debug_struct("CuSharedMemoryRegion")
269            .field("path", &self.path)
270            .finish_non_exhaustive()
271    }
272}
273
274// SAFETY:
275// Access to the mapped bytes is mediated through Copper handles and pool slot
276// leasing, so cross-thread aliasing follows the same external synchronization as
277// other mutable payload buffers.
278unsafe impl Send for CuSharedMemoryRegion {}
279// SAFETY:
280// See `Send` rationale above.
281unsafe impl Sync for CuSharedMemoryRegion {}
282
283fn shared_region_cache() -> &'static Mutex<HashMap<PathBuf, std::sync::Weak<CuSharedMemoryRegion>>>
284{
285    static CACHE: OnceLock<Mutex<HashMap<PathBuf, std::sync::Weak<CuSharedMemoryRegion>>>> =
286        OnceLock::new();
287    CACHE.get_or_init(|| Mutex::new(HashMap::new()))
288}
289
290fn cache_shared_region(region: Arc<CuSharedMemoryRegion>) {
291    lock_unpoison(shared_region_cache()).insert(region.path.clone(), Arc::downgrade(&region));
292}
293
294fn cached_shared_region(path: &Path) -> Option<Arc<CuSharedMemoryRegion>> {
295    lock_unpoison(shared_region_cache())
296        .get(path)
297        .and_then(std::sync::Weak::upgrade)
298}
299
300fn shared_slot_stride<E: ElementType>(len_elements: usize) -> usize {
301    let raw_bytes = len_elements
302        .checked_mul(size_of::<E>())
303        .expect("shared memory slot size overflow");
304    let alignment = align_of::<E>().max(1);
305    raw_bytes.div_ceil(alignment) * alignment
306}
307
308#[derive(Debug)]
309pub struct CuSharedMemoryBuffer<E: ElementType> {
310    region: Arc<CuSharedMemoryRegion>,
311    offset_bytes: usize,
312    len_elements: usize,
313    _marker: PhantomData<E>,
314}
315
316impl<E: ElementType + 'static> CuSharedMemoryBuffer<E> {
317    fn from_region(
318        region: Arc<CuSharedMemoryRegion>,
319        offset_bytes: usize,
320        len_elements: usize,
321    ) -> Self {
322        Self {
323            region,
324            offset_bytes,
325            len_elements,
326            _marker: PhantomData,
327        }
328    }
329
330    pub fn from_vec_detached(data: Vec<E>) -> CuResult<Self> {
331        let len_elements = data.len();
332        let slot_stride = shared_slot_stride::<E>(len_elements.max(1));
333        let region = CuSharedMemoryRegion::create(slot_stride)?;
334        let mut buffer = Self::from_region(region, 0, len_elements);
335        if !data.is_empty() {
336            buffer.copy_from_slice(&data);
337        }
338        Ok(buffer)
339    }
340
341    pub fn from_descriptor(descriptor: &CuSharedMemoryHandleDescriptor) -> CuResult<Self> {
342        let expected = CuSharedMemoryElementType::of::<E>()
343            .ok_or_else(|| cu29_traits::CuError::from("unsupported shared memory element type"))?;
344        if descriptor.element_type != expected {
345            return Err(cu29_traits::CuError::from(
346                "shared memory descriptor element type mismatch",
347            ));
348        }
349        let region = CuSharedMemoryRegion::open(Path::new(&descriptor.path))?;
350        Ok(Self::from_region(
351            region,
352            descriptor.offset_bytes,
353            descriptor.len_elements,
354        ))
355    }
356
357    pub fn descriptor(&self) -> Option<CuSharedMemoryHandleDescriptor>
358    where
359        E: 'static,
360    {
361        CuSharedMemoryElementType::of::<E>().map(|element_type| {
362            CuSharedMemoryHandleDescriptor::new(
363                self.region.path.display().to_string(),
364                self.offset_bytes,
365                self.len_elements,
366                element_type,
367            )
368        })
369    }
370}
371
372impl<E: ElementType> Deref for CuSharedMemoryBuffer<E> {
373    type Target = [E];
374
375    fn deref(&self) -> &Self::Target {
376        let ptr = unsafe { (*self.region.mmap.get()).as_ptr().add(self.offset_bytes) as *const E };
377        unsafe { std::slice::from_raw_parts(ptr, self.len_elements) }
378    }
379}
380
381impl<E: ElementType> DerefMut for CuSharedMemoryBuffer<E> {
382    fn deref_mut(&mut self) -> &mut Self::Target {
383        let ptr = unsafe {
384            (*self.region.mmap.get())
385                .as_mut_ptr()
386                .add(self.offset_bytes) as *mut E
387        };
388        unsafe { std::slice::from_raw_parts_mut(ptr, self.len_elements) }
389    }
390}
391
392impl<E: ElementType> ArrayLike for CuSharedMemoryBuffer<E> {
393    type Element = E;
394}
395
396impl<E: ElementType> Encode for CuSharedMemoryBuffer<E> {
397    fn encode<Enc: Encoder>(&self, encoder: &mut Enc) -> Result<(), EncodeError> {
398        let len = self.len_elements as u64;
399        Encode::encode(&len, encoder)?;
400        for value in self.deref() {
401            value.encode(encoder)?;
402        }
403        Ok(())
404    }
405}
406
407impl<E: ElementType + 'static> Decode<()> for CuSharedMemoryBuffer<E> {
408    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
409        let len = <u64 as Decode<()>>::decode(decoder)? as usize;
410        let mut vec = Vec::with_capacity(len);
411        for _ in 0..len {
412            vec.push(E::decode(decoder)?);
413        }
414        Self::from_vec_detached(vec).map_err(|e| DecodeError::OtherString(e.to_string()))
415    }
416}
417
418/// A Handle to a Buffer.
419/// For onboard usages, the buffer should be Pooled (ie, coming from a preallocated pool).
420/// The Detached version is for offline usages where we don't really need a pool to deserialize them.
421pub enum CuHandleInner<T: Debug> {
422    Pooled(ReusableOwned<T>),
423    Detached(T), // Should only be used in offline cases (e.g. deserialization)
424}
425
426impl<T> Debug for CuHandleInner<T>
427where
428    T: Debug,
429{
430    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
431        match self {
432            CuHandleInner::Pooled(r) => {
433                write!(f, "Pooled: {:?}", r.deref())
434            }
435            CuHandleInner::Detached(r) => write!(f, "Detached: {r:?}"),
436        }
437    }
438}
439
440impl<T: ArrayLike> Deref for CuHandleInner<T> {
441    type Target = [T::Element];
442
443    fn deref(&self) -> &Self::Target {
444        match self {
445            CuHandleInner::Pooled(pooled) => pooled,
446            CuHandleInner::Detached(detached) => detached,
447        }
448    }
449}
450
451impl<T: ArrayLike> DerefMut for CuHandleInner<T> {
452    fn deref_mut(&mut self) -> &mut Self::Target {
453        match self {
454            CuHandleInner::Pooled(pooled) => pooled.deref_mut(),
455            CuHandleInner::Detached(detached) => detached,
456        }
457    }
458}
459
460/// A shareable handle to an Array coming from a pool (either host or device).
461#[derive(Debug)]
462pub struct CuHandle<T: ArrayLike>(Arc<Mutex<CuHandleInner<T>>>);
463
464impl<T: ArrayLike> Clone for CuHandle<T> {
465    fn clone(&self) -> Self {
466        Self(self.0.clone())
467    }
468}
469
470impl<T: ArrayLike> Deref for CuHandle<T> {
471    type Target = Arc<Mutex<CuHandleInner<T>>>;
472
473    fn deref(&self) -> &Self::Target {
474        &self.0
475    }
476}
477
478impl<T: ArrayLike> CuHandle<T> {
479    /// Create a new CuHandle not part of a Pool (not for onboard usages, use pools instead)
480    pub fn new_detached(inner: T) -> Self {
481        CuHandle(Arc::new(Mutex::new(CuHandleInner::Detached(inner))))
482    }
483
484    /// Safely access the inner value, applying a closure to it.
485    pub fn with_inner<R>(&self, f: impl FnOnce(&CuHandleInner<T>) -> R) -> R {
486        let lock = lock_unpoison(&self.0);
487        f(&*lock)
488    }
489
490    /// Mutably access the inner value, applying a closure to it.
491    pub fn with_inner_mut<R>(&self, f: impl FnOnce(&mut CuHandleInner<T>) -> R) -> R {
492        let mut lock = lock_unpoison(&self.0);
493        f(&mut *lock)
494    }
495}
496
497impl<U> Serialize for CuHandle<Vec<U>>
498where
499    U: ElementType + Serialize + 'static,
500{
501    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
502        let inner = lock_unpoison(&self.0);
503        match inner.deref() {
504            CuHandleInner::Pooled(pooled) => pooled.deref().serialize(serializer),
505            CuHandleInner::Detached(detached) => detached.serialize(serializer),
506        }
507    }
508}
509
510impl<'de, U> Deserialize<'de> for CuHandle<Vec<U>>
511where
512    U: ElementType + Deserialize<'de> + 'static,
513{
514    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
515        Vec::<U>::deserialize(deserializer).map(CuHandle::new_detached)
516    }
517}
518
519impl<U> Serialize for CuHandle<CuSharedMemoryBuffer<U>>
520where
521    U: ElementType + Serialize + 'static,
522{
523    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
524        let inner = lock_unpoison(&self.0);
525        let buffer = match inner.deref() {
526            CuHandleInner::Pooled(pooled) => pooled.deref(),
527            CuHandleInner::Detached(detached) => detached,
528        };
529
530        if shared_handle_serialization_enabled()
531            && let Some(descriptor) = buffer.descriptor()
532        {
533            return descriptor.serialize(serializer);
534        }
535
536        buffer.deref().serialize(serializer)
537    }
538}
539
540impl<'de, U> Deserialize<'de> for CuHandle<CuSharedMemoryBuffer<U>>
541where
542    U: ElementType + Deserialize<'de> + 'static,
543{
544    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
545        enum Repr<U> {
546            Descriptor(CuSharedMemoryHandleDescriptor),
547            Data(Vec<U>),
548        }
549
550        impl<'de, U> Deserialize<'de> for Repr<U>
551        where
552            U: ElementType + Deserialize<'de>,
553        {
554            fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
555                struct ReprVisitor<U>(PhantomData<U>);
556
557                impl<'de, U> Visitor<'de> for ReprVisitor<U>
558                where
559                    U: ElementType + Deserialize<'de>,
560                {
561                    type Value = Repr<U>;
562
563                    fn expecting(
564                        &self,
565                        formatter: &mut std::fmt::Formatter<'_>,
566                    ) -> std::fmt::Result {
567                        formatter
568                            .write_str("a shared-memory handle descriptor or an element sequence")
569                    }
570
571                    fn visit_seq<A: SeqAccess<'de>>(self, seq: A) -> Result<Self::Value, A::Error> {
572                        let data =
573                            Vec::<U>::deserialize(de::value::SeqAccessDeserializer::new(seq))?;
574                        Ok(Repr::Data(data))
575                    }
576
577                    fn visit_map<A: MapAccess<'de>>(self, map: A) -> Result<Self::Value, A::Error> {
578                        let descriptor = CuSharedMemoryHandleDescriptor::deserialize(
579                            de::value::MapAccessDeserializer::new(map),
580                        )?;
581                        Ok(Repr::Descriptor(descriptor))
582                    }
583                }
584
585                deserializer.deserialize_any(ReprVisitor(PhantomData))
586            }
587        }
588
589        match Repr::<U>::deserialize(deserializer)? {
590            Repr::Descriptor(descriptor) => CuSharedMemoryBuffer::from_descriptor(&descriptor)
591                .map(CuHandle::new_detached)
592                .map_err(de::Error::custom),
593            Repr::Data(data) => CuSharedMemoryBuffer::from_vec_detached(data)
594                .map(CuHandle::new_detached)
595                .map_err(de::Error::custom),
596        }
597    }
598}
599
600impl<T: ArrayLike + Encode> Encode for CuHandle<T>
601where
602    <T as ArrayLike>::Element: 'static,
603{
604    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
605        let inner = lock_unpoison(&self.0);
606        crate::monitoring::record_payload_handle_bytes(
607            inner.deref().len() * size_of::<T::Element>(),
608        );
609        match inner.deref() {
610            CuHandleInner::Pooled(pooled) => pooled.deref().encode(encoder),
611            CuHandleInner::Detached(detached) => detached.encode(encoder),
612        }
613    }
614}
615
616impl<T: ArrayLike> Default for CuHandle<T> {
617    fn default() -> Self {
618        panic!("Cannot create a default CuHandle")
619    }
620}
621
622impl<U: ElementType + Decode<()> + 'static> Decode<()> for CuHandle<Vec<U>> {
623    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
624        let vec: Vec<U> = Vec::decode(decoder)?;
625        Ok(CuHandle(Arc::new(Mutex::new(CuHandleInner::Detached(vec)))))
626    }
627}
628
629impl<U: ElementType + Decode<()> + 'static> Decode<()> for CuHandle<CuSharedMemoryBuffer<U>> {
630    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
631        let buffer = CuSharedMemoryBuffer::<U>::decode(decoder)?;
632        Ok(CuHandle(Arc::new(Mutex::new(CuHandleInner::Detached(
633            buffer,
634        )))))
635    }
636}
637
638/// A CuPool is a pool of buffers that can be shared between different parts of the code.
639/// Handles can be stored locally in the tasks and shared between them.
640pub trait CuPool<T: ArrayLike>: PoolMonitor {
641    /// Acquire a buffer from the pool.
642    fn acquire(&self) -> Option<CuHandle<T>>;
643
644    /// Copy data from a handle to a new handle from the pool.
645    fn copy_from<O>(&self, from: &mut CuHandle<O>) -> CuHandle<T>
646    where
647        O: ArrayLike<Element = T::Element>;
648}
649
650/// A device memory pool can copy data from a device to a host memory pool on top.
651pub trait DeviceCuPool<T: ArrayLike>: CuPool<T> {
652    /// Takes a handle to a device buffer and copies it into a host buffer pool.
653    /// It returns a new handle from the host pool with the data from the device handle given.
654    fn copy_to_host_pool<O>(
655        &self,
656        from_device_handle: &CuHandle<T>,
657        to_host_handle: &mut CuHandle<O>,
658    ) -> CuResult<()>
659    where
660        O: ArrayLike<Element = T::Element>;
661}
662
663/// A pool of host memory buffers.
664pub struct CuHostMemoryPool<T> {
665    /// Underlying pool of host buffers.
666    // Being an Arc is a requirement of try_pull_owned() so buffers can refer back to the pool.
667    id: PoolID,
668    pool: Arc<Pool<T>>,
669    size: usize,
670    buffer_size: usize,
671}
672
673impl<T: ArrayLike + 'static> CuHostMemoryPool<T> {
674    pub fn new<F>(id: &str, size: usize, buffer_initializer: F) -> CuResult<Arc<Self>>
675    where
676        F: Fn() -> T,
677    {
678        let pool = Arc::new(Pool::new(size, buffer_initializer));
679        let buffer_size = pool.try_pull().unwrap().len() * size_of::<T::Element>();
680
681        let og = Self {
682            id: PoolID::from(id).map_err(|_| "Failed to create PoolID")?,
683            pool,
684            size,
685            buffer_size,
686        };
687        let og = Arc::new(og);
688        register_pool(og.clone());
689        Ok(og)
690    }
691}
692
693impl<T: ArrayLike> PoolMonitor for CuHostMemoryPool<T> {
694    fn id(&self) -> PoolID {
695        self.id
696    }
697
698    fn space_left(&self) -> usize {
699        self.pool.len()
700    }
701
702    fn total_size(&self) -> usize {
703        self.size
704    }
705
706    fn buffer_size(&self) -> usize {
707        self.buffer_size
708    }
709}
710
711impl<T: ArrayLike> CuPool<T> for CuHostMemoryPool<T> {
712    fn acquire(&self) -> Option<CuHandle<T>> {
713        let owned_object = self.pool.try_pull_owned(); // Use the owned version
714
715        owned_object.map(|reusable| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(reusable)))))
716    }
717
718    fn copy_from<O: ArrayLike<Element = T::Element>>(&self, from: &mut CuHandle<O>) -> CuHandle<T> {
719        let to_handle = self.acquire().expect("No available buffers in the pool");
720
721        match lock_unpoison(&from.0).deref() {
722            CuHandleInner::Detached(source) => match lock_unpoison(&to_handle.0).deref_mut() {
723                CuHandleInner::Detached(destination) => {
724                    destination.copy_from_slice(source);
725                }
726                CuHandleInner::Pooled(destination) => {
727                    destination.copy_from_slice(source);
728                }
729            },
730            CuHandleInner::Pooled(source) => match lock_unpoison(&to_handle.0).deref_mut() {
731                CuHandleInner::Detached(destination) => {
732                    destination.copy_from_slice(source);
733                }
734                CuHandleInner::Pooled(destination) => {
735                    destination.copy_from_slice(source);
736                }
737            },
738        }
739        to_handle
740    }
741}
742
743/// A pool of fixed-size shared-memory buffers that can be leased to a child
744/// process without copying the underlying bytes.
745pub struct CuSharedMemoryPool<E: ElementType> {
746    id: PoolID,
747    pool: Arc<Pool<CuSharedMemoryBuffer<E>>>,
748    size: usize,
749    buffer_size: usize,
750}
751
752impl<E: ElementType + 'static> CuSharedMemoryPool<E> {
753    pub fn new(id: &str, size: usize, elements_per_buffer: usize) -> CuResult<Arc<Self>> {
754        let slot_stride = shared_slot_stride::<E>(elements_per_buffer.max(1));
755        let region = CuSharedMemoryRegion::create(
756            slot_stride
757                .checked_mul(size)
758                .ok_or_else(|| cu29_traits::CuError::from("shared memory pool size overflow"))?,
759        )?;
760        let next_slot = Arc::new(std::sync::atomic::AtomicUsize::new(0));
761        let initializer_region = region.clone();
762        let initializer_next_slot = next_slot.clone();
763        let pool = Arc::new(Pool::new(size, move || {
764            let slot = initializer_next_slot.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
765            assert!(slot < size, "shared memory pool slot index overflow");
766            CuSharedMemoryBuffer::from_region(
767                initializer_region.clone(),
768                slot * slot_stride,
769                elements_per_buffer,
770            )
771        }));
772
773        let pool = Arc::new(Self {
774            id: PoolID::from(id).map_err(|_| "Failed to create PoolID")?,
775            pool,
776            size,
777            buffer_size: elements_per_buffer * size_of::<E>(),
778        });
779        register_pool(pool.clone());
780        Ok(pool)
781    }
782}
783
784impl<E: ElementType> PoolMonitor for CuSharedMemoryPool<E> {
785    fn id(&self) -> PoolID {
786        self.id
787    }
788
789    fn space_left(&self) -> usize {
790        self.pool.len()
791    }
792
793    fn total_size(&self) -> usize {
794        self.size
795    }
796
797    fn buffer_size(&self) -> usize {
798        self.buffer_size
799    }
800}
801
802impl<E: ElementType> CuPool<CuSharedMemoryBuffer<E>> for CuSharedMemoryPool<E> {
803    fn acquire(&self) -> Option<CuHandle<CuSharedMemoryBuffer<E>>> {
804        self.pool
805            .try_pull_owned()
806            .map(|reusable| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(reusable)))))
807    }
808
809    fn copy_from<O>(&self, from: &mut CuHandle<O>) -> CuHandle<CuSharedMemoryBuffer<E>>
810    where
811        O: ArrayLike<Element = E>,
812    {
813        let to_handle = self.acquire().expect("No available buffers in the pool");
814
815        match lock_unpoison(&from.0).deref() {
816            CuHandleInner::Detached(source) => match lock_unpoison(&to_handle.0).deref_mut() {
817                CuHandleInner::Detached(destination) => {
818                    destination.copy_from_slice(source);
819                }
820                CuHandleInner::Pooled(destination) => {
821                    destination.copy_from_slice(source);
822                }
823            },
824            CuHandleInner::Pooled(source) => match lock_unpoison(&to_handle.0).deref_mut() {
825                CuHandleInner::Detached(destination) => {
826                    destination.copy_from_slice(source);
827                }
828                CuHandleInner::Pooled(destination) => {
829                    destination.copy_from_slice(source);
830                }
831            },
832        }
833        to_handle
834    }
835}
836
837impl<E: ElementType + 'static> ArrayLike for Vec<E> {
838    type Element = E;
839}
840
841#[cfg(all(feature = "cuda", not(target_os = "macos")))]
842mod cuda {
843    use super::*;
844    use cu29_traits::CuError;
845    use cudarc::driver::{
846        CudaContext, CudaSlice, CudaStream, DeviceRepr, HostSlice, SyncOnDrop, ValidAsZeroBits,
847    };
848    use std::sync::Arc;
849
850    #[derive(Debug)]
851    pub struct CudaSliceWrapper<E>(CudaSlice<E>);
852
853    impl<E> Deref for CudaSliceWrapper<E>
854    where
855        E: ElementType,
856    {
857        type Target = [E];
858
859        fn deref(&self) -> &Self::Target {
860            // Implement logic to return a slice
861            panic!("You need to copy data to host memory pool before accessing it.");
862        }
863    }
864
865    impl<E> DerefMut for CudaSliceWrapper<E>
866    where
867        E: ElementType,
868    {
869        fn deref_mut(&mut self) -> &mut Self::Target {
870            panic!("You need to copy data to host memory pool before accessing it.");
871        }
872    }
873
874    impl<E: ElementType> ArrayLike for CudaSliceWrapper<E> {
875        type Element = E;
876    }
877
878    impl<E> CudaSliceWrapper<E> {
879        pub fn as_cuda_slice(&self) -> &CudaSlice<E> {
880            &self.0
881        }
882
883        pub fn as_cuda_slice_mut(&mut self) -> &mut CudaSlice<E> {
884            &mut self.0
885        }
886    }
887
888    // Create a wrapper type to bridge between ArrayLike and HostSlice
889    pub struct HostSliceWrapper<'a, T: ArrayLike> {
890        inner: &'a T,
891    }
892
893    impl<T: ArrayLike> HostSlice<T::Element> for HostSliceWrapper<'_, T> {
894        fn len(&self) -> usize {
895            self.inner.len()
896        }
897
898        // SAFETY: HostSlice requires the returned slice to remain valid for 'b.
899        unsafe fn stream_synced_slice<'b>(
900            &'b self,
901            stream: &'b CudaStream,
902        ) -> (&'b [T::Element], SyncOnDrop<'b>) {
903            (self.inner.deref(), SyncOnDrop::sync_stream(stream))
904        }
905
906        // SAFETY: This wrapper cannot provide mutable access; callers must not rely on this.
907        unsafe fn stream_synced_mut_slice<'b>(
908            &'b mut self,
909            _stream: &'b CudaStream,
910        ) -> (&'b mut [T::Element], SyncOnDrop<'b>) {
911            panic!("Cannot get mutable reference from immutable wrapper")
912        }
913    }
914
915    // Mutable wrapper
916    pub struct HostSliceMutWrapper<'a, T: ArrayLike> {
917        inner: &'a mut T,
918    }
919
920    impl<T: ArrayLike> HostSlice<T::Element> for HostSliceMutWrapper<'_, T> {
921        fn len(&self) -> usize {
922            self.inner.len()
923        }
924
925        // SAFETY: HostSlice requires the returned slice to remain valid for 'b.
926        unsafe fn stream_synced_slice<'b>(
927            &'b self,
928            stream: &'b CudaStream,
929        ) -> (&'b [T::Element], SyncOnDrop<'b>) {
930            (self.inner.deref(), SyncOnDrop::sync_stream(stream))
931        }
932
933        // SAFETY: HostSlice requires the returned slice to remain valid for 'b.
934        unsafe fn stream_synced_mut_slice<'b>(
935            &'b mut self,
936            stream: &'b CudaStream,
937        ) -> (&'b mut [T::Element], SyncOnDrop<'b>) {
938            (self.inner.deref_mut(), SyncOnDrop::sync_stream(stream))
939        }
940    }
941
942    // Add helper methods to the CuCudaPool implementation
943    impl<E: ElementType + ValidAsZeroBits + DeviceRepr> CuCudaPool<E> {
944        // Helper method to get a HostSliceWrapper from a CuHandleInner
945        fn get_host_slice_wrapper<O: ArrayLike<Element = E>>(
946            handle_inner: &CuHandleInner<O>,
947        ) -> HostSliceWrapper<'_, O> {
948            match handle_inner {
949                CuHandleInner::Pooled(pooled) => HostSliceWrapper { inner: pooled },
950                CuHandleInner::Detached(detached) => HostSliceWrapper { inner: detached },
951            }
952        }
953
954        // Helper method to get a HostSliceMutWrapper from a CuHandleInner
955        fn get_host_slice_mut_wrapper<O: ArrayLike<Element = E>>(
956            handle_inner: &mut CuHandleInner<O>,
957        ) -> HostSliceMutWrapper<'_, O> {
958            match handle_inner {
959                CuHandleInner::Pooled(pooled) => HostSliceMutWrapper { inner: pooled },
960                CuHandleInner::Detached(detached) => HostSliceMutWrapper { inner: detached },
961            }
962        }
963    }
964    /// A pool of CUDA memory buffers.
965    pub struct CuCudaPool<E>
966    where
967        E: ElementType + ValidAsZeroBits + DeviceRepr + Unpin,
968    {
969        id: PoolID,
970        stream: Arc<CudaStream>,
971        pool: Arc<Pool<CudaSliceWrapper<E>>>,
972        nb_buffers: usize,
973        nb_element_per_buffer: usize,
974    }
975
976    impl<E: ElementType + ValidAsZeroBits + DeviceRepr> CuCudaPool<E> {
977        #[allow(dead_code)]
978        pub fn new(
979            id: &'static str,
980            ctx: Arc<CudaContext>,
981            nb_buffers: usize,
982            nb_element_per_buffer: usize,
983        ) -> CuResult<Self> {
984            let stream = ctx.default_stream();
985            let pool = (0..nb_buffers)
986                .map(|_| {
987                    stream
988                        .alloc_zeros(nb_element_per_buffer)
989                        .map(CudaSliceWrapper)
990                        .map_err(|_| "Failed to allocate device memory")
991                })
992                .collect::<Result<Vec<_>, _>>()?;
993
994            Ok(Self {
995                id: PoolID::from(id).map_err(|_| "Failed to create PoolID")?,
996                stream,
997                pool: Arc::new(Pool::from_vec(pool)),
998                nb_buffers,
999                nb_element_per_buffer,
1000            })
1001        }
1002    }
1003
1004    impl<E> PoolMonitor for CuCudaPool<E>
1005    where
1006        E: DeviceRepr + ElementType + ValidAsZeroBits,
1007    {
1008        fn id(&self) -> PoolID {
1009            self.id
1010        }
1011
1012        fn space_left(&self) -> usize {
1013            self.pool.len()
1014        }
1015
1016        fn total_size(&self) -> usize {
1017            self.nb_buffers
1018        }
1019
1020        fn buffer_size(&self) -> usize {
1021            self.nb_element_per_buffer * size_of::<E>()
1022        }
1023    }
1024
1025    impl<E> CuPool<CudaSliceWrapper<E>> for CuCudaPool<E>
1026    where
1027        E: DeviceRepr + ElementType + ValidAsZeroBits,
1028    {
1029        fn acquire(&self) -> Option<CuHandle<CudaSliceWrapper<E>>> {
1030            self.pool
1031                .try_pull_owned()
1032                .map(|x| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(x)))))
1033        }
1034
1035        fn copy_from<O>(&self, from_handle: &mut CuHandle<O>) -> CuHandle<CudaSliceWrapper<E>>
1036        where
1037            O: ArrayLike<Element = E>,
1038        {
1039            let to_handle = self.acquire().expect("No available buffers in the pool");
1040
1041            {
1042                let from_lock = lock_unpoison(&from_handle.0);
1043                let mut to_lock = lock_unpoison(&to_handle.0);
1044
1045                match &mut *to_lock {
1046                    CuHandleInner::Detached(CudaSliceWrapper(to)) => {
1047                        let wrapper = Self::get_host_slice_wrapper(&*from_lock);
1048                        self.stream
1049                            .memcpy_htod(&wrapper, to)
1050                            .expect("Failed to copy data to device");
1051                    }
1052                    CuHandleInner::Pooled(to) => {
1053                        let wrapper = Self::get_host_slice_wrapper(&*from_lock);
1054                        self.stream
1055                            .memcpy_htod(&wrapper, to.as_cuda_slice_mut())
1056                            .expect("Failed to copy data to device");
1057                    }
1058                }
1059            } // locks are dropped here
1060            to_handle // now we can safely return to_handle
1061        }
1062    }
1063
1064    impl<E> DeviceCuPool<CudaSliceWrapper<E>> for CuCudaPool<E>
1065    where
1066        E: ElementType + ValidAsZeroBits + DeviceRepr,
1067    {
1068        /// Copy from device to host
1069        fn copy_to_host_pool<O>(
1070            &self,
1071            device_handle: &CuHandle<CudaSliceWrapper<E>>,
1072            host_handle: &mut CuHandle<O>,
1073        ) -> Result<(), CuError>
1074        where
1075            O: ArrayLike<Element = E>,
1076        {
1077            let device_lock = device_handle.lock().map_err(|e| {
1078                CuError::from("Device handle mutex poisoned").add_cause(&e.to_string())
1079            })?;
1080            let mut host_lock = host_handle.lock().map_err(|e| {
1081                CuError::from("Host handle mutex poisoned").add_cause(&e.to_string())
1082            })?;
1083            let src = match &*device_lock {
1084                CuHandleInner::Pooled(source) => source.as_cuda_slice(),
1085                CuHandleInner::Detached(source) => source.as_cuda_slice(),
1086            };
1087            let mut wrapper = Self::get_host_slice_mut_wrapper(&mut *host_lock);
1088            self.stream.memcpy_dtoh(src, &mut wrapper).map_err(|e| {
1089                CuError::from("Failed to copy data from device to host").add_cause(&e.to_string())
1090            })?;
1091            Ok(())
1092        }
1093    }
1094}
1095
1096#[derive(Debug)]
1097/// A buffer that is aligned to a specific size with the Element of type E.
1098pub struct AlignedBuffer<E: ElementType> {
1099    ptr: *mut E,
1100    size: usize,
1101    layout: Layout,
1102}
1103
1104impl<E: ElementType> AlignedBuffer<E> {
1105    pub fn new(num_elements: usize, alignment: usize) -> Self {
1106        assert!(
1107            num_elements > 0 && size_of::<E>() > 0,
1108            "AlignedBuffer requires a non-zero element count and non-zero-sized element type"
1109        );
1110        let alignment = alignment.max(align_of::<E>());
1111        let alloc_size = num_elements
1112            .checked_mul(size_of::<E>())
1113            .expect("AlignedBuffer allocation size overflow");
1114        let layout = Layout::from_size_align(alloc_size, alignment).unwrap();
1115        // SAFETY: layout describes a valid, non-zero allocation request.
1116        let ptr = unsafe { alloc(layout) as *mut E };
1117        if ptr.is_null() {
1118            panic!("Failed to allocate memory");
1119        }
1120        // SAFETY: ptr is valid for writes of `num_elements` elements.
1121        unsafe {
1122            for i in 0..num_elements {
1123                std::ptr::write(ptr.add(i), E::default());
1124            }
1125        }
1126        Self {
1127            ptr,
1128            size: num_elements,
1129            layout,
1130        }
1131    }
1132}
1133
1134impl<E: ElementType> Deref for AlignedBuffer<E> {
1135    type Target = [E];
1136
1137    fn deref(&self) -> &Self::Target {
1138        // SAFETY: `new` initializes all elements and keeps the pointer aligned.
1139        unsafe { std::slice::from_raw_parts(self.ptr, self.size) }
1140    }
1141}
1142
1143impl<E: ElementType> DerefMut for AlignedBuffer<E> {
1144    fn deref_mut(&mut self) -> &mut Self::Target {
1145        // SAFETY: `new` initializes all elements and keeps the pointer aligned.
1146        unsafe { std::slice::from_raw_parts_mut(self.ptr, self.size) }
1147    }
1148}
1149
1150impl<E: ElementType> Drop for AlignedBuffer<E> {
1151    fn drop(&mut self) {
1152        // SAFETY: `ptr` was allocated with `layout` in `new`.
1153        unsafe { dealloc(self.ptr as *mut u8, self.layout) }
1154    }
1155}
1156
1157#[cfg(test)]
1158mod tests {
1159    use super::*;
1160
1161    #[test]
1162    fn test_pool() {
1163        use std::cell::RefCell;
1164        let objs = RefCell::new(vec![vec![1], vec![2], vec![3]]);
1165        let holding = objs.borrow().clone();
1166        let objs_as_slices = holding.iter().map(|x| x.as_slice()).collect::<Vec<_>>();
1167        let pool = CuHostMemoryPool::new("mytestcudapool", 3, || objs.borrow_mut().pop().unwrap())
1168            .unwrap();
1169
1170        let obj1 = pool.acquire().unwrap();
1171        {
1172            let obj2 = pool.acquire().unwrap();
1173            assert!(objs_as_slices.contains(&obj1.lock().unwrap().deref().deref()));
1174            assert!(objs_as_slices.contains(&obj2.lock().unwrap().deref().deref()));
1175            assert_eq!(pool.space_left(), 1);
1176        }
1177        assert_eq!(pool.space_left(), 2);
1178
1179        let obj3 = pool.acquire().unwrap();
1180        assert!(objs_as_slices.contains(&obj3.lock().unwrap().deref().deref()));
1181
1182        assert_eq!(pool.space_left(), 1);
1183
1184        let _obj4 = pool.acquire().unwrap();
1185        assert_eq!(pool.space_left(), 0);
1186
1187        let obj5 = pool.acquire();
1188        assert!(obj5.is_none());
1189    }
1190
1191    #[cfg(all(feature = "cuda", has_nvidia_gpu))]
1192    #[test]
1193    fn test_cuda_pool() {
1194        use crate::pool::cuda::CuCudaPool;
1195        use cudarc::driver::CudaContext;
1196        let ctx = CudaContext::new(0).unwrap();
1197        let pool = CuCudaPool::<f32>::new("mytestcudapool", ctx, 3, 1).unwrap();
1198
1199        let _obj1 = pool.acquire().unwrap();
1200
1201        {
1202            let _obj2 = pool.acquire().unwrap();
1203            assert_eq!(pool.space_left(), 1);
1204        }
1205        assert_eq!(pool.space_left(), 2);
1206
1207        let _obj3 = pool.acquire().unwrap();
1208
1209        assert_eq!(pool.space_left(), 1);
1210
1211        let _obj4 = pool.acquire().unwrap();
1212        assert_eq!(pool.space_left(), 0);
1213
1214        let obj5 = pool.acquire();
1215        assert!(obj5.is_none());
1216    }
1217
1218    #[cfg(all(feature = "cuda", has_nvidia_gpu))]
1219    #[test]
1220    fn test_copy_roundtrip() {
1221        use crate::pool::cuda::CuCudaPool;
1222        use cudarc::driver::CudaContext;
1223        let ctx = CudaContext::new(0).unwrap();
1224        let host_pool = CuHostMemoryPool::new("mytesthostpool", 3, || vec![0.0; 1]).unwrap();
1225        let cuda_pool = CuCudaPool::<f32>::new("mytestcudapool", ctx, 3, 1).unwrap();
1226
1227        let cuda_handle = {
1228            let mut initial_handle = host_pool.acquire().unwrap();
1229            {
1230                let mut inner_initial_handle = initial_handle.lock().unwrap();
1231                if let CuHandleInner::Pooled(ref mut pooled) = *inner_initial_handle {
1232                    pooled[0] = 42.0;
1233                } else {
1234                    panic!();
1235                }
1236            }
1237
1238            // send that to the GPU
1239            cuda_pool.copy_from(&mut initial_handle)
1240        };
1241
1242        // get it back to the host
1243        let mut final_handle = host_pool.acquire().unwrap();
1244        cuda_pool
1245            .copy_to_host_pool(&cuda_handle, &mut final_handle)
1246            .unwrap();
1247
1248        let value = final_handle.lock().unwrap().deref().deref()[0];
1249        assert_eq!(value, 42.0);
1250    }
1251}