1use std::collections::{BTreeMap, HashMap};
4
5use citadel::Database;
6
7use crate::encoding::{
8 decode_column_raw, decode_columns, decode_columns_into, decode_composite_key, decode_key_value,
9 decode_pk_integer, decode_pk_into, decode_row_into, encode_composite_key,
10 encode_composite_key_into, encode_row, encode_row_into, row_non_pk_count, RawColumn,
11};
12use crate::error::{Result, SqlError};
13use crate::eval::{eval_expr, is_truthy, referenced_columns, ColumnMap};
14use crate::parser::*;
15use crate::planner::{self, ScanPlan};
16use crate::schema::SchemaManager;
17use crate::types::*;
18
19fn encode_index_key(idx: &IndexDef, row: &[Value], pk_values: &[Value]) -> Vec<u8> {
22 let indexed_values: Vec<Value> = idx
23 .columns
24 .iter()
25 .map(|&col_idx| row[col_idx as usize].clone())
26 .collect();
27
28 if idx.unique {
29 let any_null = indexed_values.iter().any(|v| v.is_null());
30 if !any_null {
31 return encode_composite_key(&indexed_values);
32 }
33 }
34
35 let mut all_values = indexed_values;
36 all_values.extend_from_slice(pk_values);
37 encode_composite_key(&all_values)
38}
39
40fn encode_index_value(idx: &IndexDef, row: &[Value], pk_values: &[Value]) -> Vec<u8> {
41 if idx.unique {
42 let indexed_values: Vec<Value> = idx
43 .columns
44 .iter()
45 .map(|&col_idx| row[col_idx as usize].clone())
46 .collect();
47 let any_null = indexed_values.iter().any(|v| v.is_null());
48 if !any_null {
49 return encode_composite_key(pk_values);
50 }
51 }
52 vec![]
53}
54
55fn insert_index_entries(
56 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
57 table_schema: &TableSchema,
58 row: &[Value],
59 pk_values: &[Value],
60) -> Result<()> {
61 for idx in &table_schema.indices {
62 let idx_table = TableSchema::index_table_name(&table_schema.name, &idx.name);
63 let key = encode_index_key(idx, row, pk_values);
64 let value = encode_index_value(idx, row, pk_values);
65
66 let is_new = wtx
67 .table_insert(&idx_table, &key, &value)
68 .map_err(SqlError::Storage)?;
69
70 if idx.unique && !is_new {
71 let indexed_values: Vec<Value> = idx
72 .columns
73 .iter()
74 .map(|&col_idx| row[col_idx as usize].clone())
75 .collect();
76 let any_null = indexed_values.iter().any(|v| v.is_null());
77 if !any_null {
78 return Err(SqlError::UniqueViolation(idx.name.clone()));
79 }
80 }
81 }
82 Ok(())
83}
84
85fn delete_index_entries(
86 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
87 table_schema: &TableSchema,
88 row: &[Value],
89 pk_values: &[Value],
90) -> Result<()> {
91 for idx in &table_schema.indices {
92 let idx_table = TableSchema::index_table_name(&table_schema.name, &idx.name);
93 let key = encode_index_key(idx, row, pk_values);
94 wtx.table_delete(&idx_table, &key)
95 .map_err(SqlError::Storage)?;
96 }
97 Ok(())
98}
99
100fn index_columns_changed(idx: &IndexDef, old_row: &[Value], new_row: &[Value]) -> bool {
101 idx.columns
102 .iter()
103 .any(|&col_idx| old_row[col_idx as usize] != new_row[col_idx as usize])
104}
105
106pub fn execute(
108 db: &Database,
109 schema: &mut SchemaManager,
110 stmt: &Statement,
111 params: &[Value],
112) -> Result<ExecutionResult> {
113 match stmt {
114 Statement::CreateTable(ct) => exec_create_table(db, schema, ct),
115 Statement::DropTable(dt) => exec_drop_table(db, schema, dt),
116 Statement::CreateIndex(ci) => exec_create_index(db, schema, ci),
117 Statement::DropIndex(di) => exec_drop_index(db, schema, di),
118 Statement::AlterTable(at) => exec_alter_table(db, schema, at),
119 Statement::Insert(ins) => exec_insert(db, schema, ins, params),
120 Statement::Select(sq) => exec_select_query(db, schema, sq),
121 Statement::Update(upd) => exec_update(db, schema, upd),
122 Statement::Delete(del) => exec_delete(db, schema, del),
123 Statement::Explain(inner) => explain(schema, inner),
124 Statement::Begin | Statement::Commit | Statement::Rollback => Err(SqlError::Unsupported(
125 "transaction control in auto-commit mode".into(),
126 )),
127 }
128}
129
130pub fn execute_in_txn(
132 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
133 schema: &mut SchemaManager,
134 stmt: &Statement,
135 params: &[Value],
136) -> Result<ExecutionResult> {
137 match stmt {
138 Statement::CreateTable(ct) => exec_create_table_in_txn(wtx, schema, ct),
139 Statement::DropTable(dt) => exec_drop_table_in_txn(wtx, schema, dt),
140 Statement::CreateIndex(ci) => exec_create_index_in_txn(wtx, schema, ci),
141 Statement::DropIndex(di) => exec_drop_index_in_txn(wtx, schema, di),
142 Statement::AlterTable(at) => exec_alter_table_in_txn(wtx, schema, at),
143 Statement::Insert(ins) => {
144 let mut bufs = InsertBufs::new();
145 exec_insert_in_txn(wtx, schema, ins, params, &mut bufs)
146 }
147 Statement::Select(sq) => exec_select_query_in_txn(wtx, schema, sq),
148 Statement::Update(upd) => exec_update_in_txn(wtx, schema, upd),
149 Statement::Delete(del) => exec_delete_in_txn(wtx, schema, del),
150 Statement::Explain(inner) => explain(schema, inner),
151 Statement::Begin | Statement::Commit | Statement::Rollback => {
152 Err(SqlError::Unsupported("nested transaction control".into()))
153 }
154 }
155}
156
157pub fn explain(schema: &SchemaManager, stmt: &Statement) -> Result<ExecutionResult> {
160 let lines = match stmt {
161 Statement::Select(sq) => {
162 let mut lines = Vec::new();
163 let cte_names: Vec<&str> = sq.ctes.iter().map(|c| c.name.as_str()).collect();
164 for cte in &sq.ctes {
165 lines.push(format!("WITH {} AS", cte.name));
166 lines.extend(
167 explain_query_body_cte(schema, &cte.body, &cte_names)?
168 .into_iter()
169 .map(|l| format!(" {l}")),
170 );
171 }
172 lines.extend(explain_query_body_cte(schema, &sq.body, &cte_names)?);
173 lines
174 }
175 Statement::Insert(ins) => match &ins.source {
176 InsertSource::Values(rows) => {
177 vec![format!(
178 "INSERT INTO {} ({} rows)",
179 ins.table.to_ascii_lowercase(),
180 rows.len()
181 )]
182 }
183 InsertSource::Select(sq) => {
184 let mut lines = vec![format!(
185 "INSERT INTO {} ... SELECT",
186 ins.table.to_ascii_lowercase()
187 )];
188 let cte_names: Vec<&str> = sq.ctes.iter().map(|c| c.name.as_str()).collect();
189 for cte in &sq.ctes {
190 lines.push(format!(" WITH {} AS", cte.name));
191 lines.extend(
192 explain_query_body_cte(schema, &cte.body, &cte_names)?
193 .into_iter()
194 .map(|l| format!(" {l}")),
195 );
196 }
197 lines.extend(explain_query_body_cte(schema, &sq.body, &cte_names)?);
198 lines
199 }
200 },
201 Statement::Update(upd) => explain_dml(schema, &upd.table, &upd.where_clause, "UPDATE")?,
202 Statement::Delete(del) => {
203 explain_dml(schema, &del.table, &del.where_clause, "DELETE FROM")?
204 }
205 Statement::AlterTable(at) => {
206 let desc = match &at.op {
207 AlterTableOp::AddColumn { column, .. } => {
208 format!("ALTER TABLE {} ADD COLUMN {}", at.table, column.name)
209 }
210 AlterTableOp::DropColumn { name, .. } => {
211 format!("ALTER TABLE {} DROP COLUMN {}", at.table, name)
212 }
213 AlterTableOp::RenameColumn {
214 old_name, new_name, ..
215 } => {
216 format!(
217 "ALTER TABLE {} RENAME COLUMN {} TO {}",
218 at.table, old_name, new_name
219 )
220 }
221 AlterTableOp::RenameTable { new_name } => {
222 format!("ALTER TABLE {} RENAME TO {}", at.table, new_name)
223 }
224 };
225 vec![desc]
226 }
227 Statement::Explain(_) => {
228 return Err(SqlError::Unsupported("EXPLAIN EXPLAIN".into()));
229 }
230 _ => {
231 return Err(SqlError::Unsupported(
232 "EXPLAIN for this statement type".into(),
233 ));
234 }
235 };
236
237 let rows = lines
238 .into_iter()
239 .map(|line| vec![Value::Text(line.into())])
240 .collect();
241 Ok(ExecutionResult::Query(QueryResult {
242 columns: vec!["plan".into()],
243 rows,
244 }))
245}
246
247fn explain_dml(
248 schema: &SchemaManager,
249 table: &str,
250 where_clause: &Option<Expr>,
251 verb: &str,
252) -> Result<Vec<String>> {
253 let lower = table.to_ascii_lowercase();
254 let table_schema = schema
255 .get(&lower)
256 .ok_or_else(|| SqlError::TableNotFound(table.to_string()))?;
257 let plan = planner::plan_select(table_schema, where_clause);
258 let scan_line = format_scan_line(&lower, &None, &plan, table_schema);
259 Ok(vec![format!("{verb} {}", scan_line)])
260}
261
262fn explain_query_body_cte(
263 schema: &SchemaManager,
264 body: &QueryBody,
265 cte_names: &[&str],
266) -> Result<Vec<String>> {
267 match body {
268 QueryBody::Select(sel) => explain_select_cte(schema, sel, cte_names),
269 QueryBody::Compound(comp) => {
270 let op_name = match (&comp.op, comp.all) {
271 (SetOp::Union, true) => "UNION ALL",
272 (SetOp::Union, false) => "UNION",
273 (SetOp::Intersect, true) => "INTERSECT ALL",
274 (SetOp::Intersect, false) => "INTERSECT",
275 (SetOp::Except, true) => "EXCEPT ALL",
276 (SetOp::Except, false) => "EXCEPT",
277 };
278 let mut lines = vec![op_name.to_string()];
279 let left_lines = explain_query_body_cte(schema, &comp.left, cte_names)?;
280 for l in left_lines {
281 lines.push(format!(" {l}"));
282 }
283 let right_lines = explain_query_body_cte(schema, &comp.right, cte_names)?;
284 for l in right_lines {
285 lines.push(format!(" {l}"));
286 }
287 Ok(lines)
288 }
289 }
290}
291
292fn explain_select_cte(
293 schema: &SchemaManager,
294 stmt: &SelectStmt,
295 cte_names: &[&str],
296) -> Result<Vec<String>> {
297 let mut lines = Vec::new();
298
299 if stmt.from.is_empty() {
300 lines.push("CONSTANT ROW".into());
301 return Ok(lines);
302 }
303
304 let lower_from = stmt.from.to_ascii_lowercase();
305
306 if cte_names
307 .iter()
308 .any(|n| n.eq_ignore_ascii_case(&lower_from))
309 {
310 lines.push(format!("SCAN CTE {lower_from}"));
311 for join in &stmt.joins {
312 let jname = join.table.name.to_ascii_lowercase();
313 if cte_names.iter().any(|n| n.eq_ignore_ascii_case(&jname)) {
314 lines.push(format!("SCAN CTE {jname}"));
315 } else {
316 let js = schema
317 .get(&jname)
318 .ok_or_else(|| SqlError::TableNotFound(join.table.name.clone()))?;
319 let jp = planner::plan_select(js, &None);
320 lines.push(format_scan_line(&jname, &join.table.alias, &jp, js));
321 }
322 }
323 if !stmt.joins.is_empty() {
324 lines.push("NESTED LOOP".into());
325 }
326 if !stmt.group_by.is_empty() {
327 lines.push("GROUP BY".into());
328 }
329 if stmt.distinct {
330 lines.push("DISTINCT".into());
331 }
332 if !stmt.order_by.is_empty() {
333 lines.push("SORT".into());
334 }
335 if stmt.limit.is_some() {
336 lines.push("LIMIT".into());
337 }
338 return Ok(lines);
339 }
340
341 let from_schema = schema
342 .get(&lower_from)
343 .ok_or_else(|| SqlError::TableNotFound(stmt.from.clone()))?;
344
345 if stmt.joins.is_empty() {
346 let plan = planner::plan_select(from_schema, &stmt.where_clause);
347 lines.push(format_scan_line(
348 &lower_from,
349 &stmt.from_alias,
350 &plan,
351 from_schema,
352 ));
353 } else {
354 let from_plan = planner::plan_select(from_schema, &None);
355 lines.push(format_scan_line(
356 &lower_from,
357 &stmt.from_alias,
358 &from_plan,
359 from_schema,
360 ));
361
362 for join in &stmt.joins {
363 let inner_lower = join.table.name.to_ascii_lowercase();
364 if cte_names
365 .iter()
366 .any(|n| n.eq_ignore_ascii_case(&inner_lower))
367 {
368 lines.push(format!("SCAN CTE {inner_lower}"));
369 } else {
370 let inner_schema = schema
371 .get(&inner_lower)
372 .ok_or_else(|| SqlError::TableNotFound(join.table.name.clone()))?;
373 let inner_plan = planner::plan_select(inner_schema, &None);
374 lines.push(format_scan_line(
375 &inner_lower,
376 &join.table.alias,
377 &inner_plan,
378 inner_schema,
379 ));
380 }
381 }
382
383 let join_type_str = match stmt.joins.last().map(|j| j.join_type) {
384 Some(JoinType::Left) => "LEFT JOIN",
385 Some(JoinType::Right) => "RIGHT JOIN",
386 Some(JoinType::Cross) => "CROSS JOIN",
387 _ => "NESTED LOOP",
388 };
389 lines.push(join_type_str.into());
390 }
391
392 if stmt.where_clause.is_some() && stmt.joins.is_empty() {
393 let plan = planner::plan_select(from_schema, &stmt.where_clause);
394 if matches!(plan, ScanPlan::SeqScan) {
395 lines.push("FILTER".into());
396 }
397 }
398
399 if let Some(ref w) = stmt.where_clause {
400 if !stmt.joins.is_empty() && has_subquery(w) {
401 lines.push("SUBQUERY".into());
402 }
403 }
404
405 explain_subqueries(stmt, &mut lines);
406
407 if !stmt.group_by.is_empty() {
408 lines.push("GROUP BY".into());
409 }
410
411 if stmt.distinct {
412 lines.push("DISTINCT".into());
413 }
414
415 if !stmt.order_by.is_empty() {
416 lines.push("SORT".into());
417 }
418
419 if let Some(ref offset_expr) = stmt.offset {
420 if let Ok(n) = eval_const_int(offset_expr) {
421 lines.push(format!("OFFSET {n}"));
422 } else {
423 lines.push("OFFSET".into());
424 }
425 }
426
427 if let Some(ref limit_expr) = stmt.limit {
428 if let Ok(n) = eval_const_int(limit_expr) {
429 lines.push(format!("LIMIT {n}"));
430 } else {
431 lines.push("LIMIT".into());
432 }
433 }
434
435 Ok(lines)
436}
437
438fn explain_subqueries(stmt: &SelectStmt, lines: &mut Vec<String>) {
439 let mut count = 0;
440 if let Some(ref w) = stmt.where_clause {
441 count += count_subqueries(w);
442 }
443 if let Some(ref h) = stmt.having {
444 count += count_subqueries(h);
445 }
446 for col in &stmt.columns {
447 if let SelectColumn::Expr { expr, .. } = col {
448 count += count_subqueries(expr);
449 }
450 }
451 for _ in 0..count {
452 lines.push("SUBQUERY".into());
453 }
454}
455
456fn count_subqueries(expr: &Expr) -> usize {
457 match expr {
458 Expr::InSubquery { expr: e, .. } => 1 + count_subqueries(e),
459 Expr::ScalarSubquery(_) => 1,
460 Expr::Exists { .. } => 1,
461 Expr::BinaryOp { left, right, .. } => count_subqueries(left) + count_subqueries(right),
462 Expr::UnaryOp { expr: e, .. } => count_subqueries(e),
463 Expr::IsNull(e) | Expr::IsNotNull(e) => count_subqueries(e),
464 Expr::Function { args, .. } => args.iter().map(count_subqueries).sum(),
465 Expr::Between {
466 expr: e, low, high, ..
467 } => count_subqueries(e) + count_subqueries(low) + count_subqueries(high),
468 Expr::Like {
469 expr: e, pattern, ..
470 } => count_subqueries(e) + count_subqueries(pattern),
471 Expr::Case {
472 operand,
473 conditions,
474 else_result,
475 } => {
476 let mut n = 0;
477 if let Some(op) = operand {
478 n += count_subqueries(op);
479 }
480 for (c, r) in conditions {
481 n += count_subqueries(c) + count_subqueries(r);
482 }
483 if let Some(el) = else_result {
484 n += count_subqueries(el);
485 }
486 n
487 }
488 Expr::Coalesce(args) => args.iter().map(count_subqueries).sum(),
489 Expr::Cast { expr: e, .. } => count_subqueries(e),
490 Expr::InList { expr: e, list, .. } => {
491 count_subqueries(e) + list.iter().map(count_subqueries).sum::<usize>()
492 }
493 _ => 0,
494 }
495}
496
497fn format_scan_line(
498 table_name: &str,
499 alias: &Option<String>,
500 plan: &ScanPlan,
501 table_schema: &TableSchema,
502) -> String {
503 let alias_part = match alias {
504 Some(a) if !a.eq_ignore_ascii_case(table_name) => {
505 format!(" AS {}", a.to_ascii_lowercase())
506 }
507 _ => String::new(),
508 };
509
510 let desc = planner::describe_plan(plan, table_schema);
511
512 if desc.is_empty() {
513 format!("SCAN TABLE {table_name}{alias_part}")
514 } else {
515 format!("SEARCH TABLE {table_name}{alias_part} {desc}")
516 }
517}
518
519fn validate_foreign_keys(
524 schema: &SchemaManager,
525 table_schema: &TableSchema,
526 foreign_keys: &[ForeignKeySchemaEntry],
527) -> Result<()> {
528 for fk in foreign_keys {
529 let parent = if fk.foreign_table == table_schema.name {
531 table_schema
532 } else {
533 schema.get(&fk.foreign_table).ok_or_else(|| {
534 SqlError::Unsupported(format!(
535 "FOREIGN KEY references non-existent table '{}'",
536 fk.foreign_table
537 ))
538 })?
539 };
540
541 let ref_col_indices: Vec<u16> = fk
543 .referred_columns
544 .iter()
545 .map(|rc| {
546 parent
547 .column_index(rc)
548 .map(|i| i as u16)
549 .ok_or_else(|| SqlError::ColumnNotFound(rc.clone()))
550 })
551 .collect::<Result<_>>()?;
552
553 if fk.columns.len() != ref_col_indices.len() {
554 return Err(SqlError::Unsupported(format!(
555 "FOREIGN KEY on '{}': column count mismatch",
556 table_schema.name
557 )));
558 }
559
560 let is_pk = parent.primary_key_columns == ref_col_indices;
562
563 let has_unique = !is_pk
565 && parent
566 .indices
567 .iter()
568 .any(|idx| idx.unique && idx.columns == ref_col_indices);
569
570 if !is_pk && !has_unique {
571 return Err(SqlError::Unsupported(format!(
572 "FOREIGN KEY on '{}': referred columns in '{}' are not PRIMARY KEY or UNIQUE",
573 table_schema.name, fk.foreign_table
574 )));
575 }
576 }
577 Ok(())
578}
579
580fn create_fk_auto_indices(
582 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
583 mut table_schema: TableSchema,
584) -> Result<TableSchema> {
585 let fks: Vec<(Vec<u16>, String)> = table_schema
586 .foreign_keys
587 .iter()
588 .enumerate()
589 .map(|(i, fk)| {
590 let name = fk
591 .name
592 .as_ref()
593 .map(|n| format!("__fk_{}_{}", table_schema.name, n))
594 .unwrap_or_else(|| format!("__fk_{}_{}", table_schema.name, i));
595 (fk.columns.clone(), name)
596 })
597 .collect();
598
599 for (cols, idx_name) in fks {
600 let already_covered = table_schema.indices.iter().any(|idx| idx.columns == cols);
602 if already_covered {
603 continue;
604 }
605
606 let idx_def = IndexDef {
607 name: idx_name.clone(),
608 columns: cols,
609 unique: false,
610 };
611 let idx_table = TableSchema::index_table_name(&table_schema.name, &idx_name);
612 wtx.create_table(&idx_table).map_err(SqlError::Storage)?;
613 table_schema.indices.push(idx_def);
615 }
616 Ok(table_schema)
617}
618
619fn exec_create_table(
622 db: &Database,
623 schema: &mut SchemaManager,
624 stmt: &CreateTableStmt,
625) -> Result<ExecutionResult> {
626 let lower_name = stmt.name.to_ascii_lowercase();
627
628 if schema.contains(&lower_name) {
629 if stmt.if_not_exists {
630 return Ok(ExecutionResult::Ok);
631 }
632 return Err(SqlError::TableAlreadyExists(stmt.name.clone()));
633 }
634
635 if stmt.primary_key.is_empty() {
636 return Err(SqlError::PrimaryKeyRequired);
637 }
638
639 let mut seen = std::collections::HashSet::new();
640 for col in &stmt.columns {
641 let lower = col.name.to_ascii_lowercase();
642 if !seen.insert(lower.clone()) {
643 return Err(SqlError::DuplicateColumn(col.name.clone()));
644 }
645 }
646
647 let columns: Vec<ColumnDef> = stmt
648 .columns
649 .iter()
650 .enumerate()
651 .map(|(i, c)| ColumnDef {
652 name: c.name.to_ascii_lowercase(),
653 data_type: c.data_type,
654 nullable: c.nullable,
655 position: i as u16,
656 default_expr: c.default_expr.clone(),
657 default_sql: c.default_sql.clone(),
658 check_expr: c.check_expr.clone(),
659 check_sql: c.check_sql.clone(),
660 check_name: c.check_name.clone(),
661 })
662 .collect();
663
664 let primary_key_columns: Vec<u16> = stmt
665 .primary_key
666 .iter()
667 .map(|pk_name| {
668 let lower = pk_name.to_ascii_lowercase();
669 columns
670 .iter()
671 .position(|c| c.name == lower)
672 .map(|i| i as u16)
673 .ok_or_else(|| SqlError::ColumnNotFound(pk_name.clone()))
674 })
675 .collect::<Result<_>>()?;
676
677 let check_constraints: Vec<TableCheckDef> = stmt
678 .check_constraints
679 .iter()
680 .map(|tc| TableCheckDef {
681 name: tc.name.clone(),
682 expr: tc.expr.clone(),
683 sql: tc.sql.clone(),
684 })
685 .collect();
686
687 let foreign_keys: Vec<ForeignKeySchemaEntry> = stmt
688 .foreign_keys
689 .iter()
690 .map(|fk| {
691 let col_indices: Vec<u16> = fk
692 .columns
693 .iter()
694 .map(|cn| {
695 let lower = cn.to_ascii_lowercase();
696 columns
697 .iter()
698 .position(|c| c.name == lower)
699 .map(|i| i as u16)
700 .ok_or_else(|| SqlError::ColumnNotFound(cn.clone()))
701 })
702 .collect::<Result<_>>()?;
703 Ok(ForeignKeySchemaEntry {
704 name: fk.name.clone(),
705 columns: col_indices,
706 foreign_table: fk.foreign_table.to_ascii_lowercase(),
707 referred_columns: fk
708 .referred_columns
709 .iter()
710 .map(|s| s.to_ascii_lowercase())
711 .collect(),
712 })
713 })
714 .collect::<Result<_>>()?;
715
716 let table_schema = TableSchema::new(
717 lower_name.clone(),
718 columns,
719 primary_key_columns,
720 vec![],
721 check_constraints,
722 foreign_keys,
723 );
724
725 validate_foreign_keys(schema, &table_schema, &table_schema.foreign_keys)?;
727
728 let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
729 SchemaManager::ensure_schema_table(&mut wtx)?;
730 wtx.create_table(lower_name.as_bytes())
731 .map_err(SqlError::Storage)?;
732
733 let table_schema = create_fk_auto_indices(&mut wtx, table_schema)?;
735
736 SchemaManager::save_schema(&mut wtx, &table_schema)?;
737 wtx.commit().map_err(SqlError::Storage)?;
738
739 schema.register(table_schema);
740 Ok(ExecutionResult::Ok)
741}
742
743fn exec_drop_table(
744 db: &Database,
745 schema: &mut SchemaManager,
746 stmt: &DropTableStmt,
747) -> Result<ExecutionResult> {
748 let lower_name = stmt.name.to_ascii_lowercase();
749
750 if !schema.contains(&lower_name) {
751 if stmt.if_exists {
752 return Ok(ExecutionResult::Ok);
753 }
754 return Err(SqlError::TableNotFound(stmt.name.clone()));
755 }
756
757 for (child_table, _fk) in schema.child_fks_for(&lower_name) {
759 if child_table != lower_name {
760 return Err(SqlError::ForeignKeyViolation(format!(
761 "cannot drop table '{}': referenced by foreign key in '{}'",
762 lower_name, child_table
763 )));
764 }
765 }
766
767 let table_schema = schema.get(&lower_name).unwrap();
768 let idx_tables: Vec<Vec<u8>> = table_schema
769 .indices
770 .iter()
771 .map(|idx| TableSchema::index_table_name(&lower_name, &idx.name))
772 .collect();
773
774 let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
775 for idx_table in &idx_tables {
776 wtx.drop_table(idx_table).map_err(SqlError::Storage)?;
777 }
778 wtx.drop_table(lower_name.as_bytes())
779 .map_err(SqlError::Storage)?;
780 SchemaManager::delete_schema(&mut wtx, &lower_name)?;
781 wtx.commit().map_err(SqlError::Storage)?;
782
783 schema.remove(&lower_name);
784 Ok(ExecutionResult::Ok)
785}
786
787fn exec_create_table_in_txn(
788 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
789 schema: &mut SchemaManager,
790 stmt: &CreateTableStmt,
791) -> Result<ExecutionResult> {
792 let lower_name = stmt.name.to_ascii_lowercase();
793
794 if schema.contains(&lower_name) {
795 if stmt.if_not_exists {
796 return Ok(ExecutionResult::Ok);
797 }
798 return Err(SqlError::TableAlreadyExists(stmt.name.clone()));
799 }
800
801 if stmt.primary_key.is_empty() {
802 return Err(SqlError::PrimaryKeyRequired);
803 }
804
805 let mut seen = std::collections::HashSet::new();
806 for col in &stmt.columns {
807 let lower = col.name.to_ascii_lowercase();
808 if !seen.insert(lower.clone()) {
809 return Err(SqlError::DuplicateColumn(col.name.clone()));
810 }
811 }
812
813 let columns: Vec<ColumnDef> = stmt
814 .columns
815 .iter()
816 .enumerate()
817 .map(|(i, c)| ColumnDef {
818 name: c.name.to_ascii_lowercase(),
819 data_type: c.data_type,
820 nullable: c.nullable,
821 position: i as u16,
822 default_expr: c.default_expr.clone(),
823 default_sql: c.default_sql.clone(),
824 check_expr: c.check_expr.clone(),
825 check_sql: c.check_sql.clone(),
826 check_name: c.check_name.clone(),
827 })
828 .collect();
829
830 let primary_key_columns: Vec<u16> = stmt
831 .primary_key
832 .iter()
833 .map(|pk_name| {
834 let lower = pk_name.to_ascii_lowercase();
835 columns
836 .iter()
837 .position(|c| c.name == lower)
838 .map(|i| i as u16)
839 .ok_or_else(|| SqlError::ColumnNotFound(pk_name.clone()))
840 })
841 .collect::<Result<_>>()?;
842
843 let check_constraints: Vec<TableCheckDef> = stmt
844 .check_constraints
845 .iter()
846 .map(|tc| TableCheckDef {
847 name: tc.name.clone(),
848 expr: tc.expr.clone(),
849 sql: tc.sql.clone(),
850 })
851 .collect();
852
853 let foreign_keys: Vec<ForeignKeySchemaEntry> = stmt
854 .foreign_keys
855 .iter()
856 .map(|fk| {
857 let col_indices: Vec<u16> = fk
858 .columns
859 .iter()
860 .map(|cn| {
861 let lower = cn.to_ascii_lowercase();
862 columns
863 .iter()
864 .position(|c| c.name == lower)
865 .map(|i| i as u16)
866 .ok_or_else(|| SqlError::ColumnNotFound(cn.clone()))
867 })
868 .collect::<Result<_>>()?;
869 Ok(ForeignKeySchemaEntry {
870 name: fk.name.clone(),
871 columns: col_indices,
872 foreign_table: fk.foreign_table.to_ascii_lowercase(),
873 referred_columns: fk
874 .referred_columns
875 .iter()
876 .map(|s| s.to_ascii_lowercase())
877 .collect(),
878 })
879 })
880 .collect::<Result<_>>()?;
881
882 let table_schema = TableSchema::new(
883 lower_name.clone(),
884 columns,
885 primary_key_columns,
886 vec![],
887 check_constraints,
888 foreign_keys,
889 );
890
891 validate_foreign_keys(schema, &table_schema, &table_schema.foreign_keys)?;
892
893 SchemaManager::ensure_schema_table(wtx)?;
894 wtx.create_table(lower_name.as_bytes())
895 .map_err(SqlError::Storage)?;
896
897 let table_schema = create_fk_auto_indices(wtx, table_schema)?;
898
899 SchemaManager::save_schema(wtx, &table_schema)?;
900
901 schema.register(table_schema);
902 Ok(ExecutionResult::Ok)
903}
904
905fn exec_drop_table_in_txn(
906 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
907 schema: &mut SchemaManager,
908 stmt: &DropTableStmt,
909) -> Result<ExecutionResult> {
910 let lower_name = stmt.name.to_ascii_lowercase();
911
912 if !schema.contains(&lower_name) {
913 if stmt.if_exists {
914 return Ok(ExecutionResult::Ok);
915 }
916 return Err(SqlError::TableNotFound(stmt.name.clone()));
917 }
918
919 for (child_table, _fk) in schema.child_fks_for(&lower_name) {
921 if child_table != lower_name {
922 return Err(SqlError::ForeignKeyViolation(format!(
923 "cannot drop table '{}': referenced by foreign key in '{}'",
924 lower_name, child_table
925 )));
926 }
927 }
928
929 let table_schema = schema.get(&lower_name).unwrap();
930 let idx_tables: Vec<Vec<u8>> = table_schema
931 .indices
932 .iter()
933 .map(|idx| TableSchema::index_table_name(&lower_name, &idx.name))
934 .collect();
935
936 for idx_table in &idx_tables {
937 wtx.drop_table(idx_table).map_err(SqlError::Storage)?;
938 }
939 wtx.drop_table(lower_name.as_bytes())
940 .map_err(SqlError::Storage)?;
941 SchemaManager::delete_schema(wtx, &lower_name)?;
942
943 schema.remove(&lower_name);
944 Ok(ExecutionResult::Ok)
945}
946
947fn exec_create_index(
948 db: &Database,
949 schema: &mut SchemaManager,
950 stmt: &CreateIndexStmt,
951) -> Result<ExecutionResult> {
952 let lower_table = stmt.table_name.to_ascii_lowercase();
953 let lower_idx = stmt.index_name.to_ascii_lowercase();
954
955 let table_schema = schema
956 .get(&lower_table)
957 .ok_or_else(|| SqlError::TableNotFound(stmt.table_name.clone()))?;
958
959 if table_schema.index_by_name(&lower_idx).is_some() {
960 if stmt.if_not_exists {
961 return Ok(ExecutionResult::Ok);
962 }
963 return Err(SqlError::IndexAlreadyExists(stmt.index_name.clone()));
964 }
965
966 let col_indices: Vec<u16> = stmt
967 .columns
968 .iter()
969 .map(|col_name| {
970 let lower = col_name.to_ascii_lowercase();
971 table_schema
972 .column_index(&lower)
973 .map(|i| i as u16)
974 .ok_or_else(|| SqlError::ColumnNotFound(col_name.clone()))
975 })
976 .collect::<Result<_>>()?;
977
978 let idx_def = IndexDef {
979 name: lower_idx.clone(),
980 columns: col_indices,
981 unique: stmt.unique,
982 };
983
984 let idx_table = TableSchema::index_table_name(&lower_table, &lower_idx);
985
986 let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
987 SchemaManager::ensure_schema_table(&mut wtx)?;
988 wtx.create_table(&idx_table).map_err(SqlError::Storage)?;
989
990 let pk_indices = table_schema.pk_indices();
992 let mut rows: Vec<Vec<Value>> = Vec::new();
993 {
994 let mut scan_err: Option<SqlError> = None;
995 wtx.table_for_each(lower_table.as_bytes(), |key, value| {
996 match decode_full_row(table_schema, key, value) {
997 Ok(row) => rows.push(row),
998 Err(e) => scan_err = Some(e),
999 }
1000 Ok(())
1001 })
1002 .map_err(SqlError::Storage)?;
1003 if let Some(e) = scan_err {
1004 return Err(e);
1005 }
1006 }
1007
1008 for row in &rows {
1009 let pk_values: Vec<Value> = pk_indices.iter().map(|&i| row[i].clone()).collect();
1010 let key = encode_index_key(&idx_def, row, &pk_values);
1011 let value = encode_index_value(&idx_def, row, &pk_values);
1012 let is_new = wtx
1013 .table_insert(&idx_table, &key, &value)
1014 .map_err(SqlError::Storage)?;
1015 if idx_def.unique && !is_new {
1016 let indexed_values: Vec<Value> = idx_def
1017 .columns
1018 .iter()
1019 .map(|&col_idx| row[col_idx as usize].clone())
1020 .collect();
1021 let any_null = indexed_values.iter().any(|v| v.is_null());
1022 if !any_null {
1023 return Err(SqlError::UniqueViolation(stmt.index_name.clone()));
1024 }
1025 }
1026 }
1027
1028 let mut updated_schema = table_schema.clone();
1029 updated_schema.indices.push(idx_def);
1030 SchemaManager::save_schema(&mut wtx, &updated_schema)?;
1031 wtx.commit().map_err(SqlError::Storage)?;
1032
1033 schema.register(updated_schema);
1034 Ok(ExecutionResult::Ok)
1035}
1036
1037fn exec_drop_index(
1038 db: &Database,
1039 schema: &mut SchemaManager,
1040 stmt: &DropIndexStmt,
1041) -> Result<ExecutionResult> {
1042 let lower_idx = stmt.index_name.to_ascii_lowercase();
1043
1044 let (table_name, _idx_pos) = match find_index_in_schemas(schema, &lower_idx) {
1045 Some(found) => found,
1046 None => {
1047 if stmt.if_exists {
1048 return Ok(ExecutionResult::Ok);
1049 }
1050 return Err(SqlError::IndexNotFound(stmt.index_name.clone()));
1051 }
1052 };
1053
1054 let idx_table = TableSchema::index_table_name(&table_name, &lower_idx);
1055
1056 let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
1057 wtx.drop_table(&idx_table).map_err(SqlError::Storage)?;
1058
1059 let table_schema = schema.get(&table_name).unwrap();
1060 let mut updated_schema = table_schema.clone();
1061 updated_schema.indices.retain(|i| i.name != lower_idx);
1062 SchemaManager::save_schema(&mut wtx, &updated_schema)?;
1063 wtx.commit().map_err(SqlError::Storage)?;
1064
1065 schema.register(updated_schema);
1066 Ok(ExecutionResult::Ok)
1067}
1068
1069fn exec_create_index_in_txn(
1070 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
1071 schema: &mut SchemaManager,
1072 stmt: &CreateIndexStmt,
1073) -> Result<ExecutionResult> {
1074 let lower_table = stmt.table_name.to_ascii_lowercase();
1075 let lower_idx = stmt.index_name.to_ascii_lowercase();
1076
1077 let table_schema = schema
1078 .get(&lower_table)
1079 .ok_or_else(|| SqlError::TableNotFound(stmt.table_name.clone()))?;
1080
1081 if table_schema.index_by_name(&lower_idx).is_some() {
1082 if stmt.if_not_exists {
1083 return Ok(ExecutionResult::Ok);
1084 }
1085 return Err(SqlError::IndexAlreadyExists(stmt.index_name.clone()));
1086 }
1087
1088 let col_indices: Vec<u16> = stmt
1089 .columns
1090 .iter()
1091 .map(|col_name| {
1092 let lower = col_name.to_ascii_lowercase();
1093 table_schema
1094 .column_index(&lower)
1095 .map(|i| i as u16)
1096 .ok_or_else(|| SqlError::ColumnNotFound(col_name.clone()))
1097 })
1098 .collect::<Result<_>>()?;
1099
1100 let idx_def = IndexDef {
1101 name: lower_idx.clone(),
1102 columns: col_indices,
1103 unique: stmt.unique,
1104 };
1105
1106 let idx_table = TableSchema::index_table_name(&lower_table, &lower_idx);
1107
1108 SchemaManager::ensure_schema_table(wtx)?;
1109 wtx.create_table(&idx_table).map_err(SqlError::Storage)?;
1110
1111 let pk_indices = table_schema.pk_indices();
1112 let mut rows: Vec<Vec<Value>> = Vec::new();
1113 {
1114 let mut scan_err: Option<SqlError> = None;
1115 wtx.table_for_each(lower_table.as_bytes(), |key, value| {
1116 match decode_full_row(table_schema, key, value) {
1117 Ok(row) => rows.push(row),
1118 Err(e) => scan_err = Some(e),
1119 }
1120 Ok(())
1121 })
1122 .map_err(SqlError::Storage)?;
1123 if let Some(e) = scan_err {
1124 return Err(e);
1125 }
1126 }
1127
1128 for row in &rows {
1129 let pk_values: Vec<Value> = pk_indices.iter().map(|&i| row[i].clone()).collect();
1130 let key = encode_index_key(&idx_def, row, &pk_values);
1131 let value = encode_index_value(&idx_def, row, &pk_values);
1132 let is_new = wtx
1133 .table_insert(&idx_table, &key, &value)
1134 .map_err(SqlError::Storage)?;
1135 if idx_def.unique && !is_new {
1136 let indexed_values: Vec<Value> = idx_def
1137 .columns
1138 .iter()
1139 .map(|&col_idx| row[col_idx as usize].clone())
1140 .collect();
1141 let any_null = indexed_values.iter().any(|v| v.is_null());
1142 if !any_null {
1143 return Err(SqlError::UniqueViolation(stmt.index_name.clone()));
1144 }
1145 }
1146 }
1147
1148 let mut updated_schema = table_schema.clone();
1149 updated_schema.indices.push(idx_def);
1150 SchemaManager::save_schema(wtx, &updated_schema)?;
1151
1152 schema.register(updated_schema);
1153 Ok(ExecutionResult::Ok)
1154}
1155
1156fn exec_drop_index_in_txn(
1157 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
1158 schema: &mut SchemaManager,
1159 stmt: &DropIndexStmt,
1160) -> Result<ExecutionResult> {
1161 let lower_idx = stmt.index_name.to_ascii_lowercase();
1162
1163 let (table_name, _idx_pos) = match find_index_in_schemas(schema, &lower_idx) {
1164 Some(found) => found,
1165 None => {
1166 if stmt.if_exists {
1167 return Ok(ExecutionResult::Ok);
1168 }
1169 return Err(SqlError::IndexNotFound(stmt.index_name.clone()));
1170 }
1171 };
1172
1173 let idx_table = TableSchema::index_table_name(&table_name, &lower_idx);
1174 wtx.drop_table(&idx_table).map_err(SqlError::Storage)?;
1175
1176 let table_schema = schema.get(&table_name).unwrap();
1177 let mut updated_schema = table_schema.clone();
1178 updated_schema.indices.retain(|i| i.name != lower_idx);
1179 SchemaManager::save_schema(wtx, &updated_schema)?;
1180
1181 schema.register(updated_schema);
1182 Ok(ExecutionResult::Ok)
1183}
1184
1185fn exec_alter_table(
1188 db: &Database,
1189 schema: &mut SchemaManager,
1190 stmt: &AlterTableStmt,
1191) -> Result<ExecutionResult> {
1192 let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
1193 SchemaManager::ensure_schema_table(&mut wtx)?;
1194 alter_table_impl(&mut wtx, schema, stmt)?;
1195 wtx.commit().map_err(SqlError::Storage)?;
1196 Ok(ExecutionResult::Ok)
1197}
1198
1199fn exec_alter_table_in_txn(
1200 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
1201 schema: &mut SchemaManager,
1202 stmt: &AlterTableStmt,
1203) -> Result<ExecutionResult> {
1204 SchemaManager::ensure_schema_table(wtx)?;
1205 alter_table_impl(wtx, schema, stmt)?;
1206 Ok(ExecutionResult::Ok)
1207}
1208
1209fn alter_table_impl(
1210 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
1211 schema: &mut SchemaManager,
1212 stmt: &AlterTableStmt,
1213) -> Result<()> {
1214 let lower_name = stmt.table.to_ascii_lowercase();
1215 if lower_name == "_schema" {
1216 return Err(SqlError::Unsupported("cannot alter internal table".into()));
1217 }
1218 match &stmt.op {
1219 AlterTableOp::AddColumn {
1220 column,
1221 foreign_key,
1222 if_not_exists,
1223 } => alter_add_column(
1224 wtx,
1225 schema,
1226 &lower_name,
1227 column,
1228 foreign_key.as_ref(),
1229 *if_not_exists,
1230 ),
1231 AlterTableOp::DropColumn { name, if_exists } => {
1232 alter_drop_column(wtx, schema, &lower_name, name, *if_exists)
1233 }
1234 AlterTableOp::RenameColumn { old_name, new_name } => {
1235 alter_rename_column(wtx, schema, &lower_name, old_name, new_name)
1236 }
1237 AlterTableOp::RenameTable { new_name } => {
1238 alter_rename_table(wtx, schema, &lower_name, new_name)
1239 }
1240 }
1241}
1242
1243fn alter_add_column(
1244 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
1245 schema: &mut SchemaManager,
1246 table_name: &str,
1247 col_spec: &ColumnSpec,
1248 fk_def: Option<&ForeignKeyDef>,
1249 if_not_exists: bool,
1250) -> Result<()> {
1251 let table_schema = schema
1252 .get(table_name)
1253 .ok_or_else(|| SqlError::TableNotFound(table_name.into()))?;
1254
1255 let col_lower = col_spec.name.to_ascii_lowercase();
1256
1257 if table_schema.column_index(&col_lower).is_some() {
1258 if if_not_exists {
1259 return Ok(());
1260 }
1261 return Err(SqlError::DuplicateColumn(col_spec.name.clone()));
1262 }
1263
1264 if col_spec.is_primary_key {
1265 return Err(SqlError::Unsupported(
1266 "cannot add PRIMARY KEY column via ALTER TABLE".into(),
1267 ));
1268 }
1269
1270 if !col_spec.nullable && col_spec.default_expr.is_none() {
1271 let count = wtx.table_entry_count(table_name.as_bytes()).unwrap_or(0);
1272 if count > 0 {
1273 return Err(SqlError::Unsupported(
1274 "cannot add NOT NULL column without DEFAULT to non-empty table".into(),
1275 ));
1276 }
1277 }
1278
1279 if let Some(ref check) = col_spec.check_expr {
1280 if has_subquery(check) {
1281 return Err(SqlError::Unsupported("subquery in CHECK constraint".into()));
1282 }
1283 }
1284
1285 let new_pos = table_schema.columns.len() as u16;
1286 let new_col = ColumnDef {
1287 name: col_lower.clone(),
1288 data_type: col_spec.data_type,
1289 nullable: col_spec.nullable,
1290 position: new_pos,
1291 default_expr: col_spec.default_expr.clone(),
1292 default_sql: col_spec.default_sql.clone(),
1293 check_expr: col_spec.check_expr.clone(),
1294 check_sql: col_spec.check_sql.clone(),
1295 check_name: col_spec.check_name.clone(),
1296 };
1297
1298 let mut new_schema = table_schema.clone();
1299 new_schema.columns.push(new_col);
1300
1301 if let Some(fk) = fk_def {
1302 let col_idx = new_pos;
1303 let fk_entry = ForeignKeySchemaEntry {
1304 name: fk.name.clone(),
1305 columns: vec![col_idx],
1306 foreign_table: fk.foreign_table.to_ascii_lowercase(),
1307 referred_columns: fk
1308 .referred_columns
1309 .iter()
1310 .map(|s| s.to_ascii_lowercase())
1311 .collect(),
1312 };
1313 new_schema.foreign_keys.push(fk_entry);
1314 }
1315
1316 new_schema = new_schema.rebuild();
1317
1318 if fk_def.is_some() {
1319 validate_foreign_keys(schema, &new_schema, &new_schema.foreign_keys)?;
1320 new_schema = create_fk_auto_indices(wtx, new_schema)?;
1321 }
1322
1323 SchemaManager::save_schema(wtx, &new_schema)?;
1324 schema.register(new_schema);
1325 Ok(())
1326}
1327
1328fn alter_drop_column(
1329 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
1330 schema: &mut SchemaManager,
1331 table_name: &str,
1332 col_name: &str,
1333 if_exists: bool,
1334) -> Result<()> {
1335 let table_schema = schema
1336 .get(table_name)
1337 .ok_or_else(|| SqlError::TableNotFound(table_name.into()))?;
1338
1339 let col_lower = col_name.to_ascii_lowercase();
1340 let drop_pos = match table_schema.column_index(&col_lower) {
1341 Some(pos) => pos,
1342 None => {
1343 if if_exists {
1344 return Ok(());
1345 }
1346 return Err(SqlError::ColumnNotFound(col_name.into()));
1347 }
1348 };
1349 let drop_pos_u16 = drop_pos as u16;
1350
1351 if table_schema.primary_key_columns.contains(&drop_pos_u16) {
1352 return Err(SqlError::Unsupported(
1353 "cannot drop primary key column".into(),
1354 ));
1355 }
1356
1357 for idx in &table_schema.indices {
1358 if idx.columns.contains(&drop_pos_u16) {
1359 return Err(SqlError::Unsupported(format!(
1360 "column '{}' is indexed by '{}'; drop the index first",
1361 col_lower, idx.name
1362 )));
1363 }
1364 }
1365
1366 for fk in &table_schema.foreign_keys {
1367 if fk.columns.contains(&drop_pos_u16) {
1368 return Err(SqlError::Unsupported(format!(
1369 "column '{}' is part of a foreign key",
1370 col_lower
1371 )));
1372 }
1373 }
1374
1375 for (child_table, fk) in schema.child_fks_for(table_name) {
1376 if child_table == table_name {
1377 continue; }
1379 if fk.referred_columns.iter().any(|rc| rc == &col_lower) {
1380 return Err(SqlError::Unsupported(format!(
1381 "column '{}' is referenced by a foreign key in '{}'",
1382 col_lower, child_table
1383 )));
1384 }
1385 }
1386
1387 for tc in &table_schema.check_constraints {
1388 if tc.sql.to_ascii_lowercase().contains(&col_lower) {
1389 return Err(SqlError::Unsupported(format!(
1390 "column '{}' is used in CHECK constraint '{}'",
1391 col_lower,
1392 tc.name.as_deref().unwrap_or("<unnamed>")
1393 )));
1394 }
1395 }
1396
1397 let new_schema = table_schema.without_column(drop_pos);
1401
1402 SchemaManager::save_schema(wtx, &new_schema)?;
1403 schema.register(new_schema);
1404 Ok(())
1405}
1406
1407fn alter_rename_column(
1408 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
1409 schema: &mut SchemaManager,
1410 table_name: &str,
1411 old_name: &str,
1412 new_name: &str,
1413) -> Result<()> {
1414 let table_schema = schema
1415 .get(table_name)
1416 .ok_or_else(|| SqlError::TableNotFound(table_name.into()))?;
1417
1418 let old_lower = old_name.to_ascii_lowercase();
1419 let new_lower = new_name.to_ascii_lowercase();
1420
1421 let col_pos = table_schema
1422 .column_index(&old_lower)
1423 .ok_or_else(|| SqlError::ColumnNotFound(old_name.into()))?;
1424
1425 if table_schema.column_index(&new_lower).is_some() {
1426 return Err(SqlError::DuplicateColumn(new_name.into()));
1427 }
1428
1429 let mut new_schema = table_schema.clone();
1430 new_schema.columns[col_pos].name = new_lower.clone();
1431
1432 for col in &mut new_schema.columns {
1434 if let Some(ref sql) = col.check_sql {
1435 if sql.to_ascii_lowercase().contains(&old_lower) {
1436 let updated = sql.replace(&old_lower, &new_lower);
1437 col.check_sql = Some(updated.clone());
1438 if let Ok(parsed) = crate::parser::parse_sql_expr(&updated) {
1440 col.check_expr = Some(parsed);
1441 }
1442 }
1443 }
1444 }
1445 for tc in &mut new_schema.check_constraints {
1446 if tc.sql.to_ascii_lowercase().contains(&old_lower) {
1447 tc.sql = tc.sql.replace(&old_lower, &new_lower);
1448 if let Ok(parsed) = crate::parser::parse_sql_expr(&tc.sql) {
1449 tc.expr = parsed;
1450 }
1451 }
1452 }
1453
1454 for fk in &mut new_schema.foreign_keys {
1458 if fk.foreign_table == table_name {
1459 for rc in &mut fk.referred_columns {
1460 if *rc == old_lower {
1461 *rc = new_lower.clone();
1462 }
1463 }
1464 }
1465 }
1466
1467 SchemaManager::save_schema(wtx, &new_schema)?;
1468 schema.register(new_schema);
1469 Ok(())
1470}
1471
1472fn alter_rename_table(
1473 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
1474 schema: &mut SchemaManager,
1475 old_name: &str,
1476 new_name: &str,
1477) -> Result<()> {
1478 let new_lower = new_name.to_ascii_lowercase();
1479
1480 if new_lower == "_schema" {
1481 return Err(SqlError::Unsupported(
1482 "cannot rename to internal table name".into(),
1483 ));
1484 }
1485
1486 let table_schema = schema
1487 .get(old_name)
1488 .ok_or_else(|| SqlError::TableNotFound(old_name.into()))?
1489 .clone();
1490
1491 if schema.contains(&new_lower) {
1492 return Err(SqlError::TableAlreadyExists(new_name.into()));
1493 }
1494
1495 wtx.rename_table(old_name.as_bytes(), new_lower.as_bytes())
1496 .map_err(SqlError::Storage)?;
1497
1498 let idx_renames: Vec<(Vec<u8>, Vec<u8>)> = table_schema
1499 .indices
1500 .iter()
1501 .map(|idx| {
1502 let old_idx = TableSchema::index_table_name(old_name, &idx.name);
1503 let new_idx = TableSchema::index_table_name(&new_lower, &idx.name);
1504 (old_idx, new_idx)
1505 })
1506 .collect();
1507 for (old_idx, new_idx) in &idx_renames {
1508 wtx.rename_table(old_idx, new_idx)
1509 .map_err(SqlError::Storage)?;
1510 }
1511
1512 let child_tables: Vec<String> = schema
1513 .child_fks_for(old_name)
1514 .iter()
1515 .filter(|(child, _)| *child != old_name)
1516 .map(|(child, _)| child.to_string())
1517 .collect::<std::collections::HashSet<_>>()
1518 .into_iter()
1519 .collect();
1520
1521 for child_table in &child_tables {
1522 let mut updated_child = schema.get(child_table).unwrap().clone();
1523 for fk in &mut updated_child.foreign_keys {
1524 if fk.foreign_table == old_name {
1525 fk.foreign_table = new_lower.clone();
1526 }
1527 }
1528 SchemaManager::save_schema(wtx, &updated_child)?;
1529 schema.register(updated_child);
1530 }
1531
1532 SchemaManager::delete_schema(wtx, old_name)?;
1533 let mut new_schema = table_schema;
1534 new_schema.name = new_lower.clone();
1535
1536 for fk in &mut new_schema.foreign_keys {
1538 if fk.foreign_table == old_name {
1539 fk.foreign_table = new_lower.clone();
1540 }
1541 }
1542
1543 SchemaManager::save_schema(wtx, &new_schema)?;
1544 schema.remove(old_name);
1545 schema.register(new_schema);
1546 Ok(())
1547}
1548
1549fn find_index_in_schemas(schema: &SchemaManager, index_name: &str) -> Option<(String, usize)> {
1550 for table_name in schema.table_names() {
1551 if let Some(ts) = schema.get(table_name) {
1552 if let Some(pos) = ts.indices.iter().position(|i| i.name == index_name) {
1553 return Some((table_name.to_string(), pos));
1554 }
1555 }
1556 }
1557 None
1558}
1559
1560fn extract_pk_key(
1563 idx_key: &[u8],
1564 idx_value: &[u8],
1565 is_unique: bool,
1566 num_index_cols: usize,
1567 num_pk_cols: usize,
1568) -> Result<Vec<u8>> {
1569 if is_unique && !idx_value.is_empty() {
1570 Ok(idx_value.to_vec())
1571 } else {
1572 let total_cols = num_index_cols + num_pk_cols;
1573 let all_values = decode_composite_key(idx_key, total_cols)?;
1574 let pk_values = &all_values[num_index_cols..];
1575 Ok(encode_composite_key(pk_values))
1576 }
1577}
1578
1579fn check_range_conditions(
1580 idx_key: &[u8],
1581 num_prefix_cols: usize,
1582 range_conds: &[(BinOp, Value)],
1583 num_index_cols: usize,
1584) -> Result<RangeCheck> {
1585 if range_conds.is_empty() {
1586 return Ok(RangeCheck::Match);
1587 }
1588
1589 let num_to_decode = num_prefix_cols + 1;
1590 if num_to_decode > num_index_cols {
1591 return Ok(RangeCheck::Match);
1592 }
1593
1594 let mut pos = 0;
1596 for _ in 0..num_prefix_cols {
1597 let (_, n) = decode_key_value(&idx_key[pos..])?;
1598 pos += n;
1599 }
1600 let (range_val, _) = decode_key_value(&idx_key[pos..])?;
1601
1602 let mut exceeds_upper = false;
1603 let mut below_lower = false;
1604
1605 for (op, val) in range_conds {
1606 match op {
1607 BinOp::Lt => {
1608 if range_val >= *val {
1609 exceeds_upper = true;
1610 }
1611 }
1612 BinOp::LtEq => {
1613 if range_val > *val {
1614 exceeds_upper = true;
1615 }
1616 }
1617 BinOp::Gt => {
1618 if range_val <= *val {
1619 below_lower = true;
1620 }
1621 }
1622 BinOp::GtEq => {
1623 if range_val < *val {
1624 below_lower = true;
1625 }
1626 }
1627 _ => {}
1628 }
1629 }
1630
1631 if exceeds_upper {
1632 Ok(RangeCheck::ExceedsUpper)
1633 } else if below_lower {
1634 Ok(RangeCheck::BelowLower)
1635 } else {
1636 Ok(RangeCheck::Match)
1637 }
1638}
1639
1640enum RangeCheck {
1641 Match,
1642 BelowLower,
1643 ExceedsUpper,
1644}
1645
1646fn collect_rows_read(
1648 db: &Database,
1649 table_schema: &TableSchema,
1650 where_clause: &Option<Expr>,
1651 limit: Option<usize>,
1652) -> Result<(Vec<Vec<Value>>, bool)> {
1653 let plan = planner::plan_select(table_schema, where_clause);
1654 let lower_name = &table_schema.name;
1655 let columns = &table_schema.columns;
1656
1657 match plan {
1658 ScanPlan::SeqScan => {
1659 let simple_pred = where_clause
1660 .as_ref()
1661 .and_then(|expr| try_simple_predicate(expr, table_schema));
1662
1663 if let Some(ref pred) = simple_pred {
1664 let mut rtx = db.begin_read();
1665 let entry_count =
1666 rtx.table_entry_count(lower_name.as_bytes()).unwrap_or(0) as usize;
1667 let mut rows = Vec::with_capacity(entry_count / 4);
1668 let mut scan_err: Option<SqlError> = None;
1669 rtx.table_scan_raw(lower_name.as_bytes(), |key, value| {
1670 match pred.matches_raw(key, value) {
1671 Ok(true) => match decode_full_row(table_schema, key, value) {
1672 Ok(row) => rows.push(row),
1673 Err(e) => {
1674 scan_err = Some(e);
1675 return false;
1676 }
1677 },
1678 Ok(false) => {}
1679 Err(e) => {
1680 scan_err = Some(e);
1681 return false;
1682 }
1683 }
1684 scan_err.is_none() && limit.map_or(true, |n| rows.len() < n)
1685 })
1686 .map_err(SqlError::Storage)?;
1687 if let Some(e) = scan_err {
1688 return Err(e);
1689 }
1690 return Ok((rows, true));
1691 }
1692
1693 let mut rtx = db.begin_read();
1694 let entry_count = rtx.table_entry_count(lower_name.as_bytes()).unwrap_or(0) as usize;
1695 let capacity = if where_clause.is_some() {
1696 entry_count / 4
1697 } else {
1698 entry_count
1699 };
1700 let mut rows = Vec::with_capacity(capacity);
1701 let mut scan_err: Option<SqlError> = None;
1702
1703 let col_map = ColumnMap::new(columns);
1704 let partial_ctx = where_clause.as_ref().and_then(|expr| {
1705 let needed = referenced_columns(expr, columns);
1706 if needed.len() < columns.len() {
1707 Some(PartialDecodeCtx::new(table_schema, &needed))
1708 } else {
1709 None
1710 }
1711 });
1712
1713 rtx.table_scan_from(lower_name.as_bytes(), b"", |key, value| {
1714 match (&where_clause, &partial_ctx) {
1715 (Some(expr), Some(ctx)) => match ctx.decode(key, value) {
1716 Ok(partial) => match eval_expr(expr, &col_map, &partial) {
1717 Ok(val) if is_truthy(&val) => match ctx.complete(partial, key, value) {
1718 Ok(row) => rows.push(row),
1719 Err(e) => scan_err = Some(e),
1720 },
1721 Err(e) => scan_err = Some(e),
1722 _ => {}
1723 },
1724 Err(e) => scan_err = Some(e),
1725 },
1726 (Some(expr), None) => match decode_full_row(table_schema, key, value) {
1727 Ok(row) => match eval_expr(expr, &col_map, &row) {
1728 Ok(val) if is_truthy(&val) => rows.push(row),
1729 Err(e) => scan_err = Some(e),
1730 _ => {}
1731 },
1732 Err(e) => scan_err = Some(e),
1733 },
1734 _ => match decode_full_row(table_schema, key, value) {
1735 Ok(row) => rows.push(row),
1736 Err(e) => scan_err = Some(e),
1737 },
1738 }
1739 let keep_going = scan_err.is_none() && limit.map_or(true, |n| rows.len() < n);
1740 Ok(keep_going)
1741 })
1742 .map_err(SqlError::Storage)?;
1743 if let Some(e) = scan_err {
1744 return Err(e);
1745 }
1746 Ok((rows, where_clause.is_some()))
1747 }
1748
1749 ScanPlan::PkLookup { pk_values } => {
1750 let key = encode_composite_key(&pk_values);
1751 let mut rtx = db.begin_read();
1752 match rtx
1753 .table_get(lower_name.as_bytes(), &key)
1754 .map_err(SqlError::Storage)?
1755 {
1756 Some(value) => {
1757 let row = decode_full_row(table_schema, &key, &value)?;
1758 if let Some(ref expr) = where_clause {
1759 let col_map = ColumnMap::new(columns);
1760 match eval_expr(expr, &col_map, &row) {
1761 Ok(val) if is_truthy(&val) => Ok((vec![row], true)),
1762 _ => Ok((vec![], true)),
1763 }
1764 } else {
1765 Ok((vec![row], false))
1766 }
1767 }
1768 None => Ok((vec![], true)),
1769 }
1770 }
1771
1772 ScanPlan::IndexScan {
1773 idx_table,
1774 prefix,
1775 num_prefix_cols,
1776 range_conds,
1777 is_unique,
1778 index_columns,
1779 ..
1780 } => {
1781 let num_pk_cols = table_schema.primary_key_columns.len();
1782 let num_index_cols = index_columns.len();
1783 let mut pk_keys: Vec<Vec<u8>> = Vec::new();
1784
1785 {
1786 let mut rtx = db.begin_read();
1787 let mut scan_err: Option<SqlError> = None;
1788 rtx.table_scan_from(&idx_table, &prefix, |key, value| {
1789 if !key.starts_with(&prefix) {
1790 return Ok(false);
1791 }
1792 match check_range_conditions(key, num_prefix_cols, &range_conds, num_index_cols)
1793 {
1794 Ok(RangeCheck::ExceedsUpper) => return Ok(false),
1795 Ok(RangeCheck::BelowLower) => return Ok(true),
1796 Ok(RangeCheck::Match) => {}
1797 Err(e) => {
1798 scan_err = Some(e);
1799 return Ok(false);
1800 }
1801 }
1802 match extract_pk_key(key, value, is_unique, num_index_cols, num_pk_cols) {
1803 Ok(pk) => pk_keys.push(pk),
1804 Err(e) => {
1805 scan_err = Some(e);
1806 return Ok(false);
1807 }
1808 }
1809 Ok(true)
1810 })
1811 .map_err(SqlError::Storage)?;
1812 if let Some(e) = scan_err {
1813 return Err(e);
1814 }
1815 }
1816
1817 let mut rows = Vec::new();
1818 let mut rtx = db.begin_read();
1819 let col_map = ColumnMap::new(columns);
1820 for pk_key in &pk_keys {
1821 if let Some(value) = rtx
1822 .table_get(lower_name.as_bytes(), pk_key)
1823 .map_err(SqlError::Storage)?
1824 {
1825 let row = decode_full_row(table_schema, pk_key, &value)?;
1826 if let Some(ref expr) = where_clause {
1827 match eval_expr(expr, &col_map, &row) {
1828 Ok(val) if is_truthy(&val) => rows.push(row),
1829 _ => {}
1830 }
1831 } else {
1832 rows.push(row);
1833 }
1834 }
1835 }
1836 Ok((rows, where_clause.is_some()))
1837 }
1838 }
1839}
1840
1841fn collect_rows_write(
1843 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
1844 table_schema: &TableSchema,
1845 where_clause: &Option<Expr>,
1846 limit: Option<usize>,
1847) -> Result<(Vec<Vec<Value>>, bool)> {
1848 let plan = planner::plan_select(table_schema, where_clause);
1849 let lower_name = &table_schema.name;
1850 let columns = &table_schema.columns;
1851
1852 match plan {
1853 ScanPlan::SeqScan => {
1854 let simple_pred = where_clause
1855 .as_ref()
1856 .and_then(|expr| try_simple_predicate(expr, table_schema));
1857
1858 if let Some(ref pred) = simple_pred {
1859 let mut rows = Vec::new();
1860 let mut scan_err: Option<SqlError> = None;
1861 wtx.table_scan_from(lower_name.as_bytes(), b"", |key, value| {
1862 match pred.matches_raw(key, value) {
1863 Ok(true) => match decode_full_row(table_schema, key, value) {
1864 Ok(row) => rows.push(row),
1865 Err(e) => scan_err = Some(e),
1866 },
1867 Ok(false) => {}
1868 Err(e) => scan_err = Some(e),
1869 }
1870 let keep_going = scan_err.is_none() && limit.map_or(true, |n| rows.len() < n);
1871 Ok(keep_going)
1872 })
1873 .map_err(SqlError::Storage)?;
1874 if let Some(e) = scan_err {
1875 return Err(e);
1876 }
1877 return Ok((rows, true));
1878 }
1879
1880 let mut rows = Vec::new();
1881 let mut scan_err: Option<SqlError> = None;
1882
1883 let col_map = ColumnMap::new(columns);
1884 let partial_ctx = where_clause.as_ref().and_then(|expr| {
1885 let needed = referenced_columns(expr, columns);
1886 if needed.len() < columns.len() {
1887 Some(PartialDecodeCtx::new(table_schema, &needed))
1888 } else {
1889 None
1890 }
1891 });
1892
1893 wtx.table_scan_from(lower_name.as_bytes(), b"", |key, value| {
1894 match (&where_clause, &partial_ctx) {
1895 (Some(expr), Some(ctx)) => match ctx.decode(key, value) {
1896 Ok(partial) => match eval_expr(expr, &col_map, &partial) {
1897 Ok(val) if is_truthy(&val) => match ctx.complete(partial, key, value) {
1898 Ok(row) => rows.push(row),
1899 Err(e) => scan_err = Some(e),
1900 },
1901 Err(e) => scan_err = Some(e),
1902 _ => {}
1903 },
1904 Err(e) => scan_err = Some(e),
1905 },
1906 (Some(expr), None) => match decode_full_row(table_schema, key, value) {
1907 Ok(row) => match eval_expr(expr, &col_map, &row) {
1908 Ok(val) if is_truthy(&val) => rows.push(row),
1909 Err(e) => scan_err = Some(e),
1910 _ => {}
1911 },
1912 Err(e) => scan_err = Some(e),
1913 },
1914 _ => match decode_full_row(table_schema, key, value) {
1915 Ok(row) => rows.push(row),
1916 Err(e) => scan_err = Some(e),
1917 },
1918 }
1919 let keep_going = scan_err.is_none() && limit.map_or(true, |n| rows.len() < n);
1920 Ok(keep_going)
1921 })
1922 .map_err(SqlError::Storage)?;
1923 if let Some(e) = scan_err {
1924 return Err(e);
1925 }
1926 Ok((rows, where_clause.is_some()))
1927 }
1928
1929 ScanPlan::PkLookup { pk_values } => {
1930 let key = encode_composite_key(&pk_values);
1931 match wtx
1932 .table_get(lower_name.as_bytes(), &key)
1933 .map_err(SqlError::Storage)?
1934 {
1935 Some(value) => {
1936 let row = decode_full_row(table_schema, &key, &value)?;
1937 if let Some(ref expr) = where_clause {
1938 let col_map = ColumnMap::new(columns);
1939 match eval_expr(expr, &col_map, &row) {
1940 Ok(val) if is_truthy(&val) => Ok((vec![row], true)),
1941 _ => Ok((vec![], true)),
1942 }
1943 } else {
1944 Ok((vec![row], false))
1945 }
1946 }
1947 None => Ok((vec![], true)),
1948 }
1949 }
1950
1951 ScanPlan::IndexScan {
1952 idx_table,
1953 prefix,
1954 num_prefix_cols,
1955 range_conds,
1956 is_unique,
1957 index_columns,
1958 ..
1959 } => {
1960 let num_pk_cols = table_schema.primary_key_columns.len();
1961 let num_index_cols = index_columns.len();
1962 let mut pk_keys: Vec<Vec<u8>> = Vec::new();
1963
1964 {
1965 let mut scan_err: Option<SqlError> = None;
1966 wtx.table_scan_from(&idx_table, &prefix, |key, value| {
1967 if !key.starts_with(&prefix) {
1968 return Ok(false);
1969 }
1970 match check_range_conditions(key, num_prefix_cols, &range_conds, num_index_cols)
1971 {
1972 Ok(RangeCheck::ExceedsUpper) => return Ok(false),
1973 Ok(RangeCheck::BelowLower) => return Ok(true),
1974 Ok(RangeCheck::Match) => {}
1975 Err(e) => {
1976 scan_err = Some(e);
1977 return Ok(false);
1978 }
1979 }
1980 match extract_pk_key(key, value, is_unique, num_index_cols, num_pk_cols) {
1981 Ok(pk) => pk_keys.push(pk),
1982 Err(e) => {
1983 scan_err = Some(e);
1984 return Ok(false);
1985 }
1986 }
1987 Ok(true)
1988 })
1989 .map_err(SqlError::Storage)?;
1990 if let Some(e) = scan_err {
1991 return Err(e);
1992 }
1993 }
1994
1995 let mut rows = Vec::new();
1996 let col_map = ColumnMap::new(columns);
1997 for pk_key in &pk_keys {
1998 if let Some(value) = wtx
1999 .table_get(lower_name.as_bytes(), pk_key)
2000 .map_err(SqlError::Storage)?
2001 {
2002 let row = decode_full_row(table_schema, pk_key, &value)?;
2003 if let Some(ref expr) = where_clause {
2004 match eval_expr(expr, &col_map, &row) {
2005 Ok(val) if is_truthy(&val) => rows.push(row),
2006 _ => {}
2007 }
2008 } else {
2009 rows.push(row);
2010 }
2011 }
2012 }
2013 Ok((rows, where_clause.is_some()))
2014 }
2015 }
2016}
2017
2018fn collect_keyed_rows_read(
2021 db: &Database,
2022 table_schema: &TableSchema,
2023 where_clause: &Option<Expr>,
2024) -> Result<Vec<(Vec<u8>, Vec<Value>)>> {
2025 let plan = planner::plan_select(table_schema, where_clause);
2026 let lower_name = &table_schema.name;
2027
2028 match plan {
2029 ScanPlan::SeqScan => {
2030 let mut rows = Vec::new();
2031 let mut rtx = db.begin_read();
2032 let mut scan_err: Option<SqlError> = None;
2033 rtx.table_for_each(lower_name.as_bytes(), |key, value| {
2034 match decode_full_row(table_schema, key, value) {
2035 Ok(row) => rows.push((key.to_vec(), row)),
2036 Err(e) => scan_err = Some(e),
2037 }
2038 Ok(())
2039 })
2040 .map_err(SqlError::Storage)?;
2041 if let Some(e) = scan_err {
2042 return Err(e);
2043 }
2044 Ok(rows)
2045 }
2046
2047 ScanPlan::PkLookup { pk_values } => {
2048 let key = encode_composite_key(&pk_values);
2049 let mut rtx = db.begin_read();
2050 match rtx
2051 .table_get(lower_name.as_bytes(), &key)
2052 .map_err(SqlError::Storage)?
2053 {
2054 Some(value) => {
2055 let row = decode_full_row(table_schema, &key, &value)?;
2056 Ok(vec![(key, row)])
2057 }
2058 None => Ok(vec![]),
2059 }
2060 }
2061
2062 ScanPlan::IndexScan {
2063 idx_table,
2064 prefix,
2065 num_prefix_cols,
2066 range_conds,
2067 is_unique,
2068 index_columns,
2069 ..
2070 } => {
2071 let num_pk_cols = table_schema.primary_key_columns.len();
2072 let num_index_cols = index_columns.len();
2073 let mut pk_keys: Vec<Vec<u8>> = Vec::new();
2074
2075 {
2076 let mut rtx = db.begin_read();
2077 let mut scan_err: Option<SqlError> = None;
2078 rtx.table_scan_from(&idx_table, &prefix, |key, value| {
2079 if !key.starts_with(&prefix) {
2080 return Ok(false);
2081 }
2082 match check_range_conditions(key, num_prefix_cols, &range_conds, num_index_cols)
2083 {
2084 Ok(RangeCheck::ExceedsUpper) => return Ok(false),
2085 Ok(RangeCheck::BelowLower) => return Ok(true),
2086 Ok(RangeCheck::Match) => {}
2087 Err(e) => {
2088 scan_err = Some(e);
2089 return Ok(false);
2090 }
2091 }
2092 match extract_pk_key(key, value, is_unique, num_index_cols, num_pk_cols) {
2093 Ok(pk) => pk_keys.push(pk),
2094 Err(e) => {
2095 scan_err = Some(e);
2096 return Ok(false);
2097 }
2098 }
2099 Ok(true)
2100 })
2101 .map_err(SqlError::Storage)?;
2102 if let Some(e) = scan_err {
2103 return Err(e);
2104 }
2105 }
2106
2107 let mut rows = Vec::new();
2108 let mut rtx = db.begin_read();
2109 for pk_key in &pk_keys {
2110 if let Some(value) = rtx
2111 .table_get(lower_name.as_bytes(), pk_key)
2112 .map_err(SqlError::Storage)?
2113 {
2114 rows.push((
2115 pk_key.clone(),
2116 decode_full_row(table_schema, pk_key, &value)?,
2117 ));
2118 }
2119 }
2120 Ok(rows)
2121 }
2122 }
2123}
2124
2125fn collect_keyed_rows_write(
2127 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
2128 table_schema: &TableSchema,
2129 where_clause: &Option<Expr>,
2130) -> Result<Vec<(Vec<u8>, Vec<Value>)>> {
2131 let plan = planner::plan_select(table_schema, where_clause);
2132 let lower_name = &table_schema.name;
2133
2134 match plan {
2135 ScanPlan::SeqScan => {
2136 let mut rows = Vec::new();
2137 let mut scan_err: Option<SqlError> = None;
2138 wtx.table_for_each(lower_name.as_bytes(), |key, value| {
2139 match decode_full_row(table_schema, key, value) {
2140 Ok(row) => rows.push((key.to_vec(), row)),
2141 Err(e) => scan_err = Some(e),
2142 }
2143 Ok(())
2144 })
2145 .map_err(SqlError::Storage)?;
2146 if let Some(e) = scan_err {
2147 return Err(e);
2148 }
2149 Ok(rows)
2150 }
2151
2152 ScanPlan::PkLookup { pk_values } => {
2153 let key = encode_composite_key(&pk_values);
2154 match wtx
2155 .table_get(lower_name.as_bytes(), &key)
2156 .map_err(SqlError::Storage)?
2157 {
2158 Some(value) => {
2159 let row = decode_full_row(table_schema, &key, &value)?;
2160 Ok(vec![(key, row)])
2161 }
2162 None => Ok(vec![]),
2163 }
2164 }
2165
2166 ScanPlan::IndexScan {
2167 idx_table,
2168 prefix,
2169 num_prefix_cols,
2170 range_conds,
2171 is_unique,
2172 index_columns,
2173 ..
2174 } => {
2175 let num_pk_cols = table_schema.primary_key_columns.len();
2176 let num_index_cols = index_columns.len();
2177 let mut pk_keys: Vec<Vec<u8>> = Vec::new();
2178
2179 {
2180 let mut scan_err: Option<SqlError> = None;
2181 wtx.table_scan_from(&idx_table, &prefix, |key, value| {
2182 if !key.starts_with(&prefix) {
2183 return Ok(false);
2184 }
2185 match check_range_conditions(key, num_prefix_cols, &range_conds, num_index_cols)
2186 {
2187 Ok(RangeCheck::ExceedsUpper) => return Ok(false),
2188 Ok(RangeCheck::BelowLower) => return Ok(true),
2189 Ok(RangeCheck::Match) => {}
2190 Err(e) => {
2191 scan_err = Some(e);
2192 return Ok(false);
2193 }
2194 }
2195 match extract_pk_key(key, value, is_unique, num_index_cols, num_pk_cols) {
2196 Ok(pk) => pk_keys.push(pk),
2197 Err(e) => {
2198 scan_err = Some(e);
2199 return Ok(false);
2200 }
2201 }
2202 Ok(true)
2203 })
2204 .map_err(SqlError::Storage)?;
2205 if let Some(e) = scan_err {
2206 return Err(e);
2207 }
2208 }
2209
2210 let mut rows = Vec::new();
2211 for pk_key in &pk_keys {
2212 if let Some(value) = wtx
2213 .table_get(lower_name.as_bytes(), pk_key)
2214 .map_err(SqlError::Storage)?
2215 {
2216 rows.push((
2217 pk_key.clone(),
2218 decode_full_row(table_schema, pk_key, &value)?,
2219 ));
2220 }
2221 }
2222 Ok(rows)
2223 }
2224 }
2225}
2226
2227fn exec_insert(
2230 db: &Database,
2231 schema: &SchemaManager,
2232 stmt: &InsertStmt,
2233 params: &[Value],
2234) -> Result<ExecutionResult> {
2235 let empty_ctes = CteContext::new();
2236 let materialized;
2237 let stmt = if insert_has_subquery(stmt) {
2238 materialized = materialize_insert(stmt, &mut |sub| {
2239 exec_subquery_read(db, schema, sub, &empty_ctes)
2240 })?;
2241 &materialized
2242 } else {
2243 stmt
2244 };
2245
2246 let lower_name = stmt.table.to_ascii_lowercase();
2247 let table_schema = schema
2248 .get(&lower_name)
2249 .ok_or_else(|| SqlError::TableNotFound(stmt.table.clone()))?;
2250
2251 let insert_columns = if stmt.columns.is_empty() {
2252 table_schema
2253 .columns
2254 .iter()
2255 .map(|c| c.name.clone())
2256 .collect::<Vec<_>>()
2257 } else {
2258 stmt.columns
2259 .iter()
2260 .map(|c| c.to_ascii_lowercase())
2261 .collect()
2262 };
2263
2264 let col_indices: Vec<usize> = insert_columns
2265 .iter()
2266 .map(|name| {
2267 table_schema
2268 .column_index(name)
2269 .ok_or_else(|| SqlError::ColumnNotFound(name.clone()))
2270 })
2271 .collect::<Result<_>>()?;
2272
2273 let defaults: Vec<(usize, &Expr)> = table_schema
2275 .columns
2276 .iter()
2277 .filter(|c| c.default_expr.is_some() && !col_indices.contains(&(c.position as usize)))
2278 .map(|c| (c.position as usize, c.default_expr.as_ref().unwrap()))
2279 .collect();
2280
2281 let has_checks = table_schema.has_checks();
2283 let check_col_map = if has_checks {
2284 Some(ColumnMap::new(&table_schema.columns))
2285 } else {
2286 None
2287 };
2288
2289 let select_rows = match &stmt.source {
2290 InsertSource::Select(sq) => {
2291 let insert_ctes = materialize_all_ctes(&sq.ctes, sq.recursive, &mut |body, ctx| {
2292 exec_query_body_read(db, schema, body, ctx)
2293 })?;
2294 let qr = exec_query_body_read(db, schema, &sq.body, &insert_ctes)?;
2295 Some(qr.rows)
2296 }
2297 InsertSource::Values(_) => None,
2298 };
2299
2300 let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
2301 let mut count: u64 = 0;
2302
2303 let pk_indices = table_schema.pk_indices();
2304 let non_pk = table_schema.non_pk_indices();
2305 let enc_pos = table_schema.encoding_positions();
2306 let phys_count = table_schema.physical_non_pk_count();
2307 let mut row = vec![Value::Null; table_schema.columns.len()];
2308 let mut pk_values: Vec<Value> = vec![Value::Null; pk_indices.len()];
2309 let mut value_values: Vec<Value> = vec![Value::Null; phys_count];
2310 let mut key_buf: Vec<u8> = Vec::with_capacity(64);
2311 let mut value_buf: Vec<u8> = Vec::with_capacity(256);
2312 let mut fk_key_buf: Vec<u8> = Vec::with_capacity(64);
2313
2314 let values = match &stmt.source {
2315 InsertSource::Values(rows) => Some(rows.as_slice()),
2316 InsertSource::Select(_) => None,
2317 };
2318 let sel_rows = select_rows.as_deref();
2319
2320 let total = match (values, sel_rows) {
2321 (Some(rows), _) => rows.len(),
2322 (_, Some(rows)) => rows.len(),
2323 _ => 0,
2324 };
2325
2326 if let Some(sel) = sel_rows {
2327 if !sel.is_empty() && sel[0].len() != insert_columns.len() {
2328 return Err(SqlError::InvalidValue(format!(
2329 "INSERT ... SELECT column count mismatch: expected {}, got {}",
2330 insert_columns.len(),
2331 sel[0].len()
2332 )));
2333 }
2334 }
2335
2336 for idx in 0..total {
2337 for v in row.iter_mut() {
2338 *v = Value::Null;
2339 }
2340
2341 if let Some(value_rows) = values {
2342 let value_row = &value_rows[idx];
2343 if value_row.len() != insert_columns.len() {
2344 return Err(SqlError::InvalidValue(format!(
2345 "expected {} values, got {}",
2346 insert_columns.len(),
2347 value_row.len()
2348 )));
2349 }
2350 for (i, expr) in value_row.iter().enumerate() {
2351 let val = if let Expr::Parameter(n) = expr {
2352 params
2353 .get(n - 1)
2354 .cloned()
2355 .ok_or_else(|| SqlError::Parse(format!("unbound parameter ${n}")))?
2356 } else {
2357 eval_const_expr(expr)?
2358 };
2359 let col_idx = col_indices[i];
2360 let col = &table_schema.columns[col_idx];
2361 let got_type = val.data_type();
2362 row[col_idx] = if val.is_null() {
2363 Value::Null
2364 } else {
2365 val.coerce_into(col.data_type)
2366 .ok_or_else(|| SqlError::TypeMismatch {
2367 expected: col.data_type.to_string(),
2368 got: got_type.to_string(),
2369 })?
2370 };
2371 }
2372 } else if let Some(sel) = sel_rows {
2373 let sel_row = &sel[idx];
2374 for (i, val) in sel_row.iter().enumerate() {
2375 let col_idx = col_indices[i];
2376 let col = &table_schema.columns[col_idx];
2377 let got_type = val.data_type();
2378 row[col_idx] = if val.is_null() {
2379 Value::Null
2380 } else {
2381 val.clone().coerce_into(col.data_type).ok_or_else(|| {
2382 SqlError::TypeMismatch {
2383 expected: col.data_type.to_string(),
2384 got: got_type.to_string(),
2385 }
2386 })?
2387 };
2388 }
2389 }
2390
2391 for &(pos, def_expr) in &defaults {
2393 let val = eval_const_expr(def_expr)?;
2394 let col = &table_schema.columns[pos];
2395 if val.is_null() {
2396 } else {
2398 let got_type = val.data_type();
2399 row[pos] =
2400 val.coerce_into(col.data_type)
2401 .ok_or_else(|| SqlError::TypeMismatch {
2402 expected: col.data_type.to_string(),
2403 got: got_type.to_string(),
2404 })?;
2405 }
2406 }
2407
2408 for col in &table_schema.columns {
2409 if !col.nullable && row[col.position as usize].is_null() {
2410 return Err(SqlError::NotNullViolation(col.name.clone()));
2411 }
2412 }
2413
2414 if let Some(ref col_map) = check_col_map {
2416 for col in &table_schema.columns {
2417 if let Some(ref check) = col.check_expr {
2418 let result = eval_expr(check, col_map, &row)?;
2419 if !is_truthy(&result) && !result.is_null() {
2420 let name = col.check_name.as_deref().unwrap_or(&col.name);
2421 return Err(SqlError::CheckViolation(name.to_string()));
2422 }
2423 }
2424 }
2425 for tc in &table_schema.check_constraints {
2426 let result = eval_expr(&tc.expr, col_map, &row)?;
2427 if !is_truthy(&result) && !result.is_null() {
2428 let name = tc.name.as_deref().unwrap_or(&tc.sql);
2429 return Err(SqlError::CheckViolation(name.to_string()));
2430 }
2431 }
2432 }
2433
2434 for fk in &table_schema.foreign_keys {
2436 let any_null = fk.columns.iter().any(|&ci| row[ci as usize].is_null());
2437 if any_null {
2438 continue; }
2440 let fk_vals: Vec<Value> = fk
2441 .columns
2442 .iter()
2443 .map(|&ci| row[ci as usize].clone())
2444 .collect();
2445 fk_key_buf.clear();
2446 encode_composite_key_into(&fk_vals, &mut fk_key_buf);
2447 let found = wtx
2448 .table_get(fk.foreign_table.as_bytes(), &fk_key_buf)
2449 .map_err(SqlError::Storage)?;
2450 if found.is_none() {
2451 let name = fk.name.as_deref().unwrap_or(&fk.foreign_table);
2452 return Err(SqlError::ForeignKeyViolation(name.to_string()));
2453 }
2454 }
2455
2456 for (j, &i) in pk_indices.iter().enumerate() {
2457 pk_values[j] = std::mem::replace(&mut row[i], Value::Null);
2458 }
2459 encode_composite_key_into(&pk_values, &mut key_buf);
2460
2461 for (j, &i) in non_pk.iter().enumerate() {
2462 value_values[enc_pos[j] as usize] = std::mem::replace(&mut row[i], Value::Null);
2463 }
2464 encode_row_into(&value_values, &mut value_buf);
2465
2466 if key_buf.len() > citadel_core::MAX_KEY_SIZE {
2467 return Err(SqlError::KeyTooLarge {
2468 size: key_buf.len(),
2469 max: citadel_core::MAX_KEY_SIZE,
2470 });
2471 }
2472 if value_buf.len() > citadel_core::MAX_INLINE_VALUE_SIZE {
2473 return Err(SqlError::RowTooLarge {
2474 size: value_buf.len(),
2475 max: citadel_core::MAX_INLINE_VALUE_SIZE,
2476 });
2477 }
2478
2479 let is_new = wtx
2480 .table_insert(stmt.table.as_bytes(), &key_buf, &value_buf)
2481 .map_err(SqlError::Storage)?;
2482 if !is_new {
2483 return Err(SqlError::DuplicateKey);
2484 }
2485
2486 if !table_schema.indices.is_empty() {
2487 for (j, &i) in pk_indices.iter().enumerate() {
2488 row[i] = pk_values[j].clone();
2489 }
2490 for (j, &i) in non_pk.iter().enumerate() {
2491 row[i] = std::mem::replace(&mut value_values[enc_pos[j] as usize], Value::Null);
2492 }
2493 insert_index_entries(&mut wtx, table_schema, &row, &pk_values)?;
2494 }
2495 count += 1;
2496 }
2497
2498 wtx.commit().map_err(SqlError::Storage)?;
2499 Ok(ExecutionResult::RowsAffected(count))
2500}
2501
2502fn has_subquery(expr: &Expr) -> bool {
2503 crate::parser::has_subquery(expr)
2504}
2505
2506fn stmt_has_subquery(stmt: &SelectStmt) -> bool {
2507 if let Some(ref w) = stmt.where_clause {
2508 if has_subquery(w) {
2509 return true;
2510 }
2511 }
2512 if let Some(ref h) = stmt.having {
2513 if has_subquery(h) {
2514 return true;
2515 }
2516 }
2517 for col in &stmt.columns {
2518 if let SelectColumn::Expr { expr, .. } = col {
2519 if has_subquery(expr) {
2520 return true;
2521 }
2522 }
2523 }
2524 for ob in &stmt.order_by {
2525 if has_subquery(&ob.expr) {
2526 return true;
2527 }
2528 }
2529 for join in &stmt.joins {
2530 if let Some(ref on_expr) = join.on_clause {
2531 if has_subquery(on_expr) {
2532 return true;
2533 }
2534 }
2535 }
2536 false
2537}
2538
2539fn materialize_expr(
2540 expr: &Expr,
2541 exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
2542) -> Result<Expr> {
2543 match expr {
2544 Expr::InSubquery {
2545 expr: e,
2546 subquery,
2547 negated,
2548 } => {
2549 let inner = materialize_expr(e, exec_sub)?;
2550 let qr = exec_sub(subquery)?;
2551 if !qr.columns.is_empty() && qr.columns.len() != 1 {
2552 return Err(SqlError::SubqueryMultipleColumns);
2553 }
2554 let mut values = std::collections::HashSet::new();
2555 let mut has_null = false;
2556 for row in &qr.rows {
2557 if row[0].is_null() {
2558 has_null = true;
2559 } else {
2560 values.insert(row[0].clone());
2561 }
2562 }
2563 Ok(Expr::InSet {
2564 expr: Box::new(inner),
2565 values,
2566 has_null,
2567 negated: *negated,
2568 })
2569 }
2570 Expr::ScalarSubquery(subquery) => {
2571 let qr = exec_sub(subquery)?;
2572 if qr.rows.len() > 1 {
2573 return Err(SqlError::SubqueryMultipleRows);
2574 }
2575 let val = if qr.rows.is_empty() {
2576 Value::Null
2577 } else {
2578 qr.rows[0][0].clone()
2579 };
2580 Ok(Expr::Literal(val))
2581 }
2582 Expr::Exists { subquery, negated } => {
2583 let qr = exec_sub(subquery)?;
2584 let exists = !qr.rows.is_empty();
2585 let result = if *negated { !exists } else { exists };
2586 Ok(Expr::Literal(Value::Boolean(result)))
2587 }
2588 Expr::InList {
2589 expr: e,
2590 list,
2591 negated,
2592 } => {
2593 let inner = materialize_expr(e, exec_sub)?;
2594 let items = list
2595 .iter()
2596 .map(|item| materialize_expr(item, exec_sub))
2597 .collect::<Result<Vec<_>>>()?;
2598 Ok(Expr::InList {
2599 expr: Box::new(inner),
2600 list: items,
2601 negated: *negated,
2602 })
2603 }
2604 Expr::BinaryOp { left, op, right } => Ok(Expr::BinaryOp {
2605 left: Box::new(materialize_expr(left, exec_sub)?),
2606 op: *op,
2607 right: Box::new(materialize_expr(right, exec_sub)?),
2608 }),
2609 Expr::UnaryOp { op, expr: e } => Ok(Expr::UnaryOp {
2610 op: *op,
2611 expr: Box::new(materialize_expr(e, exec_sub)?),
2612 }),
2613 Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(materialize_expr(e, exec_sub)?))),
2614 Expr::IsNotNull(e) => Ok(Expr::IsNotNull(Box::new(materialize_expr(e, exec_sub)?))),
2615 Expr::InSet {
2616 expr: e,
2617 values,
2618 has_null,
2619 negated,
2620 } => Ok(Expr::InSet {
2621 expr: Box::new(materialize_expr(e, exec_sub)?),
2622 values: values.clone(),
2623 has_null: *has_null,
2624 negated: *negated,
2625 }),
2626 Expr::Between {
2627 expr: e,
2628 low,
2629 high,
2630 negated,
2631 } => Ok(Expr::Between {
2632 expr: Box::new(materialize_expr(e, exec_sub)?),
2633 low: Box::new(materialize_expr(low, exec_sub)?),
2634 high: Box::new(materialize_expr(high, exec_sub)?),
2635 negated: *negated,
2636 }),
2637 Expr::Like {
2638 expr: e,
2639 pattern,
2640 escape,
2641 negated,
2642 } => {
2643 let esc = escape
2644 .as_ref()
2645 .map(|es| materialize_expr(es, exec_sub).map(Box::new))
2646 .transpose()?;
2647 Ok(Expr::Like {
2648 expr: Box::new(materialize_expr(e, exec_sub)?),
2649 pattern: Box::new(materialize_expr(pattern, exec_sub)?),
2650 escape: esc,
2651 negated: *negated,
2652 })
2653 }
2654 Expr::Case {
2655 operand,
2656 conditions,
2657 else_result,
2658 } => {
2659 let op = operand
2660 .as_ref()
2661 .map(|e| materialize_expr(e, exec_sub).map(Box::new))
2662 .transpose()?;
2663 let conds = conditions
2664 .iter()
2665 .map(|(c, r)| {
2666 Ok((
2667 materialize_expr(c, exec_sub)?,
2668 materialize_expr(r, exec_sub)?,
2669 ))
2670 })
2671 .collect::<Result<Vec<_>>>()?;
2672 let else_r = else_result
2673 .as_ref()
2674 .map(|e| materialize_expr(e, exec_sub).map(Box::new))
2675 .transpose()?;
2676 Ok(Expr::Case {
2677 operand: op,
2678 conditions: conds,
2679 else_result: else_r,
2680 })
2681 }
2682 Expr::Coalesce(args) => {
2683 let materialized = args
2684 .iter()
2685 .map(|a| materialize_expr(a, exec_sub))
2686 .collect::<Result<Vec<_>>>()?;
2687 Ok(Expr::Coalesce(materialized))
2688 }
2689 Expr::Cast { expr: e, data_type } => Ok(Expr::Cast {
2690 expr: Box::new(materialize_expr(e, exec_sub)?),
2691 data_type: *data_type,
2692 }),
2693 Expr::Function { name, args } => {
2694 let materialized = args
2695 .iter()
2696 .map(|a| materialize_expr(a, exec_sub))
2697 .collect::<Result<Vec<_>>>()?;
2698 Ok(Expr::Function {
2699 name: name.clone(),
2700 args: materialized,
2701 })
2702 }
2703 other => Ok(other.clone()),
2704 }
2705}
2706
2707fn materialize_stmt(
2708 stmt: &SelectStmt,
2709 exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
2710) -> Result<SelectStmt> {
2711 let where_clause = stmt
2712 .where_clause
2713 .as_ref()
2714 .map(|e| materialize_expr(e, exec_sub))
2715 .transpose()?;
2716 let having = stmt
2717 .having
2718 .as_ref()
2719 .map(|e| materialize_expr(e, exec_sub))
2720 .transpose()?;
2721 let columns = stmt
2722 .columns
2723 .iter()
2724 .map(|c| match c {
2725 SelectColumn::AllColumns => Ok(SelectColumn::AllColumns),
2726 SelectColumn::Expr { expr, alias } => Ok(SelectColumn::Expr {
2727 expr: materialize_expr(expr, exec_sub)?,
2728 alias: alias.clone(),
2729 }),
2730 })
2731 .collect::<Result<Vec<_>>>()?;
2732 let order_by = stmt
2733 .order_by
2734 .iter()
2735 .map(|ob| {
2736 Ok(OrderByItem {
2737 expr: materialize_expr(&ob.expr, exec_sub)?,
2738 descending: ob.descending,
2739 nulls_first: ob.nulls_first,
2740 })
2741 })
2742 .collect::<Result<Vec<_>>>()?;
2743 let joins = stmt
2744 .joins
2745 .iter()
2746 .map(|j| {
2747 let on_clause = j
2748 .on_clause
2749 .as_ref()
2750 .map(|e| materialize_expr(e, exec_sub))
2751 .transpose()?;
2752 Ok(JoinClause {
2753 join_type: j.join_type,
2754 table: j.table.clone(),
2755 on_clause,
2756 })
2757 })
2758 .collect::<Result<Vec<_>>>()?;
2759 let group_by = stmt
2760 .group_by
2761 .iter()
2762 .map(|e| materialize_expr(e, exec_sub))
2763 .collect::<Result<Vec<_>>>()?;
2764 Ok(SelectStmt {
2765 columns,
2766 from: stmt.from.clone(),
2767 from_alias: stmt.from_alias.clone(),
2768 joins,
2769 distinct: stmt.distinct,
2770 where_clause,
2771 order_by,
2772 limit: stmt.limit.clone(),
2773 offset: stmt.offset.clone(),
2774 group_by,
2775 having,
2776 })
2777}
2778
2779type CteContext = HashMap<String, QueryResult>;
2780type ScanTableFn<'a> = &'a mut dyn FnMut(&str) -> Result<(TableSchema, Vec<Vec<Value>>)>;
2781
2782fn exec_subquery_read(
2783 db: &Database,
2784 schema: &SchemaManager,
2785 stmt: &SelectStmt,
2786 ctes: &CteContext,
2787) -> Result<QueryResult> {
2788 match exec_select(db, schema, stmt, ctes)? {
2789 ExecutionResult::Query(qr) => Ok(qr),
2790 _ => Ok(QueryResult {
2791 columns: vec![],
2792 rows: vec![],
2793 }),
2794 }
2795}
2796
2797fn exec_subquery_write(
2798 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
2799 schema: &SchemaManager,
2800 stmt: &SelectStmt,
2801 ctes: &CteContext,
2802) -> Result<QueryResult> {
2803 match exec_select_in_txn(wtx, schema, stmt, ctes)? {
2804 ExecutionResult::Query(qr) => Ok(qr),
2805 _ => Ok(QueryResult {
2806 columns: vec![],
2807 rows: vec![],
2808 }),
2809 }
2810}
2811
2812fn update_has_subquery(stmt: &UpdateStmt) -> bool {
2813 stmt.where_clause.as_ref().is_some_and(has_subquery)
2814 || stmt.assignments.iter().any(|(_, e)| has_subquery(e))
2815}
2816
2817fn materialize_update(
2818 stmt: &UpdateStmt,
2819 exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
2820) -> Result<UpdateStmt> {
2821 let where_clause = stmt
2822 .where_clause
2823 .as_ref()
2824 .map(|e| materialize_expr(e, exec_sub))
2825 .transpose()?;
2826 let assignments = stmt
2827 .assignments
2828 .iter()
2829 .map(|(name, expr)| Ok((name.clone(), materialize_expr(expr, exec_sub)?)))
2830 .collect::<Result<Vec<_>>>()?;
2831 Ok(UpdateStmt {
2832 table: stmt.table.clone(),
2833 assignments,
2834 where_clause,
2835 })
2836}
2837
2838fn delete_has_subquery(stmt: &DeleteStmt) -> bool {
2839 stmt.where_clause.as_ref().is_some_and(has_subquery)
2840}
2841
2842fn materialize_delete(
2843 stmt: &DeleteStmt,
2844 exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
2845) -> Result<DeleteStmt> {
2846 let where_clause = stmt
2847 .where_clause
2848 .as_ref()
2849 .map(|e| materialize_expr(e, exec_sub))
2850 .transpose()?;
2851 Ok(DeleteStmt {
2852 table: stmt.table.clone(),
2853 where_clause,
2854 })
2855}
2856
2857fn insert_has_subquery(stmt: &InsertStmt) -> bool {
2858 match &stmt.source {
2859 InsertSource::Values(rows) => rows.iter().any(|row| row.iter().any(has_subquery)),
2860 InsertSource::Select(sq) => {
2861 sq.ctes.iter().any(|c| query_body_has_subquery(&c.body))
2862 || query_body_has_subquery(&sq.body)
2863 }
2864 }
2865}
2866
2867fn query_body_has_subquery(body: &QueryBody) -> bool {
2868 match body {
2869 QueryBody::Select(sel) => stmt_has_subquery(sel),
2870 QueryBody::Compound(comp) => {
2871 query_body_has_subquery(&comp.left) || query_body_has_subquery(&comp.right)
2872 }
2873 }
2874}
2875
2876fn materialize_insert(
2877 stmt: &InsertStmt,
2878 exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
2879) -> Result<InsertStmt> {
2880 let source = match &stmt.source {
2881 InsertSource::Values(rows) => {
2882 let mat = rows
2883 .iter()
2884 .map(|row| {
2885 row.iter()
2886 .map(|e| materialize_expr(e, exec_sub))
2887 .collect::<Result<Vec<_>>>()
2888 })
2889 .collect::<Result<Vec<_>>>()?;
2890 InsertSource::Values(mat)
2891 }
2892 InsertSource::Select(sq) => {
2893 let ctes = sq
2894 .ctes
2895 .iter()
2896 .map(|c| {
2897 Ok(CteDefinition {
2898 name: c.name.clone(),
2899 column_aliases: c.column_aliases.clone(),
2900 body: materialize_query_body(&c.body, exec_sub)?,
2901 })
2902 })
2903 .collect::<Result<Vec<_>>>()?;
2904 let body = materialize_query_body(&sq.body, exec_sub)?;
2905 InsertSource::Select(Box::new(SelectQuery {
2906 ctes,
2907 recursive: sq.recursive,
2908 body,
2909 }))
2910 }
2911 };
2912 Ok(InsertStmt {
2913 table: stmt.table.clone(),
2914 columns: stmt.columns.clone(),
2915 source,
2916 })
2917}
2918
2919fn materialize_query_body(
2920 body: &QueryBody,
2921 exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
2922) -> Result<QueryBody> {
2923 match body {
2924 QueryBody::Select(sel) => Ok(QueryBody::Select(Box::new(materialize_stmt(
2925 sel, exec_sub,
2926 )?))),
2927 QueryBody::Compound(comp) => Ok(QueryBody::Compound(CompoundSelect {
2928 op: comp.op.clone(),
2929 all: comp.all,
2930 left: Box::new(materialize_query_body(&comp.left, exec_sub)?),
2931 right: Box::new(materialize_query_body(&comp.right, exec_sub)?),
2932 order_by: comp.order_by.clone(),
2933 limit: comp.limit.clone(),
2934 offset: comp.offset.clone(),
2935 })),
2936 }
2937}
2938
2939fn exec_query_body(
2940 db: &Database,
2941 schema: &SchemaManager,
2942 body: &QueryBody,
2943 ctes: &CteContext,
2944) -> Result<ExecutionResult> {
2945 match body {
2946 QueryBody::Select(sel) => exec_select(db, schema, sel, ctes),
2947 QueryBody::Compound(comp) => exec_compound_select(db, schema, comp, ctes),
2948 }
2949}
2950
2951fn exec_query_body_in_txn(
2952 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
2953 schema: &SchemaManager,
2954 body: &QueryBody,
2955 ctes: &CteContext,
2956) -> Result<ExecutionResult> {
2957 match body {
2958 QueryBody::Select(sel) => exec_select_in_txn(wtx, schema, sel, ctes),
2959 QueryBody::Compound(comp) => exec_compound_select_in_txn(wtx, schema, comp, ctes),
2960 }
2961}
2962
2963fn exec_query_body_read(
2964 db: &Database,
2965 schema: &SchemaManager,
2966 body: &QueryBody,
2967 ctes: &CteContext,
2968) -> Result<QueryResult> {
2969 match exec_query_body(db, schema, body, ctes)? {
2970 ExecutionResult::Query(qr) => Ok(qr),
2971 _ => Ok(QueryResult {
2972 columns: vec![],
2973 rows: vec![],
2974 }),
2975 }
2976}
2977
2978fn exec_query_body_write(
2979 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
2980 schema: &SchemaManager,
2981 body: &QueryBody,
2982 ctes: &CteContext,
2983) -> Result<QueryResult> {
2984 match exec_query_body_in_txn(wtx, schema, body, ctes)? {
2985 ExecutionResult::Query(qr) => Ok(qr),
2986 _ => Ok(QueryResult {
2987 columns: vec![],
2988 rows: vec![],
2989 }),
2990 }
2991}
2992
2993fn exec_compound_select(
2994 db: &Database,
2995 schema: &SchemaManager,
2996 comp: &CompoundSelect,
2997 ctes: &CteContext,
2998) -> Result<ExecutionResult> {
2999 let left_qr = match exec_query_body(db, schema, &comp.left, ctes)? {
3000 ExecutionResult::Query(qr) => qr,
3001 _ => QueryResult {
3002 columns: vec![],
3003 rows: vec![],
3004 },
3005 };
3006 let right_qr = match exec_query_body(db, schema, &comp.right, ctes)? {
3007 ExecutionResult::Query(qr) => qr,
3008 _ => QueryResult {
3009 columns: vec![],
3010 rows: vec![],
3011 },
3012 };
3013 apply_set_operation(comp, left_qr, right_qr)
3014}
3015
3016fn exec_compound_select_in_txn(
3017 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
3018 schema: &SchemaManager,
3019 comp: &CompoundSelect,
3020 ctes: &CteContext,
3021) -> Result<ExecutionResult> {
3022 let left_qr = match exec_query_body_in_txn(wtx, schema, &comp.left, ctes)? {
3023 ExecutionResult::Query(qr) => qr,
3024 _ => QueryResult {
3025 columns: vec![],
3026 rows: vec![],
3027 },
3028 };
3029 let right_qr = match exec_query_body_in_txn(wtx, schema, &comp.right, ctes)? {
3030 ExecutionResult::Query(qr) => qr,
3031 _ => QueryResult {
3032 columns: vec![],
3033 rows: vec![],
3034 },
3035 };
3036 apply_set_operation(comp, left_qr, right_qr)
3037}
3038
3039fn exec_select_query(
3042 db: &Database,
3043 schema: &SchemaManager,
3044 sq: &SelectQuery,
3045) -> Result<ExecutionResult> {
3046 if let Some(fused) = try_fuse_cte(sq) {
3047 let empty = CteContext::new();
3048 return exec_query_body(db, schema, &fused, &empty);
3049 }
3050 let ctes = materialize_all_ctes(&sq.ctes, sq.recursive, &mut |body, ctx| {
3051 exec_query_body_read(db, schema, body, ctx)
3052 })?;
3053 exec_query_body(db, schema, &sq.body, &ctes)
3054}
3055
3056fn exec_select_query_in_txn(
3057 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
3058 schema: &SchemaManager,
3059 sq: &SelectQuery,
3060) -> Result<ExecutionResult> {
3061 if let Some(fused) = try_fuse_cte(sq) {
3062 let empty = CteContext::new();
3063 return exec_query_body_in_txn(wtx, schema, &fused, &empty);
3064 }
3065 let ctes = materialize_all_ctes(&sq.ctes, sq.recursive, &mut |body, ctx| {
3066 exec_query_body_write(wtx, schema, body, ctx)
3067 })?;
3068 exec_query_body_in_txn(wtx, schema, &sq.body, &ctes)
3069}
3070
3071fn try_fuse_cte(sq: &SelectQuery) -> Option<QueryBody> {
3073 if sq.ctes.len() != 1 || sq.recursive {
3074 return None;
3075 }
3076 let cte = &sq.ctes[0];
3077 if !cte.column_aliases.is_empty() {
3078 return None;
3079 }
3080
3081 let inner = match &cte.body {
3082 QueryBody::Select(s) => s.as_ref(),
3083 _ => return None,
3084 };
3085
3086 if !inner.joins.is_empty()
3087 || !inner.group_by.is_empty()
3088 || inner.distinct
3089 || inner.having.is_some()
3090 || inner.limit.is_some()
3091 || inner.offset.is_some()
3092 || !inner.order_by.is_empty()
3093 || stmt_has_subquery(inner)
3094 {
3095 return None;
3096 }
3097
3098 let all_simple_refs = inner.columns.iter().all(|c| match c {
3099 SelectColumn::AllColumns => true,
3100 SelectColumn::Expr { expr, alias } => alias.is_none() && matches!(expr, Expr::Column(_)),
3101 });
3102 if !all_simple_refs {
3103 return None;
3104 }
3105
3106 let outer = match &sq.body {
3107 QueryBody::Select(s) => s.as_ref(),
3108 _ => return None,
3109 };
3110 if !outer.from.eq_ignore_ascii_case(&cte.name) || !outer.joins.is_empty() {
3111 return None;
3112 }
3113
3114 let merged_where = match (&inner.where_clause, &outer.where_clause) {
3115 (Some(iw), Some(ow)) => Some(Expr::BinaryOp {
3116 left: Box::new(iw.clone()),
3117 op: BinOp::And,
3118 right: Box::new(ow.clone()),
3119 }),
3120 (Some(w), None) | (None, Some(w)) => Some(w.clone()),
3121 (None, None) => None,
3122 };
3123
3124 let fused = SelectStmt {
3125 columns: outer.columns.clone(),
3126 from: inner.from.clone(),
3127 from_alias: inner.from_alias.clone(),
3128 joins: vec![],
3129 distinct: outer.distinct,
3130 where_clause: merged_where,
3131 order_by: outer.order_by.clone(),
3132 limit: outer.limit.clone(),
3133 offset: outer.offset.clone(),
3134 group_by: outer.group_by.clone(),
3135 having: outer.having.clone(),
3136 };
3137
3138 Some(QueryBody::Select(Box::new(fused)))
3139}
3140
3141fn materialize_all_ctes(
3142 defs: &[CteDefinition],
3143 recursive: bool,
3144 exec_body: &mut dyn FnMut(&QueryBody, &CteContext) -> Result<QueryResult>,
3145) -> Result<CteContext> {
3146 let mut ctx = CteContext::new();
3147 for cte in defs {
3148 let qr = if recursive && cte_body_references_self(&cte.body, &cte.name) {
3149 materialize_recursive_cte(cte, &ctx, exec_body)?
3150 } else {
3151 materialize_cte(cte, &ctx, exec_body)?
3152 };
3153 ctx.insert(cte.name.clone(), qr);
3154 }
3155 Ok(ctx)
3156}
3157
3158fn materialize_cte(
3159 cte: &CteDefinition,
3160 ctx: &CteContext,
3161 exec_body: &mut dyn FnMut(&QueryBody, &CteContext) -> Result<QueryResult>,
3162) -> Result<QueryResult> {
3163 let mut qr = exec_body(&cte.body, ctx)?;
3164 if !cte.column_aliases.is_empty() {
3165 if cte.column_aliases.len() != qr.columns.len() {
3166 return Err(SqlError::CteColumnAliasMismatch {
3167 name: cte.name.clone(),
3168 expected: cte.column_aliases.len(),
3169 got: qr.columns.len(),
3170 });
3171 }
3172 qr.columns = cte.column_aliases.clone();
3173 }
3174 Ok(qr)
3175}
3176
3177const MAX_RECURSIVE_ITERATIONS: usize = 10_000;
3178
3179fn materialize_recursive_cte(
3180 cte: &CteDefinition,
3181 ctx: &CteContext,
3182 exec_body: &mut dyn FnMut(&QueryBody, &CteContext) -> Result<QueryResult>,
3183) -> Result<QueryResult> {
3184 let (anchor_body, recursive_body, union_all) = match &cte.body {
3185 QueryBody::Compound(comp) if matches!(comp.op, SetOp::Union) => {
3186 (&*comp.left, &*comp.right, comp.all)
3187 }
3188 _ => return Err(SqlError::RecursiveCteNoUnion(cte.name.clone())),
3189 };
3190
3191 let anchor_qr = exec_body(anchor_body, ctx)?;
3192 let columns = if !cte.column_aliases.is_empty() {
3193 if cte.column_aliases.len() != anchor_qr.columns.len() {
3194 return Err(SqlError::CteColumnAliasMismatch {
3195 name: cte.name.clone(),
3196 expected: cte.column_aliases.len(),
3197 got: anchor_qr.columns.len(),
3198 });
3199 }
3200 cte.column_aliases.clone()
3201 } else {
3202 anchor_qr.columns
3203 };
3204
3205 let mut accumulated = anchor_qr.rows;
3206 let mut working_rows = accumulated.clone();
3207 let mut seen = if !union_all {
3208 let mut s = std::collections::HashSet::new();
3209 for row in &accumulated {
3210 s.insert(row.clone());
3211 }
3212 Some(s)
3213 } else {
3214 None
3215 };
3216
3217 let cte_key = cte.name.clone();
3218
3219 let fast_sel = match recursive_body {
3220 QueryBody::Select(sel)
3221 if sel.from.eq_ignore_ascii_case(&cte_key)
3222 && sel.joins.is_empty()
3223 && sel.group_by.is_empty()
3224 && !sel.distinct
3225 && sel.having.is_none()
3226 && sel.limit.is_none()
3227 && sel.offset.is_none()
3228 && sel.order_by.is_empty()
3229 && !stmt_has_subquery(sel) =>
3230 {
3231 Some(sel.as_ref())
3232 }
3233 _ => None,
3234 };
3235
3236 if let Some(sel) = fast_sel {
3237 let cte_cols: Vec<ColumnDef> = columns
3238 .iter()
3239 .enumerate()
3240 .map(|(i, name)| ColumnDef {
3241 name: name.clone(),
3242 data_type: DataType::Null,
3243 nullable: true,
3244 position: i as u16,
3245 default_expr: None,
3246 default_sql: None,
3247 check_expr: None,
3248 check_sql: None,
3249 check_name: None,
3250 })
3251 .collect();
3252 let col_map = ColumnMap::new(&cte_cols);
3253 let ncols = sel.columns.len();
3254
3255 for iteration in 0..MAX_RECURSIVE_ITERATIONS {
3256 if working_rows.is_empty() {
3257 break;
3258 }
3259
3260 let mut step_rows = Vec::with_capacity(working_rows.len());
3261 for row in &working_rows {
3262 if let Some(ref w) = sel.where_clause {
3263 match eval_expr(w, &col_map, row) {
3264 Ok(val) if is_truthy(&val) => {}
3265 Ok(_) => continue,
3266 Err(e) => return Err(e),
3267 }
3268 }
3269 let mut out = Vec::with_capacity(ncols);
3270 for col in &sel.columns {
3271 match col {
3272 SelectColumn::Expr { expr, .. } => {
3273 out.push(eval_expr(expr, &col_map, row)?);
3274 }
3275 SelectColumn::AllColumns => {
3276 out.extend_from_slice(row);
3277 }
3278 }
3279 }
3280 step_rows.push(out);
3281 }
3282
3283 if step_rows.is_empty() {
3284 break;
3285 }
3286
3287 let new_rows = if let Some(ref mut seen_set) = seen {
3288 step_rows
3289 .into_iter()
3290 .filter(|r| seen_set.insert(r.clone()))
3291 .collect::<Vec<_>>()
3292 } else {
3293 step_rows
3294 };
3295
3296 if new_rows.is_empty() {
3297 break;
3298 }
3299
3300 accumulated.extend_from_slice(&new_rows);
3301 working_rows = new_rows;
3302
3303 if iteration == MAX_RECURSIVE_ITERATIONS - 1 {
3304 return Err(SqlError::RecursiveCteMaxIterations(
3305 cte_key.clone(),
3306 MAX_RECURSIVE_ITERATIONS,
3307 ));
3308 }
3309 }
3310 } else {
3311 let mut iter_ctx = ctx.clone();
3312 iter_ctx.insert(
3313 cte_key.clone(),
3314 QueryResult {
3315 columns: columns.clone(),
3316 rows: working_rows,
3317 },
3318 );
3319
3320 for iteration in 0..MAX_RECURSIVE_ITERATIONS {
3321 if iter_ctx.get(&cte_key).unwrap().rows.is_empty() {
3322 break;
3323 }
3324
3325 let iter_qr = exec_body(recursive_body, &iter_ctx)?;
3326 if iter_qr.rows.is_empty() {
3327 break;
3328 }
3329
3330 let new_rows = if let Some(ref mut seen_set) = seen {
3331 iter_qr
3332 .rows
3333 .into_iter()
3334 .filter(|r| seen_set.insert(r.clone()))
3335 .collect::<Vec<_>>()
3336 } else {
3337 iter_qr.rows
3338 };
3339
3340 if new_rows.is_empty() {
3341 break;
3342 }
3343
3344 accumulated.extend_from_slice(&new_rows);
3345 iter_ctx.get_mut(&cte_key).unwrap().rows = new_rows;
3346
3347 if iteration == MAX_RECURSIVE_ITERATIONS - 1 {
3348 return Err(SqlError::RecursiveCteMaxIterations(
3349 cte_key.clone(),
3350 MAX_RECURSIVE_ITERATIONS,
3351 ));
3352 }
3353 }
3354
3355 iter_ctx.remove(&cte_key);
3356 }
3357
3358 Ok(QueryResult {
3359 columns,
3360 rows: accumulated,
3361 })
3362}
3363
3364fn cte_body_references_self(body: &QueryBody, name: &str) -> bool {
3365 match body {
3366 QueryBody::Select(sel) => {
3367 sel.from.eq_ignore_ascii_case(name)
3368 || sel
3369 .joins
3370 .iter()
3371 .any(|j| j.table.name.eq_ignore_ascii_case(name))
3372 }
3373 QueryBody::Compound(comp) => {
3374 cte_body_references_self(&comp.left, name)
3375 || cte_body_references_self(&comp.right, name)
3376 }
3377 }
3378}
3379
3380fn build_cte_schema(name: &str, qr: &QueryResult) -> TableSchema {
3381 let columns: Vec<ColumnDef> = qr
3382 .columns
3383 .iter()
3384 .enumerate()
3385 .map(|(i, col_name)| ColumnDef {
3386 name: col_name.clone(),
3387 data_type: DataType::Null,
3388 nullable: true,
3389 position: i as u16,
3390 default_expr: None,
3391 default_sql: None,
3392 check_expr: None,
3393 check_sql: None,
3394 check_name: None,
3395 })
3396 .collect();
3397 TableSchema::new(name.into(), columns, vec![], vec![], vec![], vec![])
3398}
3399
3400fn exec_select_from_cte(
3401 cte_result: &QueryResult,
3402 stmt: &SelectStmt,
3403 exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
3404) -> Result<ExecutionResult> {
3405 let cte_schema = build_cte_schema(&stmt.from, cte_result);
3406 let actual_stmt;
3407 let s = if stmt_has_subquery(stmt) {
3408 actual_stmt = materialize_stmt(stmt, exec_sub)?;
3409 &actual_stmt
3410 } else {
3411 stmt
3412 };
3413
3414 let has_aggregates = s.columns.iter().any(|c| match c {
3415 SelectColumn::Expr { expr, .. } => is_aggregate_expr(expr),
3416 _ => false,
3417 });
3418
3419 if has_aggregates || !s.group_by.is_empty() {
3420 if let Some(ref where_expr) = s.where_clause {
3421 let col_map = ColumnMap::new(&cte_schema.columns);
3422 let filtered: Vec<Vec<Value>> = cte_result
3423 .rows
3424 .iter()
3425 .filter(|row| match eval_expr(where_expr, &col_map, row) {
3426 Ok(val) => is_truthy(&val),
3427 _ => false,
3428 })
3429 .cloned()
3430 .collect();
3431 return exec_aggregate(&cte_schema.columns, &filtered, s);
3432 }
3433 return exec_aggregate(&cte_schema.columns, &cte_result.rows, s);
3434 }
3435
3436 process_select(&cte_schema.columns, cte_result.rows.clone(), s, false)
3437}
3438
3439fn exec_select_join_with_ctes(
3440 stmt: &SelectStmt,
3441 ctes: &CteContext,
3442 scan_table: ScanTableFn<'_>,
3443) -> Result<ExecutionResult> {
3444 let (from_schema, from_rows) = resolve_table_or_cte(&stmt.from, ctes, scan_table)?;
3445 let from_alias = table_alias_or_name(&stmt.from, &stmt.from_alias);
3446
3447 let mut tables: Vec<(String, TableSchema)> = vec![(from_alias.clone(), from_schema)];
3448 let mut join_rows: Vec<Vec<Vec<Value>>> = Vec::new();
3449
3450 for join in &stmt.joins {
3451 let jname = &join.table.name;
3452 let (js, jrows) = resolve_table_or_cte(jname, ctes, scan_table)?;
3453 let jalias = table_alias_or_name(jname, &join.table.alias);
3454 tables.push((jalias, js));
3455 join_rows.push(jrows);
3456 }
3457
3458 let mut outer_rows = from_rows;
3459 let mut cur_tables: Vec<(String, &TableSchema)> = vec![(from_alias.clone(), &tables[0].1)];
3460
3461 for (ji, join) in stmt.joins.iter().enumerate() {
3462 let inner_schema = &tables[ji + 1].1;
3463 let inner_alias = &tables[ji + 1].0;
3464 let inner_rows = &join_rows[ji];
3465
3466 let mut preview_tables = cur_tables.clone();
3467 preview_tables.push((inner_alias.clone(), inner_schema));
3468 let combined_cols = build_joined_columns(&preview_tables);
3469
3470 let outer_col_count = if outer_rows.is_empty() {
3471 cur_tables.iter().map(|(_, s)| s.columns.len()).sum()
3472 } else {
3473 outer_rows[0].len()
3474 };
3475 let inner_col_count = inner_schema.columns.len();
3476
3477 outer_rows = exec_join_step(
3478 outer_rows,
3479 inner_rows,
3480 join,
3481 &combined_cols,
3482 outer_col_count,
3483 inner_col_count,
3484 None,
3485 None,
3486 );
3487 cur_tables.push((inner_alias.clone(), inner_schema));
3488 }
3489
3490 let joined_cols = build_joined_columns(&cur_tables);
3491 process_select(&joined_cols, outer_rows, stmt, false)
3492}
3493
3494fn resolve_table_or_cte(
3495 name: &str,
3496 ctes: &CteContext,
3497 scan_table: ScanTableFn<'_>,
3498) -> Result<(TableSchema, Vec<Vec<Value>>)> {
3499 let lower = name.to_ascii_lowercase();
3500 if let Some(cte_result) = ctes.get(&lower) {
3501 let schema = build_cte_schema(&lower, cte_result);
3502 Ok((schema, cte_result.rows.clone()))
3503 } else {
3504 scan_table(&lower)
3505 }
3506}
3507
3508fn scan_table_read(
3509 db: &Database,
3510 schema: &SchemaManager,
3511 name: &str,
3512) -> Result<(TableSchema, Vec<Vec<Value>>)> {
3513 let table_schema = schema
3514 .get(name)
3515 .ok_or_else(|| SqlError::TableNotFound(name.to_string()))?;
3516 let (rows, _) = collect_rows_read(db, table_schema, &None, None)?;
3517 Ok((table_schema.clone(), rows))
3518}
3519
3520fn scan_table_write(
3521 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
3522 schema: &SchemaManager,
3523 name: &str,
3524) -> Result<(TableSchema, Vec<Vec<Value>>)> {
3525 let table_schema = schema
3526 .get(name)
3527 .ok_or_else(|| SqlError::TableNotFound(name.to_string()))?;
3528 let (rows, _) = collect_rows_write(wtx, table_schema, &None, None)?;
3529 Ok((table_schema.clone(), rows))
3530}
3531
3532fn apply_set_operation(
3533 comp: &CompoundSelect,
3534 left_qr: QueryResult,
3535 right_qr: QueryResult,
3536) -> Result<ExecutionResult> {
3537 if !left_qr.columns.is_empty()
3538 && !right_qr.columns.is_empty()
3539 && left_qr.columns.len() != right_qr.columns.len()
3540 {
3541 return Err(SqlError::CompoundColumnCountMismatch {
3542 left: left_qr.columns.len(),
3543 right: right_qr.columns.len(),
3544 });
3545 }
3546
3547 let columns = left_qr.columns;
3548
3549 let mut rows = match (&comp.op, comp.all) {
3550 (SetOp::Union, true) => {
3551 let mut rows = left_qr.rows;
3552 rows.extend(right_qr.rows);
3553 rows
3554 }
3555 (SetOp::Union, false) => {
3556 let mut seen = std::collections::HashSet::new();
3557 let mut rows = Vec::new();
3558 for row in left_qr.rows.into_iter().chain(right_qr.rows) {
3559 if seen.insert(row.clone()) {
3560 rows.push(row);
3561 }
3562 }
3563 rows
3564 }
3565 (SetOp::Intersect, true) => {
3566 let mut right_counts: std::collections::HashMap<Vec<Value>, usize> =
3567 std::collections::HashMap::new();
3568 for row in &right_qr.rows {
3569 *right_counts.entry(row.clone()).or_insert(0) += 1;
3570 }
3571 let mut rows = Vec::new();
3572 for row in left_qr.rows {
3573 if let Some(count) = right_counts.get_mut(&row) {
3574 if *count > 0 {
3575 *count -= 1;
3576 rows.push(row);
3577 }
3578 }
3579 }
3580 rows
3581 }
3582 (SetOp::Intersect, false) => {
3583 let right_set: std::collections::HashSet<Vec<Value>> =
3584 right_qr.rows.into_iter().collect();
3585 let mut seen = std::collections::HashSet::new();
3586 let mut rows = Vec::new();
3587 for row in left_qr.rows {
3588 if right_set.contains(&row) && seen.insert(row.clone()) {
3589 rows.push(row);
3590 }
3591 }
3592 rows
3593 }
3594 (SetOp::Except, true) => {
3595 let mut right_counts: std::collections::HashMap<Vec<Value>, usize> =
3596 std::collections::HashMap::new();
3597 for row in &right_qr.rows {
3598 *right_counts.entry(row.clone()).or_insert(0) += 1;
3599 }
3600 let mut rows = Vec::new();
3601 for row in left_qr.rows {
3602 if let Some(count) = right_counts.get_mut(&row) {
3603 if *count > 0 {
3604 *count -= 1;
3605 continue;
3606 }
3607 }
3608 rows.push(row);
3609 }
3610 rows
3611 }
3612 (SetOp::Except, false) => {
3613 let right_set: std::collections::HashSet<Vec<Value>> =
3614 right_qr.rows.into_iter().collect();
3615 let mut seen = std::collections::HashSet::new();
3616 let mut rows = Vec::new();
3617 for row in left_qr.rows {
3618 if !right_set.contains(&row) && seen.insert(row.clone()) {
3619 rows.push(row);
3620 }
3621 }
3622 rows
3623 }
3624 };
3625
3626 if !comp.order_by.is_empty() {
3627 let col_defs: Vec<crate::types::ColumnDef> = columns
3628 .iter()
3629 .enumerate()
3630 .map(|(i, name)| crate::types::ColumnDef {
3631 name: name.clone(),
3632 data_type: crate::types::DataType::Null,
3633 nullable: true,
3634 position: i as u16,
3635 default_expr: None,
3636 default_sql: None,
3637 check_expr: None,
3638 check_sql: None,
3639 check_name: None,
3640 })
3641 .collect();
3642 sort_rows(&mut rows, &comp.order_by, &col_defs)?;
3643 }
3644
3645 if let Some(ref offset_expr) = comp.offset {
3646 let offset = eval_const_int(offset_expr)?.max(0) as usize;
3647 if offset < rows.len() {
3648 rows = rows.split_off(offset);
3649 } else {
3650 rows.clear();
3651 }
3652 }
3653
3654 if let Some(ref limit_expr) = comp.limit {
3655 let limit = eval_const_int(limit_expr)?.max(0) as usize;
3656 rows.truncate(limit);
3657 }
3658
3659 Ok(ExecutionResult::Query(QueryResult { columns, rows }))
3660}
3661
3662fn exec_select(
3663 db: &Database,
3664 schema: &SchemaManager,
3665 stmt: &SelectStmt,
3666 ctes: &CteContext,
3667) -> Result<ExecutionResult> {
3668 let materialized;
3669 let stmt = if stmt_has_subquery(stmt) {
3670 materialized =
3671 materialize_stmt(stmt, &mut |sub| exec_subquery_read(db, schema, sub, ctes))?;
3672 &materialized
3673 } else {
3674 stmt
3675 };
3676
3677 if stmt.from.is_empty() {
3678 return exec_select_no_from(stmt);
3679 }
3680
3681 let lower_name = stmt.from.to_ascii_lowercase();
3682
3683 if let Some(cte_result) = ctes.get(&lower_name) {
3684 if stmt.joins.is_empty() {
3685 return exec_select_from_cte(cte_result, stmt, &mut |sub| {
3686 exec_subquery_read(db, schema, sub, ctes)
3687 });
3688 } else {
3689 return exec_select_join_with_ctes(stmt, ctes, &mut |name| {
3690 scan_table_read(db, schema, name)
3691 });
3692 }
3693 }
3694
3695 if !ctes.is_empty()
3696 && stmt
3697 .joins
3698 .iter()
3699 .any(|j| ctes.contains_key(&j.table.name.to_ascii_lowercase()))
3700 {
3701 return exec_select_join_with_ctes(stmt, ctes, &mut |name| {
3702 scan_table_read(db, schema, name)
3703 });
3704 }
3705
3706 let table_schema = schema
3707 .get(&lower_name)
3708 .ok_or_else(|| SqlError::TableNotFound(stmt.from.clone()))?;
3709
3710 if !stmt.joins.is_empty() {
3711 return exec_select_join(db, schema, stmt);
3712 }
3713
3714 if let Some(result) = try_count_star_shortcut(stmt, || {
3715 let mut rtx = db.begin_read();
3716 rtx.table_entry_count(lower_name.as_bytes())
3717 .map_err(SqlError::Storage)
3718 })? {
3719 return Ok(result);
3720 }
3721
3722 if let Some(plan) = StreamAggPlan::try_new(stmt, table_schema)? {
3723 let mut states: Vec<AggState> = plan.ops.iter().map(|(op, _)| AggState::new(op)).collect();
3724 let mut scan_err: Option<SqlError> = None;
3725 let mut rtx = db.begin_read();
3726 if stmt.where_clause.is_none() {
3727 rtx.table_scan_raw(lower_name.as_bytes(), |key, value| {
3728 plan.feed_row_raw(key, value, &mut states, &mut scan_err)
3729 })
3730 .map_err(SqlError::Storage)?;
3731 } else {
3732 let col_map = ColumnMap::new(&table_schema.columns);
3733 rtx.table_scan_raw(lower_name.as_bytes(), |key, value| {
3734 plan.feed_row(
3735 key,
3736 value,
3737 table_schema,
3738 &col_map,
3739 &stmt.where_clause,
3740 &mut states,
3741 &mut scan_err,
3742 )
3743 })
3744 .map_err(SqlError::Storage)?;
3745 }
3746 if let Some(e) = scan_err {
3747 return Err(e);
3748 }
3749 return Ok(plan.finish(states));
3750 }
3751
3752 if let Some(plan) = StreamGroupByPlan::try_new(stmt, table_schema)? {
3753 let lower = lower_name.clone();
3754 let mut rtx = db.begin_read();
3755 return plan
3756 .execute_scan(|cb| rtx.table_scan_raw(lower.as_bytes(), |key, value| cb(key, value)));
3757 }
3758
3759 if let Some(plan) = TopKScanPlan::try_new(stmt, table_schema)? {
3760 let lower = lower_name.clone();
3761 let mut rtx = db.begin_read();
3762 return plan.execute_scan(table_schema, stmt, |cb| {
3763 rtx.table_scan_raw(lower.as_bytes(), |key, value| cb(key, value))
3764 });
3765 }
3766
3767 let scan_limit = compute_scan_limit(stmt);
3768 let (rows, predicate_applied) =
3769 collect_rows_read(db, table_schema, &stmt.where_clause, scan_limit)?;
3770 process_select(&table_schema.columns, rows, stmt, predicate_applied)
3771}
3772
3773fn compute_scan_limit(stmt: &SelectStmt) -> Option<usize> {
3774 if !stmt.order_by.is_empty()
3775 || !stmt.group_by.is_empty()
3776 || stmt.distinct
3777 || stmt.having.is_some()
3778 {
3779 return None;
3780 }
3781 let has_aggregates = stmt.columns.iter().any(|c| match c {
3782 SelectColumn::Expr { expr, .. } => is_aggregate_expr(expr),
3783 _ => false,
3784 });
3785 if has_aggregates {
3786 return None;
3787 }
3788 let limit = stmt.limit.as_ref()?;
3789 let limit_val = eval_const_int(limit).ok()?.max(0) as usize;
3790 let offset_val = stmt
3791 .offset
3792 .as_ref()
3793 .and_then(|e| eval_const_int(e).ok())
3794 .unwrap_or(0)
3795 .max(0) as usize;
3796 Some(limit_val.saturating_add(offset_val))
3797}
3798
3799fn try_count_star_shortcut(
3800 stmt: &SelectStmt,
3801 get_count: impl FnOnce() -> Result<u64>,
3802) -> Result<Option<ExecutionResult>> {
3803 if stmt.columns.len() != 1
3804 || stmt.where_clause.is_some()
3805 || !stmt.group_by.is_empty()
3806 || stmt.having.is_some()
3807 {
3808 return Ok(None);
3809 }
3810 let col = match &stmt.columns[0] {
3811 SelectColumn::Expr { expr, alias } => (expr, alias),
3812 _ => return Ok(None),
3813 };
3814 if !matches!(col.0, Expr::CountStar) {
3815 return Ok(None);
3816 }
3817 let count = get_count()? as i64;
3818 let col_name = col.1.as_deref().unwrap_or("COUNT(*)").to_string();
3819 Ok(Some(ExecutionResult::Query(QueryResult {
3820 columns: vec![col_name],
3821 rows: vec![vec![Value::Integer(count)]],
3822 })))
3823}
3824
3825enum StreamAgg {
3826 CountStar,
3827 Count(usize),
3828 Sum(usize),
3829 Avg(usize),
3830 Min(usize),
3831 Max(usize),
3832}
3833
3834enum RawAggTarget {
3835 CountStar,
3836 Pk(usize),
3837 NonPk(usize),
3838}
3839
3840enum AggState {
3841 CountStar(i64),
3842 Count(i64),
3843 Sum {
3844 int_sum: i64,
3845 real_sum: f64,
3846 has_real: bool,
3847 all_null: bool,
3848 },
3849 Avg {
3850 sum: f64,
3851 count: i64,
3852 },
3853 Min(Option<Value>),
3854 Max(Option<Value>),
3855}
3856
3857impl AggState {
3858 fn new(op: &StreamAgg) -> Self {
3859 match op {
3860 StreamAgg::CountStar => AggState::CountStar(0),
3861 StreamAgg::Count(_) => AggState::Count(0),
3862 StreamAgg::Sum(_) => AggState::Sum {
3863 int_sum: 0,
3864 real_sum: 0.0,
3865 has_real: false,
3866 all_null: true,
3867 },
3868 StreamAgg::Avg(_) => AggState::Avg { sum: 0.0, count: 0 },
3869 StreamAgg::Min(_) => AggState::Min(None),
3870 StreamAgg::Max(_) => AggState::Max(None),
3871 }
3872 }
3873
3874 fn feed_val(&mut self, val: &Value) -> Result<()> {
3875 match self {
3876 AggState::CountStar(c) => {
3877 *c += 1;
3878 }
3879 AggState::Count(c) => {
3880 if !val.is_null() {
3881 *c += 1;
3882 }
3883 }
3884 AggState::Sum {
3885 int_sum,
3886 real_sum,
3887 has_real,
3888 all_null,
3889 } => match val {
3890 Value::Integer(i) => {
3891 *int_sum += i;
3892 *all_null = false;
3893 }
3894 Value::Real(r) => {
3895 *real_sum += r;
3896 *has_real = true;
3897 *all_null = false;
3898 }
3899 Value::Null => {}
3900 _ => {
3901 return Err(SqlError::TypeMismatch {
3902 expected: "numeric".into(),
3903 got: val.data_type().to_string(),
3904 })
3905 }
3906 },
3907 AggState::Avg { sum, count } => match val {
3908 Value::Integer(i) => {
3909 *sum += *i as f64;
3910 *count += 1;
3911 }
3912 Value::Real(r) => {
3913 *sum += r;
3914 *count += 1;
3915 }
3916 Value::Null => {}
3917 _ => {
3918 return Err(SqlError::TypeMismatch {
3919 expected: "numeric".into(),
3920 got: val.data_type().to_string(),
3921 })
3922 }
3923 },
3924 AggState::Min(cur) => {
3925 if !val.is_null() {
3926 *cur = Some(match cur.take() {
3927 None => val.clone(),
3928 Some(m) => {
3929 if val < &m {
3930 val.clone()
3931 } else {
3932 m
3933 }
3934 }
3935 });
3936 }
3937 }
3938 AggState::Max(cur) => {
3939 if !val.is_null() {
3940 *cur = Some(match cur.take() {
3941 None => val.clone(),
3942 Some(m) => {
3943 if val > &m {
3944 val.clone()
3945 } else {
3946 m
3947 }
3948 }
3949 });
3950 }
3951 }
3952 }
3953 Ok(())
3954 }
3955
3956 fn feed_raw(&mut self, raw: &RawColumn) -> Result<()> {
3957 match self {
3958 AggState::CountStar(c) => {
3959 *c += 1;
3960 }
3961 AggState::Count(c) => {
3962 if !matches!(raw, RawColumn::Null) {
3963 *c += 1;
3964 }
3965 }
3966 AggState::Sum {
3967 int_sum,
3968 real_sum,
3969 has_real,
3970 all_null,
3971 } => match raw {
3972 RawColumn::Integer(i) => {
3973 *int_sum += i;
3974 *all_null = false;
3975 }
3976 RawColumn::Real(r) => {
3977 *real_sum += r;
3978 *has_real = true;
3979 *all_null = false;
3980 }
3981 RawColumn::Null => {}
3982 _ => {
3983 return Err(SqlError::TypeMismatch {
3984 expected: "numeric".into(),
3985 got: "non-numeric".into(),
3986 })
3987 }
3988 },
3989 AggState::Avg { sum, count } => match raw {
3990 RawColumn::Integer(i) => {
3991 *sum += *i as f64;
3992 *count += 1;
3993 }
3994 RawColumn::Real(r) => {
3995 *sum += r;
3996 *count += 1;
3997 }
3998 RawColumn::Null => {}
3999 _ => {
4000 return Err(SqlError::TypeMismatch {
4001 expected: "numeric".into(),
4002 got: "non-numeric".into(),
4003 })
4004 }
4005 },
4006 AggState::Min(cur) => {
4007 if !matches!(raw, RawColumn::Null) {
4008 let val = raw.to_value();
4009 *cur = Some(match cur.take() {
4010 None => val,
4011 Some(m) => {
4012 if val < m {
4013 val
4014 } else {
4015 m
4016 }
4017 }
4018 });
4019 }
4020 }
4021 AggState::Max(cur) => {
4022 if !matches!(raw, RawColumn::Null) {
4023 let val = raw.to_value();
4024 *cur = Some(match cur.take() {
4025 None => val,
4026 Some(m) => {
4027 if val > m {
4028 val
4029 } else {
4030 m
4031 }
4032 }
4033 });
4034 }
4035 }
4036 }
4037 Ok(())
4038 }
4039
4040 fn finish(self) -> Value {
4041 match self {
4042 AggState::CountStar(c) | AggState::Count(c) => Value::Integer(c),
4043 AggState::Sum {
4044 int_sum,
4045 real_sum,
4046 has_real,
4047 all_null,
4048 } => {
4049 if all_null {
4050 Value::Null
4051 } else if has_real {
4052 Value::Real(real_sum + int_sum as f64)
4053 } else {
4054 Value::Integer(int_sum)
4055 }
4056 }
4057 AggState::Avg { sum, count } => {
4058 if count == 0 {
4059 Value::Null
4060 } else {
4061 Value::Real(sum / count as f64)
4062 }
4063 }
4064 AggState::Min(v) | AggState::Max(v) => v.unwrap_or(Value::Null),
4065 }
4066 }
4067}
4068
4069struct StreamAggPlan {
4070 ops: Vec<(StreamAgg, String)>,
4071 partial_ctx: Option<PartialDecodeCtx>,
4072 raw_targets: Vec<RawAggTarget>,
4073 num_pk_cols: usize,
4074 nonpk_agg_defaults: Vec<Option<Value>>,
4075}
4076
4077impl StreamAggPlan {
4078 fn try_new(stmt: &SelectStmt, table_schema: &TableSchema) -> Result<Option<Self>> {
4079 if !stmt.group_by.is_empty() || stmt.having.is_some() || !stmt.joins.is_empty() {
4080 return Ok(None);
4081 }
4082
4083 let col_map = ColumnMap::new(&table_schema.columns);
4084 let mut ops: Vec<(StreamAgg, String)> = Vec::new();
4085 for sel_col in &stmt.columns {
4086 let (expr, alias) = match sel_col {
4087 SelectColumn::Expr { expr, alias } => (expr, alias),
4088 _ => return Ok(None),
4089 };
4090 let name = alias
4091 .as_deref()
4092 .unwrap_or(&expr_display_name(expr))
4093 .to_string();
4094 match expr {
4095 Expr::CountStar => ops.push((StreamAgg::CountStar, name)),
4096 Expr::Function {
4097 name: func_name,
4098 args,
4099 } if args.len() == 1 => {
4100 let func = func_name.to_ascii_uppercase();
4101 let col_idx = match resolve_simple_col(&args[0], &col_map) {
4102 Some(idx) => idx,
4103 None => return Ok(None),
4104 };
4105 match func.as_str() {
4106 "COUNT" => ops.push((StreamAgg::Count(col_idx), name)),
4107 "SUM" => ops.push((StreamAgg::Sum(col_idx), name)),
4108 "AVG" => ops.push((StreamAgg::Avg(col_idx), name)),
4109 "MIN" => ops.push((StreamAgg::Min(col_idx), name)),
4110 "MAX" => ops.push((StreamAgg::Max(col_idx), name)),
4111 _ => return Ok(None),
4112 }
4113 }
4114 _ => return Ok(None),
4115 }
4116 }
4117
4118 let mut needed: Vec<usize> = ops
4119 .iter()
4120 .filter_map(|(op, _)| match op {
4121 StreamAgg::CountStar => None,
4122 StreamAgg::Count(i)
4123 | StreamAgg::Sum(i)
4124 | StreamAgg::Avg(i)
4125 | StreamAgg::Min(i)
4126 | StreamAgg::Max(i) => Some(*i),
4127 })
4128 .collect();
4129 if let Some(ref where_expr) = stmt.where_clause {
4130 needed.extend(referenced_columns(where_expr, &table_schema.columns));
4131 }
4132 needed.sort_unstable();
4133 needed.dedup();
4134
4135 let partial_ctx = if needed.len() < table_schema.columns.len() {
4136 Some(PartialDecodeCtx::new(table_schema, &needed))
4137 } else {
4138 None
4139 };
4140
4141 let non_pk = table_schema.non_pk_indices();
4142 let enc_pos = table_schema.encoding_positions();
4143 let raw_targets: Vec<RawAggTarget> = ops
4144 .iter()
4145 .map(|(op, _)| match op {
4146 StreamAgg::CountStar => RawAggTarget::CountStar,
4147 StreamAgg::Count(idx)
4148 | StreamAgg::Sum(idx)
4149 | StreamAgg::Avg(idx)
4150 | StreamAgg::Min(idx)
4151 | StreamAgg::Max(idx) => {
4152 if let Some(pk_pos) = table_schema
4153 .primary_key_columns
4154 .iter()
4155 .position(|&i| i as usize == *idx)
4156 {
4157 RawAggTarget::Pk(pk_pos)
4158 } else {
4159 let nonpk_order = non_pk.iter().position(|&i| i == *idx).unwrap();
4160 RawAggTarget::NonPk(enc_pos[nonpk_order] as usize)
4161 }
4162 }
4163 })
4164 .collect();
4165
4166 let num_pk_cols = table_schema.primary_key_columns.len();
4167
4168 let mapping = table_schema.decode_col_mapping();
4169 let nonpk_agg_defaults: Vec<Option<Value>> = raw_targets
4170 .iter()
4171 .map(|t| match t {
4172 RawAggTarget::NonPk(phys_idx) => {
4173 let schema_col = mapping[*phys_idx];
4174 if schema_col == usize::MAX {
4175 return None;
4176 }
4177 table_schema.columns[schema_col]
4178 .default_expr
4179 .as_ref()
4180 .and_then(|expr| eval_const_expr(expr).ok())
4181 }
4182 _ => None,
4183 })
4184 .collect();
4185
4186 Ok(Some(Self {
4187 ops,
4188 partial_ctx,
4189 raw_targets,
4190 num_pk_cols,
4191 nonpk_agg_defaults,
4192 }))
4193 }
4194
4195 #[allow(clippy::too_many_arguments)]
4196 fn feed_row(
4197 &self,
4198 key: &[u8],
4199 value: &[u8],
4200 table_schema: &TableSchema,
4201 col_map: &ColumnMap,
4202 where_clause: &Option<Expr>,
4203 states: &mut [AggState],
4204 scan_err: &mut Option<SqlError>,
4205 ) -> bool {
4206 let row = match &self.partial_ctx {
4207 Some(ctx) => match ctx.decode(key, value) {
4208 Ok(r) => r,
4209 Err(e) => {
4210 *scan_err = Some(e);
4211 return false;
4212 }
4213 },
4214 None => match decode_full_row(table_schema, key, value) {
4215 Ok(r) => r,
4216 Err(e) => {
4217 *scan_err = Some(e);
4218 return false;
4219 }
4220 },
4221 };
4222
4223 if let Some(expr) = where_clause {
4224 match eval_expr(expr, col_map, &row) {
4225 Ok(val) if !is_truthy(&val) => return true,
4226 Err(e) => {
4227 *scan_err = Some(e);
4228 return false;
4229 }
4230 _ => {}
4231 }
4232 }
4233
4234 for (i, (op, _)) in self.ops.iter().enumerate() {
4235 let val = match op {
4236 StreamAgg::CountStar => &Value::Null,
4237 StreamAgg::Count(idx)
4238 | StreamAgg::Sum(idx)
4239 | StreamAgg::Avg(idx)
4240 | StreamAgg::Min(idx)
4241 | StreamAgg::Max(idx) => &row[*idx],
4242 };
4243 if let Err(e) = states[i].feed_val(val) {
4244 *scan_err = Some(e);
4245 return false;
4246 }
4247 }
4248 true
4249 }
4250
4251 fn feed_row_raw(
4252 &self,
4253 key: &[u8],
4254 value: &[u8],
4255 states: &mut [AggState],
4256 scan_err: &mut Option<SqlError>,
4257 ) -> bool {
4258 for (i, target) in self.raw_targets.iter().enumerate() {
4259 let raw = match target {
4260 RawAggTarget::CountStar => {
4261 if let Err(e) = states[i].feed_raw(&RawColumn::Null) {
4262 *scan_err = Some(e);
4263 return false;
4264 }
4265 continue;
4266 }
4267 RawAggTarget::Pk(pk_pos) => {
4268 if self.num_pk_cols == 1 && *pk_pos == 0 {
4269 match decode_pk_integer(key) {
4270 Ok(v) => RawColumn::Integer(v),
4271 Err(e) => {
4272 *scan_err = Some(e);
4273 return false;
4274 }
4275 }
4276 } else {
4277 match decode_composite_key(key, self.num_pk_cols) {
4278 Ok(pk) => RawColumn::Integer(match &pk[*pk_pos] {
4279 Value::Integer(i) => *i,
4280 _ => {
4281 *scan_err =
4282 Some(SqlError::InvalidValue("PK not integer".into()));
4283 return false;
4284 }
4285 }),
4286 Err(e) => {
4287 *scan_err = Some(e);
4288 return false;
4289 }
4290 }
4291 }
4292 }
4293 RawAggTarget::NonPk(idx) => {
4294 let stored = row_non_pk_count(value);
4295 if *idx >= stored {
4296 if let Some(ref default) = self.nonpk_agg_defaults[i] {
4297 if let Err(e) = states[i].feed_val(default) {
4298 *scan_err = Some(e);
4299 return false;
4300 }
4301 } else if let Err(e) = states[i].feed_raw(&RawColumn::Null) {
4302 *scan_err = Some(e);
4303 return false;
4304 }
4305 continue;
4306 }
4307 match decode_column_raw(value, *idx) {
4308 Ok(v) => v,
4309 Err(e) => {
4310 *scan_err = Some(e);
4311 return false;
4312 }
4313 }
4314 }
4315 };
4316 if let Err(e) = states[i].feed_raw(&raw) {
4317 *scan_err = Some(e);
4318 return false;
4319 }
4320 }
4321 true
4322 }
4323
4324 fn finish(self, states: Vec<AggState>) -> ExecutionResult {
4325 let col_names: Vec<String> = self.ops.iter().map(|(_, name)| name.clone()).collect();
4326 let result_row: Vec<Value> = states.into_iter().map(|s| s.finish()).collect();
4327 ExecutionResult::Query(QueryResult {
4328 columns: col_names,
4329 rows: vec![result_row],
4330 })
4331 }
4332}
4333
4334fn resolve_simple_col(expr: &Expr, col_map: &ColumnMap) -> Option<usize> {
4335 match expr {
4336 Expr::Column(name) => col_map.resolve(name).ok(),
4337 Expr::QualifiedColumn { table, column } => col_map.resolve_qualified(table, column).ok(),
4338 _ => None,
4339 }
4340}
4341
4342enum GroupByOutputCol {
4343 GroupKey,
4344 Agg(usize),
4345}
4346
4347struct StreamGroupByPlan {
4348 group_target: RawAggTarget,
4349 num_pk_cols: usize,
4350 agg_ops: Vec<StreamAgg>,
4351 raw_targets: Vec<RawAggTarget>,
4352 output: Vec<(GroupByOutputCol, String)>,
4353 where_pred: Option<SimplePredicate>,
4354}
4355
4356impl StreamGroupByPlan {
4357 fn try_new(stmt: &SelectStmt, schema: &TableSchema) -> Result<Option<Self>> {
4358 if stmt.group_by.len() != 1
4359 || stmt.having.is_some()
4360 || !stmt.joins.is_empty()
4361 || !stmt.order_by.is_empty()
4362 || stmt.limit.is_some()
4363 {
4364 return Ok(None);
4365 }
4366
4367 let where_pred = stmt
4368 .where_clause
4369 .as_ref()
4370 .map(|expr| try_simple_predicate(expr, schema));
4371 if stmt.where_clause.is_some() && where_pred.as_ref().unwrap().is_none() {
4373 return Ok(None);
4374 }
4375 let where_pred = where_pred.flatten();
4376
4377 let col_map = ColumnMap::new(&schema.columns);
4378
4379 let group_col_idx = match &stmt.group_by[0] {
4380 Expr::Column(name) => col_map.resolve(name).ok(),
4381 _ => None,
4382 };
4383 let group_col_idx = match group_col_idx {
4384 Some(idx) => idx,
4385 None => return Ok(None),
4386 };
4387
4388 if schema.columns[group_col_idx].data_type != DataType::Integer {
4389 return Ok(None);
4390 }
4391
4392 let non_pk = schema.non_pk_indices();
4393 let enc_pos = schema.encoding_positions();
4394 let group_target = if let Some(pk_pos) = schema
4395 .primary_key_columns
4396 .iter()
4397 .position(|&i| i as usize == group_col_idx)
4398 {
4399 RawAggTarget::Pk(pk_pos)
4400 } else {
4401 let nonpk_order = non_pk.iter().position(|&i| i == group_col_idx).unwrap();
4402 RawAggTarget::NonPk(enc_pos[nonpk_order] as usize)
4403 };
4404
4405 let mut agg_ops = Vec::new();
4406 let mut raw_targets = Vec::new();
4407 let mut output = Vec::new();
4408
4409 for sel_col in &stmt.columns {
4410 let (expr, alias) = match sel_col {
4411 SelectColumn::Expr { expr, alias } => (expr, alias),
4412 _ => return Ok(None),
4413 };
4414 let name = alias
4415 .as_deref()
4416 .unwrap_or(&expr_display_name(expr))
4417 .to_string();
4418
4419 if let Some(idx) = resolve_simple_col(expr, &col_map) {
4420 if idx == group_col_idx {
4421 output.push((GroupByOutputCol::GroupKey, name));
4422 continue;
4423 }
4424 }
4425
4426 match expr {
4427 Expr::CountStar => {
4428 let agg_idx = agg_ops.len();
4429 agg_ops.push(StreamAgg::CountStar);
4430 raw_targets.push(RawAggTarget::CountStar);
4431 output.push((GroupByOutputCol::Agg(agg_idx), name));
4432 }
4433 Expr::Function {
4434 name: func_name,
4435 args,
4436 } if args.len() == 1 => {
4437 let func = func_name.to_ascii_uppercase();
4438 let col_idx = match resolve_simple_col(&args[0], &col_map) {
4439 Some(idx) => idx,
4440 None => return Ok(None),
4441 };
4442 let target = if let Some(pk_pos) = schema
4443 .primary_key_columns
4444 .iter()
4445 .position(|&i| i as usize == col_idx)
4446 {
4447 RawAggTarget::Pk(pk_pos)
4448 } else {
4449 let nonpk_order = non_pk.iter().position(|&i| i == col_idx).unwrap();
4450 RawAggTarget::NonPk(enc_pos[nonpk_order] as usize)
4451 };
4452 let agg_idx = agg_ops.len();
4453 match func.as_str() {
4454 "COUNT" => agg_ops.push(StreamAgg::Count(col_idx)),
4455 "SUM" => agg_ops.push(StreamAgg::Sum(col_idx)),
4456 "AVG" => agg_ops.push(StreamAgg::Avg(col_idx)),
4457 "MIN" => agg_ops.push(StreamAgg::Min(col_idx)),
4458 "MAX" => agg_ops.push(StreamAgg::Max(col_idx)),
4459 _ => return Ok(None),
4460 }
4461 raw_targets.push(target);
4462 output.push((GroupByOutputCol::Agg(agg_idx), name));
4463 }
4464 _ => return Ok(None),
4465 }
4466 }
4467
4468 Ok(Some(Self {
4469 group_target,
4470 num_pk_cols: schema.primary_key_columns.len(),
4471 agg_ops,
4472 raw_targets,
4473 output,
4474 where_pred,
4475 }))
4476 }
4477
4478 fn execute_scan(
4479 &self,
4480 scan: impl FnOnce(
4481 &mut dyn FnMut(&[u8], &[u8]) -> bool,
4482 ) -> std::result::Result<(), citadel::Error>,
4483 ) -> Result<ExecutionResult> {
4484 let mut groups: HashMap<i64, Vec<AggState>> = HashMap::new();
4485 let mut null_group: Option<Vec<AggState>> = None;
4486 let mut scan_err: Option<SqlError> = None;
4487
4488 scan(&mut |key, value| {
4489 if let Some(ref pred) = self.where_pred {
4490 match pred.matches_raw(key, value) {
4491 Ok(true) => {}
4492 Ok(false) => return true,
4493 Err(e) => {
4494 scan_err = Some(e);
4495 return false;
4496 }
4497 }
4498 }
4499
4500 let group_key: Option<i64> = match &self.group_target {
4501 RawAggTarget::Pk(pk_pos) => {
4502 if self.num_pk_cols == 1 && *pk_pos == 0 {
4503 match decode_pk_integer(key) {
4504 Ok(v) => Some(v),
4505 Err(e) => {
4506 scan_err = Some(e);
4507 return false;
4508 }
4509 }
4510 } else {
4511 match decode_composite_key(key, self.num_pk_cols) {
4512 Ok(pk) => match &pk[*pk_pos] {
4513 Value::Integer(i) => Some(*i),
4514 Value::Null => None,
4515 _ => {
4516 scan_err = Some(SqlError::InvalidValue(
4517 "GROUP BY key not integer".into(),
4518 ));
4519 return false;
4520 }
4521 },
4522 Err(e) => {
4523 scan_err = Some(e);
4524 return false;
4525 }
4526 }
4527 }
4528 }
4529 RawAggTarget::NonPk(idx) => match decode_column_raw(value, *idx) {
4530 Ok(RawColumn::Integer(i)) => Some(i),
4531 Ok(RawColumn::Null) => None,
4532 Ok(_) => {
4533 scan_err = Some(SqlError::InvalidValue("GROUP BY key not integer".into()));
4534 return false;
4535 }
4536 Err(e) => {
4537 scan_err = Some(e);
4538 return false;
4539 }
4540 },
4541 RawAggTarget::CountStar => unreachable!(),
4542 };
4543
4544 let states = match group_key {
4545 Some(k) => groups
4546 .entry(k)
4547 .or_insert_with(|| self.agg_ops.iter().map(AggState::new).collect()),
4548 None => null_group
4549 .get_or_insert_with(|| self.agg_ops.iter().map(AggState::new).collect()),
4550 };
4551
4552 for (i, target) in self.raw_targets.iter().enumerate() {
4553 let raw = match target {
4554 RawAggTarget::CountStar => {
4555 if let Err(e) = states[i].feed_raw(&RawColumn::Null) {
4556 scan_err = Some(e);
4557 return false;
4558 }
4559 continue;
4560 }
4561 RawAggTarget::Pk(pk_pos) => {
4562 if self.num_pk_cols == 1 && *pk_pos == 0 {
4563 match decode_pk_integer(key) {
4564 Ok(v) => RawColumn::Integer(v),
4565 Err(e) => {
4566 scan_err = Some(e);
4567 return false;
4568 }
4569 }
4570 } else {
4571 match decode_composite_key(key, self.num_pk_cols) {
4572 Ok(pk) => match &pk[*pk_pos] {
4573 Value::Integer(i) => RawColumn::Integer(*i),
4574 _ => {
4575 scan_err = Some(SqlError::InvalidValue(
4576 "agg column not integer".into(),
4577 ));
4578 return false;
4579 }
4580 },
4581 Err(e) => {
4582 scan_err = Some(e);
4583 return false;
4584 }
4585 }
4586 }
4587 }
4588 RawAggTarget::NonPk(idx) => match decode_column_raw(value, *idx) {
4589 Ok(v) => v,
4590 Err(e) => {
4591 scan_err = Some(e);
4592 return false;
4593 }
4594 },
4595 };
4596 if let Err(e) = states[i].feed_raw(&raw) {
4597 scan_err = Some(e);
4598 return false;
4599 }
4600 }
4601 true
4602 })
4603 .map_err(SqlError::Storage)?;
4604
4605 if let Some(e) = scan_err {
4606 return Err(e);
4607 }
4608
4609 let col_names: Vec<String> = self.output.iter().map(|(_, name)| name.clone()).collect();
4610 let null_extra = if null_group.is_some() { 1 } else { 0 };
4611 let mut result_rows: Vec<Vec<Value>> = Vec::with_capacity(groups.len() + null_extra);
4612 if let Some(states) = null_group {
4613 let mut row = Vec::with_capacity(self.output.len());
4614 let finished: Vec<Value> = states.into_iter().map(|s| s.finish()).collect();
4615 for (col, _) in &self.output {
4616 match col {
4617 GroupByOutputCol::GroupKey => row.push(Value::Null),
4618 GroupByOutputCol::Agg(idx) => row.push(finished[*idx].clone()),
4619 }
4620 }
4621 result_rows.push(row);
4622 }
4623 for (group_key, states) in groups {
4624 let mut row = Vec::with_capacity(self.output.len());
4625 let finished: Vec<Value> = states.into_iter().map(|s| s.finish()).collect();
4626 for (col, _) in &self.output {
4627 match col {
4628 GroupByOutputCol::GroupKey => row.push(Value::Integer(group_key)),
4629 GroupByOutputCol::Agg(idx) => row.push(finished[*idx].clone()),
4630 }
4631 }
4632 result_rows.push(row);
4633 }
4634
4635 Ok(ExecutionResult::Query(QueryResult {
4636 columns: col_names,
4637 rows: result_rows,
4638 }))
4639 }
4640}
4641
4642struct TopKScanPlan {
4643 sort_target: RawAggTarget,
4644 num_pk_cols: usize,
4645 descending: bool,
4646 nulls_first: bool,
4647 keep: usize,
4648}
4649
4650impl TopKScanPlan {
4651 fn try_new(stmt: &SelectStmt, schema: &TableSchema) -> Result<Option<Self>> {
4652 if stmt.order_by.len() != 1
4653 || stmt.limit.is_none()
4654 || stmt.where_clause.is_some()
4655 || !stmt.group_by.is_empty()
4656 || stmt.having.is_some()
4657 || !stmt.joins.is_empty()
4658 || stmt.distinct
4659 {
4660 return Ok(None);
4661 }
4662
4663 let has_aggregates = stmt.columns.iter().any(|c| match c {
4664 SelectColumn::Expr { expr, .. } => is_aggregate_expr(expr),
4665 _ => false,
4666 });
4667 if has_aggregates {
4668 return Ok(None);
4669 }
4670
4671 let ob = &stmt.order_by[0];
4672 let col_map = ColumnMap::new(&schema.columns);
4673 let col_idx = match resolve_simple_col(&ob.expr, &col_map) {
4674 Some(idx) => idx,
4675 None => return Ok(None),
4676 };
4677
4678 let non_pk = schema.non_pk_indices();
4679 let enc_pos_arr = schema.encoding_positions();
4680 let sort_target = if let Some(pk_pos) = schema
4681 .primary_key_columns
4682 .iter()
4683 .position(|&i| i as usize == col_idx)
4684 {
4685 RawAggTarget::Pk(pk_pos)
4686 } else {
4687 let nonpk_order = non_pk.iter().position(|&i| i == col_idx).unwrap();
4688 RawAggTarget::NonPk(enc_pos_arr[nonpk_order] as usize)
4689 };
4690
4691 let limit = eval_const_int(stmt.limit.as_ref().unwrap())?.max(0) as usize;
4692 let offset = stmt
4693 .offset
4694 .as_ref()
4695 .map(eval_const_int)
4696 .transpose()?
4697 .unwrap_or(0)
4698 .max(0) as usize;
4699 let keep = limit.saturating_add(offset);
4700 if keep == 0 {
4701 return Ok(None);
4702 }
4703
4704 Ok(Some(Self {
4705 sort_target,
4706 num_pk_cols: schema.primary_key_columns.len(),
4707 descending: ob.descending,
4708 nulls_first: ob.nulls_first.unwrap_or(!ob.descending),
4709 keep,
4710 }))
4711 }
4712
4713 fn execute_scan(
4714 &self,
4715 schema: &TableSchema,
4716 stmt: &SelectStmt,
4717 scan: impl FnOnce(
4718 &mut dyn FnMut(&[u8], &[u8]) -> bool,
4719 ) -> std::result::Result<(), citadel::Error>,
4720 ) -> Result<ExecutionResult> {
4721 use std::cmp::Ordering;
4722 use std::collections::BinaryHeap;
4723
4724 struct Candidate {
4725 sort_key: Value,
4726 raw_key: Vec<u8>,
4727 raw_value: Vec<u8>,
4728 }
4729
4730 struct CandWrapper {
4731 c: Candidate,
4732 descending: bool,
4733 nulls_first: bool,
4734 }
4735
4736 impl PartialEq for CandWrapper {
4737 fn eq(&self, other: &Self) -> bool {
4738 self.cmp(other) == Ordering::Equal
4739 }
4740 }
4741 impl Eq for CandWrapper {}
4742
4743 impl PartialOrd for CandWrapper {
4744 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
4745 Some(self.cmp(other))
4746 }
4747 }
4748
4749 impl Ord for CandWrapper {
4751 fn cmp(&self, other: &Self) -> Ordering {
4752 let ord = match (self.c.sort_key.is_null(), other.c.sort_key.is_null()) {
4753 (true, true) => Ordering::Equal,
4754 (true, false) => {
4755 if self.nulls_first {
4756 Ordering::Less
4757 } else {
4758 Ordering::Greater
4759 }
4760 }
4761 (false, true) => {
4762 if self.nulls_first {
4763 Ordering::Greater
4764 } else {
4765 Ordering::Less
4766 }
4767 }
4768 (false, false) => self.c.sort_key.cmp(&other.c.sort_key),
4769 };
4770 if self.descending {
4771 ord.reverse()
4772 } else {
4773 ord
4774 }
4775 }
4776 }
4777
4778 let k = self.keep;
4779 let mut heap: BinaryHeap<CandWrapper> = BinaryHeap::with_capacity(k + 1);
4780 let mut scan_err: Option<SqlError> = None;
4781
4782 scan(&mut |key, value| {
4783 let sort_key: Value = match &self.sort_target {
4784 RawAggTarget::Pk(pk_pos) => {
4785 if self.num_pk_cols == 1 && *pk_pos == 0 {
4786 match decode_pk_integer(key) {
4787 Ok(v) => Value::Integer(v),
4788 Err(e) => {
4789 scan_err = Some(e);
4790 return false;
4791 }
4792 }
4793 } else {
4794 match decode_composite_key(key, self.num_pk_cols) {
4795 Ok(mut pk) => std::mem::replace(&mut pk[*pk_pos], Value::Null),
4796 Err(e) => {
4797 scan_err = Some(e);
4798 return false;
4799 }
4800 }
4801 }
4802 }
4803 RawAggTarget::NonPk(idx) => match decode_column_raw(value, *idx) {
4804 Ok(raw) => raw.to_value(),
4805 Err(e) => {
4806 scan_err = Some(e);
4807 return false;
4808 }
4809 },
4810 RawAggTarget::CountStar => unreachable!(),
4811 };
4812
4813 if heap.len() >= k {
4815 if let Some(top) = heap.peek() {
4816 let ord = match (sort_key.is_null(), top.c.sort_key.is_null()) {
4817 (true, true) => Ordering::Equal,
4818 (true, false) => {
4819 if self.nulls_first {
4820 Ordering::Less
4821 } else {
4822 Ordering::Greater
4823 }
4824 }
4825 (false, true) => {
4826 if self.nulls_first {
4827 Ordering::Greater
4828 } else {
4829 Ordering::Less
4830 }
4831 }
4832 (false, false) => sort_key.cmp(&top.c.sort_key),
4833 };
4834 let cmp = if self.descending { ord.reverse() } else { ord };
4835 if cmp != Ordering::Less {
4836 return true;
4837 }
4838 }
4839 }
4840
4841 let cand = CandWrapper {
4842 c: Candidate {
4843 sort_key,
4844 raw_key: key.to_vec(),
4845 raw_value: value.to_vec(),
4846 },
4847 descending: self.descending,
4848 nulls_first: self.nulls_first,
4849 };
4850
4851 if heap.len() < k {
4852 heap.push(cand);
4853 } else if let Some(mut top) = heap.peek_mut() {
4854 *top = cand;
4855 }
4856
4857 true
4858 })
4859 .map_err(SqlError::Storage)?;
4860
4861 if let Some(e) = scan_err {
4862 return Err(e);
4863 }
4864
4865 let mut winners: Vec<CandWrapper> = heap.into_vec();
4866 winners.sort();
4867
4868 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(winners.len());
4869 for w in &winners {
4870 rows.push(decode_full_row(schema, &w.c.raw_key, &w.c.raw_value)?);
4871 }
4872
4873 if let Some(ref offset_expr) = stmt.offset {
4874 let offset = eval_const_int(offset_expr)?.max(0) as usize;
4875 if offset < rows.len() {
4876 rows = rows.split_off(offset);
4877 } else {
4878 rows.clear();
4879 }
4880 }
4881 if let Some(ref limit_expr) = stmt.limit {
4882 let limit = eval_const_int(limit_expr)?.max(0) as usize;
4883 rows.truncate(limit);
4884 }
4885
4886 let (col_names, projected) = project_rows(&schema.columns, &stmt.columns, rows)?;
4887 Ok(ExecutionResult::Query(QueryResult {
4888 columns: col_names,
4889 rows: projected,
4890 }))
4891 }
4892}
4893
4894struct SimplePredicate {
4895 is_pk: bool,
4896 pk_pos: usize,
4897 nonpk_idx: usize,
4898 op: BinOp,
4899 literal: Value,
4900 num_pk_cols: usize,
4901 precomputed_int: Option<i64>,
4902 default_int: Option<i64>,
4903 default_val: Option<Value>,
4904}
4905
4906impl SimplePredicate {
4907 fn matches_raw(&self, key: &[u8], value: &[u8]) -> Result<bool> {
4908 if let Some(target) = self.precomputed_int {
4909 return Ok(self.match_nonpk_int_inline(value, target));
4910 }
4911 let raw = if self.is_pk {
4912 if self.num_pk_cols == 1 {
4913 RawColumn::Integer(decode_pk_integer(key)?)
4914 } else {
4915 let pk = decode_composite_key(key, self.num_pk_cols)?;
4916 match &pk[self.pk_pos] {
4917 Value::Integer(i) => RawColumn::Integer(*i),
4918 Value::Real(r) => RawColumn::Real(*r),
4919 Value::Boolean(b) => RawColumn::Boolean(*b),
4920 _ => {
4921 return Ok(raw_matches_op_value(
4922 &pk[self.pk_pos],
4923 self.op,
4924 &self.literal,
4925 ))
4926 }
4927 }
4928 }
4929 } else if self.nonpk_idx >= row_non_pk_count(value) {
4930 return Ok(match &self.default_val {
4931 Some(d) => raw_matches_op_value(d, self.op, &self.literal),
4932 None => false,
4933 });
4934 } else {
4935 decode_column_raw(value, self.nonpk_idx)?
4936 };
4937 Ok(raw_matches_op(&raw, self.op, &self.literal))
4938 }
4939
4940 #[inline(always)]
4941 fn match_nonpk_int_inline(&self, data: &[u8], target: i64) -> bool {
4942 let col_count = u16::from_le_bytes(data[0..2].try_into().unwrap()) as usize;
4943
4944 if self.nonpk_idx >= col_count {
4945 return match self.default_int {
4946 Some(v) => match self.op {
4947 BinOp::Eq => v == target,
4948 BinOp::NotEq => v != target,
4949 BinOp::Lt => v < target,
4950 BinOp::Gt => v > target,
4951 BinOp::LtEq => v <= target,
4952 BinOp::GtEq => v >= target,
4953 _ => false,
4954 },
4955 None => false,
4956 };
4957 }
4958
4959 let bm_bytes = col_count.div_ceil(8);
4960
4961 if data[2 + self.nonpk_idx / 8] & (1 << (self.nonpk_idx % 8)) != 0 {
4963 return false;
4964 }
4965
4966 let mut pos = 2 + bm_bytes;
4967
4968 for col in 0..self.nonpk_idx {
4970 if data[2 + col / 8] & (1 << (col % 8)) == 0 {
4971 let len = u32::from_le_bytes(data[pos + 1..pos + 5].try_into().unwrap()) as usize;
4972 pos += 5 + len;
4973 }
4974 }
4975
4976 let v = i64::from_le_bytes(data[pos + 5..pos + 13].try_into().unwrap());
4978
4979 match self.op {
4980 BinOp::Eq => v == target,
4981 BinOp::NotEq => v != target,
4982 BinOp::Lt => v < target,
4983 BinOp::Gt => v > target,
4984 BinOp::LtEq => v <= target,
4985 BinOp::GtEq => v >= target,
4986 _ => false,
4987 }
4988 }
4989}
4990
4991fn try_simple_predicate(expr: &Expr, schema: &TableSchema) -> Option<SimplePredicate> {
4992 let (col_name, op, literal) = match expr {
4993 Expr::BinaryOp { left, op, right } => match (left.as_ref(), right.as_ref()) {
4994 (Expr::Column(name), Expr::Literal(lit)) => (name.as_str(), *op, lit),
4995 (Expr::Literal(lit), Expr::Column(name)) => (name.as_str(), flip_cmp_op(*op)?, lit),
4996 _ => return None,
4997 },
4998 _ => return None,
4999 };
5000
5001 if !matches!(
5002 op,
5003 BinOp::Eq | BinOp::NotEq | BinOp::Lt | BinOp::Gt | BinOp::LtEq | BinOp::GtEq
5004 ) {
5005 return None;
5006 }
5007
5008 let col_idx = schema.column_index(col_name)?;
5009 let non_pk = schema.non_pk_indices();
5010
5011 if let Some(pk_pos) = schema
5012 .primary_key_columns
5013 .iter()
5014 .position(|&i| i as usize == col_idx)
5015 {
5016 Some(SimplePredicate {
5017 is_pk: true,
5018 pk_pos,
5019 nonpk_idx: 0,
5020 op,
5021 literal: literal.clone(),
5022 num_pk_cols: schema.primary_key_columns.len(),
5023 precomputed_int: None,
5024 default_int: None,
5025 default_val: None,
5026 })
5027 } else {
5028 let nonpk_order = non_pk.iter().position(|&i| i == col_idx)?;
5029 let nonpk_idx = schema.encoding_positions()[nonpk_order] as usize;
5030 let precomputed_int = match literal {
5031 Value::Integer(i) => Some(*i),
5032 _ => None,
5033 };
5034 let default_val = schema.columns[col_idx]
5035 .default_expr
5036 .as_ref()
5037 .and_then(|expr| eval_const_expr(expr).ok());
5038 let default_int = default_val.as_ref().and_then(|v| match v {
5039 Value::Integer(i) => Some(*i),
5040 _ => None,
5041 });
5042 Some(SimplePredicate {
5043 is_pk: false,
5044 pk_pos: 0,
5045 nonpk_idx,
5046 op,
5047 literal: literal.clone(),
5048 num_pk_cols: schema.primary_key_columns.len(),
5049 precomputed_int,
5050 default_int,
5051 default_val,
5052 })
5053 }
5054}
5055
5056fn flip_cmp_op(op: BinOp) -> Option<BinOp> {
5057 match op {
5058 BinOp::Eq => Some(BinOp::Eq),
5059 BinOp::NotEq => Some(BinOp::NotEq),
5060 BinOp::Lt => Some(BinOp::Gt),
5061 BinOp::Gt => Some(BinOp::Lt),
5062 BinOp::LtEq => Some(BinOp::GtEq),
5063 BinOp::GtEq => Some(BinOp::LtEq),
5064 _ => None,
5065 }
5066}
5067
5068fn raw_matches_op(raw: &RawColumn, op: BinOp, literal: &Value) -> bool {
5069 if matches!(raw, RawColumn::Null) || literal.is_null() {
5071 return false;
5072 }
5073 match op {
5074 BinOp::Eq => raw.eq_value(literal),
5075 BinOp::NotEq => !raw.eq_value(literal),
5076 BinOp::Lt => raw.cmp_value(literal) == Some(std::cmp::Ordering::Less),
5077 BinOp::Gt => raw.cmp_value(literal) == Some(std::cmp::Ordering::Greater),
5078 BinOp::LtEq => raw
5079 .cmp_value(literal)
5080 .is_some_and(|o| o != std::cmp::Ordering::Greater),
5081 BinOp::GtEq => raw
5082 .cmp_value(literal)
5083 .is_some_and(|o| o != std::cmp::Ordering::Less),
5084 _ => false,
5085 }
5086}
5087
5088fn raw_matches_op_value(val: &Value, op: BinOp, literal: &Value) -> bool {
5089 match op {
5090 BinOp::Eq => val == literal,
5091 BinOp::NotEq => val != literal && !val.is_null(),
5092 BinOp::Lt => val < literal,
5093 BinOp::Gt => val > literal,
5094 BinOp::LtEq => val <= literal,
5095 BinOp::GtEq => val >= literal,
5096 _ => false,
5097 }
5098}
5099
5100fn exec_select_no_from(stmt: &SelectStmt) -> Result<ExecutionResult> {
5101 let empty_cols: Vec<ColumnDef> = vec![];
5102 let empty_row: Vec<Value> = vec![];
5103 let (col_names, projected) = project_rows(&empty_cols, &stmt.columns, vec![empty_row])?;
5104 Ok(ExecutionResult::Query(QueryResult {
5105 columns: col_names,
5106 rows: projected,
5107 }))
5108}
5109
5110fn process_select(
5111 columns: &[ColumnDef],
5112 mut rows: Vec<Vec<Value>>,
5113 stmt: &SelectStmt,
5114 predicate_applied: bool,
5115) -> Result<ExecutionResult> {
5116 if !predicate_applied {
5117 if let Some(ref where_expr) = stmt.where_clause {
5118 let col_map = ColumnMap::new(columns);
5119 rows.retain(|row| match eval_expr(where_expr, &col_map, row) {
5120 Ok(val) => is_truthy(&val),
5121 Err(_) => false,
5122 });
5123 }
5124 }
5125
5126 let has_aggregates = stmt.columns.iter().any(|c| match c {
5127 SelectColumn::Expr { expr, .. } => is_aggregate_expr(expr),
5128 _ => false,
5129 });
5130
5131 if has_aggregates || !stmt.group_by.is_empty() {
5132 return exec_aggregate(columns, &rows, stmt);
5133 }
5134
5135 if stmt.distinct {
5136 let (col_names, mut projected) = project_rows(columns, &stmt.columns, rows)?;
5137
5138 let mut seen = std::collections::HashSet::new();
5139 projected.retain(|row| seen.insert(row.clone()));
5140
5141 if !stmt.order_by.is_empty() {
5142 let output_cols = build_output_columns(&stmt.columns, columns);
5143 sort_rows(&mut projected, &stmt.order_by, &output_cols)?;
5144 }
5145
5146 if let Some(ref offset_expr) = stmt.offset {
5147 let offset = eval_const_int(offset_expr)?.max(0) as usize;
5148 if offset < projected.len() {
5149 projected = projected.split_off(offset);
5150 } else {
5151 projected.clear();
5152 }
5153 }
5154
5155 if let Some(ref limit_expr) = stmt.limit {
5156 let limit = eval_const_int(limit_expr)?.max(0) as usize;
5157 projected.truncate(limit);
5158 }
5159
5160 return Ok(ExecutionResult::Query(QueryResult {
5161 columns: col_names,
5162 rows: projected,
5163 }));
5164 }
5165
5166 if !stmt.order_by.is_empty() {
5167 if let Some(ref limit_expr) = stmt.limit {
5168 let limit = eval_const_int(limit_expr)?.max(0) as usize;
5169 let offset = match stmt.offset {
5170 Some(ref e) => eval_const_int(e)?.max(0) as usize,
5171 None => 0,
5172 };
5173 let keep = limit.saturating_add(offset);
5174 if keep == 0 {
5175 rows.clear();
5176 } else if keep < rows.len() {
5177 topk_rows(&mut rows, &stmt.order_by, columns, keep)?;
5178 rows.truncate(keep);
5179 } else {
5180 sort_rows(&mut rows, &stmt.order_by, columns)?;
5181 }
5182 } else {
5183 sort_rows(&mut rows, &stmt.order_by, columns)?;
5184 }
5185 }
5186
5187 if let Some(ref offset_expr) = stmt.offset {
5188 let offset = eval_const_int(offset_expr)?.max(0) as usize;
5189 if offset < rows.len() {
5190 rows = rows.split_off(offset);
5191 } else {
5192 rows.clear();
5193 }
5194 }
5195
5196 if let Some(ref limit_expr) = stmt.limit {
5197 let limit = eval_const_int(limit_expr)?.max(0) as usize;
5198 rows.truncate(limit);
5199 }
5200
5201 let (col_names, projected) = project_rows(columns, &stmt.columns, rows)?;
5202
5203 Ok(ExecutionResult::Query(QueryResult {
5204 columns: col_names,
5205 rows: projected,
5206 }))
5207}
5208
5209fn resolve_table_name<'a>(schema: &'a SchemaManager, name: &str) -> Result<&'a TableSchema> {
5210 schema
5211 .get(name)
5212 .ok_or_else(|| SqlError::TableNotFound(name.to_string()))
5213}
5214
5215fn build_joined_columns(tables: &[(String, &TableSchema)]) -> Vec<ColumnDef> {
5216 let mut result = Vec::new();
5217 let mut pos: u16 = 0;
5218
5219 for (alias, schema) in tables {
5220 for col in &schema.columns {
5221 result.push(ColumnDef {
5222 name: format!("{}.{}", alias.to_ascii_lowercase(), col.name),
5223 data_type: col.data_type,
5224 nullable: col.nullable,
5225 position: pos,
5226 default_expr: None,
5227 default_sql: None,
5228 check_expr: None,
5229 check_sql: None,
5230 check_name: None,
5231 });
5232 pos += 1;
5233 }
5234 }
5235
5236 result
5237}
5238
5239fn extract_equi_join_keys(
5240 on_expr: &Expr,
5241 combined_cols: &[ColumnDef],
5242 outer_col_count: usize,
5243) -> Vec<(usize, usize)> {
5244 let mut pairs = Vec::new();
5245
5246 fn flatten<'a>(e: &'a Expr, out: &mut Vec<&'a Expr>) {
5247 match e {
5248 Expr::BinaryOp {
5249 left,
5250 op: BinOp::And,
5251 right,
5252 } => {
5253 flatten(left, out);
5254 flatten(right, out);
5255 }
5256 _ => out.push(e),
5257 }
5258 }
5259 let mut conjuncts = Vec::new();
5260 flatten(on_expr, &mut conjuncts);
5261
5262 for expr in conjuncts {
5263 if let Expr::BinaryOp {
5264 left,
5265 op: BinOp::Eq,
5266 right,
5267 } = expr
5268 {
5269 if let (Some(l_idx), Some(r_idx)) = (
5270 resolve_col_idx(left, combined_cols),
5271 resolve_col_idx(right, combined_cols),
5272 ) {
5273 if l_idx < outer_col_count && r_idx >= outer_col_count {
5274 pairs.push((l_idx, r_idx - outer_col_count));
5275 } else if r_idx < outer_col_count && l_idx >= outer_col_count {
5276 pairs.push((r_idx, l_idx - outer_col_count));
5277 }
5278 }
5279 }
5280 }
5281
5282 pairs
5283}
5284
5285fn resolve_col_idx(expr: &Expr, columns: &[ColumnDef]) -> Option<usize> {
5286 match expr {
5287 Expr::Column(name) => {
5288 let matches: Vec<usize> = columns
5289 .iter()
5290 .enumerate()
5291 .filter(|(_, c)| {
5292 c.name == *name
5293 || (c.name.len() > name.len()
5294 && c.name.as_bytes()[c.name.len() - name.len() - 1] == b'.'
5295 && c.name.ends_with(name.as_str()))
5296 })
5297 .map(|(i, _)| i)
5298 .collect();
5299 if matches.len() == 1 {
5300 Some(matches[0])
5301 } else {
5302 None
5303 }
5304 }
5305 Expr::QualifiedColumn { table, column } => {
5306 let qualified = format!("{table}.{column}");
5307 columns.iter().position(|c| c.name == qualified)
5308 }
5309 _ => None,
5310 }
5311}
5312
5313fn hash_key(row: &[Value], col_indices: &[usize]) -> Vec<Value> {
5314 col_indices.iter().map(|&i| row[i].clone()).collect()
5315}
5316
5317fn count_conjuncts(expr: &Expr) -> usize {
5318 match expr {
5319 Expr::BinaryOp {
5320 op: BinOp::And,
5321 left,
5322 right,
5323 } => count_conjuncts(left) + count_conjuncts(right),
5324 _ => 1,
5325 }
5326}
5327
5328fn combine_row(outer: &[Value], inner: &[Value], cap: usize) -> Vec<Value> {
5329 let mut combined = Vec::with_capacity(cap);
5330 combined.extend(outer.iter().cloned());
5331 combined.extend(inner.iter().cloned());
5332 combined
5333}
5334
5335struct CombineProjection {
5336 slots: Vec<(usize, bool)>,
5337}
5338
5339fn combine_row_projected(outer: &[Value], inner: &[Value], proj: &CombineProjection) -> Vec<Value> {
5340 proj.slots
5341 .iter()
5342 .map(|&(idx, is_inner)| {
5343 if is_inner {
5344 inner[idx].clone()
5345 } else {
5346 outer[idx].clone()
5347 }
5348 })
5349 .collect()
5350}
5351
5352fn build_combine_projection(
5353 needed_combined: &[usize],
5354 outer_col_count: usize,
5355) -> CombineProjection {
5356 CombineProjection {
5357 slots: needed_combined
5358 .iter()
5359 .map(|&ci| {
5360 if ci < outer_col_count {
5361 (ci, false)
5362 } else {
5363 (ci - outer_col_count, true)
5364 }
5365 })
5366 .collect(),
5367 }
5368}
5369
5370fn build_projected_columns(full_cols: &[ColumnDef], needed_combined: &[usize]) -> Vec<ColumnDef> {
5371 needed_combined
5372 .iter()
5373 .enumerate()
5374 .map(|(new_pos, &old_pos)| {
5375 let orig = &full_cols[old_pos];
5376 ColumnDef {
5377 name: orig.name.clone(),
5378 data_type: orig.data_type,
5379 nullable: orig.nullable,
5380 position: new_pos as u16,
5381 default_expr: None,
5382 default_sql: None,
5383 check_expr: None,
5384 check_sql: None,
5385 check_name: None,
5386 }
5387 })
5388 .collect()
5389}
5390
5391#[allow(clippy::too_many_arguments)]
5392fn try_integer_join(
5393 outer_rows: Vec<Vec<Value>>,
5394 inner_rows: &[Vec<Value>],
5395 join_type: &JoinType,
5396 outer_key_col: usize,
5397 inner_key_col: usize,
5398 outer_col_count: usize,
5399 inner_col_count: usize,
5400 outer_is_sorted: bool,
5401 projection: Option<&CombineProjection>,
5402) -> std::result::Result<Vec<Vec<Value>>, Vec<Vec<Value>>> {
5403 let cap = projection.map_or(outer_col_count + inner_col_count, |p| p.slots.len());
5404
5405 if outer_is_sorted && matches!(join_type, JoinType::Inner | JoinType::Cross) {
5406 let mut sorted_inner: Vec<(i64, usize)> = Vec::with_capacity(inner_rows.len());
5407 let mut needs_sort = false;
5408 let mut prev = i64::MIN;
5409 for (i, r) in inner_rows.iter().enumerate() {
5410 match r[inner_key_col] {
5411 Value::Integer(k) => {
5412 if k < prev {
5413 needs_sort = true;
5414 }
5415 prev = k;
5416 sorted_inner.push((k, i));
5417 }
5418 Value::Null => {}
5419 _ => return Err(outer_rows),
5420 }
5421 }
5422 if needs_sort {
5423 sorted_inner.sort_unstable_by_key(|&(k, _)| k);
5424 }
5425
5426 let mut result = Vec::with_capacity(outer_rows.len());
5427 let mut j = 0;
5428 for mut outer in outer_rows {
5429 let ok = match outer[outer_key_col] {
5430 Value::Integer(i) => i,
5431 _ => continue,
5432 };
5433 while j < sorted_inner.len() && sorted_inner[j].0 < ok {
5434 j += 1;
5435 }
5436 let mut kk = j;
5437 while kk < sorted_inner.len() && sorted_inner[kk].0 == ok {
5438 let is_last = kk + 1 >= sorted_inner.len() || sorted_inner[kk + 1].0 != ok;
5439 let inner = &inner_rows[sorted_inner[kk].1];
5440 if let Some(proj) = projection {
5441 if is_last {
5442 result.push(
5443 proj.slots
5444 .iter()
5445 .map(|&(idx, is_inner)| {
5446 if is_inner {
5447 inner[idx].clone()
5448 } else {
5449 std::mem::take(&mut outer[idx])
5450 }
5451 })
5452 .collect(),
5453 );
5454 } else {
5455 result.push(combine_row_projected(&outer, inner, proj));
5456 }
5457 } else if is_last {
5458 outer.extend(inner.iter().cloned());
5459 result.push(outer);
5460 break;
5461 } else {
5462 result.push(combine_row(&outer, inner, cap));
5463 }
5464 kk += 1;
5465 }
5466 }
5467 return Ok(result);
5468 }
5469
5470 let mut inner_map: HashMap<i64, Vec<usize>> = HashMap::with_capacity(inner_rows.len());
5471 for (idx, inner) in inner_rows.iter().enumerate() {
5472 match &inner[inner_key_col] {
5473 Value::Integer(k) => inner_map.entry(*k).or_default().push(idx),
5474 Value::Null => {}
5475 _ => return Err(outer_rows),
5476 }
5477 }
5478
5479 let mut result = Vec::with_capacity(inner_rows.len());
5480
5481 match join_type {
5482 JoinType::Inner | JoinType::Cross => {
5483 for mut outer in outer_rows {
5484 if let Value::Integer(k) = outer[outer_key_col] {
5485 if let Some(indices) = inner_map.get(&k) {
5486 if let Some(proj) = projection {
5487 for &idx in indices {
5488 result.push(combine_row_projected(&outer, &inner_rows[idx], proj));
5489 }
5490 } else {
5491 for &idx in &indices[..indices.len() - 1] {
5492 result.push(combine_row(&outer, &inner_rows[idx], cap));
5493 }
5494 let last_idx = *indices.last().unwrap();
5495 outer.extend(inner_rows[last_idx].iter().cloned());
5496 result.push(outer);
5497 }
5498 }
5499 }
5500 }
5501 }
5502 JoinType::Left => {
5503 for mut outer in outer_rows {
5504 if let Value::Integer(k) = outer[outer_key_col] {
5505 if let Some(indices) = inner_map.get(&k) {
5506 if let Some(proj) = projection {
5507 for &idx in indices {
5508 result.push(combine_row_projected(&outer, &inner_rows[idx], proj));
5509 }
5510 } else {
5511 for &idx in &indices[..indices.len() - 1] {
5512 result.push(combine_row(&outer, &inner_rows[idx], cap));
5513 }
5514 let last_idx = *indices.last().unwrap();
5515 outer.extend(inner_rows[last_idx].iter().cloned());
5516 result.push(outer);
5517 }
5518 continue;
5519 }
5520 }
5521 if let Some(proj) = projection {
5522 let null_inner = vec![Value::Null; inner_col_count];
5523 result.push(combine_row_projected(&outer, &null_inner, proj));
5524 } else {
5525 outer.resize(cap, Value::Null);
5526 result.push(outer);
5527 }
5528 }
5529 }
5530 JoinType::Right => {
5531 let mut inner_matched = vec![false; inner_rows.len()];
5532 for mut outer in outer_rows {
5533 if let Value::Integer(k) = outer[outer_key_col] {
5534 if let Some(indices) = inner_map.get(&k) {
5535 if let Some(proj) = projection {
5536 for &idx in indices {
5537 result.push(combine_row_projected(&outer, &inner_rows[idx], proj));
5538 inner_matched[idx] = true;
5539 }
5540 } else {
5541 for &idx in &indices[..indices.len() - 1] {
5542 result.push(combine_row(&outer, &inner_rows[idx], cap));
5543 inner_matched[idx] = true;
5544 }
5545 let last_idx = *indices.last().unwrap();
5546 inner_matched[last_idx] = true;
5547 outer.extend(inner_rows[last_idx].iter().cloned());
5548 result.push(outer);
5549 }
5550 }
5551 }
5552 }
5553 for (j, inner) in inner_rows.iter().enumerate() {
5554 if !inner_matched[j] {
5555 if let Some(proj) = projection {
5556 let null_outer = vec![Value::Null; outer_col_count];
5557 result.push(combine_row_projected(&null_outer, inner, proj));
5558 } else {
5559 let mut padded = Vec::with_capacity(cap);
5560 padded.resize(outer_col_count, Value::Null);
5561 padded.extend(inner.iter().cloned());
5562 result.push(padded);
5563 }
5564 }
5565 }
5566 }
5567 }
5568
5569 Ok(result)
5570}
5571
5572#[allow(clippy::too_many_arguments)]
5573fn exec_join_step(
5574 mut outer_rows: Vec<Vec<Value>>,
5575 inner_rows: &[Vec<Value>],
5576 join: &JoinClause,
5577 combined_cols: &[ColumnDef],
5578 outer_col_count: usize,
5579 inner_col_count: usize,
5580 outer_pk_col: Option<usize>,
5581 projection: Option<&CombineProjection>,
5582) -> Vec<Vec<Value>> {
5583 let equi_pairs = join
5584 .on_clause
5585 .as_ref()
5586 .map(|on| extract_equi_join_keys(on, combined_cols, outer_col_count))
5587 .unwrap_or_default();
5588
5589 let is_pure_equi = join.on_clause.as_ref().map_or(true, |on| {
5590 !equi_pairs.is_empty() && count_conjuncts(on) == equi_pairs.len()
5591 });
5592
5593 let effective_proj = if is_pure_equi { projection } else { None };
5594
5595 if equi_pairs.len() == 1 && is_pure_equi {
5596 let (outer_key_col, inner_key_col) = equi_pairs[0];
5597 let outer_is_sorted = outer_pk_col == Some(outer_key_col);
5598 match try_integer_join(
5599 outer_rows,
5600 inner_rows,
5601 &join.join_type,
5602 outer_key_col,
5603 inner_key_col,
5604 outer_col_count,
5605 inner_col_count,
5606 outer_is_sorted,
5607 effective_proj,
5608 ) {
5609 Ok(result) => return result,
5610 Err(rows) => outer_rows = rows,
5611 }
5612 }
5613
5614 let outer_key_cols: Vec<usize> = equi_pairs.iter().map(|&(o, _)| o).collect();
5615 let inner_key_cols: Vec<usize> = equi_pairs.iter().map(|&(_, i)| i).collect();
5616
5617 let mut inner_map: HashMap<Vec<Value>, Vec<usize>> = HashMap::new();
5618 for (idx, inner) in inner_rows.iter().enumerate() {
5619 inner_map
5620 .entry(hash_key(inner, &inner_key_cols))
5621 .or_default()
5622 .push(idx);
5623 }
5624
5625 let cap = effective_proj.map_or(outer_col_count + inner_col_count, |p| p.slots.len());
5626 let mut result = Vec::new();
5627
5628 if is_pure_equi {
5629 match join.join_type {
5630 JoinType::Inner | JoinType::Cross => {
5631 for mut outer in outer_rows {
5632 let key = hash_key(&outer, &outer_key_cols);
5633 if let Some(indices) = inner_map.get(&key) {
5634 if let Some(proj) = effective_proj {
5635 for &idx in indices {
5636 result.push(combine_row_projected(&outer, &inner_rows[idx], proj));
5637 }
5638 } else {
5639 for &idx in &indices[..indices.len() - 1] {
5640 result.push(combine_row(&outer, &inner_rows[idx], cap));
5641 }
5642 let last_idx = *indices.last().unwrap();
5643 outer.extend(inner_rows[last_idx].iter().cloned());
5644 result.push(outer);
5645 }
5646 }
5647 }
5648 }
5649 JoinType::Left => {
5650 for mut outer in outer_rows {
5651 let key = hash_key(&outer, &outer_key_cols);
5652 if let Some(indices) = inner_map.get(&key) {
5653 if let Some(proj) = effective_proj {
5654 for &idx in indices {
5655 result.push(combine_row_projected(&outer, &inner_rows[idx], proj));
5656 }
5657 } else {
5658 for &idx in &indices[..indices.len() - 1] {
5659 result.push(combine_row(&outer, &inner_rows[idx], cap));
5660 }
5661 let last_idx = *indices.last().unwrap();
5662 outer.extend(inner_rows[last_idx].iter().cloned());
5663 result.push(outer);
5664 }
5665 } else if let Some(proj) = effective_proj {
5666 let null_inner = vec![Value::Null; inner_col_count];
5667 result.push(combine_row_projected(&outer, &null_inner, proj));
5668 } else {
5669 outer.resize(cap, Value::Null);
5670 result.push(outer);
5671 }
5672 }
5673 }
5674 JoinType::Right => {
5675 let mut inner_matched = vec![false; inner_rows.len()];
5676 for mut outer in outer_rows {
5677 let key = hash_key(&outer, &outer_key_cols);
5678 if let Some(indices) = inner_map.get(&key) {
5679 if let Some(proj) = effective_proj {
5680 for &idx in indices {
5681 result.push(combine_row_projected(&outer, &inner_rows[idx], proj));
5682 inner_matched[idx] = true;
5683 }
5684 } else {
5685 for &idx in &indices[..indices.len() - 1] {
5686 result.push(combine_row(&outer, &inner_rows[idx], cap));
5687 inner_matched[idx] = true;
5688 }
5689 let last_idx = *indices.last().unwrap();
5690 inner_matched[last_idx] = true;
5691 outer.extend(inner_rows[last_idx].iter().cloned());
5692 result.push(outer);
5693 }
5694 }
5695 }
5696 for (j, inner) in inner_rows.iter().enumerate() {
5697 if !inner_matched[j] {
5698 if let Some(proj) = effective_proj {
5699 let null_outer = vec![Value::Null; outer_col_count];
5700 result.push(combine_row_projected(&null_outer, inner, proj));
5701 } else {
5702 let mut padded = Vec::with_capacity(cap);
5703 padded.resize(outer_col_count, Value::Null);
5704 padded.extend(inner.iter().cloned());
5705 result.push(padded);
5706 }
5707 }
5708 }
5709 }
5710 }
5711 } else {
5712 let combined_map = ColumnMap::new(combined_cols);
5713 let on_matches = |combined: &[Value]| -> bool {
5714 match join.on_clause {
5715 Some(ref on_expr) => eval_expr(on_expr, &combined_map, combined)
5716 .map(|v| is_truthy(&v))
5717 .unwrap_or(false),
5718 None => true,
5719 }
5720 };
5721
5722 match join.join_type {
5723 JoinType::Inner | JoinType::Cross => {
5724 for outer in &outer_rows {
5725 let key = hash_key(outer, &outer_key_cols);
5726 if let Some(indices) = inner_map.get(&key) {
5727 for &idx in indices {
5728 let combined = combine_row(outer, &inner_rows[idx], cap);
5729 if on_matches(&combined) {
5730 result.push(combined);
5731 }
5732 }
5733 }
5734 }
5735 }
5736 JoinType::Left => {
5737 for outer in &outer_rows {
5738 let key = hash_key(outer, &outer_key_cols);
5739 let mut matched = false;
5740 if let Some(indices) = inner_map.get(&key) {
5741 for &idx in indices {
5742 let combined = combine_row(outer, &inner_rows[idx], cap);
5743 if on_matches(&combined) {
5744 result.push(combined);
5745 matched = true;
5746 }
5747 }
5748 }
5749 if !matched {
5750 let mut padded = Vec::with_capacity(cap);
5751 padded.extend(outer.iter().cloned());
5752 padded.resize(cap, Value::Null);
5753 result.push(padded);
5754 }
5755 }
5756 }
5757 JoinType::Right => {
5758 let mut inner_matched = vec![false; inner_rows.len()];
5759 for outer in &outer_rows {
5760 let key = hash_key(outer, &outer_key_cols);
5761 if let Some(indices) = inner_map.get(&key) {
5762 for &idx in indices {
5763 let combined = combine_row(outer, &inner_rows[idx], cap);
5764 if on_matches(&combined) {
5765 result.push(combined);
5766 inner_matched[idx] = true;
5767 }
5768 }
5769 }
5770 }
5771 for (j, inner) in inner_rows.iter().enumerate() {
5772 if !inner_matched[j] {
5773 let mut padded = Vec::with_capacity(cap);
5774 padded.resize(outer_col_count, Value::Null);
5775 padded.extend(inner.iter().cloned());
5776 result.push(padded);
5777 }
5778 }
5779 }
5780 }
5781 }
5782
5783 result
5784}
5785
5786fn table_alias_or_name(name: &str, alias: &Option<String>) -> String {
5787 match alias {
5788 Some(a) => a.to_ascii_lowercase(),
5789 None => name.to_ascii_lowercase(),
5790 }
5791}
5792
5793fn collect_all_rows_raw(
5794 rtx: &mut citadel_txn::read_txn::ReadTxn<'_>,
5795 table_schema: &TableSchema,
5796) -> Result<Vec<Vec<Value>>> {
5797 let lower_name = &table_schema.name;
5798 let entry_count = rtx.table_entry_count(lower_name.as_bytes()).unwrap_or(0) as usize;
5799 let mut rows = Vec::with_capacity(entry_count);
5800 let mut scan_err: Option<SqlError> = None;
5801 rtx.table_scan_raw(lower_name.as_bytes(), |key, value| {
5802 match decode_full_row(table_schema, key, value) {
5803 Ok(row) => rows.push(row),
5804 Err(e) => {
5805 scan_err = Some(e);
5806 return false;
5807 }
5808 }
5809 true
5810 })
5811 .map_err(SqlError::Storage)?;
5812 if let Some(e) = scan_err {
5813 return Err(e);
5814 }
5815 Ok(rows)
5816}
5817
5818fn collect_all_rows_write(
5819 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
5820 table_schema: &TableSchema,
5821) -> Result<Vec<Vec<Value>>> {
5822 collect_rows_write(wtx, table_schema, &None, None).map(|(rows, _)| rows)
5823}
5824
5825fn has_ambiguous_bare_ref(expr: &Expr, columns: &[ColumnDef]) -> bool {
5826 match expr {
5827 Expr::Column(name) => {
5828 let lower = name.to_ascii_lowercase();
5829 columns
5830 .iter()
5831 .filter(|c| c.name == lower || c.name.ends_with(&format!(".{lower}")))
5832 .count()
5833 > 1
5834 }
5835 Expr::BinaryOp { left, right, .. } => {
5836 has_ambiguous_bare_ref(left, columns) || has_ambiguous_bare_ref(right, columns)
5837 }
5838 Expr::UnaryOp { expr: inner, .. } | Expr::IsNull(inner) | Expr::IsNotNull(inner) => {
5839 has_ambiguous_bare_ref(inner, columns)
5840 }
5841 Expr::Function { args, .. } | Expr::Coalesce(args) => {
5842 args.iter().any(|a| has_ambiguous_bare_ref(a, columns))
5843 }
5844 Expr::Between {
5845 expr: e, low, high, ..
5846 } => {
5847 has_ambiguous_bare_ref(e, columns)
5848 || has_ambiguous_bare_ref(low, columns)
5849 || has_ambiguous_bare_ref(high, columns)
5850 }
5851 Expr::InList { expr: e, list, .. } => {
5852 has_ambiguous_bare_ref(e, columns)
5853 || list.iter().any(|a| has_ambiguous_bare_ref(a, columns))
5854 }
5855 Expr::Like {
5856 expr: e,
5857 pattern,
5858 escape,
5859 ..
5860 } => {
5861 has_ambiguous_bare_ref(e, columns)
5862 || has_ambiguous_bare_ref(pattern, columns)
5863 || escape
5864 .as_ref()
5865 .is_some_and(|esc| has_ambiguous_bare_ref(esc, columns))
5866 }
5867 Expr::Cast { expr: inner, .. } => has_ambiguous_bare_ref(inner, columns),
5868 Expr::Case {
5869 operand,
5870 conditions,
5871 else_result,
5872 } => {
5873 operand
5874 .as_ref()
5875 .is_some_and(|o| has_ambiguous_bare_ref(o, columns))
5876 || conditions.iter().any(|(w, t)| {
5877 has_ambiguous_bare_ref(w, columns) || has_ambiguous_bare_ref(t, columns)
5878 })
5879 || else_result
5880 .as_ref()
5881 .is_some_and(|e| has_ambiguous_bare_ref(e, columns))
5882 }
5883 _ => false,
5884 }
5885}
5886
5887struct JoinColumnPlan {
5888 per_table: Vec<Vec<usize>>,
5889 output_combined: Vec<usize>,
5890}
5891
5892fn compute_join_needed_columns(
5893 stmt: &SelectStmt,
5894 tables: &[(String, &TableSchema)],
5895) -> Option<JoinColumnPlan> {
5896 for sel in &stmt.columns {
5897 if matches!(sel, SelectColumn::AllColumns) {
5898 return None;
5899 }
5900 }
5901
5902 let combined_cols = build_joined_columns(tables);
5903
5904 for sel in &stmt.columns {
5905 if let SelectColumn::Expr { expr, .. } = sel {
5906 if has_ambiguous_bare_ref(expr, &combined_cols) {
5907 return None;
5908 }
5909 }
5910 }
5911
5912 let mut output_combined: Vec<usize> = Vec::new();
5913 for sel in &stmt.columns {
5914 if let SelectColumn::Expr { expr, .. } = sel {
5915 output_combined.extend(referenced_columns(expr, &combined_cols));
5916 }
5917 }
5918 if let Some(w) = &stmt.where_clause {
5919 output_combined.extend(referenced_columns(w, &combined_cols));
5920 }
5921 for ob in &stmt.order_by {
5922 output_combined.extend(referenced_columns(&ob.expr, &combined_cols));
5923 }
5924 for gb in &stmt.group_by {
5925 output_combined.extend(referenced_columns(gb, &combined_cols));
5926 }
5927 if let Some(h) = &stmt.having {
5928 output_combined.extend(referenced_columns(h, &combined_cols));
5929 }
5930 output_combined.sort_unstable();
5931 output_combined.dedup();
5932
5933 let mut needed_combined = output_combined.clone();
5934 for join in &stmt.joins {
5935 if let Some(on_expr) = &join.on_clause {
5936 needed_combined.extend(referenced_columns(on_expr, &combined_cols));
5937 }
5938 }
5939 needed_combined.sort_unstable();
5940 needed_combined.dedup();
5941
5942 let mut offsets = Vec::with_capacity(tables.len() + 1);
5943 offsets.push(0usize);
5944 for (_, s) in tables {
5945 offsets.push(offsets.last().unwrap() + s.columns.len());
5946 }
5947
5948 let mut per_table: Vec<Vec<usize>> = tables.iter().map(|_| Vec::new()).collect();
5949 for &ci in &needed_combined {
5950 for (t, _) in tables.iter().enumerate() {
5951 let start = offsets[t];
5952 let end = offsets[t + 1];
5953 if ci >= start && ci < end {
5954 per_table[t].push(ci - start);
5955 break;
5956 }
5957 }
5958 }
5959
5960 Some(JoinColumnPlan {
5961 per_table,
5962 output_combined,
5963 })
5964}
5965
5966fn collect_rows_partial(
5967 rtx: &mut citadel_txn::read_txn::ReadTxn<'_>,
5968 table_schema: &TableSchema,
5969 needed: &[usize],
5970) -> Result<Vec<Vec<Value>>> {
5971 if needed.is_empty() || needed.len() == table_schema.columns.len() {
5972 return collect_all_rows_raw(rtx, table_schema);
5973 }
5974 let ctx = PartialDecodeCtx::new(table_schema, needed);
5975 let lower_name = &table_schema.name;
5976 let entry_count = rtx.table_entry_count(lower_name.as_bytes()).unwrap_or(0) as usize;
5977 let mut rows = Vec::with_capacity(entry_count);
5978 let mut scan_err: Option<SqlError> = None;
5979 rtx.table_scan_raw(lower_name.as_bytes(), |key, value| {
5980 match ctx.decode(key, value) {
5981 Ok(row) => rows.push(row),
5982 Err(e) => {
5983 scan_err = Some(e);
5984 return false;
5985 }
5986 }
5987 true
5988 })
5989 .map_err(SqlError::Storage)?;
5990 if let Some(e) = scan_err {
5991 return Err(e);
5992 }
5993 Ok(rows)
5994}
5995
5996fn collect_rows_partial_write(
5997 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
5998 table_schema: &TableSchema,
5999 needed: &[usize],
6000) -> Result<Vec<Vec<Value>>> {
6001 if needed.is_empty() || needed.len() == table_schema.columns.len() {
6002 return collect_all_rows_write(wtx, table_schema);
6003 }
6004 let ctx = PartialDecodeCtx::new(table_schema, needed);
6005 let lower_name = &table_schema.name;
6006 let entry_count = wtx.table_entry_count(lower_name.as_bytes()).unwrap_or(0) as usize;
6007 let mut rows = Vec::with_capacity(entry_count);
6008 let mut scan_err: Option<SqlError> = None;
6009 wtx.table_scan_from(lower_name.as_bytes(), b"", |key, value| {
6010 match ctx.decode(key, value) {
6011 Ok(row) => rows.push(row),
6012 Err(e) => {
6013 scan_err = Some(e);
6014 return Ok(false);
6015 }
6016 }
6017 Ok(true)
6018 })
6019 .map_err(SqlError::Storage)?;
6020 if let Some(e) = scan_err {
6021 return Err(e);
6022 }
6023 Ok(rows)
6024}
6025
6026fn exec_select_join(
6027 db: &Database,
6028 schema: &SchemaManager,
6029 stmt: &SelectStmt,
6030) -> Result<ExecutionResult> {
6031 let from_schema = resolve_table_name(schema, &stmt.from)?;
6032 let from_alias = table_alias_or_name(&stmt.from, &stmt.from_alias);
6033
6034 let mut all_tables: Vec<(String, &TableSchema)> = vec![(from_alias.clone(), from_schema)];
6035 for join in &stmt.joins {
6036 let inner_schema = resolve_table_name(schema, &join.table.name)?;
6037 let inner_alias = table_alias_or_name(&join.table.name, &join.table.alias);
6038 all_tables.push((inner_alias, inner_schema));
6039 }
6040 let (needed_per_table, output_combined) = match compute_join_needed_columns(stmt, &all_tables) {
6041 Some(plan) => (Some(plan.per_table), Some(plan.output_combined)),
6042 None => (None, None),
6043 };
6044
6045 let mut rtx = db.begin_read();
6046 let mut outer_rows = match &needed_per_table {
6047 Some(n) if !n.is_empty() => collect_rows_partial(&mut rtx, from_schema, &n[0])?,
6048 _ => collect_all_rows_raw(&mut rtx, from_schema)?,
6049 };
6050
6051 let mut tables: Vec<(String, &TableSchema)> = vec![(from_alias.clone(), from_schema)];
6052 let mut cur_outer_pk_col: Option<usize> = if from_schema.primary_key_columns.len() == 1 {
6053 Some(from_schema.primary_key_columns[0] as usize)
6054 } else {
6055 None
6056 };
6057
6058 let num_joins = stmt.joins.len();
6059 let mut last_combined_cols: Option<Vec<ColumnDef>> = None;
6060 for (ji, join) in stmt.joins.iter().enumerate() {
6061 let inner_schema = resolve_table_name(schema, &join.table.name)?;
6062 let inner_alias = table_alias_or_name(&join.table.name, &join.table.alias);
6063 let inner_rows = match &needed_per_table {
6064 Some(n) if ji + 1 < n.len() => {
6065 collect_rows_partial(&mut rtx, inner_schema, &n[ji + 1])?
6066 }
6067 _ => collect_all_rows_raw(&mut rtx, inner_schema)?,
6068 };
6069
6070 let mut preview_tables = tables.clone();
6071 preview_tables.push((inner_alias.clone(), inner_schema));
6072 let combined_cols = build_joined_columns(&preview_tables);
6073
6074 let outer_col_count = if outer_rows.is_empty() {
6075 tables.iter().map(|(_, s)| s.columns.len()).sum()
6076 } else {
6077 outer_rows[0].len()
6078 };
6079 let inner_col_count = inner_schema.columns.len();
6080
6081 let is_last = ji == num_joins - 1;
6082 let proj = if is_last {
6083 output_combined
6084 .as_ref()
6085 .map(|oc| build_combine_projection(oc, outer_col_count))
6086 } else {
6087 None
6088 };
6089
6090 outer_rows = exec_join_step(
6091 outer_rows,
6092 &inner_rows,
6093 join,
6094 &combined_cols,
6095 outer_col_count,
6096 inner_col_count,
6097 cur_outer_pk_col,
6098 proj.as_ref(),
6099 );
6100 last_combined_cols = Some(combined_cols);
6101 tables.push((inner_alias, inner_schema));
6102 cur_outer_pk_col = None;
6103 }
6104 drop(rtx);
6105
6106 let joined_cols = last_combined_cols.unwrap_or_else(|| build_joined_columns(&tables));
6107 if let Some(ref oc) = output_combined {
6108 let actual_width = outer_rows.first().map_or(0, |r| r.len());
6109 if actual_width == oc.len() {
6110 let projected_cols = build_projected_columns(&joined_cols, oc);
6111 return process_select(&projected_cols, outer_rows, stmt, false);
6112 }
6113 }
6114 process_select(&joined_cols, outer_rows, stmt, false)
6115}
6116
6117fn exec_select_join_in_txn(
6118 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
6119 schema: &SchemaManager,
6120 stmt: &SelectStmt,
6121) -> Result<ExecutionResult> {
6122 let from_schema = resolve_table_name(schema, &stmt.from)?;
6123 let from_alias = table_alias_or_name(&stmt.from, &stmt.from_alias);
6124
6125 let mut all_tables: Vec<(String, &TableSchema)> = vec![(from_alias.clone(), from_schema)];
6126 for join in &stmt.joins {
6127 let inner_schema = resolve_table_name(schema, &join.table.name)?;
6128 let inner_alias = table_alias_or_name(&join.table.name, &join.table.alias);
6129 all_tables.push((inner_alias, inner_schema));
6130 }
6131 let (needed_per_table, output_combined) = match compute_join_needed_columns(stmt, &all_tables) {
6132 Some(plan) => (Some(plan.per_table), Some(plan.output_combined)),
6133 None => (None, None),
6134 };
6135
6136 let mut outer_rows = match &needed_per_table {
6137 Some(n) if !n.is_empty() => collect_rows_partial_write(wtx, from_schema, &n[0])?,
6138 _ => collect_all_rows_write(wtx, from_schema)?,
6139 };
6140
6141 let mut tables: Vec<(String, &TableSchema)> = vec![(from_alias.clone(), from_schema)];
6142 let mut cur_outer_pk_col: Option<usize> = if from_schema.primary_key_columns.len() == 1 {
6143 Some(from_schema.primary_key_columns[0] as usize)
6144 } else {
6145 None
6146 };
6147
6148 let num_joins = stmt.joins.len();
6149 for (ji, join) in stmt.joins.iter().enumerate() {
6150 let inner_schema = resolve_table_name(schema, &join.table.name)?;
6151 let inner_alias = table_alias_or_name(&join.table.name, &join.table.alias);
6152 let inner_rows = match &needed_per_table {
6153 Some(n) if ji + 1 < n.len() => {
6154 collect_rows_partial_write(wtx, inner_schema, &n[ji + 1])?
6155 }
6156 _ => collect_all_rows_write(wtx, inner_schema)?,
6157 };
6158
6159 let mut preview_tables = tables.clone();
6160 preview_tables.push((inner_alias.clone(), inner_schema));
6161 let combined_cols = build_joined_columns(&preview_tables);
6162
6163 let outer_col_count = if outer_rows.is_empty() {
6164 tables.iter().map(|(_, s)| s.columns.len()).sum()
6165 } else {
6166 outer_rows[0].len()
6167 };
6168 let inner_col_count = inner_schema.columns.len();
6169
6170 let is_last = ji == num_joins - 1;
6171 let proj = if is_last {
6172 output_combined
6173 .as_ref()
6174 .map(|oc| build_combine_projection(oc, outer_col_count))
6175 } else {
6176 None
6177 };
6178
6179 outer_rows = exec_join_step(
6180 outer_rows,
6181 &inner_rows,
6182 join,
6183 &combined_cols,
6184 outer_col_count,
6185 inner_col_count,
6186 cur_outer_pk_col,
6187 proj.as_ref(),
6188 );
6189 tables.push((inner_alias, inner_schema));
6190 cur_outer_pk_col = None;
6191 }
6192
6193 let joined_cols = build_joined_columns(&tables);
6194 if let Some(ref oc) = output_combined {
6195 let actual_width = outer_rows.first().map_or(0, |r| r.len());
6196 if actual_width == oc.len() {
6197 let projected_cols = build_projected_columns(&joined_cols, oc);
6198 return process_select(&projected_cols, outer_rows, stmt, false);
6199 }
6200 }
6201 process_select(&joined_cols, outer_rows, stmt, false)
6202}
6203
6204fn exec_update(
6205 db: &Database,
6206 schema: &SchemaManager,
6207 stmt: &UpdateStmt,
6208) -> Result<ExecutionResult> {
6209 let materialized;
6210 let stmt = if update_has_subquery(stmt) {
6211 materialized = materialize_update(stmt, &mut |sub| {
6212 exec_subquery_read(db, schema, sub, &HashMap::new())
6213 })?;
6214 &materialized
6215 } else {
6216 stmt
6217 };
6218
6219 let lower_name = stmt.table.to_ascii_lowercase();
6220 let table_schema = schema
6221 .get(&lower_name)
6222 .ok_or_else(|| SqlError::TableNotFound(stmt.table.clone()))?;
6223
6224 let col_map = ColumnMap::new(&table_schema.columns);
6225 let all_candidates = collect_keyed_rows_read(db, table_schema, &stmt.where_clause)?;
6226 let matching_rows: Vec<(Vec<u8>, Vec<Value>)> = all_candidates
6227 .into_iter()
6228 .filter(|(_, row)| match &stmt.where_clause {
6229 Some(where_expr) => match eval_expr(where_expr, &col_map, row) {
6230 Ok(val) => is_truthy(&val),
6231 Err(_) => false,
6232 },
6233 None => true,
6234 })
6235 .collect();
6236
6237 if matching_rows.is_empty() {
6238 return Ok(ExecutionResult::RowsAffected(0));
6239 }
6240
6241 struct UpdateChange {
6242 old_key: Vec<u8>,
6243 new_key: Vec<u8>,
6244 new_value: Vec<u8>,
6245 pk_changed: bool,
6246 old_row: Vec<Value>,
6247 new_row: Vec<Value>,
6248 }
6249
6250 let pk_indices = table_schema.pk_indices();
6251 let mut changes: Vec<UpdateChange> = Vec::new();
6252
6253 for (old_key, row) in &matching_rows {
6254 let mut new_row = row.clone();
6255 let mut pk_changed = false;
6256
6257 let mut evaluated: Vec<(usize, Value)> = Vec::with_capacity(stmt.assignments.len());
6259 for (col_name, expr) in &stmt.assignments {
6260 let col_idx = table_schema
6261 .column_index(col_name)
6262 .ok_or_else(|| SqlError::ColumnNotFound(col_name.clone()))?;
6263 let new_val = eval_expr(expr, &col_map, row)?;
6264 let col = &table_schema.columns[col_idx];
6265
6266 let got_type = new_val.data_type();
6267 let coerced = if new_val.is_null() {
6268 if !col.nullable {
6269 return Err(SqlError::NotNullViolation(col.name.clone()));
6270 }
6271 Value::Null
6272 } else {
6273 new_val
6274 .coerce_into(col.data_type)
6275 .ok_or_else(|| SqlError::TypeMismatch {
6276 expected: col.data_type.to_string(),
6277 got: got_type.to_string(),
6278 })?
6279 };
6280
6281 evaluated.push((col_idx, coerced));
6282 }
6283
6284 for (col_idx, coerced) in evaluated {
6285 if table_schema.primary_key_columns.contains(&(col_idx as u16)) {
6286 pk_changed = true;
6287 }
6288 new_row[col_idx] = coerced;
6289 }
6290
6291 if table_schema.has_checks() {
6293 for col in &table_schema.columns {
6294 if let Some(ref check) = col.check_expr {
6295 let result = eval_expr(check, &col_map, &new_row)?;
6296 if !is_truthy(&result) && !result.is_null() {
6297 let name = col.check_name.as_deref().unwrap_or(&col.name);
6298 return Err(SqlError::CheckViolation(name.to_string()));
6299 }
6300 }
6301 }
6302 for tc in &table_schema.check_constraints {
6303 let result = eval_expr(&tc.expr, &col_map, &new_row)?;
6304 if !is_truthy(&result) && !result.is_null() {
6305 let name = tc.name.as_deref().unwrap_or(&tc.sql);
6306 return Err(SqlError::CheckViolation(name.to_string()));
6307 }
6308 }
6309 }
6310
6311 let pk_values: Vec<Value> = pk_indices.iter().map(|&i| new_row[i].clone()).collect();
6312 let new_key = encode_composite_key(&pk_values);
6313
6314 let non_pk = table_schema.non_pk_indices();
6315 let enc_pos = table_schema.encoding_positions();
6316 let phys_count = table_schema.physical_non_pk_count();
6317 let mut value_values = vec![Value::Null; phys_count];
6318 for (j, &i) in non_pk.iter().enumerate() {
6319 value_values[enc_pos[j] as usize] = new_row[i].clone();
6320 }
6321 let new_value = encode_row(&value_values);
6322
6323 changes.push(UpdateChange {
6324 old_key: old_key.clone(),
6325 new_key,
6326 new_value,
6327 pk_changed,
6328 old_row: row.clone(),
6329 new_row,
6330 });
6331 }
6332
6333 {
6334 use std::collections::HashSet;
6335 let mut new_keys: HashSet<Vec<u8>> = HashSet::new();
6336 for c in &changes {
6337 if c.pk_changed && c.new_key != c.old_key && !new_keys.insert(c.new_key.clone()) {
6338 return Err(SqlError::DuplicateKey);
6339 }
6340 }
6341 }
6342
6343 let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
6344
6345 if !table_schema.foreign_keys.is_empty() {
6347 for c in &changes {
6348 for fk in &table_schema.foreign_keys {
6349 let fk_changed = fk
6350 .columns
6351 .iter()
6352 .any(|&ci| c.old_row[ci as usize] != c.new_row[ci as usize]);
6353 if !fk_changed {
6354 continue;
6355 }
6356 let any_null = fk
6357 .columns
6358 .iter()
6359 .any(|&ci| c.new_row[ci as usize].is_null());
6360 if any_null {
6361 continue;
6362 }
6363 let fk_vals: Vec<Value> = fk
6364 .columns
6365 .iter()
6366 .map(|&ci| c.new_row[ci as usize].clone())
6367 .collect();
6368 let fk_key = encode_composite_key(&fk_vals);
6369 let found = wtx
6370 .table_get(fk.foreign_table.as_bytes(), &fk_key)
6371 .map_err(SqlError::Storage)?;
6372 if found.is_none() {
6373 let name = fk.name.as_deref().unwrap_or(&fk.foreign_table);
6374 return Err(SqlError::ForeignKeyViolation(name.to_string()));
6375 }
6376 }
6377 }
6378 }
6379
6380 let child_fks = schema.child_fks_for(&lower_name);
6382 if !child_fks.is_empty() {
6383 for c in &changes {
6384 if !c.pk_changed {
6385 continue;
6386 }
6387 let old_pk: Vec<Value> = pk_indices.iter().map(|&i| c.old_row[i].clone()).collect();
6388 let old_pk_key = encode_composite_key(&old_pk);
6389 for &(child_table, fk) in &child_fks {
6390 let child_schema = schema.get(child_table).unwrap();
6391 let fk_idx = child_schema
6392 .indices
6393 .iter()
6394 .find(|idx| idx.columns == fk.columns);
6395 if let Some(idx) = fk_idx {
6396 let idx_table = TableSchema::index_table_name(child_table, &idx.name);
6397 let mut has_child = false;
6398 wtx.table_scan_from(&idx_table, &old_pk_key, |key, _| {
6399 if key.starts_with(&old_pk_key) {
6400 has_child = true;
6401 Ok(false) } else {
6403 Ok(false) }
6405 })
6406 .map_err(SqlError::Storage)?;
6407 if has_child {
6408 return Err(SqlError::ForeignKeyViolation(format!(
6409 "cannot update PK in '{}': referenced by '{}'",
6410 lower_name, child_table
6411 )));
6412 }
6413 }
6414 }
6415 }
6416 }
6417
6418 for c in &changes {
6419 let old_pk: Vec<Value> = pk_indices.iter().map(|&i| c.old_row[i].clone()).collect();
6420
6421 for idx in &table_schema.indices {
6422 if index_columns_changed(idx, &c.old_row, &c.new_row) || c.pk_changed {
6423 let idx_table = TableSchema::index_table_name(&lower_name, &idx.name);
6424 let old_idx_key = encode_index_key(idx, &c.old_row, &old_pk);
6425 wtx.table_delete(&idx_table, &old_idx_key)
6426 .map_err(SqlError::Storage)?;
6427 }
6428 }
6429
6430 if c.pk_changed {
6431 wtx.table_delete(lower_name.as_bytes(), &c.old_key)
6432 .map_err(SqlError::Storage)?;
6433 }
6434 }
6435
6436 for c in &changes {
6437 let new_pk: Vec<Value> = pk_indices.iter().map(|&i| c.new_row[i].clone()).collect();
6438
6439 if c.pk_changed {
6440 let is_new = wtx
6441 .table_insert(lower_name.as_bytes(), &c.new_key, &c.new_value)
6442 .map_err(SqlError::Storage)?;
6443 if !is_new {
6444 return Err(SqlError::DuplicateKey);
6445 }
6446 } else {
6447 wtx.table_insert(lower_name.as_bytes(), &c.new_key, &c.new_value)
6448 .map_err(SqlError::Storage)?;
6449 }
6450
6451 for idx in &table_schema.indices {
6452 if index_columns_changed(idx, &c.old_row, &c.new_row) || c.pk_changed {
6453 let idx_table = TableSchema::index_table_name(&lower_name, &idx.name);
6454 let new_idx_key = encode_index_key(idx, &c.new_row, &new_pk);
6455 let new_idx_val = encode_index_value(idx, &c.new_row, &new_pk);
6456 let is_new = wtx
6457 .table_insert(&idx_table, &new_idx_key, &new_idx_val)
6458 .map_err(SqlError::Storage)?;
6459 if idx.unique && !is_new {
6460 let indexed_values: Vec<Value> = idx
6461 .columns
6462 .iter()
6463 .map(|&col_idx| c.new_row[col_idx as usize].clone())
6464 .collect();
6465 let any_null = indexed_values.iter().any(|v| v.is_null());
6466 if !any_null {
6467 return Err(SqlError::UniqueViolation(idx.name.clone()));
6468 }
6469 }
6470 }
6471 }
6472 }
6473
6474 let count = changes.len() as u64;
6475 wtx.commit().map_err(SqlError::Storage)?;
6476 Ok(ExecutionResult::RowsAffected(count))
6477}
6478
6479fn exec_delete(
6480 db: &Database,
6481 schema: &SchemaManager,
6482 stmt: &DeleteStmt,
6483) -> Result<ExecutionResult> {
6484 let materialized;
6485 let stmt = if delete_has_subquery(stmt) {
6486 materialized = materialize_delete(stmt, &mut |sub| {
6487 exec_subquery_read(db, schema, sub, &HashMap::new())
6488 })?;
6489 &materialized
6490 } else {
6491 stmt
6492 };
6493
6494 let lower_name = stmt.table.to_ascii_lowercase();
6495 let table_schema = schema
6496 .get(&lower_name)
6497 .ok_or_else(|| SqlError::TableNotFound(stmt.table.clone()))?;
6498
6499 let col_map = ColumnMap::new(&table_schema.columns);
6500 let all_candidates = collect_keyed_rows_read(db, table_schema, &stmt.where_clause)?;
6501 let rows_to_delete: Vec<(Vec<u8>, Vec<Value>)> = all_candidates
6502 .into_iter()
6503 .filter(|(_, row)| match &stmt.where_clause {
6504 Some(where_expr) => match eval_expr(where_expr, &col_map, row) {
6505 Ok(val) => is_truthy(&val),
6506 Err(_) => false,
6507 },
6508 None => true,
6509 })
6510 .collect();
6511
6512 if rows_to_delete.is_empty() {
6513 return Ok(ExecutionResult::RowsAffected(0));
6514 }
6515
6516 let pk_indices = table_schema.pk_indices();
6517 let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
6518
6519 let child_fks = schema.child_fks_for(&lower_name);
6521 if !child_fks.is_empty() {
6522 for (_key, row) in &rows_to_delete {
6523 let pk_values: Vec<Value> = pk_indices.iter().map(|&i| row[i].clone()).collect();
6524 let pk_key = encode_composite_key(&pk_values);
6525 for &(child_table, fk) in &child_fks {
6526 let child_schema = schema.get(child_table).unwrap();
6527 let fk_idx = child_schema
6528 .indices
6529 .iter()
6530 .find(|idx| idx.columns == fk.columns);
6531 if let Some(idx) = fk_idx {
6532 let idx_table = TableSchema::index_table_name(child_table, &idx.name);
6533 let mut has_child = false;
6534 wtx.table_scan_from(&idx_table, &pk_key, |key, _| {
6535 if key.starts_with(&pk_key) {
6536 has_child = true;
6537 Ok(false)
6538 } else {
6539 Ok(false)
6540 }
6541 })
6542 .map_err(SqlError::Storage)?;
6543 if has_child {
6544 return Err(SqlError::ForeignKeyViolation(format!(
6545 "cannot delete from '{}': referenced by '{}'",
6546 lower_name, child_table
6547 )));
6548 }
6549 }
6550 }
6551 }
6552 }
6553
6554 for (key, row) in &rows_to_delete {
6555 let pk_values: Vec<Value> = pk_indices.iter().map(|&i| row[i].clone()).collect();
6556 delete_index_entries(&mut wtx, table_schema, row, &pk_values)?;
6557 wtx.table_delete(lower_name.as_bytes(), key)
6558 .map_err(SqlError::Storage)?;
6559 }
6560 let count = rows_to_delete.len() as u64;
6561 wtx.commit().map_err(SqlError::Storage)?;
6562 Ok(ExecutionResult::RowsAffected(count))
6563}
6564
6565#[derive(Default)]
6566pub struct InsertBufs {
6567 row: Vec<Value>,
6568 pk_values: Vec<Value>,
6569 value_values: Vec<Value>,
6570 key_buf: Vec<u8>,
6571 value_buf: Vec<u8>,
6572 col_indices: Vec<usize>,
6573 fk_key_buf: Vec<u8>,
6574}
6575
6576impl InsertBufs {
6577 pub fn new() -> Self {
6578 Self {
6579 row: Vec::new(),
6580 pk_values: Vec::new(),
6581 value_values: Vec::new(),
6582 key_buf: Vec::with_capacity(64),
6583 value_buf: Vec::with_capacity(256),
6584 col_indices: Vec::new(),
6585 fk_key_buf: Vec::with_capacity(64),
6586 }
6587 }
6588}
6589
6590pub fn exec_insert_in_txn(
6591 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
6592 schema: &SchemaManager,
6593 stmt: &InsertStmt,
6594 params: &[Value],
6595 bufs: &mut InsertBufs,
6596) -> Result<ExecutionResult> {
6597 let empty_ctes = CteContext::new();
6598 let materialized;
6599 let stmt = if insert_has_subquery(stmt) {
6600 materialized = materialize_insert(stmt, &mut |sub| {
6601 exec_subquery_write(wtx, schema, sub, &empty_ctes)
6602 })?;
6603 &materialized
6604 } else {
6605 stmt
6606 };
6607
6608 let table_schema = schema
6609 .get(&stmt.table)
6610 .ok_or_else(|| SqlError::TableNotFound(stmt.table.clone()))?;
6611
6612 let default_columns;
6613 let insert_columns: &[String] = if stmt.columns.is_empty() {
6614 default_columns = table_schema
6615 .columns
6616 .iter()
6617 .map(|c| c.name.clone())
6618 .collect::<Vec<_>>();
6619 &default_columns
6620 } else {
6621 &stmt.columns
6622 };
6623
6624 bufs.col_indices.clear();
6625 for name in insert_columns {
6626 bufs.col_indices.push(
6627 table_schema
6628 .column_index(name)
6629 .ok_or_else(|| SqlError::ColumnNotFound(name.clone()))?,
6630 );
6631 }
6632
6633 let defaults: Vec<(usize, &Expr)> = table_schema
6635 .columns
6636 .iter()
6637 .filter(|c| c.default_expr.is_some() && !bufs.col_indices.contains(&(c.position as usize)))
6638 .map(|c| (c.position as usize, c.default_expr.as_ref().unwrap()))
6639 .collect();
6640
6641 let has_checks = table_schema.has_checks();
6642 let check_col_map = if has_checks {
6643 Some(ColumnMap::new(&table_schema.columns))
6644 } else {
6645 None
6646 };
6647
6648 let pk_indices = table_schema.pk_indices();
6649 let non_pk = table_schema.non_pk_indices();
6650 let enc_pos = table_schema.encoding_positions();
6651 let phys_count = table_schema.physical_non_pk_count();
6652 let dropped = table_schema.dropped_non_pk_slots();
6653
6654 bufs.row.resize(table_schema.columns.len(), Value::Null);
6655 bufs.pk_values.resize(pk_indices.len(), Value::Null);
6656 bufs.value_values.resize(phys_count, Value::Null);
6657
6658 let select_rows = match &stmt.source {
6659 InsertSource::Select(sq) => {
6660 let insert_ctes = materialize_all_ctes(&sq.ctes, sq.recursive, &mut |body, ctx| {
6661 exec_query_body_write(wtx, schema, body, ctx)
6662 })?;
6663 let qr = exec_query_body_write(wtx, schema, &sq.body, &insert_ctes)?;
6664 Some(qr.rows)
6665 }
6666 InsertSource::Values(_) => None,
6667 };
6668
6669 let mut count: u64 = 0;
6670
6671 let values = match &stmt.source {
6672 InsertSource::Values(rows) => Some(rows.as_slice()),
6673 InsertSource::Select(_) => None,
6674 };
6675 let sel_rows = select_rows.as_deref();
6676
6677 let total = match (values, sel_rows) {
6678 (Some(rows), _) => rows.len(),
6679 (_, Some(rows)) => rows.len(),
6680 _ => 0,
6681 };
6682
6683 if let Some(sel) = sel_rows {
6684 if !sel.is_empty() && sel[0].len() != insert_columns.len() {
6685 return Err(SqlError::InvalidValue(format!(
6686 "INSERT ... SELECT column count mismatch: expected {}, got {}",
6687 insert_columns.len(),
6688 sel[0].len()
6689 )));
6690 }
6691 }
6692
6693 for idx in 0..total {
6694 for v in bufs.row.iter_mut() {
6695 *v = Value::Null;
6696 }
6697
6698 if let Some(value_rows) = values {
6699 let value_row = &value_rows[idx];
6700 if value_row.len() != insert_columns.len() {
6701 return Err(SqlError::InvalidValue(format!(
6702 "expected {} values, got {}",
6703 insert_columns.len(),
6704 value_row.len()
6705 )));
6706 }
6707 for (i, expr) in value_row.iter().enumerate() {
6708 let val = if let Expr::Parameter(n) = expr {
6709 params
6710 .get(n - 1)
6711 .cloned()
6712 .ok_or_else(|| SqlError::Parse(format!("unbound parameter ${n}")))?
6713 } else {
6714 eval_const_expr(expr)?
6715 };
6716 let col_idx = bufs.col_indices[i];
6717 let col = &table_schema.columns[col_idx];
6718 let got_type = val.data_type();
6719 bufs.row[col_idx] = if val.is_null() {
6720 Value::Null
6721 } else {
6722 val.coerce_into(col.data_type)
6723 .ok_or_else(|| SqlError::TypeMismatch {
6724 expected: col.data_type.to_string(),
6725 got: got_type.to_string(),
6726 })?
6727 };
6728 }
6729 } else if let Some(sel) = sel_rows {
6730 let sel_row = &sel[idx];
6731 for (i, val) in sel_row.iter().enumerate() {
6732 let col_idx = bufs.col_indices[i];
6733 let col = &table_schema.columns[col_idx];
6734 let got_type = val.data_type();
6735 bufs.row[col_idx] = if val.is_null() {
6736 Value::Null
6737 } else {
6738 val.clone().coerce_into(col.data_type).ok_or_else(|| {
6739 SqlError::TypeMismatch {
6740 expected: col.data_type.to_string(),
6741 got: got_type.to_string(),
6742 }
6743 })?
6744 };
6745 }
6746 }
6747
6748 for &(pos, def_expr) in &defaults {
6750 let val = eval_const_expr(def_expr)?;
6751 let col = &table_schema.columns[pos];
6752 if val.is_null() {
6753 } else {
6755 let got_type = val.data_type();
6756 bufs.row[pos] =
6757 val.coerce_into(col.data_type)
6758 .ok_or_else(|| SqlError::TypeMismatch {
6759 expected: col.data_type.to_string(),
6760 got: got_type.to_string(),
6761 })?;
6762 }
6763 }
6764
6765 for col in &table_schema.columns {
6766 if !col.nullable && bufs.row[col.position as usize].is_null() {
6767 return Err(SqlError::NotNullViolation(col.name.clone()));
6768 }
6769 }
6770
6771 if let Some(ref col_map) = check_col_map {
6773 for col in &table_schema.columns {
6774 if let Some(ref check) = col.check_expr {
6775 let result = eval_expr(check, col_map, &bufs.row)?;
6776 if !is_truthy(&result) && !result.is_null() {
6777 let name = col.check_name.as_deref().unwrap_or(&col.name);
6778 return Err(SqlError::CheckViolation(name.to_string()));
6779 }
6780 }
6781 }
6782 for tc in &table_schema.check_constraints {
6783 let result = eval_expr(&tc.expr, col_map, &bufs.row)?;
6784 if !is_truthy(&result) && !result.is_null() {
6785 let name = tc.name.as_deref().unwrap_or(&tc.sql);
6786 return Err(SqlError::CheckViolation(name.to_string()));
6787 }
6788 }
6789 }
6790
6791 for fk in &table_schema.foreign_keys {
6793 let any_null = fk.columns.iter().any(|&ci| bufs.row[ci as usize].is_null());
6794 if any_null {
6795 continue;
6796 }
6797 let fk_vals: Vec<Value> = fk
6798 .columns
6799 .iter()
6800 .map(|&ci| bufs.row[ci as usize].clone())
6801 .collect();
6802 bufs.fk_key_buf.clear();
6803 encode_composite_key_into(&fk_vals, &mut bufs.fk_key_buf);
6804 let found = wtx
6805 .table_get(fk.foreign_table.as_bytes(), &bufs.fk_key_buf)
6806 .map_err(SqlError::Storage)?;
6807 if found.is_none() {
6808 let name = fk.name.as_deref().unwrap_or(&fk.foreign_table);
6809 return Err(SqlError::ForeignKeyViolation(name.to_string()));
6810 }
6811 }
6812
6813 for (j, &i) in pk_indices.iter().enumerate() {
6814 bufs.pk_values[j] = std::mem::replace(&mut bufs.row[i], Value::Null);
6815 }
6816 encode_composite_key_into(&bufs.pk_values, &mut bufs.key_buf);
6817
6818 for &slot in dropped {
6819 bufs.value_values[slot as usize] = Value::Null;
6820 }
6821 for (j, &i) in non_pk.iter().enumerate() {
6822 bufs.value_values[enc_pos[j] as usize] =
6823 std::mem::replace(&mut bufs.row[i], Value::Null);
6824 }
6825 encode_row_into(&bufs.value_values, &mut bufs.value_buf);
6826
6827 if bufs.key_buf.len() > citadel_core::MAX_KEY_SIZE {
6828 return Err(SqlError::KeyTooLarge {
6829 size: bufs.key_buf.len(),
6830 max: citadel_core::MAX_KEY_SIZE,
6831 });
6832 }
6833 if bufs.value_buf.len() > citadel_core::MAX_INLINE_VALUE_SIZE {
6834 return Err(SqlError::RowTooLarge {
6835 size: bufs.value_buf.len(),
6836 max: citadel_core::MAX_INLINE_VALUE_SIZE,
6837 });
6838 }
6839
6840 let is_new = wtx
6841 .table_insert(stmt.table.as_bytes(), &bufs.key_buf, &bufs.value_buf)
6842 .map_err(SqlError::Storage)?;
6843 if !is_new {
6844 return Err(SqlError::DuplicateKey);
6845 }
6846
6847 if !table_schema.indices.is_empty() {
6848 for (j, &i) in pk_indices.iter().enumerate() {
6849 bufs.row[i] = bufs.pk_values[j].clone();
6850 }
6851 for (j, &i) in non_pk.iter().enumerate() {
6852 bufs.row[i] =
6853 std::mem::replace(&mut bufs.value_values[enc_pos[j] as usize], Value::Null);
6854 }
6855 insert_index_entries(wtx, table_schema, &bufs.row, &bufs.pk_values)?;
6856 }
6857 count += 1;
6858 }
6859
6860 Ok(ExecutionResult::RowsAffected(count))
6861}
6862
6863fn exec_select_in_txn(
6864 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
6865 schema: &SchemaManager,
6866 stmt: &SelectStmt,
6867 ctes: &CteContext,
6868) -> Result<ExecutionResult> {
6869 let materialized;
6870 let stmt = if stmt_has_subquery(stmt) {
6871 materialized =
6872 materialize_stmt(stmt, &mut |sub| exec_subquery_write(wtx, schema, sub, ctes))?;
6873 &materialized
6874 } else {
6875 stmt
6876 };
6877
6878 if stmt.from.is_empty() {
6879 return exec_select_no_from(stmt);
6880 }
6881
6882 let lower_name = stmt.from.to_ascii_lowercase();
6883
6884 if let Some(cte_result) = ctes.get(&lower_name) {
6885 if stmt.joins.is_empty() {
6886 return exec_select_from_cte(cte_result, stmt, &mut |sub| {
6887 exec_subquery_write(wtx, schema, sub, ctes)
6888 });
6889 } else {
6890 return exec_select_join_with_ctes(stmt, ctes, &mut |name| {
6891 scan_table_write(wtx, schema, name)
6892 });
6893 }
6894 }
6895
6896 if !ctes.is_empty()
6897 && stmt
6898 .joins
6899 .iter()
6900 .any(|j| ctes.contains_key(&j.table.name.to_ascii_lowercase()))
6901 {
6902 return exec_select_join_with_ctes(stmt, ctes, &mut |name| {
6903 scan_table_write(wtx, schema, name)
6904 });
6905 }
6906
6907 if !stmt.joins.is_empty() {
6908 return exec_select_join_in_txn(wtx, schema, stmt);
6909 }
6910
6911 let lower_name = stmt.from.to_ascii_lowercase();
6912 let table_schema = schema
6913 .get(&lower_name)
6914 .ok_or_else(|| SqlError::TableNotFound(stmt.from.clone()))?;
6915
6916 if let Some(result) = try_count_star_shortcut(stmt, || {
6917 wtx.table_entry_count(lower_name.as_bytes())
6918 .map_err(SqlError::Storage)
6919 })? {
6920 return Ok(result);
6921 }
6922
6923 if let Some(plan) = StreamAggPlan::try_new(stmt, table_schema)? {
6924 let mut states: Vec<AggState> = plan.ops.iter().map(|(op, _)| AggState::new(op)).collect();
6925 let mut scan_err: Option<SqlError> = None;
6926 if stmt.where_clause.is_none() {
6927 wtx.table_scan_from(lower_name.as_bytes(), b"", |key, value| {
6928 Ok(plan.feed_row_raw(key, value, &mut states, &mut scan_err))
6929 })
6930 .map_err(SqlError::Storage)?;
6931 } else {
6932 let col_map = ColumnMap::new(&table_schema.columns);
6933 wtx.table_scan_from(lower_name.as_bytes(), b"", |key, value| {
6934 Ok(plan.feed_row(
6935 key,
6936 value,
6937 table_schema,
6938 &col_map,
6939 &stmt.where_clause,
6940 &mut states,
6941 &mut scan_err,
6942 ))
6943 })
6944 .map_err(SqlError::Storage)?;
6945 }
6946 if let Some(e) = scan_err {
6947 return Err(e);
6948 }
6949 return Ok(plan.finish(states));
6950 }
6951
6952 if let Some(plan) = StreamGroupByPlan::try_new(stmt, table_schema)? {
6953 let lower = lower_name.clone();
6954 return plan.execute_scan(|cb| {
6955 wtx.table_scan_from(lower.as_bytes(), b"", |key, value| Ok(cb(key, value)))
6956 });
6957 }
6958
6959 if let Some(plan) = TopKScanPlan::try_new(stmt, table_schema)? {
6960 let lower = lower_name.clone();
6961 return plan.execute_scan(table_schema, stmt, |cb| {
6962 wtx.table_scan_from(lower.as_bytes(), b"", |key, value| Ok(cb(key, value)))
6963 });
6964 }
6965
6966 let scan_limit = compute_scan_limit(stmt);
6967 let (rows, predicate_applied) =
6968 collect_rows_write(wtx, table_schema, &stmt.where_clause, scan_limit)?;
6969 process_select(&table_schema.columns, rows, stmt, predicate_applied)
6970}
6971
6972fn exec_update_in_txn(
6973 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
6974 schema: &SchemaManager,
6975 stmt: &UpdateStmt,
6976) -> Result<ExecutionResult> {
6977 let materialized;
6978 let stmt = if update_has_subquery(stmt) {
6979 materialized = materialize_update(stmt, &mut |sub| {
6980 exec_subquery_write(wtx, schema, sub, &HashMap::new())
6981 })?;
6982 &materialized
6983 } else {
6984 stmt
6985 };
6986
6987 let lower_name = stmt.table.to_ascii_lowercase();
6988 let table_schema = schema
6989 .get(&lower_name)
6990 .ok_or_else(|| SqlError::TableNotFound(stmt.table.clone()))?;
6991
6992 let col_map = ColumnMap::new(&table_schema.columns);
6993 let all_candidates = collect_keyed_rows_write(wtx, table_schema, &stmt.where_clause)?;
6994 let matching_rows: Vec<(Vec<u8>, Vec<Value>)> = all_candidates
6995 .into_iter()
6996 .filter(|(_, row)| match &stmt.where_clause {
6997 Some(where_expr) => match eval_expr(where_expr, &col_map, row) {
6998 Ok(val) => is_truthy(&val),
6999 Err(_) => false,
7000 },
7001 None => true,
7002 })
7003 .collect();
7004
7005 if matching_rows.is_empty() {
7006 return Ok(ExecutionResult::RowsAffected(0));
7007 }
7008
7009 struct UpdateChange {
7010 old_key: Vec<u8>,
7011 new_key: Vec<u8>,
7012 new_value: Vec<u8>,
7013 pk_changed: bool,
7014 old_row: Vec<Value>,
7015 new_row: Vec<Value>,
7016 }
7017
7018 let pk_indices = table_schema.pk_indices();
7019 let mut changes: Vec<UpdateChange> = Vec::new();
7020
7021 for (old_key, row) in &matching_rows {
7022 let mut new_row = row.clone();
7023 let mut pk_changed = false;
7024
7025 let mut evaluated: Vec<(usize, Value)> = Vec::with_capacity(stmt.assignments.len());
7027 for (col_name, expr) in &stmt.assignments {
7028 let col_idx = table_schema
7029 .column_index(col_name)
7030 .ok_or_else(|| SqlError::ColumnNotFound(col_name.clone()))?;
7031 let new_val = eval_expr(expr, &col_map, row)?;
7032 let col = &table_schema.columns[col_idx];
7033
7034 let got_type = new_val.data_type();
7035 let coerced = if new_val.is_null() {
7036 if !col.nullable {
7037 return Err(SqlError::NotNullViolation(col.name.clone()));
7038 }
7039 Value::Null
7040 } else {
7041 new_val
7042 .coerce_into(col.data_type)
7043 .ok_or_else(|| SqlError::TypeMismatch {
7044 expected: col.data_type.to_string(),
7045 got: got_type.to_string(),
7046 })?
7047 };
7048
7049 evaluated.push((col_idx, coerced));
7050 }
7051
7052 for (col_idx, coerced) in evaluated {
7053 if table_schema.primary_key_columns.contains(&(col_idx as u16)) {
7054 pk_changed = true;
7055 }
7056 new_row[col_idx] = coerced;
7057 }
7058
7059 if table_schema.has_checks() {
7061 for col in &table_schema.columns {
7062 if let Some(ref check) = col.check_expr {
7063 let result = eval_expr(check, &col_map, &new_row)?;
7064 if !is_truthy(&result) && !result.is_null() {
7065 let name = col.check_name.as_deref().unwrap_or(&col.name);
7066 return Err(SqlError::CheckViolation(name.to_string()));
7067 }
7068 }
7069 }
7070 for tc in &table_schema.check_constraints {
7071 let result = eval_expr(&tc.expr, &col_map, &new_row)?;
7072 if !is_truthy(&result) && !result.is_null() {
7073 let name = tc.name.as_deref().unwrap_or(&tc.sql);
7074 return Err(SqlError::CheckViolation(name.to_string()));
7075 }
7076 }
7077 }
7078
7079 let pk_values: Vec<Value> = pk_indices.iter().map(|&i| new_row[i].clone()).collect();
7080 let new_key = encode_composite_key(&pk_values);
7081
7082 let non_pk = table_schema.non_pk_indices();
7083 let enc_pos = table_schema.encoding_positions();
7084 let phys_count = table_schema.physical_non_pk_count();
7085 let mut value_values = vec![Value::Null; phys_count];
7086 for (j, &i) in non_pk.iter().enumerate() {
7087 value_values[enc_pos[j] as usize] = new_row[i].clone();
7088 }
7089 let new_value = encode_row(&value_values);
7090
7091 changes.push(UpdateChange {
7092 old_key: old_key.clone(),
7093 new_key,
7094 new_value,
7095 pk_changed,
7096 old_row: row.clone(),
7097 new_row,
7098 });
7099 }
7100
7101 {
7102 use std::collections::HashSet;
7103 let mut new_keys: HashSet<Vec<u8>> = HashSet::new();
7104 for c in &changes {
7105 if c.pk_changed && c.new_key != c.old_key && !new_keys.insert(c.new_key.clone()) {
7106 return Err(SqlError::DuplicateKey);
7107 }
7108 }
7109 }
7110
7111 if !table_schema.foreign_keys.is_empty() {
7113 for c in &changes {
7114 for fk in &table_schema.foreign_keys {
7115 let fk_changed = fk
7116 .columns
7117 .iter()
7118 .any(|&ci| c.old_row[ci as usize] != c.new_row[ci as usize]);
7119 if !fk_changed {
7120 continue;
7121 }
7122 let any_null = fk
7123 .columns
7124 .iter()
7125 .any(|&ci| c.new_row[ci as usize].is_null());
7126 if any_null {
7127 continue;
7128 }
7129 let fk_vals: Vec<Value> = fk
7130 .columns
7131 .iter()
7132 .map(|&ci| c.new_row[ci as usize].clone())
7133 .collect();
7134 let fk_key = encode_composite_key(&fk_vals);
7135 let found = wtx
7136 .table_get(fk.foreign_table.as_bytes(), &fk_key)
7137 .map_err(SqlError::Storage)?;
7138 if found.is_none() {
7139 let name = fk.name.as_deref().unwrap_or(&fk.foreign_table);
7140 return Err(SqlError::ForeignKeyViolation(name.to_string()));
7141 }
7142 }
7143 }
7144 }
7145
7146 let child_fks = schema.child_fks_for(&lower_name);
7148 if !child_fks.is_empty() {
7149 for c in &changes {
7150 if !c.pk_changed {
7151 continue;
7152 }
7153 let old_pk: Vec<Value> = pk_indices.iter().map(|&i| c.old_row[i].clone()).collect();
7154 let old_pk_key = encode_composite_key(&old_pk);
7155 for &(child_table, fk) in &child_fks {
7156 let child_schema = schema.get(child_table).unwrap();
7157 let fk_idx = child_schema
7158 .indices
7159 .iter()
7160 .find(|idx| idx.columns == fk.columns);
7161 if let Some(idx) = fk_idx {
7162 let idx_table = TableSchema::index_table_name(child_table, &idx.name);
7163 let mut has_child = false;
7164 wtx.table_scan_from(&idx_table, &old_pk_key, |key, _| {
7165 if key.starts_with(&old_pk_key) {
7166 has_child = true;
7167 Ok(false)
7168 } else {
7169 Ok(false)
7170 }
7171 })
7172 .map_err(SqlError::Storage)?;
7173 if has_child {
7174 return Err(SqlError::ForeignKeyViolation(format!(
7175 "cannot update PK in '{}': referenced by '{}'",
7176 lower_name, child_table
7177 )));
7178 }
7179 }
7180 }
7181 }
7182 }
7183
7184 for c in &changes {
7185 let old_pk: Vec<Value> = pk_indices.iter().map(|&i| c.old_row[i].clone()).collect();
7186
7187 for idx in &table_schema.indices {
7188 if index_columns_changed(idx, &c.old_row, &c.new_row) || c.pk_changed {
7189 let idx_table = TableSchema::index_table_name(&lower_name, &idx.name);
7190 let old_idx_key = encode_index_key(idx, &c.old_row, &old_pk);
7191 wtx.table_delete(&idx_table, &old_idx_key)
7192 .map_err(SqlError::Storage)?;
7193 }
7194 }
7195
7196 if c.pk_changed {
7197 wtx.table_delete(lower_name.as_bytes(), &c.old_key)
7198 .map_err(SqlError::Storage)?;
7199 }
7200 }
7201
7202 for c in &changes {
7203 let new_pk: Vec<Value> = pk_indices.iter().map(|&i| c.new_row[i].clone()).collect();
7204
7205 if c.pk_changed {
7206 let is_new = wtx
7207 .table_insert(lower_name.as_bytes(), &c.new_key, &c.new_value)
7208 .map_err(SqlError::Storage)?;
7209 if !is_new {
7210 return Err(SqlError::DuplicateKey);
7211 }
7212 } else {
7213 wtx.table_insert(lower_name.as_bytes(), &c.new_key, &c.new_value)
7214 .map_err(SqlError::Storage)?;
7215 }
7216
7217 for idx in &table_schema.indices {
7218 if index_columns_changed(idx, &c.old_row, &c.new_row) || c.pk_changed {
7219 let idx_table = TableSchema::index_table_name(&lower_name, &idx.name);
7220 let new_idx_key = encode_index_key(idx, &c.new_row, &new_pk);
7221 let new_idx_val = encode_index_value(idx, &c.new_row, &new_pk);
7222 let is_new = wtx
7223 .table_insert(&idx_table, &new_idx_key, &new_idx_val)
7224 .map_err(SqlError::Storage)?;
7225 if idx.unique && !is_new {
7226 let indexed_values: Vec<Value> = idx
7227 .columns
7228 .iter()
7229 .map(|&col_idx| c.new_row[col_idx as usize].clone())
7230 .collect();
7231 let any_null = indexed_values.iter().any(|v| v.is_null());
7232 if !any_null {
7233 return Err(SqlError::UniqueViolation(idx.name.clone()));
7234 }
7235 }
7236 }
7237 }
7238 }
7239
7240 let count = changes.len() as u64;
7241 Ok(ExecutionResult::RowsAffected(count))
7242}
7243
7244fn exec_delete_in_txn(
7245 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
7246 schema: &SchemaManager,
7247 stmt: &DeleteStmt,
7248) -> Result<ExecutionResult> {
7249 let materialized;
7250 let stmt = if delete_has_subquery(stmt) {
7251 materialized = materialize_delete(stmt, &mut |sub| {
7252 exec_subquery_write(wtx, schema, sub, &HashMap::new())
7253 })?;
7254 &materialized
7255 } else {
7256 stmt
7257 };
7258
7259 let lower_name = stmt.table.to_ascii_lowercase();
7260 let table_schema = schema
7261 .get(&lower_name)
7262 .ok_or_else(|| SqlError::TableNotFound(stmt.table.clone()))?;
7263
7264 let col_map = ColumnMap::new(&table_schema.columns);
7265 let all_candidates = collect_keyed_rows_write(wtx, table_schema, &stmt.where_clause)?;
7266 let rows_to_delete: Vec<(Vec<u8>, Vec<Value>)> = all_candidates
7267 .into_iter()
7268 .filter(|(_, row)| match &stmt.where_clause {
7269 Some(where_expr) => match eval_expr(where_expr, &col_map, row) {
7270 Ok(val) => is_truthy(&val),
7271 Err(_) => false,
7272 },
7273 None => true,
7274 })
7275 .collect();
7276
7277 if rows_to_delete.is_empty() {
7278 return Ok(ExecutionResult::RowsAffected(0));
7279 }
7280
7281 let pk_indices = table_schema.pk_indices();
7282
7283 let child_fks = schema.child_fks_for(&lower_name);
7285 if !child_fks.is_empty() {
7286 for (_key, row) in &rows_to_delete {
7287 let pk_values: Vec<Value> = pk_indices.iter().map(|&i| row[i].clone()).collect();
7288 let pk_key = encode_composite_key(&pk_values);
7289 for &(child_table, fk) in &child_fks {
7290 let child_schema = schema.get(child_table).unwrap();
7291 let fk_idx = child_schema
7292 .indices
7293 .iter()
7294 .find(|idx| idx.columns == fk.columns);
7295 if let Some(idx) = fk_idx {
7296 let idx_table = TableSchema::index_table_name(child_table, &idx.name);
7297 let mut has_child = false;
7298 wtx.table_scan_from(&idx_table, &pk_key, |key, _| {
7299 if key.starts_with(&pk_key) {
7300 has_child = true;
7301 Ok(false)
7302 } else {
7303 Ok(false)
7304 }
7305 })
7306 .map_err(SqlError::Storage)?;
7307 if has_child {
7308 return Err(SqlError::ForeignKeyViolation(format!(
7309 "cannot delete from '{}': referenced by '{}'",
7310 lower_name, child_table
7311 )));
7312 }
7313 }
7314 }
7315 }
7316 }
7317
7318 for (key, row) in &rows_to_delete {
7319 let pk_values: Vec<Value> = pk_indices.iter().map(|&i| row[i].clone()).collect();
7320 delete_index_entries(wtx, table_schema, row, &pk_values)?;
7321 wtx.table_delete(lower_name.as_bytes(), key)
7322 .map_err(SqlError::Storage)?;
7323 }
7324 let count = rows_to_delete.len() as u64;
7325 Ok(ExecutionResult::RowsAffected(count))
7326}
7327
7328fn exec_aggregate(
7331 columns: &[ColumnDef],
7332 rows: &[Vec<Value>],
7333 stmt: &SelectStmt,
7334) -> Result<ExecutionResult> {
7335 let col_map = ColumnMap::new(columns);
7336 let groups: BTreeMap<Vec<Value>, Vec<&Vec<Value>>> = if stmt.group_by.is_empty() {
7337 let mut m = BTreeMap::new();
7338 m.insert(vec![], rows.iter().collect());
7339 m
7340 } else {
7341 let mut m: BTreeMap<Vec<Value>, Vec<&Vec<Value>>> = BTreeMap::new();
7342 for row in rows {
7343 let group_key: Vec<Value> = stmt
7344 .group_by
7345 .iter()
7346 .map(|expr| eval_expr(expr, &col_map, row))
7347 .collect::<Result<_>>()?;
7348 m.entry(group_key).or_default().push(row);
7349 }
7350 m
7351 };
7352
7353 let mut result_rows = Vec::new();
7354 let output_cols = build_output_columns(&stmt.columns, columns);
7355
7356 for group_rows in groups.values() {
7357 let mut result_row = Vec::new();
7358
7359 for sel_col in &stmt.columns {
7360 match sel_col {
7361 SelectColumn::AllColumns => {
7362 return Err(SqlError::Unsupported("SELECT * with GROUP BY".into()));
7363 }
7364 SelectColumn::Expr { expr, .. } => {
7365 let val = eval_aggregate_expr(expr, &col_map, group_rows)?;
7366 result_row.push(val);
7367 }
7368 }
7369 }
7370
7371 if let Some(ref having) = stmt.having {
7372 let passes = match eval_aggregate_expr(having, &col_map, group_rows) {
7373 Ok(val) => is_truthy(&val),
7374 Err(SqlError::ColumnNotFound(_)) => {
7375 let output_map = ColumnMap::new(&output_cols);
7376 match eval_expr(having, &output_map, &result_row) {
7377 Ok(val) => is_truthy(&val),
7378 Err(_) => false,
7379 }
7380 }
7381 Err(e) => return Err(e),
7382 };
7383 if !passes {
7384 continue;
7385 }
7386 }
7387
7388 result_rows.push(result_row);
7389 }
7390
7391 if stmt.distinct {
7392 let mut seen = std::collections::HashSet::new();
7393 result_rows.retain(|row| seen.insert(row.clone()));
7394 }
7395
7396 if !stmt.order_by.is_empty() {
7397 let output_cols = build_output_columns(&stmt.columns, columns);
7398 sort_rows(&mut result_rows, &stmt.order_by, &output_cols)?;
7399 }
7400
7401 if let Some(ref offset_expr) = stmt.offset {
7402 let offset = eval_const_int(offset_expr)?.max(0) as usize;
7403 if offset < result_rows.len() {
7404 result_rows = result_rows.split_off(offset);
7405 } else {
7406 result_rows.clear();
7407 }
7408 }
7409 if let Some(ref limit_expr) = stmt.limit {
7410 let limit = eval_const_int(limit_expr)?.max(0) as usize;
7411 result_rows.truncate(limit);
7412 }
7413
7414 let col_names = stmt
7415 .columns
7416 .iter()
7417 .map(|c| match c {
7418 SelectColumn::AllColumns => "*".into(),
7419 SelectColumn::Expr { alias: Some(a), .. } => a.clone(),
7420 SelectColumn::Expr { expr, .. } => expr_display_name(expr),
7421 })
7422 .collect();
7423
7424 Ok(ExecutionResult::Query(QueryResult {
7425 columns: col_names,
7426 rows: result_rows,
7427 }))
7428}
7429
7430fn eval_aggregate_expr(
7431 expr: &Expr,
7432 col_map: &ColumnMap,
7433 group_rows: &[&Vec<Value>],
7434) -> Result<Value> {
7435 match expr {
7436 Expr::CountStar => Ok(Value::Integer(group_rows.len() as i64)),
7437
7438 Expr::Function { name, args } if is_aggregate_function(name, args.len()) => {
7439 let func = name.to_ascii_uppercase();
7440 if args.len() != 1 {
7441 return Err(SqlError::Unsupported(format!(
7442 "{func} with {} args",
7443 args.len()
7444 )));
7445 }
7446 let arg = &args[0];
7447 let values: Vec<Value> = group_rows
7448 .iter()
7449 .map(|row| eval_expr(arg, col_map, row))
7450 .collect::<Result<_>>()?;
7451
7452 match func.as_str() {
7453 "COUNT" => {
7454 let count = values.iter().filter(|v| !v.is_null()).count();
7455 Ok(Value::Integer(count as i64))
7456 }
7457 "SUM" => {
7458 let mut int_sum: i64 = 0;
7459 let mut real_sum: f64 = 0.0;
7460 let mut has_real = false;
7461 let mut all_null = true;
7462 for v in &values {
7463 match v {
7464 Value::Integer(i) => {
7465 int_sum += i;
7466 all_null = false;
7467 }
7468 Value::Real(r) => {
7469 real_sum += r;
7470 has_real = true;
7471 all_null = false;
7472 }
7473 Value::Null => {}
7474 _ => {
7475 return Err(SqlError::TypeMismatch {
7476 expected: "numeric".into(),
7477 got: v.data_type().to_string(),
7478 })
7479 }
7480 }
7481 }
7482 if all_null {
7483 return Ok(Value::Null);
7484 }
7485 if has_real {
7486 Ok(Value::Real(real_sum + int_sum as f64))
7487 } else {
7488 Ok(Value::Integer(int_sum))
7489 }
7490 }
7491 "AVG" => {
7492 let mut sum: f64 = 0.0;
7493 let mut count: i64 = 0;
7494 for v in &values {
7495 match v {
7496 Value::Integer(i) => {
7497 sum += *i as f64;
7498 count += 1;
7499 }
7500 Value::Real(r) => {
7501 sum += r;
7502 count += 1;
7503 }
7504 Value::Null => {}
7505 _ => {
7506 return Err(SqlError::TypeMismatch {
7507 expected: "numeric".into(),
7508 got: v.data_type().to_string(),
7509 })
7510 }
7511 }
7512 }
7513 if count == 0 {
7514 Ok(Value::Null)
7515 } else {
7516 Ok(Value::Real(sum / count as f64))
7517 }
7518 }
7519 "MIN" => {
7520 let mut min: Option<&Value> = None;
7521 for v in &values {
7522 if v.is_null() {
7523 continue;
7524 }
7525 min = Some(match min {
7526 None => v,
7527 Some(m) => {
7528 if v < m {
7529 v
7530 } else {
7531 m
7532 }
7533 }
7534 });
7535 }
7536 Ok(min.cloned().unwrap_or(Value::Null))
7537 }
7538 "MAX" => {
7539 let mut max: Option<&Value> = None;
7540 for v in &values {
7541 if v.is_null() {
7542 continue;
7543 }
7544 max = Some(match max {
7545 None => v,
7546 Some(m) => {
7547 if v > m {
7548 v
7549 } else {
7550 m
7551 }
7552 }
7553 });
7554 }
7555 Ok(max.cloned().unwrap_or(Value::Null))
7556 }
7557 _ => Err(SqlError::Unsupported(format!("aggregate function: {func}"))),
7558 }
7559 }
7560
7561 Expr::Column(_) | Expr::QualifiedColumn { .. } => {
7562 if let Some(first) = group_rows.first() {
7563 eval_expr(expr, col_map, first)
7564 } else {
7565 Ok(Value::Null)
7566 }
7567 }
7568
7569 Expr::Literal(v) => Ok(v.clone()),
7570
7571 Expr::BinaryOp { left, op, right } => {
7572 let l = eval_aggregate_expr(left, col_map, group_rows)?;
7573 let r = eval_aggregate_expr(right, col_map, group_rows)?;
7574 eval_expr(
7575 &Expr::BinaryOp {
7576 left: Box::new(Expr::Literal(l)),
7577 op: *op,
7578 right: Box::new(Expr::Literal(r)),
7579 },
7580 col_map,
7581 &[],
7582 )
7583 }
7584
7585 Expr::UnaryOp { op, expr: e } => {
7586 let v = eval_aggregate_expr(e, col_map, group_rows)?;
7587 eval_expr(
7588 &Expr::UnaryOp {
7589 op: *op,
7590 expr: Box::new(Expr::Literal(v)),
7591 },
7592 col_map,
7593 &[],
7594 )
7595 }
7596
7597 Expr::IsNull(e) => {
7598 let v = eval_aggregate_expr(e, col_map, group_rows)?;
7599 Ok(Value::Boolean(v.is_null()))
7600 }
7601
7602 Expr::IsNotNull(e) => {
7603 let v = eval_aggregate_expr(e, col_map, group_rows)?;
7604 Ok(Value::Boolean(!v.is_null()))
7605 }
7606
7607 Expr::Cast { expr: e, data_type } => {
7608 let v = eval_aggregate_expr(e, col_map, group_rows)?;
7609 eval_expr(
7610 &Expr::Cast {
7611 expr: Box::new(Expr::Literal(v)),
7612 data_type: *data_type,
7613 },
7614 col_map,
7615 &[],
7616 )
7617 }
7618
7619 Expr::Case {
7620 operand,
7621 conditions,
7622 else_result,
7623 } => {
7624 let op_val = operand
7625 .as_ref()
7626 .map(|e| eval_aggregate_expr(e, col_map, group_rows))
7627 .transpose()?;
7628 if let Some(ov) = &op_val {
7629 for (cond, result) in conditions {
7630 let cv = eval_aggregate_expr(cond, col_map, group_rows)?;
7631 if !ov.is_null() && !cv.is_null() && *ov == cv {
7632 return eval_aggregate_expr(result, col_map, group_rows);
7633 }
7634 }
7635 } else {
7636 for (cond, result) in conditions {
7637 let cv = eval_aggregate_expr(cond, col_map, group_rows)?;
7638 if is_truthy(&cv) {
7639 return eval_aggregate_expr(result, col_map, group_rows);
7640 }
7641 }
7642 }
7643 match else_result {
7644 Some(e) => eval_aggregate_expr(e, col_map, group_rows),
7645 None => Ok(Value::Null),
7646 }
7647 }
7648
7649 Expr::Coalesce(args) => {
7650 for arg in args {
7651 let v = eval_aggregate_expr(arg, col_map, group_rows)?;
7652 if !v.is_null() {
7653 return Ok(v);
7654 }
7655 }
7656 Ok(Value::Null)
7657 }
7658
7659 Expr::Between {
7660 expr: e,
7661 low,
7662 high,
7663 negated,
7664 } => {
7665 let v = eval_aggregate_expr(e, col_map, group_rows)?;
7666 let lo = eval_aggregate_expr(low, col_map, group_rows)?;
7667 let hi = eval_aggregate_expr(high, col_map, group_rows)?;
7668 eval_expr(
7669 &Expr::Between {
7670 expr: Box::new(Expr::Literal(v)),
7671 low: Box::new(Expr::Literal(lo)),
7672 high: Box::new(Expr::Literal(hi)),
7673 negated: *negated,
7674 },
7675 col_map,
7676 &[],
7677 )
7678 }
7679
7680 Expr::Like {
7681 expr: e,
7682 pattern,
7683 escape,
7684 negated,
7685 } => {
7686 let v = eval_aggregate_expr(e, col_map, group_rows)?;
7687 let p = eval_aggregate_expr(pattern, col_map, group_rows)?;
7688 let esc = escape
7689 .as_ref()
7690 .map(|es| eval_aggregate_expr(es, col_map, group_rows))
7691 .transpose()?;
7692 let esc_box = esc.map(|v| Box::new(Expr::Literal(v)));
7693 eval_expr(
7694 &Expr::Like {
7695 expr: Box::new(Expr::Literal(v)),
7696 pattern: Box::new(Expr::Literal(p)),
7697 escape: esc_box,
7698 negated: *negated,
7699 },
7700 col_map,
7701 &[],
7702 )
7703 }
7704
7705 Expr::Function { name, args } => {
7706 let evaluated: Vec<Value> = args
7707 .iter()
7708 .map(|a| eval_aggregate_expr(a, col_map, group_rows))
7709 .collect::<Result<_>>()?;
7710 let literal_args: Vec<Expr> = evaluated.into_iter().map(Expr::Literal).collect();
7711 eval_expr(
7712 &Expr::Function {
7713 name: name.clone(),
7714 args: literal_args,
7715 },
7716 col_map,
7717 &[],
7718 )
7719 }
7720
7721 _ => Err(SqlError::Unsupported(format!(
7722 "expression in aggregate: {expr:?}"
7723 ))),
7724 }
7725}
7726
7727fn is_aggregate_function(name: &str, arg_count: usize) -> bool {
7728 let u = name.to_ascii_uppercase();
7729 matches!(u.as_str(), "COUNT" | "SUM" | "AVG")
7730 || (matches!(u.as_str(), "MIN" | "MAX") && arg_count == 1)
7731}
7732
7733fn is_aggregate_expr(expr: &Expr) -> bool {
7734 match expr {
7735 Expr::CountStar => true,
7736 Expr::Function { name, args } => {
7737 is_aggregate_function(name, args.len()) || args.iter().any(is_aggregate_expr)
7738 }
7739 Expr::BinaryOp { left, right, .. } => is_aggregate_expr(left) || is_aggregate_expr(right),
7740 Expr::UnaryOp { expr, .. }
7741 | Expr::IsNull(expr)
7742 | Expr::IsNotNull(expr)
7743 | Expr::Cast { expr, .. } => is_aggregate_expr(expr),
7744 Expr::Case {
7745 operand,
7746 conditions,
7747 else_result,
7748 } => {
7749 operand.as_ref().is_some_and(|e| is_aggregate_expr(e))
7750 || conditions
7751 .iter()
7752 .any(|(c, r)| is_aggregate_expr(c) || is_aggregate_expr(r))
7753 || else_result.as_ref().is_some_and(|e| is_aggregate_expr(e))
7754 }
7755 Expr::Coalesce(args) => args.iter().any(is_aggregate_expr),
7756 Expr::Between {
7757 expr, low, high, ..
7758 } => is_aggregate_expr(expr) || is_aggregate_expr(low) || is_aggregate_expr(high),
7759 Expr::Like {
7760 expr,
7761 pattern,
7762 escape,
7763 ..
7764 } => {
7765 is_aggregate_expr(expr)
7766 || is_aggregate_expr(pattern)
7767 || escape.as_ref().is_some_and(|e| is_aggregate_expr(e))
7768 }
7769 _ => false,
7770 }
7771}
7772
7773struct PartialDecodeCtx {
7776 pk_positions: Vec<(usize, usize)>,
7777 nonpk_targets: Vec<usize>,
7778 nonpk_schema: Vec<usize>,
7779 num_cols: usize,
7780 num_pk_cols: usize,
7781 remaining_pk: Vec<(usize, usize)>,
7782 remaining_nonpk_targets: Vec<usize>,
7783 remaining_nonpk_schema: Vec<usize>,
7784 nonpk_defaults: Vec<(usize, usize, Value)>,
7785 remaining_defaults: Vec<(usize, usize, Value)>,
7786}
7787
7788impl PartialDecodeCtx {
7789 fn new(schema: &TableSchema, needed: &[usize]) -> Self {
7790 let non_pk = schema.non_pk_indices();
7791 let enc_pos = schema.encoding_positions();
7792 let mut pk_positions = Vec::new();
7793 let mut nonpk_targets = Vec::new();
7794 let mut nonpk_schema = Vec::new();
7795
7796 for &col in needed {
7797 if let Some(pk_pos) = schema
7798 .primary_key_columns
7799 .iter()
7800 .position(|&i| i as usize == col)
7801 {
7802 pk_positions.push((pk_pos, col));
7803 } else if let Some(nonpk_order) = non_pk.iter().position(|&i| i == col) {
7804 nonpk_targets.push(enc_pos[nonpk_order] as usize);
7805 nonpk_schema.push(col);
7806 }
7807 }
7808
7809 let needed_set: std::collections::HashSet<usize> = needed.iter().copied().collect();
7810 let mut remaining_pk = Vec::new();
7811 for (pk_pos, &pk_col) in schema.primary_key_columns.iter().enumerate() {
7812 if !needed_set.contains(&(pk_col as usize)) {
7813 remaining_pk.push((pk_pos, pk_col as usize));
7814 }
7815 }
7816 let mut remaining_nonpk_targets = Vec::new();
7817 let mut remaining_nonpk_schema = Vec::new();
7818 for (nonpk_order, &col) in non_pk.iter().enumerate() {
7819 if !needed_set.contains(&col) {
7820 remaining_nonpk_targets.push(enc_pos[nonpk_order] as usize);
7821 remaining_nonpk_schema.push(col);
7822 }
7823 }
7824
7825 let mut nonpk_defaults = Vec::new();
7826 for (&phys_pos, &schema_col) in nonpk_targets.iter().zip(nonpk_schema.iter()) {
7827 if let Some(ref expr) = schema.columns[schema_col].default_expr {
7828 if let Ok(val) = eval_const_expr(expr) {
7829 nonpk_defaults.push((phys_pos, schema_col, val));
7830 }
7831 }
7832 }
7833 let mut remaining_defaults = Vec::new();
7834 for (&phys_pos, &schema_col) in remaining_nonpk_targets
7835 .iter()
7836 .zip(remaining_nonpk_schema.iter())
7837 {
7838 if let Some(ref expr) = schema.columns[schema_col].default_expr {
7839 if let Ok(val) = eval_const_expr(expr) {
7840 remaining_defaults.push((phys_pos, schema_col, val));
7841 }
7842 }
7843 }
7844
7845 Self {
7846 pk_positions,
7847 nonpk_targets,
7848 nonpk_schema,
7849 num_cols: schema.columns.len(),
7850 num_pk_cols: schema.primary_key_columns.len(),
7851 remaining_pk,
7852 remaining_nonpk_targets,
7853 remaining_nonpk_schema,
7854 nonpk_defaults,
7855 remaining_defaults,
7856 }
7857 }
7858
7859 fn decode(&self, key: &[u8], value: &[u8]) -> Result<Vec<Value>> {
7860 let mut row = vec![Value::Null; self.num_cols];
7861
7862 if self.pk_positions.len() == 1 && self.num_pk_cols == 1 {
7863 let (_, schema_col) = self.pk_positions[0];
7864 let (v, _) = decode_key_value(key)?;
7865 row[schema_col] = v;
7866 } else if !self.pk_positions.is_empty() {
7867 let mut pk_values = decode_composite_key(key, self.num_pk_cols)?;
7868 for &(pk_pos, schema_col) in &self.pk_positions {
7869 row[schema_col] = std::mem::take(&mut pk_values[pk_pos]);
7870 }
7871 }
7872
7873 if !self.nonpk_targets.is_empty() {
7874 decode_columns_into(value, &self.nonpk_targets, &self.nonpk_schema, &mut row)?;
7875 }
7876
7877 if !self.nonpk_defaults.is_empty() {
7878 let stored = row_non_pk_count(value);
7879 for (nonpk_idx, schema_col, default) in &self.nonpk_defaults {
7880 if *nonpk_idx >= stored {
7881 row[*schema_col] = default.clone();
7882 }
7883 }
7884 }
7885
7886 Ok(row)
7887 }
7888
7889 fn complete(&self, mut row: Vec<Value>, key: &[u8], value: &[u8]) -> Result<Vec<Value>> {
7890 if !self.remaining_pk.is_empty() {
7891 let mut pk_values = decode_composite_key(key, self.num_pk_cols)?;
7892 for &(pk_pos, schema_col) in &self.remaining_pk {
7893 row[schema_col] = std::mem::take(&mut pk_values[pk_pos]);
7894 }
7895 }
7896 if !self.remaining_nonpk_targets.is_empty() {
7897 let mut values = decode_columns(value, &self.remaining_nonpk_targets)?;
7898 for (i, &schema_col) in self.remaining_nonpk_schema.iter().enumerate() {
7899 row[schema_col] = std::mem::take(&mut values[i]);
7900 }
7901 }
7902 if !self.remaining_defaults.is_empty() {
7903 let stored = row_non_pk_count(value);
7904 for (nonpk_idx, schema_col, default) in &self.remaining_defaults {
7905 if *nonpk_idx >= stored {
7906 row[*schema_col] = default.clone();
7907 }
7908 }
7909 }
7910 Ok(row)
7911 }
7912}
7913
7914fn decode_full_row(schema: &TableSchema, key: &[u8], value: &[u8]) -> Result<Vec<Value>> {
7915 let mut row = vec![Value::Null; schema.columns.len()];
7916 decode_pk_into(
7917 key,
7918 schema.primary_key_columns.len(),
7919 &mut row,
7920 schema.pk_indices(),
7921 )?;
7922 let mapping = schema.decode_col_mapping();
7923 let stored_count = row_non_pk_count(value);
7924 decode_row_into(value, &mut row, mapping)?;
7925 if stored_count < mapping.len() {
7928 for &logical_idx in mapping.iter().skip(stored_count) {
7929 if logical_idx != usize::MAX {
7930 if let Some(ref expr) = schema.columns[logical_idx].default_expr {
7931 row[logical_idx] = eval_const_expr(expr)?;
7932 }
7933 }
7934 }
7935 }
7936 Ok(row)
7937}
7938
7939fn eval_const_expr(expr: &Expr) -> Result<Value> {
7941 static EMPTY: std::sync::OnceLock<ColumnMap> = std::sync::OnceLock::new();
7942 let empty = EMPTY.get_or_init(|| ColumnMap::new(&[]));
7943 eval_expr(expr, empty, &[])
7944}
7945
7946fn eval_const_int(expr: &Expr) -> Result<i64> {
7947 match eval_const_expr(expr)? {
7948 Value::Integer(i) => Ok(i),
7949 other => Err(SqlError::TypeMismatch {
7950 expected: "INTEGER".into(),
7951 got: other.data_type().to_string(),
7952 }),
7953 }
7954}
7955
7956fn sort_rows(
7957 rows: &mut [Vec<Value>],
7958 order_by: &[OrderByItem],
7959 columns: &[ColumnDef],
7960) -> Result<()> {
7961 if rows.is_empty() {
7962 return Ok(());
7963 }
7964 let col_map = ColumnMap::new(columns);
7965 let mut indices: Vec<usize> = (0..rows.len()).collect();
7966
7967 if let Some(col_idx) = try_resolve_flat_sort_col(order_by, &col_map) {
7968 let desc = order_by[0].descending;
7969 let nulls_first = order_by[0].nulls_first.unwrap_or(!desc);
7970 indices.sort_by(|&a, &b| {
7971 compare_flat_key(&rows[a][col_idx], &rows[b][col_idx], desc, nulls_first)
7972 });
7973 } else {
7974 let keys = extract_sort_keys(rows, order_by, &col_map);
7975 indices.sort_by(|&a, &b| compare_sort_keys(&keys[a], &keys[b], order_by));
7976 }
7977
7978 let sorted: Vec<Vec<Value>> = indices
7979 .iter()
7980 .map(|&i| std::mem::take(&mut rows[i]))
7981 .collect();
7982 rows.iter_mut()
7983 .zip(sorted)
7984 .for_each(|(slot, row)| *slot = row);
7985 Ok(())
7986}
7987
7988fn topk_rows(
7989 rows: &mut [Vec<Value>],
7990 order_by: &[OrderByItem],
7991 columns: &[ColumnDef],
7992 k: usize,
7993) -> Result<()> {
7994 let col_map = ColumnMap::new(columns);
7995 let mut indices: Vec<usize> = (0..rows.len()).collect();
7996
7997 if let Some(col_idx) = try_resolve_flat_sort_col(order_by, &col_map) {
7998 let desc = order_by[0].descending;
7999 let nulls_first = order_by[0].nulls_first.unwrap_or(!desc);
8000 let cmp = |&a: &usize, &b: &usize| {
8001 compare_flat_key(&rows[a][col_idx], &rows[b][col_idx], desc, nulls_first)
8002 };
8003 indices.select_nth_unstable_by(k - 1, cmp);
8004 indices[..k].sort_by(cmp);
8005 } else {
8006 let keys = extract_sort_keys(rows, order_by, &col_map);
8007 let cmp = |&a: &usize, &b: &usize| compare_sort_keys(&keys[a], &keys[b], order_by);
8008 indices.select_nth_unstable_by(k - 1, cmp);
8009 indices[..k].sort_by(cmp);
8010 }
8011
8012 let sorted: Vec<Vec<Value>> = indices[..k]
8013 .iter()
8014 .map(|&i| std::mem::take(&mut rows[i]))
8015 .collect();
8016 rows[..k]
8017 .iter_mut()
8018 .zip(sorted)
8019 .for_each(|(slot, row)| *slot = row);
8020 Ok(())
8021}
8022
8023fn try_resolve_flat_sort_col(order_by: &[OrderByItem], col_map: &ColumnMap) -> Option<usize> {
8024 if order_by.len() != 1 {
8025 return None;
8026 }
8027 match &order_by[0].expr {
8028 Expr::Column(name) => col_map.resolve(&name.to_ascii_lowercase()).ok(),
8029 _ => None,
8030 }
8031}
8032
8033fn compare_flat_key(a: &Value, b: &Value, desc: bool, nulls_first: bool) -> std::cmp::Ordering {
8034 match (a.is_null(), b.is_null()) {
8035 (true, true) => std::cmp::Ordering::Equal,
8036 (true, false) => {
8037 if nulls_first {
8038 std::cmp::Ordering::Less
8039 } else {
8040 std::cmp::Ordering::Greater
8041 }
8042 }
8043 (false, true) => {
8044 if nulls_first {
8045 std::cmp::Ordering::Greater
8046 } else {
8047 std::cmp::Ordering::Less
8048 }
8049 }
8050 (false, false) => {
8051 let cmp = a.cmp(b);
8052 if desc {
8053 cmp.reverse()
8054 } else {
8055 cmp
8056 }
8057 }
8058 }
8059}
8060
8061fn extract_sort_keys(
8062 rows: &[Vec<Value>],
8063 order_by: &[OrderByItem],
8064 col_map: &ColumnMap,
8065) -> Vec<Vec<Value>> {
8066 rows.iter()
8067 .map(|row| {
8068 order_by
8069 .iter()
8070 .map(|item| eval_expr(&item.expr, col_map, row).unwrap_or(Value::Null))
8071 .collect()
8072 })
8073 .collect()
8074}
8075
8076fn compare_sort_keys(a: &[Value], b: &[Value], order_by: &[OrderByItem]) -> std::cmp::Ordering {
8077 for (i, item) in order_by.iter().enumerate() {
8078 let nulls_first = item.nulls_first.unwrap_or(!item.descending);
8079 let ord = match (a[i].is_null(), b[i].is_null()) {
8080 (true, true) => std::cmp::Ordering::Equal,
8081 (true, false) => {
8082 if nulls_first {
8083 std::cmp::Ordering::Less
8084 } else {
8085 std::cmp::Ordering::Greater
8086 }
8087 }
8088 (false, true) => {
8089 if nulls_first {
8090 std::cmp::Ordering::Greater
8091 } else {
8092 std::cmp::Ordering::Less
8093 }
8094 }
8095 (false, false) => {
8096 let cmp = a[i].cmp(&b[i]);
8097 if item.descending {
8098 cmp.reverse()
8099 } else {
8100 cmp
8101 }
8102 }
8103 };
8104 if ord != std::cmp::Ordering::Equal {
8105 return ord;
8106 }
8107 }
8108 std::cmp::Ordering::Equal
8109}
8110
8111fn try_build_index_map(
8112 select_cols: &[SelectColumn],
8113 columns: &[ColumnDef],
8114) -> Option<Vec<(String, usize)>> {
8115 let col_map = ColumnMap::new(columns);
8116 let mut map = Vec::new();
8117 let mut seen = std::collections::HashSet::new();
8118 for sel in select_cols {
8119 match sel {
8120 SelectColumn::AllColumns => {
8121 for col in columns {
8122 let idx = col.position as usize;
8123 if !seen.insert(idx) {
8124 return None;
8125 }
8126 map.push((col.name.clone(), idx));
8127 }
8128 }
8129 SelectColumn::Expr { expr, alias } => {
8130 let idx = match expr {
8131 Expr::Column(name) => col_map.resolve(name).ok()?,
8132 Expr::QualifiedColumn { table, column } => {
8133 col_map.resolve_qualified(table, column).ok()?
8134 }
8135 _ => return None,
8136 };
8137 if !seen.insert(idx) {
8138 return None;
8139 }
8140 let name = alias.clone().unwrap_or_else(|| expr_display_name(expr));
8141 map.push((name, idx));
8142 }
8143 }
8144 }
8145 Some(map)
8146}
8147
8148fn project_rows(
8149 columns: &[ColumnDef],
8150 select_cols: &[SelectColumn],
8151 mut rows: Vec<Vec<Value>>,
8152) -> Result<(Vec<String>, Vec<Vec<Value>>)> {
8153 if select_cols.len() == 1 && matches!(select_cols[0], SelectColumn::AllColumns) {
8155 let col_names = columns.iter().map(|c| c.name.clone()).collect();
8156 return Ok((col_names, rows));
8157 }
8158
8159 if let Some(map) = try_build_index_map(select_cols, columns) {
8161 let col_names: Vec<String> = map.iter().map(|(n, _)| n.clone()).collect();
8162 if map.len() == columns.len() && map.iter().enumerate().all(|(i, &(_, idx))| idx == i) {
8164 return Ok((col_names, rows));
8165 }
8166 let projected = rows
8167 .iter_mut()
8168 .map(|row| {
8169 map.iter()
8170 .map(|&(_, idx)| std::mem::take(&mut row[idx]))
8171 .collect()
8172 })
8173 .collect();
8174 return Ok((col_names, projected));
8175 }
8176
8177 let mut col_names = Vec::new();
8179 type Projector = Box<dyn Fn(&[Value]) -> Result<Value>>;
8180 let mut projectors: Vec<Projector> = Vec::new();
8181 let col_map = std::sync::Arc::new(ColumnMap::new(columns));
8182
8183 for sel_col in select_cols {
8184 match sel_col {
8185 SelectColumn::AllColumns => {
8186 for col in columns {
8187 let idx = col.position as usize;
8188 col_names.push(col.name.clone());
8189 projectors.push(Box::new(move |row: &[Value]| Ok(row[idx].clone())));
8190 }
8191 }
8192 SelectColumn::Expr { expr, alias } => {
8193 let name = alias.clone().unwrap_or_else(|| expr_display_name(expr));
8194 col_names.push(name);
8195 let expr = expr.clone();
8196 let map = col_map.clone();
8197 projectors.push(Box::new(move |row: &[Value]| eval_expr(&expr, &map, row)));
8198 }
8199 }
8200 }
8201
8202 let projected = rows
8203 .iter()
8204 .map(|row| {
8205 projectors
8206 .iter()
8207 .map(|p| p(row))
8208 .collect::<Result<Vec<_>>>()
8209 })
8210 .collect::<Result<Vec<_>>>()?;
8211
8212 Ok((col_names, projected))
8213}
8214
8215fn expr_display_name(expr: &Expr) -> String {
8216 match expr {
8217 Expr::Column(name) => name.clone(),
8218 Expr::QualifiedColumn { table, column } => format!("{table}.{column}"),
8219 Expr::Literal(v) => format!("{v}"),
8220 Expr::CountStar => "COUNT(*)".into(),
8221 Expr::Function { name, args } => {
8222 let arg_strs: Vec<String> = args.iter().map(expr_display_name).collect();
8223 format!("{name}({})", arg_strs.join(", "))
8224 }
8225 Expr::BinaryOp { left, op, right } => {
8226 format!(
8227 "{} {} {}",
8228 expr_display_name(left),
8229 op_symbol(op),
8230 expr_display_name(right)
8231 )
8232 }
8233 _ => "?".into(),
8234 }
8235}
8236
8237fn op_symbol(op: &BinOp) -> &'static str {
8238 match op {
8239 BinOp::Add => "+",
8240 BinOp::Sub => "-",
8241 BinOp::Mul => "*",
8242 BinOp::Div => "/",
8243 BinOp::Mod => "%",
8244 BinOp::Eq => "=",
8245 BinOp::NotEq => "<>",
8246 BinOp::Lt => "<",
8247 BinOp::Gt => ">",
8248 BinOp::LtEq => "<=",
8249 BinOp::GtEq => ">=",
8250 BinOp::And => "AND",
8251 BinOp::Or => "OR",
8252 BinOp::Concat => "||",
8253 }
8254}
8255
8256fn build_output_columns(select_cols: &[SelectColumn], columns: &[ColumnDef]) -> Vec<ColumnDef> {
8257 let mut out = Vec::new();
8258 for (i, col) in select_cols.iter().enumerate() {
8259 let (name, data_type) = match col {
8260 SelectColumn::AllColumns => (format!("col{i}"), DataType::Null),
8261 SelectColumn::Expr {
8262 alias: Some(a),
8263 expr,
8264 } => (a.clone(), infer_expr_type(expr, columns)),
8265 SelectColumn::Expr { expr, .. } => {
8266 (expr_display_name(expr), infer_expr_type(expr, columns))
8267 }
8268 };
8269 out.push(ColumnDef {
8270 name,
8271 data_type,
8272 nullable: true,
8273 position: i as u16,
8274 default_expr: None,
8275 default_sql: None,
8276 check_expr: None,
8277 check_sql: None,
8278 check_name: None,
8279 });
8280 }
8281 out
8282}
8283
8284fn infer_expr_type(expr: &Expr, columns: &[ColumnDef]) -> DataType {
8285 match expr {
8286 Expr::Column(name) => columns
8287 .iter()
8288 .find(|c| c.name == *name)
8289 .map(|c| c.data_type)
8290 .unwrap_or(DataType::Null),
8291 Expr::QualifiedColumn { table, column } => {
8292 let qualified = format!("{table}.{column}");
8293 columns
8294 .iter()
8295 .find(|c| c.name == qualified)
8296 .map(|c| c.data_type)
8297 .unwrap_or(DataType::Null)
8298 }
8299 Expr::Literal(v) => v.data_type(),
8300 Expr::CountStar => DataType::Integer,
8301 Expr::Function { name, .. } => match name.to_ascii_uppercase().as_str() {
8302 "COUNT" => DataType::Integer,
8303 "AVG" => DataType::Real,
8304 "SUM" | "MIN" | "MAX" => DataType::Null,
8305 _ => DataType::Null,
8306 },
8307 _ => DataType::Null,
8308 }
8309}