use std::sync::Arc;
use crate::proto;
use crate::types::ColumnDesc;
pub(crate) struct StmtCache {
entries: Vec<(u64, StmtInfo)>,
}
impl Default for StmtCache {
fn default() -> Self {
Self {
entries: Vec::with_capacity(16),
}
}
}
impl StmtCache {
#[inline]
pub(crate) fn get_mut(&mut self, hash: &u64, sql: &str) -> Option<&mut StmtInfo> {
self.entries
.iter_mut()
.find(|(h, info)| h == hash && info.sql.len() == sql.len() && &*info.sql == sql)
.map(|(_, info)| info)
}
#[inline]
pub(crate) fn get(&self, hash: &u64, sql: &str) -> Option<&StmtInfo> {
self.entries
.iter()
.find(|(h, info)| h == hash && info.sql.len() == sql.len() && &*info.sql == sql)
.map(|(_, info)| info)
}
#[inline]
pub(crate) fn contains_key(&self, hash: &u64, sql: &str) -> bool {
self.entries
.iter()
.any(|(h, info)| h == hash && info.sql.len() == sql.len() && &*info.sql == sql)
}
#[inline]
pub(crate) fn insert(&mut self, hash: u64, info: StmtInfo) {
if let Some(entry) = self.entries.iter_mut().find(|(h, existing)| {
*h == hash && existing.sql.len() == info.sql.len() && existing.sql == info.sql
}) {
entry.1 = info;
} else {
self.entries.push((hash, info));
}
}
#[inline]
pub(crate) fn remove(&mut self, hash: &u64) -> Option<StmtInfo> {
if let Some(pos) = self.entries.iter().position(|(h, _)| h == hash) {
Some(self.entries.swap_remove(pos).1)
} else {
None
}
}
#[inline]
pub(crate) fn len(&self) -> usize {
self.entries.len()
}
pub(crate) fn evict_lru(&mut self) -> Option<(u64, StmtInfo)> {
if self.entries.is_empty() {
return None;
}
let min_idx = self
.entries
.iter()
.enumerate()
.min_by_key(|(_, (_, info))| info.last_used)
.map(|(i, _)| i)?;
Some(self.entries.swap_remove(min_idx))
}
}
#[inline]
pub(crate) fn make_stmt_name(hash: u64) -> [u8; 18] {
const HEX: &[u8; 16] = b"0123456789abcdef";
let mut buf = [0u8; 18];
buf[0] = b's';
buf[1] = b'_';
let bytes = hash.to_be_bytes();
for (i, &b) in bytes.iter().enumerate() {
buf[2 + i * 2] = HEX[(b >> 4) as usize];
buf[2 + i * 2 + 1] = HEX[(b & 0x0f) as usize];
}
buf
}
pub(crate) struct StmtInfo {
pub(crate) name: [u8; 18],
pub(crate) sql: Box<str>,
pub(crate) columns: Arc<[ColumnDesc]>,
pub(crate) last_used: u64,
pub(crate) bind_template: Option<BindTemplate>,
}
impl StmtInfo {
#[cfg(test)]
#[inline]
pub(crate) fn name_str(&self) -> &str {
std::str::from_utf8(&self.name).expect("stmt name is ASCII")
}
}
pub(crate) struct BindTemplate {
pub(crate) bytes: Vec<u8>,
pub(crate) bind_end: usize,
pub(crate) param_slots: smallvec::SmallVec<[(usize, i32); 8]>,
}
pub(crate) fn build_bind_template(write_buf: &[u8], param_count: usize) -> Option<BindTemplate> {
if write_buf.is_empty() || write_buf[0] != b'B' {
return None;
}
if write_buf.len() < 5 {
return None;
}
let mut pos = 5;
while pos < write_buf.len() && write_buf[pos] != 0 {
pos += 1;
}
pos += 1;
while pos < write_buf.len() && write_buf[pos] != 0 {
pos += 1;
}
pos += 1;
if pos + 2 > write_buf.len() {
return None;
}
let num_fmt_codes = i16::from_be_bytes([write_buf[pos], write_buf[pos + 1]]);
pos += 2;
pos += num_fmt_codes.max(0) as usize * 2;
if pos + 2 > write_buf.len() {
return None;
}
let wire_param_count = i16::from_be_bytes([write_buf[pos], write_buf[pos + 1]]) as usize;
pos += 2;
if wire_param_count != param_count {
return None;
}
let mut param_slots = smallvec::SmallVec::with_capacity(param_count);
for _ in 0..param_count {
if pos + 4 > write_buf.len() {
return None;
}
let data_len = i32::from_be_bytes([
write_buf[pos],
write_buf[pos + 1],
write_buf[pos + 2],
write_buf[pos + 3],
]);
pos += 4;
if data_len < 0 {
param_slots.push((pos, -1));
} else {
param_slots.push((pos, data_len));
pos += data_len as usize;
}
}
let bind_end = write_buf.len();
let mut bytes = Vec::with_capacity(bind_end + proto::EXECUTE_SYNC.len());
bytes.extend_from_slice(write_buf);
bytes.extend_from_slice(proto::EXECUTE_SYNC);
Some(BindTemplate {
bytes,
bind_end,
param_slots,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::Encode;
use crate::types::hash_sql;
#[test]
fn stmt_cache_basic_ops() {
let mut cache = StmtCache::default();
assert_eq!(cache.len(), 0);
assert!(!cache.contains_key(&42, "SELECT 1"));
assert!(cache.get(&42, "SELECT 1").is_none());
assert!(cache.get_mut(&42, "SELECT 1").is_none());
assert!(cache.remove(&42).is_none());
}
#[test]
fn stmt_cache_insert_get_remove() {
let mut cache = StmtCache::default();
let info = StmtInfo {
name: *b"s_test\0\0\0\0\0\0\0\0\0\0\0\0",
sql: "SELECT 1".into(),
columns: Arc::from(Vec::new()),
last_used: 1,
bind_template: None,
};
cache.insert(42, info);
assert_eq!(cache.len(), 1);
assert!(cache.contains_key(&42, "SELECT 1"));
assert!(cache.get(&42, "SELECT 1").is_some());
assert!(cache.get_mut(&42, "SELECT 1").is_some());
let removed = cache.remove(&42);
assert!(removed.is_some());
assert_eq!(cache.len(), 0);
assert!(!cache.contains_key(&42, "SELECT 1"));
}
#[test]
fn stmt_cache_evict_lru() {
let mut cache = StmtCache::default();
let sqls = ["SELECT 0", "SELECT 1", "SELECT 2"];
for i in 0..3u64 {
cache.insert(
i,
StmtInfo {
name: make_stmt_name(i),
sql: sqls[i as usize].into(),
columns: Arc::from(Vec::new()),
last_used: i + 1,
bind_template: None,
},
);
}
assert_eq!(cache.len(), 3);
let evicted = cache.evict_lru().unwrap();
assert_eq!(evicted.0, 0); assert_eq!(cache.len(), 2);
}
#[test]
fn stmt_cache_insert_overwrite() {
let mut cache = StmtCache::default();
let info1 = StmtInfo {
name: *b"s_aaaaaaaaaaaaaaaa",
sql: "SELECT 1".into(),
columns: Arc::from(Vec::new()),
last_used: 1,
bind_template: None,
};
let info2 = StmtInfo {
name: *b"s_bbbbbbbbbbbbbbbb",
sql: "SELECT 1".into(),
columns: Arc::from(Vec::new()),
last_used: 2,
bind_template: None,
};
cache.insert(42, info1);
cache.insert(42, info2);
assert_eq!(cache.len(), 1);
assert_eq!(
cache.get(&42, "SELECT 1").unwrap().name_str(),
"s_bbbbbbbbbbbbbbbb"
);
}
#[test]
fn stmt_cache_hash_collision_different_sql() {
let mut cache = StmtCache::default();
let info = StmtInfo {
name: *b"s_test\0\0\0\0\0\0\0\0\0\0\0\0",
sql: "SELECT 1".into(),
columns: Arc::from(Vec::new()),
last_used: 1,
bind_template: None,
};
cache.insert(42, info);
assert!(cache.get(&42, "SELECT 2").is_none());
assert!(cache.get_mut(&42, "SELECT 2").is_none());
assert!(!cache.contains_key(&42, "SELECT 2"));
assert!(cache.get(&42, "SELECT 1").is_some());
}
fn name_str(name: &[u8; 18]) -> &str {
std::str::from_utf8(name).expect("ASCII")
}
#[test]
fn stmt_name_format() {
let name = make_stmt_name(0);
assert_eq!(name_str(&name), "s_0000000000000000");
let name = make_stmt_name(0xDEADBEEF12345678);
assert_eq!(name_str(&name), "s_deadbeef12345678");
let name = make_stmt_name(u64::MAX);
assert_eq!(name_str(&name), "s_ffffffffffffffff");
}
#[test]
fn stmt_name_format_verification() {
let name = make_stmt_name(0xDEADBEEFCAFEBABE);
let s = name_str(&name);
assert!(s.starts_with("s_"), "must start with s_");
assert_eq!(s.len(), 18, "s_ (2) + 16 hex = 18");
assert!(
s[2..].chars().all(|c| c.is_ascii_hexdigit()),
"remaining chars must be hex: {s}",
);
}
#[test]
fn stmt_name_zero() {
let name = make_stmt_name(0);
assert_eq!(name_str(&name), "s_0000000000000000");
}
#[test]
fn stmt_name_max() {
let name = make_stmt_name(u64::MAX);
assert_eq!(name_str(&name), "s_ffffffffffffffff");
}
#[test]
fn stmt_name_one() {
let name = make_stmt_name(1);
assert_eq!(name_str(&name), "s_0000000000000001");
}
#[test]
fn stmt_name_powers_of_two() {
let name = make_stmt_name(256);
assert_eq!(name_str(&name), "s_0000000000000100");
}
#[test]
fn stmt_name_format_always_18_chars() {
for val in [0u64, 1, 0xFF, 0xFFFF, 0xFFFF_FFFF, u64::MAX] {
let name = make_stmt_name(val);
let s = name_str(&name);
assert_eq!(s.len(), 18, "name len for {val:x}");
assert!(s.starts_with("s_"));
assert!(s[2..].chars().all(|c| c.is_ascii_hexdigit()));
}
}
#[test]
fn stmt_info_has_last_used_counter() {
let info = StmtInfo {
name: *b"s_test\0\0\0\0\0\0\0\0\0\0\0\0",
sql: "SELECT 1".into(),
columns: Arc::from(Vec::new()),
last_used: 42,
bind_template: None,
};
assert_eq!(info.last_used, 42);
}
#[test]
fn hash_sql_consistency() {
let h = hash_sql("SELECT 1");
assert_eq!(h, hash_sql("SELECT 1"));
assert_ne!(h, hash_sql("SELECT 2"));
}
#[test]
fn build_bind_template_basic() {
let mut buf = Vec::new();
let val: i32 = 42;
proto::write_bind_params(&mut buf, b"", b"s_test", &[&val as &(dyn Encode + Sync)]);
let tmpl = build_bind_template(&buf, 1);
assert!(tmpl.is_some());
let tmpl = tmpl.unwrap();
assert_eq!(tmpl.param_slots.len(), 1);
assert_eq!(tmpl.param_slots[0].1, 4);
}
#[test]
fn build_bind_template_null_param() {
let mut buf = Vec::new();
let val: Option<i32> = None;
proto::write_bind_params(&mut buf, b"", b"s_test", &[&val as &(dyn Encode + Sync)]);
let tmpl = build_bind_template(&buf, 1);
assert!(tmpl.is_some());
let tmpl = tmpl.unwrap();
assert_eq!(tmpl.param_slots.len(), 1);
assert_eq!(tmpl.param_slots[0].1, -1); }
#[test]
fn build_bind_template_multiple_params() {
let mut buf = Vec::new();
let id: i32 = 1;
let name: &str = "alice";
proto::write_bind_params(
&mut buf,
b"",
b"s_test",
&[&id as &(dyn Encode + Sync), &name as &(dyn Encode + Sync)],
);
let tmpl = build_bind_template(&buf, 2);
assert!(tmpl.is_some());
let tmpl = tmpl.unwrap();
assert_eq!(tmpl.param_slots.len(), 2);
assert_eq!(tmpl.param_slots[0].1, 4); assert_eq!(tmpl.param_slots[1].1, 5); }
#[test]
fn build_bind_template_empty_buf() {
let tmpl = build_bind_template(&[], 0);
assert!(tmpl.is_none());
}
#[test]
fn build_bind_template_wrong_type() {
let tmpl = build_bind_template(&[b'E', 0, 0, 0, 4], 0);
assert!(tmpl.is_none());
}
#[test]
fn build_bind_template_param_count_mismatch() {
let mut buf = Vec::new();
let val: i32 = 42;
proto::write_bind_params(&mut buf, b"", b"s_test", &[&val as &(dyn Encode + Sync)]);
let tmpl = build_bind_template(&buf, 2);
assert!(tmpl.is_none());
}
#[test]
fn build_bind_template_too_short_buf() {
let tmpl = build_bind_template(&[b'B', 0, 0], 0);
assert!(tmpl.is_none());
}
#[test]
fn build_bind_template_zero_params() {
let mut buf = Vec::new();
proto::write_bind_params(&mut buf, b"", b"s_test", &[]);
let tmpl = build_bind_template(&buf, 0);
assert!(tmpl.is_some());
let tmpl = tmpl.unwrap();
assert_eq!(tmpl.param_slots.len(), 0);
}
#[test]
fn build_bind_template_bool_param() {
let mut buf = Vec::new();
let val = true;
proto::write_bind_params(&mut buf, b"", b"s_test", &[&val as &(dyn Encode + Sync)]);
let tmpl = build_bind_template(&buf, 1);
assert!(tmpl.is_some());
let tmpl = tmpl.unwrap();
assert_eq!(tmpl.param_slots.len(), 1);
assert_eq!(tmpl.param_slots[0].1, 1); }
#[test]
fn build_bind_template_i64_param() {
let mut buf = Vec::new();
let val: i64 = 123456789;
proto::write_bind_params(&mut buf, b"", b"s_test", &[&val as &(dyn Encode + Sync)]);
let tmpl = build_bind_template(&buf, 1);
assert!(tmpl.is_some());
let tmpl = tmpl.unwrap();
assert_eq!(tmpl.param_slots[0].1, 8); }
#[test]
#[allow(clippy::approx_constant)]
fn build_bind_template_f64_param() {
let mut buf = Vec::new();
let val: f64 = 3.14;
proto::write_bind_params(&mut buf, b"", b"s_test", &[&val as &(dyn Encode + Sync)]);
let tmpl = build_bind_template(&buf, 1);
assert!(tmpl.is_some());
let tmpl = tmpl.unwrap();
assert_eq!(tmpl.param_slots[0].1, 8); }
#[test]
fn build_bind_template_str_param() {
let mut buf = Vec::new();
let val: &str = "hello world";
proto::write_bind_params(&mut buf, b"", b"s_test", &[&val as &(dyn Encode + Sync)]);
let tmpl = build_bind_template(&buf, 1);
assert!(tmpl.is_some());
let tmpl = tmpl.unwrap();
assert_eq!(tmpl.param_slots[0].1, 11); }
#[test]
#[allow(clippy::approx_constant)]
fn build_bind_template_mixed_params_with_null() {
let mut buf = Vec::new();
let id: i32 = 1;
let name: Option<i32> = None;
let score: f64 = 9.9;
proto::write_bind_params(
&mut buf,
b"",
b"s_test",
&[
&id as &(dyn Encode + Sync),
&name as &(dyn Encode + Sync),
&score as &(dyn Encode + Sync),
],
);
let tmpl = build_bind_template(&buf, 3);
assert!(tmpl.is_some());
let tmpl = tmpl.unwrap();
assert_eq!(tmpl.param_slots.len(), 3);
assert_eq!(tmpl.param_slots[0].1, 4); assert_eq!(tmpl.param_slots[1].1, -1); assert_eq!(tmpl.param_slots[2].1, 8); }
#[test]
fn build_bind_template_preserves_bytes() {
let mut buf = Vec::new();
let val: i32 = 42;
proto::write_bind_params(&mut buf, b"", b"s_test", &[&val as &(dyn Encode + Sync)]);
let bind_len = buf.len();
let tmpl = build_bind_template(&buf, 1).unwrap();
assert_eq!(
&tmpl.bytes[..bind_len],
&buf[..],
"template must start with original Bind message"
);
assert_eq!(
&tmpl.bytes[bind_len..],
proto::EXECUTE_SYNC,
"template must end with EXECUTE_SYNC"
);
assert_eq!(tmpl.bind_end, bind_len);
}
}