Skip to main content

citadel_sql/executor/
dml.rs

1use citadel::Database;
2
3use crate::encoding::{encode_composite_key_into, encode_row_into};
4use crate::error::{Result, SqlError};
5use crate::eval::{eval_expr, is_truthy, ColumnMap};
6use crate::parser::*;
7use crate::types::*;
8
9use crate::schema::SchemaManager;
10
11use super::helpers::*;
12use super::CteContext;
13
14// ── DML + materialization ───────────────────────────────────────────
15
16pub(super) fn exec_insert(
17    db: &Database,
18    schema: &SchemaManager,
19    stmt: &InsertStmt,
20    params: &[Value],
21) -> Result<ExecutionResult> {
22    let empty_ctes = CteContext::new();
23    let materialized;
24    let stmt = if insert_has_subquery(stmt) {
25        materialized = materialize_insert(stmt, &mut |sub| {
26            exec_subquery_read(db, schema, sub, &empty_ctes)
27        })?;
28        &materialized
29    } else {
30        stmt
31    };
32
33    let lower_name = stmt.table.to_ascii_lowercase();
34    if schema.get_view(&lower_name).is_some() {
35        return Err(SqlError::CannotModifyView(stmt.table.clone()));
36    }
37    let table_schema = schema
38        .get(&lower_name)
39        .ok_or_else(|| SqlError::TableNotFound(stmt.table.clone()))?;
40
41    let insert_columns = if stmt.columns.is_empty() {
42        table_schema
43            .columns
44            .iter()
45            .map(|c| c.name.clone())
46            .collect::<Vec<_>>()
47    } else {
48        stmt.columns
49            .iter()
50            .map(|c| c.to_ascii_lowercase())
51            .collect()
52    };
53
54    let col_indices: Vec<usize> = insert_columns
55        .iter()
56        .map(|name| {
57            table_schema
58                .column_index(name)
59                .ok_or_else(|| SqlError::ColumnNotFound(name.clone()))
60        })
61        .collect::<Result<_>>()?;
62
63    let defaults: Vec<(usize, &Expr)> = table_schema
64        .columns
65        .iter()
66        .filter(|c| c.default_expr.is_some() && !col_indices.contains(&(c.position as usize)))
67        .map(|c| (c.position as usize, c.default_expr.as_ref().unwrap()))
68        .collect();
69
70    // ColumnMap for CHECK evaluation
71    let has_checks = table_schema.has_checks();
72    let check_col_map = if has_checks {
73        Some(ColumnMap::new(&table_schema.columns))
74    } else {
75        None
76    };
77
78    let select_rows = match &stmt.source {
79        InsertSource::Select(sq) => {
80            let insert_ctes =
81                super::materialize_all_ctes(&sq.ctes, sq.recursive, &mut |body, ctx| {
82                    exec_query_body_read(db, schema, body, ctx)
83                })?;
84            let qr = exec_query_body_read(db, schema, &sq.body, &insert_ctes)?;
85            Some(qr.rows)
86        }
87        InsertSource::Values(_) => None,
88    };
89
90    let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
91    let mut count: u64 = 0;
92
93    let pk_indices = table_schema.pk_indices();
94    let non_pk = table_schema.non_pk_indices();
95    let enc_pos = table_schema.encoding_positions();
96    let phys_count = table_schema.physical_non_pk_count();
97    let mut row = vec![Value::Null; table_schema.columns.len()];
98    let mut pk_values: Vec<Value> = vec![Value::Null; pk_indices.len()];
99    let mut value_values: Vec<Value> = vec![Value::Null; phys_count];
100    let mut key_buf: Vec<u8> = Vec::with_capacity(64);
101    let mut value_buf: Vec<u8> = Vec::with_capacity(256);
102    let mut fk_key_buf: Vec<u8> = Vec::with_capacity(64);
103
104    let values = match &stmt.source {
105        InsertSource::Values(rows) => Some(rows.as_slice()),
106        InsertSource::Select(_) => None,
107    };
108    let sel_rows = select_rows.as_deref();
109
110    let total = match (values, sel_rows) {
111        (Some(rows), _) => rows.len(),
112        (_, Some(rows)) => rows.len(),
113        _ => 0,
114    };
115
116    if let Some(sel) = sel_rows {
117        if !sel.is_empty() && sel[0].len() != insert_columns.len() {
118            return Err(SqlError::InvalidValue(format!(
119                "INSERT ... SELECT column count mismatch: expected {}, got {}",
120                insert_columns.len(),
121                sel[0].len()
122            )));
123        }
124    }
125
126    for idx in 0..total {
127        for v in row.iter_mut() {
128            *v = Value::Null;
129        }
130
131        if let Some(value_rows) = values {
132            let value_row = &value_rows[idx];
133            if value_row.len() != insert_columns.len() {
134                return Err(SqlError::InvalidValue(format!(
135                    "expected {} values, got {}",
136                    insert_columns.len(),
137                    value_row.len()
138                )));
139            }
140            for (i, expr) in value_row.iter().enumerate() {
141                let val = if let Expr::Parameter(n) = expr {
142                    params
143                        .get(n - 1)
144                        .cloned()
145                        .ok_or_else(|| SqlError::Parse(format!("unbound parameter ${n}")))?
146                } else {
147                    eval_const_expr(expr)?
148                };
149                let col_idx = col_indices[i];
150                let col = &table_schema.columns[col_idx];
151                let got_type = val.data_type();
152                row[col_idx] = if val.is_null() {
153                    Value::Null
154                } else {
155                    val.coerce_into(col.data_type)
156                        .ok_or_else(|| SqlError::TypeMismatch {
157                            expected: col.data_type.to_string(),
158                            got: got_type.to_string(),
159                        })?
160                };
161            }
162        } else if let Some(sel) = sel_rows {
163            let sel_row = &sel[idx];
164            for (i, val) in sel_row.iter().enumerate() {
165                let col_idx = col_indices[i];
166                let col = &table_schema.columns[col_idx];
167                let got_type = val.data_type();
168                row[col_idx] = if val.is_null() {
169                    Value::Null
170                } else {
171                    val.clone().coerce_into(col.data_type).ok_or_else(|| {
172                        SqlError::TypeMismatch {
173                            expected: col.data_type.to_string(),
174                            got: got_type.to_string(),
175                        }
176                    })?
177                };
178            }
179        }
180
181        for &(pos, def_expr) in &defaults {
182            let val = eval_const_expr(def_expr)?;
183            let col = &table_schema.columns[pos];
184            if val.is_null() {
185                // row[pos] already Null from init
186            } else {
187                let got_type = val.data_type();
188                row[pos] =
189                    val.coerce_into(col.data_type)
190                        .ok_or_else(|| SqlError::TypeMismatch {
191                            expected: col.data_type.to_string(),
192                            got: got_type.to_string(),
193                        })?;
194            }
195        }
196
197        for col in &table_schema.columns {
198            if !col.nullable && row[col.position as usize].is_null() {
199                return Err(SqlError::NotNullViolation(col.name.clone()));
200            }
201        }
202
203        if let Some(ref col_map) = check_col_map {
204            for col in &table_schema.columns {
205                if let Some(ref check) = col.check_expr {
206                    let result = eval_expr(check, col_map, &row)?;
207                    if !is_truthy(&result) && !result.is_null() {
208                        let name = col.check_name.as_deref().unwrap_or(&col.name);
209                        return Err(SqlError::CheckViolation(name.to_string()));
210                    }
211                }
212            }
213            for tc in &table_schema.check_constraints {
214                let result = eval_expr(&tc.expr, col_map, &row)?;
215                if !is_truthy(&result) && !result.is_null() {
216                    let name = tc.name.as_deref().unwrap_or(&tc.sql);
217                    return Err(SqlError::CheckViolation(name.to_string()));
218                }
219            }
220        }
221
222        for fk in &table_schema.foreign_keys {
223            let any_null = fk.columns.iter().any(|&ci| row[ci as usize].is_null());
224            if any_null {
225                continue; // MATCH SIMPLE: skip if any FK col is NULL
226            }
227            let fk_vals: Vec<Value> = fk
228                .columns
229                .iter()
230                .map(|&ci| row[ci as usize].clone())
231                .collect();
232            fk_key_buf.clear();
233            encode_composite_key_into(&fk_vals, &mut fk_key_buf);
234            let found = wtx
235                .table_get(fk.foreign_table.as_bytes(), &fk_key_buf)
236                .map_err(SqlError::Storage)?;
237            if found.is_none() {
238                let name = fk.name.as_deref().unwrap_or(&fk.foreign_table);
239                return Err(SqlError::ForeignKeyViolation(name.to_string()));
240            }
241        }
242
243        for (j, &i) in pk_indices.iter().enumerate() {
244            pk_values[j] = std::mem::replace(&mut row[i], Value::Null);
245        }
246        encode_composite_key_into(&pk_values, &mut key_buf);
247
248        for (j, &i) in non_pk.iter().enumerate() {
249            value_values[enc_pos[j] as usize] = std::mem::replace(&mut row[i], Value::Null);
250        }
251        encode_row_into(&value_values, &mut value_buf);
252
253        if key_buf.len() > citadel_core::MAX_KEY_SIZE {
254            return Err(SqlError::KeyTooLarge {
255                size: key_buf.len(),
256                max: citadel_core::MAX_KEY_SIZE,
257            });
258        }
259        if value_buf.len() > citadel_core::MAX_INLINE_VALUE_SIZE {
260            return Err(SqlError::RowTooLarge {
261                size: value_buf.len(),
262                max: citadel_core::MAX_INLINE_VALUE_SIZE,
263            });
264        }
265
266        let is_new = wtx
267            .table_insert(stmt.table.as_bytes(), &key_buf, &value_buf)
268            .map_err(SqlError::Storage)?;
269        if !is_new {
270            return Err(SqlError::DuplicateKey);
271        }
272
273        if !table_schema.indices.is_empty() {
274            for (j, &i) in pk_indices.iter().enumerate() {
275                row[i] = pk_values[j].clone();
276            }
277            for (j, &i) in non_pk.iter().enumerate() {
278                row[i] = std::mem::replace(&mut value_values[enc_pos[j] as usize], Value::Null);
279            }
280            insert_index_entries(&mut wtx, table_schema, &row, &pk_values)?;
281        }
282        count += 1;
283    }
284
285    wtx.commit().map_err(SqlError::Storage)?;
286    Ok(ExecutionResult::RowsAffected(count))
287}
288
289pub(super) fn has_subquery(expr: &Expr) -> bool {
290    crate::parser::has_subquery(expr)
291}
292
293pub(super) fn stmt_has_subquery(stmt: &SelectStmt) -> bool {
294    if let Some(ref w) = stmt.where_clause {
295        if has_subquery(w) {
296            return true;
297        }
298    }
299    if let Some(ref h) = stmt.having {
300        if has_subquery(h) {
301            return true;
302        }
303    }
304    for col in &stmt.columns {
305        if let SelectColumn::Expr { expr, .. } = col {
306            if has_subquery(expr) {
307                return true;
308            }
309        }
310    }
311    for ob in &stmt.order_by {
312        if has_subquery(&ob.expr) {
313            return true;
314        }
315    }
316    for join in &stmt.joins {
317        if let Some(ref on_expr) = join.on_clause {
318            if has_subquery(on_expr) {
319                return true;
320            }
321        }
322    }
323    false
324}
325
326pub(super) fn materialize_expr(
327    expr: &Expr,
328    exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
329) -> Result<Expr> {
330    match expr {
331        Expr::InSubquery {
332            expr: e,
333            subquery,
334            negated,
335        } => {
336            let inner = materialize_expr(e, exec_sub)?;
337            let qr = exec_sub(subquery)?;
338            if !qr.columns.is_empty() && qr.columns.len() != 1 {
339                return Err(SqlError::SubqueryMultipleColumns);
340            }
341            let mut values = std::collections::HashSet::new();
342            let mut has_null = false;
343            for row in &qr.rows {
344                if row[0].is_null() {
345                    has_null = true;
346                } else {
347                    values.insert(row[0].clone());
348                }
349            }
350            Ok(Expr::InSet {
351                expr: Box::new(inner),
352                values,
353                has_null,
354                negated: *negated,
355            })
356        }
357        Expr::ScalarSubquery(subquery) => {
358            let qr = exec_sub(subquery)?;
359            if qr.rows.len() > 1 {
360                return Err(SqlError::SubqueryMultipleRows);
361            }
362            let val = if qr.rows.is_empty() {
363                Value::Null
364            } else {
365                qr.rows[0][0].clone()
366            };
367            Ok(Expr::Literal(val))
368        }
369        Expr::Exists { subquery, negated } => {
370            let qr = exec_sub(subquery)?;
371            let exists = !qr.rows.is_empty();
372            let result = if *negated { !exists } else { exists };
373            Ok(Expr::Literal(Value::Boolean(result)))
374        }
375        Expr::InList {
376            expr: e,
377            list,
378            negated,
379        } => {
380            let inner = materialize_expr(e, exec_sub)?;
381            let items = list
382                .iter()
383                .map(|item| materialize_expr(item, exec_sub))
384                .collect::<Result<Vec<_>>>()?;
385            Ok(Expr::InList {
386                expr: Box::new(inner),
387                list: items,
388                negated: *negated,
389            })
390        }
391        Expr::BinaryOp { left, op, right } => Ok(Expr::BinaryOp {
392            left: Box::new(materialize_expr(left, exec_sub)?),
393            op: *op,
394            right: Box::new(materialize_expr(right, exec_sub)?),
395        }),
396        Expr::UnaryOp { op, expr: e } => Ok(Expr::UnaryOp {
397            op: *op,
398            expr: Box::new(materialize_expr(e, exec_sub)?),
399        }),
400        Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(materialize_expr(e, exec_sub)?))),
401        Expr::IsNotNull(e) => Ok(Expr::IsNotNull(Box::new(materialize_expr(e, exec_sub)?))),
402        Expr::InSet {
403            expr: e,
404            values,
405            has_null,
406            negated,
407        } => Ok(Expr::InSet {
408            expr: Box::new(materialize_expr(e, exec_sub)?),
409            values: values.clone(),
410            has_null: *has_null,
411            negated: *negated,
412        }),
413        Expr::Between {
414            expr: e,
415            low,
416            high,
417            negated,
418        } => Ok(Expr::Between {
419            expr: Box::new(materialize_expr(e, exec_sub)?),
420            low: Box::new(materialize_expr(low, exec_sub)?),
421            high: Box::new(materialize_expr(high, exec_sub)?),
422            negated: *negated,
423        }),
424        Expr::Like {
425            expr: e,
426            pattern,
427            escape,
428            negated,
429        } => {
430            let esc = escape
431                .as_ref()
432                .map(|es| materialize_expr(es, exec_sub).map(Box::new))
433                .transpose()?;
434            Ok(Expr::Like {
435                expr: Box::new(materialize_expr(e, exec_sub)?),
436                pattern: Box::new(materialize_expr(pattern, exec_sub)?),
437                escape: esc,
438                negated: *negated,
439            })
440        }
441        Expr::Case {
442            operand,
443            conditions,
444            else_result,
445        } => {
446            let op = operand
447                .as_ref()
448                .map(|e| materialize_expr(e, exec_sub).map(Box::new))
449                .transpose()?;
450            let conds = conditions
451                .iter()
452                .map(|(c, r)| {
453                    Ok((
454                        materialize_expr(c, exec_sub)?,
455                        materialize_expr(r, exec_sub)?,
456                    ))
457                })
458                .collect::<Result<Vec<_>>>()?;
459            let else_r = else_result
460                .as_ref()
461                .map(|e| materialize_expr(e, exec_sub).map(Box::new))
462                .transpose()?;
463            Ok(Expr::Case {
464                operand: op,
465                conditions: conds,
466                else_result: else_r,
467            })
468        }
469        Expr::Coalesce(args) => {
470            let materialized = args
471                .iter()
472                .map(|a| materialize_expr(a, exec_sub))
473                .collect::<Result<Vec<_>>>()?;
474            Ok(Expr::Coalesce(materialized))
475        }
476        Expr::Cast { expr: e, data_type } => Ok(Expr::Cast {
477            expr: Box::new(materialize_expr(e, exec_sub)?),
478            data_type: *data_type,
479        }),
480        Expr::Function { name, args } => {
481            let materialized = args
482                .iter()
483                .map(|a| materialize_expr(a, exec_sub))
484                .collect::<Result<Vec<_>>>()?;
485            Ok(Expr::Function {
486                name: name.clone(),
487                args: materialized,
488            })
489        }
490        other => Ok(other.clone()),
491    }
492}
493
494pub(super) fn materialize_stmt(
495    stmt: &SelectStmt,
496    exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
497) -> Result<SelectStmt> {
498    let where_clause = stmt
499        .where_clause
500        .as_ref()
501        .map(|e| materialize_expr(e, exec_sub))
502        .transpose()?;
503    let having = stmt
504        .having
505        .as_ref()
506        .map(|e| materialize_expr(e, exec_sub))
507        .transpose()?;
508    let columns = stmt
509        .columns
510        .iter()
511        .map(|c| match c {
512            SelectColumn::AllColumns => Ok(SelectColumn::AllColumns),
513            SelectColumn::Expr { expr, alias } => Ok(SelectColumn::Expr {
514                expr: materialize_expr(expr, exec_sub)?,
515                alias: alias.clone(),
516            }),
517        })
518        .collect::<Result<Vec<_>>>()?;
519    let order_by = stmt
520        .order_by
521        .iter()
522        .map(|ob| {
523            Ok(OrderByItem {
524                expr: materialize_expr(&ob.expr, exec_sub)?,
525                descending: ob.descending,
526                nulls_first: ob.nulls_first,
527            })
528        })
529        .collect::<Result<Vec<_>>>()?;
530    let joins = stmt
531        .joins
532        .iter()
533        .map(|j| {
534            let on_clause = j
535                .on_clause
536                .as_ref()
537                .map(|e| materialize_expr(e, exec_sub))
538                .transpose()?;
539            Ok(JoinClause {
540                join_type: j.join_type,
541                table: j.table.clone(),
542                on_clause,
543            })
544        })
545        .collect::<Result<Vec<_>>>()?;
546    let group_by = stmt
547        .group_by
548        .iter()
549        .map(|e| materialize_expr(e, exec_sub))
550        .collect::<Result<Vec<_>>>()?;
551    Ok(SelectStmt {
552        columns,
553        from: stmt.from.clone(),
554        from_alias: stmt.from_alias.clone(),
555        joins,
556        distinct: stmt.distinct,
557        where_clause,
558        order_by,
559        limit: stmt.limit.clone(),
560        offset: stmt.offset.clone(),
561        group_by,
562        having,
563    })
564}
565
566pub(super) fn exec_subquery_read(
567    db: &Database,
568    schema: &SchemaManager,
569    stmt: &SelectStmt,
570    ctes: &CteContext,
571) -> Result<QueryResult> {
572    match super::exec_select(db, schema, stmt, ctes)? {
573        ExecutionResult::Query(qr) => Ok(qr),
574        _ => Ok(QueryResult {
575            columns: vec![],
576            rows: vec![],
577        }),
578    }
579}
580
581pub(super) fn exec_subquery_write(
582    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
583    schema: &SchemaManager,
584    stmt: &SelectStmt,
585    ctes: &CteContext,
586) -> Result<QueryResult> {
587    match super::exec_select_in_txn(wtx, schema, stmt, ctes)? {
588        ExecutionResult::Query(qr) => Ok(qr),
589        _ => Ok(QueryResult {
590            columns: vec![],
591            rows: vec![],
592        }),
593    }
594}
595
596pub(super) fn update_has_subquery(stmt: &UpdateStmt) -> bool {
597    stmt.where_clause.as_ref().is_some_and(has_subquery)
598        || stmt.assignments.iter().any(|(_, e)| has_subquery(e))
599}
600
601pub(super) fn materialize_update(
602    stmt: &UpdateStmt,
603    exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
604) -> Result<UpdateStmt> {
605    let where_clause = stmt
606        .where_clause
607        .as_ref()
608        .map(|e| materialize_expr(e, exec_sub))
609        .transpose()?;
610    let assignments = stmt
611        .assignments
612        .iter()
613        .map(|(name, expr)| Ok((name.clone(), materialize_expr(expr, exec_sub)?)))
614        .collect::<Result<Vec<_>>>()?;
615    Ok(UpdateStmt {
616        table: stmt.table.clone(),
617        assignments,
618        where_clause,
619    })
620}
621
622pub(super) fn delete_has_subquery(stmt: &DeleteStmt) -> bool {
623    stmt.where_clause.as_ref().is_some_and(has_subquery)
624}
625
626pub(super) fn materialize_delete(
627    stmt: &DeleteStmt,
628    exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
629) -> Result<DeleteStmt> {
630    let where_clause = stmt
631        .where_clause
632        .as_ref()
633        .map(|e| materialize_expr(e, exec_sub))
634        .transpose()?;
635    Ok(DeleteStmt {
636        table: stmt.table.clone(),
637        where_clause,
638    })
639}
640
641pub(super) fn insert_has_subquery(stmt: &InsertStmt) -> bool {
642    match &stmt.source {
643        InsertSource::Values(rows) => rows.iter().any(|row| row.iter().any(has_subquery)),
644        // SELECT source subqueries are handled by exec_select's correlated/non-correlated paths
645        InsertSource::Select(_) => false,
646    }
647}
648
649pub(super) fn materialize_insert(
650    stmt: &InsertStmt,
651    exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
652) -> Result<InsertStmt> {
653    let source = match &stmt.source {
654        InsertSource::Values(rows) => {
655            let mat = rows
656                .iter()
657                .map(|row| {
658                    row.iter()
659                        .map(|e| materialize_expr(e, exec_sub))
660                        .collect::<Result<Vec<_>>>()
661                })
662                .collect::<Result<Vec<_>>>()?;
663            InsertSource::Values(mat)
664        }
665        InsertSource::Select(sq) => {
666            let ctes = sq
667                .ctes
668                .iter()
669                .map(|c| {
670                    Ok(CteDefinition {
671                        name: c.name.clone(),
672                        column_aliases: c.column_aliases.clone(),
673                        body: materialize_query_body(&c.body, exec_sub)?,
674                    })
675                })
676                .collect::<Result<Vec<_>>>()?;
677            let body = materialize_query_body(&sq.body, exec_sub)?;
678            InsertSource::Select(Box::new(SelectQuery {
679                ctes,
680                recursive: sq.recursive,
681                body,
682            }))
683        }
684    };
685    Ok(InsertStmt {
686        table: stmt.table.clone(),
687        columns: stmt.columns.clone(),
688        source,
689    })
690}
691
692pub(super) fn materialize_query_body(
693    body: &QueryBody,
694    exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
695) -> Result<QueryBody> {
696    match body {
697        QueryBody::Select(sel) => Ok(QueryBody::Select(Box::new(materialize_stmt(
698            sel, exec_sub,
699        )?))),
700        QueryBody::Compound(comp) => Ok(QueryBody::Compound(Box::new(CompoundSelect {
701            op: comp.op.clone(),
702            all: comp.all,
703            left: Box::new(materialize_query_body(&comp.left, exec_sub)?),
704            right: Box::new(materialize_query_body(&comp.right, exec_sub)?),
705            order_by: comp.order_by.clone(),
706            limit: comp.limit.clone(),
707            offset: comp.offset.clone(),
708        }))),
709    }
710}
711
712pub(super) fn exec_query_body(
713    db: &Database,
714    schema: &SchemaManager,
715    body: &QueryBody,
716    ctes: &CteContext,
717) -> Result<ExecutionResult> {
718    match body {
719        QueryBody::Select(sel) => super::exec_select(db, schema, sel, ctes),
720        QueryBody::Compound(comp) => exec_compound_select(db, schema, comp, ctes),
721    }
722}
723
724pub(super) fn exec_query_body_in_txn(
725    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
726    schema: &SchemaManager,
727    body: &QueryBody,
728    ctes: &CteContext,
729) -> Result<ExecutionResult> {
730    match body {
731        QueryBody::Select(sel) => super::exec_select_in_txn(wtx, schema, sel, ctes),
732        QueryBody::Compound(comp) => exec_compound_select_in_txn(wtx, schema, comp, ctes),
733    }
734}
735
736pub(super) fn exec_query_body_read(
737    db: &Database,
738    schema: &SchemaManager,
739    body: &QueryBody,
740    ctes: &CteContext,
741) -> Result<QueryResult> {
742    match exec_query_body(db, schema, body, ctes)? {
743        ExecutionResult::Query(qr) => Ok(qr),
744        _ => Ok(QueryResult {
745            columns: vec![],
746            rows: vec![],
747        }),
748    }
749}
750
751pub(super) fn exec_query_body_write(
752    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
753    schema: &SchemaManager,
754    body: &QueryBody,
755    ctes: &CteContext,
756) -> Result<QueryResult> {
757    match exec_query_body_in_txn(wtx, schema, body, ctes)? {
758        ExecutionResult::Query(qr) => Ok(qr),
759        _ => Ok(QueryResult {
760            columns: vec![],
761            rows: vec![],
762        }),
763    }
764}
765
766pub(super) fn exec_compound_select(
767    db: &Database,
768    schema: &SchemaManager,
769    comp: &CompoundSelect,
770    ctes: &CteContext,
771) -> Result<ExecutionResult> {
772    let left_qr = match exec_query_body(db, schema, &comp.left, ctes)? {
773        ExecutionResult::Query(qr) => qr,
774        _ => QueryResult {
775            columns: vec![],
776            rows: vec![],
777        },
778    };
779    let right_qr = match exec_query_body(db, schema, &comp.right, ctes)? {
780        ExecutionResult::Query(qr) => qr,
781        _ => QueryResult {
782            columns: vec![],
783            rows: vec![],
784        },
785    };
786    apply_set_operation(comp, left_qr, right_qr)
787}
788
789pub(super) fn exec_compound_select_in_txn(
790    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
791    schema: &SchemaManager,
792    comp: &CompoundSelect,
793    ctes: &CteContext,
794) -> Result<ExecutionResult> {
795    let left_qr = match exec_query_body_in_txn(wtx, schema, &comp.left, ctes)? {
796        ExecutionResult::Query(qr) => qr,
797        _ => QueryResult {
798            columns: vec![],
799            rows: vec![],
800        },
801    };
802    let right_qr = match exec_query_body_in_txn(wtx, schema, &comp.right, ctes)? {
803        ExecutionResult::Query(qr) => qr,
804        _ => QueryResult {
805            columns: vec![],
806            rows: vec![],
807        },
808    };
809    apply_set_operation(comp, left_qr, right_qr)
810}
811
812pub(super) fn apply_set_operation(
813    comp: &CompoundSelect,
814    left_qr: QueryResult,
815    right_qr: QueryResult,
816) -> Result<ExecutionResult> {
817    if !left_qr.columns.is_empty()
818        && !right_qr.columns.is_empty()
819        && left_qr.columns.len() != right_qr.columns.len()
820    {
821        return Err(SqlError::CompoundColumnCountMismatch {
822            left: left_qr.columns.len(),
823            right: right_qr.columns.len(),
824        });
825    }
826
827    let columns = left_qr.columns;
828
829    let mut rows = match (&comp.op, comp.all) {
830        (SetOp::Union, true) => {
831            let mut rows = left_qr.rows;
832            rows.extend(right_qr.rows);
833            rows
834        }
835        (SetOp::Union, false) => {
836            let mut seen = std::collections::HashSet::new();
837            let mut rows = Vec::new();
838            for row in left_qr.rows.into_iter().chain(right_qr.rows) {
839                if seen.insert(row.clone()) {
840                    rows.push(row);
841                }
842            }
843            rows
844        }
845        (SetOp::Intersect, true) => {
846            let mut right_counts: std::collections::HashMap<Vec<Value>, usize> =
847                std::collections::HashMap::new();
848            for row in &right_qr.rows {
849                *right_counts.entry(row.clone()).or_insert(0) += 1;
850            }
851            let mut rows = Vec::new();
852            for row in left_qr.rows {
853                if let Some(count) = right_counts.get_mut(&row) {
854                    if *count > 0 {
855                        *count -= 1;
856                        rows.push(row);
857                    }
858                }
859            }
860            rows
861        }
862        (SetOp::Intersect, false) => {
863            let right_set: std::collections::HashSet<Vec<Value>> =
864                right_qr.rows.into_iter().collect();
865            let mut seen = std::collections::HashSet::new();
866            let mut rows = Vec::new();
867            for row in left_qr.rows {
868                if right_set.contains(&row) && seen.insert(row.clone()) {
869                    rows.push(row);
870                }
871            }
872            rows
873        }
874        (SetOp::Except, true) => {
875            let mut right_counts: std::collections::HashMap<Vec<Value>, usize> =
876                std::collections::HashMap::new();
877            for row in &right_qr.rows {
878                *right_counts.entry(row.clone()).or_insert(0) += 1;
879            }
880            let mut rows = Vec::new();
881            for row in left_qr.rows {
882                if let Some(count) = right_counts.get_mut(&row) {
883                    if *count > 0 {
884                        *count -= 1;
885                        continue;
886                    }
887                }
888                rows.push(row);
889            }
890            rows
891        }
892        (SetOp::Except, false) => {
893            let right_set: std::collections::HashSet<Vec<Value>> =
894                right_qr.rows.into_iter().collect();
895            let mut seen = std::collections::HashSet::new();
896            let mut rows = Vec::new();
897            for row in left_qr.rows {
898                if !right_set.contains(&row) && seen.insert(row.clone()) {
899                    rows.push(row);
900                }
901            }
902            rows
903        }
904    };
905
906    if !comp.order_by.is_empty() {
907        let col_defs: Vec<crate::types::ColumnDef> = columns
908            .iter()
909            .enumerate()
910            .map(|(i, name)| crate::types::ColumnDef {
911                name: name.clone(),
912                data_type: crate::types::DataType::Null,
913                nullable: true,
914                position: i as u16,
915                default_expr: None,
916                default_sql: None,
917                check_expr: None,
918                check_sql: None,
919                check_name: None,
920            })
921            .collect();
922        sort_rows(&mut rows, &comp.order_by, &col_defs)?;
923    }
924
925    if let Some(ref offset_expr) = comp.offset {
926        let offset = eval_const_int(offset_expr)?.max(0) as usize;
927        if offset < rows.len() {
928            rows = rows.split_off(offset);
929        } else {
930            rows.clear();
931        }
932    }
933
934    if let Some(ref limit_expr) = comp.limit {
935        let limit = eval_const_int(limit_expr)?.max(0) as usize;
936        rows.truncate(limit);
937    }
938
939    Ok(ExecutionResult::Query(QueryResult { columns, rows }))
940}
941
942#[derive(Default)]
943pub struct InsertBufs {
944    row: Vec<Value>,
945    pk_values: Vec<Value>,
946    value_values: Vec<Value>,
947    key_buf: Vec<u8>,
948    value_buf: Vec<u8>,
949    col_indices: Vec<usize>,
950    fk_key_buf: Vec<u8>,
951}
952
953impl InsertBufs {
954    pub fn new() -> Self {
955        Self {
956            row: Vec::new(),
957            pk_values: Vec::new(),
958            value_values: Vec::new(),
959            key_buf: Vec::with_capacity(64),
960            value_buf: Vec::with_capacity(256),
961            col_indices: Vec::new(),
962            fk_key_buf: Vec::with_capacity(64),
963        }
964    }
965}
966
967pub fn exec_insert_in_txn(
968    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
969    schema: &SchemaManager,
970    stmt: &InsertStmt,
971    params: &[Value],
972    bufs: &mut InsertBufs,
973) -> Result<ExecutionResult> {
974    let empty_ctes = CteContext::new();
975    let materialized;
976    let stmt = if insert_has_subquery(stmt) {
977        materialized = materialize_insert(stmt, &mut |sub| {
978            exec_subquery_write(wtx, schema, sub, &empty_ctes)
979        })?;
980        &materialized
981    } else {
982        stmt
983    };
984
985    let table_schema = schema
986        .get(&stmt.table)
987        .ok_or_else(|| SqlError::TableNotFound(stmt.table.clone()))?;
988
989    let default_columns;
990    let insert_columns: &[String] = if stmt.columns.is_empty() {
991        default_columns = table_schema
992            .columns
993            .iter()
994            .map(|c| c.name.clone())
995            .collect::<Vec<_>>();
996        &default_columns
997    } else {
998        &stmt.columns
999    };
1000
1001    bufs.col_indices.clear();
1002    for name in insert_columns {
1003        bufs.col_indices.push(
1004            table_schema
1005                .column_index(name)
1006                .ok_or_else(|| SqlError::ColumnNotFound(name.clone()))?,
1007        );
1008    }
1009
1010    let defaults: Vec<(usize, &Expr)> = table_schema
1011        .columns
1012        .iter()
1013        .filter(|c| c.default_expr.is_some() && !bufs.col_indices.contains(&(c.position as usize)))
1014        .map(|c| (c.position as usize, c.default_expr.as_ref().unwrap()))
1015        .collect();
1016
1017    let has_checks = table_schema.has_checks();
1018    let check_col_map = if has_checks {
1019        Some(ColumnMap::new(&table_schema.columns))
1020    } else {
1021        None
1022    };
1023
1024    let pk_indices = table_schema.pk_indices();
1025    let non_pk = table_schema.non_pk_indices();
1026    let enc_pos = table_schema.encoding_positions();
1027    let phys_count = table_schema.physical_non_pk_count();
1028    let dropped = table_schema.dropped_non_pk_slots();
1029
1030    bufs.row.resize(table_schema.columns.len(), Value::Null);
1031    bufs.pk_values.resize(pk_indices.len(), Value::Null);
1032    bufs.value_values.resize(phys_count, Value::Null);
1033
1034    let select_rows = match &stmt.source {
1035        InsertSource::Select(sq) => {
1036            let insert_ctes =
1037                super::materialize_all_ctes(&sq.ctes, sq.recursive, &mut |body, ctx| {
1038                    exec_query_body_write(wtx, schema, body, ctx)
1039                })?;
1040            let qr = exec_query_body_write(wtx, schema, &sq.body, &insert_ctes)?;
1041            Some(qr.rows)
1042        }
1043        InsertSource::Values(_) => None,
1044    };
1045
1046    let mut count: u64 = 0;
1047
1048    let values = match &stmt.source {
1049        InsertSource::Values(rows) => Some(rows.as_slice()),
1050        InsertSource::Select(_) => None,
1051    };
1052    let sel_rows = select_rows.as_deref();
1053
1054    let total = match (values, sel_rows) {
1055        (Some(rows), _) => rows.len(),
1056        (_, Some(rows)) => rows.len(),
1057        _ => 0,
1058    };
1059
1060    if let Some(sel) = sel_rows {
1061        if !sel.is_empty() && sel[0].len() != insert_columns.len() {
1062            return Err(SqlError::InvalidValue(format!(
1063                "INSERT ... SELECT column count mismatch: expected {}, got {}",
1064                insert_columns.len(),
1065                sel[0].len()
1066            )));
1067        }
1068    }
1069
1070    for idx in 0..total {
1071        for v in bufs.row.iter_mut() {
1072            *v = Value::Null;
1073        }
1074
1075        if let Some(value_rows) = values {
1076            let value_row = &value_rows[idx];
1077            if value_row.len() != insert_columns.len() {
1078                return Err(SqlError::InvalidValue(format!(
1079                    "expected {} values, got {}",
1080                    insert_columns.len(),
1081                    value_row.len()
1082                )));
1083            }
1084            for (i, expr) in value_row.iter().enumerate() {
1085                let val = if let Expr::Parameter(n) = expr {
1086                    params
1087                        .get(n - 1)
1088                        .cloned()
1089                        .ok_or_else(|| SqlError::Parse(format!("unbound parameter ${n}")))?
1090                } else {
1091                    eval_const_expr(expr)?
1092                };
1093                let col_idx = bufs.col_indices[i];
1094                let col = &table_schema.columns[col_idx];
1095                let got_type = val.data_type();
1096                bufs.row[col_idx] = if val.is_null() {
1097                    Value::Null
1098                } else {
1099                    val.coerce_into(col.data_type)
1100                        .ok_or_else(|| SqlError::TypeMismatch {
1101                            expected: col.data_type.to_string(),
1102                            got: got_type.to_string(),
1103                        })?
1104                };
1105            }
1106        } else if let Some(sel) = sel_rows {
1107            let sel_row = &sel[idx];
1108            for (i, val) in sel_row.iter().enumerate() {
1109                let col_idx = bufs.col_indices[i];
1110                let col = &table_schema.columns[col_idx];
1111                let got_type = val.data_type();
1112                bufs.row[col_idx] = if val.is_null() {
1113                    Value::Null
1114                } else {
1115                    val.clone().coerce_into(col.data_type).ok_or_else(|| {
1116                        SqlError::TypeMismatch {
1117                            expected: col.data_type.to_string(),
1118                            got: got_type.to_string(),
1119                        }
1120                    })?
1121                };
1122            }
1123        }
1124
1125        for &(pos, def_expr) in &defaults {
1126            let val = eval_const_expr(def_expr)?;
1127            let col = &table_schema.columns[pos];
1128            if val.is_null() {
1129                // bufs.row[pos] already Null from init
1130            } else {
1131                let got_type = val.data_type();
1132                bufs.row[pos] =
1133                    val.coerce_into(col.data_type)
1134                        .ok_or_else(|| SqlError::TypeMismatch {
1135                            expected: col.data_type.to_string(),
1136                            got: got_type.to_string(),
1137                        })?;
1138            }
1139        }
1140
1141        for col in &table_schema.columns {
1142            if !col.nullable && bufs.row[col.position as usize].is_null() {
1143                return Err(SqlError::NotNullViolation(col.name.clone()));
1144            }
1145        }
1146
1147        if let Some(ref col_map) = check_col_map {
1148            for col in &table_schema.columns {
1149                if let Some(ref check) = col.check_expr {
1150                    let result = eval_expr(check, col_map, &bufs.row)?;
1151                    if !is_truthy(&result) && !result.is_null() {
1152                        let name = col.check_name.as_deref().unwrap_or(&col.name);
1153                        return Err(SqlError::CheckViolation(name.to_string()));
1154                    }
1155                }
1156            }
1157            for tc in &table_schema.check_constraints {
1158                let result = eval_expr(&tc.expr, col_map, &bufs.row)?;
1159                if !is_truthy(&result) && !result.is_null() {
1160                    let name = tc.name.as_deref().unwrap_or(&tc.sql);
1161                    return Err(SqlError::CheckViolation(name.to_string()));
1162                }
1163            }
1164        }
1165
1166        for fk in &table_schema.foreign_keys {
1167            let any_null = fk.columns.iter().any(|&ci| bufs.row[ci as usize].is_null());
1168            if any_null {
1169                continue;
1170            }
1171            let fk_vals: Vec<Value> = fk
1172                .columns
1173                .iter()
1174                .map(|&ci| bufs.row[ci as usize].clone())
1175                .collect();
1176            bufs.fk_key_buf.clear();
1177            encode_composite_key_into(&fk_vals, &mut bufs.fk_key_buf);
1178            let found = wtx
1179                .table_get(fk.foreign_table.as_bytes(), &bufs.fk_key_buf)
1180                .map_err(SqlError::Storage)?;
1181            if found.is_none() {
1182                let name = fk.name.as_deref().unwrap_or(&fk.foreign_table);
1183                return Err(SqlError::ForeignKeyViolation(name.to_string()));
1184            }
1185        }
1186
1187        for (j, &i) in pk_indices.iter().enumerate() {
1188            bufs.pk_values[j] = std::mem::replace(&mut bufs.row[i], Value::Null);
1189        }
1190        encode_composite_key_into(&bufs.pk_values, &mut bufs.key_buf);
1191
1192        for &slot in dropped {
1193            bufs.value_values[slot as usize] = Value::Null;
1194        }
1195        for (j, &i) in non_pk.iter().enumerate() {
1196            bufs.value_values[enc_pos[j] as usize] =
1197                std::mem::replace(&mut bufs.row[i], Value::Null);
1198        }
1199        encode_row_into(&bufs.value_values, &mut bufs.value_buf);
1200
1201        if bufs.key_buf.len() > citadel_core::MAX_KEY_SIZE {
1202            return Err(SqlError::KeyTooLarge {
1203                size: bufs.key_buf.len(),
1204                max: citadel_core::MAX_KEY_SIZE,
1205            });
1206        }
1207        if bufs.value_buf.len() > citadel_core::MAX_INLINE_VALUE_SIZE {
1208            return Err(SqlError::RowTooLarge {
1209                size: bufs.value_buf.len(),
1210                max: citadel_core::MAX_INLINE_VALUE_SIZE,
1211            });
1212        }
1213
1214        let is_new = wtx
1215            .table_insert(stmt.table.as_bytes(), &bufs.key_buf, &bufs.value_buf)
1216            .map_err(SqlError::Storage)?;
1217        if !is_new {
1218            return Err(SqlError::DuplicateKey);
1219        }
1220
1221        if !table_schema.indices.is_empty() {
1222            for (j, &i) in pk_indices.iter().enumerate() {
1223                bufs.row[i] = bufs.pk_values[j].clone();
1224            }
1225            for (j, &i) in non_pk.iter().enumerate() {
1226                bufs.row[i] =
1227                    std::mem::replace(&mut bufs.value_values[enc_pos[j] as usize], Value::Null);
1228            }
1229            insert_index_entries(wtx, table_schema, &bufs.row, &bufs.pk_values)?;
1230        }
1231        count += 1;
1232    }
1233
1234    Ok(ExecutionResult::RowsAffected(count))
1235}