1use std::sync::Arc;
4
5use async_trait::async_trait;
6use ferrum_interfaces::{
7 backend::WeightFormat,
8 model_builder::{
9 BuildOptions, BuildTimeBreakdown, BuildTimeEstimate, BuilderCapabilities, BuilderInfo,
10 ModelArchitecture, ModelArchitectureFamily, ValidationIssue, ValidationSeverity,
11 },
12 ComputeBackend, ModelBuilder, ModelExecutor, WeightLoader,
13};
14use ferrum_types::{ModelConfig, ModelSource, ModelType, Result};
15use tracing::debug;
16
17use crate::executor::StubModelExecutor;
18
19#[derive(Debug, Default)]
21pub struct SimpleModelBuilder;
22
23#[async_trait]
24impl ModelBuilder for SimpleModelBuilder {
25 async fn build_model(
26 &self,
27 config: &ModelConfig,
28 compute_backend: Arc<dyn ComputeBackend>,
29 _weight_loader: Arc<dyn WeightLoader>,
30 ) -> Result<Box<dyn ModelExecutor>> {
31 debug!(
32 "Building stub model: model_id={:?}, type={:?}",
33 config.model_id, config.model_type
34 );
35
36 let vocab_size = config
37 .extra_config
38 .get("vocab_size")
39 .and_then(|v| v.as_u64())
40 .unwrap_or(32000) as usize;
41
42 let executor = StubModelExecutor::new(config.model_id.clone(), vocab_size, compute_backend);
43
44 debug!("Built stub model executor");
45 Ok(Box::new(executor))
46 }
47
48 async fn build_from_source(
49 &self,
50 source: &ModelSource,
51 compute_backend: Arc<dyn ComputeBackend>,
52 _weight_loader: Arc<dyn WeightLoader>,
53 _build_options: &BuildOptions,
54 ) -> Result<Box<dyn ModelExecutor>> {
55 debug!("Building model from source: {:?}", source);
56
57 let model_id = match source {
58 ModelSource::Local(path) => path.clone(),
59 ModelSource::HuggingFace { repo_id, .. } => repo_id.clone(),
60 ModelSource::Url { url, .. } => url.clone(),
61 ModelSource::S3 { key, .. } => key.clone(),
62 };
63
64 let executor = StubModelExecutor::new(model_id, 32000, compute_backend);
65 Ok(Box::new(executor))
66 }
67
68 fn validate_config(&self, config: &ModelConfig) -> Result<Vec<ValidationIssue>> {
69 let mut issues = Vec::new();
70
71 if config.model_path.is_empty() {
72 issues.push(ValidationIssue {
73 severity: ValidationSeverity::Error,
74 category: "configuration".into(),
75 description: "Model path cannot be empty".into(),
76 suggested_fix: Some("Provide a valid model path".into()),
77 config_path: "model_path".into(),
78 });
79 }
80
81 Ok(issues)
82 }
83
84 fn supported_model_types(&self) -> Vec<ModelType> {
85 vec![
86 ModelType::Llama,
87 ModelType::Mistral,
88 ModelType::Qwen,
89 ModelType::Phi,
90 ModelType::Custom("stub".into()),
91 ]
92 }
93
94 async fn estimate_build_time(&self, _config: &ModelConfig) -> Result<BuildTimeEstimate> {
95 Ok(BuildTimeEstimate {
96 min_time_seconds: 1,
97 max_time_seconds: 10,
98 expected_time_seconds: 3,
99 time_breakdown: BuildTimeBreakdown {
100 weight_loading_seconds: 1,
101 model_init_seconds: 1,
102 optimization_seconds: 0,
103 validation_seconds: 1,
104 overhead_seconds: 0,
105 },
106 factors: vec![],
107 })
108 }
109
110 fn builder_info(&self) -> BuilderInfo {
111 BuilderInfo {
112 name: "SimpleModelBuilder".into(),
113 version: env!("CARGO_PKG_VERSION").into(),
114 supported_architectures: vec![
115 ModelArchitecture {
116 name: "Llama".into(),
117 family: ModelArchitectureFamily::Transformer,
118 variants: vec!["llama-7b".into(), "llama-13b".into()],
119 required_features: vec![],
120 },
121 ModelArchitecture {
122 name: "Mistral".into(),
123 family: ModelArchitectureFamily::Transformer,
124 variants: vec!["mistral-7b".into()],
125 required_features: vec![],
126 },
127 ],
128 supported_weight_formats: vec![WeightFormat::SafeTensors],
129 supported_optimizations: vec![],
130 capabilities: BuilderCapabilities {
131 max_model_size: Some(70_000_000_000),
132 supports_dynamic_shapes: false,
133 supports_custom_ops: false,
134 supports_mixed_precision: true,
135 supports_model_parallelism: false,
136 supports_parallel_build: false,
137 supports_incremental_build: false,
138 },
139 }
140 }
141}
142
143#[derive(Debug, Default, Clone)]
145pub struct DefaultModelBuilderFactory;
146
147impl DefaultModelBuilderFactory {
148 pub fn new() -> Self {
149 Self
150 }
151
152 pub fn create(&self) -> Arc<dyn ModelBuilder + Send + Sync> {
153 Arc::new(SimpleModelBuilder)
154 }
155
156 pub fn create_for_model_type(
157 &self,
158 _model_type: &ModelType,
159 ) -> Arc<dyn ModelBuilder + Send + Sync> {
160 Arc::new(SimpleModelBuilder)
161 }
162}