#![allow(clippy::cast_possible_truncation)]
use super::{Command, PacketWriter};
use crate::types::{ColumnDef, FieldType};
use sqlmodel_core::Value;
#[derive(Debug, Clone)]
pub struct StmtPrepareOk {
pub statement_id: u32,
pub num_columns: u16,
pub num_params: u16,
pub warnings: u16,
}
#[derive(Debug, Clone)]
pub struct PreparedStatement {
pub statement_id: u32,
pub sql: String,
pub params: Vec<ColumnDef>,
pub columns: Vec<ColumnDef>,
}
impl PreparedStatement {
pub fn new(
statement_id: u32,
sql: String,
params: Vec<ColumnDef>,
columns: Vec<ColumnDef>,
) -> Self {
Self {
statement_id,
sql,
params,
columns,
}
}
#[must_use]
pub fn param_count(&self) -> usize {
self.params.len()
}
#[must_use]
pub fn column_count(&self) -> usize {
self.columns.len()
}
}
pub fn build_stmt_prepare_packet(sql: &str, sequence_id: u8) -> Vec<u8> {
let mut writer = PacketWriter::with_capacity(1 + sql.len());
writer.write_u8(Command::StmtPrepare as u8);
writer.write_bytes(sql.as_bytes());
writer.build_packet(sequence_id)
}
pub fn build_stmt_execute_packet(
statement_id: u32,
params: &[Value],
param_types: Option<&[FieldType]>,
sequence_id: u8,
) -> Vec<u8> {
let mut writer = PacketWriter::with_capacity(64 + params.len() * 16);
writer.write_u8(Command::StmtExecute as u8);
writer.write_u32_le(statement_id);
writer.write_u8(0x00);
writer.write_u32_le(1);
if !params.is_empty() {
let null_bitmap_len = params.len().div_ceil(8);
let mut null_bitmap = vec![0u8; null_bitmap_len];
for (i, param) in params.iter().enumerate() {
if matches!(param, Value::Null) {
null_bitmap[i / 8] |= 1 << (i % 8);
}
}
writer.write_bytes(&null_bitmap);
writer.write_u8(1);
for (i, param) in params.iter().enumerate() {
let field_type = if let Some(types) = param_types {
if i < types.len() {
types[i]
} else {
value_to_field_type(param)
}
} else {
value_to_field_type(param)
};
writer.write_u8(field_type as u8);
let flags = if is_unsigned_value(param) { 0x80 } else { 0x00 };
writer.write_u8(flags);
}
for param in params {
if !matches!(param, Value::Null) {
encode_binary_param(&mut writer, param);
}
}
}
writer.build_packet(sequence_id)
}
pub fn build_stmt_close_packet(statement_id: u32, sequence_id: u8) -> Vec<u8> {
let mut writer = PacketWriter::with_capacity(5);
writer.write_u8(Command::StmtClose as u8);
writer.write_u32_le(statement_id);
writer.build_packet(sequence_id)
}
pub fn build_stmt_reset_packet(statement_id: u32, sequence_id: u8) -> Vec<u8> {
let mut writer = PacketWriter::with_capacity(5);
writer.write_u8(Command::StmtReset as u8);
writer.write_u32_le(statement_id);
writer.build_packet(sequence_id)
}
pub fn parse_stmt_prepare_ok(data: &[u8]) -> Option<StmtPrepareOk> {
if data.len() < 12 {
return None;
}
if data[0] != 0x00 {
return None;
}
let statement_id = u32::from_le_bytes([data[1], data[2], data[3], data[4]]);
let num_columns = u16::from_le_bytes([data[5], data[6]]);
let num_params = u16::from_le_bytes([data[7], data[8]]);
let warnings = if data.len() >= 12 {
u16::from_le_bytes([data[10], data[11]])
} else {
0
};
Some(StmtPrepareOk {
statement_id,
num_columns,
num_params,
warnings,
})
}
fn value_to_field_type(value: &Value) -> FieldType {
match value {
Value::Null => FieldType::Null,
Value::Bool(_) => FieldType::Tiny,
Value::TinyInt(_) => FieldType::Tiny,
Value::SmallInt(_) => FieldType::Short,
Value::Int(_) => FieldType::Long,
Value::BigInt(_) => FieldType::LongLong,
Value::Float(_) => FieldType::Float,
Value::Double(_) => FieldType::Double,
Value::Decimal(_) => FieldType::NewDecimal,
Value::Text(_) => FieldType::VarString,
Value::Bytes(_) => FieldType::Blob,
Value::Json(_) => FieldType::Json,
Value::Date(_) => FieldType::Date,
Value::Time(_) => FieldType::Time,
Value::Timestamp(_) | Value::TimestampTz(_) => FieldType::DateTime,
Value::Uuid(_) => FieldType::Blob,
Value::Array(_) => FieldType::Json,
Value::Default => FieldType::Null,
}
}
fn is_unsigned_value(value: &Value) -> bool {
matches!(value, Value::BigInt(i) if *i > i64::MAX / 2)
}
fn encode_binary_param(writer: &mut PacketWriter, value: &Value) {
match value {
Value::Null => {
}
Value::Bool(b) => {
writer.write_u8(if *b { 1 } else { 0 });
}
Value::TinyInt(i) => {
writer.write_u8(*i as u8);
}
Value::SmallInt(i) => {
writer.write_u16_le(*i as u16);
}
Value::Int(i) => {
writer.write_u32_le(*i as u32);
}
Value::BigInt(i) => {
writer.write_u64_le(*i as u64);
}
Value::Float(f) => {
writer.write_bytes(&f.to_le_bytes());
}
Value::Double(f) => {
writer.write_bytes(&f.to_le_bytes());
}
Value::Decimal(s) => {
write_length_encoded_string(writer, s);
}
Value::Text(s) => {
write_length_encoded_string(writer, s);
}
Value::Bytes(b) => {
write_length_encoded_bytes(writer, b);
}
Value::Json(j) => {
let s = j.to_string();
write_length_encoded_string(writer, &s);
}
Value::Date(days) => {
encode_binary_date(writer, *days);
}
Value::Time(micros) => {
encode_binary_time(writer, *micros);
}
Value::Timestamp(micros) | Value::TimestampTz(micros) => {
encode_binary_datetime(writer, *micros);
}
Value::Uuid(bytes) => {
write_length_encoded_bytes(writer, bytes);
}
Value::Array(arr) => {
let s = serde_json::to_string(arr).unwrap_or_default();
write_length_encoded_string(writer, &s);
}
Value::Default => {
}
}
}
fn write_length_encoded_string(writer: &mut PacketWriter, s: &str) {
write_length_encoded_bytes(writer, s.as_bytes());
}
fn write_length_encoded_bytes(writer: &mut PacketWriter, data: &[u8]) {
let len = data.len();
if len < 251 {
writer.write_u8(len as u8);
} else if len < 0x10000 {
writer.write_u8(0xFC);
writer.write_u16_le(len as u16);
} else if len < 0x0100_0000 {
writer.write_u8(0xFD);
writer.write_u8((len & 0xFF) as u8);
writer.write_u8(((len >> 8) & 0xFF) as u8);
writer.write_u8(((len >> 16) & 0xFF) as u8);
} else {
writer.write_u8(0xFE);
writer.write_u64_le(len as u64);
}
writer.write_bytes(data);
}
fn encode_binary_date(writer: &mut PacketWriter, days: i32) {
let (year, month, day) = days_to_ymd(days);
if year == 0 && month == 0 && day == 0 {
writer.write_u8(0);
} else {
writer.write_u8(4); writer.write_u16_le(year as u16);
writer.write_u8(month as u8);
writer.write_u8(day as u8);
}
}
fn encode_binary_time(writer: &mut PacketWriter, micros: i64) {
let is_negative = micros < 0;
let micros = micros.unsigned_abs();
let total_seconds = micros / 1_000_000;
let microseconds = (micros % 1_000_000) as u32;
let hours = total_seconds / 3600;
let minutes = (total_seconds % 3600) / 60;
let seconds = total_seconds % 60;
let days = hours / 24;
let hours = hours % 24;
if days == 0 && hours == 0 && minutes == 0 && seconds == 0 && microseconds == 0 {
writer.write_u8(0); } else if microseconds == 0 {
writer.write_u8(8); writer.write_u8(if is_negative { 1 } else { 0 });
writer.write_u32_le(days as u32);
writer.write_u8(hours as u8);
writer.write_u8(minutes as u8);
writer.write_u8(seconds as u8);
} else {
writer.write_u8(12); writer.write_u8(if is_negative { 1 } else { 0 });
writer.write_u32_le(days as u32);
writer.write_u8(hours as u8);
writer.write_u8(minutes as u8);
writer.write_u8(seconds as u8);
writer.write_u32_le(microseconds);
}
}
fn encode_binary_datetime(writer: &mut PacketWriter, micros: i64) {
let total_seconds = micros / 1_000_000;
let microseconds = (micros % 1_000_000).unsigned_abs() as u32;
let days = (total_seconds / 86400) as i32;
let time_of_day = (total_seconds % 86400).unsigned_abs();
let (year, month, day) = days_to_ymd(days);
let hour = (time_of_day / 3600) as u8;
let minute = ((time_of_day % 3600) / 60) as u8;
let second = (time_of_day % 60) as u8;
if year == 0
&& month == 0
&& day == 0
&& hour == 0
&& minute == 0
&& second == 0
&& microseconds == 0
{
writer.write_u8(0); } else if hour == 0 && minute == 0 && second == 0 && microseconds == 0 {
writer.write_u8(4); writer.write_u16_le(year as u16);
writer.write_u8(month as u8);
writer.write_u8(day as u8);
} else if microseconds == 0 {
writer.write_u8(7); writer.write_u16_le(year as u16);
writer.write_u8(month as u8);
writer.write_u8(day as u8);
writer.write_u8(hour);
writer.write_u8(minute);
writer.write_u8(second);
} else {
writer.write_u8(11); writer.write_u16_le(year as u16);
writer.write_u8(month as u8);
writer.write_u8(day as u8);
writer.write_u8(hour);
writer.write_u8(minute);
writer.write_u8(second);
writer.write_u32_le(microseconds);
}
}
fn days_to_ymd(days: i32) -> (i32, i32, i32) {
let z = days + 719_468;
let era = if z >= 0 {
z / 146_097
} else {
(z - 146_096) / 146_097
};
let doe = (z - era * 146_097) as u32; let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146_096) / 365; let y = yoe as i32 + era * 400;
let doy = doe - (365 * yoe + yoe / 4 - yoe / 100); let mp = (5 * doy + 2) / 153; let d = doy - (153 * mp + 2) / 5 + 1; let m = if mp < 10 { mp + 3 } else { mp - 9 };
let year = if m <= 2 { y + 1 } else { y };
(year, m as i32, d as i32)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_build_stmt_prepare_packet() {
let packet = build_stmt_prepare_packet("SELECT * FROM users WHERE id = ?", 0);
assert_eq!(packet[3], 0);
assert_eq!(packet[4], Command::StmtPrepare as u8);
assert_eq!(&packet[5..], b"SELECT * FROM users WHERE id = ?");
}
#[test]
fn test_build_stmt_close_packet() {
let packet = build_stmt_close_packet(42, 0);
assert_eq!(packet.len(), 9);
assert_eq!(packet[4], Command::StmtClose as u8);
let stmt_id = u32::from_le_bytes([packet[5], packet[6], packet[7], packet[8]]);
assert_eq!(stmt_id, 42);
}
#[test]
fn test_parse_stmt_prepare_ok() {
let data = [
0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, ];
let result = parse_stmt_prepare_ok(&data).unwrap();
assert_eq!(result.statement_id, 1);
assert_eq!(result.num_columns, 3);
assert_eq!(result.num_params, 2);
assert_eq!(result.warnings, 0);
}
#[test]
fn test_parse_stmt_prepare_ok_invalid() {
assert!(parse_stmt_prepare_ok(&[0x00, 0x01]).is_none());
let data = [
0xFF, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00,
];
assert!(parse_stmt_prepare_ok(&data).is_none());
}
#[test]
fn test_build_stmt_execute_no_params() {
let packet = build_stmt_execute_packet(1, &[], None, 0);
assert_eq!(packet[4], Command::StmtExecute as u8);
let stmt_id = u32::from_le_bytes([packet[5], packet[6], packet[7], packet[8]]);
assert_eq!(stmt_id, 1);
assert_eq!(packet[9], 0x00);
let iter_count = u32::from_le_bytes([packet[10], packet[11], packet[12], packet[13]]);
assert_eq!(iter_count, 1);
}
#[test]
fn test_build_stmt_execute_with_params() {
let params = vec![Value::Int(42), Value::Text("hello".to_string())];
let packet = build_stmt_execute_packet(1, ¶ms, None, 0);
assert_eq!(packet[4], Command::StmtExecute as u8);
let stmt_id = u32::from_le_bytes([packet[5], packet[6], packet[7], packet[8]]);
assert_eq!(stmt_id, 1);
assert_eq!(packet[9], 0x00);
let iter_count = u32::from_le_bytes([packet[10], packet[11], packet[12], packet[13]]);
assert_eq!(iter_count, 1);
assert_eq!(packet[14], 0x00);
assert_eq!(packet[15], 0x01);
assert_eq!(packet[16], FieldType::Long as u8);
assert_eq!(packet[17], 0x00); assert_eq!(packet[18], FieldType::VarString as u8);
assert_eq!(packet[19], 0x00); }
#[test]
fn test_build_stmt_execute_with_null() {
let params = vec![Value::Null, Value::Int(42)];
let packet = build_stmt_execute_packet(1, ¶ms, None, 0);
assert_eq!(packet[14], 0x01);
}
#[test]
fn test_value_to_field_type() {
assert_eq!(value_to_field_type(&Value::Null), FieldType::Null);
assert_eq!(value_to_field_type(&Value::Bool(true)), FieldType::Tiny);
assert_eq!(value_to_field_type(&Value::TinyInt(1)), FieldType::Tiny);
assert_eq!(value_to_field_type(&Value::SmallInt(1)), FieldType::Short);
assert_eq!(value_to_field_type(&Value::Int(1)), FieldType::Long);
assert_eq!(value_to_field_type(&Value::BigInt(1)), FieldType::LongLong);
assert_eq!(value_to_field_type(&Value::Float(1.0)), FieldType::Float);
assert_eq!(value_to_field_type(&Value::Double(1.0)), FieldType::Double);
assert_eq!(
value_to_field_type(&Value::Text(String::new())),
FieldType::VarString
);
assert_eq!(value_to_field_type(&Value::Bytes(vec![])), FieldType::Blob);
}
#[test]
fn test_days_to_ymd() {
assert_eq!(days_to_ymd(0), (1970, 1, 1));
assert_eq!(days_to_ymd(10957), (2000, 1, 1));
assert_eq!(days_to_ymd(19782), (2024, 2, 29));
}
}