Skip to main content

heliosdb_proxy/
auth_scram.rs

1//! Proxy-side SCRAM-SHA-256 **server** authentication.
2//!
3//! When `[auth] mode = "scram"` is configured, the proxy terminates the
4//! client's SCRAM-SHA-256 exchange itself (it becomes the auth boundary)
5//! against verifiers loaded from an `auth_file`, instead of relaying the
6//! client's credentials straight through to the backend. This is the
7//! foundation for cross-client connection pooling (the backend connection
8//! is then established independently of the client's auth).
9//!
10//! The crypto mirrors the (RFC-5802-tested) client state machine in
11//! `backend::auth`, reusing its primitives. The state machine here is the
12//! server inverse: send server-first (salt/iterations/nonce), receive
13//! client-final, verify the `ClientProof`, return server-final.
14
15use std::collections::HashMap;
16
17use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
18
19use crate::backend::auth::{hmac_sha256, pbkdf2_hmac_sha256, sha256};
20
21/// A SCRAM-SHA-256 verifier for one user: everything needed to validate a
22/// `ClientProof` without knowing the plaintext password.
23#[derive(Debug, Clone)]
24pub struct ScramVerifier {
25    pub salt: Vec<u8>,
26    pub iterations: u32,
27    pub stored_key: [u8; 32],
28    pub server_key: [u8; 32],
29}
30
31impl ScramVerifier {
32    /// Derive a verifier from a plaintext password (salt + iterations
33    /// chosen here). Used for `auth_file` entries that store a plaintext
34    /// secret rather than a pre-computed `SCRAM-SHA-256$...` verifier.
35    pub fn from_password(password: &str, salt: Vec<u8>, iterations: u32) -> Self {
36        let salted = pbkdf2_hmac_sha256(password.as_bytes(), &salt, iterations);
37        let client_key = hmac_sha256(&salted, b"Client Key");
38        let stored_key = sha256(&client_key);
39        let server_key = hmac_sha256(&salted, b"Server Key");
40        Self { salt, iterations, stored_key, server_key }
41    }
42
43    /// Parse a PostgreSQL-format verifier string:
44    /// `SCRAM-SHA-256$<iter>:<salt_b64>$<StoredKey_b64>:<ServerKey_b64>`
45    /// (this is exactly `pg_authid.rolpassword` for SCRAM users).
46    pub fn parse(s: &str) -> Option<Self> {
47        let rest = s.strip_prefix("SCRAM-SHA-256$")?;
48        let (params, keys) = rest.split_once('$')?;
49        let (iter_str, salt_b64) = params.split_once(':')?;
50        let (stored_b64, server_b64) = keys.split_once(':')?;
51        let iterations: u32 = iter_str.parse().ok()?;
52        let salt = BASE64.decode(salt_b64.trim()).ok()?;
53        let stored = BASE64.decode(stored_b64.trim()).ok()?;
54        let server = BASE64.decode(server_b64.trim()).ok()?;
55        if stored.len() != 32 || server.len() != 32 {
56            return None;
57        }
58        let mut stored_key = [0u8; 32];
59        stored_key.copy_from_slice(&stored);
60        let mut server_key = [0u8; 32];
61        server_key.copy_from_slice(&server);
62        Some(Self { salt, iterations, stored_key, server_key })
63    }
64}
65
66/// Map of username -> verifier, loaded from an `auth_file`.
67///
68/// File format, one entry per line (`#` comments and blank lines ignored):
69/// `username:secret` where `secret` is either a plaintext password or a
70/// `SCRAM-SHA-256$...` verifier string. Quoted pgbouncer-style values are
71/// accepted (surrounding double quotes are stripped).
72#[derive(Debug, Clone, Default)]
73pub struct AuthFile {
74    users: HashMap<String, ScramVerifier>,
75}
76
77impl AuthFile {
78    pub fn load(path: &str) -> Result<Self, String> {
79        let data = std::fs::read_to_string(path)
80            .map_err(|e| format!("reading auth_file {}: {}", path, e))?;
81        Self::parse_str(&data, path)
82    }
83
84    pub fn parse_str(data: &str, path: &str) -> Result<Self, String> {
85        let mut users = HashMap::new();
86        for (lineno, raw) in data.lines().enumerate() {
87            let line = raw.trim();
88            if line.is_empty() || line.starts_with('#') {
89                continue;
90            }
91            let (user, secret) = line
92                .split_once(':')
93                .ok_or_else(|| format!("{}:{}: expected `user:secret`", path, lineno + 1))?;
94            let user = unquote(user.trim());
95            let secret = unquote(secret.trim());
96            let verifier = if secret.starts_with("SCRAM-SHA-256$") {
97                ScramVerifier::parse(&secret).ok_or_else(|| {
98                    format!("{}:{}: malformed SCRAM verifier", path, lineno + 1)
99                })?
100            } else {
101                // Plaintext: derive a verifier with a fixed salt derived
102                // from the username (stable across restarts so the same
103                // client password always validates) and 4096 iterations.
104                let salt = sha256(user.as_bytes())[..16].to_vec();
105                ScramVerifier::from_password(&secret, salt, 4096)
106            };
107            users.insert(user, verifier);
108        }
109        Ok(Self { users })
110    }
111
112    pub fn get(&self, user: &str) -> Option<&ScramVerifier> {
113        self.users.get(user)
114    }
115
116    pub fn is_empty(&self) -> bool {
117        self.users.is_empty()
118    }
119}
120
121fn unquote(s: &str) -> String {
122    let t = s.trim();
123    if t.len() >= 2 && t.starts_with('"') && t.ends_with('"') {
124        t[1..t.len() - 1].to_string()
125    } else {
126        t.to_string()
127    }
128}
129
130/// Server-side SCRAM-SHA-256 state machine. One per client handshake.
131pub struct ScramServer {
132    verifier: ScramVerifier,
133    combined_nonce: String,
134    client_first_bare: String,
135    server_first: String,
136}
137
138impl ScramServer {
139    /// Begin the exchange from the client's first message (the
140    /// SASLInitialResponse payload, e.g. `n,,n=,r=<clientnonce>`).
141    /// `server_nonce` must be a fresh random token. Returns the
142    /// `server-first` message to send back (AuthenticationSASLContinue).
143    pub fn start(
144        verifier: ScramVerifier,
145        client_first: &str,
146        server_nonce: &str,
147    ) -> Result<(Self, String), String> {
148        // Strip the gs2 header ("n,," / "y,," / "p=...,,"): the bare part
149        // is everything after the second comma.
150        let mut parts = client_first.splitn(3, ',');
151        let _gs2_cbind = parts.next();
152        let _gs2_authzid = parts.next();
153        let bare = parts
154            .next()
155            .ok_or_else(|| "malformed client-first (no bare part)".to_string())?;
156
157        let client_nonce = bare
158            .split(',')
159            .find_map(|f| f.strip_prefix("r="))
160            .ok_or_else(|| "client-first missing r=".to_string())?;
161        if client_nonce.is_empty() {
162            return Err("empty client nonce".to_string());
163        }
164
165        let combined_nonce = format!("{}{}", client_nonce, server_nonce);
166        let salt_b64 = BASE64.encode(&verifier.salt);
167        let server_first = format!("r={},s={},i={}", combined_nonce, salt_b64, verifier.iterations);
168
169        Ok((
170            Self {
171                verifier,
172                combined_nonce,
173                client_first_bare: bare.to_string(),
174                server_first: server_first.clone(),
175            },
176            server_first,
177        ))
178    }
179
180    /// Verify the client-final message (`c=<cb>,r=<nonce>,p=<proof>`).
181    /// On success returns the `server-final` message (`v=<sig>`) to send
182    /// in AuthenticationSASLFinal.
183    pub fn finish(&self, client_final: &str) -> Result<String, String> {
184        // Split off the trailing ",p=<proof>".
185        let proof_pos = client_final
186            .rfind(",p=")
187            .ok_or_else(|| "client-final missing p=".to_string())?;
188        let without_proof = &client_final[..proof_pos];
189        let proof_b64 = &client_final[proof_pos + 3..];
190
191        // Nonce echoed back must equal ours.
192        let echoed_nonce = without_proof
193            .split(',')
194            .find_map(|f| f.strip_prefix("r="))
195            .ok_or_else(|| "client-final missing r=".to_string())?;
196        if echoed_nonce != self.combined_nonce {
197            return Err("nonce mismatch".to_string());
198        }
199
200        let proof = BASE64
201            .decode(proof_b64.trim())
202            .map_err(|e| format!("bad proof base64: {}", e))?;
203        if proof.len() != 32 {
204            return Err("proof wrong length".to_string());
205        }
206
207        let auth_message = format!(
208            "{},{},{}",
209            self.client_first_bare, self.server_first, without_proof
210        );
211
212        // ClientSignature = HMAC(StoredKey, AuthMessage)
213        // ClientKey       = ClientProof XOR ClientSignature
214        // verify H(ClientKey) == StoredKey
215        let client_signature = hmac_sha256(&self.verifier.stored_key, auth_message.as_bytes());
216        let mut client_key = [0u8; 32];
217        for i in 0..32 {
218            client_key[i] = proof[i] ^ client_signature[i];
219        }
220        let derived_stored = sha256(&client_key);
221        if !constant_time_eq(&derived_stored, &self.verifier.stored_key) {
222            return Err("authentication failed (proof mismatch)".to_string());
223        }
224
225        // ServerSignature = HMAC(ServerKey, AuthMessage)
226        let server_signature = hmac_sha256(&self.verifier.server_key, auth_message.as_bytes());
227        Ok(format!("v={}", BASE64.encode(server_signature)))
228    }
229}
230
231fn constant_time_eq(a: &[u8; 32], b: &[u8; 32]) -> bool {
232    let mut diff = 0u8;
233    for i in 0..32 {
234        diff |= a[i] ^ b[i];
235    }
236    diff == 0
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use crate::backend::auth::Scram;
243
244    #[test]
245    fn parse_pg_verifier_roundtrips_from_password() {
246        // Build a verifier from a password, format it PG-style, reparse.
247        let v = ScramVerifier::from_password("s3cret", b"0123456789abcdef".to_vec(), 4096);
248        let s = format!(
249            "SCRAM-SHA-256${}:{}${}:{}",
250            v.iterations,
251            BASE64.encode(&v.salt),
252            BASE64.encode(v.stored_key),
253            BASE64.encode(v.server_key),
254        );
255        let p = ScramVerifier::parse(&s).expect("parses");
256        assert_eq!(p.iterations, v.iterations);
257        assert_eq!(p.salt, v.salt);
258        assert_eq!(p.stored_key, v.stored_key);
259        assert_eq!(p.server_key, v.server_key);
260    }
261
262    #[test]
263    fn full_scram_handshake_client_vs_server() {
264        // Drive the tested client (backend::auth::Scram) against our server.
265        let password = "correct horse battery staple";
266        let verifier = ScramVerifier::from_password(password, b"saltsaltsaltsalt".to_vec(), 4096);
267
268        let (mut client, init) = Scram::client_first("clientNONCE123");
269        // init = SASLInitialResponse: mechanism cstring + int32 len + data.
270        // Recover the client-first payload (after the mechanism + length).
271        let data = &init.0;
272        let mech_end = data.iter().position(|&b| b == 0).unwrap() + 1;
273        let client_first = std::str::from_utf8(&data[mech_end + 4..]).unwrap();
274
275        let (server, server_first) =
276            ScramServer::start(verifier.clone(), client_first, "serverNONCE456").unwrap();
277
278        let client_final = client.client_final(server_first.as_bytes(), password).unwrap();
279        let server_final = server.finish(std::str::from_utf8(&client_final.0).unwrap()).unwrap();
280
281        // The client verifies the server signature -> mutual auth complete.
282        client.verify_server(server_final.as_bytes()).unwrap();
283    }
284
285    #[test]
286    fn wrong_password_is_rejected() {
287        let verifier = ScramVerifier::from_password("rightpw", b"saltsaltsaltsalt".to_vec(), 4096);
288        let (mut client, init) = Scram::client_first("nonceAAA");
289        let data = &init.0;
290        let mech_end = data.iter().position(|&b| b == 0).unwrap() + 1;
291        let client_first = std::str::from_utf8(&data[mech_end + 4..]).unwrap();
292        let (server, server_first) =
293            ScramServer::start(verifier, client_first, "nonceBBB").unwrap();
294        // Client uses the WRONG password.
295        let client_final = client.client_final(server_first.as_bytes(), "wrongpw").unwrap();
296        let res = server.finish(std::str::from_utf8(&client_final.0).unwrap());
297        assert!(res.is_err(), "wrong password must be rejected");
298    }
299
300    #[test]
301    fn auth_file_parses_plaintext_and_verifier() {
302        let v = ScramVerifier::from_password("pw", b"0123456789abcdef".to_vec(), 4096);
303        let verifier_line = format!(
304            "carol:SCRAM-SHA-256${}:{}${}:{}",
305            v.iterations,
306            BASE64.encode(&v.salt),
307            BASE64.encode(v.stored_key),
308            BASE64.encode(v.server_key),
309        );
310        let body = format!("# comment\nalice:secret\n\nbob:\"quoted\"\n{}\n", verifier_line);
311        let af = AuthFile::parse_str(&body, "test").unwrap();
312        assert!(af.get("alice").is_some());
313        assert!(af.get("bob").is_some());
314        assert!(af.get("carol").is_some());
315        assert!(af.get("dave").is_none());
316    }
317}