mysql/conn/
mod.rs

1// Copyright (c) 2020 rust-mysql-simple contributors
2//
3// Licensed under the Apache License, Version 2.0
4// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. All files in the project carrying such notice may not be copied,
7// modified, or distributed except according to those terms.
8
9use bytes::{Buf, BufMut};
10use mysql_common::{
11    constants::UTF8MB4_GENERAL_CI,
12    crypto,
13    io::{ParseBuf, ReadMysqlExt},
14    misc::raw::Either,
15    named_params::parse_named_params,
16    packets::{
17        binlog_request::BinlogRequest, AuthPlugin, AuthSwitchRequest, Column, ComStmtClose,
18        ComStmtExecuteRequestBuilder, ComStmtSendLongData, CommonOkPacket, ErrPacket,
19        HandshakePacket, HandshakeResponse, OkPacket, OkPacketDeserializer, OkPacketKind,
20        OldAuthSwitchRequest, OldEofPacket, ResultSetTerminator, SessionStateInfo,
21    },
22    proto::{codec::Compression, sync_framed::MySyncFramed, MySerialize},
23};
24
25use mysql_common::{
26    constants::{DEFAULT_MAX_ALLOWED_PACKET, UTF8_GENERAL_CI},
27    packets::SslRequest,
28};
29
30#[cfg(not(target_os = "wasi"))]
31use std::process;
32use std::{
33    borrow::{Borrow, Cow},
34    cmp,
35    collections::HashMap,
36    convert::TryFrom,
37    io::{self, Write as _},
38    mem,
39    ops::{Deref, DerefMut},
40    sync::Arc,
41};
42
43#[cfg(unix)]
44use std::os::unix::io::{AsRawFd, RawFd};
45
46use crate::{
47    buffer_pool::{get_buffer, Buffer},
48    conn::{
49        local_infile::LocalInfile,
50        pool::{Pool, PooledConn},
51        query_result::{Binary, Or, Text},
52        stmt::{InnerStmt, Statement},
53        stmt_cache::StmtCache,
54        transaction::{AccessMode, TxOpts},
55    },
56    consts::{CapabilityFlags, Command, StatusFlags, MAX_PAYLOAD_LEN},
57    from_value, from_value_opt,
58    io::Stream,
59    prelude::*,
60    DriverError::{
61        MismatchedStmtParams, NamedParamsForPositionalQuery, OldMysqlPasswordDisabled,
62        Protocol41NotSet, ReadOnlyTransNotSupported, SetupError, UnexpectedPacket,
63        UnknownAuthPlugin, UnsupportedProtocol,
64    },
65    Error::{self, DriverError, MySqlError},
66    LocalInfileHandler, Opts, OptsBuilder, Params, QueryResult, Result, Transaction,
67    Value::{self, Bytes, NULL},
68};
69
70use crate::DriverError::TlsNotSupported;
71use crate::SslOpts;
72
73use self::binlog_stream::BinlogStream;
74
75pub mod binlog_stream;
76pub mod local_infile;
77pub mod opts;
78pub mod pool;
79pub mod query;
80pub mod query_result;
81pub mod queryable;
82pub mod stmt;
83mod stmt_cache;
84pub mod transaction;
85
86/// Mutable connection.
87#[derive(Debug)]
88pub enum ConnMut<'c, 't, 'tc> {
89    Mut(&'c mut Conn),
90    TxMut(&'t mut Transaction<'tc>),
91    Owned(Conn),
92    Pooled(PooledConn),
93}
94
95impl From<Conn> for ConnMut<'static, 'static, 'static> {
96    fn from(conn: Conn) -> Self {
97        ConnMut::Owned(conn)
98    }
99}
100
101impl From<PooledConn> for ConnMut<'static, 'static, 'static> {
102    fn from(conn: PooledConn) -> Self {
103        ConnMut::Pooled(conn)
104    }
105}
106
107impl<'a> From<&'a mut Conn> for ConnMut<'a, 'static, 'static> {
108    fn from(conn: &'a mut Conn) -> Self {
109        ConnMut::Mut(conn)
110    }
111}
112
113impl<'a> From<&'a mut PooledConn> for ConnMut<'a, 'static, 'static> {
114    fn from(conn: &'a mut PooledConn) -> Self {
115        ConnMut::Mut(conn.as_mut())
116    }
117}
118
119impl<'t, 'tc> From<&'t mut Transaction<'tc>> for ConnMut<'static, 't, 'tc> {
120    fn from(tx: &'t mut Transaction<'tc>) -> Self {
121        ConnMut::TxMut(tx)
122    }
123}
124
125impl TryFrom<&Pool> for ConnMut<'static, 'static, 'static> {
126    type Error = Error;
127
128    fn try_from(pool: &Pool) -> Result<Self> {
129        pool.get_conn().map(From::from)
130    }
131}
132
133impl Deref for ConnMut<'_, '_, '_> {
134    type Target = Conn;
135
136    fn deref(&self) -> &Conn {
137        match self {
138            ConnMut::Mut(conn) => &**conn,
139            ConnMut::TxMut(tx) => &*tx.conn,
140            ConnMut::Owned(conn) => &conn,
141            ConnMut::Pooled(conn) => conn.as_ref(),
142        }
143    }
144}
145
146impl DerefMut for ConnMut<'_, '_, '_> {
147    fn deref_mut(&mut self) -> &mut Conn {
148        match self {
149            ConnMut::Mut(ref mut conn) => &mut **conn,
150            ConnMut::TxMut(tx) => &mut *tx.conn,
151            ConnMut::Owned(ref mut conn) => conn,
152            ConnMut::Pooled(ref mut conn) => conn.as_mut(),
153        }
154    }
155}
156
157/// Connection internals.
158#[derive(Debug)]
159struct ConnInner {
160    opts: Opts,
161    stream: Option<MySyncFramed<Stream>>,
162    stmt_cache: StmtCache,
163
164    // TODO: clean this up
165    server_version: Option<(u16, u16, u16)>,
166    mariadb_server_version: Option<(u16, u16, u16)>,
167
168    /// Last Ok packet, if any.
169    ok_packet: Option<OkPacket<'static>>,
170    capability_flags: CapabilityFlags,
171    connection_id: u32,
172    status_flags: StatusFlags,
173    character_set: u8,
174    last_command: u8,
175    connected: bool,
176    has_results: bool,
177    local_infile_handler: Option<LocalInfileHandler>,
178}
179
180impl ConnInner {
181    fn empty(opts: Opts) -> Self {
182        ConnInner {
183            stmt_cache: StmtCache::new(opts.get_stmt_cache_size()),
184            opts,
185            stream: None,
186            capability_flags: CapabilityFlags::empty(),
187            status_flags: StatusFlags::empty(),
188            connection_id: 0u32,
189            character_set: 0u8,
190            ok_packet: None,
191            last_command: 0u8,
192            connected: false,
193            has_results: false,
194            server_version: None,
195            mariadb_server_version: None,
196            local_infile_handler: None,
197        }
198    }
199}
200
201/// Mysql connection.
202#[derive(Debug)]
203pub struct Conn(Box<ConnInner>);
204
205impl Conn {
206    /// Must not be called before handle_handshake.
207    const fn has_capability(&self, flag: CapabilityFlags) -> bool {
208        self.0.capability_flags.contains(flag)
209    }
210
211    /// Returns version number reported by the server.
212    pub fn server_version(&self) -> (u16, u16, u16) {
213        self.0
214            .server_version
215            .or_else(|| self.0.mariadb_server_version)
216            .unwrap()
217    }
218
219    /// Returns connection identifier.
220    pub fn connection_id(&self) -> u32 {
221        self.0.connection_id
222    }
223
224    /// Returns number of rows affected by the last query.
225    pub fn affected_rows(&self) -> u64 {
226        self.0
227            .ok_packet
228            .as_ref()
229            .map(OkPacket::affected_rows)
230            .unwrap_or_default()
231    }
232
233    /// Returns last insert id of the last query.
234    ///
235    /// Returns zero if there was no last insert id.
236    pub fn last_insert_id(&self) -> u64 {
237        self.0
238            .ok_packet
239            .as_ref()
240            .and_then(OkPacket::last_insert_id)
241            .unwrap_or_default()
242    }
243
244    /// Returns number of warnings, reported by the server.
245    pub fn warnings(&self) -> u16 {
246        self.0
247            .ok_packet
248            .as_ref()
249            .map(OkPacket::warnings)
250            .unwrap_or_default()
251    }
252
253    /// [Info], reported by the server.
254    ///
255    /// Will be empty if not defined.
256    ///
257    /// [Info]: http://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
258    pub fn info_ref(&self) -> &[u8] {
259        self.0
260            .ok_packet
261            .as_ref()
262            .and_then(OkPacket::info_ref)
263            .unwrap_or_default()
264    }
265
266    /// [Info], reported by the server.
267    ///
268    /// Will be empty if not defined.
269    ///
270    /// [Info]: http://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
271    pub fn info_str(&self) -> Cow<str> {
272        self.0
273            .ok_packet
274            .as_ref()
275            .and_then(OkPacket::info_str)
276            .unwrap_or_default()
277    }
278
279    pub fn session_state_changes(&self) -> io::Result<Vec<SessionStateInfo<'_>>> {
280        self.0
281            .ok_packet
282            .as_ref()
283            .map(|ok| ok.session_state_info())
284            .transpose()
285            .map(Option::unwrap_or_default)
286    }
287
288    fn stream_ref(&self) -> &MySyncFramed<Stream> {
289        self.0.stream.as_ref().expect("incomplete connection")
290    }
291
292    fn stream_mut(&mut self) -> &mut MySyncFramed<Stream> {
293        self.0.stream.as_mut().expect("incomplete connection")
294    }
295
296    fn is_insecure(&self) -> bool {
297        self.stream_ref().get_ref().is_insecure()
298    }
299
300    fn is_socket(&self) -> bool {
301        self.stream_ref().get_ref().is_socket()
302    }
303
304    /// Check the connection can be improved.
305    #[allow(unused_assignments)]
306    fn can_improved(&mut self) -> Result<Option<Opts>> {
307        if self.0.opts.get_prefer_socket() && self.0.opts.addr_is_loopback() {
308            let mut socket = None;
309            #[cfg(test)]
310            {
311                socket = self.0.opts.0.injected_socket.clone();
312            }
313            if socket.is_none() {
314                socket = self.get_system_var("socket")?.map(from_value::<String>);
315            }
316            if let Some(socket) = socket {
317                if self.0.opts.get_socket().is_none() {
318                    let socket_opts = OptsBuilder::from_opts(self.0.opts.clone());
319                    if !socket.is_empty() {
320                        return Ok(Some(socket_opts.socket(Some(socket)).into()));
321                    }
322                }
323            }
324        }
325        Ok(None)
326    }
327
328    /// Creates new `Conn`.
329    pub fn new<T, E>(opts: T) -> Result<Conn>
330    where
331        Opts: TryFrom<T, Error = E>,
332        crate::Error: From<E>,
333    {
334        let opts = Opts::try_from(opts)?;
335        let mut conn = Conn(Box::new(ConnInner::empty(opts)));
336        conn.connect_stream()?;
337        conn.connect()?;
338        let mut conn = {
339            if let Some(new_opts) = conn.can_improved()? {
340                let mut improved_conn = Conn(Box::new(ConnInner::empty(new_opts)));
341                improved_conn
342                    .connect_stream()
343                    .and_then(|_| {
344                        improved_conn.connect()?;
345                        Ok(improved_conn)
346                    })
347                    .unwrap_or(conn)
348            } else {
349                conn
350            }
351        };
352        for cmd in conn.0.opts.get_init() {
353            conn.query_drop(cmd)?;
354        }
355        Ok(conn)
356    }
357
358    fn soft_reset(&mut self) -> Result<()> {
359        self.write_command(Command::COM_RESET_CONNECTION, &[])?;
360        let packet = self.read_packet()?;
361        self.handle_ok::<CommonOkPacket>(&packet)?;
362        self.0.last_command = 0;
363        self.0.stmt_cache.clear();
364        Ok(())
365    }
366
367    fn hard_reset(&mut self) -> Result<()> {
368        self.0.stmt_cache.clear();
369        self.0.capability_flags = CapabilityFlags::empty();
370        self.0.status_flags = StatusFlags::empty();
371        self.0.connection_id = 0;
372        self.0.character_set = 0;
373        self.0.ok_packet = None;
374        self.0.last_command = 0;
375        self.0.connected = false;
376        self.0.has_results = false;
377        self.connect_stream()?;
378        self.connect()
379    }
380
381    /// Resets `MyConn` (drops state then reconnects).
382    pub fn reset(&mut self) -> Result<()> {
383        match (self.0.server_version, self.0.mariadb_server_version) {
384            (Some(ref version), _) if *version > (5, 7, 3) => {
385                self.soft_reset().or_else(|_| self.hard_reset())
386            }
387            (_, Some(ref version)) if *version >= (10, 2, 7) => {
388                self.soft_reset().or_else(|_| self.hard_reset())
389            }
390            _ => self.hard_reset(),
391        }
392    }
393
394    fn switch_to_ssl(&mut self, ssl_opts: SslOpts) -> Result<()> {
395        let stream = self.0.stream.take().expect("incomplete conn");
396        let (in_buf, out_buf, codec, stream) = stream.destruct();
397        let stream = stream.make_secure(self.0.opts.get_host(), ssl_opts)?;
398        let stream = MySyncFramed::construct(in_buf, out_buf, codec, stream);
399        self.0.stream = Some(stream);
400        Ok(())
401    }
402
403    fn connect_stream(&mut self) -> Result<()> {
404        let opts = &self.0.opts;
405        let read_timeout = opts.get_read_timeout().cloned();
406        let write_timeout = opts.get_write_timeout().cloned();
407        let tcp_keepalive_time = opts.get_tcp_keepalive_time_ms();
408        #[cfg(any(target_os = "linux", target_os = "macos",))]
409        let tcp_keepalive_probe_interval_secs = opts.get_tcp_keepalive_probe_interval_secs();
410        #[cfg(any(target_os = "linux", target_os = "macos",))]
411        let tcp_keepalive_probe_count = opts.get_tcp_keepalive_probe_count();
412        #[cfg(target_os = "linux")]
413        let tcp_user_timeout = opts.get_tcp_user_timeout_ms();
414        let tcp_nodelay = opts.get_tcp_nodelay();
415        let tcp_connect_timeout = opts.get_tcp_connect_timeout();
416        let bind_address = opts.bind_address().cloned();
417        #[cfg(not(target_os = "wasi"))]
418        {
419            let stream = if let Some(socket) = opts.get_socket() {
420                Stream::connect_socket(&*socket, read_timeout, write_timeout)?
421            } else {
422                let port = opts.get_tcp_port();
423                let ip_or_hostname = match opts.get_host() {
424                    url::Host::Domain(domain) => domain,
425                    url::Host::Ipv4(ip) => ip.to_string(),
426                    url::Host::Ipv6(ip) => ip.to_string(),
427                };
428                Stream::connect_tcp(
429                    &*ip_or_hostname,
430                    port,
431                    read_timeout,
432                    write_timeout,
433                    tcp_keepalive_time,
434                    #[cfg(any(target_os = "linux", target_os = "macos",))]
435                    tcp_keepalive_probe_interval_secs,
436                    #[cfg(any(target_os = "linux", target_os = "macos",))]
437                    tcp_keepalive_probe_count,
438                    #[cfg(target_os = "linux")]
439                    tcp_user_timeout,
440                    tcp_nodelay,
441                    tcp_connect_timeout,
442                    bind_address,
443                )?
444            };
445            self.0.stream = Some(MySyncFramed::new(stream));
446        }
447        #[cfg(target_os = "wasi")]
448        {
449            let port = opts.get_tcp_port();
450            let ip_or_hostname = match opts.get_host() {
451                url::Host::Domain(domain) => domain,
452                url::Host::Ipv4(ip) => ip.to_string(),
453                url::Host::Ipv6(ip) => ip.to_string(),
454            };
455            let stream = Stream::connect_tcp(
456                &*ip_or_hostname,
457                port,
458                read_timeout,
459                write_timeout,
460                tcp_keepalive_time,
461                tcp_nodelay,
462                tcp_connect_timeout,
463                bind_address,
464            )?;
465            self.0.stream = Some(MySyncFramed::new(stream));
466        }
467        Ok(())
468    }
469
470    fn raw_read_packet(&mut self, buffer: &mut Vec<u8>) -> Result<()> {
471        if !self.stream_mut().next_packet(buffer)? {
472            Err(Error::server_disconnected())
473        } else {
474            Ok(())
475        }
476    }
477
478    fn read_packet(&mut self) -> Result<Buffer> {
479        loop {
480            let mut buffer = get_buffer();
481            match self.raw_read_packet(buffer.as_mut()) {
482                Ok(()) if buffer.first() == Some(&0xff) => {
483                    match ParseBuf(&*buffer).parse(self.0.capability_flags)? {
484                        ErrPacket::Error(server_error) => {
485                            self.handle_err();
486                            return Err(MySqlError(From::from(server_error)));
487                        }
488                        ErrPacket::Progress(_progress_report) => {
489                            // TODO: Report progress
490                            continue;
491                        }
492                    }
493                }
494                Ok(()) => return Ok(buffer),
495                Err(e) => {
496                    self.handle_err();
497                    return Err(e);
498                }
499            }
500        }
501    }
502
503    fn drop_packet(&mut self) -> Result<()> {
504        self.read_packet().map(drop)
505    }
506
507    fn write_struct<T: MySerialize>(&mut self, s: &T) -> Result<()> {
508        let mut buf = get_buffer();
509        s.serialize(buf.as_mut());
510        self.write_packet(&mut &*buf)
511    }
512
513    fn write_packet<T: Buf>(&mut self, data: &mut T) -> Result<()> {
514        self.stream_mut().send(data)?;
515        Ok(())
516    }
517
518    fn handle_handshake(&mut self, hp: &HandshakePacket<'_>) {
519        self.0.capability_flags = hp.capabilities() & self.get_client_flags();
520        self.0.status_flags = hp.status_flags();
521        self.0.connection_id = hp.connection_id();
522        self.0.character_set = hp.default_collation();
523        self.0.server_version = hp.server_version_parsed();
524        self.0.mariadb_server_version = hp.maria_db_server_version_parsed();
525    }
526
527    fn handle_ok<'a, T: OkPacketKind>(
528        &mut self,
529        buffer: &'a Buffer,
530    ) -> crate::Result<OkPacket<'a>> {
531        let ok = ParseBuf(&**buffer)
532            .parse::<OkPacketDeserializer<T>>(self.0.capability_flags)?
533            .into_inner();
534        self.0.status_flags = ok.status_flags();
535        self.0.ok_packet = Some(ok.clone().into_owned());
536        Ok(ok)
537    }
538
539    fn handle_err(&mut self) {
540        self.0.status_flags = StatusFlags::empty();
541        self.0.has_results = false;
542        self.0.ok_packet = None;
543    }
544
545    fn more_results_exists(&self) -> bool {
546        self.0
547            .status_flags
548            .contains(StatusFlags::SERVER_MORE_RESULTS_EXISTS)
549    }
550
551    fn perform_auth_switch(&mut self, auth_switch_request: AuthSwitchRequest<'_>) -> Result<()> {
552        if matches!(
553            auth_switch_request.auth_plugin(),
554            AuthPlugin::MysqlOldPassword
555        ) {
556            if self.0.opts.get_secure_auth() {
557                return Err(DriverError(OldMysqlPasswordDisabled));
558            }
559        }
560
561        let nonce = auth_switch_request.plugin_data();
562        let plugin_data = auth_switch_request
563            .auth_plugin()
564            .gen_data(self.0.opts.get_pass(), nonce)
565            .map(Either::Left)
566            .unwrap_or_else(|| Either::Right([]));
567        self.write_struct(&plugin_data)?;
568        self.continue_auth(&auth_switch_request.auth_plugin(), nonce, true)
569    }
570
571    fn do_handshake(&mut self) -> Result<()> {
572        let payload = self.read_packet()?;
573        let handshake = ParseBuf(&*payload).parse::<HandshakePacket>(())?;
574
575        if handshake.protocol_version() != 10u8 {
576            return Err(DriverError(UnsupportedProtocol(
577                handshake.protocol_version(),
578            )));
579        }
580
581        if !handshake
582            .capabilities()
583            .contains(CapabilityFlags::CLIENT_PROTOCOL_41)
584        {
585            return Err(DriverError(Protocol41NotSet));
586        }
587
588        self.handle_handshake(&handshake);
589
590        if self.is_insecure() {
591            if let Some(ssl_opts) = self.0.opts.get_ssl_opts().cloned() {
592                if !self.has_capability(CapabilityFlags::CLIENT_SSL) {
593                    return Err(DriverError(TlsNotSupported));
594                } else {
595                    self.do_ssl_request()?;
596                    self.switch_to_ssl(ssl_opts)?;
597                }
598            }
599        }
600
601        // Handshake scramble is always 21 bytes length (20 + zero terminator)
602        let nonce = {
603            let mut nonce = Vec::from(handshake.scramble_1_ref());
604            nonce.extend_from_slice(handshake.scramble_2_ref().unwrap_or(&[][..]));
605            // Trim zero terminator. Fill with zeroes if nonce
606            // is somehow smaller than 20 bytes (this matches the server behavior).
607            nonce.resize(20, 0);
608            nonce
609        };
610
611        let auth_plugin = handshake
612            .auth_plugin()
613            .unwrap_or(AuthPlugin::MysqlNativePassword);
614        if let AuthPlugin::Other(ref name) = auth_plugin {
615            let plugin_name = String::from_utf8_lossy(name).into();
616            return Err(DriverError(UnknownAuthPlugin(plugin_name)));
617        }
618
619        let auth_data = auth_plugin.gen_data(self.0.opts.get_pass(), &*nonce);
620        self.write_handshake_response(&auth_plugin, auth_data.as_deref())?;
621        self.continue_auth(&auth_plugin, &*nonce, false)?;
622
623        if self.has_capability(CapabilityFlags::CLIENT_COMPRESS) {
624            self.switch_to_compressed();
625        }
626
627        Ok(())
628    }
629
630    fn switch_to_compressed(&mut self) {
631        self.stream_mut()
632            .codec_mut()
633            .compress(Compression::default());
634    }
635
636    fn get_client_flags(&self) -> CapabilityFlags {
637        let mut client_flags = CapabilityFlags::CLIENT_PROTOCOL_41
638            | CapabilityFlags::CLIENT_SECURE_CONNECTION
639            | CapabilityFlags::CLIENT_LONG_PASSWORD
640            | CapabilityFlags::CLIENT_TRANSACTIONS
641            | CapabilityFlags::CLIENT_LOCAL_FILES
642            | CapabilityFlags::CLIENT_MULTI_STATEMENTS
643            | CapabilityFlags::CLIENT_MULTI_RESULTS
644            | CapabilityFlags::CLIENT_PS_MULTI_RESULTS
645            | CapabilityFlags::CLIENT_PLUGIN_AUTH
646            | CapabilityFlags::CLIENT_CONNECT_ATTRS
647            | (self.0.capability_flags & CapabilityFlags::CLIENT_LONG_FLAG);
648        if self.0.opts.get_compress().is_some() {
649            client_flags.insert(CapabilityFlags::CLIENT_COMPRESS);
650        }
651        if let Some(db_name) = self.0.opts.get_db_name() {
652            if !db_name.is_empty() {
653                client_flags.insert(CapabilityFlags::CLIENT_CONNECT_WITH_DB);
654            }
655        }
656        if self.is_insecure() && self.0.opts.get_ssl_opts().is_some() {
657            client_flags.insert(CapabilityFlags::CLIENT_SSL);
658        }
659        client_flags | self.0.opts.get_additional_capabilities()
660    }
661
662    fn connect_attrs(&self) -> HashMap<String, String> {
663        let program_name = match self.0.opts.get_connect_attrs().get("program_name") {
664            Some(program_name) => program_name.clone(),
665            None => {
666                let arg0 = std::env::args_os().next();
667                let arg0 = arg0.as_ref().map(|x| x.to_string_lossy());
668                arg0.unwrap_or_else(|| "".into()).to_owned().to_string()
669            }
670        };
671
672        let mut attrs = HashMap::new();
673
674        attrs.insert("_client_name".into(), "rust-mysql-simple".into());
675        attrs.insert("_client_version".into(), env!("CARGO_PKG_VERSION").into());
676        attrs.insert("_os".into(), env!("CARGO_CFG_TARGET_OS").into());
677        #[cfg(not(target_os = "wasi"))]
678        attrs.insert("_pid".into(), process::id().to_string());
679        #[cfg(target_os = "wasi")]
680        attrs.insert("_pid".into(), "66666".into());
681        attrs.insert("_platform".into(), env!("CARGO_CFG_TARGET_ARCH").into());
682        attrs.insert("program_name".into(), program_name);
683
684        for (name, value) in self.0.opts.get_connect_attrs().clone() {
685            attrs.insert(name, value);
686        }
687
688        attrs
689    }
690
691    fn do_ssl_request(&mut self) -> Result<()> {
692        let charset = if self.server_version() >= (5, 5, 3) {
693            UTF8MB4_GENERAL_CI
694        } else {
695            UTF8_GENERAL_CI
696        };
697
698        let ssl_request = SslRequest::new(
699            self.get_client_flags(),
700            DEFAULT_MAX_ALLOWED_PACKET as u32,
701            charset as u8,
702        );
703        self.write_struct(&ssl_request)
704    }
705
706    fn write_handshake_response(
707        &mut self,
708        auth_plugin: &AuthPlugin<'_>,
709        scramble_buf: Option<&[u8]>,
710    ) -> Result<()> {
711        let handshake_response = HandshakeResponse::new(
712            scramble_buf,
713            self.0.server_version.unwrap_or((0, 0, 0)),
714            self.0.opts.get_user().map(str::as_bytes),
715            self.0.opts.get_db_name().map(str::as_bytes),
716            Some(auth_plugin.clone()),
717            self.0.capability_flags,
718            Some(self.connect_attrs().clone()),
719        );
720
721        let mut buf = get_buffer();
722        handshake_response.serialize(buf.as_mut());
723        self.write_packet(&mut &*buf)
724    }
725
726    fn continue_auth(
727        &mut self,
728        auth_plugin: &AuthPlugin<'_>,
729        nonce: &[u8],
730        auth_switched: bool,
731    ) -> Result<()> {
732        match auth_plugin {
733            AuthPlugin::CachingSha2Password => {
734                self.continue_caching_sha2_password_auth(nonce, auth_switched)?;
735                Ok(())
736            }
737            AuthPlugin::MysqlNativePassword | AuthPlugin::MysqlOldPassword => {
738                self.continue_mysql_native_password_auth(nonce, auth_switched)?;
739                Ok(())
740            }
741            AuthPlugin::Other(ref name) => {
742                let plugin_name = String::from_utf8_lossy(name).into();
743                Err(DriverError(UnknownAuthPlugin(plugin_name)))
744            }
745        }
746    }
747
748    fn continue_mysql_native_password_auth(
749        &mut self,
750        nonce: &[u8],
751        auth_switched: bool,
752    ) -> Result<()> {
753        let payload = self.read_packet()?;
754
755        match payload[0] {
756            // auth ok
757            0x00 => self.handle_ok::<CommonOkPacket>(&payload).map(drop),
758            // auth switch
759            0xfe if !auth_switched => {
760                let auth_switch = if payload.len() > 1 {
761                    ParseBuf(&*payload).parse(())?
762                } else {
763                    let _ = ParseBuf(&*payload).parse::<OldAuthSwitchRequest>(())?;
764                    // we'll map OldAuthSwitchRequest to an AuthSwitchRequest with mysql_old_password plugin.
765                    AuthSwitchRequest::new("mysql_old_password".as_bytes(), nonce)
766                };
767                self.perform_auth_switch(auth_switch)
768            }
769            _ => Err(DriverError(UnexpectedPacket)),
770        }
771    }
772
773    fn continue_caching_sha2_password_auth(
774        &mut self,
775        nonce: &[u8],
776        auth_switched: bool,
777    ) -> Result<()> {
778        let payload = self.read_packet()?;
779
780        match payload[0] {
781            0x00 => {
782                // ok packet for empty password
783                Ok(())
784            }
785            0x01 => match payload[1] {
786                0x03 => {
787                    let payload = self.read_packet()?;
788                    self.handle_ok::<CommonOkPacket>(&payload).map(drop)
789                }
790                0x04 => {
791                    if !self.is_insecure() || self.is_socket() {
792                        let mut pass = self
793                            .0
794                            .opts
795                            .get_pass()
796                            .map(Vec::from)
797                            .unwrap_or_else(Vec::new);
798                        pass.push(0);
799                        self.write_packet(&mut pass.as_slice())?;
800                    } else {
801                        self.write_packet(&mut &[0x02][..])?;
802                        let payload = self.read_packet()?;
803                        let key = &payload[1..];
804                        let mut pass = self
805                            .0
806                            .opts
807                            .get_pass()
808                            .map(Vec::from)
809                            .unwrap_or_else(Vec::new);
810                        pass.push(0);
811                        for i in 0..pass.len() {
812                            pass[i] ^= nonce[i % nonce.len()];
813                        }
814                        let encrypted_pass = crypto::encrypt(&*pass, key);
815                        self.write_packet(&mut encrypted_pass.as_slice())?;
816                    }
817
818                    let payload = self.read_packet()?;
819                    self.handle_ok::<CommonOkPacket>(&payload).map(drop)
820                }
821                _ => Err(DriverError(UnexpectedPacket)),
822            },
823            0xfe if !auth_switched => {
824                let auth_switch_request = ParseBuf(&*payload).parse(())?;
825                self.perform_auth_switch(auth_switch_request)
826            }
827            _ => Err(DriverError(UnexpectedPacket)),
828        }
829    }
830
831    fn reset_seq_id(&mut self) {
832        self.stream_mut().codec_mut().reset_seq_id();
833    }
834
835    fn sync_seq_id(&mut self) {
836        self.stream_mut().codec_mut().sync_seq_id();
837    }
838
839    fn write_command_raw<T: MySerialize>(&mut self, cmd: &T) -> Result<()> {
840        let mut buf = get_buffer();
841        cmd.serialize(buf.as_mut());
842        self.reset_seq_id();
843        debug_assert!(buf.len() > 0);
844        self.0.last_command = buf[0];
845        self.write_packet(&mut &*buf)
846    }
847
848    fn write_command(&mut self, cmd: Command, data: &[u8]) -> Result<()> {
849        let mut buf = get_buffer();
850        buf.as_mut().put_u8(cmd as u8);
851        buf.as_mut().extend_from_slice(data);
852
853        self.reset_seq_id();
854        self.0.last_command = buf[0];
855        self.write_packet(&mut &*buf)
856    }
857
858    fn send_long_data(&mut self, stmt_id: u32, params: &[Value]) -> Result<()> {
859        for (i, value) in params.iter().enumerate() {
860            if let Bytes(bytes) = value {
861                let chunks = bytes.chunks(MAX_PAYLOAD_LEN - 6);
862                let chunks = chunks.chain(if bytes.is_empty() {
863                    Some(&[][..])
864                } else {
865                    None
866                });
867                for chunk in chunks {
868                    let cmd = ComStmtSendLongData::new(stmt_id, i as u16, Cow::Borrowed(chunk));
869                    self.write_command_raw(&cmd)?;
870                }
871            }
872        }
873
874        Ok(())
875    }
876
877    fn _execute(
878        &mut self,
879        stmt: &Statement,
880        params: Params,
881    ) -> Result<Or<Vec<Column>, OkPacket<'static>>> {
882        let exec_request = match &params {
883            Params::Empty => {
884                if stmt.num_params() != 0 {
885                    return Err(DriverError(MismatchedStmtParams(stmt.num_params(), 0)));
886                }
887
888                let (body, _) = ComStmtExecuteRequestBuilder::new(stmt.id()).build(&[]);
889                body
890            }
891            Params::Positional(params) => {
892                if stmt.num_params() != params.len() as u16 {
893                    return Err(DriverError(MismatchedStmtParams(
894                        stmt.num_params(),
895                        params.len(),
896                    )));
897                }
898
899                let (body, as_long_data) =
900                    ComStmtExecuteRequestBuilder::new(stmt.id()).build(&*params);
901
902                if as_long_data {
903                    self.send_long_data(stmt.id(), &*params)?;
904                }
905
906                body
907            }
908            Params::Named(_) => {
909                if let Some(named_params) = stmt.named_params.as_ref() {
910                    return self._execute(stmt, params.into_positional(named_params)?);
911                } else {
912                    return Err(DriverError(NamedParamsForPositionalQuery));
913                }
914            }
915        };
916        self.write_command_raw(&exec_request)?;
917        self.handle_result_set()
918    }
919
920    fn _start_transaction(&mut self, tx_opts: TxOpts) -> Result<()> {
921        if let Some(i_level) = tx_opts.isolation_level() {
922            self.query_drop(format!("SET TRANSACTION ISOLATION LEVEL {}", i_level))?;
923        }
924        if let Some(mode) = tx_opts.access_mode() {
925            let supported = match (self.0.server_version, self.0.mariadb_server_version) {
926                (Some(ref version), _) if *version >= (5, 6, 5) => true,
927                (_, Some(ref version)) if *version >= (10, 0, 0) => true,
928                _ => false,
929            };
930            if !supported {
931                return Err(DriverError(ReadOnlyTransNotSupported));
932            }
933            match mode {
934                AccessMode::ReadOnly => self.query_drop("SET TRANSACTION READ ONLY")?,
935                AccessMode::ReadWrite => self.query_drop("SET TRANSACTION READ WRITE")?,
936            }
937        }
938        if tx_opts.with_consistent_snapshot() {
939            self.query_drop("START TRANSACTION WITH CONSISTENT SNAPSHOT")
940                .unwrap();
941        } else {
942            self.query_drop("START TRANSACTION")?;
943        };
944        Ok(())
945    }
946
947    fn send_local_infile(&mut self, file_name: &[u8]) -> Result<OkPacket<'static>> {
948        {
949            let buffer_size = cmp::min(
950                MAX_PAYLOAD_LEN - 4,
951                self.stream_ref().codec().max_allowed_packet - 4,
952            );
953            let chunk = vec![0u8; buffer_size].into_boxed_slice();
954            let maybe_handler = self
955                .0
956                .local_infile_handler
957                .clone()
958                .or_else(|| self.0.opts.get_local_infile_handler().cloned());
959            let mut local_infile = LocalInfile::new(io::Cursor::new(chunk), self);
960            if let Some(handler) = maybe_handler {
961                // Unwrap won't panic because we have exclusive access to `self` and this
962                // method is not re-entrant, because `LocalInfile` does not expose the
963                // connection.
964                let handler_fn = &mut *handler.0.lock()?;
965                handler_fn(file_name, &mut local_infile)?;
966            }
967            local_infile.flush()?;
968        }
969        self.write_packet(&mut &[][..])?;
970        let payload = self.read_packet()?;
971        let ok = self.handle_ok::<CommonOkPacket>(&payload)?;
972        Ok(ok.into_owned())
973    }
974
975    fn handle_result_set(&mut self) -> Result<Or<Vec<Column>, OkPacket<'static>>> {
976        if self.more_results_exists() {
977            self.sync_seq_id();
978        }
979
980        let pld = self.read_packet()?;
981        match pld[0] {
982            0x00 => {
983                let ok = self.handle_ok::<CommonOkPacket>(&pld)?;
984                Ok(Or::B(ok.into_owned()))
985            }
986            0xfb => match self.send_local_infile(&pld[1..]) {
987                Ok(ok) => Ok(Or::B(ok)),
988                Err(err) => Err(err),
989            },
990            _ => {
991                let mut reader = &pld[..];
992                let column_count = reader.read_lenenc_int()?;
993                let mut columns: Vec<Column> = Vec::with_capacity(column_count as usize);
994                for _ in 0..column_count {
995                    let pld = self.read_packet()?;
996                    let column = ParseBuf(&*pld).parse(())?;
997                    columns.push(column);
998                }
999                // skip eof packet
1000                self.drop_packet()?;
1001                self.0.has_results = column_count > 0;
1002                Ok(Or::A(columns))
1003            }
1004        }
1005    }
1006
1007    fn _query(&mut self, query: &str) -> Result<Or<Vec<Column>, OkPacket<'static>>> {
1008        self.write_command(Command::COM_QUERY, query.as_bytes())?;
1009        self.handle_result_set()
1010    }
1011
1012    /// Executes [`COM_PING`](http://dev.mysql.com/doc/internals/en/com-ping.html)
1013    /// on `Conn`. Return `true` on success or `false` on error.
1014    pub fn ping(&mut self) -> bool {
1015        match self.write_command(Command::COM_PING, &[]) {
1016            Ok(_) => self.drop_packet().is_ok(),
1017            _ => false,
1018        }
1019    }
1020
1021    /// Executes [`COM_INIT_DB`](https://dev.mysql.com/doc/internals/en/com-init-db.html)
1022    /// on `Conn`.
1023    pub fn select_db(&mut self, schema: &str) -> bool {
1024        match self.write_command(Command::COM_INIT_DB, schema.as_bytes()) {
1025            Ok(_) => self.drop_packet().is_ok(),
1026            _ => false,
1027        }
1028    }
1029
1030    /// Starts new transaction with provided options.
1031    /// `readonly` is only available since MySQL 5.6.5.
1032    pub fn start_transaction(&mut self, tx_opts: TxOpts) -> Result<Transaction> {
1033        self._start_transaction(tx_opts)?;
1034        Ok(Transaction::new(self.into()))
1035    }
1036
1037    fn _true_prepare(&mut self, query: &[u8]) -> Result<InnerStmt> {
1038        self.write_command(Command::COM_STMT_PREPARE, query)?;
1039        let pld = self.read_packet()?;
1040        let mut stmt = ParseBuf(&*pld).parse::<InnerStmt>(self.connection_id())?;
1041        if stmt.num_params() > 0 {
1042            let mut params: Vec<Column> = Vec::with_capacity(stmt.num_params() as usize);
1043            for _ in 0..stmt.num_params() {
1044                let pld = self.read_packet()?;
1045                params.push(ParseBuf(&*pld).parse(())?);
1046            }
1047            stmt = stmt.with_params(Some(params));
1048            self.drop_packet()?;
1049        }
1050        if stmt.num_columns() > 0 {
1051            let mut columns: Vec<Column> = Vec::with_capacity(stmt.num_columns() as usize);
1052            for _ in 0..stmt.num_columns() {
1053                let pld = self.read_packet()?;
1054                columns.push(ParseBuf(&*pld).parse(())?);
1055            }
1056            stmt = stmt.with_columns(Some(columns));
1057            self.drop_packet()?;
1058        }
1059        Ok(stmt)
1060    }
1061
1062    fn _prepare(&mut self, query: &[u8]) -> Result<Arc<InnerStmt>> {
1063        if let Some(entry) = self.0.stmt_cache.by_query(query) {
1064            return Ok(entry.stmt.clone());
1065        }
1066
1067        let inner_st = Arc::new(self._true_prepare(query)?);
1068
1069        if let Some(old_stmt) = self
1070            .0
1071            .stmt_cache
1072            .put(Arc::new(query.into()), inner_st.clone())
1073        {
1074            self.close(Statement::new(old_stmt, None))?;
1075        }
1076
1077        Ok(inner_st)
1078    }
1079
1080    fn connect(&mut self) -> Result<()> {
1081        if self.0.connected {
1082            return Ok(());
1083        }
1084        self.do_handshake()
1085            .and_then(|_| {
1086                Ok(from_value_opt::<usize>(
1087                    self.get_system_var("max_allowed_packet")?.unwrap_or(NULL),
1088                )
1089                .unwrap_or(0))
1090            })
1091            .and_then(|max_allowed_packet| {
1092                if max_allowed_packet == 0 {
1093                    Err(DriverError(SetupError))
1094                } else {
1095                    self.stream_mut().codec_mut().max_allowed_packet = max_allowed_packet;
1096                    self.0.connected = true;
1097                    Ok(())
1098                }
1099            })
1100    }
1101
1102    fn get_system_var(&mut self, name: &str) -> Result<Option<Value>> {
1103        self.query_first(format!("SELECT @@{}", name))
1104    }
1105
1106    fn next_row_packet(&mut self) -> Result<Option<Buffer>> {
1107        if !self.0.has_results {
1108            return Ok(None);
1109        }
1110
1111        let pld = self.read_packet()?;
1112
1113        if self.has_capability(CapabilityFlags::CLIENT_DEPRECATE_EOF) {
1114            if pld[0] == 0xfe && pld.len() < MAX_PAYLOAD_LEN {
1115                self.0.has_results = false;
1116                self.handle_ok::<ResultSetTerminator>(&pld)?;
1117                return Ok(None);
1118            }
1119        } else {
1120            if pld[0] == 0xfe && pld.len() < 8 {
1121                self.0.has_results = false;
1122                self.handle_ok::<OldEofPacket>(&pld)?;
1123                return Ok(None);
1124            }
1125        }
1126
1127        Ok(Some(pld))
1128    }
1129
1130    fn has_stmt(&self, query: &[u8]) -> bool {
1131        self.0.stmt_cache.contains_query(query)
1132    }
1133
1134    /// Sets a callback to handle requests for local files. These are
1135    /// caused by using `LOAD DATA LOCAL INFILE` queries. The
1136    /// callback is passed the filename, and a `Write`able object
1137    /// to receive the contents of that file.
1138    /// Specifying `None` will reset the handler to the one specified
1139    /// in the `Opts` for this connection.
1140    pub fn set_local_infile_handler(&mut self, handler: Option<LocalInfileHandler>) {
1141        self.0.local_infile_handler = handler;
1142    }
1143
1144    pub fn no_backslash_escape(&self) -> bool {
1145        self.0
1146            .status_flags
1147            .contains(StatusFlags::SERVER_STATUS_NO_BACKSLASH_ESCAPES)
1148    }
1149
1150    fn register_as_slave(&mut self, server_id: u32) -> Result<()> {
1151        use mysql_common::packets::ComRegisterSlave;
1152
1153        self.query_drop("SET @master_binlog_checksum='ALL'")?;
1154        self.write_command_raw(&ComRegisterSlave::new(server_id))?;
1155
1156        // Server will respond with OK.
1157        self.read_packet()?;
1158
1159        Ok(())
1160    }
1161
1162    fn request_binlog(&mut self, request: BinlogRequest<'_>) -> Result<()> {
1163        self.register_as_slave(request.server_id())?;
1164        self.write_command_raw(&request.as_cmd())?;
1165        Ok(())
1166    }
1167
1168    /// Turns this connection into a binlog stream.
1169    ///
1170    /// You can use `SHOW BINARY LOGS` to get the current logfile and position from the master.
1171    /// If the request's `filename` is empty, the server will send the binlog-stream
1172    /// of the first known binlog.
1173    pub fn get_binlog_stream(mut self, request: BinlogRequest<'_>) -> Result<BinlogStream> {
1174        self.request_binlog(request)?;
1175        Ok(BinlogStream::new(self))
1176    }
1177}
1178
1179#[cfg(unix)]
1180impl AsRawFd for Conn {
1181    fn as_raw_fd(&self) -> RawFd {
1182        self.stream_ref().get_ref().as_raw_fd()
1183    }
1184}
1185
1186impl Queryable for Conn {
1187    fn query_iter<T: AsRef<str>>(&mut self, query: T) -> Result<QueryResult<'_, '_, '_, Text>> {
1188        let meta = self._query(query.as_ref())?;
1189        Ok(QueryResult::new(ConnMut::Mut(self), meta))
1190    }
1191
1192    fn prep<T: AsRef<str>>(&mut self, query: T) -> Result<Statement> {
1193        let query = query.as_ref();
1194        let (named_params, real_query) = parse_named_params(query.as_bytes())?;
1195        self._prepare(real_query.borrow())
1196            .map(|inner| Statement::new(inner, named_params))
1197    }
1198
1199    fn close(&mut self, stmt: Statement) -> Result<()> {
1200        self.0.stmt_cache.remove(stmt.id());
1201        let cmd = ComStmtClose::new(stmt.id());
1202        self.write_command_raw(&cmd)
1203    }
1204
1205    fn exec_iter<S, P>(&mut self, stmt: S, params: P) -> Result<QueryResult<'_, '_, '_, Binary>>
1206    where
1207        S: AsStatement,
1208        P: Into<Params>,
1209    {
1210        let statement = stmt.as_statement(self)?;
1211        let meta = self._execute(&*statement, params.into())?;
1212        Ok(QueryResult::new(ConnMut::Mut(self), meta))
1213    }
1214}
1215
1216impl Drop for Conn {
1217    fn drop(&mut self) {
1218        let stmt_cache = mem::replace(&mut self.0.stmt_cache, StmtCache::new(0));
1219
1220        for (_, entry) in stmt_cache.into_iter() {
1221            let _ = self.close(Statement::new(entry.stmt, None));
1222        }
1223
1224        if self.0.stream.is_some() {
1225            let _ = self.write_command(Command::COM_QUIT, &[]);
1226        }
1227    }
1228}
1229
1230#[cfg(test)]
1231#[allow(non_snake_case)]
1232mod test {
1233    mod my_conn {
1234        use std::{
1235            collections::HashMap,
1236            io::Write,
1237            iter, process,
1238            sync::mpsc::{channel, sync_channel},
1239            thread::spawn,
1240            time::Duration,
1241        };
1242
1243        use mysql_common::{binlog::events::EventData, packets::binlog_request::BinlogRequest};
1244        use time::PrimitiveDateTime;
1245
1246        use crate::{
1247            from_row, from_value, params,
1248            prelude::*,
1249            test_misc::get_opts,
1250            Conn,
1251            DriverError::{MissingNamedParameter, NamedParamsForPositionalQuery},
1252            Error::DriverError,
1253            LocalInfileHandler, Opts, OptsBuilder, Pool, TxOpts,
1254            Value::{self, Bytes, Date, Float, Int, NULL},
1255        };
1256
1257        fn get_system_variable<T>(conn: &mut Conn, name: &str) -> T
1258        where
1259            T: FromValue,
1260        {
1261            conn.query_first::<(String, T), _>(format!("show variables like '{}'", name))
1262                .unwrap()
1263                .unwrap()
1264                .1
1265        }
1266
1267        #[test]
1268        fn should_connect() {
1269            let mut conn = Conn::new(get_opts()).unwrap();
1270
1271            let mode: String = conn
1272                .query_first("SELECT @@GLOBAL.sql_mode")
1273                .unwrap()
1274                .unwrap();
1275            assert!(mode.contains("TRADITIONAL"));
1276            assert!(conn.ping());
1277
1278            if crate::test_misc::test_compression() {
1279                assert!(format!("{:?}", conn.0.stream).contains("Compression"));
1280            }
1281
1282            if crate::test_misc::test_ssl() {
1283                assert!(!conn.is_insecure());
1284            }
1285        }
1286
1287        #[test]
1288        fn mysql_async_issue_107() -> crate::Result<()> {
1289            let mut conn = Conn::new(get_opts())?;
1290            conn.query_drop(
1291                r"CREATE TEMPORARY TABLE mysql.issue (
1292                        a BIGINT(20) UNSIGNED,
1293                        b VARBINARY(16),
1294                        c BINARY(32),
1295                        d BIGINT(20) UNSIGNED,
1296                        e BINARY(32)
1297                    )",
1298            )?;
1299            conn.query_drop(
1300                r"INSERT INTO mysql.issue VALUES (
1301                        0,
1302                        0xC066F966B0860000,
1303                        0x7939DA98E524C5F969FC2DE8D905FD9501EBC6F20001B0A9C941E0BE6D50CF44,
1304                        0,
1305                        ''
1306                    ), (
1307                        1,
1308                        '',
1309                        0x076311DF4D407B0854371BA13A5F3FB1A4555AC22B361375FD47B263F31822F2,
1310                        0,
1311                        ''
1312                    )",
1313            )?;
1314
1315            let q = "SELECT b, c, d, e FROM mysql.issue";
1316            let result = conn.query_iter(q)?;
1317
1318            let loaded_structs = result
1319                .map(|row| crate::from_row::<(Vec<u8>, Vec<u8>, u64, Vec<u8>)>(row.unwrap()))
1320                .collect::<Vec<_>>();
1321
1322            assert_eq!(loaded_structs.len(), 2);
1323
1324            Ok(())
1325        }
1326
1327        #[test]
1328        fn query_traits() -> Result<(), Box<dyn std::error::Error>> {
1329            macro_rules! test_query {
1330                ($conn : expr) => {
1331                    "CREATE TABLE tmplak (a INT)".run($conn)?;
1332
1333                    "INSERT INTO tmplak (a) VALUES (?)".with((42,)).run($conn)?;
1334
1335                    "INSERT INTO tmplak (a) VALUES (?)"
1336                        .with((43..=44).map(|x| (x,)))
1337                        .batch($conn)?;
1338
1339                    let first: Option<u8> = "SELECT a FROM tmplak LIMIT 1".first($conn)?;
1340                    assert_eq!(first, Some(42), "first text");
1341
1342                    let first: Option<u8> = "SELECT a FROM tmplak LIMIT 1".with(()).first($conn)?;
1343                    assert_eq!(first, Some(42), "first bin");
1344
1345                    let count = "SELECT a FROM tmplak".run($conn)?.count();
1346                    assert_eq!(count, 3, "run text");
1347
1348                    let count = "SELECT a FROM tmplak".with(()).run($conn)?.count();
1349                    assert_eq!(count, 3, "run bin");
1350
1351                    let all: Vec<u8> = "SELECT a FROM tmplak".fetch($conn)?;
1352                    assert_eq!(all, vec![42, 43, 44], "fetch text");
1353
1354                    let all: Vec<u8> = "SELECT a FROM tmplak".with(()).fetch($conn)?;
1355                    assert_eq!(all, vec![42, 43, 44], "fetch bin");
1356
1357                    let mapped = "SELECT a FROM tmplak".map($conn, |x: u8| x + 1)?;
1358                    assert_eq!(mapped, vec![43, 44, 45], "map text");
1359
1360                    let mapped = "SELECT a FROM tmplak".with(()).map($conn, |x: u8| x + 1)?;
1361                    assert_eq!(mapped, vec![43, 44, 45], "map bin");
1362
1363                    let sum = "SELECT a FROM tmplak".fold($conn, 0_u8, |acc, x: u8| acc + x)?;
1364                    assert_eq!(sum, 42 + 43 + 44, "fold text");
1365
1366                    let sum = "SELECT a FROM tmplak"
1367                        .with(())
1368                        .fold($conn, 0_u8, |acc, x: u8| acc + x)?;
1369                    assert_eq!(sum, 42 + 43 + 44, "fold bin");
1370
1371                    "DROP TABLE tmplak".run($conn)?;
1372                };
1373            }
1374
1375            let mut conn = Conn::new(get_opts())?;
1376
1377            let mut tx = conn.start_transaction(TxOpts::default())?;
1378            test_query!(&mut tx);
1379            tx.rollback()?;
1380
1381            test_query!(&mut conn);
1382
1383            let pool = Pool::new(get_opts())?;
1384            let mut pooled_conn = pool.get_conn()?;
1385
1386            let mut tx = pool.start_transaction(TxOpts::default())?;
1387            test_query!(&mut tx);
1388            tx.rollback()?;
1389
1390            test_query!(&mut pooled_conn);
1391
1392            Ok(())
1393        }
1394
1395        #[test]
1396        #[should_panic(expected = "Could not connect to address")]
1397        fn should_fail_on_wrong_socket_path() {
1398            let opts = OptsBuilder::from_opts(get_opts()).socket(Some("/foo/bar/baz"));
1399            let _ = Conn::new(opts).unwrap();
1400        }
1401
1402        #[test]
1403        fn should_fallback_to_tcp_if_cant_switch_to_socket() {
1404            let mut opts = Opts::from(get_opts());
1405            opts.0.injected_socket = Some("/foo/bar/baz".into());
1406            let _ = Conn::new(opts).unwrap();
1407        }
1408
1409        #[test]
1410        fn should_connect_with_database() {
1411            const DB_NAME: &str = "mysql";
1412
1413            let opts = OptsBuilder::from_opts(get_opts()).db_name(Some(DB_NAME));
1414
1415            let mut conn = Conn::new(opts).unwrap();
1416
1417            let db_name: String = conn.query_first("SELECT DATABASE()").unwrap().unwrap();
1418            assert_eq!(db_name, DB_NAME);
1419        }
1420
1421        #[cfg(not(target_os = "wasi"))]
1422        #[test]
1423        fn should_connect_by_hostname() {
1424            let opts = OptsBuilder::from_opts(get_opts()).ip_or_hostname(Some("localhost"));
1425            let mut conn = Conn::new(opts).unwrap();
1426            assert!(conn.ping());
1427        }
1428
1429        #[test]
1430        fn should_select_db() {
1431            const DB_NAME: &str = "t_select_db";
1432
1433            let mut conn = Conn::new(get_opts()).unwrap();
1434            conn.query_drop(format!("CREATE DATABASE IF NOT EXISTS {}", DB_NAME))
1435                .unwrap();
1436            assert!(conn.select_db(DB_NAME));
1437
1438            let db_name: String = conn.query_first("SELECT DATABASE()").unwrap().unwrap();
1439            assert_eq!(db_name, DB_NAME);
1440
1441            conn.query_drop(format!("DROP DATABASE {}", DB_NAME))
1442                .unwrap();
1443        }
1444
1445        #[test]
1446        fn should_execute_queryes_and_parse_results() {
1447            type TestRow = (String, String, String, String, String, String);
1448
1449            const CREATE_QUERY: &str = r"CREATE TEMPORARY TABLE mysql.tbl
1450                (id SERIAL, a TEXT, b INT, c INT UNSIGNED, d DATE, e FLOAT)";
1451            const INSERT_QUERY_1: &str = r"INSERT
1452                INTO mysql.tbl(a, b, c, d, e)
1453                VALUES ('hello', -123, 123, '2014-05-05', 123.123)";
1454            const INSERT_QUERY_2: &str = r"INSERT
1455                INTO mysql.tbl(a, b, c, d, e)
1456                VALUES ('world', -321, 321, '2014-06-06', 321.321)";
1457
1458            let mut conn = Conn::new(get_opts()).unwrap();
1459
1460            conn.query_drop(CREATE_QUERY).unwrap();
1461            assert_eq!(conn.affected_rows(), 0);
1462            assert_eq!(conn.last_insert_id(), 0);
1463
1464            conn.query_drop(INSERT_QUERY_1).unwrap();
1465            assert_eq!(conn.affected_rows(), 1);
1466            assert_eq!(conn.last_insert_id(), 1);
1467
1468            conn.query_drop(INSERT_QUERY_2).unwrap();
1469            assert_eq!(conn.affected_rows(), 1);
1470            assert_eq!(conn.last_insert_id(), 2);
1471
1472            conn.query_drop("SELECT * FROM unexisted").unwrap_err();
1473            conn.query_iter("SELECT * FROM mysql.tbl").unwrap(); // Drop::drop for QueryResult
1474
1475            conn.query_drop("UPDATE mysql.tbl SET a = 'foo'").unwrap();
1476            assert_eq!(conn.affected_rows(), 2);
1477            assert_eq!(conn.last_insert_id(), 0);
1478
1479            assert!(conn
1480                .query_first::<TestRow, _>("SELECT * FROM mysql.tbl WHERE a = 'bar'")
1481                .unwrap()
1482                .is_none());
1483
1484            let rows: Vec<TestRow> = conn.query("SELECT * FROM mysql.tbl").unwrap();
1485            assert_eq!(
1486                rows,
1487                vec![
1488                    (
1489                        "1".into(),
1490                        "foo".into(),
1491                        "-123".into(),
1492                        "123".into(),
1493                        "2014-05-05".into(),
1494                        "123.123".into()
1495                    ),
1496                    (
1497                        "2".into(),
1498                        "foo".into(),
1499                        "-321".into(),
1500                        "321".into(),
1501                        "2014-06-06".into(),
1502                        "321.321".into()
1503                    )
1504                ]
1505            );
1506        }
1507
1508        #[cfg(not(target_os = "wasi"))]
1509        #[test]
1510        fn should_parse_large_text_result() {
1511            let mut conn = Conn::new(get_opts()).unwrap();
1512            let value: Value = conn
1513                .query_first("SELECT REPEAT('A', 20000000)")
1514                .unwrap()
1515                .unwrap();
1516            assert_eq!(value, Bytes(iter::repeat(b'A').take(20_000_000).collect()));
1517        }
1518
1519        #[test]
1520        fn should_execute_statements_and_parse_results() {
1521            const CREATE_QUERY: &str = r"CREATE TEMPORARY TABLE
1522                mysql.tbl (a TEXT, b INT, c INT UNSIGNED, d DATE, e FLOAT)";
1523            const INSERT_SMTM: &str = r"INSERT
1524                INTO mysql.tbl (a, b, c, d, e)
1525                VALUES (?, ?, ?, ?, ?)";
1526
1527            type RowType = (Value, Value, Value, Value, Value);
1528
1529            let row1 = (
1530                Bytes(b"hello".to_vec()),
1531                Int(-123_i64),
1532                Int(123_i64),
1533                Date(2014_u16, 5_u8, 5_u8, 0_u8, 0_u8, 0_u8, 0_u32),
1534                Float(123.123_f32),
1535            );
1536            let row2 = (Bytes(b"".to_vec()), NULL, NULL, NULL, Float(321.321_f32));
1537
1538            let mut conn = Conn::new(get_opts()).unwrap();
1539            conn.query_drop(CREATE_QUERY).unwrap();
1540
1541            let insert_stmt = conn.prep(INSERT_SMTM).unwrap();
1542            assert_eq!(insert_stmt.connection_id(), conn.connection_id());
1543            conn.exec_drop(
1544                &insert_stmt,
1545                (
1546                    from_value::<String>(row1.0.clone()),
1547                    from_value::<i32>(row1.1.clone()),
1548                    from_value::<u32>(row1.2.clone()),
1549                    from_value::<PrimitiveDateTime>(row1.3.clone()),
1550                    from_value::<f32>(row1.4.clone()),
1551                ),
1552            )
1553            .unwrap();
1554            conn.exec_drop(
1555                &insert_stmt,
1556                (
1557                    from_value::<String>(row2.0.clone()),
1558                    row2.1.clone(),
1559                    row2.2.clone(),
1560                    row2.3.clone(),
1561                    from_value::<f32>(row2.4.clone()),
1562                ),
1563            )
1564            .unwrap();
1565
1566            let select_stmt = conn.prep("SELECT * from mysql.tbl").unwrap();
1567            let rows: Vec<RowType> = conn.exec(&select_stmt, ()).unwrap();
1568
1569            assert_eq!(rows, vec![row1, row2]);
1570        }
1571
1572        #[cfg(not(target_os = "wasi"))]
1573        #[test]
1574        fn should_parse_large_binary_result() {
1575            let mut conn = Conn::new(get_opts()).unwrap();
1576            let stmt = conn.prep("SELECT REPEAT('A', 20000000)").unwrap();
1577            let value: Value = conn.exec_first(&stmt, ()).unwrap().unwrap();
1578            assert_eq!(value, Bytes(iter::repeat(b'A').take(20_000_000).collect()));
1579        }
1580
1581        #[test]
1582        fn manually_closed_stmt() {
1583            let opts = OptsBuilder::from(get_opts()).stmt_cache_size(1);
1584            let mut conn = Conn::new(opts).unwrap();
1585            let stmt = conn.prep("SELECT 1").unwrap();
1586            conn.exec_drop(&stmt, ()).unwrap();
1587            conn.close(stmt).unwrap();
1588            let stmt = conn.prep("SELECT 1").unwrap();
1589            conn.exec_drop(&stmt, ()).unwrap();
1590            conn.close(stmt).unwrap();
1591            let stmt = conn.prep("SELECT 2").unwrap();
1592            conn.exec_drop(&stmt, ()).unwrap();
1593        }
1594
1595        #[test]
1596        fn should_start_commit_and_rollback_transactions() {
1597            let mut conn = Conn::new(get_opts()).unwrap();
1598            conn.query_drop(
1599                "CREATE TEMPORARY TABLE mysql.tbl(id INT NOT NULL PRIMARY KEY AUTO_INCREMENT, a INT)",
1600            )
1601            .unwrap();
1602            let _ = conn
1603                .start_transaction(TxOpts::default())
1604                .and_then(|mut t| {
1605                    t.query_drop("INSERT INTO mysql.tbl(a) VALUES(1)").unwrap();
1606                    assert_eq!(t.last_insert_id(), Some(1));
1607                    assert_eq!(t.affected_rows(), 1);
1608                    t.query_drop("INSERT INTO mysql.tbl(a) VALUES(2)").unwrap();
1609                    t.commit().unwrap();
1610                    Ok(())
1611                })
1612                .unwrap();
1613            assert_eq!(
1614                conn.query_iter("SELECT COUNT(a) from mysql.tbl")
1615                    .unwrap()
1616                    .next()
1617                    .unwrap()
1618                    .unwrap()
1619                    .unwrap(),
1620                vec![Bytes(b"2".to_vec())]
1621            );
1622            let _ = conn
1623                .start_transaction(TxOpts::default())
1624                .and_then(|mut t| {
1625                    t.query_drop("INSERT INTO tbl2(a) VALUES(1)").unwrap_err();
1626                    Ok(())
1627                    // implicit rollback
1628                })
1629                .unwrap();
1630            assert_eq!(
1631                conn.query_iter("SELECT COUNT(a) from mysql.tbl")
1632                    .unwrap()
1633                    .next()
1634                    .unwrap()
1635                    .unwrap()
1636                    .unwrap(),
1637                vec![Bytes(b"2".to_vec())]
1638            );
1639            let _ = conn
1640                .start_transaction(TxOpts::default())
1641                .and_then(|mut t| {
1642                    t.query_drop("INSERT INTO mysql.tbl(a) VALUES(1)").unwrap();
1643                    t.query_drop("INSERT INTO mysql.tbl(a) VALUES(2)").unwrap();
1644                    t.rollback().unwrap();
1645                    Ok(())
1646                })
1647                .unwrap();
1648            assert_eq!(
1649                conn.query_iter("SELECT COUNT(a) from mysql.tbl")
1650                    .unwrap()
1651                    .next()
1652                    .unwrap()
1653                    .unwrap()
1654                    .unwrap(),
1655                vec![Bytes(b"2".to_vec())]
1656            );
1657            let mut tx = conn.start_transaction(TxOpts::default()).unwrap();
1658            tx.exec_drop("INSERT INTO mysql.tbl(a) VALUES(?)", (3,))
1659                .unwrap();
1660            tx.exec_drop("INSERT INTO mysql.tbl(a) VALUES(?)", (4,))
1661                .unwrap();
1662            tx.commit().unwrap();
1663            assert_eq!(
1664                conn.query_iter("SELECT COUNT(a) from mysql.tbl")
1665                    .unwrap()
1666                    .next()
1667                    .unwrap()
1668                    .unwrap()
1669                    .unwrap(),
1670                vec![Bytes(b"4".to_vec())]
1671            );
1672            let mut tx = conn.start_transaction(TxOpts::default()).unwrap();
1673            tx.exec_drop("INSERT INTO mysql.tbl(a) VALUES(?)", (5,))
1674                .unwrap();
1675            tx.exec_drop("INSERT INTO mysql.tbl(a) VALUES(?)", (6,))
1676                .unwrap();
1677            drop(tx);
1678            assert_eq!(
1679                conn.query_first("SELECT COUNT(a) from mysql.tbl").unwrap(),
1680                Some(4_usize),
1681            );
1682        }
1683        #[test]
1684        fn should_handle_LOCAL_INFILE_with_custom_handler() {
1685            let mut conn = Conn::new(get_opts()).unwrap();
1686            conn.query_drop("CREATE TEMPORARY TABLE mysql.tbl(a TEXT)")
1687                .unwrap();
1688            conn.set_local_infile_handler(Some(LocalInfileHandler::new(|_, stream| {
1689                let mut cell_data = vec![b'Z'; 65535];
1690                cell_data.push(b'\n');
1691                for _ in 0..1536 {
1692                    stream.write_all(&*cell_data)?;
1693                }
1694                Ok(())
1695            })));
1696            match conn.query_drop("LOAD DATA LOCAL INFILE 'file_name' INTO TABLE mysql.tbl") {
1697                Ok(_) => {}
1698                Err(ref err) if format!("{}", err).find("not allowed").is_some() => {
1699                    return;
1700                }
1701                Err(err) => panic!("ERROR {}", err),
1702            }
1703            let count = conn
1704                .query_iter("SELECT * FROM mysql.tbl")
1705                .unwrap()
1706                .map(|row| {
1707                    assert_eq!(from_row::<(Vec<u8>,)>(row.unwrap()).0.len(), 65535);
1708                    1
1709                })
1710                .sum::<usize>();
1711            assert_eq!(count, 1536);
1712        }
1713
1714        #[test]
1715        fn should_reset_connection() {
1716            let mut conn = Conn::new(get_opts()).unwrap();
1717            conn.query_drop(
1718                "CREATE TEMPORARY TABLE `mysql`.`test` \
1719                 (`test` VARCHAR(255) NULL);",
1720            )
1721            .unwrap();
1722            conn.query_drop("INSERT INTO `mysql`.`test` (`test`) VALUES ('foo');")
1723                .unwrap();
1724            assert_eq!(conn.affected_rows(), 1);
1725            conn.reset().unwrap();
1726            assert_eq!(conn.affected_rows(), 0);
1727            conn.query_drop("SELECT * FROM `mysql`.`test`;")
1728                .unwrap_err();
1729        }
1730
1731        #[test]
1732        fn prep_exec() {
1733            let mut conn = Conn::new(get_opts()).unwrap();
1734
1735            let stmt1 = conn.prep("SELECT :foo").unwrap();
1736            let stmt2 = conn.prep("SELECT :bar").unwrap();
1737            assert_eq!(
1738                conn.exec::<String, _, _>(&stmt1, params! { "foo" => "foo" })
1739                    .unwrap(),
1740                vec![String::from("foo")],
1741            );
1742            assert_eq!(
1743                conn.exec::<String, _, _>(&stmt2, params! { "bar" => "bar" })
1744                    .unwrap(),
1745                vec![String::from("bar")],
1746            );
1747        }
1748
1749        #[test]
1750        fn should_connect_via_socket_for_127_0_0_1() {
1751            let opts = OptsBuilder::from_opts(get_opts());
1752            let conn = Conn::new(opts).unwrap();
1753            if conn.is_insecure() {
1754                assert!(conn.is_socket());
1755            }
1756        }
1757
1758        #[test]
1759        fn should_connect_via_socket_localhost() {
1760            let opts = OptsBuilder::from_opts(get_opts()).ip_or_hostname(Some("localhost"));
1761            let conn = Conn::new(opts).unwrap();
1762            if conn.is_insecure() {
1763                assert!(conn.is_socket());
1764            }
1765        }
1766
1767        /// QueryResult::drop hangs on connectivity errors (see [blackbeam/rust-mysql-simple#306][1]).
1768        ///
1769        /// [1]: https://github.com/blackbeam/rust-mysql-simple/issues/306
1770        #[cfg(not(target_os = "wasi"))]
1771        #[test]
1772        fn issue_306() {
1773            let (tx, rx) = channel::<()>();
1774            let handle = spawn(move || {
1775                let mut c1 = Conn::new(get_opts()).unwrap();
1776                let c1_id = c1.connection_id();
1777                let mut c2 = Conn::new(get_opts()).unwrap();
1778                let query_result = c1.query_iter("DO 1; SELECT SLEEP(1); DO 2;").unwrap();
1779                c2.query_drop(format!("KILL {c1_id}")).unwrap();
1780                drop(c2);
1781                drop(query_result);
1782                tx.send(()).unwrap();
1783            });
1784            std::thread::sleep(Duration::from_secs(2));
1785            assert!(rx.try_recv().is_ok());
1786            handle.join().unwrap();
1787        }
1788
1789        #[test]
1790        fn reset_does_work() {
1791            let mut c = Conn::new(get_opts()).unwrap();
1792            let cid = c.connection_id();
1793            c.reset().unwrap();
1794            match (c.0.server_version, c.0.mariadb_server_version) {
1795                (Some(ref version), _) if *version > (5, 7, 3) => {
1796                    assert_eq!(cid, c.connection_id());
1797                }
1798                (_, Some(ref version)) if *version >= (10, 2, 7) => {
1799                    assert_eq!(cid, c.connection_id());
1800                }
1801                _ => assert_ne!(cid, c.connection_id()),
1802            }
1803        }
1804
1805        /// Library panics with "incomplete connection" in case of subsequent
1806        /// failed calls to `reset` when the server is down.
1807        /// (see [blackbeam/rust-mysql-simple#317][1]).
1808        ///
1809        /// [1]: https://github.com/blackbeam/rust-mysql-simple/issues/317
1810        #[test]
1811        fn issue_317() {
1812            let mut c = Conn::new(get_opts()).unwrap();
1813            c.0.opts = get_opts().tcp_port(55555).into();
1814            let version = std::mem::replace(&mut c.0.server_version, Some((0, 0, 0)));
1815            let mdbversion = std::mem::replace(&mut c.0.mariadb_server_version, Some((0, 0, 0)));
1816            c.reset().unwrap_err();
1817            c.0.server_version = version;
1818            c.0.mariadb_server_version = mdbversion;
1819            let _ = c.reset();
1820        }
1821
1822        #[test]
1823        fn should_drop_multi_result_set() {
1824            let opts = OptsBuilder::from_opts(get_opts()).db_name(Some("mysql"));
1825            let mut conn = Conn::new(opts).unwrap();
1826            conn.query_drop("CREATE TEMPORARY TABLE TEST_TABLE ( name varchar(255) )")
1827                .unwrap();
1828            conn.exec_drop("SELECT * FROM TEST_TABLE", ()).unwrap();
1829            conn.query_drop(
1830                r"
1831                INSERT INTO TEST_TABLE (name) VALUES ('one');
1832                INSERT INTO TEST_TABLE (name) VALUES ('two');
1833                INSERT INTO TEST_TABLE (name) VALUES ('three');",
1834            )
1835            .unwrap();
1836            conn.exec_drop("SELECT * FROM TEST_TABLE", ()).unwrap();
1837
1838            let mut query_result = conn
1839                .query_iter(
1840                    r"
1841                SELECT * FROM TEST_TABLE;
1842                INSERT INTO TEST_TABLE (name) VALUES ('one');
1843                DO 0;",
1844                )
1845                .unwrap();
1846
1847            while let Some(result) = query_result.iter() {
1848                result.affected_rows();
1849            }
1850        }
1851
1852        #[test]
1853        fn should_handle_multi_resultset() {
1854            let opts = OptsBuilder::from_opts(get_opts())
1855                .prefer_socket(false)
1856                .db_name(Some("mysql"));
1857            let mut conn = Conn::new(opts).unwrap();
1858            conn.query_drop("DROP PROCEDURE IF EXISTS multi").unwrap();
1859            conn.query_drop(
1860                r#"CREATE PROCEDURE multi() BEGIN
1861                        SELECT 1 UNION ALL SELECT 2;
1862                        DO 1;
1863                        SELECT 3 UNION ALL SELECT 4;
1864                        DO 1;
1865                        DO 1;
1866                        SELECT REPEAT('A', 17000000);
1867                        SELECT REPEAT('A', 17000000);
1868                    END"#,
1869            )
1870            .unwrap();
1871            {
1872                let mut query_result = conn.query_iter("CALL multi()").unwrap();
1873                let result_set = query_result
1874                    .by_ref()
1875                    .map(|row| row.unwrap().unwrap().pop().unwrap())
1876                    .collect::<Vec<crate::Value>>();
1877                assert_eq!(result_set, vec![Bytes(b"1".to_vec()), Bytes(b"2".to_vec())]);
1878                let result_set = query_result
1879                    .by_ref()
1880                    .map(|row| row.unwrap().unwrap().pop().unwrap())
1881                    .collect::<Vec<crate::Value>>();
1882                assert_eq!(result_set, vec![Bytes(b"3".to_vec()), Bytes(b"4".to_vec())]);
1883            }
1884            let mut result = conn.query_iter("SELECT 1; SELECT 2; SELECT 3;").unwrap();
1885            let mut i = 0;
1886            while let Some(result_set) = result.iter() {
1887                i += 1;
1888                for row in result_set {
1889                    match i {
1890                        1 => assert_eq!(row.unwrap().unwrap(), vec![Bytes(b"1".to_vec())]),
1891                        2 => assert_eq!(row.unwrap().unwrap(), vec![Bytes(b"2".to_vec())]),
1892                        3 => assert_eq!(row.unwrap().unwrap(), vec![Bytes(b"3".to_vec())]),
1893                        _ => unreachable!(),
1894                    }
1895                }
1896            }
1897            assert_eq!(i, 3);
1898        }
1899
1900        #[test]
1901        fn issue_273() {
1902            let opts = OptsBuilder::from_opts(get_opts()).prefer_socket(false);
1903            let mut conn = Conn::new(opts).unwrap();
1904
1905            "DROP FUNCTION IF EXISTS f1".run(&mut conn).unwrap();
1906            r"CREATE DEFINER=`root`@`localhost` FUNCTION `f1`(p_arg INT, p_arg2 INT) RETURNS int
1907            DETERMINISTIC
1908            BEGIN
1909                RETURN p_arg + p_arg2;
1910            END"
1911            .run(&mut conn)
1912            .unwrap();
1913
1914            "SELECT f1(?, ?)"
1915                .with((100u8, 100u8))
1916                .run(&mut conn)
1917                .unwrap();
1918        }
1919
1920        #[cfg(not(target_os = "wasi"))]
1921        #[test]
1922        fn issue_285() {
1923            let (tx, rx) = sync_channel::<()>(0);
1924
1925            let handle = std::thread::spawn(move || {
1926                let mut conn = Conn::new(get_opts()).unwrap();
1927                const INVALID_SQL: &str = r#"
1928                CREATE TEMPORARY TABLE IF NOT EXISTS `user_details` (
1929                    `user_id` int(11) NOT NULL AUTO_INCREMENT,
1930                    `username` varchar(255) DEFAULT NULL,
1931                    `first_name` varchar(50) DEFAULT NULL,
1932                    `last_name` varchar(50) DEFAULT NULL,
1933                    PRIMARY KEY (`user_id`)
1934                );
1935
1936                INSERT INTO `user_details` (`user_id`, `username`, `first_name`, `last_name`)
1937                VALUES (1, 'rogers63', 'david')
1938                "#;
1939
1940                conn.query_iter(INVALID_SQL).unwrap();
1941                tx.send(()).unwrap();
1942            });
1943
1944            match rx.recv_timeout(Duration::from_secs(100_000)) {
1945                Ok(_) => handle.join().unwrap(),
1946                Err(_) => panic!("test failed"),
1947            }
1948        }
1949
1950        #[test]
1951        fn should_work_with_named_params() {
1952            let mut conn = Conn::new(get_opts()).unwrap();
1953            {
1954                let stmt = conn.prep("SELECT :a, :b, :a, :c").unwrap();
1955                let result = conn
1956                    .exec_first(&stmt, params! {"a" => 1, "b" => 2, "c" => 3})
1957                    .unwrap()
1958                    .unwrap();
1959                assert_eq!((1_u8, 2_u8, 1_u8, 3_u8), result);
1960            }
1961
1962            let result = conn
1963                .exec_first(
1964                    "SELECT :a, :b, :a + :b, :c",
1965                    params! {
1966                        "a" => 1,
1967                        "b" => 2,
1968                        "c" => 3,
1969                    },
1970                )
1971                .unwrap()
1972                .unwrap();
1973            assert_eq!((1_u8, 2_u8, 3_u8, 3_u8), result);
1974        }
1975
1976        #[test]
1977        fn should_return_error_on_missing_named_parameter() {
1978            let mut conn = Conn::new(get_opts()).unwrap();
1979            let stmt = conn.prep("SELECT :a, :b, :a, :c, :d").unwrap();
1980            let result =
1981                conn.exec_first::<crate::Row, _, _>(&stmt, params! {"a" => 1, "b" => 2, "c" => 3,});
1982            match result {
1983                Err(DriverError(MissingNamedParameter(ref x))) if x == "d" => (),
1984                _ => assert!(false),
1985            }
1986        }
1987
1988        #[test]
1989        fn should_return_error_on_named_params_for_positional_statement() {
1990            let mut conn = Conn::new(get_opts()).unwrap();
1991            let stmt = conn.prep("SELECT ?, ?, ?, ?, ?").unwrap();
1992            let result = conn.exec_drop(&stmt, params! {"a" => 1, "b" => 2, "c" => 3,});
1993            match result {
1994                Err(DriverError(NamedParamsForPositionalQuery)) => (),
1995                _ => assert!(false),
1996            }
1997        }
1998
1999        #[cfg(not(target_os = "wasi"))]
2000        #[test]
2001        fn should_handle_tcp_connect_timeout() {
2002            use crate::error::{DriverError::ConnectTimeout, Error::DriverError};
2003
2004            let opts = OptsBuilder::from_opts(get_opts())
2005                .prefer_socket(false)
2006                .tcp_connect_timeout(Some(::std::time::Duration::from_millis(1000)));
2007            assert!(Conn::new(opts).unwrap().ping());
2008
2009            let opts = OptsBuilder::from_opts(get_opts())
2010                .prefer_socket(false)
2011                .tcp_connect_timeout(Some(::std::time::Duration::from_millis(1000)))
2012                .ip_or_hostname(Some("192.168.255.255"));
2013            match Conn::new(opts).unwrap_err() {
2014                DriverError(ConnectTimeout) => {}
2015                err => panic!("Unexpected error: {}", err),
2016            }
2017        }
2018
2019        #[test]
2020        fn should_set_additional_capabilities() {
2021            use crate::consts::CapabilityFlags;
2022
2023            let opts = OptsBuilder::from_opts(get_opts())
2024                .additional_capabilities(CapabilityFlags::CLIENT_FOUND_ROWS);
2025
2026            let mut conn = Conn::new(opts).unwrap();
2027            conn.query_drop("CREATE TEMPORARY TABLE mysql.tbl (a INT, b TEXT)")
2028                .unwrap();
2029            conn.query_drop("INSERT INTO mysql.tbl (a, b) VALUES (1, 'foo')")
2030                .unwrap();
2031            let result = conn
2032                .query_iter("UPDATE mysql.tbl SET b = 'foo' WHERE a = 1")
2033                .unwrap();
2034            assert_eq!(result.affected_rows(), 1);
2035        }
2036
2037        #[cfg(not(target_os = "wasi"))]
2038        #[test]
2039        fn should_bind_before_connect() {
2040            let port = 28000 + (rand::random::<u16>() % 2000);
2041            let opts = OptsBuilder::from_opts(get_opts())
2042                .prefer_socket(false)
2043                .ip_or_hostname(Some("localhost"))
2044                .bind_address(Some(([127, 0, 0, 1], port)));
2045            let conn = Conn::new(opts).unwrap();
2046            let debug_format: String = format!("{:?}", conn);
2047            let expected_1 = format!("addr: V4(127.0.0.1:{})", port);
2048            let expected_2 = format!("addr: 127.0.0.1:{}", port);
2049            assert!(
2050                debug_format.contains(&expected_1) || debug_format.contains(&expected_2),
2051                "debug_format: {}",
2052                debug_format
2053            );
2054        }
2055
2056        #[cfg(not(target_os = "wasi"))]
2057        #[test]
2058        fn should_bind_before_connect_with_timeout() {
2059            let port = 30000 + (rand::random::<u16>() % 2000);
2060            let opts = OptsBuilder::from_opts(get_opts())
2061                .prefer_socket(false)
2062                .ip_or_hostname(Some("localhost"))
2063                .bind_address(Some(([127, 0, 0, 1], port)))
2064                .tcp_connect_timeout(Some(::std::time::Duration::from_millis(1000)));
2065            let mut conn = Conn::new(opts).unwrap();
2066            assert!(conn.ping());
2067            let debug_format: String = format!("{:?}", conn);
2068            let expected_1 = format!("addr: V4(127.0.0.1:{})", port);
2069            let expected_2 = format!("addr: 127.0.0.1:{}", port);
2070            assert!(
2071                debug_format.contains(&expected_1) || debug_format.contains(&expected_2),
2072                "debug_format: {}",
2073                debug_format
2074            );
2075        }
2076
2077        #[test]
2078        fn should_not_cache_statements_if_stmt_cache_size_is_zero() {
2079            let opts = OptsBuilder::from_opts(get_opts()).stmt_cache_size(0);
2080            let mut conn = Conn::new(opts).unwrap();
2081
2082            let stmt1 = conn.prep("DO 1").unwrap();
2083            let stmt2 = conn.prep("DO 2").unwrap();
2084            let stmt3 = conn.prep("DO 3").unwrap();
2085
2086            conn.close(stmt1).unwrap();
2087            conn.close(stmt2).unwrap();
2088            conn.close(stmt3).unwrap();
2089
2090            let status: (Value, u8) = conn
2091                .query_first("SHOW SESSION STATUS LIKE 'Com_stmt_close';")
2092                .unwrap()
2093                .unwrap();
2094            assert_eq!(status.1, 3);
2095        }
2096
2097        #[test]
2098        fn should_hold_stmt_cache_size_bounds() {
2099            let opts = OptsBuilder::from_opts(get_opts()).stmt_cache_size(3);
2100            let mut conn = Conn::new(opts).unwrap();
2101
2102            conn.prep("DO 1").unwrap();
2103            conn.prep("DO 2").unwrap();
2104            conn.prep("DO 3").unwrap();
2105            conn.prep("DO 1").unwrap();
2106            conn.prep("DO 4").unwrap();
2107            conn.prep("DO 3").unwrap();
2108            conn.prep("DO 5").unwrap();
2109            conn.prep("DO 6").unwrap();
2110
2111            let status: (String, usize) = conn
2112                .query_first("SHOW SESSION STATUS LIKE 'Com_stmt_close'")
2113                .unwrap()
2114                .unwrap();
2115
2116            assert_eq!(status.1, 3);
2117
2118            let mut order = conn
2119                .0
2120                .stmt_cache
2121                .iter()
2122                .map(|(_, entry)| &**entry.query.0.as_ref())
2123                .collect::<Vec<&[u8]>>();
2124            order.sort();
2125            assert_eq!(order, &[b"DO 3", b"DO 5", b"DO 6"]);
2126        }
2127
2128        #[test]
2129        fn should_handle_json_columns() {
2130            use crate::{Deserialized, Serialized};
2131            use serde_json::Value as Json;
2132            use std::str::FromStr;
2133
2134            #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
2135            pub struct DecTest {
2136                foo: String,
2137                quux: (u64, String),
2138            }
2139
2140            let decodable = DecTest {
2141                foo: "bar".into(),
2142                quux: (42, "hello".into()),
2143            };
2144
2145            let mut conn = Conn::new(get_opts()).unwrap();
2146            if conn
2147                .query_drop("CREATE TEMPORARY TABLE mysql.tbl(a VARCHAR(32), b JSON)")
2148                .is_err()
2149            {
2150                conn.query_drop("CREATE TEMPORARY TABLE mysql.tbl(a VARCHAR(32), b TEXT)")
2151                    .unwrap();
2152            }
2153            conn.exec_drop(
2154                r#"INSERT INTO mysql.tbl VALUES ('hello', ?)"#,
2155                (Serialized(&decodable),),
2156            )
2157            .unwrap();
2158
2159            let (a, b): (String, Json) = conn
2160                .query_first("SELECT a, b FROM mysql.tbl")
2161                .unwrap()
2162                .unwrap();
2163            assert_eq!(
2164                (a, b),
2165                (
2166                    "hello".into(),
2167                    Json::from_str(r#"{"foo": "bar", "quux": [42, "hello"]}"#).unwrap()
2168                )
2169            );
2170
2171            let row = conn
2172                .exec_first("SELECT a, b FROM mysql.tbl WHERE a = ?", ("hello",))
2173                .unwrap()
2174                .unwrap();
2175            let (a, Deserialized(b)) = from_row(row);
2176            assert_eq!((a, b), (String::from("hello"), decodable));
2177        }
2178
2179        #[test]
2180        fn should_set_connect_attrs() {
2181            let opts = OptsBuilder::from_opts(get_opts());
2182            let mut conn = Conn::new(opts).unwrap();
2183
2184            let support_connect_attrs = match (conn.0.server_version, conn.0.mariadb_server_version)
2185            {
2186                (Some(ref version), _) if *version >= (5, 6, 0) => true,
2187                (_, Some(ref version)) if *version >= (10, 0, 0) => true,
2188                _ => false,
2189            };
2190
2191            if support_connect_attrs {
2192                // MySQL >= 5.6 or MariaDB >= 10.0
2193
2194                if get_system_variable::<String>(&mut conn, "performance_schema") != "ON" {
2195                    panic!("The system variable `performance_schema` is off. Restart the MySQL server with `--performance_schema=on` to pass the test.");
2196                }
2197                let attrs_size: i32 =
2198                    get_system_variable(&mut conn, "performance_schema_session_connect_attrs_size");
2199                if attrs_size >= 0 && attrs_size <= 128 {
2200                    panic!("The system variable `performance_schema_session_connect_attrs_size` is {}. Restart the MySQL server with `--performance_schema_session_connect_attrs_size=-1` to pass the test.", attrs_size);
2201                }
2202
2203                fn assert_connect_attrs(conn: &mut Conn, expected_values: &[(&str, &str)]) {
2204                    let mut actual_values = HashMap::new();
2205                    for row in conn.query_iter("SELECT attr_name, attr_value FROM performance_schema.session_account_connect_attrs WHERE processlist_id = connection_id()").unwrap() {
2206                        let (name, value) = from_row::<(String, String)>(row.unwrap());
2207                        actual_values.insert(name, value);
2208                    }
2209
2210                    for (name, value) in expected_values {
2211                        assert_eq!(
2212                            actual_values.get(&name.to_string()),
2213                            Some(&value.to_string())
2214                        );
2215                    }
2216                }
2217                #[cfg(not(target_os = "wasi"))]
2218                let pid = process::id().to_string();
2219                #[cfg(target_os = "wasi")]
2220                let pid = "66666".to_string();
2221                let progname = std::env::args_os()
2222                    .next()
2223                    .unwrap()
2224                    .to_string_lossy()
2225                    .into_owned();
2226                let mut expected_values = vec![
2227                    ("_client_name", "rust-mysql-simple"),
2228                    ("_client_version", env!("CARGO_PKG_VERSION")),
2229                    ("_os", env!("CARGO_CFG_TARGET_OS")),
2230                    ("_pid", &pid),
2231                    ("_platform", env!("CARGO_CFG_TARGET_ARCH")),
2232                    ("program_name", &progname),
2233                ];
2234
2235                // No connect attributes are added.
2236                assert_connect_attrs(&mut conn, &expected_values);
2237
2238                // Connect attributes are added.
2239                let opts = OptsBuilder::from_opts(get_opts());
2240                let mut connect_attrs = HashMap::with_capacity(3);
2241                connect_attrs.insert("foo", "foo val");
2242                connect_attrs.insert("bar", "bar val");
2243                connect_attrs.insert("program_name", "my program name");
2244                let mut conn = Conn::new(opts.connect_attrs(connect_attrs)).unwrap();
2245                expected_values.pop(); // remove program_name at the last
2246                expected_values.push(("foo", "foo val"));
2247                expected_values.push(("bar", "bar val"));
2248                expected_values.push(("program_name", "my program name"));
2249                assert_connect_attrs(&mut conn, &expected_values);
2250            }
2251        }
2252
2253        #[cfg(not(target_os = "wasi"))]
2254        #[test]
2255        fn should_read_binlog() -> crate::Result<()> {
2256            use std::{
2257                collections::HashMap, sync::mpsc::sync_channel, thread::spawn, time::Duration,
2258            };
2259
2260            fn gen_dummy_data() -> crate::Result<()> {
2261                let mut conn = Conn::new(get_opts())?;
2262
2263                "CREATE TABLE IF NOT EXISTS customers (customer_id int not null)".run(&mut conn)?;
2264
2265                for i in 0_u8..100 {
2266                    "INSERT INTO customers(customer_id) VALUES (?)"
2267                        .with((i,))
2268                        .run(&mut conn)?;
2269                }
2270
2271                "DROP TABLE customers".run(&mut conn)?;
2272
2273                Ok(())
2274            }
2275
2276            fn get_conn() -> crate::Result<(Conn, Vec<u8>, u64)> {
2277                let mut conn = Conn::new(get_opts())?;
2278
2279                if let Ok(Some(gtid_mode)) =
2280                    "SELECT @@GLOBAL.GTID_MODE".first::<String, _>(&mut conn)
2281                {
2282                    if !gtid_mode.starts_with("ON") {
2283                        panic!(
2284                            "GTID_MODE is disabled \
2285                                (enable using --gtid_mode=ON --enforce_gtid_consistency=ON)"
2286                        );
2287                    }
2288                }
2289
2290                let row: crate::Row = "SHOW BINARY LOGS".first(&mut conn)?.unwrap();
2291                let filename = row.get(0).unwrap();
2292                let position = row.get(1).unwrap();
2293
2294                gen_dummy_data().unwrap();
2295                Ok((conn, filename, position))
2296            }
2297
2298            // iterate using COM_BINLOG_DUMP
2299            let (conn, filename, pos) = get_conn().unwrap();
2300            let is_mariadb = conn.0.mariadb_server_version.is_some();
2301
2302            let binlog_stream = conn
2303                .get_binlog_stream(BinlogRequest::new(12).with_filename(filename).with_pos(pos))
2304                .unwrap();
2305
2306            let mut events_num = 0;
2307            let (tx, rx) = sync_channel(0);
2308            spawn(move || {
2309                for event in binlog_stream {
2310                    tx.send(event).unwrap();
2311                }
2312            });
2313            let mut tmes = HashMap::new();
2314            while let Ok(event) = rx.recv_timeout(Duration::from_secs(1)) {
2315                let event = event.unwrap();
2316                events_num += 1;
2317
2318                // assert that event type is known
2319                event.header().event_type().unwrap();
2320
2321                // iterate over rows of an event
2322                match event.read_data()?.unwrap() {
2323                    EventData::TableMapEvent(tme) => {
2324                        tmes.insert(tme.table_id(), tme.into_owned());
2325                    }
2326                    EventData::RowsEvent(re) => {
2327                        for row in re.rows(&tmes[&re.table_id()]) {
2328                            row.unwrap();
2329                        }
2330                    }
2331                    _ => (),
2332                }
2333            }
2334            assert!(events_num > 0);
2335
2336            if !is_mariadb {
2337                // iterate using COM_BINLOG_DUMP_GTID
2338                let (conn, filename, pos) = get_conn().unwrap();
2339
2340                let binlog_stream = conn
2341                    .get_binlog_stream(
2342                        BinlogRequest::new(13)
2343                            .with_use_gtid(true)
2344                            .with_filename(filename)
2345                            .with_pos(pos),
2346                    )
2347                    .unwrap();
2348
2349                let mut events_num = 0;
2350                let (tx, rx) = sync_channel(0);
2351                spawn(move || {
2352                    for event in binlog_stream {
2353                        tx.send(event).unwrap();
2354                    }
2355                });
2356                let mut tmes = HashMap::new();
2357                while let Ok(event) = rx.recv_timeout(Duration::from_secs(1)) {
2358                    let event = event.unwrap();
2359                    events_num += 1;
2360
2361                    // assert that event type is known
2362                    event.header().event_type().unwrap();
2363
2364                    // iterate over rows of an event
2365                    match event.read_data()?.unwrap() {
2366                        EventData::TableMapEvent(tme) => {
2367                            tmes.insert(tme.table_id(), tme.into_owned());
2368                        }
2369                        EventData::RowsEvent(re) => {
2370                            for row in re.rows(&tmes[&re.table_id()]) {
2371                                row.unwrap();
2372                            }
2373                        }
2374                        _ => (),
2375                    }
2376                }
2377                assert!(events_num > 0);
2378            }
2379
2380            // iterate using COM_BINLOG_DUMP with BINLOG_DUMP_NON_BLOCK flag
2381            let (conn, filename, pos) = get_conn().unwrap();
2382
2383            let mut binlog_stream = conn
2384                .get_binlog_stream(
2385                    BinlogRequest::new(14)
2386                        .with_filename(filename)
2387                        .with_pos(pos)
2388                        .with_flags(crate::BinlogDumpFlags::BINLOG_DUMP_NON_BLOCK),
2389                )
2390                .unwrap();
2391
2392            events_num = 0;
2393            while let Some(event) = binlog_stream.next() {
2394                let event = event.unwrap();
2395                events_num += 1;
2396                event.header().event_type().unwrap();
2397                event.read_data()?;
2398            }
2399            assert!(events_num > 0);
2400
2401            Ok(())
2402        }
2403    }
2404
2405    #[cfg(feature = "nightly")]
2406    mod bench {
2407        use test;
2408
2409        use crate::{params, prelude::*, test_misc::get_opts, Conn, Value::NULL};
2410
2411        #[bench]
2412        fn simple_exec(bencher: &mut test::Bencher) {
2413            let mut conn = Conn::new(get_opts()).unwrap();
2414            bencher.iter(|| {
2415                let _ = conn.query_drop("DO 1");
2416            })
2417        }
2418
2419        #[bench]
2420        fn prepared_exec(bencher: &mut test::Bencher) {
2421            let mut conn = Conn::new(get_opts()).unwrap();
2422            let stmt = conn.prep("DO 1").unwrap();
2423            bencher.iter(|| {
2424                let _ = conn.exec_drop(&stmt, ()).unwrap();
2425            })
2426        }
2427
2428        #[bench]
2429        fn prepare_and_exec(bencher: &mut test::Bencher) {
2430            let mut conn = Conn::new(get_opts()).unwrap();
2431            bencher.iter(|| {
2432                let stmt = conn.prep("SELECT ?").unwrap();
2433                let _ = conn.exec_drop(&stmt, (0,)).unwrap();
2434            })
2435        }
2436
2437        #[bench]
2438        fn simple_query_row(bencher: &mut test::Bencher) {
2439            let mut conn = Conn::new(get_opts()).unwrap();
2440            bencher.iter(|| {
2441                let _ = conn.query_drop("SELECT 1").unwrap();
2442            })
2443        }
2444
2445        #[bench]
2446        fn simple_prepared_query_row(bencher: &mut test::Bencher) {
2447            let mut conn = Conn::new(get_opts()).unwrap();
2448            let stmt = conn.prep("SELECT 1").unwrap();
2449            bencher.iter(|| {
2450                let _ = conn.exec_drop(&stmt, ()).unwrap();
2451            })
2452        }
2453
2454        #[bench]
2455        fn simple_prepared_query_row_with_param(bencher: &mut test::Bencher) {
2456            let mut conn = Conn::new(get_opts()).unwrap();
2457            let stmt = conn.prep("SELECT ?").unwrap();
2458            bencher.iter(|| {
2459                let _ = conn.exec_drop(&stmt, (0,)).unwrap();
2460            })
2461        }
2462
2463        #[bench]
2464        fn simple_prepared_query_row_with_named_param(bencher: &mut test::Bencher) {
2465            let mut conn = Conn::new(get_opts()).unwrap();
2466            let stmt = conn.prep("SELECT :a").unwrap();
2467            bencher.iter(|| {
2468                let _ = conn.exec_drop(&stmt, params! {"a" => 0}).unwrap();
2469            })
2470        }
2471
2472        #[bench]
2473        fn simple_prepared_query_row_with_5_params(bencher: &mut test::Bencher) {
2474            let mut conn = Conn::new(get_opts()).unwrap();
2475            let stmt = conn.prep("SELECT ?, ?, ?, ?, ?").unwrap();
2476            let params = (42i8, b"123456".to_vec(), 1.618f64, NULL, 1i8);
2477            bencher.iter(|| {
2478                let _ = conn.exec_drop(&stmt, &params).unwrap();
2479            })
2480        }
2481
2482        #[bench]
2483        fn simple_prepared_query_row_with_5_named_params(bencher: &mut test::Bencher) {
2484            let mut conn = Conn::new(get_opts()).unwrap();
2485            let stmt = conn
2486                .prep("SELECT :one, :two, :three, :four, :five")
2487                .unwrap();
2488            bencher.iter(|| {
2489                let _ = conn.exec_drop(
2490                    &stmt,
2491                    params! {
2492                        "one" => 42i8,
2493                        "two" => b"123456",
2494                        "three" => 1.618f64,
2495                        "four" => NULL,
2496                        "five" => 1i8,
2497                    },
2498                );
2499            })
2500        }
2501
2502        #[bench]
2503        fn select_large_string(bencher: &mut test::Bencher) {
2504            let mut conn = Conn::new(get_opts()).unwrap();
2505            bencher.iter(|| {
2506                let _ = conn.query_drop("SELECT REPEAT('A', 10000)").unwrap();
2507            })
2508        }
2509
2510        #[bench]
2511        fn select_prepared_large_string(bencher: &mut test::Bencher) {
2512            let mut conn = Conn::new(get_opts()).unwrap();
2513            let stmt = conn.prep("SELECT REPEAT('A', 10000)").unwrap();
2514            bencher.iter(|| {
2515                let _ = conn.exec_drop(&stmt, ()).unwrap();
2516            })
2517        }
2518
2519        #[bench]
2520        fn many_small_rows(bencher: &mut test::Bencher) {
2521            let mut conn = Conn::new(get_opts()).unwrap();
2522            conn.query_drop("CREATE TEMPORARY TABLE mysql.x (id INT)")
2523                .unwrap();
2524            for _ in 0..512 {
2525                conn.query_drop("INSERT INTO mysql.x VALUES (256)").unwrap();
2526            }
2527            let stmt = conn.prep("SELECT * FROM mysql.x").unwrap();
2528            bencher.iter(|| {
2529                let _ = conn.exec_drop(&stmt, ()).unwrap();
2530            });
2531        }
2532    }
2533}