use std::path::PathBuf;
use std::sync::{Mutex, OnceLock};
use bitcode::{Decode, Encode};
static BUILD_STATE: OnceLock<Mutex<BuildState>> = OnceLock::new();
struct BuildState {
seen_hashes: std::collections::HashSet<u64>,
}
fn track_and_cleanup(dir: &std::path::Path, hash: u64) {
let state = BUILD_STATE.get_or_init(|| {
let gen_path = dir.join(".generation");
let my_gen = format!("{}_{}", std::process::id(), timestamp_nanos());
let prev_gen = std::fs::read_to_string(&gen_path).unwrap_or_default();
let prev_gen = prev_gen.trim();
if !prev_gen.is_empty() {
cleanup_stale_files(dir);
}
let _ = std::fs::write(&gen_path, &my_gen);
let _ = std::fs::write(dir.join(".manifest"), "");
Mutex::new(BuildState {
seen_hashes: std::collections::HashSet::new(),
})
});
if let Ok(mut s) = state.lock() {
if s.seen_hashes.insert(hash) {
let manifest = dir.join(".manifest");
let line = format!("{hash:016x}\n");
let _ = std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(&manifest)
.and_then(|mut f| std::io::Write::write_all(&mut f, line.as_bytes()));
}
}
}
fn timestamp_nanos() -> u128 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos()
}
fn cleanup_stale_files(dir: &std::path::Path) {
let manifest_path = dir.join(".manifest");
let active: std::collections::HashSet<String> = std::fs::read_to_string(&manifest_path)
.unwrap_or_default()
.lines()
.filter(|l| !l.is_empty())
.map(|l| l.trim().to_owned())
.collect();
if active.is_empty() {
return;
}
if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.extension().is_some_and(|e| e == "bitcode") {
if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
if !active.contains(stem) {
let _ = std::fs::remove_file(&path);
}
}
}
}
}
}
use crate::parse::ParsedQuery;
use crate::validate::{ColumnInfo, ValidationResult};
const CACHE_FORMAT_VERSION: u8 = 4;
const BSQL_VERSION: &str = env!("CARGO_PKG_VERSION");
#[derive(Encode, Decode)]
struct CacheEnvelope {
version: u8,
data: Vec<u8>,
}
#[derive(Debug, Clone, Decode)]
struct CachedQueryV1 {
pub sql_hash: u64,
pub normalized_sql: String,
pub columns: Vec<CachedColumn>,
pub param_pg_oids: Vec<u32>,
pub param_is_pg_enum: Vec<bool>,
}
#[derive(Debug, Clone, Encode, Decode)]
struct CachedQueryV2 {
pub sql_hash: u64,
pub normalized_sql: String,
pub columns: Vec<CachedColumn>,
pub param_pg_oids: Vec<u32>,
pub param_is_pg_enum: Vec<bool>,
pub bsql_version: String,
}
#[derive(Debug, Clone, Encode, Decode)]
struct CachedQueryV3 {
pub sql_hash: u64,
pub normalized_sql: String,
pub columns: Vec<CachedColumn>,
pub param_pg_oids: Vec<u32>,
pub param_is_pg_enum: Vec<bool>,
pub bsql_version: String,
pub param_rust_types: Vec<String>,
}
#[derive(Debug, Clone, Encode, Decode)]
pub struct CachedQuery {
pub sql_hash: u64,
pub normalized_sql: String,
pub columns: Vec<CachedColumn>,
pub param_pg_oids: Vec<u32>,
pub param_is_pg_enum: Vec<bool>,
pub bsql_version: String,
pub param_rust_types: Vec<String>,
pub rewritten_sql: Option<String>,
}
#[derive(Debug, Clone, Encode, Decode)]
pub struct CachedColumn {
pub name: String,
pub pg_oid: u32,
pub pg_type_name: String,
pub is_nullable: bool,
pub rust_type: String,
}
static IS_OFFLINE: OnceLock<bool> = OnceLock::new();
fn compute_is_offline() -> bool {
if std::env::var("BSQL_OFFLINE")
.map(|v| v == "true" || v == "1")
.unwrap_or(false)
{
return true;
}
let has_url =
std::env::var("BSQL_DATABASE_URL").is_ok() || std::env::var("DATABASE_URL").is_ok();
if !has_url {
if let Ok(dir) = resolve_cache_dir() {
if dir.is_dir()
&& std::fs::read_dir(&dir)
.map(|mut d| d.next().is_some())
.unwrap_or(false)
{
return true;
}
}
}
false
}
pub fn is_offline() -> bool {
*IS_OFFLINE.get_or_init(compute_is_offline)
}
static CACHE_DIR: OnceLock<Result<PathBuf, String>> = OnceLock::new();
fn resolve_cache_dir() -> Result<PathBuf, String> {
let manifest_dir =
std::env::var("CARGO_MANIFEST_DIR").map_err(|_| "CARGO_MANIFEST_DIR not set".to_owned())?;
let dir = PathBuf::from(&manifest_dir);
let mut search = dir.clone();
loop {
let candidate = search.join(".bsql");
if candidate.is_dir() {
return Ok(candidate.join("queries"));
}
if !search.pop() {
break;
}
}
Ok(dir.join(".bsql").join("queries"))
}
fn cache_dir() -> Result<&'static PathBuf, String> {
CACHE_DIR
.get_or_init(resolve_cache_dir)
.as_ref()
.map_err(|e| e.clone())
}
pub fn sql_hash(normalized_sql: &str) -> u64 {
bsql_core::rapid_hash_str(normalized_sql)
}
pub fn lookup_cached_validation(parsed: &ParsedQuery) -> Result<ValidationResult, String> {
let hash = sql_hash(&parsed.normalized_sql);
let dir = cache_dir()?;
track_and_cleanup(dir, hash);
let path = dir.join(format!("{hash:016x}.bitcode"));
if !path.exists() {
return Err(format!(
"query not found in offline cache (hash {hash:016x}). \
The SQL may have changed since the cache was last built. \
Run `cargo build` with a live PostgreSQL connection to update \
the cache, then rebuild with BSQL_OFFLINE=true.\n \
SQL: {}",
parsed.normalized_sql
));
}
let bytes = std::fs::read(&path)
.map_err(|e| format!("failed to read offline cache file {}: {e}", path.display()))?;
let envelope: CacheEnvelope = bitcode::decode(&bytes).map_err(|e| {
format!(
"failed to decode offline cache file {} (file may be corrupted \
or from an incompatible bsql version — run `cargo build` with \
a live PostgreSQL connection to regenerate): {e}",
path.display()
)
})?;
let cached: CachedQuery = if envelope.version == 1 {
let v1: CachedQueryV1 = bitcode::decode(&envelope.data).map_err(|e| {
format!(
"failed to decode v1 cached query in {} (file may be corrupted \
— run `cargo build` with a live PostgreSQL connection to \
regenerate): {e}",
path.display()
)
})?;
CachedQuery {
sql_hash: v1.sql_hash,
normalized_sql: v1.normalized_sql,
columns: v1.columns,
param_pg_oids: v1.param_pg_oids,
param_is_pg_enum: v1.param_is_pg_enum,
bsql_version: BSQL_VERSION.to_owned(),
param_rust_types: vec![],
rewritten_sql: None,
}
} else if envelope.version == 2 {
let v2: CachedQueryV2 = bitcode::decode(&envelope.data).map_err(|e| {
format!(
"failed to decode v2 cached query in {} (file may be corrupted \
— run `cargo build` with a live PostgreSQL connection to \
regenerate): {e}",
path.display()
)
})?;
CachedQuery {
sql_hash: v2.sql_hash,
normalized_sql: v2.normalized_sql,
columns: v2.columns,
param_pg_oids: v2.param_pg_oids,
param_is_pg_enum: v2.param_is_pg_enum,
bsql_version: v2.bsql_version,
param_rust_types: vec![],
rewritten_sql: None,
}
} else if envelope.version == 3 {
let v3: CachedQueryV3 = bitcode::decode(&envelope.data).map_err(|e| {
format!(
"failed to decode v3 cached query in {} (file may be corrupted \
— run `cargo build` with a live PostgreSQL connection to \
regenerate): {e}",
path.display()
)
})?;
CachedQuery {
sql_hash: v3.sql_hash,
normalized_sql: v3.normalized_sql,
columns: v3.columns,
param_pg_oids: v3.param_pg_oids,
param_is_pg_enum: v3.param_is_pg_enum,
bsql_version: v3.bsql_version,
param_rust_types: v3.param_rust_types,
rewritten_sql: None,
}
} else if envelope.version == CACHE_FORMAT_VERSION {
let cached: CachedQuery = bitcode::decode(&envelope.data).map_err(|e| {
format!(
"failed to decode cached query in {} (file may be corrupted \
— run `cargo build` with a live PostgreSQL connection to \
regenerate): {e}",
path.display()
)
})?;
if cached.bsql_version != BSQL_VERSION {
return Err(format!(
"offline cache was generated by bsql {} but current version is {} \
— run `cargo build` with a live PostgreSQL connection to regenerate",
cached.bsql_version, BSQL_VERSION
));
}
cached
} else {
return Err(format!(
"offline cache was generated by a different bsql version \
(format v{}, expected v{}) — run `cargo build` with a live \
PostgreSQL connection to regenerate",
envelope.version, CACHE_FORMAT_VERSION
));
};
if !cached.param_rust_types.is_empty() {
for (i, cached_type) in cached.param_rust_types.iter().enumerate() {
if i < parsed.params.len() {
let current_type = &parsed.params[i].rust_type;
if current_type != cached_type {
return Err(format!(
"parameter type mismatch: ${} was '{}' when cache was built, \
now declared as '{}'. Rebuild with a live database connection \
to update the cache.",
parsed.params[i].name, cached_type, current_type
));
}
}
}
if parsed.params.len() != cached.param_rust_types.len() {
return Err(format!(
"parameter count changed: cache has {} params, query now has {}. \
Rebuild with a live database connection.",
cached.param_rust_types.len(),
parsed.params.len()
));
}
}
if cached.normalized_sql != parsed.normalized_sql {
return Err(format!(
"offline cache hash collision detected (hash {hash:016x}). \
Cached SQL does not match current SQL. Run `cargo build` \
with a live PostgreSQL connection to regenerate the cache.\n \
cached: {}\n current: {}",
cached.normalized_sql, parsed.normalized_sql
));
}
for col in &cached.columns {
validate_cached_type(&col.rust_type)?;
}
Ok(cached_to_validation(&cached))
}
fn cached_to_validation(cached: &CachedQuery) -> ValidationResult {
let columns = cached
.columns
.iter()
.map(|c| ColumnInfo {
name: c.name.clone(),
pg_oid: c.pg_oid,
pg_type_name: c.pg_type_name.clone(),
is_nullable: c.is_nullable,
rust_type: c.rust_type.clone(),
})
.collect();
ValidationResult {
columns,
param_pg_oids: cached.param_pg_oids.iter().copied().collect(),
param_is_pg_enum: cached.param_is_pg_enum.iter().copied().collect(),
rewritten_sql: cached.rewritten_sql.clone(),
#[cfg(feature = "explain")]
explain_plan: None,
}
}
pub fn write_cache(parsed: &ParsedQuery, validation: &ValidationResult) {
if let Err(e) = write_cache_inner(parsed, validation) {
log::warn!("bsql: failed to write offline cache: {e}");
}
}
fn write_cache_inner(parsed: &ParsedQuery, validation: &ValidationResult) -> Result<(), String> {
let dir = cache_dir()?;
std::fs::create_dir_all(dir).map_err(|e| {
format!(
"failed to create offline cache directory {}: {e}",
dir.display()
)
})?;
let hash = sql_hash(&parsed.normalized_sql);
track_and_cleanup(dir, hash);
let cached = validation_to_cached(hash, parsed, validation);
let inner_bytes = bitcode::encode(&cached);
let envelope = CacheEnvelope {
version: CACHE_FORMAT_VERSION,
data: inner_bytes,
};
let bytes = bitcode::encode(&envelope);
let path = dir.join(format!("{hash:016x}.bitcode"));
let tmp_path = dir.join(format!("{hash:016x}.{}.bitcode.tmp", std::process::id()));
std::fs::write(&tmp_path, &bytes).map_err(|e| {
format!(
"failed to write offline cache file {}: {e}",
tmp_path.display()
)
})?;
std::fs::rename(&tmp_path, &path).map_err(|e| {
format!(
"failed to rename offline cache file {} -> {}: {e}",
tmp_path.display(),
path.display()
)
})?;
Ok(())
}
fn validation_to_cached(
hash: u64,
parsed: &ParsedQuery,
validation: &ValidationResult,
) -> CachedQuery {
let columns = validation
.columns
.iter()
.map(|c| CachedColumn {
name: c.name.clone(),
pg_oid: c.pg_oid,
pg_type_name: c.pg_type_name.clone(),
is_nullable: c.is_nullable,
rust_type: c.rust_type.clone(),
})
.collect();
CachedQuery {
sql_hash: hash,
normalized_sql: parsed.normalized_sql.clone(),
columns,
param_pg_oids: validation.param_pg_oids.to_vec(),
param_is_pg_enum: validation.param_is_pg_enum.to_vec(),
bsql_version: BSQL_VERSION.to_owned(),
param_rust_types: parsed.params.iter().map(|p| p.rust_type.clone()).collect(),
rewritten_sql: validation.rewritten_sql.clone(),
}
}
fn validate_cached_type(rust_type: &str) -> Result<(), String> {
let inner = rust_type
.strip_prefix("Option<")
.and_then(|s| s.strip_suffix('>'))
.unwrap_or(rust_type);
let element = inner
.strip_prefix("Vec<")
.and_then(|s| s.strip_suffix('>'))
.unwrap_or(inner);
let known_base = bsql_core::types::BASE_TYPE_MAP
.iter()
.any(|m| m.rust_type == inner);
const KNOWN_PREFIXES: &[&str] = &["::time::", "::chrono::", "::uuid::", "::rust_decimal::"];
if known_base
|| KNOWN_PREFIXES.iter().any(|p| inner.starts_with(p))
|| KNOWN_PREFIXES.iter().any(|p| element.starts_with(p))
{
return Ok(());
}
if syn::parse_str::<syn::Type>(rust_type).is_err() {
return Err(format!(
"offline cache contains invalid type syntax: `{rust_type}` \
— run `cargo build` with a live PostgreSQL connection to regenerate"
));
}
Err(format!(
"offline cache contains unexpected type: `{rust_type}` \
— run `cargo build` with a live PostgreSQL connection to regenerate"
))
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::TempDir;
fn sample_cached_query() -> CachedQuery {
CachedQuery {
sql_hash: 0xDEAD_BEEF_CAFE_1234,
normalized_sql: "select id, name from users where id = $1".into(),
columns: vec![
CachedColumn {
name: "id".into(),
pg_oid: 23,
pg_type_name: "int4".into(),
is_nullable: false,
rust_type: "i32".into(),
},
CachedColumn {
name: "name".into(),
pg_oid: 25,
pg_type_name: "text".into(),
is_nullable: true,
rust_type: "Option<String>".into(),
},
],
param_pg_oids: vec![23],
param_is_pg_enum: vec![false],
bsql_version: BSQL_VERSION.to_owned(),
param_rust_types: vec!["i32".into()],
rewritten_sql: None,
}
}
fn encode_enveloped(cached: &CachedQuery) -> Vec<u8> {
let inner = bitcode::encode(cached);
let envelope = CacheEnvelope {
version: CACHE_FORMAT_VERSION,
data: inner,
};
bitcode::encode(&envelope)
}
fn decode_enveloped(bytes: &[u8]) -> Result<CachedQuery, String> {
let envelope: CacheEnvelope =
bitcode::decode(bytes).map_err(|e| format!("envelope: {e}"))?;
if envelope.version != CACHE_FORMAT_VERSION {
return Err(format!(
"version mismatch: got {}, expected {}",
envelope.version, CACHE_FORMAT_VERSION
));
}
bitcode::decode(&envelope.data).map_err(|e| format!("inner: {e}"))
}
#[test]
fn envelope_round_trip() {
let original = sample_cached_query();
let bytes = encode_enveloped(&original);
let decoded = decode_enveloped(&bytes).expect("decode failed");
assert_eq!(decoded.sql_hash, original.sql_hash);
assert_eq!(decoded.normalized_sql, original.normalized_sql);
assert_eq!(decoded.columns.len(), original.columns.len());
assert_eq!(decoded.param_pg_oids, original.param_pg_oids);
assert_eq!(decoded.param_is_pg_enum, original.param_is_pg_enum);
for (d, o) in decoded.columns.iter().zip(&original.columns) {
assert_eq!(d.name, o.name);
assert_eq!(d.pg_oid, o.pg_oid);
assert_eq!(d.pg_type_name, o.pg_type_name);
assert_eq!(d.is_nullable, o.is_nullable);
assert_eq!(d.rust_type, o.rust_type);
}
}
#[test]
fn format_version_mismatch_returns_clear_error() {
let cached = sample_cached_query();
let inner = bitcode::encode(&cached);
let envelope = CacheEnvelope {
version: 99, data: inner,
};
let bytes = bitcode::encode(&envelope);
let err = decode_enveloped(&bytes).unwrap_err();
assert!(
err.contains("version mismatch"),
"error should mention version: {err}"
);
}
#[test]
fn cached_to_validation_preserves_all_fields() {
let cached = sample_cached_query();
let validation = cached_to_validation(&cached);
assert_eq!(validation.columns.len(), 2);
assert_eq!(validation.columns[0].name, "id");
assert_eq!(validation.columns[0].pg_oid, 23);
assert!(!validation.columns[0].is_nullable);
assert_eq!(validation.columns[0].rust_type, "i32");
assert_eq!(validation.columns[1].name, "name");
assert!(validation.columns[1].is_nullable);
assert_eq!(validation.columns[1].rust_type, "Option<String>");
assert_eq!(validation.param_pg_oids.as_slice(), &[23u32]);
assert_eq!(validation.param_is_pg_enum.as_slice(), &[false]);
}
#[test]
fn validation_to_cached_preserves_all_fields() {
let validation = ValidationResult {
columns: vec![ColumnInfo {
name: "count".into(),
pg_oid: 20,
pg_type_name: "int8".into(),
is_nullable: false,
rust_type: "i64".into(),
}],
param_pg_oids: smallvec::smallvec![25, 23],
param_is_pg_enum: smallvec::smallvec![false, false],
rewritten_sql: None,
#[cfg(feature = "explain")]
explain_plan: None,
};
let parsed = crate::parse::parse_query(
"SELECT COUNT(*) AS count FROM users WHERE name = $name: &str AND id = $id: i32",
)
.expect("parse failed");
let hash = sql_hash(&parsed.normalized_sql);
let cached = validation_to_cached(hash, &parsed, &validation);
assert_eq!(cached.sql_hash, hash);
assert_eq!(cached.normalized_sql, parsed.normalized_sql);
assert_eq!(cached.columns.len(), 1);
assert_eq!(cached.columns[0].name, "count");
assert_eq!(cached.columns[0].pg_oid, 20);
assert_eq!(cached.columns[0].rust_type, "i64");
assert_eq!(cached.param_pg_oids, vec![25, 23]);
assert_eq!(cached.param_rust_types, vec!["&str", "i32"]);
}
#[test]
fn sql_hash_deterministic() {
let h1 = sql_hash("select id from users where id = $1");
let h2 = sql_hash("select id from users where id = $1");
assert_eq!(h1, h2);
}
#[test]
fn sql_hash_different_for_different_sql() {
let h1 = sql_hash("select id from users where id = $1");
let h2 = sql_hash("select name from users where id = $1");
assert_ne!(h1, h2);
}
#[test]
fn write_and_read_enveloped_cache_file() {
let tmp = TempDir::new().expect("tempdir");
let queries_dir = tmp.path().join("queries");
std::fs::create_dir_all(&queries_dir).expect("mkdir");
let cached = sample_cached_query();
let bytes = encode_enveloped(&cached);
let path = queries_dir.join(format!("{:016x}.bitcode", cached.sql_hash));
std::fs::write(&path, &bytes).expect("write");
let read_bytes = std::fs::read(&path).expect("read");
let decoded = decode_enveloped(&read_bytes).expect("decode");
assert_eq!(decoded.sql_hash, cached.sql_hash);
assert_eq!(decoded.normalized_sql, cached.normalized_sql);
}
#[test]
fn corrupted_cache_file_returns_error() {
let tmp = TempDir::new().expect("tempdir");
let queries_dir = tmp.path().join("queries");
std::fs::create_dir_all(&queries_dir).expect("mkdir");
let path = queries_dir.join("deadbeefcafe1234.bitcode");
let mut f = std::fs::File::create(&path).expect("create");
f.write_all(b"this is not bitcode").expect("write");
let read_bytes = std::fs::read(&path).expect("read");
let result = bitcode::decode::<CacheEnvelope>(&read_bytes);
assert!(result.is_err(), "corrupted file should fail to decode");
}
#[test]
fn empty_cache_file_returns_error() {
let tmp = TempDir::new().expect("tempdir");
let queries_dir = tmp.path().join("queries");
std::fs::create_dir_all(&queries_dir).expect("mkdir");
let path = queries_dir.join("0000000000000000.bitcode");
std::fs::write(&path, b"").expect("write");
let read_bytes = std::fs::read(&path).expect("read");
let result = bitcode::decode::<CacheEnvelope>(&read_bytes);
assert!(result.is_err(), "empty file should fail to decode");
}
#[test]
fn is_offline_default_false() {
let _ = is_offline();
}
#[test]
fn cached_query_with_no_columns_round_trips() {
let cached = CachedQuery {
sql_hash: 123,
normalized_sql: "delete from users where id = $1".into(),
columns: vec![],
param_pg_oids: vec![23],
param_is_pg_enum: vec![false],
bsql_version: BSQL_VERSION.to_owned(),
param_rust_types: vec!["i32".into()],
rewritten_sql: None,
};
let bytes = encode_enveloped(&cached);
let decoded = decode_enveloped(&bytes).expect("decode");
assert!(decoded.columns.is_empty());
assert_eq!(decoded.param_pg_oids, vec![23]);
}
#[test]
fn cached_query_with_pg_enum_round_trips() {
let cached = CachedQuery {
sql_hash: 456,
normalized_sql: "select status from tickets where status = $1".into(),
columns: vec![CachedColumn {
name: "status".into(),
pg_oid: 25, pg_type_name: "text".into(),
is_nullable: false,
rust_type: "String".into(),
}],
param_pg_oids: vec![99999],
param_is_pg_enum: vec![true],
bsql_version: BSQL_VERSION.to_owned(),
param_rust_types: vec!["String".into()],
rewritten_sql: None,
};
let bytes = encode_enveloped(&cached);
let decoded = decode_enveloped(&bytes).expect("decode");
assert_eq!(decoded.param_is_pg_enum, vec![true]);
assert_eq!(decoded.columns[0].pg_type_name, "text");
}
#[test]
fn raw_cached_query_without_envelope_fails() {
let cached = sample_cached_query();
let raw_bytes = bitcode::encode(&cached);
match bitcode::decode::<CacheEnvelope>(&raw_bytes) {
Err(_) => {} Ok(env) => assert_ne!(
env.version, CACHE_FORMAT_VERSION,
"raw CachedQuery bytes must not accidentally pass version check"
),
}
}
#[test]
fn validate_known_base_types() {
assert!(validate_cached_type("i32").is_ok());
assert!(validate_cached_type("i64").is_ok());
assert!(validate_cached_type("String").is_ok());
assert!(validate_cached_type("bool").is_ok());
assert!(validate_cached_type("f64").is_ok());
assert!(validate_cached_type("u32").is_ok());
assert!(validate_cached_type("()").is_ok());
assert!(validate_cached_type("Vec<u8>").is_ok());
assert!(validate_cached_type("Vec<String>").is_ok());
assert!(validate_cached_type("Vec<Vec<u8>>").is_ok());
}
#[test]
fn validate_option_wrapped_types() {
assert!(validate_cached_type("Option<i32>").is_ok());
assert!(validate_cached_type("Option<String>").is_ok());
assert!(validate_cached_type("Option<Vec<u8>>").is_ok());
}
#[test]
fn validate_rejects_removed_enum_string() {
assert!(validate_cached_type("::bsql_core::types::EnumString").is_err());
assert!(validate_cached_type("Option<::bsql_core::types::EnumString>").is_err());
}
#[test]
fn validate_feature_gated_types() {
assert!(validate_cached_type("::time::OffsetDateTime").is_ok());
assert!(validate_cached_type("::time::PrimitiveDateTime").is_ok());
assert!(validate_cached_type("::time::Date").is_ok());
assert!(validate_cached_type("::time::Time").is_ok());
assert!(validate_cached_type("::chrono::DateTime<::chrono::Utc>").is_ok());
assert!(validate_cached_type("::chrono::NaiveDateTime").is_ok());
assert!(validate_cached_type("::uuid::Uuid").is_ok());
assert!(validate_cached_type("::rust_decimal::Decimal").is_ok());
}
#[test]
fn validate_feature_gated_option_types() {
assert!(validate_cached_type("Option<::time::OffsetDateTime>").is_ok());
assert!(validate_cached_type("Option<::uuid::Uuid>").is_ok());
assert!(validate_cached_type("Option<::rust_decimal::Decimal>").is_ok());
}
#[test]
fn validate_feature_gated_vec_types() {
assert!(validate_cached_type("Vec<::time::OffsetDateTime>").is_ok());
assert!(validate_cached_type("Vec<::chrono::NaiveDate>").is_ok());
assert!(validate_cached_type("Vec<::uuid::Uuid>").is_ok());
assert!(validate_cached_type("Vec<::rust_decimal::Decimal>").is_ok());
}
#[test]
fn validate_option_vec_feature_gated_types() {
assert!(validate_cached_type("Option<Vec<::time::Date>>").is_ok());
assert!(validate_cached_type("Option<Vec<::uuid::Uuid>>").is_ok());
}
#[test]
fn validate_rejects_unknown_types() {
let err = validate_cached_type("std::process::Command").unwrap_err();
assert!(err.contains("unexpected type"), "error: {err}");
let err = validate_cached_type("SomeUserType").unwrap_err();
assert!(err.contains("unexpected type"), "error: {err}");
}
#[test]
fn validate_rejects_invalid_syntax() {
let err = validate_cached_type("not a type!!").unwrap_err();
assert!(err.contains("invalid type syntax"), "error: {err}");
}
#[test]
fn validate_rejects_injected_code() {
let err = validate_cached_type("std::fs::File").unwrap_err();
assert!(err.contains("unexpected type"), "error: {err}");
}
#[test]
fn temp_filename_includes_pid() {
let pid = std::process::id();
let hash: u64 = 0xCAFE;
let tmp_name = format!("{hash:016x}.{pid}.bitcode.tmp");
assert!(
tmp_name.contains(&pid.to_string()),
"temp filename should include PID: {tmp_name}"
);
assert!(tmp_name.ends_with(".bitcode.tmp"));
assert!(tmp_name.starts_with("000000000000cafe."));
}
#[test]
fn walk_up_finds_existing_bsql_dir() {
let tmp = TempDir::new().expect("tempdir");
let bsql_dir = tmp.path().join(".bsql");
std::fs::create_dir_all(&bsql_dir).expect("mkdir");
let sub_crate = tmp.path().join("crates").join("mylib");
std::fs::create_dir_all(&sub_crate).expect("mkdir");
let mut search = sub_crate.clone();
let mut found = None;
loop {
let candidate = search.join(".bsql");
if candidate.is_dir() {
found = Some(candidate.join("queries"));
break;
}
if !search.pop() {
break;
}
}
assert_eq!(
found,
Some(bsql_dir.join("queries")),
"walk should find .bsql at workspace root"
);
}
#[test]
fn walk_up_falls_back_to_start_dir() {
let tmp = TempDir::new().expect("tempdir");
let start = tmp.path().join("projects").join("mylib");
std::fs::create_dir_all(&start).expect("mkdir");
let mut search = start.clone();
let mut found = None;
loop {
let candidate = search.join(".bsql");
if candidate.is_dir() {
found = Some(candidate.join("queries"));
break;
}
if !search.pop() {
break;
}
}
assert!(found.is_none(), "no .bsql/ should be found");
let fallback = start.join(".bsql").join("queries");
assert!(fallback.to_str().unwrap().contains("mylib"));
}
#[test]
fn write_cache_and_lookup_roundtrip() {
let tmp = TempDir::new().expect("tempdir");
let queries_dir = tmp.path().join(".bsql").join("queries");
std::fs::create_dir_all(&queries_dir).expect("mkdir");
let cached = sample_cached_query();
let bytes = encode_enveloped(&cached);
let path = queries_dir.join(format!("{:016x}.bitcode", cached.sql_hash));
std::fs::write(&path, &bytes).expect("write");
let read_bytes = std::fs::read(&path).expect("read");
let envelope: CacheEnvelope = bitcode::decode(&read_bytes).expect("envelope decode");
assert_eq!(envelope.version, CACHE_FORMAT_VERSION);
let decoded: CachedQuery = bitcode::decode(&envelope.data).expect("inner decode");
assert_eq!(decoded.sql_hash, cached.sql_hash);
assert_eq!(decoded.normalized_sql, cached.normalized_sql);
assert_eq!(decoded.columns.len(), cached.columns.len());
assert_eq!(decoded.param_pg_oids, cached.param_pg_oids);
for col in &decoded.columns {
validate_cached_type(&col.rust_type).expect("type validation failed");
}
}
#[test]
fn sql_collision_guard_rejects_mismatched_sql() {
let cached = CachedQuery {
sql_hash: 999,
normalized_sql: "select a from t".into(),
columns: vec![],
param_pg_oids: vec![],
param_is_pg_enum: vec![],
bsql_version: BSQL_VERSION.to_owned(),
param_rust_types: vec![],
rewritten_sql: None,
};
let other_sql = "select b from t";
assert_ne!(cached.normalized_sql, other_sql);
}
#[test]
fn cache_v3_roundtrip_with_param_types() {
let query = CachedQuery {
sql_hash: 42,
normalized_sql: "SELECT id FROM users WHERE id = $1".to_owned(),
columns: vec![],
param_pg_oids: vec![23],
param_is_pg_enum: vec![false],
bsql_version: BSQL_VERSION.to_owned(),
param_rust_types: vec!["i32".to_owned()],
rewritten_sql: None,
};
let bytes = encode_enveloped(&query);
let decoded = decode_enveloped(&bytes).unwrap();
assert_eq!(decoded.param_rust_types, vec!["i32"]);
}
#[test]
fn cache_v2_migration_has_empty_param_types() {
let v2 = CachedQueryV2 {
sql_hash: 77,
normalized_sql: "SELECT 1".to_owned(),
columns: vec![],
param_pg_oids: vec![23],
param_is_pg_enum: vec![false],
bsql_version: BSQL_VERSION.to_owned(),
};
let inner_bytes = bitcode::encode(&v2);
let envelope = CacheEnvelope {
version: 2,
data: inner_bytes,
};
let bytes = bitcode::encode(&envelope);
let env: CacheEnvelope = bitcode::decode(&bytes).unwrap();
assert_eq!(env.version, 2);
let decoded_v2: CachedQueryV2 = bitcode::decode(&env.data).unwrap();
let migrated = CachedQuery {
sql_hash: decoded_v2.sql_hash,
normalized_sql: decoded_v2.normalized_sql,
columns: decoded_v2.columns,
param_pg_oids: decoded_v2.param_pg_oids,
param_is_pg_enum: decoded_v2.param_is_pg_enum,
bsql_version: decoded_v2.bsql_version,
param_rust_types: vec![],
rewritten_sql: None,
};
assert!(migrated.param_rust_types.is_empty());
assert_eq!(migrated.sql_hash, 77);
}
#[test]
fn cache_v3_multiple_param_types_roundtrip() {
let query = CachedQuery {
sql_hash: 100,
normalized_sql: "SELECT 1 FROM t WHERE a = $1 AND b = $2".to_owned(),
columns: vec![],
param_pg_oids: vec![23, 25],
param_is_pg_enum: vec![false, false],
bsql_version: BSQL_VERSION.to_owned(),
param_rust_types: vec!["i32".to_owned(), "&str".to_owned()],
rewritten_sql: None,
};
let bytes = encode_enveloped(&query);
let decoded = decode_enveloped(&bytes).unwrap();
assert_eq!(decoded.param_rust_types, vec!["i32", "&str"]);
}
#[test]
fn cache_v3_empty_param_types_roundtrip() {
let query = CachedQuery {
sql_hash: 200,
normalized_sql: "SELECT 1".to_owned(),
columns: vec![],
param_pg_oids: vec![],
param_is_pg_enum: vec![],
bsql_version: BSQL_VERSION.to_owned(),
param_rust_types: vec![],
rewritten_sql: None,
};
let bytes = encode_enveloped(&query);
let decoded = decode_enveloped(&bytes).unwrap();
assert!(decoded.param_rust_types.is_empty());
}
#[test]
fn param_type_mismatch_detected() {
let tmp = TempDir::new().expect("tempdir");
let queries_dir = tmp.path().join(".bsql").join("queries");
std::fs::create_dir_all(&queries_dir).expect("mkdir");
let normalized_sql = "SELECT id FROM users WHERE id = $1";
let hash = sql_hash(normalized_sql);
let cached = CachedQuery {
sql_hash: hash,
normalized_sql: normalized_sql.to_owned(),
columns: vec![CachedColumn {
name: "id".into(),
pg_oid: 23,
pg_type_name: "int4".into(),
is_nullable: false,
rust_type: "i32".into(),
}],
param_pg_oids: vec![23],
param_is_pg_enum: vec![false],
bsql_version: BSQL_VERSION.to_owned(),
param_rust_types: vec!["i32".to_owned()],
rewritten_sql: None,
};
let bytes = encode_enveloped(&cached);
let path = queries_dir.join(format!("{hash:016x}.bitcode"));
std::fs::write(&path, &bytes).expect("write");
let read_bytes = std::fs::read(&path).expect("read");
let envelope: CacheEnvelope = bitcode::decode(&read_bytes).expect("envelope decode");
let decoded: CachedQuery = bitcode::decode(&envelope.data).expect("inner decode");
let current_type = "&str";
assert_ne!(decoded.param_rust_types[0], current_type);
}
#[test]
fn param_count_mismatch_detected() {
let cached = CachedQuery {
sql_hash: 300,
normalized_sql: "SELECT 1 FROM t WHERE a = $1 AND b = $2".to_owned(),
columns: vec![],
param_pg_oids: vec![23, 25],
param_is_pg_enum: vec![false, false],
bsql_version: BSQL_VERSION.to_owned(),
param_rust_types: vec!["i32".to_owned(), "&str".to_owned()],
rewritten_sql: None,
};
let current_param_count = 3;
assert_ne!(
cached.param_rust_types.len(),
current_param_count,
"param count should differ: cache has {}, current has {}",
cached.param_rust_types.len(),
current_param_count
);
}
#[test]
fn v1_to_v3_migration_preserves_data() {
let v1_sql_hash: u64 = 555;
let v1_normalized_sql = "SELECT name FROM t WHERE id = $1".to_owned();
let v1_columns = vec![CachedColumn {
name: "name".into(),
pg_oid: 25,
pg_type_name: "text".into(),
is_nullable: true,
rust_type: "Option<String>".into(),
}];
let v1_param_pg_oids = vec![23u32];
let v1_param_is_pg_enum = vec![false];
let migrated = CachedQuery {
sql_hash: v1_sql_hash,
normalized_sql: v1_normalized_sql,
columns: v1_columns,
param_pg_oids: v1_param_pg_oids,
param_is_pg_enum: v1_param_is_pg_enum,
bsql_version: BSQL_VERSION.to_owned(),
param_rust_types: vec![],
rewritten_sql: None,
};
assert_eq!(migrated.sql_hash, 555);
assert_eq!(migrated.normalized_sql, "SELECT name FROM t WHERE id = $1");
assert_eq!(migrated.columns.len(), 1);
assert_eq!(migrated.columns[0].name, "name");
assert_eq!(migrated.columns[0].rust_type, "Option<String>");
assert_eq!(migrated.param_pg_oids, vec![23]);
assert_eq!(migrated.param_is_pg_enum, vec![false]);
assert!(migrated.param_rust_types.is_empty());
let bytes = encode_enveloped(&migrated);
let decoded = decode_enveloped(&bytes).unwrap();
assert_eq!(decoded.sql_hash, 555);
assert!(decoded.param_rust_types.is_empty());
}
#[test]
fn v2_migrated_cache_skips_type_check() {
let v2 = CachedQueryV2 {
sql_hash: 888,
normalized_sql: "SELECT 1 WHERE $1 = 1".to_owned(),
columns: vec![],
param_pg_oids: vec![23],
param_is_pg_enum: vec![false],
bsql_version: BSQL_VERSION.to_owned(),
};
let inner_bytes = bitcode::encode(&v2);
let envelope = CacheEnvelope {
version: 2,
data: inner_bytes,
};
let bytes = bitcode::encode(&envelope);
let env: CacheEnvelope = bitcode::decode(&bytes).unwrap();
assert_eq!(env.version, 2);
let decoded_v2: CachedQueryV2 = bitcode::decode(&env.data).unwrap();
let migrated = CachedQuery {
sql_hash: decoded_v2.sql_hash,
normalized_sql: decoded_v2.normalized_sql,
columns: decoded_v2.columns,
param_pg_oids: decoded_v2.param_pg_oids,
param_is_pg_enum: decoded_v2.param_is_pg_enum,
bsql_version: decoded_v2.bsql_version,
param_rust_types: vec![],
rewritten_sql: None,
};
assert!(migrated.param_rust_types.is_empty());
}
#[test]
fn future_version_envelope_rejected() {
let cached = sample_cached_query();
let inner = bitcode::encode(&cached);
let envelope = CacheEnvelope {
version: CACHE_FORMAT_VERSION + 1, data: inner,
};
let bytes = bitcode::encode(&envelope);
let err = decode_enveloped(&bytes).unwrap_err();
assert!(
err.contains("version mismatch"),
"future version should be rejected: {err}"
);
}
#[test]
fn version_zero_envelope_rejected() {
let inner = bitcode::encode(&sample_cached_query());
let envelope = CacheEnvelope {
version: 0,
data: inner,
};
let bytes = bitcode::encode(&envelope);
let err = decode_enveloped(&bytes).unwrap_err();
assert!(
err.contains("version mismatch"),
"version 0 should be rejected: {err}"
);
}
#[test]
fn empty_data_in_envelope_fails() {
let envelope = CacheEnvelope {
version: CACHE_FORMAT_VERSION,
data: vec![],
};
let bytes = bitcode::encode(&envelope);
let err = decode_enveloped(&bytes).unwrap_err();
assert!(err.contains("inner"), "empty data should fail: {err}");
}
#[test]
fn truncated_data_in_envelope_fails() {
let cached = sample_cached_query();
let inner = bitcode::encode(&cached);
let truncated = &inner[..inner.len() / 2]; let envelope = CacheEnvelope {
version: CACHE_FORMAT_VERSION,
data: truncated.to_vec(),
};
let bytes = bitcode::encode(&envelope);
let err = decode_enveloped(&bytes).unwrap_err();
assert!(!err.is_empty(), "truncated data should fail: {err}");
}
#[test]
fn cached_query_many_columns_round_trips() {
let columns: Vec<CachedColumn> = (0..50)
.map(|i| CachedColumn {
name: format!("col_{i}"),
pg_oid: 23,
pg_type_name: "int4".into(),
is_nullable: i % 2 == 0,
rust_type: if i % 2 == 0 {
"Option<i32>".into()
} else {
"i32".into()
},
})
.collect();
let cached = CachedQuery {
sql_hash: 12345,
normalized_sql: "SELECT many columns...".into(),
columns,
param_pg_oids: vec![23, 25],
param_is_pg_enum: vec![false, false],
bsql_version: BSQL_VERSION.to_owned(),
param_rust_types: vec!["i32".into(), "&str".into()],
rewritten_sql: None,
};
let bytes = encode_enveloped(&cached);
let decoded = decode_enveloped(&bytes).unwrap();
assert_eq!(decoded.columns.len(), 50);
assert_eq!(decoded.columns[0].name, "col_0");
assert!(decoded.columns[0].is_nullable);
assert_eq!(decoded.columns[49].name, "col_49");
assert!(!decoded.columns[49].is_nullable);
}
#[test]
fn cached_query_empty_sql_round_trips() {
let cached = CachedQuery {
sql_hash: 0,
normalized_sql: String::new(),
columns: vec![],
param_pg_oids: vec![],
param_is_pg_enum: vec![],
bsql_version: BSQL_VERSION.to_owned(),
param_rust_types: vec![],
rewritten_sql: None,
};
let bytes = encode_enveloped(&cached);
let decoded = decode_enveloped(&bytes).unwrap();
assert!(decoded.normalized_sql.is_empty());
assert_eq!(decoded.sql_hash, 0);
}
#[test]
fn validate_cached_type_i16() {
assert!(validate_cached_type("i16").is_ok());
}
#[test]
fn validate_cached_type_f32() {
assert!(validate_cached_type("f32").is_ok());
}
#[test]
fn validate_cached_type_option_i16() {
assert!(validate_cached_type("Option<i16>").is_ok());
}
#[test]
fn validate_cached_type_option_f32() {
assert!(validate_cached_type("Option<f32>").is_ok());
}
#[test]
fn validate_cached_type_vec_i32() {
assert!(validate_cached_type("Vec<i32>").is_ok());
}
#[test]
fn validate_cached_type_vec_i64() {
assert!(validate_cached_type("Vec<i64>").is_ok());
}
#[test]
fn validate_cached_type_vec_bool() {
assert!(validate_cached_type("Vec<bool>").is_ok());
}
#[test]
fn validate_cached_type_vec_f32() {
assert!(validate_cached_type("Vec<f32>").is_ok());
}
#[test]
fn validate_cached_type_vec_f64() {
assert!(validate_cached_type("Vec<f64>").is_ok());
}
#[test]
fn validate_cached_type_vec_i16() {
assert!(validate_cached_type("Vec<i16>").is_ok());
}
#[test]
fn validate_cached_type_empty_string() {
let result = validate_cached_type("");
assert!(result.is_err(), "empty type string should be rejected");
}
#[test]
fn sql_hash_empty_string_deterministic() {
let h1 = sql_hash("");
let h2 = sql_hash("");
assert_eq!(h1, h2);
}
#[test]
fn sql_hash_whitespace_matters() {
let h1 = sql_hash("SELECT 1");
let h2 = sql_hash("SELECT 1");
assert_ne!(h1, h2, "whitespace should produce different hashes");
}
}