qail-pg 0.27.4

Fastest async PostgreSQL driver - AST to wire protocol, optional io_uring on Linux
Documentation
//! SCRAM-SHA-256 Authentication
//!
//! Implements the SASL SCRAM-SHA-256 authentication mechanism for PostgreSQL.
//! Reference: RFC 5802, PostgreSQL SASL documentation.

use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
use hmac::{Hmac, Mac};
use rand::RngExt;
use sha2::{Digest, Sha256};

type HmacSha256 = Hmac<Sha256>;
const GS2_HEADER_NO_CHANNEL_BINDING: &str = "n,,";
const GS2_HEADER_TLS_SERVER_END_POINT: &str = "p=tls-server-end-point,,";

/// SCRAM-SHA-256 client state machine.
pub struct ScramClient {
    username: String,
    password: String,
    /// Client nonce (random)
    client_nonce: String,
    /// Combined nonce (client + server)
    combined_nonce: Option<String>,
    /// Salt from server
    salt: Option<Vec<u8>>,
    /// Iteration count from server
    iterations: Option<u32>,
    /// Auth message for signature verification
    auth_message: Option<String>,
    /// Salted password (cached for verification)
    salted_password: Option<Vec<u8>>,
    /// Channel binding bytes for SCRAM-SHA-256-PLUS (`tls-server-end-point`).
    channel_binding_data: Option<Vec<u8>>,
    /// GS2 header prefix (`n,,` or `p=tls-server-end-point,,`).
    gs2_header: &'static str,
}

impl ScramClient {
    /// Create a new SCRAM client for authentication.
    pub fn new(username: &str, password: &str) -> Self {
        Self::new_inner(username, password, None)
    }

    /// Create a SCRAM client using `tls-server-end-point` channel binding.
    pub fn new_with_tls_server_end_point(
        username: &str,
        password: &str,
        channel_binding_data: Vec<u8>,
    ) -> Self {
        Self::new_inner(username, password, Some(channel_binding_data))
    }

    fn new_inner(username: &str, password: &str, channel_binding_data: Option<Vec<u8>>) -> Self {
        let mut rng = rand::rng();
        let nonce: String = (0..24)
            .map(|_| {
                let idx = rng.random_range(0..62);
                let chars = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
                chars[idx] as char
            })
            .collect();

        Self {
            username: username.to_string(),
            password: password.to_string(),
            client_nonce: nonce,
            combined_nonce: None,
            salt: None,
            iterations: None,
            auth_message: None,
            salted_password: None,
            gs2_header: if channel_binding_data.is_some() {
                GS2_HEADER_TLS_SERVER_END_POINT
            } else {
                GS2_HEADER_NO_CHANNEL_BINDING
            },
            channel_binding_data,
        }
    }

    /// Format: `n,,n=<user>,r=<nonce>`
    pub fn client_first_message(&self) -> Vec<u8> {
        // client-first-message = gs2-header + client-first-message-bare
        let msg = format!(
            "{}n={},r={}",
            self.gs2_header, self.username, self.client_nonce
        );
        msg.into_bytes()
    }

    /// Get the client-first-message-bare (for auth message construction).
    fn client_first_message_bare(&self) -> String {
        format!("n={},r={}", self.username, self.client_nonce)
    }

    fn channel_binding_input(&self) -> Vec<u8> {
        let binding_len = self.channel_binding_data.as_ref().map_or(0, Vec::len);
        let mut input = Vec::with_capacity(self.gs2_header.len() + binding_len);
        input.extend_from_slice(self.gs2_header.as_bytes());
        if let Some(data) = &self.channel_binding_data {
            input.extend_from_slice(data);
        }
        input
    }

    /// Process the server-first-message and generate client-final-message.
    /// Server-first-message format: `r=<nonce>,s=<salt>,i=<iterations>`
    pub fn process_server_first(&mut self, server_msg: &[u8]) -> Result<Vec<u8>, String> {
        let server_str =
            std::str::from_utf8(server_msg).map_err(|_| "Invalid UTF-8 in server message")?;

        let mut nonce = None;
        let mut salt = None;
        let mut iterations = None;

        for part in server_str.split(',') {
            if let Some(value) = part.strip_prefix("r=") {
                nonce = Some(value.to_string());
            } else if let Some(value) = part.strip_prefix("s=") {
                salt = Some(BASE64.decode(value).map_err(|_| "Invalid salt base64")?);
            } else if let Some(value) = part.strip_prefix("i=") {
                iterations = Some(
                    value
                        .parse::<u32>()
                        .map_err(|_| "Invalid iteration count")?,
                );
            }
        }

        let nonce = nonce.ok_or("Missing nonce in server message")?;
        let salt = salt.ok_or("Missing salt in server message")?;
        let iterations = iterations.ok_or("Missing iterations in server message")?;

        // R9-C: Enforce sane PBKDF2 iteration bounds.
        // Too low (< 4096): degrades key-stretching security.
        // Too high (> 100_000): CPU DoS — a rogue server could send i=999999999.
        if iterations < 4096 {
            return Err(format!(
                "SCRAM iteration count too low: {} (minimum 4096)",
                iterations,
            ));
        }
        if iterations > 100_000 {
            return Err(format!(
                "SCRAM iteration count too high: {} (maximum 100000)",
                iterations,
            ));
        }

        // Verify nonce starts with our client nonce
        if !nonce.starts_with(&self.client_nonce) {
            return Err("Server nonce doesn't contain client nonce".to_string());
        }

        self.combined_nonce = Some(nonce.clone());
        self.salt = Some(salt.clone());
        self.iterations = Some(iterations);

        // Derive salted password using PBKDF2
        let salted_password = self.derive_salted_password(&salt, iterations);
        self.salted_password = Some(salted_password.clone());

        // Compute keys
        let client_key = self.hmac(&salted_password, b"Client Key")?;
        let stored_key = Self::sha256(&client_key);

        let client_first_bare = self.client_first_message_bare();
        let channel_binding_b64 = BASE64.encode(self.channel_binding_input());
        let client_final_without_proof = format!("c={},r={}", channel_binding_b64, nonce);
        let auth_message = format!(
            "{},{},{}",
            client_first_bare, server_str, client_final_without_proof
        );
        self.auth_message = Some(auth_message.clone());

        // Compute proof
        let client_signature = self.hmac(&stored_key, auth_message.as_bytes())?;
        let client_proof: Vec<u8> = client_key
            .iter()
            .zip(client_signature.iter())
            .map(|(a, b)| a ^ b)
            .collect();

        let proof_b64 = BASE64.encode(&client_proof);
        let client_final = format!("{},p={}", client_final_without_proof, proof_b64);

        Ok(client_final.into_bytes())
    }

    /// Verify the server-final-message (server signature).
    pub fn verify_server_final(&self, server_msg: &[u8]) -> Result<(), String> {
        let server_str =
            std::str::from_utf8(server_msg).map_err(|_| "Invalid UTF-8 in server final message")?;

        let verifier = server_str
            .strip_prefix("v=")
            .ok_or("Missing verifier in server final message")?;

        let expected_signature = BASE64
            .decode(verifier)
            .map_err(|_| "Invalid base64 in server signature")?;

        // Compute expected server signature
        let salted_password = self
            .salted_password
            .as_ref()
            .ok_or("Missing salted password")?;
        let auth_message = self.auth_message.as_ref().ok_or("Missing auth message")?;

        let server_key = self.hmac(salted_password, b"Server Key")?;
        let computed_signature = self.hmac(&server_key, auth_message.as_bytes())?;

        if computed_signature != expected_signature {
            return Err("Server signature verification failed".to_string());
        }

        Ok(())
    }

    /// Derive salted password using PBKDF2-SHA256.
    fn derive_salted_password(&self, salt: &[u8], iterations: u32) -> Vec<u8> {
        let mut output = [0u8; 32];
        pbkdf2::pbkdf2_hmac::<Sha256>(self.password.as_bytes(), salt, iterations, &mut output);
        output.to_vec()
    }

    fn hmac(&self, key: &[u8], data: &[u8]) -> Result<Vec<u8>, String> {
        let mut mac = HmacSha256::new_from_slice(key)
            .map_err(|_| "HMAC init failed for SCRAM-SHA-256".to_string())?;
        mac.update(data);
        Ok(mac.finalize().into_bytes().to_vec())
    }

    fn sha256(data: &[u8]) -> Vec<u8> {
        let mut hasher = Sha256::new();
        hasher.update(data);
        hasher.finalize().to_vec()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_client_first_message() {
        let client = ScramClient::new("user", "password");
        let msg = client.client_first_message();
        let msg_str = String::from_utf8(msg).unwrap();

        assert!(msg_str.starts_with("n,,n=user,r="));
    }

    #[test]
    fn test_scram_flow() {
        let mut client = ScramClient::new("testuser", "testpass");

        // Client sends first message
        let first = client.client_first_message();
        assert!(String::from_utf8(first).unwrap().contains("n=testuser"));

        // Simulate server response (in real scenario, server generates this)
        // Format: r=<combined_nonce>,s=<salt_base64>,i=<iterations>
        let server_nonce = format!("{}ServerPart", client.client_nonce);
        let salt_b64 = BASE64.encode(b"randomsalt");
        let server_first = format!("r={},s={},i=4096", server_nonce, salt_b64);

        // Client processes and generates final message
        let final_msg = client
            .process_server_first(server_first.as_bytes())
            .unwrap();
        let final_str = String::from_utf8(final_msg).unwrap();

        assert!(final_str.starts_with("c=biws,r="));
        assert!(final_str.contains(",p="));
    }

    #[test]
    fn test_client_first_message_plus() {
        let client =
            ScramClient::new_with_tls_server_end_point("user", "password", vec![1, 2, 3, 4]);
        let msg = String::from_utf8(client.client_first_message()).unwrap();
        assert!(msg.starts_with("p=tls-server-end-point,,n=user,r="));
    }

    #[test]
    fn test_scram_plus_final_channel_binding_payload() {
        let cb_data = vec![0xde, 0xad, 0xbe, 0xef];
        let mut client =
            ScramClient::new_with_tls_server_end_point("testuser", "testpass", cb_data.clone());

        let server_nonce = format!("{}ServerPart", client.client_nonce);
        let salt_b64 = BASE64.encode(b"randomsalt");
        let server_first = format!("r={},s={},i=4096", server_nonce, salt_b64);

        let final_msg = client
            .process_server_first(server_first.as_bytes())
            .unwrap();
        let final_str = String::from_utf8(final_msg).unwrap();
        let encoded_cb = final_str
            .split(',')
            .find_map(|part| part.strip_prefix("c="))
            .unwrap()
            .to_string();
        let decoded = BASE64.decode(encoded_cb).unwrap();

        let mut expected = GS2_HEADER_TLS_SERVER_END_POINT.as_bytes().to_vec();
        expected.extend_from_slice(&cb_data);
        assert_eq!(decoded, expected);
    }
}