use std::collections::HashMap;
use pgwire::error::PgWireResult;
use crate::control::security::catalog::types::CheckConstraintDef;
use crate::control::state::SharedState;
use crate::types::TraceId;
pub async fn enforce_check_constraints(
state: &SharedState,
tenant_id: nodedb_types::TenantId,
constraints: &[CheckConstraintDef],
fields: &HashMap<String, nodedb_types::Value>,
) -> PgWireResult<()> {
for constraint in constraints {
if constraint.has_subquery {
enforce_subquery_check(state, tenant_id, constraint, fields).await?;
} else {
enforce_simple_check(constraint, fields)?;
}
}
Ok(())
}
fn enforce_simple_check(
constraint: &CheckConstraintDef,
fields: &HashMap<String, nodedb_types::Value>,
) -> PgWireResult<()> {
let bare_expr = strip_new_prefix(&constraint.check_sql);
let (expr, _deps) =
nodedb_query::expr_parse::parse_generated_expr(&bare_expr).map_err(|e| {
pgwire_err(
"23514",
&format!(
"CHECK constraint '{}' failed to parse: {}",
constraint.name, e
),
)
})?;
let doc = nodedb_types::Value::Object(fields.clone());
let result = expr.eval(&doc);
match result {
nodedb_types::Value::Bool(true) => Ok(()),
nodedb_types::Value::Null => Ok(()),
nodedb_types::Value::Integer(n) if n != 0 => Ok(()),
_ => Err(pgwire_err(
"23514",
&format!(
"CHECK constraint '{}' violated: {}",
constraint.name, constraint.check_sql
),
)),
}
}
async fn enforce_subquery_check(
state: &SharedState,
tenant_id: nodedb_types::TenantId,
constraint: &CheckConstraintDef,
fields: &HashMap<String, nodedb_types::Value>,
) -> PgWireResult<()> {
let substituted = substitute_new_refs(&constraint.check_sql, fields);
let restructured = restructure_subquery_check(&substituted);
let query_ctx = crate::control::planner::context::QueryContext::for_state(state);
let tasks = match query_ctx
.plan_sql(
&restructured.sql,
tenant_id,
crate::types::DatabaseId::DEFAULT,
)
.await
{
Ok(t) => t,
Err(e) => {
return Err(pgwire_err(
"23514",
&format!(
"CHECK constraint '{}' failed to evaluate: {}",
constraint.name, e
),
));
}
};
let mut passed = false;
for task in tasks {
let resp = crate::control::server::dispatch_utils::dispatch_to_data_plane(
state,
tenant_id,
task.vshard_id,
task.plan,
TraceId::ZERO,
)
.await;
match resp {
Ok(response) => {
let json = crate::data::executor::response_codec::decode_payload_to_json(
&response.payload,
);
if !json.is_empty() && check_count_is_positive(&json) {
passed = true;
}
}
Err(e) => {
return Err(pgwire_err(
"23514",
&format!(
"CHECK constraint '{}' failed to evaluate: {}",
constraint.name, e
),
));
}
}
}
let constraint_ok = if restructured.negate { !passed } else { passed };
if !constraint_ok {
return Err(pgwire_err(
"23514",
&format!(
"CHECK constraint '{}' violated: {}",
constraint.name, constraint.check_sql
),
));
}
Ok(())
}
fn check_count_is_positive(json: &str) -> bool {
if let Ok(v) = sonic_rs::from_str::<serde_json::Value>(json) {
let obj = if let Some(arr) = v.as_array() {
arr.first().and_then(|r| r.as_object())
} else {
v.as_object()
};
if let Some(obj) = obj {
for (_, val) in obj {
if let Some(n) = val.as_i64() {
return n > 0;
}
if let Some(n) = val.as_f64() {
return n > 0.0;
}
}
}
}
false
}
struct RestructuredCheck {
sql: String,
negate: bool,
}
fn restructure_subquery_check(expr: &str) -> RestructuredCheck {
let upper = expr.to_uppercase();
let (in_pos, negate) = if let Some(pos) = upper.find(" NOT IN (SELECT ") {
(pos, true)
} else if let Some(pos) = upper.find(" NOT IN(SELECT ") {
(pos, true)
} else if let Some(pos) = upper.find(" IN (SELECT ") {
(pos, false)
} else if let Some(pos) = upper.find(" IN(SELECT ") {
(pos, false)
} else {
return RestructuredCheck {
sql: format!("SELECT ({expr}) AS _check"),
negate: false,
};
};
let value_part = expr[..in_pos].trim();
let keyword_len = if negate { " NOT IN (" } else { " IN (" };
let select_part = &expr[in_pos + keyword_len.len()..];
let inner = select_part.trim().trim_end_matches(')').trim();
if let Some(from_pos) = inner.to_uppercase().find(" FROM ") {
let col = inner["SELECT ".len()..from_pos].trim();
let after_from = &inner[from_pos + 6..];
let (table, existing_where) = if let Some(w) = after_from.to_uppercase().find(" WHERE ") {
(&after_from[..w], Some(&after_from[w + 7..]))
} else {
(after_from.trim(), None)
};
let sql = if let Some(where_clause) = existing_where {
format!(
"SELECT COUNT(*) AS cnt FROM {} WHERE {} = {} AND {}",
table.trim(),
col,
value_part,
where_clause
)
} else {
format!(
"SELECT COUNT(*) AS cnt FROM {} WHERE {} = {}",
table.trim(),
col,
value_part
)
};
return RestructuredCheck { sql, negate };
}
RestructuredCheck {
sql: format!("SELECT ({expr}) AS _check"),
negate: false,
}
}
fn strip_new_prefix(sql: &str) -> String {
let chars: Vec<char> = sql.chars().collect();
let mut result = String::with_capacity(sql.len());
let mut i = 0;
while i < chars.len() {
if i + 4 <= chars.len() {
let window: String = chars[i..i + 4].iter().collect();
if window.eq_ignore_ascii_case("NEW.") {
if i > 0 && (chars[i - 1].is_ascii_alphanumeric() || chars[i - 1] == '_') {
result.push(chars[i]);
i += 1;
continue;
}
i += 4;
continue;
}
}
result.push(chars[i]);
i += 1;
}
result
}
fn substitute_new_refs(sql: &str, fields: &HashMap<String, nodedb_types::Value>) -> String {
let mut result = sql.to_string();
let mut field_names: Vec<&String> = fields.keys().collect();
field_names.sort_by_key(|b| std::cmp::Reverse(b.len()));
for field_name in field_names {
let pattern_upper = format!("NEW.{}", field_name.to_uppercase());
let pattern_lower = format!("NEW.{}", field_name.to_lowercase());
let pattern_orig = format!("NEW.{field_name}");
let literal = value_to_sql_literal(&fields[field_name]);
result = replace_case_insensitive(&result, &pattern_orig, &literal);
if pattern_orig != pattern_upper {
result = replace_case_insensitive(&result, &pattern_upper, &literal);
}
if pattern_orig != pattern_lower {
result = replace_case_insensitive(&result, &pattern_lower, &literal);
}
}
result = replace_remaining_new_refs(&result);
result
}
fn replace_remaining_new_refs(text: &str) -> String {
let chars: Vec<char> = text.chars().collect();
let mut result = String::with_capacity(text.len());
let mut i = 0;
while i < chars.len() {
if i + 4 <= chars.len() {
let window: String = chars[i..i + 4].iter().collect();
if window.eq_ignore_ascii_case("NEW.") {
if i > 0 && (chars[i - 1].is_ascii_alphanumeric() || chars[i - 1] == '_') {
result.push(chars[i]);
i += 1;
continue;
}
let start = i + 4;
let mut end = start;
while end < chars.len() && (chars[end].is_ascii_alphanumeric() || chars[end] == '_')
{
end += 1;
}
if end > start {
result.push_str("NULL");
i = end;
continue;
}
}
}
result.push(chars[i]);
i += 1;
}
result
}
fn replace_case_insensitive(text: &str, pattern: &str, replacement: &str) -> String {
let upper_text = text.to_uppercase();
let upper_pattern = pattern.to_uppercase();
let mut result = String::with_capacity(text.len());
let mut last_end = 0;
for (start, _) in upper_text.match_indices(&upper_pattern) {
if start > 0 {
let prev = text.as_bytes()[start - 1];
if prev.is_ascii_alphanumeric() || prev == b'_' {
continue;
}
}
let end = start + pattern.len();
if end < text.len() {
let next = text.as_bytes()[end];
if next.is_ascii_alphanumeric() || next == b'_' {
continue;
}
}
result.push_str(&text[last_end..start]);
result.push_str(replacement);
last_end = end;
}
result.push_str(&text[last_end..]);
result
}
fn value_to_sql_literal(val: &nodedb_types::Value) -> String {
match val {
nodedb_types::Value::Null => "NULL".to_string(),
nodedb_types::Value::Bool(b) => if *b { "TRUE" } else { "FALSE" }.to_string(),
nodedb_types::Value::Integer(i) => i.to_string(),
nodedb_types::Value::Float(f) => format!("{f}"),
nodedb_types::Value::String(s) => {
let escaped = s.replace('\'', "''");
format!("'{escaped}'")
}
nodedb_types::Value::DateTime(dt) | nodedb_types::Value::NaiveDateTime(dt) => {
format!("'{dt}'")
}
_ => "NULL".to_string(),
}
}
fn pgwire_err(code: &str, msg: &str) -> pgwire::error::PgWireError {
pgwire::error::PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
"ERROR".to_owned(),
code.to_owned(),
msg.to_owned(),
)))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn substitute_new_refs_basic() {
let mut fields = HashMap::new();
fields.insert(
"email".to_string(),
nodedb_types::Value::String("alice@example.com".into()),
);
fields.insert("age".to_string(), nodedb_types::Value::Integer(25));
let sql = "NEW.email LIKE '%@%.%' AND NEW.age >= 18";
let result = substitute_new_refs(sql, &fields);
assert_eq!(result, "'alice@example.com' LIKE '%@%.%' AND 25 >= 18");
}
#[test]
fn substitute_new_refs_case_insensitive() {
let mut fields = HashMap::new();
fields.insert(
"name".to_string(),
nodedb_types::Value::String("Bob".into()),
);
let sql = "new.name IS NOT NULL";
let result = substitute_new_refs(sql, &fields);
assert_eq!(result, "'Bob' IS NOT NULL");
}
#[test]
fn substitute_new_refs_missing_field() {
let fields = HashMap::new();
let sql = "NEW.unknown_field IS NOT NULL";
let result = substitute_new_refs(sql, &fields);
assert_eq!(result, "NULL IS NOT NULL");
}
#[test]
fn substitute_new_refs_with_subquery() {
let mut fields = HashMap::new();
fields.insert(
"email".to_string(),
nodedb_types::Value::String("test@x.com".into()),
);
fields.insert("id".to_string(), nodedb_types::Value::String("u1".into()));
let sql = "NEW.email NOT IN (SELECT email FROM users WHERE id != NEW.id)";
let result = substitute_new_refs(sql, &fields);
assert_eq!(
result,
"'test@x.com' NOT IN (SELECT email FROM users WHERE id != 'u1')"
);
}
#[test]
fn value_to_sql_literal_escapes_quotes() {
let val = nodedb_types::Value::String("it's a test".into());
assert_eq!(value_to_sql_literal(&val), "'it''s a test'");
}
#[test]
fn value_to_sql_literal_types() {
assert_eq!(value_to_sql_literal(&nodedb_types::Value::Null), "NULL");
assert_eq!(
value_to_sql_literal(&nodedb_types::Value::Bool(true)),
"TRUE"
);
assert_eq!(
value_to_sql_literal(&nodedb_types::Value::Integer(42)),
"42"
);
assert_eq!(
value_to_sql_literal(&nodedb_types::Value::Float(3.5)),
"3.5"
);
}
}