Skip to main content

heliosdb_proxy/backend/
auth.rs

1//! PostgreSQL client-side authentication helpers.
2//!
3//! Covers the two mechanisms we need today:
4//! - **MD5** (AuthenticationMD5Password, request code 5). Legacy but
5//!   still widely deployed. Payload is
6//!   `"md5" + hex(md5(hex(md5(password + username)) + salt))`.
7//! - **SCRAM-SHA-256** (AuthenticationSASL, mechanism
8//!   `SCRAM-SHA-256`, request code 10). The current PG default.
9//!
10//! Both implementations verify the server's end of the handshake where
11//! the protocol allows it — MD5 has no server-side verifier, SCRAM does
12//! (the server-final message includes `v=<server-signature>`).
13
14use super::error::{BackendError, BackendResult};
15use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
16use hmac::{Hmac, Mac};
17use sha2::{Digest, Sha256};
18
19type HmacSha256 = Hmac<Sha256>;
20
21// ---------------------------------------------------------------------
22// MD5 authentication
23// ---------------------------------------------------------------------
24
25/// Compute the response payload for `AuthenticationMD5Password`.
26///
27/// Returns the complete `PasswordMessage` payload (the null-terminated
28/// string the client sends back), excluding the tag and length prefix
29/// — the caller frames it.
30pub fn md5_password_response(user: &str, password: &str, salt: &[u8; 4]) -> Vec<u8> {
31    let mut out = Vec::with_capacity(35 + 1);
32    let inner = md5_hex(format!("{}{}", password, user).as_bytes());
33    let mut salted = Vec::with_capacity(inner.len() + 4);
34    salted.extend_from_slice(inner.as_bytes());
35    salted.extend_from_slice(salt);
36    out.extend_from_slice(b"md5");
37    out.extend_from_slice(md5_hex(&salted).as_bytes());
38    out.push(0);
39    out
40}
41
42fn md5_hex(bytes: &[u8]) -> String {
43    let digest = md5::Md5::digest(bytes);
44    let mut s = String::with_capacity(digest.len() * 2);
45    for b in digest {
46        s.push_str(&format!("{:02x}", b));
47    }
48    s
49}
50
51// ---------------------------------------------------------------------
52// SCRAM-SHA-256
53// ---------------------------------------------------------------------
54
55/// SCRAM client state machine. Create with `Scram::client_first`, feed
56/// the server-first into `client_final`, and feed the server-final into
57/// `verify_server`.
58pub struct Scram {
59    /// Cached for HMAC in server-signature check.
60    client_first_bare: String,
61    /// Saved for AuthMessage construction.
62    nonce: String,
63    /// Computed after `client_final`; used to verify server signature.
64    server_key: [u8; 32],
65    /// Full AuthMessage computed after client_final.
66    auth_message: String,
67    /// Whether `client_final` ran (guards `verify_server`).
68    finalised: bool,
69}
70
71/// Result of one SCRAM step: the opaque bytes to send to the server.
72#[derive(Debug)]
73pub struct ScramMessage(pub Vec<u8>);
74
75impl Scram {
76    /// Build the SASL initial response for `SCRAM-SHA-256`.
77    ///
78    /// The returned bytes are the payload of a `PasswordMessage`
79    /// (tag `p`). `nonce` must be a unique random string per session —
80    /// the caller provides it so the function is testable.
81    pub fn client_first(nonce: impl Into<String>) -> (Self, ScramMessage) {
82        let nonce = nonce.into();
83        // gs2-header is "n,," (no channel binding, no authzid).
84        // client-first-bare is n=<user>,r=<nonce>. PG ignores <user>
85        // and takes the name from the StartupMessage; we send empty.
86        let client_first_bare = format!("n=,r={}", nonce);
87        let client_first = format!("n,,{}", client_first_bare);
88
89        // SASL format: mechanism + NUL + 4-byte BE length + bytes.
90        let mech = b"SCRAM-SHA-256\0";
91        let mut out = Vec::with_capacity(mech.len() + 4 + client_first.len());
92        out.extend_from_slice(mech);
93        out.extend_from_slice(&(client_first.len() as u32).to_be_bytes());
94        out.extend_from_slice(client_first.as_bytes());
95
96        (
97            Self {
98                client_first_bare,
99                nonce,
100                server_key: [0u8; 32],
101                auth_message: String::new(),
102                finalised: false,
103            },
104            ScramMessage(out),
105        )
106    }
107
108    /// Consume the server-first message and produce the client-final.
109    ///
110    /// `server_first` is the raw bytes the server sent (the payload of
111    /// an `AuthenticationSASLContinue` frame, minus the 4-byte type
112    /// code which the caller strips).
113    pub fn client_final(
114        &mut self,
115        server_first: &[u8],
116        password: &str,
117    ) -> BackendResult<ScramMessage> {
118        let server_first_str = std::str::from_utf8(server_first).map_err(|e| {
119            BackendError::Auth(format!("server-first is not UTF-8: {}", e))
120        })?;
121
122        // Parse r=<combined-nonce>,s=<salt-base64>,i=<iteration-count>
123        let mut server_nonce = None;
124        let mut salt_b64 = None;
125        let mut iterations: Option<u32> = None;
126        for field in server_first_str.split(',') {
127            if let Some(rest) = field.strip_prefix("r=") {
128                server_nonce = Some(rest);
129            } else if let Some(rest) = field.strip_prefix("s=") {
130                salt_b64 = Some(rest);
131            } else if let Some(rest) = field.strip_prefix("i=") {
132                iterations = rest.parse().ok();
133            }
134        }
135        let server_nonce = server_nonce
136            .ok_or_else(|| BackendError::Auth("missing r= in server-first".into()))?;
137        let salt_b64 = salt_b64
138            .ok_or_else(|| BackendError::Auth("missing s= in server-first".into()))?;
139        let iterations = iterations
140            .ok_or_else(|| BackendError::Auth("missing/invalid i= in server-first".into()))?;
141
142        // The server must echo the client nonce as a prefix.
143        if !server_nonce.starts_with(&self.nonce) {
144            return Err(BackendError::Auth(
145                "server nonce does not extend client nonce".into(),
146            ));
147        }
148        if iterations < 1 {
149            return Err(BackendError::Auth("iteration count must be >= 1".into()));
150        }
151
152        let salt = BASE64
153            .decode(salt_b64)
154            .map_err(|e| BackendError::Auth(format!("bad salt base64: {}", e)))?;
155
156        // Derive keys per RFC 5802.
157        let salted_password = pbkdf2_hmac_sha256(password.as_bytes(), &salt, iterations);
158        let client_key = hmac_sha256(&salted_password, b"Client Key");
159        let stored_key = sha256(&client_key);
160        self.server_key = hmac_sha256(&salted_password, b"Server Key");
161
162        // channel-binding: "c=" + base64("n,,")
163        let channel_binding = BASE64.encode(b"n,,");
164
165        let client_final_without_proof =
166            format!("c={},r={}", channel_binding, server_nonce);
167        self.auth_message = format!(
168            "{},{},{}",
169            self.client_first_bare, server_first_str, client_final_without_proof
170        );
171
172        let client_signature = hmac_sha256(&stored_key, self.auth_message.as_bytes());
173        let mut client_proof = [0u8; 32];
174        for i in 0..32 {
175            client_proof[i] = client_key[i] ^ client_signature[i];
176        }
177
178        let client_final = format!(
179            "{},p={}",
180            client_final_without_proof,
181            BASE64.encode(client_proof)
182        );
183
184        self.finalised = true;
185        Ok(ScramMessage(client_final.into_bytes()))
186    }
187
188    /// Verify the server-final message's `v=<server-signature>` tag.
189    /// Returns `Ok(())` only if the signature matches what we expect
190    /// from the derived `server_key`.
191    pub fn verify_server(&self, server_final: &[u8]) -> BackendResult<()> {
192        if !self.finalised {
193            return Err(BackendError::Auth(
194                "verify_server called before client_final".into(),
195            ));
196        }
197        let s = std::str::from_utf8(server_final).map_err(|e| {
198            BackendError::Auth(format!("server-final is not UTF-8: {}", e))
199        })?;
200        // Server may send `e=<error>` on failure.
201        if let Some(err) = s.strip_prefix("e=") {
202            return Err(BackendError::Auth(format!("server reported: {}", err)));
203        }
204        let sig_b64 = s
205            .strip_prefix("v=")
206            .ok_or_else(|| BackendError::Auth("missing v= in server-final".into()))?
207            .split(',')
208            .next()
209            .unwrap_or("");
210        let received = BASE64
211            .decode(sig_b64)
212            .map_err(|e| BackendError::Auth(format!("bad v= base64: {}", e)))?;
213        let expected = hmac_sha256(&self.server_key, self.auth_message.as_bytes());
214        if received == expected {
215            Ok(())
216        } else {
217            Err(BackendError::Auth("server signature mismatch".into()))
218        }
219    }
220}
221
222// ---- crypto primitives --------------------------------------------------
223
224fn sha256(data: &[u8]) -> [u8; 32] {
225    let mut h = Sha256::new();
226    h.update(data);
227    h.finalize().into()
228}
229
230fn hmac_sha256(key: &[u8], data: &[u8]) -> [u8; 32] {
231    let mut mac =
232        HmacSha256::new_from_slice(key).expect("HMAC accepts any key length");
233    mac.update(data);
234    let tag = mac.finalize().into_bytes();
235    let mut out = [0u8; 32];
236    out.copy_from_slice(&tag);
237    out
238}
239
240fn pbkdf2_hmac_sha256(password: &[u8], salt: &[u8], iters: u32) -> [u8; 32] {
241    // Single-block PBKDF2 (dkLen == hLen == 32) — exactly what SCRAM
242    // requires.
243    let mut mac = HmacSha256::new_from_slice(password)
244        .expect("HMAC accepts any key length");
245    mac.update(salt);
246    mac.update(&1u32.to_be_bytes());
247    let mut u: [u8; 32] = mac.finalize().into_bytes().into();
248    let mut out = u;
249    for _ in 1..iters {
250        let mut mac = HmacSha256::new_from_slice(password)
251            .expect("HMAC accepts any key length");
252        mac.update(&u);
253        u = mac.finalize().into_bytes().into();
254        for i in 0..32 {
255            out[i] ^= u[i];
256        }
257    }
258    out
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    /// Known-answer MD5 auth per PostgreSQL docs:
266    /// `concat('md5', md5(md5(password || username) || salt))`.
267    #[test]
268    fn test_md5_password_response_known_answer() {
269        // username = "alice", password = "secret", salt = [0x01,0x02,0x03,0x04]
270        let got = md5_password_response("alice", "secret", &[0x01, 0x02, 0x03, 0x04]);
271        // Last byte is the cstring terminator.
272        assert_eq!(got.last().copied(), Some(0u8));
273        let body = std::str::from_utf8(&got[..got.len() - 1]).unwrap();
274        assert!(body.starts_with("md5"));
275        assert_eq!(body.len(), 3 + 32); // "md5" + 32 hex chars
276        // Re-derive and compare.
277        let inner = md5_hex(b"secretalice");
278        let mut combined = inner.into_bytes();
279        combined.extend_from_slice(&[0x01, 0x02, 0x03, 0x04]);
280        let outer = md5_hex(&combined);
281        assert_eq!(&body[3..], outer);
282    }
283
284    /// PBKDF2-HMAC-SHA-256 known-answer from RFC 7914 / RFC 5802 test
285    /// vectors. (P="password", S="salt", c=1, dkLen=32.)
286    #[test]
287    fn test_pbkdf2_hmac_sha256_rfc_vector() {
288        let got = pbkdf2_hmac_sha256(b"password", b"salt", 1);
289        let expected = [
290            0x12, 0x0f, 0xb6, 0xcf, 0xfc, 0xf8, 0xb3, 0x2c, 0x43, 0xe7, 0x22, 0x52,
291            0x56, 0xc4, 0xf8, 0x37, 0xa8, 0x65, 0x48, 0xc9, 0x2c, 0xcc, 0x35, 0x48,
292            0x08, 0x05, 0x98, 0x7c, 0xb7, 0x0b, 0xe1, 0x7b,
293        ];
294        assert_eq!(got, expected);
295    }
296
297    /// Higher iteration count — smoke test that the loop accumulates
298    /// correctly. Taken from the same RFC set (c=4096).
299    #[test]
300    fn test_pbkdf2_hmac_sha256_high_iters() {
301        let got = pbkdf2_hmac_sha256(b"password", b"salt", 4096);
302        let expected = [
303            0xc5, 0xe4, 0x78, 0xd5, 0x92, 0x88, 0xc8, 0x41, 0xaa, 0x53, 0x0d, 0xb6,
304            0x84, 0x5c, 0x4c, 0x8d, 0x96, 0x28, 0x93, 0xa0, 0x01, 0xce, 0x4e, 0x11,
305            0xa4, 0x96, 0x38, 0x73, 0xaa, 0x98, 0x13, 0x4a,
306        ];
307        assert_eq!(got, expected);
308    }
309
310    /// Full SCRAM-SHA-256 round-trip against a synthetic server that
311    /// follows RFC 5802 mechanics with PG-compatible message shape.
312    /// This is the end-to-end property test: client_first -> server
313    /// crafts server_first -> client_final -> server verifies +
314    /// replies server_final -> client verify_server succeeds.
315    #[test]
316    fn test_scram_roundtrip_against_synthetic_server() {
317        // Client nonce.
318        let (mut scram, first) = Scram::client_first("fyko+d2lbbFgONRv9qkxdawL");
319        // Parse the mechanism header out of client_first:
320        // "SCRAM-SHA-256\0<u32 len><bytes>"
321        let msg = &first.0;
322        let mech_end = msg.iter().position(|&b| b == 0).unwrap();
323        assert_eq!(&msg[..mech_end], b"SCRAM-SHA-256");
324        let len =
325            u32::from_be_bytes(msg[mech_end + 1..mech_end + 5].try_into().unwrap())
326                as usize;
327        let cfirst = &msg[mech_end + 5..mech_end + 5 + len];
328        let cfirst_str = std::str::from_utf8(cfirst).unwrap();
329        assert!(cfirst_str.starts_with("n,,n=,r=fyko+d2lbbFgONRv9qkxdawL"));
330
331        // ---- synthetic server ----
332        let server_nonce_suffix = "3rfcNHYJY1ZVvWVs7j";
333        let combined_nonce =
334            format!("fyko+d2lbbFgONRv9qkxdawL{}", server_nonce_suffix);
335        let salt: [u8; 16] = [
336            0x41, 0x25, 0xc2, 0x47, 0xe4, 0x3a, 0xb1, 0xe9, 0x3c, 0x6d, 0xff, 0x76,
337            0xd1, 0x22, 0x3a, 0x10,
338        ];
339        let iterations = 4096u32;
340        let salt_b64 = BASE64.encode(salt);
341        let server_first = format!(
342            "r={},s={},i={}",
343            combined_nonce, salt_b64, iterations
344        );
345
346        let password = "pencil";
347        let client_final = scram
348            .client_final(server_first.as_bytes(), password)
349            .expect("client_final");
350        let cfinal_str = std::str::from_utf8(&client_final.0).unwrap();
351
352        // Expected pieces present.
353        assert!(cfinal_str.starts_with("c=biws,r=")); // base64("n,,") = "biws"
354        assert!(cfinal_str.contains(&format!("r={}", combined_nonce)));
355        assert!(cfinal_str.contains(",p="));
356
357        // Server-side: derive the same server_key, build AuthMessage from
358        // the pieces we know, then sign.
359        let salted = pbkdf2_hmac_sha256(password.as_bytes(), &salt, iterations);
360        let server_key = hmac_sha256(&salted, b"Server Key");
361        let (cfinal_no_proof, _proof) = {
362            let idx = cfinal_str.rfind(",p=").unwrap();
363            (&cfinal_str[..idx], &cfinal_str[idx + 3..])
364        };
365        let auth_message = format!(
366            "n=,r=fyko+d2lbbFgONRv9qkxdawL,{},{}",
367            server_first, cfinal_no_proof
368        );
369        let server_sig = hmac_sha256(&server_key, auth_message.as_bytes());
370        let server_final = format!("v={}", BASE64.encode(server_sig));
371
372        // Client verifies.
373        scram
374            .verify_server(server_final.as_bytes())
375            .expect("verify_server");
376    }
377
378    #[test]
379    fn test_scram_rejects_nonce_mismatch() {
380        let (mut scram, _) = Scram::client_first("client-nonce");
381        let server_first = "r=OTHER-nonce,s=QUJD,i=4096";
382        let err = scram.client_final(server_first.as_bytes(), "pw").unwrap_err();
383        assert!(matches!(err, BackendError::Auth(_)));
384    }
385
386    #[test]
387    fn test_scram_rejects_bad_server_signature() {
388        let (mut scram, _) = Scram::client_first("abc");
389        // Set up with valid server-first so client_final succeeds.
390        let server_first = "r=abc-extension,s=QUJD,i=4096";
391        let _ = scram.client_final(server_first.as_bytes(), "pw").unwrap();
392        // Then fake a server-final with a wrong signature.
393        let bad_sig = BASE64.encode([0u8; 32]);
394        let server_final = format!("v={}", bad_sig);
395        assert!(scram.verify_server(server_final.as_bytes()).is_err());
396    }
397
398    #[test]
399    fn test_scram_rejects_server_error() {
400        let (mut scram, _) = Scram::client_first("abc");
401        let server_first = "r=abc-extension,s=QUJD,i=4096";
402        let _ = scram.client_final(server_first.as_bytes(), "pw").unwrap();
403        assert!(scram.verify_server(b"e=invalid-proof").is_err());
404    }
405}