use std::any::Any;
use std::sync::OnceLock;
use crate::error::{FerrotorchError, FerrotorchResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct GpuRngState {
pub counter: u64,
pub seed: u64,
pub offset: u64,
pub device: usize,
}
pub struct GpuBufferHandle {
pub(crate) inner: Box<dyn Any + Send + Sync>,
pub(crate) device_ordinal: usize,
pub(crate) len: usize,
}
impl GpuBufferHandle {
pub fn new(inner: Box<dyn Any + Send + Sync>, device_ordinal: usize, len: usize) -> Self {
Self { inner, device_ordinal, len }
}
#[inline]
pub fn device_ordinal(&self) -> usize { self.device_ordinal }
#[inline]
pub fn len(&self) -> usize { self.len }
#[inline]
pub fn is_empty(&self) -> bool { self.len == 0 }
pub fn downcast_ref<T: 'static>(&self) -> Option<&T> {
self.inner.downcast_ref()
}
pub fn downcast_mut<T: 'static>(&mut self) -> Option<&mut T> {
self.inner.downcast_mut()
}
pub fn into_inner<T: 'static>(self) -> Result<T, Box<dyn Any + Send + Sync>> {
self.inner.downcast::<T>().map(|b| *b)
}
}
impl std::fmt::Debug for GpuBufferHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GpuBufferHandle")
.field("device", &self.device_ordinal)
.field("len", &self.len)
.finish()
}
}
pub trait GpuBackend: Send + Sync {
fn as_any(&self) -> &dyn std::any::Any;
fn cpu_to_gpu(&self, data: &[u8], elem_size: usize, device: usize) -> FerrotorchResult<GpuBufferHandle>;
fn gpu_to_cpu(&self, handle: &GpuBufferHandle) -> FerrotorchResult<Vec<u8>>;
fn clone_buffer(&self, handle: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle>;
fn alloc_zeros(&self, len: usize, elem_size: usize, device: usize) -> FerrotorchResult<GpuBufferHandle>;
fn add_f32(&self, a: &GpuBufferHandle, b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle>;
fn sub_f32(&self, a: &GpuBufferHandle, b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle>;
fn mul_f32(&self, a: &GpuBufferHandle, b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle>;
fn neg_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle>;
fn relu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle>;
fn matmul_f32(&self, a: &GpuBufferHandle, b: &GpuBufferHandle, m: usize, k: usize, n: usize) -> FerrotorchResult<GpuBufferHandle>;
fn matmul_f16_f32(&self, a: &GpuBufferHandle, b: &GpuBufferHandle, m: usize, k: usize, n: usize) -> FerrotorchResult<GpuBufferHandle> {
self.matmul_f32(a, b, m, k, n)
}
fn sum_f32(&self, a: &GpuBufferHandle, len: usize) -> FerrotorchResult<GpuBufferHandle>;
fn add_f64(&self, _a: &GpuBufferHandle, _b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "f64 GPU ops not yet implemented".into() })
}
fn sub_f64(&self, _a: &GpuBufferHandle, _b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "f64 GPU ops not yet implemented".into() })
}
fn mul_f64(&self, _a: &GpuBufferHandle, _b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "f64 GPU ops not yet implemented".into() })
}
fn neg_f64(&self, _a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "f64 GPU ops not yet implemented".into() })
}
fn relu_f64(&self, _a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "f64 GPU ops not yet implemented".into() })
}
fn matmul_f64(&self, _a: &GpuBufferHandle, _b: &GpuBufferHandle, _m: usize, _k: usize, _n: usize) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "f64 GPU ops not yet implemented".into() })
}
fn sum_f64(&self, _a: &GpuBufferHandle, _numel: usize) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "f64 GPU ops not yet implemented".into() })
}
fn broadcast_add_f32(&self, a: &GpuBufferHandle, b: &GpuBufferHandle, a_shape: &[usize], b_shape: &[usize], out_shape: &[usize]) -> FerrotorchResult<GpuBufferHandle>;
fn broadcast_sub_f32(&self, a: &GpuBufferHandle, b: &GpuBufferHandle, a_shape: &[usize], b_shape: &[usize], out_shape: &[usize]) -> FerrotorchResult<GpuBufferHandle>;
fn broadcast_mul_f32(&self, a: &GpuBufferHandle, b: &GpuBufferHandle, a_shape: &[usize], b_shape: &[usize], out_shape: &[usize]) -> FerrotorchResult<GpuBufferHandle>;
fn softmax_f32(&self, a: &GpuBufferHandle, rows: usize, cols: usize) -> FerrotorchResult<GpuBufferHandle>;
fn dropout_f32(&self, a: &GpuBufferHandle, threshold: u32, scale: f32, seed: u32) -> FerrotorchResult<GpuBufferHandle>;
fn dropout_philox_f32(
&self,
a: &GpuBufferHandle,
threshold: u32,
scale: f32,
) -> FerrotorchResult<(GpuBufferHandle, GpuRngState)> {
let result = self.dropout_f32(a, threshold, scale, 0)?;
Ok((result, GpuRngState { counter: 0, seed: 0, offset: 0, device: 0 }))
}
fn transpose_2d_f32(&self, a: &GpuBufferHandle, m: usize, n: usize) -> FerrotorchResult<GpuBufferHandle>;
fn permute_0213_f32(&self, a: &GpuBufferHandle, d0: usize, d1: usize, d2: usize, d3: usize) -> FerrotorchResult<GpuBufferHandle>;
fn bmm_f32(&self, a: &GpuBufferHandle, b: &GpuBufferHandle, batch: usize, m: usize, k: usize, n: usize) -> FerrotorchResult<GpuBufferHandle>;
fn gelu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle>;
fn layernorm_f32(&self, input: &GpuBufferHandle, weight: &GpuBufferHandle, bias: &GpuBufferHandle, rows: usize, cols: usize, eps: f32) -> FerrotorchResult<GpuBufferHandle>;
fn slice_write_f32(&self, src: &GpuBufferHandle, dst: &mut GpuBufferHandle, n_batch: usize, d: usize, max_len: usize, pos: usize) -> FerrotorchResult<()>;
fn slice_read_f32(&self, src: &GpuBufferHandle, n_batch: usize, d: usize, len: usize, max_len: usize) -> FerrotorchResult<GpuBufferHandle>;
fn embed_lookup_f32(&self, idx: &GpuBufferHandle, weight: &GpuBufferHandle, d: usize) -> FerrotorchResult<GpuBufferHandle>;
fn embed_lookup_batch_f32(&self, indices: &GpuBufferHandle, weight: &GpuBufferHandle, n: usize, d: usize) -> FerrotorchResult<GpuBufferHandle>;
fn scatter_add_rows_f32(&self, grad_output: &GpuBufferHandle, indices: &GpuBufferHandle, num_embeddings: usize, d: usize) -> FerrotorchResult<GpuBufferHandle>;
fn scale_f32(&self, a: &GpuBufferHandle, scalar: f32) -> FerrotorchResult<GpuBufferHandle>;
fn relu_backward_f32(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle>;
fn gelu_backward_f32(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle>;
fn index_select_1d_f32(&self, input: &GpuBufferHandle, indices: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle>;
fn scatter_add_1d_f32(&self, grad_output: &GpuBufferHandle, indices: &GpuBufferHandle, input_len: usize) -> FerrotorchResult<GpuBufferHandle>;
fn masked_fill_f32(&self, input: &GpuBufferHandle, mask: &GpuBufferHandle, value: f32) -> FerrotorchResult<GpuBufferHandle>;
fn masked_zero_f32(&self, grad: &GpuBufferHandle, mask: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle>;
fn div_f32(&self, _a: &GpuBufferHandle, _b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "div_f32 GPU op not yet implemented".into() })
}
fn exp_f32(&self, _a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "exp_f32 GPU op not yet implemented".into() })
}
fn log_f32(&self, _a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "log_f32 GPU op not yet implemented".into() })
}
fn sqrt_f32(&self, _a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "sqrt_f32 GPU op not yet implemented".into() })
}
fn pow_f32(&self, _a: &GpuBufferHandle, _exponent: f32) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "pow_f32 GPU op not yet implemented".into() })
}
fn abs_f32(&self, _a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "abs_f32 GPU op not yet implemented".into() })
}
fn sigmoid_f32(&self, _a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "sigmoid_f32 GPU op not yet implemented".into() })
}
fn tanh_f32(&self, _a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "tanh_f32 GPU op not yet implemented".into() })
}
fn sigmoid_backward_f32(&self, _grad: &GpuBufferHandle, _output: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "sigmoid_backward_f32 GPU op not yet implemented".into() })
}
fn tanh_backward_f32(&self, _grad: &GpuBufferHandle, _output: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "tanh_backward_f32 GPU op not yet implemented".into() })
}
fn softmax_backward_f32(&self, _grad: &GpuBufferHandle, _output: &GpuBufferHandle, _cols: usize) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "softmax_backward_f32 GPU op not yet implemented".into() })
}
fn layernorm_backward_f32(
&self,
_input: &GpuBufferHandle,
_grad_output: &GpuBufferHandle,
_weight: &GpuBufferHandle,
_rows: usize,
_cols: usize,
_eps: f32,
) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
Err(FerrotorchError::InvalidArgument { message: "layernorm_backward_f32 GPU op not yet implemented".into() })
}
fn sum_axis_f32(&self, _a: &GpuBufferHandle, _shape: &[usize], _axis: usize) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "sum_axis_f32 GPU op not yet implemented".into() })
}
fn strided_split_f32(
&self,
_input: &GpuBufferHandle,
_total_along_axis: usize,
_split_offset: usize,
_split_size: usize,
_inner_size: usize,
_n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "strided_split_f32 GPU op not yet implemented".into() })
}
fn strided_cat_f32(
&self,
_input: &GpuBufferHandle,
_output: &mut GpuBufferHandle,
_total_along_axis: usize,
_cat_offset: usize,
_part_size: usize,
_inner_size: usize,
_n: usize,
) -> FerrotorchResult<()> {
Err(FerrotorchError::InvalidArgument { message: "strided_cat_f32 GPU op not yet implemented".into() })
}
fn has_inf_nan_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<bool> {
let bytes = self.gpu_to_cpu(a)?;
let floats: &[f32] = unsafe {
std::slice::from_raw_parts(bytes.as_ptr() as *const f32, bytes.len() / 4)
};
Ok(floats.iter().any(|v| !v.is_finite()))
}
fn save_rng_state(&self, device: usize) -> FerrotorchResult<GpuRngState> {
Err(FerrotorchError::InvalidArgument {
message: format!("save_rng_state not implemented for device {device}"),
})
}
fn restore_rng_state(&self, state: GpuRngState) -> FerrotorchResult<()> {
let _ = state;
Err(FerrotorchError::InvalidArgument {
message: "restore_rng_state not implemented".into(),
})
}
fn svd_f32(&self, _a: &GpuBufferHandle, _m: usize, _n: usize) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
Err(FerrotorchError::InvalidArgument { message: "svd_f32 GPU op not yet implemented".into() })
}
fn cholesky_f32(&self, _a: &GpuBufferHandle, _n: usize) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "cholesky_f32 GPU op not yet implemented".into() })
}
fn solve_f32(&self, _a: &GpuBufferHandle, _b: &GpuBufferHandle, _n: usize, _nrhs: usize) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "solve_f32 GPU op not yet implemented".into() })
}
fn qr_f32(&self, _a: &GpuBufferHandle, _m: usize, _n: usize) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
Err(FerrotorchError::InvalidArgument { message: "qr_f32 GPU op not yet implemented".into() })
}
}
static GPU_BACKEND: OnceLock<Box<dyn GpuBackend>> = OnceLock::new();
pub fn register_gpu_backend(backend: Box<dyn GpuBackend>) -> Result<(), Box<dyn GpuBackend>> {
GPU_BACKEND.set(backend)
}
pub fn gpu_backend() -> Option<&'static dyn GpuBackend> {
GPU_BACKEND.get().map(|b| b.as_ref())
}
pub fn has_gpu_backend() -> bool {
GPU_BACKEND.get().is_some()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gpu_buffer_handle() {
let handle = GpuBufferHandle::new(Box::new(42u64), 0, 100);
assert_eq!(handle.device_ordinal(), 0);
assert_eq!(handle.len(), 100);
assert!(!handle.is_empty());
assert_eq!(handle.downcast_ref::<u64>(), Some(&42));
}
#[test]
fn test_gpu_buffer_handle_debug() {
let handle = GpuBufferHandle::new(Box::new(()), 1, 50);
let s = format!("{handle:?}");
assert!(s.contains("device: 1"));
}
}