Skip to main content

pacha/recipe/
hyperparams.rs

1//! Hyperparameter types for training recipes.
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6/// Training hyperparameters.
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct Hyperparameters {
9    /// Learning rate.
10    pub learning_rate: f64,
11    /// Batch size.
12    pub batch_size: usize,
13    /// Number of epochs.
14    pub epochs: usize,
15    /// Weight decay (L2 regularization).
16    #[serde(default)]
17    pub weight_decay: f64,
18    /// Gradient clipping norm.
19    #[serde(skip_serializing_if = "Option::is_none")]
20    pub max_grad_norm: Option<f64>,
21    /// Warmup steps.
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub warmup_steps: Option<usize>,
24    /// Custom parameters.
25    #[serde(default)]
26    pub custom: HashMap<String, HyperparamValue>,
27}
28
29impl Default for Hyperparameters {
30    fn default() -> Self {
31        Self {
32            learning_rate: 1e-3,
33            batch_size: 32,
34            epochs: 10,
35            weight_decay: 0.0,
36            max_grad_norm: None,
37            warmup_steps: None,
38            custom: HashMap::new(),
39        }
40    }
41}
42
43impl Hyperparameters {
44    /// Create a new hyperparameters builder.
45    #[must_use]
46    pub fn builder() -> HyperparametersBuilder {
47        HyperparametersBuilder::new()
48    }
49
50    /// Set a custom parameter.
51    pub fn set_custom(&mut self, name: impl Into<String>, value: HyperparamValue) {
52        self.custom.insert(name.into(), value);
53    }
54
55    /// Get a custom parameter.
56    #[must_use]
57    pub fn get_custom(&self, name: &str) -> Option<&HyperparamValue> {
58        self.custom.get(name)
59    }
60}
61
62/// A hyperparameter value that can be one of several types.
63#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
64#[serde(untagged)]
65pub enum HyperparamValue {
66    /// Floating point value.
67    Float(f64),
68    /// Integer value.
69    Int(i64),
70    /// Boolean value.
71    Bool(bool),
72    /// String value.
73    String(String),
74    /// List of values.
75    List(Vec<HyperparamValue>),
76}
77
78impl HyperparamValue {
79    /// Try to get as float.
80    #[must_use]
81    #[allow(clippy::cast_precision_loss)]
82    pub fn as_float(&self) -> Option<f64> {
83        match self {
84            Self::Float(f) => Some(*f),
85            Self::Int(i) => Some(*i as f64),
86            _ => None,
87        }
88    }
89
90    /// Try to get as integer.
91    #[must_use]
92    #[allow(clippy::cast_possible_truncation)]
93    pub fn as_int(&self) -> Option<i64> {
94        match self {
95            Self::Int(i) => Some(*i),
96            Self::Float(f) => Some(*f as i64),
97            _ => None,
98        }
99    }
100
101    /// Try to get as boolean.
102    #[must_use]
103    pub fn as_bool(&self) -> Option<bool> {
104        match self {
105            Self::Bool(b) => Some(*b),
106            _ => None,
107        }
108    }
109
110    /// Try to get as string.
111    #[must_use]
112    pub fn as_string(&self) -> Option<&str> {
113        match self {
114            Self::String(s) => Some(s),
115            _ => None,
116        }
117    }
118
119    /// Try to get as list.
120    #[must_use]
121    pub fn as_list(&self) -> Option<&[HyperparamValue]> {
122        match self {
123            Self::List(l) => Some(l),
124            _ => None,
125        }
126    }
127}
128
129impl From<f64> for HyperparamValue {
130    fn from(v: f64) -> Self {
131        Self::Float(v)
132    }
133}
134
135impl From<i64> for HyperparamValue {
136    fn from(v: i64) -> Self {
137        Self::Int(v)
138    }
139}
140
141impl From<i32> for HyperparamValue {
142    fn from(v: i32) -> Self {
143        Self::Int(i64::from(v))
144    }
145}
146
147impl From<bool> for HyperparamValue {
148    fn from(v: bool) -> Self {
149        Self::Bool(v)
150    }
151}
152
153impl From<String> for HyperparamValue {
154    fn from(v: String) -> Self {
155        Self::String(v)
156    }
157}
158
159impl From<&str> for HyperparamValue {
160    fn from(v: &str) -> Self {
161        Self::String(v.to_string())
162    }
163}
164
165impl<T: Into<HyperparamValue>> From<Vec<T>> for HyperparamValue {
166    fn from(v: Vec<T>) -> Self {
167        Self::List(v.into_iter().map(Into::into).collect())
168    }
169}
170
171/// Builder for hyperparameters.
172#[derive(Debug, Default)]
173pub struct HyperparametersBuilder {
174    params: Hyperparameters,
175}
176
177impl HyperparametersBuilder {
178    /// Create a new builder.
179    #[must_use]
180    pub fn new() -> Self {
181        Self { params: Hyperparameters::default() }
182    }
183
184    /// Set learning rate.
185    #[must_use]
186    pub fn learning_rate(mut self, lr: f64) -> Self {
187        self.params.learning_rate = lr;
188        self
189    }
190
191    /// Set batch size.
192    #[must_use]
193    pub fn batch_size(mut self, size: usize) -> Self {
194        self.params.batch_size = size;
195        self
196    }
197
198    /// Set epochs.
199    #[must_use]
200    pub fn epochs(mut self, epochs: usize) -> Self {
201        self.params.epochs = epochs;
202        self
203    }
204
205    /// Set weight decay.
206    #[must_use]
207    pub fn weight_decay(mut self, decay: f64) -> Self {
208        self.params.weight_decay = decay;
209        self
210    }
211
212    /// Set max gradient norm.
213    #[must_use]
214    pub fn max_grad_norm(mut self, norm: f64) -> Self {
215        self.params.max_grad_norm = Some(norm);
216        self
217    }
218
219    /// Set warmup steps.
220    #[must_use]
221    pub fn warmup_steps(mut self, steps: usize) -> Self {
222        self.params.warmup_steps = Some(steps);
223        self
224    }
225
226    /// Add a custom parameter.
227    #[must_use]
228    pub fn custom(mut self, name: impl Into<String>, value: impl Into<HyperparamValue>) -> Self {
229        self.params.custom.insert(name.into(), value.into());
230        self
231    }
232
233    /// Build the hyperparameters.
234    #[must_use]
235    pub fn build(self) -> Hyperparameters {
236        self.params
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn test_hyperparameters_default() {
246        let params = Hyperparameters::default();
247        assert!((params.learning_rate - 1e-3).abs() < 1e-10);
248        assert_eq!(params.batch_size, 32);
249        assert_eq!(params.epochs, 10);
250    }
251
252    #[test]
253    fn test_hyperparameters_builder() {
254        let params = Hyperparameters::builder()
255            .learning_rate(2e-5)
256            .batch_size(64)
257            .epochs(3)
258            .weight_decay(0.01)
259            .max_grad_norm(1.0)
260            .warmup_steps(100)
261            .custom("dropout", 0.1)
262            .build();
263
264        assert!((params.learning_rate - 2e-5).abs() < 1e-10);
265        assert_eq!(params.batch_size, 64);
266        assert_eq!(params.epochs, 3);
267        assert!((params.weight_decay - 0.01).abs() < 1e-10);
268        assert_eq!(params.max_grad_norm, Some(1.0));
269        assert_eq!(params.warmup_steps, Some(100));
270        assert_eq!(params.get_custom("dropout").and_then(|v| v.as_float()), Some(0.1));
271    }
272
273    #[test]
274    fn test_hyperparam_value_types() {
275        let float_val = HyperparamValue::Float(3.14);
276        assert_eq!(float_val.as_float(), Some(3.14));
277        assert_eq!(float_val.as_int(), Some(3));
278
279        let int_val = HyperparamValue::Int(42);
280        assert_eq!(int_val.as_int(), Some(42));
281        assert_eq!(int_val.as_float(), Some(42.0));
282
283        let bool_val = HyperparamValue::Bool(true);
284        assert_eq!(bool_val.as_bool(), Some(true));
285
286        let string_val = HyperparamValue::String("test".to_string());
287        assert_eq!(string_val.as_string(), Some("test"));
288
289        let list_val =
290            HyperparamValue::List(vec![HyperparamValue::Int(1), HyperparamValue::Int(2)]);
291        assert_eq!(list_val.as_list().map(|l| l.len()), Some(2));
292    }
293
294    #[test]
295    fn test_hyperparam_value_from() {
296        let from_float: HyperparamValue = 3.14.into();
297        assert!(matches!(from_float, HyperparamValue::Float(_)));
298
299        let from_int: HyperparamValue = 42i64.into();
300        assert!(matches!(from_int, HyperparamValue::Int(_)));
301
302        let from_bool: HyperparamValue = true.into();
303        assert!(matches!(from_bool, HyperparamValue::Bool(_)));
304
305        let from_str: HyperparamValue = "test".into();
306        assert!(matches!(from_str, HyperparamValue::String(_)));
307
308        let from_vec: HyperparamValue = vec![1i64, 2i64, 3i64].into();
309        assert!(matches!(from_vec, HyperparamValue::List(_)));
310    }
311
312    #[test]
313    fn test_hyperparameters_serialization() {
314        let params = Hyperparameters::builder()
315            .learning_rate(1e-4)
316            .batch_size(16)
317            .custom("hidden_size", 768i64)
318            .build();
319
320        let json = serde_json::to_string(&params).unwrap();
321        let deserialized: Hyperparameters = serde_json::from_str(&json).unwrap();
322
323        assert!((params.learning_rate - deserialized.learning_rate).abs() < 1e-10);
324        assert_eq!(params.batch_size, deserialized.batch_size);
325    }
326
327    #[test]
328    fn test_set_get_custom() {
329        let mut params = Hyperparameters::default();
330        params.set_custom("test_param", HyperparamValue::Float(0.5));
331
332        let value = params.get_custom("test_param");
333        assert_eq!(value.and_then(|v| v.as_float()), Some(0.5));
334
335        assert!(params.get_custom("nonexistent").is_none());
336    }
337}