use crate::kernel_ops::KernelOps;
use crate::{TensorFactory, TensorOps, TensorRef};
use async_trait::async_trait;
use ferrum_types::{DataType, Device, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[async_trait]
pub trait ComputeBackend: Send + Sync {
fn name(&self) -> &str;
fn capabilities(&self) -> BackendCapabilities;
fn tensor_ops(&self) -> &dyn TensorOps;
fn tensor_factory(&self) -> &dyn TensorFactory;
fn memory_manager(&self) -> &dyn crate::DeviceMemoryManager;
fn kernel_executor(&self) -> Option<&dyn KernelExecutor>;
fn kernel_ops(&self) -> Option<&dyn KernelOps> {
None
}
async fn initialize(&mut self, device: &Device) -> Result<()>;
fn supports_device(&self, device: &Device) -> bool;
fn version(&self) -> String;
async fn synchronize(&self, device: &Device) -> Result<()>;
fn status(&self) -> BackendStatus;
async fn shutdown(&mut self) -> Result<()>;
}
#[async_trait]
pub trait WeightLoader: Send + Sync {
async fn load_tensor(&self, spec: &TensorSpec) -> Result<TensorRef>;
async fn load_tensors(&self, specs: &[TensorSpec]) -> Result<Vec<TensorRef>>;
async fn is_available(&self, source: &WeightSource) -> bool;
async fn get_metadata(&self, source: &WeightSource) -> Result<WeightMetadata>;
async fn preload(&self, source: &WeightSource) -> Result<()>;
fn capabilities(&self) -> WeightLoaderCapabilities;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorSpec {
pub name: String,
pub shape: Vec<usize>,
pub dtype: DataType,
pub device: Device,
pub source: WeightSource,
pub transformations: Vec<TensorTransformation>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum WeightSource {
File {
path: String,
tensor_name: Option<String>,
},
Url {
url: String,
headers: HashMap<String, String>,
},
HuggingFace {
repo_id: String,
filename: String,
revision: Option<String>,
cache_dir: Option<String>,
},
Memory { data: Vec<u8>, format: WeightFormat },
S3 {
bucket: String,
key: String,
region: Option<String>,
endpoint: Option<String>,
},
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum WeightFormat {
PyTorch,
SafeTensors,
Numpy,
Raw,
Onnx,
Custom(u32),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WeightMetadata {
pub tensors: HashMap<String, Vec<usize>>,
pub format: WeightFormat,
pub total_size_bytes: u64,
pub dtypes: Vec<DataType>,
pub extra: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TensorTransformation {
Transpose { dim0: usize, dim1: usize },
Reshape { shape: Vec<usize> },
Cast { dtype: DataType },
Quantize { config: QuantizationConfig },
Scale { factor: f32 },
Slice {
dim: usize,
start: Option<usize>,
end: Option<usize>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum QuantizationConfig {
INT8 { symmetric: bool },
INT4 { group_size: usize },
FP8 { e4m3: bool },
GPTQ {
bits: u8,
group_size: usize,
desc_act: bool,
},
AWQ { bits: u8, zero_point: bool },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackendCapabilities {
pub supported_dtypes: Vec<DataType>,
pub supported_devices: Vec<Device>,
pub max_tensor_dims: usize,
pub supports_fp16: bool,
pub supports_bf16: bool,
pub supports_int8: bool,
pub supports_flash_attention: bool,
pub supports_paged_attention: bool,
pub supports_tensor_parallelism: bool,
pub supports_pipeline_parallelism: bool,
pub max_batch_size: usize,
pub max_sequence_length: usize,
pub memory_alignment: usize,
pub supports_custom_kernels: bool,
pub supports_cuda_graphs: bool,
pub extra_capabilities: HashMap<String, serde_json::Value>,
}
impl BackendCapabilities {
pub fn meets_requirements(&self, requirements: &BackendRequirements) -> bool {
if !requirements
.required_devices
.iter()
.all(|dev| self.supported_devices.contains(dev))
{
return false;
}
if !requirements
.required_dtypes
.iter()
.all(|dtype| self.supported_dtypes.contains(dtype))
{
return false;
}
if requirements.min_batch_size > self.max_batch_size {
return false;
}
if requirements.min_sequence_length > self.max_sequence_length {
return false;
}
true
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackendRequirements {
pub required_devices: Vec<Device>,
pub required_dtypes: Vec<DataType>,
pub min_batch_size: usize,
pub min_sequence_length: usize,
pub requires_flash_attention: bool,
pub requires_paged_attention: bool,
pub extra_requirements: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WeightLoaderCapabilities {
pub supported_formats: Vec<WeightFormat>,
pub supported_sources: Vec<WeightSourceType>,
pub max_tensor_size: u64,
pub supports_streaming: bool,
pub supports_concurrent: bool,
pub supported_transformations: Vec<TransformationType>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum WeightSourceType {
File,
Url,
HuggingFace,
Memory,
S3,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum TransformationType {
Transpose,
Reshape,
Cast,
Quantize,
Scale,
Slice,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackendStatus {
pub is_initialized: bool,
pub is_ready: bool,
pub active_devices: Vec<Device>,
pub memory_usage: HashMap<Device, u64>,
pub operations_completed: u64,
pub last_error: Option<String>,
pub backend_specific: HashMap<String, serde_json::Value>,
}
#[async_trait]
pub trait KernelExecutor: Send + Sync {
async fn load_kernel(&self, source: &str, name: &str, device: &Device) -> Result<KernelHandle>;
async fn execute_kernel(
&self,
handle: KernelHandle,
grid_size: (u32, u32, u32),
block_size: (u32, u32, u32),
args: &[KernelArg],
) -> Result<()>;
fn get_kernel_info(&self, handle: KernelHandle) -> Option<KernelInfo>;
async fn unload_kernel(&self, handle: KernelHandle) -> Result<()>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct KernelHandle(pub u64);
#[derive(Debug, Clone)]
pub enum KernelArg {
Tensor(TensorRef),
Buffer { ptr: *const u8, size: usize },
Scalar(ScalarValue),
LocalMemory(usize),
}
#[derive(Debug, Clone)]
pub enum ScalarValue {
I8(i8),
I16(i16),
I32(i32),
I64(i64),
U8(u8),
U16(u16),
U32(u32),
U64(u64),
F32(f32),
F64(f64),
Bool(bool),
}
#[derive(Debug, Clone)]
pub struct KernelInfo {
pub name: String,
pub max_threads_per_block: u32,
pub shared_memory_size: usize,
pub registers_per_thread: u32,
pub preferred_block_size: (u32, u32, u32),
}
#[async_trait]
pub trait BackendFactory: Send + Sync {
async fn create_compute_backend(
&self,
config: &BackendConfig,
) -> Result<Box<dyn ComputeBackend>>;
async fn create_weight_loader(
&self,
config: &WeightLoaderConfig,
) -> Result<Box<dyn WeightLoader>>;
fn supported_backend_types(&self) -> Vec<BackendType>;
fn validate_config(&self, config: &BackendConfig) -> Result<()>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackendConfig {
pub backend_type: BackendType,
pub device: Device,
pub optimization_level: u8,
pub enable_debug: bool,
pub memory_config: BackendMemoryConfig,
pub backend_options: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WeightLoaderConfig {
pub enable_caching: bool,
pub cache_dir: Option<String>,
pub max_cache_size: Option<u64>,
pub max_concurrent_downloads: usize,
pub download_timeout_seconds: u64,
pub enable_integrity_checks: bool,
pub default_headers: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackendMemoryConfig {
pub pool_size: Option<u64>,
pub alignment: usize,
pub enable_pooling: bool,
pub growth_strategy: MemoryGrowthStrategy,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum MemoryGrowthStrategy {
Static,
Dynamic,
Incremental,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum BackendType {
Candle,
OnnxRuntime,
TensorRT,
Metal,
CPU,
Custom,
}
impl std::fmt::Display for BackendType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let name = match self {
BackendType::Candle => "candle",
BackendType::OnnxRuntime => "onnx_runtime",
BackendType::TensorRT => "tensorrt",
BackendType::Metal => "metal",
BackendType::CPU => "cpu",
BackendType::Custom => "custom",
};
write!(f, "{}", name)
}
}
pub trait BackendRegistry: Send + Sync {
fn register_compute_backend(
&mut self,
name: &str,
backend: Box<dyn ComputeBackend>,
) -> Result<()>;
fn register_weight_loader(&mut self, name: &str, loader: Box<dyn WeightLoader>) -> Result<()>;
fn get_compute_backend(&self, name: &str) -> Option<&dyn ComputeBackend>;
fn get_weight_loader(&self, name: &str) -> Option<&dyn WeightLoader>;
fn find_best_compute_backend(
&self,
requirements: &BackendRequirements,
) -> Option<&dyn ComputeBackend>;
fn list_backend_names(&self) -> (Vec<String>, Vec<String>); }