use std::any::Any;
use std::sync::OnceLock;
use crate::error::{FerrotorchError, FerrotorchResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct GpuRngState {
pub counter: u64,
pub seed: u64,
pub offset: u64,
pub device: usize,
}
pub struct GpuBufferHandle {
pub(crate) inner: Box<dyn Any + Send + Sync>,
pub(crate) device_ordinal: usize,
pub(crate) len: usize,
}
impl GpuBufferHandle {
pub fn new(inner: Box<dyn Any + Send + Sync>, device_ordinal: usize, len: usize) -> Self {
Self {
inner,
device_ordinal,
len,
}
}
#[inline]
pub fn device_ordinal(&self) -> usize {
self.device_ordinal
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn downcast_ref<T: 'static>(&self) -> Option<&T> {
self.inner.downcast_ref()
}
pub fn downcast_mut<T: 'static>(&mut self) -> Option<&mut T> {
self.inner.downcast_mut()
}
pub fn into_inner<T: 'static>(self) -> Result<T, Box<dyn Any + Send + Sync>> {
self.inner.downcast::<T>().map(|b| *b)
}
}
impl std::fmt::Debug for GpuBufferHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GpuBufferHandle")
.field("device", &self.device_ordinal)
.field("len", &self.len)
.finish()
}
}
pub trait GpuBackend: Send + Sync {
fn as_any(&self) -> &dyn std::any::Any;
fn cpu_to_gpu(
&self,
data: &[u8],
elem_size: usize,
device: usize,
) -> FerrotorchResult<GpuBufferHandle>;
fn gpu_to_cpu(&self, handle: &GpuBufferHandle) -> FerrotorchResult<Vec<u8>>;
fn clone_buffer(&self, handle: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle>;
fn alloc_zeros(
&self,
len: usize,
elem_size: usize,
device: usize,
) -> FerrotorchResult<GpuBufferHandle>;
fn add_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle>;
fn sub_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle>;
fn mul_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle>;
fn neg_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle>;
fn relu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle>;
fn matmul_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
m: usize,
k: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle>;
fn matmul_f16_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
m: usize,
k: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
self.matmul_f32(a, b, m, k, n)
}
fn sum_f32(&self, a: &GpuBufferHandle, len: usize) -> FerrotorchResult<GpuBufferHandle>;
fn add_f64(
&self,
_a: &GpuBufferHandle,
_b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument {
message: "f64 GPU ops not yet implemented".into(),
})
}
fn sub_f64(
&self,
_a: &GpuBufferHandle,
_b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument {
message: "f64 GPU ops not yet implemented".into(),
})
}
fn mul_f64(
&self,
_a: &GpuBufferHandle,
_b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument {
message: "f64 GPU ops not yet implemented".into(),
})
}
fn neg_f64(&self, _a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument {
message: "f64 GPU ops not yet implemented".into(),
})
}
fn relu_f64(&self, _a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument {
message: "f64 GPU ops not yet implemented".into(),
})
}
fn matmul_f64(
&self,
_a: &GpuBufferHandle,
_b: &GpuBufferHandle,
_m: usize,
_k: usize,
_n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument {
message: "f64 GPU ops not yet implemented".into(),
})
}
fn sum_f64(&self, _a: &GpuBufferHandle, _numel: usize) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument {
message: "f64 GPU ops not yet implemented".into(),
})
}
fn broadcast_add_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
) -> FerrotorchResult<GpuBufferHandle>;
fn broadcast_sub_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
) -> FerrotorchResult<GpuBufferHandle>;
fn broadcast_mul_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
) -> FerrotorchResult<GpuBufferHandle>;
fn softmax_f32(
&self,
a: &GpuBufferHandle,
rows: usize,
cols: usize,
) -> FerrotorchResult<GpuBufferHandle>;
fn dropout_f32(
&self,
a: &GpuBufferHandle,
threshold: u32,
scale: f32,
seed: u32,
) -> FerrotorchResult<GpuBufferHandle>;
fn dropout_philox_f32(
&self,
a: &GpuBufferHandle,
threshold: u32,
scale: f32,
) -> FerrotorchResult<(GpuBufferHandle, GpuRngState)> {
let result = self.dropout_f32(a, threshold, scale, 0)?;
Ok((
result,
GpuRngState {
counter: 0,
seed: 0,
offset: 0,
device: 0,
},
))
}
fn transpose_2d_f32(
&self,
a: &GpuBufferHandle,
m: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle>;
fn permute_0213_f32(
&self,
a: &GpuBufferHandle,
d0: usize,
d1: usize,
d2: usize,
d3: usize,
) -> FerrotorchResult<GpuBufferHandle>;
fn bmm_f32(
&self,
a: &GpuBufferHandle,
b: &GpuBufferHandle,
batch: usize,
m: usize,
k: usize,
n: usize,
) -> FerrotorchResult<GpuBufferHandle>;
fn gelu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle>;
fn layernorm_f32(
&self,
input: &GpuBufferHandle,
weight: &GpuBufferHandle,
bias: &GpuBufferHandle,
rows: usize,
cols: usize,
eps: f32,
) -> FerrotorchResult<GpuBufferHandle>;
fn slice_write_f32(
&self,
src: &GpuBufferHandle,
dst: &mut GpuBufferHandle,
n_batch: usize,
d: usize,
max_len: usize,
pos: usize,
) -> FerrotorchResult<()>;
fn slice_read_f32(
&self,
src: &GpuBufferHandle,
n_batch: usize,
d: usize,
len: usize,
max_len: usize,
) -> FerrotorchResult<GpuBufferHandle>;
fn embed_lookup_f32(
&self,
idx: &GpuBufferHandle,
weight: &GpuBufferHandle,
d: usize,
) -> FerrotorchResult<GpuBufferHandle>;
fn embed_lookup_batch_f32(
&self,
indices: &GpuBufferHandle,
weight: &GpuBufferHandle,
n: usize,
d: usize,
) -> FerrotorchResult<GpuBufferHandle>;
fn scatter_add_rows_f32(
&self,
grad_output: &GpuBufferHandle,
indices: &GpuBufferHandle,
num_embeddings: usize,
d: usize,
) -> FerrotorchResult<GpuBufferHandle>;
fn scale_f32(&self, a: &GpuBufferHandle, scalar: f32) -> FerrotorchResult<GpuBufferHandle>;
fn relu_backward_f32(
&self,
grad: &GpuBufferHandle,
input: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle>;
fn gelu_backward_f32(
&self,
grad: &GpuBufferHandle,
input: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle>;
fn index_select_1d_f32(
&self,
input: &GpuBufferHandle,
indices: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle>;
fn scatter_add_1d_f32(
&self,
grad_output: &GpuBufferHandle,
indices: &GpuBufferHandle,
input_len: usize,
) -> FerrotorchResult<GpuBufferHandle>;
fn masked_fill_f32(
&self,
input: &GpuBufferHandle,
mask: &GpuBufferHandle,
value: f32,
) -> FerrotorchResult<GpuBufferHandle>;
fn masked_zero_f32(
&self,
grad: &GpuBufferHandle,
mask: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle>;
fn div_f32(
&self,
_a: &GpuBufferHandle,
_b: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument {
message: "div_f32 GPU op not yet implemented".into(),
})
}
fn exp_f32(&self, _a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument {
message: "exp_f32 GPU op not yet implemented".into(),
})
}
fn log_f32(&self, _a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument {
message: "log_f32 GPU op not yet implemented".into(),
})
}
fn sqrt_f32(&self, _a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument {
message: "sqrt_f32 GPU op not yet implemented".into(),
})
}
fn pow_f32(&self, _a: &GpuBufferHandle, _exponent: f32) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument {
message: "pow_f32 GPU op not yet implemented".into(),
})
}
fn abs_f32(&self, _a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument {
message: "abs_f32 GPU op not yet implemented".into(),
})
}
fn sigmoid_f32(&self, _a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument {
message: "sigmoid_f32 GPU op not yet implemented".into(),
})
}
fn tanh_f32(&self, _a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument {
message: "tanh_f32 GPU op not yet implemented".into(),
})
}
fn sigmoid_backward_f32(
&self,
_grad: &GpuBufferHandle,
_output: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument {
message: "sigmoid_backward_f32 GPU op not yet implemented".into(),
})
}
fn tanh_backward_f32(
&self,
_grad: &GpuBufferHandle,
_output: &GpuBufferHandle,
) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument {
message: "tanh_backward_f32 GPU op not yet implemented".into(),
})
}
fn softmax_backward_f32(
&self,
_grad: &GpuBufferHandle,
_output: &GpuBufferHandle,
_cols: usize,
) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument {
message: "softmax_backward_f32 GPU op not yet implemented".into(),
})
}
fn layernorm_backward_f32(
&self,
_input: &GpuBufferHandle,
_grad_output: &GpuBufferHandle,
_weight: &GpuBufferHandle,
_rows: usize,
_cols: usize,
_eps: f32,
) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
Err(FerrotorchError::InvalidArgument {
message: "layernorm_backward_f32 GPU op not yet implemented".into(),
})
}
fn sum_axis_f32(
&self,
_a: &GpuBufferHandle,
_shape: &[usize],
_axis: usize,
) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument {
message: "sum_axis_f32 GPU op not yet implemented".into(),
})
}
fn strided_split_f32(
&self,
_input: &GpuBufferHandle,
_total_along_axis: usize,
_split_offset: usize,
_split_size: usize,
_inner_size: usize,
_n: usize,
) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument {
message: "strided_split_f32 GPU op not yet implemented".into(),
})
}
#[allow(clippy::too_many_arguments)]
fn strided_cat_f32(
&self,
_input: &GpuBufferHandle,
_output: &mut GpuBufferHandle,
_total_along_axis: usize,
_cat_offset: usize,
_part_size: usize,
_inner_size: usize,
_n: usize,
) -> FerrotorchResult<()> {
Err(FerrotorchError::InvalidArgument {
message: "strided_cat_f32 GPU op not yet implemented".into(),
})
}
fn has_inf_nan_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<bool> {
let bytes = self.gpu_to_cpu(a)?;
let floats: &[f32] =
unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const f32, bytes.len() / 4) };
Ok(floats.iter().any(|v| !v.is_finite()))
}
fn save_rng_state(&self, device: usize) -> FerrotorchResult<GpuRngState> {
Err(FerrotorchError::InvalidArgument {
message: format!("save_rng_state not implemented for device {device}"),
})
}
fn restore_rng_state(&self, state: GpuRngState) -> FerrotorchResult<()> {
let _ = state;
Err(FerrotorchError::InvalidArgument {
message: "restore_rng_state not implemented".into(),
})
}
fn svd_f32(
&self,
_a: &GpuBufferHandle,
_m: usize,
_n: usize,
) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle, GpuBufferHandle)> {
Err(FerrotorchError::InvalidArgument {
message: "svd_f32 GPU op not yet implemented".into(),
})
}
fn cholesky_f32(&self, _a: &GpuBufferHandle, _n: usize) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument {
message: "cholesky_f32 GPU op not yet implemented".into(),
})
}
fn solve_f32(
&self,
_a: &GpuBufferHandle,
_b: &GpuBufferHandle,
_n: usize,
_nrhs: usize,
) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument {
message: "solve_f32 GPU op not yet implemented".into(),
})
}
fn qr_f32(
&self,
_a: &GpuBufferHandle,
_m: usize,
_n: usize,
) -> FerrotorchResult<(GpuBufferHandle, GpuBufferHandle)> {
Err(FerrotorchError::InvalidArgument {
message: "qr_f32 GPU op not yet implemented".into(),
})
}
}
static GPU_BACKEND: OnceLock<Box<dyn GpuBackend>> = OnceLock::new();
pub fn register_gpu_backend(backend: Box<dyn GpuBackend>) -> Result<(), Box<dyn GpuBackend>> {
GPU_BACKEND.set(backend)
}
pub fn gpu_backend() -> Option<&'static dyn GpuBackend> {
GPU_BACKEND.get().map(|b| b.as_ref())
}
pub fn has_gpu_backend() -> bool {
GPU_BACKEND.get().is_some()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gpu_buffer_handle() {
let handle = GpuBufferHandle::new(Box::new(42u64), 0, 100);
assert_eq!(handle.device_ordinal(), 0);
assert_eq!(handle.len(), 100);
assert!(!handle.is_empty());
assert_eq!(handle.downcast_ref::<u64>(), Some(&42));
}
#[test]
fn test_gpu_buffer_handle_debug() {
let handle = GpuBufferHandle::new(Box::new(()), 1, 50);
let s = format!("{handle:?}");
assert!(s.contains("device: 1"));
}
}