1use std::sync::Arc;
15
16use ferrotorch_core::error::{FerrotorchError, FerrotorchResult};
17use ferrotorch_core::gpu_dispatch::{GpuBackend, GpuBufferHandle, GpuRngState};
18
19use crate::buffer::CudaBuffer;
20use crate::device::GpuDevice;
21
22pub struct CudaBackendImpl {
32 devices: Vec<Arc<GpuDevice>>,
33}
34
35impl CudaBackendImpl {
36 pub fn new() -> FerrotorchResult<Self> {
43 let device = Arc::new(
44 GpuDevice::new(0).map_err(|e| FerrotorchError::InvalidArgument {
45 message: format!("CUDA init failed: {e}"),
46 })?,
47 );
48 Ok(Self {
49 devices: vec![device],
50 })
51 }
52
53 pub fn default_device(&self) -> FerrotorchResult<&Arc<GpuDevice>> {
55 self.device(0)
56 }
57
58 fn device(&self, ordinal: usize) -> FerrotorchResult<&Arc<GpuDevice>> {
60 self.devices
61 .get(ordinal)
62 .ok_or(FerrotorchError::InvalidArgument {
63 message: format!("CUDA device {ordinal} not available"),
64 })
65 }
66
67 fn wrap_buffer(buf: CudaBuffer<f32>, ordinal: usize) -> GpuBufferHandle {
69 let len = buf.len();
70 GpuBufferHandle::new(Box::new(buf), ordinal, len)
71 }
72
73 fn wrap_buffer_f64(buf: CudaBuffer<f64>, ordinal: usize) -> GpuBufferHandle {
75 let len = buf.len();
76 GpuBufferHandle::new(Box::new(buf), ordinal, len)
77 }
78
79 fn unwrap_buffer(handle: &GpuBufferHandle) -> FerrotorchResult<&CudaBuffer<f32>> {
81 handle
82 .downcast_ref::<CudaBuffer<f32>>()
83 .ok_or(FerrotorchError::InvalidArgument {
84 message: "GPU handle does not contain a CudaBuffer<f32>".into(),
85 })
86 }
87
88 fn unwrap_buffer_mut(handle: &mut GpuBufferHandle) -> FerrotorchResult<&mut CudaBuffer<f32>> {
90 handle
91 .downcast_mut::<CudaBuffer<f32>>()
92 .ok_or(FerrotorchError::InvalidArgument {
93 message: "GPU handle does not contain a CudaBuffer<f32>".into(),
94 })
95 }
96
97 fn unwrap_buffer_f64_mut(handle: &mut GpuBufferHandle) -> FerrotorchResult<&mut CudaBuffer<f64>> {
99 handle
100 .downcast_mut::<CudaBuffer<f64>>()
101 .ok_or(FerrotorchError::InvalidArgument {
102 message: "GPU handle does not contain a CudaBuffer<f64>".into(),
103 })
104 }
105
106 fn unwrap_buffer_f64(handle: &GpuBufferHandle) -> FerrotorchResult<&CudaBuffer<f64>> {
108 handle
109 .downcast_ref::<CudaBuffer<f64>>()
110 .ok_or(FerrotorchError::InvalidArgument {
111 message: "GPU handle does not contain a CudaBuffer<f64>".into(),
112 })
113 }
114
115 fn map_gpu_err(e: crate::error::GpuError) -> FerrotorchError {
117 FerrotorchError::InvalidArgument {
118 message: format!("{e}"),
119 }
120 }
121}
122
123impl GpuBackend for CudaBackendImpl {
128 fn as_any(&self) -> &dyn std::any::Any {
129 self
130 }
131
132 fn raw_device_ptr(&self, handle: &GpuBufferHandle) -> *const std::ffi::c_void {
133 use cudarc::driver::DevicePtr;
134 let dev = match self.device(handle.device_ordinal()) {
135 Ok(d) => d,
136 Err(_) => return std::ptr::null(),
137 };
138 let stream = dev.stream();
139 if let Ok(buf) = Self::unwrap_buffer(handle) {
140 let (ptr, _sync) = buf.inner().device_ptr(&stream);
141 ptr as *const std::ffi::c_void
142 } else if let Ok(buf) = Self::unwrap_buffer_f64(handle) {
143 let (ptr, _sync) = buf.inner().device_ptr(&stream);
144 ptr as *const std::ffi::c_void
145 } else {
146 std::ptr::null()
147 }
148 }
149
150 fn raw_device_ptr_mut(&self, handle: &mut GpuBufferHandle) -> *mut std::ffi::c_void {
151 use cudarc::driver::DevicePtrMut;
152 let ordinal = handle.device_ordinal();
153 let dev = match self.device(ordinal) {
154 Ok(d) => d,
155 Err(_) => return std::ptr::null_mut(),
156 };
157 let stream = dev.stream();
158 if let Some(buf) = handle.downcast_mut::<CudaBuffer<f32>>() {
159 let (ptr, _sync) = buf.inner_mut().device_ptr_mut(&stream);
160 ptr as *mut std::ffi::c_void
161 } else if let Some(buf) = handle.downcast_mut::<CudaBuffer<f64>>() {
162 let (ptr, _sync) = buf.inner_mut().device_ptr_mut(&stream);
163 ptr as *mut std::ffi::c_void
164 } else {
165 std::ptr::null_mut()
166 }
167 }
168
169 fn buffer_elem_size(&self, handle: &GpuBufferHandle) -> usize {
170 if Self::unwrap_buffer(handle).is_ok() {
171 4 } else if Self::unwrap_buffer_f64(handle).is_ok() {
173 8 } else {
175 0
176 }
177 }
178
179 fn cpu_to_gpu(
180 &self,
181 data: &[u8],
182 elem_size: usize,
183 device: usize,
184 ) -> FerrotorchResult<GpuBufferHandle> {
185 let dev = self.device(device)?;
186 match elem_size {
187 4 => {
188 let count = data.len() / 4;
191 let f32_data: &[f32] =
192 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, count) };
193 let buf = crate::transfer::cpu_to_gpu(f32_data, dev).map_err(Self::map_gpu_err)?;
194 Ok(Self::wrap_buffer(buf, device))
195 }
196 8 => {
197 let count = data.len() / 8;
200 let f64_data: &[f64] =
201 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f64, count) };
202 let buf = crate::transfer::cpu_to_gpu(f64_data, dev).map_err(Self::map_gpu_err)?;
203 Ok(Self::wrap_buffer_f64(buf, device))
204 }
205 other => Err(FerrotorchError::InvalidArgument {
206 message: format!("cpu_to_gpu: unsupported elem_size {other} (expected 4 or 8)"),
207 }),
208 }
209 }
210
211 fn cpu_to_gpu_pinned(
212 &self,
213 data: &[u8],
214 elem_size: usize,
215 device: usize,
216 ) -> FerrotorchResult<GpuBufferHandle> {
217 let dev = self.device(device)?;
218 match elem_size {
219 4 => {
220 let count = data.len() / 4;
221 let f32_data: &[f32] =
222 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, count) };
223 let buf = crate::transfer::cpu_to_gpu_pinned(f32_data, dev)
224 .map_err(Self::map_gpu_err)?;
225 Ok(Self::wrap_buffer(buf, device))
226 }
227 8 => {
228 let count = data.len() / 8;
229 let f64_data: &[f64] =
230 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f64, count) };
231 let buf = crate::transfer::cpu_to_gpu_pinned(f64_data, dev)
232 .map_err(Self::map_gpu_err)?;
233 Ok(Self::wrap_buffer_f64(buf, device))
234 }
235 other => Err(FerrotorchError::InvalidArgument {
236 message: format!(
237 "cpu_to_gpu_pinned: unsupported elem_size {other} (expected 4 or 8)"
238 ),
239 }),
240 }
241 }
242
243 fn gpu_to_cpu(&self, handle: &GpuBufferHandle) -> FerrotorchResult<Vec<u8>> {
244 let dev = self.device(handle.device_ordinal())?;
245
246 if let Ok(buf) = Self::unwrap_buffer(handle) {
248 let f32_data = crate::transfer::gpu_to_cpu(buf, dev).map_err(Self::map_gpu_err)?;
249
250 let bytes = unsafe {
255 let mut v = std::mem::ManuallyDrop::new(f32_data);
256 let ptr = v.as_mut_ptr() as *mut u8;
257 let len = v.len() * 4;
258 let cap = v.capacity() * 4;
259 Vec::from_raw_parts(ptr, len, cap)
260 };
261 Ok(bytes)
262 } else if let Ok(buf) = Self::unwrap_buffer_f64(handle) {
263 let f64_data = crate::transfer::gpu_to_cpu(buf, dev).map_err(Self::map_gpu_err)?;
264
265 let bytes = unsafe {
270 let mut v = std::mem::ManuallyDrop::new(f64_data);
271 let ptr = v.as_mut_ptr() as *mut u8;
272 let len = v.len() * 8;
273 let cap = v.capacity() * 8;
274 Vec::from_raw_parts(ptr, len, cap)
275 };
276 Ok(bytes)
277 } else {
278 Err(FerrotorchError::InvalidArgument {
279 message: "gpu_to_cpu: handle is neither CudaBuffer<f32> nor CudaBuffer<f64>".into(),
280 })
281 }
282 }
283
284 fn clone_buffer(&self, handle: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
285 let bytes = self.gpu_to_cpu(handle)?;
288 let elem_size = if handle.downcast_ref::<CudaBuffer<f64>>().is_some() {
290 8
291 } else {
292 4
293 };
294 self.cpu_to_gpu(&bytes, elem_size, handle.device_ordinal())
295 }
296
297 fn alloc_zeros(
298 &self,
299 len: usize,
300 elem_size: usize,
301 device: usize,
302 ) -> FerrotorchResult<GpuBufferHandle> {
303 let dev = self.device(device)?;
304 match elem_size {
305 4 => {
306 let buf = crate::transfer::alloc_zeros_f32(len, dev).map_err(Self::map_gpu_err)?;
307 Ok(Self::wrap_buffer(buf, device))
308 }
309 8 => {
310 let buf = crate::transfer::alloc_zeros_f64(len, dev).map_err(Self::map_gpu_err)?;
311 Ok(Self::wrap_buffer_f64(buf, device))
312 }
313 other => Err(FerrotorchError::InvalidArgument {
314 message: format!("alloc_zeros: unsupported elem_size {other} (expected 4 or 8)"),
315 }),
316 }
317 }
318
319 fn add_f32(
322 &self,
323 a: &GpuBufferHandle,
324 b: &GpuBufferHandle,
325 ) -> FerrotorchResult<GpuBufferHandle> {
326 let a_buf = Self::unwrap_buffer(a)?;
327 let b_buf = Self::unwrap_buffer(b)?;
328 let dev = self.device(a.device_ordinal())?;
329 let result = crate::kernels::gpu_add(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
330 Ok(Self::wrap_buffer(result, a.device_ordinal()))
331 }
332
333 fn sub_f32(
334 &self,
335 a: &GpuBufferHandle,
336 b: &GpuBufferHandle,
337 ) -> FerrotorchResult<GpuBufferHandle> {
338 let a_buf = Self::unwrap_buffer(a)?;
339 let b_buf = Self::unwrap_buffer(b)?;
340 let dev = self.device(a.device_ordinal())?;
341 let result = crate::kernels::gpu_sub(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
342 Ok(Self::wrap_buffer(result, a.device_ordinal()))
343 }
344
345 fn mul_f32(
346 &self,
347 a: &GpuBufferHandle,
348 b: &GpuBufferHandle,
349 ) -> FerrotorchResult<GpuBufferHandle> {
350 let a_buf = Self::unwrap_buffer(a)?;
351 let b_buf = Self::unwrap_buffer(b)?;
352 let dev = self.device(a.device_ordinal())?;
353 let result = crate::kernels::gpu_mul(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
354 Ok(Self::wrap_buffer(result, a.device_ordinal()))
355 }
356
357 fn neg_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
358 let a_buf = Self::unwrap_buffer(a)?;
359 let dev = self.device(a.device_ordinal())?;
360 let result = crate::kernels::gpu_neg(a_buf, dev).map_err(Self::map_gpu_err)?;
361 Ok(Self::wrap_buffer(result, a.device_ordinal()))
362 }
363
364 fn relu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
365 let a_buf = Self::unwrap_buffer(a)?;
366 let dev = self.device(a.device_ordinal())?;
367 let result = crate::kernels::gpu_relu(a_buf, dev).map_err(Self::map_gpu_err)?;
368 Ok(Self::wrap_buffer(result, a.device_ordinal()))
369 }
370
371 fn div_f32(
372 &self,
373 a: &GpuBufferHandle,
374 b: &GpuBufferHandle,
375 ) -> FerrotorchResult<GpuBufferHandle> {
376 let a_buf = Self::unwrap_buffer(a)?;
377 let b_buf = Self::unwrap_buffer(b)?;
378 let dev = self.device(a.device_ordinal())?;
379 let result = crate::kernels::gpu_div(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
380 Ok(Self::wrap_buffer(result, a.device_ordinal()))
381 }
382
383 fn exp_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
384 let a_buf = Self::unwrap_buffer(a)?;
385 let dev = self.device(a.device_ordinal())?;
386 let result = crate::kernels::gpu_exp(a_buf, dev).map_err(Self::map_gpu_err)?;
387 Ok(Self::wrap_buffer(result, a.device_ordinal()))
388 }
389
390 fn log_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
391 let a_buf = Self::unwrap_buffer(a)?;
392 let dev = self.device(a.device_ordinal())?;
393 let result = crate::kernels::gpu_log(a_buf, dev).map_err(Self::map_gpu_err)?;
394 Ok(Self::wrap_buffer(result, a.device_ordinal()))
395 }
396
397 fn sqrt_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
398 let a_buf = Self::unwrap_buffer(a)?;
399 let dev = self.device(a.device_ordinal())?;
400 let result = crate::kernels::gpu_sqrt(a_buf, dev).map_err(Self::map_gpu_err)?;
401 Ok(Self::wrap_buffer(result, a.device_ordinal()))
402 }
403
404 fn pow_f32(&self, a: &GpuBufferHandle, exponent: f32) -> FerrotorchResult<GpuBufferHandle> {
405 let a_buf = Self::unwrap_buffer(a)?;
406 let dev = self.device(a.device_ordinal())?;
407 let result =
408 crate::kernels::gpu_pow(a_buf, exponent, dev).map_err(Self::map_gpu_err)?;
409 Ok(Self::wrap_buffer(result, a.device_ordinal()))
410 }
411
412 fn abs_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
413 let a_buf = Self::unwrap_buffer(a)?;
414 let dev = self.device(a.device_ordinal())?;
415 let result = crate::kernels::gpu_abs(a_buf, dev).map_err(Self::map_gpu_err)?;
416 Ok(Self::wrap_buffer(result, a.device_ordinal()))
417 }
418
419 fn sigmoid_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
420 let a_buf = Self::unwrap_buffer(a)?;
421 let dev = self.device(a.device_ordinal())?;
422 let result = crate::kernels::gpu_sigmoid(a_buf, dev).map_err(Self::map_gpu_err)?;
423 Ok(Self::wrap_buffer(result, a.device_ordinal()))
424 }
425
426 fn tanh_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
427 let a_buf = Self::unwrap_buffer(a)?;
428 let dev = self.device(a.device_ordinal())?;
429 let result = crate::kernels::gpu_tanh(a_buf, dev).map_err(Self::map_gpu_err)?;
430 Ok(Self::wrap_buffer(result, a.device_ordinal()))
431 }
432
433 fn add_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
438 let a_buf = Self::unwrap_buffer_f64(a)?;
439 let b_buf = Self::unwrap_buffer_f64(b)?;
440 let dev = self.device(a.device_ordinal())?;
441 let result = crate::kernels::gpu_add_f64(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
442 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
443 }
444
445 fn sub_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
446 let a_buf = Self::unwrap_buffer_f64(a)?;
447 let b_buf = Self::unwrap_buffer_f64(b)?;
448 let dev = self.device(a.device_ordinal())?;
449 let result = crate::kernels::gpu_sub_f64(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
450 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
451 }
452
453 fn mul_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
454 let a_buf = Self::unwrap_buffer_f64(a)?;
455 let b_buf = Self::unwrap_buffer_f64(b)?;
456 let dev = self.device(a.device_ordinal())?;
457 let result = crate::kernels::gpu_mul_f64(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
458 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
459 }
460
461 fn div_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
462 let a_buf = Self::unwrap_buffer_f64(a)?;
463 let b_buf = Self::unwrap_buffer_f64(b)?;
464 let dev = self.device(a.device_ordinal())?;
465 let result = crate::kernels::gpu_div_f64(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
466 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
467 }
468
469 fn neg_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
470 let a_buf = Self::unwrap_buffer_f64(a)?;
471 let dev = self.device(a.device_ordinal())?;
472 let result = crate::kernels::gpu_neg_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
473 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
474 }
475
476 fn relu_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
477 let a_buf = Self::unwrap_buffer_f64(a)?;
478 let dev = self.device(a.device_ordinal())?;
479 let result = crate::kernels::gpu_relu_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
480 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
481 }
482
483 fn scale_f64(&self, a: &GpuBufferHandle, scalar: f64) -> FerrotorchResult<GpuBufferHandle> {
484 let a_buf = Self::unwrap_buffer_f64(a)?;
485 let dev = self.device(a.device_ordinal())?;
486 let result = crate::kernels::gpu_scale_f64(a_buf, scalar, dev).map_err(Self::map_gpu_err)?;
487 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
488 }
489
490 fn exp_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
491 let a_buf = Self::unwrap_buffer_f64(a)?;
492 let dev = self.device(a.device_ordinal())?;
493 let result = crate::kernels::gpu_exp_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
494 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
495 }
496
497 fn log_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
498 let a_buf = Self::unwrap_buffer_f64(a)?;
499 let dev = self.device(a.device_ordinal())?;
500 let result = crate::kernels::gpu_log_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
501 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
502 }
503
504 fn sqrt_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
505 let a_buf = Self::unwrap_buffer_f64(a)?;
506 let dev = self.device(a.device_ordinal())?;
507 let result = crate::kernels::gpu_sqrt_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
508 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
509 }
510
511 fn pow_f64(&self, a: &GpuBufferHandle, exponent: f64) -> FerrotorchResult<GpuBufferHandle> {
512 let a_buf = Self::unwrap_buffer_f64(a)?;
513 let dev = self.device(a.device_ordinal())?;
514 let result = crate::kernels::gpu_pow_f64(a_buf, exponent, dev).map_err(Self::map_gpu_err)?;
515 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
516 }
517
518 fn abs_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
519 let a_buf = Self::unwrap_buffer_f64(a)?;
520 let dev = self.device(a.device_ordinal())?;
521 let result = crate::kernels::gpu_abs_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
522 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
523 }
524
525 fn sigmoid_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
526 let a_buf = Self::unwrap_buffer_f64(a)?;
527 let dev = self.device(a.device_ordinal())?;
528 let result = crate::kernels::gpu_sigmoid_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
529 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
530 }
531
532 fn tanh_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
533 let a_buf = Self::unwrap_buffer_f64(a)?;
534 let dev = self.device(a.device_ordinal())?;
535 let result = crate::kernels::gpu_tanh_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
536 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
537 }
538
539 fn relu_backward_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
541 let g_buf = Self::unwrap_buffer_f64(grad)?;
542 let i_buf = Self::unwrap_buffer_f64(input)?;
543 let dev = self.device(grad.device_ordinal())?;
544 let result = crate::kernels::gpu_relu_backward_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
545 Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
546 }
547
548 fn sigmoid_backward_f64(&self, grad: &GpuBufferHandle, output: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
549 let g_buf = Self::unwrap_buffer_f64(grad)?;
550 let o_buf = Self::unwrap_buffer_f64(output)?;
551 let dev = self.device(grad.device_ordinal())?;
552 let result = crate::kernels::gpu_sigmoid_backward_f64(g_buf, o_buf, dev).map_err(Self::map_gpu_err)?;
553 Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
554 }
555
556 fn tanh_backward_f64(&self, grad: &GpuBufferHandle, output: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
557 let g_buf = Self::unwrap_buffer_f64(grad)?;
558 let o_buf = Self::unwrap_buffer_f64(output)?;
559 let dev = self.device(grad.device_ordinal())?;
560 let result = crate::kernels::gpu_tanh_backward_f64(g_buf, o_buf, dev).map_err(Self::map_gpu_err)?;
561 Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
562 }
563
564 fn gelu_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
567 let a_buf = Self::unwrap_buffer_f64(a)?;
568 let dev = self.device(a.device_ordinal())?;
569 let result = crate::kernels::gpu_gelu_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
570 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
571 }
572
573 fn gelu_tanh_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
574 let a_buf = Self::unwrap_buffer_f64(a)?;
575 let dev = self.device(a.device_ordinal())?;
576 let result = crate::kernels::gpu_gelu_tanh_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
577 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
578 }
579
580 fn gelu_erf_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
581 let a_buf = Self::unwrap_buffer_f64(a)?;
582 let dev = self.device(a.device_ordinal())?;
583 let result = crate::kernels::gpu_gelu_erf_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
584 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
585 }
586
587 fn silu_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
588 let a_buf = Self::unwrap_buffer_f64(a)?;
589 let dev = self.device(a.device_ordinal())?;
590 let result = crate::kernels::gpu_silu_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
591 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
592 }
593
594 fn elu_f64(&self, a: &GpuBufferHandle, alpha: f64) -> FerrotorchResult<GpuBufferHandle> {
595 let a_buf = Self::unwrap_buffer_f64(a)?;
596 let dev = self.device(a.device_ordinal())?;
597 let result = crate::kernels::gpu_elu_f64(a_buf, alpha, dev).map_err(Self::map_gpu_err)?;
598 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
599 }
600
601 fn mish_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
602 let a_buf = Self::unwrap_buffer_f64(a)?;
603 let dev = self.device(a.device_ordinal())?;
604 let result = crate::kernels::gpu_mish_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
605 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
606 }
607
608 fn clamp_f64(&self, a: &GpuBufferHandle, min_val: f64, max_val: f64) -> FerrotorchResult<GpuBufferHandle> {
609 let a_buf = Self::unwrap_buffer_f64(a)?;
610 let dev = self.device(a.device_ordinal())?;
611 let result = crate::kernels::gpu_clamp_f64(a_buf, min_val, max_val, dev).map_err(Self::map_gpu_err)?;
612 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
613 }
614
615 fn gelu_backward_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
618 let g_buf = Self::unwrap_buffer_f64(grad)?;
619 let i_buf = Self::unwrap_buffer_f64(input)?;
620 let dev = self.device(grad.device_ordinal())?;
621 let result = crate::kernels::gpu_gelu_backward_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
622 Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
623 }
624
625 fn gelu_backward_tanh_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
626 let g_buf = Self::unwrap_buffer_f64(grad)?;
627 let i_buf = Self::unwrap_buffer_f64(input)?;
628 let dev = self.device(grad.device_ordinal())?;
629 let result = crate::kernels::gpu_gelu_backward_tanh_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
630 Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
631 }
632
633 fn gelu_backward_erf_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
634 let g_buf = Self::unwrap_buffer_f64(grad)?;
635 let i_buf = Self::unwrap_buffer_f64(input)?;
636 let dev = self.device(grad.device_ordinal())?;
637 let result = crate::kernels::gpu_gelu_backward_erf_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
638 Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
639 }
640
641 fn silu_backward_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
642 let g_buf = Self::unwrap_buffer_f64(grad)?;
643 let i_buf = Self::unwrap_buffer_f64(input)?;
644 let dev = self.device(grad.device_ordinal())?;
645 let result = crate::kernels::gpu_silu_backward_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
646 Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
647 }
648
649 fn elu_backward_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle, alpha: f64) -> FerrotorchResult<GpuBufferHandle> {
650 let g_buf = Self::unwrap_buffer_f64(grad)?;
651 let i_buf = Self::unwrap_buffer_f64(input)?;
652 let dev = self.device(grad.device_ordinal())?;
653 let result = crate::kernels::gpu_elu_backward_f64(g_buf, i_buf, alpha, dev).map_err(Self::map_gpu_err)?;
654 Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
655 }
656
657 fn mish_backward_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
658 let g_buf = Self::unwrap_buffer_f64(grad)?;
659 let i_buf = Self::unwrap_buffer_f64(input)?;
660 let dev = self.device(grad.device_ordinal())?;
661 let result = crate::kernels::gpu_mish_backward_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
662 Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
663 }
664
665 fn cumsum_f64(&self, a: &GpuBufferHandle, outer: usize, dim_size: usize, inner: usize) -> FerrotorchResult<GpuBufferHandle> {
667 let a_buf = Self::unwrap_buffer_f64(a)?;
668 let dev = self.device(a.device_ordinal())?;
669 let result = crate::kernels::gpu_cumsum_f64(a_buf, outer, dim_size, inner, dev).map_err(Self::map_gpu_err)?;
670 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
671 }
672
673 fn cumprod_f64(&self, a: &GpuBufferHandle, outer: usize, dim_size: usize, inner: usize) -> FerrotorchResult<GpuBufferHandle> {
674 let a_buf = Self::unwrap_buffer_f64(a)?;
675 let dev = self.device(a.device_ordinal())?;
676 let result = crate::kernels::gpu_cumprod_f64(a_buf, outer, dim_size, inner, dev).map_err(Self::map_gpu_err)?;
677 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
678 }
679
680 fn cummax_f64(&self, a: &GpuBufferHandle, outer: usize, dim_size: usize, inner: usize) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
681 let a_buf = Self::unwrap_buffer_f64(a)?;
682 let dev = self.device(a.device_ordinal())?;
683 let (vals, idxs) = crate::kernels::gpu_cummax_f64(a_buf, outer, dim_size, inner, dev).map_err(Self::map_gpu_err)?;
684 let ord = a.device_ordinal();
685 Ok((Self::wrap_buffer_f64(vals, ord), Self::wrap_buffer_f64(idxs, ord)))
686 }
687
688 fn cummin_f64(&self, a: &GpuBufferHandle, outer: usize, dim_size: usize, inner: usize) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
689 let a_buf = Self::unwrap_buffer_f64(a)?;
690 let dev = self.device(a.device_ordinal())?;
691 let (vals, idxs) = crate::kernels::gpu_cummin_f64(a_buf, outer, dim_size, inner, dev).map_err(Self::map_gpu_err)?;
692 let ord = a.device_ordinal();
693 Ok((Self::wrap_buffer_f64(vals, ord), Self::wrap_buffer_f64(idxs, ord)))
694 }
695
696 fn logcumsumexp_f64(&self, a: &GpuBufferHandle, outer: usize, dim_size: usize, inner: usize) -> FerrotorchResult<GpuBufferHandle> {
697 let a_buf = Self::unwrap_buffer_f64(a)?;
698 let dev = self.device(a.device_ordinal())?;
699 let result = crate::kernels::gpu_logcumsumexp_f64(a_buf, outer, dim_size, inner, dev).map_err(Self::map_gpu_err)?;
700 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
701 }
702
703 fn transpose_2d_f64(&self, a: &GpuBufferHandle, m: usize, n: usize) -> FerrotorchResult<GpuBufferHandle> {
705 let a_buf = Self::unwrap_buffer_f64(a)?;
706 let dev = self.device(a.device_ordinal())?;
707 let result = crate::kernels::gpu_transpose_2d_f64(a_buf, m, n, dev).map_err(Self::map_gpu_err)?;
708 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
709 }
710
711 fn permute_0213_f64(&self, a: &GpuBufferHandle, d0: usize, d1: usize, d2: usize, d3: usize) -> FerrotorchResult<GpuBufferHandle> {
712 let a_buf = Self::unwrap_buffer_f64(a)?;
713 let dev = self.device(a.device_ordinal())?;
714 let result = crate::kernels::gpu_permute_0213_f64(a_buf, d0, d1, d2, d3, dev).map_err(Self::map_gpu_err)?;
715 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
716 }
717
718 fn broadcast_add_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle, a_shape: &[usize], b_shape: &[usize], out_shape: &[usize]) -> FerrotorchResult<GpuBufferHandle> {
720 let a_buf = Self::unwrap_buffer_f64(a)?;
721 let b_buf = Self::unwrap_buffer_f64(b)?;
722 let dev = self.device(a.device_ordinal())?;
723 let result = crate::kernels::gpu_broadcast_add_f64(a_buf, b_buf, a_shape, b_shape, out_shape, dev).map_err(Self::map_gpu_err)?;
724 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
725 }
726
727 fn broadcast_sub_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle, a_shape: &[usize], b_shape: &[usize], out_shape: &[usize]) -> FerrotorchResult<GpuBufferHandle> {
728 let a_buf = Self::unwrap_buffer_f64(a)?;
729 let b_buf = Self::unwrap_buffer_f64(b)?;
730 let dev = self.device(a.device_ordinal())?;
731 let result = crate::kernels::gpu_broadcast_sub_f64(a_buf, b_buf, a_shape, b_shape, out_shape, dev).map_err(Self::map_gpu_err)?;
732 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
733 }
734
735 fn broadcast_mul_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle, a_shape: &[usize], b_shape: &[usize], out_shape: &[usize]) -> FerrotorchResult<GpuBufferHandle> {
736 let a_buf = Self::unwrap_buffer_f64(a)?;
737 let b_buf = Self::unwrap_buffer_f64(b)?;
738 let dev = self.device(a.device_ordinal())?;
739 let result = crate::kernels::gpu_broadcast_mul_f64(a_buf, b_buf, a_shape, b_shape, out_shape, dev).map_err(Self::map_gpu_err)?;
740 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
741 }
742
743 fn broadcast_div_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle, a_shape: &[usize], b_shape: &[usize], out_shape: &[usize]) -> FerrotorchResult<GpuBufferHandle> {
744 let a_buf = Self::unwrap_buffer_f64(a)?;
745 let b_buf = Self::unwrap_buffer_f64(b)?;
746 let dev = self.device(a.device_ordinal())?;
747 let result = crate::kernels::gpu_broadcast_div_f64(a_buf, b_buf, a_shape, b_shape, out_shape, dev).map_err(Self::map_gpu_err)?;
748 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
749 }
750
751 fn sum_f64(&self, a: &GpuBufferHandle, _n: usize) -> FerrotorchResult<GpuBufferHandle> {
753 let a_buf = Self::unwrap_buffer_f64(a)?;
754 let dev = self.device(a.device_ordinal())?;
755 let result = crate::kernels::gpu_reduce_sum_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
756 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
757 }
758
759 fn sum_axis_f64(&self, a: &GpuBufferHandle, shape: &[usize], axis: usize) -> FerrotorchResult<GpuBufferHandle> {
760 let a_buf = Self::unwrap_buffer_f64(a)?;
761 let dev = self.device(a.device_ordinal())?;
762 let outer: usize = shape[..axis].iter().product();
763 let axis_size = shape[axis];
764 let inner: usize = shape[axis + 1..].iter().product();
765 let result = crate::kernels::gpu_sum_axis_f64(a_buf, outer, axis_size, inner, dev).map_err(Self::map_gpu_err)?;
766 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
767 }
768
769 fn softmax_f64(&self, a: &GpuBufferHandle, rows: usize, cols: usize) -> FerrotorchResult<GpuBufferHandle> {
772 let a_buf = Self::unwrap_buffer_f64(a)?;
773 let dev = self.device(a.device_ordinal())?;
774 let result = crate::kernels::gpu_softmax_f64(a_buf, rows, cols, dev).map_err(Self::map_gpu_err)?;
775 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
776 }
777
778 fn softmax_backward_f64(&self, grad: &GpuBufferHandle, output: &GpuBufferHandle, cols: usize) -> FerrotorchResult<GpuBufferHandle> {
779 let grad_buf = Self::unwrap_buffer_f64(grad)?;
780 let output_buf = Self::unwrap_buffer_f64(output)?;
781 let dev = self.device(grad.device_ordinal())?;
782 let result = crate::kernels::gpu_softmax_backward_f64(grad_buf, output_buf, cols, dev).map_err(Self::map_gpu_err)?;
783 Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
784 }
785
786 fn log_softmax_f64(&self, a: &GpuBufferHandle, cols: usize) -> FerrotorchResult<GpuBufferHandle> {
787 let a_buf = Self::unwrap_buffer_f64(a)?;
788 let dev = self.device(a.device_ordinal())?;
789 let result = crate::kernels::gpu_log_softmax_f64(a_buf, cols, dev).map_err(Self::map_gpu_err)?;
790 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
791 }
792
793 fn log_softmax_backward_f64(&self, grad: &GpuBufferHandle, output: &GpuBufferHandle, cols: usize) -> FerrotorchResult<GpuBufferHandle> {
794 let grad_buf = Self::unwrap_buffer_f64(grad)?;
795 let output_buf = Self::unwrap_buffer_f64(output)?;
796 let dev = self.device(grad.device_ordinal())?;
797 let result = crate::kernels::gpu_log_softmax_backward_f64(grad_buf, output_buf, cols, dev).map_err(Self::map_gpu_err)?;
798 Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
799 }
800
801 fn layernorm_f64(
802 &self,
803 input: &GpuBufferHandle,
804 weight: &GpuBufferHandle,
805 bias: &GpuBufferHandle,
806 rows: usize,
807 cols: usize,
808 eps: f64,
809 ) -> FerrotorchResult<GpuBufferHandle> {
810 let in_buf = Self::unwrap_buffer_f64(input)?;
811 let w_buf = Self::unwrap_buffer_f64(weight)?;
812 let b_buf = Self::unwrap_buffer_f64(bias)?;
813 let dev = self.device(input.device_ordinal())?;
814 let result = crate::kernels::gpu_layernorm_f64(in_buf, w_buf, b_buf, rows, cols, eps, dev)
815 .map_err(Self::map_gpu_err)?;
816 Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
817 }
818
819 fn layernorm_backward_f64(
820 &self,
821 input: &GpuBufferHandle,
822 grad_output: &GpuBufferHandle,
823 weight: &GpuBufferHandle,
824 rows: usize,
825 cols: usize,
826 eps: f64,
827 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
828 let in_buf = Self::unwrap_buffer_f64(input)?;
829 let go_buf = Self::unwrap_buffer_f64(grad_output)?;
830 let w_buf = Self::unwrap_buffer_f64(weight)?;
831 let dev = self.device(input.device_ordinal())?;
832 let (gi, gw, gb) =
833 crate::kernels::gpu_layernorm_backward_f64(in_buf, go_buf, w_buf, rows, cols, eps, dev)
834 .map_err(Self::map_gpu_err)?;
835 let ordinal = input.device_ordinal();
836 Ok((
837 Self::wrap_buffer_f64(gi, ordinal),
838 Self::wrap_buffer_f64(gw, ordinal),
839 Self::wrap_buffer_f64(gb, ordinal),
840 ))
841 }
842
843 fn rmsnorm_f64(
844 &self,
845 input: &GpuBufferHandle,
846 weight: &GpuBufferHandle,
847 rows: usize,
848 cols: usize,
849 eps: f64,
850 ) -> FerrotorchResult<GpuBufferHandle> {
851 let in_buf = Self::unwrap_buffer_f64(input)?;
852 let w_buf = Self::unwrap_buffer_f64(weight)?;
853 let dev = self.device(input.device_ordinal())?;
854 let result = crate::kernels::gpu_rmsnorm_f64(in_buf, w_buf, rows, cols, eps, dev)
855 .map_err(Self::map_gpu_err)?;
856 Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
857 }
858
859 fn rmsnorm_backward_f64(
860 &self,
861 input: &GpuBufferHandle,
862 grad_output: &GpuBufferHandle,
863 weight: &GpuBufferHandle,
864 rows: usize,
865 cols: usize,
866 eps: f64,
867 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
868 let in_buf = Self::unwrap_buffer_f64(input)?;
869 let go_buf = Self::unwrap_buffer_f64(grad_output)?;
870 let w_buf = Self::unwrap_buffer_f64(weight)?;
871 let dev = self.device(input.device_ordinal())?;
872 let (gi, gw) =
873 crate::kernels::gpu_rmsnorm_backward_f64(in_buf, go_buf, w_buf, rows, cols, eps, dev)
874 .map_err(Self::map_gpu_err)?;
875 let ordinal = input.device_ordinal();
876 Ok((Self::wrap_buffer_f64(gi, ordinal), Self::wrap_buffer_f64(gw, ordinal)))
877 }
878
879 fn embed_lookup_f64(
882 &self,
883 idx: &GpuBufferHandle,
884 weight: &GpuBufferHandle,
885 d: usize,
886 ) -> FerrotorchResult<GpuBufferHandle> {
887 let idx_buf = Self::unwrap_buffer(idx)?;
889 let w_buf = Self::unwrap_buffer_f64(weight)?;
890 let dev = self.device(idx.device_ordinal())?;
891 let result =
892 crate::kernels::gpu_embed_lookup_f64(idx_buf, w_buf, d, dev).map_err(Self::map_gpu_err)?;
893 Ok(Self::wrap_buffer_f64(result, idx.device_ordinal()))
894 }
895
896 fn embed_lookup_batch_f64(
897 &self,
898 indices: &GpuBufferHandle,
899 weight: &GpuBufferHandle,
900 n: usize,
901 d: usize,
902 ) -> FerrotorchResult<GpuBufferHandle> {
903 let idx_buf = Self::unwrap_buffer(indices)?;
905 let w_buf = Self::unwrap_buffer_f64(weight)?;
906 let dev = self.device(indices.device_ordinal())?;
907 let result = crate::kernels::gpu_embed_lookup_batch_f64(idx_buf, w_buf, n, d, dev)
908 .map_err(Self::map_gpu_err)?;
909 Ok(Self::wrap_buffer_f64(result, indices.device_ordinal()))
910 }
911
912 fn scatter_add_rows_f64(
913 &self,
914 grad_output: &GpuBufferHandle,
915 indices: &GpuBufferHandle,
916 num_embeddings: usize,
917 d: usize,
918 ) -> FerrotorchResult<GpuBufferHandle> {
919 let go_buf = Self::unwrap_buffer_f64(grad_output)?;
920 let idx_buf = Self::unwrap_buffer(indices)?;
922 let dev = self.device(grad_output.device_ordinal())?;
923 let result = crate::kernels::gpu_scatter_add_rows_f64(go_buf, idx_buf, num_embeddings, d, dev)
924 .map_err(Self::map_gpu_err)?;
925 Ok(Self::wrap_buffer_f64(result, grad_output.device_ordinal()))
926 }
927
928 fn masked_fill_f64(
935 &self,
936 input: &GpuBufferHandle,
937 mask: &GpuBufferHandle,
938 value: f64,
939 ) -> FerrotorchResult<GpuBufferHandle> {
940 let input_buf = Self::unwrap_buffer_f64(input)?;
941 let mask_f32 = Self::unwrap_buffer(mask)?;
942 let dev = self.device(input.device_ordinal())?;
943 let mask_host = crate::transfer::gpu_to_cpu(mask_f32, dev).map_err(Self::map_gpu_err)?;
945 let mask_u8: Vec<u8> = mask_host.iter().map(|&v| if v != 0.0 { 1u8 } else { 0u8 }).collect();
946 let mask_gpu = crate::transfer::cpu_to_gpu(&mask_u8, dev).map_err(Self::map_gpu_err)?;
947 let result = crate::kernels::gpu_masked_fill_f64(input_buf, &mask_gpu, value, dev)
948 .map_err(Self::map_gpu_err)?;
949 Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
950 }
951
952 fn masked_zero_f64(
953 &self,
954 grad: &GpuBufferHandle,
955 mask: &GpuBufferHandle,
956 ) -> FerrotorchResult<GpuBufferHandle> {
957 let grad_buf = Self::unwrap_buffer_f64(grad)?;
958 let mask_f32 = Self::unwrap_buffer(mask)?;
959 let dev = self.device(grad.device_ordinal())?;
960 let mask_host = crate::transfer::gpu_to_cpu(mask_f32, dev).map_err(Self::map_gpu_err)?;
962 let mask_u8: Vec<u8> = mask_host.iter().map(|&v| if v != 0.0 { 1u8 } else { 0u8 }).collect();
963 let mask_gpu = crate::transfer::cpu_to_gpu(&mask_u8, dev).map_err(Self::map_gpu_err)?;
964 let result = crate::kernels::gpu_masked_zero_f64(grad_buf, &mask_gpu, dev)
965 .map_err(Self::map_gpu_err)?;
966 Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
967 }
968
969 fn slice_write_f64(
972 &self,
973 src: &GpuBufferHandle,
974 dst: &mut GpuBufferHandle,
975 n_batch: usize,
976 d: usize,
977 max_len: usize,
978 pos: usize,
979 ) -> FerrotorchResult<()> {
980 let src_buf = Self::unwrap_buffer_f64(src)?;
981 let dst_buf = Self::unwrap_buffer_f64_mut(dst)?;
982 let dev = self.device(src.device_ordinal())?;
983 crate::kernels::gpu_slice_write_f64(src_buf, dst_buf, n_batch, d, max_len, pos, dev)
984 .map_err(Self::map_gpu_err)?;
985 Ok(())
986 }
987
988 fn slice_read_f64(
989 &self,
990 src: &GpuBufferHandle,
991 n_batch: usize,
992 d: usize,
993 len: usize,
994 max_len: usize,
995 ) -> FerrotorchResult<GpuBufferHandle> {
996 let src_buf = Self::unwrap_buffer_f64(src)?;
997 let dev = self.device(src.device_ordinal())?;
998 let result = crate::kernels::gpu_slice_read_f64(src_buf, n_batch, d, len, max_len, dev)
999 .map_err(Self::map_gpu_err)?;
1000 Ok(Self::wrap_buffer_f64(result, src.device_ordinal()))
1001 }
1002
1003 fn strided_split_f64(
1006 &self,
1007 input: &GpuBufferHandle,
1008 total_along_axis: usize,
1009 split_offset: usize,
1010 split_size: usize,
1011 inner_size: usize,
1012 n: usize,
1013 ) -> FerrotorchResult<GpuBufferHandle> {
1014 let in_buf = Self::unwrap_buffer_f64(input)?;
1015 let dev = self.device(input.device_ordinal())?;
1016 let result = crate::kernels::gpu_strided_split_f64(
1017 in_buf,
1018 total_along_axis,
1019 split_offset,
1020 split_size,
1021 inner_size,
1022 n,
1023 dev,
1024 )
1025 .map_err(Self::map_gpu_err)?;
1026 Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
1027 }
1028
1029 fn strided_cat_f64(
1030 &self,
1031 input: &GpuBufferHandle,
1032 output: &mut GpuBufferHandle,
1033 total_along_axis: usize,
1034 cat_offset: usize,
1035 part_size: usize,
1036 inner_size: usize,
1037 n: usize,
1038 ) -> FerrotorchResult<()> {
1039 let in_buf = Self::unwrap_buffer_f64(input)?;
1040 let dev = self.device(input.device_ordinal())?;
1041 let out_buf = Self::unwrap_buffer_f64_mut(output)?;
1042 crate::kernels::gpu_strided_cat_f64(
1043 in_buf,
1044 out_buf,
1045 total_along_axis,
1046 cat_offset,
1047 part_size,
1048 inner_size,
1049 n,
1050 dev,
1051 )
1052 .map_err(Self::map_gpu_err)?;
1053 Ok(())
1054 }
1055
1056 fn index_select_1d_f64(
1059 &self,
1060 input: &GpuBufferHandle,
1061 indices: &GpuBufferHandle,
1062 ) -> FerrotorchResult<GpuBufferHandle> {
1063 let input_buf = Self::unwrap_buffer_f64(input)?;
1064 let idx_buf = Self::unwrap_buffer(indices)?;
1066 let dev = self.device(input.device_ordinal())?;
1067 let result = crate::kernels::gpu_index_select_1d_f64(input_buf, idx_buf, dev)
1068 .map_err(Self::map_gpu_err)?;
1069 Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
1070 }
1071
1072 fn scatter_add_1d_f64(
1073 &self,
1074 grad_output: &GpuBufferHandle,
1075 indices: &GpuBufferHandle,
1076 input_len: usize,
1077 ) -> FerrotorchResult<GpuBufferHandle> {
1078 let go_buf = Self::unwrap_buffer_f64(grad_output)?;
1079 let idx_buf = Self::unwrap_buffer(indices)?;
1081 let dev = self.device(grad_output.device_ordinal())?;
1082 let result = crate::kernels::gpu_scatter_add_1d_f64(go_buf, idx_buf, input_len, dev)
1083 .map_err(Self::map_gpu_err)?;
1084 Ok(Self::wrap_buffer_f64(result, grad_output.device_ordinal()))
1085 }
1086
1087 fn bmm_f64(
1088 &self,
1089 a: &GpuBufferHandle,
1090 b: &GpuBufferHandle,
1091 batch: usize,
1092 m: usize,
1093 k: usize,
1094 n: usize,
1095 ) -> FerrotorchResult<GpuBufferHandle> {
1096 let a_buf = Self::unwrap_buffer_f64(a)?;
1097 let b_buf = Self::unwrap_buffer_f64(b)?;
1098 let dev = self.device(a.device_ordinal())?;
1099 let result = crate::blas::gpu_bmm_f64(a_buf, b_buf, batch, m, k, n, dev)
1100 .map_err(Self::map_gpu_err)?;
1101 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
1102 }
1103
1104 #[allow(clippy::too_many_arguments)]
1105 fn fused_adam_f32(
1106 &self,
1107 param: &mut GpuBufferHandle,
1108 grad: &GpuBufferHandle,
1109 exp_avg: &mut GpuBufferHandle,
1110 exp_avg_sq: &mut GpuBufferHandle,
1111 beta1: f32,
1112 beta2: f32,
1113 lr: f32,
1114 eps: f32,
1115 bc1: f32,
1116 bc2: f32,
1117 weight_decay: f32,
1118 ) -> FerrotorchResult<()> {
1119 let ordinal = param.device_ordinal();
1120 let dev = self.device(ordinal)?;
1121 let p_buf = Self::unwrap_buffer_mut(param)?;
1122 let g_buf = Self::unwrap_buffer(grad)?;
1123 let m_buf = Self::unwrap_buffer_mut(exp_avg)?;
1124 let v_buf = Self::unwrap_buffer_mut(exp_avg_sq)?;
1125 crate::kernels::gpu_fused_adam(
1126 p_buf,
1127 g_buf,
1128 m_buf,
1129 v_buf,
1130 beta1,
1131 beta2,
1132 lr,
1133 eps,
1134 bc1,
1135 bc2,
1136 weight_decay,
1137 dev,
1138 )
1139 .map_err(Self::map_gpu_err)?;
1140 Ok(())
1141 }
1142
1143 #[allow(clippy::too_many_arguments)]
1144 fn maxpool2d_f32(
1145 &self,
1146 input: &GpuBufferHandle,
1147 batch: usize,
1148 channels: usize,
1149 h_in: usize,
1150 w_in: usize,
1151 kh: usize,
1152 kw: usize,
1153 sh: usize,
1154 sw: usize,
1155 ph: usize,
1156 pw: usize,
1157 ) -> FerrotorchResult<(GpuBufferHandle, [usize; 4])> {
1158 let buf = Self::unwrap_buffer(input)?;
1159 let dev = self.device(input.device_ordinal())?;
1160 let (out, shape) = crate::kernels::gpu_maxpool2d(
1161 buf, batch, channels, h_in, w_in, kh, kw, sh, sw, ph, pw, dev,
1162 ).map_err(Self::map_gpu_err)?;
1163 Ok((Self::wrap_buffer(out, input.device_ordinal()), shape))
1164 }
1165
1166 #[allow(clippy::too_many_arguments)]
1167 fn avgpool2d_f32(
1168 &self,
1169 input: &GpuBufferHandle,
1170 batch: usize,
1171 channels: usize,
1172 h_in: usize,
1173 w_in: usize,
1174 kh: usize,
1175 kw: usize,
1176 sh: usize,
1177 sw: usize,
1178 ph: usize,
1179 pw: usize,
1180 ) -> FerrotorchResult<(GpuBufferHandle, [usize; 4])> {
1181 let buf = Self::unwrap_buffer(input)?;
1182 let dev = self.device(input.device_ordinal())?;
1183 let (out, shape) = crate::kernels::gpu_avgpool2d(
1184 buf, batch, channels, h_in, w_in, kh, kw, sh, sw, ph, pw, dev,
1185 ).map_err(Self::map_gpu_err)?;
1186 Ok((Self::wrap_buffer(out, input.device_ordinal()), shape))
1187 }
1188
1189 #[allow(clippy::too_many_arguments)]
1190 fn conv2d_f32(
1191 &self,
1192 input: &GpuBufferHandle,
1193 weight: &GpuBufferHandle,
1194 bias: Option<&GpuBufferHandle>,
1195 input_shape: [usize; 4],
1196 weight_shape: [usize; 4],
1197 stride: (usize, usize),
1198 padding: (usize, usize),
1199 ) -> FerrotorchResult<(GpuBufferHandle, [usize; 4])> {
1200 let input_buf = Self::unwrap_buffer(input)?;
1201 let weight_buf = Self::unwrap_buffer(weight)?;
1202 let bias_buf = match bias {
1203 Some(b) => Some(Self::unwrap_buffer(b)?),
1204 None => None,
1205 };
1206 let dev = self.device(input.device_ordinal())?;
1207 let (out_buf, out_shape) = crate::conv::gpu_conv2d_f32(
1208 input_buf,
1209 weight_buf,
1210 bias_buf,
1211 input_shape,
1212 weight_shape,
1213 stride,
1214 padding,
1215 dev,
1216 )
1217 .map_err(Self::map_gpu_err)?;
1218 Ok((Self::wrap_buffer(out_buf, input.device_ordinal()), out_shape))
1219 }
1220
1221 fn fused_gru_cell_f32(
1222 &self,
1223 input_gates: &GpuBufferHandle,
1224 hidden_gates: &GpuBufferHandle,
1225 bias_ih: &GpuBufferHandle,
1226 bias_hh: &GpuBufferHandle,
1227 hx: &GpuBufferHandle,
1228 hidden_size: usize,
1229 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
1230 let ig = Self::unwrap_buffer(input_gates)?;
1231 let hg = Self::unwrap_buffer(hidden_gates)?;
1232 let bih = Self::unwrap_buffer(bias_ih)?;
1233 let bhh = Self::unwrap_buffer(bias_hh)?;
1234 let hx_buf = Self::unwrap_buffer(hx)?;
1235 let dev = self.device(input_gates.device_ordinal())?;
1236 let (hy, ws) = crate::kernels::gpu_fused_gru_forward(
1237 ig, hg, bih, bhh, hx_buf, hidden_size, dev,
1238 )
1239 .map_err(Self::map_gpu_err)?;
1240 let ord = input_gates.device_ordinal();
1241 Ok((Self::wrap_buffer(hy, ord), Self::wrap_buffer(ws, ord)))
1242 }
1243
1244 fn synchronize(&self, device: usize) -> FerrotorchResult<()> {
1245 let dev = self.device(device)?;
1246 dev.stream()
1247 .synchronize()
1248 .map_err(|e| FerrotorchError::InvalidArgument {
1249 message: format!("CUDA synchronize failed: {e}"),
1250 })?;
1251 Ok(())
1252 }
1253
1254 fn stream_count(&self, device: usize) -> usize {
1255 crate::stream::StreamPool::pool_size(device)
1256 }
1257
1258 fn matmul_f32(
1261 &self,
1262 a: &GpuBufferHandle,
1263 b: &GpuBufferHandle,
1264 m: usize,
1265 k: usize,
1266 n: usize,
1267 ) -> FerrotorchResult<GpuBufferHandle> {
1268 let a_buf = Self::unwrap_buffer(a)?;
1269 let b_buf = Self::unwrap_buffer(b)?;
1270 let dev = self.device(a.device_ordinal())?;
1271 let result =
1272 crate::blas::gpu_matmul_f32(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
1273 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1274 }
1275
1276 fn sum_f32(&self, a: &GpuBufferHandle, _len: usize) -> FerrotorchResult<GpuBufferHandle> {
1279 let a_buf = Self::unwrap_buffer(a)?;
1280 let dev = self.device(a.device_ordinal())?;
1281 let result = crate::kernels::gpu_reduce_sum(a_buf, dev).map_err(Self::map_gpu_err)?;
1282 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1283 }
1284
1285 fn matmul_f64(
1288 &self,
1289 a: &GpuBufferHandle,
1290 b: &GpuBufferHandle,
1291 m: usize,
1292 k: usize,
1293 n: usize,
1294 ) -> FerrotorchResult<GpuBufferHandle> {
1295 let a_buf = Self::unwrap_buffer_f64(a)?;
1296 let b_buf = Self::unwrap_buffer_f64(b)?;
1297 let dev = self.device(a.device_ordinal())?;
1298 let result =
1299 crate::blas::gpu_matmul_f64(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
1300 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
1301 }
1302
1303 fn broadcast_add_f32(
1306 &self,
1307 a: &GpuBufferHandle,
1308 b: &GpuBufferHandle,
1309 a_shape: &[usize],
1310 b_shape: &[usize],
1311 out_shape: &[usize],
1312 ) -> FerrotorchResult<GpuBufferHandle> {
1313 let a_buf = Self::unwrap_buffer(a)?;
1314 let b_buf = Self::unwrap_buffer(b)?;
1315 let dev = self.device(a.device_ordinal())?;
1316 let result =
1317 crate::kernels::gpu_broadcast_add(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
1318 .map_err(Self::map_gpu_err)?;
1319 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1320 }
1321
1322 fn broadcast_sub_f32(
1323 &self,
1324 a: &GpuBufferHandle,
1325 b: &GpuBufferHandle,
1326 a_shape: &[usize],
1327 b_shape: &[usize],
1328 out_shape: &[usize],
1329 ) -> FerrotorchResult<GpuBufferHandle> {
1330 let a_buf = Self::unwrap_buffer(a)?;
1331 let b_buf = Self::unwrap_buffer(b)?;
1332 let dev = self.device(a.device_ordinal())?;
1333 let result =
1334 crate::kernels::gpu_broadcast_sub(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
1335 .map_err(Self::map_gpu_err)?;
1336 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1337 }
1338
1339 fn broadcast_mul_f32(
1340 &self,
1341 a: &GpuBufferHandle,
1342 b: &GpuBufferHandle,
1343 a_shape: &[usize],
1344 b_shape: &[usize],
1345 out_shape: &[usize],
1346 ) -> FerrotorchResult<GpuBufferHandle> {
1347 let a_buf = Self::unwrap_buffer(a)?;
1348 let b_buf = Self::unwrap_buffer(b)?;
1349 let dev = self.device(a.device_ordinal())?;
1350 let result =
1351 crate::kernels::gpu_broadcast_mul(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
1352 .map_err(Self::map_gpu_err)?;
1353 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1354 }
1355
1356 fn broadcast_div_f32(
1357 &self,
1358 a: &GpuBufferHandle,
1359 b: &GpuBufferHandle,
1360 a_shape: &[usize],
1361 b_shape: &[usize],
1362 out_shape: &[usize],
1363 ) -> FerrotorchResult<GpuBufferHandle> {
1364 let a_buf = Self::unwrap_buffer(a)?;
1365 let b_buf = Self::unwrap_buffer(b)?;
1366 let dev = self.device(a.device_ordinal())?;
1367 let result =
1368 crate::kernels::gpu_broadcast_div(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
1369 .map_err(Self::map_gpu_err)?;
1370 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1371 }
1372
1373 fn softmax_f32(
1374 &self,
1375 a: &GpuBufferHandle,
1376 rows: usize,
1377 cols: usize,
1378 ) -> FerrotorchResult<GpuBufferHandle> {
1379 let a_buf = Self::unwrap_buffer(a)?;
1380 let dev = self.device(a.device_ordinal())?;
1381 let result =
1382 crate::kernels::gpu_softmax(a_buf, rows, cols, dev).map_err(Self::map_gpu_err)?;
1383 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1384 }
1385
1386 fn dropout_f32(
1387 &self,
1388 a: &GpuBufferHandle,
1389 threshold: u32,
1390 scale: f32,
1391 seed: u32,
1392 ) -> FerrotorchResult<GpuBufferHandle> {
1393 let a_buf = Self::unwrap_buffer(a)?;
1394 let dev = self.device(a.device_ordinal())?;
1395 let result = crate::kernels::gpu_dropout(a_buf, threshold, scale, seed, dev)
1396 .map_err(Self::map_gpu_err)?;
1397 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1398 }
1399
1400 fn dropout_philox_f32(
1401 &self,
1402 a: &GpuBufferHandle,
1403 threshold: u32,
1404 scale: f32,
1405 ) -> FerrotorchResult<(GpuBufferHandle, GpuRngState)> {
1406 let device_ordinal = a.device_ordinal();
1407 let n = a.len();
1408
1409 let rng_state = {
1411 let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
1412 FerrotorchError::InvalidArgument {
1413 message: "failed to lock CUDA RNG manager".into(),
1414 }
1415 })?;
1416 let philox_gen = mgr.generator(device_ordinal);
1417 let state = philox_gen.get_state();
1418 let counters_needed = n.div_ceil(4);
1420 philox_gen.advance(counters_needed as u64);
1421 state
1422 };
1423
1424 let a_buf = Self::unwrap_buffer(a)?;
1432 let dev = self.device(device_ordinal)?;
1433
1434 let derived_seed = (rng_state.counter ^ rng_state.seed) as u32;
1437 let result = crate::kernels::gpu_dropout(a_buf, threshold, scale, derived_seed, dev)
1438 .map_err(Self::map_gpu_err)?;
1439
1440 let gpu_rng_state = GpuRngState {
1441 counter: rng_state.counter,
1442 seed: rng_state.seed,
1443 offset: rng_state.offset,
1444 device: device_ordinal,
1445 };
1446
1447 Ok((Self::wrap_buffer(result, device_ordinal), gpu_rng_state))
1448 }
1449
1450 fn dropout_f64(
1451 &self,
1452 a: &GpuBufferHandle,
1453 threshold: u32,
1454 scale: f64,
1455 seed: u32,
1456 ) -> FerrotorchResult<GpuBufferHandle> {
1457 let a_buf = Self::unwrap_buffer_f64(a)?;
1458 let dev = self.device(a.device_ordinal())?;
1459 let result = crate::kernels::gpu_dropout_f64(a_buf, threshold, scale, seed, dev)
1460 .map_err(Self::map_gpu_err)?;
1461 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
1462 }
1463
1464 fn dropout_philox_f64(
1465 &self,
1466 a: &GpuBufferHandle,
1467 threshold: u32,
1468 scale: f64,
1469 ) -> FerrotorchResult<(GpuBufferHandle, GpuRngState)> {
1470 let device_ordinal = a.device_ordinal();
1471 let n = a.len();
1472
1473 let rng_state = {
1474 let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
1475 FerrotorchError::InvalidArgument {
1476 message: "failed to lock CUDA RNG manager".into(),
1477 }
1478 })?;
1479 let philox_gen = mgr.generator(device_ordinal);
1480 let state = philox_gen.get_state();
1481 let counters_needed = n.div_ceil(4);
1482 philox_gen.advance(counters_needed as u64);
1483 state
1484 };
1485
1486 let a_buf = Self::unwrap_buffer_f64(a)?;
1487 let dev = self.device(device_ordinal)?;
1488 let derived_seed = (rng_state.counter ^ rng_state.seed) as u32;
1489 let result = crate::kernels::gpu_dropout_f64(a_buf, threshold, scale, derived_seed, dev)
1490 .map_err(Self::map_gpu_err)?;
1491
1492 let gpu_rng_state = GpuRngState {
1493 counter: rng_state.counter,
1494 seed: rng_state.seed,
1495 offset: rng_state.offset,
1496 device: device_ordinal,
1497 };
1498
1499 Ok((Self::wrap_buffer_f64(result, device_ordinal), gpu_rng_state))
1500 }
1501
1502 fn transpose_2d_f32(
1503 &self,
1504 a: &GpuBufferHandle,
1505 m: usize,
1506 n: usize,
1507 ) -> FerrotorchResult<GpuBufferHandle> {
1508 let a_buf = Self::unwrap_buffer(a)?;
1509 let dev = self.device(a.device_ordinal())?;
1510 let result =
1511 crate::kernels::gpu_transpose_2d(a_buf, m, n, dev).map_err(Self::map_gpu_err)?;
1512 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1513 }
1514
1515 fn permute_0213_f32(
1516 &self,
1517 a: &GpuBufferHandle,
1518 d0: usize,
1519 d1: usize,
1520 d2: usize,
1521 d3: usize,
1522 ) -> FerrotorchResult<GpuBufferHandle> {
1523 let a_buf = Self::unwrap_buffer(a)?;
1524 let dev = self.device(a.device_ordinal())?;
1525 let result = crate::kernels::gpu_permute_0213(a_buf, d0, d1, d2, d3, dev)
1526 .map_err(Self::map_gpu_err)?;
1527 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1528 }
1529
1530 fn bmm_f32(
1531 &self,
1532 a: &GpuBufferHandle,
1533 b: &GpuBufferHandle,
1534 batch: usize,
1535 m: usize,
1536 k: usize,
1537 n: usize,
1538 ) -> FerrotorchResult<GpuBufferHandle> {
1539 let a_buf = Self::unwrap_buffer(a)?;
1540 let b_buf = Self::unwrap_buffer(b)?;
1541 let dev = self.device(a.device_ordinal())?;
1542 let result = crate::blas::gpu_bmm_f32(a_buf, b_buf, batch, m, k, n, dev)
1543 .map_err(Self::map_gpu_err)?;
1544 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1545 }
1546
1547 fn bmm_f16_f32(
1548 &self,
1549 a: &GpuBufferHandle,
1550 b: &GpuBufferHandle,
1551 batch: usize,
1552 m: usize,
1553 k: usize,
1554 n: usize,
1555 ) -> FerrotorchResult<GpuBufferHandle> {
1556 let a_buf = Self::unwrap_buffer(a)?;
1557 let b_buf = Self::unwrap_buffer(b)?;
1558 let dev = self.device(a.device_ordinal())?;
1559 let result = crate::blas::gpu_bmm_f16(a_buf, b_buf, batch, m, k, n, dev)
1560 .map_err(Self::map_gpu_err)?;
1561 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1562 }
1563
1564 fn gelu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
1565 let a_buf = Self::unwrap_buffer(a)?;
1566 let dev = self.device(a.device_ordinal())?;
1567 let result = crate::kernels::gpu_gelu(a_buf, dev).map_err(Self::map_gpu_err)?;
1568 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1569 }
1570
1571 fn gelu_tanh_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
1572 let a_buf = Self::unwrap_buffer(a)?;
1573 let dev = self.device(a.device_ordinal())?;
1574 let result = crate::kernels::gpu_gelu_tanh(a_buf, dev).map_err(Self::map_gpu_err)?;
1575 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1576 }
1577
1578 fn gelu_erf_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
1579 let a_buf = Self::unwrap_buffer(a)?;
1580 let dev = self.device(a.device_ordinal())?;
1581 let result = crate::kernels::gpu_gelu_erf(a_buf, dev).map_err(Self::map_gpu_err)?;
1582 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1583 }
1584
1585 fn layernorm_f32(
1586 &self,
1587 input: &GpuBufferHandle,
1588 weight: &GpuBufferHandle,
1589 bias: &GpuBufferHandle,
1590 rows: usize,
1591 cols: usize,
1592 eps: f32,
1593 ) -> FerrotorchResult<GpuBufferHandle> {
1594 let in_buf = Self::unwrap_buffer(input)?;
1595 let w_buf = Self::unwrap_buffer(weight)?;
1596 let b_buf = Self::unwrap_buffer(bias)?;
1597 let dev = self.device(input.device_ordinal())?;
1598 let result = crate::kernels::gpu_layernorm(in_buf, w_buf, b_buf, rows, cols, eps, dev)
1599 .map_err(Self::map_gpu_err)?;
1600 Ok(Self::wrap_buffer(result, input.device_ordinal()))
1601 }
1602
1603 fn rmsnorm_f32(
1604 &self,
1605 input: &GpuBufferHandle,
1606 weight: &GpuBufferHandle,
1607 rows: usize,
1608 cols: usize,
1609 eps: f32,
1610 ) -> FerrotorchResult<GpuBufferHandle> {
1611 let in_buf = Self::unwrap_buffer(input)?;
1612 let w_buf = Self::unwrap_buffer(weight)?;
1613 let dev = self.device(input.device_ordinal())?;
1614 let result = crate::kernels::gpu_rmsnorm(in_buf, w_buf, rows, cols, eps, dev)
1615 .map_err(Self::map_gpu_err)?;
1616 Ok(Self::wrap_buffer(result, input.device_ordinal()))
1617 }
1618
1619 fn rmsnorm_backward_f32(
1620 &self,
1621 input: &GpuBufferHandle,
1622 grad_output: &GpuBufferHandle,
1623 weight: &GpuBufferHandle,
1624 rows: usize,
1625 cols: usize,
1626 eps: f32,
1627 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
1628 let in_buf = Self::unwrap_buffer(input)?;
1629 let go_buf = Self::unwrap_buffer(grad_output)?;
1630 let w_buf = Self::unwrap_buffer(weight)?;
1631 let dev = self.device(input.device_ordinal())?;
1632 let (gi, gw) =
1633 crate::kernels::gpu_rmsnorm_backward(in_buf, go_buf, w_buf, rows, cols, eps, dev)
1634 .map_err(Self::map_gpu_err)?;
1635 let ordinal = input.device_ordinal();
1636 Ok((Self::wrap_buffer(gi, ordinal), Self::wrap_buffer(gw, ordinal)))
1637 }
1638
1639 fn slice_write_f32(
1640 &self,
1641 src: &GpuBufferHandle,
1642 dst: &mut GpuBufferHandle,
1643 n_batch: usize,
1644 d: usize,
1645 max_len: usize,
1646 pos: usize,
1647 ) -> FerrotorchResult<()> {
1648 let src_buf = Self::unwrap_buffer(src)?;
1649 let dst_buf =
1650 dst.downcast_mut::<CudaBuffer<f32>>()
1651 .ok_or(FerrotorchError::InvalidArgument {
1652 message: "slice_write_f32: dst is not CudaBuffer<f32>".into(),
1653 })?;
1654 let dev = self.device(src.device_ordinal())?;
1655 crate::kernels::gpu_slice_write(src_buf, dst_buf, n_batch, d, max_len, pos, dev)
1656 .map_err(Self::map_gpu_err)?;
1657 Ok(())
1658 }
1659
1660 fn slice_read_f32(
1661 &self,
1662 src: &GpuBufferHandle,
1663 n_batch: usize,
1664 d: usize,
1665 len: usize,
1666 max_len: usize,
1667 ) -> FerrotorchResult<GpuBufferHandle> {
1668 let src_buf = Self::unwrap_buffer(src)?;
1669 let dev = self.device(src.device_ordinal())?;
1670 let result = crate::kernels::gpu_slice_read(src_buf, n_batch, d, len, max_len, dev)
1671 .map_err(Self::map_gpu_err)?;
1672 Ok(Self::wrap_buffer(result, src.device_ordinal()))
1673 }
1674
1675 fn embed_lookup_f32(
1676 &self,
1677 idx: &GpuBufferHandle,
1678 weight: &GpuBufferHandle,
1679 d: usize,
1680 ) -> FerrotorchResult<GpuBufferHandle> {
1681 let idx_buf = Self::unwrap_buffer(idx)?;
1682 let w_buf = Self::unwrap_buffer(weight)?;
1683 let dev = self.device(idx.device_ordinal())?;
1684 let result =
1685 crate::kernels::gpu_embed_lookup(idx_buf, w_buf, d, dev).map_err(Self::map_gpu_err)?;
1686 Ok(Self::wrap_buffer(result, idx.device_ordinal()))
1687 }
1688
1689 fn embed_lookup_batch_f32(
1690 &self,
1691 indices: &GpuBufferHandle,
1692 weight: &GpuBufferHandle,
1693 n: usize,
1694 d: usize,
1695 ) -> FerrotorchResult<GpuBufferHandle> {
1696 let idx_buf = Self::unwrap_buffer(indices)?;
1697 let w_buf = Self::unwrap_buffer(weight)?;
1698 let dev = self.device(indices.device_ordinal())?;
1699 let result = crate::kernels::gpu_embed_lookup_batch(idx_buf, w_buf, n, d, dev)
1700 .map_err(Self::map_gpu_err)?;
1701 Ok(Self::wrap_buffer(result, indices.device_ordinal()))
1702 }
1703
1704 fn scatter_add_rows_f32(
1705 &self,
1706 grad_output: &GpuBufferHandle,
1707 indices: &GpuBufferHandle,
1708 num_embeddings: usize,
1709 d: usize,
1710 ) -> FerrotorchResult<GpuBufferHandle> {
1711 let go_buf = Self::unwrap_buffer(grad_output)?;
1712 let idx_buf = Self::unwrap_buffer(indices)?;
1713 let dev = self.device(grad_output.device_ordinal())?;
1714 let result = crate::kernels::gpu_scatter_add_rows(go_buf, idx_buf, num_embeddings, d, dev)
1715 .map_err(Self::map_gpu_err)?;
1716 Ok(Self::wrap_buffer(result, grad_output.device_ordinal()))
1717 }
1718
1719 fn scale_f32(&self, a: &GpuBufferHandle, scalar: f32) -> FerrotorchResult<GpuBufferHandle> {
1720 let a_buf = Self::unwrap_buffer(a)?;
1721 let dev = self.device(a.device_ordinal())?;
1722 let result = crate::kernels::gpu_scale(a_buf, scalar, dev).map_err(Self::map_gpu_err)?;
1723 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1724 }
1725
1726 fn relu_backward_f32(
1727 &self,
1728 grad: &GpuBufferHandle,
1729 input: &GpuBufferHandle,
1730 ) -> FerrotorchResult<GpuBufferHandle> {
1731 let grad_buf = Self::unwrap_buffer(grad)?;
1732 let input_buf = Self::unwrap_buffer(input)?;
1733 let dev = self.device(grad.device_ordinal())?;
1734 let result = crate::kernels::gpu_relu_backward(grad_buf, input_buf, dev)
1735 .map_err(Self::map_gpu_err)?;
1736 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1737 }
1738
1739 fn abs_backward_f32(
1740 &self,
1741 grad: &GpuBufferHandle,
1742 input: &GpuBufferHandle,
1743 ) -> FerrotorchResult<GpuBufferHandle> {
1744 let grad_buf = Self::unwrap_buffer(grad)?;
1745 let input_buf = Self::unwrap_buffer(input)?;
1746 let dev = self.device(grad.device_ordinal())?;
1747 let result = crate::kernels::gpu_abs_backward(grad_buf, input_buf, dev)
1748 .map_err(Self::map_gpu_err)?;
1749 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1750 }
1751
1752 fn fill_f32(
1753 &self,
1754 n: usize,
1755 scalar: f32,
1756 ordinal: usize,
1757 ) -> FerrotorchResult<GpuBufferHandle> {
1758 let dev = self.device(ordinal)?;
1759 let result = crate::kernels::gpu_fill_f32(n, scalar, dev).map_err(Self::map_gpu_err)?;
1760 Ok(Self::wrap_buffer(result, ordinal))
1761 }
1762
1763 fn gelu_backward_f32(
1764 &self,
1765 grad: &GpuBufferHandle,
1766 input: &GpuBufferHandle,
1767 ) -> FerrotorchResult<GpuBufferHandle> {
1768 let grad_buf = Self::unwrap_buffer(grad)?;
1769 let input_buf = Self::unwrap_buffer(input)?;
1770 let dev = self.device(grad.device_ordinal())?;
1771 let result = crate::kernels::gpu_gelu_backward(grad_buf, input_buf, dev)
1772 .map_err(Self::map_gpu_err)?;
1773 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1774 }
1775
1776 fn gelu_backward_tanh_f32(
1777 &self,
1778 grad: &GpuBufferHandle,
1779 input: &GpuBufferHandle,
1780 ) -> FerrotorchResult<GpuBufferHandle> {
1781 let grad_buf = Self::unwrap_buffer(grad)?;
1782 let input_buf = Self::unwrap_buffer(input)?;
1783 let dev = self.device(grad.device_ordinal())?;
1784 let result = crate::kernels::gpu_gelu_backward_tanh(grad_buf, input_buf, dev)
1785 .map_err(Self::map_gpu_err)?;
1786 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1787 }
1788
1789 fn gelu_backward_erf_f32(
1790 &self,
1791 grad: &GpuBufferHandle,
1792 input: &GpuBufferHandle,
1793 ) -> FerrotorchResult<GpuBufferHandle> {
1794 let grad_buf = Self::unwrap_buffer(grad)?;
1795 let input_buf = Self::unwrap_buffer(input)?;
1796 let dev = self.device(grad.device_ordinal())?;
1797 let result = crate::kernels::gpu_gelu_backward_erf(grad_buf, input_buf, dev)
1798 .map_err(Self::map_gpu_err)?;
1799 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1800 }
1801
1802 fn cumsum_f32(
1803 &self,
1804 a: &GpuBufferHandle,
1805 outer: usize,
1806 dim_size: usize,
1807 inner: usize,
1808 ) -> FerrotorchResult<GpuBufferHandle> {
1809 let a_buf = Self::unwrap_buffer(a)?;
1810 let dev = self.device(a.device_ordinal())?;
1811 let result = crate::kernels::gpu_cumsum(a_buf, outer, dim_size, inner, dev)
1812 .map_err(Self::map_gpu_err)?;
1813 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1814 }
1815
1816 fn cumprod_f32(
1817 &self,
1818 a: &GpuBufferHandle,
1819 outer: usize,
1820 dim_size: usize,
1821 inner: usize,
1822 ) -> FerrotorchResult<GpuBufferHandle> {
1823 let a_buf = Self::unwrap_buffer(a)?;
1824 let dev = self.device(a.device_ordinal())?;
1825 let result = crate::kernels::gpu_cumprod(a_buf, outer, dim_size, inner, dev)
1826 .map_err(Self::map_gpu_err)?;
1827 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1828 }
1829
1830 fn cummax_f32(
1831 &self,
1832 a: &GpuBufferHandle,
1833 outer: usize,
1834 dim_size: usize,
1835 inner: usize,
1836 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
1837 let a_buf = Self::unwrap_buffer(a)?;
1838 let dev = self.device(a.device_ordinal())?;
1839 let (vals, idxs) = crate::kernels::gpu_cummax(a_buf, outer, dim_size, inner, dev)
1840 .map_err(Self::map_gpu_err)?;
1841 let ord = a.device_ordinal();
1842 Ok((Self::wrap_buffer(vals, ord), Self::wrap_buffer(idxs, ord)))
1843 }
1844
1845 fn cummin_f32(
1846 &self,
1847 a: &GpuBufferHandle,
1848 outer: usize,
1849 dim_size: usize,
1850 inner: usize,
1851 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
1852 let a_buf = Self::unwrap_buffer(a)?;
1853 let dev = self.device(a.device_ordinal())?;
1854 let (vals, idxs) = crate::kernels::gpu_cummin(a_buf, outer, dim_size, inner, dev)
1855 .map_err(Self::map_gpu_err)?;
1856 let ord = a.device_ordinal();
1857 Ok((Self::wrap_buffer(vals, ord), Self::wrap_buffer(idxs, ord)))
1858 }
1859
1860 fn logcumsumexp_f32(
1861 &self,
1862 a: &GpuBufferHandle,
1863 outer: usize,
1864 dim_size: usize,
1865 inner: usize,
1866 ) -> FerrotorchResult<GpuBufferHandle> {
1867 let a_buf = Self::unwrap_buffer(a)?;
1868 let dev = self.device(a.device_ordinal())?;
1869 let result = crate::kernels::gpu_logcumsumexp(a_buf, outer, dim_size, inner, dev)
1870 .map_err(Self::map_gpu_err)?;
1871 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1872 }
1873
1874 fn clamp_f32(
1875 &self,
1876 a: &GpuBufferHandle,
1877 min_val: f32,
1878 max_val: f32,
1879 ) -> FerrotorchResult<GpuBufferHandle> {
1880 let a_buf = Self::unwrap_buffer(a)?;
1881 let dev = self.device(a.device_ordinal())?;
1882 let result =
1883 crate::kernels::gpu_clamp(a_buf, min_val, max_val, dev).map_err(Self::map_gpu_err)?;
1884 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1885 }
1886
1887 fn silu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
1888 let a_buf = Self::unwrap_buffer(a)?;
1889 let dev = self.device(a.device_ordinal())?;
1890 let result = crate::kernels::gpu_silu(a_buf, dev).map_err(Self::map_gpu_err)?;
1891 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1892 }
1893
1894 fn silu_backward_f32(
1895 &self,
1896 grad: &GpuBufferHandle,
1897 input: &GpuBufferHandle,
1898 ) -> FerrotorchResult<GpuBufferHandle> {
1899 let grad_buf = Self::unwrap_buffer(grad)?;
1900 let input_buf = Self::unwrap_buffer(input)?;
1901 let dev = self.device(grad.device_ordinal())?;
1902 let result = crate::kernels::gpu_silu_backward(grad_buf, input_buf, dev)
1903 .map_err(Self::map_gpu_err)?;
1904 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1905 }
1906
1907 fn elu_f32(&self, a: &GpuBufferHandle, alpha: f32) -> FerrotorchResult<GpuBufferHandle> {
1908 let a_buf = Self::unwrap_buffer(a)?;
1909 let dev = self.device(a.device_ordinal())?;
1910 let result = crate::kernels::gpu_elu(a_buf, alpha, dev).map_err(Self::map_gpu_err)?;
1911 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1912 }
1913
1914 fn elu_backward_f32(
1915 &self,
1916 grad: &GpuBufferHandle,
1917 input: &GpuBufferHandle,
1918 alpha: f32,
1919 ) -> FerrotorchResult<GpuBufferHandle> {
1920 let grad_buf = Self::unwrap_buffer(grad)?;
1921 let input_buf = Self::unwrap_buffer(input)?;
1922 let dev = self.device(grad.device_ordinal())?;
1923 let result = crate::kernels::gpu_elu_backward(grad_buf, input_buf, alpha, dev)
1924 .map_err(Self::map_gpu_err)?;
1925 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1926 }
1927
1928 fn mish_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
1929 let a_buf = Self::unwrap_buffer(a)?;
1930 let dev = self.device(a.device_ordinal())?;
1931 let result = crate::kernels::gpu_mish(a_buf, dev).map_err(Self::map_gpu_err)?;
1932 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1933 }
1934
1935 fn mish_backward_f32(
1936 &self,
1937 grad: &GpuBufferHandle,
1938 input: &GpuBufferHandle,
1939 ) -> FerrotorchResult<GpuBufferHandle> {
1940 let grad_buf = Self::unwrap_buffer(grad)?;
1941 let input_buf = Self::unwrap_buffer(input)?;
1942 let dev = self.device(grad.device_ordinal())?;
1943 let result = crate::kernels::gpu_mish_backward(grad_buf, input_buf, dev)
1944 .map_err(Self::map_gpu_err)?;
1945 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1946 }
1947
1948 fn log_softmax_f32(
1949 &self,
1950 a: &GpuBufferHandle,
1951 cols: usize,
1952 ) -> FerrotorchResult<GpuBufferHandle> {
1953 let a_buf = Self::unwrap_buffer(a)?;
1954 let dev = self.device(a.device_ordinal())?;
1955 let result =
1956 crate::kernels::gpu_log_softmax(a_buf, cols, dev).map_err(Self::map_gpu_err)?;
1957 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1958 }
1959
1960 fn log_softmax_backward_f32(
1961 &self,
1962 grad: &GpuBufferHandle,
1963 output: &GpuBufferHandle,
1964 cols: usize,
1965 ) -> FerrotorchResult<GpuBufferHandle> {
1966 let grad_buf = Self::unwrap_buffer(grad)?;
1967 let output_buf = Self::unwrap_buffer(output)?;
1968 let dev = self.device(grad.device_ordinal())?;
1969 let result =
1970 crate::kernels::gpu_log_softmax_backward(grad_buf, output_buf, cols, dev)
1971 .map_err(Self::map_gpu_err)?;
1972 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1973 }
1974
1975 fn index_select_1d_f32(
1976 &self,
1977 input: &GpuBufferHandle,
1978 indices: &GpuBufferHandle,
1979 ) -> FerrotorchResult<GpuBufferHandle> {
1980 let input_buf = Self::unwrap_buffer(input)?;
1981 let idx_buf = Self::unwrap_buffer(indices)?;
1982 let dev = self.device(input.device_ordinal())?;
1983 let result = crate::kernels::gpu_index_select_1d(input_buf, idx_buf, dev)
1984 .map_err(Self::map_gpu_err)?;
1985 Ok(Self::wrap_buffer(result, input.device_ordinal()))
1986 }
1987
1988 fn scatter_add_1d_f32(
1989 &self,
1990 grad_output: &GpuBufferHandle,
1991 indices: &GpuBufferHandle,
1992 input_len: usize,
1993 ) -> FerrotorchResult<GpuBufferHandle> {
1994 let go_buf = Self::unwrap_buffer(grad_output)?;
1995 let idx_buf = Self::unwrap_buffer(indices)?;
1996 let dev = self.device(grad_output.device_ordinal())?;
1997 let result = crate::kernels::gpu_scatter_add_1d(go_buf, idx_buf, input_len, dev)
1998 .map_err(Self::map_gpu_err)?;
1999 Ok(Self::wrap_buffer(result, grad_output.device_ordinal()))
2000 }
2001
2002 fn masked_fill_f32(
2003 &self,
2004 input: &GpuBufferHandle,
2005 mask: &GpuBufferHandle,
2006 value: f32,
2007 ) -> FerrotorchResult<GpuBufferHandle> {
2008 let input_buf = Self::unwrap_buffer(input)?;
2009 let mask_buf = Self::unwrap_buffer(mask)?;
2010 let dev = self.device(input.device_ordinal())?;
2011 let result = crate::kernels::gpu_masked_fill(input_buf, mask_buf, value, dev)
2012 .map_err(Self::map_gpu_err)?;
2013 Ok(Self::wrap_buffer(result, input.device_ordinal()))
2014 }
2015
2016 fn masked_zero_f32(
2017 &self,
2018 grad: &GpuBufferHandle,
2019 mask: &GpuBufferHandle,
2020 ) -> FerrotorchResult<GpuBufferHandle> {
2021 let grad_buf = Self::unwrap_buffer(grad)?;
2022 let mask_buf = Self::unwrap_buffer(mask)?;
2023 let dev = self.device(grad.device_ordinal())?;
2024 let result =
2025 crate::kernels::gpu_masked_zero(grad_buf, mask_buf, dev).map_err(Self::map_gpu_err)?;
2026 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
2027 }
2028
2029 fn sigmoid_backward_f32(
2030 &self,
2031 grad: &GpuBufferHandle,
2032 output: &GpuBufferHandle,
2033 ) -> FerrotorchResult<GpuBufferHandle> {
2034 let grad_buf = Self::unwrap_buffer(grad)?;
2035 let output_buf = Self::unwrap_buffer(output)?;
2036 let dev = self.device(grad.device_ordinal())?;
2037 let result = crate::kernels::gpu_sigmoid_backward(grad_buf, output_buf, dev)
2038 .map_err(Self::map_gpu_err)?;
2039 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
2040 }
2041
2042 fn tanh_backward_f32(
2043 &self,
2044 grad: &GpuBufferHandle,
2045 output: &GpuBufferHandle,
2046 ) -> FerrotorchResult<GpuBufferHandle> {
2047 let grad_buf = Self::unwrap_buffer(grad)?;
2048 let output_buf = Self::unwrap_buffer(output)?;
2049 let dev = self.device(grad.device_ordinal())?;
2050 let result = crate::kernels::gpu_tanh_backward(grad_buf, output_buf, dev)
2051 .map_err(Self::map_gpu_err)?;
2052 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
2053 }
2054
2055 fn softmax_backward_f32(
2056 &self,
2057 grad: &GpuBufferHandle,
2058 output: &GpuBufferHandle,
2059 cols: usize,
2060 ) -> FerrotorchResult<GpuBufferHandle> {
2061 let grad_buf = Self::unwrap_buffer(grad)?;
2062 let output_buf = Self::unwrap_buffer(output)?;
2063 let dev = self.device(grad.device_ordinal())?;
2064 let result = crate::kernels::gpu_softmax_backward(grad_buf, output_buf, cols, dev)
2065 .map_err(Self::map_gpu_err)?;
2066 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
2067 }
2068
2069 fn layernorm_backward_f32(
2070 &self,
2071 input: &GpuBufferHandle,
2072 grad_output: &GpuBufferHandle,
2073 weight: &GpuBufferHandle,
2074 rows: usize,
2075 cols: usize,
2076 eps: f32,
2077 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
2078 let in_buf = Self::unwrap_buffer(input)?;
2079 let go_buf = Self::unwrap_buffer(grad_output)?;
2080 let w_buf = Self::unwrap_buffer(weight)?;
2081 let dev = self.device(input.device_ordinal())?;
2082 let (gi, gw, gb) =
2083 crate::kernels::gpu_layernorm_backward(in_buf, go_buf, w_buf, rows, cols, eps, dev)
2084 .map_err(Self::map_gpu_err)?;
2085 let ordinal = input.device_ordinal();
2086 Ok((
2087 Self::wrap_buffer(gi, ordinal),
2088 Self::wrap_buffer(gw, ordinal),
2089 Self::wrap_buffer(gb, ordinal),
2090 ))
2091 }
2092
2093 fn sum_axis_f32(
2094 &self,
2095 a: &GpuBufferHandle,
2096 shape: &[usize],
2097 axis: usize,
2098 ) -> FerrotorchResult<GpuBufferHandle> {
2099 let a_buf = Self::unwrap_buffer(a)?;
2100 let dev = self.device(a.device_ordinal())?;
2101 let outer: usize = shape[..axis].iter().product();
2102 let axis_size = shape[axis];
2103 let inner: usize = shape[axis + 1..].iter().product::<usize>().max(1);
2104 let result = crate::kernels::gpu_sum_axis(a_buf, outer, axis_size, inner, dev)
2105 .map_err(Self::map_gpu_err)?;
2106 Ok(Self::wrap_buffer(result, a.device_ordinal()))
2107 }
2108
2109 fn matmul_f16_f32(
2110 &self,
2111 a: &GpuBufferHandle,
2112 b: &GpuBufferHandle,
2113 m: usize,
2114 k: usize,
2115 n: usize,
2116 ) -> FerrotorchResult<GpuBufferHandle> {
2117 let a_buf = Self::unwrap_buffer(a)?;
2118 let b_buf = Self::unwrap_buffer(b)?;
2119 let dev = self.device(a.device_ordinal())?;
2120 let result =
2121 crate::blas::gpu_matmul_f16(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
2122 Ok(Self::wrap_buffer(result, a.device_ordinal()))
2123 }
2124
2125 fn save_rng_state(&self, device: usize) -> FerrotorchResult<GpuRngState> {
2126 let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
2127 FerrotorchError::InvalidArgument {
2128 message: "failed to lock CUDA RNG manager".into(),
2129 }
2130 })?;
2131 let state = mgr.get_rng_state(device);
2132 Ok(GpuRngState {
2133 counter: state.counter,
2134 seed: state.seed,
2135 offset: state.offset,
2136 device,
2137 })
2138 }
2139
2140 fn restore_rng_state(&self, state: GpuRngState) -> FerrotorchResult<()> {
2141 let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
2142 FerrotorchError::InvalidArgument {
2143 message: "failed to lock CUDA RNG manager".into(),
2144 }
2145 })?;
2146 mgr.set_rng_state(
2147 state.device,
2148 crate::rng::PhiloxState {
2149 counter: state.counter,
2150 seed: state.seed,
2151 offset: state.offset,
2152 },
2153 );
2154 Ok(())
2155 }
2156
2157 fn strided_split_f32(
2158 &self,
2159 input: &GpuBufferHandle,
2160 total_along_axis: usize,
2161 split_offset: usize,
2162 split_size: usize,
2163 inner_size: usize,
2164 n: usize,
2165 ) -> FerrotorchResult<GpuBufferHandle> {
2166 let in_buf = Self::unwrap_buffer(input)?;
2167 let dev = self.device(input.device_ordinal())?;
2168 let result = crate::kernels::gpu_strided_split(
2169 in_buf,
2170 total_along_axis,
2171 split_offset,
2172 split_size,
2173 inner_size,
2174 n,
2175 dev,
2176 )
2177 .map_err(Self::map_gpu_err)?;
2178 Ok(Self::wrap_buffer(result, input.device_ordinal()))
2179 }
2180
2181 fn strided_copy_f32(
2182 &self,
2183 input: &GpuBufferHandle,
2184 out_shape: &[usize],
2185 src_strides: &[isize],
2186 src_offset: usize,
2187 ) -> FerrotorchResult<GpuBufferHandle> {
2188 let in_buf = Self::unwrap_buffer(input)?;
2189 let dev = self.device(input.device_ordinal())?;
2190 let result =
2191 crate::kernels::gpu_strided_copy(in_buf, out_shape, src_strides, src_offset, dev)
2192 .map_err(Self::map_gpu_err)?;
2193 Ok(Self::wrap_buffer(result, input.device_ordinal()))
2194 }
2195
2196 fn strided_copy_f64(
2197 &self,
2198 input: &GpuBufferHandle,
2199 out_shape: &[usize],
2200 src_strides: &[isize],
2201 src_offset: usize,
2202 ) -> FerrotorchResult<GpuBufferHandle> {
2203 let in_buf = Self::unwrap_buffer_f64(input)?;
2204 let dev = self.device(input.device_ordinal())?;
2205 let result = crate::kernels::gpu_strided_copy_f64(
2206 in_buf,
2207 out_shape,
2208 src_strides,
2209 src_offset,
2210 dev,
2211 )
2212 .map_err(Self::map_gpu_err)?;
2213 Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
2214 }
2215
2216 fn strided_cat_f32(
2217 &self,
2218 input: &GpuBufferHandle,
2219 output: &mut GpuBufferHandle,
2220 total_along_axis: usize,
2221 cat_offset: usize,
2222 part_size: usize,
2223 inner_size: usize,
2224 n: usize,
2225 ) -> FerrotorchResult<()> {
2226 let in_buf = Self::unwrap_buffer(input)?;
2227 let dev = self.device(input.device_ordinal())?;
2228 let out_buf =
2229 output
2230 .downcast_mut::<CudaBuffer<f32>>()
2231 .ok_or(FerrotorchError::InvalidArgument {
2232 message: "strided_cat_f32: output is not CudaBuffer<f32>".into(),
2233 })?;
2234 crate::kernels::gpu_strided_cat(
2235 in_buf,
2236 out_buf,
2237 total_along_axis,
2238 cat_offset,
2239 part_size,
2240 inner_size,
2241 n,
2242 dev,
2243 )
2244 .map_err(Self::map_gpu_err)?;
2245 Ok(())
2246 }
2247
2248 fn svd_f32(
2251 &self,
2252 a: &GpuBufferHandle,
2253 m: usize,
2254 n: usize,
2255 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
2256 let a_buf = Self::unwrap_buffer(a)?;
2257 let dev = self.device(a.device_ordinal())?;
2258 let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2259 let (u, s, vt) =
2260 crate::cusolver::gpu_svd_f32(&a_host, m, n, dev).map_err(Self::map_gpu_err)?;
2261 let u_buf = crate::transfer::cpu_to_gpu(&u, dev).map_err(Self::map_gpu_err)?;
2262 let s_buf = crate::transfer::cpu_to_gpu(&s, dev).map_err(Self::map_gpu_err)?;
2263 let vt_buf = crate::transfer::cpu_to_gpu(&vt, dev).map_err(Self::map_gpu_err)?;
2264 let ord = a.device_ordinal();
2265 Ok((
2266 Self::wrap_buffer(u_buf, ord),
2267 Self::wrap_buffer(s_buf, ord),
2268 Self::wrap_buffer(vt_buf, ord),
2269 ))
2270 }
2271
2272 fn svd_f64(
2273 &self,
2274 a: &GpuBufferHandle,
2275 m: usize,
2276 n: usize,
2277 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
2278 let a_buf = Self::unwrap_buffer_f64(a)?;
2279 let dev = self.device(a.device_ordinal())?;
2280 let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2281 let (u, s, vt) =
2282 crate::cusolver::gpu_svd_f64(&a_host, m, n, dev).map_err(Self::map_gpu_err)?;
2283 let u_buf = crate::transfer::cpu_to_gpu(&u, dev).map_err(Self::map_gpu_err)?;
2284 let s_buf = crate::transfer::cpu_to_gpu(&s, dev).map_err(Self::map_gpu_err)?;
2285 let vt_buf = crate::transfer::cpu_to_gpu(&vt, dev).map_err(Self::map_gpu_err)?;
2286 let ord = a.device_ordinal();
2287 Ok((
2288 Self::wrap_buffer_f64(u_buf, ord),
2289 Self::wrap_buffer_f64(s_buf, ord),
2290 Self::wrap_buffer_f64(vt_buf, ord),
2291 ))
2292 }
2293
2294 fn cholesky_f32(&self, a: &GpuBufferHandle, n: usize) -> FerrotorchResult<GpuBufferHandle> {
2295 let a_buf = Self::unwrap_buffer(a)?;
2296 let dev = self.device(a.device_ordinal())?;
2297 let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2298 let l = crate::cusolver::gpu_cholesky_f32(&a_host, n, dev).map_err(Self::map_gpu_err)?;
2299 let l_buf = crate::transfer::cpu_to_gpu(&l, dev).map_err(Self::map_gpu_err)?;
2300 Ok(Self::wrap_buffer(l_buf, a.device_ordinal()))
2301 }
2302
2303 fn cholesky_f64(&self, a: &GpuBufferHandle, n: usize) -> FerrotorchResult<GpuBufferHandle> {
2304 let a_buf = Self::unwrap_buffer_f64(a)?;
2305 let dev = self.device(a.device_ordinal())?;
2306 let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2307 let l = crate::cusolver::gpu_cholesky_f64(&a_host, n, dev).map_err(Self::map_gpu_err)?;
2308 let l_buf = crate::transfer::cpu_to_gpu(&l, dev).map_err(Self::map_gpu_err)?;
2309 Ok(Self::wrap_buffer_f64(l_buf, a.device_ordinal()))
2310 }
2311
2312 fn solve_f32(
2313 &self,
2314 a: &GpuBufferHandle,
2315 b: &GpuBufferHandle,
2316 n: usize,
2317 nrhs: usize,
2318 ) -> FerrotorchResult<GpuBufferHandle> {
2319 let a_buf = Self::unwrap_buffer(a)?;
2320 let b_buf = Self::unwrap_buffer(b)?;
2321 let dev = self.device(a.device_ordinal())?;
2322 let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2323 let b_host = crate::transfer::gpu_to_cpu(b_buf, dev).map_err(Self::map_gpu_err)?;
2324 let x =
2325 crate::cusolver::gpu_solve_f32(&a_host, &b_host, n, nrhs, dev)
2326 .map_err(Self::map_gpu_err)?;
2327 let x_buf = crate::transfer::cpu_to_gpu(&x, dev).map_err(Self::map_gpu_err)?;
2328 Ok(Self::wrap_buffer(x_buf, a.device_ordinal()))
2329 }
2330
2331 fn solve_f64(
2332 &self,
2333 a: &GpuBufferHandle,
2334 b: &GpuBufferHandle,
2335 n: usize,
2336 nrhs: usize,
2337 ) -> FerrotorchResult<GpuBufferHandle> {
2338 let a_buf = Self::unwrap_buffer_f64(a)?;
2339 let b_buf = Self::unwrap_buffer_f64(b)?;
2340 let dev = self.device(a.device_ordinal())?;
2341 let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2342 let b_host = crate::transfer::gpu_to_cpu(b_buf, dev).map_err(Self::map_gpu_err)?;
2343 let x =
2344 crate::cusolver::gpu_solve_f64(&a_host, &b_host, n, nrhs, dev)
2345 .map_err(Self::map_gpu_err)?;
2346 let x_buf = crate::transfer::cpu_to_gpu(&x, dev).map_err(Self::map_gpu_err)?;
2347 Ok(Self::wrap_buffer_f64(x_buf, a.device_ordinal()))
2348 }
2349
2350 fn qr_f32(
2351 &self,
2352 a: &GpuBufferHandle,
2353 m: usize,
2354 n: usize,
2355 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
2356 let a_buf = Self::unwrap_buffer(a)?;
2357 let dev = self.device(a.device_ordinal())?;
2358 let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2359 let (q, r) =
2360 crate::cusolver::gpu_qr_f32(&a_host, m, n, dev).map_err(Self::map_gpu_err)?;
2361 let q_buf = crate::transfer::cpu_to_gpu(&q, dev).map_err(Self::map_gpu_err)?;
2362 let r_buf = crate::transfer::cpu_to_gpu(&r, dev).map_err(Self::map_gpu_err)?;
2363 let ord = a.device_ordinal();
2364 Ok((Self::wrap_buffer(q_buf, ord), Self::wrap_buffer(r_buf, ord)))
2365 }
2366
2367 fn qr_f64(
2368 &self,
2369 a: &GpuBufferHandle,
2370 m: usize,
2371 n: usize,
2372 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
2373 let a_buf = Self::unwrap_buffer_f64(a)?;
2374 let dev = self.device(a.device_ordinal())?;
2375 let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2376 let (q, r) =
2377 crate::cusolver::gpu_qr_f64(&a_host, m, n, dev).map_err(Self::map_gpu_err)?;
2378 let q_buf = crate::transfer::cpu_to_gpu(&q, dev).map_err(Self::map_gpu_err)?;
2379 let r_buf = crate::transfer::cpu_to_gpu(&r, dev).map_err(Self::map_gpu_err)?;
2380 let ord = a.device_ordinal();
2381 Ok((
2382 Self::wrap_buffer_f64(q_buf, ord),
2383 Self::wrap_buffer_f64(r_buf, ord),
2384 ))
2385 }
2386}
2387
2388pub fn get_cuda_device() -> FerrotorchResult<Arc<GpuDevice>> {
2399 let backend =
2400 ferrotorch_core::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
2401 let cuda_backend = backend.as_any().downcast_ref::<CudaBackendImpl>().ok_or(
2404 FerrotorchError::InvalidArgument {
2405 message: "registered GPU backend is not CudaBackendImpl".into(),
2406 },
2407 )?;
2408 Ok(Arc::clone(cuda_backend.default_device()?))
2409}
2410
2411pub fn init_cuda_backend() -> FerrotorchResult<()> {
2425 if ferrotorch_core::gpu_dispatch::has_gpu_backend() {
2427 return Ok(());
2428 }
2429 let backend = CudaBackendImpl::new()?;
2430 let _ = ferrotorch_core::gpu_dispatch::register_gpu_backend(Box::new(backend));
2434 Ok(())
2435}
2436
2437#[cfg(test)]
2442#[cfg(feature = "cuda")]
2443mod tests {
2444 use super::*;
2445 use ferrotorch_core::gpu_dispatch;
2446
2447 fn ensure_init() {
2454 if !gpu_dispatch::has_gpu_backend() {
2455 init_cuda_backend().expect("init_cuda_backend");
2456 }
2457 }
2458
2459 #[test]
2460 fn test_init_cuda_backend() {
2461 ensure_init();
2463 assert!(gpu_dispatch::has_gpu_backend());
2464 }
2465
2466 #[test]
2467 fn test_gpu_backend_returns_some() {
2468 ensure_init();
2469 assert!(gpu_dispatch::gpu_backend().is_some());
2470 }
2471
2472 #[test]
2473 fn test_roundtrip_cpu_gpu_cpu() {
2474 ensure_init();
2475 let backend = gpu_dispatch::gpu_backend().expect("backend registered");
2476
2477 let host: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
2478 let bytes: &[u8] = unsafe {
2479 std::slice::from_raw_parts(
2480 host.as_ptr() as *const u8,
2481 host.len() * std::mem::size_of::<f32>(),
2482 )
2483 };
2484
2485 let handle = backend.cpu_to_gpu(bytes, 4, 0).expect("cpu_to_gpu");
2486 assert_eq!(handle.len(), 5);
2487 assert_eq!(handle.device_ordinal(), 0);
2488
2489 let back_bytes = backend.gpu_to_cpu(&handle).expect("gpu_to_cpu");
2490 let back: &[f32] = unsafe {
2491 std::slice::from_raw_parts(back_bytes.as_ptr() as *const f32, back_bytes.len() / 4)
2492 };
2493 assert_eq!(back, &host[..]);
2494 }
2495
2496 #[test]
2497 fn test_add_f32() {
2498 ensure_init();
2499 let backend = gpu_dispatch::gpu_backend().expect("backend registered");
2500
2501 let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
2502 let b_data: Vec<f32> = vec![10.0, 20.0, 30.0, 40.0];
2503 let expected: Vec<f32> = vec![11.0, 22.0, 33.0, 44.0];
2504
2505 let a_bytes: &[u8] =
2506 unsafe { std::slice::from_raw_parts(a_data.as_ptr() as *const u8, a_data.len() * 4) };
2507 let b_bytes: &[u8] =
2508 unsafe { std::slice::from_raw_parts(b_data.as_ptr() as *const u8, b_data.len() * 4) };
2509
2510 let a_handle = backend.cpu_to_gpu(a_bytes, 4, 0).expect("cpu_to_gpu a");
2511 let b_handle = backend.cpu_to_gpu(b_bytes, 4, 0).expect("cpu_to_gpu b");
2512
2513 let result = backend.add_f32(&a_handle, &b_handle).expect("add_f32");
2514 assert_eq!(result.len(), 4);
2515
2516 let result_bytes = backend.gpu_to_cpu(&result).expect("gpu_to_cpu");
2517 let result_f32: &[f32] = unsafe {
2518 std::slice::from_raw_parts(result_bytes.as_ptr() as *const f32, result_bytes.len() / 4)
2519 };
2520
2521 for (i, (&got, &exp)) in result_f32.iter().zip(expected.iter()).enumerate() {
2522 assert!(
2523 (got - exp).abs() < 1e-6,
2524 "element {i}: got {got}, expected {exp}",
2525 );
2526 }
2527 }
2528
2529 #[test]
2530 fn test_matmul_f32() {
2531 ensure_init();
2532 let backend = gpu_dispatch::gpu_backend().expect("backend registered");
2533
2534 let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
2542 let b_data: Vec<f32> = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
2543 let expected: Vec<f32> = vec![58.0, 64.0, 139.0, 154.0];
2544
2545 let a_bytes: &[u8] =
2546 unsafe { std::slice::from_raw_parts(a_data.as_ptr() as *const u8, a_data.len() * 4) };
2547 let b_bytes: &[u8] =
2548 unsafe { std::slice::from_raw_parts(b_data.as_ptr() as *const u8, b_data.len() * 4) };
2549
2550 let a_handle = backend.cpu_to_gpu(a_bytes, 4, 0).expect("cpu_to_gpu a");
2551 let b_handle = backend.cpu_to_gpu(b_bytes, 4, 0).expect("cpu_to_gpu b");
2552
2553 let result = backend
2554 .matmul_f32(&a_handle, &b_handle, 2, 3, 2)
2555 .expect("matmul_f32");
2556 assert_eq!(result.len(), 4);
2557
2558 let result_bytes = backend.gpu_to_cpu(&result).expect("gpu_to_cpu");
2559 let result_f32: &[f32] = unsafe {
2560 std::slice::from_raw_parts(result_bytes.as_ptr() as *const f32, result_bytes.len() / 4)
2561 };
2562
2563 for (i, (&got, &exp)) in result_f32.iter().zip(expected.iter()).enumerate() {
2564 assert!(
2565 (got - exp).abs() < 1e-3,
2566 "element {i}: got {got}, expected {exp}",
2567 );
2568 }
2569 }
2570}