mysql_async/conn/
mod.rs

1// Copyright (c) 2016 Anatoly Ikorsky
2//
3// Licensed under the Apache License, Version 2.0
4// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. All files in the project carrying such notice may not be copied,
7// modified, or distributed except according to those terms.
8
9use futures_util::FutureExt;
10pub use mysql_common::named_params;
11
12use mysql_common::{
13    constants::DEFAULT_MAX_ALLOWED_PACKET,
14    crypto,
15    io::ParseBuf,
16    packets::{
17        AuthPlugin, AuthSwitchRequest, CommonOkPacket, ErrPacket, HandshakePacket,
18        HandshakeResponse, OkPacket, OkPacketDeserializer, OldAuthSwitchRequest, OldEofPacket,
19        ResultSetTerminator,
20    },
21    proto::MySerialize,
22    row::Row,
23};
24
25use std::{
26    borrow::Cow,
27    fmt,
28    future::Future,
29    mem::{self, replace},
30    pin::Pin,
31    str::FromStr,
32    sync::Arc,
33    time::{Duration, Instant},
34};
35
36use crate::{
37    buffer_pool::PooledBuf,
38    conn::{pool::Pool, stmt_cache::StmtCache},
39    consts::{CapabilityFlags, Command, StatusFlags},
40    error::*,
41    io::Stream,
42    opts::Opts,
43    queryable::{
44        query_result::{QueryResult, ResultSetMeta},
45        transaction::TxStatus,
46        BinaryProtocol, Queryable, TextProtocol,
47    },
48    ChangeUserOpts, InfileData, OptsBuilder,
49};
50
51use self::routines::Routine;
52
53#[cfg(feature = "binlog")]
54pub mod binlog_stream;
55pub mod pool;
56pub mod routines;
57pub mod stmt_cache;
58
59const DEFAULT_WAIT_TIMEOUT: usize = 28800;
60
61/// Helper that asynchronously disconnects the givent connection on the default tokio executor.
62fn disconnect(mut conn: Conn) {
63    let disconnected = conn.inner.disconnected;
64
65    // Mark conn as disconnected.
66    conn.inner.disconnected = true;
67
68    if !disconnected {
69        // We shouldn't call tokio::spawn if unwinding
70        if std::thread::panicking() {
71            return;
72        }
73
74        // Server will report broken connection if spawn fails.
75        // this might fail if, say, the runtime is shutting down, but we've done what we could
76        if let Ok(handle) = tokio::runtime::Handle::try_current() {
77            handle.spawn(async move {
78                if let Ok(conn) = conn.cleanup_for_pool().await {
79                    let _ = conn.disconnect().await;
80                }
81            });
82        }
83    }
84}
85
86/// Pending result set.
87#[derive(Debug, Clone)]
88pub(crate) enum PendingResult {
89    /// There is a pending result set.
90    Pending(ResultSetMeta),
91    /// Result set metadata was taken but not yet consumed.
92    Taken(Arc<ResultSetMeta>),
93}
94
95/// Mysql connection
96struct ConnInner {
97    stream: Option<Stream>,
98    id: u32,
99    is_mariadb: bool,
100    version: (u16, u16, u16),
101    socket: Option<String>,
102    capabilities: CapabilityFlags,
103    status: StatusFlags,
104    last_ok_packet: Option<OkPacket<'static>>,
105    last_err_packet: Option<mysql_common::packets::ServerError<'static>>,
106    pool: Option<Pool>,
107    pending_result: std::result::Result<Option<PendingResult>, ServerError>,
108    tx_status: TxStatus,
109    reset_upon_returning_to_a_pool: bool,
110    opts: Opts,
111    ttl_deadline: Option<Instant>,
112    last_io: Instant,
113    wait_timeout: Duration,
114    stmt_cache: StmtCache,
115    nonce: Vec<u8>,
116    auth_plugin: AuthPlugin<'static>,
117    auth_switched: bool,
118    server_key: Option<Vec<u8>>,
119    /// Connection is already disconnected.
120    pub(crate) disconnected: bool,
121    /// One-time connection-level infile handler.
122    infile_handler:
123        Option<Pin<Box<dyn Future<Output = crate::Result<InfileData>> + Send + Sync + 'static>>>,
124}
125
126impl fmt::Debug for ConnInner {
127    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128        f.debug_struct("Conn")
129            .field("connection id", &self.id)
130            .field("server version", &self.version)
131            .field("pool", &self.pool)
132            .field("pending_result", &self.pending_result)
133            .field("tx_status", &self.tx_status)
134            .field("stream", &self.stream)
135            .field("options", &self.opts)
136            .field("server_key", &self.server_key)
137            .field("auth_plugin", &self.auth_plugin)
138            .finish()
139    }
140}
141
142impl ConnInner {
143    /// Constructs an empty connection.
144    fn empty(opts: Opts) -> ConnInner {
145        let ttl_deadline = opts.pool_opts().new_connection_ttl_deadline();
146        ConnInner {
147            capabilities: opts.get_capabilities(),
148            status: StatusFlags::empty(),
149            last_ok_packet: None,
150            last_err_packet: None,
151            stream: None,
152            is_mariadb: false,
153            version: (0, 0, 0),
154            id: 0,
155            pending_result: Ok(None),
156            pool: None,
157            tx_status: TxStatus::None,
158            last_io: Instant::now(),
159            wait_timeout: Duration::from_secs(0),
160            stmt_cache: StmtCache::new(opts.stmt_cache_size()),
161            socket: opts.socket().map(Into::into),
162            opts,
163            ttl_deadline,
164            nonce: Vec::default(),
165            auth_plugin: AuthPlugin::MysqlNativePassword,
166            auth_switched: false,
167            disconnected: false,
168            server_key: None,
169            infile_handler: None,
170            reset_upon_returning_to_a_pool: false,
171        }
172    }
173
174    /// Returns mutable reference to a connection stream.
175    ///
176    /// Returns `DriverError::ConnectionClosed` if there is no stream.
177    fn stream_mut(&mut self) -> Result<&mut Stream> {
178        self.stream
179            .as_mut()
180            .ok_or_else(|| DriverError::ConnectionClosed.into())
181    }
182}
183
184/// MySql server connection.
185#[derive(Debug)]
186pub struct Conn {
187    inner: Box<ConnInner>,
188}
189
190impl Conn {
191    /// Returns connection identifier.
192    pub fn id(&self) -> u32 {
193        self.inner.id
194    }
195
196    /// Returns the ID generated by a query (usually `INSERT`) on a table with a column having the
197    /// `AUTO_INCREMENT` attribute. Returns `None` if there was no previous query on the connection
198    /// or if the query did not update an AUTO_INCREMENT value.
199    pub fn last_insert_id(&self) -> Option<u64> {
200        self.inner
201            .last_ok_packet
202            .as_ref()
203            .and_then(|ok| ok.last_insert_id())
204    }
205
206    /// Returns the number of rows affected by the last `INSERT`, `UPDATE`, `REPLACE` or `DELETE`
207    /// query.
208    pub fn affected_rows(&self) -> u64 {
209        self.inner
210            .last_ok_packet
211            .as_ref()
212            .map(|ok| ok.affected_rows())
213            .unwrap_or_default()
214    }
215
216    /// Text information, as reported by the server in the last OK packet, or an empty string.
217    pub fn info(&self) -> Cow<'_, str> {
218        self.inner
219            .last_ok_packet
220            .as_ref()
221            .and_then(|ok| ok.info_str())
222            .unwrap_or_else(|| "".into())
223    }
224
225    /// Number of warnings, as reported by the server in the last OK packet, or `0`.
226    pub fn get_warnings(&self) -> u16 {
227        self.inner
228            .last_ok_packet
229            .as_ref()
230            .map(|ok| ok.warnings())
231            .unwrap_or_default()
232    }
233
234    /// Returns a reference to the last OK packet.
235    pub fn last_ok_packet(&self) -> Option<&OkPacket<'static>> {
236        self.inner.last_ok_packet.as_ref()
237    }
238
239    /// Turns on/off automatic connection reset (see [`crate::PoolOpts::with_reset_connection`]).
240    ///
241    /// Only makes sense for pooled connections.
242    pub fn reset_connection(&mut self, reset_connection: bool) {
243        self.inner.reset_upon_returning_to_a_pool = reset_connection;
244    }
245
246    pub(crate) fn stream_mut(&mut self) -> Result<&mut Stream> {
247        self.inner.stream_mut()
248    }
249
250    pub(crate) fn capabilities(&self) -> CapabilityFlags {
251        self.inner.capabilities
252    }
253
254    /// Will update last IO time for this connection.
255    pub(crate) fn touch(&mut self) {
256        self.inner.last_io = Instant::now();
257    }
258
259    /// Will set packet sequence id to `0`.
260    pub(crate) fn reset_seq_id(&mut self) {
261        if let Some(stream) = self.inner.stream.as_mut() {
262            stream.reset_seq_id();
263        }
264    }
265
266    /// Will syncronize sequence ids between compressed and uncompressed codecs.
267    pub(crate) fn sync_seq_id(&mut self) {
268        if let Some(stream) = self.inner.stream.as_mut() {
269            stream.sync_seq_id();
270        }
271    }
272
273    /// Handles OK packet.
274    pub(crate) fn handle_ok(&mut self, ok_packet: OkPacket<'static>) {
275        self.inner.status = ok_packet.status_flags();
276        self.inner.last_err_packet = None;
277        self.inner.last_ok_packet = Some(ok_packet);
278    }
279
280    /// Handles ERR packet.
281    pub(crate) fn handle_err(&mut self, err_packet: ErrPacket<'_>) -> Result<()> {
282        match err_packet {
283            ErrPacket::Error(err) => {
284                self.inner.status = StatusFlags::empty();
285                self.inner.last_ok_packet = None;
286                self.inner.last_err_packet = Some(err.clone().into_owned());
287                Err(Error::from(err))
288            }
289            ErrPacket::Progress(_) => Ok(()),
290        }
291    }
292
293    /// Returns the current transaction status.
294    pub(crate) fn get_tx_status(&self) -> TxStatus {
295        self.inner.tx_status
296    }
297
298    /// Sets the given transaction status for this connection.
299    pub(crate) fn set_tx_status(&mut self, tx_status: TxStatus) {
300        self.inner.tx_status = tx_status;
301    }
302
303    /// Returns pending result metadata, if any.
304    ///
305    /// If `Some(_)`, then result is not yet consumed.
306    pub(crate) fn use_pending_result(
307        &mut self,
308    ) -> std::result::Result<Option<&PendingResult>, ServerError> {
309        if let Err(ref e) = self.inner.pending_result {
310            let e = e.clone();
311            self.inner.pending_result = Ok(None);
312            Err(e)
313        } else {
314            Ok(self.inner.pending_result.as_ref().unwrap().as_ref())
315        }
316    }
317
318    pub(crate) fn get_pending_result(
319        &self,
320    ) -> std::result::Result<Option<&PendingResult>, &ServerError> {
321        self.inner.pending_result.as_ref().map(|x| x.as_ref())
322    }
323
324    pub(crate) fn has_pending_result(&self) -> bool {
325        self.inner.pending_result.is_err() || matches!(self.inner.pending_result, Ok(Some(_)))
326    }
327
328    /// Sets the given pening result metadata for this connection. Returns the previous value.
329    pub(crate) fn set_pending_result(
330        &mut self,
331        meta: Option<ResultSetMeta>,
332    ) -> std::result::Result<Option<PendingResult>, ServerError> {
333        replace(
334            &mut self.inner.pending_result,
335            Ok(meta.map(PendingResult::Pending)),
336        )
337    }
338
339    pub(crate) fn set_pending_result_error(
340        &mut self,
341        error: ServerError,
342    ) -> std::result::Result<Option<PendingResult>, ServerError> {
343        replace(&mut self.inner.pending_result, Err(error))
344    }
345
346    /// Gives the currently pending result to a caller for consumption.
347    pub(crate) fn take_pending_result(
348        &mut self,
349    ) -> std::result::Result<Option<Arc<ResultSetMeta>>, ServerError> {
350        let mut output = None;
351
352        self.inner.pending_result = match replace(&mut self.inner.pending_result, Ok(None))? {
353            Some(PendingResult::Pending(x)) => {
354                let meta = Arc::new(x);
355                output = Some(meta.clone());
356                Ok(Some(PendingResult::Taken(meta)))
357            }
358            x => Ok(x),
359        };
360
361        Ok(output)
362    }
363
364    /// Returns current status flags.
365    pub(crate) fn status(&self) -> StatusFlags {
366        self.inner.status
367    }
368
369    pub(crate) async fn routine<'a, F, T>(&mut self, mut f: F) -> crate::Result<T>
370    where
371        F: Routine<T> + 'a,
372    {
373        self.inner.disconnected = true;
374        let result = f.call(&mut *self).await;
375        match result {
376            result @ Ok(_) | result @ Err(crate::Error::Server(_)) => {
377                // either OK or non-fatal error
378                self.inner.disconnected = false;
379                result
380            }
381            Err(err) => {
382                if self.inner.stream.is_some() {
383                    self.take_stream().close().await?;
384                }
385                Err(err)
386            }
387        }
388    }
389
390    /// Returns server version.
391    pub fn server_version(&self) -> (u16, u16, u16) {
392        self.inner.version
393    }
394
395    /// Returns connection options.
396    pub fn opts(&self) -> &Opts {
397        &self.inner.opts
398    }
399
400    /// Setup _local_ `LOCAL INFILE` handler (see ["LOCAL INFILE Handlers"][2] section
401    /// of the crate-level docs).
402    ///
403    /// It'll overwrite existing _local_ handler, if any.
404    ///
405    /// [2]: ../mysql_async/#local-infile-handlers
406    pub fn set_infile_handler<T>(&mut self, handler: T)
407    where
408        T: Future<Output = crate::Result<InfileData>>,
409        T: Send + Sync + 'static,
410    {
411        self.inner.infile_handler = Some(Box::pin(handler));
412    }
413
414    fn take_stream(&mut self) -> Stream {
415        self.inner.stream.take().unwrap()
416    }
417
418    /// Disconnects this connection from server.
419    pub async fn disconnect(mut self) -> Result<()> {
420        if !self.inner.disconnected {
421            self.inner.disconnected = true;
422            self.write_command_data(Command::COM_QUIT, &[]).await?;
423            let stream = self.take_stream();
424            stream.close().await?;
425        }
426        Ok(())
427    }
428
429    /// Closes the connection.
430    async fn close_conn(mut self) -> Result<()> {
431        self = self.cleanup_for_pool().await?;
432        self.disconnect().await
433    }
434
435    /// Returns true if io stream is encrypted.
436    fn is_secure(&self) -> bool {
437        #[cfg(any(
438            feature = "native-tls-tls",
439            feature = "rustls-tls",
440            feature = "wasmedge-tls"
441        ))]
442        {
443            self.inner
444                .stream
445                .as_ref()
446                .map(|x| x.is_secure())
447                .unwrap_or_default()
448        }
449
450        #[cfg(not(any(
451            feature = "native-tls-tls",
452            feature = "rustls-tls",
453            feature = "wasmedge-tls"
454        )))]
455        false
456    }
457
458    /// Returns true if io stream is socket.
459    fn is_socket(&self) -> bool {
460        #[cfg(unix)]
461        {
462            self.inner
463                .stream
464                .as_ref()
465                .map(|x| x.is_socket())
466                .unwrap_or_default()
467        }
468
469        #[cfg(not(unix))]
470        false
471    }
472
473    /// Hacky way to move connection through &mut. `self` becomes unusable.
474    fn take(&mut self) -> Conn {
475        mem::replace(self, Conn::empty(Default::default()))
476    }
477
478    fn empty(opts: Opts) -> Self {
479        Self {
480            inner: Box::new(ConnInner::empty(opts)),
481        }
482    }
483
484    /// Set `io::Stream` options as defined in the `Opts` of the connection.
485    ///
486    /// Requires that self.inner.stream is Some
487    fn setup_stream(&mut self) -> Result<()> {
488        debug_assert!(self.inner.stream.is_some());
489        if let Some(stream) = self.inner.stream.as_mut() {
490            stream.set_tcp_nodelay(self.inner.opts.tcp_nodelay())?;
491        }
492        Ok(())
493    }
494
495    async fn handle_handshake(&mut self) -> Result<()> {
496        let packet = self.read_packet().await?;
497        let handshake = ParseBuf(&packet).parse::<HandshakePacket>(())?;
498
499        // Handshake scramble is always 21 bytes length (20 + zero terminator)
500        self.inner.nonce = {
501            let mut nonce = Vec::from(handshake.scramble_1_ref());
502            nonce.extend_from_slice(handshake.scramble_2_ref().unwrap_or(&[][..]));
503            // Trim zero terminator. Fill with zeroes if nonce
504            // is somehow smaller than 20 bytes (this matches the server behavior).
505            nonce.resize(20, 0);
506            nonce
507        };
508
509        self.inner.capabilities = handshake.capabilities() & self.inner.opts.get_capabilities();
510        self.inner.version = handshake
511            .maria_db_server_version_parsed()
512            .map(|version| {
513                self.inner.is_mariadb = true;
514                version
515            })
516            .or_else(|| handshake.server_version_parsed())
517            .unwrap_or((0, 0, 0));
518        self.inner.id = handshake.connection_id();
519        self.inner.status = handshake.status_flags();
520
521        // Allow only CachingSha2Password and MysqlNativePassword here
522        // because sha256_password is deprecated and other plugins won't
523        // appear here.
524        self.inner.auth_plugin = match handshake.auth_plugin() {
525            Some(AuthPlugin::CachingSha2Password) => AuthPlugin::CachingSha2Password,
526            _ => AuthPlugin::MysqlNativePassword,
527        };
528
529        Ok(())
530    }
531
532    async fn switch_to_ssl_if_needed(&mut self) -> Result<()> {
533        if self
534            .inner
535            .opts
536            .get_capabilities()
537            .contains(CapabilityFlags::CLIENT_SSL)
538        {
539            if !self
540                .inner
541                .capabilities
542                .contains(CapabilityFlags::CLIENT_SSL)
543            {
544                return Err(DriverError::NoClientSslFlagFromServer.into());
545            }
546
547            let collation = if self.inner.version >= (5, 5, 3) {
548                mysql_common::constants::UTF8MB4_GENERAL_CI
549            } else {
550                crate::consts::UTF8MB4_GENERAL_CI
551            };
552
553            let ssl_request = mysql_common::packets::SslRequest::new(
554                self.inner.capabilities,
555                DEFAULT_MAX_ALLOWED_PACKET as u32,
556                collation as u8,
557            );
558            self.write_struct(&ssl_request).await?;
559            let conn = self;
560            let ssl_opts = conn.opts().ssl_opts().cloned().expect("unreachable");
561            let domain = conn.opts().ip_or_hostname().into();
562            conn.stream_mut()?.make_secure(domain, ssl_opts).await?;
563            Ok(())
564        } else {
565            Ok(())
566        }
567    }
568
569    async fn do_handshake_response(&mut self) -> Result<()> {
570        let auth_data = self
571            .inner
572            .auth_plugin
573            .gen_data(self.inner.opts.pass(), &self.inner.nonce);
574
575        let handshake_response = HandshakeResponse::new(
576            auth_data.as_deref(),
577            self.inner.version,
578            self.inner.opts.user().map(|x| x.as_bytes()),
579            self.inner.opts.db_name().map(|x| x.as_bytes()),
580            Some(self.inner.auth_plugin.borrow()),
581            self.capabilities(),
582            Default::default(), // TODO: Add support
583        );
584
585        // Serialize here to satisfy borrow checker.
586        let mut buf = crate::BUFFER_POOL.get();
587        handshake_response.serialize(buf.as_mut());
588
589        self.write_packet(buf).await?;
590        Ok(())
591    }
592
593    async fn perform_auth_switch(
594        &mut self,
595        auth_switch_request: AuthSwitchRequest<'_>,
596    ) -> Result<()> {
597        if !self.inner.auth_switched {
598            self.inner.auth_switched = true;
599            self.inner.nonce = auth_switch_request.plugin_data().to_vec();
600
601            if matches!(
602                auth_switch_request.auth_plugin(),
603                AuthPlugin::MysqlOldPassword
604            ) && self.inner.opts.secure_auth()
605            {
606                return Err(DriverError::MysqlOldPasswordDisabled.into());
607            }
608
609            self.inner.auth_plugin = auth_switch_request.auth_plugin().clone().into_owned();
610
611            let plugin_data = match &self.inner.auth_plugin {
612                x @ AuthPlugin::CachingSha2Password => {
613                    x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
614                }
615                x @ AuthPlugin::MysqlNativePassword => {
616                    x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
617                }
618                x @ AuthPlugin::MysqlOldPassword => {
619                    if self.inner.opts.secure_auth() {
620                        return Err(DriverError::MysqlOldPasswordDisabled.into());
621                    } else {
622                        x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
623                    }
624                }
625                x @ AuthPlugin::MysqlClearPassword => {
626                    if self.inner.opts.enable_cleartext_plugin() {
627                        x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
628                    } else {
629                        return Err(DriverError::CleartextPluginDisabled.into());
630                    }
631                }
632                x @ AuthPlugin::Other(_) => x.gen_data(self.inner.opts.pass(), &self.inner.nonce),
633            };
634
635            if let Some(plugin_data) = plugin_data {
636                self.write_struct(&plugin_data.into_owned()).await?;
637            } else {
638                self.write_packet(crate::BUFFER_POOL.get()).await?;
639            }
640
641            self.continue_auth().await?;
642
643            Ok(())
644        } else {
645            unreachable!("auth_switched flag should be checked by caller")
646        }
647    }
648
649    fn continue_auth(&mut self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
650        // NOTE: we need to box this since it may recurse
651        // see https://github.com/rust-lang/rust/issues/46415#issuecomment-528099782
652        Box::pin(async move {
653            match self.inner.auth_plugin {
654                AuthPlugin::MysqlNativePassword | AuthPlugin::MysqlOldPassword => {
655                    self.continue_mysql_native_password_auth().await?;
656                    Ok(())
657                }
658                AuthPlugin::CachingSha2Password => {
659                    self.continue_caching_sha2_password_auth().await?;
660                    Ok(())
661                }
662                AuthPlugin::MysqlClearPassword => {
663                    if self.inner.opts.enable_cleartext_plugin() {
664                        self.continue_mysql_native_password_auth().await?;
665                        Ok(())
666                    } else {
667                        Err(DriverError::CleartextPluginDisabled.into())
668                    }
669                }
670                AuthPlugin::Other(ref name) => Err(DriverError::UnknownAuthPlugin {
671                    name: String::from_utf8_lossy(name.as_ref()).to_string(),
672                }
673                .into()),
674            }
675        })
676    }
677
678    fn switch_to_compression(&mut self) -> Result<()> {
679        if self
680            .capabilities()
681            .contains(CapabilityFlags::CLIENT_COMPRESS)
682        {
683            if let Some(compression) = self.inner.opts.compression() {
684                if let Some(stream) = self.inner.stream.as_mut() {
685                    stream.compress(compression);
686                }
687            }
688        }
689        Ok(())
690    }
691
692    async fn continue_caching_sha2_password_auth(&mut self) -> Result<()> {
693        let packet = self.read_packet().await?;
694        match packet.first() {
695            Some(0x00) => {
696                // ok packet for empty password
697                Ok(())
698            }
699            Some(0x01) => match packet.get(1) {
700                Some(0x03) => {
701                    // auth ok
702                    self.drop_packet().await
703                }
704                Some(0x04) => {
705                    let pass = self.inner.opts.pass().unwrap_or_default();
706                    let mut pass = crate::BUFFER_POOL.get_with(pass.as_bytes());
707                    pass.as_mut().push(0);
708
709                    if self.is_secure() || self.is_socket() {
710                        self.write_packet(pass).await?;
711                    } else {
712                        if self.inner.server_key.is_none() {
713                            self.write_bytes(&[0x02][..]).await?;
714                            let packet = self.read_packet().await?;
715                            self.inner.server_key = Some(packet[1..].to_vec());
716                        }
717                        for (i, byte) in pass.as_mut().iter_mut().enumerate() {
718                            *byte ^= self.inner.nonce[i % self.inner.nonce.len()];
719                        }
720                        let encrypted_pass = crypto::encrypt(
721                            &pass,
722                            self.inner.server_key.as_deref().expect("unreachable"),
723                        );
724                        self.write_bytes(&encrypted_pass).await?;
725                    };
726                    self.drop_packet().await?;
727                    Ok(())
728                }
729                _ => Err(DriverError::UnexpectedPacket {
730                    payload: packet.to_vec(),
731                }
732                .into()),
733            },
734            Some(0xfe) if !self.inner.auth_switched => {
735                let auth_switch_request = ParseBuf(&packet).parse::<AuthSwitchRequest>(())?;
736                self.perform_auth_switch(auth_switch_request).await?;
737                Ok(())
738            }
739            _ => Err(DriverError::UnexpectedPacket {
740                payload: packet.to_vec(),
741            }
742            .into()),
743        }
744    }
745
746    async fn continue_mysql_native_password_auth(&mut self) -> Result<()> {
747        let packet = self.read_packet().await?;
748        match packet.first() {
749            Some(0x00) => Ok(()),
750            Some(0xfe) if !self.inner.auth_switched => {
751                let auth_switch = if packet.len() > 1 {
752                    ParseBuf(&packet).parse(())?
753                } else {
754                    let _ = ParseBuf(&packet).parse::<OldAuthSwitchRequest>(())?;
755                    // map OldAuthSwitch to AuthSwitch with mysql_old_password plugin
756                    AuthSwitchRequest::new(
757                        "mysql_old_password".as_bytes(),
758                        self.inner.nonce.clone(),
759                    )
760                };
761                self.perform_auth_switch(auth_switch).await
762            }
763            _ => Err(DriverError::UnexpectedPacket {
764                payload: packet.to_vec(),
765            }
766            .into()),
767        }
768    }
769
770    /// Returns `true` for ProgressReport packet.
771    fn handle_packet(&mut self, packet: &PooledBuf) -> Result<bool> {
772        let ok_packet = if self.has_pending_result() {
773            if self
774                .capabilities()
775                .contains(CapabilityFlags::CLIENT_DEPRECATE_EOF)
776            {
777                ParseBuf(packet)
778                    .parse::<OkPacketDeserializer<ResultSetTerminator>>(self.capabilities())
779                    .map(|x| x.into_inner())
780            } else {
781                ParseBuf(packet)
782                    .parse::<OkPacketDeserializer<OldEofPacket>>(self.capabilities())
783                    .map(|x| x.into_inner())
784            }
785        } else {
786            ParseBuf(packet)
787                .parse::<OkPacketDeserializer<CommonOkPacket>>(self.capabilities())
788                .map(|x| x.into_inner())
789        };
790
791        if let Ok(ok_packet) = ok_packet {
792            self.handle_ok(ok_packet.into_owned());
793        } else {
794            let err_packet = ParseBuf(packet).parse::<ErrPacket>(self.capabilities());
795            if let Ok(err_packet) = err_packet {
796                self.handle_err(err_packet)?;
797                return Ok(true);
798            }
799        }
800
801        Ok(false)
802    }
803
804    pub(crate) async fn read_packet(&mut self) -> Result<PooledBuf> {
805        loop {
806            let packet = crate::io::ReadPacket::new(&mut *self)
807                .await
808                .map_err(|io_err| {
809                    self.inner.stream.take();
810                    self.inner.disconnected = true;
811                    Error::from(io_err)
812                })?;
813            if self.handle_packet(&packet)? {
814                // ignore progress report
815                continue;
816            } else {
817                return Ok(packet);
818            }
819        }
820    }
821
822    /// Returns future that reads packets from a server.
823    pub(crate) async fn read_packets(&mut self, n: usize) -> Result<Vec<PooledBuf>> {
824        let mut packets = Vec::with_capacity(n);
825        for _ in 0..n {
826            packets.push(self.read_packet().await?);
827        }
828        Ok(packets)
829    }
830
831    pub(crate) async fn write_packet(&mut self, data: PooledBuf) -> Result<()> {
832        crate::io::WritePacket::new(&mut *self, data)
833            .await
834            .map_err(|io_err| {
835                self.inner.stream.take();
836                self.inner.disconnected = true;
837                From::from(io_err)
838            })
839    }
840
841    /// Writes bytes to a server.
842    pub(crate) async fn write_bytes(&mut self, bytes: &[u8]) -> Result<()> {
843        let buf = crate::BUFFER_POOL.get_with(bytes);
844        self.write_packet(buf).await
845    }
846
847    /// Sends a serializable structure to a server.
848    pub(crate) async fn write_struct<T: MySerialize>(&mut self, x: &T) -> Result<()> {
849        let mut buf = crate::BUFFER_POOL.get();
850        x.serialize(buf.as_mut());
851        self.write_packet(buf).await
852    }
853
854    /// Sends a command to a server.
855    pub(crate) async fn write_command<T: MySerialize>(&mut self, cmd: &T) -> Result<()> {
856        self.clean_dirty().await?;
857        self.reset_seq_id();
858        self.write_struct(cmd).await
859    }
860
861    /// Returns future that sends full command body to a server.
862    pub(crate) async fn write_command_raw(&mut self, body: PooledBuf) -> Result<()> {
863        debug_assert!(!body.is_empty());
864        self.clean_dirty().await?;
865        self.reset_seq_id();
866        self.write_packet(body).await
867    }
868
869    /// Returns future that writes command to a server.
870    pub(crate) async fn write_command_data<T>(&mut self, cmd: Command, cmd_data: T) -> Result<()>
871    where
872        T: AsRef<[u8]>,
873    {
874        let cmd_data = cmd_data.as_ref();
875        let mut buf = crate::BUFFER_POOL.get();
876        let body = buf.as_mut();
877        body.push(cmd as u8);
878        body.extend_from_slice(cmd_data);
879        self.write_command_raw(buf).await
880    }
881
882    async fn drop_packet(&mut self) -> Result<()> {
883        self.read_packet().await?;
884        Ok(())
885    }
886
887    async fn run_init_commands(&mut self) -> Result<()> {
888        let mut init = self.inner.opts.init().to_vec();
889
890        while let Some(query) = init.pop() {
891            self.query_drop(query).await?;
892        }
893
894        Ok(())
895    }
896
897    async fn run_setup_commands(&mut self) -> Result<()> {
898        let mut setup = self.inner.opts.setup().to_vec();
899
900        while let Some(query) = setup.pop() {
901            self.query_drop(query).await?;
902        }
903
904        Ok(())
905    }
906
907    /// Returns a future that resolves to [`Conn`].
908    pub fn new<T: Into<Opts>>(opts: T) -> crate::BoxFuture<'static, Conn> {
909        let opts = opts.into();
910        async move {
911            let mut conn = Conn::empty(opts.clone());
912
913            let stream = if let Some(_path) = opts.socket() {
914                #[cfg(unix)]
915                {
916                    Stream::connect_socket(_path.to_owned()).await?
917                }
918                #[cfg(target_os = "windows")]
919                return Err(crate::DriverError::NamedPipesDisabled.into());
920                #[cfg(target_os = "wasi")]
921                return Err(crate::DriverError::NamedPipesDisabled.into());
922            } else {
923                let keepalive = opts
924                    .tcp_keepalive()
925                    .map(|x| std::time::Duration::from_millis(x.into()));
926                Stream::connect_tcp(opts.hostport_or_url(), keepalive).await?
927            };
928
929            conn.inner.stream = Some(stream);
930            conn.setup_stream()?;
931            conn.handle_handshake().await?;
932            conn.switch_to_ssl_if_needed().await?;
933            conn.do_handshake_response().await?;
934            conn.continue_auth().await?;
935            conn.switch_to_compression()?;
936            conn.read_settings().await?;
937            conn.reconnect_via_socket_if_needed().await?;
938            conn.run_init_commands().await?;
939            conn.run_setup_commands().await?;
940
941            Ok(conn)
942        }
943        .boxed()
944    }
945
946    /// Returns a future that resolves to [`Conn`].
947    pub async fn from_url<T: AsRef<str>>(url: T) -> Result<Conn> {
948        Conn::new(Opts::from_str(url.as_ref())?).await
949    }
950
951    /// Will try to reconnect via socket using socket address in `self.inner.socket`.
952    ///
953    /// Won't try to reconnect if socket connection is already enforced in [`Opts`].
954    async fn reconnect_via_socket_if_needed(&mut self) -> Result<()> {
955        if let Some(socket) = self.inner.socket.as_ref() {
956            let opts = self.inner.opts.clone();
957            if opts.socket().is_none() {
958                let opts = OptsBuilder::from_opts(opts).socket(Some(&**socket));
959                if let Ok(conn) = Conn::new(opts).await {
960                    let old_conn = std::mem::replace(self, conn);
961                    // tidy up the old connection
962                    old_conn.close_conn().await?;
963                }
964            }
965        }
966        Ok(())
967    }
968
969    /// Configures the connection based on server settings. In particular:
970    ///
971    /// * It reads and stores socket address inside the connection unless if socket address is
972    /// already in [`Opts`] or if `prefer_socket` is `false`.
973    ///
974    /// * It reads and stores `max_allowed_packet` in the connection unless it's already in [`Opts`]
975    ///
976    /// * It reads and stores `wait_timeout` in the connection unless it's already in [`Opts`]
977    ///
978    async fn read_settings(&mut self) -> Result<()> {
979        enum Action {
980            Load(Cfg),
981            Apply(CfgData),
982        }
983
984        enum CfgData {
985            MaxAllowedPacket(usize),
986            WaitTimeout(usize),
987        }
988
989        impl CfgData {
990            fn apply(&self, conn: &mut Conn) {
991                match self {
992                    Self::MaxAllowedPacket(value) => {
993                        if let Some(stream) = conn.inner.stream.as_mut() {
994                            stream.set_max_allowed_packet(*value);
995                        }
996                    }
997                    Self::WaitTimeout(value) => {
998                        conn.inner.wait_timeout = Duration::from_secs(*value as u64);
999                    }
1000                }
1001            }
1002        }
1003
1004        enum Cfg {
1005            Socket,
1006            MaxAllowedPacket,
1007            WaitTimeout,
1008        }
1009
1010        impl Cfg {
1011            const fn name(&self) -> &'static str {
1012                match self {
1013                    Self::Socket => "@@socket",
1014                    Self::MaxAllowedPacket => "@@max_allowed_packet",
1015                    Self::WaitTimeout => "@@wait_timeout",
1016                }
1017            }
1018
1019            fn apply(&self, conn: &mut Conn, value: Option<crate::Value>) {
1020                match self {
1021                    Cfg::Socket => {
1022                        conn.inner.socket = value.and_then(crate::from_value);
1023                    }
1024                    Cfg::MaxAllowedPacket => {
1025                        if let Some(stream) = conn.inner.stream.as_mut() {
1026                            stream.set_max_allowed_packet(
1027                                value
1028                                    .and_then(crate::from_value)
1029                                    .unwrap_or(DEFAULT_MAX_ALLOWED_PACKET),
1030                            );
1031                        }
1032                    }
1033                    Cfg::WaitTimeout => {
1034                        conn.inner.wait_timeout = Duration::from_secs(
1035                            value
1036                                .and_then(crate::from_value)
1037                                .unwrap_or(DEFAULT_WAIT_TIMEOUT) as u64,
1038                        );
1039                    }
1040                }
1041            }
1042        }
1043
1044        let mut actions = vec![
1045            if let Some(x) = self.opts().max_allowed_packet() {
1046                Action::Apply(CfgData::MaxAllowedPacket(x))
1047            } else {
1048                Action::Load(Cfg::MaxAllowedPacket)
1049            },
1050            if let Some(x) = self.opts().wait_timeout() {
1051                Action::Apply(CfgData::WaitTimeout(x))
1052            } else {
1053                Action::Load(Cfg::WaitTimeout)
1054            },
1055        ];
1056
1057        if self.inner.opts.prefer_socket() && self.inner.socket.is_none() {
1058            actions.push(Action::Load(Cfg::Socket))
1059        }
1060
1061        let loads = actions
1062            .iter()
1063            .filter_map(|x| match x {
1064                Action::Load(x) => Some(x),
1065                Action::Apply(_) => None,
1066            })
1067            .collect::<Vec<_>>();
1068
1069        let loaded = if !loads.is_empty() {
1070            let query = loads
1071                .iter()
1072                .zip(std::iter::once(' ').chain(std::iter::repeat(',')))
1073                .fold("SELECT".to_owned(), |mut acc, (cfg, prefix)| {
1074                    acc.push(prefix);
1075                    acc.push_str(cfg.name());
1076                    acc
1077                });
1078
1079            self.query_internal::<Row, String>(query)
1080                .await?
1081                .map(|row| row.unwrap())
1082                .unwrap_or_else(|| vec![crate::Value::NULL; loads.len()])
1083        } else {
1084            vec![]
1085        };
1086        let mut loaded = loaded.into_iter();
1087
1088        for action in actions {
1089            match action {
1090                Action::Load(cfg) => cfg.apply(self, loaded.next()),
1091                Action::Apply(cfg) => cfg.apply(self),
1092            }
1093        }
1094
1095        Ok(())
1096    }
1097
1098    /// Returns true if time since last IO exceeds `wait_timeout`
1099    /// (or `conn_ttl` if specified in opts).
1100    fn expired(&self) -> bool {
1101        if let Some(deadline) = self.inner.ttl_deadline {
1102            if Instant::now() > deadline {
1103                return true;
1104            }
1105        }
1106        let ttl = self
1107            .inner
1108            .opts
1109            .conn_ttl()
1110            .unwrap_or(self.inner.wait_timeout);
1111        !ttl.is_zero() && self.idling() > ttl
1112    }
1113
1114    /// Returns duration since last IO.
1115    fn idling(&self) -> Duration {
1116        self.inner.last_io.elapsed()
1117    }
1118
1119    /// Executes [`COM_RESET_CONNECTION`][1].
1120    ///
1121    /// Returns `false` if command is not supported (requires MySql >5.7.2, MariaDb >10.2.3).
1122    /// For older versions consider using [`Conn::change_user`].
1123    ///
1124    /// [1]: https://dev.mysql.com/doc/c-api/5.7/en/mysql-reset-connection.html
1125    pub async fn reset(&mut self) -> Result<bool> {
1126        let supports_com_reset_connection = if self.inner.is_mariadb {
1127            self.inner.version >= (10, 2, 4)
1128        } else {
1129            // assuming mysql
1130            self.inner.version > (5, 7, 2)
1131        };
1132
1133        if supports_com_reset_connection {
1134            self.routine(routines::ResetRoutine).await?;
1135            self.inner.stmt_cache.clear();
1136            self.inner.infile_handler = None;
1137            self.run_setup_commands().await?;
1138        }
1139
1140        Ok(supports_com_reset_connection)
1141    }
1142
1143    /// Executes [`COM_CHANGE_USER`][1].
1144    ///
1145    /// This might be used as an older and slower alternative to `COM_RESET_CONNECTION` that
1146    /// works on MySql prior to 5.7.3 (MariaDb prior ot 10.2.4).
1147    ///
1148    /// ## Note
1149    ///
1150    /// * Using non-default `opts` for a pooled connection is discouraging.
1151    /// * Connection options will be permanently updated.
1152    ///
1153    /// [1]: https://dev.mysql.com/doc/c-api/5.7/en/mysql-change-user.html
1154    pub async fn change_user(&mut self, opts: ChangeUserOpts) -> Result<()> {
1155        // We'll kick this connection from a pool if opts are changed.
1156        if opts != ChangeUserOpts::default() {
1157            let mut opts_changed = false;
1158            if let Some(user) = opts.user() {
1159                opts_changed |= user != self.opts().user()
1160            };
1161            if let Some(pass) = opts.pass() {
1162                opts_changed |= pass != self.opts().pass()
1163            };
1164            if let Some(db_name) = opts.db_name() {
1165                opts_changed |= db_name != self.opts().db_name()
1166            };
1167            if opts_changed {
1168                if let Some(pool) = self.inner.pool.take() {
1169                    pool.cancel_connection();
1170                }
1171            }
1172        }
1173
1174        let conn_opts = &mut self.inner.opts;
1175        opts.update_opts(conn_opts);
1176        self.routine(routines::ChangeUser).await?;
1177        self.inner.stmt_cache.clear();
1178        self.inner.infile_handler = None;
1179        self.run_setup_commands().await?;
1180        Ok(())
1181    }
1182
1183    /// Resets the connection upon returning it to a pool.
1184    ///
1185    /// Will invoke `COM_CHANGE_USER` if `COM_RESET_CONNECTION` is not supported.
1186    async fn reset_for_pool(mut self) -> Result<Self> {
1187        if !self.reset().await? {
1188            self.change_user(Default::default()).await?;
1189        }
1190        Ok(self)
1191    }
1192
1193    /// Requires that `self.inner.tx_status != TxStatus::None`
1194    async fn rollback_transaction(&mut self) -> Result<()> {
1195        debug_assert_ne!(self.inner.tx_status, TxStatus::None);
1196        self.inner.tx_status = TxStatus::None;
1197        self.query_drop("ROLLBACK").await
1198    }
1199
1200    /// Returns `true` if `SERVER_MORE_RESULTS_EXISTS` flag is contained
1201    /// in status flags of the connection.
1202    pub(crate) fn more_results_exists(&self) -> bool {
1203        self.status()
1204            .contains(StatusFlags::SERVER_MORE_RESULTS_EXISTS)
1205    }
1206
1207    /// The purpose of this function is to cleanup a pending result set
1208    /// for prematurely dropeed connection or query result.
1209    ///
1210    /// Requires that there are no other references to the pending result.
1211    pub(crate) async fn drop_result(&mut self) -> Result<()> {
1212        // Map everything into `PendingResult::Pending`
1213        let meta = match self.set_pending_result(None)? {
1214            Some(PendingResult::Pending(meta)) => Some(meta),
1215            Some(PendingResult::Taken(meta)) => {
1216                // This also asserts that there is only one reference left to the taken ResultSetMeta,
1217                // therefore this result set must be dropped here since it won't be dropped anywhere else.
1218                Some(Arc::try_unwrap(meta).expect("Conn::drop_result call on a pending result that may still be droped by someone else"))
1219            }
1220            None => None,
1221        };
1222
1223        let _ = self.set_pending_result(meta);
1224
1225        match self.use_pending_result() {
1226            Ok(Some(PendingResult::Pending(ResultSetMeta::Text(_)))) => {
1227                QueryResult::<'_, '_, TextProtocol>::new(self)
1228                    .drop_result()
1229                    .await
1230            }
1231            Ok(Some(PendingResult::Pending(ResultSetMeta::Binary(_)))) => {
1232                QueryResult::<'_, '_, BinaryProtocol>::new(self)
1233                    .drop_result()
1234                    .await
1235            }
1236            Ok(None) => Ok((/* this case does not require an action */)),
1237            Ok(Some(PendingResult::Taken(_))) | Err(_) => {
1238                unreachable!("this case must be handled earlier in this function")
1239            }
1240        }
1241    }
1242
1243    /// This function will drop pending result and rollback a transaction, if needed.
1244    ///
1245    /// The purpose of this function, is to cleanup the connection while returning it to a [`Pool`].
1246    async fn cleanup_for_pool(mut self) -> Result<Self> {
1247        loop {
1248            let result = if self.has_pending_result() {
1249                self.drop_result().await
1250            } else if self.inner.tx_status != TxStatus::None {
1251                self.rollback_transaction().await
1252            } else {
1253                break;
1254            };
1255
1256            // The connection was dropped and we assume that it was dropped intentionally,
1257            // so we'll ignore non-fatal errors during cleanup (also there is no direct caller
1258            // to return this error to).
1259            if let Err(err) = result {
1260                if err.is_fatal() {
1261                    // This means that connection is completely broken
1262                    // and shouldn't return to a pool.
1263                    return Err(err);
1264                }
1265            }
1266        }
1267        Ok(self)
1268    }
1269}
1270
1271#[cfg(test)]
1272mod test {
1273    use bytes::Bytes;
1274    use futures_util::stream::{self, StreamExt};
1275    use mysql_common::constants::MAX_PAYLOAD_LEN;
1276    use rand::Fill;
1277
1278    use crate::{
1279        from_row, params, prelude::*, test_misc::get_opts, ChangeUserOpts, Conn, Error,
1280        OptsBuilder, Pool, Value, WhiteListFsHandler,
1281    };
1282
1283    #[tokio::test]
1284    async fn should_return_found_rows_if_flag_is_set() -> super::Result<()> {
1285        let opts = get_opts().client_found_rows(true);
1286        let mut conn = Conn::new(opts).await.unwrap();
1287
1288        "CREATE TEMPORARY TABLE mysql.found_rows (id INT PRIMARY KEY AUTO_INCREMENT, val INT)"
1289            .ignore(&mut conn)
1290            .await?;
1291
1292        "INSERT INTO mysql.found_rows (val) VALUES (1)"
1293            .ignore(&mut conn)
1294            .await?;
1295
1296        // Inserted one row, affected should be one.
1297        assert_eq!(conn.affected_rows(), 1);
1298
1299        "UPDATE mysql.found_rows SET val = 1 WHERE val = 1"
1300            .ignore(&mut conn)
1301            .await?;
1302
1303        // The query doesn't affect any rows, but due to us wanting FOUND rows,
1304        // this has to return one.
1305        assert_eq!(conn.affected_rows(), 1);
1306
1307        Ok(())
1308    }
1309
1310    #[tokio::test]
1311    async fn should_not_return_found_rows_if_flag_is_not_set() -> super::Result<()> {
1312        let mut conn = Conn::new(get_opts()).await.unwrap();
1313
1314        "CREATE TEMPORARY TABLE mysql.found_rows (id INT PRIMARY KEY AUTO_INCREMENT, val INT)"
1315            .ignore(&mut conn)
1316            .await?;
1317
1318        "INSERT INTO mysql.found_rows (val) VALUES (1)"
1319            .ignore(&mut conn)
1320            .await?;
1321
1322        // Inserted one row, affected should be one.
1323        assert_eq!(conn.affected_rows(), 1);
1324
1325        "UPDATE mysql.found_rows SET val = 1 WHERE val = 1"
1326            .ignore(&mut conn)
1327            .await?;
1328
1329        // The query doesn't affect any rows.
1330        assert_eq!(conn.affected_rows(), 0);
1331
1332        Ok(())
1333    }
1334
1335    #[test]
1336    fn opts_should_satisfy_send_and_sync() {
1337        struct A<T: Sync + Send>(T);
1338        #[allow(clippy::unnecessary_operation)]
1339        A(get_opts());
1340    }
1341
1342    #[tokio::test]
1343    async fn should_connect_without_database() -> super::Result<()> {
1344        // no database name
1345        let mut conn: Conn = Conn::new(get_opts().db_name(None::<String>)).await?;
1346        conn.ping().await?;
1347        conn.disconnect().await?;
1348
1349        // empty database name
1350        let mut conn: Conn = Conn::new(get_opts().db_name(Some(""))).await?;
1351        conn.ping().await?;
1352        conn.disconnect().await?;
1353
1354        Ok(())
1355    }
1356
1357    #[tokio::test]
1358    async fn should_clean_state_if_wrapper_is_dropeed() -> super::Result<()> {
1359        let mut conn: Conn = Conn::new(get_opts()).await?;
1360
1361        conn.query_drop("CREATE TEMPORARY TABLE mysql.foo (id SERIAL)")
1362            .await?;
1363
1364        // dropped query:
1365        conn.query_iter("SELECT 1").await?;
1366        conn.ping().await?;
1367
1368        // dropped query in dropped transaction:
1369        let mut tx = conn.start_transaction(Default::default()).await?;
1370        tx.query_drop("INSERT INTO mysql.foo (id) VALUES (42)")
1371            .await?;
1372        tx.exec_iter("SELECT COUNT(*) FROM mysql.foo", ()).await?;
1373        drop(tx);
1374        conn.ping().await?;
1375
1376        let count: u8 = conn
1377            .query_first("SELECT COUNT(*) FROM mysql.foo")
1378            .await?
1379            .unwrap_or_default();
1380
1381        assert_eq!(count, 0);
1382
1383        Ok(())
1384    }
1385
1386    #[tokio::test]
1387    async fn should_connect() -> super::Result<()> {
1388        let mut conn: Conn = Conn::new(get_opts()).await?;
1389        conn.ping().await?;
1390        let plugins: Vec<String> = conn
1391            .query_map("SHOW PLUGINS", |mut row: crate::Row| {
1392                row.take("Name").unwrap()
1393            })
1394            .await?;
1395
1396        // Should connect with any combination of supported plugin and empty-nonempty password.
1397        let variants = vec![
1398            ("caching_sha2_password", 2_u8, "non-empty"),
1399            ("caching_sha2_password", 2_u8, ""),
1400            ("mysql_native_password", 0_u8, "non-empty"),
1401            ("mysql_native_password", 0_u8, ""),
1402        ]
1403        .into_iter()
1404        .filter(|variant| plugins.iter().any(|p| p == variant.0));
1405
1406        for (plug, val, pass) in variants {
1407            let _ = conn.query_drop("DROP USER 'test_user'@'%'").await;
1408
1409            let query = format!("CREATE USER 'test_user'@'%' IDENTIFIED WITH {}", plug);
1410            conn.query_drop(query).await.unwrap();
1411
1412            if (8, 0, 11) <= conn.inner.version && conn.inner.version <= (9, 0, 0) {
1413                conn.query_drop(format!("SET PASSWORD FOR 'test_user'@'%' = '{}'", pass))
1414                    .await
1415                    .unwrap();
1416            } else {
1417                conn.query_drop(format!("SET old_passwords = {}", val))
1418                    .await
1419                    .unwrap();
1420                conn.query_drop(format!(
1421                    "SET PASSWORD FOR 'test_user'@'%' = PASSWORD('{}')",
1422                    pass
1423                ))
1424                .await
1425                .unwrap();
1426            };
1427
1428            let opts = get_opts()
1429                .user(Some("test_user"))
1430                .pass(Some(pass))
1431                .db_name(None::<String>);
1432            let result = Conn::new(opts).await;
1433
1434            conn.query_drop("DROP USER 'test_user'@'%'").await.unwrap();
1435
1436            result?.disconnect().await?;
1437        }
1438
1439        if crate::test_misc::test_compression() {
1440            assert!(format!("{:?}", conn).contains("Compression"));
1441        }
1442
1443        if crate::test_misc::test_ssl() {
1444            assert!(format!("{:?}", conn).contains("Tls"));
1445        }
1446
1447        conn.disconnect().await?;
1448        Ok(())
1449    }
1450
1451    #[test]
1452    fn should_not_panic_if_dropped_without_tokio_runtime() {
1453        let fut = Conn::new(get_opts());
1454        let runtime = tokio::runtime::Builder::new_current_thread()
1455            .enable_all()
1456            .build()
1457            .unwrap();
1458        runtime.block_on(async {
1459            fut.await.unwrap();
1460        });
1461        // connection will drop here
1462    }
1463
1464    #[tokio::test]
1465    async fn should_execute_init_queries_on_new_connection() -> super::Result<()> {
1466        let opts = OptsBuilder::from_opts(get_opts()).init(vec!["SET @a = 42", "SET @b = 'foo'"]);
1467        let mut conn = Conn::new(opts).await?;
1468        let result: Vec<(u8, String)> = conn.query("SELECT @a, @b").await?;
1469        conn.disconnect().await?;
1470        assert_eq!(result, vec![(42, "foo".into())]);
1471        Ok(())
1472    }
1473
1474    #[tokio::test]
1475    async fn should_execute_setup_queries_on_reset() -> super::Result<()> {
1476        let opts = OptsBuilder::from_opts(get_opts()).setup(vec!["SET @a = 42", "SET @b = 'foo'"]);
1477        let mut conn = Conn::new(opts).await?;
1478
1479        // initial run
1480        let mut result: Vec<(u8, String)> = conn.query("SELECT @a, @b").await?;
1481        assert_eq!(result, vec![(42, "foo".into())]);
1482
1483        // after reset
1484        if conn.reset().await? {
1485            result = conn.query("SELECT @a, @b").await?;
1486            assert_eq!(result, vec![(42, "foo".into())]);
1487        }
1488
1489        // after change user
1490        conn.change_user(Default::default()).await?;
1491        result = conn.query("SELECT @a, @b").await?;
1492        assert_eq!(result, vec![(42, "foo".into())]);
1493
1494        conn.disconnect().await?;
1495        Ok(())
1496    }
1497
1498    #[tokio::test]
1499    async fn should_reset_the_connection() -> super::Result<()> {
1500        let mut conn = Conn::new(get_opts()).await?;
1501
1502        assert_eq!(
1503            conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1504            Value::NULL
1505        );
1506
1507        conn.query_drop("SET @foo = 'foo'").await?;
1508
1509        assert_eq!(
1510            conn.query_first::<String, _>("SELECT @foo").await?.unwrap(),
1511            "foo",
1512        );
1513
1514        if conn.reset().await? {
1515            assert_eq!(
1516                conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1517                Value::NULL
1518            );
1519        } else {
1520            assert_eq!(
1521                conn.query_first::<String, _>("SELECT @foo").await?.unwrap(),
1522                "foo",
1523            );
1524        }
1525
1526        conn.disconnect().await?;
1527        Ok(())
1528    }
1529
1530    #[tokio::test]
1531    async fn should_change_user() -> super::Result<()> {
1532        let mut conn = Conn::new(get_opts()).await?;
1533        assert_eq!(
1534            conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1535            Value::NULL
1536        );
1537
1538        conn.query_drop("SET @foo = 'foo'").await?;
1539
1540        assert_eq!(
1541            conn.query_first::<String, _>("SELECT @foo").await?.unwrap(),
1542            "foo",
1543        );
1544
1545        conn.change_user(Default::default()).await?;
1546        assert_eq!(
1547            conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1548            Value::NULL
1549        );
1550
1551        let plugins: &[&str] = if !conn.inner.is_mariadb && conn.server_version() >= (5, 8, 0) {
1552            &["mysql_native_password", "caching_sha2_password"]
1553        } else {
1554            &["mysql_native_password"]
1555        };
1556
1557        for plugin in plugins {
1558            let mut rng = rand::thread_rng();
1559            let mut pass = [0u8; 10];
1560            pass.try_fill(&mut rng).unwrap();
1561            let pass: String = IntoIterator::into_iter(pass)
1562                .map(|x| ((x % (123 - 97)) + 97) as char)
1563                .collect();
1564
1565            conn.query_drop("DELETE FROM mysql.user WHERE user = '__mats'")
1566                .await
1567                .unwrap();
1568            conn.query_drop("FLUSH PRIVILEGES").await.unwrap();
1569
1570            if conn.inner.is_mariadb || conn.server_version() < (5, 7, 0) {
1571                if matches!(conn.server_version(), (5, 6, _)) {
1572                    conn.query_drop("CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_old_password")
1573                        .await
1574                        .unwrap();
1575                    conn.query_drop(format!(
1576                        "SET PASSWORD FOR '__mats'@'%' = OLD_PASSWORD({})",
1577                        Value::from(pass.clone()).as_sql(false)
1578                    ))
1579                    .await
1580                    .unwrap();
1581                } else {
1582                    conn.query_drop("CREATE USER '__mats'@'%'").await.unwrap();
1583                    conn.query_drop(format!(
1584                        "SET PASSWORD FOR '__mats'@'%' = PASSWORD({})",
1585                        Value::from(pass.clone()).as_sql(false)
1586                    ))
1587                    .await
1588                    .unwrap();
1589                }
1590            } else {
1591                conn.query_drop(format!(
1592                    "CREATE USER '__mats'@'%' IDENTIFIED WITH {} BY {}",
1593                    plugin,
1594                    Value::from(pass.clone()).as_sql(false)
1595                ))
1596                .await
1597                .unwrap();
1598            };
1599
1600            conn.query_drop("FLUSH PRIVILEGES").await.unwrap();
1601
1602            let mut conn2 = Conn::new(get_opts().secure_auth(false)).await.unwrap();
1603            conn2
1604                .change_user(
1605                    ChangeUserOpts::default()
1606                        .with_db_name(None)
1607                        .with_user(Some("__mats".into()))
1608                        .with_pass(Some(pass)),
1609                )
1610                .await
1611                .unwrap();
1612            let (db, user) = conn2
1613                .query_first::<(Option<String>, String), _>("SELECT DATABASE(), USER();")
1614                .await
1615                .unwrap()
1616                .unwrap();
1617            assert_eq!(db, None);
1618            assert!(user.starts_with("__mats"));
1619
1620            conn2.disconnect().await.unwrap();
1621        }
1622
1623        conn.disconnect().await?;
1624        Ok(())
1625    }
1626
1627    #[tokio::test]
1628    async fn should_not_cache_statements_if_stmt_cache_size_is_zero() -> super::Result<()> {
1629        let opts = OptsBuilder::from_opts(get_opts()).stmt_cache_size(0);
1630
1631        let mut conn = Conn::new(opts).await?;
1632        conn.exec_drop("DO ?", (1_u8,)).await?;
1633
1634        let stmt = conn.prep("DO 2").await?;
1635        conn.exec_drop(&stmt, ()).await?;
1636        conn.exec_drop(&stmt, ()).await?;
1637        conn.close(stmt).await?;
1638
1639        conn.exec_drop("DO 3", ()).await?;
1640        conn.exec_batch("DO 4", vec![(), ()]).await?;
1641        conn.exec_first::<u8, _, _>("DO 5", ()).await?;
1642        let row: Option<(crate::Value, usize)> = conn
1643            .query_first("SHOW SESSION STATUS LIKE 'Com_stmt_close';")
1644            .await?;
1645
1646        assert_eq!(row.unwrap().1, 1);
1647        assert_eq!(conn.inner.stmt_cache.len(), 0);
1648
1649        conn.disconnect().await?;
1650
1651        Ok(())
1652    }
1653
1654    #[tokio::test]
1655    async fn should_hold_stmt_cache_size_bound() -> super::Result<()> {
1656        let opts = OptsBuilder::from_opts(get_opts()).stmt_cache_size(3);
1657        let mut conn = Conn::new(opts).await?;
1658        conn.exec_drop("DO 1", ()).await?;
1659        conn.exec_drop("DO 2", ()).await?;
1660        conn.exec_drop("DO 3", ()).await?;
1661        conn.exec_drop("DO 1", ()).await?;
1662        conn.exec_drop("DO 4", ()).await?;
1663        conn.exec_drop("DO 3", ()).await?;
1664        conn.exec_drop("DO 5", ()).await?;
1665        conn.exec_drop("DO 6", ()).await?;
1666        let row_opt = conn
1667            .query_first("SHOW SESSION STATUS LIKE 'Com_stmt_close';")
1668            .await?;
1669        let (_, count): (String, usize) = row_opt.unwrap();
1670        assert_eq!(count, 3);
1671        let order = conn
1672            .stmt_cache_ref()
1673            .iter()
1674            .map(|item| item.1.query.0.as_ref())
1675            .collect::<Vec<&[u8]>>();
1676        assert_eq!(order, &[b"DO 6", b"DO 5", b"DO 3"]);
1677        conn.disconnect().await?;
1678        Ok(())
1679    }
1680
1681    #[tokio::test]
1682    async fn should_perform_queries() -> super::Result<()> {
1683        let mut conn = Conn::new(get_opts()).await?;
1684        for x in (MAX_PAYLOAD_LEN - 2)..=(MAX_PAYLOAD_LEN + 2) {
1685            let long_string = "A".repeat(x);
1686            let result: Vec<(String, u8)> = conn
1687                .query(format!(r"SELECT '{}', 231", long_string))
1688                .await?;
1689            assert_eq!((long_string, 231_u8), result[0]);
1690        }
1691        conn.disconnect().await?;
1692        Ok(())
1693    }
1694
1695    #[tokio::test]
1696    async fn should_query_drop() -> super::Result<()> {
1697        let mut conn = Conn::new(get_opts()).await?;
1698        conn.query_drop("CREATE TEMPORARY TABLE tmp (id int DEFAULT 10, name text)")
1699            .await?;
1700        conn.query_drop("INSERT INTO tmp VALUES (1, 'foo')").await?;
1701        let result: Option<u8> = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
1702        conn.disconnect().await?;
1703        assert_eq!(result, Some(1_u8));
1704        Ok(())
1705    }
1706
1707    #[tokio::test]
1708    async fn should_prepare_statement() -> super::Result<()> {
1709        let mut conn = Conn::new(get_opts()).await?;
1710        let stmt = conn.prep(r"SELECT ?").await?;
1711        conn.close(stmt).await?;
1712        conn.disconnect().await?;
1713
1714        let mut conn = Conn::new(get_opts()).await?;
1715        let stmt = conn.prep(r"SELECT :foo").await?;
1716
1717        {
1718            let query = String::from("SELECT ?, ?");
1719            let stmt = conn.prep(&*query).await?;
1720            conn.close(stmt).await?;
1721            {
1722                let mut conn = Conn::new(get_opts()).await?;
1723                let stmt = conn.prep(&*query).await?;
1724                conn.close(stmt).await?;
1725                conn.disconnect().await?;
1726            }
1727        }
1728
1729        conn.close(stmt).await?;
1730        conn.disconnect().await?;
1731
1732        Ok(())
1733    }
1734
1735    #[tokio::test]
1736    async fn should_execute_statement() -> super::Result<()> {
1737        let long_string = "A".repeat(18 * 1024 * 1024);
1738        let mut conn = Conn::new(get_opts()).await?;
1739        let stmt = conn.prep(r"SELECT ?").await?;
1740        let result = conn.exec_iter(&stmt, (&long_string,)).await?;
1741        let mut mapped = result.map_and_drop(from_row::<(String,)>).await?;
1742        assert_eq!(mapped.len(), 1);
1743        assert_eq!(mapped.pop(), Some((long_string,)));
1744        let result = conn.exec_iter(&stmt, (42_u8,)).await?;
1745        let collected = result.collect_and_drop::<(u8,)>().await?;
1746        assert_eq!(collected, vec![(42u8,)]);
1747        let result = conn.exec_iter(&stmt, (8_u8,)).await?;
1748        let reduced = result
1749            .reduce_and_drop(2, |mut acc, row| {
1750                acc += from_row::<i32>(row);
1751                acc
1752            })
1753            .await?;
1754        conn.close(stmt).await?;
1755        conn.disconnect().await?;
1756        assert_eq!(reduced, 10);
1757
1758        let mut conn = Conn::new(get_opts()).await?;
1759        let stmt = conn.prep(r"SELECT :foo, :bar, :foo, 3").await?;
1760        let result = conn
1761            .exec_iter(&stmt, params! { "foo" => "quux", "bar" => "baz" })
1762            .await?;
1763        let mut mapped = result
1764            .map_and_drop(from_row::<(String, String, String, u8)>)
1765            .await?;
1766        assert_eq!(mapped.len(), 1);
1767        assert_eq!(
1768            mapped.pop(),
1769            Some(("quux".into(), "baz".into(), "quux".into(), 3))
1770        );
1771        let result = conn
1772            .exec_iter(&stmt, params! { "foo" => 2, "bar" => 3 })
1773            .await?;
1774        let collected = result.collect_and_drop::<(u8, u8, u8, u8)>().await?;
1775        assert_eq!(collected, vec![(2, 3, 2, 3)]);
1776        let result = conn
1777            .exec_iter(&stmt, params! { "foo" => 2, "bar" => 3 })
1778            .await?;
1779        let reduced = result
1780            .reduce_and_drop(0, |acc, row| {
1781                let (a, b, c, d): (u8, u8, u8, u8) = from_row(row);
1782                acc + a + b + c + d
1783            })
1784            .await?;
1785        conn.close(stmt).await?;
1786        conn.disconnect().await?;
1787        assert_eq!(reduced, 10);
1788        Ok(())
1789    }
1790
1791    #[tokio::test]
1792    async fn should_prep_exec_statement() -> super::Result<()> {
1793        let mut conn = Conn::new(get_opts()).await?;
1794        let result = conn
1795            .exec_iter(r"SELECT :a, :b, :a", params! { "a" => 2, "b" => 3 })
1796            .await?;
1797        let output = result
1798            .map_and_drop(|row| {
1799                let (a, b, c): (u8, u8, u8) = from_row(row);
1800                a * b * c
1801            })
1802            .await?;
1803        conn.disconnect().await?;
1804        assert_eq!(output[0], 12u8);
1805        Ok(())
1806    }
1807
1808    #[tokio::test]
1809    async fn should_first_exec_statement() -> super::Result<()> {
1810        let mut conn = Conn::new(get_opts()).await?;
1811        let output = conn
1812            .exec_first(
1813                r"SELECT :a UNION ALL SELECT :b",
1814                params! { "a" => 2, "b" => 3 },
1815            )
1816            .await?;
1817        conn.disconnect().await?;
1818        assert_eq!(output, Some(2u8));
1819        Ok(())
1820    }
1821
1822    #[tokio::test]
1823    async fn issue_107() -> super::Result<()> {
1824        let mut conn = Conn::new(get_opts()).await?;
1825        conn.query_drop(
1826            r"CREATE TEMPORARY TABLE mysql.issue (
1827                    a BIGINT(20) UNSIGNED,
1828                    b VARBINARY(16),
1829                    c BINARY(32),
1830                    d BIGINT(20) UNSIGNED,
1831                    e BINARY(32)
1832                )",
1833        )
1834        .await?;
1835        conn.query_drop(
1836            r"INSERT INTO mysql.issue VALUES (
1837                    0,
1838                    0xC066F966B0860000,
1839                    0x7939DA98E524C5F969FC2DE8D905FD9501EBC6F20001B0A9C941E0BE6D50CF44,
1840                    0,
1841                    ''
1842                ), (
1843                    1,
1844                    '',
1845                    0x076311DF4D407B0854371BA13A5F3FB1A4555AC22B361375FD47B263F31822F2,
1846                    0,
1847                    ''
1848                )",
1849        )
1850        .await?;
1851
1852        let q = "SELECT b, c, d, e FROM mysql.issue";
1853        let result = conn.query_iter(q).await?;
1854
1855        let loaded_structs = result
1856            .map_and_drop(crate::from_row::<(Vec<u8>, Vec<u8>, u64, Vec<u8>)>)
1857            .await?;
1858
1859        conn.disconnect().await?;
1860
1861        assert_eq!(loaded_structs.len(), 2);
1862
1863        Ok(())
1864    }
1865
1866    #[tokio::test]
1867    async fn should_run_transactions() -> super::Result<()> {
1868        let mut conn = Conn::new(get_opts()).await?;
1869        conn.query_drop("CREATE TEMPORARY TABLE tmp (id INT, name TEXT)")
1870            .await?;
1871        let mut transaction = conn.start_transaction(Default::default()).await?;
1872        transaction
1873            .query_drop("INSERT INTO tmp VALUES (1, 'foo'), (2, 'bar')")
1874            .await?;
1875        assert_eq!(transaction.last_insert_id(), None);
1876        assert_eq!(transaction.affected_rows(), 2);
1877        assert_eq!(transaction.get_warnings(), 0);
1878        assert_eq!(transaction.info(), "Records: 2  Duplicates: 0  Warnings: 0");
1879        transaction.commit().await?;
1880        let output_opt = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
1881        assert_eq!(output_opt, Some((2u8,)));
1882        let mut transaction = conn.start_transaction(Default::default()).await?;
1883        transaction
1884            .query_drop("INSERT INTO tmp VALUES (3, 'baz'), (4, 'quux')")
1885            .await?;
1886        let output_opt = transaction
1887            .exec_first("SELECT COUNT(*) FROM tmp", ())
1888            .await?;
1889        assert_eq!(output_opt, Some((4u8,)));
1890        transaction.rollback().await?;
1891        let output_opt = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
1892        assert_eq!(output_opt, Some((2u8,)));
1893
1894        let mut transaction = conn.start_transaction(Default::default()).await?;
1895        transaction
1896            .query_drop("INSERT INTO tmp VALUES (3, 'baz')")
1897            .await?;
1898        drop(transaction); // implicit rollback
1899        let output_opt = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
1900        assert_eq!(output_opt, Some((2u8,)));
1901
1902        conn.disconnect().await?;
1903        Ok(())
1904    }
1905
1906    #[tokio::test]
1907    async fn should_handle_multiresult_set_with_error() -> super::Result<()> {
1908        const QUERY_FIRST: &str = "SELECT * FROM tmp; SELECT 1; SELECT 2;";
1909        const QUERY_MIDDLE: &str = "SELECT 1; SELECT * FROM tmp; SELECT 2";
1910        let mut conn = Conn::new(get_opts()).await.unwrap();
1911
1912        // if error is in the first result set, then query should return it immediately.
1913        let result = QUERY_FIRST.run(&mut conn).await;
1914        assert!(matches!(result, Err(Error::Server(_))));
1915
1916        let mut result = QUERY_MIDDLE.run(&mut conn).await.unwrap();
1917
1918        // first result set will contain one row
1919        let result_set: Vec<u8> = result.collect().await.unwrap();
1920        assert_eq!(result_set, vec![1]);
1921
1922        // second result set will contain an error.
1923        let result_set: super::Result<Vec<u8>> = result.collect().await;
1924        assert!(matches!(result_set, Err(Error::Server(_))));
1925
1926        // there will be no third result set
1927        assert!(result.is_empty());
1928
1929        conn.ping().await?;
1930        conn.disconnect().await?;
1931
1932        Ok(())
1933    }
1934
1935    #[tokio::test]
1936    async fn should_handle_binary_multiresult_set_with_error() -> super::Result<()> {
1937        const PROC_DEF_FIRST: &str =
1938            r#"CREATE PROCEDURE err_first() BEGIN SELECT * FROM tmp; SELECT 1; END"#;
1939        const PROC_DEF_MIDDLE: &str =
1940            r#"CREATE PROCEDURE err_middle() BEGIN SELECT 1; SELECT * FROM tmp; SELECT 2; END"#;
1941
1942        let mut conn = Conn::new(get_opts()).await.unwrap();
1943
1944        conn.query_drop("DROP PROCEDURE IF EXISTS err_first")
1945            .await?;
1946        conn.query_iter(PROC_DEF_FIRST).await?;
1947
1948        conn.query_drop("DROP PROCEDURE IF EXISTS err_middle")
1949            .await?;
1950        conn.query_iter(PROC_DEF_MIDDLE).await?;
1951
1952        // if error is in the first result set, then query should return it immediately.
1953        let result = conn.query_iter("CALL err_first()").await;
1954        assert!(matches!(result, Err(Error::Server(_))));
1955
1956        let mut result = conn.query_iter("CALL err_middle()").await?;
1957
1958        // first result set will contain one row
1959        let result_set: Vec<u8> = result.collect().await.unwrap();
1960        assert_eq!(result_set, vec![1]);
1961
1962        // second result set will contain an error.
1963        let result_set: super::Result<Vec<u8>> = result.collect().await;
1964        assert!(matches!(result_set, Err(Error::Server(_))));
1965
1966        // there will be no third result set
1967        assert!(result.is_empty());
1968
1969        conn.ping().await?;
1970        conn.disconnect().await?;
1971
1972        Ok(())
1973    }
1974
1975    #[tokio::test]
1976    async fn should_handle_multiresult_set_with_local_infile() -> super::Result<()> {
1977        use std::fs::write;
1978
1979        let file_path = tempfile::Builder::new().tempfile_in("").unwrap();
1980        let file_path = file_path.path();
1981        let file_name = file_path.file_name().unwrap();
1982
1983        write(file_name, b"AAAAAA\nBBBBBB\nCCCCCC\n")?;
1984
1985        let opts = get_opts().local_infile_handler(Some(WhiteListFsHandler::new(&[file_name][..])));
1986
1987        // LOCAL INFILE in the middle of a multi-result set should not break anything.
1988        let mut conn = Conn::new(opts).await.unwrap();
1989        "CREATE TEMPORARY TABLE tmp (a TEXT)".run(&mut conn).await?;
1990
1991        let query = format!(
1992            r#"SELECT * FROM tmp;
1993            LOAD DATA LOCAL INFILE "{}" INTO TABLE tmp;
1994            LOAD DATA LOCAL INFILE "{}" INTO TABLE tmp;
1995            SELECT * FROM tmp"#,
1996            file_name.to_str().unwrap(),
1997            file_name.to_str().unwrap(),
1998        );
1999
2000        let mut result = query.run(&mut conn).await?;
2001
2002        let result_set = result.collect::<String>().await?;
2003        assert_eq!(result_set.len(), 0);
2004
2005        let mut no_local_infile = false;
2006
2007        for _ in 0..2 {
2008            match result.collect::<String>().await {
2009                Ok(result_set) => {
2010                    assert_eq!(result.affected_rows(), 3);
2011                    assert!(result_set.is_empty())
2012                }
2013                Err(Error::Server(ref err)) if err.code == 1148 => {
2014                    // The used command is not allowed with this MySQL version
2015                    no_local_infile = true;
2016                    break;
2017                }
2018                Err(Error::Server(ref err)) if err.code == 3948 => {
2019                    // Loading local data is disabled;
2020                    // this must be enabled on both the client and server sides
2021                    no_local_infile = true;
2022                    break;
2023                }
2024                Err(err) => return Err(err),
2025            }
2026        }
2027
2028        if no_local_infile {
2029            assert!(result.is_empty());
2030            assert_eq!(result_set.len(), 0);
2031        } else {
2032            let result_set = result.collect::<String>().await?;
2033            assert_eq!(result_set.len(), 6);
2034            assert_eq!(result_set[0], "AAAAAA");
2035            assert_eq!(result_set[1], "BBBBBB");
2036            assert_eq!(result_set[2], "CCCCCC");
2037            assert_eq!(result_set[3], "AAAAAA");
2038            assert_eq!(result_set[4], "BBBBBB");
2039            assert_eq!(result_set[5], "CCCCCC");
2040        }
2041
2042        conn.ping().await?;
2043        conn.disconnect().await?;
2044
2045        Ok(())
2046    }
2047
2048    #[tokio::test]
2049    async fn should_provide_multiresult_set_metadata() -> super::Result<()> {
2050        let mut c = Conn::new(get_opts()).await?;
2051        c.query_drop("CREATE TEMPORARY TABLE tmp (id INT, foo TEXT)")
2052            .await?;
2053
2054        let mut result = c
2055            .query_iter("SELECT 1; SELECT id, foo FROM tmp WHERE 1 = 2; DO 42; SELECT 2;")
2056            .await?;
2057        assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 1);
2058
2059        result.for_each(drop).await?;
2060        assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 2);
2061
2062        result.for_each(drop).await?;
2063        assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 0);
2064
2065        result.for_each(drop).await?;
2066        assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 1);
2067
2068        c.disconnect().await?;
2069        Ok(())
2070    }
2071
2072    #[tokio::test]
2073    async fn should_expose_query_result_metadata() -> super::Result<()> {
2074        let pool = Pool::new(get_opts());
2075        let mut c = pool.get_conn().await?;
2076
2077        c.query_drop(
2078            r"
2079            CREATE TEMPORARY TABLE `foo`
2080                ( `id` SERIAL
2081                , `bar_id` varchar(36) NOT NULL
2082                , `baz_id` varchar(36) NOT NULL
2083                , `ctime` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP()
2084                , PRIMARY KEY (`id`)
2085                , KEY `bar_idx` (`bar_id`)
2086                , KEY `baz_idx` (`baz_id`)
2087            );",
2088        )
2089        .await?;
2090
2091        const QUERY: &str = "INSERT INTO foo (bar_id, baz_id) VALUES (?, ?)";
2092        let params = ("qwerty", "data.employee_id");
2093
2094        let query_result = c.exec_iter(QUERY, params).await?;
2095        assert_eq!(query_result.last_insert_id(), Some(1));
2096        query_result.drop_result().await?;
2097
2098        c.exec_drop(QUERY, params).await?;
2099        assert_eq!(c.last_insert_id(), Some(2));
2100
2101        let mut tx = c.start_transaction(Default::default()).await?;
2102
2103        tx.exec_drop(QUERY, params).await?;
2104        assert_eq!(tx.last_insert_id(), Some(3));
2105
2106        Ok(())
2107    }
2108
2109    #[tokio::test]
2110    async fn should_handle_local_infile_locally() -> super::Result<()> {
2111        let mut conn = Conn::new(get_opts()).await.unwrap();
2112        conn.query_drop("CREATE TEMPORARY TABLE tmp (a TEXT);")
2113            .await
2114            .unwrap();
2115
2116        conn.set_infile_handler(async move {
2117            Ok(
2118                stream::iter([Bytes::from("AAAAAA\n"), Bytes::from("BBBBBB\nCCCCCC\n")])
2119                    .map(Ok)
2120                    .boxed(),
2121            )
2122        });
2123
2124        match conn
2125            .query_drop(r#"LOAD DATA LOCAL INFILE "dummy" INTO TABLE tmp;"#)
2126            .await
2127        {
2128            Ok(_) => (),
2129            Err(super::Error::Server(ref err)) if err.code == 1148 => {
2130                // The used command is not allowed with this MySQL version
2131                return Ok(());
2132            }
2133            Err(super::Error::Server(ref err)) if err.code == 3948 => {
2134                // Loading local data is disabled;
2135                // this must be enabled on both the client and server sides
2136                return Ok(());
2137            }
2138            e @ Err(_) => e.unwrap(),
2139        };
2140
2141        let result: Vec<String> = conn.query("SELECT * FROM tmp").await?;
2142        assert_eq!(result.len(), 3);
2143        assert_eq!(result[0], "AAAAAA");
2144        assert_eq!(result[1], "BBBBBB");
2145        assert_eq!(result[2], "CCCCCC");
2146
2147        Ok(())
2148    }
2149
2150    #[tokio::test]
2151    async fn should_handle_local_infile_globally() -> super::Result<()> {
2152        use std::fs::write;
2153
2154        let file_path = tempfile::Builder::new().tempfile_in("").unwrap();
2155        let file_path = file_path.path();
2156        let file_name = file_path.file_name().unwrap();
2157
2158        write(file_name, b"AAAAAA\nBBBBBB\nCCCCCC\n")?;
2159
2160        let opts = get_opts().local_infile_handler(Some(WhiteListFsHandler::new(&[file_name][..])));
2161
2162        let mut conn = Conn::new(opts).await.unwrap();
2163        conn.query_drop("CREATE TEMPORARY TABLE tmp (a TEXT);")
2164            .await
2165            .unwrap();
2166
2167        match conn
2168            .query_drop(format!(
2169                r#"LOAD DATA LOCAL INFILE "{}" INTO TABLE tmp;"#,
2170                file_name.to_str().unwrap(),
2171            ))
2172            .await
2173        {
2174            Ok(_) => (),
2175            Err(super::Error::Server(ref err)) if err.code == 1148 => {
2176                // The used command is not allowed with this MySQL version
2177                return Ok(());
2178            }
2179            Err(super::Error::Server(ref err)) if err.code == 3948 => {
2180                // Loading local data is disabled;
2181                // this must be enabled on both the client and server sides
2182                return Ok(());
2183            }
2184            e @ Err(_) => e.unwrap(),
2185        };
2186
2187        let result: Vec<String> = conn.query("SELECT * FROM tmp").await?;
2188        assert_eq!(result.len(), 3);
2189        assert_eq!(result[0], "AAAAAA");
2190        assert_eq!(result[1], "BBBBBB");
2191        assert_eq!(result[2], "CCCCCC");
2192
2193        Ok(())
2194    }
2195
2196    #[cfg(feature = "nightly")]
2197    mod bench {
2198        use crate::{conn::Conn, queryable::Queryable, test_misc::get_opts};
2199
2200        #[bench]
2201        fn simple_exec(bencher: &mut test::Bencher) {
2202            let mut runtime = tokio::runtime::Runtime::new().unwrap();
2203            let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2204
2205            bencher.iter(|| {
2206                runtime.block_on(conn.query_drop("DO 1")).unwrap();
2207            });
2208
2209            runtime.block_on(conn.disconnect()).unwrap();
2210        }
2211
2212        #[bench]
2213        fn select_large_string(bencher: &mut test::Bencher) {
2214            let mut runtime = tokio::runtime::Runtime::new().unwrap();
2215            let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2216
2217            bencher.iter(|| {
2218                runtime
2219                    .block_on(conn.query_drop("SELECT REPEAT('A', 10000)"))
2220                    .unwrap();
2221            });
2222
2223            runtime.block_on(conn.disconnect()).unwrap();
2224        }
2225
2226        #[bench]
2227        fn prepared_exec(bencher: &mut test::Bencher) {
2228            let mut runtime = tokio::runtime::Runtime::new().unwrap();
2229            let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2230            let stmt = runtime.block_on(conn.prep("DO 1")).unwrap();
2231
2232            bencher.iter(|| {
2233                runtime.block_on(conn.exec_drop(&stmt, ())).unwrap();
2234            });
2235
2236            runtime.block_on(conn.close(stmt)).unwrap();
2237            runtime.block_on(conn.disconnect()).unwrap();
2238        }
2239
2240        #[bench]
2241        fn prepare_and_exec(bencher: &mut test::Bencher) {
2242            let mut runtime = tokio::runtime::Runtime::new().unwrap();
2243            let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2244
2245            bencher.iter(|| {
2246                runtime.block_on(conn.exec_drop("SELECT ?", (0,))).unwrap();
2247            });
2248
2249            runtime.block_on(conn.disconnect()).unwrap();
2250        }
2251    }
2252}