1use crate::error::{ConfigError, Error, Result};
2use crate::token_parser::{parse_token, Parts};
3use crate::SecureString;
4use argon2::{
5 password_hash::{PasswordHash, PasswordVerifier},
6 Argon2,
7};
8use base64::engine::general_purpose::URL_SAFE_NO_PAD;
9use base64::Engine;
10use password_hash::PasswordHashString;
11
12#[derive(Clone)]
13pub struct KeyValidator {
14 hash: PasswordHashString,
15 has_checksum: bool,
16 dummy_password: SecureString,
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum KeyStatus {
23 Valid,
25 Invalid,
27}
28
29impl KeyValidator {
30 const MAX_KEY_LENGTH: usize = 512;
32 const MAX_HASH_LENGTH: usize = 512;
34
35 pub fn new(
36 has_checksum: bool,
37 dummy_key: SecureString,
38 dummy_hash: String,
39 ) -> std::result::Result<KeyValidator, ConfigError> {
40 let hash =
41 PasswordHashString::new(&dummy_hash).map_err(|_| ConfigError::InvalidArgon2Hash)?;
42
43 Ok(KeyValidator {
44 hash,
45 has_checksum,
46 dummy_password: dummy_key,
47 })
48 }
49
50 fn verify_expiry(
51 &self,
52 parts: Parts,
53 expiry_grace_period: std::time::Duration,
54 ) -> Result<KeyStatus> {
55 if let Some(expiry) = parts.expiry_b64 {
56 let decoded = URL_SAFE_NO_PAD
57 .decode(expiry)
58 .or(Err(Error::InvalidFormat))?;
59 let expiry_timestamp =
60 i64::from_be_bytes(decoded.try_into().or(Err(Error::InvalidFormat))?);
61
62 let current_time = chrono::Utc::now().timestamp();
63 let grace_seconds = expiry_grace_period.as_secs() as i64;
64
65 if expiry_timestamp + grace_seconds < current_time {
69 return Ok(KeyStatus::Invalid);
70 }
71 Ok(KeyStatus::Valid)
72 } else {
73 Ok(KeyStatus::Valid)
74 }
75 }
76
77 pub fn verify(
78 &self,
79 provided_key: &str,
80 stored_hash: &str,
81 expiry_grace_period: std::time::Duration,
82 ) -> Result<KeyStatus> {
83 if provided_key.len() > Self::MAX_KEY_LENGTH {
85 self.dummy_load();
86 return Err(Error::InvalidFormat);
87 }
88 if stored_hash.len() > Self::MAX_HASH_LENGTH {
89 self.dummy_load();
90 return Err(Error::InvalidFormat);
91 }
92
93 let token_parts = match parse_token(provided_key.as_bytes(), self.has_checksum) {
94 Ok(token_parts) => token_parts.1,
95 Err(_) => {
96 self.dummy_load();
97 return Ok(KeyStatus::Invalid);
98 }
99 };
100
101 let parsed_hash = match PasswordHash::new(stored_hash) {
104 Ok(h) => h,
105 Err(_) => {
106 self.dummy_load();
107 return Ok(KeyStatus::Invalid);
108 }
109 };
110 let result = Argon2::default()
111 .verify_password(provided_key.as_bytes(), &parsed_hash)
112 .is_ok();
113
114 let argon_result = if result {
115 KeyStatus::Valid
116 } else {
117 KeyStatus::Invalid
118 };
119
120 let expiry_result = self.verify_expiry(token_parts, expiry_grace_period)?;
124
125 match (argon_result, expiry_result) {
126 (KeyStatus::Invalid, _) | (_, KeyStatus::Invalid) => Ok(KeyStatus::Invalid),
127 _ => Ok(KeyStatus::Valid),
128 }
129 }
130 fn dummy_load(&self) {
131 use crate::ExposeSecret;
135 let dummy_bytes = self.dummy_password.expose_secret().as_bytes();
136 parse_token(dummy_bytes, self.has_checksum).ok();
137
138 Argon2::default()
139 .verify_password(dummy_bytes, &self.hash.password_hash())
140 .ok();
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147 use crate::ExposeSecret;
148 use crate::{config::HashConfig, hasher::KeyHasher, SecureString};
149
150 fn dummy_key_and_hash() -> (SecureString, String) {
151 let key = SecureString::from("sk-live-dummy123test".to_string());
152 let hasher = KeyHasher::new(HashConfig::default());
153 let hash = hasher.hash(&key).unwrap();
154 (key, hash)
155 }
156
157 #[test]
158 fn test_verification() {
159 let key = SecureString::from("sk_live_testkey123".to_string());
160 let hasher = KeyHasher::new(HashConfig::default());
161 let hash = hasher.hash(&key).unwrap();
162
163 let (dummy_key, dummy_hash) = dummy_key_and_hash();
164 let validator = KeyValidator::new(true, dummy_key, dummy_hash).unwrap();
165 assert_eq!(
166 validator
167 .verify(
168 key.expose_secret(),
169 hash.as_ref(),
170 std::time::Duration::ZERO
171 )
172 .unwrap(),
173 KeyStatus::Valid
174 );
175 assert_eq!(
176 validator
177 .verify("wrong_key", hash.as_ref(), std::time::Duration::ZERO)
178 .unwrap(),
179 KeyStatus::Invalid
180 );
181 }
182
183 #[test]
184 fn test_invalid_hash_format() {
185 let (dummy_key, dummy_hash) = dummy_key_and_hash();
186 let validator = KeyValidator::new(true, dummy_key, dummy_hash).unwrap();
187 let result = validator.verify("any_key", "invalid_hash", std::time::Duration::ZERO);
188 assert!(result.is_ok());
191 assert_eq!(result.unwrap(), KeyStatus::Invalid);
192 }
193
194 #[test]
195 fn test_oversized_key_rejection() {
196 let oversized_key = "a".repeat(513); let valid_key = SecureString::from("valid_key".to_string());
198 let hasher = KeyHasher::new(HashConfig::default());
199 let hash = hasher.hash(&valid_key).unwrap();
200
201 let (dummy_key, dummy_hash) = dummy_key_and_hash();
202 let validator = KeyValidator::new(true, dummy_key, dummy_hash).unwrap();
203 let result = validator.verify(&oversized_key, hash.as_ref(), std::time::Duration::ZERO);
204 assert!(result.is_err());
205 assert!(matches!(result.unwrap_err(), Error::InvalidFormat));
206 }
207
208 #[test]
209 fn test_oversized_hash_rejection() {
210 let oversized_hash = "a".repeat(513); let (dummy_key, dummy_hash) = dummy_key_and_hash();
213 let validator = KeyValidator::new(true, dummy_key, dummy_hash).unwrap();
214 let result = validator.verify("valid_key", &oversized_hash, std::time::Duration::ZERO);
215 assert!(result.is_err());
216 assert!(matches!(result.unwrap_err(), Error::InvalidFormat));
217 }
218
219 #[test]
220 fn test_boundary_key_length() {
221 let valid_key = SecureString::from("valid_key".to_string());
222 let hasher = KeyHasher::new(HashConfig::default());
223 let hash = hasher.hash(&valid_key).unwrap();
224
225 let (dummy_key, dummy_hash) = dummy_key_and_hash();
226 let validator = KeyValidator::new(true, dummy_key, dummy_hash).unwrap();
227
228 let max_key = "a".repeat(512);
230 let result = validator.verify(&max_key, hash.as_ref(), std::time::Duration::ZERO);
231 assert!(result.is_ok()); let over_max_key = "a".repeat(513);
235 let result = validator.verify(&over_max_key, hash.as_ref(), std::time::Duration::ZERO);
236 assert!(result.is_err());
237 assert!(matches!(result.unwrap_err(), Error::InvalidFormat));
238 }
239
240 #[test]
241 fn test_timing_oracle_protection() {
242 let valid_key = SecureString::from("sk_live_testkey123".to_string());
243 let hasher = KeyHasher::new(HashConfig::default());
244 let valid_hash = hasher.hash(&valid_key).unwrap();
245
246 let (dummy_key, dummy_hash) = dummy_key_and_hash();
247 let validator = KeyValidator::new(true, dummy_key, dummy_hash).unwrap();
248
249 let result1 = validator.verify("wrong_key", valid_hash.as_ref(), std::time::Duration::ZERO);
250 assert!(result1.is_ok());
251 assert_eq!(result1.unwrap(), KeyStatus::Invalid);
252
253 let result2 = validator.verify(
254 valid_key.expose_secret(),
255 "invalid_hash_format",
256 std::time::Duration::ZERO,
257 );
258 assert!(result2.is_ok());
259 assert_eq!(result2.unwrap(), KeyStatus::Invalid);
260
261 let result3 = validator.verify(
262 valid_key.expose_secret(),
263 "not even close to valid",
264 std::time::Duration::ZERO,
265 );
266 assert!(result3.is_ok());
267 assert_eq!(result3.unwrap(), KeyStatus::Invalid);
268 }
269}