use std::collections::HashMap;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
#[serde(transparent)]
pub struct TrialId(pub u64);
impl std::fmt::Display for TrialId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "trial-{:06}", self.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "kind", content = "detail")]
pub enum TrialStatus {
Running,
Completed,
Failed(String),
Pruned,
}
impl TrialStatus {
#[must_use]
pub fn is_terminal(&self) -> bool {
matches!(self, Self::Completed | Self::Failed(_) | Self::Pruned)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Trial {
pub id: TrialId,
pub config: HashMap<String, JsonValue>,
pub metric: Option<f64>,
pub status: TrialStatus,
pub started_at: DateTime<Utc>,
pub finished_at: Option<DateTime<Utc>>,
}
impl Trial {
#[must_use]
pub fn new(id: TrialId, config: HashMap<String, JsonValue>) -> Self {
Self {
id,
config,
metric: None,
status: TrialStatus::Running,
started_at: Utc::now(),
finished_at: None,
}
}
pub fn complete(&mut self, metric: f64) {
self.metric = Some(metric);
self.status = TrialStatus::Completed;
self.finished_at = Some(Utc::now());
}
pub fn fail(&mut self, msg: impl Into<String>) {
self.status = TrialStatus::Failed(msg.into());
self.finished_at = Some(Utc::now());
}
pub fn prune(&mut self) {
self.status = TrialStatus::Pruned;
self.finished_at = Some(Utc::now());
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn lifecycle_completion_sets_metric_and_terminal() {
let mut t = Trial::new(
TrialId(7),
HashMap::from([("lr".to_string(), json!(0.001))]),
);
assert_eq!(t.metric, None);
assert!(matches!(t.status, TrialStatus::Running));
t.complete(0.42);
assert_eq!(t.metric, Some(0.42));
assert!(t.status.is_terminal());
assert!(t.finished_at.is_some());
}
#[test]
fn failure_path_records_message() {
let mut t = Trial::new(TrialId(1), HashMap::new());
t.fail("evaluator exploded");
match &t.status {
TrialStatus::Failed(msg) => assert_eq!(msg, "evaluator exploded"),
other => panic!("unexpected status: {other:?}"),
}
assert!(t.status.is_terminal());
assert_eq!(t.metric, None);
}
#[test]
fn trial_id_display_format_is_stable() {
assert_eq!(TrialId(0).to_string(), "trial-000000");
assert_eq!(TrialId(42).to_string(), "trial-000042");
}
}