sqlx-mysql 0.9.0

MySQL driver implementation for SQLx. Not for direct use; see the `sqlx` crate for details.
Documentation
use bytes::buf::Chain;
use bytes::Bytes;
use digest::Digest;
use sha1::Sha1;
use sha2::Sha256;

use crate::connection::stream::MySqlStream;
use crate::error::Error;
use crate::protocol::auth::AuthPlugin;
use crate::protocol::Packet;

impl AuthPlugin {
    pub(super) async fn scramble(
        self,
        stream: &mut MySqlStream,
        password: &str,
        nonce: &Chain<Bytes, Bytes>,
    ) -> Result<Vec<u8>, Error> {
        match self {
            // https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/
            AuthPlugin::CachingSha2Password => Ok(scramble_sha256(password, nonce).to_vec()),

            AuthPlugin::MySqlNativePassword => Ok(scramble_sha1(password, nonce).to_vec()),

            // https://mariadb.com/kb/en/sha256_password-plugin/
            AuthPlugin::Sha256Password => encrypt_rsa(stream, 0x01, password, nonce).await,

            AuthPlugin::MySqlClearPassword => {
                let mut pw_bytes = password.as_bytes().to_owned();
                pw_bytes.push(0); // null terminate
                Ok(pw_bytes)
            }
        }
    }

    pub(super) async fn handle(
        self,
        stream: &mut MySqlStream,
        packet: Packet<Bytes>,
        password: &str,
        nonce: &Chain<Bytes, Bytes>,
    ) -> Result<bool, Error> {
        match self {
            AuthPlugin::CachingSha2Password if packet[0] == 0x01 => {
                match packet[1] {
                    // fast_auth_success — the server still sends a trailing
                    // OK_Packet, so yield back to the handshake loop and let
                    // it consume the OK on the next iteration.
                    0x03 => Ok(false),

                    // perform_full_authentication
                    0x04 => {
                        let payload = encrypt_rsa(stream, 0x02, password, nonce).await?;

                        stream.write_packet(&*payload)?;
                        stream.flush().await?;

                        Ok(false)
                    }

                    v => {
                        Err(err_protocol!("unexpected result from fast authentication 0x{:x} when expecting 0x03 (fast_auth_success) or 0x04 (perform_full_authentication)", v))
                    }
                }
            }

            _ => Err(err_protocol!(
                "unexpected packet 0x{:02x} for auth plugin '{}' during authentication",
                packet[0],
                self.name()
            )),
        }
    }
}

fn scramble_sha1(password: &str, nonce: &Chain<Bytes, Bytes>) -> Vec<u8> {
    // SHA1( password ) ^ SHA1( seed + SHA1( SHA1( password ) ) )
    // https://mariadb.com/kb/en/connection/#mysql_native_password-plugin

    let mut ctx = Sha1::new();

    ctx.update(password);

    let mut pw_hash = ctx.finalize_reset();

    ctx.update(pw_hash);

    let pw_hash_hash = ctx.finalize_reset();

    ctx.update(nonce.first_ref());
    ctx.update(nonce.last_ref());
    ctx.update(pw_hash_hash);

    let pw_seed_hash_hash = ctx.finalize();

    xor_eq(&mut pw_hash, &pw_seed_hash_hash);

    pw_hash.to_vec()
}

fn scramble_sha256(password: &str, nonce: &Chain<Bytes, Bytes>) -> Vec<u8> {
    // XOR(SHA256(password), SHA256(SHA256(SHA256(password)), seed))
    // Order matches the server-side verification in MySQL's sha2_password
    // (generate_sha2_scramble): stage2 digest first, then the nonce.
    let mut ctx = Sha256::new();

    ctx.update(password);

    let mut pw_hash = ctx.finalize_reset();

    ctx.update(pw_hash);

    let pw_hash_hash = ctx.finalize_reset();

    ctx.update(pw_hash_hash);
    ctx.update(nonce.first_ref());
    ctx.update(nonce.last_ref());

    let pw_seed_hash_hash = ctx.finalize();

    xor_eq(&mut pw_hash, &pw_seed_hash_hash);

    pw_hash.to_vec()
}

async fn encrypt_rsa<'s>(
    stream: &'s mut MySqlStream,
    public_key_request_id: u8,
    password: &'s str,
    nonce: &'s Chain<Bytes, Bytes>,
) -> Result<Vec<u8>, Error> {
    // https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/

    if stream.is_tls {
        // If in a TLS stream, send the password directly in clear text
        return Ok(to_asciz(password));
    }

    // client sends a public key request
    stream.write_packet(&[public_key_request_id][..])?;
    stream.flush().await?;

    // server sends a public key response
    let packet = stream.recv_packet().await?;
    let rsa_pub_key = &packet[1..];

    // xor the password with the given nonce
    let mut pass = to_asciz(password);

    let (a, b) = (nonce.first_ref(), nonce.last_ref());
    let mut nonce = Vec::with_capacity(a.len() + b.len());
    nonce.extend_from_slice(a);
    nonce.extend_from_slice(b);

    xor_eq(&mut pass, &nonce);

    // client sends an RSA encrypted password
    rsa_backend::encrypt(rsa_pub_key, &pass)
}

// XOR(x, y)
// If len(y) < len(x), wrap around inside y
fn xor_eq(x: &mut [u8], y: &[u8]) {
    let y_len = y.len();

    for i in 0..x.len() {
        x[i] ^= y[i % y_len];
    }
}

fn to_asciz(s: &str) -> Vec<u8> {
    let mut z = String::with_capacity(s.len() + 1);
    z.push_str(s);
    z.push('\0');

    z.into_bytes()
}

#[cfg(feature = "rsa")]
mod rsa_backend {
    use rsa::{pkcs8::DecodePublicKey, Oaep, RsaPublicKey};

    use super::Error;

    pub(super) fn encrypt(rsa_pub_key: &[u8], pass: &[u8]) -> Result<Vec<u8>, Error> {
        let pkey = parse_rsa_pub_key(rsa_pub_key)?;
        let padding = Oaep::<sha1::Sha1>::new();
        pkey.encrypt(&mut rand::rng(), padding, pass)
            .map_err(Error::protocol)
    }

    // https://docs.rs/rsa/0.3.0/rsa/struct.RSAPublicKey.html?search=#example-1
    fn parse_rsa_pub_key(key: &[u8]) -> Result<RsaPublicKey, Error> {
        let pem = std::str::from_utf8(key).map_err(Error::protocol)?;

        // This takes advantage of the knowledge that we know
        // we are receiving a PKCS#8 RSA Public Key at all
        // times from MySQL

        RsaPublicKey::from_public_key_pem(pem).map_err(Error::protocol)
    }
}

#[cfg(not(feature = "rsa"))]
mod rsa_backend {
    use super::Error;

    pub(super) fn encrypt(_rsa_pub_key: &[u8], _pass: &[u8]) -> Result<Vec<u8>, Error> {
        Err(Error::Configuration(
            "RSA auth backend disabled; enable feature `mysql-rsa` (or `rsa` if using sqlx-mysql directly) or use TLS.".into(),
        ))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use bytes::Buf;
    use sha2::{Digest, Sha256};

    // Regression test for https://github.com/launchbadge/sqlx/issues/4244:
    // caching_sha2_password fast-auth requires the client scramble to be
    // invertible by the server as XOR(scramble, SHA256(stage2 || nonce)) == stage1,
    // where stage1 = SHA256(password) and stage2 = SHA256(stage1).
    #[test]
    fn scramble_sha256_is_invertible_by_server() {
        let password = "my_pwd";
        let nonce_a = Bytes::from_static(b"0123456789");
        let nonce_b = Bytes::from_static(&[0xAB; 10]);
        let nonce = nonce_a.clone().chain(nonce_b.clone());

        let mut scramble = scramble_sha256(password, &nonce);

        let stage1 = Sha256::digest(password.as_bytes());
        let stage2 = Sha256::digest(stage1);

        let mut h = Sha256::new();
        h.update(stage2);
        h.update(&nonce_a);
        h.update(&nonce_b);
        let xor_pad = h.finalize();

        xor_eq(&mut scramble, &xor_pad);
        assert_eq!(&scramble[..], &stage1[..]);
    }
}