1use bytes::{Buf, BufMut};
10use mysql_common::{
11 constants::UTF8MB4_GENERAL_CI,
12 crypto,
13 io::{ParseBuf, ReadMysqlExt},
14 misc::raw::Either,
15 named_params::parse_named_params,
16 packets::{
17 binlog_request::BinlogRequest, AuthPlugin, AuthSwitchRequest, Column, ComStmtClose,
18 ComStmtExecuteRequestBuilder, ComStmtSendLongData, CommonOkPacket, ErrPacket,
19 HandshakePacket, HandshakeResponse, OkPacket, OkPacketDeserializer, OkPacketKind,
20 OldAuthSwitchRequest, OldEofPacket, ResultSetTerminator, SessionStateInfo,
21 },
22 proto::{codec::Compression, sync_framed::MySyncFramed, MySerialize},
23};
24
25use mysql_common::{
26 constants::{DEFAULT_MAX_ALLOWED_PACKET, UTF8_GENERAL_CI},
27 packets::SslRequest,
28};
29
30#[cfg(not(target_os = "wasi"))]
31use std::process;
32use std::{
33 borrow::{Borrow, Cow},
34 cmp,
35 collections::HashMap,
36 convert::TryFrom,
37 io::{self, Write as _},
38 mem,
39 ops::{Deref, DerefMut},
40 sync::Arc,
41};
42
43#[cfg(unix)]
44use std::os::unix::io::{AsRawFd, RawFd};
45
46use crate::{
47 buffer_pool::{get_buffer, Buffer},
48 conn::{
49 local_infile::LocalInfile,
50 pool::{Pool, PooledConn},
51 query_result::{Binary, Or, Text},
52 stmt::{InnerStmt, Statement},
53 stmt_cache::StmtCache,
54 transaction::{AccessMode, TxOpts},
55 },
56 consts::{CapabilityFlags, Command, StatusFlags, MAX_PAYLOAD_LEN},
57 from_value, from_value_opt,
58 io::Stream,
59 prelude::*,
60 DriverError::{
61 MismatchedStmtParams, NamedParamsForPositionalQuery, OldMysqlPasswordDisabled,
62 Protocol41NotSet, ReadOnlyTransNotSupported, SetupError, UnexpectedPacket,
63 UnknownAuthPlugin, UnsupportedProtocol,
64 },
65 Error::{self, DriverError, MySqlError},
66 LocalInfileHandler, Opts, OptsBuilder, Params, QueryResult, Result, Transaction,
67 Value::{self, Bytes, NULL},
68};
69
70use crate::DriverError::TlsNotSupported;
71use crate::SslOpts;
72
73use self::binlog_stream::BinlogStream;
74
75pub mod binlog_stream;
76pub mod local_infile;
77pub mod opts;
78pub mod pool;
79pub mod query;
80pub mod query_result;
81pub mod queryable;
82pub mod stmt;
83mod stmt_cache;
84pub mod transaction;
85
86#[derive(Debug)]
88pub enum ConnMut<'c, 't, 'tc> {
89 Mut(&'c mut Conn),
90 TxMut(&'t mut Transaction<'tc>),
91 Owned(Conn),
92 Pooled(PooledConn),
93}
94
95impl From<Conn> for ConnMut<'static, 'static, 'static> {
96 fn from(conn: Conn) -> Self {
97 ConnMut::Owned(conn)
98 }
99}
100
101impl From<PooledConn> for ConnMut<'static, 'static, 'static> {
102 fn from(conn: PooledConn) -> Self {
103 ConnMut::Pooled(conn)
104 }
105}
106
107impl<'a> From<&'a mut Conn> for ConnMut<'a, 'static, 'static> {
108 fn from(conn: &'a mut Conn) -> Self {
109 ConnMut::Mut(conn)
110 }
111}
112
113impl<'a> From<&'a mut PooledConn> for ConnMut<'a, 'static, 'static> {
114 fn from(conn: &'a mut PooledConn) -> Self {
115 ConnMut::Mut(conn.as_mut())
116 }
117}
118
119impl<'t, 'tc> From<&'t mut Transaction<'tc>> for ConnMut<'static, 't, 'tc> {
120 fn from(tx: &'t mut Transaction<'tc>) -> Self {
121 ConnMut::TxMut(tx)
122 }
123}
124
125impl TryFrom<&Pool> for ConnMut<'static, 'static, 'static> {
126 type Error = Error;
127
128 fn try_from(pool: &Pool) -> Result<Self> {
129 pool.get_conn().map(From::from)
130 }
131}
132
133impl Deref for ConnMut<'_, '_, '_> {
134 type Target = Conn;
135
136 fn deref(&self) -> &Conn {
137 match self {
138 ConnMut::Mut(conn) => &**conn,
139 ConnMut::TxMut(tx) => &*tx.conn,
140 ConnMut::Owned(conn) => &conn,
141 ConnMut::Pooled(conn) => conn.as_ref(),
142 }
143 }
144}
145
146impl DerefMut for ConnMut<'_, '_, '_> {
147 fn deref_mut(&mut self) -> &mut Conn {
148 match self {
149 ConnMut::Mut(ref mut conn) => &mut **conn,
150 ConnMut::TxMut(tx) => &mut *tx.conn,
151 ConnMut::Owned(ref mut conn) => conn,
152 ConnMut::Pooled(ref mut conn) => conn.as_mut(),
153 }
154 }
155}
156
157#[derive(Debug)]
159struct ConnInner {
160 opts: Opts,
161 stream: Option<MySyncFramed<Stream>>,
162 stmt_cache: StmtCache,
163
164 server_version: Option<(u16, u16, u16)>,
166 mariadb_server_version: Option<(u16, u16, u16)>,
167
168 ok_packet: Option<OkPacket<'static>>,
170 capability_flags: CapabilityFlags,
171 connection_id: u32,
172 status_flags: StatusFlags,
173 character_set: u8,
174 last_command: u8,
175 connected: bool,
176 has_results: bool,
177 local_infile_handler: Option<LocalInfileHandler>,
178}
179
180impl ConnInner {
181 fn empty(opts: Opts) -> Self {
182 ConnInner {
183 stmt_cache: StmtCache::new(opts.get_stmt_cache_size()),
184 opts,
185 stream: None,
186 capability_flags: CapabilityFlags::empty(),
187 status_flags: StatusFlags::empty(),
188 connection_id: 0u32,
189 character_set: 0u8,
190 ok_packet: None,
191 last_command: 0u8,
192 connected: false,
193 has_results: false,
194 server_version: None,
195 mariadb_server_version: None,
196 local_infile_handler: None,
197 }
198 }
199}
200
201#[derive(Debug)]
203pub struct Conn(Box<ConnInner>);
204
205impl Conn {
206 const fn has_capability(&self, flag: CapabilityFlags) -> bool {
208 self.0.capability_flags.contains(flag)
209 }
210
211 pub fn server_version(&self) -> (u16, u16, u16) {
213 self.0
214 .server_version
215 .or_else(|| self.0.mariadb_server_version)
216 .unwrap()
217 }
218
219 pub fn connection_id(&self) -> u32 {
221 self.0.connection_id
222 }
223
224 pub fn affected_rows(&self) -> u64 {
226 self.0
227 .ok_packet
228 .as_ref()
229 .map(OkPacket::affected_rows)
230 .unwrap_or_default()
231 }
232
233 pub fn last_insert_id(&self) -> u64 {
237 self.0
238 .ok_packet
239 .as_ref()
240 .and_then(OkPacket::last_insert_id)
241 .unwrap_or_default()
242 }
243
244 pub fn warnings(&self) -> u16 {
246 self.0
247 .ok_packet
248 .as_ref()
249 .map(OkPacket::warnings)
250 .unwrap_or_default()
251 }
252
253 pub fn info_ref(&self) -> &[u8] {
259 self.0
260 .ok_packet
261 .as_ref()
262 .and_then(OkPacket::info_ref)
263 .unwrap_or_default()
264 }
265
266 pub fn info_str(&self) -> Cow<str> {
272 self.0
273 .ok_packet
274 .as_ref()
275 .and_then(OkPacket::info_str)
276 .unwrap_or_default()
277 }
278
279 pub fn session_state_changes(&self) -> io::Result<Vec<SessionStateInfo<'_>>> {
280 self.0
281 .ok_packet
282 .as_ref()
283 .map(|ok| ok.session_state_info())
284 .transpose()
285 .map(Option::unwrap_or_default)
286 }
287
288 fn stream_ref(&self) -> &MySyncFramed<Stream> {
289 self.0.stream.as_ref().expect("incomplete connection")
290 }
291
292 fn stream_mut(&mut self) -> &mut MySyncFramed<Stream> {
293 self.0.stream.as_mut().expect("incomplete connection")
294 }
295
296 fn is_insecure(&self) -> bool {
297 self.stream_ref().get_ref().is_insecure()
298 }
299
300 fn is_socket(&self) -> bool {
301 self.stream_ref().get_ref().is_socket()
302 }
303
304 #[allow(unused_assignments)]
306 fn can_improved(&mut self) -> Result<Option<Opts>> {
307 if self.0.opts.get_prefer_socket() && self.0.opts.addr_is_loopback() {
308 let mut socket = None;
309 #[cfg(test)]
310 {
311 socket = self.0.opts.0.injected_socket.clone();
312 }
313 if socket.is_none() {
314 socket = self.get_system_var("socket")?.map(from_value::<String>);
315 }
316 if let Some(socket) = socket {
317 if self.0.opts.get_socket().is_none() {
318 let socket_opts = OptsBuilder::from_opts(self.0.opts.clone());
319 if !socket.is_empty() {
320 return Ok(Some(socket_opts.socket(Some(socket)).into()));
321 }
322 }
323 }
324 }
325 Ok(None)
326 }
327
328 pub fn new<T, E>(opts: T) -> Result<Conn>
330 where
331 Opts: TryFrom<T, Error = E>,
332 crate::Error: From<E>,
333 {
334 let opts = Opts::try_from(opts)?;
335 let mut conn = Conn(Box::new(ConnInner::empty(opts)));
336 conn.connect_stream()?;
337 conn.connect()?;
338 let mut conn = {
339 if let Some(new_opts) = conn.can_improved()? {
340 let mut improved_conn = Conn(Box::new(ConnInner::empty(new_opts)));
341 improved_conn
342 .connect_stream()
343 .and_then(|_| {
344 improved_conn.connect()?;
345 Ok(improved_conn)
346 })
347 .unwrap_or(conn)
348 } else {
349 conn
350 }
351 };
352 for cmd in conn.0.opts.get_init() {
353 conn.query_drop(cmd)?;
354 }
355 Ok(conn)
356 }
357
358 fn soft_reset(&mut self) -> Result<()> {
359 self.write_command(Command::COM_RESET_CONNECTION, &[])?;
360 let packet = self.read_packet()?;
361 self.handle_ok::<CommonOkPacket>(&packet)?;
362 self.0.last_command = 0;
363 self.0.stmt_cache.clear();
364 Ok(())
365 }
366
367 fn hard_reset(&mut self) -> Result<()> {
368 self.0.stmt_cache.clear();
369 self.0.capability_flags = CapabilityFlags::empty();
370 self.0.status_flags = StatusFlags::empty();
371 self.0.connection_id = 0;
372 self.0.character_set = 0;
373 self.0.ok_packet = None;
374 self.0.last_command = 0;
375 self.0.connected = false;
376 self.0.has_results = false;
377 self.connect_stream()?;
378 self.connect()
379 }
380
381 pub fn reset(&mut self) -> Result<()> {
383 match (self.0.server_version, self.0.mariadb_server_version) {
384 (Some(ref version), _) if *version > (5, 7, 3) => {
385 self.soft_reset().or_else(|_| self.hard_reset())
386 }
387 (_, Some(ref version)) if *version >= (10, 2, 7) => {
388 self.soft_reset().or_else(|_| self.hard_reset())
389 }
390 _ => self.hard_reset(),
391 }
392 }
393
394 fn switch_to_ssl(&mut self, ssl_opts: SslOpts) -> Result<()> {
395 let stream = self.0.stream.take().expect("incomplete conn");
396 let (in_buf, out_buf, codec, stream) = stream.destruct();
397 let stream = stream.make_secure(self.0.opts.get_host(), ssl_opts)?;
398 let stream = MySyncFramed::construct(in_buf, out_buf, codec, stream);
399 self.0.stream = Some(stream);
400 Ok(())
401 }
402
403 fn connect_stream(&mut self) -> Result<()> {
404 let opts = &self.0.opts;
405 let read_timeout = opts.get_read_timeout().cloned();
406 let write_timeout = opts.get_write_timeout().cloned();
407 let tcp_keepalive_time = opts.get_tcp_keepalive_time_ms();
408 #[cfg(any(target_os = "linux", target_os = "macos",))]
409 let tcp_keepalive_probe_interval_secs = opts.get_tcp_keepalive_probe_interval_secs();
410 #[cfg(any(target_os = "linux", target_os = "macos",))]
411 let tcp_keepalive_probe_count = opts.get_tcp_keepalive_probe_count();
412 #[cfg(target_os = "linux")]
413 let tcp_user_timeout = opts.get_tcp_user_timeout_ms();
414 let tcp_nodelay = opts.get_tcp_nodelay();
415 let tcp_connect_timeout = opts.get_tcp_connect_timeout();
416 let bind_address = opts.bind_address().cloned();
417 #[cfg(not(target_os = "wasi"))]
418 {
419 let stream = if let Some(socket) = opts.get_socket() {
420 Stream::connect_socket(&*socket, read_timeout, write_timeout)?
421 } else {
422 let port = opts.get_tcp_port();
423 let ip_or_hostname = match opts.get_host() {
424 url::Host::Domain(domain) => domain,
425 url::Host::Ipv4(ip) => ip.to_string(),
426 url::Host::Ipv6(ip) => ip.to_string(),
427 };
428 Stream::connect_tcp(
429 &*ip_or_hostname,
430 port,
431 read_timeout,
432 write_timeout,
433 tcp_keepalive_time,
434 #[cfg(any(target_os = "linux", target_os = "macos",))]
435 tcp_keepalive_probe_interval_secs,
436 #[cfg(any(target_os = "linux", target_os = "macos",))]
437 tcp_keepalive_probe_count,
438 #[cfg(target_os = "linux")]
439 tcp_user_timeout,
440 tcp_nodelay,
441 tcp_connect_timeout,
442 bind_address,
443 )?
444 };
445 self.0.stream = Some(MySyncFramed::new(stream));
446 }
447 #[cfg(target_os = "wasi")]
448 {
449 let port = opts.get_tcp_port();
450 let ip_or_hostname = match opts.get_host() {
451 url::Host::Domain(domain) => domain,
452 url::Host::Ipv4(ip) => ip.to_string(),
453 url::Host::Ipv6(ip) => ip.to_string(),
454 };
455 let stream = Stream::connect_tcp(
456 &*ip_or_hostname,
457 port,
458 read_timeout,
459 write_timeout,
460 tcp_keepalive_time,
461 tcp_nodelay,
462 tcp_connect_timeout,
463 bind_address,
464 )?;
465 self.0.stream = Some(MySyncFramed::new(stream));
466 }
467 Ok(())
468 }
469
470 fn raw_read_packet(&mut self, buffer: &mut Vec<u8>) -> Result<()> {
471 if !self.stream_mut().next_packet(buffer)? {
472 Err(Error::server_disconnected())
473 } else {
474 Ok(())
475 }
476 }
477
478 fn read_packet(&mut self) -> Result<Buffer> {
479 loop {
480 let mut buffer = get_buffer();
481 match self.raw_read_packet(buffer.as_mut()) {
482 Ok(()) if buffer.first() == Some(&0xff) => {
483 match ParseBuf(&*buffer).parse(self.0.capability_flags)? {
484 ErrPacket::Error(server_error) => {
485 self.handle_err();
486 return Err(MySqlError(From::from(server_error)));
487 }
488 ErrPacket::Progress(_progress_report) => {
489 continue;
491 }
492 }
493 }
494 Ok(()) => return Ok(buffer),
495 Err(e) => {
496 self.handle_err();
497 return Err(e);
498 }
499 }
500 }
501 }
502
503 fn drop_packet(&mut self) -> Result<()> {
504 self.read_packet().map(drop)
505 }
506
507 fn write_struct<T: MySerialize>(&mut self, s: &T) -> Result<()> {
508 let mut buf = get_buffer();
509 s.serialize(buf.as_mut());
510 self.write_packet(&mut &*buf)
511 }
512
513 fn write_packet<T: Buf>(&mut self, data: &mut T) -> Result<()> {
514 self.stream_mut().send(data)?;
515 Ok(())
516 }
517
518 fn handle_handshake(&mut self, hp: &HandshakePacket<'_>) {
519 self.0.capability_flags = hp.capabilities() & self.get_client_flags();
520 self.0.status_flags = hp.status_flags();
521 self.0.connection_id = hp.connection_id();
522 self.0.character_set = hp.default_collation();
523 self.0.server_version = hp.server_version_parsed();
524 self.0.mariadb_server_version = hp.maria_db_server_version_parsed();
525 }
526
527 fn handle_ok<'a, T: OkPacketKind>(
528 &mut self,
529 buffer: &'a Buffer,
530 ) -> crate::Result<OkPacket<'a>> {
531 let ok = ParseBuf(&**buffer)
532 .parse::<OkPacketDeserializer<T>>(self.0.capability_flags)?
533 .into_inner();
534 self.0.status_flags = ok.status_flags();
535 self.0.ok_packet = Some(ok.clone().into_owned());
536 Ok(ok)
537 }
538
539 fn handle_err(&mut self) {
540 self.0.status_flags = StatusFlags::empty();
541 self.0.has_results = false;
542 self.0.ok_packet = None;
543 }
544
545 fn more_results_exists(&self) -> bool {
546 self.0
547 .status_flags
548 .contains(StatusFlags::SERVER_MORE_RESULTS_EXISTS)
549 }
550
551 fn perform_auth_switch(&mut self, auth_switch_request: AuthSwitchRequest<'_>) -> Result<()> {
552 if matches!(
553 auth_switch_request.auth_plugin(),
554 AuthPlugin::MysqlOldPassword
555 ) {
556 if self.0.opts.get_secure_auth() {
557 return Err(DriverError(OldMysqlPasswordDisabled));
558 }
559 }
560
561 let nonce = auth_switch_request.plugin_data();
562 let plugin_data = auth_switch_request
563 .auth_plugin()
564 .gen_data(self.0.opts.get_pass(), nonce)
565 .map(Either::Left)
566 .unwrap_or_else(|| Either::Right([]));
567 self.write_struct(&plugin_data)?;
568 self.continue_auth(&auth_switch_request.auth_plugin(), nonce, true)
569 }
570
571 fn do_handshake(&mut self) -> Result<()> {
572 let payload = self.read_packet()?;
573 let handshake = ParseBuf(&*payload).parse::<HandshakePacket>(())?;
574
575 if handshake.protocol_version() != 10u8 {
576 return Err(DriverError(UnsupportedProtocol(
577 handshake.protocol_version(),
578 )));
579 }
580
581 if !handshake
582 .capabilities()
583 .contains(CapabilityFlags::CLIENT_PROTOCOL_41)
584 {
585 return Err(DriverError(Protocol41NotSet));
586 }
587
588 self.handle_handshake(&handshake);
589
590 if self.is_insecure() {
591 if let Some(ssl_opts) = self.0.opts.get_ssl_opts().cloned() {
592 if !self.has_capability(CapabilityFlags::CLIENT_SSL) {
593 return Err(DriverError(TlsNotSupported));
594 } else {
595 self.do_ssl_request()?;
596 self.switch_to_ssl(ssl_opts)?;
597 }
598 }
599 }
600
601 let nonce = {
603 let mut nonce = Vec::from(handshake.scramble_1_ref());
604 nonce.extend_from_slice(handshake.scramble_2_ref().unwrap_or(&[][..]));
605 nonce.resize(20, 0);
608 nonce
609 };
610
611 let auth_plugin = handshake
612 .auth_plugin()
613 .unwrap_or(AuthPlugin::MysqlNativePassword);
614 if let AuthPlugin::Other(ref name) = auth_plugin {
615 let plugin_name = String::from_utf8_lossy(name).into();
616 return Err(DriverError(UnknownAuthPlugin(plugin_name)));
617 }
618
619 let auth_data = auth_plugin.gen_data(self.0.opts.get_pass(), &*nonce);
620 self.write_handshake_response(&auth_plugin, auth_data.as_deref())?;
621 self.continue_auth(&auth_plugin, &*nonce, false)?;
622
623 if self.has_capability(CapabilityFlags::CLIENT_COMPRESS) {
624 self.switch_to_compressed();
625 }
626
627 Ok(())
628 }
629
630 fn switch_to_compressed(&mut self) {
631 self.stream_mut()
632 .codec_mut()
633 .compress(Compression::default());
634 }
635
636 fn get_client_flags(&self) -> CapabilityFlags {
637 let mut client_flags = CapabilityFlags::CLIENT_PROTOCOL_41
638 | CapabilityFlags::CLIENT_SECURE_CONNECTION
639 | CapabilityFlags::CLIENT_LONG_PASSWORD
640 | CapabilityFlags::CLIENT_TRANSACTIONS
641 | CapabilityFlags::CLIENT_LOCAL_FILES
642 | CapabilityFlags::CLIENT_MULTI_STATEMENTS
643 | CapabilityFlags::CLIENT_MULTI_RESULTS
644 | CapabilityFlags::CLIENT_PS_MULTI_RESULTS
645 | CapabilityFlags::CLIENT_PLUGIN_AUTH
646 | CapabilityFlags::CLIENT_CONNECT_ATTRS
647 | (self.0.capability_flags & CapabilityFlags::CLIENT_LONG_FLAG);
648 if self.0.opts.get_compress().is_some() {
649 client_flags.insert(CapabilityFlags::CLIENT_COMPRESS);
650 }
651 if let Some(db_name) = self.0.opts.get_db_name() {
652 if !db_name.is_empty() {
653 client_flags.insert(CapabilityFlags::CLIENT_CONNECT_WITH_DB);
654 }
655 }
656 if self.is_insecure() && self.0.opts.get_ssl_opts().is_some() {
657 client_flags.insert(CapabilityFlags::CLIENT_SSL);
658 }
659 client_flags | self.0.opts.get_additional_capabilities()
660 }
661
662 fn connect_attrs(&self) -> HashMap<String, String> {
663 let program_name = match self.0.opts.get_connect_attrs().get("program_name") {
664 Some(program_name) => program_name.clone(),
665 None => {
666 let arg0 = std::env::args_os().next();
667 let arg0 = arg0.as_ref().map(|x| x.to_string_lossy());
668 arg0.unwrap_or_else(|| "".into()).to_owned().to_string()
669 }
670 };
671
672 let mut attrs = HashMap::new();
673
674 attrs.insert("_client_name".into(), "rust-mysql-simple".into());
675 attrs.insert("_client_version".into(), env!("CARGO_PKG_VERSION").into());
676 attrs.insert("_os".into(), env!("CARGO_CFG_TARGET_OS").into());
677 #[cfg(not(target_os = "wasi"))]
678 attrs.insert("_pid".into(), process::id().to_string());
679 #[cfg(target_os = "wasi")]
680 attrs.insert("_pid".into(), "66666".into());
681 attrs.insert("_platform".into(), env!("CARGO_CFG_TARGET_ARCH").into());
682 attrs.insert("program_name".into(), program_name);
683
684 for (name, value) in self.0.opts.get_connect_attrs().clone() {
685 attrs.insert(name, value);
686 }
687
688 attrs
689 }
690
691 fn do_ssl_request(&mut self) -> Result<()> {
692 let charset = if self.server_version() >= (5, 5, 3) {
693 UTF8MB4_GENERAL_CI
694 } else {
695 UTF8_GENERAL_CI
696 };
697
698 let ssl_request = SslRequest::new(
699 self.get_client_flags(),
700 DEFAULT_MAX_ALLOWED_PACKET as u32,
701 charset as u8,
702 );
703 self.write_struct(&ssl_request)
704 }
705
706 fn write_handshake_response(
707 &mut self,
708 auth_plugin: &AuthPlugin<'_>,
709 scramble_buf: Option<&[u8]>,
710 ) -> Result<()> {
711 let handshake_response = HandshakeResponse::new(
712 scramble_buf,
713 self.0.server_version.unwrap_or((0, 0, 0)),
714 self.0.opts.get_user().map(str::as_bytes),
715 self.0.opts.get_db_name().map(str::as_bytes),
716 Some(auth_plugin.clone()),
717 self.0.capability_flags,
718 Some(self.connect_attrs().clone()),
719 );
720
721 let mut buf = get_buffer();
722 handshake_response.serialize(buf.as_mut());
723 self.write_packet(&mut &*buf)
724 }
725
726 fn continue_auth(
727 &mut self,
728 auth_plugin: &AuthPlugin<'_>,
729 nonce: &[u8],
730 auth_switched: bool,
731 ) -> Result<()> {
732 match auth_plugin {
733 AuthPlugin::CachingSha2Password => {
734 self.continue_caching_sha2_password_auth(nonce, auth_switched)?;
735 Ok(())
736 }
737 AuthPlugin::MysqlNativePassword | AuthPlugin::MysqlOldPassword => {
738 self.continue_mysql_native_password_auth(nonce, auth_switched)?;
739 Ok(())
740 }
741 AuthPlugin::Other(ref name) => {
742 let plugin_name = String::from_utf8_lossy(name).into();
743 Err(DriverError(UnknownAuthPlugin(plugin_name)))
744 }
745 }
746 }
747
748 fn continue_mysql_native_password_auth(
749 &mut self,
750 nonce: &[u8],
751 auth_switched: bool,
752 ) -> Result<()> {
753 let payload = self.read_packet()?;
754
755 match payload[0] {
756 0x00 => self.handle_ok::<CommonOkPacket>(&payload).map(drop),
758 0xfe if !auth_switched => {
760 let auth_switch = if payload.len() > 1 {
761 ParseBuf(&*payload).parse(())?
762 } else {
763 let _ = ParseBuf(&*payload).parse::<OldAuthSwitchRequest>(())?;
764 AuthSwitchRequest::new("mysql_old_password".as_bytes(), nonce)
766 };
767 self.perform_auth_switch(auth_switch)
768 }
769 _ => Err(DriverError(UnexpectedPacket)),
770 }
771 }
772
773 fn continue_caching_sha2_password_auth(
774 &mut self,
775 nonce: &[u8],
776 auth_switched: bool,
777 ) -> Result<()> {
778 let payload = self.read_packet()?;
779
780 match payload[0] {
781 0x00 => {
782 Ok(())
784 }
785 0x01 => match payload[1] {
786 0x03 => {
787 let payload = self.read_packet()?;
788 self.handle_ok::<CommonOkPacket>(&payload).map(drop)
789 }
790 0x04 => {
791 if !self.is_insecure() || self.is_socket() {
792 let mut pass = self
793 .0
794 .opts
795 .get_pass()
796 .map(Vec::from)
797 .unwrap_or_else(Vec::new);
798 pass.push(0);
799 self.write_packet(&mut pass.as_slice())?;
800 } else {
801 self.write_packet(&mut &[0x02][..])?;
802 let payload = self.read_packet()?;
803 let key = &payload[1..];
804 let mut pass = self
805 .0
806 .opts
807 .get_pass()
808 .map(Vec::from)
809 .unwrap_or_else(Vec::new);
810 pass.push(0);
811 for i in 0..pass.len() {
812 pass[i] ^= nonce[i % nonce.len()];
813 }
814 let encrypted_pass = crypto::encrypt(&*pass, key);
815 self.write_packet(&mut encrypted_pass.as_slice())?;
816 }
817
818 let payload = self.read_packet()?;
819 self.handle_ok::<CommonOkPacket>(&payload).map(drop)
820 }
821 _ => Err(DriverError(UnexpectedPacket)),
822 },
823 0xfe if !auth_switched => {
824 let auth_switch_request = ParseBuf(&*payload).parse(())?;
825 self.perform_auth_switch(auth_switch_request)
826 }
827 _ => Err(DriverError(UnexpectedPacket)),
828 }
829 }
830
831 fn reset_seq_id(&mut self) {
832 self.stream_mut().codec_mut().reset_seq_id();
833 }
834
835 fn sync_seq_id(&mut self) {
836 self.stream_mut().codec_mut().sync_seq_id();
837 }
838
839 fn write_command_raw<T: MySerialize>(&mut self, cmd: &T) -> Result<()> {
840 let mut buf = get_buffer();
841 cmd.serialize(buf.as_mut());
842 self.reset_seq_id();
843 debug_assert!(buf.len() > 0);
844 self.0.last_command = buf[0];
845 self.write_packet(&mut &*buf)
846 }
847
848 fn write_command(&mut self, cmd: Command, data: &[u8]) -> Result<()> {
849 let mut buf = get_buffer();
850 buf.as_mut().put_u8(cmd as u8);
851 buf.as_mut().extend_from_slice(data);
852
853 self.reset_seq_id();
854 self.0.last_command = buf[0];
855 self.write_packet(&mut &*buf)
856 }
857
858 fn send_long_data(&mut self, stmt_id: u32, params: &[Value]) -> Result<()> {
859 for (i, value) in params.iter().enumerate() {
860 if let Bytes(bytes) = value {
861 let chunks = bytes.chunks(MAX_PAYLOAD_LEN - 6);
862 let chunks = chunks.chain(if bytes.is_empty() {
863 Some(&[][..])
864 } else {
865 None
866 });
867 for chunk in chunks {
868 let cmd = ComStmtSendLongData::new(stmt_id, i as u16, Cow::Borrowed(chunk));
869 self.write_command_raw(&cmd)?;
870 }
871 }
872 }
873
874 Ok(())
875 }
876
877 fn _execute(
878 &mut self,
879 stmt: &Statement,
880 params: Params,
881 ) -> Result<Or<Vec<Column>, OkPacket<'static>>> {
882 let exec_request = match ¶ms {
883 Params::Empty => {
884 if stmt.num_params() != 0 {
885 return Err(DriverError(MismatchedStmtParams(stmt.num_params(), 0)));
886 }
887
888 let (body, _) = ComStmtExecuteRequestBuilder::new(stmt.id()).build(&[]);
889 body
890 }
891 Params::Positional(params) => {
892 if stmt.num_params() != params.len() as u16 {
893 return Err(DriverError(MismatchedStmtParams(
894 stmt.num_params(),
895 params.len(),
896 )));
897 }
898
899 let (body, as_long_data) =
900 ComStmtExecuteRequestBuilder::new(stmt.id()).build(&*params);
901
902 if as_long_data {
903 self.send_long_data(stmt.id(), &*params)?;
904 }
905
906 body
907 }
908 Params::Named(_) => {
909 if let Some(named_params) = stmt.named_params.as_ref() {
910 return self._execute(stmt, params.into_positional(named_params)?);
911 } else {
912 return Err(DriverError(NamedParamsForPositionalQuery));
913 }
914 }
915 };
916 self.write_command_raw(&exec_request)?;
917 self.handle_result_set()
918 }
919
920 fn _start_transaction(&mut self, tx_opts: TxOpts) -> Result<()> {
921 if let Some(i_level) = tx_opts.isolation_level() {
922 self.query_drop(format!("SET TRANSACTION ISOLATION LEVEL {}", i_level))?;
923 }
924 if let Some(mode) = tx_opts.access_mode() {
925 let supported = match (self.0.server_version, self.0.mariadb_server_version) {
926 (Some(ref version), _) if *version >= (5, 6, 5) => true,
927 (_, Some(ref version)) if *version >= (10, 0, 0) => true,
928 _ => false,
929 };
930 if !supported {
931 return Err(DriverError(ReadOnlyTransNotSupported));
932 }
933 match mode {
934 AccessMode::ReadOnly => self.query_drop("SET TRANSACTION READ ONLY")?,
935 AccessMode::ReadWrite => self.query_drop("SET TRANSACTION READ WRITE")?,
936 }
937 }
938 if tx_opts.with_consistent_snapshot() {
939 self.query_drop("START TRANSACTION WITH CONSISTENT SNAPSHOT")
940 .unwrap();
941 } else {
942 self.query_drop("START TRANSACTION")?;
943 };
944 Ok(())
945 }
946
947 fn send_local_infile(&mut self, file_name: &[u8]) -> Result<OkPacket<'static>> {
948 {
949 let buffer_size = cmp::min(
950 MAX_PAYLOAD_LEN - 4,
951 self.stream_ref().codec().max_allowed_packet - 4,
952 );
953 let chunk = vec![0u8; buffer_size].into_boxed_slice();
954 let maybe_handler = self
955 .0
956 .local_infile_handler
957 .clone()
958 .or_else(|| self.0.opts.get_local_infile_handler().cloned());
959 let mut local_infile = LocalInfile::new(io::Cursor::new(chunk), self);
960 if let Some(handler) = maybe_handler {
961 let handler_fn = &mut *handler.0.lock()?;
965 handler_fn(file_name, &mut local_infile)?;
966 }
967 local_infile.flush()?;
968 }
969 self.write_packet(&mut &[][..])?;
970 let payload = self.read_packet()?;
971 let ok = self.handle_ok::<CommonOkPacket>(&payload)?;
972 Ok(ok.into_owned())
973 }
974
975 fn handle_result_set(&mut self) -> Result<Or<Vec<Column>, OkPacket<'static>>> {
976 if self.more_results_exists() {
977 self.sync_seq_id();
978 }
979
980 let pld = self.read_packet()?;
981 match pld[0] {
982 0x00 => {
983 let ok = self.handle_ok::<CommonOkPacket>(&pld)?;
984 Ok(Or::B(ok.into_owned()))
985 }
986 0xfb => match self.send_local_infile(&pld[1..]) {
987 Ok(ok) => Ok(Or::B(ok)),
988 Err(err) => Err(err),
989 },
990 _ => {
991 let mut reader = &pld[..];
992 let column_count = reader.read_lenenc_int()?;
993 let mut columns: Vec<Column> = Vec::with_capacity(column_count as usize);
994 for _ in 0..column_count {
995 let pld = self.read_packet()?;
996 let column = ParseBuf(&*pld).parse(())?;
997 columns.push(column);
998 }
999 self.drop_packet()?;
1001 self.0.has_results = column_count > 0;
1002 Ok(Or::A(columns))
1003 }
1004 }
1005 }
1006
1007 fn _query(&mut self, query: &str) -> Result<Or<Vec<Column>, OkPacket<'static>>> {
1008 self.write_command(Command::COM_QUERY, query.as_bytes())?;
1009 self.handle_result_set()
1010 }
1011
1012 pub fn ping(&mut self) -> bool {
1015 match self.write_command(Command::COM_PING, &[]) {
1016 Ok(_) => self.drop_packet().is_ok(),
1017 _ => false,
1018 }
1019 }
1020
1021 pub fn select_db(&mut self, schema: &str) -> bool {
1024 match self.write_command(Command::COM_INIT_DB, schema.as_bytes()) {
1025 Ok(_) => self.drop_packet().is_ok(),
1026 _ => false,
1027 }
1028 }
1029
1030 pub fn start_transaction(&mut self, tx_opts: TxOpts) -> Result<Transaction> {
1033 self._start_transaction(tx_opts)?;
1034 Ok(Transaction::new(self.into()))
1035 }
1036
1037 fn _true_prepare(&mut self, query: &[u8]) -> Result<InnerStmt> {
1038 self.write_command(Command::COM_STMT_PREPARE, query)?;
1039 let pld = self.read_packet()?;
1040 let mut stmt = ParseBuf(&*pld).parse::<InnerStmt>(self.connection_id())?;
1041 if stmt.num_params() > 0 {
1042 let mut params: Vec<Column> = Vec::with_capacity(stmt.num_params() as usize);
1043 for _ in 0..stmt.num_params() {
1044 let pld = self.read_packet()?;
1045 params.push(ParseBuf(&*pld).parse(())?);
1046 }
1047 stmt = stmt.with_params(Some(params));
1048 self.drop_packet()?;
1049 }
1050 if stmt.num_columns() > 0 {
1051 let mut columns: Vec<Column> = Vec::with_capacity(stmt.num_columns() as usize);
1052 for _ in 0..stmt.num_columns() {
1053 let pld = self.read_packet()?;
1054 columns.push(ParseBuf(&*pld).parse(())?);
1055 }
1056 stmt = stmt.with_columns(Some(columns));
1057 self.drop_packet()?;
1058 }
1059 Ok(stmt)
1060 }
1061
1062 fn _prepare(&mut self, query: &[u8]) -> Result<Arc<InnerStmt>> {
1063 if let Some(entry) = self.0.stmt_cache.by_query(query) {
1064 return Ok(entry.stmt.clone());
1065 }
1066
1067 let inner_st = Arc::new(self._true_prepare(query)?);
1068
1069 if let Some(old_stmt) = self
1070 .0
1071 .stmt_cache
1072 .put(Arc::new(query.into()), inner_st.clone())
1073 {
1074 self.close(Statement::new(old_stmt, None))?;
1075 }
1076
1077 Ok(inner_st)
1078 }
1079
1080 fn connect(&mut self) -> Result<()> {
1081 if self.0.connected {
1082 return Ok(());
1083 }
1084 self.do_handshake()
1085 .and_then(|_| {
1086 Ok(from_value_opt::<usize>(
1087 self.get_system_var("max_allowed_packet")?.unwrap_or(NULL),
1088 )
1089 .unwrap_or(0))
1090 })
1091 .and_then(|max_allowed_packet| {
1092 if max_allowed_packet == 0 {
1093 Err(DriverError(SetupError))
1094 } else {
1095 self.stream_mut().codec_mut().max_allowed_packet = max_allowed_packet;
1096 self.0.connected = true;
1097 Ok(())
1098 }
1099 })
1100 }
1101
1102 fn get_system_var(&mut self, name: &str) -> Result<Option<Value>> {
1103 self.query_first(format!("SELECT @@{}", name))
1104 }
1105
1106 fn next_row_packet(&mut self) -> Result<Option<Buffer>> {
1107 if !self.0.has_results {
1108 return Ok(None);
1109 }
1110
1111 let pld = self.read_packet()?;
1112
1113 if self.has_capability(CapabilityFlags::CLIENT_DEPRECATE_EOF) {
1114 if pld[0] == 0xfe && pld.len() < MAX_PAYLOAD_LEN {
1115 self.0.has_results = false;
1116 self.handle_ok::<ResultSetTerminator>(&pld)?;
1117 return Ok(None);
1118 }
1119 } else {
1120 if pld[0] == 0xfe && pld.len() < 8 {
1121 self.0.has_results = false;
1122 self.handle_ok::<OldEofPacket>(&pld)?;
1123 return Ok(None);
1124 }
1125 }
1126
1127 Ok(Some(pld))
1128 }
1129
1130 fn has_stmt(&self, query: &[u8]) -> bool {
1131 self.0.stmt_cache.contains_query(query)
1132 }
1133
1134 pub fn set_local_infile_handler(&mut self, handler: Option<LocalInfileHandler>) {
1141 self.0.local_infile_handler = handler;
1142 }
1143
1144 pub fn no_backslash_escape(&self) -> bool {
1145 self.0
1146 .status_flags
1147 .contains(StatusFlags::SERVER_STATUS_NO_BACKSLASH_ESCAPES)
1148 }
1149
1150 fn register_as_slave(&mut self, server_id: u32) -> Result<()> {
1151 use mysql_common::packets::ComRegisterSlave;
1152
1153 self.query_drop("SET @master_binlog_checksum='ALL'")?;
1154 self.write_command_raw(&ComRegisterSlave::new(server_id))?;
1155
1156 self.read_packet()?;
1158
1159 Ok(())
1160 }
1161
1162 fn request_binlog(&mut self, request: BinlogRequest<'_>) -> Result<()> {
1163 self.register_as_slave(request.server_id())?;
1164 self.write_command_raw(&request.as_cmd())?;
1165 Ok(())
1166 }
1167
1168 pub fn get_binlog_stream(mut self, request: BinlogRequest<'_>) -> Result<BinlogStream> {
1174 self.request_binlog(request)?;
1175 Ok(BinlogStream::new(self))
1176 }
1177}
1178
1179#[cfg(unix)]
1180impl AsRawFd for Conn {
1181 fn as_raw_fd(&self) -> RawFd {
1182 self.stream_ref().get_ref().as_raw_fd()
1183 }
1184}
1185
1186impl Queryable for Conn {
1187 fn query_iter<T: AsRef<str>>(&mut self, query: T) -> Result<QueryResult<'_, '_, '_, Text>> {
1188 let meta = self._query(query.as_ref())?;
1189 Ok(QueryResult::new(ConnMut::Mut(self), meta))
1190 }
1191
1192 fn prep<T: AsRef<str>>(&mut self, query: T) -> Result<Statement> {
1193 let query = query.as_ref();
1194 let (named_params, real_query) = parse_named_params(query.as_bytes())?;
1195 self._prepare(real_query.borrow())
1196 .map(|inner| Statement::new(inner, named_params))
1197 }
1198
1199 fn close(&mut self, stmt: Statement) -> Result<()> {
1200 self.0.stmt_cache.remove(stmt.id());
1201 let cmd = ComStmtClose::new(stmt.id());
1202 self.write_command_raw(&cmd)
1203 }
1204
1205 fn exec_iter<S, P>(&mut self, stmt: S, params: P) -> Result<QueryResult<'_, '_, '_, Binary>>
1206 where
1207 S: AsStatement,
1208 P: Into<Params>,
1209 {
1210 let statement = stmt.as_statement(self)?;
1211 let meta = self._execute(&*statement, params.into())?;
1212 Ok(QueryResult::new(ConnMut::Mut(self), meta))
1213 }
1214}
1215
1216impl Drop for Conn {
1217 fn drop(&mut self) {
1218 let stmt_cache = mem::replace(&mut self.0.stmt_cache, StmtCache::new(0));
1219
1220 for (_, entry) in stmt_cache.into_iter() {
1221 let _ = self.close(Statement::new(entry.stmt, None));
1222 }
1223
1224 if self.0.stream.is_some() {
1225 let _ = self.write_command(Command::COM_QUIT, &[]);
1226 }
1227 }
1228}
1229
1230#[cfg(test)]
1231#[allow(non_snake_case)]
1232mod test {
1233 mod my_conn {
1234 use std::{
1235 collections::HashMap,
1236 io::Write,
1237 iter, process,
1238 sync::mpsc::{channel, sync_channel},
1239 thread::spawn,
1240 time::Duration,
1241 };
1242
1243 use mysql_common::{binlog::events::EventData, packets::binlog_request::BinlogRequest};
1244 use time::PrimitiveDateTime;
1245
1246 use crate::{
1247 from_row, from_value, params,
1248 prelude::*,
1249 test_misc::get_opts,
1250 Conn,
1251 DriverError::{MissingNamedParameter, NamedParamsForPositionalQuery},
1252 Error::DriverError,
1253 LocalInfileHandler, Opts, OptsBuilder, Pool, TxOpts,
1254 Value::{self, Bytes, Date, Float, Int, NULL},
1255 };
1256
1257 fn get_system_variable<T>(conn: &mut Conn, name: &str) -> T
1258 where
1259 T: FromValue,
1260 {
1261 conn.query_first::<(String, T), _>(format!("show variables like '{}'", name))
1262 .unwrap()
1263 .unwrap()
1264 .1
1265 }
1266
1267 #[test]
1268 fn should_connect() {
1269 let mut conn = Conn::new(get_opts()).unwrap();
1270
1271 let mode: String = conn
1272 .query_first("SELECT @@GLOBAL.sql_mode")
1273 .unwrap()
1274 .unwrap();
1275 assert!(mode.contains("TRADITIONAL"));
1276 assert!(conn.ping());
1277
1278 if crate::test_misc::test_compression() {
1279 assert!(format!("{:?}", conn.0.stream).contains("Compression"));
1280 }
1281
1282 if crate::test_misc::test_ssl() {
1283 assert!(!conn.is_insecure());
1284 }
1285 }
1286
1287 #[test]
1288 fn mysql_async_issue_107() -> crate::Result<()> {
1289 let mut conn = Conn::new(get_opts())?;
1290 conn.query_drop(
1291 r"CREATE TEMPORARY TABLE mysql.issue (
1292 a BIGINT(20) UNSIGNED,
1293 b VARBINARY(16),
1294 c BINARY(32),
1295 d BIGINT(20) UNSIGNED,
1296 e BINARY(32)
1297 )",
1298 )?;
1299 conn.query_drop(
1300 r"INSERT INTO mysql.issue VALUES (
1301 0,
1302 0xC066F966B0860000,
1303 0x7939DA98E524C5F969FC2DE8D905FD9501EBC6F20001B0A9C941E0BE6D50CF44,
1304 0,
1305 ''
1306 ), (
1307 1,
1308 '',
1309 0x076311DF4D407B0854371BA13A5F3FB1A4555AC22B361375FD47B263F31822F2,
1310 0,
1311 ''
1312 )",
1313 )?;
1314
1315 let q = "SELECT b, c, d, e FROM mysql.issue";
1316 let result = conn.query_iter(q)?;
1317
1318 let loaded_structs = result
1319 .map(|row| crate::from_row::<(Vec<u8>, Vec<u8>, u64, Vec<u8>)>(row.unwrap()))
1320 .collect::<Vec<_>>();
1321
1322 assert_eq!(loaded_structs.len(), 2);
1323
1324 Ok(())
1325 }
1326
1327 #[test]
1328 fn query_traits() -> Result<(), Box<dyn std::error::Error>> {
1329 macro_rules! test_query {
1330 ($conn : expr) => {
1331 "CREATE TABLE tmplak (a INT)".run($conn)?;
1332
1333 "INSERT INTO tmplak (a) VALUES (?)".with((42,)).run($conn)?;
1334
1335 "INSERT INTO tmplak (a) VALUES (?)"
1336 .with((43..=44).map(|x| (x,)))
1337 .batch($conn)?;
1338
1339 let first: Option<u8> = "SELECT a FROM tmplak LIMIT 1".first($conn)?;
1340 assert_eq!(first, Some(42), "first text");
1341
1342 let first: Option<u8> = "SELECT a FROM tmplak LIMIT 1".with(()).first($conn)?;
1343 assert_eq!(first, Some(42), "first bin");
1344
1345 let count = "SELECT a FROM tmplak".run($conn)?.count();
1346 assert_eq!(count, 3, "run text");
1347
1348 let count = "SELECT a FROM tmplak".with(()).run($conn)?.count();
1349 assert_eq!(count, 3, "run bin");
1350
1351 let all: Vec<u8> = "SELECT a FROM tmplak".fetch($conn)?;
1352 assert_eq!(all, vec![42, 43, 44], "fetch text");
1353
1354 let all: Vec<u8> = "SELECT a FROM tmplak".with(()).fetch($conn)?;
1355 assert_eq!(all, vec![42, 43, 44], "fetch bin");
1356
1357 let mapped = "SELECT a FROM tmplak".map($conn, |x: u8| x + 1)?;
1358 assert_eq!(mapped, vec![43, 44, 45], "map text");
1359
1360 let mapped = "SELECT a FROM tmplak".with(()).map($conn, |x: u8| x + 1)?;
1361 assert_eq!(mapped, vec![43, 44, 45], "map bin");
1362
1363 let sum = "SELECT a FROM tmplak".fold($conn, 0_u8, |acc, x: u8| acc + x)?;
1364 assert_eq!(sum, 42 + 43 + 44, "fold text");
1365
1366 let sum = "SELECT a FROM tmplak"
1367 .with(())
1368 .fold($conn, 0_u8, |acc, x: u8| acc + x)?;
1369 assert_eq!(sum, 42 + 43 + 44, "fold bin");
1370
1371 "DROP TABLE tmplak".run($conn)?;
1372 };
1373 }
1374
1375 let mut conn = Conn::new(get_opts())?;
1376
1377 let mut tx = conn.start_transaction(TxOpts::default())?;
1378 test_query!(&mut tx);
1379 tx.rollback()?;
1380
1381 test_query!(&mut conn);
1382
1383 let pool = Pool::new(get_opts())?;
1384 let mut pooled_conn = pool.get_conn()?;
1385
1386 let mut tx = pool.start_transaction(TxOpts::default())?;
1387 test_query!(&mut tx);
1388 tx.rollback()?;
1389
1390 test_query!(&mut pooled_conn);
1391
1392 Ok(())
1393 }
1394
1395 #[test]
1396 #[should_panic(expected = "Could not connect to address")]
1397 fn should_fail_on_wrong_socket_path() {
1398 let opts = OptsBuilder::from_opts(get_opts()).socket(Some("/foo/bar/baz"));
1399 let _ = Conn::new(opts).unwrap();
1400 }
1401
1402 #[test]
1403 fn should_fallback_to_tcp_if_cant_switch_to_socket() {
1404 let mut opts = Opts::from(get_opts());
1405 opts.0.injected_socket = Some("/foo/bar/baz".into());
1406 let _ = Conn::new(opts).unwrap();
1407 }
1408
1409 #[test]
1410 fn should_connect_with_database() {
1411 const DB_NAME: &str = "mysql";
1412
1413 let opts = OptsBuilder::from_opts(get_opts()).db_name(Some(DB_NAME));
1414
1415 let mut conn = Conn::new(opts).unwrap();
1416
1417 let db_name: String = conn.query_first("SELECT DATABASE()").unwrap().unwrap();
1418 assert_eq!(db_name, DB_NAME);
1419 }
1420
1421 #[cfg(not(target_os = "wasi"))]
1422 #[test]
1423 fn should_connect_by_hostname() {
1424 let opts = OptsBuilder::from_opts(get_opts()).ip_or_hostname(Some("localhost"));
1425 let mut conn = Conn::new(opts).unwrap();
1426 assert!(conn.ping());
1427 }
1428
1429 #[test]
1430 fn should_select_db() {
1431 const DB_NAME: &str = "t_select_db";
1432
1433 let mut conn = Conn::new(get_opts()).unwrap();
1434 conn.query_drop(format!("CREATE DATABASE IF NOT EXISTS {}", DB_NAME))
1435 .unwrap();
1436 assert!(conn.select_db(DB_NAME));
1437
1438 let db_name: String = conn.query_first("SELECT DATABASE()").unwrap().unwrap();
1439 assert_eq!(db_name, DB_NAME);
1440
1441 conn.query_drop(format!("DROP DATABASE {}", DB_NAME))
1442 .unwrap();
1443 }
1444
1445 #[test]
1446 fn should_execute_queryes_and_parse_results() {
1447 type TestRow = (String, String, String, String, String, String);
1448
1449 const CREATE_QUERY: &str = r"CREATE TEMPORARY TABLE mysql.tbl
1450 (id SERIAL, a TEXT, b INT, c INT UNSIGNED, d DATE, e FLOAT)";
1451 const INSERT_QUERY_1: &str = r"INSERT
1452 INTO mysql.tbl(a, b, c, d, e)
1453 VALUES ('hello', -123, 123, '2014-05-05', 123.123)";
1454 const INSERT_QUERY_2: &str = r"INSERT
1455 INTO mysql.tbl(a, b, c, d, e)
1456 VALUES ('world', -321, 321, '2014-06-06', 321.321)";
1457
1458 let mut conn = Conn::new(get_opts()).unwrap();
1459
1460 conn.query_drop(CREATE_QUERY).unwrap();
1461 assert_eq!(conn.affected_rows(), 0);
1462 assert_eq!(conn.last_insert_id(), 0);
1463
1464 conn.query_drop(INSERT_QUERY_1).unwrap();
1465 assert_eq!(conn.affected_rows(), 1);
1466 assert_eq!(conn.last_insert_id(), 1);
1467
1468 conn.query_drop(INSERT_QUERY_2).unwrap();
1469 assert_eq!(conn.affected_rows(), 1);
1470 assert_eq!(conn.last_insert_id(), 2);
1471
1472 conn.query_drop("SELECT * FROM unexisted").unwrap_err();
1473 conn.query_iter("SELECT * FROM mysql.tbl").unwrap(); conn.query_drop("UPDATE mysql.tbl SET a = 'foo'").unwrap();
1476 assert_eq!(conn.affected_rows(), 2);
1477 assert_eq!(conn.last_insert_id(), 0);
1478
1479 assert!(conn
1480 .query_first::<TestRow, _>("SELECT * FROM mysql.tbl WHERE a = 'bar'")
1481 .unwrap()
1482 .is_none());
1483
1484 let rows: Vec<TestRow> = conn.query("SELECT * FROM mysql.tbl").unwrap();
1485 assert_eq!(
1486 rows,
1487 vec![
1488 (
1489 "1".into(),
1490 "foo".into(),
1491 "-123".into(),
1492 "123".into(),
1493 "2014-05-05".into(),
1494 "123.123".into()
1495 ),
1496 (
1497 "2".into(),
1498 "foo".into(),
1499 "-321".into(),
1500 "321".into(),
1501 "2014-06-06".into(),
1502 "321.321".into()
1503 )
1504 ]
1505 );
1506 }
1507
1508 #[cfg(not(target_os = "wasi"))]
1509 #[test]
1510 fn should_parse_large_text_result() {
1511 let mut conn = Conn::new(get_opts()).unwrap();
1512 let value: Value = conn
1513 .query_first("SELECT REPEAT('A', 20000000)")
1514 .unwrap()
1515 .unwrap();
1516 assert_eq!(value, Bytes(iter::repeat(b'A').take(20_000_000).collect()));
1517 }
1518
1519 #[test]
1520 fn should_execute_statements_and_parse_results() {
1521 const CREATE_QUERY: &str = r"CREATE TEMPORARY TABLE
1522 mysql.tbl (a TEXT, b INT, c INT UNSIGNED, d DATE, e FLOAT)";
1523 const INSERT_SMTM: &str = r"INSERT
1524 INTO mysql.tbl (a, b, c, d, e)
1525 VALUES (?, ?, ?, ?, ?)";
1526
1527 type RowType = (Value, Value, Value, Value, Value);
1528
1529 let row1 = (
1530 Bytes(b"hello".to_vec()),
1531 Int(-123_i64),
1532 Int(123_i64),
1533 Date(2014_u16, 5_u8, 5_u8, 0_u8, 0_u8, 0_u8, 0_u32),
1534 Float(123.123_f32),
1535 );
1536 let row2 = (Bytes(b"".to_vec()), NULL, NULL, NULL, Float(321.321_f32));
1537
1538 let mut conn = Conn::new(get_opts()).unwrap();
1539 conn.query_drop(CREATE_QUERY).unwrap();
1540
1541 let insert_stmt = conn.prep(INSERT_SMTM).unwrap();
1542 assert_eq!(insert_stmt.connection_id(), conn.connection_id());
1543 conn.exec_drop(
1544 &insert_stmt,
1545 (
1546 from_value::<String>(row1.0.clone()),
1547 from_value::<i32>(row1.1.clone()),
1548 from_value::<u32>(row1.2.clone()),
1549 from_value::<PrimitiveDateTime>(row1.3.clone()),
1550 from_value::<f32>(row1.4.clone()),
1551 ),
1552 )
1553 .unwrap();
1554 conn.exec_drop(
1555 &insert_stmt,
1556 (
1557 from_value::<String>(row2.0.clone()),
1558 row2.1.clone(),
1559 row2.2.clone(),
1560 row2.3.clone(),
1561 from_value::<f32>(row2.4.clone()),
1562 ),
1563 )
1564 .unwrap();
1565
1566 let select_stmt = conn.prep("SELECT * from mysql.tbl").unwrap();
1567 let rows: Vec<RowType> = conn.exec(&select_stmt, ()).unwrap();
1568
1569 assert_eq!(rows, vec![row1, row2]);
1570 }
1571
1572 #[cfg(not(target_os = "wasi"))]
1573 #[test]
1574 fn should_parse_large_binary_result() {
1575 let mut conn = Conn::new(get_opts()).unwrap();
1576 let stmt = conn.prep("SELECT REPEAT('A', 20000000)").unwrap();
1577 let value: Value = conn.exec_first(&stmt, ()).unwrap().unwrap();
1578 assert_eq!(value, Bytes(iter::repeat(b'A').take(20_000_000).collect()));
1579 }
1580
1581 #[test]
1582 fn manually_closed_stmt() {
1583 let opts = OptsBuilder::from(get_opts()).stmt_cache_size(1);
1584 let mut conn = Conn::new(opts).unwrap();
1585 let stmt = conn.prep("SELECT 1").unwrap();
1586 conn.exec_drop(&stmt, ()).unwrap();
1587 conn.close(stmt).unwrap();
1588 let stmt = conn.prep("SELECT 1").unwrap();
1589 conn.exec_drop(&stmt, ()).unwrap();
1590 conn.close(stmt).unwrap();
1591 let stmt = conn.prep("SELECT 2").unwrap();
1592 conn.exec_drop(&stmt, ()).unwrap();
1593 }
1594
1595 #[test]
1596 fn should_start_commit_and_rollback_transactions() {
1597 let mut conn = Conn::new(get_opts()).unwrap();
1598 conn.query_drop(
1599 "CREATE TEMPORARY TABLE mysql.tbl(id INT NOT NULL PRIMARY KEY AUTO_INCREMENT, a INT)",
1600 )
1601 .unwrap();
1602 let _ = conn
1603 .start_transaction(TxOpts::default())
1604 .and_then(|mut t| {
1605 t.query_drop("INSERT INTO mysql.tbl(a) VALUES(1)").unwrap();
1606 assert_eq!(t.last_insert_id(), Some(1));
1607 assert_eq!(t.affected_rows(), 1);
1608 t.query_drop("INSERT INTO mysql.tbl(a) VALUES(2)").unwrap();
1609 t.commit().unwrap();
1610 Ok(())
1611 })
1612 .unwrap();
1613 assert_eq!(
1614 conn.query_iter("SELECT COUNT(a) from mysql.tbl")
1615 .unwrap()
1616 .next()
1617 .unwrap()
1618 .unwrap()
1619 .unwrap(),
1620 vec![Bytes(b"2".to_vec())]
1621 );
1622 let _ = conn
1623 .start_transaction(TxOpts::default())
1624 .and_then(|mut t| {
1625 t.query_drop("INSERT INTO tbl2(a) VALUES(1)").unwrap_err();
1626 Ok(())
1627 })
1629 .unwrap();
1630 assert_eq!(
1631 conn.query_iter("SELECT COUNT(a) from mysql.tbl")
1632 .unwrap()
1633 .next()
1634 .unwrap()
1635 .unwrap()
1636 .unwrap(),
1637 vec![Bytes(b"2".to_vec())]
1638 );
1639 let _ = conn
1640 .start_transaction(TxOpts::default())
1641 .and_then(|mut t| {
1642 t.query_drop("INSERT INTO mysql.tbl(a) VALUES(1)").unwrap();
1643 t.query_drop("INSERT INTO mysql.tbl(a) VALUES(2)").unwrap();
1644 t.rollback().unwrap();
1645 Ok(())
1646 })
1647 .unwrap();
1648 assert_eq!(
1649 conn.query_iter("SELECT COUNT(a) from mysql.tbl")
1650 .unwrap()
1651 .next()
1652 .unwrap()
1653 .unwrap()
1654 .unwrap(),
1655 vec![Bytes(b"2".to_vec())]
1656 );
1657 let mut tx = conn.start_transaction(TxOpts::default()).unwrap();
1658 tx.exec_drop("INSERT INTO mysql.tbl(a) VALUES(?)", (3,))
1659 .unwrap();
1660 tx.exec_drop("INSERT INTO mysql.tbl(a) VALUES(?)", (4,))
1661 .unwrap();
1662 tx.commit().unwrap();
1663 assert_eq!(
1664 conn.query_iter("SELECT COUNT(a) from mysql.tbl")
1665 .unwrap()
1666 .next()
1667 .unwrap()
1668 .unwrap()
1669 .unwrap(),
1670 vec![Bytes(b"4".to_vec())]
1671 );
1672 let mut tx = conn.start_transaction(TxOpts::default()).unwrap();
1673 tx.exec_drop("INSERT INTO mysql.tbl(a) VALUES(?)", (5,))
1674 .unwrap();
1675 tx.exec_drop("INSERT INTO mysql.tbl(a) VALUES(?)", (6,))
1676 .unwrap();
1677 drop(tx);
1678 assert_eq!(
1679 conn.query_first("SELECT COUNT(a) from mysql.tbl").unwrap(),
1680 Some(4_usize),
1681 );
1682 }
1683 #[test]
1684 fn should_handle_LOCAL_INFILE_with_custom_handler() {
1685 let mut conn = Conn::new(get_opts()).unwrap();
1686 conn.query_drop("CREATE TEMPORARY TABLE mysql.tbl(a TEXT)")
1687 .unwrap();
1688 conn.set_local_infile_handler(Some(LocalInfileHandler::new(|_, stream| {
1689 let mut cell_data = vec![b'Z'; 65535];
1690 cell_data.push(b'\n');
1691 for _ in 0..1536 {
1692 stream.write_all(&*cell_data)?;
1693 }
1694 Ok(())
1695 })));
1696 match conn.query_drop("LOAD DATA LOCAL INFILE 'file_name' INTO TABLE mysql.tbl") {
1697 Ok(_) => {}
1698 Err(ref err) if format!("{}", err).find("not allowed").is_some() => {
1699 return;
1700 }
1701 Err(err) => panic!("ERROR {}", err),
1702 }
1703 let count = conn
1704 .query_iter("SELECT * FROM mysql.tbl")
1705 .unwrap()
1706 .map(|row| {
1707 assert_eq!(from_row::<(Vec<u8>,)>(row.unwrap()).0.len(), 65535);
1708 1
1709 })
1710 .sum::<usize>();
1711 assert_eq!(count, 1536);
1712 }
1713
1714 #[test]
1715 fn should_reset_connection() {
1716 let mut conn = Conn::new(get_opts()).unwrap();
1717 conn.query_drop(
1718 "CREATE TEMPORARY TABLE `mysql`.`test` \
1719 (`test` VARCHAR(255) NULL);",
1720 )
1721 .unwrap();
1722 conn.query_drop("INSERT INTO `mysql`.`test` (`test`) VALUES ('foo');")
1723 .unwrap();
1724 assert_eq!(conn.affected_rows(), 1);
1725 conn.reset().unwrap();
1726 assert_eq!(conn.affected_rows(), 0);
1727 conn.query_drop("SELECT * FROM `mysql`.`test`;")
1728 .unwrap_err();
1729 }
1730
1731 #[test]
1732 fn prep_exec() {
1733 let mut conn = Conn::new(get_opts()).unwrap();
1734
1735 let stmt1 = conn.prep("SELECT :foo").unwrap();
1736 let stmt2 = conn.prep("SELECT :bar").unwrap();
1737 assert_eq!(
1738 conn.exec::<String, _, _>(&stmt1, params! { "foo" => "foo" })
1739 .unwrap(),
1740 vec![String::from("foo")],
1741 );
1742 assert_eq!(
1743 conn.exec::<String, _, _>(&stmt2, params! { "bar" => "bar" })
1744 .unwrap(),
1745 vec![String::from("bar")],
1746 );
1747 }
1748
1749 #[test]
1750 fn should_connect_via_socket_for_127_0_0_1() {
1751 let opts = OptsBuilder::from_opts(get_opts());
1752 let conn = Conn::new(opts).unwrap();
1753 if conn.is_insecure() {
1754 assert!(conn.is_socket());
1755 }
1756 }
1757
1758 #[test]
1759 fn should_connect_via_socket_localhost() {
1760 let opts = OptsBuilder::from_opts(get_opts()).ip_or_hostname(Some("localhost"));
1761 let conn = Conn::new(opts).unwrap();
1762 if conn.is_insecure() {
1763 assert!(conn.is_socket());
1764 }
1765 }
1766
1767 #[cfg(not(target_os = "wasi"))]
1771 #[test]
1772 fn issue_306() {
1773 let (tx, rx) = channel::<()>();
1774 let handle = spawn(move || {
1775 let mut c1 = Conn::new(get_opts()).unwrap();
1776 let c1_id = c1.connection_id();
1777 let mut c2 = Conn::new(get_opts()).unwrap();
1778 let query_result = c1.query_iter("DO 1; SELECT SLEEP(1); DO 2;").unwrap();
1779 c2.query_drop(format!("KILL {c1_id}")).unwrap();
1780 drop(c2);
1781 drop(query_result);
1782 tx.send(()).unwrap();
1783 });
1784 std::thread::sleep(Duration::from_secs(2));
1785 assert!(rx.try_recv().is_ok());
1786 handle.join().unwrap();
1787 }
1788
1789 #[test]
1790 fn reset_does_work() {
1791 let mut c = Conn::new(get_opts()).unwrap();
1792 let cid = c.connection_id();
1793 c.reset().unwrap();
1794 match (c.0.server_version, c.0.mariadb_server_version) {
1795 (Some(ref version), _) if *version > (5, 7, 3) => {
1796 assert_eq!(cid, c.connection_id());
1797 }
1798 (_, Some(ref version)) if *version >= (10, 2, 7) => {
1799 assert_eq!(cid, c.connection_id());
1800 }
1801 _ => assert_ne!(cid, c.connection_id()),
1802 }
1803 }
1804
1805 #[test]
1811 fn issue_317() {
1812 let mut c = Conn::new(get_opts()).unwrap();
1813 c.0.opts = get_opts().tcp_port(55555).into();
1814 let version = std::mem::replace(&mut c.0.server_version, Some((0, 0, 0)));
1815 let mdbversion = std::mem::replace(&mut c.0.mariadb_server_version, Some((0, 0, 0)));
1816 c.reset().unwrap_err();
1817 c.0.server_version = version;
1818 c.0.mariadb_server_version = mdbversion;
1819 let _ = c.reset();
1820 }
1821
1822 #[test]
1823 fn should_drop_multi_result_set() {
1824 let opts = OptsBuilder::from_opts(get_opts()).db_name(Some("mysql"));
1825 let mut conn = Conn::new(opts).unwrap();
1826 conn.query_drop("CREATE TEMPORARY TABLE TEST_TABLE ( name varchar(255) )")
1827 .unwrap();
1828 conn.exec_drop("SELECT * FROM TEST_TABLE", ()).unwrap();
1829 conn.query_drop(
1830 r"
1831 INSERT INTO TEST_TABLE (name) VALUES ('one');
1832 INSERT INTO TEST_TABLE (name) VALUES ('two');
1833 INSERT INTO TEST_TABLE (name) VALUES ('three');",
1834 )
1835 .unwrap();
1836 conn.exec_drop("SELECT * FROM TEST_TABLE", ()).unwrap();
1837
1838 let mut query_result = conn
1839 .query_iter(
1840 r"
1841 SELECT * FROM TEST_TABLE;
1842 INSERT INTO TEST_TABLE (name) VALUES ('one');
1843 DO 0;",
1844 )
1845 .unwrap();
1846
1847 while let Some(result) = query_result.iter() {
1848 result.affected_rows();
1849 }
1850 }
1851
1852 #[test]
1853 fn should_handle_multi_resultset() {
1854 let opts = OptsBuilder::from_opts(get_opts())
1855 .prefer_socket(false)
1856 .db_name(Some("mysql"));
1857 let mut conn = Conn::new(opts).unwrap();
1858 conn.query_drop("DROP PROCEDURE IF EXISTS multi").unwrap();
1859 conn.query_drop(
1860 r#"CREATE PROCEDURE multi() BEGIN
1861 SELECT 1 UNION ALL SELECT 2;
1862 DO 1;
1863 SELECT 3 UNION ALL SELECT 4;
1864 DO 1;
1865 DO 1;
1866 SELECT REPEAT('A', 17000000);
1867 SELECT REPEAT('A', 17000000);
1868 END"#,
1869 )
1870 .unwrap();
1871 {
1872 let mut query_result = conn.query_iter("CALL multi()").unwrap();
1873 let result_set = query_result
1874 .by_ref()
1875 .map(|row| row.unwrap().unwrap().pop().unwrap())
1876 .collect::<Vec<crate::Value>>();
1877 assert_eq!(result_set, vec![Bytes(b"1".to_vec()), Bytes(b"2".to_vec())]);
1878 let result_set = query_result
1879 .by_ref()
1880 .map(|row| row.unwrap().unwrap().pop().unwrap())
1881 .collect::<Vec<crate::Value>>();
1882 assert_eq!(result_set, vec![Bytes(b"3".to_vec()), Bytes(b"4".to_vec())]);
1883 }
1884 let mut result = conn.query_iter("SELECT 1; SELECT 2; SELECT 3;").unwrap();
1885 let mut i = 0;
1886 while let Some(result_set) = result.iter() {
1887 i += 1;
1888 for row in result_set {
1889 match i {
1890 1 => assert_eq!(row.unwrap().unwrap(), vec![Bytes(b"1".to_vec())]),
1891 2 => assert_eq!(row.unwrap().unwrap(), vec![Bytes(b"2".to_vec())]),
1892 3 => assert_eq!(row.unwrap().unwrap(), vec![Bytes(b"3".to_vec())]),
1893 _ => unreachable!(),
1894 }
1895 }
1896 }
1897 assert_eq!(i, 3);
1898 }
1899
1900 #[test]
1901 fn issue_273() {
1902 let opts = OptsBuilder::from_opts(get_opts()).prefer_socket(false);
1903 let mut conn = Conn::new(opts).unwrap();
1904
1905 "DROP FUNCTION IF EXISTS f1".run(&mut conn).unwrap();
1906 r"CREATE DEFINER=`root`@`localhost` FUNCTION `f1`(p_arg INT, p_arg2 INT) RETURNS int
1907 DETERMINISTIC
1908 BEGIN
1909 RETURN p_arg + p_arg2;
1910 END"
1911 .run(&mut conn)
1912 .unwrap();
1913
1914 "SELECT f1(?, ?)"
1915 .with((100u8, 100u8))
1916 .run(&mut conn)
1917 .unwrap();
1918 }
1919
1920 #[cfg(not(target_os = "wasi"))]
1921 #[test]
1922 fn issue_285() {
1923 let (tx, rx) = sync_channel::<()>(0);
1924
1925 let handle = std::thread::spawn(move || {
1926 let mut conn = Conn::new(get_opts()).unwrap();
1927 const INVALID_SQL: &str = r#"
1928 CREATE TEMPORARY TABLE IF NOT EXISTS `user_details` (
1929 `user_id` int(11) NOT NULL AUTO_INCREMENT,
1930 `username` varchar(255) DEFAULT NULL,
1931 `first_name` varchar(50) DEFAULT NULL,
1932 `last_name` varchar(50) DEFAULT NULL,
1933 PRIMARY KEY (`user_id`)
1934 );
1935
1936 INSERT INTO `user_details` (`user_id`, `username`, `first_name`, `last_name`)
1937 VALUES (1, 'rogers63', 'david')
1938 "#;
1939
1940 conn.query_iter(INVALID_SQL).unwrap();
1941 tx.send(()).unwrap();
1942 });
1943
1944 match rx.recv_timeout(Duration::from_secs(100_000)) {
1945 Ok(_) => handle.join().unwrap(),
1946 Err(_) => panic!("test failed"),
1947 }
1948 }
1949
1950 #[test]
1951 fn should_work_with_named_params() {
1952 let mut conn = Conn::new(get_opts()).unwrap();
1953 {
1954 let stmt = conn.prep("SELECT :a, :b, :a, :c").unwrap();
1955 let result = conn
1956 .exec_first(&stmt, params! {"a" => 1, "b" => 2, "c" => 3})
1957 .unwrap()
1958 .unwrap();
1959 assert_eq!((1_u8, 2_u8, 1_u8, 3_u8), result);
1960 }
1961
1962 let result = conn
1963 .exec_first(
1964 "SELECT :a, :b, :a + :b, :c",
1965 params! {
1966 "a" => 1,
1967 "b" => 2,
1968 "c" => 3,
1969 },
1970 )
1971 .unwrap()
1972 .unwrap();
1973 assert_eq!((1_u8, 2_u8, 3_u8, 3_u8), result);
1974 }
1975
1976 #[test]
1977 fn should_return_error_on_missing_named_parameter() {
1978 let mut conn = Conn::new(get_opts()).unwrap();
1979 let stmt = conn.prep("SELECT :a, :b, :a, :c, :d").unwrap();
1980 let result =
1981 conn.exec_first::<crate::Row, _, _>(&stmt, params! {"a" => 1, "b" => 2, "c" => 3,});
1982 match result {
1983 Err(DriverError(MissingNamedParameter(ref x))) if x == "d" => (),
1984 _ => assert!(false),
1985 }
1986 }
1987
1988 #[test]
1989 fn should_return_error_on_named_params_for_positional_statement() {
1990 let mut conn = Conn::new(get_opts()).unwrap();
1991 let stmt = conn.prep("SELECT ?, ?, ?, ?, ?").unwrap();
1992 let result = conn.exec_drop(&stmt, params! {"a" => 1, "b" => 2, "c" => 3,});
1993 match result {
1994 Err(DriverError(NamedParamsForPositionalQuery)) => (),
1995 _ => assert!(false),
1996 }
1997 }
1998
1999 #[cfg(not(target_os = "wasi"))]
2000 #[test]
2001 fn should_handle_tcp_connect_timeout() {
2002 use crate::error::{DriverError::ConnectTimeout, Error::DriverError};
2003
2004 let opts = OptsBuilder::from_opts(get_opts())
2005 .prefer_socket(false)
2006 .tcp_connect_timeout(Some(::std::time::Duration::from_millis(1000)));
2007 assert!(Conn::new(opts).unwrap().ping());
2008
2009 let opts = OptsBuilder::from_opts(get_opts())
2010 .prefer_socket(false)
2011 .tcp_connect_timeout(Some(::std::time::Duration::from_millis(1000)))
2012 .ip_or_hostname(Some("192.168.255.255"));
2013 match Conn::new(opts).unwrap_err() {
2014 DriverError(ConnectTimeout) => {}
2015 err => panic!("Unexpected error: {}", err),
2016 }
2017 }
2018
2019 #[test]
2020 fn should_set_additional_capabilities() {
2021 use crate::consts::CapabilityFlags;
2022
2023 let opts = OptsBuilder::from_opts(get_opts())
2024 .additional_capabilities(CapabilityFlags::CLIENT_FOUND_ROWS);
2025
2026 let mut conn = Conn::new(opts).unwrap();
2027 conn.query_drop("CREATE TEMPORARY TABLE mysql.tbl (a INT, b TEXT)")
2028 .unwrap();
2029 conn.query_drop("INSERT INTO mysql.tbl (a, b) VALUES (1, 'foo')")
2030 .unwrap();
2031 let result = conn
2032 .query_iter("UPDATE mysql.tbl SET b = 'foo' WHERE a = 1")
2033 .unwrap();
2034 assert_eq!(result.affected_rows(), 1);
2035 }
2036
2037 #[cfg(not(target_os = "wasi"))]
2038 #[test]
2039 fn should_bind_before_connect() {
2040 let port = 28000 + (rand::random::<u16>() % 2000);
2041 let opts = OptsBuilder::from_opts(get_opts())
2042 .prefer_socket(false)
2043 .ip_or_hostname(Some("localhost"))
2044 .bind_address(Some(([127, 0, 0, 1], port)));
2045 let conn = Conn::new(opts).unwrap();
2046 let debug_format: String = format!("{:?}", conn);
2047 let expected_1 = format!("addr: V4(127.0.0.1:{})", port);
2048 let expected_2 = format!("addr: 127.0.0.1:{}", port);
2049 assert!(
2050 debug_format.contains(&expected_1) || debug_format.contains(&expected_2),
2051 "debug_format: {}",
2052 debug_format
2053 );
2054 }
2055
2056 #[cfg(not(target_os = "wasi"))]
2057 #[test]
2058 fn should_bind_before_connect_with_timeout() {
2059 let port = 30000 + (rand::random::<u16>() % 2000);
2060 let opts = OptsBuilder::from_opts(get_opts())
2061 .prefer_socket(false)
2062 .ip_or_hostname(Some("localhost"))
2063 .bind_address(Some(([127, 0, 0, 1], port)))
2064 .tcp_connect_timeout(Some(::std::time::Duration::from_millis(1000)));
2065 let mut conn = Conn::new(opts).unwrap();
2066 assert!(conn.ping());
2067 let debug_format: String = format!("{:?}", conn);
2068 let expected_1 = format!("addr: V4(127.0.0.1:{})", port);
2069 let expected_2 = format!("addr: 127.0.0.1:{}", port);
2070 assert!(
2071 debug_format.contains(&expected_1) || debug_format.contains(&expected_2),
2072 "debug_format: {}",
2073 debug_format
2074 );
2075 }
2076
2077 #[test]
2078 fn should_not_cache_statements_if_stmt_cache_size_is_zero() {
2079 let opts = OptsBuilder::from_opts(get_opts()).stmt_cache_size(0);
2080 let mut conn = Conn::new(opts).unwrap();
2081
2082 let stmt1 = conn.prep("DO 1").unwrap();
2083 let stmt2 = conn.prep("DO 2").unwrap();
2084 let stmt3 = conn.prep("DO 3").unwrap();
2085
2086 conn.close(stmt1).unwrap();
2087 conn.close(stmt2).unwrap();
2088 conn.close(stmt3).unwrap();
2089
2090 let status: (Value, u8) = conn
2091 .query_first("SHOW SESSION STATUS LIKE 'Com_stmt_close';")
2092 .unwrap()
2093 .unwrap();
2094 assert_eq!(status.1, 3);
2095 }
2096
2097 #[test]
2098 fn should_hold_stmt_cache_size_bounds() {
2099 let opts = OptsBuilder::from_opts(get_opts()).stmt_cache_size(3);
2100 let mut conn = Conn::new(opts).unwrap();
2101
2102 conn.prep("DO 1").unwrap();
2103 conn.prep("DO 2").unwrap();
2104 conn.prep("DO 3").unwrap();
2105 conn.prep("DO 1").unwrap();
2106 conn.prep("DO 4").unwrap();
2107 conn.prep("DO 3").unwrap();
2108 conn.prep("DO 5").unwrap();
2109 conn.prep("DO 6").unwrap();
2110
2111 let status: (String, usize) = conn
2112 .query_first("SHOW SESSION STATUS LIKE 'Com_stmt_close'")
2113 .unwrap()
2114 .unwrap();
2115
2116 assert_eq!(status.1, 3);
2117
2118 let mut order = conn
2119 .0
2120 .stmt_cache
2121 .iter()
2122 .map(|(_, entry)| &**entry.query.0.as_ref())
2123 .collect::<Vec<&[u8]>>();
2124 order.sort();
2125 assert_eq!(order, &[b"DO 3", b"DO 5", b"DO 6"]);
2126 }
2127
2128 #[test]
2129 fn should_handle_json_columns() {
2130 use crate::{Deserialized, Serialized};
2131 use serde_json::Value as Json;
2132 use std::str::FromStr;
2133
2134 #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
2135 pub struct DecTest {
2136 foo: String,
2137 quux: (u64, String),
2138 }
2139
2140 let decodable = DecTest {
2141 foo: "bar".into(),
2142 quux: (42, "hello".into()),
2143 };
2144
2145 let mut conn = Conn::new(get_opts()).unwrap();
2146 if conn
2147 .query_drop("CREATE TEMPORARY TABLE mysql.tbl(a VARCHAR(32), b JSON)")
2148 .is_err()
2149 {
2150 conn.query_drop("CREATE TEMPORARY TABLE mysql.tbl(a VARCHAR(32), b TEXT)")
2151 .unwrap();
2152 }
2153 conn.exec_drop(
2154 r#"INSERT INTO mysql.tbl VALUES ('hello', ?)"#,
2155 (Serialized(&decodable),),
2156 )
2157 .unwrap();
2158
2159 let (a, b): (String, Json) = conn
2160 .query_first("SELECT a, b FROM mysql.tbl")
2161 .unwrap()
2162 .unwrap();
2163 assert_eq!(
2164 (a, b),
2165 (
2166 "hello".into(),
2167 Json::from_str(r#"{"foo": "bar", "quux": [42, "hello"]}"#).unwrap()
2168 )
2169 );
2170
2171 let row = conn
2172 .exec_first("SELECT a, b FROM mysql.tbl WHERE a = ?", ("hello",))
2173 .unwrap()
2174 .unwrap();
2175 let (a, Deserialized(b)) = from_row(row);
2176 assert_eq!((a, b), (String::from("hello"), decodable));
2177 }
2178
2179 #[test]
2180 fn should_set_connect_attrs() {
2181 let opts = OptsBuilder::from_opts(get_opts());
2182 let mut conn = Conn::new(opts).unwrap();
2183
2184 let support_connect_attrs = match (conn.0.server_version, conn.0.mariadb_server_version)
2185 {
2186 (Some(ref version), _) if *version >= (5, 6, 0) => true,
2187 (_, Some(ref version)) if *version >= (10, 0, 0) => true,
2188 _ => false,
2189 };
2190
2191 if support_connect_attrs {
2192 if get_system_variable::<String>(&mut conn, "performance_schema") != "ON" {
2195 panic!("The system variable `performance_schema` is off. Restart the MySQL server with `--performance_schema=on` to pass the test.");
2196 }
2197 let attrs_size: i32 =
2198 get_system_variable(&mut conn, "performance_schema_session_connect_attrs_size");
2199 if attrs_size >= 0 && attrs_size <= 128 {
2200 panic!("The system variable `performance_schema_session_connect_attrs_size` is {}. Restart the MySQL server with `--performance_schema_session_connect_attrs_size=-1` to pass the test.", attrs_size);
2201 }
2202
2203 fn assert_connect_attrs(conn: &mut Conn, expected_values: &[(&str, &str)]) {
2204 let mut actual_values = HashMap::new();
2205 for row in conn.query_iter("SELECT attr_name, attr_value FROM performance_schema.session_account_connect_attrs WHERE processlist_id = connection_id()").unwrap() {
2206 let (name, value) = from_row::<(String, String)>(row.unwrap());
2207 actual_values.insert(name, value);
2208 }
2209
2210 for (name, value) in expected_values {
2211 assert_eq!(
2212 actual_values.get(&name.to_string()),
2213 Some(&value.to_string())
2214 );
2215 }
2216 }
2217 #[cfg(not(target_os = "wasi"))]
2218 let pid = process::id().to_string();
2219 #[cfg(target_os = "wasi")]
2220 let pid = "66666".to_string();
2221 let progname = std::env::args_os()
2222 .next()
2223 .unwrap()
2224 .to_string_lossy()
2225 .into_owned();
2226 let mut expected_values = vec![
2227 ("_client_name", "rust-mysql-simple"),
2228 ("_client_version", env!("CARGO_PKG_VERSION")),
2229 ("_os", env!("CARGO_CFG_TARGET_OS")),
2230 ("_pid", &pid),
2231 ("_platform", env!("CARGO_CFG_TARGET_ARCH")),
2232 ("program_name", &progname),
2233 ];
2234
2235 assert_connect_attrs(&mut conn, &expected_values);
2237
2238 let opts = OptsBuilder::from_opts(get_opts());
2240 let mut connect_attrs = HashMap::with_capacity(3);
2241 connect_attrs.insert("foo", "foo val");
2242 connect_attrs.insert("bar", "bar val");
2243 connect_attrs.insert("program_name", "my program name");
2244 let mut conn = Conn::new(opts.connect_attrs(connect_attrs)).unwrap();
2245 expected_values.pop(); expected_values.push(("foo", "foo val"));
2247 expected_values.push(("bar", "bar val"));
2248 expected_values.push(("program_name", "my program name"));
2249 assert_connect_attrs(&mut conn, &expected_values);
2250 }
2251 }
2252
2253 #[cfg(not(target_os = "wasi"))]
2254 #[test]
2255 fn should_read_binlog() -> crate::Result<()> {
2256 use std::{
2257 collections::HashMap, sync::mpsc::sync_channel, thread::spawn, time::Duration,
2258 };
2259
2260 fn gen_dummy_data() -> crate::Result<()> {
2261 let mut conn = Conn::new(get_opts())?;
2262
2263 "CREATE TABLE IF NOT EXISTS customers (customer_id int not null)".run(&mut conn)?;
2264
2265 for i in 0_u8..100 {
2266 "INSERT INTO customers(customer_id) VALUES (?)"
2267 .with((i,))
2268 .run(&mut conn)?;
2269 }
2270
2271 "DROP TABLE customers".run(&mut conn)?;
2272
2273 Ok(())
2274 }
2275
2276 fn get_conn() -> crate::Result<(Conn, Vec<u8>, u64)> {
2277 let mut conn = Conn::new(get_opts())?;
2278
2279 if let Ok(Some(gtid_mode)) =
2280 "SELECT @@GLOBAL.GTID_MODE".first::<String, _>(&mut conn)
2281 {
2282 if !gtid_mode.starts_with("ON") {
2283 panic!(
2284 "GTID_MODE is disabled \
2285 (enable using --gtid_mode=ON --enforce_gtid_consistency=ON)"
2286 );
2287 }
2288 }
2289
2290 let row: crate::Row = "SHOW BINARY LOGS".first(&mut conn)?.unwrap();
2291 let filename = row.get(0).unwrap();
2292 let position = row.get(1).unwrap();
2293
2294 gen_dummy_data().unwrap();
2295 Ok((conn, filename, position))
2296 }
2297
2298 let (conn, filename, pos) = get_conn().unwrap();
2300 let is_mariadb = conn.0.mariadb_server_version.is_some();
2301
2302 let binlog_stream = conn
2303 .get_binlog_stream(BinlogRequest::new(12).with_filename(filename).with_pos(pos))
2304 .unwrap();
2305
2306 let mut events_num = 0;
2307 let (tx, rx) = sync_channel(0);
2308 spawn(move || {
2309 for event in binlog_stream {
2310 tx.send(event).unwrap();
2311 }
2312 });
2313 let mut tmes = HashMap::new();
2314 while let Ok(event) = rx.recv_timeout(Duration::from_secs(1)) {
2315 let event = event.unwrap();
2316 events_num += 1;
2317
2318 event.header().event_type().unwrap();
2320
2321 match event.read_data()?.unwrap() {
2323 EventData::TableMapEvent(tme) => {
2324 tmes.insert(tme.table_id(), tme.into_owned());
2325 }
2326 EventData::RowsEvent(re) => {
2327 for row in re.rows(&tmes[&re.table_id()]) {
2328 row.unwrap();
2329 }
2330 }
2331 _ => (),
2332 }
2333 }
2334 assert!(events_num > 0);
2335
2336 if !is_mariadb {
2337 let (conn, filename, pos) = get_conn().unwrap();
2339
2340 let binlog_stream = conn
2341 .get_binlog_stream(
2342 BinlogRequest::new(13)
2343 .with_use_gtid(true)
2344 .with_filename(filename)
2345 .with_pos(pos),
2346 )
2347 .unwrap();
2348
2349 let mut events_num = 0;
2350 let (tx, rx) = sync_channel(0);
2351 spawn(move || {
2352 for event in binlog_stream {
2353 tx.send(event).unwrap();
2354 }
2355 });
2356 let mut tmes = HashMap::new();
2357 while let Ok(event) = rx.recv_timeout(Duration::from_secs(1)) {
2358 let event = event.unwrap();
2359 events_num += 1;
2360
2361 event.header().event_type().unwrap();
2363
2364 match event.read_data()?.unwrap() {
2366 EventData::TableMapEvent(tme) => {
2367 tmes.insert(tme.table_id(), tme.into_owned());
2368 }
2369 EventData::RowsEvent(re) => {
2370 for row in re.rows(&tmes[&re.table_id()]) {
2371 row.unwrap();
2372 }
2373 }
2374 _ => (),
2375 }
2376 }
2377 assert!(events_num > 0);
2378 }
2379
2380 let (conn, filename, pos) = get_conn().unwrap();
2382
2383 let mut binlog_stream = conn
2384 .get_binlog_stream(
2385 BinlogRequest::new(14)
2386 .with_filename(filename)
2387 .with_pos(pos)
2388 .with_flags(crate::BinlogDumpFlags::BINLOG_DUMP_NON_BLOCK),
2389 )
2390 .unwrap();
2391
2392 events_num = 0;
2393 while let Some(event) = binlog_stream.next() {
2394 let event = event.unwrap();
2395 events_num += 1;
2396 event.header().event_type().unwrap();
2397 event.read_data()?;
2398 }
2399 assert!(events_num > 0);
2400
2401 Ok(())
2402 }
2403 }
2404
2405 #[cfg(feature = "nightly")]
2406 mod bench {
2407 use test;
2408
2409 use crate::{params, prelude::*, test_misc::get_opts, Conn, Value::NULL};
2410
2411 #[bench]
2412 fn simple_exec(bencher: &mut test::Bencher) {
2413 let mut conn = Conn::new(get_opts()).unwrap();
2414 bencher.iter(|| {
2415 let _ = conn.query_drop("DO 1");
2416 })
2417 }
2418
2419 #[bench]
2420 fn prepared_exec(bencher: &mut test::Bencher) {
2421 let mut conn = Conn::new(get_opts()).unwrap();
2422 let stmt = conn.prep("DO 1").unwrap();
2423 bencher.iter(|| {
2424 let _ = conn.exec_drop(&stmt, ()).unwrap();
2425 })
2426 }
2427
2428 #[bench]
2429 fn prepare_and_exec(bencher: &mut test::Bencher) {
2430 let mut conn = Conn::new(get_opts()).unwrap();
2431 bencher.iter(|| {
2432 let stmt = conn.prep("SELECT ?").unwrap();
2433 let _ = conn.exec_drop(&stmt, (0,)).unwrap();
2434 })
2435 }
2436
2437 #[bench]
2438 fn simple_query_row(bencher: &mut test::Bencher) {
2439 let mut conn = Conn::new(get_opts()).unwrap();
2440 bencher.iter(|| {
2441 let _ = conn.query_drop("SELECT 1").unwrap();
2442 })
2443 }
2444
2445 #[bench]
2446 fn simple_prepared_query_row(bencher: &mut test::Bencher) {
2447 let mut conn = Conn::new(get_opts()).unwrap();
2448 let stmt = conn.prep("SELECT 1").unwrap();
2449 bencher.iter(|| {
2450 let _ = conn.exec_drop(&stmt, ()).unwrap();
2451 })
2452 }
2453
2454 #[bench]
2455 fn simple_prepared_query_row_with_param(bencher: &mut test::Bencher) {
2456 let mut conn = Conn::new(get_opts()).unwrap();
2457 let stmt = conn.prep("SELECT ?").unwrap();
2458 bencher.iter(|| {
2459 let _ = conn.exec_drop(&stmt, (0,)).unwrap();
2460 })
2461 }
2462
2463 #[bench]
2464 fn simple_prepared_query_row_with_named_param(bencher: &mut test::Bencher) {
2465 let mut conn = Conn::new(get_opts()).unwrap();
2466 let stmt = conn.prep("SELECT :a").unwrap();
2467 bencher.iter(|| {
2468 let _ = conn.exec_drop(&stmt, params! {"a" => 0}).unwrap();
2469 })
2470 }
2471
2472 #[bench]
2473 fn simple_prepared_query_row_with_5_params(bencher: &mut test::Bencher) {
2474 let mut conn = Conn::new(get_opts()).unwrap();
2475 let stmt = conn.prep("SELECT ?, ?, ?, ?, ?").unwrap();
2476 let params = (42i8, b"123456".to_vec(), 1.618f64, NULL, 1i8);
2477 bencher.iter(|| {
2478 let _ = conn.exec_drop(&stmt, ¶ms).unwrap();
2479 })
2480 }
2481
2482 #[bench]
2483 fn simple_prepared_query_row_with_5_named_params(bencher: &mut test::Bencher) {
2484 let mut conn = Conn::new(get_opts()).unwrap();
2485 let stmt = conn
2486 .prep("SELECT :one, :two, :three, :four, :five")
2487 .unwrap();
2488 bencher.iter(|| {
2489 let _ = conn.exec_drop(
2490 &stmt,
2491 params! {
2492 "one" => 42i8,
2493 "two" => b"123456",
2494 "three" => 1.618f64,
2495 "four" => NULL,
2496 "five" => 1i8,
2497 },
2498 );
2499 })
2500 }
2501
2502 #[bench]
2503 fn select_large_string(bencher: &mut test::Bencher) {
2504 let mut conn = Conn::new(get_opts()).unwrap();
2505 bencher.iter(|| {
2506 let _ = conn.query_drop("SELECT REPEAT('A', 10000)").unwrap();
2507 })
2508 }
2509
2510 #[bench]
2511 fn select_prepared_large_string(bencher: &mut test::Bencher) {
2512 let mut conn = Conn::new(get_opts()).unwrap();
2513 let stmt = conn.prep("SELECT REPEAT('A', 10000)").unwrap();
2514 bencher.iter(|| {
2515 let _ = conn.exec_drop(&stmt, ()).unwrap();
2516 })
2517 }
2518
2519 #[bench]
2520 fn many_small_rows(bencher: &mut test::Bencher) {
2521 let mut conn = Conn::new(get_opts()).unwrap();
2522 conn.query_drop("CREATE TEMPORARY TABLE mysql.x (id INT)")
2523 .unwrap();
2524 for _ in 0..512 {
2525 conn.query_drop("INSERT INTO mysql.x VALUES (256)").unwrap();
2526 }
2527 let stmt = conn.prep("SELECT * FROM mysql.x").unwrap();
2528 bencher.iter(|| {
2529 let _ = conn.exec_drop(&stmt, ()).unwrap();
2530 });
2531 }
2532 }
2533}