use bytes::buf::Chain;
use bytes::Bytes;
use digest::Digest;
use sha1::Sha1;
use sha2::Sha256;
use crate::connection::stream::MySqlStream;
use crate::error::Error;
use crate::protocol::auth::AuthPlugin;
use crate::protocol::Packet;
impl AuthPlugin {
pub(super) async fn scramble(
self,
stream: &mut MySqlStream,
password: &str,
nonce: &Chain<Bytes, Bytes>,
) -> Result<Vec<u8>, Error> {
match self {
AuthPlugin::CachingSha2Password => Ok(scramble_sha256(password, nonce).to_vec()),
AuthPlugin::MySqlNativePassword => Ok(scramble_sha1(password, nonce).to_vec()),
AuthPlugin::Sha256Password => encrypt_rsa(stream, 0x01, password, nonce).await,
AuthPlugin::MySqlClearPassword => {
let mut pw_bytes = password.as_bytes().to_owned();
pw_bytes.push(0); Ok(pw_bytes)
}
}
}
pub(super) async fn handle(
self,
stream: &mut MySqlStream,
packet: Packet<Bytes>,
password: &str,
nonce: &Chain<Bytes, Bytes>,
) -> Result<bool, Error> {
match self {
AuthPlugin::CachingSha2Password if packet[0] == 0x01 => {
match packet[1] {
0x03 => Ok(false),
0x04 => {
let payload = encrypt_rsa(stream, 0x02, password, nonce).await?;
stream.write_packet(&*payload)?;
stream.flush().await?;
Ok(false)
}
v => {
Err(err_protocol!("unexpected result from fast authentication 0x{:x} when expecting 0x03 (fast_auth_success) or 0x04 (perform_full_authentication)", v))
}
}
}
_ => Err(err_protocol!(
"unexpected packet 0x{:02x} for auth plugin '{}' during authentication",
packet[0],
self.name()
)),
}
}
}
fn scramble_sha1(password: &str, nonce: &Chain<Bytes, Bytes>) -> Vec<u8> {
let mut ctx = Sha1::new();
ctx.update(password);
let mut pw_hash = ctx.finalize_reset();
ctx.update(pw_hash);
let pw_hash_hash = ctx.finalize_reset();
ctx.update(nonce.first_ref());
ctx.update(nonce.last_ref());
ctx.update(pw_hash_hash);
let pw_seed_hash_hash = ctx.finalize();
xor_eq(&mut pw_hash, &pw_seed_hash_hash);
pw_hash.to_vec()
}
fn scramble_sha256(password: &str, nonce: &Chain<Bytes, Bytes>) -> Vec<u8> {
let mut ctx = Sha256::new();
ctx.update(password);
let mut pw_hash = ctx.finalize_reset();
ctx.update(pw_hash);
let pw_hash_hash = ctx.finalize_reset();
ctx.update(pw_hash_hash);
ctx.update(nonce.first_ref());
ctx.update(nonce.last_ref());
let pw_seed_hash_hash = ctx.finalize();
xor_eq(&mut pw_hash, &pw_seed_hash_hash);
pw_hash.to_vec()
}
async fn encrypt_rsa<'s>(
stream: &'s mut MySqlStream,
public_key_request_id: u8,
password: &'s str,
nonce: &'s Chain<Bytes, Bytes>,
) -> Result<Vec<u8>, Error> {
if stream.is_tls {
return Ok(to_asciz(password));
}
stream.write_packet(&[public_key_request_id][..])?;
stream.flush().await?;
let packet = stream.recv_packet().await?;
let rsa_pub_key = &packet[1..];
let mut pass = to_asciz(password);
let (a, b) = (nonce.first_ref(), nonce.last_ref());
let mut nonce = Vec::with_capacity(a.len() + b.len());
nonce.extend_from_slice(a);
nonce.extend_from_slice(b);
xor_eq(&mut pass, &nonce);
rsa_backend::encrypt(rsa_pub_key, &pass)
}
fn xor_eq(x: &mut [u8], y: &[u8]) {
let y_len = y.len();
for i in 0..x.len() {
x[i] ^= y[i % y_len];
}
}
fn to_asciz(s: &str) -> Vec<u8> {
let mut z = String::with_capacity(s.len() + 1);
z.push_str(s);
z.push('\0');
z.into_bytes()
}
#[cfg(feature = "rsa")]
mod rsa_backend {
use rsa::{pkcs8::DecodePublicKey, Oaep, RsaPublicKey};
use super::Error;
pub(super) fn encrypt(rsa_pub_key: &[u8], pass: &[u8]) -> Result<Vec<u8>, Error> {
let pkey = parse_rsa_pub_key(rsa_pub_key)?;
let padding = Oaep::<sha1::Sha1>::new();
pkey.encrypt(&mut rand::rng(), padding, pass)
.map_err(Error::protocol)
}
fn parse_rsa_pub_key(key: &[u8]) -> Result<RsaPublicKey, Error> {
let pem = std::str::from_utf8(key).map_err(Error::protocol)?;
RsaPublicKey::from_public_key_pem(pem).map_err(Error::protocol)
}
}
#[cfg(not(feature = "rsa"))]
mod rsa_backend {
use super::Error;
pub(super) fn encrypt(_rsa_pub_key: &[u8], _pass: &[u8]) -> Result<Vec<u8>, Error> {
Err(Error::Configuration(
"RSA auth backend disabled; enable feature `mysql-rsa` (or `rsa` if using sqlx-mysql directly) or use TLS.".into(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Buf;
use sha2::{Digest, Sha256};
#[test]
fn scramble_sha256_is_invertible_by_server() {
let password = "my_pwd";
let nonce_a = Bytes::from_static(b"0123456789");
let nonce_b = Bytes::from_static(&[0xAB; 10]);
let nonce = nonce_a.clone().chain(nonce_b.clone());
let mut scramble = scramble_sha256(password, &nonce);
let stage1 = Sha256::digest(password.as_bytes());
let stage2 = Sha256::digest(stage1);
let mut h = Sha256::new();
h.update(stage2);
h.update(&nonce_a);
h.update(&nonce_b);
let xor_pad = h.finalize();
xor_eq(&mut scramble, &xor_pad);
assert_eq!(&scramble[..], &stage1[..]);
}
}