1extern crate mysql_common as myc;
26
27use std::collections::HashMap;
28use std::io;
29use std::io::Write;
30use std::iter;
31
32use async_trait::async_trait;
33use tokio::io::AsyncRead;
34use tokio::io::AsyncWrite;
35#[cfg(feature = "tls")]
36use tokio_rustls::rustls::ServerConfig;
37
38pub use crate::myc::constants::{CapabilityFlags, ColumnFlags, ColumnType, StatusFlags};
39#[cfg(feature = "tls")]
40pub use crate::tls::{plain_run_with_options, secure_run_with_options};
41
42mod commands;
43mod errorcodes;
44mod packet_reader;
45mod packet_writer;
46mod params;
47mod resultset;
48#[cfg(feature = "tls")]
49mod tls;
50mod value;
51mod writers;
52
53#[cfg(test)]
54mod tests;
55
56pub const U24_MAX: usize = 16_777_215;
58
59#[derive(Debug, Clone, PartialEq, Eq)]
62pub struct Column {
63 pub table: String,
67 pub column: String,
71 pub coltype: ColumnType,
73 pub colflags: ColumnFlags,
77}
78
79#[derive(Debug, Clone, PartialEq, Eq, Default)]
81pub struct OkResponse {
82 pub header: u8,
84 pub affected_rows: u64,
86 pub last_insert_id: u64,
88 pub status_flags: StatusFlags,
90 pub warnings: u16,
92 pub info: String,
94 pub session_state_info: String,
96}
97
98pub use crate::errorcodes::ErrorKind;
99pub use crate::params::{ParamParser, ParamValue, Params};
100pub use crate::resultset::{InitWriter, QueryResultWriter, RowWriter, StatementMetaWriter};
101pub use crate::value::{decode::to_naive_datetime, ToMysqlValue, Value, ValueInner};
102use crate::{commands::ClientHandshake, packet_reader::PacketReader, packet_writer::PacketWriter};
103
104const SCRAMBLE_SIZE: usize = 20;
105const MYSQL_NATIVE_PASSWORD: &str = "mysql_native_password";
106
107#[async_trait]
108pub trait AsyncMysqlShim<W: Send> {
110 type Error: From<io::Error>;
114
115 fn version(&self) -> String {
117 "5.1.10-alpha-msql-proxy".to_string()
119 }
120
121 fn connect_id(&self) -> u32 {
123 u32::from_le_bytes([0x08, 0x00, 0x00, 0x00])
124 }
125
126 fn default_auth_plugin(&self) -> &str {
128 MYSQL_NATIVE_PASSWORD
129 }
130
131 async fn auth_plugin_for_username(&self, _user: &[u8]) -> &str {
133 MYSQL_NATIVE_PASSWORD
134 }
135
136 fn salt(&self) -> [u8; SCRAMBLE_SIZE] {
138 let bs = ";X,po_k}>o6^Wz!/kM}N".as_bytes();
139 let mut scramble: [u8; SCRAMBLE_SIZE] = [0; SCRAMBLE_SIZE];
140 for i in 0..SCRAMBLE_SIZE {
141 scramble[i] = bs[i];
142 if scramble[i] == b'\0' || scramble[i] == b'$' {
143 scramble[i] += 1;
144 }
145 }
146 scramble
147 }
148
149 async fn authenticate(
151 &self,
152 _auth_plugin: &str,
153 _username: &[u8],
154 _salt: &[u8],
155 _auth_data: &[u8],
156 ) -> bool {
157 true
158 }
159
160 async fn on_prepare<'a>(
166 &'a mut self,
167 query: &'a str,
168 info: StatementMetaWriter<'a, W>,
169 ) -> Result<(), Self::Error>;
170
171 async fn on_execute<'a>(
177 &'a mut self,
178 id: u32,
179 params: ParamParser<'a>,
180 results: QueryResultWriter<'a, W>,
181 ) -> Result<(), Self::Error>;
182
183 async fn on_close<'a>(&'a mut self, stmt: u32)
186 where
187 W: 'async_trait;
188
189 async fn on_query<'a>(
194 &'a mut self,
195 query: &'a str,
196 results: QueryResultWriter<'a, W>,
197 ) -> Result<(), Self::Error>;
198
199 async fn on_init<'a>(
201 &'a mut self,
202 _: &'a str,
203 _: InitWriter<'a, W>,
204 ) -> Result<(), Self::Error> {
205 Ok(())
206 }
207}
208
209#[derive(Debug, Clone, PartialEq, Eq, Default)]
211pub struct IntermediaryOptions {
212 pub process_use_statement_on_query: bool,
214 pub reject_connection_on_dbname_absence: bool,
216}
217
218#[derive(Default)]
219struct StatementData {
220 long_data: HashMap<u16, Vec<u8>>,
221 bound_types: Vec<(myc::constants::ColumnType, bool)>,
222 params: u16,
223}
224
225const AUTH_PLUGIN_DATA_PART_1_LENGTH: usize = 8;
226
227pub struct AsyncMysqlIntermediary<B, S: AsyncRead + Unpin, W> {
230 pub(crate) client_capabilities: CapabilityFlags,
231 process_use_statement_on_query: bool,
232 reject_connection_on_dbname_absence: bool,
233 shim: B,
234 reader: packet_reader::PacketReader<S>,
235 writer: packet_writer::PacketWriter<W>,
236}
237
238impl<B, R, W> AsyncMysqlIntermediary<B, R, W>
239where
240 W: AsyncWrite + Send + Unpin,
241 B: AsyncMysqlShim<W> + Send + Sync,
242 R: AsyncRead + Send + Unpin,
243{
244 pub async fn run_on(shim: B, stream: R, output_stream: W) -> Result<(), B::Error> {
247 Self::run_with_options(shim, stream, output_stream, &Default::default()).await
248 }
249
250 pub async fn run_with_options(
253 mut shim: B,
254 input_stream: R,
255 mut output_stream: W,
256 opts: &IntermediaryOptions,
257 ) -> Result<(), B::Error> {
258 let process_use_statement_on_query = opts.process_use_statement_on_query;
259 let reject_connection_on_dbname_absence = opts.reject_connection_on_dbname_absence;
260 let (_, (handshake, seq, client_capabilities, input_stream)) =
261 AsyncMysqlIntermediary::init_before_ssl(
262 &mut shim,
263 input_stream,
264 &mut output_stream,
265 #[cfg(feature = "tls")]
266 &None,
267 )
268 .await?;
269
270 let reader = PacketReader::new(input_stream);
271 let writer = PacketWriter::new(output_stream);
272
273 let mut mi = AsyncMysqlIntermediary {
274 client_capabilities,
275 process_use_statement_on_query,
276 reject_connection_on_dbname_absence,
277 shim,
278 reader,
279 writer,
280 };
281 mi.init_after_ssl(handshake, seq).await?;
282 mi.run().await
283 }
284
285 pub async fn init_before_ssl(
286 shim: &mut B,
287 input_stream: R,
288 output_stream: &mut W,
289 #[cfg(feature = "tls")] tls_conf: &Option<std::sync::Arc<ServerConfig>>,
290 ) -> Result<
291 (
292 bool,
293 (ClientHandshake, u8, CapabilityFlags, PacketReader<R>),
294 ),
295 B::Error,
296 > {
297 let mut reader = PacketReader::new(input_stream);
298 let mut writer = PacketWriter::new(output_stream);
299 writer.write_all(&[10])?; writer.write_all(shim.version().as_bytes())?;
303 writer.write_all(&[0x00])?;
304
305 writer.write_all(&shim.connect_id().to_le_bytes())?;
307
308 let server_capabilities = CapabilityFlags::CLIENT_PROTOCOL_41
309 | CapabilityFlags::CLIENT_SECURE_CONNECTION
310 | CapabilityFlags::CLIENT_PLUGIN_AUTH
311 | CapabilityFlags::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
312 | CapabilityFlags::CLIENT_CONNECT_WITH_DB
313 | CapabilityFlags::CLIENT_DEPRECATE_EOF;
314
315 #[cfg(feature = "tls")]
316 let server_capabilities = if tls_conf.is_some() {
317 server_capabilities | CapabilityFlags::CLIENT_SSL
318 } else {
319 server_capabilities
320 };
321
322 let server_capabilities_vec = server_capabilities.bits().to_le_bytes();
323 let default_auth_plugin = shim.default_auth_plugin();
324 let scramble = shim.salt();
325
326 writer.write_all(&scramble[0..AUTH_PLUGIN_DATA_PART_1_LENGTH])?; writer.write_all(&[0x00])?;
328
329 writer.write_all(&server_capabilities_vec[..2])?; writer.write_all(&[0x21])?; writer.write_all(&[0x00, 0x00])?; writer.write_all(&server_capabilities_vec[2..4])?; if default_auth_plugin.is_empty() {
336 writer.write_all(&[0x00])?;
338 } else {
339 writer.write_all(&((scramble.len() + 1) as u8).to_le_bytes())?; }
341 writer.write_all(&[0x00; 10][..])?; writer.write_all(&scramble[AUTH_PLUGIN_DATA_PART_1_LENGTH..])?; writer.write_all(&[0x00])?;
347
348 writer.write_all(default_auth_plugin.as_bytes())?;
350 writer.write_all(&[0x00])?;
351 writer.end_packet().await?;
352 writer.flush_all().await?;
353
354 let (seq, handshake) = reader.next_async().await?.ok_or_else(|| {
355 io::Error::new(
356 io::ErrorKind::ConnectionAborted,
357 "peer terminated connection",
358 )
359 })?;
360
361 let handshake = commands::client_handshake(&handshake, false)
362 .map_err(|e| match e {
363 nom::Err::Incomplete(_) => io::Error::new(
364 io::ErrorKind::UnexpectedEof,
365 "client sent incomplete handshake",
366 ),
367 nom::Err::Failure(nom_error) | nom::Err::Error(nom_error) => {
368 if let nom::error::ErrorKind::Eof = nom_error.code {
369 io::Error::new(
370 io::ErrorKind::UnexpectedEof,
371 format!(
372 "client did not complete handshake; got {:?}",
373 nom_error.input
374 ),
375 )
376 } else {
377 io::Error::new(
378 io::ErrorKind::InvalidData,
379 format!(
380 "bad client handshake; got {:?} ({:?})",
381 nom_error.input, nom_error.code
382 ),
383 )
384 }
385 }
386 })?
387 .1;
388
389 writer.set_seq(seq + 1);
390
391 #[cfg(not(feature = "tls"))]
392 if handshake.capabilities.contains(CapabilityFlags::CLIENT_SSL) {
393 return Err(io::Error::new(
394 io::ErrorKind::InvalidData,
395 "client requested SSL despite us not advertising support for it",
396 )
397 .into());
398 }
399
400 #[cfg(feature = "tls")]
401 if handshake.capabilities.contains(CapabilityFlags::CLIENT_SSL) {
402 return Ok((true, (handshake, seq, server_capabilities, reader)));
403 }
404
405 Ok((false, (handshake, seq, server_capabilities, reader)))
406 }
407
408 pub async fn init_after_ssl(
409 &mut self,
410 #[cfg(feature = "tls")] mut handshake: ClientHandshake,
411 #[cfg(not(feature = "tls"))] handshake: ClientHandshake,
412 mut seq: u8,
413 ) -> Result<(), B::Error> {
414 #[cfg(feature = "tls")]
415 if handshake.capabilities.contains(CapabilityFlags::CLIENT_SSL) {
416 let (_seq, hs) = self.reader.next_async().await?.ok_or_else(|| {
417 io::Error::new(
418 io::ErrorKind::ConnectionAborted,
419 "peer terminated connection",
420 )
421 })?;
422 seq = _seq;
423
424 handshake = commands::client_handshake(&hs, true)
425 .map_err(|e| match e {
426 nom::Err::Incomplete(_) => io::Error::new(
427 io::ErrorKind::UnexpectedEof,
428 "client sent incomplete handshake",
429 ),
430 nom::Err::Failure(nom_error) | nom::Err::Error(nom_error) => {
431 if let nom::error::ErrorKind::Eof = nom_error.code {
432 io::Error::new(
433 io::ErrorKind::UnexpectedEof,
434 format!(
435 "client did not complete handshake; got {:?}",
436 nom_error.input
437 ),
438 )
439 } else {
440 io::Error::new(
441 io::ErrorKind::InvalidData,
442 format!(
443 "bad client handshake; got {:?} ({:?})",
444 nom_error.input, nom_error.code
445 ),
446 )
447 }
448 }
449 })?
450 .1;
451
452 self.writer.set_seq(seq + 1);
453 }
454
455 let scramble = self.shim.salt();
456 {
457 if !handshake
458 .capabilities
459 .contains(CapabilityFlags::CLIENT_PROTOCOL_41)
460 {
461 let err = io::Error::new(
462 io::ErrorKind::ConnectionAborted,
463 "Required capability: CLIENT_PROTOCOL_41, please upgrade your MySQL client version",
464 );
465 return Err(err.into());
466 }
467
468 self.client_capabilities = handshake.capabilities;
469 let mut auth_response = handshake.auth_response.clone();
470 if let Some(username) = &handshake.username {
471 let auth_plugin_expect = self.shim.auth_plugin_for_username(username).await;
472
473 if !auth_plugin_expect.is_empty()
475 && auth_response.is_empty()
476 && handshake.auth_plugin != auth_plugin_expect.as_bytes()
477 {
478 self.writer.set_seq(seq + 1);
479 self.writer.write_all(&[0xfe])?;
480 self.writer.write_all(auth_plugin_expect.as_bytes())?;
481 self.writer.write_all(&[0x00])?;
482 self.writer.write_all(&scramble)?;
483 self.writer.write_all(&[0x00])?;
484
485 self.writer.end_packet().await?;
486 self.writer.flush_all().await?;
487
488 {
489 let (rseq, auth_response_data) =
490 self.reader.next_async().await?.ok_or_else(|| {
491 io::Error::new(
492 io::ErrorKind::ConnectionAborted,
493 "peer terminated connection",
494 )
495 })?;
496
497 seq = rseq;
498 auth_response = auth_response_data.to_vec();
499 }
500 }
501
502 self.writer.set_seq(seq + 1);
503
504 if !self
505 .shim
506 .authenticate(
507 auth_plugin_expect,
508 username,
509 &scramble,
510 auth_response.as_slice(),
511 )
512 .await
513 {
514 let err_msg = format!(
515 "Authenticate failed, user: {:?}, auth_plugin: {:?}",
516 String::from_utf8_lossy(username),
517 auth_plugin_expect,
518 );
519 writers::write_err(
520 ErrorKind::ER_ACCESS_DENIED_NO_PASSWORD_ERROR,
521 err_msg.as_bytes(),
522 &mut self.writer,
523 )
524 .await?;
525 self.writer.flush_all().await?;
526 return Err(io::Error::new(io::ErrorKind::PermissionDenied, err_msg).into());
527 }
528
529 if let Some(Ok(db)) = handshake.db.as_ref().map(|x| std::str::from_utf8(x)) {
530 let w = InitWriter {
531 client_capabilities: self.client_capabilities,
532 writer: &mut self.writer,
533 };
534 self.shim.on_init(db, w).await?;
535 } else if self.reject_connection_on_dbname_absence {
536 writers::write_err(
537 ErrorKind::ER_DATABASE_NAME,
538 "database required on connection".as_bytes(),
539 &mut self.writer,
540 )
541 .await?;
542 } else {
543 writers::write_ok_packet(
544 &mut self.writer,
545 self.client_capabilities,
546 OkResponse::default(),
547 )
548 .await?;
549 }
550 }
551
552 self.writer.flush_all().await?;
553 };
554
555 Ok(())
556 }
557
558 async fn run(mut self) -> Result<(), B::Error> {
559 use crate::commands::Command;
560
561 let mut stmts: HashMap<u32, _> = HashMap::new();
562 while let Some((seq, packet)) = self.reader.next_async().await? {
563 self.writer.set_seq(seq + 1);
564 let res = commands::parse(&packet);
565 match res {
566 Ok(cmd) => {
567 match cmd.1 {
568 Command::Query(q) => {
569 if q.starts_with(b"SELECT @@") || q.starts_with(b"select @@") {
570 let w = QueryResultWriter::new(
571 &mut self.writer,
572 false,
573 self.client_capabilities,
574 );
575
576 let var = &q[b"SELECT @@".len()..];
577 let var_with_at = &q[b"SELECT ".len()..];
578 let cols = &[Column {
579 table: String::new(),
580 column: String::from_utf8_lossy(var_with_at).to_string(),
581 coltype: myc::constants::ColumnType::MYSQL_TYPE_LONG,
582 colflags: myc::constants::ColumnFlags::UNSIGNED_FLAG,
583 }];
584
585 match var {
586 b"max_allowed_packet" => {
587 let mut w = w.start(cols).await?;
588 w.write_row(iter::once(67108864u32)).await?;
589 w.finish().await?;
590 }
591 _ => {
592 self.shim
593 .on_query(
594 ::std::str::from_utf8(q).map_err(|e| {
595 io::Error::new(io::ErrorKind::InvalidData, e)
596 })?,
597 w,
598 )
599 .await?;
600 }
601 }
602 } else if !self.process_use_statement_on_query
603 && (q.starts_with(b"USE ") || q.starts_with(b"use "))
604 {
605 let w = InitWriter {
606 client_capabilities: self.client_capabilities,
607 writer: &mut self.writer,
608 };
609 let schema = ::std::str::from_utf8(&q[b"USE ".len()..])
610 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
611 let schema = schema.trim().trim_end_matches(';').trim_matches('`');
612 self.shim.on_init(schema, w).await?;
613 } else {
614 let w = QueryResultWriter::new(
615 &mut self.writer,
616 false,
617 self.client_capabilities,
618 );
619 self.shim
620 .on_query(
621 ::std::str::from_utf8(q).map_err(|e| {
622 io::Error::new(io::ErrorKind::InvalidData, e)
623 })?,
624 w,
625 )
626 .await?;
627 }
628 }
629 Command::Prepare(q) => {
630 let w = StatementMetaWriter {
631 writer: &mut self.writer,
632 stmts: &mut stmts,
633 client_capabilities: self.client_capabilities,
634 };
635
636 self.shim
637 .on_prepare(
638 ::std::str::from_utf8(q).map_err(|e| {
639 io::Error::new(io::ErrorKind::InvalidData, e)
640 })?,
641 w,
642 )
643 .await?;
644 }
645 Command::Execute { stmt, params } => {
646 let state = stmts.get_mut(&stmt).ok_or_else(|| {
647 io::Error::new(
648 io::ErrorKind::InvalidData,
649 format!("asked to execute unknown statement {}", stmt),
650 )
651 })?;
652 {
653 let params = params::ParamParser::new(params, state);
654 let w = QueryResultWriter::new(
655 &mut self.writer,
656 true,
657 self.client_capabilities,
658 );
659 self.shim.on_execute(stmt, params, w).await?;
660 }
661 state.long_data.clear();
662 }
663 Command::SendLongData { stmt, param, data } => {
664 stmts
665 .get_mut(&stmt)
666 .ok_or_else(|| {
667 io::Error::new(
668 io::ErrorKind::InvalidData,
669 format!(
670 "got long data packet for unknown statement {}",
671 stmt
672 ),
673 )
674 })?
675 .long_data
676 .entry(param)
677 .or_insert_with(Vec::new)
678 .extend(data);
679 }
680 Command::Close(stmt) => {
681 self.shim.on_close(stmt).await;
682 stmts.remove(&stmt);
683 }
685 Command::ListFields(_) => {
686 let ok_packet = OkResponse {
692 header: 0xfe,
693 ..Default::default()
694 };
695 writers::write_ok_packet(
696 &mut self.writer,
697 self.client_capabilities,
698 ok_packet,
699 )
700 .await?;
701 }
702 Command::Init(schema) => {
703 let w = InitWriter {
704 client_capabilities: self.client_capabilities,
705 writer: &mut self.writer,
706 };
707 self.shim
708 .on_init(
709 ::std::str::from_utf8(schema).map_err(|e| {
710 io::Error::new(io::ErrorKind::InvalidData, e)
711 })?,
712 w,
713 )
714 .await?;
715 }
716 Command::Ping => {
717 writers::write_ok_packet(
718 &mut self.writer,
719 self.client_capabilities,
720 OkResponse::default(),
721 )
722 .await?;
723 }
724 Command::Quit => {
725 break;
726 }
727 }
728 self.writer.flush_all().await?;
729 }
730 Err(_) => {
731 writers::write_ok_packet(
734 &mut self.writer,
735 self.client_capabilities,
736 OkResponse::default(),
737 )
738 .await?;
739 self.writer.flush_all().await?;
740 }
741 }
742 }
743 Ok(())
744 }
745}