ferrum-interfaces 0.6.0

Core trait contracts for the Ferrum LLM inference engine
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
//! Backend abstraction split into compute and weight loading concerns
//!
//! This module separates the previous "fat" Backend trait into focused
//! interfaces: ComputeBackend for tensor operations and WeightLoader for
//! model weight management.

use crate::kernel_ops::KernelOps;
use crate::{TensorFactory, TensorOps, TensorRef};
use async_trait::async_trait;
use ferrum_types::{DataType, Device, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

/// Compute backend for tensor operations and kernel execution
#[async_trait]
pub trait ComputeBackend: Send + Sync {
    /// Get backend name/identifier
    fn name(&self) -> &str;

    /// Get backend capabilities
    fn capabilities(&self) -> BackendCapabilities;

    /// Get tensor operations interface
    fn tensor_ops(&self) -> &dyn TensorOps;

    /// Get tensor factory for creating tensors
    fn tensor_factory(&self) -> &dyn TensorFactory;

    /// Get memory manager for this backend
    fn memory_manager(&self) -> &dyn crate::DeviceMemoryManager;

    /// Get kernel executor (if backend supports custom kernels)
    fn kernel_executor(&self) -> Option<&dyn KernelExecutor>;

    /// Get LLM-specific kernel operations (if backend provides optimized impls).
    ///
    /// Returns `None` by default — existing backends compile unchanged.
    /// Backends that implement `KernelOps` sub-traits (NormOps, PositionOps, etc.)
    /// return `Some` here to enable accelerated paths.
    fn kernel_ops(&self) -> Option<&dyn KernelOps> {
        None
    }

    /// Initialize backend with device
    async fn initialize(&mut self, device: &Device) -> Result<()>;

    /// Check if backend supports specific device
    fn supports_device(&self, device: &Device) -> bool;

    /// Get backend version
    fn version(&self) -> String;

    /// Synchronize all pending operations
    async fn synchronize(&self, device: &Device) -> Result<()>;

    /// Get backend status
    fn status(&self) -> BackendStatus;

    /// Shutdown backend gracefully
    async fn shutdown(&mut self) -> Result<()>;
}

/// Weight loading interface for model parameter management
#[async_trait]
pub trait WeightLoader: Send + Sync {
    /// Load tensor from weight specification
    async fn load_tensor(&self, spec: &TensorSpec) -> Result<TensorRef>;

    /// Load multiple tensors at once (batch loading)
    async fn load_tensors(&self, specs: &[TensorSpec]) -> Result<Vec<TensorRef>>;

    /// Check if weight source is available
    async fn is_available(&self, source: &WeightSource) -> bool;

    /// Get metadata about weight source
    async fn get_metadata(&self, source: &WeightSource) -> Result<WeightMetadata>;

    /// Preload weights into cache/memory
    async fn preload(&self, source: &WeightSource) -> Result<()>;

    /// Get loader capabilities
    fn capabilities(&self) -> WeightLoaderCapabilities;
}

/// Tensor specification for weight loading
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorSpec {
    /// Name/identifier of the tensor
    pub name: String,
    /// Expected tensor shape
    pub shape: Vec<usize>,
    /// Target data type
    pub dtype: DataType,
    /// Target device
    pub device: Device,
    /// Weight source location
    pub source: WeightSource,
    /// Optional transformations to apply
    pub transformations: Vec<TensorTransformation>,
}

/// Weight source specification
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum WeightSource {
    /// Local file path
    File {
        path: String,
        /// Tensor name within file (for formats like safetensors)
        tensor_name: Option<String>,
    },
    /// URL for download
    Url {
        url: String,
        headers: HashMap<String, String>,
    },
    /// Hugging Face Hub
    HuggingFace {
        repo_id: String,
        filename: String,
        revision: Option<String>,
        cache_dir: Option<String>,
    },
    /// Raw bytes in memory
    Memory { data: Vec<u8>, format: WeightFormat },
    /// S3-compatible storage
    S3 {
        bucket: String,
        key: String,
        region: Option<String>,
        endpoint: Option<String>,
    },
}

/// Weight file formats
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum WeightFormat {
    /// PyTorch tensor format
    PyTorch,
    /// Safetensors format
    SafeTensors,
    /// NumPy array format
    Numpy,
    /// Raw binary data
    Raw,
    /// ONNX format
    Onnx,
    /// Custom format
    Custom(u32),
}

/// Weight metadata information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WeightMetadata {
    /// Available tensor names and their shapes
    pub tensors: HashMap<String, Vec<usize>>,
    /// File format
    pub format: WeightFormat,
    /// Total size in bytes
    pub total_size_bytes: u64,
    /// Data types used
    pub dtypes: Vec<DataType>,
    /// Additional metadata
    pub extra: HashMap<String, serde_json::Value>,
}

/// Transformations that can be applied to loaded tensors
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TensorTransformation {
    /// Transpose dimensions
    Transpose { dim0: usize, dim1: usize },
    /// Reshape to new shape
    Reshape { shape: Vec<usize> },
    /// Convert data type
    Cast { dtype: DataType },
    /// Quantize tensor
    Quantize { config: QuantizationConfig },
    /// Apply scaling
    Scale { factor: f32 },
    /// Slice tensor
    Slice {
        dim: usize,
        start: Option<usize>,
        end: Option<usize>,
    },
}

/// Quantization configuration for weights
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum QuantizationConfig {
    /// INT8 uniform quantization
    INT8 { symmetric: bool },
    /// INT4 grouped quantization  
    INT4 { group_size: usize },
    /// FP8 quantization
    FP8 { e4m3: bool },
    /// GPTQ quantization
    GPTQ {
        bits: u8,
        group_size: usize,
        desc_act: bool,
    },
    /// AWQ quantization
    AWQ { bits: u8, zero_point: bool },
}

/// Backend capabilities description
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackendCapabilities {
    /// Supported data types
    pub supported_dtypes: Vec<DataType>,
    /// Supported devices
    pub supported_devices: Vec<Device>,
    /// Maximum tensor dimensions supported
    pub max_tensor_dims: usize,
    /// Whether backend supports FP16 operations
    pub supports_fp16: bool,
    /// Whether backend supports BF16 operations
    pub supports_bf16: bool,
    /// Whether backend supports INT8 quantization
    pub supports_int8: bool,
    /// Whether backend supports flash attention
    pub supports_flash_attention: bool,
    /// Whether backend supports paged attention
    pub supports_paged_attention: bool,
    /// Whether backend supports tensor parallelism
    pub supports_tensor_parallelism: bool,
    /// Whether backend supports pipeline parallelism
    pub supports_pipeline_parallelism: bool,
    /// Maximum batch size supported
    pub max_batch_size: usize,
    /// Maximum sequence length supported
    pub max_sequence_length: usize,
    /// Memory alignment requirements
    pub memory_alignment: usize,
    /// Whether backend supports custom kernels
    pub supports_custom_kernels: bool,
    /// Whether backend supports CUDA graphs
    pub supports_cuda_graphs: bool,
    /// Additional capabilities
    pub extra_capabilities: HashMap<String, serde_json::Value>,
}

impl BackendCapabilities {
    /// Check if capabilities meet requirements
    pub fn meets_requirements(&self, requirements: &BackendRequirements) -> bool {
        // Check devices
        if !requirements
            .required_devices
            .iter()
            .all(|dev| self.supported_devices.contains(dev))
        {
            return false;
        }

        // Check dtypes
        if !requirements
            .required_dtypes
            .iter()
            .all(|dtype| self.supported_dtypes.contains(dtype))
        {
            return false;
        }

        // Check batch size
        if requirements.min_batch_size > self.max_batch_size {
            return false;
        }

        // Check sequence length
        if requirements.min_sequence_length > self.max_sequence_length {
            return false;
        }

        true
    }
}

/// Requirements for backend selection
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackendRequirements {
    /// Required devices
    pub required_devices: Vec<Device>,
    /// Required data types
    pub required_dtypes: Vec<DataType>,
    /// Minimum batch size needed
    pub min_batch_size: usize,
    /// Minimum sequence length needed
    pub min_sequence_length: usize,
    /// Whether flash attention is required
    pub requires_flash_attention: bool,
    /// Whether paged attention is required
    pub requires_paged_attention: bool,
    /// Additional requirements
    pub extra_requirements: HashMap<String, serde_json::Value>,
}

/// Weight loader capabilities
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WeightLoaderCapabilities {
    /// Supported weight formats
    pub supported_formats: Vec<WeightFormat>,
    /// Supported weight sources
    pub supported_sources: Vec<WeightSourceType>,
    /// Maximum single tensor size in bytes
    pub max_tensor_size: u64,
    /// Whether loader supports streaming/chunked loading
    pub supports_streaming: bool,
    /// Whether loader supports concurrent loading
    pub supports_concurrent: bool,
    /// Supported transformations
    pub supported_transformations: Vec<TransformationType>,
}

/// Weight source types
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum WeightSourceType {
    File,
    Url,
    HuggingFace,
    Memory,
    S3,
}

/// Transformation types
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum TransformationType {
    Transpose,
    Reshape,
    Cast,
    Quantize,
    Scale,
    Slice,
}

/// Backend status information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackendStatus {
    /// Whether backend is initialized
    pub is_initialized: bool,
    /// Whether backend is ready for operations
    pub is_ready: bool,
    /// Currently active devices
    pub active_devices: Vec<Device>,
    /// Memory usage per device
    pub memory_usage: HashMap<Device, u64>,
    /// Number of operations completed
    pub operations_completed: u64,
    /// Last error (if any)
    pub last_error: Option<String>,
    /// Backend-specific status information
    pub backend_specific: HashMap<String, serde_json::Value>,
}

/// Kernel executor for custom GPU kernels
#[async_trait]
pub trait KernelExecutor: Send + Sync {
    /// Load kernel from source code
    async fn load_kernel(&self, source: &str, name: &str, device: &Device) -> Result<KernelHandle>;

    /// Execute kernel with arguments
    async fn execute_kernel(
        &self,
        handle: KernelHandle,
        grid_size: (u32, u32, u32),
        block_size: (u32, u32, u32),
        args: &[KernelArg],
    ) -> Result<()>;

    /// Get kernel information
    fn get_kernel_info(&self, handle: KernelHandle) -> Option<KernelInfo>;

    /// Unload kernel
    async fn unload_kernel(&self, handle: KernelHandle) -> Result<()>;
}

/// Handle for loaded kernel
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct KernelHandle(pub u64);

/// Kernel argument types
#[derive(Debug, Clone)]
pub enum KernelArg {
    /// Tensor reference
    Tensor(TensorRef),
    /// Raw memory buffer
    Buffer { ptr: *const u8, size: usize },
    /// Scalar value
    Scalar(ScalarValue),
    /// Local/shared memory allocation
    LocalMemory(usize),
}

/// Scalar values for kernel arguments
#[derive(Debug, Clone)]
pub enum ScalarValue {
    I8(i8),
    I16(i16),
    I32(i32),
    I64(i64),
    U8(u8),
    U16(u16),
    U32(u32),
    U64(u64),
    F32(f32),
    F64(f64),
    Bool(bool),
}

/// Kernel information and metadata
#[derive(Debug, Clone)]
pub struct KernelInfo {
    /// Kernel name
    pub name: String,
    /// Maximum threads per block
    pub max_threads_per_block: u32,
    /// Shared memory size required
    pub shared_memory_size: usize,
    /// Register count per thread
    pub registers_per_thread: u32,
    /// Preferred block size
    pub preferred_block_size: (u32, u32, u32),
}

/// Backend factory for creating backend instances
#[async_trait]
pub trait BackendFactory: Send + Sync {
    /// Create compute backend
    async fn create_compute_backend(
        &self,
        config: &BackendConfig,
    ) -> Result<Box<dyn ComputeBackend>>;

    /// Create weight loader
    async fn create_weight_loader(
        &self,
        config: &WeightLoaderConfig,
    ) -> Result<Box<dyn WeightLoader>>;

    /// Get supported backend types
    fn supported_backend_types(&self) -> Vec<BackendType>;

    /// Validate backend configuration
    fn validate_config(&self, config: &BackendConfig) -> Result<()>;
}

/// Backend configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackendConfig {
    /// Backend type
    pub backend_type: BackendType,
    /// Target device
    pub device: Device,
    /// Optimization level (0-3)
    pub optimization_level: u8,
    /// Enable debugging
    pub enable_debug: bool,
    /// Memory configuration
    pub memory_config: BackendMemoryConfig,
    /// Backend-specific options
    pub backend_options: HashMap<String, serde_json::Value>,
}

/// Weight loader configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WeightLoaderConfig {
    /// Enable caching
    pub enable_caching: bool,
    /// Cache directory
    pub cache_dir: Option<String>,
    /// Maximum cache size in bytes
    pub max_cache_size: Option<u64>,
    /// Number of concurrent downloads
    pub max_concurrent_downloads: usize,
    /// Connection timeout for downloads
    pub download_timeout_seconds: u64,
    /// Enable integrity checks
    pub enable_integrity_checks: bool,
    /// Custom headers for HTTP requests
    pub default_headers: HashMap<String, String>,
}

/// Memory configuration for backends
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackendMemoryConfig {
    /// Memory pool size in bytes (None for auto)
    pub pool_size: Option<u64>,
    /// Memory alignment in bytes
    pub alignment: usize,
    /// Enable memory pooling
    pub enable_pooling: bool,
    /// Memory growth strategy
    pub growth_strategy: MemoryGrowthStrategy,
}

/// Memory growth strategies
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum MemoryGrowthStrategy {
    /// Pre-allocate all memory upfront
    Static,
    /// Grow memory as needed
    Dynamic,
    /// Pre-allocate with incremental growth
    Incremental,
}

/// Backend types
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum BackendType {
    /// Candle framework
    Candle,
    /// ONNX Runtime
    OnnxRuntime,
    /// TensorRT
    TensorRT,
    /// Custom Metal implementation
    Metal,
    /// Custom CPU implementation
    CPU,
    /// Custom backend
    Custom,
}

impl std::fmt::Display for BackendType {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let name = match self {
            BackendType::Candle => "candle",
            BackendType::OnnxRuntime => "onnx_runtime",
            BackendType::TensorRT => "tensorrt",
            BackendType::Metal => "metal",
            BackendType::CPU => "cpu",
            BackendType::Custom => "custom",
        };
        write!(f, "{}", name)
    }
}

/// Backend registry for managing multiple backends
pub trait BackendRegistry: Send + Sync {
    /// Register compute backend
    fn register_compute_backend(
        &mut self,
        name: &str,
        backend: Box<dyn ComputeBackend>,
    ) -> Result<()>;

    /// Register weight loader
    fn register_weight_loader(&mut self, name: &str, loader: Box<dyn WeightLoader>) -> Result<()>;

    /// Get compute backend by name
    fn get_compute_backend(&self, name: &str) -> Option<&dyn ComputeBackend>;

    /// Get weight loader by name
    fn get_weight_loader(&self, name: &str) -> Option<&dyn WeightLoader>;

    /// Find best compute backend for requirements
    fn find_best_compute_backend(
        &self,
        requirements: &BackendRequirements,
    ) -> Option<&dyn ComputeBackend>;

    /// List all registered backend names
    fn list_backend_names(&self) -> (Vec<String>, Vec<String>); // (compute, weight)
}