use crate::{Result, Error, Value, Schema, Tuple};
use crate::sql::logical_plan::LogicalPlan;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
#[derive(Debug, Clone)]
pub struct PreparedStatement {
pub name: String,
pub query: String,
pub param_types: Vec<i32>,
pub result_schema: Option<Schema>,
pub cached_plan: Option<LogicalPlan>,
}
#[derive(Debug, Clone)]
pub struct Portal {
pub name: String,
pub statement_name: String,
pub params: Vec<Option<Vec<u8>>>,
pub param_formats: Vec<i16>,
pub result_formats: Vec<i16>,
pub state: PortalState,
}
#[derive(Debug, Clone, PartialEq)]
pub enum PortalState {
Ready,
Suspended {
rows_returned: usize,
cached_results: Option<Vec<Tuple>>,
},
Complete,
}
pub struct PreparedStatementManager {
statements: Arc<RwLock<HashMap<String, PreparedStatement>>>,
portals: Arc<RwLock<HashMap<String, Portal>>>,
max_statements: usize,
max_portals: usize,
statement_order: Arc<RwLock<Vec<String>>>,
}
impl PreparedStatementManager {
pub fn new() -> Self {
Self::with_capacity(1000, 500) }
pub fn with_capacity(max_statements: usize, max_portals: usize) -> Self {
Self {
statements: Arc::new(RwLock::new(HashMap::new())),
portals: Arc::new(RwLock::new(HashMap::new())),
max_statements,
max_portals,
statement_order: Arc::new(RwLock::new(Vec::new())),
}
}
pub fn store_statement(&self, statement: PreparedStatement) -> Result<()> {
use crate::error::LockResultExt;
let mut statements = self.statements.write()
.map_lock_err("Failed to acquire write lock on statements")?;
let mut order = self.statement_order.write()
.map_lock_err("Failed to acquire write lock on statement order")?;
let is_new = !statements.contains_key(&statement.name);
if is_new && statements.len() >= self.max_statements {
let mut evicted = false;
for name in order.iter() {
if !name.is_empty() && statements.contains_key(name) {
statements.remove(name);
tracing::debug!("Evicted prepared statement '{}' due to capacity limit", name);
evicted = true;
break;
}
}
if !evicted && statements.len() >= self.max_statements {
if let Some(first) = order.first().cloned() {
statements.remove(&first);
tracing::debug!("Evicted prepared statement '{}' due to capacity limit", first);
}
}
order.retain(|n| statements.contains_key(n));
}
if !is_new {
order.retain(|n| n != &statement.name);
}
order.push(statement.name.clone());
statements.insert(statement.name.clone(), statement);
Ok(())
}
pub fn get_statement(&self, name: &str) -> Result<Option<PreparedStatement>> {
use crate::error::LockResultExt;
let statements = self.statements.read()
.map_lock_err("Failed to acquire read lock on statements")?;
Ok(statements.get(name).cloned())
}
pub fn remove_statement(&self, name: &str) -> Result<bool> {
use crate::error::LockResultExt;
let mut statements = self.statements.write()
.map_lock_err("Failed to acquire write lock on statements")?;
let mut order = self.statement_order.write()
.map_lock_err("Failed to acquire write lock on statement order")?;
let removed = statements.remove(name).is_some();
if removed {
order.retain(|n| n != name);
}
Ok(removed)
}
pub fn store_portal(&self, portal: Portal) -> Result<()> {
use crate::error::LockResultExt;
let mut portals = self.portals.write()
.map_lock_err("Failed to acquire write lock on portals")?;
if portals.len() >= self.max_portals && !portals.contains_key(&portal.name) {
return Err(Error::resource_limit(format!(
"Maximum number of portals ({}) reached",
self.max_portals
)));
}
portals.insert(portal.name.clone(), portal);
Ok(())
}
pub fn get_portal(&self, name: &str) -> Result<Option<Portal>> {
use crate::error::LockResultExt;
let portals = self.portals.read()
.map_lock_err("Failed to acquire read lock on portals")?;
Ok(portals.get(name).cloned())
}
pub fn update_portal_state(&self, name: &str, state: PortalState) -> Result<()> {
use crate::error::LockResultExt;
let mut portals = self.portals.write()
.map_lock_err("Failed to acquire write lock on portals")?;
if let Some(portal) = portals.get_mut(name) {
portal.state = state;
Ok(())
} else {
Err(Error::query_execution(format!("Portal '{}' not found", name)))
}
}
pub fn remove_portal(&self, name: &str) -> Result<bool> {
use crate::error::LockResultExt;
let mut portals = self.portals.write()
.map_lock_err("Failed to acquire write lock on portals")?;
Ok(portals.remove(name).is_some())
}
pub fn clear_all(&self) -> Result<()> {
use crate::error::LockResultExt;
let mut statements = self.statements.write()
.map_lock_err("Failed to acquire write lock on statements")?;
let mut portals = self.portals.write()
.map_lock_err("Failed to acquire write lock on portals")?;
let mut order = self.statement_order.write()
.map_lock_err("Failed to acquire write lock on statement order")?;
statements.clear();
portals.clear();
order.clear();
Ok(())
}
pub fn statement_count(&self) -> Result<usize> {
use crate::error::LockResultExt;
let statements = self.statements.read()
.map_lock_err("Failed to acquire read lock on statements")?;
Ok(statements.len())
}
pub fn portal_count(&self) -> Result<usize> {
use crate::error::LockResultExt;
let portals = self.portals.read()
.map_lock_err("Failed to acquire read lock on portals")?;
Ok(portals.len())
}
}
impl Default for PreparedStatementManager {
fn default() -> Self {
Self::new()
}
}
pub fn decode_parameter(
data: &[u8],
format: i16,
type_oid: i32,
) -> Result<Value> {
if format == 0 {
decode_text_parameter(data, type_oid)
} else {
decode_binary_parameter(data, type_oid)
}
}
fn decode_text_parameter(data: &[u8], type_oid: i32) -> Result<Value> {
let text = std::str::from_utf8(data)
.map_err(|e| Error::protocol(format!("Invalid UTF-8 in parameter: {}", e)))?;
match type_oid {
16 => {
let val = text == "t" || text == "true" || text == "1";
Ok(Value::Boolean(val))
}
21 => {
let val = text.parse::<i16>()
.map_err(|e| Error::protocol(format!("Invalid Int2 parameter: {}", e)))?;
Ok(Value::Int2(val))
}
23 => {
let val = text.parse::<i32>()
.map_err(|e| Error::protocol(format!("Invalid Int4 parameter: {}", e)))?;
Ok(Value::Int4(val))
}
20 => {
let val = text.parse::<i64>()
.map_err(|e| Error::protocol(format!("Invalid Int8 parameter: {}", e)))?;
Ok(Value::Int8(val))
}
700 => {
let val = text.parse::<f32>()
.map_err(|e| Error::protocol(format!("Invalid Float4 parameter: {}", e)))?;
Ok(Value::Float4(val))
}
701 => {
let val = text.parse::<f64>()
.map_err(|e| Error::protocol(format!("Invalid Float8 parameter: {}", e)))?;
Ok(Value::Float8(val))
}
25 | 1043 => {
Ok(Value::String(text.to_string()))
}
114 | 3802 => {
let _json: serde_json::Value = serde_json::from_str(text)
.map_err(|e| Error::protocol(format!("Invalid JSON parameter: {}", e)))?;
Ok(Value::Json(text.to_string()))
}
_ => {
Ok(Value::String(text.to_string()))
}
}
}
#[allow(clippy::indexing_slicing)]
fn decode_binary_parameter(data: &[u8], type_oid: i32) -> Result<Value> {
match type_oid {
16 => {
if data.is_empty() {
return Err(Error::protocol("Empty boolean parameter"));
}
Ok(Value::Boolean(data[0] != 0))
}
21 => {
if data.len() < 2 {
return Err(Error::protocol("Invalid Int2 parameter length"));
}
let val = i16::from_be_bytes([data[0], data[1]]);
Ok(Value::Int2(val))
}
23 => {
if data.len() < 4 {
return Err(Error::protocol("Invalid Int4 parameter length"));
}
let val = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
Ok(Value::Int4(val))
}
20 => {
if data.len() < 8 {
return Err(Error::protocol("Invalid Int8 parameter length"));
}
let bytes: [u8; 8] = data[0..8].try_into()
.map_err(|_| Error::protocol("Invalid Int8 parameter"))?;
let val = i64::from_be_bytes(bytes);
Ok(Value::Int8(val))
}
700 => {
if data.len() < 4 {
return Err(Error::protocol("Invalid Float4 parameter length"));
}
let bytes: [u8; 4] = data[0..4].try_into()
.map_err(|_| Error::protocol("Invalid Float4 parameter"))?;
let val = f32::from_be_bytes(bytes);
Ok(Value::Float4(val))
}
701 => {
if data.len() < 8 {
return Err(Error::protocol("Invalid Float8 parameter length"));
}
let bytes: [u8; 8] = data[0..8].try_into()
.map_err(|_| Error::protocol("Invalid Float8 parameter"))?;
let val = f64::from_be_bytes(bytes);
Ok(Value::Float8(val))
}
25 | 1043 => {
let text = std::str::from_utf8(data)
.map_err(|e| Error::protocol(format!("Invalid UTF-8 in text parameter: {}", e)))?;
Ok(Value::String(text.to_string()))
}
_ => {
Ok(Value::Bytes(data.to_vec()))
}
}
}
pub fn substitute_parameters(sql: &str, params: &[Value]) -> Result<String> {
if params.is_empty() {
return Ok(sql.to_string());
}
let bytes = sql.as_bytes();
let len = bytes.len();
let mut result = String::with_capacity(sql.len() + params.len() * 8);
let mut i = 0;
while i < len {
if bytes[i] == b'$' && i + 1 < len && bytes[i + 1].is_ascii_digit() {
let start = i + 1;
let mut end = start;
while end < len && bytes[end].is_ascii_digit() {
end += 1;
}
#[allow(clippy::indexing_slicing)]
let param_num: usize = sql[start..end].parse().unwrap_or(0);
if param_num >= 1 && param_num <= params.len() {
#[allow(clippy::indexing_slicing)]
result.push_str(&value_to_sql_literal(¶ms[param_num - 1]));
} else {
#[allow(clippy::indexing_slicing)]
result.push_str(&sql[i..end]);
}
i = end;
} else {
#[allow(clippy::indexing_slicing)]
let ch = sql[i..].chars().next().unwrap_or(' ');
result.push(ch);
i += ch.len_utf8();
}
}
Ok(result)
}
fn value_to_sql_literal(value: &Value) -> String {
match value {
Value::Null => "NULL".to_string(),
Value::Boolean(b) => if *b { "TRUE" } else { "FALSE" }.to_string(),
Value::Int2(i) => i.to_string(),
Value::Int4(i) => i.to_string(),
Value::Int8(i) => i.to_string(),
Value::Float4(f) => f.to_string(),
Value::Float8(f) => f.to_string(),
Value::String(s) => format!("'{}'", s.replace('\'', "''")),
Value::Json(j) => format!("'{}'::jsonb", j.to_string().replace('\'', "''")),
Value::Timestamp(ts) => format!("'{}'::timestamp", ts.to_rfc3339()),
Value::Vector(v) => {
let arr_str = v.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(",");
format!("ARRAY[{}]", arr_str)
}
_ => format!("'{}'", value.to_string().replace('\'', "''")),
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_statement_manager() {
let manager = PreparedStatementManager::new();
let stmt = PreparedStatement {
name: "test_stmt".to_string(),
query: "SELECT * FROM users WHERE id = $1".to_string(),
param_types: vec![23], result_schema: None,
cached_plan: None,
};
manager.store_statement(stmt.clone()).unwrap();
assert_eq!(manager.statement_count().unwrap(), 1);
let retrieved = manager.get_statement("test_stmt").unwrap();
assert!(retrieved.is_some());
assert_eq!(retrieved.as_ref().unwrap().query, stmt.query);
manager.remove_statement("test_stmt").unwrap();
assert_eq!(manager.statement_count().unwrap(), 0);
}
#[test]
fn test_portal_manager() {
let manager = PreparedStatementManager::new();
let portal = Portal {
name: "test_portal".to_string(),
statement_name: "test_stmt".to_string(),
params: vec![Some(b"123".to_vec())],
param_formats: vec![0],
result_formats: vec![0],
state: PortalState::Ready,
};
manager.store_portal(portal.clone()).unwrap();
assert_eq!(manager.portal_count().unwrap(), 1);
let retrieved = manager.get_portal("test_portal").unwrap();
assert!(retrieved.is_some());
assert_eq!(retrieved.as_ref().unwrap().statement_name, portal.statement_name);
}
#[test]
fn test_decode_text_parameter() {
let val = decode_text_parameter(b"123", 23).unwrap();
assert_eq!(val, Value::Int4(123));
let val = decode_text_parameter(b"hello", 25).unwrap();
assert_eq!(val, Value::String("hello".to_string()));
let val = decode_text_parameter(b"t", 16).unwrap();
assert_eq!(val, Value::Boolean(true));
}
#[test]
fn test_decode_binary_parameter() {
let data = 123i32.to_be_bytes();
let val = decode_binary_parameter(&data, 23).unwrap();
assert_eq!(val, Value::Int4(123));
let val = decode_binary_parameter(&[1], 16).unwrap();
assert_eq!(val, Value::Boolean(true));
}
#[test]
fn test_substitute_parameters() {
let sql = "SELECT * FROM users WHERE id = $1 AND name = $2";
let params = vec![
Value::Int4(123),
Value::String("Alice".to_string()),
];
let result = substitute_parameters(sql, ¶ms).unwrap();
assert_eq!(result, "SELECT * FROM users WHERE id = 123 AND name = 'Alice'");
}
#[test]
fn test_substitute_parameters_10_plus() {
let sql = "INSERT INTO t VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11)";
let params: Vec<Value> = (1..=11).map(|i| Value::Int4(i)).collect();
let result = substitute_parameters(sql, ¶ms).unwrap();
assert_eq!(result, "INSERT INTO t VALUES (1,2,3,4,5,6,7,8,9,10,11)");
}
#[test]
fn test_substitute_parameters_with_cast() {
let sql = "SELECT $1::text, $2::int";
let params = vec![
Value::String("hello".to_string()),
Value::Int4(42),
];
let result = substitute_parameters(sql, ¶ms).unwrap();
assert_eq!(result, "SELECT 'hello'::text, 42::int");
}
#[test]
fn test_capacity_limits() {
let manager = PreparedStatementManager::with_capacity(2, 2);
for i in 0..2 {
let stmt = PreparedStatement {
name: format!("stmt{}", i),
query: "SELECT 1".to_string(),
param_types: vec![],
result_schema: None,
cached_plan: None,
};
manager.store_statement(stmt).unwrap();
}
let stmt = PreparedStatement {
name: "stmt3".to_string(),
query: "SELECT 1".to_string(),
param_types: vec![],
result_schema: None,
cached_plan: None,
};
let result = manager.store_statement(stmt);
assert!(result.is_ok(), "LRU eviction should allow new statement");
assert!(manager.get_statement("stmt0").unwrap().is_none(), "stmt0 should have been evicted");
assert!(manager.get_statement("stmt1").unwrap().is_some(), "stmt1 should still exist");
assert!(manager.get_statement("stmt3").unwrap().is_some(), "stmt3 should exist");
}
}