1use custos::{impl_stack, number::Number, CDatatype, Device, MainMemory, Shape, CPU};
2
3#[cfg(feature = "stack")]
4use custos::Stack;
5
6#[cfg(feature = "opencl")]
7use custos::OpenCL;
8
9use crate::Matrix;
10#[cfg(feature = "cuda")]
11use custos::{cuda::launch_kernel1d, Buffer, CUDA};
12
13impl<'a, T, S: Shape, D: ClipOp<T, S>> Matrix<'a, T, D, S> {
14 pub fn clip(&self, min: T, max: T) -> Matrix<T, D, S> {
15 self.device().clip(self, min, max)
16 }
17}
18
19pub trait ClipOp<T, S: Shape = (), D: Device = Self>: Device {
20 fn clip(&self, x: &Matrix<T, D, S>, min: T, max: T) -> Matrix<T, Self, S>;
21}
22
23#[impl_stack]
24impl<T: Number, D: MainMemory, S: Shape> ClipOp<T, S, D> for CPU {
25 fn clip(&self, x: &Matrix<T, D, S>, min: T, max: T) -> Matrix<T, Self, S> {
26 let mut out = self.retrieve(x.size(), x.node.idx);
27 let out_slice = &mut out[..];
28
29 for (idx, value) in x.iter().enumerate() {
30 if *value < min {
31 out_slice[idx] = min;
32 } else if *value > max {
33 out_slice[idx] = max;
34 } else {
35 out_slice[idx] = *value;
36 }
37 }
38 (out, x.dims()).into()
39 }
40}
41
42#[cfg(feature = "opencl")]
43fn cl_clip<'a, T: CDatatype>(
44 device: &'a OpenCL,
45 x: &Matrix<T, OpenCL>,
46 min: T,
47 max: T,
48) -> custos::Result<Matrix<'a, T, OpenCL>> {
49 use custos::opencl::enqueue_kernel;
50
51 let src = format!(
52 "
53 #define MIN {min}
54 #define MAX {max}
55 __kernel void clip(__global const {datatype}* input, __global {datatype}* output) {{
56
57 size_t id = get_global_id(0);
58 if (input[id] < MIN) {{
59 output[id] = MIN;
60 }} else if (input[id] > MAX) {{
61 output[id] = MAX;
62 }} else {{
63 output[id] = input[id];
64 }}
65 }}
66 ",
67 datatype = T::as_c_type_str()
68 );
69
70 let out = device.retrieve::<T, ()>(x.size(), x.node.idx);
71 enqueue_kernel(device, &src, [x.size(), 0, 0], None, &[x, &out])?;
72 Ok((out, x.dims()).into())
73}
74
75#[cfg(feature = "opencl")]
76impl<T: CDatatype> ClipOp<T> for OpenCL {
77 fn clip(&self, x: &Matrix<T, Self>, min: T, max: T) -> Matrix<T, Self> {
78 cl_clip(self, x, min, max).unwrap()
79 }
80}
81
82#[cfg(feature = "cuda")]
83pub fn cu_clip<'a, T: CDatatype>(
84 device: &'a CUDA,
85 x: &Buffer<T, CUDA>,
86 min: T,
87 max: T,
88) -> custos::Result<Buffer<'a, T, CUDA>> {
89 let src = format!(
90 r#"extern "C" __global__ void clip({datatype}* lhs, {datatype} min, {datatype} max, {datatype}* out, int numElements)
91 {{
92 int idx = blockDim.x * blockIdx.x + threadIdx.x;
93 if (idx < numElements) {{
94 {datatype} value = lhs[idx];
95 if (value > max) {{
96 out[idx] = max;
97 }} else if (value < min) {{
98 out[idx] = min;
99 }} else {{
100 out[idx] = value;
101 }}
102 }}
103
104 }}
105 "#,
106 datatype = T::as_c_type_str()
107 );
108
109 let out = device.retrieve::<T, ()>(x.len(), x);
110 launch_kernel1d(
111 x.len(),
112 device,
113 &src,
114 "clip",
115 &[x, &min, &max, &out, &x.len()],
116 )?;
117 Ok(out)
118}
119
120#[cfg(feature = "cuda")]
121impl<T: CDatatype> ClipOp<T> for CUDA {
122 fn clip(&self, x: &Matrix<T, CUDA>, min: T, max: T) -> Matrix<T, CUDA> {
123 let buf = cu_clip(self, x, min, max).unwrap();
124 (buf, x.dims()).into()
125 }
126}