1use std::fmt;
2use libc::c_void;
3
4use crate::ffi::{
5 clCreateBuffer, clGetMemObjectInfo, cl_context, cl_int, cl_mem, cl_mem_flags, cl_mem_info,
6};
7
8use crate::cl_helpers::cl_get_info5;
9use crate::{
10 build_output, ClContext, ClNumber, ClPointer, ContextPtr, HostAccessMemFlags,
11 KernelAccessMemFlags, MemFlags, MemInfo, MemLocationMemFlags, Output, NumberType,
12 NumberTyped, ObjectWrapper,
13};
14
15
16pub unsafe fn cl_create_buffer_with_creator<T: ClNumber, B: BufferCreator<T>>(
21 context: cl_context,
22 mem_flags: cl_mem_flags,
23 buffer_creator: B,
24) -> Output<cl_mem> {
25 cl_create_buffer(
26 context,
27 mem_flags,
28 buffer_creator.buffer_byte_size(),
29 buffer_creator.buffer_ptr()
30 )
31}
32
33pub unsafe fn cl_create_buffer(
39 context: cl_context,
40 mem_flags: cl_mem_flags,
41 size_in_bytes: usize,
42 ptr: *mut c_void,
43) -> Output<cl_mem> {
44 let mut err_code: cl_int = 0;
45 let cl_mem_object: cl_mem =
46 clCreateBuffer(context, mem_flags, size_in_bytes, ptr, &mut err_code);
47 build_output(cl_mem_object, err_code)
48}
49
50pub fn cl_get_mem_object_info<T>(device_mem: cl_mem, flag: cl_mem_info) -> Output<ClPointer<T>>
51where
52 T: Copy,
53{
54 unsafe { cl_get_info5(device_mem, flag, clGetMemObjectInfo) }
55}
56
57pub trait BufferCreator<T: ClNumber>: Sized {
58 fn buffer_byte_size(&self) -> usize;
63 fn buffer_ptr(&self) -> *mut c_void;
64 fn mem_config(&self) -> MemConfig;
65}
66
67impl<T: ClNumber> BufferCreator<T> for &[T] {
68 fn buffer_byte_size(&self) -> usize {
69 std::mem::size_of::<T>() * self.len()
70 }
71
72 fn buffer_ptr(&self) -> *mut c_void {
73 self.as_ptr() as *const _ as *mut c_void
74
75 }
76
77 fn mem_config(&self) -> MemConfig {
78 MemConfig::for_data()
79 }
80}
81
82impl<T: ClNumber> BufferCreator<T> for &mut [T] {
83 fn buffer_byte_size(&self) -> usize {
84 std::mem::size_of::<T>() * self.len()
85 }
86
87 fn buffer_ptr(&self) -> *mut c_void {
88 self.as_ptr() as *const _ as *mut c_void
89
90 }
91
92 fn mem_config(&self) -> MemConfig {
93 MemConfig::for_data()
94 }
95}
96
97
98impl<T: ClNumber> BufferCreator<T> for usize {
99 fn buffer_byte_size(&self) -> usize {
100 std::mem::size_of::<T>() * *self
101 }
102
103 fn buffer_ptr(&self) -> *mut c_void {
104 std::ptr::null_mut()
105 }
106
107 fn mem_config(&self) -> MemConfig {
108 MemConfig::for_size()
109 }
110}
111
112pub unsafe trait MemPtr: NumberTyped {
118 unsafe fn mem_ptr(&self) -> cl_mem;
124
125 unsafe fn mem_ptr_ref(&self) -> &cl_mem;
131
132 unsafe fn get_info<I: Copy>(&self, flag: MemInfo) -> Output<ClPointer<I>> {
138 cl_get_mem_object_info::<I>(self.mem_ptr(), flag.into())
139 }
140
141 unsafe fn len(&self) -> Output<usize> {
146 let mem_size_in_bytes = self.size()?;
147 Ok(mem_size_in_bytes / self.number_type().size_of_t())
148 }
149
150 unsafe fn is_empty(&self) -> Output<bool> {
155 self.len().map(|l| l == 0)
156 }
157
158 unsafe fn context(&self) -> Output<ClContext> {
205 self.get_info::<cl_context>(MemInfo::Context)
206 .and_then(|cl_ptr| ClContext::retain_new(cl_ptr.into_one()))
207 }
208
209 unsafe fn reference_count(&self) -> Output<u32> {
214 self.get_info(MemInfo::ReferenceCount)
215 .map(|ret| ret.into_one())
216 }
217
218 unsafe fn size(&self) -> Output<usize> {
223 self.get_info(MemInfo::Size).map(|ret| ret.into_one())
224 }
225
226 unsafe fn offset(&self) -> Output<usize> {
231 self.get_info(MemInfo::Offset).map(|ret| ret.into_one())
232 }
233
234 unsafe fn flags(&self) -> Output<MemFlags> {
239 self.get_info(MemInfo::Flags).map(|ret| ret.into_one())
240 }
241
242 }
247
248#[derive(Eq, PartialEq)]
249pub struct ClMem {
250 inner: ObjectWrapper<cl_mem>,
251 t: NumberType,
252}
253
254impl NumberTyped for ClMem {
255 fn number_type(&self) -> NumberType {
256 self.t
257 }
258}
259
260impl ClMem {
261 pub unsafe fn new<T: ClNumber>(object: cl_mem) -> Output<ClMem> {
268 Ok(ClMem {
269 inner: ObjectWrapper::new(object)?,
270 t: T::number_type()
271 })
272 }
273
274 pub fn create<T: ClNumber, B: BufferCreator<T>>(
275 context: &ClContext,
276 buffer_creator: B,
277 host_access: HostAccess,
278 kernel_access: KernelAccess,
279 mem_location: MemLocation,
280 ) -> Output<ClMem> {
281 unsafe {
282 let mem_object = cl_create_buffer_with_creator(
283 context.context_ptr(),
284 cl_mem_flags::from(host_access)
285 | cl_mem_flags::from(kernel_access)
286 | cl_mem_flags::from(mem_location),
287 buffer_creator,
288 )?;
289 ClMem::new::<T>(mem_object)
290 }
291 }
292
293 pub unsafe fn create_with_config<T: ClNumber, B: BufferCreator<T>>(
301 context: &ClContext,
302 buffer_creator: B,
303 mem_config: MemConfig,
304 ) -> Output<ClMem> {
305 let mem_object = cl_create_buffer_with_creator(
306 context.context_ptr(),
307 mem_config.into(),
308 buffer_creator,
309 )?;
310 ClMem::new::<T>(mem_object)
311 }
312}
313
314unsafe impl MemPtr for ClMem {
315 unsafe fn mem_ptr(&self) -> cl_mem {
316 self.inner.cl_object()
317 }
318
319 unsafe fn mem_ptr_ref(&self) -> &cl_mem {
320 self.inner.cl_object_ref()
321 }
322}
323
324unsafe impl Send for ClMem {}
325
326impl fmt::Debug for ClMem {
327 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
328 write!(f, "{:?}", unsafe { self.mem_ptr() })
329 }
330}
331
332
333#[cfg(test)]
334mod tests {
335 use crate::*;
336
337 #[test]
338 fn mem_can_be_created_with_len() {
339 let (context, _devices) = ll_testing::get_context();
340 let mem_config = MemConfig::default();
341 let _mem: ClMem =
342 unsafe { ClMem::create_with_config::<u32, usize>(&context, 10, mem_config).unwrap() };
343 }
344
345 #[test]
346 fn mem_can_be_created_with_slice() {
347 let (context, _devices) = ll_testing::get_context();
348 let data: Vec<u32> = vec![0, 1, 2, 3, 4];
349 let mem_config = MemConfig::for_data();
350 let _mem: ClMem =
351 unsafe { ClMem::create_with_config(&context, &data[..], mem_config).unwrap() };
352 }
353
354 mod mem_ptr_trait {
355 use crate::*;
356
357 #[test]
358 fn len_method_works() {
359 let (_devices, _context, ll_mem) = ll_testing::get_mem::<u32>(10);
360 let len = unsafe { ll_mem.len().unwrap() };
361 assert_eq!(len, 10);
362 }
363
364 #[test]
365 fn reference_count_method_works() {
366 let (_devices, _context, ll_mem) = ll_testing::get_mem::<u32>(10);
367 let ref_count = unsafe { ll_mem.reference_count().unwrap() };
368 assert_eq!(ref_count, 1);
369 }
370
371 #[test]
372 fn size_method_returns_size_in_bytes() {
373 let (_devices, _context, ll_mem) = ll_testing::get_mem::<u32>(10);
374 let bytes_size = unsafe { ll_mem.size().unwrap() };
375 assert_eq!(bytes_size, 10 * std::mem::size_of::<u32>());
376 }
377
378 #[test]
379 fn offset_method_works() {
380 let (_devices, _context, ll_mem) = ll_testing::get_mem::<u32>(10);
381 let offset = unsafe { ll_mem.offset().unwrap() };
382 assert_eq!(offset, 0);
383 }
384
385 #[test]
386 fn flags_method_works() {
387 let (_devices, _context, ll_mem) = ll_testing::get_mem::<u32>(10);
388 let flags = unsafe { ll_mem.flags().unwrap() };
389 assert_eq!(flags, MemFlags::READ_WRITE_ALLOC_HOST_PTR);
390 }
391 }
392}
393
394pub enum KernelAccess {
395 ReadOnly,
396 WriteOnly,
397 ReadWrite,
398}
399
400impl From<KernelAccess> for KernelAccessMemFlags {
401 fn from(kernel_access: KernelAccess) -> KernelAccessMemFlags {
402 match kernel_access {
403 KernelAccess::ReadOnly => KernelAccessMemFlags::READ_ONLY,
404 KernelAccess::WriteOnly => KernelAccessMemFlags::WRITE_ONLY,
405 KernelAccess::ReadWrite => KernelAccessMemFlags::READ_WRITE,
406 }
407 }
408}
409
410impl From<KernelAccess> for MemFlags {
411 fn from(kernel_access: KernelAccess) -> MemFlags {
412 MemFlags::from(KernelAccessMemFlags::from(kernel_access))
413 }
414}
415
416impl From<KernelAccess> for cl_mem_flags {
417 fn from(kernel_access: KernelAccess) -> cl_mem_flags {
418 cl_mem_flags::from(MemFlags::from(kernel_access))
419 }
420}
421
422pub enum HostAccess {
423 ReadOnly,
424 WriteOnly,
425 NoAccess,
426 ReadWrite,
427}
428
429impl From<HostAccess> for HostAccessMemFlags {
430 fn from(host_access: HostAccess) -> HostAccessMemFlags {
431 match host_access {
432 HostAccess::ReadOnly => HostAccessMemFlags::READ_ONLY,
433 HostAccess::WriteOnly => HostAccessMemFlags::WRITE_ONLY,
434 HostAccess::NoAccess => HostAccessMemFlags::NO_ACCESS,
435 HostAccess::ReadWrite => HostAccessMemFlags::READ_WRITE,
436 }
437 }
438}
439
440impl From<HostAccess> for MemFlags {
441 fn from(host_access: HostAccess) -> MemFlags {
442 MemFlags::from(HostAccessMemFlags::from(host_access))
443 }
444}
445
446impl From<HostAccess> for cl_mem_flags {
447 fn from(host_access: HostAccess) -> cl_mem_flags {
448 cl_mem_flags::from(MemFlags::from(host_access))
449 }
450}
451
452pub enum MemLocation {
457 KeepInPlace,
458 AllocOnDevice,
459 CopyToDevice,
460 ForceCopyToDevice,
461}
462
463impl From<MemLocation> for MemLocationMemFlags {
464 fn from(mem_location: MemLocation) -> MemLocationMemFlags {
465 match mem_location {
466 MemLocation::KeepInPlace => MemLocationMemFlags::KEEP_IN_PLACE,
467 MemLocation::AllocOnDevice => MemLocationMemFlags::ALLOC_ON_DEVICE,
468 MemLocation::CopyToDevice => MemLocationMemFlags::COPY_TO_DEVICE,
469 MemLocation::ForceCopyToDevice => MemLocationMemFlags::FORCE_COPY_TO_DEVICE,
470 }
471 }
472}
473
474impl From<MemLocation> for MemFlags {
475 fn from(mem_location: MemLocation) -> MemFlags {
476 MemFlags::from(MemLocationMemFlags::from(mem_location))
477 }
478}
479
480impl From<MemLocation> for cl_mem_flags {
481 fn from(mem_location: MemLocation) -> cl_mem_flags {
482 cl_mem_flags::from(MemFlags::from(mem_location))
483 }
484}
485
486pub struct MemConfig {
487 pub host_access: HostAccess,
488 pub kernel_access: KernelAccess,
489 pub mem_location: MemLocation,
490}
491
492impl MemConfig {
493 pub fn build(
494 host_access: HostAccess,
495 kernel_access: KernelAccess,
496 mem_location: MemLocation,
497 ) -> MemConfig {
498 MemConfig {
499 host_access,
500 kernel_access,
501 mem_location,
502 }
503 }
504}
505
506impl From<MemConfig> for MemFlags {
507 fn from(mem_config: MemConfig) -> MemFlags {
508 unsafe { MemFlags::from_bits_unchecked(cl_mem_flags::from(mem_config)) }
509 }
510}
511
512impl From<MemConfig> for cl_mem_flags {
513 fn from(mem_config: MemConfig) -> cl_mem_flags {
514 cl_mem_flags::from(mem_config.host_access)
515 | cl_mem_flags::from(mem_config.kernel_access)
516 | cl_mem_flags::from(mem_config.mem_location)
517 }
518}
519
520impl Default for MemConfig {
521 fn default() -> MemConfig {
522 MemConfig {
523 host_access: HostAccess::ReadWrite,
524 kernel_access: KernelAccess::ReadWrite,
525 mem_location: MemLocation::AllocOnDevice,
526 }
527 }
528}
529
530impl MemConfig {
531 pub fn for_data() -> MemConfig {
532 MemConfig {
533 mem_location: MemLocation::CopyToDevice,
534 ..MemConfig::default()
535 }
536 }
537
538 pub fn for_size() -> MemConfig {
539 MemConfig {
540 mem_location: MemLocation::AllocOnDevice,
541 ..MemConfig::default()
542 }
543 }
544}
545
546#[cfg(test)]
547mod mem_flags_tests {
548 use super::*;
549 use crate::KernelAccessMemFlags;
550
551 #[test]
552 fn kernel_access_read_only_conversion_into_kernel_access_mem_flag() {
553 let kernel_access = KernelAccess::ReadOnly;
554 assert_eq!(
555 KernelAccessMemFlags::from(kernel_access),
556 KernelAccessMemFlags::READ_ONLY
557 );
558 }
559
560 #[test]
561 fn kernel_access_write_only_conversion_into_kernel_access_mem_flag() {
562 let kernel_access = KernelAccess::WriteOnly;
563 assert_eq!(
564 KernelAccessMemFlags::from(kernel_access),
565 KernelAccessMemFlags::WRITE_ONLY
566 );
567 }
568
569 #[test]
570 fn kernel_access_convert_read_write_into_kernel_access_mem_flag() {
571 let kernel_access = KernelAccess::ReadWrite;
572 assert_eq!(
573 KernelAccessMemFlags::from(kernel_access),
574 KernelAccessMemFlags::READ_WRITE
575 );
576 }
577
578 #[test]
579 fn host_access_read_only_conversion_into_host_access_mem_flag() {
580 let host_access = HostAccess::ReadOnly;
581 assert_eq!(
582 HostAccessMemFlags::from(host_access),
583 HostAccessMemFlags::READ_ONLY
584 );
585 }
586
587 #[test]
588 fn host_access_write_only_conversion_into_host_access_mem_flag() {
589 let host_access = HostAccess::WriteOnly;
590 assert_eq!(
591 HostAccessMemFlags::from(host_access),
592 HostAccessMemFlags::WRITE_ONLY
593 );
594 }
595
596 #[test]
597 fn host_access_read_write_conversion_into_host_access_mem_flag() {
598 let host_access = HostAccess::ReadWrite;
599 assert_eq!(
600 HostAccessMemFlags::from(host_access),
601 HostAccessMemFlags::READ_WRITE
602 );
603 }
604
605 #[test]
606 fn host_access_no_access_conversion_into_host_access_mem_flag() {
607 let host_access = HostAccess::NoAccess;
608 assert_eq!(
609 HostAccessMemFlags::from(host_access),
610 HostAccessMemFlags::NO_ACCESS
611 );
612 }
613
614 #[test]
615 fn mem_location_keep_in_place_conversion_into_mem_location_mem_flag() {
616 let mem_location = MemLocation::KeepInPlace;
617 assert_eq!(
618 MemLocationMemFlags::from(mem_location),
619 MemLocationMemFlags::KEEP_IN_PLACE
620 );
621 }
622
623 #[test]
624 fn mem_location_alloc_on_device_conversion_into_mem_location_mem_flag() {
625 let mem_location = MemLocation::AllocOnDevice;
626 assert_eq!(
627 MemLocationMemFlags::from(mem_location),
628 MemLocationMemFlags::ALLOC_ON_DEVICE
629 );
630 }
631
632 #[test]
633 fn mem_location_copy_to_device_conversion_into_mem_location_mem_flag() {
634 let mem_location = MemLocation::CopyToDevice;
635 assert_eq!(
636 MemLocationMemFlags::from(mem_location),
637 MemLocationMemFlags::COPY_TO_DEVICE
638 );
639 }
640
641 #[test]
642 fn mem_location_force_copy_to_device_conversion_into_mem_location_mem_flag() {
643 let mem_location = MemLocation::ForceCopyToDevice;
644 assert_eq!(
645 MemLocationMemFlags::from(mem_location),
646 MemLocationMemFlags::FORCE_COPY_TO_DEVICE
647 );
648 }
649
650 #[test]
651 fn mem_config_conversion_into_cl_mem_flags() {
652 let mem_location = MemLocation::AllocOnDevice;
653 let host_access = HostAccess::ReadWrite;
654 let kernel_access = KernelAccess::ReadWrite;
655 let mem_config = MemConfig {
656 mem_location,
657 host_access,
658 kernel_access,
659 };
660 let expected = MemFlags::ALLOC_HOST_PTR.bits()
661 | MemFlags::HOST_READ_WRITE.bits()
662 | MemFlags::KERNEL_READ_WRITE.bits();
663
664 assert_eq!(cl_mem_flags::from(mem_config), expected);
665 }
666}