1#![allow(dead_code)]
2
3use std::{
4 ffi::{c_void, CString},
5 mem::size_of,
6 usize, vec,
7};
8
9use crate::Error;
10
11use super::{ffi::*, OCLErrorKind};
12
13#[derive(Clone, Copy, Debug)]
14pub struct Platform(cl_platform_id);
15
16impl Platform {
17 pub fn as_ptr(self) -> *mut cl_platform_id {
18 self.0 as *mut cl_platform_id
19 }
20}
21
22pub fn get_platforms() -> Result<Vec<Platform>, Error> {
23 let mut platforms: cl_uint = 0;
24 let value = unsafe { clGetPlatformIDs(0, std::ptr::null_mut(), &mut platforms) };
25
26 if value != 0 {
27 return Err(Error::from(OCLErrorKind::from_value(value)));
28 }
29 let mut vec: Vec<usize> = vec![0; platforms as usize];
30 let (ptr, len, cap) = (vec.as_mut_ptr(), vec.len(), vec.capacity());
31
32 let mut platforms_vec: Vec<Platform> = unsafe {
33 core::mem::forget(vec);
34 Vec::from_raw_parts(ptr as *mut Platform, len, cap)
35 };
36
37 let value = unsafe {
38 clGetPlatformIDs(
39 platforms,
40 platforms_vec.as_mut_ptr() as *mut cl_platform_id,
41 std::ptr::null_mut(),
42 )
43 };
44 if value != 0 {
45 return Err(Error::from(OCLErrorKind::from_value(value)));
46 }
47 Ok(platforms_vec)
48}
49
50#[derive(Clone, Copy)]
51pub enum PlatformInfo {
52 PlatformName = 0x0903,
53}
54pub fn get_platform_info(platform: Platform, param_name: PlatformInfo) -> String {
55 let mut size: size_t = 0;
56 unsafe {
57 clGetPlatformInfo(
58 platform.0,
59 param_name as cl_platform_info,
60 0,
61 std::ptr::null_mut(),
62 &mut size,
63 );
64 };
65
66 let mut param_value = vec![32u8; size];
67
68 unsafe {
69 clGetPlatformInfo(
70 platform.0,
71 param_name as cl_platform_info,
72 size,
73 param_value.as_mut_ptr() as *mut c_void,
74 std::ptr::null_mut(),
75 );
76 };
77
78 println!("param value: {:?}", param_value);
79 String::from_utf8_lossy(¶m_value).to_string()
80}
81
82pub enum DeviceType {
83 DEFAULT = (1 << 0),
84 CPU = (1 << 1),
85 GPU = (1 << 2),
86 ACCELERATOR = (1 << 3),
87 }
89
90#[derive(Copy, Clone)]
91pub enum DeviceInfo {
92 MaxMemAllocSize = 0x1010,
93 GlobalMemSize = 0x101F,
94 NAME = 0x102B,
95 VERSION = 0x102F,
96 HostUnifiedMemory = 0x1035,
97}
98#[derive(Clone, Copy, Debug, Hash)]
99pub struct CLIntDevice(pub cl_device_id);
100
101impl CLIntDevice {
102 pub fn get_name(self) -> Result<String, Error> {
103 Ok(get_device_info(self, DeviceInfo::NAME)?.string)
104 }
105 pub fn get_version(self) -> Result<String, Error> {
106 Ok(get_device_info(self, DeviceInfo::VERSION)?.string)
107 }
108 pub fn get_global_mem(self) -> Result<u64, Error> {
109 Ok(get_device_info(self, DeviceInfo::GlobalMemSize)?.size)
110 }
111 pub fn get_max_mem_alloc(self) -> Result<u64, Error> {
112 Ok(get_device_info(self, DeviceInfo::MaxMemAllocSize)?.size)
113 }
114 pub fn unified_mem(self) -> Result<bool, Error> {
115 Ok(get_device_info(self, DeviceInfo::HostUnifiedMemory)?.size != 0)
116 }
117}
118
119pub fn get_device_ids(platform: Platform, device_type: &u64) -> Result<Vec<CLIntDevice>, Error> {
120 let mut num_devices: cl_uint = 0;
121 let value = unsafe {
122 clGetDeviceIDs(
123 platform.0,
124 *device_type,
125 0,
126 std::ptr::null_mut(),
127 &mut num_devices,
128 )
129 };
130 if value != 0 {
131 return Err(Error::from(OCLErrorKind::from_value(value)));
132 }
133
134 let mut vec: Vec<usize> = vec![0; num_devices as usize];
135 let (ptr, len, cap) = (vec.as_mut_ptr(), vec.len(), vec.capacity());
136
137 let mut devices: Vec<CLIntDevice> = unsafe {
138 core::mem::forget(vec);
139 Vec::from_raw_parts(ptr as *mut CLIntDevice, len, cap)
140 };
141
142 let value = unsafe {
143 clGetDeviceIDs(
144 platform.0,
145 DeviceType::GPU as u64,
146 num_devices,
147 devices.as_mut_ptr() as *mut cl_device_id,
148 std::ptr::null_mut(),
149 )
150 };
151 if value != 0 {
152 return Err(Error::from(OCLErrorKind::from_value(value)));
153 }
154 Ok(devices)
155}
156
157pub struct DeviceReturnInfo {
158 pub string: String,
159 pub size: u64,
160 pub data: Vec<u8>,
161}
162
163pub fn get_device_info(
164 device: CLIntDevice,
165 param_name: DeviceInfo,
166) -> Result<DeviceReturnInfo, Error> {
167 let mut size: size_t = 0;
168 let value = unsafe {
169 clGetDeviceInfo(
170 device.0,
171 param_name as cl_device_info,
172 0,
173 std::ptr::null_mut(),
174 &mut size,
175 )
176 };
177 if value != 0 {
178 return Err(Error::from(OCLErrorKind::from_value(value)));
179 }
180 let mut param_value = vec![0; size];
181 let value = unsafe {
182 clGetDeviceInfo(
183 device.0,
184 param_name as cl_device_info,
185 size,
186 param_value.as_mut_ptr() as *mut c_void,
187 std::ptr::null_mut(),
188 )
189 };
190 if value != 0 {
191 return Err(Error::from(OCLErrorKind::from_value(value)));
192 }
193 let string = String::from_utf8_lossy(¶m_value).to_string();
194 let size = param_value.iter().fold(0, |x, &i| x << 4 | i as u64);
195
196 Ok(DeviceReturnInfo {
197 string,
198 size,
199 data: param_value,
200 })
201}
202
203#[derive(Debug, Hash)]
204pub struct Context(pub cl_context);
205
206impl Drop for Context {
207 fn drop(&mut self) {
208 unsafe { clReleaseContext(self.0) };
209 }
210}
211
212pub fn create_context(devices: &[CLIntDevice]) -> Result<Context, Error> {
213 let mut err = 0;
214 let r = unsafe {
215 clCreateContext(
216 std::ptr::null(),
217 devices.len() as u32,
218 devices.as_ptr() as *const *mut c_void,
219 std::ptr::null_mut(),
220 std::ptr::null_mut(),
221 &mut err,
222 )
223 };
224 if err != 0 {
225 return Err(Error::from(OCLErrorKind::from_value(err)));
226 }
227 Ok(Context(r))
228}
229
230#[derive(Debug , Clone)]
231pub struct CommandQueue(pub cl_command_queue);
232
233pub fn release_command_queue(cq: &mut CommandQueue) -> Result<(), Error> {
234 let err = unsafe { clReleaseCommandQueue(cq.0) };
235 if err != 0 {
236 return Err(OCLErrorKind::from_value(err).into());
237 }
238 Ok(())
239}
240
241impl Drop for CommandQueue {
242 fn drop(&mut self) {
243 release_command_queue(self).unwrap();
244 }
245}
246
247pub fn create_command_queue(context: &Context, device: CLIntDevice) -> Result<CommandQueue, Error> {
248 let mut err = 0;
249 let r = unsafe { clCreateCommandQueue(context.0, device.0, 0, &mut err) };
250
251 if err != 0 {
252 return Err(Error::from(OCLErrorKind::from_value(err)));
253 }
254 Ok(CommandQueue(r))
255}
256
257pub fn finish(cq: CommandQueue) {
258 unsafe { clFinish(cq.0) };
259}
260
261#[derive(Debug)]
262pub struct Event(pub cl_event);
263
264impl Event {
265 pub fn wait(self) -> Result<(), Error> {
266 wait_for_event(self)
267 }
268}
269
270impl Drop for Event {
271 fn drop(&mut self) {
272 release_event(self).unwrap();
273 }
274}
275
276pub fn wait_for_event(event: Event) -> Result<(), Error> {
277 let event_arr = [event];
278
279 let value = unsafe { clWaitForEvents(1, event_arr.as_ptr() as *mut cl_event) };
280
281 if value != 0 {
282 return Err(Error::from(OCLErrorKind::from_value(value)));
283 }
284
285 Ok(())
286}
287
288pub fn wait_for_events(events: &[Event]) -> Result<(), Error> {
289 let value = unsafe { clWaitForEvents(events.len() as u32, events.as_ptr() as *mut cl_event) };
290
291 if value != 0 {
292 return Err(Error::from(OCLErrorKind::from_value(value)));
293 }
294
295 Ok(())
296}
297
298pub fn release_event(event: &mut Event) -> Result<(), Error> {
299 let value = unsafe { clReleaseEvent(event.0) };
300 if value != 0 {
301 return Err(Error::from(OCLErrorKind::from_value(value)));
302 }
303 Ok(())
304}
305
306pub enum MemFlags {
307 MemReadWrite = 1,
308 MemWriteOnly = 1 << 1,
309 MemReadOnly = 1 << 2,
310 MemUseHostPtr = 1 << 3,
311 MemAllocHostPtr = 1 << 4,
312 MemCopyHostPtr = 1 << 5,
313 MemHostWriteOnly = 1 << 7,
314 MemHostReadOnly = 1 << 8,
315 MemHostNoAccess = 1 << 9,
316}
317
318impl core::ops::BitOr for MemFlags {
319 type Output = u64;
320
321 fn bitor(self, rhs: Self) -> Self::Output {
322 self as u64 | rhs as u64
323 }
324}
325
326pub fn create_buffer<T>(
327 context: &Context,
328 flag: u64,
329 size: usize,
330 data: Option<&[T]>,
331) -> Result<*mut c_void, Error> {
332 let mut err = 0;
333 let host_ptr = match data {
334 Some(d) => d.as_ptr() as cl_mem,
335 None => std::ptr::null_mut(),
336 };
337 let r = unsafe {
338 clCreateBuffer(
339 context.0,
340 flag,
341 size * core::mem::size_of::<T>(),
342 host_ptr,
343 &mut err,
344 )
345 };
346
347 if err != 0 {
348 return Err(Error::from(OCLErrorKind::from_value(err)));
349 }
350 Ok(r)
351}
352
353pub unsafe fn release_mem_object(ptr: *mut c_void) -> Result<(), Error> {
356 let value = clReleaseMemObject(ptr);
357 if value != 0 {
358 return Err(Error::from(OCLErrorKind::from_value(value)));
359 }
360 Ok(())
361}
362
363pub fn retain_mem_object(mem: *mut c_void) -> Result<(), Error> {
364 let value = unsafe { clRetainMemObject(mem) };
365 if value != 0 {
366 return Err(Error::from(OCLErrorKind::from_value(value)));
367 }
368 Ok(())
369}
370
371pub unsafe fn enqueue_write_buffer<T>(
374 cq: &CommandQueue,
375 mem: *mut c_void,
376 data: &[T],
377 block: bool,
378) -> Result<Event, Error> {
379 let mut events = [std::ptr::null_mut(); 1];
380
381 let value = clEnqueueWriteBuffer(
382 cq.0,
383 mem,
384 block as u32,
385 0,
386 data.len() * core::mem::size_of::<T>(),
387 data.as_ptr() as *mut c_void,
388 0,
389 std::ptr::null(),
390 events.as_mut_ptr() as *mut cl_event,
391 );
392 if value != 0 {
393 return Err(Error::from(OCLErrorKind::from_value(value)));
394 }
395 Ok(Event(events[0]))
396}
397
398pub unsafe fn enqueue_read_buffer<T>(
401 cq: &CommandQueue,
402 mem: *mut c_void,
403 data: &mut [T],
404 block: bool,
405) -> Result<Event, Error> {
406 let mut events = [std::ptr::null_mut(); 1];
407 let value = clEnqueueReadBuffer(
408 cq.0,
409 mem,
410 block as u32,
411 0,
412 data.len() * core::mem::size_of::<T>(),
413 data.as_ptr() as *mut c_void,
414 0,
415 std::ptr::null(),
416 events.as_mut_ptr() as *mut cl_event,
417 );
418 if value != 0 {
419 return Err(Error::from(OCLErrorKind::from_value(value)));
420 }
421 Ok(Event(events[0]))
422}
423
424pub fn enqueue_copy_buffer<T>(
425 cq: &CommandQueue,
426 src_mem: *mut c_void,
427 dst_mem: *mut c_void,
428 src_offset: usize,
429 dst_offset: usize,
430 size: usize,
431) -> Result<(), Error> {
432 let mut events = [std::ptr::null_mut(); 1];
433 let value = unsafe {
434 clEnqueueCopyBuffer(
435 cq.0,
436 src_mem,
437 dst_mem,
438 src_offset * size_of::<T>(),
439 dst_offset * size_of::<T>(),
440 size * size_of::<T>(),
441 0,
442 std::ptr::null(),
443 events.as_mut_ptr() as *mut cl_event,
444 )
445 };
446 if value != 0 {
447 return Err(Error::from(OCLErrorKind::from_value(value)));
448 }
449 wait_for_event(Event(events[0]))
450}
451
452pub fn enqueue_copy_buffers<T, I>(
453 cq: &CommandQueue,
454 src_mem: *mut c_void,
455 dst_mem: *mut c_void,
456 to_copy: I,
457) -> Result<(), Error>
458where
459 I: IntoIterator<Item = (usize, usize, usize)>,
460{
461 let to_copy = to_copy.into_iter();
462 let mut events = match to_copy.size_hint() {
463 (0, None) => Vec::new(),
464 (min, None) => Vec::with_capacity(min),
465 (_, Some(max)) => Vec::with_capacity(max),
466 };
467
468 for (src_offset, dst_offset, size) in to_copy {
469 let event = [std::ptr::null_mut(); 1];
470 events.push(event);
471
472 let value = unsafe {
473 clEnqueueCopyBuffer(
474 cq.0,
475 src_mem,
476 dst_mem,
477 src_offset * size_of::<T>(),
478 dst_offset * size_of::<T>(),
479 size * size_of::<T>(),
480 0,
481 std::ptr::null(),
482 events.last_mut().unwrap().as_mut_ptr() as *mut cl_event,
483 )
484 };
485
486 if value != 0 {
487 return Err(Error::from(OCLErrorKind::from_value(value)));
488 }
489 }
490
491 for event in &events {
493 wait_for_event(Event(event[0]))?;
494 }
495
496 Ok(())
497}
498
499#[inline]
500pub fn enqueue_full_copy_buffer<T>(
501 cq: &CommandQueue,
502 src_mem: *mut c_void,
503 dst_mem: *mut c_void,
504 size: usize,
505) -> Result<(), Error> {
506 enqueue_copy_buffer::<T>(cq, src_mem, dst_mem, 0, 0, size)
507}
508
509pub fn unified_ptr<T>(cq: &CommandQueue, ptr: *mut c_void, len: usize) -> Result<*mut T, Error> {
510 unsafe { enqueue_map_buffer::<T>(cq, ptr, true, 2 | 1, 0, len).map(|ptr| ptr as *mut T) }
511}
512
513pub unsafe fn enqueue_map_buffer<T>(
517 cq: &CommandQueue,
518 buffer: *mut c_void,
519 block: bool,
520 map_flags: u64,
521 offset: usize,
522 len: usize,
523) -> Result<*mut c_void, Error> {
524 let offset = offset * core::mem::size_of::<T>();
525 let size = len * core::mem::size_of::<T>();
526
527 let mut event = [std::ptr::null_mut(); 1];
528
529 let mut err = 0;
530
531 let ptr = clEnqueueMapBuffer(
532 cq.0,
533 buffer,
534 block as u32,
535 map_flags,
536 offset,
537 size,
538 0,
539 std::ptr::null(),
540 event.as_mut_ptr() as *mut cl_event,
541 &mut err,
542 );
543
544 if err != 0 {
545 return Err(Error::from(OCLErrorKind::from_value(err)));
546 }
547
548 let e = Event(event[0]);
549 wait_for_event(e)?;
550 Ok(ptr)
551}
552pub struct Program(pub cl_program);
565
566impl Drop for Program {
567 fn drop(&mut self) {
568 release_program(self).unwrap()
569 }
570}
571
572enum ProgramInfo {
573 BinarySizes = 0x1165,
574 Binaries = 0x1166,
575}
576
577enum ProgramBuildInfo {
578 Status = 0x1181,
579 BuildLog = 0x1183,
580}
581
582pub fn release_program(program: &mut Program) -> Result<(), Error> {
583 let value = unsafe { clReleaseProgram(program.0) };
584 if value != 0 {
585 return Err(Error::from(OCLErrorKind::from_value(value)));
586 }
587 Ok(())
588}
589
590pub fn create_program_with_source(context: &Context, src: &str) -> Result<Program, Error> {
591 let mut err = 0;
592 let cs = CString::new(src).expect("No cstring for you!");
593 let lens = vec![cs.as_bytes().len()];
594 let cstring: Vec<*const _> = vec![cs.as_ptr()];
595 let r = unsafe {
596 clCreateProgramWithSource(
597 context.0,
598 1,
599 cstring.as_ptr() as *const *const _,
600 lens.as_ptr() as *const usize,
601 &mut err,
602 )
603 };
604 if err != 0 {
605 return Err(Error::from(OCLErrorKind::from_value(err)));
606 }
607 Ok(Program(r))
608}
609
610pub fn build_program(
611 program: &Program,
612 devices: &[CLIntDevice],
613 options: Option<&str>,
614) -> Result<(), Error> {
615 let len = devices.len();
616
617 let err = if let Some(options) = options {
618 let options = CString::new(options).unwrap();
619 unsafe {
620 clBuildProgram(
621 program.0,
622 len as u32,
623 devices.as_ptr() as *const *mut c_void,
624 options.as_ptr(),
625 std::ptr::null_mut(),
626 std::ptr::null_mut(),
627 )
628 }
629 } else {
630 unsafe {
631 clBuildProgram(
632 program.0,
633 len as u32,
634 devices.as_ptr() as *const *mut c_void,
635 std::ptr::null(),
636 std::ptr::null_mut(),
637 std::ptr::null_mut(),
638 )
639 }
640 };
641 if err != 0 {
642 return Err(Error::from(OCLErrorKind::from_value(err)));
643 }
644 Ok(())
645}
646
647#[derive(Debug)]
648pub struct Kernel(pub cl_kernel);
649
650impl Drop for Kernel {
651 fn drop(&mut self) {
652 release_kernel(self).unwrap()
653 }
654}
655
656unsafe impl Send for Kernel {}
657unsafe impl Sync for Kernel {}
658
659pub fn create_kernel(program: &Program, str: &str) -> Result<Kernel, Error> {
660 let mut err = 0;
661 let cstring = CString::new(str).unwrap();
662 let kernel = unsafe { clCreateKernel(program.0, cstring.as_ptr(), &mut err) };
663 if err != 0 {
664 return Err(Error::from(OCLErrorKind::from_value(err)));
665 }
666 Ok(Kernel(kernel))
667}
668pub fn create_kernels_in_program(program: &Program) -> Result<Vec<Kernel>, Error> {
669 let mut n_kernels: u32 = 0;
670 let value =
671 unsafe { clCreateKernelsInProgram(program.0, 0, std::ptr::null_mut(), &mut n_kernels) };
672 if value != 0 {
673 return Err(Error::from(OCLErrorKind::from_value(value)));
674 }
675
676 let mut vec: Vec<usize> = vec![0; n_kernels as usize];
677 let (ptr, len, cap) = (vec.as_mut_ptr(), vec.len(), vec.capacity());
678
679 let mut kernels: Vec<Kernel> = unsafe {
680 core::mem::forget(vec);
681 Vec::from_raw_parts(ptr as *mut Kernel, len, cap)
682 };
683 let value = unsafe {
684 clCreateKernelsInProgram(
685 program.0,
686 n_kernels,
687 kernels.as_mut_ptr() as *mut cl_kernel,
688 std::ptr::null_mut(),
689 )
690 };
691 if value != 0 {
692 return Err(Error::from(OCLErrorKind::from_value(value)));
693 }
694
695 Ok(kernels)
696}
697
698pub fn release_kernel(kernel: &mut Kernel) -> Result<(), Error> {
699 let value = unsafe { clReleaseKernel(kernel.0) };
700 if value != 0 {
701 return Err(Error::from(OCLErrorKind::from_value(value)));
702 }
703 Ok(())
704}
705
706pub fn set_kernel_arg(
707 kernel: &Kernel,
708 index: usize,
709 arg: *const c_void,
710 arg_size: usize,
711 is_num: bool,
712) -> Result<(), Error> {
713 let ptr = if is_num {
714 arg
715 } else {
716 &arg as *const *const c_void as *const c_void
717 };
718
719 let value = unsafe { clSetKernelArg(kernel.0, index as u32, arg_size, ptr) };
720 if value != 0 {
721 return Err(Error::from(OCLErrorKind::from_value(value)));
722 }
723 Ok(())
724}
725
726pub fn enqueue_nd_range_kernel(
727 cq: &CommandQueue,
728 kernel: &Kernel,
729 wd: usize,
730 gws: &[usize; 3],
731 lws: Option<&[usize; 3]>,
732 offset: Option<[usize; 3]>,
733) -> Result<(), Error> {
734 let mut events = [std::ptr::null_mut(); 1];
735 let lws = match lws {
736 Some(lws) => lws.as_ptr(),
737 None => std::ptr::null(),
738 };
739 let offset = match offset {
740 Some(offset) => offset.as_ptr(),
741 None => std::ptr::null(),
742 };
743
744 let value = unsafe {
745 clEnqueueNDRangeKernel(
746 cq.0,
747 kernel.0,
748 wd as u32,
749 offset,
750 gws.as_ptr(),
751 lws,
752 0,
753 std::ptr::null(),
754 events.as_mut_ptr() as *mut cl_event,
755 )
756 };
757 if value != 0 {
758 return Err(Error::from(OCLErrorKind::from_value(value)));
759 }
760 let e = Event(events[0]);
761 wait_for_event(e)
764}