use crate::eval::loader::DatasetId;
use crate::muxer_harness as mh;
use muxer::{Exp3IxState, LinUcbState, MonitoredWindow, Outcome, Summary};
use std::collections::{BTreeMap, VecDeque};
use std::fs;
use std::path::PathBuf;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct FailKindCount {
pub kind: String,
pub count: u64,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct HistoryWindow {
#[serde(default)]
pub cap: usize,
#[serde(default)]
pub buf: VecDeque<Outcome>,
}
impl HistoryWindow {
#[must_use]
pub fn new(cap: usize) -> Self {
Self {
cap: cap.max(1),
buf: VecDeque::new(),
}
}
pub fn push(&mut self, o: Outcome) {
let cap = self.cap.max(1);
self.buf.push_back(o);
while self.buf.len() > cap {
self.buf.pop_front();
}
}
#[must_use]
pub fn summary(&self) -> Summary {
let mut s = Summary::default();
let mut q_sum = 0.0f64;
let mut q_n = 0u64;
for o in &self.buf {
s.calls = s.calls.saturating_add(1);
s.ok = s.ok.saturating_add(o.ok as u64);
s.junk = s.junk.saturating_add(o.junk as u64);
s.hard_junk = s.hard_junk.saturating_add(o.hard_junk as u64);
s.cost_units = s.cost_units.saturating_add(o.cost_units);
s.elapsed_ms_sum = s.elapsed_ms_sum.saturating_add(o.elapsed_ms);
if let Some(q) = o.quality_score {
q_sum += q;
q_n += 1;
}
}
s.mean_quality_score = if q_n > 0 {
Some(q_sum / q_n as f64)
} else {
None
};
s
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct BackendHistory {
#[serde(default)]
pub version: u32,
pub window_cap: usize,
pub windows: BTreeMap<String, HistoryWindow>,
#[serde(default)]
pub fail_kinds: BTreeMap<String, VecDeque<Option<String>>>,
#[serde(default)]
pub exp3ix_state: Option<Exp3IxState>,
#[serde(default)]
pub linucb_state: Option<LinUcbState>,
}
impl BackendHistory {
pub fn try_load(path: &PathBuf, window_cap: usize) -> Result<Self, String> {
#[derive(Debug, Clone, serde::Deserialize)]
struct BackendHistorySerde {
#[serde(default)]
version: u32,
#[serde(default)]
window_cap: usize,
#[serde(default)]
windows: BTreeMap<String, HistoryWindow>,
#[serde(default)]
fail_kinds: BTreeMap<String, VecDeque<Option<String>>>,
#[serde(default)]
exp3ix_state: Option<Exp3IxState>,
#[serde(default)]
linucb_state: Option<LinUcbState>,
}
let cap = window_cap.max(1);
let bytes = match fs::read(path) {
Ok(b) => b,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
return Ok(Self {
version: 3,
window_cap: cap,
windows: BTreeMap::new(),
fail_kinds: BTreeMap::new(),
exp3ix_state: None,
linucb_state: None,
});
}
Err(e) => return Err(format!("muxer history: read {}: {e}", path.display())),
};
let h = serde_json::from_slice::<BackendHistorySerde>(&bytes)
.map_err(|e| format!("muxer history: parse {}: {e}", path.display()))?;
let _ = h.window_cap;
let mut windows: BTreeMap<String, HistoryWindow> = BTreeMap::new();
let mut fail_kinds: BTreeMap<String, VecDeque<Option<String>>> = BTreeMap::new();
for (k, w) in h.windows {
let mut out = HistoryWindow::new(cap);
for mut o in w.buf {
if h.version <= 1 {
o.ok = !o.hard_junk && !o.junk;
} else {
o.ok = o.ok && !o.hard_junk && !o.junk;
}
out.push(o);
}
windows.insert(k, out);
}
for (k, mut fk) in h.fail_kinds {
while fk.len() > cap {
fk.pop_front();
}
fail_kinds.insert(k, fk);
}
Ok(Self {
version: 3,
window_cap: cap,
windows,
fail_kinds,
exp3ix_state: h.exp3ix_state,
linucb_state: h.linucb_state,
})
}
pub fn load(path: &PathBuf, window_cap: usize) -> Self {
Self::try_load(path, window_cap).unwrap_or_else(|_e| Self {
version: 3,
window_cap: window_cap.max(1),
windows: BTreeMap::new(),
fail_kinds: BTreeMap::new(),
exp3ix_state: None,
linucb_state: None,
})
}
pub fn save(&self, path: &PathBuf) {
if let Some(parent) = path.parent() {
let _ = fs::create_dir_all(parent);
}
let Ok(bytes) = serde_json::to_vec_pretty(self) else {
return;
};
let tmp_path = {
let mut name = path.file_name().unwrap_or_default().to_owned();
name.push(".tmp");
path.with_file_name(name)
};
if fs::write(&tmp_path, &bytes).is_ok() {
let _ = fs::rename(&tmp_path, path);
}
}
pub fn push_with_fail_kind(&mut self, key: &str, o: Outcome, fail_kind: Option<String>) {
let w = self
.windows
.entry(key.to_string())
.or_insert_with(|| HistoryWindow::new(self.window_cap));
w.push(o);
let cap = self.window_cap;
let fk = self.fail_kinds.entry(key.to_string()).or_default();
fk.push_back(fail_kind);
while fk.len() > cap {
fk.pop_front();
}
}
pub fn dataset_key(backend: &str, dataset: DatasetId) -> String {
format!("{backend}@@{:?}", dataset)
}
pub fn observed_summary_for(
&self,
backend: &str,
datasets: Option<&[DatasetId]>,
per_dataset: bool,
) -> Summary {
if per_dataset {
if let Some(datasets) = datasets {
let mut agg = Summary::default();
for &ds in datasets {
let k = Self::dataset_key(backend, ds);
if let Some(w) = self.windows.get(&k) {
let s = w.summary();
if s.calls == 0 {
continue;
}
agg.calls = agg.calls.saturating_add(s.calls);
agg.ok = agg.ok.saturating_add(s.ok);
agg.junk = agg.junk.saturating_add(s.junk);
agg.hard_junk = agg.hard_junk.saturating_add(s.hard_junk);
agg.cost_units = agg.cost_units.saturating_add(s.cost_units);
agg.elapsed_ms_sum = agg.elapsed_ms_sum.saturating_add(s.elapsed_ms_sum);
}
}
if agg.calls > 0 {
return agg;
}
}
}
self.windows
.get(backend)
.map(|w| w.summary())
.unwrap_or_default()
}
fn base_prior_summary_for(prior: &BackendHistory, backend: &str) -> Summary {
prior
.windows
.get(backend)
.map(|w| w.summary())
.unwrap_or_default()
}
fn facet_prior_summary_for(
prior: &BackendHistory,
backend: &str,
lang: &'static str,
dom: &'static str,
) -> Summary {
let prefix = format!("{backend}@@");
let mut agg = Summary::default();
for (k, w) in &prior.windows {
let Some(suffix) = k.strip_prefix(&prefix) else {
continue;
};
let Ok(ds) = suffix.parse::<DatasetId>() else {
continue;
};
if ds.language() != lang || ds.domain() != dom {
continue;
}
let s = w.summary();
if s.calls == 0 {
continue;
}
agg.calls = agg.calls.saturating_add(s.calls);
agg.ok = agg.ok.saturating_add(s.ok);
agg.junk = agg.junk.saturating_add(s.junk);
agg.hard_junk = agg.hard_junk.saturating_add(s.hard_junk);
agg.cost_units = agg.cost_units.saturating_add(s.cost_units);
agg.elapsed_ms_sum = agg.elapsed_ms_sum.saturating_add(s.elapsed_ms_sum);
}
agg
}
pub fn summaries_for(
&self,
prior: Option<&BackendHistory>,
arms: &[String],
datasets: Option<&[DatasetId]>,
per_dataset: bool,
prior_calls: u64,
) -> BTreeMap<String, Summary> {
let mut out = BTreeMap::new();
for a in arms {
let mut s = self.observed_summary_for(a, datasets, per_dataset);
if prior_calls > 0 {
if let Some(prior) = prior {
let mut prior_s = Self::base_prior_summary_for(prior, a);
if mh::prior_by_facets_from_env() && per_dataset {
if let Some(datasets) = datasets {
if let Some((lang, dom)) = mh::facet_prior_filter(datasets) {
let facet = Self::facet_prior_summary_for(prior, a, lang, dom);
if facet.calls > 0 {
prior_s = facet;
}
}
}
}
mh::apply_prior_counts_to_summary(&mut s, prior_s, prior_calls);
}
}
out.insert(a.clone(), s);
}
out
}
pub fn monitored_for_backends(
&self,
backends: &[String],
recent_cap: usize,
) -> BTreeMap<String, MonitoredWindow> {
let baseline_cap = self.window_cap.max(1);
let recent_cap = recent_cap.max(1).min(baseline_cap);
let mut out: BTreeMap<String, MonitoredWindow> = BTreeMap::new();
for b in backends {
let mut mw = MonitoredWindow::new(baseline_cap, recent_cap);
if let Some(w) = self.windows.get(b) {
for o in &w.buf {
mw.push(*o);
}
}
out.insert(b.clone(), mw);
}
out
}
pub fn chosen_fail_kinds_top_for(
&self,
backend: &str,
datasets: Option<&[DatasetId]>,
per_dataset: bool,
top: usize,
) -> Option<Vec<FailKindCount>> {
let mut counts: BTreeMap<String, u64> = BTreeMap::new();
let mut saw_any = false;
if per_dataset {
if let Some(datasets) = datasets {
for &ds in datasets {
let k = Self::dataset_key(backend, ds);
if let Some(buf) = self.fail_kinds.get(&k) {
saw_any = true;
for kind in buf.iter().flatten() {
*counts.entry(kind.clone()).or_insert(0) += 1;
}
}
}
}
if saw_any && !counts.is_empty() {
return Some(top_counts(counts, top));
}
}
if let Some(buf) = self.fail_kinds.get(backend) {
for kind in buf.iter().flatten() {
*counts.entry(kind.clone()).or_insert(0) += 1;
}
}
if counts.is_empty() {
None
} else {
Some(top_counts(counts, top))
}
}
}
fn top_counts(counts: BTreeMap<String, u64>, top: usize) -> Vec<FailKindCount> {
let mut rows: Vec<(u64, String)> = counts.into_iter().map(|(k, v)| (v, k)).collect();
rows.sort_by(|a, b| b.0.cmp(&a.0).then_with(|| a.1.cmp(&b.1)));
rows.into_iter()
.take(top.max(1))
.map(|(v, k)| FailKindCount { kind: k, count: v })
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
use std::str::FromStr;
use std::time::{SystemTime, UNIX_EPOCH};
#[test]
fn dataset_key_round_trips_variant() {
let key = BackendHistory::dataset_key("stacked", DatasetId::Wnut17);
let (_, suffix) = key
.split_once("@@")
.expect("dataset_key should contain @@ separator");
let parsed = DatasetId::from_str(suffix).expect("dataset key should parse");
assert_eq!(parsed, DatasetId::Wnut17);
}
struct TempJsonFile {
path: PathBuf,
}
impl TempJsonFile {
fn new(tag: &str, json: serde_json::Value) -> Self {
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
let mut path = std::env::temp_dir();
path.push(format!(
"anno-muxer-history-{tag}-pid{}-{nanos}.json",
std::process::id()
));
let content = serde_json::to_vec(&json).expect("temp json should serialize");
std::fs::write(&path, content).expect("write temp json file");
Self { path }
}
}
impl Drop for TempJsonFile {
fn drop(&mut self) {
let _ = std::fs::remove_file(&self.path);
}
}
#[test]
fn history_upgrade_v0_overrides_ok_to_not_junk() {
let tmp = TempJsonFile::new(
"upgrade-v0",
serde_json::json!({
"version": 0,
"windows": {
"a": {
"cap": 999,
"buf": [
{ "ok": false, "junk": false, "hard_junk": false, "cost_units": 1, "elapsed_ms": 2 }
]
}
}
}),
);
let h = BackendHistory::try_load(&tmp.path, 10).expect("try_load");
let w = h.windows.get("a").expect("window a");
assert_eq!(h.version, 3);
assert_eq!(h.window_cap, 10);
assert_eq!(w.cap, 10);
assert_eq!(w.buf.len(), 1);
assert!(
w.buf[0].ok,
"v0 upgrade should set ok := !junk && !hard_junk"
);
}
#[test]
fn history_upgrade_v2_clamps_ok_by_junk_flags() {
let tmp = TempJsonFile::new(
"upgrade-v2",
serde_json::json!({
"version": 2,
"windows": {
"a": {
"buf": [
{ "ok": true, "junk": true, "hard_junk": false, "cost_units": 0, "elapsed_ms": 0 },
{ "ok": true, "junk": false, "hard_junk": false, "cost_units": 0, "elapsed_ms": 1 },
{ "ok": false, "junk": false, "hard_junk": false, "cost_units": 0, "elapsed_ms": 2 }
]
}
}
}),
);
let h = BackendHistory::try_load(&tmp.path, 10).expect("try_load");
let w = h.windows.get("a").expect("window a");
assert_eq!(w.buf.len(), 3);
assert!(!w.buf[0].ok, "junk=true must force ok=false after upgrade");
assert!(w.buf[1].ok, "clean ok should remain ok");
assert!(!w.buf[2].ok, "explicit ok=false should remain false");
}
#[test]
fn history_load_truncates_windows_and_fail_kinds_to_cap() {
let tmp = TempJsonFile::new(
"truncate-cap",
serde_json::json!({
"version": 3,
"windows": {
"a": {
"cap": 999,
"buf": [
{ "ok": true, "junk": false, "hard_junk": false, "cost_units": 0, "elapsed_ms": 1 },
{ "ok": true, "junk": false, "hard_junk": false, "cost_units": 0, "elapsed_ms": 2 },
{ "ok": true, "junk": false, "hard_junk": false, "cost_units": 0, "elapsed_ms": 3 },
{ "ok": true, "junk": false, "hard_junk": false, "cost_units": 0, "elapsed_ms": 4 },
{ "ok": true, "junk": false, "hard_junk": false, "cost_units": 0, "elapsed_ms": 5 }
]
}
},
"fail_kinds": {
"a": [null, "timeout", "backend", "low_signal"]
}
}),
);
let h = BackendHistory::try_load(&tmp.path, 3).expect("try_load");
let w = h.windows.get("a").expect("window a");
assert_eq!(w.cap, 3);
assert_eq!(w.buf.len(), 3);
assert_eq!(w.buf[0].elapsed_ms, 3);
assert_eq!(w.buf[1].elapsed_ms, 4);
assert_eq!(w.buf[2].elapsed_ms, 5);
let fk = h.fail_kinds.get("a").expect("fail kinds a");
assert_eq!(fk.len(), 3);
assert_eq!(fk[0].as_deref(), Some("timeout"));
assert_eq!(fk[1].as_deref(), Some("backend"));
assert_eq!(fk[2].as_deref(), Some("low_signal"));
}
#[test]
fn history_window_summary_aggregates_mean_quality_score() {
let mut w = HistoryWindow::new(10);
w.push(muxer::Outcome::with_quality(true, false, false, 0, 0, 0.8));
w.push(muxer::Outcome::with_quality(true, false, false, 0, 0, 0.4));
w.push(muxer::Outcome::new(false, true, false, 0, 0));
let s = w.summary();
assert_eq!(s.calls, 3);
let mean = s
.mean_quality_score
.expect("mean_quality_score should be Some");
assert!(
(mean - 0.6).abs() < 1e-9,
"mean of 0.8 and 0.4 should be 0.6, got {mean}"
);
}
#[test]
fn history_window_summary_no_quality_scores_gives_none() {
let mut w = HistoryWindow::new(5);
for _ in 0..3 {
w.push(muxer::Outcome::new(true, false, false, 1, 1));
}
let s = w.summary();
assert!(
s.mean_quality_score.is_none(),
"mean_quality_score must be None when no outcomes carry a score"
);
}
#[test]
fn backend_history_save_is_atomic_and_round_trips() {
use std::path::PathBuf;
use std::time::{SystemTime, UNIX_EPOCH};
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
let mut path = std::env::temp_dir();
path.push(format!(
"anno-muxer-history-save-test-pid{}-{nanos}.json",
std::process::id()
));
let tmp_path = PathBuf::from(format!("{}.tmp", path.display()));
let mut h = BackendHistory {
version: 3,
window_cap: 5,
windows: std::collections::BTreeMap::new(),
fail_kinds: std::collections::BTreeMap::new(),
exp3ix_state: None,
linucb_state: None,
};
let mut w = HistoryWindow::new(5);
w.push(muxer::Outcome::with_quality(true, false, false, 7, 42, 0.9));
h.windows.insert("test_arm".to_string(), w);
h.save(&path);
assert!(
path.exists(),
"saved file should exist at {}",
path.display()
);
assert!(
!tmp_path.exists(),
".tmp sibling should be gone after atomic rename"
);
let loaded = BackendHistory::try_load(&path, 5).expect("reload");
let arm = loaded.windows.get("test_arm").expect("test_arm window");
assert_eq!(arm.buf.len(), 1);
assert_eq!(arm.buf[0].cost_units, 7);
assert_eq!(arm.buf[0].elapsed_ms, 42);
assert!((arm.buf[0].quality_score.unwrap() - 0.9).abs() < 1e-9);
let _ = std::fs::remove_file(&path);
}
}