1#![allow(clippy::manual_async_fn)]
8#![allow(clippy::result_large_err)]
10
11use std::collections::HashMap;
12use std::future::Future;
13use std::io::{self, Read as StdRead, Write as StdWrite};
14use std::net::TcpStream as StdTcpStream;
15use std::sync::Arc;
16
17use asupersync::io::{AsyncRead, AsyncWrite, ReadBuf};
18use asupersync::net::TcpStream;
19use asupersync::sync::Mutex;
20use asupersync::{Cx, Outcome};
21
22use sqlmodel_core::connection::{Connection, IsolationLevel, PreparedStatement, TransactionOps};
23use sqlmodel_core::error::{
24 ConnectionError, ConnectionErrorKind, ProtocolError, QueryError, QueryErrorKind,
25};
26use sqlmodel_core::{Error, Row, Value};
27
28#[cfg(feature = "console")]
29use sqlmodel_console::{ConsoleAware, SqlModelConsole};
30
31use crate::auth;
32use crate::config::MySqlConfig;
33use crate::connection::{ConnectionState, ServerCapabilities};
34use crate::protocol::{
35 Command, ErrPacket, MAX_PACKET_SIZE, PacketHeader, PacketReader, PacketType, PacketWriter,
36 capabilities, charset, prepared,
37};
38use crate::types::{
39 ColumnDef, FieldType, decode_binary_value_with_len, decode_text_value, interpolate_params,
40};
41
42pub struct MySqlAsyncConnection {
47 stream: Option<ConnectionStream>,
49 state: ConnectionState,
51 server_caps: Option<ServerCapabilities>,
53 connection_id: u32,
55 status_flags: u16,
57 affected_rows: u64,
59 last_insert_id: u64,
61 warnings: u16,
63 config: MySqlConfig,
65 sequence_id: u8,
67 prepared_stmts: HashMap<u32, PreparedStmtMeta>,
69 #[cfg(feature = "console")]
71 console: Option<Arc<SqlModelConsole>>,
72}
73
74#[derive(Debug, Clone)]
79struct PreparedStmtMeta {
80 #[allow(dead_code)]
82 statement_id: u32,
83 params: Vec<ColumnDef>,
85 columns: Vec<ColumnDef>,
87}
88
89#[allow(dead_code)]
91enum ConnectionStream {
92 Sync(StdTcpStream),
94 Async(TcpStream),
96 #[cfg(feature = "tls")]
98 Tls(AsyncTlsStream),
99}
100
101#[cfg(feature = "tls")]
106struct AsyncTlsStream {
107 tcp: TcpStream,
108 tls: rustls::ClientConnection,
109}
110
111#[cfg(feature = "tls")]
112impl std::fmt::Debug for AsyncTlsStream {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 f.debug_struct("AsyncTlsStream")
115 .field("protocol_version", &self.tls.protocol_version())
116 .field("is_handshaking", &self.tls.is_handshaking())
117 .finish_non_exhaustive()
118 }
119}
120
121#[cfg(feature = "tls")]
122impl AsyncTlsStream {
123 async fn handshake(
124 mut tcp: TcpStream,
125 tls_config: &crate::config::TlsConfig,
126 host: &str,
127 ssl_mode: crate::config::SslMode,
128 ) -> Result<Self, Error> {
129 let config = crate::tls::build_client_config(tls_config, ssl_mode)?;
130
131 let sni = tls_config.server_name.as_deref().unwrap_or(host);
132 let server_name = sni
133 .to_string()
134 .try_into()
135 .map_err(|e| connection_error(format!("Invalid server name '{sni}': {e}")))?;
136
137 let mut tls = rustls::ClientConnection::new(std::sync::Arc::new(config), server_name)
138 .map_err(|e| connection_error(format!("Failed to create TLS connection: {e}")))?;
139
140 while tls.is_handshaking() {
142 while tls.wants_write() {
143 let mut out = Vec::new();
144 tls.write_tls(&mut out)
145 .map_err(|e| connection_error(format!("TLS handshake write_tls error: {e}")))?;
146 if !out.is_empty() {
147 write_all_async(&mut tcp, &out).await.map_err(|e| {
148 Error::Connection(ConnectionError {
149 kind: ConnectionErrorKind::Disconnected,
150 message: format!("TLS handshake write error: {e}"),
151 source: Some(Box::new(e)),
152 })
153 })?;
154 }
155 }
156
157 if tls.wants_read() {
158 let mut buf = [0u8; 8192];
159 let n = read_some_async(&mut tcp, &mut buf).await.map_err(|e| {
160 Error::Connection(ConnectionError {
161 kind: ConnectionErrorKind::Disconnected,
162 message: format!("TLS handshake read error: {e}"),
163 source: Some(Box::new(e)),
164 })
165 })?;
166 if n == 0 {
167 return Err(connection_error("Connection closed during TLS handshake"));
168 }
169
170 let mut cursor = std::io::Cursor::new(&buf[..n]);
171 tls.read_tls(&mut cursor)
172 .map_err(|e| connection_error(format!("TLS handshake read_tls error: {e}")))?;
173 tls.process_new_packets()
174 .map_err(|e| connection_error(format!("TLS handshake error: {e}")))?;
175 }
176 }
177
178 Ok(Self { tcp, tls })
179 }
180
181 async fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
182 let mut read = 0;
183 while read < buf.len() {
184 let n = self.read_plain(&mut buf[read..]).await?;
185 if n == 0 {
186 return Err(io::Error::new(
187 io::ErrorKind::UnexpectedEof,
188 "connection closed",
189 ));
190 }
191 read += n;
192 }
193 Ok(())
194 }
195
196 async fn read_plain(&mut self, out: &mut [u8]) -> io::Result<usize> {
197 loop {
198 match self.tls.reader().read(out) {
199 Ok(n) if n > 0 => return Ok(n),
200 Ok(_) => {}
201 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
202 Err(e) => return Err(e),
203 }
204
205 if !self.tls.wants_read() {
206 return Ok(0);
207 }
208
209 let mut enc = [0u8; 8192];
210 let n = read_some_async(&mut self.tcp, &mut enc).await?;
211 if n == 0 {
212 return Ok(0);
213 }
214
215 let mut cursor = std::io::Cursor::new(&enc[..n]);
216 self.tls.read_tls(&mut cursor)?;
217 self.tls
218 .process_new_packets()
219 .map_err(|e| io::Error::other(format!("TLS error: {e}")))?;
220 }
221 }
222
223 async fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
224 let mut written = 0;
225 while written < buf.len() {
226 let n = self.tls.writer().write(&buf[written..])?;
227 if n == 0 {
228 return Err(io::Error::new(io::ErrorKind::WriteZero, "TLS write zero"));
229 }
230 written += n;
231 self.flush().await?;
232 }
233 Ok(())
234 }
235
236 async fn flush(&mut self) -> io::Result<()> {
237 self.tls.writer().flush()?;
238 while self.tls.wants_write() {
239 let mut out = Vec::new();
240 self.tls.write_tls(&mut out)?;
241 if !out.is_empty() {
242 write_all_async(&mut self.tcp, &out).await?;
243 }
244 }
245 flush_async(&mut self.tcp).await
246 }
247}
248
249#[cfg(feature = "tls")]
250async fn read_some_async(stream: &mut TcpStream, buf: &mut [u8]) -> io::Result<usize> {
251 let mut read_buf = ReadBuf::new(buf);
252 std::future::poll_fn(|cx| std::pin::Pin::new(&mut *stream).poll_read(cx, &mut read_buf))
253 .await?;
254 Ok(read_buf.filled().len())
255}
256
257#[cfg(feature = "tls")]
258async fn write_all_async(stream: &mut TcpStream, buf: &[u8]) -> io::Result<()> {
259 let mut written = 0;
260 while written < buf.len() {
261 let n = std::future::poll_fn(|cx| {
262 std::pin::Pin::new(&mut *stream).poll_write(cx, &buf[written..])
263 })
264 .await?;
265 if n == 0 {
266 return Err(io::Error::new(
267 io::ErrorKind::WriteZero,
268 "connection closed",
269 ));
270 }
271 written += n;
272 }
273 Ok(())
274}
275
276#[cfg(feature = "tls")]
277async fn flush_async(stream: &mut TcpStream) -> io::Result<()> {
278 std::future::poll_fn(|cx| std::pin::Pin::new(&mut *stream).poll_flush(cx)).await
279}
280
281impl std::fmt::Debug for MySqlAsyncConnection {
282 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283 f.debug_struct("MySqlAsyncConnection")
284 .field("state", &self.state)
285 .field("connection_id", &self.connection_id)
286 .field("host", &self.config.host)
287 .field("port", &self.config.port)
288 .field("database", &self.config.database)
289 .finish_non_exhaustive()
290 }
291}
292
293impl MySqlAsyncConnection {
294 pub async fn connect(_cx: &Cx, config: MySqlConfig) -> Outcome<Self, Error> {
302 let addr = config.socket_addr();
304 let socket_addr = match addr.parse() {
305 Ok(a) => a,
306 Err(e) => {
307 return Outcome::Err(Error::Connection(ConnectionError {
308 kind: ConnectionErrorKind::Connect,
309 message: format!("Invalid socket address: {}", e),
310 source: None,
311 }));
312 }
313 };
314 let stream = match TcpStream::connect_timeout(socket_addr, config.connect_timeout).await {
315 Ok(s) => s,
316 Err(e) => {
317 let kind = if e.kind() == io::ErrorKind::ConnectionRefused {
318 ConnectionErrorKind::Refused
319 } else {
320 ConnectionErrorKind::Connect
321 };
322 return Outcome::Err(Error::Connection(ConnectionError {
323 kind,
324 message: format!("Failed to connect to {}: {}", addr, e),
325 source: Some(Box::new(e)),
326 }));
327 }
328 };
329
330 stream.set_nodelay(true).ok();
332
333 let mut conn = Self {
334 stream: Some(ConnectionStream::Async(stream)),
335 state: ConnectionState::Connecting,
336 server_caps: None,
337 connection_id: 0,
338 status_flags: 0,
339 affected_rows: 0,
340 last_insert_id: 0,
341 warnings: 0,
342 config,
343 sequence_id: 0,
344 prepared_stmts: HashMap::new(),
345 #[cfg(feature = "console")]
346 console: None,
347 };
348
349 match conn.read_handshake_async().await {
351 Outcome::Ok(server_caps) => {
352 conn.connection_id = server_caps.connection_id;
353 conn.server_caps = Some(server_caps);
354 conn.state = ConnectionState::Authenticating;
355 }
356 Outcome::Err(e) => return Outcome::Err(e),
357 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
358 Outcome::Panicked(p) => return Outcome::Panicked(p),
359 }
360
361 if let Outcome::Err(e) = conn.send_handshake_response_async().await {
363 return Outcome::Err(e);
364 }
365
366 if let Outcome::Err(e) = conn.handle_auth_result_async().await {
368 return Outcome::Err(e);
369 }
370
371 conn.state = ConnectionState::Ready;
372 Outcome::Ok(conn)
373 }
374
375 pub fn state(&self) -> ConnectionState {
377 self.state
378 }
379
380 pub fn is_ready(&self) -> bool {
382 matches!(self.state, ConnectionState::Ready)
383 }
384
385 fn is_secure_transport(&self) -> bool {
386 #[cfg(feature = "tls")]
387 {
388 matches!(self.stream, Some(ConnectionStream::Tls(_)))
389 }
390 #[cfg(not(feature = "tls"))]
391 {
392 false
393 }
394 }
395
396 pub fn connection_id(&self) -> u32 {
398 self.connection_id
399 }
400
401 pub fn server_version(&self) -> Option<&str> {
403 self.server_caps
404 .as_ref()
405 .map(|caps| caps.server_version.as_str())
406 }
407
408 pub fn affected_rows(&self) -> u64 {
410 self.affected_rows
411 }
412
413 pub fn last_insert_id(&self) -> u64 {
415 self.last_insert_id
416 }
417
418 async fn read_packet_async(&mut self) -> Outcome<(Vec<u8>, u8), Error> {
422 let mut header_buf = [0u8; 4];
424
425 let Some(stream) = self.stream.as_mut() else {
426 return Outcome::Err(connection_error("Connection stream missing"));
427 };
428
429 match stream {
430 ConnectionStream::Async(stream) => {
431 let mut header_read = 0;
432 while header_read < 4 {
433 let mut read_buf = ReadBuf::new(&mut header_buf[header_read..]);
434 match std::future::poll_fn(|cx| {
435 std::pin::Pin::new(&mut *stream).poll_read(cx, &mut read_buf)
436 })
437 .await
438 {
439 Ok(()) => {
440 let n = read_buf.filled().len();
441 if n == 0 {
442 return Outcome::Err(Error::Connection(ConnectionError {
443 kind: ConnectionErrorKind::Disconnected,
444 message: "Connection closed while reading header".to_string(),
445 source: None,
446 }));
447 }
448 header_read += n;
449 }
450 Err(e) => {
451 return Outcome::Err(Error::Connection(ConnectionError {
452 kind: ConnectionErrorKind::Disconnected,
453 message: format!("Failed to read packet header: {}", e),
454 source: Some(Box::new(e)),
455 }));
456 }
457 }
458 }
459 }
460 ConnectionStream::Sync(stream) => {
461 if let Err(e) = stream.read_exact(&mut header_buf) {
462 return Outcome::Err(Error::Connection(ConnectionError {
463 kind: ConnectionErrorKind::Disconnected,
464 message: format!("Failed to read packet header: {}", e),
465 source: Some(Box::new(e)),
466 }));
467 }
468 }
469 #[cfg(feature = "tls")]
470 ConnectionStream::Tls(stream) => {
471 if let Err(e) = stream.read_exact(&mut header_buf).await {
472 return Outcome::Err(Error::Connection(ConnectionError {
473 kind: ConnectionErrorKind::Disconnected,
474 message: format!("Failed to read packet header: {e}"),
475 source: Some(Box::new(e)),
476 }));
477 }
478 }
479 }
480
481 let header = PacketHeader::from_bytes(&header_buf);
482 let payload_len = header.payload_length as usize;
483 self.sequence_id = header.sequence_id.wrapping_add(1);
484
485 let mut payload = vec![0u8; payload_len];
487 if payload_len > 0 {
488 let Some(stream) = self.stream.as_mut() else {
489 return Outcome::Err(connection_error("Connection stream missing"));
490 };
491 match stream {
492 ConnectionStream::Async(stream) => {
493 let mut total_read = 0;
494 while total_read < payload_len {
495 let mut read_buf = ReadBuf::new(&mut payload[total_read..]);
496 match std::future::poll_fn(|cx| {
497 std::pin::Pin::new(&mut *stream).poll_read(cx, &mut read_buf)
498 })
499 .await
500 {
501 Ok(()) => {
502 let n = read_buf.filled().len();
503 if n == 0 {
504 return Outcome::Err(Error::Connection(ConnectionError {
505 kind: ConnectionErrorKind::Disconnected,
506 message: "Connection closed while reading payload"
507 .to_string(),
508 source: None,
509 }));
510 }
511 total_read += n;
512 }
513 Err(e) => {
514 return Outcome::Err(Error::Connection(ConnectionError {
515 kind: ConnectionErrorKind::Disconnected,
516 message: format!("Failed to read packet payload: {}", e),
517 source: Some(Box::new(e)),
518 }));
519 }
520 }
521 }
522 }
523 ConnectionStream::Sync(stream) => {
524 if let Err(e) = stream.read_exact(&mut payload) {
525 return Outcome::Err(Error::Connection(ConnectionError {
526 kind: ConnectionErrorKind::Disconnected,
527 message: format!("Failed to read packet payload: {}", e),
528 source: Some(Box::new(e)),
529 }));
530 }
531 }
532 #[cfg(feature = "tls")]
533 ConnectionStream::Tls(stream) => {
534 if let Err(e) = stream.read_exact(&mut payload).await {
535 return Outcome::Err(Error::Connection(ConnectionError {
536 kind: ConnectionErrorKind::Disconnected,
537 message: format!("Failed to read packet payload: {e}"),
538 source: Some(Box::new(e)),
539 }));
540 }
541 }
542 }
543 }
544
545 if payload_len == MAX_PACKET_SIZE {
547 loop {
548 let mut header_buf = [0u8; 4];
550 let Some(stream) = self.stream.as_mut() else {
551 return Outcome::Err(connection_error("Connection stream missing"));
552 };
553 match stream {
554 ConnectionStream::Async(stream) => {
555 let mut header_read = 0;
556 while header_read < 4 {
557 let mut read_buf = ReadBuf::new(&mut header_buf[header_read..]);
558 match std::future::poll_fn(|cx| {
559 std::pin::Pin::new(&mut *stream).poll_read(cx, &mut read_buf)
560 })
561 .await
562 {
563 Ok(()) => {
564 let n = read_buf.filled().len();
565 if n == 0 {
566 return Outcome::Err(Error::Connection(ConnectionError {
567 kind: ConnectionErrorKind::Disconnected,
568 message: "Connection closed while reading continuation header".to_string(),
569 source: None,
570 }));
571 }
572 header_read += n;
573 }
574 Err(e) => {
575 return Outcome::Err(Error::Connection(ConnectionError {
576 kind: ConnectionErrorKind::Disconnected,
577 message: format!(
578 "Failed to read continuation header: {}",
579 e
580 ),
581 source: Some(Box::new(e)),
582 }));
583 }
584 }
585 }
586 }
587 ConnectionStream::Sync(stream) => {
588 if let Err(e) = stream.read_exact(&mut header_buf) {
589 return Outcome::Err(Error::Connection(ConnectionError {
590 kind: ConnectionErrorKind::Disconnected,
591 message: format!("Failed to read continuation header: {}", e),
592 source: Some(Box::new(e)),
593 }));
594 }
595 }
596 #[cfg(feature = "tls")]
597 ConnectionStream::Tls(stream) => {
598 if let Err(e) = stream.read_exact(&mut header_buf).await {
599 return Outcome::Err(Error::Connection(ConnectionError {
600 kind: ConnectionErrorKind::Disconnected,
601 message: format!("Failed to read continuation header: {e}"),
602 source: Some(Box::new(e)),
603 }));
604 }
605 }
606 }
607
608 let cont_header = PacketHeader::from_bytes(&header_buf);
609 let cont_len = cont_header.payload_length as usize;
610 self.sequence_id = cont_header.sequence_id.wrapping_add(1);
611
612 if cont_len > 0 {
613 let mut cont_payload = vec![0u8; cont_len];
614 let Some(stream) = self.stream.as_mut() else {
615 return Outcome::Err(connection_error("Connection stream missing"));
616 };
617 match stream {
618 ConnectionStream::Async(stream) => {
619 let mut total_read = 0;
620 while total_read < cont_len {
621 let mut read_buf = ReadBuf::new(&mut cont_payload[total_read..]);
622 match std::future::poll_fn(|cx| {
623 std::pin::Pin::new(&mut *stream).poll_read(cx, &mut read_buf)
624 })
625 .await
626 {
627 Ok(()) => {
628 let n = read_buf.filled().len();
629 if n == 0 {
630 return Outcome::Err(Error::Connection(ConnectionError {
631 kind: ConnectionErrorKind::Disconnected,
632 message: "Connection closed while reading continuation payload".to_string(),
633 source: None,
634 }));
635 }
636 total_read += n;
637 }
638 Err(e) => {
639 return Outcome::Err(Error::Connection(ConnectionError {
640 kind: ConnectionErrorKind::Disconnected,
641 message: format!(
642 "Failed to read continuation payload: {}",
643 e
644 ),
645 source: Some(Box::new(e)),
646 }));
647 }
648 }
649 }
650 }
651 ConnectionStream::Sync(stream) => {
652 if let Err(e) = stream.read_exact(&mut cont_payload) {
653 return Outcome::Err(Error::Connection(ConnectionError {
654 kind: ConnectionErrorKind::Disconnected,
655 message: format!("Failed to read continuation payload: {}", e),
656 source: Some(Box::new(e)),
657 }));
658 }
659 }
660 #[cfg(feature = "tls")]
661 ConnectionStream::Tls(stream) => {
662 if let Err(e) = stream.read_exact(&mut cont_payload).await {
663 return Outcome::Err(Error::Connection(ConnectionError {
664 kind: ConnectionErrorKind::Disconnected,
665 message: format!("Failed to read continuation payload: {e}"),
666 source: Some(Box::new(e)),
667 }));
668 }
669 }
670 }
671 payload.extend_from_slice(&cont_payload);
672 }
673
674 if cont_len < MAX_PACKET_SIZE {
675 break;
676 }
677 }
678 }
679
680 Outcome::Ok((payload, header.sequence_id))
681 }
682
683 async fn write_packet_async(&mut self, payload: &[u8]) -> Outcome<(), Error> {
685 let writer = PacketWriter::new();
686 let packet = writer.build_packet_from_payload(payload, self.sequence_id);
687 self.sequence_id = self.sequence_id.wrapping_add(1);
688
689 let Some(stream) = self.stream.as_mut() else {
690 return Outcome::Err(connection_error("Connection stream missing"));
691 };
692
693 match stream {
694 ConnectionStream::Async(stream) => {
695 let mut written = 0;
697 while written < packet.len() {
698 match std::future::poll_fn(|cx| {
699 std::pin::Pin::new(&mut *stream).poll_write(cx, &packet[written..])
700 })
701 .await
702 {
703 Ok(n) => {
704 if n == 0 {
705 return Outcome::Err(Error::Connection(ConnectionError {
706 kind: ConnectionErrorKind::Disconnected,
707 message: "Connection closed while writing packet".to_string(),
708 source: None,
709 }));
710 }
711 written += n;
712 }
713 Err(e) => {
714 return Outcome::Err(Error::Connection(ConnectionError {
715 kind: ConnectionErrorKind::Disconnected,
716 message: format!("Failed to write packet: {}", e),
717 source: Some(Box::new(e)),
718 }));
719 }
720 }
721 }
722
723 match std::future::poll_fn(|cx| std::pin::Pin::new(&mut *stream).poll_flush(cx))
724 .await
725 {
726 Ok(()) => {}
727 Err(e) => {
728 return Outcome::Err(Error::Connection(ConnectionError {
729 kind: ConnectionErrorKind::Disconnected,
730 message: format!("Failed to flush stream: {}", e),
731 source: Some(Box::new(e)),
732 }));
733 }
734 }
735 }
736 ConnectionStream::Sync(stream) => {
737 if let Err(e) = stream.write_all(&packet) {
738 return Outcome::Err(Error::Connection(ConnectionError {
739 kind: ConnectionErrorKind::Disconnected,
740 message: format!("Failed to write packet: {}", e),
741 source: Some(Box::new(e)),
742 }));
743 }
744 if let Err(e) = stream.flush() {
745 return Outcome::Err(Error::Connection(ConnectionError {
746 kind: ConnectionErrorKind::Disconnected,
747 message: format!("Failed to flush stream: {}", e),
748 source: Some(Box::new(e)),
749 }));
750 }
751 }
752 #[cfg(feature = "tls")]
753 ConnectionStream::Tls(stream) => {
754 if let Err(e) = stream.write_all(&packet).await {
755 return Outcome::Err(Error::Connection(ConnectionError {
756 kind: ConnectionErrorKind::Disconnected,
757 message: format!("Failed to write packet: {e}"),
758 source: Some(Box::new(e)),
759 }));
760 }
761 if let Err(e) = stream.flush().await {
762 return Outcome::Err(Error::Connection(ConnectionError {
763 kind: ConnectionErrorKind::Disconnected,
764 message: format!("Failed to flush stream: {e}"),
765 source: Some(Box::new(e)),
766 }));
767 }
768 }
769 }
770
771 Outcome::Ok(())
772 }
773
774 async fn read_handshake_async(&mut self) -> Outcome<ServerCapabilities, Error> {
778 let (payload, _) = match self.read_packet_async().await {
779 Outcome::Ok(p) => p,
780 Outcome::Err(e) => return Outcome::Err(e),
781 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
782 Outcome::Panicked(p) => return Outcome::Panicked(p),
783 };
784
785 let mut reader = PacketReader::new(&payload);
786
787 let Some(protocol_version) = reader.read_u8() else {
789 return Outcome::Err(protocol_error("Missing protocol version"));
790 };
791
792 if protocol_version != 10 {
793 return Outcome::Err(protocol_error(format!(
794 "Unsupported protocol version: {}",
795 protocol_version
796 )));
797 }
798
799 let Some(server_version) = reader.read_null_string() else {
801 return Outcome::Err(protocol_error("Missing server version"));
802 };
803
804 let Some(connection_id) = reader.read_u32_le() else {
806 return Outcome::Err(protocol_error("Missing connection ID"));
807 };
808
809 let Some(auth_data_1) = reader.read_bytes(8) else {
811 return Outcome::Err(protocol_error("Missing auth data"));
812 };
813
814 reader.skip(1);
816
817 let Some(caps_lower) = reader.read_u16_le() else {
819 return Outcome::Err(protocol_error("Missing capability flags"));
820 };
821
822 let charset_val = reader.read_u8().unwrap_or(charset::UTF8MB4_0900_AI_CI);
824
825 let status_flags = reader.read_u16_le().unwrap_or(0);
827
828 let caps_upper = reader.read_u16_le().unwrap_or(0);
830 let capabilities_val = u32::from(caps_lower) | (u32::from(caps_upper) << 16);
831
832 let auth_data_len = if capabilities_val & capabilities::CLIENT_PLUGIN_AUTH != 0 {
834 reader.read_u8().unwrap_or(0) as usize
835 } else {
836 0
837 };
838
839 reader.skip(10);
841
842 let mut auth_data = auth_data_1.to_vec();
844 if capabilities_val & capabilities::CLIENT_SECURE_CONNECTION != 0 {
845 let len2 = if auth_data_len > 8 {
846 auth_data_len - 8
847 } else {
848 13 };
850 if let Some(data2) = reader.read_bytes(len2) {
851 let data2_clean = if data2.last() == Some(&0) {
853 &data2[..data2.len() - 1]
854 } else {
855 data2
856 };
857 auth_data.extend_from_slice(data2_clean);
858 }
859 }
860
861 let auth_plugin = if capabilities_val & capabilities::CLIENT_PLUGIN_AUTH != 0 {
863 reader.read_null_string().unwrap_or_default()
864 } else {
865 auth::plugins::MYSQL_NATIVE_PASSWORD.to_string()
866 };
867
868 Outcome::Ok(ServerCapabilities {
869 capabilities: capabilities_val,
870 protocol_version,
871 server_version,
872 connection_id,
873 auth_plugin,
874 auth_data,
875 charset: charset_val,
876 status_flags,
877 })
878 }
879
880 async fn send_handshake_response_async(&mut self) -> Outcome<(), Error> {
882 let Some(server_caps) = self.server_caps.as_ref() else {
883 return Outcome::Err(protocol_error("No server handshake received"));
884 };
885
886 let server_caps_bits = server_caps.capabilities;
888 let auth_plugin = server_caps.auth_plugin.clone();
889 let auth_data = server_caps.auth_data.clone();
890
891 let mut client_caps = self.config.capability_flags() & server_caps_bits;
893 #[cfg(feature = "tls")]
894 if let Outcome::Err(e) = self
895 .maybe_upgrade_tls_async(server_caps_bits, &mut client_caps)
896 .await
897 {
898 return Outcome::Err(e);
899 }
900
901 #[cfg(not(feature = "tls"))]
902 if let Outcome::Err(e) = self.maybe_upgrade_tls(server_caps_bits, &mut client_caps) {
903 return Outcome::Err(e);
904 }
905
906 let auth_response = self.compute_auth_response(&auth_plugin, &auth_data);
908
909 let mut writer = PacketWriter::new();
910
911 writer.write_u32_le(client_caps);
913
914 writer.write_u32_le(self.config.max_packet_size);
916
917 writer.write_u8(self.config.charset);
919
920 writer.write_zeros(23);
922
923 writer.write_null_string(&self.config.user);
925
926 if client_caps & capabilities::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA != 0 {
928 writer.write_lenenc_bytes(&auth_response);
929 } else if client_caps & capabilities::CLIENT_SECURE_CONNECTION != 0 {
930 #[allow(clippy::cast_possible_truncation)]
931 writer.write_u8(auth_response.len() as u8);
932 writer.write_bytes(&auth_response);
933 } else {
934 writer.write_bytes(&auth_response);
935 writer.write_u8(0); }
937
938 if client_caps & capabilities::CLIENT_CONNECT_WITH_DB != 0 {
940 if let Some(ref db) = self.config.database {
941 writer.write_null_string(db);
942 } else {
943 writer.write_u8(0); }
945 }
946
947 if client_caps & capabilities::CLIENT_PLUGIN_AUTH != 0 {
949 writer.write_null_string(&auth_plugin);
950 }
951
952 if client_caps & capabilities::CLIENT_CONNECT_ATTRS != 0
954 && !self.config.attributes.is_empty()
955 {
956 let mut attrs_writer = PacketWriter::new();
957 for (key, value) in &self.config.attributes {
958 attrs_writer.write_lenenc_string(key);
959 attrs_writer.write_lenenc_string(value);
960 }
961 let attrs_data = attrs_writer.into_bytes();
962 writer.write_lenenc_bytes(&attrs_data);
963 }
964
965 self.write_packet_async(writer.as_bytes()).await
966 }
967
968 #[cfg(feature = "tls")]
969 async fn maybe_upgrade_tls_async(
970 &mut self,
971 server_caps: u32,
972 client_caps: &mut u32,
973 ) -> Outcome<(), Error> {
974 let ssl_mode = self.config.ssl_mode;
975
976 if !ssl_mode.should_try_ssl() {
977 *client_caps &= !capabilities::CLIENT_SSL;
978 return Outcome::Ok(());
979 }
980
981 let use_tls = match crate::tls::validate_ssl_mode(ssl_mode, server_caps) {
982 Ok(v) => v,
983 Err(e) => return Outcome::Err(e),
984 };
985
986 if !use_tls {
987 *client_caps &= !capabilities::CLIENT_SSL;
989 return Outcome::Ok(());
990 }
991
992 if let Err(e) = crate::tls::validate_tls_config(ssl_mode, &self.config.tls_config) {
993 return Outcome::Err(e);
994 }
995
996 let packet = crate::tls::build_ssl_request_packet(
999 *client_caps,
1000 self.config.max_packet_size,
1001 self.config.charset,
1002 self.sequence_id,
1003 );
1004 if let Outcome::Err(e) = self.write_packet_raw_async(&packet).await {
1005 return Outcome::Err(e);
1006 }
1007 self.sequence_id = self.sequence_id.wrapping_add(1);
1008
1009 let Some(stream) = self.stream.take() else {
1010 return Outcome::Err(connection_error("Connection stream missing"));
1011 };
1012 let ConnectionStream::Async(tcp) = stream else {
1013 return Outcome::Err(connection_error("TLS upgrade requires async TCP stream"));
1014 };
1015
1016 let tls = match AsyncTlsStream::handshake(
1017 tcp,
1018 &self.config.tls_config,
1019 &self.config.host,
1020 ssl_mode,
1021 )
1022 .await
1023 {
1024 Ok(s) => s,
1025 Err(e) => return Outcome::Err(e),
1026 };
1027
1028 self.stream = Some(ConnectionStream::Tls(tls));
1029 Outcome::Ok(())
1030 }
1031
1032 #[cfg(not(feature = "tls"))]
1033 fn maybe_upgrade_tls(&mut self, server_caps: u32, client_caps: &mut u32) -> Outcome<(), Error> {
1034 let ssl_mode = self.config.ssl_mode;
1035
1036 if !ssl_mode.should_try_ssl() {
1037 *client_caps &= !capabilities::CLIENT_SSL;
1038 return Outcome::Ok(());
1039 }
1040
1041 let use_tls = match crate::tls::validate_ssl_mode(ssl_mode, server_caps) {
1042 Ok(v) => v,
1043 Err(e) => return Outcome::Err(e),
1044 };
1045
1046 if !use_tls {
1047 *client_caps &= !capabilities::CLIENT_SSL;
1049 return Outcome::Ok(());
1050 }
1051
1052 if ssl_mode == crate::config::SslMode::Preferred {
1054 *client_caps &= !capabilities::CLIENT_SSL;
1055 Outcome::Ok(())
1056 } else {
1057 Outcome::Err(connection_error(
1058 "TLS requested but 'sqlmodel-mysql' was built without feature 'tls'",
1059 ))
1060 }
1061 }
1062
1063 fn compute_auth_response(&self, plugin: &str, auth_data: &[u8]) -> Vec<u8> {
1065 let pw = self.config.password_str();
1066
1067 match plugin {
1068 auth::plugins::MYSQL_NATIVE_PASSWORD => {
1070 auth::mysql_native_password(pw, auth_data)
1071 }
1072 auth::plugins::CACHING_SHA2_PASSWORD => {
1073 auth::caching_sha2_password(pw, auth_data)
1074 }
1075 auth::plugins::MYSQL_CLEAR_PASSWORD => {
1076 let mut result = pw.as_bytes().to_vec();
1077 result.push(0);
1078 result
1079 }
1080 _ => auth::mysql_native_password(pw, auth_data),
1081 }
1082 }
1083
1084 async fn handle_auth_result_async(&mut self) -> Outcome<(), Error> {
1087 loop {
1089 let (payload, _) = match self.read_packet_async().await {
1090 Outcome::Ok(p) => p,
1091 Outcome::Err(e) => return Outcome::Err(e),
1092 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1093 Outcome::Panicked(p) => return Outcome::Panicked(p),
1094 };
1095
1096 if payload.is_empty() {
1097 return Outcome::Err(protocol_error("Empty authentication response"));
1098 }
1099
1100 #[allow(clippy::cast_possible_truncation)] match PacketType::from_first_byte(payload[0], payload.len() as u32) {
1102 PacketType::Ok => {
1103 let mut reader = PacketReader::new(&payload);
1104 if let Some(ok) = reader.parse_ok_packet() {
1105 self.status_flags = ok.status_flags;
1106 self.affected_rows = ok.affected_rows;
1107 }
1108 return Outcome::Ok(());
1109 }
1110 PacketType::Error => {
1111 let mut reader = PacketReader::new(&payload);
1112 let Some(err) = reader.parse_err_packet() else {
1113 return Outcome::Err(protocol_error("Invalid error packet"));
1114 };
1115 return Outcome::Err(auth_error(format!(
1116 "Authentication failed: {} ({})",
1117 err.error_message, err.error_code
1118 )));
1119 }
1120 PacketType::Eof => {
1121 let data = &payload[1..];
1123 let mut reader = PacketReader::new(data);
1124
1125 let Some(plugin) = reader.read_null_string() else {
1126 return Outcome::Err(protocol_error("Missing plugin name in auth switch"));
1127 };
1128
1129 let auth_data = reader.read_rest();
1130 let response = self.compute_auth_response(&plugin, auth_data);
1131
1132 if let Outcome::Err(e) = self.write_packet_async(&response).await {
1133 return Outcome::Err(e);
1134 }
1135 }
1137 _ => {
1138 match self.handle_additional_auth_async(&payload).await {
1140 Outcome::Ok(()) => continue,
1141 Outcome::Err(e) => return Outcome::Err(e),
1142 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1143 Outcome::Panicked(p) => return Outcome::Panicked(p),
1144 }
1145 }
1146 }
1147 }
1148 }
1149
1150 async fn handle_additional_auth_async(&mut self, data: &[u8]) -> Outcome<(), Error> {
1152 if data.is_empty() {
1153 return Outcome::Err(protocol_error("Empty additional auth data"));
1154 }
1155
1156 match data[0] {
1157 auth::caching_sha2::FAST_AUTH_SUCCESS => {
1158 Outcome::Ok(())
1160 }
1161 auth::caching_sha2::PERFORM_FULL_AUTH => {
1162 let Some(server_caps) = self.server_caps.as_ref() else {
1163 return Outcome::Err(protocol_error("Missing server capabilities during auth"));
1164 };
1165
1166 let pw = self.config.password_owned();
1167 let seed = server_caps.auth_data.clone();
1168 let server_version = server_caps.server_version.clone();
1169
1170 if self.is_secure_transport() {
1171 let mut clear = pw.as_bytes().to_vec();
1174 clear.push(0);
1175 if let Outcome::Err(e) = self.write_packet_async(&clear).await {
1176 return Outcome::Err(e);
1177 }
1178 Outcome::Ok(())
1179 } else {
1180 if let Outcome::Err(e) = self
1182 .write_packet_async(&[auth::caching_sha2::REQUEST_PUBLIC_KEY])
1183 .await
1184 {
1185 return Outcome::Err(e);
1186 }
1187
1188 let (payload, _) = match self.read_packet_async().await {
1189 Outcome::Ok(p) => p,
1190 Outcome::Err(e) => return Outcome::Err(e),
1191 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1192 Outcome::Panicked(p) => return Outcome::Panicked(p),
1193 };
1194 if payload.is_empty() {
1195 return Outcome::Err(protocol_error("Empty public key response"));
1196 }
1197
1198 let public_key = if payload[0] == 0x01 {
1200 &payload[1..]
1201 } else {
1202 &payload[..]
1203 };
1204
1205 let use_oaep = mysql_server_uses_oaep(&server_version);
1206 let encrypted =
1207 match auth::sha256_password_rsa(&pw, &seed, public_key, use_oaep) {
1208 Ok(v) => v,
1209 Err(e) => return Outcome::Err(auth_error(e)),
1210 };
1211
1212 if let Outcome::Err(e) = self.write_packet_async(&encrypted).await {
1213 return Outcome::Err(e);
1214 }
1215 Outcome::Ok(())
1216 }
1217 }
1218 _ => Outcome::Err(protocol_error(format!(
1219 "Unknown additional auth response: {:02X}",
1220 data[0]
1221 ))),
1222 }
1223 }
1224
1225 pub async fn query_async(
1227 &mut self,
1228 _cx: &Cx,
1229 sql: &str,
1230 params: &[Value],
1231 ) -> Outcome<Vec<Row>, Error> {
1232 let sql = interpolate_params(sql, params);
1233 if !self.is_ready() && self.state != ConnectionState::InTransaction {
1234 return Outcome::Err(connection_error("Connection not ready for queries"));
1235 }
1236
1237 self.state = ConnectionState::InQuery;
1238 self.sequence_id = 0;
1239
1240 let mut writer = PacketWriter::new();
1242 writer.write_u8(Command::Query as u8);
1243 writer.write_bytes(sql.as_bytes());
1244
1245 if let Outcome::Err(e) = self.write_packet_async(writer.as_bytes()).await {
1246 return Outcome::Err(e);
1247 }
1248
1249 let (payload, _) = match self.read_packet_async().await {
1251 Outcome::Ok(p) => p,
1252 Outcome::Err(e) => return Outcome::Err(e),
1253 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1254 Outcome::Panicked(p) => return Outcome::Panicked(p),
1255 };
1256
1257 if payload.is_empty() {
1258 self.state = ConnectionState::Ready;
1259 return Outcome::Err(protocol_error("Empty query response"));
1260 }
1261
1262 #[allow(clippy::cast_possible_truncation)] match PacketType::from_first_byte(payload[0], payload.len() as u32) {
1264 PacketType::Ok => {
1265 let mut reader = PacketReader::new(&payload);
1266 if let Some(ok) = reader.parse_ok_packet() {
1267 self.affected_rows = ok.affected_rows;
1268 self.last_insert_id = ok.last_insert_id;
1269 self.status_flags = ok.status_flags;
1270 self.warnings = ok.warnings;
1271 }
1272 self.state = if self.status_flags
1273 & crate::protocol::server_status::SERVER_STATUS_IN_TRANS
1274 != 0
1275 {
1276 ConnectionState::InTransaction
1277 } else {
1278 ConnectionState::Ready
1279 };
1280 Outcome::Ok(vec![])
1281 }
1282 PacketType::Error => {
1283 self.state = ConnectionState::Ready;
1284 let mut reader = PacketReader::new(&payload);
1285 let Some(err) = reader.parse_err_packet() else {
1286 return Outcome::Err(protocol_error("Invalid error packet"));
1287 };
1288 Outcome::Err(query_error(&err))
1289 }
1290 PacketType::LocalInfile => {
1291 self.state = ConnectionState::Ready;
1292 Outcome::Err(query_error_msg("LOCAL INFILE not supported"))
1293 }
1294 _ => self.read_result_set_async(&payload).await,
1295 }
1296 }
1297
1298 async fn read_result_set_async(&mut self, first_packet: &[u8]) -> Outcome<Vec<Row>, Error> {
1300 let mut reader = PacketReader::new(first_packet);
1301 #[allow(clippy::cast_possible_truncation)] let Some(column_count) = reader.read_lenenc_int().map(|c| c as usize) else {
1303 return Outcome::Err(protocol_error("Invalid column count"));
1304 };
1305
1306 let mut columns = Vec::with_capacity(column_count);
1308 for _ in 0..column_count {
1309 let (payload, _) = match self.read_packet_async().await {
1310 Outcome::Ok(p) => p,
1311 Outcome::Err(e) => return Outcome::Err(e),
1312 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1313 Outcome::Panicked(p) => return Outcome::Panicked(p),
1314 };
1315 match self.parse_column_def(&payload) {
1316 Ok(col) => columns.push(col),
1317 Err(e) => return Outcome::Err(e),
1318 }
1319 }
1320
1321 let server_caps = self.server_caps.as_ref().map_or(0, |c| c.capabilities);
1323 if server_caps & capabilities::CLIENT_DEPRECATE_EOF == 0 {
1324 let (payload, _) = match self.read_packet_async().await {
1325 Outcome::Ok(p) => p,
1326 Outcome::Err(e) => return Outcome::Err(e),
1327 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1328 Outcome::Panicked(p) => return Outcome::Panicked(p),
1329 };
1330 if payload.first() == Some(&0xFE) {
1331 }
1333 }
1334
1335 let mut rows = Vec::new();
1337 loop {
1338 let (payload, _) = match self.read_packet_async().await {
1339 Outcome::Ok(p) => p,
1340 Outcome::Err(e) => return Outcome::Err(e),
1341 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1342 Outcome::Panicked(p) => return Outcome::Panicked(p),
1343 };
1344
1345 if payload.is_empty() {
1346 break;
1347 }
1348
1349 #[allow(clippy::cast_possible_truncation)] match PacketType::from_first_byte(payload[0], payload.len() as u32) {
1351 PacketType::Eof | PacketType::Ok => {
1352 let mut reader = PacketReader::new(&payload);
1353 if payload[0] == 0x00 {
1354 if let Some(ok) = reader.parse_ok_packet() {
1355 self.status_flags = ok.status_flags;
1356 self.warnings = ok.warnings;
1357 }
1358 } else if payload[0] == 0xFE {
1359 if let Some(eof) = reader.parse_eof_packet() {
1360 self.status_flags = eof.status_flags;
1361 self.warnings = eof.warnings;
1362 }
1363 }
1364 break;
1365 }
1366 PacketType::Error => {
1367 let mut reader = PacketReader::new(&payload);
1368 let Some(err) = reader.parse_err_packet() else {
1369 return Outcome::Err(protocol_error("Invalid error packet"));
1370 };
1371 self.state = ConnectionState::Ready;
1372 return Outcome::Err(query_error(&err));
1373 }
1374 _ => {
1375 let row = self.parse_text_row(&payload, &columns);
1376 rows.push(row);
1377 }
1378 }
1379 }
1380
1381 self.state =
1382 if self.status_flags & crate::protocol::server_status::SERVER_STATUS_IN_TRANS != 0 {
1383 ConnectionState::InTransaction
1384 } else {
1385 ConnectionState::Ready
1386 };
1387
1388 Outcome::Ok(rows)
1389 }
1390
1391 fn parse_column_def(&self, data: &[u8]) -> Result<ColumnDef, Error> {
1393 let mut reader = PacketReader::new(data);
1394
1395 let catalog = reader
1396 .read_lenenc_string()
1397 .ok_or_else(|| protocol_error("Missing catalog"))?;
1398 let schema = reader
1399 .read_lenenc_string()
1400 .ok_or_else(|| protocol_error("Missing schema"))?;
1401 let table = reader
1402 .read_lenenc_string()
1403 .ok_or_else(|| protocol_error("Missing table"))?;
1404 let org_table = reader
1405 .read_lenenc_string()
1406 .ok_or_else(|| protocol_error("Missing org_table"))?;
1407 let name = reader
1408 .read_lenenc_string()
1409 .ok_or_else(|| protocol_error("Missing name"))?;
1410 let org_name = reader
1411 .read_lenenc_string()
1412 .ok_or_else(|| protocol_error("Missing org_name"))?;
1413
1414 let _fixed_len = reader.read_lenenc_int();
1415
1416 let charset_val = reader
1417 .read_u16_le()
1418 .ok_or_else(|| protocol_error("Missing charset"))?;
1419 let column_length = reader
1420 .read_u32_le()
1421 .ok_or_else(|| protocol_error("Missing column_length"))?;
1422 let column_type = FieldType::from_u8(
1423 reader
1424 .read_u8()
1425 .ok_or_else(|| protocol_error("Missing column_type"))?,
1426 );
1427 let flags = reader
1428 .read_u16_le()
1429 .ok_or_else(|| protocol_error("Missing flags"))?;
1430 let decimals = reader
1431 .read_u8()
1432 .ok_or_else(|| protocol_error("Missing decimals"))?;
1433
1434 Ok(ColumnDef {
1435 catalog,
1436 schema,
1437 table,
1438 org_table,
1439 name,
1440 org_name,
1441 charset: charset_val,
1442 column_length,
1443 column_type,
1444 flags,
1445 decimals,
1446 })
1447 }
1448
1449 fn parse_text_row(&self, data: &[u8], columns: &[ColumnDef]) -> Row {
1451 let mut reader = PacketReader::new(data);
1452 let mut values = Vec::with_capacity(columns.len());
1453
1454 for col in columns {
1455 if reader.peek() == Some(0xFB) {
1456 reader.skip(1);
1457 values.push(Value::Null);
1458 } else if let Some(data) = reader.read_lenenc_bytes() {
1459 let is_unsigned = col.is_unsigned();
1460 let value = decode_text_value(col.column_type, &data, is_unsigned);
1461 values.push(value);
1462 } else {
1463 values.push(Value::Null);
1464 }
1465 }
1466
1467 let column_names: Vec<String> = columns.iter().map(|c| c.name.clone()).collect();
1468 Row::new(column_names, values)
1469 }
1470
1471 pub async fn execute_async(
1476 &mut self,
1477 cx: &Cx,
1478 sql: &str,
1479 params: &[Value],
1480 ) -> Outcome<u64, Error> {
1481 match self.query_async(cx, sql, params).await {
1483 Outcome::Ok(_) => Outcome::Ok(self.affected_rows),
1484 Outcome::Err(e) => Outcome::Err(e),
1485 Outcome::Cancelled(c) => Outcome::Cancelled(c),
1486 Outcome::Panicked(p) => Outcome::Panicked(p),
1487 }
1488 }
1489
1490 pub async fn prepare_async(
1495 &mut self,
1496 _cx: &Cx,
1497 sql: &str,
1498 ) -> Outcome<PreparedStatement, Error> {
1499 if !self.is_ready() && self.state != ConnectionState::InTransaction {
1500 return Outcome::Err(connection_error("Connection not ready for prepare"));
1501 }
1502
1503 self.sequence_id = 0;
1504
1505 let packet = prepared::build_stmt_prepare_packet(sql, self.sequence_id);
1507 if let Outcome::Err(e) = self.write_packet_raw_async(&packet).await {
1508 return Outcome::Err(e);
1509 }
1510
1511 let (payload, _) = match self.read_packet_async().await {
1513 Outcome::Ok(p) => p,
1514 Outcome::Err(e) => return Outcome::Err(e),
1515 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1516 Outcome::Panicked(p) => return Outcome::Panicked(p),
1517 };
1518
1519 if payload.first() == Some(&0xFF) {
1521 let mut reader = PacketReader::new(&payload);
1522 let Some(err) = reader.parse_err_packet() else {
1523 return Outcome::Err(protocol_error("Invalid error packet"));
1524 };
1525 return Outcome::Err(query_error(&err));
1526 }
1527
1528 let Some(prep_ok) = prepared::parse_stmt_prepare_ok(&payload) else {
1530 return Outcome::Err(protocol_error("Invalid prepare OK response"));
1531 };
1532
1533 let mut param_defs = Vec::with_capacity(prep_ok.num_params as usize);
1535 for _ in 0..prep_ok.num_params {
1536 let (payload, _) = match self.read_packet_async().await {
1537 Outcome::Ok(p) => p,
1538 Outcome::Err(e) => return Outcome::Err(e),
1539 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1540 Outcome::Panicked(p) => return Outcome::Panicked(p),
1541 };
1542 match self.parse_column_def(&payload) {
1543 Ok(col) => param_defs.push(col),
1544 Err(e) => return Outcome::Err(e),
1545 }
1546 }
1547
1548 let server_caps = self.server_caps.as_ref().map_or(0, |c| c.capabilities);
1550 if prep_ok.num_params > 0 && server_caps & capabilities::CLIENT_DEPRECATE_EOF == 0 {
1551 let (payload, _) = match self.read_packet_async().await {
1552 Outcome::Ok(p) => p,
1553 Outcome::Err(e) => return Outcome::Err(e),
1554 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1555 Outcome::Panicked(p) => return Outcome::Panicked(p),
1556 };
1557 if payload.first() != Some(&0xFE) {
1558 return Outcome::Err(protocol_error("Expected EOF after param definitions"));
1559 }
1560 }
1561
1562 let mut column_defs = Vec::with_capacity(prep_ok.num_columns as usize);
1564 for _ in 0..prep_ok.num_columns {
1565 let (payload, _) = match self.read_packet_async().await {
1566 Outcome::Ok(p) => p,
1567 Outcome::Err(e) => return Outcome::Err(e),
1568 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1569 Outcome::Panicked(p) => return Outcome::Panicked(p),
1570 };
1571 match self.parse_column_def(&payload) {
1572 Ok(col) => column_defs.push(col),
1573 Err(e) => return Outcome::Err(e),
1574 }
1575 }
1576
1577 if prep_ok.num_columns > 0 && server_caps & capabilities::CLIENT_DEPRECATE_EOF == 0 {
1579 let (payload, _) = match self.read_packet_async().await {
1580 Outcome::Ok(p) => p,
1581 Outcome::Err(e) => return Outcome::Err(e),
1582 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1583 Outcome::Panicked(p) => return Outcome::Panicked(p),
1584 };
1585 if payload.first() != Some(&0xFE) {
1586 return Outcome::Err(protocol_error("Expected EOF after column definitions"));
1587 }
1588 }
1589
1590 let meta = PreparedStmtMeta {
1592 statement_id: prep_ok.statement_id,
1593 params: param_defs,
1594 columns: column_defs.clone(),
1595 };
1596 self.prepared_stmts.insert(prep_ok.statement_id, meta);
1597
1598 let column_names: Vec<String> = column_defs.iter().map(|c| c.name.clone()).collect();
1600 Outcome::Ok(PreparedStatement::with_columns(
1601 u64::from(prep_ok.statement_id),
1602 sql.to_string(),
1603 prep_ok.num_params as usize,
1604 column_names,
1605 ))
1606 }
1607
1608 pub async fn query_prepared_async(
1610 &mut self,
1611 _cx: &Cx,
1612 stmt: &PreparedStatement,
1613 params: &[Value],
1614 ) -> Outcome<Vec<Row>, Error> {
1615 #[allow(clippy::cast_possible_truncation)] let stmt_id = stmt.id() as u32;
1617
1618 let Some(meta) = self.prepared_stmts.get(&stmt_id).cloned() else {
1620 return Outcome::Err(connection_error("Unknown prepared statement"));
1621 };
1622
1623 if params.len() != meta.params.len() {
1625 return Outcome::Err(connection_error(format!(
1626 "Expected {} parameters, got {}",
1627 meta.params.len(),
1628 params.len()
1629 )));
1630 }
1631
1632 if !self.is_ready() && self.state != ConnectionState::InTransaction {
1633 return Outcome::Err(connection_error("Connection not ready for query"));
1634 }
1635
1636 self.state = ConnectionState::InQuery;
1637 self.sequence_id = 0;
1638
1639 let param_types: Vec<FieldType> = meta.params.iter().map(|c| c.column_type).collect();
1641 let packet = prepared::build_stmt_execute_packet(
1642 stmt_id,
1643 params,
1644 Some(¶m_types),
1645 self.sequence_id,
1646 );
1647 if let Outcome::Err(e) = self.write_packet_raw_async(&packet).await {
1648 return Outcome::Err(e);
1649 }
1650
1651 let (payload, _) = match self.read_packet_async().await {
1653 Outcome::Ok(p) => p,
1654 Outcome::Err(e) => return Outcome::Err(e),
1655 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1656 Outcome::Panicked(p) => return Outcome::Panicked(p),
1657 };
1658
1659 if payload.is_empty() {
1660 self.state = ConnectionState::Ready;
1661 return Outcome::Err(protocol_error("Empty execute response"));
1662 }
1663
1664 #[allow(clippy::cast_possible_truncation)] match PacketType::from_first_byte(payload[0], payload.len() as u32) {
1666 PacketType::Ok => {
1667 let mut reader = PacketReader::new(&payload);
1669 if let Some(ok) = reader.parse_ok_packet() {
1670 self.affected_rows = ok.affected_rows;
1671 self.last_insert_id = ok.last_insert_id;
1672 self.status_flags = ok.status_flags;
1673 self.warnings = ok.warnings;
1674 }
1675 self.state = ConnectionState::Ready;
1676 Outcome::Ok(vec![])
1677 }
1678 PacketType::Error => {
1679 self.state = ConnectionState::Ready;
1680 let mut reader = PacketReader::new(&payload);
1681 let Some(err) = reader.parse_err_packet() else {
1682 return Outcome::Err(protocol_error("Invalid error packet"));
1683 };
1684 Outcome::Err(query_error(&err))
1685 }
1686 _ => {
1687 self.read_binary_result_set_async(&payload, &meta.columns)
1689 .await
1690 }
1691 }
1692 }
1693
1694 pub async fn execute_prepared_async(
1696 &mut self,
1697 cx: &Cx,
1698 stmt: &PreparedStatement,
1699 params: &[Value],
1700 ) -> Outcome<u64, Error> {
1701 match self.query_prepared_async(cx, stmt, params).await {
1702 Outcome::Ok(_) => Outcome::Ok(self.affected_rows),
1703 Outcome::Err(e) => Outcome::Err(e),
1704 Outcome::Cancelled(c) => Outcome::Cancelled(c),
1705 Outcome::Panicked(p) => Outcome::Panicked(p),
1706 }
1707 }
1708
1709 pub async fn close_prepared_async(&mut self, stmt: &PreparedStatement) {
1711 #[allow(clippy::cast_possible_truncation)] let stmt_id = stmt.id() as u32;
1713 self.prepared_stmts.remove(&stmt_id);
1714
1715 self.sequence_id = 0;
1716 let packet = prepared::build_stmt_close_packet(stmt_id, self.sequence_id);
1717 let _ = self.write_packet_raw_async(&packet).await;
1719 }
1720
1721 async fn read_binary_result_set_async(
1723 &mut self,
1724 first_packet: &[u8],
1725 columns: &[ColumnDef],
1726 ) -> Outcome<Vec<Row>, Error> {
1727 let mut reader = PacketReader::new(first_packet);
1729 #[allow(clippy::cast_possible_truncation)] let Some(column_count) = reader.read_lenenc_int().map(|c| c as usize) else {
1731 return Outcome::Err(protocol_error("Invalid column count"));
1732 };
1733
1734 let mut result_columns = Vec::with_capacity(column_count);
1737 for _ in 0..column_count {
1738 let (payload, _) = match self.read_packet_async().await {
1739 Outcome::Ok(p) => p,
1740 Outcome::Err(e) => return Outcome::Err(e),
1741 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1742 Outcome::Panicked(p) => return Outcome::Panicked(p),
1743 };
1744 match self.parse_column_def(&payload) {
1745 Ok(col) => result_columns.push(col),
1746 Err(e) => return Outcome::Err(e),
1747 }
1748 }
1749
1750 let cols = if result_columns.len() == columns.len() {
1752 &result_columns
1753 } else {
1754 columns
1755 };
1756
1757 let server_caps = self.server_caps.as_ref().map_or(0, |c| c.capabilities);
1759 if server_caps & capabilities::CLIENT_DEPRECATE_EOF == 0 {
1760 let (payload, _) = match self.read_packet_async().await {
1761 Outcome::Ok(p) => p,
1762 Outcome::Err(e) => return Outcome::Err(e),
1763 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1764 Outcome::Panicked(p) => return Outcome::Panicked(p),
1765 };
1766 if payload.first() == Some(&0xFE) {
1767 }
1769 }
1770
1771 let mut rows = Vec::new();
1773 loop {
1774 let (payload, _) = match self.read_packet_async().await {
1775 Outcome::Ok(p) => p,
1776 Outcome::Err(e) => return Outcome::Err(e),
1777 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1778 Outcome::Panicked(p) => return Outcome::Panicked(p),
1779 };
1780
1781 if payload.is_empty() {
1782 break;
1783 }
1784
1785 #[allow(clippy::cast_possible_truncation)] match PacketType::from_first_byte(payload[0], payload.len() as u32) {
1787 PacketType::Eof | PacketType::Ok => {
1788 let mut reader = PacketReader::new(&payload);
1789 if payload[0] == 0x00 {
1790 if let Some(ok) = reader.parse_ok_packet() {
1791 self.status_flags = ok.status_flags;
1792 self.warnings = ok.warnings;
1793 }
1794 } else if payload[0] == 0xFE {
1795 if let Some(eof) = reader.parse_eof_packet() {
1796 self.status_flags = eof.status_flags;
1797 self.warnings = eof.warnings;
1798 }
1799 }
1800 break;
1801 }
1802 PacketType::Error => {
1803 let mut reader = PacketReader::new(&payload);
1804 let Some(err) = reader.parse_err_packet() else {
1805 return Outcome::Err(protocol_error("Invalid error packet"));
1806 };
1807 self.state = ConnectionState::Ready;
1808 return Outcome::Err(query_error(&err));
1809 }
1810 _ => {
1811 let row = self.parse_binary_row(&payload, cols);
1812 rows.push(row);
1813 }
1814 }
1815 }
1816
1817 self.state =
1818 if self.status_flags & crate::protocol::server_status::SERVER_STATUS_IN_TRANS != 0 {
1819 ConnectionState::InTransaction
1820 } else {
1821 ConnectionState::Ready
1822 };
1823
1824 Outcome::Ok(rows)
1825 }
1826
1827 fn parse_binary_row(&self, data: &[u8], columns: &[ColumnDef]) -> Row {
1829 let mut values = Vec::with_capacity(columns.len());
1835 let mut column_names = Vec::with_capacity(columns.len());
1836
1837 if data.is_empty() {
1838 return Row::new(column_names, values);
1839 }
1840
1841 let mut pos = 1;
1843
1844 let null_bitmap_len = (columns.len() + 7 + 2) / 8;
1847 if pos + null_bitmap_len > data.len() {
1848 return Row::new(column_names, values);
1849 }
1850 let null_bitmap = &data[pos..pos + null_bitmap_len];
1851 pos += null_bitmap_len;
1852
1853 for (i, col) in columns.iter().enumerate() {
1855 column_names.push(col.name.clone());
1856
1857 let bit_pos = i + 2;
1859 let is_null = (null_bitmap[bit_pos / 8] & (1 << (bit_pos % 8))) != 0;
1860
1861 if is_null {
1862 values.push(Value::Null);
1863 } else {
1864 let is_unsigned = col.flags & 0x20 != 0; let (value, consumed) =
1866 decode_binary_value_with_len(&data[pos..], col.column_type, is_unsigned);
1867 values.push(value);
1868 pos += consumed;
1869 }
1870 }
1871
1872 Row::new(column_names, values)
1873 }
1874
1875 async fn write_packet_raw_async(&mut self, packet: &[u8]) -> Outcome<(), Error> {
1877 let Some(stream) = self.stream.as_mut() else {
1878 return Outcome::Err(connection_error("Connection stream missing"));
1879 };
1880 match stream {
1881 ConnectionStream::Async(stream) => {
1882 let mut written = 0;
1883 while written < packet.len() {
1884 match std::future::poll_fn(|cx| {
1885 std::pin::Pin::new(&mut *stream).poll_write(cx, &packet[written..])
1886 })
1887 .await
1888 {
1889 Ok(n) => written += n,
1890 Err(e) => {
1891 return Outcome::Err(Error::Connection(ConnectionError {
1892 kind: ConnectionErrorKind::Disconnected,
1893 message: format!("Failed to write packet: {}", e),
1894 source: Some(Box::new(e)),
1895 }));
1896 }
1897 }
1898 }
1899 if let Err(e) =
1901 std::future::poll_fn(|cx| std::pin::Pin::new(&mut *stream).poll_flush(cx)).await
1902 {
1903 return Outcome::Err(Error::Connection(ConnectionError {
1904 kind: ConnectionErrorKind::Disconnected,
1905 message: format!("Failed to flush: {}", e),
1906 source: Some(Box::new(e)),
1907 }));
1908 }
1909 Outcome::Ok(())
1910 }
1911 ConnectionStream::Sync(stream) => {
1912 if let Err(e) = stream.write_all(packet) {
1913 return Outcome::Err(Error::Connection(ConnectionError {
1914 kind: ConnectionErrorKind::Disconnected,
1915 message: format!("Failed to write packet: {}", e),
1916 source: Some(Box::new(e)),
1917 }));
1918 }
1919 if let Err(e) = stream.flush() {
1920 return Outcome::Err(Error::Connection(ConnectionError {
1921 kind: ConnectionErrorKind::Disconnected,
1922 message: format!("Failed to flush: {}", e),
1923 source: Some(Box::new(e)),
1924 }));
1925 }
1926 Outcome::Ok(())
1927 }
1928 #[cfg(feature = "tls")]
1929 ConnectionStream::Tls(_) => Outcome::Err(connection_error(
1930 "write_packet_raw_async called after TLS upgrade (bug)",
1931 )),
1932 }
1933 }
1934
1935 pub async fn ping_async(&mut self, _cx: &Cx) -> Outcome<(), Error> {
1937 self.sequence_id = 0;
1938
1939 let mut writer = PacketWriter::new();
1940 writer.write_u8(Command::Ping as u8);
1941
1942 if let Outcome::Err(e) = self.write_packet_async(writer.as_bytes()).await {
1943 return Outcome::Err(e);
1944 }
1945
1946 let (payload, _) = match self.read_packet_async().await {
1947 Outcome::Ok(p) => p,
1948 Outcome::Err(e) => return Outcome::Err(e),
1949 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1950 Outcome::Panicked(p) => return Outcome::Panicked(p),
1951 };
1952
1953 if payload.first() == Some(&0x00) {
1954 Outcome::Ok(())
1955 } else {
1956 Outcome::Err(connection_error("Ping failed"))
1957 }
1958 }
1959
1960 pub async fn close_async(mut self, _cx: &Cx) -> Result<(), Error> {
1962 if self.state == ConnectionState::Closed {
1963 return Ok(());
1964 }
1965
1966 self.sequence_id = 0;
1967
1968 let mut writer = PacketWriter::new();
1969 writer.write_u8(Command::Quit as u8);
1970
1971 let _ = self.write_packet_async(writer.as_bytes()).await;
1973
1974 self.state = ConnectionState::Closed;
1975 Ok(())
1976 }
1977}
1978
1979#[cfg(feature = "console")]
1982impl ConsoleAware for MySqlAsyncConnection {
1983 fn set_console(&mut self, console: Option<Arc<SqlModelConsole>>) {
1984 self.console = console;
1985 }
1986
1987 fn console(&self) -> Option<&Arc<SqlModelConsole>> {
1988 self.console.as_ref()
1989 }
1990}
1991
1992fn protocol_error(msg: impl Into<String>) -> Error {
1995 Error::Protocol(ProtocolError {
1996 message: msg.into(),
1997 raw_data: None,
1998 source: None,
1999 })
2000}
2001
2002fn auth_error(msg: impl Into<String>) -> Error {
2003 Error::Connection(ConnectionError {
2004 kind: ConnectionErrorKind::Authentication,
2005 message: msg.into(),
2006 source: None,
2007 })
2008}
2009
2010fn connection_error(msg: impl Into<String>) -> Error {
2011 Error::Connection(ConnectionError {
2012 kind: ConnectionErrorKind::Connect,
2013 message: msg.into(),
2014 source: None,
2015 })
2016}
2017
2018fn mysql_server_uses_oaep(server_version: &str) -> bool {
2019 let prefix: String = server_version
2022 .chars()
2023 .take_while(|c| c.is_ascii_digit() || *c == '.')
2024 .collect();
2025 let mut it = prefix.split('.').filter(|s| !s.is_empty());
2026 let major: u64 = match it.next().and_then(|s| s.parse().ok()) {
2027 Some(v) => v,
2028 None => return true,
2029 };
2030 let minor: u64 = it.next().and_then(|s| s.parse().ok()).unwrap_or(0);
2031 let patch: u64 = it.next().and_then(|s| s.parse().ok()).unwrap_or(0);
2032
2033 (major, minor, patch) >= (8, 0, 5)
2034}
2035
2036fn query_error(err: &ErrPacket) -> Error {
2037 let kind = if err.is_duplicate_key() || err.is_foreign_key_violation() {
2038 QueryErrorKind::Constraint
2039 } else {
2040 QueryErrorKind::Syntax
2041 };
2042
2043 Error::Query(QueryError {
2044 kind,
2045 message: err.error_message.clone(),
2046 sqlstate: Some(err.sql_state.clone()),
2047 sql: None,
2048 detail: None,
2049 hint: None,
2050 position: None,
2051 source: None,
2052 })
2053}
2054
2055fn query_error_msg(msg: impl Into<String>) -> Error {
2056 Error::Query(QueryError {
2057 kind: QueryErrorKind::Syntax,
2058 message: msg.into(),
2059 sqlstate: None,
2060 sql: None,
2061 detail: None,
2062 hint: None,
2063 position: None,
2064 source: None,
2065 })
2066}
2067
2068fn validate_savepoint_name(name: &str) -> Result<(), Error> {
2076 if name.is_empty() {
2077 return Err(query_error_msg("Savepoint name cannot be empty"));
2078 }
2079 if name.len() > 64 {
2080 return Err(query_error_msg(
2081 "Savepoint name exceeds maximum length of 64 characters",
2082 ));
2083 }
2084 let mut chars = name.chars();
2085 let Some(first) = chars.next() else {
2086 return Err(query_error_msg("Savepoint name cannot be empty"));
2088 };
2089 if !first.is_ascii_alphabetic() && first != '_' {
2090 return Err(query_error_msg(
2091 "Savepoint name must start with a letter or underscore",
2092 ));
2093 }
2094 for c in chars {
2095 if !c.is_ascii_alphanumeric() && c != '_' && c != '$' {
2096 return Err(query_error_msg(format!(
2097 "Savepoint name contains invalid character: '{}'",
2098 c
2099 )));
2100 }
2101 }
2102 Ok(())
2103}
2104
2105pub struct SharedMySqlConnection {
2122 inner: Arc<Mutex<MySqlAsyncConnection>>,
2123}
2124
2125impl SharedMySqlConnection {
2126 pub fn new(conn: MySqlAsyncConnection) -> Self {
2128 Self {
2129 inner: Arc::new(Mutex::new(conn)),
2130 }
2131 }
2132
2133 pub async fn connect(cx: &Cx, config: MySqlConfig) -> Outcome<Self, Error> {
2135 match MySqlAsyncConnection::connect(cx, config).await {
2136 Outcome::Ok(conn) => Outcome::Ok(Self::new(conn)),
2137 Outcome::Err(e) => Outcome::Err(e),
2138 Outcome::Cancelled(c) => Outcome::Cancelled(c),
2139 Outcome::Panicked(p) => Outcome::Panicked(p),
2140 }
2141 }
2142
2143 pub fn inner(&self) -> &Arc<Mutex<MySqlAsyncConnection>> {
2145 &self.inner
2146 }
2147}
2148
2149impl Clone for SharedMySqlConnection {
2150 fn clone(&self) -> Self {
2151 Self {
2152 inner: Arc::clone(&self.inner),
2153 }
2154 }
2155}
2156
2157impl std::fmt::Debug for SharedMySqlConnection {
2158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2159 f.debug_struct("SharedMySqlConnection")
2160 .field("inner", &"Arc<Mutex<MySqlAsyncConnection>>")
2161 .finish()
2162 }
2163}
2164
2165pub struct SharedMySqlTransaction<'conn> {
2184 inner: Arc<Mutex<MySqlAsyncConnection>>,
2185 committed: bool,
2186 _marker: std::marker::PhantomData<&'conn ()>,
2187}
2188
2189impl SharedMySqlConnection {
2190 async fn begin_transaction_impl(
2192 &self,
2193 cx: &Cx,
2194 isolation: Option<IsolationLevel>,
2195 ) -> Outcome<SharedMySqlTransaction<'_>, Error> {
2196 let inner = Arc::clone(&self.inner);
2197
2198 let Ok(mut guard) = inner.lock(cx).await else {
2200 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2201 };
2202
2203 if let Some(level) = isolation {
2205 let isolation_sql = format!("SET TRANSACTION ISOLATION LEVEL {}", level.as_sql());
2206 match guard.execute_async(cx, &isolation_sql, &[]).await {
2207 Outcome::Ok(_) => {}
2208 Outcome::Err(e) => return Outcome::Err(e),
2209 Outcome::Cancelled(c) => return Outcome::Cancelled(c),
2210 Outcome::Panicked(p) => return Outcome::Panicked(p),
2211 }
2212 }
2213
2214 match guard.execute_async(cx, "BEGIN", &[]).await {
2216 Outcome::Ok(_) => {}
2217 Outcome::Err(e) => return Outcome::Err(e),
2218 Outcome::Cancelled(c) => return Outcome::Cancelled(c),
2219 Outcome::Panicked(p) => return Outcome::Panicked(p),
2220 }
2221
2222 drop(guard);
2223
2224 Outcome::Ok(SharedMySqlTransaction {
2225 inner,
2226 committed: false,
2227 _marker: std::marker::PhantomData,
2228 })
2229 }
2230}
2231
2232impl Connection for SharedMySqlConnection {
2233 type Tx<'conn>
2234 = SharedMySqlTransaction<'conn>
2235 where
2236 Self: 'conn;
2237
2238 fn dialect(&self) -> sqlmodel_core::Dialect {
2239 sqlmodel_core::Dialect::Mysql
2240 }
2241
2242 fn query(
2243 &self,
2244 cx: &Cx,
2245 sql: &str,
2246 params: &[Value],
2247 ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
2248 let inner = Arc::clone(&self.inner);
2249 let sql = sql.to_string();
2250 let params = params.to_vec();
2251 async move {
2252 let Ok(mut guard) = inner.lock(cx).await else {
2253 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2254 };
2255 guard.query_async(cx, &sql, ¶ms).await
2256 }
2257 }
2258
2259 fn query_one(
2260 &self,
2261 cx: &Cx,
2262 sql: &str,
2263 params: &[Value],
2264 ) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
2265 let inner = Arc::clone(&self.inner);
2266 let sql = sql.to_string();
2267 let params = params.to_vec();
2268 async move {
2269 let Ok(mut guard) = inner.lock(cx).await else {
2270 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2271 };
2272 let rows = match guard.query_async(cx, &sql, ¶ms).await {
2273 Outcome::Ok(r) => r,
2274 Outcome::Err(e) => return Outcome::Err(e),
2275 Outcome::Cancelled(c) => return Outcome::Cancelled(c),
2276 Outcome::Panicked(p) => return Outcome::Panicked(p),
2277 };
2278 Outcome::Ok(rows.into_iter().next())
2279 }
2280 }
2281
2282 fn execute(
2283 &self,
2284 cx: &Cx,
2285 sql: &str,
2286 params: &[Value],
2287 ) -> impl Future<Output = Outcome<u64, Error>> + Send {
2288 let inner = Arc::clone(&self.inner);
2289 let sql = sql.to_string();
2290 let params = params.to_vec();
2291 async move {
2292 let Ok(mut guard) = inner.lock(cx).await else {
2293 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2294 };
2295 guard.execute_async(cx, &sql, ¶ms).await
2296 }
2297 }
2298
2299 fn insert(
2300 &self,
2301 cx: &Cx,
2302 sql: &str,
2303 params: &[Value],
2304 ) -> impl Future<Output = Outcome<i64, Error>> + Send {
2305 let inner = Arc::clone(&self.inner);
2306 let sql = sql.to_string();
2307 let params = params.to_vec();
2308 async move {
2309 let Ok(mut guard) = inner.lock(cx).await else {
2310 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2311 };
2312 match guard.execute_async(cx, &sql, ¶ms).await {
2313 Outcome::Ok(_) => Outcome::Ok(guard.last_insert_id() as i64),
2314 Outcome::Err(e) => Outcome::Err(e),
2315 Outcome::Cancelled(c) => Outcome::Cancelled(c),
2316 Outcome::Panicked(p) => Outcome::Panicked(p),
2317 }
2318 }
2319 }
2320
2321 fn batch(
2322 &self,
2323 cx: &Cx,
2324 statements: &[(String, Vec<Value>)],
2325 ) -> impl Future<Output = Outcome<Vec<u64>, Error>> + Send {
2326 let inner = Arc::clone(&self.inner);
2327 let statements = statements.to_vec();
2328 async move {
2329 let Ok(mut guard) = inner.lock(cx).await else {
2330 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2331 };
2332 let mut results = Vec::with_capacity(statements.len());
2333 for (sql, params) in &statements {
2334 match guard.execute_async(cx, sql, params).await {
2335 Outcome::Ok(n) => results.push(n),
2336 Outcome::Err(e) => return Outcome::Err(e),
2337 Outcome::Cancelled(c) => return Outcome::Cancelled(c),
2338 Outcome::Panicked(p) => return Outcome::Panicked(p),
2339 }
2340 }
2341 Outcome::Ok(results)
2342 }
2343 }
2344
2345 fn begin(&self, cx: &Cx) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
2346 self.begin_transaction_impl(cx, None)
2347 }
2348
2349 fn begin_with(
2350 &self,
2351 cx: &Cx,
2352 isolation: IsolationLevel,
2353 ) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
2354 self.begin_transaction_impl(cx, Some(isolation))
2355 }
2356
2357 fn prepare(
2358 &self,
2359 cx: &Cx,
2360 sql: &str,
2361 ) -> impl Future<Output = Outcome<PreparedStatement, Error>> + Send {
2362 let inner = Arc::clone(&self.inner);
2363 let sql = sql.to_string();
2364 async move {
2365 let Ok(mut guard) = inner.lock(cx).await else {
2366 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2367 };
2368 guard.prepare_async(cx, &sql).await
2369 }
2370 }
2371
2372 fn query_prepared(
2373 &self,
2374 cx: &Cx,
2375 stmt: &PreparedStatement,
2376 params: &[Value],
2377 ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
2378 let inner = Arc::clone(&self.inner);
2379 let stmt = stmt.clone();
2380 let params = params.to_vec();
2381 async move {
2382 let Ok(mut guard) = inner.lock(cx).await else {
2383 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2384 };
2385 guard.query_prepared_async(cx, &stmt, ¶ms).await
2386 }
2387 }
2388
2389 fn execute_prepared(
2390 &self,
2391 cx: &Cx,
2392 stmt: &PreparedStatement,
2393 params: &[Value],
2394 ) -> impl Future<Output = Outcome<u64, Error>> + Send {
2395 let inner = Arc::clone(&self.inner);
2396 let stmt = stmt.clone();
2397 let params = params.to_vec();
2398 async move {
2399 let Ok(mut guard) = inner.lock(cx).await else {
2400 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2401 };
2402 guard.execute_prepared_async(cx, &stmt, ¶ms).await
2403 }
2404 }
2405
2406 fn ping(&self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
2407 let inner = Arc::clone(&self.inner);
2408 async move {
2409 let Ok(mut guard) = inner.lock(cx).await else {
2410 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2411 };
2412 guard.ping_async(cx).await
2413 }
2414 }
2415
2416 fn close(self, cx: &Cx) -> impl Future<Output = Result<(), Error>> + Send {
2417 async move {
2418 match Arc::try_unwrap(self.inner) {
2420 Ok(mutex) => {
2421 let conn = mutex.into_inner();
2422 conn.close_async(cx).await
2423 }
2424 Err(_) => {
2425 Err(connection_error(
2427 "Cannot close: other references to connection exist",
2428 ))
2429 }
2430 }
2431 }
2432 }
2433}
2434
2435impl<'conn> TransactionOps for SharedMySqlTransaction<'conn> {
2436 fn query(
2437 &self,
2438 cx: &Cx,
2439 sql: &str,
2440 params: &[Value],
2441 ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
2442 let inner = Arc::clone(&self.inner);
2443 let sql = sql.to_string();
2444 let params = params.to_vec();
2445 async move {
2446 let Ok(mut guard) = inner.lock(cx).await else {
2447 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2448 };
2449 guard.query_async(cx, &sql, ¶ms).await
2450 }
2451 }
2452
2453 fn query_one(
2454 &self,
2455 cx: &Cx,
2456 sql: &str,
2457 params: &[Value],
2458 ) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
2459 let inner = Arc::clone(&self.inner);
2460 let sql = sql.to_string();
2461 let params = params.to_vec();
2462 async move {
2463 let Ok(mut guard) = inner.lock(cx).await else {
2464 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2465 };
2466 let rows = match guard.query_async(cx, &sql, ¶ms).await {
2467 Outcome::Ok(r) => r,
2468 Outcome::Err(e) => return Outcome::Err(e),
2469 Outcome::Cancelled(c) => return Outcome::Cancelled(c),
2470 Outcome::Panicked(p) => return Outcome::Panicked(p),
2471 };
2472 Outcome::Ok(rows.into_iter().next())
2473 }
2474 }
2475
2476 fn execute(
2477 &self,
2478 cx: &Cx,
2479 sql: &str,
2480 params: &[Value],
2481 ) -> impl Future<Output = Outcome<u64, Error>> + Send {
2482 let inner = Arc::clone(&self.inner);
2483 let sql = sql.to_string();
2484 let params = params.to_vec();
2485 async move {
2486 let Ok(mut guard) = inner.lock(cx).await else {
2487 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2488 };
2489 guard.execute_async(cx, &sql, ¶ms).await
2490 }
2491 }
2492
2493 fn savepoint(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
2494 let inner = Arc::clone(&self.inner);
2495 let validation_result = validate_savepoint_name(name);
2497 let sql = format!("SAVEPOINT {}", name);
2498 async move {
2499 if let Err(e) = validation_result {
2501 return Outcome::Err(e);
2502 }
2503 let Ok(mut guard) = inner.lock(cx).await else {
2504 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2505 };
2506 match guard.execute_async(cx, &sql, &[]).await {
2507 Outcome::Ok(_) => Outcome::Ok(()),
2508 Outcome::Err(e) => Outcome::Err(e),
2509 Outcome::Cancelled(c) => Outcome::Cancelled(c),
2510 Outcome::Panicked(p) => Outcome::Panicked(p),
2511 }
2512 }
2513 }
2514
2515 fn rollback_to(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
2516 let inner = Arc::clone(&self.inner);
2517 let validation_result = validate_savepoint_name(name);
2519 let sql = format!("ROLLBACK TO SAVEPOINT {}", name);
2520 async move {
2521 if let Err(e) = validation_result {
2523 return Outcome::Err(e);
2524 }
2525 let Ok(mut guard) = inner.lock(cx).await else {
2526 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2527 };
2528 match guard.execute_async(cx, &sql, &[]).await {
2529 Outcome::Ok(_) => Outcome::Ok(()),
2530 Outcome::Err(e) => Outcome::Err(e),
2531 Outcome::Cancelled(c) => Outcome::Cancelled(c),
2532 Outcome::Panicked(p) => Outcome::Panicked(p),
2533 }
2534 }
2535 }
2536
2537 fn release(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
2538 let inner = Arc::clone(&self.inner);
2539 let validation_result = validate_savepoint_name(name);
2541 let sql = format!("RELEASE SAVEPOINT {}", name);
2542 async move {
2543 if let Err(e) = validation_result {
2545 return Outcome::Err(e);
2546 }
2547 let Ok(mut guard) = inner.lock(cx).await else {
2548 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2549 };
2550 match guard.execute_async(cx, &sql, &[]).await {
2551 Outcome::Ok(_) => Outcome::Ok(()),
2552 Outcome::Err(e) => Outcome::Err(e),
2553 Outcome::Cancelled(c) => Outcome::Cancelled(c),
2554 Outcome::Panicked(p) => Outcome::Panicked(p),
2555 }
2556 }
2557 }
2558
2559 #[allow(unused_assignments)]
2562 fn commit(mut self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
2563 async move {
2564 let Ok(mut guard) = self.inner.lock(cx).await else {
2565 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2566 };
2567 match guard.execute_async(cx, "COMMIT", &[]).await {
2568 Outcome::Ok(_) => {
2569 self.committed = true;
2570 Outcome::Ok(())
2571 }
2572 Outcome::Err(e) => Outcome::Err(e),
2573 Outcome::Cancelled(c) => Outcome::Cancelled(c),
2574 Outcome::Panicked(p) => Outcome::Panicked(p),
2575 }
2576 }
2577 }
2578
2579 fn rollback(self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
2580 async move {
2581 let Ok(mut guard) = self.inner.lock(cx).await else {
2582 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2583 };
2584 match guard.execute_async(cx, "ROLLBACK", &[]).await {
2585 Outcome::Ok(_) => Outcome::Ok(()),
2586 Outcome::Err(e) => Outcome::Err(e),
2587 Outcome::Cancelled(c) => Outcome::Cancelled(c),
2588 Outcome::Panicked(p) => Outcome::Panicked(p),
2589 }
2590 }
2591 }
2592}
2593
2594impl<'conn> Drop for SharedMySqlTransaction<'conn> {
2595 fn drop(&mut self) {
2596 if !self.committed {
2597 #[cfg(debug_assertions)]
2605 eprintln!(
2606 "WARNING: SharedMySqlTransaction dropped without commit/rollback. \
2607 The MySQL transaction may still be open."
2608 );
2609 }
2610 }
2611}
2612
2613#[cfg(test)]
2614mod tests {
2615 use super::*;
2616
2617 #[test]
2618 fn test_connection_state() {
2619 assert_eq!(ConnectionState::Disconnected, ConnectionState::Disconnected);
2620 }
2621
2622 #[test]
2623 fn test_error_helpers() {
2624 let err = protocol_error("test");
2625 assert!(matches!(err, Error::Protocol(_)));
2626
2627 let err = auth_error("auth failed");
2628 assert!(matches!(err, Error::Connection(_)));
2629
2630 let err = connection_error("conn failed");
2631 assert!(matches!(err, Error::Connection(_)));
2632 }
2633
2634 #[test]
2635 fn test_validate_savepoint_name_valid() {
2636 assert!(validate_savepoint_name("sp1").is_ok());
2638 assert!(validate_savepoint_name("_savepoint").is_ok());
2639 assert!(validate_savepoint_name("SavePoint_123").is_ok());
2640 assert!(validate_savepoint_name("sp$test").is_ok());
2641 assert!(validate_savepoint_name("a").is_ok());
2642 assert!(validate_savepoint_name("_").is_ok());
2643 }
2644
2645 #[test]
2646 fn test_validate_savepoint_name_invalid() {
2647 assert!(validate_savepoint_name("").is_err());
2649
2650 assert!(validate_savepoint_name("1savepoint").is_err());
2652
2653 assert!(validate_savepoint_name("save-point").is_err());
2655 assert!(validate_savepoint_name("save point").is_err());
2656 assert!(validate_savepoint_name("save;drop table").is_err());
2657 assert!(validate_savepoint_name("sp'--").is_err());
2658
2659 let long_name = "a".repeat(65);
2661 assert!(validate_savepoint_name(&long_name).is_err());
2662 }
2663}