Skip to main content

shunt/
state.rs

1/// Runtime state: per-account cooldowns/disabling + conversation stickiness.
2///
3/// Thread-safe via Arc<Mutex<>>. Cooldowns and disables are persisted to disk;
4/// stickiness is ephemeral (lost on restart is acceptable).
5use crate::config::RoutingStrategy;
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet, VecDeque};
9use std::path::{Path, PathBuf};
10use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
11use parking_lot::Mutex;
12use std::sync::Arc;
13use std::time::{SystemTime, UNIX_EPOCH};
14use tracing::warn;
15
16fn now_ms() -> u64 {
17    SystemTime::now()
18        .duration_since(UNIX_EPOCH)
19        .unwrap_or_default()
20        .as_millis() as u64
21}
22
23/// Public version of `now_ms()` for use from other modules.
24pub fn now_ms_pub() -> u64 {
25    now_ms()
26}
27
28// ---------------------------------------------------------------------------
29// Routing snapshot (single-lock data for pick_account)
30// ---------------------------------------------------------------------------
31
32/// Pre-computed per-account data for the router, taken from a single mutex lock.
33#[derive(Debug, Clone)]
34pub struct AccountRoutingData {
35    pub available: bool,
36    pub health_check_failed: bool,
37    pub exhausted: bool,
38    pub cooldown_until_ms: u64,
39    pub util_5h: f64,
40    pub util_7d: f64,
41    pub reset_5h_secs: Option<u64>,
42    pub reset_7d_secs: Option<u64>,
43    pub burst_request_count: usize,
44}
45
46/// Snapshot of all routing-relevant state, taken with a single lock.
47#[derive(Debug, Clone)]
48pub struct RoutingSnapshot {
49    pub accounts: HashMap<String, AccountRoutingData>,
50    pub now_secs: u64,
51}
52
53// ---------------------------------------------------------------------------
54// On-disk data
55// ---------------------------------------------------------------------------
56
57#[derive(Debug, Serialize, Deserialize, Default, Clone)]
58pub struct AccountState {
59    /// Epoch-ms timestamp after which this account is usable again (0 = not cooling).
60    #[serde(default)]
61    pub cooldown_until_ms: u64,
62    /// Permanently disabled (auth failure).
63    #[serde(default)]
64    pub disabled: bool,
65    /// OAuth credentials are expired and need re-authorization via `shunt add-account`.
66    #[serde(default)]
67    pub auth_failed: bool,
68    /// Account failed health-check probes — skip in routing until it recovers.
69    #[serde(default)]
70    pub health_check_failed: bool,
71    /// Consecutive health-check failure count (for exponential backoff). Ephemeral.
72    #[serde(skip)]
73    pub health_check_failures: u32,
74    /// Epoch-ms of the last health-check probe attempt. Ephemeral.
75    #[serde(skip)]
76    pub last_health_check_ms: u64,
77}
78
79#[derive(Serialize, Deserialize, Default, Clone)]
80struct StickyEntry {
81    account_name: String,
82    expires_at_ms: u64,
83}
84
85/// Rolling 5-hour quota window per account.
86#[derive(Debug, Serialize, Deserialize, Default, Clone)]
87pub struct QuotaWindow {
88    /// Epoch-ms when this window started (0 = never used).
89    #[serde(default)]
90    pub window_start_ms: u64,
91    #[serde(default)]
92    pub input_tokens: u64,
93    #[serde(default)]
94    pub output_tokens: u64,
95}
96
97impl QuotaWindow {
98    pub fn total_tokens(&self) -> u64 {
99        self.input_tokens + self.output_tokens
100    }
101    pub fn window_expires_ms(&self) -> Option<u64> {
102        if self.window_start_ms == 0 { None } else { Some(self.window_start_ms + WINDOW_MS) }
103    }
104}
105
106pub const WINDOW_MS: u64 = 5 * 60 * 60 * 1000; // 5 hours
107
108// ---------------------------------------------------------------------------
109// Request log
110// ---------------------------------------------------------------------------
111
112/// A single proxied request recorded for the live monitor.
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct RequestLog {
115    pub ts_ms: u64,
116    pub account: String,
117    pub model: String,
118    pub status: u16,
119    pub input_tokens: u64,
120    pub output_tokens: u64,
121    pub duration_ms: u64,
122}
123
124const MAX_RECENT: usize = 200;
125
126/// Rate-limit info extracted from `anthropic-ratelimit-unified-*` response headers.
127#[derive(Debug, Serialize, Deserialize, Default, Clone)]
128pub struct RateLimitInfo {
129    /// 5-hour window utilization 0.0–1.0
130    pub utilization_5h: Option<f64>,
131    /// Unix epoch seconds when 5h window resets
132    pub reset_5h: Option<u64>,
133    /// "allowed" | "exhausted"
134    pub status_5h: Option<String>,
135    /// 7-day window utilization 0.0–1.0
136    pub utilization_7d: Option<f64>,
137    /// Unix epoch seconds when 7d window resets
138    pub reset_7d: Option<u64>,
139    pub status_7d: Option<String>,
140    /// Extra usage (overage) status: "allowed" | "rejected"
141    pub overage_status: Option<String>,
142    pub overage_disabled_reason: Option<String>,
143    /// Which claim is currently representative ("five_hour" | "seven_day")
144    pub representative_claim: Option<String>,
145    pub updated_ms: u64,
146}
147
148/// Per-day token and API-cost accumulator (all accounts combined).
149#[derive(Debug, Serialize, Deserialize, Default, Clone)]
150pub struct DailyBucket {
151    pub input_tokens: u64,
152    pub output_tokens: u64,
153    /// What those tokens would have cost on the public API (USD).
154    pub api_cost_usd: f64,
155}
156
157/// Snapshot returned by `savings_snapshot()` for the status endpoint + CLI.
158#[derive(Debug, Serialize, Deserialize, Default, Clone)]
159pub struct SavingsSnapshot {
160    pub today_input: u64,
161    pub today_output: u64,
162    pub today_cost_usd: f64,
163    pub week_input: u64,
164    pub week_output: u64,
165    pub week_cost_usd: f64,
166    pub all_time_input: u64,
167    pub all_time_output: u64,
168    pub all_time_cost_usd: f64,
169}
170
171#[derive(Serialize, Deserialize, Default, Clone)]
172struct StateData {
173    #[serde(default)]
174    accounts: HashMap<String, AccountState>,
175    #[serde(default)]
176    sticky: HashMap<String, StickyEntry>,
177    #[serde(default)]
178    quota: HashMap<String, QuotaWindow>,
179    #[serde(default)]
180    rate_limits: HashMap<String, RateLimitInfo>,
181    /// If set, all requests are forced to this account (overrides routing).
182    #[serde(default)]
183    pinned_account: Option<String>,
184    /// The most recent account that successfully handled a proxied request.
185    #[serde(default)]
186    last_used_account: Option<String>,
187    /// Recent request log (ephemeral — not persisted to disk).
188    #[serde(skip)]
189    recent_requests: VecDeque<RequestLog>,
190    /// Runtime model override — all requests use this model if set (ephemeral).
191    #[serde(skip)]
192    model_override: Option<String>,
193    /// Runtime routing strategy override (ephemeral — not persisted).
194    #[serde(skip)]
195    routing_strategy_override: Option<RoutingStrategy>,
196    /// Per-account burst window: timestamps of recent requests (ephemeral).
197    #[serde(skip)]
198    burst_windows: HashMap<String, VecDeque<u64>>,
199    /// Runtime burst RPM limit override (ephemeral).
200    #[serde(skip)]
201    burst_rpm_limit_override: Option<u32>,
202    /// Runtime fallback model override (ephemeral).
203    /// `Some(Some("model"))` = explicit override, `Some(None)` = explicitly disabled, `None` = use config/auto.
204    #[serde(skip)]
205    fallback_model_override: Option<Option<String>>,
206    /// Runtime effort override (ephemeral). None = passthrough, Some("max") = override.
207    #[serde(skip)]
208    effort_override: Option<String>,
209    /// Runtime thinking mode override (ephemeral). None = passthrough, Some("adaptive"/"disabled") = override.
210    #[serde(skip)]
211    thinking_override: Option<String>,
212    /// Daily token + cost buckets keyed by "YYYY-MM-DD" (all accounts combined).
213    #[serde(default)]
214    global_daily: HashMap<String, DailyBucket>,
215    /// All-time totals.
216    #[serde(default)]
217    all_time_input: u64,
218    #[serde(default)]
219    all_time_output: u64,
220    #[serde(default)]
221    all_time_cost_usd: f64,
222}
223
224// ---------------------------------------------------------------------------
225// Store
226// ---------------------------------------------------------------------------
227
228#[derive(Clone)]
229pub struct StateStore {
230    path: PathBuf,
231    inner: Arc<Mutex<StateData>>,
232    /// Set to true when a write is needed; the background writer thread clears it.
233    pending: Arc<AtomicBool>,
234    /// Monotonically-increasing counter for round-robin account selection.
235    round_robin: Arc<AtomicUsize>,
236    /// When true, all daemon alert notifications are suppressed (ephemeral).
237    alerts_muted: Arc<AtomicBool>,
238}
239
240impl StateStore {
241    /// Create a fresh in-memory store with no backing file (useful for tests).
242    pub fn new_empty() -> Self {
243        // No background writer thread for the null store — writes are no-ops.
244        Self {
245            path: PathBuf::from("/dev/null"),
246            inner: Arc::new(Mutex::new(StateData::default())),
247            pending: Arc::new(AtomicBool::new(false)),
248            round_robin: Arc::new(AtomicUsize::new(0)),
249            alerts_muted: Arc::new(AtomicBool::new(false)),
250        }
251    }
252
253    pub fn load(path: &Path) -> Self {
254        let mut data: StateData = if path.exists() {
255            match std::fs::read_to_string(path) {
256                Ok(text) => serde_json::from_str(&text).unwrap_or_else(|e| {
257                    warn!("State file unreadable ({e}), starting fresh");
258                    StateData::default()
259                }),
260                Err(e) => {
261                    warn!("Cannot read state file ({e}), starting fresh");
262                    StateData::default()
263                }
264            }
265        } else {
266            StateData::default()
267        };
268        // Prune expired sticky entries so the file doesn't grow unbounded.
269        let now = now_ms();
270        data.sticky.retain(|_, v| v.expires_at_ms > now);
271
272        let store = Self {
273            path: path.to_owned(),
274            inner: Arc::new(Mutex::new(data)),
275            pending: Arc::new(AtomicBool::new(false)),
276            round_robin: Arc::new(AtomicUsize::new(0)),
277            alerts_muted: Arc::new(AtomicBool::new(false)),
278        };
279        store.start_writer_thread();
280        store
281    }
282
283    /// Spawn a single background thread that flushes state to disk at most every 100 ms.
284    /// This prevents unbounded thread spawning when many requests fire in rapid succession.
285    fn start_writer_thread(&self) {
286        let pending = Arc::clone(&self.pending);
287        let inner   = Arc::clone(&self.inner);
288        let path    = self.path.clone();
289        std::thread::spawn(move || {
290            loop {
291                std::thread::sleep(std::time::Duration::from_millis(100));
292                if pending.compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed).is_ok() {
293                    let data = inner.lock().clone();
294                    if let Err(e) = write_to_disk(&data, &path) {
295                        warn!("Failed to persist state: {e}");
296                    }
297                }
298            }
299        });
300    }
301
302    // -----------------------------------------------------------------------
303    // Availability
304    // -----------------------------------------------------------------------
305
306    pub fn is_available(&self, name: &str) -> bool {
307        let data = self.inner.lock();
308        match data.accounts.get(name) {
309            None => true,
310            Some(s) => !s.disabled && now_ms() >= s.cooldown_until_ms,
311        }
312    }
313
314    /// Returns true if the account's Anthropic quota is currently exhausted in any
315    /// active window (5h or 7d) — i.e. sending another request will get a 429.
316    pub fn is_exhausted(&self, name: &str) -> bool {
317        let now_secs = SystemTime::now()
318            .duration_since(UNIX_EPOCH)
319            .unwrap_or_default()
320            .as_secs();
321        let data = self.inner.lock();
322        let Some(rl) = data.rate_limits.get(name) else { return false };
323        // Only consider a window exhausted if its reset is still in the future
324        // (i.e. the window hasn't rolled over yet).
325        let exhausted_5h = rl.status_5h.as_deref() == Some("exhausted")
326            && rl.reset_5h.map(|t| t > now_secs).unwrap_or(false);
327        let exhausted_7d = rl.status_7d.as_deref() == Some("exhausted")
328            && rl.reset_7d.map(|t| t > now_secs).unwrap_or(false);
329        exhausted_5h || exhausted_7d
330    }
331
332    /// Fetch-and-increment monotonic counter for round-robin account cycling.
333    pub fn next_rr_index(&self) -> usize {
334        self.round_robin.fetch_add(1, Ordering::Relaxed)
335    }
336
337    /// Returns a snapshot of all account states for the status endpoint.
338    pub fn account_states(&self) -> HashMap<String, AccountState> {
339        self.inner.lock().accounts.clone()
340    }
341
342    /// Single-lock snapshot of everything the router needs for account selection.
343    /// Avoids per-account mutex acquisitions (O(N) → O(1) locks per pick_account call).
344    pub fn routing_snapshot(&self) -> RoutingSnapshot {
345        let now_ms  = now_ms();
346        let now_secs = now_ms / 1_000;
347        let mut data = self.inner.lock();
348
349        // Collect all account names from both accounts and rate_limits maps.
350        let all_names: Vec<String> = {
351            let mut names: HashSet<&String> = data.accounts.keys().collect();
352            names.extend(data.rate_limits.keys());
353            names.into_iter().cloned().collect()
354        };
355
356        // Pre-compute burst counts (needs mutable access for pruning).
357        let burst_counts: HashMap<String, usize> = all_names.iter()
358            .map(|name| {
359                let count = data.burst_windows.get_mut(name)
360                    .map(|deque| Self::burst_count_inner(deque, 60_000))
361                    .unwrap_or(0);
362                (name.clone(), count)
363            })
364            .collect();
365
366        let accounts: HashMap<String, AccountRoutingData> = all_names.iter().map(|name| {
367            let acc = data.accounts.get(name);
368            let available = acc.map(|a| !a.disabled && !a.auth_failed && now_ms >= a.cooldown_until_ms).unwrap_or(true);
369            let health_check_failed = acc.map(|a| a.health_check_failed).unwrap_or(false);
370            let cooldown_until_ms = acc.map(|a| a.cooldown_until_ms).unwrap_or(0);
371
372            let (util_5h, reset_5h, util_7d, reset_7d, exhausted) =
373                if let Some(rl) = data.rate_limits.get(name) {
374                    let r5 = rl.reset_5h.filter(|&t| t > now_secs);
375                    let r7 = rl.reset_7d.filter(|&t| t > now_secs);
376                    let u5 = if r5.is_some() { rl.utilization_5h.unwrap_or(0.0) } else { 0.0 };
377                    let u7 = if r7.is_some() { rl.utilization_7d.unwrap_or(0.0) } else { 0.0 };
378                    let ex = (rl.status_5h.as_deref() == Some("exhausted") && r5.is_some())
379                          || (rl.status_7d.as_deref() == Some("exhausted") && r7.is_some());
380                    (u5, r5, u7, r7, ex)
381                } else {
382                    (0.0, None, 0.0, None, false)
383                };
384
385            let burst_request_count = burst_counts.get(name).copied().unwrap_or(0);
386
387            (name.clone(), AccountRoutingData {
388                available,
389                health_check_failed,
390                exhausted,
391                cooldown_until_ms,
392                util_5h,
393                util_7d,
394                reset_5h_secs: reset_5h,
395                reset_7d_secs: reset_7d,
396                burst_request_count,
397            })
398        }).collect();
399
400        RoutingSnapshot { accounts, now_secs }
401    }
402
403    // -----------------------------------------------------------------------
404    // Burst window tracking
405    // -----------------------------------------------------------------------
406
407    /// Record a request timestamp for burst-rate tracking.
408    pub fn record_request_burst(&self, name: &str) {
409        let mut data = self.inner.lock();
410        data.burst_windows.entry(name.to_owned()).or_default().push_back(now_ms());
411    }
412
413    /// Count requests in the last `window_ms` for an account.
414    fn burst_count_inner(deque: &mut VecDeque<u64>, window_ms: u64) -> usize {
415        let cutoff = now_ms().saturating_sub(window_ms);
416        // Prune old entries from front
417        while deque.front().map(|&t| t < cutoff).unwrap_or(false) {
418            deque.pop_front();
419        }
420        deque.len()
421    }
422
423    // -----------------------------------------------------------------------
424    // Cooldown / disable
425    // -----------------------------------------------------------------------
426
427    pub fn set_cooldown(&self, name: &str, duration_ms: u64) {
428        {
429            let mut data = self.inner.lock();
430            let acc = data.accounts.entry(name.to_owned()).or_default();
431            acc.cooldown_until_ms = now_ms() + duration_ms;
432        }
433        self.persist();
434    }
435
436    /// Like `set_cooldown`, but staggers the deadline so it doesn't collide with
437    /// other accounts already cooling. Prevents the cascade where both accounts
438    /// expire simultaneously, both get 429'd again, and loop forever.
439    /// Adds 5s offset per account already cooling within ±5s of our target deadline.
440    pub fn set_cooldown_staggered(&self, name: &str, duration_ms: u64) {
441        const STAGGER_MS: u64 = 5_000;
442        {
443            let mut data = self.inner.lock();
444            let now = now_ms();
445            let target = now + duration_ms;
446
447            // Count other accounts with cooldowns expiring within STAGGER_MS of our target
448            let nearby_count = data.accounts.iter()
449                .filter(|(n, a)| {
450                    *n != name
451                        && a.cooldown_until_ms > now
452                        && (a.cooldown_until_ms as i64 - target as i64).unsigned_abs() < STAGGER_MS
453                })
454                .count() as u64;
455
456            let offset = nearby_count.saturating_mul(STAGGER_MS);
457            let acc = data.accounts.entry(name.to_owned()).or_default();
458            acc.cooldown_until_ms = target + offset;
459        }
460        self.persist();
461    }
462
463    pub fn disable_account(&self, name: &str) {
464        {
465            let mut data = self.inner.lock();
466            data.accounts.entry(name.to_owned()).or_default().disabled = true;
467        }
468        self.persist();
469    }
470
471    pub fn set_auth_failed(&self, name: &str) {
472        {
473            let mut data = self.inner.lock();
474            let acc = data.accounts.entry(name.to_owned()).or_default();
475            acc.auth_failed = true;
476            acc.disabled = true; // also disable so it's skipped in routing
477        }
478        self.persist();
479    }
480
481    /// Clear auth_failed + disabled for an account after a successful token refresh.
482    pub fn clear_auth_failed(&self, name: &str) {
483        {
484            let mut data = self.inner.lock();
485            if let Some(acc) = data.accounts.get_mut(name) {
486                acc.auth_failed = false;
487                acc.disabled = false;
488            }
489        }
490        self.persist();
491    }
492
493    /// Returns names of accounts (from the given list) that have auth_failed set.
494    pub fn auth_failed_accounts<'a>(&self, names: &[&'a str]) -> Vec<&'a str> {
495        let data = self.inner.lock();
496        names.iter()
497            .filter(|&&n| data.accounts.get(n).map(|s| s.auth_failed).unwrap_or(false))
498            .copied()
499            .collect()
500    }
501
502    // -----------------------------------------------------------------------
503    // Health check state
504    // -----------------------------------------------------------------------
505
506    pub fn is_health_check_failed(&self, name: &str) -> bool {
507        let data = self.inner.lock();
508        data.accounts.get(name).map(|s| s.health_check_failed).unwrap_or(false)
509    }
510
511    pub fn set_health_check_failed(&self, name: &str) {
512        {
513            let mut data = self.inner.lock();
514            let acc = data.accounts.entry(name.to_owned()).or_default();
515            acc.health_check_failed = true;
516        }
517        self.persist();
518    }
519
520    pub fn clear_health_check_failed(&self, name: &str) {
521        {
522            let mut data = self.inner.lock();
523            if let Some(acc) = data.accounts.get_mut(name) {
524                acc.health_check_failed = false;
525                acc.health_check_failures = 0;
526            }
527        }
528        self.persist();
529    }
530
531    /// Increment consecutive failure count and return the new value.
532    /// Sets `health_check_failed = true` once failures >= `threshold`.
533    pub fn record_health_check_failure(&self, name: &str, threshold: u32) -> u32 {
534        let count;
535        {
536            let mut data = self.inner.lock();
537            let acc = data.accounts.entry(name.to_owned()).or_default();
538            acc.health_check_failures = acc.health_check_failures.saturating_add(1);
539            count = acc.health_check_failures;
540            if count >= threshold {
541                acc.health_check_failed = true;
542            }
543        }
544        if count >= threshold {
545            self.persist();
546        }
547        count
548    }
549
550    /// Update last_health_check_ms to now. Returns the previous value.
551    pub fn update_last_health_check(&self, name: &str) -> u64 {
552        let mut data = self.inner.lock();
553        let acc = data.accounts.entry(name.to_owned()).or_default();
554        let prev = acc.last_health_check_ms;
555        acc.last_health_check_ms = now_ms();
556        prev
557    }
558
559    /// Get the last health check timestamp and consecutive failure count.
560    pub fn health_check_info(&self, name: &str) -> (u64, u32) {
561        let data = self.inner.lock();
562        match data.accounts.get(name) {
563            Some(acc) => (acc.last_health_check_ms, acc.health_check_failures),
564            None => (0, 0),
565        }
566    }
567
568    // -----------------------------------------------------------------------
569    // Stickiness (ephemeral — not persisted)
570    // -----------------------------------------------------------------------
571
572    pub fn get_sticky(&self, fingerprint: &str) -> Option<String> {
573        let data = self.inner.lock();
574        let entry = data.sticky.get(fingerprint)?;
575        if now_ms() < entry.expires_at_ms {
576            Some(entry.account_name.clone())
577        } else {
578            None
579        }
580    }
581
582    pub fn set_sticky(&self, fingerprint: &str, account_name: &str, ttl_ms: u64) {
583        const MAX_STICKY_ENTRIES: usize = 10_000;
584        {
585            let mut data = self.inner.lock();
586            // Prune expired entries if approaching limit
587            if data.sticky.len() >= MAX_STICKY_ENTRIES {
588                let now = now_ms();
589                data.sticky.retain(|_, v| v.expires_at_ms > now);
590                // If still at limit after pruning, clear oldest half to prevent DoS
591                if data.sticky.len() >= MAX_STICKY_ENTRIES {
592                    data.sticky.clear();
593                }
594            }
595            data.sticky.insert(
596                fingerprint.to_owned(),
597                StickyEntry { account_name: account_name.to_owned(), expires_at_ms: now_ms() + ttl_ms },
598            );
599        }
600        self.persist();
601    }
602
603    // -----------------------------------------------------------------------
604    // Quota tracking
605    // -----------------------------------------------------------------------
606
607    /// Unix epoch seconds when this account's 5h window resets.
608    /// Returns None if unknown or already past.
609    pub fn reset_5h_secs(&self, name: &str) -> Option<u64> {
610        let now_secs = SystemTime::now()
611            .duration_since(UNIX_EPOCH)
612            .unwrap_or_default()
613            .as_secs();
614        let data = self.inner.lock();
615        let reset = data.rate_limits.get(name)?.reset_5h?;
616        if reset > now_secs { Some(reset) } else { None }
617    }
618
619    /// 5-hour utilization 0.0–1.0 from the last upstream response headers.
620    /// Returns 0.0 for fresh accounts or when the reset window has already passed.
621    pub fn utilization_5h(&self, name: &str) -> f64 {
622        let now_secs = SystemTime::now()
623            .duration_since(UNIX_EPOCH)
624            .unwrap_or_default()
625            .as_secs();
626        let data = self.inner.lock();
627        let Some(rl) = data.rate_limits.get(name) else { return 0.0 };
628        // If the reset time is in the past, the window has rolled over — treat as fresh
629        if rl.reset_5h.map(|t| t <= now_secs).unwrap_or(false) {
630            return 0.0;
631        }
632        rl.utilization_5h.unwrap_or(0.0)
633    }
634
635    /// 7-day utilization 0.0–1.0 from the last upstream response headers.
636    /// Returns 0.0 for fresh accounts or when the reset window has already passed.
637    pub fn utilization_7d(&self, name: &str) -> f64 {
638        let now_secs = SystemTime::now()
639            .duration_since(UNIX_EPOCH)
640            .unwrap_or_default()
641            .as_secs();
642        let data = self.inner.lock();
643        let Some(rl) = data.rate_limits.get(name) else { return 0.0 };
644        if rl.reset_7d.map(|t| t <= now_secs).unwrap_or(false) {
645            return 0.0;
646        }
647        rl.utilization_7d.unwrap_or(0.0)
648    }
649
650    /// Unix epoch seconds when this account's 7d window resets.
651    /// Returns None if unknown or already past.
652    pub fn reset_7d_secs(&self, name: &str) -> Option<u64> {
653        let now_secs = SystemTime::now()
654            .duration_since(UNIX_EPOCH)
655            .unwrap_or_default()
656            .as_secs();
657        let data = self.inner.lock();
658        let reset = data.rate_limits.get(name)?.reset_7d?;
659        if reset > now_secs { Some(reset) } else { None }
660    }
661
662    /// Record token usage from a completed request.
663    /// Lazily resets the window if the 5-hour period has elapsed.
664    pub fn record_usage(&self, name: &str, input_tokens: u64, output_tokens: u64) {
665        if input_tokens == 0 && output_tokens == 0 {
666            return;
667        }
668        {
669            let mut data = self.inner.lock();
670            let quota = data.quota.entry(name.to_owned()).or_default();
671            let now = now_ms();
672            if quota.window_start_ms == 0 || now >= quota.window_start_ms + WINDOW_MS {
673                quota.window_start_ms = now;
674                quota.input_tokens = 0;
675                quota.output_tokens = 0;
676            }
677            quota.input_tokens = quota.input_tokens.saturating_add(input_tokens);
678            quota.output_tokens = quota.output_tokens.saturating_add(output_tokens);
679        }
680        self.persist();
681    }
682
683    /// Snapshot of all quota windows for the status endpoint.
684    pub fn quota_snapshot(&self) -> HashMap<String, QuotaWindow> {
685        self.inner.lock().quota.clone()
686    }
687
688    // -----------------------------------------------------------------------
689    // Rate limit header tracking
690    // -----------------------------------------------------------------------
691
692    pub fn update_rate_limits(&self, name: &str, info: RateLimitInfo) {
693        let prev = self.inner.lock().rate_limits.get(name).cloned();
694
695        // Warn the first time utilization crosses 90% for each window.
696        let prev_5h = prev.as_ref().and_then(|p| p.utilization_5h).unwrap_or(0.0);
697        let prev_7d = prev.as_ref().and_then(|p| p.utilization_7d).unwrap_or(0.0);
698        if let Some(u) = info.utilization_5h {
699            if u >= 0.9 && prev_5h < 0.9 {
700                warn!(account = %name, utilization = %format!("{:.0}%", u * 100.0),
701                    "5h rate limit above 90% — approaching quota");
702            }
703        }
704        if let Some(u) = info.utilization_7d {
705            if u >= 0.9 && prev_7d < 0.9 {
706                warn!(account = %name, utilization = %format!("{:.0}%", u * 100.0),
707                    "7d rate limit above 90% — approaching quota");
708            }
709        }
710
711        {
712            let mut data = self.inner.lock();
713            data.rate_limits.insert(name.to_owned(), info);
714        }
715        self.persist();
716    }
717
718    pub fn rate_limit_snapshot(&self) -> HashMap<String, RateLimitInfo> {
719        self.inner.lock().rate_limits.clone()
720    }
721
722    // -----------------------------------------------------------------------
723    // Account pinning
724    // -----------------------------------------------------------------------
725
726    pub fn get_pinned(&self) -> Option<String> {
727        self.inner.lock().pinned_account.clone()
728    }
729
730    pub fn set_pinned(&self, name: Option<String>) {
731        {
732            let mut data = self.inner.lock();
733            data.pinned_account = name;
734        }
735        self.persist();
736    }
737
738    // -----------------------------------------------------------------------
739    // Last-used tracking
740    // -----------------------------------------------------------------------
741
742    pub fn get_last_used(&self) -> Option<String> {
743        self.inner.lock().last_used_account.clone()
744    }
745
746    pub fn set_last_used(&self, name: &str) {
747        {
748            let mut data = self.inner.lock();
749            data.last_used_account = Some(name.to_owned());
750        }
751        self.persist();
752    }
753
754    // -----------------------------------------------------------------------
755    // Model override
756    // -----------------------------------------------------------------------
757
758    pub fn get_model_override(&self) -> Option<String> {
759        self.inner.lock().model_override.clone()
760    }
761
762    pub fn set_model_override(&self, model: String) {
763        self.inner.lock().model_override = Some(model);
764    }
765
766    pub fn clear_model_override(&self) {
767        self.inner.lock().model_override = None;
768    }
769
770    // -----------------------------------------------------------------------
771    // Routing strategy override
772    // -----------------------------------------------------------------------
773
774    pub fn get_routing_strategy(&self) -> Option<RoutingStrategy> {
775        self.inner.lock().routing_strategy_override
776    }
777
778    pub fn set_routing_strategy(&self, strategy: RoutingStrategy) {
779        self.inner.lock().routing_strategy_override = Some(strategy);
780    }
781
782    pub fn clear_routing_strategy(&self) {
783        self.inner.lock().routing_strategy_override = None;
784    }
785
786    // -----------------------------------------------------------------------
787    // Burst RPM limit override
788    // -----------------------------------------------------------------------
789
790    pub fn get_burst_rpm_limit_override(&self) -> Option<u32> {
791        self.inner.lock().burst_rpm_limit_override
792    }
793
794    pub fn set_burst_rpm_limit_override(&self, limit: u32) {
795        self.inner.lock().burst_rpm_limit_override = Some(limit);
796    }
797
798    pub fn clear_burst_rpm_limit_override(&self) {
799        self.inner.lock().burst_rpm_limit_override = None;
800    }
801
802    // -----------------------------------------------------------------------
803    // Fallback model override
804    // -----------------------------------------------------------------------
805
806    /// Returns `Some(Some("model"))` for explicit override, `Some(None)` for explicitly disabled,
807    /// `None` for "use config/auto".
808    pub fn get_fallback_model_override(&self) -> Option<Option<String>> {
809        self.inner.lock().fallback_model_override.clone()
810    }
811
812    pub fn set_fallback_model_override(&self, model: Option<String>) {
813        self.inner.lock().fallback_model_override = Some(model);
814    }
815
816    pub fn clear_fallback_model_override(&self) {
817        self.inner.lock().fallback_model_override = None;
818    }
819
820    // -----------------------------------------------------------------------
821    // Effort override
822    // -----------------------------------------------------------------------
823
824    pub fn get_effort_override(&self) -> Option<String> {
825        self.inner.lock().effort_override.clone()
826    }
827
828    pub fn set_effort_override(&self, effort: String) {
829        self.inner.lock().effort_override = Some(effort);
830    }
831
832    pub fn clear_effort_override(&self) {
833        self.inner.lock().effort_override = None;
834    }
835
836    // -----------------------------------------------------------------------
837    // Thinking mode override
838    // -----------------------------------------------------------------------
839
840    pub fn get_thinking_override(&self) -> Option<String> {
841        self.inner.lock().thinking_override.clone()
842    }
843
844    pub fn set_thinking_override(&self, mode: String) {
845        self.inner.lock().thinking_override = Some(mode);
846    }
847
848    pub fn clear_thinking_override(&self) {
849        self.inner.lock().thinking_override = None;
850    }
851
852    // -----------------------------------------------------------------------
853    // Alerts mute
854    // -----------------------------------------------------------------------
855
856    pub fn get_alerts_muted(&self) -> bool {
857        self.alerts_muted.load(Ordering::Relaxed)
858    }
859
860    pub fn set_alerts_muted(&self, muted: bool) {
861        self.alerts_muted.store(muted, Ordering::Relaxed);
862    }
863
864    // -----------------------------------------------------------------------
865    // Request log
866    // -----------------------------------------------------------------------
867
868    pub fn record_request(&self, log: RequestLog) {
869        let mut data = self.inner.lock();
870        if data.recent_requests.len() >= MAX_RECENT {
871            data.recent_requests.pop_front();
872        }
873        data.recent_requests.push_back(log);
874    }
875
876    /// Most-recent first snapshot for the monitor / status endpoint.
877    pub fn recent_requests_snapshot(&self) -> Vec<RequestLog> {
878        let data = self.inner.lock();
879        data.recent_requests.iter().rev().cloned().collect()
880    }
881
882    // -----------------------------------------------------------------------
883    // Global savings tracking
884    // -----------------------------------------------------------------------
885
886    /// Record tokens + API cost globally (across all accounts) for the savings display.
887    pub fn record_global(&self, model: &str, input_tokens: u64, output_tokens: u64) {
888        if input_tokens == 0 && output_tokens == 0 {
889            return;
890        }
891        let cost = crate::pricing::api_cost_usd(model, input_tokens, output_tokens);
892        let key = today_key();
893        {
894            let mut data = self.inner.lock();
895            let bucket = data.global_daily.entry(key).or_default();
896            bucket.input_tokens  = bucket.input_tokens.saturating_add(input_tokens);
897            bucket.output_tokens = bucket.output_tokens.saturating_add(output_tokens);
898            bucket.api_cost_usd  += cost;
899            data.all_time_input  = data.all_time_input.saturating_add(input_tokens);
900            data.all_time_output = data.all_time_output.saturating_add(output_tokens);
901            data.all_time_cost_usd += cost;
902
903            // Prune buckets older than 90 days to prevent unbounded growth.
904            if data.global_daily.len() > 100 {
905                let cutoff = epoch_to_ymd(
906                    SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_secs()
907                        .saturating_sub(90 * 86400)
908                );
909                data.global_daily.retain(|k, _| k.as_str() >= cutoff.as_str());
910            }
911        }
912        self.persist();
913    }
914
915    /// Snapshot of daily and all-time savings for the status endpoint and CLI.
916    pub fn savings_snapshot(&self) -> SavingsSnapshot {
917        let now_secs = SystemTime::now()
918            .duration_since(UNIX_EPOCH)
919            .unwrap_or_default()
920            .as_secs();
921        let today   = today_key();
922        let week_ago = epoch_to_ymd(now_secs.saturating_sub(7 * 86400));
923
924        let data = self.inner.lock();
925
926        let today_bucket = data.global_daily.get(&today).cloned().unwrap_or_default();
927
928        let (week_input, week_output, week_cost) = data.global_daily.iter()
929            .filter(|(k, _)| k.as_str() >= week_ago.as_str())
930            .fold((0u64, 0u64, 0f64), |(i, o, c), (_, b)| {
931                (i + b.input_tokens, o + b.output_tokens, c + b.api_cost_usd)
932            });
933
934        SavingsSnapshot {
935            today_input:      today_bucket.input_tokens,
936            today_output:     today_bucket.output_tokens,
937            today_cost_usd:   today_bucket.api_cost_usd,
938            week_input,
939            week_output,
940            week_cost_usd:    week_cost,
941            all_time_input:   data.all_time_input,
942            all_time_output:  data.all_time_output,
943            all_time_cost_usd: data.all_time_cost_usd,
944        }
945    }
946
947    // -----------------------------------------------------------------------
948    // Persistence
949    // -----------------------------------------------------------------------
950
951    fn persist(&self) {
952        // Signal the background writer thread; it will flush within ~100 ms.
953        self.pending.store(true, Ordering::Release);
954    }
955}
956
957#[cfg(test)]
958mod tests {
959    use super::*;
960
961    #[test]
962    fn test_sticky_ttl_expiry() {
963        let store = StateStore::new_empty();
964        let fp = "conv-fp-ttl";
965        store.set_sticky(fp, "account1", 500); // 500 ms TTL
966        assert_eq!(store.get_sticky(fp).as_deref(), Some("account1"),
967            "sticky should be available immediately");
968        std::thread::sleep(std::time::Duration::from_millis(600));
969        assert!(store.get_sticky(fp).is_none(),
970            "sticky must expire after TTL elapses");
971    }
972
973    #[test]
974    fn test_cooldown_blocks_availability() {
975        let store = StateStore::new_empty();
976        store.set_cooldown("acc", 5_000); // 5s cooldown
977        assert!(!store.is_available("acc"), "account should be unavailable during cooldown");
978    }
979
980    #[test]
981    fn test_disable_blocks_availability() {
982        let store = StateStore::new_empty();
983        store.disable_account("acc");
984        assert!(!store.is_available("acc"), "disabled account must be unavailable");
985    }
986
987    #[test]
988    fn test_quota_accumulates() {
989        let store = StateStore::new_empty();
990        store.record_usage("acc", 100, 50);
991        store.record_usage("acc", 200, 75);
992        let snap = store.quota_snapshot();
993        let q = &snap["acc"];
994        assert_eq!(q.input_tokens, 300);
995        assert_eq!(q.output_tokens, 125);
996        assert_eq!(q.total_tokens(), 425);
997    }
998
999    #[test]
1000    fn test_pinned_account_round_trip() {
1001        let store = StateStore::new_empty();
1002        assert!(store.get_pinned().is_none());
1003        store.set_pinned(Some("myaccount".into()));
1004        assert_eq!(store.get_pinned().as_deref(), Some("myaccount"));
1005        store.set_pinned(None);
1006        assert!(store.get_pinned().is_none());
1007    }
1008
1009    #[test]
1010    fn test_last_used_round_trip() {
1011        let store = StateStore::new_empty();
1012        assert!(store.get_last_used().is_none());
1013        store.set_last_used("acc1");
1014        assert_eq!(store.get_last_used().as_deref(), Some("acc1"));
1015    }
1016
1017    #[test]
1018    fn test_recent_requests_ring_buffer() {
1019        let store = StateStore::new_empty();
1020        // Fill past MAX_RECENT
1021        for i in 0..=(MAX_RECENT + 5) {
1022            store.record_request(RequestLog {
1023                ts_ms: i as u64,
1024                account: "acc".into(),
1025                model: "m".into(),
1026                status: 200,
1027                input_tokens: 1,
1028                output_tokens: 1,
1029                duration_ms: 1,
1030            });
1031        }
1032        let snap = store.recent_requests_snapshot();
1033        assert_eq!(snap.len(), MAX_RECENT, "buffer must not grow beyond MAX_RECENT");
1034        // Most recent first
1035        assert!(snap[0].ts_ms > snap[snap.len() - 1].ts_ms, "snapshot must be newest-first");
1036    }
1037
1038    #[test]
1039    fn test_health_check_failed_round_trip() {
1040        let store = StateStore::new_empty();
1041        assert!(!store.is_health_check_failed("acc"), "fresh account should not be health-check-failed");
1042
1043        store.set_health_check_failed("acc");
1044        assert!(store.is_health_check_failed("acc"), "should be marked unhealthy after set");
1045
1046        store.clear_health_check_failed("acc");
1047        assert!(!store.is_health_check_failed("acc"), "should be cleared after clear");
1048    }
1049
1050    #[test]
1051    fn test_health_check_failure_threshold() {
1052        let store = StateStore::new_empty();
1053
1054        // First failure: count=1, threshold=2 → not yet marked
1055        let count = store.record_health_check_failure("acc", 2);
1056        assert_eq!(count, 1);
1057        assert!(!store.is_health_check_failed("acc"),
1058            "should not be marked after 1 failure (threshold=2)");
1059
1060        // Second failure: count=2, threshold=2 → now marked
1061        let count = store.record_health_check_failure("acc", 2);
1062        assert_eq!(count, 2);
1063        assert!(store.is_health_check_failed("acc"),
1064            "should be marked after 2 failures (threshold=2)");
1065    }
1066
1067    #[test]
1068    fn test_clear_health_check_resets_failure_count() {
1069        let store = StateStore::new_empty();
1070        store.record_health_check_failure("acc", 2);
1071        store.record_health_check_failure("acc", 2);
1072        assert!(store.is_health_check_failed("acc"));
1073
1074        store.clear_health_check_failed("acc");
1075        assert!(!store.is_health_check_failed("acc"));
1076
1077        let (_, failures) = store.health_check_info("acc");
1078        assert_eq!(failures, 0, "failure count must reset to 0 after clear");
1079    }
1080
1081    #[test]
1082    fn test_health_check_info_and_last_check() {
1083        let store = StateStore::new_empty();
1084        let (last, failures) = store.health_check_info("acc");
1085        assert_eq!(last, 0, "fresh account last_health_check_ms should be 0");
1086        assert_eq!(failures, 0);
1087
1088        let prev = store.update_last_health_check("acc");
1089        assert_eq!(prev, 0, "first update should return previous value 0");
1090
1091        let (last2, _) = store.health_check_info("acc");
1092        assert!(last2 > 0, "last_health_check_ms should be updated to now");
1093    }
1094
1095    #[test]
1096    fn test_health_check_failed_persists() {
1097        let path = std::env::temp_dir().join(format!(
1098            "shunt_test_hc_{}.json",
1099            std::time::SystemTime::now()
1100                .duration_since(std::time::UNIX_EPOCH)
1101                .unwrap()
1102                .as_nanos()
1103        ));
1104
1105        {
1106            let store = StateStore::load(&path);
1107            store.set_health_check_failed("acc");
1108            std::thread::sleep(std::time::Duration::from_millis(300));
1109        }
1110
1111        let store2 = StateStore::load(&path);
1112        assert!(store2.is_health_check_failed("acc"),
1113            "health_check_failed must survive restart");
1114
1115        // Ephemeral fields should NOT persist
1116        let (last, failures) = store2.health_check_info("acc");
1117        assert_eq!(last, 0, "last_health_check_ms is ephemeral, should be 0 after reload");
1118        assert_eq!(failures, 0, "health_check_failures is ephemeral, should be 0 after reload");
1119
1120        let _ = std::fs::remove_file(&path);
1121    }
1122
1123    #[test]
1124    fn test_state_persistence_roundtrip() {
1125        // Use a unique temp path so parallel tests don't collide
1126        let path = std::env::temp_dir().join(format!(
1127            "shunt_test_state_{}.json",
1128            std::time::SystemTime::now()
1129                .duration_since(std::time::UNIX_EPOCH)
1130                .unwrap()
1131                .as_nanos()
1132        ));
1133
1134        {
1135            let store = StateStore::load(&path);
1136            store.set_cooldown("acc", 999_999_000); // far-future cooldown
1137            store.record_usage("acc", 111, 222);
1138            store.set_last_used("acc");
1139            // Wait for the background writer (polls every 100 ms) to flush
1140            std::thread::sleep(std::time::Duration::from_millis(300));
1141        }
1142
1143        // Load a fresh store from the persisted file
1144        let store2 = StateStore::load(&path);
1145        assert!(!store2.is_available("acc"), "cooldown must survive restart");
1146        let snap = store2.quota_snapshot();
1147        assert_eq!(snap["acc"].input_tokens, 111, "quota must survive restart");
1148        assert_eq!(snap["acc"].output_tokens, 222);
1149        assert_eq!(store2.get_last_used().as_deref(), Some("acc"),
1150            "last_used_account must survive restart");
1151
1152        let _ = std::fs::remove_file(&path);
1153    }
1154
1155    #[test]
1156    fn test_burst_window_tracking() {
1157        let store = StateStore::new_empty();
1158        // Record 5 requests
1159        for _ in 0..5 {
1160            store.record_request_burst("acc");
1161        }
1162        // Snapshot should show 5 in burst_request_count
1163        let snap = store.routing_snapshot();
1164        let data = snap.accounts.get("acc");
1165        assert!(data.is_none() || data.unwrap().burst_request_count == 0,
1166            "no account state yet, burst tracked separately");
1167        // Now create account state so it appears in snapshot
1168        store.set_cooldown("acc", 0); // creates the AccountState entry
1169        for _ in 0..3 {
1170            store.record_request_burst("acc");
1171        }
1172        let snap = store.routing_snapshot();
1173        let data = snap.accounts.get("acc").expect("acc should exist in snapshot");
1174        // Should have all 8 requests (5 + 3) since they're within the 60s window
1175        assert_eq!(data.burst_request_count, 8, "should count all recent requests");
1176    }
1177}
1178
1179/// "YYYY-MM-DD" string for today in UTC.
1180fn today_key() -> String {
1181    let secs = SystemTime::now()
1182        .duration_since(UNIX_EPOCH)
1183        .unwrap_or_default()
1184        .as_secs();
1185    epoch_to_ymd(secs)
1186}
1187
1188/// Convert Unix epoch seconds to "YYYY-MM-DD" (UTC) using Hinnant's civil_from_days.
1189fn epoch_to_ymd(secs: u64) -> String {
1190    let days = (secs / 86400) as i64;
1191    let z    = days + 719_468;
1192    let era  = if z >= 0 { z } else { z - 146_096 } / 146_097;
1193    let doe  = z - era * 146_097;
1194    let yoe  = (doe - doe / 1_460 + doe / 36_524 - doe / 146_096) / 365;
1195    let y    = yoe + era * 400;
1196    let doy  = doe - (365 * yoe + yoe / 4 - yoe / 100);
1197    let mp   = (5 * doy + 2) / 153;
1198    let d    = doy - (153 * mp + 2) / 5 + 1;
1199    let m    = if mp < 10 { mp + 3 } else { mp - 9 };
1200    let y    = if m <= 2 { y + 1 } else { y };
1201    format!("{y:04}-{m:02}-{d:02}")
1202}
1203
1204fn write_to_disk(data: &StateData, path: &Path) -> Result<()> {
1205    if let Some(parent) = path.parent() {
1206        std::fs::create_dir_all(parent)?;
1207    }
1208    let tmp = path.with_extension("tmp");
1209    std::fs::write(&tmp, serde_json::to_string_pretty(data)?)?;
1210    #[cfg(unix)]
1211    {
1212        use std::os::unix::fs::PermissionsExt;
1213        let _ = std::fs::set_permissions(&tmp, std::fs::Permissions::from_mode(0o600));
1214    }
1215    std::fs::rename(&tmp, path)?;
1216    Ok(())
1217}