use std::sync::atomic::{AtomicBool, Ordering};
use tokio::time::{interval, Duration};
use crate::constants::DEFAULT_RECOVERY_PROBE_INTERVAL_SECONDS;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RateLimitTier {
None,
WorkerLimited,
OrchestratorLimited,
FullHibernate,
}
pub struct RateLimiter {
tier: std::sync::Mutex<RateLimitTier>,
paused: AtomicBool,
recovery_interval: u64,
}
impl RateLimiter {
pub fn new(recovery_interval: u64) -> Self {
Self {
tier: std::sync::Mutex::new(RateLimitTier::None),
paused: AtomicBool::new(false),
recovery_interval,
}
}
pub fn current_tier(&self) -> RateLimitTier {
*self.tier.lock().unwrap()
}
pub fn is_dispatch_paused(&self) -> bool {
let tier = self.current_tier();
matches!(
tier,
RateLimitTier::WorkerLimited
| RateLimitTier::OrchestratorLimited
| RateLimitTier::FullHibernate
)
}
pub fn is_orchestrator_paused(&self) -> bool {
let tier = self.current_tier();
matches!(
tier,
RateLimitTier::OrchestratorLimited | RateLimitTier::FullHibernate
)
}
pub fn report_worker_rate_limit(&self) {
let mut tier = self.tier.lock().unwrap();
match *tier {
RateLimitTier::None => {
*tier = RateLimitTier::WorkerLimited;
}
RateLimitTier::OrchestratorLimited => {
*tier = RateLimitTier::FullHibernate;
}
_ => {}
}
self.paused.store(true, Ordering::Relaxed);
}
pub fn report_orchestrator_rate_limit(&self) {
let mut tier = self.tier.lock().unwrap();
match *tier {
RateLimitTier::None => {
*tier = RateLimitTier::OrchestratorLimited;
}
RateLimitTier::WorkerLimited => {
*tier = RateLimitTier::FullHibernate;
}
_ => {}
}
self.paused.store(true, Ordering::Relaxed);
}
pub fn report_recovery(&self) {
let mut tier = self.tier.lock().unwrap();
*tier = RateLimitTier::None;
self.paused.store(false, Ordering::Relaxed);
}
pub async fn start_recovery_probe(&self) {
let mut tick = interval(Duration::from_secs(self.recovery_interval));
loop {
tick.tick().await;
if !self.paused.load(Ordering::Relaxed) {
continue; }
match probe_api().await {
Ok(()) => {
self.report_recovery();
tracing::info!("Rate limit recovery: API probe succeeded, resuming");
}
Err(e) => {
tracing::warn!("Rate limit recovery probe failed: {}", e);
}
}
}
}
}
impl Default for RateLimiter {
fn default() -> Self {
Self::new(DEFAULT_RECOVERY_PROBE_INTERVAL_SECONDS)
}
}
async fn probe_api() -> Result<(), String> {
let output = tokio::process::Command::new("gh")
.args(["api", "rate_limit"])
.output()
.await
.map_err(|e| e.to_string())?;
if output.status.success() {
Ok(())
} else {
Err("API probe failed".into())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_rate_limiter_starts_at_none() {
let rl = RateLimiter::new(300);
assert_eq!(rl.current_tier(), RateLimitTier::None);
}
#[test]
fn report_worker_rate_limit_escalates_to_worker_limited() {
let rl = RateLimiter::new(300);
rl.report_worker_rate_limit();
assert_eq!(rl.current_tier(), RateLimitTier::WorkerLimited);
}
#[test]
fn report_orchestrator_rate_limit_escalates_to_orchestrator_limited() {
let rl = RateLimiter::new(300);
rl.report_orchestrator_rate_limit();
assert_eq!(rl.current_tier(), RateLimitTier::OrchestratorLimited);
}
#[test]
fn both_reports_escalate_to_full_hibernate() {
let rl = RateLimiter::new(300);
rl.report_worker_rate_limit();
rl.report_orchestrator_rate_limit();
assert_eq!(rl.current_tier(), RateLimitTier::FullHibernate);
}
#[test]
fn both_reports_reverse_order_escalate_to_full_hibernate() {
let rl = RateLimiter::new(300);
rl.report_orchestrator_rate_limit();
rl.report_worker_rate_limit();
assert_eq!(rl.current_tier(), RateLimitTier::FullHibernate);
}
#[test]
fn report_recovery_resets_to_none() {
let rl = RateLimiter::new(300);
rl.report_worker_rate_limit();
rl.report_orchestrator_rate_limit();
assert_eq!(rl.current_tier(), RateLimitTier::FullHibernate);
rl.report_recovery();
assert_eq!(rl.current_tier(), RateLimitTier::None);
}
#[test]
fn is_dispatch_paused_true_when_worker_limited_or_higher() {
let rl = RateLimiter::new(300);
assert!(!rl.is_dispatch_paused());
rl.report_worker_rate_limit();
assert!(rl.is_dispatch_paused());
rl.report_recovery();
rl.report_orchestrator_rate_limit();
assert!(rl.is_dispatch_paused());
rl.report_worker_rate_limit(); assert!(rl.is_dispatch_paused());
}
#[test]
fn is_orchestrator_paused_true_when_orchestrator_limited_or_higher() {
let rl = RateLimiter::new(300);
assert!(!rl.is_orchestrator_paused());
rl.report_worker_rate_limit();
assert!(!rl.is_orchestrator_paused());
rl.report_recovery();
rl.report_orchestrator_rate_limit();
assert!(rl.is_orchestrator_paused());
rl.report_worker_rate_limit();
assert_eq!(rl.current_tier(), RateLimitTier::FullHibernate);
assert!(rl.is_orchestrator_paused());
}
#[test]
fn paused_flag_tracks_tier_state() {
let rl = RateLimiter::new(300);
assert!(!rl.paused.load(Ordering::Relaxed));
rl.report_worker_rate_limit();
assert!(rl.paused.load(Ordering::Relaxed));
rl.report_recovery();
assert!(!rl.paused.load(Ordering::Relaxed));
}
}