use std::sync::{Arc, OnceLock};
use ferrotorch_core::dtype::DType;
use ferrotorch_core::error::{FerrotorchError, FerrotorchResult};
use ferrotorch_core::gpu_dispatch::{GpuBackend, GpuBufferHandle, GpuRngState};
use crate::buffer::CudaBuffer;
#[cfg(all(feature = "cuda", feature = "cusparselt"))]
use crate::cusparselt::CusparseLtHandle;
use crate::device::GpuDevice;
#[cfg(feature = "cuda")]
use crate::sparse::CusparseHandle;
pub struct CudaBackendImpl {
devices: Vec<Arc<GpuDevice>>,
#[cfg(feature = "cuda")]
cusparse_handle: OnceLock<CusparseHandle>,
#[cfg(all(feature = "cuda", feature = "cusparselt"))]
cusparselt_handle: OnceLock<CusparseLtHandle>,
}
impl CudaBackendImpl {
pub fn new() -> FerrotorchResult<Self> {
let device = Arc::new(
GpuDevice::new(0).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("CUDA init failed: {e}"),
})?,
);
Ok(Self {
devices: vec![device],
#[cfg(feature = "cuda")]
cusparse_handle: OnceLock::new(),
#[cfg(all(feature = "cuda", feature = "cusparselt"))]
cusparselt_handle: OnceLock::new(),
})
}
#[cfg(all(feature = "cuda", feature = "cusparselt"))]
fn cusparselt(&self) -> FerrotorchResult<&CusparseLtHandle> {
if let Some(h) = self.cusparselt_handle.get() {
return Ok(h);
}
let new_handle = CusparseLtHandle::new().map_err(Self::map_gpu_err)?;
let _ = self.cusparselt_handle.set(new_handle);
self.cusparselt_handle
.get()
.ok_or(FerrotorchError::InvalidArgument {
message: "cuSPARSELt handle slot empty after init".into(),
})
}
#[cfg(feature = "cuda")]
fn cusparse(&self) -> FerrotorchResult<&CusparseHandle> {
if let Some(h) = self.cusparse_handle.get() {
return Ok(h);
}
let new_handle = CusparseHandle::new().map_err(Self::map_gpu_err)?;
let _ = self.cusparse_handle.set(new_handle);
self.cusparse_handle
.get()
.ok_or(FerrotorchError::InvalidArgument {
message: "cuSPARSE handle slot empty after init".into(),
})
}
pub fn default_device(&self) -> FerrotorchResult<&Arc<GpuDevice>> {
self.device(0)
}
fn device(&self, ordinal: usize) -> FerrotorchResult<&Arc<GpuDevice>> {
self.devices
.get(ordinal)
.ok_or(FerrotorchError::InvalidArgument {
message: format!("CUDA device {ordinal} not available"),
})
}
fn wrap_buffer(buf: CudaBuffer<f32>, ordinal: usize) -> GpuBufferHandle {
let len = buf.len();
GpuBufferHandle::new(Box::new(buf), ordinal, len, DType::F32)
}
fn wrap_buffer_f64(buf: CudaBuffer<f64>, ordinal: usize) -> GpuBufferHandle {
let len = buf.len();
GpuBufferHandle::new(Box::new(buf), ordinal, len, DType::F64)
}
fn unwrap_buffer(handle: &GpuBufferHandle) -> FerrotorchResult<&CudaBuffer<f32>> {
if handle.dtype() != DType::F32 {
return Err(FerrotorchError::InvalidArgument {
message: format!("expected F32 buffer, handle is tagged {}", handle.dtype()),
});
}
handle
.downcast_ref::<CudaBuffer<f32>>()
.ok_or(FerrotorchError::InvalidArgument {
message: "GPU handle does not contain a CudaBuffer<f32>".into(),
})
}
fn unwrap_buffer_mut(handle: &mut GpuBufferHandle) -> FerrotorchResult<&mut CudaBuffer<f32>> {
if handle.dtype() != DType::F32 {
return Err(FerrotorchError::InvalidArgument {
message: format!("expected F32 buffer, handle is tagged {}", handle.dtype()),
});
}
handle
.downcast_mut::<CudaBuffer<f32>>()
.ok_or(FerrotorchError::InvalidArgument {
message: "GPU handle does not contain a CudaBuffer<f32>".into(),
})
}
fn unwrap_buffer_f64_mut(
handle: &mut GpuBufferHandle,
) -> FerrotorchResult<&mut CudaBuffer<f64>> {
if handle.dtype() != DType::F64 {
return Err(FerrotorchError::InvalidArgument {
message: format!("expected F64 buffer, handle is tagged {}", handle.dtype()),
});
}
handle
.downcast_mut::<CudaBuffer<f64>>()
.ok_or(FerrotorchError::InvalidArgument {
message: "GPU handle does not contain a CudaBuffer<f64>".into(),
})
}
fn unwrap_buffer_f64(handle: &GpuBufferHandle) -> FerrotorchResult<&CudaBuffer<f64>> {
if handle.dtype() != DType::F64 {
return Err(FerrotorchError::InvalidArgument {
message: format!("expected F64 buffer, handle is tagged {}", handle.dtype()),
});
}
handle
.downcast_ref::<CudaBuffer<f64>>()
.ok_or(FerrotorchError::InvalidArgument {
message: "GPU handle does not contain a CudaBuffer<f64>".into(),
})
}
#[cfg(feature = "cuda")]
fn wrap_buffer_bf16(buf: cudarc::driver::CudaSlice<u16>, ordinal: usize) -> GpuBufferHandle {
let len = buf.len();
GpuBufferHandle::new(Box::new(buf), ordinal, len, DType::BF16)
}
#[cfg(feature = "cuda")]
fn unwrap_buffer_bf16(
handle: &GpuBufferHandle,
) -> FerrotorchResult<&cudarc::driver::CudaSlice<u16>> {
if handle.dtype() != DType::BF16 {
return Err(FerrotorchError::InvalidArgument {
message: format!("expected BF16 buffer, handle is tagged {}", handle.dtype()),
});
}
handle
.downcast_ref::<cudarc::driver::CudaSlice<u16>>()
.ok_or(FerrotorchError::InvalidArgument {
message: "GPU handle does not contain a CudaSlice<u16> (bf16)".into(),
})
}
#[cfg(feature = "cuda")]
fn wrap_buffer_f16(buf: cudarc::driver::CudaSlice<u16>, ordinal: usize) -> GpuBufferHandle {
let len = buf.len();
GpuBufferHandle::new(Box::new(buf), ordinal, len, DType::F16)
}
#[cfg(feature = "cuda")]
fn unwrap_buffer_f16(
handle: &GpuBufferHandle,
) -> FerrotorchResult<&cudarc::driver::CudaSlice<u16>> {
if handle.dtype() != DType::F16 {
return Err(FerrotorchError::InvalidArgument {
message: format!("expected F16 buffer, handle is tagged {}", handle.dtype()),
});
}
handle
.downcast_ref::<cudarc::driver::CudaSlice<u16>>()
.ok_or(FerrotorchError::InvalidArgument {
message: "GPU handle does not contain a CudaSlice<u16> (f16)".into(),
})
}
fn wrap_buffer_i32(buf: CudaBuffer<i32>, ordinal: usize) -> GpuBufferHandle {
let len = buf.len();
GpuBufferHandle::new(Box::new(buf), ordinal, len, DType::I32)
}
fn unwrap_buffer_i32(handle: &GpuBufferHandle) -> FerrotorchResult<&CudaBuffer<i32>> {
if handle.dtype() != DType::I32 {
return Err(FerrotorchError::InvalidArgument {
message: format!("expected I32 buffer, handle is tagged {}", handle.dtype()),
});
}
handle
.downcast_ref::<CudaBuffer<i32>>()
.ok_or(FerrotorchError::InvalidArgument {
message: "GPU handle does not contain a CudaBuffer<i32>".into(),
})
}
fn wrap_buffer_i64(buf: CudaBuffer<i64>, ordinal: usize) -> GpuBufferHandle {
let len = buf.len();
GpuBufferHandle::new(Box::new(buf), ordinal, len, DType::I64)
}
fn unwrap_buffer_i64(handle: &GpuBufferHandle) -> FerrotorchResult<&CudaBuffer<i64>> {
if handle.dtype() != DType::I64 {
return Err(FerrotorchError::InvalidArgument {
message: format!("expected I64 buffer, handle is tagged {}", handle.dtype()),
});
}
handle
.downcast_ref::<CudaBuffer<i64>>()
.ok_or(FerrotorchError::InvalidArgument {
message: "GPU handle does not contain a CudaBuffer<i64>".into(),
})
}
fn wrap_buffer_bool(buf: CudaBuffer<u8>, ordinal: usize) -> GpuBufferHandle {
let len = buf.len();
GpuBufferHandle::new(Box::new(buf), ordinal, len, DType::Bool)
}
fn unwrap_buffer_bool(handle: &GpuBufferHandle) -> FerrotorchResult<&CudaBuffer<u8>> {
if handle.dtype() != DType::Bool {
return Err(FerrotorchError::InvalidArgument {
message: format!("expected Bool buffer, handle is tagged {}", handle.dtype()),
});
}
handle
.downcast_ref::<CudaBuffer<u8>>()
.ok_or(FerrotorchError::InvalidArgument {
message: "GPU handle does not contain a CudaBuffer<u8> (bool)".into(),
})
}
#[cfg(feature = "cuda")]
fn wrap_slice_bool(slice: cudarc::driver::CudaSlice<u8>, ordinal: usize) -> GpuBufferHandle {
let len = slice.len();
let buf = CudaBuffer {
data: Some(slice),
len,
alloc_len: len,
device_ordinal: ordinal,
pool_fn: None,
};
Self::wrap_buffer_bool(buf, ordinal)
}
#[cfg(feature = "cuda")]
fn wrap_slice_i32(slice: cudarc::driver::CudaSlice<i32>, ordinal: usize) -> GpuBufferHandle {
let len = slice.len();
let buf = CudaBuffer {
data: Some(slice),
len,
alloc_len: len,
device_ordinal: ordinal,
pool_fn: None,
};
Self::wrap_buffer_i32(buf, ordinal)
}
#[cfg(feature = "cuda")]
fn wrap_slice_i64(slice: cudarc::driver::CudaSlice<i64>, ordinal: usize) -> GpuBufferHandle {
let len = slice.len();
let buf = CudaBuffer {
data: Some(slice),
len,
alloc_len: len,
device_ordinal: ordinal,
pool_fn: None,
};
Self::wrap_buffer_i64(buf, ordinal)
}
#[cfg(feature = "cuda")]
fn wrap_slice_f32(slice: cudarc::driver::CudaSlice<f32>, ordinal: usize) -> GpuBufferHandle {
let len = slice.len();
let buf = CudaBuffer {
data: Some(slice),
len,
alloc_len: len,
device_ordinal: ordinal,
pool_fn: None,
};
Self::wrap_buffer(buf, ordinal)
}
#[cfg(feature = "cuda")]
fn wrap_slice_f64(slice: cudarc::driver::CudaSlice<f64>, ordinal: usize) -> GpuBufferHandle {
let len = slice.len();
let buf = CudaBuffer {
data: Some(slice),
len,
alloc_len: len,
device_ordinal: ordinal,
pool_fn: None,
};
Self::wrap_buffer_f64(buf, ordinal)
}
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
fn gather_or_select(
&self,
src: &GpuBufferHandle,
index: &GpuBufferHandle,
outer: usize,
in_dim: usize,
out_dim: usize,
inner: usize,
is_gather: bool,
) -> FerrotorchResult<GpuBufferHandle> {
use crate::gather_int as gi;
let dev = self.device(src.device_ordinal())?;
let ord = src.device_ordinal();
let op = if is_gather { "gather" } else { "index_select" };
match index.dtype() {
DType::I32 | DType::I64 => {}
other => {
return Err(FerrotorchError::InvalidArgument {
message: format!("{op}: index dtype must be I32/I64, got {other}"),
});
}
}
let i32idx = index.dtype() == DType::I32;
macro_rules! run {
($val:expr, $g32:path, $g64:path, $s32:path, $s64:path, $wrap:expr) => {{
let r = if is_gather {
if i32idx {
$g32(
$val,
Self::unwrap_buffer_i32(index)?.inner(),
outer,
in_dim,
out_dim,
inner,
dev,
)
} else {
$g64(
$val,
Self::unwrap_buffer_i64(index)?.inner(),
outer,
in_dim,
out_dim,
inner,
dev,
)
}
} else if i32idx {
$s32(
$val,
Self::unwrap_buffer_i32(index)?.inner(),
outer,
in_dim,
out_dim,
inner,
dev,
)
} else {
$s64(
$val,
Self::unwrap_buffer_i64(index)?.inner(),
outer,
in_dim,
out_dim,
inner,
dev,
)
}
.map_err(Self::map_gpu_err)?;
Ok($wrap(r, ord))
}};
}
match src.dtype() {
DType::F32 => run!(
Self::unwrap_buffer(src)?.inner(),
gi::gather_f32_i32,
gi::gather_f32_i64,
gi::isel_f32_i32,
gi::isel_f32_i64,
Self::wrap_slice_f32
),
DType::F64 => run!(
Self::unwrap_buffer_f64(src)?.inner(),
gi::gather_f64_i32,
gi::gather_f64_i64,
gi::isel_f64_i32,
gi::isel_f64_i64,
Self::wrap_slice_f64
),
DType::I32 => run!(
Self::unwrap_buffer_i32(src)?.inner(),
gi::gather_i32_i32,
gi::gather_i32_i64,
gi::isel_i32_i32,
gi::isel_i32_i64,
Self::wrap_slice_i32
),
DType::I64 => run!(
Self::unwrap_buffer_i64(src)?.inner(),
gi::gather_i64_i32,
gi::gather_i64_i64,
gi::isel_i64_i32,
gi::isel_i64_i64,
Self::wrap_slice_i64
),
DType::F16 => run!(
Self::unwrap_buffer_f16(src)?,
gi::gather_u16_i32,
gi::gather_u16_i64,
gi::isel_u16_i32,
gi::isel_u16_i64,
Self::wrap_buffer_f16
),
DType::BF16 => run!(
Self::unwrap_buffer_bf16(src)?,
gi::gather_u16_i32,
gi::gather_u16_i64,
gi::isel_u16_i32,
gi::isel_u16_i64,
Self::wrap_buffer_bf16
),
other => Err(FerrotorchError::InvalidArgument {
message: format!("{op}: unsupported value dtype {other}"),
}),
}
}
fn map_gpu_err(e: crate::error::GpuError) -> FerrotorchError {
FerrotorchError::InvalidArgument {
message: format!("{e}"),
}
}
}
impl GpuBackend for CudaBackendImpl {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn raw_device_ptr(&self, handle: &GpuBufferHandle) -> *const std::ffi::c_void {
use cudarc::driver::DevicePtr;
let dev = match self.device(handle.device_ordinal()) {
Ok(d) => d,
Err(_) => return std::ptr::null(),
};
let stream = dev.stream();
if let Ok(buf) = Self::unwrap_buffer(handle) {
let (ptr, _sync) = buf.inner().device_ptr(&stream);
ptr as *const std::ffi::c_void
} else if let Ok(buf) = Self::unwrap_buffer_f64(handle) {
let (ptr, _sync) = buf.inner().device_ptr(&stream);
ptr as *const std::ffi::c_void
} else if let Ok(slice) = Self::unwrap_buffer_bf16(handle) {
let (ptr, _sync) = slice.device_ptr(&stream);
ptr as *const std::ffi::c_void
} else if let Ok(slice) = Self::unwrap_buffer_f16(handle) {
let (ptr, _sync) = slice.device_ptr(&stream);
ptr as *const std::ffi::c_void
} else if let Ok(buf) = Self::unwrap_buffer_i32(handle) {
let (ptr, _sync) = buf.inner().device_ptr(&stream);
ptr as *const std::ffi::c_void
} else if let Ok(buf) = Self::unwrap_buffer_i64(handle) {
let (ptr, _sync) = buf.inner().device_ptr(&stream);
ptr as *const std::ffi::c_void
} else if let Ok(buf) = Self::unwrap_buffer_bool(handle) {
let (ptr, _sync) = buf.inner().device_ptr(&stream);
ptr as *const std::ffi::c_void
} else {
std::ptr::null()
}
}
fn raw_device_ptr_mut(&self, handle: &mut GpuBufferHandle) -> *mut std::ffi::c_void {
use cudarc::driver::DevicePtrMut;
let ordinal = handle.device_ordinal();
let dev = match self.device(ordinal) {
Ok(d) => d,
Err(_) => return std::ptr::null_mut(),
};
let stream = dev.stream();
if let Some(buf) = handle.downcast_mut::<CudaBuffer<f32>>() {
let (ptr, _sync) = buf.inner_mut().device_ptr_mut(&stream);
ptr as *mut std::ffi::c_void
} else if let Some(buf) = handle.downcast_mut::<CudaBuffer<f64>>() {
let (ptr, _sync) = buf.inner_mut().device_ptr_mut(&stream);
ptr as *mut std::ffi::c_void
} else if let Some(slice) = handle.downcast_mut::<cudarc::driver::CudaSlice<u16>>() {
let (ptr, _sync) = slice.device_ptr_mut(&stream);
ptr as *mut std::ffi::c_void
} else if let Some(buf) = handle.downcast_mut::<CudaBuffer<i32>>() {
let (ptr, _sync) = buf.inner_mut().device_ptr_mut(&stream);
ptr as *mut std::ffi::c_void
} else if let Some(buf) = handle.downcast_mut::<CudaBuffer<i64>>() {
let (ptr, _sync) = buf.inner_mut().device_ptr_mut(&stream);
ptr as *mut std::ffi::c_void
} else if let Some(buf) = handle.downcast_mut::<CudaBuffer<u8>>() {
let (ptr, _sync) = buf.inner_mut().device_ptr_mut(&stream);
ptr as *mut std::ffi::c_void
} else {
std::ptr::null_mut()
}
}
fn buffer_elem_size(&self, handle: &GpuBufferHandle) -> usize {
handle.dtype().size_of()
}
fn cpu_to_gpu(
&self,
data: &[u8],
dtype: DType,
device: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(device)?;
match dtype {
DType::F32 => {
let count = data.len() / 4;
let f32_data: &[f32] =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, count) };
let buf = crate::transfer::cpu_to_gpu(f32_data, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(buf, device))
}
DType::F64 => {
let count = data.len() / 8;
let f64_data: &[f64] =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f64, count) };
let buf = crate::transfer::cpu_to_gpu(f64_data, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(buf, device))
}
DType::BF16 => {
let count = data.len() / 2;
let u16_vec: Vec<u16> = unsafe {
let slice = std::slice::from_raw_parts(data.as_ptr() as *const u16, count);
slice.to_vec()
};
let slice = dev
.stream()
.clone_htod(&u16_vec)
.map_err(|e| Self::map_gpu_err(crate::error::GpuError::Driver(e)))?;
Ok(Self::wrap_buffer_bf16(slice, device))
}
DType::F16 => {
let count = data.len() / 2;
let u16_vec: Vec<u16> = unsafe {
let slice = std::slice::from_raw_parts(data.as_ptr() as *const u16, count);
slice.to_vec()
};
let slice = dev
.stream()
.clone_htod(&u16_vec)
.map_err(|e| Self::map_gpu_err(crate::error::GpuError::Driver(e)))?;
Ok(Self::wrap_buffer_f16(slice, device))
}
DType::I32 => {
let count = data.len() / 4;
let i32_data: &[i32] =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const i32, count) };
let buf = crate::transfer::cpu_to_gpu(i32_data, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_i32(buf, device))
}
DType::I64 => {
let count = data.len() / 8;
let i64_data: &[i64] =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const i64, count) };
let buf = crate::transfer::cpu_to_gpu(i64_data, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_i64(buf, device))
}
DType::Bool => {
let count = data.len(); debug_assert_eq!(count, data.len());
let buf = crate::transfer::cpu_to_gpu(data, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bool(buf, device))
}
other => Err(FerrotorchError::InvalidArgument {
message: format!(
"cpu_to_gpu: dtype {other} not supported on CUDA \
(supported: F32, F64, BF16, F16, I32, I64, Bool)"
),
}),
}
}
fn cpu_to_gpu_pinned(
&self,
data: &[u8],
dtype: DType,
device: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(device)?;
match dtype {
DType::F32 => {
let count = data.len() / 4;
let f32_data: &[f32] =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, count) };
let buf =
crate::transfer::cpu_to_gpu_pinned(f32_data, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(buf, device))
}
DType::F64 => {
let count = data.len() / 8;
let f64_data: &[f64] =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f64, count) };
let buf =
crate::transfer::cpu_to_gpu_pinned(f64_data, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(buf, device))
}
DType::I32 => {
let count = data.len() / 4;
let i32_data: &[i32] =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const i32, count) };
let buf =
crate::transfer::cpu_to_gpu_pinned(i32_data, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_i32(buf, device))
}
DType::I64 => {
let count = data.len() / 8;
let i64_data: &[i64] =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const i64, count) };
let buf =
crate::transfer::cpu_to_gpu_pinned(i64_data, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_i64(buf, device))
}
DType::Bool => {
let buf =
crate::transfer::cpu_to_gpu_pinned(data, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bool(buf, device))
}
other => Err(FerrotorchError::InvalidArgument {
message: format!(
"cpu_to_gpu_pinned: dtype {other} not supported on CUDA \
(supported: F32, F64, I32, I64, Bool)"
),
}),
}
}
fn gpu_to_cpu(&self, handle: &GpuBufferHandle) -> FerrotorchResult<Vec<u8>> {
let dev = self.device(handle.device_ordinal())?;
if let Ok(buf) = Self::unwrap_buffer(handle) {
let f32_data = crate::transfer::gpu_to_cpu(buf, dev).map_err(Self::map_gpu_err)?;
let bytes = unsafe {
let mut v = std::mem::ManuallyDrop::new(f32_data);
let ptr = v.as_mut_ptr() as *mut u8;
let len = v.len() * 4;
let cap = v.capacity() * 4;
Vec::from_raw_parts(ptr, len, cap)
};
Ok(bytes)
} else if let Ok(buf) = Self::unwrap_buffer_f64(handle) {
let f64_data = crate::transfer::gpu_to_cpu(buf, dev).map_err(Self::map_gpu_err)?;
let bytes = unsafe {
let mut v = std::mem::ManuallyDrop::new(f64_data);
let ptr = v.as_mut_ptr() as *mut u8;
let len = v.len() * 8;
let cap = v.capacity() * 8;
Vec::from_raw_parts(ptr, len, cap)
};
Ok(bytes)
} else if let Ok(slice) = Self::unwrap_buffer_bf16(handle) {
let u16_data = dev
.stream()
.clone_dtoh(slice)
.map_err(|e| Self::map_gpu_err(crate::error::GpuError::Driver(e)))?;
let bytes = unsafe {
let mut v = std::mem::ManuallyDrop::new(u16_data);
let ptr = v.as_mut_ptr() as *mut u8;
let len = v.len() * 2;
let cap = v.capacity() * 2;
Vec::from_raw_parts(ptr, len, cap)
};
Ok(bytes)
} else if let Ok(slice) = Self::unwrap_buffer_f16(handle) {
let u16_data = dev
.stream()
.clone_dtoh(slice)
.map_err(|e| Self::map_gpu_err(crate::error::GpuError::Driver(e)))?;
let bytes = unsafe {
let mut v = std::mem::ManuallyDrop::new(u16_data);
let ptr = v.as_mut_ptr() as *mut u8;
let len = v.len() * 2;
let cap = v.capacity() * 2;
Vec::from_raw_parts(ptr, len, cap)
};
Ok(bytes)
} else if let Ok(buf) = Self::unwrap_buffer_i32(handle) {
let i32_data = crate::transfer::gpu_to_cpu(buf, dev).map_err(Self::map_gpu_err)?;
let bytes = unsafe {
let mut v = std::mem::ManuallyDrop::new(i32_data);
let ptr = v.as_mut_ptr() as *mut u8;
let len = v.len() * 4;
let cap = v.capacity() * 4;
Vec::from_raw_parts(ptr, len, cap)
};
Ok(bytes)
} else if let Ok(buf) = Self::unwrap_buffer_i64(handle) {
let i64_data = crate::transfer::gpu_to_cpu(buf, dev).map_err(Self::map_gpu_err)?;
let bytes = unsafe {
let mut v = std::mem::ManuallyDrop::new(i64_data);
let ptr = v.as_mut_ptr() as *mut u8;
let len = v.len() * 8;
let cap = v.capacity() * 8;
Vec::from_raw_parts(ptr, len, cap)
};
Ok(bytes)
} else if let Ok(buf) = Self::unwrap_buffer_bool(handle) {
let u8_data = crate::transfer::gpu_to_cpu(buf, dev).map_err(Self::map_gpu_err)?;
Ok(u8_data)
} else {
Err(FerrotorchError::InvalidArgument {
message: "gpu_to_cpu: handle is not a recognised dtype \
(expected CudaBuffer<f32>, CudaBuffer<f64>, \
CudaSlice<u16> for bf16/f16, CudaBuffer<i32>/<i64>, \
or CudaBuffer<u8> for bool)"
.into(),
})
}
}
fn clone_buffer(&self, handle: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let ordinal = handle.device_ordinal();
let map_drv = |e| Self::map_gpu_err(crate::error::GpuError::Driver(e));
match handle.dtype() {
DType::F32 => {
let buf = Self::unwrap_buffer(handle)?;
let slice = buf.inner().try_clone().map_err(map_drv)?;
let cloned = CudaBuffer {
data: Some(slice),
len: buf.len(),
alloc_len: buf.alloc_len(),
device_ordinal: ordinal,
pool_fn: None,
};
Ok(Self::wrap_buffer(cloned, ordinal))
}
DType::F64 => {
let buf = Self::unwrap_buffer_f64(handle)?;
let slice = buf.inner().try_clone().map_err(map_drv)?;
let cloned = CudaBuffer {
data: Some(slice),
len: buf.len(),
alloc_len: buf.alloc_len(),
device_ordinal: ordinal,
pool_fn: None,
};
Ok(Self::wrap_buffer_f64(cloned, ordinal))
}
DType::BF16 => {
let slice = Self::unwrap_buffer_bf16(handle)?;
let cloned = slice.try_clone().map_err(map_drv)?;
Ok(Self::wrap_buffer_bf16(cloned, ordinal))
}
DType::F16 => {
let slice = Self::unwrap_buffer_f16(handle)?;
let cloned = slice.try_clone().map_err(map_drv)?;
Ok(Self::wrap_buffer_f16(cloned, ordinal))
}
DType::I32 => {
let buf = Self::unwrap_buffer_i32(handle)?;
let slice = buf.inner().try_clone().map_err(map_drv)?;
let cloned = CudaBuffer {
data: Some(slice),
len: buf.len(),
alloc_len: buf.alloc_len(),
device_ordinal: ordinal,
pool_fn: None,
};
Ok(Self::wrap_buffer_i32(cloned, ordinal))
}
DType::I64 => {
let buf = Self::unwrap_buffer_i64(handle)?;
let slice = buf.inner().try_clone().map_err(map_drv)?;
let cloned = CudaBuffer {
data: Some(slice),
len: buf.len(),
alloc_len: buf.alloc_len(),
device_ordinal: ordinal,
pool_fn: None,
};
Ok(Self::wrap_buffer_i64(cloned, ordinal))
}
DType::Bool => {
let buf = Self::unwrap_buffer_bool(handle)?;
let slice = buf.inner().try_clone().map_err(map_drv)?;
let cloned = CudaBuffer {
data: Some(slice),
len: buf.len(),
alloc_len: buf.alloc_len(),
device_ordinal: ordinal,
pool_fn: None,
};
Ok(Self::wrap_buffer_bool(cloned, ordinal))
}
other => Err(FerrotorchError::InvalidArgument {
message: format!(
"clone_buffer: dtype {other} has no device-to-device copy \
path on CUDA (supported: F32, F64, BF16, F16, I32, I64, Bool)"
),
}),
}
}
fn has_inf_nan_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<bool> {
let buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
crate::kernels::gpu_has_inf_nan(buf, dev).map_err(Self::map_gpu_err)
}
fn alloc_zeros(
&self,
len: usize,
dtype: DType,
device: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(device)?;
match dtype {
DType::BF16 => {
let slice =
crate::transfer::alloc_zeros_bf16(len, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bf16(slice, device))
}
DType::F16 => {
let slice =
crate::transfer::alloc_zeros_bf16(len, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f16(slice, device))
}
DType::F32 => {
let buf = crate::transfer::alloc_zeros_f32(len, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(buf, device))
}
DType::F64 => {
let buf = crate::transfer::alloc_zeros_f64(len, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(buf, device))
}
DType::I32 => {
let buf: CudaBuffer<i32> =
crate::transfer::alloc_zeros(len, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_i32(buf, device))
}
DType::I64 => {
let buf: CudaBuffer<i64> =
crate::transfer::alloc_zeros(len, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_i64(buf, device))
}
DType::Bool => {
let buf: CudaBuffer<u8> =
crate::transfer::alloc_zeros(len, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bool(buf, device))
}
other => Err(FerrotorchError::InvalidArgument {
message: format!(
"alloc_zeros: dtype {other} not supported on CUDA \
(supported: F32, F64, BF16, F16, I32, I64, Bool)"
),
}),
}
}
fn add_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let b_buf = Self::unwrap_buffer(b)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_add(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn sub_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let b_buf = Self::unwrap_buffer(b)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_sub(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn mul_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let b_buf = Self::unwrap_buffer(b)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_mul(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn neg_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_neg(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn relu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_relu(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn div_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let b_buf = Self::unwrap_buffer(b)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_div(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn exp_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_exp(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn log_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_log(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn sqrt_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_sqrt(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn pow_f32(&self, a: &GpuBufferHandle, exponent: f32) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_pow(a_buf, exponent, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn abs_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_abs(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn sigmoid_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_sigmoid(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn tanh_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_tanh(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn add_f64(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let b_buf = Self::unwrap_buffer_f64(b)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_add_f64(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn sub_f64(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let b_buf = Self::unwrap_buffer_f64(b)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_sub_f64(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn mul_f64(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let b_buf = Self::unwrap_buffer_f64(b)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_mul_f64(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn div_f64(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let b_buf = Self::unwrap_buffer_f64(b)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_div_f64(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn neg_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_neg_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn relu_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_relu_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn scale_f64(&self, a: &GpuBufferHandle, scalar: f64) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::kernels::gpu_scale_f64(a_buf, scalar, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn exp_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_exp_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn log_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_log_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn sqrt_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_sqrt_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn pow_f64(&self, a: &GpuBufferHandle, exponent: f64) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::kernels::gpu_pow_f64(a_buf, exponent, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn abs_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_abs_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn sigmoid_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_sigmoid_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn tanh_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_tanh_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn relu_backward_f64(
&self,
grad: &GpuBufferHandle,
input: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let g_buf = Self::unwrap_buffer_f64(grad)?;
let i_buf = Self::unwrap_buffer_f64(input)?;
let dev = self.device(grad.device_ordinal())?;
let result =
crate::kernels::gpu_relu_backward_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
}
fn abs_backward_f64(
&self,
grad: &GpuBufferHandle,
input: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let g_buf = Self::unwrap_buffer_f64(grad)?;
let i_buf = Self::unwrap_buffer_f64(input)?;
let dev = self.device(grad.device_ordinal())?;
let result =
crate::kernels::gpu_abs_backward_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
}
fn sigmoid_backward_f64(
&self,
grad: &GpuBufferHandle,
output: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let g_buf = Self::unwrap_buffer_f64(grad)?;
let o_buf = Self::unwrap_buffer_f64(output)?;
let dev = self.device(grad.device_ordinal())?;
let result = crate::kernels::gpu_sigmoid_backward_f64(g_buf, o_buf, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
}
fn tanh_backward_f64(
&self,
grad: &GpuBufferHandle,
output: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let g_buf = Self::unwrap_buffer_f64(grad)?;
let o_buf = Self::unwrap_buffer_f64(output)?;
let dev = self.device(grad.device_ordinal())?;
let result =
crate::kernels::gpu_tanh_backward_f64(g_buf, o_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
}
fn gelu_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_gelu_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn gelu_tanh_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_gelu_tanh_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn gelu_erf_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_gelu_erf_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn silu_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_silu_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn elu_f64(&self, a: &GpuBufferHandle, alpha: f64) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_elu_f64(a_buf, alpha, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn mish_f64(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_mish_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn clamp_f64(
&self,
a: &GpuBufferHandle,
min_val: f64,
max_val: f64,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_clamp_f64(a_buf, min_val, max_val, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn clamp_backward_f64(
&self,
grad: &GpuBufferHandle,
input: &GpuBufferHandle,
min_val: f64,
max_val: f64,
) -> FerrotorchResult<GpuBufferHandle> {
let g_buf = Self::unwrap_buffer_f64(grad)?;
let i_buf = Self::unwrap_buffer_f64(input)?;
let dev = self.device(input.device_ordinal())?;
let result = crate::kernels::gpu_clamp_backward_f64(g_buf, i_buf, min_val, max_val, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
}
fn gelu_backward_f64(
&self,
grad: &GpuBufferHandle,
input: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let g_buf = Self::unwrap_buffer_f64(grad)?;
let i_buf = Self::unwrap_buffer_f64(input)?;
let dev = self.device(grad.device_ordinal())?;
let result =
crate::kernels::gpu_gelu_backward_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
}
fn gelu_backward_tanh_f64(
&self,
grad: &GpuBufferHandle,
input: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let g_buf = Self::unwrap_buffer_f64(grad)?;
let i_buf = Self::unwrap_buffer_f64(input)?;
let dev = self.device(grad.device_ordinal())?;
let result = crate::kernels::gpu_gelu_backward_tanh_f64(g_buf, i_buf, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
}
fn gelu_backward_erf_f64(
&self,
grad: &GpuBufferHandle,
input: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let g_buf = Self::unwrap_buffer_f64(grad)?;
let i_buf = Self::unwrap_buffer_f64(input)?;
let dev = self.device(grad.device_ordinal())?;
let result = crate::kernels::gpu_gelu_backward_erf_f64(g_buf, i_buf, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
}
fn silu_backward_f64(
&self,
grad: &GpuBufferHandle,
input: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let g_buf = Self::unwrap_buffer_f64(grad)?;
let i_buf = Self::unwrap_buffer_f64(input)?;
let dev = self.device(grad.device_ordinal())?;
let result =
crate::kernels::gpu_silu_backward_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
}
fn elu_backward_f64(
&self,
grad: &GpuBufferHandle,
input: &GpuBufferHandle,
alpha: f64,
) -> FerrotorchResult<GpuBufferHandle> {
let g_buf = Self::unwrap_buffer_f64(grad)?;
let i_buf = Self::unwrap_buffer_f64(input)?;
let dev = self.device(grad.device_ordinal())?;
let result = crate::kernels::gpu_elu_backward_f64(g_buf, i_buf, alpha, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
}
fn mish_backward_f64(
&self,
grad: &GpuBufferHandle,
input: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let g_buf = Self::unwrap_buffer_f64(grad)?;
let i_buf = Self::unwrap_buffer_f64(input)?;
let dev = self.device(grad.device_ordinal())?;
let result =
crate::kernels::gpu_mish_backward_f64(g_buf, i_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
}
fn cumsum_f64(
&self,
a: &GpuBufferHandle,
outer: usize,
dim_size: usize,
inner: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_cumsum_f64(a_buf, outer, dim_size, inner, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn cumprod_f64(
&self,
a: &GpuBufferHandle,
outer: usize,
dim_size: usize,
inner: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_cumprod_f64(a_buf, outer, dim_size, inner, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn cummax_f64(
&self,
a: &GpuBufferHandle,
outer: usize,
dim_size: usize,
inner: usize,
) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let (vals, idxs) = crate::kernels::gpu_cummax_f64(a_buf, outer, dim_size, inner, dev)
.map_err(Self::map_gpu_err)?;
let ord = a.device_ordinal();
Ok((
Self::wrap_buffer_f64(vals, ord),
Self::wrap_buffer_f64(idxs, ord),
))
}
fn cummin_f64(
&self,
a: &GpuBufferHandle,
outer: usize,
dim_size: usize,
inner: usize,
) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let (vals, idxs) = crate::kernels::gpu_cummin_f64(a_buf, outer, dim_size, inner, dev)
.map_err(Self::map_gpu_err)?;
let ord = a.device_ordinal();
Ok((
Self::wrap_buffer_f64(vals, ord),
Self::wrap_buffer_f64(idxs, ord),
))
}
fn logcumsumexp_f64(
&self,
a: &GpuBufferHandle,
outer: usize,
dim_size: usize,
inner: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_logcumsumexp_f64(a_buf, outer, dim_size, inner, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn transpose_2d_f64(
&self,
a: &GpuBufferHandle,
m: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::kernels::gpu_transpose_2d_f64(a_buf, m, n, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn permute_0213_f64(
&self,
a: &GpuBufferHandle,
d0: usize,
d1: usize,
d2: usize,
d3: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_permute_0213_f64(a_buf, d0, d1, d2, d3, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn broadcast_add_f64(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let b_buf = Self::unwrap_buffer_f64(b)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::kernels::gpu_broadcast_add_f64(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn broadcast_sub_f64(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let b_buf = Self::unwrap_buffer_f64(b)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::kernels::gpu_broadcast_sub_f64(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn broadcast_mul_f64(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let b_buf = Self::unwrap_buffer_f64(b)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::kernels::gpu_broadcast_mul_f64(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn broadcast_div_f64(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let b_buf = Self::unwrap_buffer_f64(b)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::kernels::gpu_broadcast_div_f64(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn sum_f64(&self, a: &GpuBufferHandle, _n: usize) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_reduce_sum_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn prod_f64(&self, a: &GpuBufferHandle, _n: usize) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_reduce_prod_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn prod_backward_f64(
&self,
input: &GpuBufferHandle,
grad_output: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let input_buf = Self::unwrap_buffer_f64(input)?;
let grad_buf = Self::unwrap_buffer_f64(grad_output)?;
let dev = self.device(input.device_ordinal())?;
let result = crate::kernels::gpu_prod_backward_f64(input_buf, grad_buf, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
}
fn min_f64(&self, a: &GpuBufferHandle, _n: usize) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_reduce_min_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn max_f64(&self, a: &GpuBufferHandle, _n: usize) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_reduce_max_f64(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn masked_min_f64(
&self,
data: &GpuBufferHandle,
mask_f: &GpuBufferHandle,
_len: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let d_buf = Self::unwrap_buffer_f64(data)?;
let m_buf = Self::unwrap_buffer_f64(mask_f)?;
let dev = self.device(data.device_ordinal())?;
let result = crate::kernels::gpu_masked_reduce_min_f64(d_buf, m_buf, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, data.device_ordinal()))
}
fn masked_max_f64(
&self,
data: &GpuBufferHandle,
mask_f: &GpuBufferHandle,
_len: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let d_buf = Self::unwrap_buffer_f64(data)?;
let m_buf = Self::unwrap_buffer_f64(mask_f)?;
let dev = self.device(data.device_ordinal())?;
let result = crate::kernels::gpu_masked_reduce_max_f64(d_buf, m_buf, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, data.device_ordinal()))
}
fn sum_axis_f64(
&self,
a: &GpuBufferHandle,
shape: &[usize],
axis: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let outer: usize = shape[..axis].iter().product();
let axis_size = shape[axis];
let inner: usize = shape[axis + 1..].iter().product();
let result = crate::kernels::gpu_sum_axis_f64(a_buf, outer, axis_size, inner, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn softmax_f64(
&self,
a: &GpuBufferHandle,
rows: usize,
cols: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::kernels::gpu_softmax_f64(a_buf, rows, cols, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn softmax_backward_f64(
&self,
grad: &GpuBufferHandle,
output: &GpuBufferHandle,
cols: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let grad_buf = Self::unwrap_buffer_f64(grad)?;
let output_buf = Self::unwrap_buffer_f64(output)?;
let dev = self.device(grad.device_ordinal())?;
let result = crate::kernels::gpu_softmax_backward_f64(grad_buf, output_buf, cols, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
}
fn log_softmax_f64(
&self,
a: &GpuBufferHandle,
cols: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::kernels::gpu_log_softmax_f64(a_buf, cols, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn log_softmax_backward_f64(
&self,
grad: &GpuBufferHandle,
output: &GpuBufferHandle,
cols: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let grad_buf = Self::unwrap_buffer_f64(grad)?;
let output_buf = Self::unwrap_buffer_f64(output)?;
let dev = self.device(grad.device_ordinal())?;
let result = crate::kernels::gpu_log_softmax_backward_f64(grad_buf, output_buf, cols, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
}
fn layernorm_f64(
&self,
input: &GpuBufferHandle,
weight: &GpuBufferHandle,
bias: &GpuBufferHandle,
rows: usize,
cols: usize,
eps: f64,
) -> FerrotorchResult<GpuBufferHandle> {
let in_buf = Self::unwrap_buffer_f64(input)?;
let w_buf = Self::unwrap_buffer_f64(weight)?;
let b_buf = Self::unwrap_buffer_f64(bias)?;
let dev = self.device(input.device_ordinal())?;
let result = crate::kernels::gpu_layernorm_f64(in_buf, w_buf, b_buf, rows, cols, eps, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
}
fn layernorm_backward_f64(
&self,
input: &GpuBufferHandle,
grad_output: &GpuBufferHandle,
weight: &GpuBufferHandle,
rows: usize,
cols: usize,
eps: f64,
) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
let in_buf = Self::unwrap_buffer_f64(input)?;
let go_buf = Self::unwrap_buffer_f64(grad_output)?;
let w_buf = Self::unwrap_buffer_f64(weight)?;
let dev = self.device(input.device_ordinal())?;
let (gi, gw, gb) =
crate::kernels::gpu_layernorm_backward_f64(in_buf, go_buf, w_buf, rows, cols, eps, dev)
.map_err(Self::map_gpu_err)?;
let ordinal = input.device_ordinal();
Ok((
Self::wrap_buffer_f64(gi, ordinal),
Self::wrap_buffer_f64(gw, ordinal),
Self::wrap_buffer_f64(gb, ordinal),
))
}
fn rmsnorm_f64(
&self,
input: &GpuBufferHandle,
weight: &GpuBufferHandle,
rows: usize,
cols: usize,
eps: f64,
) -> FerrotorchResult<GpuBufferHandle> {
let in_buf = Self::unwrap_buffer_f64(input)?;
let w_buf = Self::unwrap_buffer_f64(weight)?;
let dev = self.device(input.device_ordinal())?;
let result = crate::kernels::gpu_rmsnorm_f64(in_buf, w_buf, rows, cols, eps, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
}
fn rmsnorm_backward_f64(
&self,
input: &GpuBufferHandle,
grad_output: &GpuBufferHandle,
weight: &GpuBufferHandle,
rows: usize,
cols: usize,
eps: f64,
) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
let in_buf = Self::unwrap_buffer_f64(input)?;
let go_buf = Self::unwrap_buffer_f64(grad_output)?;
let w_buf = Self::unwrap_buffer_f64(weight)?;
let dev = self.device(input.device_ordinal())?;
let (gi, gw) =
crate::kernels::gpu_rmsnorm_backward_f64(in_buf, go_buf, w_buf, rows, cols, eps, dev)
.map_err(Self::map_gpu_err)?;
let ordinal = input.device_ordinal();
Ok((
Self::wrap_buffer_f64(gi, ordinal),
Self::wrap_buffer_f64(gw, ordinal),
))
}
fn embed_lookup_f64(
&self,
idx: &GpuBufferHandle,
weight: &GpuBufferHandle,
d: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let idx_buf = Self::unwrap_buffer(idx)?;
let w_buf = Self::unwrap_buffer_f64(weight)?;
let dev = self.device(idx.device_ordinal())?;
let result = crate::kernels::gpu_embed_lookup_f64(idx_buf, w_buf, d, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, idx.device_ordinal()))
}
fn embed_lookup_batch_f64(
&self,
indices: &GpuBufferHandle,
weight: &GpuBufferHandle,
n: usize,
d: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let idx_buf = Self::unwrap_buffer(indices)?;
let w_buf = Self::unwrap_buffer_f64(weight)?;
let dev = self.device(indices.device_ordinal())?;
let result = crate::kernels::gpu_embed_lookup_batch_f64(idx_buf, w_buf, n, d, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, indices.device_ordinal()))
}
fn scatter_add_rows_f64(
&self,
grad_output: &GpuBufferHandle,
indices: &GpuBufferHandle,
num_embeddings: usize,
d: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let go_buf = Self::unwrap_buffer_f64(grad_output)?;
let idx_buf = Self::unwrap_buffer(indices)?;
let dev = self.device(grad_output.device_ordinal())?;
let result =
crate::kernels::gpu_scatter_add_rows_f64(go_buf, idx_buf, num_embeddings, d, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, grad_output.device_ordinal()))
}
fn masked_fill_f64(
&self,
input: &GpuBufferHandle,
mask: &GpuBufferHandle,
value: f64,
) -> FerrotorchResult<GpuBufferHandle> {
let input_buf = Self::unwrap_buffer_f64(input)?;
let mask_f32 = Self::unwrap_buffer(mask)?;
let dev = self.device(input.device_ordinal())?;
let mask_host = crate::transfer::gpu_to_cpu(mask_f32, dev).map_err(Self::map_gpu_err)?;
let mask_u8: Vec<u8> = mask_host
.iter()
.map(|&v| if v != 0.0 { 1u8 } else { 0u8 })
.collect();
let mask_gpu = crate::transfer::cpu_to_gpu(&mask_u8, dev).map_err(Self::map_gpu_err)?;
let result = crate::kernels::gpu_masked_fill_f64(input_buf, &mask_gpu, value, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
}
fn masked_zero_f64(
&self,
grad: &GpuBufferHandle,
mask: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let grad_buf = Self::unwrap_buffer_f64(grad)?;
let mask_f32 = Self::unwrap_buffer(mask)?;
let dev = self.device(grad.device_ordinal())?;
let mask_host = crate::transfer::gpu_to_cpu(mask_f32, dev).map_err(Self::map_gpu_err)?;
let mask_u8: Vec<u8> = mask_host
.iter()
.map(|&v| if v != 0.0 { 1u8 } else { 0u8 })
.collect();
let mask_gpu = crate::transfer::cpu_to_gpu(&mask_u8, dev).map_err(Self::map_gpu_err)?;
let result = crate::kernels::gpu_masked_zero_f64(grad_buf, &mask_gpu, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, grad.device_ordinal()))
}
fn slice_write_f64(
&self,
src: &GpuBufferHandle,
dst: &mut GpuBufferHandle,
n_batch: usize,
d: usize,
max_len: usize,
pos: usize,
) -> FerrotorchResult<()> {
let src_buf = Self::unwrap_buffer_f64(src)?;
let dst_buf = Self::unwrap_buffer_f64_mut(dst)?;
let dev = self.device(src.device_ordinal())?;
crate::kernels::gpu_slice_write_f64(src_buf, dst_buf, n_batch, d, max_len, pos, dev)
.map_err(Self::map_gpu_err)?;
Ok(())
}
fn slice_read_f64(
&self,
src: &GpuBufferHandle,
n_batch: usize,
d: usize,
len: usize,
max_len: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let src_buf = Self::unwrap_buffer_f64(src)?;
let dev = self.device(src.device_ordinal())?;
let result = crate::kernels::gpu_slice_read_f64(src_buf, n_batch, d, len, max_len, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, src.device_ordinal()))
}
fn strided_split_f64(
&self,
input: &GpuBufferHandle,
total_along_axis: usize,
split_offset: usize,
split_size: usize,
inner_size: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let in_buf = Self::unwrap_buffer_f64(input)?;
let dev = self.device(input.device_ordinal())?;
let result = crate::kernels::gpu_strided_split_f64(
in_buf,
total_along_axis,
split_offset,
split_size,
inner_size,
n,
dev,
)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
}
fn index_select_1d_f64(
&self,
input: &GpuBufferHandle,
indices: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let input_buf = Self::unwrap_buffer_f64(input)?;
let idx_buf = Self::unwrap_buffer(indices)?;
let dev = self.device(input.device_ordinal())?;
let result = crate::kernels::gpu_index_select_1d_f64(input_buf, idx_buf, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
}
fn scatter_add_1d_f64(
&self,
grad_output: &GpuBufferHandle,
indices: &GpuBufferHandle,
input_len: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let go_buf = Self::unwrap_buffer_f64(grad_output)?;
let idx_buf = Self::unwrap_buffer(indices)?;
let dev = self.device(grad_output.device_ordinal())?;
let result = crate::kernels::gpu_scatter_add_1d_f64(go_buf, idx_buf, input_len, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, grad_output.device_ordinal()))
}
fn index_select_dim_f64(
&self,
input: &GpuBufferHandle,
indices: &GpuBufferHandle,
outer: usize,
in_dim_size: usize,
out_dim_size: usize,
inner: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let input_buf = Self::unwrap_buffer_f64(input)?;
let idx_buf = Self::unwrap_buffer(indices)?;
let dev = self.device(input.device_ordinal())?;
let result = crate::kernels::gpu_index_select_dim_f64(
input_buf,
idx_buf,
outer,
in_dim_size,
out_dim_size,
inner,
dev,
)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
}
fn bmm_f64(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
batch: usize,
m: usize,
k: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let b_buf = Self::unwrap_buffer_f64(b)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::blas::gpu_bmm_f64(a_buf, b_buf, batch, m, k, n, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn broadcast_bmm_f64(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
a_lead: &[usize],
b_lead: &[usize],
out_lead: &[usize],
m: usize,
k: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let b_buf = Self::unwrap_buffer_f64(b)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::blas::gpu_broadcast_bmm_f64(
a_buf, b_buf, a_lead, b_lead, out_lead, m, k, n, dev,
)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
#[allow(clippy::too_many_arguments)]
fn fused_adam_f32(
&self,
param: &mut GpuBufferHandle,
grad: &GpuBufferHandle,
exp_avg: &mut GpuBufferHandle,
exp_avg_sq: &mut GpuBufferHandle,
beta1: f32,
beta2: f32,
lr: f32,
eps: f32,
bc1: f32,
bc2: f32,
weight_decay: f32,
) -> FerrotorchResult<()> {
let ordinal = param.device_ordinal();
let dev = self.device(ordinal)?;
let p_buf = Self::unwrap_buffer_mut(param)?;
let g_buf = Self::unwrap_buffer(grad)?;
let m_buf = Self::unwrap_buffer_mut(exp_avg)?;
let v_buf = Self::unwrap_buffer_mut(exp_avg_sq)?;
crate::kernels::gpu_fused_adam(
p_buf,
g_buf,
m_buf,
v_buf,
beta1,
beta2,
lr,
eps,
bc1,
bc2,
weight_decay,
dev,
)
.map_err(Self::map_gpu_err)?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn maxpool2d_f32(
&self,
input: &GpuBufferHandle,
batch: usize,
channels: usize,
h_in: usize,
w_in: usize,
kh: usize,
kw: usize,
sh: usize,
sw: usize,
ph: usize,
pw: usize,
) -> FerrotorchResult<(GpuBufferHandle, [usize; 4])> {
let buf = Self::unwrap_buffer(input)?;
let dev = self.device(input.device_ordinal())?;
let (out, shape) = crate::kernels::gpu_maxpool2d(
buf, batch, channels, h_in, w_in, kh, kw, sh, sw, ph, pw, dev,
)
.map_err(Self::map_gpu_err)?;
Ok((Self::wrap_buffer(out, input.device_ordinal()), shape))
}
#[allow(clippy::too_many_arguments)]
fn avgpool2d_f32(
&self,
input: &GpuBufferHandle,
batch: usize,
channels: usize,
h_in: usize,
w_in: usize,
kh: usize,
kw: usize,
sh: usize,
sw: usize,
ph: usize,
pw: usize,
) -> FerrotorchResult<(GpuBufferHandle, [usize; 4])> {
let buf = Self::unwrap_buffer(input)?;
let dev = self.device(input.device_ordinal())?;
let (out, shape) = crate::kernels::gpu_avgpool2d(
buf, batch, channels, h_in, w_in, kh, kw, sh, sw, ph, pw, dev,
)
.map_err(Self::map_gpu_err)?;
Ok((Self::wrap_buffer(out, input.device_ordinal()), shape))
}
#[allow(clippy::too_many_arguments)]
fn conv2d_f32(
&self,
input: &GpuBufferHandle,
weight: &GpuBufferHandle,
bias: Option<&GpuBufferHandle>,
input_shape: [usize; 4],
weight_shape: [usize; 4],
stride: (usize, usize),
padding: (usize, usize),
dilation: (usize, usize),
groups: usize,
) -> FerrotorchResult<(GpuBufferHandle, [usize; 4])> {
let input_buf = Self::unwrap_buffer(input)?;
let weight_buf = Self::unwrap_buffer(weight)?;
let bias_buf = match bias {
Some(b) => Some(Self::unwrap_buffer(b)?),
None => None,
};
let dev = self.device(input.device_ordinal())?;
let (out_buf, out_shape) = crate::conv::gpu_conv2d_f32(
input_buf,
weight_buf,
bias_buf,
input_shape,
weight_shape,
stride,
padding,
dilation,
groups,
dev,
)
.map_err(Self::map_gpu_err)?;
Ok((
Self::wrap_buffer(out_buf, input.device_ordinal()),
out_shape,
))
}
fn fused_gru_cell_f32(
&self,
input_gates: &GpuBufferHandle,
hidden_gates: &GpuBufferHandle,
bias_ih: &GpuBufferHandle,
bias_hh: &GpuBufferHandle,
hx: &GpuBufferHandle,
hidden_size: usize,
) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
let ig = Self::unwrap_buffer(input_gates)?;
let hg = Self::unwrap_buffer(hidden_gates)?;
let bih = Self::unwrap_buffer(bias_ih)?;
let bhh = Self::unwrap_buffer(bias_hh)?;
let hx_buf = Self::unwrap_buffer(hx)?;
let dev = self.device(input_gates.device_ordinal())?;
let (hy, ws) =
crate::kernels::gpu_fused_gru_forward(ig, hg, bih, bhh, hx_buf, hidden_size, dev)
.map_err(Self::map_gpu_err)?;
let ord = input_gates.device_ordinal();
Ok((Self::wrap_buffer(hy, ord), Self::wrap_buffer(ws, ord)))
}
fn synchronize(&self, device: usize) -> FerrotorchResult<()> {
let dev = self.device(device)?;
dev.stream()
.synchronize()
.map_err(|e| FerrotorchError::InvalidArgument {
message: format!("CUDA synchronize failed: {e}"),
})?;
Ok(())
}
fn stream_count(&self, device: usize) -> usize {
crate::stream::StreamPool::pool_size(device)
}
fn matmul_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
m: usize,
k: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let b_buf = Self::unwrap_buffer(b)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::blas::gpu_matmul_f32(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn sum_f32(&self, a: &GpuBufferHandle, _len: usize) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_reduce_sum(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn prod_f32(&self, a: &GpuBufferHandle, _len: usize) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_reduce_prod(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn prod_backward_f32(
&self,
input: &GpuBufferHandle,
grad_output: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let input_buf = Self::unwrap_buffer(input)?;
let grad_buf = Self::unwrap_buffer(grad_output)?;
let dev = self.device(input.device_ordinal())?;
let result = crate::kernels::gpu_prod_backward_f32(input_buf, grad_buf, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, input.device_ordinal()))
}
fn min_f32(&self, a: &GpuBufferHandle, _len: usize) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_reduce_min(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn max_f32(&self, a: &GpuBufferHandle, _len: usize) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_reduce_max(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn masked_min_f32(
&self,
data: &GpuBufferHandle,
mask_f: &GpuBufferHandle,
_len: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let d_buf = Self::unwrap_buffer(data)?;
let m_buf = Self::unwrap_buffer(mask_f)?;
let dev = self.device(data.device_ordinal())?;
let result =
crate::kernels::gpu_masked_reduce_min(d_buf, m_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, data.device_ordinal()))
}
fn masked_max_f32(
&self,
data: &GpuBufferHandle,
mask_f: &GpuBufferHandle,
_len: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let d_buf = Self::unwrap_buffer(data)?;
let m_buf = Self::unwrap_buffer(mask_f)?;
let dev = self.device(data.device_ordinal())?;
let result =
crate::kernels::gpu_masked_reduce_max(d_buf, m_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, data.device_ordinal()))
}
fn matmul_f64(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
m: usize,
k: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let b_buf = Self::unwrap_buffer_f64(b)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::blas::gpu_matmul_f64(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn dot_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let b_buf = Self::unwrap_buffer(b)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::blas::gpu_dot_f32(a_buf, b_buf, n, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn dot_f64(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let b_buf = Self::unwrap_buffer_f64(b)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::blas::gpu_dot_f64(a_buf, b_buf, n, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn mv_f32(
&self,
a: &GpuBufferHandle,
x: &GpuBufferHandle,
m: usize,
k: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let x_buf = Self::unwrap_buffer(x)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::blas::gpu_mv_f32(a_buf, x_buf, m, k, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn mv_f64(
&self,
a: &GpuBufferHandle,
x: &GpuBufferHandle,
m: usize,
k: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let x_buf = Self::unwrap_buffer_f64(x)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::blas::gpu_mv_f64(a_buf, x_buf, m, k, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn vm_f32(
&self,
x: &GpuBufferHandle,
b: &GpuBufferHandle,
k: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let x_buf = Self::unwrap_buffer(x)?;
let b_buf = Self::unwrap_buffer(b)?;
let dev = self.device(x.device_ordinal())?;
let result = crate::blas::gpu_vm_f32(x_buf, b_buf, k, n, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, x.device_ordinal()))
}
fn vm_f64(
&self,
x: &GpuBufferHandle,
b: &GpuBufferHandle,
k: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let x_buf = Self::unwrap_buffer_f64(x)?;
let b_buf = Self::unwrap_buffer_f64(b)?;
let dev = self.device(x.device_ordinal())?;
let result = crate::blas::gpu_vm_f64(x_buf, b_buf, k, n, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, x.device_ordinal()))
}
fn broadcast_add_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let b_buf = Self::unwrap_buffer(b)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::kernels::gpu_broadcast_add(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn broadcast_sub_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let b_buf = Self::unwrap_buffer(b)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::kernels::gpu_broadcast_sub(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn broadcast_mul_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let b_buf = Self::unwrap_buffer(b)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::kernels::gpu_broadcast_mul(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn broadcast_div_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let b_buf = Self::unwrap_buffer(b)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::kernels::gpu_broadcast_div(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn softmax_f32(
&self,
a: &GpuBufferHandle,
rows: usize,
cols: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::kernels::gpu_softmax(a_buf, rows, cols, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn dropout_f32(
&self,
a: &GpuBufferHandle,
threshold: u32,
scale: f32,
seed: u32,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_dropout(a_buf, threshold, scale, seed, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn dropout_philox_f32(
&self,
a: &GpuBufferHandle,
threshold: u32,
scale: f32,
) -> FerrotorchResult<(GpuBufferHandle, GpuRngState)> {
let device_ordinal = a.device_ordinal();
let n = a.len();
let rng_state = {
let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
FerrotorchError::InvalidArgument {
message: "failed to lock CUDA RNG manager".into(),
}
})?;
let philox_gen = mgr.generator(device_ordinal);
let state = philox_gen.get_state();
let counters_needed = n.div_ceil(4);
philox_gen.advance(counters_needed as u64);
state
};
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(device_ordinal)?;
let derived_seed = (rng_state.counter ^ rng_state.seed) as u32;
let result = crate::kernels::gpu_dropout(a_buf, threshold, scale, derived_seed, dev)
.map_err(Self::map_gpu_err)?;
let gpu_rng_state = GpuRngState::new(
rng_state.counter,
rng_state.seed,
rng_state.offset,
device_ordinal,
);
Ok((Self::wrap_buffer(result, device_ordinal), gpu_rng_state))
}
fn dropout_f64(
&self,
a: &GpuBufferHandle,
threshold: u32,
scale: f64,
seed: u32,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_dropout_f64(a_buf, threshold, scale, seed, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, a.device_ordinal()))
}
fn dropout_philox_f64(
&self,
a: &GpuBufferHandle,
threshold: u32,
scale: f64,
) -> FerrotorchResult<(GpuBufferHandle, GpuRngState)> {
let device_ordinal = a.device_ordinal();
let n = a.len();
let rng_state = {
let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
FerrotorchError::InvalidArgument {
message: "failed to lock CUDA RNG manager".into(),
}
})?;
let philox_gen = mgr.generator(device_ordinal);
let state = philox_gen.get_state();
let counters_needed = n.div_ceil(4);
philox_gen.advance(counters_needed as u64);
state
};
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(device_ordinal)?;
let derived_seed = (rng_state.counter ^ rng_state.seed) as u32;
let result = crate::kernels::gpu_dropout_f64(a_buf, threshold, scale, derived_seed, dev)
.map_err(Self::map_gpu_err)?;
let gpu_rng_state = GpuRngState::new(
rng_state.counter,
rng_state.seed,
rng_state.offset,
device_ordinal,
);
Ok((Self::wrap_buffer_f64(result, device_ordinal), gpu_rng_state))
}
fn transpose_2d_f32(
&self,
a: &GpuBufferHandle,
m: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::kernels::gpu_transpose_2d(a_buf, m, n, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn permute_0213_f32(
&self,
a: &GpuBufferHandle,
d0: usize,
d1: usize,
d2: usize,
d3: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_permute_0213(a_buf, d0, d1, d2, d3, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn bmm_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
batch: usize,
m: usize,
k: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let b_buf = Self::unwrap_buffer(b)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::blas::gpu_bmm_f32(a_buf, b_buf, batch, m, k, n, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn broadcast_bmm_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
a_lead: &[usize],
b_lead: &[usize],
out_lead: &[usize],
m: usize,
k: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let b_buf = Self::unwrap_buffer(b)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::blas::gpu_broadcast_bmm_f32(
a_buf, b_buf, a_lead, b_lead, out_lead, m, k, n, dev,
)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn bmm_f16_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
batch: usize,
m: usize,
k: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let b_buf = Self::unwrap_buffer(b)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::blas::gpu_bmm_f16(a_buf, b_buf, batch, m, k, n, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn matmul_bf16_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
m: usize,
k: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let b_buf = Self::unwrap_buffer(b)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::blas::gpu_matmul_bf16(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn bmm_bf16_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
batch: usize,
m: usize,
k: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let b_buf = Self::unwrap_buffer(b)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::blas::gpu_bmm_bf16(a_buf, b_buf, batch, m, k, n, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn matmul_bf16_bf16(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
m: usize,
k: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_bf16(a)?;
let b_buf = Self::unwrap_buffer_bf16(b)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::blas::gpu_matmul_bf16_bf16(a_buf, b_buf, m, k, n, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bf16(result, a.device_ordinal()))
}
fn bmm_bf16_bf16(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
batch: usize,
m: usize,
k: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_bf16(a)?;
let b_buf = Self::unwrap_buffer_bf16(b)?;
let dev = self.device(a.device_ordinal())?;
let stride_a = m * k;
let stride_b = k * n;
let result = crate::blas::gpu_matmul_bf16_bf16_strided_batched(
a_buf, b_buf, m, k, n, batch, stride_a, stride_b, 1.0, dev,
)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bf16(result, a.device_ordinal()))
}
fn softmax_bf16_f32(
&self,
a: &GpuBufferHandle,
rows: usize,
cols: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let buf = a.downcast_ref::<cudarc::driver::CudaSlice<u16>>().ok_or(
FerrotorchError::InvalidArgument {
message: "softmax_bf16_f32: GPU handle does not contain a CudaSlice<u16> (bf16)"
.into(),
},
)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_softmax_bf16_f32(buf, rows, cols, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn add_bf16_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = a.downcast_ref::<cudarc::driver::CudaSlice<u16>>().ok_or(
FerrotorchError::InvalidArgument {
message: "add_bf16_f32: handle `a` does not contain CudaSlice<u16> (bf16)".into(),
},
)?;
let b_buf = b.downcast_ref::<cudarc::driver::CudaSlice<u16>>().ok_or(
FerrotorchError::InvalidArgument {
message: "add_bf16_f32: handle `b` does not contain CudaSlice<u16> (bf16)".into(),
},
)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::kernels::gpu_add_bf16_f32(a_buf, b_buf, n, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn sub_bf16_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = a.downcast_ref::<cudarc::driver::CudaSlice<u16>>().ok_or(
FerrotorchError::InvalidArgument {
message: "sub_bf16_f32: handle `a` does not contain CudaSlice<u16> (bf16)".into(),
},
)?;
let b_buf = b.downcast_ref::<cudarc::driver::CudaSlice<u16>>().ok_or(
FerrotorchError::InvalidArgument {
message: "sub_bf16_f32: handle `b` does not contain CudaSlice<u16> (bf16)".into(),
},
)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::kernels::gpu_sub_bf16_f32(a_buf, b_buf, n, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn mul_bf16_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = a.downcast_ref::<cudarc::driver::CudaSlice<u16>>().ok_or(
FerrotorchError::InvalidArgument {
message: "mul_bf16_f32: handle `a` does not contain CudaSlice<u16> (bf16)".into(),
},
)?;
let b_buf = b.downcast_ref::<cudarc::driver::CudaSlice<u16>>().ok_or(
FerrotorchError::InvalidArgument {
message: "mul_bf16_f32: handle `b` does not contain CudaSlice<u16> (bf16)".into(),
},
)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::kernels::gpu_mul_bf16_f32(a_buf, b_buf, n, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn div_bf16_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = a.downcast_ref::<cudarc::driver::CudaSlice<u16>>().ok_or(
FerrotorchError::InvalidArgument {
message: "div_bf16_f32: handle `a` does not contain CudaSlice<u16> (bf16)".into(),
},
)?;
let b_buf = b.downcast_ref::<cudarc::driver::CudaSlice<u16>>().ok_or(
FerrotorchError::InvalidArgument {
message: "div_bf16_f32: handle `b` does not contain CudaSlice<u16> (bf16)".into(),
},
)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::kernels::gpu_div_bf16_f32(a_buf, b_buf, n, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn sum_axis_bf16_f32(
&self,
a: &GpuBufferHandle,
outer: usize,
axis_size: usize,
inner: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = a.downcast_ref::<cudarc::driver::CudaSlice<u16>>().ok_or(
FerrotorchError::InvalidArgument {
message: "sum_axis_bf16_f32: handle does not contain CudaSlice<u16> (bf16)".into(),
},
)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_sum_axis_bf16_f32(a_buf, outer, axis_size, inner, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn mean_axis_bf16_f32(
&self,
a: &GpuBufferHandle,
outer: usize,
axis_size: usize,
inner: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = a.downcast_ref::<cudarc::driver::CudaSlice<u16>>().ok_or(
FerrotorchError::InvalidArgument {
message: "mean_axis_bf16_f32: handle does not contain CudaSlice<u16> (bf16)".into(),
},
)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_mean_axis_bf16_f32(a_buf, outer, axis_size, inner, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn relu_bf16_f32(&self, a: &GpuBufferHandle, n: usize) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = a.downcast_ref::<cudarc::driver::CudaSlice<u16>>().ok_or(
FerrotorchError::InvalidArgument {
message: "relu_bf16_f32: handle does not contain CudaSlice<u16> (bf16)".into(),
},
)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_relu_bf16_f32(a_buf, n, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn sigmoid_bf16_f32(&self, a: &GpuBufferHandle, n: usize) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = a.downcast_ref::<cudarc::driver::CudaSlice<u16>>().ok_or(
FerrotorchError::InvalidArgument {
message: "sigmoid_bf16_f32: handle does not contain CudaSlice<u16> (bf16)".into(),
},
)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::kernels::gpu_sigmoid_bf16_f32(a_buf, n, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn gelu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_gelu(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn gelu_tanh_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_gelu_tanh(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn gelu_erf_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_gelu_erf(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn layernorm_f32(
&self,
input: &GpuBufferHandle,
weight: &GpuBufferHandle,
bias: &GpuBufferHandle,
rows: usize,
cols: usize,
eps: f32,
) -> FerrotorchResult<GpuBufferHandle> {
let in_buf = Self::unwrap_buffer(input)?;
let w_buf = Self::unwrap_buffer(weight)?;
let b_buf = Self::unwrap_buffer(bias)?;
let dev = self.device(input.device_ordinal())?;
let result = crate::kernels::gpu_layernorm(in_buf, w_buf, b_buf, rows, cols, eps, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, input.device_ordinal()))
}
fn rmsnorm_f32(
&self,
input: &GpuBufferHandle,
weight: &GpuBufferHandle,
rows: usize,
cols: usize,
eps: f32,
) -> FerrotorchResult<GpuBufferHandle> {
let in_buf = Self::unwrap_buffer(input)?;
let w_buf = Self::unwrap_buffer(weight)?;
let dev = self.device(input.device_ordinal())?;
let result = crate::kernels::gpu_rmsnorm(in_buf, w_buf, rows, cols, eps, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, input.device_ordinal()))
}
fn rmsnorm_backward_f32(
&self,
input: &GpuBufferHandle,
grad_output: &GpuBufferHandle,
weight: &GpuBufferHandle,
rows: usize,
cols: usize,
eps: f32,
) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
let in_buf = Self::unwrap_buffer(input)?;
let go_buf = Self::unwrap_buffer(grad_output)?;
let w_buf = Self::unwrap_buffer(weight)?;
let dev = self.device(input.device_ordinal())?;
let (gi, gw) =
crate::kernels::gpu_rmsnorm_backward(in_buf, go_buf, w_buf, rows, cols, eps, dev)
.map_err(Self::map_gpu_err)?;
let ordinal = input.device_ordinal();
Ok((
Self::wrap_buffer(gi, ordinal),
Self::wrap_buffer(gw, ordinal),
))
}
fn slice_write_f32(
&self,
src: &GpuBufferHandle,
dst: &mut GpuBufferHandle,
n_batch: usize,
d: usize,
max_len: usize,
pos: usize,
) -> FerrotorchResult<()> {
let src_buf = Self::unwrap_buffer(src)?;
let dst_buf =
dst.downcast_mut::<CudaBuffer<f32>>()
.ok_or(FerrotorchError::InvalidArgument {
message: "slice_write_f32: dst is not CudaBuffer<f32>".into(),
})?;
let dev = self.device(src.device_ordinal())?;
crate::kernels::gpu_slice_write(src_buf, dst_buf, n_batch, d, max_len, pos, dev)
.map_err(Self::map_gpu_err)?;
Ok(())
}
fn slice_read_f32(
&self,
src: &GpuBufferHandle,
n_batch: usize,
d: usize,
len: usize,
max_len: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let src_buf = Self::unwrap_buffer(src)?;
let dev = self.device(src.device_ordinal())?;
let result = crate::kernels::gpu_slice_read(src_buf, n_batch, d, len, max_len, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, src.device_ordinal()))
}
fn embed_lookup_f32(
&self,
idx: &GpuBufferHandle,
weight: &GpuBufferHandle,
d: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let idx_buf = Self::unwrap_buffer(idx)?;
let w_buf = Self::unwrap_buffer(weight)?;
let dev = self.device(idx.device_ordinal())?;
let result =
crate::kernels::gpu_embed_lookup(idx_buf, w_buf, d, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, idx.device_ordinal()))
}
fn embed_lookup_batch_f32(
&self,
indices: &GpuBufferHandle,
weight: &GpuBufferHandle,
n: usize,
d: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let idx_buf = Self::unwrap_buffer(indices)?;
let w_buf = Self::unwrap_buffer(weight)?;
let dev = self.device(indices.device_ordinal())?;
let result = crate::kernels::gpu_embed_lookup_batch(idx_buf, w_buf, n, d, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, indices.device_ordinal()))
}
fn scatter_add_rows_f32(
&self,
grad_output: &GpuBufferHandle,
indices: &GpuBufferHandle,
num_embeddings: usize,
d: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let go_buf = Self::unwrap_buffer(grad_output)?;
let idx_buf = Self::unwrap_buffer(indices)?;
let dev = self.device(grad_output.device_ordinal())?;
let result = crate::kernels::gpu_scatter_add_rows(go_buf, idx_buf, num_embeddings, d, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, grad_output.device_ordinal()))
}
fn scale_f32(&self, a: &GpuBufferHandle, scalar: f32) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_scale(a_buf, scalar, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn relu_backward_f32(
&self,
grad: &GpuBufferHandle,
input: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let grad_buf = Self::unwrap_buffer(grad)?;
let input_buf = Self::unwrap_buffer(input)?;
let dev = self.device(grad.device_ordinal())?;
let result = crate::kernels::gpu_relu_backward(grad_buf, input_buf, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, grad.device_ordinal()))
}
fn abs_backward_f32(
&self,
grad: &GpuBufferHandle,
input: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let grad_buf = Self::unwrap_buffer(grad)?;
let input_buf = Self::unwrap_buffer(input)?;
let dev = self.device(grad.device_ordinal())?;
let result = crate::kernels::gpu_abs_backward(grad_buf, input_buf, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, grad.device_ordinal()))
}
fn fill_f32(&self, n: usize, scalar: f32, ordinal: usize) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(ordinal)?;
let result = crate::kernels::gpu_fill_f32(n, scalar, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, ordinal))
}
fn fill_f64(&self, n: usize, scalar: f64, ordinal: usize) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(ordinal)?;
let result = crate::kernels::gpu_fill_f64(n, scalar, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, ordinal))
}
fn gelu_backward_f32(
&self,
grad: &GpuBufferHandle,
input: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let grad_buf = Self::unwrap_buffer(grad)?;
let input_buf = Self::unwrap_buffer(input)?;
let dev = self.device(grad.device_ordinal())?;
let result = crate::kernels::gpu_gelu_backward(grad_buf, input_buf, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, grad.device_ordinal()))
}
fn gelu_backward_tanh_f32(
&self,
grad: &GpuBufferHandle,
input: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let grad_buf = Self::unwrap_buffer(grad)?;
let input_buf = Self::unwrap_buffer(input)?;
let dev = self.device(grad.device_ordinal())?;
let result = crate::kernels::gpu_gelu_backward_tanh(grad_buf, input_buf, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, grad.device_ordinal()))
}
fn gelu_backward_erf_f32(
&self,
grad: &GpuBufferHandle,
input: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let grad_buf = Self::unwrap_buffer(grad)?;
let input_buf = Self::unwrap_buffer(input)?;
let dev = self.device(grad.device_ordinal())?;
let result = crate::kernels::gpu_gelu_backward_erf(grad_buf, input_buf, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, grad.device_ordinal()))
}
fn cumsum_f32(
&self,
a: &GpuBufferHandle,
outer: usize,
dim_size: usize,
inner: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_cumsum(a_buf, outer, dim_size, inner, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn cumprod_f32(
&self,
a: &GpuBufferHandle,
outer: usize,
dim_size: usize,
inner: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_cumprod(a_buf, outer, dim_size, inner, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn cummax_f32(
&self,
a: &GpuBufferHandle,
outer: usize,
dim_size: usize,
inner: usize,
) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let (vals, idxs) = crate::kernels::gpu_cummax(a_buf, outer, dim_size, inner, dev)
.map_err(Self::map_gpu_err)?;
let ord = a.device_ordinal();
Ok((Self::wrap_buffer(vals, ord), Self::wrap_buffer(idxs, ord)))
}
fn cummin_f32(
&self,
a: &GpuBufferHandle,
outer: usize,
dim_size: usize,
inner: usize,
) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let (vals, idxs) = crate::kernels::gpu_cummin(a_buf, outer, dim_size, inner, dev)
.map_err(Self::map_gpu_err)?;
let ord = a.device_ordinal();
Ok((Self::wrap_buffer(vals, ord), Self::wrap_buffer(idxs, ord)))
}
fn logcumsumexp_f32(
&self,
a: &GpuBufferHandle,
outer: usize,
dim_size: usize,
inner: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_logcumsumexp(a_buf, outer, dim_size, inner, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn roll_f32(
&self,
a: &GpuBufferHandle,
outer: usize,
dim_size: usize,
inner: usize,
shift_norm: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::roll::gpu_roll_f32(a_buf, outer, dim_size, inner, shift_norm, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn clamp_f32(
&self,
a: &GpuBufferHandle,
min_val: f32,
max_val: f32,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::kernels::gpu_clamp(a_buf, min_val, max_val, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn clamp_backward_f32(
&self,
grad: &GpuBufferHandle,
input: &GpuBufferHandle,
min_val: f32,
max_val: f32,
) -> FerrotorchResult<GpuBufferHandle> {
let g_buf = Self::unwrap_buffer(grad)?;
let i_buf = Self::unwrap_buffer(input)?;
let dev = self.device(input.device_ordinal())?;
let result = crate::kernels::gpu_clamp_backward(g_buf, i_buf, min_val, max_val, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, input.device_ordinal()))
}
fn silu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_silu(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn silu_backward_f32(
&self,
grad: &GpuBufferHandle,
input: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let grad_buf = Self::unwrap_buffer(grad)?;
let input_buf = Self::unwrap_buffer(input)?;
let dev = self.device(grad.device_ordinal())?;
let result = crate::kernels::gpu_silu_backward(grad_buf, input_buf, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, grad.device_ordinal()))
}
fn elu_f32(&self, a: &GpuBufferHandle, alpha: f32) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_elu(a_buf, alpha, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn elu_backward_f32(
&self,
grad: &GpuBufferHandle,
input: &GpuBufferHandle,
alpha: f32,
) -> FerrotorchResult<GpuBufferHandle> {
let grad_buf = Self::unwrap_buffer(grad)?;
let input_buf = Self::unwrap_buffer(input)?;
let dev = self.device(grad.device_ordinal())?;
let result = crate::kernels::gpu_elu_backward(grad_buf, input_buf, alpha, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, grad.device_ordinal()))
}
fn mish_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::kernels::gpu_mish(a_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn mish_backward_f32(
&self,
grad: &GpuBufferHandle,
input: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let grad_buf = Self::unwrap_buffer(grad)?;
let input_buf = Self::unwrap_buffer(input)?;
let dev = self.device(grad.device_ordinal())?;
let result = crate::kernels::gpu_mish_backward(grad_buf, input_buf, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, grad.device_ordinal()))
}
fn log_softmax_f32(
&self,
a: &GpuBufferHandle,
cols: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::kernels::gpu_log_softmax(a_buf, cols, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn log_softmax_backward_f32(
&self,
grad: &GpuBufferHandle,
output: &GpuBufferHandle,
cols: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let grad_buf = Self::unwrap_buffer(grad)?;
let output_buf = Self::unwrap_buffer(output)?;
let dev = self.device(grad.device_ordinal())?;
let result = crate::kernels::gpu_log_softmax_backward(grad_buf, output_buf, cols, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, grad.device_ordinal()))
}
fn index_select_1d_f32(
&self,
input: &GpuBufferHandle,
indices: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let input_buf = Self::unwrap_buffer(input)?;
let idx_buf = Self::unwrap_buffer(indices)?;
let dev = self.device(input.device_ordinal())?;
let result = crate::kernels::gpu_index_select_1d(input_buf, idx_buf, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, input.device_ordinal()))
}
fn scatter_add_1d_f32(
&self,
grad_output: &GpuBufferHandle,
indices: &GpuBufferHandle,
input_len: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let go_buf = Self::unwrap_buffer(grad_output)?;
let idx_buf = Self::unwrap_buffer(indices)?;
let dev = self.device(grad_output.device_ordinal())?;
let result = crate::kernels::gpu_scatter_add_1d(go_buf, idx_buf, input_len, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, grad_output.device_ordinal()))
}
fn index_select_dim_f32(
&self,
input: &GpuBufferHandle,
indices: &GpuBufferHandle,
outer: usize,
in_dim_size: usize,
out_dim_size: usize,
inner: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let input_buf = Self::unwrap_buffer(input)?;
let idx_buf = Self::unwrap_buffer(indices)?;
let dev = self.device(input.device_ordinal())?;
let result = crate::kernels::gpu_index_select_dim(
input_buf,
idx_buf,
outer,
in_dim_size,
out_dim_size,
inner,
dev,
)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, input.device_ordinal()))
}
fn masked_fill_f32(
&self,
input: &GpuBufferHandle,
mask: &GpuBufferHandle,
value: f32,
) -> FerrotorchResult<GpuBufferHandle> {
let input_buf = Self::unwrap_buffer(input)?;
let mask_buf = Self::unwrap_buffer(mask)?;
let dev = self.device(input.device_ordinal())?;
let result = crate::kernels::gpu_masked_fill(input_buf, mask_buf, value, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, input.device_ordinal()))
}
fn masked_zero_f32(
&self,
grad: &GpuBufferHandle,
mask: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let grad_buf = Self::unwrap_buffer(grad)?;
let mask_buf = Self::unwrap_buffer(mask)?;
let dev = self.device(grad.device_ordinal())?;
let result =
crate::kernels::gpu_masked_zero(grad_buf, mask_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, grad.device_ordinal()))
}
fn sigmoid_backward_f32(
&self,
grad: &GpuBufferHandle,
output: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let grad_buf = Self::unwrap_buffer(grad)?;
let output_buf = Self::unwrap_buffer(output)?;
let dev = self.device(grad.device_ordinal())?;
let result = crate::kernels::gpu_sigmoid_backward(grad_buf, output_buf, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, grad.device_ordinal()))
}
fn tanh_backward_f32(
&self,
grad: &GpuBufferHandle,
output: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let grad_buf = Self::unwrap_buffer(grad)?;
let output_buf = Self::unwrap_buffer(output)?;
let dev = self.device(grad.device_ordinal())?;
let result = crate::kernels::gpu_tanh_backward(grad_buf, output_buf, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, grad.device_ordinal()))
}
fn softmax_backward_f32(
&self,
grad: &GpuBufferHandle,
output: &GpuBufferHandle,
cols: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let grad_buf = Self::unwrap_buffer(grad)?;
let output_buf = Self::unwrap_buffer(output)?;
let dev = self.device(grad.device_ordinal())?;
let result = crate::kernels::gpu_softmax_backward(grad_buf, output_buf, cols, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, grad.device_ordinal()))
}
fn layernorm_backward_f32(
&self,
input: &GpuBufferHandle,
grad_output: &GpuBufferHandle,
weight: &GpuBufferHandle,
rows: usize,
cols: usize,
eps: f32,
) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
let in_buf = Self::unwrap_buffer(input)?;
let go_buf = Self::unwrap_buffer(grad_output)?;
let w_buf = Self::unwrap_buffer(weight)?;
let dev = self.device(input.device_ordinal())?;
let (gi, gw, gb) =
crate::kernels::gpu_layernorm_backward(in_buf, go_buf, w_buf, rows, cols, eps, dev)
.map_err(Self::map_gpu_err)?;
let ordinal = input.device_ordinal();
Ok((
Self::wrap_buffer(gi, ordinal),
Self::wrap_buffer(gw, ordinal),
Self::wrap_buffer(gb, ordinal),
))
}
fn sum_axis_f32(
&self,
a: &GpuBufferHandle,
shape: &[usize],
axis: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let outer: usize = shape[..axis].iter().product();
let axis_size = shape[axis];
let inner: usize = shape[axis + 1..].iter().product::<usize>().max(1);
let result = crate::kernels::gpu_sum_axis(a_buf, outer, axis_size, inner, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn matmul_f16_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
m: usize,
k: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let b_buf = Self::unwrap_buffer(b)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::blas::gpu_matmul_f16(a_buf, b_buf, m, k, n, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, a.device_ordinal()))
}
fn save_rng_state(&self, device: usize) -> FerrotorchResult<GpuRngState> {
let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
FerrotorchError::InvalidArgument {
message: "failed to lock CUDA RNG manager".into(),
}
})?;
let state = mgr.get_rng_state(device);
Ok(GpuRngState::new(
state.counter,
state.seed,
state.offset,
device,
))
}
fn restore_rng_state(&self, state: GpuRngState) -> FerrotorchResult<()> {
let mut mgr = crate::rng::cuda_rng_manager().lock().map_err(|_| {
FerrotorchError::InvalidArgument {
message: "failed to lock CUDA RNG manager".into(),
}
})?;
let philox =
crate::rng::PhiloxState::from_parts(state.counter(), state.seed(), state.offset())
.map_err(Self::map_gpu_err)?;
mgr.set_rng_state(state.device(), philox);
Ok(())
}
fn strided_split_f32(
&self,
input: &GpuBufferHandle,
total_along_axis: usize,
split_offset: usize,
split_size: usize,
inner_size: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let in_buf = Self::unwrap_buffer(input)?;
let dev = self.device(input.device_ordinal())?;
let result = crate::kernels::gpu_strided_split(
in_buf,
total_along_axis,
split_offset,
split_size,
inner_size,
n,
dev,
)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, input.device_ordinal()))
}
fn strided_copy_f32(
&self,
input: &GpuBufferHandle,
out_shape: &[usize],
src_strides: &[isize],
src_offset: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let in_buf = Self::unwrap_buffer(input)?;
let dev = self.device(input.device_ordinal())?;
let result =
crate::kernels::gpu_strided_copy(in_buf, out_shape, src_strides, src_offset, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, input.device_ordinal()))
}
fn strided_copy_f64(
&self,
input: &GpuBufferHandle,
out_shape: &[usize],
src_strides: &[isize],
src_offset: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let in_buf = Self::unwrap_buffer_f64(input)?;
let dev = self.device(input.device_ordinal())?;
let result =
crate::kernels::gpu_strided_copy_f64(in_buf, out_shape, src_strides, src_offset, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, input.device_ordinal()))
}
fn strided_scatter_f32(
&self,
src: &GpuBufferHandle,
dst: &mut GpuBufferHandle,
view_shape: &[usize],
dst_strides: &[isize],
dst_offset: usize,
) -> FerrotorchResult<()> {
let ord = src.device_ordinal();
if dst.device_ordinal() != ord {
return Err(FerrotorchError::DeviceMismatch {
expected: ferrotorch_core::Device::Cuda(ord),
got: ferrotorch_core::Device::Cuda(dst.device_ordinal()),
});
}
let src_buf_ptr = Self::unwrap_buffer(src)? as *const CudaBuffer<f32>;
let dst_buf = Self::unwrap_buffer_mut(dst)?;
let dev = self.device(ord)?;
let src_ref = unsafe { &*src_buf_ptr };
crate::kernels::gpu_strided_scatter(
src_ref,
dst_buf,
view_shape,
dst_strides,
dst_offset,
dev,
)
.map_err(Self::map_gpu_err)
}
fn strided_scatter_f64(
&self,
src: &GpuBufferHandle,
dst: &mut GpuBufferHandle,
view_shape: &[usize],
dst_strides: &[isize],
dst_offset: usize,
) -> FerrotorchResult<()> {
let ord = src.device_ordinal();
if dst.device_ordinal() != ord {
return Err(FerrotorchError::DeviceMismatch {
expected: ferrotorch_core::Device::Cuda(ord),
got: ferrotorch_core::Device::Cuda(dst.device_ordinal()),
});
}
let src_buf_ptr = Self::unwrap_buffer_f64(src)? as *const CudaBuffer<f64>;
let dst_buf = Self::unwrap_buffer_f64_mut(dst)?;
let dev = self.device(ord)?;
let src_ref = unsafe { &*src_buf_ptr };
crate::kernels::gpu_strided_scatter_f64(
src_ref,
dst_buf,
view_shape,
dst_strides,
dst_offset,
dev,
)
.map_err(Self::map_gpu_err)
}
fn strided_cat(
&self,
src: &GpuBufferHandle,
dst: &mut GpuBufferHandle,
total_along_axis: usize,
offset: usize,
t_axis_size: usize,
inner: usize,
t_numel: usize,
elem_size: usize,
) -> FerrotorchResult<()> {
let dev = self.device(src.device_ordinal())?;
match elem_size {
2 => {
let in_slice = Self::unwrap_buffer_bf16(src)?;
let out_slice = dst.downcast_mut::<cudarc::driver::CudaSlice<u16>>().ok_or(
FerrotorchError::InvalidArgument {
message: "strided_cat: output is not a 2-byte (u16) buffer".into(),
},
)?;
crate::kernels::gpu_strided_cat_u16(
in_slice,
out_slice,
total_along_axis,
offset,
t_axis_size,
inner,
t_numel,
dev,
)
.map_err(Self::map_gpu_err)?;
Ok(())
}
4 => {
let in_buf = Self::unwrap_buffer(src)?;
let out_buf = dst.downcast_mut::<CudaBuffer<f32>>().ok_or(
FerrotorchError::InvalidArgument {
message: "strided_cat: output is not CudaBuffer<f32>".into(),
},
)?;
crate::kernels::gpu_strided_cat(
in_buf,
out_buf,
total_along_axis,
offset,
t_axis_size,
inner,
t_numel,
dev,
)
.map_err(Self::map_gpu_err)?;
Ok(())
}
8 => {
let in_buf = Self::unwrap_buffer_f64(src)?;
let out_buf = Self::unwrap_buffer_f64_mut(dst)?;
crate::kernels::gpu_strided_cat_f64(
in_buf,
out_buf,
total_along_axis,
offset,
t_axis_size,
inner,
t_numel,
dev,
)
.map_err(Self::map_gpu_err)?;
Ok(())
}
other => Err(FerrotorchError::InvalidArgument {
message: format!(
"strided_cat: unsupported elem_size={other} on CUDA (supported: 2, 4, 8)"
),
}),
}
}
fn svd_f32(
&self,
a: &GpuBufferHandle,
m: usize,
n: usize,
) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let (u_buf, s_buf, vt_buf) =
crate::cusolver::gpu_svd_f32_dev(a_buf, m, n, dev).map_err(Self::map_gpu_err)?;
let ord = a.device_ordinal();
Ok((
Self::wrap_buffer(u_buf, ord),
Self::wrap_buffer(s_buf, ord),
Self::wrap_buffer(vt_buf, ord),
))
}
fn svd_f64(
&self,
a: &GpuBufferHandle,
m: usize,
n: usize,
) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let (u_buf, s_buf, vt_buf) =
crate::cusolver::gpu_svd_f64_dev(a_buf, m, n, dev).map_err(Self::map_gpu_err)?;
let ord = a.device_ordinal();
Ok((
Self::wrap_buffer_f64(u_buf, ord),
Self::wrap_buffer_f64(s_buf, ord),
Self::wrap_buffer_f64(vt_buf, ord),
))
}
fn cholesky_f32(&self, a: &GpuBufferHandle, n: usize) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let l_buf =
crate::cusolver::gpu_cholesky_f32_dev(a_buf, n, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(l_buf, a.device_ordinal()))
}
fn cholesky_f64(&self, a: &GpuBufferHandle, n: usize) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let l_buf =
crate::cusolver::gpu_cholesky_f64_dev(a_buf, n, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(l_buf, a.device_ordinal()))
}
fn solve_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
n: usize,
nrhs: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let b_buf = Self::unwrap_buffer(b)?;
let dev = self.device(a.device_ordinal())?;
let x = crate::cusolver::gpu_solve_f32_dev(a_buf, b_buf, n, nrhs, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(x, a.device_ordinal()))
}
fn solve_f64(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
n: usize,
nrhs: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let b_buf = Self::unwrap_buffer_f64(b)?;
let dev = self.device(a.device_ordinal())?;
let x = crate::cusolver::gpu_solve_f64_dev(a_buf, b_buf, n, nrhs, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(x, a.device_ordinal()))
}
fn qr_f32(
&self,
a: &GpuBufferHandle,
m: usize,
n: usize,
) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let (q_buf, r_buf) =
crate::cusolver::gpu_qr_f32_dev(a_buf, m, n, dev).map_err(Self::map_gpu_err)?;
let ord = a.device_ordinal();
Ok((Self::wrap_buffer(q_buf, ord), Self::wrap_buffer(r_buf, ord)))
}
fn qr_f64(
&self,
a: &GpuBufferHandle,
m: usize,
n: usize,
) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let (q_buf, r_buf) =
crate::cusolver::gpu_qr_f64_dev(a_buf, m, n, dev).map_err(Self::map_gpu_err)?;
let ord = a.device_ordinal();
Ok((
Self::wrap_buffer_f64(q_buf, ord),
Self::wrap_buffer_f64(r_buf, ord),
))
}
fn lu_factor_f32(
&self,
a: &GpuBufferHandle,
n: usize,
) -> FerrotorchResult<(GpuBufferHandle, Vec<i32>)> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let (lu, ipiv) =
crate::cusolver::gpu_lu_factor_f32(a_buf, n, dev).map_err(Self::map_gpu_err)?;
let ipiv_host = crate::transfer::gpu_to_cpu(&ipiv, dev).map_err(Self::map_gpu_err)?;
Ok((Self::wrap_buffer(lu, a.device_ordinal()), ipiv_host))
}
fn lu_factor_f64(
&self,
a: &GpuBufferHandle,
n: usize,
) -> FerrotorchResult<(GpuBufferHandle, Vec<i32>)> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let (lu, ipiv) =
crate::cusolver::gpu_lu_factor_f64(a_buf, n, dev).map_err(Self::map_gpu_err)?;
let ipiv_host = crate::transfer::gpu_to_cpu(&ipiv, dev).map_err(Self::map_gpu_err)?;
Ok((Self::wrap_buffer_f64(lu, a.device_ordinal()), ipiv_host))
}
fn lstsq_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
m: usize,
n: usize,
nrhs: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let b_buf = Self::unwrap_buffer(b)?;
let dev = self.device(a.device_ordinal())?;
let x = crate::cusolver::gpu_lstsq_f32(a_buf, b_buf, m, n, nrhs, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(x, a.device_ordinal()))
}
fn lstsq_f64(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
m: usize,
n: usize,
nrhs: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let b_buf = Self::unwrap_buffer_f64(b)?;
let dev = self.device(a.device_ordinal())?;
let x = crate::cusolver::gpu_lstsq_f64(a_buf, b_buf, m, n, nrhs, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(x, a.device_ordinal()))
}
fn eig_f32(
&self,
a: &GpuBufferHandle,
n: usize,
) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let (w, v) = crate::cusolver::gpu_eig_f32(a_buf, n, dev).map_err(Self::map_gpu_err)?;
let ord = a.device_ordinal();
Ok((Self::wrap_buffer(w, ord), Self::wrap_buffer(v, ord)))
}
fn eig_f64(
&self,
a: &GpuBufferHandle,
n: usize,
) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let (w, v) = crate::cusolver::gpu_eig_f64(a_buf, n, dev).map_err(Self::map_gpu_err)?;
let ord = a.device_ordinal();
Ok((Self::wrap_buffer_f64(w, ord), Self::wrap_buffer_f64(v, ord)))
}
fn eigh_f32(
&self,
a: &GpuBufferHandle,
n: usize,
) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let (w, v) = crate::cusolver::gpu_eigh_f32(a_buf, n, dev).map_err(Self::map_gpu_err)?;
let ord = a.device_ordinal();
Ok((Self::wrap_buffer(w, ord), Self::wrap_buffer(v, ord)))
}
fn eigh_f64(
&self,
a: &GpuBufferHandle,
n: usize,
) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let (w, v) = crate::cusolver::gpu_eigh_f64(a_buf, n, dev).map_err(Self::map_gpu_err)?;
let ord = a.device_ordinal();
Ok((Self::wrap_buffer_f64(w, ord), Self::wrap_buffer_f64(v, ord)))
}
fn eigvalsh_f32(&self, a: &GpuBufferHandle, n: usize) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let w = crate::cusolver::gpu_eigvalsh_f32(a_buf, n, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(w, a.device_ordinal()))
}
fn eigvalsh_f64(&self, a: &GpuBufferHandle, n: usize) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let w = crate::cusolver::gpu_eigvalsh_f64(a_buf, n, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(w, a.device_ordinal()))
}
fn fft_c2c_f32(
&self,
a: &GpuBufferHandle,
batch: usize,
n: usize,
inverse: bool,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let out = crate::cufft::gpu_fft_c2c_f32(a_buf, batch, n, inverse, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(out, a.device_ordinal()))
}
fn fft_c2c_f64(
&self,
a: &GpuBufferHandle,
batch: usize,
n: usize,
inverse: bool,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let out = crate::cufft::gpu_fft_c2c_f64(a_buf, batch, n, inverse, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(out, a.device_ordinal()))
}
fn pad_truncate_complex_f32(
&self,
src: &GpuBufferHandle,
batch: usize,
src_n: usize,
dst_n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let src_buf = Self::unwrap_buffer(src)?;
let dev = self.device(src.device_ordinal())?;
let out = crate::kernels::gpu_pad_truncate_complex_f32(src_buf, batch, src_n, dst_n, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(out, src.device_ordinal()))
}
fn pad_truncate_complex_f64(
&self,
src: &GpuBufferHandle,
batch: usize,
src_n: usize,
dst_n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let src_buf = Self::unwrap_buffer_f64(src)?;
let dev = self.device(src.device_ordinal())?;
let out = crate::kernels::gpu_pad_truncate_complex_f64(src_buf, batch, src_n, dst_n, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(out, src.device_ordinal()))
}
fn fft2_c2c_f32(
&self,
a: &GpuBufferHandle,
h: usize,
w: usize,
inverse: bool,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let out =
crate::cufft::gpu_fft2_c2c_f32(a_buf, h, w, inverse, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(out, a.device_ordinal()))
}
fn fft2_c2c_f64(
&self,
a: &GpuBufferHandle,
h: usize,
w: usize,
inverse: bool,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let out =
crate::cufft::gpu_fft2_c2c_f64(a_buf, h, w, inverse, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(out, a.device_ordinal()))
}
fn repeat_along_dim_f32(
&self,
input: &GpuBufferHandle,
outer: usize,
repeat_count: usize,
inner: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let in_buf = Self::unwrap_buffer(input)?;
let dev = self.device(input.device_ordinal())?;
let out = crate::kernels::gpu_repeat_along_dim(in_buf, outer, repeat_count, inner, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(out, input.device_ordinal()))
}
fn repeat_along_dim_f64(
&self,
input: &GpuBufferHandle,
outer: usize,
repeat_count: usize,
inner: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let in_buf = Self::unwrap_buffer_f64(input)?;
let dev = self.device(input.device_ordinal())?;
let out = crate::kernels::gpu_repeat_along_dim_f64(in_buf, outer, repeat_count, inner, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(out, input.device_ordinal()))
}
fn rfft_r2c_f32(
&self,
a: &GpuBufferHandle,
batch: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let out =
crate::cufft::gpu_rfft_r2c_f32(a_buf, batch, n, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(out, a.device_ordinal()))
}
fn rfft_r2c_f64(
&self,
a: &GpuBufferHandle,
batch: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let out =
crate::cufft::gpu_rfft_r2c_f64(a_buf, batch, n, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(out, a.device_ordinal()))
}
fn irfft_c2r_f32(
&self,
a: &GpuBufferHandle,
batch: usize,
n_out: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let out =
crate::cufft::gpu_irfft_c2r_f32(a_buf, batch, n_out, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(out, a.device_ordinal()))
}
fn irfft_c2r_f64(
&self,
a: &GpuBufferHandle,
batch: usize,
n_out: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let out =
crate::cufft::gpu_irfft_c2r_f64(a_buf, batch, n_out, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(out, a.device_ordinal()))
}
fn hfft_f32(
&self,
a: &GpuBufferHandle,
batch: usize,
half_in: usize,
n_out: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let out = crate::cufft::gpu_hfft_f32(a_buf, batch, half_in, n_out, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(out, a.device_ordinal()))
}
fn hfft_f64(
&self,
a: &GpuBufferHandle,
batch: usize,
half_in: usize,
n_out: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let out = crate::cufft::gpu_hfft_f64(a_buf, batch, half_in, n_out, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(out, a.device_ordinal()))
}
fn ihfft_f32(
&self,
a: &GpuBufferHandle,
batch: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let out = crate::cufft::gpu_ihfft_f32(a_buf, batch, n, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(out, a.device_ordinal()))
}
fn ihfft_f64(
&self,
a: &GpuBufferHandle,
batch: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let out = crate::cufft::gpu_ihfft_f64(a_buf, batch, n, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(out, a.device_ordinal()))
}
fn fftn3d_c2c_f32(
&self,
a: &GpuBufferHandle,
d: usize,
h: usize,
w: usize,
inverse: bool,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let out = crate::cufft::gpu_fftn3d_c2c_f32(a_buf, d, h, w, inverse, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(out, a.device_ordinal()))
}
fn fftn3d_c2c_f64(
&self,
a: &GpuBufferHandle,
d: usize,
h: usize,
w: usize,
inverse: bool,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let out = crate::cufft::gpu_fftn3d_c2c_f64(a_buf, d, h, w, inverse, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(out, a.device_ordinal()))
}
fn fftn2d_c2c_f32(
&self,
a: &GpuBufferHandle,
h: usize,
w: usize,
inverse: bool,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let out = crate::cufft::gpu_fftn2d_c2c_f32(a_buf, h, w, inverse, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(out, a.device_ordinal()))
}
fn fftn2d_c2c_f64(
&self,
a: &GpuBufferHandle,
h: usize,
w: usize,
inverse: bool,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let out = crate::cufft::gpu_fftn2d_c2c_f64(a_buf, h, w, inverse, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(out, a.device_ordinal()))
}
fn fftn_axes_c2c_f32(
&self,
a: &GpuBufferHandle,
shape: &[usize],
axes: &[usize],
inverse: bool,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let dev = self.device(a.device_ordinal())?;
let out = crate::cufft::gpu_fftn_axes_c2c_f32(a_buf, shape, axes, inverse, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(out, a.device_ordinal()))
}
fn fftn_axes_c2c_f64(
&self,
a: &GpuBufferHandle,
shape: &[usize],
axes: &[usize],
inverse: bool,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f64(a)?;
let dev = self.device(a.device_ordinal())?;
let out = crate::cufft::gpu_fftn_axes_c2c_f64(a_buf, shape, axes, inverse, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(out, a.device_ordinal()))
}
fn spmm_csr_f32(
&self,
crow_indices: &[u32],
col_indices: &[u32],
values: &[f32],
dense: &GpuBufferHandle,
m: usize,
k: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let dense_buf = Self::unwrap_buffer(dense)?;
let dev = self.device(dense.device_ordinal())?;
let handle = self.cusparse()?;
let out = crate::sparse::gpu_spmm_csr_f32(
handle,
crow_indices,
col_indices,
values,
dense_buf,
m,
k,
n,
dev,
)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(out, dense.device_ordinal()))
}
fn spmm_csr_f64(
&self,
crow_indices: &[u32],
col_indices: &[u32],
values: &[f64],
dense: &GpuBufferHandle,
m: usize,
k: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let dense_buf = Self::unwrap_buffer_f64(dense)?;
let dev = self.device(dense.device_ordinal())?;
let handle = self.cusparse()?;
let out = crate::sparse::gpu_spmm_csr_f64(
handle,
crow_indices,
col_indices,
values,
dense_buf,
m,
k,
n,
dev,
)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(out, dense.device_ordinal()))
}
fn sparse_to_dense_csr_f32(
&self,
crow_indices: &[u32],
col_indices: &[u32],
values: &[f32],
device_ordinal: usize,
m: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(device_ordinal)?;
let handle = self.cusparse()?;
let out = crate::sparse::gpu_sparse_to_dense_csr_f32(
handle,
crow_indices,
col_indices,
values,
m,
n,
dev,
)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(out, device_ordinal))
}
fn sparse_to_dense_csr_f64(
&self,
crow_indices: &[u32],
col_indices: &[u32],
values: &[f64],
device_ordinal: usize,
m: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(device_ordinal)?;
let handle = self.cusparse()?;
let out = crate::sparse::gpu_sparse_to_dense_csr_f64(
handle,
crow_indices,
col_indices,
values,
m,
n,
dev,
)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(out, device_ordinal))
}
fn dense_to_sparse_csr_f32(
&self,
dense: &GpuBufferHandle,
m: usize,
n: usize,
) -> FerrotorchResult<(Vec<u32>, Vec<u32>, Vec<f32>)> {
let dense_buf = Self::unwrap_buffer(dense)?;
let dev = self.device(dense.device_ordinal())?;
let handle = self.cusparse()?;
crate::sparse::gpu_dense_to_sparse_csr_f32(handle, dense_buf, m, n, dev)
.map_err(Self::map_gpu_err)
}
fn dense_to_sparse_csr_f64(
&self,
dense: &GpuBufferHandle,
m: usize,
n: usize,
) -> FerrotorchResult<(Vec<u32>, Vec<u32>, Vec<f64>)> {
let dense_buf = Self::unwrap_buffer_f64(dense)?;
let dev = self.device(dense.device_ordinal())?;
let handle = self.cusparse()?;
crate::sparse::gpu_dense_to_sparse_csr_f64(handle, dense_buf, m, n, dev)
.map_err(Self::map_gpu_err)
}
fn csc_to_dense_f32(
&self,
col_ptrs: &[u32],
row_indices: &[u32],
values: &[f32],
device_ordinal: usize,
m: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(device_ordinal)?;
let handle = self.cusparse()?;
let out =
crate::sparse::gpu_csc_to_dense_f32(handle, col_ptrs, row_indices, values, m, n, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(out, device_ordinal))
}
fn csc_to_dense_f64(
&self,
col_ptrs: &[u32],
row_indices: &[u32],
values: &[f64],
device_ordinal: usize,
m: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(device_ordinal)?;
let handle = self.cusparse()?;
let out =
crate::sparse::gpu_csc_to_dense_f64(handle, col_ptrs, row_indices, values, m, n, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(out, device_ordinal))
}
fn csr_to_csc_f32(
&self,
crow_indices: &[u32],
col_indices: &[u32],
values: &[f32],
device_ordinal: usize,
m: usize,
n: usize,
) -> FerrotorchResult<(Vec<u32>, Vec<u32>, Vec<f32>)> {
let dev = self.device(device_ordinal)?;
let handle = self.cusparse()?;
crate::sparse::gpu_csr_to_csc_f32(handle, crow_indices, col_indices, values, m, n, dev)
.map_err(Self::map_gpu_err)
}
fn csr_to_csc_f64(
&self,
crow_indices: &[u32],
col_indices: &[u32],
values: &[f64],
device_ordinal: usize,
m: usize,
n: usize,
) -> FerrotorchResult<(Vec<u32>, Vec<u32>, Vec<f64>)> {
let dev = self.device(device_ordinal)?;
let handle = self.cusparse()?;
crate::sparse::gpu_csr_to_csc_f64(handle, crow_indices, col_indices, values, m, n, dev)
.map_err(Self::map_gpu_err)
}
fn coo_to_csr_f32(
&self,
row_indices: &[u32],
col_indices: &[u32],
values: &[f32],
device_ordinal: usize,
m: usize,
n: usize,
) -> FerrotorchResult<(Vec<u32>, Vec<u32>, Vec<f32>)> {
let dev = self.device(device_ordinal)?;
let handle = self.cusparse()?;
crate::sparse::gpu_coo_to_csr_f32(handle, row_indices, col_indices, values, m, n, dev)
.map_err(Self::map_gpu_err)
}
fn coo_to_csr_f64(
&self,
row_indices: &[u32],
col_indices: &[u32],
values: &[f64],
device_ordinal: usize,
m: usize,
n: usize,
) -> FerrotorchResult<(Vec<u32>, Vec<u32>, Vec<f64>)> {
let dev = self.device(device_ordinal)?;
let handle = self.cusparse()?;
crate::sparse::gpu_coo_to_csr_f64(handle, row_indices, col_indices, values, m, n, dev)
.map_err(Self::map_gpu_err)
}
fn csr_to_coo_f32(
&self,
crow_indices: &[u32],
col_indices: &[u32],
values: &[f32],
device_ordinal: usize,
m: usize,
n: usize,
) -> FerrotorchResult<(Vec<u32>, Vec<u32>, Vec<f32>)> {
let dev = self.device(device_ordinal)?;
let handle = self.cusparse()?;
crate::sparse::gpu_csr_to_coo_f32(handle, crow_indices, col_indices, values, m, n, dev)
.map_err(Self::map_gpu_err)
}
fn csr_to_coo_f64(
&self,
crow_indices: &[u32],
col_indices: &[u32],
values: &[f64],
device_ordinal: usize,
m: usize,
n: usize,
) -> FerrotorchResult<(Vec<u32>, Vec<u32>, Vec<f64>)> {
let dev = self.device(device_ordinal)?;
let handle = self.cusparse()?;
crate::sparse::gpu_csr_to_coo_f64(handle, crow_indices, col_indices, values, m, n, dev)
.map_err(Self::map_gpu_err)
}
fn flash_attention_forward_f32(
&self,
query: &GpuBufferHandle,
key: &GpuBufferHandle,
value: &GpuBufferHandle,
seq_q: usize,
seq_k: usize,
d: usize,
d_v: usize,
scale: f32,
) -> FerrotorchResult<GpuBufferHandle> {
let q_buf = Self::unwrap_buffer(query)?;
let k_buf = Self::unwrap_buffer(key)?;
let v_buf = Self::unwrap_buffer(value)?;
let dev = self.device(query.device_ordinal())?;
let result = crate::flash_attention::gpu_flash_attention_f32(
q_buf, k_buf, v_buf, seq_q, seq_k, d, d_v, 1, scale,
false, dev,
)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(result, query.device_ordinal()))
}
fn flash_attention_forward_f64(
&self,
query: &GpuBufferHandle,
key: &GpuBufferHandle,
value: &GpuBufferHandle,
seq_q: usize,
seq_k: usize,
d: usize,
d_v: usize,
scale: f64,
) -> FerrotorchResult<GpuBufferHandle> {
let q_buf = Self::unwrap_buffer_f64(query)?;
let k_buf = Self::unwrap_buffer_f64(key)?;
let v_buf = Self::unwrap_buffer_f64(value)?;
let dev = self.device(query.device_ordinal())?;
let result = crate::flash_attention::gpu_flash_attention_f64(
q_buf, k_buf, v_buf, seq_q, seq_k, d, d_v, 1, scale,
false, dev,
)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f64(result, query.device_ordinal()))
}
#[cfg(feature = "cusparselt")]
fn sparse_matmul_24_f32(
&self,
a: &GpuBufferHandle,
b_dense_decompressed: &GpuBufferHandle,
m: usize,
k: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer(a)?;
let b_buf = Self::unwrap_buffer(b_dense_decompressed)?;
let dev = self.device(a.device_ordinal())?;
let handle = self.cusparselt()?;
let out = crate::cusparselt::gpu_sparse_matmul_24::<f32>(
handle,
a_buf,
b_buf,
m,
k,
n,
crate::cusparselt::CuSpLtDType::F32,
dev,
)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer(out, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn matmul_bf16_bf16_nt(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
m: usize,
k: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_bf16(a)?;
let b_buf = Self::unwrap_buffer_bf16(b)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::blas::gpu_matmul_bf16_bf16_nt(a_buf, b_buf, m, k, n, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bf16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn softmax_bf16_bf16(
&self,
a: &GpuBufferHandle,
rows: usize,
cols: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let buf = Self::unwrap_buffer_bf16(a)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::bf16::gpu_softmax_bf16(buf, rows, cols, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bf16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn layernorm_bf16_bf16(
&self,
input: &GpuBufferHandle,
gamma: &GpuBufferHandle,
beta: &GpuBufferHandle,
rows: usize,
cols: usize,
eps: f32,
) -> FerrotorchResult<GpuBufferHandle> {
let in_buf = Self::unwrap_buffer_bf16(input)?;
let g_buf = Self::unwrap_buffer_bf16(gamma)?;
let b_buf = Self::unwrap_buffer_bf16(beta)?;
let dev = self.device(input.device_ordinal())?;
let result = crate::bf16::gpu_layernorm_bf16(in_buf, g_buf, b_buf, rows, cols, eps, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bf16(result, input.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn gelu_bf16_bf16(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let buf = Self::unwrap_buffer_bf16(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::bf16::gpu_gelu_bf16(buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bf16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn silu_bf16_bf16(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let buf = Self::unwrap_buffer_bf16(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::bf16::gpu_silu_bf16(buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bf16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn relu_bf16_bf16(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let buf = Self::unwrap_buffer_bf16(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::bf16::gpu_relu_bf16(buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bf16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn add_bf16_bf16(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_bf16(a)?;
let b_buf = Self::unwrap_buffer_bf16(b)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::bf16::gpu_add_bf16(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bf16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn mul_bf16_bf16(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_bf16(a)?;
let b_buf = Self::unwrap_buffer_bf16(b)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::bf16::gpu_mul_bf16(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bf16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn scale_bf16_bf16(
&self,
a: &GpuBufferHandle,
scalar: f32,
) -> FerrotorchResult<GpuBufferHandle> {
let buf = Self::unwrap_buffer_bf16(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::bf16::gpu_scale_bf16(buf, scalar, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bf16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn sub_bf16_bf16(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_bf16(a)?;
let b_buf = Self::unwrap_buffer_bf16(b)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::bf16::gpu_sub_bf16(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bf16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn div_bf16_bf16(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_bf16(a)?;
let b_buf = Self::unwrap_buffer_bf16(b)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::bf16::gpu_div_bf16(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bf16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn neg_bf16_bf16(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let buf = Self::unwrap_buffer_bf16(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::bf16::gpu_neg_bf16(buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bf16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn broadcast_add_bf16(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_bf16(a)?;
let b_buf = Self::unwrap_buffer_bf16(b)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::bf16::gpu_broadcast_add_bf16(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bf16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn broadcast_sub_bf16(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_bf16(a)?;
let b_buf = Self::unwrap_buffer_bf16(b)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::bf16::gpu_broadcast_sub_bf16(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bf16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn broadcast_mul_bf16(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_bf16(a)?;
let b_buf = Self::unwrap_buffer_bf16(b)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::bf16::gpu_broadcast_mul_bf16(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bf16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn broadcast_div_bf16(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_bf16(a)?;
let b_buf = Self::unwrap_buffer_bf16(b)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::bf16::gpu_broadcast_div_bf16(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bf16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn sum_bf16_bf16(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let buf = Self::unwrap_buffer_bf16(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::bf16::gpu_sum_bf16(buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bf16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn mean_bf16_bf16(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let buf = Self::unwrap_buffer_bf16(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::bf16::gpu_mean_bf16(buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bf16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn sum_axis_bf16_bf16(
&self,
a: &GpuBufferHandle,
shape: &[usize],
axis: usize,
) -> FerrotorchResult<GpuBufferHandle> {
if axis >= shape.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"sum_axis_bf16_bf16: axis {axis} out of bounds for shape {shape:?}"
),
});
}
let outer: usize = shape[..axis].iter().product();
let axis_size = shape[axis];
let inner: usize = shape[axis + 1..].iter().product();
let buf = Self::unwrap_buffer_bf16(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::bf16::gpu_sum_axis_bf16_bf16(buf, outer, axis_size, inner, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bf16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn mean_axis_bf16_bf16(
&self,
a: &GpuBufferHandle,
shape: &[usize],
axis: usize,
) -> FerrotorchResult<GpuBufferHandle> {
if axis >= shape.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"mean_axis_bf16_bf16: axis {axis} out of bounds for shape {shape:?}"
),
});
}
let outer: usize = shape[..axis].iter().product();
let axis_size = shape[axis];
let inner: usize = shape[axis + 1..].iter().product();
let buf = Self::unwrap_buffer_bf16(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::bf16::gpu_mean_axis_bf16_bf16(buf, outer, axis_size, inner, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bf16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn exp_bf16_bf16(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let buf = Self::unwrap_buffer_bf16(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::bf16::gpu_exp_bf16(buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bf16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn log_bf16_bf16(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let buf = Self::unwrap_buffer_bf16(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::bf16::gpu_log_bf16(buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bf16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn tanh_bf16_bf16(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let buf = Self::unwrap_buffer_bf16(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::bf16::gpu_tanh_bf16(buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bf16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn sigmoid_bf16_bf16(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let buf = Self::unwrap_buffer_bf16(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::bf16::gpu_sigmoid_bf16(buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_bf16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn add_f16(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f16(a)?;
let b_buf = Self::unwrap_buffer_f16(b)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::f16::gpu_add_f16(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn sub_f16(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f16(a)?;
let b_buf = Self::unwrap_buffer_f16(b)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::f16::gpu_sub_f16(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn mul_f16(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f16(a)?;
let b_buf = Self::unwrap_buffer_f16(b)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::f16::gpu_mul_f16(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn div_f16(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f16(a)?;
let b_buf = Self::unwrap_buffer_f16(b)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::f16::gpu_div_f16(a_buf, b_buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn neg_f16(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let buf = Self::unwrap_buffer_f16(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::f16::gpu_neg_f16(buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn scale_f16(&self, a: &GpuBufferHandle, scale: f32) -> FerrotorchResult<GpuBufferHandle> {
let buf = Self::unwrap_buffer_f16(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::f16::gpu_scale_f16(buf, scale, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn broadcast_add_f16(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f16(a)?;
let b_buf = Self::unwrap_buffer_f16(b)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::f16::gpu_broadcast_add_f16(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn broadcast_sub_f16(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f16(a)?;
let b_buf = Self::unwrap_buffer_f16(b)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::f16::gpu_broadcast_sub_f16(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn broadcast_mul_f16(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f16(a)?;
let b_buf = Self::unwrap_buffer_f16(b)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::f16::gpu_broadcast_mul_f16(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn broadcast_div_f16(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f16(a)?;
let b_buf = Self::unwrap_buffer_f16(b)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::f16::gpu_broadcast_div_f16(a_buf, b_buf, a_shape, b_shape, out_shape, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn sum_f16(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let buf = Self::unwrap_buffer_f16(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::f16::gpu_sum_f16(buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn mean_f16(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let buf = Self::unwrap_buffer_f16(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::f16::gpu_mean_f16(buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn sum_axis_f16(
&self,
a: &GpuBufferHandle,
shape: &[usize],
axis: usize,
) -> FerrotorchResult<GpuBufferHandle> {
if axis >= shape.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!("sum_axis_f16: axis {axis} out of bounds for shape {shape:?}"),
});
}
let outer: usize = shape[..axis].iter().product();
let axis_size = shape[axis];
let inner: usize = shape[axis + 1..].iter().product();
let buf = Self::unwrap_buffer_f16(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::f16::gpu_sum_axis_f16(buf, outer, axis_size, inner, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn mean_axis_f16(
&self,
a: &GpuBufferHandle,
shape: &[usize],
axis: usize,
) -> FerrotorchResult<GpuBufferHandle> {
if axis >= shape.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!("mean_axis_f16: axis {axis} out of bounds for shape {shape:?}"),
});
}
let outer: usize = shape[..axis].iter().product();
let axis_size = shape[axis];
let inner: usize = shape[axis + 1..].iter().product();
let buf = Self::unwrap_buffer_f16(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::f16::gpu_mean_axis_f16(buf, outer, axis_size, inner, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn exp_f16(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let buf = Self::unwrap_buffer_f16(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::f16::gpu_exp_f16(buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn log_f16(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let buf = Self::unwrap_buffer_f16(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::f16::gpu_log_f16(buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn tanh_f16(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let buf = Self::unwrap_buffer_f16(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::f16::gpu_tanh_f16(buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn sigmoid_f16(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let buf = Self::unwrap_buffer_f16(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::f16::gpu_sigmoid_f16(buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn sqrt_f16(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let buf = Self::unwrap_buffer_f16(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::f16::gpu_sqrt_f16(buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn relu_f16(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let buf = Self::unwrap_buffer_f16(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::f16::gpu_relu_f16(buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn silu_f16(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let buf = Self::unwrap_buffer_f16(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::f16::gpu_silu_f16(buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn gelu_f16(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let buf = Self::unwrap_buffer_f16(a)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::f16::gpu_gelu_f16(buf, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn softmax_f16(
&self,
a: &GpuBufferHandle,
rows: usize,
cols: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let buf = Self::unwrap_buffer_f16(a)?;
let dev = self.device(a.device_ordinal())?;
let result =
crate::f16::gpu_softmax_f16(buf, rows, cols, dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn layernorm_f16(
&self,
input: &GpuBufferHandle,
gamma: &GpuBufferHandle,
beta: &GpuBufferHandle,
rows: usize,
cols: usize,
eps: f32,
) -> FerrotorchResult<GpuBufferHandle> {
let in_buf = Self::unwrap_buffer_f16(input)?;
let g_buf = Self::unwrap_buffer_f16(gamma)?;
let b_buf = Self::unwrap_buffer_f16(beta)?;
let dev = self.device(input.device_ordinal())?;
let result = crate::f16::gpu_layernorm_f16(in_buf, g_buf, b_buf, rows, cols, eps, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f16(result, input.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn rmsnorm_f16(
&self,
input: &GpuBufferHandle,
weight: &GpuBufferHandle,
rows: usize,
cols: usize,
eps: f32,
) -> FerrotorchResult<GpuBufferHandle> {
let in_buf = Self::unwrap_buffer_f16(input)?;
let w_buf = Self::unwrap_buffer_f16(weight)?;
let dev = self.device(input.device_ordinal())?;
let result = crate::f16::gpu_rmsnorm_f16(in_buf, w_buf, rows, cols, eps, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f16(result, input.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn matmul_f16_f16(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
m: usize,
k: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let a_buf = Self::unwrap_buffer_f16(a)?;
let b_buf = Self::unwrap_buffer_f16(b)?;
let dev = self.device(a.device_ordinal())?;
let result = crate::blas::gpu_matmul_f16_f16(a_buf, b_buf, m, k, n, dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_buffer_f16(result, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn int_add(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(a.device_ordinal())?;
match a.dtype() {
DType::I32 => {
let av = Self::unwrap_buffer_i32(a)?;
let bv = Self::unwrap_buffer_i32(b)?;
let r = crate::int_kernels::gpu_add_i32(av.inner(), bv.inner(), dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i32(r, a.device_ordinal()))
}
DType::I64 => {
let av = Self::unwrap_buffer_i64(a)?;
let bv = Self::unwrap_buffer_i64(b)?;
let r = crate::int_kernels::gpu_add_i64(av.inner(), bv.inner(), dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i64(r, a.device_ordinal()))
}
_ => Err(FerrotorchError::NotImplementedOnCuda { op: "int_add" }),
}
}
#[cfg(feature = "cuda")]
fn int_sub(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(a.device_ordinal())?;
match a.dtype() {
DType::I32 => {
let av = Self::unwrap_buffer_i32(a)?;
let bv = Self::unwrap_buffer_i32(b)?;
let r = crate::int_kernels::gpu_sub_i32(av.inner(), bv.inner(), dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i32(r, a.device_ordinal()))
}
DType::I64 => {
let av = Self::unwrap_buffer_i64(a)?;
let bv = Self::unwrap_buffer_i64(b)?;
let r = crate::int_kernels::gpu_sub_i64(av.inner(), bv.inner(), dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i64(r, a.device_ordinal()))
}
_ => Err(FerrotorchError::NotImplementedOnCuda { op: "int_sub" }),
}
}
#[cfg(feature = "cuda")]
fn int_mul(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(a.device_ordinal())?;
match a.dtype() {
DType::I32 => {
let av = Self::unwrap_buffer_i32(a)?;
let bv = Self::unwrap_buffer_i32(b)?;
let r = crate::int_kernels::gpu_mul_i32(av.inner(), bv.inner(), dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i32(r, a.device_ordinal()))
}
DType::I64 => {
let av = Self::unwrap_buffer_i64(a)?;
let bv = Self::unwrap_buffer_i64(b)?;
let r = crate::int_kernels::gpu_mul_i64(av.inner(), bv.inner(), dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i64(r, a.device_ordinal()))
}
_ => Err(FerrotorchError::NotImplementedOnCuda { op: "int_mul" }),
}
}
#[cfg(feature = "cuda")]
fn int_neg(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(a.device_ordinal())?;
match a.dtype() {
DType::I32 => {
let av = Self::unwrap_buffer_i32(a)?;
let r =
crate::int_kernels::gpu_neg_i32(av.inner(), dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i32(r, a.device_ordinal()))
}
DType::I64 => {
let av = Self::unwrap_buffer_i64(a)?;
let r =
crate::int_kernels::gpu_neg_i64(av.inner(), dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i64(r, a.device_ordinal()))
}
_ => Err(FerrotorchError::NotImplementedOnCuda { op: "int_neg" }),
}
}
#[cfg(feature = "cuda")]
fn int_floor_div(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(a.device_ordinal())?;
match a.dtype() {
DType::I32 => {
let av = Self::unwrap_buffer_i32(a)?;
let bv = Self::unwrap_buffer_i32(b)?;
let r = crate::int_kernels::gpu_floor_div_i32(av.inner(), bv.inner(), dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i32(r, a.device_ordinal()))
}
DType::I64 => {
let av = Self::unwrap_buffer_i64(a)?;
let bv = Self::unwrap_buffer_i64(b)?;
let r = crate::int_kernels::gpu_floor_div_i64(av.inner(), bv.inner(), dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i64(r, a.device_ordinal()))
}
_ => Err(FerrotorchError::NotImplementedOnCuda {
op: "int_floor_div",
}),
}
}
#[cfg(feature = "cuda")]
fn int_remainder(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(a.device_ordinal())?;
match a.dtype() {
DType::I32 => {
let av = Self::unwrap_buffer_i32(a)?;
let bv = Self::unwrap_buffer_i32(b)?;
let r = crate::int_kernels::gpu_remainder_i32(av.inner(), bv.inner(), dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i32(r, a.device_ordinal()))
}
DType::I64 => {
let av = Self::unwrap_buffer_i64(a)?;
let bv = Self::unwrap_buffer_i64(b)?;
let r = crate::int_kernels::gpu_remainder_i64(av.inner(), bv.inner(), dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i64(r, a.device_ordinal()))
}
_ => Err(FerrotorchError::NotImplementedOnCuda {
op: "int_remainder",
}),
}
}
#[cfg(feature = "cuda")]
fn int_bitand(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(a.device_ordinal())?;
match a.dtype() {
DType::I32 => {
let av = Self::unwrap_buffer_i32(a)?;
let bv = Self::unwrap_buffer_i32(b)?;
let r = crate::int_kernels::gpu_bitand_i32(av.inner(), bv.inner(), dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i32(r, a.device_ordinal()))
}
DType::I64 => {
let av = Self::unwrap_buffer_i64(a)?;
let bv = Self::unwrap_buffer_i64(b)?;
let r = crate::int_kernels::gpu_bitand_i64(av.inner(), bv.inner(), dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i64(r, a.device_ordinal()))
}
_ => Err(FerrotorchError::NotImplementedOnCuda { op: "int_bitand" }),
}
}
#[cfg(feature = "cuda")]
fn int_bitor(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(a.device_ordinal())?;
match a.dtype() {
DType::I32 => {
let av = Self::unwrap_buffer_i32(a)?;
let bv = Self::unwrap_buffer_i32(b)?;
let r = crate::int_kernels::gpu_bitor_i32(av.inner(), bv.inner(), dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i32(r, a.device_ordinal()))
}
DType::I64 => {
let av = Self::unwrap_buffer_i64(a)?;
let bv = Self::unwrap_buffer_i64(b)?;
let r = crate::int_kernels::gpu_bitor_i64(av.inner(), bv.inner(), dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i64(r, a.device_ordinal()))
}
_ => Err(FerrotorchError::NotImplementedOnCuda { op: "int_bitor" }),
}
}
#[cfg(feature = "cuda")]
fn int_bitxor(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(a.device_ordinal())?;
match a.dtype() {
DType::I32 => {
let av = Self::unwrap_buffer_i32(a)?;
let bv = Self::unwrap_buffer_i32(b)?;
let r = crate::int_kernels::gpu_bitxor_i32(av.inner(), bv.inner(), dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i32(r, a.device_ordinal()))
}
DType::I64 => {
let av = Self::unwrap_buffer_i64(a)?;
let bv = Self::unwrap_buffer_i64(b)?;
let r = crate::int_kernels::gpu_bitxor_i64(av.inner(), bv.inner(), dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i64(r, a.device_ordinal()))
}
_ => Err(FerrotorchError::NotImplementedOnCuda { op: "int_bitxor" }),
}
}
#[cfg(feature = "cuda")]
fn int_bitnot(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(a.device_ordinal())?;
match a.dtype() {
DType::I32 => {
let av = Self::unwrap_buffer_i32(a)?;
let r = crate::int_kernels::gpu_bitnot_i32(av.inner(), dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i32(r, a.device_ordinal()))
}
DType::I64 => {
let av = Self::unwrap_buffer_i64(a)?;
let r = crate::int_kernels::gpu_bitnot_i64(av.inner(), dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i64(r, a.device_ordinal()))
}
_ => Err(FerrotorchError::NotImplementedOnCuda { op: "int_bitnot" }),
}
}
#[cfg(feature = "cuda")]
fn int_shl(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(a.device_ordinal())?;
match a.dtype() {
DType::I32 => {
let av = Self::unwrap_buffer_i32(a)?;
let bv = Self::unwrap_buffer_i32(b)?;
let r = crate::int_kernels::gpu_shl_i32(av.inner(), bv.inner(), dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i32(r, a.device_ordinal()))
}
DType::I64 => {
let av = Self::unwrap_buffer_i64(a)?;
let bv = Self::unwrap_buffer_i64(b)?;
let r = crate::int_kernels::gpu_shl_i64(av.inner(), bv.inner(), dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i64(r, a.device_ordinal()))
}
_ => Err(FerrotorchError::NotImplementedOnCuda { op: "int_shl" }),
}
}
#[cfg(feature = "cuda")]
fn int_shr(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(a.device_ordinal())?;
match a.dtype() {
DType::I32 => {
let av = Self::unwrap_buffer_i32(a)?;
let bv = Self::unwrap_buffer_i32(b)?;
let r = crate::int_kernels::gpu_shr_i32(av.inner(), bv.inner(), dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i32(r, a.device_ordinal()))
}
DType::I64 => {
let av = Self::unwrap_buffer_i64(a)?;
let bv = Self::unwrap_buffer_i64(b)?;
let r = crate::int_kernels::gpu_shr_i64(av.inner(), bv.inner(), dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i64(r, a.device_ordinal()))
}
_ => Err(FerrotorchError::NotImplementedOnCuda { op: "int_shr" }),
}
}
#[cfg(feature = "cuda")]
fn int_sum(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(a.device_ordinal())?;
match a.dtype() {
DType::I32 => {
let av = Self::unwrap_buffer_i32(a)?;
let r =
crate::int_kernels::gpu_sum_i32(av.inner(), dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i32(r, a.device_ordinal()))
}
DType::I64 => {
let av = Self::unwrap_buffer_i64(a)?;
let r =
crate::int_kernels::gpu_sum_i64(av.inner(), dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i64(r, a.device_ordinal()))
}
_ => Err(FerrotorchError::NotImplementedOnCuda { op: "int_sum" }),
}
}
#[cfg(feature = "cuda")]
fn int_prod(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(a.device_ordinal())?;
match a.dtype() {
DType::I32 => {
let av = Self::unwrap_buffer_i32(a)?;
let r =
crate::int_kernels::gpu_prod_i32(av.inner(), dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i32(r, a.device_ordinal()))
}
DType::I64 => {
let av = Self::unwrap_buffer_i64(a)?;
let r =
crate::int_kernels::gpu_prod_i64(av.inner(), dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i64(r, a.device_ordinal()))
}
_ => Err(FerrotorchError::NotImplementedOnCuda { op: "int_prod" }),
}
}
#[cfg(feature = "cuda")]
fn int_min(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(a.device_ordinal())?;
match a.dtype() {
DType::I32 => {
let av = Self::unwrap_buffer_i32(a)?;
let r =
crate::int_kernels::gpu_min_i32(av.inner(), dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i32(r, a.device_ordinal()))
}
DType::I64 => {
let av = Self::unwrap_buffer_i64(a)?;
let r =
crate::int_kernels::gpu_min_i64(av.inner(), dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i64(r, a.device_ordinal()))
}
_ => Err(FerrotorchError::NotImplementedOnCuda { op: "int_min" }),
}
}
#[cfg(feature = "cuda")]
fn int_max(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(a.device_ordinal())?;
match a.dtype() {
DType::I32 => {
let av = Self::unwrap_buffer_i32(a)?;
let r =
crate::int_kernels::gpu_max_i32(av.inner(), dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i32(r, a.device_ordinal()))
}
DType::I64 => {
let av = Self::unwrap_buffer_i64(a)?;
let r =
crate::int_kernels::gpu_max_i64(av.inner(), dev).map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i64(r, a.device_ordinal()))
}
_ => Err(FerrotorchError::NotImplementedOnCuda { op: "int_max" }),
}
}
#[cfg(feature = "cuda")]
fn argmax(
&self,
src: &GpuBufferHandle,
outer: usize,
dim_size: usize,
inner: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(src.device_ordinal())?;
let ord = src.device_ordinal();
let r = match src.dtype() {
DType::F32 => crate::reduce_arg::gpu_argmax_f32(
Self::unwrap_buffer(src)?.inner(),
outer,
dim_size,
inner,
dev,
),
DType::F64 => crate::reduce_arg::gpu_argmax_f64(
Self::unwrap_buffer_f64(src)?.inner(),
outer,
dim_size,
inner,
dev,
),
DType::F16 => crate::reduce_arg::gpu_argmax_f16(
Self::unwrap_buffer_f16(src)?,
outer,
dim_size,
inner,
dev,
),
DType::BF16 => crate::reduce_arg::gpu_argmax_bf16(
Self::unwrap_buffer_bf16(src)?,
outer,
dim_size,
inner,
dev,
),
DType::I32 => crate::reduce_arg::gpu_argmax_i32(
Self::unwrap_buffer_i32(src)?.inner(),
outer,
dim_size,
inner,
dev,
),
DType::I64 => crate::reduce_arg::gpu_argmax_i64(
Self::unwrap_buffer_i64(src)?.inner(),
outer,
dim_size,
inner,
dev,
),
other => {
return Err(FerrotorchError::InvalidArgument {
message: format!("argmax: unsupported value dtype {other}"),
});
}
}
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i64(r, ord))
}
#[cfg(feature = "cuda")]
fn argmin(
&self,
src: &GpuBufferHandle,
outer: usize,
dim_size: usize,
inner: usize,
) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(src.device_ordinal())?;
let ord = src.device_ordinal();
let r = match src.dtype() {
DType::F32 => crate::reduce_arg::gpu_argmin_f32(
Self::unwrap_buffer(src)?.inner(),
outer,
dim_size,
inner,
dev,
),
DType::F64 => crate::reduce_arg::gpu_argmin_f64(
Self::unwrap_buffer_f64(src)?.inner(),
outer,
dim_size,
inner,
dev,
),
DType::F16 => crate::reduce_arg::gpu_argmin_f16(
Self::unwrap_buffer_f16(src)?,
outer,
dim_size,
inner,
dev,
),
DType::BF16 => crate::reduce_arg::gpu_argmin_bf16(
Self::unwrap_buffer_bf16(src)?,
outer,
dim_size,
inner,
dev,
),
DType::I32 => crate::reduce_arg::gpu_argmin_i32(
Self::unwrap_buffer_i32(src)?.inner(),
outer,
dim_size,
inner,
dev,
),
DType::I64 => crate::reduce_arg::gpu_argmin_i64(
Self::unwrap_buffer_i64(src)?.inner(),
outer,
dim_size,
inner,
dev,
),
other => {
return Err(FerrotorchError::InvalidArgument {
message: format!("argmin: unsupported value dtype {other}"),
});
}
}
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_i64(r, ord))
}
#[cfg(feature = "cuda")]
fn index_select_intidx(
&self,
src: &GpuBufferHandle,
index: &GpuBufferHandle,
outer: usize,
in_dim: usize,
out_dim: usize,
inner: usize,
) -> FerrotorchResult<GpuBufferHandle> {
self.gather_or_select(src, index, outer, in_dim, out_dim, inner, false)
}
#[cfg(feature = "cuda")]
fn gather_intidx(
&self,
src: &GpuBufferHandle,
index: &GpuBufferHandle,
outer: usize,
in_dim: usize,
out_dim: usize,
inner: usize,
) -> FerrotorchResult<GpuBufferHandle> {
self.gather_or_select(src, index, outer, in_dim, out_dim, inner, true)
}
#[cfg(feature = "cuda")]
fn cast_f_to_i(&self, src: &GpuBufferHandle, dst: DType) -> FerrotorchResult<GpuBufferHandle> {
use crate::cast_kernels as ck;
let dev = self.device(src.device_ordinal())?;
let ord = src.device_ordinal();
match (src.dtype(), dst) {
(DType::F32, DType::I32) => Ok(Self::wrap_slice_i32(
ck::cast_f32_to_i32(Self::unwrap_buffer(src)?.inner(), src.len(), dev)
.map_err(Self::map_gpu_err)?,
ord,
)),
(DType::F32, DType::I64) => Ok(Self::wrap_slice_i64(
ck::cast_f32_to_i64(Self::unwrap_buffer(src)?.inner(), src.len(), dev)
.map_err(Self::map_gpu_err)?,
ord,
)),
(DType::F64, DType::I32) => Ok(Self::wrap_slice_i32(
ck::cast_f64_to_i32(Self::unwrap_buffer_f64(src)?.inner(), src.len(), dev)
.map_err(Self::map_gpu_err)?,
ord,
)),
(DType::F64, DType::I64) => Ok(Self::wrap_slice_i64(
ck::cast_f64_to_i64(Self::unwrap_buffer_f64(src)?.inner(), src.len(), dev)
.map_err(Self::map_gpu_err)?,
ord,
)),
(DType::F16, DType::I32) => Ok(Self::wrap_slice_i32(
ck::cast_f16_to_i32(Self::unwrap_buffer_f16(src)?, src.len(), dev)
.map_err(Self::map_gpu_err)?,
ord,
)),
(DType::F16, DType::I64) => Ok(Self::wrap_slice_i64(
ck::cast_f16_to_i64(Self::unwrap_buffer_f16(src)?, src.len(), dev)
.map_err(Self::map_gpu_err)?,
ord,
)),
(DType::BF16, DType::I32) => Ok(Self::wrap_slice_i32(
ck::cast_bf16_to_i32(Self::unwrap_buffer_bf16(src)?, src.len(), dev)
.map_err(Self::map_gpu_err)?,
ord,
)),
(DType::BF16, DType::I64) => Ok(Self::wrap_slice_i64(
ck::cast_bf16_to_i64(Self::unwrap_buffer_bf16(src)?, src.len(), dev)
.map_err(Self::map_gpu_err)?,
ord,
)),
(s, d) => Err(FerrotorchError::InvalidArgument {
message: format!("cast_f_to_i: unsupported {s} -> {d}"),
}),
}
}
#[cfg(feature = "cuda")]
fn cast_i_to_f(&self, src: &GpuBufferHandle, dst: DType) -> FerrotorchResult<GpuBufferHandle> {
use crate::cast_kernels as ck;
let dev = self.device(src.device_ordinal())?;
let ord = src.device_ordinal();
match (src.dtype(), dst) {
(DType::I32, DType::F32) => Ok(Self::wrap_slice_f32(
ck::cast_i32_to_f32(Self::unwrap_buffer_i32(src)?.inner(), src.len(), dev)
.map_err(Self::map_gpu_err)?,
ord,
)),
(DType::I32, DType::F64) => Ok(Self::wrap_slice_f64(
ck::cast_i32_to_f64(Self::unwrap_buffer_i32(src)?.inner(), src.len(), dev)
.map_err(Self::map_gpu_err)?,
ord,
)),
(DType::I32, DType::F16) => Ok(Self::wrap_buffer_f16(
ck::cast_i32_to_f16(Self::unwrap_buffer_i32(src)?.inner(), src.len(), dev)
.map_err(Self::map_gpu_err)?,
ord,
)),
(DType::I32, DType::BF16) => Ok(Self::wrap_buffer_bf16(
ck::cast_i32_to_bf16(Self::unwrap_buffer_i32(src)?.inner(), src.len(), dev)
.map_err(Self::map_gpu_err)?,
ord,
)),
(DType::I64, DType::F32) => Ok(Self::wrap_slice_f32(
ck::cast_i64_to_f32(Self::unwrap_buffer_i64(src)?.inner(), src.len(), dev)
.map_err(Self::map_gpu_err)?,
ord,
)),
(DType::I64, DType::F64) => Ok(Self::wrap_slice_f64(
ck::cast_i64_to_f64(Self::unwrap_buffer_i64(src)?.inner(), src.len(), dev)
.map_err(Self::map_gpu_err)?,
ord,
)),
(DType::I64, DType::F16) => Ok(Self::wrap_buffer_f16(
ck::cast_i64_to_f16(Self::unwrap_buffer_i64(src)?.inner(), src.len(), dev)
.map_err(Self::map_gpu_err)?,
ord,
)),
(DType::I64, DType::BF16) => Ok(Self::wrap_buffer_bf16(
ck::cast_i64_to_bf16(Self::unwrap_buffer_i64(src)?.inner(), src.len(), dev)
.map_err(Self::map_gpu_err)?,
ord,
)),
(s, d) => Err(FerrotorchError::InvalidArgument {
message: format!("cast_i_to_f: unsupported {s} -> {d}"),
}),
}
}
#[cfg(feature = "cuda")]
fn cast_i_to_i(&self, src: &GpuBufferHandle, dst: DType) -> FerrotorchResult<GpuBufferHandle> {
use crate::cast_kernels as ck;
let dev = self.device(src.device_ordinal())?;
let ord = src.device_ordinal();
match (src.dtype(), dst) {
(DType::I32, DType::I64) => Ok(Self::wrap_slice_i64(
ck::cast_i32_to_i64(Self::unwrap_buffer_i32(src)?.inner(), src.len(), dev)
.map_err(Self::map_gpu_err)?,
ord,
)),
(DType::I64, DType::I32) => Ok(Self::wrap_slice_i32(
ck::cast_i64_to_i32(Self::unwrap_buffer_i64(src)?.inner(), src.len(), dev)
.map_err(Self::map_gpu_err)?,
ord,
)),
(DType::I32, DType::I32) => Ok(Self::wrap_slice_i32(
ck::cast_i32_copy(Self::unwrap_buffer_i32(src)?.inner(), src.len(), dev)
.map_err(Self::map_gpu_err)?,
ord,
)),
(DType::I64, DType::I64) => Ok(Self::wrap_slice_i64(
ck::cast_i64_copy(Self::unwrap_buffer_i64(src)?.inner(), src.len(), dev)
.map_err(Self::map_gpu_err)?,
ord,
)),
(s, d) => Err(FerrotorchError::InvalidArgument {
message: format!("cast_i_to_i: unsupported {s} -> {d}"),
}),
}
}
#[cfg(feature = "cuda")]
fn compare(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
op: ferrotorch_core::gpu_dispatch::CompareOp,
) -> FerrotorchResult<GpuBufferHandle> {
use crate::bool_kernels as bk;
if a.dtype() != b.dtype() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"compare: operand dtypes differ ({} vs {})",
a.dtype(),
b.dtype()
),
});
}
let dev = self.device(a.device_ordinal())?;
let ord = a.device_ordinal();
let suffix = op.suffix();
let r = match a.dtype() {
DType::F32 => bk::gpu_cmp_f32(
Self::unwrap_buffer(a)?.inner(),
Self::unwrap_buffer(b)?.inner(),
suffix,
dev,
),
DType::F64 => bk::gpu_cmp_f64(
Self::unwrap_buffer_f64(a)?.inner(),
Self::unwrap_buffer_f64(b)?.inner(),
suffix,
dev,
),
DType::I32 => bk::gpu_cmp_i32(
Self::unwrap_buffer_i32(a)?.inner(),
Self::unwrap_buffer_i32(b)?.inner(),
suffix,
dev,
),
DType::I64 => bk::gpu_cmp_i64(
Self::unwrap_buffer_i64(a)?.inner(),
Self::unwrap_buffer_i64(b)?.inner(),
suffix,
dev,
),
DType::BF16 => bk::gpu_cmp_bf16(
Self::unwrap_buffer_bf16(a)?,
Self::unwrap_buffer_bf16(b)?,
suffix,
dev,
),
DType::F16 => bk::gpu_cmp_f16(
Self::unwrap_buffer_f16(a)?,
Self::unwrap_buffer_f16(b)?,
suffix,
dev,
),
_ => return Err(FerrotorchError::NotImplementedOnCuda { op: "compare" }),
};
Ok(Self::wrap_slice_bool(r.map_err(Self::map_gpu_err)?, ord))
}
#[cfg(feature = "cuda")]
fn bool_and(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(a.device_ordinal())?;
let r = crate::bool_kernels::gpu_and_bool(
Self::unwrap_buffer_bool(a)?.inner(),
Self::unwrap_buffer_bool(b)?.inner(),
dev,
)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_bool(r, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn bool_or(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(a.device_ordinal())?;
let r = crate::bool_kernels::gpu_or_bool(
Self::unwrap_buffer_bool(a)?.inner(),
Self::unwrap_buffer_bool(b)?.inner(),
dev,
)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_bool(r, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn bool_xor(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(a.device_ordinal())?;
let r = crate::bool_kernels::gpu_xor_bool(
Self::unwrap_buffer_bool(a)?.inner(),
Self::unwrap_buffer_bool(b)?.inner(),
dev,
)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_bool(r, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn bool_not(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(a.device_ordinal())?;
let r = crate::bool_kernels::gpu_not_bool(Self::unwrap_buffer_bool(a)?.inner(), dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_bool(r, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn bool_any(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(a.device_ordinal())?;
let r = crate::bool_kernels::gpu_any_bool(Self::unwrap_buffer_bool(a)?.inner(), dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_bool(r, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn bool_all(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
let dev = self.device(a.device_ordinal())?;
let r = crate::bool_kernels::gpu_all_bool(Self::unwrap_buffer_bool(a)?.inner(), dev)
.map_err(Self::map_gpu_err)?;
Ok(Self::wrap_slice_bool(r, a.device_ordinal()))
}
#[cfg(feature = "cuda")]
fn cast_bool_to_f(
&self,
src: &GpuBufferHandle,
dst: DType,
) -> FerrotorchResult<GpuBufferHandle> {
use crate::cast_kernels as ck;
if src.dtype() != DType::Bool {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"cast_bool_to_f: src is tagged {}, expected Bool",
src.dtype()
),
});
}
let dev = self.device(src.device_ordinal())?;
let ord = src.device_ordinal();
let inb = Self::unwrap_buffer_bool(src)?.inner();
match dst {
DType::F32 => Ok(Self::wrap_slice_f32(
ck::cast_bool_to_f32(inb, src.len(), dev).map_err(Self::map_gpu_err)?,
ord,
)),
DType::F64 => Ok(Self::wrap_slice_f64(
ck::cast_bool_to_f64(inb, src.len(), dev).map_err(Self::map_gpu_err)?,
ord,
)),
DType::F16 => Ok(Self::wrap_buffer_f16(
ck::cast_bool_to_f16(inb, src.len(), dev).map_err(Self::map_gpu_err)?,
ord,
)),
DType::BF16 => Ok(Self::wrap_buffer_bf16(
ck::cast_bool_to_bf16(inb, src.len(), dev).map_err(Self::map_gpu_err)?,
ord,
)),
d => Err(FerrotorchError::InvalidArgument {
message: format!("cast_bool_to_f: unsupported Bool -> {d}"),
}),
}
}
#[cfg(feature = "cuda")]
fn masked_fill_dt(
&self,
input: &GpuBufferHandle,
mask: &GpuBufferHandle,
value: f64,
) -> FerrotorchResult<GpuBufferHandle> {
use crate::masked_kernels as mk;
if mask.dtype() != DType::Bool {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"masked_fill: mask is tagged {}, expected Bool",
mask.dtype()
),
});
}
if input.len() != mask.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"masked_fill: input numel {} != mask numel {}",
input.len(),
mask.len()
),
});
}
let dev = self.device(input.device_ordinal())?;
let ord = input.device_ordinal();
let mb = Self::unwrap_buffer_bool(mask)?.inner();
match input.dtype() {
DType::F32 => Ok(Self::wrap_slice_f32(
mk::masked_fill_f32(Self::unwrap_buffer(input)?.inner(), mb, value as f32, dev)
.map_err(Self::map_gpu_err)?,
ord,
)),
DType::F64 => Ok(Self::wrap_slice_f64(
mk::masked_fill_f64(Self::unwrap_buffer_f64(input)?.inner(), mb, value, dev)
.map_err(Self::map_gpu_err)?,
ord,
)),
DType::F16 => Ok(Self::wrap_buffer_f16(
mk::masked_fill_f16(Self::unwrap_buffer_f16(input)?, mb, value as f32, dev)
.map_err(Self::map_gpu_err)?,
ord,
)),
DType::BF16 => Ok(Self::wrap_buffer_bf16(
mk::masked_fill_bf16(Self::unwrap_buffer_bf16(input)?, mb, value as f32, dev)
.map_err(Self::map_gpu_err)?,
ord,
)),
DType::I32 => Ok(Self::wrap_slice_i32(
mk::masked_fill_i32(
Self::unwrap_buffer_i32(input)?.inner(),
mb,
value as i32,
dev,
)
.map_err(Self::map_gpu_err)?,
ord,
)),
DType::I64 => Ok(Self::wrap_slice_i64(
mk::masked_fill_i64(
Self::unwrap_buffer_i64(input)?.inner(),
mb,
value as i64,
dev,
)
.map_err(Self::map_gpu_err)?,
ord,
)),
_ => Err(FerrotorchError::NotImplementedOnCuda {
op: "masked_fill_dt",
}),
}
}
#[cfg(feature = "cuda")]
fn where_cond(
&self,
cond: &GpuBufferHandle,
x: &GpuBufferHandle,
y: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
use crate::masked_kernels as mk;
if cond.dtype() != DType::Bool {
return Err(FerrotorchError::InvalidArgument {
message: format!("where_cond: cond is tagged {}, expected Bool", cond.dtype()),
});
}
if x.dtype() != y.dtype() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"where_cond: x/y dtypes differ ({} vs {})",
x.dtype(),
y.dtype()
),
});
}
if x.len() != y.len() || x.len() != cond.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"where_cond: numel mismatch (cond {}, x {}, y {})",
cond.len(),
x.len(),
y.len()
),
});
}
let dev = self.device(x.device_ordinal())?;
let ord = x.device_ordinal();
let cb = Self::unwrap_buffer_bool(cond)?.inner();
match x.dtype() {
DType::F32 => Ok(Self::wrap_slice_f32(
mk::where_32::<f32>(
cb,
Self::unwrap_buffer(x)?.inner(),
Self::unwrap_buffer(y)?.inner(),
dev,
)
.map_err(Self::map_gpu_err)?,
ord,
)),
DType::F64 => Ok(Self::wrap_slice_f64(
mk::where_64::<f64>(
cb,
Self::unwrap_buffer_f64(x)?.inner(),
Self::unwrap_buffer_f64(y)?.inner(),
dev,
)
.map_err(Self::map_gpu_err)?,
ord,
)),
DType::F16 => Ok(Self::wrap_buffer_f16(
mk::where_16(
cb,
Self::unwrap_buffer_f16(x)?,
Self::unwrap_buffer_f16(y)?,
dev,
)
.map_err(Self::map_gpu_err)?,
ord,
)),
DType::BF16 => Ok(Self::wrap_buffer_bf16(
mk::where_16(
cb,
Self::unwrap_buffer_bf16(x)?,
Self::unwrap_buffer_bf16(y)?,
dev,
)
.map_err(Self::map_gpu_err)?,
ord,
)),
DType::I32 => Ok(Self::wrap_slice_i32(
mk::where_32::<i32>(
cb,
Self::unwrap_buffer_i32(x)?.inner(),
Self::unwrap_buffer_i32(y)?.inner(),
dev,
)
.map_err(Self::map_gpu_err)?,
ord,
)),
DType::I64 => Ok(Self::wrap_slice_i64(
mk::where_64::<i64>(
cb,
Self::unwrap_buffer_i64(x)?.inner(),
Self::unwrap_buffer_i64(y)?.inner(),
dev,
)
.map_err(Self::map_gpu_err)?,
ord,
)),
_ => Err(FerrotorchError::NotImplementedOnCuda { op: "where_cond" }),
}
}
#[cfg(feature = "cuda")]
fn masked_select(
&self,
input: &GpuBufferHandle,
mask: &GpuBufferHandle,
) -> FerrotorchResult<(GpuBufferHandle, usize)> {
use crate::masked_kernels as mk;
if mask.dtype() != DType::Bool {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"masked_select: mask is tagged {}, expected Bool",
mask.dtype()
),
});
}
if input.len() != mask.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"masked_select: input numel {} != mask numel {}",
input.len(),
mask.len()
),
});
}
let dev = self.device(input.device_ordinal())?;
let ord = input.device_ordinal();
let mb = Self::unwrap_buffer_bool(mask)?.inner();
match input.dtype() {
DType::F32 => {
let (out, len) =
mk::masked_select_32::<f32>(Self::unwrap_buffer(input)?.inner(), mb, dev)
.map_err(Self::map_gpu_err)?;
Ok((Self::wrap_slice_f32(out, ord), len))
}
DType::F64 => {
let (out, len) =
mk::masked_select_64::<f64>(Self::unwrap_buffer_f64(input)?.inner(), mb, dev)
.map_err(Self::map_gpu_err)?;
Ok((Self::wrap_slice_f64(out, ord), len))
}
DType::F16 => {
let (out, len) = mk::masked_select_16(Self::unwrap_buffer_f16(input)?, mb, dev)
.map_err(Self::map_gpu_err)?;
Ok((Self::wrap_buffer_f16(out, ord), len))
}
DType::BF16 => {
let (out, len) = mk::masked_select_16(Self::unwrap_buffer_bf16(input)?, mb, dev)
.map_err(Self::map_gpu_err)?;
Ok((Self::wrap_buffer_bf16(out, ord), len))
}
DType::I32 => {
let (out, len) =
mk::masked_select_32::<i32>(Self::unwrap_buffer_i32(input)?.inner(), mb, dev)
.map_err(Self::map_gpu_err)?;
Ok((Self::wrap_slice_i32(out, ord), len))
}
DType::I64 => {
let (out, len) =
mk::masked_select_64::<i64>(Self::unwrap_buffer_i64(input)?.inner(), mb, dev)
.map_err(Self::map_gpu_err)?;
Ok((Self::wrap_slice_i64(out, ord), len))
}
_ => Err(FerrotorchError::NotImplementedOnCuda {
op: "masked_select",
}),
}
}
#[cfg(feature = "cuda")]
fn masked_scatter(
&self,
grad_compact: &GpuBufferHandle,
mask: &GpuBufferHandle,
out_numel: usize,
) -> FerrotorchResult<GpuBufferHandle> {
use crate::masked_kernels as mk;
if mask.dtype() != DType::Bool {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"masked_scatter: mask is tagged {}, expected Bool",
mask.dtype()
),
});
}
if mask.len() != out_numel {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"masked_scatter: mask numel {} != out_numel {}",
mask.len(),
out_numel
),
});
}
let dev = self.device(grad_compact.device_ordinal())?;
let ord = grad_compact.device_ordinal();
let mb = Self::unwrap_buffer_bool(mask)?.inner();
match grad_compact.dtype() {
DType::F32 => Ok(Self::wrap_slice_f32(
mk::masked_scatter_32::<f32>(
Self::unwrap_buffer(grad_compact)?.inner(),
mb,
out_numel,
dev,
)
.map_err(Self::map_gpu_err)?,
ord,
)),
DType::F64 => Ok(Self::wrap_slice_f64(
mk::masked_scatter_64::<f64>(
Self::unwrap_buffer_f64(grad_compact)?.inner(),
mb,
out_numel,
dev,
)
.map_err(Self::map_gpu_err)?,
ord,
)),
DType::F16 => Ok(Self::wrap_buffer_f16(
mk::masked_scatter_16(Self::unwrap_buffer_f16(grad_compact)?, mb, out_numel, dev)
.map_err(Self::map_gpu_err)?,
ord,
)),
DType::BF16 => Ok(Self::wrap_buffer_bf16(
mk::masked_scatter_16(Self::unwrap_buffer_bf16(grad_compact)?, mb, out_numel, dev)
.map_err(Self::map_gpu_err)?,
ord,
)),
DType::I32 => Ok(Self::wrap_slice_i32(
mk::masked_scatter_32::<i32>(
Self::unwrap_buffer_i32(grad_compact)?.inner(),
mb,
out_numel,
dev,
)
.map_err(Self::map_gpu_err)?,
ord,
)),
DType::I64 => Ok(Self::wrap_slice_i64(
mk::masked_scatter_64::<i64>(
Self::unwrap_buffer_i64(grad_compact)?.inner(),
mb,
out_numel,
dev,
)
.map_err(Self::map_gpu_err)?,
ord,
)),
_ => Err(FerrotorchError::NotImplementedOnCuda {
op: "masked_scatter",
}),
}
}
}
pub fn get_cuda_device() -> FerrotorchResult<Arc<GpuDevice>> {
let backend =
ferrotorch_core::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let cuda_backend = backend.as_any().downcast_ref::<CudaBackendImpl>().ok_or(
FerrotorchError::InvalidArgument {
message: "registered GPU backend is not CudaBackendImpl".into(),
},
)?;
Ok(Arc::clone(cuda_backend.default_device()?))
}
pub fn init_cuda_backend() -> FerrotorchResult<()> {
if ferrotorch_core::gpu_dispatch::has_gpu_backend() {
return Ok(());
}
let backend = CudaBackendImpl::new()?;
let _ = ferrotorch_core::gpu_dispatch::register_gpu_backend(Box::new(backend));
Ok(())
}
#[cfg(test)]
#[cfg(feature = "cuda")]
mod tests {
use super::*;
use ferrotorch_core::gpu_dispatch;
fn ensure_init() {
if !gpu_dispatch::has_gpu_backend() {
init_cuda_backend().expect("init_cuda_backend");
}
}
#[test]
fn test_init_cuda_backend() {
ensure_init();
assert!(gpu_dispatch::has_gpu_backend());
}
#[test]
fn test_gpu_backend_returns_some() {
ensure_init();
assert!(gpu_dispatch::gpu_backend().is_some());
}
#[test]
fn test_roundtrip_cpu_gpu_cpu() {
ensure_init();
let backend = gpu_dispatch::gpu_backend().expect("backend registered");
let host: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let bytes: &[u8] = unsafe {
std::slice::from_raw_parts(
host.as_ptr() as *const u8,
host.len() * std::mem::size_of::<f32>(),
)
};
let handle = backend
.cpu_to_gpu(bytes, DType::F32, 0)
.expect("cpu_to_gpu");
assert_eq!(handle.len(), 5);
assert_eq!(handle.device_ordinal(), 0);
let back_bytes = backend.gpu_to_cpu(&handle).expect("gpu_to_cpu");
let back: &[f32] = unsafe {
std::slice::from_raw_parts(back_bytes.as_ptr() as *const f32, back_bytes.len() / 4)
};
assert_eq!(back, &host[..]);
}
#[test]
fn test_add_f32() {
ensure_init();
let backend = gpu_dispatch::gpu_backend().expect("backend registered");
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let b_data: Vec<f32> = vec![10.0, 20.0, 30.0, 40.0];
let expected: Vec<f32> = vec![11.0, 22.0, 33.0, 44.0];
let a_bytes: &[u8] =
unsafe { std::slice::from_raw_parts(a_data.as_ptr() as *const u8, a_data.len() * 4) };
let b_bytes: &[u8] =
unsafe { std::slice::from_raw_parts(b_data.as_ptr() as *const u8, b_data.len() * 4) };
let a_handle = backend
.cpu_to_gpu(a_bytes, DType::F32, 0)
.expect("cpu_to_gpu a");
let b_handle = backend
.cpu_to_gpu(b_bytes, DType::F32, 0)
.expect("cpu_to_gpu b");
let result = backend.add_f32(&a_handle, &b_handle).expect("add_f32");
assert_eq!(result.len(), 4);
let result_bytes = backend.gpu_to_cpu(&result).expect("gpu_to_cpu");
let result_f32: &[f32] = unsafe {
std::slice::from_raw_parts(result_bytes.as_ptr() as *const f32, result_bytes.len() / 4)
};
for (i, (&got, &exp)) in result_f32.iter().zip(expected.iter()).enumerate() {
assert!(
(got - exp).abs() < 1e-6,
"element {i}: got {got}, expected {exp}",
);
}
}
#[test]
fn test_matmul_f32() {
ensure_init();
let backend = gpu_dispatch::gpu_backend().expect("backend registered");
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_data: Vec<f32> = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
let expected: Vec<f32> = vec![58.0, 64.0, 139.0, 154.0];
let a_bytes: &[u8] =
unsafe { std::slice::from_raw_parts(a_data.as_ptr() as *const u8, a_data.len() * 4) };
let b_bytes: &[u8] =
unsafe { std::slice::from_raw_parts(b_data.as_ptr() as *const u8, b_data.len() * 4) };
let a_handle = backend
.cpu_to_gpu(a_bytes, DType::F32, 0)
.expect("cpu_to_gpu a");
let b_handle = backend
.cpu_to_gpu(b_bytes, DType::F32, 0)
.expect("cpu_to_gpu b");
let result = backend
.matmul_f32(&a_handle, &b_handle, 2, 3, 2)
.expect("matmul_f32");
assert_eq!(result.len(), 4);
let result_bytes = backend.gpu_to_cpu(&result).expect("gpu_to_cpu");
let result_f32: &[f32] = unsafe {
std::slice::from_raw_parts(result_bytes.as_ptr() as *const f32, result_bytes.len() / 4)
};
for (i, (&got, &exp)) in result_f32.iter().zip(expected.iter()).enumerate() {
assert!(
(got - exp).abs() < 1e-3,
"element {i}: got {got}, expected {exp}",
);
}
}
#[test]
fn test_roundtrip_bf16() {
ensure_init();
let backend = gpu_dispatch::gpu_backend().expect("backend registered");
let host: Vec<u16> = [0.0_f32, 1.0, -1.0, 2.5, -3.5, 100.0]
.iter()
.map(|&x| half::bf16::from_f32(x).to_bits())
.collect();
let bytes: &[u8] = unsafe {
std::slice::from_raw_parts(
host.as_ptr() as *const u8,
host.len() * std::mem::size_of::<u16>(),
)
};
let handle = backend
.cpu_to_gpu(bytes, DType::BF16, 0)
.expect("cpu_to_gpu bf16");
assert_eq!(handle.len(), host.len());
assert_eq!(handle.device_ordinal(), 0);
assert_eq!(backend.buffer_elem_size(&handle), 2);
let back_bytes = backend.gpu_to_cpu(&handle).expect("gpu_to_cpu bf16");
assert_eq!(back_bytes.len(), host.len() * 2);
let back: &[u16] = unsafe {
std::slice::from_raw_parts(back_bytes.as_ptr() as *const u16, back_bytes.len() / 2)
};
assert_eq!(back, &host[..]);
let cloned = backend.clone_buffer(&handle).expect("clone_buffer bf16");
assert_eq!(cloned.len(), host.len());
let cloned_bytes = backend.gpu_to_cpu(&cloned).expect("gpu_to_cpu cloned");
let cloned_back: &[u16] = unsafe {
std::slice::from_raw_parts(cloned_bytes.as_ptr() as *const u16, cloned_bytes.len() / 2)
};
assert_eq!(cloned_back, &host[..]);
}
#[test]
fn test_f16_bf16_tag_disambiguation() {
ensure_init();
let backend = gpu_dispatch::gpu_backend().expect("backend registered");
let host: Vec<u16> = [0.0_f32, 1.0, -1.0, 2.5, -3.5, 100.0]
.iter()
.map(|&x| half::f16::from_f32(x).to_bits())
.collect();
let bytes: &[u8] = unsafe {
std::slice::from_raw_parts(
host.as_ptr() as *const u8,
host.len() * std::mem::size_of::<u16>(),
)
};
let f16_handle = backend
.cpu_to_gpu(bytes, DType::F16, 0)
.expect("cpu_to_gpu f16");
assert_eq!(f16_handle.dtype(), DType::F16);
assert_ne!(f16_handle.dtype(), DType::BF16);
assert_eq!(backend.buffer_elem_size(&f16_handle), 2);
let back_bytes = backend.gpu_to_cpu(&f16_handle).expect("gpu_to_cpu f16");
let back: &[u16] = unsafe {
std::slice::from_raw_parts(back_bytes.as_ptr() as *const u16, back_bytes.len() / 2)
};
assert_eq!(back, &host[..], "f16 round-trip must be bit-exact");
let mismatch = backend.add_bf16_bf16(&f16_handle, &f16_handle);
assert!(
mismatch.is_err(),
"F16-tagged handle fed to add_bf16_bf16 must Err, got Ok"
);
if let Err(e) = mismatch {
let msg = format!("{e}");
assert!(
msg.contains("BF16") || msg.contains("expected BF16"),
"error must name the BF16 tag mismatch, got: {msg}"
);
}
let bf16_host: Vec<u16> = [0.0_f32, 1.0]
.iter()
.map(|&x| half::bf16::from_f32(x).to_bits())
.collect();
let bf16_bytes: &[u8] = unsafe {
std::slice::from_raw_parts(
bf16_host.as_ptr() as *const u8,
bf16_host.len() * std::mem::size_of::<u16>(),
)
};
let bf16_handle = backend
.cpu_to_gpu(bf16_bytes, DType::BF16, 0)
.expect("cpu_to_gpu bf16");
assert_eq!(bf16_handle.dtype(), DType::BF16);
let rev_mismatch = backend.add_f16(&bf16_handle, &bf16_handle);
assert!(
rev_mismatch.is_err(),
"BF16-tagged handle fed to add_f16 must Err, got Ok"
);
if let Err(e) = rev_mismatch {
let msg = format!("{e}");
assert!(
msg.contains("F16") || msg.contains("expected F16"),
"error must name the F16 tag mismatch, got: {msg}"
);
}
}
#[test]
fn test_alloc_zeros_bf16() {
ensure_init();
let backend = gpu_dispatch::gpu_backend().expect("backend registered");
let handle = backend
.alloc_zeros(8, DType::BF16, 0)
.expect("alloc_zeros bf16");
assert_eq!(handle.len(), 8);
assert_eq!(backend.buffer_elem_size(&handle), 2);
let bytes = backend.gpu_to_cpu(&handle).expect("gpu_to_cpu zeros");
assert_eq!(bytes.len(), 16);
assert!(bytes.iter().all(|&b| b == 0));
}
#[test]
fn test_matmul_bf16_bf16_dispatcher() {
ensure_init();
let backend = gpu_dispatch::gpu_backend().expect("backend registered");
let a_f32: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_f32: Vec<f32> = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
let expected_f32: Vec<f32> = vec![58.0, 64.0, 139.0, 154.0];
let a_bf16: Vec<u16> = a_f32
.iter()
.map(|&x| half::bf16::from_f32(x).to_bits())
.collect();
let b_bf16: Vec<u16> = b_f32
.iter()
.map(|&x| half::bf16::from_f32(x).to_bits())
.collect();
let a_bytes: &[u8] =
unsafe { std::slice::from_raw_parts(a_bf16.as_ptr() as *const u8, a_bf16.len() * 2) };
let b_bytes: &[u8] =
unsafe { std::slice::from_raw_parts(b_bf16.as_ptr() as *const u8, b_bf16.len() * 2) };
let a_handle = backend
.cpu_to_gpu(a_bytes, DType::BF16, 0)
.expect("cpu_to_gpu a");
let b_handle = backend
.cpu_to_gpu(b_bytes, DType::BF16, 0)
.expect("cpu_to_gpu b");
let result = backend
.matmul_bf16_bf16(&a_handle, &b_handle, 2, 3, 2)
.expect("matmul_bf16_bf16");
assert_eq!(result.len(), 4);
assert_eq!(backend.buffer_elem_size(&result), 2);
let result_bytes = backend.gpu_to_cpu(&result).expect("gpu_to_cpu result");
let result_bf16: &[u16] = unsafe {
std::slice::from_raw_parts(result_bytes.as_ptr() as *const u16, result_bytes.len() / 2)
};
for (i, (&got_bits, &exp)) in result_bf16.iter().zip(expected_f32.iter()).enumerate() {
let got = half::bf16::from_bits(got_bits).to_f32();
assert!(
(got - exp).abs() < 1.0,
"element {i}: got {got}, expected {exp}"
);
}
}
}