use indexmap::IndexMap;
use std::time::Instant;
use crate::statement::Statement;
#[derive(Debug)]
struct CachedStatement {
statement: Statement,
in_use: bool,
last_used: Instant,
}
impl CachedStatement {
fn new(statement: Statement) -> Self {
Self {
statement,
in_use: false,
last_used: Instant::now(),
}
}
fn touch(&mut self) {
self.last_used = Instant::now();
}
}
#[derive(Debug)]
pub struct StatementCache {
cache: IndexMap<String, CachedStatement>,
max_size: usize,
}
impl StatementCache {
pub fn new(max_size: usize) -> Self {
Self {
cache: IndexMap::with_capacity(max_size),
max_size,
}
}
pub fn get(&mut self, sql: &str) -> Option<Statement> {
if self.max_size == 0 {
return None;
}
if let Some(cached) = self.cache.get_mut(sql) {
cached.touch();
if cached.in_use {
tracing::trace!(sql = sql, "Statement cache hit but in use, returning fresh");
return None;
}
cached.in_use = true;
tracing::trace!(
sql = sql,
cursor_id = cached.statement.cursor_id(),
"Statement cache hit"
);
return Some(cached.statement.clone_for_reuse());
}
tracing::trace!(sql = sql, "Statement cache miss");
None
}
pub fn put(&mut self, sql: String, statement: Statement) {
if self.max_size == 0 {
return;
}
if statement.is_ddl() {
tracing::trace!(sql = sql, "Not caching DDL statement");
return;
}
if statement.cursor_id() == 0 {
tracing::trace!(sql = sql, "Not caching statement without cursor_id");
return;
}
if let Some(cached) = self.cache.get_mut(&sql) {
cached.statement = statement;
cached.in_use = false;
cached.touch();
tracing::trace!(sql = sql, "Updated existing cache entry");
return;
}
if self.cache.len() >= self.max_size {
self.evict_lru();
}
tracing::trace!(
sql = sql,
cursor_id = statement.cursor_id(),
"Adding statement to cache"
);
self.cache.insert(sql, CachedStatement::new(statement));
}
pub fn return_statement(&mut self, sql: &str) {
if let Some(cached) = self.cache.get_mut(sql) {
cached.in_use = false;
tracing::trace!(sql = sql, "Statement returned to cache");
}
}
pub fn mark_cursor_closed(&mut self, sql: &str) {
if let Some(cached) = self.cache.get_mut(sql) {
if cached.statement.cursor_id() != 0 {
cached.statement.set_cursor_id(0);
cached.statement.set_executed(false);
tracing::trace!(sql = sql, "Cursor closed, reset cursor_id to 0");
}
}
}
pub fn clear(&mut self) {
self.cache.clear();
tracing::debug!("Statement cache cleared");
}
pub fn len(&self) -> usize {
self.cache.len()
}
pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}
pub fn max_size(&self) -> usize {
self.max_size
}
fn evict_lru(&mut self) {
let lru_key = self
.cache
.iter()
.filter(|(_, cached)| !cached.in_use)
.min_by_key(|(_, cached)| cached.last_used)
.map(|(key, _)| key.clone());
if let Some(key) = lru_key {
if let Some(cached) = self.cache.swap_remove(&key) {
tracing::trace!(
sql = key,
cursor_id = cached.statement.cursor_id(),
"Evicted LRU statement from cache"
);
}
} else {
tracing::warn!("Statement cache full and all statements in use");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_test_statement(sql: &str, cursor_id: u16) -> Statement {
let mut stmt = Statement::new(sql);
stmt.set_cursor_id(cursor_id);
stmt.set_executed(true);
stmt
}
#[test]
fn test_cache_basic() {
let mut cache = StatementCache::new(5);
let stmt = make_test_statement("SELECT 1 FROM DUAL", 100);
cache.put("SELECT 1 FROM DUAL".to_string(), stmt);
assert_eq!(cache.len(), 1);
let cached = cache.get("SELECT 1 FROM DUAL").expect("Should be cached");
assert_eq!(cached.cursor_id(), 100);
cache.return_statement("SELECT 1 FROM DUAL");
}
#[test]
fn test_cache_miss() {
let mut cache = StatementCache::new(5);
assert!(cache.get("SELECT 1 FROM DUAL").is_none());
}
#[test]
fn test_cache_disabled() {
let mut cache = StatementCache::new(0);
let stmt = make_test_statement("SELECT 1 FROM DUAL", 100);
cache.put("SELECT 1 FROM DUAL".to_string(), stmt);
assert_eq!(cache.len(), 0);
assert!(cache.get("SELECT 1 FROM DUAL").is_none());
}
#[test]
fn test_ddl_not_cached() {
let mut cache = StatementCache::new(5);
let mut stmt = Statement::new("CREATE TABLE test (id NUMBER)");
stmt.set_cursor_id(100);
cache.put("CREATE TABLE test (id NUMBER)".to_string(), stmt);
assert_eq!(cache.len(), 0);
}
#[test]
fn test_no_cursor_not_cached() {
let mut cache = StatementCache::new(5);
let stmt = Statement::new("SELECT 1 FROM DUAL");
cache.put("SELECT 1 FROM DUAL".to_string(), stmt);
assert_eq!(cache.len(), 0);
}
#[test]
fn test_lru_eviction() {
let mut cache = StatementCache::new(3);
cache.put(
"SELECT 1 FROM DUAL".to_string(),
make_test_statement("SELECT 1 FROM DUAL", 1),
);
cache.put(
"SELECT 2 FROM DUAL".to_string(),
make_test_statement("SELECT 2 FROM DUAL", 2),
);
cache.put(
"SELECT 3 FROM DUAL".to_string(),
make_test_statement("SELECT 3 FROM DUAL", 3),
);
assert_eq!(cache.len(), 3);
cache.get("SELECT 1 FROM DUAL");
cache.return_statement("SELECT 1 FROM DUAL");
cache.put(
"SELECT 4 FROM DUAL".to_string(),
make_test_statement("SELECT 4 FROM DUAL", 4),
);
assert_eq!(cache.len(), 3);
assert!(cache.get("SELECT 2 FROM DUAL").is_none()); assert!(cache.get("SELECT 1 FROM DUAL").is_some()); }
#[test]
fn test_in_use_not_returned() {
let mut cache = StatementCache::new(5);
cache.put(
"SELECT 1 FROM DUAL".to_string(),
make_test_statement("SELECT 1 FROM DUAL", 100),
);
let _ = cache.get("SELECT 1 FROM DUAL");
assert!(cache.get("SELECT 1 FROM DUAL").is_none());
cache.return_statement("SELECT 1 FROM DUAL");
assert!(cache.get("SELECT 1 FROM DUAL").is_some());
}
#[test]
fn test_clear() {
let mut cache = StatementCache::new(5);
cache.put(
"SELECT 1 FROM DUAL".to_string(),
make_test_statement("SELECT 1 FROM DUAL", 1),
);
cache.put(
"SELECT 2 FROM DUAL".to_string(),
make_test_statement("SELECT 2 FROM DUAL", 2),
);
assert_eq!(cache.len(), 2);
cache.clear();
assert_eq!(cache.len(), 0);
}
#[test]
fn test_update_existing() {
let mut cache = StatementCache::new(5);
cache.put(
"SELECT 1 FROM DUAL".to_string(),
make_test_statement("SELECT 1 FROM DUAL", 100),
);
cache.put(
"SELECT 1 FROM DUAL".to_string(),
make_test_statement("SELECT 1 FROM DUAL", 200),
);
assert_eq!(cache.len(), 1);
let cached = cache.get("SELECT 1 FROM DUAL").unwrap();
assert_eq!(cached.cursor_id(), 200);
}
}