1use std::{collections::HashMap, fmt::Display, str::FromStr, sync::Mutex, time::Instant};
2
3use base64::{Engine, prelude::BASE64_URL_SAFE};
4use sha2::{Digest, Sha256};
5use tracing::warn;
6
7use crate::{
8 Basileus,
9 err::{PkceAuthError, PkceTokenError},
10};
11
12#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
14#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
15pub struct CodeChallenge {
16 #[cfg_attr(feature = "serde", serde(rename = "code_challenge"))]
18 pub challenge: String,
19 #[cfg_attr(feature = "serde", serde(rename = "code_challenge_method"))]
21 pub method: CodeChallengeMethod,
22}
23
24#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
26#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
27pub enum CodeChallengeMethod {
28 #[cfg_attr(feature = "serde", serde(rename = "S256"))]
30 S256,
31 #[cfg_attr(feature = "serde", serde(rename = "plain"))]
33 Plain,
34}
35
36impl Display for CodeChallengeMethod {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 match self {
39 CodeChallengeMethod::S256 => write!(f, "S256"),
40 CodeChallengeMethod::Plain => write!(f, "plain"),
41 }
42 }
43}
44
45impl FromStr for CodeChallengeMethod {
46 type Err = String;
47
48 fn from_str(s: &str) -> Result<Self, Self::Err> {
49 match s {
50 "S256" => Ok(CodeChallengeMethod::S256),
51 "plain" => Ok(CodeChallengeMethod::Plain),
52 _ => Err(format!(
53 "invalid code challenge method: {s}, must be either 'S256' or 'plain'"
54 )),
55 }
56 }
57}
58
59impl CodeChallenge {
60 pub fn new(challenge: String) -> Self {
62 Self {
63 challenge,
64 method: CodeChallengeMethod::S256,
65 }
66 }
67
68 pub fn verify(&self, code_verifier: &str) -> bool {
70 match self.method {
71 CodeChallengeMethod::S256 => {
72 let hash = Sha256::digest(code_verifier);
73 let encoded = BASE64_URL_SAFE.encode(hash);
74 self.challenge == encoded
75 }
76 CodeChallengeMethod::Plain => self.challenge == code_verifier,
77 }
78 }
79}
80
81impl Display for CodeChallenge {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 write!(f, "{}:{}", self.method, self.challenge)
84 }
85}
86
87pub struct Pkce {
89 pub user: String,
91 pub code_challenge: CodeChallenge,
93 pub begin: Instant,
95}
96
97impl Pkce {
98 pub fn new(user: String, code_challenge: CodeChallenge) -> Self {
100 Self {
101 user,
102 code_challenge,
103 begin: Instant::now(),
104 }
105 }
106
107 pub fn valid(&self) -> bool {
110 self.begin.elapsed().as_secs() <= 600
111 }
112}
113
114#[derive(Clone, Debug)]
115#[cfg_attr(feature = "serde", serde_inline_default::serde_inline_default)]
116#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
117pub struct PkceConfig {
118 #[cfg(feature = "serde")]
122 #[serde_inline_default(false)]
123 pub allow_plain: bool,
124 #[cfg(not(feature = "serde"))]
128 pub allow_plain: bool,
129}
130
131impl Default for PkceConfig {
132 fn default() -> Self {
133 Self { allow_plain: false }
134 }
135}
136
137pub struct PkceModule {
138 pub config: PkceConfig,
139 pending: Mutex<HashMap<String, Pkce>>,
141}
142
143impl PkceModule {
144 pub fn new(config: PkceConfig) -> Self {
145 if config.allow_plain {
146 warn!(
147 "allowing `plain` transformation method for PKCE. This is a security vulnerability"
148 );
149 }
150 Self {
151 config,
152 pending: Mutex::new(HashMap::new()),
153 }
154 }
155}
156
157impl Basileus {
158 pub async fn pkce_auth_req(
162 &self,
163 user: &str,
164 pass: &str,
165 code_challenge: CodeChallenge,
166 ) -> Result<String, PkceAuthError> {
167 if code_challenge.method == CodeChallengeMethod::Plain && !self.pkce.config.allow_plain {
168 return Err(PkceAuthError::InsecurePlain);
169 }
170
171 if !self.verify_pass(user, pass).await? {
172 return Err(PkceAuthError::Unauthorized);
173 }
174
175 let auth_code = Sha256::digest(format!("{user}, {code_challenge}"));
176 let auth_code = BASE64_URL_SAFE.encode(auth_code);
177
178 let pkce = Pkce::new(user.into(), code_challenge);
179 self.pkce
180 .pending
181 .lock()
182 .unwrap()
183 .insert(auth_code.clone(), pkce);
184 Ok(auth_code)
185 }
186
187 pub fn pkce_token_req(
193 &self,
194 code: &str,
195 code_verifier: &str,
196 ) -> Result<String, PkceTokenError> {
197 let pkce = match self.pkce.pending.lock().unwrap().remove(code) {
198 Some(pkce) => pkce,
199 None => return Err(PkceTokenError::InvalidCode),
200 };
201 if !pkce.valid() {
202 return Err(PkceTokenError::ExpiredCode);
203 }
204 if !pkce.code_challenge.verify(code_verifier) {
205 return Err(PkceTokenError::InvalidVerifier);
206 }
207 let token = self.issue_token(&pkce.user);
208 Ok(token)
209 }
210}