use crate::nightly::cold_path;
use zerocopy::byteorder::little_endian::{U16 as U16LE, U32 as U32LE};
use zerocopy::{FromBytes, Immutable, KnownLayout};
use crate::buffer::BufferSet;
use crate::constant::{
CAPABILITIES_ALWAYS_ENABLED, CAPABILITIES_CONFIGURABLE, CapabilityFlags,
MARIADB_CAPABILITIES_ENABLED, MAX_ALLOWED_PACKET, MariadbCapabilityFlags, UTF8MB4_GENERAL_CI,
};
use crate::error::{Error, Result, eyre};
use crate::opts::Opts;
use crate::protocol::primitive::*;
use crate::protocol::response::ErrPayloadBytes;
#[derive(Debug, Clone, Copy, FromBytes, KnownLayout, Immutable)]
#[repr(C, packed)]
struct HandshakeFixedFields {
connection_id: U32LE,
auth_data_part1: [u8; 8],
_filler1: u8,
capability_flags_lower: U16LE,
charset: u8,
status_flags: U16LE,
capability_flags_upper: U16LE,
auth_data_len: u8,
_filler2: [u8; 6],
mariadb_capabilities: U32LE,
}
#[derive(Debug, Clone)]
pub struct InitialHandshake {
pub protocol_version: u8,
pub server_version: std::ops::Range<usize>,
pub connection_id: u32,
pub auth_plugin_data: Vec<u8>,
pub capability_flags: CapabilityFlags,
pub mariadb_capabilities: MariadbCapabilityFlags,
pub charset: u8,
pub status_flags: crate::constant::ServerStatusFlags,
pub auth_plugin_name: std::ops::Range<usize>,
}
pub fn read_initial_handshake(payload: &[u8]) -> Result<InitialHandshake> {
let (protocol_version, data) = read_int_1(payload)?;
if protocol_version == 0xFF {
cold_path();
Err(ErrPayloadBytes(payload))?
}
let server_version_start = payload.len() - data.len();
let (server_version_bytes, data) = read_string_null(data)?;
let server_version = server_version_start..server_version_start + server_version_bytes.len();
let (fixed, data) = HandshakeFixedFields::ref_from_prefix(data)?;
let connection_id = fixed.connection_id.get();
let charset = fixed.charset;
let status_flags = fixed.status_flags.get();
let capability_flags = CapabilityFlags::from_bits(
((fixed.capability_flags_upper.get() as u32) << 16)
| (fixed.capability_flags_lower.get() as u32),
)
.ok_or_else(|| Error::LibraryBug(eyre!("invalid capability flags from server")))?;
let mariadb_capabilities = MariadbCapabilityFlags::from_bits(fixed.mariadb_capabilities.get())
.ok_or_else(|| Error::LibraryBug(eyre!("invalid mariadb capability flags from server")))?;
let auth_data_len = fixed.auth_data_len;
let auth_data_2_len = (auth_data_len as usize).saturating_sub(9).max(12);
let (auth_data_2, data) = read_string_fix(data, auth_data_2_len)?;
let (_reserved, data) = read_int_1(data)?;
let mut auth_plugin_data = Vec::new();
auth_plugin_data.extend_from_slice(&fixed.auth_data_part1);
auth_plugin_data.extend_from_slice(auth_data_2);
let auth_plugin_name_start = payload.len() - data.len();
let (auth_plugin_name_bytes, rest) = read_string_null(data)?;
let auth_plugin_name =
auth_plugin_name_start..auth_plugin_name_start + auth_plugin_name_bytes.len();
if !rest.is_empty() {
return Err(Error::LibraryBug(eyre!(
"unexpected trailing data in handshake packet: {} bytes",
rest.len()
)));
}
Ok(InitialHandshake {
protocol_version,
server_version,
connection_id,
auth_plugin_data,
capability_flags,
mariadb_capabilities,
charset,
status_flags: crate::constant::ServerStatusFlags::from_bits_truncate(status_flags),
auth_plugin_name,
})
}
#[derive(Debug, Clone)]
pub struct AuthSwitchRequest<'buf> {
pub plugin_name: &'buf [u8],
pub plugin_data: &'buf [u8],
}
pub fn read_auth_switch_request(payload: &[u8]) -> Result<AuthSwitchRequest<'_>> {
let (header, mut data) = read_int_1(payload)?;
if header != 0xFE {
return Err(Error::LibraryBug(eyre!(
"expected auth switch header 0xFE, got 0x{:02X}",
header
)));
}
let (plugin_name, rest) = read_string_null(data)?;
data = rest;
if let Some(0) = data.last() {
Ok(AuthSwitchRequest {
plugin_name,
plugin_data: &data[..data.len() - 1],
})
} else {
Err(Error::LibraryBug(eyre!(
"auth switch request plugin data not null-terminated"
)))
}
}
pub fn write_auth_switch_response(out: &mut Vec<u8>, auth_data: &[u8]) {
out.extend_from_slice(auth_data);
}
pub fn auth_mysql_native_password(password: &str, challenge: &[u8]) -> [u8; 20] {
use sha1::{Digest, Sha1};
if password.is_empty() {
return [0_u8; 20];
}
let stage1_hash = Sha1::digest(password.as_bytes());
let stage2_hash = Sha1::digest(stage1_hash);
let mut hasher = Sha1::new();
hasher.update(challenge);
hasher.update(stage2_hash);
let token_hash = hasher.finalize();
let mut result = [0_u8; 20];
for i in 0..20 {
result[i] = stage1_hash[i] ^ token_hash[i];
}
result
}
pub fn auth_caching_sha2_password(password: &str, challenge: &[u8]) -> [u8; 32] {
use sha2::{Digest, Sha256};
if password.is_empty() {
return [0_u8; 32];
}
let stage1 = Sha256::digest(password.as_bytes());
let stage2 = Sha256::digest(stage1);
let mut hasher = Sha256::new();
hasher.update(stage2);
hasher.update(challenge);
let scramble = hasher.finalize();
let mut result = [0_u8; 32];
for i in 0..32 {
result[i] = stage1[i] ^ scramble[i];
}
result
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CachingSha2PasswordFastAuthResult {
Success,
FullAuthRequired,
}
pub fn read_caching_sha2_password_fast_auth_result(
payload: &[u8],
) -> Result<CachingSha2PasswordFastAuthResult> {
if payload.is_empty() {
return Err(Error::LibraryBug(eyre!(
"empty payload for caching_sha2_password fast auth result"
)));
}
match payload[0] {
0x03 => Ok(CachingSha2PasswordFastAuthResult::Success),
0x04 => Ok(CachingSha2PasswordFastAuthResult::FullAuthRequired),
_ => Err(Error::LibraryBug(eyre!(
"unexpected caching_sha2_password fast auth result: 0x{:02X}",
payload[0]
))),
}
}
fn rsa_encrypt_password(password: &str, scramble: &[u8], pem_str: &str) -> Result<Vec<u8>> {
use aws_lc_rs::rsa::{OAEP_SHA1_MGF1SHA1, OaepPublicEncryptingKey, PublicEncryptingKey};
let pem_data = pem::parse(pem_str)
.map_err(|e| Error::LibraryBug(eyre!("failed to parse RSA public key PEM: {}", e)))?;
let public_key = PublicEncryptingKey::from_der(pem_data.contents())
.map_err(|e| Error::LibraryBug(eyre!("failed to parse RSA public key DER: {}", e)))?;
let oaep_key = OaepPublicEncryptingKey::new(public_key)
.map_err(|e| Error::LibraryBug(eyre!("failed to create OAEP key: {}", e)))?;
if scramble.is_empty() {
return Err(Error::LibraryBug(eyre!(
"empty scramble in rsa_encrypt_password"
)));
}
let mut buf = Vec::with_capacity(password.len() + 1);
buf.extend_from_slice(password.as_bytes());
buf.push(0);
for (byte, key) in buf.iter_mut().zip(scramble.iter().cycle()) {
*byte ^= key;
}
let mut ciphertext = vec![0u8; oaep_key.ciphertext_size()];
let encrypted = oaep_key
.encrypt(&OAEP_SHA1_MGF1SHA1, &buf, &mut ciphertext, None)
.map_err(|e| Error::LibraryBug(eyre!("RSA encryption failed: {}", e)))?;
Ok(encrypted.to_vec())
}
fn write_ssl_request(
out: &mut Vec<u8>,
capability_flags: CapabilityFlags,
mariadb_capabilities: MariadbCapabilityFlags,
) {
write_int_4(out, capability_flags.bits());
write_int_4(out, MAX_ALLOWED_PACKET);
write_int_1(out, UTF8MB4_GENERAL_CI);
out.extend_from_slice(&[0_u8; 19]);
if capability_flags.is_mariadb() {
write_int_4(out, mariadb_capabilities.bits());
} else {
write_int_4(out, 0);
}
}
pub enum HandshakeAction<'buf> {
ReadPacket(&'buf mut Vec<u8>),
WritePacket { sequence_id: u8 },
UpgradeTls { sequence_id: u8 },
Finished,
}
enum HandshakeState {
Start,
WaitingInitialHandshake,
WaitingTlsUpgrade,
WaitingAuthResult,
WaitingFinalAuthResult { caching_sha2: bool },
WaitingCachingSha2FastAuthOk,
WaitingRsaPublicKey,
Connected,
}
pub struct Handshake<'a> {
state: HandshakeState,
opts: &'a Opts,
initial_handshake: Option<InitialHandshake>,
next_sequence_id: u8,
capability_flags: Option<CapabilityFlags>,
mariadb_capabilities: Option<MariadbCapabilityFlags>,
}
impl<'a> Handshake<'a> {
pub fn new(opts: &'a Opts) -> Self {
Self {
state: HandshakeState::Start,
opts,
initial_handshake: None,
next_sequence_id: 1,
capability_flags: None,
mariadb_capabilities: None,
}
}
pub fn step<'buf>(&mut self, buffer_set: &'buf mut BufferSet) -> Result<HandshakeAction<'buf>> {
match &mut self.state {
HandshakeState::Start => {
self.state = HandshakeState::WaitingInitialHandshake;
Ok(HandshakeAction::ReadPacket(
&mut buffer_set.initial_handshake,
))
}
HandshakeState::WaitingInitialHandshake => {
let handshake = read_initial_handshake(&buffer_set.initial_handshake)?;
let mut client_caps = CAPABILITIES_ALWAYS_ENABLED
| (self.opts.capabilities & CAPABILITIES_CONFIGURABLE);
if self.opts.db.is_some() {
client_caps |= CapabilityFlags::CLIENT_CONNECT_WITH_DB;
}
if self.opts.tls {
client_caps |= CapabilityFlags::CLIENT_SSL;
}
let negotiated_caps = client_caps & handshake.capability_flags;
let mariadb_caps = if negotiated_caps.is_mariadb() {
if !handshake
.mariadb_capabilities
.contains(MARIADB_CAPABILITIES_ENABLED)
{
return Err(Error::Unsupported(format!(
"MariaDB server does not support the required capabilities. Server: {:?} Required: {:?}",
handshake.mariadb_capabilities, MARIADB_CAPABILITIES_ENABLED
)));
}
MARIADB_CAPABILITIES_ENABLED
} else {
MariadbCapabilityFlags::empty()
};
self.capability_flags = Some(negotiated_caps);
self.mariadb_capabilities = Some(mariadb_caps);
self.initial_handshake = Some(handshake);
if self.opts.tls && negotiated_caps.contains(CapabilityFlags::CLIENT_SSL) {
write_ssl_request(buffer_set.new_write_buffer(), negotiated_caps, mariadb_caps);
let seq = self.next_sequence_id;
self.next_sequence_id = self.next_sequence_id.wrapping_add(1);
self.state = HandshakeState::WaitingTlsUpgrade;
Ok(HandshakeAction::UpgradeTls { sequence_id: seq })
} else {
self.write_handshake_response(buffer_set)?;
let seq = self.next_sequence_id;
self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
self.state = HandshakeState::WaitingAuthResult;
Ok(HandshakeAction::WritePacket { sequence_id: seq })
}
}
HandshakeState::WaitingTlsUpgrade => {
self.write_handshake_response(buffer_set)?;
let seq = self.next_sequence_id;
self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
self.state = HandshakeState::WaitingAuthResult;
Ok(HandshakeAction::WritePacket { sequence_id: seq })
}
HandshakeState::WaitingAuthResult => {
let payload = &buffer_set.read_buffer[..];
if payload.is_empty() {
return Err(Error::LibraryBug(eyre!(
"empty payload while waiting for auth result"
)));
}
let initial_handshake = self.initial_handshake.as_ref().ok_or_else(|| {
Error::LibraryBug(eyre!("initial_handshake not set in WaitingAuthResult"))
})?;
let initial_plugin =
&buffer_set.initial_handshake[initial_handshake.auth_plugin_name.clone()];
match payload[0] {
0x00 => {
self.state = HandshakeState::Connected;
Ok(HandshakeAction::Finished)
}
0xFF => {
Err(ErrPayloadBytes(payload).into())
}
0x01 => {
if initial_plugin == b"caching_sha2_password" {
self.handle_auth_more_data(buffer_set)
} else {
Err(Error::LibraryBug(eyre!(
"unexpected AuthMoreData (0x01) for plugin {:?}",
String::from_utf8_lossy(initial_plugin)
)))
}
}
0xFE => {
let auth_switch = read_auth_switch_request(payload)?;
let (auth_response, is_caching_sha2) = match auth_switch.plugin_name {
b"mysql_native_password" => (
auth_mysql_native_password(
&self.opts.password,
auth_switch.plugin_data,
)
.to_vec(),
false,
),
b"caching_sha2_password" => (
auth_caching_sha2_password(
&self.opts.password,
auth_switch.plugin_data,
)
.to_vec(),
true,
),
plugin => {
return Err(Error::Unsupported(
String::from_utf8_lossy(plugin).to_string(),
));
}
};
write_auth_switch_response(buffer_set.new_write_buffer(), &auth_response);
let seq = self.next_sequence_id;
self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
self.state = HandshakeState::WaitingFinalAuthResult {
caching_sha2: is_caching_sha2,
};
Ok(HandshakeAction::WritePacket { sequence_id: seq })
}
header => Err(Error::LibraryBug(eyre!(
"unexpected packet header 0x{:02X} while waiting for auth result",
header
))),
}
}
HandshakeState::WaitingFinalAuthResult { caching_sha2 } => {
let payload = &buffer_set.read_buffer[..];
if payload.is_empty() {
return Err(Error::LibraryBug(eyre!(
"empty payload while waiting for final auth result"
)));
}
match payload[0] {
0x00 => {
self.state = HandshakeState::Connected;
Ok(HandshakeAction::Finished)
}
0xFF => {
Err(ErrPayloadBytes(payload).into())
}
0x01 if *caching_sha2 => self.handle_auth_more_data(buffer_set),
header => Err(Error::LibraryBug(eyre!(
"unexpected packet header 0x{:02X} while waiting for final auth result",
header
))),
}
}
HandshakeState::WaitingCachingSha2FastAuthOk => {
let payload = &buffer_set.read_buffer[..];
if payload.is_empty() {
return Err(Error::LibraryBug(eyre!(
"empty payload while waiting for caching_sha2 OK"
)));
}
match payload[0] {
0x00 => {
self.state = HandshakeState::Connected;
Ok(HandshakeAction::Finished)
}
0xFF => Err(ErrPayloadBytes(payload).into()),
header => Err(Error::LibraryBug(eyre!(
"unexpected packet header 0x{:02X} while waiting for caching_sha2 OK",
header
))),
}
}
HandshakeState::WaitingRsaPublicKey => {
let payload = &buffer_set.read_buffer[..];
if payload.is_empty() {
return Err(Error::LibraryBug(eyre!(
"empty payload while waiting for RSA public key"
)));
}
match payload[0] {
0xFF => return Err(ErrPayloadBytes(payload).into()),
0x01 if payload.len() >= 2 => {}
header => {
return Err(Error::LibraryBug(eyre!(
"expected AuthMoreData (0x01) with RSA public key, got 0x{:02X}",
header
)));
}
}
let pem = std::str::from_utf8(&payload[1..]).map_err(|e| {
Error::LibraryBug(eyre!("RSA public key is not valid UTF-8: {}", e))
})?;
let handshake = self
.initial_handshake
.as_ref()
.ok_or_else(|| Error::LibraryBug(eyre!("initial_handshake not set")))?;
let encrypted =
rsa_encrypt_password(&self.opts.password, &handshake.auth_plugin_data, pem)?;
let out = buffer_set.new_write_buffer();
out.extend_from_slice(&encrypted);
let seq = self.next_sequence_id;
self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
self.state = HandshakeState::WaitingFinalAuthResult {
caching_sha2: false,
};
Ok(HandshakeAction::WritePacket { sequence_id: seq })
}
HandshakeState::Connected => Err(Error::LibraryBug(eyre!(
"step() called after handshake completed"
))),
}
}
pub fn finish(self) -> Result<(InitialHandshake, CapabilityFlags, MariadbCapabilityFlags)> {
if !matches!(self.state, HandshakeState::Connected) {
return Err(Error::LibraryBug(eyre!(
"finish() called before handshake completed"
)));
}
let initial_handshake = self.initial_handshake.ok_or_else(|| {
Error::LibraryBug(eyre!("initial_handshake not set in Connected state"))
})?;
let capability_flags = self.capability_flags.ok_or_else(|| {
Error::LibraryBug(eyre!("capability_flags not set in Connected state"))
})?;
let mariadb_capabilities = self.mariadb_capabilities.ok_or_else(|| {
Error::LibraryBug(eyre!("mariadb_capabilities not set in Connected state"))
})?;
Ok((initial_handshake, capability_flags, mariadb_capabilities))
}
fn write_handshake_response(&self, buffer_set: &mut BufferSet) -> Result<()> {
buffer_set.new_write_buffer();
let handshake = self.initial_handshake.as_ref().ok_or_else(|| {
Error::LibraryBug(eyre!(
"initial_handshake not set in write_handshake_response"
))
})?;
let capability_flags = self.capability_flags.ok_or_else(|| {
Error::LibraryBug(eyre!(
"capability_flags not set in write_handshake_response"
))
})?;
let mariadb_capabilities = self.mariadb_capabilities.ok_or_else(|| {
Error::LibraryBug(eyre!(
"mariadb_capabilities not set in write_handshake_response"
))
})?;
let auth_plugin_name = &buffer_set.initial_handshake[handshake.auth_plugin_name.clone()];
let auth_response = {
match auth_plugin_name {
b"mysql_native_password" => {
auth_mysql_native_password(&self.opts.password, &handshake.auth_plugin_data)
.to_vec()
}
b"caching_sha2_password" => {
auth_caching_sha2_password(&self.opts.password, &handshake.auth_plugin_data)
.to_vec()
}
plugin => {
return Err(Error::Unsupported(
String::from_utf8_lossy(plugin).to_string(),
));
}
}
};
let out = &mut buffer_set.write_buffer;
write_int_4(out, capability_flags.bits());
write_int_4(out, MAX_ALLOWED_PACKET);
write_int_1(out, UTF8MB4_GENERAL_CI);
out.extend_from_slice(&[0_u8; 19]);
write_int_4(out, mariadb_capabilities.bits());
write_string_null(out, self.opts.user.as_bytes());
if capability_flags.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) {
write_bytes_lenenc(out, &auth_response);
} else {
write_int_1(out, auth_response.len() as u8);
out.extend_from_slice(&auth_response);
}
if let Some(db) = &self.opts.db {
write_string_null(out, db.as_bytes());
}
if capability_flags.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH) {
write_string_null(out, auth_plugin_name);
}
Ok(())
}
fn handle_auth_more_data<'buf>(
&mut self,
buffer_set: &'buf mut BufferSet,
) -> Result<HandshakeAction<'buf>> {
let payload = &buffer_set.read_buffer[..];
if payload.len() < 2 {
return Err(Error::LibraryBug(eyre!(
"AuthMoreData packet too short: {} bytes",
payload.len()
)));
}
let result = read_caching_sha2_password_fast_auth_result(&payload[1..])?;
match result {
CachingSha2PasswordFastAuthResult::Success => {
self.state = HandshakeState::WaitingCachingSha2FastAuthOk;
Ok(HandshakeAction::ReadPacket(&mut buffer_set.read_buffer))
}
CachingSha2PasswordFastAuthResult::FullAuthRequired => {
let capability_flags = self
.capability_flags
.ok_or_else(|| Error::LibraryBug(eyre!("capability_flags not set")))?;
if capability_flags.contains(CapabilityFlags::CLIENT_SSL) {
let out = buffer_set.new_write_buffer();
out.extend_from_slice(self.opts.password.as_bytes());
out.push(0);
let seq = self.next_sequence_id;
self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
self.state = HandshakeState::WaitingFinalAuthResult {
caching_sha2: false,
};
Ok(HandshakeAction::WritePacket { sequence_id: seq })
} else {
let out = buffer_set.new_write_buffer();
out.push(0x02);
let seq = self.next_sequence_id;
self.next_sequence_id = self.next_sequence_id.wrapping_add(2);
self.state = HandshakeState::WaitingRsaPublicKey;
Ok(HandshakeAction::WritePacket { sequence_id: seq })
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_macros::{check_eq, check_err};
#[test]
fn handshake_fixed_fields_has_alignment_of_1() {
assert_eq!(std::mem::align_of::<HandshakeFixedFields>(), 1);
}
#[test]
#[expect(clippy::unwrap_used)]
fn rsa_encrypt_password_xors_and_encrypts() {
use aws_lc_rs::encoding::AsDer;
use aws_lc_rs::rsa::{
KeySize, OAEP_SHA1_MGF1SHA1, OaepPrivateDecryptingKey, PrivateDecryptingKey,
};
use aws_lc_rs::signature::KeyPair;
let key_pair = aws_lc_rs::rsa::KeyPair::generate(KeySize::Rsa2048).unwrap();
let private_key_pkcs8 = key_pair.as_der().unwrap();
let public_key_der = key_pair.public_key().as_der().unwrap();
let pem_data = pem::Pem::new("PUBLIC KEY", public_key_der.as_ref().to_vec());
let pem_string = pem::encode(&pem_data);
let password = "test_password";
let scramble = b"01234567890123456789";
let encrypted = super::rsa_encrypt_password(password, scramble, &pem_string).unwrap();
let private_key = PrivateDecryptingKey::from_pkcs8(private_key_pkcs8.as_ref()).unwrap();
let oaep_key = OaepPrivateDecryptingKey::new(private_key).unwrap();
let mut plaintext = vec![0u8; encrypted.len()];
let decrypted = oaep_key
.decrypt(&OAEP_SHA1_MGF1SHA1, &encrypted, &mut plaintext, None)
.unwrap();
let mut expected = password.as_bytes().to_vec();
expected.push(0);
for (byte, key) in expected.iter_mut().zip(scramble.iter().cycle()) {
*byte ^= key;
}
assert_eq!(decrypted, expected);
}
#[test]
fn fast_auth_result_parsing() -> crate::error::Result<()> {
check_eq!(
read_caching_sha2_password_fast_auth_result(&[0x03])?,
CachingSha2PasswordFastAuthResult::Success,
);
check_eq!(
read_caching_sha2_password_fast_auth_result(&[0x04])?,
CachingSha2PasswordFastAuthResult::FullAuthRequired,
);
check_err!(read_caching_sha2_password_fast_auth_result(&[0x05]));
check_err!(read_caching_sha2_password_fast_auth_result(&[]));
Ok(())
}
}