custos_math/ops/
clip.rs

1use custos::{impl_stack, number::Number, CDatatype, Device, MainMemory, Shape, CPU};
2
3#[cfg(feature = "stack")]
4use custos::Stack;
5
6#[cfg(feature = "opencl")]
7use custos::OpenCL;
8
9use crate::Matrix;
10#[cfg(feature = "cuda")]
11use custos::{cuda::launch_kernel1d, Buffer, CUDA};
12
13impl<'a, T, S: Shape, D: ClipOp<T, S>> Matrix<'a, T, D, S> {
14    pub fn clip(&self, min: T, max: T) -> Matrix<T, D, S> {
15        self.device().clip(self, min, max)
16    }
17}
18
19pub trait ClipOp<T, S: Shape = (), D: Device = Self>: Device {
20    fn clip(&self, x: &Matrix<T, D, S>, min: T, max: T) -> Matrix<T, Self, S>;
21}
22
23#[impl_stack]
24impl<T: Number, D: MainMemory, S: Shape> ClipOp<T, S, D> for CPU {
25    fn clip(&self, x: &Matrix<T, D, S>, min: T, max: T) -> Matrix<T, Self, S> {
26        let mut out = self.retrieve(x.size(), x.node.idx);
27        let out_slice = &mut out[..];
28
29        for (idx, value) in x.iter().enumerate() {
30            if *value < min {
31                out_slice[idx] = min;
32            } else if *value > max {
33                out_slice[idx] = max;
34            } else {
35                out_slice[idx] = *value;
36            }
37        }
38        (out, x.dims()).into()
39    }
40}
41
42#[cfg(feature = "opencl")]
43fn cl_clip<'a, T: CDatatype>(
44    device: &'a OpenCL,
45    x: &Matrix<T, OpenCL>,
46    min: T,
47    max: T,
48) -> custos::Result<Matrix<'a, T, OpenCL>> {
49    use custos::opencl::enqueue_kernel;
50
51    let src = format!(
52        "
53        #define MIN {min}
54        #define MAX {max}
55        __kernel void clip(__global const {datatype}* input, __global {datatype}* output) {{
56
57            size_t id = get_global_id(0);
58            if (input[id] < MIN) {{
59                output[id] = MIN;
60            }} else if (input[id] > MAX) {{
61                output[id] = MAX;
62            }} else {{
63                output[id] = input[id];
64            }} 
65        }}
66    ",
67        datatype = T::as_c_type_str()
68    );
69
70    let out = device.retrieve::<T, ()>(x.size(), x.node.idx);
71    enqueue_kernel(device, &src, [x.size(), 0, 0], None, &[x, &out])?;
72    Ok((out, x.dims()).into())
73}
74
75#[cfg(feature = "opencl")]
76impl<T: CDatatype> ClipOp<T> for OpenCL {
77    fn clip(&self, x: &Matrix<T, Self>, min: T, max: T) -> Matrix<T, Self> {
78        cl_clip(self, x, min, max).unwrap()
79    }
80}
81
82#[cfg(feature = "cuda")]
83pub fn cu_clip<'a, T: CDatatype>(
84    device: &'a CUDA,
85    x: &Buffer<T, CUDA>,
86    min: T,
87    max: T,
88) -> custos::Result<Buffer<'a, T, CUDA>> {
89    let src = format!(
90        r#"extern "C" __global__ void clip({datatype}* lhs, {datatype} min, {datatype} max, {datatype}* out, int numElements)
91            {{
92                int idx = blockDim.x * blockIdx.x + threadIdx.x;
93                if (idx < numElements) {{
94                    {datatype} value = lhs[idx];
95                    if (value > max) {{
96                        out[idx] = max;
97                    }} else if (value < min) {{
98                        out[idx] = min;
99                    }} else {{
100                        out[idx] = value;
101                    }}
102                }}
103              
104            }}
105    "#,
106        datatype = T::as_c_type_str()
107    );
108
109    let out = device.retrieve::<T, ()>(x.len(), x);
110    launch_kernel1d(
111        x.len(),
112        device,
113        &src,
114        "clip",
115        &[x, &min, &max, &out, &x.len()],
116    )?;
117    Ok(out)
118}
119
120#[cfg(feature = "cuda")]
121impl<T: CDatatype> ClipOp<T> for CUDA {
122    fn clip(&self, x: &Matrix<T, CUDA>, min: T, max: T) -> Matrix<T, CUDA> {
123        let buf = cu_clip(self, x, min, max).unwrap();
124        (buf, x.dims()).into()
125    }
126}