Skip to main content

citadel_sql/executor/
dml.rs

1use std::cell::RefCell;
2use std::sync::Arc;
3
4use citadel::Database;
5use citadel_buffer::btree::{UpsertAction, UpsertOutcome};
6use citadel_txn::write_txn::WriteTxn;
7use rustc_hash::FxHashMap;
8
9use crate::encoding::{encode_composite_key_into, encode_row_into};
10use crate::error::{Result, SqlError};
11use crate::eval::{eval_expr, is_truthy, ColumnMap, EvalCtx};
12use crate::parser::*;
13use crate::types::*;
14
15use crate::schema::SchemaManager;
16
17use super::compile::CompiledPlan;
18use super::helpers::*;
19use super::CteContext;
20
21pub(super) fn exec_insert(
22    db: &Database,
23    schema: &SchemaManager,
24    stmt: &InsertStmt,
25    params: &[Value],
26) -> Result<ExecutionResult> {
27    let empty_ctes = CteContext::default();
28    let materialized;
29    let stmt = if insert_has_subquery(stmt) {
30        materialized = materialize_insert(stmt, &mut |sub| {
31            exec_subquery_read(db, schema, sub, &empty_ctes)
32        })?;
33        &materialized
34    } else {
35        stmt
36    };
37
38    let lower_name = stmt.table.to_ascii_lowercase();
39    if schema.get_view(&lower_name).is_some() {
40        return Err(SqlError::CannotModifyView(stmt.table.clone()));
41    }
42    let table_schema = schema
43        .get(&lower_name)
44        .ok_or_else(|| SqlError::TableNotFound(stmt.table.clone()))?;
45
46    let insert_columns = if stmt.columns.is_empty() {
47        table_schema
48            .columns
49            .iter()
50            .map(|c| c.name.clone())
51            .collect::<Vec<_>>()
52    } else {
53        stmt.columns
54            .iter()
55            .map(|c| c.to_ascii_lowercase())
56            .collect()
57    };
58
59    let col_indices: Vec<usize> = insert_columns
60        .iter()
61        .map(|name| {
62            table_schema
63                .column_index(name)
64                .ok_or_else(|| SqlError::ColumnNotFound(name.clone()))
65        })
66        .collect::<Result<_>>()?;
67
68    for &ci in &col_indices {
69        if table_schema.columns[ci].generated_kind.is_some() {
70            return Err(SqlError::CannotInsertIntoGeneratedColumn(
71                table_schema.columns[ci].name.clone(),
72            ));
73        }
74    }
75
76    let defaults: Vec<(usize, &Expr)> = table_schema
77        .columns
78        .iter()
79        .filter(|c| c.default_expr.is_some() && !col_indices.contains(&(c.position as usize)))
80        .map(|c| (c.position as usize, c.default_expr.as_ref().unwrap()))
81        .collect();
82
83    let generated_cols: Vec<(usize, &Expr)> = table_schema
84        .columns
85        .iter()
86        .filter(|c| matches!(c.generated_kind, Some(crate::parser::GeneratedKind::Stored)))
87        .map(|c| (c.position as usize, c.generated_expr.as_ref().unwrap()))
88        .collect();
89
90    let has_checks = table_schema.has_checks();
91    let strict = table_schema.is_strict();
92    let row_col_map_for_gen = if !generated_cols.is_empty() {
93        Some(ColumnMap::new(&table_schema.columns))
94    } else {
95        None
96    };
97    let check_col_map = if has_checks {
98        Some(ColumnMap::new(&table_schema.columns))
99    } else {
100        None
101    };
102
103    let select_rows = match &stmt.source {
104        InsertSource::Select(sq) => {
105            let insert_ctes =
106                super::materialize_all_ctes(&sq.ctes, sq.recursive, &mut |body, ctx| {
107                    exec_query_body_read(db, schema, body, ctx)
108                })?;
109            let qr = exec_query_body_read(db, schema, &sq.body, &insert_ctes)?;
110            Some(qr.rows)
111        }
112        InsertSource::Values(_) => None,
113    };
114
115    let compiled_conflict: Option<Arc<CompiledOnConflict>> = stmt
116        .on_conflict
117        .as_ref()
118        .map(|oc| compile_on_conflict(oc, table_schema).map(Arc::new))
119        .transpose()?;
120
121    let row_col_map = compiled_conflict
122        .as_ref()
123        .map(|_| ColumnMap::new(&table_schema.columns));
124
125    let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
126    let mut count: u64 = 0;
127    let mut returning_rows: Option<Vec<super::helpers::ReturningRow>> =
128        stmt.returning.as_ref().map(|_| Vec::new());
129
130    let pk_indices = table_schema.pk_indices();
131    let non_pk = table_schema.non_pk_indices();
132    let enc_pos = table_schema.encoding_positions();
133    let phys_count = table_schema.physical_non_pk_count();
134    let mut row = vec![Value::Null; table_schema.columns.len()];
135    let mut pk_values: Vec<Value> = vec![Value::Null; pk_indices.len()];
136    let mut value_values: Vec<Value> = vec![Value::Null; phys_count];
137    let mut key_buf: Vec<u8> = Vec::with_capacity(64);
138    let mut value_buf: Vec<u8> = Vec::with_capacity(256);
139    let mut fk_key_buf: Vec<u8> = Vec::with_capacity(64);
140
141    let values = match &stmt.source {
142        InsertSource::Values(rows) => Some(rows.as_slice()),
143        InsertSource::Select(_) => None,
144    };
145    let sel_rows = select_rows.as_deref();
146
147    let total = match (values, sel_rows) {
148        (Some(rows), _) => rows.len(),
149        (_, Some(rows)) => rows.len(),
150        _ => 0,
151    };
152
153    if let Some(sel) = sel_rows {
154        if !sel.is_empty() && sel[0].len() != insert_columns.len() {
155            return Err(SqlError::InvalidValue(format!(
156                "INSERT ... SELECT column count mismatch: expected {}, got {}",
157                insert_columns.len(),
158                sel[0].len()
159            )));
160        }
161    }
162
163    for idx in 0..total {
164        for v in row.iter_mut() {
165            *v = Value::Null;
166        }
167
168        if let Some(value_rows) = values {
169            let value_row = &value_rows[idx];
170            if value_row.len() != insert_columns.len() {
171                return Err(SqlError::InvalidValue(format!(
172                    "expected {} values, got {}",
173                    insert_columns.len(),
174                    value_row.len()
175                )));
176            }
177            for (i, expr) in value_row.iter().enumerate() {
178                let val = if let Expr::Parameter(n) = expr {
179                    params
180                        .get(n - 1)
181                        .cloned()
182                        .ok_or_else(|| SqlError::Parse(format!("unbound parameter ${n}")))?
183                } else {
184                    eval_const_expr(expr)?
185                };
186                let col_idx = col_indices[i];
187                let col = &table_schema.columns[col_idx];
188                row[col_idx] = if val.is_null() {
189                    Value::Null
190                } else {
191                    coerce_for_column(val, col, strict)?
192                };
193            }
194        } else if let Some(sel) = sel_rows {
195            let sel_row = &sel[idx];
196            for (i, val) in sel_row.iter().enumerate() {
197                let col_idx = col_indices[i];
198                let col = &table_schema.columns[col_idx];
199                row[col_idx] = if val.is_null() {
200                    Value::Null
201                } else {
202                    coerce_for_column(val.clone(), col, strict)?
203                };
204            }
205        }
206
207        for &(pos, def_expr) in &defaults {
208            let val = eval_const_expr(def_expr)?;
209            let col = &table_schema.columns[pos];
210            if !val.is_null() {
211                row[pos] = coerce_for_column(val, col, strict)?;
212            }
213        }
214
215        if let Some(ref gen_map) = row_col_map_for_gen {
216            for &(pos, gen_expr) in &generated_cols {
217                let val = eval_expr(gen_expr, &EvalCtx::new(gen_map, &row))?;
218                let col = &table_schema.columns[pos];
219                row[pos] = if val.is_null() {
220                    Value::Null
221                } else {
222                    coerce_for_column(val, col, strict)?
223                };
224            }
225        }
226
227        for col in &table_schema.columns {
228            if !col.nullable && row[col.position as usize].is_null() {
229                return Err(SqlError::NotNullViolation(col.name.clone()));
230            }
231        }
232
233        if let Some(ref col_map) = check_col_map {
234            for col in &table_schema.columns {
235                if let Some(ref check) = col.check_expr {
236                    let result = eval_expr(check, &EvalCtx::new(col_map, &row))?;
237                    if !is_truthy(&result) && !result.is_null() {
238                        let name = col.check_name.as_deref().unwrap_or(&col.name);
239                        return Err(SqlError::CheckViolation(name.to_string()));
240                    }
241                }
242            }
243            for tc in &table_schema.check_constraints {
244                let result = eval_expr(&tc.expr, &EvalCtx::new(col_map, &row))?;
245                if !is_truthy(&result) && !result.is_null() {
246                    let name = tc.name.as_deref().unwrap_or(&tc.sql);
247                    return Err(SqlError::CheckViolation(name.to_string()));
248                }
249            }
250        }
251
252        for fk in &table_schema.foreign_keys {
253            let any_null = fk.columns.iter().any(|&ci| row[ci as usize].is_null());
254            if any_null {
255                continue; // MATCH SIMPLE: skip if any FK col is NULL
256            }
257            let fk_vals: Vec<Value> = fk
258                .columns
259                .iter()
260                .map(|&ci| row[ci as usize].clone())
261                .collect();
262            fk_key_buf.clear();
263            encode_composite_key_into(&fk_vals, &mut fk_key_buf);
264            let found = wtx
265                .table_get(fk.foreign_table.as_bytes(), &fk_key_buf)
266                .map_err(SqlError::Storage)?;
267            if found.is_none() {
268                let name = fk.name.as_deref().unwrap_or(&fk.foreign_table);
269                return Err(SqlError::ForeignKeyViolation(name.to_string()));
270            }
271        }
272
273        let proposed_row_for_returning: Option<Vec<Value>> =
274            returning_rows.as_ref().map(|_| row.clone());
275
276        for (j, &i) in pk_indices.iter().enumerate() {
277            pk_values[j] = std::mem::replace(&mut row[i], Value::Null);
278        }
279        encode_composite_key_into(&pk_values, &mut key_buf);
280
281        for (j, &i) in non_pk.iter().enumerate() {
282            let col = &table_schema.columns[i];
283            if matches!(
284                col.generated_kind,
285                Some(crate::parser::GeneratedKind::Virtual)
286            ) {
287                value_values[enc_pos[j] as usize] = Value::Null;
288                row[i] = Value::Null;
289            } else {
290                value_values[enc_pos[j] as usize] = std::mem::replace(&mut row[i], Value::Null);
291            }
292        }
293        encode_row_into(&value_values, &mut value_buf);
294
295        if key_buf.len() > citadel_core::MAX_KEY_SIZE {
296            return Err(SqlError::KeyTooLarge {
297                size: key_buf.len(),
298                max: citadel_core::MAX_KEY_SIZE,
299            });
300        }
301        if value_buf.len() > citadel_core::MAX_INLINE_VALUE_SIZE {
302            return Err(SqlError::RowTooLarge {
303                size: value_buf.len(),
304                max: citadel_core::MAX_INLINE_VALUE_SIZE,
305            });
306        }
307
308        match compiled_conflict.as_ref() {
309            None => {
310                let is_new = wtx
311                    .table_insert(stmt.table.as_bytes(), &key_buf, &value_buf)
312                    .map_err(SqlError::Storage)?;
313                if !is_new {
314                    return Err(SqlError::DuplicateKey);
315                }
316                if !table_schema.indices.is_empty() {
317                    for (j, &i) in pk_indices.iter().enumerate() {
318                        row[i] = pk_values[j].clone();
319                    }
320                    for (j, &i) in non_pk.iter().enumerate() {
321                        row[i] =
322                            std::mem::replace(&mut value_values[enc_pos[j] as usize], Value::Null);
323                    }
324                    insert_index_entries(&mut wtx, table_schema, &row, &pk_values)?;
325                }
326                count += 1;
327                if let Some(buf) = returning_rows.as_mut() {
328                    buf.push((None, proposed_row_for_returning));
329                }
330            }
331            Some(oc) => {
332                let oc_ref: &CompiledOnConflict = oc;
333                let needs_row = upsert_needs_row(oc_ref, table_schema);
334                if needs_row {
335                    for (j, &i) in pk_indices.iter().enumerate() {
336                        row[i] = pk_values[j].clone();
337                    }
338                    for (j, &i) in non_pk.iter().enumerate() {
339                        row[i] =
340                            std::mem::replace(&mut value_values[enc_pos[j] as usize], Value::Null);
341                    }
342                }
343                let outcome = apply_insert_with_conflict(
344                    &mut wtx,
345                    table_schema,
346                    &key_buf,
347                    &value_buf,
348                    &row,
349                    &pk_values,
350                    oc_ref,
351                    row_col_map.as_ref().unwrap(),
352                    stmt.returning.is_some(),
353                )?;
354                match outcome {
355                    InsertRowOutcome::Inserted => {
356                        count += 1;
357                        if let Some(buf) = returning_rows.as_mut() {
358                            buf.push((None, proposed_row_for_returning));
359                        }
360                    }
361                    InsertRowOutcome::Updated { old, new } => {
362                        count += 1;
363                        if let Some(buf) = returning_rows.as_mut() {
364                            buf.push((Some(old), Some(new)));
365                        }
366                    }
367                    InsertRowOutcome::Skipped => {}
368                }
369            }
370        }
371    }
372
373    if let (Some(returning_cols), Some(rows)) = (stmt.returning.as_ref(), returning_rows) {
374        let qr = super::helpers::project_returning(table_schema, returning_cols, &rows)?;
375        wtx.commit().map_err(SqlError::Storage)?;
376        return Ok(ExecutionResult::Query(qr));
377    }
378
379    wtx.commit().map_err(SqlError::Storage)?;
380    Ok(ExecutionResult::RowsAffected(count))
381}
382
383pub(super) fn has_subquery(expr: &Expr) -> bool {
384    crate::parser::has_subquery(expr)
385}
386
387pub(super) fn stmt_has_subquery(stmt: &SelectStmt) -> bool {
388    if let Some(ref w) = stmt.where_clause {
389        if has_subquery(w) {
390            return true;
391        }
392    }
393    if let Some(ref h) = stmt.having {
394        if has_subquery(h) {
395            return true;
396        }
397    }
398    for col in &stmt.columns {
399        if let SelectColumn::Expr { expr, .. } = col {
400            if has_subquery(expr) {
401                return true;
402            }
403        }
404    }
405    for ob in &stmt.order_by {
406        if has_subquery(&ob.expr) {
407            return true;
408        }
409    }
410    for join in &stmt.joins {
411        if let Some(ref on_expr) = join.on_clause {
412            if has_subquery(on_expr) {
413                return true;
414            }
415        }
416    }
417    false
418}
419
420pub(super) fn materialize_expr(
421    expr: &Expr,
422    exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
423) -> Result<Expr> {
424    match expr {
425        Expr::InSubquery {
426            expr: e,
427            subquery,
428            negated,
429        } => {
430            let inner = materialize_expr(e, exec_sub)?;
431            let qr = exec_sub(subquery)?;
432            if !qr.columns.is_empty() && qr.columns.len() != 1 {
433                return Err(SqlError::SubqueryMultipleColumns);
434            }
435            let mut values = rustc_hash::FxHashSet::default();
436            let mut has_null = false;
437            for row in &qr.rows {
438                if row[0].is_null() {
439                    has_null = true;
440                } else {
441                    values.insert(row[0].clone());
442                }
443            }
444            Ok(Expr::InSet {
445                expr: Box::new(inner),
446                values,
447                has_null,
448                negated: *negated,
449            })
450        }
451        Expr::ScalarSubquery(subquery) => {
452            let qr = exec_sub(subquery)?;
453            if qr.rows.len() > 1 {
454                return Err(SqlError::SubqueryMultipleRows);
455            }
456            let val = if qr.rows.is_empty() {
457                Value::Null
458            } else {
459                qr.rows[0][0].clone()
460            };
461            Ok(Expr::Literal(val))
462        }
463        Expr::Exists { subquery, negated } => {
464            let qr = exec_sub(subquery)?;
465            let exists = !qr.rows.is_empty();
466            let result = if *negated { !exists } else { exists };
467            Ok(Expr::Literal(Value::Boolean(result)))
468        }
469        Expr::InList {
470            expr: e,
471            list,
472            negated,
473        } => {
474            let inner = materialize_expr(e, exec_sub)?;
475            let items = list
476                .iter()
477                .map(|item| materialize_expr(item, exec_sub))
478                .collect::<Result<Vec<_>>>()?;
479            Ok(Expr::InList {
480                expr: Box::new(inner),
481                list: items,
482                negated: *negated,
483            })
484        }
485        Expr::BinaryOp { left, op, right } => Ok(Expr::BinaryOp {
486            left: Box::new(materialize_expr(left, exec_sub)?),
487            op: *op,
488            right: Box::new(materialize_expr(right, exec_sub)?),
489        }),
490        Expr::UnaryOp { op, expr: e } => Ok(Expr::UnaryOp {
491            op: *op,
492            expr: Box::new(materialize_expr(e, exec_sub)?),
493        }),
494        Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(materialize_expr(e, exec_sub)?))),
495        Expr::IsNotNull(e) => Ok(Expr::IsNotNull(Box::new(materialize_expr(e, exec_sub)?))),
496        Expr::InSet {
497            expr: e,
498            values,
499            has_null,
500            negated,
501        } => Ok(Expr::InSet {
502            expr: Box::new(materialize_expr(e, exec_sub)?),
503            values: values.clone(),
504            has_null: *has_null,
505            negated: *negated,
506        }),
507        Expr::Between {
508            expr: e,
509            low,
510            high,
511            negated,
512        } => Ok(Expr::Between {
513            expr: Box::new(materialize_expr(e, exec_sub)?),
514            low: Box::new(materialize_expr(low, exec_sub)?),
515            high: Box::new(materialize_expr(high, exec_sub)?),
516            negated: *negated,
517        }),
518        Expr::Like {
519            expr: e,
520            pattern,
521            escape,
522            negated,
523        } => {
524            let esc = escape
525                .as_ref()
526                .map(|es| materialize_expr(es, exec_sub).map(Box::new))
527                .transpose()?;
528            Ok(Expr::Like {
529                expr: Box::new(materialize_expr(e, exec_sub)?),
530                pattern: Box::new(materialize_expr(pattern, exec_sub)?),
531                escape: esc,
532                negated: *negated,
533            })
534        }
535        Expr::Case {
536            operand,
537            conditions,
538            else_result,
539        } => {
540            let op = operand
541                .as_ref()
542                .map(|e| materialize_expr(e, exec_sub).map(Box::new))
543                .transpose()?;
544            let conds = conditions
545                .iter()
546                .map(|(c, r)| {
547                    Ok((
548                        materialize_expr(c, exec_sub)?,
549                        materialize_expr(r, exec_sub)?,
550                    ))
551                })
552                .collect::<Result<Vec<_>>>()?;
553            let else_r = else_result
554                .as_ref()
555                .map(|e| materialize_expr(e, exec_sub).map(Box::new))
556                .transpose()?;
557            Ok(Expr::Case {
558                operand: op,
559                conditions: conds,
560                else_result: else_r,
561            })
562        }
563        Expr::Coalesce(args) => {
564            let materialized = args
565                .iter()
566                .map(|a| materialize_expr(a, exec_sub))
567                .collect::<Result<Vec<_>>>()?;
568            Ok(Expr::Coalesce(materialized))
569        }
570        Expr::Cast { expr: e, data_type } => Ok(Expr::Cast {
571            expr: Box::new(materialize_expr(e, exec_sub)?),
572            data_type: *data_type,
573        }),
574        Expr::Function {
575            name,
576            args,
577            distinct,
578        } => {
579            let materialized = args
580                .iter()
581                .map(|a| materialize_expr(a, exec_sub))
582                .collect::<Result<Vec<_>>>()?;
583            Ok(Expr::Function {
584                name: name.clone(),
585                args: materialized,
586                distinct: *distinct,
587            })
588        }
589        other => Ok(other.clone()),
590    }
591}
592
593pub(super) fn materialize_stmt(
594    stmt: &SelectStmt,
595    exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
596) -> Result<SelectStmt> {
597    let where_clause = stmt
598        .where_clause
599        .as_ref()
600        .map(|e| materialize_expr(e, exec_sub))
601        .transpose()?;
602    let having = stmt
603        .having
604        .as_ref()
605        .map(|e| materialize_expr(e, exec_sub))
606        .transpose()?;
607    let columns = stmt
608        .columns
609        .iter()
610        .map(|c| match c {
611            SelectColumn::AllColumns => Ok(SelectColumn::AllColumns),
612            SelectColumn::AllFromOld => Ok(SelectColumn::AllFromOld),
613            SelectColumn::AllFromNew => Ok(SelectColumn::AllFromNew),
614            SelectColumn::Expr { expr, alias } => Ok(SelectColumn::Expr {
615                expr: materialize_expr(expr, exec_sub)?,
616                alias: alias.clone(),
617            }),
618        })
619        .collect::<Result<Vec<_>>>()?;
620    let order_by = stmt
621        .order_by
622        .iter()
623        .map(|ob| {
624            Ok(OrderByItem {
625                expr: materialize_expr(&ob.expr, exec_sub)?,
626                descending: ob.descending,
627                nulls_first: ob.nulls_first,
628            })
629        })
630        .collect::<Result<Vec<_>>>()?;
631    let joins = stmt
632        .joins
633        .iter()
634        .map(|j| {
635            let on_clause = j
636                .on_clause
637                .as_ref()
638                .map(|e| materialize_expr(e, exec_sub))
639                .transpose()?;
640            Ok(JoinClause {
641                join_type: j.join_type,
642                table: j.table.clone(),
643                subquery: j.subquery.clone(),
644                on_clause,
645            })
646        })
647        .collect::<Result<Vec<_>>>()?;
648    let group_by = stmt
649        .group_by
650        .iter()
651        .map(|e| materialize_expr(e, exec_sub))
652        .collect::<Result<Vec<_>>>()?;
653    Ok(SelectStmt {
654        columns,
655        from: stmt.from.clone(),
656        from_alias: stmt.from_alias.clone(),
657        from_subquery: stmt.from_subquery.clone(),
658        joins,
659        distinct: stmt.distinct,
660        where_clause,
661        order_by,
662        limit: stmt.limit.clone(),
663        offset: stmt.offset.clone(),
664        group_by,
665        having,
666    })
667}
668
669pub(super) fn exec_subquery_read(
670    db: &Database,
671    schema: &SchemaManager,
672    stmt: &SelectStmt,
673    ctes: &CteContext,
674) -> Result<QueryResult> {
675    match super::exec_select(db, schema, stmt, ctes)? {
676        ExecutionResult::Query(qr) => Ok(qr),
677        _ => Ok(QueryResult {
678            columns: vec![],
679            rows: vec![],
680        }),
681    }
682}
683
684pub(super) fn exec_subquery_write(
685    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
686    schema: &SchemaManager,
687    stmt: &SelectStmt,
688    ctes: &CteContext,
689) -> Result<QueryResult> {
690    match super::exec_select_in_txn(wtx, schema, stmt, ctes)? {
691        ExecutionResult::Query(qr) => Ok(qr),
692        _ => Ok(QueryResult {
693            columns: vec![],
694            rows: vec![],
695        }),
696    }
697}
698
699pub(super) fn update_has_subquery(stmt: &UpdateStmt) -> bool {
700    stmt.where_clause.as_ref().is_some_and(has_subquery)
701        || stmt.assignments.iter().any(|(_, e)| has_subquery(e))
702}
703
704pub(super) fn materialize_update(
705    stmt: &UpdateStmt,
706    exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
707) -> Result<UpdateStmt> {
708    let where_clause = stmt
709        .where_clause
710        .as_ref()
711        .map(|e| materialize_expr(e, exec_sub))
712        .transpose()?;
713    let assignments = stmt
714        .assignments
715        .iter()
716        .map(|(name, expr)| Ok((name.clone(), materialize_expr(expr, exec_sub)?)))
717        .collect::<Result<Vec<_>>>()?;
718    Ok(UpdateStmt {
719        table: stmt.table.clone(),
720        assignments,
721        where_clause,
722        returning: stmt.returning.clone(),
723    })
724}
725
726pub(super) fn delete_has_subquery(stmt: &DeleteStmt) -> bool {
727    stmt.where_clause.as_ref().is_some_and(has_subquery)
728}
729
730pub(super) fn materialize_delete(
731    stmt: &DeleteStmt,
732    exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
733) -> Result<DeleteStmt> {
734    let where_clause = stmt
735        .where_clause
736        .as_ref()
737        .map(|e| materialize_expr(e, exec_sub))
738        .transpose()?;
739    Ok(DeleteStmt {
740        table: stmt.table.clone(),
741        where_clause,
742        returning: stmt.returning.clone(),
743    })
744}
745
746pub(super) fn insert_has_subquery(stmt: &InsertStmt) -> bool {
747    match &stmt.source {
748        InsertSource::Values(rows) => rows.iter().any(|row| row.iter().any(has_subquery)),
749        // SELECT source subqueries are handled by exec_select's correlated/non-correlated paths
750        InsertSource::Select(_) => false,
751    }
752}
753
754pub(super) fn materialize_insert(
755    stmt: &InsertStmt,
756    exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
757) -> Result<InsertStmt> {
758    let source = match &stmt.source {
759        InsertSource::Values(rows) => {
760            let mat = rows
761                .iter()
762                .map(|row| {
763                    row.iter()
764                        .map(|e| materialize_expr(e, exec_sub))
765                        .collect::<Result<Vec<_>>>()
766                })
767                .collect::<Result<Vec<_>>>()?;
768            InsertSource::Values(mat)
769        }
770        InsertSource::Select(sq) => {
771            let ctes = sq
772                .ctes
773                .iter()
774                .map(|c| {
775                    Ok(CteDefinition {
776                        name: c.name.clone(),
777                        column_aliases: c.column_aliases.clone(),
778                        body: materialize_query_body(&c.body, exec_sub)?,
779                    })
780                })
781                .collect::<Result<Vec<_>>>()?;
782            let body = materialize_query_body(&sq.body, exec_sub)?;
783            InsertSource::Select(Box::new(SelectQuery {
784                ctes,
785                recursive: sq.recursive,
786                body,
787            }))
788        }
789    };
790    Ok(InsertStmt {
791        table: stmt.table.clone(),
792        columns: stmt.columns.clone(),
793        source,
794        on_conflict: stmt.on_conflict.clone(),
795        returning: stmt.returning.clone(),
796    })
797}
798
799pub(super) fn materialize_query_body(
800    body: &QueryBody,
801    exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
802) -> Result<QueryBody> {
803    match body {
804        QueryBody::Select(sel) => Ok(QueryBody::Select(Box::new(materialize_stmt(
805            sel, exec_sub,
806        )?))),
807        QueryBody::Compound(comp) => Ok(QueryBody::Compound(Box::new(CompoundSelect {
808            op: comp.op.clone(),
809            all: comp.all,
810            left: Box::new(materialize_query_body(&comp.left, exec_sub)?),
811            right: Box::new(materialize_query_body(&comp.right, exec_sub)?),
812            order_by: comp.order_by.clone(),
813            limit: comp.limit.clone(),
814            offset: comp.offset.clone(),
815        }))),
816        QueryBody::Insert(_) | QueryBody::Update(_) | QueryBody::Delete(_) => Ok(body.clone()),
817    }
818}
819
820pub(super) fn exec_query_body(
821    db: &Database,
822    schema: &SchemaManager,
823    body: &QueryBody,
824    ctes: &CteContext,
825) -> Result<ExecutionResult> {
826    match body {
827        QueryBody::Select(sel) => super::exec_select(db, schema, sel, ctes),
828        QueryBody::Compound(comp) => exec_compound_select(db, schema, comp, ctes),
829        QueryBody::Insert(_) | QueryBody::Update(_) | QueryBody::Delete(_) => Err(
830            SqlError::Unsupported("DML CTE bodies require an active write transaction".into()),
831        ),
832    }
833}
834
835pub(super) fn exec_query_body_in_txn(
836    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
837    schema: &SchemaManager,
838    body: &QueryBody,
839    ctes: &CteContext,
840) -> Result<ExecutionResult> {
841    match body {
842        QueryBody::Select(sel) => super::exec_select_in_txn(wtx, schema, sel, ctes),
843        QueryBody::Compound(comp) => exec_compound_select_in_txn(wtx, schema, comp, ctes),
844        QueryBody::Insert(ins) => exec_insert_in_txn_with_ctes(wtx, schema, ins, &[], ctes),
845        QueryBody::Update(upd) => super::exec_update_in_txn(wtx, schema, upd),
846        QueryBody::Delete(del) => super::exec_delete_in_txn(wtx, schema, del),
847    }
848}
849
850pub(super) fn exec_query_body_read(
851    db: &Database,
852    schema: &SchemaManager,
853    body: &QueryBody,
854    ctes: &CteContext,
855) -> Result<QueryResult> {
856    match exec_query_body(db, schema, body, ctes)? {
857        ExecutionResult::Query(qr) => Ok(qr),
858        _ => Ok(QueryResult {
859            columns: vec![],
860            rows: vec![],
861        }),
862    }
863}
864
865pub(super) fn exec_query_body_write(
866    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
867    schema: &SchemaManager,
868    body: &QueryBody,
869    ctes: &CteContext,
870) -> Result<QueryResult> {
871    match exec_query_body_in_txn(wtx, schema, body, ctes)? {
872        ExecutionResult::Query(qr) => Ok(qr),
873        _ => Ok(QueryResult {
874            columns: vec![],
875            rows: vec![],
876        }),
877    }
878}
879
880pub(super) fn exec_compound_select(
881    db: &Database,
882    schema: &SchemaManager,
883    comp: &CompoundSelect,
884    ctes: &CteContext,
885) -> Result<ExecutionResult> {
886    let left_qr = match exec_query_body(db, schema, &comp.left, ctes)? {
887        ExecutionResult::Query(qr) => qr,
888        _ => QueryResult {
889            columns: vec![],
890            rows: vec![],
891        },
892    };
893    let right_qr = match exec_query_body(db, schema, &comp.right, ctes)? {
894        ExecutionResult::Query(qr) => qr,
895        _ => QueryResult {
896            columns: vec![],
897            rows: vec![],
898        },
899    };
900    apply_set_operation(comp, left_qr, right_qr)
901}
902
903pub(super) fn exec_compound_select_in_txn(
904    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
905    schema: &SchemaManager,
906    comp: &CompoundSelect,
907    ctes: &CteContext,
908) -> Result<ExecutionResult> {
909    let left_qr = match exec_query_body_in_txn(wtx, schema, &comp.left, ctes)? {
910        ExecutionResult::Query(qr) => qr,
911        _ => QueryResult {
912            columns: vec![],
913            rows: vec![],
914        },
915    };
916    let right_qr = match exec_query_body_in_txn(wtx, schema, &comp.right, ctes)? {
917        ExecutionResult::Query(qr) => qr,
918        _ => QueryResult {
919            columns: vec![],
920            rows: vec![],
921        },
922    };
923    apply_set_operation(comp, left_qr, right_qr)
924}
925
926pub(super) fn apply_set_operation(
927    comp: &CompoundSelect,
928    left_qr: QueryResult,
929    right_qr: QueryResult,
930) -> Result<ExecutionResult> {
931    if !left_qr.columns.is_empty()
932        && !right_qr.columns.is_empty()
933        && left_qr.columns.len() != right_qr.columns.len()
934    {
935        return Err(SqlError::CompoundColumnCountMismatch {
936            left: left_qr.columns.len(),
937            right: right_qr.columns.len(),
938        });
939    }
940
941    let columns = left_qr.columns;
942
943    let mut rows = match (&comp.op, comp.all) {
944        (SetOp::Union, true) => {
945            let mut rows = left_qr.rows;
946            rows.extend(right_qr.rows);
947            rows
948        }
949        (SetOp::Union, false) => {
950            let mut seen: rustc_hash::FxHashSet<Vec<Value>> = rustc_hash::FxHashSet::default();
951            let mut rows = Vec::new();
952            for row in left_qr.rows.into_iter().chain(right_qr.rows) {
953                if !seen.contains(&row) {
954                    seen.insert(row.clone());
955                    rows.push(row);
956                }
957            }
958            rows
959        }
960        (SetOp::Intersect, true) => {
961            let mut right_counts: FxHashMap<Vec<Value>, usize> = FxHashMap::default();
962            for row in &right_qr.rows {
963                *right_counts.entry(row.clone()).or_insert(0) += 1;
964            }
965            let mut rows = Vec::new();
966            for row in left_qr.rows {
967                if let Some(count) = right_counts.get_mut(&row) {
968                    if *count > 0 {
969                        *count -= 1;
970                        rows.push(row);
971                    }
972                }
973            }
974            rows
975        }
976        (SetOp::Intersect, false) => {
977            let right_set: rustc_hash::FxHashSet<Vec<Value>> = right_qr.rows.into_iter().collect();
978            let mut seen: rustc_hash::FxHashSet<Vec<Value>> = rustc_hash::FxHashSet::default();
979            let mut rows = Vec::new();
980            for row in left_qr.rows {
981                if right_set.contains(&row) && !seen.contains(&row) {
982                    seen.insert(row.clone());
983                    rows.push(row);
984                }
985            }
986            rows
987        }
988        (SetOp::Except, true) => {
989            let mut right_counts: FxHashMap<Vec<Value>, usize> = FxHashMap::default();
990            for row in &right_qr.rows {
991                *right_counts.entry(row.clone()).or_insert(0) += 1;
992            }
993            let mut rows = Vec::new();
994            for row in left_qr.rows {
995                if let Some(count) = right_counts.get_mut(&row) {
996                    if *count > 0 {
997                        *count -= 1;
998                        continue;
999                    }
1000                }
1001                rows.push(row);
1002            }
1003            rows
1004        }
1005        (SetOp::Except, false) => {
1006            let right_set: rustc_hash::FxHashSet<Vec<Value>> = right_qr.rows.into_iter().collect();
1007            let mut seen: rustc_hash::FxHashSet<Vec<Value>> = rustc_hash::FxHashSet::default();
1008            let mut rows = Vec::new();
1009            for row in left_qr.rows {
1010                if !right_set.contains(&row) && !seen.contains(&row) {
1011                    seen.insert(row.clone());
1012                    rows.push(row);
1013                }
1014            }
1015            rows
1016        }
1017    };
1018
1019    if !comp.order_by.is_empty() {
1020        let col_defs: Vec<crate::types::ColumnDef> = columns
1021            .iter()
1022            .enumerate()
1023            .map(|(i, name)| crate::types::ColumnDef {
1024                name: name.clone(),
1025                data_type: crate::types::DataType::Null,
1026                nullable: true,
1027                position: i as u16,
1028                default_expr: None,
1029                default_sql: None,
1030                check_expr: None,
1031                check_sql: None,
1032                check_name: None,
1033                is_with_timezone: false,
1034                generated_expr: None,
1035                generated_sql: None,
1036                generated_kind: None,
1037                collation: crate::types::Collation::Binary,
1038            })
1039            .collect();
1040        sort_rows(&mut rows, &comp.order_by, &col_defs)?;
1041    }
1042
1043    if let Some(ref offset_expr) = comp.offset {
1044        let offset = eval_const_int(offset_expr)?.max(0) as usize;
1045        if offset < rows.len() {
1046            rows = rows.split_off(offset);
1047        } else {
1048            rows.clear();
1049        }
1050    }
1051
1052    if let Some(ref limit_expr) = comp.limit {
1053        let limit = eval_const_int(limit_expr)?.max(0) as usize;
1054        rows.truncate(limit);
1055    }
1056
1057    Ok(ExecutionResult::Query(QueryResult { columns, rows }))
1058}
1059
1060struct InsertBufs {
1061    row: Vec<Value>,
1062    pk_values: Vec<Value>,
1063    value_values: Vec<Value>,
1064    key_buf: Vec<u8>,
1065    value_buf: Vec<u8>,
1066    col_indices: Vec<usize>,
1067    fk_key_buf: Vec<u8>,
1068}
1069
1070impl InsertBufs {
1071    fn new() -> Self {
1072        Self {
1073            row: Vec::new(),
1074            pk_values: Vec::new(),
1075            value_values: Vec::new(),
1076            key_buf: Vec::with_capacity(64),
1077            value_buf: Vec::with_capacity(256),
1078            col_indices: Vec::new(),
1079            fk_key_buf: Vec::with_capacity(64),
1080        }
1081    }
1082}
1083
1084thread_local! {
1085    static INSERT_SCRATCH: RefCell<InsertBufs> = RefCell::new(InsertBufs::new());
1086    static UPSERT_SCRATCH: RefCell<UpsertBufs> = RefCell::new(UpsertBufs::new());
1087}
1088
1089fn with_insert_scratch<R>(f: impl FnOnce(&mut InsertBufs) -> R) -> R {
1090    INSERT_SCRATCH.with(|slot| f(&mut slot.borrow_mut()))
1091}
1092
1093pub(super) struct UpsertBufs {
1094    old_row: Vec<Value>,
1095    new_row: Vec<Value>,
1096    value_values: Vec<Value>,
1097    new_value_buf: Vec<u8>,
1098}
1099
1100impl UpsertBufs {
1101    pub(super) fn new() -> Self {
1102        Self {
1103            old_row: Vec::new(),
1104            new_row: Vec::new(),
1105            value_values: Vec::new(),
1106            new_value_buf: Vec::with_capacity(256),
1107        }
1108    }
1109}
1110
1111pub fn exec_insert_in_txn(
1112    wtx: &mut WriteTxn<'_>,
1113    schema: &SchemaManager,
1114    stmt: &InsertStmt,
1115    params: &[Value],
1116) -> Result<ExecutionResult> {
1117    with_insert_scratch(|bufs| {
1118        exec_insert_in_txn_impl(
1119            wtx,
1120            schema,
1121            stmt,
1122            params,
1123            bufs,
1124            None,
1125            &CteContext::default(),
1126        )
1127    })
1128}
1129
1130pub(super) fn exec_insert_in_txn_with_ctes(
1131    wtx: &mut WriteTxn<'_>,
1132    schema: &SchemaManager,
1133    stmt: &InsertStmt,
1134    params: &[Value],
1135    outer_ctes: &CteContext,
1136) -> Result<ExecutionResult> {
1137    with_insert_scratch(|bufs| {
1138        exec_insert_in_txn_impl(wtx, schema, stmt, params, bufs, None, outer_ctes)
1139    })
1140}
1141
1142fn exec_insert_in_txn_cached(
1143    wtx: &mut WriteTxn<'_>,
1144    schema: &SchemaManager,
1145    stmt: &InsertStmt,
1146    params: &[Value],
1147    cache: &InsertCache,
1148) -> Result<ExecutionResult> {
1149    with_insert_scratch(|bufs| {
1150        exec_insert_in_txn_impl(
1151            wtx,
1152            schema,
1153            stmt,
1154            params,
1155            bufs,
1156            Some(cache),
1157            &CteContext::default(),
1158        )
1159    })
1160}
1161
1162fn exec_insert_in_txn_impl(
1163    wtx: &mut WriteTxn<'_>,
1164    schema: &SchemaManager,
1165    stmt: &InsertStmt,
1166    params: &[Value],
1167    bufs: &mut InsertBufs,
1168    cache: Option<&InsertCache>,
1169    outer_ctes: &CteContext,
1170) -> Result<ExecutionResult> {
1171    let empty_ctes = CteContext::default();
1172    let materialized;
1173    let has_sub = match cache {
1174        Some(c) => c.has_subquery,
1175        None => insert_has_subquery(stmt),
1176    };
1177    let stmt = if has_sub {
1178        materialized = materialize_insert(stmt, &mut |sub| {
1179            exec_subquery_write(wtx, schema, sub, &empty_ctes)
1180        })?;
1181        &materialized
1182    } else {
1183        stmt
1184    };
1185
1186    let table_schema = schema
1187        .get(&stmt.table)
1188        .ok_or_else(|| SqlError::TableNotFound(stmt.table.clone()))?;
1189
1190    let default_columns;
1191    let insert_columns: &[String] = if stmt.columns.is_empty() {
1192        default_columns = table_schema
1193            .columns
1194            .iter()
1195            .map(|c| c.name.clone())
1196            .collect::<Vec<_>>();
1197        &default_columns
1198    } else {
1199        &stmt.columns
1200    };
1201
1202    bufs.col_indices.clear();
1203    if let Some(c) = cache {
1204        bufs.col_indices.extend_from_slice(&c.col_indices);
1205    } else {
1206        for name in insert_columns {
1207            bufs.col_indices.push(
1208                table_schema
1209                    .column_index(name)
1210                    .ok_or_else(|| SqlError::ColumnNotFound(name.clone()))?,
1211            );
1212        }
1213    }
1214
1215    if cache.is_none() {
1216        for &ci in &bufs.col_indices {
1217            if table_schema.columns[ci].generated_kind.is_some() {
1218                return Err(SqlError::CannotInsertIntoGeneratedColumn(
1219                    table_schema.columns[ci].name.clone(),
1220                ));
1221            }
1222        }
1223    }
1224
1225    let generated_cols_uncached: Vec<(usize, &Expr, FastGenEval)>;
1226    let cached_gen_positions: &[usize];
1227    let cached_gen_fast_evals: &[FastGenEval];
1228    if let Some(c) = cache {
1229        cached_gen_positions = &c.generated_col_positions;
1230        cached_gen_fast_evals = &c.generated_fast_evals;
1231        generated_cols_uncached = Vec::new();
1232    } else {
1233        cached_gen_positions = &[];
1234        cached_gen_fast_evals = &[];
1235        generated_cols_uncached = table_schema
1236            .columns
1237            .iter()
1238            .filter(|c| matches!(c.generated_kind, Some(crate::parser::GeneratedKind::Stored)))
1239            .map(|c| {
1240                let expr = c.generated_expr.as_ref().unwrap();
1241                let fe = detect_fast_gen_eval(expr, table_schema);
1242                (c.position as usize, expr, fe)
1243            })
1244            .collect();
1245    }
1246    let has_gen_cols = !cached_gen_positions.is_empty() || !generated_cols_uncached.is_empty();
1247    let row_col_map_for_gen_owned: Option<ColumnMap> = if !has_gen_cols || cache.is_some() {
1248        None
1249    } else {
1250        Some(ColumnMap::new(&table_schema.columns))
1251    };
1252    let row_col_map_for_gen: Option<&ColumnMap> = if !has_gen_cols {
1253        None
1254    } else if let Some(c) = cache {
1255        c.row_col_map.as_ref()
1256    } else {
1257        row_col_map_for_gen_owned.as_ref()
1258    };
1259
1260    let any_defaults = match cache {
1261        Some(c) => c.any_defaults,
1262        None => table_schema
1263            .columns
1264            .iter()
1265            .any(|c| c.default_expr.is_some()),
1266    };
1267    let defaults: Vec<(usize, &Expr)> = if any_defaults {
1268        table_schema
1269            .columns
1270            .iter()
1271            .filter(|c| {
1272                c.default_expr.is_some() && !bufs.col_indices.contains(&(c.position as usize))
1273            })
1274            .map(|c| (c.position as usize, c.default_expr.as_ref().unwrap()))
1275            .collect()
1276    } else {
1277        Vec::new()
1278    };
1279
1280    let has_checks = match cache {
1281        Some(c) => c.has_checks,
1282        None => table_schema.has_checks(),
1283    };
1284    let check_col_map = if has_checks {
1285        Some(ColumnMap::new(&table_schema.columns))
1286    } else {
1287        None
1288    };
1289
1290    let (pk_indices, non_pk, enc_pos, phys_count, dropped): (
1291        &[usize],
1292        &[usize],
1293        &[u16],
1294        usize,
1295        &[u16],
1296    ) = if let Some(c) = cache {
1297        (
1298            &c.pk_indices,
1299            &c.non_pk_indices,
1300            &c.encoding_positions,
1301            c.phys_count,
1302            &c.dropped_non_pk_slots,
1303        )
1304    } else {
1305        (
1306            table_schema.pk_indices(),
1307            table_schema.non_pk_indices(),
1308            table_schema.encoding_positions(),
1309            table_schema.physical_non_pk_count(),
1310            table_schema.dropped_non_pk_slots(),
1311        )
1312    };
1313
1314    bufs.row.resize(table_schema.columns.len(), Value::Null);
1315    bufs.pk_values.resize(pk_indices.len(), Value::Null);
1316    bufs.value_values.resize(phys_count, Value::Null);
1317
1318    let table_bytes = stmt.table.as_bytes();
1319    let has_fks = !table_schema.foreign_keys.is_empty();
1320    let has_indices = !table_schema.indices.is_empty();
1321    let has_defaults = !defaults.is_empty();
1322
1323    let compiled_conflict: Option<Arc<CompiledOnConflict>> = match (cache, &stmt.on_conflict) {
1324        (Some(c), Some(_)) if c.on_conflict.is_some() => c.on_conflict.clone(),
1325        (_, Some(oc)) => Some(Arc::new(compile_on_conflict(oc, table_schema)?)),
1326        (_, None) => None,
1327    };
1328
1329    let row_col_map_owned: Option<ColumnMap> =
1330        if compiled_conflict.is_some() && cache.and_then(|c| c.row_col_map.as_ref()).is_none() {
1331            Some(ColumnMap::new(&table_schema.columns))
1332        } else {
1333            None
1334        };
1335    let row_col_map: Option<&ColumnMap> = cache
1336        .and_then(|c| c.row_col_map.as_ref())
1337        .or(row_col_map_owned.as_ref());
1338
1339    let select_rows = match &stmt.source {
1340        InsertSource::Select(sq) => {
1341            let insert_ctes = super::materialize_all_ctes_with_outer(
1342                &sq.ctes,
1343                sq.recursive,
1344                outer_ctes,
1345                &mut |body, ctx| exec_query_body_write(wtx, schema, body, ctx),
1346            )?;
1347            let qr = exec_query_body_write(wtx, schema, &sq.body, &insert_ctes)?;
1348            Some(qr.rows)
1349        }
1350        InsertSource::Values(_) => None,
1351    };
1352
1353    let mut count: u64 = 0;
1354    let mut returning_rows: Option<Vec<super::helpers::ReturningRow>> =
1355        stmt.returning.as_ref().map(|_| Vec::new());
1356
1357    let values = match &stmt.source {
1358        InsertSource::Values(rows) => Some(rows.as_slice()),
1359        InsertSource::Select(_) => None,
1360    };
1361    let sel_rows = select_rows.as_deref();
1362
1363    let total = match (values, sel_rows) {
1364        (Some(rows), _) => rows.len(),
1365        (_, Some(rows)) => rows.len(),
1366        _ => 0,
1367    };
1368
1369    if let Some(sel) = sel_rows {
1370        if !sel.is_empty() && sel[0].len() != insert_columns.len() {
1371            return Err(SqlError::InvalidValue(format!(
1372                "INSERT ... SELECT column count mismatch: expected {}, got {}",
1373                insert_columns.len(),
1374                sel[0].len()
1375            )));
1376        }
1377    }
1378
1379    let skip_row_clear = cache.is_some_and(|c| c.row_fully_overwritten);
1380    for idx in 0..total {
1381        if !skip_row_clear {
1382            for v in bufs.row.iter_mut() {
1383                *v = Value::Null;
1384            }
1385        }
1386
1387        if let Some(value_rows) = values {
1388            if let Some(plan) = cache.and_then(|c| c.bind_plan.as_ref()) {
1389                for action in plan {
1390                    match action {
1391                        BindAction::Param {
1392                            param_idx,
1393                            col_idx,
1394                            target,
1395                        } => {
1396                            let v = &params[*param_idx];
1397                            bufs.row[*col_idx] = if v.is_null() {
1398                                Value::Null
1399                            } else if v.data_type() == *target {
1400                                v.clone()
1401                            } else {
1402                                let got = v.data_type();
1403                                v.clone().coerce_into(*target).ok_or_else(|| {
1404                                    SqlError::TypeMismatch {
1405                                        expected: target.to_string(),
1406                                        got: got.to_string(),
1407                                    }
1408                                })?
1409                            };
1410                        }
1411                        BindAction::Literal { value, col_idx } => {
1412                            bufs.row[*col_idx] = value.clone();
1413                        }
1414                    }
1415                }
1416            } else {
1417                let value_row = &value_rows[idx];
1418                if value_row.len() != insert_columns.len() {
1419                    return Err(SqlError::InvalidValue(format!(
1420                        "expected {} values, got {}",
1421                        insert_columns.len(),
1422                        value_row.len()
1423                    )));
1424                }
1425                for (i, expr) in value_row.iter().enumerate() {
1426                    let val = match expr {
1427                        Expr::Parameter(n) => params
1428                            .get(n - 1)
1429                            .cloned()
1430                            .ok_or_else(|| SqlError::Parse(format!("unbound parameter ${n}")))?,
1431                        Expr::Literal(v) => v.clone(),
1432                        _ => eval_const_expr(expr)?,
1433                    };
1434                    let col_idx = bufs.col_indices[i];
1435                    let col = &table_schema.columns[col_idx];
1436                    let got_type = val.data_type();
1437                    bufs.row[col_idx] = if val.is_null() {
1438                        Value::Null
1439                    } else {
1440                        val.coerce_into(col.data_type)
1441                            .ok_or_else(|| SqlError::TypeMismatch {
1442                                expected: col.data_type.to_string(),
1443                                got: got_type.to_string(),
1444                            })?
1445                    };
1446                }
1447            }
1448        } else if let Some(sel) = sel_rows {
1449            let sel_row = &sel[idx];
1450            for (i, val) in sel_row.iter().enumerate() {
1451                let col_idx = bufs.col_indices[i];
1452                let col = &table_schema.columns[col_idx];
1453                let got_type = val.data_type();
1454                bufs.row[col_idx] = if val.is_null() {
1455                    Value::Null
1456                } else {
1457                    val.clone().coerce_into(col.data_type).ok_or_else(|| {
1458                        SqlError::TypeMismatch {
1459                            expected: col.data_type.to_string(),
1460                            got: got_type.to_string(),
1461                        }
1462                    })?
1463                };
1464            }
1465        }
1466
1467        if has_defaults {
1468            for &(pos, def_expr) in &defaults {
1469                let val = eval_const_expr(def_expr)?;
1470                let col = &table_schema.columns[pos];
1471                if !val.is_null() {
1472                    let got_type = val.data_type();
1473                    bufs.row[pos] =
1474                        val.coerce_into(col.data_type)
1475                            .ok_or_else(|| SqlError::TypeMismatch {
1476                                expected: col.data_type.to_string(),
1477                                got: got_type.to_string(),
1478                            })?;
1479                }
1480            }
1481        }
1482
1483        if let Some(gen_map) = row_col_map_for_gen {
1484            if cache.is_some() {
1485                for (pos, fast) in cached_gen_positions
1486                    .iter()
1487                    .copied()
1488                    .zip(cached_gen_fast_evals.iter())
1489                {
1490                    let gen_expr = table_schema.columns[pos].generated_expr.as_ref().unwrap();
1491                    let val = eval_fast_gen(fast, gen_expr, &bufs.row, gen_map)?;
1492                    let col = &table_schema.columns[pos];
1493                    bufs.row[pos] = if val.is_null() {
1494                        Value::Null
1495                    } else {
1496                        let got_type = val.data_type();
1497                        val.coerce_into(col.data_type)
1498                            .ok_or_else(|| SqlError::TypeMismatch {
1499                                expected: col.data_type.to_string(),
1500                                got: got_type.to_string(),
1501                            })?
1502                    };
1503                }
1504            } else {
1505                for (pos, gen_expr, fast) in &generated_cols_uncached {
1506                    let val = eval_fast_gen(fast, gen_expr, &bufs.row, gen_map)?;
1507                    let col = &table_schema.columns[*pos];
1508                    bufs.row[*pos] = if val.is_null() {
1509                        Value::Null
1510                    } else {
1511                        let got_type = val.data_type();
1512                        val.coerce_into(col.data_type)
1513                            .ok_or_else(|| SqlError::TypeMismatch {
1514                                expected: col.data_type.to_string(),
1515                                got: got_type.to_string(),
1516                            })?
1517                    };
1518                }
1519            }
1520        }
1521
1522        if let Some(c) = cache {
1523            for &pos in &c.not_null_indices {
1524                if bufs.row[pos as usize].is_null() {
1525                    return Err(SqlError::NotNullViolation(
1526                        table_schema.columns[pos as usize].name.clone(),
1527                    ));
1528                }
1529            }
1530        } else {
1531            for col in &table_schema.columns {
1532                if !col.nullable && bufs.row[col.position as usize].is_null() {
1533                    return Err(SqlError::NotNullViolation(col.name.clone()));
1534                }
1535            }
1536        }
1537
1538        if let Some(ref col_map) = check_col_map {
1539            for col in &table_schema.columns {
1540                if let Some(ref check) = col.check_expr {
1541                    let result = eval_expr(check, &EvalCtx::new(col_map, &bufs.row))?;
1542                    if !is_truthy(&result) && !result.is_null() {
1543                        let name = col.check_name.as_deref().unwrap_or(&col.name);
1544                        return Err(SqlError::CheckViolation(name.to_string()));
1545                    }
1546                }
1547            }
1548            for tc in &table_schema.check_constraints {
1549                let result = eval_expr(&tc.expr, &EvalCtx::new(col_map, &bufs.row))?;
1550                if !is_truthy(&result) && !result.is_null() {
1551                    let name = tc.name.as_deref().unwrap_or(&tc.sql);
1552                    return Err(SqlError::CheckViolation(name.to_string()));
1553                }
1554            }
1555        }
1556
1557        if has_fks {
1558            for fk in &table_schema.foreign_keys {
1559                let any_null = fk.columns.iter().any(|&ci| bufs.row[ci as usize].is_null());
1560                if any_null {
1561                    continue;
1562                }
1563                crate::encoding::encode_composite_key_from_indices(
1564                    &fk.columns,
1565                    &bufs.row,
1566                    &mut bufs.fk_key_buf,
1567                );
1568                let found = wtx
1569                    .table_get(fk.foreign_table.as_bytes(), &bufs.fk_key_buf)
1570                    .map_err(SqlError::Storage)?;
1571                if found.is_none() {
1572                    let name = fk.name.as_deref().unwrap_or(&fk.foreign_table);
1573                    return Err(SqlError::ForeignKeyViolation(name.to_string()));
1574                }
1575            }
1576        }
1577
1578        let proposed_row_for_returning: Option<Vec<Value>> =
1579            returning_rows.as_ref().map(|_| bufs.row.clone());
1580
1581        for (j, &i) in pk_indices.iter().enumerate() {
1582            bufs.pk_values[j] = std::mem::replace(&mut bufs.row[i], Value::Null);
1583        }
1584        match cache.map(|c| c.single_int_pk).unwrap_or(false) {
1585            true => match bufs.pk_values[0] {
1586                Value::Integer(v) => crate::encoding::encode_int_key_into(v, &mut bufs.key_buf),
1587                _ => encode_composite_key_into(&bufs.pk_values, &mut bufs.key_buf),
1588            },
1589            false => encode_composite_key_into(&bufs.pk_values, &mut bufs.key_buf),
1590        }
1591
1592        for &slot in dropped {
1593            bufs.value_values[slot as usize] = Value::Null;
1594        }
1595        for (j, &i) in non_pk.iter().enumerate() {
1596            let col = &table_schema.columns[i];
1597            if matches!(
1598                col.generated_kind,
1599                Some(crate::parser::GeneratedKind::Virtual)
1600            ) {
1601                bufs.value_values[enc_pos[j] as usize] = Value::Null;
1602                bufs.row[i] = Value::Null;
1603            } else {
1604                bufs.value_values[enc_pos[j] as usize] =
1605                    std::mem::replace(&mut bufs.row[i], Value::Null);
1606            }
1607        }
1608        match cache.and_then(|c| c.row_encoder.as_ref()) {
1609            Some(tmpl) => crate::encoding::encode_int_row_with_template(
1610                tmpl,
1611                &bufs.value_values,
1612                &mut bufs.value_buf,
1613            )?,
1614            None => encode_row_into(&bufs.value_values, &mut bufs.value_buf),
1615        }
1616
1617        if bufs.key_buf.len() > citadel_core::MAX_KEY_SIZE {
1618            return Err(SqlError::KeyTooLarge {
1619                size: bufs.key_buf.len(),
1620                max: citadel_core::MAX_KEY_SIZE,
1621            });
1622        }
1623        if bufs.value_buf.len() > citadel_core::MAX_INLINE_VALUE_SIZE {
1624            return Err(SqlError::RowTooLarge {
1625                size: bufs.value_buf.len(),
1626                max: citadel_core::MAX_INLINE_VALUE_SIZE,
1627            });
1628        }
1629
1630        match compiled_conflict.as_ref() {
1631            None => {
1632                let is_new = wtx
1633                    .table_insert(table_bytes, &bufs.key_buf, &bufs.value_buf)
1634                    .map_err(SqlError::Storage)?;
1635                if !is_new {
1636                    return Err(SqlError::DuplicateKey);
1637                }
1638                if has_indices {
1639                    for (j, &i) in pk_indices.iter().enumerate() {
1640                        bufs.row[i] = bufs.pk_values[j].clone();
1641                    }
1642                    for (j, &i) in non_pk.iter().enumerate() {
1643                        bufs.row[i] = std::mem::replace(
1644                            &mut bufs.value_values[enc_pos[j] as usize],
1645                            Value::Null,
1646                        );
1647                    }
1648                    insert_index_entries(wtx, table_schema, &bufs.row, &bufs.pk_values)?;
1649                }
1650                count += 1;
1651                if let Some(buf) = returning_rows.as_mut() {
1652                    buf.push((None, proposed_row_for_returning));
1653                }
1654            }
1655            Some(oc) => {
1656                let oc_ref: &CompiledOnConflict = oc;
1657                let needs_row = upsert_needs_row(oc_ref, table_schema);
1658                if needs_row {
1659                    for (j, &i) in pk_indices.iter().enumerate() {
1660                        bufs.row[i] = bufs.pk_values[j].clone();
1661                    }
1662                    for (j, &i) in non_pk.iter().enumerate() {
1663                        bufs.row[i] = std::mem::replace(
1664                            &mut bufs.value_values[enc_pos[j] as usize],
1665                            Value::Null,
1666                        );
1667                    }
1668                }
1669                let outcome = apply_insert_with_conflict(
1670                    wtx,
1671                    table_schema,
1672                    &bufs.key_buf,
1673                    &bufs.value_buf,
1674                    &bufs.row,
1675                    &bufs.pk_values,
1676                    oc_ref,
1677                    row_col_map.unwrap(),
1678                    stmt.returning.is_some(),
1679                )?;
1680                match outcome {
1681                    InsertRowOutcome::Inserted => {
1682                        count += 1;
1683                        if let Some(buf) = returning_rows.as_mut() {
1684                            buf.push((None, proposed_row_for_returning));
1685                        }
1686                    }
1687                    InsertRowOutcome::Updated { old, new } => {
1688                        count += 1;
1689                        if let Some(buf) = returning_rows.as_mut() {
1690                            buf.push((Some(old), Some(new)));
1691                        }
1692                    }
1693                    InsertRowOutcome::Skipped => {}
1694                }
1695            }
1696        }
1697    }
1698
1699    if let (Some(returning_cols), Some(rows)) = (stmt.returning.as_ref(), returning_rows) {
1700        return Ok(ExecutionResult::Query(super::helpers::project_returning(
1701            table_schema,
1702            returning_cols,
1703            &rows,
1704        )?));
1705    }
1706
1707    Ok(ExecutionResult::RowsAffected(count))
1708}
1709
1710pub struct CompiledInsert {
1711    table_lower: String,
1712    cached: Option<InsertCache>,
1713}
1714
1715struct InsertCache {
1716    col_indices: Vec<usize>,
1717    has_subquery: bool,
1718    any_defaults: bool,
1719    has_checks: bool,
1720    on_conflict: Option<Arc<CompiledOnConflict>>,
1721    row_col_map: Option<ColumnMap>,
1722    generated_col_positions: Vec<usize>,
1723    generated_fast_evals: Vec<FastGenEval>,
1724    pk_indices: Vec<usize>,
1725    non_pk_indices: Vec<usize>,
1726    encoding_positions: Vec<u16>,
1727    dropped_non_pk_slots: Vec<u16>,
1728    phys_count: usize,
1729    single_int_pk: bool,
1730    not_null_indices: Vec<u16>,
1731    bind_plan: Option<Vec<BindAction>>,
1732    row_fully_overwritten: bool,
1733    row_encoder: Option<crate::encoding::IntRowTemplate>,
1734    is_trivial_fast: bool,
1735    trivial_fast_program: Option<TrivialFastProgram>,
1736}
1737
1738#[derive(Clone)]
1739enum BindAction {
1740    Param {
1741        param_idx: usize,
1742        col_idx: usize,
1743        target: DataType,
1744    },
1745    Literal {
1746        value: Value,
1747        col_idx: usize,
1748    },
1749}
1750
1751#[derive(Clone)]
1752struct TrivialFastProgram {
1753    template: Vec<u8>,
1754    ops: Vec<WriteOp>,
1755    pk_param: u8,
1756    not_null_param_indices: Vec<u8>,
1757}
1758
1759#[derive(Clone)]
1760enum WriteOp {
1761    ParamI64 {
1762        param_idx: u8,
1763        off: u32,
1764    },
1765    LiteralI64 {
1766        value: i64,
1767        off: u32,
1768    },
1769    GenAddParamsI64 {
1770        a_param: u8,
1771        b_param: u8,
1772        off: u32,
1773        bitmap_byte_off: u32,
1774        bitmap_bit_mask: u8,
1775    },
1776    GenMulAddParamI64 {
1777        param_idx: u8,
1778        mul: i64,
1779        add: i64,
1780        off: u32,
1781        bitmap_byte_off: u32,
1782        bitmap_bit_mask: u8,
1783    },
1784}
1785
1786fn build_trivial_fast_program(
1787    bind_plan: &[BindAction],
1788    row_encoder: &crate::encoding::IntRowTemplate,
1789    non_virtual_pairs: &[(usize, usize)],
1790    generated_col_positions: &[usize],
1791    generated_fast_evals: &[FastGenEval],
1792    pk_indices: &[usize],
1793    columns: &[crate::types::ColumnDef],
1794) -> Option<TrivialFastProgram> {
1795    let pk_col = pk_indices[0];
1796    let col_to_slot: rustc_hash::FxHashMap<usize, usize> =
1797        non_virtual_pairs.iter().copied().collect();
1798    let slot_to_off: rustc_hash::FxHashMap<usize, usize> =
1799        row_encoder.slot_offsets.iter().copied().collect();
1800
1801    let mut col_to_param: rustc_hash::FxHashMap<usize, u8> = Default::default();
1802    let mut col_to_lit_int: rustc_hash::FxHashMap<usize, i64> = Default::default();
1803    let mut pk_param: Option<u8> = None;
1804    let mut ops: Vec<WriteOp> = Vec::with_capacity(bind_plan.len() + generated_col_positions.len());
1805    let mut not_null_param_indices: Vec<u8> = Vec::new();
1806
1807    for action in bind_plan {
1808        match action {
1809            BindAction::Param {
1810                param_idx,
1811                col_idx,
1812                target,
1813            } => {
1814                if *target != DataType::Integer {
1815                    return None;
1816                }
1817                let pi: u8 = u8::try_from(*param_idx).ok()?;
1818                col_to_param.insert(*col_idx, pi);
1819                if *col_idx == pk_col {
1820                    pk_param = Some(pi);
1821                } else {
1822                    let slot = *col_to_slot.get(col_idx)?;
1823                    let off = u32::try_from(*slot_to_off.get(&slot)?).ok()?;
1824                    ops.push(WriteOp::ParamI64 { param_idx: pi, off });
1825                    if !columns[*col_idx].nullable {
1826                        not_null_param_indices.push(pi);
1827                    }
1828                }
1829            }
1830            BindAction::Literal { value, col_idx } => match value {
1831                Value::Integer(v) => {
1832                    col_to_lit_int.insert(*col_idx, *v);
1833                    if *col_idx == pk_col {
1834                        return None;
1835                    }
1836                    let slot = *col_to_slot.get(col_idx)?;
1837                    let off = u32::try_from(*slot_to_off.get(&slot)?).ok()?;
1838                    ops.push(WriteOp::LiteralI64 { value: *v, off });
1839                }
1840                _ => return None,
1841            },
1842        }
1843    }
1844
1845    let pk_param = pk_param?;
1846
1847    for (i, &gen_pos) in generated_col_positions.iter().enumerate() {
1848        let gen_slot = *col_to_slot.get(&gen_pos)?;
1849        let gen_off = u32::try_from(*slot_to_off.get(&gen_slot)?).ok()?;
1850        let bitmap_byte_off = u32::try_from(2 + gen_slot / 8).ok()?;
1851        let bitmap_bit_mask: u8 = 1u8 << (gen_slot % 8);
1852        let gen_col_nullable = columns[gen_pos].nullable;
1853
1854        match &generated_fast_evals[i] {
1855            FastGenEval::IntColAddCol {
1856                left_idx,
1857                right_idx,
1858            } => {
1859                let a_param = col_to_param.get(left_idx).copied();
1860                let b_param = col_to_param.get(right_idx).copied();
1861                match (a_param, b_param) {
1862                    (Some(ap), Some(bp)) => {
1863                        let deps_safe = gen_col_nullable
1864                            || (not_null_param_indices.contains(&ap)
1865                                && not_null_param_indices.contains(&bp));
1866                        if !deps_safe {
1867                            return None;
1868                        }
1869                        ops.push(WriteOp::GenAddParamsI64 {
1870                            a_param: ap,
1871                            b_param: bp,
1872                            off: gen_off,
1873                            bitmap_byte_off,
1874                            bitmap_bit_mask,
1875                        });
1876                    }
1877                    (Some(p), None) => {
1878                        let lit = col_to_lit_int.get(right_idx).copied()?;
1879                        if !gen_col_nullable && !not_null_param_indices.contains(&p) {
1880                            return None;
1881                        }
1882                        ops.push(WriteOp::GenMulAddParamI64 {
1883                            param_idx: p,
1884                            mul: 1,
1885                            add: lit,
1886                            off: gen_off,
1887                            bitmap_byte_off,
1888                            bitmap_bit_mask,
1889                        });
1890                    }
1891                    (None, Some(p)) => {
1892                        let lit = col_to_lit_int.get(left_idx).copied()?;
1893                        if !gen_col_nullable && !not_null_param_indices.contains(&p) {
1894                            return None;
1895                        }
1896                        ops.push(WriteOp::GenMulAddParamI64 {
1897                            param_idx: p,
1898                            mul: 1,
1899                            add: lit,
1900                            off: gen_off,
1901                            bitmap_byte_off,
1902                            bitmap_bit_mask,
1903                        });
1904                    }
1905                    (None, None) => {
1906                        let la = col_to_lit_int.get(left_idx).copied()?;
1907                        let lb = col_to_lit_int.get(right_idx).copied()?;
1908                        ops.push(WriteOp::LiteralI64 {
1909                            value: la.wrapping_add(lb),
1910                            off: gen_off,
1911                        });
1912                    }
1913                }
1914            }
1915            FastGenEval::IntColMulAdd {
1916                col_schema_idx,
1917                mul,
1918                add,
1919            } => {
1920                if let Some(p) = col_to_param.get(col_schema_idx).copied() {
1921                    if !gen_col_nullable && !not_null_param_indices.contains(&p) {
1922                        return None;
1923                    }
1924                    ops.push(WriteOp::GenMulAddParamI64 {
1925                        param_idx: p,
1926                        mul: *mul,
1927                        add: *add,
1928                        off: gen_off,
1929                        bitmap_byte_off,
1930                        bitmap_bit_mask,
1931                    });
1932                } else if let Some(lit) = col_to_lit_int.get(col_schema_idx).copied() {
1933                    ops.push(WriteOp::LiteralI64 {
1934                        value: lit.wrapping_mul(*mul).wrapping_add(*add),
1935                        off: gen_off,
1936                    });
1937                } else {
1938                    return None;
1939                }
1940            }
1941            FastGenEval::None => return None,
1942        }
1943    }
1944
1945    Some(TrivialFastProgram {
1946        template: row_encoder.template.clone(),
1947        ops,
1948        pk_param,
1949        not_null_param_indices,
1950    })
1951}
1952
1953#[derive(Clone)]
1954pub(super) enum CompiledOnConflict {
1955    DoNothing {
1956        target: Option<ConflictKind>,
1957    },
1958    DoUpdate {
1959        target: ConflictKind,
1960        assignments: Vec<(usize, Expr)>,
1961        where_clause: Option<Expr>,
1962        fast_paths: Option<Vec<DoUpdateFastPath>>,
1963    },
1964}
1965
1966#[derive(Clone, Copy)]
1967pub(super) enum DoUpdateFastPath {
1968    IntAddConst { phys_idx: usize, delta: i64 },
1969}
1970
1971#[derive(Clone, Debug)]
1972pub(super) enum ConflictKind {
1973    PrimaryKey,
1974    UniqueIndex { index_idx: usize },
1975}
1976
1977fn resolve_conflict_target(target: &ConflictTarget, ts: &TableSchema) -> Result<ConflictKind> {
1978    match target {
1979        ConflictTarget::Columns(cols) => {
1980            let col_idx_set: Vec<u16> = cols
1981                .iter()
1982                .map(|name| {
1983                    ts.column_index(name)
1984                        .map(|i| i as u16)
1985                        .ok_or_else(|| SqlError::ColumnNotFound(name.clone()))
1986                })
1987                .collect::<Result<_>>()?;
1988            let pk_set = ts.primary_key_columns.clone();
1989            if set_equal(&col_idx_set, &pk_set) {
1990                return Ok(ConflictKind::PrimaryKey);
1991            }
1992            for (index_idx, idx) in ts.indices.iter().enumerate() {
1993                if idx.unique && set_equal(&col_idx_set, &idx.columns) {
1994                    return Ok(ConflictKind::UniqueIndex { index_idx });
1995                }
1996            }
1997            Err(SqlError::Plan(
1998                "ON CONFLICT target does not match any unique constraint".into(),
1999            ))
2000        }
2001        ConflictTarget::Constraint(name) => {
2002            let lower = name.to_ascii_lowercase();
2003            for (index_idx, idx) in ts.indices.iter().enumerate() {
2004                if idx.name.eq_ignore_ascii_case(&lower) {
2005                    if idx.unique {
2006                        return Ok(ConflictKind::UniqueIndex { index_idx });
2007                    }
2008                    return Err(SqlError::Plan(format!(
2009                        "ON CONFLICT ON CONSTRAINT '{name}' requires a unique index"
2010                    )));
2011                }
2012            }
2013            Err(SqlError::Plan(format!(
2014                "unknown constraint '{name}'; primary keys cannot be referenced by name, use ON CONFLICT (col_list)"
2015            )))
2016        }
2017    }
2018}
2019
2020fn set_equal(a: &[u16], b: &[u16]) -> bool {
2021    if a.len() != b.len() {
2022        return false;
2023    }
2024    let mut a_sorted = a.to_vec();
2025    let mut b_sorted = b.to_vec();
2026    a_sorted.sort_unstable();
2027    b_sorted.sort_unstable();
2028    a_sorted == b_sorted
2029}
2030
2031pub(super) enum InsertRowOutcome {
2032    Inserted,
2033    Updated { old: Vec<Value>, new: Vec<Value> },
2034    Skipped,
2035}
2036
2037#[allow(clippy::too_many_arguments)]
2038#[inline]
2039pub(super) fn apply_insert_with_conflict(
2040    wtx: &mut WriteTxn<'_>,
2041    table_schema: &TableSchema,
2042    key_buf: &[u8],
2043    value_buf: &[u8],
2044    row: &[Value],
2045    pk_values: &[Value],
2046    on_conflict: &CompiledOnConflict,
2047    col_map: &ColumnMap,
2048    capture_returning: bool,
2049) -> Result<InsertRowOutcome> {
2050    let table_bytes = table_schema.name.as_bytes();
2051
2052    if let CompiledOnConflict::DoNothing { target } = on_conflict {
2053        let pk_target = matches!(target, None | Some(ConflictKind::PrimaryKey));
2054        if pk_target && table_schema.indices.is_empty() && table_schema.foreign_keys.is_empty() {
2055            let inserted = wtx
2056                .table_insert_if_absent(table_bytes, key_buf, value_buf)
2057                .map_err(SqlError::Storage)?;
2058            return Ok(if inserted {
2059                InsertRowOutcome::Inserted
2060            } else {
2061                InsertRowOutcome::Skipped
2062            });
2063        }
2064    }
2065
2066    if let CompiledOnConflict::DoUpdate {
2067        target: ConflictKind::PrimaryKey,
2068        assignments,
2069        where_clause,
2070        fast_paths,
2071    } = on_conflict
2072    {
2073        if can_fuse_do_update(table_schema, assignments) {
2074            return apply_do_update_fused(
2075                wtx,
2076                table_schema,
2077                table_bytes,
2078                key_buf,
2079                value_buf,
2080                row,
2081                assignments,
2082                where_clause.as_ref(),
2083                col_map,
2084                fast_paths.as_deref(),
2085                capture_returning,
2086            );
2087        }
2088    }
2089
2090    let primary_outcome = wtx
2091        .table_insert_or_fetch(table_bytes, key_buf, value_buf)
2092        .map_err(SqlError::Storage)?;
2093
2094    match primary_outcome {
2095        citadel_txn::write_txn::InsertOutcome::Inserted => {
2096            if table_schema.indices.is_empty() {
2097                return Ok(InsertRowOutcome::Inserted);
2098            }
2099            let mut inserted_keys: Vec<(usize, Vec<u8>)> = Vec::new();
2100            match insert_index_entries_or_fetch(
2101                wtx,
2102                table_schema,
2103                row,
2104                pk_values,
2105                &mut inserted_keys,
2106            )? {
2107                None => Ok(InsertRowOutcome::Inserted),
2108                Some(conflicting_idx) => {
2109                    let matches_target =
2110                        matches!(on_conflict, CompiledOnConflict::DoNothing { target: None })
2111                            || matches!(
2112                                on_conflict,
2113                                CompiledOnConflict::DoNothing {
2114                                    target: Some(ConflictKind::UniqueIndex { index_idx }),
2115                                } | CompiledOnConflict::DoUpdate {
2116                                    target: ConflictKind::UniqueIndex { index_idx },
2117                                    ..
2118                                } if *index_idx == conflicting_idx
2119                            );
2120                    undo_partial_insert(wtx, table_schema, key_buf, &inserted_keys)?;
2121                    if !matches_target {
2122                        return Err(SqlError::UniqueViolation(
2123                            table_schema.indices[conflicting_idx].name.clone(),
2124                        ));
2125                    }
2126                    match on_conflict {
2127                        CompiledOnConflict::DoNothing { .. } => Ok(InsertRowOutcome::Skipped),
2128                        CompiledOnConflict::DoUpdate {
2129                            assignments,
2130                            where_clause,
2131                            ..
2132                        } => {
2133                            let existing_pk =
2134                                fetch_unique_index_pk(wtx, table_schema, conflicting_idx, row)?;
2135                            apply_do_update(
2136                                wtx,
2137                                table_schema,
2138                                &existing_pk,
2139                                row,
2140                                assignments,
2141                                where_clause.as_ref(),
2142                                col_map,
2143                                capture_returning,
2144                            )
2145                        }
2146                    }
2147                }
2148            }
2149        }
2150        citadel_txn::write_txn::InsertOutcome::Existed(old_bytes) => {
2151            let matches_target = matches!(
2152                on_conflict,
2153                CompiledOnConflict::DoNothing { target: None }
2154                    | CompiledOnConflict::DoNothing {
2155                        target: Some(ConflictKind::PrimaryKey),
2156                    }
2157                    | CompiledOnConflict::DoUpdate {
2158                        target: ConflictKind::PrimaryKey,
2159                        ..
2160                    }
2161            );
2162            if !matches_target {
2163                return Err(SqlError::DuplicateKey);
2164            }
2165            match on_conflict {
2166                CompiledOnConflict::DoNothing { .. } => Ok(InsertRowOutcome::Skipped),
2167                CompiledOnConflict::DoUpdate {
2168                    assignments,
2169                    where_clause,
2170                    ..
2171                } => {
2172                    let old_row = decode_full_row(table_schema, key_buf, &old_bytes)?;
2173                    apply_do_update_with_old_row(
2174                        wtx,
2175                        table_schema,
2176                        key_buf,
2177                        &old_row,
2178                        row,
2179                        assignments,
2180                        where_clause.as_ref(),
2181                        col_map,
2182                        capture_returning,
2183                    )
2184                }
2185            }
2186        }
2187    }
2188}
2189
2190#[inline]
2191fn apply_fast_path_patch(
2192    old_bytes: &[u8],
2193    fast_paths: &[DoUpdateFastPath],
2194) -> Result<UpsertAction> {
2195    UPSERT_SCRATCH.with(|slot| {
2196        let mut bufs = slot.borrow_mut();
2197        bufs.new_value_buf.clear();
2198        bufs.new_value_buf.extend_from_slice(old_bytes);
2199
2200        let mut patch_scratch: Vec<u8> = Vec::new();
2201
2202        for fp in fast_paths {
2203            match fp {
2204                DoUpdateFastPath::IntAddConst { phys_idx, delta } => {
2205                    let decoded =
2206                        crate::encoding::decode_columns(&bufs.new_value_buf, &[*phys_idx])?;
2207                    let old_val = &decoded[0];
2208                    let new_val = match old_val {
2209                        Value::Integer(i) => Value::Integer(i.wrapping_add(*delta)),
2210                        Value::Null => Value::Null,
2211                        _ => {
2212                            return Err(SqlError::TypeMismatch {
2213                                expected: "INTEGER".into(),
2214                                got: old_val.data_type().to_string(),
2215                            });
2216                        }
2217                    };
2218                    if !crate::encoding::patch_column_in_place(
2219                        &mut bufs.new_value_buf,
2220                        *phys_idx,
2221                        &new_val,
2222                    )? {
2223                        patch_scratch.clear();
2224                        crate::encoding::patch_row_column(
2225                            &bufs.new_value_buf,
2226                            *phys_idx,
2227                            &new_val,
2228                            &mut patch_scratch,
2229                        )?;
2230                        std::mem::swap(&mut bufs.new_value_buf, &mut patch_scratch);
2231                    }
2232                }
2233            }
2234        }
2235
2236        if bufs.new_value_buf.len() > citadel_core::MAX_INLINE_VALUE_SIZE {
2237            return Err(SqlError::RowTooLarge {
2238                size: bufs.new_value_buf.len(),
2239                max: citadel_core::MAX_INLINE_VALUE_SIZE,
2240            });
2241        }
2242
2243        Ok(UpsertAction::Replace(bufs.new_value_buf.clone()))
2244    })
2245}
2246
2247fn upsert_needs_row(oc: &CompiledOnConflict, ts: &TableSchema) -> bool {
2248    if !ts.indices.is_empty() {
2249        return true;
2250    }
2251    match oc {
2252        CompiledOnConflict::DoNothing { .. } => false,
2253        CompiledOnConflict::DoUpdate { fast_paths, .. } => fast_paths.is_none() || ts.has_checks(),
2254    }
2255}
2256
2257fn can_fuse_do_update(ts: &TableSchema, assignments: &[(usize, Expr)]) -> bool {
2258    if !ts.indices.is_empty() {
2259        return false;
2260    }
2261    if !ts.foreign_keys.is_empty() {
2262        return false;
2263    }
2264    if ts.columns.iter().any(|c| c.generated_kind.is_some()) {
2265        return false;
2266    }
2267    let pk = ts.pk_indices();
2268    !assignments.iter().any(|(ci, _)| pk.contains(ci))
2269}
2270
2271#[allow(clippy::too_many_arguments)]
2272#[inline]
2273fn apply_do_update_fused(
2274    wtx: &mut WriteTxn<'_>,
2275    table_schema: &TableSchema,
2276    table_bytes: &[u8],
2277    key_buf: &[u8],
2278    value_buf: &[u8],
2279    proposed_row: &[Value],
2280    assignments: &[(usize, Expr)],
2281    where_clause: Option<&Expr>,
2282    col_map: &ColumnMap,
2283    fast_paths: Option<&[DoUpdateFastPath]>,
2284    capture_returning: bool,
2285) -> Result<InsertRowOutcome> {
2286    let non_pk = table_schema.non_pk_indices();
2287    let enc_pos = table_schema.encoding_positions();
2288    let phys_count = table_schema.physical_non_pk_count();
2289    let dropped = table_schema.dropped_non_pk_slots();
2290    let has_checks = table_schema.has_checks();
2291    let has_fks = !table_schema.foreign_keys.is_empty();
2292
2293    let captured: std::cell::RefCell<Option<(Vec<Value>, Vec<Value>)>> =
2294        std::cell::RefCell::new(None);
2295
2296    let outcome =
2297        wtx.table_upsert_with::<_, SqlError>(table_bytes, key_buf, value_buf, |old_bytes| {
2298            if let Some(fps) = fast_paths {
2299                if !has_checks {
2300                    let action = apply_fast_path_patch(old_bytes, fps)?;
2301                    if capture_returning {
2302                        if let UpsertAction::Replace(ref new_bytes) = action {
2303                            let old_row = decode_full_row(table_schema, key_buf, old_bytes)?;
2304                            let new_row = decode_full_row(table_schema, key_buf, new_bytes)?;
2305                            *captured.borrow_mut() = Some((old_row, new_row));
2306                        }
2307                    }
2308                    return Ok(action);
2309                }
2310            }
2311            UPSERT_SCRATCH.with(|slot| {
2312                let mut bufs = slot.borrow_mut();
2313                let UpsertBufs {
2314                    old_row,
2315                    new_row,
2316                    value_values,
2317                    new_value_buf,
2318                } = &mut *bufs;
2319
2320                old_row.clear();
2321                old_row.resize(table_schema.columns.len(), Value::Null);
2322                decode_full_row_into(table_schema, key_buf, old_bytes, old_row)?;
2323
2324                if let Some(w) = where_clause {
2325                    let ctx = EvalCtx::with_excluded(col_map, old_row, col_map, proposed_row);
2326                    let result = eval_expr(w, &ctx)?;
2327                    if result.is_null() || !is_truthy(&result) {
2328                        return Ok(UpsertAction::Skip);
2329                    }
2330                }
2331
2332                new_row.clear();
2333                new_row.extend_from_slice(old_row);
2334                for (col_idx, expr) in assignments {
2335                    let ctx = EvalCtx::with_excluded(col_map, old_row, col_map, proposed_row);
2336                    let val = eval_expr(expr, &ctx)?;
2337                    let col = &table_schema.columns[*col_idx];
2338                    new_row[*col_idx] = if val.is_null() {
2339                        Value::Null
2340                    } else {
2341                        let got = val.data_type();
2342                        val.coerce_into(col.data_type)
2343                            .ok_or_else(|| SqlError::TypeMismatch {
2344                                expected: col.data_type.to_string(),
2345                                got: got.to_string(),
2346                            })?
2347                    };
2348                }
2349
2350                for (assigned_idx, _) in assignments {
2351                    let col = &table_schema.columns[*assigned_idx];
2352                    if !col.nullable && new_row[col.position as usize].is_null() {
2353                        return Err(SqlError::NotNullViolation(col.name.clone()));
2354                    }
2355                }
2356                if has_checks {
2357                    for col in &table_schema.columns {
2358                        if let Some(ref check) = col.check_expr {
2359                            let ctx = EvalCtx::new(col_map, new_row);
2360                            let result = eval_expr(check, &ctx)?;
2361                            if !is_truthy(&result) && !result.is_null() {
2362                                let name = col.check_name.as_deref().unwrap_or(&col.name);
2363                                return Err(SqlError::CheckViolation(name.to_string()));
2364                            }
2365                        }
2366                    }
2367                    for tc in &table_schema.check_constraints {
2368                        let ctx = EvalCtx::new(col_map, new_row);
2369                        let result = eval_expr(&tc.expr, &ctx)?;
2370                        if !is_truthy(&result) && !result.is_null() {
2371                            let name = tc.name.as_deref().unwrap_or(&tc.sql);
2372                            return Err(SqlError::CheckViolation(name.to_string()));
2373                        }
2374                    }
2375                }
2376                let _ = has_fks;
2377
2378                value_values.clear();
2379                value_values.resize(phys_count, Value::Null);
2380                for &slot in dropped {
2381                    value_values[slot as usize] = Value::Null;
2382                }
2383                for (j, &i) in non_pk.iter().enumerate() {
2384                    value_values[enc_pos[j] as usize] = new_row[i].clone();
2385                }
2386                new_value_buf.clear();
2387                crate::encoding::encode_row_into(value_values, new_value_buf);
2388
2389                if new_value_buf.len() > citadel_core::MAX_INLINE_VALUE_SIZE {
2390                    return Err(SqlError::RowTooLarge {
2391                        size: new_value_buf.len(),
2392                        max: citadel_core::MAX_INLINE_VALUE_SIZE,
2393                    });
2394                }
2395
2396                if capture_returning {
2397                    *captured.borrow_mut() = Some((old_row.clone(), new_row.clone()));
2398                }
2399                Ok(UpsertAction::Replace(new_value_buf.clone()))
2400            })
2401        })?;
2402
2403    match outcome {
2404        UpsertOutcome::Inserted => Ok(InsertRowOutcome::Inserted),
2405        UpsertOutcome::Updated => {
2406            if capture_returning {
2407                let (old, new) = captured.into_inner().ok_or_else(|| {
2408                    SqlError::InvalidValue("DO UPDATE produced no captured rows".into())
2409                })?;
2410                Ok(InsertRowOutcome::Updated { old, new })
2411            } else {
2412                Ok(InsertRowOutcome::Inserted)
2413            }
2414        }
2415        UpsertOutcome::Skipped => Ok(InsertRowOutcome::Skipped),
2416    }
2417}
2418
2419fn fetch_unique_index_pk(
2420    wtx: &mut WriteTxn<'_>,
2421    table_schema: &TableSchema,
2422    index_idx: usize,
2423    row: &[Value],
2424) -> Result<Vec<u8>> {
2425    let idx = &table_schema.indices[index_idx];
2426    let idx_table = TableSchema::index_table_name(&table_schema.name, &idx.name);
2427    let indexed: Vec<Value> = idx
2428        .columns
2429        .iter()
2430        .map(|&col_idx| row[col_idx as usize].clone())
2431        .collect();
2432    let key = crate::encoding::encode_composite_key(&indexed);
2433    let value = wtx
2434        .table_get(&idx_table, &key)
2435        .map_err(SqlError::Storage)?
2436        .ok_or_else(|| {
2437            SqlError::InvalidValue("unique index missing expected collision entry".into())
2438        })?;
2439    Ok(value)
2440}
2441
2442#[allow(clippy::too_many_arguments)]
2443fn apply_do_update(
2444    wtx: &mut WriteTxn<'_>,
2445    table_schema: &TableSchema,
2446    pk_key: &[u8],
2447    proposed_row: &[Value],
2448    assignments: &[(usize, Expr)],
2449    where_clause: Option<&Expr>,
2450    col_map: &ColumnMap,
2451    capture_returning: bool,
2452) -> Result<InsertRowOutcome> {
2453    let old_value = wtx
2454        .table_get(table_schema.name.as_bytes(), pk_key)
2455        .map_err(SqlError::Storage)?
2456        .ok_or_else(|| SqlError::InvalidValue("primary row missing for DO UPDATE target".into()))?;
2457    let old_row = decode_full_row(table_schema, pk_key, &old_value)?;
2458    apply_do_update_with_old_row(
2459        wtx,
2460        table_schema,
2461        pk_key,
2462        &old_row,
2463        proposed_row,
2464        assignments,
2465        where_clause,
2466        col_map,
2467        capture_returning,
2468    )
2469}
2470
2471#[allow(clippy::too_many_arguments)]
2472fn apply_do_update_with_old_row(
2473    wtx: &mut WriteTxn<'_>,
2474    table_schema: &TableSchema,
2475    old_pk_key: &[u8],
2476    old_row: &[Value],
2477    proposed_row: &[Value],
2478    assignments: &[(usize, Expr)],
2479    where_clause: Option<&Expr>,
2480    col_map: &ColumnMap,
2481    capture_returning: bool,
2482) -> Result<InsertRowOutcome> {
2483    if let Some(w) = where_clause {
2484        let ctx = EvalCtx::with_excluded(col_map, old_row, col_map, proposed_row);
2485        let result = eval_expr(w, &ctx)?;
2486        if result.is_null() || !is_truthy(&result) {
2487            return Ok(InsertRowOutcome::Skipped);
2488        }
2489    }
2490
2491    let mut new_row = old_row.to_vec();
2492    for (col_idx, expr) in assignments {
2493        let ctx = EvalCtx::with_excluded(col_map, old_row, col_map, proposed_row);
2494        let val = eval_expr(expr, &ctx)?;
2495        let col = &table_schema.columns[*col_idx];
2496        new_row[*col_idx] = if val.is_null() {
2497            Value::Null
2498        } else {
2499            let got = val.data_type();
2500            val.coerce_into(col.data_type)
2501                .ok_or_else(|| SqlError::TypeMismatch {
2502                    expected: col.data_type.to_string(),
2503                    got: got.to_string(),
2504                })?
2505        };
2506    }
2507
2508    for col in &table_schema.columns {
2509        if matches!(
2510            col.generated_kind,
2511            Some(crate::parser::GeneratedKind::Stored)
2512        ) {
2513            let val = eval_expr(
2514                col.generated_expr.as_ref().unwrap(),
2515                &EvalCtx::new(col_map, &new_row),
2516            )?;
2517            let pos = col.position as usize;
2518            new_row[pos] = if val.is_null() {
2519                if !col.nullable {
2520                    return Err(SqlError::NotNullViolation(col.name.clone()));
2521                }
2522                Value::Null
2523            } else {
2524                let got = val.data_type();
2525                val.coerce_into(col.data_type)
2526                    .ok_or_else(|| SqlError::TypeMismatch {
2527                        expected: col.data_type.to_string(),
2528                        got: got.to_string(),
2529                    })?
2530            };
2531        }
2532    }
2533
2534    let pk_indices = table_schema.pk_indices();
2535    let assigned_pk = assignments.iter().any(|(ci, _)| pk_indices.contains(ci));
2536    let pk_changed = assigned_pk && pk_indices.iter().any(|&i| old_row[i] != new_row[i]);
2537
2538    for (assigned_idx, _) in assignments {
2539        let col = &table_schema.columns[*assigned_idx];
2540        if !col.nullable && new_row[col.position as usize].is_null() {
2541            return Err(SqlError::NotNullViolation(col.name.clone()));
2542        }
2543    }
2544    if table_schema.has_checks() {
2545        for col in &table_schema.columns {
2546            if let Some(ref check) = col.check_expr {
2547                let ctx = EvalCtx::new(col_map, &new_row);
2548                let result = eval_expr(check, &ctx)?;
2549                if !is_truthy(&result) && !result.is_null() {
2550                    let name = col.check_name.as_deref().unwrap_or(&col.name);
2551                    return Err(SqlError::CheckViolation(name.to_string()));
2552                }
2553            }
2554        }
2555        for tc in &table_schema.check_constraints {
2556            let ctx = EvalCtx::new(col_map, &new_row);
2557            let result = eval_expr(&tc.expr, &ctx)?;
2558            if !is_truthy(&result) && !result.is_null() {
2559                let name = tc.name.as_deref().unwrap_or(&tc.sql);
2560                return Err(SqlError::CheckViolation(name.to_string()));
2561            }
2562        }
2563    }
2564    for fk in &table_schema.foreign_keys {
2565        let changed = fk
2566            .columns
2567            .iter()
2568            .any(|&ci| old_row[ci as usize] != new_row[ci as usize]);
2569        if !changed {
2570            continue;
2571        }
2572        let any_null = fk.columns.iter().any(|&ci| new_row[ci as usize].is_null());
2573        if any_null {
2574            continue;
2575        }
2576        let fk_vals: Vec<Value> = fk
2577            .columns
2578            .iter()
2579            .map(|&ci| new_row[ci as usize].clone())
2580            .collect();
2581        let fk_key = crate::encoding::encode_composite_key(&fk_vals);
2582        let found = wtx
2583            .table_get(fk.foreign_table.as_bytes(), &fk_key)
2584            .map_err(SqlError::Storage)?;
2585        if found.is_none() {
2586            let name = fk.name.as_deref().unwrap_or(&fk.foreign_table);
2587            return Err(SqlError::ForeignKeyViolation(name.to_string()));
2588        }
2589    }
2590
2591    let has_indices = !table_schema.indices.is_empty();
2592    let old_pk_values: Vec<Value> = if has_indices || pk_changed {
2593        pk_indices.iter().map(|&i| old_row[i].clone()).collect()
2594    } else {
2595        Vec::new()
2596    };
2597    let new_pk_values: Vec<Value> = if has_indices || pk_changed {
2598        pk_indices.iter().map(|&i| new_row[i].clone()).collect()
2599    } else {
2600        Vec::new()
2601    };
2602
2603    let non_pk = table_schema.non_pk_indices();
2604    let enc_pos = table_schema.encoding_positions();
2605    let phys_count = table_schema.physical_non_pk_count();
2606    let dropped = table_schema.dropped_non_pk_slots();
2607    let mut value_values: Vec<Value> = vec![Value::Null; phys_count];
2608    for &slot in dropped {
2609        value_values[slot as usize] = Value::Null;
2610    }
2611    for (j, &i) in non_pk.iter().enumerate() {
2612        let col = &table_schema.columns[i];
2613        value_values[enc_pos[j] as usize] = if matches!(
2614            col.generated_kind,
2615            Some(crate::parser::GeneratedKind::Virtual)
2616        ) {
2617            Value::Null
2618        } else {
2619            new_row[i].clone()
2620        };
2621    }
2622    let mut new_value_buf = Vec::with_capacity(256);
2623    crate::encoding::encode_row_into(&value_values, &mut new_value_buf);
2624
2625    if new_value_buf.len() > citadel_core::MAX_INLINE_VALUE_SIZE {
2626        return Err(SqlError::RowTooLarge {
2627            size: new_value_buf.len(),
2628            max: citadel_core::MAX_INLINE_VALUE_SIZE,
2629        });
2630    }
2631
2632    if pk_changed {
2633        let new_pk_key = crate::encoding::encode_composite_key(&new_pk_values);
2634        let inserted = wtx
2635            .table_insert(table_schema.name.as_bytes(), &new_pk_key, &new_value_buf)
2636            .map_err(SqlError::Storage)?;
2637        if !inserted {
2638            return Err(SqlError::DuplicateKey);
2639        }
2640        wtx.table_delete(table_schema.name.as_bytes(), old_pk_key)
2641            .map_err(SqlError::Storage)?;
2642        for idx in &table_schema.indices {
2643            let idx_table = TableSchema::index_table_name(&table_schema.name, &idx.name);
2644            let old_idx_key = encode_index_key(idx, old_row, &old_pk_values);
2645            wtx.table_delete(&idx_table, &old_idx_key)
2646                .map_err(SqlError::Storage)?;
2647            let new_idx_key = encode_index_key(idx, &new_row, &new_pk_values);
2648            let new_idx_val = encode_index_value(idx, &new_row, &new_pk_values);
2649            let is_new = wtx
2650                .table_insert(&idx_table, &new_idx_key, &new_idx_val)
2651                .map_err(SqlError::Storage)?;
2652            if idx.unique && !is_new {
2653                let any_null = idx.columns.iter().any(|&c| new_row[c as usize].is_null());
2654                if !any_null {
2655                    return Err(SqlError::UniqueViolation(idx.name.clone()));
2656                }
2657            }
2658        }
2659    } else {
2660        wtx.table_update_sorted(
2661            table_schema.name.as_bytes(),
2662            &[(old_pk_key, new_value_buf.as_slice())],
2663        )
2664        .map_err(SqlError::Storage)?;
2665        let col_map_partial =
2666            any_partial_index(table_schema).then(|| ColumnMap::new(&table_schema.columns));
2667        for idx in &table_schema.indices {
2668            let cols_changed = index_columns_changed(idx, old_row, &new_row);
2669            let (del, ins) = partial_idx_update_actions(
2670                idx,
2671                old_row,
2672                &new_row,
2673                cols_changed,
2674                false,
2675                col_map_partial.as_ref(),
2676            );
2677            let idx_table = TableSchema::index_table_name(&table_schema.name, &idx.name);
2678            if del {
2679                let old_idx_key = encode_index_key(idx, old_row, &old_pk_values);
2680                wtx.table_delete(&idx_table, &old_idx_key)
2681                    .map_err(SqlError::Storage)?;
2682            }
2683            if ins {
2684                let new_idx_key = encode_index_key(idx, &new_row, &new_pk_values);
2685                let new_idx_val = encode_index_value(idx, &new_row, &new_pk_values);
2686                let is_new = wtx
2687                    .table_insert(&idx_table, &new_idx_key, &new_idx_val)
2688                    .map_err(SqlError::Storage)?;
2689                if idx.unique && !is_new {
2690                    let any_null = idx.columns.iter().any(|&c| new_row[c as usize].is_null());
2691                    if !any_null {
2692                        return Err(SqlError::UniqueViolation(idx.name.clone()));
2693                    }
2694                }
2695            }
2696        }
2697    }
2698
2699    if capture_returning {
2700        Ok(InsertRowOutcome::Updated {
2701            old: old_row.to_vec(),
2702            new: new_row,
2703        })
2704    } else {
2705        Ok(InsertRowOutcome::Inserted)
2706    }
2707}
2708
2709fn detect_fast_paths(
2710    ts: &TableSchema,
2711    assignments: &[(usize, Expr)],
2712) -> Option<Vec<DoUpdateFastPath>> {
2713    let non_pk = ts.non_pk_indices();
2714    let enc_pos = ts.encoding_positions();
2715    let mut out = Vec::with_capacity(assignments.len());
2716    for (col_idx, expr) in assignments {
2717        let col = &ts.columns[*col_idx];
2718        if col.data_type != DataType::Integer {
2719            return None;
2720        }
2721        let nonpk_order = non_pk.iter().position(|&i| i == *col_idx)?;
2722        let phys_idx = enc_pos[nonpk_order] as usize;
2723
2724        if let Expr::BinaryOp { left, op, right } = expr {
2725            if !matches!(op, BinOp::Add | BinOp::Sub) {
2726                return None;
2727            }
2728            let reads_target =
2729                matches!(left.as_ref(), Expr::Column(n) if n.eq_ignore_ascii_case(&col.name));
2730            if !reads_target {
2731                return None;
2732            }
2733            if let Expr::Literal(Value::Integer(n)) = right.as_ref() {
2734                let delta = if matches!(op, BinOp::Sub) { -n } else { *n };
2735                let _ = col_idx;
2736                out.push(DoUpdateFastPath::IntAddConst { phys_idx, delta });
2737                continue;
2738            }
2739            return None;
2740        }
2741        return None;
2742    }
2743    Some(out)
2744}
2745
2746fn compile_on_conflict(oc: &OnConflictClause, ts: &TableSchema) -> Result<CompiledOnConflict> {
2747    let target = oc
2748        .target
2749        .as_ref()
2750        .map(|t| resolve_conflict_target(t, ts))
2751        .transpose()?;
2752    match &oc.action {
2753        OnConflictAction::DoNothing => Ok(CompiledOnConflict::DoNothing { target }),
2754        OnConflictAction::DoUpdate {
2755            assignments,
2756            where_clause,
2757        } => {
2758            let target = target.ok_or_else(|| {
2759                SqlError::Plan("ON CONFLICT without target requires DO NOTHING".into())
2760            })?;
2761            let compiled_assignments: Vec<(usize, Expr)> = assignments
2762                .iter()
2763                .map(|(name, expr)| {
2764                    let col_idx = ts
2765                        .column_index(name)
2766                        .ok_or_else(|| SqlError::ColumnNotFound(name.clone()))?;
2767                    Ok((col_idx, expr.clone()))
2768                })
2769                .collect::<Result<_>>()?;
2770            let fast_paths = if where_clause.is_none() {
2771                detect_fast_paths(ts, &compiled_assignments)
2772            } else {
2773                None
2774            };
2775            Ok(CompiledOnConflict::DoUpdate {
2776                target,
2777                assignments: compiled_assignments,
2778                where_clause: where_clause.clone(),
2779                fast_paths,
2780            })
2781        }
2782    }
2783}
2784
2785/// Caller MUST check `cache.is_trivial_fast` first.
2786fn exec_insert_trivial_fast(
2787    wtx: &mut WriteTxn<'_>,
2788    table_lower: &str,
2789    cache: &InsertCache,
2790    bufs: &mut InsertBufs,
2791    params: &[Value],
2792) -> Result<ExecutionResult> {
2793    let prog = cache
2794        .trivial_fast_program
2795        .as_ref()
2796        .expect("trivial fast: program");
2797
2798    for &p in &prog.not_null_param_indices {
2799        if params[p as usize].is_null() {
2800            return Err(SqlError::NotNullViolation(format!("param@{p}")));
2801        }
2802    }
2803
2804    match &params[prog.pk_param as usize] {
2805        Value::Integer(v) => crate::encoding::encode_int_key_into(*v, &mut bufs.key_buf),
2806        _ => return Err(SqlError::InvalidValue("non-integer PK in fast path".into())),
2807    }
2808
2809    bufs.value_buf.clear();
2810    bufs.value_buf.extend_from_slice(&prog.template);
2811
2812    for op in &prog.ops {
2813        match op {
2814            WriteOp::ParamI64 { param_idx, off } => match &params[*param_idx as usize] {
2815                Value::Integer(v) => {
2816                    let off = *off as usize;
2817                    bufs.value_buf[off..off + 8].copy_from_slice(&v.to_le_bytes());
2818                }
2819                other => {
2820                    return Err(SqlError::TypeMismatch {
2821                        expected: "Integer".into(),
2822                        got: other.data_type().to_string(),
2823                    });
2824                }
2825            },
2826            WriteOp::LiteralI64 { value, off } => {
2827                let off = *off as usize;
2828                bufs.value_buf[off..off + 8].copy_from_slice(&value.to_le_bytes());
2829            }
2830            WriteOp::GenAddParamsI64 {
2831                a_param,
2832                b_param,
2833                off,
2834                bitmap_byte_off,
2835                bitmap_bit_mask,
2836            } => match (&params[*a_param as usize], &params[*b_param as usize]) {
2837                (Value::Integer(a), Value::Integer(b)) => {
2838                    let off = *off as usize;
2839                    bufs.value_buf[off..off + 8].copy_from_slice(&a.wrapping_add(*b).to_le_bytes());
2840                }
2841                _ => {
2842                    bufs.value_buf[*bitmap_byte_off as usize] |= *bitmap_bit_mask;
2843                }
2844            },
2845            WriteOp::GenMulAddParamI64 {
2846                param_idx,
2847                mul,
2848                add,
2849                off,
2850                bitmap_byte_off,
2851                bitmap_bit_mask,
2852            } => match &params[*param_idx as usize] {
2853                Value::Integer(v) => {
2854                    let r = v.wrapping_mul(*mul).wrapping_add(*add);
2855                    let off = *off as usize;
2856                    bufs.value_buf[off..off + 8].copy_from_slice(&r.to_le_bytes());
2857                }
2858                _ => {
2859                    bufs.value_buf[*bitmap_byte_off as usize] |= *bitmap_bit_mask;
2860                }
2861            },
2862        }
2863    }
2864
2865    let is_new = wtx
2866        .table_insert(table_lower.as_bytes(), &bufs.key_buf, &bufs.value_buf)
2867        .map_err(SqlError::Storage)?;
2868    if !is_new {
2869        return Err(SqlError::DuplicateKey);
2870    }
2871    Ok(ExecutionResult::RowsAffected(1))
2872}
2873
2874fn build_bind_plan(
2875    stmt: &InsertStmt,
2876    col_indices: &[usize],
2877    col_data_types: &[DataType],
2878) -> Option<Vec<BindAction>> {
2879    let rows = match &stmt.source {
2880        InsertSource::Values(rows) => rows,
2881        _ => return None,
2882    };
2883    if rows.len() != 1 {
2884        return None;
2885    }
2886    let value_row = &rows[0];
2887    if value_row.len() != col_indices.len() {
2888        return None;
2889    }
2890    let mut plan = Vec::with_capacity(value_row.len());
2891    for (i, expr) in value_row.iter().enumerate() {
2892        let col_idx = col_indices[i];
2893        let target = col_data_types[col_idx];
2894        match expr {
2895            Expr::Parameter(n) => {
2896                if *n == 0 {
2897                    return None;
2898                }
2899                plan.push(BindAction::Param {
2900                    param_idx: n - 1,
2901                    col_idx,
2902                    target,
2903                });
2904            }
2905            Expr::Literal(v) => plan.push(BindAction::Literal {
2906                value: v.clone(),
2907                col_idx,
2908            }),
2909            _ => return None,
2910        }
2911    }
2912    Some(plan)
2913}
2914
2915impl CompiledInsert {
2916    pub fn try_compile(schema: &SchemaManager, stmt: &InsertStmt) -> Option<Self> {
2917        let lower = stmt.table.to_ascii_lowercase();
2918        let cached = if let Some(ts) = schema.get(&lower) {
2919            let insert_columns: Vec<&str> = if stmt.columns.is_empty() {
2920                ts.columns.iter().map(|c| c.name.as_str()).collect()
2921            } else {
2922                stmt.columns.iter().map(|s| s.as_str()).collect()
2923            };
2924            let mut col_indices = Vec::with_capacity(insert_columns.len());
2925            for name in &insert_columns {
2926                col_indices.push(ts.column_index(name)?);
2927            }
2928            if col_indices
2929                .iter()
2930                .any(|&ci| ts.columns[ci].generated_kind.is_some())
2931            {
2932                return None;
2933            }
2934            let on_conflict = stmt
2935                .on_conflict
2936                .as_ref()
2937                .map(|oc| compile_on_conflict(oc, ts))
2938                .transpose()
2939                .ok()
2940                .flatten()
2941                .map(Arc::new);
2942            let generated_col_positions: Vec<usize> = ts
2943                .columns
2944                .iter()
2945                .enumerate()
2946                .filter_map(|(i, c)| {
2947                    matches!(c.generated_kind, Some(crate::parser::GeneratedKind::Stored))
2948                        .then_some(i)
2949                })
2950                .collect();
2951            let generated_fast_evals: Vec<FastGenEval> = generated_col_positions
2952                .iter()
2953                .map(|&pos| {
2954                    detect_fast_gen_eval(ts.columns[pos].generated_expr.as_ref().unwrap(), ts)
2955                })
2956                .collect();
2957            let row_col_map = if on_conflict.is_some() || !generated_col_positions.is_empty() {
2958                Some(ColumnMap::new(&ts.columns))
2959            } else {
2960                None
2961            };
2962            let pk_indices: Vec<usize> = ts.pk_indices().to_vec();
2963            let non_pk_indices: Vec<usize> = ts.non_pk_indices().to_vec();
2964            let encoding_positions: Vec<u16> = ts.encoding_positions().to_vec();
2965            let dropped_non_pk_slots: Vec<u16> = ts.dropped_non_pk_slots().to_vec();
2966            let phys_count = ts.physical_non_pk_count();
2967            let col_data_types: Vec<DataType> = ts.columns.iter().map(|c| c.data_type).collect();
2968            let single_int_pk =
2969                pk_indices.len() == 1 && ts.columns[pk_indices[0]].data_type == DataType::Integer;
2970            let not_null_indices: Vec<u16> = ts
2971                .columns
2972                .iter()
2973                .filter(|c| !c.nullable)
2974                .map(|c| c.position)
2975                .collect();
2976            let bind_plan = build_bind_plan(stmt, &col_indices, &col_data_types);
2977            let any_defaults_flag = ts.columns.iter().any(|c| c.default_expr.is_some());
2978            let row_fully_overwritten = if any_defaults_flag {
2979                false
2980            } else {
2981                let mut covered: rustc_hash::FxHashSet<usize> =
2982                    col_indices.iter().copied().collect();
2983                covered.extend(generated_col_positions.iter().copied());
2984                for (j, &i) in non_pk_indices.iter().enumerate() {
2985                    let _ = j;
2986                    if matches!(
2987                        ts.columns[i].generated_kind,
2988                        Some(crate::parser::GeneratedKind::Virtual)
2989                    ) {
2990                        covered.insert(i);
2991                    }
2992                }
2993                bind_plan.is_some() && covered.len() == ts.columns.len()
2994            };
2995            let has_fks = !ts.foreign_keys.is_empty();
2996            let has_indices = !ts.indices.is_empty();
2997            let mut non_virtual_pairs: Vec<(usize, usize)> = Vec::new();
2998            let mut null_value_slots: Vec<usize> =
2999                dropped_non_pk_slots.iter().map(|&s| s as usize).collect();
3000            for (j, &i) in non_pk_indices.iter().enumerate() {
3001                let slot = encoding_positions[j] as usize;
3002                if matches!(
3003                    ts.columns[i].generated_kind,
3004                    Some(crate::parser::GeneratedKind::Virtual)
3005                ) {
3006                    null_value_slots.push(slot);
3007                } else {
3008                    non_virtual_pairs.push((i, slot));
3009                }
3010            }
3011            let row_encoder = {
3012                let all_int_or_null = non_pk_indices.iter().enumerate().all(|(j, &i)| {
3013                    let col = &ts.columns[i];
3014                    if matches!(
3015                        col.generated_kind,
3016                        Some(crate::parser::GeneratedKind::Virtual)
3017                    ) {
3018                        true
3019                    } else {
3020                        col.data_type == DataType::Integer && encoding_positions[j] != u16::MAX
3021                    }
3022                });
3023                if all_int_or_null {
3024                    let mut null_slots: Vec<usize> =
3025                        dropped_non_pk_slots.iter().map(|&s| s as usize).collect();
3026                    for (j, &i) in non_pk_indices.iter().enumerate() {
3027                        if matches!(
3028                            ts.columns[i].generated_kind,
3029                            Some(crate::parser::GeneratedKind::Virtual)
3030                        ) {
3031                            null_slots.push(encoding_positions[j] as usize);
3032                        }
3033                    }
3034                    Some(crate::encoding::build_int_row_template(
3035                        phys_count,
3036                        &null_slots,
3037                    ))
3038                } else {
3039                    None
3040                }
3041            };
3042            let is_trivial_fast_eligible = !insert_has_subquery(stmt)
3043                && !ts.columns.iter().any(|c| c.default_expr.is_some())
3044                && !ts.has_checks()
3045                && !has_fks
3046                && !has_indices
3047                && stmt.on_conflict.is_none()
3048                && stmt.returning.is_none()
3049                && bind_plan.is_some()
3050                && row_encoder.is_some()
3051                && row_fully_overwritten
3052                && single_int_pk
3053                && generated_fast_evals
3054                    .iter()
3055                    .all(|fe| !matches!(fe, FastGenEval::None));
3056            let trivial_fast_program = if is_trivial_fast_eligible {
3057                build_trivial_fast_program(
3058                    bind_plan.as_ref().unwrap(),
3059                    row_encoder.as_ref().unwrap(),
3060                    &non_virtual_pairs,
3061                    &generated_col_positions,
3062                    &generated_fast_evals,
3063                    &pk_indices,
3064                    &ts.columns,
3065                )
3066            } else {
3067                None
3068            };
3069            let is_trivial_fast = trivial_fast_program.is_some();
3070            Some(InsertCache {
3071                col_indices,
3072                has_subquery: insert_has_subquery(stmt),
3073                any_defaults: ts.columns.iter().any(|c| c.default_expr.is_some()),
3074                has_checks: ts.has_checks(),
3075                on_conflict,
3076                row_col_map,
3077                generated_col_positions,
3078                generated_fast_evals,
3079                pk_indices,
3080                non_pk_indices,
3081                encoding_positions,
3082                dropped_non_pk_slots,
3083                phys_count,
3084                single_int_pk,
3085                not_null_indices,
3086                bind_plan,
3087                row_fully_overwritten,
3088                row_encoder,
3089                is_trivial_fast,
3090                trivial_fast_program,
3091            })
3092        } else if schema.get_view(&lower).is_some() {
3093            None
3094        } else {
3095            return None;
3096        };
3097        Some(Self {
3098            table_lower: lower,
3099            cached,
3100        })
3101    }
3102}
3103
3104impl CompiledPlan for CompiledInsert {
3105    fn execute(
3106        &self,
3107        db: &Database,
3108        schema: &SchemaManager,
3109        stmt: &Statement,
3110        params: &[Value],
3111        wtx: Option<&mut WriteTxn<'_>>,
3112    ) -> Result<ExecutionResult> {
3113        let ins = match stmt {
3114            Statement::Insert(i) => i,
3115            _ => {
3116                return Err(SqlError::Unsupported(
3117                    "CompiledInsert received non-INSERT statement".into(),
3118                ))
3119            }
3120        };
3121        match wtx {
3122            None => exec_insert(db, schema, ins, params),
3123            Some(outer) => match self.cached.as_ref() {
3124                Some(c) if c.is_trivial_fast => with_insert_scratch(|bufs| {
3125                    exec_insert_trivial_fast(outer, &self.table_lower, c, bufs, params)
3126                }),
3127                Some(c) => exec_insert_in_txn_cached(outer, schema, ins, params, c),
3128                None => exec_insert_in_txn(outer, schema, ins, params),
3129            },
3130        }
3131    }
3132
3133    fn uses_scoped_params(&self) -> bool {
3134        match self.cached.as_ref() {
3135            Some(c) => !c.is_trivial_fast,
3136            None => true,
3137        }
3138    }
3139}
3140
3141pub struct CompiledDelete {
3142    table_lower: String,
3143}
3144
3145impl CompiledDelete {
3146    pub fn try_compile(schema: &SchemaManager, stmt: &DeleteStmt) -> Option<Self> {
3147        let lower = stmt.table.to_ascii_lowercase();
3148        schema.get(&lower)?;
3149        Some(Self { table_lower: lower })
3150    }
3151}
3152
3153impl CompiledPlan for CompiledDelete {
3154    fn execute(
3155        &self,
3156        db: &Database,
3157        schema: &SchemaManager,
3158        stmt: &Statement,
3159        _params: &[Value],
3160        wtx: Option<&mut WriteTxn<'_>>,
3161    ) -> Result<ExecutionResult> {
3162        let del = match stmt {
3163            Statement::Delete(d) => d,
3164            _ => {
3165                return Err(SqlError::Unsupported(
3166                    "CompiledDelete received non-DELETE statement".into(),
3167                ))
3168            }
3169        };
3170        let _ = &self.table_lower;
3171        match wtx {
3172            None => super::write::exec_delete(db, schema, del),
3173            Some(outer) => super::write::exec_delete_in_txn(outer, schema, del),
3174        }
3175    }
3176}