Skip to main content

fraiseql_wire/auth/
scram.rs

1//! SCRAM-SHA-256 authentication implementation
2//!
3//! Implements the SCRAM-SHA-256 (Salted Challenge Response Authentication Mechanism)
4//! as defined in RFC 5802 for PostgreSQL authentication (Postgres 10+).
5
6use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
7use hmac::{Hmac, Mac};
8use pbkdf2::pbkdf2;
9use rand::{rngs::OsRng, Rng};
10use sha2::{Digest, Sha256};
11use std::fmt;
12
13type HmacSha256 = Hmac<Sha256>;
14
15/// Maximum PBKDF2 iteration count accepted from the server (DoS protection).
16///
17/// A malicious server can supply a very large `i=` value in its SCRAM first message,
18/// causing the client to spend seconds (or minutes) in PBKDF2 before the connection
19/// is rejected. Capping at 1,000,000 prevents this denial-of-service vector while
20/// remaining orders of magnitude above typical PostgreSQL defaults (4096–600,000).
21const MAX_SCRAM_ITERATIONS: u32 = 1_000_000;
22
23/// SCRAM authentication error types
24#[derive(Debug, Clone)]
25#[non_exhaustive]
26pub enum ScramError {
27    /// Invalid proof from server
28    InvalidServerProof(String),
29    /// Invalid server message format
30    InvalidServerMessage(String),
31    /// UTF-8 encoding/decoding error
32    Utf8Error(String),
33    /// Base64 decoding error
34    Base64Error(String),
35}
36
37impl fmt::Display for ScramError {
38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39        match self {
40            ScramError::InvalidServerProof(msg) => write!(f, "invalid server proof: {}", msg),
41            ScramError::InvalidServerMessage(msg) => write!(f, "invalid server message: {}", msg),
42            ScramError::Utf8Error(msg) => write!(f, "UTF-8 error: {}", msg),
43            ScramError::Base64Error(msg) => write!(f, "Base64 error: {}", msg),
44        }
45    }
46}
47
48impl std::error::Error for ScramError {}
49
50/// Internal state needed for SCRAM authentication
51#[derive(Clone, Debug)]
52pub struct ScramState {
53    /// Combined authentication message (for verification)
54    auth_message: Vec<u8>,
55    /// Server key (for verification calculation)
56    server_key: Vec<u8>,
57}
58
59/// SCRAM-SHA-256 client implementation
60pub struct ScramClient {
61    username: String,
62    password: String,
63    nonce: String,
64}
65
66impl ScramClient {
67    /// Create a new SCRAM client
68    pub fn new(username: String, password: String) -> Self {
69        // SECURITY: OsRng guarantees OS-level entropy for SCRAM nonces.
70        let mut rng = OsRng;
71        let nonce_bytes: Vec<u8> = (0..24).map(|_| rng.gen()).collect();
72        let nonce = BASE64.encode(&nonce_bytes);
73
74        Self {
75            username,
76            password,
77            nonce,
78        }
79    }
80
81    /// Generate client first message (no proof)
82    pub fn client_first(&self) -> String {
83        // RFC 5802 format: gs2-header client-first-message-bare
84        // gs2-header = "n,," (n = no channel binding, empty authorization identity)
85        // client-first-message-bare = "n=<username>,r=<nonce>"
86        // RFC 5802 §5.1: username must have ',' escaped as '=2C' and '=' escaped as '=3D'.
87        let escaped_username = self.username.replace('=', "=3D").replace(',', "=2C");
88        format!("n,,n={},r={}", escaped_username, self.nonce)
89    }
90
91    /// Process server first message and generate client final message
92    ///
93    /// Returns (`client_final_message`, `internal_state`)
94    ///
95    /// # Errors
96    ///
97    /// Returns [`ScramError::InvalidServerMessage`] if the server message cannot be parsed,
98    /// the server nonce does not start with the client nonce, or the iteration count is
99    /// invalid or exceeds `MAX_SCRAM_ITERATIONS`. Returns [`ScramError::Base64Error`] if
100    /// the salt is not valid base64.
101    pub fn client_final(&mut self, server_first: &str) -> Result<(String, ScramState), ScramError> {
102        // Parse server first message: r=<client_nonce><server_nonce>,s=<salt>,i=<iterations>
103        let (server_nonce, salt, iterations) = parse_server_first(server_first)?;
104
105        // Verify server nonce starts with our client nonce
106        if !server_nonce.starts_with(&self.nonce) {
107            return Err(ScramError::InvalidServerMessage(
108                "server nonce doesn't contain client nonce".to_string(),
109            ));
110        }
111
112        // Decode salt and iterations
113        let salt_bytes = BASE64
114            .decode(&salt)
115            .map_err(|_| ScramError::Base64Error("invalid salt encoding".to_string()))?;
116        let iterations = iterations
117            .parse::<u32>()
118            .map_err(|_| ScramError::InvalidServerMessage("invalid iteration count".to_string()))?;
119
120        // SECURITY: Guard against server-supplied iteration counts large enough to
121        // cause a denial-of-service via excessive PBKDF2 CPU time.
122        if iterations > MAX_SCRAM_ITERATIONS {
123            return Err(ScramError::InvalidServerMessage(format!(
124                "server iteration count {iterations} exceeds maximum of {MAX_SCRAM_ITERATIONS}"
125            )));
126        }
127
128        // Build channel binding (no channel binding for SCRAM-SHA-256)
129        let channel_binding = BASE64.encode(b"n,,");
130
131        // Build client final without proof
132        let client_final_without_proof = format!("c={},r={}", channel_binding, server_nonce);
133
134        // Build auth message for signature calculation.
135        // client-first-message-bare is "n=<escaped_username>,r=<nonce>" (without gs2-header).
136        // SECURITY: Must use the RFC 5802 §5.1-escaped username (same as client_first()),
137        // not the raw username — otherwise an attacker who controls ',' or '=' in a username
138        // can inject arbitrary SCRAM attributes and break authentication.
139        let escaped_username = self.username.replace('=', "=3D").replace(',', "=2C");
140        let client_first_bare = format!("n={},r={}", escaped_username, self.nonce);
141        let auth_message = format!(
142            "{},{},{}",
143            client_first_bare, server_first, client_final_without_proof
144        );
145
146        // Calculate proof
147        let proof = calculate_client_proof(
148            &self.password,
149            &salt_bytes,
150            iterations,
151            auth_message.as_bytes(),
152        )?;
153
154        // Calculate server signature for later verification
155        let server_key = calculate_server_key(&self.password, &salt_bytes, iterations)?;
156
157        // Build client final message
158        let client_final = format!("{},p={}", client_final_without_proof, BASE64.encode(&proof));
159
160        let state = ScramState {
161            auth_message: auth_message.into_bytes(),
162            server_key,
163        };
164
165        Ok((client_final, state))
166    }
167
168    /// Verify server final message and confirm authentication
169    ///
170    /// # Errors
171    ///
172    /// Returns `ScramError::InvalidServerMessage` if the server final message is malformed.
173    /// Returns `ScramError::Base64Error` if the server signature is not valid base64.
174    /// Returns `ScramError::AuthenticationFailed` if the server signature does not match.
175    pub fn verify_server_final(
176        &self,
177        server_final: &str,
178        state: &ScramState,
179    ) -> Result<(), ScramError> {
180        // Parse server final: v=<server_signature>
181        let server_sig_encoded = server_final
182            .strip_prefix("v=")
183            .ok_or_else(|| ScramError::InvalidServerMessage("missing 'v=' prefix".to_string()))?;
184
185        let server_signature = BASE64.decode(server_sig_encoded).map_err(|_| {
186            ScramError::Base64Error("invalid server signature encoding".to_string())
187        })?;
188
189        // Calculate expected server signature
190        let expected_signature =
191            calculate_server_signature(&state.server_key, &state.auth_message)?;
192
193        // Constant-time comparison
194        if constant_time_compare(&server_signature, &expected_signature) {
195            Ok(())
196        } else {
197            Err(ScramError::InvalidServerProof(
198                "server signature verification failed".to_string(),
199            ))
200        }
201    }
202}
203
204/// Parse server first message format: r=<nonce>,s=<salt>,i=<iterations>
205fn parse_server_first(msg: &str) -> Result<(String, String, String), ScramError> {
206    let mut nonce = String::new();
207    let mut salt = String::new();
208    let mut iterations = String::new();
209
210    for part in msg.split(',') {
211        if let Some(value) = part.strip_prefix("r=") {
212            nonce = value.to_string();
213        } else if let Some(value) = part.strip_prefix("s=") {
214            salt = value.to_string();
215        } else if let Some(value) = part.strip_prefix("i=") {
216            iterations = value.to_string();
217        }
218    }
219
220    if nonce.is_empty() || salt.is_empty() || iterations.is_empty() {
221        return Err(ScramError::InvalidServerMessage(
222            "missing required fields in server first message".to_string(),
223        ));
224    }
225
226    Ok((nonce, salt, iterations))
227}
228
229/// Calculate SCRAM client proof
230fn calculate_client_proof(
231    password: &str,
232    salt: &[u8],
233    iterations: u32,
234    auth_message: &[u8],
235) -> Result<Vec<u8>, ScramError> {
236    // SaltedPassword := PBKDF2(password, salt, iterations, HMAC-SHA256)
237    let password_bytes = password.as_bytes();
238    let mut salted_password = vec![0u8; 32]; // SHA256 produces 32 bytes
239    let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);
240
241    // ClientKey := HMAC(SaltedPassword, "Client Key")
242    let mut client_key_hmac = HmacSha256::new_from_slice(&salted_password)
243        .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
244    client_key_hmac.update(b"Client Key");
245    let client_key = client_key_hmac.finalize().into_bytes();
246
247    // StoredKey := SHA256(ClientKey)
248    let stored_key = Sha256::digest(client_key.to_vec().as_slice());
249
250    // ClientSignature := HMAC(StoredKey, AuthMessage)
251    let mut client_sig_hmac = HmacSha256::new_from_slice(&stored_key)
252        .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
253    client_sig_hmac.update(auth_message);
254    let client_signature = client_sig_hmac.finalize().into_bytes();
255
256    // ClientProof := ClientKey XOR ClientSignature
257    let mut proof = client_key.to_vec();
258    for (proof_byte, sig_byte) in proof.iter_mut().zip(client_signature.iter()) {
259        *proof_byte ^= sig_byte;
260    }
261
262    Ok(proof.clone())
263}
264
265/// Calculate server key for server signature verification
266fn calculate_server_key(
267    password: &str,
268    salt: &[u8],
269    iterations: u32,
270) -> Result<Vec<u8>, ScramError> {
271    // SaltedPassword := PBKDF2(password, salt, iterations, HMAC-SHA256)
272    let password_bytes = password.as_bytes();
273    let mut salted_password = vec![0u8; 32];
274    let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);
275
276    // ServerKey := HMAC(SaltedPassword, "Server Key")
277    let mut server_key_hmac = HmacSha256::new_from_slice(&salted_password)
278        .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
279    server_key_hmac.update(b"Server Key");
280
281    Ok(server_key_hmac.finalize().into_bytes().to_vec())
282}
283
284/// Calculate server signature for verification
285fn calculate_server_signature(
286    server_key: &[u8],
287    auth_message: &[u8],
288) -> Result<Vec<u8>, ScramError> {
289    let mut hmac = HmacSha256::new_from_slice(server_key)
290        .map_err(|_| ScramError::Utf8Error("invalid HMAC key for server signature".to_string()))?;
291    hmac.update(auth_message);
292    Ok(hmac.finalize().into_bytes().to_vec())
293}
294
295/// Constant-time comparison to prevent timing attacks.
296///
297/// Uses the `subtle` crate for verified constant-time operations.
298fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
299    use subtle::ConstantTimeEq;
300    a.ct_eq(b).into()
301}
302
303#[cfg(test)]
304mod tests {
305    #![allow(clippy::unwrap_used)] // Reason: test code, panics are acceptable
306    use super::*;
307
308    #[test]
309    fn test_scram_client_creation() {
310        let client = ScramClient::new("user".to_string(), "password".to_string());
311        assert_eq!(client.username, "user");
312        assert_eq!(client.password, "password");
313        assert!(!client.nonce.is_empty());
314    }
315
316    #[test]
317    fn test_client_first_message_format() {
318        let client = ScramClient::new("alice".to_string(), "secret".to_string());
319        let first = client.client_first();
320
321        // RFC 5802 format: "n,,n=<username>,r=<nonce>"
322        assert!(first.starts_with("n,,n=alice,r="));
323        assert!(first.len() > 20);
324    }
325
326    #[test]
327    fn test_parse_server_first_valid() {
328        let server_first = "r=client_nonce_server_nonce,s=aW1hZ2luYXJ5c2FsdA==,i=4096";
329        let (nonce, salt, iterations) = parse_server_first(server_first).unwrap();
330
331        assert_eq!(nonce, "client_nonce_server_nonce");
332        assert_eq!(salt, "aW1hZ2luYXJ5c2FsdA==");
333        assert_eq!(iterations, "4096");
334    }
335
336    #[test]
337    fn test_parse_server_first_invalid() {
338        let server_first = "r=nonce,s=salt"; // missing iterations
339        let result = parse_server_first(server_first);
340        assert!(
341            matches!(result, Err(ScramError::InvalidServerMessage(_))),
342            "expected InvalidServerMessage error, got: {result:?}"
343        );
344    }
345
346    #[test]
347    fn test_constant_time_compare_equal() {
348        let a = b"test_value";
349        let b_arr = b"test_value";
350        assert!(constant_time_compare(a, b_arr));
351    }
352
353    #[test]
354    fn test_constant_time_compare_different() {
355        let a = b"test_value";
356        let b_arr = b"test_wrong";
357        assert!(!constant_time_compare(a, b_arr));
358    }
359
360    #[test]
361    fn test_constant_time_compare_different_length() {
362        let a = b"test";
363        let b_arr = b"test_longer";
364        assert!(!constant_time_compare(a, b_arr));
365    }
366
367    #[test]
368    fn test_scram_client_final_flow() {
369        let mut client = ScramClient::new("user".to_string(), "password".to_string());
370        let _client_first = client.client_first();
371
372        // Simulate server response
373        let server_nonce = format!("{}server_nonce_part", client.nonce);
374        let server_first = format!("r={},s={},i=4096", server_nonce, BASE64.encode(b"salty"));
375
376        // Should succeed with valid format
377        let result = client.client_final(&server_first);
378        let (client_final, state) = result.unwrap_or_else(|e| {
379            panic!("expected Ok for client_final with valid server message: {e}")
380        });
381        assert!(client_final.starts_with("c="));
382        assert!(!state.auth_message.is_empty());
383    }
384
385    #[test]
386    fn test_scram_iteration_count_too_high_is_rejected() {
387        // H3: A server-supplied i= value above MAX_SCRAM_ITERATIONS must be rejected
388        // to prevent PBKDF2-based denial-of-service.
389        let mut client = ScramClient::new("user".to_string(), "password".to_string());
390        let _client_first = client.client_first();
391
392        let server_nonce = format!("{}server_nonce_part", client.nonce);
393        let excessive_iterations = MAX_SCRAM_ITERATIONS + 1;
394        let server_first = format!(
395            "r={},s={},i={}",
396            server_nonce,
397            BASE64.encode(b"salty"),
398            excessive_iterations
399        );
400
401        let result = client.client_final(&server_first);
402        assert!(
403            matches!(result, Err(ScramError::InvalidServerMessage(_))),
404            "expected InvalidServerMessage for excessive iterations, got: {result:?}"
405        );
406    }
407
408    #[test]
409    fn test_scram_iteration_count_at_limit_is_accepted() {
410        // Exactly MAX_SCRAM_ITERATIONS must be accepted.
411        let mut client = ScramClient::new("user".to_string(), "password".to_string());
412        let _client_first = client.client_first();
413
414        let server_nonce = format!("{}server_nonce_part", client.nonce);
415        let server_first = format!(
416            "r={},s={},i={}",
417            server_nonce,
418            BASE64.encode(b"salty"),
419            MAX_SCRAM_ITERATIONS
420        );
421
422        // Should not fail on the iteration count check (may fail for other reasons if any)
423        let result = client.client_final(&server_first);
424        // We only care that it didn't fail with an iteration-count error
425        if let Err(ScramError::InvalidServerMessage(msg)) = &result {
426            assert!(
427                !msg.contains("iteration count"),
428                "unexpected iteration-count rejection at limit: {msg}"
429            );
430        }
431    }
432
433    #[test]
434    fn test_scram_username_escaping_in_auth_message() {
435        // H4: The auth message must use the RFC 5802-escaped username, not the raw one.
436        // A username containing ',' or '=' must be escaped in client_first_bare.
437        let mut client = ScramClient::new("user,admin=evil".to_string(), "password".to_string());
438        let client_first = client.client_first();
439        // client_first should have escaped username
440        assert!(
441            client_first.contains("user=2Cadmin=3Devil"),
442            "client_first should escape ',' and '=' in username, got: {client_first}"
443        );
444
445        // client_final should use the same escaped username in the auth message
446        let server_nonce = format!("{}server_nonce_part", client.nonce);
447        let server_first = format!("r={},s={},i=4096", server_nonce, BASE64.encode(b"salty"));
448
449        let result = client.client_final(&server_first);
450        let (_client_final, state) =
451            result.unwrap_or_else(|e| panic!("expected Ok for escaped-username client_final: {e}"));
452
453        // The auth message must contain the escaped username, not the raw one
454        let auth_message = String::from_utf8(state.auth_message).unwrap();
455        assert!(
456            auth_message.contains("user=2Cadmin=3Devil"),
457            "auth_message should contain escaped username, got: {auth_message}"
458        );
459        assert!(
460            !auth_message.contains("user,admin=evil"),
461            "auth_message must NOT contain raw (unescaped) username, got: {auth_message}"
462        );
463    }
464}