1use crate::config::RoutingStrategy;
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet, VecDeque};
9use std::path::{Path, PathBuf};
10use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
11use parking_lot::Mutex;
12use std::sync::Arc;
13use std::time::{SystemTime, UNIX_EPOCH};
14use tracing::warn;
15
16fn now_ms() -> u64 {
17 SystemTime::now()
18 .duration_since(UNIX_EPOCH)
19 .unwrap_or_default()
20 .as_millis() as u64
21}
22
23pub fn now_ms_pub() -> u64 {
25 now_ms()
26}
27
28#[derive(Debug, Clone)]
34pub struct AccountRoutingData {
35 pub available: bool,
36 pub health_check_failed: bool,
37 pub exhausted: bool,
38 pub cooldown_until_ms: u64,
39 pub util_5h: f64,
40 pub util_7d: f64,
41 pub reset_5h_secs: Option<u64>,
42 pub reset_7d_secs: Option<u64>,
43 pub burst_request_count: usize,
44}
45
46#[derive(Debug, Clone)]
48pub struct RoutingSnapshot {
49 pub accounts: HashMap<String, AccountRoutingData>,
50 pub now_secs: u64,
51}
52
53#[derive(Debug, Serialize, Deserialize, Default, Clone)]
58pub struct AccountState {
59 #[serde(default)]
61 pub cooldown_until_ms: u64,
62 #[serde(default)]
64 pub disabled: bool,
65 #[serde(default)]
67 pub auth_failed: bool,
68 #[serde(default)]
70 pub health_check_failed: bool,
71 #[serde(skip)]
73 pub health_check_failures: u32,
74 #[serde(skip)]
76 pub last_health_check_ms: u64,
77}
78
79#[derive(Serialize, Deserialize, Default, Clone)]
80struct StickyEntry {
81 account_name: String,
82 expires_at_ms: u64,
83}
84
85#[derive(Debug, Serialize, Deserialize, Default, Clone)]
87pub struct QuotaWindow {
88 #[serde(default)]
90 pub window_start_ms: u64,
91 #[serde(default)]
92 pub input_tokens: u64,
93 #[serde(default)]
94 pub output_tokens: u64,
95}
96
97impl QuotaWindow {
98 pub fn total_tokens(&self) -> u64 {
99 self.input_tokens + self.output_tokens
100 }
101 pub fn window_expires_ms(&self) -> Option<u64> {
102 if self.window_start_ms == 0 { None } else { Some(self.window_start_ms + WINDOW_MS) }
103 }
104}
105
106pub const WINDOW_MS: u64 = 5 * 60 * 60 * 1000; #[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct RequestLog {
115 pub ts_ms: u64,
116 pub account: String,
117 pub model: String,
118 pub status: u16,
119 pub input_tokens: u64,
120 pub output_tokens: u64,
121 pub duration_ms: u64,
122}
123
124const MAX_RECENT: usize = 200;
125
126#[derive(Debug, Serialize, Deserialize, Default, Clone)]
128pub struct RateLimitInfo {
129 pub utilization_5h: Option<f64>,
131 pub reset_5h: Option<u64>,
133 pub status_5h: Option<String>,
135 pub utilization_7d: Option<f64>,
137 pub reset_7d: Option<u64>,
139 pub status_7d: Option<String>,
140 pub overage_status: Option<String>,
142 pub overage_disabled_reason: Option<String>,
143 pub representative_claim: Option<String>,
145 pub updated_ms: u64,
146}
147
148#[derive(Debug, Serialize, Deserialize, Default, Clone)]
150pub struct DailyBucket {
151 pub input_tokens: u64,
152 pub output_tokens: u64,
153 pub api_cost_usd: f64,
155}
156
157#[derive(Debug, Serialize, Deserialize, Default, Clone)]
159pub struct SavingsSnapshot {
160 pub today_input: u64,
161 pub today_output: u64,
162 pub today_cost_usd: f64,
163 pub week_input: u64,
164 pub week_output: u64,
165 pub week_cost_usd: f64,
166 pub all_time_input: u64,
167 pub all_time_output: u64,
168 pub all_time_cost_usd: f64,
169}
170
171#[derive(Serialize, Deserialize, Default, Clone)]
172struct StateData {
173 #[serde(default)]
174 accounts: HashMap<String, AccountState>,
175 #[serde(default)]
176 sticky: HashMap<String, StickyEntry>,
177 #[serde(default)]
178 quota: HashMap<String, QuotaWindow>,
179 #[serde(default)]
180 rate_limits: HashMap<String, RateLimitInfo>,
181 #[serde(default)]
183 pinned_account: Option<String>,
184 #[serde(default)]
186 last_used_account: Option<String>,
187 #[serde(default)]
189 recent_requests: VecDeque<RequestLog>,
190 #[serde(skip)]
192 model_override: Option<String>,
193 #[serde(skip)]
195 routing_strategy_override: Option<RoutingStrategy>,
196 #[serde(skip)]
198 burst_windows: HashMap<String, VecDeque<u64>>,
199 #[serde(skip)]
201 burst_rpm_limit_override: Option<u32>,
202 #[serde(skip)]
205 fallback_model_override: Option<Option<String>>,
206 #[serde(skip)]
208 effort_override: Option<String>,
209 #[serde(skip)]
211 thinking_override: Option<String>,
212 #[serde(default)]
214 global_daily: HashMap<String, DailyBucket>,
215 #[serde(default)]
217 all_time_input: u64,
218 #[serde(default)]
219 all_time_output: u64,
220 #[serde(default)]
221 all_time_cost_usd: f64,
222}
223
224#[derive(Clone)]
229pub struct StateStore {
230 path: PathBuf,
231 inner: Arc<Mutex<StateData>>,
232 pending: Arc<AtomicBool>,
234 round_robin: Arc<AtomicUsize>,
236 alerts_muted: Arc<AtomicBool>,
238}
239
240impl StateStore {
241 pub fn new_empty() -> Self {
243 Self {
245 path: PathBuf::from("/dev/null"),
246 inner: Arc::new(Mutex::new(StateData::default())),
247 pending: Arc::new(AtomicBool::new(false)),
248 round_robin: Arc::new(AtomicUsize::new(0)),
249 alerts_muted: Arc::new(AtomicBool::new(false)),
250 }
251 }
252
253 pub fn load(path: &Path) -> Self {
254 let mut data: StateData = if path.exists() {
255 match std::fs::read_to_string(path) {
256 Ok(text) => serde_json::from_str(&text).unwrap_or_else(|e| {
257 warn!("State file unreadable ({e}), starting fresh");
258 StateData::default()
259 }),
260 Err(e) => {
261 warn!("Cannot read state file ({e}), starting fresh");
262 StateData::default()
263 }
264 }
265 } else {
266 StateData::default()
267 };
268 let now = now_ms();
270 data.sticky.retain(|_, v| v.expires_at_ms > now);
271 while data.recent_requests.len() > MAX_RECENT {
273 data.recent_requests.pop_front();
274 }
275
276 let store = Self {
277 path: path.to_owned(),
278 inner: Arc::new(Mutex::new(data)),
279 pending: Arc::new(AtomicBool::new(false)),
280 round_robin: Arc::new(AtomicUsize::new(0)),
281 alerts_muted: Arc::new(AtomicBool::new(false)),
282 };
283 store.start_writer_thread();
284 store
285 }
286
287 fn start_writer_thread(&self) {
290 let pending = Arc::clone(&self.pending);
291 let inner = Arc::clone(&self.inner);
292 let path = self.path.clone();
293 std::thread::spawn(move || {
294 loop {
295 std::thread::sleep(std::time::Duration::from_millis(100));
296 if pending.compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed).is_ok() {
297 let data = inner.lock().clone();
298 if let Err(e) = write_to_disk(&data, &path) {
299 warn!("Failed to persist state: {e}");
300 }
301 }
302 }
303 });
304 }
305
306 pub fn flush_sync(&self) {
310 let data = self.inner.lock().clone();
311 if let Err(e) = write_to_disk(&data, &self.path) {
312 warn!("Final state flush failed: {e}");
313 }
314 }
315
316 pub fn is_available(&self, name: &str) -> bool {
321 let data = self.inner.lock();
322 match data.accounts.get(name) {
323 None => true,
324 Some(s) => !s.disabled && now_ms() >= s.cooldown_until_ms,
325 }
326 }
327
328 pub fn is_exhausted(&self, name: &str) -> bool {
331 let now_secs = SystemTime::now()
332 .duration_since(UNIX_EPOCH)
333 .unwrap_or_default()
334 .as_secs();
335 let data = self.inner.lock();
336 let Some(rl) = data.rate_limits.get(name) else { return false };
337 let exhausted_5h = rl.status_5h.as_deref() == Some("exhausted")
340 && rl.reset_5h.map(|t| t > now_secs).unwrap_or(false);
341 let exhausted_7d = rl.status_7d.as_deref() == Some("exhausted")
342 && rl.reset_7d.map(|t| t > now_secs).unwrap_or(false);
343 exhausted_5h || exhausted_7d
344 }
345
346 pub fn next_rr_index(&self) -> usize {
348 self.round_robin.fetch_add(1, Ordering::Relaxed)
349 }
350
351 pub fn account_states(&self) -> HashMap<String, AccountState> {
353 self.inner.lock().accounts.clone()
354 }
355
356 pub fn routing_snapshot(&self) -> RoutingSnapshot {
359 let now_ms = now_ms();
360 let now_secs = now_ms / 1_000;
361 let mut data = self.inner.lock();
362
363 let all_names: Vec<String> = {
365 let mut names: HashSet<&String> = data.accounts.keys().collect();
366 names.extend(data.rate_limits.keys());
367 names.into_iter().cloned().collect()
368 };
369
370 let burst_counts: HashMap<String, usize> = all_names.iter()
372 .map(|name| {
373 let count = data.burst_windows.get_mut(name)
374 .map(|deque| Self::burst_count_inner(deque, 60_000))
375 .unwrap_or(0);
376 (name.clone(), count)
377 })
378 .collect();
379
380 let accounts: HashMap<String, AccountRoutingData> = all_names.iter().map(|name| {
381 let acc = data.accounts.get(name);
382 let available = acc.map(|a| !a.disabled && !a.auth_failed && now_ms >= a.cooldown_until_ms).unwrap_or(true);
383 let health_check_failed = acc.map(|a| a.health_check_failed).unwrap_or(false);
384 let cooldown_until_ms = acc.map(|a| a.cooldown_until_ms).unwrap_or(0);
385
386 let (util_5h, reset_5h, util_7d, reset_7d, exhausted) =
387 if let Some(rl) = data.rate_limits.get(name) {
388 let r5 = rl.reset_5h.filter(|&t| t > now_secs);
389 let r7 = rl.reset_7d.filter(|&t| t > now_secs);
390 let u5 = if r5.is_some() { rl.utilization_5h.unwrap_or(0.0) } else { 0.0 };
391 let u7 = if r7.is_some() { rl.utilization_7d.unwrap_or(0.0) } else { 0.0 };
392 let ex = (rl.status_5h.as_deref() == Some("exhausted") && r5.is_some())
393 || (rl.status_7d.as_deref() == Some("exhausted") && r7.is_some());
394 (u5, r5, u7, r7, ex)
395 } else {
396 (0.0, None, 0.0, None, false)
397 };
398
399 let burst_request_count = burst_counts.get(name).copied().unwrap_or(0);
400
401 (name.clone(), AccountRoutingData {
402 available,
403 health_check_failed,
404 exhausted,
405 cooldown_until_ms,
406 util_5h,
407 util_7d,
408 reset_5h_secs: reset_5h,
409 reset_7d_secs: reset_7d,
410 burst_request_count,
411 })
412 }).collect();
413
414 RoutingSnapshot { accounts, now_secs }
415 }
416
417 pub fn record_request_burst(&self, name: &str) {
423 let mut data = self.inner.lock();
424 data.burst_windows.entry(name.to_owned()).or_default().push_back(now_ms());
425 }
426
427 fn burst_count_inner(deque: &mut VecDeque<u64>, window_ms: u64) -> usize {
429 let cutoff = now_ms().saturating_sub(window_ms);
430 while deque.front().map(|&t| t < cutoff).unwrap_or(false) {
432 deque.pop_front();
433 }
434 deque.len()
435 }
436
437 pub fn set_cooldown(&self, name: &str, duration_ms: u64) {
442 {
443 let mut data = self.inner.lock();
444 let acc = data.accounts.entry(name.to_owned()).or_default();
445 acc.cooldown_until_ms = now_ms() + duration_ms;
446 }
447 self.persist();
448 }
449
450 pub fn set_cooldown_staggered(&self, name: &str, duration_ms: u64) {
455 const STAGGER_MS: u64 = 5_000;
456 {
457 let mut data = self.inner.lock();
458 let now = now_ms();
459 let target = now + duration_ms;
460
461 let nearby_count = data.accounts.iter()
463 .filter(|(n, a)| {
464 *n != name
465 && a.cooldown_until_ms > now
466 && (a.cooldown_until_ms as i64 - target as i64).unsigned_abs() < STAGGER_MS
467 })
468 .count() as u64;
469
470 let offset = nearby_count.saturating_mul(STAGGER_MS);
471 let acc = data.accounts.entry(name.to_owned()).or_default();
472 acc.cooldown_until_ms = target + offset;
473 }
474 self.persist();
475 }
476
477 pub fn disable_account(&self, name: &str) {
478 {
479 let mut data = self.inner.lock();
480 data.accounts.entry(name.to_owned()).or_default().disabled = true;
481 }
482 self.persist();
483 }
484
485 pub fn set_auth_failed(&self, name: &str) {
486 {
487 let mut data = self.inner.lock();
488 let acc = data.accounts.entry(name.to_owned()).or_default();
489 acc.auth_failed = true;
490 acc.disabled = true; }
492 self.persist();
493 }
494
495 pub fn clear_auth_failed(&self, name: &str) {
497 {
498 let mut data = self.inner.lock();
499 if let Some(acc) = data.accounts.get_mut(name) {
500 acc.auth_failed = false;
501 acc.disabled = false;
502 }
503 }
504 self.persist();
505 }
506
507 pub fn auth_failed_accounts<'a>(&self, names: &[&'a str]) -> Vec<&'a str> {
509 let data = self.inner.lock();
510 names.iter()
511 .filter(|&&n| data.accounts.get(n).map(|s| s.auth_failed).unwrap_or(false))
512 .copied()
513 .collect()
514 }
515
516 pub fn is_health_check_failed(&self, name: &str) -> bool {
521 let data = self.inner.lock();
522 data.accounts.get(name).map(|s| s.health_check_failed).unwrap_or(false)
523 }
524
525 pub fn set_health_check_failed(&self, name: &str) {
526 {
527 let mut data = self.inner.lock();
528 let acc = data.accounts.entry(name.to_owned()).or_default();
529 acc.health_check_failed = true;
530 }
531 self.persist();
532 }
533
534 pub fn clear_health_check_failed(&self, name: &str) {
535 {
536 let mut data = self.inner.lock();
537 if let Some(acc) = data.accounts.get_mut(name) {
538 acc.health_check_failed = false;
539 acc.health_check_failures = 0;
540 }
541 }
542 self.persist();
543 }
544
545 pub fn record_health_check_failure(&self, name: &str, threshold: u32) -> u32 {
548 let count;
549 {
550 let mut data = self.inner.lock();
551 let acc = data.accounts.entry(name.to_owned()).or_default();
552 acc.health_check_failures = acc.health_check_failures.saturating_add(1);
553 count = acc.health_check_failures;
554 if count >= threshold {
555 acc.health_check_failed = true;
556 }
557 }
558 if count >= threshold {
559 self.persist();
560 }
561 count
562 }
563
564 pub fn update_last_health_check(&self, name: &str) -> u64 {
566 let mut data = self.inner.lock();
567 let acc = data.accounts.entry(name.to_owned()).or_default();
568 let prev = acc.last_health_check_ms;
569 acc.last_health_check_ms = now_ms();
570 prev
571 }
572
573 pub fn health_check_info(&self, name: &str) -> (u64, u32) {
575 let data = self.inner.lock();
576 match data.accounts.get(name) {
577 Some(acc) => (acc.last_health_check_ms, acc.health_check_failures),
578 None => (0, 0),
579 }
580 }
581
582 pub fn get_sticky(&self, fingerprint: &str) -> Option<String> {
587 let data = self.inner.lock();
588 let entry = data.sticky.get(fingerprint)?;
589 if now_ms() < entry.expires_at_ms {
590 Some(entry.account_name.clone())
591 } else {
592 None
593 }
594 }
595
596 pub fn set_sticky(&self, fingerprint: &str, account_name: &str, ttl_ms: u64) {
597 const MAX_STICKY_ENTRIES: usize = 10_000;
598 {
599 let mut data = self.inner.lock();
600 if data.sticky.len() >= MAX_STICKY_ENTRIES {
602 let now = now_ms();
603 data.sticky.retain(|_, v| v.expires_at_ms > now);
604 if data.sticky.len() >= MAX_STICKY_ENTRIES {
606 data.sticky.clear();
607 }
608 }
609 data.sticky.insert(
610 fingerprint.to_owned(),
611 StickyEntry { account_name: account_name.to_owned(), expires_at_ms: now_ms() + ttl_ms },
612 );
613 }
614 self.persist();
615 }
616
617 pub fn reset_5h_secs(&self, name: &str) -> Option<u64> {
624 let now_secs = SystemTime::now()
625 .duration_since(UNIX_EPOCH)
626 .unwrap_or_default()
627 .as_secs();
628 let data = self.inner.lock();
629 let reset = data.rate_limits.get(name)?.reset_5h?;
630 if reset > now_secs { Some(reset) } else { None }
631 }
632
633 pub fn utilization_5h(&self, name: &str) -> f64 {
636 let now_secs = SystemTime::now()
637 .duration_since(UNIX_EPOCH)
638 .unwrap_or_default()
639 .as_secs();
640 let data = self.inner.lock();
641 let Some(rl) = data.rate_limits.get(name) else { return 0.0 };
642 if rl.reset_5h.map(|t| t <= now_secs).unwrap_or(false) {
644 return 0.0;
645 }
646 rl.utilization_5h.unwrap_or(0.0)
647 }
648
649 pub fn utilization_7d(&self, name: &str) -> f64 {
652 let now_secs = SystemTime::now()
653 .duration_since(UNIX_EPOCH)
654 .unwrap_or_default()
655 .as_secs();
656 let data = self.inner.lock();
657 let Some(rl) = data.rate_limits.get(name) else { return 0.0 };
658 if rl.reset_7d.map(|t| t <= now_secs).unwrap_or(false) {
659 return 0.0;
660 }
661 rl.utilization_7d.unwrap_or(0.0)
662 }
663
664 pub fn reset_7d_secs(&self, name: &str) -> Option<u64> {
667 let now_secs = SystemTime::now()
668 .duration_since(UNIX_EPOCH)
669 .unwrap_or_default()
670 .as_secs();
671 let data = self.inner.lock();
672 let reset = data.rate_limits.get(name)?.reset_7d?;
673 if reset > now_secs { Some(reset) } else { None }
674 }
675
676 pub fn record_usage(&self, name: &str, input_tokens: u64, output_tokens: u64) {
679 if input_tokens == 0 && output_tokens == 0 {
680 return;
681 }
682 {
683 let mut data = self.inner.lock();
684 let quota = data.quota.entry(name.to_owned()).or_default();
685 let now = now_ms();
686 if quota.window_start_ms == 0 || now >= quota.window_start_ms + WINDOW_MS {
687 quota.window_start_ms = now;
688 quota.input_tokens = 0;
689 quota.output_tokens = 0;
690 }
691 quota.input_tokens = quota.input_tokens.saturating_add(input_tokens);
692 quota.output_tokens = quota.output_tokens.saturating_add(output_tokens);
693 }
694 self.persist();
695 }
696
697 pub fn quota_snapshot(&self) -> HashMap<String, QuotaWindow> {
699 self.inner.lock().quota.clone()
700 }
701
702 pub fn update_rate_limits(&self, name: &str, info: RateLimitInfo) {
707 let prev = self.inner.lock().rate_limits.get(name).cloned();
708
709 let prev_5h = prev.as_ref().and_then(|p| p.utilization_5h).unwrap_or(0.0);
711 let prev_7d = prev.as_ref().and_then(|p| p.utilization_7d).unwrap_or(0.0);
712 if let Some(u) = info.utilization_5h {
713 if u >= 0.9 && prev_5h < 0.9 {
714 warn!(account = %name, utilization = %format!("{:.0}%", u * 100.0),
715 "5h rate limit above 90% — approaching quota");
716 }
717 }
718 if let Some(u) = info.utilization_7d {
719 if u >= 0.9 && prev_7d < 0.9 {
720 warn!(account = %name, utilization = %format!("{:.0}%", u * 100.0),
721 "7d rate limit above 90% — approaching quota");
722 }
723 }
724
725 {
726 let mut data = self.inner.lock();
727 data.rate_limits.insert(name.to_owned(), info);
728 }
729 self.persist();
730 }
731
732 pub fn rate_limit_snapshot(&self) -> HashMap<String, RateLimitInfo> {
733 self.inner.lock().rate_limits.clone()
734 }
735
736 pub fn get_pinned(&self) -> Option<String> {
741 self.inner.lock().pinned_account.clone()
742 }
743
744 pub fn set_pinned(&self, name: Option<String>) {
745 {
746 let mut data = self.inner.lock();
747 data.pinned_account = name;
748 }
749 self.persist();
750 }
751
752 pub fn get_last_used(&self) -> Option<String> {
757 self.inner.lock().last_used_account.clone()
758 }
759
760 pub fn set_last_used(&self, name: &str) {
761 {
762 let mut data = self.inner.lock();
763 data.last_used_account = Some(name.to_owned());
764 }
765 self.persist();
766 }
767
768 pub fn get_model_override(&self) -> Option<String> {
773 self.inner.lock().model_override.clone()
774 }
775
776 pub fn set_model_override(&self, model: String) {
777 self.inner.lock().model_override = Some(model);
778 }
779
780 pub fn clear_model_override(&self) {
781 self.inner.lock().model_override = None;
782 }
783
784 pub fn get_routing_strategy(&self) -> Option<RoutingStrategy> {
789 self.inner.lock().routing_strategy_override
790 }
791
792 pub fn set_routing_strategy(&self, strategy: RoutingStrategy) {
793 self.inner.lock().routing_strategy_override = Some(strategy);
794 }
795
796 pub fn clear_routing_strategy(&self) {
797 self.inner.lock().routing_strategy_override = None;
798 }
799
800 pub fn get_burst_rpm_limit_override(&self) -> Option<u32> {
805 self.inner.lock().burst_rpm_limit_override
806 }
807
808 pub fn set_burst_rpm_limit_override(&self, limit: u32) {
809 self.inner.lock().burst_rpm_limit_override = Some(limit);
810 }
811
812 pub fn clear_burst_rpm_limit_override(&self) {
813 self.inner.lock().burst_rpm_limit_override = None;
814 }
815
816 pub fn get_fallback_model_override(&self) -> Option<Option<String>> {
823 self.inner.lock().fallback_model_override.clone()
824 }
825
826 pub fn set_fallback_model_override(&self, model: Option<String>) {
827 self.inner.lock().fallback_model_override = Some(model);
828 }
829
830 pub fn clear_fallback_model_override(&self) {
831 self.inner.lock().fallback_model_override = None;
832 }
833
834 pub fn get_effort_override(&self) -> Option<String> {
839 self.inner.lock().effort_override.clone()
840 }
841
842 pub fn set_effort_override(&self, effort: String) {
843 self.inner.lock().effort_override = Some(effort);
844 }
845
846 pub fn clear_effort_override(&self) {
847 self.inner.lock().effort_override = None;
848 }
849
850 pub fn get_thinking_override(&self) -> Option<String> {
855 self.inner.lock().thinking_override.clone()
856 }
857
858 pub fn set_thinking_override(&self, mode: String) {
859 self.inner.lock().thinking_override = Some(mode);
860 }
861
862 pub fn clear_thinking_override(&self) {
863 self.inner.lock().thinking_override = None;
864 }
865
866 pub fn get_alerts_muted(&self) -> bool {
871 self.alerts_muted.load(Ordering::Relaxed)
872 }
873
874 pub fn set_alerts_muted(&self, muted: bool) {
875 self.alerts_muted.store(muted, Ordering::Relaxed);
876 }
877
878 pub fn record_request(&self, log: RequestLog) {
883 let mut data = self.inner.lock();
884 if data.recent_requests.len() >= MAX_RECENT {
885 data.recent_requests.pop_front();
886 }
887 data.recent_requests.push_back(log);
888 }
889
890 pub fn recent_requests_snapshot(&self) -> Vec<RequestLog> {
892 let data = self.inner.lock();
893 data.recent_requests.iter().rev().cloned().collect()
894 }
895
896 pub fn record_global(&self, model: &str, input_tokens: u64, output_tokens: u64) {
902 if input_tokens == 0 && output_tokens == 0 {
903 return;
904 }
905 let cost = crate::pricing::api_cost_usd(model, input_tokens, output_tokens);
906 let key = today_key();
907 {
908 let mut data = self.inner.lock();
909 let bucket = data.global_daily.entry(key).or_default();
910 bucket.input_tokens = bucket.input_tokens.saturating_add(input_tokens);
911 bucket.output_tokens = bucket.output_tokens.saturating_add(output_tokens);
912 bucket.api_cost_usd += cost;
913 data.all_time_input = data.all_time_input.saturating_add(input_tokens);
914 data.all_time_output = data.all_time_output.saturating_add(output_tokens);
915 data.all_time_cost_usd += cost;
916
917 if data.global_daily.len() > 100 {
919 let cutoff = epoch_to_ymd(
920 SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_secs()
921 .saturating_sub(90 * 86400)
922 );
923 data.global_daily.retain(|k, _| k.as_str() >= cutoff.as_str());
924 }
925 }
926 self.persist();
927 }
928
929 pub fn savings_snapshot(&self) -> SavingsSnapshot {
931 let now_secs = SystemTime::now()
932 .duration_since(UNIX_EPOCH)
933 .unwrap_or_default()
934 .as_secs();
935 let today = today_key();
936 let week_ago = epoch_to_ymd(now_secs.saturating_sub(7 * 86400));
937
938 let data = self.inner.lock();
939
940 let today_bucket = data.global_daily.get(&today).cloned().unwrap_or_default();
941
942 let (week_input, week_output, week_cost) = data.global_daily.iter()
943 .filter(|(k, _)| k.as_str() >= week_ago.as_str())
944 .fold((0u64, 0u64, 0f64), |(i, o, c), (_, b)| {
945 (i + b.input_tokens, o + b.output_tokens, c + b.api_cost_usd)
946 });
947
948 SavingsSnapshot {
949 today_input: today_bucket.input_tokens,
950 today_output: today_bucket.output_tokens,
951 today_cost_usd: today_bucket.api_cost_usd,
952 week_input,
953 week_output,
954 week_cost_usd: week_cost,
955 all_time_input: data.all_time_input,
956 all_time_output: data.all_time_output,
957 all_time_cost_usd: data.all_time_cost_usd,
958 }
959 }
960
961 fn persist(&self) {
966 self.pending.store(true, Ordering::Release);
968 }
969}
970
971#[cfg(test)]
972mod tests {
973 use super::*;
974
975 #[test]
976 fn test_sticky_ttl_expiry() {
977 let store = StateStore::new_empty();
978 let fp = "conv-fp-ttl";
979 store.set_sticky(fp, "account1", 500); assert_eq!(store.get_sticky(fp).as_deref(), Some("account1"),
981 "sticky should be available immediately");
982 std::thread::sleep(std::time::Duration::from_millis(600));
983 assert!(store.get_sticky(fp).is_none(),
984 "sticky must expire after TTL elapses");
985 }
986
987 #[test]
988 fn test_cooldown_blocks_availability() {
989 let store = StateStore::new_empty();
990 store.set_cooldown("acc", 5_000); assert!(!store.is_available("acc"), "account should be unavailable during cooldown");
992 }
993
994 #[test]
995 fn test_disable_blocks_availability() {
996 let store = StateStore::new_empty();
997 store.disable_account("acc");
998 assert!(!store.is_available("acc"), "disabled account must be unavailable");
999 }
1000
1001 #[test]
1002 fn test_quota_accumulates() {
1003 let store = StateStore::new_empty();
1004 store.record_usage("acc", 100, 50);
1005 store.record_usage("acc", 200, 75);
1006 let snap = store.quota_snapshot();
1007 let q = &snap["acc"];
1008 assert_eq!(q.input_tokens, 300);
1009 assert_eq!(q.output_tokens, 125);
1010 assert_eq!(q.total_tokens(), 425);
1011 }
1012
1013 #[test]
1014 fn test_pinned_account_round_trip() {
1015 let store = StateStore::new_empty();
1016 assert!(store.get_pinned().is_none());
1017 store.set_pinned(Some("myaccount".into()));
1018 assert_eq!(store.get_pinned().as_deref(), Some("myaccount"));
1019 store.set_pinned(None);
1020 assert!(store.get_pinned().is_none());
1021 }
1022
1023 #[test]
1024 fn test_last_used_round_trip() {
1025 let store = StateStore::new_empty();
1026 assert!(store.get_last_used().is_none());
1027 store.set_last_used("acc1");
1028 assert_eq!(store.get_last_used().as_deref(), Some("acc1"));
1029 }
1030
1031 #[test]
1032 fn test_recent_requests_ring_buffer() {
1033 let store = StateStore::new_empty();
1034 for i in 0..=(MAX_RECENT + 5) {
1036 store.record_request(RequestLog {
1037 ts_ms: i as u64,
1038 account: "acc".into(),
1039 model: "m".into(),
1040 status: 200,
1041 input_tokens: 1,
1042 output_tokens: 1,
1043 duration_ms: 1,
1044 });
1045 }
1046 let snap = store.recent_requests_snapshot();
1047 assert_eq!(snap.len(), MAX_RECENT, "buffer must not grow beyond MAX_RECENT");
1048 assert!(snap[0].ts_ms > snap[snap.len() - 1].ts_ms, "snapshot must be newest-first");
1050 }
1051
1052 #[test]
1053 fn test_health_check_failed_round_trip() {
1054 let store = StateStore::new_empty();
1055 assert!(!store.is_health_check_failed("acc"), "fresh account should not be health-check-failed");
1056
1057 store.set_health_check_failed("acc");
1058 assert!(store.is_health_check_failed("acc"), "should be marked unhealthy after set");
1059
1060 store.clear_health_check_failed("acc");
1061 assert!(!store.is_health_check_failed("acc"), "should be cleared after clear");
1062 }
1063
1064 #[test]
1065 fn test_health_check_failure_threshold() {
1066 let store = StateStore::new_empty();
1067
1068 let count = store.record_health_check_failure("acc", 2);
1070 assert_eq!(count, 1);
1071 assert!(!store.is_health_check_failed("acc"),
1072 "should not be marked after 1 failure (threshold=2)");
1073
1074 let count = store.record_health_check_failure("acc", 2);
1076 assert_eq!(count, 2);
1077 assert!(store.is_health_check_failed("acc"),
1078 "should be marked after 2 failures (threshold=2)");
1079 }
1080
1081 #[test]
1082 fn test_clear_health_check_resets_failure_count() {
1083 let store = StateStore::new_empty();
1084 store.record_health_check_failure("acc", 2);
1085 store.record_health_check_failure("acc", 2);
1086 assert!(store.is_health_check_failed("acc"));
1087
1088 store.clear_health_check_failed("acc");
1089 assert!(!store.is_health_check_failed("acc"));
1090
1091 let (_, failures) = store.health_check_info("acc");
1092 assert_eq!(failures, 0, "failure count must reset to 0 after clear");
1093 }
1094
1095 #[test]
1096 fn test_health_check_info_and_last_check() {
1097 let store = StateStore::new_empty();
1098 let (last, failures) = store.health_check_info("acc");
1099 assert_eq!(last, 0, "fresh account last_health_check_ms should be 0");
1100 assert_eq!(failures, 0);
1101
1102 let prev = store.update_last_health_check("acc");
1103 assert_eq!(prev, 0, "first update should return previous value 0");
1104
1105 let (last2, _) = store.health_check_info("acc");
1106 assert!(last2 > 0, "last_health_check_ms should be updated to now");
1107 }
1108
1109 #[test]
1110 fn test_health_check_failed_persists() {
1111 let path = std::env::temp_dir().join(format!(
1112 "shunt_test_hc_{}.json",
1113 std::time::SystemTime::now()
1114 .duration_since(std::time::UNIX_EPOCH)
1115 .unwrap()
1116 .as_nanos()
1117 ));
1118
1119 {
1120 let store = StateStore::load(&path);
1121 store.set_health_check_failed("acc");
1122 std::thread::sleep(std::time::Duration::from_millis(300));
1123 }
1124
1125 let store2 = StateStore::load(&path);
1126 assert!(store2.is_health_check_failed("acc"),
1127 "health_check_failed must survive restart");
1128
1129 let (last, failures) = store2.health_check_info("acc");
1131 assert_eq!(last, 0, "last_health_check_ms is ephemeral, should be 0 after reload");
1132 assert_eq!(failures, 0, "health_check_failures is ephemeral, should be 0 after reload");
1133
1134 let _ = std::fs::remove_file(&path);
1135 }
1136
1137 #[test]
1138 fn test_state_persistence_roundtrip() {
1139 let path = std::env::temp_dir().join(format!(
1141 "shunt_test_state_{}.json",
1142 std::time::SystemTime::now()
1143 .duration_since(std::time::UNIX_EPOCH)
1144 .unwrap()
1145 .as_nanos()
1146 ));
1147
1148 {
1149 let store = StateStore::load(&path);
1150 store.set_cooldown("acc", 999_999_000); store.record_usage("acc", 111, 222);
1152 store.set_last_used("acc");
1153 std::thread::sleep(std::time::Duration::from_millis(300));
1155 }
1156
1157 let store2 = StateStore::load(&path);
1159 assert!(!store2.is_available("acc"), "cooldown must survive restart");
1160 let snap = store2.quota_snapshot();
1161 assert_eq!(snap["acc"].input_tokens, 111, "quota must survive restart");
1162 assert_eq!(snap["acc"].output_tokens, 222);
1163 assert_eq!(store2.get_last_used().as_deref(), Some("acc"),
1164 "last_used_account must survive restart");
1165
1166 let _ = std::fs::remove_file(&path);
1167 }
1168
1169 #[test]
1170 fn test_burst_window_tracking() {
1171 let store = StateStore::new_empty();
1172 for _ in 0..5 {
1174 store.record_request_burst("acc");
1175 }
1176 let snap = store.routing_snapshot();
1178 let data = snap.accounts.get("acc");
1179 assert!(data.is_none() || data.unwrap().burst_request_count == 0,
1180 "no account state yet, burst tracked separately");
1181 store.set_cooldown("acc", 0); for _ in 0..3 {
1184 store.record_request_burst("acc");
1185 }
1186 let snap = store.routing_snapshot();
1187 let data = snap.accounts.get("acc").expect("acc should exist in snapshot");
1188 assert_eq!(data.burst_request_count, 8, "should count all recent requests");
1190 }
1191}
1192
1193fn today_key() -> String {
1195 let secs = SystemTime::now()
1196 .duration_since(UNIX_EPOCH)
1197 .unwrap_or_default()
1198 .as_secs();
1199 epoch_to_ymd(secs)
1200}
1201
1202fn epoch_to_ymd(secs: u64) -> String {
1204 let days = (secs / 86400) as i64;
1205 let z = days + 719_468;
1206 let era = if z >= 0 { z } else { z - 146_096 } / 146_097;
1207 let doe = z - era * 146_097;
1208 let yoe = (doe - doe / 1_460 + doe / 36_524 - doe / 146_096) / 365;
1209 let y = yoe + era * 400;
1210 let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
1211 let mp = (5 * doy + 2) / 153;
1212 let d = doy - (153 * mp + 2) / 5 + 1;
1213 let m = if mp < 10 { mp + 3 } else { mp - 9 };
1214 let y = if m <= 2 { y + 1 } else { y };
1215 format!("{y:04}-{m:02}-{d:02}")
1216}
1217
1218fn write_to_disk(data: &StateData, path: &Path) -> Result<()> {
1219 if let Some(parent) = path.parent() {
1220 std::fs::create_dir_all(parent)?;
1221 }
1222 let tmp = path.with_extension("tmp");
1223 std::fs::write(&tmp, serde_json::to_string_pretty(data)?)?;
1224 #[cfg(unix)]
1225 {
1226 use std::os::unix::fs::PermissionsExt;
1227 let _ = std::fs::set_permissions(&tmp, std::fs::Permissions::from_mode(0o600));
1228 }
1229 std::fs::rename(&tmp, path)?;
1230 Ok(())
1231}