mysql_async/conn/
mod.rs

1// Copyright (c) 2016 Anatoly Ikorsky
2//
3// Licensed under the Apache License, Version 2.0
4// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. All files in the project carrying such notice may not be copied,
7// modified, or distributed except according to those terms.
8
9use futures_util::FutureExt;
10pub use mysql_common::named_params;
11
12use mysql_common::{
13    constants::{DEFAULT_MAX_ALLOWED_PACKET, UTF8MB4_GENERAL_CI, UTF8_GENERAL_CI},
14    crypto,
15    io::ParseBuf,
16    packets::{
17        binlog_request::BinlogRequest, AuthPlugin, AuthSwitchRequest, CommonOkPacket, ErrPacket,
18        HandshakePacket, HandshakeResponse, OkPacket, OkPacketDeserializer, OldAuthSwitchRequest,
19        OldEofPacket, ResultSetTerminator, SslRequest,
20    },
21    proto::MySerialize,
22    row::Row,
23};
24
25use std::{
26    borrow::Cow,
27    fmt,
28    future::Future,
29    mem::{self, replace},
30    pin::Pin,
31    str::FromStr,
32    sync::Arc,
33    time::{Duration, Instant},
34};
35
36use crate::{
37    buffer_pool::PooledBuf,
38    conn::{pool::Pool, stmt_cache::StmtCache},
39    consts::{CapabilityFlags, Command, StatusFlags},
40    error::*,
41    io::Stream,
42    opts::Opts,
43    queryable::{
44        query_result::{QueryResult, ResultSetMeta},
45        transaction::TxStatus,
46        BinaryProtocol, Queryable, TextProtocol,
47    },
48    BinlogStream, ChangeUserOpts, InfileData, OptsBuilder,
49};
50
51use self::routines::Routine;
52
53pub mod binlog_stream;
54pub mod pool;
55pub mod routines;
56pub mod stmt_cache;
57
58const DEFAULT_WAIT_TIMEOUT: usize = 28800;
59
60/// Helper that asynchronously disconnects the givent connection on the default tokio executor.
61fn disconnect(mut conn: Conn) {
62    let disconnected = conn.inner.disconnected;
63
64    // Mark conn as disconnected.
65    conn.inner.disconnected = true;
66
67    if !disconnected {
68        // We shouldn't call tokio::spawn if unwinding
69        if std::thread::panicking() {
70            return;
71        }
72
73        // Server will report broken connection if spawn fails.
74        // this might fail if, say, the runtime is shutting down, but we've done what we could
75        if let Ok(handle) = tokio::runtime::Handle::try_current() {
76            handle.spawn(async move {
77                if let Ok(conn) = conn.cleanup_for_pool().await {
78                    let _ = conn.disconnect().await;
79                }
80            });
81        }
82    }
83}
84
85/// Pending result set.
86#[derive(Debug, Clone)]
87pub(crate) enum PendingResult {
88    /// There is a pending result set.
89    Pending(ResultSetMeta),
90    /// Result set metadata was taken but not yet consumed.
91    Taken(Arc<ResultSetMeta>),
92}
93
94/// Mysql connection
95struct ConnInner {
96    stream: Option<Stream>,
97    id: u32,
98    is_mariadb: bool,
99    version: (u16, u16, u16),
100    socket: Option<String>,
101    capabilities: CapabilityFlags,
102    status: StatusFlags,
103    last_ok_packet: Option<OkPacket<'static>>,
104    last_err_packet: Option<mysql_common::packets::ServerError<'static>>,
105    pool: Option<Pool>,
106    pending_result: std::result::Result<Option<PendingResult>, ServerError>,
107    tx_status: TxStatus,
108    reset_upon_returning_to_a_pool: bool,
109    opts: Opts,
110    last_io: Instant,
111    wait_timeout: Duration,
112    stmt_cache: StmtCache,
113    nonce: Vec<u8>,
114    auth_plugin: AuthPlugin<'static>,
115    auth_switched: bool,
116    server_key: Option<Vec<u8>>,
117    /// Connection is already disconnected.
118    pub(crate) disconnected: bool,
119    /// One-time connection-level infile handler.
120    infile_handler:
121        Option<Pin<Box<dyn Future<Output = crate::Result<InfileData>> + Send + Sync + 'static>>>,
122}
123
124impl fmt::Debug for ConnInner {
125    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126        f.debug_struct("Conn")
127            .field("connection id", &self.id)
128            .field("server version", &self.version)
129            .field("pool", &self.pool)
130            .field("pending_result", &self.pending_result)
131            .field("tx_status", &self.tx_status)
132            .field("stream", &self.stream)
133            .field("options", &self.opts)
134            .field("server_key", &self.server_key)
135            .field("auth_plugin", &self.auth_plugin)
136            .finish()
137    }
138}
139
140impl ConnInner {
141    /// Constructs an empty connection.
142    fn empty(opts: Opts) -> ConnInner {
143        ConnInner {
144            capabilities: opts.get_capabilities(),
145            status: StatusFlags::empty(),
146            last_ok_packet: None,
147            last_err_packet: None,
148            stream: None,
149            is_mariadb: false,
150            version: (0, 0, 0),
151            id: 0,
152            pending_result: Ok(None),
153            pool: None,
154            tx_status: TxStatus::None,
155            last_io: Instant::now(),
156            wait_timeout: Duration::from_secs(0),
157            stmt_cache: StmtCache::new(opts.stmt_cache_size()),
158            socket: opts.socket().map(Into::into),
159            opts,
160            nonce: Vec::default(),
161            auth_plugin: AuthPlugin::MysqlNativePassword,
162            auth_switched: false,
163            disconnected: false,
164            server_key: None,
165            infile_handler: None,
166            reset_upon_returning_to_a_pool: false,
167        }
168    }
169
170    /// Returns mutable reference to a connection stream.
171    ///
172    /// Returns `DriverError::ConnectionClosed` if there is no stream.
173    fn stream_mut(&mut self) -> Result<&mut Stream> {
174        self.stream
175            .as_mut()
176            .ok_or_else(|| DriverError::ConnectionClosed.into())
177    }
178}
179
180/// MySql server connection.
181#[derive(Debug)]
182pub struct Conn {
183    inner: Box<ConnInner>,
184}
185
186impl Conn {
187    /// Returns connection identifier.
188    pub fn id(&self) -> u32 {
189        self.inner.id
190    }
191
192    /// Returns the ID generated by a query (usually `INSERT`) on a table with a column having the
193    /// `AUTO_INCREMENT` attribute. Returns `None` if there was no previous query on the connection
194    /// or if the query did not update an AUTO_INCREMENT value.
195    pub fn last_insert_id(&self) -> Option<u64> {
196        self.inner
197            .last_ok_packet
198            .as_ref()
199            .and_then(|ok| ok.last_insert_id())
200    }
201
202    /// Returns the number of rows affected by the last `INSERT`, `UPDATE`, `REPLACE` or `DELETE`
203    /// query.
204    pub fn affected_rows(&self) -> u64 {
205        self.inner
206            .last_ok_packet
207            .as_ref()
208            .map(|ok| ok.affected_rows())
209            .unwrap_or_default()
210    }
211
212    /// Text information, as reported by the server in the last OK packet, or an empty string.
213    pub fn info(&self) -> Cow<'_, str> {
214        self.inner
215            .last_ok_packet
216            .as_ref()
217            .and_then(|ok| ok.info_str())
218            .unwrap_or_else(|| "".into())
219    }
220
221    /// Number of warnings, as reported by the server in the last OK packet, or `0`.
222    pub fn get_warnings(&self) -> u16 {
223        self.inner
224            .last_ok_packet
225            .as_ref()
226            .map(|ok| ok.warnings())
227            .unwrap_or_default()
228    }
229
230    /// Returns a reference to the last OK packet.
231    pub fn last_ok_packet(&self) -> Option<&OkPacket<'static>> {
232        self.inner.last_ok_packet.as_ref()
233    }
234
235    /// Turns on/off automatic connection reset (see [`crate::PoolOpts::with_reset_connection`]).
236    ///
237    /// Only makes sense for pooled connections.
238    pub fn reset_connection(&mut self, reset_connection: bool) {
239        self.inner.reset_upon_returning_to_a_pool = reset_connection;
240    }
241
242    pub(crate) fn stream_mut(&mut self) -> Result<&mut Stream> {
243        self.inner.stream_mut()
244    }
245
246    pub(crate) fn capabilities(&self) -> CapabilityFlags {
247        self.inner.capabilities
248    }
249
250    /// Will update last IO time for this connection.
251    pub(crate) fn touch(&mut self) {
252        self.inner.last_io = Instant::now();
253    }
254
255    /// Will set packet sequence id to `0`.
256    pub(crate) fn reset_seq_id(&mut self) {
257        if let Some(stream) = self.inner.stream.as_mut() {
258            stream.reset_seq_id();
259        }
260    }
261
262    /// Will syncronize sequence ids between compressed and uncompressed codecs.
263    pub(crate) fn sync_seq_id(&mut self) {
264        if let Some(stream) = self.inner.stream.as_mut() {
265            stream.sync_seq_id();
266        }
267    }
268
269    /// Handles OK packet.
270    pub(crate) fn handle_ok(&mut self, ok_packet: OkPacket<'static>) {
271        self.inner.status = ok_packet.status_flags();
272        self.inner.last_err_packet = None;
273        self.inner.last_ok_packet = Some(ok_packet);
274    }
275
276    /// Handles ERR packet.
277    pub(crate) fn handle_err(&mut self, err_packet: ErrPacket<'_>) -> Result<()> {
278        match err_packet {
279            ErrPacket::Error(err) => {
280                self.inner.status = StatusFlags::empty();
281                self.inner.last_ok_packet = None;
282                self.inner.last_err_packet = Some(err.clone().into_owned());
283                Err(Error::from(err))
284            }
285            ErrPacket::Progress(_) => Ok(()),
286        }
287    }
288
289    /// Returns the current transaction status.
290    pub(crate) fn get_tx_status(&self) -> TxStatus {
291        self.inner.tx_status
292    }
293
294    /// Sets the given transaction status for this connection.
295    pub(crate) fn set_tx_status(&mut self, tx_status: TxStatus) {
296        self.inner.tx_status = tx_status;
297    }
298
299    /// Returns pending result metadata, if any.
300    ///
301    /// If `Some(_)`, then result is not yet consumed.
302    pub(crate) fn use_pending_result(
303        &mut self,
304    ) -> std::result::Result<Option<&PendingResult>, ServerError> {
305        if let Err(ref e) = self.inner.pending_result {
306            let e = e.clone();
307            self.inner.pending_result = Ok(None);
308            return Err(e);
309        } else {
310            Ok(self.inner.pending_result.as_ref().unwrap().as_ref())
311        }
312    }
313
314    pub(crate) fn get_pending_result(
315        &self,
316    ) -> std::result::Result<Option<&PendingResult>, &ServerError> {
317        self.inner.pending_result.as_ref().map(|x| x.as_ref())
318    }
319
320    pub(crate) fn has_pending_result(&self) -> bool {
321        matches!(self.inner.pending_result, Err(_))
322            || matches!(self.inner.pending_result, Ok(Some(_)))
323    }
324
325    /// Sets the given pening result metadata for this connection. Returns the previous value.
326    pub(crate) fn set_pending_result(
327        &mut self,
328        meta: Option<ResultSetMeta>,
329    ) -> std::result::Result<Option<PendingResult>, ServerError> {
330        replace(
331            &mut self.inner.pending_result,
332            Ok(meta.map(PendingResult::Pending)),
333        )
334    }
335
336    pub(crate) fn set_pending_result_error(
337        &mut self,
338        error: ServerError,
339    ) -> std::result::Result<Option<PendingResult>, ServerError> {
340        replace(&mut self.inner.pending_result, Err(error))
341    }
342
343    /// Gives the currently pending result to a caller for consumption.
344    pub(crate) fn take_pending_result(
345        &mut self,
346    ) -> std::result::Result<Option<Arc<ResultSetMeta>>, ServerError> {
347        let mut output = None;
348
349        self.inner.pending_result = match replace(&mut self.inner.pending_result, Ok(None))? {
350            Some(PendingResult::Pending(x)) => {
351                let meta = Arc::new(x);
352                output = Some(meta.clone());
353                Ok(Some(PendingResult::Taken(meta)))
354            }
355            x => Ok(x),
356        };
357
358        Ok(output)
359    }
360
361    /// Returns current status flags.
362    pub(crate) fn status(&self) -> StatusFlags {
363        self.inner.status
364    }
365
366    pub(crate) async fn routine<'a, F, T>(&mut self, mut f: F) -> crate::Result<T>
367    where
368        F: Routine<T> + 'a,
369    {
370        self.inner.disconnected = true;
371        let result = f.call(&mut *self).await;
372        match result {
373            result @ Ok(_) | result @ Err(crate::Error::Server(_)) => {
374                // either OK or non-fatal error
375                self.inner.disconnected = false;
376                result
377            }
378            Err(err) => {
379                if self.inner.stream.is_some() {
380                    self.take_stream().close().await?;
381                }
382                Err(err)
383            }
384        }
385    }
386
387    /// Returns server version.
388    pub fn server_version(&self) -> (u16, u16, u16) {
389        self.inner.version
390    }
391
392    /// Returns connection options.
393    pub fn opts(&self) -> &Opts {
394        &self.inner.opts
395    }
396
397    /// Setup _local_ `LOCAL INFILE` handler (see ["LOCAL INFILE Handlers"][2] section
398    /// of the crate-level docs).
399    ///
400    /// It'll overwrite existing _local_ handler, if any.
401    ///
402    /// [2]: ../mysql_async/#local-infile-handlers
403    pub fn set_infile_handler<T>(&mut self, handler: T)
404    where
405        T: Future<Output = crate::Result<InfileData>>,
406        T: Send + Sync + 'static,
407    {
408        self.inner.infile_handler = Some(Box::pin(handler));
409    }
410
411    fn take_stream(&mut self) -> Stream {
412        self.inner.stream.take().unwrap()
413    }
414
415    /// Disconnects this connection from server.
416    pub async fn disconnect(mut self) -> Result<()> {
417        if !self.inner.disconnected {
418            self.inner.disconnected = true;
419            self.write_command_data(Command::COM_QUIT, &[]).await?;
420            let stream = self.take_stream();
421            stream.close().await?;
422        }
423        Ok(())
424    }
425
426    /// Closes the connection.
427    async fn close_conn(mut self) -> Result<()> {
428        self = self.cleanup_for_pool().await?;
429        self.disconnect().await
430    }
431
432    /// Returns true if io stream is encrypted.
433    fn is_secure(&self) -> bool {
434        #[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))]
435        {
436            self.inner
437                .stream
438                .as_ref()
439                .map(|x| x.is_secure())
440                .unwrap_or_default()
441        }
442
443        #[cfg(not(any(feature = "native-tls-tls", feature = "rustls-tls")))]
444        false
445    }
446
447    /// Returns true if io stream is socket.
448    fn is_socket(&self) -> bool {
449        #[cfg(unix)]
450        {
451            self.inner
452                .stream
453                .as_ref()
454                .map(|x| x.is_socket())
455                .unwrap_or_default()
456        }
457
458        #[cfg(not(unix))]
459        false
460    }
461
462    /// Hacky way to move connection through &mut. `self` becomes unusable.
463    fn take(&mut self) -> Conn {
464        mem::replace(self, Conn::empty(Default::default()))
465    }
466
467    fn empty(opts: Opts) -> Self {
468        Self {
469            inner: Box::new(ConnInner::empty(opts)),
470        }
471    }
472
473    /// Set `io::Stream` options as defined in the `Opts` of the connection.
474    ///
475    /// Requires that self.inner.stream is Some
476    fn setup_stream(&mut self) -> Result<()> {
477        debug_assert!(self.inner.stream.is_some());
478        if let Some(stream) = self.inner.stream.as_mut() {
479            stream.set_tcp_nodelay(self.inner.opts.tcp_nodelay())?;
480        }
481        Ok(())
482    }
483
484    async fn handle_handshake(&mut self) -> Result<()> {
485        let packet = self.read_packet().await?;
486        let handshake = ParseBuf(&*packet).parse::<HandshakePacket>(())?;
487
488        // Handshake scramble is always 21 bytes length (20 + zero terminator)
489        self.inner.nonce = {
490            let mut nonce = Vec::from(handshake.scramble_1_ref());
491            nonce.extend_from_slice(handshake.scramble_2_ref().unwrap_or(&[][..]));
492            // Trim zero terminator. Fill with zeroes if nonce
493            // is somehow smaller than 20 bytes (this matches the server behavior).
494            nonce.resize(20, 0);
495            nonce
496        };
497
498        self.inner.capabilities = handshake.capabilities() & self.inner.opts.get_capabilities();
499        self.inner.version = handshake
500            .maria_db_server_version_parsed()
501            .map(|version| {
502                self.inner.is_mariadb = true;
503                version
504            })
505            .or_else(|| handshake.server_version_parsed())
506            .unwrap_or((0, 0, 0));
507        self.inner.id = handshake.connection_id();
508        self.inner.status = handshake.status_flags();
509
510        // Allow only CachingSha2Password and MysqlNativePassword here
511        // because sha256_password is deprecated and other plugins won't
512        // appear here.
513        self.inner.auth_plugin = match handshake.auth_plugin() {
514            Some(AuthPlugin::CachingSha2Password) => AuthPlugin::CachingSha2Password,
515            _ => AuthPlugin::MysqlNativePassword,
516        };
517
518        Ok(())
519    }
520
521    async fn switch_to_ssl_if_needed(&mut self) -> Result<()> {
522        if self
523            .inner
524            .opts
525            .get_capabilities()
526            .contains(CapabilityFlags::CLIENT_SSL)
527        {
528            if !self
529                .inner
530                .capabilities
531                .contains(CapabilityFlags::CLIENT_SSL)
532            {
533                return Err(DriverError::NoClientSslFlagFromServer.into());
534            }
535
536            let collation = if self.inner.version >= (5, 5, 3) {
537                UTF8MB4_GENERAL_CI
538            } else {
539                UTF8_GENERAL_CI
540            };
541
542            let ssl_request = SslRequest::new(
543                self.inner.capabilities,
544                DEFAULT_MAX_ALLOWED_PACKET as u32,
545                collation as u8,
546            );
547            self.write_struct(&ssl_request).await?;
548            let conn = self;
549            let ssl_opts = conn.opts().ssl_opts().cloned().expect("unreachable");
550            let domain = conn.opts().ip_or_hostname().into();
551            conn.stream_mut()?.make_secure(domain, ssl_opts).await?;
552            Ok(())
553        } else {
554            Ok(())
555        }
556    }
557
558    async fn do_handshake_response(&mut self) -> Result<()> {
559        let auth_data = self
560            .inner
561            .auth_plugin
562            .gen_data(self.inner.opts.pass(), &*self.inner.nonce);
563
564        let handshake_response = HandshakeResponse::new(
565            auth_data.as_deref(),
566            self.inner.version,
567            self.inner.opts.user().map(|x| x.as_bytes()),
568            self.inner.opts.db_name().map(|x| x.as_bytes()),
569            Some(self.inner.auth_plugin.borrow()),
570            self.capabilities(),
571            Default::default(), // TODO: Add support
572        );
573
574        // Serialize here to satisfy borrow checker.
575        let mut buf = crate::BUFFER_POOL.get();
576        handshake_response.serialize(buf.as_mut());
577
578        self.write_packet(buf).await?;
579        Ok(())
580    }
581
582    async fn perform_auth_switch(
583        &mut self,
584        auth_switch_request: AuthSwitchRequest<'_>,
585    ) -> Result<()> {
586        if !self.inner.auth_switched {
587            self.inner.auth_switched = true;
588            self.inner.nonce = auth_switch_request.plugin_data().to_vec();
589
590            if matches!(
591                auth_switch_request.auth_plugin(),
592                AuthPlugin::MysqlOldPassword
593            ) {
594                if self.inner.opts.secure_auth() {
595                    return Err(DriverError::MysqlOldPasswordDisabled.into());
596                }
597            }
598
599            self.inner.auth_plugin = auth_switch_request.auth_plugin().clone().into_owned();
600
601            let plugin_data = match &self.inner.auth_plugin {
602                x @ AuthPlugin::CachingSha2Password => {
603                    x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
604                }
605                x @ AuthPlugin::MysqlNativePassword => {
606                    x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
607                }
608                x @ AuthPlugin::MysqlOldPassword => {
609                    if self.inner.opts.secure_auth() {
610                        return Err(DriverError::MysqlOldPasswordDisabled.into());
611                    } else {
612                        x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
613                    }
614                }
615                x @ AuthPlugin::MysqlClearPassword => {
616                    if self.inner.opts.enable_cleartext_plugin() {
617                        x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
618                    } else {
619                        return Err(DriverError::CleartextPluginDisabled.into());
620                    }
621                }
622                x @ AuthPlugin::Other(_) => x.gen_data(self.inner.opts.pass(), &self.inner.nonce),
623            };
624
625            if let Some(plugin_data) = plugin_data {
626                self.write_struct(&plugin_data.into_owned()).await?;
627            } else {
628                self.write_packet(crate::BUFFER_POOL.get()).await?;
629            }
630
631            self.continue_auth().await?;
632
633            Ok(())
634        } else {
635            unreachable!("auth_switched flag should be checked by caller")
636        }
637    }
638
639    fn continue_auth(&mut self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
640        // NOTE: we need to box this since it may recurse
641        // see https://github.com/rust-lang/rust/issues/46415#issuecomment-528099782
642        Box::pin(async move {
643            match self.inner.auth_plugin {
644                AuthPlugin::MysqlNativePassword | AuthPlugin::MysqlOldPassword => {
645                    self.continue_mysql_native_password_auth().await?;
646                    Ok(())
647                }
648                AuthPlugin::CachingSha2Password => {
649                    self.continue_caching_sha2_password_auth().await?;
650                    Ok(())
651                }
652                AuthPlugin::MysqlClearPassword => {
653                    if self.inner.opts.enable_cleartext_plugin() {
654                        self.continue_mysql_native_password_auth().await?;
655                        Ok(())
656                    } else {
657                        Err(DriverError::CleartextPluginDisabled.into())
658                    }
659                }
660                AuthPlugin::Other(ref name) => Err(DriverError::UnknownAuthPlugin {
661                    name: String::from_utf8_lossy(name.as_ref()).to_string(),
662                }
663                .into()),
664            }
665        })
666    }
667
668    fn switch_to_compression(&mut self) -> Result<()> {
669        if self
670            .capabilities()
671            .contains(CapabilityFlags::CLIENT_COMPRESS)
672        {
673            if let Some(compression) = self.inner.opts.compression() {
674                if let Some(stream) = self.inner.stream.as_mut() {
675                    stream.compress(compression);
676                }
677            }
678        }
679        Ok(())
680    }
681
682    async fn continue_caching_sha2_password_auth(&mut self) -> Result<()> {
683        let packet = self.read_packet().await?;
684        match packet.get(0) {
685            Some(0x00) => {
686                // ok packet for empty password
687                Ok(())
688            }
689            Some(0x01) => match packet.get(1) {
690                Some(0x03) => {
691                    // auth ok
692                    self.drop_packet().await
693                }
694                Some(0x04) => {
695                    let pass = self.inner.opts.pass().unwrap_or_default();
696                    let mut pass = crate::BUFFER_POOL.get_with(pass.as_bytes());
697                    pass.as_mut().push(0);
698
699                    if self.is_secure() || self.is_socket() {
700                        self.write_packet(pass).await?;
701                    } else {
702                        if self.inner.server_key.is_none() {
703                            self.write_bytes(&[0x02][..]).await?;
704                            let packet = self.read_packet().await?;
705                            self.inner.server_key = Some(packet[1..].to_vec());
706                        }
707                        for (i, byte) in pass.as_mut().iter_mut().enumerate() {
708                            *byte ^= self.inner.nonce[i % self.inner.nonce.len()];
709                        }
710                        let encrypted_pass = crypto::encrypt(
711                            &*pass,
712                            self.inner.server_key.as_deref().expect("unreachable"),
713                        );
714                        self.write_bytes(&*encrypted_pass).await?;
715                    };
716                    self.drop_packet().await?;
717                    Ok(())
718                }
719                _ => Err(DriverError::UnexpectedPacket {
720                    payload: packet.to_vec(),
721                }
722                .into()),
723            },
724            Some(0xfe) if !self.inner.auth_switched => {
725                let auth_switch_request = ParseBuf(&*packet).parse::<AuthSwitchRequest>(())?;
726                self.perform_auth_switch(auth_switch_request).await?;
727                Ok(())
728            }
729            _ => Err(DriverError::UnexpectedPacket {
730                payload: packet.to_vec(),
731            }
732            .into()),
733        }
734    }
735
736    async fn continue_mysql_native_password_auth(&mut self) -> Result<()> {
737        let packet = self.read_packet().await?;
738        match packet.get(0) {
739            Some(0x00) => Ok(()),
740            Some(0xfe) if !self.inner.auth_switched => {
741                let auth_switch = if packet.len() > 1 {
742                    ParseBuf(&*packet).parse(())?
743                } else {
744                    let _ = ParseBuf(&*packet).parse::<OldAuthSwitchRequest>(())?;
745                    // map OldAuthSwitch to AuthSwitch with mysql_old_password plugin
746                    AuthSwitchRequest::new(
747                        "mysql_old_password".as_bytes(),
748                        self.inner.nonce.clone(),
749                    )
750                };
751                self.perform_auth_switch(auth_switch).await
752            }
753            _ => Err(DriverError::UnexpectedPacket {
754                payload: packet.to_vec(),
755            }
756            .into()),
757        }
758    }
759
760    /// Returns `true` for ProgressReport packet.
761    fn handle_packet(&mut self, packet: &PooledBuf) -> Result<bool> {
762        let ok_packet = if self.has_pending_result() {
763            if self
764                .capabilities()
765                .contains(CapabilityFlags::CLIENT_DEPRECATE_EOF)
766            {
767                ParseBuf(&*packet)
768                    .parse::<OkPacketDeserializer<ResultSetTerminator>>(self.capabilities())
769                    .map(|x| x.into_inner())
770            } else {
771                ParseBuf(&*packet)
772                    .parse::<OkPacketDeserializer<OldEofPacket>>(self.capabilities())
773                    .map(|x| x.into_inner())
774            }
775        } else {
776            ParseBuf(&*packet)
777                .parse::<OkPacketDeserializer<CommonOkPacket>>(self.capabilities())
778                .map(|x| x.into_inner())
779        };
780
781        if let Ok(ok_packet) = ok_packet {
782            self.handle_ok(ok_packet.into_owned());
783        } else {
784            let err_packet = ParseBuf(&*packet).parse::<ErrPacket>(self.capabilities());
785            if let Ok(err_packet) = err_packet {
786                self.handle_err(err_packet)?;
787                return Ok(true);
788            }
789        }
790
791        Ok(false)
792    }
793
794    pub(crate) async fn read_packet(&mut self) -> Result<PooledBuf> {
795        loop {
796            let packet = crate::io::ReadPacket::new(&mut *self)
797                .await
798                .map_err(|io_err| {
799                    self.inner.stream.take();
800                    self.inner.disconnected = true;
801                    Error::from(io_err)
802                })?;
803            if self.handle_packet(&packet)? {
804                // ignore progress report
805                continue;
806            } else {
807                return Ok(packet);
808            }
809        }
810    }
811
812    /// Returns future that reads packets from a server.
813    pub(crate) async fn read_packets(&mut self, n: usize) -> Result<Vec<PooledBuf>> {
814        let mut packets = Vec::with_capacity(n);
815        for _ in 0..n {
816            packets.push(self.read_packet().await?);
817        }
818        Ok(packets)
819    }
820
821    pub(crate) async fn write_packet(&mut self, data: PooledBuf) -> Result<()> {
822        crate::io::WritePacket::new(&mut *self, data)
823            .await
824            .map_err(|io_err| {
825                self.inner.stream.take();
826                self.inner.disconnected = true;
827                From::from(io_err)
828            })
829    }
830
831    /// Writes bytes to a server.
832    pub(crate) async fn write_bytes(&mut self, bytes: &[u8]) -> Result<()> {
833        let buf = crate::BUFFER_POOL.get_with(bytes);
834        self.write_packet(buf).await
835    }
836
837    /// Sends a serializable structure to a server.
838    pub(crate) async fn write_struct<T: MySerialize>(&mut self, x: &T) -> Result<()> {
839        let mut buf = crate::BUFFER_POOL.get();
840        x.serialize(buf.as_mut());
841        self.write_packet(buf).await
842    }
843
844    /// Sends a command to a server.
845    pub(crate) async fn write_command<T: MySerialize>(&mut self, cmd: &T) -> Result<()> {
846        self.clean_dirty().await?;
847        self.reset_seq_id();
848        self.write_struct(cmd).await
849    }
850
851    /// Returns future that sends full command body to a server.
852    pub(crate) async fn write_command_raw(&mut self, body: PooledBuf) -> Result<()> {
853        debug_assert!(!body.is_empty());
854        self.clean_dirty().await?;
855        self.reset_seq_id();
856        self.write_packet(body).await
857    }
858
859    /// Returns future that writes command to a server.
860    pub(crate) async fn write_command_data<T>(&mut self, cmd: Command, cmd_data: T) -> Result<()>
861    where
862        T: AsRef<[u8]>,
863    {
864        let cmd_data = cmd_data.as_ref();
865        let mut buf = crate::BUFFER_POOL.get();
866        let body = buf.as_mut();
867        body.push(cmd as u8);
868        body.extend_from_slice(cmd_data);
869        self.write_command_raw(buf).await
870    }
871
872    async fn drop_packet(&mut self) -> Result<()> {
873        self.read_packet().await?;
874        Ok(())
875    }
876
877    async fn run_init_commands(&mut self) -> Result<()> {
878        let mut init = self.inner.opts.init().to_vec();
879
880        while let Some(query) = init.pop() {
881            self.query_drop(query).await?;
882        }
883
884        Ok(())
885    }
886
887    async fn run_setup_commands(&mut self) -> Result<()> {
888        let mut setup = self.inner.opts.setup().to_vec();
889
890        while let Some(query) = setup.pop() {
891            self.query_drop(query).await?;
892        }
893
894        Ok(())
895    }
896
897    /// Returns a future that resolves to [`Conn`].
898    pub fn new<T: Into<Opts>>(opts: T) -> crate::BoxFuture<'static, Conn> {
899        let opts = opts.into();
900        async move {
901            let mut conn = Conn::empty(opts.clone());
902
903            let stream = if let Some(_path) = opts.socket() {
904                #[cfg(unix)]
905                {
906                    Stream::connect_socket(_path.to_owned()).await?
907                }
908                #[cfg(target_os = "windows")]
909                return Err(crate::DriverError::NamedPipesDisabled.into());
910            } else {
911                let keepalive = opts
912                    .tcp_keepalive()
913                    .map(|x| std::time::Duration::from_millis(x.into()));
914                Stream::connect_tcp(opts.hostport_or_url(), keepalive).await?
915            };
916
917            conn.inner.stream = Some(stream);
918            conn.setup_stream()?;
919            conn.handle_handshake().await?;
920            conn.switch_to_ssl_if_needed().await?;
921            conn.do_handshake_response().await?;
922            conn.continue_auth().await?;
923            conn.switch_to_compression()?;
924            conn.read_settings().await?;
925            conn.reconnect_via_socket_if_needed().await?;
926            conn.run_init_commands().await?;
927            conn.run_setup_commands().await?;
928
929            Ok(conn)
930        }
931        .boxed()
932    }
933
934    /// Returns a future that resolves to [`Conn`].
935    pub async fn from_url<T: AsRef<str>>(url: T) -> Result<Conn> {
936        Conn::new(Opts::from_str(url.as_ref())?).await
937    }
938
939    /// Will try to reconnect via socket using socket address in `self.inner.socket`.
940    ///
941    /// Won't try to reconnect if socket connection is already enforced in [`Opts`].
942    async fn reconnect_via_socket_if_needed(&mut self) -> Result<()> {
943        if let Some(socket) = self.inner.socket.as_ref() {
944            let opts = self.inner.opts.clone();
945            if opts.socket().is_none() {
946                let opts = OptsBuilder::from_opts(opts).socket(Some(&**socket));
947                if let Ok(conn) = Conn::new(opts).await {
948                    let old_conn = std::mem::replace(self, conn);
949                    // tidy up the old connection
950                    old_conn.close_conn().await?;
951                }
952            }
953        }
954        Ok(())
955    }
956
957    /// Configures the connection based on server settings. In particular:
958    ///
959    /// * It reads and stores socket address inside the connection unless if socket address is
960    /// already in [`Opts`] or if `prefer_socket` is `false`.
961    ///
962    /// * It reads and stores `max_allowed_packet` in the connection unless it's already in [`Opts`]
963    ///
964    /// * It reads and stores `wait_timeout` in the connection unless it's already in [`Opts`]
965    ///
966    async fn read_settings(&mut self) -> Result<()> {
967        enum Action {
968            Load(Cfg),
969            Apply(CfgData),
970        }
971
972        enum CfgData {
973            MaxAllowedPacket(usize),
974            WaitTimeout(usize),
975        }
976
977        impl CfgData {
978            fn apply(&self, conn: &mut Conn) {
979                match self {
980                    Self::MaxAllowedPacket(value) => {
981                        if let Some(stream) = conn.inner.stream.as_mut() {
982                            stream.set_max_allowed_packet(*value);
983                        }
984                    }
985                    Self::WaitTimeout(value) => {
986                        conn.inner.wait_timeout = Duration::from_secs(*value as u64);
987                    }
988                }
989            }
990        }
991
992        enum Cfg {
993            Socket,
994            MaxAllowedPacket,
995            WaitTimeout,
996        }
997
998        impl Cfg {
999            const fn name(&self) -> &'static str {
1000                match self {
1001                    Self::Socket => "@@socket",
1002                    Self::MaxAllowedPacket => "@@max_allowed_packet",
1003                    Self::WaitTimeout => "@@wait_timeout",
1004                }
1005            }
1006
1007            fn apply(&self, conn: &mut Conn, value: Option<crate::Value>) {
1008                match self {
1009                    Cfg::Socket => {
1010                        conn.inner.socket = value.map(crate::from_value).flatten();
1011                    }
1012                    Cfg::MaxAllowedPacket => {
1013                        if let Some(stream) = conn.inner.stream.as_mut() {
1014                            stream.set_max_allowed_packet(
1015                                value
1016                                    .map(crate::from_value)
1017                                    .flatten()
1018                                    .unwrap_or(DEFAULT_MAX_ALLOWED_PACKET),
1019                            );
1020                        }
1021                    }
1022                    Cfg::WaitTimeout => {
1023                        conn.inner.wait_timeout = Duration::from_secs(
1024                            value
1025                                .map(crate::from_value)
1026                                .flatten()
1027                                .unwrap_or(DEFAULT_WAIT_TIMEOUT) as u64,
1028                        );
1029                    }
1030                }
1031            }
1032        }
1033
1034        let mut actions = vec![
1035            if let Some(x) = self.opts().max_allowed_packet() {
1036                Action::Apply(CfgData::MaxAllowedPacket(x))
1037            } else {
1038                Action::Load(Cfg::MaxAllowedPacket)
1039            },
1040            if let Some(x) = self.opts().wait_timeout() {
1041                Action::Apply(CfgData::WaitTimeout(x))
1042            } else {
1043                Action::Load(Cfg::WaitTimeout)
1044            },
1045        ];
1046
1047        if self.inner.opts.prefer_socket() && self.inner.socket.is_none() {
1048            actions.push(Action::Load(Cfg::Socket))
1049        }
1050
1051        let loads = actions
1052            .iter()
1053            .filter_map(|x| match x {
1054                Action::Load(x) => Some(x),
1055                Action::Apply(_) => None,
1056            })
1057            .collect::<Vec<_>>();
1058
1059        let loaded = if !loads.is_empty() {
1060            let query = loads
1061                .iter()
1062                .zip(std::iter::once(' ').chain(std::iter::repeat(',')))
1063                .fold("SELECT".to_owned(), |mut acc, (cfg, prefix)| {
1064                    acc.push(prefix);
1065                    acc.push_str(cfg.name());
1066                    acc
1067                });
1068
1069            self.query_internal::<Row, String>(query)
1070                .await?
1071                .map(|row| row.unwrap())
1072                .unwrap_or_else(|| vec![crate::Value::NULL; loads.len()])
1073        } else {
1074            vec![]
1075        };
1076        let mut loaded = loaded.into_iter();
1077
1078        for action in actions {
1079            match action {
1080                Action::Load(cfg) => cfg.apply(self, loaded.next()),
1081                Action::Apply(cfg) => cfg.apply(self),
1082            }
1083        }
1084
1085        Ok(())
1086    }
1087
1088    /// Returns true if time since last IO exceeds `wait_timeout`
1089    /// (or `conn_ttl` if specified in opts).
1090    fn expired(&self) -> bool {
1091        let ttl = self
1092            .inner
1093            .opts
1094            .conn_ttl()
1095            .unwrap_or(self.inner.wait_timeout);
1096        !ttl.is_zero() && self.idling() > ttl
1097    }
1098
1099    /// Returns duration since last IO.
1100    fn idling(&self) -> Duration {
1101        self.inner.last_io.elapsed()
1102    }
1103
1104    /// Executes [`COM_RESET_CONNECTION`][1].
1105    ///
1106    /// Returns `false` if command is not supported (requires MySql >5.7.2, MariaDb >10.2.3).
1107    /// For older versions consider using [`Conn::change_user`].
1108    ///
1109    /// [1]: https://dev.mysql.com/doc/c-api/5.7/en/mysql-reset-connection.html
1110    pub async fn reset(&mut self) -> Result<bool> {
1111        let supports_com_reset_connection = if self.inner.is_mariadb {
1112            self.inner.version >= (10, 2, 4)
1113        } else {
1114            // assuming mysql
1115            self.inner.version > (5, 7, 2)
1116        };
1117
1118        if supports_com_reset_connection {
1119            self.routine(routines::ResetRoutine).await?;
1120            self.inner.stmt_cache.clear();
1121            self.inner.infile_handler = None;
1122            self.run_setup_commands().await?;
1123        }
1124
1125        Ok(supports_com_reset_connection)
1126    }
1127
1128    /// Executes [`COM_CHANGE_USER`][1].
1129    ///
1130    /// This might be used as an older and slower alternative to `COM_RESET_CONNECTION` that
1131    /// works on MySql prior to 5.7.3 (MariaDb prior ot 10.2.4).
1132    ///
1133    /// ## Note
1134    ///
1135    /// * Using non-default `opts` for a pooled connection is discouraging.
1136    /// * Connection options will be permanently updated.
1137    ///
1138    /// [1]: https://dev.mysql.com/doc/c-api/5.7/en/mysql-change-user.html
1139    pub async fn change_user(&mut self, opts: ChangeUserOpts) -> Result<()> {
1140        // We'll kick this connection from a pool if opts are changed.
1141        if opts != ChangeUserOpts::default() {
1142            let mut opts_changed = false;
1143            if let Some(user) = opts.user() {
1144                opts_changed |= user != self.opts().user()
1145            };
1146            if let Some(pass) = opts.pass() {
1147                opts_changed |= pass != self.opts().pass()
1148            };
1149            if let Some(db_name) = opts.db_name() {
1150                opts_changed |= db_name != self.opts().db_name()
1151            };
1152            if opts_changed {
1153                if let Some(pool) = self.inner.pool.take() {
1154                    pool.cancel_connection();
1155                }
1156            }
1157        }
1158
1159        let conn_opts = &mut self.inner.opts;
1160        opts.update_opts(conn_opts);
1161        self.routine(routines::ChangeUser).await?;
1162        self.inner.stmt_cache.clear();
1163        self.inner.infile_handler = None;
1164        self.run_setup_commands().await?;
1165        Ok(())
1166    }
1167
1168    /// Resets the connection upon returning it to a pool.
1169    ///
1170    /// Will invoke `COM_CHANGE_USER` if `COM_RESET_CONNECTION` is not supported.
1171    async fn reset_for_pool(mut self) -> Result<Self> {
1172        if !self.reset().await? {
1173            self.change_user(Default::default()).await?;
1174        }
1175        Ok(self)
1176    }
1177
1178    /// Requires that `self.inner.tx_status != TxStatus::None`
1179    async fn rollback_transaction(&mut self) -> Result<()> {
1180        debug_assert_ne!(self.inner.tx_status, TxStatus::None);
1181        self.inner.tx_status = TxStatus::None;
1182        self.query_drop("ROLLBACK").await
1183    }
1184
1185    /// Returns `true` if `SERVER_MORE_RESULTS_EXISTS` flag is contained
1186    /// in status flags of the connection.
1187    pub(crate) fn more_results_exists(&self) -> bool {
1188        self.status()
1189            .contains(StatusFlags::SERVER_MORE_RESULTS_EXISTS)
1190    }
1191
1192    /// The purpose of this function is to cleanup a pending result set
1193    /// for prematurely dropeed connection or query result.
1194    ///
1195    /// Requires that there are no other references to the pending result.
1196    pub(crate) async fn drop_result(&mut self) -> Result<()> {
1197        // Map everything into `PendingResult::Pending`
1198        let meta = match self.set_pending_result(None)? {
1199            Some(PendingResult::Pending(meta)) => Some(meta),
1200            Some(PendingResult::Taken(meta)) => {
1201                // This also asserts that there is only one reference left to the taken ResultSetMeta,
1202                // therefore this result set must be dropped here since it won't be dropped anywhere else.
1203                Some(Arc::try_unwrap(meta).expect("Conn::drop_result call on a pending result that may still be droped by someone else"))
1204            }
1205            None => None,
1206        };
1207
1208        let _ = self.set_pending_result(meta);
1209
1210        match self.use_pending_result() {
1211            Ok(Some(PendingResult::Pending(ResultSetMeta::Text(_)))) => {
1212                QueryResult::<'_, '_, TextProtocol>::new(self)
1213                    .drop_result()
1214                    .await
1215            }
1216            Ok(Some(PendingResult::Pending(ResultSetMeta::Binary(_)))) => {
1217                QueryResult::<'_, '_, BinaryProtocol>::new(self)
1218                    .drop_result()
1219                    .await
1220            }
1221            Ok(None) => Ok((/* this case does not require an action */)),
1222            Ok(Some(PendingResult::Taken(_))) | Err(_) => {
1223                unreachable!("this case must be handled earlier in this function")
1224            }
1225        }
1226    }
1227
1228    /// This function will drop pending result and rollback a transaction, if needed.
1229    ///
1230    /// The purpose of this function, is to cleanup the connection while returning it to a [`Pool`].
1231    async fn cleanup_for_pool(mut self) -> Result<Self> {
1232        loop {
1233            let result = if self.has_pending_result() {
1234                self.drop_result().await
1235            } else if self.inner.tx_status != TxStatus::None {
1236                self.rollback_transaction().await
1237            } else {
1238                break;
1239            };
1240
1241            // The connection was dropped and we assume that it was dropped intentionally,
1242            // so we'll ignore non-fatal errors during cleanup (also there is no direct caller
1243            // to return this error to).
1244            if let Err(err) = result {
1245                if err.is_fatal() {
1246                    // This means that connection is completely broken
1247                    // and shouldn't return to a pool.
1248                    return Err(err);
1249                }
1250            }
1251        }
1252        Ok(self)
1253    }
1254
1255    async fn register_as_slave(&mut self, server_id: u32) -> Result<()> {
1256        use mysql_common::packets::ComRegisterSlave;
1257
1258        self.query_drop("SET @master_binlog_checksum='ALL'").await?;
1259        self.write_command(&ComRegisterSlave::new(server_id))
1260            .await?;
1261
1262        // Server will respond with OK.
1263        self.read_packet().await?;
1264
1265        Ok(())
1266    }
1267
1268    async fn request_binlog(&mut self, request: BinlogRequest<'_>) -> Result<()> {
1269        self.register_as_slave(request.server_id()).await?;
1270        self.write_command(&request.as_cmd()).await?;
1271        Ok(())
1272    }
1273
1274    pub async fn get_binlog_stream(mut self, request: BinlogRequest<'_>) -> Result<BinlogStream> {
1275        self.request_binlog(request).await?;
1276
1277        Ok(BinlogStream::new(self))
1278    }
1279}
1280
1281#[cfg(test)]
1282mod test {
1283    use bytes::Bytes;
1284    use futures_util::stream::{self, StreamExt};
1285    use mysql_common::{binlog::events::EventData, constants::MAX_PAYLOAD_LEN};
1286    use rand::Fill;
1287    use tokio::time::timeout;
1288
1289    use std::time::Duration;
1290
1291    use crate::{
1292        from_row, params, prelude::*, test_misc::get_opts, BinlogDumpFlags, BinlogRequest,
1293        ChangeUserOpts, Conn, Error, OptsBuilder, Pool, Value, WhiteListFsHandler,
1294    };
1295
1296    async fn gen_dummy_data() -> super::Result<()> {
1297        let mut conn = Conn::new(get_opts()).await?;
1298
1299        "CREATE TABLE IF NOT EXISTS customers (customer_id int not null)"
1300            .ignore(&mut conn)
1301            .await?;
1302
1303        for i in 0_u8..100 {
1304            "INSERT INTO customers(customer_id) VALUES (?)"
1305                .with((i,))
1306                .ignore(&mut conn)
1307                .await?;
1308        }
1309
1310        "DROP TABLE customers".ignore(&mut conn).await?;
1311
1312        Ok(())
1313    }
1314
1315    async fn create_binlog_stream_conn(pool: Option<&Pool>) -> super::Result<(Conn, Vec<u8>, u64)> {
1316        let mut conn = match pool {
1317            None => Conn::new(get_opts()).await.unwrap(),
1318            Some(pool) => pool.get_conn().await.unwrap(),
1319        };
1320
1321        if let Ok(Some(gtid_mode)) = "SELECT @@GLOBAL.GTID_MODE"
1322            .first::<String, _>(&mut conn)
1323            .await
1324        {
1325            if !gtid_mode.starts_with("ON") {
1326                panic!(
1327                    "GTID_MODE is disabled \
1328                        (enable using --gtid_mode=ON --enforce_gtid_consistency=ON)"
1329                );
1330            }
1331        }
1332
1333        let row: crate::Row = "SHOW BINARY LOGS".first(&mut conn).await.unwrap().unwrap();
1334        let filename = row.get(0).unwrap();
1335        let position = row.get(1).unwrap();
1336
1337        gen_dummy_data().await.unwrap();
1338        Ok((conn, filename, position))
1339    }
1340
1341    #[tokio::test]
1342    async fn should_read_binlog() -> super::Result<()> {
1343        read_binlog_streams_and_close_their_connections(None, (12, 13, 14))
1344            .await
1345            .unwrap();
1346
1347        let pool = Pool::new(get_opts());
1348        read_binlog_streams_and_close_their_connections(Some(&pool), (15, 16, 17))
1349            .await
1350            .unwrap();
1351
1352        // Disconnecting the pool verifies that closing the binlog connections
1353        // left the pool in a sane state.
1354        timeout(Duration::from_secs(10), pool.disconnect())
1355            .await
1356            .unwrap()
1357            .unwrap();
1358
1359        Ok(())
1360    }
1361
1362    #[tokio::test]
1363    async fn should_return_found_rows_if_flag_is_set() -> super::Result<()> {
1364        let opts = get_opts().client_found_rows(true);
1365        let mut conn = Conn::new(opts).await.unwrap();
1366
1367        "CREATE TEMPORARY TABLE mysql.found_rows (id INT PRIMARY KEY AUTO_INCREMENT, val INT)"
1368            .ignore(&mut conn)
1369            .await?;
1370
1371        "INSERT INTO mysql.found_rows (val) VALUES (1)"
1372            .ignore(&mut conn)
1373            .await?;
1374
1375        // Inserted one row, affected should be one.
1376        assert_eq!(conn.affected_rows(), 1);
1377
1378        "UPDATE mysql.found_rows SET val = 1 WHERE val = 1"
1379            .ignore(&mut conn)
1380            .await?;
1381
1382        // The query doesn't affect any rows, but due to us wanting FOUND rows,
1383        // this has to return one.
1384        assert_eq!(conn.affected_rows(), 1);
1385
1386        Ok(())
1387    }
1388
1389    #[tokio::test]
1390    async fn should_not_return_found_rows_if_flag_is_not_set() -> super::Result<()> {
1391        let mut conn = Conn::new(get_opts()).await.unwrap();
1392
1393        "CREATE TEMPORARY TABLE mysql.found_rows (id INT PRIMARY KEY AUTO_INCREMENT, val INT)"
1394            .ignore(&mut conn)
1395            .await?;
1396
1397        "INSERT INTO mysql.found_rows (val) VALUES (1)"
1398            .ignore(&mut conn)
1399            .await?;
1400
1401        // Inserted one row, affected should be one.
1402        assert_eq!(conn.affected_rows(), 1);
1403
1404        "UPDATE mysql.found_rows SET val = 1 WHERE val = 1"
1405            .ignore(&mut conn)
1406            .await?;
1407
1408        // The query doesn't affect any rows.
1409        assert_eq!(conn.affected_rows(), 0);
1410
1411        Ok(())
1412    }
1413
1414    async fn read_binlog_streams_and_close_their_connections(
1415        pool: Option<&Pool>,
1416        binlog_server_ids: (u32, u32, u32),
1417    ) -> super::Result<()> {
1418        // iterate using COM_BINLOG_DUMP
1419        let (conn, filename, pos) = create_binlog_stream_conn(pool).await.unwrap();
1420        let is_mariadb = conn.inner.is_mariadb;
1421
1422        let mut binlog_stream = conn
1423            .get_binlog_stream(
1424                BinlogRequest::new(binlog_server_ids.0)
1425                    .with_filename(filename)
1426                    .with_pos(pos),
1427            )
1428            .await
1429            .unwrap();
1430
1431        let mut events_num = 0;
1432        while let Ok(Some(event)) = timeout(Duration::from_secs(10), binlog_stream.next()).await {
1433            let event = event.unwrap();
1434            events_num += 1;
1435
1436            // assert that event type is known
1437            event.header().event_type().unwrap();
1438
1439            // iterate over rows of an event
1440            match event.read_data()?.unwrap() {
1441                EventData::RowsEvent(re) => {
1442                    let tme = binlog_stream.get_tme(re.table_id());
1443                    for row in re.rows(tme.unwrap()) {
1444                        row.unwrap();
1445                    }
1446                }
1447                _ => (),
1448            }
1449        }
1450        assert!(events_num > 0);
1451        timeout(Duration::from_secs(10), binlog_stream.close())
1452            .await
1453            .unwrap()
1454            .unwrap();
1455
1456        if !is_mariadb {
1457            // iterate using COM_BINLOG_DUMP_GTID
1458            let (conn, filename, pos) = create_binlog_stream_conn(pool).await.unwrap();
1459
1460            let mut binlog_stream = conn
1461                .get_binlog_stream(
1462                    BinlogRequest::new(binlog_server_ids.1)
1463                        .with_use_gtid(true)
1464                        .with_filename(filename)
1465                        .with_pos(pos),
1466                )
1467                .await
1468                .unwrap();
1469
1470            events_num = 0;
1471            while let Ok(Some(event)) = timeout(Duration::from_secs(10), binlog_stream.next()).await
1472            {
1473                let event = event.unwrap();
1474                events_num += 1;
1475
1476                // assert that event type is known
1477                event.header().event_type().unwrap();
1478
1479                // iterate over rows of an event
1480                match event.read_data()?.unwrap() {
1481                    EventData::RowsEvent(re) => {
1482                        let tme = binlog_stream.get_tme(re.table_id());
1483                        for row in re.rows(tme.unwrap()) {
1484                            row.unwrap();
1485                        }
1486                    }
1487                    _ => (),
1488                }
1489            }
1490            assert!(events_num > 0);
1491            timeout(Duration::from_secs(10), binlog_stream.close())
1492                .await
1493                .unwrap()
1494                .unwrap();
1495        }
1496
1497        // iterate using COM_BINLOG_DUMP with BINLOG_DUMP_NON_BLOCK flag
1498        let (conn, filename, pos) = create_binlog_stream_conn(pool).await.unwrap();
1499
1500        let mut binlog_stream = conn
1501            .get_binlog_stream(
1502                BinlogRequest::new(binlog_server_ids.2)
1503                    .with_filename(filename)
1504                    .with_pos(pos)
1505                    .with_flags(BinlogDumpFlags::BINLOG_DUMP_NON_BLOCK),
1506            )
1507            .await
1508            .unwrap();
1509
1510        events_num = 0;
1511        while let Some(event) = binlog_stream.next().await {
1512            let event = event.unwrap();
1513            events_num += 1;
1514            event.header().event_type().unwrap();
1515            event.read_data().unwrap();
1516        }
1517        assert!(events_num > 0);
1518        timeout(Duration::from_secs(10), binlog_stream.close())
1519            .await
1520            .unwrap()
1521            .unwrap();
1522
1523        Ok(())
1524    }
1525
1526    #[test]
1527    fn opts_should_satisfy_send_and_sync() {
1528        struct A<T: Sync + Send>(T);
1529        A(get_opts());
1530    }
1531
1532    #[tokio::test]
1533    async fn should_connect_without_database() -> super::Result<()> {
1534        // no database name
1535        let mut conn: Conn = Conn::new(get_opts().db_name(None::<String>)).await?;
1536        conn.ping().await?;
1537        conn.disconnect().await?;
1538
1539        // empty database name
1540        let mut conn: Conn = Conn::new(get_opts().db_name(Some(""))).await?;
1541        conn.ping().await?;
1542        conn.disconnect().await?;
1543
1544        Ok(())
1545    }
1546
1547    #[tokio::test]
1548    async fn should_clean_state_if_wrapper_is_dropeed() -> super::Result<()> {
1549        let mut conn: Conn = Conn::new(get_opts()).await?;
1550
1551        conn.query_drop("CREATE TEMPORARY TABLE mysql.foo (id SERIAL)")
1552            .await?;
1553
1554        // dropped query:
1555        conn.query_iter("SELECT 1").await?;
1556        conn.ping().await?;
1557
1558        // dropped query in dropped transaction:
1559        let mut tx = conn.start_transaction(Default::default()).await?;
1560        tx.query_drop("INSERT INTO mysql.foo (id) VALUES (42)")
1561            .await?;
1562        tx.exec_iter("SELECT COUNT(*) FROM mysql.foo", ()).await?;
1563        drop(tx);
1564        conn.ping().await?;
1565
1566        let count: u8 = conn
1567            .query_first("SELECT COUNT(*) FROM mysql.foo")
1568            .await?
1569            .unwrap_or_default();
1570
1571        assert_eq!(count, 0);
1572
1573        Ok(())
1574    }
1575
1576    #[tokio::test]
1577    async fn should_connect() -> super::Result<()> {
1578        let mut conn: Conn = Conn::new(get_opts()).await?;
1579        conn.ping().await?;
1580        let plugins: Vec<String> = conn
1581            .query_map("SHOW PLUGINS", |mut row: crate::Row| {
1582                row.take("Name").unwrap()
1583            })
1584            .await?;
1585
1586        // Should connect with any combination of supported plugin and empty-nonempty password.
1587        let variants = vec![
1588            ("caching_sha2_password", 2_u8, "non-empty"),
1589            ("caching_sha2_password", 2_u8, ""),
1590            ("mysql_native_password", 0_u8, "non-empty"),
1591            ("mysql_native_password", 0_u8, ""),
1592        ]
1593        .into_iter()
1594        .filter(|variant| plugins.iter().any(|p| p == variant.0));
1595
1596        for (plug, val, pass) in variants {
1597            let _ = conn.query_drop("DROP USER 'test_user'@'%'").await;
1598
1599            let query = format!("CREATE USER 'test_user'@'%' IDENTIFIED WITH {}", plug);
1600            conn.query_drop(query).await.unwrap();
1601
1602            if (8, 0, 11) <= conn.inner.version && conn.inner.version <= (9, 0, 0) {
1603                conn.query_drop(format!("SET PASSWORD FOR 'test_user'@'%' = '{}'", pass))
1604                    .await
1605                    .unwrap();
1606            } else {
1607                conn.query_drop(format!("SET old_passwords = {}", val))
1608                    .await
1609                    .unwrap();
1610                conn.query_drop(format!(
1611                    "SET PASSWORD FOR 'test_user'@'%' = PASSWORD('{}')",
1612                    pass
1613                ))
1614                .await
1615                .unwrap();
1616            };
1617
1618            let opts = get_opts()
1619                .user(Some("test_user"))
1620                .pass(Some(pass))
1621                .db_name(None::<String>);
1622            let result = Conn::new(opts).await;
1623
1624            conn.query_drop("DROP USER 'test_user'@'%'").await.unwrap();
1625
1626            result?.disconnect().await?;
1627        }
1628
1629        if crate::test_misc::test_compression() {
1630            assert!(format!("{:?}", conn).contains("Compression"));
1631        }
1632
1633        if crate::test_misc::test_ssl() {
1634            assert!(format!("{:?}", conn).contains("Tls"));
1635        }
1636
1637        conn.disconnect().await?;
1638        Ok(())
1639    }
1640
1641    #[test]
1642    fn should_not_panic_if_dropped_without_tokio_runtime() {
1643        let fut = Conn::new(get_opts());
1644        let runtime = tokio::runtime::Runtime::new().unwrap();
1645        runtime.block_on(async {
1646            fut.await.unwrap();
1647        });
1648        // connection will drop here
1649    }
1650
1651    #[tokio::test]
1652    async fn should_execute_init_queries_on_new_connection() -> super::Result<()> {
1653        let opts = OptsBuilder::from_opts(get_opts()).init(vec!["SET @a = 42", "SET @b = 'foo'"]);
1654        let mut conn = Conn::new(opts).await?;
1655        let result: Vec<(u8, String)> = conn.query("SELECT @a, @b").await?;
1656        conn.disconnect().await?;
1657        assert_eq!(result, vec![(42, "foo".into())]);
1658        Ok(())
1659    }
1660
1661    #[tokio::test]
1662    async fn should_execute_setup_queries_on_reset() -> super::Result<()> {
1663        let opts = OptsBuilder::from_opts(get_opts()).setup(vec!["SET @a = 42", "SET @b = 'foo'"]);
1664        let mut conn = Conn::new(opts).await?;
1665
1666        // initial run
1667        let mut result: Vec<(u8, String)> = conn.query("SELECT @a, @b").await?;
1668        assert_eq!(result, vec![(42, "foo".into())]);
1669
1670        // after reset
1671        if conn.reset().await? {
1672            result = conn.query("SELECT @a, @b").await?;
1673            assert_eq!(result, vec![(42, "foo".into())]);
1674        }
1675
1676        // after change user
1677        conn.change_user(Default::default()).await?;
1678        result = conn.query("SELECT @a, @b").await?;
1679        assert_eq!(result, vec![(42, "foo".into())]);
1680
1681        conn.disconnect().await?;
1682        Ok(())
1683    }
1684
1685    #[tokio::test]
1686    async fn should_reset_the_connection() -> super::Result<()> {
1687        let mut conn = Conn::new(get_opts()).await?;
1688
1689        assert_eq!(
1690            conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1691            Value::NULL
1692        );
1693
1694        conn.query_drop("SET @foo = 'foo'").await?;
1695
1696        assert_eq!(
1697            conn.query_first::<String, _>("SELECT @foo").await?.unwrap(),
1698            "foo",
1699        );
1700
1701        if conn.reset().await? {
1702            assert_eq!(
1703                conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1704                Value::NULL
1705            );
1706        } else {
1707            assert_eq!(
1708                conn.query_first::<String, _>("SELECT @foo").await?.unwrap(),
1709                "foo",
1710            );
1711        }
1712
1713        conn.disconnect().await?;
1714        Ok(())
1715    }
1716
1717    #[tokio::test]
1718    async fn should_change_user() -> super::Result<()> {
1719        let mut conn = Conn::new(get_opts()).await?;
1720        assert_eq!(
1721            conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1722            Value::NULL
1723        );
1724
1725        conn.query_drop("SET @foo = 'foo'").await?;
1726
1727        assert_eq!(
1728            conn.query_first::<String, _>("SELECT @foo").await?.unwrap(),
1729            "foo",
1730        );
1731
1732        conn.change_user(Default::default()).await?;
1733        assert_eq!(
1734            conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1735            Value::NULL
1736        );
1737
1738        let plugins: &[&str] = if !conn.inner.is_mariadb && conn.server_version() >= (5, 8, 0) {
1739            &["mysql_native_password", "caching_sha2_password"]
1740        } else {
1741            &["mysql_native_password"]
1742        };
1743
1744        for plugin in plugins {
1745            let mut rng = rand::thread_rng();
1746            let mut pass = [0u8; 10];
1747            pass.try_fill(&mut rng).unwrap();
1748            let pass: String = IntoIterator::into_iter(pass)
1749                .map(|x| ((x % (123 - 97)) + 97) as char)
1750                .collect();
1751
1752            conn.query_drop("DELETE FROM mysql.user WHERE user = '__mats'")
1753                .await
1754                .unwrap();
1755            conn.query_drop("FLUSH PRIVILEGES").await.unwrap();
1756
1757            if conn.inner.is_mariadb || conn.server_version() < (5, 7, 0) {
1758                if matches!(conn.server_version(), (5, 6, _)) {
1759                    conn.query_drop("CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_old_password")
1760                        .await
1761                        .unwrap();
1762                    conn.query_drop(format!(
1763                        "SET PASSWORD FOR '__mats'@'%' = OLD_PASSWORD({})",
1764                        Value::from(pass.clone()).as_sql(false)
1765                    ))
1766                    .await
1767                    .unwrap();
1768                } else {
1769                    conn.query_drop("CREATE USER '__mats'@'%'").await.unwrap();
1770                    conn.query_drop(format!(
1771                        "SET PASSWORD FOR '__mats'@'%' = PASSWORD({})",
1772                        Value::from(pass.clone()).as_sql(false)
1773                    ))
1774                    .await
1775                    .unwrap();
1776                }
1777            } else {
1778                conn.query_drop(format!(
1779                    "CREATE USER '__mats'@'%' IDENTIFIED WITH {} BY {}",
1780                    plugin,
1781                    Value::from(pass.clone()).as_sql(false)
1782                ))
1783                .await
1784                .unwrap();
1785            };
1786
1787            conn.query_drop("FLUSH PRIVILEGES").await.unwrap();
1788
1789            let mut conn2 = Conn::new(get_opts().secure_auth(false)).await.unwrap();
1790            conn2
1791                .change_user(
1792                    ChangeUserOpts::default()
1793                        .with_db_name(None)
1794                        .with_user(Some("__mats".into()))
1795                        .with_pass(Some(pass)),
1796                )
1797                .await
1798                .unwrap();
1799            let (db, user) = conn2
1800                .query_first::<(Option<String>, String), _>("SELECT DATABASE(), USER();")
1801                .await
1802                .unwrap()
1803                .unwrap();
1804            assert_eq!(db, None);
1805            assert!(user.starts_with("__mats"));
1806
1807            conn2.disconnect().await.unwrap();
1808        }
1809
1810        conn.disconnect().await?;
1811        Ok(())
1812    }
1813
1814    #[tokio::test]
1815    async fn should_not_cache_statements_if_stmt_cache_size_is_zero() -> super::Result<()> {
1816        let opts = OptsBuilder::from_opts(get_opts()).stmt_cache_size(0);
1817
1818        let mut conn = Conn::new(opts).await?;
1819        conn.exec_drop("DO ?", (1_u8,)).await?;
1820
1821        let stmt = conn.prep("DO 2").await?;
1822        conn.exec_drop(&stmt, ()).await?;
1823        conn.exec_drop(&stmt, ()).await?;
1824        conn.close(stmt).await?;
1825
1826        conn.exec_drop("DO 3", ()).await?;
1827        conn.exec_batch("DO 4", vec![(), ()]).await?;
1828        conn.exec_first::<u8, _, _>("DO 5", ()).await?;
1829        let row: Option<(crate::Value, usize)> = conn
1830            .query_first("SHOW SESSION STATUS LIKE 'Com_stmt_close';")
1831            .await?;
1832
1833        assert_eq!(row.unwrap().1, 1);
1834        assert_eq!(conn.inner.stmt_cache.len(), 0);
1835
1836        conn.disconnect().await?;
1837
1838        Ok(())
1839    }
1840
1841    #[tokio::test]
1842    async fn should_hold_stmt_cache_size_bound() -> super::Result<()> {
1843        let opts = OptsBuilder::from_opts(get_opts()).stmt_cache_size(3);
1844        let mut conn = Conn::new(opts).await?;
1845        conn.exec_drop("DO 1", ()).await?;
1846        conn.exec_drop("DO 2", ()).await?;
1847        conn.exec_drop("DO 3", ()).await?;
1848        conn.exec_drop("DO 1", ()).await?;
1849        conn.exec_drop("DO 4", ()).await?;
1850        conn.exec_drop("DO 3", ()).await?;
1851        conn.exec_drop("DO 5", ()).await?;
1852        conn.exec_drop("DO 6", ()).await?;
1853        let row_opt = conn
1854            .query_first("SHOW SESSION STATUS LIKE 'Com_stmt_close';")
1855            .await?;
1856        let (_, count): (String, usize) = row_opt.unwrap();
1857        assert_eq!(count, 3);
1858        let order = conn
1859            .stmt_cache_ref()
1860            .iter()
1861            .map(|item| item.1.query.0.as_ref())
1862            .collect::<Vec<&[u8]>>();
1863        assert_eq!(order, &[b"DO 6", b"DO 5", b"DO 3"]);
1864        conn.disconnect().await?;
1865        Ok(())
1866    }
1867
1868    #[tokio::test]
1869    async fn should_perform_queries() -> super::Result<()> {
1870        let mut conn = Conn::new(get_opts()).await?;
1871        for x in (MAX_PAYLOAD_LEN - 2)..=(MAX_PAYLOAD_LEN + 2) {
1872            let long_string = ::std::iter::repeat('A').take(x).collect::<String>();
1873            let result: Vec<(String, u8)> = conn
1874                .query(format!(r"SELECT '{}', 231", long_string))
1875                .await?;
1876            assert_eq!((long_string, 231_u8), result[0]);
1877        }
1878        conn.disconnect().await?;
1879        Ok(())
1880    }
1881
1882    #[tokio::test]
1883    async fn should_query_drop() -> super::Result<()> {
1884        let mut conn = Conn::new(get_opts()).await?;
1885        conn.query_drop("CREATE TEMPORARY TABLE tmp (id int DEFAULT 10, name text)")
1886            .await?;
1887        conn.query_drop("INSERT INTO tmp VALUES (1, 'foo')").await?;
1888        let result: Option<u8> = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
1889        conn.disconnect().await?;
1890        assert_eq!(result, Some(1_u8));
1891        Ok(())
1892    }
1893
1894    #[tokio::test]
1895    async fn should_prepare_statement() -> super::Result<()> {
1896        let mut conn = Conn::new(get_opts()).await?;
1897        let stmt = conn.prep(r"SELECT ?").await?;
1898        conn.close(stmt).await?;
1899        conn.disconnect().await?;
1900
1901        let mut conn = Conn::new(get_opts()).await?;
1902        let stmt = conn.prep(r"SELECT :foo").await?;
1903
1904        {
1905            let query = String::from("SELECT ?, ?");
1906            let stmt = conn.prep(&*query).await?;
1907            conn.close(stmt).await?;
1908            {
1909                let mut conn = Conn::new(get_opts()).await?;
1910                let stmt = conn.prep(&*query).await?;
1911                conn.close(stmt).await?;
1912                conn.disconnect().await?;
1913            }
1914        }
1915
1916        conn.close(stmt).await?;
1917        conn.disconnect().await?;
1918
1919        Ok(())
1920    }
1921
1922    #[tokio::test]
1923    async fn should_execute_statement() -> super::Result<()> {
1924        let long_string = ::std::iter::repeat('A')
1925            .take(18 * 1024 * 1024)
1926            .collect::<String>();
1927        let mut conn = Conn::new(get_opts()).await?;
1928        let stmt = conn.prep(r"SELECT ?").await?;
1929        let result = conn.exec_iter(&stmt, (&long_string,)).await?;
1930        let mut mapped = result
1931            .map_and_drop(|row| from_row::<(String,)>(row))
1932            .await?;
1933        assert_eq!(mapped.len(), 1);
1934        assert_eq!(mapped.pop(), Some((long_string,)));
1935        let result = conn.exec_iter(&stmt, (42_u8,)).await?;
1936        let collected = result.collect_and_drop::<(u8,)>().await?;
1937        assert_eq!(collected, vec![(42u8,)]);
1938        let result = conn.exec_iter(&stmt, (8_u8,)).await?;
1939        let reduced = result
1940            .reduce_and_drop(2, |mut acc, row| {
1941                acc += from_row::<i32>(row);
1942                acc
1943            })
1944            .await?;
1945        conn.close(stmt).await?;
1946        conn.disconnect().await?;
1947        assert_eq!(reduced, 10);
1948
1949        let mut conn = Conn::new(get_opts()).await?;
1950        let stmt = conn.prep(r"SELECT :foo, :bar, :foo, 3").await?;
1951        let result = conn
1952            .exec_iter(&stmt, params! { "foo" => "quux", "bar" => "baz" })
1953            .await?;
1954        let mut mapped = result
1955            .map_and_drop(|row| from_row::<(String, String, String, u8)>(row))
1956            .await?;
1957        assert_eq!(mapped.len(), 1);
1958        assert_eq!(
1959            mapped.pop(),
1960            Some(("quux".into(), "baz".into(), "quux".into(), 3))
1961        );
1962        let result = conn
1963            .exec_iter(&stmt, params! { "foo" => 2, "bar" => 3 })
1964            .await?;
1965        let collected = result.collect_and_drop::<(u8, u8, u8, u8)>().await?;
1966        assert_eq!(collected, vec![(2, 3, 2, 3)]);
1967        let result = conn
1968            .exec_iter(&stmt, params! { "foo" => 2, "bar" => 3 })
1969            .await?;
1970        let reduced = result
1971            .reduce_and_drop(0, |acc, row| {
1972                let (a, b, c, d): (u8, u8, u8, u8) = from_row(row);
1973                acc + a + b + c + d
1974            })
1975            .await?;
1976        conn.close(stmt).await?;
1977        conn.disconnect().await?;
1978        assert_eq!(reduced, 10);
1979        Ok(())
1980    }
1981
1982    #[tokio::test]
1983    async fn should_prep_exec_statement() -> super::Result<()> {
1984        let mut conn = Conn::new(get_opts()).await?;
1985        let result = conn
1986            .exec_iter(r"SELECT :a, :b, :a", params! { "a" => 2, "b" => 3 })
1987            .await?;
1988        let output = result
1989            .map_and_drop(|row| {
1990                let (a, b, c): (u8, u8, u8) = from_row(row);
1991                a * b * c
1992            })
1993            .await?;
1994        conn.disconnect().await?;
1995        assert_eq!(output[0], 12u8);
1996        Ok(())
1997    }
1998
1999    #[tokio::test]
2000    async fn should_first_exec_statement() -> super::Result<()> {
2001        let mut conn = Conn::new(get_opts()).await?;
2002        let output = conn
2003            .exec_first(
2004                r"SELECT :a UNION ALL SELECT :b",
2005                params! { "a" => 2, "b" => 3 },
2006            )
2007            .await?;
2008        conn.disconnect().await?;
2009        assert_eq!(output, Some(2u8));
2010        Ok(())
2011    }
2012
2013    #[tokio::test]
2014    async fn issue_107() -> super::Result<()> {
2015        let mut conn = Conn::new(get_opts()).await?;
2016        conn.query_drop(
2017            r"CREATE TEMPORARY TABLE mysql.issue (
2018                    a BIGINT(20) UNSIGNED,
2019                    b VARBINARY(16),
2020                    c BINARY(32),
2021                    d BIGINT(20) UNSIGNED,
2022                    e BINARY(32)
2023                )",
2024        )
2025        .await?;
2026        conn.query_drop(
2027            r"INSERT INTO mysql.issue VALUES (
2028                    0,
2029                    0xC066F966B0860000,
2030                    0x7939DA98E524C5F969FC2DE8D905FD9501EBC6F20001B0A9C941E0BE6D50CF44,
2031                    0,
2032                    ''
2033                ), (
2034                    1,
2035                    '',
2036                    0x076311DF4D407B0854371BA13A5F3FB1A4555AC22B361375FD47B263F31822F2,
2037                    0,
2038                    ''
2039                )",
2040        )
2041        .await?;
2042
2043        let q = "SELECT b, c, d, e FROM mysql.issue";
2044        let result = conn.query_iter(q).await?;
2045
2046        let loaded_structs = result
2047            .map_and_drop(|row| crate::from_row::<(Vec<u8>, Vec<u8>, u64, Vec<u8>)>(row))
2048            .await?;
2049
2050        conn.disconnect().await?;
2051
2052        assert_eq!(loaded_structs.len(), 2);
2053
2054        Ok(())
2055    }
2056
2057    #[tokio::test]
2058    async fn should_run_transactions() -> super::Result<()> {
2059        let mut conn = Conn::new(get_opts()).await?;
2060        conn.query_drop("CREATE TEMPORARY TABLE tmp (id INT, name TEXT)")
2061            .await?;
2062        let mut transaction = conn.start_transaction(Default::default()).await?;
2063        transaction
2064            .query_drop("INSERT INTO tmp VALUES (1, 'foo'), (2, 'bar')")
2065            .await?;
2066        assert_eq!(transaction.last_insert_id(), None);
2067        assert_eq!(transaction.affected_rows(), 2);
2068        assert_eq!(transaction.get_warnings(), 0);
2069        assert_eq!(transaction.info(), "Records: 2  Duplicates: 0  Warnings: 0");
2070        transaction.commit().await?;
2071        let output_opt = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
2072        assert_eq!(output_opt, Some((2u8,)));
2073        let mut transaction = conn.start_transaction(Default::default()).await?;
2074        transaction
2075            .query_drop("INSERT INTO tmp VALUES (3, 'baz'), (4, 'quux')")
2076            .await?;
2077        let output_opt = transaction
2078            .exec_first("SELECT COUNT(*) FROM tmp", ())
2079            .await?;
2080        assert_eq!(output_opt, Some((4u8,)));
2081        transaction.rollback().await?;
2082        let output_opt = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
2083        assert_eq!(output_opt, Some((2u8,)));
2084
2085        let mut transaction = conn.start_transaction(Default::default()).await?;
2086        transaction
2087            .query_drop("INSERT INTO tmp VALUES (3, 'baz')")
2088            .await?;
2089        drop(transaction); // implicit rollback
2090        let output_opt = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
2091        assert_eq!(output_opt, Some((2u8,)));
2092
2093        conn.disconnect().await?;
2094        Ok(())
2095    }
2096
2097    #[tokio::test]
2098    async fn should_handle_multiresult_set_with_error() -> super::Result<()> {
2099        const QUERY_FIRST: &str = "SELECT * FROM tmp; SELECT 1; SELECT 2;";
2100        const QUERY_MIDDLE: &str = "SELECT 1; SELECT * FROM tmp; SELECT 2";
2101        let mut conn = Conn::new(get_opts()).await.unwrap();
2102
2103        // if error is in the first result set, then query should return it immediately.
2104        let result = QUERY_FIRST.run(&mut conn).await;
2105        assert!(matches!(result, Err(Error::Server(_))));
2106
2107        let mut result = QUERY_MIDDLE.run(&mut conn).await.unwrap();
2108
2109        // first result set will contain one row
2110        let result_set: Vec<u8> = result.collect().await.unwrap();
2111        assert_eq!(result_set, vec![1]);
2112
2113        // second result set will contain an error.
2114        let result_set: super::Result<Vec<u8>> = result.collect().await;
2115        assert!(matches!(result_set, Err(Error::Server(_))));
2116
2117        // there will be no third result set
2118        assert!(result.is_empty());
2119
2120        conn.ping().await?;
2121        conn.disconnect().await?;
2122
2123        Ok(())
2124    }
2125
2126    #[tokio::test]
2127    async fn should_handle_binary_multiresult_set_with_error() -> super::Result<()> {
2128        const PROC_DEF_FIRST: &str =
2129            r#"CREATE PROCEDURE err_first() BEGIN SELECT * FROM tmp; SELECT 1; END"#;
2130        const PROC_DEF_MIDDLE: &str =
2131            r#"CREATE PROCEDURE err_middle() BEGIN SELECT 1; SELECT * FROM tmp; SELECT 2; END"#;
2132
2133        let mut conn = Conn::new(get_opts()).await.unwrap();
2134
2135        conn.query_drop("DROP PROCEDURE IF EXISTS err_first")
2136            .await?;
2137        conn.query_iter(PROC_DEF_FIRST).await?;
2138
2139        conn.query_drop("DROP PROCEDURE IF EXISTS err_middle")
2140            .await?;
2141        conn.query_iter(PROC_DEF_MIDDLE).await?;
2142
2143        // if error is in the first result set, then query should return it immediately.
2144        let result = conn.query_iter("CALL err_first()").await;
2145        assert!(matches!(result, Err(Error::Server(_))));
2146
2147        let mut result = conn.query_iter("CALL err_middle()").await?;
2148
2149        // first result set will contain one row
2150        let result_set: Vec<u8> = result.collect().await.unwrap();
2151        assert_eq!(result_set, vec![1]);
2152
2153        // second result set will contain an error.
2154        let result_set: super::Result<Vec<u8>> = result.collect().await;
2155        assert!(matches!(result_set, Err(Error::Server(_))));
2156
2157        // there will be no third result set
2158        assert!(result.is_empty());
2159
2160        conn.ping().await?;
2161        conn.disconnect().await?;
2162
2163        Ok(())
2164    }
2165
2166    #[tokio::test]
2167    async fn should_handle_multiresult_set_with_local_infile() -> super::Result<()> {
2168        use std::fs::write;
2169
2170        let file_path = tempfile::Builder::new().tempfile_in("").unwrap();
2171        let file_path = file_path.path();
2172        let file_name = file_path.file_name().unwrap();
2173
2174        write(file_name, b"AAAAAA\nBBBBBB\nCCCCCC\n")?;
2175
2176        let opts = get_opts().local_infile_handler(Some(WhiteListFsHandler::new(&[file_name][..])));
2177
2178        // LOCAL INFILE in the middle of a multi-result set should not break anything.
2179        let mut conn = Conn::new(opts).await.unwrap();
2180        "CREATE TEMPORARY TABLE tmp (a TEXT)".run(&mut conn).await?;
2181
2182        let query = format!(
2183            r#"SELECT * FROM tmp;
2184            LOAD DATA LOCAL INFILE "{}" INTO TABLE tmp;
2185            LOAD DATA LOCAL INFILE "{}" INTO TABLE tmp;
2186            SELECT * FROM tmp"#,
2187            file_name.to_str().unwrap(),
2188            file_name.to_str().unwrap(),
2189        );
2190
2191        let mut result = query.run(&mut conn).await?;
2192
2193        let result_set = result.collect::<String>().await?;
2194        assert_eq!(result_set.len(), 0);
2195
2196        let mut no_local_infile = false;
2197
2198        for _ in 0..2 {
2199            match result.collect::<String>().await {
2200                Ok(result_set) => {
2201                    assert_eq!(result.affected_rows(), 3);
2202                    assert!(result_set.is_empty())
2203                }
2204                Err(Error::Server(ref err)) if err.code == 1148 => {
2205                    // The used command is not allowed with this MySQL version
2206                    no_local_infile = true;
2207                    break;
2208                }
2209                Err(Error::Server(ref err)) if err.code == 3948 => {
2210                    // Loading local data is disabled;
2211                    // this must be enabled on both the client and server sides
2212                    no_local_infile = true;
2213                    break;
2214                }
2215                Err(err) => return Err(err),
2216            }
2217        }
2218
2219        if no_local_infile {
2220            assert!(result.is_empty());
2221            assert_eq!(result_set.len(), 0);
2222        } else {
2223            let result_set = result.collect::<String>().await?;
2224            assert_eq!(result_set.len(), 6);
2225            assert_eq!(result_set[0], "AAAAAA");
2226            assert_eq!(result_set[1], "BBBBBB");
2227            assert_eq!(result_set[2], "CCCCCC");
2228            assert_eq!(result_set[3], "AAAAAA");
2229            assert_eq!(result_set[4], "BBBBBB");
2230            assert_eq!(result_set[5], "CCCCCC");
2231        }
2232
2233        conn.ping().await?;
2234        conn.disconnect().await?;
2235
2236        Ok(())
2237    }
2238
2239    #[tokio::test]
2240    async fn should_provide_multiresult_set_metadata() -> super::Result<()> {
2241        let mut c = Conn::new(get_opts()).await?;
2242        c.query_drop("CREATE TEMPORARY TABLE tmp (id INT, foo TEXT)")
2243            .await?;
2244
2245        let mut result = c
2246            .query_iter("SELECT 1; SELECT id, foo FROM tmp WHERE 1 = 2; DO 42; SELECT 2;")
2247            .await?;
2248        assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 1);
2249
2250        result.for_each(drop).await?;
2251        assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 2);
2252
2253        result.for_each(drop).await?;
2254        assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 0);
2255
2256        result.for_each(drop).await?;
2257        assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 1);
2258
2259        c.disconnect().await?;
2260        Ok(())
2261    }
2262
2263    #[tokio::test]
2264    async fn should_expose_query_result_metadata() -> super::Result<()> {
2265        let pool = Pool::new(get_opts());
2266        let mut c = pool.get_conn().await?;
2267
2268        c.query_drop(
2269            r"
2270            CREATE TEMPORARY TABLE `foo`
2271                ( `id` SERIAL
2272                , `bar_id` varchar(36) NOT NULL
2273                , `baz_id` varchar(36) NOT NULL
2274                , `ctime` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP()
2275                , PRIMARY KEY (`id`)
2276                , KEY `bar_idx` (`bar_id`)
2277                , KEY `baz_idx` (`baz_id`)
2278            );",
2279        )
2280        .await?;
2281
2282        const QUERY: &str = "INSERT INTO foo (bar_id, baz_id) VALUES (?, ?)";
2283        let params = ("qwerty", "data.employee_id");
2284
2285        let query_result = c.exec_iter(QUERY, params).await?;
2286        assert_eq!(query_result.last_insert_id(), Some(1));
2287        query_result.drop_result().await?;
2288
2289        c.exec_drop(QUERY, params).await?;
2290        assert_eq!(c.last_insert_id(), Some(2));
2291
2292        let mut tx = c.start_transaction(Default::default()).await?;
2293
2294        tx.exec_drop(QUERY, params).await?;
2295        assert_eq!(tx.last_insert_id(), Some(3));
2296
2297        Ok(())
2298    }
2299
2300    #[tokio::test]
2301    async fn should_handle_local_infile_locally() -> super::Result<()> {
2302        let mut conn = Conn::new(get_opts()).await.unwrap();
2303        conn.query_drop("CREATE TEMPORARY TABLE tmp (a TEXT);")
2304            .await
2305            .unwrap();
2306
2307        conn.set_infile_handler(async move {
2308            Ok(
2309                stream::iter([Bytes::from("AAAAAA\n"), Bytes::from("BBBBBB\nCCCCCC\n")])
2310                    .map(Ok)
2311                    .boxed(),
2312            )
2313        });
2314
2315        match conn
2316            .query_drop(r#"LOAD DATA LOCAL INFILE "dummy" INTO TABLE tmp;"#)
2317            .await
2318        {
2319            Ok(_) => (),
2320            Err(super::Error::Server(ref err)) if err.code == 1148 => {
2321                // The used command is not allowed with this MySQL version
2322                return Ok(());
2323            }
2324            Err(super::Error::Server(ref err)) if err.code == 3948 => {
2325                // Loading local data is disabled;
2326                // this must be enabled on both the client and server sides
2327                return Ok(());
2328            }
2329            e @ Err(_) => e.unwrap(),
2330        };
2331
2332        let result: Vec<String> = conn.query("SELECT * FROM tmp").await?;
2333        assert_eq!(result.len(), 3);
2334        assert_eq!(result[0], "AAAAAA");
2335        assert_eq!(result[1], "BBBBBB");
2336        assert_eq!(result[2], "CCCCCC");
2337
2338        Ok(())
2339    }
2340
2341    #[tokio::test]
2342    async fn should_handle_local_infile_globally() -> super::Result<()> {
2343        use std::fs::write;
2344
2345        let file_path = tempfile::Builder::new().tempfile_in("").unwrap();
2346        let file_path = file_path.path();
2347        let file_name = file_path.file_name().unwrap();
2348
2349        write(file_name, b"AAAAAA\nBBBBBB\nCCCCCC\n")?;
2350
2351        let opts = get_opts().local_infile_handler(Some(WhiteListFsHandler::new(&[file_name][..])));
2352
2353        let mut conn = Conn::new(opts).await.unwrap();
2354        conn.query_drop("CREATE TEMPORARY TABLE tmp (a TEXT);")
2355            .await
2356            .unwrap();
2357
2358        match conn
2359            .query_drop(format!(
2360                r#"LOAD DATA LOCAL INFILE "{}" INTO TABLE tmp;"#,
2361                file_name.to_str().unwrap(),
2362            ))
2363            .await
2364        {
2365            Ok(_) => (),
2366            Err(super::Error::Server(ref err)) if err.code == 1148 => {
2367                // The used command is not allowed with this MySQL version
2368                return Ok(());
2369            }
2370            Err(super::Error::Server(ref err)) if err.code == 3948 => {
2371                // Loading local data is disabled;
2372                // this must be enabled on both the client and server sides
2373                return Ok(());
2374            }
2375            e @ Err(_) => e.unwrap(),
2376        };
2377
2378        let result: Vec<String> = conn.query("SELECT * FROM tmp").await?;
2379        assert_eq!(result.len(), 3);
2380        assert_eq!(result[0], "AAAAAA");
2381        assert_eq!(result[1], "BBBBBB");
2382        assert_eq!(result[2], "CCCCCC");
2383
2384        Ok(())
2385    }
2386
2387    #[cfg(feature = "nightly")]
2388    mod bench {
2389        use crate::{conn::Conn, queryable::Queryable, test_misc::get_opts};
2390
2391        #[bench]
2392        fn simple_exec(bencher: &mut test::Bencher) {
2393            let mut runtime = tokio::runtime::Runtime::new().unwrap();
2394            let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2395
2396            bencher.iter(|| {
2397                runtime.block_on(conn.query_drop("DO 1")).unwrap();
2398            });
2399
2400            runtime.block_on(conn.disconnect()).unwrap();
2401        }
2402
2403        #[bench]
2404        fn select_large_string(bencher: &mut test::Bencher) {
2405            let mut runtime = tokio::runtime::Runtime::new().unwrap();
2406            let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2407
2408            bencher.iter(|| {
2409                runtime
2410                    .block_on(conn.query_drop("SELECT REPEAT('A', 10000)"))
2411                    .unwrap();
2412            });
2413
2414            runtime.block_on(conn.disconnect()).unwrap();
2415        }
2416
2417        #[bench]
2418        fn prepared_exec(bencher: &mut test::Bencher) {
2419            let mut runtime = tokio::runtime::Runtime::new().unwrap();
2420            let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2421            let stmt = runtime.block_on(conn.prep("DO 1")).unwrap();
2422
2423            bencher.iter(|| {
2424                runtime.block_on(conn.exec_drop(&stmt, ())).unwrap();
2425            });
2426
2427            runtime.block_on(conn.close(stmt)).unwrap();
2428            runtime.block_on(conn.disconnect()).unwrap();
2429        }
2430
2431        #[bench]
2432        fn prepare_and_exec(bencher: &mut test::Bencher) {
2433            let mut runtime = tokio::runtime::Runtime::new().unwrap();
2434            let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2435
2436            bencher.iter(|| {
2437                runtime.block_on(conn.exec_drop("SELECT ?", (0,))).unwrap();
2438            });
2439
2440            runtime.block_on(conn.disconnect()).unwrap();
2441        }
2442    }
2443}