use smallvec::SmallVec;
use bsql_driver_postgres::{ColumnDesc, Connection, DriverError};
use crate::dynamic::QueryVariant;
use crate::parse::ParsedQuery;
#[derive(Debug, Clone)]
pub struct ColumnInfo {
pub name: String,
pub pg_oid: u32,
pub pg_type_name: String,
pub is_nullable: bool,
pub rust_type: String,
}
#[derive(Debug, Clone)]
pub struct ValidationResult {
pub columns: Vec<ColumnInfo>,
pub param_pg_oids: SmallVec<[u32; 8]>,
pub param_is_pg_enum: SmallVec<[bool; 8]>,
#[cfg(feature = "explain")]
pub explain_plan: Option<String>,
}
pub fn validate_query(
parsed: &ParsedQuery,
conn: &mut Connection,
) -> Result<ValidationResult, String> {
let result = conn
.prepare_describe(&parsed.positional_sql)
.map_err(|e| format_driver_error(&e, parsed))?;
let param_pg_oids: SmallVec<[u32; 8]> = result.param_oids.iter().copied().collect();
let param_is_pg_enum = detect_pg_enums(conn, &result.param_oids);
let columns = build_columns(conn, &result.columns)?;
Ok(ValidationResult {
columns,
param_pg_oids,
param_is_pg_enum,
#[cfg(feature = "explain")]
explain_plan: fetch_explain_plan(conn, parsed),
})
}
fn build_columns(
conn: &mut Connection,
pg_columns: &[ColumnDesc],
) -> Result<Vec<ColumnInfo>, String> {
let nullable_flags = resolve_nullability_batch(conn, pg_columns);
let enum_flags = detect_column_enums(conn, pg_columns);
let mut columns = Vec::with_capacity(pg_columns.len());
for (i, col) in pg_columns.iter().enumerate() {
let pg_oid = col.type_oid;
let pg_type_name = bsql_core::types::pg_name_for_oid(pg_oid)
.unwrap_or("unknown")
.to_owned();
let name = col.name.to_string();
let is_nullable = nullable_flags[i];
if enum_flags[i] {
return Err(format!(
"column \"{name}\" is PostgreSQL enum type `{pg_type_name}`. \
Define a Rust enum with #[bsql::pg_enum] or cast to text: {name}::text"
));
}
let base_rust_type = crate::types::resolve_rust_type(pg_oid)
.map_err(|msg| format!("column \"{name}\": {msg}"))?;
let rust_type = if is_nullable {
format!("Option<{base_rust_type}>")
} else {
base_rust_type.to_owned()
};
columns.push(ColumnInfo {
name,
pg_oid,
pg_type_name,
is_nullable,
rust_type,
});
}
Ok(columns)
}
#[cfg(feature = "explain")]
fn fetch_explain_plan(conn: &mut Connection, parsed: &ParsedQuery) -> Option<String> {
let explain_sql = format!("EXPLAIN (FORMAT TEXT, COSTS) {}", parsed.positional_sql);
match conn.simple_query_rows(&explain_sql) {
Ok(rows) => {
let lines: Vec<String> = rows
.into_iter()
.filter_map(|row| row.into_iter().next().flatten())
.collect();
if lines.is_empty() {
None
} else {
let plan_text = lines.join("\n");
let threshold = crate::explain::explain_threshold();
let warnings = crate::explain::analyze_plan(&plan_text, threshold);
for warning in &warnings {
eprintln!("warning: [bsql] {}", warning.message);
}
Some(plan_text)
}
}
Err(_) => None,
}
}
fn resolve_nullability_batch(conn: &mut Connection, columns: &[ColumnDesc]) -> Vec<bool> {
let col_count = columns.len();
let mut result = vec![true; col_count];
let mut table_oids: Vec<u32> = Vec::new();
let mut col_nums: Vec<i16> = Vec::new();
let mut col_indices: Vec<usize> = Vec::new();
for (i, col) in columns.iter().enumerate() {
if col.table_oid != 0 && col.column_id != 0 {
table_oids.push(col.table_oid);
col_nums.push(col.column_id);
col_indices.push(i);
}
}
if table_oids.is_empty() {
return result;
}
let oid_array = format!(
"ARRAY[{}]::oid[]",
table_oids
.iter()
.map(|o| o.to_string())
.collect::<Vec<_>>()
.join(",")
);
let num_array = format!(
"ARRAY[{}]::int2[]",
col_nums
.iter()
.map(|n| n.to_string())
.collect::<Vec<_>>()
.join(",")
);
let query = format!(
"SELECT a.attrelid, a.attnum, NOT a.attnotnull \
FROM pg_attribute a \
WHERE (a.attrelid, a.attnum) IN (\
SELECT unnest({oid_array}), unnest({num_array})\
)"
);
if let Ok(rows) = conn.simple_query_rows(&query) {
let mut lookup: std::collections::HashMap<(u32, i16), Vec<usize>> =
std::collections::HashMap::with_capacity(table_oids.len());
for (idx, (&t, &c)) in table_oids.iter().zip(col_nums.iter()).enumerate() {
lookup.entry((t, c)).or_default().push(col_indices[idx]);
}
for row in &rows {
let oid: u32 = row
.first()
.and_then(|v| v.as_deref())
.and_then(|s| s.parse().ok())
.unwrap_or(0);
let num: i16 = row
.get(1)
.and_then(|v| v.as_deref())
.and_then(|s| s.parse().ok())
.unwrap_or(0);
let is_nullable: bool = row
.get(2)
.and_then(|v| v.as_deref())
.map(|s| s == "t" || s == "true")
.unwrap_or(true);
if let Some(indices) = lookup.get(&(oid, num)) {
for &idx in indices {
result[idx] = is_nullable;
}
}
}
}
result
}
fn detect_pg_enums(conn: &mut Connection, oids: &[u32]) -> SmallVec<[bool; 8]> {
if oids.is_empty() {
return SmallVec::new();
}
let oid_list = oids
.iter()
.map(|o| o.to_string())
.collect::<Vec<_>>()
.join(",");
let query = format!("SELECT oid, typtype FROM pg_type WHERE oid IN ({oid_list})");
let mut enum_map: std::collections::HashMap<u32, bool> =
std::collections::HashMap::with_capacity(oids.len());
if let Ok(rows) = conn.simple_query_rows(&query) {
for row in &rows {
let oid: u32 = row
.first()
.and_then(|v| v.as_deref())
.and_then(|s| s.parse().ok())
.unwrap_or(0);
let typtype: &str = row.get(1).and_then(|v| v.as_deref()).unwrap_or("b");
enum_map.insert(oid, typtype == "e");
}
}
oids.iter()
.map(|oid| enum_map.get(oid).copied().unwrap_or(false))
.collect()
}
fn detect_column_enums(conn: &mut Connection, columns: &[ColumnDesc]) -> Vec<bool> {
let mut result = vec![false; columns.len()];
let custom_oids: Vec<(usize, u32)> = columns
.iter()
.enumerate()
.filter(|(_, c)| c.type_oid >= 10000)
.map(|(i, c)| (i, c.type_oid))
.collect();
if custom_oids.is_empty() {
return result;
}
let oid_list = custom_oids
.iter()
.map(|(_, o)| o.to_string())
.collect::<Vec<_>>()
.join(",");
let query = format!("SELECT oid, typtype FROM pg_type WHERE oid IN ({oid_list})");
if let Ok(rows) = conn.simple_query_rows(&query) {
let mut enum_set: std::collections::HashSet<u32> = std::collections::HashSet::new();
for row in &rows {
let oid: u32 = row
.first()
.and_then(|v| v.as_deref())
.and_then(|s| s.parse().ok())
.unwrap_or(0);
let typtype: &str = row.get(1).and_then(|v| v.as_deref()).unwrap_or("b");
if typtype == "e" {
enum_set.insert(oid);
}
}
for &(idx, oid) in &custom_oids {
if enum_set.contains(&oid) {
result[idx] = true;
}
}
}
result
}
pub fn check_param_types(
parsed: &ParsedQuery,
validation: &ValidationResult,
) -> Result<(), String> {
check_params_against_pg(
&parsed.params,
&validation.param_pg_oids,
&validation.param_is_pg_enum,
false,
"",
)
}
pub fn validate_variants(
variants: &[QueryVariant],
parsed: &ParsedQuery,
conn: &mut Connection,
) -> Result<ValidationResult, String> {
if variants.len() <= 1 {
return validate_query(parsed, conn);
}
let mut canonical_result: Option<ValidationResult> = None;
for (i, variant) in variants.iter().enumerate() {
let result = validate_variant(variant, conn, parsed, i)?;
check_variant_param_types(variant, &result)?;
if let Some(ref canonical) = canonical_result {
if result.columns.len() != canonical.columns.len() {
return Err(format!(
"variant {} (mask {:#06b}) returns {} columns, but variant 0 \
returns {} columns. Optional clauses must not change the SELECT list.",
i,
variant.mask,
result.columns.len(),
canonical.columns.len()
));
}
} else {
canonical_result = Some(result);
}
}
canonical_result.ok_or_else(|| "no variants to validate (internal error)".to_owned())
}
fn validate_variant(
variant: &QueryVariant,
conn: &mut Connection,
parsed: &ParsedQuery,
variant_index: usize,
) -> Result<ValidationResult, String> {
let result = conn
.prepare_describe(&variant.sql)
.map_err(|e| format_variant_driver_error(&e, variant, parsed, variant_index))?;
let param_pg_oids: SmallVec<[u32; 8]> = result.param_oids.iter().copied().collect();
let param_is_pg_enum = detect_pg_enums(conn, &result.param_oids);
let columns = build_columns(conn, &result.columns)?;
Ok(ValidationResult {
columns,
param_pg_oids,
param_is_pg_enum,
#[cfg(feature = "explain")]
explain_plan: None,
})
}
pub fn check_variant_param_types(
variant: &QueryVariant,
validation: &ValidationResult,
) -> Result<(), String> {
check_params_against_pg(
&variant.params,
&validation.param_pg_oids,
&validation.param_is_pg_enum,
true,
&format!("variant (mask {:#06b})", variant.mask),
)
}
fn check_params_against_pg(
params: &[crate::parse::Param],
pg_oids: &[u32],
pg_enum_flags: &[bool],
strip_option_wrapper: bool,
context: &str,
) -> Result<(), String> {
if params.len() != pg_oids.len() {
let ctx = if context.is_empty() {
String::new()
} else {
format!(" in {context}")
};
return Err(format!(
"parameter count mismatch{ctx}: query has {} parameters but PostgreSQL \
expects {}. Check your $name: Type declarations.",
params.len(),
pg_oids.len()
));
}
for (i, (param, &pg_oid)) in params.iter().zip(pg_oids).enumerate() {
let is_pg_enum = pg_enum_flags.get(i).copied().unwrap_or(false);
let check_type = if strip_option_wrapper {
strip_option(¶m.rust_type)
} else {
¶m.rust_type
};
if is_pg_enum {
if matches!(check_type, "&str" | "String") {
continue;
}
if crate::types::is_known_non_enum_type(check_type) {
return Err(format!(
"type `{}` cannot be used for PostgreSQL enum parameter `${}`. \
Use `&str`, `String`, or a `#[bsql::pg_enum]` type.",
param.rust_type, param.name
));
}
continue;
}
if !crate::types::is_param_compatible_extended(check_type, pg_oid) {
let pg_name = bsql_core::types::pg_name_for_oid(pg_oid).unwrap_or("unknown");
let extra_hint = match crate::types::resolve_rust_type(pg_oid) {
Ok(expected) => format!(" (expected `{expected}`)"),
Err(msg) => format!(" — {msg}"),
};
return Err(format!(
"type mismatch for parameter `${}`: declared `{}` but PostgreSQL \
expects `{}` (OID {}){extra_hint}",
param.name, param.rust_type, pg_name, pg_oid
));
}
}
Ok(())
}
fn strip_option(ty: &str) -> &str {
if let Some(inner) = ty.strip_prefix("Option<") {
if let Some(inner) = inner.strip_suffix('>') {
return inner;
}
}
ty
}
fn format_driver_error_base(e: &DriverError) -> String {
match e {
DriverError::Server {
message,
detail,
hint,
position,
..
} => {
let mut out = format!("PostgreSQL error: {message}");
if let Some(pos) = position {
out.push_str(&format!(" (at position {pos})"));
}
if let Some(d) = detail {
out.push_str(&format!("\n detail: {d}"));
}
if let Some(h) = hint {
out.push_str(&format!("\n hint: {h}"));
}
out
}
other => format!("PostgreSQL error: {other}"),
}
}
fn format_variant_driver_error(
e: &DriverError,
variant: &QueryVariant,
parsed: &ParsedQuery,
variant_index: usize,
) -> String {
let n = parsed.optional_clauses.len();
let included: Vec<usize> = (0..n).filter(|&i| (variant.mask & (1 << i)) != 0).collect();
let clause_desc = if included.is_empty() {
"no optional clauses included".to_owned()
} else {
let clause_strs: Vec<String> = included
.iter()
.map(|&i| {
format!(
"clause {} `[{}]`",
i, parsed.optional_clauses[i].sql_fragment
)
})
.collect();
format!("with {}", clause_strs.join(", "))
};
let base_msg = format_driver_error_base(e);
format!(
"optional clause variant {} ({clause_desc}) produces invalid SQL:\n \
{base_msg}\n SQL: {}",
variant_index, variant.sql
)
}
fn format_driver_error(e: &DriverError, parsed: &ParsedQuery) -> String {
let mut out = format_driver_error_base(e);
out.push_str(&format!("\n SQL: {}", parsed.positional_sql));
if let DriverError::Server {
position: Some(pos),
..
} = e
{
let col = (*pos as usize).saturating_sub(1); let prefix_len = " SQL: ".len();
let marker = format!("\n{}{}", " ".repeat(prefix_len + col), "^");
out.push_str(&marker);
}
out
}
pub fn validate_query_with_suggestions(
parsed: &ParsedQuery,
conn: &mut Connection,
) -> Result<ValidationResult, String> {
match validate_query(parsed, conn) {
Ok(result) => Ok(result),
Err(base_error) => {
if let Some(suggestion) = crate::suggest::enhance_error(&base_error, conn) {
Err(format!("{base_error}{suggestion}"))
} else {
Err(base_error)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parse::Param;
#[test]
fn strip_option_wraps_i32() {
assert_eq!(strip_option("Option<i32>"), "i32");
}
#[test]
fn strip_option_no_change_plain_type() {
assert_eq!(strip_option("i32"), "i32");
}
#[test]
fn strip_option_nested() {
assert_eq!(strip_option("Option<Option<i32>>"), "Option<i32>");
}
#[test]
fn strip_option_with_str() {
assert_eq!(strip_option("Option<&str>"), "&str");
}
#[test]
fn strip_option_with_string() {
assert_eq!(strip_option("Option<String>"), "String");
}
#[test]
fn strip_option_with_whitespace_strips_outer() {
assert_eq!(strip_option("Option< i32 >"), " i32 ");
}
#[test]
fn strip_option_empty_string() {
assert_eq!(strip_option(""), "");
}
#[test]
fn strip_option_prefix_only() {
assert_eq!(strip_option("Option<i32"), "Option<i32");
}
#[test]
fn format_server_error_basic() {
let err = DriverError::Server {
code: *b"42P01",
message: "relation \"users\" does not exist".into(),
detail: None,
hint: None,
position: None,
};
let msg = format_driver_error_base(&err);
assert!(msg.contains("relation \"users\" does not exist"));
assert!(msg.starts_with("PostgreSQL error:"));
}
#[test]
fn format_server_error_with_detail_and_hint() {
let err = DriverError::Server {
code: *b"42P01",
message: "something went wrong".into(),
detail: Some("extra detail here".into()),
hint: Some("try this instead".into()),
position: None,
};
let msg = format_driver_error_base(&err);
assert!(msg.contains("something went wrong"));
assert!(msg.contains("detail: extra detail here"));
assert!(msg.contains("hint: try this instead"));
}
#[test]
fn format_server_error_with_position() {
let err = DriverError::Server {
code: *b"42601",
message: "syntax error".into(),
detail: None,
hint: None,
position: Some(15),
};
let msg = format_driver_error_base(&err);
assert!(msg.contains("at position 15"));
}
#[test]
fn format_non_server_error() {
let err = DriverError::Pool("connection lost".into());
let msg = format_driver_error_base(&err);
assert!(msg.contains("PostgreSQL error:"));
assert!(msg.contains("connection lost"));
}
#[test]
fn format_driver_error_includes_sql() {
let err = DriverError::Server {
code: *b"42P01",
message: "relation does not exist".into(),
detail: None,
hint: None,
position: None,
};
let parsed = crate::parse::parse_query("SELECT id FROM users WHERE id = $id: i32").unwrap();
let msg = format_driver_error(&err, &parsed);
assert!(msg.contains("SQL:"), "should include SQL in error: {msg}");
assert!(msg.contains("$1"), "should include positional SQL: {msg}");
}
#[test]
fn format_driver_error_includes_position_marker() {
let err = DriverError::Server {
code: *b"42601",
message: "syntax error".into(),
detail: None,
hint: None,
position: Some(8),
};
let parsed = crate::parse::parse_query("SELECT id FROM users WHERE id = $id: i32").unwrap();
let msg = format_driver_error(&err, &parsed);
assert!(msg.contains('^'), "should include position marker: {msg}");
}
#[test]
fn check_params_count_mismatch() {
let params = vec![Param {
name: "id".into(),
rust_type: "i32".into(),
position: 1,
}];
let pg_oids = [23u32, 25u32]; let pg_enum = [false, false];
let result = check_params_against_pg(¶ms, &pg_oids, &pg_enum, false, "");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.contains("parameter count mismatch"), "error: {err}");
}
#[test]
fn check_params_count_mismatch_with_context() {
let params = vec![];
let pg_oids = [23u32];
let pg_enum = [false];
let result =
check_params_against_pg(¶ms, &pg_oids, &pg_enum, false, "variant (mask 0b0011)");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.contains("variant (mask 0b0011)"),
"should include context: {err}"
);
}
#[test]
fn check_params_type_mismatch() {
let params = vec![Param {
name: "id".into(),
rust_type: "&str".into(), position: 1,
}];
let pg_oids = [23u32]; let pg_enum = [false];
let result = check_params_against_pg(¶ms, &pg_oids, &pg_enum, false, "");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.contains("type mismatch"),
"should mention type mismatch: {err}"
);
}
#[test]
fn check_params_matching_types_ok() {
let params = vec![Param {
name: "id".into(),
rust_type: "i32".into(),
position: 1,
}];
let pg_oids = [23u32]; let pg_enum = [false];
let result = check_params_against_pg(¶ms, &pg_oids, &pg_enum, false, "");
assert!(result.is_ok());
}
#[test]
fn check_params_empty_ok() {
let params: Vec<Param> = vec![];
let pg_oids: [u32; 0] = [];
let pg_enum: [bool; 0] = [];
let result = check_params_against_pg(¶ms, &pg_oids, &pg_enum, false, "");
assert!(result.is_ok());
}
#[test]
fn check_params_enum_with_str_ok() {
let params = vec![Param {
name: "status".into(),
rust_type: "&str".into(),
position: 1,
}];
let pg_oids = [99999u32]; let pg_enum = [true];
let result = check_params_against_pg(¶ms, &pg_oids, &pg_enum, false, "");
assert!(result.is_ok(), "enum param with &str should be accepted");
}
#[test]
fn check_params_enum_with_string_ok() {
let params = vec![Param {
name: "status".into(),
rust_type: "String".into(),
position: 1,
}];
let pg_oids = [99999u32];
let pg_enum = [true];
let result = check_params_against_pg(¶ms, &pg_oids, &pg_enum, false, "");
assert!(result.is_ok(), "enum param with String should be accepted");
}
#[test]
fn check_params_enum_with_i32_error() {
let params = vec![Param {
name: "status".into(),
rust_type: "i32".into(),
position: 1,
}];
let pg_oids = [99999u32];
let pg_enum = [true];
let result = check_params_against_pg(¶ms, &pg_oids, &pg_enum, false, "");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.contains("cannot be used for PostgreSQL enum"),
"should reject i32 for enum: {err}"
);
}
#[test]
fn check_params_enum_with_custom_type_ok() {
let params = vec![Param {
name: "status".into(),
rust_type: "MyStatusEnum".into(),
position: 1,
}];
let pg_oids = [99999u32];
let pg_enum = [true];
let result = check_params_against_pg(¶ms, &pg_oids, &pg_enum, false, "");
assert!(result.is_ok(), "custom enum type should be accepted");
}
#[test]
fn check_params_strip_option_in_variant_mode() {
let params = vec![Param {
name: "id".into(),
rust_type: "Option<i32>".into(),
position: 1,
}];
let pg_oids = [23u32]; let pg_enum = [false];
let result = check_params_against_pg(¶ms, &pg_oids, &pg_enum, true, "variant");
assert!(
result.is_ok(),
"Option<i32> stripped to i32 should match int4"
);
}
#[test]
fn check_params_strip_option_mismatch() {
let params = vec![Param {
name: "id".into(),
rust_type: "Option<&str>".into(),
position: 1,
}];
let pg_oids = [23u32]; let pg_enum = [false];
let result = check_params_against_pg(¶ms, &pg_oids, &pg_enum, true, "variant");
assert!(
result.is_err(),
"Option<&str> stripped to &str should not match int4"
);
}
}