use sha1::Sha1;
use sha2::{Digest, Sha256};
use rand::rngs::OsRng;
use rsa::RsaPublicKey;
use rsa::pkcs1::DecodeRsaPublicKey;
use rsa::pkcs8::DecodePublicKey;
pub mod plugins {
pub const MYSQL_NATIVE_PASSWORD: &str = "mysql_native_password";
pub const CACHING_SHA2_PASSWORD: &str = "caching_sha2_password";
pub const SHA256_PASSWORD: &str = "sha256_password";
pub const MYSQL_CLEAR_PASSWORD: &str = "mysql_clear_password";
}
pub mod caching_sha2 {
pub const REQUEST_PUBLIC_KEY: u8 = 0x02;
pub const FAST_AUTH_SUCCESS: u8 = 0x03;
pub const PERFORM_FULL_AUTH: u8 = 0x04;
}
pub fn mysql_native_password(password: &str, auth_data: &[u8]) -> Vec<u8> {
if password.is_empty() {
return vec![];
}
let seed = if auth_data.len() > 20 {
&auth_data[..20]
} else {
auth_data
};
let mut hasher = Sha1::new();
hasher.update(password.as_bytes());
let stage1: [u8; 20] = hasher.finalize().into();
let mut hasher = Sha1::new();
hasher.update(stage1);
let stage2: [u8; 20] = hasher.finalize().into();
let mut hasher = Sha1::new();
hasher.update(seed);
hasher.update(stage2);
let stage3: [u8; 20] = hasher.finalize().into();
stage1
.iter()
.zip(stage3.iter())
.map(|(a, b)| a ^ b)
.collect()
}
pub fn caching_sha2_password(password: &str, auth_data: &[u8]) -> Vec<u8> {
if password.is_empty() {
return vec![];
}
let seed = if auth_data.len() == 21 && auth_data.last() == Some(&0) {
&auth_data[..20]
} else {
auth_data
};
let mut hasher = Sha256::new();
hasher.update(password.as_bytes());
let password_hash: [u8; 32] = hasher.finalize().into();
let mut hasher = Sha256::new();
hasher.update(password_hash);
let password_hash_hash: [u8; 32] = hasher.finalize().into();
let mut hasher = Sha256::new();
hasher.update(password_hash_hash);
hasher.update(seed);
let scramble: [u8; 32] = hasher.finalize().into();
password_hash
.iter()
.zip(scramble.iter())
.map(|(a, b)| a ^ b)
.collect()
}
pub fn generate_nonce(length: usize) -> Vec<u8> {
use rand::RngCore;
use rand::rngs::OsRng;
let mut bytes = vec![0u8; length];
OsRng.fill_bytes(&mut bytes);
bytes
}
pub fn sha256_password_rsa(
password: &str,
seed: &[u8],
public_key_pem: &[u8],
use_oaep: bool,
) -> Result<Vec<u8>, String> {
let mut pw = password.as_bytes().to_vec();
pw.push(0);
if seed.is_empty() {
return Err("Seed is empty".to_string());
}
for (i, b) in pw.iter_mut().enumerate() {
*b ^= seed[i % seed.len()];
}
let pem = std::str::from_utf8(public_key_pem)
.map_err(|e| format!("Public key is not valid UTF-8 PEM: {e}"))?;
let pub_key = RsaPublicKey::from_public_key_pem(pem)
.or_else(|_| RsaPublicKey::from_pkcs1_pem(pem))
.map_err(|e| format!("Failed to parse RSA public key PEM: {e}"))?;
let encrypted = if use_oaep {
let padding = rsa::Oaep::new::<Sha1>();
pub_key
.encrypt(&mut OsRng, padding, &pw)
.map_err(|e| format!("RSA OAEP encryption failed: {e}"))?
} else {
let padding = rsa::Pkcs1v15Encrypt;
pub_key
.encrypt(&mut OsRng, padding, &pw)
.map_err(|e| format!("RSA PKCS1v1.5 encryption failed: {e}"))?
};
Ok(encrypted)
}
pub fn xor_password_with_seed(password: &str, seed: &[u8]) -> Vec<u8> {
let password_bytes = password.as_bytes();
let mut result = Vec::with_capacity(password_bytes.len() + 1);
for (i, &byte) in password_bytes.iter().enumerate() {
let seed_byte = seed.get(i % seed.len()).copied().unwrap_or(0);
result.push(byte ^ seed_byte);
}
result.push(0);
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mysql_native_password_empty() {
let result = mysql_native_password("", &[0; 20]);
assert!(result.is_empty());
}
#[test]
fn test_mysql_native_password() {
let seed = [0u8; 20];
let result = mysql_native_password("secret", &seed);
assert_eq!(result.len(), 20);
let result2 = mysql_native_password("secret", &seed);
assert_eq!(result, result2);
}
#[test]
fn test_mysql_native_password_real_seed() {
let seed = [
0x3d, 0x4c, 0x5e, 0x2f, 0x1a, 0x0b, 0x7c, 0x8d, 0x9e, 0xaf, 0x10, 0x21, 0x32, 0x43,
0x54, 0x65, 0x76, 0x87, 0x98, 0xa9,
];
let result = mysql_native_password("mypassword", &seed);
assert_eq!(result.len(), 20);
let result2 = mysql_native_password("otherpassword", &seed);
assert_ne!(result, result2);
}
#[test]
fn test_caching_sha2_password_empty() {
let result = caching_sha2_password("", &[0; 20]);
assert!(result.is_empty());
}
#[test]
fn test_caching_sha2_password() {
let seed = [0u8; 20];
let result = caching_sha2_password("secret", &seed);
assert_eq!(result.len(), 32);
let result2 = caching_sha2_password("secret", &seed);
assert_eq!(result, result2);
}
#[test]
fn test_caching_sha2_password_with_nul() {
let mut seed = vec![0u8; 20];
seed.push(0);
let result = caching_sha2_password("secret", &seed);
assert_eq!(result.len(), 32);
let result2 = caching_sha2_password("secret", &seed[..20]);
assert_eq!(result, result2);
}
#[test]
fn test_generate_nonce() {
let nonce1 = generate_nonce(20);
let nonce2 = generate_nonce(20);
assert_eq!(nonce1.len(), 20);
assert_eq!(nonce2.len(), 20);
assert_ne!(nonce1, nonce2);
}
#[test]
fn test_xor_password_with_seed() {
let password = "test";
let seed = [1, 2, 3, 4, 5, 6, 7, 8];
let result = xor_password_with_seed(password, &seed);
assert_eq!(result.len(), 5);
assert_eq!(result[4], 0);
let recovered: Vec<u8> = result[..4]
.iter()
.enumerate()
.map(|(i, &b)| b ^ seed[i % seed.len()])
.collect();
assert_eq!(recovered, password.as_bytes());
}
#[test]
fn test_plugin_names() {
assert_eq!(plugins::MYSQL_NATIVE_PASSWORD, "mysql_native_password");
assert_eq!(plugins::CACHING_SHA2_PASSWORD, "caching_sha2_password");
assert_eq!(plugins::SHA256_PASSWORD, "sha256_password");
}
}