1use std::sync::Arc;
15
16use ferrotorch_core::error::{FerrotorchError, FerrotorchResult};
17use ferrotorch_core::gpu_dispatch::{GpuBackend, GpuBufferHandle, GpuRngState};
18
19use crate::buffer::CudaBuffer;
20use crate::device::GpuDevice;
21
22pub struct CudaBackendImpl {
32 devices: Vec<Arc<GpuDevice>>,
33}
34
35impl CudaBackendImpl {
36 pub fn new() -> FerrotorchResult<Self> {
43 let device = Arc::new(
44 GpuDevice::new(0).map_err(|e| FerrotorchError::InvalidArgument {
45 message: format!("CUDA init failed: {e}"),
46 })?,
47 );
48 Ok(Self {
49 devices: vec![device],
50 })
51 }
52
53 pub fn default_device(&self) -> FerrotorchResult<&Arc<GpuDevice>> {
55 self.device(0)
56 }
57
58 fn device(&self, ordinal: usize) -> FerrotorchResult<&Arc<GpuDevice>> {
60 self.devices
61 .get(ordinal)
62 .ok_or(FerrotorchError::InvalidArgument {
63 message: format!("CUDA device {ordinal} not available"),
64 })
65 }
66
67 fn wrap_buffer(buf: CudaBuffer<f32>, ordinal: usize) -> GpuBufferHandle {
69 let len = buf.len();
70 GpuBufferHandle::new(Box::new(buf), ordinal, len)
71 }
72
73 fn wrap_buffer_f64(buf: CudaBuffer<f64>, ordinal: usize) -> GpuBufferHandle {
75 let len = buf.len();
76 GpuBufferHandle::new(Box::new(buf), ordinal, len)
77 }
78
79 fn unwrap_buffer(handle: &GpuBufferHandle) -> FerrotorchResult<&CudaBuffer<f32>> {
81 handle
82 .downcast_ref::<CudaBuffer<f32>>()
83 .ok_or(FerrotorchError::InvalidArgument {
84 message: "GPU handle does not contain a CudaBuffer<f32>".into(),
85 })
86 }
87
88 fn unwrap_buffer_mut(handle: &mut GpuBufferHandle) -> FerrotorchResult<&mut CudaBuffer<f32>> {
90 handle
91 .downcast_mut::<CudaBuffer<f32>>()
92 .ok_or(FerrotorchError::InvalidArgument {
93 message: "GPU handle does not contain a CudaBuffer<f32>".into(),
94 })
95 }
96
97 fn unwrap_buffer_f64_mut(handle: &mut GpuBufferHandle) -> FerrotorchResult<&mut CudaBuffer<f64>> {
99 handle
100 .downcast_mut::<CudaBuffer<f64>>()
101 .ok_or(FerrotorchError::InvalidArgument {
102 message: "GPU handle does not contain a CudaBuffer<f64>".into(),
103 })
104 }
105
106 fn unwrap_buffer_f64(handle: &GpuBufferHandle) -> FerrotorchResult<&CudaBuffer<f64>> {
108 handle
109 .downcast_ref::<CudaBuffer<f64>>()
110 .ok_or(FerrotorchError::InvalidArgument {
111 message: "GPU handle does not contain a CudaBuffer<f64>".into(),
112 })
113 }
114
115 fn map_gpu_err(e: crate::error::GpuError) -> FerrotorchError {
117 FerrotorchError::InvalidArgument {
118 message: format!("{e}"),
119 }
120 }
121}
122
123impl GpuBackend for CudaBackendImpl {
128 fn as_any(&self) -> &dyn std::any::Any {
129 self
130 }
131
132 fn raw_device_ptr(&self, handle: &GpuBufferHandle) -> *const std::ffi::c_void {
133 use cudarc::driver::DevicePtr;
134 let dev = match self.device(handle.device_ordinal()) {
135 Ok(d) => d,
136 Err(_) => return std::ptr::null(),
137 };
138 let stream = dev.stream();
139 if let Ok(buf) = Self::unwrap_buffer(handle) {
140 let (ptr, _sync) = buf.inner().device_ptr(&stream);
141 ptr as *const std::ffi::c_void
142 } else if let Ok(buf) = Self::unwrap_buffer_f64(handle) {
143 let (ptr, _sync) = buf.inner().device_ptr(&stream);
144 ptr as *const std::ffi::c_void
145 } else {
146 std::ptr::null()
147 }
148 }
149
150 fn raw_device_ptr_mut(&self, handle: &mut GpuBufferHandle) -> *mut std::ffi::c_void {
151 use cudarc::driver::DevicePtrMut;
152 let ordinal = handle.device_ordinal();
153 let dev = match self.device(ordinal) {
154 Ok(d) => d,
155 Err(_) => return std::ptr::null_mut(),
156 };
157 let stream = dev.stream();
158 if let Some(buf) = handle.downcast_mut::<CudaBuffer<f32>>() {
159 let (ptr, _sync) = buf.inner_mut().device_ptr_mut(&stream);
160 ptr as *mut std::ffi::c_void
161 } else if let Some(buf) = handle.downcast_mut::<CudaBuffer<f64>>() {
162 let (ptr, _sync) = buf.inner_mut().device_ptr_mut(&stream);
163 ptr as *mut std::ffi::c_void
164 } else {
165 std::ptr::null_mut()
166 }
167 }
168
169 fn buffer_elem_size(&self, handle: &GpuBufferHandle) -> usize {
170 if Self::unwrap_buffer(handle).is_ok() {
171 4 } else if Self::unwrap_buffer_f64(handle).is_ok() {
173 8 } else {
175 0
176 }
177 }
178
179 fn cpu_to_gpu(
180 &self,
181 data: &[u8],
182 elem_size: usize,
183 device: usize,
184 ) -> FerrotorchResult<GpuBufferHandle> {
185 let dev = self.device(device)?;
186 match elem_size {
187 4 => {
188 let count = data.len() / 4;
191 let f32_data: &[f32] =
192 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, count) };
193 let buf = crate::transfer::cpu_to_gpu(f32_data, dev).map_err(Self::map_gpu_err)?;
194 Ok(Self::wrap_buffer(buf, device))
195 }
196 8 => {
197 let count = data.len() / 8;
200 let f64_data: &[f64] =
201 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f64, count) };
202 let buf = crate::transfer::cpu_to_gpu(f64_data, dev).map_err(Self::map_gpu_err)?;
203 Ok(Self::wrap_buffer_f64(buf, device))
204 }
205 other => Err(FerrotorchError::InvalidArgument {
206 message: format!("cpu_to_gpu: unsupported elem_size {other} (expected 4 or 8)"),
207 }),
208 }
209 }
210
211 fn cpu_to_gpu_pinned(
212 &self,
213 data: &[u8],
214 elem_size: usize,
215 device: usize,
216 ) -> FerrotorchResult<GpuBufferHandle> {
217 let dev = self.device(device)?;
218 match elem_size {
219 4 => {
220 let count = data.len() / 4;
221 let f32_data: &[f32] =
222 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, count) };
223 let buf = crate::transfer::cpu_to_gpu_pinned(f32_data, dev)
224 .map_err(Self::map_gpu_err)?;
225 Ok(Self::wrap_buffer(buf, device))
226 }
227 8 => {
228 let count = data.len() / 8;
229 let f64_data: &[f64] =
230 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f64, count) };
231 let buf = crate::transfer::cpu_to_gpu_pinned(f64_data, dev)
232 .map_err(Self::map_gpu_err)?;
233 Ok(Self::wrap_buffer_f64(buf, device))
234 }
235 other => Err(FerrotorchError::InvalidArgument {
236 message: format!(
237 "cpu_to_gpu_pinned: unsupported elem_size {other} (expected 4 or 8)"
238 ),
239 }),
240 }
241 }
242
243 fn gpu_to_cpu(&self, handle: &GpuBufferHandle) -> FerrotorchResult<Vec<u8>> {
244 let dev = self.device(handle.device_ordinal())?;
245
246 if let Ok(buf) = Self::unwrap_buffer(handle) {
248 let f32_data = crate::transfer::gpu_to_cpu(buf, dev).map_err(Self::map_gpu_err)?;
249
250 let bytes = unsafe {
255 let mut v = std::mem::ManuallyDrop::new(f32_data);
256 let ptr = v.as_mut_ptr() as *mut u8;
257 let len = v.len() * 4;
258 let cap = v.capacity() * 4;
259 Vec::from_raw_parts(ptr, len, cap)
260 };
261 Ok(bytes)
262 } else if let Ok(buf) = Self::unwrap_buffer_f64(handle) {
263 let f64_data = crate::transfer::gpu_to_cpu(buf, dev).map_err(Self::map_gpu_err)?;
264
265 let bytes = unsafe {
270 let mut v = std::mem::ManuallyDrop::new(f64_data);
271 let ptr = v.as_mut_ptr() as *mut u8;
272 let len = v.len() * 8;
273 let cap = v.capacity() * 8;
274 Vec::from_raw_parts(ptr, len, cap)
275 };
276 Ok(bytes)
277 } else {
278 Err(FerrotorchError::InvalidArgument {
279 message: "gpu_to_cpu: handle is neither CudaBuffer<f32> nor CudaBuffer<f64>".into(),
280 })
281 }
282 }
283
284 fn clone_buffer(&self, handle: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
285 let bytes = self.gpu_to_cpu(handle)?;
288 let elem_size = if handle.downcast_ref::<CudaBuffer<f64>>().is_some() {
290 8
291 } else {
292 4
293 };
294 self.cpu_to_gpu(&bytes, elem_size, handle.device_ordinal())
295 }
296
297 fn alloc_zeros(
298 &self,
299 len: usize,
300 elem_size: usize,
301 device: usize,
302 ) -> FerrotorchResult<GpuBufferHandle> {
303 let dev = self.device(device)?;
304 match elem_size {
305 4 => {
306 let buf = crate::transfer::alloc_zeros_f32(len, dev).map_err(Self::map_gpu_err)?;
307 Ok(Self::wrap_buffer(buf, device))
308 }
309 8 => {
310 let buf = crate::transfer::alloc_zeros_f64(len, dev).map_err(Self::map_gpu_err)?;
311 Ok(Self::wrap_buffer_f64(buf, device))
312 }
313 other => Err(FerrotorchError::InvalidArgument {
314 message: format!("alloc_zeros: unsupported elem_size {other} (expected 4 or 8)"),
315 }),
316 }
317 }
318
319 fn add_f32(
322 &self,
323 a: &GpuBufferHandle,
324 b: &GpuBufferHandle,
325 ) -> FerrotorchResult<GpuBufferHandle> {
326 let a_buf = Self::unwrap_buffer(a)?;
327 let b_buf = Self::unwrap_buffer(b)?;
328 let dev = self.device(a.device_ordinal())?;
329 let result = crate::kernels::gpu_add(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
330 Ok(Self::wrap_buffer(result, a.device_ordinal()))
331 }
332
333 fn sub_f32(
334 &self,
335 a: &GpuBufferHandle,
336 b: &GpuBufferHandle,
337 ) -> FerrotorchResult<GpuBufferHandle> {
338 let a_buf = Self::unwrap_buffer(a)?;
339 let b_buf = Self::unwrap_buffer(b)?;
340 let dev = self.device(a.device_ordinal())?;
341 let result = crate::kernels::gpu_sub(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
342 Ok(Self::wrap_buffer(result, a.device_ordinal()))
343 }
344
345 fn mul_f32(
346 &self,
347 a: &GpuBufferHandle,
348 b: &GpuBufferHandle,
349 ) -> FerrotorchResult<GpuBufferHandle> {
350 let a_buf = Self::unwrap_buffer(a)?;
351 let b_buf = Self::unwrap_buffer(b)?;
352 let dev = self.device(a.device_ordinal())?;
353 let result = crate::kernels::gpu_mul(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
354 Ok(Self::wrap_buffer(result, a.device_ordinal()))
355 }
356
357 fn neg_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
358 let a_buf = Self::unwrap_buffer(a)?;
359 let dev = self.device(a.device_ordinal())?;
360 let result = crate::kernels::gpu_neg(a_buf, dev).map_err(Self::map_gpu_err)?;
361 Ok(Self::wrap_buffer(result, a.device_ordinal()))
362 }
363
364 fn relu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
365 let a_buf = Self::unwrap_buffer(a)?;
366 let dev = self.device(a.device_ordinal())?;
367 let result = crate::kernels::gpu_relu(a_buf, dev).map_err(Self::map_gpu_err)?;
368 Ok(Self::wrap_buffer(result, a.device_ordinal()))
369 }
370
371 fn div_f32(
372 &self,
373 a: &GpuBufferHandle,
374 b: &GpuBufferHandle,
375 ) -> FerrotorchResult<GpuBufferHandle> {
376 let a_buf = Self::unwrap_buffer(a)?;
377 let b_buf = Self::unwrap_buffer(b)?;
378 let dev = self.device(a.device_ordinal())?;
379 let result = crate::kernels::gpu_div(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
380 Ok(Self::wrap_buffer(result, a.device_ordinal()))
381 }
382
383 fn exp_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
384 let a_buf = Self::unwrap_buffer(a)?;
385 let dev = self.device(a.device_ordinal())?;
386 let result = crate::kernels::gpu_exp(a_buf, dev).map_err(Self::map_gpu_err)?;
387 Ok(Self::wrap_buffer(result, a.device_ordinal()))
388 }
389
390 fn log_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
391 let a_buf = Self::unwrap_buffer(a)?;
392 let dev = self.device(a.device_ordinal())?;
393 let result = crate::kernels::gpu_log(a_buf, dev).map_err(Self::map_gpu_err)?;
394 Ok(Self::wrap_buffer(result, a.device_ordinal()))
395 }
396
397 fn sqrt_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
398 let a_buf = Self::unwrap_buffer(a)?;
399 let dev = self.device(a.device_ordinal())?;
400 let result = crate::kernels::gpu_sqrt(a_buf, dev).map_err(Self::map_gpu_err)?;
401 Ok(Self::wrap_buffer(result, a.device_ordinal()))
402 }
403
404 fn pow_f32(&self, a: &GpuBufferHandle, exponent: f32) -> FerrotorchResult<GpuBufferHandle> {
405 let a_buf = Self::unwrap_buffer(a)?;
406 let dev = self.device(a.device_ordinal())?;
407 let result =
408 crate::kernels::gpu_pow(a_buf, exponent, dev).map_err(Self::map_gpu_err)?;
409 Ok(Self::wrap_buffer(result, a.device_ordinal()))
410 }
411
412 fn abs_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
413 let a_buf = Self::unwrap_buffer(a)?;
414 let dev = self.device(a.device_ordinal())?;
415 let result = crate::kernels::gpu_abs(a_buf, dev).map_err(Self::map_gpu_err)?;
416 Ok(Self::wrap_buffer(result, a.device_ordinal()))
417 }
418
419 fn sigmoid_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
420 let a_buf = Self::unwrap_buffer(a)?;
421 let dev = self.device(a.device_ordinal())?;
422 let result = crate::kernels::gpu_sigmoid(a_buf, dev).map_err(Self::map_gpu_err)?;
423 Ok(Self::wrap_buffer(result, a.device_ordinal()))
424 }
425
426 fn tanh_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
427 let a_buf = Self::unwrap_buffer(a)?;
428 let dev = self.device(a.device_ordinal())?;
429 let result = crate::kernels::gpu_tanh(a_buf, dev).map_err(Self::map_gpu_err)?;
430 Ok(Self::wrap_buffer(result, a.device_ordinal()))
431 }
432
433 fn add_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
438 let a_buf = Self::unwrap_buffer_f64(a)?;
439 let b_buf = Self::unwrap_buffer_f64(b)?;
440 let dev = self.device(a.device_ordinal())?;
441 let result = crate::kernels::gpu_add_f64(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
442 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
443 }
444
445 fn sub_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
446 let a_buf = Self::unwrap_buffer_f64(a)?;
447 let b_buf = Self::unwrap_buffer_f64(b)?;
448 let dev = self.device(a.device_ordinal())?;
449 let result = crate::kernels::gpu_sub_f64(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
450 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
451 }
452
453 fn mul_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
454 let a_buf = Self::unwrap_buffer_f64(a)?;
455 let b_buf = Self::unwrap_buffer_f64(b)?;
456 let dev = self.device(a.device_ordinal())?;
457 let result = crate::kernels::gpu_mul_f64(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
458 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
459 }
460
461 fn div_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
462 let a_buf = Self::unwrap_buffer_f64(a)?;
463 let b_buf = Self::unwrap_buffer_f64(b)?;
464 let dev = self.device(a.device_ordinal())?;
465 let result = crate::kernels::gpu_div_f64(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
466 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
467 }
468
469 fn neg_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
470 let a_buf = Self::unwrap_buffer_f64(a)?;
471 let dev = self.device(a.device_ordinal())?;
472 let result = crate::kernels::gpu_neg_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
473 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
474 }
475
476 fn relu_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
477 let a_buf = Self::unwrap_buffer_f64(a)?;
478 let dev = self.device(a.device_ordinal())?;
479 let result = crate::kernels::gpu_relu_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
480 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
481 }
482
483 fn scale_f64(&self, a: &GpuBufferHandle, scalar: f64) -> FerrotorchResult<GpuBufferHandle> {
484 let a_buf = Self::unwrap_buffer_f64(a)?;
485 let dev = self.device(a.device_ordinal())?;
486 let result = crate::kernels::gpu_scale_f64(a_buf, scalar, dev).map_err(Self::map_gpu_err)?;
487 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
488 }
489
490 fn exp_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
491 let a_buf = Self::unwrap_buffer_f64(a)?;
492 let dev = self.device(a.device_ordinal())?;
493 let result = crate::kernels::gpu_exp_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
494 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
495 }
496
497 fn log_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
498 let a_buf = Self::unwrap_buffer_f64(a)?;
499 let dev = self.device(a.device_ordinal())?;
500 let result = crate::kernels::gpu_log_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
501 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
502 }
503
504 fn sqrt_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
505 let a_buf = Self::unwrap_buffer_f64(a)?;
506 let dev = self.device(a.device_ordinal())?;
507 let result = crate::kernels::gpu_sqrt_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
508 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
509 }
510
511 fn pow_f64(&self, a: &GpuBufferHandle, exponent: f64) -> FerrotorchResult<GpuBufferHandle> {
512 let a_buf = Self::unwrap_buffer_f64(a)?;
513 let dev = self.device(a.device_ordinal())?;
514 let result = crate::kernels::gpu_pow_f64(a_buf, exponent, dev).map_err(Self::map_gpu_err)?;
515 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
516 }
517
518 fn abs_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
519 let a_buf = Self::unwrap_buffer_f64(a)?;
520 let dev = self.device(a.device_ordinal())?;
521 let result = crate::kernels::gpu_abs_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
522 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
523 }
524
525 fn sigmoid_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
526 let a_buf = Self::unwrap_buffer_f64(a)?;
527 let dev = self.device(a.device_ordinal())?;
528 let result = crate::kernels::gpu_sigmoid_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
529 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
530 }
531
532 fn tanh_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
533 let a_buf = Self::unwrap_buffer_f64(a)?;
534 let dev = self.device(a.device_ordinal())?;
535 let result = crate::kernels::gpu_tanh_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
536 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
537 }
538
539 fn relu_backward_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
541 let g_buf = Self::unwrap_buffer_f64(grad)?;
542 let i_buf = Self::unwrap_buffer_f64(input)?;
543 let dev = self.device(grad.device_ordinal())?;
544 let result = crate::kernels::gpu_relu_backward_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
545 Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
546 }
547
548 fn sigmoid_backward_f64(&self, grad: &GpuBufferHandle, output: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
549 let g_buf = Self::unwrap_buffer_f64(grad)?;
550 let o_buf = Self::unwrap_buffer_f64(output)?;
551 let dev = self.device(grad.device_ordinal())?;
552 let result = crate::kernels::gpu_sigmoid_backward_f64(g_buf, o_buf, dev).map_err(Self::map_gpu_err)?;
553 Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
554 }
555
556 fn tanh_backward_f64(&self, grad: &GpuBufferHandle, output: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
557 let g_buf = Self::unwrap_buffer_f64(grad)?;
558 let o_buf = Self::unwrap_buffer_f64(output)?;
559 let dev = self.device(grad.device_ordinal())?;
560 let result = crate::kernels::gpu_tanh_backward_f64(g_buf, o_buf, dev).map_err(Self::map_gpu_err)?;
561 Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
562 }
563
564 fn gelu_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
567 let a_buf = Self::unwrap_buffer_f64(a)?;
568 let dev = self.device(a.device_ordinal())?;
569 let result = crate::kernels::gpu_gelu_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
570 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
571 }
572
573 fn gelu_tanh_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
574 let a_buf = Self::unwrap_buffer_f64(a)?;
575 let dev = self.device(a.device_ordinal())?;
576 let result = crate::kernels::gpu_gelu_tanh_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
577 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
578 }
579
580 fn gelu_erf_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
581 let a_buf = Self::unwrap_buffer_f64(a)?;
582 let dev = self.device(a.device_ordinal())?;
583 let result = crate::kernels::gpu_gelu_erf_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
584 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
585 }
586
587 fn silu_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
588 let a_buf = Self::unwrap_buffer_f64(a)?;
589 let dev = self.device(a.device_ordinal())?;
590 let result = crate::kernels::gpu_silu_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
591 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
592 }
593
594 fn elu_f64(&self, a: &GpuBufferHandle, alpha: f64) -> FerrotorchResult<GpuBufferHandle> {
595 let a_buf = Self::unwrap_buffer_f64(a)?;
596 let dev = self.device(a.device_ordinal())?;
597 let result = crate::kernels::gpu_elu_f64(a_buf, alpha, dev).map_err(Self::map_gpu_err)?;
598 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
599 }
600
601 fn mish_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
602 let a_buf = Self::unwrap_buffer_f64(a)?;
603 let dev = self.device(a.device_ordinal())?;
604 let result = crate::kernels::gpu_mish_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
605 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
606 }
607
608 fn clamp_f64(&self, a: &GpuBufferHandle, min_val: f64, max_val: f64) -> FerrotorchResult<GpuBufferHandle> {
609 let a_buf = Self::unwrap_buffer_f64(a)?;
610 let dev = self.device(a.device_ordinal())?;
611 let result = crate::kernels::gpu_clamp_f64(a_buf, min_val, max_val, dev).map_err(Self::map_gpu_err)?;
612 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
613 }
614
615 fn gelu_backward_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
618 let g_buf = Self::unwrap_buffer_f64(grad)?;
619 let i_buf = Self::unwrap_buffer_f64(input)?;
620 let dev = self.device(grad.device_ordinal())?;
621 let result = crate::kernels::gpu_gelu_backward_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
622 Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
623 }
624
625 fn gelu_backward_tanh_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
626 let g_buf = Self::unwrap_buffer_f64(grad)?;
627 let i_buf = Self::unwrap_buffer_f64(input)?;
628 let dev = self.device(grad.device_ordinal())?;
629 let result = crate::kernels::gpu_gelu_backward_tanh_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
630 Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
631 }
632
633 fn gelu_backward_erf_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
634 let g_buf = Self::unwrap_buffer_f64(grad)?;
635 let i_buf = Self::unwrap_buffer_f64(input)?;
636 let dev = self.device(grad.device_ordinal())?;
637 let result = crate::kernels::gpu_gelu_backward_erf_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
638 Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
639 }
640
641 fn silu_backward_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
642 let g_buf = Self::unwrap_buffer_f64(grad)?;
643 let i_buf = Self::unwrap_buffer_f64(input)?;
644 let dev = self.device(grad.device_ordinal())?;
645 let result = crate::kernels::gpu_silu_backward_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
646 Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
647 }
648
649 fn elu_backward_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle, alpha: f64) -> FerrotorchResult<GpuBufferHandle> {
650 let g_buf = Self::unwrap_buffer_f64(grad)?;
651 let i_buf = Self::unwrap_buffer_f64(input)?;
652 let dev = self.device(grad.device_ordinal())?;
653 let result = crate::kernels::gpu_elu_backward_f64(g_buf, i_buf, alpha, dev).map_err(Self::map_gpu_err)?;
654 Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
655 }
656
657 fn mish_backward_f64(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
658 let g_buf = Self::unwrap_buffer_f64(grad)?;
659 let i_buf = Self::unwrap_buffer_f64(input)?;
660 let dev = self.device(grad.device_ordinal())?;
661 let result = crate::kernels::gpu_mish_backward_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
662 Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
663 }
664
665 fn cumsum_f64(&self, a: &GpuBufferHandle, outer: usize, dim_size: usize, inner: usize) -> FerrotorchResult<GpuBufferHandle> {
667 let a_buf = Self::unwrap_buffer_f64(a)?;
668 let dev = self.device(a.device_ordinal())?;
669 let result = crate::kernels::gpu_cumsum_f64(a_buf, outer, dim_size, inner, dev).map_err(Self::map_gpu_err)?;
670 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
671 }
672
673 fn cumprod_f64(&self, a: &GpuBufferHandle, outer: usize, dim_size: usize, inner: usize) -> FerrotorchResult<GpuBufferHandle> {
674 let a_buf = Self::unwrap_buffer_f64(a)?;
675 let dev = self.device(a.device_ordinal())?;
676 let result = crate::kernels::gpu_cumprod_f64(a_buf, outer, dim_size, inner, dev).map_err(Self::map_gpu_err)?;
677 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
678 }
679
680 fn cummax_f64(&self, a: &GpuBufferHandle, outer: usize, dim_size: usize, inner: usize) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
681 let a_buf = Self::unwrap_buffer_f64(a)?;
682 let dev = self.device(a.device_ordinal())?;
683 let (vals, idxs) = crate::kernels::gpu_cummax_f64(a_buf, outer, dim_size, inner, dev).map_err(Self::map_gpu_err)?;
684 let ord = a.device_ordinal();
685 Ok((Self::wrap_buffer_f64(vals, ord), Self::wrap_buffer_f64(idxs, ord)))
686 }
687
688 fn cummin_f64(&self, a: &GpuBufferHandle, outer: usize, dim_size: usize, inner: usize) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
689 let a_buf = Self::unwrap_buffer_f64(a)?;
690 let dev = self.device(a.device_ordinal())?;
691 let (vals, idxs) = crate::kernels::gpu_cummin_f64(a_buf, outer, dim_size, inner, dev).map_err(Self::map_gpu_err)?;
692 let ord = a.device_ordinal();
693 Ok((Self::wrap_buffer_f64(vals, ord), Self::wrap_buffer_f64(idxs, ord)))
694 }
695
696 fn logcumsumexp_f64(&self, a: &GpuBufferHandle, outer: usize, dim_size: usize, inner: usize) -> FerrotorchResult<GpuBufferHandle> {
697 let a_buf = Self::unwrap_buffer_f64(a)?;
698 let dev = self.device(a.device_ordinal())?;
699 let result = crate::kernels::gpu_logcumsumexp_f64(a_buf, outer, dim_size, inner, dev).map_err(Self::map_gpu_err)?;
700 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
701 }
702
703 fn transpose_2d_f64(&self, a: &GpuBufferHandle, m: usize, n: usize) -> FerrotorchResult<GpuBufferHandle> {
705 let a_buf = Self::unwrap_buffer_f64(a)?;
706 let dev = self.device(a.device_ordinal())?;
707 let result = crate::kernels::gpu_transpose_2d_f64(a_buf, m, n, dev).map_err(Self::map_gpu_err)?;
708 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
709 }
710
711 fn permute_0213_f64(&self, a: &GpuBufferHandle, d0: usize, d1: usize, d2: usize, d3: usize) -> FerrotorchResult<GpuBufferHandle> {
712 let a_buf = Self::unwrap_buffer_f64(a)?;
713 let dev = self.device(a.device_ordinal())?;
714 let result = crate::kernels::gpu_permute_0213_f64(a_buf, d0, d1, d2, d3, dev).map_err(Self::map_gpu_err)?;
715 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
716 }
717
718 fn broadcast_add_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle, a_shape: &[usize], b_shape: &[usize], out_shape: &[usize]) -> FerrotorchResult<GpuBufferHandle> {
720 let a_buf = Self::unwrap_buffer_f64(a)?;
721 let b_buf = Self::unwrap_buffer_f64(b)?;
722 let dev = self.device(a.device_ordinal())?;
723 let result = crate::kernels::gpu_broadcast_add_f64(a_buf, b_buf, a_shape, b_shape, out_shape, dev).map_err(Self::map_gpu_err)?;
724 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
725 }
726
727 fn broadcast_sub_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle, a_shape: &[usize], b_shape: &[usize], out_shape: &[usize]) -> FerrotorchResult<GpuBufferHandle> {
728 let a_buf = Self::unwrap_buffer_f64(a)?;
729 let b_buf = Self::unwrap_buffer_f64(b)?;
730 let dev = self.device(a.device_ordinal())?;
731 let result = crate::kernels::gpu_broadcast_sub_f64(a_buf, b_buf, a_shape, b_shape, out_shape, dev).map_err(Self::map_gpu_err)?;
732 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
733 }
734
735 fn broadcast_mul_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle, a_shape: &[usize], b_shape: &[usize], out_shape: &[usize]) -> FerrotorchResult<GpuBufferHandle> {
736 let a_buf = Self::unwrap_buffer_f64(a)?;
737 let b_buf = Self::unwrap_buffer_f64(b)?;
738 let dev = self.device(a.device_ordinal())?;
739 let result = crate::kernels::gpu_broadcast_mul_f64(a_buf, b_buf, a_shape, b_shape, out_shape, dev).map_err(Self::map_gpu_err)?;
740 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
741 }
742
743 fn broadcast_div_f64(&self, a: &GpuBufferHandle, b: &GpuBufferHandle, a_shape: &[usize], b_shape: &[usize], out_shape: &[usize]) -> FerrotorchResult<GpuBufferHandle> {
744 let a_buf = Self::unwrap_buffer_f64(a)?;
745 let b_buf = Self::unwrap_buffer_f64(b)?;
746 let dev = self.device(a.device_ordinal())?;
747 let result = crate::kernels::gpu_broadcast_div_f64(a_buf, b_buf, a_shape, b_shape, out_shape, dev).map_err(Self::map_gpu_err)?;
748 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
749 }
750
751 fn sum_f64(&self, a: &GpuBufferHandle, _n: usize) -> FerrotorchResult<GpuBufferHandle> {
753 let a_buf = Self::unwrap_buffer_f64(a)?;
754 let dev = self.device(a.device_ordinal())?;
755 let result = crate::kernels::gpu_reduce_sum_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
756 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
757 }
758
759 fn sum_axis_f64(&self, a: &GpuBufferHandle, shape: &[usize], axis: usize) -> FerrotorchResult<GpuBufferHandle> {
760 let a_buf = Self::unwrap_buffer_f64(a)?;
761 let dev = self.device(a.device_ordinal())?;
762 let outer: usize = shape[..axis].iter().product();
763 let axis_size = shape[axis];
764 let inner: usize = shape[axis + 1..].iter().product();
765 let result = crate::kernels::gpu_sum_axis_f64(a_buf, outer, axis_size, inner, dev).map_err(Self::map_gpu_err)?;
766 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
767 }
768
769 fn softmax_f64(&self, a: &GpuBufferHandle, rows: usize, cols: usize) -> FerrotorchResult<GpuBufferHandle> {
772 let a_buf = Self::unwrap_buffer_f64(a)?;
773 let dev = self.device(a.device_ordinal())?;
774 let result = crate::kernels::gpu_softmax_f64(a_buf, rows, cols, dev).map_err(Self::map_gpu_err)?;
775 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
776 }
777
778 fn softmax_backward_f64(&self, grad: &GpuBufferHandle, output: &GpuBufferHandle, cols: usize) -> FerrotorchResult<GpuBufferHandle> {
779 let grad_buf = Self::unwrap_buffer_f64(grad)?;
780 let output_buf = Self::unwrap_buffer_f64(output)?;
781 let dev = self.device(grad.device_ordinal())?;
782 let result = crate::kernels::gpu_softmax_backward_f64(grad_buf, output_buf, cols, dev).map_err(Self::map_gpu_err)?;
783 Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
784 }
785
786 fn log_softmax_f64(&self, a: &GpuBufferHandle, cols: usize) -> FerrotorchResult<GpuBufferHandle> {
787 let a_buf = Self::unwrap_buffer_f64(a)?;
788 let dev = self.device(a.device_ordinal())?;
789 let result = crate::kernels::gpu_log_softmax_f64(a_buf, cols, dev).map_err(Self::map_gpu_err)?;
790 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
791 }
792
793 fn log_softmax_backward_f64(&self, grad: &GpuBufferHandle, output: &GpuBufferHandle, cols: usize) -> FerrotorchResult<GpuBufferHandle> {
794 let grad_buf = Self::unwrap_buffer_f64(grad)?;
795 let output_buf = Self::unwrap_buffer_f64(output)?;
796 let dev = self.device(grad.device_ordinal())?;
797 let result = crate::kernels::gpu_log_softmax_backward_f64(grad_buf, output_buf, cols, dev).map_err(Self::map_gpu_err)?;
798 Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
799 }
800
801 fn layernorm_f64(
802 &self,
803 input: &GpuBufferHandle,
804 weight: &GpuBufferHandle,
805 bias: &GpuBufferHandle,
806 rows: usize,
807 cols: usize,
808 eps: f64,
809 ) -> FerrotorchResult<GpuBufferHandle> {
810 let in_buf = Self::unwrap_buffer_f64(input)?;
811 let w_buf = Self::unwrap_buffer_f64(weight)?;
812 let b_buf = Self::unwrap_buffer_f64(bias)?;
813 let dev = self.device(input.device_ordinal())?;
814 let result = crate::kernels::gpu_layernorm_f64(in_buf, w_buf, b_buf, rows, cols, eps, dev)
815 .map_err(Self::map_gpu_err)?;
816 Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
817 }
818
819 fn layernorm_backward_f64(
820 &self,
821 input: &GpuBufferHandle,
822 grad_output: &GpuBufferHandle,
823 weight: &GpuBufferHandle,
824 rows: usize,
825 cols: usize,
826 eps: f64,
827 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
828 let in_buf = Self::unwrap_buffer_f64(input)?;
829 let go_buf = Self::unwrap_buffer_f64(grad_output)?;
830 let w_buf = Self::unwrap_buffer_f64(weight)?;
831 let dev = self.device(input.device_ordinal())?;
832 let (gi, gw, gb) =
833 crate::kernels::gpu_layernorm_backward_f64(in_buf, go_buf, w_buf, rows, cols, eps, dev)
834 .map_err(Self::map_gpu_err)?;
835 let ordinal = input.device_ordinal();
836 Ok((
837 Self::wrap_buffer_f64(gi, ordinal),
838 Self::wrap_buffer_f64(gw, ordinal),
839 Self::wrap_buffer_f64(gb, ordinal),
840 ))
841 }
842
843 fn rmsnorm_f64(
844 &self,
845 input: &GpuBufferHandle,
846 weight: &GpuBufferHandle,
847 rows: usize,
848 cols: usize,
849 eps: f64,
850 ) -> FerrotorchResult<GpuBufferHandle> {
851 let in_buf = Self::unwrap_buffer_f64(input)?;
852 let w_buf = Self::unwrap_buffer_f64(weight)?;
853 let dev = self.device(input.device_ordinal())?;
854 let result = crate::kernels::gpu_rmsnorm_f64(in_buf, w_buf, rows, cols, eps, dev)
855 .map_err(Self::map_gpu_err)?;
856 Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
857 }
858
859 fn rmsnorm_backward_f64(
860 &self,
861 input: &GpuBufferHandle,
862 grad_output: &GpuBufferHandle,
863 weight: &GpuBufferHandle,
864 rows: usize,
865 cols: usize,
866 eps: f64,
867 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
868 let in_buf = Self::unwrap_buffer_f64(input)?;
869 let go_buf = Self::unwrap_buffer_f64(grad_output)?;
870 let w_buf = Self::unwrap_buffer_f64(weight)?;
871 let dev = self.device(input.device_ordinal())?;
872 let (gi, gw) =
873 crate::kernels::gpu_rmsnorm_backward_f64(in_buf, go_buf, w_buf, rows, cols, eps, dev)
874 .map_err(Self::map_gpu_err)?;
875 let ordinal = input.device_ordinal();
876 Ok((Self::wrap_buffer_f64(gi, ordinal), Self::wrap_buffer_f64(gw, ordinal)))
877 }
878
879 fn embed_lookup_f64(
882 &self,
883 idx: &GpuBufferHandle,
884 weight: &GpuBufferHandle,
885 d: usize,
886 ) -> FerrotorchResult<GpuBufferHandle> {
887 let idx_buf = Self::unwrap_buffer(idx)?;
889 let w_buf = Self::unwrap_buffer_f64(weight)?;
890 let dev = self.device(idx.device_ordinal())?;
891 let result =
892 crate::kernels::gpu_embed_lookup_f64(idx_buf, w_buf, d, dev).map_err(Self::map_gpu_err)?;
893 Ok(Self::wrap_buffer_f64(result, idx.device_ordinal()))
894 }
895
896 fn embed_lookup_batch_f64(
897 &self,
898 indices: &GpuBufferHandle,
899 weight: &GpuBufferHandle,
900 n: usize,
901 d: usize,
902 ) -> FerrotorchResult<GpuBufferHandle> {
903 let idx_buf = Self::unwrap_buffer(indices)?;
905 let w_buf = Self::unwrap_buffer_f64(weight)?;
906 let dev = self.device(indices.device_ordinal())?;
907 let result = crate::kernels::gpu_embed_lookup_batch_f64(idx_buf, w_buf, n, d, dev)
908 .map_err(Self::map_gpu_err)?;
909 Ok(Self::wrap_buffer_f64(result, indices.device_ordinal()))
910 }
911
912 fn scatter_add_rows_f64(
913 &self,
914 grad_output: &GpuBufferHandle,
915 indices: &GpuBufferHandle,
916 num_embeddings: usize,
917 d: usize,
918 ) -> FerrotorchResult<GpuBufferHandle> {
919 let go_buf = Self::unwrap_buffer_f64(grad_output)?;
920 let idx_buf = Self::unwrap_buffer(indices)?;
922 let dev = self.device(grad_output.device_ordinal())?;
923 let result = crate::kernels::gpu_scatter_add_rows_f64(go_buf, idx_buf, num_embeddings, d, dev)
924 .map_err(Self::map_gpu_err)?;
925 Ok(Self::wrap_buffer_f64(result, grad_output.device_ordinal()))
926 }
927
928 fn masked_fill_f64(
935 &self,
936 input: &GpuBufferHandle,
937 mask: &GpuBufferHandle,
938 value: f64,
939 ) -> FerrotorchResult<GpuBufferHandle> {
940 let input_buf = Self::unwrap_buffer_f64(input)?;
941 let mask_f32 = Self::unwrap_buffer(mask)?;
942 let dev = self.device(input.device_ordinal())?;
943 let mask_host = crate::transfer::gpu_to_cpu(mask_f32, dev).map_err(Self::map_gpu_err)?;
945 let mask_u8: Vec<u8> = mask_host.iter().map(|&v| if v != 0.0 { 1u8 } else { 0u8 }).collect();
946 let mask_gpu = crate::transfer::cpu_to_gpu(&mask_u8, dev).map_err(Self::map_gpu_err)?;
947 let result = crate::kernels::gpu_masked_fill_f64(input_buf, &mask_gpu, value, dev)
948 .map_err(Self::map_gpu_err)?;
949 Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
950 }
951
952 fn masked_zero_f64(
953 &self,
954 grad: &GpuBufferHandle,
955 mask: &GpuBufferHandle,
956 ) -> FerrotorchResult<GpuBufferHandle> {
957 let grad_buf = Self::unwrap_buffer_f64(grad)?;
958 let mask_f32 = Self::unwrap_buffer(mask)?;
959 let dev = self.device(grad.device_ordinal())?;
960 let mask_host = crate::transfer::gpu_to_cpu(mask_f32, dev).map_err(Self::map_gpu_err)?;
962 let mask_u8: Vec<u8> = mask_host.iter().map(|&v| if v != 0.0 { 1u8 } else { 0u8 }).collect();
963 let mask_gpu = crate::transfer::cpu_to_gpu(&mask_u8, dev).map_err(Self::map_gpu_err)?;
964 let result = crate::kernels::gpu_masked_zero_f64(grad_buf, &mask_gpu, dev)
965 .map_err(Self::map_gpu_err)?;
966 Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
967 }
968
969 fn slice_write_f64(
972 &self,
973 src: &GpuBufferHandle,
974 dst: &mut GpuBufferHandle,
975 n_batch: usize,
976 d: usize,
977 max_len: usize,
978 pos: usize,
979 ) -> FerrotorchResult<()> {
980 let src_buf = Self::unwrap_buffer_f64(src)?;
981 let dst_buf = Self::unwrap_buffer_f64_mut(dst)?;
982 let dev = self.device(src.device_ordinal())?;
983 crate::kernels::gpu_slice_write_f64(src_buf, dst_buf, n_batch, d, max_len, pos, dev)
984 .map_err(Self::map_gpu_err)?;
985 Ok(())
986 }
987
988 fn slice_read_f64(
989 &self,
990 src: &GpuBufferHandle,
991 n_batch: usize,
992 d: usize,
993 len: usize,
994 max_len: usize,
995 ) -> FerrotorchResult<GpuBufferHandle> {
996 let src_buf = Self::unwrap_buffer_f64(src)?;
997 let dev = self.device(src.device_ordinal())?;
998 let result = crate::kernels::gpu_slice_read_f64(src_buf, n_batch, d, len, max_len, dev)
999 .map_err(Self::map_gpu_err)?;
1000 Ok(Self::wrap_buffer_f64(result, src.device_ordinal()))
1001 }
1002
1003 fn strided_split_f64(
1006 &self,
1007 input: &GpuBufferHandle,
1008 total_along_axis: usize,
1009 split_offset: usize,
1010 split_size: usize,
1011 inner_size: usize,
1012 n: usize,
1013 ) -> FerrotorchResult<GpuBufferHandle> {
1014 let in_buf = Self::unwrap_buffer_f64(input)?;
1015 let dev = self.device(input.device_ordinal())?;
1016 let result = crate::kernels::gpu_strided_split_f64(
1017 in_buf,
1018 total_along_axis,
1019 split_offset,
1020 split_size,
1021 inner_size,
1022 n,
1023 dev,
1024 )
1025 .map_err(Self::map_gpu_err)?;
1026 Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
1027 }
1028
1029 fn strided_cat_f64(
1030 &self,
1031 input: &GpuBufferHandle,
1032 output: &mut GpuBufferHandle,
1033 total_along_axis: usize,
1034 cat_offset: usize,
1035 part_size: usize,
1036 inner_size: usize,
1037 n: usize,
1038 ) -> FerrotorchResult<()> {
1039 let in_buf = Self::unwrap_buffer_f64(input)?;
1040 let dev = self.device(input.device_ordinal())?;
1041 let out_buf = Self::unwrap_buffer_f64_mut(output)?;
1042 crate::kernels::gpu_strided_cat_f64(
1043 in_buf,
1044 out_buf,
1045 total_along_axis,
1046 cat_offset,
1047 part_size,
1048 inner_size,
1049 n,
1050 dev,
1051 )
1052 .map_err(Self::map_gpu_err)?;
1053 Ok(())
1054 }
1055
1056 fn index_select_1d_f64(
1059 &self,
1060 input: &GpuBufferHandle,
1061 indices: &GpuBufferHandle,
1062 ) -> FerrotorchResult<GpuBufferHandle> {
1063 let input_buf = Self::unwrap_buffer_f64(input)?;
1064 let idx_buf = Self::unwrap_buffer(indices)?;
1066 let dev = self.device(input.device_ordinal())?;
1067 let result = crate::kernels::gpu_index_select_1d_f64(input_buf, idx_buf, dev)
1068 .map_err(Self::map_gpu_err)?;
1069 Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
1070 }
1071
1072 fn scatter_add_1d_f64(
1073 &self,
1074 grad_output: &GpuBufferHandle,
1075 indices: &GpuBufferHandle,
1076 input_len: usize,
1077 ) -> FerrotorchResult<GpuBufferHandle> {
1078 let go_buf = Self::unwrap_buffer_f64(grad_output)?;
1079 let idx_buf = Self::unwrap_buffer(indices)?;
1081 let dev = self.device(grad_output.device_ordinal())?;
1082 let result = crate::kernels::gpu_scatter_add_1d_f64(go_buf, idx_buf, input_len, dev)
1083 .map_err(Self::map_gpu_err)?;
1084 Ok(Self::wrap_buffer_f64(result, grad_output.device_ordinal()))
1085 }
1086
1087 fn bmm_f64(
1088 &self,
1089 a: &GpuBufferHandle,
1090 b: &GpuBufferHandle,
1091 batch: usize,
1092 m: usize,
1093 k: usize,
1094 n: usize,
1095 ) -> FerrotorchResult<GpuBufferHandle> {
1096 let a_buf = Self::unwrap_buffer_f64(a)?;
1097 let b_buf = Self::unwrap_buffer_f64(b)?;
1098 let dev = self.device(a.device_ordinal())?;
1099 let result = crate::blas::gpu_bmm_f64(a_buf, b_buf, batch, m, k, n, dev)
1100 .map_err(Self::map_gpu_err)?;
1101 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
1102 }
1103
1104 #[allow(clippy::too_many_arguments)]
1105 fn fused_adam_f32(
1106 &self,
1107 param: &mut GpuBufferHandle,
1108 grad: &GpuBufferHandle,
1109 exp_avg: &mut GpuBufferHandle,
1110 exp_avg_sq: &mut GpuBufferHandle,
1111 beta1: f32,
1112 beta2: f32,
1113 lr: f32,
1114 eps: f32,
1115 bc1: f32,
1116 bc2: f32,
1117 weight_decay: f32,
1118 ) -> FerrotorchResult<()> {
1119 let ordinal = param.device_ordinal();
1120 let dev = self.device(ordinal)?;
1121 let p_buf = Self::unwrap_buffer_mut(param)?;
1122 let g_buf = Self::unwrap_buffer(grad)?;
1123 let m_buf = Self::unwrap_buffer_mut(exp_avg)?;
1124 let v_buf = Self::unwrap_buffer_mut(exp_avg_sq)?;
1125 crate::kernels::gpu_fused_adam(
1126 p_buf,
1127 g_buf,
1128 m_buf,
1129 v_buf,
1130 beta1,
1131 beta2,
1132 lr,
1133 eps,
1134 bc1,
1135 bc2,
1136 weight_decay,
1137 dev,
1138 )
1139 .map_err(Self::map_gpu_err)?;
1140 Ok(())
1141 }
1142
1143 #[allow(clippy::too_many_arguments)]
1144 fn maxpool2d_f32(
1145 &self,
1146 input: &GpuBufferHandle,
1147 batch: usize,
1148 channels: usize,
1149 h_in: usize,
1150 w_in: usize,
1151 kh: usize,
1152 kw: usize,
1153 sh: usize,
1154 sw: usize,
1155 ph: usize,
1156 pw: usize,
1157 ) -> FerrotorchResult<(GpuBufferHandle, [usize; 4])> {
1158 let buf = Self::unwrap_buffer(input)?;
1159 let dev = self.device(input.device_ordinal())?;
1160 let (out, shape) = crate::kernels::gpu_maxpool2d(
1161 buf, batch, channels, h_in, w_in, kh, kw, sh, sw, ph, pw, dev,
1162 ).map_err(Self::map_gpu_err)?;
1163 Ok((Self::wrap_buffer(out, input.device_ordinal()), shape))
1164 }
1165
1166 #[allow(clippy::too_many_arguments)]
1167 fn avgpool2d_f32(
1168 &self,
1169 input: &GpuBufferHandle,
1170 batch: usize,
1171 channels: usize,
1172 h_in: usize,
1173 w_in: usize,
1174 kh: usize,
1175 kw: usize,
1176 sh: usize,
1177 sw: usize,
1178 ph: usize,
1179 pw: usize,
1180 ) -> FerrotorchResult<(GpuBufferHandle, [usize; 4])> {
1181 let buf = Self::unwrap_buffer(input)?;
1182 let dev = self.device(input.device_ordinal())?;
1183 let (out, shape) = crate::kernels::gpu_avgpool2d(
1184 buf, batch, channels, h_in, w_in, kh, kw, sh, sw, ph, pw, dev,
1185 ).map_err(Self::map_gpu_err)?;
1186 Ok((Self::wrap_buffer(out, input.device_ordinal()), shape))
1187 }
1188
1189 #[allow(clippy::too_many_arguments)]
1190 fn conv2d_f32(
1191 &self,
1192 input: &GpuBufferHandle,
1193 weight: &GpuBufferHandle,
1194 bias: Option<&GpuBufferHandle>,
1195 input_shape: [usize; 4],
1196 weight_shape: [usize; 4],
1197 stride: (usize, usize),
1198 padding: (usize, usize),
1199 ) -> FerrotorchResult<(GpuBufferHandle, [usize; 4])> {
1200 let input_buf = Self::unwrap_buffer(input)?;
1201 let weight_buf = Self::unwrap_buffer(weight)?;
1202 let bias_buf = match bias {
1203 Some(b) => Some(Self::unwrap_buffer(b)?),
1204 None => None,
1205 };
1206 let dev = self.device(input.device_ordinal())?;
1207 let (out_buf, out_shape) = crate::conv::gpu_conv2d_f32(
1208 input_buf,
1209 weight_buf,
1210 bias_buf,
1211 input_shape,
1212 weight_shape,
1213 stride,
1214 padding,
1215 dev,
1216 )
1217 .map_err(Self::map_gpu_err)?;
1218 Ok((Self::wrap_buffer(out_buf, input.device_ordinal()), out_shape))
1219 }
1220
1221 fn fused_gru_cell_f32(
1222 &self,
1223 input_gates: &GpuBufferHandle,
1224 hidden_gates: &GpuBufferHandle,
1225 bias_ih: &GpuBufferHandle,
1226 bias_hh: &GpuBufferHandle,
1227 hx: &GpuBufferHandle,
1228 hidden_size: usize,
1229 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
1230 let ig = Self::unwrap_buffer(input_gates)?;
1231 let hg = Self::unwrap_buffer(hidden_gates)?;
1232 let bih = Self::unwrap_buffer(bias_ih)?;
1233 let bhh = Self::unwrap_buffer(bias_hh)?;
1234 let hx_buf = Self::unwrap_buffer(hx)?;
1235 let dev = self.device(input_gates.device_ordinal())?;
1236 let (hy, ws) = crate::kernels::gpu_fused_gru_forward(
1237 ig, hg, bih, bhh, hx_buf, hidden_size, dev,
1238 )
1239 .map_err(Self::map_gpu_err)?;
1240 let ord = input_gates.device_ordinal();
1241 Ok((Self::wrap_buffer(hy, ord), Self::wrap_buffer(ws, ord)))
1242 }
1243
1244 fn synchronize(&self, device: usize) -> FerrotorchResult<()> {
1245 let dev = self.device(device)?;
1246 dev.stream()
1247 .synchronize()
1248 .map_err(|e| FerrotorchError::InvalidArgument {
1249 message: format!("CUDA synchronize failed: {e}"),
1250 })?;
1251 Ok(())
1252 }
1253
1254 fn stream_count(&self, device: usize) -> usize {
1255 crate::stream::StreamPool::pool_size(device)
1256 }
1257
1258 fn matmul_f32(
1261 &self,
1262 a: &GpuBufferHandle,
1263 b: &GpuBufferHandle,
1264 m: usize,
1265 k: usize,
1266 n: usize,
1267 ) -> FerrotorchResult<GpuBufferHandle> {
1268 let a_buf = Self::unwrap_buffer(a)?;
1269 let b_buf = Self::unwrap_buffer(b)?;
1270 let dev = self.device(a.device_ordinal())?;
1271 let result =
1272 crate::blas::gpu_matmul_f32(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
1273 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1274 }
1275
1276 fn sum_f32(&self, a: &GpuBufferHandle, _len: usize) -> FerrotorchResult<GpuBufferHandle> {
1279 let a_buf = Self::unwrap_buffer(a)?;
1280 let dev = self.device(a.device_ordinal())?;
1281 let result = crate::kernels::gpu_reduce_sum(a_buf, dev).map_err(Self::map_gpu_err)?;
1282 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1283 }
1284
1285 fn matmul_f64(
1288 &self,
1289 a: &GpuBufferHandle,
1290 b: &GpuBufferHandle,
1291 m: usize,
1292 k: usize,
1293 n: usize,
1294 ) -> FerrotorchResult<GpuBufferHandle> {
1295 let a_buf = Self::unwrap_buffer_f64(a)?;
1296 let b_buf = Self::unwrap_buffer_f64(b)?;
1297 let dev = self.device(a.device_ordinal())?;
1298 let result =
1299 crate::blas::gpu_matmul_f64(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
1300 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
1301 }
1302
1303 fn broadcast_add_f32(
1306 &self,
1307 a: &GpuBufferHandle,
1308 b: &GpuBufferHandle,
1309 a_shape: &[usize],
1310 b_shape: &[usize],
1311 out_shape: &[usize],
1312 ) -> FerrotorchResult<GpuBufferHandle> {
1313 let a_buf = Self::unwrap_buffer(a)?;
1314 let b_buf = Self::unwrap_buffer(b)?;
1315 let dev = self.device(a.device_ordinal())?;
1316 let result =
1317 crate::kernels::gpu_broadcast_add(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
1318 .map_err(Self::map_gpu_err)?;
1319 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1320 }
1321
1322 fn broadcast_sub_f32(
1323 &self,
1324 a: &GpuBufferHandle,
1325 b: &GpuBufferHandle,
1326 a_shape: &[usize],
1327 b_shape: &[usize],
1328 out_shape: &[usize],
1329 ) -> FerrotorchResult<GpuBufferHandle> {
1330 let a_buf = Self::unwrap_buffer(a)?;
1331 let b_buf = Self::unwrap_buffer(b)?;
1332 let dev = self.device(a.device_ordinal())?;
1333 let result =
1334 crate::kernels::gpu_broadcast_sub(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
1335 .map_err(Self::map_gpu_err)?;
1336 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1337 }
1338
1339 fn broadcast_mul_f32(
1340 &self,
1341 a: &GpuBufferHandle,
1342 b: &GpuBufferHandle,
1343 a_shape: &[usize],
1344 b_shape: &[usize],
1345 out_shape: &[usize],
1346 ) -> FerrotorchResult<GpuBufferHandle> {
1347 let a_buf = Self::unwrap_buffer(a)?;
1348 let b_buf = Self::unwrap_buffer(b)?;
1349 let dev = self.device(a.device_ordinal())?;
1350 let result =
1351 crate::kernels::gpu_broadcast_mul(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
1352 .map_err(Self::map_gpu_err)?;
1353 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1354 }
1355
1356 fn broadcast_div_f32(
1357 &self,
1358 a: &GpuBufferHandle,
1359 b: &GpuBufferHandle,
1360 a_shape: &[usize],
1361 b_shape: &[usize],
1362 out_shape: &[usize],
1363 ) -> FerrotorchResult<GpuBufferHandle> {
1364 let a_buf = Self::unwrap_buffer(a)?;
1365 let b_buf = Self::unwrap_buffer(b)?;
1366 let dev = self.device(a.device_ordinal())?;
1367 let result =
1368 crate::kernels::gpu_broadcast_div(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
1369 .map_err(Self::map_gpu_err)?;
1370 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1371 }
1372
1373 fn softmax_f32(
1374 &self,
1375 a: &GpuBufferHandle,
1376 rows: usize,
1377 cols: usize,
1378 ) -> FerrotorchResult<GpuBufferHandle> {
1379 let a_buf = Self::unwrap_buffer(a)?;
1380 let dev = self.device(a.device_ordinal())?;
1381 let result =
1382 crate::kernels::gpu_softmax(a_buf, rows, cols, dev).map_err(Self::map_gpu_err)?;
1383 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1384 }
1385
1386 fn dropout_f32(
1387 &self,
1388 a: &GpuBufferHandle,
1389 threshold: u32,
1390 scale: f32,
1391 seed: u32,
1392 ) -> FerrotorchResult<GpuBufferHandle> {
1393 let a_buf = Self::unwrap_buffer(a)?;
1394 let dev = self.device(a.device_ordinal())?;
1395 let result = crate::kernels::gpu_dropout(a_buf, threshold, scale, seed, dev)
1396 .map_err(Self::map_gpu_err)?;
1397 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1398 }
1399
1400 fn dropout_philox_f32(
1401 &self,
1402 a: &GpuBufferHandle,
1403 threshold: u32,
1404 scale: f32,
1405 ) -> FerrotorchResult<(GpuBufferHandle, GpuRngState)> {
1406 let device_ordinal = a.device_ordinal();
1407 let n = a.len();
1408
1409 let rng_state = {
1411 let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
1412 FerrotorchError::InvalidArgument {
1413 message: "failed to lock CUDA RNG manager".into(),
1414 }
1415 })?;
1416 let philox_gen = mgr.generator(device_ordinal);
1417 let state = philox_gen.get_state();
1418 let counters_needed = n.div_ceil(4);
1420 philox_gen.advance(counters_needed as u64);
1421 state
1422 };
1423
1424 let a_buf = Self::unwrap_buffer(a)?;
1432 let dev = self.device(device_ordinal)?;
1433
1434 let derived_seed = (rng_state.counter ^ rng_state.seed) as u32;
1437 let result = crate::kernels::gpu_dropout(a_buf, threshold, scale, derived_seed, dev)
1438 .map_err(Self::map_gpu_err)?;
1439
1440 let gpu_rng_state = GpuRngState {
1441 counter: rng_state.counter,
1442 seed: rng_state.seed,
1443 offset: rng_state.offset,
1444 device: device_ordinal,
1445 };
1446
1447 Ok((Self::wrap_buffer(result, device_ordinal), gpu_rng_state))
1448 }
1449
1450 fn dropout_f64(
1451 &self,
1452 a: &GpuBufferHandle,
1453 threshold: u32,
1454 scale: f64,
1455 seed: u32,
1456 ) -> FerrotorchResult<GpuBufferHandle> {
1457 let a_buf = Self::unwrap_buffer_f64(a)?;
1458 let dev = self.device(a.device_ordinal())?;
1459 let result = crate::kernels::gpu_dropout_f64(a_buf, threshold, scale, seed, dev)
1460 .map_err(Self::map_gpu_err)?;
1461 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
1462 }
1463
1464 fn dropout_philox_f64(
1465 &self,
1466 a: &GpuBufferHandle,
1467 threshold: u32,
1468 scale: f64,
1469 ) -> FerrotorchResult<(GpuBufferHandle, GpuRngState)> {
1470 let device_ordinal = a.device_ordinal();
1471 let n = a.len();
1472
1473 let rng_state = {
1474 let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
1475 FerrotorchError::InvalidArgument {
1476 message: "failed to lock CUDA RNG manager".into(),
1477 }
1478 })?;
1479 let philox_gen = mgr.generator(device_ordinal);
1480 let state = philox_gen.get_state();
1481 let counters_needed = n.div_ceil(4);
1482 philox_gen.advance(counters_needed as u64);
1483 state
1484 };
1485
1486 let a_buf = Self::unwrap_buffer_f64(a)?;
1487 let dev = self.device(device_ordinal)?;
1488 let derived_seed = (rng_state.counter ^ rng_state.seed) as u32;
1489 let result = crate::kernels::gpu_dropout_f64(a_buf, threshold, scale, derived_seed, dev)
1490 .map_err(Self::map_gpu_err)?;
1491
1492 let gpu_rng_state = GpuRngState {
1493 counter: rng_state.counter,
1494 seed: rng_state.seed,
1495 offset: rng_state.offset,
1496 device: device_ordinal,
1497 };
1498
1499 Ok((Self::wrap_buffer_f64(result, device_ordinal), gpu_rng_state))
1500 }
1501
1502 fn transpose_2d_f32(
1503 &self,
1504 a: &GpuBufferHandle,
1505 m: usize,
1506 n: usize,
1507 ) -> FerrotorchResult<GpuBufferHandle> {
1508 let a_buf = Self::unwrap_buffer(a)?;
1509 let dev = self.device(a.device_ordinal())?;
1510 let result =
1511 crate::kernels::gpu_transpose_2d(a_buf, m, n, dev).map_err(Self::map_gpu_err)?;
1512 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1513 }
1514
1515 fn permute_0213_f32(
1516 &self,
1517 a: &GpuBufferHandle,
1518 d0: usize,
1519 d1: usize,
1520 d2: usize,
1521 d3: usize,
1522 ) -> FerrotorchResult<GpuBufferHandle> {
1523 let a_buf = Self::unwrap_buffer(a)?;
1524 let dev = self.device(a.device_ordinal())?;
1525 let result = crate::kernels::gpu_permute_0213(a_buf, d0, d1, d2, d3, dev)
1526 .map_err(Self::map_gpu_err)?;
1527 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1528 }
1529
1530 fn bmm_f32(
1531 &self,
1532 a: &GpuBufferHandle,
1533 b: &GpuBufferHandle,
1534 batch: usize,
1535 m: usize,
1536 k: usize,
1537 n: usize,
1538 ) -> FerrotorchResult<GpuBufferHandle> {
1539 let a_buf = Self::unwrap_buffer(a)?;
1540 let b_buf = Self::unwrap_buffer(b)?;
1541 let dev = self.device(a.device_ordinal())?;
1542 let result = crate::blas::gpu_bmm_f32(a_buf, b_buf, batch, m, k, n, dev)
1543 .map_err(Self::map_gpu_err)?;
1544 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1545 }
1546
1547 fn bmm_f16_f32(
1548 &self,
1549 a: &GpuBufferHandle,
1550 b: &GpuBufferHandle,
1551 batch: usize,
1552 m: usize,
1553 k: usize,
1554 n: usize,
1555 ) -> FerrotorchResult<GpuBufferHandle> {
1556 let a_buf = Self::unwrap_buffer(a)?;
1557 let b_buf = Self::unwrap_buffer(b)?;
1558 let dev = self.device(a.device_ordinal())?;
1559 let result = crate::blas::gpu_bmm_f16(a_buf, b_buf, batch, m, k, n, dev)
1560 .map_err(Self::map_gpu_err)?;
1561 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1562 }
1563
1564 fn gelu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
1565 let a_buf = Self::unwrap_buffer(a)?;
1566 let dev = self.device(a.device_ordinal())?;
1567 let result = crate::kernels::gpu_gelu(a_buf, dev).map_err(Self::map_gpu_err)?;
1568 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1569 }
1570
1571 fn gelu_tanh_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
1572 let a_buf = Self::unwrap_buffer(a)?;
1573 let dev = self.device(a.device_ordinal())?;
1574 let result = crate::kernels::gpu_gelu_tanh(a_buf, dev).map_err(Self::map_gpu_err)?;
1575 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1576 }
1577
1578 fn gelu_erf_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
1579 let a_buf = Self::unwrap_buffer(a)?;
1580 let dev = self.device(a.device_ordinal())?;
1581 let result = crate::kernels::gpu_gelu_erf(a_buf, dev).map_err(Self::map_gpu_err)?;
1582 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1583 }
1584
1585 fn layernorm_f32(
1586 &self,
1587 input: &GpuBufferHandle,
1588 weight: &GpuBufferHandle,
1589 bias: &GpuBufferHandle,
1590 rows: usize,
1591 cols: usize,
1592 eps: f32,
1593 ) -> FerrotorchResult<GpuBufferHandle> {
1594 let in_buf = Self::unwrap_buffer(input)?;
1595 let w_buf = Self::unwrap_buffer(weight)?;
1596 let b_buf = Self::unwrap_buffer(bias)?;
1597 let dev = self.device(input.device_ordinal())?;
1598 let result = crate::kernels::gpu_layernorm(in_buf, w_buf, b_buf, rows, cols, eps, dev)
1599 .map_err(Self::map_gpu_err)?;
1600 Ok(Self::wrap_buffer(result, input.device_ordinal()))
1601 }
1602
1603 fn rmsnorm_f32(
1604 &self,
1605 input: &GpuBufferHandle,
1606 weight: &GpuBufferHandle,
1607 rows: usize,
1608 cols: usize,
1609 eps: f32,
1610 ) -> FerrotorchResult<GpuBufferHandle> {
1611 let in_buf = Self::unwrap_buffer(input)?;
1612 let w_buf = Self::unwrap_buffer(weight)?;
1613 let dev = self.device(input.device_ordinal())?;
1614 let result = crate::kernels::gpu_rmsnorm(in_buf, w_buf, rows, cols, eps, dev)
1615 .map_err(Self::map_gpu_err)?;
1616 Ok(Self::wrap_buffer(result, input.device_ordinal()))
1617 }
1618
1619 fn rmsnorm_backward_f32(
1620 &self,
1621 input: &GpuBufferHandle,
1622 grad_output: &GpuBufferHandle,
1623 weight: &GpuBufferHandle,
1624 rows: usize,
1625 cols: usize,
1626 eps: f32,
1627 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
1628 let in_buf = Self::unwrap_buffer(input)?;
1629 let go_buf = Self::unwrap_buffer(grad_output)?;
1630 let w_buf = Self::unwrap_buffer(weight)?;
1631 let dev = self.device(input.device_ordinal())?;
1632 let (gi, gw) =
1633 crate::kernels::gpu_rmsnorm_backward(in_buf, go_buf, w_buf, rows, cols, eps, dev)
1634 .map_err(Self::map_gpu_err)?;
1635 let ordinal = input.device_ordinal();
1636 Ok((Self::wrap_buffer(gi, ordinal), Self::wrap_buffer(gw, ordinal)))
1637 }
1638
1639 fn slice_write_f32(
1640 &self,
1641 src: &GpuBufferHandle,
1642 dst: &mut GpuBufferHandle,
1643 n_batch: usize,
1644 d: usize,
1645 max_len: usize,
1646 pos: usize,
1647 ) -> FerrotorchResult<()> {
1648 let src_buf = Self::unwrap_buffer(src)?;
1649 let dst_buf =
1650 dst.downcast_mut::<CudaBuffer<f32>>()
1651 .ok_or(FerrotorchError::InvalidArgument {
1652 message: "slice_write_f32: dst is not CudaBuffer<f32>".into(),
1653 })?;
1654 let dev = self.device(src.device_ordinal())?;
1655 crate::kernels::gpu_slice_write(src_buf, dst_buf, n_batch, d, max_len, pos, dev)
1656 .map_err(Self::map_gpu_err)?;
1657 Ok(())
1658 }
1659
1660 fn slice_read_f32(
1661 &self,
1662 src: &GpuBufferHandle,
1663 n_batch: usize,
1664 d: usize,
1665 len: usize,
1666 max_len: usize,
1667 ) -> FerrotorchResult<GpuBufferHandle> {
1668 let src_buf = Self::unwrap_buffer(src)?;
1669 let dev = self.device(src.device_ordinal())?;
1670 let result = crate::kernels::gpu_slice_read(src_buf, n_batch, d, len, max_len, dev)
1671 .map_err(Self::map_gpu_err)?;
1672 Ok(Self::wrap_buffer(result, src.device_ordinal()))
1673 }
1674
1675 fn embed_lookup_f32(
1676 &self,
1677 idx: &GpuBufferHandle,
1678 weight: &GpuBufferHandle,
1679 d: usize,
1680 ) -> FerrotorchResult<GpuBufferHandle> {
1681 let idx_buf = Self::unwrap_buffer(idx)?;
1682 let w_buf = Self::unwrap_buffer(weight)?;
1683 let dev = self.device(idx.device_ordinal())?;
1684 let result =
1685 crate::kernels::gpu_embed_lookup(idx_buf, w_buf, d, dev).map_err(Self::map_gpu_err)?;
1686 Ok(Self::wrap_buffer(result, idx.device_ordinal()))
1687 }
1688
1689 fn embed_lookup_batch_f32(
1690 &self,
1691 indices: &GpuBufferHandle,
1692 weight: &GpuBufferHandle,
1693 n: usize,
1694 d: usize,
1695 ) -> FerrotorchResult<GpuBufferHandle> {
1696 let idx_buf = Self::unwrap_buffer(indices)?;
1697 let w_buf = Self::unwrap_buffer(weight)?;
1698 let dev = self.device(indices.device_ordinal())?;
1699 let result = crate::kernels::gpu_embed_lookup_batch(idx_buf, w_buf, n, d, dev)
1700 .map_err(Self::map_gpu_err)?;
1701 Ok(Self::wrap_buffer(result, indices.device_ordinal()))
1702 }
1703
1704 fn scatter_add_rows_f32(
1705 &self,
1706 grad_output: &GpuBufferHandle,
1707 indices: &GpuBufferHandle,
1708 num_embeddings: usize,
1709 d: usize,
1710 ) -> FerrotorchResult<GpuBufferHandle> {
1711 let go_buf = Self::unwrap_buffer(grad_output)?;
1712 let idx_buf = Self::unwrap_buffer(indices)?;
1713 let dev = self.device(grad_output.device_ordinal())?;
1714 let result = crate::kernels::gpu_scatter_add_rows(go_buf, idx_buf, num_embeddings, d, dev)
1715 .map_err(Self::map_gpu_err)?;
1716 Ok(Self::wrap_buffer(result, grad_output.device_ordinal()))
1717 }
1718
1719 fn scale_f32(&self, a: &GpuBufferHandle, scalar: f32) -> FerrotorchResult<GpuBufferHandle> {
1720 let a_buf = Self::unwrap_buffer(a)?;
1721 let dev = self.device(a.device_ordinal())?;
1722 let result = crate::kernels::gpu_scale(a_buf, scalar, dev).map_err(Self::map_gpu_err)?;
1723 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1724 }
1725
1726 fn relu_backward_f32(
1727 &self,
1728 grad: &GpuBufferHandle,
1729 input: &GpuBufferHandle,
1730 ) -> FerrotorchResult<GpuBufferHandle> {
1731 let grad_buf = Self::unwrap_buffer(grad)?;
1732 let input_buf = Self::unwrap_buffer(input)?;
1733 let dev = self.device(grad.device_ordinal())?;
1734 let result = crate::kernels::gpu_relu_backward(grad_buf, input_buf, dev)
1735 .map_err(Self::map_gpu_err)?;
1736 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1737 }
1738
1739 fn gelu_backward_f32(
1740 &self,
1741 grad: &GpuBufferHandle,
1742 input: &GpuBufferHandle,
1743 ) -> FerrotorchResult<GpuBufferHandle> {
1744 let grad_buf = Self::unwrap_buffer(grad)?;
1745 let input_buf = Self::unwrap_buffer(input)?;
1746 let dev = self.device(grad.device_ordinal())?;
1747 let result = crate::kernels::gpu_gelu_backward(grad_buf, input_buf, dev)
1748 .map_err(Self::map_gpu_err)?;
1749 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1750 }
1751
1752 fn gelu_backward_tanh_f32(
1753 &self,
1754 grad: &GpuBufferHandle,
1755 input: &GpuBufferHandle,
1756 ) -> FerrotorchResult<GpuBufferHandle> {
1757 let grad_buf = Self::unwrap_buffer(grad)?;
1758 let input_buf = Self::unwrap_buffer(input)?;
1759 let dev = self.device(grad.device_ordinal())?;
1760 let result = crate::kernels::gpu_gelu_backward_tanh(grad_buf, input_buf, dev)
1761 .map_err(Self::map_gpu_err)?;
1762 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1763 }
1764
1765 fn gelu_backward_erf_f32(
1766 &self,
1767 grad: &GpuBufferHandle,
1768 input: &GpuBufferHandle,
1769 ) -> FerrotorchResult<GpuBufferHandle> {
1770 let grad_buf = Self::unwrap_buffer(grad)?;
1771 let input_buf = Self::unwrap_buffer(input)?;
1772 let dev = self.device(grad.device_ordinal())?;
1773 let result = crate::kernels::gpu_gelu_backward_erf(grad_buf, input_buf, dev)
1774 .map_err(Self::map_gpu_err)?;
1775 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1776 }
1777
1778 fn cumsum_f32(
1779 &self,
1780 a: &GpuBufferHandle,
1781 outer: usize,
1782 dim_size: usize,
1783 inner: usize,
1784 ) -> FerrotorchResult<GpuBufferHandle> {
1785 let a_buf = Self::unwrap_buffer(a)?;
1786 let dev = self.device(a.device_ordinal())?;
1787 let result = crate::kernels::gpu_cumsum(a_buf, outer, dim_size, inner, dev)
1788 .map_err(Self::map_gpu_err)?;
1789 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1790 }
1791
1792 fn cumprod_f32(
1793 &self,
1794 a: &GpuBufferHandle,
1795 outer: usize,
1796 dim_size: usize,
1797 inner: usize,
1798 ) -> FerrotorchResult<GpuBufferHandle> {
1799 let a_buf = Self::unwrap_buffer(a)?;
1800 let dev = self.device(a.device_ordinal())?;
1801 let result = crate::kernels::gpu_cumprod(a_buf, outer, dim_size, inner, dev)
1802 .map_err(Self::map_gpu_err)?;
1803 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1804 }
1805
1806 fn cummax_f32(
1807 &self,
1808 a: &GpuBufferHandle,
1809 outer: usize,
1810 dim_size: usize,
1811 inner: usize,
1812 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
1813 let a_buf = Self::unwrap_buffer(a)?;
1814 let dev = self.device(a.device_ordinal())?;
1815 let (vals, idxs) = crate::kernels::gpu_cummax(a_buf, outer, dim_size, inner, dev)
1816 .map_err(Self::map_gpu_err)?;
1817 let ord = a.device_ordinal();
1818 Ok((Self::wrap_buffer(vals, ord), Self::wrap_buffer(idxs, ord)))
1819 }
1820
1821 fn cummin_f32(
1822 &self,
1823 a: &GpuBufferHandle,
1824 outer: usize,
1825 dim_size: usize,
1826 inner: usize,
1827 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
1828 let a_buf = Self::unwrap_buffer(a)?;
1829 let dev = self.device(a.device_ordinal())?;
1830 let (vals, idxs) = crate::kernels::gpu_cummin(a_buf, outer, dim_size, inner, dev)
1831 .map_err(Self::map_gpu_err)?;
1832 let ord = a.device_ordinal();
1833 Ok((Self::wrap_buffer(vals, ord), Self::wrap_buffer(idxs, ord)))
1834 }
1835
1836 fn logcumsumexp_f32(
1837 &self,
1838 a: &GpuBufferHandle,
1839 outer: usize,
1840 dim_size: usize,
1841 inner: usize,
1842 ) -> FerrotorchResult<GpuBufferHandle> {
1843 let a_buf = Self::unwrap_buffer(a)?;
1844 let dev = self.device(a.device_ordinal())?;
1845 let result = crate::kernels::gpu_logcumsumexp(a_buf, outer, dim_size, inner, dev)
1846 .map_err(Self::map_gpu_err)?;
1847 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1848 }
1849
1850 fn clamp_f32(
1851 &self,
1852 a: &GpuBufferHandle,
1853 min_val: f32,
1854 max_val: f32,
1855 ) -> FerrotorchResult<GpuBufferHandle> {
1856 let a_buf = Self::unwrap_buffer(a)?;
1857 let dev = self.device(a.device_ordinal())?;
1858 let result =
1859 crate::kernels::gpu_clamp(a_buf, min_val, max_val, dev).map_err(Self::map_gpu_err)?;
1860 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1861 }
1862
1863 fn silu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
1864 let a_buf = Self::unwrap_buffer(a)?;
1865 let dev = self.device(a.device_ordinal())?;
1866 let result = crate::kernels::gpu_silu(a_buf, dev).map_err(Self::map_gpu_err)?;
1867 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1868 }
1869
1870 fn silu_backward_f32(
1871 &self,
1872 grad: &GpuBufferHandle,
1873 input: &GpuBufferHandle,
1874 ) -> FerrotorchResult<GpuBufferHandle> {
1875 let grad_buf = Self::unwrap_buffer(grad)?;
1876 let input_buf = Self::unwrap_buffer(input)?;
1877 let dev = self.device(grad.device_ordinal())?;
1878 let result = crate::kernels::gpu_silu_backward(grad_buf, input_buf, dev)
1879 .map_err(Self::map_gpu_err)?;
1880 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1881 }
1882
1883 fn elu_f32(&self, a: &GpuBufferHandle, alpha: f32) -> FerrotorchResult<GpuBufferHandle> {
1884 let a_buf = Self::unwrap_buffer(a)?;
1885 let dev = self.device(a.device_ordinal())?;
1886 let result = crate::kernels::gpu_elu(a_buf, alpha, dev).map_err(Self::map_gpu_err)?;
1887 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1888 }
1889
1890 fn elu_backward_f32(
1891 &self,
1892 grad: &GpuBufferHandle,
1893 input: &GpuBufferHandle,
1894 alpha: f32,
1895 ) -> FerrotorchResult<GpuBufferHandle> {
1896 let grad_buf = Self::unwrap_buffer(grad)?;
1897 let input_buf = Self::unwrap_buffer(input)?;
1898 let dev = self.device(grad.device_ordinal())?;
1899 let result = crate::kernels::gpu_elu_backward(grad_buf, input_buf, alpha, dev)
1900 .map_err(Self::map_gpu_err)?;
1901 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1902 }
1903
1904 fn mish_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
1905 let a_buf = Self::unwrap_buffer(a)?;
1906 let dev = self.device(a.device_ordinal())?;
1907 let result = crate::kernels::gpu_mish(a_buf, dev).map_err(Self::map_gpu_err)?;
1908 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1909 }
1910
1911 fn mish_backward_f32(
1912 &self,
1913 grad: &GpuBufferHandle,
1914 input: &GpuBufferHandle,
1915 ) -> FerrotorchResult<GpuBufferHandle> {
1916 let grad_buf = Self::unwrap_buffer(grad)?;
1917 let input_buf = Self::unwrap_buffer(input)?;
1918 let dev = self.device(grad.device_ordinal())?;
1919 let result = crate::kernels::gpu_mish_backward(grad_buf, input_buf, dev)
1920 .map_err(Self::map_gpu_err)?;
1921 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1922 }
1923
1924 fn log_softmax_f32(
1925 &self,
1926 a: &GpuBufferHandle,
1927 cols: usize,
1928 ) -> FerrotorchResult<GpuBufferHandle> {
1929 let a_buf = Self::unwrap_buffer(a)?;
1930 let dev = self.device(a.device_ordinal())?;
1931 let result =
1932 crate::kernels::gpu_log_softmax(a_buf, cols, dev).map_err(Self::map_gpu_err)?;
1933 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1934 }
1935
1936 fn log_softmax_backward_f32(
1937 &self,
1938 grad: &GpuBufferHandle,
1939 output: &GpuBufferHandle,
1940 cols: usize,
1941 ) -> FerrotorchResult<GpuBufferHandle> {
1942 let grad_buf = Self::unwrap_buffer(grad)?;
1943 let output_buf = Self::unwrap_buffer(output)?;
1944 let dev = self.device(grad.device_ordinal())?;
1945 let result =
1946 crate::kernels::gpu_log_softmax_backward(grad_buf, output_buf, cols, dev)
1947 .map_err(Self::map_gpu_err)?;
1948 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1949 }
1950
1951 fn index_select_1d_f32(
1952 &self,
1953 input: &GpuBufferHandle,
1954 indices: &GpuBufferHandle,
1955 ) -> FerrotorchResult<GpuBufferHandle> {
1956 let input_buf = Self::unwrap_buffer(input)?;
1957 let idx_buf = Self::unwrap_buffer(indices)?;
1958 let dev = self.device(input.device_ordinal())?;
1959 let result = crate::kernels::gpu_index_select_1d(input_buf, idx_buf, dev)
1960 .map_err(Self::map_gpu_err)?;
1961 Ok(Self::wrap_buffer(result, input.device_ordinal()))
1962 }
1963
1964 fn scatter_add_1d_f32(
1965 &self,
1966 grad_output: &GpuBufferHandle,
1967 indices: &GpuBufferHandle,
1968 input_len: usize,
1969 ) -> FerrotorchResult<GpuBufferHandle> {
1970 let go_buf = Self::unwrap_buffer(grad_output)?;
1971 let idx_buf = Self::unwrap_buffer(indices)?;
1972 let dev = self.device(grad_output.device_ordinal())?;
1973 let result = crate::kernels::gpu_scatter_add_1d(go_buf, idx_buf, input_len, dev)
1974 .map_err(Self::map_gpu_err)?;
1975 Ok(Self::wrap_buffer(result, grad_output.device_ordinal()))
1976 }
1977
1978 fn masked_fill_f32(
1979 &self,
1980 input: &GpuBufferHandle,
1981 mask: &GpuBufferHandle,
1982 value: f32,
1983 ) -> FerrotorchResult<GpuBufferHandle> {
1984 let input_buf = Self::unwrap_buffer(input)?;
1985 let mask_buf = Self::unwrap_buffer(mask)?;
1986 let dev = self.device(input.device_ordinal())?;
1987 let result = crate::kernels::gpu_masked_fill(input_buf, mask_buf, value, dev)
1988 .map_err(Self::map_gpu_err)?;
1989 Ok(Self::wrap_buffer(result, input.device_ordinal()))
1990 }
1991
1992 fn masked_zero_f32(
1993 &self,
1994 grad: &GpuBufferHandle,
1995 mask: &GpuBufferHandle,
1996 ) -> FerrotorchResult<GpuBufferHandle> {
1997 let grad_buf = Self::unwrap_buffer(grad)?;
1998 let mask_buf = Self::unwrap_buffer(mask)?;
1999 let dev = self.device(grad.device_ordinal())?;
2000 let result =
2001 crate::kernels::gpu_masked_zero(grad_buf, mask_buf, dev).map_err(Self::map_gpu_err)?;
2002 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
2003 }
2004
2005 fn sigmoid_backward_f32(
2006 &self,
2007 grad: &GpuBufferHandle,
2008 output: &GpuBufferHandle,
2009 ) -> FerrotorchResult<GpuBufferHandle> {
2010 let grad_buf = Self::unwrap_buffer(grad)?;
2011 let output_buf = Self::unwrap_buffer(output)?;
2012 let dev = self.device(grad.device_ordinal())?;
2013 let result = crate::kernels::gpu_sigmoid_backward(grad_buf, output_buf, dev)
2014 .map_err(Self::map_gpu_err)?;
2015 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
2016 }
2017
2018 fn tanh_backward_f32(
2019 &self,
2020 grad: &GpuBufferHandle,
2021 output: &GpuBufferHandle,
2022 ) -> FerrotorchResult<GpuBufferHandle> {
2023 let grad_buf = Self::unwrap_buffer(grad)?;
2024 let output_buf = Self::unwrap_buffer(output)?;
2025 let dev = self.device(grad.device_ordinal())?;
2026 let result = crate::kernels::gpu_tanh_backward(grad_buf, output_buf, dev)
2027 .map_err(Self::map_gpu_err)?;
2028 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
2029 }
2030
2031 fn softmax_backward_f32(
2032 &self,
2033 grad: &GpuBufferHandle,
2034 output: &GpuBufferHandle,
2035 cols: usize,
2036 ) -> FerrotorchResult<GpuBufferHandle> {
2037 let grad_buf = Self::unwrap_buffer(grad)?;
2038 let output_buf = Self::unwrap_buffer(output)?;
2039 let dev = self.device(grad.device_ordinal())?;
2040 let result = crate::kernels::gpu_softmax_backward(grad_buf, output_buf, cols, dev)
2041 .map_err(Self::map_gpu_err)?;
2042 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
2043 }
2044
2045 fn layernorm_backward_f32(
2046 &self,
2047 input: &GpuBufferHandle,
2048 grad_output: &GpuBufferHandle,
2049 weight: &GpuBufferHandle,
2050 rows: usize,
2051 cols: usize,
2052 eps: f32,
2053 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
2054 let in_buf = Self::unwrap_buffer(input)?;
2055 let go_buf = Self::unwrap_buffer(grad_output)?;
2056 let w_buf = Self::unwrap_buffer(weight)?;
2057 let dev = self.device(input.device_ordinal())?;
2058 let (gi, gw, gb) =
2059 crate::kernels::gpu_layernorm_backward(in_buf, go_buf, w_buf, rows, cols, eps, dev)
2060 .map_err(Self::map_gpu_err)?;
2061 let ordinal = input.device_ordinal();
2062 Ok((
2063 Self::wrap_buffer(gi, ordinal),
2064 Self::wrap_buffer(gw, ordinal),
2065 Self::wrap_buffer(gb, ordinal),
2066 ))
2067 }
2068
2069 fn sum_axis_f32(
2070 &self,
2071 a: &GpuBufferHandle,
2072 shape: &[usize],
2073 axis: usize,
2074 ) -> FerrotorchResult<GpuBufferHandle> {
2075 let a_buf = Self::unwrap_buffer(a)?;
2076 let dev = self.device(a.device_ordinal())?;
2077 let outer: usize = shape[..axis].iter().product();
2078 let axis_size = shape[axis];
2079 let inner: usize = shape[axis + 1..].iter().product::<usize>().max(1);
2080 let result = crate::kernels::gpu_sum_axis(a_buf, outer, axis_size, inner, dev)
2081 .map_err(Self::map_gpu_err)?;
2082 Ok(Self::wrap_buffer(result, a.device_ordinal()))
2083 }
2084
2085 fn matmul_f16_f32(
2086 &self,
2087 a: &GpuBufferHandle,
2088 b: &GpuBufferHandle,
2089 m: usize,
2090 k: usize,
2091 n: usize,
2092 ) -> FerrotorchResult<GpuBufferHandle> {
2093 let a_buf = Self::unwrap_buffer(a)?;
2094 let b_buf = Self::unwrap_buffer(b)?;
2095 let dev = self.device(a.device_ordinal())?;
2096 let result =
2097 crate::blas::gpu_matmul_f16(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
2098 Ok(Self::wrap_buffer(result, a.device_ordinal()))
2099 }
2100
2101 fn save_rng_state(&self, device: usize) -> FerrotorchResult<GpuRngState> {
2102 let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
2103 FerrotorchError::InvalidArgument {
2104 message: "failed to lock CUDA RNG manager".into(),
2105 }
2106 })?;
2107 let state = mgr.get_rng_state(device);
2108 Ok(GpuRngState {
2109 counter: state.counter,
2110 seed: state.seed,
2111 offset: state.offset,
2112 device,
2113 })
2114 }
2115
2116 fn restore_rng_state(&self, state: GpuRngState) -> FerrotorchResult<()> {
2117 let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
2118 FerrotorchError::InvalidArgument {
2119 message: "failed to lock CUDA RNG manager".into(),
2120 }
2121 })?;
2122 mgr.set_rng_state(
2123 state.device,
2124 crate::rng::PhiloxState {
2125 counter: state.counter,
2126 seed: state.seed,
2127 offset: state.offset,
2128 },
2129 );
2130 Ok(())
2131 }
2132
2133 fn strided_split_f32(
2134 &self,
2135 input: &GpuBufferHandle,
2136 total_along_axis: usize,
2137 split_offset: usize,
2138 split_size: usize,
2139 inner_size: usize,
2140 n: usize,
2141 ) -> FerrotorchResult<GpuBufferHandle> {
2142 let in_buf = Self::unwrap_buffer(input)?;
2143 let dev = self.device(input.device_ordinal())?;
2144 let result = crate::kernels::gpu_strided_split(
2145 in_buf,
2146 total_along_axis,
2147 split_offset,
2148 split_size,
2149 inner_size,
2150 n,
2151 dev,
2152 )
2153 .map_err(Self::map_gpu_err)?;
2154 Ok(Self::wrap_buffer(result, input.device_ordinal()))
2155 }
2156
2157 fn strided_copy_f32(
2158 &self,
2159 input: &GpuBufferHandle,
2160 out_shape: &[usize],
2161 src_strides: &[isize],
2162 src_offset: usize,
2163 ) -> FerrotorchResult<GpuBufferHandle> {
2164 let in_buf = Self::unwrap_buffer(input)?;
2165 let dev = self.device(input.device_ordinal())?;
2166 let result =
2167 crate::kernels::gpu_strided_copy(in_buf, out_shape, src_strides, src_offset, dev)
2168 .map_err(Self::map_gpu_err)?;
2169 Ok(Self::wrap_buffer(result, input.device_ordinal()))
2170 }
2171
2172 fn strided_copy_f64(
2173 &self,
2174 input: &GpuBufferHandle,
2175 out_shape: &[usize],
2176 src_strides: &[isize],
2177 src_offset: usize,
2178 ) -> FerrotorchResult<GpuBufferHandle> {
2179 let in_buf = Self::unwrap_buffer_f64(input)?;
2180 let dev = self.device(input.device_ordinal())?;
2181 let result = crate::kernels::gpu_strided_copy_f64(
2182 in_buf,
2183 out_shape,
2184 src_strides,
2185 src_offset,
2186 dev,
2187 )
2188 .map_err(Self::map_gpu_err)?;
2189 Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
2190 }
2191
2192 fn strided_cat_f32(
2193 &self,
2194 input: &GpuBufferHandle,
2195 output: &mut GpuBufferHandle,
2196 total_along_axis: usize,
2197 cat_offset: usize,
2198 part_size: usize,
2199 inner_size: usize,
2200 n: usize,
2201 ) -> FerrotorchResult<()> {
2202 let in_buf = Self::unwrap_buffer(input)?;
2203 let dev = self.device(input.device_ordinal())?;
2204 let out_buf =
2205 output
2206 .downcast_mut::<CudaBuffer<f32>>()
2207 .ok_or(FerrotorchError::InvalidArgument {
2208 message: "strided_cat_f32: output is not CudaBuffer<f32>".into(),
2209 })?;
2210 crate::kernels::gpu_strided_cat(
2211 in_buf,
2212 out_buf,
2213 total_along_axis,
2214 cat_offset,
2215 part_size,
2216 inner_size,
2217 n,
2218 dev,
2219 )
2220 .map_err(Self::map_gpu_err)?;
2221 Ok(())
2222 }
2223
2224 fn svd_f32(
2227 &self,
2228 a: &GpuBufferHandle,
2229 m: usize,
2230 n: usize,
2231 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
2232 let a_buf = Self::unwrap_buffer(a)?;
2233 let dev = self.device(a.device_ordinal())?;
2234 let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2235 let (u, s, vt) =
2236 crate::cusolver::gpu_svd_f32(&a_host, m, n, dev).map_err(Self::map_gpu_err)?;
2237 let u_buf = crate::transfer::cpu_to_gpu(&u, dev).map_err(Self::map_gpu_err)?;
2238 let s_buf = crate::transfer::cpu_to_gpu(&s, dev).map_err(Self::map_gpu_err)?;
2239 let vt_buf = crate::transfer::cpu_to_gpu(&vt, dev).map_err(Self::map_gpu_err)?;
2240 let ord = a.device_ordinal();
2241 Ok((
2242 Self::wrap_buffer(u_buf, ord),
2243 Self::wrap_buffer(s_buf, ord),
2244 Self::wrap_buffer(vt_buf, ord),
2245 ))
2246 }
2247
2248 fn svd_f64(
2249 &self,
2250 a: &GpuBufferHandle,
2251 m: usize,
2252 n: usize,
2253 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
2254 let a_buf = Self::unwrap_buffer_f64(a)?;
2255 let dev = self.device(a.device_ordinal())?;
2256 let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2257 let (u, s, vt) =
2258 crate::cusolver::gpu_svd_f64(&a_host, m, n, dev).map_err(Self::map_gpu_err)?;
2259 let u_buf = crate::transfer::cpu_to_gpu(&u, dev).map_err(Self::map_gpu_err)?;
2260 let s_buf = crate::transfer::cpu_to_gpu(&s, dev).map_err(Self::map_gpu_err)?;
2261 let vt_buf = crate::transfer::cpu_to_gpu(&vt, dev).map_err(Self::map_gpu_err)?;
2262 let ord = a.device_ordinal();
2263 Ok((
2264 Self::wrap_buffer_f64(u_buf, ord),
2265 Self::wrap_buffer_f64(s_buf, ord),
2266 Self::wrap_buffer_f64(vt_buf, ord),
2267 ))
2268 }
2269
2270 fn cholesky_f32(&self, a: &GpuBufferHandle, n: usize) -> FerrotorchResult<GpuBufferHandle> {
2271 let a_buf = Self::unwrap_buffer(a)?;
2272 let dev = self.device(a.device_ordinal())?;
2273 let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2274 let l = crate::cusolver::gpu_cholesky_f32(&a_host, n, dev).map_err(Self::map_gpu_err)?;
2275 let l_buf = crate::transfer::cpu_to_gpu(&l, dev).map_err(Self::map_gpu_err)?;
2276 Ok(Self::wrap_buffer(l_buf, a.device_ordinal()))
2277 }
2278
2279 fn cholesky_f64(&self, a: &GpuBufferHandle, n: usize) -> FerrotorchResult<GpuBufferHandle> {
2280 let a_buf = Self::unwrap_buffer_f64(a)?;
2281 let dev = self.device(a.device_ordinal())?;
2282 let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2283 let l = crate::cusolver::gpu_cholesky_f64(&a_host, n, dev).map_err(Self::map_gpu_err)?;
2284 let l_buf = crate::transfer::cpu_to_gpu(&l, dev).map_err(Self::map_gpu_err)?;
2285 Ok(Self::wrap_buffer_f64(l_buf, a.device_ordinal()))
2286 }
2287
2288 fn solve_f32(
2289 &self,
2290 a: &GpuBufferHandle,
2291 b: &GpuBufferHandle,
2292 n: usize,
2293 nrhs: usize,
2294 ) -> FerrotorchResult<GpuBufferHandle> {
2295 let a_buf = Self::unwrap_buffer(a)?;
2296 let b_buf = Self::unwrap_buffer(b)?;
2297 let dev = self.device(a.device_ordinal())?;
2298 let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2299 let b_host = crate::transfer::gpu_to_cpu(b_buf, dev).map_err(Self::map_gpu_err)?;
2300 let x =
2301 crate::cusolver::gpu_solve_f32(&a_host, &b_host, n, nrhs, dev)
2302 .map_err(Self::map_gpu_err)?;
2303 let x_buf = crate::transfer::cpu_to_gpu(&x, dev).map_err(Self::map_gpu_err)?;
2304 Ok(Self::wrap_buffer(x_buf, a.device_ordinal()))
2305 }
2306
2307 fn solve_f64(
2308 &self,
2309 a: &GpuBufferHandle,
2310 b: &GpuBufferHandle,
2311 n: usize,
2312 nrhs: usize,
2313 ) -> FerrotorchResult<GpuBufferHandle> {
2314 let a_buf = Self::unwrap_buffer_f64(a)?;
2315 let b_buf = Self::unwrap_buffer_f64(b)?;
2316 let dev = self.device(a.device_ordinal())?;
2317 let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2318 let b_host = crate::transfer::gpu_to_cpu(b_buf, dev).map_err(Self::map_gpu_err)?;
2319 let x =
2320 crate::cusolver::gpu_solve_f64(&a_host, &b_host, n, nrhs, dev)
2321 .map_err(Self::map_gpu_err)?;
2322 let x_buf = crate::transfer::cpu_to_gpu(&x, dev).map_err(Self::map_gpu_err)?;
2323 Ok(Self::wrap_buffer_f64(x_buf, a.device_ordinal()))
2324 }
2325
2326 fn qr_f32(
2327 &self,
2328 a: &GpuBufferHandle,
2329 m: usize,
2330 n: usize,
2331 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
2332 let a_buf = Self::unwrap_buffer(a)?;
2333 let dev = self.device(a.device_ordinal())?;
2334 let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2335 let (q, r) =
2336 crate::cusolver::gpu_qr_f32(&a_host, m, n, dev).map_err(Self::map_gpu_err)?;
2337 let q_buf = crate::transfer::cpu_to_gpu(&q, dev).map_err(Self::map_gpu_err)?;
2338 let r_buf = crate::transfer::cpu_to_gpu(&r, dev).map_err(Self::map_gpu_err)?;
2339 let ord = a.device_ordinal();
2340 Ok((Self::wrap_buffer(q_buf, ord), Self::wrap_buffer(r_buf, ord)))
2341 }
2342
2343 fn qr_f64(
2344 &self,
2345 a: &GpuBufferHandle,
2346 m: usize,
2347 n: usize,
2348 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
2349 let a_buf = Self::unwrap_buffer_f64(a)?;
2350 let dev = self.device(a.device_ordinal())?;
2351 let a_host = crate::transfer::gpu_to_cpu(a_buf, dev).map_err(Self::map_gpu_err)?;
2352 let (q, r) =
2353 crate::cusolver::gpu_qr_f64(&a_host, m, n, dev).map_err(Self::map_gpu_err)?;
2354 let q_buf = crate::transfer::cpu_to_gpu(&q, dev).map_err(Self::map_gpu_err)?;
2355 let r_buf = crate::transfer::cpu_to_gpu(&r, dev).map_err(Self::map_gpu_err)?;
2356 let ord = a.device_ordinal();
2357 Ok((
2358 Self::wrap_buffer_f64(q_buf, ord),
2359 Self::wrap_buffer_f64(r_buf, ord),
2360 ))
2361 }
2362}
2363
2364pub fn get_cuda_device() -> FerrotorchResult<Arc<GpuDevice>> {
2375 let backend =
2376 ferrotorch_core::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
2377 let cuda_backend = backend.as_any().downcast_ref::<CudaBackendImpl>().ok_or(
2380 FerrotorchError::InvalidArgument {
2381 message: "registered GPU backend is not CudaBackendImpl".into(),
2382 },
2383 )?;
2384 Ok(Arc::clone(cuda_backend.default_device()?))
2385}
2386
2387pub fn init_cuda_backend() -> FerrotorchResult<()> {
2401 if ferrotorch_core::gpu_dispatch::has_gpu_backend() {
2403 return Ok(());
2404 }
2405 let backend = CudaBackendImpl::new()?;
2406 let _ = ferrotorch_core::gpu_dispatch::register_gpu_backend(Box::new(backend));
2410 Ok(())
2411}
2412
2413#[cfg(test)]
2418#[cfg(feature = "cuda")]
2419mod tests {
2420 use super::*;
2421 use ferrotorch_core::gpu_dispatch;
2422
2423 fn ensure_init() {
2430 if !gpu_dispatch::has_gpu_backend() {
2431 init_cuda_backend().expect("init_cuda_backend");
2432 }
2433 }
2434
2435 #[test]
2436 fn test_init_cuda_backend() {
2437 ensure_init();
2439 assert!(gpu_dispatch::has_gpu_backend());
2440 }
2441
2442 #[test]
2443 fn test_gpu_backend_returns_some() {
2444 ensure_init();
2445 assert!(gpu_dispatch::gpu_backend().is_some());
2446 }
2447
2448 #[test]
2449 fn test_roundtrip_cpu_gpu_cpu() {
2450 ensure_init();
2451 let backend = gpu_dispatch::gpu_backend().expect("backend registered");
2452
2453 let host: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
2454 let bytes: &[u8] = unsafe {
2455 std::slice::from_raw_parts(
2456 host.as_ptr() as *const u8,
2457 host.len() * std::mem::size_of::<f32>(),
2458 )
2459 };
2460
2461 let handle = backend.cpu_to_gpu(bytes, 4, 0).expect("cpu_to_gpu");
2462 assert_eq!(handle.len(), 5);
2463 assert_eq!(handle.device_ordinal(), 0);
2464
2465 let back_bytes = backend.gpu_to_cpu(&handle).expect("gpu_to_cpu");
2466 let back: &[f32] = unsafe {
2467 std::slice::from_raw_parts(back_bytes.as_ptr() as *const f32, back_bytes.len() / 4)
2468 };
2469 assert_eq!(back, &host[..]);
2470 }
2471
2472 #[test]
2473 fn test_add_f32() {
2474 ensure_init();
2475 let backend = gpu_dispatch::gpu_backend().expect("backend registered");
2476
2477 let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
2478 let b_data: Vec<f32> = vec![10.0, 20.0, 30.0, 40.0];
2479 let expected: Vec<f32> = vec![11.0, 22.0, 33.0, 44.0];
2480
2481 let a_bytes: &[u8] =
2482 unsafe { std::slice::from_raw_parts(a_data.as_ptr() as *const u8, a_data.len() * 4) };
2483 let b_bytes: &[u8] =
2484 unsafe { std::slice::from_raw_parts(b_data.as_ptr() as *const u8, b_data.len() * 4) };
2485
2486 let a_handle = backend.cpu_to_gpu(a_bytes, 4, 0).expect("cpu_to_gpu a");
2487 let b_handle = backend.cpu_to_gpu(b_bytes, 4, 0).expect("cpu_to_gpu b");
2488
2489 let result = backend.add_f32(&a_handle, &b_handle).expect("add_f32");
2490 assert_eq!(result.len(), 4);
2491
2492 let result_bytes = backend.gpu_to_cpu(&result).expect("gpu_to_cpu");
2493 let result_f32: &[f32] = unsafe {
2494 std::slice::from_raw_parts(result_bytes.as_ptr() as *const f32, result_bytes.len() / 4)
2495 };
2496
2497 for (i, (&got, &exp)) in result_f32.iter().zip(expected.iter()).enumerate() {
2498 assert!(
2499 (got - exp).abs() < 1e-6,
2500 "element {i}: got {got}, expected {exp}",
2501 );
2502 }
2503 }
2504
2505 #[test]
2506 fn test_matmul_f32() {
2507 ensure_init();
2508 let backend = gpu_dispatch::gpu_backend().expect("backend registered");
2509
2510 let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
2518 let b_data: Vec<f32> = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
2519 let expected: Vec<f32> = vec![58.0, 64.0, 139.0, 154.0];
2520
2521 let a_bytes: &[u8] =
2522 unsafe { std::slice::from_raw_parts(a_data.as_ptr() as *const u8, a_data.len() * 4) };
2523 let b_bytes: &[u8] =
2524 unsafe { std::slice::from_raw_parts(b_data.as_ptr() as *const u8, b_data.len() * 4) };
2525
2526 let a_handle = backend.cpu_to_gpu(a_bytes, 4, 0).expect("cpu_to_gpu a");
2527 let b_handle = backend.cpu_to_gpu(b_bytes, 4, 0).expect("cpu_to_gpu b");
2528
2529 let result = backend
2530 .matmul_f32(&a_handle, &b_handle, 2, 3, 2)
2531 .expect("matmul_f32");
2532 assert_eq!(result.len(), 4);
2533
2534 let result_bytes = backend.gpu_to_cpu(&result).expect("gpu_to_cpu");
2535 let result_f32: &[f32] = unsafe {
2536 std::slice::from_raw_parts(result_bytes.as_ptr() as *const f32, result_bytes.len() / 4)
2537 };
2538
2539 for (i, (&got, &exp)) in result_f32.iter().zip(expected.iter()).enumerate() {
2540 assert!(
2541 (got - exp).abs() < 1e-3,
2542 "element {i}: got {got}, expected {exp}",
2543 );
2544 }
2545 }
2546}