Skip to main content

shunt/
state.rs

1/// Runtime state: per-account cooldowns/disabling + conversation stickiness.
2///
3/// Thread-safe via Arc<Mutex<>>. Cooldowns and disables are persisted to disk;
4/// stickiness is ephemeral (lost on restart is acceptable).
5use crate::config::RoutingStrategy;
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9use std::path::{Path, PathBuf};
10use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
11use parking_lot::Mutex;
12use std::sync::Arc;
13use std::time::{SystemTime, UNIX_EPOCH};
14use tracing::warn;
15
16fn now_ms() -> u64 {
17    SystemTime::now()
18        .duration_since(UNIX_EPOCH)
19        .unwrap_or_default()
20        .as_millis() as u64
21}
22
23/// Public version of `now_ms()` for use from other modules.
24pub fn now_ms_pub() -> u64 {
25    now_ms()
26}
27
28// ---------------------------------------------------------------------------
29// On-disk data
30// ---------------------------------------------------------------------------
31
32#[derive(Debug, Serialize, Deserialize, Default, Clone)]
33pub struct AccountState {
34    /// Epoch-ms timestamp after which this account is usable again (0 = not cooling).
35    #[serde(default)]
36    pub cooldown_until_ms: u64,
37    /// Permanently disabled (auth failure).
38    #[serde(default)]
39    pub disabled: bool,
40    /// OAuth credentials are expired and need re-authorization via `shunt add-account`.
41    #[serde(default)]
42    pub auth_failed: bool,
43    /// Account failed health-check probes — skip in routing until it recovers.
44    #[serde(default)]
45    pub health_check_failed: bool,
46    /// Consecutive health-check failure count (for exponential backoff). Ephemeral.
47    #[serde(skip)]
48    pub health_check_failures: u32,
49    /// Epoch-ms of the last health-check probe attempt. Ephemeral.
50    #[serde(skip)]
51    pub last_health_check_ms: u64,
52}
53
54#[derive(Serialize, Deserialize, Default, Clone)]
55struct StickyEntry {
56    account_name: String,
57    expires_at_ms: u64,
58}
59
60/// Rolling 5-hour quota window per account.
61#[derive(Debug, Serialize, Deserialize, Default, Clone)]
62pub struct QuotaWindow {
63    /// Epoch-ms when this window started (0 = never used).
64    #[serde(default)]
65    pub window_start_ms: u64,
66    #[serde(default)]
67    pub input_tokens: u64,
68    #[serde(default)]
69    pub output_tokens: u64,
70}
71
72impl QuotaWindow {
73    pub fn total_tokens(&self) -> u64 {
74        self.input_tokens + self.output_tokens
75    }
76    pub fn window_expires_ms(&self) -> Option<u64> {
77        if self.window_start_ms == 0 { None } else { Some(self.window_start_ms + WINDOW_MS) }
78    }
79}
80
81pub const WINDOW_MS: u64 = 5 * 60 * 60 * 1000; // 5 hours
82
83// ---------------------------------------------------------------------------
84// Request log
85// ---------------------------------------------------------------------------
86
87/// A single proxied request recorded for the live monitor.
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct RequestLog {
90    pub ts_ms: u64,
91    pub account: String,
92    pub model: String,
93    pub status: u16,
94    pub input_tokens: u64,
95    pub output_tokens: u64,
96    pub duration_ms: u64,
97}
98
99const MAX_RECENT: usize = 200;
100
101/// Rate-limit info extracted from `anthropic-ratelimit-unified-*` response headers.
102#[derive(Debug, Serialize, Deserialize, Default, Clone)]
103pub struct RateLimitInfo {
104    /// 5-hour window utilization 0.0–1.0
105    pub utilization_5h: Option<f64>,
106    /// Unix epoch seconds when 5h window resets
107    pub reset_5h: Option<u64>,
108    /// "allowed" | "exhausted"
109    pub status_5h: Option<String>,
110    /// 7-day window utilization 0.0–1.0
111    pub utilization_7d: Option<f64>,
112    /// Unix epoch seconds when 7d window resets
113    pub reset_7d: Option<u64>,
114    pub status_7d: Option<String>,
115    /// Extra usage (overage) status: "allowed" | "rejected"
116    pub overage_status: Option<String>,
117    pub overage_disabled_reason: Option<String>,
118    /// Which claim is currently representative ("five_hour" | "seven_day")
119    pub representative_claim: Option<String>,
120    pub updated_ms: u64,
121}
122
123/// Per-day token and API-cost accumulator (all accounts combined).
124#[derive(Debug, Serialize, Deserialize, Default, Clone)]
125pub struct DailyBucket {
126    pub input_tokens: u64,
127    pub output_tokens: u64,
128    /// What those tokens would have cost on the public API (USD).
129    pub api_cost_usd: f64,
130}
131
132/// Snapshot returned by `savings_snapshot()` for the status endpoint + CLI.
133#[derive(Debug, Serialize, Deserialize, Default, Clone)]
134pub struct SavingsSnapshot {
135    pub today_input: u64,
136    pub today_output: u64,
137    pub today_cost_usd: f64,
138    pub week_input: u64,
139    pub week_output: u64,
140    pub week_cost_usd: f64,
141    pub all_time_input: u64,
142    pub all_time_output: u64,
143    pub all_time_cost_usd: f64,
144}
145
146#[derive(Serialize, Deserialize, Default, Clone)]
147struct StateData {
148    #[serde(default)]
149    accounts: HashMap<String, AccountState>,
150    #[serde(default)]
151    sticky: HashMap<String, StickyEntry>,
152    #[serde(default)]
153    quota: HashMap<String, QuotaWindow>,
154    #[serde(default)]
155    rate_limits: HashMap<String, RateLimitInfo>,
156    /// If set, all requests are forced to this account (overrides routing).
157    #[serde(default)]
158    pinned_account: Option<String>,
159    /// The most recent account that successfully handled a proxied request.
160    #[serde(default)]
161    last_used_account: Option<String>,
162    /// Recent request log (ephemeral — not persisted to disk).
163    #[serde(skip)]
164    recent_requests: VecDeque<RequestLog>,
165    /// Runtime model override — all requests use this model if set (ephemeral).
166    #[serde(skip)]
167    model_override: Option<String>,
168    /// Runtime routing strategy override (ephemeral — not persisted).
169    #[serde(skip)]
170    routing_strategy_override: Option<RoutingStrategy>,
171    /// Daily token + cost buckets keyed by "YYYY-MM-DD" (all accounts combined).
172    #[serde(default)]
173    global_daily: HashMap<String, DailyBucket>,
174    /// All-time totals.
175    #[serde(default)]
176    all_time_input: u64,
177    #[serde(default)]
178    all_time_output: u64,
179    #[serde(default)]
180    all_time_cost_usd: f64,
181}
182
183// ---------------------------------------------------------------------------
184// Store
185// ---------------------------------------------------------------------------
186
187#[derive(Clone)]
188pub struct StateStore {
189    path: PathBuf,
190    inner: Arc<Mutex<StateData>>,
191    /// Set to true when a write is needed; the background writer thread clears it.
192    pending: Arc<AtomicBool>,
193    /// Monotonically-increasing counter for round-robin account selection.
194    round_robin: Arc<AtomicUsize>,
195}
196
197impl StateStore {
198    /// Create a fresh in-memory store with no backing file (useful for tests).
199    pub fn new_empty() -> Self {
200        // No background writer thread for the null store — writes are no-ops.
201        Self {
202            path: PathBuf::from("/dev/null"),
203            inner: Arc::new(Mutex::new(StateData::default())),
204            pending: Arc::new(AtomicBool::new(false)),
205            round_robin: Arc::new(AtomicUsize::new(0)),
206        }
207    }
208
209    pub fn load(path: &Path) -> Self {
210        let mut data: StateData = if path.exists() {
211            match std::fs::read_to_string(path) {
212                Ok(text) => serde_json::from_str(&text).unwrap_or_else(|e| {
213                    warn!("State file unreadable ({e}), starting fresh");
214                    StateData::default()
215                }),
216                Err(e) => {
217                    warn!("Cannot read state file ({e}), starting fresh");
218                    StateData::default()
219                }
220            }
221        } else {
222            StateData::default()
223        };
224        // Prune expired sticky entries so the file doesn't grow unbounded.
225        let now = now_ms();
226        data.sticky.retain(|_, v| v.expires_at_ms > now);
227
228        let store = Self {
229            path: path.to_owned(),
230            inner: Arc::new(Mutex::new(data)),
231            pending: Arc::new(AtomicBool::new(false)),
232            round_robin: Arc::new(AtomicUsize::new(0)),
233        };
234        store.start_writer_thread();
235        store
236    }
237
238    /// Spawn a single background thread that flushes state to disk at most every 100 ms.
239    /// This prevents unbounded thread spawning when many requests fire in rapid succession.
240    fn start_writer_thread(&self) {
241        let pending = Arc::clone(&self.pending);
242        let inner   = Arc::clone(&self.inner);
243        let path    = self.path.clone();
244        std::thread::spawn(move || {
245            loop {
246                std::thread::sleep(std::time::Duration::from_millis(100));
247                if pending.compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed).is_ok() {
248                    let data = inner.lock().clone();
249                    if let Err(e) = write_to_disk(&data, &path) {
250                        warn!("Failed to persist state: {e}");
251                    }
252                }
253            }
254        });
255    }
256
257    // -----------------------------------------------------------------------
258    // Availability
259    // -----------------------------------------------------------------------
260
261    pub fn is_available(&self, name: &str) -> bool {
262        let data = self.inner.lock();
263        match data.accounts.get(name) {
264            None => true,
265            Some(s) => !s.disabled && now_ms() >= s.cooldown_until_ms,
266        }
267    }
268
269    /// Returns true if the account's Anthropic quota is currently exhausted in any
270    /// active window (5h or 7d) — i.e. sending another request will get a 429.
271    pub fn is_exhausted(&self, name: &str) -> bool {
272        let now_secs = SystemTime::now()
273            .duration_since(UNIX_EPOCH)
274            .unwrap_or_default()
275            .as_secs();
276        let data = self.inner.lock();
277        let Some(rl) = data.rate_limits.get(name) else { return false };
278        // Only consider a window exhausted if its reset is still in the future
279        // (i.e. the window hasn't rolled over yet).
280        let exhausted_5h = rl.status_5h.as_deref() == Some("exhausted")
281            && rl.reset_5h.map(|t| t > now_secs).unwrap_or(false);
282        let exhausted_7d = rl.status_7d.as_deref() == Some("exhausted")
283            && rl.reset_7d.map(|t| t > now_secs).unwrap_or(false);
284        exhausted_5h || exhausted_7d
285    }
286
287    /// Fetch-and-increment monotonic counter for round-robin account cycling.
288    pub fn next_rr_index(&self) -> usize {
289        self.round_robin.fetch_add(1, Ordering::Relaxed)
290    }
291
292    /// Returns a snapshot of all account states for the status endpoint.
293    pub fn account_states(&self) -> HashMap<String, AccountState> {
294        self.inner.lock().accounts.clone()
295    }
296
297    // -----------------------------------------------------------------------
298    // Cooldown / disable
299    // -----------------------------------------------------------------------
300
301    pub fn set_cooldown(&self, name: &str, duration_ms: u64) {
302        {
303            let mut data = self.inner.lock();
304            let acc = data.accounts.entry(name.to_owned()).or_default();
305            acc.cooldown_until_ms = now_ms() + duration_ms;
306        }
307        self.persist();
308    }
309
310    pub fn disable_account(&self, name: &str) {
311        {
312            let mut data = self.inner.lock();
313            data.accounts.entry(name.to_owned()).or_default().disabled = true;
314        }
315        self.persist();
316    }
317
318    pub fn set_auth_failed(&self, name: &str) {
319        {
320            let mut data = self.inner.lock();
321            let acc = data.accounts.entry(name.to_owned()).or_default();
322            acc.auth_failed = true;
323            acc.disabled = true; // also disable so it's skipped in routing
324        }
325        self.persist();
326    }
327
328    /// Clear auth_failed + disabled for an account after a successful token refresh.
329    pub fn clear_auth_failed(&self, name: &str) {
330        {
331            let mut data = self.inner.lock();
332            if let Some(acc) = data.accounts.get_mut(name) {
333                acc.auth_failed = false;
334                acc.disabled = false;
335            }
336        }
337        self.persist();
338    }
339
340    /// Returns names of accounts (from the given list) that have auth_failed set.
341    pub fn auth_failed_accounts<'a>(&self, names: &[&'a str]) -> Vec<&'a str> {
342        let data = self.inner.lock();
343        names.iter()
344            .filter(|&&n| data.accounts.get(n).map(|s| s.auth_failed).unwrap_or(false))
345            .copied()
346            .collect()
347    }
348
349    // -----------------------------------------------------------------------
350    // Health check state
351    // -----------------------------------------------------------------------
352
353    pub fn is_health_check_failed(&self, name: &str) -> bool {
354        let data = self.inner.lock();
355        data.accounts.get(name).map(|s| s.health_check_failed).unwrap_or(false)
356    }
357
358    pub fn set_health_check_failed(&self, name: &str) {
359        {
360            let mut data = self.inner.lock();
361            let acc = data.accounts.entry(name.to_owned()).or_default();
362            acc.health_check_failed = true;
363        }
364        self.persist();
365    }
366
367    pub fn clear_health_check_failed(&self, name: &str) {
368        {
369            let mut data = self.inner.lock();
370            if let Some(acc) = data.accounts.get_mut(name) {
371                acc.health_check_failed = false;
372                acc.health_check_failures = 0;
373            }
374        }
375        self.persist();
376    }
377
378    /// Increment consecutive failure count and return the new value.
379    /// Sets `health_check_failed = true` once failures >= `threshold`.
380    pub fn record_health_check_failure(&self, name: &str, threshold: u32) -> u32 {
381        let count;
382        {
383            let mut data = self.inner.lock();
384            let acc = data.accounts.entry(name.to_owned()).or_default();
385            acc.health_check_failures = acc.health_check_failures.saturating_add(1);
386            count = acc.health_check_failures;
387            if count >= threshold {
388                acc.health_check_failed = true;
389            }
390        }
391        if count >= threshold {
392            self.persist();
393        }
394        count
395    }
396
397    /// Update last_health_check_ms to now. Returns the previous value.
398    pub fn update_last_health_check(&self, name: &str) -> u64 {
399        let mut data = self.inner.lock();
400        let acc = data.accounts.entry(name.to_owned()).or_default();
401        let prev = acc.last_health_check_ms;
402        acc.last_health_check_ms = now_ms();
403        prev
404    }
405
406    /// Get the last health check timestamp and consecutive failure count.
407    pub fn health_check_info(&self, name: &str) -> (u64, u32) {
408        let data = self.inner.lock();
409        match data.accounts.get(name) {
410            Some(acc) => (acc.last_health_check_ms, acc.health_check_failures),
411            None => (0, 0),
412        }
413    }
414
415    // -----------------------------------------------------------------------
416    // Stickiness (ephemeral — not persisted)
417    // -----------------------------------------------------------------------
418
419    pub fn get_sticky(&self, fingerprint: &str) -> Option<String> {
420        let data = self.inner.lock();
421        let entry = data.sticky.get(fingerprint)?;
422        if now_ms() < entry.expires_at_ms {
423            Some(entry.account_name.clone())
424        } else {
425            None
426        }
427    }
428
429    pub fn set_sticky(&self, fingerprint: &str, account_name: &str, ttl_ms: u64) {
430        const MAX_STICKY_ENTRIES: usize = 10_000;
431        {
432            let mut data = self.inner.lock();
433            // Prune expired entries if approaching limit
434            if data.sticky.len() >= MAX_STICKY_ENTRIES {
435                let now = now_ms();
436                data.sticky.retain(|_, v| v.expires_at_ms > now);
437                // If still at limit after pruning, clear oldest half to prevent DoS
438                if data.sticky.len() >= MAX_STICKY_ENTRIES {
439                    data.sticky.clear();
440                }
441            }
442            data.sticky.insert(
443                fingerprint.to_owned(),
444                StickyEntry { account_name: account_name.to_owned(), expires_at_ms: now_ms() + ttl_ms },
445            );
446        }
447        self.persist();
448    }
449
450    // -----------------------------------------------------------------------
451    // Quota tracking
452    // -----------------------------------------------------------------------
453
454    /// Unix epoch seconds when this account's 5h window resets.
455    /// Returns None if unknown or already past.
456    pub fn reset_5h_secs(&self, name: &str) -> Option<u64> {
457        let now_secs = SystemTime::now()
458            .duration_since(UNIX_EPOCH)
459            .unwrap_or_default()
460            .as_secs();
461        let data = self.inner.lock();
462        let reset = data.rate_limits.get(name)?.reset_5h?;
463        if reset > now_secs { Some(reset) } else { None }
464    }
465
466    /// 5-hour utilization 0.0–1.0 from the last upstream response headers.
467    /// Returns 0.0 for fresh accounts or when the reset window has already passed.
468    pub fn utilization_5h(&self, name: &str) -> f64 {
469        let now_secs = SystemTime::now()
470            .duration_since(UNIX_EPOCH)
471            .unwrap_or_default()
472            .as_secs();
473        let data = self.inner.lock();
474        let Some(rl) = data.rate_limits.get(name) else { return 0.0 };
475        // If the reset time is in the past, the window has rolled over — treat as fresh
476        if rl.reset_5h.map(|t| t <= now_secs).unwrap_or(false) {
477            return 0.0;
478        }
479        rl.utilization_5h.unwrap_or(0.0)
480    }
481
482    /// 7-day utilization 0.0–1.0 from the last upstream response headers.
483    /// Returns 0.0 for fresh accounts or when the reset window has already passed.
484    pub fn utilization_7d(&self, name: &str) -> f64 {
485        let now_secs = SystemTime::now()
486            .duration_since(UNIX_EPOCH)
487            .unwrap_or_default()
488            .as_secs();
489        let data = self.inner.lock();
490        let Some(rl) = data.rate_limits.get(name) else { return 0.0 };
491        if rl.reset_7d.map(|t| t <= now_secs).unwrap_or(false) {
492            return 0.0;
493        }
494        rl.utilization_7d.unwrap_or(0.0)
495    }
496
497    /// Unix epoch seconds when this account's 7d window resets.
498    /// Returns None if unknown or already past.
499    pub fn reset_7d_secs(&self, name: &str) -> Option<u64> {
500        let now_secs = SystemTime::now()
501            .duration_since(UNIX_EPOCH)
502            .unwrap_or_default()
503            .as_secs();
504        let data = self.inner.lock();
505        let reset = data.rate_limits.get(name)?.reset_7d?;
506        if reset > now_secs { Some(reset) } else { None }
507    }
508
509    /// Record token usage from a completed request.
510    /// Lazily resets the window if the 5-hour period has elapsed.
511    pub fn record_usage(&self, name: &str, input_tokens: u64, output_tokens: u64) {
512        if input_tokens == 0 && output_tokens == 0 {
513            return;
514        }
515        {
516            let mut data = self.inner.lock();
517            let quota = data.quota.entry(name.to_owned()).or_default();
518            let now = now_ms();
519            if quota.window_start_ms == 0 || now >= quota.window_start_ms + WINDOW_MS {
520                quota.window_start_ms = now;
521                quota.input_tokens = 0;
522                quota.output_tokens = 0;
523            }
524            quota.input_tokens = quota.input_tokens.saturating_add(input_tokens);
525            quota.output_tokens = quota.output_tokens.saturating_add(output_tokens);
526        }
527        self.persist();
528    }
529
530    /// Snapshot of all quota windows for the status endpoint.
531    pub fn quota_snapshot(&self) -> HashMap<String, QuotaWindow> {
532        self.inner.lock().quota.clone()
533    }
534
535    // -----------------------------------------------------------------------
536    // Rate limit header tracking
537    // -----------------------------------------------------------------------
538
539    pub fn update_rate_limits(&self, name: &str, info: RateLimitInfo) {
540        let prev = self.inner.lock().rate_limits.get(name).cloned();
541
542        // Warn the first time utilization crosses 90% for each window.
543        let prev_5h = prev.as_ref().and_then(|p| p.utilization_5h).unwrap_or(0.0);
544        let prev_7d = prev.as_ref().and_then(|p| p.utilization_7d).unwrap_or(0.0);
545        if let Some(u) = info.utilization_5h {
546            if u >= 0.9 && prev_5h < 0.9 {
547                warn!(account = %name, utilization = %format!("{:.0}%", u * 100.0),
548                    "5h rate limit above 90% — approaching quota");
549            }
550        }
551        if let Some(u) = info.utilization_7d {
552            if u >= 0.9 && prev_7d < 0.9 {
553                warn!(account = %name, utilization = %format!("{:.0}%", u * 100.0),
554                    "7d rate limit above 90% — approaching quota");
555            }
556        }
557
558        {
559            let mut data = self.inner.lock();
560            data.rate_limits.insert(name.to_owned(), info);
561        }
562        self.persist();
563    }
564
565    pub fn rate_limit_snapshot(&self) -> HashMap<String, RateLimitInfo> {
566        self.inner.lock().rate_limits.clone()
567    }
568
569    // -----------------------------------------------------------------------
570    // Account pinning
571    // -----------------------------------------------------------------------
572
573    pub fn get_pinned(&self) -> Option<String> {
574        self.inner.lock().pinned_account.clone()
575    }
576
577    pub fn set_pinned(&self, name: Option<String>) {
578        {
579            let mut data = self.inner.lock();
580            data.pinned_account = name;
581        }
582        self.persist();
583    }
584
585    // -----------------------------------------------------------------------
586    // Last-used tracking
587    // -----------------------------------------------------------------------
588
589    pub fn get_last_used(&self) -> Option<String> {
590        self.inner.lock().last_used_account.clone()
591    }
592
593    pub fn set_last_used(&self, name: &str) {
594        {
595            let mut data = self.inner.lock();
596            data.last_used_account = Some(name.to_owned());
597        }
598        self.persist();
599    }
600
601    // -----------------------------------------------------------------------
602    // Model override
603    // -----------------------------------------------------------------------
604
605    pub fn get_model_override(&self) -> Option<String> {
606        self.inner.lock().model_override.clone()
607    }
608
609    pub fn set_model_override(&self, model: String) {
610        self.inner.lock().model_override = Some(model);
611    }
612
613    pub fn clear_model_override(&self) {
614        self.inner.lock().model_override = None;
615    }
616
617    // -----------------------------------------------------------------------
618    // Routing strategy override
619    // -----------------------------------------------------------------------
620
621    pub fn get_routing_strategy(&self) -> Option<RoutingStrategy> {
622        self.inner.lock().routing_strategy_override
623    }
624
625    pub fn set_routing_strategy(&self, strategy: RoutingStrategy) {
626        self.inner.lock().routing_strategy_override = Some(strategy);
627    }
628
629    pub fn clear_routing_strategy(&self) {
630        self.inner.lock().routing_strategy_override = None;
631    }
632
633    // -----------------------------------------------------------------------
634    // Request log
635    // -----------------------------------------------------------------------
636
637    pub fn record_request(&self, log: RequestLog) {
638        let mut data = self.inner.lock();
639        if data.recent_requests.len() >= MAX_RECENT {
640            data.recent_requests.pop_front();
641        }
642        data.recent_requests.push_back(log);
643    }
644
645    /// Most-recent first snapshot for the monitor / status endpoint.
646    pub fn recent_requests_snapshot(&self) -> Vec<RequestLog> {
647        let data = self.inner.lock();
648        data.recent_requests.iter().rev().cloned().collect()
649    }
650
651    // -----------------------------------------------------------------------
652    // Global savings tracking
653    // -----------------------------------------------------------------------
654
655    /// Record tokens + API cost globally (across all accounts) for the savings display.
656    pub fn record_global(&self, model: &str, input_tokens: u64, output_tokens: u64) {
657        if input_tokens == 0 && output_tokens == 0 {
658            return;
659        }
660        let cost = crate::pricing::api_cost_usd(model, input_tokens, output_tokens);
661        let key = today_key();
662        {
663            let mut data = self.inner.lock();
664            let bucket = data.global_daily.entry(key).or_default();
665            bucket.input_tokens  = bucket.input_tokens.saturating_add(input_tokens);
666            bucket.output_tokens = bucket.output_tokens.saturating_add(output_tokens);
667            bucket.api_cost_usd  += cost;
668            data.all_time_input  = data.all_time_input.saturating_add(input_tokens);
669            data.all_time_output = data.all_time_output.saturating_add(output_tokens);
670            data.all_time_cost_usd += cost;
671
672            // Prune buckets older than 90 days to prevent unbounded growth.
673            if data.global_daily.len() > 100 {
674                let cutoff = epoch_to_ymd(
675                    SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_secs()
676                        .saturating_sub(90 * 86400)
677                );
678                data.global_daily.retain(|k, _| k.as_str() >= cutoff.as_str());
679            }
680        }
681        self.persist();
682    }
683
684    /// Snapshot of daily and all-time savings for the status endpoint and CLI.
685    pub fn savings_snapshot(&self) -> SavingsSnapshot {
686        let now_secs = SystemTime::now()
687            .duration_since(UNIX_EPOCH)
688            .unwrap_or_default()
689            .as_secs();
690        let today   = today_key();
691        let week_ago = epoch_to_ymd(now_secs.saturating_sub(7 * 86400));
692
693        let data = self.inner.lock();
694
695        let today_bucket = data.global_daily.get(&today).cloned().unwrap_or_default();
696
697        let (week_input, week_output, week_cost) = data.global_daily.iter()
698            .filter(|(k, _)| k.as_str() >= week_ago.as_str())
699            .fold((0u64, 0u64, 0f64), |(i, o, c), (_, b)| {
700                (i + b.input_tokens, o + b.output_tokens, c + b.api_cost_usd)
701            });
702
703        SavingsSnapshot {
704            today_input:      today_bucket.input_tokens,
705            today_output:     today_bucket.output_tokens,
706            today_cost_usd:   today_bucket.api_cost_usd,
707            week_input,
708            week_output,
709            week_cost_usd:    week_cost,
710            all_time_input:   data.all_time_input,
711            all_time_output:  data.all_time_output,
712            all_time_cost_usd: data.all_time_cost_usd,
713        }
714    }
715
716    // -----------------------------------------------------------------------
717    // Persistence
718    // -----------------------------------------------------------------------
719
720    fn persist(&self) {
721        // Signal the background writer thread; it will flush within ~100 ms.
722        self.pending.store(true, Ordering::Release);
723    }
724}
725
726#[cfg(test)]
727mod tests {
728    use super::*;
729
730    #[test]
731    fn test_sticky_ttl_expiry() {
732        let store = StateStore::new_empty();
733        let fp = "conv-fp-ttl";
734        store.set_sticky(fp, "account1", 500); // 500 ms TTL
735        assert_eq!(store.get_sticky(fp).as_deref(), Some("account1"),
736            "sticky should be available immediately");
737        std::thread::sleep(std::time::Duration::from_millis(600));
738        assert!(store.get_sticky(fp).is_none(),
739            "sticky must expire after TTL elapses");
740    }
741
742    #[test]
743    fn test_cooldown_blocks_availability() {
744        let store = StateStore::new_empty();
745        store.set_cooldown("acc", 5_000); // 5s cooldown
746        assert!(!store.is_available("acc"), "account should be unavailable during cooldown");
747    }
748
749    #[test]
750    fn test_disable_blocks_availability() {
751        let store = StateStore::new_empty();
752        store.disable_account("acc");
753        assert!(!store.is_available("acc"), "disabled account must be unavailable");
754    }
755
756    #[test]
757    fn test_quota_accumulates() {
758        let store = StateStore::new_empty();
759        store.record_usage("acc", 100, 50);
760        store.record_usage("acc", 200, 75);
761        let snap = store.quota_snapshot();
762        let q = &snap["acc"];
763        assert_eq!(q.input_tokens, 300);
764        assert_eq!(q.output_tokens, 125);
765        assert_eq!(q.total_tokens(), 425);
766    }
767
768    #[test]
769    fn test_pinned_account_round_trip() {
770        let store = StateStore::new_empty();
771        assert!(store.get_pinned().is_none());
772        store.set_pinned(Some("myaccount".into()));
773        assert_eq!(store.get_pinned().as_deref(), Some("myaccount"));
774        store.set_pinned(None);
775        assert!(store.get_pinned().is_none());
776    }
777
778    #[test]
779    fn test_last_used_round_trip() {
780        let store = StateStore::new_empty();
781        assert!(store.get_last_used().is_none());
782        store.set_last_used("acc1");
783        assert_eq!(store.get_last_used().as_deref(), Some("acc1"));
784    }
785
786    #[test]
787    fn test_recent_requests_ring_buffer() {
788        let store = StateStore::new_empty();
789        // Fill past MAX_RECENT
790        for i in 0..=(MAX_RECENT + 5) {
791            store.record_request(RequestLog {
792                ts_ms: i as u64,
793                account: "acc".into(),
794                model: "m".into(),
795                status: 200,
796                input_tokens: 1,
797                output_tokens: 1,
798                duration_ms: 1,
799            });
800        }
801        let snap = store.recent_requests_snapshot();
802        assert_eq!(snap.len(), MAX_RECENT, "buffer must not grow beyond MAX_RECENT");
803        // Most recent first
804        assert!(snap[0].ts_ms > snap[snap.len() - 1].ts_ms, "snapshot must be newest-first");
805    }
806
807    #[test]
808    fn test_health_check_failed_round_trip() {
809        let store = StateStore::new_empty();
810        assert!(!store.is_health_check_failed("acc"), "fresh account should not be health-check-failed");
811
812        store.set_health_check_failed("acc");
813        assert!(store.is_health_check_failed("acc"), "should be marked unhealthy after set");
814
815        store.clear_health_check_failed("acc");
816        assert!(!store.is_health_check_failed("acc"), "should be cleared after clear");
817    }
818
819    #[test]
820    fn test_health_check_failure_threshold() {
821        let store = StateStore::new_empty();
822
823        // First failure: count=1, threshold=2 → not yet marked
824        let count = store.record_health_check_failure("acc", 2);
825        assert_eq!(count, 1);
826        assert!(!store.is_health_check_failed("acc"),
827            "should not be marked after 1 failure (threshold=2)");
828
829        // Second failure: count=2, threshold=2 → now marked
830        let count = store.record_health_check_failure("acc", 2);
831        assert_eq!(count, 2);
832        assert!(store.is_health_check_failed("acc"),
833            "should be marked after 2 failures (threshold=2)");
834    }
835
836    #[test]
837    fn test_clear_health_check_resets_failure_count() {
838        let store = StateStore::new_empty();
839        store.record_health_check_failure("acc", 2);
840        store.record_health_check_failure("acc", 2);
841        assert!(store.is_health_check_failed("acc"));
842
843        store.clear_health_check_failed("acc");
844        assert!(!store.is_health_check_failed("acc"));
845
846        let (_, failures) = store.health_check_info("acc");
847        assert_eq!(failures, 0, "failure count must reset to 0 after clear");
848    }
849
850    #[test]
851    fn test_health_check_info_and_last_check() {
852        let store = StateStore::new_empty();
853        let (last, failures) = store.health_check_info("acc");
854        assert_eq!(last, 0, "fresh account last_health_check_ms should be 0");
855        assert_eq!(failures, 0);
856
857        let prev = store.update_last_health_check("acc");
858        assert_eq!(prev, 0, "first update should return previous value 0");
859
860        let (last2, _) = store.health_check_info("acc");
861        assert!(last2 > 0, "last_health_check_ms should be updated to now");
862    }
863
864    #[test]
865    fn test_health_check_failed_persists() {
866        let path = std::env::temp_dir().join(format!(
867            "shunt_test_hc_{}.json",
868            std::time::SystemTime::now()
869                .duration_since(std::time::UNIX_EPOCH)
870                .unwrap()
871                .as_nanos()
872        ));
873
874        {
875            let store = StateStore::load(&path);
876            store.set_health_check_failed("acc");
877            std::thread::sleep(std::time::Duration::from_millis(300));
878        }
879
880        let store2 = StateStore::load(&path);
881        assert!(store2.is_health_check_failed("acc"),
882            "health_check_failed must survive restart");
883
884        // Ephemeral fields should NOT persist
885        let (last, failures) = store2.health_check_info("acc");
886        assert_eq!(last, 0, "last_health_check_ms is ephemeral, should be 0 after reload");
887        assert_eq!(failures, 0, "health_check_failures is ephemeral, should be 0 after reload");
888
889        let _ = std::fs::remove_file(&path);
890    }
891
892    #[test]
893    fn test_state_persistence_roundtrip() {
894        // Use a unique temp path so parallel tests don't collide
895        let path = std::env::temp_dir().join(format!(
896            "shunt_test_state_{}.json",
897            std::time::SystemTime::now()
898                .duration_since(std::time::UNIX_EPOCH)
899                .unwrap()
900                .as_nanos()
901        ));
902
903        {
904            let store = StateStore::load(&path);
905            store.set_cooldown("acc", 999_999_000); // far-future cooldown
906            store.record_usage("acc", 111, 222);
907            store.set_last_used("acc");
908            // Wait for the background writer (polls every 100 ms) to flush
909            std::thread::sleep(std::time::Duration::from_millis(300));
910        }
911
912        // Load a fresh store from the persisted file
913        let store2 = StateStore::load(&path);
914        assert!(!store2.is_available("acc"), "cooldown must survive restart");
915        let snap = store2.quota_snapshot();
916        assert_eq!(snap["acc"].input_tokens, 111, "quota must survive restart");
917        assert_eq!(snap["acc"].output_tokens, 222);
918        assert_eq!(store2.get_last_used().as_deref(), Some("acc"),
919            "last_used_account must survive restart");
920
921        let _ = std::fs::remove_file(&path);
922    }
923}
924
925/// "YYYY-MM-DD" string for today in UTC.
926fn today_key() -> String {
927    let secs = SystemTime::now()
928        .duration_since(UNIX_EPOCH)
929        .unwrap_or_default()
930        .as_secs();
931    epoch_to_ymd(secs)
932}
933
934/// Convert Unix epoch seconds to "YYYY-MM-DD" (UTC) using Hinnant's civil_from_days.
935fn epoch_to_ymd(secs: u64) -> String {
936    let days = (secs / 86400) as i64;
937    let z    = days + 719_468;
938    let era  = if z >= 0 { z } else { z - 146_096 } / 146_097;
939    let doe  = z - era * 146_097;
940    let yoe  = (doe - doe / 1_460 + doe / 36_524 - doe / 146_096) / 365;
941    let y    = yoe + era * 400;
942    let doy  = doe - (365 * yoe + yoe / 4 - yoe / 100);
943    let mp   = (5 * doy + 2) / 153;
944    let d    = doy - (153 * mp + 2) / 5 + 1;
945    let m    = if mp < 10 { mp + 3 } else { mp - 9 };
946    let y    = if m <= 2 { y + 1 } else { y };
947    format!("{y:04}-{m:02}-{d:02}")
948}
949
950fn write_to_disk(data: &StateData, path: &Path) -> Result<()> {
951    if let Some(parent) = path.parent() {
952        std::fs::create_dir_all(parent)?;
953    }
954    let tmp = path.with_extension("tmp");
955    std::fs::write(&tmp, serde_json::to_string_pretty(data)?)?;
956    #[cfg(unix)]
957    {
958        use std::os::unix::fs::PermissionsExt;
959        let _ = std::fs::set_permissions(&tmp, std::fs::Permissions::from_mode(0o600));
960    }
961    std::fs::rename(&tmp, path)?;
962    Ok(())
963}