pub mod discovery;
pub mod error;
pub mod issuer;
pub mod token;
pub use crate::error::Error;
use biscuit::{Empty, SingleOrMultiple};
use biscuit::jwa::{self, SignatureAlgorithm};
use biscuit::jwk::{AlgorithmParameters, JWKSet};
use biscuit::jws::{Compact, Secret};
use chrono::{Duration, NaiveDate, Utc};
use inth_oauth2::token::Token as _t;
use reqwest::Url;
use serde_derive::{Deserialize, Serialize};
use validator::Validate;
use validator_derive::Validate;
use crate::discovery::{Config, Discovered};
use crate::error::{Decode, Expiry, Mismatch, Missing, Validation};
use crate::token::{Claims, Token};
type IdToken = Compact<Claims, Empty>;
pub struct Client {
oauth: inth_oauth2::Client<Discovered>,
jwks: JWKSet<Empty>,
}
macro_rules! wrong_key {
($expected:expr, $actual:expr) => (
Err(error::Jose::WrongKeyType {
expected: format!("{:?}", $expected),
actual: format!("{:?}", $actual)
}.into()
)
)
}
impl Client {
pub fn discover(id: String, secret: String, redirect: Url, issuer: Url) -> Result<Self, Error> {
discovery::secure(&redirect)?;
let client = reqwest::Client::new();
let config = discovery::discover(&client, issuer)?;
let jwks = discovery::jwks(&client, config.jwks_uri.clone())?;
let provider = Discovered(config);
Ok(Self::new(id, secret, redirect, provider, jwks))
}
pub fn new(id: String, secret:
String, redirect: Url, provider: Discovered, jwks: JWKSet<Empty>) -> Self {
Client {
oauth: inth_oauth2::Client::new(
provider,
id,
secret,
Some(redirect.into_string())),
jwks
}
}
pub fn redirect_url(&self) -> &str {
self.oauth.redirect_uri.as_ref().expect("We always require a redirect to construct client!")
}
pub fn request_token(&self,
client: &reqwest::Client,
auth_code: &str,
) -> Result<Token, Error> {
self.oauth.request_token(client, auth_code).map_err(Error::from)
}
pub fn config(&self) -> &Config {
&self.oauth.provider.0
}
pub fn auth_url(&self, options: &Options) -> Url {
let scope = match options.scope {
Some(ref scope) => {
if !scope.contains("openid") {
String::from("openid ") + scope
} else {
scope.clone()
}
}
None => String::from("openid")
};
let mut url = self.oauth.auth_uri(Some(&scope), options.state.as_ref().map(String::as_str));
{
let mut query = url.query_pairs_mut();
if let Some(ref nonce) = options.nonce {
query.append_pair("nonce", nonce.as_str());
}
if let Some(ref display) = options.display {
query.append_pair("display", display.as_str());
}
if let Some(ref prompt) = options.prompt {
let s = prompt.iter().map(|s| s.as_str()).collect::<Vec<_>>().join(" ");
query.append_pair("prompt", s.as_str());
}
if let Some(max_age) = options.max_age {
query.append_pair("max_age", max_age.num_seconds().to_string().as_str());
}
if let Some(ref ui_locales) = options.ui_locales {
query.append_pair("ui_locales", ui_locales.as_str());
}
if let Some(ref claims_locales) = options.claims_locales {
query.append_pair("claims_locales", claims_locales.as_str());
}
if let Some(ref id_token_hint) = options.id_token_hint {
query.append_pair("id_token_hint", id_token_hint.as_str());
}
if let Some(ref login_hint) = options.login_hint {
query.append_pair("login_hint", login_hint.as_str());
}
if let Some(ref acr_values) = options.acr_values {
query.append_pair("acr_values", acr_values.as_str());
}
}
url
}
pub fn authenticate(&self, auth_code: &str, nonce: Option<&str>, max_age: Option<&Duration>
) -> Result<Token, Error> {
let client = reqwest::Client::new();
let mut token = self.request_token(&client, auth_code)?;
self.decode_token(&mut token.id_token)?;
self.validate_token(&token.id_token, nonce, max_age)?;
Ok(token)
}
pub fn decode_token(&self, token: &mut IdToken) -> Result<(), Error> {
if let Compact::Decoded { .. } = *token {
return Ok(())
}
let header = token.unverified_header()?;
let key = if self.jwks.keys.len() > 1 {
let token_kid = header.registered.key_id.ok_or(Decode::MissingKid)?;
self.jwks.find(&token_kid).ok_or(Decode::MissingKey(token_kid))?
} else {
self.jwks.keys.first().as_ref().ok_or(Decode::EmptySet)?
};
if let Some(alg) = key.common.algorithm.as_ref() {
if let &jwa::Algorithm::Signature(sig) = alg {
if header.registered.algorithm != sig {
return wrong_key!(sig, header.registered.algorithm);
}
} else {
return wrong_key!(SignatureAlgorithm::default(), alg);
}
}
let alg = header.registered.algorithm;
match key.algorithm {
AlgorithmParameters::OctectKey { ref value, .. } => {
match alg {
SignatureAlgorithm::HS256 |
SignatureAlgorithm::HS384 |
SignatureAlgorithm::HS512 => {
*token = token.decode(&Secret::Bytes(value.clone()), alg)?;
Ok(())
}
_ => wrong_key!("HS256 | HS384 | HS512", alg)
}
}
AlgorithmParameters::RSA(ref params) => {
match alg {
SignatureAlgorithm::RS256 |
SignatureAlgorithm::RS384 |
SignatureAlgorithm::RS512 => {
let pkcs = Secret::RSAModulusExponent {
n: params.n.clone(),
e: params.e.clone(),
};
*token = token.decode(&pkcs, alg)?;
Ok(())
}
_ => wrong_key!("RS256 | RS384 | RS512", alg)
}
}
AlgorithmParameters::EllipticCurve(_) => unimplemented!("No support for EC keys yet"),
}
}
pub fn validate_token(
&self,
token: &IdToken,
nonce: Option<&str>,
max_age: Option<&Duration>
) -> Result<(), Error> {
let claims = token.payload()?;
if claims.iss != self.config().issuer {
let expected = self.config().issuer.as_str().to_string();
let actual = claims.iss.as_str().to_string();
return Err(Validation::Mismatch(Mismatch::Issuer { expected, actual }).into());
}
match nonce {
Some(expected) => match claims.nonce {
Some(ref actual) => {
if expected != actual {
let expected = expected.to_string();
let actual = actual.to_string();
return Err(Validation::Mismatch(
Mismatch::Nonce { expected, actual }).into());
}
}
None => return Err(Validation::Missing(Missing::Nonce).into()),
}
None => if claims.nonce.is_some() {
return Err(Validation::Missing(Missing::Nonce).into())
}
}
if !claims.aud.contains(&self.oauth.client_id) {
return Err(Validation::Missing(Missing::Audience).into());
}
if let SingleOrMultiple::Multiple(_) = claims.aud {
if let None = claims.azp {
return Err(Validation::Missing(Missing::AuthorizedParty).into());
}
}
if let Some(ref actual) = claims.azp {
if actual != &self.oauth.client_id {
let expected = self.oauth.client_id.to_string();
let actual = actual.to_string();
return Err(Validation::Mismatch(Mismatch::AuthorizedParty {
expected, actual
}).into());
}
}
let now = Utc::now();
if now.timestamp() < 1504758600 {
panic!("chrono::Utc::now() can never be before this was written!")
}
if claims.exp <= now.timestamp() {
return Err(Validation::Expired(
Expiry::Expires(
chrono::naive::NaiveDateTime::from_timestamp(claims.exp, 0))).into());
}
if let Some(max) = max_age {
match claims.auth_time {
Some(time) => {
let age = chrono::Duration::seconds(now.timestamp() - time);
if age >= *max {
return Err(error::Validation::Expired(Expiry::MaxAge(age)).into());
}
}
None => return Err(Validation::Missing(Missing::AuthTime).into()),
}
}
Ok(())
}
pub fn request_userinfo(&self, client: &reqwest::Client, token: &Token
) -> Result<Userinfo, Error> {
match self.config().userinfo_endpoint {
Some(ref url) => {
discovery::secure(&url)?;
let claims = token.id_token.payload()?;
let auth_code = token.access_token().to_string();
let mut resp = client.get(url.clone())
.header_011(reqwest::hyper_011::header::Authorization(reqwest::hyper_011::header::Bearer { token: auth_code }))
.send()?;
let info: Userinfo = resp.json()?;
if claims.sub != info.sub {
let expected = info.sub.clone();
let actual = claims.sub.clone();
return Err(error::Userinfo::MismatchSubject { expected, actual }.into())
}
Ok(info)
}
None => Err(error::Userinfo::NoUrl.into())
}
}
}
#[derive(Default)]
pub struct Options {
pub scope: Option<String>,
pub state: Option<String>,
pub nonce: Option<String>,
pub display: Option<Display>,
pub prompt: Option<std::collections::HashSet<Prompt>>,
pub max_age: Option<Duration>,
pub ui_locales: Option<String>,
pub claims_locales: Option<String>,
pub id_token_hint: Option<String>,
pub login_hint: Option<String>,
pub acr_values: Option<String>,
}
#[derive(Debug, Deserialize, Serialize, Validate)]
pub struct Userinfo {
pub sub: String,
#[serde(default)] pub name: Option<String>,
#[serde(default)] pub given_name: Option<String>,
#[serde(default)] pub family_name: Option<String>,
#[serde(default)] pub middle_name: Option<String>,
#[serde(default)] pub nickname: Option<String>,
#[serde(default)] pub preferred_username: Option<String>,
#[serde(default)] #[serde(with = "url_serde")] pub profile: Option<Url>,
#[serde(default)] #[serde(with = "url_serde")] pub picture: Option<Url>,
#[serde(default)] #[serde(with = "url_serde")] pub website: Option<Url>,
#[serde(default)] #[validate(email)] pub email: Option<String>,
#[serde(default)] pub email_verified: bool,
#[serde(default)] pub gender: Option<String>,
#[serde(default)] pub birthdate: Option<NaiveDate>,
#[serde(default)] pub zoneinfo: Option<String>,
#[serde(default)] pub locale: Option<String>,
#[serde(default)] pub phone_number: Option<String>,
#[serde(default)] pub phone_number_verified: bool,
#[serde(default)] pub address: Option<Address>,
#[serde(default)] pub updated_at: Option<i64>,
}
pub enum Display {
Page,
Popup,
Touch,
Wap,
}
impl Display {
fn as_str(&self) -> &'static str {
use self::Display::*;
match *self {
Page => "page",
Popup => "popup",
Touch => "touch",
Wap => "wap",
}
}
}
#[derive(PartialEq, Eq, Hash)]
pub enum Prompt {
None,
Login,
Consent,
SelectAccount,
}
impl Prompt {
fn as_str(&self) -> &'static str {
use self::Prompt::*;
match *self {
None => "none",
Login => "login",
Consent => "consent",
SelectAccount => "select_account",
}
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct Address {
#[serde(default)] pub formatted: Option<String>,
#[serde(default)] pub street_address: Option<String>,
#[serde(default)] pub locality: Option<String>,
#[serde(default)] pub region: Option<String>,
#[serde(default)] pub postal_code: Option<String>,
#[serde(default)] pub country: Option<String>,
}
#[cfg(test)]
mod tests {
use reqwest::Url;
use crate::Client;
use crate::issuer;
macro_rules! test {
($issuer:ident) => {
#[test]
fn $issuer() {
let id = "test".to_string();
let secret = "a secret to everybody".to_string();
let redirect = Url::parse("https://example.com/re").unwrap();
let client = Client::discover(id, secret, redirect, issuer::$issuer()).unwrap();
client.auth_url(&Default::default());
}
}
}
test!(google);
test!(microsoft);
test!(paypal);
test!(salesforce);
test!(yahoo);
}