use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Hyperparameters {
pub learning_rate: f64,
pub batch_size: usize,
pub epochs: usize,
#[serde(default)]
pub weight_decay: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_grad_norm: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub warmup_steps: Option<usize>,
#[serde(default)]
pub custom: HashMap<String, HyperparamValue>,
}
impl Default for Hyperparameters {
fn default() -> Self {
Self {
learning_rate: 1e-3,
batch_size: 32,
epochs: 10,
weight_decay: 0.0,
max_grad_norm: None,
warmup_steps: None,
custom: HashMap::new(),
}
}
}
impl Hyperparameters {
#[must_use]
pub fn builder() -> HyperparametersBuilder {
HyperparametersBuilder::new()
}
pub fn set_custom(&mut self, name: impl Into<String>, value: HyperparamValue) {
self.custom.insert(name.into(), value);
}
#[must_use]
pub fn get_custom(&self, name: &str) -> Option<&HyperparamValue> {
self.custom.get(name)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum HyperparamValue {
Float(f64),
Int(i64),
Bool(bool),
String(String),
List(Vec<HyperparamValue>),
}
impl HyperparamValue {
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn as_float(&self) -> Option<f64> {
match self {
Self::Float(f) => Some(*f),
Self::Int(i) => Some(*i as f64),
_ => None,
}
}
#[must_use]
#[allow(clippy::cast_possible_truncation)]
pub fn as_int(&self) -> Option<i64> {
match self {
Self::Int(i) => Some(*i),
Self::Float(f) => Some(*f as i64),
_ => None,
}
}
#[must_use]
pub fn as_bool(&self) -> Option<bool> {
match self {
Self::Bool(b) => Some(*b),
_ => None,
}
}
#[must_use]
pub fn as_string(&self) -> Option<&str> {
match self {
Self::String(s) => Some(s),
_ => None,
}
}
#[must_use]
pub fn as_list(&self) -> Option<&[HyperparamValue]> {
match self {
Self::List(l) => Some(l),
_ => None,
}
}
}
impl From<f64> for HyperparamValue {
fn from(v: f64) -> Self {
Self::Float(v)
}
}
impl From<i64> for HyperparamValue {
fn from(v: i64) -> Self {
Self::Int(v)
}
}
impl From<i32> for HyperparamValue {
fn from(v: i32) -> Self {
Self::Int(i64::from(v))
}
}
impl From<bool> for HyperparamValue {
fn from(v: bool) -> Self {
Self::Bool(v)
}
}
impl From<String> for HyperparamValue {
fn from(v: String) -> Self {
Self::String(v)
}
}
impl From<&str> for HyperparamValue {
fn from(v: &str) -> Self {
Self::String(v.to_string())
}
}
impl<T: Into<HyperparamValue>> From<Vec<T>> for HyperparamValue {
fn from(v: Vec<T>) -> Self {
Self::List(v.into_iter().map(Into::into).collect())
}
}
#[derive(Debug, Default)]
pub struct HyperparametersBuilder {
params: Hyperparameters,
}
impl HyperparametersBuilder {
#[must_use]
pub fn new() -> Self {
Self { params: Hyperparameters::default() }
}
#[must_use]
pub fn learning_rate(mut self, lr: f64) -> Self {
self.params.learning_rate = lr;
self
}
#[must_use]
pub fn batch_size(mut self, size: usize) -> Self {
self.params.batch_size = size;
self
}
#[must_use]
pub fn epochs(mut self, epochs: usize) -> Self {
self.params.epochs = epochs;
self
}
#[must_use]
pub fn weight_decay(mut self, decay: f64) -> Self {
self.params.weight_decay = decay;
self
}
#[must_use]
pub fn max_grad_norm(mut self, norm: f64) -> Self {
self.params.max_grad_norm = Some(norm);
self
}
#[must_use]
pub fn warmup_steps(mut self, steps: usize) -> Self {
self.params.warmup_steps = Some(steps);
self
}
#[must_use]
pub fn custom(mut self, name: impl Into<String>, value: impl Into<HyperparamValue>) -> Self {
self.params.custom.insert(name.into(), value.into());
self
}
#[must_use]
pub fn build(self) -> Hyperparameters {
self.params
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hyperparameters_default() {
let params = Hyperparameters::default();
assert!((params.learning_rate - 1e-3).abs() < 1e-10);
assert_eq!(params.batch_size, 32);
assert_eq!(params.epochs, 10);
}
#[test]
fn test_hyperparameters_builder() {
let params = Hyperparameters::builder()
.learning_rate(2e-5)
.batch_size(64)
.epochs(3)
.weight_decay(0.01)
.max_grad_norm(1.0)
.warmup_steps(100)
.custom("dropout", 0.1)
.build();
assert!((params.learning_rate - 2e-5).abs() < 1e-10);
assert_eq!(params.batch_size, 64);
assert_eq!(params.epochs, 3);
assert!((params.weight_decay - 0.01).abs() < 1e-10);
assert_eq!(params.max_grad_norm, Some(1.0));
assert_eq!(params.warmup_steps, Some(100));
assert_eq!(params.get_custom("dropout").and_then(|v| v.as_float()), Some(0.1));
}
#[test]
fn test_hyperparam_value_types() {
let float_val = HyperparamValue::Float(3.14);
assert_eq!(float_val.as_float(), Some(3.14));
assert_eq!(float_val.as_int(), Some(3));
let int_val = HyperparamValue::Int(42);
assert_eq!(int_val.as_int(), Some(42));
assert_eq!(int_val.as_float(), Some(42.0));
let bool_val = HyperparamValue::Bool(true);
assert_eq!(bool_val.as_bool(), Some(true));
let string_val = HyperparamValue::String("test".to_string());
assert_eq!(string_val.as_string(), Some("test"));
let list_val =
HyperparamValue::List(vec![HyperparamValue::Int(1), HyperparamValue::Int(2)]);
assert_eq!(list_val.as_list().map(|l| l.len()), Some(2));
}
#[test]
fn test_hyperparam_value_from() {
let from_float: HyperparamValue = 3.14.into();
assert!(matches!(from_float, HyperparamValue::Float(_)));
let from_int: HyperparamValue = 42i64.into();
assert!(matches!(from_int, HyperparamValue::Int(_)));
let from_bool: HyperparamValue = true.into();
assert!(matches!(from_bool, HyperparamValue::Bool(_)));
let from_str: HyperparamValue = "test".into();
assert!(matches!(from_str, HyperparamValue::String(_)));
let from_vec: HyperparamValue = vec![1i64, 2i64, 3i64].into();
assert!(matches!(from_vec, HyperparamValue::List(_)));
}
#[test]
fn test_hyperparameters_serialization() {
let params = Hyperparameters::builder()
.learning_rate(1e-4)
.batch_size(16)
.custom("hidden_size", 768i64)
.build();
let json = serde_json::to_string(¶ms).unwrap();
let deserialized: Hyperparameters = serde_json::from_str(&json).unwrap();
assert!((params.learning_rate - deserialized.learning_rate).abs() < 1e-10);
assert_eq!(params.batch_size, deserialized.batch_size);
}
#[test]
fn test_set_get_custom() {
let mut params = Hyperparameters::default();
params.set_custom("test_param", HyperparamValue::Float(0.5));
let value = params.get_custom("test_param");
assert_eq!(value.and_then(|v| v.as_float()), Some(0.5));
assert!(params.get_custom("nonexistent").is_none());
}
}