#![deny(missing_docs)]
#![deny(rust_2018_idioms)]
extern crate mysql_common as myc;
use std::collections::HashMap;
use std::io;
use std::io::prelude::*;
use std::iter;
use std::net;
use myc::constants::CapabilityFlags;
pub use crate::myc::constants::{ColumnFlags, ColumnType, StatusFlags};
mod commands;
mod errorcodes;
mod packet;
mod params;
mod resultset;
#[cfg(feature = "tls")]
mod tls;
mod value;
mod writers;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Column {
pub table: String,
pub column: String,
pub coltype: ColumnType,
pub colflags: ColumnFlags,
}
pub use crate::errorcodes::ErrorKind;
pub use crate::params::{ParamParser, ParamValue, Params};
pub use crate::resultset::{InitWriter, QueryResultWriter, RowWriter, StatementMetaWriter};
pub use crate::value::{ToMysqlValue, Value, ValueInner};
pub trait MysqlShim<W: Read + Write> {
type Error: From<io::Error>;
fn on_prepare(
&mut self,
query: &str,
info: StatementMetaWriter<'_, W>,
) -> Result<(), Self::Error>;
fn on_execute(
&mut self,
id: u32,
params: ParamParser<'_>,
results: QueryResultWriter<'_, W>,
) -> Result<(), Self::Error>;
fn on_close(&mut self, stmt: u32);
fn on_query(
&mut self,
query: &str,
results: QueryResultWriter<'_, W>,
) -> Result<(), Self::Error>;
fn on_init(&mut self, _: &str, _: InitWriter<'_, W>) -> Result<(), Self::Error> {
Ok(())
}
#[cfg(feature = "tls")]
fn tls_config(&self) -> Option<std::sync::Arc<rustls::ServerConfig>> {
None
}
fn after_authentication(
&mut self,
_context: &AuthenticationContext<'_>,
) -> Result<(), Self::Error> {
Ok(())
}
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Debug, Default, Clone, PartialEq)]
pub struct AuthenticationContext<'a> {
pub username: Option<Vec<u8>>,
#[cfg(feature = "tls")]
pub tls_client_certs: Option<&'a [rustls::pki_types::CertificateDer<'a>]>,
#[cfg(not(feature = "tls"))]
_pd: Option<&'a std::marker::PhantomData<()>>,
}
pub struct MysqlIntermediary<B, RW: Read + Write> {
shim: B,
rw: packet::PacketConn<RW>,
}
impl<B: MysqlShim<net::TcpStream>> MysqlIntermediary<B, net::TcpStream> {
pub fn run_on_tcp(shim: B, stream: net::TcpStream) -> Result<(), B::Error> {
MysqlIntermediary::run_on(shim, stream)
}
}
impl<B: MysqlShim<S>, S: Read + Write + Clone> MysqlIntermediary<B, S> {
pub fn run_on_stream(shim: B, stream: S) -> Result<(), B::Error> {
MysqlIntermediary::run_on(shim, stream)
}
}
#[derive(Default)]
struct StatementData {
long_data: HashMap<u16, Vec<u8>>,
bound_types: Vec<(myc::constants::ColumnType, bool)>,
params: u16,
}
impl<B: MysqlShim<RW>, RW: Read + Write> MysqlIntermediary<B, RW> {
pub fn run_on(shim: B, rw: RW) -> Result<(), B::Error> {
let rw = packet::PacketConn::new(rw);
let mut mi = MysqlIntermediary { shim, rw };
mi.init()?;
mi.run()
}
fn init(&mut self) -> Result<(), B::Error> {
#[cfg(feature = "tls")]
let tls_conf = self.shim.tls_config();
self.rw.write_all(&[10])?;
self.rw.write_all(&b"5.1.10-alpha-msql-proxy\0"[..])?;
self.rw.write_all(&[0x08, 0x00, 0x00, 0x00])?; self.rw.write_all(&b";X,po_k}\0"[..])?; let capabilities = &mut [0x00, 0x42]; #[cfg(feature = "tls")]
if tls_conf.is_some() {
capabilities[1] |= 0x08; }
self.rw.write_all(capabilities)?;
self.rw.write_all(&[0x21])?; self.rw.write_all(&[0x00, 0x00])?; self.rw.write_all(&[0x00, 0x00])?; self.rw.write_all(&[0x00])?; self.rw.write_all(&[0x00; 6][..])?; self.rw.write_all(&[0x00; 4][..])?; self.rw.write_all(&b">o6^Wz!/kM}N\0"[..])?; self.rw.flush()?;
let mut auth_context = AuthenticationContext::default();
{
let (seq, handshake) = self.rw.next()?.ok_or_else(|| {
io::Error::new(
io::ErrorKind::ConnectionAborted,
"peer terminated connection",
)
})?;
let handshake = commands::client_handshake(&handshake, false)
.map_err(|e| match e {
nom::Err::Incomplete(_) => io::Error::new(
io::ErrorKind::UnexpectedEof,
"client sent incomplete handshake",
),
nom::Err::Failure(nom_error) | nom::Err::Error(nom_error) => {
if let nom::error::ErrorKind::Eof = nom_error.code {
io::Error::new(
io::ErrorKind::UnexpectedEof,
format!(
"client did not complete handshake; got {:?}",
nom_error.input
),
)
} else {
io::Error::new(
io::ErrorKind::InvalidData,
format!(
"bad client handshake; got {:?} ({:?})",
nom_error.input, nom_error.code
),
)
}
}
})?
.1;
auth_context.username = handshake.username.map(|x| x.to_vec());
self.rw.set_seq(seq + 1);
#[cfg(not(feature = "tls"))]
if handshake.capabilities.contains(CapabilityFlags::CLIENT_SSL) {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"client requested SSL despite us not advertising support for it",
)
.into());
}
#[cfg(feature = "tls")]
if handshake.capabilities.contains(CapabilityFlags::CLIENT_SSL) {
let config = tls_conf.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
"client requested SSL despite us not advertising support for it",
)
})?;
self.rw.switch_to_tls(config)?;
let (seq, handshake) = self.rw.next()?.ok_or_else(|| {
io::Error::new(
io::ErrorKind::ConnectionAborted,
"peer terminated connection",
)
})?;
let handshake = commands::client_handshake(&handshake, true)
.map_err(|e| match e {
nom::Err::Incomplete(_) => io::Error::new(
io::ErrorKind::UnexpectedEof,
"client sent incomplete handshake",
),
nom::Err::Failure(nom_error) | nom::Err::Error(nom_error) => {
if let nom::error::ErrorKind::Eof = nom_error.code {
io::Error::new(
io::ErrorKind::UnexpectedEof,
format!(
"client did not complete handshake; got {:?}",
nom_error.input
),
)
} else {
io::Error::new(
io::ErrorKind::InvalidData,
format!(
"bad client handshake; got {:?} ({:?})",
nom_error.input, nom_error.code
),
)
}
}
})?
.1;
auth_context.username = handshake.username.map(|x| x.to_vec());
self.rw.set_seq(seq + 1);
auth_context.tls_client_certs = self.rw.tls_certs();
}
if let Err(e) = self.shim.after_authentication(&auth_context) {
writers::write_err(
ErrorKind::ER_ACCESS_DENIED_ERROR,
"client authentication failed".as_ref(),
&mut self.rw,
)?;
self.rw.flush()?;
return Err(e);
}
}
writers::write_ok_packet(&mut self.rw, 0, 0, StatusFlags::empty())?;
self.rw.flush()?;
Ok(())
}
fn run(mut self) -> Result<(), B::Error> {
use crate::commands::Command;
let mut stmts: HashMap<u32, _> = HashMap::new();
while let Some((seq, packet)) = self.rw.next()? {
self.rw.set_seq(seq + 1);
let cmd = commands::parse(&packet).unwrap().1;
match cmd {
Command::Query(q) => {
if q.starts_with(b"SELECT @@") || q.starts_with(b"select @@") {
let w = QueryResultWriter::new(&mut self.rw, false);
let var = &q[b"SELECT @@".len()..];
match var {
b"max_allowed_packet" => {
let cols = &[Column {
table: String::new(),
column: "@@max_allowed_packet".to_owned(),
coltype: myc::constants::ColumnType::MYSQL_TYPE_LONG,
colflags: myc::constants::ColumnFlags::UNSIGNED_FLAG,
}];
let mut w = w.start(cols)?;
w.write_row(iter::once(67108864u32))?;
w.finish()?;
}
_ => {
w.completed(0, 0)?;
}
}
} else if q.starts_with(b"USE ") || q.starts_with(b"use ") {
let w = InitWriter {
writer: &mut self.rw,
};
let schema = ::std::str::from_utf8(&q[b"USE ".len()..])
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
let schema = schema.trim().trim_end_matches(';').trim_matches('`');
self.shim.on_init(schema, w)?;
} else {
let w = QueryResultWriter::new(&mut self.rw, false);
self.shim.on_query(
::std::str::from_utf8(q)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?,
w,
)?;
}
}
Command::Prepare(q) => {
let w = StatementMetaWriter {
writer: &mut self.rw,
stmts: &mut stmts,
};
self.shim.on_prepare(
::std::str::from_utf8(q)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?,
w,
)?;
}
Command::Execute { stmt, params } => {
let state = stmts.get_mut(&stmt).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("asked to execute unknown statement {}", stmt),
)
})?;
{
let params = params::ParamParser::new(params, state);
let w = QueryResultWriter::new(&mut self.rw, true);
self.shim.on_execute(stmt, params, w)?;
}
state.long_data.clear();
}
Command::SendLongData { stmt, param, data } => {
stmts
.get_mut(&stmt)
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("got long data packet for unknown statement {}", stmt),
)
})?
.long_data
.entry(param)
.or_insert_with(Vec::new)
.extend(data);
}
Command::Close(stmt) => {
self.shim.on_close(stmt);
stmts.remove(&stmt);
}
Command::ListFields(_) => {
let cols = &[Column {
table: String::new(),
column: "not implemented".to_owned(),
coltype: myc::constants::ColumnType::MYSQL_TYPE_SHORT,
colflags: myc::constants::ColumnFlags::UNSIGNED_FLAG,
}];
writers::write_column_definitions(cols, &mut self.rw, true, true)?;
}
Command::Init(schema) => {
let w = InitWriter {
writer: &mut self.rw,
};
self.shim.on_init(
::std::str::from_utf8(schema)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?,
w,
)?;
}
Command::Ping => {
writers::write_ok_packet(&mut self.rw, 0, 0, StatusFlags::empty())?;
}
Command::Quit => {
break;
}
}
self.rw.flush()?;
}
Ok(())
}
}