use std::env;
use std::io::Read;
use std::str;
use std::sync::Arc;
use std::time::Duration;
use backoff::{Error as BackoffError, ExponentialBackoff, Operation};
use failure::ResultExt;
use reqwest::{StatusCode, Url};
use reqwest::blocking::{Client, Response};
use url::ParseError;
use crate::parsers::*;
use crate::{AccessToken, InitializationError, InitializationResult, TokenInfo};
use crate::{TokenInfoError, TokenInfoErrorKind, TokenInfoResult, TokenInfoService};
#[cfg(feature = "async")]
use crate::async_client::AsyncTokenInfoServiceClientLight;
#[cfg(feature = "metrix")]
use crate::metrics::metrix::MetrixCollector;
#[cfg(feature = "async")]
use crate::metrics::{DevNullMetricsCollector, MetricsCollector};
#[cfg(feature = "metrix")]
use metrix::processor::{AggregatesProcessors, ProcessorMount};
pub struct TokenInfoServiceClientBuilder<P: TokenInfoParser> {
pub parser: Option<P>,
pub endpoint: Option<String>,
pub query_parameter: Option<String>,
pub fallback_endpoint: Option<String>,
}
impl<P> TokenInfoServiceClientBuilder<P>
where
P: TokenInfoParser + Clone + Sync + Send + 'static,
{
pub fn new(parser: P) -> Self {
let mut builder = Self::default();
builder.with_parser(parser);
builder
}
pub fn with_parser(&mut self, parser: P) -> &mut Self {
self.parser = Some(parser);
self
}
pub fn with_endpoint<T: Into<String>>(&mut self, endpoint: T) -> &mut Self {
self.endpoint = Some(endpoint.into());
self
}
pub fn with_fallback_endpoint<T: Into<String>>(&mut self, endpoint: T) -> &mut Self {
self.fallback_endpoint = Some(endpoint.into());
self
}
pub fn with_query_parameter<T: Into<String>>(&mut self, parameter: T) -> &mut Self {
self.query_parameter = Some(parameter.into());
self
}
pub fn build(self) -> InitializationResult<TokenInfoServiceClient> {
let parser = if let Some(parser) = self.parser {
parser
} else {
return Err(InitializationError("No token info parser.".into()));
};
let endpoint = if let Some(endpoint) = self.endpoint {
endpoint
} else {
return Err(InitializationError("No endpoint.".into()));
};
TokenInfoServiceClient::new::<P>(
&endpoint,
self.query_parameter.as_ref().map(|s| &**s),
self.fallback_endpoint.as_ref().map(|s| &**s),
parser,
)
}
#[cfg(feature = "async")]
pub fn build_async(
self,
) -> InitializationResult<AsyncTokenInfoServiceClientLight<P, DevNullMetricsCollector>> {
self.build_async_with_metrics(DevNullMetricsCollector)
}
#[cfg(feature = "async")]
pub fn build_async_with_metrics<M>(
self,
metrics_collector: M,
) -> InitializationResult<AsyncTokenInfoServiceClientLight<P, M>>
where
M: MetricsCollector + Clone + Send + 'static,
{
let parser = if let Some(parser) = self.parser {
parser
} else {
return Err(InitializationError("No token info parser.".into()));
};
let endpoint = if let Some(endpoint) = self.endpoint {
endpoint
} else {
return Err(InitializationError("No endpoint.".into()));
};
AsyncTokenInfoServiceClientLight::with_metrics(
&endpoint,
self.query_parameter.as_ref().map(|s| &**s),
self.fallback_endpoint.as_ref().map(|s| &**s),
parser,
metrics_collector,
)
}
#[cfg(all(feature = "async", feature = "metrix"))]
pub fn build_async_with_metrix<M, T>(
self,
takes_metrics: &mut M,
group_name: Option<T>,
) -> InitializationResult<AsyncTokenInfoServiceClientLight<P, MetrixCollector>>
where
M: AggregatesProcessors,
T: Into<String>,
{
let metrics_collector = if let Some(group) = group_name {
let mut mount = ProcessorMount::new(group);
let collector = MetrixCollector::new(&mut mount);
takes_metrics.add_processor(mount);
collector
} else {
MetrixCollector::new(takes_metrics)
};
self.build_async_with_metrics(metrics_collector)
}
pub fn from_env() -> InitializationResult<Self> {
let endpoint = env::var("TOKKIT_TOKEN_INTROSPECTION_ENDPOINT").map_err(|err| {
InitializationError(format!("'TOKKIT_TOKEN_INTROSPECTION_ENDPOINT': {}", err))
})?;
let query_parameter = match env::var("TOKKIT_TOKEN_INTROSPECTION_QUERY_PARAMETER") {
Ok(v) => Some(v),
Err(env::VarError::NotPresent) => None,
Err(err) => {
return Err(InitializationError(format!(
"'TOKKIT_TOKEN_INTROSPECTION_QUERY_PARAMETER': {}",
err
)));
}
};
let fallback_endpoint = match env::var("TOKKIT_TOKEN_INTROSPECTION_FALLBACK_ENDPOINT") {
Ok(v) => Some(v),
Err(env::VarError::NotPresent) => None,
Err(err) => {
return Err(InitializationError(format!(
"'TOKKIT_TOKEN_INTROSPECTION_FALLBACK_ENDPOINT': {}",
err
)));
}
};
Ok(TokenInfoServiceClientBuilder {
parser: Default::default(),
endpoint: Some(endpoint),
query_parameter,
fallback_endpoint,
})
}
}
impl TokenInfoServiceClientBuilder<PlanBTokenInfoParser> {
pub fn plan_b(endpoint: String) -> TokenInfoServiceClientBuilder<PlanBTokenInfoParser> {
let mut builder = Self::default();
builder.with_parser(PlanBTokenInfoParser);
builder.with_endpoint(endpoint);
builder.with_query_parameter("access_token");
builder
}
pub fn plan_b_from_env(
) -> InitializationResult<TokenInfoServiceClientBuilder<PlanBTokenInfoParser>> {
let mut builder = Self::from_env()?;
builder.with_parser(PlanBTokenInfoParser);
builder.with_query_parameter("access_token");
Ok(builder)
}
}
impl TokenInfoServiceClientBuilder<GoogleV3TokenInfoParser> {
pub fn google_v3() -> TokenInfoServiceClientBuilder<GoogleV3TokenInfoParser> {
let mut builder = Self::default();
builder.with_parser(GoogleV3TokenInfoParser);
builder.with_endpoint("https://www.googleapis.com/oauth2/v3/tokeninfo");
builder.with_query_parameter("access_token");
builder
}
}
impl TokenInfoServiceClientBuilder<AmazonTokenInfoParser> {
pub fn amazon() -> TokenInfoServiceClientBuilder<AmazonTokenInfoParser> {
let mut builder = Self::default();
builder.with_parser(AmazonTokenInfoParser);
builder.with_endpoint("https://api.amazon.com/auth/O2/tokeninfo");
builder.with_query_parameter("access_token");
builder
}
}
impl<P: TokenInfoParser> Default for TokenInfoServiceClientBuilder<P> {
fn default() -> Self {
TokenInfoServiceClientBuilder {
parser: Default::default(),
endpoint: Default::default(),
query_parameter: Default::default(),
fallback_endpoint: Default::default(),
}
}
}
pub struct TokenInfoServiceClient {
url_prefix: Arc<String>,
fallback_url_prefix: Option<Arc<String>>,
http_client: Client,
parser: Arc<dyn TokenInfoParser + Sync + Send + 'static>,
}
impl TokenInfoServiceClient {
pub fn new<P>(
endpoint: &str,
query_parameter: Option<&str>,
fallback_endpoint: Option<&str>,
parser: P,
) -> InitializationResult<TokenInfoServiceClient>
where
P: TokenInfoParser + Sync + Send + 'static,
{
let url_prefix = assemble_url_prefix(endpoint, &query_parameter)
.map_err(InitializationError)?;
let fallback_url_prefix = if let Some(fallback_endpoint_address) = fallback_endpoint {
Some(
assemble_url_prefix(fallback_endpoint_address, &query_parameter)
.map_err(InitializationError)?,
)
} else {
None
};
let client = Client::new();
Ok(TokenInfoServiceClient {
url_prefix: Arc::new(url_prefix),
fallback_url_prefix: fallback_url_prefix.map(Arc::new),
http_client: client,
parser: Arc::new(parser),
})
}
}
pub(crate) fn assemble_url_prefix(
endpoint: &str,
query_parameter: &Option<&str>,
) -> ::std::result::Result<String, String> {
let mut url_prefix = String::from(endpoint);
if let Some(query_parameter) = query_parameter {
if url_prefix.ends_with('/') {
url_prefix.pop();
}
url_prefix.push_str(&format!("?{}=", query_parameter));
} else if !url_prefix.ends_with('/') {
url_prefix.push('/');
}
let test_url = format!("{}test_token", url_prefix);
let _ = test_url
.parse::<Url>()
.map_err(|err| format!("Invalid URL: {}", err))?;
Ok(url_prefix)
}
impl TokenInfoService for TokenInfoServiceClient {
fn introspect(&self, token: &AccessToken) -> TokenInfoResult<TokenInfo> {
let url: Url = complete_url(&self.url_prefix, token)?;
let fallback_url = match self.fallback_url_prefix {
Some(ref fb_url_prefix) => Some(complete_url(fb_url_prefix, token)?),
None => None,
};
get_with_fallback(url, fallback_url, &self.http_client, &*self.parser)
}
}
impl Clone for TokenInfoServiceClient {
fn clone(&self) -> Self {
TokenInfoServiceClient {
url_prefix: self.url_prefix.clone(),
fallback_url_prefix: self.fallback_url_prefix.clone(),
http_client: self.http_client.clone(),
parser: self.parser.clone(),
}
}
}
fn complete_url(url_prefix: &str, token: &AccessToken) -> TokenInfoResult<Url> {
let mut url_str = url_prefix.to_string();
url_str.push_str(token.0.as_ref());
let url = url_str.parse()?;
Ok(url)
}
fn get_with_fallback(
url: Url,
fallback_url: Option<Url>,
client: &Client,
parser: &dyn TokenInfoParser,
) -> TokenInfoResult<TokenInfo> {
get_from_remote(url, client, parser).or_else(|err| match *err.kind() {
TokenInfoErrorKind::Client(_) => Err(err),
_ => fallback_url
.map(|url| get_from_remote(url, client, parser))
.unwrap_or(Err(err)),
})
}
fn get_from_remote<P>(
url: Url,
http_client: &Client,
parser: &P,
) -> TokenInfoResult<TokenInfo>
where
P: TokenInfoParser + ?Sized,
{
let mut op = || match get_from_remote_no_retry(url.clone(), http_client, parser) {
Ok(token_info) => Ok(token_info),
Err(err) => match *err.kind() {
TokenInfoErrorKind::InvalidResponseContent(_) => Err(BackoffError::Permanent(err)),
TokenInfoErrorKind::UrlError(_) => Err(BackoffError::Permanent(err)),
TokenInfoErrorKind::NotAuthenticated(_) => Err(BackoffError::Permanent(err)),
TokenInfoErrorKind::Client(_) => Err(BackoffError::Permanent(err)),
_ => Err(BackoffError::Transient(err)),
},
};
let mut backoff = ExponentialBackoff::default();
backoff.max_elapsed_time = Some(Duration::from_millis(200));
backoff.initial_interval = Duration::from_millis(10);
backoff.multiplier = 1.5;
let notify = |err, _| {
warn!("Retry on token info service: {}", err);
};
let retry_result = op.retry_notify(&mut backoff, notify);
match retry_result {
Ok(token_info) => Ok(token_info),
Err(BackoffError::Transient(err)) => Err(err),
Err(BackoffError::Permanent(err)) => Err(err),
}
}
fn get_from_remote_no_retry<P>(
url: Url,
http_client: &Client,
parser: &P,
) -> TokenInfoResult<TokenInfo>
where
P: TokenInfoParser + ?Sized,
{
let request_builder = http_client.get(url);
match request_builder.send() {
Ok(ref mut response) => process_response(response, parser),
Err(err) => Err(TokenInfoErrorKind::Connection(err.to_string()).into()),
}
}
fn process_response<P>(
response: &mut Response,
parser: &P,
) -> TokenInfoResult<TokenInfo>
where
P: TokenInfoParser + ?Sized,
{
let mut body = Vec::new();
response
.read_to_end(&mut body)
.context(TokenInfoErrorKind::Io(
"Could not read response bode".to_string(),
))?;
if response.status() == StatusCode::OK {
let result: TokenInfo = match parser.parse(&body) {
Ok(info) => info,
Err(msg) => {
return Err(TokenInfoErrorKind::InvalidResponseContent(msg.to_string()).into());
}
};
Ok(result)
} else if response.status() == StatusCode::UNAUTHORIZED {
let msg = str::from_utf8(&body)?;
Err(TokenInfoErrorKind::NotAuthenticated(format!(
"The server refused the token: {}",
msg
))
.into())
} else if response.status().is_client_error() {
let msg = str::from_utf8(&body)?;
Err(TokenInfoErrorKind::Client(msg.to_string()).into())
} else if response.status().is_server_error() {
let msg = str::from_utf8(&body)?;
Err(TokenInfoErrorKind::Server(msg.to_string()).into())
} else {
let msg = str::from_utf8(&body)?;
Err(TokenInfoErrorKind::Other(msg.to_string()).into())
}
}
impl From<ParseError> for TokenInfoError {
fn from(what: ParseError) -> Self {
TokenInfoErrorKind::UrlError(what.to_string()).into()
}
}
impl From<str::Utf8Error> for TokenInfoError {
fn from(what: str::Utf8Error) -> Self {
TokenInfoErrorKind::InvalidResponseContent(what.to_string()).into()
}
}