1use std::sync::Arc;
15
16use ferrotorch_core::error::{FerrotorchError, FerrotorchResult};
17use ferrotorch_core::gpu_dispatch::{GpuBackend, GpuBufferHandle, GpuRngState};
18
19use crate::buffer::CudaBuffer;
20use crate::device::GpuDevice;
21
22pub struct CudaBackendImpl {
32 devices: Vec<Arc<GpuDevice>>,
33}
34
35impl CudaBackendImpl {
36 pub fn new() -> FerrotorchResult<Self> {
43 let device = Arc::new(
44 GpuDevice::new(0).map_err(|e| FerrotorchError::InvalidArgument {
45 message: format!("CUDA init failed: {e}"),
46 })?,
47 );
48 Ok(Self {
49 devices: vec![device],
50 })
51 }
52
53 pub fn default_device(&self) -> FerrotorchResult<&Arc<GpuDevice>> {
55 self.device(0)
56 }
57
58 fn device(&self, ordinal: usize) -> FerrotorchResult<&Arc<GpuDevice>> {
60 self.devices
61 .get(ordinal)
62 .ok_or(FerrotorchError::InvalidArgument {
63 message: format!("CUDA device {ordinal} not available"),
64 })
65 }
66
67 fn wrap_buffer(buf: CudaBuffer<f32>, ordinal: usize) -> GpuBufferHandle {
69 let len = buf.len();
70 GpuBufferHandle::new(Box::new(buf), ordinal, len)
71 }
72
73 fn wrap_buffer_f64(buf: CudaBuffer<f64>, ordinal: usize) -> GpuBufferHandle {
75 let len = buf.len();
76 GpuBufferHandle::new(Box::new(buf), ordinal, len)
77 }
78
79 fn unwrap_buffer(handle: &GpuBufferHandle) -> FerrotorchResult<&CudaBuffer<f32>> {
81 handle
82 .downcast_ref::<CudaBuffer<f32>>()
83 .ok_or(FerrotorchError::InvalidArgument {
84 message: "GPU handle does not contain a CudaBuffer<f32>".into(),
85 })
86 }
87
88 fn unwrap_buffer_mut(handle: &mut GpuBufferHandle) -> FerrotorchResult<&mut CudaBuffer<f32>> {
90 handle
91 .downcast_mut::<CudaBuffer<f32>>()
92 .ok_or(FerrotorchError::InvalidArgument {
93 message: "GPU handle does not contain a CudaBuffer<f32>".into(),
94 })
95 }
96
97 fn unwrap_buffer_f64(handle: &GpuBufferHandle) -> FerrotorchResult<&CudaBuffer<f64>> {
99 handle
100 .downcast_ref::<CudaBuffer<f64>>()
101 .ok_or(FerrotorchError::InvalidArgument {
102 message: "GPU handle does not contain a CudaBuffer<f64>".into(),
103 })
104 }
105
106 fn map_gpu_err(e: crate::error::GpuError) -> FerrotorchError {
108 FerrotorchError::InvalidArgument {
109 message: format!("{e}"),
110 }
111 }
112}
113
114impl GpuBackend for CudaBackendImpl {
119 fn as_any(&self) -> &dyn std::any::Any {
120 self
121 }
122
123 fn cpu_to_gpu(
124 &self,
125 data: &[u8],
126 elem_size: usize,
127 device: usize,
128 ) -> FerrotorchResult<GpuBufferHandle> {
129 let dev = self.device(device)?;
130 match elem_size {
131 4 => {
132 let count = data.len() / 4;
135 let f32_data: &[f32] =
136 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, count) };
137 let buf = crate::transfer::cpu_to_gpu(f32_data, dev).map_err(Self::map_gpu_err)?;
138 Ok(Self::wrap_buffer(buf, device))
139 }
140 8 => {
141 let count = data.len() / 8;
144 let f64_data: &[f64] =
145 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f64, count) };
146 let buf = crate::transfer::cpu_to_gpu(f64_data, dev).map_err(Self::map_gpu_err)?;
147 Ok(Self::wrap_buffer_f64(buf, device))
148 }
149 other => Err(FerrotorchError::InvalidArgument {
150 message: format!("cpu_to_gpu: unsupported elem_size {other} (expected 4 or 8)"),
151 }),
152 }
153 }
154
155 fn cpu_to_gpu_pinned(
156 &self,
157 data: &[u8],
158 elem_size: usize,
159 device: usize,
160 ) -> FerrotorchResult<GpuBufferHandle> {
161 let dev = self.device(device)?;
162 match elem_size {
163 4 => {
164 let count = data.len() / 4;
165 let f32_data: &[f32] =
166 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, count) };
167 let buf = crate::transfer::cpu_to_gpu_pinned(f32_data, dev)
168 .map_err(Self::map_gpu_err)?;
169 Ok(Self::wrap_buffer(buf, device))
170 }
171 8 => {
172 let count = data.len() / 8;
173 let f64_data: &[f64] =
174 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f64, count) };
175 let buf = crate::transfer::cpu_to_gpu_pinned(f64_data, dev)
176 .map_err(Self::map_gpu_err)?;
177 Ok(Self::wrap_buffer_f64(buf, device))
178 }
179 other => Err(FerrotorchError::InvalidArgument {
180 message: format!(
181 "cpu_to_gpu_pinned: unsupported elem_size {other} (expected 4 or 8)"
182 ),
183 }),
184 }
185 }
186
187 fn gpu_to_cpu(&self, handle: &GpuBufferHandle) -> FerrotorchResult<Vec<u8>> {
188 let dev = self.device(handle.device_ordinal())?;
189
190 if let Ok(buf) = Self::unwrap_buffer(handle) {
192 let f32_data = crate::transfer::gpu_to_cpu(buf, dev).map_err(Self::map_gpu_err)?;
193
194 let bytes = unsafe {
199 let mut v = std::mem::ManuallyDrop::new(f32_data);
200 let ptr = v.as_mut_ptr() as *mut u8;
201 let len = v.len() * 4;
202 let cap = v.capacity() * 4;
203 Vec::from_raw_parts(ptr, len, cap)
204 };
205 Ok(bytes)
206 } else if let Ok(buf) = Self::unwrap_buffer_f64(handle) {
207 let f64_data = crate::transfer::gpu_to_cpu(buf, dev).map_err(Self::map_gpu_err)?;
208
209 let bytes = unsafe {
214 let mut v = std::mem::ManuallyDrop::new(f64_data);
215 let ptr = v.as_mut_ptr() as *mut u8;
216 let len = v.len() * 8;
217 let cap = v.capacity() * 8;
218 Vec::from_raw_parts(ptr, len, cap)
219 };
220 Ok(bytes)
221 } else {
222 Err(FerrotorchError::InvalidArgument {
223 message: "gpu_to_cpu: handle is neither CudaBuffer<f32> nor CudaBuffer<f64>".into(),
224 })
225 }
226 }
227
228 fn clone_buffer(&self, handle: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
229 let bytes = self.gpu_to_cpu(handle)?;
232 let elem_size = if handle.downcast_ref::<CudaBuffer<f64>>().is_some() {
234 8
235 } else {
236 4
237 };
238 self.cpu_to_gpu(&bytes, elem_size, handle.device_ordinal())
239 }
240
241 fn alloc_zeros(
242 &self,
243 len: usize,
244 elem_size: usize,
245 device: usize,
246 ) -> FerrotorchResult<GpuBufferHandle> {
247 let dev = self.device(device)?;
248 match elem_size {
249 4 => {
250 let buf = crate::transfer::alloc_zeros_f32(len, dev).map_err(Self::map_gpu_err)?;
251 Ok(Self::wrap_buffer(buf, device))
252 }
253 8 => {
254 let buf = crate::transfer::alloc_zeros_f64(len, dev).map_err(Self::map_gpu_err)?;
255 Ok(Self::wrap_buffer_f64(buf, device))
256 }
257 other => Err(FerrotorchError::InvalidArgument {
258 message: format!("alloc_zeros: unsupported elem_size {other} (expected 4 or 8)"),
259 }),
260 }
261 }
262
263 fn add_f32(
266 &self,
267 a: &GpuBufferHandle,
268 b: &GpuBufferHandle,
269 ) -> FerrotorchResult<GpuBufferHandle> {
270 let a_buf = Self::unwrap_buffer(a)?;
271 let b_buf = Self::unwrap_buffer(b)?;
272 let dev = self.device(a.device_ordinal())?;
273 let result = crate::kernels::gpu_add(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
274 Ok(Self::wrap_buffer(result, a.device_ordinal()))
275 }
276
277 fn sub_f32(
278 &self,
279 a: &GpuBufferHandle,
280 b: &GpuBufferHandle,
281 ) -> FerrotorchResult<GpuBufferHandle> {
282 let a_buf = Self::unwrap_buffer(a)?;
283 let b_buf = Self::unwrap_buffer(b)?;
284 let dev = self.device(a.device_ordinal())?;
285 let result = crate::kernels::gpu_sub(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
286 Ok(Self::wrap_buffer(result, a.device_ordinal()))
287 }
288
289 fn mul_f32(
290 &self,
291 a: &GpuBufferHandle,
292 b: &GpuBufferHandle,
293 ) -> FerrotorchResult<GpuBufferHandle> {
294 let a_buf = Self::unwrap_buffer(a)?;
295 let b_buf = Self::unwrap_buffer(b)?;
296 let dev = self.device(a.device_ordinal())?;
297 let result = crate::kernels::gpu_mul(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
298 Ok(Self::wrap_buffer(result, a.device_ordinal()))
299 }
300
301 fn neg_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
302 let a_buf = Self::unwrap_buffer(a)?;
303 let dev = self.device(a.device_ordinal())?;
304 let result = crate::kernels::gpu_neg(a_buf, dev).map_err(Self::map_gpu_err)?;
305 Ok(Self::wrap_buffer(result, a.device_ordinal()))
306 }
307
308 fn relu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
309 let a_buf = Self::unwrap_buffer(a)?;
310 let dev = self.device(a.device_ordinal())?;
311 let result = crate::kernels::gpu_relu(a_buf, dev).map_err(Self::map_gpu_err)?;
312 Ok(Self::wrap_buffer(result, a.device_ordinal()))
313 }
314
315 fn div_f32(
316 &self,
317 a: &GpuBufferHandle,
318 b: &GpuBufferHandle,
319 ) -> FerrotorchResult<GpuBufferHandle> {
320 let a_buf = Self::unwrap_buffer(a)?;
321 let b_buf = Self::unwrap_buffer(b)?;
322 let dev = self.device(a.device_ordinal())?;
323 let result = crate::kernels::gpu_div(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
324 Ok(Self::wrap_buffer(result, a.device_ordinal()))
325 }
326
327 fn exp_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
328 let a_buf = Self::unwrap_buffer(a)?;
329 let dev = self.device(a.device_ordinal())?;
330 let result = crate::kernels::gpu_exp(a_buf, dev).map_err(Self::map_gpu_err)?;
331 Ok(Self::wrap_buffer(result, a.device_ordinal()))
332 }
333
334 fn log_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
335 let a_buf = Self::unwrap_buffer(a)?;
336 let dev = self.device(a.device_ordinal())?;
337 let result = crate::kernels::gpu_log(a_buf, dev).map_err(Self::map_gpu_err)?;
338 Ok(Self::wrap_buffer(result, a.device_ordinal()))
339 }
340
341 fn sqrt_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
342 let a_buf = Self::unwrap_buffer(a)?;
343 let dev = self.device(a.device_ordinal())?;
344 let result = crate::kernels::gpu_sqrt(a_buf, dev).map_err(Self::map_gpu_err)?;
345 Ok(Self::wrap_buffer(result, a.device_ordinal()))
346 }
347
348 fn pow_f32(&self, a: &GpuBufferHandle, exponent: f32) -> FerrotorchResult<GpuBufferHandle> {
349 let a_buf = Self::unwrap_buffer(a)?;
350 let dev = self.device(a.device_ordinal())?;
351 let result =
352 crate::kernels::gpu_pow(a_buf, exponent, dev).map_err(Self::map_gpu_err)?;
353 Ok(Self::wrap_buffer(result, a.device_ordinal()))
354 }
355
356 fn abs_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
357 let a_buf = Self::unwrap_buffer(a)?;
358 let dev = self.device(a.device_ordinal())?;
359 let result = crate::kernels::gpu_abs(a_buf, dev).map_err(Self::map_gpu_err)?;
360 Ok(Self::wrap_buffer(result, a.device_ordinal()))
361 }
362
363 fn sigmoid_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
364 let a_buf = Self::unwrap_buffer(a)?;
365 let dev = self.device(a.device_ordinal())?;
366 let result = crate::kernels::gpu_sigmoid(a_buf, dev).map_err(Self::map_gpu_err)?;
367 Ok(Self::wrap_buffer(result, a.device_ordinal()))
368 }
369
370 fn tanh_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
371 let a_buf = Self::unwrap_buffer(a)?;
372 let dev = self.device(a.device_ordinal())?;
373 let result = crate::kernels::gpu_tanh(a_buf, dev).map_err(Self::map_gpu_err)?;
374 Ok(Self::wrap_buffer(result, a.device_ordinal()))
375 }
376
377 #[allow(clippy::too_many_arguments)]
378 fn fused_adam_f32(
379 &self,
380 param: &mut GpuBufferHandle,
381 grad: &GpuBufferHandle,
382 exp_avg: &mut GpuBufferHandle,
383 exp_avg_sq: &mut GpuBufferHandle,
384 beta1: f32,
385 beta2: f32,
386 lr: f32,
387 eps: f32,
388 bc1: f32,
389 bc2: f32,
390 weight_decay: f32,
391 ) -> FerrotorchResult<()> {
392 let ordinal = param.device_ordinal();
393 let dev = self.device(ordinal)?;
394 let p_buf = Self::unwrap_buffer_mut(param)?;
395 let g_buf = Self::unwrap_buffer(grad)?;
396 let m_buf = Self::unwrap_buffer_mut(exp_avg)?;
397 let v_buf = Self::unwrap_buffer_mut(exp_avg_sq)?;
398 crate::kernels::gpu_fused_adam(
399 p_buf,
400 g_buf,
401 m_buf,
402 v_buf,
403 beta1,
404 beta2,
405 lr,
406 eps,
407 bc1,
408 bc2,
409 weight_decay,
410 dev,
411 )
412 .map_err(Self::map_gpu_err)?;
413 Ok(())
414 }
415
416 #[allow(clippy::too_many_arguments)]
417 fn maxpool2d_f32(
418 &self,
419 input: &GpuBufferHandle,
420 batch: usize,
421 channels: usize,
422 h_in: usize,
423 w_in: usize,
424 kh: usize,
425 kw: usize,
426 sh: usize,
427 sw: usize,
428 ph: usize,
429 pw: usize,
430 ) -> FerrotorchResult<(GpuBufferHandle, [usize; 4])> {
431 let buf = Self::unwrap_buffer(input)?;
432 let dev = self.device(input.device_ordinal())?;
433 let (out, shape) = crate::kernels::gpu_maxpool2d(
434 buf, batch, channels, h_in, w_in, kh, kw, sh, sw, ph, pw, dev,
435 ).map_err(Self::map_gpu_err)?;
436 Ok((Self::wrap_buffer(out, input.device_ordinal()), shape))
437 }
438
439 #[allow(clippy::too_many_arguments)]
440 fn avgpool2d_f32(
441 &self,
442 input: &GpuBufferHandle,
443 batch: usize,
444 channels: usize,
445 h_in: usize,
446 w_in: usize,
447 kh: usize,
448 kw: usize,
449 sh: usize,
450 sw: usize,
451 ph: usize,
452 pw: usize,
453 ) -> FerrotorchResult<(GpuBufferHandle, [usize; 4])> {
454 let buf = Self::unwrap_buffer(input)?;
455 let dev = self.device(input.device_ordinal())?;
456 let (out, shape) = crate::kernels::gpu_avgpool2d(
457 buf, batch, channels, h_in, w_in, kh, kw, sh, sw, ph, pw, dev,
458 ).map_err(Self::map_gpu_err)?;
459 Ok((Self::wrap_buffer(out, input.device_ordinal()), shape))
460 }
461
462 #[allow(clippy::too_many_arguments)]
463 fn conv2d_f32(
464 &self,
465 input: &GpuBufferHandle,
466 weight: &GpuBufferHandle,
467 bias: Option<&GpuBufferHandle>,
468 input_shape: [usize; 4],
469 weight_shape: [usize; 4],
470 stride: (usize, usize),
471 padding: (usize, usize),
472 ) -> FerrotorchResult<(GpuBufferHandle, [usize; 4])> {
473 let input_buf = Self::unwrap_buffer(input)?;
474 let weight_buf = Self::unwrap_buffer(weight)?;
475 let bias_buf = match bias {
476 Some(b) => Some(Self::unwrap_buffer(b)?),
477 None => None,
478 };
479 let dev = self.device(input.device_ordinal())?;
480 let (out_buf, out_shape) = crate::conv::gpu_conv2d_f32(
481 input_buf,
482 weight_buf,
483 bias_buf,
484 input_shape,
485 weight_shape,
486 stride,
487 padding,
488 dev,
489 )
490 .map_err(Self::map_gpu_err)?;
491 Ok((Self::wrap_buffer(out_buf, input.device_ordinal()), out_shape))
492 }
493
494 fn fused_gru_cell_f32(
495 &self,
496 input_gates: &GpuBufferHandle,
497 hidden_gates: &GpuBufferHandle,
498 bias_ih: &GpuBufferHandle,
499 bias_hh: &GpuBufferHandle,
500 hx: &GpuBufferHandle,
501 hidden_size: usize,
502 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
503 let ig = Self::unwrap_buffer(input_gates)?;
504 let hg = Self::unwrap_buffer(hidden_gates)?;
505 let bih = Self::unwrap_buffer(bias_ih)?;
506 let bhh = Self::unwrap_buffer(bias_hh)?;
507 let hx_buf = Self::unwrap_buffer(hx)?;
508 let dev = self.device(input_gates.device_ordinal())?;
509 let (hy, ws) = crate::kernels::gpu_fused_gru_forward(
510 ig, hg, bih, bhh, hx_buf, hidden_size, dev,
511 )
512 .map_err(Self::map_gpu_err)?;
513 let ord = input_gates.device_ordinal();
514 Ok((Self::wrap_buffer(hy, ord), Self::wrap_buffer(ws, ord)))
515 }
516
517 fn synchronize(&self, device: usize) -> FerrotorchResult<()> {
518 let dev = self.device(device)?;
519 dev.stream()
520 .synchronize()
521 .map_err(|e| FerrotorchError::InvalidArgument {
522 message: format!("CUDA synchronize failed: {e}"),
523 })?;
524 Ok(())
525 }
526
527 fn stream_count(&self, device: usize) -> usize {
528 crate::stream::StreamPool::pool_size(device)
529 }
530
531 fn matmul_f32(
534 &self,
535 a: &GpuBufferHandle,
536 b: &GpuBufferHandle,
537 m: usize,
538 k: usize,
539 n: usize,
540 ) -> FerrotorchResult<GpuBufferHandle> {
541 let a_buf = Self::unwrap_buffer(a)?;
542 let b_buf = Self::unwrap_buffer(b)?;
543 let dev = self.device(a.device_ordinal())?;
544 let result =
545 crate::blas::gpu_matmul_f32(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
546 Ok(Self::wrap_buffer(result, a.device_ordinal()))
547 }
548
549 fn sum_f32(&self, a: &GpuBufferHandle, _len: usize) -> FerrotorchResult<GpuBufferHandle> {
552 let a_buf = Self::unwrap_buffer(a)?;
553 let dev = self.device(a.device_ordinal())?;
554 let result = crate::kernels::gpu_reduce_sum(a_buf, dev).map_err(Self::map_gpu_err)?;
555 Ok(Self::wrap_buffer(result, a.device_ordinal()))
556 }
557
558 fn matmul_f64(
561 &self,
562 a: &GpuBufferHandle,
563 b: &GpuBufferHandle,
564 m: usize,
565 k: usize,
566 n: usize,
567 ) -> FerrotorchResult<GpuBufferHandle> {
568 let a_buf = Self::unwrap_buffer_f64(a)?;
569 let b_buf = Self::unwrap_buffer_f64(b)?;
570 let dev = self.device(a.device_ordinal())?;
571 let result =
572 crate::blas::gpu_matmul_f64(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
573 Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
574 }
575
576 fn broadcast_add_f32(
579 &self,
580 a: &GpuBufferHandle,
581 b: &GpuBufferHandle,
582 a_shape: &[usize],
583 b_shape: &[usize],
584 out_shape: &[usize],
585 ) -> FerrotorchResult<GpuBufferHandle> {
586 let a_buf = Self::unwrap_buffer(a)?;
587 let b_buf = Self::unwrap_buffer(b)?;
588 let dev = self.device(a.device_ordinal())?;
589 let result =
590 crate::kernels::gpu_broadcast_add(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
591 .map_err(Self::map_gpu_err)?;
592 Ok(Self::wrap_buffer(result, a.device_ordinal()))
593 }
594
595 fn broadcast_sub_f32(
596 &self,
597 a: &GpuBufferHandle,
598 b: &GpuBufferHandle,
599 a_shape: &[usize],
600 b_shape: &[usize],
601 out_shape: &[usize],
602 ) -> FerrotorchResult<GpuBufferHandle> {
603 let a_buf = Self::unwrap_buffer(a)?;
604 let b_buf = Self::unwrap_buffer(b)?;
605 let dev = self.device(a.device_ordinal())?;
606 let result =
607 crate::kernels::gpu_broadcast_sub(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
608 .map_err(Self::map_gpu_err)?;
609 Ok(Self::wrap_buffer(result, a.device_ordinal()))
610 }
611
612 fn broadcast_mul_f32(
613 &self,
614 a: &GpuBufferHandle,
615 b: &GpuBufferHandle,
616 a_shape: &[usize],
617 b_shape: &[usize],
618 out_shape: &[usize],
619 ) -> FerrotorchResult<GpuBufferHandle> {
620 let a_buf = Self::unwrap_buffer(a)?;
621 let b_buf = Self::unwrap_buffer(b)?;
622 let dev = self.device(a.device_ordinal())?;
623 let result =
624 crate::kernels::gpu_broadcast_mul(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
625 .map_err(Self::map_gpu_err)?;
626 Ok(Self::wrap_buffer(result, a.device_ordinal()))
627 }
628
629 fn broadcast_div_f32(
630 &self,
631 a: &GpuBufferHandle,
632 b: &GpuBufferHandle,
633 a_shape: &[usize],
634 b_shape: &[usize],
635 out_shape: &[usize],
636 ) -> FerrotorchResult<GpuBufferHandle> {
637 let a_buf = Self::unwrap_buffer(a)?;
638 let b_buf = Self::unwrap_buffer(b)?;
639 let dev = self.device(a.device_ordinal())?;
640 let result =
641 crate::kernels::gpu_broadcast_div(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
642 .map_err(Self::map_gpu_err)?;
643 Ok(Self::wrap_buffer(result, a.device_ordinal()))
644 }
645
646 fn softmax_f32(
647 &self,
648 a: &GpuBufferHandle,
649 rows: usize,
650 cols: usize,
651 ) -> FerrotorchResult<GpuBufferHandle> {
652 let a_buf = Self::unwrap_buffer(a)?;
653 let dev = self.device(a.device_ordinal())?;
654 let result =
655 crate::kernels::gpu_softmax(a_buf, rows, cols, dev).map_err(Self::map_gpu_err)?;
656 Ok(Self::wrap_buffer(result, a.device_ordinal()))
657 }
658
659 fn dropout_f32(
660 &self,
661 a: &GpuBufferHandle,
662 threshold: u32,
663 scale: f32,
664 seed: u32,
665 ) -> FerrotorchResult<GpuBufferHandle> {
666 let a_buf = Self::unwrap_buffer(a)?;
667 let dev = self.device(a.device_ordinal())?;
668 let result = crate::kernels::gpu_dropout(a_buf, threshold, scale, seed, dev)
669 .map_err(Self::map_gpu_err)?;
670 Ok(Self::wrap_buffer(result, a.device_ordinal()))
671 }
672
673 fn dropout_philox_f32(
674 &self,
675 a: &GpuBufferHandle,
676 threshold: u32,
677 scale: f32,
678 ) -> FerrotorchResult<(GpuBufferHandle, GpuRngState)> {
679 let device_ordinal = a.device_ordinal();
680 let n = a.len();
681
682 let rng_state = {
684 let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
685 FerrotorchError::InvalidArgument {
686 message: "failed to lock CUDA RNG manager".into(),
687 }
688 })?;
689 let philox_gen = mgr.generator(device_ordinal);
690 let state = philox_gen.get_state();
691 let counters_needed = n.div_ceil(4);
693 philox_gen.advance(counters_needed as u64);
694 state
695 };
696
697 let a_buf = Self::unwrap_buffer(a)?;
705 let dev = self.device(device_ordinal)?;
706
707 let derived_seed = (rng_state.counter ^ rng_state.seed) as u32;
710 let result = crate::kernels::gpu_dropout(a_buf, threshold, scale, derived_seed, dev)
711 .map_err(Self::map_gpu_err)?;
712
713 let gpu_rng_state = GpuRngState {
714 counter: rng_state.counter,
715 seed: rng_state.seed,
716 offset: rng_state.offset,
717 device: device_ordinal,
718 };
719
720 Ok((Self::wrap_buffer(result, device_ordinal), gpu_rng_state))
721 }
722
723 fn transpose_2d_f32(
724 &self,
725 a: &GpuBufferHandle,
726 m: usize,
727 n: usize,
728 ) -> FerrotorchResult<GpuBufferHandle> {
729 let a_buf = Self::unwrap_buffer(a)?;
730 let dev = self.device(a.device_ordinal())?;
731 let result =
732 crate::kernels::gpu_transpose_2d(a_buf, m, n, dev).map_err(Self::map_gpu_err)?;
733 Ok(Self::wrap_buffer(result, a.device_ordinal()))
734 }
735
736 fn permute_0213_f32(
737 &self,
738 a: &GpuBufferHandle,
739 d0: usize,
740 d1: usize,
741 d2: usize,
742 d3: usize,
743 ) -> FerrotorchResult<GpuBufferHandle> {
744 let a_buf = Self::unwrap_buffer(a)?;
745 let dev = self.device(a.device_ordinal())?;
746 let result = crate::kernels::gpu_permute_0213(a_buf, d0, d1, d2, d3, dev)
747 .map_err(Self::map_gpu_err)?;
748 Ok(Self::wrap_buffer(result, a.device_ordinal()))
749 }
750
751 fn bmm_f32(
752 &self,
753 a: &GpuBufferHandle,
754 b: &GpuBufferHandle,
755 batch: usize,
756 m: usize,
757 k: usize,
758 n: usize,
759 ) -> FerrotorchResult<GpuBufferHandle> {
760 let a_buf = Self::unwrap_buffer(a)?;
761 let b_buf = Self::unwrap_buffer(b)?;
762 let dev = self.device(a.device_ordinal())?;
763 let result = crate::blas::gpu_bmm_f32(a_buf, b_buf, batch, m, k, n, dev)
764 .map_err(Self::map_gpu_err)?;
765 Ok(Self::wrap_buffer(result, a.device_ordinal()))
766 }
767
768 fn bmm_f16_f32(
769 &self,
770 a: &GpuBufferHandle,
771 b: &GpuBufferHandle,
772 batch: usize,
773 m: usize,
774 k: usize,
775 n: usize,
776 ) -> FerrotorchResult<GpuBufferHandle> {
777 let a_buf = Self::unwrap_buffer(a)?;
778 let b_buf = Self::unwrap_buffer(b)?;
779 let dev = self.device(a.device_ordinal())?;
780 let result = crate::blas::gpu_bmm_f16(a_buf, b_buf, batch, m, k, n, dev)
781 .map_err(Self::map_gpu_err)?;
782 Ok(Self::wrap_buffer(result, a.device_ordinal()))
783 }
784
785 fn gelu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
786 let a_buf = Self::unwrap_buffer(a)?;
787 let dev = self.device(a.device_ordinal())?;
788 let result = crate::kernels::gpu_gelu(a_buf, dev).map_err(Self::map_gpu_err)?;
789 Ok(Self::wrap_buffer(result, a.device_ordinal()))
790 }
791
792 fn gelu_tanh_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
793 let a_buf = Self::unwrap_buffer(a)?;
794 let dev = self.device(a.device_ordinal())?;
795 let result = crate::kernels::gpu_gelu_tanh(a_buf, dev).map_err(Self::map_gpu_err)?;
796 Ok(Self::wrap_buffer(result, a.device_ordinal()))
797 }
798
799 fn gelu_erf_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
800 let a_buf = Self::unwrap_buffer(a)?;
801 let dev = self.device(a.device_ordinal())?;
802 let result = crate::kernels::gpu_gelu_erf(a_buf, dev).map_err(Self::map_gpu_err)?;
803 Ok(Self::wrap_buffer(result, a.device_ordinal()))
804 }
805
806 fn layernorm_f32(
807 &self,
808 input: &GpuBufferHandle,
809 weight: &GpuBufferHandle,
810 bias: &GpuBufferHandle,
811 rows: usize,
812 cols: usize,
813 eps: f32,
814 ) -> FerrotorchResult<GpuBufferHandle> {
815 let in_buf = Self::unwrap_buffer(input)?;
816 let w_buf = Self::unwrap_buffer(weight)?;
817 let b_buf = Self::unwrap_buffer(bias)?;
818 let dev = self.device(input.device_ordinal())?;
819 let result = crate::kernels::gpu_layernorm(in_buf, w_buf, b_buf, rows, cols, eps, dev)
820 .map_err(Self::map_gpu_err)?;
821 Ok(Self::wrap_buffer(result, input.device_ordinal()))
822 }
823
824 fn slice_write_f32(
825 &self,
826 src: &GpuBufferHandle,
827 dst: &mut GpuBufferHandle,
828 n_batch: usize,
829 d: usize,
830 max_len: usize,
831 pos: usize,
832 ) -> FerrotorchResult<()> {
833 let src_buf = Self::unwrap_buffer(src)?;
834 let dst_buf =
835 dst.downcast_mut::<CudaBuffer<f32>>()
836 .ok_or(FerrotorchError::InvalidArgument {
837 message: "slice_write_f32: dst is not CudaBuffer<f32>".into(),
838 })?;
839 let dev = self.device(src.device_ordinal())?;
840 crate::kernels::gpu_slice_write(src_buf, dst_buf, n_batch, d, max_len, pos, dev)
841 .map_err(Self::map_gpu_err)?;
842 Ok(())
843 }
844
845 fn slice_read_f32(
846 &self,
847 src: &GpuBufferHandle,
848 n_batch: usize,
849 d: usize,
850 len: usize,
851 max_len: usize,
852 ) -> FerrotorchResult<GpuBufferHandle> {
853 let src_buf = Self::unwrap_buffer(src)?;
854 let dev = self.device(src.device_ordinal())?;
855 let result = crate::kernels::gpu_slice_read(src_buf, n_batch, d, len, max_len, dev)
856 .map_err(Self::map_gpu_err)?;
857 Ok(Self::wrap_buffer(result, src.device_ordinal()))
858 }
859
860 fn embed_lookup_f32(
861 &self,
862 idx: &GpuBufferHandle,
863 weight: &GpuBufferHandle,
864 d: usize,
865 ) -> FerrotorchResult<GpuBufferHandle> {
866 let idx_buf = Self::unwrap_buffer(idx)?;
867 let w_buf = Self::unwrap_buffer(weight)?;
868 let dev = self.device(idx.device_ordinal())?;
869 let result =
870 crate::kernels::gpu_embed_lookup(idx_buf, w_buf, d, dev).map_err(Self::map_gpu_err)?;
871 Ok(Self::wrap_buffer(result, idx.device_ordinal()))
872 }
873
874 fn embed_lookup_batch_f32(
875 &self,
876 indices: &GpuBufferHandle,
877 weight: &GpuBufferHandle,
878 n: usize,
879 d: usize,
880 ) -> FerrotorchResult<GpuBufferHandle> {
881 let idx_buf = Self::unwrap_buffer(indices)?;
882 let w_buf = Self::unwrap_buffer(weight)?;
883 let dev = self.device(indices.device_ordinal())?;
884 let result = crate::kernels::gpu_embed_lookup_batch(idx_buf, w_buf, n, d, dev)
885 .map_err(Self::map_gpu_err)?;
886 Ok(Self::wrap_buffer(result, indices.device_ordinal()))
887 }
888
889 fn scatter_add_rows_f32(
890 &self,
891 grad_output: &GpuBufferHandle,
892 indices: &GpuBufferHandle,
893 num_embeddings: usize,
894 d: usize,
895 ) -> FerrotorchResult<GpuBufferHandle> {
896 let go_buf = Self::unwrap_buffer(grad_output)?;
897 let idx_buf = Self::unwrap_buffer(indices)?;
898 let dev = self.device(grad_output.device_ordinal())?;
899 let result = crate::kernels::gpu_scatter_add_rows(go_buf, idx_buf, num_embeddings, d, dev)
900 .map_err(Self::map_gpu_err)?;
901 Ok(Self::wrap_buffer(result, grad_output.device_ordinal()))
902 }
903
904 fn scale_f32(&self, a: &GpuBufferHandle, scalar: f32) -> FerrotorchResult<GpuBufferHandle> {
905 let a_buf = Self::unwrap_buffer(a)?;
906 let dev = self.device(a.device_ordinal())?;
907 let result = crate::kernels::gpu_scale(a_buf, scalar, dev).map_err(Self::map_gpu_err)?;
908 Ok(Self::wrap_buffer(result, a.device_ordinal()))
909 }
910
911 fn relu_backward_f32(
912 &self,
913 grad: &GpuBufferHandle,
914 input: &GpuBufferHandle,
915 ) -> FerrotorchResult<GpuBufferHandle> {
916 let grad_buf = Self::unwrap_buffer(grad)?;
917 let input_buf = Self::unwrap_buffer(input)?;
918 let dev = self.device(grad.device_ordinal())?;
919 let result = crate::kernels::gpu_relu_backward(grad_buf, input_buf, dev)
920 .map_err(Self::map_gpu_err)?;
921 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
922 }
923
924 fn gelu_backward_f32(
925 &self,
926 grad: &GpuBufferHandle,
927 input: &GpuBufferHandle,
928 ) -> FerrotorchResult<GpuBufferHandle> {
929 let grad_buf = Self::unwrap_buffer(grad)?;
930 let input_buf = Self::unwrap_buffer(input)?;
931 let dev = self.device(grad.device_ordinal())?;
932 let result = crate::kernels::gpu_gelu_backward(grad_buf, input_buf, dev)
933 .map_err(Self::map_gpu_err)?;
934 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
935 }
936
937 fn gelu_backward_tanh_f32(
938 &self,
939 grad: &GpuBufferHandle,
940 input: &GpuBufferHandle,
941 ) -> FerrotorchResult<GpuBufferHandle> {
942 let grad_buf = Self::unwrap_buffer(grad)?;
943 let input_buf = Self::unwrap_buffer(input)?;
944 let dev = self.device(grad.device_ordinal())?;
945 let result = crate::kernels::gpu_gelu_backward_tanh(grad_buf, input_buf, dev)
946 .map_err(Self::map_gpu_err)?;
947 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
948 }
949
950 fn gelu_backward_erf_f32(
951 &self,
952 grad: &GpuBufferHandle,
953 input: &GpuBufferHandle,
954 ) -> FerrotorchResult<GpuBufferHandle> {
955 let grad_buf = Self::unwrap_buffer(grad)?;
956 let input_buf = Self::unwrap_buffer(input)?;
957 let dev = self.device(grad.device_ordinal())?;
958 let result = crate::kernels::gpu_gelu_backward_erf(grad_buf, input_buf, dev)
959 .map_err(Self::map_gpu_err)?;
960 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
961 }
962
963 fn index_select_1d_f32(
964 &self,
965 input: &GpuBufferHandle,
966 indices: &GpuBufferHandle,
967 ) -> FerrotorchResult<GpuBufferHandle> {
968 let input_buf = Self::unwrap_buffer(input)?;
969 let idx_buf = Self::unwrap_buffer(indices)?;
970 let dev = self.device(input.device_ordinal())?;
971 let result = crate::kernels::gpu_index_select_1d(input_buf, idx_buf, dev)
972 .map_err(Self::map_gpu_err)?;
973 Ok(Self::wrap_buffer(result, input.device_ordinal()))
974 }
975
976 fn scatter_add_1d_f32(
977 &self,
978 grad_output: &GpuBufferHandle,
979 indices: &GpuBufferHandle,
980 input_len: usize,
981 ) -> FerrotorchResult<GpuBufferHandle> {
982 let go_buf = Self::unwrap_buffer(grad_output)?;
983 let idx_buf = Self::unwrap_buffer(indices)?;
984 let dev = self.device(grad_output.device_ordinal())?;
985 let result = crate::kernels::gpu_scatter_add_1d(go_buf, idx_buf, input_len, dev)
986 .map_err(Self::map_gpu_err)?;
987 Ok(Self::wrap_buffer(result, grad_output.device_ordinal()))
988 }
989
990 fn masked_fill_f32(
991 &self,
992 input: &GpuBufferHandle,
993 mask: &GpuBufferHandle,
994 value: f32,
995 ) -> FerrotorchResult<GpuBufferHandle> {
996 let input_buf = Self::unwrap_buffer(input)?;
997 let mask_buf = Self::unwrap_buffer(mask)?;
998 let dev = self.device(input.device_ordinal())?;
999 let result = crate::kernels::gpu_masked_fill(input_buf, mask_buf, value, dev)
1000 .map_err(Self::map_gpu_err)?;
1001 Ok(Self::wrap_buffer(result, input.device_ordinal()))
1002 }
1003
1004 fn masked_zero_f32(
1005 &self,
1006 grad: &GpuBufferHandle,
1007 mask: &GpuBufferHandle,
1008 ) -> FerrotorchResult<GpuBufferHandle> {
1009 let grad_buf = Self::unwrap_buffer(grad)?;
1010 let mask_buf = Self::unwrap_buffer(mask)?;
1011 let dev = self.device(grad.device_ordinal())?;
1012 let result =
1013 crate::kernels::gpu_masked_zero(grad_buf, mask_buf, dev).map_err(Self::map_gpu_err)?;
1014 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1015 }
1016
1017 fn sigmoid_backward_f32(
1018 &self,
1019 grad: &GpuBufferHandle,
1020 output: &GpuBufferHandle,
1021 ) -> FerrotorchResult<GpuBufferHandle> {
1022 let grad_buf = Self::unwrap_buffer(grad)?;
1023 let output_buf = Self::unwrap_buffer(output)?;
1024 let dev = self.device(grad.device_ordinal())?;
1025 let result = crate::kernels::gpu_sigmoid_backward(grad_buf, output_buf, dev)
1026 .map_err(Self::map_gpu_err)?;
1027 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1028 }
1029
1030 fn tanh_backward_f32(
1031 &self,
1032 grad: &GpuBufferHandle,
1033 output: &GpuBufferHandle,
1034 ) -> FerrotorchResult<GpuBufferHandle> {
1035 let grad_buf = Self::unwrap_buffer(grad)?;
1036 let output_buf = Self::unwrap_buffer(output)?;
1037 let dev = self.device(grad.device_ordinal())?;
1038 let result = crate::kernels::gpu_tanh_backward(grad_buf, output_buf, dev)
1039 .map_err(Self::map_gpu_err)?;
1040 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1041 }
1042
1043 fn softmax_backward_f32(
1044 &self,
1045 grad: &GpuBufferHandle,
1046 output: &GpuBufferHandle,
1047 cols: usize,
1048 ) -> FerrotorchResult<GpuBufferHandle> {
1049 let grad_buf = Self::unwrap_buffer(grad)?;
1050 let output_buf = Self::unwrap_buffer(output)?;
1051 let dev = self.device(grad.device_ordinal())?;
1052 let result = crate::kernels::gpu_softmax_backward(grad_buf, output_buf, cols, dev)
1053 .map_err(Self::map_gpu_err)?;
1054 Ok(Self::wrap_buffer(result, grad.device_ordinal()))
1055 }
1056
1057 fn layernorm_backward_f32(
1058 &self,
1059 input: &GpuBufferHandle,
1060 grad_output: &GpuBufferHandle,
1061 weight: &GpuBufferHandle,
1062 rows: usize,
1063 cols: usize,
1064 eps: f32,
1065 ) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
1066 let in_buf = Self::unwrap_buffer(input)?;
1067 let go_buf = Self::unwrap_buffer(grad_output)?;
1068 let w_buf = Self::unwrap_buffer(weight)?;
1069 let dev = self.device(input.device_ordinal())?;
1070 let (gi, gw, gb) =
1071 crate::kernels::gpu_layernorm_backward(in_buf, go_buf, w_buf, rows, cols, eps, dev)
1072 .map_err(Self::map_gpu_err)?;
1073 let ordinal = input.device_ordinal();
1074 Ok((
1075 Self::wrap_buffer(gi, ordinal),
1076 Self::wrap_buffer(gw, ordinal),
1077 Self::wrap_buffer(gb, ordinal),
1078 ))
1079 }
1080
1081 fn sum_axis_f32(
1082 &self,
1083 a: &GpuBufferHandle,
1084 shape: &[usize],
1085 axis: usize,
1086 ) -> FerrotorchResult<GpuBufferHandle> {
1087 let a_buf = Self::unwrap_buffer(a)?;
1088 let dev = self.device(a.device_ordinal())?;
1089 let outer: usize = shape[..axis].iter().product();
1090 let axis_size = shape[axis];
1091 let inner: usize = shape[axis + 1..].iter().product::<usize>().max(1);
1092 let result = crate::kernels::gpu_sum_axis(a_buf, outer, axis_size, inner, dev)
1093 .map_err(Self::map_gpu_err)?;
1094 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1095 }
1096
1097 fn matmul_f16_f32(
1098 &self,
1099 a: &GpuBufferHandle,
1100 b: &GpuBufferHandle,
1101 m: usize,
1102 k: usize,
1103 n: usize,
1104 ) -> FerrotorchResult<GpuBufferHandle> {
1105 let a_buf = Self::unwrap_buffer(a)?;
1106 let b_buf = Self::unwrap_buffer(b)?;
1107 let dev = self.device(a.device_ordinal())?;
1108 let result =
1109 crate::blas::gpu_matmul_f16(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
1110 Ok(Self::wrap_buffer(result, a.device_ordinal()))
1111 }
1112
1113 fn save_rng_state(&self, device: usize) -> FerrotorchResult<GpuRngState> {
1114 let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
1115 FerrotorchError::InvalidArgument {
1116 message: "failed to lock CUDA RNG manager".into(),
1117 }
1118 })?;
1119 let state = mgr.get_rng_state(device);
1120 Ok(GpuRngState {
1121 counter: state.counter,
1122 seed: state.seed,
1123 offset: state.offset,
1124 device,
1125 })
1126 }
1127
1128 fn restore_rng_state(&self, state: GpuRngState) -> FerrotorchResult<()> {
1129 let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
1130 FerrotorchError::InvalidArgument {
1131 message: "failed to lock CUDA RNG manager".into(),
1132 }
1133 })?;
1134 mgr.set_rng_state(
1135 state.device,
1136 crate::rng::PhiloxState {
1137 counter: state.counter,
1138 seed: state.seed,
1139 offset: state.offset,
1140 },
1141 );
1142 Ok(())
1143 }
1144
1145 fn strided_split_f32(
1146 &self,
1147 input: &GpuBufferHandle,
1148 total_along_axis: usize,
1149 split_offset: usize,
1150 split_size: usize,
1151 inner_size: usize,
1152 n: usize,
1153 ) -> FerrotorchResult<GpuBufferHandle> {
1154 let in_buf = Self::unwrap_buffer(input)?;
1155 let dev = self.device(input.device_ordinal())?;
1156 let result = crate::kernels::gpu_strided_split(
1157 in_buf,
1158 total_along_axis,
1159 split_offset,
1160 split_size,
1161 inner_size,
1162 n,
1163 dev,
1164 )
1165 .map_err(Self::map_gpu_err)?;
1166 Ok(Self::wrap_buffer(result, input.device_ordinal()))
1167 }
1168
1169 fn strided_cat_f32(
1170 &self,
1171 input: &GpuBufferHandle,
1172 output: &mut GpuBufferHandle,
1173 total_along_axis: usize,
1174 cat_offset: usize,
1175 part_size: usize,
1176 inner_size: usize,
1177 n: usize,
1178 ) -> FerrotorchResult<()> {
1179 let in_buf = Self::unwrap_buffer(input)?;
1180 let dev = self.device(input.device_ordinal())?;
1181 let out_buf =
1182 output
1183 .downcast_mut::<CudaBuffer<f32>>()
1184 .ok_or(FerrotorchError::InvalidArgument {
1185 message: "strided_cat_f32: output is not CudaBuffer<f32>".into(),
1186 })?;
1187 crate::kernels::gpu_strided_cat(
1188 in_buf,
1189 out_buf,
1190 total_along_axis,
1191 cat_offset,
1192 part_size,
1193 inner_size,
1194 n,
1195 dev,
1196 )
1197 .map_err(Self::map_gpu_err)?;
1198 Ok(())
1199 }
1200}
1201
1202pub fn get_cuda_device() -> FerrotorchResult<Arc<GpuDevice>> {
1213 let backend =
1214 ferrotorch_core::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
1215 let cuda_backend = backend.as_any().downcast_ref::<CudaBackendImpl>().ok_or(
1218 FerrotorchError::InvalidArgument {
1219 message: "registered GPU backend is not CudaBackendImpl".into(),
1220 },
1221 )?;
1222 Ok(Arc::clone(cuda_backend.default_device()?))
1223}
1224
1225pub fn init_cuda_backend() -> FerrotorchResult<()> {
1239 if ferrotorch_core::gpu_dispatch::has_gpu_backend() {
1241 return Ok(());
1242 }
1243 let backend = CudaBackendImpl::new()?;
1244 let _ = ferrotorch_core::gpu_dispatch::register_gpu_backend(Box::new(backend));
1248 Ok(())
1249}
1250
1251#[cfg(test)]
1256#[cfg(feature = "cuda")]
1257mod tests {
1258 use super::*;
1259 use ferrotorch_core::gpu_dispatch;
1260
1261 fn ensure_init() {
1268 if !gpu_dispatch::has_gpu_backend() {
1269 init_cuda_backend().expect("init_cuda_backend");
1270 }
1271 }
1272
1273 #[test]
1274 fn test_init_cuda_backend() {
1275 ensure_init();
1277 assert!(gpu_dispatch::has_gpu_backend());
1278 }
1279
1280 #[test]
1281 fn test_gpu_backend_returns_some() {
1282 ensure_init();
1283 assert!(gpu_dispatch::gpu_backend().is_some());
1284 }
1285
1286 #[test]
1287 fn test_roundtrip_cpu_gpu_cpu() {
1288 ensure_init();
1289 let backend = gpu_dispatch::gpu_backend().expect("backend registered");
1290
1291 let host: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1292 let bytes: &[u8] = unsafe {
1293 std::slice::from_raw_parts(
1294 host.as_ptr() as *const u8,
1295 host.len() * std::mem::size_of::<f32>(),
1296 )
1297 };
1298
1299 let handle = backend.cpu_to_gpu(bytes, 4, 0).expect("cpu_to_gpu");
1300 assert_eq!(handle.len(), 5);
1301 assert_eq!(handle.device_ordinal(), 0);
1302
1303 let back_bytes = backend.gpu_to_cpu(&handle).expect("gpu_to_cpu");
1304 let back: &[f32] = unsafe {
1305 std::slice::from_raw_parts(back_bytes.as_ptr() as *const f32, back_bytes.len() / 4)
1306 };
1307 assert_eq!(back, &host[..]);
1308 }
1309
1310 #[test]
1311 fn test_add_f32() {
1312 ensure_init();
1313 let backend = gpu_dispatch::gpu_backend().expect("backend registered");
1314
1315 let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
1316 let b_data: Vec<f32> = vec![10.0, 20.0, 30.0, 40.0];
1317 let expected: Vec<f32> = vec![11.0, 22.0, 33.0, 44.0];
1318
1319 let a_bytes: &[u8] =
1320 unsafe { std::slice::from_raw_parts(a_data.as_ptr() as *const u8, a_data.len() * 4) };
1321 let b_bytes: &[u8] =
1322 unsafe { std::slice::from_raw_parts(b_data.as_ptr() as *const u8, b_data.len() * 4) };
1323
1324 let a_handle = backend.cpu_to_gpu(a_bytes, 4, 0).expect("cpu_to_gpu a");
1325 let b_handle = backend.cpu_to_gpu(b_bytes, 4, 0).expect("cpu_to_gpu b");
1326
1327 let result = backend.add_f32(&a_handle, &b_handle).expect("add_f32");
1328 assert_eq!(result.len(), 4);
1329
1330 let result_bytes = backend.gpu_to_cpu(&result).expect("gpu_to_cpu");
1331 let result_f32: &[f32] = unsafe {
1332 std::slice::from_raw_parts(result_bytes.as_ptr() as *const f32, result_bytes.len() / 4)
1333 };
1334
1335 for (i, (&got, &exp)) in result_f32.iter().zip(expected.iter()).enumerate() {
1336 assert!(
1337 (got - exp).abs() < 1e-6,
1338 "element {i}: got {got}, expected {exp}",
1339 );
1340 }
1341 }
1342
1343 #[test]
1344 fn test_matmul_f32() {
1345 ensure_init();
1346 let backend = gpu_dispatch::gpu_backend().expect("backend registered");
1347
1348 let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1356 let b_data: Vec<f32> = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
1357 let expected: Vec<f32> = vec![58.0, 64.0, 139.0, 154.0];
1358
1359 let a_bytes: &[u8] =
1360 unsafe { std::slice::from_raw_parts(a_data.as_ptr() as *const u8, a_data.len() * 4) };
1361 let b_bytes: &[u8] =
1362 unsafe { std::slice::from_raw_parts(b_data.as_ptr() as *const u8, b_data.len() * 4) };
1363
1364 let a_handle = backend.cpu_to_gpu(a_bytes, 4, 0).expect("cpu_to_gpu a");
1365 let b_handle = backend.cpu_to_gpu(b_bytes, 4, 0).expect("cpu_to_gpu b");
1366
1367 let result = backend
1368 .matmul_f32(&a_handle, &b_handle, 2, 3, 2)
1369 .expect("matmul_f32");
1370 assert_eq!(result.len(), 4);
1371
1372 let result_bytes = backend.gpu_to_cpu(&result).expect("gpu_to_cpu");
1373 let result_f32: &[f32] = unsafe {
1374 std::slice::from_raw_parts(result_bytes.as_ptr() as *const f32, result_bytes.len() / 4)
1375 };
1376
1377 for (i, (&got, &exp)) in result_f32.iter().zip(expected.iter()).enumerate() {
1378 assert!(
1379 (got - exp).abs() < 1e-3,
1380 "element {i}: got {got}, expected {exp}",
1381 );
1382 }
1383 }
1384}