Skip to main content

heliosdb_proxy/backend/
auth.rs

1//! PostgreSQL client-side authentication helpers.
2//!
3//! Covers the two mechanisms we need today:
4//! - **MD5** (AuthenticationMD5Password, request code 5). Legacy but
5//!   still widely deployed. Payload is
6//!   `"md5" + hex(md5(hex(md5(password + username)) + salt))`.
7//! - **SCRAM-SHA-256** (AuthenticationSASL, mechanism
8//!   `SCRAM-SHA-256`, request code 10). The current PG default.
9//!
10//! Both implementations verify the server's end of the handshake where
11//! the protocol allows it — MD5 has no server-side verifier, SCRAM does
12//! (the server-final message includes `v=<server-signature>`).
13
14use super::error::{BackendError, BackendResult};
15use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
16use hmac::{Hmac, Mac};
17use sha2::{Digest, Sha256};
18
19type HmacSha256 = Hmac<Sha256>;
20
21// ---------------------------------------------------------------------
22// MD5 authentication
23// ---------------------------------------------------------------------
24
25/// Compute the response payload for `AuthenticationMD5Password`.
26///
27/// Returns the complete `PasswordMessage` payload (the null-terminated
28/// string the client sends back), excluding the tag and length prefix
29/// — the caller frames it.
30pub fn md5_password_response(user: &str, password: &str, salt: &[u8; 4]) -> Vec<u8> {
31    let mut out = Vec::with_capacity(35 + 1);
32    let inner = md5_hex(format!("{}{}", password, user).as_bytes());
33    let mut salted = Vec::with_capacity(inner.len() + 4);
34    salted.extend_from_slice(inner.as_bytes());
35    salted.extend_from_slice(salt);
36    out.extend_from_slice(b"md5");
37    out.extend_from_slice(md5_hex(&salted).as_bytes());
38    out.push(0);
39    out
40}
41
42fn md5_hex(bytes: &[u8]) -> String {
43    let digest = md5::Md5::digest(bytes);
44    let mut s = String::with_capacity(digest.len() * 2);
45    for b in digest {
46        s.push_str(&format!("{:02x}", b));
47    }
48    s
49}
50
51// ---------------------------------------------------------------------
52// SCRAM-SHA-256
53// ---------------------------------------------------------------------
54
55/// SCRAM client state machine. Create with `Scram::client_first`, feed
56/// the server-first into `client_final`, and feed the server-final into
57/// `verify_server`.
58pub struct Scram {
59    /// Cached for HMAC in server-signature check.
60    client_first_bare: String,
61    /// Saved for AuthMessage construction.
62    nonce: String,
63    /// Computed after `client_final`; used to verify server signature.
64    server_key: [u8; 32],
65    /// Full AuthMessage computed after client_final.
66    auth_message: String,
67    /// Whether `client_final` ran (guards `verify_server`).
68    finalised: bool,
69}
70
71/// Result of one SCRAM step: the opaque bytes to send to the server.
72#[derive(Debug)]
73pub struct ScramMessage(pub Vec<u8>);
74
75impl Scram {
76    /// Build the SASL initial response for `SCRAM-SHA-256`.
77    ///
78    /// The returned bytes are the payload of a `PasswordMessage`
79    /// (tag `p`). `nonce` must be a unique random string per session —
80    /// the caller provides it so the function is testable.
81    pub fn client_first(nonce: impl Into<String>) -> (Self, ScramMessage) {
82        let nonce = nonce.into();
83        // gs2-header is "n,," (no channel binding, no authzid).
84        // client-first-bare is n=<user>,r=<nonce>. PG ignores <user>
85        // and takes the name from the StartupMessage; we send empty.
86        let client_first_bare = format!("n=,r={}", nonce);
87        let client_first = format!("n,,{}", client_first_bare);
88
89        // SASL format: mechanism + NUL + 4-byte BE length + bytes.
90        let mech = b"SCRAM-SHA-256\0";
91        let mut out = Vec::with_capacity(mech.len() + 4 + client_first.len());
92        out.extend_from_slice(mech);
93        out.extend_from_slice(&(client_first.len() as u32).to_be_bytes());
94        out.extend_from_slice(client_first.as_bytes());
95
96        (
97            Self {
98                client_first_bare,
99                nonce,
100                server_key: [0u8; 32],
101                auth_message: String::new(),
102                finalised: false,
103            },
104            ScramMessage(out),
105        )
106    }
107
108    /// Consume the server-first message and produce the client-final.
109    ///
110    /// `server_first` is the raw bytes the server sent (the payload of
111    /// an `AuthenticationSASLContinue` frame, minus the 4-byte type
112    /// code which the caller strips).
113    pub fn client_final(
114        &mut self,
115        server_first: &[u8],
116        password: &str,
117    ) -> BackendResult<ScramMessage> {
118        let server_first_str = std::str::from_utf8(server_first)
119            .map_err(|e| BackendError::Auth(format!("server-first is not UTF-8: {}", e)))?;
120
121        // Parse r=<combined-nonce>,s=<salt-base64>,i=<iteration-count>
122        let mut server_nonce = None;
123        let mut salt_b64 = None;
124        let mut iterations: Option<u32> = None;
125        for field in server_first_str.split(',') {
126            if let Some(rest) = field.strip_prefix("r=") {
127                server_nonce = Some(rest);
128            } else if let Some(rest) = field.strip_prefix("s=") {
129                salt_b64 = Some(rest);
130            } else if let Some(rest) = field.strip_prefix("i=") {
131                iterations = rest.parse().ok();
132            }
133        }
134        let server_nonce =
135            server_nonce.ok_or_else(|| BackendError::Auth("missing r= in server-first".into()))?;
136        let salt_b64 =
137            salt_b64.ok_or_else(|| BackendError::Auth("missing s= in server-first".into()))?;
138        let iterations = iterations
139            .ok_or_else(|| BackendError::Auth("missing/invalid i= in server-first".into()))?;
140
141        // The server must echo the client nonce as a prefix.
142        if !server_nonce.starts_with(&self.nonce) {
143            return Err(BackendError::Auth(
144                "server nonce does not extend client nonce".into(),
145            ));
146        }
147        if iterations < 1 {
148            return Err(BackendError::Auth("iteration count must be >= 1".into()));
149        }
150
151        let salt = BASE64
152            .decode(salt_b64)
153            .map_err(|e| BackendError::Auth(format!("bad salt base64: {}", e)))?;
154
155        // Derive keys per RFC 5802.
156        let salted_password = pbkdf2_hmac_sha256(password.as_bytes(), &salt, iterations);
157        let client_key = hmac_sha256(&salted_password, b"Client Key");
158        let stored_key = sha256(&client_key);
159        self.server_key = hmac_sha256(&salted_password, b"Server Key");
160
161        // channel-binding: "c=" + base64("n,,")
162        let channel_binding = BASE64.encode(b"n,,");
163
164        let client_final_without_proof = format!("c={},r={}", channel_binding, server_nonce);
165        self.auth_message = format!(
166            "{},{},{}",
167            self.client_first_bare, server_first_str, client_final_without_proof
168        );
169
170        let client_signature = hmac_sha256(&stored_key, self.auth_message.as_bytes());
171        let mut client_proof = [0u8; 32];
172        for i in 0..32 {
173            client_proof[i] = client_key[i] ^ client_signature[i];
174        }
175
176        let client_final = format!(
177            "{},p={}",
178            client_final_without_proof,
179            BASE64.encode(client_proof)
180        );
181
182        self.finalised = true;
183        Ok(ScramMessage(client_final.into_bytes()))
184    }
185
186    /// Verify the server-final message's `v=<server-signature>` tag.
187    /// Returns `Ok(())` only if the signature matches what we expect
188    /// from the derived `server_key`.
189    pub fn verify_server(&self, server_final: &[u8]) -> BackendResult<()> {
190        if !self.finalised {
191            return Err(BackendError::Auth(
192                "verify_server called before client_final".into(),
193            ));
194        }
195        let s = std::str::from_utf8(server_final)
196            .map_err(|e| BackendError::Auth(format!("server-final is not UTF-8: {}", e)))?;
197        // Server may send `e=<error>` on failure.
198        if let Some(err) = s.strip_prefix("e=") {
199            return Err(BackendError::Auth(format!("server reported: {}", err)));
200        }
201        let sig_b64 = s
202            .strip_prefix("v=")
203            .ok_or_else(|| BackendError::Auth("missing v= in server-final".into()))?
204            .split(',')
205            .next()
206            .unwrap_or("");
207        let received = BASE64
208            .decode(sig_b64)
209            .map_err(|e| BackendError::Auth(format!("bad v= base64: {}", e)))?;
210        let expected = hmac_sha256(&self.server_key, self.auth_message.as_bytes());
211        if received == expected {
212            Ok(())
213        } else {
214            Err(BackendError::Auth("server signature mismatch".into()))
215        }
216    }
217}
218
219// ---- crypto primitives --------------------------------------------------
220
221pub(crate) fn sha256(data: &[u8]) -> [u8; 32] {
222    let mut h = Sha256::new();
223    h.update(data);
224    h.finalize().into()
225}
226
227pub(crate) fn hmac_sha256(key: &[u8], data: &[u8]) -> [u8; 32] {
228    let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts any key length");
229    mac.update(data);
230    let tag = mac.finalize().into_bytes();
231    let mut out = [0u8; 32];
232    out.copy_from_slice(&tag);
233    out
234}
235
236pub(crate) fn pbkdf2_hmac_sha256(password: &[u8], salt: &[u8], iters: u32) -> [u8; 32] {
237    // Single-block PBKDF2 (dkLen == hLen == 32) — exactly what SCRAM
238    // requires.
239    let mut mac = HmacSha256::new_from_slice(password).expect("HMAC accepts any key length");
240    mac.update(salt);
241    mac.update(&1u32.to_be_bytes());
242    let mut u: [u8; 32] = mac.finalize().into_bytes().into();
243    let mut out = u;
244    for _ in 1..iters {
245        let mut mac = HmacSha256::new_from_slice(password).expect("HMAC accepts any key length");
246        mac.update(&u);
247        u = mac.finalize().into_bytes().into();
248        for i in 0..32 {
249            out[i] ^= u[i];
250        }
251    }
252    out
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    /// Known-answer MD5 auth per PostgreSQL docs:
260    /// `concat('md5', md5(md5(password || username) || salt))`.
261    #[test]
262    fn test_md5_password_response_known_answer() {
263        // username = "alice", password = "secret", salt = [0x01,0x02,0x03,0x04]
264        let got = md5_password_response("alice", "secret", &[0x01, 0x02, 0x03, 0x04]);
265        // Last byte is the cstring terminator.
266        assert_eq!(got.last().copied(), Some(0u8));
267        let body = std::str::from_utf8(&got[..got.len() - 1]).unwrap();
268        assert!(body.starts_with("md5"));
269        assert_eq!(body.len(), 3 + 32); // "md5" + 32 hex chars
270                                        // Re-derive and compare.
271        let inner = md5_hex(b"secretalice");
272        let mut combined = inner.into_bytes();
273        combined.extend_from_slice(&[0x01, 0x02, 0x03, 0x04]);
274        let outer = md5_hex(&combined);
275        assert_eq!(&body[3..], outer);
276    }
277
278    /// PBKDF2-HMAC-SHA-256 known-answer from RFC 7914 / RFC 5802 test
279    /// vectors. (P="password", S="salt", c=1, dkLen=32.)
280    #[test]
281    fn test_pbkdf2_hmac_sha256_rfc_vector() {
282        let got = pbkdf2_hmac_sha256(b"password", b"salt", 1);
283        let expected = [
284            0x12, 0x0f, 0xb6, 0xcf, 0xfc, 0xf8, 0xb3, 0x2c, 0x43, 0xe7, 0x22, 0x52, 0x56, 0xc4,
285            0xf8, 0x37, 0xa8, 0x65, 0x48, 0xc9, 0x2c, 0xcc, 0x35, 0x48, 0x08, 0x05, 0x98, 0x7c,
286            0xb7, 0x0b, 0xe1, 0x7b,
287        ];
288        assert_eq!(got, expected);
289    }
290
291    /// Higher iteration count — smoke test that the loop accumulates
292    /// correctly. Taken from the same RFC set (c=4096).
293    #[test]
294    fn test_pbkdf2_hmac_sha256_high_iters() {
295        let got = pbkdf2_hmac_sha256(b"password", b"salt", 4096);
296        let expected = [
297            0xc5, 0xe4, 0x78, 0xd5, 0x92, 0x88, 0xc8, 0x41, 0xaa, 0x53, 0x0d, 0xb6, 0x84, 0x5c,
298            0x4c, 0x8d, 0x96, 0x28, 0x93, 0xa0, 0x01, 0xce, 0x4e, 0x11, 0xa4, 0x96, 0x38, 0x73,
299            0xaa, 0x98, 0x13, 0x4a,
300        ];
301        assert_eq!(got, expected);
302    }
303
304    /// Full SCRAM-SHA-256 round-trip against a synthetic server that
305    /// follows RFC 5802 mechanics with PG-compatible message shape.
306    /// This is the end-to-end property test: client_first -> server
307    /// crafts server_first -> client_final -> server verifies +
308    /// replies server_final -> client verify_server succeeds.
309    #[test]
310    fn test_scram_roundtrip_against_synthetic_server() {
311        // Client nonce.
312        let (mut scram, first) = Scram::client_first("fyko+d2lbbFgONRv9qkxdawL");
313        // Parse the mechanism header out of client_first:
314        // "SCRAM-SHA-256\0<u32 len><bytes>"
315        let msg = &first.0;
316        let mech_end = msg.iter().position(|&b| b == 0).unwrap();
317        assert_eq!(&msg[..mech_end], b"SCRAM-SHA-256");
318        let len = u32::from_be_bytes(msg[mech_end + 1..mech_end + 5].try_into().unwrap()) as usize;
319        let cfirst = &msg[mech_end + 5..mech_end + 5 + len];
320        let cfirst_str = std::str::from_utf8(cfirst).unwrap();
321        assert!(cfirst_str.starts_with("n,,n=,r=fyko+d2lbbFgONRv9qkxdawL"));
322
323        // ---- synthetic server ----
324        let server_nonce_suffix = "3rfcNHYJY1ZVvWVs7j";
325        let combined_nonce = format!("fyko+d2lbbFgONRv9qkxdawL{}", server_nonce_suffix);
326        let salt: [u8; 16] = [
327            0x41, 0x25, 0xc2, 0x47, 0xe4, 0x3a, 0xb1, 0xe9, 0x3c, 0x6d, 0xff, 0x76, 0xd1, 0x22,
328            0x3a, 0x10,
329        ];
330        let iterations = 4096u32;
331        let salt_b64 = BASE64.encode(salt);
332        let server_first = format!("r={},s={},i={}", combined_nonce, salt_b64, iterations);
333
334        let password = "pencil";
335        let client_final = scram
336            .client_final(server_first.as_bytes(), password)
337            .expect("client_final");
338        let cfinal_str = std::str::from_utf8(&client_final.0).unwrap();
339
340        // Expected pieces present.
341        assert!(cfinal_str.starts_with("c=biws,r=")); // base64("n,,") = "biws"
342        assert!(cfinal_str.contains(&format!("r={}", combined_nonce)));
343        assert!(cfinal_str.contains(",p="));
344
345        // Server-side: derive the same server_key, build AuthMessage from
346        // the pieces we know, then sign.
347        let salted = pbkdf2_hmac_sha256(password.as_bytes(), &salt, iterations);
348        let server_key = hmac_sha256(&salted, b"Server Key");
349        let (cfinal_no_proof, _proof) = {
350            let idx = cfinal_str.rfind(",p=").unwrap();
351            (&cfinal_str[..idx], &cfinal_str[idx + 3..])
352        };
353        let auth_message = format!(
354            "n=,r=fyko+d2lbbFgONRv9qkxdawL,{},{}",
355            server_first, cfinal_no_proof
356        );
357        let server_sig = hmac_sha256(&server_key, auth_message.as_bytes());
358        let server_final = format!("v={}", BASE64.encode(server_sig));
359
360        // Client verifies.
361        scram
362            .verify_server(server_final.as_bytes())
363            .expect("verify_server");
364    }
365
366    #[test]
367    fn test_scram_rejects_nonce_mismatch() {
368        let (mut scram, _) = Scram::client_first("client-nonce");
369        let server_first = "r=OTHER-nonce,s=QUJD,i=4096";
370        let err = scram
371            .client_final(server_first.as_bytes(), "pw")
372            .unwrap_err();
373        assert!(matches!(err, BackendError::Auth(_)));
374    }
375
376    #[test]
377    fn test_scram_rejects_bad_server_signature() {
378        let (mut scram, _) = Scram::client_first("abc");
379        // Set up with valid server-first so client_final succeeds.
380        let server_first = "r=abc-extension,s=QUJD,i=4096";
381        let _ = scram.client_final(server_first.as_bytes(), "pw").unwrap();
382        // Then fake a server-final with a wrong signature.
383        let bad_sig = BASE64.encode([0u8; 32]);
384        let server_final = format!("v={}", bad_sig);
385        assert!(scram.verify_server(server_final.as_bytes()).is_err());
386    }
387
388    #[test]
389    fn test_scram_rejects_server_error() {
390        let (mut scram, _) = Scram::client_first("abc");
391        let server_first = "r=abc-extension,s=QUJD,i=4096";
392        let _ = scram.client_final(server_first.as_bytes(), "pw").unwrap();
393        assert!(scram.verify_server(b"e=invalid-proof").is_err());
394    }
395}