msql_srv/lib.rs
1//! Bindings for emulating a MySQL/MariaDB server.
2//!
3//! When developing new databases or caching layers, it can be immensely useful to test your system
4//! using existing applications. However, this often requires significant work modifying
5//! applications to use your database over the existing ones. This crate solves that problem by
6//! acting as a MySQL server, and delegating operations such as querying and query execution to
7//! user-defined logic.
8//!
9//! To start, implement `MysqlShim` for your backend, and create a `MysqlIntermediary` over an
10//! instance of your backend and a connection stream. The appropriate methods will be called on
11//! your backend whenever a client issues a `QUERY`, `PREPARE`, or `EXECUTE` command, and you will
12//! have a chance to respond appropriately. For example, to write a shim that always responds to
13//! all commands with a "no results" reply:
14//!
15//! ```
16//! # extern crate msql_srv;
17//! extern crate mysql;
18//! # use std::io;
19//! # use std::net;
20//! # use std::thread;
21//! use msql_srv::*;
22//! use mysql::prelude::*;
23//! use mysql::Opts;
24//!
25//! struct Backend;
26//! impl<W: io::Read + io::Write> MysqlShim<W> for Backend {
27//! type Error = io::Error;
28//!
29//! fn on_prepare(&mut self, _: &str, info: StatementMetaWriter<W>) -> io::Result<()> {
30//! info.reply(42, &[], &[])
31//! }
32//! fn on_execute(
33//! &mut self,
34//! _: u32,
35//! _: ParamParser,
36//! results: QueryResultWriter<W>,
37//! ) -> io::Result<()> {
38//! results.completed(0, 0)
39//! }
40//! fn on_close(&mut self, _: u32) {}
41//!
42//! fn on_init(&mut self, _: &str, writer: InitWriter<W>) -> io::Result<()> { Ok(()) }
43//!
44//! fn on_query(&mut self, _: &str, results: QueryResultWriter<W>) -> io::Result<()> {
45//! let cols = [
46//! Column {
47//! table: "foo".to_string(),
48//! column: "a".to_string(),
49//! coltype: ColumnType::MYSQL_TYPE_LONGLONG,
50//! colflags: ColumnFlags::empty(),
51//! },
52//! Column {
53//! table: "foo".to_string(),
54//! column: "b".to_string(),
55//! coltype: ColumnType::MYSQL_TYPE_STRING,
56//! colflags: ColumnFlags::empty(),
57//! },
58//! ];
59//!
60//! let mut rw = results.start(&cols)?;
61//! rw.write_col(42)?;
62//! rw.write_col("b's value")?;
63//! rw.finish()
64//! }
65//! }
66//!
67//! fn main() {
68//! let listener = net::TcpListener::bind("127.0.0.1:0").unwrap();
69//! let port = listener.local_addr().unwrap().port();
70//!
71//! let jh = thread::spawn(move || {
72//! if let Ok((s, _)) = listener.accept() {
73//! MysqlIntermediary::run_on_tcp(Backend, s).unwrap();
74//! }
75//! });
76//!
77//! let mut db = mysql::Conn::new(Opts::from_url(&format!("mysql://127.0.0.1:{}", port)).unwrap()).unwrap();
78//! assert_eq!(db.ping(), true);
79//! assert_eq!(db.query_iter("SELECT a, b FROM foo").unwrap().count(), 1);
80//! drop(db);
81//! jh.join().unwrap();
82//! }
83//! ```
84#![deny(missing_docs)]
85#![deny(rust_2018_idioms)]
86
87// Note to developers: you can find decent overviews of the protocol at
88//
89// https://github.com/cwarden/mysql-proxy/blob/master/doc/protocol.rst
90//
91// and
92//
93// https://mariadb.com/kb/en/library/clientserver-protocol/
94//
95// Wireshark also does a pretty good job at parsing the MySQL protocol.
96
97extern crate mysql_common as myc;
98
99use std::collections::HashMap;
100use std::io;
101use std::io::prelude::*;
102use std::iter;
103use std::net;
104
105use myc::constants::CapabilityFlags;
106
107pub use crate::myc::constants::{ColumnFlags, ColumnType, StatusFlags};
108
109mod commands;
110mod errorcodes;
111mod packet;
112mod params;
113mod resultset;
114#[cfg(feature = "tls")]
115mod tls;
116mod value;
117mod writers;
118
119/// Meta-information abot a single column, used either to describe a prepared statement parameter
120/// or an output column.
121#[derive(Debug, Clone, PartialEq, Eq)]
122pub struct Column {
123 /// This column's associated table.
124 ///
125 /// Note that this is *technically* the table's alias.
126 pub table: String,
127 /// This column's name.
128 ///
129 /// Note that this is *technically* the column's alias.
130 pub column: String,
131 /// This column's type>
132 pub coltype: ColumnType,
133 /// Any flags associated with this column.
134 ///
135 /// Of particular interest are `ColumnFlags::UNSIGNED_FLAG` and `ColumnFlags::NOT_NULL_FLAG`.
136 pub colflags: ColumnFlags,
137}
138
139pub use crate::errorcodes::ErrorKind;
140pub use crate::params::{ParamParser, ParamValue, Params};
141pub use crate::resultset::{InitWriter, QueryResultWriter, RowWriter, StatementMetaWriter};
142pub use crate::value::{ToMysqlValue, Value, ValueInner};
143
144/// Implementors of this trait can be used to drive a MySQL-compatible database backend.
145pub trait MysqlShim<W: Read + Write> {
146 /// The error type produced by operations on this shim.
147 ///
148 /// Must implement `From<io::Error>` so that transport-level errors can be lifted.
149 type Error: From<io::Error>;
150
151 /// Called when the client issues a request to prepare `query` for later execution.
152 ///
153 /// The provided [`StatementMetaWriter`](struct.StatementMetaWriter.html) should be used to
154 /// notify the client of the statement id assigned to the prepared statement, as well as to
155 /// give metadata about the types of parameters and returned columns.
156 fn on_prepare(
157 &mut self,
158 query: &str,
159 info: StatementMetaWriter<'_, W>,
160 ) -> Result<(), Self::Error>;
161
162 /// Called when the client executes a previously prepared statement.
163 ///
164 /// Any parameters included with the client's command is given in `params`.
165 /// A response to the query should be given using the provided
166 /// [`QueryResultWriter`](struct.QueryResultWriter.html).
167 fn on_execute(
168 &mut self,
169 id: u32,
170 params: ParamParser<'_>,
171 results: QueryResultWriter<'_, W>,
172 ) -> Result<(), Self::Error>;
173
174 /// Called when the client wishes to deallocate resources associated with a previously prepared
175 /// statement.
176 fn on_close(&mut self, stmt: u32);
177
178 /// Called when the client issues a query for immediate execution.
179 ///
180 /// Results should be returned using the given
181 /// [`QueryResultWriter`](struct.QueryResultWriter.html).
182 fn on_query(
183 &mut self,
184 query: &str,
185 results: QueryResultWriter<'_, W>,
186 ) -> Result<(), Self::Error>;
187
188 /// Called when client switches database.
189 fn on_init(&mut self, _: &str, _: InitWriter<'_, W>) -> Result<(), Self::Error> {
190 Ok(())
191 }
192
193 /// Provides the TLS configuration, if we want to support TLS.
194 #[cfg(feature = "tls")]
195 fn tls_config(&self) -> Option<std::sync::Arc<rustls::ServerConfig>> {
196 None
197 }
198
199 /// Called after successful authentication (including TLS if applicable) passing relevant
200 /// information to allow additional logic in the MySqlShim implementation.
201 fn after_authentication(
202 &mut self,
203 _context: &AuthenticationContext<'_>,
204 ) -> Result<(), Self::Error> {
205 Ok(())
206 }
207}
208
209/// Information about an authenticated user
210#[allow(clippy::derive_partial_eq_without_eq)]
211#[derive(Debug, Default, Clone, PartialEq)]
212pub struct AuthenticationContext<'a> {
213 /// The username exactly as passed by the client,
214 pub username: Option<Vec<u8>>,
215 #[cfg(feature = "tls")]
216 /// The TLS certificate chain presented by the client.
217 pub tls_client_certs: Option<&'a [rustls::pki_types::CertificateDer<'a>]>,
218 #[cfg(not(feature = "tls"))]
219 _pd: Option<&'a std::marker::PhantomData<()>>,
220}
221
222/// A server that speaks the MySQL/MariaDB protocol, and can delegate client commands to a backend
223/// that implements [`MysqlShim`](trait.MysqlShim.html).
224pub struct MysqlIntermediary<B, RW: Read + Write> {
225 shim: B,
226 rw: packet::PacketConn<RW>,
227}
228
229impl<B: MysqlShim<net::TcpStream>> MysqlIntermediary<B, net::TcpStream> {
230 /// Create a new server over a TCP stream and process client commands until the client
231 /// disconnects or an error occurs. See also
232 /// [`MysqlIntermediary::run_on`](struct.MysqlIntermediary.html#method.run_on).
233 pub fn run_on_tcp(shim: B, stream: net::TcpStream) -> Result<(), B::Error> {
234 MysqlIntermediary::run_on(shim, stream)
235 }
236}
237
238impl<B: MysqlShim<S>, S: Read + Write + Clone> MysqlIntermediary<B, S> {
239 /// Create a new server over a two-way stream and process client commands until the client
240 /// disconnects or an error occurs. See also
241 /// [`MysqlIntermediary::run_on`](struct.MysqlIntermediary.html#method.run_on).
242 pub fn run_on_stream(shim: B, stream: S) -> Result<(), B::Error> {
243 MysqlIntermediary::run_on(shim, stream)
244 }
245}
246
247#[derive(Default)]
248struct StatementData {
249 long_data: HashMap<u16, Vec<u8>>,
250 bound_types: Vec<(myc::constants::ColumnType, bool)>,
251 params: u16,
252}
253
254impl<B: MysqlShim<RW>, RW: Read + Write> MysqlIntermediary<B, RW> {
255 /// Create a new server over a two-way channel and process client commands until the client
256 /// disconnects or an error occurs.
257 pub fn run_on(shim: B, rw: RW) -> Result<(), B::Error> {
258 let rw = packet::PacketConn::new(rw);
259 let mut mi = MysqlIntermediary { shim, rw };
260 mi.init()?;
261 mi.run()
262 }
263
264 fn init(&mut self) -> Result<(), B::Error> {
265 #[cfg(feature = "tls")]
266 let tls_conf = self.shim.tls_config();
267
268 self.rw.write_all(&[10])?; // protocol 10
269
270 // 5.1.10 because that's what Ruby's ActiveRecord requires
271 self.rw.write_all(&b"5.1.10-alpha-msql-proxy\0"[..])?;
272
273 self.rw.write_all(&[0x08, 0x00, 0x00, 0x00])?; // TODO: connection ID
274 self.rw.write_all(&b";X,po_k}\0"[..])?; // auth seed
275 let capabilities = &mut [0x00, 0x42]; // 4.1 proto
276 #[cfg(feature = "tls")]
277 if tls_conf.is_some() {
278 capabilities[1] |= 0x08; // SSL support flag
279 }
280 self.rw.write_all(capabilities)?;
281 self.rw.write_all(&[0x21])?; // UTF8_GENERAL_CI
282 self.rw.write_all(&[0x00, 0x00])?; // status flags
283 self.rw.write_all(&[0x00, 0x00])?; // extended capabilities
284 self.rw.write_all(&[0x00])?; // no plugins
285 self.rw.write_all(&[0x00; 6][..])?; // filler
286 self.rw.write_all(&[0x00; 4][..])?; // filler
287 self.rw.write_all(&b">o6^Wz!/kM}N\0"[..])?; // 4.1+ servers must extend salt
288 self.rw.flush()?;
289
290 let mut auth_context = AuthenticationContext::default();
291
292 {
293 let (seq, handshake) = self.rw.next()?.ok_or_else(|| {
294 io::Error::new(
295 io::ErrorKind::ConnectionAborted,
296 "peer terminated connection",
297 )
298 })?;
299 let handshake = commands::client_handshake(&handshake, false)
300 .map_err(|e| match e {
301 nom::Err::Incomplete(_) => io::Error::new(
302 io::ErrorKind::UnexpectedEof,
303 "client sent incomplete handshake",
304 ),
305 nom::Err::Failure(nom_error) | nom::Err::Error(nom_error) => {
306 if let nom::error::ErrorKind::Eof = nom_error.code {
307 io::Error::new(
308 io::ErrorKind::UnexpectedEof,
309 format!(
310 "client did not complete handshake; got {:?}",
311 nom_error.input
312 ),
313 )
314 } else {
315 io::Error::new(
316 io::ErrorKind::InvalidData,
317 format!(
318 "bad client handshake; got {:?} ({:?})",
319 nom_error.input, nom_error.code
320 ),
321 )
322 }
323 }
324 })?
325 .1;
326
327 auth_context.username = handshake.username.map(|x| x.to_vec());
328
329 self.rw.set_seq(seq + 1);
330
331 #[cfg(not(feature = "tls"))]
332 if handshake.capabilities.contains(CapabilityFlags::CLIENT_SSL) {
333 return Err(io::Error::new(
334 io::ErrorKind::InvalidData,
335 "client requested SSL despite us not advertising support for it",
336 )
337 .into());
338 }
339
340 #[cfg(feature = "tls")]
341 if handshake.capabilities.contains(CapabilityFlags::CLIENT_SSL) {
342 let config = tls_conf.ok_or_else(|| {
343 io::Error::new(
344 io::ErrorKind::InvalidData,
345 "client requested SSL despite us not advertising support for it",
346 )
347 })?;
348
349 self.rw.switch_to_tls(config)?;
350
351 let (seq, handshake) = self.rw.next()?.ok_or_else(|| {
352 io::Error::new(
353 io::ErrorKind::ConnectionAborted,
354 "peer terminated connection",
355 )
356 })?;
357
358 let handshake = commands::client_handshake(&handshake, true)
359 .map_err(|e| match e {
360 nom::Err::Incomplete(_) => io::Error::new(
361 io::ErrorKind::UnexpectedEof,
362 "client sent incomplete handshake",
363 ),
364 nom::Err::Failure(nom_error) | nom::Err::Error(nom_error) => {
365 if let nom::error::ErrorKind::Eof = nom_error.code {
366 io::Error::new(
367 io::ErrorKind::UnexpectedEof,
368 format!(
369 "client did not complete handshake; got {:?}",
370 nom_error.input
371 ),
372 )
373 } else {
374 io::Error::new(
375 io::ErrorKind::InvalidData,
376 format!(
377 "bad client handshake; got {:?} ({:?})",
378 nom_error.input, nom_error.code
379 ),
380 )
381 }
382 }
383 })?
384 .1;
385
386 auth_context.username = handshake.username.map(|x| x.to_vec());
387
388 self.rw.set_seq(seq + 1);
389
390 auth_context.tls_client_certs = self.rw.tls_certs();
391 }
392
393 if let Err(e) = self.shim.after_authentication(&auth_context) {
394 writers::write_err(
395 ErrorKind::ER_ACCESS_DENIED_ERROR,
396 "client authentication failed".as_ref(),
397 &mut self.rw,
398 )?;
399 self.rw.flush()?;
400 return Err(e);
401 }
402 }
403
404 writers::write_ok_packet(&mut self.rw, 0, 0, StatusFlags::empty())?;
405 self.rw.flush()?;
406
407 Ok(())
408 }
409
410 fn run(mut self) -> Result<(), B::Error> {
411 use crate::commands::Command;
412
413 let mut stmts: HashMap<u32, _> = HashMap::new();
414 while let Some((seq, packet)) = self.rw.next()? {
415 self.rw.set_seq(seq + 1);
416 let cmd = commands::parse(&packet).unwrap().1;
417 match cmd {
418 Command::Query(q) => {
419 if q.starts_with(b"SELECT @@") || q.starts_with(b"select @@") {
420 let w = QueryResultWriter::new(&mut self.rw, false);
421 let var = &q[b"SELECT @@".len()..];
422 match var {
423 b"max_allowed_packet" => {
424 let cols = &[Column {
425 table: String::new(),
426 column: "@@max_allowed_packet".to_owned(),
427 coltype: myc::constants::ColumnType::MYSQL_TYPE_LONG,
428 colflags: myc::constants::ColumnFlags::UNSIGNED_FLAG,
429 }];
430 let mut w = w.start(cols)?;
431 w.write_row(iter::once(67108864u32))?;
432 w.finish()?;
433 }
434 _ => {
435 w.completed(0, 0)?;
436 }
437 }
438 } else if q.starts_with(b"USE ") || q.starts_with(b"use ") {
439 let w = InitWriter {
440 writer: &mut self.rw,
441 };
442 let schema = ::std::str::from_utf8(&q[b"USE ".len()..])
443 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
444 let schema = schema.trim().trim_end_matches(';').trim_matches('`');
445 self.shim.on_init(schema, w)?;
446 } else {
447 let w = QueryResultWriter::new(&mut self.rw, false);
448 self.shim.on_query(
449 ::std::str::from_utf8(q)
450 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?,
451 w,
452 )?;
453 }
454 }
455 Command::Prepare(q) => {
456 let w = StatementMetaWriter {
457 writer: &mut self.rw,
458 stmts: &mut stmts,
459 };
460
461 self.shim.on_prepare(
462 ::std::str::from_utf8(q)
463 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?,
464 w,
465 )?;
466 }
467 Command::Execute { stmt, params } => {
468 let state = stmts.get_mut(&stmt).ok_or_else(|| {
469 io::Error::new(
470 io::ErrorKind::InvalidData,
471 format!("asked to execute unknown statement {}", stmt),
472 )
473 })?;
474 {
475 let params = params::ParamParser::new(params, state);
476 let w = QueryResultWriter::new(&mut self.rw, true);
477 self.shim.on_execute(stmt, params, w)?;
478 }
479 state.long_data.clear();
480 }
481 Command::SendLongData { stmt, param, data } => {
482 stmts
483 .get_mut(&stmt)
484 .ok_or_else(|| {
485 io::Error::new(
486 io::ErrorKind::InvalidData,
487 format!("got long data packet for unknown statement {}", stmt),
488 )
489 })?
490 .long_data
491 .entry(param)
492 .or_insert_with(Vec::new)
493 .extend(data);
494 }
495 Command::Close(stmt) => {
496 self.shim.on_close(stmt);
497 stmts.remove(&stmt);
498 // NOTE: spec dictates no response from server
499 }
500 Command::ListFields(_) => {
501 let cols = &[Column {
502 table: String::new(),
503 column: "not implemented".to_owned(),
504 coltype: myc::constants::ColumnType::MYSQL_TYPE_SHORT,
505 colflags: myc::constants::ColumnFlags::UNSIGNED_FLAG,
506 }];
507 writers::write_column_definitions(cols, &mut self.rw, true, true)?;
508 }
509 Command::Init(schema) => {
510 let w = InitWriter {
511 writer: &mut self.rw,
512 };
513 self.shim.on_init(
514 ::std::str::from_utf8(schema)
515 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?,
516 w,
517 )?;
518 }
519 Command::Ping => {
520 writers::write_ok_packet(&mut self.rw, 0, 0, StatusFlags::empty())?;
521 }
522 Command::Quit => {
523 break;
524 }
525 }
526 self.rw.flush()?;
527 }
528 Ok(())
529 }
530}