ocl_stream/
lib.rs

1/*
2 * opencl stream executor
3 * Copyright (C) 2021 trivernis
4 * See LICENSE for more information
5 */
6
7pub mod executor;
8pub mod traits;
9pub mod utils;
10
11pub use executor::stream;
12pub use executor::OCLStreamExecutor;
13// reexport the ocl crate
14pub use ocl;
15
16#[cfg(test)]
17mod tests {
18    use crate::executor::OCLStreamExecutor;
19    use crate::traits::*;
20    use crate::utils::shared_buffer::SharedBuffer;
21    use ocl::ProQue;
22    use std::ops::Deref;
23
24    #[test]
25    fn it_streams_ocl_calculations() {
26        let pro_que = ProQue::builder()
27            .src(
28                "\
29        __kernel void bench_int(const uint limit, __global int *NUMBERS) {
30            uint id = get_global_id(0);
31            int num = NUMBERS[id];
32            for (int i = 0; i < limit; i++) {
33                num += i;
34            }
35            NUMBERS[id] = num;
36        }",
37            )
38            .dims(1)
39            .build()
40            .unwrap();
41        let stream_executor = OCLStreamExecutor::new(pro_que);
42        let input_buffer: SharedBuffer<u32> = vec![0u32; 100]
43            .to_shared_buffer(stream_executor.pro_que())
44            .unwrap();
45
46        let mut stream = stream_executor.execute_bounded(10, move |ctx| {
47            let pro_que = ctx.pro_que();
48            let tx = ctx.sender();
49
50            let kernel = pro_que
51                .kernel_builder("bench_int")
52                .arg(100)
53                .arg(input_buffer.inner().lock().deref())
54                .global_work_size(100)
55                .build()?;
56            unsafe {
57                kernel.enq()?;
58            }
59
60            let mut result = vec![0u32; 100];
61            input_buffer.read(&mut result)?;
62
63            for num in result {
64                tx.send(num)?;
65            }
66
67            Ok(())
68        });
69
70        let mut count = 0;
71
72        let num = (99f32.powf(2.0) + 99f32) / 2f32;
73        while let Ok(n) = stream.next() {
74            assert_eq!(n, num as u32);
75            count += 1;
76        }
77        assert_eq!(count, 100)
78    }
79}