1use std::borrow::Cow;
20
21use data_encoding::BASE64URL_NOPAD;
22use mas_iana::oauth::PkceCodeChallengeMethod;
23use serde::{Deserialize, Serialize};
24use sha2::{Digest, Sha256};
25use thiserror::Error;
26
27#[derive(Debug, Error, PartialEq, Eq)]
29pub enum CodeChallengeError {
30 #[error("code_verifier should be at least 43 characters long")]
32 TooShort,
33
34 #[error("code_verifier should be at most 128 characters long")]
36 TooLong,
37
38 #[error("code_verifier contains invalid characters")]
40 InvalidCharacters,
41
42 #[error("challenge verification failed")]
44 VerificationFailed,
45
46 #[error("unknown challenge method")]
48 UnknownChallengeMethod,
49}
50
51fn validate_verifier(verifier: &str) -> Result<(), CodeChallengeError> {
52 if verifier.len() < 43 {
53 return Err(CodeChallengeError::TooShort);
54 }
55
56 if verifier.len() > 128 {
57 return Err(CodeChallengeError::TooLong);
58 }
59
60 if !verifier
61 .chars()
62 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~')
63 {
64 return Err(CodeChallengeError::InvalidCharacters);
65 }
66
67 Ok(())
68}
69
70pub trait CodeChallengeMethodExt {
72 fn compute_challenge<'a>(&self, verifier: &'a str) -> Result<Cow<'a, str>, CodeChallengeError>;
79
80 fn verify(&self, challenge: &str, verifier: &str) -> Result<(), CodeChallengeError>
88 where
89 Self: Sized,
90 {
91 if self.compute_challenge(verifier)? == challenge {
92 Ok(())
93 } else {
94 Err(CodeChallengeError::VerificationFailed)
95 }
96 }
97}
98
99impl CodeChallengeMethodExt for PkceCodeChallengeMethod {
100 fn compute_challenge<'a>(&self, verifier: &'a str) -> Result<Cow<'a, str>, CodeChallengeError> {
101 validate_verifier(verifier)?;
102
103 let challenge = match self {
104 Self::Plain => verifier.into(),
105 Self::S256 => {
106 let mut hasher = Sha256::new();
107 hasher.update(verifier.as_bytes());
108 let hash = hasher.finalize();
109 let verifier = BASE64URL_NOPAD.encode(&hash);
110 verifier.into()
111 }
112 _ => return Err(CodeChallengeError::UnknownChallengeMethod),
113 };
114
115 Ok(challenge)
116 }
117}
118
119#[derive(Clone, Serialize, Deserialize)]
121pub struct AuthorizationRequest {
122 pub code_challenge_method: PkceCodeChallengeMethod,
124
125 pub code_challenge: String,
127}
128
129#[derive(Clone, Serialize, Deserialize)]
131pub struct TokenRequest {
132 pub code_challenge_verifier: String,
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 #[test]
141 fn test_pkce_verification() {
142 use PkceCodeChallengeMethod::{Plain, S256};
143 let challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM";
145
146 assert!(S256
147 .verify(challenge, "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk")
148 .is_ok());
149
150 assert!(Plain.verify(challenge, challenge).is_ok());
151
152 assert_eq!(
153 S256.verify(challenge, "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"),
154 Err(CodeChallengeError::VerificationFailed),
155 );
156
157 assert_eq!(
158 S256.verify(challenge, "tooshort"),
159 Err(CodeChallengeError::TooShort),
160 );
161
162 assert_eq!(
163 S256.verify(challenge, "toolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolongtoolong"),
164 Err(CodeChallengeError::TooLong),
165 );
166
167 assert_eq!(
168 S256.verify(
169 challenge,
170 "this is long enough but has invalid characters in it"
171 ),
172 Err(CodeChallengeError::InvalidCharacters),
173 );
174 }
175}