Skip to main content

fraiseql_wire/auth/
scram.rs

1//! SCRAM-SHA-256 authentication implementation
2//!
3//! Implements the SCRAM-SHA-256 (Salted Challenge Response Authentication Mechanism)
4//! as defined in RFC 5802 for PostgreSQL authentication (Postgres 10+).
5
6use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
7use hmac::{Hmac, Mac};
8use pbkdf2::pbkdf2;
9use rand::{rngs::OsRng, Rng};
10use sha2::{Digest, Sha256};
11use std::fmt;
12
13type HmacSha256 = Hmac<Sha256>;
14
15/// SCRAM authentication error types
16#[derive(Debug, Clone)]
17pub enum ScramError {
18    /// Invalid proof from server
19    InvalidServerProof(String),
20    /// Invalid server message format
21    InvalidServerMessage(String),
22    /// UTF-8 encoding/decoding error
23    Utf8Error(String),
24    /// Base64 decoding error
25    Base64Error(String),
26}
27
28impl fmt::Display for ScramError {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        match self {
31            ScramError::InvalidServerProof(msg) => write!(f, "invalid server proof: {}", msg),
32            ScramError::InvalidServerMessage(msg) => write!(f, "invalid server message: {}", msg),
33            ScramError::Utf8Error(msg) => write!(f, "UTF-8 error: {}", msg),
34            ScramError::Base64Error(msg) => write!(f, "Base64 error: {}", msg),
35        }
36    }
37}
38
39impl std::error::Error for ScramError {}
40
41/// Internal state needed for SCRAM authentication
42#[derive(Clone, Debug)]
43pub struct ScramState {
44    /// Combined authentication message (for verification)
45    auth_message: Vec<u8>,
46    /// Server key (for verification calculation)
47    server_key: Vec<u8>,
48}
49
50/// SCRAM-SHA-256 client implementation
51pub struct ScramClient {
52    username: String,
53    password: String,
54    nonce: String,
55}
56
57impl ScramClient {
58    /// Create a new SCRAM client
59    pub fn new(username: String, password: String) -> Self {
60        // SECURITY: OsRng guarantees OS-level entropy for SCRAM nonces.
61        let mut rng = OsRng;
62        let nonce_bytes: Vec<u8> = (0..24).map(|_| rng.gen()).collect();
63        let nonce = BASE64.encode(&nonce_bytes);
64
65        Self {
66            username,
67            password,
68            nonce,
69        }
70    }
71
72    /// Generate client first message (no proof)
73    pub fn client_first(&self) -> String {
74        // RFC 5802 format: gs2-header client-first-message-bare
75        // gs2-header = "n,," (n = no channel binding, empty authorization identity)
76        // client-first-message-bare = "n=<username>,r=<nonce>"
77        // RFC 5802 ยง5.1: username must have ',' escaped as '=2C' and '=' escaped as '=3D'.
78        let escaped_username = self.username.replace('=', "=3D").replace(',', "=2C");
79        format!("n,,n={},r={}", escaped_username, self.nonce)
80    }
81
82    /// Process server first message and generate client final message
83    ///
84    /// Returns (`client_final_message`, `internal_state`)
85    ///
86    /// # Errors
87    ///
88    /// Returns `ScramError::InvalidServerMessage` if the server message is malformed or the nonce is invalid.
89    /// Returns `ScramError::Base64Error` if the salt encoding is invalid.
90    /// Returns `ScramError::Utf8Error` if HMAC key derivation fails.
91    pub fn client_final(&mut self, server_first: &str) -> Result<(String, ScramState), ScramError> {
92        // Parse server first message: r=<client_nonce><server_nonce>,s=<salt>,i=<iterations>
93        let (server_nonce, salt, iterations) = parse_server_first(server_first)?;
94
95        // Verify server nonce starts with our client nonce
96        if !server_nonce.starts_with(&self.nonce) {
97            return Err(ScramError::InvalidServerMessage(
98                "server nonce doesn't contain client nonce".to_string(),
99            ));
100        }
101
102        // Decode salt and iterations
103        let salt_bytes = BASE64
104            .decode(&salt)
105            .map_err(|_| ScramError::Base64Error("invalid salt encoding".to_string()))?;
106        let iterations = iterations
107            .parse::<u32>()
108            .map_err(|_| ScramError::InvalidServerMessage("invalid iteration count".to_string()))?;
109
110        // Build channel binding (no channel binding for SCRAM-SHA-256)
111        let channel_binding = BASE64.encode(b"n,,");
112
113        // Build client final without proof
114        let client_final_without_proof = format!("c={},r={}", channel_binding, server_nonce);
115
116        // Build auth message for signature calculation
117        // client-first-message-bare is "n=<username>,r=<nonce>" (without gs2-header)
118        let client_first_bare = format!("n={},r={}", self.username, self.nonce);
119        let auth_message = format!(
120            "{},{},{}",
121            client_first_bare, server_first, client_final_without_proof
122        );
123
124        // Calculate proof
125        let proof = calculate_client_proof(
126            &self.password,
127            &salt_bytes,
128            iterations,
129            auth_message.as_bytes(),
130        )?;
131
132        // Calculate server signature for later verification
133        let server_key = calculate_server_key(&self.password, &salt_bytes, iterations)?;
134
135        // Build client final message
136        let client_final = format!("{},p={}", client_final_without_proof, BASE64.encode(&proof));
137
138        let state = ScramState {
139            auth_message: auth_message.into_bytes(),
140            server_key,
141        };
142
143        Ok((client_final, state))
144    }
145
146    /// Verify server final message and confirm authentication
147    ///
148    /// # Errors
149    ///
150    /// Returns `ScramError::InvalidServerMessage` if the server final message is missing the `v=` prefix.
151    /// Returns `ScramError::Base64Error` if the server signature encoding is invalid.
152    /// Returns `ScramError::InvalidServerProof` if the server signature does not match.
153    pub fn verify_server_final(
154        &self,
155        server_final: &str,
156        state: &ScramState,
157    ) -> Result<(), ScramError> {
158        // Parse server final: v=<server_signature>
159        let server_sig_encoded = server_final
160            .strip_prefix("v=")
161            .ok_or_else(|| ScramError::InvalidServerMessage("missing 'v=' prefix".to_string()))?;
162
163        let server_signature = BASE64.decode(server_sig_encoded).map_err(|_| {
164            ScramError::Base64Error("invalid server signature encoding".to_string())
165        })?;
166
167        // Calculate expected server signature
168        let expected_signature =
169            calculate_server_signature(&state.server_key, &state.auth_message)?;
170
171        // Constant-time comparison
172        if constant_time_compare(&server_signature, &expected_signature) {
173            Ok(())
174        } else {
175            Err(ScramError::InvalidServerProof(
176                "server signature verification failed".to_string(),
177            ))
178        }
179    }
180}
181
182/// Parse server first message format: r=<nonce>,s=<salt>,i=<iterations>
183fn parse_server_first(msg: &str) -> Result<(String, String, String), ScramError> {
184    let mut nonce = String::new();
185    let mut salt = String::new();
186    let mut iterations = String::new();
187
188    for part in msg.split(',') {
189        if let Some(value) = part.strip_prefix("r=") {
190            nonce = value.to_string();
191        } else if let Some(value) = part.strip_prefix("s=") {
192            salt = value.to_string();
193        } else if let Some(value) = part.strip_prefix("i=") {
194            iterations = value.to_string();
195        }
196    }
197
198    if nonce.is_empty() || salt.is_empty() || iterations.is_empty() {
199        return Err(ScramError::InvalidServerMessage(
200            "missing required fields in server first message".to_string(),
201        ));
202    }
203
204    Ok((nonce, salt, iterations))
205}
206
207/// Calculate SCRAM client proof
208fn calculate_client_proof(
209    password: &str,
210    salt: &[u8],
211    iterations: u32,
212    auth_message: &[u8],
213) -> Result<Vec<u8>, ScramError> {
214    // SaltedPassword := PBKDF2(password, salt, iterations, HMAC-SHA256)
215    let password_bytes = password.as_bytes();
216    let mut salted_password = vec![0u8; 32]; // SHA256 produces 32 bytes
217    let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);
218
219    // ClientKey := HMAC(SaltedPassword, "Client Key")
220    let mut client_key_hmac = HmacSha256::new_from_slice(&salted_password)
221        .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
222    client_key_hmac.update(b"Client Key");
223    let client_key = client_key_hmac.finalize().into_bytes();
224
225    // StoredKey := SHA256(ClientKey)
226    let stored_key = Sha256::digest(client_key.to_vec().as_slice());
227
228    // ClientSignature := HMAC(StoredKey, AuthMessage)
229    let mut client_sig_hmac = HmacSha256::new_from_slice(&stored_key)
230        .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
231    client_sig_hmac.update(auth_message);
232    let client_signature = client_sig_hmac.finalize().into_bytes();
233
234    // ClientProof := ClientKey XOR ClientSignature
235    let mut proof = client_key.to_vec();
236    for (proof_byte, sig_byte) in proof.iter_mut().zip(client_signature.iter()) {
237        *proof_byte ^= sig_byte;
238    }
239
240    Ok(proof.clone())
241}
242
243/// Calculate server key for server signature verification
244fn calculate_server_key(
245    password: &str,
246    salt: &[u8],
247    iterations: u32,
248) -> Result<Vec<u8>, ScramError> {
249    // SaltedPassword := PBKDF2(password, salt, iterations, HMAC-SHA256)
250    let password_bytes = password.as_bytes();
251    let mut salted_password = vec![0u8; 32];
252    let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);
253
254    // ServerKey := HMAC(SaltedPassword, "Server Key")
255    let mut server_key_hmac = HmacSha256::new_from_slice(&salted_password)
256        .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
257    server_key_hmac.update(b"Server Key");
258
259    Ok(server_key_hmac.finalize().into_bytes().to_vec())
260}
261
262/// Calculate server signature for verification
263fn calculate_server_signature(
264    server_key: &[u8],
265    auth_message: &[u8],
266) -> Result<Vec<u8>, ScramError> {
267    let mut hmac = HmacSha256::new_from_slice(server_key)
268        .map_err(|_| ScramError::Utf8Error("invalid HMAC key for server signature".to_string()))?;
269    hmac.update(auth_message);
270    Ok(hmac.finalize().into_bytes().to_vec())
271}
272
273/// Constant-time comparison to prevent timing attacks.
274///
275/// Uses the `subtle` crate for verified constant-time operations.
276fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
277    use subtle::ConstantTimeEq;
278    a.ct_eq(b).into()
279}
280
281#[cfg(test)]
282mod tests {
283    #![allow(clippy::unwrap_used)] // Reason: test code, panics are acceptable
284    use super::*;
285
286    #[test]
287    fn test_scram_client_creation() {
288        let client = ScramClient::new("user".to_string(), "password".to_string());
289        assert_eq!(client.username, "user");
290        assert_eq!(client.password, "password");
291        assert!(!client.nonce.is_empty());
292    }
293
294    #[test]
295    fn test_client_first_message_format() {
296        let client = ScramClient::new("alice".to_string(), "secret".to_string());
297        let first = client.client_first();
298
299        // RFC 5802 format: "n,,n=<username>,r=<nonce>"
300        assert!(first.starts_with("n,,n=alice,r="));
301        assert!(first.len() > 20);
302    }
303
304    #[test]
305    fn test_parse_server_first_valid() {
306        let server_first = "r=client_nonce_server_nonce,s=aW1hZ2luYXJ5c2FsdA==,i=4096";
307        let (nonce, salt, iterations) = parse_server_first(server_first).unwrap();
308
309        assert_eq!(nonce, "client_nonce_server_nonce");
310        assert_eq!(salt, "aW1hZ2luYXJ5c2FsdA==");
311        assert_eq!(iterations, "4096");
312    }
313
314    #[test]
315    fn test_parse_server_first_invalid() {
316        let server_first = "r=nonce,s=salt"; // missing iterations
317        assert!(parse_server_first(server_first).is_err());
318    }
319
320    #[test]
321    fn test_constant_time_compare_equal() {
322        let a = b"test_value";
323        let b_arr = b"test_value";
324        assert!(constant_time_compare(a, b_arr));
325    }
326
327    #[test]
328    fn test_constant_time_compare_different() {
329        let a = b"test_value";
330        let b_arr = b"test_wrong";
331        assert!(!constant_time_compare(a, b_arr));
332    }
333
334    #[test]
335    fn test_constant_time_compare_different_length() {
336        let a = b"test";
337        let b_arr = b"test_longer";
338        assert!(!constant_time_compare(a, b_arr));
339    }
340
341    #[test]
342    fn test_scram_client_final_flow() {
343        let mut client = ScramClient::new("user".to_string(), "password".to_string());
344        let _client_first = client.client_first();
345
346        // Simulate server response
347        let server_nonce = format!("{}server_nonce_part", client.nonce);
348        let server_first = format!("r={},s={},i=4096", server_nonce, BASE64.encode(b"salty"));
349
350        // Should succeed with valid format
351        let result = client.client_final(&server_first);
352        assert!(result.is_ok());
353
354        let (client_final, state) = result.unwrap();
355        assert!(client_final.starts_with("c="));
356        assert!(!state.auth_message.is_empty());
357    }
358}