Skip to main content

ferrum_interfaces/
backend.rs

1//! Backend abstraction split into compute and weight loading concerns
2//!
3//! This module separates the previous "fat" Backend trait into focused
4//! interfaces: ComputeBackend for tensor operations and WeightLoader for
5//! model weight management.
6
7use crate::kernel_ops::KernelOps;
8use crate::{TensorFactory, TensorOps, TensorRef};
9use async_trait::async_trait;
10use ferrum_types::{DataType, Device, Result};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14/// Compute backend for tensor operations and kernel execution
15#[async_trait]
16pub trait ComputeBackend: Send + Sync {
17    /// Get backend name/identifier
18    fn name(&self) -> &str;
19
20    /// Get backend capabilities
21    fn capabilities(&self) -> BackendCapabilities;
22
23    /// Get tensor operations interface
24    fn tensor_ops(&self) -> &dyn TensorOps;
25
26    /// Get tensor factory for creating tensors
27    fn tensor_factory(&self) -> &dyn TensorFactory;
28
29    /// Get memory manager for this backend
30    fn memory_manager(&self) -> &dyn crate::DeviceMemoryManager;
31
32    /// Get kernel executor (if backend supports custom kernels)
33    fn kernel_executor(&self) -> Option<&dyn KernelExecutor>;
34
35    /// Get LLM-specific kernel operations (if backend provides optimized impls).
36    ///
37    /// Returns `None` by default — existing backends compile unchanged.
38    /// Backends that implement `KernelOps` sub-traits (NormOps, PositionOps, etc.)
39    /// return `Some` here to enable accelerated paths.
40    fn kernel_ops(&self) -> Option<&dyn KernelOps> {
41        None
42    }
43
44    /// Initialize backend with device
45    async fn initialize(&mut self, device: &Device) -> Result<()>;
46
47    /// Check if backend supports specific device
48    fn supports_device(&self, device: &Device) -> bool;
49
50    /// Get backend version
51    fn version(&self) -> String;
52
53    /// Synchronize all pending operations
54    async fn synchronize(&self, device: &Device) -> Result<()>;
55
56    /// Get backend status
57    fn status(&self) -> BackendStatus;
58
59    /// Shutdown backend gracefully
60    async fn shutdown(&mut self) -> Result<()>;
61}
62
63/// Weight loading interface for model parameter management
64#[async_trait]
65pub trait WeightLoader: Send + Sync {
66    /// Load tensor from weight specification
67    async fn load_tensor(&self, spec: &TensorSpec) -> Result<TensorRef>;
68
69    /// Load multiple tensors at once (batch loading)
70    async fn load_tensors(&self, specs: &[TensorSpec]) -> Result<Vec<TensorRef>>;
71
72    /// Check if weight source is available
73    async fn is_available(&self, source: &WeightSource) -> bool;
74
75    /// Get metadata about weight source
76    async fn get_metadata(&self, source: &WeightSource) -> Result<WeightMetadata>;
77
78    /// Preload weights into cache/memory
79    async fn preload(&self, source: &WeightSource) -> Result<()>;
80
81    /// Get loader capabilities
82    fn capabilities(&self) -> WeightLoaderCapabilities;
83}
84
85/// Tensor specification for weight loading
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct TensorSpec {
88    /// Name/identifier of the tensor
89    pub name: String,
90    /// Expected tensor shape
91    pub shape: Vec<usize>,
92    /// Target data type
93    pub dtype: DataType,
94    /// Target device
95    pub device: Device,
96    /// Weight source location
97    pub source: WeightSource,
98    /// Optional transformations to apply
99    pub transformations: Vec<TensorTransformation>,
100}
101
102/// Weight source specification
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub enum WeightSource {
105    /// Local file path
106    File {
107        path: String,
108        /// Tensor name within file (for formats like safetensors)
109        tensor_name: Option<String>,
110    },
111    /// URL for download
112    Url {
113        url: String,
114        headers: HashMap<String, String>,
115    },
116    /// Hugging Face Hub
117    HuggingFace {
118        repo_id: String,
119        filename: String,
120        revision: Option<String>,
121        cache_dir: Option<String>,
122    },
123    /// Raw bytes in memory
124    Memory { data: Vec<u8>, format: WeightFormat },
125    /// S3-compatible storage
126    S3 {
127        bucket: String,
128        key: String,
129        region: Option<String>,
130        endpoint: Option<String>,
131    },
132}
133
134/// Weight file formats
135#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
136pub enum WeightFormat {
137    /// PyTorch tensor format
138    PyTorch,
139    /// Safetensors format
140    SafeTensors,
141    /// NumPy array format
142    Numpy,
143    /// Raw binary data
144    Raw,
145    /// ONNX format
146    Onnx,
147    /// Custom format
148    Custom(u32),
149}
150
151/// Weight metadata information
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct WeightMetadata {
154    /// Available tensor names and their shapes
155    pub tensors: HashMap<String, Vec<usize>>,
156    /// File format
157    pub format: WeightFormat,
158    /// Total size in bytes
159    pub total_size_bytes: u64,
160    /// Data types used
161    pub dtypes: Vec<DataType>,
162    /// Additional metadata
163    pub extra: HashMap<String, serde_json::Value>,
164}
165
166/// Transformations that can be applied to loaded tensors
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub enum TensorTransformation {
169    /// Transpose dimensions
170    Transpose { dim0: usize, dim1: usize },
171    /// Reshape to new shape
172    Reshape { shape: Vec<usize> },
173    /// Convert data type
174    Cast { dtype: DataType },
175    /// Quantize tensor
176    Quantize { config: QuantizationConfig },
177    /// Apply scaling
178    Scale { factor: f32 },
179    /// Slice tensor
180    Slice {
181        dim: usize,
182        start: Option<usize>,
183        end: Option<usize>,
184    },
185}
186
187/// Quantization configuration for weights
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub enum QuantizationConfig {
190    /// INT8 uniform quantization
191    INT8 { symmetric: bool },
192    /// INT4 grouped quantization  
193    INT4 { group_size: usize },
194    /// FP8 quantization
195    FP8 { e4m3: bool },
196    /// GPTQ quantization
197    GPTQ {
198        bits: u8,
199        group_size: usize,
200        desc_act: bool,
201    },
202    /// AWQ quantization
203    AWQ { bits: u8, zero_point: bool },
204}
205
206/// Backend capabilities description
207#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct BackendCapabilities {
209    /// Supported data types
210    pub supported_dtypes: Vec<DataType>,
211    /// Supported devices
212    pub supported_devices: Vec<Device>,
213    /// Maximum tensor dimensions supported
214    pub max_tensor_dims: usize,
215    /// Whether backend supports FP16 operations
216    pub supports_fp16: bool,
217    /// Whether backend supports BF16 operations
218    pub supports_bf16: bool,
219    /// Whether backend supports INT8 quantization
220    pub supports_int8: bool,
221    /// Whether backend supports flash attention
222    pub supports_flash_attention: bool,
223    /// Whether backend supports paged attention
224    pub supports_paged_attention: bool,
225    /// Whether backend supports tensor parallelism
226    pub supports_tensor_parallelism: bool,
227    /// Whether backend supports pipeline parallelism
228    pub supports_pipeline_parallelism: bool,
229    /// Maximum batch size supported
230    pub max_batch_size: usize,
231    /// Maximum sequence length supported
232    pub max_sequence_length: usize,
233    /// Memory alignment requirements
234    pub memory_alignment: usize,
235    /// Whether backend supports custom kernels
236    pub supports_custom_kernels: bool,
237    /// Whether backend supports CUDA graphs
238    pub supports_cuda_graphs: bool,
239    /// Additional capabilities
240    pub extra_capabilities: HashMap<String, serde_json::Value>,
241}
242
243impl BackendCapabilities {
244    /// Check if capabilities meet requirements
245    pub fn meets_requirements(&self, requirements: &BackendRequirements) -> bool {
246        // Check devices
247        if !requirements
248            .required_devices
249            .iter()
250            .all(|dev| self.supported_devices.contains(dev))
251        {
252            return false;
253        }
254
255        // Check dtypes
256        if !requirements
257            .required_dtypes
258            .iter()
259            .all(|dtype| self.supported_dtypes.contains(dtype))
260        {
261            return false;
262        }
263
264        // Check batch size
265        if requirements.min_batch_size > self.max_batch_size {
266            return false;
267        }
268
269        // Check sequence length
270        if requirements.min_sequence_length > self.max_sequence_length {
271            return false;
272        }
273
274        true
275    }
276}
277
278/// Requirements for backend selection
279#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct BackendRequirements {
281    /// Required devices
282    pub required_devices: Vec<Device>,
283    /// Required data types
284    pub required_dtypes: Vec<DataType>,
285    /// Minimum batch size needed
286    pub min_batch_size: usize,
287    /// Minimum sequence length needed
288    pub min_sequence_length: usize,
289    /// Whether flash attention is required
290    pub requires_flash_attention: bool,
291    /// Whether paged attention is required
292    pub requires_paged_attention: bool,
293    /// Additional requirements
294    pub extra_requirements: HashMap<String, serde_json::Value>,
295}
296
297/// Weight loader capabilities
298#[derive(Debug, Clone, Serialize, Deserialize)]
299pub struct WeightLoaderCapabilities {
300    /// Supported weight formats
301    pub supported_formats: Vec<WeightFormat>,
302    /// Supported weight sources
303    pub supported_sources: Vec<WeightSourceType>,
304    /// Maximum single tensor size in bytes
305    pub max_tensor_size: u64,
306    /// Whether loader supports streaming/chunked loading
307    pub supports_streaming: bool,
308    /// Whether loader supports concurrent loading
309    pub supports_concurrent: bool,
310    /// Supported transformations
311    pub supported_transformations: Vec<TransformationType>,
312}
313
314/// Weight source types
315#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
316pub enum WeightSourceType {
317    File,
318    Url,
319    HuggingFace,
320    Memory,
321    S3,
322}
323
324/// Transformation types
325#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
326pub enum TransformationType {
327    Transpose,
328    Reshape,
329    Cast,
330    Quantize,
331    Scale,
332    Slice,
333}
334
335/// Backend status information
336#[derive(Debug, Clone, Serialize, Deserialize)]
337pub struct BackendStatus {
338    /// Whether backend is initialized
339    pub is_initialized: bool,
340    /// Whether backend is ready for operations
341    pub is_ready: bool,
342    /// Currently active devices
343    pub active_devices: Vec<Device>,
344    /// Memory usage per device
345    pub memory_usage: HashMap<Device, u64>,
346    /// Number of operations completed
347    pub operations_completed: u64,
348    /// Last error (if any)
349    pub last_error: Option<String>,
350    /// Backend-specific status information
351    pub backend_specific: HashMap<String, serde_json::Value>,
352}
353
354/// Kernel executor for custom GPU kernels
355#[async_trait]
356pub trait KernelExecutor: Send + Sync {
357    /// Load kernel from source code
358    async fn load_kernel(&self, source: &str, name: &str, device: &Device) -> Result<KernelHandle>;
359
360    /// Execute kernel with arguments
361    async fn execute_kernel(
362        &self,
363        handle: KernelHandle,
364        grid_size: (u32, u32, u32),
365        block_size: (u32, u32, u32),
366        args: &[KernelArg],
367    ) -> Result<()>;
368
369    /// Get kernel information
370    fn get_kernel_info(&self, handle: KernelHandle) -> Option<KernelInfo>;
371
372    /// Unload kernel
373    async fn unload_kernel(&self, handle: KernelHandle) -> Result<()>;
374}
375
376/// Handle for loaded kernel
377#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
378pub struct KernelHandle(pub u64);
379
380/// Kernel argument types
381#[derive(Debug, Clone)]
382pub enum KernelArg {
383    /// Tensor reference
384    Tensor(TensorRef),
385    /// Raw memory buffer
386    Buffer { ptr: *const u8, size: usize },
387    /// Scalar value
388    Scalar(ScalarValue),
389    /// Local/shared memory allocation
390    LocalMemory(usize),
391}
392
393/// Scalar values for kernel arguments
394#[derive(Debug, Clone)]
395pub enum ScalarValue {
396    I8(i8),
397    I16(i16),
398    I32(i32),
399    I64(i64),
400    U8(u8),
401    U16(u16),
402    U32(u32),
403    U64(u64),
404    F32(f32),
405    F64(f64),
406    Bool(bool),
407}
408
409/// Kernel information and metadata
410#[derive(Debug, Clone)]
411pub struct KernelInfo {
412    /// Kernel name
413    pub name: String,
414    /// Maximum threads per block
415    pub max_threads_per_block: u32,
416    /// Shared memory size required
417    pub shared_memory_size: usize,
418    /// Register count per thread
419    pub registers_per_thread: u32,
420    /// Preferred block size
421    pub preferred_block_size: (u32, u32, u32),
422}
423
424/// Backend factory for creating backend instances
425#[async_trait]
426pub trait BackendFactory: Send + Sync {
427    /// Create compute backend
428    async fn create_compute_backend(
429        &self,
430        config: &BackendConfig,
431    ) -> Result<Box<dyn ComputeBackend>>;
432
433    /// Create weight loader
434    async fn create_weight_loader(
435        &self,
436        config: &WeightLoaderConfig,
437    ) -> Result<Box<dyn WeightLoader>>;
438
439    /// Get supported backend types
440    fn supported_backend_types(&self) -> Vec<BackendType>;
441
442    /// Validate backend configuration
443    fn validate_config(&self, config: &BackendConfig) -> Result<()>;
444}
445
446/// Backend configuration
447#[derive(Debug, Clone, Serialize, Deserialize)]
448pub struct BackendConfig {
449    /// Backend type
450    pub backend_type: BackendType,
451    /// Target device
452    pub device: Device,
453    /// Optimization level (0-3)
454    pub optimization_level: u8,
455    /// Enable debugging
456    pub enable_debug: bool,
457    /// Memory configuration
458    pub memory_config: BackendMemoryConfig,
459    /// Backend-specific options
460    pub backend_options: HashMap<String, serde_json::Value>,
461}
462
463/// Weight loader configuration
464#[derive(Debug, Clone, Serialize, Deserialize)]
465pub struct WeightLoaderConfig {
466    /// Enable caching
467    pub enable_caching: bool,
468    /// Cache directory
469    pub cache_dir: Option<String>,
470    /// Maximum cache size in bytes
471    pub max_cache_size: Option<u64>,
472    /// Number of concurrent downloads
473    pub max_concurrent_downloads: usize,
474    /// Connection timeout for downloads
475    pub download_timeout_seconds: u64,
476    /// Enable integrity checks
477    pub enable_integrity_checks: bool,
478    /// Custom headers for HTTP requests
479    pub default_headers: HashMap<String, String>,
480}
481
482/// Memory configuration for backends
483#[derive(Debug, Clone, Serialize, Deserialize)]
484pub struct BackendMemoryConfig {
485    /// Memory pool size in bytes (None for auto)
486    pub pool_size: Option<u64>,
487    /// Memory alignment in bytes
488    pub alignment: usize,
489    /// Enable memory pooling
490    pub enable_pooling: bool,
491    /// Memory growth strategy
492    pub growth_strategy: MemoryGrowthStrategy,
493}
494
495/// Memory growth strategies
496#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
497pub enum MemoryGrowthStrategy {
498    /// Pre-allocate all memory upfront
499    Static,
500    /// Grow memory as needed
501    Dynamic,
502    /// Pre-allocate with incremental growth
503    Incremental,
504}
505
506/// Backend types
507#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
508pub enum BackendType {
509    /// Candle framework
510    Candle,
511    /// ONNX Runtime
512    OnnxRuntime,
513    /// TensorRT
514    TensorRT,
515    /// Custom Metal implementation
516    Metal,
517    /// Custom CPU implementation
518    CPU,
519    /// Custom backend
520    Custom,
521}
522
523impl std::fmt::Display for BackendType {
524    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
525        let name = match self {
526            BackendType::Candle => "candle",
527            BackendType::OnnxRuntime => "onnx_runtime",
528            BackendType::TensorRT => "tensorrt",
529            BackendType::Metal => "metal",
530            BackendType::CPU => "cpu",
531            BackendType::Custom => "custom",
532        };
533        write!(f, "{}", name)
534    }
535}
536
537/// Backend registry for managing multiple backends
538pub trait BackendRegistry: Send + Sync {
539    /// Register compute backend
540    fn register_compute_backend(
541        &mut self,
542        name: &str,
543        backend: Box<dyn ComputeBackend>,
544    ) -> Result<()>;
545
546    /// Register weight loader
547    fn register_weight_loader(&mut self, name: &str, loader: Box<dyn WeightLoader>) -> Result<()>;
548
549    /// Get compute backend by name
550    fn get_compute_backend(&self, name: &str) -> Option<&dyn ComputeBackend>;
551
552    /// Get weight loader by name
553    fn get_weight_loader(&self, name: &str) -> Option<&dyn WeightLoader>;
554
555    /// Find best compute backend for requirements
556    fn find_best_compute_backend(
557        &self,
558        requirements: &BackendRequirements,
559    ) -> Option<&dyn ComputeBackend>;
560
561    /// List all registered backend names
562    fn list_backend_names(&self) -> (Vec<String>, Vec<String>); // (compute, weight)
563}