use std::collections::HashSet;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
use manifoldb_query::ast::Statement;
use manifoldb_query::plan::{LogicalPlan, PhysicalPlan, PhysicalPlanner, PlanBuilder};
use manifoldb_query::ExtendedParser;
use crate::error::{Error, Result};
#[derive(Debug)]
pub struct PreparedStatement {
sql: String,
ast: Statement,
logical_plan: LogicalPlan,
physical_plan: PhysicalPlan,
schema_version: u64,
accessed_tables: HashSet<String>,
is_dml: bool,
is_ddl: bool,
}
impl PreparedStatement {
pub fn new(sql: &str, schema_version: u64) -> Result<Self> {
let ast = ExtendedParser::parse_single(sql)?;
let mut builder = PlanBuilder::new();
let logical_plan =
builder.build_statement(&ast).map_err(|e| Error::Parse(e.to_string()))?;
let planner = PhysicalPlanner::new();
let physical_plan = planner.plan(&logical_plan);
let accessed_tables = Self::extract_tables(&logical_plan);
let (is_dml, is_ddl) = Self::classify_statement(&logical_plan);
Ok(Self {
sql: sql.to_string(),
ast,
logical_plan,
physical_plan,
schema_version,
accessed_tables,
is_dml,
is_ddl,
})
}
#[must_use]
pub fn sql(&self) -> &str {
&self.sql
}
#[must_use]
pub fn schema_version(&self) -> u64 {
self.schema_version
}
#[must_use]
pub fn ast(&self) -> &Statement {
&self.ast
}
#[must_use]
pub fn logical_plan(&self) -> &LogicalPlan {
&self.logical_plan
}
#[must_use]
pub fn physical_plan(&self) -> &PhysicalPlan {
&self.physical_plan
}
#[must_use]
pub fn accessed_tables(&self) -> &HashSet<String> {
&self.accessed_tables
}
#[must_use]
pub fn is_dml(&self) -> bool {
self.is_dml
}
#[must_use]
pub fn is_ddl(&self) -> bool {
self.is_ddl
}
#[must_use]
pub fn is_query(&self) -> bool {
!self.is_dml && !self.is_ddl
}
#[must_use]
pub fn is_valid(&self, current_schema_version: u64) -> bool {
self.schema_version == current_schema_version
}
fn extract_tables(plan: &LogicalPlan) -> HashSet<String> {
let mut tables = HashSet::new();
Self::extract_tables_recursive(plan, &mut tables);
tables
}
fn extract_tables_recursive(plan: &LogicalPlan, tables: &mut HashSet<String>) {
match plan {
LogicalPlan::Scan(scan_node) => {
tables.insert(scan_node.table_name.clone());
}
LogicalPlan::Insert { table, input, .. } => {
tables.insert(table.clone());
Self::extract_tables_recursive(input, tables);
}
LogicalPlan::Update { table, .. } => {
tables.insert(table.clone());
}
LogicalPlan::Delete { table, .. } => {
tables.insert(table.clone());
}
LogicalPlan::CreateTable(node) => {
tables.insert(node.name.clone());
}
LogicalPlan::DropTable(node) => {
for name in &node.names {
tables.insert(name.clone());
}
}
LogicalPlan::CreateIndex(node) => {
tables.insert(node.table.clone());
}
LogicalPlan::DropIndex(_) => {
}
_ => {
for child in plan.children() {
Self::extract_tables_recursive(child, tables);
}
}
}
}
fn classify_statement(plan: &LogicalPlan) -> (bool, bool) {
match plan {
LogicalPlan::Insert { .. }
| LogicalPlan::Update { .. }
| LogicalPlan::Delete { .. } => (true, false),
LogicalPlan::CreateTable(_)
| LogicalPlan::DropTable(_)
| LogicalPlan::CreateIndex(_)
| LogicalPlan::DropIndex(_) => (false, true),
_ => (false, false),
}
}
}
#[derive(Debug)]
pub struct PreparedStatementCache {
statements: RwLock<std::collections::HashMap<String, Arc<PreparedStatement>>>,
max_size: usize,
schema_version: AtomicU64,
hits: AtomicU64,
misses: AtomicU64,
invalidations: AtomicU64,
}
impl PreparedStatementCache {
#[must_use]
pub fn new(max_size: usize) -> Self {
Self {
statements: RwLock::new(std::collections::HashMap::new()),
max_size,
schema_version: AtomicU64::new(0),
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
invalidations: AtomicU64::new(0),
}
}
#[must_use]
pub fn default_cache() -> Self {
Self::new(1000)
}
pub fn get_or_prepare(&self, sql: &str) -> Result<Arc<PreparedStatement>> {
let schema_version = self.schema_version.load(Ordering::Acquire);
{
let cache = self
.statements
.read()
.map_err(|_| Error::lock_poisoned("prepared statement cache read lock"))?;
if let Some(stmt) = cache.get(sql) {
if stmt.is_valid(schema_version) {
self.hits.fetch_add(1, Ordering::Relaxed);
return Ok(Arc::clone(stmt));
}
}
}
self.misses.fetch_add(1, Ordering::Relaxed);
let stmt = Arc::new(PreparedStatement::new(sql, schema_version)?);
{
let mut cache = self
.statements
.write()
.map_err(|_| Error::lock_poisoned("prepared statement cache write lock"))?;
if cache.len() >= self.max_size {
let to_remove = (self.max_size / 10).max(1);
let keys_to_remove: Vec<String> = cache.keys().take(to_remove).cloned().collect();
for key in keys_to_remove {
cache.remove(&key);
}
}
cache.insert(sql.to_string(), Arc::clone(&stmt));
}
Ok(stmt)
}
pub fn prepare(&self, sql: &str) -> Result<Arc<PreparedStatement>> {
let schema_version = self.schema_version.load(Ordering::Acquire);
Ok(Arc::new(PreparedStatement::new(sql, schema_version)?))
}
pub fn set_schema_version(&self, version: u64) {
let old_version = self.schema_version.swap(version, Ordering::Release);
if old_version != version {
self.invalidations.fetch_add(1, Ordering::Relaxed);
}
}
#[must_use]
pub fn schema_version(&self) -> u64 {
self.schema_version.load(Ordering::Acquire)
}
pub fn clear(&self) -> Result<()> {
let mut cache = self
.statements
.write()
.map_err(|_| Error::lock_poisoned("prepared statement cache write lock"))?;
cache.clear();
Ok(())
}
pub fn invalidate_tables(&self, tables: &[String]) -> Result<()> {
if tables.is_empty() {
return Ok(());
}
let tables_set: HashSet<&String> = tables.iter().collect();
let mut cache = self
.statements
.write()
.map_err(|_| Error::lock_poisoned("prepared statement cache write lock"))?;
cache.retain(|_, stmt| !stmt.accessed_tables().iter().any(|t| tables_set.contains(t)));
Ok(())
}
pub fn len(&self) -> Result<usize> {
let cache = self
.statements
.read()
.map_err(|_| Error::lock_poisoned("prepared statement cache read lock"))?;
Ok(cache.len())
}
pub fn is_empty(&self) -> Result<bool> {
let cache = self
.statements
.read()
.map_err(|_| Error::lock_poisoned("prepared statement cache read lock"))?;
Ok(cache.is_empty())
}
#[must_use]
pub fn hits(&self) -> u64 {
self.hits.load(Ordering::Relaxed)
}
#[must_use]
pub fn misses(&self) -> u64 {
self.misses.load(Ordering::Relaxed)
}
#[must_use]
pub fn invalidations(&self) -> u64 {
self.invalidations.load(Ordering::Relaxed)
}
#[must_use]
pub fn hit_rate(&self) -> Option<f64> {
let hits = self.hits.load(Ordering::Relaxed);
let misses = self.misses.load(Ordering::Relaxed);
let total = hits + misses;
if total == 0 {
None
} else {
Some((hits as f64 / total as f64) * 100.0)
}
}
pub fn reset_metrics(&self) {
self.hits.store(0, Ordering::Relaxed);
self.misses.store(0, Ordering::Relaxed);
self.invalidations.store(0, Ordering::Relaxed);
}
}
impl Default for PreparedStatementCache {
fn default() -> Self {
Self::default_cache()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prepare_select() {
let stmt = PreparedStatement::new("SELECT * FROM users WHERE id = $1", 0)
.expect("should parse SELECT statement");
assert!(!stmt.is_dml());
assert!(!stmt.is_ddl());
assert!(stmt.is_query());
assert!(stmt.accessed_tables().contains("users"));
}
#[test]
fn test_prepare_insert() {
let stmt = PreparedStatement::new("INSERT INTO users (name) VALUES ($1)", 0)
.expect("should parse INSERT statement");
assert!(stmt.is_dml());
assert!(!stmt.is_ddl());
assert!(!stmt.is_query());
assert!(stmt.accessed_tables().contains("users"));
}
#[test]
fn test_prepare_update() {
let stmt = PreparedStatement::new("UPDATE users SET name = $1 WHERE id = $2", 0)
.expect("should parse UPDATE statement");
assert!(stmt.is_dml());
assert!(!stmt.is_ddl());
assert!(stmt.accessed_tables().contains("users"));
}
#[test]
fn test_prepare_delete() {
let stmt = PreparedStatement::new("DELETE FROM users WHERE id = $1", 0)
.expect("should parse DELETE statement");
assert!(stmt.is_dml());
assert!(!stmt.is_ddl());
assert!(stmt.accessed_tables().contains("users"));
}
#[test]
fn test_schema_version_validity() {
let stmt = PreparedStatement::new("SELECT * FROM users", 5).expect("should parse SELECT");
assert!(stmt.is_valid(5));
assert!(!stmt.is_valid(6));
assert!(!stmt.is_valid(4));
}
#[test]
fn test_cache_basic() {
let cache = PreparedStatementCache::new(100);
let stmt1 = cache.get_or_prepare("SELECT * FROM users").expect("should prepare statement");
assert_eq!(cache.hits(), 0);
assert_eq!(cache.misses(), 1);
let stmt2 =
cache.get_or_prepare("SELECT * FROM users").expect("should get cached statement");
assert_eq!(cache.hits(), 1);
assert_eq!(cache.misses(), 1);
assert!(Arc::ptr_eq(&stmt1, &stmt2));
}
#[test]
fn test_cache_invalidation_on_schema_change() {
let cache = PreparedStatementCache::new(100);
let stmt1 = cache.get_or_prepare("SELECT * FROM users").expect("should prepare statement");
assert_eq!(stmt1.schema_version(), 0);
cache.set_schema_version(1);
let stmt2 = cache
.get_or_prepare("SELECT * FROM users")
.expect("should re-prepare after schema change");
assert_eq!(stmt2.schema_version(), 1);
assert!(!Arc::ptr_eq(&stmt1, &stmt2));
}
#[test]
fn test_cache_table_invalidation() {
let cache = PreparedStatementCache::new(100);
cache.get_or_prepare("SELECT * FROM users").expect("should prepare users query");
cache.get_or_prepare("SELECT * FROM orders").expect("should prepare orders query");
assert_eq!(cache.len().expect("should get cache len"), 2);
cache.invalidate_tables(&["users".to_string()]).expect("should invalidate tables");
assert_eq!(cache.len().expect("should get cache len"), 1);
cache.get_or_prepare("SELECT * FROM orders").expect("should get cached orders query");
assert_eq!(cache.hits(), 1);
}
#[test]
fn test_cache_clear() {
let cache = PreparedStatementCache::new(100);
cache.get_or_prepare("SELECT * FROM users").expect("should prepare users query");
cache.get_or_prepare("SELECT * FROM orders").expect("should prepare orders query");
assert_eq!(cache.len().expect("should get cache len"), 2);
cache.clear().expect("should clear cache");
assert!(cache.is_empty().expect("should check if empty"));
}
#[test]
fn test_cache_eviction() {
let cache = PreparedStatementCache::new(5);
for i in 0..10 {
cache
.get_or_prepare(&format!("SELECT * FROM table{i}"))
.expect("should prepare statement");
}
assert!(cache.len().expect("should get cache len") <= 5);
}
#[test]
fn test_concurrent_schema_version_change() {
use std::sync::Arc;
use std::thread;
let cache = Arc::new(PreparedStatementCache::new(100));
cache.get_or_prepare("SELECT * FROM users").expect("should prepare statement");
let cache_clone = Arc::clone(&cache);
let handle = thread::spawn(move || {
cache_clone.set_schema_version(1);
});
handle.join().expect("thread should complete");
let stmt = cache
.get_or_prepare("SELECT * FROM users")
.expect("should re-prepare after concurrent schema change");
assert_eq!(stmt.schema_version(), 1);
}
#[test]
fn test_invalidate_empty_tables_is_noop() {
let cache = PreparedStatementCache::new(100);
cache.get_or_prepare("SELECT * FROM users").expect("should prepare statement");
assert_eq!(cache.len().expect("should get cache len"), 1);
cache.invalidate_tables(&[]).expect("should handle empty table list");
assert_eq!(cache.len().expect("should get cache len"), 1);
}
#[test]
fn test_parse_error_returns_error() {
let result = PreparedStatement::new("INVALID SQL SYNTAX HERE", 0);
assert!(result.is_err(), "invalid SQL should return error");
}
}