use std::any::Any;
use std::sync::OnceLock;
use crate::error::{FerrotorchError, FerrotorchResult};
pub struct GpuBufferHandle {
pub(crate) inner: Box<dyn Any + Send + Sync>,
pub(crate) device_ordinal: usize,
pub(crate) len: usize,
}
impl GpuBufferHandle {
pub fn new(inner: Box<dyn Any + Send + Sync>, device_ordinal: usize, len: usize) -> Self {
Self { inner, device_ordinal, len }
}
#[inline]
pub fn device_ordinal(&self) -> usize { self.device_ordinal }
#[inline]
pub fn len(&self) -> usize { self.len }
#[inline]
pub fn is_empty(&self) -> bool { self.len == 0 }
pub fn downcast_ref<T: 'static>(&self) -> Option<&T> {
self.inner.downcast_ref()
}
pub fn downcast_mut<T: 'static>(&mut self) -> Option<&mut T> {
self.inner.downcast_mut()
}
pub fn into_inner<T: 'static>(self) -> Result<T, Box<dyn Any + Send + Sync>> {
self.inner.downcast::<T>().map(|b| *b)
}
}
impl std::fmt::Debug for GpuBufferHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GpuBufferHandle")
.field("device", &self.device_ordinal)
.field("len", &self.len)
.finish()
}
}
pub trait GpuBackend: Send + Sync {
fn as_any(&self) -> &dyn std::any::Any;
fn cpu_to_gpu(&self, data: &[u8], elem_size: usize, device: usize) -> FerrotorchResult<GpuBufferHandle>;
fn gpu_to_cpu(&self, handle: &GpuBufferHandle) -> FerrotorchResult<Vec<u8>>;
fn clone_buffer(&self, handle: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle>;
fn alloc_zeros(&self, len: usize, elem_size: usize, device: usize) -> FerrotorchResult<GpuBufferHandle>;
fn add_f32(&self, a: &GpuBufferHandle, b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle>;
fn sub_f32(&self, a: &GpuBufferHandle, b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle>;
fn mul_f32(&self, a: &GpuBufferHandle, b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle>;
fn neg_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle>;
fn relu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle>;
fn matmul_f32(&self, a: &GpuBufferHandle, b: &GpuBufferHandle, m: usize, k: usize, n: usize) -> FerrotorchResult<GpuBufferHandle>;
fn sum_f32(&self, a: &GpuBufferHandle, len: usize) -> FerrotorchResult<GpuBufferHandle>;
fn add_f64(&self, _a: &GpuBufferHandle, _b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "f64 GPU ops not yet implemented".into() })
}
fn sub_f64(&self, _a: &GpuBufferHandle, _b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "f64 GPU ops not yet implemented".into() })
}
fn mul_f64(&self, _a: &GpuBufferHandle, _b: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "f64 GPU ops not yet implemented".into() })
}
fn neg_f64(&self, _a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "f64 GPU ops not yet implemented".into() })
}
fn relu_f64(&self, _a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "f64 GPU ops not yet implemented".into() })
}
fn matmul_f64(&self, _a: &GpuBufferHandle, _b: &GpuBufferHandle, _m: usize, _k: usize, _n: usize) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "f64 GPU ops not yet implemented".into() })
}
fn sum_f64(&self, _a: &GpuBufferHandle, _numel: usize) -> FerrotorchResult<GpuBufferHandle> {
Err(FerrotorchError::InvalidArgument { message: "f64 GPU ops not yet implemented".into() })
}
fn broadcast_add_f32(&self, a: &GpuBufferHandle, b: &GpuBufferHandle, a_shape: &[usize], b_shape: &[usize], out_shape: &[usize]) -> FerrotorchResult<GpuBufferHandle>;
fn broadcast_sub_f32(&self, a: &GpuBufferHandle, b: &GpuBufferHandle, a_shape: &[usize], b_shape: &[usize], out_shape: &[usize]) -> FerrotorchResult<GpuBufferHandle>;
fn broadcast_mul_f32(&self, a: &GpuBufferHandle, b: &GpuBufferHandle, a_shape: &[usize], b_shape: &[usize], out_shape: &[usize]) -> FerrotorchResult<GpuBufferHandle>;
fn softmax_f32(&self, a: &GpuBufferHandle, rows: usize, cols: usize) -> FerrotorchResult<GpuBufferHandle>;
fn dropout_f32(&self, a: &GpuBufferHandle, threshold: u32, scale: f32, seed: u32) -> FerrotorchResult<GpuBufferHandle>;
fn transpose_2d_f32(&self, a: &GpuBufferHandle, m: usize, n: usize) -> FerrotorchResult<GpuBufferHandle>;
fn permute_0213_f32(&self, a: &GpuBufferHandle, d0: usize, d1: usize, d2: usize, d3: usize) -> FerrotorchResult<GpuBufferHandle>;
fn bmm_f32(&self, a: &GpuBufferHandle, b: &GpuBufferHandle, batch: usize, m: usize, k: usize, n: usize) -> FerrotorchResult<GpuBufferHandle>;
fn gelu_f32(&self, a: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle>;
fn layernorm_f32(&self, input: &GpuBufferHandle, weight: &GpuBufferHandle, bias: &GpuBufferHandle, rows: usize, cols: usize, eps: f32) -> FerrotorchResult<GpuBufferHandle>;
fn slice_write_f32(&self, src: &GpuBufferHandle, dst: &mut GpuBufferHandle, n_batch: usize, d: usize, max_len: usize, pos: usize) -> FerrotorchResult<()>;
fn slice_read_f32(&self, src: &GpuBufferHandle, n_batch: usize, d: usize, len: usize, max_len: usize) -> FerrotorchResult<GpuBufferHandle>;
fn embed_lookup_f32(&self, idx: &GpuBufferHandle, weight: &GpuBufferHandle, d: usize) -> FerrotorchResult<GpuBufferHandle>;
fn scale_f32(&self, a: &GpuBufferHandle, scalar: f32) -> FerrotorchResult<GpuBufferHandle>;
fn relu_backward_f32(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle>;
fn gelu_backward_f32(&self, grad: &GpuBufferHandle, input: &GpuBufferHandle) -> FerrotorchResult<GpuBufferHandle>;
}
static GPU_BACKEND: OnceLock<Box<dyn GpuBackend>> = OnceLock::new();
pub fn register_gpu_backend(backend: Box<dyn GpuBackend>) -> Result<(), Box<dyn GpuBackend>> {
GPU_BACKEND.set(backend)
}
pub fn gpu_backend() -> Option<&'static dyn GpuBackend> {
GPU_BACKEND.get().map(|b| b.as_ref())
}
pub fn has_gpu_backend() -> bool {
GPU_BACKEND.get().is_some()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gpu_buffer_handle() {
let handle = GpuBufferHandle::new(Box::new(42u64), 0, 100);
assert_eq!(handle.device_ordinal(), 0);
assert_eq!(handle.len(), 100);
assert!(!handle.is_empty());
assert_eq!(handle.downcast_ref::<u64>(), Some(&42));
}
#[test]
fn test_gpu_buffer_handle_debug() {
let handle = GpuBufferHandle::new(Box::new(()), 1, 50);
let s = format!("{handle:?}");
assert!(s.contains("device: 1"));
}
}