tokkit 0.17.0

A simple(simplistic) OAUTH toolkit.
Documentation
use std::collections::BTreeMap;
use std::sync::mpsc;
use std::sync::{Arc, Mutex};
use std::time::{Duration, UNIX_EPOCH};

mod request_scheduler;
mod token_updater;

use super::*;
use crate::token_manager::token_provider::AccessTokenProvider;

pub type EpochMillis = u64;

pub fn initialize<
    T: Eq + Ord + Send + Sync + Clone + Display + 'static,
    C: Clock + Clone + Send + 'static,
>(
    groups: Vec<ManagedTokenGroup<T>>,
    clock: C,
) -> (Inner<T>, mpsc::Sender<ManagerCommand<T>>) {
    let tokens = Arc::new(create_tokens(&groups));
    let rows = create_rows(groups, clock.now());

    let (tx, rx) = mpsc::channel::<ManagerCommand<T>>();

    let is_running = Arc::new(AtomicBool::new(true));

    let inner = Inner { tokens, is_running };

    start(rows, inner.clone(), tx.clone(), rx, clock);

    (inner, tx)
}

fn create_rows<T: Clone>(
    groups: Vec<ManagedTokenGroup<T>>,
    now: EpochMillis,
) -> Vec<Mutex<TokenRow<T>>> {
    let mut states = Vec::new();
    for group in groups {
        for managed_token in group.managed_tokens {
            states.push(Mutex::new(TokenRow {
                token_id: managed_token.token_id.clone(),
                scopes: managed_token.scopes,
                refresh_threshold: group.refresh_threshold,
                warning_threshold: group.warning_threshold,
                last_touched: now,
                refresh_at: now,
                warn_at: now,
                expires_at: now,
                scheduled_for: now,
                token_state: TokenState::Uninitialized,
                last_notification_at: None,
                token_provider: group.token_provider.clone(),
            }));
        }
    }
    states
}

fn create_tokens<T: Eq + Ord + Clone + Display>(
    groups: &[ManagedTokenGroup<T>],
) -> BTreeMap<T, (usize, Mutex<StdResult<AccessToken, TokenErrorKind>>)> {
    let mut tokens: BTreeMap<T, (usize, Mutex<StdResult<AccessToken, TokenErrorKind>>)> =
        Default::default();
    let mut idx = 0;
    for group in groups {
        for managed_token in &group.managed_tokens {
            tokens.insert(
                managed_token.token_id.clone(),
                (
                    idx,
                    Mutex::new(Err(TokenErrorKind::NotInitialized(
                        managed_token.token_id.to_string(),
                    ))),
                ),
            );
            idx += 1;
        }
    }
    tokens
}

fn start<
    T: Eq + Ord + Send + Sync + Clone + Display + 'static,
    C: Clock + Clone + Send + 'static,
>(
    rows: Vec<Mutex<TokenRow<T>>>,
    inner: Inner<T>,
    sender: mpsc::Sender<ManagerCommand<T>>,
    receiver: mpsc::Receiver<ManagerCommand<T>>,
    clock: C,
) {
    let rows1 = Arc::new(rows);
    let rows2 = rows1.clone();
    let inner1 = inner.clone();
    let clock1 = clock.clone();
    thread::spawn(move || {
        let scheduler = request_scheduler::RefreshScheduler::new(
            &*rows1,
            &sender,
            500,
            10_000,
            &inner1.is_running,
            &clock1,
        );
        scheduler.start();
    });
    thread::spawn(move || {
        let token_updater = token_updater::TokenUpdater::new(
            &*rows2,
            &inner.tokens,
            receiver,
            &inner.is_running,
            &clock,
        );
        token_updater.start();
    });
}

#[derive(Clone)]
pub struct Inner<T> {
    pub tokens: Arc<BTreeMap<T, (usize, Mutex<StdResult<AccessToken, TokenErrorKind>>)>>,
    pub is_running: Arc<AtomicBool>,
}

impl<T: Eq + Ord + Clone + Display> Inner<T> {
    pub fn get_access_token(&self, token_id: &T) -> TokenResult<AccessToken> {
        match self.tokens.get(&token_id) {
            Some((_, guard)) => match &*guard.lock().unwrap() {
                Ok(token) => Ok(token.clone()),
                Err(err) => Err(err.clone().into()),
            },
            None => Err(TokenErrorKind::NoToken(token_id.to_string()).into()),
        }
    }
}

#[derive(PartialEq, Eq, Debug)]
pub enum TokenState {
    Uninitialized,
    Initializing,
    Ok,
    OkPending,
    Error,
    ErrorPending,
}

impl TokenState {
    pub fn is_refresh_pending(&self) -> bool {
        match *self {
            TokenState::Initializing | TokenState::OkPending | TokenState::ErrorPending => true,
            _ => false,
        }
    }

    pub fn is_uninitialized(&self) -> bool {
        match *self {
            TokenState::Uninitialized | TokenState::Initializing => true,
            _ => false,
        }
    }
}

pub struct TokenRow<T> {
    token_id: T,
    scopes: Vec<Scope>,
    refresh_threshold: f32,
    warning_threshold: f32,
    last_touched: EpochMillis,
    refresh_at: EpochMillis,
    warn_at: EpochMillis,
    expires_at: EpochMillis,
    scheduled_for: EpochMillis,
    token_state: TokenState,
    last_notification_at: Option<EpochMillis>,
    token_provider: Arc<dyn AccessTokenProvider + Send + Sync + 'static>,
}

#[derive(Debug, PartialEq)]
pub enum ManagerCommand<T> {
    ScheduledRefresh(usize, u64),
    ForceRefresh(T, u64),
    RefreshOnError(usize, u64),
}

pub trait Clock {
    fn now(&self) -> EpochMillis;
}

pub struct SystemClock;

impl Clock for SystemClock {
    fn now(&self) -> EpochMillis {
        millis_from_duration(UNIX_EPOCH.elapsed().unwrap())
    }
}

impl Clone for SystemClock {
    fn clone(&self) -> Self
    where
        Self: Sized,
    {
        SystemClock
    }
}

fn diff_millis(start_millis: u64, end_millis: u64) -> u64 {
    if start_millis > end_millis {
        0
    } else {
        end_millis - start_millis
    }
}

fn minus_millis(from: u64, subtract: u64) -> u64 {
    if subtract > from {
        0
    } else {
        from - subtract
    }
}

fn elapsed_millis_from(start_millis: u64, clock: &dyn Clock) -> u64 {
    diff_millis(start_millis, clock.now())
}

fn millis_from_duration(d: Duration) -> u64 {
    (d.as_secs() * 1000) + d.subsec_millis() as u64
}