api_keys_simplified/
validator.rs

1use crate::error::{ConfigError, Error, Result};
2use crate::token_parser::{parse_token, Parts};
3use crate::SecureString;
4use argon2::{
5    password_hash::{PasswordHash, PasswordVerifier},
6    Argon2,
7};
8use base64::engine::general_purpose::URL_SAFE_NO_PAD;
9use base64::Engine;
10use password_hash::PasswordHashString;
11
12#[derive(Clone)]
13pub struct KeyValidator {
14    hash: PasswordHashString,
15    has_checksum: bool,
16    /// Dummy password for timing attack protection (should be a generated API key)
17    dummy_password: SecureString,
18}
19
20/// Represents the status of an API key after verification
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum KeyStatus {
23    /// Key is valid
24    Valid,
25    /// Key is invalid (wrong key or hash mismatch)
26    Invalid,
27}
28
29impl KeyValidator {
30    /// Maximum allowed length for API keys (prevents DoS via oversized inputs)
31    const MAX_KEY_LENGTH: usize = 512;
32    /// Maximum allowed length for password hashes (prevents DoS via malformed hashes)
33    const MAX_HASH_LENGTH: usize = 512;
34
35    pub fn new(
36        has_checksum: bool,
37        dummy_key: SecureString,
38        dummy_hash: String,
39    ) -> std::result::Result<KeyValidator, ConfigError> {
40        let hash =
41            PasswordHashString::new(&dummy_hash).map_err(|_| ConfigError::InvalidArgon2Hash)?;
42
43        Ok(KeyValidator {
44            hash,
45            has_checksum,
46            dummy_password: dummy_key,
47        })
48    }
49
50    fn verify_expiry(
51        &self,
52        parts: Parts,
53        expiry_grace_period: std::time::Duration,
54    ) -> Result<KeyStatus> {
55        if let Some(expiry) = parts.expiry_b64 {
56            let decoded = URL_SAFE_NO_PAD
57                .decode(expiry)
58                .or(Err(Error::InvalidFormat))?;
59            let expiry_timestamp =
60                i64::from_be_bytes(decoded.try_into().or(Err(Error::InvalidFormat))?);
61
62            let current_time = chrono::Utc::now().timestamp();
63            let grace_seconds = expiry_grace_period.as_secs() as i64;
64
65            // Key is invalid if it expired more than grace_period ago
66            // This ensures once a key expires beyond the grace period, it stays expired
67            // even if the clock goes backwards
68            if expiry_timestamp + grace_seconds < current_time {
69                return Ok(KeyStatus::Invalid);
70            }
71            Ok(KeyStatus::Valid)
72        } else {
73            Ok(KeyStatus::Valid)
74        }
75    }
76
77    pub fn verify(
78        &self,
79        provided_key: &str,
80        stored_hash: &str,
81        expiry_grace_period: std::time::Duration,
82    ) -> Result<KeyStatus> {
83        // Input length validation to prevent DoS attacks
84        if provided_key.len() > Self::MAX_KEY_LENGTH {
85            self.dummy_load();
86            return Err(Error::InvalidFormat);
87        }
88        if stored_hash.len() > Self::MAX_HASH_LENGTH {
89            self.dummy_load();
90            return Err(Error::InvalidFormat);
91        }
92
93        let token_parts = match parse_token(provided_key.as_bytes(), self.has_checksum) {
94            Ok(token_parts) => token_parts.1,
95            Err(_) => {
96                self.dummy_load();
97                return Ok(KeyStatus::Invalid);
98            }
99        };
100
101        // Parse the stored hash - if parsing fails, perform dummy verification
102        // to ensure constant timing and prevent user enumeration attacks
103        let parsed_hash = match PasswordHash::new(stored_hash) {
104            Ok(h) => h,
105            Err(_) => {
106                self.dummy_load();
107                return Ok(KeyStatus::Invalid);
108            }
109        };
110        let result = Argon2::default()
111            .verify_password(provided_key.as_bytes(), &parsed_hash)
112            .is_ok();
113
114        let argon_result = if result {
115            KeyStatus::Valid
116        } else {
117            KeyStatus::Invalid
118        };
119
120        // SECURITY: Force evaluation of expiry check BEFORE the match to ensure
121        // constant-time execution. This prevents the compiler from short-circuiting
122        // the expiry check when argon_result is Invalid, which would create a timing oracle.
123        let expiry_result = self.verify_expiry(token_parts, expiry_grace_period)?;
124
125        match (argon_result, expiry_result) {
126            (KeyStatus::Invalid, _) | (_, KeyStatus::Invalid) => Ok(KeyStatus::Invalid),
127            _ => Ok(KeyStatus::Valid),
128        }
129    }
130    fn dummy_load(&self) {
131        // SECURITY: Perform dummy Argon2 verification to match timing of real verification
132        // This prevents timing attacks that could distinguish between "invalid hash format"
133        // and "valid hash but wrong password" errors
134        use crate::ExposeSecret;
135        let dummy_bytes = self.dummy_password.expose_secret().as_bytes();
136        parse_token(dummy_bytes, self.has_checksum).ok();
137
138        Argon2::default()
139            .verify_password(dummy_bytes, &self.hash.password_hash())
140            .ok();
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use crate::ExposeSecret;
148    use crate::{config::HashConfig, hasher::KeyHasher, SecureString};
149
150    fn dummy_key_and_hash() -> (SecureString, String) {
151        let key = SecureString::from("sk-live-dummy123test".to_string());
152        let hasher = KeyHasher::new(HashConfig::default());
153        let hash = hasher.hash(&key).unwrap();
154        (key, hash)
155    }
156
157    #[test]
158    fn test_verification() {
159        let key = SecureString::from("sk_live_testkey123".to_string());
160        let hasher = KeyHasher::new(HashConfig::default());
161        let hash = hasher.hash(&key).unwrap();
162
163        let (dummy_key, dummy_hash) = dummy_key_and_hash();
164        let validator = KeyValidator::new(true, dummy_key, dummy_hash).unwrap();
165        assert_eq!(
166            validator
167                .verify(
168                    key.expose_secret(),
169                    hash.as_ref(),
170                    std::time::Duration::ZERO
171                )
172                .unwrap(),
173            KeyStatus::Valid
174        );
175        assert_eq!(
176            validator
177                .verify("wrong_key", hash.as_ref(), std::time::Duration::ZERO)
178                .unwrap(),
179            KeyStatus::Invalid
180        );
181    }
182
183    #[test]
184    fn test_invalid_hash_format() {
185        let (dummy_key, dummy_hash) = dummy_key_and_hash();
186        let validator = KeyValidator::new(true, dummy_key, dummy_hash).unwrap();
187        let result = validator.verify("any_key", "invalid_hash", std::time::Duration::ZERO);
188        // After timing oracle fix: invalid hash format returns Ok(Invalid) instead of Err
189        // to prevent timing-based user enumeration attacks
190        assert!(result.is_ok());
191        assert_eq!(result.unwrap(), KeyStatus::Invalid);
192    }
193
194    #[test]
195    fn test_oversized_key_rejection() {
196        let oversized_key = "a".repeat(513); // Exceeds MAX_KEY_LENGTH
197        let valid_key = SecureString::from("valid_key".to_string());
198        let hasher = KeyHasher::new(HashConfig::default());
199        let hash = hasher.hash(&valid_key).unwrap();
200
201        let (dummy_key, dummy_hash) = dummy_key_and_hash();
202        let validator = KeyValidator::new(true, dummy_key, dummy_hash).unwrap();
203        let result = validator.verify(&oversized_key, hash.as_ref(), std::time::Duration::ZERO);
204        assert!(result.is_err());
205        assert!(matches!(result.unwrap_err(), Error::InvalidFormat));
206    }
207
208    #[test]
209    fn test_oversized_hash_rejection() {
210        let oversized_hash = "a".repeat(513); // Exceeds MAX_HASH_LENGTH
211
212        let (dummy_key, dummy_hash) = dummy_key_and_hash();
213        let validator = KeyValidator::new(true, dummy_key, dummy_hash).unwrap();
214        let result = validator.verify("valid_key", &oversized_hash, std::time::Duration::ZERO);
215        assert!(result.is_err());
216        assert!(matches!(result.unwrap_err(), Error::InvalidFormat));
217    }
218
219    #[test]
220    fn test_boundary_key_length() {
221        let valid_key = SecureString::from("valid_key".to_string());
222        let hasher = KeyHasher::new(HashConfig::default());
223        let hash = hasher.hash(&valid_key).unwrap();
224
225        let (dummy_key, dummy_hash) = dummy_key_and_hash();
226        let validator = KeyValidator::new(true, dummy_key, dummy_hash).unwrap();
227
228        // Test at boundary (512 chars - should pass)
229        let max_key = "a".repeat(512);
230        let result = validator.verify(&max_key, hash.as_ref(), std::time::Duration::ZERO);
231        assert!(result.is_ok()); // Should not error on length check
232
233        // Test just over boundary (513 chars - should fail)
234        let over_max_key = "a".repeat(513);
235        let result = validator.verify(&over_max_key, hash.as_ref(), std::time::Duration::ZERO);
236        assert!(result.is_err());
237        assert!(matches!(result.unwrap_err(), Error::InvalidFormat));
238    }
239
240    #[test]
241    fn test_timing_oracle_protection() {
242        let valid_key = SecureString::from("sk_live_testkey123".to_string());
243        let hasher = KeyHasher::new(HashConfig::default());
244        let valid_hash = hasher.hash(&valid_key).unwrap();
245
246        let (dummy_key, dummy_hash) = dummy_key_and_hash();
247        let validator = KeyValidator::new(true, dummy_key, dummy_hash).unwrap();
248
249        let result1 = validator.verify("wrong_key", valid_hash.as_ref(), std::time::Duration::ZERO);
250        assert!(result1.is_ok());
251        assert_eq!(result1.unwrap(), KeyStatus::Invalid);
252
253        let result2 = validator.verify(
254            valid_key.expose_secret(),
255            "invalid_hash_format",
256            std::time::Duration::ZERO,
257        );
258        assert!(result2.is_ok());
259        assert_eq!(result2.unwrap(), KeyStatus::Invalid);
260
261        let result3 = validator.verify(
262            valid_key.expose_secret(),
263            "not even close to valid",
264            std::time::Duration::ZERO,
265        );
266        assert!(result3.is_ok());
267        assert_eq!(result3.unwrap(), KeyStatus::Invalid);
268    }
269}