min_cl/api/
cl.rs

1#![allow(dead_code)]
2
3use std::{
4    ffi::{c_void, CString},
5    mem::size_of,
6    usize, vec,
7};
8
9use crate::Error;
10
11use super::{ffi::*, OCLErrorKind};
12
13#[derive(Clone, Copy, Debug)]
14pub struct Platform(cl_platform_id);
15
16impl Platform {
17    pub fn as_ptr(self) -> *mut cl_platform_id {
18        self.0 as *mut cl_platform_id
19    }
20}
21
22pub fn get_platforms() -> Result<Vec<Platform>, Error> {
23    let mut platforms: cl_uint = 0;
24    let value = unsafe { clGetPlatformIDs(0, std::ptr::null_mut(), &mut platforms) };
25
26    if value != 0 {
27        return Err(Error::from(OCLErrorKind::from_value(value)));
28    }
29    let mut vec: Vec<usize> = vec![0; platforms as usize];
30    let (ptr, len, cap) = (vec.as_mut_ptr(), vec.len(), vec.capacity());
31
32    let mut platforms_vec: Vec<Platform> = unsafe {
33        core::mem::forget(vec);
34        Vec::from_raw_parts(ptr as *mut Platform, len, cap)
35    };
36
37    let value = unsafe {
38        clGetPlatformIDs(
39            platforms,
40            platforms_vec.as_mut_ptr() as *mut cl_platform_id,
41            std::ptr::null_mut(),
42        )
43    };
44    if value != 0 {
45        return Err(Error::from(OCLErrorKind::from_value(value)));
46    }
47    Ok(platforms_vec)
48}
49
50#[derive(Clone, Copy)]
51pub enum PlatformInfo {
52    PlatformName = 0x0903,
53}
54pub fn get_platform_info(platform: Platform, param_name: PlatformInfo) -> String {
55    let mut size: size_t = 0;
56    unsafe {
57        clGetPlatformInfo(
58            platform.0,
59            param_name as cl_platform_info,
60            0,
61            std::ptr::null_mut(),
62            &mut size,
63        );
64    };
65
66    let mut param_value = vec![32u8; size];
67
68    unsafe {
69        clGetPlatformInfo(
70            platform.0,
71            param_name as cl_platform_info,
72            size,
73            param_value.as_mut_ptr() as *mut c_void,
74            std::ptr::null_mut(),
75        );
76    };
77
78    println!("param value: {:?}", param_value);
79    String::from_utf8_lossy(&param_value).to_string()
80}
81
82pub enum DeviceType {
83    DEFAULT = (1 << 0),
84    CPU = (1 << 1),
85    GPU = (1 << 2),
86    ACCELERATOR = (1 << 3),
87    //ALL =         0xFFFFFFFF
88}
89
90#[derive(Copy, Clone)]
91pub enum DeviceInfo {
92    MaxMemAllocSize = 0x1010,
93    GlobalMemSize = 0x101F,
94    NAME = 0x102B,
95    VERSION = 0x102F,
96    HostUnifiedMemory = 0x1035,
97}
98#[derive(Clone, Copy, Debug, Hash)]
99pub struct CLIntDevice(pub cl_device_id);
100
101impl CLIntDevice {
102    pub fn get_name(self) -> Result<String, Error> {
103        Ok(get_device_info(self, DeviceInfo::NAME)?.string)
104    }
105    pub fn get_version(self) -> Result<String, Error> {
106        Ok(get_device_info(self, DeviceInfo::VERSION)?.string)
107    }
108    pub fn get_global_mem(self) -> Result<u64, Error> {
109        Ok(get_device_info(self, DeviceInfo::GlobalMemSize)?.size)
110    }
111    pub fn get_max_mem_alloc(self) -> Result<u64, Error> {
112        Ok(get_device_info(self, DeviceInfo::MaxMemAllocSize)?.size)
113    }
114    pub fn unified_mem(self) -> Result<bool, Error> {
115        Ok(get_device_info(self, DeviceInfo::HostUnifiedMemory)?.size != 0)
116    }
117}
118
119pub fn get_device_ids(platform: Platform, device_type: &u64) -> Result<Vec<CLIntDevice>, Error> {
120    let mut num_devices: cl_uint = 0;
121    let value = unsafe {
122        clGetDeviceIDs(
123            platform.0,
124            *device_type,
125            0,
126            std::ptr::null_mut(),
127            &mut num_devices,
128        )
129    };
130    if value != 0 {
131        return Err(Error::from(OCLErrorKind::from_value(value)));
132    }
133
134    let mut vec: Vec<usize> = vec![0; num_devices as usize];
135    let (ptr, len, cap) = (vec.as_mut_ptr(), vec.len(), vec.capacity());
136
137    let mut devices: Vec<CLIntDevice> = unsafe {
138        core::mem::forget(vec);
139        Vec::from_raw_parts(ptr as *mut CLIntDevice, len, cap)
140    };
141
142    let value = unsafe {
143        clGetDeviceIDs(
144            platform.0,
145            DeviceType::GPU as u64,
146            num_devices,
147            devices.as_mut_ptr() as *mut cl_device_id,
148            std::ptr::null_mut(),
149        )
150    };
151    if value != 0 {
152        return Err(Error::from(OCLErrorKind::from_value(value)));
153    }
154    Ok(devices)
155}
156
157pub struct DeviceReturnInfo {
158    pub string: String,
159    pub size: u64,
160    pub data: Vec<u8>,
161}
162
163pub fn get_device_info(
164    device: CLIntDevice,
165    param_name: DeviceInfo,
166) -> Result<DeviceReturnInfo, Error> {
167    let mut size: size_t = 0;
168    let value = unsafe {
169        clGetDeviceInfo(
170            device.0,
171            param_name as cl_device_info,
172            0,
173            std::ptr::null_mut(),
174            &mut size,
175        )
176    };
177    if value != 0 {
178        return Err(Error::from(OCLErrorKind::from_value(value)));
179    }
180    let mut param_value = vec![0; size];
181    let value = unsafe {
182        clGetDeviceInfo(
183            device.0,
184            param_name as cl_device_info,
185            size,
186            param_value.as_mut_ptr() as *mut c_void,
187            std::ptr::null_mut(),
188        )
189    };
190    if value != 0 {
191        return Err(Error::from(OCLErrorKind::from_value(value)));
192    }
193    let string = String::from_utf8_lossy(&param_value).to_string();
194    let size = param_value.iter().fold(0, |x, &i| x << 4 | i as u64);
195
196    Ok(DeviceReturnInfo {
197        string,
198        size,
199        data: param_value,
200    })
201}
202
203#[derive(Debug, Hash)]
204pub struct Context(pub cl_context);
205
206impl Drop for Context {
207    fn drop(&mut self) {
208        unsafe { clReleaseContext(self.0) };
209    }
210}
211
212pub fn create_context(devices: &[CLIntDevice]) -> Result<Context, Error> {
213    let mut err = 0;
214    let r = unsafe {
215        clCreateContext(
216            std::ptr::null(),
217            devices.len() as u32,
218            devices.as_ptr() as *const *mut c_void,
219            std::ptr::null_mut(),
220            std::ptr::null_mut(),
221            &mut err,
222        )
223    };
224    if err != 0 {
225        return Err(Error::from(OCLErrorKind::from_value(err)));
226    }
227    Ok(Context(r))
228}
229
230#[derive(Debug /* remove: */, Clone)]
231pub struct CommandQueue(pub cl_command_queue);
232
233pub fn release_command_queue(cq: &mut CommandQueue) -> Result<(), Error> {
234    let err = unsafe { clReleaseCommandQueue(cq.0) };
235    if err != 0 {
236        return Err(OCLErrorKind::from_value(err).into());
237    }
238    Ok(())
239}
240
241impl Drop for CommandQueue {
242    fn drop(&mut self) {
243        release_command_queue(self).unwrap();
244    }
245}
246
247pub fn create_command_queue(context: &Context, device: CLIntDevice) -> Result<CommandQueue, Error> {
248    let mut err = 0;
249    let r = unsafe { clCreateCommandQueue(context.0, device.0, 0, &mut err) };
250
251    if err != 0 {
252        return Err(Error::from(OCLErrorKind::from_value(err)));
253    }
254    Ok(CommandQueue(r))
255}
256
257pub fn finish(cq: CommandQueue) {
258    unsafe { clFinish(cq.0) };
259}
260
261#[derive(Debug)]
262pub struct Event(pub cl_event);
263
264impl Event {
265    pub fn wait(self) -> Result<(), Error> {
266        wait_for_event(self)
267    }
268}
269
270impl Drop for Event {
271    fn drop(&mut self) {
272        release_event(self).unwrap();
273    }
274}
275
276pub fn wait_for_event(event: Event) -> Result<(), Error> {
277    let event_arr = [event];
278
279    let value = unsafe { clWaitForEvents(1, event_arr.as_ptr() as *mut cl_event) };
280
281    if value != 0 {
282        return Err(Error::from(OCLErrorKind::from_value(value)));
283    }
284
285    Ok(())
286}
287
288pub fn wait_for_events(events: &[Event]) -> Result<(), Error> {
289    let value = unsafe { clWaitForEvents(events.len() as u32, events.as_ptr() as *mut cl_event) };
290
291    if value != 0 {
292        return Err(Error::from(OCLErrorKind::from_value(value)));
293    }
294
295    Ok(())
296}
297
298pub fn release_event(event: &mut Event) -> Result<(), Error> {
299    let value = unsafe { clReleaseEvent(event.0) };
300    if value != 0 {
301        return Err(Error::from(OCLErrorKind::from_value(value)));
302    }
303    Ok(())
304}
305
306pub enum MemFlags {
307    MemReadWrite = 1,
308    MemWriteOnly = 1 << 1,
309    MemReadOnly = 1 << 2,
310    MemUseHostPtr = 1 << 3,
311    MemAllocHostPtr = 1 << 4,
312    MemCopyHostPtr = 1 << 5,
313    MemHostWriteOnly = 1 << 7,
314    MemHostReadOnly = 1 << 8,
315    MemHostNoAccess = 1 << 9,
316}
317
318impl core::ops::BitOr for MemFlags {
319    type Output = u64;
320
321    fn bitor(self, rhs: Self) -> Self::Output {
322        self as u64 | rhs as u64
323    }
324}
325
326pub fn create_buffer<T>(
327    context: &Context,
328    flag: u64,
329    size: usize,
330    data: Option<&[T]>,
331) -> Result<*mut c_void, Error> {
332    let mut err = 0;
333    let host_ptr = match data {
334        Some(d) => d.as_ptr() as cl_mem,
335        None => std::ptr::null_mut(),
336    };
337    let r = unsafe {
338        clCreateBuffer(
339            context.0,
340            flag,
341            size * core::mem::size_of::<T>(),
342            host_ptr,
343            &mut err,
344        )
345    };
346
347    if err != 0 {
348        return Err(Error::from(OCLErrorKind::from_value(err)));
349    }
350    Ok(r)
351}
352
353/// # Safety
354/// valid mem object
355pub unsafe fn release_mem_object(ptr: *mut c_void) -> Result<(), Error> {
356    let value = clReleaseMemObject(ptr);
357    if value != 0 {
358        return Err(Error::from(OCLErrorKind::from_value(value)));
359    }
360    Ok(())
361}
362
363pub fn retain_mem_object(mem: *mut c_void) -> Result<(), Error> {
364    let value = unsafe { clRetainMemObject(mem) };
365    if value != 0 {
366        return Err(Error::from(OCLErrorKind::from_value(value)));
367    }
368    Ok(())
369}
370
371/// # Safety
372/// valid mem object
373pub unsafe fn enqueue_write_buffer<T>(
374    cq: &CommandQueue,
375    mem: *mut c_void,
376    data: &[T],
377    block: bool,
378) -> Result<Event, Error> {
379    let mut events = [std::ptr::null_mut(); 1];
380
381    let value = clEnqueueWriteBuffer(
382        cq.0,
383        mem,
384        block as u32,
385        0,
386        data.len() * core::mem::size_of::<T>(),
387        data.as_ptr() as *mut c_void,
388        0,
389        std::ptr::null(),
390        events.as_mut_ptr() as *mut cl_event,
391    );
392    if value != 0 {
393        return Err(Error::from(OCLErrorKind::from_value(value)));
394    }
395    Ok(Event(events[0]))
396}
397
398/// # Safety
399/// valid mem object
400pub unsafe fn enqueue_read_buffer<T>(
401    cq: &CommandQueue,
402    mem: *mut c_void,
403    data: &mut [T],
404    block: bool,
405) -> Result<Event, Error> {
406    let mut events = [std::ptr::null_mut(); 1];
407    let value = clEnqueueReadBuffer(
408        cq.0,
409        mem,
410        block as u32,
411        0,
412        data.len() * core::mem::size_of::<T>(),
413        data.as_ptr() as *mut c_void,
414        0,
415        std::ptr::null(),
416        events.as_mut_ptr() as *mut cl_event,
417    );
418    if value != 0 {
419        return Err(Error::from(OCLErrorKind::from_value(value)));
420    }
421    Ok(Event(events[0]))
422}
423
424pub fn enqueue_copy_buffer<T>(
425    cq: &CommandQueue,
426    src_mem: *mut c_void,
427    dst_mem: *mut c_void,
428    src_offset: usize,
429    dst_offset: usize,
430    size: usize,
431) -> Result<(), Error> {
432    let mut events = [std::ptr::null_mut(); 1];
433    let value = unsafe {
434        clEnqueueCopyBuffer(
435            cq.0,
436            src_mem,
437            dst_mem,
438            src_offset * size_of::<T>(),
439            dst_offset * size_of::<T>(),
440            size * size_of::<T>(),
441            0,
442            std::ptr::null(),
443            events.as_mut_ptr() as *mut cl_event,
444        )
445    };
446    if value != 0 {
447        return Err(Error::from(OCLErrorKind::from_value(value)));
448    }
449    wait_for_event(Event(events[0]))
450}
451
452pub fn enqueue_copy_buffers<T, I>(
453    cq: &CommandQueue,
454    src_mem: *mut c_void,
455    dst_mem: *mut c_void,
456    to_copy: I,
457) -> Result<(), Error>
458where
459    I: IntoIterator<Item = (usize, usize, usize)>,
460{
461    let to_copy = to_copy.into_iter();
462    let mut events = match to_copy.size_hint() {
463        (0, None) => Vec::new(),
464        (min, None) => Vec::with_capacity(min),
465        (_, Some(max)) => Vec::with_capacity(max),
466    };
467
468    for (src_offset, dst_offset, size) in to_copy {
469        let event = [std::ptr::null_mut(); 1];
470        events.push(event);
471
472        let value = unsafe {
473            clEnqueueCopyBuffer(
474                cq.0,
475                src_mem,
476                dst_mem,
477                src_offset * size_of::<T>(),
478                dst_offset * size_of::<T>(),
479                size * size_of::<T>(),
480                0,
481                std::ptr::null(),
482                events.last_mut().unwrap().as_mut_ptr() as *mut cl_event,
483            )
484        };
485
486        if value != 0 {
487            return Err(Error::from(OCLErrorKind::from_value(value)));
488        }
489    }
490
491    // borrow to avoid moving while the event is still in progress
492    for event in &events {
493        wait_for_event(Event(event[0]))?;
494    }
495
496    Ok(())
497}
498
499#[inline]
500pub fn enqueue_full_copy_buffer<T>(
501    cq: &CommandQueue,
502    src_mem: *mut c_void,
503    dst_mem: *mut c_void,
504    size: usize,
505) -> Result<(), Error> {
506    enqueue_copy_buffer::<T>(cq, src_mem, dst_mem, 0, 0, size)
507}
508
509pub fn unified_ptr<T>(cq: &CommandQueue, ptr: *mut c_void, len: usize) -> Result<*mut T, Error> {
510    unsafe { enqueue_map_buffer::<T>(cq, ptr, true, 2 | 1, 0, len).map(|ptr| ptr as *mut T) }
511}
512
513/// map_flags: Read: 1, Write: 2,
514/// # Safety
515/// valid mem object
516pub unsafe fn enqueue_map_buffer<T>(
517    cq: &CommandQueue,
518    buffer: *mut c_void,
519    block: bool,
520    map_flags: u64,
521    offset: usize,
522    len: usize,
523) -> Result<*mut c_void, Error> {
524    let offset = offset * core::mem::size_of::<T>();
525    let size = len * core::mem::size_of::<T>();
526
527    let mut event = [std::ptr::null_mut(); 1];
528
529    let mut err = 0;
530
531    let ptr = clEnqueueMapBuffer(
532        cq.0,
533        buffer,
534        block as u32,
535        map_flags,
536        offset,
537        size,
538        0,
539        std::ptr::null(),
540        event.as_mut_ptr() as *mut cl_event,
541        &mut err,
542    );
543
544    if err != 0 {
545        return Err(Error::from(OCLErrorKind::from_value(err)));
546    }
547
548    let e = Event(event[0]);
549    wait_for_event(e)?;
550    Ok(ptr)
551}
552/*
553pub fn enqueue_fill_buffer<T>(cq: &CommandQueue, mem: &Mem, pattern: Vec<T>) -> Event {
554    let mut events = vec![std::ptr::null_mut();1];
555    let offset = 0;
556    let pattern_size = core::mem::size_of::<T>();
557    let size = pattern_size*pattern.len();
558    let err = unsafe {clEnqueueFillBuffer(cq.0, mem.0, pattern.as_ptr() as *mut c_void, pattern_size, offset, size, 0, std::ptr::null(), events.as_mut_ptr() as *mut cl_event)};
559    println!("err enq copy bff: {}", err);
560    Event(events[0])
561}
562*/
563
564pub struct Program(pub cl_program);
565
566impl Drop for Program {
567    fn drop(&mut self) {
568        release_program(self).unwrap()
569    }
570}
571
572enum ProgramInfo {
573    BinarySizes = 0x1165,
574    Binaries = 0x1166,
575}
576
577enum ProgramBuildInfo {
578    Status = 0x1181,
579    BuildLog = 0x1183,
580}
581
582pub fn release_program(program: &mut Program) -> Result<(), Error> {
583    let value = unsafe { clReleaseProgram(program.0) };
584    if value != 0 {
585        return Err(Error::from(OCLErrorKind::from_value(value)));
586    }
587    Ok(())
588}
589
590pub fn create_program_with_source(context: &Context, src: &str) -> Result<Program, Error> {
591    let mut err = 0;
592    let cs = CString::new(src).expect("No cstring for you!");
593    let lens = vec![cs.as_bytes().len()];
594    let cstring: Vec<*const _> = vec![cs.as_ptr()];
595    let r = unsafe {
596        clCreateProgramWithSource(
597            context.0,
598            1,
599            cstring.as_ptr() as *const *const _,
600            lens.as_ptr() as *const usize,
601            &mut err,
602        )
603    };
604    if err != 0 {
605        return Err(Error::from(OCLErrorKind::from_value(err)));
606    }
607    Ok(Program(r))
608}
609
610pub fn build_program(
611    program: &Program,
612    devices: &[CLIntDevice],
613    options: Option<&str>,
614) -> Result<(), Error> {
615    let len = devices.len();
616
617    let err = if let Some(options) = options {
618        let options = CString::new(options).unwrap();
619        unsafe {
620            clBuildProgram(
621                program.0,
622                len as u32,
623                devices.as_ptr() as *const *mut c_void,
624                options.as_ptr(),
625                std::ptr::null_mut(),
626                std::ptr::null_mut(),
627            )
628        }
629    } else {
630        unsafe {
631            clBuildProgram(
632                program.0,
633                len as u32,
634                devices.as_ptr() as *const *mut c_void,
635                std::ptr::null(),
636                std::ptr::null_mut(),
637                std::ptr::null_mut(),
638            )
639        }
640    };
641    if err != 0 {
642        return Err(Error::from(OCLErrorKind::from_value(err)));
643    }
644    Ok(())
645}
646
647#[derive(Debug)]
648pub struct Kernel(pub cl_kernel);
649
650impl Drop for Kernel {
651    fn drop(&mut self) {
652        release_kernel(self).unwrap()
653    }
654}
655
656unsafe impl Send for Kernel {}
657unsafe impl Sync for Kernel {}
658
659pub fn create_kernel(program: &Program, str: &str) -> Result<Kernel, Error> {
660    let mut err = 0;
661    let cstring = CString::new(str).unwrap();
662    let kernel = unsafe { clCreateKernel(program.0, cstring.as_ptr(), &mut err) };
663    if err != 0 {
664        return Err(Error::from(OCLErrorKind::from_value(err)));
665    }
666    Ok(Kernel(kernel))
667}
668pub fn create_kernels_in_program(program: &Program) -> Result<Vec<Kernel>, Error> {
669    let mut n_kernels: u32 = 0;
670    let value =
671        unsafe { clCreateKernelsInProgram(program.0, 0, std::ptr::null_mut(), &mut n_kernels) };
672    if value != 0 {
673        return Err(Error::from(OCLErrorKind::from_value(value)));
674    }
675
676    let mut vec: Vec<usize> = vec![0; n_kernels as usize];
677    let (ptr, len, cap) = (vec.as_mut_ptr(), vec.len(), vec.capacity());
678
679    let mut kernels: Vec<Kernel> = unsafe {
680        core::mem::forget(vec);
681        Vec::from_raw_parts(ptr as *mut Kernel, len, cap)
682    };
683    let value = unsafe {
684        clCreateKernelsInProgram(
685            program.0,
686            n_kernels,
687            kernels.as_mut_ptr() as *mut cl_kernel,
688            std::ptr::null_mut(),
689        )
690    };
691    if value != 0 {
692        return Err(Error::from(OCLErrorKind::from_value(value)));
693    }
694
695    Ok(kernels)
696}
697
698pub fn release_kernel(kernel: &mut Kernel) -> Result<(), Error> {
699    let value = unsafe { clReleaseKernel(kernel.0) };
700    if value != 0 {
701        return Err(Error::from(OCLErrorKind::from_value(value)));
702    }
703    Ok(())
704}
705
706pub fn set_kernel_arg(
707    kernel: &Kernel,
708    index: usize,
709    arg: *const c_void,
710    arg_size: usize,
711    is_num: bool,
712) -> Result<(), Error> {
713    let ptr = if is_num {
714        arg
715    } else {
716        &arg as *const *const c_void as *const c_void
717    };
718
719    let value = unsafe { clSetKernelArg(kernel.0, index as u32, arg_size, ptr) };
720    if value != 0 {
721        return Err(Error::from(OCLErrorKind::from_value(value)));
722    }
723    Ok(())
724}
725
726pub fn enqueue_nd_range_kernel(
727    cq: &CommandQueue,
728    kernel: &Kernel,
729    wd: usize,
730    gws: &[usize; 3],
731    lws: Option<&[usize; 3]>,
732    offset: Option<[usize; 3]>,
733) -> Result<(), Error> {
734    let mut events = [std::ptr::null_mut(); 1];
735    let lws = match lws {
736        Some(lws) => lws.as_ptr(),
737        None => std::ptr::null(),
738    };
739    let offset = match offset {
740        Some(offset) => offset.as_ptr(),
741        None => std::ptr::null(),
742    };
743
744    let value = unsafe {
745        clEnqueueNDRangeKernel(
746            cq.0,
747            kernel.0,
748            wd as u32,
749            offset,
750            gws.as_ptr(),
751            lws,
752            0,
753            std::ptr::null(),
754            events.as_mut_ptr() as *mut cl_event,
755        )
756    };
757    if value != 0 {
758        return Err(Error::from(OCLErrorKind::from_value(value)));
759    }
760    let e = Event(events[0]);
761    // e.release();
762    // Ok(())
763    wait_for_event(e)
764}