use crate::error::AuthenticationFailedError;
use crate::{Error, IdToken, Provider};
#[derive(Clone, Debug)]
pub enum OidcResponseMode {
Query,
FormPost,
Fragment,
}
impl std::ops::Deref for OidcResponseMode {
type Target = str;
fn deref(&self) -> &str {
match self {
Self::Query => "query",
Self::FormPost => "form_post",
Self::Fragment => "fragment",
}
}
}
#[derive(Clone, Debug)]
pub enum OidcPrompt {
NoPrompt, Login,
Consent,
SelectAccount,
}
impl std::ops::Deref for OidcPrompt {
type Target = str;
fn deref(&self) -> &str {
match self {
Self::NoPrompt => "none",
Self::Login => "login",
Self::Consent => "consent",
Self::SelectAccount => "select_account",
}
}
}
#[derive(Clone, Debug)]
pub struct Client<P: Provider> {
client_id: String,
client_secret: String,
redirect_uri: String,
response_mode: OidcResponseMode,
provider: P,
}
impl<P: Provider> Client<P> {
pub fn auth_url(&self, session: &Session, prompt: Option<OidcPrompt>) -> url::Url {
let mut authurl = self.provider.authorization_endpoint();
authurl
.query_pairs_mut()
.append_pair("scope", "openid profile email")
.append_pair("response_type", "code")
.append_pair("client_id", &self.client_id)
.append_pair("nonce", &session.nonce())
.append_pair("state", &session.state())
.append_pair("response_mode", &self.response_mode)
.append_pair("redirect_uri", &self.redirect_uri)
.append_pair("code_challenge_method", "S256")
.append_pair("code_challenge", &session.pkce_challenge());
if let Some(prompt) = prompt {
authurl.query_pairs_mut().append_pair("prompt", &prompt);
}
authurl
}
pub async fn authenticate<T>(
&self,
state: &str,
code: &str,
session: &Session,
) -> Result<IdToken<T>, Error>
where
T: serde::de::DeserializeOwned,
{
if state != session.state() {
log::warn!("state mismatch");
return Err(Error::BadRequest);
}
let code_verifier = session.pkce_verifier();
let params = vec![
("grant_type", "authorization_code"),
("code", code),
("client_id", &self.client_id),
("client_secret", &self.client_secret),
("redirect_uri", &self.redirect_uri),
("code_verifier", &code_verifier),
];
let response = reqwest::Client::new()
.post(self.provider.token_endpoint().clone())
.form(¶ms)
.send()
.await?;
if let Err(err) = response.error_for_status_ref() {
let err_body = response.text().await?;
log::warn!("Token endpoint returns error {}", err_body);
Err(err.into())
} else {
let token_response = response.json::<OidcTokenEndpointResponse>().await?;
log::debug!("Token endpoint returns {:?}", token_response);
let id_token = IdToken::<T>::decode_without_jws_validation(&token_response.id_token)?;
self.validate_claims(&id_token, session)?;
Ok(id_token)
}
}
fn validate_claims<T>(
&self,
id_token: &IdToken<T>,
session: &Session,
) -> Result<(), AuthenticationFailedError> {
use std::time::SystemTime;
if !self.provider.validate_iss(&id_token.iss) {
log::info!("Invalid iss {}", id_token.iss);
return Err(AuthenticationFailedError::ClaimValidationError);
}
if id_token.aud != self.client_id {
log::info!("Invalid aud {}", id_token.aud);
return Err(AuthenticationFailedError::ClaimValidationError);
}
if &id_token.nonce != &session.nonce() {
log::info!("Invalid nonce {}", id_token.nonce);
return Err(AuthenticationFailedError::ClaimValidationError);
}
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map_or(0, |t| t.as_secs());
if id_token.iat > now + 60 || now > id_token.exp {
log::info!(
"Invalid iat {} or exp {} : now = {}",
id_token.iat,
id_token.exp,
now
);
return Err(AuthenticationFailedError::ClaimValidationError);
}
Ok(())
}
}
pub struct ClientBuilder<P: Provider> {
client_id: Option<String>,
client_secret: Option<String>,
redirect_uri: Option<String>,
response_mode: OidcResponseMode,
provider: P,
}
impl<P: Provider> ClientBuilder<P> {
pub(crate) fn from_provider(provider: P) -> Self {
Self {
client_id: None,
client_secret: None,
redirect_uri: None,
response_mode: OidcResponseMode::Query,
provider,
}
}
pub fn build(self) -> Option<Client<P>> {
match self {
Self {
client_id: Some(client_id),
client_secret: Some(client_secret),
redirect_uri: Some(redirect_uri),
response_mode,
provider,
} => Some(Client {
client_id,
client_secret,
redirect_uri,
response_mode,
provider,
}),
_ => {
None
}
}
}
pub fn client_id(self, client_id: &str) -> Self {
let mut builder = self;
builder.client_id = Some(client_id.to_string());
builder
}
pub fn client_secret(self, client_secret: &str) -> Self {
let mut builder = self;
builder.client_secret = Some(client_secret.to_string());
builder
}
pub fn redirect_uri(self, redirect_uri: &str) -> Self {
let mut builder = self;
builder.redirect_uri = Some(redirect_uri.to_string());
builder
}
pub fn response_mode(self, response_mode: OidcResponseMode) -> Self {
let mut builder = self;
builder.response_mode = response_mode;
builder
}
}
pub struct Session {
rand_bytes: [u8; 144],
}
impl Session {
pub fn new_session() -> Result<Session, crate::Error> {
let mut rand_bytes = [0u8; 144];
getrandom::fill(&mut rand_bytes).map_err(|e| {
log::error!("getrandom() failed with {:?}", e);
crate::Error::InternalError
})?;
Ok(Session { rand_bytes })
}
pub fn save_session(&self) -> (String, String) {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
return (self.key(), URL_SAFE_NO_PAD.encode(&self.rand_bytes[36..]));
}
pub fn load_session(
session_key: &str,
session_value: &str,
) -> Result<Self, base64::DecodeSliceError> {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
let mut rand_bytes = [0u8; 144];
URL_SAFE_NO_PAD.decode_slice(session_key, &mut rand_bytes[..36])?;
URL_SAFE_NO_PAD.decode_slice(session_value, &mut rand_bytes[36..])?;
Ok(Self { rand_bytes })
}
pub fn key(&self) -> String {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
URL_SAFE_NO_PAD.encode(&self.rand_bytes[..36])
}
fn state(&self) -> String {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
URL_SAFE_NO_PAD.encode(&self.rand_bytes[36..72])
}
fn nonce(&self) -> String {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
URL_SAFE_NO_PAD.encode(&self.rand_bytes[72..108])
}
fn pkce_challenge(&self) -> String {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use sha2::{Digest, Sha256};
let challenge_byte = Sha256::digest(&self.pkce_verifier().as_bytes());
URL_SAFE_NO_PAD.encode(&challenge_byte)
}
fn pkce_verifier(&self) -> String {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
URL_SAFE_NO_PAD.encode(&self.rand_bytes[108..144])
}
}
#[derive(Debug, serde::Deserialize)]
struct OidcTokenEndpointResponse {
id_token: String,
}