#[macro_use]
extern crate log;
extern crate env_logger;
#[macro_use]
extern crate futures;
#[macro_use]
extern crate tokio_core;
extern crate byteorder;
use std::rc::Rc;
use std::io::{self, Read, Write, Error, ErrorKind};
use std::net::Shutdown;
use futures::{Future, Poll, Async};
use tokio_core::net::{TcpStream};
use byteorder::*;
#[derive(Debug,PartialEq)]
pub enum Action {
Drop,
Forward,
Mutate(Packet),
Respond(Vec<Packet>),
Error { code: u16, state: [u8; 5], msg: String },
}
pub trait PacketHandler {
fn handle_request(&mut self, p: &Packet) -> Action;
fn handle_response(&mut self, p: &Packet) -> Action;
}
#[derive(Debug,PartialEq)]
pub struct Packet {
pub bytes: Vec<u8>
}
impl Packet {
pub fn error_packet(code: u16, state: [u8; 5], msg: String) -> Self {
let mut payload: Vec<u8> = Vec::with_capacity(9 + msg.len());
payload.push(0xff); payload.write_u16::<LittleEndian>(code).unwrap(); payload.extend_from_slice("#".as_bytes()); payload.extend_from_slice(&state); payload.extend_from_slice(msg.as_bytes());
let mut header: Vec<u8> = Vec::with_capacity(4 + 9 + msg.len());
header.write_u32::<LittleEndian>(payload.len() as u32).unwrap();
header.pop(); header.push(1);
header.extend_from_slice(&payload);
Packet { bytes: header }
}
pub fn sequence_id(&self) -> u8 {
self.bytes[3]
}
pub fn packet_type(&self) -> Result<PacketType, Error> {
match self.bytes[4] {
0x00 => Ok(PacketType::ComSleep),
0x01 => Ok(PacketType::ComQuit),
0x02 => Ok(PacketType::ComInitDb),
0x03 => Ok(PacketType::ComQuery),
0x04 => Ok(PacketType::ComFieldList),
0x05 => Ok(PacketType::ComCreateDb),
0x06 => Ok(PacketType::ComDropDb),
0x07 => Ok(PacketType::ComRefresh),
0x08 => Ok(PacketType::ComShutdown),
0x09 => Ok(PacketType::ComStatistics),
0x0a => Ok(PacketType::ComProcessInfo),
0x0b => Ok(PacketType::ComConnect),
0x0c => Ok(PacketType::ComProcessKill),
0x0d => Ok(PacketType::ComDebug),
0x0e => Ok(PacketType::ComPing),
0x0f => Ok(PacketType::ComTime),
0x10 => Ok(PacketType::ComDelayedInsert),
0x11 => Ok(PacketType::ComChangeUser),
0x12 => Ok(PacketType::ComBinlogDump),
0x13 => Ok(PacketType::ComTableDump),
0x14 => Ok(PacketType::ComConnectOut),
0x15 => Ok(PacketType::ComRegisterSlave),
0x16 => Ok(PacketType::ComStmtPrepare),
0x17 => Ok(PacketType::ComStmtExecute),
0x18 => Ok(PacketType::ComStmtSendLongData),
0x19 => Ok(PacketType::ComStmtClose),
0x1a => Ok(PacketType::ComStmtReset),
0x1d => Ok(PacketType::ComDaemon),
0x1e => Ok(PacketType::ComBinlogDumpGtid),
0x1f => Ok(PacketType::ComResetConnection),
_ => Err(Error::new(ErrorKind::Other, "Invalid packet type"))
}
}
}
#[derive(Copy,Clone)]
pub enum PacketType {
ComSleep = 0x00,
ComQuit = 0x01,
ComInitDb = 0x02,
ComQuery = 0x03,
ComFieldList = 0x04,
ComCreateDb = 0x05,
ComDropDb = 0x06,
ComRefresh = 0x07,
ComShutdown = 0x08,
ComStatistics = 0x09,
ComProcessInfo = 0x0a,
ComConnect = 0x0b,
ComProcessKill= 0x0c,
ComDebug = 0x0d,
ComPing = 0x0e,
ComTime = 0x0f,
ComDelayedInsert = 0x10,
ComChangeUser = 0x11,
ComBinlogDump = 0x12,
ComTableDump = 0x13,
ComConnectOut = 0x14,
ComRegisterSlave = 0x15,
ComStmtPrepare = 0x16,
ComStmtExecute = 0x17,
ComStmtSendLongData = 0x18,
ComStmtClose = 0x19,
ComStmtReset = 0x1a,
ComDaemon= 0x1d,
ComBinlogDumpGtid = 0x1e,
ComResetConnection = 0x1f,
}
struct ConnReader {
stream: Rc<TcpStream>,
packet_buf: Vec<u8>,
read_buf: Vec<u8>,
}
struct ConnWriter {
stream: Rc<TcpStream>,
write_buf: Vec<u8>,
}
impl ConnReader {
fn new(stream: Rc<TcpStream>) -> Self {
ConnReader {
stream: stream,
packet_buf: Vec::with_capacity(4096),
read_buf: vec![0_u8; 4096]
}
}
fn read(&mut self) -> Poll<(), io::Error> {
debug!("read()");
loop {
match self.stream.poll_read() {
Async::Ready(_) => {
let n = try_nb!((&*self.stream).read(&mut self.read_buf[..]));
if n == 0 {
return Err(Error::new(ErrorKind::Other, "connection closed"));
}
self.packet_buf.extend_from_slice(&self.read_buf[0..n]);
},
_ => return Ok(Async::NotReady),
}
}
}
fn next(&mut self) -> Option<Packet> {
debug!("next()");
if self.packet_buf.len() > 3 {
let l = parse_packet_length(&self.packet_buf);
let s = 4 + l;
if self.packet_buf.len() >= s {
let p = Packet { bytes: self.packet_buf.drain(0..s).collect() };
Some(p)
} else {
None
}
} else {
None
}
}
}
impl ConnWriter {
fn new(stream: Rc<TcpStream>) -> Self {
ConnWriter{
stream: stream,
write_buf: Vec::with_capacity(4096),
}
}
fn push(&mut self, p: &Packet) {
self.write_buf.extend_from_slice(&p.bytes);
debug!("end push()");
}
fn write(&mut self) -> Poll<(), io::Error> {
debug!("write()");
while self.write_buf.len() > 0 {
match self.stream.poll_write() {
Async::Ready(_) => {
let s = try!((&*self.stream).write(&self.write_buf[..]));
let _ : Vec<u8> = self.write_buf.drain(0..s).collect();
},
_ => return Ok(Async::NotReady)
}
}
return Ok(Async::Ready(()));
}
}
pub struct Pipe<H: PacketHandler + 'static> {
client_reader: ConnReader,
client_writer: ConnWriter,
server_reader: ConnReader,
server_writer: ConnWriter,
handler: H,
}
impl<H> Pipe<H> where H: PacketHandler + 'static {
pub fn new(client: Rc<TcpStream>,
server: Rc<TcpStream>,
handler: H
) -> Pipe<H> {
Pipe {
client_reader: ConnReader::new(client.clone()),
client_writer: ConnWriter::new(client),
server_reader: ConnReader::new(server.clone()),
server_writer: ConnWriter::new(server),
handler: handler,
}
}
}
impl<H> Future for Pipe<H> where H: PacketHandler + 'static {
type Item = ();
type Error = Error;
fn poll(&mut self) -> Poll<(), Error> {
loop {
let client_read = self.client_reader.read();
match &client_read {
&Err(ref e) => {
debug!("Client closed connection: {}", e);
match self.server_writer.stream.shutdown(Shutdown::Write) {
Ok(_) => {},
Err(_) => {}
}
},
_ => {}
}
while let Some(request) = self.client_reader.next() {
match self.handler.handle_request(&request) {
Action::Drop => {},
Action::Forward => self.server_writer.push(&request),
Action::Mutate(ref p2) => self.server_writer.push(p2),
Action::Respond(ref v) => {
for p in v {
self.client_writer.push(&p);
}
},
Action::Error { code, state, msg } => {
let error_packet = Packet::error_packet(code, state, msg);
self.client_writer.push(&error_packet);
}
};
}
let server_read = self.server_reader.read();
match &server_read {
&Err(ref e) => {
debug!("Server closed connection: {}", e);
match self.client_writer.stream.shutdown(Shutdown::Write) {
Ok(_) => {},
Err(_) => {}
}
},
_ => {}
}
while let Some(response) = self.server_reader.next() {
match self.handler.handle_response(&response) {
Action::Drop => {},
Action::Forward => self.client_writer.push(&response),
Action::Mutate(ref p2) => self.client_writer.push(p2),
Action::Respond(ref v) => {
for p in v {
self.server_writer.push(&p);
}
},
Action::Error { code, state, msg } => {
let error_packet = Packet::error_packet(code, state, msg);
self.client_writer.push(&error_packet);
}
};
}
let client_write = self.client_writer.write();
let server_write = self.server_writer.write();
try_ready!(client_read);
try_ready!(client_write);
try_ready!(server_read);
try_ready!(server_write);
}
}
}
fn parse_packet_length(header: &[u8]) -> usize {
(((header[2] as u32) << 16) |
((header[1] as u32) << 8) |
header[0] as u32) as usize
}