1use crate::{ComputeBackend, ModelExecutor, WeightLoader};
8use async_trait::async_trait;
9use ferrum_types::{ModelConfig, ModelInfo, ModelSource, Result};
10use serde::{Deserialize, Serialize};
11use std::{collections::HashMap, sync::Arc};
12
13#[async_trait]
15pub trait ModelBuilder: Send + Sync {
16 async fn build_model(
18 &self,
19 config: &ModelConfig,
20 compute_backend: Arc<dyn ComputeBackend>,
21 weight_loader: Arc<dyn WeightLoader>,
22 ) -> Result<Box<dyn ModelExecutor>>;
23
24 async fn build_from_source(
26 &self,
27 source: &ModelSource,
28 compute_backend: Arc<dyn ComputeBackend>,
29 weight_loader: Arc<dyn WeightLoader>,
30 build_options: &BuildOptions,
31 ) -> Result<Box<dyn ModelExecutor>>;
32
33 fn validate_config(&self, config: &ModelConfig) -> Result<Vec<ValidationIssue>>;
35
36 fn supported_model_types(&self) -> Vec<ferrum_types::ModelType>;
38
39 async fn estimate_build_time(&self, config: &ModelConfig) -> Result<BuildTimeEstimate>;
41
42 fn builder_info(&self) -> BuilderInfo;
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct BuildOptions {
49 pub enable_validation: bool,
51 pub enable_optimization: bool,
53 pub optimization_level: u8,
55 pub enable_quantization: bool,
57 pub quantization_config: Option<ferrum_types::QuantizationConfig>,
59 pub enable_compression: bool,
61 pub build_timeout_seconds: Option<u64>,
63 pub additional_options: HashMap<String, serde_json::Value>,
65}
66
67impl Default for BuildOptions {
68 fn default() -> Self {
69 Self {
70 enable_validation: true,
71 enable_optimization: true,
72 optimization_level: 2,
73 enable_quantization: false,
74 quantization_config: None,
75 enable_compression: false,
76 build_timeout_seconds: Some(3600), additional_options: HashMap::new(),
78 }
79 }
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct ValidationIssue {
85 pub severity: ValidationSeverity,
87 pub category: String,
89 pub description: String,
91 pub suggested_fix: Option<String>,
93 pub config_path: String,
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
99pub enum ValidationSeverity {
100 Warning,
102 Error,
104 Critical,
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct BuildTimeEstimate {
111 pub min_time_seconds: u64,
113 pub max_time_seconds: u64,
115 pub expected_time_seconds: u64,
117 pub time_breakdown: BuildTimeBreakdown,
119 pub factors: Vec<BuildTimeFactor>,
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct BuildTimeBreakdown {
126 pub weight_loading_seconds: u64,
128 pub model_init_seconds: u64,
130 pub optimization_seconds: u64,
132 pub validation_seconds: u64,
134 pub overhead_seconds: u64,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct BuildTimeFactor {
141 pub factor: String,
143 pub impact: f32,
145 pub description: String,
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct BuilderInfo {
152 pub name: String,
154 pub version: String,
156 pub supported_architectures: Vec<ModelArchitecture>,
158 pub supported_weight_formats: Vec<crate::backend::WeightFormat>,
160 pub supported_optimizations: Vec<OptimizationTechnique>,
162 pub capabilities: BuilderCapabilities,
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct ModelArchitecture {
169 pub name: String,
171 pub family: ModelArchitectureFamily,
173 pub variants: Vec<String>,
175 pub required_features: Vec<String>,
177}
178
179#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
181pub enum ModelArchitectureFamily {
182 Transformer,
184 CNN,
186 RNN,
188 GNN,
190 Diffusion,
192 Custom,
194}
195
196#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
198pub enum OptimizationTechnique {
199 OperatorFusion,
201 ConstantFolding,
203 DeadCodeElimination,
205 MemoryLayoutOptimization,
207 KernelSelection,
209 Quantization,
211 Pruning,
213 Distillation,
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct BuilderCapabilities {
220 pub max_model_size: Option<u64>,
222 pub supports_dynamic_shapes: bool,
224 pub supports_custom_ops: bool,
226 pub supports_mixed_precision: bool,
228 pub supports_model_parallelism: bool,
230 pub supports_parallel_build: bool,
232 pub supports_incremental_build: bool,
234}
235
236#[async_trait]
238pub trait AdvancedModelBuilder: ModelBuilder {
239 async fn build_with_custom_layers(
241 &self,
242 config: &ModelConfig,
243 custom_layers: Vec<Box<dyn CustomLayer>>,
244 compute_backend: Arc<dyn ComputeBackend>,
245 weight_loader: Arc<dyn WeightLoader>,
246 ) -> Result<Box<dyn ModelExecutor>>;
247
248 async fn build_incremental(
250 &self,
251 config: &ModelConfig,
252 compute_backend: Arc<dyn ComputeBackend>,
253 weight_loader: Arc<dyn WeightLoader>,
254 progress_callback: Box<dyn Fn(BuildProgress) + Send + Sync>,
255 ) -> Result<Box<dyn ModelExecutor>>;
256
257 async fn build_with_optimization(
259 &self,
260 config: &ModelConfig,
261 optimization_pipeline: Vec<Box<dyn OptimizationPass>>,
262 compute_backend: Arc<dyn ComputeBackend>,
263 weight_loader: Arc<dyn WeightLoader>,
264 ) -> Result<Box<dyn ModelExecutor>>;
265
266 async fn export_model_definition(&self, config: &ModelConfig) -> Result<ModelIR>;
268
269 async fn import_model_definition(
271 &self,
272 definition: &ModelIR,
273 compute_backend: Arc<dyn ComputeBackend>,
274 weight_loader: Arc<dyn WeightLoader>,
275 ) -> Result<Box<dyn ModelExecutor>>;
276}
277
278pub trait CustomLayer: Send + Sync {
280 fn name(&self) -> &str;
282
283 fn layer_type(&self) -> &str;
285
286 fn input_shape(&self) -> Vec<i64>; fn output_shape(&self, input_shape: &[i64]) -> Vec<i64>;
291
292 fn initialize_parameters(&self) -> Result<HashMap<String, crate::TensorRef>>;
294
295 fn config(&self) -> serde_json::Value;
297}
298
299#[derive(Debug, Clone, Serialize, Deserialize)]
301pub struct BuildProgress {
302 pub phase: BuildPhase,
304 pub progress: f32,
306 pub current_operation: String,
308 pub elapsed_seconds: u64,
310 pub remaining_seconds: Option<u64>,
312 pub phase_details: HashMap<String, serde_json::Value>,
314}
315
316#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
318pub enum BuildPhase {
319 Validation,
321 WeightLoading,
323 ModelInitialization,
325 LayerConstruction,
327 ParameterBinding,
329 Optimization,
331 FinalValidation,
333 Finalization,
335}
336
337pub trait OptimizationPass: Send + Sync {
339 fn name(&self) -> &str;
341
342 fn apply(&self, definition: &mut ModelIR) -> Result<OptimizationResult>;
344
345 fn is_applicable(&self, definition: &ModelIR) -> bool;
347
348 fn dependencies(&self) -> Vec<String>;
350}
351
352#[derive(Debug, Clone, Serialize, Deserialize)]
354pub struct OptimizationResult {
355 pub applied: bool,
357 pub stats: OptimizationStats,
359 pub warnings: Vec<String>,
361}
362
363#[derive(Debug, Clone, Serialize, Deserialize)]
365pub struct OptimizationStats {
366 pub parameters_eliminated: u64,
368 pub operations_eliminated: usize,
370 pub memory_saved: u64,
372 pub estimated_speedup: f32,
374}
375
376#[derive(Debug, Clone, Serialize, Deserialize)]
382pub struct ModelIR {
383 pub metadata: ModelMetadata,
385 pub architecture: ArchitectureDefinition,
387 pub parameters: Vec<ParameterSpec>,
389 pub layers: Vec<LayerDefinition>,
391 pub graph: GraphDefinition,
393}
394
395#[derive(Debug, Clone, Serialize, Deserialize)]
397pub struct ModelMetadata {
398 pub name: String,
400 pub version: String,
402 pub model_type: ferrum_types::ModelType,
404 pub architecture_family: ModelArchitectureFamily,
406 pub description: Option<String>,
408 pub author: Option<String>,
410 pub license: Option<String>,
412 pub additional: HashMap<String, serde_json::Value>,
414}
415
416#[derive(Debug, Clone, Serialize, Deserialize)]
418pub struct ArchitectureDefinition {
419 pub name: String,
421 pub dimensions: ModelDimensions,
423 pub config: HashMap<String, serde_json::Value>,
425}
426
427#[derive(Debug, Clone, Serialize, Deserialize)]
429pub struct ModelDimensions {
430 pub vocab_size: usize,
432 pub hidden_size: usize,
434 pub num_layers: usize,
436 pub num_heads: usize,
438 pub num_kv_heads: Option<usize>,
440 pub intermediate_size: Option<usize>,
442 pub max_sequence_length: usize,
444}
445
446#[derive(Debug, Clone, Serialize, Deserialize)]
448pub struct ParameterSpec {
449 pub name: String,
451 pub shape: Vec<i64>,
453 pub dtype: ferrum_types::DataType,
455 pub trainable: bool,
457 pub initialization: InitializationStrategy,
459 pub metadata: HashMap<String, serde_json::Value>,
461}
462
463#[derive(Debug, Clone, Serialize, Deserialize)]
465pub enum InitializationStrategy {
466 Zeros,
468 Ones,
470 Uniform { min: f32, max: f32 },
472 Normal { mean: f32, std: f32 },
474 Xavier,
476 Kaiming,
478 Custom(String),
480}
481
482#[derive(Debug, Clone, Serialize, Deserialize)]
484pub struct LayerDefinition {
485 pub name: String,
487 pub layer_type: String,
489 pub inputs: Vec<TensorSpec>,
491 pub outputs: Vec<TensorSpec>,
493 pub parameters: Vec<String>, pub config: HashMap<String, serde_json::Value>,
497}
498
499#[derive(Debug, Clone, Serialize, Deserialize)]
501pub struct TensorSpec {
502 pub name: String,
504 pub shape: Vec<i64>,
506 pub dtype: ferrum_types::DataType,
508}
509
510#[derive(Debug, Clone, Serialize, Deserialize)]
512pub struct GraphDefinition {
513 pub inputs: Vec<String>,
515 pub outputs: Vec<String>,
517 pub nodes: Vec<GraphNode>,
519 pub edges: Vec<GraphEdge>,
521}
522
523#[derive(Debug, Clone, Serialize, Deserialize)]
525pub struct GraphNode {
526 pub id: String,
528 pub layer_name: String,
530 pub metadata: HashMap<String, serde_json::Value>,
532}
533
534#[derive(Debug, Clone, Serialize, Deserialize)]
536pub struct GraphEdge {
537 pub source: String,
539 pub target: String,
541 pub source_output: Option<usize>,
543 pub target_input: Option<usize>,
545}
546
547#[async_trait]
549pub trait ModelBuilderFactory: Send + Sync {
550 async fn create_builder(&self) -> Result<Box<dyn ModelBuilder>>;
552
553 async fn create_advanced_builder(&self) -> Result<Box<dyn AdvancedModelBuilder>>;
555
556 fn supported_types(&self) -> Vec<ferrum_types::ModelType>;
558
559 async fn create_builder_for_type(
561 &self,
562 model_type: ferrum_types::ModelType,
563 ) -> Result<Box<dyn ModelBuilder>>;
564}
565
566pub trait ModelRegistry: Send + Sync {
568 fn register_model(
570 &mut self,
571 model_id: &ferrum_types::ModelId,
572 executor: Box<dyn ModelExecutor>,
573 ) -> Result<()>;
574
575 fn get_model(&self, model_id: &ferrum_types::ModelId) -> Option<&dyn ModelExecutor>;
577
578 fn remove_model(&mut self, model_id: &ferrum_types::ModelId) -> Option<Box<dyn ModelExecutor>>;
580
581 fn list_models(&self) -> Vec<ferrum_types::ModelId>;
583
584 fn get_model_info(&self, model_id: &ferrum_types::ModelId) -> Option<&ModelInfo>;
586
587 fn contains_model(&self, model_id: &ferrum_types::ModelId) -> bool;
589}