Skip to main content

contextdb_engine/
schema_enforcer.rs

1use crate::database::Database;
2use contextdb_core::{Error, Result, Value};
3use contextdb_parser::ast::{Expr, Literal};
4use contextdb_planner::PhysicalPlan;
5use std::collections::{HashMap, HashSet};
6
7pub fn validate_dml(
8    plan: &PhysicalPlan,
9    db: &Database,
10    params: &HashMap<String, Value>,
11) -> Result<()> {
12    match plan {
13        PhysicalPlan::Update(p) if db.relational_store().is_immutable(&p.table) => {
14            Err(Error::ImmutableTable(p.table.clone()))
15        }
16        PhysicalPlan::Delete(p) if db.relational_store().is_immutable(&p.table) => {
17            Err(Error::ImmutableTable(p.table.clone()))
18        }
19        PhysicalPlan::Insert(p) => {
20            let table_meta = db
21                .table_meta(&p.table)
22                .ok_or_else(|| Error::TableNotFound(p.table.clone()))?;
23
24            // When no column list is provided, infer from table metadata.
25            let columns: Vec<String> = if p.columns.is_empty() {
26                table_meta.columns.iter().map(|c| c.name.clone()).collect()
27            } else {
28                p.columns.clone()
29            };
30
31            // Validate that all INSERT columns exist in the table schema
32            for col_name in &columns {
33                if !table_meta.columns.iter().any(|c| c.name == *col_name) {
34                    return Err(Error::Other(format!(
35                        "column '{}' does not exist in table '{}'",
36                        col_name, p.table
37                    )));
38                }
39            }
40
41            if p.on_conflict.is_none()
42                && let Some(id_idx) = columns.iter().position(|c| c == "id")
43            {
44                let mut pending_ids = HashSet::new();
45                for row in &p.values {
46                    let id_value = row
47                        .get(id_idx)
48                        .map(|e| resolve_expr(e, params))
49                        .transpose()?;
50                    if let Some(v) = id_value {
51                        if !pending_ids.insert(cache_key_for_value(&v)) {
52                            return Err(Error::UniqueViolation {
53                                table: p.table.clone(),
54                                column: "id".to_string(),
55                            });
56                        }
57                        let lookup = db.point_lookup(&p.table, "id", &v, db.snapshot())?;
58                        if lookup.is_some() {
59                            return Err(Error::UniqueViolation {
60                                table: p.table.clone(),
61                                column: "id".to_string(),
62                            });
63                        }
64                    }
65                }
66            }
67
68            for row in &p.values {
69                for column in table_meta
70                    .columns
71                    .iter()
72                    .filter(|column| !column.nullable && !column.primary_key)
73                {
74                    if let Some(index) = columns.iter().position(|name| name == &column.name) {
75                        let value = row
76                            .get(index)
77                            .ok_or_else(|| {
78                                Error::PlanError("column/value count mismatch".to_string())
79                            })
80                            .and_then(|expr| resolve_expr(expr, params))?;
81                        if value == Value::Null {
82                            return Err(Error::Other(format!(
83                                "NOT NULL constraint violated: {}.{}",
84                                p.table, column.name
85                            )));
86                        }
87                    } else if column.default.is_none() {
88                        return Err(Error::Other(format!(
89                            "NOT NULL constraint violated: {}.{}",
90                            p.table, column.name
91                        )));
92                    }
93                }
94            }
95
96            let unique_columns: Vec<_> = table_meta
97                .columns
98                .iter()
99                .filter(|column| column.unique && !column.primary_key)
100                .collect();
101            if !unique_columns.is_empty() {
102                let existing_rows = db.scan(&p.table, db.snapshot())?;
103                for row in &p.values {
104                    for column in &unique_columns {
105                        let Some(index) = columns.iter().position(|name| name == &column.name)
106                        else {
107                            continue;
108                        };
109                        let value = row
110                            .get(index)
111                            .ok_or_else(|| {
112                                Error::PlanError("column/value count mismatch".to_string())
113                            })
114                            .and_then(|expr| resolve_expr(expr, params))?;
115                        if value == Value::Null {
116                            continue;
117                        }
118                        if existing_rows
119                            .iter()
120                            .any(|existing| existing.values.get(&column.name) == Some(&value))
121                        {
122                            return Err(Error::UniqueViolation {
123                                table: p.table.clone(),
124                                column: column.name.clone(),
125                            });
126                        }
127                    }
128                }
129            }
130            Ok(())
131        }
132        _ => Ok(()),
133    }
134}
135
136fn resolve_expr(expr: &Expr, params: &HashMap<String, Value>) -> Result<Value> {
137    match expr {
138        Expr::Literal(l) => Ok(match l {
139            Literal::Null => Value::Null,
140            Literal::Bool(v) => Value::Bool(*v),
141            Literal::Integer(v) => Value::Int64(*v),
142            Literal::Real(v) => Value::Float64(*v),
143            Literal::Text(v) => {
144                if let Ok(id) = uuid::Uuid::parse_str(v) {
145                    Value::Uuid(id)
146                } else {
147                    Value::Text(v.clone())
148                }
149            }
150            Literal::Vector(v) => Value::Vector(v.clone()),
151        }),
152        Expr::Parameter(p) => params
153            .get(p)
154            .cloned()
155            .ok_or_else(|| Error::NotFound(format!("missing parameter: {}", p))),
156        Expr::Column(c) => Ok(Value::Text(c.column.clone())),
157        _ => Err(Error::PlanError(
158            "unsupported expression in schema enforcer".to_string(),
159        )),
160    }
161}
162
163fn cache_key_for_value(value: &Value) -> Vec<u8> {
164    bincode::serde::encode_to_vec(value, bincode::config::standard())
165        .expect("Value should serialize for uniqueness cache key generation")
166}