Skip to main content

ferrum_interfaces/
model_builder.rs

1//! Model builder interface for constructing model executors
2//!
3//! This module provides interfaces for building model executors from
4//! configurations and weight sources, separating model construction
5//! from backend implementation.
6
7use crate::{ComputeBackend, ModelExecutor, WeightLoader};
8use async_trait::async_trait;
9use ferrum_types::{ModelConfig, ModelInfo, ModelSource, Result};
10use serde::{Deserialize, Serialize};
11use std::{collections::HashMap, sync::Arc};
12
13/// Model builder for constructing model executors
14#[async_trait]
15pub trait ModelBuilder: Send + Sync {
16    /// Build model executor from configuration
17    async fn build_model(
18        &self,
19        config: &ModelConfig,
20        compute_backend: Arc<dyn ComputeBackend>,
21        weight_loader: Arc<dyn WeightLoader>,
22    ) -> Result<Box<dyn ModelExecutor>>;
23
24    /// Build model executor from model source
25    async fn build_from_source(
26        &self,
27        source: &ModelSource,
28        compute_backend: Arc<dyn ComputeBackend>,
29        weight_loader: Arc<dyn WeightLoader>,
30        build_options: &BuildOptions,
31    ) -> Result<Box<dyn ModelExecutor>>;
32
33    /// Validate model configuration
34    fn validate_config(&self, config: &ModelConfig) -> Result<Vec<ValidationIssue>>;
35
36    /// Get supported model types
37    fn supported_model_types(&self) -> Vec<ferrum_types::ModelType>;
38
39    /// Get estimated build time
40    async fn estimate_build_time(&self, config: &ModelConfig) -> Result<BuildTimeEstimate>;
41
42    /// Get builder information
43    fn builder_info(&self) -> BuilderInfo;
44}
45
46/// Build options for model construction
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct BuildOptions {
49    /// Enable model validation after build
50    pub enable_validation: bool,
51    /// Enable model optimization
52    pub enable_optimization: bool,
53    /// Optimization level (0-3)
54    pub optimization_level: u8,
55    /// Enable model quantization
56    pub enable_quantization: bool,
57    /// Quantization configuration
58    pub quantization_config: Option<ferrum_types::QuantizationConfig>,
59    /// Enable model compression
60    pub enable_compression: bool,
61    /// Build timeout in seconds
62    pub build_timeout_seconds: Option<u64>,
63    /// Additional build options
64    pub additional_options: HashMap<String, serde_json::Value>,
65}
66
67impl Default for BuildOptions {
68    fn default() -> Self {
69        Self {
70            enable_validation: true,
71            enable_optimization: true,
72            optimization_level: 2,
73            enable_quantization: false,
74            quantization_config: None,
75            enable_compression: false,
76            build_timeout_seconds: Some(3600), // 1 hour
77            additional_options: HashMap::new(),
78        }
79    }
80}
81
82/// Validation issue found during configuration validation
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct ValidationIssue {
85    /// Issue severity
86    pub severity: ValidationSeverity,
87    /// Issue category
88    pub category: String,
89    /// Issue description
90    pub description: String,
91    /// Suggested fix
92    pub suggested_fix: Option<String>,
93    /// Configuration path where issue was found
94    pub config_path: String,
95}
96
97/// Validation issue severity
98#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
99pub enum ValidationSeverity {
100    /// Warning that doesn't prevent build
101    Warning,
102    /// Error that prevents build
103    Error,
104    /// Critical error that indicates serious misconfiguration
105    Critical,
106}
107
108/// Build time estimation
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct BuildTimeEstimate {
111    /// Estimated minimum build time
112    pub min_time_seconds: u64,
113    /// Estimated maximum build time
114    pub max_time_seconds: u64,
115    /// Most likely build time
116    pub expected_time_seconds: u64,
117    /// Breakdown of build time by phase
118    pub time_breakdown: BuildTimeBreakdown,
119    /// Factors affecting build time
120    pub factors: Vec<BuildTimeFactor>,
121}
122
123/// Build time breakdown by phase
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct BuildTimeBreakdown {
126    /// Time for weight loading
127    pub weight_loading_seconds: u64,
128    /// Time for model initialization
129    pub model_init_seconds: u64,
130    /// Time for optimization
131    pub optimization_seconds: u64,
132    /// Time for validation
133    pub validation_seconds: u64,
134    /// Other overhead time
135    pub overhead_seconds: u64,
136}
137
138/// Factor affecting build time
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct BuildTimeFactor {
141    /// Factor name
142    pub factor: String,
143    /// Impact on build time (multiplier)
144    pub impact: f32,
145    /// Description
146    pub description: String,
147}
148
149/// Builder information and capabilities
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct BuilderInfo {
152    /// Builder name
153    pub name: String,
154    /// Builder version
155    pub version: String,
156    /// Supported model architectures
157    pub supported_architectures: Vec<ModelArchitecture>,
158    /// Supported weight formats
159    pub supported_weight_formats: Vec<crate::backend::WeightFormat>,
160    /// Supported optimization techniques
161    pub supported_optimizations: Vec<OptimizationTechnique>,
162    /// Builder capabilities
163    pub capabilities: BuilderCapabilities,
164}
165
166/// Model architecture types
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct ModelArchitecture {
169    /// Architecture name
170    pub name: String,
171    /// Architecture family
172    pub family: ModelArchitectureFamily,
173    /// Supported variants
174    pub variants: Vec<String>,
175    /// Required features
176    pub required_features: Vec<String>,
177}
178
179/// Model architecture families
180#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
181pub enum ModelArchitectureFamily {
182    /// Transformer-based models
183    Transformer,
184    /// Convolutional neural networks
185    CNN,
186    /// Recurrent neural networks
187    RNN,
188    /// Graph neural networks
189    GNN,
190    /// Diffusion models
191    Diffusion,
192    /// Custom architecture
193    Custom,
194}
195
196/// Optimization techniques
197#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
198pub enum OptimizationTechnique {
199    /// Operator fusion
200    OperatorFusion,
201    /// Constant folding
202    ConstantFolding,
203    /// Dead code elimination
204    DeadCodeElimination,
205    /// Memory layout optimization
206    MemoryLayoutOptimization,
207    /// Kernel selection optimization
208    KernelSelection,
209    /// Quantization
210    Quantization,
211    /// Pruning
212    Pruning,
213    /// Knowledge distillation
214    Distillation,
215}
216
217/// Builder capabilities
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct BuilderCapabilities {
220    /// Maximum model size supported (parameters)
221    pub max_model_size: Option<u64>,
222    /// Supports dynamic shapes
223    pub supports_dynamic_shapes: bool,
224    /// Supports custom operations
225    pub supports_custom_ops: bool,
226    /// Supports mixed precision
227    pub supports_mixed_precision: bool,
228    /// Supports model parallelism
229    pub supports_model_parallelism: bool,
230    /// Parallel build support
231    pub supports_parallel_build: bool,
232    /// Incremental build support
233    pub supports_incremental_build: bool,
234}
235
236/// Advanced model builder with additional capabilities
237#[async_trait]
238pub trait AdvancedModelBuilder: ModelBuilder {
239    /// Build model with custom layers
240    async fn build_with_custom_layers(
241        &self,
242        config: &ModelConfig,
243        custom_layers: Vec<Box<dyn CustomLayer>>,
244        compute_backend: Arc<dyn ComputeBackend>,
245        weight_loader: Arc<dyn WeightLoader>,
246    ) -> Result<Box<dyn ModelExecutor>>;
247
248    /// Build model incrementally (for large models)
249    async fn build_incremental(
250        &self,
251        config: &ModelConfig,
252        compute_backend: Arc<dyn ComputeBackend>,
253        weight_loader: Arc<dyn WeightLoader>,
254        progress_callback: Box<dyn Fn(BuildProgress) + Send + Sync>,
255    ) -> Result<Box<dyn ModelExecutor>>;
256
257    /// Build model with custom optimization pipeline
258    async fn build_with_optimization(
259        &self,
260        config: &ModelConfig,
261        optimization_pipeline: Vec<Box<dyn OptimizationPass>>,
262        compute_backend: Arc<dyn ComputeBackend>,
263        weight_loader: Arc<dyn WeightLoader>,
264    ) -> Result<Box<dyn ModelExecutor>>;
265
266    /// Export model definition for debugging
267    async fn export_model_definition(&self, config: &ModelConfig) -> Result<ModelIR>;
268
269    /// Import model definition for custom builds
270    async fn import_model_definition(
271        &self,
272        definition: &ModelIR,
273        compute_backend: Arc<dyn ComputeBackend>,
274        weight_loader: Arc<dyn WeightLoader>,
275    ) -> Result<Box<dyn ModelExecutor>>;
276}
277
278/// Custom layer interface for advanced builders
279pub trait CustomLayer: Send + Sync {
280    /// Get layer name
281    fn name(&self) -> &str;
282
283    /// Get layer type
284    fn layer_type(&self) -> &str;
285
286    /// Get input shape requirements
287    fn input_shape(&self) -> Vec<i64>; // -1 for dynamic dimensions
288
289    /// Get output shape
290    fn output_shape(&self, input_shape: &[i64]) -> Vec<i64>;
291
292    /// Initialize layer parameters
293    fn initialize_parameters(&self) -> Result<HashMap<String, crate::TensorRef>>;
294
295    /// Get layer configuration
296    fn config(&self) -> serde_json::Value;
297}
298
299/// Build progress information
300#[derive(Debug, Clone, Serialize, Deserialize)]
301pub struct BuildProgress {
302    /// Current build phase
303    pub phase: BuildPhase,
304    /// Progress percentage (0.0 - 1.0)
305    pub progress: f32,
306    /// Current operation description
307    pub current_operation: String,
308    /// Elapsed time in seconds
309    pub elapsed_seconds: u64,
310    /// Estimated remaining time in seconds
311    pub remaining_seconds: Option<u64>,
312    /// Phase-specific details
313    pub phase_details: HashMap<String, serde_json::Value>,
314}
315
316/// Build phases
317#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
318pub enum BuildPhase {
319    /// Configuration validation
320    Validation,
321    /// Weight loading
322    WeightLoading,
323    /// Model initialization
324    ModelInitialization,
325    /// Layer construction
326    LayerConstruction,
327    /// Parameter binding
328    ParameterBinding,
329    /// Model optimization
330    Optimization,
331    /// Final validation
332    FinalValidation,
333    /// Cleanup and finalization
334    Finalization,
335}
336
337/// Optimization pass for custom optimization pipelines
338pub trait OptimizationPass: Send + Sync {
339    /// Get optimization pass name
340    fn name(&self) -> &str;
341
342    /// Apply optimization to model definition
343    fn apply(&self, definition: &mut ModelIR) -> Result<OptimizationResult>;
344
345    /// Check if optimization is applicable
346    fn is_applicable(&self, definition: &ModelIR) -> bool;
347
348    /// Get optimization dependencies (must run before this)
349    fn dependencies(&self) -> Vec<String>;
350}
351
352/// Optimization result
353#[derive(Debug, Clone, Serialize, Deserialize)]
354pub struct OptimizationResult {
355    /// Whether optimization was applied
356    pub applied: bool,
357    /// Optimization statistics
358    pub stats: OptimizationStats,
359    /// Warnings or issues
360    pub warnings: Vec<String>,
361}
362
363/// Optimization statistics
364#[derive(Debug, Clone, Serialize, Deserialize)]
365pub struct OptimizationStats {
366    /// Parameters eliminated
367    pub parameters_eliminated: u64,
368    /// Operations eliminated
369    pub operations_eliminated: usize,
370    /// Memory saved (bytes)
371    pub memory_saved: u64,
372    /// Estimated speedup
373    pub estimated_speedup: f32,
374}
375
376/// Model IR (Intermediate Representation) for export/import
377///
378/// Note: This is different from ferrum_models::ModelDefinition which is used
379/// for parsing HuggingFace config.json files. This type represents a complete
380/// model definition including computational graph for model building/export.
381#[derive(Debug, Clone, Serialize, Deserialize)]
382pub struct ModelIR {
383    /// Model metadata
384    pub metadata: ModelMetadata,
385    /// Model architecture definition
386    pub architecture: ArchitectureDefinition,
387    /// Parameter specifications
388    pub parameters: Vec<ParameterSpec>,
389    /// Layer definitions
390    pub layers: Vec<LayerDefinition>,
391    /// Model graph/connectivity
392    pub graph: GraphDefinition,
393}
394
395/// Model metadata
396#[derive(Debug, Clone, Serialize, Deserialize)]
397pub struct ModelMetadata {
398    /// Model name
399    pub name: String,
400    /// Model version
401    pub version: String,
402    /// Model type
403    pub model_type: ferrum_types::ModelType,
404    /// Architecture family
405    pub architecture_family: ModelArchitectureFamily,
406    /// Model description
407    pub description: Option<String>,
408    /// Author information
409    pub author: Option<String>,
410    /// License information
411    pub license: Option<String>,
412    /// Additional metadata
413    pub additional: HashMap<String, serde_json::Value>,
414}
415
416/// Architecture definition
417#[derive(Debug, Clone, Serialize, Deserialize)]
418pub struct ArchitectureDefinition {
419    /// Architecture name
420    pub name: String,
421    /// Model dimensions
422    pub dimensions: ModelDimensions,
423    /// Architecture-specific configuration
424    pub config: HashMap<String, serde_json::Value>,
425}
426
427/// Model dimensions and hyperparameters
428#[derive(Debug, Clone, Serialize, Deserialize)]
429pub struct ModelDimensions {
430    /// Vocabulary size
431    pub vocab_size: usize,
432    /// Hidden/embedding dimension
433    pub hidden_size: usize,
434    /// Number of layers
435    pub num_layers: usize,
436    /// Number of attention heads
437    pub num_heads: usize,
438    /// Number of key-value heads (for GQA/MQA)
439    pub num_kv_heads: Option<usize>,
440    /// Intermediate/FFN dimension
441    pub intermediate_size: Option<usize>,
442    /// Maximum sequence length
443    pub max_sequence_length: usize,
444}
445
446/// Parameter specification
447#[derive(Debug, Clone, Serialize, Deserialize)]
448pub struct ParameterSpec {
449    /// Parameter name
450    pub name: String,
451    /// Parameter shape
452    pub shape: Vec<i64>,
453    /// Data type
454    pub dtype: ferrum_types::DataType,
455    /// Whether parameter is trainable
456    pub trainable: bool,
457    /// Initialization strategy
458    pub initialization: InitializationStrategy,
459    /// Additional parameter metadata
460    pub metadata: HashMap<String, serde_json::Value>,
461}
462
463/// Parameter initialization strategies
464#[derive(Debug, Clone, Serialize, Deserialize)]
465pub enum InitializationStrategy {
466    /// Zero initialization
467    Zeros,
468    /// One initialization
469    Ones,
470    /// Uniform random initialization
471    Uniform { min: f32, max: f32 },
472    /// Normal/Gaussian initialization
473    Normal { mean: f32, std: f32 },
474    /// Xavier/Glorot initialization
475    Xavier,
476    /// Kaiming/He initialization
477    Kaiming,
478    /// Custom initialization
479    Custom(String),
480}
481
482/// Layer definition
483#[derive(Debug, Clone, Serialize, Deserialize)]
484pub struct LayerDefinition {
485    /// Layer name
486    pub name: String,
487    /// Layer type
488    pub layer_type: String,
489    /// Input specifications
490    pub inputs: Vec<TensorSpec>,
491    /// Output specifications
492    pub outputs: Vec<TensorSpec>,
493    /// Layer parameters
494    pub parameters: Vec<String>, // Parameter names
495    /// Layer configuration
496    pub config: HashMap<String, serde_json::Value>,
497}
498
499/// Tensor specification
500#[derive(Debug, Clone, Serialize, Deserialize)]
501pub struct TensorSpec {
502    /// Tensor name
503    pub name: String,
504    /// Tensor shape (-1 for dynamic dimensions)
505    pub shape: Vec<i64>,
506    /// Data type
507    pub dtype: ferrum_types::DataType,
508}
509
510/// Model graph definition
511#[derive(Debug, Clone, Serialize, Deserialize)]
512pub struct GraphDefinition {
513    /// Input nodes
514    pub inputs: Vec<String>,
515    /// Output nodes
516    pub outputs: Vec<String>,
517    /// Graph nodes (layers)
518    pub nodes: Vec<GraphNode>,
519    /// Graph edges (connections)
520    pub edges: Vec<GraphEdge>,
521}
522
523/// Graph node representing a layer
524#[derive(Debug, Clone, Serialize, Deserialize)]
525pub struct GraphNode {
526    /// Node ID
527    pub id: String,
528    /// Layer name
529    pub layer_name: String,
530    /// Node metadata
531    pub metadata: HashMap<String, serde_json::Value>,
532}
533
534/// Graph edge representing a connection
535#[derive(Debug, Clone, Serialize, Deserialize)]
536pub struct GraphEdge {
537    /// Source node ID
538    pub source: String,
539    /// Target node ID
540    pub target: String,
541    /// Source output index
542    pub source_output: Option<usize>,
543    /// Target input index
544    pub target_input: Option<usize>,
545}
546
547/// Model builder factory
548#[async_trait]
549pub trait ModelBuilderFactory: Send + Sync {
550    /// Create standard model builder
551    async fn create_builder(&self) -> Result<Box<dyn ModelBuilder>>;
552
553    /// Create advanced model builder
554    async fn create_advanced_builder(&self) -> Result<Box<dyn AdvancedModelBuilder>>;
555
556    /// Get supported model types
557    fn supported_types(&self) -> Vec<ferrum_types::ModelType>;
558
559    /// Create builder for specific model type
560    async fn create_builder_for_type(
561        &self,
562        model_type: ferrum_types::ModelType,
563    ) -> Result<Box<dyn ModelBuilder>>;
564}
565
566/// Model registry for managing built models
567pub trait ModelRegistry: Send + Sync {
568    /// Register model executor
569    fn register_model(
570        &mut self,
571        model_id: &ferrum_types::ModelId,
572        executor: Box<dyn ModelExecutor>,
573    ) -> Result<()>;
574
575    /// Get model executor
576    fn get_model(&self, model_id: &ferrum_types::ModelId) -> Option<&dyn ModelExecutor>;
577
578    /// Remove model executor
579    fn remove_model(&mut self, model_id: &ferrum_types::ModelId) -> Option<Box<dyn ModelExecutor>>;
580
581    /// List registered models
582    fn list_models(&self) -> Vec<ferrum_types::ModelId>;
583
584    /// Get model information
585    fn get_model_info(&self, model_id: &ferrum_types::ModelId) -> Option<&ModelInfo>;
586
587    /// Check if model exists
588    fn contains_model(&self, model_id: &ferrum_types::ModelId) -> bool;
589}