1use futures_util::FutureExt;
10pub use mysql_common::named_params;
11
12use mysql_common::{
13 constants::{DEFAULT_MAX_ALLOWED_PACKET, UTF8MB4_GENERAL_CI, UTF8_GENERAL_CI},
14 crypto,
15 io::ParseBuf,
16 packets::{
17 binlog_request::BinlogRequest, AuthPlugin, AuthSwitchRequest, CommonOkPacket, ErrPacket,
18 HandshakePacket, HandshakeResponse, OkPacket, OkPacketDeserializer, OldAuthSwitchRequest,
19 OldEofPacket, ResultSetTerminator, SslRequest,
20 },
21 proto::MySerialize,
22 row::Row,
23};
24
25use std::{
26 borrow::Cow,
27 fmt,
28 future::Future,
29 mem::{self, replace},
30 pin::Pin,
31 str::FromStr,
32 sync::Arc,
33 time::{Duration, Instant},
34};
35
36use crate::{
37 buffer_pool::PooledBuf,
38 conn::{pool::Pool, stmt_cache::StmtCache},
39 consts::{CapabilityFlags, Command, StatusFlags},
40 error::*,
41 io::Stream,
42 opts::Opts,
43 queryable::{
44 query_result::{QueryResult, ResultSetMeta},
45 transaction::TxStatus,
46 BinaryProtocol, Queryable, TextProtocol,
47 },
48 BinlogStream, ChangeUserOpts, InfileData, OptsBuilder,
49};
50
51use self::routines::Routine;
52
53pub mod binlog_stream;
54pub mod pool;
55pub mod routines;
56pub mod stmt_cache;
57
58const DEFAULT_WAIT_TIMEOUT: usize = 28800;
59
60fn disconnect(mut conn: Conn) {
62 let disconnected = conn.inner.disconnected;
63
64 conn.inner.disconnected = true;
66
67 if !disconnected {
68 if std::thread::panicking() {
70 return;
71 }
72
73 if let Ok(handle) = tokio::runtime::Handle::try_current() {
76 handle.spawn(async move {
77 if let Ok(conn) = conn.cleanup_for_pool().await {
78 let _ = conn.disconnect().await;
79 }
80 });
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
87pub(crate) enum PendingResult {
88 Pending(ResultSetMeta),
90 Taken(Arc<ResultSetMeta>),
92}
93
94struct ConnInner {
96 stream: Option<Stream>,
97 id: u32,
98 is_mariadb: bool,
99 version: (u16, u16, u16),
100 socket: Option<String>,
101 capabilities: CapabilityFlags,
102 status: StatusFlags,
103 last_ok_packet: Option<OkPacket<'static>>,
104 last_err_packet: Option<mysql_common::packets::ServerError<'static>>,
105 pool: Option<Pool>,
106 pending_result: std::result::Result<Option<PendingResult>, ServerError>,
107 tx_status: TxStatus,
108 reset_upon_returning_to_a_pool: bool,
109 opts: Opts,
110 last_io: Instant,
111 wait_timeout: Duration,
112 stmt_cache: StmtCache,
113 nonce: Vec<u8>,
114 auth_plugin: AuthPlugin<'static>,
115 auth_switched: bool,
116 server_key: Option<Vec<u8>>,
117 pub(crate) disconnected: bool,
119 infile_handler:
121 Option<Pin<Box<dyn Future<Output = crate::Result<InfileData>> + Send + Sync + 'static>>>,
122}
123
124impl fmt::Debug for ConnInner {
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 f.debug_struct("Conn")
127 .field("connection id", &self.id)
128 .field("server version", &self.version)
129 .field("pool", &self.pool)
130 .field("pending_result", &self.pending_result)
131 .field("tx_status", &self.tx_status)
132 .field("stream", &self.stream)
133 .field("options", &self.opts)
134 .field("server_key", &self.server_key)
135 .field("auth_plugin", &self.auth_plugin)
136 .finish()
137 }
138}
139
140impl ConnInner {
141 fn empty(opts: Opts) -> ConnInner {
143 ConnInner {
144 capabilities: opts.get_capabilities(),
145 status: StatusFlags::empty(),
146 last_ok_packet: None,
147 last_err_packet: None,
148 stream: None,
149 is_mariadb: false,
150 version: (0, 0, 0),
151 id: 0,
152 pending_result: Ok(None),
153 pool: None,
154 tx_status: TxStatus::None,
155 last_io: Instant::now(),
156 wait_timeout: Duration::from_secs(0),
157 stmt_cache: StmtCache::new(opts.stmt_cache_size()),
158 socket: opts.socket().map(Into::into),
159 opts,
160 nonce: Vec::default(),
161 auth_plugin: AuthPlugin::MysqlNativePassword,
162 auth_switched: false,
163 disconnected: false,
164 server_key: None,
165 infile_handler: None,
166 reset_upon_returning_to_a_pool: false,
167 }
168 }
169
170 fn stream_mut(&mut self) -> Result<&mut Stream> {
174 self.stream
175 .as_mut()
176 .ok_or_else(|| DriverError::ConnectionClosed.into())
177 }
178}
179
180#[derive(Debug)]
182pub struct Conn {
183 inner: Box<ConnInner>,
184}
185
186impl Conn {
187 pub fn id(&self) -> u32 {
189 self.inner.id
190 }
191
192 pub fn last_insert_id(&self) -> Option<u64> {
196 self.inner
197 .last_ok_packet
198 .as_ref()
199 .and_then(|ok| ok.last_insert_id())
200 }
201
202 pub fn affected_rows(&self) -> u64 {
205 self.inner
206 .last_ok_packet
207 .as_ref()
208 .map(|ok| ok.affected_rows())
209 .unwrap_or_default()
210 }
211
212 pub fn info(&self) -> Cow<'_, str> {
214 self.inner
215 .last_ok_packet
216 .as_ref()
217 .and_then(|ok| ok.info_str())
218 .unwrap_or_else(|| "".into())
219 }
220
221 pub fn get_warnings(&self) -> u16 {
223 self.inner
224 .last_ok_packet
225 .as_ref()
226 .map(|ok| ok.warnings())
227 .unwrap_or_default()
228 }
229
230 pub fn last_ok_packet(&self) -> Option<&OkPacket<'static>> {
232 self.inner.last_ok_packet.as_ref()
233 }
234
235 pub fn reset_connection(&mut self, reset_connection: bool) {
239 self.inner.reset_upon_returning_to_a_pool = reset_connection;
240 }
241
242 pub(crate) fn stream_mut(&mut self) -> Result<&mut Stream> {
243 self.inner.stream_mut()
244 }
245
246 pub(crate) fn capabilities(&self) -> CapabilityFlags {
247 self.inner.capabilities
248 }
249
250 pub(crate) fn touch(&mut self) {
252 self.inner.last_io = Instant::now();
253 }
254
255 pub(crate) fn reset_seq_id(&mut self) {
257 if let Some(stream) = self.inner.stream.as_mut() {
258 stream.reset_seq_id();
259 }
260 }
261
262 pub(crate) fn sync_seq_id(&mut self) {
264 if let Some(stream) = self.inner.stream.as_mut() {
265 stream.sync_seq_id();
266 }
267 }
268
269 pub(crate) fn handle_ok(&mut self, ok_packet: OkPacket<'static>) {
271 self.inner.status = ok_packet.status_flags();
272 self.inner.last_err_packet = None;
273 self.inner.last_ok_packet = Some(ok_packet);
274 }
275
276 pub(crate) fn handle_err(&mut self, err_packet: ErrPacket<'_>) -> Result<()> {
278 match err_packet {
279 ErrPacket::Error(err) => {
280 self.inner.status = StatusFlags::empty();
281 self.inner.last_ok_packet = None;
282 self.inner.last_err_packet = Some(err.clone().into_owned());
283 Err(Error::from(err))
284 }
285 ErrPacket::Progress(_) => Ok(()),
286 }
287 }
288
289 pub(crate) fn get_tx_status(&self) -> TxStatus {
291 self.inner.tx_status
292 }
293
294 pub(crate) fn set_tx_status(&mut self, tx_status: TxStatus) {
296 self.inner.tx_status = tx_status;
297 }
298
299 pub(crate) fn use_pending_result(
303 &mut self,
304 ) -> std::result::Result<Option<&PendingResult>, ServerError> {
305 if let Err(ref e) = self.inner.pending_result {
306 let e = e.clone();
307 self.inner.pending_result = Ok(None);
308 return Err(e);
309 } else {
310 Ok(self.inner.pending_result.as_ref().unwrap().as_ref())
311 }
312 }
313
314 pub(crate) fn get_pending_result(
315 &self,
316 ) -> std::result::Result<Option<&PendingResult>, &ServerError> {
317 self.inner.pending_result.as_ref().map(|x| x.as_ref())
318 }
319
320 pub(crate) fn has_pending_result(&self) -> bool {
321 matches!(self.inner.pending_result, Err(_))
322 || matches!(self.inner.pending_result, Ok(Some(_)))
323 }
324
325 pub(crate) fn set_pending_result(
327 &mut self,
328 meta: Option<ResultSetMeta>,
329 ) -> std::result::Result<Option<PendingResult>, ServerError> {
330 replace(
331 &mut self.inner.pending_result,
332 Ok(meta.map(PendingResult::Pending)),
333 )
334 }
335
336 pub(crate) fn set_pending_result_error(
337 &mut self,
338 error: ServerError,
339 ) -> std::result::Result<Option<PendingResult>, ServerError> {
340 replace(&mut self.inner.pending_result, Err(error))
341 }
342
343 pub(crate) fn take_pending_result(
345 &mut self,
346 ) -> std::result::Result<Option<Arc<ResultSetMeta>>, ServerError> {
347 let mut output = None;
348
349 self.inner.pending_result = match replace(&mut self.inner.pending_result, Ok(None))? {
350 Some(PendingResult::Pending(x)) => {
351 let meta = Arc::new(x);
352 output = Some(meta.clone());
353 Ok(Some(PendingResult::Taken(meta)))
354 }
355 x => Ok(x),
356 };
357
358 Ok(output)
359 }
360
361 pub(crate) fn status(&self) -> StatusFlags {
363 self.inner.status
364 }
365
366 pub(crate) async fn routine<'a, F, T>(&mut self, mut f: F) -> crate::Result<T>
367 where
368 F: Routine<T> + 'a,
369 {
370 self.inner.disconnected = true;
371 let result = f.call(&mut *self).await;
372 match result {
373 result @ Ok(_) | result @ Err(crate::Error::Server(_)) => {
374 self.inner.disconnected = false;
376 result
377 }
378 Err(err) => {
379 if self.inner.stream.is_some() {
380 self.take_stream().close().await?;
381 }
382 Err(err)
383 }
384 }
385 }
386
387 pub fn server_version(&self) -> (u16, u16, u16) {
389 self.inner.version
390 }
391
392 pub fn opts(&self) -> &Opts {
394 &self.inner.opts
395 }
396
397 pub fn set_infile_handler<T>(&mut self, handler: T)
404 where
405 T: Future<Output = crate::Result<InfileData>>,
406 T: Send + Sync + 'static,
407 {
408 self.inner.infile_handler = Some(Box::pin(handler));
409 }
410
411 fn take_stream(&mut self) -> Stream {
412 self.inner.stream.take().unwrap()
413 }
414
415 pub async fn disconnect(mut self) -> Result<()> {
417 if !self.inner.disconnected {
418 self.inner.disconnected = true;
419 self.write_command_data(Command::COM_QUIT, &[]).await?;
420 let stream = self.take_stream();
421 stream.close().await?;
422 }
423 Ok(())
424 }
425
426 async fn close_conn(mut self) -> Result<()> {
428 self = self.cleanup_for_pool().await?;
429 self.disconnect().await
430 }
431
432 fn is_secure(&self) -> bool {
434 #[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))]
435 {
436 self.inner
437 .stream
438 .as_ref()
439 .map(|x| x.is_secure())
440 .unwrap_or_default()
441 }
442
443 #[cfg(not(any(feature = "native-tls-tls", feature = "rustls-tls")))]
444 false
445 }
446
447 fn is_socket(&self) -> bool {
449 #[cfg(unix)]
450 {
451 self.inner
452 .stream
453 .as_ref()
454 .map(|x| x.is_socket())
455 .unwrap_or_default()
456 }
457
458 #[cfg(not(unix))]
459 false
460 }
461
462 fn take(&mut self) -> Conn {
464 mem::replace(self, Conn::empty(Default::default()))
465 }
466
467 fn empty(opts: Opts) -> Self {
468 Self {
469 inner: Box::new(ConnInner::empty(opts)),
470 }
471 }
472
473 fn setup_stream(&mut self) -> Result<()> {
477 debug_assert!(self.inner.stream.is_some());
478 if let Some(stream) = self.inner.stream.as_mut() {
479 stream.set_tcp_nodelay(self.inner.opts.tcp_nodelay())?;
480 }
481 Ok(())
482 }
483
484 async fn handle_handshake(&mut self) -> Result<()> {
485 let packet = self.read_packet().await?;
486 let handshake = ParseBuf(&*packet).parse::<HandshakePacket>(())?;
487
488 self.inner.nonce = {
490 let mut nonce = Vec::from(handshake.scramble_1_ref());
491 nonce.extend_from_slice(handshake.scramble_2_ref().unwrap_or(&[][..]));
492 nonce.resize(20, 0);
495 nonce
496 };
497
498 self.inner.capabilities = handshake.capabilities() & self.inner.opts.get_capabilities();
499 self.inner.version = handshake
500 .maria_db_server_version_parsed()
501 .map(|version| {
502 self.inner.is_mariadb = true;
503 version
504 })
505 .or_else(|| handshake.server_version_parsed())
506 .unwrap_or((0, 0, 0));
507 self.inner.id = handshake.connection_id();
508 self.inner.status = handshake.status_flags();
509
510 self.inner.auth_plugin = match handshake.auth_plugin() {
514 Some(AuthPlugin::CachingSha2Password) => AuthPlugin::CachingSha2Password,
515 _ => AuthPlugin::MysqlNativePassword,
516 };
517
518 Ok(())
519 }
520
521 async fn switch_to_ssl_if_needed(&mut self) -> Result<()> {
522 if self
523 .inner
524 .opts
525 .get_capabilities()
526 .contains(CapabilityFlags::CLIENT_SSL)
527 {
528 if !self
529 .inner
530 .capabilities
531 .contains(CapabilityFlags::CLIENT_SSL)
532 {
533 return Err(DriverError::NoClientSslFlagFromServer.into());
534 }
535
536 let collation = if self.inner.version >= (5, 5, 3) {
537 UTF8MB4_GENERAL_CI
538 } else {
539 UTF8_GENERAL_CI
540 };
541
542 let ssl_request = SslRequest::new(
543 self.inner.capabilities,
544 DEFAULT_MAX_ALLOWED_PACKET as u32,
545 collation as u8,
546 );
547 self.write_struct(&ssl_request).await?;
548 let conn = self;
549 let ssl_opts = conn.opts().ssl_opts().cloned().expect("unreachable");
550 let domain = conn.opts().ip_or_hostname().into();
551 conn.stream_mut()?.make_secure(domain, ssl_opts).await?;
552 Ok(())
553 } else {
554 Ok(())
555 }
556 }
557
558 async fn do_handshake_response(&mut self) -> Result<()> {
559 let auth_data = self
560 .inner
561 .auth_plugin
562 .gen_data(self.inner.opts.pass(), &*self.inner.nonce);
563
564 let handshake_response = HandshakeResponse::new(
565 auth_data.as_deref(),
566 self.inner.version,
567 self.inner.opts.user().map(|x| x.as_bytes()),
568 self.inner.opts.db_name().map(|x| x.as_bytes()),
569 Some(self.inner.auth_plugin.borrow()),
570 self.capabilities(),
571 Default::default(), );
573
574 let mut buf = crate::BUFFER_POOL.get();
576 handshake_response.serialize(buf.as_mut());
577
578 self.write_packet(buf).await?;
579 Ok(())
580 }
581
582 async fn perform_auth_switch(
583 &mut self,
584 auth_switch_request: AuthSwitchRequest<'_>,
585 ) -> Result<()> {
586 if !self.inner.auth_switched {
587 self.inner.auth_switched = true;
588 self.inner.nonce = auth_switch_request.plugin_data().to_vec();
589
590 if matches!(
591 auth_switch_request.auth_plugin(),
592 AuthPlugin::MysqlOldPassword
593 ) {
594 if self.inner.opts.secure_auth() {
595 return Err(DriverError::MysqlOldPasswordDisabled.into());
596 }
597 }
598
599 self.inner.auth_plugin = auth_switch_request.auth_plugin().clone().into_owned();
600
601 let plugin_data = match &self.inner.auth_plugin {
602 x @ AuthPlugin::CachingSha2Password => {
603 x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
604 }
605 x @ AuthPlugin::MysqlNativePassword => {
606 x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
607 }
608 x @ AuthPlugin::MysqlOldPassword => {
609 if self.inner.opts.secure_auth() {
610 return Err(DriverError::MysqlOldPasswordDisabled.into());
611 } else {
612 x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
613 }
614 }
615 x @ AuthPlugin::MysqlClearPassword => {
616 if self.inner.opts.enable_cleartext_plugin() {
617 x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
618 } else {
619 return Err(DriverError::CleartextPluginDisabled.into());
620 }
621 }
622 x @ AuthPlugin::Other(_) => x.gen_data(self.inner.opts.pass(), &self.inner.nonce),
623 };
624
625 if let Some(plugin_data) = plugin_data {
626 self.write_struct(&plugin_data.into_owned()).await?;
627 } else {
628 self.write_packet(crate::BUFFER_POOL.get()).await?;
629 }
630
631 self.continue_auth().await?;
632
633 Ok(())
634 } else {
635 unreachable!("auth_switched flag should be checked by caller")
636 }
637 }
638
639 fn continue_auth(&mut self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
640 Box::pin(async move {
643 match self.inner.auth_plugin {
644 AuthPlugin::MysqlNativePassword | AuthPlugin::MysqlOldPassword => {
645 self.continue_mysql_native_password_auth().await?;
646 Ok(())
647 }
648 AuthPlugin::CachingSha2Password => {
649 self.continue_caching_sha2_password_auth().await?;
650 Ok(())
651 }
652 AuthPlugin::MysqlClearPassword => {
653 if self.inner.opts.enable_cleartext_plugin() {
654 self.continue_mysql_native_password_auth().await?;
655 Ok(())
656 } else {
657 Err(DriverError::CleartextPluginDisabled.into())
658 }
659 }
660 AuthPlugin::Other(ref name) => Err(DriverError::UnknownAuthPlugin {
661 name: String::from_utf8_lossy(name.as_ref()).to_string(),
662 }
663 .into()),
664 }
665 })
666 }
667
668 fn switch_to_compression(&mut self) -> Result<()> {
669 if self
670 .capabilities()
671 .contains(CapabilityFlags::CLIENT_COMPRESS)
672 {
673 if let Some(compression) = self.inner.opts.compression() {
674 if let Some(stream) = self.inner.stream.as_mut() {
675 stream.compress(compression);
676 }
677 }
678 }
679 Ok(())
680 }
681
682 async fn continue_caching_sha2_password_auth(&mut self) -> Result<()> {
683 let packet = self.read_packet().await?;
684 match packet.get(0) {
685 Some(0x00) => {
686 Ok(())
688 }
689 Some(0x01) => match packet.get(1) {
690 Some(0x03) => {
691 self.drop_packet().await
693 }
694 Some(0x04) => {
695 let pass = self.inner.opts.pass().unwrap_or_default();
696 let mut pass = crate::BUFFER_POOL.get_with(pass.as_bytes());
697 pass.as_mut().push(0);
698
699 if self.is_secure() || self.is_socket() {
700 self.write_packet(pass).await?;
701 } else {
702 if self.inner.server_key.is_none() {
703 self.write_bytes(&[0x02][..]).await?;
704 let packet = self.read_packet().await?;
705 self.inner.server_key = Some(packet[1..].to_vec());
706 }
707 for (i, byte) in pass.as_mut().iter_mut().enumerate() {
708 *byte ^= self.inner.nonce[i % self.inner.nonce.len()];
709 }
710 let encrypted_pass = crypto::encrypt(
711 &*pass,
712 self.inner.server_key.as_deref().expect("unreachable"),
713 );
714 self.write_bytes(&*encrypted_pass).await?;
715 };
716 self.drop_packet().await?;
717 Ok(())
718 }
719 _ => Err(DriverError::UnexpectedPacket {
720 payload: packet.to_vec(),
721 }
722 .into()),
723 },
724 Some(0xfe) if !self.inner.auth_switched => {
725 let auth_switch_request = ParseBuf(&*packet).parse::<AuthSwitchRequest>(())?;
726 self.perform_auth_switch(auth_switch_request).await?;
727 Ok(())
728 }
729 _ => Err(DriverError::UnexpectedPacket {
730 payload: packet.to_vec(),
731 }
732 .into()),
733 }
734 }
735
736 async fn continue_mysql_native_password_auth(&mut self) -> Result<()> {
737 let packet = self.read_packet().await?;
738 match packet.get(0) {
739 Some(0x00) => Ok(()),
740 Some(0xfe) if !self.inner.auth_switched => {
741 let auth_switch = if packet.len() > 1 {
742 ParseBuf(&*packet).parse(())?
743 } else {
744 let _ = ParseBuf(&*packet).parse::<OldAuthSwitchRequest>(())?;
745 AuthSwitchRequest::new(
747 "mysql_old_password".as_bytes(),
748 self.inner.nonce.clone(),
749 )
750 };
751 self.perform_auth_switch(auth_switch).await
752 }
753 _ => Err(DriverError::UnexpectedPacket {
754 payload: packet.to_vec(),
755 }
756 .into()),
757 }
758 }
759
760 fn handle_packet(&mut self, packet: &PooledBuf) -> Result<bool> {
762 let ok_packet = if self.has_pending_result() {
763 if self
764 .capabilities()
765 .contains(CapabilityFlags::CLIENT_DEPRECATE_EOF)
766 {
767 ParseBuf(&*packet)
768 .parse::<OkPacketDeserializer<ResultSetTerminator>>(self.capabilities())
769 .map(|x| x.into_inner())
770 } else {
771 ParseBuf(&*packet)
772 .parse::<OkPacketDeserializer<OldEofPacket>>(self.capabilities())
773 .map(|x| x.into_inner())
774 }
775 } else {
776 ParseBuf(&*packet)
777 .parse::<OkPacketDeserializer<CommonOkPacket>>(self.capabilities())
778 .map(|x| x.into_inner())
779 };
780
781 if let Ok(ok_packet) = ok_packet {
782 self.handle_ok(ok_packet.into_owned());
783 } else {
784 let err_packet = ParseBuf(&*packet).parse::<ErrPacket>(self.capabilities());
785 if let Ok(err_packet) = err_packet {
786 self.handle_err(err_packet)?;
787 return Ok(true);
788 }
789 }
790
791 Ok(false)
792 }
793
794 pub(crate) async fn read_packet(&mut self) -> Result<PooledBuf> {
795 loop {
796 let packet = crate::io::ReadPacket::new(&mut *self)
797 .await
798 .map_err(|io_err| {
799 self.inner.stream.take();
800 self.inner.disconnected = true;
801 Error::from(io_err)
802 })?;
803 if self.handle_packet(&packet)? {
804 continue;
806 } else {
807 return Ok(packet);
808 }
809 }
810 }
811
812 pub(crate) async fn read_packets(&mut self, n: usize) -> Result<Vec<PooledBuf>> {
814 let mut packets = Vec::with_capacity(n);
815 for _ in 0..n {
816 packets.push(self.read_packet().await?);
817 }
818 Ok(packets)
819 }
820
821 pub(crate) async fn write_packet(&mut self, data: PooledBuf) -> Result<()> {
822 crate::io::WritePacket::new(&mut *self, data)
823 .await
824 .map_err(|io_err| {
825 self.inner.stream.take();
826 self.inner.disconnected = true;
827 From::from(io_err)
828 })
829 }
830
831 pub(crate) async fn write_bytes(&mut self, bytes: &[u8]) -> Result<()> {
833 let buf = crate::BUFFER_POOL.get_with(bytes);
834 self.write_packet(buf).await
835 }
836
837 pub(crate) async fn write_struct<T: MySerialize>(&mut self, x: &T) -> Result<()> {
839 let mut buf = crate::BUFFER_POOL.get();
840 x.serialize(buf.as_mut());
841 self.write_packet(buf).await
842 }
843
844 pub(crate) async fn write_command<T: MySerialize>(&mut self, cmd: &T) -> Result<()> {
846 self.clean_dirty().await?;
847 self.reset_seq_id();
848 self.write_struct(cmd).await
849 }
850
851 pub(crate) async fn write_command_raw(&mut self, body: PooledBuf) -> Result<()> {
853 debug_assert!(!body.is_empty());
854 self.clean_dirty().await?;
855 self.reset_seq_id();
856 self.write_packet(body).await
857 }
858
859 pub(crate) async fn write_command_data<T>(&mut self, cmd: Command, cmd_data: T) -> Result<()>
861 where
862 T: AsRef<[u8]>,
863 {
864 let cmd_data = cmd_data.as_ref();
865 let mut buf = crate::BUFFER_POOL.get();
866 let body = buf.as_mut();
867 body.push(cmd as u8);
868 body.extend_from_slice(cmd_data);
869 self.write_command_raw(buf).await
870 }
871
872 async fn drop_packet(&mut self) -> Result<()> {
873 self.read_packet().await?;
874 Ok(())
875 }
876
877 async fn run_init_commands(&mut self) -> Result<()> {
878 let mut init = self.inner.opts.init().to_vec();
879
880 while let Some(query) = init.pop() {
881 self.query_drop(query).await?;
882 }
883
884 Ok(())
885 }
886
887 async fn run_setup_commands(&mut self) -> Result<()> {
888 let mut setup = self.inner.opts.setup().to_vec();
889
890 while let Some(query) = setup.pop() {
891 self.query_drop(query).await?;
892 }
893
894 Ok(())
895 }
896
897 pub fn new<T: Into<Opts>>(opts: T) -> crate::BoxFuture<'static, Conn> {
899 let opts = opts.into();
900 async move {
901 let mut conn = Conn::empty(opts.clone());
902
903 let stream = if let Some(_path) = opts.socket() {
904 #[cfg(unix)]
905 {
906 Stream::connect_socket(_path.to_owned()).await?
907 }
908 #[cfg(target_os = "windows")]
909 return Err(crate::DriverError::NamedPipesDisabled.into());
910 } else {
911 let keepalive = opts
912 .tcp_keepalive()
913 .map(|x| std::time::Duration::from_millis(x.into()));
914 Stream::connect_tcp(opts.hostport_or_url(), keepalive).await?
915 };
916
917 conn.inner.stream = Some(stream);
918 conn.setup_stream()?;
919 conn.handle_handshake().await?;
920 conn.switch_to_ssl_if_needed().await?;
921 conn.do_handshake_response().await?;
922 conn.continue_auth().await?;
923 conn.switch_to_compression()?;
924 conn.read_settings().await?;
925 conn.reconnect_via_socket_if_needed().await?;
926 conn.run_init_commands().await?;
927 conn.run_setup_commands().await?;
928
929 Ok(conn)
930 }
931 .boxed()
932 }
933
934 pub async fn from_url<T: AsRef<str>>(url: T) -> Result<Conn> {
936 Conn::new(Opts::from_str(url.as_ref())?).await
937 }
938
939 async fn reconnect_via_socket_if_needed(&mut self) -> Result<()> {
943 if let Some(socket) = self.inner.socket.as_ref() {
944 let opts = self.inner.opts.clone();
945 if opts.socket().is_none() {
946 let opts = OptsBuilder::from_opts(opts).socket(Some(&**socket));
947 if let Ok(conn) = Conn::new(opts).await {
948 let old_conn = std::mem::replace(self, conn);
949 old_conn.close_conn().await?;
951 }
952 }
953 }
954 Ok(())
955 }
956
957 async fn read_settings(&mut self) -> Result<()> {
967 enum Action {
968 Load(Cfg),
969 Apply(CfgData),
970 }
971
972 enum CfgData {
973 MaxAllowedPacket(usize),
974 WaitTimeout(usize),
975 }
976
977 impl CfgData {
978 fn apply(&self, conn: &mut Conn) {
979 match self {
980 Self::MaxAllowedPacket(value) => {
981 if let Some(stream) = conn.inner.stream.as_mut() {
982 stream.set_max_allowed_packet(*value);
983 }
984 }
985 Self::WaitTimeout(value) => {
986 conn.inner.wait_timeout = Duration::from_secs(*value as u64);
987 }
988 }
989 }
990 }
991
992 enum Cfg {
993 Socket,
994 MaxAllowedPacket,
995 WaitTimeout,
996 }
997
998 impl Cfg {
999 const fn name(&self) -> &'static str {
1000 match self {
1001 Self::Socket => "@@socket",
1002 Self::MaxAllowedPacket => "@@max_allowed_packet",
1003 Self::WaitTimeout => "@@wait_timeout",
1004 }
1005 }
1006
1007 fn apply(&self, conn: &mut Conn, value: Option<crate::Value>) {
1008 match self {
1009 Cfg::Socket => {
1010 conn.inner.socket = value.map(crate::from_value).flatten();
1011 }
1012 Cfg::MaxAllowedPacket => {
1013 if let Some(stream) = conn.inner.stream.as_mut() {
1014 stream.set_max_allowed_packet(
1015 value
1016 .map(crate::from_value)
1017 .flatten()
1018 .unwrap_or(DEFAULT_MAX_ALLOWED_PACKET),
1019 );
1020 }
1021 }
1022 Cfg::WaitTimeout => {
1023 conn.inner.wait_timeout = Duration::from_secs(
1024 value
1025 .map(crate::from_value)
1026 .flatten()
1027 .unwrap_or(DEFAULT_WAIT_TIMEOUT) as u64,
1028 );
1029 }
1030 }
1031 }
1032 }
1033
1034 let mut actions = vec![
1035 if let Some(x) = self.opts().max_allowed_packet() {
1036 Action::Apply(CfgData::MaxAllowedPacket(x))
1037 } else {
1038 Action::Load(Cfg::MaxAllowedPacket)
1039 },
1040 if let Some(x) = self.opts().wait_timeout() {
1041 Action::Apply(CfgData::WaitTimeout(x))
1042 } else {
1043 Action::Load(Cfg::WaitTimeout)
1044 },
1045 ];
1046
1047 if self.inner.opts.prefer_socket() && self.inner.socket.is_none() {
1048 actions.push(Action::Load(Cfg::Socket))
1049 }
1050
1051 let loads = actions
1052 .iter()
1053 .filter_map(|x| match x {
1054 Action::Load(x) => Some(x),
1055 Action::Apply(_) => None,
1056 })
1057 .collect::<Vec<_>>();
1058
1059 let loaded = if !loads.is_empty() {
1060 let query = loads
1061 .iter()
1062 .zip(std::iter::once(' ').chain(std::iter::repeat(',')))
1063 .fold("SELECT".to_owned(), |mut acc, (cfg, prefix)| {
1064 acc.push(prefix);
1065 acc.push_str(cfg.name());
1066 acc
1067 });
1068
1069 self.query_internal::<Row, String>(query)
1070 .await?
1071 .map(|row| row.unwrap())
1072 .unwrap_or_else(|| vec![crate::Value::NULL; loads.len()])
1073 } else {
1074 vec![]
1075 };
1076 let mut loaded = loaded.into_iter();
1077
1078 for action in actions {
1079 match action {
1080 Action::Load(cfg) => cfg.apply(self, loaded.next()),
1081 Action::Apply(cfg) => cfg.apply(self),
1082 }
1083 }
1084
1085 Ok(())
1086 }
1087
1088 fn expired(&self) -> bool {
1091 let ttl = self
1092 .inner
1093 .opts
1094 .conn_ttl()
1095 .unwrap_or(self.inner.wait_timeout);
1096 !ttl.is_zero() && self.idling() > ttl
1097 }
1098
1099 fn idling(&self) -> Duration {
1101 self.inner.last_io.elapsed()
1102 }
1103
1104 pub async fn reset(&mut self) -> Result<bool> {
1111 let supports_com_reset_connection = if self.inner.is_mariadb {
1112 self.inner.version >= (10, 2, 4)
1113 } else {
1114 self.inner.version > (5, 7, 2)
1116 };
1117
1118 if supports_com_reset_connection {
1119 self.routine(routines::ResetRoutine).await?;
1120 self.inner.stmt_cache.clear();
1121 self.inner.infile_handler = None;
1122 self.run_setup_commands().await?;
1123 }
1124
1125 Ok(supports_com_reset_connection)
1126 }
1127
1128 pub async fn change_user(&mut self, opts: ChangeUserOpts) -> Result<()> {
1140 if opts != ChangeUserOpts::default() {
1142 let mut opts_changed = false;
1143 if let Some(user) = opts.user() {
1144 opts_changed |= user != self.opts().user()
1145 };
1146 if let Some(pass) = opts.pass() {
1147 opts_changed |= pass != self.opts().pass()
1148 };
1149 if let Some(db_name) = opts.db_name() {
1150 opts_changed |= db_name != self.opts().db_name()
1151 };
1152 if opts_changed {
1153 if let Some(pool) = self.inner.pool.take() {
1154 pool.cancel_connection();
1155 }
1156 }
1157 }
1158
1159 let conn_opts = &mut self.inner.opts;
1160 opts.update_opts(conn_opts);
1161 self.routine(routines::ChangeUser).await?;
1162 self.inner.stmt_cache.clear();
1163 self.inner.infile_handler = None;
1164 self.run_setup_commands().await?;
1165 Ok(())
1166 }
1167
1168 async fn reset_for_pool(mut self) -> Result<Self> {
1172 if !self.reset().await? {
1173 self.change_user(Default::default()).await?;
1174 }
1175 Ok(self)
1176 }
1177
1178 async fn rollback_transaction(&mut self) -> Result<()> {
1180 debug_assert_ne!(self.inner.tx_status, TxStatus::None);
1181 self.inner.tx_status = TxStatus::None;
1182 self.query_drop("ROLLBACK").await
1183 }
1184
1185 pub(crate) fn more_results_exists(&self) -> bool {
1188 self.status()
1189 .contains(StatusFlags::SERVER_MORE_RESULTS_EXISTS)
1190 }
1191
1192 pub(crate) async fn drop_result(&mut self) -> Result<()> {
1197 let meta = match self.set_pending_result(None)? {
1199 Some(PendingResult::Pending(meta)) => Some(meta),
1200 Some(PendingResult::Taken(meta)) => {
1201 Some(Arc::try_unwrap(meta).expect("Conn::drop_result call on a pending result that may still be droped by someone else"))
1204 }
1205 None => None,
1206 };
1207
1208 let _ = self.set_pending_result(meta);
1209
1210 match self.use_pending_result() {
1211 Ok(Some(PendingResult::Pending(ResultSetMeta::Text(_)))) => {
1212 QueryResult::<'_, '_, TextProtocol>::new(self)
1213 .drop_result()
1214 .await
1215 }
1216 Ok(Some(PendingResult::Pending(ResultSetMeta::Binary(_)))) => {
1217 QueryResult::<'_, '_, BinaryProtocol>::new(self)
1218 .drop_result()
1219 .await
1220 }
1221 Ok(None) => Ok(()),
1222 Ok(Some(PendingResult::Taken(_))) | Err(_) => {
1223 unreachable!("this case must be handled earlier in this function")
1224 }
1225 }
1226 }
1227
1228 async fn cleanup_for_pool(mut self) -> Result<Self> {
1232 loop {
1233 let result = if self.has_pending_result() {
1234 self.drop_result().await
1235 } else if self.inner.tx_status != TxStatus::None {
1236 self.rollback_transaction().await
1237 } else {
1238 break;
1239 };
1240
1241 if let Err(err) = result {
1245 if err.is_fatal() {
1246 return Err(err);
1249 }
1250 }
1251 }
1252 Ok(self)
1253 }
1254
1255 async fn register_as_slave(&mut self, server_id: u32) -> Result<()> {
1256 use mysql_common::packets::ComRegisterSlave;
1257
1258 self.query_drop("SET @master_binlog_checksum='ALL'").await?;
1259 self.write_command(&ComRegisterSlave::new(server_id))
1260 .await?;
1261
1262 self.read_packet().await?;
1264
1265 Ok(())
1266 }
1267
1268 async fn request_binlog(&mut self, request: BinlogRequest<'_>) -> Result<()> {
1269 self.register_as_slave(request.server_id()).await?;
1270 self.write_command(&request.as_cmd()).await?;
1271 Ok(())
1272 }
1273
1274 pub async fn get_binlog_stream(mut self, request: BinlogRequest<'_>) -> Result<BinlogStream> {
1275 self.request_binlog(request).await?;
1276
1277 Ok(BinlogStream::new(self))
1278 }
1279}
1280
1281#[cfg(test)]
1282mod test {
1283 use bytes::Bytes;
1284 use futures_util::stream::{self, StreamExt};
1285 use mysql_common::{binlog::events::EventData, constants::MAX_PAYLOAD_LEN};
1286 use rand::Fill;
1287 use tokio::time::timeout;
1288
1289 use std::time::Duration;
1290
1291 use crate::{
1292 from_row, params, prelude::*, test_misc::get_opts, BinlogDumpFlags, BinlogRequest,
1293 ChangeUserOpts, Conn, Error, OptsBuilder, Pool, Value, WhiteListFsHandler,
1294 };
1295
1296 async fn gen_dummy_data() -> super::Result<()> {
1297 let mut conn = Conn::new(get_opts()).await?;
1298
1299 "CREATE TABLE IF NOT EXISTS customers (customer_id int not null)"
1300 .ignore(&mut conn)
1301 .await?;
1302
1303 for i in 0_u8..100 {
1304 "INSERT INTO customers(customer_id) VALUES (?)"
1305 .with((i,))
1306 .ignore(&mut conn)
1307 .await?;
1308 }
1309
1310 "DROP TABLE customers".ignore(&mut conn).await?;
1311
1312 Ok(())
1313 }
1314
1315 async fn create_binlog_stream_conn(pool: Option<&Pool>) -> super::Result<(Conn, Vec<u8>, u64)> {
1316 let mut conn = match pool {
1317 None => Conn::new(get_opts()).await.unwrap(),
1318 Some(pool) => pool.get_conn().await.unwrap(),
1319 };
1320
1321 if let Ok(Some(gtid_mode)) = "SELECT @@GLOBAL.GTID_MODE"
1322 .first::<String, _>(&mut conn)
1323 .await
1324 {
1325 if !gtid_mode.starts_with("ON") {
1326 panic!(
1327 "GTID_MODE is disabled \
1328 (enable using --gtid_mode=ON --enforce_gtid_consistency=ON)"
1329 );
1330 }
1331 }
1332
1333 let row: crate::Row = "SHOW BINARY LOGS".first(&mut conn).await.unwrap().unwrap();
1334 let filename = row.get(0).unwrap();
1335 let position = row.get(1).unwrap();
1336
1337 gen_dummy_data().await.unwrap();
1338 Ok((conn, filename, position))
1339 }
1340
1341 #[tokio::test]
1342 async fn should_read_binlog() -> super::Result<()> {
1343 read_binlog_streams_and_close_their_connections(None, (12, 13, 14))
1344 .await
1345 .unwrap();
1346
1347 let pool = Pool::new(get_opts());
1348 read_binlog_streams_and_close_their_connections(Some(&pool), (15, 16, 17))
1349 .await
1350 .unwrap();
1351
1352 timeout(Duration::from_secs(10), pool.disconnect())
1355 .await
1356 .unwrap()
1357 .unwrap();
1358
1359 Ok(())
1360 }
1361
1362 #[tokio::test]
1363 async fn should_return_found_rows_if_flag_is_set() -> super::Result<()> {
1364 let opts = get_opts().client_found_rows(true);
1365 let mut conn = Conn::new(opts).await.unwrap();
1366
1367 "CREATE TEMPORARY TABLE mysql.found_rows (id INT PRIMARY KEY AUTO_INCREMENT, val INT)"
1368 .ignore(&mut conn)
1369 .await?;
1370
1371 "INSERT INTO mysql.found_rows (val) VALUES (1)"
1372 .ignore(&mut conn)
1373 .await?;
1374
1375 assert_eq!(conn.affected_rows(), 1);
1377
1378 "UPDATE mysql.found_rows SET val = 1 WHERE val = 1"
1379 .ignore(&mut conn)
1380 .await?;
1381
1382 assert_eq!(conn.affected_rows(), 1);
1385
1386 Ok(())
1387 }
1388
1389 #[tokio::test]
1390 async fn should_not_return_found_rows_if_flag_is_not_set() -> super::Result<()> {
1391 let mut conn = Conn::new(get_opts()).await.unwrap();
1392
1393 "CREATE TEMPORARY TABLE mysql.found_rows (id INT PRIMARY KEY AUTO_INCREMENT, val INT)"
1394 .ignore(&mut conn)
1395 .await?;
1396
1397 "INSERT INTO mysql.found_rows (val) VALUES (1)"
1398 .ignore(&mut conn)
1399 .await?;
1400
1401 assert_eq!(conn.affected_rows(), 1);
1403
1404 "UPDATE mysql.found_rows SET val = 1 WHERE val = 1"
1405 .ignore(&mut conn)
1406 .await?;
1407
1408 assert_eq!(conn.affected_rows(), 0);
1410
1411 Ok(())
1412 }
1413
1414 async fn read_binlog_streams_and_close_their_connections(
1415 pool: Option<&Pool>,
1416 binlog_server_ids: (u32, u32, u32),
1417 ) -> super::Result<()> {
1418 let (conn, filename, pos) = create_binlog_stream_conn(pool).await.unwrap();
1420 let is_mariadb = conn.inner.is_mariadb;
1421
1422 let mut binlog_stream = conn
1423 .get_binlog_stream(
1424 BinlogRequest::new(binlog_server_ids.0)
1425 .with_filename(filename)
1426 .with_pos(pos),
1427 )
1428 .await
1429 .unwrap();
1430
1431 let mut events_num = 0;
1432 while let Ok(Some(event)) = timeout(Duration::from_secs(10), binlog_stream.next()).await {
1433 let event = event.unwrap();
1434 events_num += 1;
1435
1436 event.header().event_type().unwrap();
1438
1439 match event.read_data()?.unwrap() {
1441 EventData::RowsEvent(re) => {
1442 let tme = binlog_stream.get_tme(re.table_id());
1443 for row in re.rows(tme.unwrap()) {
1444 row.unwrap();
1445 }
1446 }
1447 _ => (),
1448 }
1449 }
1450 assert!(events_num > 0);
1451 timeout(Duration::from_secs(10), binlog_stream.close())
1452 .await
1453 .unwrap()
1454 .unwrap();
1455
1456 if !is_mariadb {
1457 let (conn, filename, pos) = create_binlog_stream_conn(pool).await.unwrap();
1459
1460 let mut binlog_stream = conn
1461 .get_binlog_stream(
1462 BinlogRequest::new(binlog_server_ids.1)
1463 .with_use_gtid(true)
1464 .with_filename(filename)
1465 .with_pos(pos),
1466 )
1467 .await
1468 .unwrap();
1469
1470 events_num = 0;
1471 while let Ok(Some(event)) = timeout(Duration::from_secs(10), binlog_stream.next()).await
1472 {
1473 let event = event.unwrap();
1474 events_num += 1;
1475
1476 event.header().event_type().unwrap();
1478
1479 match event.read_data()?.unwrap() {
1481 EventData::RowsEvent(re) => {
1482 let tme = binlog_stream.get_tme(re.table_id());
1483 for row in re.rows(tme.unwrap()) {
1484 row.unwrap();
1485 }
1486 }
1487 _ => (),
1488 }
1489 }
1490 assert!(events_num > 0);
1491 timeout(Duration::from_secs(10), binlog_stream.close())
1492 .await
1493 .unwrap()
1494 .unwrap();
1495 }
1496
1497 let (conn, filename, pos) = create_binlog_stream_conn(pool).await.unwrap();
1499
1500 let mut binlog_stream = conn
1501 .get_binlog_stream(
1502 BinlogRequest::new(binlog_server_ids.2)
1503 .with_filename(filename)
1504 .with_pos(pos)
1505 .with_flags(BinlogDumpFlags::BINLOG_DUMP_NON_BLOCK),
1506 )
1507 .await
1508 .unwrap();
1509
1510 events_num = 0;
1511 while let Some(event) = binlog_stream.next().await {
1512 let event = event.unwrap();
1513 events_num += 1;
1514 event.header().event_type().unwrap();
1515 event.read_data().unwrap();
1516 }
1517 assert!(events_num > 0);
1518 timeout(Duration::from_secs(10), binlog_stream.close())
1519 .await
1520 .unwrap()
1521 .unwrap();
1522
1523 Ok(())
1524 }
1525
1526 #[test]
1527 fn opts_should_satisfy_send_and_sync() {
1528 struct A<T: Sync + Send>(T);
1529 A(get_opts());
1530 }
1531
1532 #[tokio::test]
1533 async fn should_connect_without_database() -> super::Result<()> {
1534 let mut conn: Conn = Conn::new(get_opts().db_name(None::<String>)).await?;
1536 conn.ping().await?;
1537 conn.disconnect().await?;
1538
1539 let mut conn: Conn = Conn::new(get_opts().db_name(Some(""))).await?;
1541 conn.ping().await?;
1542 conn.disconnect().await?;
1543
1544 Ok(())
1545 }
1546
1547 #[tokio::test]
1548 async fn should_clean_state_if_wrapper_is_dropeed() -> super::Result<()> {
1549 let mut conn: Conn = Conn::new(get_opts()).await?;
1550
1551 conn.query_drop("CREATE TEMPORARY TABLE mysql.foo (id SERIAL)")
1552 .await?;
1553
1554 conn.query_iter("SELECT 1").await?;
1556 conn.ping().await?;
1557
1558 let mut tx = conn.start_transaction(Default::default()).await?;
1560 tx.query_drop("INSERT INTO mysql.foo (id) VALUES (42)")
1561 .await?;
1562 tx.exec_iter("SELECT COUNT(*) FROM mysql.foo", ()).await?;
1563 drop(tx);
1564 conn.ping().await?;
1565
1566 let count: u8 = conn
1567 .query_first("SELECT COUNT(*) FROM mysql.foo")
1568 .await?
1569 .unwrap_or_default();
1570
1571 assert_eq!(count, 0);
1572
1573 Ok(())
1574 }
1575
1576 #[tokio::test]
1577 async fn should_connect() -> super::Result<()> {
1578 let mut conn: Conn = Conn::new(get_opts()).await?;
1579 conn.ping().await?;
1580 let plugins: Vec<String> = conn
1581 .query_map("SHOW PLUGINS", |mut row: crate::Row| {
1582 row.take("Name").unwrap()
1583 })
1584 .await?;
1585
1586 let variants = vec![
1588 ("caching_sha2_password", 2_u8, "non-empty"),
1589 ("caching_sha2_password", 2_u8, ""),
1590 ("mysql_native_password", 0_u8, "non-empty"),
1591 ("mysql_native_password", 0_u8, ""),
1592 ]
1593 .into_iter()
1594 .filter(|variant| plugins.iter().any(|p| p == variant.0));
1595
1596 for (plug, val, pass) in variants {
1597 let _ = conn.query_drop("DROP USER 'test_user'@'%'").await;
1598
1599 let query = format!("CREATE USER 'test_user'@'%' IDENTIFIED WITH {}", plug);
1600 conn.query_drop(query).await.unwrap();
1601
1602 if (8, 0, 11) <= conn.inner.version && conn.inner.version <= (9, 0, 0) {
1603 conn.query_drop(format!("SET PASSWORD FOR 'test_user'@'%' = '{}'", pass))
1604 .await
1605 .unwrap();
1606 } else {
1607 conn.query_drop(format!("SET old_passwords = {}", val))
1608 .await
1609 .unwrap();
1610 conn.query_drop(format!(
1611 "SET PASSWORD FOR 'test_user'@'%' = PASSWORD('{}')",
1612 pass
1613 ))
1614 .await
1615 .unwrap();
1616 };
1617
1618 let opts = get_opts()
1619 .user(Some("test_user"))
1620 .pass(Some(pass))
1621 .db_name(None::<String>);
1622 let result = Conn::new(opts).await;
1623
1624 conn.query_drop("DROP USER 'test_user'@'%'").await.unwrap();
1625
1626 result?.disconnect().await?;
1627 }
1628
1629 if crate::test_misc::test_compression() {
1630 assert!(format!("{:?}", conn).contains("Compression"));
1631 }
1632
1633 if crate::test_misc::test_ssl() {
1634 assert!(format!("{:?}", conn).contains("Tls"));
1635 }
1636
1637 conn.disconnect().await?;
1638 Ok(())
1639 }
1640
1641 #[test]
1642 fn should_not_panic_if_dropped_without_tokio_runtime() {
1643 let fut = Conn::new(get_opts());
1644 let runtime = tokio::runtime::Runtime::new().unwrap();
1645 runtime.block_on(async {
1646 fut.await.unwrap();
1647 });
1648 }
1650
1651 #[tokio::test]
1652 async fn should_execute_init_queries_on_new_connection() -> super::Result<()> {
1653 let opts = OptsBuilder::from_opts(get_opts()).init(vec!["SET @a = 42", "SET @b = 'foo'"]);
1654 let mut conn = Conn::new(opts).await?;
1655 let result: Vec<(u8, String)> = conn.query("SELECT @a, @b").await?;
1656 conn.disconnect().await?;
1657 assert_eq!(result, vec![(42, "foo".into())]);
1658 Ok(())
1659 }
1660
1661 #[tokio::test]
1662 async fn should_execute_setup_queries_on_reset() -> super::Result<()> {
1663 let opts = OptsBuilder::from_opts(get_opts()).setup(vec!["SET @a = 42", "SET @b = 'foo'"]);
1664 let mut conn = Conn::new(opts).await?;
1665
1666 let mut result: Vec<(u8, String)> = conn.query("SELECT @a, @b").await?;
1668 assert_eq!(result, vec![(42, "foo".into())]);
1669
1670 if conn.reset().await? {
1672 result = conn.query("SELECT @a, @b").await?;
1673 assert_eq!(result, vec![(42, "foo".into())]);
1674 }
1675
1676 conn.change_user(Default::default()).await?;
1678 result = conn.query("SELECT @a, @b").await?;
1679 assert_eq!(result, vec![(42, "foo".into())]);
1680
1681 conn.disconnect().await?;
1682 Ok(())
1683 }
1684
1685 #[tokio::test]
1686 async fn should_reset_the_connection() -> super::Result<()> {
1687 let mut conn = Conn::new(get_opts()).await?;
1688
1689 assert_eq!(
1690 conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1691 Value::NULL
1692 );
1693
1694 conn.query_drop("SET @foo = 'foo'").await?;
1695
1696 assert_eq!(
1697 conn.query_first::<String, _>("SELECT @foo").await?.unwrap(),
1698 "foo",
1699 );
1700
1701 if conn.reset().await? {
1702 assert_eq!(
1703 conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1704 Value::NULL
1705 );
1706 } else {
1707 assert_eq!(
1708 conn.query_first::<String, _>("SELECT @foo").await?.unwrap(),
1709 "foo",
1710 );
1711 }
1712
1713 conn.disconnect().await?;
1714 Ok(())
1715 }
1716
1717 #[tokio::test]
1718 async fn should_change_user() -> super::Result<()> {
1719 let mut conn = Conn::new(get_opts()).await?;
1720 assert_eq!(
1721 conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1722 Value::NULL
1723 );
1724
1725 conn.query_drop("SET @foo = 'foo'").await?;
1726
1727 assert_eq!(
1728 conn.query_first::<String, _>("SELECT @foo").await?.unwrap(),
1729 "foo",
1730 );
1731
1732 conn.change_user(Default::default()).await?;
1733 assert_eq!(
1734 conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1735 Value::NULL
1736 );
1737
1738 let plugins: &[&str] = if !conn.inner.is_mariadb && conn.server_version() >= (5, 8, 0) {
1739 &["mysql_native_password", "caching_sha2_password"]
1740 } else {
1741 &["mysql_native_password"]
1742 };
1743
1744 for plugin in plugins {
1745 let mut rng = rand::thread_rng();
1746 let mut pass = [0u8; 10];
1747 pass.try_fill(&mut rng).unwrap();
1748 let pass: String = IntoIterator::into_iter(pass)
1749 .map(|x| ((x % (123 - 97)) + 97) as char)
1750 .collect();
1751
1752 conn.query_drop("DELETE FROM mysql.user WHERE user = '__mats'")
1753 .await
1754 .unwrap();
1755 conn.query_drop("FLUSH PRIVILEGES").await.unwrap();
1756
1757 if conn.inner.is_mariadb || conn.server_version() < (5, 7, 0) {
1758 if matches!(conn.server_version(), (5, 6, _)) {
1759 conn.query_drop("CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_old_password")
1760 .await
1761 .unwrap();
1762 conn.query_drop(format!(
1763 "SET PASSWORD FOR '__mats'@'%' = OLD_PASSWORD({})",
1764 Value::from(pass.clone()).as_sql(false)
1765 ))
1766 .await
1767 .unwrap();
1768 } else {
1769 conn.query_drop("CREATE USER '__mats'@'%'").await.unwrap();
1770 conn.query_drop(format!(
1771 "SET PASSWORD FOR '__mats'@'%' = PASSWORD({})",
1772 Value::from(pass.clone()).as_sql(false)
1773 ))
1774 .await
1775 .unwrap();
1776 }
1777 } else {
1778 conn.query_drop(format!(
1779 "CREATE USER '__mats'@'%' IDENTIFIED WITH {} BY {}",
1780 plugin,
1781 Value::from(pass.clone()).as_sql(false)
1782 ))
1783 .await
1784 .unwrap();
1785 };
1786
1787 conn.query_drop("FLUSH PRIVILEGES").await.unwrap();
1788
1789 let mut conn2 = Conn::new(get_opts().secure_auth(false)).await.unwrap();
1790 conn2
1791 .change_user(
1792 ChangeUserOpts::default()
1793 .with_db_name(None)
1794 .with_user(Some("__mats".into()))
1795 .with_pass(Some(pass)),
1796 )
1797 .await
1798 .unwrap();
1799 let (db, user) = conn2
1800 .query_first::<(Option<String>, String), _>("SELECT DATABASE(), USER();")
1801 .await
1802 .unwrap()
1803 .unwrap();
1804 assert_eq!(db, None);
1805 assert!(user.starts_with("__mats"));
1806
1807 conn2.disconnect().await.unwrap();
1808 }
1809
1810 conn.disconnect().await?;
1811 Ok(())
1812 }
1813
1814 #[tokio::test]
1815 async fn should_not_cache_statements_if_stmt_cache_size_is_zero() -> super::Result<()> {
1816 let opts = OptsBuilder::from_opts(get_opts()).stmt_cache_size(0);
1817
1818 let mut conn = Conn::new(opts).await?;
1819 conn.exec_drop("DO ?", (1_u8,)).await?;
1820
1821 let stmt = conn.prep("DO 2").await?;
1822 conn.exec_drop(&stmt, ()).await?;
1823 conn.exec_drop(&stmt, ()).await?;
1824 conn.close(stmt).await?;
1825
1826 conn.exec_drop("DO 3", ()).await?;
1827 conn.exec_batch("DO 4", vec![(), ()]).await?;
1828 conn.exec_first::<u8, _, _>("DO 5", ()).await?;
1829 let row: Option<(crate::Value, usize)> = conn
1830 .query_first("SHOW SESSION STATUS LIKE 'Com_stmt_close';")
1831 .await?;
1832
1833 assert_eq!(row.unwrap().1, 1);
1834 assert_eq!(conn.inner.stmt_cache.len(), 0);
1835
1836 conn.disconnect().await?;
1837
1838 Ok(())
1839 }
1840
1841 #[tokio::test]
1842 async fn should_hold_stmt_cache_size_bound() -> super::Result<()> {
1843 let opts = OptsBuilder::from_opts(get_opts()).stmt_cache_size(3);
1844 let mut conn = Conn::new(opts).await?;
1845 conn.exec_drop("DO 1", ()).await?;
1846 conn.exec_drop("DO 2", ()).await?;
1847 conn.exec_drop("DO 3", ()).await?;
1848 conn.exec_drop("DO 1", ()).await?;
1849 conn.exec_drop("DO 4", ()).await?;
1850 conn.exec_drop("DO 3", ()).await?;
1851 conn.exec_drop("DO 5", ()).await?;
1852 conn.exec_drop("DO 6", ()).await?;
1853 let row_opt = conn
1854 .query_first("SHOW SESSION STATUS LIKE 'Com_stmt_close';")
1855 .await?;
1856 let (_, count): (String, usize) = row_opt.unwrap();
1857 assert_eq!(count, 3);
1858 let order = conn
1859 .stmt_cache_ref()
1860 .iter()
1861 .map(|item| item.1.query.0.as_ref())
1862 .collect::<Vec<&[u8]>>();
1863 assert_eq!(order, &[b"DO 6", b"DO 5", b"DO 3"]);
1864 conn.disconnect().await?;
1865 Ok(())
1866 }
1867
1868 #[tokio::test]
1869 async fn should_perform_queries() -> super::Result<()> {
1870 let mut conn = Conn::new(get_opts()).await?;
1871 for x in (MAX_PAYLOAD_LEN - 2)..=(MAX_PAYLOAD_LEN + 2) {
1872 let long_string = ::std::iter::repeat('A').take(x).collect::<String>();
1873 let result: Vec<(String, u8)> = conn
1874 .query(format!(r"SELECT '{}', 231", long_string))
1875 .await?;
1876 assert_eq!((long_string, 231_u8), result[0]);
1877 }
1878 conn.disconnect().await?;
1879 Ok(())
1880 }
1881
1882 #[tokio::test]
1883 async fn should_query_drop() -> super::Result<()> {
1884 let mut conn = Conn::new(get_opts()).await?;
1885 conn.query_drop("CREATE TEMPORARY TABLE tmp (id int DEFAULT 10, name text)")
1886 .await?;
1887 conn.query_drop("INSERT INTO tmp VALUES (1, 'foo')").await?;
1888 let result: Option<u8> = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
1889 conn.disconnect().await?;
1890 assert_eq!(result, Some(1_u8));
1891 Ok(())
1892 }
1893
1894 #[tokio::test]
1895 async fn should_prepare_statement() -> super::Result<()> {
1896 let mut conn = Conn::new(get_opts()).await?;
1897 let stmt = conn.prep(r"SELECT ?").await?;
1898 conn.close(stmt).await?;
1899 conn.disconnect().await?;
1900
1901 let mut conn = Conn::new(get_opts()).await?;
1902 let stmt = conn.prep(r"SELECT :foo").await?;
1903
1904 {
1905 let query = String::from("SELECT ?, ?");
1906 let stmt = conn.prep(&*query).await?;
1907 conn.close(stmt).await?;
1908 {
1909 let mut conn = Conn::new(get_opts()).await?;
1910 let stmt = conn.prep(&*query).await?;
1911 conn.close(stmt).await?;
1912 conn.disconnect().await?;
1913 }
1914 }
1915
1916 conn.close(stmt).await?;
1917 conn.disconnect().await?;
1918
1919 Ok(())
1920 }
1921
1922 #[tokio::test]
1923 async fn should_execute_statement() -> super::Result<()> {
1924 let long_string = ::std::iter::repeat('A')
1925 .take(18 * 1024 * 1024)
1926 .collect::<String>();
1927 let mut conn = Conn::new(get_opts()).await?;
1928 let stmt = conn.prep(r"SELECT ?").await?;
1929 let result = conn.exec_iter(&stmt, (&long_string,)).await?;
1930 let mut mapped = result
1931 .map_and_drop(|row| from_row::<(String,)>(row))
1932 .await?;
1933 assert_eq!(mapped.len(), 1);
1934 assert_eq!(mapped.pop(), Some((long_string,)));
1935 let result = conn.exec_iter(&stmt, (42_u8,)).await?;
1936 let collected = result.collect_and_drop::<(u8,)>().await?;
1937 assert_eq!(collected, vec![(42u8,)]);
1938 let result = conn.exec_iter(&stmt, (8_u8,)).await?;
1939 let reduced = result
1940 .reduce_and_drop(2, |mut acc, row| {
1941 acc += from_row::<i32>(row);
1942 acc
1943 })
1944 .await?;
1945 conn.close(stmt).await?;
1946 conn.disconnect().await?;
1947 assert_eq!(reduced, 10);
1948
1949 let mut conn = Conn::new(get_opts()).await?;
1950 let stmt = conn.prep(r"SELECT :foo, :bar, :foo, 3").await?;
1951 let result = conn
1952 .exec_iter(&stmt, params! { "foo" => "quux", "bar" => "baz" })
1953 .await?;
1954 let mut mapped = result
1955 .map_and_drop(|row| from_row::<(String, String, String, u8)>(row))
1956 .await?;
1957 assert_eq!(mapped.len(), 1);
1958 assert_eq!(
1959 mapped.pop(),
1960 Some(("quux".into(), "baz".into(), "quux".into(), 3))
1961 );
1962 let result = conn
1963 .exec_iter(&stmt, params! { "foo" => 2, "bar" => 3 })
1964 .await?;
1965 let collected = result.collect_and_drop::<(u8, u8, u8, u8)>().await?;
1966 assert_eq!(collected, vec![(2, 3, 2, 3)]);
1967 let result = conn
1968 .exec_iter(&stmt, params! { "foo" => 2, "bar" => 3 })
1969 .await?;
1970 let reduced = result
1971 .reduce_and_drop(0, |acc, row| {
1972 let (a, b, c, d): (u8, u8, u8, u8) = from_row(row);
1973 acc + a + b + c + d
1974 })
1975 .await?;
1976 conn.close(stmt).await?;
1977 conn.disconnect().await?;
1978 assert_eq!(reduced, 10);
1979 Ok(())
1980 }
1981
1982 #[tokio::test]
1983 async fn should_prep_exec_statement() -> super::Result<()> {
1984 let mut conn = Conn::new(get_opts()).await?;
1985 let result = conn
1986 .exec_iter(r"SELECT :a, :b, :a", params! { "a" => 2, "b" => 3 })
1987 .await?;
1988 let output = result
1989 .map_and_drop(|row| {
1990 let (a, b, c): (u8, u8, u8) = from_row(row);
1991 a * b * c
1992 })
1993 .await?;
1994 conn.disconnect().await?;
1995 assert_eq!(output[0], 12u8);
1996 Ok(())
1997 }
1998
1999 #[tokio::test]
2000 async fn should_first_exec_statement() -> super::Result<()> {
2001 let mut conn = Conn::new(get_opts()).await?;
2002 let output = conn
2003 .exec_first(
2004 r"SELECT :a UNION ALL SELECT :b",
2005 params! { "a" => 2, "b" => 3 },
2006 )
2007 .await?;
2008 conn.disconnect().await?;
2009 assert_eq!(output, Some(2u8));
2010 Ok(())
2011 }
2012
2013 #[tokio::test]
2014 async fn issue_107() -> super::Result<()> {
2015 let mut conn = Conn::new(get_opts()).await?;
2016 conn.query_drop(
2017 r"CREATE TEMPORARY TABLE mysql.issue (
2018 a BIGINT(20) UNSIGNED,
2019 b VARBINARY(16),
2020 c BINARY(32),
2021 d BIGINT(20) UNSIGNED,
2022 e BINARY(32)
2023 )",
2024 )
2025 .await?;
2026 conn.query_drop(
2027 r"INSERT INTO mysql.issue VALUES (
2028 0,
2029 0xC066F966B0860000,
2030 0x7939DA98E524C5F969FC2DE8D905FD9501EBC6F20001B0A9C941E0BE6D50CF44,
2031 0,
2032 ''
2033 ), (
2034 1,
2035 '',
2036 0x076311DF4D407B0854371BA13A5F3FB1A4555AC22B361375FD47B263F31822F2,
2037 0,
2038 ''
2039 )",
2040 )
2041 .await?;
2042
2043 let q = "SELECT b, c, d, e FROM mysql.issue";
2044 let result = conn.query_iter(q).await?;
2045
2046 let loaded_structs = result
2047 .map_and_drop(|row| crate::from_row::<(Vec<u8>, Vec<u8>, u64, Vec<u8>)>(row))
2048 .await?;
2049
2050 conn.disconnect().await?;
2051
2052 assert_eq!(loaded_structs.len(), 2);
2053
2054 Ok(())
2055 }
2056
2057 #[tokio::test]
2058 async fn should_run_transactions() -> super::Result<()> {
2059 let mut conn = Conn::new(get_opts()).await?;
2060 conn.query_drop("CREATE TEMPORARY TABLE tmp (id INT, name TEXT)")
2061 .await?;
2062 let mut transaction = conn.start_transaction(Default::default()).await?;
2063 transaction
2064 .query_drop("INSERT INTO tmp VALUES (1, 'foo'), (2, 'bar')")
2065 .await?;
2066 assert_eq!(transaction.last_insert_id(), None);
2067 assert_eq!(transaction.affected_rows(), 2);
2068 assert_eq!(transaction.get_warnings(), 0);
2069 assert_eq!(transaction.info(), "Records: 2 Duplicates: 0 Warnings: 0");
2070 transaction.commit().await?;
2071 let output_opt = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
2072 assert_eq!(output_opt, Some((2u8,)));
2073 let mut transaction = conn.start_transaction(Default::default()).await?;
2074 transaction
2075 .query_drop("INSERT INTO tmp VALUES (3, 'baz'), (4, 'quux')")
2076 .await?;
2077 let output_opt = transaction
2078 .exec_first("SELECT COUNT(*) FROM tmp", ())
2079 .await?;
2080 assert_eq!(output_opt, Some((4u8,)));
2081 transaction.rollback().await?;
2082 let output_opt = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
2083 assert_eq!(output_opt, Some((2u8,)));
2084
2085 let mut transaction = conn.start_transaction(Default::default()).await?;
2086 transaction
2087 .query_drop("INSERT INTO tmp VALUES (3, 'baz')")
2088 .await?;
2089 drop(transaction); let output_opt = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
2091 assert_eq!(output_opt, Some((2u8,)));
2092
2093 conn.disconnect().await?;
2094 Ok(())
2095 }
2096
2097 #[tokio::test]
2098 async fn should_handle_multiresult_set_with_error() -> super::Result<()> {
2099 const QUERY_FIRST: &str = "SELECT * FROM tmp; SELECT 1; SELECT 2;";
2100 const QUERY_MIDDLE: &str = "SELECT 1; SELECT * FROM tmp; SELECT 2";
2101 let mut conn = Conn::new(get_opts()).await.unwrap();
2102
2103 let result = QUERY_FIRST.run(&mut conn).await;
2105 assert!(matches!(result, Err(Error::Server(_))));
2106
2107 let mut result = QUERY_MIDDLE.run(&mut conn).await.unwrap();
2108
2109 let result_set: Vec<u8> = result.collect().await.unwrap();
2111 assert_eq!(result_set, vec![1]);
2112
2113 let result_set: super::Result<Vec<u8>> = result.collect().await;
2115 assert!(matches!(result_set, Err(Error::Server(_))));
2116
2117 assert!(result.is_empty());
2119
2120 conn.ping().await?;
2121 conn.disconnect().await?;
2122
2123 Ok(())
2124 }
2125
2126 #[tokio::test]
2127 async fn should_handle_binary_multiresult_set_with_error() -> super::Result<()> {
2128 const PROC_DEF_FIRST: &str =
2129 r#"CREATE PROCEDURE err_first() BEGIN SELECT * FROM tmp; SELECT 1; END"#;
2130 const PROC_DEF_MIDDLE: &str =
2131 r#"CREATE PROCEDURE err_middle() BEGIN SELECT 1; SELECT * FROM tmp; SELECT 2; END"#;
2132
2133 let mut conn = Conn::new(get_opts()).await.unwrap();
2134
2135 conn.query_drop("DROP PROCEDURE IF EXISTS err_first")
2136 .await?;
2137 conn.query_iter(PROC_DEF_FIRST).await?;
2138
2139 conn.query_drop("DROP PROCEDURE IF EXISTS err_middle")
2140 .await?;
2141 conn.query_iter(PROC_DEF_MIDDLE).await?;
2142
2143 let result = conn.query_iter("CALL err_first()").await;
2145 assert!(matches!(result, Err(Error::Server(_))));
2146
2147 let mut result = conn.query_iter("CALL err_middle()").await?;
2148
2149 let result_set: Vec<u8> = result.collect().await.unwrap();
2151 assert_eq!(result_set, vec![1]);
2152
2153 let result_set: super::Result<Vec<u8>> = result.collect().await;
2155 assert!(matches!(result_set, Err(Error::Server(_))));
2156
2157 assert!(result.is_empty());
2159
2160 conn.ping().await?;
2161 conn.disconnect().await?;
2162
2163 Ok(())
2164 }
2165
2166 #[tokio::test]
2167 async fn should_handle_multiresult_set_with_local_infile() -> super::Result<()> {
2168 use std::fs::write;
2169
2170 let file_path = tempfile::Builder::new().tempfile_in("").unwrap();
2171 let file_path = file_path.path();
2172 let file_name = file_path.file_name().unwrap();
2173
2174 write(file_name, b"AAAAAA\nBBBBBB\nCCCCCC\n")?;
2175
2176 let opts = get_opts().local_infile_handler(Some(WhiteListFsHandler::new(&[file_name][..])));
2177
2178 let mut conn = Conn::new(opts).await.unwrap();
2180 "CREATE TEMPORARY TABLE tmp (a TEXT)".run(&mut conn).await?;
2181
2182 let query = format!(
2183 r#"SELECT * FROM tmp;
2184 LOAD DATA LOCAL INFILE "{}" INTO TABLE tmp;
2185 LOAD DATA LOCAL INFILE "{}" INTO TABLE tmp;
2186 SELECT * FROM tmp"#,
2187 file_name.to_str().unwrap(),
2188 file_name.to_str().unwrap(),
2189 );
2190
2191 let mut result = query.run(&mut conn).await?;
2192
2193 let result_set = result.collect::<String>().await?;
2194 assert_eq!(result_set.len(), 0);
2195
2196 let mut no_local_infile = false;
2197
2198 for _ in 0..2 {
2199 match result.collect::<String>().await {
2200 Ok(result_set) => {
2201 assert_eq!(result.affected_rows(), 3);
2202 assert!(result_set.is_empty())
2203 }
2204 Err(Error::Server(ref err)) if err.code == 1148 => {
2205 no_local_infile = true;
2207 break;
2208 }
2209 Err(Error::Server(ref err)) if err.code == 3948 => {
2210 no_local_infile = true;
2213 break;
2214 }
2215 Err(err) => return Err(err),
2216 }
2217 }
2218
2219 if no_local_infile {
2220 assert!(result.is_empty());
2221 assert_eq!(result_set.len(), 0);
2222 } else {
2223 let result_set = result.collect::<String>().await?;
2224 assert_eq!(result_set.len(), 6);
2225 assert_eq!(result_set[0], "AAAAAA");
2226 assert_eq!(result_set[1], "BBBBBB");
2227 assert_eq!(result_set[2], "CCCCCC");
2228 assert_eq!(result_set[3], "AAAAAA");
2229 assert_eq!(result_set[4], "BBBBBB");
2230 assert_eq!(result_set[5], "CCCCCC");
2231 }
2232
2233 conn.ping().await?;
2234 conn.disconnect().await?;
2235
2236 Ok(())
2237 }
2238
2239 #[tokio::test]
2240 async fn should_provide_multiresult_set_metadata() -> super::Result<()> {
2241 let mut c = Conn::new(get_opts()).await?;
2242 c.query_drop("CREATE TEMPORARY TABLE tmp (id INT, foo TEXT)")
2243 .await?;
2244
2245 let mut result = c
2246 .query_iter("SELECT 1; SELECT id, foo FROM tmp WHERE 1 = 2; DO 42; SELECT 2;")
2247 .await?;
2248 assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 1);
2249
2250 result.for_each(drop).await?;
2251 assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 2);
2252
2253 result.for_each(drop).await?;
2254 assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 0);
2255
2256 result.for_each(drop).await?;
2257 assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 1);
2258
2259 c.disconnect().await?;
2260 Ok(())
2261 }
2262
2263 #[tokio::test]
2264 async fn should_expose_query_result_metadata() -> super::Result<()> {
2265 let pool = Pool::new(get_opts());
2266 let mut c = pool.get_conn().await?;
2267
2268 c.query_drop(
2269 r"
2270 CREATE TEMPORARY TABLE `foo`
2271 ( `id` SERIAL
2272 , `bar_id` varchar(36) NOT NULL
2273 , `baz_id` varchar(36) NOT NULL
2274 , `ctime` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP()
2275 , PRIMARY KEY (`id`)
2276 , KEY `bar_idx` (`bar_id`)
2277 , KEY `baz_idx` (`baz_id`)
2278 );",
2279 )
2280 .await?;
2281
2282 const QUERY: &str = "INSERT INTO foo (bar_id, baz_id) VALUES (?, ?)";
2283 let params = ("qwerty", "data.employee_id");
2284
2285 let query_result = c.exec_iter(QUERY, params).await?;
2286 assert_eq!(query_result.last_insert_id(), Some(1));
2287 query_result.drop_result().await?;
2288
2289 c.exec_drop(QUERY, params).await?;
2290 assert_eq!(c.last_insert_id(), Some(2));
2291
2292 let mut tx = c.start_transaction(Default::default()).await?;
2293
2294 tx.exec_drop(QUERY, params).await?;
2295 assert_eq!(tx.last_insert_id(), Some(3));
2296
2297 Ok(())
2298 }
2299
2300 #[tokio::test]
2301 async fn should_handle_local_infile_locally() -> super::Result<()> {
2302 let mut conn = Conn::new(get_opts()).await.unwrap();
2303 conn.query_drop("CREATE TEMPORARY TABLE tmp (a TEXT);")
2304 .await
2305 .unwrap();
2306
2307 conn.set_infile_handler(async move {
2308 Ok(
2309 stream::iter([Bytes::from("AAAAAA\n"), Bytes::from("BBBBBB\nCCCCCC\n")])
2310 .map(Ok)
2311 .boxed(),
2312 )
2313 });
2314
2315 match conn
2316 .query_drop(r#"LOAD DATA LOCAL INFILE "dummy" INTO TABLE tmp;"#)
2317 .await
2318 {
2319 Ok(_) => (),
2320 Err(super::Error::Server(ref err)) if err.code == 1148 => {
2321 return Ok(());
2323 }
2324 Err(super::Error::Server(ref err)) if err.code == 3948 => {
2325 return Ok(());
2328 }
2329 e @ Err(_) => e.unwrap(),
2330 };
2331
2332 let result: Vec<String> = conn.query("SELECT * FROM tmp").await?;
2333 assert_eq!(result.len(), 3);
2334 assert_eq!(result[0], "AAAAAA");
2335 assert_eq!(result[1], "BBBBBB");
2336 assert_eq!(result[2], "CCCCCC");
2337
2338 Ok(())
2339 }
2340
2341 #[tokio::test]
2342 async fn should_handle_local_infile_globally() -> super::Result<()> {
2343 use std::fs::write;
2344
2345 let file_path = tempfile::Builder::new().tempfile_in("").unwrap();
2346 let file_path = file_path.path();
2347 let file_name = file_path.file_name().unwrap();
2348
2349 write(file_name, b"AAAAAA\nBBBBBB\nCCCCCC\n")?;
2350
2351 let opts = get_opts().local_infile_handler(Some(WhiteListFsHandler::new(&[file_name][..])));
2352
2353 let mut conn = Conn::new(opts).await.unwrap();
2354 conn.query_drop("CREATE TEMPORARY TABLE tmp (a TEXT);")
2355 .await
2356 .unwrap();
2357
2358 match conn
2359 .query_drop(format!(
2360 r#"LOAD DATA LOCAL INFILE "{}" INTO TABLE tmp;"#,
2361 file_name.to_str().unwrap(),
2362 ))
2363 .await
2364 {
2365 Ok(_) => (),
2366 Err(super::Error::Server(ref err)) if err.code == 1148 => {
2367 return Ok(());
2369 }
2370 Err(super::Error::Server(ref err)) if err.code == 3948 => {
2371 return Ok(());
2374 }
2375 e @ Err(_) => e.unwrap(),
2376 };
2377
2378 let result: Vec<String> = conn.query("SELECT * FROM tmp").await?;
2379 assert_eq!(result.len(), 3);
2380 assert_eq!(result[0], "AAAAAA");
2381 assert_eq!(result[1], "BBBBBB");
2382 assert_eq!(result[2], "CCCCCC");
2383
2384 Ok(())
2385 }
2386
2387 #[cfg(feature = "nightly")]
2388 mod bench {
2389 use crate::{conn::Conn, queryable::Queryable, test_misc::get_opts};
2390
2391 #[bench]
2392 fn simple_exec(bencher: &mut test::Bencher) {
2393 let mut runtime = tokio::runtime::Runtime::new().unwrap();
2394 let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2395
2396 bencher.iter(|| {
2397 runtime.block_on(conn.query_drop("DO 1")).unwrap();
2398 });
2399
2400 runtime.block_on(conn.disconnect()).unwrap();
2401 }
2402
2403 #[bench]
2404 fn select_large_string(bencher: &mut test::Bencher) {
2405 let mut runtime = tokio::runtime::Runtime::new().unwrap();
2406 let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2407
2408 bencher.iter(|| {
2409 runtime
2410 .block_on(conn.query_drop("SELECT REPEAT('A', 10000)"))
2411 .unwrap();
2412 });
2413
2414 runtime.block_on(conn.disconnect()).unwrap();
2415 }
2416
2417 #[bench]
2418 fn prepared_exec(bencher: &mut test::Bencher) {
2419 let mut runtime = tokio::runtime::Runtime::new().unwrap();
2420 let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2421 let stmt = runtime.block_on(conn.prep("DO 1")).unwrap();
2422
2423 bencher.iter(|| {
2424 runtime.block_on(conn.exec_drop(&stmt, ())).unwrap();
2425 });
2426
2427 runtime.block_on(conn.close(stmt)).unwrap();
2428 runtime.block_on(conn.disconnect()).unwrap();
2429 }
2430
2431 #[bench]
2432 fn prepare_and_exec(bencher: &mut test::Bencher) {
2433 let mut runtime = tokio::runtime::Runtime::new().unwrap();
2434 let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2435
2436 bencher.iter(|| {
2437 runtime.block_on(conn.exec_drop("SELECT ?", (0,))).unwrap();
2438 });
2439
2440 runtime.block_on(conn.disconnect()).unwrap();
2441 }
2442 }
2443}