mysql_proxy/
lib.rs

1//! An extensible MySQL Proxy Server based on tokio-core
2
3#[macro_use]
4extern crate log;
5extern crate env_logger;
6#[macro_use]
7extern crate futures;
8#[macro_use]
9extern crate tokio_core;
10extern crate byteorder;
11
12use std::rc::Rc;
13use std::io::{self, Read, Write, Error, ErrorKind};
14use std::net::Shutdown;
15
16use futures::{Future, Poll, Async};
17use tokio_core::net::{TcpStream};
18use byteorder::*;
19
20/// Handlers return a variant of this enum to indicate how the proxy should handle the packet.
21#[derive(Debug,PartialEq)]
22pub enum Action {
23    /// drop the packet
24    Drop,
25    /// forward the packet unmodified
26    Forward,
27    /// forward a mutated packet
28    Mutate(Packet),
29    /// respond to the packet without forwarding
30    Respond(Vec<Packet>),
31    /// respond with an error packet
32    Error { code: u16, state: [u8; 5], msg: String },
33}
34
35/// Packet handlers need to implement this trait
36pub trait PacketHandler {
37    fn handle_request(&mut self, p: &Packet) -> Action;
38    fn handle_response(&mut self, p: &Packet) -> Action;
39}
40
41/// A packet is just a wrapper for a Vec<u8>
42#[derive(Debug,PartialEq)]
43pub struct Packet {
44    pub bytes: Vec<u8>
45}
46
47impl Packet {
48
49    /// Create an error packet
50    pub fn error_packet(code: u16, state: [u8; 5], msg: String) -> Self {
51
52        // start building payload
53        let mut payload: Vec<u8> = Vec::with_capacity(9 + msg.len());
54        payload.push(0xff);  // packet type
55        payload.write_u16::<LittleEndian>(code).unwrap(); // error code
56        payload.extend_from_slice("#".as_bytes()); // sql_state_marker
57        payload.extend_from_slice(&state); // SQL STATE
58        payload.extend_from_slice(msg.as_bytes());
59
60        // create header with length and sequence id
61        let mut header: Vec<u8> = Vec::with_capacity(4 + 9 + msg.len());
62        header.write_u32::<LittleEndian>(payload.len() as u32).unwrap();
63        header.pop(); // we need 3 byte length, so discard last byte
64        header.push(1); // sequence_id
65
66        // combine the vectors
67        header.extend_from_slice(&payload);
68
69        // now move the vector into the packet
70        Packet { bytes: header }
71    }
72
73    pub fn sequence_id(&self) -> u8 {
74        self.bytes[3]
75    }
76
77    /// Determine the type of packet
78    pub fn packet_type(&self) -> Result<PacketType, Error> {
79        match self.bytes[4] {
80            0x00 => Ok(PacketType::ComSleep),
81            0x01 => Ok(PacketType::ComQuit),
82            0x02 => Ok(PacketType::ComInitDb),
83            0x03 => Ok(PacketType::ComQuery),
84            0x04 => Ok(PacketType::ComFieldList),
85            0x05 => Ok(PacketType::ComCreateDb),
86            0x06 => Ok(PacketType::ComDropDb),
87            0x07 => Ok(PacketType::ComRefresh),
88            0x08 => Ok(PacketType::ComShutdown),
89            0x09 => Ok(PacketType::ComStatistics),
90            0x0a => Ok(PacketType::ComProcessInfo),
91            0x0b => Ok(PacketType::ComConnect),
92            0x0c => Ok(PacketType::ComProcessKill),
93            0x0d => Ok(PacketType::ComDebug),
94            0x0e => Ok(PacketType::ComPing),
95            0x0f => Ok(PacketType::ComTime),
96            0x10 => Ok(PacketType::ComDelayedInsert),
97            0x11 => Ok(PacketType::ComChangeUser),
98            0x12 => Ok(PacketType::ComBinlogDump),
99            0x13 => Ok(PacketType::ComTableDump),
100            0x14 => Ok(PacketType::ComConnectOut),
101            0x15 => Ok(PacketType::ComRegisterSlave),
102            0x16 => Ok(PacketType::ComStmtPrepare),
103            0x17 => Ok(PacketType::ComStmtExecute),
104            0x18 => Ok(PacketType::ComStmtSendLongData),
105            0x19 => Ok(PacketType::ComStmtClose),
106            0x1a => Ok(PacketType::ComStmtReset),
107            0x1d => Ok(PacketType::ComDaemon),
108            0x1e => Ok(PacketType::ComBinlogDumpGtid),
109            0x1f => Ok(PacketType::ComResetConnection),
110            _ => Err(Error::new(ErrorKind::Other, "Invalid packet type"))
111        }
112    }
113
114}
115
116#[derive(Copy,Clone)]
117pub enum PacketType {
118    ComSleep = 0x00,
119    ComQuit = 0x01,
120    ComInitDb = 0x02,
121    ComQuery = 0x03,
122    ComFieldList = 0x04,
123    ComCreateDb = 0x05,
124    ComDropDb = 0x06,
125    ComRefresh = 0x07,
126    ComShutdown = 0x08,
127    ComStatistics = 0x09,
128    ComProcessInfo = 0x0a,
129    ComConnect = 0x0b,
130    ComProcessKill= 0x0c,
131    ComDebug = 0x0d,
132    ComPing = 0x0e,
133    ComTime = 0x0f,
134    ComDelayedInsert = 0x10,
135    ComChangeUser = 0x11,
136    ComBinlogDump = 0x12,
137    ComTableDump = 0x13,
138    ComConnectOut = 0x14,
139    ComRegisterSlave = 0x15,
140    ComStmtPrepare = 0x16,
141    ComStmtExecute = 0x17,
142    ComStmtSendLongData = 0x18,
143    ComStmtClose = 0x19,
144    ComStmtReset = 0x1a,
145    ComDaemon= 0x1d,
146    ComBinlogDumpGtid = 0x1e,
147    ComResetConnection = 0x1f,
148}
149
150
151/// Wrapper for TcpStream with some built-in buffering
152struct ConnReader {
153    stream: Rc<TcpStream>,
154    packet_buf: Vec<u8>,
155    read_buf: Vec<u8>,
156}
157
158/// Wrapper for TcpStream with some built-in buffering
159struct ConnWriter {
160    stream: Rc<TcpStream>,
161    write_buf: Vec<u8>,
162}
163
164impl ConnReader {
165
166    fn new(stream: Rc<TcpStream>) -> Self {
167        ConnReader {
168            stream: stream,
169            packet_buf: Vec::with_capacity(4096),
170            read_buf: vec![0_u8; 4096]
171        }
172    }
173
174    /// Read from the socket until the status is NotReady
175    fn read(&mut self) -> Poll<(), io::Error> {
176        debug!("read()");
177        loop {
178            match self.stream.poll_read() {
179                Async::Ready(_) => {
180                    let n = try_nb!((&*self.stream).read(&mut self.read_buf[..]));
181                    if n == 0 {
182                        return Err(Error::new(ErrorKind::Other, "connection closed"));
183                    }
184                    self.packet_buf.extend_from_slice(&self.read_buf[0..n]);
185                },
186                _ => return Ok(Async::NotReady),
187            }
188        }
189    }
190
191    fn next(&mut self) -> Option<Packet> {
192        debug!("next()");
193        // do we have a header
194        if self.packet_buf.len() > 3 {
195            let l = parse_packet_length(&self.packet_buf);
196            // do we have the whole packet?
197            let s = 4 + l;
198            if self.packet_buf.len() >= s {
199                let p = Packet { bytes: self.packet_buf.drain(0..s).collect() };
200                Some(p)
201            } else {
202                None
203            }
204        } else {
205            None
206        }
207    }
208}
209
210impl ConnWriter {
211
212    fn new(stream: Rc<TcpStream>) -> Self {
213        ConnWriter{
214            stream: stream,
215            write_buf: Vec::with_capacity(4096),
216        }
217    }
218
219    /// Write a packet to the write buffer
220    fn push(&mut self, p: &Packet) {
221        //        debug!("push() capacity: {} position: {} packet_size: {}",
222        //               self.write_buf.capacity(), self.write_pos, p.bytes.len());
223
224        self.write_buf.extend_from_slice(&p.bytes);
225        debug!("end push()");
226    }
227
228    /// Writes the contents of the write buffer to the socket
229    fn write(&mut self) -> Poll<(), io::Error> {
230        debug!("write()");
231        while self.write_buf.len() > 0 {
232            match self.stream.poll_write() {
233                Async::Ready(_) => {
234                    let s = try!((&*self.stream).write(&self.write_buf[..]));
235                    let _ : Vec<u8> = self.write_buf.drain(0..s).collect();
236                },
237                _ => return Ok(Async::NotReady)
238            }
239        }
240        return Ok(Async::Ready(()));
241    }
242}
243
244pub struct Pipe<H: PacketHandler + 'static> {
245    client_reader: ConnReader,
246    client_writer: ConnWriter,
247    server_reader: ConnReader,
248    server_writer: ConnWriter,
249    handler: H,
250}
251
252impl<H> Pipe<H> where H: PacketHandler + 'static {
253    pub fn new(client: Rc<TcpStream>,
254               server: Rc<TcpStream>,
255               handler: H
256    ) -> Pipe<H> {
257
258        Pipe {
259            client_reader: ConnReader::new(client.clone()),
260            client_writer: ConnWriter::new(client),
261            server_reader: ConnReader::new(server.clone()),
262            server_writer: ConnWriter::new(server),
263            handler: handler,
264        }
265    }
266}
267
268impl<H> Future for Pipe<H> where H: PacketHandler + 'static {
269    type Item = ();
270    type Error = Error;
271
272    fn poll(&mut self) -> Poll<(), Error> {
273        loop {
274            let client_read = self.client_reader.read();
275
276            // process buffered requests
277            while let Some(request) = self.client_reader.next() {
278                match self.handler.handle_request(&request) {
279                    Action::Drop => {},
280                    Action::Forward => self.server_writer.push(&request),
281                    Action::Mutate(ref p2) => self.server_writer.push(p2),
282                    Action::Respond(ref v) => {
283                        for p in v {
284                            self.client_writer.push(&p);
285                        }
286                    },
287                    Action::Error { code, state, msg } => {
288                        let error_packet = Packet::error_packet(code, state, msg);
289                        self.client_writer.push(&error_packet);
290                    }
291                };
292            }
293
294            // try reading from server
295            let server_read = self.server_reader.read();
296
297            // process buffered responses
298            while let Some(response) = self.server_reader.next() {
299                match self.handler.handle_response(&response) {
300                    Action::Drop => {},
301                    Action::Forward => self.client_writer.push(&response),
302                    Action::Mutate(ref p2) => self.client_writer.push(p2),
303                    Action::Respond(ref v) => {
304                        for p in v {
305                            self.server_writer.push(&p);
306                        }
307                    },
308                    Action::Error { code, state, msg } => {
309                        let error_packet = Packet::error_packet(code, state, msg);
310                        self.client_writer.push(&error_packet);
311                    }
312                };
313            }
314
315            // perform all of the writes at the end, since the request handlers may have
316            // queued packets in either, or both directions
317
318            // try writing to client
319            let client_write = self.client_writer.write();
320
321            // if the server connection has closed, close the client connection too
322            match &server_read {
323                &Err(ref e) => {
324                    debug!("Server closed connection: {}", e);
325                    match self.client_writer.stream.shutdown(Shutdown::Write) {
326                        Ok(_) => {},
327                        Err(_) => {}
328                    }
329                },
330                _ => {}
331            }
332
333            // try writing to server
334            let server_write = self.server_writer.write();
335
336            // if the client connection has closed, close the server connection too
337            match &client_read {
338                &Err(ref e) => {
339                    debug!("Client closed connection: {}", e);
340                    match self.server_writer.stream.shutdown(Shutdown::Write) {
341                        Ok(_) => {},
342                        Err(_) => {}
343                    }
344                },
345                _ => {}
346            }
347
348            try_ready!(client_read);
349            try_ready!(client_write);
350            try_ready!(server_read);
351            try_ready!(server_write);
352        }
353
354    }
355
356}
357
358/// Parse the MySQL packet length (3 byte little-endian)
359fn parse_packet_length(header: &[u8]) -> usize {
360    (((header[2] as u32) << 16) |
361        ((header[1] as u32) << 8) |
362        header[0] as u32) as usize
363}