1#[cfg(any(feature = "gpu", feature = "gpu-wasm"))]
7use super::super::runtime;
8use super::super::shaders;
9use super::GpuDevice;
10
11impl GpuDevice {
12 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
14 pub fn relu(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
15 runtime::block_on(async {
16 self.execute_element_wise_op("ReLU", shaders::RELU_SHADER, input, result, None).await
17 })
18 }
19
20 pub async fn relu_async(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
22 self.execute_element_wise_op("ReLU", shaders::RELU_SHADER, input, result, None).await
23 }
24
25 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
27 pub fn leaky_relu(
28 &self,
29 input: &[f32],
30 result: &mut [f32],
31 negative_slope: f32,
32 ) -> Result<(), String> {
33 runtime::block_on(self.leaky_relu_async(input, result, negative_slope))
34 }
35
36 pub async fn leaky_relu_async(
38 &self,
39 input: &[f32],
40 result: &mut [f32],
41 negative_slope: f32,
42 ) -> Result<(), String> {
43 #[repr(C)]
44 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
45 struct LeakyReluParams {
46 negative_slope: f32,
47 }
48
49 let params = LeakyReluParams { negative_slope };
50 let uniform_data = bytemuck::bytes_of(¶ms);
51
52 self.execute_element_wise_op(
53 "LeakyReLU",
54 shaders::LEAKY_RELU_SHADER,
55 input,
56 result,
57 Some(uniform_data),
58 )
59 .await
60 }
61
62 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
64 pub fn elu(&self, input: &[f32], result: &mut [f32], alpha: f32) -> Result<(), String> {
65 runtime::block_on(self.elu_async(input, result, alpha))
66 }
67
68 pub async fn elu_async(
70 &self,
71 input: &[f32],
72 result: &mut [f32],
73 alpha: f32,
74 ) -> Result<(), String> {
75 #[repr(C)]
76 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
77 struct EluParams {
78 alpha: f32,
79 }
80
81 let params = EluParams { alpha };
82 let uniform_data = bytemuck::bytes_of(¶ms);
83
84 self.execute_element_wise_op("ELU", shaders::ELU_SHADER, input, result, Some(uniform_data))
85 .await
86 }
87
88 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
90 pub fn sigmoid(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
91 runtime::block_on(self.sigmoid_async(input, result))
92 }
93
94 pub async fn sigmoid_async(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
96 self.execute_element_wise_op("Sigmoid", shaders::SIGMOID_SHADER, input, result, None).await
97 }
98
99 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
101 pub fn tanh(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
102 runtime::block_on(self.tanh_async(input, result))
103 }
104
105 pub async fn tanh_async(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
107 self.execute_element_wise_op("Tanh", shaders::TANH_SHADER, input, result, None).await
108 }
109
110 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
112 pub fn swish(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
113 runtime::block_on(self.swish_async(input, result))
114 }
115
116 pub async fn swish_async(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
118 self.execute_element_wise_op("Swish", shaders::SWISH_SHADER, input, result, None).await
119 }
120
121 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
123 pub fn gelu(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
124 runtime::block_on(self.gelu_async(input, result))
125 }
126
127 pub async fn gelu_async(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
129 self.execute_element_wise_op("GELU", shaders::GELU_SHADER, input, result, None).await
130 }
131
132 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
134 pub fn clip(
135 &self,
136 input: &[f32],
137 result: &mut [f32],
138 min_val: f32,
139 max_val: f32,
140 ) -> Result<(), String> {
141 runtime::block_on(self.clip_async(input, result, min_val, max_val))
142 }
143
144 pub async fn clip_async(
146 &self,
147 input: &[f32],
148 result: &mut [f32],
149 min_val: f32,
150 max_val: f32,
151 ) -> Result<(), String> {
152 #[repr(C)]
153 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
154 struct ClipParams {
155 min_val: f32,
156 max_val: f32,
157 }
158
159 let params = ClipParams { min_val, max_val };
160 let uniform_data = bytemuck::bytes_of(¶ms);
161
162 self.execute_element_wise_op(
163 "Clip",
164 shaders::CLIP_SHADER,
165 input,
166 result,
167 Some(uniform_data),
168 )
169 .await
170 }
171
172 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
180 pub fn softmax(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
181 runtime::block_on(async { self.softmax_async(input, result).await })
182 }
183
184 pub async fn softmax_async(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
186 let max_val = self.reduce_max(input).await?;
188
189 let exp_vals = self.compute_exp_subtract(input, max_val).await?;
191
192 let sum_exp = self.reduce_sum(&exp_vals).await?;
194
195 self.normalize_by_sum(&exp_vals, result, sum_exp).await?;
197
198 Ok(())
199 }
200
201 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
209 pub fn log_softmax(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
210 runtime::block_on(async { self.log_softmax_async(input, result).await })
211 }
212
213 pub async fn log_softmax_async(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
215 let max_val = self.reduce_max(input).await?;
217
218 let exp_vals = self.compute_exp_subtract(input, max_val).await?;
220
221 let sum_exp = self.reduce_sum(&exp_vals).await?;
223
224 let log_sum_exp = sum_exp.max(f32::EPSILON).ln();
226
227 #[repr(C)]
228 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
229 struct LogSoftmaxParams {
230 max_val: f32,
231 log_sum_exp: f32,
232 }
233
234 let params = LogSoftmaxParams { max_val, log_sum_exp };
235 let uniform_data = bytemuck::bytes_of(¶ms);
236
237 self.execute_element_wise_op(
238 "LogSoftmax",
239 shaders::LOG_SOFTMAX_SHADER,
240 input,
241 result,
242 Some(uniform_data),
243 )
244 .await?;
245
246 Ok(())
247 }
248
249 pub(super) async fn compute_exp_subtract(
251 &self,
252 input: &[f32],
253 max_val: f32,
254 ) -> Result<Vec<f32>, String> {
255 #[repr(C)]
256 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
257 struct MaxValue {
258 max_val: f32,
259 }
260
261 let params = MaxValue { max_val };
262 let uniform_data = bytemuck::bytes_of(¶ms);
263
264 let mut result = vec![0.0f32; input.len()];
265 self.execute_element_wise_op(
266 "SoftmaxExp",
267 shaders::SOFTMAX_EXP_SHADER,
268 input,
269 &mut result,
270 Some(uniform_data),
271 )
272 .await?;
273
274 Ok(result)
275 }
276
277 pub(super) async fn normalize_by_sum(
279 &self,
280 input: &[f32],
281 result: &mut [f32],
282 sum_val: f32,
283 ) -> Result<(), String> {
284 #[repr(C)]
285 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
286 struct SumValue {
287 sum_val: f32,
288 }
289
290 let params = SumValue { sum_val };
291 let uniform_data = bytemuck::bytes_of(¶ms);
292
293 self.execute_element_wise_op(
294 "SoftmaxNormalize",
295 shaders::SOFTMAX_NORMALIZE_SHADER,
296 input,
297 result,
298 Some(uniform_data),
299 )
300 .await?;
301
302 Ok(())
303 }
304}