use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct PreparedStatement {
pub name: String,
pub query: String,
pub param_types: Vec<u32>,
pub prepared_at: chrono::DateTime<chrono::Utc>,
pub execution_count: u64,
}
#[derive(Debug, Default)]
pub struct PreparedStatementTracker {
statements: HashMap<String, PreparedStatement>,
max_statements: usize,
total_prepared: u64,
total_deallocated: u64,
}
impl PreparedStatementTracker {
pub fn new() -> Self {
Self::with_capacity(1000)
}
pub fn with_capacity(max_statements: usize) -> Self {
Self {
statements: HashMap::with_capacity(max_statements.min(100)),
max_statements,
total_prepared: 0,
total_deallocated: 0,
}
}
pub fn register(&mut self, name: String, query: String, param_types: Vec<u32>) {
if name.is_empty() {
return;
}
if self.statements.len() >= self.max_statements {
if let Some(oldest) = self
.statements
.iter()
.min_by_key(|(_, s)| s.prepared_at)
.map(|(k, _)| k.clone())
{
self.statements.remove(&oldest);
self.total_deallocated += 1;
}
}
self.statements.insert(
name.clone(),
PreparedStatement {
name,
query,
param_types,
prepared_at: chrono::Utc::now(),
execution_count: 0,
},
);
self.total_prepared += 1;
}
pub fn unregister(&mut self, name: &str) -> Option<PreparedStatement> {
let stmt = self.statements.remove(name);
if stmt.is_some() {
self.total_deallocated += 1;
}
stmt
}
pub fn clear(&mut self) {
self.total_deallocated += self.statements.len() as u64;
self.statements.clear();
}
pub fn get(&self, name: &str) -> Option<&PreparedStatement> {
self.statements.get(name)
}
pub fn record_execution(&mut self, name: &str) {
if let Some(stmt) = self.statements.get_mut(name) {
stmt.execution_count += 1;
}
}
pub fn contains(&self, name: &str) -> bool {
self.statements.contains_key(name)
}
pub fn all_statements(&self) -> impl Iterator<Item = &PreparedStatement> {
self.statements.values()
}
pub fn len(&self) -> usize {
self.statements.len()
}
pub fn is_empty(&self) -> bool {
self.statements.is_empty()
}
pub fn generate_prepare_sql(&self) -> Vec<String> {
self.statements
.values()
.map(|stmt| {
if stmt.param_types.is_empty() {
format!("PREPARE {} AS {}", stmt.name, stmt.query)
} else {
let types: Vec<String> = stmt
.param_types
.iter()
.map(|t| oid_to_type_name(*t))
.collect();
format!(
"PREPARE {} ({}) AS {}",
stmt.name,
types.join(", "),
stmt.query
)
}
})
.collect()
}
pub fn stats(&self) -> TrackerStats {
TrackerStats {
active_statements: self.statements.len(),
total_prepared: self.total_prepared,
total_deallocated: self.total_deallocated,
max_capacity: self.max_statements,
}
}
}
#[derive(Debug, Clone)]
pub struct TrackerStats {
pub active_statements: usize,
pub total_prepared: u64,
pub total_deallocated: u64,
pub max_capacity: usize,
}
fn oid_to_type_name(oid: u32) -> String {
match oid {
16 => "boolean".to_string(),
17 => "bytea".to_string(),
18 => "char".to_string(),
19 => "name".to_string(),
20 => "bigint".to_string(),
21 => "smallint".to_string(),
23 => "integer".to_string(),
25 => "text".to_string(),
26 => "oid".to_string(),
700 => "real".to_string(),
701 => "double precision".to_string(),
790 => "money".to_string(),
1042 => "char".to_string(),
1043 => "varchar".to_string(),
1082 => "date".to_string(),
1083 => "time".to_string(),
1114 => "timestamp".to_string(),
1184 => "timestamptz".to_string(),
1186 => "interval".to_string(),
1700 => "numeric".to_string(),
2950 => "uuid".to_string(),
3802 => "jsonb".to_string(),
_ => format!("unknown({})", oid),
}
}
pub fn parse_prepare_statement(sql: &str) -> Option<(String, Vec<String>, String)> {
let sql = sql.trim();
let upper = sql.to_uppercase();
if !upper.starts_with("PREPARE ") {
return None;
}
let rest = &sql[8..].trim_start();
let name_end = rest
.find(|c: char| c.is_whitespace() || c == '(')
.unwrap_or(rest.len());
let name = rest[..name_end].to_string();
let rest = rest[name_end..].trim_start();
let (param_types, rest) = if rest.starts_with('(') {
if let Some(close) = rest.find(')') {
let types_str = &rest[1..close];
let types: Vec<String> = types_str
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
(types, rest[close + 1..].trim_start())
} else {
(Vec::new(), rest)
}
} else {
(Vec::new(), rest)
};
let upper_rest = rest.to_uppercase();
if !upper_rest.starts_with("AS ") {
return None;
}
let query = rest[3..].trim_start().to_string();
Some((name, param_types, query))
}
pub fn parse_deallocate_statement(sql: &str) -> Option<Option<String>> {
let sql = sql.trim();
let upper = sql.to_uppercase();
if !upper.starts_with("DEALLOCATE ") {
return None;
}
let rest = sql[11..].trim();
let upper_rest = rest.to_uppercase();
if upper_rest == "ALL" || upper_rest.starts_with("ALL ") || upper_rest.starts_with("ALL;") {
Some(None) } else {
let name = if upper_rest.starts_with("PREPARE ") {
rest[8..].trim()
} else {
rest
};
let name = name.trim_end_matches(';').trim();
Some(Some(name.to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_register_and_get() {
let mut tracker = PreparedStatementTracker::new();
tracker.register(
"stmt1".to_string(),
"SELECT * FROM users WHERE id = $1".to_string(),
vec![23],
);
assert!(tracker.contains("stmt1"));
let stmt = tracker.get("stmt1").unwrap();
assert_eq!(stmt.query, "SELECT * FROM users WHERE id = $1");
assert_eq!(stmt.param_types, vec![23]);
}
#[test]
fn test_unregister() {
let mut tracker = PreparedStatementTracker::new();
tracker.register("stmt1".to_string(), "SELECT 1".to_string(), vec![]);
assert!(tracker.contains("stmt1"));
tracker.unregister("stmt1");
assert!(!tracker.contains("stmt1"));
}
#[test]
fn test_clear() {
let mut tracker = PreparedStatementTracker::new();
tracker.register("stmt1".to_string(), "SELECT 1".to_string(), vec![]);
tracker.register("stmt2".to_string(), "SELECT 2".to_string(), vec![]);
assert_eq!(tracker.len(), 2);
tracker.clear();
assert!(tracker.is_empty());
}
#[test]
fn test_capacity_limit() {
let mut tracker = PreparedStatementTracker::with_capacity(3);
tracker.register("stmt1".to_string(), "SELECT 1".to_string(), vec![]);
tracker.register("stmt2".to_string(), "SELECT 2".to_string(), vec![]);
tracker.register("stmt3".to_string(), "SELECT 3".to_string(), vec![]);
tracker.register("stmt4".to_string(), "SELECT 4".to_string(), vec![]);
assert_eq!(tracker.len(), 3);
assert!(tracker.contains("stmt4"));
}
#[test]
fn test_generate_prepare_sql() {
let mut tracker = PreparedStatementTracker::new();
tracker.register(
"get_user".to_string(),
"SELECT * FROM users WHERE id = $1".to_string(),
vec![23],
);
let sqls = tracker.generate_prepare_sql();
assert_eq!(sqls.len(), 1);
assert!(sqls[0].contains("PREPARE get_user"));
assert!(sqls[0].contains("integer"));
}
#[test]
fn test_parse_prepare_statement() {
let result = parse_prepare_statement("PREPARE stmt1 AS SELECT 1");
assert!(result.is_some());
let (name, params, query) = result.unwrap();
assert_eq!(name, "stmt1");
assert!(params.is_empty());
assert_eq!(query, "SELECT 1");
let result = parse_prepare_statement("PREPARE stmt2 (integer, text) AS SELECT * FROM t WHERE id = $1 AND name = $2");
assert!(result.is_some());
let (name, params, query) = result.unwrap();
assert_eq!(name, "stmt2");
assert_eq!(params, vec!["integer", "text"]);
assert!(query.starts_with("SELECT"));
}
#[test]
fn test_parse_deallocate_statement() {
assert_eq!(
parse_deallocate_statement("DEALLOCATE ALL"),
Some(None)
);
assert_eq!(
parse_deallocate_statement("DEALLOCATE stmt1"),
Some(Some("stmt1".to_string()))
);
assert_eq!(
parse_deallocate_statement("DEALLOCATE PREPARE stmt2"),
Some(Some("stmt2".to_string()))
);
assert_eq!(parse_deallocate_statement("SELECT 1"), None);
}
#[test]
fn test_execution_tracking() {
let mut tracker = PreparedStatementTracker::new();
tracker.register("stmt1".to_string(), "SELECT 1".to_string(), vec![]);
tracker.record_execution("stmt1");
tracker.record_execution("stmt1");
let stmt = tracker.get("stmt1").unwrap();
assert_eq!(stmt.execution_count, 2);
}
#[test]
fn test_unnamed_statements_ignored() {
let mut tracker = PreparedStatementTracker::new();
tracker.register("".to_string(), "SELECT 1".to_string(), vec![]);
assert!(tracker.is_empty());
}
}