1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
use custos::{cache::Cache, get_device, number::Number, CDatatype, CPU};

#[cfg(feature = "opencl")]
use custos::CLDevice;

use crate::Matrix;
#[cfg(feature = "cuda")]
use custos::{cuda::launch_kernel1d, Buffer, CudaDevice};

impl<'a, T: CDatatype> Matrix<'a, T> {
    pub fn clip(&self, min: T, max: T) -> Matrix<T> {
        get_device!(self.device(), ClipOp<T>).clip(self, min, max)
    }
}

pub trait ClipOp<T> {
    fn clip(&self, x: &Matrix<T>, min: T, max: T) -> Matrix<T>;
}

impl<T: Number> ClipOp<T> for CPU {
    fn clip(&self, x: &Matrix<T>, min: T, max: T) -> Matrix<T> {
        let mut y = Cache::get::<T, CPU>(self, x.size(), x.node.idx);
        let y_slice = y.as_mut_slice();

        for (idx, value) in x.as_slice().iter().enumerate() {
            if *value < min {
                y_slice[idx] = min;
            } else if *value > max {
                y_slice[idx] = max;
            } else {
                y_slice[idx] = *value;
            }
        }
        (y, x.dims()).into()
    }
}

#[cfg(feature = "opencl")]
fn cl_clip<'a, T: CDatatype>(
    device: &'a CLDevice,
    x: &Matrix<T>,
    min: T,
    max: T,
) -> custos::Result<Matrix<'a, T>> {
    use custos::opencl::enqueue_kernel;

    let src = format!(
        "
        #define MIN {min}
        #define MAX {max}
        __kernel void clip(__global const {datatype}* input, __global {datatype}* output) {{

            size_t id = get_global_id(0);
            if (input[id] < MIN) {{
                output[id] = MIN;
            }} else if (input[id] > MAX) {{
                output[id] = MAX;
            }} else {{
                output[id] = input[id];
            }} 
        }}
    ",
        datatype = T::as_c_type_str()
    );

    let out = Cache::get::<T, _>(device, x.size(), x.node.idx);
    enqueue_kernel(device, &src, [x.size(), 0, 0], None, &[x, &out])?;
    Ok((out, x.dims()).into())
}

#[cfg(feature = "opencl")]
impl<T: CDatatype> ClipOp<T> for CLDevice {
    fn clip(&self, x: &Matrix<T>, min: T, max: T) -> Matrix<T> {
        cl_clip(self, x, min, max).unwrap()
    }
}

#[cfg(feature = "cuda")]
pub fn cu_clip<'a, T: CDatatype>(
    device: &'a CudaDevice,
    x: &Buffer<T>,
    min: T,
    max: T,
) -> custos::Result<Buffer<'a, T>> {
    let src = format!(
        r#"extern "C" __global__ void clip({datatype}* lhs, {datatype} min, {datatype} max, {datatype}* out, int numElements)
            {{
                int idx = blockDim.x * blockIdx.x + threadIdx.x;
                if (idx < numElements) {{
                    {datatype} value = lhs[idx];
                    if (value > max) {{
                        out[idx] = max;
                    }} else if (value < min) {{
                        out[idx] = min;
                    }} else {{
                        out[idx] = value;
                    }}
                }}
              
            }}
    "#,
        datatype = T::as_c_type_str()
    );

    let out = Cache::get::<T, _>(device, x.len(), x.node.idx);
    launch_kernel1d(
        x.len(),
        device,
        &src,
        "clip",
        &[x, &min, &max, &out, &x.len],
    )?;
    Ok(out)
}

#[cfg(feature = "cuda")]
impl<T: CDatatype> ClipOp<T> for CudaDevice {
    fn clip(&self, x: &Matrix<T>, min: T, max: T) -> Matrix<T> {
        let buf = cu_clip(self, x, min, max).unwrap();
        (buf, x.dims()).into()
    }
}