opensrv_mysql/
lib.rs

1// Copyright 2021 Datafuse Labs.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Note to developers: you can find decent overviews of the protocol at
16//
17//   https://github.com/cwarden/mysql-proxy/blob/master/doc/protocol.rst
18//
19// and
20//
21//   https://mariadb.com/kb/en/library/clientserver-protocol/
22//
23// Wireshark also does a pretty good job at parsing the MySQL protocol.
24
25extern crate mysql_common as myc;
26
27use std::collections::HashMap;
28use std::io;
29use std::io::Write;
30use std::iter;
31
32use async_trait::async_trait;
33use tokio::io::AsyncRead;
34use tokio::io::AsyncWrite;
35#[cfg(feature = "tls")]
36use tokio_rustls::rustls::ServerConfig;
37
38pub use crate::myc::constants::{CapabilityFlags, ColumnFlags, ColumnType, StatusFlags};
39#[cfg(feature = "tls")]
40pub use crate::tls::{plain_run_with_options, secure_run_with_options};
41
42mod commands;
43mod errorcodes;
44mod packet_reader;
45mod packet_writer;
46mod params;
47mod resultset;
48#[cfg(feature = "tls")]
49mod tls;
50mod value;
51mod writers;
52
53#[cfg(test)]
54mod tests;
55
56// max payload size 2^(24-1)
57pub const U24_MAX: usize = 16_777_215;
58
59/// Meta-information abot a single column, used either to describe a prepared statement parameter
60/// or an output column.
61#[derive(Debug, Clone, PartialEq, Eq)]
62pub struct Column {
63    /// This column's associated table.
64    ///
65    /// Note that this is *technically* the table's alias.
66    pub table: String,
67    /// This column's name.
68    ///
69    /// Note that this is *technically* the column's alias.
70    pub column: String,
71    /// This column's type>
72    pub coltype: ColumnType,
73    /// Any flags associated with this column.
74    ///
75    /// Of particular interest are `ColumnFlags::UNSIGNED_FLAG` and `ColumnFlags::NOT_NULL_FLAG`.
76    pub colflags: ColumnFlags,
77}
78
79/// QueryStatusInfo represents the status of a query.
80#[derive(Debug, Clone, PartialEq, Eq, Default)]
81pub struct OkResponse {
82    /// header
83    pub header: u8,
84    /// affected rows in update/insert
85    pub affected_rows: u64,
86    /// insert_id in update/insert
87    pub last_insert_id: u64,
88    /// StatusFlags associated with this query
89    pub status_flags: StatusFlags,
90    /// Warnings
91    pub warnings: u16,
92    /// Extra infomation
93    pub info: String,
94    /// session state change information
95    pub session_state_info: String,
96}
97
98pub use crate::errorcodes::ErrorKind;
99pub use crate::params::{ParamParser, ParamValue, Params};
100pub use crate::resultset::{InitWriter, QueryResultWriter, RowWriter, StatementMetaWriter};
101pub use crate::value::{decode::to_naive_datetime, ToMysqlValue, Value, ValueInner};
102use crate::{commands::ClientHandshake, packet_reader::PacketReader, packet_writer::PacketWriter};
103
104const SCRAMBLE_SIZE: usize = 20;
105const MYSQL_NATIVE_PASSWORD: &str = "mysql_native_password";
106
107#[async_trait]
108/// Implementors of this async-trait can be used to drive a MySQL-compatible database backend.
109pub trait AsyncMysqlShim<W: Send> {
110    /// The error type produced by operations on this shim.
111    ///
112    /// Must implement `From<io::Error>` so that transport-level errors can be lifted.
113    type Error: From<io::Error>;
114
115    /// Server version
116    fn version(&self) -> String {
117        // 5.1.10 because that's what Ruby's ActiveRecord requires
118        "5.1.10-alpha-msql-proxy".to_string()
119    }
120
121    /// Connection id
122    fn connect_id(&self) -> u32 {
123        u32::from_le_bytes([0x08, 0x00, 0x00, 0x00])
124    }
125
126    /// get auth plugin
127    fn default_auth_plugin(&self) -> &str {
128        MYSQL_NATIVE_PASSWORD
129    }
130
131    /// get auth plugin
132    async fn auth_plugin_for_username(&self, _user: &[u8]) -> &str {
133        MYSQL_NATIVE_PASSWORD
134    }
135
136    /// Default salt(scramble) for auth plugin
137    fn salt(&self) -> [u8; SCRAMBLE_SIZE] {
138        let bs = ";X,po_k}>o6^Wz!/kM}N".as_bytes();
139        let mut scramble: [u8; SCRAMBLE_SIZE] = [0; SCRAMBLE_SIZE];
140        for i in 0..SCRAMBLE_SIZE {
141            scramble[i] = bs[i];
142            if scramble[i] == b'\0' || scramble[i] == b'$' {
143                scramble[i] += 1;
144            }
145        }
146        scramble
147    }
148
149    /// authenticate method for the specified plugin
150    async fn authenticate(
151        &self,
152        _auth_plugin: &str,
153        _username: &[u8],
154        _salt: &[u8],
155        _auth_data: &[u8],
156    ) -> bool {
157        true
158    }
159
160    /// Called when the client issues a request to prepare `query` for later execution.
161    ///
162    /// The provided [`StatementMetaWriter`](struct.StatementMetaWriter.html) should be used to
163    /// notify the client of the statement id assigned to the prepared statement, as well as to
164    /// give metadata about the types of parameters and returned columns.
165    async fn on_prepare<'a>(
166        &'a mut self,
167        query: &'a str,
168        info: StatementMetaWriter<'a, W>,
169    ) -> Result<(), Self::Error>;
170
171    /// Called when the client executes a previously prepared statement.
172    ///
173    /// Any parameters included with the client's command is given in `params`.
174    /// A response to the query should be given using the provided
175    /// [`QueryResultWriter`](struct.QueryResultWriter.html).
176    async fn on_execute<'a>(
177        &'a mut self,
178        id: u32,
179        params: ParamParser<'a>,
180        results: QueryResultWriter<'a, W>,
181    ) -> Result<(), Self::Error>;
182
183    /// Called when the client wishes to deallocate resources associated with a previously prepared
184    /// statement.
185    async fn on_close<'a>(&'a mut self, stmt: u32)
186    where
187        W: 'async_trait;
188
189    /// Called when the client issues a query for immediate execution.
190    ///
191    /// Results should be returned using the given
192    /// [`QueryResultWriter`](struct.QueryResultWriter.html).
193    async fn on_query<'a>(
194        &'a mut self,
195        query: &'a str,
196        results: QueryResultWriter<'a, W>,
197    ) -> Result<(), Self::Error>;
198
199    /// Called when client switches database.
200    async fn on_init<'a>(
201        &'a mut self,
202        _: &'a str,
203        _: InitWriter<'a, W>,
204    ) -> Result<(), Self::Error> {
205        Ok(())
206    }
207}
208
209/// The options which passed to AsyncMysqlIntermediary struct
210#[derive(Debug, Clone, PartialEq, Eq, Default)]
211pub struct IntermediaryOptions {
212    /// process use statement on the on_query handler
213    pub process_use_statement_on_query: bool,
214    /// reject connection if dbname not provided
215    pub reject_connection_on_dbname_absence: bool,
216}
217
218#[derive(Default)]
219struct StatementData {
220    long_data: HashMap<u16, Vec<u8>>,
221    bound_types: Vec<(myc::constants::ColumnType, bool)>,
222    params: u16,
223}
224
225const AUTH_PLUGIN_DATA_PART_1_LENGTH: usize = 8;
226
227/// A server that speaks the MySQL/MariaDB protocol, and can delegate client commands to a backend
228/// that implements [`AsyncMysqlShim`](trait.AsyncMysqlShim.html).
229pub struct AsyncMysqlIntermediary<B, S: AsyncRead + Unpin, W> {
230    pub(crate) client_capabilities: CapabilityFlags,
231    process_use_statement_on_query: bool,
232    reject_connection_on_dbname_absence: bool,
233    shim: B,
234    reader: packet_reader::PacketReader<S>,
235    writer: packet_writer::PacketWriter<W>,
236}
237
238impl<B, R, W> AsyncMysqlIntermediary<B, R, W>
239where
240    W: AsyncWrite + Send + Unpin,
241    B: AsyncMysqlShim<W> + Send + Sync,
242    R: AsyncRead + Send + Unpin,
243{
244    /// Create a new server over two one-way channels and process client commands until the client
245    /// disconnects or an error occurs.
246    pub async fn run_on(shim: B, stream: R, output_stream: W) -> Result<(), B::Error> {
247        Self::run_with_options(shim, stream, output_stream, &Default::default()).await
248    }
249
250    /// Create a new server over two one-way channels and process client commands until the client
251    /// disconnects or an error occurs, with config options.
252    pub async fn run_with_options(
253        mut shim: B,
254        input_stream: R,
255        mut output_stream: W,
256        opts: &IntermediaryOptions,
257    ) -> Result<(), B::Error> {
258        let process_use_statement_on_query = opts.process_use_statement_on_query;
259        let reject_connection_on_dbname_absence = opts.reject_connection_on_dbname_absence;
260        let (_, (handshake, seq, client_capabilities, input_stream)) =
261            AsyncMysqlIntermediary::init_before_ssl(
262                &mut shim,
263                input_stream,
264                &mut output_stream,
265                #[cfg(feature = "tls")]
266                &None,
267            )
268            .await?;
269
270        let reader = PacketReader::new(input_stream);
271        let writer = PacketWriter::new(output_stream);
272
273        let mut mi = AsyncMysqlIntermediary {
274            client_capabilities,
275            process_use_statement_on_query,
276            reject_connection_on_dbname_absence,
277            shim,
278            reader,
279            writer,
280        };
281        mi.init_after_ssl(handshake, seq).await?;
282        mi.run().await
283    }
284
285    pub async fn init_before_ssl(
286        shim: &mut B,
287        input_stream: R,
288        output_stream: &mut W,
289        #[cfg(feature = "tls")] tls_conf: &Option<std::sync::Arc<ServerConfig>>,
290    ) -> Result<
291        (
292            bool,
293            (ClientHandshake, u8, CapabilityFlags, PacketReader<R>),
294        ),
295        B::Error,
296    > {
297        let mut reader = PacketReader::new(input_stream);
298        let mut writer = PacketWriter::new(output_stream);
299        // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeV10
300        writer.write_all(&[10])?; // protocol 10
301
302        writer.write_all(shim.version().as_bytes())?;
303        writer.write_all(&[0x00])?;
304
305        // connection_id (4 bytes)
306        writer.write_all(&shim.connect_id().to_le_bytes())?;
307
308        let server_capabilities = CapabilityFlags::CLIENT_PROTOCOL_41
309            | CapabilityFlags::CLIENT_SECURE_CONNECTION
310            | CapabilityFlags::CLIENT_PLUGIN_AUTH
311            | CapabilityFlags::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
312            | CapabilityFlags::CLIENT_CONNECT_WITH_DB
313            | CapabilityFlags::CLIENT_DEPRECATE_EOF;
314
315        #[cfg(feature = "tls")]
316        let server_capabilities = if tls_conf.is_some() {
317            server_capabilities | CapabilityFlags::CLIENT_SSL
318        } else {
319            server_capabilities
320        };
321
322        let server_capabilities_vec = server_capabilities.bits().to_le_bytes();
323        let default_auth_plugin = shim.default_auth_plugin();
324        let scramble = shim.salt();
325
326        writer.write_all(&scramble[0..AUTH_PLUGIN_DATA_PART_1_LENGTH])?; // auth-plugin-data-part-1
327        writer.write_all(&[0x00])?;
328
329        writer.write_all(&server_capabilities_vec[..2])?; // The lower 2 bytes of the Capabilities Flags, 0x42
330                                                          // self.writer.write_all(&[0x00, 0x42])?;
331        writer.write_all(&[0x21])?; // UTF8_GENERAL_CI
332        writer.write_all(&[0x00, 0x00])?; // status_flags
333        writer.write_all(&server_capabilities_vec[2..4])?; // The upper 2 bytes of the Capabilities Flags
334
335        if default_auth_plugin.is_empty() {
336            // no plugins
337            writer.write_all(&[0x00])?;
338        } else {
339            writer.write_all(&((scramble.len() + 1) as u8).to_le_bytes())?; // length of the combined auth_plugin_data(scramble), if auth_plugin_data_len is > 0
340        }
341        writer.write_all(&[0x00; 10][..])?; // 10 bytes filler
342
343        // Part2 of the auth_plugin_data
344        // $len=MAX(13, length of auth-plugin-data - 8)
345        writer.write_all(&scramble[AUTH_PLUGIN_DATA_PART_1_LENGTH..])?; // 12 bytes
346        writer.write_all(&[0x00])?;
347
348        // Plugin name
349        writer.write_all(default_auth_plugin.as_bytes())?;
350        writer.write_all(&[0x00])?;
351        writer.end_packet().await?;
352        writer.flush_all().await?;
353
354        let (seq, handshake) = reader.next_async().await?.ok_or_else(|| {
355            io::Error::new(
356                io::ErrorKind::ConnectionAborted,
357                "peer terminated connection",
358            )
359        })?;
360
361        let handshake = commands::client_handshake(&handshake, false)
362            .map_err(|e| match e {
363                nom::Err::Incomplete(_) => io::Error::new(
364                    io::ErrorKind::UnexpectedEof,
365                    "client sent incomplete handshake",
366                ),
367                nom::Err::Failure(nom_error) | nom::Err::Error(nom_error) => {
368                    if let nom::error::ErrorKind::Eof = nom_error.code {
369                        io::Error::new(
370                            io::ErrorKind::UnexpectedEof,
371                            format!(
372                                "client did not complete handshake; got {:?}",
373                                nom_error.input
374                            ),
375                        )
376                    } else {
377                        io::Error::new(
378                            io::ErrorKind::InvalidData,
379                            format!(
380                                "bad client handshake; got {:?} ({:?})",
381                                nom_error.input, nom_error.code
382                            ),
383                        )
384                    }
385                }
386            })?
387            .1;
388
389        writer.set_seq(seq + 1);
390
391        #[cfg(not(feature = "tls"))]
392        if handshake.capabilities.contains(CapabilityFlags::CLIENT_SSL) {
393            return Err(io::Error::new(
394                io::ErrorKind::InvalidData,
395                "client requested SSL despite us not advertising support for it",
396            )
397            .into());
398        }
399
400        #[cfg(feature = "tls")]
401        if handshake.capabilities.contains(CapabilityFlags::CLIENT_SSL) {
402            return Ok((true, (handshake, seq, server_capabilities, reader)));
403        }
404
405        Ok((false, (handshake, seq, server_capabilities, reader)))
406    }
407
408    pub async fn init_after_ssl(
409        &mut self,
410        #[cfg(feature = "tls")] mut handshake: ClientHandshake,
411        #[cfg(not(feature = "tls"))] handshake: ClientHandshake,
412        mut seq: u8,
413    ) -> Result<(), B::Error> {
414        #[cfg(feature = "tls")]
415        if handshake.capabilities.contains(CapabilityFlags::CLIENT_SSL) {
416            let (_seq, hs) = self.reader.next_async().await?.ok_or_else(|| {
417                io::Error::new(
418                    io::ErrorKind::ConnectionAborted,
419                    "peer terminated connection",
420                )
421            })?;
422            seq = _seq;
423
424            handshake = commands::client_handshake(&hs, true)
425                .map_err(|e| match e {
426                    nom::Err::Incomplete(_) => io::Error::new(
427                        io::ErrorKind::UnexpectedEof,
428                        "client sent incomplete handshake",
429                    ),
430                    nom::Err::Failure(nom_error) | nom::Err::Error(nom_error) => {
431                        if let nom::error::ErrorKind::Eof = nom_error.code {
432                            io::Error::new(
433                                io::ErrorKind::UnexpectedEof,
434                                format!(
435                                    "client did not complete handshake; got {:?}",
436                                    nom_error.input
437                                ),
438                            )
439                        } else {
440                            io::Error::new(
441                                io::ErrorKind::InvalidData,
442                                format!(
443                                    "bad client handshake; got {:?} ({:?})",
444                                    nom_error.input, nom_error.code
445                                ),
446                            )
447                        }
448                    }
449                })?
450                .1;
451
452            self.writer.set_seq(seq + 1);
453        }
454
455        let scramble = self.shim.salt();
456        {
457            if !handshake
458                .capabilities
459                .contains(CapabilityFlags::CLIENT_PROTOCOL_41)
460            {
461                let err = io::Error::new(
462                    io::ErrorKind::ConnectionAborted,
463                    "Required capability: CLIENT_PROTOCOL_41, please upgrade your MySQL client version",
464                );
465                return Err(err.into());
466            }
467
468            self.client_capabilities = handshake.capabilities;
469            let mut auth_response = handshake.auth_response.clone();
470            if let Some(username) = &handshake.username {
471                let auth_plugin_expect = self.shim.auth_plugin_for_username(username).await;
472
473                // auth switch
474                if !auth_plugin_expect.is_empty()
475                    && auth_response.is_empty()
476                    && handshake.auth_plugin != auth_plugin_expect.as_bytes()
477                {
478                    self.writer.set_seq(seq + 1);
479                    self.writer.write_all(&[0xfe])?;
480                    self.writer.write_all(auth_plugin_expect.as_bytes())?;
481                    self.writer.write_all(&[0x00])?;
482                    self.writer.write_all(&scramble)?;
483                    self.writer.write_all(&[0x00])?;
484
485                    self.writer.end_packet().await?;
486                    self.writer.flush_all().await?;
487
488                    {
489                        let (rseq, auth_response_data) =
490                            self.reader.next_async().await?.ok_or_else(|| {
491                                io::Error::new(
492                                    io::ErrorKind::ConnectionAborted,
493                                    "peer terminated connection",
494                                )
495                            })?;
496
497                        seq = rseq;
498                        auth_response = auth_response_data.to_vec();
499                    }
500                }
501
502                self.writer.set_seq(seq + 1);
503
504                if !self
505                    .shim
506                    .authenticate(
507                        auth_plugin_expect,
508                        username,
509                        &scramble,
510                        auth_response.as_slice(),
511                    )
512                    .await
513                {
514                    let err_msg = format!(
515                        "Authenticate failed, user: {:?}, auth_plugin: {:?}",
516                        String::from_utf8_lossy(username),
517                        auth_plugin_expect,
518                    );
519                    writers::write_err(
520                        ErrorKind::ER_ACCESS_DENIED_NO_PASSWORD_ERROR,
521                        err_msg.as_bytes(),
522                        &mut self.writer,
523                    )
524                    .await?;
525                    self.writer.flush_all().await?;
526                    return Err(io::Error::new(io::ErrorKind::PermissionDenied, err_msg).into());
527                }
528
529                if let Some(Ok(db)) = handshake.db.as_ref().map(|x| std::str::from_utf8(x)) {
530                    let w = InitWriter {
531                        client_capabilities: self.client_capabilities,
532                        writer: &mut self.writer,
533                    };
534                    self.shim.on_init(db, w).await?;
535                } else if self.reject_connection_on_dbname_absence {
536                    writers::write_err(
537                        ErrorKind::ER_DATABASE_NAME,
538                        "database required on connection".as_bytes(),
539                        &mut self.writer,
540                    )
541                    .await?;
542                } else {
543                    writers::write_ok_packet(
544                        &mut self.writer,
545                        self.client_capabilities,
546                        OkResponse::default(),
547                    )
548                    .await?;
549                }
550            }
551
552            self.writer.flush_all().await?;
553        };
554
555        Ok(())
556    }
557
558    async fn run(mut self) -> Result<(), B::Error> {
559        use crate::commands::Command;
560
561        let mut stmts: HashMap<u32, _> = HashMap::new();
562        while let Some((seq, packet)) = self.reader.next_async().await? {
563            self.writer.set_seq(seq + 1);
564            let res = commands::parse(&packet);
565            match res {
566                Ok(cmd) => {
567                    match cmd.1 {
568                        Command::Query(q) => {
569                            if q.starts_with(b"SELECT @@") || q.starts_with(b"select @@") {
570                                let w = QueryResultWriter::new(
571                                    &mut self.writer,
572                                    false,
573                                    self.client_capabilities,
574                                );
575
576                                let var = &q[b"SELECT @@".len()..];
577                                let var_with_at = &q[b"SELECT ".len()..];
578                                let cols = &[Column {
579                                    table: String::new(),
580                                    column: String::from_utf8_lossy(var_with_at).to_string(),
581                                    coltype: myc::constants::ColumnType::MYSQL_TYPE_LONG,
582                                    colflags: myc::constants::ColumnFlags::UNSIGNED_FLAG,
583                                }];
584
585                                match var {
586                                    b"max_allowed_packet" => {
587                                        let mut w = w.start(cols).await?;
588                                        w.write_row(iter::once(67108864u32)).await?;
589                                        w.finish().await?;
590                                    }
591                                    _ => {
592                                        self.shim
593                                            .on_query(
594                                                ::std::str::from_utf8(q).map_err(|e| {
595                                                    io::Error::new(io::ErrorKind::InvalidData, e)
596                                                })?,
597                                                w,
598                                            )
599                                            .await?;
600                                    }
601                                }
602                            } else if !self.process_use_statement_on_query
603                                && (q.starts_with(b"USE ") || q.starts_with(b"use "))
604                            {
605                                let w = InitWriter {
606                                    client_capabilities: self.client_capabilities,
607                                    writer: &mut self.writer,
608                                };
609                                let schema = ::std::str::from_utf8(&q[b"USE ".len()..])
610                                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
611                                let schema = schema.trim().trim_end_matches(';').trim_matches('`');
612                                self.shim.on_init(schema, w).await?;
613                            } else {
614                                let w = QueryResultWriter::new(
615                                    &mut self.writer,
616                                    false,
617                                    self.client_capabilities,
618                                );
619                                self.shim
620                                    .on_query(
621                                        ::std::str::from_utf8(q).map_err(|e| {
622                                            io::Error::new(io::ErrorKind::InvalidData, e)
623                                        })?,
624                                        w,
625                                    )
626                                    .await?;
627                            }
628                        }
629                        Command::Prepare(q) => {
630                            let w = StatementMetaWriter {
631                                writer: &mut self.writer,
632                                stmts: &mut stmts,
633                                client_capabilities: self.client_capabilities,
634                            };
635
636                            self.shim
637                                .on_prepare(
638                                    ::std::str::from_utf8(q).map_err(|e| {
639                                        io::Error::new(io::ErrorKind::InvalidData, e)
640                                    })?,
641                                    w,
642                                )
643                                .await?;
644                        }
645                        Command::Execute { stmt, params } => {
646                            let state = stmts.get_mut(&stmt).ok_or_else(|| {
647                                io::Error::new(
648                                    io::ErrorKind::InvalidData,
649                                    format!("asked to execute unknown statement {}", stmt),
650                                )
651                            })?;
652                            {
653                                let params = params::ParamParser::new(params, state);
654                                let w = QueryResultWriter::new(
655                                    &mut self.writer,
656                                    true,
657                                    self.client_capabilities,
658                                );
659                                self.shim.on_execute(stmt, params, w).await?;
660                            }
661                            state.long_data.clear();
662                        }
663                        Command::SendLongData { stmt, param, data } => {
664                            stmts
665                                .get_mut(&stmt)
666                                .ok_or_else(|| {
667                                    io::Error::new(
668                                        io::ErrorKind::InvalidData,
669                                        format!(
670                                            "got long data packet for unknown statement {}",
671                                            stmt
672                                        ),
673                                    )
674                                })?
675                                .long_data
676                                .entry(param)
677                                .or_insert_with(Vec::new)
678                                .extend(data);
679                        }
680                        Command::Close(stmt) => {
681                            self.shim.on_close(stmt).await;
682                            stmts.remove(&stmt);
683                            // NOTE: spec dictates no response from server
684                        }
685                        Command::ListFields(_) => {
686                            // mysql_list_fields (CommandByte::COM_FIELD_LIST / 0x04) has been deprecated in mysql 5.7
687                            // and will be removed in a future version.
688                            // The mysql command line tool issues one of these commands after switching databases with USE <DB>.
689                            // Return a invalid column definitions lead to incorrect mariadb-client behaviour,
690                            // see https://github.com/datafuselabs/databend/issues/4439
691                            let ok_packet = OkResponse {
692                                header: 0xfe,
693                                ..Default::default()
694                            };
695                            writers::write_ok_packet(
696                                &mut self.writer,
697                                self.client_capabilities,
698                                ok_packet,
699                            )
700                            .await?;
701                        }
702                        Command::Init(schema) => {
703                            let w = InitWriter {
704                                client_capabilities: self.client_capabilities,
705                                writer: &mut self.writer,
706                            };
707                            self.shim
708                                .on_init(
709                                    ::std::str::from_utf8(schema).map_err(|e| {
710                                        io::Error::new(io::ErrorKind::InvalidData, e)
711                                    })?,
712                                    w,
713                                )
714                                .await?;
715                        }
716                        Command::Ping => {
717                            writers::write_ok_packet(
718                                &mut self.writer,
719                                self.client_capabilities,
720                                OkResponse::default(),
721                            )
722                            .await?;
723                        }
724                        Command::Quit => {
725                            break;
726                        }
727                    }
728                    self.writer.flush_all().await?;
729                }
730                Err(_) => {
731                    // if parser err, we need also stay the conn,
732                    // because we can not support all command.
733                    writers::write_ok_packet(
734                        &mut self.writer,
735                        self.client_capabilities,
736                        OkResponse::default(),
737                    )
738                    .await?;
739                    self.writer.flush_all().await?;
740                }
741            }
742        }
743        Ok(())
744    }
745}