pgwire_replication/auth/
scram.rs

1//! SCRAM-SHA-256 authentication implementation.
2//!
3//! This module implements the SCRAM-SHA-256 authentication mechanism as specified
4//! in RFC 5802 and RFC 7677, used by PostgreSQL for secure password authentication.
5//!
6//! # Protocol Overview
7//!
8//! SCRAM (Salted Challenge Response Authentication Mechanism) provides:
9//! - Password never sent in plaintext
10//! - Mutual authentication (client verifies server)
11//! - Protection against replay attacks via nonces
12//!
13//! # Example Flow
14//!
15//! ```no_run
16//! use pgwire_replication::auth::scram::ScramClient;
17//!
18//! fn main() -> Result<(), Box<dyn std::error::Error>> {
19//!     let client = ScramClient::new("postgres");
20//!
21//!     // Send to server: client.client_first.as_bytes()
22//!     let server_first = String::new(); // received from server
23//!
24//!     let (client_final, auth_msg, salted_pw) =
25//!         client.client_final("password", &server_first)?;
26//!
27//!     // Send to server: client_final.as_bytes()
28//!     let server_final = String::new(); // received from server
29//!
30//!     ScramClient::verify_server_final(&server_final, &salted_pw, &auth_msg)?;
31//!     Ok(())
32//! }
33//! ```
34
35#[cfg(feature = "scram")]
36use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
37#[cfg(feature = "scram")]
38use hmac::{Hmac, Mac};
39#[cfg(feature = "scram")]
40use rand::RngCore;
41#[cfg(feature = "scram")]
42use sha2::{Digest, Sha256};
43
44use crate::error::{PgWireError, Result};
45
46#[cfg(feature = "scram")]
47type HmacSha256 = Hmac<Sha256>;
48
49/// SCRAM-SHA-256 client state.
50///
51/// Holds the client nonce and first message needed for the authentication exchange.
52#[cfg(feature = "scram")]
53#[derive(Debug, Clone)]
54pub struct ScramClient {
55    /// Base64-encoded client nonce (18 random bytes)
56    pub client_nonce_b64: String,
57    /// Client-first-message-bare (without channel binding prefix)
58    pub client_first_bare: String,
59    /// Complete client-first-message to send to server
60    pub client_first: String,
61}
62
63#[cfg(feature = "scram")]
64impl ScramClient {
65    /// Create a new SCRAM client with a random nonce.
66    ///
67    /// # Arguments
68    /// * `username` - PostgreSQL username (will be SASL-escaped)
69    pub fn new(username: &str) -> ScramClient {
70        let mut nonce = [0u8; 18];
71        rand::rng().fill_bytes(&mut nonce);
72        let nonce_b64 = B64.encode(nonce);
73
74        let user = sasl_escape_username(username);
75        let client_first_bare = format!("n={user},r={nonce_b64}");
76        let client_first = format!("n,,{client_first_bare}");
77
78        ScramClient {
79            client_nonce_b64: nonce_b64,
80            client_first_bare,
81            client_first,
82        }
83    }
84
85    /// Create a SCRAM client with a specific nonce (for testing).
86    #[cfg(test)]
87    pub(crate) fn with_nonce(username: &str, nonce_b64: &str) -> ScramClient {
88        let user = sasl_escape_username(username);
89        let client_first_bare = format!("n={user},r={nonce_b64}");
90        let client_first = format!("n,,{client_first_bare}");
91
92        ScramClient {
93            client_nonce_b64: nonce_b64.to_string(),
94            client_first_bare,
95            client_first,
96        }
97    }
98
99    /// Parse server-first-message.
100    ///
101    /// Extracts:
102    /// - `r`: Combined nonce (client nonce + server nonce)
103    /// - `s`: Base64-encoded salt
104    /// - `i`: Iteration count
105    ///
106    /// # Errors
107    /// Returns error if any required field is missing or malformed.
108    pub fn parse_server_first(server_first: &str) -> Result<(String, String, u32)> {
109        let mut r = None;
110        let mut s = None;
111        let mut i = None;
112
113        for part in server_first.split(',') {
114            if let Some(v) = part.strip_prefix("r=") {
115                r = Some(v.to_string());
116            } else if let Some(v) = part.strip_prefix("s=") {
117                s = Some(v.to_string());
118            } else if let Some(v) = part.strip_prefix("i=") {
119                i = v.parse::<u32>().ok();
120            }
121        }
122
123        Ok((
124            r.ok_or_else(|| PgWireError::Auth("SCRAM server-first missing nonce (r=)".into()))?,
125            s.ok_or_else(|| PgWireError::Auth("SCRAM server-first missing salt (s=)".into()))?,
126            i.ok_or_else(|| {
127                PgWireError::Auth(
128                    "SCRAM server-first missing or invalid iteration count (i=)".into(),
129                )
130            })?,
131        ))
132    }
133
134    /// Compute client-final-message.
135    ///
136    /// # Arguments
137    /// * `password` - User's password
138    /// * `server_first` - Server-first-message received from server
139    ///
140    /// # Returns
141    /// Tuple of:
142    /// - `client_final`: Message to send to server
143    /// - `auth_message`: Full auth message (needed for server verification)
144    /// - `salted_password`: Derived key (needed for server verification)
145    ///
146    /// # Errors
147    /// - Nonce doesn't start with client nonce (possible MITM)
148    /// - Invalid base64 in salt
149    pub fn client_final(
150        &self,
151        password: &str,
152        server_first: &str,
153    ) -> Result<(String, String, Vec<u8>)> {
154        let (rnonce, salt_b64, iters) = Self::parse_server_first(server_first)?;
155
156        // Security check: server nonce must start with our nonce
157        if !rnonce.starts_with(&self.client_nonce_b64) {
158            return Err(PgWireError::Auth(
159                "SCRAM nonce mismatch: server nonce doesn't include client nonce".into(),
160            ));
161        }
162
163        let salt = B64
164            .decode(salt_b64.as_bytes())
165            .map_err(|e| PgWireError::Auth(format!("SCRAM invalid salt base64: {e}")))?;
166
167        // Channel binding for non-TLS or tls-unique not supported: "biws" = base64("n,,")
168        let channel_binding = "biws";
169        let client_final_wo_proof = format!("c={channel_binding},r={rnonce}");
170
171        let auth_message = format!(
172            "{},{},{}",
173            self.client_first_bare, server_first, client_final_wo_proof
174        );
175
176        // SCRAM key derivation
177        let salted_password = hi_sha256(password.as_bytes(), &salt, iters);
178        let client_key = hmac_sha256(&salted_password, b"Client Key");
179        let stored_key = Sha256::digest(&client_key);
180
181        // Compute proof
182        let client_sig = hmac_sha256(stored_key.as_slice(), auth_message.as_bytes());
183        let proof = xor_bytes(&client_key, &client_sig);
184        let proof_b64 = B64.encode(proof);
185
186        let client_final = format!("{client_final_wo_proof},p={proof_b64}");
187        Ok((client_final, auth_message, salted_password))
188    }
189
190    /// Verify server-final-message.
191    ///
192    /// This provides mutual authentication - ensures we're talking to a server
193    /// that knows the password, not an impostor.
194    ///
195    /// # Arguments
196    /// * `server_final` - Server-final-message received
197    /// * `salted_password` - From `client_final()` return value
198    /// * `auth_message` - From `client_final()` return value
199    ///
200    /// # Errors
201    /// - Missing server signature
202    /// - Invalid base64
203    /// - Signature mismatch (server doesn't know password)
204    pub fn verify_server_final(
205        server_final: &str,
206        salted_password: &[u8],
207        auth_message: &str,
208    ) -> Result<()> {
209        // Check for error from server
210        if let Some(err) = server_final.split(',').find_map(|p| p.strip_prefix("e=")) {
211            return Err(PgWireError::Auth(format!("SCRAM server error: {err}")));
212        }
213
214        let v = server_final
215            .split(',')
216            .find_map(|p| p.strip_prefix("v="))
217            .ok_or_else(|| PgWireError::Auth("SCRAM server-final missing signature (v=)".into()))?;
218
219        let server_sig = B64.decode(v.trim().as_bytes()).map_err(|e| {
220            PgWireError::Auth(format!("SCRAM invalid server signature base64: {e}"))
221        })?;
222
223        // Compute expected server signature
224        let server_key = hmac_sha256(salted_password, b"Server Key");
225        let expected = hmac_sha256(&server_key, auth_message.as_bytes());
226
227        // Constant-time comparison to prevent timing attacks
228        if !constant_time_eq(&server_sig, &expected) {
229            return Err(PgWireError::Auth(
230                "SCRAM server signature mismatch: server may not know the password".into(),
231            ));
232        }
233
234        Ok(())
235    }
236}
237
238/// SASL-escape a username per RFC 5802.
239///
240/// Escapes `=` as `=3D` and `,` as `=2C`.
241#[cfg(feature = "scram")]
242fn sasl_escape_username(u: &str) -> String {
243    u.replace('=', "=3D").replace(',', "=2C")
244}
245
246/// Hi() function from RFC 5802 - essentially PBKDF2-HMAC-SHA256.
247///
248/// Derives a key from password and salt using the specified iteration count.
249#[cfg(feature = "scram")]
250fn hi_sha256(password: &[u8], salt: &[u8], iters: u32) -> Vec<u8> {
251    // U1 = HMAC(password, salt || INT(1))
252    let mut s1 = Vec::with_capacity(salt.len() + 4);
253    s1.extend_from_slice(salt);
254    s1.extend_from_slice(&1u32.to_be_bytes());
255
256    let mut u = hmac_sha256(password, &s1);
257    let mut out = u.clone();
258
259    // Ui = HMAC(password, U(i-1)), result = U1 XOR U2 XOR ... XOR Ui
260    for _ in 1..iters {
261        u = hmac_sha256(password, &u);
262        for (o, ui) in out.iter_mut().zip(u.iter()) {
263            *o ^= *ui;
264        }
265    }
266
267    out
268}
269
270/// Compute HMAC-SHA-256.
271#[cfg(feature = "scram")]
272fn hmac_sha256(key: &[u8], msg: &[u8]) -> Vec<u8> {
273    let mut mac = HmacSha256::new_from_slice(key).expect("HMAC key length is always valid");
274    mac.update(msg);
275    mac.finalize().into_bytes().to_vec()
276}
277
278/// XOR two byte slices of equal length.
279#[cfg(feature = "scram")]
280fn xor_bytes(a: &[u8], b: &[u8]) -> Vec<u8> {
281    debug_assert_eq!(a.len(), b.len(), "XOR operands must have equal length");
282    a.iter().zip(b.iter()).map(|(x, y)| x ^ y).collect()
283}
284
285/// Constant-time byte slice comparison.
286///
287/// Returns true if slices are equal, using constant-time comparison
288/// to prevent timing side-channel attacks.
289#[cfg(feature = "scram")]
290fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
291    if a.len() != b.len() {
292        return false;
293    }
294
295    // XOR all bytes, OR results together - any difference results in non-zero
296    let result = a
297        .iter()
298        .zip(b.iter())
299        .fold(0u8, |acc, (x, y)| acc | (x ^ y));
300
301    result == 0
302}
303
304#[cfg(test)]
305#[cfg(feature = "scram")]
306mod tests {
307    use super::*;
308
309    // ==================== ScramClient::new tests ====================
310
311    #[test]
312    fn scram_builds_first_message() {
313        let c = ScramClient::new("user");
314        assert!(c.client_first.starts_with("n,,n=user,r="));
315        assert!(c.client_first_bare.starts_with("n=user,r="));
316        assert!(!c.client_nonce_b64.is_empty());
317    }
318
319    #[test]
320    fn scram_escapes_special_chars_in_username() {
321        let c = ScramClient::new("user=name,test");
322        // = becomes =3D, , becomes =2C
323        assert!(c.client_first.contains("n=user=3Dname=2Ctest,r="));
324    }
325
326    #[test]
327    fn scram_unique_nonces() {
328        let c1 = ScramClient::new("user");
329        let c2 = ScramClient::new("user");
330        assert_ne!(c1.client_nonce_b64, c2.client_nonce_b64);
331    }
332
333    // ==================== parse_server_first tests ====================
334
335    #[test]
336    fn parse_server_first_valid() {
337        let (r, s, i) = ScramClient::parse_server_first("r=abc123,s=c2FsdA==,i=4096").unwrap();
338        assert_eq!(r, "abc123");
339        assert_eq!(s, "c2FsdA==");
340        assert_eq!(i, 4096);
341    }
342
343    #[test]
344    fn parse_server_first_different_order() {
345        // Fields can appear in any order
346        let (r, s, i) = ScramClient::parse_server_first("i=1000,s=Zm9v,r=xyz").unwrap();
347        assert_eq!(r, "xyz");
348        assert_eq!(s, "Zm9v");
349        assert_eq!(i, 1000);
350    }
351
352    #[test]
353    fn parse_server_first_with_extensions() {
354        // Should ignore unknown extensions
355        let (r, s, i) =
356            ScramClient::parse_server_first("r=nonce,s=c2FsdA==,i=4096,x=unknown").unwrap();
357        assert_eq!(r, "nonce");
358        assert_eq!(i, 4096);
359        let _ = s; // unused but parsed
360    }
361
362    #[test]
363    fn parse_server_first_missing_nonce() {
364        let err = ScramClient::parse_server_first("s=c2FsdA==,i=4096").unwrap_err();
365        assert!(err.to_string().contains("nonce"));
366    }
367
368    #[test]
369    fn parse_server_first_missing_salt() {
370        let err = ScramClient::parse_server_first("r=abc,i=4096").unwrap_err();
371        assert!(err.to_string().contains("salt"));
372    }
373
374    #[test]
375    fn parse_server_first_missing_iterations() {
376        let err = ScramClient::parse_server_first("r=abc,s=c2FsdA==").unwrap_err();
377        assert!(err.to_string().contains("iteration"));
378    }
379
380    #[test]
381    fn parse_server_first_invalid_iterations() {
382        let err = ScramClient::parse_server_first("r=abc,s=c2FsdA==,i=notanumber").unwrap_err();
383        assert!(err.to_string().contains("iteration"));
384    }
385
386    // ==================== client_final tests ====================
387
388    #[test]
389    fn client_final_computes_proof() {
390        // Use deterministic nonce for reproducible test
391        let client = ScramClient::with_nonce("user", "rOprNGfwEbeRWgbNEkqO");
392
393        let server_first = "r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,s=W22ZaJ0SNY7soEsUEjb6gQ==,i=4096";
394
395        let (client_final, auth_message, salted_password) =
396            client.client_final("pencil", server_first).unwrap();
397
398        // Verify structure
399        assert!(client_final.starts_with("c=biws,r="));
400        assert!(client_final.contains(",p="));
401
402        // Auth message should contain all three messages
403        assert!(auth_message.contains(&client.client_first_bare));
404        assert!(auth_message.contains(server_first));
405
406        // Salted password should be 32 bytes (SHA-256 output)
407        assert_eq!(salted_password.len(), 32);
408    }
409
410    #[test]
411    fn client_final_rejects_nonce_mismatch() {
412        let client = ScramClient::with_nonce("user", "clientnonce");
413
414        // Server returns nonce that doesn't start with client nonce
415        let server_first = "r=differentnonce,s=c2FsdA==,i=4096";
416
417        let err = client.client_final("password", server_first).unwrap_err();
418        assert!(err.to_string().contains("nonce mismatch"));
419    }
420
421    #[test]
422    fn client_final_rejects_invalid_salt_base64() {
423        let client = ScramClient::with_nonce("user", "abc");
424
425        let server_first = "r=abcdef,s=!!!invalid!!!,i=4096";
426
427        let err = client.client_final("password", server_first).unwrap_err();
428        assert!(err.to_string().contains("base64"));
429    }
430
431    // ==================== verify_server_final tests ====================
432
433    #[test]
434    fn verify_server_final_accepts_valid_signature() {
435        // This is a complete SCRAM exchange with known values
436        let client = ScramClient::with_nonce("user", "fyko+d2lbbFgONRv9qkxdawL");
437
438        let server_first = "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096";
439
440        let (_, auth_message, salted_password) =
441            client.client_final("pencil", server_first).unwrap();
442
443        // Compute expected server signature manually
444        let server_key = hmac_sha256(&salted_password, b"Server Key");
445        let server_sig = hmac_sha256(&server_key, auth_message.as_bytes());
446        let server_final = format!("v={}", B64.encode(&server_sig));
447
448        // Should succeed
449        ScramClient::verify_server_final(&server_final, &salted_password, &auth_message).unwrap();
450    }
451
452    #[test]
453    fn verify_server_final_rejects_wrong_signature() {
454        let salted_password = vec![0u8; 32];
455        let auth_message = "test";
456        let server_final = "v=AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; // wrong
457
458        let err = ScramClient::verify_server_final(server_final, &salted_password, auth_message)
459            .unwrap_err();
460        assert!(err.to_string().contains("signature mismatch"));
461    }
462
463    #[test]
464    fn verify_server_final_rejects_missing_signature() {
465        let err = ScramClient::verify_server_final("", &[], "").unwrap_err();
466        assert!(err.to_string().contains("missing signature"));
467    }
468
469    #[test]
470    fn verify_server_final_handles_server_error() {
471        let err = ScramClient::verify_server_final("e=invalid-proof", &[], "").unwrap_err();
472        assert!(err.to_string().contains("server error"));
473        assert!(err.to_string().contains("invalid-proof"));
474    }
475
476    #[test]
477    fn verify_server_final_rejects_invalid_base64() {
478        let err = ScramClient::verify_server_final("v=!!!invalid!!!", &[], "").unwrap_err();
479        assert!(err.to_string().contains("base64"));
480    }
481
482    // ==================== Helper function tests ====================
483
484    #[test]
485    fn sasl_escape_username_escapes_equals() {
486        assert_eq!(sasl_escape_username("a=b"), "a=3Db");
487    }
488
489    #[test]
490    fn sasl_escape_username_escapes_comma() {
491        assert_eq!(sasl_escape_username("a,b"), "a=2Cb");
492    }
493
494    #[test]
495    fn sasl_escape_username_escapes_both() {
496        assert_eq!(sasl_escape_username("a=b,c"), "a=3Db=2Cc");
497    }
498
499    #[test]
500    fn sasl_escape_username_preserves_normal() {
501        assert_eq!(sasl_escape_username("normal_user123"), "normal_user123");
502    }
503
504    #[test]
505    fn hi_sha256_single_iteration() {
506        // With 1 iteration, result is just HMAC(password, salt || 0x00000001)
507        let result = hi_sha256(b"password", b"salt", 1);
508        assert_eq!(result.len(), 32);
509    }
510
511    #[test]
512    fn hi_sha256_multiple_iterations() {
513        let result = hi_sha256(b"password", b"salt", 4096);
514        assert_eq!(result.len(), 32);
515
516        // More iterations should produce different result
517        let result2 = hi_sha256(b"password", b"salt", 1000);
518        assert_ne!(result, result2);
519    }
520
521    #[test]
522    fn hmac_sha256_produces_correct_length() {
523        let result = hmac_sha256(b"key", b"message");
524        assert_eq!(result.len(), 32);
525    }
526
527    #[test]
528    fn xor_bytes_works() {
529        assert_eq!(xor_bytes(&[0xFF, 0x00], &[0x0F, 0xF0]), vec![0xF0, 0xF0]);
530        assert_eq!(xor_bytes(&[0x00], &[0x00]), vec![0x00]);
531    }
532
533    #[test]
534    fn constant_time_eq_equal() {
535        assert!(constant_time_eq(&[1, 2, 3], &[1, 2, 3]));
536        assert!(constant_time_eq(&[], &[]));
537    }
538
539    #[test]
540    fn constant_time_eq_not_equal() {
541        assert!(!constant_time_eq(&[1, 2, 3], &[1, 2, 4]));
542        assert!(!constant_time_eq(&[1, 2, 3], &[1, 2]));
543    }
544
545    #[test]
546    fn constant_time_eq_different_lengths() {
547        assert!(!constant_time_eq(&[1, 2, 3], &[1, 2, 3, 4]));
548    }
549}