use crate::{coverage_pick_under_sampled, CoverageConfig};
use crate::{novelty_pick_unseen, stable_hash64, LatencyGuardrailConfig};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum PipelineOrder {
#[default]
NoveltyFirst,
GuardrailFirst,
}
#[derive(Debug, Clone)]
pub struct PolicyPlan {
pub prechosen: Vec<String>,
pub eligible: Vec<String>,
pub stop_early: bool,
}
#[derive(Debug, Clone)]
pub struct PolicyFill {
pub chosen: Vec<String>,
pub plan: PolicyPlan,
pub eligible_used: Vec<String>,
pub fallback_used: bool,
pub stopped_early: bool,
}
pub fn guardrail_filter_observed<F>(
seed: u64,
arms: &[String],
guard: LatencyGuardrailConfig,
mut observed: F,
) -> (Vec<String>, bool)
where
F: FnMut(&str) -> (u64, f64),
{
let Some(max_ms) = guard.max_mean_ms else {
return (arms.to_vec(), false);
};
let mut eligible: Vec<String> = arms
.iter()
.filter(|b| {
let (calls, mean_ms) = observed(b.as_str());
if guard.require_measured && calls == 0 {
return false;
}
mean_ms <= max_ms
})
.cloned()
.collect();
eligible.sort_by_key(|b| stable_hash64(seed ^ 0x4755_4152, b));
if eligible.is_empty() {
if guard.allow_fewer {
return (Vec::new(), true);
}
return (arms.to_vec(), false);
}
(eligible, false)
}
pub fn guardrail_filter_observed_strict<F>(
seed: u64,
arms: &[String],
guard: LatencyGuardrailConfig,
mut observed: F,
) -> (Vec<String>, bool)
where
F: FnMut(&str) -> (u64, f64),
{
let Some(max_ms) = guard.max_mean_ms else {
return (arms.to_vec(), false);
};
let mut eligible: Vec<String> = arms
.iter()
.filter(|b| {
let (calls, mean_ms) = observed(b.as_str());
if guard.require_measured && calls == 0 {
return false;
}
mean_ms <= max_ms
})
.cloned()
.collect();
eligible.sort_by_key(|b| stable_hash64(seed ^ 0x4755_4152, b)); if eligible.is_empty() {
return (Vec::new(), true);
}
(eligible, false)
}
pub fn guardrail_filter_observed_elapsed<F>(
seed: u64,
arms: &[String],
guard: LatencyGuardrailConfig,
mut observed: F,
) -> (Vec<String>, bool)
where
F: FnMut(&str) -> (u64, u64),
{
guardrail_filter_observed(seed, arms, guard, |b| {
let (calls, elapsed_ms_sum) = observed(b);
let mean_ms = if calls == 0 {
0.0
} else {
(elapsed_ms_sum as f64) / (calls as f64)
};
(calls, mean_ms)
})
}
pub fn guardrail_filter_observed_strict_elapsed<F>(
seed: u64,
arms: &[String],
guard: LatencyGuardrailConfig,
mut observed: F,
) -> (Vec<String>, bool)
where
F: FnMut(&str) -> (u64, u64),
{
guardrail_filter_observed_strict(seed, arms, guard, |b| {
let (calls, elapsed_ms_sum) = observed(b);
let mean_ms = if calls == 0 {
0.0
} else {
(elapsed_ms_sum as f64) / (calls as f64)
};
(calls, mean_ms)
})
}
#[allow(clippy::too_many_arguments)]
pub fn policy_plan_generic<F>(
seed: u64,
arms: &[String],
k: usize,
novelty_enabled: bool,
coverage: CoverageConfig,
guard: LatencyGuardrailConfig,
order: PipelineOrder,
mut observed: F,
) -> PolicyPlan
where
F: FnMut(&str) -> (u64, u64),
{
match order {
PipelineOrder::GuardrailFirst => {
let (guarded, stop_early) = guardrail_filter_observed_strict_elapsed(
seed ^ 0x504C_414E,
arms,
guard,
&mut observed,
);
if stop_early || guarded.is_empty() {
return PolicyPlan {
prechosen: Vec::new(),
eligible: Vec::new(),
stop_early: true,
};
}
let prechosen = prepick_novelty_coverage(
seed,
&guarded,
k,
novelty_enabled,
coverage,
&mut observed,
);
let eligible: Vec<String> = guarded
.iter()
.filter(|b| !prechosen.contains(*b))
.cloned()
.collect();
PolicyPlan {
prechosen,
eligible,
stop_early: false,
}
}
PipelineOrder::NoveltyFirst => {
let prechosen =
prepick_novelty_coverage(seed, arms, k, novelty_enabled, coverage, &mut observed);
let remaining: Vec<String> = arms
.iter()
.filter(|b| !prechosen.contains(*b))
.cloned()
.collect();
let (eligible, stop_early) =
guardrail_filter_observed_elapsed(seed ^ 0x504C_414E, &remaining, guard, observed);
PolicyPlan {
prechosen,
eligible,
stop_early,
}
}
}
}
fn prepick_novelty_coverage<F>(
seed: u64,
arms: &[String],
k: usize,
novelty_enabled: bool,
coverage: CoverageConfig,
observed: &mut F,
) -> Vec<String>
where
F: FnMut(&str) -> (u64, u64),
{
let pre_novel = novelty_pick_unseen(seed, arms, k, novelty_enabled, |b| observed(b).0);
if !coverage.enabled {
return pre_novel;
}
let remaining_after_novel: Vec<String> = arms
.iter()
.filter(|b| !pre_novel.contains(*b))
.cloned()
.collect();
let need_cov = k.saturating_sub(pre_novel.len());
let mut pre_cov = coverage_pick_under_sampled(
seed ^ 0x434F_5645, &remaining_after_novel,
need_cov,
coverage,
|b| observed(b).0,
);
let mut prechosen = pre_novel;
for b in pre_cov.drain(..) {
if prechosen.len() >= k {
break;
}
if prechosen.contains(&b) {
continue;
}
prechosen.push(b);
}
prechosen
}
#[allow(clippy::too_many_arguments)]
pub fn policy_fill_generic<F, P>(
seed: u64,
arms: &[String],
k: usize,
novelty_enabled: bool,
coverage: CoverageConfig,
guard: LatencyGuardrailConfig,
order: PipelineOrder,
mut observed: F,
mut pick_rest: P,
) -> PolicyFill
where
F: FnMut(&str) -> (u64, u64),
P: FnMut(&[String], usize) -> Vec<String>,
{
let plan = policy_plan_generic(
seed,
arms,
k,
novelty_enabled,
coverage,
guard,
order,
&mut observed,
);
let mut chosen = plan.prechosen.clone();
if chosen.len() >= k {
return PolicyFill {
chosen,
plan,
eligible_used: Vec::new(),
fallback_used: false,
stopped_early: false,
};
}
if plan.stop_early && !chosen.is_empty() {
return PolicyFill {
chosen,
eligible_used: Vec::new(),
plan,
fallback_used: false,
stopped_early: true,
};
}
if order == PipelineOrder::GuardrailFirst {
if plan.stop_early {
return PolicyFill {
chosen,
plan,
eligible_used: Vec::new(),
fallback_used: false,
stopped_early: true,
};
}
let eligible_used = plan.eligible.clone();
let remaining_k = k.saturating_sub(chosen.len());
if remaining_k > 0 && !eligible_used.is_empty() {
let rest = pick_rest(&eligible_used, remaining_k);
for b in rest {
if chosen.len() >= k {
break;
}
if !eligible_used.contains(&b) {
continue;
}
if chosen.contains(&b) {
continue;
}
chosen.push(b);
}
}
return PolicyFill {
chosen,
plan,
eligible_used,
fallback_used: false,
stopped_early: false,
};
}
let mut eligible_used = plan.eligible.clone();
let mut fallback_used = false;
let mut stopped_early = false;
if eligible_used.is_empty() {
if guard.require_measured {
stopped_early = true;
return PolicyFill {
chosen,
eligible_used,
plan,
fallback_used,
stopped_early,
};
}
if guard.allow_fewer && !chosen.is_empty() {
stopped_early = true;
return PolicyFill {
chosen,
eligible_used,
plan,
fallback_used,
stopped_early,
};
}
eligible_used = arms
.iter()
.filter(|b| !chosen.contains(*b))
.cloned()
.collect();
fallback_used = true;
}
let remaining_k = k.saturating_sub(chosen.len());
if remaining_k > 0 && !eligible_used.is_empty() {
let rest = pick_rest(&eligible_used, remaining_k);
for b in rest {
if chosen.len() >= k {
break;
}
if !eligible_used.contains(&b) {
continue;
}
if chosen.contains(&b) {
continue;
}
chosen.push(b);
}
}
PolicyFill {
chosen,
plan,
eligible_used,
fallback_used,
stopped_early,
}
}
#[cfg(feature = "contextual")]
pub(crate) fn policy_fill_k_observed_with<F, P>(
seed: u64,
arms: &[String],
k: usize,
novelty_enabled: bool,
guard: LatencyGuardrailConfig,
observed: F,
pick_rest: P,
) -> PolicyFill
where
F: FnMut(&str) -> (u64, u64),
P: FnMut(&[String], usize) -> Vec<String>,
{
policy_fill_generic(
seed,
arms,
k,
novelty_enabled,
CoverageConfig::default(),
guard,
PipelineOrder::NoveltyFirst,
observed,
pick_rest,
)
}
#[allow(clippy::too_many_arguments)]
pub fn policy_fill_k_observed_with_coverage<F, P>(
seed: u64,
arms: &[String],
k: usize,
novelty_enabled: bool,
coverage: CoverageConfig,
guard: LatencyGuardrailConfig,
observed: F,
pick_rest: P,
) -> PolicyFill
where
F: FnMut(&str) -> (u64, u64),
P: FnMut(&[String], usize) -> Vec<String>,
{
policy_fill_generic(
seed,
arms,
k,
novelty_enabled,
coverage,
guard,
PipelineOrder::NoveltyFirst,
observed,
pick_rest,
)
}
#[allow(clippy::too_many_arguments)]
pub fn policy_fill_k_observed_guardrail_first_with_coverage<F, P>(
seed: u64,
arms: &[String],
k: usize,
novelty_enabled: bool,
coverage: CoverageConfig,
guard: LatencyGuardrailConfig,
observed: F,
pick_rest: P,
) -> PolicyFill
where
F: FnMut(&str) -> (u64, u64),
P: FnMut(&[String], usize) -> Vec<String>,
{
policy_fill_generic(
seed,
arms,
k,
novelty_enabled,
coverage,
guard,
PipelineOrder::GuardrailFirst,
observed,
pick_rest,
)
}
pub(crate) fn select_k_without_replacement_by_with_meta<F, M>(
seed: u64,
items: &[String],
k: usize,
mut pick: F,
) -> Vec<(String, M)>
where
F: FnMut(u64, &[String], usize) -> Vec<(String, M)>,
{
if k == 0 || items.is_empty() {
return Vec::new();
}
let mut remaining: Vec<String> = items.to_vec();
let mut out: Vec<(String, M)> = Vec::new();
while !remaining.is_empty() && out.len() < k {
let need = k - out.len();
let batch = pick(seed ^ (out.len() as u64), &remaining, need);
if batch.is_empty() {
break;
}
let mut made_progress = false;
for (b, meta) in batch {
if out.len() >= k {
break;
}
if !remaining.contains(&b) {
continue;
}
if out.iter().any(|(x, _)| x == &b) {
continue;
}
remaining.retain(|x| x != &b);
out.push((b, meta));
made_progress = true;
}
if !made_progress {
break;
}
}
out
}
#[must_use]
pub fn select_k_without_replacement_by<F>(
seed: u64,
items: &[String],
k: usize,
mut pick: F,
) -> Vec<String>
where
F: FnMut(u64, &[String], usize) -> Vec<String>,
{
select_k_without_replacement_by_with_meta(seed, items, k, |s, rem, need| {
pick(s, rem, need).into_iter().map(|b| (b, ())).collect()
})
.into_iter()
.map(|(b, _)| b)
.collect()
}
#[cfg(feature = "contextual")]
#[derive(Debug, Clone)]
pub struct ContextualPolicyFill {
pub fill: PolicyFill,
pub context: Vec<f64>,
pub scores: std::collections::BTreeMap<String, (f64, f64, f64)>,
}
#[cfg(feature = "contextual")]
#[allow(clippy::too_many_arguments)]
pub fn policy_fill_k_contextual<F>(
seed: u64,
arms: &[String],
k: usize,
novelty_enabled: bool,
guard: LatencyGuardrailConfig,
linucb: &mut crate::LinUcb,
context: &[f64],
observed: F,
) -> ContextualPolicyFill
where
F: FnMut(&str) -> (u64, u64),
{
let scores = linucb.scores(arms, context);
let fill = policy_fill_k_observed_with(
seed,
arms,
k,
novelty_enabled,
guard,
observed,
|eligible, remaining_k| {
let mut scored: Vec<(f64, String)> = eligible
.iter()
.map(|a| {
let ucb = scores.get(a).map(|t| t.0).unwrap_or(0.0);
(ucb, a.clone())
})
.collect();
scored.sort_by(|a, b| b.0.total_cmp(&a.0).then_with(|| a.1.cmp(&b.1)));
scored
.into_iter()
.take(remaining_k)
.map(|(_, arm)| arm)
.collect()
},
);
ContextualPolicyFill {
fill,
context: context.to_vec(),
scores,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn arms2() -> Vec<String> {
vec!["a".to_string(), "b".to_string()]
}
fn arms3() -> Vec<String> {
vec!["a".to_string(), "b".to_string(), "c".to_string()]
}
fn guard_strict(max_ms: f64) -> LatencyGuardrailConfig {
LatencyGuardrailConfig {
max_mean_ms: Some(max_ms),
require_measured: true,
allow_fewer: false,
}
}
fn guard_soft(max_ms: f64) -> LatencyGuardrailConfig {
LatencyGuardrailConfig {
max_mean_ms: Some(max_ms),
require_measured: false,
allow_fewer: true,
}
}
#[test]
fn novelty_first_unseen_bypasses_require_measured() {
let arms = vec!["unseen".to_string(), "slow".to_string(), "fast".to_string()];
let plan = policy_plan_generic(
42,
&arms,
3,
true,
CoverageConfig::default(),
guard_strict(50.0),
PipelineOrder::NoveltyFirst,
|b| match b {
"unseen" => (0, 0), "slow" => (10, 1000), "fast" => (10, 200), _ => (0, 0),
},
);
assert!(
plan.prechosen.contains(&"unseen".to_string()),
"NoveltyFirst: unseen arm should be prechosen despite require_measured"
);
assert!(
plan.eligible.contains(&"fast".to_string()),
"fast arm should be eligible after guardrail"
);
assert!(
!plan.eligible.contains(&"slow".to_string()),
"slow arm (100ms > 50ms) must not be eligible"
);
}
#[test]
fn guardrail_first_blocks_unseen_with_require_measured() {
let arms = vec!["unseen".to_string(), "fast".to_string()];
let plan = policy_plan_generic(
42,
&arms,
2,
true,
CoverageConfig::default(),
guard_strict(50.0),
PipelineOrder::GuardrailFirst,
|b| {
if b == "unseen" {
(0, 0)
} else {
(10, 200) }
},
);
assert!(
!plan.prechosen.contains(&"unseen".to_string()),
"GuardrailFirst: unseen arm must not bypass require_measured"
);
assert!(
plan.eligible.contains(&"fast".to_string()),
"measured fast arm should be eligible"
);
}
#[test]
fn guardrail_first_all_unmeasured_stops_early() {
let plan = policy_plan_generic(
42,
&arms2(),
2,
true,
CoverageConfig::default(),
LatencyGuardrailConfig {
max_mean_ms: Some(50.0),
require_measured: true,
allow_fewer: true,
},
PipelineOrder::GuardrailFirst,
|_| (0, 0),
);
assert!(plan.stop_early);
assert!(plan.prechosen.is_empty());
assert!(plan.eligible.is_empty());
}
#[test]
fn novelty_first_guardrail_applies_to_remainder() {
let arms = arms3();
let plan = policy_plan_generic(
42,
&arms,
3,
true,
CoverageConfig::default(),
guard_soft(50.0),
PipelineOrder::NoveltyFirst,
|b| match b {
"a" => (0, 0), "b" => (5, 100), "c" => (5, 750), _ => (0, 0),
},
);
assert!(
plan.prechosen.contains(&"a".to_string()),
"unseen arm prechosen"
);
assert!(
plan.eligible.contains(&"b".to_string()),
"fast arm eligible"
);
assert!(
!plan.eligible.contains(&"c".to_string()),
"slow arm filtered"
);
}
#[test]
fn policy_fill_generic_prechosen_fills_without_algorithm() {
let arms = arms2();
let fill = policy_fill_generic(
42,
&arms,
1,
true,
CoverageConfig::default(),
LatencyGuardrailConfig::default(),
PipelineOrder::NoveltyFirst,
|b| if b == "a" { (0, 0) } else { (5, 100) },
|_eligible, _k| panic!("pick_rest must not be called when prechosen fills k"),
);
assert_eq!(fill.chosen, vec!["a".to_string()]);
assert!(!fill.stopped_early);
assert!(!fill.fallback_used);
}
#[test]
fn policy_fill_generic_combines_prechosen_and_algorithm() {
let arms = arms3();
let fill = policy_fill_generic(
42,
&arms,
2,
true,
CoverageConfig::default(),
LatencyGuardrailConfig::default(),
PipelineOrder::NoveltyFirst,
|b| if b == "a" { (0, 0) } else { (5, 100) },
|eligible, _k| eligible.to_vec(),
);
assert_eq!(fill.chosen.len(), 2);
assert!(
fill.chosen.contains(&"a".to_string()),
"prechosen novelty arm included"
);
let mut s = fill.chosen.clone();
s.sort();
s.dedup();
assert_eq!(s.len(), fill.chosen.len(), "chosen must be unique");
}
#[test]
fn guardrail_first_stop_early_halts_fill() {
let arms = arms2();
let fill = policy_fill_generic(
42,
&arms,
2,
false,
CoverageConfig::default(),
LatencyGuardrailConfig {
max_mean_ms: Some(10.0),
require_measured: true,
allow_fewer: true,
},
PipelineOrder::GuardrailFirst,
|_| (0, 0), |_eligible, _k| unreachable!("must not reach algorithm"),
);
assert!(fill.stopped_early);
assert!(fill.chosen.is_empty());
}
#[test]
fn guardrail_filter_observed_filters_and_falls_back() {
let arms = arms2();
let (eligible, stop_early) = guardrail_filter_observed(
42,
&arms,
LatencyGuardrailConfig {
max_mean_ms: Some(10.0),
require_measured: false,
allow_fewer: false,
},
|_| (1, 100.0), );
assert!(!stop_early);
assert_eq!(eligible, arms, "fallback returns original arms");
}
#[test]
fn guardrail_filter_observed_strict_returns_empty_no_fallback() {
let arms = arms2();
let (eligible, stop_early) = guardrail_filter_observed_strict(
42,
&arms,
LatencyGuardrailConfig {
max_mean_ms: Some(10.0),
require_measured: false,
allow_fewer: true,
},
|_| (1, 100.0),
);
assert!(stop_early);
assert!(eligible.is_empty());
}
}