use std::borrow::Cow;
use std::collections::HashMap;
use std::convert::TryInto;
use std::ops::Range;
use futures_core::future::BoxFuture;
use sha1::Sha1;
use crate::connection::{Connect, Connection};
use crate::executor::Executor;
use crate::mysql::protocol::{
AuthPlugin, AuthSwitch, Capabilities, ComPing, Handshake, HandshakeResponse,
};
use crate::mysql::stream::MySqlStream;
use crate::mysql::util::xor_eq;
use crate::mysql::{rsa, tls};
use crate::url::Url;
pub(super) const MAX_PACKET_SIZE: u32 = 1024;
pub(super) const COLLATE_UTF8MB4_UNICODE_CI: u8 = 224;
pub struct MySqlConnection {
pub(super) stream: MySqlStream,
pub(super) is_ready: bool,
pub(super) cache_statement: HashMap<Box<str>, u32>,
pub(super) current_row_values: Vec<Option<Range<usize>>>,
}
fn to_asciz(s: &str) -> Vec<u8> {
let mut z = String::with_capacity(s.len() + 1);
z.push_str(s);
z.push('\0');
z.into_bytes()
}
async fn rsa_encrypt_with_nonce(
stream: &mut MySqlStream,
public_key_request_id: u8,
password: &str,
nonce: &[u8],
) -> crate::Result<Vec<u8>> {
if stream.is_tls() {
return Ok(to_asciz(password));
}
stream.send(&[public_key_request_id][..], false).await?;
let packet = stream.receive().await?;
let rsa_pub_key = &packet[1..];
let mut pass = to_asciz(password);
xor_eq(&mut pass, nonce);
rsa::encrypt::<Sha1>(rsa_pub_key, &pass)
}
async fn make_auth_response(
stream: &mut MySqlStream,
plugin: &AuthPlugin,
password: &str,
nonce: &[u8],
) -> crate::Result<Vec<u8>> {
if password.is_empty() {
return Ok(vec![]);
}
match plugin {
AuthPlugin::CachingSha2Password | AuthPlugin::MySqlNativePassword => {
Ok(plugin.scramble(password, nonce))
}
AuthPlugin::Sha256Password => rsa_encrypt_with_nonce(stream, 0x01, password, nonce).await,
}
}
async fn establish(stream: &mut MySqlStream, url: &Url) -> crate::Result<()> {
let handshake = Handshake::read(stream.receive().await?)?;
let mut auth_plugin = handshake.auth_plugin;
let mut auth_plugin_data = handshake.auth_plugin_data;
stream.capabilities &= handshake.server_capabilities;
stream.capabilities |= Capabilities::PROTOCOL_41;
log::trace!("using capability flags: {:?}", stream.capabilities);
tls::upgrade_if_needed(stream, url).await?;
let password = &*url.password().unwrap_or_default();
let auth_response =
make_auth_response(stream, &auth_plugin, password, &auth_plugin_data).await?;
stream
.send(
HandshakeResponse {
client_collation: COLLATE_UTF8MB4_UNICODE_CI,
max_packet_size: MAX_PACKET_SIZE,
username: &url.username().unwrap_or(Cow::Borrowed("root")),
database: url.database(),
auth_plugin: &auth_plugin,
auth_response: &auth_response,
},
false,
)
.await?;
loop {
let packet = stream.receive().await?;
match packet[0] {
0x00 => {
break;
}
0xFF => {
return stream.handle_err();
}
0xFE => {
let auth = AuthSwitch::read(packet)?;
auth_plugin = auth.auth_plugin;
auth_plugin_data = auth.auth_plugin_data;
let auth_response =
make_auth_response(stream, &auth_plugin, password, &auth_plugin_data).await?;
stream.send(&*auth_response, false).await?;
}
0x01 if auth_plugin == AuthPlugin::CachingSha2Password => {
match packet[1] {
0x03 => {}
0x04 => {
let enc = rsa_encrypt_with_nonce(stream, 0x02, password, &auth_plugin_data)
.await?;
stream.send(&*enc, false).await?;
}
unk => {
return Err(protocol_err!("unexpected result from 'fast' authentication 0x{:x} when expecting OK (0x03) or CONTINUE (0x04)", unk).into());
}
}
}
_ => {
return stream.handle_unexpected();
}
}
}
Ok(())
}
async fn close(mut stream: MySqlStream) -> crate::Result<()> {
stream.flush().await?;
stream.shutdown()?;
Ok(())
}
async fn ping(stream: &mut MySqlStream) -> crate::Result<()> {
stream.wait_until_ready().await?;
stream.is_ready = false;
stream.send(ComPing, true).await?;
match stream.receive().await?[0] {
0x00 | 0xFE => stream.handle_ok().map(drop),
0xFF => stream.handle_err(),
_ => stream.handle_unexpected(),
}
}
impl MySqlConnection {
pub(super) async fn new(url: std::result::Result<Url, url::ParseError>) -> crate::Result<Self> {
let url = url?;
let mut stream = MySqlStream::new(&url).await?;
establish(&mut stream, &url).await?;
let mut self_ = Self {
stream,
current_row_values: Vec::with_capacity(10),
is_ready: true,
cache_statement: HashMap::new(),
};
self_.execute(r#"
SET sql_mode=(SELECT CONCAT(@@sql_mode, ',PIPES_AS_CONCAT,NO_ENGINE_SUBSTITUTION,NO_ZERO_DATE,NO_ZERO_IN_DATE'));
SET time_zone = '+00:00';
SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci;
"#).await?;
Ok(self_)
}
}
impl Connect for MySqlConnection {
fn connect<T>(url: T) -> BoxFuture<'static, crate::Result<MySqlConnection>>
where
T: TryInto<Url, Error = url::ParseError>,
Self: Sized,
{
Box::pin(MySqlConnection::new(url.try_into()))
}
}
impl Connection for MySqlConnection {
#[inline]
fn close(self) -> BoxFuture<'static, crate::Result<()>> {
Box::pin(close(self.stream))
}
#[inline]
fn ping(&mut self) -> BoxFuture<crate::Result<()>> {
Box::pin(ping(&mut self.stream))
}
}