use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use super::relevance::Bucket;
pub const DEFAULT_GAMMA: f64 = 0.5;
pub const DEFAULT_DELTA: f64 = 0.05;
pub const DEFAULT_KAPPA: f64 = 2.0;
pub const DEFAULT_LAMBDA: f64 = 1.0;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BetaPosterior {
pub alpha: f64,
pub beta: f64,
pub n: u64,
pub c_self: f64,
pub kappa: f64,
pub last_update: DateTime<Utc>,
}
impl BetaPosterior {
pub fn from_self_confidence(c_self: f64, kappa: f64) -> Self {
let c = c_self.clamp(0.0, 1.0);
Self {
alpha: kappa * c,
beta: kappa * (1.0 - c),
n: 0,
c_self: c,
kappa,
last_update: Utc::now(),
}
}
pub fn mean(&self) -> f64 {
let total = self.alpha + self.beta;
if total <= 0.0 {
return 0.0;
}
self.alpha / total
}
pub fn variance(&self) -> f64 {
let total = self.alpha + self.beta;
if total <= 0.0 {
return 0.0;
}
let denom = total * total * (total + 1.0);
(self.alpha * self.beta) / denom
}
pub fn score(&self, gamma: f64) -> f64 {
self.mean() - gamma * self.variance().sqrt()
}
pub fn update(&mut self, outcome: bool, lambda: f64) {
self.alpha *= lambda;
self.beta *= lambda;
if outcome {
self.alpha += 1.0;
} else {
self.beta += 1.0;
}
self.n += 1;
self.last_update = Utc::now();
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DelegationConfig {
#[serde(default = "default_gamma")]
pub gamma: f64,
#[serde(default = "default_delta")]
pub delta: f64,
#[serde(default = "default_kappa")]
pub kappa: f64,
#[serde(default = "default_lambda")]
pub lambda: f64,
#[serde(default)]
pub enabled: bool,
}
fn default_gamma() -> f64 {
DEFAULT_GAMMA
}
fn default_delta() -> f64 {
DEFAULT_DELTA
}
fn default_kappa() -> f64 {
DEFAULT_KAPPA
}
fn default_lambda() -> f64 {
DEFAULT_LAMBDA
}
impl Default for DelegationConfig {
fn default() -> Self {
Self {
gamma: DEFAULT_GAMMA,
delta: DEFAULT_DELTA,
kappa: DEFAULT_KAPPA,
lambda: DEFAULT_LAMBDA,
enabled: false,
}
}
}
pub type BeliefKey = (String, String, Bucket);
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct DelegationState {
#[serde(default)]
pub beliefs: BTreeMap<String, BetaPosterior>,
#[serde(default)]
pub config: DelegationConfig,
}
impl DelegationState {
pub fn with_config(config: DelegationConfig) -> Self {
Self {
beliefs: BTreeMap::new(),
config,
}
}
pub fn key(agent: &str, skill: &str, bucket: Bucket) -> String {
format!(
"{agent}|{skill}|{}|{}|{}",
bucket.difficulty.as_str(),
bucket.dependency.as_str(),
bucket.tool_use.as_str(),
)
}
pub fn ensure(
&mut self,
agent: &str,
skill: &str,
bucket: Bucket,
c_self: f64,
) -> &mut BetaPosterior {
let key = Self::key(agent, skill, bucket);
let kappa = self.config.kappa;
self.beliefs
.entry(key)
.or_insert_with(|| BetaPosterior::from_self_confidence(c_self, kappa))
}
pub fn score(&self, agent: &str, skill: &str, bucket: Bucket) -> Option<f64> {
let key = Self::key(agent, skill, bucket);
self.beliefs.get(&key).map(|p| p.score(self.config.gamma))
}
pub fn update(&mut self, agent: &str, skill: &str, bucket: Bucket, outcome: bool) {
let lambda = self.config.lambda;
let post = self.ensure(agent, skill, bucket, 0.5);
post.update(outcome, lambda);
}
pub fn delegate_to<'a>(
&self,
local: &'a str,
peers: &'a [&'a str],
skill: &str,
bucket: Bucket,
) -> Option<&'a str> {
let local_score = self.score(local, skill, bucket).unwrap_or(0.0);
let mut best: Option<(&str, f64)> = None;
for peer in peers {
if *peer == local {
continue;
}
let peer_score = self.score(peer, skill, bucket).unwrap_or(0.0);
if peer_score > local_score + self.config.delta {
match best {
Some((_, current_best)) if current_best >= peer_score => {}
_ => best = Some((peer, peer_score)),
}
}
}
best.map(|(peer, _)| peer)
}
pub fn rank_candidates<'a>(
&self,
candidates: &'a [&'a str],
skill: &str,
bucket: Bucket,
) -> Option<&'a str> {
if candidates.is_empty() {
return None;
}
let mut best: Option<(&str, f64)> = None;
for name in candidates {
let score = self.score(name, skill, bucket).unwrap_or(0.0);
match best {
Some((_, current)) if current >= score => {}
_ => best = Some((name, score)),
}
}
best.map(|(name, _)| name)
}
pub fn shrink_cold_start(
&mut self,
agent: &str,
skill: &str,
bucket: Bucket,
neighbors: &[Bucket],
m_z: f64,
) {
let m_z = m_z.clamp(0.0, 2.0);
if m_z <= 0.0 {
return;
}
let own_key = Self::key(agent, skill, bucket);
if let Some(own) = self.beliefs.get(&own_key) {
if own.n > 0 {
return;
}
}
let mut sum_alpha = 0.0;
let mut sum_beta = 0.0;
let mut contributors = 0.0;
for nb in neighbors {
if *nb == bucket {
continue;
}
let nb_key = Self::key(agent, skill, *nb);
if let Some(post) = self.beliefs.get(&nb_key) {
if post.n > 0 {
sum_alpha += post.mean();
sum_beta += 1.0 - post.mean();
contributors += 1.0;
}
}
}
if contributors <= 0.0 {
return;
}
let avg_alpha = sum_alpha / contributors;
let avg_beta = sum_beta / contributors;
let kappa = self.config.kappa;
let post = self
.beliefs
.entry(own_key)
.or_insert_with(|| BetaPosterior::from_self_confidence(0.5, kappa));
post.alpha += avg_alpha * m_z;
post.beta += avg_beta * m_z;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::session::relevance::{Dependency, Difficulty, ToolUse};
fn bucket() -> Bucket {
Bucket {
difficulty: Difficulty::Easy,
dependency: Dependency::Isolated,
tool_use: ToolUse::No,
}
}
#[test]
fn beta_update_increments_success_count() {
let mut post = BetaPosterior::from_self_confidence(0.5, 2.0);
post.update(true, 1.0);
assert_eq!(post.n, 1);
assert!((post.alpha - 2.0).abs() < 1e-9);
assert!((post.beta - 1.0).abs() < 1e-9);
}
#[test]
fn beta_score_penalises_uncertainty() {
let mut thin = BetaPosterior::from_self_confidence(0.8, 2.0);
let mut thick = BetaPosterior::from_self_confidence(0.5, 2.0);
for _ in 0..100 {
thick.update(true, 1.0);
thick.update(false, 1.0);
}
thin.update(false, 1.0);
let gamma = 0.5;
assert!(thin.score(gamma) < thick.score(gamma));
}
#[test]
fn delegation_state_update_seeds_and_records() {
let mut state = DelegationState::with_config(DelegationConfig::default());
state.update("openai", "model_call", bucket(), true);
let score = state
.score("openai", "model_call", bucket())
.expect("update must seed the posterior");
assert!(score.is_finite());
}
#[test]
fn delegate_to_respects_margin() {
let mut state = DelegationState::with_config(DelegationConfig::default());
let b = bucket();
for _ in 0..20 {
state.update("local", "skill", b, true);
state.update("local", "skill", b, false);
}
for _ in 0..20 {
state.update("peer", "skill", b, true);
state.update("peer", "skill", b, false);
}
for _ in 0..2 {
state.update("peer", "skill", b, true);
}
let peers = ["peer"];
let maybe = state.delegate_to("local", &peers, "skill", b);
assert!(maybe.is_some() || maybe.is_none());
}
#[test]
fn shrink_cold_start_pulls_neighbour_mass() {
let mut state = DelegationState::with_config(DelegationConfig::default());
let b1 = bucket();
let b2 = Bucket {
difficulty: Difficulty::Medium,
..b1
};
for _ in 0..10 {
state.update("agent", "skill", b2, true);
}
assert!(
state
.beliefs
.get(&DelegationState::key("agent", "skill", b1))
.map(|p| p.n)
.unwrap_or(0)
== 0
);
state.shrink_cold_start("agent", "skill", b1, &[b2], 2.0);
let post = state
.beliefs
.get(&DelegationState::key("agent", "skill", b1))
.unwrap();
assert!(post.alpha > post.beta);
}
#[test]
fn rank_candidates_picks_first_on_cold_start() {
let state = DelegationState::with_config(DelegationConfig::default());
let pick = state.rank_candidates(&["a", "b", "c"], "swarm_dispatch", bucket());
assert_eq!(pick, Some("a"));
}
#[test]
fn rank_candidates_prefers_best_scoring_once_warm() {
let mut state = DelegationState::with_config(DelegationConfig::default());
let b = bucket();
for _ in 0..5 {
state.update("b", "swarm_dispatch", b, true);
}
for _ in 0..5 {
state.update("a", "swarm_dispatch", b, false);
}
let pick = state.rank_candidates(&["a", "b"], "swarm_dispatch", b);
assert_eq!(pick, Some("b"));
}
#[test]
fn rank_candidates_is_none_for_empty_input() {
let state = DelegationState::with_config(DelegationConfig::default());
assert!(
state
.rank_candidates(&[], "swarm_dispatch", bucket())
.is_none()
);
}
#[test]
fn config_defaults_match_documented_constants() {
let cfg = DelegationConfig::default();
assert!((cfg.gamma - DEFAULT_GAMMA).abs() < 1e-9);
assert!((cfg.delta - DEFAULT_DELTA).abs() < 1e-9);
assert!((cfg.kappa - DEFAULT_KAPPA).abs() < 1e-9);
assert!((cfg.lambda - DEFAULT_LAMBDA).abs() < 1e-9);
assert!(!cfg.enabled);
}
}