#![allow(clippy::manual_async_fn)]
#![allow(clippy::result_large_err)]
use std::collections::HashMap;
use std::future::Future;
use std::io::{self, Read as StdRead, Write as StdWrite};
use std::net::TcpStream as StdTcpStream;
use std::sync::Arc;
use asupersync::io::{AsyncRead, AsyncWrite, ReadBuf};
use asupersync::net::TcpStream;
use asupersync::sync::Mutex;
use asupersync::{Cx, Outcome};
use sqlmodel_core::connection::{Connection, IsolationLevel, PreparedStatement, TransactionOps};
use sqlmodel_core::error::{
ConnectionError, ConnectionErrorKind, ProtocolError, QueryError, QueryErrorKind,
};
use sqlmodel_core::{Error, Row, Value};
#[cfg(feature = "console")]
use sqlmodel_console::{ConsoleAware, SqlModelConsole};
use crate::auth;
use crate::config::MySqlConfig;
use crate::connection::{ConnectionState, ServerCapabilities};
use crate::protocol::{
Command, ErrPacket, MAX_PACKET_SIZE, PacketHeader, PacketReader, PacketType, PacketWriter,
capabilities, charset, prepared,
};
use crate::types::{
ColumnDef, FieldType, decode_binary_value_with_len, decode_text_value, interpolate_params,
};
pub struct MySqlAsyncConnection {
stream: Option<ConnectionStream>,
state: ConnectionState,
server_caps: Option<ServerCapabilities>,
connection_id: u32,
status_flags: u16,
affected_rows: u64,
last_insert_id: u64,
warnings: u16,
config: MySqlConfig,
sequence_id: u8,
prepared_stmts: HashMap<u32, PreparedStmtMeta>,
#[cfg(feature = "console")]
console: Option<Arc<SqlModelConsole>>,
}
#[derive(Debug, Clone)]
struct PreparedStmtMeta {
#[allow(dead_code)]
statement_id: u32,
params: Vec<ColumnDef>,
columns: Vec<ColumnDef>,
}
#[allow(dead_code)]
enum ConnectionStream {
Sync(StdTcpStream),
Async(TcpStream),
#[cfg(feature = "tls")]
Tls(AsyncTlsStream),
}
#[cfg(feature = "tls")]
struct AsyncTlsStream {
tcp: TcpStream,
tls: rustls::ClientConnection,
}
#[cfg(feature = "tls")]
impl std::fmt::Debug for AsyncTlsStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncTlsStream")
.field("protocol_version", &self.tls.protocol_version())
.field("is_handshaking", &self.tls.is_handshaking())
.finish_non_exhaustive()
}
}
#[cfg(feature = "tls")]
impl AsyncTlsStream {
async fn handshake(
mut tcp: TcpStream,
tls_config: &crate::config::TlsConfig,
host: &str,
ssl_mode: crate::config::SslMode,
) -> Result<Self, Error> {
let config = crate::tls::build_client_config(tls_config, ssl_mode)?;
let sni = tls_config.server_name.as_deref().unwrap_or(host);
let server_name = sni
.to_string()
.try_into()
.map_err(|e| connection_error(format!("Invalid server name '{sni}': {e}")))?;
let mut tls = rustls::ClientConnection::new(std::sync::Arc::new(config), server_name)
.map_err(|e| connection_error(format!("Failed to create TLS connection: {e}")))?;
while tls.is_handshaking() {
while tls.wants_write() {
let mut out = Vec::new();
tls.write_tls(&mut out)
.map_err(|e| connection_error(format!("TLS handshake write_tls error: {e}")))?;
if !out.is_empty() {
write_all_async(&mut tcp, &out).await.map_err(|e| {
Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("TLS handshake write error: {e}"),
source: Some(Box::new(e)),
})
})?;
}
}
if tls.wants_read() {
let mut buf = [0u8; 8192];
let n = read_some_async(&mut tcp, &mut buf).await.map_err(|e| {
Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("TLS handshake read error: {e}"),
source: Some(Box::new(e)),
})
})?;
if n == 0 {
return Err(connection_error("Connection closed during TLS handshake"));
}
let mut cursor = std::io::Cursor::new(&buf[..n]);
tls.read_tls(&mut cursor)
.map_err(|e| connection_error(format!("TLS handshake read_tls error: {e}")))?;
tls.process_new_packets()
.map_err(|e| connection_error(format!("TLS handshake error: {e}")))?;
}
}
Ok(Self { tcp, tls })
}
async fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
let mut read = 0;
while read < buf.len() {
let n = self.read_plain(&mut buf[read..]).await?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"connection closed",
));
}
read += n;
}
Ok(())
}
async fn read_plain(&mut self, out: &mut [u8]) -> io::Result<usize> {
loop {
match self.tls.reader().read(out) {
Ok(n) if n > 0 => return Ok(n),
Ok(_) => {}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
Err(e) => return Err(e),
}
if !self.tls.wants_read() {
return Ok(0);
}
let mut enc = [0u8; 8192];
let n = read_some_async(&mut self.tcp, &mut enc).await?;
if n == 0 {
return Ok(0);
}
let mut cursor = std::io::Cursor::new(&enc[..n]);
self.tls.read_tls(&mut cursor)?;
self.tls
.process_new_packets()
.map_err(|e| io::Error::other(format!("TLS error: {e}")))?;
}
}
async fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
let mut written = 0;
while written < buf.len() {
let n = self.tls.writer().write(&buf[written..])?;
if n == 0 {
return Err(io::Error::new(io::ErrorKind::WriteZero, "TLS write zero"));
}
written += n;
self.flush().await?;
}
Ok(())
}
async fn flush(&mut self) -> io::Result<()> {
self.tls.writer().flush()?;
while self.tls.wants_write() {
let mut out = Vec::new();
self.tls.write_tls(&mut out)?;
if !out.is_empty() {
write_all_async(&mut self.tcp, &out).await?;
}
}
flush_async(&mut self.tcp).await
}
}
#[cfg(feature = "tls")]
async fn read_some_async(stream: &mut TcpStream, buf: &mut [u8]) -> io::Result<usize> {
let mut read_buf = ReadBuf::new(buf);
std::future::poll_fn(|cx| std::pin::Pin::new(&mut *stream).poll_read(cx, &mut read_buf))
.await?;
Ok(read_buf.filled().len())
}
#[cfg(feature = "tls")]
async fn write_all_async(stream: &mut TcpStream, buf: &[u8]) -> io::Result<()> {
let mut written = 0;
while written < buf.len() {
let n = std::future::poll_fn(|cx| {
std::pin::Pin::new(&mut *stream).poll_write(cx, &buf[written..])
})
.await?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"connection closed",
));
}
written += n;
}
Ok(())
}
#[cfg(feature = "tls")]
async fn flush_async(stream: &mut TcpStream) -> io::Result<()> {
std::future::poll_fn(|cx| std::pin::Pin::new(&mut *stream).poll_flush(cx)).await
}
impl std::fmt::Debug for MySqlAsyncConnection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MySqlAsyncConnection")
.field("state", &self.state)
.field("connection_id", &self.connection_id)
.field("host", &self.config.host)
.field("port", &self.config.port)
.field("database", &self.config.database)
.finish_non_exhaustive()
}
}
impl MySqlAsyncConnection {
pub async fn connect(_cx: &Cx, config: MySqlConfig) -> Outcome<Self, Error> {
let addr = config.socket_addr();
let socket_addr = match addr.parse() {
Ok(a) => a,
Err(e) => {
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Connect,
message: format!("Invalid socket address: {}", e),
source: None,
}));
}
};
let stream = match TcpStream::connect_timeout(socket_addr, config.connect_timeout).await {
Ok(s) => s,
Err(e) => {
let kind = if e.kind() == io::ErrorKind::ConnectionRefused {
ConnectionErrorKind::Refused
} else {
ConnectionErrorKind::Connect
};
return Outcome::Err(Error::Connection(ConnectionError {
kind,
message: format!("Failed to connect to {}: {}", addr, e),
source: Some(Box::new(e)),
}));
}
};
stream.set_nodelay(true).ok();
let mut conn = Self {
stream: Some(ConnectionStream::Async(stream)),
state: ConnectionState::Connecting,
server_caps: None,
connection_id: 0,
status_flags: 0,
affected_rows: 0,
last_insert_id: 0,
warnings: 0,
config,
sequence_id: 0,
prepared_stmts: HashMap::new(),
#[cfg(feature = "console")]
console: None,
};
match conn.read_handshake_async().await {
Outcome::Ok(server_caps) => {
conn.connection_id = server_caps.connection_id;
conn.server_caps = Some(server_caps);
conn.state = ConnectionState::Authenticating;
}
Outcome::Err(e) => return Outcome::Err(e),
Outcome::Cancelled(r) => return Outcome::Cancelled(r),
Outcome::Panicked(p) => return Outcome::Panicked(p),
}
if let Outcome::Err(e) = conn.send_handshake_response_async().await {
return Outcome::Err(e);
}
if let Outcome::Err(e) = conn.handle_auth_result_async().await {
return Outcome::Err(e);
}
conn.state = ConnectionState::Ready;
Outcome::Ok(conn)
}
pub fn state(&self) -> ConnectionState {
self.state
}
pub fn is_ready(&self) -> bool {
matches!(self.state, ConnectionState::Ready)
}
fn is_secure_transport(&self) -> bool {
#[cfg(feature = "tls")]
{
matches!(self.stream, Some(ConnectionStream::Tls(_)))
}
#[cfg(not(feature = "tls"))]
{
false
}
}
pub fn connection_id(&self) -> u32 {
self.connection_id
}
pub fn server_version(&self) -> Option<&str> {
self.server_caps
.as_ref()
.map(|caps| caps.server_version.as_str())
}
pub fn affected_rows(&self) -> u64 {
self.affected_rows
}
pub fn last_insert_id(&self) -> u64 {
self.last_insert_id
}
async fn read_packet_async(&mut self) -> Outcome<(Vec<u8>, u8), Error> {
let mut header_buf = [0u8; 4];
let Some(stream) = self.stream.as_mut() else {
return Outcome::Err(connection_error("Connection stream missing"));
};
match stream {
ConnectionStream::Async(stream) => {
let mut header_read = 0;
while header_read < 4 {
let mut read_buf = ReadBuf::new(&mut header_buf[header_read..]);
match std::future::poll_fn(|cx| {
std::pin::Pin::new(&mut *stream).poll_read(cx, &mut read_buf)
})
.await
{
Ok(()) => {
let n = read_buf.filled().len();
if n == 0 {
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: "Connection closed while reading header".to_string(),
source: None,
}));
}
header_read += n;
}
Err(e) => {
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("Failed to read packet header: {}", e),
source: Some(Box::new(e)),
}));
}
}
}
}
ConnectionStream::Sync(stream) => {
if let Err(e) = stream.read_exact(&mut header_buf) {
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("Failed to read packet header: {}", e),
source: Some(Box::new(e)),
}));
}
}
#[cfg(feature = "tls")]
ConnectionStream::Tls(stream) => {
if let Err(e) = stream.read_exact(&mut header_buf).await {
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("Failed to read packet header: {e}"),
source: Some(Box::new(e)),
}));
}
}
}
let header = PacketHeader::from_bytes(&header_buf);
let payload_len = header.payload_length as usize;
self.sequence_id = header.sequence_id.wrapping_add(1);
let mut payload = vec![0u8; payload_len];
if payload_len > 0 {
let Some(stream) = self.stream.as_mut() else {
return Outcome::Err(connection_error("Connection stream missing"));
};
match stream {
ConnectionStream::Async(stream) => {
let mut total_read = 0;
while total_read < payload_len {
let mut read_buf = ReadBuf::new(&mut payload[total_read..]);
match std::future::poll_fn(|cx| {
std::pin::Pin::new(&mut *stream).poll_read(cx, &mut read_buf)
})
.await
{
Ok(()) => {
let n = read_buf.filled().len();
if n == 0 {
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: "Connection closed while reading payload"
.to_string(),
source: None,
}));
}
total_read += n;
}
Err(e) => {
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("Failed to read packet payload: {}", e),
source: Some(Box::new(e)),
}));
}
}
}
}
ConnectionStream::Sync(stream) => {
if let Err(e) = stream.read_exact(&mut payload) {
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("Failed to read packet payload: {}", e),
source: Some(Box::new(e)),
}));
}
}
#[cfg(feature = "tls")]
ConnectionStream::Tls(stream) => {
if let Err(e) = stream.read_exact(&mut payload).await {
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("Failed to read packet payload: {e}"),
source: Some(Box::new(e)),
}));
}
}
}
}
if payload_len == MAX_PACKET_SIZE {
loop {
let mut header_buf = [0u8; 4];
let Some(stream) = self.stream.as_mut() else {
return Outcome::Err(connection_error("Connection stream missing"));
};
match stream {
ConnectionStream::Async(stream) => {
let mut header_read = 0;
while header_read < 4 {
let mut read_buf = ReadBuf::new(&mut header_buf[header_read..]);
match std::future::poll_fn(|cx| {
std::pin::Pin::new(&mut *stream).poll_read(cx, &mut read_buf)
})
.await
{
Ok(()) => {
let n = read_buf.filled().len();
if n == 0 {
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: "Connection closed while reading continuation header".to_string(),
source: None,
}));
}
header_read += n;
}
Err(e) => {
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!(
"Failed to read continuation header: {}",
e
),
source: Some(Box::new(e)),
}));
}
}
}
}
ConnectionStream::Sync(stream) => {
if let Err(e) = stream.read_exact(&mut header_buf) {
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("Failed to read continuation header: {}", e),
source: Some(Box::new(e)),
}));
}
}
#[cfg(feature = "tls")]
ConnectionStream::Tls(stream) => {
if let Err(e) = stream.read_exact(&mut header_buf).await {
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("Failed to read continuation header: {e}"),
source: Some(Box::new(e)),
}));
}
}
}
let cont_header = PacketHeader::from_bytes(&header_buf);
let cont_len = cont_header.payload_length as usize;
self.sequence_id = cont_header.sequence_id.wrapping_add(1);
if cont_len > 0 {
let mut cont_payload = vec![0u8; cont_len];
let Some(stream) = self.stream.as_mut() else {
return Outcome::Err(connection_error("Connection stream missing"));
};
match stream {
ConnectionStream::Async(stream) => {
let mut total_read = 0;
while total_read < cont_len {
let mut read_buf = ReadBuf::new(&mut cont_payload[total_read..]);
match std::future::poll_fn(|cx| {
std::pin::Pin::new(&mut *stream).poll_read(cx, &mut read_buf)
})
.await
{
Ok(()) => {
let n = read_buf.filled().len();
if n == 0 {
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: "Connection closed while reading continuation payload".to_string(),
source: None,
}));
}
total_read += n;
}
Err(e) => {
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!(
"Failed to read continuation payload: {}",
e
),
source: Some(Box::new(e)),
}));
}
}
}
}
ConnectionStream::Sync(stream) => {
if let Err(e) = stream.read_exact(&mut cont_payload) {
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("Failed to read continuation payload: {}", e),
source: Some(Box::new(e)),
}));
}
}
#[cfg(feature = "tls")]
ConnectionStream::Tls(stream) => {
if let Err(e) = stream.read_exact(&mut cont_payload).await {
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("Failed to read continuation payload: {e}"),
source: Some(Box::new(e)),
}));
}
}
}
payload.extend_from_slice(&cont_payload);
}
if cont_len < MAX_PACKET_SIZE {
break;
}
}
}
Outcome::Ok((payload, header.sequence_id))
}
async fn write_packet_async(&mut self, payload: &[u8]) -> Outcome<(), Error> {
let writer = PacketWriter::new();
let packet = writer.build_packet_from_payload(payload, self.sequence_id);
self.sequence_id = self.sequence_id.wrapping_add(1);
let Some(stream) = self.stream.as_mut() else {
return Outcome::Err(connection_error("Connection stream missing"));
};
match stream {
ConnectionStream::Async(stream) => {
let mut written = 0;
while written < packet.len() {
match std::future::poll_fn(|cx| {
std::pin::Pin::new(&mut *stream).poll_write(cx, &packet[written..])
})
.await
{
Ok(n) => {
if n == 0 {
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: "Connection closed while writing packet".to_string(),
source: None,
}));
}
written += n;
}
Err(e) => {
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("Failed to write packet: {}", e),
source: Some(Box::new(e)),
}));
}
}
}
match std::future::poll_fn(|cx| std::pin::Pin::new(&mut *stream).poll_flush(cx))
.await
{
Ok(()) => {}
Err(e) => {
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("Failed to flush stream: {}", e),
source: Some(Box::new(e)),
}));
}
}
}
ConnectionStream::Sync(stream) => {
if let Err(e) = stream.write_all(&packet) {
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("Failed to write packet: {}", e),
source: Some(Box::new(e)),
}));
}
if let Err(e) = stream.flush() {
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("Failed to flush stream: {}", e),
source: Some(Box::new(e)),
}));
}
}
#[cfg(feature = "tls")]
ConnectionStream::Tls(stream) => {
if let Err(e) = stream.write_all(&packet).await {
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("Failed to write packet: {e}"),
source: Some(Box::new(e)),
}));
}
if let Err(e) = stream.flush().await {
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("Failed to flush stream: {e}"),
source: Some(Box::new(e)),
}));
}
}
}
Outcome::Ok(())
}
async fn read_handshake_async(&mut self) -> Outcome<ServerCapabilities, Error> {
let (payload, _) = match self.read_packet_async().await {
Outcome::Ok(p) => p,
Outcome::Err(e) => return Outcome::Err(e),
Outcome::Cancelled(r) => return Outcome::Cancelled(r),
Outcome::Panicked(p) => return Outcome::Panicked(p),
};
let mut reader = PacketReader::new(&payload);
let Some(protocol_version) = reader.read_u8() else {
return Outcome::Err(protocol_error("Missing protocol version"));
};
if protocol_version != 10 {
return Outcome::Err(protocol_error(format!(
"Unsupported protocol version: {}",
protocol_version
)));
}
let Some(server_version) = reader.read_null_string() else {
return Outcome::Err(protocol_error("Missing server version"));
};
let Some(connection_id) = reader.read_u32_le() else {
return Outcome::Err(protocol_error("Missing connection ID"));
};
let Some(auth_data_1) = reader.read_bytes(8) else {
return Outcome::Err(protocol_error("Missing auth data"));
};
reader.skip(1);
let Some(caps_lower) = reader.read_u16_le() else {
return Outcome::Err(protocol_error("Missing capability flags"));
};
let charset_val = reader.read_u8().unwrap_or(charset::UTF8MB4_0900_AI_CI);
let status_flags = reader.read_u16_le().unwrap_or(0);
let caps_upper = reader.read_u16_le().unwrap_or(0);
let capabilities_val = u32::from(caps_lower) | (u32::from(caps_upper) << 16);
let auth_data_len = if capabilities_val & capabilities::CLIENT_PLUGIN_AUTH != 0 {
reader.read_u8().unwrap_or(0) as usize
} else {
0
};
reader.skip(10);
let mut auth_data = auth_data_1.to_vec();
if capabilities_val & capabilities::CLIENT_SECURE_CONNECTION != 0 {
let len2 = if auth_data_len > 8 {
auth_data_len - 8
} else {
13 };
if let Some(data2) = reader.read_bytes(len2) {
let data2_clean = if data2.last() == Some(&0) {
&data2[..data2.len() - 1]
} else {
data2
};
auth_data.extend_from_slice(data2_clean);
}
}
let auth_plugin = if capabilities_val & capabilities::CLIENT_PLUGIN_AUTH != 0 {
reader.read_null_string().unwrap_or_default()
} else {
auth::plugins::MYSQL_NATIVE_PASSWORD.to_string()
};
Outcome::Ok(ServerCapabilities {
capabilities: capabilities_val,
protocol_version,
server_version,
connection_id,
auth_plugin,
auth_data,
charset: charset_val,
status_flags,
})
}
async fn send_handshake_response_async(&mut self) -> Outcome<(), Error> {
let Some(server_caps) = self.server_caps.as_ref() else {
return Outcome::Err(protocol_error("No server handshake received"));
};
let server_caps_bits = server_caps.capabilities;
let auth_plugin = server_caps.auth_plugin.clone();
let auth_data = server_caps.auth_data.clone();
let mut client_caps = self.config.capability_flags() & server_caps_bits;
#[cfg(feature = "tls")]
if let Outcome::Err(e) = self
.maybe_upgrade_tls_async(server_caps_bits, &mut client_caps)
.await
{
return Outcome::Err(e);
}
#[cfg(not(feature = "tls"))]
if let Outcome::Err(e) = self.maybe_upgrade_tls(server_caps_bits, &mut client_caps) {
return Outcome::Err(e);
}
let auth_response = self.compute_auth_response(&auth_plugin, &auth_data);
let mut writer = PacketWriter::new();
writer.write_u32_le(client_caps);
writer.write_u32_le(self.config.max_packet_size);
writer.write_u8(self.config.charset);
writer.write_zeros(23);
writer.write_null_string(&self.config.user);
if client_caps & capabilities::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA != 0 {
writer.write_lenenc_bytes(&auth_response);
} else if client_caps & capabilities::CLIENT_SECURE_CONNECTION != 0 {
#[allow(clippy::cast_possible_truncation)]
writer.write_u8(auth_response.len() as u8);
writer.write_bytes(&auth_response);
} else {
writer.write_bytes(&auth_response);
writer.write_u8(0); }
if client_caps & capabilities::CLIENT_CONNECT_WITH_DB != 0 {
if let Some(ref db) = self.config.database {
writer.write_null_string(db);
} else {
writer.write_u8(0); }
}
if client_caps & capabilities::CLIENT_PLUGIN_AUTH != 0 {
writer.write_null_string(&auth_plugin);
}
if client_caps & capabilities::CLIENT_CONNECT_ATTRS != 0
&& !self.config.attributes.is_empty()
{
let mut attrs_writer = PacketWriter::new();
for (key, value) in &self.config.attributes {
attrs_writer.write_lenenc_string(key);
attrs_writer.write_lenenc_string(value);
}
let attrs_data = attrs_writer.into_bytes();
writer.write_lenenc_bytes(&attrs_data);
}
self.write_packet_async(writer.as_bytes()).await
}
#[cfg(feature = "tls")]
async fn maybe_upgrade_tls_async(
&mut self,
server_caps: u32,
client_caps: &mut u32,
) -> Outcome<(), Error> {
let ssl_mode = self.config.ssl_mode;
if !ssl_mode.should_try_ssl() {
*client_caps &= !capabilities::CLIENT_SSL;
return Outcome::Ok(());
}
let use_tls = match crate::tls::validate_ssl_mode(ssl_mode, server_caps) {
Ok(v) => v,
Err(e) => return Outcome::Err(e),
};
if !use_tls {
*client_caps &= !capabilities::CLIENT_SSL;
return Outcome::Ok(());
}
if let Err(e) = crate::tls::validate_tls_config(ssl_mode, &self.config.tls_config) {
return Outcome::Err(e);
}
let packet = crate::tls::build_ssl_request_packet(
*client_caps,
self.config.max_packet_size,
self.config.charset,
self.sequence_id,
);
if let Outcome::Err(e) = self.write_packet_raw_async(&packet).await {
return Outcome::Err(e);
}
self.sequence_id = self.sequence_id.wrapping_add(1);
let Some(stream) = self.stream.take() else {
return Outcome::Err(connection_error("Connection stream missing"));
};
let ConnectionStream::Async(tcp) = stream else {
return Outcome::Err(connection_error("TLS upgrade requires async TCP stream"));
};
let tls = match AsyncTlsStream::handshake(
tcp,
&self.config.tls_config,
&self.config.host,
ssl_mode,
)
.await
{
Ok(s) => s,
Err(e) => return Outcome::Err(e),
};
self.stream = Some(ConnectionStream::Tls(tls));
Outcome::Ok(())
}
#[cfg(not(feature = "tls"))]
fn maybe_upgrade_tls(&mut self, server_caps: u32, client_caps: &mut u32) -> Outcome<(), Error> {
let ssl_mode = self.config.ssl_mode;
if !ssl_mode.should_try_ssl() {
*client_caps &= !capabilities::CLIENT_SSL;
return Outcome::Ok(());
}
let use_tls = match crate::tls::validate_ssl_mode(ssl_mode, server_caps) {
Ok(v) => v,
Err(e) => return Outcome::Err(e),
};
if !use_tls {
*client_caps &= !capabilities::CLIENT_SSL;
return Outcome::Ok(());
}
if ssl_mode == crate::config::SslMode::Preferred {
*client_caps &= !capabilities::CLIENT_SSL;
Outcome::Ok(())
} else {
Outcome::Err(connection_error(
"TLS requested but 'sqlmodel-mysql' was built without feature 'tls'",
))
}
}
fn compute_auth_response(&self, plugin: &str, auth_data: &[u8]) -> Vec<u8> {
let pw = self.config.password_str();
match plugin {
auth::plugins::MYSQL_NATIVE_PASSWORD => {
auth::mysql_native_password(pw, auth_data)
}
auth::plugins::CACHING_SHA2_PASSWORD => {
auth::caching_sha2_password(pw, auth_data)
}
auth::plugins::MYSQL_CLEAR_PASSWORD => {
let mut result = pw.as_bytes().to_vec();
result.push(0);
result
}
_ => auth::mysql_native_password(pw, auth_data),
}
}
async fn handle_auth_result_async(&mut self) -> Outcome<(), Error> {
loop {
let (payload, _) = match self.read_packet_async().await {
Outcome::Ok(p) => p,
Outcome::Err(e) => return Outcome::Err(e),
Outcome::Cancelled(r) => return Outcome::Cancelled(r),
Outcome::Panicked(p) => return Outcome::Panicked(p),
};
if payload.is_empty() {
return Outcome::Err(protocol_error("Empty authentication response"));
}
#[allow(clippy::cast_possible_truncation)] match PacketType::from_first_byte(payload[0], payload.len() as u32) {
PacketType::Ok => {
let mut reader = PacketReader::new(&payload);
if let Some(ok) = reader.parse_ok_packet() {
self.status_flags = ok.status_flags;
self.affected_rows = ok.affected_rows;
}
return Outcome::Ok(());
}
PacketType::Error => {
let mut reader = PacketReader::new(&payload);
let Some(err) = reader.parse_err_packet() else {
return Outcome::Err(protocol_error("Invalid error packet"));
};
return Outcome::Err(auth_error(format!(
"Authentication failed: {} ({})",
err.error_message, err.error_code
)));
}
PacketType::Eof => {
let data = &payload[1..];
let mut reader = PacketReader::new(data);
let Some(plugin) = reader.read_null_string() else {
return Outcome::Err(protocol_error("Missing plugin name in auth switch"));
};
let auth_data = reader.read_rest();
let response = self.compute_auth_response(&plugin, auth_data);
if let Outcome::Err(e) = self.write_packet_async(&response).await {
return Outcome::Err(e);
}
}
_ => {
match self.handle_additional_auth_async(&payload).await {
Outcome::Ok(()) => continue,
Outcome::Err(e) => return Outcome::Err(e),
Outcome::Cancelled(r) => return Outcome::Cancelled(r),
Outcome::Panicked(p) => return Outcome::Panicked(p),
}
}
}
}
}
async fn handle_additional_auth_async(&mut self, data: &[u8]) -> Outcome<(), Error> {
if data.is_empty() {
return Outcome::Err(protocol_error("Empty additional auth data"));
}
match data[0] {
auth::caching_sha2::FAST_AUTH_SUCCESS => {
Outcome::Ok(())
}
auth::caching_sha2::PERFORM_FULL_AUTH => {
let Some(server_caps) = self.server_caps.as_ref() else {
return Outcome::Err(protocol_error("Missing server capabilities during auth"));
};
let pw = self.config.password_owned();
let seed = server_caps.auth_data.clone();
let server_version = server_caps.server_version.clone();
if self.is_secure_transport() {
let mut clear = pw.as_bytes().to_vec();
clear.push(0);
if let Outcome::Err(e) = self.write_packet_async(&clear).await {
return Outcome::Err(e);
}
Outcome::Ok(())
} else {
if let Outcome::Err(e) = self
.write_packet_async(&[auth::caching_sha2::REQUEST_PUBLIC_KEY])
.await
{
return Outcome::Err(e);
}
let (payload, _) = match self.read_packet_async().await {
Outcome::Ok(p) => p,
Outcome::Err(e) => return Outcome::Err(e),
Outcome::Cancelled(r) => return Outcome::Cancelled(r),
Outcome::Panicked(p) => return Outcome::Panicked(p),
};
if payload.is_empty() {
return Outcome::Err(protocol_error("Empty public key response"));
}
let public_key = if payload[0] == 0x01 {
&payload[1..]
} else {
&payload[..]
};
let use_oaep = mysql_server_uses_oaep(&server_version);
let encrypted =
match auth::sha256_password_rsa(&pw, &seed, public_key, use_oaep) {
Ok(v) => v,
Err(e) => return Outcome::Err(auth_error(e)),
};
if let Outcome::Err(e) = self.write_packet_async(&encrypted).await {
return Outcome::Err(e);
}
Outcome::Ok(())
}
}
_ => Outcome::Err(protocol_error(format!(
"Unknown additional auth response: {:02X}",
data[0]
))),
}
}
pub async fn query_async(
&mut self,
_cx: &Cx,
sql: &str,
params: &[Value],
) -> Outcome<Vec<Row>, Error> {
let sql = interpolate_params(sql, params);
if !self.is_ready() && self.state != ConnectionState::InTransaction {
return Outcome::Err(connection_error("Connection not ready for queries"));
}
self.state = ConnectionState::InQuery;
self.sequence_id = 0;
let mut writer = PacketWriter::new();
writer.write_u8(Command::Query as u8);
writer.write_bytes(sql.as_bytes());
if let Outcome::Err(e) = self.write_packet_async(writer.as_bytes()).await {
return Outcome::Err(e);
}
let (payload, _) = match self.read_packet_async().await {
Outcome::Ok(p) => p,
Outcome::Err(e) => return Outcome::Err(e),
Outcome::Cancelled(r) => return Outcome::Cancelled(r),
Outcome::Panicked(p) => return Outcome::Panicked(p),
};
if payload.is_empty() {
self.state = ConnectionState::Ready;
return Outcome::Err(protocol_error("Empty query response"));
}
#[allow(clippy::cast_possible_truncation)] match PacketType::from_first_byte(payload[0], payload.len() as u32) {
PacketType::Ok => {
let mut reader = PacketReader::new(&payload);
if let Some(ok) = reader.parse_ok_packet() {
self.affected_rows = ok.affected_rows;
self.last_insert_id = ok.last_insert_id;
self.status_flags = ok.status_flags;
self.warnings = ok.warnings;
}
self.state = if self.status_flags
& crate::protocol::server_status::SERVER_STATUS_IN_TRANS
!= 0
{
ConnectionState::InTransaction
} else {
ConnectionState::Ready
};
Outcome::Ok(vec![])
}
PacketType::Error => {
self.state = ConnectionState::Ready;
let mut reader = PacketReader::new(&payload);
let Some(err) = reader.parse_err_packet() else {
return Outcome::Err(protocol_error("Invalid error packet"));
};
Outcome::Err(query_error(&err))
}
PacketType::LocalInfile => {
self.state = ConnectionState::Ready;
Outcome::Err(query_error_msg("LOCAL INFILE not supported"))
}
_ => self.read_result_set_async(&payload).await,
}
}
async fn read_result_set_async(&mut self, first_packet: &[u8]) -> Outcome<Vec<Row>, Error> {
let mut reader = PacketReader::new(first_packet);
#[allow(clippy::cast_possible_truncation)] let Some(column_count) = reader.read_lenenc_int().map(|c| c as usize) else {
return Outcome::Err(protocol_error("Invalid column count"));
};
let mut columns = Vec::with_capacity(column_count);
for _ in 0..column_count {
let (payload, _) = match self.read_packet_async().await {
Outcome::Ok(p) => p,
Outcome::Err(e) => return Outcome::Err(e),
Outcome::Cancelled(r) => return Outcome::Cancelled(r),
Outcome::Panicked(p) => return Outcome::Panicked(p),
};
match self.parse_column_def(&payload) {
Ok(col) => columns.push(col),
Err(e) => return Outcome::Err(e),
}
}
let server_caps = self.server_caps.as_ref().map_or(0, |c| c.capabilities);
if server_caps & capabilities::CLIENT_DEPRECATE_EOF == 0 {
let (payload, _) = match self.read_packet_async().await {
Outcome::Ok(p) => p,
Outcome::Err(e) => return Outcome::Err(e),
Outcome::Cancelled(r) => return Outcome::Cancelled(r),
Outcome::Panicked(p) => return Outcome::Panicked(p),
};
if payload.first() == Some(&0xFE) {
}
}
let mut rows = Vec::new();
loop {
let (payload, _) = match self.read_packet_async().await {
Outcome::Ok(p) => p,
Outcome::Err(e) => return Outcome::Err(e),
Outcome::Cancelled(r) => return Outcome::Cancelled(r),
Outcome::Panicked(p) => return Outcome::Panicked(p),
};
if payload.is_empty() {
break;
}
#[allow(clippy::cast_possible_truncation)] match PacketType::from_first_byte(payload[0], payload.len() as u32) {
PacketType::Eof | PacketType::Ok => {
let mut reader = PacketReader::new(&payload);
if payload[0] == 0x00 {
if let Some(ok) = reader.parse_ok_packet() {
self.status_flags = ok.status_flags;
self.warnings = ok.warnings;
}
} else if payload[0] == 0xFE {
if let Some(eof) = reader.parse_eof_packet() {
self.status_flags = eof.status_flags;
self.warnings = eof.warnings;
}
}
break;
}
PacketType::Error => {
let mut reader = PacketReader::new(&payload);
let Some(err) = reader.parse_err_packet() else {
return Outcome::Err(protocol_error("Invalid error packet"));
};
self.state = ConnectionState::Ready;
return Outcome::Err(query_error(&err));
}
_ => {
let row = self.parse_text_row(&payload, &columns);
rows.push(row);
}
}
}
self.state =
if self.status_flags & crate::protocol::server_status::SERVER_STATUS_IN_TRANS != 0 {
ConnectionState::InTransaction
} else {
ConnectionState::Ready
};
Outcome::Ok(rows)
}
fn parse_column_def(&self, data: &[u8]) -> Result<ColumnDef, Error> {
let mut reader = PacketReader::new(data);
let catalog = reader
.read_lenenc_string()
.ok_or_else(|| protocol_error("Missing catalog"))?;
let schema = reader
.read_lenenc_string()
.ok_or_else(|| protocol_error("Missing schema"))?;
let table = reader
.read_lenenc_string()
.ok_or_else(|| protocol_error("Missing table"))?;
let org_table = reader
.read_lenenc_string()
.ok_or_else(|| protocol_error("Missing org_table"))?;
let name = reader
.read_lenenc_string()
.ok_or_else(|| protocol_error("Missing name"))?;
let org_name = reader
.read_lenenc_string()
.ok_or_else(|| protocol_error("Missing org_name"))?;
let _fixed_len = reader.read_lenenc_int();
let charset_val = reader
.read_u16_le()
.ok_or_else(|| protocol_error("Missing charset"))?;
let column_length = reader
.read_u32_le()
.ok_or_else(|| protocol_error("Missing column_length"))?;
let column_type = FieldType::from_u8(
reader
.read_u8()
.ok_or_else(|| protocol_error("Missing column_type"))?,
);
let flags = reader
.read_u16_le()
.ok_or_else(|| protocol_error("Missing flags"))?;
let decimals = reader
.read_u8()
.ok_or_else(|| protocol_error("Missing decimals"))?;
Ok(ColumnDef {
catalog,
schema,
table,
org_table,
name,
org_name,
charset: charset_val,
column_length,
column_type,
flags,
decimals,
})
}
fn parse_text_row(&self, data: &[u8], columns: &[ColumnDef]) -> Row {
let mut reader = PacketReader::new(data);
let mut values = Vec::with_capacity(columns.len());
for col in columns {
if reader.peek() == Some(0xFB) {
reader.skip(1);
values.push(Value::Null);
} else if let Some(data) = reader.read_lenenc_bytes() {
let is_unsigned = col.is_unsigned();
let value = decode_text_value(col.column_type, &data, is_unsigned);
values.push(value);
} else {
values.push(Value::Null);
}
}
let column_names: Vec<String> = columns.iter().map(|c| c.name.clone()).collect();
Row::new(column_names, values)
}
pub async fn execute_async(
&mut self,
cx: &Cx,
sql: &str,
params: &[Value],
) -> Outcome<u64, Error> {
match self.query_async(cx, sql, params).await {
Outcome::Ok(_) => Outcome::Ok(self.affected_rows),
Outcome::Err(e) => Outcome::Err(e),
Outcome::Cancelled(c) => Outcome::Cancelled(c),
Outcome::Panicked(p) => Outcome::Panicked(p),
}
}
pub async fn prepare_async(
&mut self,
_cx: &Cx,
sql: &str,
) -> Outcome<PreparedStatement, Error> {
if !self.is_ready() && self.state != ConnectionState::InTransaction {
return Outcome::Err(connection_error("Connection not ready for prepare"));
}
self.sequence_id = 0;
let packet = prepared::build_stmt_prepare_packet(sql, self.sequence_id);
if let Outcome::Err(e) = self.write_packet_raw_async(&packet).await {
return Outcome::Err(e);
}
let (payload, _) = match self.read_packet_async().await {
Outcome::Ok(p) => p,
Outcome::Err(e) => return Outcome::Err(e),
Outcome::Cancelled(r) => return Outcome::Cancelled(r),
Outcome::Panicked(p) => return Outcome::Panicked(p),
};
if payload.first() == Some(&0xFF) {
let mut reader = PacketReader::new(&payload);
let Some(err) = reader.parse_err_packet() else {
return Outcome::Err(protocol_error("Invalid error packet"));
};
return Outcome::Err(query_error(&err));
}
let Some(prep_ok) = prepared::parse_stmt_prepare_ok(&payload) else {
return Outcome::Err(protocol_error("Invalid prepare OK response"));
};
let mut param_defs = Vec::with_capacity(prep_ok.num_params as usize);
for _ in 0..prep_ok.num_params {
let (payload, _) = match self.read_packet_async().await {
Outcome::Ok(p) => p,
Outcome::Err(e) => return Outcome::Err(e),
Outcome::Cancelled(r) => return Outcome::Cancelled(r),
Outcome::Panicked(p) => return Outcome::Panicked(p),
};
match self.parse_column_def(&payload) {
Ok(col) => param_defs.push(col),
Err(e) => return Outcome::Err(e),
}
}
let server_caps = self.server_caps.as_ref().map_or(0, |c| c.capabilities);
if prep_ok.num_params > 0 && server_caps & capabilities::CLIENT_DEPRECATE_EOF == 0 {
let (payload, _) = match self.read_packet_async().await {
Outcome::Ok(p) => p,
Outcome::Err(e) => return Outcome::Err(e),
Outcome::Cancelled(r) => return Outcome::Cancelled(r),
Outcome::Panicked(p) => return Outcome::Panicked(p),
};
if payload.first() != Some(&0xFE) {
return Outcome::Err(protocol_error("Expected EOF after param definitions"));
}
}
let mut column_defs = Vec::with_capacity(prep_ok.num_columns as usize);
for _ in 0..prep_ok.num_columns {
let (payload, _) = match self.read_packet_async().await {
Outcome::Ok(p) => p,
Outcome::Err(e) => return Outcome::Err(e),
Outcome::Cancelled(r) => return Outcome::Cancelled(r),
Outcome::Panicked(p) => return Outcome::Panicked(p),
};
match self.parse_column_def(&payload) {
Ok(col) => column_defs.push(col),
Err(e) => return Outcome::Err(e),
}
}
if prep_ok.num_columns > 0 && server_caps & capabilities::CLIENT_DEPRECATE_EOF == 0 {
let (payload, _) = match self.read_packet_async().await {
Outcome::Ok(p) => p,
Outcome::Err(e) => return Outcome::Err(e),
Outcome::Cancelled(r) => return Outcome::Cancelled(r),
Outcome::Panicked(p) => return Outcome::Panicked(p),
};
if payload.first() != Some(&0xFE) {
return Outcome::Err(protocol_error("Expected EOF after column definitions"));
}
}
let meta = PreparedStmtMeta {
statement_id: prep_ok.statement_id,
params: param_defs,
columns: column_defs.clone(),
};
self.prepared_stmts.insert(prep_ok.statement_id, meta);
let column_names: Vec<String> = column_defs.iter().map(|c| c.name.clone()).collect();
Outcome::Ok(PreparedStatement::with_columns(
u64::from(prep_ok.statement_id),
sql.to_string(),
prep_ok.num_params as usize,
column_names,
))
}
pub async fn query_prepared_async(
&mut self,
_cx: &Cx,
stmt: &PreparedStatement,
params: &[Value],
) -> Outcome<Vec<Row>, Error> {
#[allow(clippy::cast_possible_truncation)] let stmt_id = stmt.id() as u32;
let Some(meta) = self.prepared_stmts.get(&stmt_id).cloned() else {
return Outcome::Err(connection_error("Unknown prepared statement"));
};
if params.len() != meta.params.len() {
return Outcome::Err(connection_error(format!(
"Expected {} parameters, got {}",
meta.params.len(),
params.len()
)));
}
if !self.is_ready() && self.state != ConnectionState::InTransaction {
return Outcome::Err(connection_error("Connection not ready for query"));
}
self.state = ConnectionState::InQuery;
self.sequence_id = 0;
let param_types: Vec<FieldType> = meta.params.iter().map(|c| c.column_type).collect();
let packet = prepared::build_stmt_execute_packet(
stmt_id,
params,
Some(¶m_types),
self.sequence_id,
);
if let Outcome::Err(e) = self.write_packet_raw_async(&packet).await {
return Outcome::Err(e);
}
let (payload, _) = match self.read_packet_async().await {
Outcome::Ok(p) => p,
Outcome::Err(e) => return Outcome::Err(e),
Outcome::Cancelled(r) => return Outcome::Cancelled(r),
Outcome::Panicked(p) => return Outcome::Panicked(p),
};
if payload.is_empty() {
self.state = ConnectionState::Ready;
return Outcome::Err(protocol_error("Empty execute response"));
}
#[allow(clippy::cast_possible_truncation)] match PacketType::from_first_byte(payload[0], payload.len() as u32) {
PacketType::Ok => {
let mut reader = PacketReader::new(&payload);
if let Some(ok) = reader.parse_ok_packet() {
self.affected_rows = ok.affected_rows;
self.last_insert_id = ok.last_insert_id;
self.status_flags = ok.status_flags;
self.warnings = ok.warnings;
}
self.state = ConnectionState::Ready;
Outcome::Ok(vec![])
}
PacketType::Error => {
self.state = ConnectionState::Ready;
let mut reader = PacketReader::new(&payload);
let Some(err) = reader.parse_err_packet() else {
return Outcome::Err(protocol_error("Invalid error packet"));
};
Outcome::Err(query_error(&err))
}
_ => {
self.read_binary_result_set_async(&payload, &meta.columns)
.await
}
}
}
pub async fn execute_prepared_async(
&mut self,
cx: &Cx,
stmt: &PreparedStatement,
params: &[Value],
) -> Outcome<u64, Error> {
match self.query_prepared_async(cx, stmt, params).await {
Outcome::Ok(_) => Outcome::Ok(self.affected_rows),
Outcome::Err(e) => Outcome::Err(e),
Outcome::Cancelled(c) => Outcome::Cancelled(c),
Outcome::Panicked(p) => Outcome::Panicked(p),
}
}
pub async fn close_prepared_async(&mut self, stmt: &PreparedStatement) {
#[allow(clippy::cast_possible_truncation)] let stmt_id = stmt.id() as u32;
self.prepared_stmts.remove(&stmt_id);
self.sequence_id = 0;
let packet = prepared::build_stmt_close_packet(stmt_id, self.sequence_id);
let _ = self.write_packet_raw_async(&packet).await;
}
async fn read_binary_result_set_async(
&mut self,
first_packet: &[u8],
columns: &[ColumnDef],
) -> Outcome<Vec<Row>, Error> {
let mut reader = PacketReader::new(first_packet);
#[allow(clippy::cast_possible_truncation)] let Some(column_count) = reader.read_lenenc_int().map(|c| c as usize) else {
return Outcome::Err(protocol_error("Invalid column count"));
};
let mut result_columns = Vec::with_capacity(column_count);
for _ in 0..column_count {
let (payload, _) = match self.read_packet_async().await {
Outcome::Ok(p) => p,
Outcome::Err(e) => return Outcome::Err(e),
Outcome::Cancelled(r) => return Outcome::Cancelled(r),
Outcome::Panicked(p) => return Outcome::Panicked(p),
};
match self.parse_column_def(&payload) {
Ok(col) => result_columns.push(col),
Err(e) => return Outcome::Err(e),
}
}
let cols = if result_columns.len() == columns.len() {
&result_columns
} else {
columns
};
let server_caps = self.server_caps.as_ref().map_or(0, |c| c.capabilities);
if server_caps & capabilities::CLIENT_DEPRECATE_EOF == 0 {
let (payload, _) = match self.read_packet_async().await {
Outcome::Ok(p) => p,
Outcome::Err(e) => return Outcome::Err(e),
Outcome::Cancelled(r) => return Outcome::Cancelled(r),
Outcome::Panicked(p) => return Outcome::Panicked(p),
};
if payload.first() == Some(&0xFE) {
}
}
let mut rows = Vec::new();
loop {
let (payload, _) = match self.read_packet_async().await {
Outcome::Ok(p) => p,
Outcome::Err(e) => return Outcome::Err(e),
Outcome::Cancelled(r) => return Outcome::Cancelled(r),
Outcome::Panicked(p) => return Outcome::Panicked(p),
};
if payload.is_empty() {
break;
}
#[allow(clippy::cast_possible_truncation)] match PacketType::from_first_byte(payload[0], payload.len() as u32) {
PacketType::Eof | PacketType::Ok => {
let mut reader = PacketReader::new(&payload);
if payload[0] == 0x00 {
if let Some(ok) = reader.parse_ok_packet() {
self.status_flags = ok.status_flags;
self.warnings = ok.warnings;
}
} else if payload[0] == 0xFE {
if let Some(eof) = reader.parse_eof_packet() {
self.status_flags = eof.status_flags;
self.warnings = eof.warnings;
}
}
break;
}
PacketType::Error => {
let mut reader = PacketReader::new(&payload);
let Some(err) = reader.parse_err_packet() else {
return Outcome::Err(protocol_error("Invalid error packet"));
};
self.state = ConnectionState::Ready;
return Outcome::Err(query_error(&err));
}
_ => {
let row = self.parse_binary_row(&payload, cols);
rows.push(row);
}
}
}
self.state =
if self.status_flags & crate::protocol::server_status::SERVER_STATUS_IN_TRANS != 0 {
ConnectionState::InTransaction
} else {
ConnectionState::Ready
};
Outcome::Ok(rows)
}
fn parse_binary_row(&self, data: &[u8], columns: &[ColumnDef]) -> Row {
let mut values = Vec::with_capacity(columns.len());
let mut column_names = Vec::with_capacity(columns.len());
if data.is_empty() {
return Row::new(column_names, values);
}
let mut pos = 1;
let null_bitmap_len = (columns.len() + 7 + 2) / 8;
if pos + null_bitmap_len > data.len() {
return Row::new(column_names, values);
}
let null_bitmap = &data[pos..pos + null_bitmap_len];
pos += null_bitmap_len;
for (i, col) in columns.iter().enumerate() {
column_names.push(col.name.clone());
let bit_pos = i + 2;
let is_null = (null_bitmap[bit_pos / 8] & (1 << (bit_pos % 8))) != 0;
if is_null {
values.push(Value::Null);
} else {
let is_unsigned = col.flags & 0x20 != 0; let (value, consumed) =
decode_binary_value_with_len(&data[pos..], col.column_type, is_unsigned);
values.push(value);
pos += consumed;
}
}
Row::new(column_names, values)
}
async fn write_packet_raw_async(&mut self, packet: &[u8]) -> Outcome<(), Error> {
let Some(stream) = self.stream.as_mut() else {
return Outcome::Err(connection_error("Connection stream missing"));
};
match stream {
ConnectionStream::Async(stream) => {
let mut written = 0;
while written < packet.len() {
match std::future::poll_fn(|cx| {
std::pin::Pin::new(&mut *stream).poll_write(cx, &packet[written..])
})
.await
{
Ok(n) => written += n,
Err(e) => {
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("Failed to write packet: {}", e),
source: Some(Box::new(e)),
}));
}
}
}
if let Err(e) =
std::future::poll_fn(|cx| std::pin::Pin::new(&mut *stream).poll_flush(cx)).await
{
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("Failed to flush: {}", e),
source: Some(Box::new(e)),
}));
}
Outcome::Ok(())
}
ConnectionStream::Sync(stream) => {
if let Err(e) = stream.write_all(packet) {
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("Failed to write packet: {}", e),
source: Some(Box::new(e)),
}));
}
if let Err(e) = stream.flush() {
return Outcome::Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("Failed to flush: {}", e),
source: Some(Box::new(e)),
}));
}
Outcome::Ok(())
}
#[cfg(feature = "tls")]
ConnectionStream::Tls(_) => Outcome::Err(connection_error(
"write_packet_raw_async called after TLS upgrade (bug)",
)),
}
}
pub async fn ping_async(&mut self, _cx: &Cx) -> Outcome<(), Error> {
self.sequence_id = 0;
let mut writer = PacketWriter::new();
writer.write_u8(Command::Ping as u8);
if let Outcome::Err(e) = self.write_packet_async(writer.as_bytes()).await {
return Outcome::Err(e);
}
let (payload, _) = match self.read_packet_async().await {
Outcome::Ok(p) => p,
Outcome::Err(e) => return Outcome::Err(e),
Outcome::Cancelled(r) => return Outcome::Cancelled(r),
Outcome::Panicked(p) => return Outcome::Panicked(p),
};
if payload.first() == Some(&0x00) {
Outcome::Ok(())
} else {
Outcome::Err(connection_error("Ping failed"))
}
}
pub async fn close_async(mut self, _cx: &Cx) -> Result<(), Error> {
if self.state == ConnectionState::Closed {
return Ok(());
}
self.sequence_id = 0;
let mut writer = PacketWriter::new();
writer.write_u8(Command::Quit as u8);
let _ = self.write_packet_async(writer.as_bytes()).await;
self.state = ConnectionState::Closed;
Ok(())
}
}
#[cfg(feature = "console")]
impl ConsoleAware for MySqlAsyncConnection {
fn set_console(&mut self, console: Option<Arc<SqlModelConsole>>) {
self.console = console;
}
fn console(&self) -> Option<&Arc<SqlModelConsole>> {
self.console.as_ref()
}
}
fn protocol_error(msg: impl Into<String>) -> Error {
Error::Protocol(ProtocolError {
message: msg.into(),
raw_data: None,
source: None,
})
}
fn auth_error(msg: impl Into<String>) -> Error {
Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Authentication,
message: msg.into(),
source: None,
})
}
fn connection_error(msg: impl Into<String>) -> Error {
Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Connect,
message: msg.into(),
source: None,
})
}
fn mysql_server_uses_oaep(server_version: &str) -> bool {
let prefix: String = server_version
.chars()
.take_while(|c| c.is_ascii_digit() || *c == '.')
.collect();
let mut it = prefix.split('.').filter(|s| !s.is_empty());
let major: u64 = match it.next().and_then(|s| s.parse().ok()) {
Some(v) => v,
None => return true,
};
let minor: u64 = it.next().and_then(|s| s.parse().ok()).unwrap_or(0);
let patch: u64 = it.next().and_then(|s| s.parse().ok()).unwrap_or(0);
(major, minor, patch) >= (8, 0, 5)
}
fn query_error(err: &ErrPacket) -> Error {
let kind = if err.is_duplicate_key() || err.is_foreign_key_violation() {
QueryErrorKind::Constraint
} else {
QueryErrorKind::Syntax
};
Error::Query(QueryError {
kind,
message: err.error_message.clone(),
sqlstate: Some(err.sql_state.clone()),
sql: None,
detail: None,
hint: None,
position: None,
source: None,
})
}
fn query_error_msg(msg: impl Into<String>) -> Error {
Error::Query(QueryError {
kind: QueryErrorKind::Syntax,
message: msg.into(),
sqlstate: None,
sql: None,
detail: None,
hint: None,
position: None,
source: None,
})
}
fn validate_savepoint_name(name: &str) -> Result<(), Error> {
if name.is_empty() {
return Err(query_error_msg("Savepoint name cannot be empty"));
}
if name.len() > 64 {
return Err(query_error_msg(
"Savepoint name exceeds maximum length of 64 characters",
));
}
let mut chars = name.chars();
let Some(first) = chars.next() else {
return Err(query_error_msg("Savepoint name cannot be empty"));
};
if !first.is_ascii_alphabetic() && first != '_' {
return Err(query_error_msg(
"Savepoint name must start with a letter or underscore",
));
}
for c in chars {
if !c.is_ascii_alphanumeric() && c != '_' && c != '$' {
return Err(query_error_msg(format!(
"Savepoint name contains invalid character: '{}'",
c
)));
}
}
Ok(())
}
pub struct SharedMySqlConnection {
inner: Arc<Mutex<MySqlAsyncConnection>>,
}
impl SharedMySqlConnection {
pub fn new(conn: MySqlAsyncConnection) -> Self {
Self {
inner: Arc::new(Mutex::new(conn)),
}
}
pub async fn connect(cx: &Cx, config: MySqlConfig) -> Outcome<Self, Error> {
match MySqlAsyncConnection::connect(cx, config).await {
Outcome::Ok(conn) => Outcome::Ok(Self::new(conn)),
Outcome::Err(e) => Outcome::Err(e),
Outcome::Cancelled(c) => Outcome::Cancelled(c),
Outcome::Panicked(p) => Outcome::Panicked(p),
}
}
pub fn inner(&self) -> &Arc<Mutex<MySqlAsyncConnection>> {
&self.inner
}
}
impl Clone for SharedMySqlConnection {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl std::fmt::Debug for SharedMySqlConnection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SharedMySqlConnection")
.field("inner", &"Arc<Mutex<MySqlAsyncConnection>>")
.finish()
}
}
pub struct SharedMySqlTransaction<'conn> {
inner: Arc<Mutex<MySqlAsyncConnection>>,
committed: bool,
_marker: std::marker::PhantomData<&'conn ()>,
}
impl SharedMySqlConnection {
async fn begin_transaction_impl(
&self,
cx: &Cx,
isolation: Option<IsolationLevel>,
) -> Outcome<SharedMySqlTransaction<'_>, Error> {
let inner = Arc::clone(&self.inner);
let Ok(mut guard) = inner.lock(cx).await else {
return Outcome::Err(connection_error("Failed to acquire connection lock"));
};
if let Some(level) = isolation {
let isolation_sql = format!("SET TRANSACTION ISOLATION LEVEL {}", level.as_sql());
match guard.execute_async(cx, &isolation_sql, &[]).await {
Outcome::Ok(_) => {}
Outcome::Err(e) => return Outcome::Err(e),
Outcome::Cancelled(c) => return Outcome::Cancelled(c),
Outcome::Panicked(p) => return Outcome::Panicked(p),
}
}
match guard.execute_async(cx, "BEGIN", &[]).await {
Outcome::Ok(_) => {}
Outcome::Err(e) => return Outcome::Err(e),
Outcome::Cancelled(c) => return Outcome::Cancelled(c),
Outcome::Panicked(p) => return Outcome::Panicked(p),
}
drop(guard);
Outcome::Ok(SharedMySqlTransaction {
inner,
committed: false,
_marker: std::marker::PhantomData,
})
}
}
impl Connection for SharedMySqlConnection {
type Tx<'conn>
= SharedMySqlTransaction<'conn>
where
Self: 'conn;
fn dialect(&self) -> sqlmodel_core::Dialect {
sqlmodel_core::Dialect::Mysql
}
fn query(
&self,
cx: &Cx,
sql: &str,
params: &[Value],
) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
let inner = Arc::clone(&self.inner);
let sql = sql.to_string();
let params = params.to_vec();
async move {
let Ok(mut guard) = inner.lock(cx).await else {
return Outcome::Err(connection_error("Failed to acquire connection lock"));
};
guard.query_async(cx, &sql, ¶ms).await
}
}
fn query_one(
&self,
cx: &Cx,
sql: &str,
params: &[Value],
) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
let inner = Arc::clone(&self.inner);
let sql = sql.to_string();
let params = params.to_vec();
async move {
let Ok(mut guard) = inner.lock(cx).await else {
return Outcome::Err(connection_error("Failed to acquire connection lock"));
};
let rows = match guard.query_async(cx, &sql, ¶ms).await {
Outcome::Ok(r) => r,
Outcome::Err(e) => return Outcome::Err(e),
Outcome::Cancelled(c) => return Outcome::Cancelled(c),
Outcome::Panicked(p) => return Outcome::Panicked(p),
};
Outcome::Ok(rows.into_iter().next())
}
}
fn execute(
&self,
cx: &Cx,
sql: &str,
params: &[Value],
) -> impl Future<Output = Outcome<u64, Error>> + Send {
let inner = Arc::clone(&self.inner);
let sql = sql.to_string();
let params = params.to_vec();
async move {
let Ok(mut guard) = inner.lock(cx).await else {
return Outcome::Err(connection_error("Failed to acquire connection lock"));
};
guard.execute_async(cx, &sql, ¶ms).await
}
}
fn insert(
&self,
cx: &Cx,
sql: &str,
params: &[Value],
) -> impl Future<Output = Outcome<i64, Error>> + Send {
let inner = Arc::clone(&self.inner);
let sql = sql.to_string();
let params = params.to_vec();
async move {
let Ok(mut guard) = inner.lock(cx).await else {
return Outcome::Err(connection_error("Failed to acquire connection lock"));
};
match guard.execute_async(cx, &sql, ¶ms).await {
Outcome::Ok(_) => Outcome::Ok(guard.last_insert_id() as i64),
Outcome::Err(e) => Outcome::Err(e),
Outcome::Cancelled(c) => Outcome::Cancelled(c),
Outcome::Panicked(p) => Outcome::Panicked(p),
}
}
}
fn batch(
&self,
cx: &Cx,
statements: &[(String, Vec<Value>)],
) -> impl Future<Output = Outcome<Vec<u64>, Error>> + Send {
let inner = Arc::clone(&self.inner);
let statements = statements.to_vec();
async move {
let Ok(mut guard) = inner.lock(cx).await else {
return Outcome::Err(connection_error("Failed to acquire connection lock"));
};
let mut results = Vec::with_capacity(statements.len());
for (sql, params) in &statements {
match guard.execute_async(cx, sql, params).await {
Outcome::Ok(n) => results.push(n),
Outcome::Err(e) => return Outcome::Err(e),
Outcome::Cancelled(c) => return Outcome::Cancelled(c),
Outcome::Panicked(p) => return Outcome::Panicked(p),
}
}
Outcome::Ok(results)
}
}
fn begin(&self, cx: &Cx) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
self.begin_transaction_impl(cx, None)
}
fn begin_with(
&self,
cx: &Cx,
isolation: IsolationLevel,
) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
self.begin_transaction_impl(cx, Some(isolation))
}
fn prepare(
&self,
cx: &Cx,
sql: &str,
) -> impl Future<Output = Outcome<PreparedStatement, Error>> + Send {
let inner = Arc::clone(&self.inner);
let sql = sql.to_string();
async move {
let Ok(mut guard) = inner.lock(cx).await else {
return Outcome::Err(connection_error("Failed to acquire connection lock"));
};
guard.prepare_async(cx, &sql).await
}
}
fn query_prepared(
&self,
cx: &Cx,
stmt: &PreparedStatement,
params: &[Value],
) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
let inner = Arc::clone(&self.inner);
let stmt = stmt.clone();
let params = params.to_vec();
async move {
let Ok(mut guard) = inner.lock(cx).await else {
return Outcome::Err(connection_error("Failed to acquire connection lock"));
};
guard.query_prepared_async(cx, &stmt, ¶ms).await
}
}
fn execute_prepared(
&self,
cx: &Cx,
stmt: &PreparedStatement,
params: &[Value],
) -> impl Future<Output = Outcome<u64, Error>> + Send {
let inner = Arc::clone(&self.inner);
let stmt = stmt.clone();
let params = params.to_vec();
async move {
let Ok(mut guard) = inner.lock(cx).await else {
return Outcome::Err(connection_error("Failed to acquire connection lock"));
};
guard.execute_prepared_async(cx, &stmt, ¶ms).await
}
}
fn ping(&self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
let inner = Arc::clone(&self.inner);
async move {
let Ok(mut guard) = inner.lock(cx).await else {
return Outcome::Err(connection_error("Failed to acquire connection lock"));
};
guard.ping_async(cx).await
}
}
fn close(self, cx: &Cx) -> impl Future<Output = Result<(), Error>> + Send {
async move {
match Arc::try_unwrap(self.inner) {
Ok(mutex) => {
let conn = mutex.into_inner();
conn.close_async(cx).await
}
Err(_) => {
Err(connection_error(
"Cannot close: other references to connection exist",
))
}
}
}
}
}
impl<'conn> TransactionOps for SharedMySqlTransaction<'conn> {
fn query(
&self,
cx: &Cx,
sql: &str,
params: &[Value],
) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
let inner = Arc::clone(&self.inner);
let sql = sql.to_string();
let params = params.to_vec();
async move {
let Ok(mut guard) = inner.lock(cx).await else {
return Outcome::Err(connection_error("Failed to acquire connection lock"));
};
guard.query_async(cx, &sql, ¶ms).await
}
}
fn query_one(
&self,
cx: &Cx,
sql: &str,
params: &[Value],
) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
let inner = Arc::clone(&self.inner);
let sql = sql.to_string();
let params = params.to_vec();
async move {
let Ok(mut guard) = inner.lock(cx).await else {
return Outcome::Err(connection_error("Failed to acquire connection lock"));
};
let rows = match guard.query_async(cx, &sql, ¶ms).await {
Outcome::Ok(r) => r,
Outcome::Err(e) => return Outcome::Err(e),
Outcome::Cancelled(c) => return Outcome::Cancelled(c),
Outcome::Panicked(p) => return Outcome::Panicked(p),
};
Outcome::Ok(rows.into_iter().next())
}
}
fn execute(
&self,
cx: &Cx,
sql: &str,
params: &[Value],
) -> impl Future<Output = Outcome<u64, Error>> + Send {
let inner = Arc::clone(&self.inner);
let sql = sql.to_string();
let params = params.to_vec();
async move {
let Ok(mut guard) = inner.lock(cx).await else {
return Outcome::Err(connection_error("Failed to acquire connection lock"));
};
guard.execute_async(cx, &sql, ¶ms).await
}
}
fn savepoint(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
let inner = Arc::clone(&self.inner);
let validation_result = validate_savepoint_name(name);
let sql = format!("SAVEPOINT {}", name);
async move {
if let Err(e) = validation_result {
return Outcome::Err(e);
}
let Ok(mut guard) = inner.lock(cx).await else {
return Outcome::Err(connection_error("Failed to acquire connection lock"));
};
match guard.execute_async(cx, &sql, &[]).await {
Outcome::Ok(_) => Outcome::Ok(()),
Outcome::Err(e) => Outcome::Err(e),
Outcome::Cancelled(c) => Outcome::Cancelled(c),
Outcome::Panicked(p) => Outcome::Panicked(p),
}
}
}
fn rollback_to(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
let inner = Arc::clone(&self.inner);
let validation_result = validate_savepoint_name(name);
let sql = format!("ROLLBACK TO SAVEPOINT {}", name);
async move {
if let Err(e) = validation_result {
return Outcome::Err(e);
}
let Ok(mut guard) = inner.lock(cx).await else {
return Outcome::Err(connection_error("Failed to acquire connection lock"));
};
match guard.execute_async(cx, &sql, &[]).await {
Outcome::Ok(_) => Outcome::Ok(()),
Outcome::Err(e) => Outcome::Err(e),
Outcome::Cancelled(c) => Outcome::Cancelled(c),
Outcome::Panicked(p) => Outcome::Panicked(p),
}
}
}
fn release(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
let inner = Arc::clone(&self.inner);
let validation_result = validate_savepoint_name(name);
let sql = format!("RELEASE SAVEPOINT {}", name);
async move {
if let Err(e) = validation_result {
return Outcome::Err(e);
}
let Ok(mut guard) = inner.lock(cx).await else {
return Outcome::Err(connection_error("Failed to acquire connection lock"));
};
match guard.execute_async(cx, &sql, &[]).await {
Outcome::Ok(_) => Outcome::Ok(()),
Outcome::Err(e) => Outcome::Err(e),
Outcome::Cancelled(c) => Outcome::Cancelled(c),
Outcome::Panicked(p) => Outcome::Panicked(p),
}
}
}
#[allow(unused_assignments)]
fn commit(mut self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
async move {
let Ok(mut guard) = self.inner.lock(cx).await else {
return Outcome::Err(connection_error("Failed to acquire connection lock"));
};
match guard.execute_async(cx, "COMMIT", &[]).await {
Outcome::Ok(_) => {
self.committed = true;
Outcome::Ok(())
}
Outcome::Err(e) => Outcome::Err(e),
Outcome::Cancelled(c) => Outcome::Cancelled(c),
Outcome::Panicked(p) => Outcome::Panicked(p),
}
}
}
fn rollback(self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
async move {
let Ok(mut guard) = self.inner.lock(cx).await else {
return Outcome::Err(connection_error("Failed to acquire connection lock"));
};
match guard.execute_async(cx, "ROLLBACK", &[]).await {
Outcome::Ok(_) => Outcome::Ok(()),
Outcome::Err(e) => Outcome::Err(e),
Outcome::Cancelled(c) => Outcome::Cancelled(c),
Outcome::Panicked(p) => Outcome::Panicked(p),
}
}
}
}
impl<'conn> Drop for SharedMySqlTransaction<'conn> {
fn drop(&mut self) {
if !self.committed {
#[cfg(debug_assertions)]
eprintln!(
"WARNING: SharedMySqlTransaction dropped without commit/rollback. \
The MySQL transaction may still be open."
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_connection_state() {
assert_eq!(ConnectionState::Disconnected, ConnectionState::Disconnected);
}
#[test]
fn test_error_helpers() {
let err = protocol_error("test");
assert!(matches!(err, Error::Protocol(_)));
let err = auth_error("auth failed");
assert!(matches!(err, Error::Connection(_)));
let err = connection_error("conn failed");
assert!(matches!(err, Error::Connection(_)));
}
#[test]
fn test_validate_savepoint_name_valid() {
assert!(validate_savepoint_name("sp1").is_ok());
assert!(validate_savepoint_name("_savepoint").is_ok());
assert!(validate_savepoint_name("SavePoint_123").is_ok());
assert!(validate_savepoint_name("sp$test").is_ok());
assert!(validate_savepoint_name("a").is_ok());
assert!(validate_savepoint_name("_").is_ok());
}
#[test]
fn test_validate_savepoint_name_invalid() {
assert!(validate_savepoint_name("").is_err());
assert!(validate_savepoint_name("1savepoint").is_err());
assert!(validate_savepoint_name("save-point").is_err());
assert!(validate_savepoint_name("save point").is_err());
assert!(validate_savepoint_name("save;drop table").is_err());
assert!(validate_savepoint_name("sp'--").is_err());
let long_name = "a".repeat(65);
assert!(validate_savepoint_name(&long_name).is_err());
}
}