use crate::db::get_or_init;
use crate::diagnostic::ValidationError;
use crate::dry_run::dry_run;
use crate::error_extract::{classify, ErrorClass};
use crate::registry::{self, Registry};
pub fn validate_query_as(struct_name: &str, sql: &str) -> Result<(), ValidationError> {
let (_table_name, entry) = registry::get_by_struct(struct_name).ok_or_else(|| {
ValidationError::StructNotRegistered {
struct_name: struct_name.to_owned(),
}
})?;
let mut db = get_or_init().lock();
let schema = run_dry_run_with_seed(sql, &mut db)?;
drop(db);
finish_name_check(struct_name, &entry.fields, &schema)
}
pub fn validate_scalar_sql(sql: &str) -> Result<(), ValidationError> {
let mut db = get_or_init().lock();
let schema = run_dry_run_with_seed(sql, &mut db)?;
drop(db);
let col_count = schema.column_count();
if col_count != 1 {
return Err(ValidationError::HyperError {
message: format!(
"query_scalar! requires exactly one projected column, but the query projects {col_count}"
),
});
}
Ok(())
}
fn run_dry_run_with_seed(
sql: &str,
db: &mut crate::db::CompileTimeDb,
) -> Result<hyperdb_api::ResultSchema, ValidationError> {
const MAX_SEED_ROUNDS: usize = 8;
for _ in 0..MAX_SEED_ROUNDS {
match dry_run(db, sql) {
Ok(schema) => return Ok(schema),
Err(e) => match classify(&e) {
ErrorClass::MissingTable(t) => match Registry::seed_if_known(&t, db) {
Ok(true) => {} Ok(false) => {
return Err(ValidationError::TablesNotRegistered { tables: vec![t] })
}
Err(seed_err) => {
return Err(ValidationError::HyperError {
message: format!("{seed_err}"),
})
}
},
ErrorClass::SyntaxError(msg) => {
return Err(ValidationError::SqlSyntaxError { message: msg })
}
ErrorClass::MissingColumn(col) => {
return Err(ValidationError::UnknownColumn { column: col })
}
ErrorClass::Other(msg) => return Err(ValidationError::HyperError { message: msg }),
},
}
}
Err(ValidationError::HyperError {
message: format!(
"compile-time validation exceeded {MAX_SEED_ROUNDS} seed-and-retry rounds; \
ensure all tables referenced by this query are registered via \
`#[derive(Table)] #[hyperdb(register)]`"
),
})
}
fn finish_name_check(
struct_name: &str,
struct_fields: &[String],
schema: &hyperdb_api::ResultSchema,
) -> Result<(), ValidationError> {
let result_cols: std::collections::HashSet<&str> = schema
.columns()
.iter()
.map(hyperdb_api::ResultColumn::name)
.collect();
let missing: Vec<String> = struct_fields
.iter()
.filter(|f| !result_cols.contains(f.as_str()))
.cloned()
.collect();
if missing.is_empty() {
Ok(())
} else {
Err(ValidationError::MissingColumns {
struct_name: struct_name.to_owned(),
missing,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn setup_users() {
registry::register(
"User",
"users",
"CREATE TABLE IF NOT EXISTS users (id BIGINT, name TEXT, email TEXT)",
vec!["id".into(), "name".into(), "email".into()],
);
}
#[test]
fn struct_not_registered_error() {
let err = validate_query_as("Ghost", "SELECT 1").unwrap_err();
assert!(
matches!(err, ValidationError::StructNotRegistered { .. }),
"expected StructNotRegistered, got: {err}"
);
}
#[test]
#[ignore = "requires HYPERD_PATH; run manually"]
fn valid_query_passes() {
setup_users();
validate_query_as("User", "SELECT id, name, email FROM users").unwrap();
}
#[test]
#[ignore = "requires HYPERD_PATH; run manually"]
fn extra_column_in_result_is_ok() {
registry::register(
"SlimUser",
"slim_users",
"CREATE TABLE IF NOT EXISTS slim_users (id BIGINT, name TEXT, extra TEXT)",
vec!["id".into(), "name".into()],
);
validate_query_as("SlimUser", "SELECT * FROM slim_users").unwrap();
}
#[test]
#[ignore = "requires HYPERD_PATH; run manually"]
fn missing_column_error() {
setup_users();
let err = validate_query_as("User", "SELECT id, name FROM users").unwrap_err();
assert!(
matches!(err, ValidationError::MissingColumns { .. }),
"expected MissingColumns, got: {err}"
);
let msg = err.to_diagnostic();
assert!(
msg.contains("email"),
"missing column name in message: {msg}"
);
}
#[test]
#[ignore = "requires HYPERD_PATH; run manually"]
fn seed_and_retry_on_missing_table() {
registry::register(
"Order",
"orders",
"CREATE TABLE IF NOT EXISTS orders (id BIGINT, total DOUBLE PRECISION)",
vec!["id".into(), "total".into()],
);
validate_query_as("Order", "SELECT id, total FROM orders").unwrap();
validate_query_as("Order", "SELECT id, total FROM orders").unwrap();
}
#[test]
#[ignore = "requires HYPERD_PATH; run manually"]
fn unregistered_table_in_sql_error() {
registry::register(
"Known",
"known",
"CREATE TABLE IF NOT EXISTS known (id BIGINT)",
vec!["id".into()],
);
let err = validate_query_as("Known", "SELECT * FROM nonexistent_xyz").unwrap_err();
assert!(
matches!(err, ValidationError::TablesNotRegistered { .. }),
"expected TablesNotRegistered, got: {err}"
);
}
}