use std::{collections::HashMap, fmt::Display, str::FromStr, sync::Mutex, time::Instant};
use base64::{Engine, prelude::BASE64_URL_SAFE};
use sha2::{Digest, Sha256};
use tracing::warn;
use crate::{
Basileus,
err::{PkceAuthError, PkceTokenError},
};
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct CodeChallenge {
#[cfg_attr(feature = "serde", serde(rename = "code_challenge"))]
pub challenge: String,
#[cfg_attr(feature = "serde", serde(rename = "code_challenge_method"))]
pub method: CodeChallengeMethod,
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum CodeChallengeMethod {
#[cfg_attr(feature = "serde", serde(rename = "S256"))]
S256,
#[cfg_attr(feature = "serde", serde(rename = "plain"))]
Plain,
}
impl Display for CodeChallengeMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CodeChallengeMethod::S256 => write!(f, "S256"),
CodeChallengeMethod::Plain => write!(f, "plain"),
}
}
}
impl FromStr for CodeChallengeMethod {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"S256" => Ok(CodeChallengeMethod::S256),
"plain" => Ok(CodeChallengeMethod::Plain),
_ => Err(format!(
"invalid code challenge method: {s}, must be either 'S256' or 'plain'"
)),
}
}
}
impl CodeChallenge {
pub fn new(challenge: String) -> Self {
Self {
challenge,
method: CodeChallengeMethod::S256,
}
}
pub fn verify(&self, code_verifier: &str) -> bool {
match self.method {
CodeChallengeMethod::S256 => {
let hash = Sha256::digest(code_verifier);
let encoded = BASE64_URL_SAFE.encode(hash);
self.challenge == encoded
}
CodeChallengeMethod::Plain => self.challenge == code_verifier,
}
}
}
impl Display for CodeChallenge {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}:{}", self.method, self.challenge)
}
}
pub struct Pkce {
pub user: String,
pub code_challenge: CodeChallenge,
pub begin: Instant,
}
impl Pkce {
pub fn new(user: String, code_challenge: CodeChallenge) -> Self {
Self {
user,
code_challenge,
begin: Instant::now(),
}
}
pub fn valid(&self) -> bool {
self.begin.elapsed().as_secs() <= 600
}
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", serde_inline_default::serde_inline_default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct PkceConfig {
#[cfg(feature = "serde")]
#[serde_inline_default(false)]
pub allow_plain: bool,
#[cfg(not(feature = "serde"))]
pub allow_plain: bool,
}
impl Default for PkceConfig {
fn default() -> Self {
Self { allow_plain: false }
}
}
pub struct PkceModule {
pub config: PkceConfig,
pending: Mutex<HashMap<String, Pkce>>,
}
impl PkceModule {
pub fn new(config: PkceConfig) -> Self {
if config.allow_plain {
warn!(
"allowing `plain` transformation method for PKCE. This is a security vulnerability"
);
}
Self {
config,
pending: Mutex::new(HashMap::new()),
}
}
}
impl Basileus {
pub async fn pkce_auth_req(
&self,
user: &str,
pass: &str,
code_challenge: CodeChallenge,
) -> Result<String, PkceAuthError> {
if code_challenge.method == CodeChallengeMethod::Plain && !self.pkce.config.allow_plain {
return Err(PkceAuthError::InsecurePlain);
}
if !self.verify_pass(user, pass).await? {
return Err(PkceAuthError::Unauthorized);
}
let auth_code = Sha256::digest(format!("{user}, {code_challenge}"));
let auth_code = BASE64_URL_SAFE.encode(auth_code);
let pkce = Pkce::new(user.into(), code_challenge);
self.pkce
.pending
.lock()
.unwrap()
.insert(auth_code.clone(), pkce);
Ok(auth_code)
}
pub fn pkce_token_req(
&self,
code: &str,
code_verifier: &str,
) -> Result<String, PkceTokenError> {
let pkce = match self.pkce.pending.lock().unwrap().remove(code) {
Some(pkce) => pkce,
None => return Err(PkceTokenError::InvalidCode),
};
if !pkce.valid() {
return Err(PkceTokenError::ExpiredCode);
}
if !pkce.code_challenge.verify(code_verifier) {
return Err(PkceTokenError::InvalidVerifier);
}
let token = self.issue_token(&pkce.user);
Ok(token)
}
}