use tokio::runtime::Runtime;
use tokio_postgres::Client;
use crate::dynamic::QueryVariant;
use crate::parse::ParsedQuery;
#[derive(Debug, Clone)]
pub struct ColumnInfo {
pub name: String,
#[allow(dead_code)] pub pg_oid: u32,
#[allow(dead_code)] pub pg_type_name: String,
#[allow(dead_code)] pub is_nullable: bool,
pub rust_type: String,
}
#[derive(Debug, Clone)]
pub struct ValidationResult {
pub columns: Vec<ColumnInfo>,
pub param_pg_oids: Vec<u32>,
pub param_is_pg_enum: Vec<bool>,
pub explain_plan: Option<String>,
}
pub fn validate_query(
parsed: &ParsedQuery,
rt: &Runtime,
client: &Client,
) -> Result<ValidationResult, String> {
rt.block_on(validate_async(parsed, client))
}
async fn validate_async(parsed: &ParsedQuery, client: &Client) -> Result<ValidationResult, String> {
let stmt = client
.prepare(&parsed.positional_sql)
.await
.map_err(|e| format_pg_error(&e, parsed))?;
let param_pg_oids: Vec<u32> = stmt.params().iter().map(|t| t.oid()).collect();
let param_is_pg_enum: Vec<bool> = stmt
.params()
.iter()
.map(|t| matches!(t.kind(), postgres_types::Kind::Enum(_)))
.collect();
let nullable_flags = resolve_nullability_batch(client, stmt.columns()).await;
let mut columns = Vec::with_capacity(stmt.columns().len());
for (i, col) in stmt.columns().iter().enumerate() {
let pg_oid = col.type_().oid();
let pg_type_name = col.type_().name().to_owned();
let name = col.name().to_owned();
let is_nullable = nullable_flags[i];
let is_pg_enum = matches!(col.type_().kind(), postgres_types::Kind::Enum(_));
let base_rust_type = if is_pg_enum {
"::bsql_core::types::EnumString"
} else {
crate::types::resolve_rust_type(pg_oid)
.map_err(|msg| format!("column \"{name}\": {msg}"))?
};
let rust_type = if is_nullable {
format!("Option<{base_rust_type}>")
} else {
base_rust_type.to_owned()
};
columns.push(ColumnInfo {
name,
pg_oid,
pg_type_name,
is_nullable,
rust_type,
});
}
let explain_plan = fetch_explain_plan(client, parsed).await;
Ok(ValidationResult {
columns,
param_pg_oids,
param_is_pg_enum,
explain_plan,
})
}
#[cfg(feature = "explain")]
async fn fetch_explain_plan(client: &Client, parsed: &ParsedQuery) -> Option<String> {
let explain_sql = format!("EXPLAIN (FORMAT TEXT, COSTS) {}", parsed.positional_sql);
match client.simple_query(&explain_sql).await {
Ok(messages) => {
let lines: Vec<String> = messages
.iter()
.filter_map(|msg| {
if let tokio_postgres::SimpleQueryMessage::Row(row) = msg {
row.get(0).map(String::from)
} else {
None
}
})
.collect();
if lines.is_empty() {
None
} else {
Some(lines.join("\n"))
}
}
Err(_) => None,
}
}
#[cfg(not(feature = "explain"))]
async fn fetch_explain_plan(_client: &Client, _parsed: &ParsedQuery) -> Option<String> {
None
}
async fn resolve_nullability_batch(
client: &Client,
columns: &[tokio_postgres::Column],
) -> Vec<bool> {
let col_count = columns.len();
let mut result = vec![true; col_count];
let mut table_oids: Vec<u32> = Vec::new();
let mut col_nums: Vec<i16> = Vec::new();
let mut col_indices: Vec<usize> = Vec::new();
for (i, col) in columns.iter().enumerate() {
match (col.table_oid(), col.column_id()) {
(Some(t), Some(c)) if t != 0 && c != 0 => {
table_oids.push(t);
col_nums.push(c);
col_indices.push(i);
}
_ => {} }
}
if table_oids.is_empty() {
return result;
}
let query = "\
SELECT a.attrelid, a.attnum, NOT a.attnotnull \
FROM pg_attribute a \
WHERE (a.attrelid, a.attnum) IN (\
SELECT unnest($1::oid[]), unnest($2::int2[])\
)";
if let Ok(rows) = client.query(query, &[&table_oids, &col_nums]).await {
for row in &rows {
let oid: u32 = row.get(0);
let num: i16 = row.get(1);
let is_nullable: bool = row.get(2);
for (idx, (&t, &c)) in table_oids.iter().zip(col_nums.iter()).enumerate() {
if t == oid && c == num {
result[col_indices[idx]] = is_nullable;
}
}
}
}
result
}
pub fn check_param_types(
parsed: &ParsedQuery,
validation: &ValidationResult,
) -> Result<(), String> {
if parsed.params.len() != validation.param_pg_oids.len() {
return Err(format!(
"parameter count mismatch: query has {} parameters but PostgreSQL \
expects {}. Check your $name: Type declarations.",
parsed.params.len(),
validation.param_pg_oids.len()
));
}
for (i, (param, &pg_oid)) in parsed
.params
.iter()
.zip(&validation.param_pg_oids)
.enumerate()
{
let is_pg_enum = validation.param_is_pg_enum.get(i).copied().unwrap_or(false);
if is_pg_enum {
if matches!(param.rust_type.as_str(), "&str" | "String") {
continue;
}
if crate::types::is_known_non_enum_type(¶m.rust_type) {
return Err(format!(
"type `{}` cannot be used for PostgreSQL enum parameter `${}`. \
Use `&str`, `String`, or a `#[bsql::pg_enum]` type.",
param.rust_type, param.name
));
}
continue;
}
if !crate::types::is_param_compatible_extended(¶m.rust_type, pg_oid) {
let pg_name = bsql_core::types::pg_name_for_oid(pg_oid).unwrap_or("unknown");
let extra_hint = match crate::types::resolve_rust_type(pg_oid) {
Ok(expected) => format!(" (expected `{expected}`)"),
Err(msg) => format!(" — {msg}"),
};
return Err(format!(
"type mismatch for parameter `${}`: declared `{}` but PostgreSQL \
expects `{}` (OID {}){extra_hint}",
param.name, param.rust_type, pg_name, pg_oid
));
}
}
Ok(())
}
pub fn validate_variants(
variants: &[QueryVariant],
parsed: &ParsedQuery,
rt: &Runtime,
client: &Client,
) -> Result<ValidationResult, String> {
if variants.len() <= 1 {
return validate_query(parsed, rt, client);
}
let mut canonical_result: Option<ValidationResult> = None;
for (i, variant) in variants.iter().enumerate() {
let result = rt.block_on(validate_variant_async(variant, client, parsed, i))?;
check_variant_param_types(variant, &result)?;
if let Some(ref canonical) = canonical_result {
if result.columns.len() != canonical.columns.len() {
return Err(format!(
"variant {} (mask {:#06b}) returns {} columns, but variant 0 \
returns {} columns. Optional clauses must not change the SELECT list.",
i,
variant.mask,
result.columns.len(),
canonical.columns.len()
));
}
} else {
canonical_result = Some(result);
}
}
canonical_result.ok_or_else(|| "no variants to validate (internal error)".to_owned())
}
async fn validate_variant_async(
variant: &QueryVariant,
client: &Client,
parsed: &ParsedQuery,
variant_index: usize,
) -> Result<ValidationResult, String> {
let stmt = client
.prepare(&variant.sql)
.await
.map_err(|e| format_variant_error(&e, variant, parsed, variant_index))?;
let param_pg_oids: Vec<u32> = stmt.params().iter().map(|t| t.oid()).collect();
let param_is_pg_enum: Vec<bool> = stmt
.params()
.iter()
.map(|t| matches!(t.kind(), postgres_types::Kind::Enum(_)))
.collect();
let nullable_flags = resolve_nullability_batch(client, stmt.columns()).await;
let mut columns = Vec::with_capacity(stmt.columns().len());
for (i, col) in stmt.columns().iter().enumerate() {
let pg_oid = col.type_().oid();
let pg_type_name = col.type_().name().to_owned();
let name = col.name().to_owned();
let is_nullable = nullable_flags[i];
let is_pg_enum = matches!(col.type_().kind(), postgres_types::Kind::Enum(_));
let base_rust_type = if is_pg_enum {
"::bsql_core::types::EnumString"
} else {
crate::types::resolve_rust_type(pg_oid)
.map_err(|msg| format!("column \"{name}\": {msg}"))?
};
let rust_type = if is_nullable {
format!("Option<{base_rust_type}>")
} else {
base_rust_type.to_owned()
};
columns.push(ColumnInfo {
name,
pg_oid,
pg_type_name,
is_nullable,
rust_type,
});
}
Ok(ValidationResult {
columns,
param_pg_oids,
param_is_pg_enum,
explain_plan: None, })
}
pub fn check_variant_param_types(
variant: &QueryVariant,
validation: &ValidationResult,
) -> Result<(), String> {
if variant.params.len() != validation.param_pg_oids.len() {
return Err(format!(
"parameter count mismatch in variant (mask {:#06b}): query has {} \
parameters but PostgreSQL expects {}.",
variant.mask,
variant.params.len(),
validation.param_pg_oids.len()
));
}
for (i, (param, &pg_oid)) in variant
.params
.iter()
.zip(&validation.param_pg_oids)
.enumerate()
{
let is_pg_enum = validation.param_is_pg_enum.get(i).copied().unwrap_or(false);
let check_type = strip_option(¶m.rust_type);
if is_pg_enum {
if matches!(check_type, "&str" | "String") {
continue;
}
if crate::types::is_known_non_enum_type(check_type) {
return Err(format!(
"type `{}` cannot be used for PostgreSQL enum parameter `${}`. \
Use `&str`, `String`, or a `#[bsql::pg_enum]` type.",
param.rust_type, param.name
));
}
continue;
}
if !crate::types::is_param_compatible_extended(check_type, pg_oid) {
let pg_name = bsql_core::types::pg_name_for_oid(pg_oid).unwrap_or("unknown");
let extra_hint = match crate::types::resolve_rust_type(pg_oid) {
Ok(expected) => format!(" (expected `{expected}`)"),
Err(msg) => format!(" — {msg}"),
};
return Err(format!(
"type mismatch for parameter `${}`: declared `{}` but PostgreSQL \
expects `{}` (OID {}){extra_hint}",
param.name, param.rust_type, pg_name, pg_oid
));
}
}
Ok(())
}
fn strip_option(ty: &str) -> &str {
if let Some(inner) = ty.strip_prefix("Option<") {
if let Some(inner) = inner.strip_suffix('>') {
return inner;
}
}
ty
}
fn format_variant_error(
e: &tokio_postgres::Error,
variant: &QueryVariant,
parsed: &ParsedQuery,
variant_index: usize,
) -> String {
let n = parsed.optional_clauses.len();
let included: Vec<usize> = (0..n).filter(|&i| (variant.mask & (1 << i)) != 0).collect();
let clause_desc = if included.is_empty() {
"no optional clauses included".to_owned()
} else {
let clause_strs: Vec<String> = included
.iter()
.map(|&i| {
format!(
"clause {} `[{}]`",
i, parsed.optional_clauses[i].sql_fragment
)
})
.collect();
format!("with {}", clause_strs.join(", "))
};
let base_msg = if let Some(db_err) = e.as_db_error() {
let detail = db_err.detail().unwrap_or("");
let hint = db_err.hint().unwrap_or("");
let mut out = format!("PostgreSQL error: {}", db_err.message());
if !detail.is_empty() {
out.push_str(&format!("\n detail: {detail}"));
}
if !hint.is_empty() {
out.push_str(&format!("\n hint: {hint}"));
}
out
} else {
format!("PostgreSQL error: {e}")
};
format!(
"optional clause variant {} ({clause_desc}) produces invalid SQL:\n \
{base_msg}\n SQL: {}",
variant_index, variant.sql
)
}
fn format_pg_error(e: &tokio_postgres::Error, parsed: &ParsedQuery) -> String {
let msg = e.to_string();
if let Some(db_err) = e.as_db_error() {
let detail = db_err.detail().unwrap_or("");
let hint = db_err.hint().unwrap_or("");
let position = db_err.position();
let mut out = format!("PostgreSQL error: {}", db_err.message());
if !detail.is_empty() {
out.push_str(&format!("\n detail: {detail}"));
}
if !hint.is_empty() {
out.push_str(&format!("\n hint: {hint}"));
}
if let Some(pos) = position {
out.push_str(&format!("\n position: {pos:?}"));
out.push_str(&format!("\n SQL: {}", parsed.positional_sql));
}
out
} else {
format!("PostgreSQL error: {msg}\n SQL: {}", parsed.positional_sql)
}
}