use std::sync::atomic::Ordering;
use std::sync::Arc;
use rand::Rng;
use super::inner::{now_millis, KeyInner, STATE_DEAD};
use super::lease::KeyLease;
use crate::config::PoolConfig;
use crate::error::ApiError;
pub struct KeyPool {
keys: Vec<Arc<KeyInner>>,
config: PoolConfig,
}
impl KeyPool {
pub(crate) fn new(keys: Vec<Arc<KeyInner>>, config: PoolConfig) -> Self {
Self { keys, config }
}
pub(crate) fn acquire(&self, estimated_tokens: u32) -> Option<KeyLease> {
if self.keys.is_empty() {
return None;
}
let n = self.keys.len();
let start = rand::thread_rng().gen_range(0..n);
let mut attempts = 0;
for i in 0..n {
if attempts >= self.config.max_cas_attempts {
break; }
let key = &self.keys[(start + i) % n];
if !key.is_available() {
continue;
}
attempts += 1;
let cur = key.tpm_inflight.load(Ordering::Acquire);
if cur + estimated_tokens > key.tpm_limit {
continue;
}
match key.tpm_inflight.compare_exchange(
cur,
cur + estimated_tokens,
Ordering::AcqRel,
Ordering::Relaxed,
) {
Ok(_) => {
return Some(KeyLease {
inner: Arc::clone(key),
reserved_tokens: estimated_tokens,
});
}
Err(_) => {
std::hint::spin_loop();
continue;
}
}
}
None
}
pub(crate) fn report_error(&self, lease: &KeyLease, err: &ApiError) {
match err {
ApiError::Unauthorized => {
lease.inner.state.store(STATE_DEAD, Ordering::Release);
}
ApiError::RateLimited { retry_after } => {
let until = now_millis() + retry_after.as_millis() as u64;
lease
.inner
.cool_down_until
.fetch_max(until, Ordering::AcqRel);
}
ApiError::Provider(_) | ApiError::PrimitiveProvider(_) => {
let prev = lease
.inner
.consecutive_failures
.fetch_add(1, Ordering::Relaxed);
if prev + 1 >= self.config.circuit_breaker_threshold {
let until =
now_millis() + self.config.circuit_breaker_cooldown.as_millis() as u64;
lease
.inner
.failure_cool_down_until
.store(until, Ordering::Release);
lease.inner.consecutive_failures.store(0, Ordering::Relaxed);
}
}
ApiError::Cancelled => {
}
ApiError::Protocol(_) => {}
}
}
pub(crate) fn report_success(&self, lease: &KeyLease) {
lease.inner.consecutive_failures.store(0, Ordering::Relaxed);
}
pub fn status(&self) -> Vec<KeyStatus> {
self.keys
.iter()
.map(|k| KeyStatus {
label: k.label.clone(),
available: k.is_available(),
tpm_inflight: k.tpm_inflight.load(Ordering::Relaxed),
tpm_limit: k.tpm_limit,
cool_down_until: k.cool_down_until.load(Ordering::Relaxed),
failure_cool_down_until: k.failure_cool_down_until.load(Ordering::Relaxed),
consecutive_failures: k.consecutive_failures.load(Ordering::Relaxed),
})
.collect()
}
}
#[derive(Debug, Clone)]
pub struct KeyStatus {
pub label: String,
pub available: bool,
pub tpm_inflight: u32,
pub tpm_limit: u32,
pub cool_down_until: u64,
pub failure_cool_down_until: u64,
pub consecutive_failures: u32,
}