1use futures_util::FutureExt;
10pub use mysql_common::named_params;
11
12use mysql_common::{
13 constants::DEFAULT_MAX_ALLOWED_PACKET,
14 crypto,
15 io::ParseBuf,
16 packets::{
17 AuthPlugin, AuthSwitchRequest, CommonOkPacket, ErrPacket, HandshakePacket,
18 HandshakeResponse, OkPacket, OkPacketDeserializer, OldAuthSwitchRequest, OldEofPacket,
19 ResultSetTerminator,
20 },
21 proto::MySerialize,
22 row::Row,
23};
24
25use std::{
26 borrow::Cow,
27 fmt,
28 future::Future,
29 mem::{self, replace},
30 pin::Pin,
31 str::FromStr,
32 sync::Arc,
33 time::{Duration, Instant},
34};
35
36use crate::{
37 buffer_pool::PooledBuf,
38 conn::{pool::Pool, stmt_cache::StmtCache},
39 consts::{CapabilityFlags, Command, StatusFlags},
40 error::*,
41 io::Stream,
42 opts::Opts,
43 queryable::{
44 query_result::{QueryResult, ResultSetMeta},
45 transaction::TxStatus,
46 BinaryProtocol, Queryable, TextProtocol,
47 },
48 ChangeUserOpts, InfileData, OptsBuilder,
49};
50
51use self::routines::Routine;
52
53#[cfg(feature = "binlog")]
54pub mod binlog_stream;
55pub mod pool;
56pub mod routines;
57pub mod stmt_cache;
58
59const DEFAULT_WAIT_TIMEOUT: usize = 28800;
60
61fn disconnect(mut conn: Conn) {
63 let disconnected = conn.inner.disconnected;
64
65 conn.inner.disconnected = true;
67
68 if !disconnected {
69 if std::thread::panicking() {
71 return;
72 }
73
74 if let Ok(handle) = tokio::runtime::Handle::try_current() {
77 handle.spawn(async move {
78 if let Ok(conn) = conn.cleanup_for_pool().await {
79 let _ = conn.disconnect().await;
80 }
81 });
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
88pub(crate) enum PendingResult {
89 Pending(ResultSetMeta),
91 Taken(Arc<ResultSetMeta>),
93}
94
95struct ConnInner {
97 stream: Option<Stream>,
98 id: u32,
99 is_mariadb: bool,
100 version: (u16, u16, u16),
101 socket: Option<String>,
102 capabilities: CapabilityFlags,
103 status: StatusFlags,
104 last_ok_packet: Option<OkPacket<'static>>,
105 last_err_packet: Option<mysql_common::packets::ServerError<'static>>,
106 pool: Option<Pool>,
107 pending_result: std::result::Result<Option<PendingResult>, ServerError>,
108 tx_status: TxStatus,
109 reset_upon_returning_to_a_pool: bool,
110 opts: Opts,
111 ttl_deadline: Option<Instant>,
112 last_io: Instant,
113 wait_timeout: Duration,
114 stmt_cache: StmtCache,
115 nonce: Vec<u8>,
116 auth_plugin: AuthPlugin<'static>,
117 auth_switched: bool,
118 server_key: Option<Vec<u8>>,
119 pub(crate) disconnected: bool,
121 infile_handler:
123 Option<Pin<Box<dyn Future<Output = crate::Result<InfileData>> + Send + Sync + 'static>>>,
124}
125
126impl fmt::Debug for ConnInner {
127 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128 f.debug_struct("Conn")
129 .field("connection id", &self.id)
130 .field("server version", &self.version)
131 .field("pool", &self.pool)
132 .field("pending_result", &self.pending_result)
133 .field("tx_status", &self.tx_status)
134 .field("stream", &self.stream)
135 .field("options", &self.opts)
136 .field("server_key", &self.server_key)
137 .field("auth_plugin", &self.auth_plugin)
138 .finish()
139 }
140}
141
142impl ConnInner {
143 fn empty(opts: Opts) -> ConnInner {
145 let ttl_deadline = opts.pool_opts().new_connection_ttl_deadline();
146 ConnInner {
147 capabilities: opts.get_capabilities(),
148 status: StatusFlags::empty(),
149 last_ok_packet: None,
150 last_err_packet: None,
151 stream: None,
152 is_mariadb: false,
153 version: (0, 0, 0),
154 id: 0,
155 pending_result: Ok(None),
156 pool: None,
157 tx_status: TxStatus::None,
158 last_io: Instant::now(),
159 wait_timeout: Duration::from_secs(0),
160 stmt_cache: StmtCache::new(opts.stmt_cache_size()),
161 socket: opts.socket().map(Into::into),
162 opts,
163 ttl_deadline,
164 nonce: Vec::default(),
165 auth_plugin: AuthPlugin::MysqlNativePassword,
166 auth_switched: false,
167 disconnected: false,
168 server_key: None,
169 infile_handler: None,
170 reset_upon_returning_to_a_pool: false,
171 }
172 }
173
174 fn stream_mut(&mut self) -> Result<&mut Stream> {
178 self.stream
179 .as_mut()
180 .ok_or_else(|| DriverError::ConnectionClosed.into())
181 }
182}
183
184#[derive(Debug)]
186pub struct Conn {
187 inner: Box<ConnInner>,
188}
189
190impl Conn {
191 pub fn id(&self) -> u32 {
193 self.inner.id
194 }
195
196 pub fn last_insert_id(&self) -> Option<u64> {
200 self.inner
201 .last_ok_packet
202 .as_ref()
203 .and_then(|ok| ok.last_insert_id())
204 }
205
206 pub fn affected_rows(&self) -> u64 {
209 self.inner
210 .last_ok_packet
211 .as_ref()
212 .map(|ok| ok.affected_rows())
213 .unwrap_or_default()
214 }
215
216 pub fn info(&self) -> Cow<'_, str> {
218 self.inner
219 .last_ok_packet
220 .as_ref()
221 .and_then(|ok| ok.info_str())
222 .unwrap_or_else(|| "".into())
223 }
224
225 pub fn get_warnings(&self) -> u16 {
227 self.inner
228 .last_ok_packet
229 .as_ref()
230 .map(|ok| ok.warnings())
231 .unwrap_or_default()
232 }
233
234 pub fn last_ok_packet(&self) -> Option<&OkPacket<'static>> {
236 self.inner.last_ok_packet.as_ref()
237 }
238
239 pub fn reset_connection(&mut self, reset_connection: bool) {
243 self.inner.reset_upon_returning_to_a_pool = reset_connection;
244 }
245
246 pub(crate) fn stream_mut(&mut self) -> Result<&mut Stream> {
247 self.inner.stream_mut()
248 }
249
250 pub(crate) fn capabilities(&self) -> CapabilityFlags {
251 self.inner.capabilities
252 }
253
254 pub(crate) fn touch(&mut self) {
256 self.inner.last_io = Instant::now();
257 }
258
259 pub(crate) fn reset_seq_id(&mut self) {
261 if let Some(stream) = self.inner.stream.as_mut() {
262 stream.reset_seq_id();
263 }
264 }
265
266 pub(crate) fn sync_seq_id(&mut self) {
268 if let Some(stream) = self.inner.stream.as_mut() {
269 stream.sync_seq_id();
270 }
271 }
272
273 pub(crate) fn handle_ok(&mut self, ok_packet: OkPacket<'static>) {
275 self.inner.status = ok_packet.status_flags();
276 self.inner.last_err_packet = None;
277 self.inner.last_ok_packet = Some(ok_packet);
278 }
279
280 pub(crate) fn handle_err(&mut self, err_packet: ErrPacket<'_>) -> Result<()> {
282 match err_packet {
283 ErrPacket::Error(err) => {
284 self.inner.status = StatusFlags::empty();
285 self.inner.last_ok_packet = None;
286 self.inner.last_err_packet = Some(err.clone().into_owned());
287 Err(Error::from(err))
288 }
289 ErrPacket::Progress(_) => Ok(()),
290 }
291 }
292
293 pub(crate) fn get_tx_status(&self) -> TxStatus {
295 self.inner.tx_status
296 }
297
298 pub(crate) fn set_tx_status(&mut self, tx_status: TxStatus) {
300 self.inner.tx_status = tx_status;
301 }
302
303 pub(crate) fn use_pending_result(
307 &mut self,
308 ) -> std::result::Result<Option<&PendingResult>, ServerError> {
309 if let Err(ref e) = self.inner.pending_result {
310 let e = e.clone();
311 self.inner.pending_result = Ok(None);
312 Err(e)
313 } else {
314 Ok(self.inner.pending_result.as_ref().unwrap().as_ref())
315 }
316 }
317
318 pub(crate) fn get_pending_result(
319 &self,
320 ) -> std::result::Result<Option<&PendingResult>, &ServerError> {
321 self.inner.pending_result.as_ref().map(|x| x.as_ref())
322 }
323
324 pub(crate) fn has_pending_result(&self) -> bool {
325 self.inner.pending_result.is_err() || matches!(self.inner.pending_result, Ok(Some(_)))
326 }
327
328 pub(crate) fn set_pending_result(
330 &mut self,
331 meta: Option<ResultSetMeta>,
332 ) -> std::result::Result<Option<PendingResult>, ServerError> {
333 replace(
334 &mut self.inner.pending_result,
335 Ok(meta.map(PendingResult::Pending)),
336 )
337 }
338
339 pub(crate) fn set_pending_result_error(
340 &mut self,
341 error: ServerError,
342 ) -> std::result::Result<Option<PendingResult>, ServerError> {
343 replace(&mut self.inner.pending_result, Err(error))
344 }
345
346 pub(crate) fn take_pending_result(
348 &mut self,
349 ) -> std::result::Result<Option<Arc<ResultSetMeta>>, ServerError> {
350 let mut output = None;
351
352 self.inner.pending_result = match replace(&mut self.inner.pending_result, Ok(None))? {
353 Some(PendingResult::Pending(x)) => {
354 let meta = Arc::new(x);
355 output = Some(meta.clone());
356 Ok(Some(PendingResult::Taken(meta)))
357 }
358 x => Ok(x),
359 };
360
361 Ok(output)
362 }
363
364 pub(crate) fn status(&self) -> StatusFlags {
366 self.inner.status
367 }
368
369 pub(crate) async fn routine<'a, F, T>(&mut self, mut f: F) -> crate::Result<T>
370 where
371 F: Routine<T> + 'a,
372 {
373 self.inner.disconnected = true;
374 let result = f.call(&mut *self).await;
375 match result {
376 result @ Ok(_) | result @ Err(crate::Error::Server(_)) => {
377 self.inner.disconnected = false;
379 result
380 }
381 Err(err) => {
382 if self.inner.stream.is_some() {
383 self.take_stream().close().await?;
384 }
385 Err(err)
386 }
387 }
388 }
389
390 pub fn server_version(&self) -> (u16, u16, u16) {
392 self.inner.version
393 }
394
395 pub fn opts(&self) -> &Opts {
397 &self.inner.opts
398 }
399
400 pub fn set_infile_handler<T>(&mut self, handler: T)
407 where
408 T: Future<Output = crate::Result<InfileData>>,
409 T: Send + Sync + 'static,
410 {
411 self.inner.infile_handler = Some(Box::pin(handler));
412 }
413
414 fn take_stream(&mut self) -> Stream {
415 self.inner.stream.take().unwrap()
416 }
417
418 pub async fn disconnect(mut self) -> Result<()> {
420 if !self.inner.disconnected {
421 self.inner.disconnected = true;
422 self.write_command_data(Command::COM_QUIT, &[]).await?;
423 let stream = self.take_stream();
424 stream.close().await?;
425 }
426 Ok(())
427 }
428
429 async fn close_conn(mut self) -> Result<()> {
431 self = self.cleanup_for_pool().await?;
432 self.disconnect().await
433 }
434
435 fn is_secure(&self) -> bool {
437 #[cfg(any(
438 feature = "native-tls-tls",
439 feature = "rustls-tls",
440 feature = "wasmedge-tls"
441 ))]
442 {
443 self.inner
444 .stream
445 .as_ref()
446 .map(|x| x.is_secure())
447 .unwrap_or_default()
448 }
449
450 #[cfg(not(any(
451 feature = "native-tls-tls",
452 feature = "rustls-tls",
453 feature = "wasmedge-tls"
454 )))]
455 false
456 }
457
458 fn is_socket(&self) -> bool {
460 #[cfg(unix)]
461 {
462 self.inner
463 .stream
464 .as_ref()
465 .map(|x| x.is_socket())
466 .unwrap_or_default()
467 }
468
469 #[cfg(not(unix))]
470 false
471 }
472
473 fn take(&mut self) -> Conn {
475 mem::replace(self, Conn::empty(Default::default()))
476 }
477
478 fn empty(opts: Opts) -> Self {
479 Self {
480 inner: Box::new(ConnInner::empty(opts)),
481 }
482 }
483
484 fn setup_stream(&mut self) -> Result<()> {
488 debug_assert!(self.inner.stream.is_some());
489 if let Some(stream) = self.inner.stream.as_mut() {
490 stream.set_tcp_nodelay(self.inner.opts.tcp_nodelay())?;
491 }
492 Ok(())
493 }
494
495 async fn handle_handshake(&mut self) -> Result<()> {
496 let packet = self.read_packet().await?;
497 let handshake = ParseBuf(&packet).parse::<HandshakePacket>(())?;
498
499 self.inner.nonce = {
501 let mut nonce = Vec::from(handshake.scramble_1_ref());
502 nonce.extend_from_slice(handshake.scramble_2_ref().unwrap_or(&[][..]));
503 nonce.resize(20, 0);
506 nonce
507 };
508
509 self.inner.capabilities = handshake.capabilities() & self.inner.opts.get_capabilities();
510 self.inner.version = handshake
511 .maria_db_server_version_parsed()
512 .map(|version| {
513 self.inner.is_mariadb = true;
514 version
515 })
516 .or_else(|| handshake.server_version_parsed())
517 .unwrap_or((0, 0, 0));
518 self.inner.id = handshake.connection_id();
519 self.inner.status = handshake.status_flags();
520
521 self.inner.auth_plugin = match handshake.auth_plugin() {
525 Some(AuthPlugin::CachingSha2Password) => AuthPlugin::CachingSha2Password,
526 _ => AuthPlugin::MysqlNativePassword,
527 };
528
529 Ok(())
530 }
531
532 async fn switch_to_ssl_if_needed(&mut self) -> Result<()> {
533 if self
534 .inner
535 .opts
536 .get_capabilities()
537 .contains(CapabilityFlags::CLIENT_SSL)
538 {
539 if !self
540 .inner
541 .capabilities
542 .contains(CapabilityFlags::CLIENT_SSL)
543 {
544 return Err(DriverError::NoClientSslFlagFromServer.into());
545 }
546
547 let collation = if self.inner.version >= (5, 5, 3) {
548 mysql_common::constants::UTF8MB4_GENERAL_CI
549 } else {
550 crate::consts::UTF8MB4_GENERAL_CI
551 };
552
553 let ssl_request = mysql_common::packets::SslRequest::new(
554 self.inner.capabilities,
555 DEFAULT_MAX_ALLOWED_PACKET as u32,
556 collation as u8,
557 );
558 self.write_struct(&ssl_request).await?;
559 let conn = self;
560 let ssl_opts = conn.opts().ssl_opts().cloned().expect("unreachable");
561 let domain = conn.opts().ip_or_hostname().into();
562 conn.stream_mut()?.make_secure(domain, ssl_opts).await?;
563 Ok(())
564 } else {
565 Ok(())
566 }
567 }
568
569 async fn do_handshake_response(&mut self) -> Result<()> {
570 let auth_data = self
571 .inner
572 .auth_plugin
573 .gen_data(self.inner.opts.pass(), &self.inner.nonce);
574
575 let handshake_response = HandshakeResponse::new(
576 auth_data.as_deref(),
577 self.inner.version,
578 self.inner.opts.user().map(|x| x.as_bytes()),
579 self.inner.opts.db_name().map(|x| x.as_bytes()),
580 Some(self.inner.auth_plugin.borrow()),
581 self.capabilities(),
582 Default::default(), );
584
585 let mut buf = crate::BUFFER_POOL.get();
587 handshake_response.serialize(buf.as_mut());
588
589 self.write_packet(buf).await?;
590 Ok(())
591 }
592
593 async fn perform_auth_switch(
594 &mut self,
595 auth_switch_request: AuthSwitchRequest<'_>,
596 ) -> Result<()> {
597 if !self.inner.auth_switched {
598 self.inner.auth_switched = true;
599 self.inner.nonce = auth_switch_request.plugin_data().to_vec();
600
601 if matches!(
602 auth_switch_request.auth_plugin(),
603 AuthPlugin::MysqlOldPassword
604 ) && self.inner.opts.secure_auth()
605 {
606 return Err(DriverError::MysqlOldPasswordDisabled.into());
607 }
608
609 self.inner.auth_plugin = auth_switch_request.auth_plugin().clone().into_owned();
610
611 let plugin_data = match &self.inner.auth_plugin {
612 x @ AuthPlugin::CachingSha2Password => {
613 x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
614 }
615 x @ AuthPlugin::MysqlNativePassword => {
616 x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
617 }
618 x @ AuthPlugin::MysqlOldPassword => {
619 if self.inner.opts.secure_auth() {
620 return Err(DriverError::MysqlOldPasswordDisabled.into());
621 } else {
622 x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
623 }
624 }
625 x @ AuthPlugin::MysqlClearPassword => {
626 if self.inner.opts.enable_cleartext_plugin() {
627 x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
628 } else {
629 return Err(DriverError::CleartextPluginDisabled.into());
630 }
631 }
632 x @ AuthPlugin::Other(_) => x.gen_data(self.inner.opts.pass(), &self.inner.nonce),
633 };
634
635 if let Some(plugin_data) = plugin_data {
636 self.write_struct(&plugin_data.into_owned()).await?;
637 } else {
638 self.write_packet(crate::BUFFER_POOL.get()).await?;
639 }
640
641 self.continue_auth().await?;
642
643 Ok(())
644 } else {
645 unreachable!("auth_switched flag should be checked by caller")
646 }
647 }
648
649 fn continue_auth(&mut self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
650 Box::pin(async move {
653 match self.inner.auth_plugin {
654 AuthPlugin::MysqlNativePassword | AuthPlugin::MysqlOldPassword => {
655 self.continue_mysql_native_password_auth().await?;
656 Ok(())
657 }
658 AuthPlugin::CachingSha2Password => {
659 self.continue_caching_sha2_password_auth().await?;
660 Ok(())
661 }
662 AuthPlugin::MysqlClearPassword => {
663 if self.inner.opts.enable_cleartext_plugin() {
664 self.continue_mysql_native_password_auth().await?;
665 Ok(())
666 } else {
667 Err(DriverError::CleartextPluginDisabled.into())
668 }
669 }
670 AuthPlugin::Other(ref name) => Err(DriverError::UnknownAuthPlugin {
671 name: String::from_utf8_lossy(name.as_ref()).to_string(),
672 }
673 .into()),
674 }
675 })
676 }
677
678 fn switch_to_compression(&mut self) -> Result<()> {
679 if self
680 .capabilities()
681 .contains(CapabilityFlags::CLIENT_COMPRESS)
682 {
683 if let Some(compression) = self.inner.opts.compression() {
684 if let Some(stream) = self.inner.stream.as_mut() {
685 stream.compress(compression);
686 }
687 }
688 }
689 Ok(())
690 }
691
692 async fn continue_caching_sha2_password_auth(&mut self) -> Result<()> {
693 let packet = self.read_packet().await?;
694 match packet.first() {
695 Some(0x00) => {
696 Ok(())
698 }
699 Some(0x01) => match packet.get(1) {
700 Some(0x03) => {
701 self.drop_packet().await
703 }
704 Some(0x04) => {
705 let pass = self.inner.opts.pass().unwrap_or_default();
706 let mut pass = crate::BUFFER_POOL.get_with(pass.as_bytes());
707 pass.as_mut().push(0);
708
709 if self.is_secure() || self.is_socket() {
710 self.write_packet(pass).await?;
711 } else {
712 if self.inner.server_key.is_none() {
713 self.write_bytes(&[0x02][..]).await?;
714 let packet = self.read_packet().await?;
715 self.inner.server_key = Some(packet[1..].to_vec());
716 }
717 for (i, byte) in pass.as_mut().iter_mut().enumerate() {
718 *byte ^= self.inner.nonce[i % self.inner.nonce.len()];
719 }
720 let encrypted_pass = crypto::encrypt(
721 &pass,
722 self.inner.server_key.as_deref().expect("unreachable"),
723 );
724 self.write_bytes(&encrypted_pass).await?;
725 };
726 self.drop_packet().await?;
727 Ok(())
728 }
729 _ => Err(DriverError::UnexpectedPacket {
730 payload: packet.to_vec(),
731 }
732 .into()),
733 },
734 Some(0xfe) if !self.inner.auth_switched => {
735 let auth_switch_request = ParseBuf(&packet).parse::<AuthSwitchRequest>(())?;
736 self.perform_auth_switch(auth_switch_request).await?;
737 Ok(())
738 }
739 _ => Err(DriverError::UnexpectedPacket {
740 payload: packet.to_vec(),
741 }
742 .into()),
743 }
744 }
745
746 async fn continue_mysql_native_password_auth(&mut self) -> Result<()> {
747 let packet = self.read_packet().await?;
748 match packet.first() {
749 Some(0x00) => Ok(()),
750 Some(0xfe) if !self.inner.auth_switched => {
751 let auth_switch = if packet.len() > 1 {
752 ParseBuf(&packet).parse(())?
753 } else {
754 let _ = ParseBuf(&packet).parse::<OldAuthSwitchRequest>(())?;
755 AuthSwitchRequest::new(
757 "mysql_old_password".as_bytes(),
758 self.inner.nonce.clone(),
759 )
760 };
761 self.perform_auth_switch(auth_switch).await
762 }
763 _ => Err(DriverError::UnexpectedPacket {
764 payload: packet.to_vec(),
765 }
766 .into()),
767 }
768 }
769
770 fn handle_packet(&mut self, packet: &PooledBuf) -> Result<bool> {
772 let ok_packet = if self.has_pending_result() {
773 if self
774 .capabilities()
775 .contains(CapabilityFlags::CLIENT_DEPRECATE_EOF)
776 {
777 ParseBuf(packet)
778 .parse::<OkPacketDeserializer<ResultSetTerminator>>(self.capabilities())
779 .map(|x| x.into_inner())
780 } else {
781 ParseBuf(packet)
782 .parse::<OkPacketDeserializer<OldEofPacket>>(self.capabilities())
783 .map(|x| x.into_inner())
784 }
785 } else {
786 ParseBuf(packet)
787 .parse::<OkPacketDeserializer<CommonOkPacket>>(self.capabilities())
788 .map(|x| x.into_inner())
789 };
790
791 if let Ok(ok_packet) = ok_packet {
792 self.handle_ok(ok_packet.into_owned());
793 } else {
794 let err_packet = ParseBuf(packet).parse::<ErrPacket>(self.capabilities());
795 if let Ok(err_packet) = err_packet {
796 self.handle_err(err_packet)?;
797 return Ok(true);
798 }
799 }
800
801 Ok(false)
802 }
803
804 pub(crate) async fn read_packet(&mut self) -> Result<PooledBuf> {
805 loop {
806 let packet = crate::io::ReadPacket::new(&mut *self)
807 .await
808 .map_err(|io_err| {
809 self.inner.stream.take();
810 self.inner.disconnected = true;
811 Error::from(io_err)
812 })?;
813 if self.handle_packet(&packet)? {
814 continue;
816 } else {
817 return Ok(packet);
818 }
819 }
820 }
821
822 pub(crate) async fn read_packets(&mut self, n: usize) -> Result<Vec<PooledBuf>> {
824 let mut packets = Vec::with_capacity(n);
825 for _ in 0..n {
826 packets.push(self.read_packet().await?);
827 }
828 Ok(packets)
829 }
830
831 pub(crate) async fn write_packet(&mut self, data: PooledBuf) -> Result<()> {
832 crate::io::WritePacket::new(&mut *self, data)
833 .await
834 .map_err(|io_err| {
835 self.inner.stream.take();
836 self.inner.disconnected = true;
837 From::from(io_err)
838 })
839 }
840
841 pub(crate) async fn write_bytes(&mut self, bytes: &[u8]) -> Result<()> {
843 let buf = crate::BUFFER_POOL.get_with(bytes);
844 self.write_packet(buf).await
845 }
846
847 pub(crate) async fn write_struct<T: MySerialize>(&mut self, x: &T) -> Result<()> {
849 let mut buf = crate::BUFFER_POOL.get();
850 x.serialize(buf.as_mut());
851 self.write_packet(buf).await
852 }
853
854 pub(crate) async fn write_command<T: MySerialize>(&mut self, cmd: &T) -> Result<()> {
856 self.clean_dirty().await?;
857 self.reset_seq_id();
858 self.write_struct(cmd).await
859 }
860
861 pub(crate) async fn write_command_raw(&mut self, body: PooledBuf) -> Result<()> {
863 debug_assert!(!body.is_empty());
864 self.clean_dirty().await?;
865 self.reset_seq_id();
866 self.write_packet(body).await
867 }
868
869 pub(crate) async fn write_command_data<T>(&mut self, cmd: Command, cmd_data: T) -> Result<()>
871 where
872 T: AsRef<[u8]>,
873 {
874 let cmd_data = cmd_data.as_ref();
875 let mut buf = crate::BUFFER_POOL.get();
876 let body = buf.as_mut();
877 body.push(cmd as u8);
878 body.extend_from_slice(cmd_data);
879 self.write_command_raw(buf).await
880 }
881
882 async fn drop_packet(&mut self) -> Result<()> {
883 self.read_packet().await?;
884 Ok(())
885 }
886
887 async fn run_init_commands(&mut self) -> Result<()> {
888 let mut init = self.inner.opts.init().to_vec();
889
890 while let Some(query) = init.pop() {
891 self.query_drop(query).await?;
892 }
893
894 Ok(())
895 }
896
897 async fn run_setup_commands(&mut self) -> Result<()> {
898 let mut setup = self.inner.opts.setup().to_vec();
899
900 while let Some(query) = setup.pop() {
901 self.query_drop(query).await?;
902 }
903
904 Ok(())
905 }
906
907 pub fn new<T: Into<Opts>>(opts: T) -> crate::BoxFuture<'static, Conn> {
909 let opts = opts.into();
910 async move {
911 let mut conn = Conn::empty(opts.clone());
912
913 let stream = if let Some(_path) = opts.socket() {
914 #[cfg(unix)]
915 {
916 Stream::connect_socket(_path.to_owned()).await?
917 }
918 #[cfg(target_os = "windows")]
919 return Err(crate::DriverError::NamedPipesDisabled.into());
920 #[cfg(target_os = "wasi")]
921 return Err(crate::DriverError::NamedPipesDisabled.into());
922 } else {
923 let keepalive = opts
924 .tcp_keepalive()
925 .map(|x| std::time::Duration::from_millis(x.into()));
926 Stream::connect_tcp(opts.hostport_or_url(), keepalive).await?
927 };
928
929 conn.inner.stream = Some(stream);
930 conn.setup_stream()?;
931 conn.handle_handshake().await?;
932 conn.switch_to_ssl_if_needed().await?;
933 conn.do_handshake_response().await?;
934 conn.continue_auth().await?;
935 conn.switch_to_compression()?;
936 conn.read_settings().await?;
937 conn.reconnect_via_socket_if_needed().await?;
938 conn.run_init_commands().await?;
939 conn.run_setup_commands().await?;
940
941 Ok(conn)
942 }
943 .boxed()
944 }
945
946 pub async fn from_url<T: AsRef<str>>(url: T) -> Result<Conn> {
948 Conn::new(Opts::from_str(url.as_ref())?).await
949 }
950
951 async fn reconnect_via_socket_if_needed(&mut self) -> Result<()> {
955 if let Some(socket) = self.inner.socket.as_ref() {
956 let opts = self.inner.opts.clone();
957 if opts.socket().is_none() {
958 let opts = OptsBuilder::from_opts(opts).socket(Some(&**socket));
959 if let Ok(conn) = Conn::new(opts).await {
960 let old_conn = std::mem::replace(self, conn);
961 old_conn.close_conn().await?;
963 }
964 }
965 }
966 Ok(())
967 }
968
969 async fn read_settings(&mut self) -> Result<()> {
979 enum Action {
980 Load(Cfg),
981 Apply(CfgData),
982 }
983
984 enum CfgData {
985 MaxAllowedPacket(usize),
986 WaitTimeout(usize),
987 }
988
989 impl CfgData {
990 fn apply(&self, conn: &mut Conn) {
991 match self {
992 Self::MaxAllowedPacket(value) => {
993 if let Some(stream) = conn.inner.stream.as_mut() {
994 stream.set_max_allowed_packet(*value);
995 }
996 }
997 Self::WaitTimeout(value) => {
998 conn.inner.wait_timeout = Duration::from_secs(*value as u64);
999 }
1000 }
1001 }
1002 }
1003
1004 enum Cfg {
1005 Socket,
1006 MaxAllowedPacket,
1007 WaitTimeout,
1008 }
1009
1010 impl Cfg {
1011 const fn name(&self) -> &'static str {
1012 match self {
1013 Self::Socket => "@@socket",
1014 Self::MaxAllowedPacket => "@@max_allowed_packet",
1015 Self::WaitTimeout => "@@wait_timeout",
1016 }
1017 }
1018
1019 fn apply(&self, conn: &mut Conn, value: Option<crate::Value>) {
1020 match self {
1021 Cfg::Socket => {
1022 conn.inner.socket = value.and_then(crate::from_value);
1023 }
1024 Cfg::MaxAllowedPacket => {
1025 if let Some(stream) = conn.inner.stream.as_mut() {
1026 stream.set_max_allowed_packet(
1027 value
1028 .and_then(crate::from_value)
1029 .unwrap_or(DEFAULT_MAX_ALLOWED_PACKET),
1030 );
1031 }
1032 }
1033 Cfg::WaitTimeout => {
1034 conn.inner.wait_timeout = Duration::from_secs(
1035 value
1036 .and_then(crate::from_value)
1037 .unwrap_or(DEFAULT_WAIT_TIMEOUT) as u64,
1038 );
1039 }
1040 }
1041 }
1042 }
1043
1044 let mut actions = vec![
1045 if let Some(x) = self.opts().max_allowed_packet() {
1046 Action::Apply(CfgData::MaxAllowedPacket(x))
1047 } else {
1048 Action::Load(Cfg::MaxAllowedPacket)
1049 },
1050 if let Some(x) = self.opts().wait_timeout() {
1051 Action::Apply(CfgData::WaitTimeout(x))
1052 } else {
1053 Action::Load(Cfg::WaitTimeout)
1054 },
1055 ];
1056
1057 if self.inner.opts.prefer_socket() && self.inner.socket.is_none() {
1058 actions.push(Action::Load(Cfg::Socket))
1059 }
1060
1061 let loads = actions
1062 .iter()
1063 .filter_map(|x| match x {
1064 Action::Load(x) => Some(x),
1065 Action::Apply(_) => None,
1066 })
1067 .collect::<Vec<_>>();
1068
1069 let loaded = if !loads.is_empty() {
1070 let query = loads
1071 .iter()
1072 .zip(std::iter::once(' ').chain(std::iter::repeat(',')))
1073 .fold("SELECT".to_owned(), |mut acc, (cfg, prefix)| {
1074 acc.push(prefix);
1075 acc.push_str(cfg.name());
1076 acc
1077 });
1078
1079 self.query_internal::<Row, String>(query)
1080 .await?
1081 .map(|row| row.unwrap())
1082 .unwrap_or_else(|| vec![crate::Value::NULL; loads.len()])
1083 } else {
1084 vec![]
1085 };
1086 let mut loaded = loaded.into_iter();
1087
1088 for action in actions {
1089 match action {
1090 Action::Load(cfg) => cfg.apply(self, loaded.next()),
1091 Action::Apply(cfg) => cfg.apply(self),
1092 }
1093 }
1094
1095 Ok(())
1096 }
1097
1098 fn expired(&self) -> bool {
1101 if let Some(deadline) = self.inner.ttl_deadline {
1102 if Instant::now() > deadline {
1103 return true;
1104 }
1105 }
1106 let ttl = self
1107 .inner
1108 .opts
1109 .conn_ttl()
1110 .unwrap_or(self.inner.wait_timeout);
1111 !ttl.is_zero() && self.idling() > ttl
1112 }
1113
1114 fn idling(&self) -> Duration {
1116 self.inner.last_io.elapsed()
1117 }
1118
1119 pub async fn reset(&mut self) -> Result<bool> {
1126 let supports_com_reset_connection = if self.inner.is_mariadb {
1127 self.inner.version >= (10, 2, 4)
1128 } else {
1129 self.inner.version > (5, 7, 2)
1131 };
1132
1133 if supports_com_reset_connection {
1134 self.routine(routines::ResetRoutine).await?;
1135 self.inner.stmt_cache.clear();
1136 self.inner.infile_handler = None;
1137 self.run_setup_commands().await?;
1138 }
1139
1140 Ok(supports_com_reset_connection)
1141 }
1142
1143 pub async fn change_user(&mut self, opts: ChangeUserOpts) -> Result<()> {
1155 if opts != ChangeUserOpts::default() {
1157 let mut opts_changed = false;
1158 if let Some(user) = opts.user() {
1159 opts_changed |= user != self.opts().user()
1160 };
1161 if let Some(pass) = opts.pass() {
1162 opts_changed |= pass != self.opts().pass()
1163 };
1164 if let Some(db_name) = opts.db_name() {
1165 opts_changed |= db_name != self.opts().db_name()
1166 };
1167 if opts_changed {
1168 if let Some(pool) = self.inner.pool.take() {
1169 pool.cancel_connection();
1170 }
1171 }
1172 }
1173
1174 let conn_opts = &mut self.inner.opts;
1175 opts.update_opts(conn_opts);
1176 self.routine(routines::ChangeUser).await?;
1177 self.inner.stmt_cache.clear();
1178 self.inner.infile_handler = None;
1179 self.run_setup_commands().await?;
1180 Ok(())
1181 }
1182
1183 async fn reset_for_pool(mut self) -> Result<Self> {
1187 if !self.reset().await? {
1188 self.change_user(Default::default()).await?;
1189 }
1190 Ok(self)
1191 }
1192
1193 async fn rollback_transaction(&mut self) -> Result<()> {
1195 debug_assert_ne!(self.inner.tx_status, TxStatus::None);
1196 self.inner.tx_status = TxStatus::None;
1197 self.query_drop("ROLLBACK").await
1198 }
1199
1200 pub(crate) fn more_results_exists(&self) -> bool {
1203 self.status()
1204 .contains(StatusFlags::SERVER_MORE_RESULTS_EXISTS)
1205 }
1206
1207 pub(crate) async fn drop_result(&mut self) -> Result<()> {
1212 let meta = match self.set_pending_result(None)? {
1214 Some(PendingResult::Pending(meta)) => Some(meta),
1215 Some(PendingResult::Taken(meta)) => {
1216 Some(Arc::try_unwrap(meta).expect("Conn::drop_result call on a pending result that may still be droped by someone else"))
1219 }
1220 None => None,
1221 };
1222
1223 let _ = self.set_pending_result(meta);
1224
1225 match self.use_pending_result() {
1226 Ok(Some(PendingResult::Pending(ResultSetMeta::Text(_)))) => {
1227 QueryResult::<'_, '_, TextProtocol>::new(self)
1228 .drop_result()
1229 .await
1230 }
1231 Ok(Some(PendingResult::Pending(ResultSetMeta::Binary(_)))) => {
1232 QueryResult::<'_, '_, BinaryProtocol>::new(self)
1233 .drop_result()
1234 .await
1235 }
1236 Ok(None) => Ok(()),
1237 Ok(Some(PendingResult::Taken(_))) | Err(_) => {
1238 unreachable!("this case must be handled earlier in this function")
1239 }
1240 }
1241 }
1242
1243 async fn cleanup_for_pool(mut self) -> Result<Self> {
1247 loop {
1248 let result = if self.has_pending_result() {
1249 self.drop_result().await
1250 } else if self.inner.tx_status != TxStatus::None {
1251 self.rollback_transaction().await
1252 } else {
1253 break;
1254 };
1255
1256 if let Err(err) = result {
1260 if err.is_fatal() {
1261 return Err(err);
1264 }
1265 }
1266 }
1267 Ok(self)
1268 }
1269}
1270
1271#[cfg(test)]
1272mod test {
1273 use bytes::Bytes;
1274 use futures_util::stream::{self, StreamExt};
1275 use mysql_common::constants::MAX_PAYLOAD_LEN;
1276 use rand::Fill;
1277
1278 use crate::{
1279 from_row, params, prelude::*, test_misc::get_opts, ChangeUserOpts, Conn, Error,
1280 OptsBuilder, Pool, Value, WhiteListFsHandler,
1281 };
1282
1283 #[tokio::test]
1284 async fn should_return_found_rows_if_flag_is_set() -> super::Result<()> {
1285 let opts = get_opts().client_found_rows(true);
1286 let mut conn = Conn::new(opts).await.unwrap();
1287
1288 "CREATE TEMPORARY TABLE mysql.found_rows (id INT PRIMARY KEY AUTO_INCREMENT, val INT)"
1289 .ignore(&mut conn)
1290 .await?;
1291
1292 "INSERT INTO mysql.found_rows (val) VALUES (1)"
1293 .ignore(&mut conn)
1294 .await?;
1295
1296 assert_eq!(conn.affected_rows(), 1);
1298
1299 "UPDATE mysql.found_rows SET val = 1 WHERE val = 1"
1300 .ignore(&mut conn)
1301 .await?;
1302
1303 assert_eq!(conn.affected_rows(), 1);
1306
1307 Ok(())
1308 }
1309
1310 #[tokio::test]
1311 async fn should_not_return_found_rows_if_flag_is_not_set() -> super::Result<()> {
1312 let mut conn = Conn::new(get_opts()).await.unwrap();
1313
1314 "CREATE TEMPORARY TABLE mysql.found_rows (id INT PRIMARY KEY AUTO_INCREMENT, val INT)"
1315 .ignore(&mut conn)
1316 .await?;
1317
1318 "INSERT INTO mysql.found_rows (val) VALUES (1)"
1319 .ignore(&mut conn)
1320 .await?;
1321
1322 assert_eq!(conn.affected_rows(), 1);
1324
1325 "UPDATE mysql.found_rows SET val = 1 WHERE val = 1"
1326 .ignore(&mut conn)
1327 .await?;
1328
1329 assert_eq!(conn.affected_rows(), 0);
1331
1332 Ok(())
1333 }
1334
1335 #[test]
1336 fn opts_should_satisfy_send_and_sync() {
1337 struct A<T: Sync + Send>(T);
1338 #[allow(clippy::unnecessary_operation)]
1339 A(get_opts());
1340 }
1341
1342 #[tokio::test]
1343 async fn should_connect_without_database() -> super::Result<()> {
1344 let mut conn: Conn = Conn::new(get_opts().db_name(None::<String>)).await?;
1346 conn.ping().await?;
1347 conn.disconnect().await?;
1348
1349 let mut conn: Conn = Conn::new(get_opts().db_name(Some(""))).await?;
1351 conn.ping().await?;
1352 conn.disconnect().await?;
1353
1354 Ok(())
1355 }
1356
1357 #[tokio::test]
1358 async fn should_clean_state_if_wrapper_is_dropeed() -> super::Result<()> {
1359 let mut conn: Conn = Conn::new(get_opts()).await?;
1360
1361 conn.query_drop("CREATE TEMPORARY TABLE mysql.foo (id SERIAL)")
1362 .await?;
1363
1364 conn.query_iter("SELECT 1").await?;
1366 conn.ping().await?;
1367
1368 let mut tx = conn.start_transaction(Default::default()).await?;
1370 tx.query_drop("INSERT INTO mysql.foo (id) VALUES (42)")
1371 .await?;
1372 tx.exec_iter("SELECT COUNT(*) FROM mysql.foo", ()).await?;
1373 drop(tx);
1374 conn.ping().await?;
1375
1376 let count: u8 = conn
1377 .query_first("SELECT COUNT(*) FROM mysql.foo")
1378 .await?
1379 .unwrap_or_default();
1380
1381 assert_eq!(count, 0);
1382
1383 Ok(())
1384 }
1385
1386 #[tokio::test]
1387 async fn should_connect() -> super::Result<()> {
1388 let mut conn: Conn = Conn::new(get_opts()).await?;
1389 conn.ping().await?;
1390 let plugins: Vec<String> = conn
1391 .query_map("SHOW PLUGINS", |mut row: crate::Row| {
1392 row.take("Name").unwrap()
1393 })
1394 .await?;
1395
1396 let variants = vec![
1398 ("caching_sha2_password", 2_u8, "non-empty"),
1399 ("caching_sha2_password", 2_u8, ""),
1400 ("mysql_native_password", 0_u8, "non-empty"),
1401 ("mysql_native_password", 0_u8, ""),
1402 ]
1403 .into_iter()
1404 .filter(|variant| plugins.iter().any(|p| p == variant.0));
1405
1406 for (plug, val, pass) in variants {
1407 let _ = conn.query_drop("DROP USER 'test_user'@'%'").await;
1408
1409 let query = format!("CREATE USER 'test_user'@'%' IDENTIFIED WITH {}", plug);
1410 conn.query_drop(query).await.unwrap();
1411
1412 if (8, 0, 11) <= conn.inner.version && conn.inner.version <= (9, 0, 0) {
1413 conn.query_drop(format!("SET PASSWORD FOR 'test_user'@'%' = '{}'", pass))
1414 .await
1415 .unwrap();
1416 } else {
1417 conn.query_drop(format!("SET old_passwords = {}", val))
1418 .await
1419 .unwrap();
1420 conn.query_drop(format!(
1421 "SET PASSWORD FOR 'test_user'@'%' = PASSWORD('{}')",
1422 pass
1423 ))
1424 .await
1425 .unwrap();
1426 };
1427
1428 let opts = get_opts()
1429 .user(Some("test_user"))
1430 .pass(Some(pass))
1431 .db_name(None::<String>);
1432 let result = Conn::new(opts).await;
1433
1434 conn.query_drop("DROP USER 'test_user'@'%'").await.unwrap();
1435
1436 result?.disconnect().await?;
1437 }
1438
1439 if crate::test_misc::test_compression() {
1440 assert!(format!("{:?}", conn).contains("Compression"));
1441 }
1442
1443 if crate::test_misc::test_ssl() {
1444 assert!(format!("{:?}", conn).contains("Tls"));
1445 }
1446
1447 conn.disconnect().await?;
1448 Ok(())
1449 }
1450
1451 #[test]
1452 fn should_not_panic_if_dropped_without_tokio_runtime() {
1453 let fut = Conn::new(get_opts());
1454 let runtime = tokio::runtime::Builder::new_current_thread()
1455 .enable_all()
1456 .build()
1457 .unwrap();
1458 runtime.block_on(async {
1459 fut.await.unwrap();
1460 });
1461 }
1463
1464 #[tokio::test]
1465 async fn should_execute_init_queries_on_new_connection() -> super::Result<()> {
1466 let opts = OptsBuilder::from_opts(get_opts()).init(vec!["SET @a = 42", "SET @b = 'foo'"]);
1467 let mut conn = Conn::new(opts).await?;
1468 let result: Vec<(u8, String)> = conn.query("SELECT @a, @b").await?;
1469 conn.disconnect().await?;
1470 assert_eq!(result, vec![(42, "foo".into())]);
1471 Ok(())
1472 }
1473
1474 #[tokio::test]
1475 async fn should_execute_setup_queries_on_reset() -> super::Result<()> {
1476 let opts = OptsBuilder::from_opts(get_opts()).setup(vec!["SET @a = 42", "SET @b = 'foo'"]);
1477 let mut conn = Conn::new(opts).await?;
1478
1479 let mut result: Vec<(u8, String)> = conn.query("SELECT @a, @b").await?;
1481 assert_eq!(result, vec![(42, "foo".into())]);
1482
1483 if conn.reset().await? {
1485 result = conn.query("SELECT @a, @b").await?;
1486 assert_eq!(result, vec![(42, "foo".into())]);
1487 }
1488
1489 conn.change_user(Default::default()).await?;
1491 result = conn.query("SELECT @a, @b").await?;
1492 assert_eq!(result, vec![(42, "foo".into())]);
1493
1494 conn.disconnect().await?;
1495 Ok(())
1496 }
1497
1498 #[tokio::test]
1499 async fn should_reset_the_connection() -> super::Result<()> {
1500 let mut conn = Conn::new(get_opts()).await?;
1501
1502 assert_eq!(
1503 conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1504 Value::NULL
1505 );
1506
1507 conn.query_drop("SET @foo = 'foo'").await?;
1508
1509 assert_eq!(
1510 conn.query_first::<String, _>("SELECT @foo").await?.unwrap(),
1511 "foo",
1512 );
1513
1514 if conn.reset().await? {
1515 assert_eq!(
1516 conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1517 Value::NULL
1518 );
1519 } else {
1520 assert_eq!(
1521 conn.query_first::<String, _>("SELECT @foo").await?.unwrap(),
1522 "foo",
1523 );
1524 }
1525
1526 conn.disconnect().await?;
1527 Ok(())
1528 }
1529
1530 #[tokio::test]
1531 async fn should_change_user() -> super::Result<()> {
1532 let mut conn = Conn::new(get_opts()).await?;
1533 assert_eq!(
1534 conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1535 Value::NULL
1536 );
1537
1538 conn.query_drop("SET @foo = 'foo'").await?;
1539
1540 assert_eq!(
1541 conn.query_first::<String, _>("SELECT @foo").await?.unwrap(),
1542 "foo",
1543 );
1544
1545 conn.change_user(Default::default()).await?;
1546 assert_eq!(
1547 conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1548 Value::NULL
1549 );
1550
1551 let plugins: &[&str] = if !conn.inner.is_mariadb && conn.server_version() >= (5, 8, 0) {
1552 &["mysql_native_password", "caching_sha2_password"]
1553 } else {
1554 &["mysql_native_password"]
1555 };
1556
1557 for plugin in plugins {
1558 let mut rng = rand::thread_rng();
1559 let mut pass = [0u8; 10];
1560 pass.try_fill(&mut rng).unwrap();
1561 let pass: String = IntoIterator::into_iter(pass)
1562 .map(|x| ((x % (123 - 97)) + 97) as char)
1563 .collect();
1564
1565 conn.query_drop("DELETE FROM mysql.user WHERE user = '__mats'")
1566 .await
1567 .unwrap();
1568 conn.query_drop("FLUSH PRIVILEGES").await.unwrap();
1569
1570 if conn.inner.is_mariadb || conn.server_version() < (5, 7, 0) {
1571 if matches!(conn.server_version(), (5, 6, _)) {
1572 conn.query_drop("CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_old_password")
1573 .await
1574 .unwrap();
1575 conn.query_drop(format!(
1576 "SET PASSWORD FOR '__mats'@'%' = OLD_PASSWORD({})",
1577 Value::from(pass.clone()).as_sql(false)
1578 ))
1579 .await
1580 .unwrap();
1581 } else {
1582 conn.query_drop("CREATE USER '__mats'@'%'").await.unwrap();
1583 conn.query_drop(format!(
1584 "SET PASSWORD FOR '__mats'@'%' = PASSWORD({})",
1585 Value::from(pass.clone()).as_sql(false)
1586 ))
1587 .await
1588 .unwrap();
1589 }
1590 } else {
1591 conn.query_drop(format!(
1592 "CREATE USER '__mats'@'%' IDENTIFIED WITH {} BY {}",
1593 plugin,
1594 Value::from(pass.clone()).as_sql(false)
1595 ))
1596 .await
1597 .unwrap();
1598 };
1599
1600 conn.query_drop("FLUSH PRIVILEGES").await.unwrap();
1601
1602 let mut conn2 = Conn::new(get_opts().secure_auth(false)).await.unwrap();
1603 conn2
1604 .change_user(
1605 ChangeUserOpts::default()
1606 .with_db_name(None)
1607 .with_user(Some("__mats".into()))
1608 .with_pass(Some(pass)),
1609 )
1610 .await
1611 .unwrap();
1612 let (db, user) = conn2
1613 .query_first::<(Option<String>, String), _>("SELECT DATABASE(), USER();")
1614 .await
1615 .unwrap()
1616 .unwrap();
1617 assert_eq!(db, None);
1618 assert!(user.starts_with("__mats"));
1619
1620 conn2.disconnect().await.unwrap();
1621 }
1622
1623 conn.disconnect().await?;
1624 Ok(())
1625 }
1626
1627 #[tokio::test]
1628 async fn should_not_cache_statements_if_stmt_cache_size_is_zero() -> super::Result<()> {
1629 let opts = OptsBuilder::from_opts(get_opts()).stmt_cache_size(0);
1630
1631 let mut conn = Conn::new(opts).await?;
1632 conn.exec_drop("DO ?", (1_u8,)).await?;
1633
1634 let stmt = conn.prep("DO 2").await?;
1635 conn.exec_drop(&stmt, ()).await?;
1636 conn.exec_drop(&stmt, ()).await?;
1637 conn.close(stmt).await?;
1638
1639 conn.exec_drop("DO 3", ()).await?;
1640 conn.exec_batch("DO 4", vec![(), ()]).await?;
1641 conn.exec_first::<u8, _, _>("DO 5", ()).await?;
1642 let row: Option<(crate::Value, usize)> = conn
1643 .query_first("SHOW SESSION STATUS LIKE 'Com_stmt_close';")
1644 .await?;
1645
1646 assert_eq!(row.unwrap().1, 1);
1647 assert_eq!(conn.inner.stmt_cache.len(), 0);
1648
1649 conn.disconnect().await?;
1650
1651 Ok(())
1652 }
1653
1654 #[tokio::test]
1655 async fn should_hold_stmt_cache_size_bound() -> super::Result<()> {
1656 let opts = OptsBuilder::from_opts(get_opts()).stmt_cache_size(3);
1657 let mut conn = Conn::new(opts).await?;
1658 conn.exec_drop("DO 1", ()).await?;
1659 conn.exec_drop("DO 2", ()).await?;
1660 conn.exec_drop("DO 3", ()).await?;
1661 conn.exec_drop("DO 1", ()).await?;
1662 conn.exec_drop("DO 4", ()).await?;
1663 conn.exec_drop("DO 3", ()).await?;
1664 conn.exec_drop("DO 5", ()).await?;
1665 conn.exec_drop("DO 6", ()).await?;
1666 let row_opt = conn
1667 .query_first("SHOW SESSION STATUS LIKE 'Com_stmt_close';")
1668 .await?;
1669 let (_, count): (String, usize) = row_opt.unwrap();
1670 assert_eq!(count, 3);
1671 let order = conn
1672 .stmt_cache_ref()
1673 .iter()
1674 .map(|item| item.1.query.0.as_ref())
1675 .collect::<Vec<&[u8]>>();
1676 assert_eq!(order, &[b"DO 6", b"DO 5", b"DO 3"]);
1677 conn.disconnect().await?;
1678 Ok(())
1679 }
1680
1681 #[tokio::test]
1682 async fn should_perform_queries() -> super::Result<()> {
1683 let mut conn = Conn::new(get_opts()).await?;
1684 for x in (MAX_PAYLOAD_LEN - 2)..=(MAX_PAYLOAD_LEN + 2) {
1685 let long_string = "A".repeat(x);
1686 let result: Vec<(String, u8)> = conn
1687 .query(format!(r"SELECT '{}', 231", long_string))
1688 .await?;
1689 assert_eq!((long_string, 231_u8), result[0]);
1690 }
1691 conn.disconnect().await?;
1692 Ok(())
1693 }
1694
1695 #[tokio::test]
1696 async fn should_query_drop() -> super::Result<()> {
1697 let mut conn = Conn::new(get_opts()).await?;
1698 conn.query_drop("CREATE TEMPORARY TABLE tmp (id int DEFAULT 10, name text)")
1699 .await?;
1700 conn.query_drop("INSERT INTO tmp VALUES (1, 'foo')").await?;
1701 let result: Option<u8> = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
1702 conn.disconnect().await?;
1703 assert_eq!(result, Some(1_u8));
1704 Ok(())
1705 }
1706
1707 #[tokio::test]
1708 async fn should_prepare_statement() -> super::Result<()> {
1709 let mut conn = Conn::new(get_opts()).await?;
1710 let stmt = conn.prep(r"SELECT ?").await?;
1711 conn.close(stmt).await?;
1712 conn.disconnect().await?;
1713
1714 let mut conn = Conn::new(get_opts()).await?;
1715 let stmt = conn.prep(r"SELECT :foo").await?;
1716
1717 {
1718 let query = String::from("SELECT ?, ?");
1719 let stmt = conn.prep(&*query).await?;
1720 conn.close(stmt).await?;
1721 {
1722 let mut conn = Conn::new(get_opts()).await?;
1723 let stmt = conn.prep(&*query).await?;
1724 conn.close(stmt).await?;
1725 conn.disconnect().await?;
1726 }
1727 }
1728
1729 conn.close(stmt).await?;
1730 conn.disconnect().await?;
1731
1732 Ok(())
1733 }
1734
1735 #[tokio::test]
1736 async fn should_execute_statement() -> super::Result<()> {
1737 let long_string = "A".repeat(18 * 1024 * 1024);
1738 let mut conn = Conn::new(get_opts()).await?;
1739 let stmt = conn.prep(r"SELECT ?").await?;
1740 let result = conn.exec_iter(&stmt, (&long_string,)).await?;
1741 let mut mapped = result.map_and_drop(from_row::<(String,)>).await?;
1742 assert_eq!(mapped.len(), 1);
1743 assert_eq!(mapped.pop(), Some((long_string,)));
1744 let result = conn.exec_iter(&stmt, (42_u8,)).await?;
1745 let collected = result.collect_and_drop::<(u8,)>().await?;
1746 assert_eq!(collected, vec![(42u8,)]);
1747 let result = conn.exec_iter(&stmt, (8_u8,)).await?;
1748 let reduced = result
1749 .reduce_and_drop(2, |mut acc, row| {
1750 acc += from_row::<i32>(row);
1751 acc
1752 })
1753 .await?;
1754 conn.close(stmt).await?;
1755 conn.disconnect().await?;
1756 assert_eq!(reduced, 10);
1757
1758 let mut conn = Conn::new(get_opts()).await?;
1759 let stmt = conn.prep(r"SELECT :foo, :bar, :foo, 3").await?;
1760 let result = conn
1761 .exec_iter(&stmt, params! { "foo" => "quux", "bar" => "baz" })
1762 .await?;
1763 let mut mapped = result
1764 .map_and_drop(from_row::<(String, String, String, u8)>)
1765 .await?;
1766 assert_eq!(mapped.len(), 1);
1767 assert_eq!(
1768 mapped.pop(),
1769 Some(("quux".into(), "baz".into(), "quux".into(), 3))
1770 );
1771 let result = conn
1772 .exec_iter(&stmt, params! { "foo" => 2, "bar" => 3 })
1773 .await?;
1774 let collected = result.collect_and_drop::<(u8, u8, u8, u8)>().await?;
1775 assert_eq!(collected, vec![(2, 3, 2, 3)]);
1776 let result = conn
1777 .exec_iter(&stmt, params! { "foo" => 2, "bar" => 3 })
1778 .await?;
1779 let reduced = result
1780 .reduce_and_drop(0, |acc, row| {
1781 let (a, b, c, d): (u8, u8, u8, u8) = from_row(row);
1782 acc + a + b + c + d
1783 })
1784 .await?;
1785 conn.close(stmt).await?;
1786 conn.disconnect().await?;
1787 assert_eq!(reduced, 10);
1788 Ok(())
1789 }
1790
1791 #[tokio::test]
1792 async fn should_prep_exec_statement() -> super::Result<()> {
1793 let mut conn = Conn::new(get_opts()).await?;
1794 let result = conn
1795 .exec_iter(r"SELECT :a, :b, :a", params! { "a" => 2, "b" => 3 })
1796 .await?;
1797 let output = result
1798 .map_and_drop(|row| {
1799 let (a, b, c): (u8, u8, u8) = from_row(row);
1800 a * b * c
1801 })
1802 .await?;
1803 conn.disconnect().await?;
1804 assert_eq!(output[0], 12u8);
1805 Ok(())
1806 }
1807
1808 #[tokio::test]
1809 async fn should_first_exec_statement() -> super::Result<()> {
1810 let mut conn = Conn::new(get_opts()).await?;
1811 let output = conn
1812 .exec_first(
1813 r"SELECT :a UNION ALL SELECT :b",
1814 params! { "a" => 2, "b" => 3 },
1815 )
1816 .await?;
1817 conn.disconnect().await?;
1818 assert_eq!(output, Some(2u8));
1819 Ok(())
1820 }
1821
1822 #[tokio::test]
1823 async fn issue_107() -> super::Result<()> {
1824 let mut conn = Conn::new(get_opts()).await?;
1825 conn.query_drop(
1826 r"CREATE TEMPORARY TABLE mysql.issue (
1827 a BIGINT(20) UNSIGNED,
1828 b VARBINARY(16),
1829 c BINARY(32),
1830 d BIGINT(20) UNSIGNED,
1831 e BINARY(32)
1832 )",
1833 )
1834 .await?;
1835 conn.query_drop(
1836 r"INSERT INTO mysql.issue VALUES (
1837 0,
1838 0xC066F966B0860000,
1839 0x7939DA98E524C5F969FC2DE8D905FD9501EBC6F20001B0A9C941E0BE6D50CF44,
1840 0,
1841 ''
1842 ), (
1843 1,
1844 '',
1845 0x076311DF4D407B0854371BA13A5F3FB1A4555AC22B361375FD47B263F31822F2,
1846 0,
1847 ''
1848 )",
1849 )
1850 .await?;
1851
1852 let q = "SELECT b, c, d, e FROM mysql.issue";
1853 let result = conn.query_iter(q).await?;
1854
1855 let loaded_structs = result
1856 .map_and_drop(crate::from_row::<(Vec<u8>, Vec<u8>, u64, Vec<u8>)>)
1857 .await?;
1858
1859 conn.disconnect().await?;
1860
1861 assert_eq!(loaded_structs.len(), 2);
1862
1863 Ok(())
1864 }
1865
1866 #[tokio::test]
1867 async fn should_run_transactions() -> super::Result<()> {
1868 let mut conn = Conn::new(get_opts()).await?;
1869 conn.query_drop("CREATE TEMPORARY TABLE tmp (id INT, name TEXT)")
1870 .await?;
1871 let mut transaction = conn.start_transaction(Default::default()).await?;
1872 transaction
1873 .query_drop("INSERT INTO tmp VALUES (1, 'foo'), (2, 'bar')")
1874 .await?;
1875 assert_eq!(transaction.last_insert_id(), None);
1876 assert_eq!(transaction.affected_rows(), 2);
1877 assert_eq!(transaction.get_warnings(), 0);
1878 assert_eq!(transaction.info(), "Records: 2 Duplicates: 0 Warnings: 0");
1879 transaction.commit().await?;
1880 let output_opt = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
1881 assert_eq!(output_opt, Some((2u8,)));
1882 let mut transaction = conn.start_transaction(Default::default()).await?;
1883 transaction
1884 .query_drop("INSERT INTO tmp VALUES (3, 'baz'), (4, 'quux')")
1885 .await?;
1886 let output_opt = transaction
1887 .exec_first("SELECT COUNT(*) FROM tmp", ())
1888 .await?;
1889 assert_eq!(output_opt, Some((4u8,)));
1890 transaction.rollback().await?;
1891 let output_opt = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
1892 assert_eq!(output_opt, Some((2u8,)));
1893
1894 let mut transaction = conn.start_transaction(Default::default()).await?;
1895 transaction
1896 .query_drop("INSERT INTO tmp VALUES (3, 'baz')")
1897 .await?;
1898 drop(transaction); let output_opt = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
1900 assert_eq!(output_opt, Some((2u8,)));
1901
1902 conn.disconnect().await?;
1903 Ok(())
1904 }
1905
1906 #[tokio::test]
1907 async fn should_handle_multiresult_set_with_error() -> super::Result<()> {
1908 const QUERY_FIRST: &str = "SELECT * FROM tmp; SELECT 1; SELECT 2;";
1909 const QUERY_MIDDLE: &str = "SELECT 1; SELECT * FROM tmp; SELECT 2";
1910 let mut conn = Conn::new(get_opts()).await.unwrap();
1911
1912 let result = QUERY_FIRST.run(&mut conn).await;
1914 assert!(matches!(result, Err(Error::Server(_))));
1915
1916 let mut result = QUERY_MIDDLE.run(&mut conn).await.unwrap();
1917
1918 let result_set: Vec<u8> = result.collect().await.unwrap();
1920 assert_eq!(result_set, vec![1]);
1921
1922 let result_set: super::Result<Vec<u8>> = result.collect().await;
1924 assert!(matches!(result_set, Err(Error::Server(_))));
1925
1926 assert!(result.is_empty());
1928
1929 conn.ping().await?;
1930 conn.disconnect().await?;
1931
1932 Ok(())
1933 }
1934
1935 #[tokio::test]
1936 async fn should_handle_binary_multiresult_set_with_error() -> super::Result<()> {
1937 const PROC_DEF_FIRST: &str =
1938 r#"CREATE PROCEDURE err_first() BEGIN SELECT * FROM tmp; SELECT 1; END"#;
1939 const PROC_DEF_MIDDLE: &str =
1940 r#"CREATE PROCEDURE err_middle() BEGIN SELECT 1; SELECT * FROM tmp; SELECT 2; END"#;
1941
1942 let mut conn = Conn::new(get_opts()).await.unwrap();
1943
1944 conn.query_drop("DROP PROCEDURE IF EXISTS err_first")
1945 .await?;
1946 conn.query_iter(PROC_DEF_FIRST).await?;
1947
1948 conn.query_drop("DROP PROCEDURE IF EXISTS err_middle")
1949 .await?;
1950 conn.query_iter(PROC_DEF_MIDDLE).await?;
1951
1952 let result = conn.query_iter("CALL err_first()").await;
1954 assert!(matches!(result, Err(Error::Server(_))));
1955
1956 let mut result = conn.query_iter("CALL err_middle()").await?;
1957
1958 let result_set: Vec<u8> = result.collect().await.unwrap();
1960 assert_eq!(result_set, vec![1]);
1961
1962 let result_set: super::Result<Vec<u8>> = result.collect().await;
1964 assert!(matches!(result_set, Err(Error::Server(_))));
1965
1966 assert!(result.is_empty());
1968
1969 conn.ping().await?;
1970 conn.disconnect().await?;
1971
1972 Ok(())
1973 }
1974
1975 #[tokio::test]
1976 async fn should_handle_multiresult_set_with_local_infile() -> super::Result<()> {
1977 use std::fs::write;
1978
1979 let file_path = tempfile::Builder::new().tempfile_in("").unwrap();
1980 let file_path = file_path.path();
1981 let file_name = file_path.file_name().unwrap();
1982
1983 write(file_name, b"AAAAAA\nBBBBBB\nCCCCCC\n")?;
1984
1985 let opts = get_opts().local_infile_handler(Some(WhiteListFsHandler::new(&[file_name][..])));
1986
1987 let mut conn = Conn::new(opts).await.unwrap();
1989 "CREATE TEMPORARY TABLE tmp (a TEXT)".run(&mut conn).await?;
1990
1991 let query = format!(
1992 r#"SELECT * FROM tmp;
1993 LOAD DATA LOCAL INFILE "{}" INTO TABLE tmp;
1994 LOAD DATA LOCAL INFILE "{}" INTO TABLE tmp;
1995 SELECT * FROM tmp"#,
1996 file_name.to_str().unwrap(),
1997 file_name.to_str().unwrap(),
1998 );
1999
2000 let mut result = query.run(&mut conn).await?;
2001
2002 let result_set = result.collect::<String>().await?;
2003 assert_eq!(result_set.len(), 0);
2004
2005 let mut no_local_infile = false;
2006
2007 for _ in 0..2 {
2008 match result.collect::<String>().await {
2009 Ok(result_set) => {
2010 assert_eq!(result.affected_rows(), 3);
2011 assert!(result_set.is_empty())
2012 }
2013 Err(Error::Server(ref err)) if err.code == 1148 => {
2014 no_local_infile = true;
2016 break;
2017 }
2018 Err(Error::Server(ref err)) if err.code == 3948 => {
2019 no_local_infile = true;
2022 break;
2023 }
2024 Err(err) => return Err(err),
2025 }
2026 }
2027
2028 if no_local_infile {
2029 assert!(result.is_empty());
2030 assert_eq!(result_set.len(), 0);
2031 } else {
2032 let result_set = result.collect::<String>().await?;
2033 assert_eq!(result_set.len(), 6);
2034 assert_eq!(result_set[0], "AAAAAA");
2035 assert_eq!(result_set[1], "BBBBBB");
2036 assert_eq!(result_set[2], "CCCCCC");
2037 assert_eq!(result_set[3], "AAAAAA");
2038 assert_eq!(result_set[4], "BBBBBB");
2039 assert_eq!(result_set[5], "CCCCCC");
2040 }
2041
2042 conn.ping().await?;
2043 conn.disconnect().await?;
2044
2045 Ok(())
2046 }
2047
2048 #[tokio::test]
2049 async fn should_provide_multiresult_set_metadata() -> super::Result<()> {
2050 let mut c = Conn::new(get_opts()).await?;
2051 c.query_drop("CREATE TEMPORARY TABLE tmp (id INT, foo TEXT)")
2052 .await?;
2053
2054 let mut result = c
2055 .query_iter("SELECT 1; SELECT id, foo FROM tmp WHERE 1 = 2; DO 42; SELECT 2;")
2056 .await?;
2057 assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 1);
2058
2059 result.for_each(drop).await?;
2060 assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 2);
2061
2062 result.for_each(drop).await?;
2063 assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 0);
2064
2065 result.for_each(drop).await?;
2066 assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 1);
2067
2068 c.disconnect().await?;
2069 Ok(())
2070 }
2071
2072 #[tokio::test]
2073 async fn should_expose_query_result_metadata() -> super::Result<()> {
2074 let pool = Pool::new(get_opts());
2075 let mut c = pool.get_conn().await?;
2076
2077 c.query_drop(
2078 r"
2079 CREATE TEMPORARY TABLE `foo`
2080 ( `id` SERIAL
2081 , `bar_id` varchar(36) NOT NULL
2082 , `baz_id` varchar(36) NOT NULL
2083 , `ctime` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP()
2084 , PRIMARY KEY (`id`)
2085 , KEY `bar_idx` (`bar_id`)
2086 , KEY `baz_idx` (`baz_id`)
2087 );",
2088 )
2089 .await?;
2090
2091 const QUERY: &str = "INSERT INTO foo (bar_id, baz_id) VALUES (?, ?)";
2092 let params = ("qwerty", "data.employee_id");
2093
2094 let query_result = c.exec_iter(QUERY, params).await?;
2095 assert_eq!(query_result.last_insert_id(), Some(1));
2096 query_result.drop_result().await?;
2097
2098 c.exec_drop(QUERY, params).await?;
2099 assert_eq!(c.last_insert_id(), Some(2));
2100
2101 let mut tx = c.start_transaction(Default::default()).await?;
2102
2103 tx.exec_drop(QUERY, params).await?;
2104 assert_eq!(tx.last_insert_id(), Some(3));
2105
2106 Ok(())
2107 }
2108
2109 #[tokio::test]
2110 async fn should_handle_local_infile_locally() -> super::Result<()> {
2111 let mut conn = Conn::new(get_opts()).await.unwrap();
2112 conn.query_drop("CREATE TEMPORARY TABLE tmp (a TEXT);")
2113 .await
2114 .unwrap();
2115
2116 conn.set_infile_handler(async move {
2117 Ok(
2118 stream::iter([Bytes::from("AAAAAA\n"), Bytes::from("BBBBBB\nCCCCCC\n")])
2119 .map(Ok)
2120 .boxed(),
2121 )
2122 });
2123
2124 match conn
2125 .query_drop(r#"LOAD DATA LOCAL INFILE "dummy" INTO TABLE tmp;"#)
2126 .await
2127 {
2128 Ok(_) => (),
2129 Err(super::Error::Server(ref err)) if err.code == 1148 => {
2130 return Ok(());
2132 }
2133 Err(super::Error::Server(ref err)) if err.code == 3948 => {
2134 return Ok(());
2137 }
2138 e @ Err(_) => e.unwrap(),
2139 };
2140
2141 let result: Vec<String> = conn.query("SELECT * FROM tmp").await?;
2142 assert_eq!(result.len(), 3);
2143 assert_eq!(result[0], "AAAAAA");
2144 assert_eq!(result[1], "BBBBBB");
2145 assert_eq!(result[2], "CCCCCC");
2146
2147 Ok(())
2148 }
2149
2150 #[tokio::test]
2151 async fn should_handle_local_infile_globally() -> super::Result<()> {
2152 use std::fs::write;
2153
2154 let file_path = tempfile::Builder::new().tempfile_in("").unwrap();
2155 let file_path = file_path.path();
2156 let file_name = file_path.file_name().unwrap();
2157
2158 write(file_name, b"AAAAAA\nBBBBBB\nCCCCCC\n")?;
2159
2160 let opts = get_opts().local_infile_handler(Some(WhiteListFsHandler::new(&[file_name][..])));
2161
2162 let mut conn = Conn::new(opts).await.unwrap();
2163 conn.query_drop("CREATE TEMPORARY TABLE tmp (a TEXT);")
2164 .await
2165 .unwrap();
2166
2167 match conn
2168 .query_drop(format!(
2169 r#"LOAD DATA LOCAL INFILE "{}" INTO TABLE tmp;"#,
2170 file_name.to_str().unwrap(),
2171 ))
2172 .await
2173 {
2174 Ok(_) => (),
2175 Err(super::Error::Server(ref err)) if err.code == 1148 => {
2176 return Ok(());
2178 }
2179 Err(super::Error::Server(ref err)) if err.code == 3948 => {
2180 return Ok(());
2183 }
2184 e @ Err(_) => e.unwrap(),
2185 };
2186
2187 let result: Vec<String> = conn.query("SELECT * FROM tmp").await?;
2188 assert_eq!(result.len(), 3);
2189 assert_eq!(result[0], "AAAAAA");
2190 assert_eq!(result[1], "BBBBBB");
2191 assert_eq!(result[2], "CCCCCC");
2192
2193 Ok(())
2194 }
2195
2196 #[cfg(feature = "nightly")]
2197 mod bench {
2198 use crate::{conn::Conn, queryable::Queryable, test_misc::get_opts};
2199
2200 #[bench]
2201 fn simple_exec(bencher: &mut test::Bencher) {
2202 let mut runtime = tokio::runtime::Runtime::new().unwrap();
2203 let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2204
2205 bencher.iter(|| {
2206 runtime.block_on(conn.query_drop("DO 1")).unwrap();
2207 });
2208
2209 runtime.block_on(conn.disconnect()).unwrap();
2210 }
2211
2212 #[bench]
2213 fn select_large_string(bencher: &mut test::Bencher) {
2214 let mut runtime = tokio::runtime::Runtime::new().unwrap();
2215 let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2216
2217 bencher.iter(|| {
2218 runtime
2219 .block_on(conn.query_drop("SELECT REPEAT('A', 10000)"))
2220 .unwrap();
2221 });
2222
2223 runtime.block_on(conn.disconnect()).unwrap();
2224 }
2225
2226 #[bench]
2227 fn prepared_exec(bencher: &mut test::Bencher) {
2228 let mut runtime = tokio::runtime::Runtime::new().unwrap();
2229 let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2230 let stmt = runtime.block_on(conn.prep("DO 1")).unwrap();
2231
2232 bencher.iter(|| {
2233 runtime.block_on(conn.exec_drop(&stmt, ())).unwrap();
2234 });
2235
2236 runtime.block_on(conn.close(stmt)).unwrap();
2237 runtime.block_on(conn.disconnect()).unwrap();
2238 }
2239
2240 #[bench]
2241 fn prepare_and_exec(bencher: &mut test::Bencher) {
2242 let mut runtime = tokio::runtime::Runtime::new().unwrap();
2243 let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2244
2245 bencher.iter(|| {
2246 runtime.block_on(conn.exec_drop("SELECT ?", (0,))).unwrap();
2247 });
2248
2249 runtime.block_on(conn.disconnect()).unwrap();
2250 }
2251 }
2252}