use super::schema::{OptimSpec, TrainSpec};
use crate::error::{Error, Result};
use crate::io::{load_model, Model, ModelMetadata};
use crate::optim::{Adam, AdamW, Optimizer, SGD};
use crate::Tensor;
const PARAM_MOMENTUM: &str = "momentum";
const PARAM_BETA1: &str = "beta1";
const PARAM_BETA2: &str = "beta2";
const PARAM_EPS: &str = "eps";
const PARAM_WEIGHT_DECAY: &str = "weight_decay";
pub fn build_optimizer(spec: &OptimSpec) -> Result<Box<dyn Optimizer>> {
match spec.name.to_lowercase().as_str() {
"sgd" => {
let momentum =
spec.params.get(PARAM_MOMENTUM).and_then(serde_json::Value::as_f64).unwrap_or(0.0)
as f32;
Ok(Box::new(SGD::new(spec.lr, momentum)))
}
"adam" => {
let beta1 =
spec.params.get(PARAM_BETA1).and_then(serde_json::Value::as_f64).unwrap_or(0.9)
as f32;
let beta2 =
spec.params.get(PARAM_BETA2).and_then(serde_json::Value::as_f64).unwrap_or(0.999)
as f32;
let eps = spec.params.get(PARAM_EPS).and_then(serde_json::Value::as_f64).unwrap_or(1e-8)
as f32;
Ok(Box::new(Adam::new(spec.lr, beta1, beta2, eps)))
}
"adamw" => {
let beta1 =
spec.params.get(PARAM_BETA1).and_then(serde_json::Value::as_f64).unwrap_or(0.9)
as f32;
let beta2 =
spec.params.get(PARAM_BETA2).and_then(serde_json::Value::as_f64).unwrap_or(0.999)
as f32;
let eps = spec.params.get(PARAM_EPS).and_then(serde_json::Value::as_f64).unwrap_or(1e-8)
as f32;
let weight_decay = spec
.params
.get(PARAM_WEIGHT_DECAY)
.and_then(serde_json::Value::as_f64)
.unwrap_or(0.01) as f32;
Ok(Box::new(AdamW::new(spec.lr, beta1, beta2, eps, weight_decay)))
}
name => Err(Error::ConfigError(format!(
"Unknown optimizer: {name}. Supported: sgd, adam, adamw"
))),
}
}
pub fn build_model(spec: &TrainSpec) -> Result<Model> {
let model_path = &spec.model.path;
if model_path.exists() {
println!("Loading model from: {}", model_path.display());
let mut model = load_model(model_path)?;
model.metadata = model
.metadata
.with_custom("config_path", serde_json::json!(model_path))
.with_custom("optimizer", serde_json::json!(spec.optimizer.name))
.with_custom("learning_rate", serde_json::json!(spec.optimizer.lr))
.with_custom("batch_size", serde_json::json!(spec.data.batch_size));
for (_, tensor) in &mut model.parameters {
tensor.set_requires_grad(true);
}
println!(
"Loaded model '{}' with {} parameters",
model.metadata.name,
model.parameters.len()
);
return Ok(model);
}
eprintln!(
"Warning: Model file not found at '{}', using demo mode (simple MLP)",
model_path.display()
);
let params = vec![
("layer1.weight".to_string(), Tensor::from_vec(vec![0.1, 0.2, 0.3, 0.4], true)),
("layer1.bias".to_string(), Tensor::from_vec(vec![0.01, 0.02], true)),
("layer2.weight".to_string(), Tensor::from_vec(vec![0.5, 0.6], true)),
("layer2.bias".to_string(), Tensor::from_vec(vec![0.1], true)),
];
let metadata =
ModelMetadata::new(format!("demo-model-from-{}", model_path.display()), "simple-mlp")
.with_custom("demo_mode", serde_json::json!(true))
.with_custom("config_path", serde_json::json!(model_path))
.with_custom("optimizer", serde_json::json!(spec.optimizer.name))
.with_custom("learning_rate", serde_json::json!(spec.optimizer.lr))
.with_custom("batch_size", serde_json::json!(spec.data.batch_size));
Ok(Model::new(metadata, params))
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[test]
fn test_build_optimizer_adam() {
let mut params = std::collections::HashMap::new();
params.insert("beta1".to_string(), serde_json::json!(0.9));
params.insert("beta2".to_string(), serde_json::json!(0.999));
let spec = OptimSpec { name: "adam".to_string(), lr: 0.001, params };
let optimizer = build_optimizer(&spec).expect("operation should succeed");
assert_eq!(optimizer.lr(), 0.001);
}
#[test]
fn test_build_optimizer_sgd() {
let mut params = std::collections::HashMap::new();
params.insert("momentum".to_string(), serde_json::json!(0.9));
let spec = OptimSpec { name: "sgd".to_string(), lr: 0.01, params };
let optimizer = build_optimizer(&spec).expect("operation should succeed");
assert_eq!(optimizer.lr(), 0.01);
}
#[test]
fn test_build_optimizer_adamw() {
let mut params = std::collections::HashMap::new();
params.insert("weight_decay".to_string(), serde_json::json!(0.01));
let spec = OptimSpec { name: "adamw".to_string(), lr: 0.001, params };
let optimizer = build_optimizer(&spec).expect("operation should succeed");
assert_eq!(optimizer.lr(), 0.001);
}
#[test]
fn test_build_optimizer_unknown() {
let spec = OptimSpec {
name: "unknown".to_string(),
lr: 0.001,
params: std::collections::HashMap::new(),
};
let result = build_optimizer(&spec);
assert!(result.is_err());
}
#[test]
fn test_build_model_demo_mode() {
use super::super::schema::{DataConfig, ModelRef, TrainSpec, TrainingParams};
let spec = TrainSpec {
model: ModelRef { path: PathBuf::from("nonexistent.gguf"), ..Default::default() },
data: DataConfig {
train: PathBuf::from("train.parquet"),
batch_size: 8,
..Default::default()
},
optimizer: OptimSpec {
name: "adam".to_string(),
lr: 0.001,
params: std::collections::HashMap::new(),
},
lora: None,
quantize: None,
merge: None,
training: TrainingParams::default(),
publish: None,
};
let model = build_model(&spec).expect("operation should succeed");
assert_eq!(model.parameters.len(), 4);
assert!(model.get_parameter("layer1.weight").is_some());
assert_eq!(model.metadata.architecture, "simple-mlp");
assert!(model.metadata.name.starts_with("demo-model"));
}
#[test]
fn test_build_model_loads_real_file() {
use super::super::schema::{DataConfig, ModelRef, TrainSpec, TrainingParams};
use crate::io::{save_model, ModelFormat, SaveConfig};
use tempfile::NamedTempFile;
let params = vec![
("embed.weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], false)),
("attn.q".to_string(), Tensor::from_vec(vec![0.1, 0.2], false)),
("attn.k".to_string(), Tensor::from_vec(vec![0.3, 0.4], false)),
];
let original = Model::new(ModelMetadata::new("test-transformer", "transformer"), params);
let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
let temp_path = temp_file.path().with_extension("safetensors");
let config = SaveConfig::new(ModelFormat::SafeTensors);
save_model(&original, &temp_path, &config).expect("save should succeed");
let spec = TrainSpec {
model: ModelRef { path: temp_path.clone(), ..Default::default() },
data: DataConfig {
train: PathBuf::from("train.parquet"),
batch_size: 8,
..Default::default()
},
optimizer: OptimSpec {
name: "adam".to_string(),
lr: 0.001,
params: std::collections::HashMap::new(),
},
lora: None,
quantize: None,
merge: None,
training: TrainingParams::default(),
publish: None,
};
let loaded = build_model(&spec).expect("load should succeed");
assert_eq!(loaded.parameters.len(), 3);
assert!(loaded.get_parameter("embed.weight").is_some());
assert!(loaded.get_parameter("attn.q").is_some());
assert!(loaded.get_parameter("attn.k").is_some());
assert_eq!(loaded.metadata.name, "test-transformer");
assert_eq!(loaded.metadata.architecture, "transformer");
for (_, tensor) in &loaded.parameters {
assert!(
tensor.requires_grad(),
"All parameters should have requires_grad=true for training"
);
}
std::fs::remove_file(temp_path).ok();
}
#[test]
fn test_build_model_adds_training_metadata() {
use super::super::schema::{DataConfig, ModelRef, TrainSpec, TrainingParams};
use crate::io::{save_model, ModelFormat, SaveConfig};
use tempfile::NamedTempFile;
let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
let original = Model::new(ModelMetadata::new("meta-test", "linear"), params);
let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
let temp_path = temp_file.path().with_extension("json");
let config = SaveConfig::new(ModelFormat::Json);
save_model(&original, &temp_path, &config).expect("save should succeed");
let spec = TrainSpec {
model: ModelRef { path: temp_path.clone(), ..Default::default() },
data: DataConfig {
train: PathBuf::from("train.parquet"),
batch_size: 32,
..Default::default()
},
optimizer: OptimSpec {
name: "adamw".to_string(),
lr: 0.0001,
params: std::collections::HashMap::new(),
},
lora: None,
quantize: None,
merge: None,
training: TrainingParams::default(),
publish: None,
};
let loaded = build_model(&spec).expect("load should succeed");
assert!(loaded.metadata.custom.contains_key("optimizer"));
assert!(loaded.metadata.custom.contains_key("learning_rate"));
assert!(loaded.metadata.custom.contains_key("batch_size"));
std::fs::remove_file(temp_path).ok();
}
}