use std::ops::RangeInclusive;
use std::time::Duration;
use reqwest::Url;
#[derive(Clone)]
pub struct Config {
pub token_url: Url,
pub caller: String,
pub audience: String,
pub cache_type: CacheType,
pub token_encryption_key: String,
pub check_interval: Duration,
pub staleness_check_percentage: StalenessCheckPercentage,
pub client_id: String,
pub client_secret: String,
pub jwks_url: Url,
pub scope: Option<String>,
}
impl Config {
pub fn token_url(&self) -> &Url {
&self.token_url
}
pub fn caller(&self) -> &str {
&self.caller
}
pub fn audience(&self) -> &str {
&self.audience
}
pub fn cache_type(&self) -> &CacheType {
&self.cache_type
}
pub fn token_encryption_key(&self) -> &str {
&self.token_encryption_key
}
pub fn check_interval(&self) -> &Duration {
&self.check_interval
}
pub fn client_id(&self) -> &str {
&self.client_id
}
pub fn client_secret(&self) -> &str {
&self.client_secret
}
pub fn jwks_url(&self) -> &Url {
&self.jwks_url
}
pub fn staleness_check_percentage(&self) -> &StalenessCheckPercentage {
&self.staleness_check_percentage
}
pub fn is_inmemory_cache(&self) -> bool {
self.cache_type == CacheType::Inmemory
}
#[cfg(test)]
pub fn test_config(server: &mockito::Server) -> Config {
use std::str::FromStr;
Config {
token_url: Url::from_str(&format!("{}/{}", server.url().as_str(), "token")).unwrap(),
jwks_url: Url::from_str(&format!("{}/{}", server.url().as_str(), "jwks")).unwrap(),
caller: "caller".to_string(),
audience: "audience".to_string(),
cache_type: CacheType::Inmemory,
token_encryption_key: "32char_long_token_encryption_key".to_string(),
check_interval: Duration::from_secs(10),
staleness_check_percentage: StalenessCheckPercentage::default(),
client_id: "client_id".to_string(),
client_secret: "client_secret".to_string(),
scope: None,
}
}
}
#[derive(Clone, Eq, PartialEq)]
pub enum CacheType {
Redis(String),
Inmemory,
}
impl CacheType {
pub fn redis_connection_url(&self) -> &str {
match &self {
CacheType::Redis(url) => url,
CacheType::Inmemory => {
panic!("Something went wrong getting Redis connection string")
}
}
}
}
#[derive(Clone)]
pub struct StalenessCheckPercentage(RangeInclusive<f64>);
impl StalenessCheckPercentage {
pub fn new(min: f64, max: f64) -> Self {
assert!((0.0..=1.0).contains(&min));
assert!((0.0..=1.0).contains(&max));
assert!(min <= max);
Self(min..=max)
}
pub fn random_value_between(&self) -> f64 {
use rand::Rng;
rand::thread_rng().gen_range(self.0.clone())
}
}
impl Default for StalenessCheckPercentage {
fn default() -> Self {
Self(0.6..=0.9)
}
}
impl From<RangeInclusive<f64>> for StalenessCheckPercentage {
fn from(range: RangeInclusive<f64>) -> Self {
Self::new(*range.start(), *range.end())
}
}