use crate::error::AuthenticationFailedError;
use crate::{Error, IdToken, Provider};
#[derive(Clone, Debug)]
pub enum OidcResponseMode {
FormPost,
Fragment,
}
impl std::ops::Deref for OidcResponseMode {
type Target = str;
fn deref(&self) -> &str {
match self {
Self::FormPost => "form_post",
Self::Fragment => "fragment",
}
}
}
#[derive(Clone, Debug)]
pub struct Client<P: Provider> {
client_id: String,
client_secret: String,
redirect_uri: String,
response_mode: OidcResponseMode,
provider: P,
}
impl<P: Provider> Client<P> {
pub fn auth_url(&self, session: &Session) -> url::Url {
let mut authurl = self.provider.authorization_endpoint();
authurl
.query_pairs_mut()
.append_pair("scope", "openid profile email")
.append_pair("response_type", "code")
.append_pair("client_id", &self.client_id)
.append_pair("nonce", &session.nonce())
.append_pair("state", &session.state())
.append_pair("response_mode", &self.response_mode)
.append_pair("redirect_uri", &self.redirect_uri)
.append_pair("code_challenge_method", "S256")
.append_pair("code_challenge", &session.pkce_challenge());
authurl
}
pub async fn authenticate(
&self,
state: &str,
code: &str,
session: &Session,
) -> Result<IdToken, Error> {
if state != session.state() {
return Err(Error::BadRequest);
}
let code_verifier = session.pkce_verifier();
let params = vec![
("grant_type", "authorization_code"),
("code", code),
("client_id", &self.client_id),
("client_secret", &self.client_secret),
("redirect_uri", &self.redirect_uri),
("code_verifier", &code_verifier),
];
let response = reqwest::Client::new()
.post(self.provider.token_endpoint().clone())
.form(¶ms)
.send()
.await?;
if let Err(err) = response.error_for_status_ref() {
let err_body = response.text().await?;
log::warn!("Token endpoint returns error {}", err_body);
Err(err.into())
} else {
let token_response = response.json::<OidcTokenEndpointResponse>().await?;
log::debug!("Token endpoint returns {:?}", token_response);
let id_token = IdToken::decode_without_jws_validation(&token_response.id_token)?;
self.validate_claims(&id_token, session)?;
Ok(id_token)
}
}
fn validate_claims(
&self,
id_token: &IdToken,
session: &Session,
) -> Result<(), AuthenticationFailedError> {
use std::time::SystemTime;
if !self.provider.validate_iss(&id_token.iss) {
log::info!("Invalid iss {}", id_token.iss);
return Err(AuthenticationFailedError::ClaimValidationError.into());
}
if id_token.aud != self.client_id {
log::info!("Invalid aud {}", id_token.aud);
return Err(AuthenticationFailedError::ClaimValidationError);
}
if &id_token.nonce != &session.nonce() {
log::info!("Invalid nonce {}", id_token.nonce);
return Err(AuthenticationFailedError::ClaimValidationError);
}
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map_or(0, |t| t.as_secs());
if id_token.iat > now + 60 || now > id_token.exp {
log::info!(
"Invalid iat {} or exp {} : now = {}",
id_token.iat,
id_token.exp,
now
);
return Err(AuthenticationFailedError::ClaimValidationError);
}
Ok(())
}
}
pub struct ClientBuilder<P: Provider> {
client_id: Option<String>,
client_secret: Option<String>,
redirect_uri: Option<String>,
response_mode: OidcResponseMode,
provider: P,
}
impl<P: Provider> ClientBuilder<P> {
pub(crate) fn from_provider(provider: P) -> Self {
Self {
client_id: None,
client_secret: None,
redirect_uri: None,
response_mode: OidcResponseMode::FormPost,
provider,
}
}
pub fn build(self) -> Option<Client<P>> {
match self {
Self {
client_id: Some(client_id),
client_secret: Some(client_secret),
redirect_uri: Some(redirect_uri),
response_mode,
provider,
} => Some(Client {
client_id,
client_secret,
redirect_uri,
response_mode,
provider,
}),
_ => {
None
}
}
}
pub fn client_id(self, client_id: &str) -> Self {
let mut builder = self;
builder.client_id = Some(client_id.to_string());
builder
}
pub fn client_secret(self, client_secret: &str) -> Self {
let mut builder = self;
builder.client_secret = Some(client_secret.to_string());
builder
}
pub fn redirect_uri(self, redirect_uri: &str) -> Self {
let mut builder = self;
builder.redirect_uri = Some(redirect_uri.to_string());
builder
}
pub fn response_mode(self, response_mode: OidcResponseMode) -> Self {
let mut builder = self;
builder.response_mode = response_mode;
builder
}
}
pub struct Session {
rand_bytes: [u8; 144],
}
impl Session {
pub fn new_session() -> Session {
use rand_core::{OsRng, RngCore};
let mut rand_bytes = [0u8; 144];
OsRng.fill_bytes(&mut rand_bytes);
Session { rand_bytes }
}
pub fn save_session(&self) -> (String, String) {
return (
base64::encode_config(&self.rand_bytes[..36], base64::URL_SAFE_NO_PAD),
base64::encode_config(&self.rand_bytes[36..], base64::URL_SAFE_NO_PAD),
);
}
pub fn load_session(
session_key: &str,
session_value: &str,
) -> Result<Self, base64::DecodeError> {
if session_key.len() == 48 && session_value.len() == 144 {
use base64::URL_SAFE_NO_PAD;
let mut rand_bytes = [0u8; 144];
base64::decode_config_slice(&session_key, URL_SAFE_NO_PAD, &mut rand_bytes[..36])?;
base64::decode_config_slice(&session_value, URL_SAFE_NO_PAD, &mut rand_bytes[36..])?;
Ok(Self { rand_bytes })
} else {
Err(base64::DecodeError::InvalidLength)
}
}
fn state(&self) -> String {
base64::encode_config(&self.rand_bytes[36..72], base64::URL_SAFE_NO_PAD)
}
fn nonce(&self) -> String {
base64::encode_config(&self.rand_bytes[72..108], base64::URL_SAFE_NO_PAD)
}
fn pkce_challenge(&self) -> String {
use sha2::{Digest, Sha256};
let challenge_byte = Sha256::digest(&self.pkce_verifier().as_bytes());
base64::encode_config(&challenge_byte, base64::URL_SAFE_NO_PAD)
}
fn pkce_verifier(&self) -> String {
base64::encode_config(&self.rand_bytes[108..144], base64::URL_SAFE_NO_PAD)
}
}
#[derive(Debug, serde::Deserialize)]
struct OidcTokenEndpointResponse {
access_token: Option<String>,
id_token: String,
}