cartel-pg 0.1.1

Async PostgreSQL driver for the dope runtime
use base64::Engine;
use base64::engine::general_purpose::STANDARD;
use hmac::{Hmac, Mac};
use rand::RngCore;
use sha2::{Digest, Sha256};

use crate::Error;

const MECH: &str = "SCRAM-SHA-256";
const MECH_PLUS: &str = "SCRAM-SHA-256-PLUS";
const GS2_HEADER: &str = "n,,";
const GS2_HEADER_B64: &str = "biws";

type Hs = Hmac<Sha256>;

pub(super) struct Scram {
    password: Vec<u8>,
    client_nonce: String,
    client_first_bare: String,
    auth_message: String,
    server_signature: [u8; 32],
}

impl Scram {
    pub(super) fn new(password: &str) -> Self {
        let mut nonce_raw = [0u8; 18];
        rand::rngs::OsRng.fill_bytes(&mut nonce_raw);
        let client_nonce = STANDARD.encode(nonce_raw);
        let client_first_bare = format!("n=,r={client_nonce}");
        Self {
            password: password.as_bytes().to_vec(),
            client_nonce,
            client_first_bare,
            auth_message: String::new(),
            server_signature: [0u8; 32],
        }
    }

    pub(super) fn pick_mechanism(&self, offered: &[&str]) -> Result<&'static str, Error> {
        if offered.contains(&MECH) {
            Ok(MECH)
        } else if offered.contains(&MECH_PLUS) {
            Err(Error::Auth(
                "server only offers SCRAM-SHA-256-PLUS (channel binding) which is not supported"
                    .into(),
            ))
        } else {
            Err(Error::Auth(format!(
                "no compatible SASL mechanism offered: {:?}",
                offered
            )))
        }
    }

    pub(super) fn client_first(&self) -> String {
        format!("{GS2_HEADER}{}", self.client_first_bare)
    }

    pub(super) fn client_final(&mut self, server_first: &[u8]) -> Result<String, Error> {
        let server_first_str = std::str::from_utf8(server_first)
            .map_err(|_| Error::Auth("server-first not utf-8".into()))?;
        let mut nonce_b64 = "";
        let mut salt_b64 = "";
        let mut iterations = 0u32;
        for attr in server_first_str.split(',') {
            if let Some(v) = attr.strip_prefix("r=") {
                nonce_b64 = v;
            } else if let Some(v) = attr.strip_prefix("s=") {
                salt_b64 = v;
            } else if let Some(v) = attr.strip_prefix("i=") {
                iterations = v
                    .parse()
                    .map_err(|_| Error::Auth("server-first: bad iteration count".into()))?;
            }
        }
        if nonce_b64.is_empty() || salt_b64.is_empty() || iterations == 0 {
            return Err(Error::Auth("server-first missing fields".into()));
        }
        if !nonce_b64.starts_with(&self.client_nonce) {
            return Err(Error::Auth(
                "server nonce does not extend client nonce".into(),
            ));
        }
        let salt = STANDARD
            .decode(salt_b64.as_bytes())
            .map_err(|_| Error::Auth("server-first: bad base64 salt".into()))?;

        let salted = pbkdf2_sha256_32(&self.password, &salt, iterations);
        let client_key = hmac_sha256(&salted, b"Client Key");
        let stored_key = sha256(&client_key);
        let server_key = hmac_sha256(&salted, b"Server Key");

        let client_final_no_proof = format!("c={GS2_HEADER_B64},r={nonce_b64}");
        self.auth_message = format!(
            "{},{},{}",
            self.client_first_bare, server_first_str, client_final_no_proof
        );

        let client_signature = hmac_sha256(&stored_key, self.auth_message.as_bytes());
        let mut client_proof = client_key;
        for (a, b) in client_proof.iter_mut().zip(client_signature.iter()) {
            *a ^= *b;
        }
        let client_proof_b64 = STANDARD.encode(client_proof);
        let server_signature_v = hmac_sha256(&server_key, self.auth_message.as_bytes());
        self.server_signature = server_signature_v;

        Ok(format!("{client_final_no_proof},p={client_proof_b64}"))
    }

    pub(super) fn verify_server_final(&self, server_final: &[u8]) -> Result<(), Error> {
        let s = std::str::from_utf8(server_final)
            .map_err(|_| Error::Auth("server-final not utf-8".into()))?;
        let v_b64 = s
            .split(',')
            .find_map(|attr| attr.strip_prefix("v="))
            .ok_or(Error::Auth("server-final missing v=".into()))?;
        let sig = STANDARD
            .decode(v_b64.as_bytes())
            .map_err(|_| Error::Auth("server-final: bad base64 v=".into()))?;
        if sig.as_slice() != self.server_signature.as_slice() {
            return Err(Error::Auth("server signature mismatch".into()));
        }
        Ok(())
    }
}

fn sha256(input: &[u8]) -> [u8; 32] {
    let mut h = Sha256::new();
    h.update(input);
    h.finalize().into()
}

fn hmac_sha256(key: &[u8], data: &[u8]) -> [u8; 32] {
    let mut m = <Hs as Mac>::new_from_slice(key).expect("hmac key length always valid");
    m.update(data);
    m.finalize().into_bytes().into()
}

fn pbkdf2_sha256_32(password: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] {
    let mut u = {
        let mut m = <Hs as Mac>::new_from_slice(password).expect("hmac key length always valid");
        m.update(salt);
        m.update(&1u32.to_be_bytes());
        m.finalize().into_bytes()
    };
    let mut t = u;
    for _ in 1..iterations {
        u = {
            let mut m =
                <Hs as Mac>::new_from_slice(password).expect("hmac key length always valid");
            m.update(&u);
            m.finalize().into_bytes()
        };
        for (a, b) in t.iter_mut().zip(u.iter()) {
            *a ^= *b;
        }
    }
    t.into()
}