use std::collections::HashMap;
use nodedb_types::Value;
#[derive(Debug, Clone)]
pub struct RowBindings {
new_row: Option<HashMap<String, Value>>,
old_row: Option<HashMap<String, Value>>,
tg_op: String,
tg_table_name: String,
tg_when: String,
variables: HashMap<String, String>,
}
impl RowBindings {
pub fn empty() -> Self {
Self {
new_row: None,
old_row: None,
tg_op: String::new(),
tg_table_name: String::new(),
tg_when: String::new(),
variables: HashMap::new(),
}
}
pub fn with_params(params: HashMap<String, String>) -> Self {
Self {
new_row: None,
old_row: None,
tg_op: String::new(),
tg_table_name: String::new(),
tg_when: String::new(),
variables: params,
}
}
pub fn with_variable(&self, name: &str, value: &str) -> Self {
let mut copy = self.clone();
copy.variables.insert(name.to_string(), value.to_string());
copy
}
pub fn with_new_row(&self, new_row: HashMap<String, Value>) -> Self {
let mut copy = self.clone();
copy.new_row = Some(new_row);
copy
}
pub fn before_insert(collection: &str, new_row: HashMap<String, Value>) -> Self {
Self {
new_row: Some(new_row),
old_row: None,
tg_op: "INSERT".into(),
tg_table_name: collection.into(),
tg_when: "BEFORE".into(),
variables: HashMap::new(),
}
}
pub fn before_update(
collection: &str,
old_row: HashMap<String, Value>,
new_row: HashMap<String, Value>,
) -> Self {
Self {
new_row: Some(new_row),
old_row: Some(old_row),
tg_op: "UPDATE".into(),
tg_table_name: collection.into(),
tg_when: "BEFORE".into(),
variables: HashMap::new(),
}
}
pub fn before_delete(collection: &str, old_row: HashMap<String, Value>) -> Self {
Self {
new_row: None,
old_row: Some(old_row),
tg_op: "DELETE".into(),
tg_table_name: collection.into(),
tg_when: "BEFORE".into(),
variables: HashMap::new(),
}
}
pub fn after_insert(collection: &str, new_row: HashMap<String, Value>) -> Self {
Self {
new_row: Some(new_row),
old_row: None,
tg_op: "INSERT".into(),
tg_table_name: collection.into(),
tg_when: "AFTER".into(),
variables: HashMap::new(),
}
}
pub fn after_update(
collection: &str,
old_row: HashMap<String, Value>,
new_row: HashMap<String, Value>,
) -> Self {
Self {
new_row: Some(new_row),
old_row: Some(old_row),
tg_op: "UPDATE".into(),
tg_table_name: collection.into(),
tg_when: "AFTER".into(),
variables: HashMap::new(),
}
}
pub fn statement(collection: &str, tg_op: &str) -> Self {
Self {
new_row: None,
old_row: None,
tg_op: tg_op.into(),
tg_table_name: collection.into(),
tg_when: "AFTER".into(),
variables: HashMap::new(),
}
}
pub fn after_delete(collection: &str, old_row: HashMap<String, Value>) -> Self {
Self {
new_row: None,
old_row: Some(old_row),
tg_op: "DELETE".into(),
tg_table_name: collection.into(),
tg_when: "AFTER".into(),
variables: HashMap::new(),
}
}
pub fn substitute(&self, sql: &str) -> String {
let mut result = sql.to_string();
if let Some(ref new_row) = self.new_row {
for (field, value) in new_row {
let literal = value.to_sql_literal();
result = replace_qualified_field_reference(&result, "NEW", field, &literal);
}
}
if let Some(ref old_row) = self.old_row {
for (field, value) in old_row {
let literal = value.to_sql_literal();
result = replace_qualified_field_reference(&result, "OLD", field, &literal);
}
}
for (name, value) in &self.variables {
result = replace_case_insensitive(&result, name, value);
}
result = replace_case_insensitive(&result, "TG_OP", &format!("'{}'", self.tg_op));
result = replace_case_insensitive(
&result,
"TG_TABLE_NAME",
&format!("'{}'", self.tg_table_name),
);
result = replace_case_insensitive(&result, "TG_WHEN", &format!("'{}'", self.tg_when));
result
}
}
fn replace_qualified_field_reference(
input: &str,
qualifier: &str,
field: &str,
replacement: &str,
) -> String {
let bytes = input.as_bytes();
let qual_len = qualifier.len();
let field_len = field.len();
let mut result = String::with_capacity(input.len());
let mut i = 0;
while i < bytes.len() {
if !matches_identifier_at(input, i, qualifier) {
result.push(bytes[i] as char);
i += 1;
continue;
}
let mut cursor = i + qual_len;
cursor = skip_ascii_whitespace(bytes, cursor);
if bytes.get(cursor) != Some(&b'.') {
result.push(bytes[i] as char);
i += 1;
continue;
}
cursor += 1;
cursor = skip_ascii_whitespace(bytes, cursor);
if !matches_identifier_at(input, cursor, field) {
result.push(bytes[i] as char);
i += 1;
continue;
}
let field_end = cursor + field_len;
if is_identifier_char(bytes.get(field_end).copied()) {
result.push(bytes[i] as char);
i += 1;
continue;
}
result.push_str(replacement);
i = field_end;
}
result
}
fn matches_identifier_at(input: &str, start: usize, ident: &str) -> bool {
let bytes = input.as_bytes();
let ident_len = ident.len();
let Some(slice) = input.get(start..start + ident_len) else {
return false;
};
if !slice.eq_ignore_ascii_case(ident) {
return false;
}
if start > 0 && is_identifier_char(bytes.get(start - 1).copied()) {
return false;
}
!is_identifier_char(bytes.get(start + ident_len).copied())
}
use super::sql_bytes::{is_identifier_char, skip_ascii_whitespace};
fn replace_case_insensitive(input: &str, pattern: &str, replacement: &str) -> String {
if pattern.is_empty() {
return input.to_string();
}
let lower_input = input.to_lowercase();
let lower_pattern = pattern.to_lowercase();
let mut result = String::with_capacity(input.len());
let mut search_from = 0;
while let Some(pos) = lower_input[search_from..].find(&lower_pattern) {
let abs_pos = search_from + pos;
result.push_str(&input[search_from..abs_pos]);
result.push_str(replacement);
search_from = abs_pos + pattern.len();
}
result.push_str(&input[search_from..]);
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn substitute_new_fields() {
let mut row = HashMap::new();
row.insert("id".into(), Value::String("ord-1".into()));
row.insert("total".into(), Value::Float(99.99));
let bindings = RowBindings::after_insert("orders", row);
let sql = "INSERT INTO audit (id, amount) VALUES (NEW.id, NEW.total)";
let result = bindings.substitute(sql);
assert!(result.contains("'ord-1'"), "got: {result}");
assert!(result.contains("99.99"), "got: {result}");
}
#[test]
fn substitute_tg_op() {
let bindings = RowBindings::after_insert("orders", HashMap::new());
let result = bindings.substitute("VALUES (TG_OP, TG_TABLE_NAME)");
assert!(result.contains("'INSERT'"));
assert!(result.contains("'orders'"));
}
#[test]
fn substitute_null_value() {
let mut row = HashMap::new();
row.insert("x".into(), Value::Null);
let bindings = RowBindings::after_insert("c", row);
let result = bindings.substitute("SELECT NEW.x");
assert!(result.contains("NULL"));
}
#[test]
fn value_sql_literals() {
assert_eq!(Value::Null.to_sql_literal(), "NULL");
assert_eq!(Value::Bool(true).to_sql_literal(), "TRUE");
assert_eq!(Value::Integer(42).to_sql_literal(), "42");
assert_eq!(Value::String("hello".into()).to_sql_literal(), "'hello'");
assert_eq!(Value::String("it's".into()).to_sql_literal(), "'it''s'");
}
#[test]
fn substitute_spaced_qualified_field_reference() {
let mut row = HashMap::new();
row.insert("id".into(), Value::String("as1".into()));
let bindings = RowBindings::after_insert("orders", row);
let result = bindings.substitute("VALUES (NEW . id | | '_log', NEW . id)");
assert_eq!(result, "VALUES ('as1' | | '_log', 'as1')");
}
}