Skip to main content

citadel_sql/executor/
dml.rs

1use std::cell::RefCell;
2
3use citadel::Database;
4use citadel_txn::write_txn::WriteTxn;
5use rustc_hash::FxHashMap;
6
7use crate::encoding::{encode_composite_key_into, encode_row_into};
8use crate::error::{Result, SqlError};
9use crate::eval::{eval_expr, is_truthy, ColumnMap, EvalCtx};
10use crate::parser::*;
11use crate::types::*;
12
13use crate::schema::SchemaManager;
14
15use super::compile::CompiledPlan;
16use super::helpers::*;
17use super::CteContext;
18
19pub(super) fn exec_insert(
20    db: &Database,
21    schema: &SchemaManager,
22    stmt: &InsertStmt,
23    params: &[Value],
24) -> Result<ExecutionResult> {
25    let empty_ctes = CteContext::default();
26    let materialized;
27    let stmt = if insert_has_subquery(stmt) {
28        materialized = materialize_insert(stmt, &mut |sub| {
29            exec_subquery_read(db, schema, sub, &empty_ctes)
30        })?;
31        &materialized
32    } else {
33        stmt
34    };
35
36    let lower_name = stmt.table.to_ascii_lowercase();
37    if schema.get_view(&lower_name).is_some() {
38        return Err(SqlError::CannotModifyView(stmt.table.clone()));
39    }
40    let table_schema = schema
41        .get(&lower_name)
42        .ok_or_else(|| SqlError::TableNotFound(stmt.table.clone()))?;
43
44    let insert_columns = if stmt.columns.is_empty() {
45        table_schema
46            .columns
47            .iter()
48            .map(|c| c.name.clone())
49            .collect::<Vec<_>>()
50    } else {
51        stmt.columns
52            .iter()
53            .map(|c| c.to_ascii_lowercase())
54            .collect()
55    };
56
57    let col_indices: Vec<usize> = insert_columns
58        .iter()
59        .map(|name| {
60            table_schema
61                .column_index(name)
62                .ok_or_else(|| SqlError::ColumnNotFound(name.clone()))
63        })
64        .collect::<Result<_>>()?;
65
66    let defaults: Vec<(usize, &Expr)> = table_schema
67        .columns
68        .iter()
69        .filter(|c| c.default_expr.is_some() && !col_indices.contains(&(c.position as usize)))
70        .map(|c| (c.position as usize, c.default_expr.as_ref().unwrap()))
71        .collect();
72
73    // ColumnMap for CHECK evaluation
74    let has_checks = table_schema.has_checks();
75    let check_col_map = if has_checks {
76        Some(ColumnMap::new(&table_schema.columns))
77    } else {
78        None
79    };
80
81    let select_rows = match &stmt.source {
82        InsertSource::Select(sq) => {
83            let insert_ctes =
84                super::materialize_all_ctes(&sq.ctes, sq.recursive, &mut |body, ctx| {
85                    exec_query_body_read(db, schema, body, ctx)
86                })?;
87            let qr = exec_query_body_read(db, schema, &sq.body, &insert_ctes)?;
88            Some(qr.rows)
89        }
90        InsertSource::Values(_) => None,
91    };
92
93    let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
94    let mut count: u64 = 0;
95
96    let pk_indices = table_schema.pk_indices();
97    let non_pk = table_schema.non_pk_indices();
98    let enc_pos = table_schema.encoding_positions();
99    let phys_count = table_schema.physical_non_pk_count();
100    let mut row = vec![Value::Null; table_schema.columns.len()];
101    let mut pk_values: Vec<Value> = vec![Value::Null; pk_indices.len()];
102    let mut value_values: Vec<Value> = vec![Value::Null; phys_count];
103    let mut key_buf: Vec<u8> = Vec::with_capacity(64);
104    let mut value_buf: Vec<u8> = Vec::with_capacity(256);
105    let mut fk_key_buf: Vec<u8> = Vec::with_capacity(64);
106
107    let values = match &stmt.source {
108        InsertSource::Values(rows) => Some(rows.as_slice()),
109        InsertSource::Select(_) => None,
110    };
111    let sel_rows = select_rows.as_deref();
112
113    let total = match (values, sel_rows) {
114        (Some(rows), _) => rows.len(),
115        (_, Some(rows)) => rows.len(),
116        _ => 0,
117    };
118
119    if let Some(sel) = sel_rows {
120        if !sel.is_empty() && sel[0].len() != insert_columns.len() {
121            return Err(SqlError::InvalidValue(format!(
122                "INSERT ... SELECT column count mismatch: expected {}, got {}",
123                insert_columns.len(),
124                sel[0].len()
125            )));
126        }
127    }
128
129    for idx in 0..total {
130        for v in row.iter_mut() {
131            *v = Value::Null;
132        }
133
134        if let Some(value_rows) = values {
135            let value_row = &value_rows[idx];
136            if value_row.len() != insert_columns.len() {
137                return Err(SqlError::InvalidValue(format!(
138                    "expected {} values, got {}",
139                    insert_columns.len(),
140                    value_row.len()
141                )));
142            }
143            for (i, expr) in value_row.iter().enumerate() {
144                let val = if let Expr::Parameter(n) = expr {
145                    params
146                        .get(n - 1)
147                        .cloned()
148                        .ok_or_else(|| SqlError::Parse(format!("unbound parameter ${n}")))?
149                } else {
150                    eval_const_expr(expr)?
151                };
152                let col_idx = col_indices[i];
153                let col = &table_schema.columns[col_idx];
154                let got_type = val.data_type();
155                row[col_idx] = if val.is_null() {
156                    Value::Null
157                } else {
158                    val.coerce_into(col.data_type)
159                        .ok_or_else(|| SqlError::TypeMismatch {
160                            expected: col.data_type.to_string(),
161                            got: got_type.to_string(),
162                        })?
163                };
164            }
165        } else if let Some(sel) = sel_rows {
166            let sel_row = &sel[idx];
167            for (i, val) in sel_row.iter().enumerate() {
168                let col_idx = col_indices[i];
169                let col = &table_schema.columns[col_idx];
170                let got_type = val.data_type();
171                row[col_idx] = if val.is_null() {
172                    Value::Null
173                } else {
174                    val.clone().coerce_into(col.data_type).ok_or_else(|| {
175                        SqlError::TypeMismatch {
176                            expected: col.data_type.to_string(),
177                            got: got_type.to_string(),
178                        }
179                    })?
180                };
181            }
182        }
183
184        for &(pos, def_expr) in &defaults {
185            let val = eval_const_expr(def_expr)?;
186            let col = &table_schema.columns[pos];
187            if val.is_null() {
188                // row[pos] already Null from init
189            } else {
190                let got_type = val.data_type();
191                row[pos] =
192                    val.coerce_into(col.data_type)
193                        .ok_or_else(|| SqlError::TypeMismatch {
194                            expected: col.data_type.to_string(),
195                            got: got_type.to_string(),
196                        })?;
197            }
198        }
199
200        for col in &table_schema.columns {
201            if !col.nullable && row[col.position as usize].is_null() {
202                return Err(SqlError::NotNullViolation(col.name.clone()));
203            }
204        }
205
206        if let Some(ref col_map) = check_col_map {
207            for col in &table_schema.columns {
208                if let Some(ref check) = col.check_expr {
209                    let result = eval_expr(check, &EvalCtx::new(col_map, &row))?;
210                    if !is_truthy(&result) && !result.is_null() {
211                        let name = col.check_name.as_deref().unwrap_or(&col.name);
212                        return Err(SqlError::CheckViolation(name.to_string()));
213                    }
214                }
215            }
216            for tc in &table_schema.check_constraints {
217                let result = eval_expr(&tc.expr, &EvalCtx::new(col_map, &row))?;
218                if !is_truthy(&result) && !result.is_null() {
219                    let name = tc.name.as_deref().unwrap_or(&tc.sql);
220                    return Err(SqlError::CheckViolation(name.to_string()));
221                }
222            }
223        }
224
225        for fk in &table_schema.foreign_keys {
226            let any_null = fk.columns.iter().any(|&ci| row[ci as usize].is_null());
227            if any_null {
228                continue; // MATCH SIMPLE: skip if any FK col is NULL
229            }
230            let fk_vals: Vec<Value> = fk
231                .columns
232                .iter()
233                .map(|&ci| row[ci as usize].clone())
234                .collect();
235            fk_key_buf.clear();
236            encode_composite_key_into(&fk_vals, &mut fk_key_buf);
237            let found = wtx
238                .table_get(fk.foreign_table.as_bytes(), &fk_key_buf)
239                .map_err(SqlError::Storage)?;
240            if found.is_none() {
241                let name = fk.name.as_deref().unwrap_or(&fk.foreign_table);
242                return Err(SqlError::ForeignKeyViolation(name.to_string()));
243            }
244        }
245
246        for (j, &i) in pk_indices.iter().enumerate() {
247            pk_values[j] = std::mem::replace(&mut row[i], Value::Null);
248        }
249        encode_composite_key_into(&pk_values, &mut key_buf);
250
251        for (j, &i) in non_pk.iter().enumerate() {
252            value_values[enc_pos[j] as usize] = std::mem::replace(&mut row[i], Value::Null);
253        }
254        encode_row_into(&value_values, &mut value_buf);
255
256        if key_buf.len() > citadel_core::MAX_KEY_SIZE {
257            return Err(SqlError::KeyTooLarge {
258                size: key_buf.len(),
259                max: citadel_core::MAX_KEY_SIZE,
260            });
261        }
262        if value_buf.len() > citadel_core::MAX_INLINE_VALUE_SIZE {
263            return Err(SqlError::RowTooLarge {
264                size: value_buf.len(),
265                max: citadel_core::MAX_INLINE_VALUE_SIZE,
266            });
267        }
268
269        let is_new = wtx
270            .table_insert(stmt.table.as_bytes(), &key_buf, &value_buf)
271            .map_err(SqlError::Storage)?;
272        if !is_new {
273            return Err(SqlError::DuplicateKey);
274        }
275
276        if !table_schema.indices.is_empty() {
277            for (j, &i) in pk_indices.iter().enumerate() {
278                row[i] = pk_values[j].clone();
279            }
280            for (j, &i) in non_pk.iter().enumerate() {
281                row[i] = std::mem::replace(&mut value_values[enc_pos[j] as usize], Value::Null);
282            }
283            insert_index_entries(&mut wtx, table_schema, &row, &pk_values)?;
284        }
285        count += 1;
286    }
287
288    wtx.commit().map_err(SqlError::Storage)?;
289    Ok(ExecutionResult::RowsAffected(count))
290}
291
292pub(super) fn has_subquery(expr: &Expr) -> bool {
293    crate::parser::has_subquery(expr)
294}
295
296pub(super) fn stmt_has_subquery(stmt: &SelectStmt) -> bool {
297    if let Some(ref w) = stmt.where_clause {
298        if has_subquery(w) {
299            return true;
300        }
301    }
302    if let Some(ref h) = stmt.having {
303        if has_subquery(h) {
304            return true;
305        }
306    }
307    for col in &stmt.columns {
308        if let SelectColumn::Expr { expr, .. } = col {
309            if has_subquery(expr) {
310                return true;
311            }
312        }
313    }
314    for ob in &stmt.order_by {
315        if has_subquery(&ob.expr) {
316            return true;
317        }
318    }
319    for join in &stmt.joins {
320        if let Some(ref on_expr) = join.on_clause {
321            if has_subquery(on_expr) {
322                return true;
323            }
324        }
325    }
326    false
327}
328
329pub(super) fn materialize_expr(
330    expr: &Expr,
331    exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
332) -> Result<Expr> {
333    match expr {
334        Expr::InSubquery {
335            expr: e,
336            subquery,
337            negated,
338        } => {
339            let inner = materialize_expr(e, exec_sub)?;
340            let qr = exec_sub(subquery)?;
341            if !qr.columns.is_empty() && qr.columns.len() != 1 {
342                return Err(SqlError::SubqueryMultipleColumns);
343            }
344            let mut values = std::collections::HashSet::new();
345            let mut has_null = false;
346            for row in &qr.rows {
347                if row[0].is_null() {
348                    has_null = true;
349                } else {
350                    values.insert(row[0].clone());
351                }
352            }
353            Ok(Expr::InSet {
354                expr: Box::new(inner),
355                values,
356                has_null,
357                negated: *negated,
358            })
359        }
360        Expr::ScalarSubquery(subquery) => {
361            let qr = exec_sub(subquery)?;
362            if qr.rows.len() > 1 {
363                return Err(SqlError::SubqueryMultipleRows);
364            }
365            let val = if qr.rows.is_empty() {
366                Value::Null
367            } else {
368                qr.rows[0][0].clone()
369            };
370            Ok(Expr::Literal(val))
371        }
372        Expr::Exists { subquery, negated } => {
373            let qr = exec_sub(subquery)?;
374            let exists = !qr.rows.is_empty();
375            let result = if *negated { !exists } else { exists };
376            Ok(Expr::Literal(Value::Boolean(result)))
377        }
378        Expr::InList {
379            expr: e,
380            list,
381            negated,
382        } => {
383            let inner = materialize_expr(e, exec_sub)?;
384            let items = list
385                .iter()
386                .map(|item| materialize_expr(item, exec_sub))
387                .collect::<Result<Vec<_>>>()?;
388            Ok(Expr::InList {
389                expr: Box::new(inner),
390                list: items,
391                negated: *negated,
392            })
393        }
394        Expr::BinaryOp { left, op, right } => Ok(Expr::BinaryOp {
395            left: Box::new(materialize_expr(left, exec_sub)?),
396            op: *op,
397            right: Box::new(materialize_expr(right, exec_sub)?),
398        }),
399        Expr::UnaryOp { op, expr: e } => Ok(Expr::UnaryOp {
400            op: *op,
401            expr: Box::new(materialize_expr(e, exec_sub)?),
402        }),
403        Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(materialize_expr(e, exec_sub)?))),
404        Expr::IsNotNull(e) => Ok(Expr::IsNotNull(Box::new(materialize_expr(e, exec_sub)?))),
405        Expr::InSet {
406            expr: e,
407            values,
408            has_null,
409            negated,
410        } => Ok(Expr::InSet {
411            expr: Box::new(materialize_expr(e, exec_sub)?),
412            values: values.clone(),
413            has_null: *has_null,
414            negated: *negated,
415        }),
416        Expr::Between {
417            expr: e,
418            low,
419            high,
420            negated,
421        } => Ok(Expr::Between {
422            expr: Box::new(materialize_expr(e, exec_sub)?),
423            low: Box::new(materialize_expr(low, exec_sub)?),
424            high: Box::new(materialize_expr(high, exec_sub)?),
425            negated: *negated,
426        }),
427        Expr::Like {
428            expr: e,
429            pattern,
430            escape,
431            negated,
432        } => {
433            let esc = escape
434                .as_ref()
435                .map(|es| materialize_expr(es, exec_sub).map(Box::new))
436                .transpose()?;
437            Ok(Expr::Like {
438                expr: Box::new(materialize_expr(e, exec_sub)?),
439                pattern: Box::new(materialize_expr(pattern, exec_sub)?),
440                escape: esc,
441                negated: *negated,
442            })
443        }
444        Expr::Case {
445            operand,
446            conditions,
447            else_result,
448        } => {
449            let op = operand
450                .as_ref()
451                .map(|e| materialize_expr(e, exec_sub).map(Box::new))
452                .transpose()?;
453            let conds = conditions
454                .iter()
455                .map(|(c, r)| {
456                    Ok((
457                        materialize_expr(c, exec_sub)?,
458                        materialize_expr(r, exec_sub)?,
459                    ))
460                })
461                .collect::<Result<Vec<_>>>()?;
462            let else_r = else_result
463                .as_ref()
464                .map(|e| materialize_expr(e, exec_sub).map(Box::new))
465                .transpose()?;
466            Ok(Expr::Case {
467                operand: op,
468                conditions: conds,
469                else_result: else_r,
470            })
471        }
472        Expr::Coalesce(args) => {
473            let materialized = args
474                .iter()
475                .map(|a| materialize_expr(a, exec_sub))
476                .collect::<Result<Vec<_>>>()?;
477            Ok(Expr::Coalesce(materialized))
478        }
479        Expr::Cast { expr: e, data_type } => Ok(Expr::Cast {
480            expr: Box::new(materialize_expr(e, exec_sub)?),
481            data_type: *data_type,
482        }),
483        Expr::Function {
484            name,
485            args,
486            distinct,
487        } => {
488            let materialized = args
489                .iter()
490                .map(|a| materialize_expr(a, exec_sub))
491                .collect::<Result<Vec<_>>>()?;
492            Ok(Expr::Function {
493                name: name.clone(),
494                args: materialized,
495                distinct: *distinct,
496            })
497        }
498        other => Ok(other.clone()),
499    }
500}
501
502pub(super) fn materialize_stmt(
503    stmt: &SelectStmt,
504    exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
505) -> Result<SelectStmt> {
506    let where_clause = stmt
507        .where_clause
508        .as_ref()
509        .map(|e| materialize_expr(e, exec_sub))
510        .transpose()?;
511    let having = stmt
512        .having
513        .as_ref()
514        .map(|e| materialize_expr(e, exec_sub))
515        .transpose()?;
516    let columns = stmt
517        .columns
518        .iter()
519        .map(|c| match c {
520            SelectColumn::AllColumns => Ok(SelectColumn::AllColumns),
521            SelectColumn::Expr { expr, alias } => Ok(SelectColumn::Expr {
522                expr: materialize_expr(expr, exec_sub)?,
523                alias: alias.clone(),
524            }),
525        })
526        .collect::<Result<Vec<_>>>()?;
527    let order_by = stmt
528        .order_by
529        .iter()
530        .map(|ob| {
531            Ok(OrderByItem {
532                expr: materialize_expr(&ob.expr, exec_sub)?,
533                descending: ob.descending,
534                nulls_first: ob.nulls_first,
535            })
536        })
537        .collect::<Result<Vec<_>>>()?;
538    let joins = stmt
539        .joins
540        .iter()
541        .map(|j| {
542            let on_clause = j
543                .on_clause
544                .as_ref()
545                .map(|e| materialize_expr(e, exec_sub))
546                .transpose()?;
547            Ok(JoinClause {
548                join_type: j.join_type,
549                table: j.table.clone(),
550                on_clause,
551            })
552        })
553        .collect::<Result<Vec<_>>>()?;
554    let group_by = stmt
555        .group_by
556        .iter()
557        .map(|e| materialize_expr(e, exec_sub))
558        .collect::<Result<Vec<_>>>()?;
559    Ok(SelectStmt {
560        columns,
561        from: stmt.from.clone(),
562        from_alias: stmt.from_alias.clone(),
563        joins,
564        distinct: stmt.distinct,
565        where_clause,
566        order_by,
567        limit: stmt.limit.clone(),
568        offset: stmt.offset.clone(),
569        group_by,
570        having,
571    })
572}
573
574pub(super) fn exec_subquery_read(
575    db: &Database,
576    schema: &SchemaManager,
577    stmt: &SelectStmt,
578    ctes: &CteContext,
579) -> Result<QueryResult> {
580    match super::exec_select(db, schema, stmt, ctes)? {
581        ExecutionResult::Query(qr) => Ok(qr),
582        _ => Ok(QueryResult {
583            columns: vec![],
584            rows: vec![],
585        }),
586    }
587}
588
589pub(super) fn exec_subquery_write(
590    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
591    schema: &SchemaManager,
592    stmt: &SelectStmt,
593    ctes: &CteContext,
594) -> Result<QueryResult> {
595    match super::exec_select_in_txn(wtx, schema, stmt, ctes)? {
596        ExecutionResult::Query(qr) => Ok(qr),
597        _ => Ok(QueryResult {
598            columns: vec![],
599            rows: vec![],
600        }),
601    }
602}
603
604pub(super) fn update_has_subquery(stmt: &UpdateStmt) -> bool {
605    stmt.where_clause.as_ref().is_some_and(has_subquery)
606        || stmt.assignments.iter().any(|(_, e)| has_subquery(e))
607}
608
609pub(super) fn materialize_update(
610    stmt: &UpdateStmt,
611    exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
612) -> Result<UpdateStmt> {
613    let where_clause = stmt
614        .where_clause
615        .as_ref()
616        .map(|e| materialize_expr(e, exec_sub))
617        .transpose()?;
618    let assignments = stmt
619        .assignments
620        .iter()
621        .map(|(name, expr)| Ok((name.clone(), materialize_expr(expr, exec_sub)?)))
622        .collect::<Result<Vec<_>>>()?;
623    Ok(UpdateStmt {
624        table: stmt.table.clone(),
625        assignments,
626        where_clause,
627    })
628}
629
630pub(super) fn delete_has_subquery(stmt: &DeleteStmt) -> bool {
631    stmt.where_clause.as_ref().is_some_and(has_subquery)
632}
633
634pub(super) fn materialize_delete(
635    stmt: &DeleteStmt,
636    exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
637) -> Result<DeleteStmt> {
638    let where_clause = stmt
639        .where_clause
640        .as_ref()
641        .map(|e| materialize_expr(e, exec_sub))
642        .transpose()?;
643    Ok(DeleteStmt {
644        table: stmt.table.clone(),
645        where_clause,
646    })
647}
648
649pub(super) fn insert_has_subquery(stmt: &InsertStmt) -> bool {
650    match &stmt.source {
651        InsertSource::Values(rows) => rows.iter().any(|row| row.iter().any(has_subquery)),
652        // SELECT source subqueries are handled by exec_select's correlated/non-correlated paths
653        InsertSource::Select(_) => false,
654    }
655}
656
657pub(super) fn materialize_insert(
658    stmt: &InsertStmt,
659    exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
660) -> Result<InsertStmt> {
661    let source = match &stmt.source {
662        InsertSource::Values(rows) => {
663            let mat = rows
664                .iter()
665                .map(|row| {
666                    row.iter()
667                        .map(|e| materialize_expr(e, exec_sub))
668                        .collect::<Result<Vec<_>>>()
669                })
670                .collect::<Result<Vec<_>>>()?;
671            InsertSource::Values(mat)
672        }
673        InsertSource::Select(sq) => {
674            let ctes = sq
675                .ctes
676                .iter()
677                .map(|c| {
678                    Ok(CteDefinition {
679                        name: c.name.clone(),
680                        column_aliases: c.column_aliases.clone(),
681                        body: materialize_query_body(&c.body, exec_sub)?,
682                    })
683                })
684                .collect::<Result<Vec<_>>>()?;
685            let body = materialize_query_body(&sq.body, exec_sub)?;
686            InsertSource::Select(Box::new(SelectQuery {
687                ctes,
688                recursive: sq.recursive,
689                body,
690            }))
691        }
692    };
693    Ok(InsertStmt {
694        table: stmt.table.clone(),
695        columns: stmt.columns.clone(),
696        source,
697    })
698}
699
700pub(super) fn materialize_query_body(
701    body: &QueryBody,
702    exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
703) -> Result<QueryBody> {
704    match body {
705        QueryBody::Select(sel) => Ok(QueryBody::Select(Box::new(materialize_stmt(
706            sel, exec_sub,
707        )?))),
708        QueryBody::Compound(comp) => Ok(QueryBody::Compound(Box::new(CompoundSelect {
709            op: comp.op.clone(),
710            all: comp.all,
711            left: Box::new(materialize_query_body(&comp.left, exec_sub)?),
712            right: Box::new(materialize_query_body(&comp.right, exec_sub)?),
713            order_by: comp.order_by.clone(),
714            limit: comp.limit.clone(),
715            offset: comp.offset.clone(),
716        }))),
717    }
718}
719
720pub(super) fn exec_query_body(
721    db: &Database,
722    schema: &SchemaManager,
723    body: &QueryBody,
724    ctes: &CteContext,
725) -> Result<ExecutionResult> {
726    match body {
727        QueryBody::Select(sel) => super::exec_select(db, schema, sel, ctes),
728        QueryBody::Compound(comp) => exec_compound_select(db, schema, comp, ctes),
729    }
730}
731
732pub(super) fn exec_query_body_in_txn(
733    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
734    schema: &SchemaManager,
735    body: &QueryBody,
736    ctes: &CteContext,
737) -> Result<ExecutionResult> {
738    match body {
739        QueryBody::Select(sel) => super::exec_select_in_txn(wtx, schema, sel, ctes),
740        QueryBody::Compound(comp) => exec_compound_select_in_txn(wtx, schema, comp, ctes),
741    }
742}
743
744pub(super) fn exec_query_body_read(
745    db: &Database,
746    schema: &SchemaManager,
747    body: &QueryBody,
748    ctes: &CteContext,
749) -> Result<QueryResult> {
750    match exec_query_body(db, schema, body, ctes)? {
751        ExecutionResult::Query(qr) => Ok(qr),
752        _ => Ok(QueryResult {
753            columns: vec![],
754            rows: vec![],
755        }),
756    }
757}
758
759pub(super) fn exec_query_body_write(
760    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
761    schema: &SchemaManager,
762    body: &QueryBody,
763    ctes: &CteContext,
764) -> Result<QueryResult> {
765    match exec_query_body_in_txn(wtx, schema, body, ctes)? {
766        ExecutionResult::Query(qr) => Ok(qr),
767        _ => Ok(QueryResult {
768            columns: vec![],
769            rows: vec![],
770        }),
771    }
772}
773
774pub(super) fn exec_compound_select(
775    db: &Database,
776    schema: &SchemaManager,
777    comp: &CompoundSelect,
778    ctes: &CteContext,
779) -> Result<ExecutionResult> {
780    let left_qr = match exec_query_body(db, schema, &comp.left, ctes)? {
781        ExecutionResult::Query(qr) => qr,
782        _ => QueryResult {
783            columns: vec![],
784            rows: vec![],
785        },
786    };
787    let right_qr = match exec_query_body(db, schema, &comp.right, ctes)? {
788        ExecutionResult::Query(qr) => qr,
789        _ => QueryResult {
790            columns: vec![],
791            rows: vec![],
792        },
793    };
794    apply_set_operation(comp, left_qr, right_qr)
795}
796
797pub(super) fn exec_compound_select_in_txn(
798    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
799    schema: &SchemaManager,
800    comp: &CompoundSelect,
801    ctes: &CteContext,
802) -> Result<ExecutionResult> {
803    let left_qr = match exec_query_body_in_txn(wtx, schema, &comp.left, ctes)? {
804        ExecutionResult::Query(qr) => qr,
805        _ => QueryResult {
806            columns: vec![],
807            rows: vec![],
808        },
809    };
810    let right_qr = match exec_query_body_in_txn(wtx, schema, &comp.right, ctes)? {
811        ExecutionResult::Query(qr) => qr,
812        _ => QueryResult {
813            columns: vec![],
814            rows: vec![],
815        },
816    };
817    apply_set_operation(comp, left_qr, right_qr)
818}
819
820pub(super) fn apply_set_operation(
821    comp: &CompoundSelect,
822    left_qr: QueryResult,
823    right_qr: QueryResult,
824) -> Result<ExecutionResult> {
825    if !left_qr.columns.is_empty()
826        && !right_qr.columns.is_empty()
827        && left_qr.columns.len() != right_qr.columns.len()
828    {
829        return Err(SqlError::CompoundColumnCountMismatch {
830            left: left_qr.columns.len(),
831            right: right_qr.columns.len(),
832        });
833    }
834
835    let columns = left_qr.columns;
836
837    let mut rows = match (&comp.op, comp.all) {
838        (SetOp::Union, true) => {
839            let mut rows = left_qr.rows;
840            rows.extend(right_qr.rows);
841            rows
842        }
843        (SetOp::Union, false) => {
844            let mut seen: std::collections::HashSet<Vec<Value>> = std::collections::HashSet::new();
845            let mut rows = Vec::new();
846            for row in left_qr.rows.into_iter().chain(right_qr.rows) {
847                if !seen.contains(&row) {
848                    seen.insert(row.clone());
849                    rows.push(row);
850                }
851            }
852            rows
853        }
854        (SetOp::Intersect, true) => {
855            let mut right_counts: FxHashMap<Vec<Value>, usize> = FxHashMap::default();
856            for row in &right_qr.rows {
857                *right_counts.entry(row.clone()).or_insert(0) += 1;
858            }
859            let mut rows = Vec::new();
860            for row in left_qr.rows {
861                if let Some(count) = right_counts.get_mut(&row) {
862                    if *count > 0 {
863                        *count -= 1;
864                        rows.push(row);
865                    }
866                }
867            }
868            rows
869        }
870        (SetOp::Intersect, false) => {
871            let right_set: std::collections::HashSet<Vec<Value>> =
872                right_qr.rows.into_iter().collect();
873            let mut seen: std::collections::HashSet<Vec<Value>> = std::collections::HashSet::new();
874            let mut rows = Vec::new();
875            for row in left_qr.rows {
876                if right_set.contains(&row) && !seen.contains(&row) {
877                    seen.insert(row.clone());
878                    rows.push(row);
879                }
880            }
881            rows
882        }
883        (SetOp::Except, true) => {
884            let mut right_counts: FxHashMap<Vec<Value>, usize> = FxHashMap::default();
885            for row in &right_qr.rows {
886                *right_counts.entry(row.clone()).or_insert(0) += 1;
887            }
888            let mut rows = Vec::new();
889            for row in left_qr.rows {
890                if let Some(count) = right_counts.get_mut(&row) {
891                    if *count > 0 {
892                        *count -= 1;
893                        continue;
894                    }
895                }
896                rows.push(row);
897            }
898            rows
899        }
900        (SetOp::Except, false) => {
901            let right_set: std::collections::HashSet<Vec<Value>> =
902                right_qr.rows.into_iter().collect();
903            let mut seen: std::collections::HashSet<Vec<Value>> = std::collections::HashSet::new();
904            let mut rows = Vec::new();
905            for row in left_qr.rows {
906                if !right_set.contains(&row) && !seen.contains(&row) {
907                    seen.insert(row.clone());
908                    rows.push(row);
909                }
910            }
911            rows
912        }
913    };
914
915    if !comp.order_by.is_empty() {
916        let col_defs: Vec<crate::types::ColumnDef> = columns
917            .iter()
918            .enumerate()
919            .map(|(i, name)| crate::types::ColumnDef {
920                name: name.clone(),
921                data_type: crate::types::DataType::Null,
922                nullable: true,
923                position: i as u16,
924                default_expr: None,
925                default_sql: None,
926                check_expr: None,
927                check_sql: None,
928                check_name: None,
929                is_with_timezone: false,
930            })
931            .collect();
932        sort_rows(&mut rows, &comp.order_by, &col_defs)?;
933    }
934
935    if let Some(ref offset_expr) = comp.offset {
936        let offset = eval_const_int(offset_expr)?.max(0) as usize;
937        if offset < rows.len() {
938            rows = rows.split_off(offset);
939        } else {
940            rows.clear();
941        }
942    }
943
944    if let Some(ref limit_expr) = comp.limit {
945        let limit = eval_const_int(limit_expr)?.max(0) as usize;
946        rows.truncate(limit);
947    }
948
949    Ok(ExecutionResult::Query(QueryResult { columns, rows }))
950}
951
952struct InsertBufs {
953    row: Vec<Value>,
954    pk_values: Vec<Value>,
955    value_values: Vec<Value>,
956    key_buf: Vec<u8>,
957    value_buf: Vec<u8>,
958    col_indices: Vec<usize>,
959    fk_key_buf: Vec<u8>,
960}
961
962impl InsertBufs {
963    fn new() -> Self {
964        Self {
965            row: Vec::new(),
966            pk_values: Vec::new(),
967            value_values: Vec::new(),
968            key_buf: Vec::with_capacity(64),
969            value_buf: Vec::with_capacity(256),
970            col_indices: Vec::new(),
971            fk_key_buf: Vec::with_capacity(64),
972        }
973    }
974}
975
976thread_local! {
977    static INSERT_SCRATCH: RefCell<InsertBufs> = RefCell::new(InsertBufs::new());
978}
979
980fn with_insert_scratch<R>(f: impl FnOnce(&mut InsertBufs) -> R) -> R {
981    INSERT_SCRATCH.with(|slot| f(&mut slot.borrow_mut()))
982}
983
984pub fn exec_insert_in_txn(
985    wtx: &mut WriteTxn<'_>,
986    schema: &SchemaManager,
987    stmt: &InsertStmt,
988    params: &[Value],
989) -> Result<ExecutionResult> {
990    with_insert_scratch(|bufs| exec_insert_in_txn_impl(wtx, schema, stmt, params, bufs, None))
991}
992
993fn exec_insert_in_txn_cached(
994    wtx: &mut WriteTxn<'_>,
995    schema: &SchemaManager,
996    stmt: &InsertStmt,
997    params: &[Value],
998    cache: &InsertCache,
999) -> Result<ExecutionResult> {
1000    with_insert_scratch(|bufs| {
1001        exec_insert_in_txn_impl(wtx, schema, stmt, params, bufs, Some(cache))
1002    })
1003}
1004
1005fn exec_insert_in_txn_impl(
1006    wtx: &mut WriteTxn<'_>,
1007    schema: &SchemaManager,
1008    stmt: &InsertStmt,
1009    params: &[Value],
1010    bufs: &mut InsertBufs,
1011    cache: Option<&InsertCache>,
1012) -> Result<ExecutionResult> {
1013    let empty_ctes = CteContext::default();
1014    let materialized;
1015    let has_sub = match cache {
1016        Some(c) => c.has_subquery,
1017        None => insert_has_subquery(stmt),
1018    };
1019    let stmt = if has_sub {
1020        materialized = materialize_insert(stmt, &mut |sub| {
1021            exec_subquery_write(wtx, schema, sub, &empty_ctes)
1022        })?;
1023        &materialized
1024    } else {
1025        stmt
1026    };
1027
1028    let table_schema = schema
1029        .get(&stmt.table)
1030        .ok_or_else(|| SqlError::TableNotFound(stmt.table.clone()))?;
1031
1032    let default_columns;
1033    let insert_columns: &[String] = if stmt.columns.is_empty() {
1034        default_columns = table_schema
1035            .columns
1036            .iter()
1037            .map(|c| c.name.clone())
1038            .collect::<Vec<_>>();
1039        &default_columns
1040    } else {
1041        &stmt.columns
1042    };
1043
1044    bufs.col_indices.clear();
1045    if let Some(c) = cache {
1046        bufs.col_indices.extend_from_slice(&c.col_indices);
1047    } else {
1048        for name in insert_columns {
1049            bufs.col_indices.push(
1050                table_schema
1051                    .column_index(name)
1052                    .ok_or_else(|| SqlError::ColumnNotFound(name.clone()))?,
1053            );
1054        }
1055    }
1056
1057    let any_defaults = match cache {
1058        Some(c) => c.any_defaults,
1059        None => table_schema
1060            .columns
1061            .iter()
1062            .any(|c| c.default_expr.is_some()),
1063    };
1064    let defaults: Vec<(usize, &Expr)> = if any_defaults {
1065        table_schema
1066            .columns
1067            .iter()
1068            .filter(|c| {
1069                c.default_expr.is_some() && !bufs.col_indices.contains(&(c.position as usize))
1070            })
1071            .map(|c| (c.position as usize, c.default_expr.as_ref().unwrap()))
1072            .collect()
1073    } else {
1074        Vec::new()
1075    };
1076
1077    let has_checks = match cache {
1078        Some(c) => c.has_checks,
1079        None => table_schema.has_checks(),
1080    };
1081    let check_col_map = if has_checks {
1082        Some(ColumnMap::new(&table_schema.columns))
1083    } else {
1084        None
1085    };
1086
1087    let pk_indices = table_schema.pk_indices();
1088    let non_pk = table_schema.non_pk_indices();
1089    let enc_pos = table_schema.encoding_positions();
1090    let phys_count = table_schema.physical_non_pk_count();
1091    let dropped = table_schema.dropped_non_pk_slots();
1092
1093    bufs.row.resize(table_schema.columns.len(), Value::Null);
1094    bufs.pk_values.resize(pk_indices.len(), Value::Null);
1095    bufs.value_values.resize(phys_count, Value::Null);
1096
1097    let table_bytes = stmt.table.as_bytes();
1098    let has_fks = !table_schema.foreign_keys.is_empty();
1099    let has_indices = !table_schema.indices.is_empty();
1100    let has_defaults = !defaults.is_empty();
1101
1102    let select_rows = match &stmt.source {
1103        InsertSource::Select(sq) => {
1104            let insert_ctes =
1105                super::materialize_all_ctes(&sq.ctes, sq.recursive, &mut |body, ctx| {
1106                    exec_query_body_write(wtx, schema, body, ctx)
1107                })?;
1108            let qr = exec_query_body_write(wtx, schema, &sq.body, &insert_ctes)?;
1109            Some(qr.rows)
1110        }
1111        InsertSource::Values(_) => None,
1112    };
1113
1114    let mut count: u64 = 0;
1115
1116    let values = match &stmt.source {
1117        InsertSource::Values(rows) => Some(rows.as_slice()),
1118        InsertSource::Select(_) => None,
1119    };
1120    let sel_rows = select_rows.as_deref();
1121
1122    let total = match (values, sel_rows) {
1123        (Some(rows), _) => rows.len(),
1124        (_, Some(rows)) => rows.len(),
1125        _ => 0,
1126    };
1127
1128    if let Some(sel) = sel_rows {
1129        if !sel.is_empty() && sel[0].len() != insert_columns.len() {
1130            return Err(SqlError::InvalidValue(format!(
1131                "INSERT ... SELECT column count mismatch: expected {}, got {}",
1132                insert_columns.len(),
1133                sel[0].len()
1134            )));
1135        }
1136    }
1137
1138    for idx in 0..total {
1139        for v in bufs.row.iter_mut() {
1140            *v = Value::Null;
1141        }
1142
1143        if let Some(value_rows) = values {
1144            let value_row = &value_rows[idx];
1145            if value_row.len() != insert_columns.len() {
1146                return Err(SqlError::InvalidValue(format!(
1147                    "expected {} values, got {}",
1148                    insert_columns.len(),
1149                    value_row.len()
1150                )));
1151            }
1152            for (i, expr) in value_row.iter().enumerate() {
1153                let val = match expr {
1154                    Expr::Parameter(n) => params
1155                        .get(n - 1)
1156                        .cloned()
1157                        .ok_or_else(|| SqlError::Parse(format!("unbound parameter ${n}")))?,
1158                    Expr::Literal(v) => v.clone(),
1159                    _ => eval_const_expr(expr)?,
1160                };
1161                let col_idx = bufs.col_indices[i];
1162                let col = &table_schema.columns[col_idx];
1163                let got_type = val.data_type();
1164                bufs.row[col_idx] = if val.is_null() {
1165                    Value::Null
1166                } else {
1167                    val.coerce_into(col.data_type)
1168                        .ok_or_else(|| SqlError::TypeMismatch {
1169                            expected: col.data_type.to_string(),
1170                            got: got_type.to_string(),
1171                        })?
1172                };
1173            }
1174        } else if let Some(sel) = sel_rows {
1175            let sel_row = &sel[idx];
1176            for (i, val) in sel_row.iter().enumerate() {
1177                let col_idx = bufs.col_indices[i];
1178                let col = &table_schema.columns[col_idx];
1179                let got_type = val.data_type();
1180                bufs.row[col_idx] = if val.is_null() {
1181                    Value::Null
1182                } else {
1183                    val.clone().coerce_into(col.data_type).ok_or_else(|| {
1184                        SqlError::TypeMismatch {
1185                            expected: col.data_type.to_string(),
1186                            got: got_type.to_string(),
1187                        }
1188                    })?
1189                };
1190            }
1191        }
1192
1193        if has_defaults {
1194            for &(pos, def_expr) in &defaults {
1195                let val = eval_const_expr(def_expr)?;
1196                let col = &table_schema.columns[pos];
1197                if !val.is_null() {
1198                    let got_type = val.data_type();
1199                    bufs.row[pos] =
1200                        val.coerce_into(col.data_type)
1201                            .ok_or_else(|| SqlError::TypeMismatch {
1202                                expected: col.data_type.to_string(),
1203                                got: got_type.to_string(),
1204                            })?;
1205                }
1206            }
1207        }
1208
1209        for col in &table_schema.columns {
1210            if !col.nullable && bufs.row[col.position as usize].is_null() {
1211                return Err(SqlError::NotNullViolation(col.name.clone()));
1212            }
1213        }
1214
1215        if let Some(ref col_map) = check_col_map {
1216            for col in &table_schema.columns {
1217                if let Some(ref check) = col.check_expr {
1218                    let result = eval_expr(check, &EvalCtx::new(col_map, &bufs.row))?;
1219                    if !is_truthy(&result) && !result.is_null() {
1220                        let name = col.check_name.as_deref().unwrap_or(&col.name);
1221                        return Err(SqlError::CheckViolation(name.to_string()));
1222                    }
1223                }
1224            }
1225            for tc in &table_schema.check_constraints {
1226                let result = eval_expr(&tc.expr, &EvalCtx::new(col_map, &bufs.row))?;
1227                if !is_truthy(&result) && !result.is_null() {
1228                    let name = tc.name.as_deref().unwrap_or(&tc.sql);
1229                    return Err(SqlError::CheckViolation(name.to_string()));
1230                }
1231            }
1232        }
1233
1234        if has_fks {
1235            for fk in &table_schema.foreign_keys {
1236                let any_null = fk.columns.iter().any(|&ci| bufs.row[ci as usize].is_null());
1237                if any_null {
1238                    continue;
1239                }
1240                let fk_vals: Vec<Value> = fk
1241                    .columns
1242                    .iter()
1243                    .map(|&ci| bufs.row[ci as usize].clone())
1244                    .collect();
1245                bufs.fk_key_buf.clear();
1246                encode_composite_key_into(&fk_vals, &mut bufs.fk_key_buf);
1247                let found = wtx
1248                    .table_get(fk.foreign_table.as_bytes(), &bufs.fk_key_buf)
1249                    .map_err(SqlError::Storage)?;
1250                if found.is_none() {
1251                    let name = fk.name.as_deref().unwrap_or(&fk.foreign_table);
1252                    return Err(SqlError::ForeignKeyViolation(name.to_string()));
1253                }
1254            }
1255        }
1256
1257        for (j, &i) in pk_indices.iter().enumerate() {
1258            bufs.pk_values[j] = std::mem::replace(&mut bufs.row[i], Value::Null);
1259        }
1260        encode_composite_key_into(&bufs.pk_values, &mut bufs.key_buf);
1261
1262        for &slot in dropped {
1263            bufs.value_values[slot as usize] = Value::Null;
1264        }
1265        for (j, &i) in non_pk.iter().enumerate() {
1266            bufs.value_values[enc_pos[j] as usize] =
1267                std::mem::replace(&mut bufs.row[i], Value::Null);
1268        }
1269        encode_row_into(&bufs.value_values, &mut bufs.value_buf);
1270
1271        if bufs.key_buf.len() > citadel_core::MAX_KEY_SIZE {
1272            return Err(SqlError::KeyTooLarge {
1273                size: bufs.key_buf.len(),
1274                max: citadel_core::MAX_KEY_SIZE,
1275            });
1276        }
1277        if bufs.value_buf.len() > citadel_core::MAX_INLINE_VALUE_SIZE {
1278            return Err(SqlError::RowTooLarge {
1279                size: bufs.value_buf.len(),
1280                max: citadel_core::MAX_INLINE_VALUE_SIZE,
1281            });
1282        }
1283
1284        let is_new = wtx
1285            .table_insert(table_bytes, &bufs.key_buf, &bufs.value_buf)
1286            .map_err(SqlError::Storage)?;
1287        if !is_new {
1288            return Err(SqlError::DuplicateKey);
1289        }
1290
1291        if has_indices {
1292            for (j, &i) in pk_indices.iter().enumerate() {
1293                bufs.row[i] = bufs.pk_values[j].clone();
1294            }
1295            for (j, &i) in non_pk.iter().enumerate() {
1296                bufs.row[i] =
1297                    std::mem::replace(&mut bufs.value_values[enc_pos[j] as usize], Value::Null);
1298            }
1299            insert_index_entries(wtx, table_schema, &bufs.row, &bufs.pk_values)?;
1300        }
1301        count += 1;
1302    }
1303
1304    Ok(ExecutionResult::RowsAffected(count))
1305}
1306
1307pub struct CompiledInsert {
1308    table_lower: String,
1309    cached: Option<InsertCache>,
1310}
1311
1312struct InsertCache {
1313    col_indices: Vec<usize>,
1314    has_subquery: bool,
1315    any_defaults: bool,
1316    has_checks: bool,
1317}
1318
1319impl CompiledInsert {
1320    pub fn try_compile(schema: &SchemaManager, stmt: &InsertStmt) -> Option<Self> {
1321        let lower = stmt.table.to_ascii_lowercase();
1322        let cached = if let Some(ts) = schema.get(&lower) {
1323            let insert_columns: Vec<&str> = if stmt.columns.is_empty() {
1324                ts.columns.iter().map(|c| c.name.as_str()).collect()
1325            } else {
1326                stmt.columns.iter().map(|s| s.as_str()).collect()
1327            };
1328            let mut col_indices = Vec::with_capacity(insert_columns.len());
1329            for name in &insert_columns {
1330                col_indices.push(ts.column_index(name)?);
1331            }
1332            Some(InsertCache {
1333                col_indices,
1334                has_subquery: insert_has_subquery(stmt),
1335                any_defaults: ts.columns.iter().any(|c| c.default_expr.is_some()),
1336                has_checks: ts.has_checks(),
1337            })
1338        } else if schema.get_view(&lower).is_some() {
1339            None
1340        } else {
1341            return None;
1342        };
1343        Some(Self {
1344            table_lower: lower,
1345            cached,
1346        })
1347    }
1348}
1349
1350impl CompiledPlan for CompiledInsert {
1351    fn execute(
1352        &self,
1353        db: &Database,
1354        schema: &SchemaManager,
1355        stmt: &Statement,
1356        params: &[Value],
1357        wtx: Option<&mut WriteTxn<'_>>,
1358    ) -> Result<ExecutionResult> {
1359        let ins = match stmt {
1360            Statement::Insert(i) => i,
1361            _ => {
1362                return Err(SqlError::Unsupported(
1363                    "CompiledInsert received non-INSERT statement".into(),
1364                ))
1365            }
1366        };
1367        let _ = &self.table_lower;
1368        match wtx {
1369            None => exec_insert(db, schema, ins, params),
1370            Some(outer) => match self.cached.as_ref() {
1371                Some(c) => exec_insert_in_txn_cached(outer, schema, ins, params, c),
1372                None => exec_insert_in_txn(outer, schema, ins, params),
1373            },
1374        }
1375    }
1376}
1377
1378pub struct CompiledDelete {
1379    table_lower: String,
1380}
1381
1382impl CompiledDelete {
1383    pub fn try_compile(schema: &SchemaManager, stmt: &DeleteStmt) -> Option<Self> {
1384        let lower = stmt.table.to_ascii_lowercase();
1385        schema.get(&lower)?;
1386        Some(Self { table_lower: lower })
1387    }
1388}
1389
1390impl CompiledPlan for CompiledDelete {
1391    fn execute(
1392        &self,
1393        db: &Database,
1394        schema: &SchemaManager,
1395        stmt: &Statement,
1396        _params: &[Value],
1397        wtx: Option<&mut WriteTxn<'_>>,
1398    ) -> Result<ExecutionResult> {
1399        let del = match stmt {
1400            Statement::Delete(d) => d,
1401            _ => {
1402                return Err(SqlError::Unsupported(
1403                    "CompiledDelete received non-DELETE statement".into(),
1404                ))
1405            }
1406        };
1407        let _ = &self.table_lower;
1408        match wtx {
1409            None => super::write::exec_delete(db, schema, del),
1410            Some(outer) => super::write::exec_delete_in_txn(outer, schema, del),
1411        }
1412    }
1413}