Skip to main content

heliosdb_proxy/
auth_scram.rs

1//! Proxy-side SCRAM-SHA-256 **server** authentication.
2//!
3//! When `[auth] mode = "scram"` is configured, the proxy terminates the
4//! client's SCRAM-SHA-256 exchange itself (it becomes the auth boundary)
5//! against verifiers loaded from an `auth_file`, instead of relaying the
6//! client's credentials straight through to the backend. This is the
7//! foundation for cross-client connection pooling (the backend connection
8//! is then established independently of the client's auth).
9//!
10//! The crypto mirrors the (RFC-5802-tested) client state machine in
11//! `backend::auth`, reusing its primitives. The state machine here is the
12//! server inverse: send server-first (salt/iterations/nonce), receive
13//! client-final, verify the `ClientProof`, return server-final.
14
15use std::collections::HashMap;
16
17use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
18
19use crate::backend::auth::{hmac_sha256, pbkdf2_hmac_sha256, sha256};
20
21/// A SCRAM-SHA-256 verifier for one user: everything needed to validate a
22/// `ClientProof` without knowing the plaintext password.
23#[derive(Debug, Clone)]
24pub struct ScramVerifier {
25    pub salt: Vec<u8>,
26    pub iterations: u32,
27    pub stored_key: [u8; 32],
28    pub server_key: [u8; 32],
29}
30
31impl ScramVerifier {
32    /// Derive a verifier from a plaintext password (salt + iterations
33    /// chosen here). Used for `auth_file` entries that store a plaintext
34    /// secret rather than a pre-computed `SCRAM-SHA-256$...` verifier.
35    pub fn from_password(password: &str, salt: Vec<u8>, iterations: u32) -> Self {
36        let salted = pbkdf2_hmac_sha256(password.as_bytes(), &salt, iterations);
37        let client_key = hmac_sha256(&salted, b"Client Key");
38        let stored_key = sha256(&client_key);
39        let server_key = hmac_sha256(&salted, b"Server Key");
40        Self {
41            salt,
42            iterations,
43            stored_key,
44            server_key,
45        }
46    }
47
48    /// Parse a PostgreSQL-format verifier string:
49    /// `SCRAM-SHA-256$<iter>:<salt_b64>$<StoredKey_b64>:<ServerKey_b64>`
50    /// (this is exactly `pg_authid.rolpassword` for SCRAM users).
51    pub fn parse(s: &str) -> Option<Self> {
52        let rest = s.strip_prefix("SCRAM-SHA-256$")?;
53        let (params, keys) = rest.split_once('$')?;
54        let (iter_str, salt_b64) = params.split_once(':')?;
55        let (stored_b64, server_b64) = keys.split_once(':')?;
56        let iterations: u32 = iter_str.parse().ok()?;
57        let salt = BASE64.decode(salt_b64.trim()).ok()?;
58        let stored = BASE64.decode(stored_b64.trim()).ok()?;
59        let server = BASE64.decode(server_b64.trim()).ok()?;
60        if stored.len() != 32 || server.len() != 32 {
61            return None;
62        }
63        let mut stored_key = [0u8; 32];
64        stored_key.copy_from_slice(&stored);
65        let mut server_key = [0u8; 32];
66        server_key.copy_from_slice(&server);
67        Some(Self {
68            salt,
69            iterations,
70            stored_key,
71            server_key,
72        })
73    }
74}
75
76/// Map of username -> verifier, loaded from an `auth_file`.
77///
78/// File format, one entry per line (`#` comments and blank lines ignored):
79/// `username:secret` where `secret` is either a plaintext password or a
80/// `SCRAM-SHA-256$...` verifier string. Quoted pgbouncer-style values are
81/// accepted (surrounding double quotes are stripped).
82#[derive(Debug, Clone, Default)]
83pub struct AuthFile {
84    users: HashMap<String, ScramVerifier>,
85}
86
87impl AuthFile {
88    pub fn load(path: &str) -> Result<Self, String> {
89        let data = std::fs::read_to_string(path)
90            .map_err(|e| format!("reading auth_file {}: {}", path, e))?;
91        Self::parse_str(&data, path)
92    }
93
94    pub fn parse_str(data: &str, path: &str) -> Result<Self, String> {
95        let mut users = HashMap::new();
96        for (lineno, raw) in data.lines().enumerate() {
97            let line = raw.trim();
98            if line.is_empty() || line.starts_with('#') {
99                continue;
100            }
101            let (user, secret) = line
102                .split_once(':')
103                .ok_or_else(|| format!("{}:{}: expected `user:secret`", path, lineno + 1))?;
104            let user = unquote(user.trim());
105            let secret = unquote(secret.trim());
106            let verifier = if secret.starts_with("SCRAM-SHA-256$") {
107                ScramVerifier::parse(&secret)
108                    .ok_or_else(|| format!("{}:{}: malformed SCRAM verifier", path, lineno + 1))?
109            } else {
110                // Plaintext: derive a verifier with a fixed salt derived
111                // from the username (stable across restarts so the same
112                // client password always validates) and 4096 iterations.
113                let salt = sha256(user.as_bytes())[..16].to_vec();
114                ScramVerifier::from_password(&secret, salt, 4096)
115            };
116            users.insert(user, verifier);
117        }
118        Ok(Self { users })
119    }
120
121    pub fn get(&self, user: &str) -> Option<&ScramVerifier> {
122        self.users.get(user)
123    }
124
125    pub fn is_empty(&self) -> bool {
126        self.users.is_empty()
127    }
128}
129
130fn unquote(s: &str) -> String {
131    let t = s.trim();
132    if t.len() >= 2 && t.starts_with('"') && t.ends_with('"') {
133        t[1..t.len() - 1].to_string()
134    } else {
135        t.to_string()
136    }
137}
138
139/// Server-side SCRAM-SHA-256 state machine. One per client handshake.
140pub struct ScramServer {
141    verifier: ScramVerifier,
142    combined_nonce: String,
143    client_first_bare: String,
144    server_first: String,
145}
146
147impl ScramServer {
148    /// Begin the exchange from the client's first message (the
149    /// SASLInitialResponse payload, e.g. `n,,n=,r=<clientnonce>`).
150    /// `server_nonce` must be a fresh random token. Returns the
151    /// `server-first` message to send back (AuthenticationSASLContinue).
152    pub fn start(
153        verifier: ScramVerifier,
154        client_first: &str,
155        server_nonce: &str,
156    ) -> Result<(Self, String), String> {
157        // Strip the gs2 header ("n,," / "y,," / "p=...,,"): the bare part
158        // is everything after the second comma.
159        let mut parts = client_first.splitn(3, ',');
160        let _gs2_cbind = parts.next();
161        let _gs2_authzid = parts.next();
162        let bare = parts
163            .next()
164            .ok_or_else(|| "malformed client-first (no bare part)".to_string())?;
165
166        let client_nonce = bare
167            .split(',')
168            .find_map(|f| f.strip_prefix("r="))
169            .ok_or_else(|| "client-first missing r=".to_string())?;
170        if client_nonce.is_empty() {
171            return Err("empty client nonce".to_string());
172        }
173
174        let combined_nonce = format!("{}{}", client_nonce, server_nonce);
175        let salt_b64 = BASE64.encode(&verifier.salt);
176        let server_first = format!(
177            "r={},s={},i={}",
178            combined_nonce, salt_b64, verifier.iterations
179        );
180
181        Ok((
182            Self {
183                verifier,
184                combined_nonce,
185                client_first_bare: bare.to_string(),
186                server_first: server_first.clone(),
187            },
188            server_first,
189        ))
190    }
191
192    /// Verify the client-final message (`c=<cb>,r=<nonce>,p=<proof>`).
193    /// On success returns the `server-final` message (`v=<sig>`) to send
194    /// in AuthenticationSASLFinal.
195    pub fn finish(&self, client_final: &str) -> Result<String, String> {
196        // Split off the trailing ",p=<proof>".
197        let proof_pos = client_final
198            .rfind(",p=")
199            .ok_or_else(|| "client-final missing p=".to_string())?;
200        let without_proof = &client_final[..proof_pos];
201        let proof_b64 = &client_final[proof_pos + 3..];
202
203        // Nonce echoed back must equal ours.
204        let echoed_nonce = without_proof
205            .split(',')
206            .find_map(|f| f.strip_prefix("r="))
207            .ok_or_else(|| "client-final missing r=".to_string())?;
208        if echoed_nonce != self.combined_nonce {
209            return Err("nonce mismatch".to_string());
210        }
211
212        let proof = BASE64
213            .decode(proof_b64.trim())
214            .map_err(|e| format!("bad proof base64: {}", e))?;
215        if proof.len() != 32 {
216            return Err("proof wrong length".to_string());
217        }
218
219        let auth_message = format!(
220            "{},{},{}",
221            self.client_first_bare, self.server_first, without_proof
222        );
223
224        // ClientSignature = HMAC(StoredKey, AuthMessage)
225        // ClientKey       = ClientProof XOR ClientSignature
226        // verify H(ClientKey) == StoredKey
227        let client_signature = hmac_sha256(&self.verifier.stored_key, auth_message.as_bytes());
228        let mut client_key = [0u8; 32];
229        for i in 0..32 {
230            client_key[i] = proof[i] ^ client_signature[i];
231        }
232        let derived_stored = sha256(&client_key);
233        if !constant_time_eq(&derived_stored, &self.verifier.stored_key) {
234            return Err("authentication failed (proof mismatch)".to_string());
235        }
236
237        // ServerSignature = HMAC(ServerKey, AuthMessage)
238        let server_signature = hmac_sha256(&self.verifier.server_key, auth_message.as_bytes());
239        Ok(format!("v={}", BASE64.encode(server_signature)))
240    }
241}
242
243fn constant_time_eq(a: &[u8; 32], b: &[u8; 32]) -> bool {
244    let mut diff = 0u8;
245    for i in 0..32 {
246        diff |= a[i] ^ b[i];
247    }
248    diff == 0
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    use crate::backend::auth::Scram;
255
256    #[test]
257    fn parse_pg_verifier_roundtrips_from_password() {
258        // Build a verifier from a password, format it PG-style, reparse.
259        let v = ScramVerifier::from_password("s3cret", b"0123456789abcdef".to_vec(), 4096);
260        let s = format!(
261            "SCRAM-SHA-256${}:{}${}:{}",
262            v.iterations,
263            BASE64.encode(&v.salt),
264            BASE64.encode(v.stored_key),
265            BASE64.encode(v.server_key),
266        );
267        let p = ScramVerifier::parse(&s).expect("parses");
268        assert_eq!(p.iterations, v.iterations);
269        assert_eq!(p.salt, v.salt);
270        assert_eq!(p.stored_key, v.stored_key);
271        assert_eq!(p.server_key, v.server_key);
272    }
273
274    #[test]
275    fn full_scram_handshake_client_vs_server() {
276        // Drive the tested client (backend::auth::Scram) against our server.
277        let password = "correct horse battery staple";
278        let verifier = ScramVerifier::from_password(password, b"saltsaltsaltsalt".to_vec(), 4096);
279
280        let (mut client, init) = Scram::client_first("clientNONCE123");
281        // init = SASLInitialResponse: mechanism cstring + int32 len + data.
282        // Recover the client-first payload (after the mechanism + length).
283        let data = &init.0;
284        let mech_end = data.iter().position(|&b| b == 0).unwrap() + 1;
285        let client_first = std::str::from_utf8(&data[mech_end + 4..]).unwrap();
286
287        let (server, server_first) =
288            ScramServer::start(verifier.clone(), client_first, "serverNONCE456").unwrap();
289
290        let client_final = client
291            .client_final(server_first.as_bytes(), password)
292            .unwrap();
293        let server_final = server
294            .finish(std::str::from_utf8(&client_final.0).unwrap())
295            .unwrap();
296
297        // The client verifies the server signature -> mutual auth complete.
298        client.verify_server(server_final.as_bytes()).unwrap();
299    }
300
301    #[test]
302    fn wrong_password_is_rejected() {
303        let verifier = ScramVerifier::from_password("rightpw", b"saltsaltsaltsalt".to_vec(), 4096);
304        let (mut client, init) = Scram::client_first("nonceAAA");
305        let data = &init.0;
306        let mech_end = data.iter().position(|&b| b == 0).unwrap() + 1;
307        let client_first = std::str::from_utf8(&data[mech_end + 4..]).unwrap();
308        let (server, server_first) =
309            ScramServer::start(verifier, client_first, "nonceBBB").unwrap();
310        // Client uses the WRONG password.
311        let client_final = client
312            .client_final(server_first.as_bytes(), "wrongpw")
313            .unwrap();
314        let res = server.finish(std::str::from_utf8(&client_final.0).unwrap());
315        assert!(res.is_err(), "wrong password must be rejected");
316    }
317
318    #[test]
319    fn auth_file_parses_plaintext_and_verifier() {
320        let v = ScramVerifier::from_password("pw", b"0123456789abcdef".to_vec(), 4096);
321        let verifier_line = format!(
322            "carol:SCRAM-SHA-256${}:{}${}:{}",
323            v.iterations,
324            BASE64.encode(&v.salt),
325            BASE64.encode(v.stored_key),
326            BASE64.encode(v.server_key),
327        );
328        let body = format!(
329            "# comment\nalice:secret\n\nbob:\"quoted\"\n{}\n",
330            verifier_line
331        );
332        let af = AuthFile::parse_str(&body, "test").unwrap();
333        assert!(af.get("alice").is_some());
334        assert!(af.get("bob").is_some());
335        assert!(af.get("carol").is_some());
336        assert!(af.get("dave").is_none());
337    }
338}