use std::collections::HashMap;
use std::fmt;
pub const DEFAULT_CANDIDATE_WAITS_MS: [u64; 8] = [0, 1, 2, 5, 10, 20, 50, 100];
pub const MAX_CONTENTION_BUCKETS: usize = 16;
pub const DEFAULT_STARVATION_THRESHOLD: u32 = 5;
const MAX_TRACKED_CONFLICT_TXNS: usize = 4_096;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RetryAction {
FailNow,
RetryAfter { wait_ms: u64 },
}
impl fmt::Display for RetryAction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::FailNow => write!(f, "FailNow"),
Self::RetryAfter { wait_ms } => write!(f, "RetryAfter({wait_ms}ms)"),
}
}
}
#[derive(Debug, Clone, Copy)]
#[allow(clippy::derive_partial_eq_without_eq)]
pub struct BetaPosterior {
pub alpha: f64,
pub beta: f64,
}
impl Default for BetaPosterior {
fn default() -> Self {
Self {
alpha: 1.0,
beta: 1.0,
}
}
}
impl BetaPosterior {
#[must_use]
pub fn new(alpha: f64, beta: f64) -> Self {
Self { alpha, beta }
}
pub fn observe(&mut self, success: bool) {
self.alpha *= 0.95;
self.beta *= 0.95;
if success {
self.alpha += 1.0;
} else {
self.beta += 1.0;
}
}
#[must_use]
pub fn mean(&self) -> f64 {
self.alpha / (self.alpha + self.beta)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ContentionBucketKey {
pub n_active_quantized: u8,
pub m2_hat_quantized: u8,
}
impl ContentionBucketKey {
#[must_use]
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
pub fn from_raw(n_active: u32, m2_hat: f64) -> Self {
let n_clamped = n_active.clamp(1, 16);
let n_quantized = u8::try_from(n_clamped).unwrap_or(16);
let m2_q = (m2_hat * 15.0).floor().clamp(0.0, 15.0) as u8;
Self {
n_active_quantized: n_quantized,
m2_hat_quantized: m2_q,
}
}
#[must_use]
pub fn bucket_index(&self) -> u8 {
let raw = u16::from(self.n_active_quantized.wrapping_sub(1)) * 16
+ u16::from(self.m2_hat_quantized);
u8::try_from(raw % u16::try_from(MAX_CONTENTION_BUCKETS).unwrap_or(16)).unwrap_or(0)
}
}
#[derive(Debug, Clone, Copy)]
#[allow(clippy::derive_partial_eq_without_eq)]
pub struct RetryCostParams {
pub c_fail: f64,
pub c_try: f64,
}
impl Default for RetryCostParams {
fn default() -> Self {
Self {
c_fail: 100.0,
c_try: 1.0,
}
}
}
#[must_use]
pub fn expected_loss_failnow(params: &RetryCostParams) -> f64 {
params.c_fail
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn expected_loss_retry(wait_ms: u64, p_succ: f64, params: &RetryCostParams) -> f64 {
(1.0 - p_succ).mul_add(params.c_fail, wait_ms as f64 + params.c_try)
}
#[derive(Debug, Clone, Copy)]
#[allow(clippy::derive_partial_eq_without_eq)]
pub struct HazardModelParams {
pub lambda: f64,
}
impl HazardModelParams {
#[must_use]
pub fn new(lambda: f64) -> Self {
Self { lambda }
}
#[must_use]
pub fn p_succ(&self, wait_ms: f64) -> f64 {
1.0 - (-self.lambda * wait_ms).exp()
}
#[must_use]
pub fn optimal_wait_ms(&self, c_fail: f64) -> f64 {
let product = self.lambda * c_fail;
if product <= 1.0 {
0.0
} else {
product.ln() / self.lambda
}
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn optimal_wait_clamped(&self, c_fail: f64, budget_ms: u64, candidates: &[u64]) -> u64 {
let raw = self.optimal_wait_ms(c_fail);
let clamped = raw.clamp(0.0, budget_ms as f64);
let mut best = 0_u64;
let mut best_dist = f64::MAX;
for &t in candidates {
if t > budget_ms {
continue;
}
let dist = (clamped - t as f64).abs();
if dist < best_dist {
best_dist = dist;
best = t;
}
}
best
}
}
#[derive(Debug, Clone)]
pub struct RetryEvidenceEntry {
pub txn_id: u64,
pub candidate_set: Vec<u64>,
pub p_hat: Vec<f64>,
pub expected_losses: Vec<f64>,
pub chosen_action: RetryAction,
pub expected_loss_failnow: f64,
pub regime_id: u64,
pub bucket_key: Option<ContentionBucketKey>,
pub alpha_values: Vec<f64>,
pub beta_values: Vec<f64>,
pub starvation_escalation: bool,
}
impl RetryEvidenceEntry {
#[must_use]
pub fn is_complete(&self) -> bool {
!self.candidate_set.is_empty()
&& self.p_hat.len() == self.candidate_set.len()
&& self.expected_losses.len() == self.candidate_set.len()
&& self.alpha_values.len() == self.candidate_set.len()
&& self.beta_values.len() == self.candidate_set.len()
}
}
pub struct RetryController {
pub params: RetryCostParams,
candidates: Vec<u64>,
posteriors: Vec<BetaPosterior>,
starvation_threshold: u32,
ledger: Vec<RetryEvidenceEntry>,
conflict_counts: HashMap<u64, u32>,
#[allow(dead_code)]
current_regime_id: u64,
}
impl RetryController {
#[must_use]
pub fn new(params: RetryCostParams) -> Self {
let candidates = DEFAULT_CANDIDATE_WAITS_MS.to_vec();
let posteriors = vec![BetaPosterior::default(); candidates.len()];
Self {
params,
candidates,
posteriors,
starvation_threshold: DEFAULT_STARVATION_THRESHOLD,
ledger: Vec::new(),
conflict_counts: HashMap::new(),
current_regime_id: 0,
}
}
#[must_use]
pub fn with_candidates(
params: RetryCostParams,
candidates: Vec<u64>,
starvation_threshold: u32,
) -> Self {
let posteriors = vec![BetaPosterior::default(); candidates.len()];
Self {
params,
candidates,
posteriors,
starvation_threshold,
ledger: Vec::new(),
conflict_counts: HashMap::new(),
current_regime_id: 0,
}
}
#[allow(clippy::cast_precision_loss)]
pub fn decide(
&mut self,
txn_id: u64,
budget_ms: u64,
regime_id: u64,
bucket_key: Option<ContentionBucketKey>,
) -> RetryAction {
if regime_id != self.current_regime_id {
self.current_regime_id = regime_id;
for p in &mut self.posteriors {
p.alpha = 1.0;
p.beta = 1.0;
}
}
let conflict_count = self.increment_conflict(txn_id);
let starvation_escalation = conflict_count >= self.starvation_threshold;
if budget_ms == 0 {
let entry = self.build_evidence(
txn_id,
&[],
&[],
&[],
RetryAction::FailNow,
regime_id,
bucket_key,
starvation_escalation,
);
self.ledger.push(entry);
self.clear_conflict(txn_id);
return RetryAction::FailNow;
}
let eligible: Vec<u64> = self
.candidates
.iter()
.copied()
.filter(|&t| t <= budget_ms)
.collect();
if eligible.is_empty() {
let entry = self.build_evidence(
txn_id,
&[],
&[],
&[],
RetryAction::FailNow,
regime_id,
bucket_key,
starvation_escalation,
);
self.ledger.push(entry);
self.clear_conflict(txn_id);
return RetryAction::FailNow;
}
let el_fail = expected_loss_failnow(&self.params);
let mut p_hats = Vec::with_capacity(eligible.len());
let mut losses = Vec::with_capacity(eligible.len());
for &t in &eligible {
let idx = self.candidate_index(t);
let posterior = &self.posteriors[idx];
let p_gittins = gittins_index_approx(posterior.alpha, posterior.beta).min(1.0);
let el = expected_loss_retry(t, p_gittins, &self.params);
p_hats.push(p_gittins);
losses.push(el);
}
let mut best_action = RetryAction::FailNow;
let mut best_loss = el_fail;
for (i, &el) in losses.iter().enumerate() {
if el < best_loss {
best_loss = el;
best_action = RetryAction::RetryAfter {
wait_ms: eligible[i],
};
}
}
let entry = self.build_evidence(
txn_id,
&eligible,
&p_hats,
&losses,
best_action,
regime_id,
bucket_key,
starvation_escalation,
);
self.ledger.push(entry);
if best_action == RetryAction::FailNow {
self.clear_conflict(txn_id);
}
best_action
}
pub fn observe(&mut self, wait_ms: u64, success: bool) {
let idx = self.candidate_index(wait_ms);
self.posteriors[idx].observe(success);
}
#[allow(clippy::cast_precision_loss)]
pub fn decide_with_cx(
&mut self,
txn_id: u64,
budget_ms: u64,
regime_id: u64,
bucket_key: Option<ContentionBucketKey>,
cx_cancelled: bool,
) -> RetryAction {
if cx_cancelled {
return RetryAction::FailNow;
}
self.decide(txn_id, budget_ms, regime_id, bucket_key)
}
pub fn clear_conflict(&mut self, txn_id: u64) {
self.conflict_counts.remove(&txn_id);
}
#[must_use]
pub fn ledger(&self) -> &[RetryEvidenceEntry] {
&self.ledger
}
#[must_use]
pub fn posterior(&self, wait_ms: u64) -> &BetaPosterior {
let idx = self.candidate_index(wait_ms);
&self.posteriors[idx]
}
#[must_use]
pub fn tracked_conflicts(&self) -> usize {
self.conflict_counts.len()
}
#[must_use]
pub fn is_starvation_escalated(&self, txn_id: u64) -> bool {
self.conflict_counts
.get(&txn_id)
.is_some_and(|count| *count >= self.starvation_threshold)
}
fn candidate_index(&self, wait_ms: u64) -> usize {
debug_assert!(
!self.candidates.is_empty(),
"retry candidate set must not be empty"
);
if let Some(exact) = self
.candidates
.iter()
.position(|&candidate| candidate == wait_ms)
{
return exact;
}
self.candidates
.iter()
.enumerate()
.min_by_key(|(_, candidate)| candidate.abs_diff(wait_ms))
.map_or(0, |(idx, _)| idx)
}
fn increment_conflict(&mut self, txn_id: u64) -> u32 {
if self.conflict_counts.len() >= MAX_TRACKED_CONFLICT_TXNS
&& !self.conflict_counts.contains_key(&txn_id)
{
let evict_txn = self
.conflict_counts
.iter()
.min_by_key(|(id, count)| (**count, **id))
.map(|(id, _)| *id);
if let Some(evict_txn) = evict_txn {
self.conflict_counts.remove(&evict_txn);
}
}
let count = self.conflict_counts.entry(txn_id).or_insert(0);
*count = count.saturating_add(1);
*count
}
#[allow(clippy::too_many_arguments)]
fn build_evidence(
&self,
txn_id: u64,
eligible: &[u64],
p_hats: &[f64],
losses: &[f64],
chosen_action: RetryAction,
regime_id: u64,
bucket_key: Option<ContentionBucketKey>,
starvation_escalation: bool,
) -> RetryEvidenceEntry {
let alphas: Vec<f64> = eligible
.iter()
.map(|&t| self.posteriors[self.candidate_index(t)].alpha)
.collect();
let betas: Vec<f64> = eligible
.iter()
.map(|&t| self.posteriors[self.candidate_index(t)].beta)
.collect();
RetryEvidenceEntry {
txn_id,
candidate_set: eligible.to_vec(),
p_hat: p_hats.to_vec(),
expected_losses: losses.to_vec(),
chosen_action,
expected_loss_failnow: expected_loss_failnow(&self.params),
regime_id,
bucket_key,
alpha_values: alphas,
beta_values: betas,
starvation_escalation,
}
}
}
#[must_use]
pub fn gittins_index_approx(alpha: f64, beta: f64) -> f64 {
let total = alpha + beta;
let mean = alpha / total;
let exploration_bonus = (2.0 * alpha * beta / (total * total * (total + 1.0))).sqrt();
mean + exploration_bonus
}
#[must_use]
pub fn gittins_threshold(c_try: f64, c_fail: f64) -> f64 {
if c_fail <= 0.0 {
return 0.0;
}
(1.0 - c_try / c_fail).max(0.0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_beta_posterior_default() {
let bp = BetaPosterior::default();
assert!((bp.mean() - 0.5).abs() < 1e-10);
}
#[test]
fn test_beta_posterior_observe() {
let mut bp = BetaPosterior::default();
bp.observe(true);
assert!((bp.mean() - 1.95 / 2.90).abs() < 1e-10);
bp.observe(false);
assert!((bp.mean() - 1.8525 / 3.755).abs() < 1e-10);
}
#[test]
fn test_expected_loss_failnow_equals_cfail() {
let params = RetryCostParams {
c_fail: 42.0,
c_try: 1.0,
};
assert!((expected_loss_failnow(¶ms) - 42.0).abs() < 1e-10);
}
#[test]
fn test_expected_loss_retry_formula() {
let params = RetryCostParams {
c_fail: 100.0,
c_try: 5.0,
};
let el = expected_loss_retry(10, 0.8, ¶ms);
assert!((el - 35.0).abs() < 1e-10);
}
#[test]
fn test_hazard_model_optimal() {
let hm = HazardModelParams::new(0.5);
let t_star = hm.optimal_wait_ms(100.0);
let expected = 2.0_f64.mul_add(50.0_f64.ln(), 0.0);
assert!((t_star - expected).abs() < 0.01);
}
#[test]
fn test_hazard_model_no_retry() {
let hm = HazardModelParams::new(0.01);
assert!((hm.optimal_wait_ms(50.0)).abs() < 1e-10);
}
#[test]
fn test_contention_bucket_deterministic() {
let k1 = ContentionBucketKey::from_raw(4, 0.025);
let k2 = ContentionBucketKey::from_raw(4, 0.025);
assert_eq!(k1.bucket_index(), k2.bucket_index());
}
#[test]
fn test_contention_bucket_bounded() {
for n in 0..=20 {
for m2_step in 0..=20 {
let m2 = f64::from(m2_step) / 20.0;
let k = ContentionBucketKey::from_raw(n, m2);
assert!(
usize::from(k.bucket_index()) < MAX_CONTENTION_BUCKETS,
"bucket_index={} for n={n} m2={m2}",
k.bucket_index()
);
}
}
}
#[test]
fn test_controller_budget_exhausted() {
let mut ctrl = RetryController::new(RetryCostParams::default());
let action = ctrl.decide(1, 0, 0, None);
assert_eq!(action, RetryAction::FailNow);
}
#[test]
fn test_controller_argmin() {
let params = RetryCostParams {
c_fail: 100.0,
c_try: 1.0,
};
let mut ctrl = RetryController::new(params);
for _ in 0..50 {
ctrl.observe(5, true);
}
for _ in 0..2 {
ctrl.observe(5, false);
}
let action = ctrl.decide(1, 100, 0, None);
assert!(
matches!(action, RetryAction::RetryAfter { wait_ms: 5 }),
"Expected RetryAfter(5ms), got {action:?}"
);
}
#[test]
fn test_observe_non_candidate_wait_uses_nearest_arm() {
let params = RetryCostParams::default();
let mut ctrl = RetryController::with_candidates(params, vec![0, 5, 10], 3);
let alpha_zero_before = ctrl.posterior(0).alpha;
let alpha_five_before = ctrl.posterior(5).alpha;
let alpha_ten_before = ctrl.posterior(10).alpha;
ctrl.observe(6, true);
assert!((ctrl.posterior(0).alpha - alpha_zero_before).abs() < 1e-10);
assert!((ctrl.posterior(10).alpha - alpha_ten_before).abs() < 1e-10);
assert!((ctrl.posterior(5).alpha - (alpha_five_before * 0.95 + 1.0)).abs() < 1e-10);
}
#[test]
fn test_evidence_ledger_complete() {
let mut ctrl = RetryController::new(RetryCostParams::default());
let _ = ctrl.decide(42, 50, 7, Some(ContentionBucketKey::from_raw(4, 0.1)));
assert!(!ctrl.ledger().is_empty());
let entry = &ctrl.ledger()[0];
assert!(entry.is_complete() || entry.candidate_set.is_empty());
assert_eq!(entry.txn_id, 42);
assert_eq!(entry.regime_id, 7);
}
#[test]
fn test_starvation_escalation() {
let params = RetryCostParams::default();
let mut ctrl =
RetryController::with_candidates(params, vec![0, 5, 10], DEFAULT_STARVATION_THRESHOLD);
for i in 0..DEFAULT_STARVATION_THRESHOLD {
let _ = ctrl.decide(99, 100, 0, None);
assert_eq!(
ctrl.is_starvation_escalated(99),
i + 1 >= DEFAULT_STARVATION_THRESHOLD
);
}
assert!(ctrl.is_starvation_escalated(99));
assert!(ctrl.ledger().last().unwrap().starvation_escalation);
}
#[test]
fn test_fail_now_clears_conflict_tracking() {
let mut ctrl = RetryController::new(RetryCostParams::default());
let action = ctrl.decide(77, 0, 0, None);
assert_eq!(action, RetryAction::FailNow);
assert_eq!(ctrl.tracked_conflicts(), 0);
}
#[test]
fn test_conflict_tracking_is_bounded() {
let mut ctrl = RetryController::new(RetryCostParams::default());
for txn_id in 1..=(MAX_TRACKED_CONFLICT_TXNS + 256) {
let _ = ctrl.decide(u64::try_from(txn_id).unwrap(), 100, 0, None);
}
assert!(ctrl.tracked_conflicts() <= MAX_TRACKED_CONFLICT_TXNS);
}
#[test]
fn test_gittins_index_basic() {
let gi = gittins_index_approx(1.0, 1.0);
assert!(
gi > 0.5,
"Gittins index should exceed mean for uniform prior"
);
let gi_strong = gittins_index_approx(100.0, 1.0);
assert!(gi_strong > 0.95);
}
#[test]
fn test_beta_bernoulli_posterior_mean() {
let mut bp = BetaPosterior::new(1.0, 1.0);
bp.observe(true); bp.observe(true); bp.observe(true); bp.observe(false); let expected = 3.52438125 / (3.52438125 + 1.81450625);
assert!(
(bp.mean() - expected).abs() < 1e-10,
"bead_id=bd-1p75 case=posterior_mean p={} expected={expected}",
bp.mean()
);
}
#[test]
fn test_budget_clamp() {
let mut ctrl = RetryController::new(RetryCostParams::default());
let _ = ctrl.decide(1, 3, 0, None);
let entry = ctrl.ledger().last().expect("should have entry");
for &t in &entry.candidate_set {
assert!(
t <= 3,
"bead_id=bd-1p75 case=budget_clamp candidate={t} budget=3"
);
}
assert!(
!entry.candidate_set.contains(&5),
"bead_id=bd-1p75 case=budget_clamp 5ms should be excluded"
);
}
#[test]
fn test_hazard_model_clamp_budget() {
let hm = HazardModelParams::new(0.001);
let clamped = hm.optimal_wait_clamped(10_000.0, 50, &DEFAULT_CANDIDATE_WAITS_MS);
assert_eq!(
clamped, 50,
"bead_id=bd-1p75 case=hazard_clamp_budget clamped={clamped}"
);
}
#[test]
fn test_no_priority_for_retries() {
let params = RetryCostParams::default();
let mut ctrl_fresh = RetryController::new(params);
let mut ctrl_retried = RetryController::new(RetryCostParams::default());
for _ in 0..10 {
ctrl_fresh.observe(5, true);
ctrl_retried.observe(5, true);
}
let action_fresh = ctrl_fresh.decide(1, 100, 0, None);
let action_retried = ctrl_retried.decide(2, 100, 0, None);
assert_eq!(
action_fresh, action_retried,
"bead_id=bd-1p75 case=no_priority fresh={action_fresh:?} retried={action_retried:?}"
);
}
#[test]
fn test_evidence_ledger_starvation() {
let params = RetryCostParams::default();
let mut ctrl = RetryController::with_candidates(
params,
vec![0, 5, 10],
3, );
for _ in 0..3 {
let _ = ctrl.decide(77, 100, 0, None);
}
let last = ctrl.ledger().last().expect("should have entry");
assert!(
last.starvation_escalation,
"bead_id=bd-1p75 case=ledger_starvation expected=true"
);
assert_eq!(
last.txn_id, 77,
"bead_id=bd-1p75 case=ledger_starvation_txn"
);
}
#[test]
fn test_gittins_index_threshold() {
let threshold = gittins_threshold(1.0, 100.0);
assert!(
(threshold - 0.99).abs() < 1e-10,
"bead_id=bd-1p75 case=gittins_threshold threshold={threshold}"
);
let gi = gittins_index_approx(100.0, 1.0);
assert!(
gi > threshold,
"bead_id=bd-1p75 case=gittins_retry gi={gi} threshold={threshold}"
);
let gi_low = gittins_index_approx(1.0, 100.0);
assert!(
gi_low < threshold,
"bead_id=bd-1p75 case=gittins_no_retry gi={gi_low} threshold={threshold}"
);
}
#[test]
fn test_cx_deadline_respected() {
let mut ctrl = RetryController::new(RetryCostParams::default());
let action = ctrl.decide_with_cx(1, 1000, 0, None, true);
assert_eq!(
action,
RetryAction::FailNow,
"bead_id=bd-1p75 case=cx_deadline"
);
let action2 = ctrl.decide_with_cx(2, 2, 0, None, false);
let entry = ctrl.ledger().last().expect("should have entry");
for &t in &entry.candidate_set {
assert!(
t <= 2,
"bead_id=bd-1p75 case=cx_effective_budget candidate={t}"
);
}
assert!(
matches!(
action2,
RetryAction::FailNow | RetryAction::RetryAfter { .. }
),
"bead_id=bd-1p75 case=cx_effective_budget action={action2:?}"
);
}
}