1use std::{
3 borrow::Cow,
4 fmt::Display,
5 marker::PhantomData,
6 mem::ManuallyDrop,
7 net::{IpAddr, Ipv4Addr, SocketAddr, ToSocketAddrs},
8 ops::Range,
9};
10
11use bytes::{Buf, BufMut, BytesMut};
12use thiserror::Error;
13use tokio::{
14 io::{AsyncReadExt, AsyncWriteExt},
15 net::{TcpSocket, UnixStream},
16};
17
18use crate::{
19 args::Args,
20 auth::{AuthPlugin, compute_auth},
21 bind::{Bind, BindError},
22 constants::{client, com},
23 decode::Column,
24 lru::{Entry, LRUCache},
25 package_parser::{DecodeError, DecodeResult, PackageParser},
26 row::{FromRow, Row},
27};
28
29#[derive(Error, Debug)]
31#[non_exhaustive]
32pub enum ConnectionErrorContent {
33 #[error("mysql error {code}: {message}")]
35 Mysql {
36 code: u16,
38 status: [u8; 5],
40 message: String,
42 },
43 #[error(transparent)]
45 Io(#[from] tokio::io::Error),
46 #[error("error reading {0}: {1}")]
48 Decode(&'static str, DecodeError),
49 #[error("error binding paramater {0}: {1}")]
51 Bind(u16, BindError),
52 #[error("protocol error {0}")]
54 ProtocolError(String),
55 #[error("fetch return no columns")]
57 ExpectedRows,
58 #[error("rows return for execute")]
60 UnexpectedRows,
61 #[cfg(feature = "cancel_testing")]
62 #[doc(hidden)]
64 #[error("await threshold reached")]
65 TestCancelled,
66 #[error("await threshold reached")]
68 TooFewListArguments,
69 #[error("await threshold reached")]
71 TooManyListArguments,
72 #[error("Invalid url")]
74 InvalidUrl,
75}
76
77pub struct ConnectionError(Box<ConnectionErrorContent>);
82
83const _: () = {
84 assert!(size_of::<ConnectionError>() == size_of::<usize>());
85};
86
87impl ConnectionError {
88 pub fn content(&self) -> &ConnectionErrorContent {
90 &self.0
91 }
92}
93
94impl std::ops::Deref for ConnectionError {
95 type Target = ConnectionErrorContent;
96
97 fn deref(&self) -> &Self::Target {
98 &self.0
99 }
100}
101
102impl<E: Into<ConnectionErrorContent>> From<E> for ConnectionError {
103 fn from(value: E) -> Self {
104 ConnectionError(Box::new(value.into()))
105 }
106}
107
108impl std::fmt::Debug for ConnectionError {
109 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110 std::fmt::Debug::fmt(&self.0, f)
111 }
112}
113
114impl std::fmt::Display for ConnectionError {
115 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116 std::fmt::Display::fmt(&self.0, f)
117 }
118}
119
120impl std::error::Error for ConnectionError {}
121
122pub type ConnectionResult<T> = std::result::Result<T, ConnectionError>;
124pub trait WithLoc<T> {
126 fn loc(self, loc: &'static str) -> ConnectionResult<T>;
128}
129
130impl<T> WithLoc<T> for DecodeResult<T> {
131 fn loc(self, loc: &'static str) -> ConnectionResult<T> {
132 self.map_err(|v| ConnectionErrorContent::Decode(loc, v).into())
133 }
134}
135
136trait Except {
138 type Value;
140
141 fn ev(self, loc: &'static str, expected: Self::Value) -> ConnectionResult<()>;
143}
144
145impl<T: Eq + Display> Except for DecodeResult<T> {
146 type Value = T;
147
148 fn ev(self, loc: &'static str, expected: T) -> ConnectionResult<()> {
149 let v = self.loc(loc)?;
150 if v != expected {
151 Err(ConnectionErrorContent::ProtocolError(format!(
152 "Expected {expected} for {loc} got {v}"
153 ))
154 .into())
155 } else {
156 Ok(())
157 }
158 }
159}
160
161pub trait RowMap<'a> {
163 type E: From<ConnectionError>;
165
166 type T;
168
169 fn map(row: Row<'a>) -> Result<Self::T, Self::E>;
171}
172
173struct FromRowMapper<T>(PhantomData<T>);
175
176impl<'a, T: FromRow<'a>> RowMap<'a> for FromRowMapper<T> {
177 type E = ConnectionError;
178
179 type T = T;
180
181 fn map(row: Row<'a>) -> Result<Self::T, Self::E> {
182 T::from_row(&row).loc("row")
183 }
184}
185
186enum OwnedReadHalf {
188 Tcp(tokio::net::tcp::OwnedReadHalf),
190 Unix(tokio::net::unix::OwnedReadHalf),
192}
193
194enum OwnedWriteHalf {
196 Tcp(tokio::net::tcp::OwnedWriteHalf),
198 Unix(tokio::net::unix::OwnedWriteHalf),
200}
201
202struct Reader {
204 buff: BytesMut,
206 read: OwnedReadHalf,
208 skip_on_read: usize,
210 buffer_packages: bool,
212}
213
214impl Reader {
215 fn new(read: OwnedReadHalf) -> Self {
217 Self {
218 read,
219 buff: BytesMut::with_capacity(1234),
220 skip_on_read: 0,
221 buffer_packages: false,
222 }
223 }
224
225 async fn read_raw(&mut self) -> ConnectionResult<Range<usize>> {
229 if !self.buffer_packages {
230 self.buff.advance(self.skip_on_read);
231 self.skip_on_read = 0;
232 }
233
234 while self.buff.remaining() < 4 + self.skip_on_read {
235 match &mut self.read {
236 OwnedReadHalf::Tcp(r) => r.read_buf(&mut self.buff).await?,
237 OwnedReadHalf::Unix(r) => r.read_buf(&mut self.buff).await?,
238 };
239 }
240 let y: u32 = u32::from_le_bytes(
241 self.buff[self.skip_on_read..self.skip_on_read + 4]
242 .try_into()
243 .unwrap(),
244 );
245 let len: usize = (y & 0xFFFFFF).try_into().unwrap();
246 let _s = (y >> 24) as u8;
247 if len == 0xFFFFFF {
248 return Err(ConnectionErrorContent::ProtocolError(
249 "Extended packages not supported".to_string(),
250 )
251 .into());
252 }
253 while self.buff.remaining() < self.skip_on_read + 4 + len {
254 match &mut self.read {
255 OwnedReadHalf::Tcp(r) => r.read_buf(&mut self.buff).await?,
256 OwnedReadHalf::Unix(r) => r.read_buf(&mut self.buff).await?,
257 };
258 }
259 let r = self.skip_on_read + 4..self.skip_on_read + 4 + len;
260 self.skip_on_read += 4 + len;
261 Ok(r)
262 }
263
264 #[inline]
268 async fn read(&mut self) -> ConnectionResult<&[u8]> {
269 let r = self.read_raw().await?;
270 Ok(self.bytes(r))
271 }
272
273 #[inline]
281 fn bytes(&self, r: Range<usize>) -> &[u8] {
282 &self.buff[r]
283 }
284}
285
286struct Writer {
288 write: OwnedWriteHalf,
290 buff: BytesMut,
292 seq: u8,
294}
295
296impl Writer {
297 fn new(write: OwnedWriteHalf) -> Self {
299 Writer {
300 write,
301 buff: BytesMut::with_capacity(1234),
302 seq: 1,
303 }
304 }
305
306 fn compose(&mut self) -> Composer<'_> {
308 self.buff.clear();
309 self.buff.put_u32(0);
310 Composer { writer: self }
311 }
312
313 async fn send(&mut self) -> ConnectionResult<()> {
315 match &mut self.write {
316 OwnedWriteHalf::Tcp(r) => Ok(r.write_all_buf(&mut self.buff).await?),
317 OwnedWriteHalf::Unix(r) => Ok(r.write_all_buf(&mut self.buff).await?),
318 }
319 }
320}
321
322struct Composer<'a> {
324 writer: &'a mut Writer,
326}
327
328impl<'a> Composer<'a> {
329 fn put_u32(&mut self, v: u32) {
331 self.writer.buff.put_u32_le(v)
332 }
333
334 fn put_u16(&mut self, v: u16) {
336 self.writer.buff.put_u16_le(v)
337 }
338
339 fn put_u8(&mut self, v: u8) {
341 self.writer.buff.put_u8(v)
342 }
343
344 fn put_str_null(&mut self, s: &str) {
346 self.writer.buff.put(s.as_bytes());
347 self.writer.buff.put_u8(0);
348 }
349
350 fn put_bytes(&mut self, s: &[u8]) {
352 self.writer.buff.put(s);
353 }
354
355 fn finalize(self) {
357 let len = self.writer.buff.len();
358 let mut x = &mut self.writer.buff[..4];
359 x.put_u32_le((len - 4) as u32 | ((self.writer.seq as u32) << 24));
360 self.writer.seq = self.writer.seq.wrapping_add(1);
361 }
362}
363
364pub struct ConnectionOptions<'a> {
366 address: SocketAddr,
368 user: Cow<'a, str>,
370 password: Cow<'a, str>,
372 database: Option<Cow<'a, str>>,
374 statement_case_size: usize,
376 unix_socket: Option<Cow<'a, std::path::Path>>,
378}
379
380impl<'a> ConnectionOptions<'a> {
381 pub fn new() -> ConnectionOptions<'a> {
383 Default::default()
384 }
385
386 pub fn into_owned(self) -> ConnectionOptions<'static> {
388 ConnectionOptions {
389 address: self.address,
390 user: self.user.into_owned().into(),
391 password: self.password.into_owned().into(),
392 database: self.database.map(|v| v.into_owned().into()),
393 statement_case_size: self.statement_case_size,
394 unix_socket: self.unix_socket.map(|v| v.into_owned().into()),
395 }
396 }
397
398 pub fn from_url(url: &'a str) -> Result<Self, ConnectionError> {
408 let Some(v) = url.strip_prefix("mysql://") else {
409 return Err(ConnectionErrorContent::InvalidUrl.into());
410 };
411 let (authority, path) = v
412 .split_once('/')
413 .map(|(a, b)| (a, Some(b)))
414 .unwrap_or((v, None));
415 let (user_info, address) = authority
416 .split_once('@')
417 .map(|(a, b)| (Some(a), b))
418 .unwrap_or((None, authority));
419 let (user, password) = user_info
420 .map(|v| {
421 v.split_once(':')
422 .map(|(a, b)| (Some(a), Some(b)))
423 .unwrap_or((Some(v), None))
424 })
425 .unwrap_or_default();
426 let (host, port) = address
427 .rsplit_once(':')
428 .map(|(a, b)| (a, Some(b)))
429 .unwrap_or((address, None));
430 let port: u16 = match port {
431 Some(v) => v.parse().map_err(|_| ConnectionErrorContent::InvalidUrl)?,
432 None => 3306,
433 };
434 let (db, unix_socket) = path
435 .map(|v| {
436 v.split_once("?socket=")
437 .map(|(a, b)| (Some(a), Some(std::path::Path::new(b).into())))
438 .unwrap_or((Some(v), None))
439 })
440 .unwrap_or_default();
441
442 let mut addrs = (host, port).to_socket_addrs()?;
443 let Some(address) = addrs.next() else {
444 return Err(ConnectionErrorContent::InvalidUrl.into());
445 };
446
447 Ok(ConnectionOptions {
448 address,
449 user: user.unwrap_or("root").into(),
450 password: password.unwrap_or("password").into(),
451 database: db.map(|v| v.into()),
452 unix_socket,
453 ..Default::default()
454 })
455 }
456
457 pub fn user(self, user: impl Into<Cow<'a, str>>) -> Self {
459 Self {
460 user: user.into(),
461 ..self
462 }
463 }
464
465 pub fn password(self, password: impl Into<Cow<'a, str>>) -> Self {
467 Self {
468 password: password.into(),
469 ..self
470 }
471 }
472
473 pub fn database(self, database: impl Into<Cow<'a, str>>) -> Self {
475 Self {
476 database: Some(database.into()),
477 ..self
478 }
479 }
480
481 pub fn address(self, address: impl std::net::ToSocketAddrs) -> Result<Self, std::io::Error> {
483 match address.to_socket_addrs()?.next() {
484 Some(v) => Ok(Self { address: v, ..self }),
485 None => Err(std::io::Error::new(
486 std::io::ErrorKind::NotFound,
487 "No host resolved",
488 )),
489 }
490 }
491
492 pub fn unix_socket(self, path: impl Into<Cow<'a, std::path::Path>>) -> Self {
494 Self {
495 unix_socket: Some(path.into()),
496 ..self
497 }
498 }
499
500 pub fn statment_case_size(self, size: usize) -> Self {
502 Self {
503 statement_case_size: size,
504 ..self
505 }
506 }
507}
508
509impl<'a> Default for ConnectionOptions<'a> {
510 fn default() -> Self {
511 Self {
512 address: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 3306),
513 user: Cow::Borrowed("root"),
514 password: Cow::Borrowed("password"),
515 database: None,
516 statement_case_size: 1024,
517 unix_socket: None,
518 }
519 }
520}
521
522#[derive(Debug)]
524#[non_exhaustive]
525pub struct ColumnDefinition<'a> {
526 pub schema: &'a str,
528 pub table_alias: &'a str,
530 pub table: &'a str,
532 pub column_alias: &'a str,
534 pub column: &'a str,
536 pub extended_matadata: Option<&'a str>,
538 pub character_set_number: u16,
540 pub max_column_size: u32,
542 pub field_types: u8,
544 pub field_detail_flags: u16,
546 pub decimals: u8,
548}
549
550impl<'a> ColumnDefinition<'a> {
551 fn new(data: &'a [u8]) -> DecodeResult<Self> {
553 let mut p = PackageParser::new(data);
554 p.skip_lenenc_str()?; let schema = p.get_lenenc_str()?;
556 let table_alias = p.get_lenenc_str()?;
557 let table = p.get_lenenc_str()?;
558 let column_alias = p.get_lenenc_str()?;
559 let column = p.get_lenenc_str()?;
560 let extended_matadata = None;
561 p.get_lenenc()?;
562 let character_set_number = p.get_u16()?;
563 let max_column_size = p.get_u32()?;
564 let field_types = p.get_u8()?;
565 let field_detail_flags = p.get_u16()?;
566 let decimals = p.get_u8()?;
567 Ok(ColumnDefinition {
568 schema,
569 table_alias,
570 table,
571 column_alias,
572 column,
573 extended_matadata,
574 character_set_number,
575 max_column_size,
576 field_types,
577 field_detail_flags,
578 decimals,
579 })
580 }
581}
582
583pub struct ColumnsInformation<'a> {
585 data: &'a [u8],
587 ranges: &'a [Range<usize>],
589}
590
591impl<'a> ColumnsInformation<'a> {
592 pub fn get(&self, idx: usize) -> Option<DecodeResult<ColumnDefinition<'a>>> {
594 self.ranges
595 .get(idx)
596 .map(|v| ColumnDefinition::new(&self.data[v.clone()]))
597 }
598}
599
600impl<'a> Iterator for ColumnsInformation<'a> {
601 type Item = DecodeResult<ColumnDefinition<'a>>;
602
603 fn next(&mut self) -> Option<Self::Item> {
604 match self.ranges.split_off_first() {
605 Some(v) => Some(ColumnDefinition::new(&self.data[v.clone()])),
606 None => None,
607 }
608 }
609
610 fn size_hint(&self) -> (usize, Option<usize>) {
611 (self.ranges.len(), Some(self.ranges.len()))
612 }
613
614 fn nth(&mut self, n: usize) -> Option<Self::Item> {
615 self.get(n)
616 }
617}
618
619impl<'a> ExactSizeIterator for ColumnsInformation<'a> {
620 fn len(&self) -> usize {
621 self.ranges.len()
622 }
623}
624
625impl StatementInformation {
626 pub fn columns(&self) -> ColumnsInformation<'_> {
628 ColumnsInformation {
629 data: &self.info,
630 ranges: &self.ranges[self.num_params as usize..],
631 }
632 }
633
634 pub fn parameters(&self) -> ColumnsInformation<'_> {
636 ColumnsInformation {
637 data: &self.info,
638 ranges: &self.ranges[..self.num_params as usize],
639 }
640 }
641}
642
643pub struct StatementInformation {
645 num_params: u16,
647 info: Vec<u8>,
649 ranges: Vec<Range<usize>>,
651}
652
653struct Statement {
655 stmt_id: u32,
657 num_params: u16,
659 information: Option<StatementInformation>,
661}
662
663pub struct QueryIterator<'a> {
665 connection: &'a mut RawConnection,
667}
668
669impl<'a> QueryIterator<'a> {
670 pub async fn next(&mut self) -> ConnectionResult<Option<Row<'_>>> {
674 match self.connection.state {
675 ConnectionState::Clean => return Ok(None),
676 ConnectionState::QueryReadRows => (),
677 _ => panic!("Logic error"),
678 }
679 self.connection.test_cancel()?;
681 let start_instant = self.connection.stats.get_instant();
682 let package = self.connection.reader.read().await?;
683 self.connection.stats.add_fetch(start_instant);
684 let mut pp = PackageParser::new(package);
685 match pp.get_u8().loc("Row first byte")? {
686 0x00 => Ok(Some(Row::new(&self.connection.columns, package))),
687 0xFE => {
688 self.connection.state = ConnectionState::Clean;
690 Ok(None)
691 }
692 0xFF => {
693 handle_mysql_error(&mut pp)?;
694 unreachable!()
695 }
696 v => Err(ConnectionErrorContent::ProtocolError(format!(
697 "Unexpected response type {v} to row package"
698 ))
699 .into()),
700 }
701 }
702}
703
704pub struct MapQueryIterator<'a, M>
706where
707 for<'b> M: RowMap<'b>,
708{
709 connection: &'a mut RawConnection,
711 _phantom: PhantomData<M>,
713}
714
715impl<'a, M> MapQueryIterator<'a, M>
716where
717 for<'b> M: RowMap<'b>,
718{
719 pub async fn next<'b>(
723 &'b mut self,
724 ) -> Result<Option<<M as RowMap<'b>>::T>, <M as RowMap<'b>>::E> {
725 match self.connection.state {
726 ConnectionState::Clean => return Ok(None),
727 ConnectionState::QueryReadRows => (),
728 _ => panic!("Logic error"),
729 }
730 self.connection.test_cancel()?;
732 let start_instant = self.connection.stats.get_instant();
733 let package = self.connection.reader.read().await?;
734 self.connection.stats.add_fetch(start_instant);
735 let mut pp = PackageParser::new(package);
736 match pp.get_u8().loc("Row first byte")? {
737 0x00 => Ok(Some(M::map(Row::new(&self.connection.columns, package))?)),
738 0xFE => {
739 self.connection.state = ConnectionState::Clean;
741 Ok(None)
742 }
743 0xFF => {
744 handle_mysql_error(&mut pp)?;
745 unreachable!()
746 }
747 v => Err(
748 ConnectionError::from(ConnectionErrorContent::ProtocolError(format!(
749 "Unexpected response type {v} to row package"
750 )))
751 .into(),
752 ),
753 }
754 }
755}
756
757pub struct ExecuteResult {
759 affected_rows: u64,
761 last_insert_id: u64,
763}
764
765impl ExecuteResult {
766 pub fn affected_rows(&self) -> u64 {
768 self.affected_rows
769 }
770
771 pub fn last_insert_id(&self) -> u64 {
773 self.last_insert_id
774 }
775}
776
777enum QueryResult {
779 WithColumns,
781 ExecuteResult(ExecuteResult),
783}
784
785#[derive(Clone, Copy, Debug)]
787enum ConnectionState {
788 Clean,
790 PrepareStatementSend,
792 PrepareStatementReadHead,
794 PrepareStatementReadParams {
796 params: u16,
798 columns: u16,
800 stmt_id: u32,
802 },
803 ClosePreparedStatement,
805 QuerySend,
807 QueryReadHead,
809 QueryReadColumns(u64),
811 QueryReadRows,
813 UnpreparedSend,
815 UnpreparedRecv,
817 PingSend,
819 PingRecv,
821 Broken,
823}
824
825trait IStats {
827 type Instant: Sized;
829 fn get_instant(&self) -> Self::Instant;
831 fn add_prepare(&mut self, start_instant: Self::Instant);
833 fn add_execute(&mut self, start_instant: Self::Instant);
835 fn add_fetch(&mut self, start_instant: Self::Instant);
837}
838
839#[allow(unused)]
841#[derive(Default, Debug)]
842pub struct Stats {
843 pub prepare_counts: usize,
845 pub prepare_time: std::time::Duration,
847 pub execute_counts: usize,
849 pub execute_time: std::time::Duration,
851 pub fetch_time: std::time::Duration,
853}
854
855impl IStats for Stats {
856 type Instant = std::time::Instant;
857
858 fn get_instant(&self) -> Self::Instant {
859 std::time::Instant::now()
860 }
861
862 fn add_prepare(&mut self, start_instant: Self::Instant) {
863 self.prepare_counts += 1;
864 self.prepare_time += start_instant.elapsed();
865 }
866
867 fn add_execute(&mut self, start_instant: Self::Instant) {
868 self.execute_counts += 1;
869 self.execute_time += start_instant.elapsed()
870 }
871
872 fn add_fetch(&mut self, start_instant: Self::Instant) {
873 self.fetch_time += start_instant.elapsed();
874 }
875}
876#[allow(unused)]
878#[derive(Default)]
879struct NoStats;
880
881impl IStats for NoStats {
882 type Instant = NoStats;
883
884 fn get_instant(&self) -> Self::Instant {
885 NoStats
886 }
887
888 fn add_prepare(&mut self, _: Self::Instant) {}
889
890 fn add_execute(&mut self, _: Self::Instant) {}
891
892 fn add_fetch(&mut self, _: Self::Instant) {}
893}
894
895struct RawConnection {
898 reader: Reader,
900 writer: Writer,
902 state: ConnectionState,
904 columns: Vec<Column>,
906 ranges: Vec<Range<usize>>,
908 #[cfg(feature = "cancel_testing")]
909 cancel_count: Option<usize>,
911 #[cfg(feature = "stats")]
912 stats: Stats,
914 #[cfg(not(feature = "stats"))]
915 stats: NoStats,
917 #[cfg(feature = "list_hack")]
918 list_lengths: Vec<usize>,
920}
921
922fn parse_column_definition(p: &mut PackageParser) -> ConnectionResult<Column> {
924 p.skip_lenenc_str().loc("catalog")?;
925 p.skip_lenenc_str().loc("schema")?;
926 p.skip_lenenc_str().loc("table")?;
927 p.skip_lenenc_str().loc("org_table")?;
928 p.skip_lenenc_str().loc("name")?;
929 p.skip_lenenc_str().loc("org_name")?;
930 p.get_lenenc().loc("length of fixed length fields")?;
931 let character_set = p.get_u16().loc("character_set")?;
932 p.get_u32().loc("column_length")?;
933 let r#type = p.get_u8().loc("type")?;
934 let flags = p.get_u16().loc("flags")?;
935 p.get_u8().loc("decimals")?;
936 p.get_u16().loc("res")?;
937 Ok(Column {
938 r#type,
939 flags,
940 character_set,
941 })
942}
943
944fn handle_mysql_error(pp: &mut PackageParser) -> ConnectionResult<std::convert::Infallible> {
946 let code = pp.get_u16().loc("code")?;
948 pp.get_u8().ev("sharp", b'#')?;
949 let a = pp.get_u8().loc("status0")?;
950 let b = pp.get_u8().loc("status1")?;
951 let c = pp.get_u8().loc("status2")?;
952 let d = pp.get_u8().loc("status3")?;
953 let e = pp.get_u8().loc("status4")?;
954 let msg = pp.get_eof_str().loc("message")?;
955 Err(ConnectionErrorContent::Mysql {
956 code,
957 status: [a, b, c, d, e],
958 message: msg.to_string(),
959 }
960 .into())
961}
962
963fn begin_transaction_query(depth: usize) -> Cow<'static, str> {
965 match depth {
966 0 => "BEGIN".into(),
967 1 => "SAVEPOINT _sqly_savepoint_1".into(),
968 2 => "SAVEPOINT _sqly_savepoint_2".into(),
969 3 => "SAVEPOINT _sqly_savepoint_3".into(),
970 v => format!("SAVEPOINT _sqly_savepoint_{}", v).into(),
971 }
972}
973
974fn commit_transaction_query(depth: usize) -> Cow<'static, str> {
976 match depth {
977 0 => "COMMIT".into(),
978 1 => "RELEASE SAVEPOINT _sqly_savepoint_1".into(),
979 2 => "RELEASE SAVEPOINT _sqly_savepoint_2".into(),
980 3 => "RELEASE SAVEPOINT _sqly_savepoint_3".into(),
981 v => format!("RELEASE SAVEPOINT _sqly_savepoint_{}", v).into(),
982 }
983}
984
985fn rollback_transaction_query(depth: usize) -> Cow<'static, str> {
987 match depth {
988 0 => "ROLLBACK".into(),
989 1 => "ROLLBACK TO SAVEPOINT _sqly_savepoint_1".into(),
990 2 => "ROLLBACK TO SAVEPOINT _sqly_savepoint_2".into(),
991 3 => "ROLLBACK TO SAVEPOINT _sqly_savepoint_3".into(),
992 v => format!("RELEASE TO SAVEPOINT _sqly_savepoint_{}", v).into(),
993 }
994}
995
996impl RawConnection {
997 async fn connect(options: &ConnectionOptions<'_>) -> ConnectionResult<Self> {
1001 let (read, write) = if let Some(path) = &options.unix_socket {
1005 let socket = UnixStream::connect(path).await?;
1006 let (read, write) = socket.into_split();
1007 (OwnedReadHalf::Unix(read), OwnedWriteHalf::Unix(write))
1008 } else {
1009 let stream = if options.address.is_ipv4() {
1010 let socket = TcpSocket::new_v4()?;
1011 socket.connect(options.address).await?
1012 } else {
1013 let socket = TcpSocket::new_v6()?;
1014 socket.connect(options.address).await?
1015 };
1016 let (read, write) = stream.into_split();
1017 (OwnedReadHalf::Tcp(read), OwnedWriteHalf::Tcp(write))
1018 };
1019
1020 let mut reader = Reader::new(read);
1021 let mut writer = Writer::new(write);
1022
1023 let package = reader.read().await?;
1025 let mut p = PackageParser::new(package);
1026 p.get_u8().ev("protocol version", 10)?;
1027 p.skip_null_str().loc("status")?;
1028 let _wthread_id = p.get_u32().loc("thread_id")?;
1029 let nonce1 = p.get_bytes(8).loc("nonce1")?;
1030 p.get_u8().ev("nonce1_end", 0)?;
1031 let capability_flags_1 = p.get_u16().loc("capability_flags_1")?;
1032 let _character_set = p.get_u8().loc("character_set")?;
1033 p.get_u16().loc("status_flags")?;
1034 let capability_flags_2 = p.get_u16().loc("capability_flags_2")?;
1035 let auth_plugin_data_len = p.get_u8().loc("auth_plugin_data_len")?;
1036 let _capability_flags = capability_flags_1 as u32 | (capability_flags_2 as u32) << 16;
1037 p.get_bytes(10).loc("reserved")?;
1038 let nonce2 = p
1039 .get_bytes(auth_plugin_data_len as usize - 9)
1040 .loc("nonce2")?;
1041 p.get_u8().ev("nonce2_end", 0)?;
1042
1043 let auth_plugin = p.get_null_str().loc("auth_plugin")?;
1044 let auth_method = match auth_plugin {
1045 "mysql_native_password" => AuthPlugin::NativePassword,
1046 #[cfg(feature = "sha2_auth")]
1047 "caching_sha2_password" => AuthPlugin::CachingSha2Password,
1048 v => {
1049 return Err(ConnectionErrorContent::ProtocolError(format!(
1050 "Unhandled auth plugin {v}"
1051 ))
1052 .into());
1053 }
1054 };
1055
1056 let mut p = writer.compose();
1058 let mut opts = client::LONG_PASSWORD
1059 | client::LONG_FLAG
1060 | client::LOCAL_FILES
1061 | client::PROTOCOL_41
1062 | client::DEPRECATE_EOF
1063 | client::TRANSACTIONS
1064 | client::SECURE_CONNECTION
1065 | client::MULTI_STATEMENTS
1066 | client::MULTI_RESULTS
1067 | client::PS_MULTI_RESULTS
1068 | client::PLUGIN_AUTH;
1069 if options.database.is_some() {
1070 opts |= client::CONNECT_WITH_DB
1071 }
1072 p.put_u32(opts);
1073 p.put_u32(0x1000000); p.put_u16(45); for _ in 0..22 {
1076 p.put_u8(0);
1077 }
1078 p.put_str_null(&options.user);
1079
1080 let mut nonce = Vec::with_capacity(nonce1.len() + nonce2.len());
1081 nonce.extend_from_slice(nonce1);
1082 nonce.extend_from_slice(nonce2);
1083
1084 let auth = compute_auth(&options.password, &nonce, auth_method);
1085 let auth = auth.as_slice();
1086 p.put_u8(auth.len() as u8);
1087 for v in auth {
1088 p.put_u8(*v);
1089 }
1090 if let Some(database) = &options.database {
1091 p.put_str_null(database);
1092 }
1093 p.put_str_null(auth_plugin);
1095 p.finalize();
1096
1097 writer.send().await?;
1098
1099 loop {
1100 let p = reader.read().await?;
1101 let mut pp = PackageParser::new(p);
1102 match pp.get_u8().loc("response type")? {
1103 0xFF => {
1104 handle_mysql_error(&mut pp)?;
1105 }
1106 0x00 => {
1107 let _rows = pp.get_lenenc().loc("rows")?;
1108 let _last_inserted_id = pp.get_lenenc().loc("last_inserted_id")?;
1109 break;
1110 }
1111 0xFE => {
1112 return Err(ConnectionErrorContent::ProtocolError(
1113 "Unexpected auth switch".into(),
1114 )
1115 .into());
1116 }
1117 #[cfg(feature = "sha2_auth")]
1118 0x01 if matches!(auth_method, AuthPlugin::CachingSha2Password) => {
1119 match pp.get_u8().loc("auth_status")? {
1120 0x03 => break,
1122 0x04 => {
1124 writer.seq = 3;
1126 let mut p = writer.compose();
1127 p.put_u8(0x02);
1128 p.finalize();
1129 writer.send().await?;
1130
1131 let p = reader.read().await?;
1132 let mut pp = PackageParser::new(p);
1133 pp.get_u8().ev("first", 1)?;
1134 let pem = pp.get_eof_str().loc("pem")?;
1135
1136 let pwd = crate::auth::encrypt_rsa(pem, &options.password, &nonce)?;
1137
1138 writer.seq = 5;
1139 let mut p = writer.compose();
1140 p.put_bytes(&pwd);
1141 p.finalize();
1142 writer.send().await?;
1143 }
1144 v => {
1145 return Err(ConnectionErrorContent::ProtocolError(format!(
1146 "Unexpected auth status {v} to handshake response"
1147 ))
1148 .into());
1149 }
1150 }
1151 }
1152 v => {
1153 return Err(ConnectionErrorContent::ProtocolError(format!(
1154 "Unexpected response type {v} to handshake response"
1155 ))
1156 .into());
1157 }
1158 }
1159 }
1160 writer.seq = 0;
1161 Ok(RawConnection {
1162 reader,
1163 writer,
1164 state: ConnectionState::Clean,
1165 columns: Vec::new(),
1166 ranges: Vec::new(),
1167 #[cfg(feature = "cancel_testing")]
1168 cancel_count: None,
1169 stats: Default::default(),
1170 #[cfg(feature = "list_hack")]
1171 list_lengths: Vec::new(),
1172 })
1173 }
1174
1175 #[inline]
1181 fn test_cancel(&mut self) -> ConnectionResult<()> {
1182 #[cfg(feature = "cancel_testing")]
1183 if let Some(v) = &mut self.cancel_count {
1184 if *v == 0 {
1185 return Err(ConnectionErrorContent::TestCancelled.into());
1186 }
1187 *v -= 1;
1188 }
1189 Ok(())
1190 }
1191
1192 async fn cleanup(&mut self) -> ConnectionResult<()> {
1194 loop {
1195 match self.state {
1196 ConnectionState::Clean => break,
1197 ConnectionState::PrepareStatementSend => {
1198 self.test_cancel()?;
1199 self.writer.send().await?;
1200 self.state = ConnectionState::PrepareStatementReadHead;
1201 continue;
1202 }
1203 ConnectionState::PrepareStatementReadHead => {
1204 self.test_cancel()?;
1205 let package = self.reader.read().await?;
1206 let mut p = PackageParser::new(package);
1207 match p.get_u8().loc("response type")? {
1208 0 => {
1209 let stmt_id = p.get_u32().loc("stmt_id")?;
1210 let columns = p.get_u16().loc("num_columns")?;
1211 let params = p.get_u16().loc("num_params")?;
1212 self.state = ConnectionState::PrepareStatementReadParams {
1213 params,
1214 columns,
1215 stmt_id,
1216 };
1217 continue;
1218 }
1219 255 => {
1220 self.state = ConnectionState::Clean;
1221 }
1222 v => {
1223 self.state = ConnectionState::Broken;
1224 return Err(ConnectionErrorContent::ProtocolError(format!(
1225 "Unexpected response type {v} to prepare statement"
1226 ))
1227 .into());
1228 }
1229 }
1230 }
1231 ConnectionState::PrepareStatementReadParams {
1232 params: 0,
1233 columns: 0,
1234 stmt_id,
1235 } => {
1236 self.writer.seq = 0;
1237 let mut p = self.writer.compose();
1238 p.put_u8(com::STMT_CLOSE);
1239 p.put_u32(stmt_id);
1240 p.finalize();
1241 self.state = ConnectionState::ClosePreparedStatement;
1242 }
1243 ConnectionState::PrepareStatementReadParams {
1244 params: 0,
1245 columns,
1246 stmt_id,
1247 } => {
1248 self.test_cancel()?;
1249 self.reader.read().await?;
1250 self.state = ConnectionState::PrepareStatementReadParams {
1251 params: 0,
1252 columns: columns - 1,
1253 stmt_id,
1254 };
1255 }
1256 ConnectionState::PrepareStatementReadParams {
1257 params,
1258 columns,
1259 stmt_id,
1260 } => {
1261 self.test_cancel()?;
1262 self.reader.read().await?;
1263 self.state = ConnectionState::PrepareStatementReadParams {
1264 params: params - 1,
1265 columns,
1266 stmt_id,
1267 };
1268 }
1269 ConnectionState::ClosePreparedStatement => {
1270 self.test_cancel()?;
1271 self.writer.send().await?;
1272 self.state = ConnectionState::Clean;
1273 }
1274 ConnectionState::QuerySend => {
1275 self.test_cancel()?;
1276 self.writer.send().await?;
1277 self.state = ConnectionState::QueryReadHead;
1278 }
1279 ConnectionState::QueryReadHead => {
1280 self.test_cancel()?;
1281 let package = self.reader.read().await?;
1282 {
1283 let mut pp = PackageParser::new(package);
1284 match pp.get_u8().loc("first_byte")? {
1285 255 | 0 => {
1286 self.state = ConnectionState::Clean;
1287 continue;
1288 }
1289 _ => (),
1290 }
1291 }
1292 let column_count = PackageParser::new(package)
1293 .get_lenenc()
1294 .loc("column_count")?;
1295 self.state = ConnectionState::QueryReadColumns(column_count)
1296 }
1297 ConnectionState::QueryReadColumns(0) => {
1298 self.state = ConnectionState::QueryReadRows;
1299 }
1300 ConnectionState::QueryReadColumns(cnt) => {
1301 self.test_cancel()?;
1302 self.reader.read().await?;
1303 self.state = ConnectionState::QueryReadColumns(cnt - 1);
1304 }
1305 ConnectionState::QueryReadRows => {
1306 self.test_cancel()?;
1307 let package = self.reader.read().await?;
1308 let mut pp = PackageParser::new(package);
1309 match pp.get_u8().loc("Row first byte")? {
1310 0x00 => (),
1311 0xFE => {
1312 self.state = ConnectionState::Clean;
1314 }
1315 0xFF => {
1316 self.state = ConnectionState::Broken;
1317 handle_mysql_error(&mut pp)?;
1318 unreachable!()
1319 }
1320 v => {
1321 self.state = ConnectionState::Broken;
1322 return Err(ConnectionErrorContent::ProtocolError(format!(
1323 "Unexpected response type {v} to row package"
1324 ))
1325 .into());
1326 }
1327 }
1328 }
1329 ConnectionState::UnpreparedSend => {
1330 self.test_cancel()?;
1331 self.writer.send().await?;
1332 self.state = ConnectionState::QueryReadHead;
1333 }
1334 ConnectionState::UnpreparedRecv => {
1335 self.test_cancel()?;
1336 let package = self.reader.read().await?;
1337 let mut pp = PackageParser::new(package);
1338 match pp.get_u8().loc("first_byte")? {
1339 255 => {
1340 self.state = ConnectionState::Broken;
1341 handle_mysql_error(&mut pp)?;
1342 unreachable!()
1343 }
1344 0 => {
1345 self.state = ConnectionState::Clean;
1346 return Ok(());
1347 }
1348 v => {
1349 self.state = ConnectionState::Broken;
1350 return Err(ConnectionErrorContent::ProtocolError(format!(
1351 "Unexpected response type {v} to row package"
1352 ))
1353 .into());
1354 }
1355 }
1356 }
1357 ConnectionState::PingSend => {
1358 self.test_cancel()?;
1359 self.writer.send().await?;
1360 self.state = ConnectionState::PingRecv;
1361 }
1362 ConnectionState::PingRecv => {
1363 self.test_cancel()?;
1364 let package = self.reader.read().await?;
1365 let mut pp = PackageParser::new(package);
1366 match pp.get_u8().loc("first_byte")? {
1367 255 => {
1368 self.state = ConnectionState::Broken;
1369 handle_mysql_error(&mut pp)?;
1370 unreachable!()
1371 }
1372 0 => {
1373 self.state = ConnectionState::Clean;
1374 return Ok(());
1375 }
1376 v => {
1377 self.state = ConnectionState::Broken;
1378 return Err(ConnectionErrorContent::ProtocolError(format!(
1379 "Unexpected response type {v} to ping"
1380 ))
1381 .into());
1382 }
1383 }
1384 }
1385 ConnectionState::Broken => {
1386 return Err(ConnectionErrorContent::ProtocolError(
1387 "Previous protocol error reported".to_string(),
1388 )
1389 .into());
1390 }
1391 }
1392 }
1393 Ok(())
1394 }
1395
1396 async fn prepare_query(&mut self, stmt: &str, with_info: bool) -> ConnectionResult<Statement> {
1400 assert!(matches!(self.state, ConnectionState::Clean));
1401 self.writer.seq = 0;
1402 let mut p = self.writer.compose();
1403 p.put_u8(com::STMT_PREPARE);
1404 p.put_bytes(stmt.as_bytes());
1405 p.finalize();
1406
1407 let start_instant = self.stats.get_instant();
1408 self.state = ConnectionState::PrepareStatementSend;
1409 self.test_cancel()?;
1411 self.writer.send().await?;
1412
1413 self.state = ConnectionState::PrepareStatementReadHead;
1414 self.test_cancel()?;
1416 let package = self.reader.read().await?;
1417
1418 let mut p = PackageParser::new(package);
1419 match p.get_u8().loc("response type")? {
1420 0 => {
1421 let stmt_id = p.get_u32().loc("stmt_id")?;
1422 let num_columns = p.get_u16().loc("num_columns")?;
1423 let num_params = p.get_u16().loc("num_params")?;
1424 let mut info_bytes: Vec<_> = Vec::new();
1427 let mut info_ranges = Vec::new();
1428
1429 for p in 0..num_params {
1431 self.state = ConnectionState::PrepareStatementReadParams {
1432 params: num_params - p,
1433 columns: num_columns,
1434 stmt_id,
1435 };
1436 self.test_cancel()?;
1438 let pkg = self.reader.read().await?;
1439 if with_info {
1440 let start = info_bytes.len();
1441 info_bytes.extend(pkg);
1442 info_ranges.push(start..info_bytes.len())
1443 }
1444 }
1445
1446 for c in 0..num_columns {
1448 self.state = ConnectionState::PrepareStatementReadParams {
1449 params: 0,
1450 columns: num_columns - c,
1451 stmt_id,
1452 };
1453 self.test_cancel()?;
1455 let pkg = self.reader.read().await?;
1456 if with_info {
1457 let start = info_bytes.len();
1458 info_bytes.extend(pkg);
1459 info_ranges.push(start..info_bytes.len())
1460 }
1461 }
1462
1463 let information = if with_info {
1464 Some(StatementInformation {
1465 num_params,
1466 info: info_bytes,
1467 ranges: info_ranges,
1468 })
1469 } else {
1470 None
1471 };
1472
1473 self.state = ConnectionState::Clean;
1474 self.stats.add_prepare(start_instant);
1475 Ok(Statement {
1476 stmt_id,
1477 num_params,
1478 information,
1479 })
1480 }
1481 255 => {
1482 handle_mysql_error(&mut p)?;
1483 unreachable!()
1484 }
1485 v => {
1486 self.state = ConnectionState::Broken;
1487 Err(ConnectionErrorContent::ProtocolError(format!(
1488 "Unexpected response type {v} to prepare statement"
1489 ))
1490 .into())
1491 }
1492 }
1493 }
1494
1495 fn query<'a>(&'a mut self, statement: &'a Statement) -> Query<'a> {
1497 assert!(matches!(self.state, ConnectionState::Clean));
1498
1499 self.writer.seq = 0;
1500 let mut p = self.writer.compose();
1501 p.put_u8(com::STMT_EXECUTE);
1502 p.put_u32(statement.stmt_id);
1503 p.put_u8(0); p.put_u32(1); let null_offset = p.writer.buff.len();
1507 let mut type_offset = null_offset;
1508 if statement.num_params != 0 {
1509 let null_bytes = statement.num_params.div_ceil(8);
1510 for _ in 0..null_bytes {
1512 p.put_u8(0);
1513 }
1514 p.put_u8(1); type_offset = p.writer.buff.len();
1517 for _ in 0..statement.num_params {
1519 p.put_u16(0);
1520 }
1521 }
1522
1523 Query {
1524 connection: self,
1525 statement,
1526 cur_param: 0,
1527 null_offset,
1528 type_offset,
1529 }
1530 }
1531
1532 async fn query_send(&mut self) -> ConnectionResult<QueryResult> {
1534 let p = Composer {
1535 writer: &mut self.writer,
1536 };
1537 p.finalize();
1538
1539 let start_instant = self.stats.get_instant();
1540 self.state = ConnectionState::QuerySend;
1541 self.test_cancel()?;
1543 self.writer.send().await?;
1544
1545 self.state = ConnectionState::QueryReadHead;
1546 self.test_cancel()?;
1548 let package = self.reader.read().await?;
1549 {
1550 let mut pp = PackageParser::new(package);
1551 match pp.get_u8().loc("first_byte")? {
1552 255 => {
1553 handle_mysql_error(&mut pp)?;
1554 }
1555 0 => {
1556 self.stats.add_execute(start_instant);
1557 self.state = ConnectionState::Clean;
1558 let affected_rows = pp.get_lenenc().loc("affected_rows")?;
1559 let last_insert_id = pp.get_lenenc().loc("last_insert_id")?;
1560 return Ok(QueryResult::ExecuteResult(ExecuteResult {
1561 affected_rows,
1562 last_insert_id,
1563 }));
1564 }
1565 _ => (),
1566 }
1567 }
1568
1569 let column_count = PackageParser::new(package)
1570 .get_lenenc()
1571 .loc("column_count")?;
1572
1573 self.columns.clear();
1574
1575 for c in 0..column_count {
1577 self.state = ConnectionState::QueryReadColumns(column_count - c);
1578 self.test_cancel()?;
1580 let package = self.reader.read().await?;
1581 let mut p = PackageParser::new(package);
1582 self.columns.push(parse_column_definition(&mut p)?);
1583 }
1584 self.stats.add_execute(start_instant);
1585
1586 self.state = ConnectionState::QueryReadRows;
1587 Ok(QueryResult::WithColumns)
1588 }
1589
1590 fn execute_unprepared(
1595 &mut self,
1596 escaped_sql: Cow<'_, str>,
1597 ) -> impl Future<Output = ConnectionResult<ExecuteResult>> + Send {
1598 assert!(matches!(self.state, ConnectionState::Clean));
1599 self.writer.seq = 0;
1600 let mut p = self.writer.compose();
1601 p.put_u8(com::QUERY);
1602 p.put_bytes(escaped_sql.as_bytes());
1603 p.finalize();
1604
1605 self.state = ConnectionState::UnpreparedSend;
1606
1607 async move {
1608 let start_time = self.stats.get_instant();
1609 self.test_cancel()?;
1610 self.writer.send().await?;
1611
1612 self.state = ConnectionState::UnpreparedRecv;
1613 self.test_cancel()?;
1614 let package = self.reader.read().await?;
1615 {
1616 let mut pp = PackageParser::new(package);
1617 match pp.get_u8().loc("first_byte")? {
1618 255 => {
1619 handle_mysql_error(&mut pp)?;
1620 unreachable!()
1621 }
1622 0 => {
1623 self.stats.add_execute(start_time);
1624 self.state = ConnectionState::Clean;
1625 let affected_rows = pp.get_lenenc().loc("affected_rows")?;
1626 let last_insert_id = pp.get_lenenc().loc("last_insert_id")?;
1627 Ok(ExecuteResult {
1628 affected_rows,
1629 last_insert_id,
1630 })
1631 }
1632 v => {
1633 self.state = ConnectionState::Broken;
1634 Err(ConnectionErrorContent::ProtocolError(format!(
1635 "Unexpected response type {v} to row package"
1636 ))
1637 .into())
1638 }
1639 }
1640 }
1641 }
1642 }
1643
1644 fn close_prepared_statement(
1646 &mut self,
1647 id: u32,
1648 ) -> impl Future<Output = ConnectionResult<()>> + Send {
1649 assert!(matches!(self.state, ConnectionState::Clean));
1650 self.writer.seq = 0;
1651 let mut p = self.writer.compose();
1652 p.put_u8(com::STMT_CLOSE);
1653 p.put_u32(id);
1654 p.finalize();
1655 self.state = ConnectionState::ClosePreparedStatement;
1656 async move {
1657 let start_time = self.stats.get_instant();
1658 self.test_cancel()?;
1659 self.writer.send().await?;
1660 self.stats.add_prepare(start_time);
1661 self.state = ConnectionState::Clean;
1662 Ok(())
1663 }
1664 }
1665
1666 fn ping(&mut self) -> impl Future<Output = ConnectionResult<()>> + Send {
1668 assert!(matches!(self.state, ConnectionState::Clean));
1669 self.writer.seq = 0;
1670 let mut p = self.writer.compose();
1671 p.put_u8(com::PING);
1672 p.finalize();
1673 self.state = ConnectionState::PingSend;
1674
1675 async move {
1676 self.test_cancel()?;
1677 self.writer.send().await?;
1678
1679 self.state = ConnectionState::PingRecv;
1680 self.test_cancel()?;
1681 let package = self.reader.read().await?;
1682 {
1683 let mut pp = PackageParser::new(package);
1684 match pp.get_u8().loc("first_byte")? {
1685 255 => {
1686 handle_mysql_error(&mut pp)?;
1687 unreachable!()
1688 }
1689 0 => {
1690 self.state = ConnectionState::Clean;
1691 Ok(())
1692 }
1693 v => {
1694 self.state = ConnectionState::Broken;
1695 Err(ConnectionErrorContent::ProtocolError(format!(
1696 "Unexpected response type {v} to ping"
1697 ))
1698 .into())
1699 }
1700 }
1701 }
1702 }
1703 }
1704}
1705
1706pub struct Connection {
1708 prepared_statements: LRUCache<Statement>,
1712
1713 prepared_statement: Option<Statement>,
1715 raw: RawConnection,
1717 transaction_depth: usize,
1719 cleanup_rollbacks: usize,
1721}
1722
1723pub struct Query<'a> {
1725 connection: &'a mut RawConnection,
1727 statement: &'a Statement,
1729 cur_param: u16,
1731 null_offset: usize,
1733 type_offset: usize,
1735}
1736
1737impl<'a> Query<'a> {
1738 #[inline]
1740 pub fn information(&self) -> Option<&StatementInformation> {
1741 self.statement.information.as_ref()
1742 }
1743
1744 #[inline]
1746 pub fn bind<T: Bind + ?Sized>(mut self, v: &T) -> ConnectionResult<Self> {
1747 if self.cur_param == self.statement.num_params {
1748 return Err(ConnectionErrorContent::Bind(
1749 self.cur_param,
1750 BindError::TooManyArgumentsBound,
1751 )
1752 .into());
1753 }
1754 let mut w = crate::bind::Writer::new(&mut self.connection.writer.buff);
1755 if !v
1756 .bind(&mut w)
1757 .map_err(|e| ConnectionErrorContent::Bind(self.cur_param, e))?
1758 {
1759 let w = self.cur_param / 8;
1760 let b = self.cur_param % 8;
1761 self.connection.writer.buff[self.null_offset + w as usize] |= 1 << b;
1762 }
1763
1764 self.connection.writer.buff[self.type_offset + (self.cur_param * 2) as usize] = T::TYPE;
1765 if T::UNSIGNED {
1766 self.connection.writer.buff[self.type_offset + (self.cur_param * 2) as usize + 1] = 128;
1767 }
1768 self.cur_param += 1;
1769 Ok(self)
1770 }
1771
1772 pub async fn fetch_optional_map<M: RowMap<'a>>(self) -> Result<Option<M::T>, M::E> {
1778 if self.cur_param != self.statement.num_params {
1779 return Err(ConnectionError::from(ConnectionErrorContent::Bind(
1780 self.cur_param,
1781 BindError::TooFewArgumentsBound,
1782 ))
1783 .into());
1784 }
1785 match self.connection.query_send().await? {
1786 QueryResult::WithColumns => (),
1787 QueryResult::ExecuteResult(_) => {
1788 return Err(ConnectionError::from(ConnectionErrorContent::ExpectedRows).into());
1789 }
1790 }
1791
1792 let start_instant = self.connection.stats.get_instant();
1793 self.connection.test_cancel()?;
1795 let p1 = self.connection.reader.read_raw().await?;
1796 {
1797 let mut pp = PackageParser::new(self.connection.reader.bytes(p1.clone()));
1798 match pp.get_u8().loc("Row first byte")? {
1799 0x00 => (),
1800 0xFE => {
1801 self.connection.state = ConnectionState::Clean;
1803 return Ok(None);
1804 }
1805 0xFF => {
1806 handle_mysql_error(&mut pp)?;
1807 unreachable!()
1808 }
1809 v => {
1810 return Err(ConnectionError::from(ConnectionErrorContent::ProtocolError(
1811 format!("Unexpected response type {v} to row package"),
1812 ))
1813 .into());
1814 }
1815 }
1816 }
1817
1818 self.connection.reader.buffer_packages = true;
1820
1821 self.connection.test_cancel()?;
1823 let p2 = self.connection.reader.read_raw().await?;
1824 {
1825 let mut pp = PackageParser::new(self.connection.reader.bytes(p2));
1826 match pp.get_u8().loc("Row first byte")? {
1827 0x00 => {
1828 return Err(
1829 ConnectionError::from(ConnectionErrorContent::UnexpectedRows).into(),
1830 );
1831 }
1832 0xFE => {
1833 self.connection.state = ConnectionState::Clean;
1834 }
1835 0xFF => {
1836 handle_mysql_error(&mut pp)?;
1837 unreachable!()
1838 }
1839 v => {
1840 return Err(ConnectionError::from(ConnectionErrorContent::ProtocolError(
1841 format!("Unexpected response type {v} to row package"),
1842 ))
1843 .into());
1844 }
1845 }
1846 }
1847
1848 self.connection.stats.add_fetch(start_instant);
1849 let row = Row::new(&self.connection.columns, self.connection.reader.bytes(p1));
1850 Ok(Some(M::map(row)?))
1851 }
1852
1853 pub fn fetch_optional<T: FromRow<'a>>(
1859 self,
1860 ) -> impl Future<Output = ConnectionResult<Option<T>>> + Send {
1861 self.fetch_optional_map::<FromRowMapper<T>>()
1862 }
1863
1864 #[inline]
1870 pub async fn fetch_one<T: FromRow<'a>>(self) -> ConnectionResult<T> {
1871 match self.fetch_optional().await? {
1872 Some(v) => Ok(v),
1873 None => Err(ConnectionErrorContent::ExpectedRows.into()),
1874 }
1875 }
1876
1877 pub async fn fetch_all_map<M: RowMap<'a>>(self) -> Result<Vec<M::T>, M::E> {
1881 if self.cur_param != self.statement.num_params {
1882 return Err(ConnectionError::from(ConnectionErrorContent::Bind(
1883 self.cur_param,
1884 BindError::TooFewArgumentsBound,
1885 ))
1886 .into());
1887 }
1888 let start_instant = self.connection.stats.get_instant();
1889 match self.connection.query_send().await? {
1890 QueryResult::WithColumns => (),
1891 QueryResult::ExecuteResult(_) => {
1892 return Err(ConnectionError::from(ConnectionErrorContent::ExpectedRows).into());
1893 }
1894 };
1895
1896 self.connection.ranges.clear();
1897 loop {
1898 self.connection.test_cancel()?;
1900 let p = self.connection.reader.read_raw().await?;
1901 {
1902 let mut pp = PackageParser::new(self.connection.reader.bytes(p.clone()));
1903 match pp.get_u8().loc("Row first byte")? {
1904 0x00 => self.connection.ranges.push(p),
1905 0xFE => {
1906 self.connection.state = ConnectionState::Clean;
1908 break;
1909 }
1910 0xFF => {
1911 handle_mysql_error(&mut pp)?;
1912 unreachable!()
1913 }
1914 v => {
1915 return Err(ConnectionError::from(ConnectionErrorContent::ProtocolError(
1916 format!("Unexpected response type {v} to row package"),
1917 ))
1918 .into());
1919 }
1920 }
1921 }
1922
1923 self.connection.reader.buffer_packages = true;
1925 }
1926
1927 self.connection.stats.add_fetch(start_instant);
1928 let mut ans = Vec::with_capacity(self.connection.ranges.len());
1929 for p in &self.connection.ranges {
1930 let row = Row::new(
1931 &self.connection.columns,
1932 self.connection.reader.bytes(p.clone()),
1933 );
1934 ans.push(M::map(row)?);
1935 }
1936 Ok(ans)
1937 }
1938
1939 pub fn fetch_all<T: FromRow<'a>>(
1943 self,
1944 ) -> impl Future<Output = ConnectionResult<Vec<T>>> + Send {
1945 self.fetch_all_map::<FromRowMapper<T>>()
1946 }
1947
1948 pub async fn fetch(self) -> ConnectionResult<QueryIterator<'a>> {
1950 if self.cur_param != self.statement.num_params {
1951 return Err(ConnectionErrorContent::Bind(
1952 self.cur_param,
1953 BindError::TooFewArgumentsBound,
1954 )
1955 .into());
1956 }
1957 match self.connection.query_send().await? {
1958 QueryResult::ExecuteResult(_) => Err(ConnectionErrorContent::ExpectedRows.into()),
1959 QueryResult::WithColumns => Ok(QueryIterator {
1960 connection: self.connection,
1961 }),
1962 }
1963 }
1964
1965 pub async fn fetch_map<M>(self) -> ConnectionResult<MapQueryIterator<'a, M>>
1967 where
1968 for<'b> M: RowMap<'b>,
1969 {
1970 if self.cur_param != self.statement.num_params {
1971 return Err(ConnectionErrorContent::Bind(
1972 self.cur_param,
1973 BindError::TooFewArgumentsBound,
1974 )
1975 .into());
1976 }
1977 match self.connection.query_send().await? {
1978 QueryResult::ExecuteResult(_) => Err(ConnectionErrorContent::ExpectedRows.into()),
1979 QueryResult::WithColumns => Ok(MapQueryIterator {
1980 connection: self.connection,
1981 _phantom: Default::default(),
1982 }),
1983 }
1984 }
1985
1986 pub async fn execute(self) -> ConnectionResult<ExecuteResult> {
1988 if self.cur_param != self.statement.num_params {
1989 return Err(ConnectionErrorContent::Bind(
1990 self.cur_param,
1991 BindError::TooFewArgumentsBound,
1992 )
1993 .into());
1994 }
1995 match self.connection.query_send().await? {
1996 QueryResult::WithColumns => Err(ConnectionErrorContent::UnexpectedRows.into()),
1997 QueryResult::ExecuteResult(v) => Ok(v),
1998 }
1999 }
2000}
2001
2002pub struct Transaction<'a> {
2009 connection: &'a mut Connection,
2011}
2012
2013impl<'a> Transaction<'a> {
2014 pub async fn commit(self) -> ConnectionResult<()> {
2020 self.connection.cleanup().await?;
2021 let mut this = ManuallyDrop::new(self);
2022 this.connection.commit_impl().await?;
2023 Ok(())
2024 }
2025
2026 pub async fn rollback(self) -> ConnectionResult<()> {
2031 self.connection.cleanup().await?;
2032 let mut this = ManuallyDrop::new(self);
2033 this.connection.rollback_impl().await?;
2034 Ok(())
2035 }
2036}
2037
2038impl<'a> Executor for Transaction<'a> {
2039 #[inline]
2040 fn query_raw(
2041 &mut self,
2042 stmt: Cow<'static, str>,
2043 options: QueryOptions,
2044 ) -> impl Future<Output = ConnectionResult<Query<'_>>> + Send {
2045 self.connection.query_inner(stmt, options)
2046 }
2047
2048 #[inline]
2049 fn begin(&mut self) -> impl Future<Output = ConnectionResult<Transaction<'_>>> + Send {
2050 self.connection.begin_impl()
2051 }
2052
2053 #[inline]
2054 fn query_with_args_raw(
2055 &mut self,
2056 stmt: Cow<'static, str>,
2057 options: QueryOptions,
2058 args: impl Args + Send,
2059 ) -> impl Future<Output = ConnectionResult<Query<'_>>> {
2060 self.connection.query_with_args_raw(stmt, options, args)
2061 }
2062
2063 #[inline]
2064 fn execute_unprepared(
2065 &mut self,
2066 stmt: &str,
2067 ) -> impl Future<Output = ConnectionResult<ExecuteResult>> + Send {
2068 self.connection.execute_unprepared(stmt)
2069 }
2070
2071 #[inline]
2072 fn ping(&mut self) -> impl Future<Output = ConnectionResult<()>> + Send {
2073 self.connection.ping()
2074 }
2075}
2076
2077impl<'a> Drop for Transaction<'a> {
2078 fn drop(&mut self) {
2079 self.connection.cleanup_rollbacks += 1;
2081 }
2082}
2083
2084pub trait Executor: Sized + Send {
2086 fn query_raw(
2097 &mut self,
2098 stmt: Cow<'static, str>,
2099 options: QueryOptions,
2100 ) -> impl Future<Output = ConnectionResult<Query<'_>>> + Send;
2101
2102 fn query_with_args_raw(
2111 &mut self,
2112 stmt: Cow<'static, str>,
2113 options: QueryOptions,
2114 args: impl Args + Send,
2115 ) -> impl Future<Output = ConnectionResult<Query<'_>>> + Send;
2116
2117 fn execute_unprepared(
2121 &mut self,
2122 stmt: &str,
2123 ) -> impl Future<Output = ConnectionResult<ExecuteResult>> + Send;
2124
2125 fn begin(&mut self) -> impl Future<Output = ConnectionResult<Transaction<'_>>> + Send;
2130
2131 fn ping(&mut self) -> impl Future<Output = ConnectionResult<()>> + Send;
2133}
2134
2135pub trait ExecutorExt {
2137 fn query(
2148 &mut self,
2149 stmt: impl Into<Cow<'static, str>>,
2150 ) -> impl Future<Output = ConnectionResult<Query<'_>>> + Send;
2151
2152 fn query_with_options(
2163 &mut self,
2164 stmt: impl Into<Cow<'static, str>>,
2165 options: QueryOptions,
2166 ) -> impl Future<Output = ConnectionResult<Query<'_>>> + Send;
2167
2168 fn query_with_args(
2183 &mut self,
2184 stmt: impl Into<Cow<'static, str>>,
2185 args: impl Args + Send,
2186 ) -> impl Future<Output = ConnectionResult<Query<'_>>> + Send;
2187
2188 fn fetch_all<'a, T: FromRow<'a>>(
2197 &'a mut self,
2198 stmt: impl Into<Cow<'static, str>>,
2199 args: impl Args + Send,
2200 ) -> impl Future<Output = ConnectionResult<Vec<T>>> + Send;
2201
2202 fn fetch_all_map<'a, M: RowMap<'a>>(
2211 &'a mut self,
2212 stmt: impl Into<Cow<'static, str>>,
2213 args: impl Args + Send,
2214 ) -> impl Future<Output = Result<Vec<M::T>, M::E>> + Send;
2215
2216 fn fetch_one<'a, T: FromRow<'a>>(
2225 &'a mut self,
2226 stmt: impl Into<Cow<'static, str>>,
2227 args: impl Args + Send,
2228 ) -> impl Future<Output = ConnectionResult<T>> + Send;
2229
2230 fn fetch_one_map<'a, M: RowMap<'a>>(
2239 &'a mut self,
2240 stmt: impl Into<Cow<'static, str>>,
2241 args: impl Args + Send,
2242 ) -> impl Future<Output = Result<M::T, M::E>> + Send;
2243
2244 fn fetch_optional<'a, T: FromRow<'a>>(
2253 &'a mut self,
2254 stmt: impl Into<Cow<'static, str>>,
2255 args: impl Args + Send,
2256 ) -> impl Future<Output = ConnectionResult<Option<T>>> + Send;
2257
2258 fn fetch_optional_map<'a, M: RowMap<'a>>(
2267 &'a mut self,
2268 stmt: impl Into<Cow<'static, str>>,
2269 args: impl Args + Send,
2270 ) -> impl Future<Output = Result<Option<M::T>, M::E>> + Send;
2271
2272 fn execute(
2281 &mut self,
2282 stmt: impl Into<Cow<'static, str>>,
2283 args: impl Args + Send,
2284 ) -> impl Future<Output = ConnectionResult<ExecuteResult>> + Send;
2285
2286 fn fetch(
2295 &mut self,
2296 stmt: impl Into<Cow<'static, str>>,
2297 args: impl Args + Send,
2298 ) -> impl Future<Output = ConnectionResult<QueryIterator<'_>>> + Send;
2299
2300 fn fetch_map<'a, M>(
2309 &'a mut self,
2310 stmt: impl Into<Cow<'static, str>>,
2311 args: impl Args + Send,
2312 ) -> impl Future<Output = ConnectionResult<MapQueryIterator<'a, M>>> + Send
2313 where
2314 for<'b> M: RowMap<'b>;
2315}
2316
2317async fn fetch_all_impl<'a, E: Executor + Sized + Send, T: FromRow<'a>>(
2319 e: &'a mut E,
2320 stmt: Cow<'static, str>,
2321 args: impl Args + Send,
2322) -> ConnectionResult<Vec<T>> {
2323 let q = e
2324 .query_with_args_raw(stmt, QueryOptions::new(), args)
2325 .await?;
2326 q.fetch_all().await
2327}
2328
2329async fn fetch_all_map_impl<'a, E: Executor + Sized + Send, M: RowMap<'a>>(
2331 e: &'a mut E,
2332 stmt: Cow<'static, str>,
2333 args: impl Args + Send,
2334) -> Result<Vec<M::T>, M::E> {
2335 let q = e
2336 .query_with_args_raw(stmt, QueryOptions::new(), args)
2337 .await?;
2338 q.fetch_all_map::<M>().await
2339}
2340
2341async fn fetch_one_impl<'a, E: Executor + Sized + Send, T: FromRow<'a>>(
2343 e: &'a mut E,
2344 stmt: Cow<'static, str>,
2345 args: impl Args + Send,
2346) -> ConnectionResult<T> {
2347 let q = e
2348 .query_with_args_raw(stmt, QueryOptions::new(), args)
2349 .await?;
2350 match q.fetch_optional().await? {
2351 Some(v) => Ok(v),
2352 None => Err(ConnectionErrorContent::ExpectedRows.into()),
2353 }
2354}
2355
2356async fn fetch_one_map_impl<'a, E: Executor + Sized + Send, M: RowMap<'a>>(
2358 e: &'a mut E,
2359 stmt: Cow<'static, str>,
2360 args: impl Args + Send,
2361) -> Result<M::T, M::E> {
2362 let q = e
2363 .query_with_args_raw(stmt, QueryOptions::new(), args)
2364 .await
2365 .map_err(M::E::from)?;
2366 match q.fetch_optional_map::<M>().await? {
2367 Some(v) => Ok(v),
2368 None => Err(ConnectionError::from(ConnectionErrorContent::ExpectedRows).into()),
2369 }
2370}
2371
2372async fn fetch_optional_impl<'a, E: Executor + Sized + Send, T: FromRow<'a>>(
2374 e: &'a mut E,
2375 stmt: Cow<'static, str>,
2376 args: impl Args + Send,
2377) -> ConnectionResult<Option<T>> {
2378 let q = e
2379 .query_with_args_raw(stmt, QueryOptions::new(), args)
2380 .await?;
2381 q.fetch_optional().await
2382}
2383
2384async fn fetch_optional_map_impl<'a, E: Executor + Sized + Send, M: RowMap<'a>>(
2386 e: &'a mut E,
2387 stmt: Cow<'static, str>,
2388 args: impl Args + Send,
2389) -> Result<Option<M::T>, M::E> {
2390 let q = e
2391 .query_with_args_raw(stmt, QueryOptions::new(), args)
2392 .await?;
2393 q.fetch_optional_map::<M>().await
2394}
2395
2396async fn execute_impl<E: Executor + Sized + Send>(
2398 e: &mut E,
2399 stmt: Cow<'static, str>,
2400 args: impl Args + Send,
2401) -> ConnectionResult<ExecuteResult> {
2402 let q = e
2403 .query_with_args_raw(stmt, QueryOptions::new(), args)
2404 .await?;
2405 q.execute().await
2406}
2407
2408async fn fetch_impl<'a, E: Executor + Sized + Send>(
2410 e: &'a mut E,
2411 stmt: Cow<'static, str>,
2412 args: impl Args + Send,
2413) -> ConnectionResult<QueryIterator<'a>> {
2414 let q = e
2415 .query_with_args_raw(stmt, QueryOptions::new(), args)
2416 .await?;
2417 q.fetch().await
2418}
2419
2420async fn fetch_map_impl<'a, E: Executor + Sized + Send, M>(
2422 e: &'a mut E,
2423 stmt: Cow<'static, str>,
2424 args: impl Args + Send,
2425) -> ConnectionResult<MapQueryIterator<'a, M>>
2426where
2427 for<'b> M: RowMap<'b>,
2428{
2429 let q = e
2430 .query_with_args_raw(stmt, QueryOptions::new(), args)
2431 .await?;
2432 q.fetch_map::<M>().await
2433}
2434
2435impl<E: Executor + Sized + Send> ExecutorExt for E {
2436 #[inline]
2437 fn query(
2438 &mut self,
2439 stmt: impl Into<Cow<'static, str>>,
2440 ) -> impl Future<Output = ConnectionResult<Query<'_>>> + Send {
2441 self.query_raw(stmt.into(), QueryOptions::new())
2442 }
2443
2444 #[inline]
2445 fn query_with_options(
2446 &mut self,
2447 stmt: impl Into<Cow<'static, str>>,
2448 options: QueryOptions,
2449 ) -> impl Future<Output = ConnectionResult<Query<'_>>> + Send {
2450 self.query_raw(stmt.into(), options)
2451 }
2452
2453 #[inline]
2454 fn query_with_args(
2455 &mut self,
2456 stmt: impl Into<Cow<'static, str>>,
2457 args: impl Args + Send,
2458 ) -> impl Future<Output = ConnectionResult<Query<'_>>> {
2459 self.query_with_args_raw(stmt.into(), QueryOptions::new(), args)
2460 }
2461
2462 #[inline]
2463 fn fetch_all<'a, T: FromRow<'a>>(
2464 &'a mut self,
2465 stmt: impl Into<Cow<'static, str>>,
2466 args: impl Args + Send,
2467 ) -> impl Future<Output = ConnectionResult<Vec<T>>> + Send {
2468 fetch_all_impl(self, stmt.into(), args)
2469 }
2470
2471 #[inline]
2472 fn fetch_all_map<'a, M: RowMap<'a>>(
2473 &'a mut self,
2474 stmt: impl Into<Cow<'static, str>>,
2475 args: impl Args + Send,
2476 ) -> impl Future<Output = Result<Vec<M::T>, M::E>> + Send {
2477 fetch_all_map_impl::<E, M>(self, stmt.into(), args)
2478 }
2479
2480 #[inline]
2481 fn fetch_one<'a, T: FromRow<'a>>(
2482 &'a mut self,
2483 stmt: impl Into<Cow<'static, str>>,
2484 args: impl Args + Send,
2485 ) -> impl Future<Output = ConnectionResult<T>> + Send {
2486 fetch_one_impl(self, stmt.into(), args)
2487 }
2488
2489 #[inline]
2490 fn fetch_one_map<'a, M: RowMap<'a>>(
2491 &'a mut self,
2492 stmt: impl Into<Cow<'static, str>>,
2493 args: impl Args + Send,
2494 ) -> impl Future<Output = Result<M::T, M::E>> + Send {
2495 fetch_one_map_impl::<E, M>(self, stmt.into(), args)
2496 }
2497
2498 #[inline]
2499 fn fetch_optional<'a, T: FromRow<'a>>(
2500 &'a mut self,
2501 stmt: impl Into<Cow<'static, str>>,
2502 args: impl Args + Send,
2503 ) -> impl Future<Output = ConnectionResult<Option<T>>> + Send {
2504 fetch_optional_impl(self, stmt.into(), args)
2505 }
2506
2507 #[inline]
2508 fn fetch_optional_map<'a, M: RowMap<'a>>(
2509 &'a mut self,
2510 stmt: impl Into<Cow<'static, str>>,
2511 args: impl Args + Send,
2512 ) -> impl Future<Output = Result<Option<M::T>, M::E>> + Send {
2513 fetch_optional_map_impl::<E, M>(self, stmt.into(), args)
2514 }
2515
2516 #[inline]
2517 fn execute(
2518 &mut self,
2519 stmt: impl Into<Cow<'static, str>>,
2520 args: impl Args + Send,
2521 ) -> impl Future<Output = ConnectionResult<ExecuteResult>> + Send {
2522 execute_impl(self, stmt.into(), args)
2523 }
2524
2525 #[inline]
2526 fn fetch(
2527 &mut self,
2528 stmt: impl Into<Cow<'static, str>>,
2529 args: impl Args + Send,
2530 ) -> impl Future<Output = ConnectionResult<QueryIterator<'_>>> + Send {
2531 fetch_impl(self, stmt.into(), args)
2532 }
2533
2534 #[inline]
2535 fn fetch_map<'a, M>(
2536 &'a mut self,
2537 stmt: impl Into<Cow<'static, str>>,
2538 args: impl Args + Send,
2539 ) -> impl Future<Output = ConnectionResult<MapQueryIterator<'a, M>>> + Send
2540 where
2541 for<'b> M: RowMap<'b>,
2542 {
2543 fetch_map_impl::<E, M>(self, stmt.into(), args)
2544 }
2545}
2546
2547pub struct QueryOptions {
2549 cache: bool,
2551 information: bool,
2553}
2554
2555impl QueryOptions {
2556 pub fn new() -> Self {
2558 Default::default()
2559 }
2560
2561 pub fn cache(self, enable: bool) -> Self {
2563 QueryOptions {
2564 cache: enable,
2565 ..self
2566 }
2567 }
2568
2569 pub fn information(self, enable: bool) -> Self {
2571 QueryOptions {
2572 information: enable,
2573 ..self
2574 }
2575 }
2576}
2577
2578impl Default for QueryOptions {
2579 fn default() -> Self {
2580 Self {
2581 cache: true,
2582 information: false,
2583 }
2584 }
2585}
2586
2587impl Connection {
2588 pub async fn connect(options: &ConnectionOptions<'_>) -> ConnectionResult<Self> {
2590 let raw = RawConnection::connect(options).await?;
2591 Ok(Connection {
2592 raw,
2593 prepared_statements: LRUCache::new(options.statement_case_size),
2594 transaction_depth: 0,
2595 cleanup_rollbacks: 0,
2596 prepared_statement: None,
2597 })
2598 }
2599
2600 pub fn is_clean(&self) -> bool {
2602 matches!(self.raw.state, ConnectionState::Clean)
2603 }
2604
2605 pub async fn cleanup(&mut self) -> ConnectionResult<()> {
2607 self.raw.cleanup().await?;
2608
2609 assert!(self.cleanup_rollbacks <= self.transaction_depth);
2610 if self.cleanup_rollbacks != 0 {
2611 let statement =
2612 rollback_transaction_query(self.transaction_depth - self.cleanup_rollbacks);
2613 self.transaction_depth -= self.cleanup_rollbacks;
2615 self.cleanup_rollbacks = 0;
2616 self.raw.execute_unprepared(statement).await?;
2617 }
2618
2619 if let Some(v) = self.prepared_statement.take() {
2620 self.raw.close_prepared_statement(v.stmt_id).await?
2621 }
2622 Ok(())
2623 }
2624
2625 async fn query_inner(
2628 &mut self,
2629 stmt: Cow<'static, str>,
2630 options: QueryOptions,
2631 ) -> ConnectionResult<Query<'_>> {
2632 self.cleanup().await?;
2633 if !options.cache {
2634 let r = self.raw.prepare_query(&stmt, options.information).await?;
2635 self.prepared_statement = Some(r);
2636 Ok(self.raw.query(self.prepared_statement.as_ref().unwrap()))
2637 } else {
2638 let statement = match self.prepared_statements.entry(stmt) {
2639 Entry::Occupied(mut e) => {
2640 if e.get().information.is_none() && options.information {
2641 let r = self.raw.prepare_query(e.key(), options.information).await?;
2643 let old = e.insert(r);
2644 self.raw.close_prepared_statement(old.stmt_id).await?
2645 }
2646 e.bump();
2647 e.into_mut()
2648 }
2649 Entry::Vacant(e) => {
2650 let r = self.raw.prepare_query(e.key(), options.information).await?;
2651 let (r, old) = e.insert(r);
2652 if let Some((_, old)) = old {
2653 self.raw.close_prepared_statement(old.stmt_id).await?
2654 }
2655 r
2656 }
2657 };
2658 Ok(self.raw.query(statement))
2659 }
2660 }
2661
2662 async fn begin_impl(&mut self) -> ConnectionResult<Transaction<'_>> {
2664 self.cleanup().await?;
2665
2666 assert_eq!(self.cleanup_rollbacks, 0); let q = begin_transaction_query(self.transaction_depth);
2671 self.transaction_depth += 1;
2672 self.cleanup_rollbacks = 1;
2673 self.raw.execute_unprepared(q).await?;
2674
2675 assert_eq!(self.cleanup_rollbacks, 1);
2677 self.cleanup_rollbacks = 0;
2678 Ok(Transaction { connection: self })
2679 }
2680
2681 async fn ping_impl(&mut self) -> ConnectionResult<()> {
2683 self.cleanup().await?;
2684 self.raw.ping().await?;
2685 Ok(())
2686 }
2687
2688 fn rollback_impl(&mut self) -> impl Future<Output = ConnectionResult<ExecuteResult>> + Send {
2690 assert!(matches!(self.raw.state, ConnectionState::Clean));
2691 assert_eq!(self.cleanup_rollbacks, 0);
2692 assert_ne!(self.transaction_depth, 0);
2693 self.transaction_depth -= 1;
2694
2695 self.raw
2698 .execute_unprepared(rollback_transaction_query(self.transaction_depth))
2699 }
2700
2701 fn commit_impl(&mut self) -> impl Future<Output = ConnectionResult<ExecuteResult>> + Send {
2703 assert!(matches!(self.raw.state, ConnectionState::Clean));
2704 assert_eq!(self.cleanup_rollbacks, 0);
2705 assert_ne!(self.transaction_depth, 0);
2706
2707 self.transaction_depth -= 1;
2708
2709 self.raw
2712 .execute_unprepared(commit_transaction_query(self.transaction_depth))
2713 }
2714
2715 #[cfg(feature = "cancel_testing")]
2716 #[doc(hidden)]
2717 pub fn set_cancel_count(&mut self, cnt: Option<usize>) {
2719 self.raw.cancel_count = cnt;
2720 }
2721
2722 #[cfg(feature = "stats")]
2723 pub fn stats(&self) -> &Stats {
2727 &self.raw.stats
2728 }
2729
2730 #[cfg(feature = "stats")]
2734 pub fn clear_stats(&mut self) {
2735 self.raw.stats = Default::default()
2736 }
2737}
2738
2739#[cfg(feature = "list_hack")]
2740fn convert_list_query(
2743 stmt: Cow<'static, str>,
2744 lengths: &[usize],
2745) -> ConnectionResult<Cow<'static, str>> {
2746 if let Some((head, tail)) = stmt.split_once("_LIST_") {
2747 let mut stmt = String::with_capacity(stmt.len() + 2 * lengths.iter().sum::<usize>());
2748 stmt.push_str(head);
2749 let mut len_it = lengths.iter();
2750 for part in tail.split("_LIST_") {
2751 let Some(len) = len_it.next() else {
2752 return Err(ConnectionErrorContent::TooFewListArguments.into());
2753 };
2754 if *len == 0 {
2755 stmt.push_str("NULL");
2756 } else {
2757 for i in 0..*len {
2758 if i == 0 {
2759 stmt.push('?');
2760 } else {
2761 stmt.push_str(", ?");
2762 }
2763 }
2764 }
2765 stmt.push_str(part);
2766 }
2767 if len_it.next().is_some() {
2768 return Err(ConnectionErrorContent::TooManyListArguments.into());
2769 }
2770 Ok(stmt.into())
2771 } else {
2772 if !lengths.is_empty() {
2773 return Err(ConnectionErrorContent::TooManyListArguments.into());
2774 }
2775 Ok(stmt)
2776 }
2777}
2778
2779impl Executor for Connection {
2780 #[inline]
2781 fn query_raw(
2782 &mut self,
2783 stmt: Cow<'static, str>,
2784 options: QueryOptions,
2785 ) -> impl Future<Output = ConnectionResult<Query<'_>>> + Send {
2786 self.query_inner(stmt, options)
2787 }
2788
2789 #[inline]
2790 fn begin(&mut self) -> impl Future<Output = ConnectionResult<Transaction<'_>>> + Send {
2791 self.begin_impl()
2792 }
2793
2794 #[inline]
2795 async fn query_with_args_raw(
2796 &mut self,
2797 stmt: Cow<'static, str>,
2798 options: QueryOptions,
2799 args: impl Args + Send,
2800 ) -> ConnectionResult<Query<'_>> {
2801 #[cfg(feature = "list_hack")]
2802 let stmt = {
2803 self.raw.list_lengths.clear();
2804 args.list_lengths(&mut self.raw.list_lengths);
2805 convert_list_query(stmt, &self.raw.list_lengths)?
2806 };
2807 self.cleanup().await?;
2808 let query = self.query_inner(stmt, options).await?;
2809 args.bind_args(query)
2810 }
2811
2812 async fn execute_unprepared(&mut self, stmt: &str) -> ConnectionResult<ExecuteResult> {
2813 self.cleanup().await?;
2814 self.raw.execute_unprepared(stmt.into()).await
2815 }
2816
2817 #[inline]
2818 fn ping(&mut self) -> impl Future<Output = ConnectionResult<()>> + Send {
2819 self.ping_impl()
2820 }
2821}