1#[macro_use]
4extern crate log;
5extern crate env_logger;
6#[macro_use]
7extern crate futures;
8#[macro_use]
9extern crate tokio_core;
10extern crate byteorder;
11
12use std::rc::Rc;
13use std::io::{self, Read, Write, Error, ErrorKind};
14use std::net::Shutdown;
15
16use futures::{Future, Poll, Async};
17use tokio_core::net::{TcpStream};
18use byteorder::*;
19
20#[derive(Debug,PartialEq)]
22pub enum Action {
23 Drop,
25 Forward,
27 Mutate(Packet),
29 Respond(Vec<Packet>),
31 Error { code: u16, state: [u8; 5], msg: String },
33}
34
35pub trait PacketHandler {
37 fn handle_request(&mut self, p: &Packet) -> Action;
38 fn handle_response(&mut self, p: &Packet) -> Action;
39}
40
41#[derive(Debug,PartialEq)]
43pub struct Packet {
44 pub bytes: Vec<u8>
45}
46
47impl Packet {
48
49 pub fn error_packet(code: u16, state: [u8; 5], msg: String) -> Self {
51
52 let mut payload: Vec<u8> = Vec::with_capacity(9 + msg.len());
54 payload.push(0xff); payload.write_u16::<LittleEndian>(code).unwrap(); payload.extend_from_slice("#".as_bytes()); payload.extend_from_slice(&state); payload.extend_from_slice(msg.as_bytes());
59
60 let mut header: Vec<u8> = Vec::with_capacity(4 + 9 + msg.len());
62 header.write_u32::<LittleEndian>(payload.len() as u32).unwrap();
63 header.pop(); header.push(1); header.extend_from_slice(&payload);
68
69 Packet { bytes: header }
71 }
72
73 pub fn sequence_id(&self) -> u8 {
74 self.bytes[3]
75 }
76
77 pub fn packet_type(&self) -> Result<PacketType, Error> {
79 match self.bytes[4] {
80 0x00 => Ok(PacketType::ComSleep),
81 0x01 => Ok(PacketType::ComQuit),
82 0x02 => Ok(PacketType::ComInitDb),
83 0x03 => Ok(PacketType::ComQuery),
84 0x04 => Ok(PacketType::ComFieldList),
85 0x05 => Ok(PacketType::ComCreateDb),
86 0x06 => Ok(PacketType::ComDropDb),
87 0x07 => Ok(PacketType::ComRefresh),
88 0x08 => Ok(PacketType::ComShutdown),
89 0x09 => Ok(PacketType::ComStatistics),
90 0x0a => Ok(PacketType::ComProcessInfo),
91 0x0b => Ok(PacketType::ComConnect),
92 0x0c => Ok(PacketType::ComProcessKill),
93 0x0d => Ok(PacketType::ComDebug),
94 0x0e => Ok(PacketType::ComPing),
95 0x0f => Ok(PacketType::ComTime),
96 0x10 => Ok(PacketType::ComDelayedInsert),
97 0x11 => Ok(PacketType::ComChangeUser),
98 0x12 => Ok(PacketType::ComBinlogDump),
99 0x13 => Ok(PacketType::ComTableDump),
100 0x14 => Ok(PacketType::ComConnectOut),
101 0x15 => Ok(PacketType::ComRegisterSlave),
102 0x16 => Ok(PacketType::ComStmtPrepare),
103 0x17 => Ok(PacketType::ComStmtExecute),
104 0x18 => Ok(PacketType::ComStmtSendLongData),
105 0x19 => Ok(PacketType::ComStmtClose),
106 0x1a => Ok(PacketType::ComStmtReset),
107 0x1d => Ok(PacketType::ComDaemon),
108 0x1e => Ok(PacketType::ComBinlogDumpGtid),
109 0x1f => Ok(PacketType::ComResetConnection),
110 _ => Err(Error::new(ErrorKind::Other, "Invalid packet type"))
111 }
112 }
113
114}
115
116#[derive(Copy,Clone)]
117pub enum PacketType {
118 ComSleep = 0x00,
119 ComQuit = 0x01,
120 ComInitDb = 0x02,
121 ComQuery = 0x03,
122 ComFieldList = 0x04,
123 ComCreateDb = 0x05,
124 ComDropDb = 0x06,
125 ComRefresh = 0x07,
126 ComShutdown = 0x08,
127 ComStatistics = 0x09,
128 ComProcessInfo = 0x0a,
129 ComConnect = 0x0b,
130 ComProcessKill= 0x0c,
131 ComDebug = 0x0d,
132 ComPing = 0x0e,
133 ComTime = 0x0f,
134 ComDelayedInsert = 0x10,
135 ComChangeUser = 0x11,
136 ComBinlogDump = 0x12,
137 ComTableDump = 0x13,
138 ComConnectOut = 0x14,
139 ComRegisterSlave = 0x15,
140 ComStmtPrepare = 0x16,
141 ComStmtExecute = 0x17,
142 ComStmtSendLongData = 0x18,
143 ComStmtClose = 0x19,
144 ComStmtReset = 0x1a,
145 ComDaemon= 0x1d,
146 ComBinlogDumpGtid = 0x1e,
147 ComResetConnection = 0x1f,
148}
149
150
151struct ConnReader {
153 stream: Rc<TcpStream>,
154 packet_buf: Vec<u8>,
155 read_buf: Vec<u8>,
156}
157
158struct ConnWriter {
160 stream: Rc<TcpStream>,
161 write_buf: Vec<u8>,
162}
163
164impl ConnReader {
165
166 fn new(stream: Rc<TcpStream>) -> Self {
167 ConnReader {
168 stream: stream,
169 packet_buf: Vec::with_capacity(4096),
170 read_buf: vec![0_u8; 4096]
171 }
172 }
173
174 fn read(&mut self) -> Poll<(), io::Error> {
176 debug!("read()");
177 loop {
178 match self.stream.poll_read() {
179 Async::Ready(_) => {
180 let n = try_nb!((&*self.stream).read(&mut self.read_buf[..]));
181 if n == 0 {
182 return Err(Error::new(ErrorKind::Other, "connection closed"));
183 }
184 self.packet_buf.extend_from_slice(&self.read_buf[0..n]);
185 },
186 _ => return Ok(Async::NotReady),
187 }
188 }
189 }
190
191 fn next(&mut self) -> Option<Packet> {
192 debug!("next()");
193 if self.packet_buf.len() > 3 {
195 let l = parse_packet_length(&self.packet_buf);
196 let s = 4 + l;
198 if self.packet_buf.len() >= s {
199 let p = Packet { bytes: self.packet_buf.drain(0..s).collect() };
200 Some(p)
201 } else {
202 None
203 }
204 } else {
205 None
206 }
207 }
208}
209
210impl ConnWriter {
211
212 fn new(stream: Rc<TcpStream>) -> Self {
213 ConnWriter{
214 stream: stream,
215 write_buf: Vec::with_capacity(4096),
216 }
217 }
218
219 fn push(&mut self, p: &Packet) {
221 self.write_buf.extend_from_slice(&p.bytes);
225 debug!("end push()");
226 }
227
228 fn write(&mut self) -> Poll<(), io::Error> {
230 debug!("write()");
231 while self.write_buf.len() > 0 {
232 match self.stream.poll_write() {
233 Async::Ready(_) => {
234 let s = try!((&*self.stream).write(&self.write_buf[..]));
235 let _ : Vec<u8> = self.write_buf.drain(0..s).collect();
236 },
237 _ => return Ok(Async::NotReady)
238 }
239 }
240 return Ok(Async::Ready(()));
241 }
242}
243
244pub struct Pipe<H: PacketHandler + 'static> {
245 client_reader: ConnReader,
246 client_writer: ConnWriter,
247 server_reader: ConnReader,
248 server_writer: ConnWriter,
249 handler: H,
250}
251
252impl<H> Pipe<H> where H: PacketHandler + 'static {
253 pub fn new(client: Rc<TcpStream>,
254 server: Rc<TcpStream>,
255 handler: H
256 ) -> Pipe<H> {
257
258 Pipe {
259 client_reader: ConnReader::new(client.clone()),
260 client_writer: ConnWriter::new(client),
261 server_reader: ConnReader::new(server.clone()),
262 server_writer: ConnWriter::new(server),
263 handler: handler,
264 }
265 }
266}
267
268impl<H> Future for Pipe<H> where H: PacketHandler + 'static {
269 type Item = ();
270 type Error = Error;
271
272 fn poll(&mut self) -> Poll<(), Error> {
273 loop {
274 let client_read = self.client_reader.read();
275
276 while let Some(request) = self.client_reader.next() {
278 match self.handler.handle_request(&request) {
279 Action::Drop => {},
280 Action::Forward => self.server_writer.push(&request),
281 Action::Mutate(ref p2) => self.server_writer.push(p2),
282 Action::Respond(ref v) => {
283 for p in v {
284 self.client_writer.push(&p);
285 }
286 },
287 Action::Error { code, state, msg } => {
288 let error_packet = Packet::error_packet(code, state, msg);
289 self.client_writer.push(&error_packet);
290 }
291 };
292 }
293
294 let server_read = self.server_reader.read();
296
297 while let Some(response) = self.server_reader.next() {
299 match self.handler.handle_response(&response) {
300 Action::Drop => {},
301 Action::Forward => self.client_writer.push(&response),
302 Action::Mutate(ref p2) => self.client_writer.push(p2),
303 Action::Respond(ref v) => {
304 for p in v {
305 self.server_writer.push(&p);
306 }
307 },
308 Action::Error { code, state, msg } => {
309 let error_packet = Packet::error_packet(code, state, msg);
310 self.client_writer.push(&error_packet);
311 }
312 };
313 }
314
315 let client_write = self.client_writer.write();
320
321 match &server_read {
323 &Err(ref e) => {
324 debug!("Server closed connection: {}", e);
325 match self.client_writer.stream.shutdown(Shutdown::Write) {
326 Ok(_) => {},
327 Err(_) => {}
328 }
329 },
330 _ => {}
331 }
332
333 let server_write = self.server_writer.write();
335
336 match &client_read {
338 &Err(ref e) => {
339 debug!("Client closed connection: {}", e);
340 match self.server_writer.stream.shutdown(Shutdown::Write) {
341 Ok(_) => {},
342 Err(_) => {}
343 }
344 },
345 _ => {}
346 }
347
348 try_ready!(client_read);
349 try_ready!(client_write);
350 try_ready!(server_read);
351 try_ready!(server_write);
352 }
353
354 }
355
356}
357
358fn parse_packet_length(header: &[u8]) -> usize {
360 (((header[2] as u32) << 16) |
361 ((header[1] as u32) << 8) |
362 header[0] as u32) as usize
363}