use loadwise_core::{Node, SelectionContext, Strategy, Weighted};
use serde::{Deserialize, Serialize};
use std::sync::Mutex;
#[derive(Debug, Clone)]
pub struct AccountNode {
pub account_id: String,
pub weight: u32,
}
impl AccountNode {
#[must_use]
pub fn new(account_id: impl Into<String>) -> Self {
Self {
account_id: account_id.into(),
weight: 1,
}
}
#[must_use]
pub fn with_weight(mut self, weight: u32) -> Self {
self.weight = weight.max(1);
self
}
}
impl Node for AccountNode {
type Id = String;
fn id(&self) -> &String {
&self.account_id
}
}
impl Weighted for AccountNode {
fn weight(&self) -> u32 {
self.weight
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StrategyKind {
#[default]
RoundRobin,
WeightedRoundRobin,
Random,
WeightedRandom,
Priority,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoutingPolicy {
pub provider: String,
#[serde(default)]
pub family: Option<String>,
#[serde(default)]
pub strategy: StrategyKind,
#[serde(default)]
pub accounts: Vec<String>,
#[serde(default)]
pub weights: std::collections::HashMap<String, u32>,
}
pub struct AccountSelector {
nodes: Vec<AccountNode>,
strategy: Box<dyn Strategy<AccountNode> + Send + Sync>,
last_picked: Mutex<Option<String>>,
}
impl AccountSelector {
#[must_use]
pub fn new(policy: &RoutingPolicy, available_accounts: &[&str]) -> Self {
let pool: Vec<AccountNode> = available_accounts
.iter()
.filter(|id| policy.accounts.is_empty() || policy.accounts.iter().any(|a| a == *id))
.map(|id| {
let weight = policy.weights.get(*id).copied().unwrap_or(1);
AccountNode::new(*id).with_weight(weight)
})
.collect();
Self {
strategy: build_strategy(policy.strategy),
nodes: pool,
last_picked: Mutex::new(None),
}
}
#[must_use]
pub fn len(&self) -> usize {
self.nodes.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
#[must_use]
pub fn pick(&self, ctx: &SelectionContext) -> Option<String> {
let idx = self.strategy.select(&self.nodes, ctx)?;
let picked = self.nodes.get(idx)?.account_id.clone();
let mut last = self
.last_picked
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
*last = Some(picked.clone());
Some(picked)
}
#[must_use]
pub fn last_picked(&self) -> Option<String> {
let guard = self
.last_picked
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
guard.clone()
}
}
impl From<byokey_config::PolicyStrategyKind> for StrategyKind {
fn from(value: byokey_config::PolicyStrategyKind) -> Self {
use byokey_config::PolicyStrategyKind as P;
match value {
P::RoundRobin => Self::RoundRobin,
P::WeightedRoundRobin => Self::WeightedRoundRobin,
P::Random => Self::Random,
P::WeightedRandom => Self::WeightedRandom,
P::Priority => Self::Priority,
}
}
}
impl From<&byokey_config::RoutingPolicyEntry> for RoutingPolicy {
fn from(entry: &byokey_config::RoutingPolicyEntry) -> Self {
Self {
provider: entry.provider.to_string(),
family: entry.family.clone(),
strategy: entry.strategy.into(),
accounts: entry.accounts.clone(),
weights: entry.weights.clone(),
}
}
}
struct PriorityStrategy;
impl Strategy<AccountNode> for PriorityStrategy {
fn select(&self, nodes: &[AccountNode], _ctx: &SelectionContext) -> Option<usize> {
if nodes.is_empty() { None } else { Some(0) }
}
}
fn build_strategy(kind: StrategyKind) -> Box<dyn Strategy<AccountNode> + Send + Sync> {
use loadwise_core::strategy::{Random, RoundRobin, WeightedRandom, WeightedRoundRobin};
match kind {
StrategyKind::RoundRobin => Box::new(RoundRobin::new()),
StrategyKind::Priority => Box::new(PriorityStrategy),
StrategyKind::WeightedRoundRobin => Box::new(WeightedRoundRobin::new()),
StrategyKind::Random => Box::new(Random::new()),
StrategyKind::WeightedRandom => Box::new(WeightedRandom::new()),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn policy(strategy: StrategyKind, accounts: &[&str]) -> RoutingPolicy {
RoutingPolicy {
provider: "claude".into(),
family: None,
strategy,
accounts: accounts.iter().map(|s| (*s).to_string()).collect(),
weights: std::collections::HashMap::default(),
}
}
#[test]
fn round_robin_cycles_through_pool() {
let p = policy(StrategyKind::RoundRobin, &[]);
let sel = AccountSelector::new(&p, &["a", "b", "c"]);
let ctx = SelectionContext::default();
assert_eq!(sel.pick(&ctx).as_deref(), Some("a"));
assert_eq!(sel.pick(&ctx).as_deref(), Some("b"));
assert_eq!(sel.pick(&ctx).as_deref(), Some("c"));
assert_eq!(sel.pick(&ctx).as_deref(), Some("a"));
}
#[test]
fn accounts_filter_restricts_pool() {
let p = policy(StrategyKind::RoundRobin, &["a", "c"]);
let sel = AccountSelector::new(&p, &["a", "b", "c"]);
assert_eq!(sel.len(), 2);
let ctx = SelectionContext::default();
assert_eq!(sel.pick(&ctx).as_deref(), Some("a"));
assert_eq!(sel.pick(&ctx).as_deref(), Some("c"));
}
#[test]
fn empty_pool_returns_none() {
let p = policy(StrategyKind::RoundRobin, &[]);
let sel = AccountSelector::new(&p, &[]);
assert!(sel.pick(&SelectionContext::default()).is_none());
}
#[test]
fn last_picked_tracks_selection() {
let p = policy(StrategyKind::RoundRobin, &[]);
let sel = AccountSelector::new(&p, &["x", "y"]);
assert_eq!(sel.last_picked(), None);
let _ = sel.pick(&SelectionContext::default());
assert_eq!(sel.last_picked().as_deref(), Some("x"));
}
#[test]
fn weighted_round_robin_respects_weights() {
let mut p = policy(StrategyKind::WeightedRoundRobin, &[]);
p.weights.insert("a".into(), 3);
p.weights.insert("b".into(), 1);
let sel = AccountSelector::new(&p, &["a", "b"]);
let ctx = SelectionContext::default();
let picks: Vec<String> = (0..4).filter_map(|_| sel.pick(&ctx)).collect();
let a = picks.iter().filter(|p| p.as_str() == "a").count();
let b = picks.iter().filter(|p| p.as_str() == "b").count();
assert_eq!(a, 3, "a should win 3 of 4 picks");
assert_eq!(b, 1, "b should win 1 of 4 picks");
}
}