1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
use custos::{
    opencl::enqueue_kernel,
    Buffer, CDatatype, CLDevice, cache::Cache,
};

trait Both {
    fn as_str<'a>() -> &'a str;
}

/*
impl <T: GenericOCL>Both for T {
    fn as_str<'a>() -> &'a str {
        T::as_ocl_type_str()
    }
}


impl <T: !GenericOCL>Both for T {
    fn as_str<'a, >() -> &'a str {
        "undefined"
    }
}
*/

//std::any::TypeId::of::<T>() ... check all impl

/// Element-wise operations. The op/operation is usually "+", "-", "*", "/".
/// "tensor element-wise"
///
/// # Example
/// ```
/// use custos::{CLDevice, Buffer, VecRead};
/// use custos_math::cl_tew;
///
/// fn main() -> Result<(), custos::Error> {
///     let device = CLDevice::new(0)?;
///     let lhs = Buffer::<i16>::from((&device, [15, 30, 21, 5, 8]));
///     let rhs = Buffer::<i16>::from((&device, [10, 9, 8, 6, 3]));
///
///     let result = cl_tew(&device, &lhs, &rhs, "+")?;
///     assert_eq!(vec![25, 39, 29, 11, 11], device.read(&result));
///     Ok(())
/// }
/// ```
pub fn cl_tew<'a, T: CDatatype>(
    device: &'a CLDevice,
    lhs: &Buffer<T>,
    rhs: &Buffer<T>,
    op: &str,
) -> custos::Result<Buffer<'a, T>> {
    let src = format!("
        __kernel void eop(__global {datatype}* self, __global const {datatype}* rhs, __global {datatype}* out) {{
            size_t id = get_global_id(0);
            out[id] = self[id]{op}rhs[id];
        }}
    ", datatype=T::as_c_type_str());

    let gws = [lhs.len, 0, 0];
    let out = Cache::get::<T, _>(device, lhs.len);
    enqueue_kernel(device, &src, gws, None, &[lhs, rhs, &out])?;
    Ok(out)
}

/// Element-wise "assign" operations. The op/operation is usually "+", "-", "*", "/".
///
/// # Example
/// ```
/// use custos::{CLDevice, Buffer, VecRead};
/// use custos_math::cl_tew_self;
///
/// fn main() -> Result<(), custos::Error> {
///     let device = CLDevice::new(0)?;
///     let mut lhs = Buffer::<i16>::from((&device, [15, 30, 21, 5, 8]));
///     let rhs = Buffer::<i16>::from((&device, [10, 9, 8, 6, 3]));
///
///     cl_tew_self(&device, &mut lhs, &rhs, "+")?;
///     assert_eq!(vec![25, 39, 29, 11, 11], device.read(&lhs));
///     Ok(())
/// }
/// ```
pub fn cl_tew_self<T: CDatatype>(
    device: &CLDevice,
    lhs: &mut Buffer<T>,
    rhs: &Buffer<T>,
    op: &str,
) -> custos::Result<()> {
    let src = format!(
        "
        __kernel void eop_self(__global {datatype}* self, __global const {datatype}* rhs) {{
            size_t id = get_global_id(0);
            self[id] = self[id]{op}rhs[id];
        }}
    ",
        datatype = T::as_c_type_str()
    );

    let gws = [lhs.len, 0, 0];
    enqueue_kernel(device, &src, gws, None, &[lhs, rhs])?;
    Ok(())
}