use bytes::{BufMut, BytesMut};
use sha1::{Digest, Sha1};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tracing::{debug, info, warn};
use crate::YamlBaseError;
use crate::config::Config;
use crate::database::{Storage, Value};
use crate::protocol::mysql_binary::MySqlBinaryProtocol;
use crate::protocol::mysql_caching_sha2::{CACHING_SHA2_PLUGIN_NAME, CachingSha2Auth};
use crate::protocol::mysql_information_schema::MySqlInformationSchema;
use crate::protocol::mysql_system::MySqlSystemVariables;
use crate::sql::{QueryExecutor, parse_sql};
const PROTOCOL_VERSION: u8 = 10;
const SERVER_VERSION: &str = "8.0.35-yamlbase";
const AUTH_PLUGIN_NAME: &str = "mysql_native_password";
const COM_QUIT: u8 = 0x01;
const COM_INIT_DB: u8 = 0x02;
const COM_QUERY: u8 = 0x03;
const COM_PING: u8 = 0x0e;
const COM_STMT_PREPARE: u8 = 0x16;
const COM_STMT_EXECUTE: u8 = 0x17;
const COM_STMT_CLOSE: u8 = 0x19;
const COM_STMT_RESET: u8 = 0x1a;
const CLIENT_LONG_PASSWORD: u32 = 0x00000001;
const CLIENT_FOUND_ROWS: u32 = 0x00000002;
const CLIENT_LONG_FLAG: u32 = 0x00000004;
const CLIENT_CONNECT_WITH_DB: u32 = 0x00000008;
const CLIENT_PROTOCOL_41: u32 = 0x00000200;
const CLIENT_SECURE_CONNECTION: u32 = 0x00008000;
const CLIENT_PLUGIN_AUTH: u32 = 0x00080000;
const CLIENT_DEPRECATE_EOF: u32 = 0x01000000;
const MYSQL_TYPE_VAR_STRING: u8 = 253;
const SERVER_STATUS_AUTOCOMMIT: u16 = 0x0002;
pub struct MySqlProtocol {
config: Arc<Config>,
executor: QueryExecutor,
_database_name: String,
system_variables: MySqlSystemVariables,
binary_protocol: MySqlBinaryProtocol,
information_schema: MySqlInformationSchema,
}
struct ConnectionState {
sequence_id: u8,
capabilities: u32,
auth_data: Vec<u8>,
client_auth_plugin: Option<String>,
using_binary_protocol: bool,
}
impl Default for ConnectionState {
fn default() -> Self {
Self {
sequence_id: 0,
capabilities: 0,
auth_data: generate_auth_data(),
client_auth_plugin: None,
using_binary_protocol: false,
}
}
}
impl MySqlProtocol {
pub async fn new(config: Arc<Config>, storage: Arc<Storage>) -> crate::Result<Self> {
let executor = QueryExecutor::new(storage.clone()).await?;
let mut information_schema = MySqlInformationSchema::new(storage.clone());
let db_arc = storage.database();
let db = db_arc.read().await;
for (table_name, table) in &db.tables {
information_schema.add_user_table(table_name, &table.columns);
}
drop(db);
Ok(Self {
config,
executor,
_database_name: String::new(),
system_variables: MySqlSystemVariables::new(),
binary_protocol: MySqlBinaryProtocol::new(),
information_schema,
})
}
pub async fn handle_connection(&mut self, mut stream: TcpStream) -> crate::Result<()> {
info!("New MySQL connection");
let mut state = ConnectionState::default();
self.send_handshake(&mut stream, &mut state).await?;
let response_packet = self.read_packet(&mut stream, &mut state).await?;
let (username, auth_response, _database, client_plugin) =
self.parse_handshake_response(&response_packet)?;
state.client_auth_plugin = client_plugin;
debug!(
"Authentication check - username: {}, expected: {}",
username, self.config.username
);
if username != self.config.username {
debug!("Username mismatch");
self.send_error(&mut stream, &mut state, 1045, "28000", "Access denied")
.await?;
return Ok(());
}
let expected = compute_auth_response(&self.config.password, &state.auth_data);
debug!(
"Password check - auth_response len: {}, expected len: {}, config password: {}",
auth_response.len(),
expected.len(),
self.config.password
);
let client_wants_caching = state
.client_auth_plugin
.as_ref()
.map(|p| p == CACHING_SHA2_PLUGIN_NAME)
.unwrap_or(false);
if client_wants_caching || auth_response.is_empty() {
debug!("Client requested caching_sha2_password or sent empty auth");
let caching_auth_data = generate_auth_data();
let caching_auth = CachingSha2Auth::new(caching_auth_data.clone());
caching_auth
.send_auth_switch_request(&mut stream, &mut state.sequence_id)
.await?;
let auth_switch_response = self.read_packet(&mut stream, &mut state).await?;
let auth_success = caching_auth
.authenticate(
&mut stream,
&mut state.sequence_id,
&username,
"",
&self.config.username,
&self.config.password,
auth_switch_response,
)
.await?;
if !auth_success {
self.send_error(&mut stream, &mut state, 1045, "28000", "Access denied")
.await?;
return Ok(());
}
} else {
if auth_response != expected {
debug!(
"Password mismatch - expected: {:?}, got: {:?}",
expected, auth_response
);
self.send_error(&mut stream, &mut state, 1045, "28000", "Access denied")
.await?;
return Ok(());
}
}
self.send_ok(&mut stream, &mut state, 0, 0).await?;
info!("MySQL authentication successful, entering command loop");
loop {
let packet = match self.read_packet(&mut stream, &mut state).await {
Ok(p) => p,
Err(e) => {
debug!("Error reading packet: {}", e);
break;
}
};
if packet.is_empty() {
continue;
}
let command = packet[0];
debug!("Received command: 0x{:02x}", command);
match command {
COM_QUERY => {
let query = std::str::from_utf8(&packet[1..]).map_err(|_| {
YamlBaseError::Protocol("Invalid UTF-8 in query".to_string())
})?;
if let Err(e) = self.handle_query(&mut stream, &mut state, query).await {
warn!("Error handling query '{}': {}", query, e);
let _ = self
.send_error(&mut stream, &mut state, 1146, "42S02", &e.to_string())
.await;
}
}
COM_QUIT => {
info!("Client disconnected");
break;
}
COM_PING => {
self.send_ok(&mut stream, &mut state, 0, 0).await?;
}
COM_INIT_DB => {
let _db_name = std::str::from_utf8(&packet[1..]).map_err(|_| {
YamlBaseError::Protocol("Invalid UTF-8 in database name".to_string())
})?;
self.send_ok(&mut stream, &mut state, 0, 0).await?;
}
COM_STMT_PREPARE | COM_STMT_EXECUTE | COM_STMT_CLOSE | COM_STMT_RESET => {
match self
.binary_protocol
.handle_binary_command(
command,
&packet[1..],
&mut stream,
&mut state.sequence_id,
)
.await
{
Ok(handled) => {
if handled {
state.using_binary_protocol = true;
} else {
debug!(
"Binary protocol handler declined command: 0x{:02x}",
command
);
self.send_error(
&mut stream,
&mut state,
1047,
"08S01",
"Unknown command",
)
.await?;
}
}
Err(e) => {
warn!("Binary protocol error: {}", e);
self.send_error(&mut stream, &mut state, 1047, "08S01", &e.to_string())
.await?;
}
}
}
_ => {
debug!("Unhandled command: 0x{:02x}", command);
self.send_error(&mut stream, &mut state, 1047, "08S01", "Unknown command")
.await?;
}
}
}
Ok(())
}
async fn send_handshake(
&self,
stream: &mut TcpStream,
state: &mut ConnectionState,
) -> crate::Result<()> {
let mut packet = BytesMut::new();
packet.put_u8(PROTOCOL_VERSION);
packet.put_slice(SERVER_VERSION.as_bytes());
packet.put_u8(0);
packet.put_u32_le(1);
packet.put_slice(&state.auth_data[..8]);
packet.put_u8(0);
let capabilities = CLIENT_LONG_PASSWORD
| CLIENT_FOUND_ROWS
| CLIENT_LONG_FLAG
| CLIENT_CONNECT_WITH_DB
| CLIENT_PROTOCOL_41
| CLIENT_SECURE_CONNECTION
| CLIENT_PLUGIN_AUTH
| CLIENT_DEPRECATE_EOF;
state.capabilities = capabilities;
packet.put_u16_le((capabilities & 0xFFFF) as u16);
packet.put_u8(33);
packet.put_u16_le(SERVER_STATUS_AUTOCOMMIT);
packet.put_u16_le(((capabilities >> 16) & 0xFFFF) as u16);
packet.put_u8(21);
packet.put_slice(&[0; 10]);
packet.put_slice(&state.auth_data[8..20]);
packet.put_u8(0);
packet.put_slice(AUTH_PLUGIN_NAME.as_bytes());
packet.put_u8(0);
self.write_packet(stream, state, &packet).await?;
Ok(())
}
#[allow(clippy::type_complexity)]
fn parse_handshake_response(
&self,
packet: &[u8],
) -> crate::Result<(String, Vec<u8>, Option<String>, Option<String>)> {
debug!("Parsing handshake response, packet len: {}", packet.len());
if packet.len() < 32 {
return Err(YamlBaseError::Protocol(
"Handshake response too short".to_string(),
));
}
let mut pos = 0;
let client_flags = u32::from_le_bytes([
packet[pos],
packet[pos + 1],
packet[pos + 2],
packet[pos + 3],
]);
debug!("Client capabilities: 0x{:08x}", client_flags);
pos += 4;
pos += 4;
pos += 1;
pos += 23;
let username_end = packet[pos..]
.iter()
.position(|&b| b == 0)
.ok_or_else(|| YamlBaseError::Protocol("Invalid handshake response".to_string()))?;
let username = std::str::from_utf8(&packet[pos..pos + username_end])
.map_err(|_| YamlBaseError::Protocol("Invalid UTF-8 in username".to_string()))?
.to_string();
debug!("Username: {}", username);
pos += username_end + 1;
if pos >= packet.len() {
return Ok((username, Vec::new(), None, None));
}
let auth_len = packet[pos] as usize;
debug!("Auth response length: {}", auth_len);
pos += 1;
let auth_response = if auth_len > 0 && pos + auth_len <= packet.len() {
packet[pos..pos + auth_len].to_vec()
} else {
debug!("Auth response empty or invalid length");
Vec::new()
};
pos += auth_len;
let database = if pos < packet.len() {
let db_end = packet[pos..]
.iter()
.position(|&b| b == 0)
.unwrap_or(packet.len() - pos);
if db_end > 0 {
Some(
std::str::from_utf8(&packet[pos..pos + db_end])
.map_err(|_| {
YamlBaseError::Protocol("Invalid UTF-8 in database".to_string())
})?
.to_string(),
)
} else {
None
}
} else {
None
};
if let Some(ref db) = database {
pos += db.len() + 1;
}
let auth_plugin = if pos < packet.len() {
let plugin_end = packet[pos..]
.iter()
.position(|&b| b == 0)
.unwrap_or(packet.len() - pos);
if plugin_end > 0 {
Some(
std::str::from_utf8(&packet[pos..pos + plugin_end])
.map_err(|_| {
YamlBaseError::Protocol("Invalid UTF-8 in auth plugin".to_string())
})?
.to_string(),
)
} else {
None
}
} else {
None
};
debug!("Client auth plugin: {:?}", auth_plugin);
Ok((username, auth_response, database, auth_plugin))
}
async fn handle_query(
&mut self,
stream: &mut TcpStream,
state: &mut ConnectionState,
query: &str,
) -> crate::Result<()> {
let query_trimmed = query.trim();
let query_upper = query_trimmed.to_uppercase();
debug!("Handling query: {}", query_trimmed);
if query_trimmed.is_empty() {
debug!("Empty query received");
self.send_error(stream, state, 1064, "42000", "Syntax error: Empty query")
.await?;
return Ok(());
}
if self
.system_variables
.is_system_variable_query(query_trimmed)
{
return self
.handle_system_variable_query(stream, state, query_trimmed)
.await;
}
if let Ok(handled) = self.system_variables.handle_set_command(query_trimmed) {
if handled {
return self.send_ok(stream, state, 0, 0).await;
}
}
if query_upper.starts_with("SHOW ") {
return self.handle_show_command(stream, state, query_trimmed).await;
}
if query_upper.starts_with("DESCRIBE ") || query_upper.starts_with("DESC ") {
return self
.handle_describe_command(stream, state, query_trimmed)
.await;
}
if query_upper.contains("INFORMATION_SCHEMA") {
return self
.handle_information_schema_query(stream, state, query_trimmed)
.await;
}
let processed_query = self.preprocess_mysql_query(query_trimmed);
let statements = match parse_sql(&processed_query) {
Ok(stmts) => stmts,
Err(e) => {
self.send_error(
stream,
state,
1064,
"42000",
&format!("Syntax error: {}", e),
)
.await?;
return Ok(());
}
};
for statement in statements {
debug!("Executing statement: {:?}", statement);
let is_transaction_command = matches!(
statement,
sqlparser::ast::Statement::StartTransaction { .. }
| sqlparser::ast::Statement::Commit { .. }
| sqlparser::ast::Statement::Rollback { .. }
);
match self.executor.execute(&statement).await {
Ok(result) => {
debug!(
"Query executed successfully. Result: {} columns, {} rows",
result.columns.len(),
result.rows.len()
);
if is_transaction_command
|| (result.columns.is_empty() && result.rows.is_empty())
{
debug!("Sending OK packet for transaction command or empty result");
self.send_ok(stream, state, 0, 0).await?;
} else {
self.send_query_result(stream, state, &result).await?;
}
}
Err(e) => {
debug!("Query execution error: {}", e);
self.send_error(stream, state, 1146, "42S02", &e.to_string())
.await?;
}
}
}
Ok(())
}
fn preprocess_mysql_query(&self, query: &str) -> String {
use once_cell::sync::Lazy;
use regex::Regex;
let mut result = query.to_string();
if result.contains('`') {
result = result.replace('`', "");
debug!("Removed backticks: {}", result);
}
static MYSQL_FUNCTION_RE: Lazy<Result<Regex, regex::Error>> =
Lazy::new(|| Regex::new(r"(?i)\bIFNULL\s*\(\s*([^,]+)\s*,\s*([^)]+)\s*\)"));
if let Ok(ref re) = *MYSQL_FUNCTION_RE {
result = re.replace_all(&result, "COALESCE($1, $2)").to_string();
}
static LIMIT_RE: Lazy<Result<Regex, regex::Error>> =
Lazy::new(|| Regex::new(r"(?i)\bLIMIT\s+(\d+)\s*,\s*(\d+)\b"));
if let Ok(ref re) = *LIMIT_RE {
result = re.replace_all(&result, "LIMIT $2 OFFSET $1").to_string();
}
debug!("Preprocessed query: {} -> {}", query, result);
result
}
async fn handle_system_variable_query(
&self,
stream: &mut TcpStream,
state: &mut ConnectionState,
query: &str,
) -> crate::Result<()> {
let query_upper = query.to_uppercase();
if query_upper.starts_with("SHOW VARIABLES")
|| query_upper.starts_with("SHOW SESSION VARIABLES")
|| query_upper.starts_with("SHOW GLOBAL VARIABLES")
{
let pattern = if let Some(like_pos) = query_upper.find(" LIKE ") {
let pattern_part = &query[like_pos + 6..].trim();
Some(pattern_part.trim_matches('\'').trim_matches('"'))
} else {
None
};
let result = self.system_variables.handle_show_variables(pattern);
self.send_query_result(stream, state, &result).await
} else if query.contains("@@") {
use regex::Regex;
let re = Regex::new(r"@@(\w+)").unwrap();
let variables: Vec<&str> = re.find_iter(query).map(|m| m.as_str()).collect();
if variables.len() == 1 {
let var_name = variables[0].trim_start_matches("@@");
let result = self.system_variables.handle_variable_query(var_name);
self.send_query_result(stream, state, &result).await
} else if variables.len() > 1 {
let var_names: Vec<&str> = variables
.iter()
.map(|v| v.trim_start_matches("@@"))
.collect();
let result = self.system_variables.handle_multiple_variables(&var_names);
self.send_query_result(stream, state, &result).await
} else {
self.send_error(stream, state, 1064, "42000", "Invalid variable query")
.await
}
} else {
self.send_error(
stream,
state,
1064,
"42000",
"Unknown system variable query",
)
.await
}
}
async fn handle_show_command(
&mut self,
stream: &mut TcpStream,
state: &mut ConnectionState,
query: &str,
) -> crate::Result<()> {
let query_upper = query.to_uppercase();
if query_upper.starts_with("SHOW FULL TABLES") {
self.handle_show_full_tables(stream, state, query).await
} else if query_upper.starts_with("SHOW TABLES") {
self.handle_show_tables(stream, state, query).await
} else if query_upper.starts_with("SHOW CREATE TABLE") {
self.handle_show_create_table(stream, state, query).await
} else if query_upper.starts_with("SHOW DATABASES")
|| query_upper.starts_with("SHOW SCHEMAS")
{
self.handle_show_databases(stream, state).await
} else if query_upper.starts_with("SHOW COLUMNS") || query_upper.starts_with("SHOW FIELDS")
{
self.handle_show_columns(stream, state, query).await
} else if query_upper.starts_with("SHOW INDEX")
|| query_upper.starts_with("SHOW INDEXES")
|| query_upper.starts_with("SHOW KEYS")
{
self.handle_show_indexes(stream, state, query).await
} else if query_upper.starts_with("SHOW STATUS") {
self.handle_show_status(stream, state).await
} else if query_upper.starts_with("SHOW ENGINES") {
self.handle_show_engines(stream, state).await
} else if query_upper.starts_with("SHOW COLLATION") {
self.handle_show_collation(stream, state).await
} else if query_upper.starts_with("SHOW CHARACTER SET") {
self.handle_show_character_set(stream, state).await
} else {
debug!("Unhandled SHOW command: {}", query);
self.send_error(stream, state, 1064, "42000", "Unknown SHOW command")
.await
}
}
async fn handle_show_full_tables(
&mut self,
stream: &mut TcpStream,
state: &mut ConnectionState,
_query: &str,
) -> crate::Result<()> {
debug!("Handling SHOW FULL TABLES");
let storage = self.executor.storage();
let db = storage.database();
let db_guard = db.read().await;
let mut rows = Vec::new();
for (table_name, _table) in &db_guard.tables {
rows.push(vec![
Value::Text(table_name.clone()),
Value::Text("BASE TABLE".to_string()), ]);
}
drop(db_guard);
let result = crate::sql::executor::QueryResult {
columns: vec![
format!(
"Tables_in_{}",
self.executor.storage().database().read().await.name
),
"Table_type".to_string(),
],
column_types: vec![
crate::yaml::schema::SqlType::Text,
crate::yaml::schema::SqlType::Text,
],
rows,
};
self.send_query_result(stream, state, &result).await
}
async fn handle_show_tables(
&mut self,
stream: &mut TcpStream,
state: &mut ConnectionState,
_query: &str,
) -> crate::Result<()> {
debug!("Handling SHOW TABLES");
let storage = self.executor.storage();
let db = storage.database();
let db_guard = db.read().await;
let mut rows = Vec::new();
for (table_name, _table) in &db_guard.tables {
rows.push(vec![Value::Text(table_name.clone())]);
}
let db_name = db_guard.name.clone();
drop(db_guard);
let result = crate::sql::executor::QueryResult {
columns: vec![format!("Tables_in_{}", db_name)],
column_types: vec![crate::yaml::schema::SqlType::Text],
rows,
};
self.send_query_result(stream, state, &result).await
}
async fn handle_show_create_table(
&mut self,
stream: &mut TcpStream,
state: &mut ConnectionState,
query: &str,
) -> crate::Result<()> {
debug!("Handling SHOW CREATE TABLE");
let parts: Vec<&str> = query.split_whitespace().collect();
if parts.len() < 4 {
return self
.send_error(
stream,
state,
1064,
"42000",
"Invalid SHOW CREATE TABLE syntax",
)
.await;
}
let table_name = parts[3]
.trim_matches('`')
.trim_matches('"')
.trim_matches('\'');
let storage = self.executor.storage();
let db = storage.database();
let db_guard = db.read().await;
if let Some(table) = db_guard.tables.get(table_name) {
let mut create_sql = format!("CREATE TABLE `{}` (\n", table_name);
for (i, column) in table.columns.iter().enumerate() {
if i > 0 {
create_sql.push_str(",\n");
}
let mysql_type = match &column.sql_type {
crate::yaml::schema::SqlType::Integer => "int(11)",
crate::yaml::schema::SqlType::BigInt => "bigint(20)",
crate::yaml::schema::SqlType::Boolean => "tinyint(1)",
crate::yaml::schema::SqlType::Float => "float",
crate::yaml::schema::SqlType::Double => "double",
crate::yaml::schema::SqlType::Decimal(_, _) => "decimal(10,2)",
crate::yaml::schema::SqlType::Date => "date",
crate::yaml::schema::SqlType::Time => "time",
crate::yaml::schema::SqlType::Timestamp => "timestamp",
crate::yaml::schema::SqlType::Text => "text",
crate::yaml::schema::SqlType::Varchar(len) => &format!("varchar({})", len),
crate::yaml::schema::SqlType::Char(len) => &format!("char({})", len),
crate::yaml::schema::SqlType::Json => "json",
crate::yaml::schema::SqlType::Uuid => "char(36)",
};
create_sql.push_str(&format!(
" `{}` {} {}",
column.name,
mysql_type,
if column.nullable {
"DEFAULT NULL"
} else {
"NOT NULL"
}
));
}
create_sql.push_str("\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4");
let result = crate::sql::executor::QueryResult {
columns: vec!["Table".to_string(), "Create Table".to_string()],
column_types: vec![
crate::yaml::schema::SqlType::Text,
crate::yaml::schema::SqlType::Text,
],
rows: vec![vec![
Value::Text(table_name.to_string()),
Value::Text(create_sql),
]],
};
drop(db_guard);
self.send_query_result(stream, state, &result).await
} else {
drop(db_guard);
self.send_error(
stream,
state,
1146,
"42S02",
&format!("Table '{}' doesn't exist", table_name),
)
.await
}
}
async fn handle_show_databases(
&mut self,
stream: &mut TcpStream,
state: &mut ConnectionState,
) -> crate::Result<()> {
debug!("Handling SHOW DATABASES");
let storage = self.executor.storage();
let db = storage.database();
let db_guard = db.read().await;
let db_name = db_guard.name.clone();
drop(db_guard);
let result = crate::sql::executor::QueryResult {
columns: vec!["Database".to_string()],
column_types: vec![crate::yaml::schema::SqlType::Text],
rows: vec![
vec![Value::Text("information_schema".to_string())],
vec![Value::Text(db_name)],
],
};
self.send_query_result(stream, state, &result).await
}
async fn handle_show_columns(
&mut self,
stream: &mut TcpStream,
state: &mut ConnectionState,
query: &str,
) -> crate::Result<()> {
debug!("Handling SHOW COLUMNS");
let query_upper = query.to_uppercase();
let table_name = if let Some(from_pos) = query_upper.find(" FROM ") {
let table_part = &query[from_pos + 6..].trim();
table_part
.split_whitespace()
.next()
.unwrap_or("")
.trim_matches('`')
.trim_matches('"')
.trim_matches('\'')
} else {
return self
.send_error(stream, state, 1064, "42000", "Invalid SHOW COLUMNS syntax")
.await;
};
let storage = self.executor.storage();
let db = storage.database();
let db_guard = db.read().await;
if let Some(table) = db_guard.tables.get(table_name) {
let mut rows = Vec::new();
for column in &table.columns {
let mysql_type = match &column.sql_type {
crate::yaml::schema::SqlType::Integer => "int(11)",
crate::yaml::schema::SqlType::BigInt => "bigint(20)",
crate::yaml::schema::SqlType::Boolean => "tinyint(1)",
crate::yaml::schema::SqlType::Float => "float",
crate::yaml::schema::SqlType::Double => "double",
crate::yaml::schema::SqlType::Decimal(_, _) => "decimal(10,2)",
crate::yaml::schema::SqlType::Date => "date",
crate::yaml::schema::SqlType::Time => "time",
crate::yaml::schema::SqlType::Timestamp => "timestamp",
crate::yaml::schema::SqlType::Text => "text",
crate::yaml::schema::SqlType::Varchar(len) => &format!("varchar({})", len),
crate::yaml::schema::SqlType::Char(len) => &format!("char({})", len),
crate::yaml::schema::SqlType::Json => "json",
crate::yaml::schema::SqlType::Uuid => "char(36)",
};
rows.push(vec![
Value::Text(column.name.clone()),
Value::Text(mysql_type.to_string()),
Value::Text(if column.nullable { "YES" } else { "NO" }.to_string()),
Value::Text("".to_string()), Value::Text(if column.nullable { "NULL" } else { "" }.to_string()), Value::Text("".to_string()), ]);
}
let result = crate::sql::executor::QueryResult {
columns: vec![
"Field".to_string(),
"Type".to_string(),
"Null".to_string(),
"Key".to_string(),
"Default".to_string(),
"Extra".to_string(),
],
column_types: vec![
crate::yaml::schema::SqlType::Text,
crate::yaml::schema::SqlType::Text,
crate::yaml::schema::SqlType::Text,
crate::yaml::schema::SqlType::Text,
crate::yaml::schema::SqlType::Text,
crate::yaml::schema::SqlType::Text,
],
rows,
};
drop(db_guard);
self.send_query_result(stream, state, &result).await
} else {
drop(db_guard);
self.send_error(
stream,
state,
1146,
"42S02",
&format!("Table '{}' doesn't exist", table_name),
)
.await
}
}
async fn handle_show_indexes(
&mut self,
stream: &mut TcpStream,
state: &mut ConnectionState,
_query: &str,
) -> crate::Result<()> {
debug!("Handling SHOW INDEXES");
let result = crate::sql::executor::QueryResult {
columns: vec![
"Table".to_string(),
"Non_unique".to_string(),
"Key_name".to_string(),
"Seq_in_index".to_string(),
"Column_name".to_string(),
"Collation".to_string(),
"Cardinality".to_string(),
"Sub_part".to_string(),
"Packed".to_string(),
"Null".to_string(),
"Index_type".to_string(),
"Comment".to_string(),
"Index_comment".to_string(),
],
column_types: vec![crate::yaml::schema::SqlType::Text; 13],
rows: Vec::new(),
};
self.send_query_result(stream, state, &result).await
}
async fn handle_show_status(
&mut self,
stream: &mut TcpStream,
state: &mut ConnectionState,
) -> crate::Result<()> {
debug!("Handling SHOW STATUS");
let rows = vec![
vec![
Value::Text("Connections".to_string()),
Value::Text("1".to_string()),
],
vec![
Value::Text("Uptime".to_string()),
Value::Text("3600".to_string()),
],
vec![
Value::Text("Threads_connected".to_string()),
Value::Text("1".to_string()),
],
];
let result = crate::sql::executor::QueryResult {
columns: vec!["Variable_name".to_string(), "Value".to_string()],
column_types: vec![
crate::yaml::schema::SqlType::Text,
crate::yaml::schema::SqlType::Text,
],
rows,
};
self.send_query_result(stream, state, &result).await
}
async fn handle_show_engines(
&mut self,
stream: &mut TcpStream,
state: &mut ConnectionState,
) -> crate::Result<()> {
debug!("Handling SHOW ENGINES");
let rows = vec![
vec![
Value::Text("InnoDB".to_string()),
Value::Text("DEFAULT".to_string()),
Value::Text(
"Supports transactions, row-level locking, and foreign keys".to_string(),
),
Value::Text("YES".to_string()),
Value::Text("YES".to_string()),
Value::Text("YES".to_string()),
],
vec![
Value::Text("MyISAM".to_string()),
Value::Text("YES".to_string()),
Value::Text("MyISAM storage engine".to_string()),
Value::Text("NO".to_string()),
Value::Text("NO".to_string()),
Value::Text("NO".to_string()),
],
];
let result = crate::sql::executor::QueryResult {
columns: vec![
"Engine".to_string(),
"Support".to_string(),
"Comment".to_string(),
"Transactions".to_string(),
"XA".to_string(),
"Savepoints".to_string(),
],
column_types: vec![crate::yaml::schema::SqlType::Text; 6],
rows,
};
self.send_query_result(stream, state, &result).await
}
async fn handle_show_collation(
&mut self,
stream: &mut TcpStream,
state: &mut ConnectionState,
) -> crate::Result<()> {
debug!("Handling SHOW COLLATION");
let rows = vec![
vec![
Value::Text("utf8mb4_0900_ai_ci".to_string()),
Value::Text("utf8mb4".to_string()),
Value::Text("255".to_string()),
Value::Text("Yes".to_string()),
Value::Text("Yes".to_string()),
Value::Text("8".to_string()),
],
vec![
Value::Text("utf8mb4_general_ci".to_string()),
Value::Text("utf8mb4".to_string()),
Value::Text("45".to_string()),
Value::Text("".to_string()),
Value::Text("Yes".to_string()),
Value::Text("1".to_string()),
],
];
let result = crate::sql::executor::QueryResult {
columns: vec![
"Collation".to_string(),
"Charset".to_string(),
"Id".to_string(),
"Default".to_string(),
"Compiled".to_string(),
"Sortlen".to_string(),
],
column_types: vec![crate::yaml::schema::SqlType::Text; 6],
rows,
};
self.send_query_result(stream, state, &result).await
}
async fn handle_show_character_set(
&mut self,
stream: &mut TcpStream,
state: &mut ConnectionState,
) -> crate::Result<()> {
debug!("Handling SHOW CHARACTER SET");
let rows = vec![
vec![
Value::Text("utf8mb4".to_string()),
Value::Text("UTF-8 Unicode".to_string()),
Value::Text("utf8mb4_0900_ai_ci".to_string()),
Value::Text("4".to_string()),
],
vec![
Value::Text("utf8mb3".to_string()),
Value::Text("UTF-8 Unicode".to_string()),
Value::Text("utf8mb3_general_ci".to_string()),
Value::Text("3".to_string()),
],
];
let result = crate::sql::executor::QueryResult {
columns: vec![
"Charset".to_string(),
"Description".to_string(),
"Default collation".to_string(),
"Maxlen".to_string(),
],
column_types: vec![crate::yaml::schema::SqlType::Text; 4],
rows,
};
self.send_query_result(stream, state, &result).await
}
async fn handle_describe_command(
&mut self,
stream: &mut TcpStream,
state: &mut ConnectionState,
query: &str,
) -> crate::Result<()> {
let show_query = query.replacen("DESCRIBE", "SHOW COLUMNS FROM", 1).replacen(
"DESC",
"SHOW COLUMNS FROM",
1,
);
self.handle_show_columns(stream, state, &show_query).await
}
async fn handle_information_schema_query(
&mut self,
stream: &mut TcpStream,
state: &mut ConnectionState,
query: &str,
) -> crate::Result<()> {
debug!("Handling information_schema query: {}", query);
let query_upper = query.to_uppercase();
let result = if query_upper.contains("INFORMATION_SCHEMA.TABLES") {
self.information_schema.query_tables(Some(query))
} else if query_upper.contains("INFORMATION_SCHEMA.COLUMNS") {
self.information_schema.query_columns(Some(query))
} else if query_upper.contains("INFORMATION_SCHEMA.SCHEMATA") {
self.information_schema.query_schemata(Some(query))
} else if query_upper.contains("INFORMATION_SCHEMA.KEY_COLUMN_USAGE") {
self.information_schema.query_key_column_usage(Some(query))
} else {
crate::sql::executor::QueryResult {
columns: vec!["result".to_string()],
column_types: vec![crate::yaml::schema::SqlType::Text],
rows: Vec::new(),
}
};
self.send_query_result(stream, state, &result).await
}
async fn send_query_result(
&self,
stream: &mut TcpStream,
state: &mut ConnectionState,
result: &crate::sql::executor::QueryResult,
) -> crate::Result<()> {
debug!(
"Sending query result with {} columns and {} rows",
result.columns.len(),
result.rows.len()
);
let columns: Vec<&str> = result.columns.iter().map(|s| s.as_str()).collect();
let rows: Vec<Vec<String>> = result
.rows
.iter()
.map(|row| row.iter().map(|val| val.to_string()).collect())
.collect();
let string_rows: Vec<Vec<&str>> = rows
.iter()
.map(|row| row.iter().map(|s| s.as_str()).collect())
.collect();
self.send_simple_result_set(stream, state, &columns, &string_rows)
.await
}
async fn send_simple_result_set(
&self,
stream: &mut TcpStream,
state: &mut ConnectionState,
columns: &[&str],
rows: &[Vec<&str>],
) -> crate::Result<()> {
debug!(
"send_simple_result_set: {} columns, {} rows",
columns.len(),
rows.len()
);
let mut packet = BytesMut::new();
packet.put_u8(columns.len() as u8);
self.write_packet(stream, state, &packet).await?;
for (idx, column) in columns.iter().enumerate() {
debug!("Writing column definition {}: {}", idx, column);
let mut col_packet = BytesMut::new();
col_packet.put_u8(3);
col_packet.put_slice(b"def");
col_packet.put_u8(0);
col_packet.put_u8(0);
col_packet.put_u8(0);
col_packet.put_u8(column.len() as u8);
col_packet.put_slice(column.as_bytes());
col_packet.put_u8(column.len() as u8);
col_packet.put_slice(column.as_bytes());
col_packet.put_u8(0x0c);
col_packet.put_u16_le(33);
col_packet.put_u32_le(255);
col_packet.put_u8(MYSQL_TYPE_VAR_STRING);
col_packet.put_u16_le(0);
col_packet.put_u8(0);
col_packet.put_u16_le(0);
self.write_packet(stream, state, &col_packet).await?;
}
if (state.capabilities & CLIENT_DEPRECATE_EOF) == 0 {
let mut eof_packet = BytesMut::new();
eof_packet.put_u8(0xfe); eof_packet.put_u16_le(0); eof_packet.put_u16_le(SERVER_STATUS_AUTOCOMMIT); self.write_packet(stream, state, &eof_packet).await?
}
for row in rows {
let mut row_packet = BytesMut::new();
for value in row {
if *value == "NULL" {
row_packet.put_u8(0xfb); } else {
let bytes = value.as_bytes();
if bytes.len() < 251 {
row_packet.put_u8(bytes.len() as u8);
} else if bytes.len() < 65536 {
row_packet.put_u8(0xfc);
row_packet.put_u16_le(bytes.len() as u16);
} else if bytes.len() < 16777216 {
row_packet.put_u8(0xfd);
row_packet.put_u8((bytes.len() & 0xff) as u8);
row_packet.put_u8(((bytes.len() >> 8) & 0xff) as u8);
row_packet.put_u8(((bytes.len() >> 16) & 0xff) as u8);
} else {
row_packet.put_u8(0xfe);
row_packet.put_u64_le(bytes.len() as u64);
}
row_packet.put_slice(bytes);
}
}
self.write_packet(stream, state, &row_packet).await?;
}
if (state.capabilities & CLIENT_DEPRECATE_EOF) != 0 {
let mut ok_packet = BytesMut::new();
ok_packet.put_u8(0x00); ok_packet.put_u8(0x00); ok_packet.put_u8(0x00); ok_packet.put_u16_le(SERVER_STATUS_AUTOCOMMIT); ok_packet.put_u16_le(0); self.write_packet(stream, state, &ok_packet).await
} else {
let mut eof_packet = BytesMut::new();
eof_packet.put_u8(0xfe); eof_packet.put_u16_le(0); eof_packet.put_u16_le(SERVER_STATUS_AUTOCOMMIT); self.write_packet(stream, state, &eof_packet).await
}
}
async fn send_ok(
&self,
stream: &mut TcpStream,
state: &mut ConnectionState,
affected_rows: u64,
_info: u64,
) -> crate::Result<()> {
let mut packet = BytesMut::new();
packet.put_u8(0x00);
put_lenenc_int(&mut packet, affected_rows);
put_lenenc_int(&mut packet, 0);
packet.put_u16_le(SERVER_STATUS_AUTOCOMMIT);
packet.put_u16_le(0);
self.write_packet(stream, state, &packet).await
}
async fn send_error(
&self,
stream: &mut TcpStream,
state: &mut ConnectionState,
error_code: u16,
sql_state: &str,
message: &str,
) -> crate::Result<()> {
let mut packet = BytesMut::new();
packet.put_u8(0xff);
packet.put_u16_le(error_code);
packet.put_u8(b'#');
packet.put_slice(sql_state.as_bytes());
packet.put_slice(message.as_bytes());
self.write_packet(stream, state, &packet).await
}
async fn write_packet(
&self,
stream: &mut TcpStream,
state: &mut ConnectionState,
payload: &[u8],
) -> crate::Result<()> {
const MAX_PACKET_SIZE: usize = 0xffffff;
if payload.len() <= MAX_PACKET_SIZE {
let mut packet = BytesMut::with_capacity(4 + payload.len());
packet.put_u8((payload.len() & 0xff) as u8);
packet.put_u8(((payload.len() >> 8) & 0xff) as u8);
packet.put_u8(((payload.len() >> 16) & 0xff) as u8);
packet.put_u8(state.sequence_id);
state.sequence_id = state.sequence_id.wrapping_add(1);
packet.put_slice(payload);
stream.write_all(&packet).await?;
stream.flush().await?;
} else {
let mut offset = 0;
while offset < payload.len() {
let chunk_size = std::cmp::min(MAX_PACKET_SIZE, payload.len() - offset);
let chunk = &payload[offset..offset + chunk_size];
let mut packet = BytesMut::with_capacity(4 + chunk_size);
packet.put_u8((chunk_size & 0xff) as u8);
packet.put_u8(((chunk_size >> 8) & 0xff) as u8);
packet.put_u8(((chunk_size >> 16) & 0xff) as u8);
packet.put_u8(state.sequence_id);
state.sequence_id = state.sequence_id.wrapping_add(1);
packet.put_slice(chunk);
stream.write_all(&packet).await?;
stream.flush().await?;
offset += chunk_size;
}
}
Ok(())
}
async fn read_packet(
&self,
stream: &mut TcpStream,
state: &mut ConnectionState,
) -> crate::Result<Vec<u8>> {
let mut header = [0u8; 4];
match stream.read_exact(&mut header).await {
Ok(_) => {}
Err(e) => {
debug!("Error reading packet header: {}", e);
return Err(YamlBaseError::Io(e));
}
}
let len = (header[0] as usize) | ((header[1] as usize) << 8) | ((header[2] as usize) << 16);
state.sequence_id = header[3].wrapping_add(1);
if len == 0 {
return Ok(Vec::new());
}
let mut payload = vec![0u8; len];
match stream.read_exact(&mut payload).await {
Ok(_) => Ok(payload),
Err(e) => {
debug!("Error reading packet payload: {}", e);
Err(YamlBaseError::Io(e))
}
}
}
}
fn generate_auth_data() -> Vec<u8> {
use rand::Rng;
let mut rng = rand::thread_rng();
let mut auth_data = vec![0u8; 20];
rng.fill(&mut auth_data[..]);
auth_data
}
fn compute_auth_response(password: &str, auth_data: &[u8]) -> Vec<u8> {
if password.is_empty() {
return Vec::new();
}
let mut hasher = Sha1::new();
hasher.update(password.as_bytes());
let stage1 = hasher.finalize();
let mut hasher = Sha1::new();
hasher.update(stage1);
let stage2 = hasher.finalize();
let mut hasher = Sha1::new();
hasher.update(auth_data);
hasher.update(stage2);
let result = hasher.finalize();
stage1
.iter()
.zip(result.iter())
.map(|(a, b)| a ^ b)
.collect()
}
fn put_lenenc_int(buf: &mut BytesMut, value: u64) {
if value < 251 {
buf.put_u8(value as u8);
} else if value < 65536 {
buf.put_u8(0xfc);
buf.put_u16_le(value as u16);
} else if value < 16777216 {
buf.put_u8(0xfd);
buf.put_u8((value & 0xff) as u8);
buf.put_u8(((value >> 8) & 0xff) as u8);
buf.put_u8(((value >> 16) & 0xff) as u8);
} else {
buf.put_u8(0xfe);
buf.put_u64_le(value);
}
}