msql_srv/
lib.rs

1//! Bindings for emulating a MySQL/MariaDB server.
2//!
3//! When developing new databases or caching layers, it can be immensely useful to test your system
4//! using existing applications. However, this often requires significant work modifying
5//! applications to use your database over the existing ones. This crate solves that problem by
6//! acting as a MySQL server, and delegating operations such as querying and query execution to
7//! user-defined logic.
8//!
9//! To start, implement `MysqlShim` for your backend, and create a `MysqlIntermediary` over an
10//! instance of your backend and a connection stream. The appropriate methods will be called on
11//! your backend whenever a client issues a `QUERY`, `PREPARE`, or `EXECUTE` command, and you will
12//! have a chance to respond appropriately. For example, to write a shim that always responds to
13//! all commands with a "no results" reply:
14//!
15//! ```
16//! # extern crate msql_srv;
17//! extern crate mysql;
18//! # use std::io;
19//! # use std::net;
20//! # use std::thread;
21//! use msql_srv::*;
22//! use mysql::prelude::*;
23//! use mysql::Opts;
24//!
25//! struct Backend;
26//! impl<W: io::Read + io::Write> MysqlShim<W> for Backend {
27//!     type Error = io::Error;
28//!
29//!     fn on_prepare(&mut self, _: &str, info: StatementMetaWriter<W>) -> io::Result<()> {
30//!         info.reply(42, &[], &[])
31//!     }
32//!     fn on_execute(
33//!         &mut self,
34//!         _: u32,
35//!         _: ParamParser,
36//!         results: QueryResultWriter<W>,
37//!     ) -> io::Result<()> {
38//!         results.completed(0, 0)
39//!     }
40//!     fn on_close(&mut self, _: u32) {}
41//!
42//!     fn on_init(&mut self, _: &str, writer: InitWriter<W>) -> io::Result<()> { Ok(()) }
43//!
44//!     fn on_query(&mut self, _: &str, results: QueryResultWriter<W>) -> io::Result<()> {
45//!         let cols = [
46//!             Column {
47//!                 table: "foo".to_string(),
48//!                 column: "a".to_string(),
49//!                 coltype: ColumnType::MYSQL_TYPE_LONGLONG,
50//!                 colflags: ColumnFlags::empty(),
51//!             },
52//!             Column {
53//!                 table: "foo".to_string(),
54//!                 column: "b".to_string(),
55//!                 coltype: ColumnType::MYSQL_TYPE_STRING,
56//!                 colflags: ColumnFlags::empty(),
57//!             },
58//!         ];
59//!
60//!         let mut rw = results.start(&cols)?;
61//!         rw.write_col(42)?;
62//!         rw.write_col("b's value")?;
63//!         rw.finish()
64//!     }
65//! }
66//!
67//! fn main() {
68//!     let listener = net::TcpListener::bind("127.0.0.1:0").unwrap();
69//!     let port = listener.local_addr().unwrap().port();
70//!
71//!     let jh = thread::spawn(move || {
72//!         if let Ok((s, _)) = listener.accept() {
73//!             MysqlIntermediary::run_on_tcp(Backend, s).unwrap();
74//!         }
75//!     });
76//!
77//!     let mut db = mysql::Conn::new(Opts::from_url(&format!("mysql://127.0.0.1:{}", port)).unwrap()).unwrap();
78//!     assert_eq!(db.ping(), true);
79//!     assert_eq!(db.query_iter("SELECT a, b FROM foo").unwrap().count(), 1);
80//!     drop(db);
81//!     jh.join().unwrap();
82//! }
83//! ```
84#![deny(missing_docs)]
85#![deny(rust_2018_idioms)]
86
87// Note to developers: you can find decent overviews of the protocol at
88//
89//   https://github.com/cwarden/mysql-proxy/blob/master/doc/protocol.rst
90//
91// and
92//
93//   https://mariadb.com/kb/en/library/clientserver-protocol/
94//
95// Wireshark also does a pretty good job at parsing the MySQL protocol.
96
97extern crate mysql_common as myc;
98
99use std::collections::HashMap;
100use std::io;
101use std::io::prelude::*;
102use std::iter;
103use std::net;
104
105use myc::constants::CapabilityFlags;
106
107pub use crate::myc::constants::{ColumnFlags, ColumnType, StatusFlags};
108
109mod commands;
110mod errorcodes;
111mod packet;
112mod params;
113mod resultset;
114#[cfg(feature = "tls")]
115mod tls;
116mod value;
117mod writers;
118
119/// Meta-information abot a single column, used either to describe a prepared statement parameter
120/// or an output column.
121#[derive(Debug, Clone, PartialEq, Eq)]
122pub struct Column {
123    /// This column's associated table.
124    ///
125    /// Note that this is *technically* the table's alias.
126    pub table: String,
127    /// This column's name.
128    ///
129    /// Note that this is *technically* the column's alias.
130    pub column: String,
131    /// This column's type>
132    pub coltype: ColumnType,
133    /// Any flags associated with this column.
134    ///
135    /// Of particular interest are `ColumnFlags::UNSIGNED_FLAG` and `ColumnFlags::NOT_NULL_FLAG`.
136    pub colflags: ColumnFlags,
137}
138
139pub use crate::errorcodes::ErrorKind;
140pub use crate::params::{ParamParser, ParamValue, Params};
141pub use crate::resultset::{InitWriter, QueryResultWriter, RowWriter, StatementMetaWriter};
142pub use crate::value::{ToMysqlValue, Value, ValueInner};
143
144/// Implementors of this trait can be used to drive a MySQL-compatible database backend.
145pub trait MysqlShim<W: Read + Write> {
146    /// The error type produced by operations on this shim.
147    ///
148    /// Must implement `From<io::Error>` so that transport-level errors can be lifted.
149    type Error: From<io::Error>;
150
151    /// Called when the client issues a request to prepare `query` for later execution.
152    ///
153    /// The provided [`StatementMetaWriter`](struct.StatementMetaWriter.html) should be used to
154    /// notify the client of the statement id assigned to the prepared statement, as well as to
155    /// give metadata about the types of parameters and returned columns.
156    fn on_prepare(
157        &mut self,
158        query: &str,
159        info: StatementMetaWriter<'_, W>,
160    ) -> Result<(), Self::Error>;
161
162    /// Called when the client executes a previously prepared statement.
163    ///
164    /// Any parameters included with the client's command is given in `params`.
165    /// A response to the query should be given using the provided
166    /// [`QueryResultWriter`](struct.QueryResultWriter.html).
167    fn on_execute(
168        &mut self,
169        id: u32,
170        params: ParamParser<'_>,
171        results: QueryResultWriter<'_, W>,
172    ) -> Result<(), Self::Error>;
173
174    /// Called when the client wishes to deallocate resources associated with a previously prepared
175    /// statement.
176    fn on_close(&mut self, stmt: u32);
177
178    /// Called when the client issues a query for immediate execution.
179    ///
180    /// Results should be returned using the given
181    /// [`QueryResultWriter`](struct.QueryResultWriter.html).
182    fn on_query(
183        &mut self,
184        query: &str,
185        results: QueryResultWriter<'_, W>,
186    ) -> Result<(), Self::Error>;
187
188    /// Called when client switches database.
189    fn on_init(&mut self, _: &str, _: InitWriter<'_, W>) -> Result<(), Self::Error> {
190        Ok(())
191    }
192
193    /// Provides the TLS configuration, if we want to support TLS.
194    #[cfg(feature = "tls")]
195    fn tls_config(&self) -> Option<std::sync::Arc<rustls::ServerConfig>> {
196        None
197    }
198
199    /// Called after successful authentication (including TLS if applicable) passing relevant
200    /// information to allow additional logic in the MySqlShim implementation.
201    fn after_authentication(
202        &mut self,
203        _context: &AuthenticationContext<'_>,
204    ) -> Result<(), Self::Error> {
205        Ok(())
206    }
207}
208
209/// Information about an authenticated user
210#[allow(clippy::derive_partial_eq_without_eq)]
211#[derive(Debug, Default, Clone, PartialEq)]
212pub struct AuthenticationContext<'a> {
213    /// The username exactly as passed by the client,
214    pub username: Option<Vec<u8>>,
215    #[cfg(feature = "tls")]
216    /// The TLS certificate chain presented by the client.
217    pub tls_client_certs: Option<&'a [rustls::pki_types::CertificateDer<'a>]>,
218    #[cfg(not(feature = "tls"))]
219    _pd: Option<&'a std::marker::PhantomData<()>>,
220}
221
222/// A server that speaks the MySQL/MariaDB protocol, and can delegate client commands to a backend
223/// that implements [`MysqlShim`](trait.MysqlShim.html).
224pub struct MysqlIntermediary<B, RW: Read + Write> {
225    shim: B,
226    rw: packet::PacketConn<RW>,
227}
228
229impl<B: MysqlShim<net::TcpStream>> MysqlIntermediary<B, net::TcpStream> {
230    /// Create a new server over a TCP stream and process client commands until the client
231    /// disconnects or an error occurs. See also
232    /// [`MysqlIntermediary::run_on`](struct.MysqlIntermediary.html#method.run_on).
233    pub fn run_on_tcp(shim: B, stream: net::TcpStream) -> Result<(), B::Error> {
234        MysqlIntermediary::run_on(shim, stream)
235    }
236}
237
238impl<B: MysqlShim<S>, S: Read + Write + Clone> MysqlIntermediary<B, S> {
239    /// Create a new server over a two-way stream and process client commands until the client
240    /// disconnects or an error occurs. See also
241    /// [`MysqlIntermediary::run_on`](struct.MysqlIntermediary.html#method.run_on).
242    pub fn run_on_stream(shim: B, stream: S) -> Result<(), B::Error> {
243        MysqlIntermediary::run_on(shim, stream)
244    }
245}
246
247#[derive(Default)]
248struct StatementData {
249    long_data: HashMap<u16, Vec<u8>>,
250    bound_types: Vec<(myc::constants::ColumnType, bool)>,
251    params: u16,
252}
253
254impl<B: MysqlShim<RW>, RW: Read + Write> MysqlIntermediary<B, RW> {
255    /// Create a new server over a two-way channel and process client commands until the client
256    /// disconnects or an error occurs.
257    pub fn run_on(shim: B, rw: RW) -> Result<(), B::Error> {
258        let rw = packet::PacketConn::new(rw);
259        let mut mi = MysqlIntermediary { shim, rw };
260        mi.init()?;
261        mi.run()
262    }
263
264    fn init(&mut self) -> Result<(), B::Error> {
265        #[cfg(feature = "tls")]
266        let tls_conf = self.shim.tls_config();
267
268        self.rw.write_all(&[10])?; // protocol 10
269
270        // 5.1.10 because that's what Ruby's ActiveRecord requires
271        self.rw.write_all(&b"5.1.10-alpha-msql-proxy\0"[..])?;
272
273        self.rw.write_all(&[0x08, 0x00, 0x00, 0x00])?; // TODO: connection ID
274        self.rw.write_all(&b";X,po_k}\0"[..])?; // auth seed
275        let capabilities = &mut [0x00, 0x42]; // 4.1 proto
276        #[cfg(feature = "tls")]
277        if tls_conf.is_some() {
278            capabilities[1] |= 0x08; // SSL support flag
279        }
280        self.rw.write_all(capabilities)?;
281        self.rw.write_all(&[0x21])?; // UTF8_GENERAL_CI
282        self.rw.write_all(&[0x00, 0x00])?; // status flags
283        self.rw.write_all(&[0x00, 0x00])?; // extended capabilities
284        self.rw.write_all(&[0x00])?; // no plugins
285        self.rw.write_all(&[0x00; 6][..])?; // filler
286        self.rw.write_all(&[0x00; 4][..])?; // filler
287        self.rw.write_all(&b">o6^Wz!/kM}N\0"[..])?; // 4.1+ servers must extend salt
288        self.rw.flush()?;
289
290        let mut auth_context = AuthenticationContext::default();
291
292        {
293            let (seq, handshake) = self.rw.next()?.ok_or_else(|| {
294                io::Error::new(
295                    io::ErrorKind::ConnectionAborted,
296                    "peer terminated connection",
297                )
298            })?;
299            let handshake = commands::client_handshake(&handshake, false)
300                .map_err(|e| match e {
301                    nom::Err::Incomplete(_) => io::Error::new(
302                        io::ErrorKind::UnexpectedEof,
303                        "client sent incomplete handshake",
304                    ),
305                    nom::Err::Failure(nom_error) | nom::Err::Error(nom_error) => {
306                        if let nom::error::ErrorKind::Eof = nom_error.code {
307                            io::Error::new(
308                                io::ErrorKind::UnexpectedEof,
309                                format!(
310                                    "client did not complete handshake; got {:?}",
311                                    nom_error.input
312                                ),
313                            )
314                        } else {
315                            io::Error::new(
316                                io::ErrorKind::InvalidData,
317                                format!(
318                                    "bad client handshake; got {:?} ({:?})",
319                                    nom_error.input, nom_error.code
320                                ),
321                            )
322                        }
323                    }
324                })?
325                .1;
326
327            auth_context.username = handshake.username.map(|x| x.to_vec());
328
329            self.rw.set_seq(seq + 1);
330
331            #[cfg(not(feature = "tls"))]
332            if handshake.capabilities.contains(CapabilityFlags::CLIENT_SSL) {
333                return Err(io::Error::new(
334                    io::ErrorKind::InvalidData,
335                    "client requested SSL despite us not advertising support for it",
336                )
337                .into());
338            }
339
340            #[cfg(feature = "tls")]
341            if handshake.capabilities.contains(CapabilityFlags::CLIENT_SSL) {
342                let config = tls_conf.ok_or_else(|| {
343                    io::Error::new(
344                        io::ErrorKind::InvalidData,
345                        "client requested SSL despite us not advertising support for it",
346                    )
347                })?;
348
349                self.rw.switch_to_tls(config)?;
350
351                let (seq, handshake) = self.rw.next()?.ok_or_else(|| {
352                    io::Error::new(
353                        io::ErrorKind::ConnectionAborted,
354                        "peer terminated connection",
355                    )
356                })?;
357
358                let handshake = commands::client_handshake(&handshake, true)
359                    .map_err(|e| match e {
360                        nom::Err::Incomplete(_) => io::Error::new(
361                            io::ErrorKind::UnexpectedEof,
362                            "client sent incomplete handshake",
363                        ),
364                        nom::Err::Failure(nom_error) | nom::Err::Error(nom_error) => {
365                            if let nom::error::ErrorKind::Eof = nom_error.code {
366                                io::Error::new(
367                                    io::ErrorKind::UnexpectedEof,
368                                    format!(
369                                        "client did not complete handshake; got {:?}",
370                                        nom_error.input
371                                    ),
372                                )
373                            } else {
374                                io::Error::new(
375                                    io::ErrorKind::InvalidData,
376                                    format!(
377                                        "bad client handshake; got {:?} ({:?})",
378                                        nom_error.input, nom_error.code
379                                    ),
380                                )
381                            }
382                        }
383                    })?
384                    .1;
385
386                auth_context.username = handshake.username.map(|x| x.to_vec());
387
388                self.rw.set_seq(seq + 1);
389
390                auth_context.tls_client_certs = self.rw.tls_certs();
391            }
392
393            if let Err(e) = self.shim.after_authentication(&auth_context) {
394                writers::write_err(
395                    ErrorKind::ER_ACCESS_DENIED_ERROR,
396                    "client authentication failed".as_ref(),
397                    &mut self.rw,
398                )?;
399                self.rw.flush()?;
400                return Err(e);
401            }
402        }
403
404        writers::write_ok_packet(&mut self.rw, 0, 0, StatusFlags::empty())?;
405        self.rw.flush()?;
406
407        Ok(())
408    }
409
410    fn run(mut self) -> Result<(), B::Error> {
411        use crate::commands::Command;
412
413        let mut stmts: HashMap<u32, _> = HashMap::new();
414        while let Some((seq, packet)) = self.rw.next()? {
415            self.rw.set_seq(seq + 1);
416            let cmd = commands::parse(&packet).unwrap().1;
417            match cmd {
418                Command::Query(q) => {
419                    if q.starts_with(b"SELECT @@") || q.starts_with(b"select @@") {
420                        let w = QueryResultWriter::new(&mut self.rw, false);
421                        let var = &q[b"SELECT @@".len()..];
422                        match var {
423                            b"max_allowed_packet" => {
424                                let cols = &[Column {
425                                    table: String::new(),
426                                    column: "@@max_allowed_packet".to_owned(),
427                                    coltype: myc::constants::ColumnType::MYSQL_TYPE_LONG,
428                                    colflags: myc::constants::ColumnFlags::UNSIGNED_FLAG,
429                                }];
430                                let mut w = w.start(cols)?;
431                                w.write_row(iter::once(67108864u32))?;
432                                w.finish()?;
433                            }
434                            _ => {
435                                w.completed(0, 0)?;
436                            }
437                        }
438                    } else if q.starts_with(b"USE ") || q.starts_with(b"use ") {
439                        let w = InitWriter {
440                            writer: &mut self.rw,
441                        };
442                        let schema = ::std::str::from_utf8(&q[b"USE ".len()..])
443                            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
444                        let schema = schema.trim().trim_end_matches(';').trim_matches('`');
445                        self.shim.on_init(schema, w)?;
446                    } else {
447                        let w = QueryResultWriter::new(&mut self.rw, false);
448                        self.shim.on_query(
449                            ::std::str::from_utf8(q)
450                                .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?,
451                            w,
452                        )?;
453                    }
454                }
455                Command::Prepare(q) => {
456                    let w = StatementMetaWriter {
457                        writer: &mut self.rw,
458                        stmts: &mut stmts,
459                    };
460
461                    self.shim.on_prepare(
462                        ::std::str::from_utf8(q)
463                            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?,
464                        w,
465                    )?;
466                }
467                Command::Execute { stmt, params } => {
468                    let state = stmts.get_mut(&stmt).ok_or_else(|| {
469                        io::Error::new(
470                            io::ErrorKind::InvalidData,
471                            format!("asked to execute unknown statement {}", stmt),
472                        )
473                    })?;
474                    {
475                        let params = params::ParamParser::new(params, state);
476                        let w = QueryResultWriter::new(&mut self.rw, true);
477                        self.shim.on_execute(stmt, params, w)?;
478                    }
479                    state.long_data.clear();
480                }
481                Command::SendLongData { stmt, param, data } => {
482                    stmts
483                        .get_mut(&stmt)
484                        .ok_or_else(|| {
485                            io::Error::new(
486                                io::ErrorKind::InvalidData,
487                                format!("got long data packet for unknown statement {}", stmt),
488                            )
489                        })?
490                        .long_data
491                        .entry(param)
492                        .or_insert_with(Vec::new)
493                        .extend(data);
494                }
495                Command::Close(stmt) => {
496                    self.shim.on_close(stmt);
497                    stmts.remove(&stmt);
498                    // NOTE: spec dictates no response from server
499                }
500                Command::ListFields(_) => {
501                    let cols = &[Column {
502                        table: String::new(),
503                        column: "not implemented".to_owned(),
504                        coltype: myc::constants::ColumnType::MYSQL_TYPE_SHORT,
505                        colflags: myc::constants::ColumnFlags::UNSIGNED_FLAG,
506                    }];
507                    writers::write_column_definitions(cols, &mut self.rw, true, true)?;
508                }
509                Command::Init(schema) => {
510                    let w = InitWriter {
511                        writer: &mut self.rw,
512                    };
513                    self.shim.on_init(
514                        ::std::str::from_utf8(schema)
515                            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?,
516                        w,
517                    )?;
518                }
519                Command::Ping => {
520                    writers::write_ok_packet(&mut self.rw, 0, 0, StatusFlags::empty())?;
521                }
522                Command::Quit => {
523                    break;
524                }
525            }
526            self.rw.flush()?;
527        }
528        Ok(())
529    }
530}