fraiseql-wire 2.3.1

Streaming JSON query engine for Postgres 17
Documentation
//! SCRAM-SHA-256 authentication implementation
//!
//! Implements the SCRAM-SHA-256 (Salted Challenge Response Authentication Mechanism)
//! as defined in RFC 5802 for PostgreSQL authentication (Postgres 10+).

use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use hmac::{Hmac, Mac};
use pbkdf2::pbkdf2;
use rand::Rng;
use sha2::{Digest, Sha256};
use std::fmt;
use zeroize::Zeroizing;

type HmacSha256 = Hmac<Sha256>;

/// Maximum PBKDF2 iteration count accepted from the server (DoS protection).
///
/// A malicious server can supply a very large `i=` value in its SCRAM first message,
/// causing the client to spend seconds (or minutes) in PBKDF2 before the connection
/// is rejected. Capping at 1,000,000 prevents this denial-of-service vector while
/// remaining orders of magnitude above typical PostgreSQL defaults (4096–600,000).
pub(crate) const MAX_SCRAM_ITERATIONS: u32 = 1_000_000;

/// SCRAM authentication error types
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum ScramError {
    /// Invalid proof from server
    InvalidServerProof(String),
    /// Invalid server message format
    InvalidServerMessage(String),
    /// UTF-8 encoding/decoding error
    Utf8Error(String),
    /// Base64 decoding error
    Base64Error(String),
}

impl fmt::Display for ScramError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            ScramError::InvalidServerProof(msg) => write!(f, "invalid server proof: {}", msg),
            ScramError::InvalidServerMessage(msg) => write!(f, "invalid server message: {}", msg),
            ScramError::Utf8Error(msg) => write!(f, "UTF-8 error: {}", msg),
            ScramError::Base64Error(msg) => write!(f, "Base64 error: {}", msg),
        }
    }
}

impl std::error::Error for ScramError {}

/// Internal state needed for SCRAM authentication
#[derive(Clone, Debug)]
pub struct ScramState {
    /// Combined authentication message (for verification)
    auth_message: Vec<u8>,
    /// Server key (for verification calculation)
    server_key: Vec<u8>,
}

/// SCRAM-SHA-256 client implementation
pub struct ScramClient {
    username: String,
    /// Password is stored as `Zeroizing<String>` so the key material is
    /// overwritten with zeros when `ScramClient` is dropped (S38).
    password: Zeroizing<String>,
    nonce: String,
}

impl ScramClient {
    /// Create a new SCRAM client
    #[must_use]
    pub fn new(username: String, password: String) -> Self {
        // SECURITY: rand::rng() is backed by OS-level entropy for SCRAM nonces.
        let mut rng = rand::rng();
        let nonce_bytes: Vec<u8> = (0..24).map(|_| rng.random()).collect();
        let nonce = BASE64.encode(&nonce_bytes);

        Self {
            username,
            password: Zeroizing::new(password),
            nonce,
        }
    }

    /// Generate client first message (no proof)
    #[must_use]
    pub fn client_first(&self) -> String {
        // RFC 5802 format: gs2-header client-first-message-bare
        // gs2-header = "n,," (n = no channel binding, empty authorization identity)
        // client-first-message-bare = "n=<username>,r=<nonce>"
        // RFC 5802 §5.1: username must have ',' escaped as '=2C' and '=' escaped as '=3D'.
        let escaped_username = self.username.replace('=', "=3D").replace(',', "=2C");
        format!("n,,n={},r={}", escaped_username, self.nonce)
    }

    /// Process server first message and generate client final message
    ///
    /// Returns (`client_final_message`, `internal_state`)
    ///
    /// # Errors
    ///
    /// Returns [`ScramError::InvalidServerMessage`] if the server message cannot be parsed,
    /// the server nonce does not start with the client nonce, or the iteration count is
    /// invalid or exceeds `MAX_SCRAM_ITERATIONS`. Returns [`ScramError::Base64Error`] if
    /// the salt is not valid base64.
    pub fn client_final(&mut self, server_first: &str) -> Result<(String, ScramState), ScramError> {
        // Parse server first message: r=<client_nonce><server_nonce>,s=<salt>,i=<iterations>
        let (server_nonce, salt, iterations) = parse_server_first(server_first)?;

        // Verify server nonce starts with our client nonce
        if !server_nonce.starts_with(&self.nonce) {
            return Err(ScramError::InvalidServerMessage(
                "server nonce doesn't contain client nonce".to_string(),
            ));
        }

        // Decode salt and iterations
        let salt_bytes = BASE64
            .decode(&salt)
            .map_err(|_| ScramError::Base64Error("invalid salt encoding".to_string()))?;
        let iterations = iterations
            .parse::<u32>()
            .map_err(|_| ScramError::InvalidServerMessage("invalid iteration count".to_string()))?;

        // SECURITY: Guard against server-supplied iteration counts large enough to
        // cause a denial-of-service via excessive PBKDF2 CPU time.
        if iterations > MAX_SCRAM_ITERATIONS {
            return Err(ScramError::InvalidServerMessage(format!(
                "server iteration count {iterations} exceeds maximum of {MAX_SCRAM_ITERATIONS}"
            )));
        }

        // Build channel binding (no channel binding for SCRAM-SHA-256)
        let channel_binding = BASE64.encode(b"n,,");

        // Build client final without proof
        let client_final_without_proof = format!("c={},r={}", channel_binding, server_nonce);

        // Build auth message for signature calculation.
        // client-first-message-bare is "n=<escaped_username>,r=<nonce>" (without gs2-header).
        // SECURITY: Must use the RFC 5802 §5.1-escaped username (same as client_first()),
        // not the raw username — otherwise an attacker who controls ',' or '=' in a username
        // can inject arbitrary SCRAM attributes and break authentication.
        let escaped_username = self.username.replace('=', "=3D").replace(',', "=2C");
        let client_first_bare = format!("n={},r={}", escaped_username, self.nonce);
        let auth_message = format!(
            "{},{},{}",
            client_first_bare, server_first, client_final_without_proof
        );

        // Calculate proof
        let proof = calculate_client_proof(
            &self.password,
            &salt_bytes,
            iterations,
            auth_message.as_bytes(),
        )?;

        // Calculate server signature for later verification
        let server_key = calculate_server_key(&self.password, &salt_bytes, iterations)?;

        // Build client final message
        let client_final = format!("{},p={}", client_final_without_proof, BASE64.encode(&proof));

        let state = ScramState {
            auth_message: auth_message.into_bytes(),
            server_key,
        };

        Ok((client_final, state))
    }

    /// Verify server final message and confirm authentication
    ///
    /// # Errors
    ///
    /// Returns `ScramError::InvalidServerMessage` if the server final message is malformed.
    /// Returns `ScramError::Base64Error` if the server signature is not valid base64.
    /// Returns `ScramError::AuthenticationFailed` if the server signature does not match.
    pub fn verify_server_final(
        &self,
        server_final: &str,
        state: &ScramState,
    ) -> Result<(), ScramError> {
        // Parse server final: v=<server_signature>
        let server_sig_encoded = server_final
            .strip_prefix("v=")
            .ok_or_else(|| ScramError::InvalidServerMessage("missing 'v=' prefix".to_string()))?;

        let server_signature = BASE64.decode(server_sig_encoded).map_err(|_| {
            ScramError::Base64Error("invalid server signature encoding".to_string())
        })?;

        // Calculate expected server signature
        let expected_signature =
            calculate_server_signature(&state.server_key, &state.auth_message)?;

        // Constant-time comparison
        if constant_time_compare(&server_signature, &expected_signature) {
            Ok(())
        } else {
            Err(ScramError::InvalidServerProof(
                "server signature verification failed".to_string(),
            ))
        }
    }
}

/// Parse server first message format: r=<nonce>,s=<salt>,i=<iterations>
pub(crate) fn parse_server_first(msg: &str) -> Result<(String, String, String), ScramError> {
    let mut nonce = String::new();
    let mut salt = String::new();
    let mut iterations = String::new();

    for part in msg.split(',') {
        if let Some(value) = part.strip_prefix("r=") {
            nonce = value.to_string();
        } else if let Some(value) = part.strip_prefix("s=") {
            salt = value.to_string();
        } else if let Some(value) = part.strip_prefix("i=") {
            iterations = value.to_string();
        }
    }

    if nonce.is_empty() || salt.is_empty() || iterations.is_empty() {
        return Err(ScramError::InvalidServerMessage(
            "missing required fields in server first message".to_string(),
        ));
    }

    Ok((nonce, salt, iterations))
}

/// Calculate SCRAM client proof
fn calculate_client_proof(
    password: &str,
    salt: &[u8],
    iterations: u32,
    auth_message: &[u8],
) -> Result<Vec<u8>, ScramError> {
    // SaltedPassword := PBKDF2(password, salt, iterations, HMAC-SHA256)
    let password_bytes = password.as_bytes();
    let mut salted_password = vec![0u8; 32]; // SHA256 produces 32 bytes
    let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);

    // ClientKey := HMAC(SaltedPassword, "Client Key")
    let mut client_key_hmac = HmacSha256::new_from_slice(&salted_password)
        .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
    client_key_hmac.update(b"Client Key");
    let client_key = client_key_hmac.finalize().into_bytes();

    // StoredKey := SHA256(ClientKey)
    let stored_key = Sha256::digest(client_key.to_vec().as_slice());

    // ClientSignature := HMAC(StoredKey, AuthMessage)
    let mut client_sig_hmac = HmacSha256::new_from_slice(&stored_key)
        .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
    client_sig_hmac.update(auth_message);
    let client_signature = client_sig_hmac.finalize().into_bytes();

    // ClientProof := ClientKey XOR ClientSignature
    let mut proof = client_key.to_vec();
    for (proof_byte, sig_byte) in proof.iter_mut().zip(client_signature.iter()) {
        *proof_byte ^= sig_byte;
    }

    Ok(proof.clone())
}

/// Calculate server key for server signature verification
fn calculate_server_key(
    password: &str,
    salt: &[u8],
    iterations: u32,
) -> Result<Vec<u8>, ScramError> {
    // SaltedPassword := PBKDF2(password, salt, iterations, HMAC-SHA256)
    let password_bytes = password.as_bytes();
    let mut salted_password = vec![0u8; 32];
    let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);

    // ServerKey := HMAC(SaltedPassword, "Server Key")
    let mut server_key_hmac = HmacSha256::new_from_slice(&salted_password)
        .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
    server_key_hmac.update(b"Server Key");

    Ok(server_key_hmac.finalize().into_bytes().to_vec())
}

/// Calculate server signature for verification
fn calculate_server_signature(
    server_key: &[u8],
    auth_message: &[u8],
) -> Result<Vec<u8>, ScramError> {
    let mut hmac = HmacSha256::new_from_slice(server_key)
        .map_err(|_| ScramError::Utf8Error("invalid HMAC key for server signature".to_string()))?;
    hmac.update(auth_message);
    Ok(hmac.finalize().into_bytes().to_vec())
}

/// Constant-time comparison to prevent timing attacks.
///
/// Uses the `subtle` crate for verified constant-time operations.
pub(crate) fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
    use subtle::ConstantTimeEq;
    a.ct_eq(b).into()
}

#[cfg(test)]
mod tests;