api_keys_simplified/
validator.rs

1use crate::error::{ConfigError, Error, Result};
2use crate::token_parser::{parse_token, Parts};
3use crate::HashConfig;
4use argon2::{
5    password_hash::{PasswordHash, PasswordVerifier},
6    Argon2,
7};
8use base64::engine::general_purpose::URL_SAFE_NO_PAD;
9use base64::Engine;
10use password_hash::PasswordHashString;
11
12#[derive(Clone)]
13pub struct KeyValidator {
14    hash: PasswordHashString,
15    has_checksum: bool,
16}
17
18/// Represents the status of an API key after verification
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum KeyStatus {
21    /// Key is valid
22    Valid,
23    /// Key is invalid (wrong key or hash mismatch)
24    Invalid,
25}
26
27impl KeyValidator {
28    /// Maximum allowed length for API keys (prevents DoS via oversized inputs)
29    const MAX_KEY_LENGTH: usize = 512;
30    /// Maximum allowed length for password hashes (prevents DoS via malformed hashes)
31    const MAX_HASH_LENGTH: usize = 512;
32
33    pub fn new(
34        hash_config: &HashConfig,
35        has_checksum: bool,
36    ) -> std::result::Result<KeyValidator, ConfigError> {
37        let dummy_hash = format!("$argon2id$v=19$m={},t={},p={}$0bJKH8iokgID0PWXnrsXvw$oef42xfOKBQMkCpvoQTeVHLhsYf+EQWMc2u4Ebn1MUo", hash_config.memory_cost(), hash_config.time_cost(), hash_config.parallelism());
38        let hash =
39            PasswordHashString::new(&dummy_hash).map_err(|_| ConfigError::InvalidArgon2Hash)?;
40
41        Ok(KeyValidator { hash, has_checksum })
42    }
43
44    fn verify_expiry(&self, parts: Parts) -> Result<KeyStatus> {
45        if let Some(expiry) = parts.expiry_b64 {
46            let decoded = URL_SAFE_NO_PAD
47                .decode(expiry)
48                .or(Err(Error::InvalidFormat))?;
49            let expiry = i64::from_be_bytes(decoded.try_into().or(Err(Error::InvalidFormat))?);
50
51            // TODO(ARCHITECTURE): time libs are platform dependent.
52            // We should set an `infra` layer and abstract
53            // out these libs.
54            if chrono::Utc::now().timestamp() <= expiry {
55                Ok(KeyStatus::Valid)
56            } else {
57                Ok(KeyStatus::Invalid)
58            }
59        } else {
60            Ok(KeyStatus::Valid)
61        }
62    }
63
64    pub fn verify(&self, provided_key: &str, stored_hash: &str) -> Result<KeyStatus> {
65        // Input length validation to prevent DoS attacks
66        if provided_key.len() > Self::MAX_KEY_LENGTH {
67            self.dummy_load();
68            return Err(Error::InvalidFormat);
69        }
70        if stored_hash.len() > Self::MAX_HASH_LENGTH {
71            self.dummy_load();
72            return Err(Error::InvalidFormat);
73        }
74
75        let token_parts = match parse_token(provided_key.as_bytes(), self.has_checksum) {
76            Ok(token_parts) => token_parts.1,
77            Err(_) => {
78                self.dummy_load();
79                return Ok(KeyStatus::Invalid);
80            }
81        };
82
83        // Parse the stored hash - if parsing fails, perform dummy verification
84        // to ensure constant timing and prevent user enumeration attacks
85        let parsed_hash = match PasswordHash::new(stored_hash) {
86            Ok(h) => h,
87            Err(_) => {
88                self.dummy_load();
89                return Ok(KeyStatus::Invalid);
90            }
91        };
92        let result = Argon2::default()
93            .verify_password(provided_key.as_bytes(), &parsed_hash)
94            .is_ok();
95
96        let argon_result = if result {
97            KeyStatus::Valid
98        } else {
99            KeyStatus::Invalid
100        };
101
102        // SECURITY: Force evaluation of expiry check BEFORE the match to ensure
103        // constant-time execution. This prevents the compiler from short-circuiting
104        // the expiry check when argon_result is Invalid, which would create a timing oracle.
105        let expiry_result = self.verify_expiry(token_parts)?;
106
107        match (argon_result, expiry_result) {
108            (KeyStatus::Invalid, _) | (_, KeyStatus::Invalid) => Ok(KeyStatus::Invalid),
109            _ => Ok(KeyStatus::Valid),
110        }
111    }
112    fn dummy_load(&self) {
113        // SECURITY: Perform dummy Argon2 verification to match timing of real verification
114        // This prevents timing attacks that could distinguish between "invalid hash format"
115        // and "valid hash but wrong password" errors
116        let dummy_password =
117            b"text-v1-test-okphUY-aqllb-qHoZDC9mVlm5sY9lvmm.AAAAAGk2Mvg.a54368d6331bf42dc18c";
118        parse_token(dummy_password, self.has_checksum).ok();
119
120        Argon2::default()
121            .verify_password(dummy_password, &self.hash.password_hash())
122            .ok();
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129    use crate::ExposeSecret;
130    use crate::{config::HashConfig, hasher::KeyHasher, SecureString};
131
132    #[test]
133    fn test_verification() {
134        let key = SecureString::from("sk_live_testkey123".to_string());
135        let hasher = KeyHasher::new(HashConfig::default());
136        let hash = hasher.hash(&key).unwrap();
137
138        let validator = KeyValidator::new(&HashConfig::default(), true).unwrap();
139        assert_eq!(
140            validator
141                .verify(key.expose_secret(), hash.as_ref())
142                .unwrap(),
143            KeyStatus::Valid
144        );
145        assert_eq!(
146            validator.verify("wrong_key", hash.as_ref()).unwrap(),
147            KeyStatus::Invalid
148        );
149    }
150
151    #[test]
152    fn test_invalid_hash_format() {
153        let validator = KeyValidator::new(&HashConfig::default(), true).unwrap();
154        let result = validator.verify("any_key", "invalid_hash");
155        // After timing oracle fix: invalid hash format returns Ok(Invalid) instead of Err
156        // to prevent timing-based user enumeration attacks
157        assert!(result.is_ok());
158        assert_eq!(result.unwrap(), KeyStatus::Invalid);
159    }
160
161    #[test]
162    fn test_oversized_key_rejection() {
163        let oversized_key = "a".repeat(513); // Exceeds MAX_KEY_LENGTH
164        let valid_key = SecureString::from("valid_key".to_string());
165        let hasher = KeyHasher::new(HashConfig::default());
166        let hash = hasher.hash(&valid_key).unwrap();
167
168        let validator = KeyValidator::new(&HashConfig::default(), true).unwrap();
169        let result = validator.verify(&oversized_key, hash.as_ref());
170        assert!(result.is_err());
171        assert!(matches!(result.unwrap_err(), Error::InvalidFormat));
172    }
173
174    #[test]
175    fn test_oversized_hash_rejection() {
176        let oversized_hash = "a".repeat(513); // Exceeds MAX_HASH_LENGTH
177
178        let validator = KeyValidator::new(&HashConfig::default(), true).unwrap();
179        let result = validator.verify("valid_key", &oversized_hash);
180        assert!(result.is_err());
181        assert!(matches!(result.unwrap_err(), Error::InvalidFormat));
182    }
183
184    #[test]
185    fn test_boundary_key_length() {
186        let valid_key = SecureString::from("valid_key".to_string());
187        let hasher = KeyHasher::new(HashConfig::default());
188        let hash = hasher.hash(&valid_key).unwrap();
189
190        let validator = KeyValidator::new(&HashConfig::default(), true).unwrap();
191
192        // Test at boundary (512 chars - should pass)
193        let max_key = "a".repeat(512);
194        let result = validator.verify(&max_key, hash.as_ref());
195        assert!(result.is_ok()); // Should not error on length check
196
197        // Test just over boundary (513 chars - should fail)
198        let over_max_key = "a".repeat(513);
199        let result = validator.verify(&over_max_key, hash.as_ref());
200        assert!(result.is_err());
201        assert!(matches!(result.unwrap_err(), Error::InvalidFormat));
202    }
203
204    #[test]
205    fn test_timing_oracle_protection() {
206        let valid_key = SecureString::from("sk_live_testkey123".to_string());
207        let hasher = KeyHasher::new(HashConfig::default());
208        let valid_hash = hasher.hash(&valid_key).unwrap();
209
210        let validator = KeyValidator::new(&HashConfig::default(), true).unwrap();
211
212        let result1 = validator.verify("wrong_key", valid_hash.as_ref());
213        assert!(result1.is_ok());
214        assert_eq!(result1.unwrap(), KeyStatus::Invalid);
215
216        let result2 = validator.verify(valid_key.expose_secret(), "invalid_hash_format");
217        assert!(result2.is_ok());
218        assert_eq!(result2.unwrap(), KeyStatus::Invalid);
219
220        let result3 = validator.verify(valid_key.expose_secret(), "not even close to valid");
221        assert!(result3.is_ok());
222        assert_eq!(result3.unwrap(), KeyStatus::Invalid);
223    }
224}