use crate::{ComputeBackend, ModelExecutor, WeightLoader};
use async_trait::async_trait;
use ferrum_types::{ModelConfig, ModelInfo, ModelSource, Result};
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, sync::Arc};
#[async_trait]
pub trait ModelBuilder: Send + Sync {
async fn build_model(
&self,
config: &ModelConfig,
compute_backend: Arc<dyn ComputeBackend>,
weight_loader: Arc<dyn WeightLoader>,
) -> Result<Box<dyn ModelExecutor>>;
async fn build_from_source(
&self,
source: &ModelSource,
compute_backend: Arc<dyn ComputeBackend>,
weight_loader: Arc<dyn WeightLoader>,
build_options: &BuildOptions,
) -> Result<Box<dyn ModelExecutor>>;
fn validate_config(&self, config: &ModelConfig) -> Result<Vec<ValidationIssue>>;
fn supported_model_types(&self) -> Vec<ferrum_types::ModelType>;
async fn estimate_build_time(&self, config: &ModelConfig) -> Result<BuildTimeEstimate>;
fn builder_info(&self) -> BuilderInfo;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BuildOptions {
pub enable_validation: bool,
pub enable_optimization: bool,
pub optimization_level: u8,
pub enable_quantization: bool,
pub quantization_config: Option<ferrum_types::QuantizationConfig>,
pub enable_compression: bool,
pub build_timeout_seconds: Option<u64>,
pub additional_options: HashMap<String, serde_json::Value>,
}
impl Default for BuildOptions {
fn default() -> Self {
Self {
enable_validation: true,
enable_optimization: true,
optimization_level: 2,
enable_quantization: false,
quantization_config: None,
enable_compression: false,
build_timeout_seconds: Some(3600), additional_options: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValidationIssue {
pub severity: ValidationSeverity,
pub category: String,
pub description: String,
pub suggested_fix: Option<String>,
pub config_path: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum ValidationSeverity {
Warning,
Error,
Critical,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BuildTimeEstimate {
pub min_time_seconds: u64,
pub max_time_seconds: u64,
pub expected_time_seconds: u64,
pub time_breakdown: BuildTimeBreakdown,
pub factors: Vec<BuildTimeFactor>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BuildTimeBreakdown {
pub weight_loading_seconds: u64,
pub model_init_seconds: u64,
pub optimization_seconds: u64,
pub validation_seconds: u64,
pub overhead_seconds: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BuildTimeFactor {
pub factor: String,
pub impact: f32,
pub description: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BuilderInfo {
pub name: String,
pub version: String,
pub supported_architectures: Vec<ModelArchitecture>,
pub supported_weight_formats: Vec<crate::backend::WeightFormat>,
pub supported_optimizations: Vec<OptimizationTechnique>,
pub capabilities: BuilderCapabilities,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelArchitecture {
pub name: String,
pub family: ModelArchitectureFamily,
pub variants: Vec<String>,
pub required_features: Vec<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ModelArchitectureFamily {
Transformer,
CNN,
RNN,
GNN,
Diffusion,
Custom,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum OptimizationTechnique {
OperatorFusion,
ConstantFolding,
DeadCodeElimination,
MemoryLayoutOptimization,
KernelSelection,
Quantization,
Pruning,
Distillation,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BuilderCapabilities {
pub max_model_size: Option<u64>,
pub supports_dynamic_shapes: bool,
pub supports_custom_ops: bool,
pub supports_mixed_precision: bool,
pub supports_model_parallelism: bool,
pub supports_parallel_build: bool,
pub supports_incremental_build: bool,
}
#[async_trait]
pub trait AdvancedModelBuilder: ModelBuilder {
async fn build_with_custom_layers(
&self,
config: &ModelConfig,
custom_layers: Vec<Box<dyn CustomLayer>>,
compute_backend: Arc<dyn ComputeBackend>,
weight_loader: Arc<dyn WeightLoader>,
) -> Result<Box<dyn ModelExecutor>>;
async fn build_incremental(
&self,
config: &ModelConfig,
compute_backend: Arc<dyn ComputeBackend>,
weight_loader: Arc<dyn WeightLoader>,
progress_callback: Box<dyn Fn(BuildProgress) + Send + Sync>,
) -> Result<Box<dyn ModelExecutor>>;
async fn build_with_optimization(
&self,
config: &ModelConfig,
optimization_pipeline: Vec<Box<dyn OptimizationPass>>,
compute_backend: Arc<dyn ComputeBackend>,
weight_loader: Arc<dyn WeightLoader>,
) -> Result<Box<dyn ModelExecutor>>;
async fn export_model_definition(&self, config: &ModelConfig) -> Result<ModelIR>;
async fn import_model_definition(
&self,
definition: &ModelIR,
compute_backend: Arc<dyn ComputeBackend>,
weight_loader: Arc<dyn WeightLoader>,
) -> Result<Box<dyn ModelExecutor>>;
}
pub trait CustomLayer: Send + Sync {
fn name(&self) -> &str;
fn layer_type(&self) -> &str;
fn input_shape(&self) -> Vec<i64>;
fn output_shape(&self, input_shape: &[i64]) -> Vec<i64>;
fn initialize_parameters(&self) -> Result<HashMap<String, crate::TensorRef>>;
fn config(&self) -> serde_json::Value;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BuildProgress {
pub phase: BuildPhase,
pub progress: f32,
pub current_operation: String,
pub elapsed_seconds: u64,
pub remaining_seconds: Option<u64>,
pub phase_details: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum BuildPhase {
Validation,
WeightLoading,
ModelInitialization,
LayerConstruction,
ParameterBinding,
Optimization,
FinalValidation,
Finalization,
}
pub trait OptimizationPass: Send + Sync {
fn name(&self) -> &str;
fn apply(&self, definition: &mut ModelIR) -> Result<OptimizationResult>;
fn is_applicable(&self, definition: &ModelIR) -> bool;
fn dependencies(&self) -> Vec<String>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizationResult {
pub applied: bool,
pub stats: OptimizationStats,
pub warnings: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizationStats {
pub parameters_eliminated: u64,
pub operations_eliminated: usize,
pub memory_saved: u64,
pub estimated_speedup: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelIR {
pub metadata: ModelMetadata,
pub architecture: ArchitectureDefinition,
pub parameters: Vec<ParameterSpec>,
pub layers: Vec<LayerDefinition>,
pub graph: GraphDefinition,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMetadata {
pub name: String,
pub version: String,
pub model_type: ferrum_types::ModelType,
pub architecture_family: ModelArchitectureFamily,
pub description: Option<String>,
pub author: Option<String>,
pub license: Option<String>,
pub additional: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ArchitectureDefinition {
pub name: String,
pub dimensions: ModelDimensions,
pub config: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelDimensions {
pub vocab_size: usize,
pub hidden_size: usize,
pub num_layers: usize,
pub num_heads: usize,
pub num_kv_heads: Option<usize>,
pub intermediate_size: Option<usize>,
pub max_sequence_length: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParameterSpec {
pub name: String,
pub shape: Vec<i64>,
pub dtype: ferrum_types::DataType,
pub trainable: bool,
pub initialization: InitializationStrategy,
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum InitializationStrategy {
Zeros,
Ones,
Uniform { min: f32, max: f32 },
Normal { mean: f32, std: f32 },
Xavier,
Kaiming,
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayerDefinition {
pub name: String,
pub layer_type: String,
pub inputs: Vec<TensorSpec>,
pub outputs: Vec<TensorSpec>,
pub parameters: Vec<String>, pub config: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorSpec {
pub name: String,
pub shape: Vec<i64>,
pub dtype: ferrum_types::DataType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphDefinition {
pub inputs: Vec<String>,
pub outputs: Vec<String>,
pub nodes: Vec<GraphNode>,
pub edges: Vec<GraphEdge>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphNode {
pub id: String,
pub layer_name: String,
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphEdge {
pub source: String,
pub target: String,
pub source_output: Option<usize>,
pub target_input: Option<usize>,
}
#[async_trait]
pub trait ModelBuilderFactory: Send + Sync {
async fn create_builder(&self) -> Result<Box<dyn ModelBuilder>>;
async fn create_advanced_builder(&self) -> Result<Box<dyn AdvancedModelBuilder>>;
fn supported_types(&self) -> Vec<ferrum_types::ModelType>;
async fn create_builder_for_type(
&self,
model_type: ferrum_types::ModelType,
) -> Result<Box<dyn ModelBuilder>>;
}
pub trait ModelRegistry: Send + Sync {
fn register_model(
&mut self,
model_id: &ferrum_types::ModelId,
executor: Box<dyn ModelExecutor>,
) -> Result<()>;
fn get_model(&self, model_id: &ferrum_types::ModelId) -> Option<&dyn ModelExecutor>;
fn remove_model(&mut self, model_id: &ferrum_types::ModelId) -> Option<Box<dyn ModelExecutor>>;
fn list_models(&self) -> Vec<ferrum_types::ModelId>;
fn get_model_info(&self, model_id: &ferrum_types::ModelId) -> Option<&ModelInfo>;
fn contains_model(&self, model_id: &ferrum_types::ModelId) -> bool;
}