use bytes::{BufMut, BytesMut};
use chrono::{Datelike, Timelike};
use std::collections::HashMap;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
use tracing::debug;
use crate::YamlBaseError;
use crate::database::Value;
use crate::sql::executor::QueryResult;
const COM_STMT_PREPARE: u8 = 0x16;
const COM_STMT_EXECUTE: u8 = 0x17;
const COM_STMT_CLOSE: u8 = 0x19;
const COM_STMT_RESET: u8 = 0x1a;
const STMT_OK: u8 = 0x00;
const _MYSQL_TYPE_DECIMAL: u8 = 0x00;
const MYSQL_TYPE_TINY: u8 = 0x01;
const _MYSQL_TYPE_SHORT: u8 = 0x02;
const MYSQL_TYPE_LONG: u8 = 0x03;
const MYSQL_TYPE_FLOAT: u8 = 0x04;
const MYSQL_TYPE_DOUBLE: u8 = 0x05;
const _MYSQL_TYPE_NULL: u8 = 0x06;
const MYSQL_TYPE_TIMESTAMP: u8 = 0x07;
const MYSQL_TYPE_LONGLONG: u8 = 0x08;
const _MYSQL_TYPE_INT24: u8 = 0x09;
const MYSQL_TYPE_DATE: u8 = 0x0a;
const MYSQL_TYPE_TIME: u8 = 0x0b;
const _MYSQL_TYPE_DATETIME: u8 = 0x0c;
const _MYSQL_TYPE_YEAR: u8 = 0x0d;
const _MYSQL_TYPE_NEWDATE: u8 = 0x0e;
const MYSQL_TYPE_VARCHAR: u8 = 0x0f;
const _MYSQL_TYPE_BIT: u8 = 0x10;
const MYSQL_TYPE_NEWDECIMAL: u8 = 0xf6;
const _MYSQL_TYPE_ENUM: u8 = 0xf7;
const _MYSQL_TYPE_SET: u8 = 0xf8;
const _MYSQL_TYPE_TINY_BLOB: u8 = 0xf9;
const _MYSQL_TYPE_MEDIUM_BLOB: u8 = 0xfa;
const _MYSQL_TYPE_LONG_BLOB: u8 = 0xfb;
const _MYSQL_TYPE_BLOB: u8 = 0xfc;
const MYSQL_TYPE_VAR_STRING: u8 = 0xfd;
const MYSQL_TYPE_STRING: u8 = 0xfe;
const _MYSQL_TYPE_GEOMETRY: u8 = 0xff;
#[derive(Debug, Clone)]
pub struct PreparedStatement {
pub id: u32,
pub query: String,
pub param_count: u16,
pub columns: Vec<String>,
pub column_types: Vec<crate::yaml::schema::SqlType>,
}
pub struct MySqlBinaryProtocol {
statements: HashMap<u32, PreparedStatement>,
next_stmt_id: u32,
}
impl MySqlBinaryProtocol {
pub fn new() -> Self {
Self {
statements: HashMap::new(),
next_stmt_id: 1,
}
}
pub async fn handle_binary_command(
&mut self,
command: u8,
payload: &[u8],
stream: &mut TcpStream,
sequence_id: &mut u8,
) -> crate::Result<bool> {
match command {
COM_STMT_PREPARE => {
self.handle_stmt_prepare(payload, stream, sequence_id)
.await?;
Ok(true)
}
COM_STMT_EXECUTE => {
self.handle_stmt_execute(payload, stream, sequence_id)
.await?;
Ok(true)
}
COM_STMT_CLOSE => {
self.handle_stmt_close(payload)?;
Ok(true)
}
COM_STMT_RESET => {
self.handle_stmt_reset(payload, stream, sequence_id).await?;
Ok(true)
}
_ => Ok(false), }
}
async fn handle_stmt_prepare(
&mut self,
payload: &[u8],
stream: &mut TcpStream,
sequence_id: &mut u8,
) -> crate::Result<()> {
let query = std::str::from_utf8(payload).map_err(|_| {
YamlBaseError::Protocol("Invalid UTF-8 in prepared statement".to_string())
})?;
debug!("Preparing statement: {}", query);
let stmt_id = self.next_stmt_id;
self.next_stmt_id += 1;
let param_count = query.matches('?').count() as u16;
let stmt = PreparedStatement {
id: stmt_id,
query: query.to_string(),
param_count,
columns: Vec::new(), column_types: Vec::new(),
};
self.statements.insert(stmt_id, stmt);
let mut packet = BytesMut::new();
packet.put_u8(STMT_OK);
packet.put_u32_le(stmt_id);
packet.put_u16_le(0);
packet.put_u16_le(param_count);
packet.put_u8(0);
packet.put_u16_le(0);
self.write_packet(stream, sequence_id, &packet).await?;
Ok(())
}
async fn handle_stmt_execute(
&mut self,
payload: &[u8],
stream: &mut TcpStream,
sequence_id: &mut u8,
) -> crate::Result<()> {
if payload.len() < 4 {
return Err(YamlBaseError::Protocol(
"Invalid STMT_EXECUTE packet".to_string(),
));
}
let stmt_id = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]);
debug!("Executing prepared statement ID: {}", stmt_id);
let _stmt = self
.statements
.get(&stmt_id)
.ok_or_else(|| YamlBaseError::Protocol("Unknown statement ID".to_string()))?;
let mut packet = BytesMut::new();
packet.put_u8(0x00); packet.put_u8(0x00); packet.put_u8(0x00); packet.put_u16_le(0x0002); packet.put_u16_le(0);
self.write_packet(stream, sequence_id, &packet).await?;
Ok(())
}
fn handle_stmt_close(&mut self, payload: &[u8]) -> crate::Result<()> {
if payload.len() < 4 {
return Err(YamlBaseError::Protocol(
"Invalid STMT_CLOSE packet".to_string(),
));
}
let stmt_id = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]);
debug!("Closing prepared statement ID: {}", stmt_id);
self.statements.remove(&stmt_id);
Ok(())
}
async fn handle_stmt_reset(
&mut self,
payload: &[u8],
stream: &mut TcpStream,
sequence_id: &mut u8,
) -> crate::Result<()> {
if payload.len() < 4 {
return Err(YamlBaseError::Protocol(
"Invalid STMT_RESET packet".to_string(),
));
}
let stmt_id = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]);
debug!("Resetting prepared statement ID: {}", stmt_id);
if !self.statements.contains_key(&stmt_id) {
return Err(YamlBaseError::Protocol("Unknown statement ID".to_string()));
}
let mut packet = BytesMut::new();
packet.put_u8(0x00); packet.put_u8(0x00); packet.put_u8(0x00); packet.put_u16_le(0x0002); packet.put_u16_le(0);
self.write_packet(stream, sequence_id, &packet).await?;
Ok(())
}
pub async fn send_binary_result_set(
&self,
stream: &mut TcpStream,
sequence_id: &mut u8,
result: &QueryResult,
) -> crate::Result<()> {
debug!(
"Sending binary result set with {} columns and {} rows",
result.columns.len(),
result.rows.len()
);
let mut packet = BytesMut::new();
self.put_lenenc_int(&mut packet, result.columns.len() as u64);
self.write_packet(stream, sequence_id, &packet).await?;
for (i, column_name) in result.columns.iter().enumerate() {
let column_type = result
.column_types
.get(i)
.unwrap_or(&crate::yaml::schema::SqlType::Text);
let mut col_packet = BytesMut::new();
self.put_lenenc_string(&mut col_packet, "def");
self.put_lenenc_string(&mut col_packet, "");
self.put_lenenc_string(&mut col_packet, "");
self.put_lenenc_string(&mut col_packet, "");
self.put_lenenc_string(&mut col_packet, column_name);
self.put_lenenc_string(&mut col_packet, column_name);
col_packet.put_u8(0x0c);
col_packet.put_u16_le(33);
col_packet.put_u32_le(255);
col_packet.put_u8(self.sql_type_to_mysql_type(column_type));
col_packet.put_u16_le(0);
col_packet.put_u8(0);
col_packet.put_u16_le(0);
self.write_packet(stream, sequence_id, &col_packet).await?;
}
let mut eof_packet = BytesMut::new();
eof_packet.put_u8(0xfe); eof_packet.put_u16_le(0); eof_packet.put_u16_le(0x0002); self.write_packet(stream, sequence_id, &eof_packet).await?;
for row in &result.rows {
let mut row_packet = BytesMut::new();
row_packet.put_u8(0x00);
let null_bitmap_len = (result.columns.len() + 7 + 2) / 8;
let mut null_bitmap = vec![0u8; null_bitmap_len];
for (i, value) in row.iter().enumerate() {
if matches!(value, Value::Null) {
let byte_index = (i + 2) / 8;
let bit_index = (i + 2) % 8;
null_bitmap[byte_index] |= 1 << bit_index;
}
}
row_packet.put_slice(&null_bitmap);
for (i, value) in row.iter().enumerate() {
if !matches!(value, Value::Null) {
let column_type = result
.column_types
.get(i)
.unwrap_or(&crate::yaml::schema::SqlType::Text);
self.encode_binary_value(&mut row_packet, value, column_type);
}
}
self.write_packet(stream, sequence_id, &row_packet).await?;
}
let mut eof_packet = BytesMut::new();
eof_packet.put_u8(0xfe); eof_packet.put_u16_le(0); eof_packet.put_u16_le(0x0002); self.write_packet(stream, sequence_id, &eof_packet).await?;
Ok(())
}
fn sql_type_to_mysql_type(&self, sql_type: &crate::yaml::schema::SqlType) -> u8 {
match sql_type {
crate::yaml::schema::SqlType::Integer => MYSQL_TYPE_LONG,
crate::yaml::schema::SqlType::BigInt => MYSQL_TYPE_LONGLONG,
crate::yaml::schema::SqlType::Boolean => MYSQL_TYPE_TINY,
crate::yaml::schema::SqlType::Float => MYSQL_TYPE_FLOAT,
crate::yaml::schema::SqlType::Double => MYSQL_TYPE_DOUBLE,
crate::yaml::schema::SqlType::Decimal(_, _) => MYSQL_TYPE_NEWDECIMAL,
crate::yaml::schema::SqlType::Date => MYSQL_TYPE_DATE,
crate::yaml::schema::SqlType::Time => MYSQL_TYPE_TIME,
crate::yaml::schema::SqlType::Timestamp => MYSQL_TYPE_TIMESTAMP,
crate::yaml::schema::SqlType::Text => MYSQL_TYPE_VAR_STRING,
crate::yaml::schema::SqlType::Varchar(_) => MYSQL_TYPE_VARCHAR,
crate::yaml::schema::SqlType::Char(_) => MYSQL_TYPE_STRING,
crate::yaml::schema::SqlType::Json => MYSQL_TYPE_VAR_STRING,
crate::yaml::schema::SqlType::Uuid => MYSQL_TYPE_STRING,
}
}
fn encode_binary_value(
&self,
packet: &mut BytesMut,
value: &Value,
sql_type: &crate::yaml::schema::SqlType,
) {
match value {
Value::Null => {
}
Value::Integer(i) => match sql_type {
crate::yaml::schema::SqlType::Boolean => {
packet.put_u8(*i as u8);
}
crate::yaml::schema::SqlType::Integer => {
packet.put_u32_le(*i as u32);
}
_ => {
packet.put_u64_le(*i as u64);
}
},
Value::Float(f) => {
packet.put_f32_le(*f);
}
Value::Double(f) => {
packet.put_f64_le(*f);
}
Value::Text(s) => {
self.put_lenenc_string(packet, s);
}
Value::Boolean(b) => {
packet.put_u8(if *b { 1 } else { 0 });
}
Value::Date(d) => {
packet.put_u8(4); packet.put_u16_le(d.year() as u16);
packet.put_u8(d.month() as u8);
packet.put_u8(d.day() as u8);
}
Value::Timestamp(dt) => {
let has_time =
dt.hour() != 0 || dt.minute() != 0 || dt.second() != 0 || dt.nanosecond() != 0;
let has_microseconds = dt.nanosecond() != 0;
if !has_time {
packet.put_u8(4); packet.put_u16_le(dt.year() as u16);
packet.put_u8(dt.month() as u8);
packet.put_u8(dt.day() as u8);
} else if !has_microseconds {
packet.put_u8(7); packet.put_u16_le(dt.year() as u16);
packet.put_u8(dt.month() as u8);
packet.put_u8(dt.day() as u8);
packet.put_u8(dt.hour() as u8);
packet.put_u8(dt.minute() as u8);
packet.put_u8(dt.second() as u8);
} else {
packet.put_u8(11); packet.put_u16_le(dt.year() as u16);
packet.put_u8(dt.month() as u8);
packet.put_u8(dt.day() as u8);
packet.put_u8(dt.hour() as u8);
packet.put_u8(dt.minute() as u8);
packet.put_u8(dt.second() as u8);
packet.put_u32_le(dt.nanosecond() / 1000); }
}
Value::Time(_t) => {
packet.put_u8(0); }
Value::Decimal(d) => {
self.put_lenenc_string(packet, &d.to_string());
}
Value::Uuid(u) => {
self.put_lenenc_string(packet, &u.to_string());
}
Value::Json(j) => {
self.put_lenenc_string(packet, &j.to_string());
}
}
}
fn put_lenenc_int(&self, packet: &mut BytesMut, value: u64) {
if value < 251 {
packet.put_u8(value as u8);
} else if value < 65536 {
packet.put_u8(0xfc);
packet.put_u16_le(value as u16);
} else if value < 16777216 {
packet.put_u8(0xfd);
packet.put_u8((value & 0xff) as u8);
packet.put_u8(((value >> 8) & 0xff) as u8);
packet.put_u8(((value >> 16) & 0xff) as u8);
} else {
packet.put_u8(0xfe);
packet.put_u64_le(value);
}
}
fn put_lenenc_string(&self, packet: &mut BytesMut, s: &str) {
let bytes = s.as_bytes();
self.put_lenenc_int(packet, bytes.len() as u64);
packet.put_slice(bytes);
}
async fn write_packet(
&self,
stream: &mut TcpStream,
sequence_id: &mut u8,
payload: &[u8],
) -> crate::Result<()> {
let mut packet = BytesMut::with_capacity(4 + payload.len());
packet.put_u8((payload.len() & 0xff) as u8);
packet.put_u8(((payload.len() >> 8) & 0xff) as u8);
packet.put_u8(((payload.len() >> 16) & 0xff) as u8);
packet.put_u8(*sequence_id);
*sequence_id = sequence_id.wrapping_add(1);
packet.put_slice(payload);
stream.write_all(&packet).await?;
stream.flush().await?;
Ok(())
}
}
impl Default for MySqlBinaryProtocol {
fn default() -> Self {
Self::new()
}
}