1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
mod gemm;
mod tew;
mod switching;

pub use gemm::cl_gemm;
pub use tew::*;
pub use switching::*;

use custos::{
    libs::opencl::{cl_device::CLDevice, KernelOptions},
    Error, CDatatype, opencl::{api::{enqueue_write_buffer, wait_for_event}, KernelArg}, Buffer,
};

use crate::Matrix;

pub fn cl_str_op<T: CDatatype>(
    device: &CLDevice,
    x: &Matrix<T>,
    op: &str,
) -> Result<Matrix<T>, Error> {
    let src = format!(
        "
        __kernel void str_op(__global const {datatype}* lhs, __global {datatype}* out) {{
            size_t id = get_global_id(0);
            {datatype} x = lhs[id];
            out[id] = {op};
        }}
    ",
        datatype = T::as_c_type_str()
    );

    let buf = KernelOptions::new(device, x.as_buf(), [x.size(), 0, 0], &src)?
        .with_output(x.size())
        .run()?.unwrap();
    Ok((buf, x.dims()).into())
}

pub fn cl_scalar_op<T: CDatatype>(
    device: &CLDevice,
    x: &Matrix<T>,
    scalar: T,
    op: &str,
) -> Result<Matrix<T>, Error> {
    let src = format!("
        __kernel void scalar_r_op(__global const {datatype}* x, const {datatype} scalar, __global {datatype}* out) {{
            size_t id = get_global_id(0);
            
            out[id] = x[id]{op}scalar;
        }}
    ", datatype=T::as_c_type_str());

    let buf = KernelOptions::new(device, x.as_buf(), [x.size(), 0, 0], &src)?
        .add_arg(&scalar)
        .with_output(x.size())
        .run();
    // TODO: unwrap, Ok()?
    buf.map(|buf| (buf.unwrap(), x.dims()).into())
}

pub fn cl_write<T>(device: &CLDevice, x: &mut Buffer<T>, data: &[T]) {
    let event = unsafe {enqueue_write_buffer(&device.queue(), x.ptr.1, data, true).unwrap()};
    wait_for_event(event).unwrap();
} 

impl<'a, T: Copy> KernelArg<'a, T> for Matrix<T> {
    fn some_buf(&'a self) -> Option<&'a Buffer<T>> {
        Some(self.as_buf())
    }
}

impl<'a, T: Copy> KernelArg<'a, T> for &'a Matrix<T> {
    fn some_buf(&self) -> Option<&'a Buffer<T>> {
        Some(self.as_buf())
    }
}