1use super::schema::{OptimSpec, TrainSpec};
4use crate::error::{Error, Result};
5use crate::io::{load_model, Model, ModelMetadata};
6use crate::optim::{Adam, AdamW, Optimizer, SGD};
7use crate::Tensor;
8
9const PARAM_MOMENTUM: &str = "momentum";
11const PARAM_BETA1: &str = "beta1";
12const PARAM_BETA2: &str = "beta2";
13const PARAM_EPS: &str = "eps";
14const PARAM_WEIGHT_DECAY: &str = "weight_decay";
15
16pub fn build_optimizer(spec: &OptimSpec) -> Result<Box<dyn Optimizer>> {
18 match spec.name.to_lowercase().as_str() {
19 "sgd" => {
20 let momentum =
21 spec.params.get(PARAM_MOMENTUM).and_then(serde_json::Value::as_f64).unwrap_or(0.0)
22 as f32;
23
24 Ok(Box::new(SGD::new(spec.lr, momentum)))
25 }
26 "adam" => {
27 let beta1 =
28 spec.params.get(PARAM_BETA1).and_then(serde_json::Value::as_f64).unwrap_or(0.9)
29 as f32;
30
31 let beta2 =
32 spec.params.get(PARAM_BETA2).and_then(serde_json::Value::as_f64).unwrap_or(0.999)
33 as f32;
34
35 let eps = spec.params.get(PARAM_EPS).and_then(serde_json::Value::as_f64).unwrap_or(1e-8)
36 as f32;
37
38 Ok(Box::new(Adam::new(spec.lr, beta1, beta2, eps)))
39 }
40 "adamw" => {
41 let beta1 =
42 spec.params.get(PARAM_BETA1).and_then(serde_json::Value::as_f64).unwrap_or(0.9)
43 as f32;
44
45 let beta2 =
46 spec.params.get(PARAM_BETA2).and_then(serde_json::Value::as_f64).unwrap_or(0.999)
47 as f32;
48
49 let eps = spec.params.get(PARAM_EPS).and_then(serde_json::Value::as_f64).unwrap_or(1e-8)
50 as f32;
51
52 let weight_decay = spec
53 .params
54 .get(PARAM_WEIGHT_DECAY)
55 .and_then(serde_json::Value::as_f64)
56 .unwrap_or(0.01) as f32;
57
58 Ok(Box::new(AdamW::new(spec.lr, beta1, beta2, eps, weight_decay)))
59 }
60 name => Err(Error::ConfigError(format!(
61 "Unknown optimizer: {name}. Supported: sgd, adam, adamw"
62 ))),
63 }
64}
65
66pub fn build_model(spec: &TrainSpec) -> Result<Model> {
76 let model_path = &spec.model.path;
77
78 if model_path.exists() {
80 println!("Loading model from: {}", model_path.display());
81 let mut model = load_model(model_path)?;
82
83 model.metadata = model
85 .metadata
86 .with_custom("config_path", serde_json::json!(model_path))
87 .with_custom("optimizer", serde_json::json!(spec.optimizer.name))
88 .with_custom("learning_rate", serde_json::json!(spec.optimizer.lr))
89 .with_custom("batch_size", serde_json::json!(spec.data.batch_size));
90
91 for (_, tensor) in &mut model.parameters {
93 tensor.set_requires_grad(true);
94 }
95
96 println!(
97 "Loaded model '{}' with {} parameters",
98 model.metadata.name,
99 model.parameters.len()
100 );
101
102 return Ok(model);
103 }
104
105 eprintln!(
107 "Warning: Model file not found at '{}', using demo mode (simple MLP)",
108 model_path.display()
109 );
110
111 let params = vec![
112 ("layer1.weight".to_string(), Tensor::from_vec(vec![0.1, 0.2, 0.3, 0.4], true)),
113 ("layer1.bias".to_string(), Tensor::from_vec(vec![0.01, 0.02], true)),
114 ("layer2.weight".to_string(), Tensor::from_vec(vec![0.5, 0.6], true)),
115 ("layer2.bias".to_string(), Tensor::from_vec(vec![0.1], true)),
116 ];
117
118 let metadata =
119 ModelMetadata::new(format!("demo-model-from-{}", model_path.display()), "simple-mlp")
120 .with_custom("demo_mode", serde_json::json!(true))
121 .with_custom("config_path", serde_json::json!(model_path))
122 .with_custom("optimizer", serde_json::json!(spec.optimizer.name))
123 .with_custom("learning_rate", serde_json::json!(spec.optimizer.lr))
124 .with_custom("batch_size", serde_json::json!(spec.data.batch_size));
125
126 Ok(Model::new(metadata, params))
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 use std::path::PathBuf;
133
134 #[test]
135 fn test_build_optimizer_adam() {
136 let mut params = std::collections::HashMap::new();
137 params.insert("beta1".to_string(), serde_json::json!(0.9));
138 params.insert("beta2".to_string(), serde_json::json!(0.999));
139
140 let spec = OptimSpec { name: "adam".to_string(), lr: 0.001, params };
141
142 let optimizer = build_optimizer(&spec).expect("operation should succeed");
143 assert_eq!(optimizer.lr(), 0.001);
144 }
145
146 #[test]
147 fn test_build_optimizer_sgd() {
148 let mut params = std::collections::HashMap::new();
149 params.insert("momentum".to_string(), serde_json::json!(0.9));
150
151 let spec = OptimSpec { name: "sgd".to_string(), lr: 0.01, params };
152
153 let optimizer = build_optimizer(&spec).expect("operation should succeed");
154 assert_eq!(optimizer.lr(), 0.01);
155 }
156
157 #[test]
158 fn test_build_optimizer_adamw() {
159 let mut params = std::collections::HashMap::new();
160 params.insert("weight_decay".to_string(), serde_json::json!(0.01));
161
162 let spec = OptimSpec { name: "adamw".to_string(), lr: 0.001, params };
163
164 let optimizer = build_optimizer(&spec).expect("operation should succeed");
165 assert_eq!(optimizer.lr(), 0.001);
166 }
167
168 #[test]
169 fn test_build_optimizer_unknown() {
170 let spec = OptimSpec {
171 name: "unknown".to_string(),
172 lr: 0.001,
173 params: std::collections::HashMap::new(),
174 };
175
176 let result = build_optimizer(&spec);
177 assert!(result.is_err());
178 }
179
180 #[test]
181 fn test_build_model_demo_mode() {
182 use super::super::schema::{DataConfig, ModelRef, TrainSpec, TrainingParams};
183
184 let spec = TrainSpec {
186 model: ModelRef { path: PathBuf::from("nonexistent.gguf"), ..Default::default() },
187 data: DataConfig {
188 train: PathBuf::from("train.parquet"),
189 batch_size: 8,
190 ..Default::default()
191 },
192 optimizer: OptimSpec {
193 name: "adam".to_string(),
194 lr: 0.001,
195 params: std::collections::HashMap::new(),
196 },
197 lora: None,
198 quantize: None,
199 merge: None,
200 training: TrainingParams::default(),
201 publish: None,
202 };
203
204 let model = build_model(&spec).expect("operation should succeed");
205 assert_eq!(model.parameters.len(), 4);
206 assert!(model.get_parameter("layer1.weight").is_some());
207 assert_eq!(model.metadata.architecture, "simple-mlp");
209 assert!(model.metadata.name.starts_with("demo-model"));
210 }
211
212 #[test]
213 fn test_build_model_loads_real_file() {
214 use super::super::schema::{DataConfig, ModelRef, TrainSpec, TrainingParams};
215 use crate::io::{save_model, ModelFormat, SaveConfig};
216 use tempfile::NamedTempFile;
217
218 let params = vec![
220 ("embed.weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], false)),
221 ("attn.q".to_string(), Tensor::from_vec(vec![0.1, 0.2], false)),
222 ("attn.k".to_string(), Tensor::from_vec(vec![0.3, 0.4], false)),
223 ];
224 let original = Model::new(ModelMetadata::new("test-transformer", "transformer"), params);
225
226 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
227 let temp_path = temp_file.path().with_extension("safetensors");
228
229 let config = SaveConfig::new(ModelFormat::SafeTensors);
230 save_model(&original, &temp_path, &config).expect("save should succeed");
231
232 let spec = TrainSpec {
234 model: ModelRef { path: temp_path.clone(), ..Default::default() },
235 data: DataConfig {
236 train: PathBuf::from("train.parquet"),
237 batch_size: 8,
238 ..Default::default()
239 },
240 optimizer: OptimSpec {
241 name: "adam".to_string(),
242 lr: 0.001,
243 params: std::collections::HashMap::new(),
244 },
245 lora: None,
246 quantize: None,
247 merge: None,
248 training: TrainingParams::default(),
249 publish: None,
250 };
251
252 let loaded = build_model(&spec).expect("load should succeed");
253
254 assert_eq!(loaded.parameters.len(), 3);
256 assert!(loaded.get_parameter("embed.weight").is_some());
257 assert!(loaded.get_parameter("attn.q").is_some());
258 assert!(loaded.get_parameter("attn.k").is_some());
259
260 assert_eq!(loaded.metadata.name, "test-transformer");
262 assert_eq!(loaded.metadata.architecture, "transformer");
263
264 for (_, tensor) in &loaded.parameters {
266 assert!(
267 tensor.requires_grad(),
268 "All parameters should have requires_grad=true for training"
269 );
270 }
271
272 std::fs::remove_file(temp_path).ok();
274 }
275
276 #[test]
277 fn test_build_model_adds_training_metadata() {
278 use super::super::schema::{DataConfig, ModelRef, TrainSpec, TrainingParams};
279 use crate::io::{save_model, ModelFormat, SaveConfig};
280 use tempfile::NamedTempFile;
281
282 let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
284 let original = Model::new(ModelMetadata::new("meta-test", "linear"), params);
285
286 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
287 let temp_path = temp_file.path().with_extension("json");
288
289 let config = SaveConfig::new(ModelFormat::Json);
290 save_model(&original, &temp_path, &config).expect("save should succeed");
291
292 let spec = TrainSpec {
293 model: ModelRef { path: temp_path.clone(), ..Default::default() },
294 data: DataConfig {
295 train: PathBuf::from("train.parquet"),
296 batch_size: 32,
297 ..Default::default()
298 },
299 optimizer: OptimSpec {
300 name: "adamw".to_string(),
301 lr: 0.0001,
302 params: std::collections::HashMap::new(),
303 },
304 lora: None,
305 quantize: None,
306 merge: None,
307 training: TrainingParams::default(),
308 publish: None,
309 };
310
311 let loaded = build_model(&spec).expect("load should succeed");
312
313 assert!(loaded.metadata.custom.contains_key("optimizer"));
315 assert!(loaded.metadata.custom.contains_key("learning_rate"));
316 assert!(loaded.metadata.custom.contains_key("batch_size"));
317
318 std::fs::remove_file(temp_path).ok();
320 }
321}