use std::fs;
use std::io;
use std::path::{Path, PathBuf};
use std::time::Duration;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Baseline {
pub name: String,
pub mean_ns: u64,
pub samples: u64,
pub ops_per_sec: f64,
}
impl Baseline {
pub fn mean(&self) -> Duration {
Duration::from_nanos(self.mean_ns)
}
}
pub trait BaselineStore {
fn load(&self, scope: &str, name: &str) -> io::Result<Option<Baseline>>;
fn save(&self, scope: &str, baseline: &Baseline) -> io::Result<()>;
}
pub struct JsonFileBaselineStore {
root: PathBuf,
}
impl JsonFileBaselineStore {
pub fn new(root: impl Into<PathBuf>) -> Self {
Self { root: root.into() }
}
fn path_for(&self, scope: &str, name: &str) -> PathBuf {
let safe_scope = sanitize(scope);
let safe_name = sanitize(name);
self.root.join(safe_scope).join(format!("{safe_name}.json"))
}
}
impl BaselineStore for JsonFileBaselineStore {
fn load(&self, scope: &str, name: &str) -> io::Result<Option<Baseline>> {
let path = self.path_for(scope, name);
match fs::read(&path) {
Ok(bytes) => {
let b: Baseline = serde_json::from_slice(&bytes).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("invalid baseline at {}: {}", path.display(), e),
)
})?;
Ok(Some(b))
}
Err(e) if e.kind() == io::ErrorKind::NotFound => Ok(None),
Err(e) => Err(e),
}
}
fn save(&self, scope: &str, baseline: &Baseline) -> io::Result<()> {
let path = self.path_for(scope, &baseline.name);
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
let bytes = serde_json::to_vec_pretty(baseline)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("serialize: {}", e)))?;
atomic_write(&path, &bytes)
}
}
fn atomic_write(path: &Path, bytes: &[u8]) -> io::Result<()> {
let parent = path.parent().unwrap_or(Path::new("."));
let file_name = path
.file_name()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "no file name"))?;
let temp = parent.join(format!(".{}.tmp", file_name.to_string_lossy()));
fs::write(&temp, bytes)?;
fs::rename(&temp, path)?;
Ok(())
}
fn sanitize(s: &str) -> String {
s.chars()
.map(|c| {
if c.is_ascii_alphanumeric() || matches!(c, '_' | '-' | '.') {
c
} else {
'_'
}
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_trip_baseline_through_json_store() {
let dir = tempfile::tempdir().unwrap();
let store = JsonFileBaselineStore::new(dir.path());
let b = Baseline {
name: "parse_query".into(),
mean_ns: 1234,
samples: 1000,
ops_per_sec: 810_000.0,
};
store.save("abc1234", &b).unwrap();
let back = store.load("abc1234", "parse_query").unwrap().unwrap();
assert_eq!(back, b);
}
#[test]
fn missing_baseline_returns_none() {
let dir = tempfile::tempdir().unwrap();
let store = JsonFileBaselineStore::new(dir.path());
let r = store.load("anything", "absent").unwrap();
assert!(r.is_none());
}
#[test]
fn save_creates_parent_directories() {
let dir = tempfile::tempdir().unwrap();
let store = JsonFileBaselineStore::new(dir.path().join("not_yet_existing"));
let b = Baseline {
name: "x".into(),
mean_ns: 1,
samples: 1,
ops_per_sec: 1.0,
};
store.save("main", &b).unwrap();
let back = store.load("main", "x").unwrap().unwrap();
assert_eq!(back, b);
}
#[test]
fn save_overwrites_existing() {
let dir = tempfile::tempdir().unwrap();
let store = JsonFileBaselineStore::new(dir.path());
let b1 = Baseline {
name: "x".into(),
mean_ns: 100,
samples: 1,
ops_per_sec: 10.0,
};
let b2 = Baseline {
name: "x".into(),
mean_ns: 200,
samples: 2,
ops_per_sec: 5.0,
};
store.save("main", &b1).unwrap();
store.save("main", &b2).unwrap();
let back = store.load("main", "x").unwrap().unwrap();
assert_eq!(back, b2);
}
#[test]
fn sanitize_blocks_path_traversal_in_scope_and_name() {
let dir = tempfile::tempdir().unwrap();
let store = JsonFileBaselineStore::new(dir.path());
let b = Baseline {
name: "../escaped".into(),
mean_ns: 1,
samples: 1,
ops_per_sec: 1.0,
};
store.save("../danger", &b).unwrap();
let parent = dir.path().parent().unwrap();
let entries_in_parent: usize = fs::read_dir(parent)
.unwrap()
.filter_map(|e| e.ok())
.filter(|e| {
e.path() != dir.path() && e.file_name().to_string_lossy().starts_with("danger")
})
.count();
assert_eq!(entries_in_parent, 0);
}
#[test]
fn corrupt_baseline_yields_invalid_data_error() {
let dir = tempfile::tempdir().unwrap();
let store = JsonFileBaselineStore::new(dir.path());
let path = store.path_for("main", "broken");
fs::create_dir_all(path.parent().unwrap()).unwrap();
fs::write(&path, b"{ this is not json").unwrap();
let err = store.load("main", "broken").unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
}
#[test]
fn baseline_mean_returns_duration() {
let b = Baseline {
name: "x".into(),
mean_ns: 5_000,
samples: 1,
ops_per_sec: 1.0,
};
assert_eq!(b.mean(), Duration::from_nanos(5_000));
}
}