use std::sync::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum RoutingStrategy {
#[default]
RoundRobin,
FillFirst,
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum KeyState {
Ready,
Cooldown { until: Instant },
Blocked,
Disabled,
}
pub struct CredentialRouter {
keys: Vec<String>,
states: Mutex<Vec<KeyState>>,
index: AtomicUsize,
cooldown_duration: Duration,
strategy: RoutingStrategy,
max_retry: Option<usize>,
}
impl CredentialRouter {
#[must_use]
pub fn new(keys: Vec<String>, cooldown_duration: Duration) -> Self {
assert!(
!keys.is_empty(),
"CredentialRouter requires at least one key"
);
let len = keys.len();
Self {
keys,
states: Mutex::new(vec![KeyState::Ready; len]),
index: AtomicUsize::new(0),
cooldown_duration,
strategy: RoutingStrategy::default(),
max_retry: None,
}
}
#[must_use]
pub fn with_strategy(mut self, strategy: RoutingStrategy) -> Self {
self.strategy = strategy;
self
}
#[must_use]
pub fn with_max_retry(mut self, max: usize) -> Self {
self.max_retry = Some(max);
self
}
#[must_use]
pub fn len(&self) -> usize {
self.keys.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.keys.is_empty()
}
#[must_use]
pub fn max_retry(&self) -> Option<usize> {
self.max_retry
}
pub fn next_key(&self) -> Option<&str> {
let len = self.keys.len();
let now = Instant::now();
let mut states = self.states.lock().expect("states lock");
for state in states.iter_mut() {
if let KeyState::Cooldown { until } = *state
&& now >= until
{
*state = KeyState::Ready;
}
}
match self.strategy {
RoutingStrategy::RoundRobin => {
let start = self.index.fetch_add(1, Ordering::Relaxed) % len;
for i in 0..len {
let idx = (start + i) % len;
if states[idx] == KeyState::Ready {
return Some(&self.keys[idx]);
}
}
None
}
RoutingStrategy::FillFirst => {
for (idx, state) in states.iter().enumerate() {
if *state == KeyState::Ready {
return Some(&self.keys[idx]);
}
}
None
}
}
}
pub fn mark_error(&self, key: &str) {
if let Some(idx) = self.keys.iter().position(|k| k == key) {
let mut states = self.states.lock().expect("states lock");
states[idx] = KeyState::Cooldown {
until: Instant::now() + self.cooldown_duration,
};
}
}
pub fn mark_error_with_delay(&self, key: &str, delay: Duration) {
if let Some(idx) = self.keys.iter().position(|k| k == key) {
let mut states = self.states.lock().expect("states lock");
states[idx] = KeyState::Cooldown {
until: Instant::now() + delay,
};
}
}
pub fn mark_blocked(&self, key: &str) {
if let Some(idx) = self.keys.iter().position(|k| k == key) {
let mut states = self.states.lock().expect("states lock");
states[idx] = KeyState::Blocked;
}
}
pub fn mark_disabled(&self, key: &str) {
if let Some(idx) = self.keys.iter().position(|k| k == key) {
let mut states = self.states.lock().expect("states lock");
states[idx] = KeyState::Disabled;
}
}
pub fn clear_cooldown(&self, key: &str) {
if let Some(idx) = self.keys.iter().position(|k| k == key) {
let mut states = self.states.lock().expect("states lock");
states[idx] = KeyState::Ready;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_round_robin() {
let router = CredentialRouter::new(
vec!["key-a".into(), "key-b".into(), "key-c".into()],
Duration::from_mins(1),
);
let k1 = router.next_key().unwrap().to_string();
let k2 = router.next_key().unwrap().to_string();
let k3 = router.next_key().unwrap().to_string();
let k4 = router.next_key().unwrap().to_string();
assert_eq!(k1, "key-a");
assert_eq!(k2, "key-b");
assert_eq!(k3, "key-c");
assert_eq!(k4, "key-a");
}
#[test]
fn test_cooldown_skips_key() {
let router =
CredentialRouter::new(vec!["key-a".into(), "key-b".into()], Duration::from_mins(1));
router.mark_error("key-a");
let k = router.next_key().unwrap();
assert_eq!(k, "key-b");
}
#[test]
fn test_all_cooled_returns_none() {
let router =
CredentialRouter::new(vec!["key-a".into(), "key-b".into()], Duration::from_mins(1));
router.mark_error("key-a");
router.mark_error("key-b");
assert!(router.next_key().is_none());
}
#[test]
fn test_clear_cooldown() {
let router = CredentialRouter::new(vec!["key-a".into()], Duration::from_mins(1));
router.mark_error("key-a");
assert!(router.next_key().is_none());
router.clear_cooldown("key-a");
assert!(router.next_key().is_some());
}
#[test]
fn test_single_key() {
let router = CredentialRouter::new(vec!["only-key".into()], Duration::from_mins(1));
assert_eq!(router.next_key().unwrap(), "only-key");
assert_eq!(router.next_key().unwrap(), "only-key");
}
#[test]
fn test_len() {
let router = CredentialRouter::new(vec!["a".into(), "b".into()], Duration::from_secs(1));
assert_eq!(router.len(), 2);
assert!(!router.is_empty());
}
#[test]
#[should_panic(expected = "at least one key")]
fn test_empty_keys_panics() {
let _ = CredentialRouter::new(vec![], Duration::from_secs(1));
}
#[test]
fn test_fill_first_always_picks_first_ready() {
let router = CredentialRouter::new(
vec!["key-a".into(), "key-b".into(), "key-c".into()],
Duration::from_mins(1),
)
.with_strategy(RoutingStrategy::FillFirst);
assert_eq!(router.next_key().unwrap(), "key-a");
assert_eq!(router.next_key().unwrap(), "key-a");
assert_eq!(router.next_key().unwrap(), "key-a");
}
#[test]
fn test_fill_first_skips_cooled() {
let router = CredentialRouter::new(
vec!["key-a".into(), "key-b".into(), "key-c".into()],
Duration::from_mins(1),
)
.with_strategy(RoutingStrategy::FillFirst);
router.mark_error("key-a");
assert_eq!(router.next_key().unwrap(), "key-b");
router.mark_error("key-b");
assert_eq!(router.next_key().unwrap(), "key-c");
router.mark_error("key-c");
assert!(router.next_key().is_none());
}
#[test]
fn test_cooldown_auto_promotion() {
let router = CredentialRouter::new(vec!["key-a".into()], Duration::from_millis(1));
router.mark_error("key-a");
std::thread::sleep(Duration::from_millis(5));
assert_eq!(router.next_key().unwrap(), "key-a");
}
#[test]
fn test_blocked_not_auto_promoted() {
let router = CredentialRouter::new(
vec!["key-a".into(), "key-b".into()],
Duration::from_millis(1),
);
router.mark_blocked("key-a");
std::thread::sleep(Duration::from_millis(5));
assert_eq!(router.next_key().unwrap(), "key-b");
}
#[test]
fn test_disabled_not_auto_promoted() {
let router = CredentialRouter::new(
vec!["key-a".into(), "key-b".into()],
Duration::from_millis(1),
);
router.mark_disabled("key-a");
std::thread::sleep(Duration::from_millis(5));
assert_eq!(router.next_key().unwrap(), "key-b");
}
#[test]
fn test_clear_cooldown_restores_blocked() {
let router = CredentialRouter::new(vec!["key-a".into()], Duration::from_mins(1));
router.mark_blocked("key-a");
assert!(router.next_key().is_none());
router.clear_cooldown("key-a");
assert_eq!(router.next_key().unwrap(), "key-a");
}
#[test]
fn test_clear_cooldown_restores_disabled() {
let router = CredentialRouter::new(vec!["key-a".into()], Duration::from_mins(1));
router.mark_disabled("key-a");
assert!(router.next_key().is_none());
router.clear_cooldown("key-a");
assert_eq!(router.next_key().unwrap(), "key-a");
}
#[test]
fn test_max_retry_configuration() {
let router = CredentialRouter::new(
vec!["a".into(), "b".into(), "c".into()],
Duration::from_secs(1),
);
assert!(router.max_retry().is_none());
let router = router.with_max_retry(2);
assert_eq!(router.max_retry(), Some(2));
}
#[test]
fn test_default_strategy_is_round_robin() {
let router = CredentialRouter::new(vec!["a".into()], Duration::from_secs(1));
assert_eq!(router.next_key().unwrap(), "a");
}
#[test]
fn test_with_strategy_builder() {
let router = CredentialRouter::new(vec!["a".into()], Duration::from_secs(1))
.with_strategy(RoutingStrategy::FillFirst);
assert_eq!(router.next_key().unwrap(), "a");
}
#[test]
fn test_all_blocked_returns_none() {
let router =
CredentialRouter::new(vec!["key-a".into(), "key-b".into()], Duration::from_mins(1));
router.mark_blocked("key-a");
router.mark_blocked("key-b");
assert!(router.next_key().is_none());
}
#[test]
fn test_mixed_states() {
let router = CredentialRouter::new(
vec![
"key-a".into(),
"key-b".into(),
"key-c".into(),
"key-d".into(),
],
Duration::from_mins(1),
)
.with_strategy(RoutingStrategy::FillFirst);
router.mark_disabled("key-a");
router.mark_blocked("key-b");
router.mark_error("key-c");
assert_eq!(router.next_key().unwrap(), "key-d");
}
}