use bytes::{BufMut, BytesMut};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
use tracing::debug;
use crate::YamlBaseError;
use crate::database::Value;
use crate::protocol::catalog_router::CatalogRouter;
use crate::sql::executor::QueryResult;
use crate::sql::{QueryExecutor, parse_sql};
use crate::yaml::schema::SqlType;
use sqlparser::ast::{
Expr, FunctionArg, FunctionArgExpr, FunctionArguments, SelectItem, Statement, Value as SqlValue,
};
#[derive(Debug, Clone)]
pub struct PreparedStatement {
pub name: String,
pub query: String,
pub parameter_types: Vec<SqlType>,
pub parsed_statements: Vec<sqlparser::ast::Statement>,
}
#[derive(Debug, Clone)]
pub struct Portal {
pub name: String,
pub statement: PreparedStatement,
pub parameters: Vec<Value>,
pub result_formats: Vec<u16>,
}
pub struct ExtendedProtocol {
pub prepared_statements: HashMap<String, PreparedStatement>,
pub portals: HashMap<String, Portal>,
pub catalog_router: Option<Arc<CatalogRouter>>,
}
impl ExtendedProtocol {
pub fn new() -> Self {
Self {
prepared_statements: HashMap::new(),
portals: HashMap::new(),
catalog_router: None,
}
}
pub fn with_catalog_router(catalog_router: Arc<CatalogRouter>) -> Self {
Self {
prepared_statements: HashMap::new(),
portals: HashMap::new(),
catalog_router: Some(catalog_router),
}
}
pub fn set_catalog_router(&mut self, catalog_router: Arc<CatalogRouter>) {
self.catalog_router = Some(catalog_router);
}
}
impl Default for ExtendedProtocol {
fn default() -> Self {
Self::new()
}
}
impl ExtendedProtocol {
pub async fn handle_parse(&mut self, stream: &mut TcpStream, data: &[u8]) -> crate::Result<()> {
debug!("Handling Parse message");
let mut pos = 0;
let name_end = data[pos..]
.iter()
.position(|&b| b == 0)
.unwrap_or(data.len() - pos);
let name = std::str::from_utf8(&data[pos..pos + name_end])
.map_err(|_| YamlBaseError::Protocol("Invalid UTF-8 in statement name".to_string()))?
.to_string();
pos += name_end + 1;
let query_end = data[pos..]
.iter()
.position(|&b| b == 0)
.unwrap_or(data.len() - pos);
let query = std::str::from_utf8(&data[pos..pos + query_end])
.map_err(|_| YamlBaseError::Protocol("Invalid UTF-8 in query".to_string()))?
.to_string();
pos += query_end + 1;
if pos + 2 > data.len() {
return Err(YamlBaseError::Protocol(
"Incomplete parse message".to_string(),
));
}
let param_count = u16::from_be_bytes([data[pos], data[pos + 1]]) as usize;
pos += 2;
let mut parameter_types = Vec::new();
for _ in 0..param_count {
if pos + 4 > data.len() {
return Err(YamlBaseError::Protocol(
"Incomplete parameter types".to_string(),
));
}
let oid = u32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
parameter_types.push(oid_to_sql_type(oid));
pos += 4;
}
let parsed_statements = parse_sql(&query)?;
if parameter_types.is_empty() && !parsed_statements.is_empty() {
if let Statement::Query(query_ref) = &parsed_statements[0] {
let inferred_types = infer_parameter_types(query_ref);
debug!("Inferred {} parameters from query", inferred_types.len());
parameter_types = inferred_types;
}
}
debug!(
"PreparedStatement '{}' has {} parameter types",
name,
parameter_types.len()
);
let stmt = PreparedStatement {
name: name.clone(),
query,
parameter_types,
parsed_statements,
};
self.prepared_statements.insert(name, stmt);
let mut buf = BytesMut::new();
buf.put_u8(b'1');
buf.put_u32(4);
stream.write_all(&buf).await?;
Ok(())
}
pub async fn handle_bind(&mut self, stream: &mut TcpStream, data: &[u8]) -> crate::Result<()> {
debug!("Handling Bind message");
let mut pos = 0;
let portal_name_end = data[pos..]
.iter()
.position(|&b| b == 0)
.unwrap_or(data.len() - pos);
let portal_name = std::str::from_utf8(&data[pos..pos + portal_name_end])
.map_err(|_| YamlBaseError::Protocol("Invalid UTF-8 in portal name".to_string()))?
.to_string();
pos += portal_name_end + 1;
let stmt_name_end = data[pos..]
.iter()
.position(|&b| b == 0)
.unwrap_or(data.len() - pos);
let stmt_name = std::str::from_utf8(&data[pos..pos + stmt_name_end])
.map_err(|_| YamlBaseError::Protocol("Invalid UTF-8 in statement name".to_string()))?
.to_string();
pos += stmt_name_end + 1;
let statement = self
.prepared_statements
.get(&stmt_name)
.ok_or_else(|| {
YamlBaseError::Protocol(format!("Unknown prepared statement: {}", stmt_name))
})?
.clone();
if pos + 2 > data.len() {
return Err(YamlBaseError::Protocol(
"Incomplete bind message".to_string(),
));
}
let format_code_count = u16::from_be_bytes([data[pos], data[pos + 1]]) as usize;
pos += 2;
let mut _format_codes = Vec::new();
for _ in 0..format_code_count {
if pos + 2 > data.len() {
return Err(YamlBaseError::Protocol(
"Incomplete format codes".to_string(),
));
}
let format = u16::from_be_bytes([data[pos], data[pos + 1]]);
_format_codes.push(format);
pos += 2;
}
if pos + 2 > data.len() {
return Err(YamlBaseError::Protocol(
"Incomplete parameter count".to_string(),
));
}
let param_value_count = u16::from_be_bytes([data[pos], data[pos + 1]]) as usize;
pos += 2;
let mut parameters = Vec::new();
for i in 0..param_value_count {
if pos + 4 > data.len() {
return Err(YamlBaseError::Protocol(
"Incomplete parameter value".to_string(),
));
}
let length =
i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
if length == -1 {
parameters.push(Value::Null);
} else {
let length = length as usize;
if pos + length > data.len() {
return Err(YamlBaseError::Protocol(
"Incomplete parameter data".to_string(),
));
}
let value_data = &data[pos..pos + length];
pos += length;
let sql_type = statement.parameter_types.get(i).unwrap_or(&SqlType::Text);
let value = parse_parameter_value(value_data, sql_type)?;
parameters.push(value);
}
}
if pos + 2 > data.len() {
return Err(YamlBaseError::Protocol(
"Incomplete result format count".to_string(),
));
}
let result_format_count = u16::from_be_bytes([data[pos], data[pos + 1]]) as usize;
pos += 2;
let mut result_formats = Vec::new();
for _ in 0..result_format_count {
if pos + 2 > data.len() {
return Err(YamlBaseError::Protocol(
"Incomplete result format codes".to_string(),
));
}
let format = u16::from_be_bytes([data[pos], data[pos + 1]]);
result_formats.push(format);
pos += 2;
}
let portal = Portal {
name: portal_name.clone(),
statement,
parameters,
result_formats,
};
self.portals.insert(portal_name, portal);
let mut buf = BytesMut::new();
buf.put_u8(b'2');
buf.put_u32(4);
stream.write_all(&buf).await?;
Ok(())
}
pub async fn handle_describe(
&self,
stream: &mut TcpStream,
data: &[u8],
executor: &QueryExecutor,
) -> crate::Result<()> {
debug!("Handling Describe message with {} bytes", data.len());
if data.is_empty() {
return Err(YamlBaseError::Protocol(
"Empty describe message".to_string(),
));
}
let describe_type = data[0];
let name_end = data[1..]
.iter()
.position(|&b| b == 0)
.unwrap_or(data.len() - 1);
let name = std::str::from_utf8(&data[1..1 + name_end])
.map_err(|_| YamlBaseError::Protocol("Invalid UTF-8 in describe name".to_string()))?;
match describe_type {
b'S' => {
if let Some(stmt) = self.prepared_statements.get(name) {
let mut buf = BytesMut::new();
buf.put_u8(b't');
buf.put_u32(4 + 2 + stmt.parameter_types.len() as u32 * 4);
buf.put_u16(stmt.parameter_types.len() as u16);
for param_type in &stmt.parameter_types {
buf.put_u32(sql_type_to_oid(param_type));
}
stream.write_all(&buf).await?;
if !stmt.parsed_statements.is_empty() {
if let sqlparser::ast::Statement::Query(query) = &stmt.parsed_statements[0]
{
if let sqlparser::ast::SetExpr::Select(select) = &*query.body {
let (columns, types) =
extract_columns_and_types_from_select(select, executor);
send_row_description_for_columns_with_types(
stream, &columns, &types,
)
.await?;
} else {
buf.clear();
buf.put_u8(b'n');
buf.put_u32(4);
stream.write_all(&buf).await?;
}
} else {
buf.clear();
buf.put_u8(b'n');
buf.put_u32(4);
stream.write_all(&buf).await?;
}
}
} else {
return Err(YamlBaseError::Protocol(format!(
"Unknown statement: {}",
name
)));
}
}
b'P' => {
if let Some(portal) = self.portals.get(name) {
if !portal.statement.parsed_statements.is_empty() {
if let sqlparser::ast::Statement::Query(_) =
&portal.statement.parsed_statements[0]
{
match executor
.execute(&portal.statement.parsed_statements[0])
.await
{
Ok(result) => {
send_row_description(stream, &result).await?;
}
Err(_) => {
let mut buf = BytesMut::new();
buf.put_u8(b'n');
buf.put_u32(4);
stream.write_all(&buf).await?;
}
}
} else {
let mut buf = BytesMut::new();
buf.put_u8(b'n');
buf.put_u32(4);
stream.write_all(&buf).await?;
}
}
} else {
return Err(YamlBaseError::Protocol(format!("Unknown portal: {}", name)));
}
}
_ => {
return Err(YamlBaseError::Protocol(format!(
"Unknown describe type: {}",
describe_type as char
)));
}
}
Ok(())
}
pub async fn handle_execute(
&self,
stream: &mut TcpStream,
data: &[u8],
executor: &QueryExecutor,
) -> crate::Result<()> {
debug!("Handling Execute message");
let mut pos = 0;
let name_end = data[pos..]
.iter()
.position(|&b| b == 0)
.unwrap_or(data.len() - pos);
let portal_name = std::str::from_utf8(&data[pos..pos + name_end])
.map_err(|_| YamlBaseError::Protocol("Invalid UTF-8 in portal name".to_string()))?;
pos += name_end + 1;
if pos + 4 > data.len() {
return Err(YamlBaseError::Protocol(
"Incomplete execute message".to_string(),
));
}
let _row_limit =
u32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
let portal = self
.portals
.get(portal_name)
.ok_or_else(|| YamlBaseError::Protocol(format!("Unknown portal: {}", portal_name)))?;
debug!(
"Executing prepared statement '{}' with query: {}",
portal.statement.name, portal.statement.query
);
debug!("Parameters: {:?}", portal.parameters);
if !portal.statement.parsed_statements.is_empty() {
let mut statement = portal.statement.parsed_statements[0].clone();
substitute_parameters(&mut statement, &portal.parameters)?;
debug!("Query after parameter substitution: {:?}", statement);
let result = if let Some(catalog_router) = &self.catalog_router {
if let Some(catalog_result) = catalog_router.route_statement(&statement)? {
debug!("Catalog query handled by catalog router");
catalog_result
} else {
executor.execute(&statement).await?
}
} else {
executor.execute(&statement).await?
};
debug!(
"Execute result: {} rows, {} columns: {:?}",
result.rows.len(),
result.columns.len(),
result.columns
);
if !result.rows.is_empty() {
debug!("First row: {:?}", result.rows[0]);
}
send_data_rows(stream, &result, &portal.result_formats).await?;
let mut buf = BytesMut::new();
buf.put_u8(b'C');
let tag = format!("SELECT {}", result.rows.len());
buf.put_u32(4 + tag.len() as u32 + 1);
buf.put_slice(tag.as_bytes());
buf.put_u8(0);
stream.write_all(&buf).await?;
}
Ok(())
}
pub async fn handle_sync(&self, stream: &mut TcpStream) -> crate::Result<()> {
debug!("Handling Sync message");
let mut buf = BytesMut::new();
buf.put_u8(b'Z');
buf.put_u32(5);
buf.put_u8(b'I'); stream.write_all(&buf).await?;
Ok(())
}
pub fn close_statement(&mut self, name: &str) {
self.prepared_statements.remove(name);
}
pub fn close_portal(&mut self, name: &str) {
self.portals.remove(name);
}
}
async fn send_row_description(stream: &mut TcpStream, result: &QueryResult) -> crate::Result<()> {
let mut buf = BytesMut::new();
buf.put_u8(b'T');
let mut length = 6; for col in &result.columns {
length += col.len() + 1 + 18; }
buf.put_u32(length as u32);
buf.put_u16(result.columns.len() as u16);
for (i, col) in result.columns.iter().enumerate() {
buf.put_slice(col.as_bytes());
buf.put_u8(0); buf.put_u32(0); buf.put_u16(i as u16);
let type_oid = if i < result.column_types.len() {
sql_type_to_oid(&result.column_types[i])
} else {
25 };
buf.put_u32(type_oid);
buf.put_i16(-1); buf.put_i32(-1); buf.put_i16(0); }
stream.write_all(&buf).await?;
Ok(())
}
async fn send_row_description_for_columns_with_types(
stream: &mut TcpStream,
columns: &[String],
types: &[SqlType],
) -> crate::Result<()> {
let mut buf = BytesMut::new();
buf.put_u8(b'T');
let mut length = 6; for col in columns {
length += col.len() + 1 + 18; }
buf.put_u32(length as u32);
buf.put_u16(columns.len() as u16);
for (i, col) in columns.iter().enumerate() {
buf.put_slice(col.as_bytes());
buf.put_u8(0); buf.put_u32(0); buf.put_u16(i as u16);
let type_oid = if i < types.len() {
sql_type_to_oid(&types[i])
} else {
25 };
buf.put_u32(type_oid);
buf.put_i16(-1); buf.put_i32(-1); buf.put_i16(0); }
stream.write_all(&buf).await?;
Ok(())
}
fn extract_columns_and_types_from_select(
select: &sqlparser::ast::Select,
executor: &QueryExecutor,
) -> (Vec<String>, Vec<SqlType>) {
let mut columns = Vec::new();
let mut types = Vec::new();
for item in &select.projection {
match item {
sqlparser::ast::SelectItem::UnnamedExpr(expr) => {
match expr {
Expr::Identifier(ident) => {
columns.push(ident.value.clone());
types.push(infer_type_from_column_name(&ident.value));
}
Expr::Function(func) => {
let func_name = func
.name
.0
.first()
.map(|ident| ident.value.to_uppercase())
.unwrap_or_default();
match func_name.as_str() {
"COUNT" => {
columns.push(func_name.clone());
types.push(SqlType::BigInt); }
"SUM" => {
columns.push(func_name.clone());
types.push(SqlType::Double); }
"AVG" => {
columns.push(func_name.clone());
types.push(SqlType::Double);
}
_ => {
columns.push(func_name.clone());
types.push(SqlType::Text);
}
}
}
_ => {
columns.push(format!("column{}", columns.len()));
types.push(SqlType::Text);
}
}
}
sqlparser::ast::SelectItem::ExprWithAlias { expr, alias } => {
columns.push(alias.value.clone());
match expr {
Expr::Function(func) => {
let func_name = func
.name
.0
.first()
.map(|ident| ident.value.to_uppercase())
.unwrap_or_default();
match func_name.as_str() {
"COUNT" => types.push(SqlType::BigInt), "SUM" => types.push(SqlType::Double),
"AVG" => types.push(SqlType::Double),
_ => types.push(SqlType::Text),
}
}
_ => types.push(SqlType::Text),
}
}
sqlparser::ast::SelectItem::Wildcard(_) => {
if let Some(table) = select.from.first() {
if let Some(table_name) = get_table_name_from_relation(&table.relation) {
if let Ok(db) = executor.storage().database().try_read() {
if let Some(table) = db.get_table(&table_name) {
for col in &table.columns {
columns.push(col.name.clone());
types.push(col.sql_type.clone());
}
}
}
}
}
}
_ => {
columns.push(format!("column{}", columns.len()));
types.push(SqlType::Text);
}
}
}
(columns, types)
}
fn infer_type_from_column_name(name: &str) -> SqlType {
match name.to_lowercase().as_str() {
"age" | "id" | "count" | "quantity" => SqlType::Integer,
"price" | "amount" | "total" => SqlType::Double,
"active" | "enabled" | "deleted" | "is_active" | "in_stock" => SqlType::Boolean,
"created_at" | "updated_at" => SqlType::Timestamp,
"created_date" => SqlType::Date,
_ => SqlType::Text,
}
}
fn get_table_name_from_relation(relation: &sqlparser::ast::TableFactor) -> Option<String> {
match relation {
sqlparser::ast::TableFactor::Table { name, .. } => {
name.0.first().map(|ident| ident.value.clone())
}
_ => None,
}
}
async fn send_data_rows(
stream: &mut TcpStream,
result: &QueryResult,
result_formats: &[u16],
) -> crate::Result<()> {
for row in &result.rows {
let mut buf = BytesMut::new();
buf.put_u8(b'D');
let mut row_length = 6; for (col_idx, val) in row.iter().enumerate() {
if matches!(val, Value::Null) {
row_length += 4; } else {
let format = if result_formats.is_empty() {
0 } else if result_formats.len() == 1 {
result_formats[0] } else if col_idx < result_formats.len() {
result_formats[col_idx] } else {
0 };
if format == 1 {
match val {
Value::Integer(_) => {
let col_type = result.column_types.get(col_idx);
match col_type {
Some(SqlType::BigInt) => row_length += 4 + 8, Some(SqlType::Integer) => row_length += 4 + 4, _ => row_length += 4 + 4, }
}
Value::Boolean(_) => row_length += 4 + 1, Value::Float(_) => row_length += 4 + 4, Value::Double(_) => row_length += 4 + 8, _ => {
let val_str = val.to_string();
row_length += 4 + val_str.len();
}
}
} else {
let val_str = val.to_string();
row_length += 4 + val_str.len();
}
}
}
buf.put_u32(row_length as u32);
buf.put_u16(row.len() as u16);
for (col_idx, val) in row.iter().enumerate() {
if matches!(val, Value::Null) {
buf.put_i32(-1); } else {
let format = if result_formats.is_empty() {
0 } else if result_formats.len() == 1 {
result_formats[0] } else if col_idx < result_formats.len() {
result_formats[col_idx] } else {
0 };
if format == 1 {
match val {
Value::Integer(i) => {
let col_type = result.column_types.get(col_idx);
match col_type {
Some(SqlType::BigInt) => {
buf.put_i32(8); buf.put_i64(*i); }
Some(SqlType::Integer) => {
buf.put_i32(4); buf.put_i32(*i as i32); }
_ => {
buf.put_i32(4); buf.put_i32(*i as i32); }
}
}
Value::Boolean(b) => {
buf.put_i32(1); buf.put_u8(if *b { 1 } else { 0 });
}
Value::Float(f) => {
buf.put_i32(4); buf.put_f32(*f);
}
Value::Double(d) => {
buf.put_i32(8); buf.put_f64(*d);
}
_ => {
let val_str = val.to_string();
buf.put_i32(val_str.len() as i32);
buf.put_slice(val_str.as_bytes());
}
}
} else {
let val_str = val.to_string();
buf.put_i32(val_str.len() as i32);
buf.put_slice(val_str.as_bytes());
}
}
}
stream.write_all(&buf).await?;
}
Ok(())
}
fn oid_to_sql_type(oid: u32) -> SqlType {
match oid {
16 => SqlType::Boolean, 20 => SqlType::BigInt, 21 => SqlType::Integer, 23 => SqlType::Integer, 25 => SqlType::Text, 700 => SqlType::Float, 701 => SqlType::Double, 1042 => SqlType::Char(1), 1043 => SqlType::Varchar(255), 1082 => SqlType::Date, 1083 => SqlType::Time, 1114 => SqlType::Timestamp, 1700 => SqlType::Decimal(38, 0), 2950 => SqlType::Uuid, 3802 => SqlType::Json, _ => SqlType::Text, }
}
fn sql_type_to_oid(sql_type: &SqlType) -> u32 {
match sql_type {
SqlType::Boolean => 16,
SqlType::Integer => 23, SqlType::BigInt => 20, SqlType::Float => 700,
SqlType::Double => 701,
SqlType::Decimal(_, _) => 1700,
SqlType::Char(_) => 1042, SqlType::Varchar(_) => 1043,
SqlType::Text => 25,
SqlType::Date => 1082,
SqlType::Time => 1083,
SqlType::Timestamp => 1114,
SqlType::Uuid => 2950,
SqlType::Json => 3802,
}
}
fn substitute_parameters(statement: &mut Statement, parameters: &[Value]) -> crate::Result<()> {
match statement {
Statement::Query(query) => {
substitute_parameters_in_query(query, parameters)?;
}
_ => {
return Err(YamlBaseError::Protocol(
"Parameter substitution only supported for queries".to_string(),
));
}
}
Ok(())
}
fn substitute_parameters_in_query(
query: &mut sqlparser::ast::Query,
parameters: &[Value],
) -> crate::Result<()> {
if let sqlparser::ast::SetExpr::Select(select) = &mut *query.body {
if let Some(selection) = &mut select.selection {
substitute_parameters_in_expr(selection, parameters)?;
}
}
Ok(())
}
fn substitute_parameters_in_expr(expr: &mut Expr, parameters: &[Value]) -> crate::Result<()> {
match expr {
Expr::Value(SqlValue::Placeholder(s)) => {
if let Some(num_str) = s.strip_prefix('$') {
if let Ok(param_idx) = num_str.parse::<usize>() {
if param_idx > 0 && param_idx <= parameters.len() {
let param_value = ¶meters[param_idx - 1];
*expr = value_to_sql_expr(param_value);
} else {
return Err(YamlBaseError::Protocol(format!(
"Invalid parameter index: ${}",
param_idx
)));
}
} else {
return Err(YamlBaseError::Protocol(format!(
"Invalid placeholder format: {}",
s
)));
}
}
}
Expr::BinaryOp { left, right, .. } => {
substitute_parameters_in_expr(left, parameters)?;
substitute_parameters_in_expr(right, parameters)?;
}
Expr::UnaryOp { expr, .. } => {
substitute_parameters_in_expr(expr, parameters)?;
}
Expr::InList { expr, list, .. } => {
substitute_parameters_in_expr(expr, parameters)?;
for item in list {
substitute_parameters_in_expr(item, parameters)?;
}
}
Expr::Between {
expr, low, high, ..
} => {
substitute_parameters_in_expr(expr, parameters)?;
substitute_parameters_in_expr(low, parameters)?;
substitute_parameters_in_expr(high, parameters)?;
}
Expr::Case {
operand,
conditions,
results,
else_result,
} => {
if let Some(op) = operand {
substitute_parameters_in_expr(op, parameters)?;
}
for cond in conditions {
substitute_parameters_in_expr(cond, parameters)?;
}
for res in results {
substitute_parameters_in_expr(res, parameters)?;
}
if let Some(else_res) = else_result {
substitute_parameters_in_expr(else_res, parameters)?;
}
}
Expr::Nested(inner) => {
substitute_parameters_in_expr(inner, parameters)?;
}
Expr::IsNull(inner) | Expr::IsNotNull(inner) => {
substitute_parameters_in_expr(inner, parameters)?;
}
Expr::Like { expr, pattern, .. } => {
substitute_parameters_in_expr(expr, parameters)?;
substitute_parameters_in_expr(pattern, parameters)?;
}
_ => {}
}
Ok(())
}
fn value_to_sql_expr(value: &Value) -> Expr {
match value {
Value::Null => Expr::Value(SqlValue::Null),
Value::Boolean(b) => Expr::Value(SqlValue::Boolean(*b)),
Value::Integer(i) => Expr::Value(SqlValue::Number(i.to_string(), false)),
Value::Float(f) => Expr::Value(SqlValue::Number(f.to_string(), false)),
Value::Double(d) => Expr::Value(SqlValue::Number(d.to_string(), false)),
Value::Text(s) => Expr::Value(SqlValue::SingleQuotedString(s.clone())),
Value::Date(d) => Expr::Value(SqlValue::SingleQuotedString(d.to_string())),
Value::Time(t) => Expr::Value(SqlValue::SingleQuotedString(t.to_string())),
Value::Timestamp(ts) => Expr::Value(SqlValue::SingleQuotedString(ts.to_string())),
Value::Uuid(u) => Expr::Value(SqlValue::SingleQuotedString(u.to_string())),
Value::Json(j) => Expr::Value(SqlValue::SingleQuotedString(j.to_string())),
Value::Decimal(d) => Expr::Value(SqlValue::Number(d.to_string(), false)),
}
}
fn infer_parameter_types(query: &sqlparser::ast::Query) -> Vec<SqlType> {
let mut parameter_types = std::collections::HashMap::new();
if let sqlparser::ast::SetExpr::Select(select) = &*query.body {
if let Some(selection) = &select.selection {
infer_types_in_expr(selection, &mut parameter_types);
}
for item in &select.projection {
match item {
SelectItem::UnnamedExpr(expr) | SelectItem::ExprWithAlias { expr, .. } => {
infer_types_in_projection_expr(expr, &mut parameter_types);
}
_ => {}
}
}
}
let max_param = parameter_types.keys().max().copied().unwrap_or(0);
let mut result = Vec::new();
for i in 1..=max_param {
result.push(parameter_types.get(&i).cloned().unwrap_or(SqlType::Text));
}
result
}
fn infer_types_in_expr(
expr: &Expr,
parameter_types: &mut std::collections::HashMap<usize, SqlType>,
) {
match expr {
Expr::BinaryOp { left, op, right } => {
match op {
sqlparser::ast::BinaryOperator::Eq
| sqlparser::ast::BinaryOperator::NotEq
| sqlparser::ast::BinaryOperator::Lt
| sqlparser::ast::BinaryOperator::LtEq
| sqlparser::ast::BinaryOperator::Gt
| sqlparser::ast::BinaryOperator::GtEq => {
if let Expr::Value(SqlValue::Placeholder(s)) = &**left {
if let Some(num_str) = s.strip_prefix('$') {
if let Ok(param_num) = num_str.parse::<usize>() {
if let Some(inferred_type) = infer_type_from_expr(right) {
parameter_types.insert(param_num, inferred_type);
}
}
}
}
if let Expr::Value(SqlValue::Placeholder(s)) = &**right {
if let Some(num_str) = s.strip_prefix('$') {
if let Ok(param_num) = num_str.parse::<usize>() {
if let Some(inferred_type) = infer_type_from_expr(left) {
parameter_types.insert(param_num, inferred_type);
}
}
}
}
}
sqlparser::ast::BinaryOperator::And | sqlparser::ast::BinaryOperator::Or => {
infer_types_in_expr(left, parameter_types);
infer_types_in_expr(right, parameter_types);
}
_ => {}
}
}
Expr::UnaryOp { expr, .. } => {
infer_types_in_expr(expr, parameter_types);
}
Expr::InList { expr, list, .. } => {
infer_types_in_expr(expr, parameter_types);
for item in list {
infer_types_in_expr(item, parameter_types);
}
}
Expr::Between {
expr, low, high, ..
} => {
infer_types_in_expr(expr, parameter_types);
infer_types_in_expr(low, parameter_types);
infer_types_in_expr(high, parameter_types);
}
Expr::Case {
operand,
conditions,
results,
else_result,
} => {
if let Some(op) = operand {
infer_types_in_expr(op, parameter_types);
}
for cond in conditions {
infer_types_in_expr(cond, parameter_types);
}
for res in results {
infer_types_in_expr(res, parameter_types);
}
if let Some(else_res) = else_result {
infer_types_in_expr(else_res, parameter_types);
}
}
Expr::Nested(inner) => {
infer_types_in_expr(inner, parameter_types);
}
Expr::IsNull(inner) | Expr::IsNotNull(inner) => {
infer_types_in_expr(inner, parameter_types);
}
Expr::Like { expr, pattern, .. } => {
infer_types_in_expr(expr, parameter_types);
if let Expr::Value(SqlValue::Placeholder(s)) = &**pattern {
if let Some(num_str) = s.strip_prefix('$')
&& let Ok(param_num) = num_str.parse::<usize>() {
parameter_types.insert(param_num, SqlType::Text);
}
} else {
infer_types_in_expr(pattern, parameter_types);
}
}
_ => {}
}
}
fn infer_type_from_expr(expr: &Expr) -> Option<SqlType> {
match expr {
Expr::Identifier(ident) => {
match ident.value.to_lowercase().as_str() {
"age" | "id" | "count" | "quantity" | "value" => Some(SqlType::Integer),
"price" | "amount" | "total" => Some(SqlType::Double),
"active" | "enabled" | "deleted" | "is_active" | "in_stock" => {
Some(SqlType::Boolean)
}
"name" | "username" | "email" | "description" | "status" | "customer_name" => {
Some(SqlType::Text)
}
"created_at" | "updated_at" => Some(SqlType::Timestamp),
"created_date" => Some(SqlType::Date),
_ => None,
}
}
Expr::Value(SqlValue::Boolean(_)) => Some(SqlType::Boolean),
Expr::Value(SqlValue::Number(_, _)) => Some(SqlType::Integer),
Expr::Value(SqlValue::SingleQuotedString(_)) => Some(SqlType::Text),
_ => None,
}
}
fn infer_types_in_projection_expr(
expr: &Expr,
parameter_types: &mut std::collections::HashMap<usize, SqlType>,
) {
match expr {
Expr::Function(func) => {
if let FunctionArguments::List(args) = &func.args {
for arg in &args.args {
if let FunctionArg::Unnamed(FunctionArgExpr::Expr(arg_expr)) = arg {
infer_types_in_expr(arg_expr, parameter_types);
}
}
}
}
_ => {
infer_types_in_expr(expr, parameter_types);
}
}
}
fn parse_parameter_value(data: &[u8], sql_type: &SqlType) -> crate::Result<Value> {
match sql_type {
SqlType::Integer => {
if data.len() == 8 {
let bytes: [u8; 8] = data.try_into().map_err(|_| {
YamlBaseError::Protocol("Failed to parse 8-byte integer".to_string())
})?;
let val = i64::from_be_bytes(bytes);
Ok(Value::Integer(val))
} else if data.len() == 4 {
let bytes: [u8; 4] = data.try_into().map_err(|_| {
YamlBaseError::Protocol("Failed to parse 4-byte integer".to_string())
})?;
let val = i32::from_be_bytes(bytes) as i64;
Ok(Value::Integer(val))
} else if data.len() == 2 {
let bytes: [u8; 2] = data.try_into().map_err(|_| {
YamlBaseError::Protocol("Failed to parse 2-byte integer".to_string())
})?;
let val = i16::from_be_bytes(bytes) as i64;
Ok(Value::Integer(val))
} else {
Err(YamlBaseError::Protocol("Invalid integer size".to_string()))
}
}
SqlType::BigInt => {
if data.len() == 8 {
let bytes: [u8; 8] = data.try_into().map_err(|_| {
YamlBaseError::Protocol("Failed to parse 8-byte bigint".to_string())
})?;
let val = i64::from_be_bytes(bytes);
Ok(Value::Integer(val))
} else {
Err(YamlBaseError::Protocol("Invalid bigint size".to_string()))
}
}
SqlType::Float => {
if data.len() == 4 {
let bytes: [u8; 4] = data.try_into().map_err(|_| {
YamlBaseError::Protocol("Failed to parse 4-byte float".to_string())
})?;
let val = f32::from_be_bytes(bytes);
Ok(Value::Float(val))
} else {
Err(YamlBaseError::Protocol("Invalid float size".to_string()))
}
}
SqlType::Double => {
if data.len() == 8 {
let bytes: [u8; 8] = data.try_into().map_err(|_| {
YamlBaseError::Protocol("Failed to parse 8-byte double".to_string())
})?;
let val = f64::from_be_bytes(bytes);
Ok(Value::Double(val))
} else {
Err(YamlBaseError::Protocol("Invalid double size".to_string()))
}
}
SqlType::Boolean => {
if data.len() == 1 {
Ok(Value::Boolean(data[0] != 0))
} else {
Err(YamlBaseError::Protocol("Invalid boolean size".to_string()))
}
}
_ => {
let text = std::str::from_utf8(data)
.map_err(|_| YamlBaseError::Protocol("Invalid UTF-8 in parameter".to_string()))?;
Ok(Value::Text(text.to_string()))
}
}
}