Skip to main content

basileus/
pkce.rs

1use std::{collections::HashMap, fmt::Display, str::FromStr, sync::Mutex, time::Instant};
2
3use base64::{Engine, prelude::BASE64_URL_SAFE};
4use sha2::{Digest, Sha256};
5use tracing::warn;
6
7use crate::{
8    Basileus,
9    err::{PkceAuthError, PkceTokenError},
10};
11
12/// A client PKCE code challenge, as defined in [RFC 7636](https://datatracker.ietf.org/doc/html/rfc7636#section-4.2).
13#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
14#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
15pub struct CodeChallenge {
16    /// The base64URL-encoded `code_challenge`.
17    #[cfg_attr(feature = "serde", serde(rename = "code_challenge"))]
18    pub challenge: String,
19    /// The `code_challenge_method`.
20    #[cfg_attr(feature = "serde", serde(rename = "code_challenge_method"))]
21    pub method: CodeChallengeMethod,
22}
23
24/// The PKCE code challenge method, as defined in [RFC 7636](https://datatracker.ietf.org/doc/html/rfc7636#section-4.2).
25#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
26#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
27pub enum CodeChallengeMethod {
28    /// SHA256 transformation.
29    #[cfg_attr(feature = "serde", serde(rename = "S256"))]
30    S256,
31    /// Plain (`code_challenge = code_verifier`) transformation.
32    #[cfg_attr(feature = "serde", serde(rename = "plain"))]
33    Plain,
34}
35
36impl Display for CodeChallengeMethod {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        match self {
39            CodeChallengeMethod::S256 => write!(f, "S256"),
40            CodeChallengeMethod::Plain => write!(f, "plain"),
41        }
42    }
43}
44
45impl FromStr for CodeChallengeMethod {
46    type Err = String;
47
48    fn from_str(s: &str) -> Result<Self, Self::Err> {
49        match s {
50            "S256" => Ok(CodeChallengeMethod::S256),
51            "plain" => Ok(CodeChallengeMethod::Plain),
52            _ => Err(format!(
53                "invalid code challenge method: {s}, must be either 'S256' or 'plain'"
54            )),
55        }
56    }
57}
58
59impl CodeChallenge {
60    /// Create a new `CodeChallenge` object with specified base64URL-encoded code challenge.
61    pub fn new(challenge: String) -> Self {
62        Self {
63            challenge,
64            method: CodeChallengeMethod::S256,
65        }
66    }
67
68    /// Verify the `code_verifier` by checking if the hash matches the stored `code_challenge`.
69    pub fn verify(&self, code_verifier: &str) -> bool {
70        match self.method {
71            CodeChallengeMethod::S256 => {
72                let hash = Sha256::digest(code_verifier);
73                let encoded = BASE64_URL_SAFE.encode(hash);
74                self.challenge == encoded
75            }
76            CodeChallengeMethod::Plain => self.challenge == code_verifier,
77        }
78    }
79}
80
81impl Display for CodeChallenge {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        write!(f, "{}:{}", self.method, self.challenge)
84    }
85}
86
87/// A pending PKCE authentication request.
88pub struct Pkce {
89    /// The authorized user name.
90    pub user: String,
91    /// The associated code challenge from the PKCE authorization request.
92    pub code_challenge: CodeChallenge,
93    /// Time of creation.
94    pub begin: Instant,
95}
96
97impl Pkce {
98    /// Create a new `Pkce` object with specified authorized user and code challenge.
99    pub fn new(user: String, code_challenge: CodeChallenge) -> Self {
100        Self {
101            user,
102            code_challenge,
103            begin: Instant::now(),
104        }
105    }
106
107    /// Check if the PKCE request is still valid.
108    /// Expiry time is set to 10 minutes (600 seconds).
109    pub fn valid(&self) -> bool {
110        self.begin.elapsed().as_secs() <= 600
111    }
112}
113
114#[derive(Clone, Debug)]
115#[cfg_attr(feature = "serde", serde_inline_default::serde_inline_default)]
116#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
117pub struct PkceConfig {
118    /// Whether to allow the `plain` transformation method for PKCE code challenges.
119    ///
120    /// **This is a security vulnerability and should always be avoided.**
121    #[cfg(feature = "serde")]
122    #[serde_inline_default(false)]
123    pub allow_plain: bool,
124    /// Whether to allow the `plain` transformation method for PKCE code challenges.
125    ///
126    /// **This is a security vulnerability and should always be avoided.**
127    #[cfg(not(feature = "serde"))]
128    pub allow_plain: bool,
129}
130
131impl Default for PkceConfig {
132    fn default() -> Self {
133        Self { allow_plain: false }
134    }
135}
136
137pub struct PkceModule {
138    pub config: PkceConfig,
139    /// Map from PKCE challenges to their beloinging users.
140    pending: Mutex<HashMap<String, Pkce>>,
141}
142
143impl PkceModule {
144    pub fn new(config: PkceConfig) -> Self {
145        if config.allow_plain {
146            warn!(
147                "allowing `plain` transformation method for PKCE. This is a security vulnerability"
148            );
149        }
150        Self {
151            config,
152            pending: Mutex::new(HashMap::new()),
153        }
154    }
155}
156
157impl Basileus {
158    /// Handle a PKCE authorization request.
159    ///
160    /// If the authorization is successful, returns a base64URL-encoded [authorization code](https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2).
161    pub async fn pkce_auth_req(
162        &self,
163        user: &str,
164        pass: &str,
165        code_challenge: CodeChallenge,
166    ) -> Result<String, PkceAuthError> {
167        if code_challenge.method == CodeChallengeMethod::Plain && !self.pkce.config.allow_plain {
168            return Err(PkceAuthError::InsecurePlain);
169        }
170
171        if !self.verify_pass(user, pass).await? {
172            return Err(PkceAuthError::Unauthorized);
173        }
174
175        let auth_code = Sha256::digest(format!("{user}, {code_challenge}"));
176        let auth_code = BASE64_URL_SAFE.encode(auth_code);
177
178        let pkce = Pkce::new(user.into(), code_challenge);
179        self.pkce
180            .pending
181            .lock()
182            .unwrap()
183            .insert(auth_code.clone(), pkce);
184        Ok(auth_code)
185    }
186
187    /// Handle a PKCE access token request.
188    ///
189    /// A successful request requires a valid previously issued authorization code (through [`Self::pkce_auth_req`]) and a matching code verifier.
190    ///
191    /// Returns the token if successful.
192    pub fn pkce_token_req(
193        &self,
194        code: &str,
195        code_verifier: &str,
196    ) -> Result<String, PkceTokenError> {
197        let pkce = match self.pkce.pending.lock().unwrap().remove(code) {
198            Some(pkce) => pkce,
199            None => return Err(PkceTokenError::InvalidCode),
200        };
201        if !pkce.valid() {
202            return Err(PkceTokenError::ExpiredCode);
203        }
204        if !pkce.code_challenge.verify(code_verifier) {
205            return Err(PkceTokenError::InvalidVerifier);
206        }
207        let token = self.issue_token(&pkce.user);
208        Ok(token)
209    }
210}