1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
mod gemm;
mod switching;
mod tew;
pub use gemm::cl_gemm;
pub use switching::*;
pub use tew::*;
use custos::{
devices::opencl::cl_device::CLDevice,
opencl::{
api::{enqueue_write_buffer, wait_for_event},
enqueue_kernel, AsClCvoidPtr,
},
Buffer, CDatatype, Error, cache::Cache,
};
use crate::Matrix;
pub fn cl_str_op<'a, T: CDatatype>(
device: &'a CLDevice,
x: &Matrix<T>,
op: &str,
) -> Result<Matrix<'a, T>, Error> {
let src = format!(
"
__kernel void str_op(__global const {datatype}* lhs, __global {datatype}* out) {{
size_t id = get_global_id(0);
{datatype} x = lhs[id];
out[id] = {op};
}}
",
datatype = T::as_c_type_str()
);
let out = Cache::get::<T, _>(device, x.size());
enqueue_kernel(device, &src, [x.size(), 0, 0], None, &[x, &out])?;
Ok((out, x.dims()).into())
}
pub fn cl_scalar_op<'a, T: CDatatype>(
device: &'a CLDevice,
x: &Matrix<T>,
scalar: T,
op: &str,
) -> Result<Matrix<'a, T>, Error> {
let src = format!("
__kernel void scalar_r_op(__global const {datatype}* x, const {datatype} scalar, __global {datatype}* out) {{
size_t id = get_global_id(0);
out[id] = x[id]{op}scalar;
}}
", datatype=T::as_c_type_str());
let out = Cache::get::<T, _>(device, x.size());
enqueue_kernel(device, &src, [x.size(), 0, 0], None, &[x, &scalar, &out])?;
Ok((out, x.dims()).into())
}
pub fn cl_write<T>(device: &CLDevice, x: &mut Buffer<T>, data: &[T]) {
let event = unsafe { enqueue_write_buffer(&device.queue(), x.ptr.1, data, true).unwrap() };
wait_for_event(event).unwrap();
}
impl<'a, T> AsClCvoidPtr for Matrix<'a, T> {
fn as_cvoid_ptr(&self) -> *const std::ffi::c_void {
self.ptr.1
}
}
impl<'a, T> AsClCvoidPtr for &Matrix<'a, T> {
fn as_cvoid_ptr(&self) -> *const std::ffi::c_void {
self.ptr.1
}
}