use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::sync::mpsc::{Receiver, Sender};
use std::time::Duration;
use crate::lockorder::{RankedMutex, rank};
use crate::providers::{Provider, ThirdPartyStats};
use super::fetch::{
FetchError, UsageInfo, UsageWindow, epoch_secs_to_iso, fetch_raw, iso_to_epoch_secs,
load_disk_cache, now_epoch_secs, now_ms, write_disk_cache,
};
const TICK_INTERVAL: Duration = Duration::from_secs(1);
pub(crate) const REFRESH_INTERVAL_MS: u64 = 60_000;
const MAX_RETRY_AFTER_MS: u64 = 15 * 60 * 1000;
#[repr(transparent)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub(crate) struct EpochMs(u64);
#[repr(transparent)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub(crate) struct IntervalMs(u64);
impl EpochMs {
pub(crate) const fn from_millis(ms: u64) -> Self {
Self(ms)
}
pub(crate) const fn as_millis(self) -> u64 {
self.0
}
pub(crate) const fn saturating_add(self, interval: IntervalMs) -> EpochMs {
EpochMs(self.0.saturating_add(interval.0))
}
}
impl IntervalMs {
pub(crate) const fn from_millis(ms: u64) -> Self {
Self(ms)
}
}
pub(crate) type UsageStore = Arc<RankedMutex<HashMap<String, UsageInfo>, rank::UsageStore>>;
pub(crate) type StatusStore = Arc<RankedMutex<HashMap<String, FetchStatus>, rank::UsageStatus>>;
pub(crate) type TokenList = Arc<RankedMutex<Vec<TokenEntry>, rank::Tokens>>;
pub(crate) type LastFetchedAt = Arc<RankedMutex<HashMap<String, EpochMs>, rank::LastFetched>>;
pub(crate) type RefetchQueue = Arc<RankedMutex<HashSet<String>, rank::RefetchQueue>>;
pub(crate) type PendingSwitch = Arc<RankedMutex<HashSet<String>, rank::PendingSwitch>>;
pub(crate) type PendingSwitchOff = Arc<RankedMutex<bool, rank::PendingSwitchOff>>;
#[derive(Clone)]
pub(crate) struct TokenEntry {
pub(crate) name: String,
pub(crate) access_token: String,
pub(crate) refresh_token: Option<String>,
pub(crate) auto_start: bool,
pub(crate) access_expires_at: Option<i64>,
}
#[derive(Clone)]
pub(crate) struct ThirdPartyEntry {
pub(crate) name: String,
pub(crate) provider: Provider,
pub(crate) api_key: String,
}
trait NamedEntry {
fn name(&self) -> &str;
}
impl NamedEntry for TokenEntry {
fn name(&self) -> &str {
&self.name
}
}
impl NamedEntry for ThirdPartyEntry {
fn name(&self) -> &str {
&self.name
}
}
pub(crate) type ThirdPartyList = Arc<RankedMutex<Vec<ThirdPartyEntry>, rank::ThirdParty>>;
pub(crate) type ThirdPartyUsageStore =
Arc<RankedMutex<HashMap<String, ThirdPartyStats>, rank::ThirdPartyUsageStore>>;
pub(crate) type ThirdPartyStatusStore =
Arc<RankedMutex<HashMap<String, FetchStatus>, rank::ThirdPartyStatus>>;
pub(crate) type NextRefreshPerProfile = Arc<RankedMutex<HashMap<String, u64>, rank::NextRefresh>>;
pub(crate) type ActivityStore = Arc<RankedMutex<HashMap<String, ProfileActivity>, rank::Activity>>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum ProfileActivity {
Idle,
Fetching,
Refreshing,
Switching,
#[allow(dead_code)]
Starting,
}
#[allow(dead_code)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum ActivityKind {
Fetching,
Refreshing,
Switching,
Starting,
}
impl ActivityKind {
#[allow(dead_code)]
pub(crate) fn as_activity(self) -> ProfileActivity {
match self {
ActivityKind::Fetching => ProfileActivity::Fetching,
ActivityKind::Refreshing => ProfileActivity::Refreshing,
ActivityKind::Switching => ProfileActivity::Switching,
ActivityKind::Starting => ProfileActivity::Starting,
}
}
}
#[derive(Debug)]
pub(crate) struct OpResult {
pub(crate) name: String,
pub(crate) kind: ActivityKind,
pub(crate) outcome: anyhow::Result<()>,
}
pub(crate) type OpResultSender = Sender<OpResult>;
pub(crate) type OpResultReceiver = Receiver<OpResult>;
#[derive(Debug)]
pub(crate) enum StartupSignal {
ReconcileDone,
ReconcileNeedsPrompt { active: String },
BootstrapDone,
}
pub(crate) type StartupSender = Sender<StartupSignal>;
pub(crate) type StartupReceiver = Receiver<StartupSignal>;
pub(crate) fn mark_activity(store: &ActivityStore, name: &str, activity: ProfileActivity) {
if let Ok(mut g) = store.lock() {
if matches!(activity, ProfileActivity::Idle) {
g.remove(name);
} else {
g.insert(name.to_string(), activity);
}
}
}
pub(crate) fn clear_activity(store: &ActivityStore, name: &str) {
if let Ok(mut g) = store.lock() {
g.remove(name);
}
}
pub(crate) fn is_idle(store: &ActivityStore, name: &str) -> bool {
match store.lock() {
Ok(g) => !g.contains_key(name),
Err(_) => false,
}
}
pub(crate) fn any_busy(store: &ActivityStore) -> bool {
match store.lock() {
Ok(g) => !g.is_empty(),
Err(_) => true,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum FetchStatus {
Fresh,
Cached,
Failed,
RateLimited,
}
pub(crate) type RotatedTokens = (String, Option<String>);
fn load_cached_with_status(name: &str, status: FetchStatus) -> (Option<UsageInfo>, FetchStatus) {
match load_disk_cache(name) {
Some(info) => (Some(info), status),
None => (None, FetchStatus::Failed),
}
}
fn fetch_with_rotation(
config: &crate::profile::ConfigHandle,
name: &str,
access_token: &str,
refresh_token: Option<&str>,
refetch: &RefetchQueue,
activity: &ActivityStore,
) -> FetchOutcome {
match fetch_raw(access_token) {
Ok(info) => return FetchOutcome::live(name, info, None),
Err(FetchError::RateLimited { retry_after }) => {
return FetchOutcome::cached(name, FetchStatus::RateLimited, None, retry_after);
}
Err(FetchError::Status(401)) => {}
Err(_) => return FetchOutcome::cached(name, FetchStatus::Cached, None, None),
}
let bail_to_cache = |rotated: Option<RotatedTokens>| {
FetchOutcome::cached(name, FetchStatus::Cached, rotated, None)
};
let Some(rt) = refresh_token else {
return bail_to_cache(None);
};
let Ok(_rotation_guard) = crate::runtime::RotationGuard::acquire(name) else {
return bail_to_cache(None);
};
if crate::runtime::has_live_session(name) {
return bail_to_cache(None);
}
mark_activity(activity, name, ProfileActivity::Refreshing);
let refresh_result = crate::oauth::refresh(rt);
mark_activity(activity, name, ProfileActivity::Fetching);
let tok = match refresh_result {
Ok(t) => t,
Err(_) => return bail_to_cache(None),
};
let access = tok.access_token.clone();
let refresh = tok.refresh_token.clone();
if crate::oauth::apply_rotated_tokens_locked(config, name, tok).is_err() {
return bail_to_cache(None);
}
let rotated: Option<RotatedTokens> = Some((access.clone(), Some(refresh)));
match fetch_raw(&access) {
Ok(info) => FetchOutcome::live(name, info, rotated),
Err(FetchError::RateLimited { retry_after }) => {
FetchOutcome::cached(name, FetchStatus::RateLimited, rotated, retry_after)
}
Err(_) => {
if let Ok(mut q) = refetch.lock() {
q.insert(name.to_string());
}
bail_to_cache(rotated)
}
}
}
struct FetchOutcome {
name: String,
info: Option<UsageInfo>,
status: FetchStatus,
rotated: Option<RotatedTokens>,
from_fetch: bool,
retry_after: Option<Duration>,
}
impl FetchOutcome {
fn live(name: &str, info: UsageInfo, rotated: Option<RotatedTokens>) -> Self {
Self {
name: name.to_string(),
info: Some(info),
status: FetchStatus::Fresh,
rotated,
from_fetch: true,
retry_after: None,
}
}
fn cached(
name: &str,
status: FetchStatus,
rotated: Option<RotatedTokens>,
retry_after: Option<Duration>,
) -> Self {
let (info, status) = load_cached_with_status(name, status);
Self {
name: name.to_string(),
info,
status,
rotated,
from_fetch: false,
retry_after,
}
}
}
fn window_lapsed(store: &UsageStore, name: &str, now_secs: i64) -> bool {
let Ok(s) = store.lock() else {
return false;
};
let Some(info) = s.get(name) else {
return false;
};
let live = info
.five_hour
.as_ref()
.and_then(|w| w.resets_at.as_deref())
.and_then(iso_to_epoch_secs)
.is_some_and(|resets_at| now_secs < resets_at);
!live
}
fn run_fetch(
config: &crate::profile::ConfigHandle,
mut entry: TokenEntry,
store: Option<&UsageStore>,
refetch: &RefetchQueue,
activity: &ActivityStore,
) -> FetchOutcome {
let mut kick_rotated: Option<RotatedTokens> = None;
if entry.auto_start
&& let Some(store) = store
{
let now_secs = now_epoch_secs();
if window_lapsed(store, &entry.name, now_secs) {
let kicked = crate::oauth::auto_start_kick(
config,
&entry.name,
&entry.access_token,
entry.refresh_token.as_deref(),
entry.access_expires_at,
Some(activity),
);
if let Some((access, refresh)) = kicked.rotated.clone() {
entry.access_token = access;
entry.refresh_token = refresh;
kick_rotated = kicked.rotated;
}
if kicked.opened {
mark_window_open(store, &entry.name, now_secs);
}
}
}
let mut outcome = fetch_with_rotation(
config,
&entry.name,
&entry.access_token,
entry.refresh_token.as_deref(),
refetch,
activity,
);
if outcome.rotated.is_none() {
outcome.rotated = kick_rotated;
}
outcome
}
fn apply_outcome(
outcome: FetchOutcome,
store: &UsageStore,
status: &StatusStore,
last_fetched: &LastFetchedAt,
) {
let now = EpochMs::from_millis(now_ms());
let is_fresh = outcome.from_fetch;
if is_fresh && let Some(info) = &outcome.info {
write_disk_cache(&outcome.name, info);
}
if let Ok(mut s) = store.lock()
&& let Some(info) = &outcome.info
{
if is_fresh || !s.contains_key(&outcome.name) {
s.insert(outcome.name.clone(), info.clone());
}
}
let defer = IntervalMs::from_millis(outcome.retry_after.map_or(0, |ra| {
(ra.as_millis() as u64)
.min(MAX_RETRY_AFTER_MS)
.saturating_sub(REFRESH_INTERVAL_MS)
}));
if let Ok(mut lf) = last_fetched.lock() {
lf.insert(outcome.name.clone(), now.saturating_add(defer));
}
if let Ok(mut st) = status.lock() {
st.insert(outcome.name.clone(), outcome.status);
}
}
fn mark_window_open(store: &UsageStore, name: &str, now_secs: i64) {
let Ok(mut s) = store.lock() else {
return;
};
let info = s.entry(name.to_string()).or_default();
let live = info
.five_hour
.as_ref()
.and_then(|w| w.resets_at.as_deref())
.and_then(iso_to_epoch_secs)
.is_some_and(|resets_at| now_secs < resets_at);
if live {
return;
}
info.five_hour = Some(UsageWindow {
utilization: 0.0,
resets_at: Some(epoch_secs_to_iso(now_secs + 5 * 3600)),
});
}
pub(crate) fn fetch_all_into(
config: &crate::profile::ConfigHandle,
tokens: &[TokenEntry],
store: &UsageStore,
status: &StatusStore,
last_fetched: &LastFetchedAt,
refetch: &RefetchQueue,
activity: &ActivityStore,
) {
if tokens.is_empty() {
return;
}
for entry in tokens {
mark_activity(activity, &entry.name, ProfileActivity::Fetching);
}
let handles: Vec<_> = tokens
.iter()
.cloned()
.map(|entry| {
let name = entry.name.clone();
let config = Arc::clone(config);
let refetch = Arc::clone(refetch);
let activity = Arc::clone(activity);
let h =
std::thread::spawn(move || run_fetch(&config, entry, None, &refetch, &activity));
(name, h)
})
.collect();
for (name, h) in handles {
match h.join() {
Ok(outcome) => {
clear_activity(activity, &outcome.name);
apply_outcome(outcome, store, status, last_fetched);
}
Err(_) => {
clear_activity(activity, &name);
}
}
}
}
pub(crate) fn collect_third_party_entries(
profiles: &[crate::profile::Profile],
) -> Vec<ThirdPartyEntry> {
profiles
.iter()
.filter_map(|p| {
let provider = p.provider?;
let api_key = p.api_key.clone()?;
Some(ThirdPartyEntry {
name: p.name.to_string(),
provider,
api_key,
})
})
.collect()
}
fn fetch_third_party_due(state: &SchedulerState, due: Vec<ThirdPartyEntry>) {
for entry in &due {
mark_activity(&state.activity, &entry.name, ProfileActivity::Fetching);
}
let handles: Vec<_> = due
.into_iter()
.map(|entry| {
let name = entry.name.clone();
let h = std::thread::spawn(move || {
crate::providers::fetch_third_party_usage(entry.provider, &entry.api_key)
});
(name, h)
})
.collect();
for (name, h) in handles {
match h.join() {
Ok(Ok(stats)) => {
clear_activity(&state.activity, &name);
crate::providers::write_third_party_disk_cache(&name, &stats);
if let Ok(mut store) = state.third_party_usage_store.lock() {
store.insert(name.clone(), stats);
}
if let Ok(mut st) = state.third_party_status.lock() {
st.insert(name.clone(), FetchStatus::Fresh);
}
stamp_last_fetched(&state.last_fetched, name, None);
}
Ok(Err(err)) => {
clear_activity(&state.activity, &name);
let cached = crate::providers::load_third_party_disk_cache(&name);
let (status, retry_after) = match &err {
crate::providers::ThirdPartyError::RateLimited { retry_after } => {
(FetchStatus::RateLimited, *retry_after)
}
_ if cached.is_some() => (FetchStatus::Cached, None),
_ => (FetchStatus::Failed, None),
};
if let Some(c) = cached
&& let Ok(mut store) = state.third_party_usage_store.lock()
{
store.entry(name.clone()).or_insert(c);
}
if let Ok(mut st) = state.third_party_status.lock() {
st.insert(name.clone(), status);
}
stamp_last_fetched(&state.last_fetched, name, retry_after);
}
Err(_) => {
clear_activity(&state.activity, &name);
}
}
}
}
fn stamp_last_fetched(last_fetched: &LastFetchedAt, name: String, retry_after: Option<Duration>) {
let defer = IntervalMs::from_millis(retry_after.map_or(0, |ra| {
(ra.as_millis() as u64)
.min(MAX_RETRY_AFTER_MS)
.saturating_sub(REFRESH_INTERVAL_MS)
}));
if let Ok(mut lf) = last_fetched.lock() {
lf.insert(name, EpochMs::from_millis(now_ms()).saturating_add(defer));
}
}
fn partition_and_merge<T: NamedEntry + Clone>(
snapshot: &[T],
forced: &HashSet<String>,
state: &SchedulerState,
now: u64,
) -> (Vec<T>, HashMap<String, u64>) {
if snapshot.is_empty() {
return (Vec::new(), HashMap::new());
}
let (mut due, mut next) = partition_due(snapshot, now, &state.last_fetched, &state.activity);
merge_forced(snapshot, forced, &mut due, &mut next, &state.activity, now);
(due, next)
}
fn publish_countdowns(
nrpp: &NextRefreshPerProfile,
oauth: HashMap<String, u64>,
third_party: HashMap<String, u64>,
) {
if let Ok(mut map) = nrpp.lock() {
map.clear();
map.extend(oauth);
map.extend(third_party);
}
}
pub(crate) struct SchedulerState {
config: crate::profile::ConfigHandle,
tokens: TokenList,
store: UsageStore,
status: StatusStore,
next_refresh_per_profile: NextRefreshPerProfile,
activity: ActivityStore,
last_fetched: LastFetchedAt,
pending_switch: PendingSwitch,
pending_switch_off: PendingSwitchOff,
refetch_queue: RefetchQueue,
third_party_tokens: ThirdPartyList,
third_party_usage_store: ThirdPartyUsageStore,
third_party_status: ThirdPartyStatusStore,
}
fn tick(state: &SchedulerState) {
let forced: HashSet<String> = state
.refetch_queue
.lock()
.ok()
.map(|mut q| std::mem::take(&mut *q))
.unwrap_or_default();
let oauth_snapshot: Vec<TokenEntry> =
state.tokens.lock().map(|t| t.clone()).unwrap_or_default();
let tp_snapshot: Vec<ThirdPartyEntry> = state
.third_party_tokens
.lock()
.map(|t| t.clone())
.unwrap_or_default();
let now = now_ms();
let (oauth_due, oauth_next) = partition_and_merge(&oauth_snapshot, &forced, state, now);
let (tp_due, tp_next) = partition_and_merge(&tp_snapshot, &forced, state, now);
publish_countdowns(&state.next_refresh_per_profile, oauth_next, tp_next);
let fetched = !oauth_due.is_empty() || !tp_due.is_empty();
if !oauth_due.is_empty() {
for entry in &oauth_due {
mark_activity(&state.activity, &entry.name, ProfileActivity::Fetching);
}
let handles: Vec<_> = oauth_due
.into_iter()
.map(|entry| {
let name = entry.name.clone();
let config = Arc::clone(&state.config);
let store = Arc::clone(&state.store);
let refetch_queue = Arc::clone(&state.refetch_queue);
let activity = Arc::clone(&state.activity);
let h = std::thread::spawn(move || {
run_fetch(&config, entry, Some(&store), &refetch_queue, &activity)
});
(name, h)
})
.collect();
for (name, h) in handles {
match h.join() {
Ok(outcome) => {
clear_activity(&state.activity, &outcome.name);
if let Some((new_access, new_refresh)) = &outcome.rotated
&& let Ok(mut t) = state.tokens.lock()
&& let Some(entry) = t.iter_mut().find(|e| e.name == outcome.name)
{
entry.access_token = new_access.clone();
entry.refresh_token = new_refresh.clone();
}
apply_outcome(outcome, &state.store, &state.status, &state.last_fetched);
}
Err(_) => {
clear_activity(&state.activity, &name);
}
}
}
scan_auto_switch(
&state.config,
&state.store,
&state.activity,
&state.pending_switch,
&state.pending_switch_off,
);
}
if !tp_due.is_empty() {
fetch_third_party_due(state, tp_due);
}
if fetched {
let now = now_ms();
let (_, oauth_after) =
partition_due(&oauth_snapshot, now, &state.last_fetched, &state.activity);
let (_, tp_after) = partition_due(&tp_snapshot, now, &state.last_fetched, &state.activity);
publish_countdowns(&state.next_refresh_per_profile, oauth_after, tp_after);
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn spawn_refresher(
config: crate::profile::ConfigHandle,
tokens: TokenList,
store: UsageStore,
status: StatusStore,
next_refresh_per_profile: NextRefreshPerProfile,
activity: ActivityStore,
last_fetched: LastFetchedAt,
pending_switch: PendingSwitch,
pending_switch_off: PendingSwitchOff,
refetch_queue: RefetchQueue,
third_party_tokens: ThirdPartyList,
third_party_usage_store: ThirdPartyUsageStore,
third_party_status: ThirdPartyStatusStore,
) {
let state = SchedulerState {
config,
tokens,
store,
status,
next_refresh_per_profile,
activity,
last_fetched,
pending_switch,
pending_switch_off,
refetch_queue,
third_party_tokens,
third_party_usage_store,
third_party_status,
};
std::thread::spawn(move || {
loop {
std::thread::sleep(TICK_INTERVAL);
tick(&state);
}
});
}
fn scan_auto_switch(
config: &crate::profile::ConfigHandle,
store: &UsageStore,
activity: &ActivityStore,
pending_switch: &PendingSwitch,
pending_switch_off: &PendingSwitchOff,
) {
{
let Ok(p) = pending_switch.lock() else { return };
if !p.is_empty() {
return;
}
}
{
let Ok(off) = pending_switch_off.lock() else {
return;
};
if *off {
return;
}
}
{
let Ok(a) = activity.lock() else { return };
if a.values().any(|v| matches!(v, ProfileActivity::Switching)) {
return;
}
}
let snapshot = {
let cfg = match config.lock() {
Ok(c) => c,
Err(_) => return,
};
crate::fallback::snapshot_chain(&cfg)
};
let Some(snapshot) = snapshot else {
return;
};
match crate::fallback::next_auto_switch_target(&snapshot, store) {
Some(crate::fallback::SwitchAction::To(name)) => {
if let Ok(mut p) = pending_switch.lock() {
p.insert(name);
}
}
Some(crate::fallback::SwitchAction::Off) => {
if let Ok(mut off) = pending_switch_off.lock() {
*off = true;
}
}
None => {}
}
}
fn partition_due<T: NamedEntry + Clone>(
snapshot: &[T],
now: u64,
last_fetched: &LastFetchedAt,
activity: &ActivityStore,
) -> (Vec<T>, HashMap<String, u64>) {
let now = EpochMs::from_millis(now);
let Ok(lf) = last_fetched.lock() else {
return (Vec::new(), HashMap::new());
};
let act = activity.lock();
let interval = IntervalMs::from_millis(REFRESH_INTERVAL_MS);
let mut due = Vec::new();
let mut per_profile = HashMap::with_capacity(snapshot.len());
for entry in snapshot {
let last = lf
.get(entry.name())
.copied()
.unwrap_or(EpochMs::from_millis(0));
let next = last.saturating_add(interval);
per_profile.insert(entry.name().to_string(), next.as_millis());
let excluded = match act.as_ref() {
Ok(a) => matches!(
a.get(entry.name()),
Some(ProfileActivity::Switching | ProfileActivity::Refreshing)
),
Err(_) => true, };
if excluded {
continue;
}
if now >= next {
due.push(entry.clone());
}
}
(due, per_profile)
}
fn merge_forced<T: NamedEntry + Clone>(
snapshot: &[T],
forced: &HashSet<String>,
due: &mut Vec<T>,
per_profile_next: &mut HashMap<String, u64>,
activity: &ActivityStore,
now: u64,
) {
if forced.is_empty() {
return;
}
let switching: HashSet<String> = match activity.lock() {
Ok(a) => a
.iter()
.filter(|(_, v)| matches!(v, ProfileActivity::Switching | ProfileActivity::Refreshing))
.map(|(n, _)| n.clone())
.collect(),
Err(_) => snapshot.iter().map(|e| e.name().to_string()).collect(),
};
let mut extras: Vec<T> = Vec::with_capacity(forced.len());
for entry in snapshot.iter().filter(|e| {
forced.contains(e.name())
&& !switching.contains(e.name())
&& !due.iter().any(|d| d.name() == e.name())
}) {
per_profile_next.insert(entry.name().to_string(), now);
extras.push(entry.clone());
}
due.extend(extras);
}
#[cfg(test)]
#[path = "../../tests/inline/scheduler.rs"]
mod tests;