1use serde::{Deserialize, Serialize};
28use std::collections::HashMap;
29
30#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
36pub enum ParamValue {
37 String(String),
39 Float(f64),
41 Int(i64),
43 Bool(bool),
45}
46
47impl ParamValue {
48 pub fn type_name(&self) -> &'static str {
50 match self {
51 ParamValue::String(_) => "string",
52 ParamValue::Float(_) => "float",
53 ParamValue::Int(_) => "int",
54 ParamValue::Bool(_) => "bool",
55 }
56 }
57}
58
59impl std::fmt::Display for ParamValue {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 match self {
62 ParamValue::String(s) => write!(f, "{s}"),
63 ParamValue::Float(v) => write!(f, "{v}"),
64 ParamValue::Int(v) => write!(f, "{v}"),
65 ParamValue::Bool(v) => write!(f, "{v}"),
66 }
67 }
68}
69
70impl From<&str> for ParamValue {
73 fn from(s: &str) -> Self {
74 ParamValue::String(s.to_string())
75 }
76}
77
78impl From<String> for ParamValue {
79 fn from(s: String) -> Self {
80 ParamValue::String(s)
81 }
82}
83
84impl From<f64> for ParamValue {
85 fn from(v: f64) -> Self {
86 ParamValue::Float(v)
87 }
88}
89
90impl From<f32> for ParamValue {
91 fn from(v: f32) -> Self {
92 ParamValue::Float(f64::from(v))
93 }
94}
95
96impl From<i64> for ParamValue {
97 fn from(v: i64) -> Self {
98 ParamValue::Int(v)
99 }
100}
101
102impl From<i32> for ParamValue {
103 fn from(v: i32) -> Self {
104 ParamValue::Int(i64::from(v))
105 }
106}
107
108impl From<bool> for ParamValue {
109 fn from(v: bool) -> Self {
110 ParamValue::Bool(v)
111 }
112}
113
114#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
123pub struct ParamDiff {
124 pub changed: HashMap<String, (ParamValue, ParamValue)>,
127 pub added: HashMap<String, ParamValue>,
129 pub removed: HashMap<String, ParamValue>,
131}
132
133impl ParamDiff {
134 pub fn is_empty(&self) -> bool {
136 self.changed.is_empty() && self.added.is_empty() && self.removed.is_empty()
137 }
138
139 pub fn len(&self) -> usize {
141 self.changed.len() + self.added.len() + self.removed.len()
142 }
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct ParamLogger {
155 params: HashMap<String, ParamValue>,
156}
157
158impl ParamLogger {
159 pub fn new() -> Self {
161 Self { params: HashMap::new() }
162 }
163
164 pub fn log_param(&mut self, key: &str, value: impl Into<ParamValue>) {
166 self.params.insert(key.to_string(), value.into());
167 }
168
169 pub fn log_params(&mut self, params: HashMap<String, ParamValue>) {
171 self.params.extend(params);
172 }
173
174 pub fn get_param(&self, key: &str) -> Option<&ParamValue> {
176 self.params.get(key)
177 }
178
179 pub fn get_all_params(&self) -> &HashMap<String, ParamValue> {
181 &self.params
182 }
183
184 pub fn len(&self) -> usize {
186 self.params.len()
187 }
188
189 pub fn is_empty(&self) -> bool {
191 self.params.is_empty()
192 }
193
194 pub fn to_json(&self) -> String {
198 let sorted: std::collections::BTreeMap<&String, &ParamValue> = self.params.iter().collect();
200 serde_json::to_string_pretty(&sorted).unwrap_or_else(|e| {
201 eprintln!("ParamLogger JSON serialization failed: {e}");
202 "{}".to_string()
203 })
204 }
205
206 pub fn diff(&self, other: &ParamLogger) -> ParamDiff {
212 let mut changed = HashMap::new();
213 let mut added = HashMap::new();
214 let mut removed = HashMap::new();
215
216 for (key, self_val) in &self.params {
218 match other.params.get(key) {
219 Some(other_val) if self_val != other_val => {
220 changed.insert(key.clone(), (self_val.clone(), other_val.clone()));
221 }
222 None => {
223 removed.insert(key.clone(), self_val.clone());
224 }
225 _ => {} }
227 }
228
229 for (key, other_val) in &other.params {
231 if !self.params.contains_key(key) {
232 added.insert(key.clone(), other_val.clone());
233 }
234 }
235
236 ParamDiff { changed, added, removed }
237 }
238}
239
240impl Default for ParamLogger {
241 fn default() -> Self {
242 Self::new()
243 }
244}
245
246#[cfg(test)]
251mod tests {
252 use super::*;
253
254 #[test]
255 fn test_param_logger_new_is_empty() {
256 let logger = ParamLogger::new();
257 assert!(logger.is_empty());
258 assert_eq!(logger.len(), 0);
259 }
260
261 #[test]
262 fn test_log_param_string() {
263 let mut logger = ParamLogger::new();
264 logger.log_param("model", "llama-7b");
265 assert_eq!(logger.get_param("model"), Some(&ParamValue::String("llama-7b".to_string())));
266 }
267
268 #[test]
269 fn test_log_param_float() {
270 let mut logger = ParamLogger::new();
271 logger.log_param("lr", 1e-4_f64);
272 assert_eq!(logger.get_param("lr"), Some(&ParamValue::Float(1e-4)));
273 }
274
275 #[test]
276 fn test_log_param_f32_converts_to_f64() {
277 let mut logger = ParamLogger::new();
278 logger.log_param("weight_decay", 0.01_f32);
279 assert_eq!(logger.get_param("weight_decay"), Some(&ParamValue::Float(f64::from(0.01_f32))));
280 }
281
282 #[test]
283 fn test_log_param_int() {
284 let mut logger = ParamLogger::new();
285 logger.log_param("epochs", 10_i64);
286 assert_eq!(logger.get_param("epochs"), Some(&ParamValue::Int(10)));
287 }
288
289 #[test]
290 fn test_log_param_i32_converts_to_i64() {
291 let mut logger = ParamLogger::new();
292 logger.log_param("batch_size", 32_i32);
293 assert_eq!(logger.get_param("batch_size"), Some(&ParamValue::Int(32)));
294 }
295
296 #[test]
297 fn test_log_param_bool() {
298 let mut logger = ParamLogger::new();
299 logger.log_param("use_lora", true);
300 assert_eq!(logger.get_param("use_lora"), Some(&ParamValue::Bool(true)));
301 }
302
303 #[test]
304 fn test_log_param_owned_string() {
305 let mut logger = ParamLogger::new();
306 logger.log_param("optimizer", String::from("adamw"));
307 assert_eq!(logger.get_param("optimizer"), Some(&ParamValue::String("adamw".to_string())));
308 }
309
310 #[test]
311 fn test_log_param_overwrites() {
312 let mut logger = ParamLogger::new();
313 logger.log_param("lr", 1e-3_f64);
314 logger.log_param("lr", 1e-4_f64);
315 assert_eq!(logger.get_param("lr"), Some(&ParamValue::Float(1e-4)));
316 assert_eq!(logger.len(), 1);
317 }
318
319 #[test]
320 fn test_get_param_missing_returns_none() {
321 let logger = ParamLogger::new();
322 assert_eq!(logger.get_param("nonexistent"), None);
323 }
324
325 #[test]
326 fn test_log_params_bulk() {
327 let mut logger = ParamLogger::new();
328 let mut params = HashMap::new();
329 params.insert("lr".to_string(), ParamValue::Float(1e-4));
330 params.insert("epochs".to_string(), ParamValue::Int(10));
331 params.insert("model".to_string(), ParamValue::String("gpt2".to_string()));
332 logger.log_params(params);
333
334 assert_eq!(logger.len(), 3);
335 assert_eq!(logger.get_param("lr"), Some(&ParamValue::Float(1e-4)));
336 assert_eq!(logger.get_param("epochs"), Some(&ParamValue::Int(10)));
337 }
338
339 #[test]
340 fn test_get_all_params() {
341 let mut logger = ParamLogger::new();
342 logger.log_param("a", 1_i64);
343 logger.log_param("b", 2_i64);
344
345 let all = logger.get_all_params();
346 assert_eq!(all.len(), 2);
347 assert!(all.contains_key("a"));
348 assert!(all.contains_key("b"));
349 }
350
351 #[test]
352 fn test_to_json_deterministic() {
353 let mut logger = ParamLogger::new();
354 logger.log_param("z_param", 1_i64);
355 logger.log_param("a_param", 2_i64);
356 logger.log_param("m_param", 3_i64);
357
358 let json = logger.to_json();
359 let a_pos = json.find("a_param").expect("a_param not found");
361 let m_pos = json.find("m_param").expect("m_param not found");
362 let z_pos = json.find("z_param").expect("z_param not found");
363 assert!(a_pos < m_pos, "a_param should come before m_param");
364 assert!(m_pos < z_pos, "m_param should come before z_param");
365 }
366
367 #[test]
368 fn test_to_json_contains_values() {
369 let mut logger = ParamLogger::new();
370 logger.log_param("lr", 0.001_f64);
371 logger.log_param("use_lora", true);
372 logger.log_param("model", "gpt2");
373
374 let json = logger.to_json();
375 assert!(json.contains("0.001"));
376 assert!(json.contains("true"));
377 assert!(json.contains("gpt2"));
378 }
379
380 #[test]
381 fn test_to_json_empty() {
382 let logger = ParamLogger::new();
383 let json = logger.to_json();
384 assert_eq!(json, "{}");
385 }
386
387 #[test]
388 fn test_to_json_roundtrip() {
389 let mut logger = ParamLogger::new();
390 logger.log_param("lr", 1e-4_f64);
391 logger.log_param("epochs", 10_i64);
392 logger.log_param("model", "llama");
393 logger.log_param("lora", true);
394
395 let json = logger.to_json();
396 let deserialized: std::collections::BTreeMap<String, ParamValue> =
397 serde_json::from_str(&json).expect("should deserialize");
398
399 assert_eq!(deserialized.len(), 4);
400 assert_eq!(deserialized.get("lr"), Some(&ParamValue::Float(1e-4)));
401 assert_eq!(deserialized.get("epochs"), Some(&ParamValue::Int(10)));
402 assert_eq!(deserialized.get("model"), Some(&ParamValue::String("llama".to_string())));
403 assert_eq!(deserialized.get("lora"), Some(&ParamValue::Bool(true)));
404 }
405
406 #[test]
411 fn test_diff_identical_is_empty() {
412 let mut a = ParamLogger::new();
413 a.log_param("lr", 1e-4_f64);
414 a.log_param("epochs", 10_i64);
415
416 let mut b = ParamLogger::new();
417 b.log_param("lr", 1e-4_f64);
418 b.log_param("epochs", 10_i64);
419
420 let diff = a.diff(&b);
421 assert!(diff.is_empty());
422 assert_eq!(diff.len(), 0);
423 }
424
425 #[test]
426 fn test_diff_empty_loggers() {
427 let a = ParamLogger::new();
428 let b = ParamLogger::new();
429 let diff = a.diff(&b);
430 assert!(diff.is_empty());
431 }
432
433 #[test]
434 fn test_diff_changed_values() {
435 let mut a = ParamLogger::new();
436 a.log_param("lr", 1e-3_f64);
437 a.log_param("epochs", 10_i64);
438
439 let mut b = ParamLogger::new();
440 b.log_param("lr", 1e-4_f64);
441 b.log_param("epochs", 10_i64);
442
443 let diff = a.diff(&b);
444 assert_eq!(diff.changed.len(), 1);
445 assert_eq!(
446 diff.changed.get("lr"),
447 Some(&(ParamValue::Float(1e-3), ParamValue::Float(1e-4)))
448 );
449 assert!(diff.added.is_empty());
450 assert!(diff.removed.is_empty());
451 }
452
453 #[test]
454 fn test_diff_added_params() {
455 let mut a = ParamLogger::new();
456 a.log_param("lr", 1e-4_f64);
457
458 let mut b = ParamLogger::new();
459 b.log_param("lr", 1e-4_f64);
460 b.log_param("warmup", 100_i64);
461
462 let diff = a.diff(&b);
463 assert!(diff.changed.is_empty());
464 assert_eq!(diff.added.len(), 1);
465 assert_eq!(diff.added.get("warmup"), Some(&ParamValue::Int(100)));
466 assert!(diff.removed.is_empty());
467 }
468
469 #[test]
470 fn test_diff_removed_params() {
471 let mut a = ParamLogger::new();
472 a.log_param("lr", 1e-4_f64);
473 a.log_param("warmup", 100_i64);
474
475 let mut b = ParamLogger::new();
476 b.log_param("lr", 1e-4_f64);
477
478 let diff = a.diff(&b);
479 assert!(diff.changed.is_empty());
480 assert!(diff.added.is_empty());
481 assert_eq!(diff.removed.len(), 1);
482 assert_eq!(diff.removed.get("warmup"), Some(&ParamValue::Int(100)));
483 }
484
485 #[test]
486 fn test_diff_mixed_changes() {
487 let mut a = ParamLogger::new();
488 a.log_param("lr", 1e-3_f64);
489 a.log_param("old_param", "remove_me");
490 a.log_param("same", 42_i64);
491
492 let mut b = ParamLogger::new();
493 b.log_param("lr", 1e-4_f64);
494 b.log_param("new_param", true);
495 b.log_param("same", 42_i64);
496
497 let diff = a.diff(&b);
498 assert_eq!(diff.changed.len(), 1);
499 assert_eq!(diff.added.len(), 1);
500 assert_eq!(diff.removed.len(), 1);
501 assert_eq!(diff.len(), 3);
502 assert!(!diff.is_empty());
503
504 assert!(diff.changed.contains_key("lr"));
505 assert!(diff.added.contains_key("new_param"));
506 assert!(diff.removed.contains_key("old_param"));
507 }
508
509 #[test]
510 fn test_diff_type_change_counts_as_changed() {
511 let mut a = ParamLogger::new();
512 a.log_param("value", 10_i64);
513
514 let mut b = ParamLogger::new();
515 b.log_param("value", 10.0_f64);
516
517 let diff = a.diff(&b);
518 assert_eq!(diff.changed.len(), 1);
519 assert_eq!(
520 diff.changed.get("value"),
521 Some(&(ParamValue::Int(10), ParamValue::Float(10.0)))
522 );
523 }
524
525 #[test]
530 fn test_param_value_type_name() {
531 assert_eq!(ParamValue::String("x".into()).type_name(), "string");
532 assert_eq!(ParamValue::Float(1.0).type_name(), "float");
533 assert_eq!(ParamValue::Int(1).type_name(), "int");
534 assert_eq!(ParamValue::Bool(true).type_name(), "bool");
535 }
536
537 #[test]
538 fn test_param_value_display() {
539 assert_eq!(format!("{}", ParamValue::String("hello".into())), "hello");
540 assert_eq!(format!("{}", ParamValue::Float(3.14)), "3.14");
541 assert_eq!(format!("{}", ParamValue::Int(42)), "42");
542 assert_eq!(format!("{}", ParamValue::Bool(false)), "false");
543 }
544
545 #[test]
546 fn test_param_value_serde_roundtrip() {
547 let values = vec![
548 ParamValue::String("test".into()),
549 ParamValue::Float(1.23),
550 ParamValue::Int(-5),
551 ParamValue::Bool(true),
552 ];
553 for val in &values {
554 let json = serde_json::to_string(val).expect("serialize");
555 let back: ParamValue = serde_json::from_str(&json).expect("deserialize");
556 assert_eq!(&back, val);
557 }
558 }
559
560 #[test]
561 fn test_param_diff_is_empty_and_len() {
562 let diff =
563 ParamDiff { changed: HashMap::new(), added: HashMap::new(), removed: HashMap::new() };
564 assert!(diff.is_empty());
565 assert_eq!(diff.len(), 0);
566
567 let mut diff2 =
568 ParamDiff { changed: HashMap::new(), added: HashMap::new(), removed: HashMap::new() };
569 diff2.added.insert("x".to_string(), ParamValue::Int(1));
570 assert!(!diff2.is_empty());
571 assert_eq!(diff2.len(), 1);
572 }
573
574 #[test]
575 fn test_default_impl() {
576 let logger = ParamLogger::default();
577 assert!(logger.is_empty());
578 }
579}