use std::convert::TryFrom;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use chrono::{DateTime, Duration, Utc};
use hmac::{Hmac, Mac};
use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use crate::error::{Error, ErrorCode};
use crate::rest::RestInner;
use crate::{http, rest, Result};
const MAX_TOKEN_LENGTH: usize = 128 * 1024;
mod duration {
use std::fmt;
use super::*;
use serde::{de, Deserializer, Serializer};
#[derive(Debug)]
pub struct MilliSecondsTimestampVisitor;
impl<'de> de::Visitor<'de> for MilliSecondsTimestampVisitor {
type Value = Duration;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a duration in milliseconds")
}
fn visit_i64<E>(self, value: i64) -> std::result::Result<Self::Value, E>
where
E: de::Error,
{
Ok(Duration::milliseconds(value))
}
}
pub fn deserialize<'de, D>(d: D) -> std::result::Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
d.deserialize_u64(MilliSecondsTimestampVisitor)
}
pub fn serialize<S>(d: &Duration, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
let n = d.num_milliseconds();
serializer.serialize_i64(n)
}
}
#[derive(Clone)]
pub enum Credential {
TokenDetails(TokenDetails),
TokenRequest(TokenRequest),
Callback(Arc<dyn AuthCallback>),
Key(Key),
Url(reqwest::Url),
}
impl std::fmt::Debug for Credential {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TokenDetails(arg0) => f.debug_tuple("TokenDetails").field(arg0).finish(),
Self::TokenRequest(arg0) => f.debug_tuple("TokenRequest").field(arg0).finish(),
Self::Key(arg0) => f.debug_tuple("Key").field(arg0).finish(),
Self::Callback(_) => f.debug_tuple("Callback").field(&"Fn").finish(),
Self::Url(arg0) => f.debug_tuple("Url").field(arg0).finish(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct AuthOptions {
pub token: Option<Credential>,
pub headers: Option<http::HeaderMap>,
pub method: http::Method,
pub params: Option<http::UrlQuery>,
}
#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
pub struct Key {
#[serde(rename(deserialize = "keyName"))]
pub name: String,
pub value: String,
}
impl Key {
pub fn new(s: &str) -> Result<Self> {
if let [name, value] = s.splitn(2, ':').collect::<Vec<&str>>()[..] {
Ok(Key {
name: name.to_string(),
value: value.to_string(),
})
} else {
Err(Error::new(ErrorCode::BadRequest, "Invalid key"))
}
}
}
impl TryFrom<&str> for Key {
type Error = Error;
fn try_from(s: &str) -> Result<Self> {
Self::new(s)
}
}
impl Key {
pub fn sign(&self, params: &TokenParams) -> Result<TokenRequest> {
params.sign(self)
}
}
#[derive(Clone, Debug)]
pub struct Auth<'a> {
pub(crate) rest: &'a rest::Rest,
}
impl<'a> Auth<'a> {
pub fn new(rest: &'a rest::Rest) -> Self {
Self { rest }
}
fn inner(&self) -> &RestInner {
&self.rest.inner
}
pub fn create_token_request(
&self,
params: &TokenParams,
options: &AuthOptions,
) -> Result<TokenRequest> {
let key = match &options.token {
Some(Credential::Key(k)) => k,
_ => {
return Err(Error::new(
ErrorCode::UnableToObtainCredentialsFromGivenParameters,
"API key is required to create signed token requests",
))
}
};
params.sign(key)
}
pub(crate) fn exchange(
&self,
req: &TokenRequest,
) -> Pin<Box<dyn Future<Output = Result<TokenDetails>> + Send + 'a>> {
let req = self
.rest
.request(
http::Method::POST,
&format!("/keys/{}/requestToken", req.key_name),
)
.authenticate(false)
.body(req);
Box::pin(async move { req.send().await?.body().await.map_err(Into::into) })
}
fn request_url<'b>(
&'b self,
url: &'b reqwest::Url,
) -> Pin<Box<dyn Future<Output = Result<TokenDetails>> + Send + 'b>> {
let fut = async move {
let res = self
.rest
.request_url(Default::default(), url.clone())
.authenticate(false)
.send()
.await?;
let content_type = res.content_type().ok_or_else(|| {
Error::new(
ErrorCode::ErrorFromClientTokenCallback,
"authUrl response is missing a content-type header",
)
})?;
match content_type.essence_str() {
"application/json" => {
let token: RequestOrDetails = res.json().await?;
match token {
RequestOrDetails::Request(r) => self.exchange(&r).await,
RequestOrDetails::Details(d) => Ok(d),
}
},
"text/plain" | "application/jwt" => {
let token = res.text().await?;
Ok(TokenDetails::from(token))
},
_ => Err(Error::new(ErrorCode::ErrorFromClientTokenCallback, format!("authUrl responded with unacceptable content-type {}, should be either text/plain, application/jwt or application/json", content_type))),
}
};
Box::pin(fut)
}
pub async fn request_token(
&self,
params: &TokenParams,
options: &AuthOptions,
) -> Result<TokenDetails> {
let token = options.token.as_ref().ok_or_else(|| {
Error::new(
ErrorCode::NoWayToRenewAuthToken,
"no means provided to renew auth token",
)
})?;
let mut details = match token {
Credential::TokenDetails(token) => Ok(token.clone()),
Credential::TokenRequest(r) => self.exchange(r).await,
Credential::Callback(f) => match f.token(params).await {
Ok(token) => token.into_details(self).await,
Err(e) => Err(e),
},
Credential::Key(k) => self.exchange(¶ms.sign(k)?).await,
Credential::Url(url) => self.request_url(url).await,
};
if matches!(token, Credential::Callback(_) | Credential::Url(_)) {
if let Err(ref mut err) = details {
if err.code == ErrorCode::BadRequest {
err.code = ErrorCode::ErrorFromClientTokenCallback;
err.status_code = Some(401);
}
};
}
let details = details?;
if details.token.len() > MAX_TOKEN_LENGTH {
return Err(Error::with_status(
ErrorCode::ErrorFromClientTokenCallback,
401,
format!(
"Token string exceeded max permitted length (was {} bytes)",
details.token.len()
),
));
}
Ok(details)
}
pub(crate) async fn with_auth_headers(&self, req: &mut reqwest::Request) -> Result<()> {
if let Credential::Key(k) = &self.inner().opts.credential {
return Self::set_basic_auth(req, k);
}
let options = AuthOptions {
token: Some(self.inner().opts.credential.clone()),
..Default::default()
};
let res = self.request_token(&Default::default(), &options).await?;
Self::set_bearer_auth(req, &res.token)
}
fn set_bearer_auth(req: &mut reqwest::Request, token: &str) -> Result<()> {
Self::set_header(
req,
reqwest::header::AUTHORIZATION,
format!("Bearer {}", token),
)
}
fn set_basic_auth(req: &mut reqwest::Request, key: &Key) -> Result<()> {
let encoded = base64::encode(format!("{}:{}", key.name, key.value));
Self::set_header(
req,
reqwest::header::AUTHORIZATION,
format!("Basic {}", encoded),
)
}
fn set_header(req: &mut reqwest::Request, key: http::HeaderName, value: String) -> Result<()> {
req.headers_mut().append(key, value.parse()?);
Ok(())
}
fn generate_nonce() -> String {
thread_rng()
.sample_iter(&Alphanumeric)
.take(16)
.map(char::from)
.collect()
}
fn compute_mac(
key: &Key,
ttl: Duration,
capability: &str,
client_id: Option<&str>,
timestamp: DateTime<Utc>,
nonce: &str,
) -> Result<String> {
let mut mac = Hmac::<Sha256>::new_from_slice(key.value.as_bytes())?;
mac.update(key.name.as_bytes());
mac.update(b"\n");
mac.update(ttl.num_milliseconds().to_string().as_bytes());
mac.update(b"\n");
mac.update(capability.as_bytes());
mac.update(b"\n");
mac.update(client_id.map(|c| c.as_bytes()).unwrap_or_default());
mac.update(b"\n");
mac.update(timestamp.timestamp_millis().to_string().as_bytes());
mac.update(b"\n");
mac.update(nonce.as_bytes());
mac.update(b"\n");
Ok(base64::encode(mac.finalize().into_bytes()))
}
}
#[derive(Clone, Debug)]
pub struct TokenParams {
pub capability: String,
pub client_id: Option<String>,
pub nonce: Option<String>,
pub timestamp: Option<DateTime<Utc>>,
pub ttl: Duration,
}
impl Default for TokenParams {
fn default() -> Self {
Self {
capability: "{\"*\":[\"*\"]}".to_string(),
client_id: Default::default(),
nonce: Default::default(),
timestamp: Default::default(),
ttl: Duration::minutes(60),
}
}
}
impl TokenParams {
pub fn new() -> Self {
Default::default()
}
pub fn capability(mut self, capability: &str) -> Self {
self.capability = capability.to_string();
self
}
pub fn client_id(mut self, client_id: &str) -> Self {
self.client_id = Some(client_id.to_string());
self
}
pub fn ttl(mut self, ttl: Duration) -> Self {
self.ttl = ttl;
self
}
pub fn timestamp(mut self, timestamp: DateTime<Utc>) -> Self {
self.timestamp = Some(timestamp);
self
}
fn sign(&self, key: &Key) -> Result<TokenRequest> {
if let Some(ref client_id) = self.client_id {
if client_id.is_empty() {
return Err(Error::new(
ErrorCode::InvalidClientID,
"client_id can’t be an empty string",
));
}
}
let nonce = self.nonce.clone().unwrap_or_else(Auth::generate_nonce);
let timestamp = self.timestamp.unwrap_or_else(Utc::now);
let key_name = key.name.clone();
let req = TokenRequest {
mac: Auth::compute_mac(
key,
self.ttl,
&self.capability,
self.client_id.as_deref(),
timestamp,
&nonce,
)?,
key_name,
timestamp,
capability: self.capability.clone(),
client_id: self.client_id.clone(),
nonce,
ttl: self.ttl,
};
Ok(req)
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct TokenRequest {
pub key_name: String,
#[serde(with = "chrono::serde::ts_milliseconds")]
pub timestamp: DateTime<Utc>,
pub capability: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_id: Option<String>,
pub mac: String,
pub nonce: String,
#[serde(with = "duration")]
pub ttl: Duration,
}
#[derive(Clone, Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TokenDetails {
pub token: String,
#[serde(flatten)]
pub metadata: Option<TokenMetadata>,
}
impl TokenDetails {
pub fn token(s: String) -> Self {
Self {
token: s,
metadata: None,
}
}
}
impl From<String> for TokenDetails {
fn from(token: String) -> Self {
TokenDetails {
token,
metadata: None,
}
}
}
#[derive(Clone, Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TokenMetadata {
#[serde(with = "chrono::serde::ts_milliseconds")]
pub expires: DateTime<Utc>,
#[serde(with = "chrono::serde::ts_milliseconds")]
pub issued: DateTime<Utc>,
pub capability: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_id: Option<String>,
}
#[derive(Clone, Debug, Deserialize)]
#[serde(untagged)]
pub enum RequestOrDetails {
Request(TokenRequest),
Details(TokenDetails),
}
impl RequestOrDetails {
async fn into_details(self, auth: &Auth<'_>) -> Result<TokenDetails> {
match self {
RequestOrDetails::Request(r) => auth.exchange(&r).await,
RequestOrDetails::Details(d) => Ok(d),
}
}
}
pub trait AuthCallback: Send + Sync {
fn token<'a>(
&'a self,
params: &'a TokenParams,
) -> Pin<Box<dyn Send + Future<Output = Result<RequestOrDetails>> + 'a>>;
}