#![allow(clippy::cast_possible_truncation)]
use std::io::{Read, Write};
use std::net::TcpStream;
#[cfg(feature = "console")]
use std::sync::Arc;
use sqlmodel_core::Error;
use sqlmodel_core::error::{
ConnectionError, ConnectionErrorKind, ProtocolError, QueryError, QueryErrorKind,
};
use sqlmodel_core::{Row, Value};
#[cfg(feature = "console")]
use sqlmodel_console::{ConsoleAware, SqlModelConsole};
use crate::auth;
use crate::config::MySqlConfig;
use crate::protocol::{
Command, ErrPacket, MAX_PACKET_SIZE, PacketHeader, PacketReader, PacketType, PacketWriter,
capabilities, charset,
};
use crate::types::{ColumnDef, FieldType, decode_text_value, interpolate_params};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionState {
Disconnected,
Connecting,
Authenticating,
Ready,
InQuery,
InTransaction,
Error,
Closed,
}
#[derive(Debug, Clone)]
pub struct ServerCapabilities {
pub capabilities: u32,
pub protocol_version: u8,
pub server_version: String,
pub connection_id: u32,
pub auth_plugin: String,
pub auth_data: Vec<u8>,
pub charset: u8,
pub status_flags: u16,
}
pub struct MySqlConnection {
stream: TcpStream,
state: ConnectionState,
server_caps: Option<ServerCapabilities>,
connection_id: u32,
status_flags: u16,
affected_rows: u64,
last_insert_id: u64,
warnings: u16,
config: MySqlConfig,
sequence_id: u8,
#[cfg(feature = "console")]
console: Option<Arc<SqlModelConsole>>,
}
impl std::fmt::Debug for MySqlConnection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MySqlConnection")
.field("state", &self.state)
.field("connection_id", &self.connection_id)
.field("host", &self.config.host)
.field("port", &self.config.port)
.field("database", &self.config.database)
.finish_non_exhaustive()
}
}
impl MySqlConnection {
#[allow(clippy::result_large_err)]
pub fn connect(config: MySqlConfig) -> Result<Self, Error> {
let stream = TcpStream::connect_timeout(
&config.socket_addr().parse().map_err(|e| {
Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Connect,
message: format!("Invalid socket address: {}", e),
source: None,
})
})?,
config.connect_timeout,
)
.map_err(|e| {
let kind = if e.kind() == std::io::ErrorKind::ConnectionRefused {
ConnectionErrorKind::Refused
} else {
ConnectionErrorKind::Connect
};
Error::Connection(ConnectionError {
kind,
message: format!("Failed to connect to {}: {}", config.socket_addr(), e),
source: Some(Box::new(e)),
})
})?;
stream.set_nodelay(true).ok();
stream.set_read_timeout(Some(config.connect_timeout)).ok();
stream.set_write_timeout(Some(config.connect_timeout)).ok();
let mut conn = Self {
stream,
state: ConnectionState::Connecting,
server_caps: None,
connection_id: 0,
status_flags: 0,
affected_rows: 0,
last_insert_id: 0,
warnings: 0,
config,
sequence_id: 0,
#[cfg(feature = "console")]
console: None,
};
let server_caps = conn.read_handshake()?;
conn.connection_id = server_caps.connection_id;
conn.server_caps = Some(server_caps);
conn.state = ConnectionState::Authenticating;
conn.send_handshake_response()?;
conn.handle_auth_result()?;
conn.state = ConnectionState::Ready;
Ok(conn)
}
pub fn state(&self) -> ConnectionState {
self.state
}
pub fn is_ready(&self) -> bool {
matches!(self.state, ConnectionState::Ready)
}
pub fn connection_id(&self) -> u32 {
self.connection_id
}
pub fn server_version(&self) -> Option<&str> {
self.server_caps
.as_ref()
.map(|caps| caps.server_version.as_str())
}
pub fn affected_rows(&self) -> u64 {
self.affected_rows
}
pub fn last_insert_id(&self) -> u64 {
self.last_insert_id
}
pub fn warnings(&self) -> u16 {
self.warnings
}
#[allow(clippy::result_large_err)]
fn read_handshake(&mut self) -> Result<ServerCapabilities, Error> {
let (payload, _) = self.read_packet()?;
let mut reader = PacketReader::new(&payload);
let protocol_version = reader
.read_u8()
.ok_or_else(|| protocol_error("Missing protocol version"))?;
if protocol_version != 10 {
return Err(protocol_error(format!(
"Unsupported protocol version: {}",
protocol_version
)));
}
let server_version = reader
.read_null_string()
.ok_or_else(|| protocol_error("Missing server version"))?;
let connection_id = reader
.read_u32_le()
.ok_or_else(|| protocol_error("Missing connection ID"))?;
let auth_data_1 = reader
.read_bytes(8)
.ok_or_else(|| protocol_error("Missing auth data"))?;
reader.skip(1);
let caps_lower = reader
.read_u16_le()
.ok_or_else(|| protocol_error("Missing capability flags"))?;
let charset = reader.read_u8().unwrap_or(charset::UTF8MB4_0900_AI_CI);
let status_flags = reader.read_u16_le().unwrap_or(0);
let caps_upper = reader.read_u16_le().unwrap_or(0);
let capabilities = u32::from(caps_lower) | (u32::from(caps_upper) << 16);
let auth_data_len = if capabilities & capabilities::CLIENT_PLUGIN_AUTH != 0 {
reader.read_u8().unwrap_or(0) as usize
} else {
0
};
reader.skip(10);
let mut auth_data = auth_data_1.to_vec();
if capabilities & capabilities::CLIENT_SECURE_CONNECTION != 0 {
let len2 = if auth_data_len > 8 {
auth_data_len - 8
} else {
13 };
if let Some(data2) = reader.read_bytes(len2) {
let data2_clean = if data2.last() == Some(&0) {
&data2[..data2.len() - 1]
} else {
data2
};
auth_data.extend_from_slice(data2_clean);
}
}
let auth_plugin = if capabilities & capabilities::CLIENT_PLUGIN_AUTH != 0 {
reader.read_null_string().unwrap_or_default()
} else {
auth::plugins::MYSQL_NATIVE_PASSWORD.to_string()
};
Ok(ServerCapabilities {
capabilities,
protocol_version,
server_version,
connection_id,
auth_plugin,
auth_data,
charset,
status_flags,
})
}
#[allow(clippy::result_large_err)]
fn send_handshake_response(&mut self) -> Result<(), Error> {
let server_caps = self
.server_caps
.as_ref()
.ok_or_else(|| protocol_error("No server handshake received"))?;
let client_caps = self.config.capability_flags() & server_caps.capabilities;
let auth_response =
self.compute_auth_response(&server_caps.auth_plugin, &server_caps.auth_data);
let mut writer = PacketWriter::new();
writer.write_u32_le(client_caps);
writer.write_u32_le(self.config.max_packet_size);
writer.write_u8(self.config.charset);
writer.write_zeros(23);
writer.write_null_string(&self.config.user);
if client_caps & capabilities::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA != 0 {
writer.write_lenenc_bytes(&auth_response);
} else if client_caps & capabilities::CLIENT_SECURE_CONNECTION != 0 {
#[allow(clippy::cast_possible_truncation)]
writer.write_u8(auth_response.len() as u8);
writer.write_bytes(&auth_response);
} else {
writer.write_bytes(&auth_response);
writer.write_u8(0); }
if client_caps & capabilities::CLIENT_CONNECT_WITH_DB != 0 {
if let Some(ref db) = self.config.database {
writer.write_null_string(db);
} else {
writer.write_u8(0); }
}
if client_caps & capabilities::CLIENT_PLUGIN_AUTH != 0 {
writer.write_null_string(&server_caps.auth_plugin);
}
if client_caps & capabilities::CLIENT_CONNECT_ATTRS != 0
&& !self.config.attributes.is_empty()
{
let mut attrs_writer = PacketWriter::new();
for (key, value) in &self.config.attributes {
attrs_writer.write_lenenc_string(key);
attrs_writer.write_lenenc_string(value);
}
let attrs_data = attrs_writer.into_bytes();
writer.write_lenenc_bytes(&attrs_data);
}
self.write_packet(writer.as_bytes())?;
Ok(())
}
fn compute_auth_response(&self, plugin: &str, auth_data: &[u8]) -> Vec<u8> {
let password = self.config.password.as_deref().unwrap_or("");
match plugin {
auth::plugins::MYSQL_NATIVE_PASSWORD => {
auth::mysql_native_password(password, auth_data)
}
auth::plugins::CACHING_SHA2_PASSWORD => {
auth::caching_sha2_password(password, auth_data)
}
auth::plugins::MYSQL_CLEAR_PASSWORD => {
let mut result = password.as_bytes().to_vec();
result.push(0);
result
}
_ => {
auth::mysql_native_password(password, auth_data)
}
}
}
#[allow(clippy::result_large_err)]
fn handle_auth_result(&mut self) -> Result<(), Error> {
let (payload, _) = self.read_packet()?;
if payload.is_empty() {
return Err(protocol_error("Empty authentication response"));
}
match PacketType::from_first_byte(payload[0], payload.len() as u32) {
PacketType::Ok => {
let mut reader = PacketReader::new(&payload);
if let Some(ok) = reader.parse_ok_packet() {
self.status_flags = ok.status_flags;
self.affected_rows = ok.affected_rows;
}
Ok(())
}
PacketType::Error => {
let mut reader = PacketReader::new(&payload);
let err = reader
.parse_err_packet()
.ok_or_else(|| protocol_error("Invalid error packet"))?;
Err(auth_error(format!(
"Authentication failed: {} ({})",
err.error_message, err.error_code
)))
}
PacketType::Eof => {
self.handle_auth_switch(&payload[1..])
}
_ => {
self.handle_additional_auth(&payload)
}
}
}
#[allow(clippy::result_large_err)]
fn handle_auth_switch(&mut self, data: &[u8]) -> Result<(), Error> {
let mut reader = PacketReader::new(data);
let plugin = reader
.read_null_string()
.ok_or_else(|| protocol_error("Missing plugin name in auth switch"))?;
let auth_data = reader.read_rest();
let response = self.compute_auth_response(&plugin, auth_data);
self.write_packet(&response)?;
self.handle_auth_result()
}
#[allow(clippy::result_large_err)]
fn handle_additional_auth(&mut self, data: &[u8]) -> Result<(), Error> {
if data.is_empty() {
return Err(protocol_error("Empty additional auth data"));
}
match data[0] {
auth::caching_sha2::FAST_AUTH_SUCCESS => {
let (payload, _) = self.read_packet()?;
let mut reader = PacketReader::new(&payload);
if let Some(ok) = reader.parse_ok_packet() {
self.status_flags = ok.status_flags;
}
Ok(())
}
auth::caching_sha2::PERFORM_FULL_AUTH => {
let Some(server_caps) = self.server_caps.as_ref() else {
return Err(protocol_error("Missing server capabilities during auth"));
};
let password = self.config.password.clone().unwrap_or_default();
let seed = server_caps.auth_data.clone();
let server_version = server_caps.server_version.clone();
self.write_packet(&[auth::caching_sha2::REQUEST_PUBLIC_KEY])?;
let (payload, _) = self.read_packet()?;
if payload.is_empty() {
return Err(protocol_error("Empty public key response"));
}
let public_key = if payload[0] == 0x01 {
&payload[1..]
} else {
&payload[..]
};
let use_oaep = mysql_server_uses_oaep(&server_version);
let encrypted = auth::sha256_password_rsa(&password, &seed, public_key, use_oaep)
.map_err(auth_error)?;
self.write_packet(&encrypted)?;
self.handle_auth_result()
}
_ => {
let mut reader = PacketReader::new(data);
if let Some(ok) = reader.parse_ok_packet() {
self.status_flags = ok.status_flags;
Ok(())
} else {
Err(protocol_error(format!(
"Unknown auth response: {:02X}",
data[0]
)))
}
}
}
}
#[allow(clippy::result_large_err)]
pub fn query_sync(&mut self, sql: &str, params: &[Value]) -> Result<Vec<Row>, Error> {
#[cfg(feature = "console")]
let start = std::time::Instant::now();
let sql = interpolate_params(sql, params);
if !self.is_ready() && self.state != ConnectionState::InTransaction {
return Err(connection_error("Connection not ready for queries"));
}
self.state = ConnectionState::InQuery;
self.sequence_id = 0;
let mut writer = PacketWriter::new();
writer.write_u8(Command::Query as u8);
writer.write_bytes(sql.as_bytes());
self.write_packet(writer.as_bytes())?;
let (payload, _) = self.read_packet()?;
if payload.is_empty() {
self.state = ConnectionState::Ready;
return Err(protocol_error("Empty query response"));
}
match PacketType::from_first_byte(payload[0], payload.len() as u32) {
PacketType::Ok => {
let mut reader = PacketReader::new(&payload);
if let Some(ok) = reader.parse_ok_packet() {
self.affected_rows = ok.affected_rows;
self.last_insert_id = ok.last_insert_id;
self.status_flags = ok.status_flags;
self.warnings = ok.warnings;
}
self.state = if self.status_flags
& crate::protocol::server_status::SERVER_STATUS_IN_TRANS
!= 0
{
ConnectionState::InTransaction
} else {
ConnectionState::Ready
};
#[cfg(feature = "console")]
{
let elapsed_ms = start.elapsed().as_secs_f64() * 1000.0;
self.emit_execute_timing(&sql, elapsed_ms, self.affected_rows);
self.emit_warnings(self.warnings);
}
Ok(vec![])
}
PacketType::Error => {
self.state = ConnectionState::Ready;
let mut reader = PacketReader::new(&payload);
let err = reader
.parse_err_packet()
.ok_or_else(|| protocol_error("Invalid error packet"))?;
Err(query_error(&err))
}
PacketType::LocalInfile => {
self.state = ConnectionState::Ready;
Err(query_error_msg("LOCAL INFILE not supported"))
}
_ => {
#[cfg(feature = "console")]
let result = self.read_result_set_with_timing(&sql, &payload, start);
#[cfg(not(feature = "console"))]
let result = self.read_result_set(&payload);
result
}
}
}
#[allow(dead_code)] #[allow(clippy::result_large_err)]
fn read_result_set(&mut self, first_packet: &[u8]) -> Result<Vec<Row>, Error> {
let mut reader = PacketReader::new(first_packet);
let column_count = reader
.read_lenenc_int()
.ok_or_else(|| protocol_error("Invalid column count"))?
as usize;
let mut columns = Vec::with_capacity(column_count);
for _ in 0..column_count {
let (payload, _) = self.read_packet()?;
columns.push(self.parse_column_def(&payload)?);
}
let server_caps = self.server_caps.as_ref().map_or(0, |c| c.capabilities);
if server_caps & capabilities::CLIENT_DEPRECATE_EOF == 0 {
let (payload, _) = self.read_packet()?;
if payload.first() == Some(&0xFE) {
}
}
let mut rows = Vec::new();
loop {
let (payload, _) = self.read_packet()?;
if payload.is_empty() {
break;
}
match PacketType::from_first_byte(payload[0], payload.len() as u32) {
PacketType::Eof | PacketType::Ok => {
let mut reader = PacketReader::new(&payload);
if payload[0] == 0x00 {
if let Some(ok) = reader.parse_ok_packet() {
self.status_flags = ok.status_flags;
self.warnings = ok.warnings;
}
} else if payload[0] == 0xFE {
if let Some(eof) = reader.parse_eof_packet() {
self.status_flags = eof.status_flags;
self.warnings = eof.warnings;
}
}
break;
}
PacketType::Error => {
let mut reader = PacketReader::new(&payload);
let err = reader
.parse_err_packet()
.ok_or_else(|| protocol_error("Invalid error packet"))?;
self.state = ConnectionState::Ready;
return Err(query_error(&err));
}
_ => {
let row = self.parse_text_row(&payload, &columns);
rows.push(row);
}
}
}
self.state =
if self.status_flags & crate::protocol::server_status::SERVER_STATUS_IN_TRANS != 0 {
ConnectionState::InTransaction
} else {
ConnectionState::Ready
};
Ok(rows)
}
#[cfg(feature = "console")]
#[allow(clippy::result_large_err)]
fn read_result_set_with_timing(
&mut self,
sql: &str,
first_packet: &[u8],
start: std::time::Instant,
) -> Result<Vec<Row>, Error> {
let mut reader = PacketReader::new(first_packet);
let column_count = reader
.read_lenenc_int()
.ok_or_else(|| protocol_error("Invalid column count"))?
as usize;
let mut columns = Vec::with_capacity(column_count);
let mut col_names = Vec::with_capacity(column_count);
for _ in 0..column_count {
let (payload, _) = self.read_packet()?;
let col = self.parse_column_def(&payload)?;
col_names.push(col.name.clone());
columns.push(col);
}
let server_caps = self.server_caps.as_ref().map_or(0, |c| c.capabilities);
if server_caps & capabilities::CLIENT_DEPRECATE_EOF == 0 {
let (payload, _) = self.read_packet()?;
if payload.first() == Some(&0xFE) {
}
}
let mut rows = Vec::new();
loop {
let (payload, _) = self.read_packet()?;
if payload.is_empty() {
break;
}
match PacketType::from_first_byte(payload[0], payload.len() as u32) {
PacketType::Eof | PacketType::Ok => {
let mut reader = PacketReader::new(&payload);
if payload[0] == 0x00 {
if let Some(ok) = reader.parse_ok_packet() {
self.status_flags = ok.status_flags;
self.warnings = ok.warnings;
}
} else if payload[0] == 0xFE {
if let Some(eof) = reader.parse_eof_packet() {
self.status_flags = eof.status_flags;
self.warnings = eof.warnings;
}
}
break;
}
PacketType::Error => {
let mut reader = PacketReader::new(&payload);
let err = reader
.parse_err_packet()
.ok_or_else(|| protocol_error("Invalid error packet"))?;
self.state = ConnectionState::Ready;
return Err(query_error(&err));
}
_ => {
let row = self.parse_text_row(&payload, &columns);
rows.push(row);
}
}
}
self.state =
if self.status_flags & crate::protocol::server_status::SERVER_STATUS_IN_TRANS != 0 {
ConnectionState::InTransaction
} else {
ConnectionState::Ready
};
let elapsed_ms = start.elapsed().as_secs_f64() * 1000.0;
let sql_upper = sql.trim().to_uppercase();
if sql_upper.starts_with("SHOW") {
self.emit_show_results(sql, &col_names, &rows, elapsed_ms);
} else {
self.emit_query_timing(sql, elapsed_ms, rows.len());
}
self.emit_warnings(self.warnings);
Ok(rows)
}
#[allow(clippy::result_large_err)]
fn parse_column_def(&self, data: &[u8]) -> Result<ColumnDef, Error> {
let mut reader = PacketReader::new(data);
let catalog = reader
.read_lenenc_string()
.ok_or_else(|| protocol_error("Missing catalog"))?;
let schema = reader
.read_lenenc_string()
.ok_or_else(|| protocol_error("Missing schema"))?;
let table = reader
.read_lenenc_string()
.ok_or_else(|| protocol_error("Missing table"))?;
let org_table = reader
.read_lenenc_string()
.ok_or_else(|| protocol_error("Missing org_table"))?;
let name = reader
.read_lenenc_string()
.ok_or_else(|| protocol_error("Missing name"))?;
let org_name = reader
.read_lenenc_string()
.ok_or_else(|| protocol_error("Missing org_name"))?;
let _fixed_len = reader.read_lenenc_int();
let charset = reader
.read_u16_le()
.ok_or_else(|| protocol_error("Missing charset"))?;
let column_length = reader
.read_u32_le()
.ok_or_else(|| protocol_error("Missing column_length"))?;
let column_type = FieldType::from_u8(
reader
.read_u8()
.ok_or_else(|| protocol_error("Missing column_type"))?,
);
let flags = reader
.read_u16_le()
.ok_or_else(|| protocol_error("Missing flags"))?;
let decimals = reader
.read_u8()
.ok_or_else(|| protocol_error("Missing decimals"))?;
Ok(ColumnDef {
catalog,
schema,
table,
org_table,
name,
org_name,
charset,
column_length,
column_type,
flags,
decimals,
})
}
fn parse_text_row(&self, data: &[u8], columns: &[ColumnDef]) -> Row {
let mut reader = PacketReader::new(data);
let mut values = Vec::with_capacity(columns.len());
for col in columns {
if reader.peek() == Some(0xFB) {
reader.skip(1);
values.push(Value::Null);
} else if let Some(data) = reader.read_lenenc_bytes() {
let is_unsigned = col.is_unsigned();
let value = decode_text_value(col.column_type, &data, is_unsigned);
values.push(value);
} else {
values.push(Value::Null);
}
}
let column_names: Vec<String> = columns.iter().map(|c| c.name.clone()).collect();
Row::new(column_names, values)
}
#[allow(clippy::result_large_err)]
pub fn query_one_sync(&mut self, sql: &str, params: &[Value]) -> Result<Option<Row>, Error> {
let rows = self.query_sync(sql, params)?;
Ok(rows.into_iter().next())
}
#[allow(clippy::result_large_err)]
pub fn execute_sync(&mut self, sql: &str, params: &[Value]) -> Result<u64, Error> {
self.query_sync(sql, params)?;
Ok(self.affected_rows)
}
#[allow(clippy::result_large_err)]
pub fn insert_sync(&mut self, sql: &str, params: &[Value]) -> Result<i64, Error> {
self.query_sync(sql, params)?;
Ok(self.last_insert_id as i64)
}
#[allow(clippy::result_large_err)]
pub fn ping(&mut self) -> Result<(), Error> {
self.sequence_id = 0;
let mut writer = PacketWriter::new();
writer.write_u8(Command::Ping as u8);
self.write_packet(writer.as_bytes())?;
let (payload, _) = self.read_packet()?;
if payload.first() == Some(&0x00) {
Ok(())
} else {
Err(connection_error("Ping failed"))
}
}
#[allow(clippy::result_large_err)]
pub fn close(mut self) -> Result<(), Error> {
if self.state == ConnectionState::Closed {
return Ok(());
}
self.sequence_id = 0;
let mut writer = PacketWriter::new();
writer.write_u8(Command::Quit as u8);
let _ = self.write_packet(writer.as_bytes());
self.state = ConnectionState::Closed;
Ok(())
}
#[allow(clippy::result_large_err)]
fn read_packet(&mut self) -> Result<(Vec<u8>, u8), Error> {
let mut header_buf = [0u8; 4];
self.stream.read_exact(&mut header_buf).map_err(|e| {
Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("Failed to read packet header: {}", e),
source: Some(Box::new(e)),
})
})?;
let header = PacketHeader::from_bytes(&header_buf);
let payload_len = header.payload_length as usize;
self.sequence_id = header.sequence_id.wrapping_add(1);
let mut payload = vec![0u8; payload_len];
if payload_len > 0 {
self.stream.read_exact(&mut payload).map_err(|e| {
Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("Failed to read packet payload: {}", e),
source: Some(Box::new(e)),
})
})?;
}
if payload_len == MAX_PACKET_SIZE {
loop {
let mut header_buf = [0u8; 4];
self.stream.read_exact(&mut header_buf).map_err(|e| {
Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("Failed to read continuation header: {}", e),
source: Some(Box::new(e)),
})
})?;
let cont_header = PacketHeader::from_bytes(&header_buf);
let cont_len = cont_header.payload_length as usize;
self.sequence_id = cont_header.sequence_id.wrapping_add(1);
if cont_len > 0 {
let mut cont_payload = vec![0u8; cont_len];
self.stream.read_exact(&mut cont_payload).map_err(|e| {
Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("Failed to read continuation payload: {}", e),
source: Some(Box::new(e)),
})
})?;
payload.extend_from_slice(&cont_payload);
}
if cont_len < MAX_PACKET_SIZE {
break;
}
}
}
Ok((payload, header.sequence_id))
}
#[allow(clippy::result_large_err)]
fn write_packet(&mut self, payload: &[u8]) -> Result<(), Error> {
let writer = PacketWriter::new();
let packet = writer.build_packet_from_payload(payload, self.sequence_id);
self.sequence_id = self.sequence_id.wrapping_add(1);
self.stream.write_all(&packet).map_err(|e| {
Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("Failed to write packet: {}", e),
source: Some(Box::new(e)),
})
})?;
self.stream.flush().map_err(|e| {
Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("Failed to flush stream: {}", e),
source: Some(Box::new(e)),
})
})?;
Ok(())
}
}
#[cfg(feature = "console")]
impl ConsoleAware for MySqlConnection {
fn set_console(&mut self, console: Option<Arc<SqlModelConsole>>) {
self.console = console;
}
fn console(&self) -> Option<&Arc<SqlModelConsole>> {
self.console.as_ref()
}
}
#[cfg(feature = "console")]
impl MySqlConnection {
#[allow(dead_code)]
fn emit_connection_progress(&self, stage: &str, status: &str, is_final: bool) {
if let Some(console) = &self.console {
let mode = console.mode();
match mode {
sqlmodel_console::OutputMode::Plain => {
if is_final {
console.status(&format!("[MySQL] {}: {}", stage, status));
}
}
sqlmodel_console::OutputMode::Rich => {
let status_icon = if status.starts_with("OK") || status.starts_with("Connected")
{
"✓"
} else if status.starts_with("Error") || status.starts_with("Failed") {
"✗"
} else {
"…"
};
console.status(&format!(" {} {}: {}", status_icon, stage, status));
}
sqlmodel_console::OutputMode::Json => {
}
}
}
}
fn emit_query_timing(&self, sql: &str, elapsed_ms: f64, row_count: usize) {
if let Some(console) = &self.console {
let mode = console.mode();
let sql_preview: String = sql.chars().take(60).collect();
let sql_display = if sql.len() > 60 {
format!("{}...", sql_preview)
} else {
sql_preview
};
match mode {
sqlmodel_console::OutputMode::Plain => {
console.status(&format!(
"[MySQL] Query: {:.2}ms, {} rows | {}",
elapsed_ms, row_count, sql_display
));
}
sqlmodel_console::OutputMode::Rich => {
let time_color = if elapsed_ms < 10.0 {
"\x1b[32m" } else if elapsed_ms < 100.0 {
"\x1b[33m" } else {
"\x1b[31m" };
console.status(&format!(
" ⏱ {}{:.2}ms\x1b[0m ({} rows) {}",
time_color, elapsed_ms, row_count, sql_display
));
}
sqlmodel_console::OutputMode::Json => {
}
}
}
}
fn emit_execute_timing(&self, sql: &str, elapsed_ms: f64, affected_rows: u64) {
if let Some(console) = &self.console {
let mode = console.mode();
let sql_preview: String = sql.chars().take(60).collect();
let sql_display = if sql.len() > 60 {
format!("{}...", sql_preview)
} else {
sql_preview
};
match mode {
sqlmodel_console::OutputMode::Plain => {
console.status(&format!(
"[MySQL] Execute: {:.2}ms, {} affected | {}",
elapsed_ms, affected_rows, sql_display
));
}
sqlmodel_console::OutputMode::Rich => {
let time_color = if elapsed_ms < 10.0 {
"\x1b[32m"
} else if elapsed_ms < 100.0 {
"\x1b[33m"
} else {
"\x1b[31m"
};
console.status(&format!(
" ⏱ {}{:.2}ms\x1b[0m ({} affected) {}",
time_color, elapsed_ms, affected_rows, sql_display
));
}
sqlmodel_console::OutputMode::Json => {}
}
}
}
fn emit_warnings(&self, warning_count: u16) {
if warning_count == 0 {
return;
}
if let Some(console) = &self.console {
let mode = console.mode();
match mode {
sqlmodel_console::OutputMode::Plain => {
console.warning(&format!("[MySQL] {} warning(s)", warning_count));
}
sqlmodel_console::OutputMode::Rich => {
console.warning(&format!("{} warning(s)", warning_count));
}
sqlmodel_console::OutputMode::Json => {}
}
}
}
fn emit_show_results(&self, sql: &str, col_names: &[String], rows: &[Row], elapsed_ms: f64) {
if let Some(console) = &self.console {
let mode = console.mode();
let sql_upper = sql.trim().to_uppercase();
if !sql_upper.starts_with("SHOW") {
self.emit_query_timing(sql, elapsed_ms, rows.len());
return;
}
match mode {
sqlmodel_console::OutputMode::Plain | sqlmodel_console::OutputMode::Rich => {
let mut widths: Vec<usize> = col_names.iter().map(|n| n.len()).collect();
for row in rows {
for (i, val) in row.values().enumerate() {
if i < widths.len() {
let val_str = format_value(val);
widths[i] = widths[i].max(val_str.len());
}
}
}
let header: String = col_names
.iter()
.zip(&widths)
.map(|(name, width)| format!("{:width$}", name, width = width))
.collect::<Vec<_>>()
.join(" | ");
let separator: String = widths
.iter()
.map(|w| "-".repeat(*w))
.collect::<Vec<_>>()
.join("-+-");
console.status(&header);
console.status(&separator);
for row in rows {
let row_str: String = row
.values()
.zip(&widths)
.map(|(val, width)| {
format!("{:width$}", format_value(val), width = width)
})
.collect::<Vec<_>>()
.join(" | ");
console.status(&row_str);
}
console.status(&format!("({} rows, {:.2}ms)\n", rows.len(), elapsed_ms));
}
sqlmodel_console::OutputMode::Json => {}
}
}
}
}
#[cfg(feature = "console")]
fn format_value(value: &Value) -> String {
match value {
Value::Null => "NULL".to_string(),
Value::Bool(b) => if *b { "true" } else { "false" }.to_string(),
Value::TinyInt(i) => i.to_string(),
Value::SmallInt(i) => i.to_string(),
Value::Int(i) => i.to_string(),
Value::BigInt(i) => i.to_string(),
Value::Float(f) => format!("{:.6}", f),
Value::Double(f) => format!("{:.6}", f),
Value::Decimal(d) => d.clone(),
Value::Text(s) => s.clone(),
Value::Bytes(b) => format!("<{} bytes>", b.len()),
Value::Date(d) => format!("date:{}", d),
Value::Time(t) => format!("time:{}", t),
Value::Timestamp(ts) => format!("ts:{}", ts),
Value::TimestampTz(ts) => format!("tstz:{}", ts),
Value::Uuid(u) => {
format!(
"{:02x}{:02x}{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}",
u[0],
u[1],
u[2],
u[3],
u[4],
u[5],
u[6],
u[7],
u[8],
u[9],
u[10],
u[11],
u[12],
u[13],
u[14],
u[15]
)
}
Value::Json(j) => j.to_string(),
Value::Array(arr) => format!("[{} items]", arr.len()),
Value::Default => "DEFAULT".to_string(),
}
}
fn protocol_error(msg: impl Into<String>) -> Error {
Error::Protocol(ProtocolError {
message: msg.into(),
raw_data: None,
source: None,
})
}
fn auth_error(msg: impl Into<String>) -> Error {
Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Authentication,
message: msg.into(),
source: None,
})
}
fn connection_error(msg: impl Into<String>) -> Error {
Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Connect,
message: msg.into(),
source: None,
})
}
fn mysql_server_uses_oaep(server_version: &str) -> bool {
let prefix: String = server_version
.chars()
.take_while(|c| c.is_ascii_digit() || *c == '.')
.collect();
let mut it = prefix.split('.').filter(|s| !s.is_empty());
let major: u64 = match it.next().and_then(|s| s.parse().ok()) {
Some(v) => v,
None => return true,
};
let minor: u64 = it.next().and_then(|s| s.parse().ok()).unwrap_or(0);
let patch: u64 = it.next().and_then(|s| s.parse().ok()).unwrap_or(0);
(major, minor, patch) >= (8, 0, 5)
}
fn query_error(err: &ErrPacket) -> Error {
let kind = if err.is_duplicate_key() || err.is_foreign_key_violation() {
QueryErrorKind::Constraint
} else {
QueryErrorKind::Syntax
};
Error::Query(QueryError {
kind,
message: err.error_message.clone(),
sqlstate: Some(err.sql_state.clone()),
sql: None,
detail: None,
hint: None,
position: None,
source: None,
})
}
fn query_error_msg(msg: impl Into<String>) -> Error {
Error::Query(QueryError {
kind: QueryErrorKind::Syntax,
message: msg.into(),
sqlstate: None,
sql: None,
detail: None,
hint: None,
position: None,
source: None,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_connection_state_default() {
assert_eq!(ConnectionState::Disconnected, ConnectionState::Disconnected);
}
#[test]
fn test_error_helpers() {
let err = protocol_error("test error");
assert!(matches!(err, Error::Protocol(_)));
let err = auth_error("auth failed");
assert!(matches!(err, Error::Connection(_)));
let err = connection_error("connection failed");
assert!(matches!(err, Error::Connection(_)));
}
#[test]
fn test_query_error_duplicate_key() {
let err_packet = ErrPacket {
error_code: 1062,
sql_state: "23000".to_string(),
error_message: "Duplicate entry".to_string(),
};
let err = query_error(&err_packet);
assert!(matches!(err, Error::Query(_)), "Expected query error");
let Error::Query(q) = err else { return };
assert_eq!(q.kind, QueryErrorKind::Constraint);
}
#[cfg(feature = "console")]
mod console_tests {
use super::*;
use sqlmodel_console::{ConsoleAware, OutputMode, SqlModelConsole};
fn assert_console_aware<T: ConsoleAware>() {}
#[test]
fn test_console_aware_trait_impl() {
let config = MySqlConfig::new()
.host("localhost")
.port(13306)
.user("test")
.password("test");
assert_console_aware::<MySqlConnection>();
assert_eq!(config.host, "localhost");
assert_eq!(config.port, 13306);
}
#[test]
fn test_format_value_all_types() {
assert_eq!(format_value(&Value::Null), "NULL");
assert_eq!(format_value(&Value::Bool(true)), "true");
assert_eq!(format_value(&Value::Bool(false)), "false");
assert_eq!(format_value(&Value::TinyInt(42)), "42");
assert_eq!(format_value(&Value::SmallInt(1000)), "1000");
assert_eq!(format_value(&Value::Int(123_456)), "123456");
assert_eq!(format_value(&Value::BigInt(9_999_999_999)), "9999999999");
assert!(format_value(&Value::Float(1.5)).starts_with("1.5"));
assert!(format_value(&Value::Double(1.234_567_890)).starts_with("1.23456"));
assert_eq!(
format_value(&Value::Decimal("123.45".to_string())),
"123.45"
);
assert_eq!(format_value(&Value::Text("hello".to_string())), "hello");
assert_eq!(format_value(&Value::Bytes(vec![1, 2, 3])), "<3 bytes>");
assert!(format_value(&Value::Date(19000)).contains("date:"));
assert!(format_value(&Value::Time(43_200_000_000)).contains("time:"));
assert!(format_value(&Value::Timestamp(1_700_000_000_000_000)).contains("ts:"));
assert!(format_value(&Value::TimestampTz(1_700_000_000_000_000)).contains("tstz:"));
let uuid = [0u8; 16];
let uuid_str = format_value(&Value::Uuid(uuid));
assert_eq!(uuid_str, "00000000-0000-0000-0000-000000000000");
let json = serde_json::json!({"key": "value"});
let json_str = format_value(&Value::Json(json));
assert!(json_str.contains("key"));
let arr = vec![Value::Int(1), Value::Int(2)];
assert_eq!(format_value(&Value::Array(arr)), "[2 items]");
}
#[test]
fn test_plain_mode_output_format() {
let plain_console = SqlModelConsole::with_mode(OutputMode::Plain);
assert!(plain_console.is_plain());
let rich_console = SqlModelConsole::with_mode(OutputMode::Rich);
assert!(rich_console.is_rich());
let json_console = SqlModelConsole::with_mode(OutputMode::Json);
assert!(json_console.is_json());
}
#[test]
fn test_console_mode_detection() {
let console = SqlModelConsole::with_mode(OutputMode::Plain);
assert!(console.is_plain());
assert!(!console.is_rich());
assert!(!console.is_json());
assert_eq!(console.mode(), OutputMode::Plain);
}
#[test]
fn test_format_value_uuid() {
let uuid: [u8; 16] = [
0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0, 0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc,
0xde, 0xf0,
];
let result = format_value(&Value::Uuid(uuid));
assert_eq!(result, "12345678-9abc-def0-1234-56789abcdef0");
}
#[test]
fn test_format_value_nested_json() {
let json = serde_json::json!({
"users": [
{"name": "Alice", "age": 30},
{"name": "Bob", "age": 25}
]
});
let result = format_value(&Value::Json(json));
assert!(result.contains("users"));
assert!(result.contains("Alice"));
}
}
}