1use crate::kernel_ops::KernelOps;
8use crate::{TensorFactory, TensorOps, TensorRef};
9use async_trait::async_trait;
10use ferrum_types::{DataType, Device, Result};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14#[async_trait]
16pub trait ComputeBackend: Send + Sync {
17 fn name(&self) -> &str;
19
20 fn capabilities(&self) -> BackendCapabilities;
22
23 fn tensor_ops(&self) -> &dyn TensorOps;
25
26 fn tensor_factory(&self) -> &dyn TensorFactory;
28
29 fn memory_manager(&self) -> &dyn crate::DeviceMemoryManager;
31
32 fn kernel_executor(&self) -> Option<&dyn KernelExecutor>;
34
35 fn kernel_ops(&self) -> Option<&dyn KernelOps> {
41 None
42 }
43
44 async fn initialize(&mut self, device: &Device) -> Result<()>;
46
47 fn supports_device(&self, device: &Device) -> bool;
49
50 fn version(&self) -> String;
52
53 async fn synchronize(&self, device: &Device) -> Result<()>;
55
56 fn status(&self) -> BackendStatus;
58
59 async fn shutdown(&mut self) -> Result<()>;
61}
62
63#[async_trait]
65pub trait WeightLoader: Send + Sync {
66 async fn load_tensor(&self, spec: &TensorSpec) -> Result<TensorRef>;
68
69 async fn load_tensors(&self, specs: &[TensorSpec]) -> Result<Vec<TensorRef>>;
71
72 async fn is_available(&self, source: &WeightSource) -> bool;
74
75 async fn get_metadata(&self, source: &WeightSource) -> Result<WeightMetadata>;
77
78 async fn preload(&self, source: &WeightSource) -> Result<()>;
80
81 fn capabilities(&self) -> WeightLoaderCapabilities;
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct TensorSpec {
88 pub name: String,
90 pub shape: Vec<usize>,
92 pub dtype: DataType,
94 pub device: Device,
96 pub source: WeightSource,
98 pub transformations: Vec<TensorTransformation>,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub enum WeightSource {
105 File {
107 path: String,
108 tensor_name: Option<String>,
110 },
111 Url {
113 url: String,
114 headers: HashMap<String, String>,
115 },
116 HuggingFace {
118 repo_id: String,
119 filename: String,
120 revision: Option<String>,
121 cache_dir: Option<String>,
122 },
123 Memory { data: Vec<u8>, format: WeightFormat },
125 S3 {
127 bucket: String,
128 key: String,
129 region: Option<String>,
130 endpoint: Option<String>,
131 },
132}
133
134#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
136pub enum WeightFormat {
137 PyTorch,
139 SafeTensors,
141 Numpy,
143 Raw,
145 Onnx,
147 Custom(u32),
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct WeightMetadata {
154 pub tensors: HashMap<String, Vec<usize>>,
156 pub format: WeightFormat,
158 pub total_size_bytes: u64,
160 pub dtypes: Vec<DataType>,
162 pub extra: HashMap<String, serde_json::Value>,
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize)]
168pub enum TensorTransformation {
169 Transpose { dim0: usize, dim1: usize },
171 Reshape { shape: Vec<usize> },
173 Cast { dtype: DataType },
175 Quantize { config: QuantizationConfig },
177 Scale { factor: f32 },
179 Slice {
181 dim: usize,
182 start: Option<usize>,
183 end: Option<usize>,
184 },
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
189pub enum QuantizationConfig {
190 INT8 { symmetric: bool },
192 INT4 { group_size: usize },
194 FP8 { e4m3: bool },
196 GPTQ {
198 bits: u8,
199 group_size: usize,
200 desc_act: bool,
201 },
202 AWQ { bits: u8, zero_point: bool },
204}
205
206#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct BackendCapabilities {
209 pub supported_dtypes: Vec<DataType>,
211 pub supported_devices: Vec<Device>,
213 pub max_tensor_dims: usize,
215 pub supports_fp16: bool,
217 pub supports_bf16: bool,
219 pub supports_int8: bool,
221 pub supports_flash_attention: bool,
223 pub supports_paged_attention: bool,
225 pub supports_tensor_parallelism: bool,
227 pub supports_pipeline_parallelism: bool,
229 pub max_batch_size: usize,
231 pub max_sequence_length: usize,
233 pub memory_alignment: usize,
235 pub supports_custom_kernels: bool,
237 pub supports_cuda_graphs: bool,
239 pub extra_capabilities: HashMap<String, serde_json::Value>,
241}
242
243impl BackendCapabilities {
244 pub fn meets_requirements(&self, requirements: &BackendRequirements) -> bool {
246 if !requirements
248 .required_devices
249 .iter()
250 .all(|dev| self.supported_devices.contains(dev))
251 {
252 return false;
253 }
254
255 if !requirements
257 .required_dtypes
258 .iter()
259 .all(|dtype| self.supported_dtypes.contains(dtype))
260 {
261 return false;
262 }
263
264 if requirements.min_batch_size > self.max_batch_size {
266 return false;
267 }
268
269 if requirements.min_sequence_length > self.max_sequence_length {
271 return false;
272 }
273
274 true
275 }
276}
277
278#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct BackendRequirements {
281 pub required_devices: Vec<Device>,
283 pub required_dtypes: Vec<DataType>,
285 pub min_batch_size: usize,
287 pub min_sequence_length: usize,
289 pub requires_flash_attention: bool,
291 pub requires_paged_attention: bool,
293 pub extra_requirements: HashMap<String, serde_json::Value>,
295}
296
297#[derive(Debug, Clone, Serialize, Deserialize)]
299pub struct WeightLoaderCapabilities {
300 pub supported_formats: Vec<WeightFormat>,
302 pub supported_sources: Vec<WeightSourceType>,
304 pub max_tensor_size: u64,
306 pub supports_streaming: bool,
308 pub supports_concurrent: bool,
310 pub supported_transformations: Vec<TransformationType>,
312}
313
314#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
316pub enum WeightSourceType {
317 File,
318 Url,
319 HuggingFace,
320 Memory,
321 S3,
322}
323
324#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
326pub enum TransformationType {
327 Transpose,
328 Reshape,
329 Cast,
330 Quantize,
331 Scale,
332 Slice,
333}
334
335#[derive(Debug, Clone, Serialize, Deserialize)]
337pub struct BackendStatus {
338 pub is_initialized: bool,
340 pub is_ready: bool,
342 pub active_devices: Vec<Device>,
344 pub memory_usage: HashMap<Device, u64>,
346 pub operations_completed: u64,
348 pub last_error: Option<String>,
350 pub backend_specific: HashMap<String, serde_json::Value>,
352}
353
354#[async_trait]
356pub trait KernelExecutor: Send + Sync {
357 async fn load_kernel(&self, source: &str, name: &str, device: &Device) -> Result<KernelHandle>;
359
360 async fn execute_kernel(
362 &self,
363 handle: KernelHandle,
364 grid_size: (u32, u32, u32),
365 block_size: (u32, u32, u32),
366 args: &[KernelArg],
367 ) -> Result<()>;
368
369 fn get_kernel_info(&self, handle: KernelHandle) -> Option<KernelInfo>;
371
372 async fn unload_kernel(&self, handle: KernelHandle) -> Result<()>;
374}
375
376#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
378pub struct KernelHandle(pub u64);
379
380#[derive(Debug, Clone)]
382pub enum KernelArg {
383 Tensor(TensorRef),
385 Buffer { ptr: *const u8, size: usize },
387 Scalar(ScalarValue),
389 LocalMemory(usize),
391}
392
393#[derive(Debug, Clone)]
395pub enum ScalarValue {
396 I8(i8),
397 I16(i16),
398 I32(i32),
399 I64(i64),
400 U8(u8),
401 U16(u16),
402 U32(u32),
403 U64(u64),
404 F32(f32),
405 F64(f64),
406 Bool(bool),
407}
408
409#[derive(Debug, Clone)]
411pub struct KernelInfo {
412 pub name: String,
414 pub max_threads_per_block: u32,
416 pub shared_memory_size: usize,
418 pub registers_per_thread: u32,
420 pub preferred_block_size: (u32, u32, u32),
422}
423
424#[async_trait]
426pub trait BackendFactory: Send + Sync {
427 async fn create_compute_backend(
429 &self,
430 config: &BackendConfig,
431 ) -> Result<Box<dyn ComputeBackend>>;
432
433 async fn create_weight_loader(
435 &self,
436 config: &WeightLoaderConfig,
437 ) -> Result<Box<dyn WeightLoader>>;
438
439 fn supported_backend_types(&self) -> Vec<BackendType>;
441
442 fn validate_config(&self, config: &BackendConfig) -> Result<()>;
444}
445
446#[derive(Debug, Clone, Serialize, Deserialize)]
448pub struct BackendConfig {
449 pub backend_type: BackendType,
451 pub device: Device,
453 pub optimization_level: u8,
455 pub enable_debug: bool,
457 pub memory_config: BackendMemoryConfig,
459 pub backend_options: HashMap<String, serde_json::Value>,
461}
462
463#[derive(Debug, Clone, Serialize, Deserialize)]
465pub struct WeightLoaderConfig {
466 pub enable_caching: bool,
468 pub cache_dir: Option<String>,
470 pub max_cache_size: Option<u64>,
472 pub max_concurrent_downloads: usize,
474 pub download_timeout_seconds: u64,
476 pub enable_integrity_checks: bool,
478 pub default_headers: HashMap<String, String>,
480}
481
482#[derive(Debug, Clone, Serialize, Deserialize)]
484pub struct BackendMemoryConfig {
485 pub pool_size: Option<u64>,
487 pub alignment: usize,
489 pub enable_pooling: bool,
491 pub growth_strategy: MemoryGrowthStrategy,
493}
494
495#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
497pub enum MemoryGrowthStrategy {
498 Static,
500 Dynamic,
502 Incremental,
504}
505
506#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
508pub enum BackendType {
509 Candle,
511 OnnxRuntime,
513 TensorRT,
515 Metal,
517 CPU,
519 Custom,
521}
522
523impl std::fmt::Display for BackendType {
524 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
525 let name = match self {
526 BackendType::Candle => "candle",
527 BackendType::OnnxRuntime => "onnx_runtime",
528 BackendType::TensorRT => "tensorrt",
529 BackendType::Metal => "metal",
530 BackendType::CPU => "cpu",
531 BackendType::Custom => "custom",
532 };
533 write!(f, "{}", name)
534 }
535}
536
537pub trait BackendRegistry: Send + Sync {
539 fn register_compute_backend(
541 &mut self,
542 name: &str,
543 backend: Box<dyn ComputeBackend>,
544 ) -> Result<()>;
545
546 fn register_weight_loader(&mut self, name: &str, loader: Box<dyn WeightLoader>) -> Result<()>;
548
549 fn get_compute_backend(&self, name: &str) -> Option<&dyn ComputeBackend>;
551
552 fn get_weight_loader(&self, name: &str) -> Option<&dyn WeightLoader>;
554
555 fn find_best_compute_backend(
557 &self,
558 requirements: &BackendRequirements,
559 ) -> Option<&dyn ComputeBackend>;
560
561 fn list_backend_names(&self) -> (Vec<String>, Vec<String>); }