use std::collections::{HashMap, HashSet};
#[cfg(unix)]
use std::io::Write as _;
use std::path::{Path, PathBuf};
use rand::SeedableRng as _;
use rand_distr::{Beta, Distribution};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct BetaDist {
pub alpha: f64,
pub beta: f64,
}
impl Default for BetaDist {
fn default() -> Self {
Self {
alpha: 1.0,
beta: 1.0,
}
}
}
impl BetaDist {
pub fn sample<R: rand::Rng>(&self, rng: &mut R) -> f64 {
let alpha = self.alpha.max(1e-6);
let beta = self.beta.max(1e-6);
let dist = Beta::new(alpha, beta).unwrap_or_else(|_| Beta::new(1.0, 1.0).unwrap());
dist.sample(rng)
}
}
#[derive(Debug, Clone)]
pub struct ThompsonSelection {
pub provider: String,
pub alpha: f64,
pub beta: f64,
pub exploit: bool,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
pub struct ThompsonState {
distributions: HashMap<String, BetaDist>,
#[serde(skip)]
rng: Option<rand::rngs::SmallRng>,
}
impl ThompsonState {
#[must_use]
pub fn select(&mut self, providers: &[String]) -> Option<ThompsonSelection> {
if providers.is_empty() {
return None;
}
let rng = self
.rng
.get_or_insert_with(|| rand::rngs::SmallRng::from_rng(&mut rand::rng()));
let (best, _) = providers
.iter()
.map(|name| {
let dist = self.distributions.get(name).cloned().unwrap_or_default();
let sample = dist.sample(rng);
(name.clone(), sample)
})
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))?;
let best_dist = self.distributions.get(&best).cloned().unwrap_or_default();
let best_mean = best_dist.alpha / (best_dist.alpha + best_dist.beta);
let exploit = providers.iter().all(|name| {
let dist = self.distributions.get(name).cloned().unwrap_or_default();
best_mean >= dist.alpha / (dist.alpha + dist.beta)
});
Some(ThompsonSelection {
provider: best,
alpha: best_dist.alpha,
beta: best_dist.beta,
exploit,
})
}
#[must_use]
pub fn select_with_priors(
&mut self,
providers: &[String],
overrides: &std::collections::HashMap<String, (f64, f64)>,
) -> Option<ThompsonSelection> {
if providers.is_empty() {
return None;
}
let rng = self
.rng
.get_or_insert_with(|| rand::rngs::SmallRng::from_rng(&mut rand::rng()));
let (best, _) = providers
.iter()
.map(|name| {
let (alpha, beta) = overrides.get(name).copied().unwrap_or_else(|| {
let dist = self.distributions.get(name).cloned().unwrap_or_default();
(dist.alpha, dist.beta)
});
let dist = BetaDist {
alpha: alpha.max(1e-6),
beta: beta.max(1e-6),
};
let sample = dist.sample(rng);
(name.clone(), sample)
})
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))?;
let stored = self.distributions.get(&best).cloned().unwrap_or_default();
let (eff_alpha, eff_beta) = overrides
.get(&best)
.copied()
.unwrap_or((stored.alpha, stored.beta));
let best_mean = eff_alpha / (eff_alpha + eff_beta);
let exploit = providers.iter().all(|name| {
let (a, b) = overrides.get(name).copied().unwrap_or_else(|| {
let d = self.distributions.get(name).cloned().unwrap_or_default();
(d.alpha, d.beta)
});
best_mean >= a / (a + b)
});
Some(ThompsonSelection {
provider: best,
alpha: eff_alpha,
beta: eff_beta,
exploit,
})
}
pub fn update(&mut self, provider: &str, success: bool) {
let dist = self.distributions.entry(provider.to_owned()).or_default();
if success {
dist.alpha += 1.0;
} else {
dist.beta += 1.0;
}
}
#[must_use]
pub fn provider_stats(&self) -> Vec<(String, f64, f64)> {
let mut stats: Vec<(String, f64, f64)> = self
.distributions
.iter()
.map(|(name, dist)| (name.clone(), dist.alpha, dist.beta))
.collect();
stats.sort_by(|a, b| a.0.cmp(&b.0));
stats
}
#[must_use]
pub fn default_path() -> PathBuf {
dirs::home_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join(".zeph")
.join("router_thompson_state.json")
}
#[must_use]
pub fn get_distribution(&self, provider: &str) -> BetaDist {
self.distributions
.get(provider)
.cloned()
.unwrap_or_default()
}
pub fn prune(&mut self, known: &HashSet<String>) {
self.distributions.retain(|k, _| known.contains(k));
}
#[must_use]
pub fn load(path: &Path) -> Self {
let bytes = match std::fs::read(path) {
Ok(b) => b,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
return Self::default();
}
Err(e) => {
tracing::debug!(path = %path.display(), error = %e, "Thompson state file unreadable, using uniform priors");
return Self::default();
}
};
match serde_json::from_slice::<Self>(&bytes) {
Ok(mut state) => {
for dist in state.distributions.values_mut() {
dist.alpha = dist.alpha.clamp(0.5, 1e9);
dist.beta = dist.beta.clamp(0.5, 1e9);
if !dist.alpha.is_finite() {
dist.alpha = 1.0;
}
if !dist.beta.is_finite() {
dist.beta = 1.0;
}
}
state
}
Err(e) => {
tracing::warn!(
path = %path.display(),
error = %e,
"Thompson state file is corrupt; resetting to uniform priors"
);
Self::default()
}
}
}
pub fn save(&self, path: &Path) -> std::io::Result<()> {
let json = serde_json::to_vec(self).map_err(|e| std::io::Error::other(e.to_string()))?;
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let tmp = path.with_extension("tmp");
#[cfg(unix)]
{
use std::os::unix::fs::OpenOptionsExt;
std::fs::OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.mode(0o600)
.open(&tmp)?
.write_all(&json)?;
}
#[cfg(not(unix))]
std::fs::write(&tmp, &json)?;
std::fs::rename(&tmp, path)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn select_empty_providers_returns_none() {
let mut state = ThompsonState::default();
assert!(state.select(&[]).is_none());
}
#[test]
fn select_single_provider_returns_it() {
let mut state = ThompsonState::default();
let result = state.select(&["ollama".to_owned()]);
assert_eq!(result.map(|s| s.provider).as_deref(), Some("ollama"));
}
#[test]
fn select_returns_one_of_the_providers() {
let mut state = ThompsonState::default();
let providers = vec!["a".to_owned(), "b".to_owned(), "c".to_owned()];
let selected = state.select(&providers).unwrap().provider;
assert!(providers.contains(&selected));
}
#[test]
fn update_success_increases_alpha() {
let mut state = ThompsonState::default();
state.update("p", true);
let dist = &state.distributions["p"];
assert!((dist.alpha - 2.0).abs() < f64::EPSILON);
assert!((dist.beta - 1.0).abs() < f64::EPSILON);
}
#[test]
fn update_failure_increases_beta() {
let mut state = ThompsonState::default();
state.update("p", false);
let dist = &state.distributions["p"];
assert!((dist.alpha - 1.0).abs() < f64::EPSILON);
assert!((dist.beta - 2.0).abs() < f64::EPSILON);
}
#[test]
fn beta_dist_sample_in_range() {
let dist = BetaDist::default();
let mut rng = rand::rngs::SmallRng::seed_from_u64(42);
for _ in 0..100 {
let v = dist.sample(&mut rng);
assert!((0.0..=1.0).contains(&v), "sample {v} out of [0,1]");
}
}
#[test]
fn save_and_load_roundtrip() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("thompson.json");
let mut state = ThompsonState::default();
state.update("provider_a", true);
state.update("provider_a", true);
state.update("provider_b", false);
state.save(&path).unwrap();
let loaded = ThompsonState::load(&path);
assert!((loaded.distributions["provider_a"].alpha - 3.0).abs() < f64::EPSILON);
assert!((loaded.distributions["provider_a"].beta - 1.0).abs() < f64::EPSILON);
assert!((loaded.distributions["provider_b"].beta - 2.0).abs() < f64::EPSILON);
}
#[test]
fn load_missing_file_returns_default() {
let state = ThompsonState::load(Path::new("/tmp/does-not-exist-zeph-test.json"));
assert!(state.distributions.is_empty());
}
#[test]
fn load_corrupt_file_returns_default() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("corrupt.json");
std::fs::write(&path, b"not valid json {{{{").unwrap();
let state = ThompsonState::load(&path);
assert!(state.distributions.is_empty());
}
#[test]
fn prune_removes_stale_entries() {
let mut state = ThompsonState::default();
state.update("provider_a", true);
state.update("provider_b", false);
state.update("provider_c", true);
let known: HashSet<String> = ["provider_a".to_owned(), "provider_c".to_owned()]
.into_iter()
.collect();
state.prune(&known);
assert!(state.distributions.contains_key("provider_a"));
assert!(!state.distributions.contains_key("provider_b"));
assert!(state.distributions.contains_key("provider_c"));
}
#[test]
fn provider_stats_returns_sorted_entries() {
let mut state = ThompsonState::default();
state.update("z_provider", true);
state.update("a_provider", false);
state.update("m_provider", true);
let provider_stats = state.provider_stats();
assert_eq!(provider_stats.len(), 3);
assert_eq!(provider_stats[0].0, "a_provider");
assert_eq!(provider_stats[1].0, "m_provider");
assert_eq!(provider_stats[2].0, "z_provider");
}
#[test]
fn high_alpha_provider_selected_disproportionately() {
let mut state = ThompsonState::default();
for _ in 0..50 {
state.update("provider_a", true);
state.update("provider_b", false);
}
let providers = vec!["provider_a".to_owned(), "provider_b".to_owned()];
let trials = 1000usize;
let mut a_wins = 0usize;
for _ in 0..trials {
if state.select(&providers).map(|s| s.provider).as_deref() == Some("provider_a") {
a_wins += 1;
}
}
let Ok(wins) = u32::try_from(a_wins) else {
panic!("a_wins overflowed u32");
};
let ratio = f64::from(wins) / 1000.0;
assert!(
ratio > 0.90,
"provider_a should be selected >90% of the time, got {ratio:.2}"
);
}
#[test]
fn select_is_mut_compatible_with_repeated_calls() {
let mut state = ThompsonState::default();
state.update("a", true);
state.update("b", false);
let providers = vec!["a".to_owned(), "b".to_owned()];
for _ in 0..10 {
let result = state.select(&providers);
assert!(result.is_some());
assert!(!result.unwrap().provider.is_empty());
}
}
#[test]
fn save_leaves_no_tmp_file() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("state.json");
let tmp = path.with_extension("tmp");
let mut state = ThompsonState::default();
state.update("p", true);
state.save(&path).unwrap();
assert!(path.exists(), "state file must exist after save");
assert!(
!tmp.exists(),
"tmp file must be cleaned up after atomic rename"
);
}
#[test]
fn load_clamps_out_of_range_values() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("state.json");
std::fs::write(
&path,
br#"{"distributions":{"p":{"alpha":-1.0,"beta":2000000000000.0}}}"#,
)
.unwrap();
let state = ThompsonState::load(&path);
let dist = &state.distributions["p"];
assert!(dist.alpha >= 0.5, "alpha must be clamped to >= 0.5");
assert!(dist.beta <= 1e9, "beta must be clamped to <= 1e9");
}
}