use std::collections::BTreeMap;
use std::env;
use std::fmt::Display;
use std::result::Result as StdResult;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::Sender;
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::{Duration, Instant};
use crate::{AccessToken, Scope};
mod error;
mod internals;
pub mod token_provider;
pub use self::error::*;
use self::token_provider::*;
use super::{InitializationError, InitializationResult};
pub struct ManagedTokenBuilder<T> {
pub token_id: Option<T>,
pub scopes: Vec<Scope>,
}
impl<T: Eq + Send + Clone + Display> ManagedTokenBuilder<T> {
pub fn with_identifier(&mut self, token_id: T) -> &mut Self {
self.token_id = Some(token_id);
self
}
pub fn with_scope(&mut self, scope: Scope) -> &mut Self {
self.scopes.push(scope);
self
}
pub fn with_scopes(&mut self, scopes: Vec<Scope>) -> &mut Self {
for scope in scopes {
self.scopes.push(scope);
}
self
}
pub fn with_scopes_from_env(&mut self) -> StdResult<&mut Self, InitializationError> {
self.with_scopes_from_selected_env_var("TOKKIT_MANAGED_TOKEN_SCOPES")
}
pub fn with_scopes_from_selected_env_var(
&mut self,
env_name: &str,
) -> StdResult<&mut Self, InitializationError> {
match env::var(env_name) {
Ok(v) => {
let scopes = split_scopes(&v);
self.with_scopes(scopes)
}
Err(err) => return Err(InitializationError(err.to_string())),
};
Ok(self)
}
pub fn build(self) -> StdResult<ManagedToken<T>, InitializationError> {
let token_id = if let Some(token_id) = self.token_id {
token_id
} else {
return Err(InitializationError("Token name is mandatory".to_string()));
};
Ok(ManagedToken {
token_id,
scopes: self.scopes,
})
}
}
fn split_scopes(input: &str) -> Vec<Scope> {
input
.split(' ')
.filter(|s| !s.is_empty())
.map(Scope::new)
.collect()
}
impl ManagedTokenBuilder<String> {
pub fn with_id_from_env(&mut self) -> StdResult<&mut Self, InitializationError> {
self.with_id_from_selected_env_var("TOKKIT_MANAGED_TOKEN_ID")
}
pub fn with_id_from_selected_env_var(
&mut self,
env_name: &str,
) -> StdResult<&mut Self, InitializationError> {
match env::var(env_name) {
Ok(v) => self.token_id = Some(v),
Err(err) => return Err(InitializationError(err.to_string())),
};
Ok(self)
}
}
impl<T: Eq + Send + Clone + Display> Default for ManagedTokenBuilder<T> {
fn default() -> Self {
ManagedTokenBuilder {
token_id: Default::default(),
scopes: Default::default(),
}
}
}
pub struct ManagedToken<T> {
pub token_id: T,
pub scopes: Vec<Scope>,
}
pub struct ManagedTokenGroupBuilder<T, S: AccessTokenProvider + 'static> {
token_provider: Option<Arc<S>>,
managed_tokens: Vec<ManagedToken<T>>,
refresh_threshold: f32,
warning_threshold: f32,
}
impl<T: Eq + Send + Clone + Display, S: AccessTokenProvider + Send + Sync + 'static>
ManagedTokenGroupBuilder<T, S>
{
pub fn with_token_provider(&mut self, token_provider: S) -> &mut Self {
self.token_provider = Some(Arc::new(token_provider));
self
}
pub fn with_managed_token(&mut self, managed_token: ManagedToken<T>) -> &mut Self {
self.managed_tokens.push(managed_token);
self
}
pub fn with_refresh_threshold(&mut self, refresh_threshold: f32) -> &mut Self {
self.refresh_threshold = refresh_threshold;
self
}
pub fn with_warning_threshold(&mut self, warning_threshold: f32) -> &mut Self {
self.refresh_threshold = warning_threshold;
self
}
pub fn with_managed_token_from_builder(
&mut self,
builder: ManagedTokenBuilder<T>,
) -> StdResult<&mut Self, InitializationError> {
let managed_token = builder.build()?;
Ok(self.with_managed_token(managed_token))
}
pub fn single_token(token_id: T, scopes: Vec<Scope>, token_provider: S) -> Self {
let managed_token = ManagedToken { token_id, scopes };
let mut builder = Self::default();
builder.with_managed_token(managed_token);
builder.with_token_provider(token_provider);
builder
}
pub fn single_token_from_env(
token_id: T,
token_provider: S,
) -> StdResult<Self, InitializationError> {
let mut managed_token_builder = ManagedTokenBuilder::default();
managed_token_builder.with_identifier(token_id);
let _ = managed_token_builder.with_scopes_from_env()?;
let managed_token = managed_token_builder.build()?;
let mut builder = Self::default();
builder.with_managed_token(managed_token);
builder.with_token_provider(token_provider);
Ok(builder)
}
pub fn build(self) -> StdResult<ManagedTokenGroup<T>, InitializationError> {
let token_provider = if let Some(token_provider) = self.token_provider {
token_provider
} else {
return Err(InitializationError(
"Token service is mandatory".to_string(),
));
};
if self.managed_tokens.is_empty() {
return Err(InitializationError(
"Managed Tokens must not be empty".to_string(),
));
}
if self.refresh_threshold <= 0.0 || self.refresh_threshold > 1.0 {
return Err(InitializationError(
"Refresh threshold must be of (0;1]".to_string(),
));
}
if self.warning_threshold <= 0.0 || self.warning_threshold > 1.0 {
return Err(InitializationError(
"Warning threshold must be of (0;1]".to_string(),
));
}
Ok(ManagedTokenGroup {
token_provider,
managed_tokens: self.managed_tokens,
refresh_threshold: self.refresh_threshold,
warning_threshold: self.warning_threshold,
})
}
}
impl<T: Eq + Send + Clone + Display, S: AccessTokenProvider + 'static> Default
for ManagedTokenGroupBuilder<T, S>
{
fn default() -> Self {
ManagedTokenGroupBuilder {
token_provider: Default::default(),
managed_tokens: Default::default(),
refresh_threshold: 0.75,
warning_threshold: 0.85,
}
}
}
pub struct ManagedTokenGroup<T> {
pub token_provider: Arc<dyn AccessTokenProvider + Send + Sync + 'static>,
pub managed_tokens: Vec<ManagedToken<T>>,
pub refresh_threshold: f32,
pub warning_threshold: f32,
}
struct IsRunningGuard {
is_running: Arc<AtomicBool>,
}
impl Default for IsRunningGuard {
fn default() -> IsRunningGuard {
IsRunningGuard {
is_running: Arc::new(AtomicBool::new(true)),
}
}
}
impl Drop for IsRunningGuard {
fn drop(&mut self) {
self.is_running.store(false, Ordering::Relaxed);
}
}
pub trait GivesAccessTokensById<T: Eq + Ord + Clone + Display> {
fn get_access_token(&self, token_id: &T) -> TokenResult<AccessToken>;
fn refresh(&self, name: &T);
}
#[derive(Clone)]
pub struct AccessTokenSource<T> {
tokens: Arc<BTreeMap<T, (usize, Mutex<StdResult<AccessToken, TokenErrorKind>>)>>,
sender: Sender<internals::ManagerCommand<T>>,
is_running: Arc<IsRunningGuard>,
}
impl<T: Eq + Ord + Clone + Display> AccessTokenSource<T> {
pub fn single_source_for(&self, token_id: &T) -> TokenResult<FixedAccessTokenSource<T>> {
match self.tokens.get(token_id) {
Some(_) => Ok(FixedAccessTokenSource {
token_source: self.clone(),
token_id: token_id.clone(),
}),
None => Err(TokenErrorKind::NoToken(token_id.to_string()).into()),
}
}
pub fn single_source_sync_for(
&self,
token_id: &T,
) -> TokenResult<FixedAccessTokenSourceSync<T>> {
match self.tokens.get(token_id) {
Some(_) => Ok(FixedAccessTokenSourceSync {
token_source: self.synced(),
token_id: token_id.clone(),
}),
None => Err(TokenErrorKind::NoToken(token_id.to_string()).into()),
}
}
pub fn synced(&self) -> AccessTokenSourceSync<T> {
AccessTokenSourceSync {
tokens: self.tokens.clone(),
sender: Arc::new(Mutex::new(self.sender.clone())),
is_running: self.is_running.clone(),
}
}
pub fn new_detached(tokens: &[(T, AccessToken)]) -> AccessTokenSource<T> {
let mut tokens_map = BTreeMap::new();
for (i, (id, token)) in tokens.iter().enumerate() {
let item = (i, Mutex::new(Ok(token.clone())));
tokens_map.insert(id.clone(), item);
}
let (tx, _) = ::std::sync::mpsc::channel::<internals::ManagerCommand<T>>();
AccessTokenSource {
tokens: Arc::new(tokens_map),
is_running: Default::default(),
sender: tx,
}
}
}
impl<T: Eq + Ord + Clone + Display> GivesAccessTokensById<T> for AccessTokenSource<T> {
fn get_access_token(&self, token_id: &T) -> TokenResult<AccessToken> {
match self.tokens.get(&token_id) {
Some((_, guard)) => match &*guard.lock().unwrap() {
Ok(token) => Ok(token.clone()),
Err(err) => Err(err.clone().into()),
},
None => Err(TokenErrorKind::NoToken(token_id.to_string()).into()),
}
}
fn refresh(&self, name: &T) {
match self.sender.send(internals::ManagerCommand::ForceRefresh(
name.clone(),
internals::Clock::now(&internals::SystemClock),
)) {
Ok(_) => (),
Err(err) => warn!("Could send send refresh command for {}: {}", name, err),
}
}
}
#[derive(Clone)]
pub struct AccessTokenSourceSync<T> {
tokens: Arc<BTreeMap<T, (usize, Mutex<StdResult<AccessToken, TokenErrorKind>>)>>,
sender: Arc<Mutex<Sender<internals::ManagerCommand<T>>>>,
is_running: Arc<IsRunningGuard>,
}
impl<T: Eq + Ord + Clone + Display> AccessTokenSourceSync<T> {
pub fn single_source_sync_for(
&self,
token_id: &T,
) -> TokenResult<FixedAccessTokenSourceSync<T>> {
match self.tokens.get(token_id) {
Some(_) => Ok(FixedAccessTokenSourceSync {
token_source: self.clone(),
token_id: token_id.clone(),
}),
None => Err(TokenErrorKind::NoToken(token_id.to_string()).into()),
}
}
pub fn new_detached(tokens: &[(T, AccessToken)]) -> AccessTokenSourceSync<T> {
let mut tokens_map = BTreeMap::new();
for (i, (id, token)) in tokens.iter().enumerate() {
let item = (i, Mutex::new(Ok(token.clone())));
tokens_map.insert(id.clone(), item);
}
let (tx, _) = ::std::sync::mpsc::channel::<internals::ManagerCommand<T>>();
AccessTokenSourceSync {
tokens: Arc::new(tokens_map),
is_running: Default::default(),
sender: Arc::new(Mutex::new(tx)),
}
}
}
impl<T: Eq + Ord + Clone + Display> GivesAccessTokensById<T> for AccessTokenSourceSync<T> {
fn get_access_token(&self, token_id: &T) -> TokenResult<AccessToken> {
match self.tokens.get(&token_id) {
Some((_, guard)) => match &*guard.lock().unwrap() {
Ok(token) => Ok(token.clone()),
Err(err) => Err(err.clone().into()),
},
None => Err(TokenErrorKind::NoToken(token_id.to_string()).into()),
}
}
fn refresh(&self, name: &T) {
match self
.sender
.lock()
.unwrap()
.send(internals::ManagerCommand::ForceRefresh(
name.clone(),
internals::Clock::now(&internals::SystemClock),
)) {
Ok(_) => (),
Err(err) => warn!("Could send send refresh command for {}: {}", name, err),
}
}
}
pub trait GivesFixedAccessToken<T: Eq + Ord + Clone + Display> {
fn get_access_token(&self) -> TokenResult<AccessToken>;
fn refresh(&self);
}
#[derive(Clone)]
pub struct FixedAccessTokenSource<T> {
token_source: AccessTokenSource<T>,
token_id: T,
}
impl<T: Eq + Ord + Clone + Display> FixedAccessTokenSource<T> {
pub fn new_detached(token_id: T, token: AccessToken) -> FixedAccessTokenSource<T> {
let token_source = AccessTokenSource::new_detached(&[(token_id.clone(), token)]);
FixedAccessTokenSource {
token_source,
token_id,
}
}
}
impl<T: Eq + Ord + Clone + Display> GivesFixedAccessToken<T> for FixedAccessTokenSource<T> {
fn get_access_token(&self) -> TokenResult<AccessToken> {
self.token_source.get_access_token(&self.token_id)
}
fn refresh(&self) {
self.token_source.refresh(&self.token_id)
}
}
#[derive(Clone)]
pub struct FixedAccessTokenSourceSync<T> {
token_source: AccessTokenSourceSync<T>,
token_id: T,
}
impl<T: Eq + Ord + Clone + Display> FixedAccessTokenSourceSync<T> {
pub fn new_detached(token_id: T, token: AccessToken) -> FixedAccessTokenSourceSync<T> {
let token_source = AccessTokenSourceSync::new_detached(&[(token_id.clone(), token)]);
FixedAccessTokenSourceSync {
token_source,
token_id,
}
}
}
impl<T: Eq + Ord + Clone + Display> GivesFixedAccessToken<T> for FixedAccessTokenSourceSync<T> {
fn get_access_token(&self) -> TokenResult<AccessToken> {
self.token_source.get_access_token(&self.token_id)
}
fn refresh(&self) {
self.token_source.refresh(&self.token_id)
}
}
pub struct AccessTokenManager;
impl AccessTokenManager {
pub fn start<T: Eq + Ord + Send + Sync + Clone + Display + 'static>(
groups: Vec<ManagedTokenGroup<T>>,
) -> InitializationResult<AccessTokenSource<T>> {
{
let mut seen = BTreeMap::default();
for group in &groups {
for managed_token in &group.managed_tokens {
let token_id = &managed_token.token_id;
if seen.contains_key(token_id) {
return Err(InitializationError(format!(
"Token id '{}' is used more than once.",
token_id
)));
} else {
seen.insert(token_id, ());
}
}
}
}
let (inner, sender) = internals::initialize(groups, internals::SystemClock);
Ok(AccessTokenSource {
tokens: inner.tokens,
sender,
is_running: Arc::new(IsRunningGuard {
is_running: inner.is_running,
}),
})
}
pub fn start_and_wait_for_tokens<T: Eq + Ord + Send + Sync + Clone + Display + 'static>(
groups: Vec<ManagedTokenGroup<T>>,
timeout_in: Duration,
) -> InitializationResult<AccessTokenSource<T>> {
{
let mut seen = BTreeMap::default();
for group in &groups {
for managed_token in &group.managed_tokens {
let token_id = &managed_token.token_id;
if seen.contains_key(token_id) {
return Err(InitializationError(format!(
"Token id '{}' is used more than once.",
token_id
)));
} else {
seen.insert(token_id, ());
}
}
}
}
let (inner, sender) = internals::initialize(groups, internals::SystemClock);
let start = Instant::now();
loop {
if start.elapsed() >= timeout_in {
return Err(InitializationError(
"Not all tokens were initialized within the \
given time."
.into(),
));
}
let all_initialized = inner.tokens.keys().all(|id| {
if let Err(token_error) = inner.get_access_token(id) {
if let TokenErrorKind::NotInitialized(_) = *token_error.kind() {
false
} else {
true
}
} else {
true
}
});
if all_initialized {
break;
}
::std::thread::sleep(Duration::from_millis(5));
}
Ok(AccessTokenSource {
tokens: inner.tokens,
sender,
is_running: Arc::new(IsRunningGuard {
is_running: inner.is_running,
}),
})
}
}