Skip to main content

ferrum_models/builder/
factory.rs

1//! Model builder factory - MVP with working stub executor
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use ferrum_interfaces::{
7    backend::WeightFormat,
8    model_builder::{
9        BuildOptions, BuildTimeBreakdown, BuildTimeEstimate, BuilderCapabilities, BuilderInfo,
10        ModelArchitecture, ModelArchitectureFamily, ValidationIssue, ValidationSeverity,
11    },
12    ComputeBackend, ModelBuilder, ModelExecutor, WeightLoader,
13};
14use ferrum_types::{ModelConfig, ModelSource, ModelType, Result};
15use tracing::debug;
16
17use crate::executor::StubModelExecutor;
18
19/// Simple model builder - MVP implementation
20#[derive(Debug, Default)]
21pub struct SimpleModelBuilder;
22
23#[async_trait]
24impl ModelBuilder for SimpleModelBuilder {
25    async fn build_model(
26        &self,
27        config: &ModelConfig,
28        compute_backend: Arc<dyn ComputeBackend>,
29        _weight_loader: Arc<dyn WeightLoader>,
30    ) -> Result<Box<dyn ModelExecutor>> {
31        debug!(
32            "Building stub model: model_id={:?}, type={:?}",
33            config.model_id, config.model_type
34        );
35
36        let vocab_size = config
37            .extra_config
38            .get("vocab_size")
39            .and_then(|v| v.as_u64())
40            .unwrap_or(32000) as usize;
41
42        let executor = StubModelExecutor::new(config.model_id.clone(), vocab_size, compute_backend);
43
44        debug!("Built stub model executor");
45        Ok(Box::new(executor))
46    }
47
48    async fn build_from_source(
49        &self,
50        source: &ModelSource,
51        compute_backend: Arc<dyn ComputeBackend>,
52        _weight_loader: Arc<dyn WeightLoader>,
53        _build_options: &BuildOptions,
54    ) -> Result<Box<dyn ModelExecutor>> {
55        debug!("Building model from source: {:?}", source);
56
57        let model_id = match source {
58            ModelSource::Local(path) => path.clone(),
59            ModelSource::HuggingFace { repo_id, .. } => repo_id.clone(),
60            ModelSource::Url { url, .. } => url.clone(),
61            ModelSource::S3 { key, .. } => key.clone(),
62        };
63
64        let executor = StubModelExecutor::new(model_id, 32000, compute_backend);
65        Ok(Box::new(executor))
66    }
67
68    fn validate_config(&self, config: &ModelConfig) -> Result<Vec<ValidationIssue>> {
69        let mut issues = Vec::new();
70
71        if config.model_path.is_empty() {
72            issues.push(ValidationIssue {
73                severity: ValidationSeverity::Error,
74                category: "configuration".into(),
75                description: "Model path cannot be empty".into(),
76                suggested_fix: Some("Provide a valid model path".into()),
77                config_path: "model_path".into(),
78            });
79        }
80
81        Ok(issues)
82    }
83
84    fn supported_model_types(&self) -> Vec<ModelType> {
85        vec![
86            ModelType::Llama,
87            ModelType::Mistral,
88            ModelType::Qwen,
89            ModelType::Phi,
90            ModelType::Custom("stub".into()),
91        ]
92    }
93
94    async fn estimate_build_time(&self, _config: &ModelConfig) -> Result<BuildTimeEstimate> {
95        Ok(BuildTimeEstimate {
96            min_time_seconds: 1,
97            max_time_seconds: 10,
98            expected_time_seconds: 3,
99            time_breakdown: BuildTimeBreakdown {
100                weight_loading_seconds: 1,
101                model_init_seconds: 1,
102                optimization_seconds: 0,
103                validation_seconds: 1,
104                overhead_seconds: 0,
105            },
106            factors: vec![],
107        })
108    }
109
110    fn builder_info(&self) -> BuilderInfo {
111        BuilderInfo {
112            name: "SimpleModelBuilder".into(),
113            version: env!("CARGO_PKG_VERSION").into(),
114            supported_architectures: vec![
115                ModelArchitecture {
116                    name: "Llama".into(),
117                    family: ModelArchitectureFamily::Transformer,
118                    variants: vec!["llama-7b".into(), "llama-13b".into()],
119                    required_features: vec![],
120                },
121                ModelArchitecture {
122                    name: "Mistral".into(),
123                    family: ModelArchitectureFamily::Transformer,
124                    variants: vec!["mistral-7b".into()],
125                    required_features: vec![],
126                },
127            ],
128            supported_weight_formats: vec![WeightFormat::SafeTensors],
129            supported_optimizations: vec![],
130            capabilities: BuilderCapabilities {
131                max_model_size: Some(70_000_000_000),
132                supports_dynamic_shapes: false,
133                supports_custom_ops: false,
134                supports_mixed_precision: true,
135                supports_model_parallelism: false,
136                supports_parallel_build: false,
137                supports_incremental_build: false,
138            },
139        }
140    }
141}
142
143/// Default model builder factory
144#[derive(Debug, Default, Clone)]
145pub struct DefaultModelBuilderFactory;
146
147impl DefaultModelBuilderFactory {
148    pub fn new() -> Self {
149        Self
150    }
151
152    pub fn create(&self) -> Arc<dyn ModelBuilder + Send + Sync> {
153        Arc::new(SimpleModelBuilder)
154    }
155
156    pub fn create_for_model_type(
157        &self,
158        _model_type: &ModelType,
159    ) -> Arc<dyn ModelBuilder + Send + Sync> {
160        Arc::new(SimpleModelBuilder)
161    }
162}