Skip to main content

entrenar/yaml_mode/manifest/
optimizer.rs

1//! Optimizer Configuration
2//!
3//! Contains optimizer-related configuration types for training manifests.
4
5use serde::{Deserialize, Serialize};
6
7/// Optimizer configuration
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct OptimizerConfig {
10    /// Optimizer name (sgd, adam, adamw, rmsprop, adagrad, lamb)
11    pub name: String,
12
13    /// Learning rate
14    pub lr: f64,
15
16    /// Weight decay (L2 regularization)
17    #[serde(default, skip_serializing_if = "Option::is_none")]
18    pub weight_decay: Option<f64>,
19
20    /// Adam/AdamW betas
21    #[serde(default, skip_serializing_if = "Option::is_none")]
22    pub betas: Option<Vec<f64>>,
23
24    /// Adam epsilon
25    #[serde(default, skip_serializing_if = "Option::is_none")]
26    pub eps: Option<f64>,
27
28    /// AMSGrad variant
29    #[serde(default, skip_serializing_if = "Option::is_none")]
30    pub amsgrad: Option<bool>,
31
32    /// SGD momentum
33    #[serde(default, skip_serializing_if = "Option::is_none")]
34    pub momentum: Option<f64>,
35
36    /// Nesterov momentum
37    #[serde(default, skip_serializing_if = "Option::is_none")]
38    pub nesterov: Option<bool>,
39
40    /// SGD dampening
41    #[serde(default, skip_serializing_if = "Option::is_none")]
42    pub dampening: Option<f64>,
43
44    /// RMSprop alpha
45    #[serde(default, skip_serializing_if = "Option::is_none")]
46    pub alpha: Option<f64>,
47
48    /// RMSprop centered
49    #[serde(default, skip_serializing_if = "Option::is_none")]
50    pub centered: Option<bool>,
51
52    /// Per-parameter groups
53    #[serde(default, skip_serializing_if = "Option::is_none")]
54    pub param_groups: Option<Vec<ParamGroup>>,
55}
56
57/// Per-parameter group configuration
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct ParamGroup {
60    pub params: String,
61    #[serde(default, skip_serializing_if = "Option::is_none")]
62    pub lr: Option<f64>,
63    #[serde(default, skip_serializing_if = "Option::is_none")]
64    pub weight_decay: Option<f64>,
65}