1use arrayvec::ArrayString;
2use bincode::de::Decoder;
3use bincode::enc::Encoder;
4use bincode::error::{DecodeError, EncodeError};
5use bincode::{Decode, Encode};
6use cu29_traits::CuResult;
7use hashbrown::HashMap;
8use object_pool::{Pool, ReusableOwned};
9use serde::de::{self, MapAccess, SeqAccess, Visitor};
10use serde::{Deserialize, Deserializer, Serialize, Serializer};
11use smallvec::SmallVec;
12use std::alloc::{Layout, alloc, dealloc};
13use std::cell::Cell;
14use std::cell::UnsafeCell;
15use std::fmt::Debug;
16use std::fs::OpenOptions;
17use std::marker::PhantomData;
18use std::mem::{align_of, size_of};
19use std::ops::{Deref, DerefMut};
20use std::path::{Path, PathBuf};
21use std::sync::{Arc, Mutex, MutexGuard, OnceLock};
22
23use memmap2::{MmapMut, MmapOptions};
24use tempfile::NamedTempFile;
25
26type PoolID = ArrayString<64>;
27
28pub trait PoolMonitor: Send + Sync {
30 fn id(&self) -> PoolID;
32
33 fn space_left(&self) -> usize;
35
36 fn total_size(&self) -> usize;
38
39 fn buffer_size(&self) -> usize;
41}
42
43static POOL_REGISTRY: OnceLock<Mutex<HashMap<String, Arc<dyn PoolMonitor>>>> = OnceLock::new();
44const MAX_POOLS: usize = 16;
45
46fn lock_unpoison<T>(mutex: &Mutex<T>) -> MutexGuard<'_, T> {
47 match mutex.lock() {
48 Ok(guard) => guard,
49 Err(poison) => poison.into_inner(),
50 }
51}
52
53fn register_pool(pool: Arc<dyn PoolMonitor>) {
55 POOL_REGISTRY
56 .get_or_init(|| Mutex::new(HashMap::new()))
57 .lock()
58 .unwrap_or_else(|poison| poison.into_inner())
59 .insert(pool.id().to_string(), pool);
60}
61
62type PoolStats = (PoolID, usize, usize, usize);
63
64pub fn pools_statistics() -> SmallVec<[PoolStats; MAX_POOLS]> {
67 let registry_lock = match POOL_REGISTRY.get() {
69 Some(lock) => lock_unpoison(lock),
70 None => return SmallVec::new(), };
72 let mut result = SmallVec::with_capacity(MAX_POOLS);
73 for pool in registry_lock.values() {
74 result.push((
75 pool.id(),
76 pool.space_left(),
77 pool.total_size(),
78 pool.buffer_size(),
79 ));
80 }
81 result
82}
83
84pub trait ElementType: Default + Sized + Copy + Debug + Unpin + Send + Sync {
86 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError>;
87 fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError>;
88}
89
90impl<T> ElementType for T
92where
93 T: Default + Sized + Copy + Debug + Unpin + Send + Sync,
94 T: Encode,
95 T: Decode<()>,
96{
97 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
98 self.encode(encoder)
99 }
100
101 fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
102 Self::decode(decoder)
103 }
104}
105
106pub trait ArrayLike: Deref<Target = [Self::Element]> + DerefMut + Debug + Sync + Send {
107 type Element: ElementType;
108}
109
110thread_local! {
111 static SHARED_HANDLE_SERIALIZATION_ENABLED: Cell<bool> = const { Cell::new(false) };
112}
113
114pub struct SharedHandleSerializationGuard {
115 previous: bool,
116}
117
118impl Drop for SharedHandleSerializationGuard {
119 fn drop(&mut self) {
120 SHARED_HANDLE_SERIALIZATION_ENABLED.with(|enabled| enabled.set(self.previous));
121 }
122}
123
124pub fn enable_shared_handle_serialization() -> SharedHandleSerializationGuard {
125 let previous = SHARED_HANDLE_SERIALIZATION_ENABLED.with(|enabled| {
126 let previous = enabled.get();
127 enabled.set(true);
128 previous
129 });
130 SharedHandleSerializationGuard { previous }
131}
132
133fn shared_handle_serialization_enabled() -> bool {
134 SHARED_HANDLE_SERIALIZATION_ENABLED.with(Cell::get)
135}
136
137#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
138#[serde(rename_all = "snake_case")]
139pub enum CuSharedMemoryElementType {
140 U8,
141 U16,
142 U32,
143 U64,
144 I8,
145 I16,
146 I32,
147 I64,
148 F32,
149 F64,
150}
151
152impl CuSharedMemoryElementType {
153 pub fn of<E: ElementType + 'static>() -> Option<Self> {
154 let type_id = core::any::TypeId::of::<E>();
155 if type_id == core::any::TypeId::of::<u8>() {
156 Some(Self::U8)
157 } else if type_id == core::any::TypeId::of::<u16>() {
158 Some(Self::U16)
159 } else if type_id == core::any::TypeId::of::<u32>() {
160 Some(Self::U32)
161 } else if type_id == core::any::TypeId::of::<u64>() {
162 Some(Self::U64)
163 } else if type_id == core::any::TypeId::of::<i8>() {
164 Some(Self::I8)
165 } else if type_id == core::any::TypeId::of::<i16>() {
166 Some(Self::I16)
167 } else if type_id == core::any::TypeId::of::<i32>() {
168 Some(Self::I32)
169 } else if type_id == core::any::TypeId::of::<i64>() {
170 Some(Self::I64)
171 } else if type_id == core::any::TypeId::of::<f32>() {
172 Some(Self::F32)
173 } else if type_id == core::any::TypeId::of::<f64>() {
174 Some(Self::F64)
175 } else {
176 None
177 }
178 }
179}
180
181#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
182pub struct CuSharedMemoryHandleDescriptor {
183 #[serde(rename = "__cu_shm_handle__")]
184 pub marker: bool,
185 pub path: String,
186 pub offset_bytes: usize,
187 pub len_elements: usize,
188 pub element_type: CuSharedMemoryElementType,
189}
190
191impl CuSharedMemoryHandleDescriptor {
192 fn new(
193 path: String,
194 offset_bytes: usize,
195 len_elements: usize,
196 element_type: CuSharedMemoryElementType,
197 ) -> Self {
198 Self {
199 marker: true,
200 path,
201 offset_bytes,
202 len_elements,
203 element_type,
204 }
205 }
206}
207
208struct CuSharedMemoryRegion {
209 path: PathBuf,
210 mmap: UnsafeCell<MmapMut>,
211 _backing_file: Option<NamedTempFile>,
212}
213
214impl CuSharedMemoryRegion {
215 fn create(byte_len: usize) -> CuResult<Arc<Self>> {
216 let file = NamedTempFile::new()
217 .map_err(|e| cu29_traits::CuError::new_with_cause("create shared memory file", e))?;
218 file.as_file()
219 .set_len(byte_len as u64)
220 .map_err(|e| cu29_traits::CuError::new_with_cause("size shared memory file", e))?;
221 let mmap = unsafe {
222 MmapOptions::new()
223 .len(byte_len)
224 .map_mut(file.as_file())
225 .map_err(|e| cu29_traits::CuError::new_with_cause("map shared memory file", e))?
226 };
227 let region = Arc::new(Self {
228 path: file.path().to_path_buf(),
229 mmap: UnsafeCell::new(mmap),
230 _backing_file: Some(file),
231 });
232 cache_shared_region(region.clone());
233 Ok(region)
234 }
235
236 fn open(path: &Path) -> CuResult<Arc<Self>> {
237 if let Some(region) = cached_shared_region(path) {
238 return Ok(region);
239 }
240
241 let file = OpenOptions::new()
242 .read(true)
243 .write(true)
244 .open(path)
245 .map_err(|e| cu29_traits::CuError::new_with_cause("open shared memory file", e))?;
246 let len = file
247 .metadata()
248 .map_err(|e| cu29_traits::CuError::new_with_cause("stat shared memory file", e))?
249 .len() as usize;
250 let mmap = unsafe {
251 MmapOptions::new()
252 .len(len)
253 .map_mut(&file)
254 .map_err(|e| cu29_traits::CuError::new_with_cause("map shared memory file", e))?
255 };
256 let region = Arc::new(Self {
257 path: path.to_path_buf(),
258 mmap: UnsafeCell::new(mmap),
259 _backing_file: None,
260 });
261 cache_shared_region(region.clone());
262 Ok(region)
263 }
264}
265
266impl Debug for CuSharedMemoryRegion {
267 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
268 f.debug_struct("CuSharedMemoryRegion")
269 .field("path", &self.path)
270 .finish_non_exhaustive()
271 }
272}
273
274unsafe impl Send for CuSharedMemoryRegion {}
279unsafe impl Sync for CuSharedMemoryRegion {}
282
283fn shared_region_cache() -> &'static Mutex<HashMap<PathBuf, std::sync::Weak<CuSharedMemoryRegion>>>
284{
285 static CACHE: OnceLock<Mutex<HashMap<PathBuf, std::sync::Weak<CuSharedMemoryRegion>>>> =
286 OnceLock::new();
287 CACHE.get_or_init(|| Mutex::new(HashMap::new()))
288}
289
290fn cache_shared_region(region: Arc<CuSharedMemoryRegion>) {
291 lock_unpoison(shared_region_cache()).insert(region.path.clone(), Arc::downgrade(®ion));
292}
293
294fn cached_shared_region(path: &Path) -> Option<Arc<CuSharedMemoryRegion>> {
295 lock_unpoison(shared_region_cache())
296 .get(path)
297 .and_then(std::sync::Weak::upgrade)
298}
299
300fn shared_slot_stride<E: ElementType>(len_elements: usize) -> usize {
301 let raw_bytes = len_elements
302 .checked_mul(size_of::<E>())
303 .expect("shared memory slot size overflow");
304 let alignment = align_of::<E>().max(1);
305 raw_bytes.div_ceil(alignment) * alignment
306}
307
308#[derive(Debug)]
309pub struct CuSharedMemoryBuffer<E: ElementType> {
310 region: Arc<CuSharedMemoryRegion>,
311 offset_bytes: usize,
312 len_elements: usize,
313 _marker: PhantomData<E>,
314}
315
316impl<E: ElementType + 'static> CuSharedMemoryBuffer<E> {
317 fn from_region(
318 region: Arc<CuSharedMemoryRegion>,
319 offset_bytes: usize,
320 len_elements: usize,
321 ) -> Self {
322 Self {
323 region,
324 offset_bytes,
325 len_elements,
326 _marker: PhantomData,
327 }
328 }
329
330 pub fn from_vec_detached(data: Vec<E>) -> CuResult<Self> {
331 let len_elements = data.len();
332 let slot_stride = shared_slot_stride::<E>(len_elements.max(1));
333 let region = CuSharedMemoryRegion::create(slot_stride)?;
334 let mut buffer = Self::from_region(region, 0, len_elements);
335 if !data.is_empty() {
336 buffer.copy_from_slice(&data);
337 }
338 Ok(buffer)
339 }
340
341 pub fn from_descriptor(descriptor: &CuSharedMemoryHandleDescriptor) -> CuResult<Self> {
342 let expected = CuSharedMemoryElementType::of::<E>()
343 .ok_or_else(|| cu29_traits::CuError::from("unsupported shared memory element type"))?;
344 if descriptor.element_type != expected {
345 return Err(cu29_traits::CuError::from(
346 "shared memory descriptor element type mismatch",
347 ));
348 }
349 let region = CuSharedMemoryRegion::open(Path::new(&descriptor.path))?;
350 Ok(Self::from_region(
351 region,
352 descriptor.offset_bytes,
353 descriptor.len_elements,
354 ))
355 }
356
357 pub fn descriptor(&self) -> Option<CuSharedMemoryHandleDescriptor>
358 where
359 E: 'static,
360 {
361 CuSharedMemoryElementType::of::<E>().map(|element_type| {
362 CuSharedMemoryHandleDescriptor::new(
363 self.region.path.display().to_string(),
364 self.offset_bytes,
365 self.len_elements,
366 element_type,
367 )
368 })
369 }
370}
371
372impl<E: ElementType> Deref for CuSharedMemoryBuffer<E> {
373 type Target = [E];
374
375 fn deref(&self) -> &Self::Target {
376 let ptr = unsafe { (*self.region.mmap.get()).as_ptr().add(self.offset_bytes) as *const E };
377 unsafe { std::slice::from_raw_parts(ptr, self.len_elements) }
378 }
379}
380
381impl<E: ElementType> DerefMut for CuSharedMemoryBuffer<E> {
382 fn deref_mut(&mut self) -> &mut Self::Target {
383 let ptr = unsafe {
384 (*self.region.mmap.get())
385 .as_mut_ptr()
386 .add(self.offset_bytes) as *mut E
387 };
388 unsafe { std::slice::from_raw_parts_mut(ptr, self.len_elements) }
389 }
390}
391
392impl<E: ElementType> ArrayLike for CuSharedMemoryBuffer<E> {
393 type Element = E;
394}
395
396impl<E: ElementType> Encode for CuSharedMemoryBuffer<E> {
397 fn encode<Enc: Encoder>(&self, encoder: &mut Enc) -> Result<(), EncodeError> {
398 let len = self.len_elements as u64;
399 Encode::encode(&len, encoder)?;
400 for value in self.deref() {
401 value.encode(encoder)?;
402 }
403 Ok(())
404 }
405}
406
407impl<E: ElementType + 'static> Decode<()> for CuSharedMemoryBuffer<E> {
408 fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
409 let len = <u64 as Decode<()>>::decode(decoder)? as usize;
410 let mut vec = Vec::with_capacity(len);
411 for _ in 0..len {
412 vec.push(E::decode(decoder)?);
413 }
414 Self::from_vec_detached(vec).map_err(|e| DecodeError::OtherString(e.to_string()))
415 }
416}
417
418pub enum CuHandleInner<T: Debug> {
422 Pooled(ReusableOwned<T>),
423 Detached(T), }
425
426impl<T> Debug for CuHandleInner<T>
427where
428 T: Debug,
429{
430 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
431 match self {
432 CuHandleInner::Pooled(r) => {
433 write!(f, "Pooled: {:?}", r.deref())
434 }
435 CuHandleInner::Detached(r) => write!(f, "Detached: {r:?}"),
436 }
437 }
438}
439
440impl<T: ArrayLike> Deref for CuHandleInner<T> {
441 type Target = [T::Element];
442
443 fn deref(&self) -> &Self::Target {
444 match self {
445 CuHandleInner::Pooled(pooled) => pooled,
446 CuHandleInner::Detached(detached) => detached,
447 }
448 }
449}
450
451impl<T: ArrayLike> DerefMut for CuHandleInner<T> {
452 fn deref_mut(&mut self) -> &mut Self::Target {
453 match self {
454 CuHandleInner::Pooled(pooled) => pooled.deref_mut(),
455 CuHandleInner::Detached(detached) => detached,
456 }
457 }
458}
459
460#[derive(Debug)]
462pub struct CuHandle<T: ArrayLike>(Arc<Mutex<CuHandleInner<T>>>);
463
464impl<T: ArrayLike> Clone for CuHandle<T> {
465 fn clone(&self) -> Self {
466 Self(self.0.clone())
467 }
468}
469
470impl<T: ArrayLike> Deref for CuHandle<T> {
471 type Target = Arc<Mutex<CuHandleInner<T>>>;
472
473 fn deref(&self) -> &Self::Target {
474 &self.0
475 }
476}
477
478impl<T: ArrayLike> CuHandle<T> {
479 pub fn new_detached(inner: T) -> Self {
481 CuHandle(Arc::new(Mutex::new(CuHandleInner::Detached(inner))))
482 }
483
484 pub fn with_inner<R>(&self, f: impl FnOnce(&CuHandleInner<T>) -> R) -> R {
486 let lock = lock_unpoison(&self.0);
487 f(&*lock)
488 }
489
490 pub fn with_inner_mut<R>(&self, f: impl FnOnce(&mut CuHandleInner<T>) -> R) -> R {
492 let mut lock = lock_unpoison(&self.0);
493 f(&mut *lock)
494 }
495}
496
497impl<U> Serialize for CuHandle<Vec<U>>
498where
499 U: ElementType + Serialize + 'static,
500{
501 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
502 let inner = lock_unpoison(&self.0);
503 match inner.deref() {
504 CuHandleInner::Pooled(pooled) => pooled.deref().serialize(serializer),
505 CuHandleInner::Detached(detached) => detached.serialize(serializer),
506 }
507 }
508}
509
510impl<'de, U> Deserialize<'de> for CuHandle<Vec<U>>
511where
512 U: ElementType + Deserialize<'de> + 'static,
513{
514 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
515 Vec::<U>::deserialize(deserializer).map(CuHandle::new_detached)
516 }
517}
518
519impl<U> Serialize for CuHandle<CuSharedMemoryBuffer<U>>
520where
521 U: ElementType + Serialize + 'static,
522{
523 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
524 let inner = lock_unpoison(&self.0);
525 let buffer = match inner.deref() {
526 CuHandleInner::Pooled(pooled) => pooled.deref(),
527 CuHandleInner::Detached(detached) => detached,
528 };
529
530 if shared_handle_serialization_enabled()
531 && let Some(descriptor) = buffer.descriptor()
532 {
533 return descriptor.serialize(serializer);
534 }
535
536 buffer.deref().serialize(serializer)
537 }
538}
539
540impl<'de, U> Deserialize<'de> for CuHandle<CuSharedMemoryBuffer<U>>
541where
542 U: ElementType + Deserialize<'de> + 'static,
543{
544 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
545 enum Repr<U> {
546 Descriptor(CuSharedMemoryHandleDescriptor),
547 Data(Vec<U>),
548 }
549
550 impl<'de, U> Deserialize<'de> for Repr<U>
551 where
552 U: ElementType + Deserialize<'de>,
553 {
554 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
555 struct ReprVisitor<U>(PhantomData<U>);
556
557 impl<'de, U> Visitor<'de> for ReprVisitor<U>
558 where
559 U: ElementType + Deserialize<'de>,
560 {
561 type Value = Repr<U>;
562
563 fn expecting(
564 &self,
565 formatter: &mut std::fmt::Formatter<'_>,
566 ) -> std::fmt::Result {
567 formatter
568 .write_str("a shared-memory handle descriptor or an element sequence")
569 }
570
571 fn visit_seq<A: SeqAccess<'de>>(self, seq: A) -> Result<Self::Value, A::Error> {
572 let data =
573 Vec::<U>::deserialize(de::value::SeqAccessDeserializer::new(seq))?;
574 Ok(Repr::Data(data))
575 }
576
577 fn visit_map<A: MapAccess<'de>>(self, map: A) -> Result<Self::Value, A::Error> {
578 let descriptor = CuSharedMemoryHandleDescriptor::deserialize(
579 de::value::MapAccessDeserializer::new(map),
580 )?;
581 Ok(Repr::Descriptor(descriptor))
582 }
583 }
584
585 deserializer.deserialize_any(ReprVisitor(PhantomData))
586 }
587 }
588
589 match Repr::<U>::deserialize(deserializer)? {
590 Repr::Descriptor(descriptor) => CuSharedMemoryBuffer::from_descriptor(&descriptor)
591 .map(CuHandle::new_detached)
592 .map_err(de::Error::custom),
593 Repr::Data(data) => CuSharedMemoryBuffer::from_vec_detached(data)
594 .map(CuHandle::new_detached)
595 .map_err(de::Error::custom),
596 }
597 }
598}
599
600impl<T: ArrayLike + Encode> Encode for CuHandle<T>
601where
602 <T as ArrayLike>::Element: 'static,
603{
604 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
605 let inner = lock_unpoison(&self.0);
606 crate::monitoring::record_payload_handle_bytes(
607 inner.deref().len() * size_of::<T::Element>(),
608 );
609 match inner.deref() {
610 CuHandleInner::Pooled(pooled) => pooled.deref().encode(encoder),
611 CuHandleInner::Detached(detached) => detached.encode(encoder),
612 }
613 }
614}
615
616impl<T: ArrayLike> Default for CuHandle<T> {
617 fn default() -> Self {
618 panic!("Cannot create a default CuHandle")
619 }
620}
621
622impl<U: ElementType + Decode<()> + 'static> Decode<()> for CuHandle<Vec<U>> {
623 fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
624 let vec: Vec<U> = Vec::decode(decoder)?;
625 Ok(CuHandle(Arc::new(Mutex::new(CuHandleInner::Detached(vec)))))
626 }
627}
628
629impl<U: ElementType + Decode<()> + 'static> Decode<()> for CuHandle<CuSharedMemoryBuffer<U>> {
630 fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
631 let buffer = CuSharedMemoryBuffer::<U>::decode(decoder)?;
632 Ok(CuHandle(Arc::new(Mutex::new(CuHandleInner::Detached(
633 buffer,
634 )))))
635 }
636}
637
638pub trait CuPool<T: ArrayLike>: PoolMonitor {
641 fn acquire(&self) -> Option<CuHandle<T>>;
643
644 fn copy_from<O>(&self, from: &mut CuHandle<O>) -> CuHandle<T>
646 where
647 O: ArrayLike<Element = T::Element>;
648}
649
650pub trait DeviceCuPool<T: ArrayLike>: CuPool<T> {
652 fn copy_to_host_pool<O>(
655 &self,
656 from_device_handle: &CuHandle<T>,
657 to_host_handle: &mut CuHandle<O>,
658 ) -> CuResult<()>
659 where
660 O: ArrayLike<Element = T::Element>;
661}
662
663pub struct CuHostMemoryPool<T> {
665 id: PoolID,
668 pool: Arc<Pool<T>>,
669 size: usize,
670 buffer_size: usize,
671}
672
673impl<T: ArrayLike + 'static> CuHostMemoryPool<T> {
674 pub fn new<F>(id: &str, size: usize, buffer_initializer: F) -> CuResult<Arc<Self>>
675 where
676 F: Fn() -> T,
677 {
678 let pool = Arc::new(Pool::new(size, buffer_initializer));
679 let buffer_size = pool.try_pull().unwrap().len() * size_of::<T::Element>();
680
681 let og = Self {
682 id: PoolID::from(id).map_err(|_| "Failed to create PoolID")?,
683 pool,
684 size,
685 buffer_size,
686 };
687 let og = Arc::new(og);
688 register_pool(og.clone());
689 Ok(og)
690 }
691}
692
693impl<T: ArrayLike> PoolMonitor for CuHostMemoryPool<T> {
694 fn id(&self) -> PoolID {
695 self.id
696 }
697
698 fn space_left(&self) -> usize {
699 self.pool.len()
700 }
701
702 fn total_size(&self) -> usize {
703 self.size
704 }
705
706 fn buffer_size(&self) -> usize {
707 self.buffer_size
708 }
709}
710
711impl<T: ArrayLike> CuPool<T> for CuHostMemoryPool<T> {
712 fn acquire(&self) -> Option<CuHandle<T>> {
713 let owned_object = self.pool.try_pull_owned(); owned_object.map(|reusable| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(reusable)))))
716 }
717
718 fn copy_from<O: ArrayLike<Element = T::Element>>(&self, from: &mut CuHandle<O>) -> CuHandle<T> {
719 let to_handle = self.acquire().expect("No available buffers in the pool");
720
721 match lock_unpoison(&from.0).deref() {
722 CuHandleInner::Detached(source) => match lock_unpoison(&to_handle.0).deref_mut() {
723 CuHandleInner::Detached(destination) => {
724 destination.copy_from_slice(source);
725 }
726 CuHandleInner::Pooled(destination) => {
727 destination.copy_from_slice(source);
728 }
729 },
730 CuHandleInner::Pooled(source) => match lock_unpoison(&to_handle.0).deref_mut() {
731 CuHandleInner::Detached(destination) => {
732 destination.copy_from_slice(source);
733 }
734 CuHandleInner::Pooled(destination) => {
735 destination.copy_from_slice(source);
736 }
737 },
738 }
739 to_handle
740 }
741}
742
743pub struct CuSharedMemoryPool<E: ElementType> {
746 id: PoolID,
747 pool: Arc<Pool<CuSharedMemoryBuffer<E>>>,
748 size: usize,
749 buffer_size: usize,
750}
751
752impl<E: ElementType + 'static> CuSharedMemoryPool<E> {
753 pub fn new(id: &str, size: usize, elements_per_buffer: usize) -> CuResult<Arc<Self>> {
754 let slot_stride = shared_slot_stride::<E>(elements_per_buffer.max(1));
755 let region = CuSharedMemoryRegion::create(
756 slot_stride
757 .checked_mul(size)
758 .ok_or_else(|| cu29_traits::CuError::from("shared memory pool size overflow"))?,
759 )?;
760 let next_slot = Arc::new(std::sync::atomic::AtomicUsize::new(0));
761 let initializer_region = region.clone();
762 let initializer_next_slot = next_slot.clone();
763 let pool = Arc::new(Pool::new(size, move || {
764 let slot = initializer_next_slot.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
765 assert!(slot < size, "shared memory pool slot index overflow");
766 CuSharedMemoryBuffer::from_region(
767 initializer_region.clone(),
768 slot * slot_stride,
769 elements_per_buffer,
770 )
771 }));
772
773 let pool = Arc::new(Self {
774 id: PoolID::from(id).map_err(|_| "Failed to create PoolID")?,
775 pool,
776 size,
777 buffer_size: elements_per_buffer * size_of::<E>(),
778 });
779 register_pool(pool.clone());
780 Ok(pool)
781 }
782}
783
784impl<E: ElementType> PoolMonitor for CuSharedMemoryPool<E> {
785 fn id(&self) -> PoolID {
786 self.id
787 }
788
789 fn space_left(&self) -> usize {
790 self.pool.len()
791 }
792
793 fn total_size(&self) -> usize {
794 self.size
795 }
796
797 fn buffer_size(&self) -> usize {
798 self.buffer_size
799 }
800}
801
802impl<E: ElementType> CuPool<CuSharedMemoryBuffer<E>> for CuSharedMemoryPool<E> {
803 fn acquire(&self) -> Option<CuHandle<CuSharedMemoryBuffer<E>>> {
804 self.pool
805 .try_pull_owned()
806 .map(|reusable| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(reusable)))))
807 }
808
809 fn copy_from<O>(&self, from: &mut CuHandle<O>) -> CuHandle<CuSharedMemoryBuffer<E>>
810 where
811 O: ArrayLike<Element = E>,
812 {
813 let to_handle = self.acquire().expect("No available buffers in the pool");
814
815 match lock_unpoison(&from.0).deref() {
816 CuHandleInner::Detached(source) => match lock_unpoison(&to_handle.0).deref_mut() {
817 CuHandleInner::Detached(destination) => {
818 destination.copy_from_slice(source);
819 }
820 CuHandleInner::Pooled(destination) => {
821 destination.copy_from_slice(source);
822 }
823 },
824 CuHandleInner::Pooled(source) => match lock_unpoison(&to_handle.0).deref_mut() {
825 CuHandleInner::Detached(destination) => {
826 destination.copy_from_slice(source);
827 }
828 CuHandleInner::Pooled(destination) => {
829 destination.copy_from_slice(source);
830 }
831 },
832 }
833 to_handle
834 }
835}
836
837impl<E: ElementType + 'static> ArrayLike for Vec<E> {
838 type Element = E;
839}
840
841#[cfg(all(feature = "cuda", not(target_os = "macos")))]
842mod cuda {
843 use super::*;
844 use cu29_traits::CuError;
845 use cudarc::driver::{
846 CudaContext, CudaSlice, CudaStream, DeviceRepr, HostSlice, SyncOnDrop, ValidAsZeroBits,
847 };
848 use std::sync::Arc;
849
850 #[derive(Debug)]
851 pub struct CudaSliceWrapper<E>(CudaSlice<E>);
852
853 impl<E> Deref for CudaSliceWrapper<E>
854 where
855 E: ElementType,
856 {
857 type Target = [E];
858
859 fn deref(&self) -> &Self::Target {
860 panic!("You need to copy data to host memory pool before accessing it.");
862 }
863 }
864
865 impl<E> DerefMut for CudaSliceWrapper<E>
866 where
867 E: ElementType,
868 {
869 fn deref_mut(&mut self) -> &mut Self::Target {
870 panic!("You need to copy data to host memory pool before accessing it.");
871 }
872 }
873
874 impl<E: ElementType> ArrayLike for CudaSliceWrapper<E> {
875 type Element = E;
876 }
877
878 impl<E> CudaSliceWrapper<E> {
879 pub fn as_cuda_slice(&self) -> &CudaSlice<E> {
880 &self.0
881 }
882
883 pub fn as_cuda_slice_mut(&mut self) -> &mut CudaSlice<E> {
884 &mut self.0
885 }
886 }
887
888 pub struct HostSliceWrapper<'a, T: ArrayLike> {
890 inner: &'a T,
891 }
892
893 impl<T: ArrayLike> HostSlice<T::Element> for HostSliceWrapper<'_, T> {
894 fn len(&self) -> usize {
895 self.inner.len()
896 }
897
898 unsafe fn stream_synced_slice<'b>(
900 &'b self,
901 stream: &'b CudaStream,
902 ) -> (&'b [T::Element], SyncOnDrop<'b>) {
903 (self.inner.deref(), SyncOnDrop::sync_stream(stream))
904 }
905
906 unsafe fn stream_synced_mut_slice<'b>(
908 &'b mut self,
909 _stream: &'b CudaStream,
910 ) -> (&'b mut [T::Element], SyncOnDrop<'b>) {
911 panic!("Cannot get mutable reference from immutable wrapper")
912 }
913 }
914
915 pub struct HostSliceMutWrapper<'a, T: ArrayLike> {
917 inner: &'a mut T,
918 }
919
920 impl<T: ArrayLike> HostSlice<T::Element> for HostSliceMutWrapper<'_, T> {
921 fn len(&self) -> usize {
922 self.inner.len()
923 }
924
925 unsafe fn stream_synced_slice<'b>(
927 &'b self,
928 stream: &'b CudaStream,
929 ) -> (&'b [T::Element], SyncOnDrop<'b>) {
930 (self.inner.deref(), SyncOnDrop::sync_stream(stream))
931 }
932
933 unsafe fn stream_synced_mut_slice<'b>(
935 &'b mut self,
936 stream: &'b CudaStream,
937 ) -> (&'b mut [T::Element], SyncOnDrop<'b>) {
938 (self.inner.deref_mut(), SyncOnDrop::sync_stream(stream))
939 }
940 }
941
942 impl<E: ElementType + ValidAsZeroBits + DeviceRepr> CuCudaPool<E> {
944 fn get_host_slice_wrapper<O: ArrayLike<Element = E>>(
946 handle_inner: &CuHandleInner<O>,
947 ) -> HostSliceWrapper<'_, O> {
948 match handle_inner {
949 CuHandleInner::Pooled(pooled) => HostSliceWrapper { inner: pooled },
950 CuHandleInner::Detached(detached) => HostSliceWrapper { inner: detached },
951 }
952 }
953
954 fn get_host_slice_mut_wrapper<O: ArrayLike<Element = E>>(
956 handle_inner: &mut CuHandleInner<O>,
957 ) -> HostSliceMutWrapper<'_, O> {
958 match handle_inner {
959 CuHandleInner::Pooled(pooled) => HostSliceMutWrapper { inner: pooled },
960 CuHandleInner::Detached(detached) => HostSliceMutWrapper { inner: detached },
961 }
962 }
963 }
964 pub struct CuCudaPool<E>
966 where
967 E: ElementType + ValidAsZeroBits + DeviceRepr + Unpin,
968 {
969 id: PoolID,
970 stream: Arc<CudaStream>,
971 pool: Arc<Pool<CudaSliceWrapper<E>>>,
972 nb_buffers: usize,
973 nb_element_per_buffer: usize,
974 }
975
976 impl<E: ElementType + ValidAsZeroBits + DeviceRepr> CuCudaPool<E> {
977 #[allow(dead_code)]
978 pub fn new(
979 id: &'static str,
980 ctx: Arc<CudaContext>,
981 nb_buffers: usize,
982 nb_element_per_buffer: usize,
983 ) -> CuResult<Self> {
984 let stream = ctx.default_stream();
985 let pool = (0..nb_buffers)
986 .map(|_| {
987 stream
988 .alloc_zeros(nb_element_per_buffer)
989 .map(CudaSliceWrapper)
990 .map_err(|_| "Failed to allocate device memory")
991 })
992 .collect::<Result<Vec<_>, _>>()?;
993
994 Ok(Self {
995 id: PoolID::from(id).map_err(|_| "Failed to create PoolID")?,
996 stream,
997 pool: Arc::new(Pool::from_vec(pool)),
998 nb_buffers,
999 nb_element_per_buffer,
1000 })
1001 }
1002 }
1003
1004 impl<E> PoolMonitor for CuCudaPool<E>
1005 where
1006 E: DeviceRepr + ElementType + ValidAsZeroBits,
1007 {
1008 fn id(&self) -> PoolID {
1009 self.id
1010 }
1011
1012 fn space_left(&self) -> usize {
1013 self.pool.len()
1014 }
1015
1016 fn total_size(&self) -> usize {
1017 self.nb_buffers
1018 }
1019
1020 fn buffer_size(&self) -> usize {
1021 self.nb_element_per_buffer * size_of::<E>()
1022 }
1023 }
1024
1025 impl<E> CuPool<CudaSliceWrapper<E>> for CuCudaPool<E>
1026 where
1027 E: DeviceRepr + ElementType + ValidAsZeroBits,
1028 {
1029 fn acquire(&self) -> Option<CuHandle<CudaSliceWrapper<E>>> {
1030 self.pool
1031 .try_pull_owned()
1032 .map(|x| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(x)))))
1033 }
1034
1035 fn copy_from<O>(&self, from_handle: &mut CuHandle<O>) -> CuHandle<CudaSliceWrapper<E>>
1036 where
1037 O: ArrayLike<Element = E>,
1038 {
1039 let to_handle = self.acquire().expect("No available buffers in the pool");
1040
1041 {
1042 let from_lock = lock_unpoison(&from_handle.0);
1043 let mut to_lock = lock_unpoison(&to_handle.0);
1044
1045 match &mut *to_lock {
1046 CuHandleInner::Detached(CudaSliceWrapper(to)) => {
1047 let wrapper = Self::get_host_slice_wrapper(&*from_lock);
1048 self.stream
1049 .memcpy_htod(&wrapper, to)
1050 .expect("Failed to copy data to device");
1051 }
1052 CuHandleInner::Pooled(to) => {
1053 let wrapper = Self::get_host_slice_wrapper(&*from_lock);
1054 self.stream
1055 .memcpy_htod(&wrapper, to.as_cuda_slice_mut())
1056 .expect("Failed to copy data to device");
1057 }
1058 }
1059 } to_handle }
1062 }
1063
1064 impl<E> DeviceCuPool<CudaSliceWrapper<E>> for CuCudaPool<E>
1065 where
1066 E: ElementType + ValidAsZeroBits + DeviceRepr,
1067 {
1068 fn copy_to_host_pool<O>(
1070 &self,
1071 device_handle: &CuHandle<CudaSliceWrapper<E>>,
1072 host_handle: &mut CuHandle<O>,
1073 ) -> Result<(), CuError>
1074 where
1075 O: ArrayLike<Element = E>,
1076 {
1077 let device_lock = device_handle.lock().map_err(|e| {
1078 CuError::from("Device handle mutex poisoned").add_cause(&e.to_string())
1079 })?;
1080 let mut host_lock = host_handle.lock().map_err(|e| {
1081 CuError::from("Host handle mutex poisoned").add_cause(&e.to_string())
1082 })?;
1083 let src = match &*device_lock {
1084 CuHandleInner::Pooled(source) => source.as_cuda_slice(),
1085 CuHandleInner::Detached(source) => source.as_cuda_slice(),
1086 };
1087 let mut wrapper = Self::get_host_slice_mut_wrapper(&mut *host_lock);
1088 self.stream.memcpy_dtoh(src, &mut wrapper).map_err(|e| {
1089 CuError::from("Failed to copy data from device to host").add_cause(&e.to_string())
1090 })?;
1091 Ok(())
1092 }
1093 }
1094}
1095
1096#[derive(Debug)]
1097pub struct AlignedBuffer<E: ElementType> {
1099 ptr: *mut E,
1100 size: usize,
1101 layout: Layout,
1102}
1103
1104impl<E: ElementType> AlignedBuffer<E> {
1105 pub fn new(num_elements: usize, alignment: usize) -> Self {
1106 assert!(
1107 num_elements > 0 && size_of::<E>() > 0,
1108 "AlignedBuffer requires a non-zero element count and non-zero-sized element type"
1109 );
1110 let alignment = alignment.max(align_of::<E>());
1111 let alloc_size = num_elements
1112 .checked_mul(size_of::<E>())
1113 .expect("AlignedBuffer allocation size overflow");
1114 let layout = Layout::from_size_align(alloc_size, alignment).unwrap();
1115 let ptr = unsafe { alloc(layout) as *mut E };
1117 if ptr.is_null() {
1118 panic!("Failed to allocate memory");
1119 }
1120 unsafe {
1122 for i in 0..num_elements {
1123 std::ptr::write(ptr.add(i), E::default());
1124 }
1125 }
1126 Self {
1127 ptr,
1128 size: num_elements,
1129 layout,
1130 }
1131 }
1132}
1133
1134impl<E: ElementType> Deref for AlignedBuffer<E> {
1135 type Target = [E];
1136
1137 fn deref(&self) -> &Self::Target {
1138 unsafe { std::slice::from_raw_parts(self.ptr, self.size) }
1140 }
1141}
1142
1143impl<E: ElementType> DerefMut for AlignedBuffer<E> {
1144 fn deref_mut(&mut self) -> &mut Self::Target {
1145 unsafe { std::slice::from_raw_parts_mut(self.ptr, self.size) }
1147 }
1148}
1149
1150impl<E: ElementType> Drop for AlignedBuffer<E> {
1151 fn drop(&mut self) {
1152 unsafe { dealloc(self.ptr as *mut u8, self.layout) }
1154 }
1155}
1156
1157#[cfg(test)]
1158mod tests {
1159 use super::*;
1160
1161 #[test]
1162 fn test_pool() {
1163 use std::cell::RefCell;
1164 let objs = RefCell::new(vec![vec![1], vec![2], vec![3]]);
1165 let holding = objs.borrow().clone();
1166 let objs_as_slices = holding.iter().map(|x| x.as_slice()).collect::<Vec<_>>();
1167 let pool = CuHostMemoryPool::new("mytestcudapool", 3, || objs.borrow_mut().pop().unwrap())
1168 .unwrap();
1169
1170 let obj1 = pool.acquire().unwrap();
1171 {
1172 let obj2 = pool.acquire().unwrap();
1173 assert!(objs_as_slices.contains(&obj1.lock().unwrap().deref().deref()));
1174 assert!(objs_as_slices.contains(&obj2.lock().unwrap().deref().deref()));
1175 assert_eq!(pool.space_left(), 1);
1176 }
1177 assert_eq!(pool.space_left(), 2);
1178
1179 let obj3 = pool.acquire().unwrap();
1180 assert!(objs_as_slices.contains(&obj3.lock().unwrap().deref().deref()));
1181
1182 assert_eq!(pool.space_left(), 1);
1183
1184 let _obj4 = pool.acquire().unwrap();
1185 assert_eq!(pool.space_left(), 0);
1186
1187 let obj5 = pool.acquire();
1188 assert!(obj5.is_none());
1189 }
1190
1191 #[cfg(all(feature = "cuda", has_nvidia_gpu))]
1192 #[test]
1193 fn test_cuda_pool() {
1194 use crate::pool::cuda::CuCudaPool;
1195 use cudarc::driver::CudaContext;
1196 let ctx = CudaContext::new(0).unwrap();
1197 let pool = CuCudaPool::<f32>::new("mytestcudapool", ctx, 3, 1).unwrap();
1198
1199 let _obj1 = pool.acquire().unwrap();
1200
1201 {
1202 let _obj2 = pool.acquire().unwrap();
1203 assert_eq!(pool.space_left(), 1);
1204 }
1205 assert_eq!(pool.space_left(), 2);
1206
1207 let _obj3 = pool.acquire().unwrap();
1208
1209 assert_eq!(pool.space_left(), 1);
1210
1211 let _obj4 = pool.acquire().unwrap();
1212 assert_eq!(pool.space_left(), 0);
1213
1214 let obj5 = pool.acquire();
1215 assert!(obj5.is_none());
1216 }
1217
1218 #[cfg(all(feature = "cuda", has_nvidia_gpu))]
1219 #[test]
1220 fn test_copy_roundtrip() {
1221 use crate::pool::cuda::CuCudaPool;
1222 use cudarc::driver::CudaContext;
1223 let ctx = CudaContext::new(0).unwrap();
1224 let host_pool = CuHostMemoryPool::new("mytesthostpool", 3, || vec![0.0; 1]).unwrap();
1225 let cuda_pool = CuCudaPool::<f32>::new("mytestcudapool", ctx, 3, 1).unwrap();
1226
1227 let cuda_handle = {
1228 let mut initial_handle = host_pool.acquire().unwrap();
1229 {
1230 let mut inner_initial_handle = initial_handle.lock().unwrap();
1231 if let CuHandleInner::Pooled(ref mut pooled) = *inner_initial_handle {
1232 pooled[0] = 42.0;
1233 } else {
1234 panic!();
1235 }
1236 }
1237
1238 cuda_pool.copy_from(&mut initial_handle)
1240 };
1241
1242 let mut final_handle = host_pool.acquire().unwrap();
1244 cuda_pool
1245 .copy_to_host_pool(&cuda_handle, &mut final_handle)
1246 .unwrap();
1247
1248 let value = final_handle.lock().unwrap().deref().deref()[0];
1249 assert_eq!(value, 42.0);
1250 }
1251}