use std::collections::HashMap;
use tracing::{info, warn};
use crate::control::planner::procedural::executor::bindings::RowBindings;
use crate::control::planner::procedural::executor::core::{MAX_CASCADE_DEPTH, StatementExecutor};
use crate::control::security::catalog::trigger_types::{StoredTrigger, TriggerSecurity};
use crate::control::security::identity::{AuthMethod, AuthenticatedIdentity, Role};
use crate::control::state::SharedState;
use crate::types::TenantId;
pub fn check_cascade_depth(cascade_depth: u32, collection: &str) -> crate::Result<()> {
if cascade_depth >= MAX_CASCADE_DEPTH {
return Err(crate::Error::BadRequest {
detail: format!(
"trigger cascade depth exceeded ({MAX_CASCADE_DEPTH}): \
possible infinite loop on collection '{collection}'"
),
});
}
Ok(())
}
pub async fn fire_triggers(
state: &SharedState,
identity: &AuthenticatedIdentity,
tenant_id: TenantId,
collection: &str,
triggers: &[StoredTrigger],
bindings: &RowBindings,
cascade_depth: u32,
) -> crate::Result<()> {
for trigger in triggers {
if let Some(ref when_cond) = trigger.when_condition {
let bound_cond = bindings.substitute(when_cond);
if !evaluate_simple_condition(&bound_cond) {
continue;
}
}
let block = match state.block_cache.get_or_parse(&trigger.body_sql) {
Ok(b) => b,
Err(e) => {
warn!(
trigger = %trigger.name,
error = %e,
"failed to parse trigger body, skipping"
);
continue;
}
};
let effective_identity = resolve_trigger_identity(trigger, identity, tenant_id);
info!(
trigger = %trigger.name,
collection = collection,
timing = trigger.timing.as_str(),
security = trigger.security.as_str(),
caller = %identity.username,
effective_user = %effective_identity.username,
cascade_depth = cascade_depth,
"trigger invoked"
);
let executor = StatementExecutor::with_source(
state,
effective_identity,
tenant_id,
cascade_depth + 1,
crate::event::EventSource::Trigger,
);
if let Err(e) = executor.execute_block(&block, bindings).await {
return Err(crate::Error::BadRequest {
detail: format!(
"trigger '{}' on '{}' failed: {}",
trigger.name, collection, e
),
});
}
}
Ok(())
}
pub(crate) fn resolve_trigger_identity(
trigger: &StoredTrigger,
caller: &AuthenticatedIdentity,
tenant_id: TenantId,
) -> AuthenticatedIdentity {
match trigger.security {
TriggerSecurity::Invoker => caller.clone(),
TriggerSecurity::Definer => {
AuthenticatedIdentity {
user_id: 0, username: trigger.owner.clone(),
tenant_id,
auth_method: AuthMethod::Trust,
roles: vec![Role::TenantAdmin],
is_superuser: true,
default_database: None,
accessible_databases: crate::control::security::identity::DatabaseSet::All,
}
}
}
}
#[allow(clippy::too_many_arguments)]
pub async fn fire_before_triggers_with_mutation(
state: &SharedState,
identity: &AuthenticatedIdentity,
tenant_id: TenantId,
collection: &str,
triggers: &[StoredTrigger],
bindings: &RowBindings,
cascade_depth: u32,
mut new_fields: Option<HashMap<String, nodedb_types::Value>>,
) -> crate::Result<Option<HashMap<String, nodedb_types::Value>>> {
for trigger in triggers {
if let Some(ref when_cond) = trigger.when_condition {
let current_bindings = if let Some(ref fields) = new_fields {
rebuild_bindings_with_new(bindings, fields)
} else {
bindings.clone()
};
let bound_cond = current_bindings.substitute(when_cond);
if !evaluate_simple_condition(&bound_cond) {
continue;
}
}
let block = match state.block_cache.get_or_parse(&trigger.body_sql) {
Ok(b) => b,
Err(e) => {
warn!(
trigger = %trigger.name,
error = %e,
"failed to parse BEFORE trigger body, skipping"
);
continue;
}
};
let current_bindings = if let Some(ref fields) = new_fields {
rebuild_bindings_with_new(bindings, fields)
} else {
bindings.clone()
};
let effective_identity = resolve_trigger_identity(trigger, identity, tenant_id);
info!(
trigger = %trigger.name,
collection = collection,
timing = "BEFORE",
security = trigger.security.as_str(),
caller = %identity.username,
effective_user = %effective_identity.username,
cascade_depth = cascade_depth,
"BEFORE trigger invoked"
);
let executor = StatementExecutor::with_source(
state,
effective_identity,
tenant_id,
cascade_depth + 1,
crate::event::EventSource::Trigger,
);
if let Err(e) = executor.execute_block(&block, ¤t_bindings).await {
return Err(crate::Error::BadRequest {
detail: format!(
"BEFORE trigger '{}' on '{}' aborted DML: {}",
trigger.name, collection, e
),
});
}
let mutations = executor.take_new_mutations();
if !mutations.is_empty()
&& let Some(fields) = new_fields.as_mut()
{
for (field, value) in mutations {
fields.insert(field, value);
}
}
}
Ok(new_fields)
}
fn rebuild_bindings_with_new(
original: &RowBindings,
new_fields: &HashMap<String, nodedb_types::Value>,
) -> RowBindings {
original.with_new_row(new_fields.clone())
}
pub fn evaluate_simple_condition(condition: &str) -> bool {
super::try_eval_simple_condition(condition).unwrap_or(true)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn simple_condition_true() {
assert!(evaluate_simple_condition("TRUE"));
assert!(evaluate_simple_condition("1"));
}
#[test]
fn simple_condition_false() {
assert!(!evaluate_simple_condition("FALSE"));
assert!(!evaluate_simple_condition("0"));
assert!(!evaluate_simple_condition("NULL"));
}
#[test]
fn complex_condition_defaults_true() {
assert!(evaluate_simple_condition("'ord-1' IS NOT NULL"));
}
}