use std::{
fs::OpenOptions,
io::{BufRead, BufReader, Seek, Write},
path::{Path, PathBuf},
};
use crate::{error::TuneError, trial::Trial};
pub struct TrialJournal {
path: PathBuf,
file: std::fs::File,
}
impl TrialJournal {
pub fn open(path: impl AsRef<Path>) -> Result<Self, TuneError> {
let path = path.as_ref().to_path_buf();
let file = OpenOptions::new()
.create(true)
.append(true)
.read(true)
.open(&path)?;
Ok(Self { path, file })
}
pub fn record(&mut self, trial: &Trial) -> Result<(), TuneError> {
let mut line = serde_json::to_string(trial)?;
line.push('\n');
self.file.write_all(line.as_bytes())?;
self.file.flush()?;
Ok(())
}
pub fn replay(path: impl AsRef<Path>) -> Result<Vec<Trial>, TuneError> {
let file = match OpenOptions::new().read(true).open(path.as_ref()) {
Ok(f) => f,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(Vec::new()),
Err(e) => return Err(e.into()),
};
let reader = BufReader::new(file);
let mut by_id: std::collections::BTreeMap<u64, Trial> = std::collections::BTreeMap::new();
for (lineno, line) in reader.lines().enumerate() {
let line = line?;
if line.trim().is_empty() {
continue;
}
let trial: Trial = serde_json::from_str(&line).map_err(|e| {
tracing::error!(line = lineno + 1, "journal parse error: {e}");
TuneError::Serde(e)
})?;
by_id.insert(trial.id.0, trial);
}
Ok(by_id.into_values().collect())
}
pub fn reopen(&mut self) -> Result<(), TuneError> {
self.file = OpenOptions::new()
.append(true)
.read(true)
.open(&self.path)?;
let _ = self.file.seek(std::io::SeekFrom::End(0))?;
Ok(())
}
#[must_use]
pub fn path(&self) -> &Path {
&self.path
}
}
#[cfg(test)]
#[allow(clippy::cast_precision_loss)]
mod tests {
use super::*;
use crate::trial::{TrialId, TrialStatus};
use serde_json::json;
use std::collections::HashMap;
#[test]
fn journal_replay_reconstructs_history() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("trials.jsonl");
{
let mut j = TrialJournal::open(&path).unwrap();
for i in 0..5 {
let mut t = Trial::new(
TrialId(i),
HashMap::from([("lr".to_string(), json!(0.001 * (i + 1) as f64))]),
);
j.record(&t).unwrap(); t.complete(0.5 - i as f64 * 0.05);
j.record(&t).unwrap(); }
}
let replayed = TrialJournal::replay(&path).unwrap();
assert_eq!(replayed.len(), 5);
for (i, t) in replayed.iter().enumerate() {
assert_eq!(t.id, TrialId(i as u64));
assert!(matches!(t.status, TrialStatus::Completed));
assert!(t.metric.is_some());
}
}
#[test]
fn journal_replay_missing_file_yields_empty() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("does-not-exist.jsonl");
let replayed = TrialJournal::replay(&path).unwrap();
assert!(replayed.is_empty());
}
}