Skip to main content

entrenar/config/
builder.rs

1//! Build training components from configuration
2
3use super::schema::{OptimSpec, TrainSpec};
4use crate::error::{Error, Result};
5use crate::io::{load_model, Model, ModelMetadata};
6use crate::optim::{Adam, AdamW, Optimizer, SGD};
7use crate::Tensor;
8
9// Optimizer parameter field name constants (CB-525)
10const PARAM_MOMENTUM: &str = "momentum";
11const PARAM_BETA1: &str = "beta1";
12const PARAM_BETA2: &str = "beta2";
13const PARAM_EPS: &str = "eps";
14const PARAM_WEIGHT_DECAY: &str = "weight_decay";
15
16/// Build optimizer from configuration
17pub fn build_optimizer(spec: &OptimSpec) -> Result<Box<dyn Optimizer>> {
18    match spec.name.to_lowercase().as_str() {
19        "sgd" => {
20            let momentum =
21                spec.params.get(PARAM_MOMENTUM).and_then(serde_json::Value::as_f64).unwrap_or(0.0)
22                    as f32;
23
24            Ok(Box::new(SGD::new(spec.lr, momentum)))
25        }
26        "adam" => {
27            let beta1 =
28                spec.params.get(PARAM_BETA1).and_then(serde_json::Value::as_f64).unwrap_or(0.9)
29                    as f32;
30
31            let beta2 =
32                spec.params.get(PARAM_BETA2).and_then(serde_json::Value::as_f64).unwrap_or(0.999)
33                    as f32;
34
35            let eps = spec.params.get(PARAM_EPS).and_then(serde_json::Value::as_f64).unwrap_or(1e-8)
36                as f32;
37
38            Ok(Box::new(Adam::new(spec.lr, beta1, beta2, eps)))
39        }
40        "adamw" => {
41            let beta1 =
42                spec.params.get(PARAM_BETA1).and_then(serde_json::Value::as_f64).unwrap_or(0.9)
43                    as f32;
44
45            let beta2 =
46                spec.params.get(PARAM_BETA2).and_then(serde_json::Value::as_f64).unwrap_or(0.999)
47                    as f32;
48
49            let eps = spec.params.get(PARAM_EPS).and_then(serde_json::Value::as_f64).unwrap_or(1e-8)
50                as f32;
51
52            let weight_decay = spec
53                .params
54                .get(PARAM_WEIGHT_DECAY)
55                .and_then(serde_json::Value::as_f64)
56                .unwrap_or(0.01) as f32;
57
58            Ok(Box::new(AdamW::new(spec.lr, beta1, beta2, eps, weight_decay)))
59        }
60        name => Err(Error::ConfigError(format!(
61            "Unknown optimizer: {name}. Supported: sgd, adam, adamw"
62        ))),
63    }
64}
65
66/// Build a model from configuration by loading from file
67///
68/// Loads the model from the path specified in the TrainSpec. Supports:
69/// - SafeTensors (.safetensors) - HuggingFace compatible binary format
70/// - JSON (.json) - Entrenar serialization format
71/// - YAML (.yaml, .yml) - Entrenar serialization format
72///
73/// Falls back to demo mode (simple MLP) if the model file doesn't exist,
74/// to support testing and demonstration workflows.
75pub fn build_model(spec: &TrainSpec) -> Result<Model> {
76    let model_path = &spec.model.path;
77
78    // Try to load the actual model if it exists
79    if model_path.exists() {
80        println!("Loading model from: {}", model_path.display());
81        let mut model = load_model(model_path)?;
82
83        // Add training metadata
84        model.metadata = model
85            .metadata
86            .with_custom("config_path", serde_json::json!(model_path))
87            .with_custom("optimizer", serde_json::json!(spec.optimizer.name))
88            .with_custom("learning_rate", serde_json::json!(spec.optimizer.lr))
89            .with_custom("batch_size", serde_json::json!(spec.data.batch_size));
90
91        // Enable gradients on all parameters for training
92        for (_, tensor) in &mut model.parameters {
93            tensor.set_requires_grad(true);
94        }
95
96        println!(
97            "Loaded model '{}' with {} parameters",
98            model.metadata.name,
99            model.parameters.len()
100        );
101
102        return Ok(model);
103    }
104
105    // Demo mode fallback: create a simple model for testing
106    eprintln!(
107        "Warning: Model file not found at '{}', using demo mode (simple MLP)",
108        model_path.display()
109    );
110
111    let params = vec![
112        ("layer1.weight".to_string(), Tensor::from_vec(vec![0.1, 0.2, 0.3, 0.4], true)),
113        ("layer1.bias".to_string(), Tensor::from_vec(vec![0.01, 0.02], true)),
114        ("layer2.weight".to_string(), Tensor::from_vec(vec![0.5, 0.6], true)),
115        ("layer2.bias".to_string(), Tensor::from_vec(vec![0.1], true)),
116    ];
117
118    let metadata =
119        ModelMetadata::new(format!("demo-model-from-{}", model_path.display()), "simple-mlp")
120            .with_custom("demo_mode", serde_json::json!(true))
121            .with_custom("config_path", serde_json::json!(model_path))
122            .with_custom("optimizer", serde_json::json!(spec.optimizer.name))
123            .with_custom("learning_rate", serde_json::json!(spec.optimizer.lr))
124            .with_custom("batch_size", serde_json::json!(spec.data.batch_size));
125
126    Ok(Model::new(metadata, params))
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use std::path::PathBuf;
133
134    #[test]
135    fn test_build_optimizer_adam() {
136        let mut params = std::collections::HashMap::new();
137        params.insert("beta1".to_string(), serde_json::json!(0.9));
138        params.insert("beta2".to_string(), serde_json::json!(0.999));
139
140        let spec = OptimSpec { name: "adam".to_string(), lr: 0.001, params };
141
142        let optimizer = build_optimizer(&spec).expect("operation should succeed");
143        assert_eq!(optimizer.lr(), 0.001);
144    }
145
146    #[test]
147    fn test_build_optimizer_sgd() {
148        let mut params = std::collections::HashMap::new();
149        params.insert("momentum".to_string(), serde_json::json!(0.9));
150
151        let spec = OptimSpec { name: "sgd".to_string(), lr: 0.01, params };
152
153        let optimizer = build_optimizer(&spec).expect("operation should succeed");
154        assert_eq!(optimizer.lr(), 0.01);
155    }
156
157    #[test]
158    fn test_build_optimizer_adamw() {
159        let mut params = std::collections::HashMap::new();
160        params.insert("weight_decay".to_string(), serde_json::json!(0.01));
161
162        let spec = OptimSpec { name: "adamw".to_string(), lr: 0.001, params };
163
164        let optimizer = build_optimizer(&spec).expect("operation should succeed");
165        assert_eq!(optimizer.lr(), 0.001);
166    }
167
168    #[test]
169    fn test_build_optimizer_unknown() {
170        let spec = OptimSpec {
171            name: "unknown".to_string(),
172            lr: 0.001,
173            params: std::collections::HashMap::new(),
174        };
175
176        let result = build_optimizer(&spec);
177        assert!(result.is_err());
178    }
179
180    #[test]
181    fn test_build_model_demo_mode() {
182        use super::super::schema::{DataConfig, ModelRef, TrainSpec, TrainingParams};
183
184        // When model file doesn't exist, should fall back to demo mode
185        let spec = TrainSpec {
186            model: ModelRef { path: PathBuf::from("nonexistent.gguf"), ..Default::default() },
187            data: DataConfig {
188                train: PathBuf::from("train.parquet"),
189                batch_size: 8,
190                ..Default::default()
191            },
192            optimizer: OptimSpec {
193                name: "adam".to_string(),
194                lr: 0.001,
195                params: std::collections::HashMap::new(),
196            },
197            lora: None,
198            quantize: None,
199            merge: None,
200            training: TrainingParams::default(),
201            publish: None,
202        };
203
204        let model = build_model(&spec).expect("operation should succeed");
205        assert_eq!(model.parameters.len(), 4);
206        assert!(model.get_parameter("layer1.weight").is_some());
207        // Verify demo mode indicator
208        assert_eq!(model.metadata.architecture, "simple-mlp");
209        assert!(model.metadata.name.starts_with("demo-model"));
210    }
211
212    #[test]
213    fn test_build_model_loads_real_file() {
214        use super::super::schema::{DataConfig, ModelRef, TrainSpec, TrainingParams};
215        use crate::io::{save_model, ModelFormat, SaveConfig};
216        use tempfile::NamedTempFile;
217
218        // Create a real model file
219        let params = vec![
220            ("embed.weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], false)),
221            ("attn.q".to_string(), Tensor::from_vec(vec![0.1, 0.2], false)),
222            ("attn.k".to_string(), Tensor::from_vec(vec![0.3, 0.4], false)),
223        ];
224        let original = Model::new(ModelMetadata::new("test-transformer", "transformer"), params);
225
226        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
227        let temp_path = temp_file.path().with_extension("safetensors");
228
229        let config = SaveConfig::new(ModelFormat::SafeTensors);
230        save_model(&original, &temp_path, &config).expect("save should succeed");
231
232        // Build model from the real file
233        let spec = TrainSpec {
234            model: ModelRef { path: temp_path.clone(), ..Default::default() },
235            data: DataConfig {
236                train: PathBuf::from("train.parquet"),
237                batch_size: 8,
238                ..Default::default()
239            },
240            optimizer: OptimSpec {
241                name: "adam".to_string(),
242                lr: 0.001,
243                params: std::collections::HashMap::new(),
244            },
245            lora: None,
246            quantize: None,
247            merge: None,
248            training: TrainingParams::default(),
249            publish: None,
250        };
251
252        let loaded = build_model(&spec).expect("load should succeed");
253
254        // Verify it loaded the real model, not demo mode
255        assert_eq!(loaded.parameters.len(), 3);
256        assert!(loaded.get_parameter("embed.weight").is_some());
257        assert!(loaded.get_parameter("attn.q").is_some());
258        assert!(loaded.get_parameter("attn.k").is_some());
259
260        // Verify metadata was preserved
261        assert_eq!(loaded.metadata.name, "test-transformer");
262        assert_eq!(loaded.metadata.architecture, "transformer");
263
264        // Verify gradients are enabled for training
265        for (_, tensor) in &loaded.parameters {
266            assert!(
267                tensor.requires_grad(),
268                "All parameters should have requires_grad=true for training"
269            );
270        }
271
272        // Clean up
273        std::fs::remove_file(temp_path).ok();
274    }
275
276    #[test]
277    fn test_build_model_adds_training_metadata() {
278        use super::super::schema::{DataConfig, ModelRef, TrainSpec, TrainingParams};
279        use crate::io::{save_model, ModelFormat, SaveConfig};
280        use tempfile::NamedTempFile;
281
282        // Create a real model file
283        let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
284        let original = Model::new(ModelMetadata::new("meta-test", "linear"), params);
285
286        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
287        let temp_path = temp_file.path().with_extension("json");
288
289        let config = SaveConfig::new(ModelFormat::Json);
290        save_model(&original, &temp_path, &config).expect("save should succeed");
291
292        let spec = TrainSpec {
293            model: ModelRef { path: temp_path.clone(), ..Default::default() },
294            data: DataConfig {
295                train: PathBuf::from("train.parquet"),
296                batch_size: 32,
297                ..Default::default()
298            },
299            optimizer: OptimSpec {
300                name: "adamw".to_string(),
301                lr: 0.0001,
302                params: std::collections::HashMap::new(),
303            },
304            lora: None,
305            quantize: None,
306            merge: None,
307            training: TrainingParams::default(),
308            publish: None,
309        };
310
311        let loaded = build_model(&spec).expect("load should succeed");
312
313        // Verify training metadata was added
314        assert!(loaded.metadata.custom.contains_key("optimizer"));
315        assert!(loaded.metadata.custom.contains_key("learning_rate"));
316        assert!(loaded.metadata.custom.contains_key("batch_size"));
317
318        // Clean up
319        std::fs::remove_file(temp_path).ok();
320    }
321}