Skip to main content

fraiseql_wire/auth/scram/
mod.rs

1//! SCRAM-SHA-256 authentication implementation
2//!
3//! Implements the SCRAM-SHA-256 (Salted Challenge Response Authentication Mechanism)
4//! as defined in RFC 5802 for PostgreSQL authentication (Postgres 10+).
5
6use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
7use hmac::{Hmac, Mac};
8use pbkdf2::pbkdf2;
9use rand::Rng;
10use sha2::{Digest, Sha256};
11use std::fmt;
12use zeroize::Zeroizing;
13
14type HmacSha256 = Hmac<Sha256>;
15
16/// Maximum PBKDF2 iteration count accepted from the server (DoS protection).
17///
18/// A malicious server can supply a very large `i=` value in its SCRAM first message,
19/// causing the client to spend seconds (or minutes) in PBKDF2 before the connection
20/// is rejected. Capping at 1,000,000 prevents this denial-of-service vector while
21/// remaining orders of magnitude above typical PostgreSQL defaults (4096–600,000).
22pub(crate) const MAX_SCRAM_ITERATIONS: u32 = 1_000_000;
23
24/// SCRAM authentication error types
25#[derive(Debug, Clone)]
26#[non_exhaustive]
27pub enum ScramError {
28    /// Invalid proof from server
29    InvalidServerProof(String),
30    /// Invalid server message format
31    InvalidServerMessage(String),
32    /// UTF-8 encoding/decoding error
33    Utf8Error(String),
34    /// Base64 decoding error
35    Base64Error(String),
36}
37
38impl fmt::Display for ScramError {
39    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40        match self {
41            ScramError::InvalidServerProof(msg) => write!(f, "invalid server proof: {}", msg),
42            ScramError::InvalidServerMessage(msg) => write!(f, "invalid server message: {}", msg),
43            ScramError::Utf8Error(msg) => write!(f, "UTF-8 error: {}", msg),
44            ScramError::Base64Error(msg) => write!(f, "Base64 error: {}", msg),
45        }
46    }
47}
48
49impl std::error::Error for ScramError {}
50
51/// Internal state needed for SCRAM authentication
52#[derive(Clone, Debug)]
53pub struct ScramState {
54    /// Combined authentication message (for verification)
55    auth_message: Vec<u8>,
56    /// Server key (for verification calculation)
57    server_key: Vec<u8>,
58}
59
60/// SCRAM-SHA-256 client implementation
61pub struct ScramClient {
62    username: String,
63    /// Password is stored as `Zeroizing<String>` so the key material is
64    /// overwritten with zeros when `ScramClient` is dropped (S38).
65    password: Zeroizing<String>,
66    nonce: String,
67}
68
69impl ScramClient {
70    /// Create a new SCRAM client
71    #[must_use]
72    pub fn new(username: String, password: String) -> Self {
73        // SECURITY: rand::rng() is backed by OS-level entropy for SCRAM nonces.
74        let mut rng = rand::rng();
75        let nonce_bytes: Vec<u8> = (0..24).map(|_| rng.random()).collect();
76        let nonce = BASE64.encode(&nonce_bytes);
77
78        Self {
79            username,
80            password: Zeroizing::new(password),
81            nonce,
82        }
83    }
84
85    /// Generate client first message (no proof)
86    #[must_use]
87    pub fn client_first(&self) -> String {
88        // RFC 5802 format: gs2-header client-first-message-bare
89        // gs2-header = "n,," (n = no channel binding, empty authorization identity)
90        // client-first-message-bare = "n=<username>,r=<nonce>"
91        // RFC 5802 §5.1: username must have ',' escaped as '=2C' and '=' escaped as '=3D'.
92        let escaped_username = self.username.replace('=', "=3D").replace(',', "=2C");
93        format!("n,,n={},r={}", escaped_username, self.nonce)
94    }
95
96    /// Process server first message and generate client final message
97    ///
98    /// Returns (`client_final_message`, `internal_state`)
99    ///
100    /// # Errors
101    ///
102    /// Returns [`ScramError::InvalidServerMessage`] if the server message cannot be parsed,
103    /// the server nonce does not start with the client nonce, or the iteration count is
104    /// invalid or exceeds `MAX_SCRAM_ITERATIONS`. Returns [`ScramError::Base64Error`] if
105    /// the salt is not valid base64.
106    pub fn client_final(&mut self, server_first: &str) -> Result<(String, ScramState), ScramError> {
107        // Parse server first message: r=<client_nonce><server_nonce>,s=<salt>,i=<iterations>
108        let (server_nonce, salt, iterations) = parse_server_first(server_first)?;
109
110        // Verify server nonce starts with our client nonce
111        if !server_nonce.starts_with(&self.nonce) {
112            return Err(ScramError::InvalidServerMessage(
113                "server nonce doesn't contain client nonce".to_string(),
114            ));
115        }
116
117        // Decode salt and iterations
118        let salt_bytes = BASE64
119            .decode(&salt)
120            .map_err(|_| ScramError::Base64Error("invalid salt encoding".to_string()))?;
121        let iterations = iterations
122            .parse::<u32>()
123            .map_err(|_| ScramError::InvalidServerMessage("invalid iteration count".to_string()))?;
124
125        // SECURITY: Guard against server-supplied iteration counts large enough to
126        // cause a denial-of-service via excessive PBKDF2 CPU time.
127        if iterations > MAX_SCRAM_ITERATIONS {
128            return Err(ScramError::InvalidServerMessage(format!(
129                "server iteration count {iterations} exceeds maximum of {MAX_SCRAM_ITERATIONS}"
130            )));
131        }
132
133        // Build channel binding (no channel binding for SCRAM-SHA-256)
134        let channel_binding = BASE64.encode(b"n,,");
135
136        // Build client final without proof
137        let client_final_without_proof = format!("c={},r={}", channel_binding, server_nonce);
138
139        // Build auth message for signature calculation.
140        // client-first-message-bare is "n=<escaped_username>,r=<nonce>" (without gs2-header).
141        // SECURITY: Must use the RFC 5802 §5.1-escaped username (same as client_first()),
142        // not the raw username — otherwise an attacker who controls ',' or '=' in a username
143        // can inject arbitrary SCRAM attributes and break authentication.
144        let escaped_username = self.username.replace('=', "=3D").replace(',', "=2C");
145        let client_first_bare = format!("n={},r={}", escaped_username, self.nonce);
146        let auth_message = format!(
147            "{},{},{}",
148            client_first_bare, server_first, client_final_without_proof
149        );
150
151        // Calculate proof
152        let proof = calculate_client_proof(
153            &self.password,
154            &salt_bytes,
155            iterations,
156            auth_message.as_bytes(),
157        )?;
158
159        // Calculate server signature for later verification
160        let server_key = calculate_server_key(&self.password, &salt_bytes, iterations)?;
161
162        // Build client final message
163        let client_final = format!("{},p={}", client_final_without_proof, BASE64.encode(&proof));
164
165        let state = ScramState {
166            auth_message: auth_message.into_bytes(),
167            server_key,
168        };
169
170        Ok((client_final, state))
171    }
172
173    /// Verify server final message and confirm authentication
174    ///
175    /// # Errors
176    ///
177    /// Returns `ScramError::InvalidServerMessage` if the server final message is malformed.
178    /// Returns `ScramError::Base64Error` if the server signature is not valid base64.
179    /// Returns `ScramError::AuthenticationFailed` if the server signature does not match.
180    pub fn verify_server_final(
181        &self,
182        server_final: &str,
183        state: &ScramState,
184    ) -> Result<(), ScramError> {
185        // Parse server final: v=<server_signature>
186        let server_sig_encoded = server_final
187            .strip_prefix("v=")
188            .ok_or_else(|| ScramError::InvalidServerMessage("missing 'v=' prefix".to_string()))?;
189
190        let server_signature = BASE64.decode(server_sig_encoded).map_err(|_| {
191            ScramError::Base64Error("invalid server signature encoding".to_string())
192        })?;
193
194        // Calculate expected server signature
195        let expected_signature =
196            calculate_server_signature(&state.server_key, &state.auth_message)?;
197
198        // Constant-time comparison
199        if constant_time_compare(&server_signature, &expected_signature) {
200            Ok(())
201        } else {
202            Err(ScramError::InvalidServerProof(
203                "server signature verification failed".to_string(),
204            ))
205        }
206    }
207}
208
209/// Parse server first message format: r=<nonce>,s=<salt>,i=<iterations>
210pub(crate) fn parse_server_first(msg: &str) -> Result<(String, String, String), ScramError> {
211    let mut nonce = String::new();
212    let mut salt = String::new();
213    let mut iterations = String::new();
214
215    for part in msg.split(',') {
216        if let Some(value) = part.strip_prefix("r=") {
217            nonce = value.to_string();
218        } else if let Some(value) = part.strip_prefix("s=") {
219            salt = value.to_string();
220        } else if let Some(value) = part.strip_prefix("i=") {
221            iterations = value.to_string();
222        }
223    }
224
225    if nonce.is_empty() || salt.is_empty() || iterations.is_empty() {
226        return Err(ScramError::InvalidServerMessage(
227            "missing required fields in server first message".to_string(),
228        ));
229    }
230
231    Ok((nonce, salt, iterations))
232}
233
234/// Calculate SCRAM client proof
235fn calculate_client_proof(
236    password: &str,
237    salt: &[u8],
238    iterations: u32,
239    auth_message: &[u8],
240) -> Result<Vec<u8>, ScramError> {
241    // SaltedPassword := PBKDF2(password, salt, iterations, HMAC-SHA256)
242    let password_bytes = password.as_bytes();
243    let mut salted_password = vec![0u8; 32]; // SHA256 produces 32 bytes
244    let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);
245
246    // ClientKey := HMAC(SaltedPassword, "Client Key")
247    let mut client_key_hmac = HmacSha256::new_from_slice(&salted_password)
248        .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
249    client_key_hmac.update(b"Client Key");
250    let client_key = client_key_hmac.finalize().into_bytes();
251
252    // StoredKey := SHA256(ClientKey)
253    let stored_key = Sha256::digest(client_key.to_vec().as_slice());
254
255    // ClientSignature := HMAC(StoredKey, AuthMessage)
256    let mut client_sig_hmac = HmacSha256::new_from_slice(&stored_key)
257        .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
258    client_sig_hmac.update(auth_message);
259    let client_signature = client_sig_hmac.finalize().into_bytes();
260
261    // ClientProof := ClientKey XOR ClientSignature
262    let mut proof = client_key.to_vec();
263    for (proof_byte, sig_byte) in proof.iter_mut().zip(client_signature.iter()) {
264        *proof_byte ^= sig_byte;
265    }
266
267    Ok(proof.clone())
268}
269
270/// Calculate server key for server signature verification
271fn calculate_server_key(
272    password: &str,
273    salt: &[u8],
274    iterations: u32,
275) -> Result<Vec<u8>, ScramError> {
276    // SaltedPassword := PBKDF2(password, salt, iterations, HMAC-SHA256)
277    let password_bytes = password.as_bytes();
278    let mut salted_password = vec![0u8; 32];
279    let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);
280
281    // ServerKey := HMAC(SaltedPassword, "Server Key")
282    let mut server_key_hmac = HmacSha256::new_from_slice(&salted_password)
283        .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
284    server_key_hmac.update(b"Server Key");
285
286    Ok(server_key_hmac.finalize().into_bytes().to_vec())
287}
288
289/// Calculate server signature for verification
290fn calculate_server_signature(
291    server_key: &[u8],
292    auth_message: &[u8],
293) -> Result<Vec<u8>, ScramError> {
294    let mut hmac = HmacSha256::new_from_slice(server_key)
295        .map_err(|_| ScramError::Utf8Error("invalid HMAC key for server signature".to_string()))?;
296    hmac.update(auth_message);
297    Ok(hmac.finalize().into_bytes().to_vec())
298}
299
300/// Constant-time comparison to prevent timing attacks.
301///
302/// Uses the `subtle` crate for verified constant-time operations.
303pub(crate) fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
304    use subtle::ConstantTimeEq;
305    a.ct_eq(b).into()
306}
307
308#[cfg(test)]
309mod tests;