1use std::sync::Arc;
15
16use ferrotorch_core::error::{FerrotorchError, FerrotorchResult};
17use ferrotorch_core::gpu_dispatch::{GpuBackend, GpuBufferHandle, GpuRngState};
18
19use crate::buffer::CudaBuffer;
20use crate::device::GpuDevice;
21
22pub struct CudaBackendImpl {
32 devices: Vec<Arc<GpuDevice>>,
33}
34
35impl CudaBackendImpl {
36 pub fn new() -> FerrotorchResult<Self> {
43 let device = Arc::new(
44 GpuDevice::new(0).map_err(|e| FerrotorchError::InvalidArgument {
45 message: format!("CUDA init failed: {e}"),
46 })?,
47 );
48 Ok(Self {
49 devices: vec![device],
50 })
51 }
52
53 pub fn default_device(&self) -> FerrotorchResult<&Arc<GpuDevice>> {
55 self.device(0)
56 }
57
58 fn device(&self, ordinal: usize) -> FerrotorchResult<&Arc<GpuDevice>> {
60 self.devices
61 .get(ordinal)
62 .ok_or(FerrotorchError::InvalidArgument {
63 message: format!("CUDA device {ordinal} not available"),
64 })
65 }
66
67 fn wrap_buffer(buf: CudaBuffer<f32>, ordinal: usize) -> GpuBufferHandle {
69 let len = buf.len();
70 GpuBufferHandle::new(Box::new(buf), ordinal, len)
71 }
72
73 fn wrap_buffer_f64(buf: CudaBuffer<f64>, ordinal: usize) -> GpuBufferHandle {
75 let len = buf.len();
76 GpuBufferHandle::new(Box::new(buf), ordinal, len)
77 }
78
79 fn unwrap_buffer(handle: &GpuBufferHandle) -> FerrotorchResult<&CudaBuffer<f32>> {
81 handle
82 .downcast_ref::<CudaBuffer<f32>>()
83 .ok_or(FerrotorchError::InvalidArgument {
84 message: "GPU handle does not contain a CudaBuffer<f32>".into(),
85 })
86 }
87
88 fn unwrap_buffer_mut(handle: &mut GpuBufferHandle) -> FerrotorchResult<&mut CudaBuffer<f32>> {
90 handle
91 .downcast_mut::<CudaBuffer<f32>>()
92 .ok_or(FerrotorchError::InvalidArgument {
93 message: "GPU handle does not contain a CudaBuffer<f32>".into(),
94 })
95 }
96
97 fn unwrap_buffer_f64(handle: &GpuBufferHandle) -> FerrotorchResult<&CudaBuffer<f64>> {
99 handle
100 .downcast_ref::<CudaBuffer<f64>>()
101 .ok_or(FerrotorchError::InvalidArgument {
102 message: "GPU handle does not contain a CudaBuffer<f64>".into(),
103 })
104 }
105
106 fn map_gpu_err(e: crate::error::GpuError) -> FerrotorchError {
108 FerrotorchError::InvalidArgument {
109 message: format!("{e}"),
110 }
111 }
112}
113
114impl GpuBackend for CudaBackendImpl {
119 fn as_any(&self) -> &dyn std::any::Any {
120 self
121 }
122
123 fn cpu_to_gpu(
124 &self,
125 data: &[u8],
126 elem_size: usize,
127 device: usize,
128 ) -> FerrotorchResult<GpuBufferHandle> {
129 let dev = self.device(device)?;
130 match elem_size {
131 4 => {
132 let count = data.len() / 4;
135 let f32_data: &[f32] =
136 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, count) };
137 let buf = crate::transfer::cpu_to_gpu(f32_data, dev).map_err(Self::map_gpu_err)?;
138 Ok(Self::wrap_buffer(buf, device))
139 }
140 8 => {
141 let count = data.len() / 8;
144 let f64_data: &[f64] =
145 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f64, count) };
146 let buf = crate::transfer::cpu_to_gpu(f64_data, dev).map_err(Self::map_gpu_err)?;
147 Ok(Self::wrap_buffer_f64(buf, device))
148 }
149 other => Err(FerrotorchError::InvalidArgument {
150 message: format!("cpu_to_gpu: unsupported elem_size {other} (expected 4 or 8)"),
151 }),
152 }
153 }
154
155 fn cpu_to_gpu_pinned(
156 &self,
157 data: &[u8],
158 elem_size: usize,
159 device: usize,
160 ) -> FerrotorchResult<GpuBufferHandle> {
161 let dev = self.device(device)?;
162 match elem_size {
163 4 => {
164 let count = data.len() / 4;
165 let f32_data: &[f32] =
166 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, count) };
167 let buf = crate::transfer::cpu_to_gpu_pinned(f32_data, dev)
168 .map_err(Self::map_gpu_err)?;
169 Ok(Self::wrap_buffer(buf, device))
170 }
171 8 => {
172 let count = data.len() / 8;
173 let f64_data: &[f64] =
174 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f64, count) };
175 let buf = crate::transfer::cpu_to_gpu_pinned(f64_data, dev)
176 .map_err(Self::map_gpu_err)?;
177 Ok(Self::wrap_buffer_f64(buf, device))
178 }
179 other => Err(FerrotorchError::InvalidArgument {
180 message: format!(
181 "cpu_to_gpu_pinned: unsupported elem_size {other} (expected 4 or 8)"
182 ),
183 }),
184 }
185 }
186
187 fn gpu_to_cpu(&self, handle: &GpuBufferHandle) -> FerrotorchResult<Vec<u8>> {
188 let dev = self.device(handle.device_ordinal())?;
189
190 if let Ok(buf) = Self::unwrap_buffer(handle) {
192 let f32_data = crate::transfer::gpu_to_cpu(buf, dev).map_err(Self::map_gpu_err)?;
193
194 let bytes = unsafe {
199 let mut v = std::mem::ManuallyDrop::new(f32_data);
200 let ptr = v.as_mut_ptr() as *mut u8;
201 let len = v.len() * 4;
202 let cap = v.capacity() * 4;
203 Vec::from_raw_parts(ptr, len, cap)
204 };
205 Ok(bytes)
206 } else if let Ok(buf) = Self::unwrap_buffer_f64(handle) {
207 let f64_data = crate::transfer::gpu_to_cpu(buf, dev).map_err(Self::map_gpu_err)?;
208
209 let bytes = unsafe {
214 let mut v = std::mem::ManuallyDrop::new(f64_data);
215 let ptr = v.as_mut_ptr() as *mut u8;
216 let len = v.len() * 8;
217 let cap = v.capacity() * 8;
218 Vec::from_raw_parts(ptr, len, cap)
219 };
220 Ok(bytes)
221 } else {
222 Err(FerrotorchError::InvalidArgument {
223 message: "gpu_to_cpu: handle is neither CudaBuffer<f32> nor CudaBuffer<f64>".into(),
224 })
225 }
226 }
227
228 fn clone_buffer(&self, handle: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
229 let bytes = self.gpu_to_cpu(handle)?;
232 let elem_size = if handle.downcast_ref::<CudaBuffer<f64>>().is_some() {
234 8
235 } else {
236 4
237 };
238 self.cpu_to_gpu(&bytes, elem_size, handle.device_ordinal())
239 }
240
241 fn alloc_zeros(
242 &self,
243 len: usize,
244 elem_size: usize,
245 device: usize,
246 ) -> FerrotorchResult<GpuBufferHandle> {
247 let dev = self.device(device)?;
248 match elem_size {
249 4 => {
250 let buf = crate::transfer::alloc_zeros_f32(len, dev).map_err(Self::map_gpu_err)?;
251 Ok(Self::wrap_buffer(buf, device))
252 }
253 8 => {
254 let buf = crate::transfer::alloc_zeros_f64(len, dev).map_err(Self::map_gpu_err)?;
255 Ok(Self::wrap_buffer_f64(buf, device))
256 }
257 other => Err(FerrotorchError::InvalidArgument {
258 message: format!("alloc_zeros: unsupported elem_size {other} (expected 4 or 8)"),
259 }),
260 }
261 }
262
263 fn add_f32(
266 &self,
267 a: &GpuBufferHandle,
268 b: &GpuBufferHandle,
269 ) -> FerrotorchResult<GpuBufferHandle> {
270 let a_buf = Self::unwrap_buffer(a)?;
271 let b_buf = Self::unwrap_buffer(b)?;
272 let dev = self.device(a.device_ordinal())?;
273 let result = crate::kernels::gpu_add(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
274 Ok(Self::wrap_buffer(result, a.device_ordinal()))
275 }
276
277 fn sub_f32(
278 &self,
279 a: &GpuBufferHandle,
280 b: &GpuBufferHandle,
281 ) -> FerrotorchResult<GpuBufferHandle> {
282 let a_buf = Self::unwrap_buffer(a)?;
283 let b_buf = Self::unwrap_buffer(b)?;
284 let dev = self.device(a.device_ordinal())?;
285 let result = crate::kernels::gpu_sub(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
286 Ok(Self::wrap_buffer(result, a.device_ordinal()))
287 }
288
289 fn mul_f32(
290 &self,
291 a: &GpuBufferHandle,
292 b: &GpuBufferHandle,
293 ) -> FerrotorchResult<GpuBufferHandle> {
294 let a_buf = Self::unwrap_buffer(a)?;
295 let b_buf = Self::unwrap_buffer(b)?;
296 let dev = self.device(a.device_ordinal())?;
297 let result = crate::kernels::gpu_mul(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
298 Ok(Self::wrap_buffer(result, a.device_ordinal()))
299 }
300
301 fn neg_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
302 let a_buf = Self::unwrap_buffer(a)?;
303 let dev = self.device(a.device_ordinal())?;
304 let result = crate::kernels::gpu_neg(a_buf, dev).map_err(Self::map_gpu_err)?;
305 Ok(Self::wrap_buffer(result, a.device_ordinal()))
306 }
307
308 fn relu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
309 let a_buf = Self::unwrap_buffer(a)?;
310 let dev = self.device(a.device_ordinal())?;
311 let result = crate::kernels::gpu_relu(a_buf, dev).map_err(Self::map_gpu_err)?;
312 Ok(Self::wrap_buffer(result, a.device_ordinal()))
313 }
314
315 fn div_f32(
316 &self,
317 a: &GpuBufferHandle,
318 b: &GpuBufferHandle,
319 ) -> FerrotorchResult<GpuBufferHandle> {
320 let a_buf = Self::unwrap_buffer(a)?;
321 let b_buf = Self::unwrap_buffer(b)?;
322 let dev = self.device(a.device_ordinal())?;
323 let result = crate::kernels::gpu_div(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
324 Ok(Self::wrap_buffer(result, a.device_ordinal()))
325 }
326
327 fn exp_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
328 let a_buf = Self::unwrap_buffer(a)?;
329 let dev = self.device(a.device_ordinal())?;
330 let result = crate::kernels::gpu_exp(a_buf, dev).map_err(Self::map_gpu_err)?;
331 Ok(Self::wrap_buffer(result, a.device_ordinal()))
332 }
333
334 fn log_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
335 let a_buf = Self::unwrap_buffer(a)?;
336 let dev = self.device(a.device_ordinal())?;
337 let result = crate::kernels::gpu_log(a_buf, dev).map_err(Self::map_gpu_err)?;
338 Ok(Self::wrap_buffer(result, a.device_ordinal()))
339 }
340
341 fn sqrt_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
342 let a_buf = Self::unwrap_buffer(a)?;
343 let dev = self.device(a.device_ordinal())?;
344 let result = crate::kernels::gpu_sqrt(a_buf, dev).map_err(Self::map_gpu_err)?;
345 Ok(Self::wrap_buffer(result, a.device_ordinal()))
346 }
347
348 fn pow_f32(&self, a: &GpuBufferHandle, exponent: f32) -> FerrotorchResult<GpuBufferHandle> {
349 let a_buf = Self::unwrap_buffer(a)?;
350 let dev = self.device(a.device_ordinal())?;
351 let result =
352 crate::kernels::gpu_pow(a_buf, exponent, dev).map_err(Self::map_gpu_err)?;
353 Ok(Self::wrap_buffer(result, a.device_ordinal()))
354 }
355
356 fn abs_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
357 let a_buf = Self::unwrap_buffer(a)?;
358 let dev = self.device(a.device_ordinal())?;
359 let result = crate::kernels::gpu_abs(a_buf, dev).map_err(Self::map_gpu_err)?;
360 Ok(Self::wrap_buffer(result, a.device_ordinal()))
361 }
362
363 fn sigmoid_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
364 let a_buf = Self::unwrap_buffer(a)?;
365 let dev = self.device(a.device_ordinal())?;
366 let result = crate::kernels::gpu_sigmoid(a_buf, dev).map_err(Self::map_gpu_err)?;
367 Ok(Self::wrap_buffer(result, a.device_ordinal()))
368 }
369
370 fn tanh_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
371 let a_buf = Self::unwrap_buffer(a)?;
372 let dev = self.device(a.device_ordinal())?;
373 let result = crate::kernels::gpu_tanh(a_buf, dev).map_err(Self::map_gpu_err)?;
374 Ok(Self::wrap_buffer(result, a.device_ordinal()))
375 }
376
377 #[allow(clippy::too_many_arguments)]
378 fn fused_adam_f32(
379 &self,
380 param: &mut GpuBufferHandle,
381 grad: &GpuBufferHandle,
382 exp_avg: &mut GpuBufferHandle,
383 exp_avg_sq: &mut GpuBufferHandle,
384 beta1: f32,
385 beta2: f32,
386 lr: f32,
387 eps: f32,
388 bc1: f32,
389 bc2: f32,
390 weight_decay: f32,
391 ) -> FerrotorchResult<()> {
392 let ordinal = param.device_ordinal();
393 let dev = self.device(ordinal)?;
394 let p_buf = Self::unwrap_buffer_mut(param)?;
395 let g_buf = Self::unwrap_buffer(grad)?;
396 let m_buf = Self::unwrap_buffer_mut(exp_avg)?;
397 let v_buf = Self::unwrap_buffer_mut(exp_avg_sq)?;
398 crate::kernels::gpu_fused_adam(
399 p_buf,
400 g_buf,
401 m_buf,
402 v_buf,
403 beta1,
404 beta2,
405 lr,
406 eps,
407 bc1,
408 bc2,
409 weight_decay,
410 dev,
411 )
412 .map_err(Self::map_gpu_err)?;
413 Ok(())
414 }
415
416 #[allow(clippy::too_many_arguments)]
417 fn maxpool2d_f32(
418 &self,
419 input: &GpuBufferHandle,
420 batch: usize,
421 channels: usize,
422 h_in: usize,
423 w_in: usize,
424 kh: usize,
425 kw: usize,
426 sh: usize,
427 sw: usize,
428 ph: usize,
429 pw: usize,
430 ) -> FerrotorchResult<(GpuBufferHandle, [usize; 4])> {
431 let buf = Self::unwrap_buffer(input)?;
432 let dev = self.device(input.device_ordinal())?;
433 let (out, shape) = crate::kernels::gpu_maxpool2d(
434 buf, batch, channels, h_in, w_in, kh, kw, sh, sw, ph, pw, dev,
435 ).map_err(Self::map_gpu_err)?;
436 Ok((Self::wrap_buffer(out, input.device_ordinal()), shape))
437 }
438
439 #[allow(clippy::too_many_arguments)]
440 fn avgpool2d_f32(
441 &self,
442 input: &GpuBufferHandle,
443 batch: usize,
444 channels: usize,
445 h_in: usize,
446 w_in: usize,
447 kh: usize,
448 kw: usize,
449 sh: usize,
450 sw: usize,
451 ph: usize,
452 pw: usize,
453 ) -> FerrotorchResult<(GpuBufferHandle, [usize; 4])> {
454 let buf = Self::unwrap_buffer(input)?;
455 let dev = self.device(input.device_ordinal())?;
456 let (out, shape) = crate::kernels::gpu_avgpool2d(
457 buf, batch, channels, h_in, w_in, kh, kw, sh, sw, ph, pw, dev,
458 ).map_err(Self::map_gpu_err)?;
459 Ok((Self::wrap_buffer(out, input.device_ordinal()), shape))
460 }
461
462 #[allow(clippy::too_many_arguments)]
463 fn conv2d_f32(
464 &self,
465 input: &GpuBufferHandle,
466 weight: &GpuBufferHandle,
467 bias: Option<&GpuBufferHandle>,
468 input_shape: [usize; 4],
469 weight_shape: [usize; 4],
470 stride: (usize, usize),
471 padding: (usize, usize),
472 ) -> FerrotorchResult<(GpuBufferHandle, [usize; 4])> {
473 let input_buf = Self::unwrap_buffer(input)?;
474 let weight_buf = Self::unwrap_buffer(weight)?;
475 let bias_buf = match bias {
476 Some(b) => Some(Self::unwrap_buffer(b)?),
477 None => None,
478 };
479 let dev = self.device(input.device_ordinal())?;
480 let (out_buf, out_shape) = crate::conv::gpu_conv2d_f32(
481 input_buf,
482 weight_buf,
483 bias_buf,
484 input_shape,
485 weight_shape,
486 stride,
487 padding,
488 dev,
489 )
490 .map_err(Self::map_gpu_err)?;
491 Ok((Self::wrap_buffer(out_buf, input.device_ordinal()), out_shape))
492 }
493
494 fn fused_gru_cell_f32(
495 &self,
496 input_gates: &GpuBufferHandle,
497 hidden_gates: &GpuBufferHandle,
498 bias_ih: &GpuBufferHandle,
499 bias_hh: &GpuBufferHandle,
500 hx: &GpuBufferHandle,
501 hidden_size: usize,
502 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
503 let ig = Self::unwrap_buffer(input_gates)?;
504 let hg = Self::unwrap_buffer(hidden_gates)?;
505 let bih = Self::unwrap_buffer(bias_ih)?;
506 let bhh = Self::unwrap_buffer(bias_hh)?;
507 let hx_buf = Self::unwrap_buffer(hx)?;
508 let dev = self.device(input_gates.device_ordinal())?;
509 let (hy, ws) = crate::kernels::gpu_fused_gru_forward(
510 ig, hg, bih, bhh, hx_buf, hidden_size, dev,
511 )
512 .map_err(Self::map_gpu_err)?;
513 let ord = input_gates.device_ordinal();
514 Ok((Self::wrap_buffer(hy, ord), Self::wrap_buffer(ws, ord)))
515 }
516
517 fn synchronize(&self, device: usize) -> FerrotorchResult<()> {
518 let dev = self.device(device)?;
519 dev.stream()
520 .synchronize()
521 .map_err(|e| FerrotorchError::InvalidArgument {
522 message: format!("CUDA synchronize failed: {e}"),
523 })?;
524 Ok(())
525 }
526
527 fn stream_count(&self, device: usize) -> usize {
528 crate::stream::StreamPool::pool_size(device)
529 }
530
531 fn matmul_f32(
534 &self,
535 a: &GpuBufferHandle,
536 b: &GpuBufferHandle,
537 m: usize,
538 k: usize,
539 n: usize,
540 ) -> FerrotorchResult<GpuBufferHandle> {
541 let a_buf = Self::unwrap_buffer(a)?;
542 let b_buf = Self::unwrap_buffer(b)?;
543 let dev = self.device(a.device_ordinal())?;
544 let result =
545 crate::blas::gpu_matmul_f32(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
546 Ok(Self::wrap_buffer(result, a.device_ordinal()))
547 }
548
549 fn sum_f32(&self, a: &GpuBufferHandle, _len: usize) -> FerrotorchResult<GpuBufferHandle> {
552 let a_buf = Self::unwrap_buffer(a)?;
553 let dev = self.device(a.device_ordinal())?;
554 let result = crate::kernels::gpu_reduce_sum(a_buf, dev).map_err(Self::map_gpu_err)?;
555 Ok(Self::wrap_buffer(result, a.device_ordinal()))
556 }
557
558 fn matmul_f64(
561 &self,
562 a: &GpuBufferHandle,
563 b: &GpuBufferHandle,
564 m: usize,
565 k: usize,
566 n: usize,
567 ) -> FerrotorchResult<GpuBufferHandle> {
568 let a_buf = Self::unwrap_buffer_f64(a)?;
569 let b_buf = Self::unwrap_buffer_f64(b)?;
570 let dev = self.device(a.device_ordinal())?;
571 let result =
572 crate::blas::gpu_matmul_f64(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
573 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
574 }
575
576 fn broadcast_add_f32(
579 &self,
580 a: &GpuBufferHandle,
581 b: &GpuBufferHandle,
582 a_shape: &[usize],
583 b_shape: &[usize],
584 out_shape: &[usize],
585 ) -> FerrotorchResult<GpuBufferHandle> {
586 let a_buf = Self::unwrap_buffer(a)?;
587 let b_buf = Self::unwrap_buffer(b)?;
588 let dev = self.device(a.device_ordinal())?;
589 let result =
590 crate::kernels::gpu_broadcast_add(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
591 .map_err(Self::map_gpu_err)?;
592 Ok(Self::wrap_buffer(result, a.device_ordinal()))
593 }
594
595 fn broadcast_sub_f32(
596 &self,
597 a: &GpuBufferHandle,
598 b: &GpuBufferHandle,
599 a_shape: &[usize],
600 b_shape: &[usize],
601 out_shape: &[usize],
602 ) -> FerrotorchResult<GpuBufferHandle> {
603 let a_buf = Self::unwrap_buffer(a)?;
604 let b_buf = Self::unwrap_buffer(b)?;
605 let dev = self.device(a.device_ordinal())?;
606 let result =
607 crate::kernels::gpu_broadcast_sub(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
608 .map_err(Self::map_gpu_err)?;
609 Ok(Self::wrap_buffer(result, a.device_ordinal()))
610 }
611
612 fn broadcast_mul_f32(
613 &self,
614 a: &GpuBufferHandle,
615 b: &GpuBufferHandle,
616 a_shape: &[usize],
617 b_shape: &[usize],
618 out_shape: &[usize],
619 ) -> FerrotorchResult<GpuBufferHandle> {
620 let a_buf = Self::unwrap_buffer(a)?;
621 let b_buf = Self::unwrap_buffer(b)?;
622 let dev = self.device(a.device_ordinal())?;
623 let result =
624 crate::kernels::gpu_broadcast_mul(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
625 .map_err(Self::map_gpu_err)?;
626 Ok(Self::wrap_buffer(result, a.device_ordinal()))
627 }
628
629 fn broadcast_div_f32(
630 &self,
631 a: &GpuBufferHandle,
632 b: &GpuBufferHandle,
633 a_shape: &[usize],
634 b_shape: &[usize],
635 out_shape: &[usize],
636 ) -> FerrotorchResult<GpuBufferHandle> {
637 let a_buf = Self::unwrap_buffer(a)?;
638 let b_buf = Self::unwrap_buffer(b)?;
639 let dev = self.device(a.device_ordinal())?;
640 let result =
641 crate::kernels::gpu_broadcast_div(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
642 .map_err(Self::map_gpu_err)?;
643 Ok(Self::wrap_buffer(result, a.device_ordinal()))
644 }
645
646 fn softmax_f32(
647 &self,
648 a: &GpuBufferHandle,
649 rows: usize,
650 cols: usize,
651 ) -> FerrotorchResult<GpuBufferHandle> {
652 let a_buf = Self::unwrap_buffer(a)?;
653 let dev = self.device(a.device_ordinal())?;
654 let result =
655 crate::kernels::gpu_softmax(a_buf, rows, cols, dev).map_err(Self::map_gpu_err)?;
656 Ok(Self::wrap_buffer(result, a.device_ordinal()))
657 }
658
659 fn dropout_f32(
660 &self,
661 a: &GpuBufferHandle,
662 threshold: u32,
663 scale: f32,
664 seed: u32,
665 ) -> FerrotorchResult<GpuBufferHandle> {
666 let a_buf = Self::unwrap_buffer(a)?;
667 let dev = self.device(a.device_ordinal())?;
668 let result = crate::kernels::gpu_dropout(a_buf, threshold, scale, seed, dev)
669 .map_err(Self::map_gpu_err)?;
670 Ok(Self::wrap_buffer(result, a.device_ordinal()))
671 }
672
673 fn dropout_philox_f32(
674 &self,
675 a: &GpuBufferHandle,
676 threshold: u32,
677 scale: f32,
678 ) -> FerrotorchResult<(GpuBufferHandle, GpuRngState)> {
679 let device_ordinal = a.device_ordinal();
680 let n = a.len();
681
682 let rng_state = {
684 let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
685 FerrotorchError::InvalidArgument {
686 message: "failed to lock CUDA RNG manager".into(),
687 }
688 })?;
689 let philox_gen = mgr.generator(device_ordinal);
690 let state = philox_gen.get_state();
691 let counters_needed = n.div_ceil(4);
693 philox_gen.advance(counters_needed as u64);
694 state
695 };
696
697 let a_buf = Self::unwrap_buffer(a)?;
705 let dev = self.device(device_ordinal)?;
706
707 let derived_seed = (rng_state.counter ^ rng_state.seed) as u32;
710 let result = crate::kernels::gpu_dropout(a_buf, threshold, scale, derived_seed, dev)
711 .map_err(Self::map_gpu_err)?;
712
713 let gpu_rng_state = GpuRngState {
714 counter: rng_state.counter,
715 seed: rng_state.seed,
716 offset: rng_state.offset,
717 device: device_ordinal,
718 };
719
720 Ok((Self::wrap_buffer(result, device_ordinal), gpu_rng_state))
721 }
722
723 fn transpose_2d_f32(
724 &self,
725 a: &GpuBufferHandle,
726 m: usize,
727 n: usize,
728 ) -> FerrotorchResult<GpuBufferHandle> {
729 let a_buf = Self::unwrap_buffer(a)?;
730 let dev = self.device(a.device_ordinal())?;
731 let result =
732 crate::kernels::gpu_transpose_2d(a_buf, m, n, dev).map_err(Self::map_gpu_err)?;
733 Ok(Self::wrap_buffer(result, a.device_ordinal()))
734 }
735
736 fn permute_0213_f32(
737 &self,
738 a: &GpuBufferHandle,
739 d0: usize,
740 d1: usize,
741 d2: usize,
742 d3: usize,
743 ) -> FerrotorchResult<GpuBufferHandle> {
744 let a_buf = Self::unwrap_buffer(a)?;
745 let dev = self.device(a.device_ordinal())?;
746 let result = crate::kernels::gpu_permute_0213(a_buf, d0, d1, d2, d3, dev)
747 .map_err(Self::map_gpu_err)?;
748 Ok(Self::wrap_buffer(result, a.device_ordinal()))
749 }
750
751 fn bmm_f32(
752 &self,
753 a: &GpuBufferHandle,
754 b: &GpuBufferHandle,
755 batch: usize,
756 m: usize,
757 k: usize,
758 n: usize,
759 ) -> FerrotorchResult<GpuBufferHandle> {
760 let a_buf = Self::unwrap_buffer(a)?;
761 let b_buf = Self::unwrap_buffer(b)?;
762 let dev = self.device(a.device_ordinal())?;
763 let result = crate::blas::gpu_bmm_f32(a_buf, b_buf, batch, m, k, n, dev)
764 .map_err(Self::map_gpu_err)?;
765 Ok(Self::wrap_buffer(result, a.device_ordinal()))
766 }
767
768 fn bmm_f16_f32(
769 &self,
770 a: &GpuBufferHandle,
771 b: &GpuBufferHandle,
772 batch: usize,
773 m: usize,
774 k: usize,
775 n: usize,
776 ) -> FerrotorchResult<GpuBufferHandle> {
777 let a_buf = Self::unwrap_buffer(a)?;
778 let b_buf = Self::unwrap_buffer(b)?;
779 let dev = self.device(a.device_ordinal())?;
780 let result = crate::blas::gpu_bmm_f16(a_buf, b_buf, batch, m, k, n, dev)
781 .map_err(Self::map_gpu_err)?;
782 Ok(Self::wrap_buffer(result, a.device_ordinal()))
783 }
784
785 fn gelu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
786 let a_buf = Self::unwrap_buffer(a)?;
787 let dev = self.device(a.device_ordinal())?;
788 let result = crate::kernels::gpu_gelu(a_buf, dev).map_err(Self::map_gpu_err)?;
789 Ok(Self::wrap_buffer(result, a.device_ordinal()))
790 }
791
792 fn gelu_tanh_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
793 let a_buf = Self::unwrap_buffer(a)?;
794 let dev = self.device(a.device_ordinal())?;
795 let result = crate::kernels::gpu_gelu_tanh(a_buf, dev).map_err(Self::map_gpu_err)?;
796 Ok(Self::wrap_buffer(result, a.device_ordinal()))
797 }
798
799 fn gelu_erf_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
800 let a_buf = Self::unwrap_buffer(a)?;
801 let dev = self.device(a.device_ordinal())?;
802 let result = crate::kernels::gpu_gelu_erf(a_buf, dev).map_err(Self::map_gpu_err)?;
803 Ok(Self::wrap_buffer(result, a.device_ordinal()))
804 }
805
806 fn layernorm_f32(
807 &self,
808 input: &GpuBufferHandle,
809 weight: &GpuBufferHandle,
810 bias: &GpuBufferHandle,
811 rows: usize,
812 cols: usize,
813 eps: f32,
814 ) -> FerrotorchResult<GpuBufferHandle> {
815 let in_buf = Self::unwrap_buffer(input)?;
816 let w_buf = Self::unwrap_buffer(weight)?;
817 let b_buf = Self::unwrap_buffer(bias)?;
818 let dev = self.device(input.device_ordinal())?;
819 let result = crate::kernels::gpu_layernorm(in_buf, w_buf, b_buf, rows, cols, eps, dev)
820 .map_err(Self::map_gpu_err)?;
821 Ok(Self::wrap_buffer(result, input.device_ordinal()))
822 }
823
824 fn rmsnorm_f32(
825 &self,
826 input: &GpuBufferHandle,
827 weight: &GpuBufferHandle,
828 rows: usize,
829 cols: usize,
830 eps: f32,
831 ) -> FerrotorchResult<GpuBufferHandle> {
832 let in_buf = Self::unwrap_buffer(input)?;
833 let w_buf = Self::unwrap_buffer(weight)?;
834 let dev = self.device(input.device_ordinal())?;
835 let result = crate::kernels::gpu_rmsnorm(in_buf, w_buf, rows, cols, eps, dev)
836 .map_err(Self::map_gpu_err)?;
837 Ok(Self::wrap_buffer(result, input.device_ordinal()))
838 }
839
840 fn rmsnorm_backward_f32(
841 &self,
842 input: &GpuBufferHandle,
843 grad_output: &GpuBufferHandle,
844 weight: &GpuBufferHandle,
845 rows: usize,
846 cols: usize,
847 eps: f32,
848 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
849 let in_buf = Self::unwrap_buffer(input)?;
850 let go_buf = Self::unwrap_buffer(grad_output)?;
851 let w_buf = Self::unwrap_buffer(weight)?;
852 let dev = self.device(input.device_ordinal())?;
853 let (gi, gw) =
854 crate::kernels::gpu_rmsnorm_backward(in_buf, go_buf, w_buf, rows, cols, eps, dev)
855 .map_err(Self::map_gpu_err)?;
856 let ordinal = input.device_ordinal();
857 Ok((Self::wrap_buffer(gi, ordinal), Self::wrap_buffer(gw, ordinal)))
858 }
859
860 fn slice_write_f32(
861 &self,
862 src: &GpuBufferHandle,
863 dst: &mut GpuBufferHandle,
864 n_batch: usize,
865 d: usize,
866 max_len: usize,
867 pos: usize,
868 ) -> FerrotorchResult<()> {
869 let src_buf = Self::unwrap_buffer(src)?;
870 let dst_buf =
871 dst.downcast_mut::<CudaBuffer<f32>>()
872 .ok_or(FerrotorchError::InvalidArgument {
873 message: "slice_write_f32: dst is not CudaBuffer<f32>".into(),
874 })?;
875 let dev = self.device(src.device_ordinal())?;
876 crate::kernels::gpu_slice_write(src_buf, dst_buf, n_batch, d, max_len, pos, dev)
877 .map_err(Self::map_gpu_err)?;
878 Ok(())
879 }
880
881 fn slice_read_f32(
882 &self,
883 src: &GpuBufferHandle,
884 n_batch: usize,
885 d: usize,
886 len: usize,
887 max_len: usize,
888 ) -> FerrotorchResult<GpuBufferHandle> {
889 let src_buf = Self::unwrap_buffer(src)?;
890 let dev = self.device(src.device_ordinal())?;
891 let result = crate::kernels::gpu_slice_read(src_buf, n_batch, d, len, max_len, dev)
892 .map_err(Self::map_gpu_err)?;
893 Ok(Self::wrap_buffer(result, src.device_ordinal()))
894 }
895
896 fn embed_lookup_f32(
897 &self,
898 idx: &GpuBufferHandle,
899 weight: &GpuBufferHandle,
900 d: usize,
901 ) -> FerrotorchResult<GpuBufferHandle> {
902 let idx_buf = Self::unwrap_buffer(idx)?;
903 let w_buf = Self::unwrap_buffer(weight)?;
904 let dev = self.device(idx.device_ordinal())?;
905 let result =
906 crate::kernels::gpu_embed_lookup(idx_buf, w_buf, d, dev).map_err(Self::map_gpu_err)?;
907 Ok(Self::wrap_buffer(result, idx.device_ordinal()))
908 }
909
910 fn embed_lookup_batch_f32(
911 &self,
912 indices: &GpuBufferHandle,
913 weight: &GpuBufferHandle,
914 n: usize,
915 d: usize,
916 ) -> FerrotorchResult<GpuBufferHandle> {
917 let idx_buf = Self::unwrap_buffer(indices)?;
918 let w_buf = Self::unwrap_buffer(weight)?;
919 let dev = self.device(indices.device_ordinal())?;
920 let result = crate::kernels::gpu_embed_lookup_batch(idx_buf, w_buf, n, d, dev)
921 .map_err(Self::map_gpu_err)?;
922 Ok(Self::wrap_buffer(result, indices.device_ordinal()))
923 }
924
925 fn scatter_add_rows_f32(
926 &self,
927 grad_output: &GpuBufferHandle,
928 indices: &GpuBufferHandle,
929 num_embeddings: usize,
930 d: usize,
931 ) -> FerrotorchResult<GpuBufferHandle> {
932 let go_buf = Self::unwrap_buffer(grad_output)?;
933 let idx_buf = Self::unwrap_buffer(indices)?;
934 let dev = self.device(grad_output.device_ordinal())?;
935 let result = crate::kernels::gpu_scatter_add_rows(go_buf, idx_buf, num_embeddings, d, dev)
936 .map_err(Self::map_gpu_err)?;
937 Ok(Self::wrap_buffer(result, grad_output.device_ordinal()))
938 }
939
940 fn scale_f32(&self, a: &GpuBufferHandle, scalar: f32) -> FerrotorchResult<GpuBufferHandle> {
941 let a_buf = Self::unwrap_buffer(a)?;
942 let dev = self.device(a.device_ordinal())?;
943 let result = crate::kernels::gpu_scale(a_buf, scalar, dev).map_err(Self::map_gpu_err)?;
944 Ok(Self::wrap_buffer(result, a.device_ordinal()))
945 }
946
947 fn relu_backward_f32(
948 &self,
949 grad: &GpuBufferHandle,
950 input: &GpuBufferHandle,
951 ) -> FerrotorchResult<GpuBufferHandle> {
952 let grad_buf = Self::unwrap_buffer(grad)?;
953 let input_buf = Self::unwrap_buffer(input)?;
954 let dev = self.device(grad.device_ordinal())?;
955 let result = crate::kernels::gpu_relu_backward(grad_buf, input_buf, dev)
956 .map_err(Self::map_gpu_err)?;
957 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
958 }
959
960 fn gelu_backward_f32(
961 &self,
962 grad: &GpuBufferHandle,
963 input: &GpuBufferHandle,
964 ) -> FerrotorchResult<GpuBufferHandle> {
965 let grad_buf = Self::unwrap_buffer(grad)?;
966 let input_buf = Self::unwrap_buffer(input)?;
967 let dev = self.device(grad.device_ordinal())?;
968 let result = crate::kernels::gpu_gelu_backward(grad_buf, input_buf, dev)
969 .map_err(Self::map_gpu_err)?;
970 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
971 }
972
973 fn gelu_backward_tanh_f32(
974 &self,
975 grad: &GpuBufferHandle,
976 input: &GpuBufferHandle,
977 ) -> FerrotorchResult<GpuBufferHandle> {
978 let grad_buf = Self::unwrap_buffer(grad)?;
979 let input_buf = Self::unwrap_buffer(input)?;
980 let dev = self.device(grad.device_ordinal())?;
981 let result = crate::kernels::gpu_gelu_backward_tanh(grad_buf, input_buf, dev)
982 .map_err(Self::map_gpu_err)?;
983 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
984 }
985
986 fn gelu_backward_erf_f32(
987 &self,
988 grad: &GpuBufferHandle,
989 input: &GpuBufferHandle,
990 ) -> FerrotorchResult<GpuBufferHandle> {
991 let grad_buf = Self::unwrap_buffer(grad)?;
992 let input_buf = Self::unwrap_buffer(input)?;
993 let dev = self.device(grad.device_ordinal())?;
994 let result = crate::kernels::gpu_gelu_backward_erf(grad_buf, input_buf, dev)
995 .map_err(Self::map_gpu_err)?;
996 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
997 }
998
999 fn cumsum_f32(
1000 &self,
1001 a: &GpuBufferHandle,
1002 outer: usize,
1003 dim_size: usize,
1004 inner: usize,
1005 ) -> FerrotorchResult<GpuBufferHandle> {
1006 let a_buf = Self::unwrap_buffer(a)?;
1007 let dev = self.device(a.device_ordinal())?;
1008 let result = crate::kernels::gpu_cumsum(a_buf, outer, dim_size, inner, dev)
1009 .map_err(Self::map_gpu_err)?;
1010 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1011 }
1012
1013 fn cumprod_f32(
1014 &self,
1015 a: &GpuBufferHandle,
1016 outer: usize,
1017 dim_size: usize,
1018 inner: usize,
1019 ) -> FerrotorchResult<GpuBufferHandle> {
1020 let a_buf = Self::unwrap_buffer(a)?;
1021 let dev = self.device(a.device_ordinal())?;
1022 let result = crate::kernels::gpu_cumprod(a_buf, outer, dim_size, inner, dev)
1023 .map_err(Self::map_gpu_err)?;
1024 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1025 }
1026
1027 fn cummax_f32(
1028 &self,
1029 a: &GpuBufferHandle,
1030 outer: usize,
1031 dim_size: usize,
1032 inner: usize,
1033 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
1034 let a_buf = Self::unwrap_buffer(a)?;
1035 let dev = self.device(a.device_ordinal())?;
1036 let (vals, idxs) = crate::kernels::gpu_cummax(a_buf, outer, dim_size, inner, dev)
1037 .map_err(Self::map_gpu_err)?;
1038 let ord = a.device_ordinal();
1039 Ok((Self::wrap_buffer(vals, ord), Self::wrap_buffer(idxs, ord)))
1040 }
1041
1042 fn cummin_f32(
1043 &self,
1044 a: &GpuBufferHandle,
1045 outer: usize,
1046 dim_size: usize,
1047 inner: usize,
1048 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
1049 let a_buf = Self::unwrap_buffer(a)?;
1050 let dev = self.device(a.device_ordinal())?;
1051 let (vals, idxs) = crate::kernels::gpu_cummin(a_buf, outer, dim_size, inner, dev)
1052 .map_err(Self::map_gpu_err)?;
1053 let ord = a.device_ordinal();
1054 Ok((Self::wrap_buffer(vals, ord), Self::wrap_buffer(idxs, ord)))
1055 }
1056
1057 fn logcumsumexp_f32(
1058 &self,
1059 a: &GpuBufferHandle,
1060 outer: usize,
1061 dim_size: usize,
1062 inner: usize,
1063 ) -> FerrotorchResult<GpuBufferHandle> {
1064 let a_buf = Self::unwrap_buffer(a)?;
1065 let dev = self.device(a.device_ordinal())?;
1066 let result = crate::kernels::gpu_logcumsumexp(a_buf, outer, dim_size, inner, dev)
1067 .map_err(Self::map_gpu_err)?;
1068 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1069 }
1070
1071 fn clamp_f32(
1072 &self,
1073 a: &GpuBufferHandle,
1074 min_val: f32,
1075 max_val: f32,
1076 ) -> FerrotorchResult<GpuBufferHandle> {
1077 let a_buf = Self::unwrap_buffer(a)?;
1078 let dev = self.device(a.device_ordinal())?;
1079 let result =
1080 crate::kernels::gpu_clamp(a_buf, min_val, max_val, dev).map_err(Self::map_gpu_err)?;
1081 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1082 }
1083
1084 fn silu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
1085 let a_buf = Self::unwrap_buffer(a)?;
1086 let dev = self.device(a.device_ordinal())?;
1087 let result = crate::kernels::gpu_silu(a_buf, dev).map_err(Self::map_gpu_err)?;
1088 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1089 }
1090
1091 fn silu_backward_f32(
1092 &self,
1093 grad: &GpuBufferHandle,
1094 input: &GpuBufferHandle,
1095 ) -> FerrotorchResult<GpuBufferHandle> {
1096 let grad_buf = Self::unwrap_buffer(grad)?;
1097 let input_buf = Self::unwrap_buffer(input)?;
1098 let dev = self.device(grad.device_ordinal())?;
1099 let result = crate::kernels::gpu_silu_backward(grad_buf, input_buf, dev)
1100 .map_err(Self::map_gpu_err)?;
1101 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1102 }
1103
1104 fn elu_f32(&self, a: &GpuBufferHandle, alpha: f32) -> FerrotorchResult<GpuBufferHandle> {
1105 let a_buf = Self::unwrap_buffer(a)?;
1106 let dev = self.device(a.device_ordinal())?;
1107 let result = crate::kernels::gpu_elu(a_buf, alpha, dev).map_err(Self::map_gpu_err)?;
1108 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1109 }
1110
1111 fn elu_backward_f32(
1112 &self,
1113 grad: &GpuBufferHandle,
1114 input: &GpuBufferHandle,
1115 alpha: f32,
1116 ) -> FerrotorchResult<GpuBufferHandle> {
1117 let grad_buf = Self::unwrap_buffer(grad)?;
1118 let input_buf = Self::unwrap_buffer(input)?;
1119 let dev = self.device(grad.device_ordinal())?;
1120 let result = crate::kernels::gpu_elu_backward(grad_buf, input_buf, alpha, dev)
1121 .map_err(Self::map_gpu_err)?;
1122 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1123 }
1124
1125 fn mish_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
1126 let a_buf = Self::unwrap_buffer(a)?;
1127 let dev = self.device(a.device_ordinal())?;
1128 let result = crate::kernels::gpu_mish(a_buf, dev).map_err(Self::map_gpu_err)?;
1129 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1130 }
1131
1132 fn mish_backward_f32(
1133 &self,
1134 grad: &GpuBufferHandle,
1135 input: &GpuBufferHandle,
1136 ) -> FerrotorchResult<GpuBufferHandle> {
1137 let grad_buf = Self::unwrap_buffer(grad)?;
1138 let input_buf = Self::unwrap_buffer(input)?;
1139 let dev = self.device(grad.device_ordinal())?;
1140 let result = crate::kernels::gpu_mish_backward(grad_buf, input_buf, dev)
1141 .map_err(Self::map_gpu_err)?;
1142 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1143 }
1144
1145 fn log_softmax_f32(
1146 &self,
1147 a: &GpuBufferHandle,
1148 cols: usize,
1149 ) -> FerrotorchResult<GpuBufferHandle> {
1150 let a_buf = Self::unwrap_buffer(a)?;
1151 let dev = self.device(a.device_ordinal())?;
1152 let result =
1153 crate::kernels::gpu_log_softmax(a_buf, cols, dev).map_err(Self::map_gpu_err)?;
1154 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1155 }
1156
1157 fn log_softmax_backward_f32(
1158 &self,
1159 grad: &GpuBufferHandle,
1160 output: &GpuBufferHandle,
1161 cols: usize,
1162 ) -> FerrotorchResult<GpuBufferHandle> {
1163 let grad_buf = Self::unwrap_buffer(grad)?;
1164 let output_buf = Self::unwrap_buffer(output)?;
1165 let dev = self.device(grad.device_ordinal())?;
1166 let result =
1167 crate::kernels::gpu_log_softmax_backward(grad_buf, output_buf, cols, dev)
1168 .map_err(Self::map_gpu_err)?;
1169 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1170 }
1171
1172 fn index_select_1d_f32(
1173 &self,
1174 input: &GpuBufferHandle,
1175 indices: &GpuBufferHandle,
1176 ) -> FerrotorchResult<GpuBufferHandle> {
1177 let input_buf = Self::unwrap_buffer(input)?;
1178 let idx_buf = Self::unwrap_buffer(indices)?;
1179 let dev = self.device(input.device_ordinal())?;
1180 let result = crate::kernels::gpu_index_select_1d(input_buf, idx_buf, dev)
1181 .map_err(Self::map_gpu_err)?;
1182 Ok(Self::wrap_buffer(result, input.device_ordinal()))
1183 }
1184
1185 fn scatter_add_1d_f32(
1186 &self,
1187 grad_output: &GpuBufferHandle,
1188 indices: &GpuBufferHandle,
1189 input_len: usize,
1190 ) -> FerrotorchResult<GpuBufferHandle> {
1191 let go_buf = Self::unwrap_buffer(grad_output)?;
1192 let idx_buf = Self::unwrap_buffer(indices)?;
1193 let dev = self.device(grad_output.device_ordinal())?;
1194 let result = crate::kernels::gpu_scatter_add_1d(go_buf, idx_buf, input_len, dev)
1195 .map_err(Self::map_gpu_err)?;
1196 Ok(Self::wrap_buffer(result, grad_output.device_ordinal()))
1197 }
1198
1199 fn masked_fill_f32(
1200 &self,
1201 input: &GpuBufferHandle,
1202 mask: &GpuBufferHandle,
1203 value: f32,
1204 ) -> FerrotorchResult<GpuBufferHandle> {
1205 let input_buf = Self::unwrap_buffer(input)?;
1206 let mask_buf = Self::unwrap_buffer(mask)?;
1207 let dev = self.device(input.device_ordinal())?;
1208 let result = crate::kernels::gpu_masked_fill(input_buf, mask_buf, value, dev)
1209 .map_err(Self::map_gpu_err)?;
1210 Ok(Self::wrap_buffer(result, input.device_ordinal()))
1211 }
1212
1213 fn masked_zero_f32(
1214 &self,
1215 grad: &GpuBufferHandle,
1216 mask: &GpuBufferHandle,
1217 ) -> FerrotorchResult<GpuBufferHandle> {
1218 let grad_buf = Self::unwrap_buffer(grad)?;
1219 let mask_buf = Self::unwrap_buffer(mask)?;
1220 let dev = self.device(grad.device_ordinal())?;
1221 let result =
1222 crate::kernels::gpu_masked_zero(grad_buf, mask_buf, dev).map_err(Self::map_gpu_err)?;
1223 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1224 }
1225
1226 fn sigmoid_backward_f32(
1227 &self,
1228 grad: &GpuBufferHandle,
1229 output: &GpuBufferHandle,
1230 ) -> FerrotorchResult<GpuBufferHandle> {
1231 let grad_buf = Self::unwrap_buffer(grad)?;
1232 let output_buf = Self::unwrap_buffer(output)?;
1233 let dev = self.device(grad.device_ordinal())?;
1234 let result = crate::kernels::gpu_sigmoid_backward(grad_buf, output_buf, dev)
1235 .map_err(Self::map_gpu_err)?;
1236 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1237 }
1238
1239 fn tanh_backward_f32(
1240 &self,
1241 grad: &GpuBufferHandle,
1242 output: &GpuBufferHandle,
1243 ) -> FerrotorchResult<GpuBufferHandle> {
1244 let grad_buf = Self::unwrap_buffer(grad)?;
1245 let output_buf = Self::unwrap_buffer(output)?;
1246 let dev = self.device(grad.device_ordinal())?;
1247 let result = crate::kernels::gpu_tanh_backward(grad_buf, output_buf, dev)
1248 .map_err(Self::map_gpu_err)?;
1249 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1250 }
1251
1252 fn softmax_backward_f32(
1253 &self,
1254 grad: &GpuBufferHandle,
1255 output: &GpuBufferHandle,
1256 cols: usize,
1257 ) -> FerrotorchResult<GpuBufferHandle> {
1258 let grad_buf = Self::unwrap_buffer(grad)?;
1259 let output_buf = Self::unwrap_buffer(output)?;
1260 let dev = self.device(grad.device_ordinal())?;
1261 let result = crate::kernels::gpu_softmax_backward(grad_buf, output_buf, cols, dev)
1262 .map_err(Self::map_gpu_err)?;
1263 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1264 }
1265
1266 fn layernorm_backward_f32(
1267 &self,
1268 input: &GpuBufferHandle,
1269 grad_output: &GpuBufferHandle,
1270 weight: &GpuBufferHandle,
1271 rows: usize,
1272 cols: usize,
1273 eps: f32,
1274 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
1275 let in_buf = Self::unwrap_buffer(input)?;
1276 let go_buf = Self::unwrap_buffer(grad_output)?;
1277 let w_buf = Self::unwrap_buffer(weight)?;
1278 let dev = self.device(input.device_ordinal())?;
1279 let (gi, gw, gb) =
1280 crate::kernels::gpu_layernorm_backward(in_buf, go_buf, w_buf, rows, cols, eps, dev)
1281 .map_err(Self::map_gpu_err)?;
1282 let ordinal = input.device_ordinal();
1283 Ok((
1284 Self::wrap_buffer(gi, ordinal),
1285 Self::wrap_buffer(gw, ordinal),
1286 Self::wrap_buffer(gb, ordinal),
1287 ))
1288 }
1289
1290 fn sum_axis_f32(
1291 &self,
1292 a: &GpuBufferHandle,
1293 shape: &[usize],
1294 axis: usize,
1295 ) -> FerrotorchResult<GpuBufferHandle> {
1296 let a_buf = Self::unwrap_buffer(a)?;
1297 let dev = self.device(a.device_ordinal())?;
1298 let outer: usize = shape[..axis].iter().product();
1299 let axis_size = shape[axis];
1300 let inner: usize = shape[axis + 1..].iter().product::<usize>().max(1);
1301 let result = crate::kernels::gpu_sum_axis(a_buf, outer, axis_size, inner, dev)
1302 .map_err(Self::map_gpu_err)?;
1303 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1304 }
1305
1306 fn matmul_f16_f32(
1307 &self,
1308 a: &GpuBufferHandle,
1309 b: &GpuBufferHandle,
1310 m: usize,
1311 k: usize,
1312 n: usize,
1313 ) -> FerrotorchResult<GpuBufferHandle> {
1314 let a_buf = Self::unwrap_buffer(a)?;
1315 let b_buf = Self::unwrap_buffer(b)?;
1316 let dev = self.device(a.device_ordinal())?;
1317 let result =
1318 crate::blas::gpu_matmul_f16(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
1319 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1320 }
1321
1322 fn save_rng_state(&self, device: usize) -> FerrotorchResult<GpuRngState> {
1323 let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
1324 FerrotorchError::InvalidArgument {
1325 message: "failed to lock CUDA RNG manager".into(),
1326 }
1327 })?;
1328 let state = mgr.get_rng_state(device);
1329 Ok(GpuRngState {
1330 counter: state.counter,
1331 seed: state.seed,
1332 offset: state.offset,
1333 device,
1334 })
1335 }
1336
1337 fn restore_rng_state(&self, state: GpuRngState) -> FerrotorchResult<()> {
1338 let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
1339 FerrotorchError::InvalidArgument {
1340 message: "failed to lock CUDA RNG manager".into(),
1341 }
1342 })?;
1343 mgr.set_rng_state(
1344 state.device,
1345 crate::rng::PhiloxState {
1346 counter: state.counter,
1347 seed: state.seed,
1348 offset: state.offset,
1349 },
1350 );
1351 Ok(())
1352 }
1353
1354 fn strided_split_f32(
1355 &self,
1356 input: &GpuBufferHandle,
1357 total_along_axis: usize,
1358 split_offset: usize,
1359 split_size: usize,
1360 inner_size: usize,
1361 n: usize,
1362 ) -> FerrotorchResult<GpuBufferHandle> {
1363 let in_buf = Self::unwrap_buffer(input)?;
1364 let dev = self.device(input.device_ordinal())?;
1365 let result = crate::kernels::gpu_strided_split(
1366 in_buf,
1367 total_along_axis,
1368 split_offset,
1369 split_size,
1370 inner_size,
1371 n,
1372 dev,
1373 )
1374 .map_err(Self::map_gpu_err)?;
1375 Ok(Self::wrap_buffer(result, input.device_ordinal()))
1376 }
1377
1378 fn strided_cat_f32(
1379 &self,
1380 input: &GpuBufferHandle,
1381 output: &mut GpuBufferHandle,
1382 total_along_axis: usize,
1383 cat_offset: usize,
1384 part_size: usize,
1385 inner_size: usize,
1386 n: usize,
1387 ) -> FerrotorchResult<()> {
1388 let in_buf = Self::unwrap_buffer(input)?;
1389 let dev = self.device(input.device_ordinal())?;
1390 let out_buf =
1391 output
1392 .downcast_mut::<CudaBuffer<f32>>()
1393 .ok_or(FerrotorchError::InvalidArgument {
1394 message: "strided_cat_f32: output is not CudaBuffer<f32>".into(),
1395 })?;
1396 crate::kernels::gpu_strided_cat(
1397 in_buf,
1398 out_buf,
1399 total_along_axis,
1400 cat_offset,
1401 part_size,
1402 inner_size,
1403 n,
1404 dev,
1405 )
1406 .map_err(Self::map_gpu_err)?;
1407 Ok(())
1408 }
1409}
1410
1411pub fn get_cuda_device() -> FerrotorchResult<Arc<GpuDevice>> {
1422 let backend =
1423 ferrotorch_core::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
1424 let cuda_backend = backend.as_any().downcast_ref::<CudaBackendImpl>().ok_or(
1427 FerrotorchError::InvalidArgument {
1428 message: "registered GPU backend is not CudaBackendImpl".into(),
1429 },
1430 )?;
1431 Ok(Arc::clone(cuda_backend.default_device()?))
1432}
1433
1434pub fn init_cuda_backend() -> FerrotorchResult<()> {
1448 if ferrotorch_core::gpu_dispatch::has_gpu_backend() {
1450 return Ok(());
1451 }
1452 let backend = CudaBackendImpl::new()?;
1453 let _ = ferrotorch_core::gpu_dispatch::register_gpu_backend(Box::new(backend));
1457 Ok(())
1458}
1459
1460#[cfg(test)]
1465#[cfg(feature = "cuda")]
1466mod tests {
1467 use super::*;
1468 use ferrotorch_core::gpu_dispatch;
1469
1470 fn ensure_init() {
1477 if !gpu_dispatch::has_gpu_backend() {
1478 init_cuda_backend().expect("init_cuda_backend");
1479 }
1480 }
1481
1482 #[test]
1483 fn test_init_cuda_backend() {
1484 ensure_init();
1486 assert!(gpu_dispatch::has_gpu_backend());
1487 }
1488
1489 #[test]
1490 fn test_gpu_backend_returns_some() {
1491 ensure_init();
1492 assert!(gpu_dispatch::gpu_backend().is_some());
1493 }
1494
1495 #[test]
1496 fn test_roundtrip_cpu_gpu_cpu() {
1497 ensure_init();
1498 let backend = gpu_dispatch::gpu_backend().expect("backend registered");
1499
1500 let host: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1501 let bytes: &[u8] = unsafe {
1502 std::slice::from_raw_parts(
1503 host.as_ptr() as *const u8,
1504 host.len() * std::mem::size_of::<f32>(),
1505 )
1506 };
1507
1508 let handle = backend.cpu_to_gpu(bytes, 4, 0).expect("cpu_to_gpu");
1509 assert_eq!(handle.len(), 5);
1510 assert_eq!(handle.device_ordinal(), 0);
1511
1512 let back_bytes = backend.gpu_to_cpu(&handle).expect("gpu_to_cpu");
1513 let back: &[f32] = unsafe {
1514 std::slice::from_raw_parts(back_bytes.as_ptr() as *const f32, back_bytes.len() / 4)
1515 };
1516 assert_eq!(back, &host[..]);
1517 }
1518
1519 #[test]
1520 fn test_add_f32() {
1521 ensure_init();
1522 let backend = gpu_dispatch::gpu_backend().expect("backend registered");
1523
1524 let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
1525 let b_data: Vec<f32> = vec![10.0, 20.0, 30.0, 40.0];
1526 let expected: Vec<f32> = vec![11.0, 22.0, 33.0, 44.0];
1527
1528 let a_bytes: &[u8] =
1529 unsafe { std::slice::from_raw_parts(a_data.as_ptr() as *const u8, a_data.len() * 4) };
1530 let b_bytes: &[u8] =
1531 unsafe { std::slice::from_raw_parts(b_data.as_ptr() as *const u8, b_data.len() * 4) };
1532
1533 let a_handle = backend.cpu_to_gpu(a_bytes, 4, 0).expect("cpu_to_gpu a");
1534 let b_handle = backend.cpu_to_gpu(b_bytes, 4, 0).expect("cpu_to_gpu b");
1535
1536 let result = backend.add_f32(&a_handle, &b_handle).expect("add_f32");
1537 assert_eq!(result.len(), 4);
1538
1539 let result_bytes = backend.gpu_to_cpu(&result).expect("gpu_to_cpu");
1540 let result_f32: &[f32] = unsafe {
1541 std::slice::from_raw_parts(result_bytes.as_ptr() as *const f32, result_bytes.len() / 4)
1542 };
1543
1544 for (i, (&got, &exp)) in result_f32.iter().zip(expected.iter()).enumerate() {
1545 assert!(
1546 (got - exp).abs() < 1e-6,
1547 "element {i}: got {got}, expected {exp}",
1548 );
1549 }
1550 }
1551
1552 #[test]
1553 fn test_matmul_f32() {
1554 ensure_init();
1555 let backend = gpu_dispatch::gpu_backend().expect("backend registered");
1556
1557 let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1565 let b_data: Vec<f32> = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
1566 let expected: Vec<f32> = vec![58.0, 64.0, 139.0, 154.0];
1567
1568 let a_bytes: &[u8] =
1569 unsafe { std::slice::from_raw_parts(a_data.as_ptr() as *const u8, a_data.len() * 4) };
1570 let b_bytes: &[u8] =
1571 unsafe { std::slice::from_raw_parts(b_data.as_ptr() as *const u8, b_data.len() * 4) };
1572
1573 let a_handle = backend.cpu_to_gpu(a_bytes, 4, 0).expect("cpu_to_gpu a");
1574 let b_handle = backend.cpu_to_gpu(b_bytes, 4, 0).expect("cpu_to_gpu b");
1575
1576 let result = backend
1577 .matmul_f32(&a_handle, &b_handle, 2, 3, 2)
1578 .expect("matmul_f32");
1579 assert_eq!(result.len(), 4);
1580
1581 let result_bytes = backend.gpu_to_cpu(&result).expect("gpu_to_cpu");
1582 let result_f32: &[f32] = unsafe {
1583 std::slice::from_raw_parts(result_bytes.as_ptr() as *const f32, result_bytes.len() / 4)
1584 };
1585
1586 for (i, (&got, &exp)) in result_f32.iter().zip(expected.iter()).enumerate() {
1587 assert!(
1588 (got - exp).abs() < 1e-3,
1589 "element {i}: got {got}, expected {exp}",
1590 );
1591 }
1592 }
1593}