pacha/recipe/
hyperparams.rs1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct Hyperparameters {
9 pub learning_rate: f64,
11 pub batch_size: usize,
13 pub epochs: usize,
15 #[serde(default)]
17 pub weight_decay: f64,
18 #[serde(skip_serializing_if = "Option::is_none")]
20 pub max_grad_norm: Option<f64>,
21 #[serde(skip_serializing_if = "Option::is_none")]
23 pub warmup_steps: Option<usize>,
24 #[serde(default)]
26 pub custom: HashMap<String, HyperparamValue>,
27}
28
29impl Default for Hyperparameters {
30 fn default() -> Self {
31 Self {
32 learning_rate: 1e-3,
33 batch_size: 32,
34 epochs: 10,
35 weight_decay: 0.0,
36 max_grad_norm: None,
37 warmup_steps: None,
38 custom: HashMap::new(),
39 }
40 }
41}
42
43impl Hyperparameters {
44 #[must_use]
46 pub fn builder() -> HyperparametersBuilder {
47 HyperparametersBuilder::new()
48 }
49
50 pub fn set_custom(&mut self, name: impl Into<String>, value: HyperparamValue) {
52 self.custom.insert(name.into(), value);
53 }
54
55 #[must_use]
57 pub fn get_custom(&self, name: &str) -> Option<&HyperparamValue> {
58 self.custom.get(name)
59 }
60}
61
62#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
64#[serde(untagged)]
65pub enum HyperparamValue {
66 Float(f64),
68 Int(i64),
70 Bool(bool),
72 String(String),
74 List(Vec<HyperparamValue>),
76}
77
78impl HyperparamValue {
79 #[must_use]
81 #[allow(clippy::cast_precision_loss)]
82 pub fn as_float(&self) -> Option<f64> {
83 match self {
84 Self::Float(f) => Some(*f),
85 Self::Int(i) => Some(*i as f64),
86 _ => None,
87 }
88 }
89
90 #[must_use]
92 #[allow(clippy::cast_possible_truncation)]
93 pub fn as_int(&self) -> Option<i64> {
94 match self {
95 Self::Int(i) => Some(*i),
96 Self::Float(f) => Some(*f as i64),
97 _ => None,
98 }
99 }
100
101 #[must_use]
103 pub fn as_bool(&self) -> Option<bool> {
104 match self {
105 Self::Bool(b) => Some(*b),
106 _ => None,
107 }
108 }
109
110 #[must_use]
112 pub fn as_string(&self) -> Option<&str> {
113 match self {
114 Self::String(s) => Some(s),
115 _ => None,
116 }
117 }
118
119 #[must_use]
121 pub fn as_list(&self) -> Option<&[HyperparamValue]> {
122 match self {
123 Self::List(l) => Some(l),
124 _ => None,
125 }
126 }
127}
128
129impl From<f64> for HyperparamValue {
130 fn from(v: f64) -> Self {
131 Self::Float(v)
132 }
133}
134
135impl From<i64> for HyperparamValue {
136 fn from(v: i64) -> Self {
137 Self::Int(v)
138 }
139}
140
141impl From<i32> for HyperparamValue {
142 fn from(v: i32) -> Self {
143 Self::Int(i64::from(v))
144 }
145}
146
147impl From<bool> for HyperparamValue {
148 fn from(v: bool) -> Self {
149 Self::Bool(v)
150 }
151}
152
153impl From<String> for HyperparamValue {
154 fn from(v: String) -> Self {
155 Self::String(v)
156 }
157}
158
159impl From<&str> for HyperparamValue {
160 fn from(v: &str) -> Self {
161 Self::String(v.to_string())
162 }
163}
164
165impl<T: Into<HyperparamValue>> From<Vec<T>> for HyperparamValue {
166 fn from(v: Vec<T>) -> Self {
167 Self::List(v.into_iter().map(Into::into).collect())
168 }
169}
170
171#[derive(Debug, Default)]
173pub struct HyperparametersBuilder {
174 params: Hyperparameters,
175}
176
177impl HyperparametersBuilder {
178 #[must_use]
180 pub fn new() -> Self {
181 Self { params: Hyperparameters::default() }
182 }
183
184 #[must_use]
186 pub fn learning_rate(mut self, lr: f64) -> Self {
187 self.params.learning_rate = lr;
188 self
189 }
190
191 #[must_use]
193 pub fn batch_size(mut self, size: usize) -> Self {
194 self.params.batch_size = size;
195 self
196 }
197
198 #[must_use]
200 pub fn epochs(mut self, epochs: usize) -> Self {
201 self.params.epochs = epochs;
202 self
203 }
204
205 #[must_use]
207 pub fn weight_decay(mut self, decay: f64) -> Self {
208 self.params.weight_decay = decay;
209 self
210 }
211
212 #[must_use]
214 pub fn max_grad_norm(mut self, norm: f64) -> Self {
215 self.params.max_grad_norm = Some(norm);
216 self
217 }
218
219 #[must_use]
221 pub fn warmup_steps(mut self, steps: usize) -> Self {
222 self.params.warmup_steps = Some(steps);
223 self
224 }
225
226 #[must_use]
228 pub fn custom(mut self, name: impl Into<String>, value: impl Into<HyperparamValue>) -> Self {
229 self.params.custom.insert(name.into(), value.into());
230 self
231 }
232
233 #[must_use]
235 pub fn build(self) -> Hyperparameters {
236 self.params
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[test]
245 fn test_hyperparameters_default() {
246 let params = Hyperparameters::default();
247 assert!((params.learning_rate - 1e-3).abs() < 1e-10);
248 assert_eq!(params.batch_size, 32);
249 assert_eq!(params.epochs, 10);
250 }
251
252 #[test]
253 fn test_hyperparameters_builder() {
254 let params = Hyperparameters::builder()
255 .learning_rate(2e-5)
256 .batch_size(64)
257 .epochs(3)
258 .weight_decay(0.01)
259 .max_grad_norm(1.0)
260 .warmup_steps(100)
261 .custom("dropout", 0.1)
262 .build();
263
264 assert!((params.learning_rate - 2e-5).abs() < 1e-10);
265 assert_eq!(params.batch_size, 64);
266 assert_eq!(params.epochs, 3);
267 assert!((params.weight_decay - 0.01).abs() < 1e-10);
268 assert_eq!(params.max_grad_norm, Some(1.0));
269 assert_eq!(params.warmup_steps, Some(100));
270 assert_eq!(params.get_custom("dropout").and_then(|v| v.as_float()), Some(0.1));
271 }
272
273 #[test]
274 fn test_hyperparam_value_types() {
275 let float_val = HyperparamValue::Float(3.14);
276 assert_eq!(float_val.as_float(), Some(3.14));
277 assert_eq!(float_val.as_int(), Some(3));
278
279 let int_val = HyperparamValue::Int(42);
280 assert_eq!(int_val.as_int(), Some(42));
281 assert_eq!(int_val.as_float(), Some(42.0));
282
283 let bool_val = HyperparamValue::Bool(true);
284 assert_eq!(bool_val.as_bool(), Some(true));
285
286 let string_val = HyperparamValue::String("test".to_string());
287 assert_eq!(string_val.as_string(), Some("test"));
288
289 let list_val =
290 HyperparamValue::List(vec![HyperparamValue::Int(1), HyperparamValue::Int(2)]);
291 assert_eq!(list_val.as_list().map(|l| l.len()), Some(2));
292 }
293
294 #[test]
295 fn test_hyperparam_value_from() {
296 let from_float: HyperparamValue = 3.14.into();
297 assert!(matches!(from_float, HyperparamValue::Float(_)));
298
299 let from_int: HyperparamValue = 42i64.into();
300 assert!(matches!(from_int, HyperparamValue::Int(_)));
301
302 let from_bool: HyperparamValue = true.into();
303 assert!(matches!(from_bool, HyperparamValue::Bool(_)));
304
305 let from_str: HyperparamValue = "test".into();
306 assert!(matches!(from_str, HyperparamValue::String(_)));
307
308 let from_vec: HyperparamValue = vec![1i64, 2i64, 3i64].into();
309 assert!(matches!(from_vec, HyperparamValue::List(_)));
310 }
311
312 #[test]
313 fn test_hyperparameters_serialization() {
314 let params = Hyperparameters::builder()
315 .learning_rate(1e-4)
316 .batch_size(16)
317 .custom("hidden_size", 768i64)
318 .build();
319
320 let json = serde_json::to_string(¶ms).unwrap();
321 let deserialized: Hyperparameters = serde_json::from_str(&json).unwrap();
322
323 assert!((params.learning_rate - deserialized.learning_rate).abs() < 1e-10);
324 assert_eq!(params.batch_size, deserialized.batch_size);
325 }
326
327 #[test]
328 fn test_set_get_custom() {
329 let mut params = Hyperparameters::default();
330 params.set_custom("test_param", HyperparamValue::Float(0.5));
331
332 let value = params.get_custom("test_param");
333 assert_eq!(value.and_then(|v| v.as_float()), Some(0.5));
334
335 assert!(params.get_custom("nonexistent").is_none());
336 }
337}