#![allow(dead_code)]
use crate::{
db::{
DbSession, PersistedRow, QueryError,
executor::EntityAuthority,
session::sql::SqlCompiledCommandCacheKey,
sql::lowering::{
PreparedSqlParameterContract, PreparedSqlParameterTypeFamily, PreparedSqlStatement,
},
},
traits::{CanisterKind, EntityValue},
value::Value,
};
#[derive(Clone, Debug)]
pub(in crate::db) struct PreparedSqlQuery {
source_sql: String,
statement: PreparedSqlStatement,
parameter_contracts: Vec<PreparedSqlParameterContract>,
}
impl PreparedSqlQuery {
#[must_use]
pub(in crate::db) fn source_sql(&self) -> &str {
&self.source_sql
}
#[must_use]
pub(in crate::db) const fn parameter_contracts(&self) -> &[PreparedSqlParameterContract] {
self.parameter_contracts.as_slice()
}
#[must_use]
pub(in crate::db) const fn parameter_count(&self) -> usize {
self.parameter_contracts.len()
}
}
impl<C: CanisterKind> DbSession<C> {
#[expect(
clippy::unused_self,
reason = "the session-owned SQL API keeps prepare/execute shaped consistently on DbSession"
)]
pub(in crate::db) fn prepare_sql_query<E>(
&self,
sql: &str,
) -> Result<PreparedSqlQuery, QueryError>
where
E: PersistedRow<Canister = C> + EntityValue,
{
let statement = crate::db::session::sql::parse_sql_statement_with_attribution(sql)
.map(|(statement, _)| statement)?;
Self::ensure_sql_query_statement_supported(&statement)?;
let authority = EntityAuthority::for_type::<E>();
let prepared = Self::prepare_sql_statement_for_authority(&statement, authority)?;
let parameter_contracts = prepared
.parameter_contracts(authority.model())
.map_err(QueryError::from_sql_lowering_error)?;
Ok(PreparedSqlQuery {
source_sql: sql.to_string(),
statement: prepared,
parameter_contracts,
})
}
pub(in crate::db) fn execute_prepared_sql_query<E>(
&self,
prepared: &PreparedSqlQuery,
bindings: &[Value],
) -> Result<crate::db::session::sql::SqlStatementResult, QueryError>
where
E: PersistedRow<Canister = C> + EntityValue,
{
validate_parameter_bindings(prepared.parameter_contracts(), bindings)?;
let bound_statement = prepared.statement.bind_literals(bindings)?;
let authority = EntityAuthority::for_type::<E>();
let compiled_cache_key =
SqlCompiledCommandCacheKey::query_for_entity::<E>(prepared.source_sql());
let compiled = Self::compile_sql_statement_for_authority(
&bound_statement,
authority,
compiled_cache_key,
)?
.0;
self.execute_compiled_sql::<E>(&compiled)
}
}
fn validate_parameter_bindings(
contracts: &[PreparedSqlParameterContract],
bindings: &[Value],
) -> Result<(), QueryError> {
if bindings.len() != contracts.len() {
return Err(QueryError::unsupported_query(format!(
"prepared SQL expected {} bindings, found {}",
contracts.len(),
bindings.len(),
)));
}
for contract in contracts {
let binding = bindings.get(contract.index()).ok_or_else(|| {
QueryError::unsupported_query(format!(
"missing prepared SQL binding at index={}",
contract.index(),
))
})?;
if !binding_matches_contract(binding, contract) {
return Err(QueryError::unsupported_query(format!(
"prepared SQL binding at index={} does not match the required {:?} contract",
contract.index(),
contract.type_family(),
)));
}
}
Ok(())
}
const fn binding_matches_contract(value: &Value, contract: &PreparedSqlParameterContract) -> bool {
if matches!(value, Value::Null) {
return contract.null_allowed();
}
match contract.type_family() {
PreparedSqlParameterTypeFamily::Numeric => matches!(
value,
Value::Int(_)
| Value::Int128(_)
| Value::IntBig(_)
| Value::Uint(_)
| Value::Uint128(_)
| Value::UintBig(_)
| Value::Float32(_)
| Value::Float64(_)
| Value::Decimal(_)
| Value::Duration(_)
| Value::Timestamp(_)
),
PreparedSqlParameterTypeFamily::Text => {
matches!(value, Value::Text(_) | Value::Enum(_))
}
PreparedSqlParameterTypeFamily::Bool => matches!(value, Value::Bool(_)),
}
}