Skip to main content

entrenar/monitor/
params.rs

1//! Parameter Logging API (GH-73)
2//!
3//! Provides structured parameter tracking for training experiments.
4//! Parameters are stored as typed key-value pairs with JSON serialization
5//! and diff support for comparing experiment configurations.
6//!
7//! # Example
8//!
9//! ```
10//! use entrenar::monitor::params::{ParamLogger, ParamValue};
11//!
12//! let mut logger = ParamLogger::new();
13//! logger.log_param("learning_rate", 1e-4_f64);
14//! logger.log_param("epochs", 10_i64);
15//! logger.log_param("model", "llama-7b");
16//! logger.log_param("use_lora", true);
17//!
18//! assert_eq!(
19//!     logger.get_param("learning_rate"),
20//!     Some(&ParamValue::Float(1e-4))
21//! );
22//!
23//! let json = logger.to_json();
24//! assert!(json.contains("learning_rate"));
25//! ```
26
27use serde::{Deserialize, Serialize};
28use std::collections::HashMap;
29
30// =============================================================================
31// ParamValue
32// =============================================================================
33
34/// A typed parameter value supporting common training hyperparameter types.
35#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
36pub enum ParamValue {
37    /// String parameter (e.g., model name, optimizer type)
38    String(String),
39    /// Floating-point parameter (e.g., learning rate, weight decay)
40    Float(f64),
41    /// Integer parameter (e.g., epochs, batch size, seed)
42    Int(i64),
43    /// Boolean parameter (e.g., use_lora, freeze_base)
44    Bool(bool),
45}
46
47impl ParamValue {
48    /// Returns the type name of this value as a static string.
49    pub fn type_name(&self) -> &'static str {
50        match self {
51            ParamValue::String(_) => "string",
52            ParamValue::Float(_) => "float",
53            ParamValue::Int(_) => "int",
54            ParamValue::Bool(_) => "bool",
55        }
56    }
57}
58
59impl std::fmt::Display for ParamValue {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        match self {
62            ParamValue::String(s) => write!(f, "{s}"),
63            ParamValue::Float(v) => write!(f, "{v}"),
64            ParamValue::Int(v) => write!(f, "{v}"),
65            ParamValue::Bool(v) => write!(f, "{v}"),
66        }
67    }
68}
69
70// -- From impls for ergonomic `log_param` calls --
71
72impl From<&str> for ParamValue {
73    fn from(s: &str) -> Self {
74        ParamValue::String(s.to_string())
75    }
76}
77
78impl From<String> for ParamValue {
79    fn from(s: String) -> Self {
80        ParamValue::String(s)
81    }
82}
83
84impl From<f64> for ParamValue {
85    fn from(v: f64) -> Self {
86        ParamValue::Float(v)
87    }
88}
89
90impl From<f32> for ParamValue {
91    fn from(v: f32) -> Self {
92        ParamValue::Float(f64::from(v))
93    }
94}
95
96impl From<i64> for ParamValue {
97    fn from(v: i64) -> Self {
98        ParamValue::Int(v)
99    }
100}
101
102impl From<i32> for ParamValue {
103    fn from(v: i32) -> Self {
104        ParamValue::Int(i64::from(v))
105    }
106}
107
108impl From<bool> for ParamValue {
109    fn from(v: bool) -> Self {
110        ParamValue::Bool(v)
111    }
112}
113
114// =============================================================================
115// ParamDiff
116// =============================================================================
117
118/// Result of comparing two `ParamLogger` instances.
119///
120/// Captures which parameters were changed, added, or removed between
121/// two experiment configurations.
122#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
123pub struct ParamDiff {
124    /// Parameters present in both loggers but with different values.
125    /// Maps key -> (old_value, new_value).
126    pub changed: HashMap<String, (ParamValue, ParamValue)>,
127    /// Parameters present only in the *other* logger (new additions).
128    pub added: HashMap<String, ParamValue>,
129    /// Parameters present only in *self* (removed in other).
130    pub removed: HashMap<String, ParamValue>,
131}
132
133impl ParamDiff {
134    /// Returns `true` if there are no differences.
135    pub fn is_empty(&self) -> bool {
136        self.changed.is_empty() && self.added.is_empty() && self.removed.is_empty()
137    }
138
139    /// Total number of differences (changed + added + removed).
140    pub fn len(&self) -> usize {
141        self.changed.len() + self.added.len() + self.removed.len()
142    }
143}
144
145// =============================================================================
146// ParamLogger
147// =============================================================================
148
149/// Structured parameter logger for training experiments.
150///
151/// Stores hyperparameters, configuration flags, and other experiment metadata
152/// as typed key-value pairs. Supports JSON serialization and diff comparison.
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct ParamLogger {
155    params: HashMap<String, ParamValue>,
156}
157
158impl ParamLogger {
159    /// Create a new, empty parameter logger.
160    pub fn new() -> Self {
161        Self { params: HashMap::new() }
162    }
163
164    /// Log a single parameter. Overwrites any existing value for the key.
165    pub fn log_param(&mut self, key: &str, value: impl Into<ParamValue>) {
166        self.params.insert(key.to_string(), value.into());
167    }
168
169    /// Log multiple parameters at once. Overwrites existing values.
170    pub fn log_params(&mut self, params: HashMap<String, ParamValue>) {
171        self.params.extend(params);
172    }
173
174    /// Retrieve a parameter by key.
175    pub fn get_param(&self, key: &str) -> Option<&ParamValue> {
176        self.params.get(key)
177    }
178
179    /// Retrieve all parameters as a reference to the underlying map.
180    pub fn get_all_params(&self) -> &HashMap<String, ParamValue> {
181        &self.params
182    }
183
184    /// Returns the number of logged parameters.
185    pub fn len(&self) -> usize {
186        self.params.len()
187    }
188
189    /// Returns `true` if no parameters have been logged.
190    pub fn is_empty(&self) -> bool {
191        self.params.is_empty()
192    }
193
194    /// Serialize all parameters to a JSON string.
195    ///
196    /// Keys are sorted for deterministic output.
197    pub fn to_json(&self) -> String {
198        // Use BTreeMap for sorted keys -> deterministic JSON
199        let sorted: std::collections::BTreeMap<&String, &ParamValue> = self.params.iter().collect();
200        serde_json::to_string_pretty(&sorted).unwrap_or_else(|e| {
201            eprintln!("ParamLogger JSON serialization failed: {e}");
202            "{}".to_string()
203        })
204    }
205
206    /// Compute the diff between `self` and `other`.
207    ///
208    /// - **changed**: keys present in both with different values
209    /// - **added**: keys in `other` but not in `self`
210    /// - **removed**: keys in `self` but not in `other`
211    pub fn diff(&self, other: &ParamLogger) -> ParamDiff {
212        let mut changed = HashMap::new();
213        let mut added = HashMap::new();
214        let mut removed = HashMap::new();
215
216        // Find changed and removed
217        for (key, self_val) in &self.params {
218            match other.params.get(key) {
219                Some(other_val) if self_val != other_val => {
220                    changed.insert(key.clone(), (self_val.clone(), other_val.clone()));
221                }
222                None => {
223                    removed.insert(key.clone(), self_val.clone());
224                }
225                _ => {} // Same value, no diff
226            }
227        }
228
229        // Find added (in other but not in self)
230        for (key, other_val) in &other.params {
231            if !self.params.contains_key(key) {
232                added.insert(key.clone(), other_val.clone());
233            }
234        }
235
236        ParamDiff { changed, added, removed }
237    }
238}
239
240impl Default for ParamLogger {
241    fn default() -> Self {
242        Self::new()
243    }
244}
245
246// =============================================================================
247// Tests
248// =============================================================================
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    #[test]
255    fn test_param_logger_new_is_empty() {
256        let logger = ParamLogger::new();
257        assert!(logger.is_empty());
258        assert_eq!(logger.len(), 0);
259    }
260
261    #[test]
262    fn test_log_param_string() {
263        let mut logger = ParamLogger::new();
264        logger.log_param("model", "llama-7b");
265        assert_eq!(logger.get_param("model"), Some(&ParamValue::String("llama-7b".to_string())));
266    }
267
268    #[test]
269    fn test_log_param_float() {
270        let mut logger = ParamLogger::new();
271        logger.log_param("lr", 1e-4_f64);
272        assert_eq!(logger.get_param("lr"), Some(&ParamValue::Float(1e-4)));
273    }
274
275    #[test]
276    fn test_log_param_f32_converts_to_f64() {
277        let mut logger = ParamLogger::new();
278        logger.log_param("weight_decay", 0.01_f32);
279        assert_eq!(logger.get_param("weight_decay"), Some(&ParamValue::Float(f64::from(0.01_f32))));
280    }
281
282    #[test]
283    fn test_log_param_int() {
284        let mut logger = ParamLogger::new();
285        logger.log_param("epochs", 10_i64);
286        assert_eq!(logger.get_param("epochs"), Some(&ParamValue::Int(10)));
287    }
288
289    #[test]
290    fn test_log_param_i32_converts_to_i64() {
291        let mut logger = ParamLogger::new();
292        logger.log_param("batch_size", 32_i32);
293        assert_eq!(logger.get_param("batch_size"), Some(&ParamValue::Int(32)));
294    }
295
296    #[test]
297    fn test_log_param_bool() {
298        let mut logger = ParamLogger::new();
299        logger.log_param("use_lora", true);
300        assert_eq!(logger.get_param("use_lora"), Some(&ParamValue::Bool(true)));
301    }
302
303    #[test]
304    fn test_log_param_owned_string() {
305        let mut logger = ParamLogger::new();
306        logger.log_param("optimizer", String::from("adamw"));
307        assert_eq!(logger.get_param("optimizer"), Some(&ParamValue::String("adamw".to_string())));
308    }
309
310    #[test]
311    fn test_log_param_overwrites() {
312        let mut logger = ParamLogger::new();
313        logger.log_param("lr", 1e-3_f64);
314        logger.log_param("lr", 1e-4_f64);
315        assert_eq!(logger.get_param("lr"), Some(&ParamValue::Float(1e-4)));
316        assert_eq!(logger.len(), 1);
317    }
318
319    #[test]
320    fn test_get_param_missing_returns_none() {
321        let logger = ParamLogger::new();
322        assert_eq!(logger.get_param("nonexistent"), None);
323    }
324
325    #[test]
326    fn test_log_params_bulk() {
327        let mut logger = ParamLogger::new();
328        let mut params = HashMap::new();
329        params.insert("lr".to_string(), ParamValue::Float(1e-4));
330        params.insert("epochs".to_string(), ParamValue::Int(10));
331        params.insert("model".to_string(), ParamValue::String("gpt2".to_string()));
332        logger.log_params(params);
333
334        assert_eq!(logger.len(), 3);
335        assert_eq!(logger.get_param("lr"), Some(&ParamValue::Float(1e-4)));
336        assert_eq!(logger.get_param("epochs"), Some(&ParamValue::Int(10)));
337    }
338
339    #[test]
340    fn test_get_all_params() {
341        let mut logger = ParamLogger::new();
342        logger.log_param("a", 1_i64);
343        logger.log_param("b", 2_i64);
344
345        let all = logger.get_all_params();
346        assert_eq!(all.len(), 2);
347        assert!(all.contains_key("a"));
348        assert!(all.contains_key("b"));
349    }
350
351    #[test]
352    fn test_to_json_deterministic() {
353        let mut logger = ParamLogger::new();
354        logger.log_param("z_param", 1_i64);
355        logger.log_param("a_param", 2_i64);
356        logger.log_param("m_param", 3_i64);
357
358        let json = logger.to_json();
359        // Keys should be sorted alphabetically
360        let a_pos = json.find("a_param").expect("a_param not found");
361        let m_pos = json.find("m_param").expect("m_param not found");
362        let z_pos = json.find("z_param").expect("z_param not found");
363        assert!(a_pos < m_pos, "a_param should come before m_param");
364        assert!(m_pos < z_pos, "m_param should come before z_param");
365    }
366
367    #[test]
368    fn test_to_json_contains_values() {
369        let mut logger = ParamLogger::new();
370        logger.log_param("lr", 0.001_f64);
371        logger.log_param("use_lora", true);
372        logger.log_param("model", "gpt2");
373
374        let json = logger.to_json();
375        assert!(json.contains("0.001"));
376        assert!(json.contains("true"));
377        assert!(json.contains("gpt2"));
378    }
379
380    #[test]
381    fn test_to_json_empty() {
382        let logger = ParamLogger::new();
383        let json = logger.to_json();
384        assert_eq!(json, "{}");
385    }
386
387    #[test]
388    fn test_to_json_roundtrip() {
389        let mut logger = ParamLogger::new();
390        logger.log_param("lr", 1e-4_f64);
391        logger.log_param("epochs", 10_i64);
392        logger.log_param("model", "llama");
393        logger.log_param("lora", true);
394
395        let json = logger.to_json();
396        let deserialized: std::collections::BTreeMap<String, ParamValue> =
397            serde_json::from_str(&json).expect("should deserialize");
398
399        assert_eq!(deserialized.len(), 4);
400        assert_eq!(deserialized.get("lr"), Some(&ParamValue::Float(1e-4)));
401        assert_eq!(deserialized.get("epochs"), Some(&ParamValue::Int(10)));
402        assert_eq!(deserialized.get("model"), Some(&ParamValue::String("llama".to_string())));
403        assert_eq!(deserialized.get("lora"), Some(&ParamValue::Bool(true)));
404    }
405
406    // =========================================================================
407    // Diff tests
408    // =========================================================================
409
410    #[test]
411    fn test_diff_identical_is_empty() {
412        let mut a = ParamLogger::new();
413        a.log_param("lr", 1e-4_f64);
414        a.log_param("epochs", 10_i64);
415
416        let mut b = ParamLogger::new();
417        b.log_param("lr", 1e-4_f64);
418        b.log_param("epochs", 10_i64);
419
420        let diff = a.diff(&b);
421        assert!(diff.is_empty());
422        assert_eq!(diff.len(), 0);
423    }
424
425    #[test]
426    fn test_diff_empty_loggers() {
427        let a = ParamLogger::new();
428        let b = ParamLogger::new();
429        let diff = a.diff(&b);
430        assert!(diff.is_empty());
431    }
432
433    #[test]
434    fn test_diff_changed_values() {
435        let mut a = ParamLogger::new();
436        a.log_param("lr", 1e-3_f64);
437        a.log_param("epochs", 10_i64);
438
439        let mut b = ParamLogger::new();
440        b.log_param("lr", 1e-4_f64);
441        b.log_param("epochs", 10_i64);
442
443        let diff = a.diff(&b);
444        assert_eq!(diff.changed.len(), 1);
445        assert_eq!(
446            diff.changed.get("lr"),
447            Some(&(ParamValue::Float(1e-3), ParamValue::Float(1e-4)))
448        );
449        assert!(diff.added.is_empty());
450        assert!(diff.removed.is_empty());
451    }
452
453    #[test]
454    fn test_diff_added_params() {
455        let mut a = ParamLogger::new();
456        a.log_param("lr", 1e-4_f64);
457
458        let mut b = ParamLogger::new();
459        b.log_param("lr", 1e-4_f64);
460        b.log_param("warmup", 100_i64);
461
462        let diff = a.diff(&b);
463        assert!(diff.changed.is_empty());
464        assert_eq!(diff.added.len(), 1);
465        assert_eq!(diff.added.get("warmup"), Some(&ParamValue::Int(100)));
466        assert!(diff.removed.is_empty());
467    }
468
469    #[test]
470    fn test_diff_removed_params() {
471        let mut a = ParamLogger::new();
472        a.log_param("lr", 1e-4_f64);
473        a.log_param("warmup", 100_i64);
474
475        let mut b = ParamLogger::new();
476        b.log_param("lr", 1e-4_f64);
477
478        let diff = a.diff(&b);
479        assert!(diff.changed.is_empty());
480        assert!(diff.added.is_empty());
481        assert_eq!(diff.removed.len(), 1);
482        assert_eq!(diff.removed.get("warmup"), Some(&ParamValue::Int(100)));
483    }
484
485    #[test]
486    fn test_diff_mixed_changes() {
487        let mut a = ParamLogger::new();
488        a.log_param("lr", 1e-3_f64);
489        a.log_param("old_param", "remove_me");
490        a.log_param("same", 42_i64);
491
492        let mut b = ParamLogger::new();
493        b.log_param("lr", 1e-4_f64);
494        b.log_param("new_param", true);
495        b.log_param("same", 42_i64);
496
497        let diff = a.diff(&b);
498        assert_eq!(diff.changed.len(), 1);
499        assert_eq!(diff.added.len(), 1);
500        assert_eq!(diff.removed.len(), 1);
501        assert_eq!(diff.len(), 3);
502        assert!(!diff.is_empty());
503
504        assert!(diff.changed.contains_key("lr"));
505        assert!(diff.added.contains_key("new_param"));
506        assert!(diff.removed.contains_key("old_param"));
507    }
508
509    #[test]
510    fn test_diff_type_change_counts_as_changed() {
511        let mut a = ParamLogger::new();
512        a.log_param("value", 10_i64);
513
514        let mut b = ParamLogger::new();
515        b.log_param("value", 10.0_f64);
516
517        let diff = a.diff(&b);
518        assert_eq!(diff.changed.len(), 1);
519        assert_eq!(
520            diff.changed.get("value"),
521            Some(&(ParamValue::Int(10), ParamValue::Float(10.0)))
522        );
523    }
524
525    // =========================================================================
526    // ParamValue tests
527    // =========================================================================
528
529    #[test]
530    fn test_param_value_type_name() {
531        assert_eq!(ParamValue::String("x".into()).type_name(), "string");
532        assert_eq!(ParamValue::Float(1.0).type_name(), "float");
533        assert_eq!(ParamValue::Int(1).type_name(), "int");
534        assert_eq!(ParamValue::Bool(true).type_name(), "bool");
535    }
536
537    #[test]
538    fn test_param_value_display() {
539        assert_eq!(format!("{}", ParamValue::String("hello".into())), "hello");
540        assert_eq!(format!("{}", ParamValue::Float(3.14)), "3.14");
541        assert_eq!(format!("{}", ParamValue::Int(42)), "42");
542        assert_eq!(format!("{}", ParamValue::Bool(false)), "false");
543    }
544
545    #[test]
546    fn test_param_value_serde_roundtrip() {
547        let values = vec![
548            ParamValue::String("test".into()),
549            ParamValue::Float(1.23),
550            ParamValue::Int(-5),
551            ParamValue::Bool(true),
552        ];
553        for val in &values {
554            let json = serde_json::to_string(val).expect("serialize");
555            let back: ParamValue = serde_json::from_str(&json).expect("deserialize");
556            assert_eq!(&back, val);
557        }
558    }
559
560    #[test]
561    fn test_param_diff_is_empty_and_len() {
562        let diff =
563            ParamDiff { changed: HashMap::new(), added: HashMap::new(), removed: HashMap::new() };
564        assert!(diff.is_empty());
565        assert_eq!(diff.len(), 0);
566
567        let mut diff2 =
568            ParamDiff { changed: HashMap::new(), added: HashMap::new(), removed: HashMap::new() };
569        diff2.added.insert("x".to_string(), ParamValue::Int(1));
570        assert!(!diff2.is_empty());
571        assert_eq!(diff2.len(), 1);
572    }
573
574    #[test]
575    fn test_default_impl() {
576        let logger = ParamLogger::default();
577        assert!(logger.is_empty());
578    }
579}