contextdb_engine/
schema_enforcer.rs1use crate::database::Database;
2use contextdb_core::{Error, Result, Value};
3use contextdb_parser::ast::{Expr, Literal};
4use contextdb_planner::PhysicalPlan;
5use std::collections::{HashMap, HashSet};
6
7pub fn validate_dml(
8 plan: &PhysicalPlan,
9 db: &Database,
10 params: &HashMap<String, Value>,
11) -> Result<()> {
12 match plan {
13 PhysicalPlan::Update(p) if db.relational_store().is_immutable(&p.table) => {
14 Err(Error::ImmutableTable(p.table.clone()))
15 }
16 PhysicalPlan::Delete(p) if db.relational_store().is_immutable(&p.table) => {
17 Err(Error::ImmutableTable(p.table.clone()))
18 }
19 PhysicalPlan::Insert(p) => {
20 let table_meta = db
21 .table_meta(&p.table)
22 .ok_or_else(|| Error::TableNotFound(p.table.clone()))?;
23
24 let columns: Vec<String> = if p.columns.is_empty() {
26 table_meta.columns.iter().map(|c| c.name.clone()).collect()
27 } else {
28 p.columns.clone()
29 };
30
31 for col_name in &columns {
33 if !table_meta.columns.iter().any(|c| c.name == *col_name) {
34 return Err(Error::Other(format!(
35 "column '{}' does not exist in table '{}'",
36 col_name, p.table
37 )));
38 }
39 }
40
41 if p.on_conflict.is_none()
42 && let Some(id_idx) = columns.iter().position(|c| c == "id")
43 {
44 let mut pending_ids = HashSet::new();
45 for row in &p.values {
46 let id_value = row
47 .get(id_idx)
48 .map(|e| resolve_expr(e, params))
49 .transpose()?;
50 if let Some(v) = id_value {
51 if !pending_ids.insert(cache_key_for_value(&v)) {
52 return Err(Error::UniqueViolation {
53 table: p.table.clone(),
54 column: "id".to_string(),
55 });
56 }
57 let lookup = db.point_lookup(&p.table, "id", &v, db.snapshot())?;
58 if lookup.is_some() {
59 return Err(Error::UniqueViolation {
60 table: p.table.clone(),
61 column: "id".to_string(),
62 });
63 }
64 }
65 }
66 }
67
68 for row in &p.values {
69 for column in table_meta
70 .columns
71 .iter()
72 .filter(|column| !column.nullable && !column.primary_key)
73 {
74 if let Some(index) = columns.iter().position(|name| name == &column.name) {
75 let value = row
76 .get(index)
77 .ok_or_else(|| {
78 Error::PlanError("column/value count mismatch".to_string())
79 })
80 .and_then(|expr| resolve_expr(expr, params))?;
81 if value == Value::Null {
82 return Err(Error::Other(format!(
83 "NOT NULL constraint violated: {}.{}",
84 p.table, column.name
85 )));
86 }
87 } else if column.default.is_none() {
88 return Err(Error::Other(format!(
89 "NOT NULL constraint violated: {}.{}",
90 p.table, column.name
91 )));
92 }
93 }
94 }
95
96 let unique_columns: Vec<_> = table_meta
97 .columns
98 .iter()
99 .filter(|column| column.unique && !column.primary_key)
100 .collect();
101 if !unique_columns.is_empty() {
102 let existing_rows = db.scan(&p.table, db.snapshot())?;
103 for row in &p.values {
104 for column in &unique_columns {
105 let Some(index) = columns.iter().position(|name| name == &column.name)
106 else {
107 continue;
108 };
109 let value = row
110 .get(index)
111 .ok_or_else(|| {
112 Error::PlanError("column/value count mismatch".to_string())
113 })
114 .and_then(|expr| resolve_expr(expr, params))?;
115 if value == Value::Null {
116 continue;
117 }
118 if existing_rows
119 .iter()
120 .any(|existing| existing.values.get(&column.name) == Some(&value))
121 {
122 return Err(Error::UniqueViolation {
123 table: p.table.clone(),
124 column: column.name.clone(),
125 });
126 }
127 }
128 }
129 }
130 Ok(())
131 }
132 _ => Ok(()),
133 }
134}
135
136fn resolve_expr(expr: &Expr, params: &HashMap<String, Value>) -> Result<Value> {
137 match expr {
138 Expr::Literal(l) => Ok(match l {
139 Literal::Null => Value::Null,
140 Literal::Bool(v) => Value::Bool(*v),
141 Literal::Integer(v) => Value::Int64(*v),
142 Literal::Real(v) => Value::Float64(*v),
143 Literal::Text(v) => {
144 if let Ok(id) = uuid::Uuid::parse_str(v) {
145 Value::Uuid(id)
146 } else {
147 Value::Text(v.clone())
148 }
149 }
150 Literal::Vector(v) => Value::Vector(v.clone()),
151 }),
152 Expr::Parameter(p) => params
153 .get(p)
154 .cloned()
155 .ok_or_else(|| Error::NotFound(format!("missing parameter: {}", p))),
156 Expr::Column(c) => Ok(Value::Text(c.column.clone())),
157 _ => Err(Error::PlanError(
158 "unsupported expression in schema enforcer".to_string(),
159 )),
160 }
161}
162
163fn cache_key_for_value(value: &Value) -> Vec<u8> {
164 bincode::serde::encode_to_vec(value, bincode::config::standard())
165 .expect("Value should serialize for uniqueness cache key generation")
166}