use std::hash::{Hash, Hasher};
use std::path::PathBuf;
use std::sync::LazyLock;
use bitcode::{Decode, Encode};
use crate::parse::ParsedQuery;
use crate::validate::{ColumnInfo, ValidationResult};
const CACHE_FORMAT_VERSION: u8 = 1;
#[derive(Encode, Decode)]
struct CacheEnvelope {
version: u8,
data: Vec<u8>,
}
#[derive(Debug, Clone, Encode, Decode)]
pub struct CachedQuery {
pub sql_hash: u64,
pub normalized_sql: String,
pub columns: Vec<CachedColumn>,
pub param_pg_oids: Vec<u32>,
pub param_is_pg_enum: Vec<bool>,
}
#[derive(Debug, Clone, Encode, Decode)]
pub struct CachedColumn {
pub name: String,
pub pg_oid: u32,
pub pg_type_name: String,
pub is_nullable: bool,
pub rust_type: String,
}
static IS_OFFLINE: LazyLock<bool> = LazyLock::new(|| {
std::env::var("BSQL_OFFLINE")
.map(|v| v == "true" || v == "1")
.unwrap_or(false)
});
pub fn is_offline() -> bool {
*IS_OFFLINE
}
static CACHE_DIR: LazyLock<Result<PathBuf, String>> = LazyLock::new(resolve_cache_dir);
fn resolve_cache_dir() -> Result<PathBuf, String> {
let manifest_dir =
std::env::var("CARGO_MANIFEST_DIR").map_err(|_| "CARGO_MANIFEST_DIR not set".to_owned())?;
let dir = PathBuf::from(&manifest_dir);
let mut search = dir.clone();
loop {
let candidate = search.join(".bsql");
if candidate.is_dir() {
return Ok(candidate.join("queries"));
}
if !search.pop() {
break;
}
}
Ok(dir.join(".bsql").join("queries"))
}
fn cache_dir() -> Result<&'static PathBuf, String> {
CACHE_DIR.as_ref().map_err(|e| e.clone())
}
pub fn sql_hash(normalized_sql: &str) -> u64 {
let mut hasher = rapidhash::quality::RapidHasher::default();
normalized_sql.hash(&mut hasher);
hasher.finish()
}
pub fn lookup_cached_validation(parsed: &ParsedQuery) -> Result<ValidationResult, String> {
let hash = sql_hash(&parsed.normalized_sql);
let dir = cache_dir()?;
let path = dir.join(format!("{hash:016x}.bitcode"));
if !path.exists() {
return Err(format!(
"query not found in offline cache (hash {hash:016x}). \
The SQL may have changed since the cache was last built. \
Run `cargo build` with a live PostgreSQL connection to update \
the cache, then rebuild with BSQL_OFFLINE=true.\n \
SQL: {}",
parsed.normalized_sql
));
}
let bytes = std::fs::read(&path)
.map_err(|e| format!("failed to read offline cache file {}: {e}", path.display()))?;
let envelope: CacheEnvelope = bitcode::decode(&bytes).map_err(|e| {
format!(
"failed to decode offline cache file {} (file may be corrupted \
or from an incompatible bsql version — run `cargo build` with \
a live PostgreSQL connection to regenerate): {e}",
path.display()
)
})?;
if envelope.version != CACHE_FORMAT_VERSION {
return Err(format!(
"offline cache was generated by a different bsql version \
(format v{}, expected v{}) — run `cargo build` with a live \
PostgreSQL connection to regenerate",
envelope.version, CACHE_FORMAT_VERSION
));
}
let cached: CachedQuery = bitcode::decode(&envelope.data).map_err(|e| {
format!(
"failed to decode cached query in {} (file may be corrupted \
— run `cargo build` with a live PostgreSQL connection to \
regenerate): {e}",
path.display()
)
})?;
if cached.normalized_sql != parsed.normalized_sql {
return Err(format!(
"offline cache hash collision detected (hash {hash:016x}). \
Cached SQL does not match current SQL. Run `cargo build` \
with a live PostgreSQL connection to regenerate the cache.\n \
cached: {}\n current: {}",
cached.normalized_sql, parsed.normalized_sql
));
}
for col in &cached.columns {
validate_cached_type(&col.rust_type)?;
}
Ok(cached_to_validation(&cached))
}
fn cached_to_validation(cached: &CachedQuery) -> ValidationResult {
let columns = cached
.columns
.iter()
.map(|c| ColumnInfo {
name: c.name.clone(),
pg_oid: c.pg_oid,
pg_type_name: c.pg_type_name.clone(),
is_nullable: c.is_nullable,
rust_type: c.rust_type.clone(),
})
.collect();
ValidationResult {
columns,
param_pg_oids: cached.param_pg_oids.clone(),
param_is_pg_enum: cached.param_is_pg_enum.clone(),
explain_plan: None, }
}
pub fn write_cache(parsed: &ParsedQuery, validation: &ValidationResult) {
if let Err(e) = write_cache_inner(parsed, validation) {
eprintln!("bsql: warning: failed to write offline cache: {e}");
}
}
fn write_cache_inner(parsed: &ParsedQuery, validation: &ValidationResult) -> Result<(), String> {
let dir = cache_dir()?;
std::fs::create_dir_all(dir).map_err(|e| {
format!(
"failed to create offline cache directory {}: {e}",
dir.display()
)
})?;
let hash = sql_hash(&parsed.normalized_sql);
let cached = validation_to_cached(hash, parsed, validation);
let inner_bytes = bitcode::encode(&cached);
let envelope = CacheEnvelope {
version: CACHE_FORMAT_VERSION,
data: inner_bytes,
};
let bytes = bitcode::encode(&envelope);
let path = dir.join(format!("{hash:016x}.bitcode"));
let tmp_path = dir.join(format!("{hash:016x}.{}.bitcode.tmp", std::process::id()));
std::fs::write(&tmp_path, &bytes).map_err(|e| {
format!(
"failed to write offline cache file {}: {e}",
tmp_path.display()
)
})?;
std::fs::rename(&tmp_path, &path).map_err(|e| {
format!(
"failed to rename offline cache file {} -> {}: {e}",
tmp_path.display(),
path.display()
)
})?;
Ok(())
}
fn validation_to_cached(
hash: u64,
parsed: &ParsedQuery,
validation: &ValidationResult,
) -> CachedQuery {
let columns = validation
.columns
.iter()
.map(|c| CachedColumn {
name: c.name.clone(),
pg_oid: c.pg_oid,
pg_type_name: c.pg_type_name.clone(),
is_nullable: c.is_nullable,
rust_type: c.rust_type.clone(),
})
.collect();
CachedQuery {
sql_hash: hash,
normalized_sql: parsed.normalized_sql.clone(),
columns,
param_pg_oids: validation.param_pg_oids.clone(),
param_is_pg_enum: validation.param_is_pg_enum.clone(),
}
}
fn validate_cached_type(rust_type: &str) -> Result<(), String> {
if syn::parse_str::<syn::Type>(rust_type).is_err() {
return Err(format!(
"offline cache contains invalid type syntax: `{rust_type}` \
— run `cargo build` with a live PostgreSQL connection to regenerate"
));
}
let inner = rust_type
.strip_prefix("Option<")
.and_then(|s| s.strip_suffix('>'))
.unwrap_or(rust_type);
let element = inner
.strip_prefix("Vec<")
.and_then(|s| s.strip_suffix('>'))
.unwrap_or(inner);
const KNOWN_TYPES: &[&str] = &[
"bool",
"i16",
"i32",
"i64",
"f32",
"f64",
"u32",
"()",
"String",
"Vec<u8>",
"Vec<bool>",
"Vec<i16>",
"Vec<i32>",
"Vec<i64>",
"Vec<f32>",
"Vec<f64>",
"Vec<String>",
"Vec<Vec<u8>>",
"::bsql_core::types::EnumString",
];
const KNOWN_PREFIXES: &[&str] = &["::time::", "::chrono::", "::uuid::", "::rust_decimal::"];
if KNOWN_TYPES.contains(&inner)
|| KNOWN_PREFIXES.iter().any(|p| inner.starts_with(p))
|| KNOWN_PREFIXES.iter().any(|p| element.starts_with(p))
{
return Ok(());
}
Err(format!(
"offline cache contains unexpected type: `{rust_type}` \
— run `cargo build` with a live PostgreSQL connection to regenerate"
))
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::TempDir;
fn sample_cached_query() -> CachedQuery {
CachedQuery {
sql_hash: 0xDEAD_BEEF_CAFE_1234,
normalized_sql: "select id, name from users where id = $1".into(),
columns: vec![
CachedColumn {
name: "id".into(),
pg_oid: 23,
pg_type_name: "int4".into(),
is_nullable: false,
rust_type: "i32".into(),
},
CachedColumn {
name: "name".into(),
pg_oid: 25,
pg_type_name: "text".into(),
is_nullable: true,
rust_type: "Option<String>".into(),
},
],
param_pg_oids: vec![23],
param_is_pg_enum: vec![false],
}
}
fn encode_enveloped(cached: &CachedQuery) -> Vec<u8> {
let inner = bitcode::encode(cached);
let envelope = CacheEnvelope {
version: CACHE_FORMAT_VERSION,
data: inner,
};
bitcode::encode(&envelope)
}
fn decode_enveloped(bytes: &[u8]) -> Result<CachedQuery, String> {
let envelope: CacheEnvelope =
bitcode::decode(bytes).map_err(|e| format!("envelope: {e}"))?;
if envelope.version != CACHE_FORMAT_VERSION {
return Err(format!(
"version mismatch: got {}, expected {}",
envelope.version, CACHE_FORMAT_VERSION
));
}
bitcode::decode(&envelope.data).map_err(|e| format!("inner: {e}"))
}
#[test]
fn envelope_round_trip() {
let original = sample_cached_query();
let bytes = encode_enveloped(&original);
let decoded = decode_enveloped(&bytes).expect("decode failed");
assert_eq!(decoded.sql_hash, original.sql_hash);
assert_eq!(decoded.normalized_sql, original.normalized_sql);
assert_eq!(decoded.columns.len(), original.columns.len());
assert_eq!(decoded.param_pg_oids, original.param_pg_oids);
assert_eq!(decoded.param_is_pg_enum, original.param_is_pg_enum);
for (d, o) in decoded.columns.iter().zip(&original.columns) {
assert_eq!(d.name, o.name);
assert_eq!(d.pg_oid, o.pg_oid);
assert_eq!(d.pg_type_name, o.pg_type_name);
assert_eq!(d.is_nullable, o.is_nullable);
assert_eq!(d.rust_type, o.rust_type);
}
}
#[test]
fn format_version_mismatch_returns_clear_error() {
let cached = sample_cached_query();
let inner = bitcode::encode(&cached);
let envelope = CacheEnvelope {
version: 99, data: inner,
};
let bytes = bitcode::encode(&envelope);
let err = decode_enveloped(&bytes).unwrap_err();
assert!(
err.contains("version mismatch"),
"error should mention version: {err}"
);
}
#[test]
fn cached_to_validation_preserves_all_fields() {
let cached = sample_cached_query();
let validation = cached_to_validation(&cached);
assert_eq!(validation.columns.len(), 2);
assert_eq!(validation.columns[0].name, "id");
assert_eq!(validation.columns[0].pg_oid, 23);
assert!(!validation.columns[0].is_nullable);
assert_eq!(validation.columns[0].rust_type, "i32");
assert_eq!(validation.columns[1].name, "name");
assert!(validation.columns[1].is_nullable);
assert_eq!(validation.columns[1].rust_type, "Option<String>");
assert_eq!(validation.param_pg_oids, vec![23]);
assert_eq!(validation.param_is_pg_enum, vec![false]);
}
#[test]
fn validation_to_cached_preserves_all_fields() {
let validation = ValidationResult {
columns: vec![ColumnInfo {
name: "count".into(),
pg_oid: 20,
pg_type_name: "int8".into(),
is_nullable: false,
rust_type: "i64".into(),
}],
param_pg_oids: vec![25, 23],
param_is_pg_enum: vec![false, false],
explain_plan: None,
};
let parsed = crate::parse::parse_query(
"SELECT COUNT(*) AS count FROM users WHERE name = $name: &str AND id = $id: i32",
)
.expect("parse failed");
let hash = sql_hash(&parsed.normalized_sql);
let cached = validation_to_cached(hash, &parsed, &validation);
assert_eq!(cached.sql_hash, hash);
assert_eq!(cached.normalized_sql, parsed.normalized_sql);
assert_eq!(cached.columns.len(), 1);
assert_eq!(cached.columns[0].name, "count");
assert_eq!(cached.columns[0].pg_oid, 20);
assert_eq!(cached.columns[0].rust_type, "i64");
assert_eq!(cached.param_pg_oids, vec![25, 23]);
}
#[test]
fn sql_hash_deterministic() {
let h1 = sql_hash("select id from users where id = $1");
let h2 = sql_hash("select id from users where id = $1");
assert_eq!(h1, h2);
}
#[test]
fn sql_hash_different_for_different_sql() {
let h1 = sql_hash("select id from users where id = $1");
let h2 = sql_hash("select name from users where id = $1");
assert_ne!(h1, h2);
}
#[test]
fn write_and_read_enveloped_cache_file() {
let tmp = TempDir::new().expect("tempdir");
let queries_dir = tmp.path().join("queries");
std::fs::create_dir_all(&queries_dir).expect("mkdir");
let cached = sample_cached_query();
let bytes = encode_enveloped(&cached);
let path = queries_dir.join(format!("{:016x}.bitcode", cached.sql_hash));
std::fs::write(&path, &bytes).expect("write");
let read_bytes = std::fs::read(&path).expect("read");
let decoded = decode_enveloped(&read_bytes).expect("decode");
assert_eq!(decoded.sql_hash, cached.sql_hash);
assert_eq!(decoded.normalized_sql, cached.normalized_sql);
}
#[test]
fn corrupted_cache_file_returns_error() {
let tmp = TempDir::new().expect("tempdir");
let queries_dir = tmp.path().join("queries");
std::fs::create_dir_all(&queries_dir).expect("mkdir");
let path = queries_dir.join("deadbeefcafe1234.bitcode");
let mut f = std::fs::File::create(&path).expect("create");
f.write_all(b"this is not bitcode").expect("write");
let read_bytes = std::fs::read(&path).expect("read");
let result = bitcode::decode::<CacheEnvelope>(&read_bytes);
assert!(result.is_err(), "corrupted file should fail to decode");
}
#[test]
fn empty_cache_file_returns_error() {
let tmp = TempDir::new().expect("tempdir");
let queries_dir = tmp.path().join("queries");
std::fs::create_dir_all(&queries_dir).expect("mkdir");
let path = queries_dir.join("0000000000000000.bitcode");
std::fs::write(&path, b"").expect("write");
let read_bytes = std::fs::read(&path).expect("read");
let result = bitcode::decode::<CacheEnvelope>(&read_bytes);
assert!(result.is_err(), "empty file should fail to decode");
}
#[test]
fn is_offline_default_false() {
let _ = is_offline();
}
#[test]
fn cached_query_with_no_columns_round_trips() {
let cached = CachedQuery {
sql_hash: 123,
normalized_sql: "delete from users where id = $1".into(),
columns: vec![],
param_pg_oids: vec![23],
param_is_pg_enum: vec![false],
};
let bytes = encode_enveloped(&cached);
let decoded = decode_enveloped(&bytes).expect("decode");
assert!(decoded.columns.is_empty());
assert_eq!(decoded.param_pg_oids, vec![23]);
}
#[test]
fn cached_query_with_pg_enum_round_trips() {
let cached = CachedQuery {
sql_hash: 456,
normalized_sql: "select status from tickets where status = $1".into(),
columns: vec![CachedColumn {
name: "status".into(),
pg_oid: 99999, pg_type_name: "ticket_status".into(),
is_nullable: false,
rust_type: "::bsql_core::types::EnumString".into(),
}],
param_pg_oids: vec![99999],
param_is_pg_enum: vec![true],
};
let bytes = encode_enveloped(&cached);
let decoded = decode_enveloped(&bytes).expect("decode");
assert_eq!(decoded.param_is_pg_enum, vec![true]);
assert_eq!(decoded.columns[0].pg_type_name, "ticket_status");
}
#[test]
fn raw_cached_query_without_envelope_fails() {
let cached = sample_cached_query();
let raw_bytes = bitcode::encode(&cached);
match bitcode::decode::<CacheEnvelope>(&raw_bytes) {
Err(_) => {} Ok(env) => assert_ne!(
env.version, CACHE_FORMAT_VERSION,
"raw CachedQuery bytes must not accidentally pass version check"
),
}
}
#[test]
fn validate_known_base_types() {
assert!(validate_cached_type("i32").is_ok());
assert!(validate_cached_type("i64").is_ok());
assert!(validate_cached_type("String").is_ok());
assert!(validate_cached_type("bool").is_ok());
assert!(validate_cached_type("f64").is_ok());
assert!(validate_cached_type("u32").is_ok());
assert!(validate_cached_type("()").is_ok());
assert!(validate_cached_type("Vec<u8>").is_ok());
assert!(validate_cached_type("Vec<String>").is_ok());
assert!(validate_cached_type("Vec<Vec<u8>>").is_ok());
}
#[test]
fn validate_option_wrapped_types() {
assert!(validate_cached_type("Option<i32>").is_ok());
assert!(validate_cached_type("Option<String>").is_ok());
assert!(validate_cached_type("Option<Vec<u8>>").is_ok());
}
#[test]
fn validate_enum_string() {
assert!(validate_cached_type("::bsql_core::types::EnumString").is_ok());
assert!(validate_cached_type("Option<::bsql_core::types::EnumString>").is_ok());
}
#[test]
fn validate_feature_gated_types() {
assert!(validate_cached_type("::time::OffsetDateTime").is_ok());
assert!(validate_cached_type("::time::PrimitiveDateTime").is_ok());
assert!(validate_cached_type("::time::Date").is_ok());
assert!(validate_cached_type("::time::Time").is_ok());
assert!(validate_cached_type("::chrono::DateTime<::chrono::Utc>").is_ok());
assert!(validate_cached_type("::chrono::NaiveDateTime").is_ok());
assert!(validate_cached_type("::uuid::Uuid").is_ok());
assert!(validate_cached_type("::rust_decimal::Decimal").is_ok());
}
#[test]
fn validate_feature_gated_option_types() {
assert!(validate_cached_type("Option<::time::OffsetDateTime>").is_ok());
assert!(validate_cached_type("Option<::uuid::Uuid>").is_ok());
assert!(validate_cached_type("Option<::rust_decimal::Decimal>").is_ok());
}
#[test]
fn validate_feature_gated_vec_types() {
assert!(validate_cached_type("Vec<::time::OffsetDateTime>").is_ok());
assert!(validate_cached_type("Vec<::chrono::NaiveDate>").is_ok());
assert!(validate_cached_type("Vec<::uuid::Uuid>").is_ok());
assert!(validate_cached_type("Vec<::rust_decimal::Decimal>").is_ok());
}
#[test]
fn validate_option_vec_feature_gated_types() {
assert!(validate_cached_type("Option<Vec<::time::Date>>").is_ok());
assert!(validate_cached_type("Option<Vec<::uuid::Uuid>>").is_ok());
}
#[test]
fn validate_rejects_unknown_types() {
let err = validate_cached_type("std::process::Command").unwrap_err();
assert!(err.contains("unexpected type"), "error: {err}");
let err = validate_cached_type("SomeUserType").unwrap_err();
assert!(err.contains("unexpected type"), "error: {err}");
}
#[test]
fn validate_rejects_invalid_syntax() {
let err = validate_cached_type("not a type!!").unwrap_err();
assert!(err.contains("invalid type syntax"), "error: {err}");
}
#[test]
fn validate_rejects_injected_code() {
let err = validate_cached_type("std::fs::File").unwrap_err();
assert!(err.contains("unexpected type"), "error: {err}");
}
#[test]
fn temp_filename_includes_pid() {
let pid = std::process::id();
let hash: u64 = 0xCAFE;
let tmp_name = format!("{hash:016x}.{pid}.bitcode.tmp");
assert!(
tmp_name.contains(&pid.to_string()),
"temp filename should include PID: {tmp_name}"
);
assert!(tmp_name.ends_with(".bitcode.tmp"));
assert!(tmp_name.starts_with("000000000000cafe."));
}
#[test]
fn walk_up_finds_existing_bsql_dir() {
let tmp = TempDir::new().expect("tempdir");
let bsql_dir = tmp.path().join(".bsql");
std::fs::create_dir_all(&bsql_dir).expect("mkdir");
let sub_crate = tmp.path().join("crates").join("mylib");
std::fs::create_dir_all(&sub_crate).expect("mkdir");
let mut search = sub_crate.clone();
let mut found = None;
loop {
let candidate = search.join(".bsql");
if candidate.is_dir() {
found = Some(candidate.join("queries"));
break;
}
if !search.pop() {
break;
}
}
assert_eq!(
found,
Some(bsql_dir.join("queries")),
"walk should find .bsql at workspace root"
);
}
#[test]
fn walk_up_falls_back_to_start_dir() {
let tmp = TempDir::new().expect("tempdir");
let start = tmp.path().join("projects").join("mylib");
std::fs::create_dir_all(&start).expect("mkdir");
let mut search = start.clone();
let mut found = None;
loop {
let candidate = search.join(".bsql");
if candidate.is_dir() {
found = Some(candidate.join("queries"));
break;
}
if !search.pop() {
break;
}
}
assert!(found.is_none(), "no .bsql/ should be found");
let fallback = start.join(".bsql").join("queries");
assert!(fallback.to_str().unwrap().contains("mylib"));
}
#[test]
fn write_cache_and_lookup_roundtrip() {
let tmp = TempDir::new().expect("tempdir");
let queries_dir = tmp.path().join(".bsql").join("queries");
std::fs::create_dir_all(&queries_dir).expect("mkdir");
let cached = sample_cached_query();
let bytes = encode_enveloped(&cached);
let path = queries_dir.join(format!("{:016x}.bitcode", cached.sql_hash));
std::fs::write(&path, &bytes).expect("write");
let read_bytes = std::fs::read(&path).expect("read");
let envelope: CacheEnvelope = bitcode::decode(&read_bytes).expect("envelope decode");
assert_eq!(envelope.version, CACHE_FORMAT_VERSION);
let decoded: CachedQuery = bitcode::decode(&envelope.data).expect("inner decode");
assert_eq!(decoded.sql_hash, cached.sql_hash);
assert_eq!(decoded.normalized_sql, cached.normalized_sql);
assert_eq!(decoded.columns.len(), cached.columns.len());
assert_eq!(decoded.param_pg_oids, cached.param_pg_oids);
for col in &decoded.columns {
validate_cached_type(&col.rust_type).expect("type validation failed");
}
}
#[test]
fn sql_collision_guard_rejects_mismatched_sql() {
let cached = CachedQuery {
sql_hash: 999,
normalized_sql: "select a from t".into(),
columns: vec![],
param_pg_oids: vec![],
param_is_pg_enum: vec![],
};
let other_sql = "select b from t";
assert_ne!(cached.normalized_sql, other_sql);
}
}