use crate::{CurrentQuota, QuotaPolicy, QuotaResult, QuotaTooLow};
use std::sync::atomic::{AtomicU64, Ordering};
#[inline]
pub(super) fn consume_tat(
tat_ns: &AtomicU64,
effective_now_ns: u64,
cost: u64,
ns_per_token: f64,
token_ns: u64,
burst_tokens: f64,
burst_ns: u64,
burst_after_one_ns: u64,
) -> QuotaResult {
let mut observed_tat_ns = tat_ns.load(Ordering::Relaxed);
loop {
let current_tat_ns = observed_tat_ns;
if cost == 0 {
return Ok(CurrentQuota::gcra(
current_tat_ns,
effective_now_ns,
burst_ns,
ns_per_token,
burst_tokens,
));
}
let cost_ns = if cost == 1 {
token_ns
} else {
duration_ns(cost as f64, ns_per_token)
};
let allowed_tat_ns = if cost == 1 {
effective_now_ns.saturating_add(burst_after_one_ns)
} else if cost_ns > burst_ns {
return Err(QuotaTooLow::new(CurrentQuota::gcra(
current_tat_ns,
effective_now_ns,
burst_ns,
ns_per_token,
burst_tokens,
)));
} else {
effective_now_ns.saturating_add(burst_ns - cost_ns)
};
if current_tat_ns > allowed_tat_ns {
return Err(QuotaTooLow::new(CurrentQuota::gcra(
current_tat_ns,
effective_now_ns,
burst_ns,
ns_per_token,
burst_tokens,
)));
}
let new_tat_ns = current_tat_ns.max(effective_now_ns).saturating_add(cost_ns);
match tat_ns.compare_exchange_weak(
current_tat_ns,
new_tat_ns,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => {
return Ok(CurrentQuota::gcra(
new_tat_ns,
effective_now_ns,
burst_ns,
ns_per_token,
burst_tokens,
));
}
Err(actual) => observed_tat_ns = actual,
}
}
}
#[inline]
pub(super) fn check_tat(
tat_ns: &AtomicU64,
effective_now_ns: u64,
cost: u64,
ns_per_token: f64,
token_ns: u64,
burst_ns: u64,
burst_after_one_ns: u64,
) -> bool {
if cost == 0 {
return true;
}
let cost_ns = if cost == 1 {
token_ns
} else {
duration_ns(cost as f64, ns_per_token)
};
let allowed_tat_ns = if cost == 1 {
effective_now_ns.saturating_add(burst_after_one_ns)
} else if cost_ns > burst_ns {
return false;
} else {
effective_now_ns.saturating_add(burst_ns - cost_ns)
};
let mut observed_tat_ns = tat_ns.load(Ordering::Relaxed);
loop {
let current_tat_ns = observed_tat_ns;
if current_tat_ns > allowed_tat_ns {
return false;
}
let new_tat_ns = current_tat_ns.max(effective_now_ns).saturating_add(cost_ns);
match tat_ns.compare_exchange_weak(
current_tat_ns,
new_tat_ns,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => return true,
Err(actual) => observed_tat_ns = actual,
}
}
}
#[inline]
pub(super) fn check_one_tat(
tat_ns: &AtomicU64,
effective_now_ns: u64,
token_ns: u64,
burst_after_one_ns: u64,
) -> bool {
let allowed_tat_ns = effective_now_ns.saturating_add(burst_after_one_ns);
let mut observed_tat_ns = tat_ns.load(Ordering::Relaxed);
loop {
let current_tat_ns = observed_tat_ns;
if current_tat_ns > allowed_tat_ns {
return false;
}
let new_tat_ns = current_tat_ns
.max(effective_now_ns)
.saturating_add(token_ns);
match tat_ns.compare_exchange_weak(
current_tat_ns,
new_tat_ns,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => return true,
Err(actual) => observed_tat_ns = actual,
}
}
}
#[inline]
pub(super) fn burst_tokens(
policy: &QuotaPolicy,
initial_available_tokens: u64,
rate_ns: f64,
) -> f64 {
if rate_ns <= 0.0 {
return initial_available_tokens as f64;
}
policy
.capacity()
.max(initial_available_tokens as f64)
.max(0.0)
}
#[inline]
pub(super) fn ns_per_token(rate_ns: f64) -> f64 {
if rate_ns.is_finite() && rate_ns > 0.0 {
(1.0 / rate_ns).max(1.0)
} else {
1.0
}
}
#[inline]
pub(super) fn duration_ns(tokens: f64, ns_per_token: f64) -> u64 {
if tokens <= 0.0 {
return 0;
}
(tokens * ns_per_token).ceil().min((u64::MAX - 1) as f64) as u64
}
#[inline]
pub(super) fn initial_tat_ns(
burst_tokens: f64,
initial_available_tokens: u64,
ns_per_token: f64,
) -> u64 {
let initial_tokens = (initial_available_tokens as f64).clamp(0.0, burst_tokens);
duration_ns(burst_tokens - initial_tokens, ns_per_token)
}