1use std::collections::{BTreeMap, HashMap};
4
5use citadel::Database;
6
7use crate::encoding::{
8 decode_column_raw, decode_columns, decode_columns_into, decode_composite_key, decode_key_value,
9 decode_pk_integer, decode_pk_into, decode_row_into, encode_composite_key,
10 encode_composite_key_into, encode_row, encode_row_into, row_non_pk_count, RawColumn,
11};
12use crate::error::{Result, SqlError};
13use crate::eval::{eval_expr, is_truthy, referenced_columns, ColumnMap};
14use crate::parser::*;
15use crate::planner::{self, ScanPlan};
16use crate::schema::SchemaManager;
17use crate::types::*;
18
19fn encode_index_key(idx: &IndexDef, row: &[Value], pk_values: &[Value]) -> Vec<u8> {
22 let indexed_values: Vec<Value> = idx
23 .columns
24 .iter()
25 .map(|&col_idx| row[col_idx as usize].clone())
26 .collect();
27
28 if idx.unique {
29 let any_null = indexed_values.iter().any(|v| v.is_null());
30 if !any_null {
31 return encode_composite_key(&indexed_values);
32 }
33 }
34
35 let mut all_values = indexed_values;
36 all_values.extend_from_slice(pk_values);
37 encode_composite_key(&all_values)
38}
39
40fn encode_index_value(idx: &IndexDef, row: &[Value], pk_values: &[Value]) -> Vec<u8> {
41 if idx.unique {
42 let indexed_values: Vec<Value> = idx
43 .columns
44 .iter()
45 .map(|&col_idx| row[col_idx as usize].clone())
46 .collect();
47 let any_null = indexed_values.iter().any(|v| v.is_null());
48 if !any_null {
49 return encode_composite_key(pk_values);
50 }
51 }
52 vec![]
53}
54
55fn insert_index_entries(
56 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
57 table_schema: &TableSchema,
58 row: &[Value],
59 pk_values: &[Value],
60) -> Result<()> {
61 for idx in &table_schema.indices {
62 let idx_table = TableSchema::index_table_name(&table_schema.name, &idx.name);
63 let key = encode_index_key(idx, row, pk_values);
64 let value = encode_index_value(idx, row, pk_values);
65
66 let is_new = wtx
67 .table_insert(&idx_table, &key, &value)
68 .map_err(SqlError::Storage)?;
69
70 if idx.unique && !is_new {
71 let indexed_values: Vec<Value> = idx
72 .columns
73 .iter()
74 .map(|&col_idx| row[col_idx as usize].clone())
75 .collect();
76 let any_null = indexed_values.iter().any(|v| v.is_null());
77 if !any_null {
78 return Err(SqlError::UniqueViolation(idx.name.clone()));
79 }
80 }
81 }
82 Ok(())
83}
84
85fn delete_index_entries(
86 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
87 table_schema: &TableSchema,
88 row: &[Value],
89 pk_values: &[Value],
90) -> Result<()> {
91 for idx in &table_schema.indices {
92 let idx_table = TableSchema::index_table_name(&table_schema.name, &idx.name);
93 let key = encode_index_key(idx, row, pk_values);
94 wtx.table_delete(&idx_table, &key)
95 .map_err(SqlError::Storage)?;
96 }
97 Ok(())
98}
99
100fn index_columns_changed(idx: &IndexDef, old_row: &[Value], new_row: &[Value]) -> bool {
101 idx.columns
102 .iter()
103 .any(|&col_idx| old_row[col_idx as usize] != new_row[col_idx as usize])
104}
105
106pub fn execute(
108 db: &Database,
109 schema: &mut SchemaManager,
110 stmt: &Statement,
111 params: &[Value],
112) -> Result<ExecutionResult> {
113 match stmt {
114 Statement::CreateTable(ct) => exec_create_table(db, schema, ct),
115 Statement::DropTable(dt) => exec_drop_table(db, schema, dt),
116 Statement::CreateIndex(ci) => exec_create_index(db, schema, ci),
117 Statement::DropIndex(di) => exec_drop_index(db, schema, di),
118 Statement::AlterTable(at) => exec_alter_table(db, schema, at),
119 Statement::Insert(ins) => exec_insert(db, schema, ins, params),
120 Statement::Select(sq) => exec_select_query(db, schema, sq),
121 Statement::Update(upd) => exec_update(db, schema, upd),
122 Statement::Delete(del) => exec_delete(db, schema, del),
123 Statement::Explain(inner) => explain(schema, inner),
124 Statement::Begin | Statement::Commit | Statement::Rollback => Err(SqlError::Unsupported(
125 "transaction control in auto-commit mode".into(),
126 )),
127 }
128}
129
130pub fn execute_in_txn(
132 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
133 schema: &mut SchemaManager,
134 stmt: &Statement,
135 params: &[Value],
136) -> Result<ExecutionResult> {
137 match stmt {
138 Statement::CreateTable(ct) => exec_create_table_in_txn(wtx, schema, ct),
139 Statement::DropTable(dt) => exec_drop_table_in_txn(wtx, schema, dt),
140 Statement::CreateIndex(ci) => exec_create_index_in_txn(wtx, schema, ci),
141 Statement::DropIndex(di) => exec_drop_index_in_txn(wtx, schema, di),
142 Statement::AlterTable(at) => exec_alter_table_in_txn(wtx, schema, at),
143 Statement::Insert(ins) => {
144 let mut bufs = InsertBufs::new();
145 exec_insert_in_txn(wtx, schema, ins, params, &mut bufs)
146 }
147 Statement::Select(sq) => exec_select_query_in_txn(wtx, schema, sq),
148 Statement::Update(upd) => exec_update_in_txn(wtx, schema, upd),
149 Statement::Delete(del) => exec_delete_in_txn(wtx, schema, del),
150 Statement::Explain(inner) => explain(schema, inner),
151 Statement::Begin | Statement::Commit | Statement::Rollback => {
152 Err(SqlError::Unsupported("nested transaction control".into()))
153 }
154 }
155}
156
157pub fn explain(schema: &SchemaManager, stmt: &Statement) -> Result<ExecutionResult> {
160 let lines = match stmt {
161 Statement::Select(sq) => {
162 let mut lines = Vec::new();
163 let cte_names: Vec<&str> = sq.ctes.iter().map(|c| c.name.as_str()).collect();
164 for cte in &sq.ctes {
165 lines.push(format!("WITH {} AS", cte.name));
166 lines.extend(
167 explain_query_body_cte(schema, &cte.body, &cte_names)?
168 .into_iter()
169 .map(|l| format!(" {l}")),
170 );
171 }
172 lines.extend(explain_query_body_cte(schema, &sq.body, &cte_names)?);
173 lines
174 }
175 Statement::Insert(ins) => match &ins.source {
176 InsertSource::Values(rows) => {
177 vec![format!(
178 "INSERT INTO {} ({} rows)",
179 ins.table.to_ascii_lowercase(),
180 rows.len()
181 )]
182 }
183 InsertSource::Select(sq) => {
184 let mut lines = vec![format!(
185 "INSERT INTO {} ... SELECT",
186 ins.table.to_ascii_lowercase()
187 )];
188 let cte_names: Vec<&str> = sq.ctes.iter().map(|c| c.name.as_str()).collect();
189 for cte in &sq.ctes {
190 lines.push(format!(" WITH {} AS", cte.name));
191 lines.extend(
192 explain_query_body_cte(schema, &cte.body, &cte_names)?
193 .into_iter()
194 .map(|l| format!(" {l}")),
195 );
196 }
197 lines.extend(explain_query_body_cte(schema, &sq.body, &cte_names)?);
198 lines
199 }
200 },
201 Statement::Update(upd) => explain_dml(schema, &upd.table, &upd.where_clause, "UPDATE")?,
202 Statement::Delete(del) => {
203 explain_dml(schema, &del.table, &del.where_clause, "DELETE FROM")?
204 }
205 Statement::AlterTable(at) => {
206 let desc = match &at.op {
207 AlterTableOp::AddColumn { column, .. } => {
208 format!("ALTER TABLE {} ADD COLUMN {}", at.table, column.name)
209 }
210 AlterTableOp::DropColumn { name, .. } => {
211 format!("ALTER TABLE {} DROP COLUMN {}", at.table, name)
212 }
213 AlterTableOp::RenameColumn {
214 old_name, new_name, ..
215 } => {
216 format!(
217 "ALTER TABLE {} RENAME COLUMN {} TO {}",
218 at.table, old_name, new_name
219 )
220 }
221 AlterTableOp::RenameTable { new_name } => {
222 format!("ALTER TABLE {} RENAME TO {}", at.table, new_name)
223 }
224 };
225 vec![desc]
226 }
227 Statement::Explain(_) => {
228 return Err(SqlError::Unsupported("EXPLAIN EXPLAIN".into()));
229 }
230 _ => {
231 return Err(SqlError::Unsupported(
232 "EXPLAIN for this statement type".into(),
233 ));
234 }
235 };
236
237 let rows = lines
238 .into_iter()
239 .map(|line| vec![Value::Text(line.into())])
240 .collect();
241 Ok(ExecutionResult::Query(QueryResult {
242 columns: vec!["plan".into()],
243 rows,
244 }))
245}
246
247fn explain_dml(
248 schema: &SchemaManager,
249 table: &str,
250 where_clause: &Option<Expr>,
251 verb: &str,
252) -> Result<Vec<String>> {
253 let lower = table.to_ascii_lowercase();
254 let table_schema = schema
255 .get(&lower)
256 .ok_or_else(|| SqlError::TableNotFound(table.to_string()))?;
257 let plan = planner::plan_select(table_schema, where_clause);
258 let scan_line = format_scan_line(&lower, &None, &plan, table_schema);
259 Ok(vec![format!("{verb} {}", scan_line)])
260}
261
262fn explain_query_body_cte(
263 schema: &SchemaManager,
264 body: &QueryBody,
265 cte_names: &[&str],
266) -> Result<Vec<String>> {
267 match body {
268 QueryBody::Select(sel) => explain_select_cte(schema, sel, cte_names),
269 QueryBody::Compound(comp) => {
270 let op_name = match (&comp.op, comp.all) {
271 (SetOp::Union, true) => "UNION ALL",
272 (SetOp::Union, false) => "UNION",
273 (SetOp::Intersect, true) => "INTERSECT ALL",
274 (SetOp::Intersect, false) => "INTERSECT",
275 (SetOp::Except, true) => "EXCEPT ALL",
276 (SetOp::Except, false) => "EXCEPT",
277 };
278 let mut lines = vec![op_name.to_string()];
279 let left_lines = explain_query_body_cte(schema, &comp.left, cte_names)?;
280 for l in left_lines {
281 lines.push(format!(" {l}"));
282 }
283 let right_lines = explain_query_body_cte(schema, &comp.right, cte_names)?;
284 for l in right_lines {
285 lines.push(format!(" {l}"));
286 }
287 Ok(lines)
288 }
289 }
290}
291
292fn explain_select_cte(
293 schema: &SchemaManager,
294 stmt: &SelectStmt,
295 cte_names: &[&str],
296) -> Result<Vec<String>> {
297 let mut lines = Vec::new();
298
299 if stmt.from.is_empty() {
300 lines.push("CONSTANT ROW".into());
301 return Ok(lines);
302 }
303
304 let lower_from = stmt.from.to_ascii_lowercase();
305
306 if cte_names
307 .iter()
308 .any(|n| n.eq_ignore_ascii_case(&lower_from))
309 {
310 lines.push(format!("SCAN CTE {lower_from}"));
311 for join in &stmt.joins {
312 let jname = join.table.name.to_ascii_lowercase();
313 if cte_names.iter().any(|n| n.eq_ignore_ascii_case(&jname)) {
314 lines.push(format!("SCAN CTE {jname}"));
315 } else {
316 let js = schema
317 .get(&jname)
318 .ok_or_else(|| SqlError::TableNotFound(join.table.name.clone()))?;
319 let jp = planner::plan_select(js, &None);
320 lines.push(format_scan_line(&jname, &join.table.alias, &jp, js));
321 }
322 }
323 if !stmt.joins.is_empty() {
324 lines.push("NESTED LOOP".into());
325 }
326 if has_any_window_function(stmt) {
327 lines.push("WINDOW".into());
328 }
329 if !stmt.group_by.is_empty() {
330 lines.push("GROUP BY".into());
331 }
332 if stmt.distinct {
333 lines.push("DISTINCT".into());
334 }
335 if !stmt.order_by.is_empty() {
336 lines.push("SORT".into());
337 }
338 if stmt.limit.is_some() {
339 lines.push("LIMIT".into());
340 }
341 return Ok(lines);
342 }
343
344 let from_schema = schema
345 .get(&lower_from)
346 .ok_or_else(|| SqlError::TableNotFound(stmt.from.clone()))?;
347
348 if stmt.joins.is_empty() {
349 let plan = planner::plan_select(from_schema, &stmt.where_clause);
350 lines.push(format_scan_line(
351 &lower_from,
352 &stmt.from_alias,
353 &plan,
354 from_schema,
355 ));
356 } else {
357 let from_plan = planner::plan_select(from_schema, &None);
358 lines.push(format_scan_line(
359 &lower_from,
360 &stmt.from_alias,
361 &from_plan,
362 from_schema,
363 ));
364
365 for join in &stmt.joins {
366 let inner_lower = join.table.name.to_ascii_lowercase();
367 if cte_names
368 .iter()
369 .any(|n| n.eq_ignore_ascii_case(&inner_lower))
370 {
371 lines.push(format!("SCAN CTE {inner_lower}"));
372 } else {
373 let inner_schema = schema
374 .get(&inner_lower)
375 .ok_or_else(|| SqlError::TableNotFound(join.table.name.clone()))?;
376 let inner_plan = planner::plan_select(inner_schema, &None);
377 lines.push(format_scan_line(
378 &inner_lower,
379 &join.table.alias,
380 &inner_plan,
381 inner_schema,
382 ));
383 }
384 }
385
386 let join_type_str = match stmt.joins.last().map(|j| j.join_type) {
387 Some(JoinType::Left) => "LEFT JOIN",
388 Some(JoinType::Right) => "RIGHT JOIN",
389 Some(JoinType::Cross) => "CROSS JOIN",
390 _ => "NESTED LOOP",
391 };
392 lines.push(join_type_str.into());
393 }
394
395 if stmt.where_clause.is_some() && stmt.joins.is_empty() {
396 let plan = planner::plan_select(from_schema, &stmt.where_clause);
397 if matches!(plan, ScanPlan::SeqScan) {
398 lines.push("FILTER".into());
399 }
400 }
401
402 if let Some(ref w) = stmt.where_clause {
403 if !stmt.joins.is_empty() && has_subquery(w) {
404 lines.push("SUBQUERY".into());
405 }
406 }
407
408 explain_subqueries(stmt, &mut lines);
409
410 if has_any_window_function(stmt) {
411 lines.push("WINDOW".into());
412 }
413
414 if !stmt.group_by.is_empty() {
415 lines.push("GROUP BY".into());
416 }
417
418 if stmt.distinct {
419 lines.push("DISTINCT".into());
420 }
421
422 if !stmt.order_by.is_empty() {
423 lines.push("SORT".into());
424 }
425
426 if let Some(ref offset_expr) = stmt.offset {
427 if let Ok(n) = eval_const_int(offset_expr) {
428 lines.push(format!("OFFSET {n}"));
429 } else {
430 lines.push("OFFSET".into());
431 }
432 }
433
434 if let Some(ref limit_expr) = stmt.limit {
435 if let Ok(n) = eval_const_int(limit_expr) {
436 lines.push(format!("LIMIT {n}"));
437 } else {
438 lines.push("LIMIT".into());
439 }
440 }
441
442 Ok(lines)
443}
444
445fn explain_subqueries(stmt: &SelectStmt, lines: &mut Vec<String>) {
446 let mut count = 0;
447 if let Some(ref w) = stmt.where_clause {
448 count += count_subqueries(w);
449 }
450 if let Some(ref h) = stmt.having {
451 count += count_subqueries(h);
452 }
453 for col in &stmt.columns {
454 if let SelectColumn::Expr { expr, .. } = col {
455 count += count_subqueries(expr);
456 }
457 }
458 for _ in 0..count {
459 lines.push("SUBQUERY".into());
460 }
461}
462
463fn count_subqueries(expr: &Expr) -> usize {
464 match expr {
465 Expr::InSubquery { expr: e, .. } => 1 + count_subqueries(e),
466 Expr::ScalarSubquery(_) => 1,
467 Expr::Exists { .. } => 1,
468 Expr::BinaryOp { left, right, .. } => count_subqueries(left) + count_subqueries(right),
469 Expr::UnaryOp { expr: e, .. } => count_subqueries(e),
470 Expr::IsNull(e) | Expr::IsNotNull(e) => count_subqueries(e),
471 Expr::Function { args, .. } => args.iter().map(count_subqueries).sum(),
472 Expr::Between {
473 expr: e, low, high, ..
474 } => count_subqueries(e) + count_subqueries(low) + count_subqueries(high),
475 Expr::Like {
476 expr: e, pattern, ..
477 } => count_subqueries(e) + count_subqueries(pattern),
478 Expr::Case {
479 operand,
480 conditions,
481 else_result,
482 } => {
483 let mut n = 0;
484 if let Some(op) = operand {
485 n += count_subqueries(op);
486 }
487 for (c, r) in conditions {
488 n += count_subqueries(c) + count_subqueries(r);
489 }
490 if let Some(el) = else_result {
491 n += count_subqueries(el);
492 }
493 n
494 }
495 Expr::Coalesce(args) => args.iter().map(count_subqueries).sum(),
496 Expr::Cast { expr: e, .. } => count_subqueries(e),
497 Expr::InList { expr: e, list, .. } => {
498 count_subqueries(e) + list.iter().map(count_subqueries).sum::<usize>()
499 }
500 _ => 0,
501 }
502}
503
504fn format_scan_line(
505 table_name: &str,
506 alias: &Option<String>,
507 plan: &ScanPlan,
508 table_schema: &TableSchema,
509) -> String {
510 let alias_part = match alias {
511 Some(a) if !a.eq_ignore_ascii_case(table_name) => {
512 format!(" AS {}", a.to_ascii_lowercase())
513 }
514 _ => String::new(),
515 };
516
517 let desc = planner::describe_plan(plan, table_schema);
518
519 if desc.is_empty() {
520 format!("SCAN TABLE {table_name}{alias_part}")
521 } else {
522 format!("SEARCH TABLE {table_name}{alias_part} {desc}")
523 }
524}
525
526fn validate_foreign_keys(
531 schema: &SchemaManager,
532 table_schema: &TableSchema,
533 foreign_keys: &[ForeignKeySchemaEntry],
534) -> Result<()> {
535 for fk in foreign_keys {
536 let parent = if fk.foreign_table == table_schema.name {
538 table_schema
539 } else {
540 schema.get(&fk.foreign_table).ok_or_else(|| {
541 SqlError::Unsupported(format!(
542 "FOREIGN KEY references non-existent table '{}'",
543 fk.foreign_table
544 ))
545 })?
546 };
547
548 let ref_col_indices: Vec<u16> = fk
550 .referred_columns
551 .iter()
552 .map(|rc| {
553 parent
554 .column_index(rc)
555 .map(|i| i as u16)
556 .ok_or_else(|| SqlError::ColumnNotFound(rc.clone()))
557 })
558 .collect::<Result<_>>()?;
559
560 if fk.columns.len() != ref_col_indices.len() {
561 return Err(SqlError::Unsupported(format!(
562 "FOREIGN KEY on '{}': column count mismatch",
563 table_schema.name
564 )));
565 }
566
567 let is_pk = parent.primary_key_columns == ref_col_indices;
569
570 let has_unique = !is_pk
572 && parent
573 .indices
574 .iter()
575 .any(|idx| idx.unique && idx.columns == ref_col_indices);
576
577 if !is_pk && !has_unique {
578 return Err(SqlError::Unsupported(format!(
579 "FOREIGN KEY on '{}': referred columns in '{}' are not PRIMARY KEY or UNIQUE",
580 table_schema.name, fk.foreign_table
581 )));
582 }
583 }
584 Ok(())
585}
586
587fn create_fk_auto_indices(
589 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
590 mut table_schema: TableSchema,
591) -> Result<TableSchema> {
592 let fks: Vec<(Vec<u16>, String)> = table_schema
593 .foreign_keys
594 .iter()
595 .enumerate()
596 .map(|(i, fk)| {
597 let name = fk
598 .name
599 .as_ref()
600 .map(|n| format!("__fk_{}_{}", table_schema.name, n))
601 .unwrap_or_else(|| format!("__fk_{}_{}", table_schema.name, i));
602 (fk.columns.clone(), name)
603 })
604 .collect();
605
606 for (cols, idx_name) in fks {
607 let already_covered = table_schema.indices.iter().any(|idx| idx.columns == cols);
609 if already_covered {
610 continue;
611 }
612
613 let idx_def = IndexDef {
614 name: idx_name.clone(),
615 columns: cols,
616 unique: false,
617 };
618 let idx_table = TableSchema::index_table_name(&table_schema.name, &idx_name);
619 wtx.create_table(&idx_table).map_err(SqlError::Storage)?;
620 table_schema.indices.push(idx_def);
622 }
623 Ok(table_schema)
624}
625
626fn exec_create_table(
629 db: &Database,
630 schema: &mut SchemaManager,
631 stmt: &CreateTableStmt,
632) -> Result<ExecutionResult> {
633 let lower_name = stmt.name.to_ascii_lowercase();
634
635 if schema.contains(&lower_name) {
636 if stmt.if_not_exists {
637 return Ok(ExecutionResult::Ok);
638 }
639 return Err(SqlError::TableAlreadyExists(stmt.name.clone()));
640 }
641
642 if stmt.primary_key.is_empty() {
643 return Err(SqlError::PrimaryKeyRequired);
644 }
645
646 let mut seen = std::collections::HashSet::new();
647 for col in &stmt.columns {
648 let lower = col.name.to_ascii_lowercase();
649 if !seen.insert(lower.clone()) {
650 return Err(SqlError::DuplicateColumn(col.name.clone()));
651 }
652 }
653
654 let columns: Vec<ColumnDef> = stmt
655 .columns
656 .iter()
657 .enumerate()
658 .map(|(i, c)| ColumnDef {
659 name: c.name.to_ascii_lowercase(),
660 data_type: c.data_type,
661 nullable: c.nullable,
662 position: i as u16,
663 default_expr: c.default_expr.clone(),
664 default_sql: c.default_sql.clone(),
665 check_expr: c.check_expr.clone(),
666 check_sql: c.check_sql.clone(),
667 check_name: c.check_name.clone(),
668 })
669 .collect();
670
671 let primary_key_columns: Vec<u16> = stmt
672 .primary_key
673 .iter()
674 .map(|pk_name| {
675 let lower = pk_name.to_ascii_lowercase();
676 columns
677 .iter()
678 .position(|c| c.name == lower)
679 .map(|i| i as u16)
680 .ok_or_else(|| SqlError::ColumnNotFound(pk_name.clone()))
681 })
682 .collect::<Result<_>>()?;
683
684 let check_constraints: Vec<TableCheckDef> = stmt
685 .check_constraints
686 .iter()
687 .map(|tc| TableCheckDef {
688 name: tc.name.clone(),
689 expr: tc.expr.clone(),
690 sql: tc.sql.clone(),
691 })
692 .collect();
693
694 let foreign_keys: Vec<ForeignKeySchemaEntry> = stmt
695 .foreign_keys
696 .iter()
697 .map(|fk| {
698 let col_indices: Vec<u16> = fk
699 .columns
700 .iter()
701 .map(|cn| {
702 let lower = cn.to_ascii_lowercase();
703 columns
704 .iter()
705 .position(|c| c.name == lower)
706 .map(|i| i as u16)
707 .ok_or_else(|| SqlError::ColumnNotFound(cn.clone()))
708 })
709 .collect::<Result<_>>()?;
710 Ok(ForeignKeySchemaEntry {
711 name: fk.name.clone(),
712 columns: col_indices,
713 foreign_table: fk.foreign_table.to_ascii_lowercase(),
714 referred_columns: fk
715 .referred_columns
716 .iter()
717 .map(|s| s.to_ascii_lowercase())
718 .collect(),
719 })
720 })
721 .collect::<Result<_>>()?;
722
723 let table_schema = TableSchema::new(
724 lower_name.clone(),
725 columns,
726 primary_key_columns,
727 vec![],
728 check_constraints,
729 foreign_keys,
730 );
731
732 validate_foreign_keys(schema, &table_schema, &table_schema.foreign_keys)?;
734
735 let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
736 SchemaManager::ensure_schema_table(&mut wtx)?;
737 wtx.create_table(lower_name.as_bytes())
738 .map_err(SqlError::Storage)?;
739
740 let table_schema = create_fk_auto_indices(&mut wtx, table_schema)?;
742
743 SchemaManager::save_schema(&mut wtx, &table_schema)?;
744 wtx.commit().map_err(SqlError::Storage)?;
745
746 schema.register(table_schema);
747 Ok(ExecutionResult::Ok)
748}
749
750fn exec_drop_table(
751 db: &Database,
752 schema: &mut SchemaManager,
753 stmt: &DropTableStmt,
754) -> Result<ExecutionResult> {
755 let lower_name = stmt.name.to_ascii_lowercase();
756
757 if !schema.contains(&lower_name) {
758 if stmt.if_exists {
759 return Ok(ExecutionResult::Ok);
760 }
761 return Err(SqlError::TableNotFound(stmt.name.clone()));
762 }
763
764 for (child_table, _fk) in schema.child_fks_for(&lower_name) {
766 if child_table != lower_name {
767 return Err(SqlError::ForeignKeyViolation(format!(
768 "cannot drop table '{}': referenced by foreign key in '{}'",
769 lower_name, child_table
770 )));
771 }
772 }
773
774 let table_schema = schema.get(&lower_name).unwrap();
775 let idx_tables: Vec<Vec<u8>> = table_schema
776 .indices
777 .iter()
778 .map(|idx| TableSchema::index_table_name(&lower_name, &idx.name))
779 .collect();
780
781 let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
782 for idx_table in &idx_tables {
783 wtx.drop_table(idx_table).map_err(SqlError::Storage)?;
784 }
785 wtx.drop_table(lower_name.as_bytes())
786 .map_err(SqlError::Storage)?;
787 SchemaManager::delete_schema(&mut wtx, &lower_name)?;
788 wtx.commit().map_err(SqlError::Storage)?;
789
790 schema.remove(&lower_name);
791 Ok(ExecutionResult::Ok)
792}
793
794fn exec_create_table_in_txn(
795 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
796 schema: &mut SchemaManager,
797 stmt: &CreateTableStmt,
798) -> Result<ExecutionResult> {
799 let lower_name = stmt.name.to_ascii_lowercase();
800
801 if schema.contains(&lower_name) {
802 if stmt.if_not_exists {
803 return Ok(ExecutionResult::Ok);
804 }
805 return Err(SqlError::TableAlreadyExists(stmt.name.clone()));
806 }
807
808 if stmt.primary_key.is_empty() {
809 return Err(SqlError::PrimaryKeyRequired);
810 }
811
812 let mut seen = std::collections::HashSet::new();
813 for col in &stmt.columns {
814 let lower = col.name.to_ascii_lowercase();
815 if !seen.insert(lower.clone()) {
816 return Err(SqlError::DuplicateColumn(col.name.clone()));
817 }
818 }
819
820 let columns: Vec<ColumnDef> = stmt
821 .columns
822 .iter()
823 .enumerate()
824 .map(|(i, c)| ColumnDef {
825 name: c.name.to_ascii_lowercase(),
826 data_type: c.data_type,
827 nullable: c.nullable,
828 position: i as u16,
829 default_expr: c.default_expr.clone(),
830 default_sql: c.default_sql.clone(),
831 check_expr: c.check_expr.clone(),
832 check_sql: c.check_sql.clone(),
833 check_name: c.check_name.clone(),
834 })
835 .collect();
836
837 let primary_key_columns: Vec<u16> = stmt
838 .primary_key
839 .iter()
840 .map(|pk_name| {
841 let lower = pk_name.to_ascii_lowercase();
842 columns
843 .iter()
844 .position(|c| c.name == lower)
845 .map(|i| i as u16)
846 .ok_or_else(|| SqlError::ColumnNotFound(pk_name.clone()))
847 })
848 .collect::<Result<_>>()?;
849
850 let check_constraints: Vec<TableCheckDef> = stmt
851 .check_constraints
852 .iter()
853 .map(|tc| TableCheckDef {
854 name: tc.name.clone(),
855 expr: tc.expr.clone(),
856 sql: tc.sql.clone(),
857 })
858 .collect();
859
860 let foreign_keys: Vec<ForeignKeySchemaEntry> = stmt
861 .foreign_keys
862 .iter()
863 .map(|fk| {
864 let col_indices: Vec<u16> = fk
865 .columns
866 .iter()
867 .map(|cn| {
868 let lower = cn.to_ascii_lowercase();
869 columns
870 .iter()
871 .position(|c| c.name == lower)
872 .map(|i| i as u16)
873 .ok_or_else(|| SqlError::ColumnNotFound(cn.clone()))
874 })
875 .collect::<Result<_>>()?;
876 Ok(ForeignKeySchemaEntry {
877 name: fk.name.clone(),
878 columns: col_indices,
879 foreign_table: fk.foreign_table.to_ascii_lowercase(),
880 referred_columns: fk
881 .referred_columns
882 .iter()
883 .map(|s| s.to_ascii_lowercase())
884 .collect(),
885 })
886 })
887 .collect::<Result<_>>()?;
888
889 let table_schema = TableSchema::new(
890 lower_name.clone(),
891 columns,
892 primary_key_columns,
893 vec![],
894 check_constraints,
895 foreign_keys,
896 );
897
898 validate_foreign_keys(schema, &table_schema, &table_schema.foreign_keys)?;
899
900 SchemaManager::ensure_schema_table(wtx)?;
901 wtx.create_table(lower_name.as_bytes())
902 .map_err(SqlError::Storage)?;
903
904 let table_schema = create_fk_auto_indices(wtx, table_schema)?;
905
906 SchemaManager::save_schema(wtx, &table_schema)?;
907
908 schema.register(table_schema);
909 Ok(ExecutionResult::Ok)
910}
911
912fn exec_drop_table_in_txn(
913 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
914 schema: &mut SchemaManager,
915 stmt: &DropTableStmt,
916) -> Result<ExecutionResult> {
917 let lower_name = stmt.name.to_ascii_lowercase();
918
919 if !schema.contains(&lower_name) {
920 if stmt.if_exists {
921 return Ok(ExecutionResult::Ok);
922 }
923 return Err(SqlError::TableNotFound(stmt.name.clone()));
924 }
925
926 for (child_table, _fk) in schema.child_fks_for(&lower_name) {
928 if child_table != lower_name {
929 return Err(SqlError::ForeignKeyViolation(format!(
930 "cannot drop table '{}': referenced by foreign key in '{}'",
931 lower_name, child_table
932 )));
933 }
934 }
935
936 let table_schema = schema.get(&lower_name).unwrap();
937 let idx_tables: Vec<Vec<u8>> = table_schema
938 .indices
939 .iter()
940 .map(|idx| TableSchema::index_table_name(&lower_name, &idx.name))
941 .collect();
942
943 for idx_table in &idx_tables {
944 wtx.drop_table(idx_table).map_err(SqlError::Storage)?;
945 }
946 wtx.drop_table(lower_name.as_bytes())
947 .map_err(SqlError::Storage)?;
948 SchemaManager::delete_schema(wtx, &lower_name)?;
949
950 schema.remove(&lower_name);
951 Ok(ExecutionResult::Ok)
952}
953
954fn exec_create_index(
955 db: &Database,
956 schema: &mut SchemaManager,
957 stmt: &CreateIndexStmt,
958) -> Result<ExecutionResult> {
959 let lower_table = stmt.table_name.to_ascii_lowercase();
960 let lower_idx = stmt.index_name.to_ascii_lowercase();
961
962 let table_schema = schema
963 .get(&lower_table)
964 .ok_or_else(|| SqlError::TableNotFound(stmt.table_name.clone()))?;
965
966 if table_schema.index_by_name(&lower_idx).is_some() {
967 if stmt.if_not_exists {
968 return Ok(ExecutionResult::Ok);
969 }
970 return Err(SqlError::IndexAlreadyExists(stmt.index_name.clone()));
971 }
972
973 let col_indices: Vec<u16> = stmt
974 .columns
975 .iter()
976 .map(|col_name| {
977 let lower = col_name.to_ascii_lowercase();
978 table_schema
979 .column_index(&lower)
980 .map(|i| i as u16)
981 .ok_or_else(|| SqlError::ColumnNotFound(col_name.clone()))
982 })
983 .collect::<Result<_>>()?;
984
985 let idx_def = IndexDef {
986 name: lower_idx.clone(),
987 columns: col_indices,
988 unique: stmt.unique,
989 };
990
991 let idx_table = TableSchema::index_table_name(&lower_table, &lower_idx);
992
993 let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
994 SchemaManager::ensure_schema_table(&mut wtx)?;
995 wtx.create_table(&idx_table).map_err(SqlError::Storage)?;
996
997 let pk_indices = table_schema.pk_indices();
999 let mut rows: Vec<Vec<Value>> = Vec::new();
1000 {
1001 let mut scan_err: Option<SqlError> = None;
1002 wtx.table_for_each(lower_table.as_bytes(), |key, value| {
1003 match decode_full_row(table_schema, key, value) {
1004 Ok(row) => rows.push(row),
1005 Err(e) => scan_err = Some(e),
1006 }
1007 Ok(())
1008 })
1009 .map_err(SqlError::Storage)?;
1010 if let Some(e) = scan_err {
1011 return Err(e);
1012 }
1013 }
1014
1015 for row in &rows {
1016 let pk_values: Vec<Value> = pk_indices.iter().map(|&i| row[i].clone()).collect();
1017 let key = encode_index_key(&idx_def, row, &pk_values);
1018 let value = encode_index_value(&idx_def, row, &pk_values);
1019 let is_new = wtx
1020 .table_insert(&idx_table, &key, &value)
1021 .map_err(SqlError::Storage)?;
1022 if idx_def.unique && !is_new {
1023 let indexed_values: Vec<Value> = idx_def
1024 .columns
1025 .iter()
1026 .map(|&col_idx| row[col_idx as usize].clone())
1027 .collect();
1028 let any_null = indexed_values.iter().any(|v| v.is_null());
1029 if !any_null {
1030 return Err(SqlError::UniqueViolation(stmt.index_name.clone()));
1031 }
1032 }
1033 }
1034
1035 let mut updated_schema = table_schema.clone();
1036 updated_schema.indices.push(idx_def);
1037 SchemaManager::save_schema(&mut wtx, &updated_schema)?;
1038 wtx.commit().map_err(SqlError::Storage)?;
1039
1040 schema.register(updated_schema);
1041 Ok(ExecutionResult::Ok)
1042}
1043
1044fn exec_drop_index(
1045 db: &Database,
1046 schema: &mut SchemaManager,
1047 stmt: &DropIndexStmt,
1048) -> Result<ExecutionResult> {
1049 let lower_idx = stmt.index_name.to_ascii_lowercase();
1050
1051 let (table_name, _idx_pos) = match find_index_in_schemas(schema, &lower_idx) {
1052 Some(found) => found,
1053 None => {
1054 if stmt.if_exists {
1055 return Ok(ExecutionResult::Ok);
1056 }
1057 return Err(SqlError::IndexNotFound(stmt.index_name.clone()));
1058 }
1059 };
1060
1061 let idx_table = TableSchema::index_table_name(&table_name, &lower_idx);
1062
1063 let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
1064 wtx.drop_table(&idx_table).map_err(SqlError::Storage)?;
1065
1066 let table_schema = schema.get(&table_name).unwrap();
1067 let mut updated_schema = table_schema.clone();
1068 updated_schema.indices.retain(|i| i.name != lower_idx);
1069 SchemaManager::save_schema(&mut wtx, &updated_schema)?;
1070 wtx.commit().map_err(SqlError::Storage)?;
1071
1072 schema.register(updated_schema);
1073 Ok(ExecutionResult::Ok)
1074}
1075
1076fn exec_create_index_in_txn(
1077 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
1078 schema: &mut SchemaManager,
1079 stmt: &CreateIndexStmt,
1080) -> Result<ExecutionResult> {
1081 let lower_table = stmt.table_name.to_ascii_lowercase();
1082 let lower_idx = stmt.index_name.to_ascii_lowercase();
1083
1084 let table_schema = schema
1085 .get(&lower_table)
1086 .ok_or_else(|| SqlError::TableNotFound(stmt.table_name.clone()))?;
1087
1088 if table_schema.index_by_name(&lower_idx).is_some() {
1089 if stmt.if_not_exists {
1090 return Ok(ExecutionResult::Ok);
1091 }
1092 return Err(SqlError::IndexAlreadyExists(stmt.index_name.clone()));
1093 }
1094
1095 let col_indices: Vec<u16> = stmt
1096 .columns
1097 .iter()
1098 .map(|col_name| {
1099 let lower = col_name.to_ascii_lowercase();
1100 table_schema
1101 .column_index(&lower)
1102 .map(|i| i as u16)
1103 .ok_or_else(|| SqlError::ColumnNotFound(col_name.clone()))
1104 })
1105 .collect::<Result<_>>()?;
1106
1107 let idx_def = IndexDef {
1108 name: lower_idx.clone(),
1109 columns: col_indices,
1110 unique: stmt.unique,
1111 };
1112
1113 let idx_table = TableSchema::index_table_name(&lower_table, &lower_idx);
1114
1115 SchemaManager::ensure_schema_table(wtx)?;
1116 wtx.create_table(&idx_table).map_err(SqlError::Storage)?;
1117
1118 let pk_indices = table_schema.pk_indices();
1119 let mut rows: Vec<Vec<Value>> = Vec::new();
1120 {
1121 let mut scan_err: Option<SqlError> = None;
1122 wtx.table_for_each(lower_table.as_bytes(), |key, value| {
1123 match decode_full_row(table_schema, key, value) {
1124 Ok(row) => rows.push(row),
1125 Err(e) => scan_err = Some(e),
1126 }
1127 Ok(())
1128 })
1129 .map_err(SqlError::Storage)?;
1130 if let Some(e) = scan_err {
1131 return Err(e);
1132 }
1133 }
1134
1135 for row in &rows {
1136 let pk_values: Vec<Value> = pk_indices.iter().map(|&i| row[i].clone()).collect();
1137 let key = encode_index_key(&idx_def, row, &pk_values);
1138 let value = encode_index_value(&idx_def, row, &pk_values);
1139 let is_new = wtx
1140 .table_insert(&idx_table, &key, &value)
1141 .map_err(SqlError::Storage)?;
1142 if idx_def.unique && !is_new {
1143 let indexed_values: Vec<Value> = idx_def
1144 .columns
1145 .iter()
1146 .map(|&col_idx| row[col_idx as usize].clone())
1147 .collect();
1148 let any_null = indexed_values.iter().any(|v| v.is_null());
1149 if !any_null {
1150 return Err(SqlError::UniqueViolation(stmt.index_name.clone()));
1151 }
1152 }
1153 }
1154
1155 let mut updated_schema = table_schema.clone();
1156 updated_schema.indices.push(idx_def);
1157 SchemaManager::save_schema(wtx, &updated_schema)?;
1158
1159 schema.register(updated_schema);
1160 Ok(ExecutionResult::Ok)
1161}
1162
1163fn exec_drop_index_in_txn(
1164 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
1165 schema: &mut SchemaManager,
1166 stmt: &DropIndexStmt,
1167) -> Result<ExecutionResult> {
1168 let lower_idx = stmt.index_name.to_ascii_lowercase();
1169
1170 let (table_name, _idx_pos) = match find_index_in_schemas(schema, &lower_idx) {
1171 Some(found) => found,
1172 None => {
1173 if stmt.if_exists {
1174 return Ok(ExecutionResult::Ok);
1175 }
1176 return Err(SqlError::IndexNotFound(stmt.index_name.clone()));
1177 }
1178 };
1179
1180 let idx_table = TableSchema::index_table_name(&table_name, &lower_idx);
1181 wtx.drop_table(&idx_table).map_err(SqlError::Storage)?;
1182
1183 let table_schema = schema.get(&table_name).unwrap();
1184 let mut updated_schema = table_schema.clone();
1185 updated_schema.indices.retain(|i| i.name != lower_idx);
1186 SchemaManager::save_schema(wtx, &updated_schema)?;
1187
1188 schema.register(updated_schema);
1189 Ok(ExecutionResult::Ok)
1190}
1191
1192fn exec_alter_table(
1195 db: &Database,
1196 schema: &mut SchemaManager,
1197 stmt: &AlterTableStmt,
1198) -> Result<ExecutionResult> {
1199 let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
1200 SchemaManager::ensure_schema_table(&mut wtx)?;
1201 alter_table_impl(&mut wtx, schema, stmt)?;
1202 wtx.commit().map_err(SqlError::Storage)?;
1203 Ok(ExecutionResult::Ok)
1204}
1205
1206fn exec_alter_table_in_txn(
1207 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
1208 schema: &mut SchemaManager,
1209 stmt: &AlterTableStmt,
1210) -> Result<ExecutionResult> {
1211 SchemaManager::ensure_schema_table(wtx)?;
1212 alter_table_impl(wtx, schema, stmt)?;
1213 Ok(ExecutionResult::Ok)
1214}
1215
1216fn alter_table_impl(
1217 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
1218 schema: &mut SchemaManager,
1219 stmt: &AlterTableStmt,
1220) -> Result<()> {
1221 let lower_name = stmt.table.to_ascii_lowercase();
1222 if lower_name == "_schema" {
1223 return Err(SqlError::Unsupported("cannot alter internal table".into()));
1224 }
1225 match &stmt.op {
1226 AlterTableOp::AddColumn {
1227 column,
1228 foreign_key,
1229 if_not_exists,
1230 } => alter_add_column(
1231 wtx,
1232 schema,
1233 &lower_name,
1234 column,
1235 foreign_key.as_ref(),
1236 *if_not_exists,
1237 ),
1238 AlterTableOp::DropColumn { name, if_exists } => {
1239 alter_drop_column(wtx, schema, &lower_name, name, *if_exists)
1240 }
1241 AlterTableOp::RenameColumn { old_name, new_name } => {
1242 alter_rename_column(wtx, schema, &lower_name, old_name, new_name)
1243 }
1244 AlterTableOp::RenameTable { new_name } => {
1245 alter_rename_table(wtx, schema, &lower_name, new_name)
1246 }
1247 }
1248}
1249
1250fn alter_add_column(
1251 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
1252 schema: &mut SchemaManager,
1253 table_name: &str,
1254 col_spec: &ColumnSpec,
1255 fk_def: Option<&ForeignKeyDef>,
1256 if_not_exists: bool,
1257) -> Result<()> {
1258 let table_schema = schema
1259 .get(table_name)
1260 .ok_or_else(|| SqlError::TableNotFound(table_name.into()))?;
1261
1262 let col_lower = col_spec.name.to_ascii_lowercase();
1263
1264 if table_schema.column_index(&col_lower).is_some() {
1265 if if_not_exists {
1266 return Ok(());
1267 }
1268 return Err(SqlError::DuplicateColumn(col_spec.name.clone()));
1269 }
1270
1271 if col_spec.is_primary_key {
1272 return Err(SqlError::Unsupported(
1273 "cannot add PRIMARY KEY column via ALTER TABLE".into(),
1274 ));
1275 }
1276
1277 if !col_spec.nullable && col_spec.default_expr.is_none() {
1278 let count = wtx.table_entry_count(table_name.as_bytes()).unwrap_or(0);
1279 if count > 0 {
1280 return Err(SqlError::Unsupported(
1281 "cannot add NOT NULL column without DEFAULT to non-empty table".into(),
1282 ));
1283 }
1284 }
1285
1286 if let Some(ref check) = col_spec.check_expr {
1287 if has_subquery(check) {
1288 return Err(SqlError::Unsupported("subquery in CHECK constraint".into()));
1289 }
1290 }
1291
1292 let new_pos = table_schema.columns.len() as u16;
1293 let new_col = ColumnDef {
1294 name: col_lower.clone(),
1295 data_type: col_spec.data_type,
1296 nullable: col_spec.nullable,
1297 position: new_pos,
1298 default_expr: col_spec.default_expr.clone(),
1299 default_sql: col_spec.default_sql.clone(),
1300 check_expr: col_spec.check_expr.clone(),
1301 check_sql: col_spec.check_sql.clone(),
1302 check_name: col_spec.check_name.clone(),
1303 };
1304
1305 let mut new_schema = table_schema.clone();
1306 new_schema.columns.push(new_col);
1307
1308 if let Some(fk) = fk_def {
1309 let col_idx = new_pos;
1310 let fk_entry = ForeignKeySchemaEntry {
1311 name: fk.name.clone(),
1312 columns: vec![col_idx],
1313 foreign_table: fk.foreign_table.to_ascii_lowercase(),
1314 referred_columns: fk
1315 .referred_columns
1316 .iter()
1317 .map(|s| s.to_ascii_lowercase())
1318 .collect(),
1319 };
1320 new_schema.foreign_keys.push(fk_entry);
1321 }
1322
1323 new_schema = new_schema.rebuild();
1324
1325 if fk_def.is_some() {
1326 validate_foreign_keys(schema, &new_schema, &new_schema.foreign_keys)?;
1327 new_schema = create_fk_auto_indices(wtx, new_schema)?;
1328 }
1329
1330 SchemaManager::save_schema(wtx, &new_schema)?;
1331 schema.register(new_schema);
1332 Ok(())
1333}
1334
1335fn alter_drop_column(
1336 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
1337 schema: &mut SchemaManager,
1338 table_name: &str,
1339 col_name: &str,
1340 if_exists: bool,
1341) -> Result<()> {
1342 let table_schema = schema
1343 .get(table_name)
1344 .ok_or_else(|| SqlError::TableNotFound(table_name.into()))?;
1345
1346 let col_lower = col_name.to_ascii_lowercase();
1347 let drop_pos = match table_schema.column_index(&col_lower) {
1348 Some(pos) => pos,
1349 None => {
1350 if if_exists {
1351 return Ok(());
1352 }
1353 return Err(SqlError::ColumnNotFound(col_name.into()));
1354 }
1355 };
1356 let drop_pos_u16 = drop_pos as u16;
1357
1358 if table_schema.primary_key_columns.contains(&drop_pos_u16) {
1359 return Err(SqlError::Unsupported(
1360 "cannot drop primary key column".into(),
1361 ));
1362 }
1363
1364 for idx in &table_schema.indices {
1365 if idx.columns.contains(&drop_pos_u16) {
1366 return Err(SqlError::Unsupported(format!(
1367 "column '{}' is indexed by '{}'; drop the index first",
1368 col_lower, idx.name
1369 )));
1370 }
1371 }
1372
1373 for fk in &table_schema.foreign_keys {
1374 if fk.columns.contains(&drop_pos_u16) {
1375 return Err(SqlError::Unsupported(format!(
1376 "column '{}' is part of a foreign key",
1377 col_lower
1378 )));
1379 }
1380 }
1381
1382 for (child_table, fk) in schema.child_fks_for(table_name) {
1383 if child_table == table_name {
1384 continue; }
1386 if fk.referred_columns.iter().any(|rc| rc == &col_lower) {
1387 return Err(SqlError::Unsupported(format!(
1388 "column '{}' is referenced by a foreign key in '{}'",
1389 col_lower, child_table
1390 )));
1391 }
1392 }
1393
1394 for tc in &table_schema.check_constraints {
1395 if tc.sql.to_ascii_lowercase().contains(&col_lower) {
1396 return Err(SqlError::Unsupported(format!(
1397 "column '{}' is used in CHECK constraint '{}'",
1398 col_lower,
1399 tc.name.as_deref().unwrap_or("<unnamed>")
1400 )));
1401 }
1402 }
1403
1404 let new_schema = table_schema.without_column(drop_pos);
1408
1409 SchemaManager::save_schema(wtx, &new_schema)?;
1410 schema.register(new_schema);
1411 Ok(())
1412}
1413
1414fn alter_rename_column(
1415 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
1416 schema: &mut SchemaManager,
1417 table_name: &str,
1418 old_name: &str,
1419 new_name: &str,
1420) -> Result<()> {
1421 let table_schema = schema
1422 .get(table_name)
1423 .ok_or_else(|| SqlError::TableNotFound(table_name.into()))?;
1424
1425 let old_lower = old_name.to_ascii_lowercase();
1426 let new_lower = new_name.to_ascii_lowercase();
1427
1428 let col_pos = table_schema
1429 .column_index(&old_lower)
1430 .ok_or_else(|| SqlError::ColumnNotFound(old_name.into()))?;
1431
1432 if table_schema.column_index(&new_lower).is_some() {
1433 return Err(SqlError::DuplicateColumn(new_name.into()));
1434 }
1435
1436 let mut new_schema = table_schema.clone();
1437 new_schema.columns[col_pos].name = new_lower.clone();
1438
1439 for col in &mut new_schema.columns {
1441 if let Some(ref sql) = col.check_sql {
1442 if sql.to_ascii_lowercase().contains(&old_lower) {
1443 let updated = sql.replace(&old_lower, &new_lower);
1444 col.check_sql = Some(updated.clone());
1445 if let Ok(parsed) = crate::parser::parse_sql_expr(&updated) {
1447 col.check_expr = Some(parsed);
1448 }
1449 }
1450 }
1451 }
1452 for tc in &mut new_schema.check_constraints {
1453 if tc.sql.to_ascii_lowercase().contains(&old_lower) {
1454 tc.sql = tc.sql.replace(&old_lower, &new_lower);
1455 if let Ok(parsed) = crate::parser::parse_sql_expr(&tc.sql) {
1456 tc.expr = parsed;
1457 }
1458 }
1459 }
1460
1461 for fk in &mut new_schema.foreign_keys {
1465 if fk.foreign_table == table_name {
1466 for rc in &mut fk.referred_columns {
1467 if *rc == old_lower {
1468 *rc = new_lower.clone();
1469 }
1470 }
1471 }
1472 }
1473
1474 SchemaManager::save_schema(wtx, &new_schema)?;
1475 schema.register(new_schema);
1476 Ok(())
1477}
1478
1479fn alter_rename_table(
1480 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
1481 schema: &mut SchemaManager,
1482 old_name: &str,
1483 new_name: &str,
1484) -> Result<()> {
1485 let new_lower = new_name.to_ascii_lowercase();
1486
1487 if new_lower == "_schema" {
1488 return Err(SqlError::Unsupported(
1489 "cannot rename to internal table name".into(),
1490 ));
1491 }
1492
1493 let table_schema = schema
1494 .get(old_name)
1495 .ok_or_else(|| SqlError::TableNotFound(old_name.into()))?
1496 .clone();
1497
1498 if schema.contains(&new_lower) {
1499 return Err(SqlError::TableAlreadyExists(new_name.into()));
1500 }
1501
1502 wtx.rename_table(old_name.as_bytes(), new_lower.as_bytes())
1503 .map_err(SqlError::Storage)?;
1504
1505 let idx_renames: Vec<(Vec<u8>, Vec<u8>)> = table_schema
1506 .indices
1507 .iter()
1508 .map(|idx| {
1509 let old_idx = TableSchema::index_table_name(old_name, &idx.name);
1510 let new_idx = TableSchema::index_table_name(&new_lower, &idx.name);
1511 (old_idx, new_idx)
1512 })
1513 .collect();
1514 for (old_idx, new_idx) in &idx_renames {
1515 wtx.rename_table(old_idx, new_idx)
1516 .map_err(SqlError::Storage)?;
1517 }
1518
1519 let child_tables: Vec<String> = schema
1520 .child_fks_for(old_name)
1521 .iter()
1522 .filter(|(child, _)| *child != old_name)
1523 .map(|(child, _)| child.to_string())
1524 .collect::<std::collections::HashSet<_>>()
1525 .into_iter()
1526 .collect();
1527
1528 for child_table in &child_tables {
1529 let mut updated_child = schema.get(child_table).unwrap().clone();
1530 for fk in &mut updated_child.foreign_keys {
1531 if fk.foreign_table == old_name {
1532 fk.foreign_table = new_lower.clone();
1533 }
1534 }
1535 SchemaManager::save_schema(wtx, &updated_child)?;
1536 schema.register(updated_child);
1537 }
1538
1539 SchemaManager::delete_schema(wtx, old_name)?;
1540 let mut new_schema = table_schema;
1541 new_schema.name = new_lower.clone();
1542
1543 for fk in &mut new_schema.foreign_keys {
1545 if fk.foreign_table == old_name {
1546 fk.foreign_table = new_lower.clone();
1547 }
1548 }
1549
1550 SchemaManager::save_schema(wtx, &new_schema)?;
1551 schema.remove(old_name);
1552 schema.register(new_schema);
1553 Ok(())
1554}
1555
1556fn find_index_in_schemas(schema: &SchemaManager, index_name: &str) -> Option<(String, usize)> {
1557 for table_name in schema.table_names() {
1558 if let Some(ts) = schema.get(table_name) {
1559 if let Some(pos) = ts.indices.iter().position(|i| i.name == index_name) {
1560 return Some((table_name.to_string(), pos));
1561 }
1562 }
1563 }
1564 None
1565}
1566
1567fn extract_pk_key(
1570 idx_key: &[u8],
1571 idx_value: &[u8],
1572 is_unique: bool,
1573 num_index_cols: usize,
1574 num_pk_cols: usize,
1575) -> Result<Vec<u8>> {
1576 if is_unique && !idx_value.is_empty() {
1577 Ok(idx_value.to_vec())
1578 } else {
1579 let total_cols = num_index_cols + num_pk_cols;
1580 let all_values = decode_composite_key(idx_key, total_cols)?;
1581 let pk_values = &all_values[num_index_cols..];
1582 Ok(encode_composite_key(pk_values))
1583 }
1584}
1585
1586fn check_range_conditions(
1587 idx_key: &[u8],
1588 num_prefix_cols: usize,
1589 range_conds: &[(BinOp, Value)],
1590 num_index_cols: usize,
1591) -> Result<RangeCheck> {
1592 if range_conds.is_empty() {
1593 return Ok(RangeCheck::Match);
1594 }
1595
1596 let num_to_decode = num_prefix_cols + 1;
1597 if num_to_decode > num_index_cols {
1598 return Ok(RangeCheck::Match);
1599 }
1600
1601 let mut pos = 0;
1603 for _ in 0..num_prefix_cols {
1604 let (_, n) = decode_key_value(&idx_key[pos..])?;
1605 pos += n;
1606 }
1607 let (range_val, _) = decode_key_value(&idx_key[pos..])?;
1608
1609 let mut exceeds_upper = false;
1610 let mut below_lower = false;
1611
1612 for (op, val) in range_conds {
1613 match op {
1614 BinOp::Lt => {
1615 if range_val >= *val {
1616 exceeds_upper = true;
1617 }
1618 }
1619 BinOp::LtEq => {
1620 if range_val > *val {
1621 exceeds_upper = true;
1622 }
1623 }
1624 BinOp::Gt => {
1625 if range_val <= *val {
1626 below_lower = true;
1627 }
1628 }
1629 BinOp::GtEq => {
1630 if range_val < *val {
1631 below_lower = true;
1632 }
1633 }
1634 _ => {}
1635 }
1636 }
1637
1638 if exceeds_upper {
1639 Ok(RangeCheck::ExceedsUpper)
1640 } else if below_lower {
1641 Ok(RangeCheck::BelowLower)
1642 } else {
1643 Ok(RangeCheck::Match)
1644 }
1645}
1646
1647enum RangeCheck {
1648 Match,
1649 BelowLower,
1650 ExceedsUpper,
1651}
1652
1653fn collect_rows_read(
1655 db: &Database,
1656 table_schema: &TableSchema,
1657 where_clause: &Option<Expr>,
1658 limit: Option<usize>,
1659) -> Result<(Vec<Vec<Value>>, bool)> {
1660 let plan = planner::plan_select(table_schema, where_clause);
1661 let lower_name = &table_schema.name;
1662 let columns = &table_schema.columns;
1663
1664 match plan {
1665 ScanPlan::SeqScan => {
1666 let simple_pred = where_clause
1667 .as_ref()
1668 .and_then(|expr| try_simple_predicate(expr, table_schema));
1669
1670 if let Some(ref pred) = simple_pred {
1671 let mut rtx = db.begin_read();
1672 let entry_count =
1673 rtx.table_entry_count(lower_name.as_bytes()).unwrap_or(0) as usize;
1674 let mut rows = Vec::with_capacity(entry_count / 4);
1675 let mut scan_err: Option<SqlError> = None;
1676 rtx.table_scan_raw(lower_name.as_bytes(), |key, value| {
1677 match pred.matches_raw(key, value) {
1678 Ok(true) => match decode_full_row(table_schema, key, value) {
1679 Ok(row) => rows.push(row),
1680 Err(e) => {
1681 scan_err = Some(e);
1682 return false;
1683 }
1684 },
1685 Ok(false) => {}
1686 Err(e) => {
1687 scan_err = Some(e);
1688 return false;
1689 }
1690 }
1691 scan_err.is_none() && limit.map_or(true, |n| rows.len() < n)
1692 })
1693 .map_err(SqlError::Storage)?;
1694 if let Some(e) = scan_err {
1695 return Err(e);
1696 }
1697 return Ok((rows, true));
1698 }
1699
1700 let mut rtx = db.begin_read();
1701 let entry_count = rtx.table_entry_count(lower_name.as_bytes()).unwrap_or(0) as usize;
1702 let capacity = if where_clause.is_some() {
1703 entry_count / 4
1704 } else {
1705 entry_count
1706 };
1707 let mut rows = Vec::with_capacity(capacity);
1708 let mut scan_err: Option<SqlError> = None;
1709
1710 let col_map = ColumnMap::new(columns);
1711 let partial_ctx = where_clause.as_ref().and_then(|expr| {
1712 let needed = referenced_columns(expr, columns);
1713 if needed.len() < columns.len() {
1714 Some(PartialDecodeCtx::new(table_schema, &needed))
1715 } else {
1716 None
1717 }
1718 });
1719
1720 rtx.table_scan_from(lower_name.as_bytes(), b"", |key, value| {
1721 match (&where_clause, &partial_ctx) {
1722 (Some(expr), Some(ctx)) => match ctx.decode(key, value) {
1723 Ok(partial) => match eval_expr(expr, &col_map, &partial) {
1724 Ok(val) if is_truthy(&val) => match ctx.complete(partial, key, value) {
1725 Ok(row) => rows.push(row),
1726 Err(e) => scan_err = Some(e),
1727 },
1728 Err(e) => scan_err = Some(e),
1729 _ => {}
1730 },
1731 Err(e) => scan_err = Some(e),
1732 },
1733 (Some(expr), None) => match decode_full_row(table_schema, key, value) {
1734 Ok(row) => match eval_expr(expr, &col_map, &row) {
1735 Ok(val) if is_truthy(&val) => rows.push(row),
1736 Err(e) => scan_err = Some(e),
1737 _ => {}
1738 },
1739 Err(e) => scan_err = Some(e),
1740 },
1741 _ => match decode_full_row(table_schema, key, value) {
1742 Ok(row) => rows.push(row),
1743 Err(e) => scan_err = Some(e),
1744 },
1745 }
1746 let keep_going = scan_err.is_none() && limit.map_or(true, |n| rows.len() < n);
1747 Ok(keep_going)
1748 })
1749 .map_err(SqlError::Storage)?;
1750 if let Some(e) = scan_err {
1751 return Err(e);
1752 }
1753 Ok((rows, where_clause.is_some()))
1754 }
1755
1756 ScanPlan::PkLookup { pk_values } => {
1757 let key = encode_composite_key(&pk_values);
1758 let mut rtx = db.begin_read();
1759 match rtx
1760 .table_get(lower_name.as_bytes(), &key)
1761 .map_err(SqlError::Storage)?
1762 {
1763 Some(value) => {
1764 let row = decode_full_row(table_schema, &key, &value)?;
1765 if let Some(ref expr) = where_clause {
1766 let col_map = ColumnMap::new(columns);
1767 match eval_expr(expr, &col_map, &row) {
1768 Ok(val) if is_truthy(&val) => Ok((vec![row], true)),
1769 _ => Ok((vec![], true)),
1770 }
1771 } else {
1772 Ok((vec![row], false))
1773 }
1774 }
1775 None => Ok((vec![], true)),
1776 }
1777 }
1778
1779 ScanPlan::IndexScan {
1780 idx_table,
1781 prefix,
1782 num_prefix_cols,
1783 range_conds,
1784 is_unique,
1785 index_columns,
1786 ..
1787 } => {
1788 let num_pk_cols = table_schema.primary_key_columns.len();
1789 let num_index_cols = index_columns.len();
1790 let mut pk_keys: Vec<Vec<u8>> = Vec::new();
1791
1792 {
1793 let mut rtx = db.begin_read();
1794 let mut scan_err: Option<SqlError> = None;
1795 rtx.table_scan_from(&idx_table, &prefix, |key, value| {
1796 if !key.starts_with(&prefix) {
1797 return Ok(false);
1798 }
1799 match check_range_conditions(key, num_prefix_cols, &range_conds, num_index_cols)
1800 {
1801 Ok(RangeCheck::ExceedsUpper) => return Ok(false),
1802 Ok(RangeCheck::BelowLower) => return Ok(true),
1803 Ok(RangeCheck::Match) => {}
1804 Err(e) => {
1805 scan_err = Some(e);
1806 return Ok(false);
1807 }
1808 }
1809 match extract_pk_key(key, value, is_unique, num_index_cols, num_pk_cols) {
1810 Ok(pk) => pk_keys.push(pk),
1811 Err(e) => {
1812 scan_err = Some(e);
1813 return Ok(false);
1814 }
1815 }
1816 Ok(true)
1817 })
1818 .map_err(SqlError::Storage)?;
1819 if let Some(e) = scan_err {
1820 return Err(e);
1821 }
1822 }
1823
1824 let mut rows = Vec::new();
1825 let mut rtx = db.begin_read();
1826 let col_map = ColumnMap::new(columns);
1827 for pk_key in &pk_keys {
1828 if let Some(value) = rtx
1829 .table_get(lower_name.as_bytes(), pk_key)
1830 .map_err(SqlError::Storage)?
1831 {
1832 let row = decode_full_row(table_schema, pk_key, &value)?;
1833 if let Some(ref expr) = where_clause {
1834 match eval_expr(expr, &col_map, &row) {
1835 Ok(val) if is_truthy(&val) => rows.push(row),
1836 _ => {}
1837 }
1838 } else {
1839 rows.push(row);
1840 }
1841 }
1842 }
1843 Ok((rows, where_clause.is_some()))
1844 }
1845 }
1846}
1847
1848fn collect_rows_write(
1850 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
1851 table_schema: &TableSchema,
1852 where_clause: &Option<Expr>,
1853 limit: Option<usize>,
1854) -> Result<(Vec<Vec<Value>>, bool)> {
1855 let plan = planner::plan_select(table_schema, where_clause);
1856 let lower_name = &table_schema.name;
1857 let columns = &table_schema.columns;
1858
1859 match plan {
1860 ScanPlan::SeqScan => {
1861 let simple_pred = where_clause
1862 .as_ref()
1863 .and_then(|expr| try_simple_predicate(expr, table_schema));
1864
1865 if let Some(ref pred) = simple_pred {
1866 let mut rows = Vec::new();
1867 let mut scan_err: Option<SqlError> = None;
1868 wtx.table_scan_from(lower_name.as_bytes(), b"", |key, value| {
1869 match pred.matches_raw(key, value) {
1870 Ok(true) => match decode_full_row(table_schema, key, value) {
1871 Ok(row) => rows.push(row),
1872 Err(e) => scan_err = Some(e),
1873 },
1874 Ok(false) => {}
1875 Err(e) => scan_err = Some(e),
1876 }
1877 let keep_going = scan_err.is_none() && limit.map_or(true, |n| rows.len() < n);
1878 Ok(keep_going)
1879 })
1880 .map_err(SqlError::Storage)?;
1881 if let Some(e) = scan_err {
1882 return Err(e);
1883 }
1884 return Ok((rows, true));
1885 }
1886
1887 let mut rows = Vec::new();
1888 let mut scan_err: Option<SqlError> = None;
1889
1890 let col_map = ColumnMap::new(columns);
1891 let partial_ctx = where_clause.as_ref().and_then(|expr| {
1892 let needed = referenced_columns(expr, columns);
1893 if needed.len() < columns.len() {
1894 Some(PartialDecodeCtx::new(table_schema, &needed))
1895 } else {
1896 None
1897 }
1898 });
1899
1900 wtx.table_scan_from(lower_name.as_bytes(), b"", |key, value| {
1901 match (&where_clause, &partial_ctx) {
1902 (Some(expr), Some(ctx)) => match ctx.decode(key, value) {
1903 Ok(partial) => match eval_expr(expr, &col_map, &partial) {
1904 Ok(val) if is_truthy(&val) => match ctx.complete(partial, key, value) {
1905 Ok(row) => rows.push(row),
1906 Err(e) => scan_err = Some(e),
1907 },
1908 Err(e) => scan_err = Some(e),
1909 _ => {}
1910 },
1911 Err(e) => scan_err = Some(e),
1912 },
1913 (Some(expr), None) => match decode_full_row(table_schema, key, value) {
1914 Ok(row) => match eval_expr(expr, &col_map, &row) {
1915 Ok(val) if is_truthy(&val) => rows.push(row),
1916 Err(e) => scan_err = Some(e),
1917 _ => {}
1918 },
1919 Err(e) => scan_err = Some(e),
1920 },
1921 _ => match decode_full_row(table_schema, key, value) {
1922 Ok(row) => rows.push(row),
1923 Err(e) => scan_err = Some(e),
1924 },
1925 }
1926 let keep_going = scan_err.is_none() && limit.map_or(true, |n| rows.len() < n);
1927 Ok(keep_going)
1928 })
1929 .map_err(SqlError::Storage)?;
1930 if let Some(e) = scan_err {
1931 return Err(e);
1932 }
1933 Ok((rows, where_clause.is_some()))
1934 }
1935
1936 ScanPlan::PkLookup { pk_values } => {
1937 let key = encode_composite_key(&pk_values);
1938 match wtx
1939 .table_get(lower_name.as_bytes(), &key)
1940 .map_err(SqlError::Storage)?
1941 {
1942 Some(value) => {
1943 let row = decode_full_row(table_schema, &key, &value)?;
1944 if let Some(ref expr) = where_clause {
1945 let col_map = ColumnMap::new(columns);
1946 match eval_expr(expr, &col_map, &row) {
1947 Ok(val) if is_truthy(&val) => Ok((vec![row], true)),
1948 _ => Ok((vec![], true)),
1949 }
1950 } else {
1951 Ok((vec![row], false))
1952 }
1953 }
1954 None => Ok((vec![], true)),
1955 }
1956 }
1957
1958 ScanPlan::IndexScan {
1959 idx_table,
1960 prefix,
1961 num_prefix_cols,
1962 range_conds,
1963 is_unique,
1964 index_columns,
1965 ..
1966 } => {
1967 let num_pk_cols = table_schema.primary_key_columns.len();
1968 let num_index_cols = index_columns.len();
1969 let mut pk_keys: Vec<Vec<u8>> = Vec::new();
1970
1971 {
1972 let mut scan_err: Option<SqlError> = None;
1973 wtx.table_scan_from(&idx_table, &prefix, |key, value| {
1974 if !key.starts_with(&prefix) {
1975 return Ok(false);
1976 }
1977 match check_range_conditions(key, num_prefix_cols, &range_conds, num_index_cols)
1978 {
1979 Ok(RangeCheck::ExceedsUpper) => return Ok(false),
1980 Ok(RangeCheck::BelowLower) => return Ok(true),
1981 Ok(RangeCheck::Match) => {}
1982 Err(e) => {
1983 scan_err = Some(e);
1984 return Ok(false);
1985 }
1986 }
1987 match extract_pk_key(key, value, is_unique, num_index_cols, num_pk_cols) {
1988 Ok(pk) => pk_keys.push(pk),
1989 Err(e) => {
1990 scan_err = Some(e);
1991 return Ok(false);
1992 }
1993 }
1994 Ok(true)
1995 })
1996 .map_err(SqlError::Storage)?;
1997 if let Some(e) = scan_err {
1998 return Err(e);
1999 }
2000 }
2001
2002 let mut rows = Vec::new();
2003 let col_map = ColumnMap::new(columns);
2004 for pk_key in &pk_keys {
2005 if let Some(value) = wtx
2006 .table_get(lower_name.as_bytes(), pk_key)
2007 .map_err(SqlError::Storage)?
2008 {
2009 let row = decode_full_row(table_schema, pk_key, &value)?;
2010 if let Some(ref expr) = where_clause {
2011 match eval_expr(expr, &col_map, &row) {
2012 Ok(val) if is_truthy(&val) => rows.push(row),
2013 _ => {}
2014 }
2015 } else {
2016 rows.push(row);
2017 }
2018 }
2019 }
2020 Ok((rows, where_clause.is_some()))
2021 }
2022 }
2023}
2024
2025fn collect_keyed_rows_read(
2028 db: &Database,
2029 table_schema: &TableSchema,
2030 where_clause: &Option<Expr>,
2031) -> Result<Vec<(Vec<u8>, Vec<Value>)>> {
2032 let plan = planner::plan_select(table_schema, where_clause);
2033 let lower_name = &table_schema.name;
2034
2035 match plan {
2036 ScanPlan::SeqScan => {
2037 let mut rows = Vec::new();
2038 let mut rtx = db.begin_read();
2039 let mut scan_err: Option<SqlError> = None;
2040 rtx.table_for_each(lower_name.as_bytes(), |key, value| {
2041 match decode_full_row(table_schema, key, value) {
2042 Ok(row) => rows.push((key.to_vec(), row)),
2043 Err(e) => scan_err = Some(e),
2044 }
2045 Ok(())
2046 })
2047 .map_err(SqlError::Storage)?;
2048 if let Some(e) = scan_err {
2049 return Err(e);
2050 }
2051 Ok(rows)
2052 }
2053
2054 ScanPlan::PkLookup { pk_values } => {
2055 let key = encode_composite_key(&pk_values);
2056 let mut rtx = db.begin_read();
2057 match rtx
2058 .table_get(lower_name.as_bytes(), &key)
2059 .map_err(SqlError::Storage)?
2060 {
2061 Some(value) => {
2062 let row = decode_full_row(table_schema, &key, &value)?;
2063 Ok(vec![(key, row)])
2064 }
2065 None => Ok(vec![]),
2066 }
2067 }
2068
2069 ScanPlan::IndexScan {
2070 idx_table,
2071 prefix,
2072 num_prefix_cols,
2073 range_conds,
2074 is_unique,
2075 index_columns,
2076 ..
2077 } => {
2078 let num_pk_cols = table_schema.primary_key_columns.len();
2079 let num_index_cols = index_columns.len();
2080 let mut pk_keys: Vec<Vec<u8>> = Vec::new();
2081
2082 {
2083 let mut rtx = db.begin_read();
2084 let mut scan_err: Option<SqlError> = None;
2085 rtx.table_scan_from(&idx_table, &prefix, |key, value| {
2086 if !key.starts_with(&prefix) {
2087 return Ok(false);
2088 }
2089 match check_range_conditions(key, num_prefix_cols, &range_conds, num_index_cols)
2090 {
2091 Ok(RangeCheck::ExceedsUpper) => return Ok(false),
2092 Ok(RangeCheck::BelowLower) => return Ok(true),
2093 Ok(RangeCheck::Match) => {}
2094 Err(e) => {
2095 scan_err = Some(e);
2096 return Ok(false);
2097 }
2098 }
2099 match extract_pk_key(key, value, is_unique, num_index_cols, num_pk_cols) {
2100 Ok(pk) => pk_keys.push(pk),
2101 Err(e) => {
2102 scan_err = Some(e);
2103 return Ok(false);
2104 }
2105 }
2106 Ok(true)
2107 })
2108 .map_err(SqlError::Storage)?;
2109 if let Some(e) = scan_err {
2110 return Err(e);
2111 }
2112 }
2113
2114 let mut rows = Vec::new();
2115 let mut rtx = db.begin_read();
2116 for pk_key in &pk_keys {
2117 if let Some(value) = rtx
2118 .table_get(lower_name.as_bytes(), pk_key)
2119 .map_err(SqlError::Storage)?
2120 {
2121 rows.push((
2122 pk_key.clone(),
2123 decode_full_row(table_schema, pk_key, &value)?,
2124 ));
2125 }
2126 }
2127 Ok(rows)
2128 }
2129 }
2130}
2131
2132fn collect_keyed_rows_write(
2134 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
2135 table_schema: &TableSchema,
2136 where_clause: &Option<Expr>,
2137) -> Result<Vec<(Vec<u8>, Vec<Value>)>> {
2138 let plan = planner::plan_select(table_schema, where_clause);
2139 let lower_name = &table_schema.name;
2140
2141 match plan {
2142 ScanPlan::SeqScan => {
2143 let mut rows = Vec::new();
2144 let mut scan_err: Option<SqlError> = None;
2145 wtx.table_for_each(lower_name.as_bytes(), |key, value| {
2146 match decode_full_row(table_schema, key, value) {
2147 Ok(row) => rows.push((key.to_vec(), row)),
2148 Err(e) => scan_err = Some(e),
2149 }
2150 Ok(())
2151 })
2152 .map_err(SqlError::Storage)?;
2153 if let Some(e) = scan_err {
2154 return Err(e);
2155 }
2156 Ok(rows)
2157 }
2158
2159 ScanPlan::PkLookup { pk_values } => {
2160 let key = encode_composite_key(&pk_values);
2161 match wtx
2162 .table_get(lower_name.as_bytes(), &key)
2163 .map_err(SqlError::Storage)?
2164 {
2165 Some(value) => {
2166 let row = decode_full_row(table_schema, &key, &value)?;
2167 Ok(vec![(key, row)])
2168 }
2169 None => Ok(vec![]),
2170 }
2171 }
2172
2173 ScanPlan::IndexScan {
2174 idx_table,
2175 prefix,
2176 num_prefix_cols,
2177 range_conds,
2178 is_unique,
2179 index_columns,
2180 ..
2181 } => {
2182 let num_pk_cols = table_schema.primary_key_columns.len();
2183 let num_index_cols = index_columns.len();
2184 let mut pk_keys: Vec<Vec<u8>> = Vec::new();
2185
2186 {
2187 let mut scan_err: Option<SqlError> = None;
2188 wtx.table_scan_from(&idx_table, &prefix, |key, value| {
2189 if !key.starts_with(&prefix) {
2190 return Ok(false);
2191 }
2192 match check_range_conditions(key, num_prefix_cols, &range_conds, num_index_cols)
2193 {
2194 Ok(RangeCheck::ExceedsUpper) => return Ok(false),
2195 Ok(RangeCheck::BelowLower) => return Ok(true),
2196 Ok(RangeCheck::Match) => {}
2197 Err(e) => {
2198 scan_err = Some(e);
2199 return Ok(false);
2200 }
2201 }
2202 match extract_pk_key(key, value, is_unique, num_index_cols, num_pk_cols) {
2203 Ok(pk) => pk_keys.push(pk),
2204 Err(e) => {
2205 scan_err = Some(e);
2206 return Ok(false);
2207 }
2208 }
2209 Ok(true)
2210 })
2211 .map_err(SqlError::Storage)?;
2212 if let Some(e) = scan_err {
2213 return Err(e);
2214 }
2215 }
2216
2217 let mut rows = Vec::new();
2218 for pk_key in &pk_keys {
2219 if let Some(value) = wtx
2220 .table_get(lower_name.as_bytes(), pk_key)
2221 .map_err(SqlError::Storage)?
2222 {
2223 rows.push((
2224 pk_key.clone(),
2225 decode_full_row(table_schema, pk_key, &value)?,
2226 ));
2227 }
2228 }
2229 Ok(rows)
2230 }
2231 }
2232}
2233
2234fn exec_insert(
2237 db: &Database,
2238 schema: &SchemaManager,
2239 stmt: &InsertStmt,
2240 params: &[Value],
2241) -> Result<ExecutionResult> {
2242 let empty_ctes = CteContext::new();
2243 let materialized;
2244 let stmt = if insert_has_subquery(stmt) {
2245 materialized = materialize_insert(stmt, &mut |sub| {
2246 exec_subquery_read(db, schema, sub, &empty_ctes)
2247 })?;
2248 &materialized
2249 } else {
2250 stmt
2251 };
2252
2253 let lower_name = stmt.table.to_ascii_lowercase();
2254 let table_schema = schema
2255 .get(&lower_name)
2256 .ok_or_else(|| SqlError::TableNotFound(stmt.table.clone()))?;
2257
2258 let insert_columns = if stmt.columns.is_empty() {
2259 table_schema
2260 .columns
2261 .iter()
2262 .map(|c| c.name.clone())
2263 .collect::<Vec<_>>()
2264 } else {
2265 stmt.columns
2266 .iter()
2267 .map(|c| c.to_ascii_lowercase())
2268 .collect()
2269 };
2270
2271 let col_indices: Vec<usize> = insert_columns
2272 .iter()
2273 .map(|name| {
2274 table_schema
2275 .column_index(name)
2276 .ok_or_else(|| SqlError::ColumnNotFound(name.clone()))
2277 })
2278 .collect::<Result<_>>()?;
2279
2280 let defaults: Vec<(usize, &Expr)> = table_schema
2282 .columns
2283 .iter()
2284 .filter(|c| c.default_expr.is_some() && !col_indices.contains(&(c.position as usize)))
2285 .map(|c| (c.position as usize, c.default_expr.as_ref().unwrap()))
2286 .collect();
2287
2288 let has_checks = table_schema.has_checks();
2290 let check_col_map = if has_checks {
2291 Some(ColumnMap::new(&table_schema.columns))
2292 } else {
2293 None
2294 };
2295
2296 let select_rows = match &stmt.source {
2297 InsertSource::Select(sq) => {
2298 let insert_ctes = materialize_all_ctes(&sq.ctes, sq.recursive, &mut |body, ctx| {
2299 exec_query_body_read(db, schema, body, ctx)
2300 })?;
2301 let qr = exec_query_body_read(db, schema, &sq.body, &insert_ctes)?;
2302 Some(qr.rows)
2303 }
2304 InsertSource::Values(_) => None,
2305 };
2306
2307 let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
2308 let mut count: u64 = 0;
2309
2310 let pk_indices = table_schema.pk_indices();
2311 let non_pk = table_schema.non_pk_indices();
2312 let enc_pos = table_schema.encoding_positions();
2313 let phys_count = table_schema.physical_non_pk_count();
2314 let mut row = vec![Value::Null; table_schema.columns.len()];
2315 let mut pk_values: Vec<Value> = vec![Value::Null; pk_indices.len()];
2316 let mut value_values: Vec<Value> = vec![Value::Null; phys_count];
2317 let mut key_buf: Vec<u8> = Vec::with_capacity(64);
2318 let mut value_buf: Vec<u8> = Vec::with_capacity(256);
2319 let mut fk_key_buf: Vec<u8> = Vec::with_capacity(64);
2320
2321 let values = match &stmt.source {
2322 InsertSource::Values(rows) => Some(rows.as_slice()),
2323 InsertSource::Select(_) => None,
2324 };
2325 let sel_rows = select_rows.as_deref();
2326
2327 let total = match (values, sel_rows) {
2328 (Some(rows), _) => rows.len(),
2329 (_, Some(rows)) => rows.len(),
2330 _ => 0,
2331 };
2332
2333 if let Some(sel) = sel_rows {
2334 if !sel.is_empty() && sel[0].len() != insert_columns.len() {
2335 return Err(SqlError::InvalidValue(format!(
2336 "INSERT ... SELECT column count mismatch: expected {}, got {}",
2337 insert_columns.len(),
2338 sel[0].len()
2339 )));
2340 }
2341 }
2342
2343 for idx in 0..total {
2344 for v in row.iter_mut() {
2345 *v = Value::Null;
2346 }
2347
2348 if let Some(value_rows) = values {
2349 let value_row = &value_rows[idx];
2350 if value_row.len() != insert_columns.len() {
2351 return Err(SqlError::InvalidValue(format!(
2352 "expected {} values, got {}",
2353 insert_columns.len(),
2354 value_row.len()
2355 )));
2356 }
2357 for (i, expr) in value_row.iter().enumerate() {
2358 let val = if let Expr::Parameter(n) = expr {
2359 params
2360 .get(n - 1)
2361 .cloned()
2362 .ok_or_else(|| SqlError::Parse(format!("unbound parameter ${n}")))?
2363 } else {
2364 eval_const_expr(expr)?
2365 };
2366 let col_idx = col_indices[i];
2367 let col = &table_schema.columns[col_idx];
2368 let got_type = val.data_type();
2369 row[col_idx] = if val.is_null() {
2370 Value::Null
2371 } else {
2372 val.coerce_into(col.data_type)
2373 .ok_or_else(|| SqlError::TypeMismatch {
2374 expected: col.data_type.to_string(),
2375 got: got_type.to_string(),
2376 })?
2377 };
2378 }
2379 } else if let Some(sel) = sel_rows {
2380 let sel_row = &sel[idx];
2381 for (i, val) in sel_row.iter().enumerate() {
2382 let col_idx = col_indices[i];
2383 let col = &table_schema.columns[col_idx];
2384 let got_type = val.data_type();
2385 row[col_idx] = if val.is_null() {
2386 Value::Null
2387 } else {
2388 val.clone().coerce_into(col.data_type).ok_or_else(|| {
2389 SqlError::TypeMismatch {
2390 expected: col.data_type.to_string(),
2391 got: got_type.to_string(),
2392 }
2393 })?
2394 };
2395 }
2396 }
2397
2398 for &(pos, def_expr) in &defaults {
2400 let val = eval_const_expr(def_expr)?;
2401 let col = &table_schema.columns[pos];
2402 if val.is_null() {
2403 } else {
2405 let got_type = val.data_type();
2406 row[pos] =
2407 val.coerce_into(col.data_type)
2408 .ok_or_else(|| SqlError::TypeMismatch {
2409 expected: col.data_type.to_string(),
2410 got: got_type.to_string(),
2411 })?;
2412 }
2413 }
2414
2415 for col in &table_schema.columns {
2416 if !col.nullable && row[col.position as usize].is_null() {
2417 return Err(SqlError::NotNullViolation(col.name.clone()));
2418 }
2419 }
2420
2421 if let Some(ref col_map) = check_col_map {
2423 for col in &table_schema.columns {
2424 if let Some(ref check) = col.check_expr {
2425 let result = eval_expr(check, col_map, &row)?;
2426 if !is_truthy(&result) && !result.is_null() {
2427 let name = col.check_name.as_deref().unwrap_or(&col.name);
2428 return Err(SqlError::CheckViolation(name.to_string()));
2429 }
2430 }
2431 }
2432 for tc in &table_schema.check_constraints {
2433 let result = eval_expr(&tc.expr, col_map, &row)?;
2434 if !is_truthy(&result) && !result.is_null() {
2435 let name = tc.name.as_deref().unwrap_or(&tc.sql);
2436 return Err(SqlError::CheckViolation(name.to_string()));
2437 }
2438 }
2439 }
2440
2441 for fk in &table_schema.foreign_keys {
2443 let any_null = fk.columns.iter().any(|&ci| row[ci as usize].is_null());
2444 if any_null {
2445 continue; }
2447 let fk_vals: Vec<Value> = fk
2448 .columns
2449 .iter()
2450 .map(|&ci| row[ci as usize].clone())
2451 .collect();
2452 fk_key_buf.clear();
2453 encode_composite_key_into(&fk_vals, &mut fk_key_buf);
2454 let found = wtx
2455 .table_get(fk.foreign_table.as_bytes(), &fk_key_buf)
2456 .map_err(SqlError::Storage)?;
2457 if found.is_none() {
2458 let name = fk.name.as_deref().unwrap_or(&fk.foreign_table);
2459 return Err(SqlError::ForeignKeyViolation(name.to_string()));
2460 }
2461 }
2462
2463 for (j, &i) in pk_indices.iter().enumerate() {
2464 pk_values[j] = std::mem::replace(&mut row[i], Value::Null);
2465 }
2466 encode_composite_key_into(&pk_values, &mut key_buf);
2467
2468 for (j, &i) in non_pk.iter().enumerate() {
2469 value_values[enc_pos[j] as usize] = std::mem::replace(&mut row[i], Value::Null);
2470 }
2471 encode_row_into(&value_values, &mut value_buf);
2472
2473 if key_buf.len() > citadel_core::MAX_KEY_SIZE {
2474 return Err(SqlError::KeyTooLarge {
2475 size: key_buf.len(),
2476 max: citadel_core::MAX_KEY_SIZE,
2477 });
2478 }
2479 if value_buf.len() > citadel_core::MAX_INLINE_VALUE_SIZE {
2480 return Err(SqlError::RowTooLarge {
2481 size: value_buf.len(),
2482 max: citadel_core::MAX_INLINE_VALUE_SIZE,
2483 });
2484 }
2485
2486 let is_new = wtx
2487 .table_insert(stmt.table.as_bytes(), &key_buf, &value_buf)
2488 .map_err(SqlError::Storage)?;
2489 if !is_new {
2490 return Err(SqlError::DuplicateKey);
2491 }
2492
2493 if !table_schema.indices.is_empty() {
2494 for (j, &i) in pk_indices.iter().enumerate() {
2495 row[i] = pk_values[j].clone();
2496 }
2497 for (j, &i) in non_pk.iter().enumerate() {
2498 row[i] = std::mem::replace(&mut value_values[enc_pos[j] as usize], Value::Null);
2499 }
2500 insert_index_entries(&mut wtx, table_schema, &row, &pk_values)?;
2501 }
2502 count += 1;
2503 }
2504
2505 wtx.commit().map_err(SqlError::Storage)?;
2506 Ok(ExecutionResult::RowsAffected(count))
2507}
2508
2509fn has_subquery(expr: &Expr) -> bool {
2510 crate::parser::has_subquery(expr)
2511}
2512
2513fn stmt_has_subquery(stmt: &SelectStmt) -> bool {
2514 if let Some(ref w) = stmt.where_clause {
2515 if has_subquery(w) {
2516 return true;
2517 }
2518 }
2519 if let Some(ref h) = stmt.having {
2520 if has_subquery(h) {
2521 return true;
2522 }
2523 }
2524 for col in &stmt.columns {
2525 if let SelectColumn::Expr { expr, .. } = col {
2526 if has_subquery(expr) {
2527 return true;
2528 }
2529 }
2530 }
2531 for ob in &stmt.order_by {
2532 if has_subquery(&ob.expr) {
2533 return true;
2534 }
2535 }
2536 for join in &stmt.joins {
2537 if let Some(ref on_expr) = join.on_clause {
2538 if has_subquery(on_expr) {
2539 return true;
2540 }
2541 }
2542 }
2543 false
2544}
2545
2546fn materialize_expr(
2547 expr: &Expr,
2548 exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
2549) -> Result<Expr> {
2550 match expr {
2551 Expr::InSubquery {
2552 expr: e,
2553 subquery,
2554 negated,
2555 } => {
2556 let inner = materialize_expr(e, exec_sub)?;
2557 let qr = exec_sub(subquery)?;
2558 if !qr.columns.is_empty() && qr.columns.len() != 1 {
2559 return Err(SqlError::SubqueryMultipleColumns);
2560 }
2561 let mut values = std::collections::HashSet::new();
2562 let mut has_null = false;
2563 for row in &qr.rows {
2564 if row[0].is_null() {
2565 has_null = true;
2566 } else {
2567 values.insert(row[0].clone());
2568 }
2569 }
2570 Ok(Expr::InSet {
2571 expr: Box::new(inner),
2572 values,
2573 has_null,
2574 negated: *negated,
2575 })
2576 }
2577 Expr::ScalarSubquery(subquery) => {
2578 let qr = exec_sub(subquery)?;
2579 if qr.rows.len() > 1 {
2580 return Err(SqlError::SubqueryMultipleRows);
2581 }
2582 let val = if qr.rows.is_empty() {
2583 Value::Null
2584 } else {
2585 qr.rows[0][0].clone()
2586 };
2587 Ok(Expr::Literal(val))
2588 }
2589 Expr::Exists { subquery, negated } => {
2590 let qr = exec_sub(subquery)?;
2591 let exists = !qr.rows.is_empty();
2592 let result = if *negated { !exists } else { exists };
2593 Ok(Expr::Literal(Value::Boolean(result)))
2594 }
2595 Expr::InList {
2596 expr: e,
2597 list,
2598 negated,
2599 } => {
2600 let inner = materialize_expr(e, exec_sub)?;
2601 let items = list
2602 .iter()
2603 .map(|item| materialize_expr(item, exec_sub))
2604 .collect::<Result<Vec<_>>>()?;
2605 Ok(Expr::InList {
2606 expr: Box::new(inner),
2607 list: items,
2608 negated: *negated,
2609 })
2610 }
2611 Expr::BinaryOp { left, op, right } => Ok(Expr::BinaryOp {
2612 left: Box::new(materialize_expr(left, exec_sub)?),
2613 op: *op,
2614 right: Box::new(materialize_expr(right, exec_sub)?),
2615 }),
2616 Expr::UnaryOp { op, expr: e } => Ok(Expr::UnaryOp {
2617 op: *op,
2618 expr: Box::new(materialize_expr(e, exec_sub)?),
2619 }),
2620 Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(materialize_expr(e, exec_sub)?))),
2621 Expr::IsNotNull(e) => Ok(Expr::IsNotNull(Box::new(materialize_expr(e, exec_sub)?))),
2622 Expr::InSet {
2623 expr: e,
2624 values,
2625 has_null,
2626 negated,
2627 } => Ok(Expr::InSet {
2628 expr: Box::new(materialize_expr(e, exec_sub)?),
2629 values: values.clone(),
2630 has_null: *has_null,
2631 negated: *negated,
2632 }),
2633 Expr::Between {
2634 expr: e,
2635 low,
2636 high,
2637 negated,
2638 } => Ok(Expr::Between {
2639 expr: Box::new(materialize_expr(e, exec_sub)?),
2640 low: Box::new(materialize_expr(low, exec_sub)?),
2641 high: Box::new(materialize_expr(high, exec_sub)?),
2642 negated: *negated,
2643 }),
2644 Expr::Like {
2645 expr: e,
2646 pattern,
2647 escape,
2648 negated,
2649 } => {
2650 let esc = escape
2651 .as_ref()
2652 .map(|es| materialize_expr(es, exec_sub).map(Box::new))
2653 .transpose()?;
2654 Ok(Expr::Like {
2655 expr: Box::new(materialize_expr(e, exec_sub)?),
2656 pattern: Box::new(materialize_expr(pattern, exec_sub)?),
2657 escape: esc,
2658 negated: *negated,
2659 })
2660 }
2661 Expr::Case {
2662 operand,
2663 conditions,
2664 else_result,
2665 } => {
2666 let op = operand
2667 .as_ref()
2668 .map(|e| materialize_expr(e, exec_sub).map(Box::new))
2669 .transpose()?;
2670 let conds = conditions
2671 .iter()
2672 .map(|(c, r)| {
2673 Ok((
2674 materialize_expr(c, exec_sub)?,
2675 materialize_expr(r, exec_sub)?,
2676 ))
2677 })
2678 .collect::<Result<Vec<_>>>()?;
2679 let else_r = else_result
2680 .as_ref()
2681 .map(|e| materialize_expr(e, exec_sub).map(Box::new))
2682 .transpose()?;
2683 Ok(Expr::Case {
2684 operand: op,
2685 conditions: conds,
2686 else_result: else_r,
2687 })
2688 }
2689 Expr::Coalesce(args) => {
2690 let materialized = args
2691 .iter()
2692 .map(|a| materialize_expr(a, exec_sub))
2693 .collect::<Result<Vec<_>>>()?;
2694 Ok(Expr::Coalesce(materialized))
2695 }
2696 Expr::Cast { expr: e, data_type } => Ok(Expr::Cast {
2697 expr: Box::new(materialize_expr(e, exec_sub)?),
2698 data_type: *data_type,
2699 }),
2700 Expr::Function { name, args } => {
2701 let materialized = args
2702 .iter()
2703 .map(|a| materialize_expr(a, exec_sub))
2704 .collect::<Result<Vec<_>>>()?;
2705 Ok(Expr::Function {
2706 name: name.clone(),
2707 args: materialized,
2708 })
2709 }
2710 other => Ok(other.clone()),
2711 }
2712}
2713
2714fn materialize_stmt(
2715 stmt: &SelectStmt,
2716 exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
2717) -> Result<SelectStmt> {
2718 let where_clause = stmt
2719 .where_clause
2720 .as_ref()
2721 .map(|e| materialize_expr(e, exec_sub))
2722 .transpose()?;
2723 let having = stmt
2724 .having
2725 .as_ref()
2726 .map(|e| materialize_expr(e, exec_sub))
2727 .transpose()?;
2728 let columns = stmt
2729 .columns
2730 .iter()
2731 .map(|c| match c {
2732 SelectColumn::AllColumns => Ok(SelectColumn::AllColumns),
2733 SelectColumn::Expr { expr, alias } => Ok(SelectColumn::Expr {
2734 expr: materialize_expr(expr, exec_sub)?,
2735 alias: alias.clone(),
2736 }),
2737 })
2738 .collect::<Result<Vec<_>>>()?;
2739 let order_by = stmt
2740 .order_by
2741 .iter()
2742 .map(|ob| {
2743 Ok(OrderByItem {
2744 expr: materialize_expr(&ob.expr, exec_sub)?,
2745 descending: ob.descending,
2746 nulls_first: ob.nulls_first,
2747 })
2748 })
2749 .collect::<Result<Vec<_>>>()?;
2750 let joins = stmt
2751 .joins
2752 .iter()
2753 .map(|j| {
2754 let on_clause = j
2755 .on_clause
2756 .as_ref()
2757 .map(|e| materialize_expr(e, exec_sub))
2758 .transpose()?;
2759 Ok(JoinClause {
2760 join_type: j.join_type,
2761 table: j.table.clone(),
2762 on_clause,
2763 })
2764 })
2765 .collect::<Result<Vec<_>>>()?;
2766 let group_by = stmt
2767 .group_by
2768 .iter()
2769 .map(|e| materialize_expr(e, exec_sub))
2770 .collect::<Result<Vec<_>>>()?;
2771 Ok(SelectStmt {
2772 columns,
2773 from: stmt.from.clone(),
2774 from_alias: stmt.from_alias.clone(),
2775 joins,
2776 distinct: stmt.distinct,
2777 where_clause,
2778 order_by,
2779 limit: stmt.limit.clone(),
2780 offset: stmt.offset.clone(),
2781 group_by,
2782 having,
2783 })
2784}
2785
2786type CteContext = HashMap<String, QueryResult>;
2787type ScanTableFn<'a> = &'a mut dyn FnMut(&str) -> Result<(TableSchema, Vec<Vec<Value>>)>;
2788
2789fn exec_subquery_read(
2790 db: &Database,
2791 schema: &SchemaManager,
2792 stmt: &SelectStmt,
2793 ctes: &CteContext,
2794) -> Result<QueryResult> {
2795 match exec_select(db, schema, stmt, ctes)? {
2796 ExecutionResult::Query(qr) => Ok(qr),
2797 _ => Ok(QueryResult {
2798 columns: vec![],
2799 rows: vec![],
2800 }),
2801 }
2802}
2803
2804fn exec_subquery_write(
2805 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
2806 schema: &SchemaManager,
2807 stmt: &SelectStmt,
2808 ctes: &CteContext,
2809) -> Result<QueryResult> {
2810 match exec_select_in_txn(wtx, schema, stmt, ctes)? {
2811 ExecutionResult::Query(qr) => Ok(qr),
2812 _ => Ok(QueryResult {
2813 columns: vec![],
2814 rows: vec![],
2815 }),
2816 }
2817}
2818
2819fn update_has_subquery(stmt: &UpdateStmt) -> bool {
2820 stmt.where_clause.as_ref().is_some_and(has_subquery)
2821 || stmt.assignments.iter().any(|(_, e)| has_subquery(e))
2822}
2823
2824fn materialize_update(
2825 stmt: &UpdateStmt,
2826 exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
2827) -> Result<UpdateStmt> {
2828 let where_clause = stmt
2829 .where_clause
2830 .as_ref()
2831 .map(|e| materialize_expr(e, exec_sub))
2832 .transpose()?;
2833 let assignments = stmt
2834 .assignments
2835 .iter()
2836 .map(|(name, expr)| Ok((name.clone(), materialize_expr(expr, exec_sub)?)))
2837 .collect::<Result<Vec<_>>>()?;
2838 Ok(UpdateStmt {
2839 table: stmt.table.clone(),
2840 assignments,
2841 where_clause,
2842 })
2843}
2844
2845fn delete_has_subquery(stmt: &DeleteStmt) -> bool {
2846 stmt.where_clause.as_ref().is_some_and(has_subquery)
2847}
2848
2849fn materialize_delete(
2850 stmt: &DeleteStmt,
2851 exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
2852) -> Result<DeleteStmt> {
2853 let where_clause = stmt
2854 .where_clause
2855 .as_ref()
2856 .map(|e| materialize_expr(e, exec_sub))
2857 .transpose()?;
2858 Ok(DeleteStmt {
2859 table: stmt.table.clone(),
2860 where_clause,
2861 })
2862}
2863
2864fn insert_has_subquery(stmt: &InsertStmt) -> bool {
2865 match &stmt.source {
2866 InsertSource::Values(rows) => rows.iter().any(|row| row.iter().any(has_subquery)),
2867 InsertSource::Select(sq) => {
2868 sq.ctes.iter().any(|c| query_body_has_subquery(&c.body))
2869 || query_body_has_subquery(&sq.body)
2870 }
2871 }
2872}
2873
2874fn query_body_has_subquery(body: &QueryBody) -> bool {
2875 match body {
2876 QueryBody::Select(sel) => stmt_has_subquery(sel),
2877 QueryBody::Compound(comp) => {
2878 query_body_has_subquery(&comp.left) || query_body_has_subquery(&comp.right)
2879 }
2880 }
2881}
2882
2883fn materialize_insert(
2884 stmt: &InsertStmt,
2885 exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
2886) -> Result<InsertStmt> {
2887 let source = match &stmt.source {
2888 InsertSource::Values(rows) => {
2889 let mat = rows
2890 .iter()
2891 .map(|row| {
2892 row.iter()
2893 .map(|e| materialize_expr(e, exec_sub))
2894 .collect::<Result<Vec<_>>>()
2895 })
2896 .collect::<Result<Vec<_>>>()?;
2897 InsertSource::Values(mat)
2898 }
2899 InsertSource::Select(sq) => {
2900 let ctes = sq
2901 .ctes
2902 .iter()
2903 .map(|c| {
2904 Ok(CteDefinition {
2905 name: c.name.clone(),
2906 column_aliases: c.column_aliases.clone(),
2907 body: materialize_query_body(&c.body, exec_sub)?,
2908 })
2909 })
2910 .collect::<Result<Vec<_>>>()?;
2911 let body = materialize_query_body(&sq.body, exec_sub)?;
2912 InsertSource::Select(Box::new(SelectQuery {
2913 ctes,
2914 recursive: sq.recursive,
2915 body,
2916 }))
2917 }
2918 };
2919 Ok(InsertStmt {
2920 table: stmt.table.clone(),
2921 columns: stmt.columns.clone(),
2922 source,
2923 })
2924}
2925
2926fn materialize_query_body(
2927 body: &QueryBody,
2928 exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
2929) -> Result<QueryBody> {
2930 match body {
2931 QueryBody::Select(sel) => Ok(QueryBody::Select(Box::new(materialize_stmt(
2932 sel, exec_sub,
2933 )?))),
2934 QueryBody::Compound(comp) => Ok(QueryBody::Compound(Box::new(CompoundSelect {
2935 op: comp.op.clone(),
2936 all: comp.all,
2937 left: Box::new(materialize_query_body(&comp.left, exec_sub)?),
2938 right: Box::new(materialize_query_body(&comp.right, exec_sub)?),
2939 order_by: comp.order_by.clone(),
2940 limit: comp.limit.clone(),
2941 offset: comp.offset.clone(),
2942 }))),
2943 }
2944}
2945
2946fn exec_query_body(
2947 db: &Database,
2948 schema: &SchemaManager,
2949 body: &QueryBody,
2950 ctes: &CteContext,
2951) -> Result<ExecutionResult> {
2952 match body {
2953 QueryBody::Select(sel) => exec_select(db, schema, sel, ctes),
2954 QueryBody::Compound(comp) => exec_compound_select(db, schema, comp, ctes),
2955 }
2956}
2957
2958fn exec_query_body_in_txn(
2959 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
2960 schema: &SchemaManager,
2961 body: &QueryBody,
2962 ctes: &CteContext,
2963) -> Result<ExecutionResult> {
2964 match body {
2965 QueryBody::Select(sel) => exec_select_in_txn(wtx, schema, sel, ctes),
2966 QueryBody::Compound(comp) => exec_compound_select_in_txn(wtx, schema, comp, ctes),
2967 }
2968}
2969
2970fn exec_query_body_read(
2971 db: &Database,
2972 schema: &SchemaManager,
2973 body: &QueryBody,
2974 ctes: &CteContext,
2975) -> Result<QueryResult> {
2976 match exec_query_body(db, schema, body, ctes)? {
2977 ExecutionResult::Query(qr) => Ok(qr),
2978 _ => Ok(QueryResult {
2979 columns: vec![],
2980 rows: vec![],
2981 }),
2982 }
2983}
2984
2985fn exec_query_body_write(
2986 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
2987 schema: &SchemaManager,
2988 body: &QueryBody,
2989 ctes: &CteContext,
2990) -> Result<QueryResult> {
2991 match exec_query_body_in_txn(wtx, schema, body, ctes)? {
2992 ExecutionResult::Query(qr) => Ok(qr),
2993 _ => Ok(QueryResult {
2994 columns: vec![],
2995 rows: vec![],
2996 }),
2997 }
2998}
2999
3000fn exec_compound_select(
3001 db: &Database,
3002 schema: &SchemaManager,
3003 comp: &CompoundSelect,
3004 ctes: &CteContext,
3005) -> Result<ExecutionResult> {
3006 let left_qr = match exec_query_body(db, schema, &comp.left, ctes)? {
3007 ExecutionResult::Query(qr) => qr,
3008 _ => QueryResult {
3009 columns: vec![],
3010 rows: vec![],
3011 },
3012 };
3013 let right_qr = match exec_query_body(db, schema, &comp.right, ctes)? {
3014 ExecutionResult::Query(qr) => qr,
3015 _ => QueryResult {
3016 columns: vec![],
3017 rows: vec![],
3018 },
3019 };
3020 apply_set_operation(comp, left_qr, right_qr)
3021}
3022
3023fn exec_compound_select_in_txn(
3024 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
3025 schema: &SchemaManager,
3026 comp: &CompoundSelect,
3027 ctes: &CteContext,
3028) -> Result<ExecutionResult> {
3029 let left_qr = match exec_query_body_in_txn(wtx, schema, &comp.left, ctes)? {
3030 ExecutionResult::Query(qr) => qr,
3031 _ => QueryResult {
3032 columns: vec![],
3033 rows: vec![],
3034 },
3035 };
3036 let right_qr = match exec_query_body_in_txn(wtx, schema, &comp.right, ctes)? {
3037 ExecutionResult::Query(qr) => qr,
3038 _ => QueryResult {
3039 columns: vec![],
3040 rows: vec![],
3041 },
3042 };
3043 apply_set_operation(comp, left_qr, right_qr)
3044}
3045
3046fn exec_select_query(
3049 db: &Database,
3050 schema: &SchemaManager,
3051 sq: &SelectQuery,
3052) -> Result<ExecutionResult> {
3053 if let Some(fused) = try_fuse_cte(sq) {
3054 let empty = CteContext::new();
3055 return exec_query_body(db, schema, &fused, &empty);
3056 }
3057 let ctes = materialize_all_ctes(&sq.ctes, sq.recursive, &mut |body, ctx| {
3058 exec_query_body_read(db, schema, body, ctx)
3059 })?;
3060 exec_query_body(db, schema, &sq.body, &ctes)
3061}
3062
3063fn exec_select_query_in_txn(
3064 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
3065 schema: &SchemaManager,
3066 sq: &SelectQuery,
3067) -> Result<ExecutionResult> {
3068 if let Some(fused) = try_fuse_cte(sq) {
3069 let empty = CteContext::new();
3070 return exec_query_body_in_txn(wtx, schema, &fused, &empty);
3071 }
3072 let ctes = materialize_all_ctes(&sq.ctes, sq.recursive, &mut |body, ctx| {
3073 exec_query_body_write(wtx, schema, body, ctx)
3074 })?;
3075 exec_query_body_in_txn(wtx, schema, &sq.body, &ctes)
3076}
3077
3078fn try_fuse_cte(sq: &SelectQuery) -> Option<QueryBody> {
3080 if sq.ctes.len() != 1 || sq.recursive {
3081 return None;
3082 }
3083 let cte = &sq.ctes[0];
3084 if !cte.column_aliases.is_empty() {
3085 return None;
3086 }
3087
3088 let inner = match &cte.body {
3089 QueryBody::Select(s) => s.as_ref(),
3090 _ => return None,
3091 };
3092
3093 if !inner.joins.is_empty()
3094 || !inner.group_by.is_empty()
3095 || inner.distinct
3096 || inner.having.is_some()
3097 || inner.limit.is_some()
3098 || inner.offset.is_some()
3099 || !inner.order_by.is_empty()
3100 || stmt_has_subquery(inner)
3101 {
3102 return None;
3103 }
3104
3105 let all_simple_refs = inner.columns.iter().all(|c| match c {
3106 SelectColumn::AllColumns => true,
3107 SelectColumn::Expr { expr, alias } => alias.is_none() && matches!(expr, Expr::Column(_)),
3108 });
3109 if !all_simple_refs {
3110 return None;
3111 }
3112
3113 let outer = match &sq.body {
3114 QueryBody::Select(s) => s.as_ref(),
3115 _ => return None,
3116 };
3117 if !outer.from.eq_ignore_ascii_case(&cte.name) || !outer.joins.is_empty() {
3118 return None;
3119 }
3120
3121 let merged_where = match (&inner.where_clause, &outer.where_clause) {
3122 (Some(iw), Some(ow)) => Some(Expr::BinaryOp {
3123 left: Box::new(iw.clone()),
3124 op: BinOp::And,
3125 right: Box::new(ow.clone()),
3126 }),
3127 (Some(w), None) | (None, Some(w)) => Some(w.clone()),
3128 (None, None) => None,
3129 };
3130
3131 let fused = SelectStmt {
3132 columns: outer.columns.clone(),
3133 from: inner.from.clone(),
3134 from_alias: inner.from_alias.clone(),
3135 joins: vec![],
3136 distinct: outer.distinct,
3137 where_clause: merged_where,
3138 order_by: outer.order_by.clone(),
3139 limit: outer.limit.clone(),
3140 offset: outer.offset.clone(),
3141 group_by: outer.group_by.clone(),
3142 having: outer.having.clone(),
3143 };
3144
3145 Some(QueryBody::Select(Box::new(fused)))
3146}
3147
3148fn materialize_all_ctes(
3149 defs: &[CteDefinition],
3150 recursive: bool,
3151 exec_body: &mut dyn FnMut(&QueryBody, &CteContext) -> Result<QueryResult>,
3152) -> Result<CteContext> {
3153 let mut ctx = CteContext::new();
3154 for cte in defs {
3155 let qr = if recursive && cte_body_references_self(&cte.body, &cte.name) {
3156 materialize_recursive_cte(cte, &ctx, exec_body)?
3157 } else {
3158 materialize_cte(cte, &ctx, exec_body)?
3159 };
3160 ctx.insert(cte.name.clone(), qr);
3161 }
3162 Ok(ctx)
3163}
3164
3165fn materialize_cte(
3166 cte: &CteDefinition,
3167 ctx: &CteContext,
3168 exec_body: &mut dyn FnMut(&QueryBody, &CteContext) -> Result<QueryResult>,
3169) -> Result<QueryResult> {
3170 let mut qr = exec_body(&cte.body, ctx)?;
3171 if !cte.column_aliases.is_empty() {
3172 if cte.column_aliases.len() != qr.columns.len() {
3173 return Err(SqlError::CteColumnAliasMismatch {
3174 name: cte.name.clone(),
3175 expected: cte.column_aliases.len(),
3176 got: qr.columns.len(),
3177 });
3178 }
3179 qr.columns = cte.column_aliases.clone();
3180 }
3181 Ok(qr)
3182}
3183
3184const MAX_RECURSIVE_ITERATIONS: usize = 10_000;
3185
3186fn materialize_recursive_cte(
3187 cte: &CteDefinition,
3188 ctx: &CteContext,
3189 exec_body: &mut dyn FnMut(&QueryBody, &CteContext) -> Result<QueryResult>,
3190) -> Result<QueryResult> {
3191 let (anchor_body, recursive_body, union_all) = match &cte.body {
3192 QueryBody::Compound(comp) if matches!(comp.op, SetOp::Union) => {
3193 (&*comp.left, &*comp.right, comp.all)
3194 }
3195 _ => return Err(SqlError::RecursiveCteNoUnion(cte.name.clone())),
3196 };
3197
3198 let anchor_qr = exec_body(anchor_body, ctx)?;
3199 let columns = if !cte.column_aliases.is_empty() {
3200 if cte.column_aliases.len() != anchor_qr.columns.len() {
3201 return Err(SqlError::CteColumnAliasMismatch {
3202 name: cte.name.clone(),
3203 expected: cte.column_aliases.len(),
3204 got: anchor_qr.columns.len(),
3205 });
3206 }
3207 cte.column_aliases.clone()
3208 } else {
3209 anchor_qr.columns
3210 };
3211
3212 let mut accumulated = anchor_qr.rows;
3213 let mut working_rows = accumulated.clone();
3214 let mut seen = if !union_all {
3215 let mut s = std::collections::HashSet::new();
3216 for row in &accumulated {
3217 s.insert(row.clone());
3218 }
3219 Some(s)
3220 } else {
3221 None
3222 };
3223
3224 let cte_key = cte.name.clone();
3225
3226 let fast_sel = match recursive_body {
3227 QueryBody::Select(sel)
3228 if sel.from.eq_ignore_ascii_case(&cte_key)
3229 && sel.joins.is_empty()
3230 && sel.group_by.is_empty()
3231 && !sel.distinct
3232 && sel.having.is_none()
3233 && sel.limit.is_none()
3234 && sel.offset.is_none()
3235 && sel.order_by.is_empty()
3236 && !stmt_has_subquery(sel) =>
3237 {
3238 Some(sel.as_ref())
3239 }
3240 _ => None,
3241 };
3242
3243 if let Some(sel) = fast_sel {
3244 let cte_cols: Vec<ColumnDef> = columns
3245 .iter()
3246 .enumerate()
3247 .map(|(i, name)| ColumnDef {
3248 name: name.clone(),
3249 data_type: DataType::Null,
3250 nullable: true,
3251 position: i as u16,
3252 default_expr: None,
3253 default_sql: None,
3254 check_expr: None,
3255 check_sql: None,
3256 check_name: None,
3257 })
3258 .collect();
3259 let col_map = ColumnMap::new(&cte_cols);
3260 let ncols = sel.columns.len();
3261
3262 for iteration in 0..MAX_RECURSIVE_ITERATIONS {
3263 if working_rows.is_empty() {
3264 break;
3265 }
3266
3267 let mut step_rows = Vec::with_capacity(working_rows.len());
3268 for row in &working_rows {
3269 if let Some(ref w) = sel.where_clause {
3270 match eval_expr(w, &col_map, row) {
3271 Ok(val) if is_truthy(&val) => {}
3272 Ok(_) => continue,
3273 Err(e) => return Err(e),
3274 }
3275 }
3276 let mut out = Vec::with_capacity(ncols);
3277 for col in &sel.columns {
3278 match col {
3279 SelectColumn::Expr { expr, .. } => {
3280 out.push(eval_expr(expr, &col_map, row)?);
3281 }
3282 SelectColumn::AllColumns => {
3283 out.extend_from_slice(row);
3284 }
3285 }
3286 }
3287 step_rows.push(out);
3288 }
3289
3290 if step_rows.is_empty() {
3291 break;
3292 }
3293
3294 let new_rows = if let Some(ref mut seen_set) = seen {
3295 step_rows
3296 .into_iter()
3297 .filter(|r| seen_set.insert(r.clone()))
3298 .collect::<Vec<_>>()
3299 } else {
3300 step_rows
3301 };
3302
3303 if new_rows.is_empty() {
3304 break;
3305 }
3306
3307 accumulated.extend_from_slice(&new_rows);
3308 working_rows = new_rows;
3309
3310 if iteration == MAX_RECURSIVE_ITERATIONS - 1 {
3311 return Err(SqlError::RecursiveCteMaxIterations(
3312 cte_key.clone(),
3313 MAX_RECURSIVE_ITERATIONS,
3314 ));
3315 }
3316 }
3317 } else {
3318 let mut iter_ctx = ctx.clone();
3319 iter_ctx.insert(
3320 cte_key.clone(),
3321 QueryResult {
3322 columns: columns.clone(),
3323 rows: working_rows,
3324 },
3325 );
3326
3327 for iteration in 0..MAX_RECURSIVE_ITERATIONS {
3328 if iter_ctx.get(&cte_key).unwrap().rows.is_empty() {
3329 break;
3330 }
3331
3332 let iter_qr = exec_body(recursive_body, &iter_ctx)?;
3333 if iter_qr.rows.is_empty() {
3334 break;
3335 }
3336
3337 let new_rows = if let Some(ref mut seen_set) = seen {
3338 iter_qr
3339 .rows
3340 .into_iter()
3341 .filter(|r| seen_set.insert(r.clone()))
3342 .collect::<Vec<_>>()
3343 } else {
3344 iter_qr.rows
3345 };
3346
3347 if new_rows.is_empty() {
3348 break;
3349 }
3350
3351 accumulated.extend_from_slice(&new_rows);
3352 iter_ctx.get_mut(&cte_key).unwrap().rows = new_rows;
3353
3354 if iteration == MAX_RECURSIVE_ITERATIONS - 1 {
3355 return Err(SqlError::RecursiveCteMaxIterations(
3356 cte_key.clone(),
3357 MAX_RECURSIVE_ITERATIONS,
3358 ));
3359 }
3360 }
3361
3362 iter_ctx.remove(&cte_key);
3363 }
3364
3365 Ok(QueryResult {
3366 columns,
3367 rows: accumulated,
3368 })
3369}
3370
3371fn cte_body_references_self(body: &QueryBody, name: &str) -> bool {
3372 match body {
3373 QueryBody::Select(sel) => {
3374 sel.from.eq_ignore_ascii_case(name)
3375 || sel
3376 .joins
3377 .iter()
3378 .any(|j| j.table.name.eq_ignore_ascii_case(name))
3379 }
3380 QueryBody::Compound(comp) => {
3381 cte_body_references_self(&comp.left, name)
3382 || cte_body_references_self(&comp.right, name)
3383 }
3384 }
3385}
3386
3387fn build_cte_schema(name: &str, qr: &QueryResult) -> TableSchema {
3388 let columns: Vec<ColumnDef> = qr
3389 .columns
3390 .iter()
3391 .enumerate()
3392 .map(|(i, col_name)| ColumnDef {
3393 name: col_name.clone(),
3394 data_type: DataType::Null,
3395 nullable: true,
3396 position: i as u16,
3397 default_expr: None,
3398 default_sql: None,
3399 check_expr: None,
3400 check_sql: None,
3401 check_name: None,
3402 })
3403 .collect();
3404 TableSchema::new(name.into(), columns, vec![], vec![], vec![], vec![])
3405}
3406
3407fn exec_select_from_cte(
3408 cte_result: &QueryResult,
3409 stmt: &SelectStmt,
3410 exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
3411) -> Result<ExecutionResult> {
3412 let cte_schema = build_cte_schema(&stmt.from, cte_result);
3413 let actual_stmt;
3414 let s = if stmt_has_subquery(stmt) {
3415 actual_stmt = materialize_stmt(stmt, exec_sub)?;
3416 &actual_stmt
3417 } else {
3418 stmt
3419 };
3420
3421 let has_aggregates = s.columns.iter().any(|c| match c {
3422 SelectColumn::Expr { expr, .. } => is_aggregate_expr(expr),
3423 _ => false,
3424 });
3425
3426 if has_aggregates || !s.group_by.is_empty() {
3427 if let Some(ref where_expr) = s.where_clause {
3428 let col_map = ColumnMap::new(&cte_schema.columns);
3429 let filtered: Vec<Vec<Value>> = cte_result
3430 .rows
3431 .iter()
3432 .filter(|row| match eval_expr(where_expr, &col_map, row) {
3433 Ok(val) => is_truthy(&val),
3434 _ => false,
3435 })
3436 .cloned()
3437 .collect();
3438 return exec_aggregate(&cte_schema.columns, &filtered, s);
3439 }
3440 return exec_aggregate(&cte_schema.columns, &cte_result.rows, s);
3441 }
3442
3443 process_select(&cte_schema.columns, cte_result.rows.clone(), s, false)
3444}
3445
3446fn exec_select_join_with_ctes(
3447 stmt: &SelectStmt,
3448 ctes: &CteContext,
3449 scan_table: ScanTableFn<'_>,
3450) -> Result<ExecutionResult> {
3451 let (from_schema, from_rows) = resolve_table_or_cte(&stmt.from, ctes, scan_table)?;
3452 let from_alias = table_alias_or_name(&stmt.from, &stmt.from_alias);
3453
3454 let mut tables: Vec<(String, TableSchema)> = vec![(from_alias.clone(), from_schema)];
3455 let mut join_rows: Vec<Vec<Vec<Value>>> = Vec::new();
3456
3457 for join in &stmt.joins {
3458 let jname = &join.table.name;
3459 let (js, jrows) = resolve_table_or_cte(jname, ctes, scan_table)?;
3460 let jalias = table_alias_or_name(jname, &join.table.alias);
3461 tables.push((jalias, js));
3462 join_rows.push(jrows);
3463 }
3464
3465 let mut outer_rows = from_rows;
3466 let mut cur_tables: Vec<(String, &TableSchema)> = vec![(from_alias.clone(), &tables[0].1)];
3467
3468 for (ji, join) in stmt.joins.iter().enumerate() {
3469 let inner_schema = &tables[ji + 1].1;
3470 let inner_alias = &tables[ji + 1].0;
3471 let inner_rows = &join_rows[ji];
3472
3473 let mut preview_tables = cur_tables.clone();
3474 preview_tables.push((inner_alias.clone(), inner_schema));
3475 let combined_cols = build_joined_columns(&preview_tables);
3476
3477 let outer_col_count = if outer_rows.is_empty() {
3478 cur_tables.iter().map(|(_, s)| s.columns.len()).sum()
3479 } else {
3480 outer_rows[0].len()
3481 };
3482 let inner_col_count = inner_schema.columns.len();
3483
3484 outer_rows = exec_join_step(
3485 outer_rows,
3486 inner_rows,
3487 join,
3488 &combined_cols,
3489 outer_col_count,
3490 inner_col_count,
3491 None,
3492 None,
3493 );
3494 cur_tables.push((inner_alias.clone(), inner_schema));
3495 }
3496
3497 let joined_cols = build_joined_columns(&cur_tables);
3498 process_select(&joined_cols, outer_rows, stmt, false)
3499}
3500
3501fn resolve_table_or_cte(
3502 name: &str,
3503 ctes: &CteContext,
3504 scan_table: ScanTableFn<'_>,
3505) -> Result<(TableSchema, Vec<Vec<Value>>)> {
3506 let lower = name.to_ascii_lowercase();
3507 if let Some(cte_result) = ctes.get(&lower) {
3508 let schema = build_cte_schema(&lower, cte_result);
3509 Ok((schema, cte_result.rows.clone()))
3510 } else {
3511 scan_table(&lower)
3512 }
3513}
3514
3515fn scan_table_read(
3516 db: &Database,
3517 schema: &SchemaManager,
3518 name: &str,
3519) -> Result<(TableSchema, Vec<Vec<Value>>)> {
3520 let table_schema = schema
3521 .get(name)
3522 .ok_or_else(|| SqlError::TableNotFound(name.to_string()))?;
3523 let (rows, _) = collect_rows_read(db, table_schema, &None, None)?;
3524 Ok((table_schema.clone(), rows))
3525}
3526
3527fn scan_table_write(
3528 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
3529 schema: &SchemaManager,
3530 name: &str,
3531) -> Result<(TableSchema, Vec<Vec<Value>>)> {
3532 let table_schema = schema
3533 .get(name)
3534 .ok_or_else(|| SqlError::TableNotFound(name.to_string()))?;
3535 let (rows, _) = collect_rows_write(wtx, table_schema, &None, None)?;
3536 Ok((table_schema.clone(), rows))
3537}
3538
3539fn apply_set_operation(
3540 comp: &CompoundSelect,
3541 left_qr: QueryResult,
3542 right_qr: QueryResult,
3543) -> Result<ExecutionResult> {
3544 if !left_qr.columns.is_empty()
3545 && !right_qr.columns.is_empty()
3546 && left_qr.columns.len() != right_qr.columns.len()
3547 {
3548 return Err(SqlError::CompoundColumnCountMismatch {
3549 left: left_qr.columns.len(),
3550 right: right_qr.columns.len(),
3551 });
3552 }
3553
3554 let columns = left_qr.columns;
3555
3556 let mut rows = match (&comp.op, comp.all) {
3557 (SetOp::Union, true) => {
3558 let mut rows = left_qr.rows;
3559 rows.extend(right_qr.rows);
3560 rows
3561 }
3562 (SetOp::Union, false) => {
3563 let mut seen = std::collections::HashSet::new();
3564 let mut rows = Vec::new();
3565 for row in left_qr.rows.into_iter().chain(right_qr.rows) {
3566 if seen.insert(row.clone()) {
3567 rows.push(row);
3568 }
3569 }
3570 rows
3571 }
3572 (SetOp::Intersect, true) => {
3573 let mut right_counts: std::collections::HashMap<Vec<Value>, usize> =
3574 std::collections::HashMap::new();
3575 for row in &right_qr.rows {
3576 *right_counts.entry(row.clone()).or_insert(0) += 1;
3577 }
3578 let mut rows = Vec::new();
3579 for row in left_qr.rows {
3580 if let Some(count) = right_counts.get_mut(&row) {
3581 if *count > 0 {
3582 *count -= 1;
3583 rows.push(row);
3584 }
3585 }
3586 }
3587 rows
3588 }
3589 (SetOp::Intersect, false) => {
3590 let right_set: std::collections::HashSet<Vec<Value>> =
3591 right_qr.rows.into_iter().collect();
3592 let mut seen = std::collections::HashSet::new();
3593 let mut rows = Vec::new();
3594 for row in left_qr.rows {
3595 if right_set.contains(&row) && seen.insert(row.clone()) {
3596 rows.push(row);
3597 }
3598 }
3599 rows
3600 }
3601 (SetOp::Except, true) => {
3602 let mut right_counts: std::collections::HashMap<Vec<Value>, usize> =
3603 std::collections::HashMap::new();
3604 for row in &right_qr.rows {
3605 *right_counts.entry(row.clone()).or_insert(0) += 1;
3606 }
3607 let mut rows = Vec::new();
3608 for row in left_qr.rows {
3609 if let Some(count) = right_counts.get_mut(&row) {
3610 if *count > 0 {
3611 *count -= 1;
3612 continue;
3613 }
3614 }
3615 rows.push(row);
3616 }
3617 rows
3618 }
3619 (SetOp::Except, false) => {
3620 let right_set: std::collections::HashSet<Vec<Value>> =
3621 right_qr.rows.into_iter().collect();
3622 let mut seen = std::collections::HashSet::new();
3623 let mut rows = Vec::new();
3624 for row in left_qr.rows {
3625 if !right_set.contains(&row) && seen.insert(row.clone()) {
3626 rows.push(row);
3627 }
3628 }
3629 rows
3630 }
3631 };
3632
3633 if !comp.order_by.is_empty() {
3634 let col_defs: Vec<crate::types::ColumnDef> = columns
3635 .iter()
3636 .enumerate()
3637 .map(|(i, name)| crate::types::ColumnDef {
3638 name: name.clone(),
3639 data_type: crate::types::DataType::Null,
3640 nullable: true,
3641 position: i as u16,
3642 default_expr: None,
3643 default_sql: None,
3644 check_expr: None,
3645 check_sql: None,
3646 check_name: None,
3647 })
3648 .collect();
3649 sort_rows(&mut rows, &comp.order_by, &col_defs)?;
3650 }
3651
3652 if let Some(ref offset_expr) = comp.offset {
3653 let offset = eval_const_int(offset_expr)?.max(0) as usize;
3654 if offset < rows.len() {
3655 rows = rows.split_off(offset);
3656 } else {
3657 rows.clear();
3658 }
3659 }
3660
3661 if let Some(ref limit_expr) = comp.limit {
3662 let limit = eval_const_int(limit_expr)?.max(0) as usize;
3663 rows.truncate(limit);
3664 }
3665
3666 Ok(ExecutionResult::Query(QueryResult { columns, rows }))
3667}
3668
3669fn exec_select(
3670 db: &Database,
3671 schema: &SchemaManager,
3672 stmt: &SelectStmt,
3673 ctes: &CteContext,
3674) -> Result<ExecutionResult> {
3675 let materialized;
3676 let stmt = if stmt_has_subquery(stmt) {
3677 materialized =
3678 materialize_stmt(stmt, &mut |sub| exec_subquery_read(db, schema, sub, ctes))?;
3679 &materialized
3680 } else {
3681 stmt
3682 };
3683
3684 if stmt.from.is_empty() {
3685 return exec_select_no_from(stmt);
3686 }
3687
3688 let lower_name = stmt.from.to_ascii_lowercase();
3689
3690 if let Some(cte_result) = ctes.get(&lower_name) {
3691 if stmt.joins.is_empty() {
3692 return exec_select_from_cte(cte_result, stmt, &mut |sub| {
3693 exec_subquery_read(db, schema, sub, ctes)
3694 });
3695 } else {
3696 return exec_select_join_with_ctes(stmt, ctes, &mut |name| {
3697 scan_table_read(db, schema, name)
3698 });
3699 }
3700 }
3701
3702 if !ctes.is_empty()
3703 && stmt
3704 .joins
3705 .iter()
3706 .any(|j| ctes.contains_key(&j.table.name.to_ascii_lowercase()))
3707 {
3708 return exec_select_join_with_ctes(stmt, ctes, &mut |name| {
3709 scan_table_read(db, schema, name)
3710 });
3711 }
3712
3713 let table_schema = schema
3714 .get(&lower_name)
3715 .ok_or_else(|| SqlError::TableNotFound(stmt.from.clone()))?;
3716
3717 if !stmt.joins.is_empty() {
3718 return exec_select_join(db, schema, stmt);
3719 }
3720
3721 if let Some(result) = try_count_star_shortcut(stmt, || {
3722 let mut rtx = db.begin_read();
3723 rtx.table_entry_count(lower_name.as_bytes())
3724 .map_err(SqlError::Storage)
3725 })? {
3726 return Ok(result);
3727 }
3728
3729 if let Some(plan) = StreamAggPlan::try_new(stmt, table_schema)? {
3730 let mut states: Vec<AggState> = plan.ops.iter().map(|(op, _)| AggState::new(op)).collect();
3731 let mut scan_err: Option<SqlError> = None;
3732 let mut rtx = db.begin_read();
3733 if stmt.where_clause.is_none() {
3734 rtx.table_scan_raw(lower_name.as_bytes(), |key, value| {
3735 plan.feed_row_raw(key, value, &mut states, &mut scan_err)
3736 })
3737 .map_err(SqlError::Storage)?;
3738 } else {
3739 let col_map = ColumnMap::new(&table_schema.columns);
3740 rtx.table_scan_raw(lower_name.as_bytes(), |key, value| {
3741 plan.feed_row(
3742 key,
3743 value,
3744 table_schema,
3745 &col_map,
3746 &stmt.where_clause,
3747 &mut states,
3748 &mut scan_err,
3749 )
3750 })
3751 .map_err(SqlError::Storage)?;
3752 }
3753 if let Some(e) = scan_err {
3754 return Err(e);
3755 }
3756 return Ok(plan.finish(states));
3757 }
3758
3759 if let Some(plan) = StreamGroupByPlan::try_new(stmt, table_schema)? {
3760 let lower = lower_name.clone();
3761 let mut rtx = db.begin_read();
3762 return plan
3763 .execute_scan(|cb| rtx.table_scan_raw(lower.as_bytes(), |key, value| cb(key, value)));
3764 }
3765
3766 if let Some(plan) = TopKScanPlan::try_new(stmt, table_schema)? {
3767 let lower = lower_name.clone();
3768 let mut rtx = db.begin_read();
3769 return plan.execute_scan(table_schema, stmt, |cb| {
3770 rtx.table_scan_raw(lower.as_bytes(), |key, value| cb(key, value))
3771 });
3772 }
3773
3774 let scan_limit = compute_scan_limit(stmt);
3775 let (rows, predicate_applied) =
3776 collect_rows_read(db, table_schema, &stmt.where_clause, scan_limit)?;
3777 process_select(&table_schema.columns, rows, stmt, predicate_applied)
3778}
3779
3780fn compute_scan_limit(stmt: &SelectStmt) -> Option<usize> {
3781 if !stmt.order_by.is_empty()
3782 || !stmt.group_by.is_empty()
3783 || stmt.distinct
3784 || stmt.having.is_some()
3785 {
3786 return None;
3787 }
3788 if has_any_window_function(stmt) {
3789 return None;
3790 }
3791 let has_aggregates = stmt.columns.iter().any(|c| match c {
3792 SelectColumn::Expr { expr, .. } => is_aggregate_expr(expr),
3793 _ => false,
3794 });
3795 if has_aggregates {
3796 return None;
3797 }
3798 let limit = stmt.limit.as_ref()?;
3799 let limit_val = eval_const_int(limit).ok()?.max(0) as usize;
3800 let offset_val = stmt
3801 .offset
3802 .as_ref()
3803 .and_then(|e| eval_const_int(e).ok())
3804 .unwrap_or(0)
3805 .max(0) as usize;
3806 Some(limit_val.saturating_add(offset_val))
3807}
3808
3809fn try_count_star_shortcut(
3810 stmt: &SelectStmt,
3811 get_count: impl FnOnce() -> Result<u64>,
3812) -> Result<Option<ExecutionResult>> {
3813 if stmt.columns.len() != 1
3814 || stmt.where_clause.is_some()
3815 || !stmt.group_by.is_empty()
3816 || stmt.having.is_some()
3817 {
3818 return Ok(None);
3819 }
3820 let col = match &stmt.columns[0] {
3821 SelectColumn::Expr { expr, alias } => (expr, alias),
3822 _ => return Ok(None),
3823 };
3824 if !matches!(col.0, Expr::CountStar) {
3825 return Ok(None);
3826 }
3827 let count = get_count()? as i64;
3828 let col_name = col.1.as_deref().unwrap_or("COUNT(*)").to_string();
3829 Ok(Some(ExecutionResult::Query(QueryResult {
3830 columns: vec![col_name],
3831 rows: vec![vec![Value::Integer(count)]],
3832 })))
3833}
3834
3835enum StreamAgg {
3836 CountStar,
3837 Count(usize),
3838 Sum(usize),
3839 Avg(usize),
3840 Min(usize),
3841 Max(usize),
3842}
3843
3844enum RawAggTarget {
3845 CountStar,
3846 Pk(usize),
3847 NonPk(usize),
3848}
3849
3850enum AggState {
3851 CountStar(i64),
3852 Count(i64),
3853 Sum {
3854 int_sum: i64,
3855 real_sum: f64,
3856 has_real: bool,
3857 all_null: bool,
3858 },
3859 Avg {
3860 sum: f64,
3861 count: i64,
3862 },
3863 Min(Option<Value>),
3864 Max(Option<Value>),
3865}
3866
3867impl AggState {
3868 fn new(op: &StreamAgg) -> Self {
3869 match op {
3870 StreamAgg::CountStar => AggState::CountStar(0),
3871 StreamAgg::Count(_) => AggState::Count(0),
3872 StreamAgg::Sum(_) => AggState::Sum {
3873 int_sum: 0,
3874 real_sum: 0.0,
3875 has_real: false,
3876 all_null: true,
3877 },
3878 StreamAgg::Avg(_) => AggState::Avg { sum: 0.0, count: 0 },
3879 StreamAgg::Min(_) => AggState::Min(None),
3880 StreamAgg::Max(_) => AggState::Max(None),
3881 }
3882 }
3883
3884 fn feed_val(&mut self, val: &Value) -> Result<()> {
3885 match self {
3886 AggState::CountStar(c) => {
3887 *c += 1;
3888 }
3889 AggState::Count(c) => {
3890 if !val.is_null() {
3891 *c += 1;
3892 }
3893 }
3894 AggState::Sum {
3895 int_sum,
3896 real_sum,
3897 has_real,
3898 all_null,
3899 } => match val {
3900 Value::Integer(i) => {
3901 *int_sum += i;
3902 *all_null = false;
3903 }
3904 Value::Real(r) => {
3905 *real_sum += r;
3906 *has_real = true;
3907 *all_null = false;
3908 }
3909 Value::Null => {}
3910 _ => {
3911 return Err(SqlError::TypeMismatch {
3912 expected: "numeric".into(),
3913 got: val.data_type().to_string(),
3914 })
3915 }
3916 },
3917 AggState::Avg { sum, count } => match val {
3918 Value::Integer(i) => {
3919 *sum += *i as f64;
3920 *count += 1;
3921 }
3922 Value::Real(r) => {
3923 *sum += r;
3924 *count += 1;
3925 }
3926 Value::Null => {}
3927 _ => {
3928 return Err(SqlError::TypeMismatch {
3929 expected: "numeric".into(),
3930 got: val.data_type().to_string(),
3931 })
3932 }
3933 },
3934 AggState::Min(cur) => {
3935 if !val.is_null() {
3936 *cur = Some(match cur.take() {
3937 None => val.clone(),
3938 Some(m) => {
3939 if val < &m {
3940 val.clone()
3941 } else {
3942 m
3943 }
3944 }
3945 });
3946 }
3947 }
3948 AggState::Max(cur) => {
3949 if !val.is_null() {
3950 *cur = Some(match cur.take() {
3951 None => val.clone(),
3952 Some(m) => {
3953 if val > &m {
3954 val.clone()
3955 } else {
3956 m
3957 }
3958 }
3959 });
3960 }
3961 }
3962 }
3963 Ok(())
3964 }
3965
3966 fn feed_raw(&mut self, raw: &RawColumn) -> Result<()> {
3967 match self {
3968 AggState::CountStar(c) => {
3969 *c += 1;
3970 }
3971 AggState::Count(c) => {
3972 if !matches!(raw, RawColumn::Null) {
3973 *c += 1;
3974 }
3975 }
3976 AggState::Sum {
3977 int_sum,
3978 real_sum,
3979 has_real,
3980 all_null,
3981 } => match raw {
3982 RawColumn::Integer(i) => {
3983 *int_sum += i;
3984 *all_null = false;
3985 }
3986 RawColumn::Real(r) => {
3987 *real_sum += r;
3988 *has_real = true;
3989 *all_null = false;
3990 }
3991 RawColumn::Null => {}
3992 _ => {
3993 return Err(SqlError::TypeMismatch {
3994 expected: "numeric".into(),
3995 got: "non-numeric".into(),
3996 })
3997 }
3998 },
3999 AggState::Avg { sum, count } => match raw {
4000 RawColumn::Integer(i) => {
4001 *sum += *i as f64;
4002 *count += 1;
4003 }
4004 RawColumn::Real(r) => {
4005 *sum += r;
4006 *count += 1;
4007 }
4008 RawColumn::Null => {}
4009 _ => {
4010 return Err(SqlError::TypeMismatch {
4011 expected: "numeric".into(),
4012 got: "non-numeric".into(),
4013 })
4014 }
4015 },
4016 AggState::Min(cur) => {
4017 if !matches!(raw, RawColumn::Null) {
4018 let val = raw.to_value();
4019 *cur = Some(match cur.take() {
4020 None => val,
4021 Some(m) => {
4022 if val < m {
4023 val
4024 } else {
4025 m
4026 }
4027 }
4028 });
4029 }
4030 }
4031 AggState::Max(cur) => {
4032 if !matches!(raw, RawColumn::Null) {
4033 let val = raw.to_value();
4034 *cur = Some(match cur.take() {
4035 None => val,
4036 Some(m) => {
4037 if val > m {
4038 val
4039 } else {
4040 m
4041 }
4042 }
4043 });
4044 }
4045 }
4046 }
4047 Ok(())
4048 }
4049
4050 fn finish(self) -> Value {
4051 match self {
4052 AggState::CountStar(c) | AggState::Count(c) => Value::Integer(c),
4053 AggState::Sum {
4054 int_sum,
4055 real_sum,
4056 has_real,
4057 all_null,
4058 } => {
4059 if all_null {
4060 Value::Null
4061 } else if has_real {
4062 Value::Real(real_sum + int_sum as f64)
4063 } else {
4064 Value::Integer(int_sum)
4065 }
4066 }
4067 AggState::Avg { sum, count } => {
4068 if count == 0 {
4069 Value::Null
4070 } else {
4071 Value::Real(sum / count as f64)
4072 }
4073 }
4074 AggState::Min(v) | AggState::Max(v) => v.unwrap_or(Value::Null),
4075 }
4076 }
4077}
4078
4079struct StreamAggPlan {
4080 ops: Vec<(StreamAgg, String)>,
4081 partial_ctx: Option<PartialDecodeCtx>,
4082 raw_targets: Vec<RawAggTarget>,
4083 num_pk_cols: usize,
4084 nonpk_agg_defaults: Vec<Option<Value>>,
4085}
4086
4087impl StreamAggPlan {
4088 fn try_new(stmt: &SelectStmt, table_schema: &TableSchema) -> Result<Option<Self>> {
4089 if !stmt.group_by.is_empty() || stmt.having.is_some() || !stmt.joins.is_empty() {
4090 return Ok(None);
4091 }
4092
4093 let col_map = ColumnMap::new(&table_schema.columns);
4094 let mut ops: Vec<(StreamAgg, String)> = Vec::new();
4095 for sel_col in &stmt.columns {
4096 let (expr, alias) = match sel_col {
4097 SelectColumn::Expr { expr, alias } => (expr, alias),
4098 _ => return Ok(None),
4099 };
4100 let name = alias
4101 .as_deref()
4102 .unwrap_or(&expr_display_name(expr))
4103 .to_string();
4104 match expr {
4105 Expr::CountStar => ops.push((StreamAgg::CountStar, name)),
4106 Expr::Function {
4107 name: func_name,
4108 args,
4109 } if args.len() == 1 => {
4110 let func = func_name.to_ascii_uppercase();
4111 let col_idx = match resolve_simple_col(&args[0], &col_map) {
4112 Some(idx) => idx,
4113 None => return Ok(None),
4114 };
4115 match func.as_str() {
4116 "COUNT" => ops.push((StreamAgg::Count(col_idx), name)),
4117 "SUM" => ops.push((StreamAgg::Sum(col_idx), name)),
4118 "AVG" => ops.push((StreamAgg::Avg(col_idx), name)),
4119 "MIN" => ops.push((StreamAgg::Min(col_idx), name)),
4120 "MAX" => ops.push((StreamAgg::Max(col_idx), name)),
4121 _ => return Ok(None),
4122 }
4123 }
4124 _ => return Ok(None),
4125 }
4126 }
4127
4128 let mut needed: Vec<usize> = ops
4129 .iter()
4130 .filter_map(|(op, _)| match op {
4131 StreamAgg::CountStar => None,
4132 StreamAgg::Count(i)
4133 | StreamAgg::Sum(i)
4134 | StreamAgg::Avg(i)
4135 | StreamAgg::Min(i)
4136 | StreamAgg::Max(i) => Some(*i),
4137 })
4138 .collect();
4139 if let Some(ref where_expr) = stmt.where_clause {
4140 needed.extend(referenced_columns(where_expr, &table_schema.columns));
4141 }
4142 needed.sort_unstable();
4143 needed.dedup();
4144
4145 let partial_ctx = if needed.len() < table_schema.columns.len() {
4146 Some(PartialDecodeCtx::new(table_schema, &needed))
4147 } else {
4148 None
4149 };
4150
4151 let non_pk = table_schema.non_pk_indices();
4152 let enc_pos = table_schema.encoding_positions();
4153 let raw_targets: Vec<RawAggTarget> = ops
4154 .iter()
4155 .map(|(op, _)| match op {
4156 StreamAgg::CountStar => RawAggTarget::CountStar,
4157 StreamAgg::Count(idx)
4158 | StreamAgg::Sum(idx)
4159 | StreamAgg::Avg(idx)
4160 | StreamAgg::Min(idx)
4161 | StreamAgg::Max(idx) => {
4162 if let Some(pk_pos) = table_schema
4163 .primary_key_columns
4164 .iter()
4165 .position(|&i| i as usize == *idx)
4166 {
4167 RawAggTarget::Pk(pk_pos)
4168 } else {
4169 let nonpk_order = non_pk.iter().position(|&i| i == *idx).unwrap();
4170 RawAggTarget::NonPk(enc_pos[nonpk_order] as usize)
4171 }
4172 }
4173 })
4174 .collect();
4175
4176 let num_pk_cols = table_schema.primary_key_columns.len();
4177
4178 let mapping = table_schema.decode_col_mapping();
4179 let nonpk_agg_defaults: Vec<Option<Value>> = raw_targets
4180 .iter()
4181 .map(|t| match t {
4182 RawAggTarget::NonPk(phys_idx) => {
4183 let schema_col = mapping[*phys_idx];
4184 if schema_col == usize::MAX {
4185 return None;
4186 }
4187 table_schema.columns[schema_col]
4188 .default_expr
4189 .as_ref()
4190 .and_then(|expr| eval_const_expr(expr).ok())
4191 }
4192 _ => None,
4193 })
4194 .collect();
4195
4196 Ok(Some(Self {
4197 ops,
4198 partial_ctx,
4199 raw_targets,
4200 num_pk_cols,
4201 nonpk_agg_defaults,
4202 }))
4203 }
4204
4205 #[allow(clippy::too_many_arguments)]
4206 fn feed_row(
4207 &self,
4208 key: &[u8],
4209 value: &[u8],
4210 table_schema: &TableSchema,
4211 col_map: &ColumnMap,
4212 where_clause: &Option<Expr>,
4213 states: &mut [AggState],
4214 scan_err: &mut Option<SqlError>,
4215 ) -> bool {
4216 let row = match &self.partial_ctx {
4217 Some(ctx) => match ctx.decode(key, value) {
4218 Ok(r) => r,
4219 Err(e) => {
4220 *scan_err = Some(e);
4221 return false;
4222 }
4223 },
4224 None => match decode_full_row(table_schema, key, value) {
4225 Ok(r) => r,
4226 Err(e) => {
4227 *scan_err = Some(e);
4228 return false;
4229 }
4230 },
4231 };
4232
4233 if let Some(expr) = where_clause {
4234 match eval_expr(expr, col_map, &row) {
4235 Ok(val) if !is_truthy(&val) => return true,
4236 Err(e) => {
4237 *scan_err = Some(e);
4238 return false;
4239 }
4240 _ => {}
4241 }
4242 }
4243
4244 for (i, (op, _)) in self.ops.iter().enumerate() {
4245 let val = match op {
4246 StreamAgg::CountStar => &Value::Null,
4247 StreamAgg::Count(idx)
4248 | StreamAgg::Sum(idx)
4249 | StreamAgg::Avg(idx)
4250 | StreamAgg::Min(idx)
4251 | StreamAgg::Max(idx) => &row[*idx],
4252 };
4253 if let Err(e) = states[i].feed_val(val) {
4254 *scan_err = Some(e);
4255 return false;
4256 }
4257 }
4258 true
4259 }
4260
4261 fn feed_row_raw(
4262 &self,
4263 key: &[u8],
4264 value: &[u8],
4265 states: &mut [AggState],
4266 scan_err: &mut Option<SqlError>,
4267 ) -> bool {
4268 for (i, target) in self.raw_targets.iter().enumerate() {
4269 let raw = match target {
4270 RawAggTarget::CountStar => {
4271 if let Err(e) = states[i].feed_raw(&RawColumn::Null) {
4272 *scan_err = Some(e);
4273 return false;
4274 }
4275 continue;
4276 }
4277 RawAggTarget::Pk(pk_pos) => {
4278 if self.num_pk_cols == 1 && *pk_pos == 0 {
4279 match decode_pk_integer(key) {
4280 Ok(v) => RawColumn::Integer(v),
4281 Err(e) => {
4282 *scan_err = Some(e);
4283 return false;
4284 }
4285 }
4286 } else {
4287 match decode_composite_key(key, self.num_pk_cols) {
4288 Ok(pk) => RawColumn::Integer(match &pk[*pk_pos] {
4289 Value::Integer(i) => *i,
4290 _ => {
4291 *scan_err =
4292 Some(SqlError::InvalidValue("PK not integer".into()));
4293 return false;
4294 }
4295 }),
4296 Err(e) => {
4297 *scan_err = Some(e);
4298 return false;
4299 }
4300 }
4301 }
4302 }
4303 RawAggTarget::NonPk(idx) => {
4304 let stored = row_non_pk_count(value);
4305 if *idx >= stored {
4306 if let Some(ref default) = self.nonpk_agg_defaults[i] {
4307 if let Err(e) = states[i].feed_val(default) {
4308 *scan_err = Some(e);
4309 return false;
4310 }
4311 } else if let Err(e) = states[i].feed_raw(&RawColumn::Null) {
4312 *scan_err = Some(e);
4313 return false;
4314 }
4315 continue;
4316 }
4317 match decode_column_raw(value, *idx) {
4318 Ok(v) => v,
4319 Err(e) => {
4320 *scan_err = Some(e);
4321 return false;
4322 }
4323 }
4324 }
4325 };
4326 if let Err(e) = states[i].feed_raw(&raw) {
4327 *scan_err = Some(e);
4328 return false;
4329 }
4330 }
4331 true
4332 }
4333
4334 fn finish(self, states: Vec<AggState>) -> ExecutionResult {
4335 let col_names: Vec<String> = self.ops.iter().map(|(_, name)| name.clone()).collect();
4336 let result_row: Vec<Value> = states.into_iter().map(|s| s.finish()).collect();
4337 ExecutionResult::Query(QueryResult {
4338 columns: col_names,
4339 rows: vec![result_row],
4340 })
4341 }
4342}
4343
4344fn resolve_simple_col(expr: &Expr, col_map: &ColumnMap) -> Option<usize> {
4345 match expr {
4346 Expr::Column(name) => col_map.resolve(name).ok(),
4347 Expr::QualifiedColumn { table, column } => col_map.resolve_qualified(table, column).ok(),
4348 _ => None,
4349 }
4350}
4351
4352enum GroupByOutputCol {
4353 GroupKey,
4354 Agg(usize),
4355}
4356
4357struct StreamGroupByPlan {
4358 group_target: RawAggTarget,
4359 num_pk_cols: usize,
4360 agg_ops: Vec<StreamAgg>,
4361 raw_targets: Vec<RawAggTarget>,
4362 output: Vec<(GroupByOutputCol, String)>,
4363 where_pred: Option<SimplePredicate>,
4364}
4365
4366impl StreamGroupByPlan {
4367 fn try_new(stmt: &SelectStmt, schema: &TableSchema) -> Result<Option<Self>> {
4368 if stmt.group_by.len() != 1
4369 || stmt.having.is_some()
4370 || !stmt.joins.is_empty()
4371 || !stmt.order_by.is_empty()
4372 || stmt.limit.is_some()
4373 {
4374 return Ok(None);
4375 }
4376
4377 let where_pred = stmt
4378 .where_clause
4379 .as_ref()
4380 .map(|expr| try_simple_predicate(expr, schema));
4381 if stmt.where_clause.is_some() && where_pred.as_ref().unwrap().is_none() {
4383 return Ok(None);
4384 }
4385 let where_pred = where_pred.flatten();
4386
4387 let col_map = ColumnMap::new(&schema.columns);
4388
4389 let group_col_idx = match &stmt.group_by[0] {
4390 Expr::Column(name) => col_map.resolve(name).ok(),
4391 _ => None,
4392 };
4393 let group_col_idx = match group_col_idx {
4394 Some(idx) => idx,
4395 None => return Ok(None),
4396 };
4397
4398 if schema.columns[group_col_idx].data_type != DataType::Integer {
4399 return Ok(None);
4400 }
4401
4402 let non_pk = schema.non_pk_indices();
4403 let enc_pos = schema.encoding_positions();
4404 let group_target = if let Some(pk_pos) = schema
4405 .primary_key_columns
4406 .iter()
4407 .position(|&i| i as usize == group_col_idx)
4408 {
4409 RawAggTarget::Pk(pk_pos)
4410 } else {
4411 let nonpk_order = non_pk.iter().position(|&i| i == group_col_idx).unwrap();
4412 RawAggTarget::NonPk(enc_pos[nonpk_order] as usize)
4413 };
4414
4415 let mut agg_ops = Vec::new();
4416 let mut raw_targets = Vec::new();
4417 let mut output = Vec::new();
4418
4419 for sel_col in &stmt.columns {
4420 let (expr, alias) = match sel_col {
4421 SelectColumn::Expr { expr, alias } => (expr, alias),
4422 _ => return Ok(None),
4423 };
4424 let name = alias
4425 .as_deref()
4426 .unwrap_or(&expr_display_name(expr))
4427 .to_string();
4428
4429 if let Some(idx) = resolve_simple_col(expr, &col_map) {
4430 if idx == group_col_idx {
4431 output.push((GroupByOutputCol::GroupKey, name));
4432 continue;
4433 }
4434 }
4435
4436 match expr {
4437 Expr::CountStar => {
4438 let agg_idx = agg_ops.len();
4439 agg_ops.push(StreamAgg::CountStar);
4440 raw_targets.push(RawAggTarget::CountStar);
4441 output.push((GroupByOutputCol::Agg(agg_idx), name));
4442 }
4443 Expr::Function {
4444 name: func_name,
4445 args,
4446 } if args.len() == 1 => {
4447 let func = func_name.to_ascii_uppercase();
4448 let col_idx = match resolve_simple_col(&args[0], &col_map) {
4449 Some(idx) => idx,
4450 None => return Ok(None),
4451 };
4452 let target = if let Some(pk_pos) = schema
4453 .primary_key_columns
4454 .iter()
4455 .position(|&i| i as usize == col_idx)
4456 {
4457 RawAggTarget::Pk(pk_pos)
4458 } else {
4459 let nonpk_order = non_pk.iter().position(|&i| i == col_idx).unwrap();
4460 RawAggTarget::NonPk(enc_pos[nonpk_order] as usize)
4461 };
4462 let agg_idx = agg_ops.len();
4463 match func.as_str() {
4464 "COUNT" => agg_ops.push(StreamAgg::Count(col_idx)),
4465 "SUM" => agg_ops.push(StreamAgg::Sum(col_idx)),
4466 "AVG" => agg_ops.push(StreamAgg::Avg(col_idx)),
4467 "MIN" => agg_ops.push(StreamAgg::Min(col_idx)),
4468 "MAX" => agg_ops.push(StreamAgg::Max(col_idx)),
4469 _ => return Ok(None),
4470 }
4471 raw_targets.push(target);
4472 output.push((GroupByOutputCol::Agg(agg_idx), name));
4473 }
4474 _ => return Ok(None),
4475 }
4476 }
4477
4478 Ok(Some(Self {
4479 group_target,
4480 num_pk_cols: schema.primary_key_columns.len(),
4481 agg_ops,
4482 raw_targets,
4483 output,
4484 where_pred,
4485 }))
4486 }
4487
4488 fn execute_scan(
4489 &self,
4490 scan: impl FnOnce(
4491 &mut dyn FnMut(&[u8], &[u8]) -> bool,
4492 ) -> std::result::Result<(), citadel::Error>,
4493 ) -> Result<ExecutionResult> {
4494 let mut groups: HashMap<i64, Vec<AggState>> = HashMap::new();
4495 let mut null_group: Option<Vec<AggState>> = None;
4496 let mut scan_err: Option<SqlError> = None;
4497
4498 scan(&mut |key, value| {
4499 if let Some(ref pred) = self.where_pred {
4500 match pred.matches_raw(key, value) {
4501 Ok(true) => {}
4502 Ok(false) => return true,
4503 Err(e) => {
4504 scan_err = Some(e);
4505 return false;
4506 }
4507 }
4508 }
4509
4510 let group_key: Option<i64> = match &self.group_target {
4511 RawAggTarget::Pk(pk_pos) => {
4512 if self.num_pk_cols == 1 && *pk_pos == 0 {
4513 match decode_pk_integer(key) {
4514 Ok(v) => Some(v),
4515 Err(e) => {
4516 scan_err = Some(e);
4517 return false;
4518 }
4519 }
4520 } else {
4521 match decode_composite_key(key, self.num_pk_cols) {
4522 Ok(pk) => match &pk[*pk_pos] {
4523 Value::Integer(i) => Some(*i),
4524 Value::Null => None,
4525 _ => {
4526 scan_err = Some(SqlError::InvalidValue(
4527 "GROUP BY key not integer".into(),
4528 ));
4529 return false;
4530 }
4531 },
4532 Err(e) => {
4533 scan_err = Some(e);
4534 return false;
4535 }
4536 }
4537 }
4538 }
4539 RawAggTarget::NonPk(idx) => match decode_column_raw(value, *idx) {
4540 Ok(RawColumn::Integer(i)) => Some(i),
4541 Ok(RawColumn::Null) => None,
4542 Ok(_) => {
4543 scan_err = Some(SqlError::InvalidValue("GROUP BY key not integer".into()));
4544 return false;
4545 }
4546 Err(e) => {
4547 scan_err = Some(e);
4548 return false;
4549 }
4550 },
4551 RawAggTarget::CountStar => unreachable!(),
4552 };
4553
4554 let states = match group_key {
4555 Some(k) => groups
4556 .entry(k)
4557 .or_insert_with(|| self.agg_ops.iter().map(AggState::new).collect()),
4558 None => null_group
4559 .get_or_insert_with(|| self.agg_ops.iter().map(AggState::new).collect()),
4560 };
4561
4562 for (i, target) in self.raw_targets.iter().enumerate() {
4563 let raw = match target {
4564 RawAggTarget::CountStar => {
4565 if let Err(e) = states[i].feed_raw(&RawColumn::Null) {
4566 scan_err = Some(e);
4567 return false;
4568 }
4569 continue;
4570 }
4571 RawAggTarget::Pk(pk_pos) => {
4572 if self.num_pk_cols == 1 && *pk_pos == 0 {
4573 match decode_pk_integer(key) {
4574 Ok(v) => RawColumn::Integer(v),
4575 Err(e) => {
4576 scan_err = Some(e);
4577 return false;
4578 }
4579 }
4580 } else {
4581 match decode_composite_key(key, self.num_pk_cols) {
4582 Ok(pk) => match &pk[*pk_pos] {
4583 Value::Integer(i) => RawColumn::Integer(*i),
4584 _ => {
4585 scan_err = Some(SqlError::InvalidValue(
4586 "agg column not integer".into(),
4587 ));
4588 return false;
4589 }
4590 },
4591 Err(e) => {
4592 scan_err = Some(e);
4593 return false;
4594 }
4595 }
4596 }
4597 }
4598 RawAggTarget::NonPk(idx) => match decode_column_raw(value, *idx) {
4599 Ok(v) => v,
4600 Err(e) => {
4601 scan_err = Some(e);
4602 return false;
4603 }
4604 },
4605 };
4606 if let Err(e) = states[i].feed_raw(&raw) {
4607 scan_err = Some(e);
4608 return false;
4609 }
4610 }
4611 true
4612 })
4613 .map_err(SqlError::Storage)?;
4614
4615 if let Some(e) = scan_err {
4616 return Err(e);
4617 }
4618
4619 let col_names: Vec<String> = self.output.iter().map(|(_, name)| name.clone()).collect();
4620 let null_extra = if null_group.is_some() { 1 } else { 0 };
4621 let mut result_rows: Vec<Vec<Value>> = Vec::with_capacity(groups.len() + null_extra);
4622 if let Some(states) = null_group {
4623 let mut row = Vec::with_capacity(self.output.len());
4624 let finished: Vec<Value> = states.into_iter().map(|s| s.finish()).collect();
4625 for (col, _) in &self.output {
4626 match col {
4627 GroupByOutputCol::GroupKey => row.push(Value::Null),
4628 GroupByOutputCol::Agg(idx) => row.push(finished[*idx].clone()),
4629 }
4630 }
4631 result_rows.push(row);
4632 }
4633 for (group_key, states) in groups {
4634 let mut row = Vec::with_capacity(self.output.len());
4635 let finished: Vec<Value> = states.into_iter().map(|s| s.finish()).collect();
4636 for (col, _) in &self.output {
4637 match col {
4638 GroupByOutputCol::GroupKey => row.push(Value::Integer(group_key)),
4639 GroupByOutputCol::Agg(idx) => row.push(finished[*idx].clone()),
4640 }
4641 }
4642 result_rows.push(row);
4643 }
4644
4645 Ok(ExecutionResult::Query(QueryResult {
4646 columns: col_names,
4647 rows: result_rows,
4648 }))
4649 }
4650}
4651
4652struct TopKScanPlan {
4653 sort_target: RawAggTarget,
4654 num_pk_cols: usize,
4655 descending: bool,
4656 nulls_first: bool,
4657 keep: usize,
4658}
4659
4660impl TopKScanPlan {
4661 fn try_new(stmt: &SelectStmt, schema: &TableSchema) -> Result<Option<Self>> {
4662 if stmt.order_by.len() != 1
4663 || stmt.limit.is_none()
4664 || stmt.where_clause.is_some()
4665 || !stmt.group_by.is_empty()
4666 || stmt.having.is_some()
4667 || !stmt.joins.is_empty()
4668 || stmt.distinct
4669 {
4670 return Ok(None);
4671 }
4672
4673 if has_any_window_function(stmt) {
4674 return Ok(None);
4675 }
4676
4677 let has_aggregates = stmt.columns.iter().any(|c| match c {
4678 SelectColumn::Expr { expr, .. } => is_aggregate_expr(expr),
4679 _ => false,
4680 });
4681 if has_aggregates {
4682 return Ok(None);
4683 }
4684
4685 let ob = &stmt.order_by[0];
4686 let col_map = ColumnMap::new(&schema.columns);
4687 let col_idx = match resolve_simple_col(&ob.expr, &col_map) {
4688 Some(idx) => idx,
4689 None => return Ok(None),
4690 };
4691
4692 let non_pk = schema.non_pk_indices();
4693 let enc_pos_arr = schema.encoding_positions();
4694 let sort_target = if let Some(pk_pos) = schema
4695 .primary_key_columns
4696 .iter()
4697 .position(|&i| i as usize == col_idx)
4698 {
4699 RawAggTarget::Pk(pk_pos)
4700 } else {
4701 let nonpk_order = non_pk.iter().position(|&i| i == col_idx).unwrap();
4702 RawAggTarget::NonPk(enc_pos_arr[nonpk_order] as usize)
4703 };
4704
4705 let limit = eval_const_int(stmt.limit.as_ref().unwrap())?.max(0) as usize;
4706 let offset = stmt
4707 .offset
4708 .as_ref()
4709 .map(eval_const_int)
4710 .transpose()?
4711 .unwrap_or(0)
4712 .max(0) as usize;
4713 let keep = limit.saturating_add(offset);
4714 if keep == 0 {
4715 return Ok(None);
4716 }
4717
4718 Ok(Some(Self {
4719 sort_target,
4720 num_pk_cols: schema.primary_key_columns.len(),
4721 descending: ob.descending,
4722 nulls_first: ob.nulls_first.unwrap_or(!ob.descending),
4723 keep,
4724 }))
4725 }
4726
4727 fn execute_scan(
4728 &self,
4729 schema: &TableSchema,
4730 stmt: &SelectStmt,
4731 scan: impl FnOnce(
4732 &mut dyn FnMut(&[u8], &[u8]) -> bool,
4733 ) -> std::result::Result<(), citadel::Error>,
4734 ) -> Result<ExecutionResult> {
4735 use std::cmp::Ordering;
4736 use std::collections::BinaryHeap;
4737
4738 struct Candidate {
4739 sort_key: Value,
4740 raw_key: Vec<u8>,
4741 raw_value: Vec<u8>,
4742 }
4743
4744 struct CandWrapper {
4745 c: Candidate,
4746 descending: bool,
4747 nulls_first: bool,
4748 }
4749
4750 impl PartialEq for CandWrapper {
4751 fn eq(&self, other: &Self) -> bool {
4752 self.cmp(other) == Ordering::Equal
4753 }
4754 }
4755 impl Eq for CandWrapper {}
4756
4757 impl PartialOrd for CandWrapper {
4758 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
4759 Some(self.cmp(other))
4760 }
4761 }
4762
4763 impl Ord for CandWrapper {
4765 fn cmp(&self, other: &Self) -> Ordering {
4766 let ord = match (self.c.sort_key.is_null(), other.c.sort_key.is_null()) {
4767 (true, true) => Ordering::Equal,
4768 (true, false) => {
4769 if self.nulls_first {
4770 Ordering::Less
4771 } else {
4772 Ordering::Greater
4773 }
4774 }
4775 (false, true) => {
4776 if self.nulls_first {
4777 Ordering::Greater
4778 } else {
4779 Ordering::Less
4780 }
4781 }
4782 (false, false) => self.c.sort_key.cmp(&other.c.sort_key),
4783 };
4784 if self.descending {
4785 ord.reverse()
4786 } else {
4787 ord
4788 }
4789 }
4790 }
4791
4792 let k = self.keep;
4793 let mut heap: BinaryHeap<CandWrapper> = BinaryHeap::with_capacity(k + 1);
4794 let mut scan_err: Option<SqlError> = None;
4795
4796 scan(&mut |key, value| {
4797 let sort_key: Value = match &self.sort_target {
4798 RawAggTarget::Pk(pk_pos) => {
4799 if self.num_pk_cols == 1 && *pk_pos == 0 {
4800 match decode_pk_integer(key) {
4801 Ok(v) => Value::Integer(v),
4802 Err(e) => {
4803 scan_err = Some(e);
4804 return false;
4805 }
4806 }
4807 } else {
4808 match decode_composite_key(key, self.num_pk_cols) {
4809 Ok(mut pk) => std::mem::replace(&mut pk[*pk_pos], Value::Null),
4810 Err(e) => {
4811 scan_err = Some(e);
4812 return false;
4813 }
4814 }
4815 }
4816 }
4817 RawAggTarget::NonPk(idx) => match decode_column_raw(value, *idx) {
4818 Ok(raw) => raw.to_value(),
4819 Err(e) => {
4820 scan_err = Some(e);
4821 return false;
4822 }
4823 },
4824 RawAggTarget::CountStar => unreachable!(),
4825 };
4826
4827 if heap.len() >= k {
4829 if let Some(top) = heap.peek() {
4830 let ord = match (sort_key.is_null(), top.c.sort_key.is_null()) {
4831 (true, true) => Ordering::Equal,
4832 (true, false) => {
4833 if self.nulls_first {
4834 Ordering::Less
4835 } else {
4836 Ordering::Greater
4837 }
4838 }
4839 (false, true) => {
4840 if self.nulls_first {
4841 Ordering::Greater
4842 } else {
4843 Ordering::Less
4844 }
4845 }
4846 (false, false) => sort_key.cmp(&top.c.sort_key),
4847 };
4848 let cmp = if self.descending { ord.reverse() } else { ord };
4849 if cmp != Ordering::Less {
4850 return true;
4851 }
4852 }
4853 }
4854
4855 let cand = CandWrapper {
4856 c: Candidate {
4857 sort_key,
4858 raw_key: key.to_vec(),
4859 raw_value: value.to_vec(),
4860 },
4861 descending: self.descending,
4862 nulls_first: self.nulls_first,
4863 };
4864
4865 if heap.len() < k {
4866 heap.push(cand);
4867 } else if let Some(mut top) = heap.peek_mut() {
4868 *top = cand;
4869 }
4870
4871 true
4872 })
4873 .map_err(SqlError::Storage)?;
4874
4875 if let Some(e) = scan_err {
4876 return Err(e);
4877 }
4878
4879 let mut winners: Vec<CandWrapper> = heap.into_vec();
4880 winners.sort();
4881
4882 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(winners.len());
4883 for w in &winners {
4884 rows.push(decode_full_row(schema, &w.c.raw_key, &w.c.raw_value)?);
4885 }
4886
4887 if let Some(ref offset_expr) = stmt.offset {
4888 let offset = eval_const_int(offset_expr)?.max(0) as usize;
4889 if offset < rows.len() {
4890 rows = rows.split_off(offset);
4891 } else {
4892 rows.clear();
4893 }
4894 }
4895 if let Some(ref limit_expr) = stmt.limit {
4896 let limit = eval_const_int(limit_expr)?.max(0) as usize;
4897 rows.truncate(limit);
4898 }
4899
4900 let (col_names, projected) = project_rows(&schema.columns, &stmt.columns, rows)?;
4901 Ok(ExecutionResult::Query(QueryResult {
4902 columns: col_names,
4903 rows: projected,
4904 }))
4905 }
4906}
4907
4908struct SimplePredicate {
4909 is_pk: bool,
4910 pk_pos: usize,
4911 nonpk_idx: usize,
4912 op: BinOp,
4913 literal: Value,
4914 num_pk_cols: usize,
4915 precomputed_int: Option<i64>,
4916 default_int: Option<i64>,
4917 default_val: Option<Value>,
4918}
4919
4920impl SimplePredicate {
4921 fn matches_raw(&self, key: &[u8], value: &[u8]) -> Result<bool> {
4922 if let Some(target) = self.precomputed_int {
4923 return Ok(self.match_nonpk_int_inline(value, target));
4924 }
4925 let raw = if self.is_pk {
4926 if self.num_pk_cols == 1 {
4927 RawColumn::Integer(decode_pk_integer(key)?)
4928 } else {
4929 let pk = decode_composite_key(key, self.num_pk_cols)?;
4930 match &pk[self.pk_pos] {
4931 Value::Integer(i) => RawColumn::Integer(*i),
4932 Value::Real(r) => RawColumn::Real(*r),
4933 Value::Boolean(b) => RawColumn::Boolean(*b),
4934 _ => {
4935 return Ok(raw_matches_op_value(
4936 &pk[self.pk_pos],
4937 self.op,
4938 &self.literal,
4939 ))
4940 }
4941 }
4942 }
4943 } else if self.nonpk_idx >= row_non_pk_count(value) {
4944 return Ok(match &self.default_val {
4945 Some(d) => raw_matches_op_value(d, self.op, &self.literal),
4946 None => false,
4947 });
4948 } else {
4949 decode_column_raw(value, self.nonpk_idx)?
4950 };
4951 Ok(raw_matches_op(&raw, self.op, &self.literal))
4952 }
4953
4954 #[inline(always)]
4955 fn match_nonpk_int_inline(&self, data: &[u8], target: i64) -> bool {
4956 let col_count = u16::from_le_bytes(data[0..2].try_into().unwrap()) as usize;
4957
4958 if self.nonpk_idx >= col_count {
4959 return match self.default_int {
4960 Some(v) => match self.op {
4961 BinOp::Eq => v == target,
4962 BinOp::NotEq => v != target,
4963 BinOp::Lt => v < target,
4964 BinOp::Gt => v > target,
4965 BinOp::LtEq => v <= target,
4966 BinOp::GtEq => v >= target,
4967 _ => false,
4968 },
4969 None => false,
4970 };
4971 }
4972
4973 let bm_bytes = col_count.div_ceil(8);
4974
4975 if data[2 + self.nonpk_idx / 8] & (1 << (self.nonpk_idx % 8)) != 0 {
4977 return false;
4978 }
4979
4980 let mut pos = 2 + bm_bytes;
4981
4982 for col in 0..self.nonpk_idx {
4984 if data[2 + col / 8] & (1 << (col % 8)) == 0 {
4985 let len = u32::from_le_bytes(data[pos + 1..pos + 5].try_into().unwrap()) as usize;
4986 pos += 5 + len;
4987 }
4988 }
4989
4990 let v = i64::from_le_bytes(data[pos + 5..pos + 13].try_into().unwrap());
4992
4993 match self.op {
4994 BinOp::Eq => v == target,
4995 BinOp::NotEq => v != target,
4996 BinOp::Lt => v < target,
4997 BinOp::Gt => v > target,
4998 BinOp::LtEq => v <= target,
4999 BinOp::GtEq => v >= target,
5000 _ => false,
5001 }
5002 }
5003}
5004
5005fn try_simple_predicate(expr: &Expr, schema: &TableSchema) -> Option<SimplePredicate> {
5006 let (col_name, op, literal) = match expr {
5007 Expr::BinaryOp { left, op, right } => match (left.as_ref(), right.as_ref()) {
5008 (Expr::Column(name), Expr::Literal(lit)) => (name.as_str(), *op, lit),
5009 (Expr::Literal(lit), Expr::Column(name)) => (name.as_str(), flip_cmp_op(*op)?, lit),
5010 _ => return None,
5011 },
5012 _ => return None,
5013 };
5014
5015 if !matches!(
5016 op,
5017 BinOp::Eq | BinOp::NotEq | BinOp::Lt | BinOp::Gt | BinOp::LtEq | BinOp::GtEq
5018 ) {
5019 return None;
5020 }
5021
5022 let col_idx = schema.column_index(col_name)?;
5023 let non_pk = schema.non_pk_indices();
5024
5025 if let Some(pk_pos) = schema
5026 .primary_key_columns
5027 .iter()
5028 .position(|&i| i as usize == col_idx)
5029 {
5030 Some(SimplePredicate {
5031 is_pk: true,
5032 pk_pos,
5033 nonpk_idx: 0,
5034 op,
5035 literal: literal.clone(),
5036 num_pk_cols: schema.primary_key_columns.len(),
5037 precomputed_int: None,
5038 default_int: None,
5039 default_val: None,
5040 })
5041 } else {
5042 let nonpk_order = non_pk.iter().position(|&i| i == col_idx)?;
5043 let nonpk_idx = schema.encoding_positions()[nonpk_order] as usize;
5044 let precomputed_int = match literal {
5045 Value::Integer(i) => Some(*i),
5046 _ => None,
5047 };
5048 let default_val = schema.columns[col_idx]
5049 .default_expr
5050 .as_ref()
5051 .and_then(|expr| eval_const_expr(expr).ok());
5052 let default_int = default_val.as_ref().and_then(|v| match v {
5053 Value::Integer(i) => Some(*i),
5054 _ => None,
5055 });
5056 Some(SimplePredicate {
5057 is_pk: false,
5058 pk_pos: 0,
5059 nonpk_idx,
5060 op,
5061 literal: literal.clone(),
5062 num_pk_cols: schema.primary_key_columns.len(),
5063 precomputed_int,
5064 default_int,
5065 default_val,
5066 })
5067 }
5068}
5069
5070fn flip_cmp_op(op: BinOp) -> Option<BinOp> {
5071 match op {
5072 BinOp::Eq => Some(BinOp::Eq),
5073 BinOp::NotEq => Some(BinOp::NotEq),
5074 BinOp::Lt => Some(BinOp::Gt),
5075 BinOp::Gt => Some(BinOp::Lt),
5076 BinOp::LtEq => Some(BinOp::GtEq),
5077 BinOp::GtEq => Some(BinOp::LtEq),
5078 _ => None,
5079 }
5080}
5081
5082fn raw_matches_op(raw: &RawColumn, op: BinOp, literal: &Value) -> bool {
5083 if matches!(raw, RawColumn::Null) || literal.is_null() {
5085 return false;
5086 }
5087 match op {
5088 BinOp::Eq => raw.eq_value(literal),
5089 BinOp::NotEq => !raw.eq_value(literal),
5090 BinOp::Lt => raw.cmp_value(literal) == Some(std::cmp::Ordering::Less),
5091 BinOp::Gt => raw.cmp_value(literal) == Some(std::cmp::Ordering::Greater),
5092 BinOp::LtEq => raw
5093 .cmp_value(literal)
5094 .is_some_and(|o| o != std::cmp::Ordering::Greater),
5095 BinOp::GtEq => raw
5096 .cmp_value(literal)
5097 .is_some_and(|o| o != std::cmp::Ordering::Less),
5098 _ => false,
5099 }
5100}
5101
5102fn raw_matches_op_value(val: &Value, op: BinOp, literal: &Value) -> bool {
5103 match op {
5104 BinOp::Eq => val == literal,
5105 BinOp::NotEq => val != literal && !val.is_null(),
5106 BinOp::Lt => val < literal,
5107 BinOp::Gt => val > literal,
5108 BinOp::LtEq => val <= literal,
5109 BinOp::GtEq => val >= literal,
5110 _ => false,
5111 }
5112}
5113
5114fn exec_select_no_from(stmt: &SelectStmt) -> Result<ExecutionResult> {
5115 let empty_cols: Vec<ColumnDef> = vec![];
5116 let empty_row: Vec<Value> = vec![];
5117 let (col_names, projected) = project_rows(&empty_cols, &stmt.columns, vec![empty_row])?;
5118 Ok(ExecutionResult::Query(QueryResult {
5119 columns: col_names,
5120 rows: projected,
5121 }))
5122}
5123
5124fn process_select(
5125 columns: &[ColumnDef],
5126 mut rows: Vec<Vec<Value>>,
5127 stmt: &SelectStmt,
5128 predicate_applied: bool,
5129) -> Result<ExecutionResult> {
5130 if !predicate_applied {
5131 if let Some(ref where_expr) = stmt.where_clause {
5132 let col_map = ColumnMap::new(columns);
5133 rows.retain(|row| match eval_expr(where_expr, &col_map, row) {
5134 Ok(val) => is_truthy(&val),
5135 Err(_) => false,
5136 });
5137 }
5138 }
5139
5140 if has_any_window_function(stmt) {
5141 return eval_window_select(columns, rows, stmt);
5142 }
5143
5144 let has_aggregates = stmt.columns.iter().any(|c| match c {
5145 SelectColumn::Expr { expr, .. } => is_aggregate_expr(expr),
5146 _ => false,
5147 });
5148
5149 if has_aggregates || !stmt.group_by.is_empty() {
5150 return exec_aggregate(columns, &rows, stmt);
5151 }
5152
5153 if stmt.distinct {
5154 let (col_names, mut projected) = project_rows(columns, &stmt.columns, rows)?;
5155
5156 let mut seen = std::collections::HashSet::new();
5157 projected.retain(|row| seen.insert(row.clone()));
5158
5159 if !stmt.order_by.is_empty() {
5160 let output_cols = build_output_columns(&stmt.columns, columns);
5161 sort_rows(&mut projected, &stmt.order_by, &output_cols)?;
5162 }
5163
5164 if let Some(ref offset_expr) = stmt.offset {
5165 let offset = eval_const_int(offset_expr)?.max(0) as usize;
5166 if offset < projected.len() {
5167 projected = projected.split_off(offset);
5168 } else {
5169 projected.clear();
5170 }
5171 }
5172
5173 if let Some(ref limit_expr) = stmt.limit {
5174 let limit = eval_const_int(limit_expr)?.max(0) as usize;
5175 projected.truncate(limit);
5176 }
5177
5178 return Ok(ExecutionResult::Query(QueryResult {
5179 columns: col_names,
5180 rows: projected,
5181 }));
5182 }
5183
5184 if !stmt.order_by.is_empty() {
5185 if let Some(ref limit_expr) = stmt.limit {
5186 let limit = eval_const_int(limit_expr)?.max(0) as usize;
5187 let offset = match stmt.offset {
5188 Some(ref e) => eval_const_int(e)?.max(0) as usize,
5189 None => 0,
5190 };
5191 let keep = limit.saturating_add(offset);
5192 if keep == 0 {
5193 rows.clear();
5194 } else if keep < rows.len() {
5195 topk_rows(&mut rows, &stmt.order_by, columns, keep)?;
5196 rows.truncate(keep);
5197 } else {
5198 sort_rows(&mut rows, &stmt.order_by, columns)?;
5199 }
5200 } else {
5201 sort_rows(&mut rows, &stmt.order_by, columns)?;
5202 }
5203 }
5204
5205 if let Some(ref offset_expr) = stmt.offset {
5206 let offset = eval_const_int(offset_expr)?.max(0) as usize;
5207 if offset < rows.len() {
5208 rows = rows.split_off(offset);
5209 } else {
5210 rows.clear();
5211 }
5212 }
5213
5214 if let Some(ref limit_expr) = stmt.limit {
5215 let limit = eval_const_int(limit_expr)?.max(0) as usize;
5216 rows.truncate(limit);
5217 }
5218
5219 let (col_names, projected) = project_rows(columns, &stmt.columns, rows)?;
5220
5221 Ok(ExecutionResult::Query(QueryResult {
5222 columns: col_names,
5223 rows: projected,
5224 }))
5225}
5226
5227fn resolve_table_name<'a>(schema: &'a SchemaManager, name: &str) -> Result<&'a TableSchema> {
5228 schema
5229 .get(name)
5230 .ok_or_else(|| SqlError::TableNotFound(name.to_string()))
5231}
5232
5233fn build_joined_columns(tables: &[(String, &TableSchema)]) -> Vec<ColumnDef> {
5234 let mut result = Vec::new();
5235 let mut pos: u16 = 0;
5236
5237 for (alias, schema) in tables {
5238 for col in &schema.columns {
5239 result.push(ColumnDef {
5240 name: format!("{}.{}", alias.to_ascii_lowercase(), col.name),
5241 data_type: col.data_type,
5242 nullable: col.nullable,
5243 position: pos,
5244 default_expr: None,
5245 default_sql: None,
5246 check_expr: None,
5247 check_sql: None,
5248 check_name: None,
5249 });
5250 pos += 1;
5251 }
5252 }
5253
5254 result
5255}
5256
5257fn extract_equi_join_keys(
5258 on_expr: &Expr,
5259 combined_cols: &[ColumnDef],
5260 outer_col_count: usize,
5261) -> Vec<(usize, usize)> {
5262 let mut pairs = Vec::new();
5263
5264 fn flatten<'a>(e: &'a Expr, out: &mut Vec<&'a Expr>) {
5265 match e {
5266 Expr::BinaryOp {
5267 left,
5268 op: BinOp::And,
5269 right,
5270 } => {
5271 flatten(left, out);
5272 flatten(right, out);
5273 }
5274 _ => out.push(e),
5275 }
5276 }
5277 let mut conjuncts = Vec::new();
5278 flatten(on_expr, &mut conjuncts);
5279
5280 for expr in conjuncts {
5281 if let Expr::BinaryOp {
5282 left,
5283 op: BinOp::Eq,
5284 right,
5285 } = expr
5286 {
5287 if let (Some(l_idx), Some(r_idx)) = (
5288 resolve_col_idx(left, combined_cols),
5289 resolve_col_idx(right, combined_cols),
5290 ) {
5291 if l_idx < outer_col_count && r_idx >= outer_col_count {
5292 pairs.push((l_idx, r_idx - outer_col_count));
5293 } else if r_idx < outer_col_count && l_idx >= outer_col_count {
5294 pairs.push((r_idx, l_idx - outer_col_count));
5295 }
5296 }
5297 }
5298 }
5299
5300 pairs
5301}
5302
5303fn resolve_col_idx(expr: &Expr, columns: &[ColumnDef]) -> Option<usize> {
5304 match expr {
5305 Expr::Column(name) => {
5306 let matches: Vec<usize> = columns
5307 .iter()
5308 .enumerate()
5309 .filter(|(_, c)| {
5310 c.name == *name
5311 || (c.name.len() > name.len()
5312 && c.name.as_bytes()[c.name.len() - name.len() - 1] == b'.'
5313 && c.name.ends_with(name.as_str()))
5314 })
5315 .map(|(i, _)| i)
5316 .collect();
5317 if matches.len() == 1 {
5318 Some(matches[0])
5319 } else {
5320 None
5321 }
5322 }
5323 Expr::QualifiedColumn { table, column } => {
5324 let qualified = format!("{table}.{column}");
5325 columns.iter().position(|c| c.name == qualified)
5326 }
5327 _ => None,
5328 }
5329}
5330
5331fn hash_key(row: &[Value], col_indices: &[usize]) -> Vec<Value> {
5332 col_indices.iter().map(|&i| row[i].clone()).collect()
5333}
5334
5335fn count_conjuncts(expr: &Expr) -> usize {
5336 match expr {
5337 Expr::BinaryOp {
5338 op: BinOp::And,
5339 left,
5340 right,
5341 } => count_conjuncts(left) + count_conjuncts(right),
5342 _ => 1,
5343 }
5344}
5345
5346fn combine_row(outer: &[Value], inner: &[Value], cap: usize) -> Vec<Value> {
5347 let mut combined = Vec::with_capacity(cap);
5348 combined.extend(outer.iter().cloned());
5349 combined.extend(inner.iter().cloned());
5350 combined
5351}
5352
5353struct CombineProjection {
5354 slots: Vec<(usize, bool)>,
5355}
5356
5357fn combine_row_projected(outer: &[Value], inner: &[Value], proj: &CombineProjection) -> Vec<Value> {
5358 proj.slots
5359 .iter()
5360 .map(|&(idx, is_inner)| {
5361 if is_inner {
5362 inner[idx].clone()
5363 } else {
5364 outer[idx].clone()
5365 }
5366 })
5367 .collect()
5368}
5369
5370fn build_combine_projection(
5371 needed_combined: &[usize],
5372 outer_col_count: usize,
5373) -> CombineProjection {
5374 CombineProjection {
5375 slots: needed_combined
5376 .iter()
5377 .map(|&ci| {
5378 if ci < outer_col_count {
5379 (ci, false)
5380 } else {
5381 (ci - outer_col_count, true)
5382 }
5383 })
5384 .collect(),
5385 }
5386}
5387
5388fn build_projected_columns(full_cols: &[ColumnDef], needed_combined: &[usize]) -> Vec<ColumnDef> {
5389 needed_combined
5390 .iter()
5391 .enumerate()
5392 .map(|(new_pos, &old_pos)| {
5393 let orig = &full_cols[old_pos];
5394 ColumnDef {
5395 name: orig.name.clone(),
5396 data_type: orig.data_type,
5397 nullable: orig.nullable,
5398 position: new_pos as u16,
5399 default_expr: None,
5400 default_sql: None,
5401 check_expr: None,
5402 check_sql: None,
5403 check_name: None,
5404 }
5405 })
5406 .collect()
5407}
5408
5409#[allow(clippy::too_many_arguments)]
5410fn try_integer_join(
5411 outer_rows: Vec<Vec<Value>>,
5412 inner_rows: &[Vec<Value>],
5413 join_type: &JoinType,
5414 outer_key_col: usize,
5415 inner_key_col: usize,
5416 outer_col_count: usize,
5417 inner_col_count: usize,
5418 outer_is_sorted: bool,
5419 projection: Option<&CombineProjection>,
5420) -> std::result::Result<Vec<Vec<Value>>, Vec<Vec<Value>>> {
5421 let cap = projection.map_or(outer_col_count + inner_col_count, |p| p.slots.len());
5422
5423 if outer_is_sorted && matches!(join_type, JoinType::Inner | JoinType::Cross) {
5424 let mut sorted_inner: Vec<(i64, usize)> = Vec::with_capacity(inner_rows.len());
5425 let mut needs_sort = false;
5426 let mut prev = i64::MIN;
5427 for (i, r) in inner_rows.iter().enumerate() {
5428 match r[inner_key_col] {
5429 Value::Integer(k) => {
5430 if k < prev {
5431 needs_sort = true;
5432 }
5433 prev = k;
5434 sorted_inner.push((k, i));
5435 }
5436 Value::Null => {}
5437 _ => return Err(outer_rows),
5438 }
5439 }
5440 if needs_sort {
5441 sorted_inner.sort_unstable_by_key(|&(k, _)| k);
5442 }
5443
5444 let mut result = Vec::with_capacity(outer_rows.len());
5445 let mut j = 0;
5446 for mut outer in outer_rows {
5447 let ok = match outer[outer_key_col] {
5448 Value::Integer(i) => i,
5449 _ => continue,
5450 };
5451 while j < sorted_inner.len() && sorted_inner[j].0 < ok {
5452 j += 1;
5453 }
5454 let mut kk = j;
5455 while kk < sorted_inner.len() && sorted_inner[kk].0 == ok {
5456 let is_last = kk + 1 >= sorted_inner.len() || sorted_inner[kk + 1].0 != ok;
5457 let inner = &inner_rows[sorted_inner[kk].1];
5458 if let Some(proj) = projection {
5459 if is_last {
5460 result.push(
5461 proj.slots
5462 .iter()
5463 .map(|&(idx, is_inner)| {
5464 if is_inner {
5465 inner[idx].clone()
5466 } else {
5467 std::mem::take(&mut outer[idx])
5468 }
5469 })
5470 .collect(),
5471 );
5472 } else {
5473 result.push(combine_row_projected(&outer, inner, proj));
5474 }
5475 } else if is_last {
5476 outer.extend(inner.iter().cloned());
5477 result.push(outer);
5478 break;
5479 } else {
5480 result.push(combine_row(&outer, inner, cap));
5481 }
5482 kk += 1;
5483 }
5484 }
5485 return Ok(result);
5486 }
5487
5488 let mut inner_map: HashMap<i64, Vec<usize>> = HashMap::with_capacity(inner_rows.len());
5489 for (idx, inner) in inner_rows.iter().enumerate() {
5490 match &inner[inner_key_col] {
5491 Value::Integer(k) => inner_map.entry(*k).or_default().push(idx),
5492 Value::Null => {}
5493 _ => return Err(outer_rows),
5494 }
5495 }
5496
5497 let mut result = Vec::with_capacity(inner_rows.len());
5498
5499 match join_type {
5500 JoinType::Inner | JoinType::Cross => {
5501 for mut outer in outer_rows {
5502 if let Value::Integer(k) = outer[outer_key_col] {
5503 if let Some(indices) = inner_map.get(&k) {
5504 if let Some(proj) = projection {
5505 for &idx in indices {
5506 result.push(combine_row_projected(&outer, &inner_rows[idx], proj));
5507 }
5508 } else {
5509 for &idx in &indices[..indices.len() - 1] {
5510 result.push(combine_row(&outer, &inner_rows[idx], cap));
5511 }
5512 let last_idx = *indices.last().unwrap();
5513 outer.extend(inner_rows[last_idx].iter().cloned());
5514 result.push(outer);
5515 }
5516 }
5517 }
5518 }
5519 }
5520 JoinType::Left => {
5521 for mut outer in outer_rows {
5522 if let Value::Integer(k) = outer[outer_key_col] {
5523 if let Some(indices) = inner_map.get(&k) {
5524 if let Some(proj) = projection {
5525 for &idx in indices {
5526 result.push(combine_row_projected(&outer, &inner_rows[idx], proj));
5527 }
5528 } else {
5529 for &idx in &indices[..indices.len() - 1] {
5530 result.push(combine_row(&outer, &inner_rows[idx], cap));
5531 }
5532 let last_idx = *indices.last().unwrap();
5533 outer.extend(inner_rows[last_idx].iter().cloned());
5534 result.push(outer);
5535 }
5536 continue;
5537 }
5538 }
5539 if let Some(proj) = projection {
5540 let null_inner = vec![Value::Null; inner_col_count];
5541 result.push(combine_row_projected(&outer, &null_inner, proj));
5542 } else {
5543 outer.resize(cap, Value::Null);
5544 result.push(outer);
5545 }
5546 }
5547 }
5548 JoinType::Right => {
5549 let mut inner_matched = vec![false; inner_rows.len()];
5550 for mut outer in outer_rows {
5551 if let Value::Integer(k) = outer[outer_key_col] {
5552 if let Some(indices) = inner_map.get(&k) {
5553 if let Some(proj) = projection {
5554 for &idx in indices {
5555 result.push(combine_row_projected(&outer, &inner_rows[idx], proj));
5556 inner_matched[idx] = true;
5557 }
5558 } else {
5559 for &idx in &indices[..indices.len() - 1] {
5560 result.push(combine_row(&outer, &inner_rows[idx], cap));
5561 inner_matched[idx] = true;
5562 }
5563 let last_idx = *indices.last().unwrap();
5564 inner_matched[last_idx] = true;
5565 outer.extend(inner_rows[last_idx].iter().cloned());
5566 result.push(outer);
5567 }
5568 }
5569 }
5570 }
5571 for (j, inner) in inner_rows.iter().enumerate() {
5572 if !inner_matched[j] {
5573 if let Some(proj) = projection {
5574 let null_outer = vec![Value::Null; outer_col_count];
5575 result.push(combine_row_projected(&null_outer, inner, proj));
5576 } else {
5577 let mut padded = Vec::with_capacity(cap);
5578 padded.resize(outer_col_count, Value::Null);
5579 padded.extend(inner.iter().cloned());
5580 result.push(padded);
5581 }
5582 }
5583 }
5584 }
5585 }
5586
5587 Ok(result)
5588}
5589
5590#[allow(clippy::too_many_arguments)]
5591fn exec_join_step(
5592 mut outer_rows: Vec<Vec<Value>>,
5593 inner_rows: &[Vec<Value>],
5594 join: &JoinClause,
5595 combined_cols: &[ColumnDef],
5596 outer_col_count: usize,
5597 inner_col_count: usize,
5598 outer_pk_col: Option<usize>,
5599 projection: Option<&CombineProjection>,
5600) -> Vec<Vec<Value>> {
5601 let equi_pairs = join
5602 .on_clause
5603 .as_ref()
5604 .map(|on| extract_equi_join_keys(on, combined_cols, outer_col_count))
5605 .unwrap_or_default();
5606
5607 let is_pure_equi = join.on_clause.as_ref().map_or(true, |on| {
5608 !equi_pairs.is_empty() && count_conjuncts(on) == equi_pairs.len()
5609 });
5610
5611 let effective_proj = if is_pure_equi { projection } else { None };
5612
5613 if equi_pairs.len() == 1 && is_pure_equi {
5614 let (outer_key_col, inner_key_col) = equi_pairs[0];
5615 let outer_is_sorted = outer_pk_col == Some(outer_key_col);
5616 match try_integer_join(
5617 outer_rows,
5618 inner_rows,
5619 &join.join_type,
5620 outer_key_col,
5621 inner_key_col,
5622 outer_col_count,
5623 inner_col_count,
5624 outer_is_sorted,
5625 effective_proj,
5626 ) {
5627 Ok(result) => return result,
5628 Err(rows) => outer_rows = rows,
5629 }
5630 }
5631
5632 let outer_key_cols: Vec<usize> = equi_pairs.iter().map(|&(o, _)| o).collect();
5633 let inner_key_cols: Vec<usize> = equi_pairs.iter().map(|&(_, i)| i).collect();
5634
5635 let mut inner_map: HashMap<Vec<Value>, Vec<usize>> = HashMap::new();
5636 for (idx, inner) in inner_rows.iter().enumerate() {
5637 inner_map
5638 .entry(hash_key(inner, &inner_key_cols))
5639 .or_default()
5640 .push(idx);
5641 }
5642
5643 let cap = effective_proj.map_or(outer_col_count + inner_col_count, |p| p.slots.len());
5644 let mut result = Vec::new();
5645
5646 if is_pure_equi {
5647 match join.join_type {
5648 JoinType::Inner | JoinType::Cross => {
5649 for mut outer in outer_rows {
5650 let key = hash_key(&outer, &outer_key_cols);
5651 if let Some(indices) = inner_map.get(&key) {
5652 if let Some(proj) = effective_proj {
5653 for &idx in indices {
5654 result.push(combine_row_projected(&outer, &inner_rows[idx], proj));
5655 }
5656 } else {
5657 for &idx in &indices[..indices.len() - 1] {
5658 result.push(combine_row(&outer, &inner_rows[idx], cap));
5659 }
5660 let last_idx = *indices.last().unwrap();
5661 outer.extend(inner_rows[last_idx].iter().cloned());
5662 result.push(outer);
5663 }
5664 }
5665 }
5666 }
5667 JoinType::Left => {
5668 for mut outer in outer_rows {
5669 let key = hash_key(&outer, &outer_key_cols);
5670 if let Some(indices) = inner_map.get(&key) {
5671 if let Some(proj) = effective_proj {
5672 for &idx in indices {
5673 result.push(combine_row_projected(&outer, &inner_rows[idx], proj));
5674 }
5675 } else {
5676 for &idx in &indices[..indices.len() - 1] {
5677 result.push(combine_row(&outer, &inner_rows[idx], cap));
5678 }
5679 let last_idx = *indices.last().unwrap();
5680 outer.extend(inner_rows[last_idx].iter().cloned());
5681 result.push(outer);
5682 }
5683 } else if let Some(proj) = effective_proj {
5684 let null_inner = vec![Value::Null; inner_col_count];
5685 result.push(combine_row_projected(&outer, &null_inner, proj));
5686 } else {
5687 outer.resize(cap, Value::Null);
5688 result.push(outer);
5689 }
5690 }
5691 }
5692 JoinType::Right => {
5693 let mut inner_matched = vec![false; inner_rows.len()];
5694 for mut outer in outer_rows {
5695 let key = hash_key(&outer, &outer_key_cols);
5696 if let Some(indices) = inner_map.get(&key) {
5697 if let Some(proj) = effective_proj {
5698 for &idx in indices {
5699 result.push(combine_row_projected(&outer, &inner_rows[idx], proj));
5700 inner_matched[idx] = true;
5701 }
5702 } else {
5703 for &idx in &indices[..indices.len() - 1] {
5704 result.push(combine_row(&outer, &inner_rows[idx], cap));
5705 inner_matched[idx] = true;
5706 }
5707 let last_idx = *indices.last().unwrap();
5708 inner_matched[last_idx] = true;
5709 outer.extend(inner_rows[last_idx].iter().cloned());
5710 result.push(outer);
5711 }
5712 }
5713 }
5714 for (j, inner) in inner_rows.iter().enumerate() {
5715 if !inner_matched[j] {
5716 if let Some(proj) = effective_proj {
5717 let null_outer = vec![Value::Null; outer_col_count];
5718 result.push(combine_row_projected(&null_outer, inner, proj));
5719 } else {
5720 let mut padded = Vec::with_capacity(cap);
5721 padded.resize(outer_col_count, Value::Null);
5722 padded.extend(inner.iter().cloned());
5723 result.push(padded);
5724 }
5725 }
5726 }
5727 }
5728 }
5729 } else {
5730 let combined_map = ColumnMap::new(combined_cols);
5731 let on_matches = |combined: &[Value]| -> bool {
5732 match join.on_clause {
5733 Some(ref on_expr) => eval_expr(on_expr, &combined_map, combined)
5734 .map(|v| is_truthy(&v))
5735 .unwrap_or(false),
5736 None => true,
5737 }
5738 };
5739
5740 match join.join_type {
5741 JoinType::Inner | JoinType::Cross => {
5742 for outer in &outer_rows {
5743 let key = hash_key(outer, &outer_key_cols);
5744 if let Some(indices) = inner_map.get(&key) {
5745 for &idx in indices {
5746 let combined = combine_row(outer, &inner_rows[idx], cap);
5747 if on_matches(&combined) {
5748 result.push(combined);
5749 }
5750 }
5751 }
5752 }
5753 }
5754 JoinType::Left => {
5755 for outer in &outer_rows {
5756 let key = hash_key(outer, &outer_key_cols);
5757 let mut matched = false;
5758 if let Some(indices) = inner_map.get(&key) {
5759 for &idx in indices {
5760 let combined = combine_row(outer, &inner_rows[idx], cap);
5761 if on_matches(&combined) {
5762 result.push(combined);
5763 matched = true;
5764 }
5765 }
5766 }
5767 if !matched {
5768 let mut padded = Vec::with_capacity(cap);
5769 padded.extend(outer.iter().cloned());
5770 padded.resize(cap, Value::Null);
5771 result.push(padded);
5772 }
5773 }
5774 }
5775 JoinType::Right => {
5776 let mut inner_matched = vec![false; inner_rows.len()];
5777 for outer in &outer_rows {
5778 let key = hash_key(outer, &outer_key_cols);
5779 if let Some(indices) = inner_map.get(&key) {
5780 for &idx in indices {
5781 let combined = combine_row(outer, &inner_rows[idx], cap);
5782 if on_matches(&combined) {
5783 result.push(combined);
5784 inner_matched[idx] = true;
5785 }
5786 }
5787 }
5788 }
5789 for (j, inner) in inner_rows.iter().enumerate() {
5790 if !inner_matched[j] {
5791 let mut padded = Vec::with_capacity(cap);
5792 padded.resize(outer_col_count, Value::Null);
5793 padded.extend(inner.iter().cloned());
5794 result.push(padded);
5795 }
5796 }
5797 }
5798 }
5799 }
5800
5801 result
5802}
5803
5804fn table_alias_or_name(name: &str, alias: &Option<String>) -> String {
5805 match alias {
5806 Some(a) => a.to_ascii_lowercase(),
5807 None => name.to_ascii_lowercase(),
5808 }
5809}
5810
5811fn collect_all_rows_raw(
5812 rtx: &mut citadel_txn::read_txn::ReadTxn<'_>,
5813 table_schema: &TableSchema,
5814) -> Result<Vec<Vec<Value>>> {
5815 let lower_name = &table_schema.name;
5816 let entry_count = rtx.table_entry_count(lower_name.as_bytes()).unwrap_or(0) as usize;
5817 let mut rows = Vec::with_capacity(entry_count);
5818 let mut scan_err: Option<SqlError> = None;
5819 rtx.table_scan_raw(lower_name.as_bytes(), |key, value| {
5820 match decode_full_row(table_schema, key, value) {
5821 Ok(row) => rows.push(row),
5822 Err(e) => {
5823 scan_err = Some(e);
5824 return false;
5825 }
5826 }
5827 true
5828 })
5829 .map_err(SqlError::Storage)?;
5830 if let Some(e) = scan_err {
5831 return Err(e);
5832 }
5833 Ok(rows)
5834}
5835
5836fn collect_all_rows_write(
5837 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
5838 table_schema: &TableSchema,
5839) -> Result<Vec<Vec<Value>>> {
5840 collect_rows_write(wtx, table_schema, &None, None).map(|(rows, _)| rows)
5841}
5842
5843fn has_ambiguous_bare_ref(expr: &Expr, columns: &[ColumnDef]) -> bool {
5844 match expr {
5845 Expr::Column(name) => {
5846 let lower = name.to_ascii_lowercase();
5847 columns
5848 .iter()
5849 .filter(|c| c.name == lower || c.name.ends_with(&format!(".{lower}")))
5850 .count()
5851 > 1
5852 }
5853 Expr::BinaryOp { left, right, .. } => {
5854 has_ambiguous_bare_ref(left, columns) || has_ambiguous_bare_ref(right, columns)
5855 }
5856 Expr::UnaryOp { expr: inner, .. } | Expr::IsNull(inner) | Expr::IsNotNull(inner) => {
5857 has_ambiguous_bare_ref(inner, columns)
5858 }
5859 Expr::Function { args, .. } | Expr::Coalesce(args) => {
5860 args.iter().any(|a| has_ambiguous_bare_ref(a, columns))
5861 }
5862 Expr::Between {
5863 expr: e, low, high, ..
5864 } => {
5865 has_ambiguous_bare_ref(e, columns)
5866 || has_ambiguous_bare_ref(low, columns)
5867 || has_ambiguous_bare_ref(high, columns)
5868 }
5869 Expr::InList { expr: e, list, .. } => {
5870 has_ambiguous_bare_ref(e, columns)
5871 || list.iter().any(|a| has_ambiguous_bare_ref(a, columns))
5872 }
5873 Expr::Like {
5874 expr: e,
5875 pattern,
5876 escape,
5877 ..
5878 } => {
5879 has_ambiguous_bare_ref(e, columns)
5880 || has_ambiguous_bare_ref(pattern, columns)
5881 || escape
5882 .as_ref()
5883 .is_some_and(|esc| has_ambiguous_bare_ref(esc, columns))
5884 }
5885 Expr::Cast { expr: inner, .. } => has_ambiguous_bare_ref(inner, columns),
5886 Expr::Case {
5887 operand,
5888 conditions,
5889 else_result,
5890 } => {
5891 operand
5892 .as_ref()
5893 .is_some_and(|o| has_ambiguous_bare_ref(o, columns))
5894 || conditions.iter().any(|(w, t)| {
5895 has_ambiguous_bare_ref(w, columns) || has_ambiguous_bare_ref(t, columns)
5896 })
5897 || else_result
5898 .as_ref()
5899 .is_some_and(|e| has_ambiguous_bare_ref(e, columns))
5900 }
5901 _ => false,
5902 }
5903}
5904
5905struct JoinColumnPlan {
5906 per_table: Vec<Vec<usize>>,
5907 output_combined: Vec<usize>,
5908}
5909
5910fn compute_join_needed_columns(
5911 stmt: &SelectStmt,
5912 tables: &[(String, &TableSchema)],
5913) -> Option<JoinColumnPlan> {
5914 for sel in &stmt.columns {
5915 if matches!(sel, SelectColumn::AllColumns) {
5916 return None;
5917 }
5918 }
5919
5920 let combined_cols = build_joined_columns(tables);
5921
5922 for sel in &stmt.columns {
5923 if let SelectColumn::Expr { expr, .. } = sel {
5924 if has_ambiguous_bare_ref(expr, &combined_cols) {
5925 return None;
5926 }
5927 }
5928 }
5929
5930 let mut output_combined: Vec<usize> = Vec::new();
5931 for sel in &stmt.columns {
5932 if let SelectColumn::Expr { expr, .. } = sel {
5933 output_combined.extend(referenced_columns(expr, &combined_cols));
5934 }
5935 }
5936 if let Some(w) = &stmt.where_clause {
5937 output_combined.extend(referenced_columns(w, &combined_cols));
5938 }
5939 for ob in &stmt.order_by {
5940 output_combined.extend(referenced_columns(&ob.expr, &combined_cols));
5941 }
5942 for gb in &stmt.group_by {
5943 output_combined.extend(referenced_columns(gb, &combined_cols));
5944 }
5945 if let Some(h) = &stmt.having {
5946 output_combined.extend(referenced_columns(h, &combined_cols));
5947 }
5948 output_combined.sort_unstable();
5949 output_combined.dedup();
5950
5951 let mut needed_combined = output_combined.clone();
5952 for join in &stmt.joins {
5953 if let Some(on_expr) = &join.on_clause {
5954 needed_combined.extend(referenced_columns(on_expr, &combined_cols));
5955 }
5956 }
5957 needed_combined.sort_unstable();
5958 needed_combined.dedup();
5959
5960 let mut offsets = Vec::with_capacity(tables.len() + 1);
5961 offsets.push(0usize);
5962 for (_, s) in tables {
5963 offsets.push(offsets.last().unwrap() + s.columns.len());
5964 }
5965
5966 let mut per_table: Vec<Vec<usize>> = tables.iter().map(|_| Vec::new()).collect();
5967 for &ci in &needed_combined {
5968 for (t, _) in tables.iter().enumerate() {
5969 let start = offsets[t];
5970 let end = offsets[t + 1];
5971 if ci >= start && ci < end {
5972 per_table[t].push(ci - start);
5973 break;
5974 }
5975 }
5976 }
5977
5978 Some(JoinColumnPlan {
5979 per_table,
5980 output_combined,
5981 })
5982}
5983
5984fn collect_rows_partial(
5985 rtx: &mut citadel_txn::read_txn::ReadTxn<'_>,
5986 table_schema: &TableSchema,
5987 needed: &[usize],
5988) -> Result<Vec<Vec<Value>>> {
5989 if needed.is_empty() || needed.len() == table_schema.columns.len() {
5990 return collect_all_rows_raw(rtx, table_schema);
5991 }
5992 let ctx = PartialDecodeCtx::new(table_schema, needed);
5993 let lower_name = &table_schema.name;
5994 let entry_count = rtx.table_entry_count(lower_name.as_bytes()).unwrap_or(0) as usize;
5995 let mut rows = Vec::with_capacity(entry_count);
5996 let mut scan_err: Option<SqlError> = None;
5997 rtx.table_scan_raw(lower_name.as_bytes(), |key, value| {
5998 match ctx.decode(key, value) {
5999 Ok(row) => rows.push(row),
6000 Err(e) => {
6001 scan_err = Some(e);
6002 return false;
6003 }
6004 }
6005 true
6006 })
6007 .map_err(SqlError::Storage)?;
6008 if let Some(e) = scan_err {
6009 return Err(e);
6010 }
6011 Ok(rows)
6012}
6013
6014fn collect_rows_partial_write(
6015 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
6016 table_schema: &TableSchema,
6017 needed: &[usize],
6018) -> Result<Vec<Vec<Value>>> {
6019 if needed.is_empty() || needed.len() == table_schema.columns.len() {
6020 return collect_all_rows_write(wtx, table_schema);
6021 }
6022 let ctx = PartialDecodeCtx::new(table_schema, needed);
6023 let lower_name = &table_schema.name;
6024 let entry_count = wtx.table_entry_count(lower_name.as_bytes()).unwrap_or(0) as usize;
6025 let mut rows = Vec::with_capacity(entry_count);
6026 let mut scan_err: Option<SqlError> = None;
6027 wtx.table_scan_from(lower_name.as_bytes(), b"", |key, value| {
6028 match ctx.decode(key, value) {
6029 Ok(row) => rows.push(row),
6030 Err(e) => {
6031 scan_err = Some(e);
6032 return Ok(false);
6033 }
6034 }
6035 Ok(true)
6036 })
6037 .map_err(SqlError::Storage)?;
6038 if let Some(e) = scan_err {
6039 return Err(e);
6040 }
6041 Ok(rows)
6042}
6043
6044fn exec_select_join(
6045 db: &Database,
6046 schema: &SchemaManager,
6047 stmt: &SelectStmt,
6048) -> Result<ExecutionResult> {
6049 let from_schema = resolve_table_name(schema, &stmt.from)?;
6050 let from_alias = table_alias_or_name(&stmt.from, &stmt.from_alias);
6051
6052 let mut all_tables: Vec<(String, &TableSchema)> = vec![(from_alias.clone(), from_schema)];
6053 for join in &stmt.joins {
6054 let inner_schema = resolve_table_name(schema, &join.table.name)?;
6055 let inner_alias = table_alias_or_name(&join.table.name, &join.table.alias);
6056 all_tables.push((inner_alias, inner_schema));
6057 }
6058 let (needed_per_table, output_combined) = match compute_join_needed_columns(stmt, &all_tables) {
6059 Some(plan) => (Some(plan.per_table), Some(plan.output_combined)),
6060 None => (None, None),
6061 };
6062
6063 let mut rtx = db.begin_read();
6064 let mut outer_rows = match &needed_per_table {
6065 Some(n) if !n.is_empty() => collect_rows_partial(&mut rtx, from_schema, &n[0])?,
6066 _ => collect_all_rows_raw(&mut rtx, from_schema)?,
6067 };
6068
6069 let mut tables: Vec<(String, &TableSchema)> = vec![(from_alias.clone(), from_schema)];
6070 let mut cur_outer_pk_col: Option<usize> = if from_schema.primary_key_columns.len() == 1 {
6071 Some(from_schema.primary_key_columns[0] as usize)
6072 } else {
6073 None
6074 };
6075
6076 let num_joins = stmt.joins.len();
6077 let mut last_combined_cols: Option<Vec<ColumnDef>> = None;
6078 for (ji, join) in stmt.joins.iter().enumerate() {
6079 let inner_schema = resolve_table_name(schema, &join.table.name)?;
6080 let inner_alias = table_alias_or_name(&join.table.name, &join.table.alias);
6081 let inner_rows = match &needed_per_table {
6082 Some(n) if ji + 1 < n.len() => {
6083 collect_rows_partial(&mut rtx, inner_schema, &n[ji + 1])?
6084 }
6085 _ => collect_all_rows_raw(&mut rtx, inner_schema)?,
6086 };
6087
6088 let mut preview_tables = tables.clone();
6089 preview_tables.push((inner_alias.clone(), inner_schema));
6090 let combined_cols = build_joined_columns(&preview_tables);
6091
6092 let outer_col_count = if outer_rows.is_empty() {
6093 tables.iter().map(|(_, s)| s.columns.len()).sum()
6094 } else {
6095 outer_rows[0].len()
6096 };
6097 let inner_col_count = inner_schema.columns.len();
6098
6099 let is_last = ji == num_joins - 1;
6100 let proj = if is_last {
6101 output_combined
6102 .as_ref()
6103 .map(|oc| build_combine_projection(oc, outer_col_count))
6104 } else {
6105 None
6106 };
6107
6108 outer_rows = exec_join_step(
6109 outer_rows,
6110 &inner_rows,
6111 join,
6112 &combined_cols,
6113 outer_col_count,
6114 inner_col_count,
6115 cur_outer_pk_col,
6116 proj.as_ref(),
6117 );
6118 last_combined_cols = Some(combined_cols);
6119 tables.push((inner_alias, inner_schema));
6120 cur_outer_pk_col = None;
6121 }
6122 drop(rtx);
6123
6124 let joined_cols = last_combined_cols.unwrap_or_else(|| build_joined_columns(&tables));
6125 if let Some(ref oc) = output_combined {
6126 let actual_width = outer_rows.first().map_or(0, |r| r.len());
6127 if actual_width == oc.len() {
6128 let projected_cols = build_projected_columns(&joined_cols, oc);
6129 return process_select(&projected_cols, outer_rows, stmt, false);
6130 }
6131 }
6132 process_select(&joined_cols, outer_rows, stmt, false)
6133}
6134
6135fn exec_select_join_in_txn(
6136 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
6137 schema: &SchemaManager,
6138 stmt: &SelectStmt,
6139) -> Result<ExecutionResult> {
6140 let from_schema = resolve_table_name(schema, &stmt.from)?;
6141 let from_alias = table_alias_or_name(&stmt.from, &stmt.from_alias);
6142
6143 let mut all_tables: Vec<(String, &TableSchema)> = vec![(from_alias.clone(), from_schema)];
6144 for join in &stmt.joins {
6145 let inner_schema = resolve_table_name(schema, &join.table.name)?;
6146 let inner_alias = table_alias_or_name(&join.table.name, &join.table.alias);
6147 all_tables.push((inner_alias, inner_schema));
6148 }
6149 let (needed_per_table, output_combined) = match compute_join_needed_columns(stmt, &all_tables) {
6150 Some(plan) => (Some(plan.per_table), Some(plan.output_combined)),
6151 None => (None, None),
6152 };
6153
6154 let mut outer_rows = match &needed_per_table {
6155 Some(n) if !n.is_empty() => collect_rows_partial_write(wtx, from_schema, &n[0])?,
6156 _ => collect_all_rows_write(wtx, from_schema)?,
6157 };
6158
6159 let mut tables: Vec<(String, &TableSchema)> = vec![(from_alias.clone(), from_schema)];
6160 let mut cur_outer_pk_col: Option<usize> = if from_schema.primary_key_columns.len() == 1 {
6161 Some(from_schema.primary_key_columns[0] as usize)
6162 } else {
6163 None
6164 };
6165
6166 let num_joins = stmt.joins.len();
6167 for (ji, join) in stmt.joins.iter().enumerate() {
6168 let inner_schema = resolve_table_name(schema, &join.table.name)?;
6169 let inner_alias = table_alias_or_name(&join.table.name, &join.table.alias);
6170 let inner_rows = match &needed_per_table {
6171 Some(n) if ji + 1 < n.len() => {
6172 collect_rows_partial_write(wtx, inner_schema, &n[ji + 1])?
6173 }
6174 _ => collect_all_rows_write(wtx, inner_schema)?,
6175 };
6176
6177 let mut preview_tables = tables.clone();
6178 preview_tables.push((inner_alias.clone(), inner_schema));
6179 let combined_cols = build_joined_columns(&preview_tables);
6180
6181 let outer_col_count = if outer_rows.is_empty() {
6182 tables.iter().map(|(_, s)| s.columns.len()).sum()
6183 } else {
6184 outer_rows[0].len()
6185 };
6186 let inner_col_count = inner_schema.columns.len();
6187
6188 let is_last = ji == num_joins - 1;
6189 let proj = if is_last {
6190 output_combined
6191 .as_ref()
6192 .map(|oc| build_combine_projection(oc, outer_col_count))
6193 } else {
6194 None
6195 };
6196
6197 outer_rows = exec_join_step(
6198 outer_rows,
6199 &inner_rows,
6200 join,
6201 &combined_cols,
6202 outer_col_count,
6203 inner_col_count,
6204 cur_outer_pk_col,
6205 proj.as_ref(),
6206 );
6207 tables.push((inner_alias, inner_schema));
6208 cur_outer_pk_col = None;
6209 }
6210
6211 let joined_cols = build_joined_columns(&tables);
6212 if let Some(ref oc) = output_combined {
6213 let actual_width = outer_rows.first().map_or(0, |r| r.len());
6214 if actual_width == oc.len() {
6215 let projected_cols = build_projected_columns(&joined_cols, oc);
6216 return process_select(&projected_cols, outer_rows, stmt, false);
6217 }
6218 }
6219 process_select(&joined_cols, outer_rows, stmt, false)
6220}
6221
6222fn exec_update(
6223 db: &Database,
6224 schema: &SchemaManager,
6225 stmt: &UpdateStmt,
6226) -> Result<ExecutionResult> {
6227 let materialized;
6228 let stmt = if update_has_subquery(stmt) {
6229 materialized = materialize_update(stmt, &mut |sub| {
6230 exec_subquery_read(db, schema, sub, &HashMap::new())
6231 })?;
6232 &materialized
6233 } else {
6234 stmt
6235 };
6236
6237 let lower_name = stmt.table.to_ascii_lowercase();
6238 let table_schema = schema
6239 .get(&lower_name)
6240 .ok_or_else(|| SqlError::TableNotFound(stmt.table.clone()))?;
6241
6242 let col_map = ColumnMap::new(&table_schema.columns);
6243 let all_candidates = collect_keyed_rows_read(db, table_schema, &stmt.where_clause)?;
6244 let matching_rows: Vec<(Vec<u8>, Vec<Value>)> = all_candidates
6245 .into_iter()
6246 .filter(|(_, row)| match &stmt.where_clause {
6247 Some(where_expr) => match eval_expr(where_expr, &col_map, row) {
6248 Ok(val) => is_truthy(&val),
6249 Err(_) => false,
6250 },
6251 None => true,
6252 })
6253 .collect();
6254
6255 if matching_rows.is_empty() {
6256 return Ok(ExecutionResult::RowsAffected(0));
6257 }
6258
6259 struct UpdateChange {
6260 old_key: Vec<u8>,
6261 new_key: Vec<u8>,
6262 new_value: Vec<u8>,
6263 pk_changed: bool,
6264 old_row: Vec<Value>,
6265 new_row: Vec<Value>,
6266 }
6267
6268 let pk_indices = table_schema.pk_indices();
6269 let mut changes: Vec<UpdateChange> = Vec::new();
6270
6271 for (old_key, row) in &matching_rows {
6272 let mut new_row = row.clone();
6273 let mut pk_changed = false;
6274
6275 let mut evaluated: Vec<(usize, Value)> = Vec::with_capacity(stmt.assignments.len());
6277 for (col_name, expr) in &stmt.assignments {
6278 let col_idx = table_schema
6279 .column_index(col_name)
6280 .ok_or_else(|| SqlError::ColumnNotFound(col_name.clone()))?;
6281 let new_val = eval_expr(expr, &col_map, row)?;
6282 let col = &table_schema.columns[col_idx];
6283
6284 let got_type = new_val.data_type();
6285 let coerced = if new_val.is_null() {
6286 if !col.nullable {
6287 return Err(SqlError::NotNullViolation(col.name.clone()));
6288 }
6289 Value::Null
6290 } else {
6291 new_val
6292 .coerce_into(col.data_type)
6293 .ok_or_else(|| SqlError::TypeMismatch {
6294 expected: col.data_type.to_string(),
6295 got: got_type.to_string(),
6296 })?
6297 };
6298
6299 evaluated.push((col_idx, coerced));
6300 }
6301
6302 for (col_idx, coerced) in evaluated {
6303 if table_schema.primary_key_columns.contains(&(col_idx as u16)) {
6304 pk_changed = true;
6305 }
6306 new_row[col_idx] = coerced;
6307 }
6308
6309 if table_schema.has_checks() {
6311 for col in &table_schema.columns {
6312 if let Some(ref check) = col.check_expr {
6313 let result = eval_expr(check, &col_map, &new_row)?;
6314 if !is_truthy(&result) && !result.is_null() {
6315 let name = col.check_name.as_deref().unwrap_or(&col.name);
6316 return Err(SqlError::CheckViolation(name.to_string()));
6317 }
6318 }
6319 }
6320 for tc in &table_schema.check_constraints {
6321 let result = eval_expr(&tc.expr, &col_map, &new_row)?;
6322 if !is_truthy(&result) && !result.is_null() {
6323 let name = tc.name.as_deref().unwrap_or(&tc.sql);
6324 return Err(SqlError::CheckViolation(name.to_string()));
6325 }
6326 }
6327 }
6328
6329 let pk_values: Vec<Value> = pk_indices.iter().map(|&i| new_row[i].clone()).collect();
6330 let new_key = encode_composite_key(&pk_values);
6331
6332 let non_pk = table_schema.non_pk_indices();
6333 let enc_pos = table_schema.encoding_positions();
6334 let phys_count = table_schema.physical_non_pk_count();
6335 let mut value_values = vec![Value::Null; phys_count];
6336 for (j, &i) in non_pk.iter().enumerate() {
6337 value_values[enc_pos[j] as usize] = new_row[i].clone();
6338 }
6339 let new_value = encode_row(&value_values);
6340
6341 changes.push(UpdateChange {
6342 old_key: old_key.clone(),
6343 new_key,
6344 new_value,
6345 pk_changed,
6346 old_row: row.clone(),
6347 new_row,
6348 });
6349 }
6350
6351 {
6352 use std::collections::HashSet;
6353 let mut new_keys: HashSet<Vec<u8>> = HashSet::new();
6354 for c in &changes {
6355 if c.pk_changed && c.new_key != c.old_key && !new_keys.insert(c.new_key.clone()) {
6356 return Err(SqlError::DuplicateKey);
6357 }
6358 }
6359 }
6360
6361 let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
6362
6363 if !table_schema.foreign_keys.is_empty() {
6365 for c in &changes {
6366 for fk in &table_schema.foreign_keys {
6367 let fk_changed = fk
6368 .columns
6369 .iter()
6370 .any(|&ci| c.old_row[ci as usize] != c.new_row[ci as usize]);
6371 if !fk_changed {
6372 continue;
6373 }
6374 let any_null = fk
6375 .columns
6376 .iter()
6377 .any(|&ci| c.new_row[ci as usize].is_null());
6378 if any_null {
6379 continue;
6380 }
6381 let fk_vals: Vec<Value> = fk
6382 .columns
6383 .iter()
6384 .map(|&ci| c.new_row[ci as usize].clone())
6385 .collect();
6386 let fk_key = encode_composite_key(&fk_vals);
6387 let found = wtx
6388 .table_get(fk.foreign_table.as_bytes(), &fk_key)
6389 .map_err(SqlError::Storage)?;
6390 if found.is_none() {
6391 let name = fk.name.as_deref().unwrap_or(&fk.foreign_table);
6392 return Err(SqlError::ForeignKeyViolation(name.to_string()));
6393 }
6394 }
6395 }
6396 }
6397
6398 let child_fks = schema.child_fks_for(&lower_name);
6400 if !child_fks.is_empty() {
6401 for c in &changes {
6402 if !c.pk_changed {
6403 continue;
6404 }
6405 let old_pk: Vec<Value> = pk_indices.iter().map(|&i| c.old_row[i].clone()).collect();
6406 let old_pk_key = encode_composite_key(&old_pk);
6407 for &(child_table, fk) in &child_fks {
6408 let child_schema = schema.get(child_table).unwrap();
6409 let fk_idx = child_schema
6410 .indices
6411 .iter()
6412 .find(|idx| idx.columns == fk.columns);
6413 if let Some(idx) = fk_idx {
6414 let idx_table = TableSchema::index_table_name(child_table, &idx.name);
6415 let mut has_child = false;
6416 wtx.table_scan_from(&idx_table, &old_pk_key, |key, _| {
6417 if key.starts_with(&old_pk_key) {
6418 has_child = true;
6419 Ok(false) } else {
6421 Ok(false) }
6423 })
6424 .map_err(SqlError::Storage)?;
6425 if has_child {
6426 return Err(SqlError::ForeignKeyViolation(format!(
6427 "cannot update PK in '{}': referenced by '{}'",
6428 lower_name, child_table
6429 )));
6430 }
6431 }
6432 }
6433 }
6434 }
6435
6436 for c in &changes {
6437 let old_pk: Vec<Value> = pk_indices.iter().map(|&i| c.old_row[i].clone()).collect();
6438
6439 for idx in &table_schema.indices {
6440 if index_columns_changed(idx, &c.old_row, &c.new_row) || c.pk_changed {
6441 let idx_table = TableSchema::index_table_name(&lower_name, &idx.name);
6442 let old_idx_key = encode_index_key(idx, &c.old_row, &old_pk);
6443 wtx.table_delete(&idx_table, &old_idx_key)
6444 .map_err(SqlError::Storage)?;
6445 }
6446 }
6447
6448 if c.pk_changed {
6449 wtx.table_delete(lower_name.as_bytes(), &c.old_key)
6450 .map_err(SqlError::Storage)?;
6451 }
6452 }
6453
6454 for c in &changes {
6455 let new_pk: Vec<Value> = pk_indices.iter().map(|&i| c.new_row[i].clone()).collect();
6456
6457 if c.pk_changed {
6458 let is_new = wtx
6459 .table_insert(lower_name.as_bytes(), &c.new_key, &c.new_value)
6460 .map_err(SqlError::Storage)?;
6461 if !is_new {
6462 return Err(SqlError::DuplicateKey);
6463 }
6464 } else {
6465 wtx.table_insert(lower_name.as_bytes(), &c.new_key, &c.new_value)
6466 .map_err(SqlError::Storage)?;
6467 }
6468
6469 for idx in &table_schema.indices {
6470 if index_columns_changed(idx, &c.old_row, &c.new_row) || c.pk_changed {
6471 let idx_table = TableSchema::index_table_name(&lower_name, &idx.name);
6472 let new_idx_key = encode_index_key(idx, &c.new_row, &new_pk);
6473 let new_idx_val = encode_index_value(idx, &c.new_row, &new_pk);
6474 let is_new = wtx
6475 .table_insert(&idx_table, &new_idx_key, &new_idx_val)
6476 .map_err(SqlError::Storage)?;
6477 if idx.unique && !is_new {
6478 let indexed_values: Vec<Value> = idx
6479 .columns
6480 .iter()
6481 .map(|&col_idx| c.new_row[col_idx as usize].clone())
6482 .collect();
6483 let any_null = indexed_values.iter().any(|v| v.is_null());
6484 if !any_null {
6485 return Err(SqlError::UniqueViolation(idx.name.clone()));
6486 }
6487 }
6488 }
6489 }
6490 }
6491
6492 let count = changes.len() as u64;
6493 wtx.commit().map_err(SqlError::Storage)?;
6494 Ok(ExecutionResult::RowsAffected(count))
6495}
6496
6497fn exec_delete(
6498 db: &Database,
6499 schema: &SchemaManager,
6500 stmt: &DeleteStmt,
6501) -> Result<ExecutionResult> {
6502 let materialized;
6503 let stmt = if delete_has_subquery(stmt) {
6504 materialized = materialize_delete(stmt, &mut |sub| {
6505 exec_subquery_read(db, schema, sub, &HashMap::new())
6506 })?;
6507 &materialized
6508 } else {
6509 stmt
6510 };
6511
6512 let lower_name = stmt.table.to_ascii_lowercase();
6513 let table_schema = schema
6514 .get(&lower_name)
6515 .ok_or_else(|| SqlError::TableNotFound(stmt.table.clone()))?;
6516
6517 let col_map = ColumnMap::new(&table_schema.columns);
6518 let all_candidates = collect_keyed_rows_read(db, table_schema, &stmt.where_clause)?;
6519 let rows_to_delete: Vec<(Vec<u8>, Vec<Value>)> = all_candidates
6520 .into_iter()
6521 .filter(|(_, row)| match &stmt.where_clause {
6522 Some(where_expr) => match eval_expr(where_expr, &col_map, row) {
6523 Ok(val) => is_truthy(&val),
6524 Err(_) => false,
6525 },
6526 None => true,
6527 })
6528 .collect();
6529
6530 if rows_to_delete.is_empty() {
6531 return Ok(ExecutionResult::RowsAffected(0));
6532 }
6533
6534 let pk_indices = table_schema.pk_indices();
6535 let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
6536
6537 let child_fks = schema.child_fks_for(&lower_name);
6539 if !child_fks.is_empty() {
6540 for (_key, row) in &rows_to_delete {
6541 let pk_values: Vec<Value> = pk_indices.iter().map(|&i| row[i].clone()).collect();
6542 let pk_key = encode_composite_key(&pk_values);
6543 for &(child_table, fk) in &child_fks {
6544 let child_schema = schema.get(child_table).unwrap();
6545 let fk_idx = child_schema
6546 .indices
6547 .iter()
6548 .find(|idx| idx.columns == fk.columns);
6549 if let Some(idx) = fk_idx {
6550 let idx_table = TableSchema::index_table_name(child_table, &idx.name);
6551 let mut has_child = false;
6552 wtx.table_scan_from(&idx_table, &pk_key, |key, _| {
6553 if key.starts_with(&pk_key) {
6554 has_child = true;
6555 Ok(false)
6556 } else {
6557 Ok(false)
6558 }
6559 })
6560 .map_err(SqlError::Storage)?;
6561 if has_child {
6562 return Err(SqlError::ForeignKeyViolation(format!(
6563 "cannot delete from '{}': referenced by '{}'",
6564 lower_name, child_table
6565 )));
6566 }
6567 }
6568 }
6569 }
6570 }
6571
6572 for (key, row) in &rows_to_delete {
6573 let pk_values: Vec<Value> = pk_indices.iter().map(|&i| row[i].clone()).collect();
6574 delete_index_entries(&mut wtx, table_schema, row, &pk_values)?;
6575 wtx.table_delete(lower_name.as_bytes(), key)
6576 .map_err(SqlError::Storage)?;
6577 }
6578 let count = rows_to_delete.len() as u64;
6579 wtx.commit().map_err(SqlError::Storage)?;
6580 Ok(ExecutionResult::RowsAffected(count))
6581}
6582
6583#[derive(Default)]
6584pub struct InsertBufs {
6585 row: Vec<Value>,
6586 pk_values: Vec<Value>,
6587 value_values: Vec<Value>,
6588 key_buf: Vec<u8>,
6589 value_buf: Vec<u8>,
6590 col_indices: Vec<usize>,
6591 fk_key_buf: Vec<u8>,
6592}
6593
6594impl InsertBufs {
6595 pub fn new() -> Self {
6596 Self {
6597 row: Vec::new(),
6598 pk_values: Vec::new(),
6599 value_values: Vec::new(),
6600 key_buf: Vec::with_capacity(64),
6601 value_buf: Vec::with_capacity(256),
6602 col_indices: Vec::new(),
6603 fk_key_buf: Vec::with_capacity(64),
6604 }
6605 }
6606}
6607
6608pub fn exec_insert_in_txn(
6609 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
6610 schema: &SchemaManager,
6611 stmt: &InsertStmt,
6612 params: &[Value],
6613 bufs: &mut InsertBufs,
6614) -> Result<ExecutionResult> {
6615 let empty_ctes = CteContext::new();
6616 let materialized;
6617 let stmt = if insert_has_subquery(stmt) {
6618 materialized = materialize_insert(stmt, &mut |sub| {
6619 exec_subquery_write(wtx, schema, sub, &empty_ctes)
6620 })?;
6621 &materialized
6622 } else {
6623 stmt
6624 };
6625
6626 let table_schema = schema
6627 .get(&stmt.table)
6628 .ok_or_else(|| SqlError::TableNotFound(stmt.table.clone()))?;
6629
6630 let default_columns;
6631 let insert_columns: &[String] = if stmt.columns.is_empty() {
6632 default_columns = table_schema
6633 .columns
6634 .iter()
6635 .map(|c| c.name.clone())
6636 .collect::<Vec<_>>();
6637 &default_columns
6638 } else {
6639 &stmt.columns
6640 };
6641
6642 bufs.col_indices.clear();
6643 for name in insert_columns {
6644 bufs.col_indices.push(
6645 table_schema
6646 .column_index(name)
6647 .ok_or_else(|| SqlError::ColumnNotFound(name.clone()))?,
6648 );
6649 }
6650
6651 let defaults: Vec<(usize, &Expr)> = table_schema
6653 .columns
6654 .iter()
6655 .filter(|c| c.default_expr.is_some() && !bufs.col_indices.contains(&(c.position as usize)))
6656 .map(|c| (c.position as usize, c.default_expr.as_ref().unwrap()))
6657 .collect();
6658
6659 let has_checks = table_schema.has_checks();
6660 let check_col_map = if has_checks {
6661 Some(ColumnMap::new(&table_schema.columns))
6662 } else {
6663 None
6664 };
6665
6666 let pk_indices = table_schema.pk_indices();
6667 let non_pk = table_schema.non_pk_indices();
6668 let enc_pos = table_schema.encoding_positions();
6669 let phys_count = table_schema.physical_non_pk_count();
6670 let dropped = table_schema.dropped_non_pk_slots();
6671
6672 bufs.row.resize(table_schema.columns.len(), Value::Null);
6673 bufs.pk_values.resize(pk_indices.len(), Value::Null);
6674 bufs.value_values.resize(phys_count, Value::Null);
6675
6676 let select_rows = match &stmt.source {
6677 InsertSource::Select(sq) => {
6678 let insert_ctes = materialize_all_ctes(&sq.ctes, sq.recursive, &mut |body, ctx| {
6679 exec_query_body_write(wtx, schema, body, ctx)
6680 })?;
6681 let qr = exec_query_body_write(wtx, schema, &sq.body, &insert_ctes)?;
6682 Some(qr.rows)
6683 }
6684 InsertSource::Values(_) => None,
6685 };
6686
6687 let mut count: u64 = 0;
6688
6689 let values = match &stmt.source {
6690 InsertSource::Values(rows) => Some(rows.as_slice()),
6691 InsertSource::Select(_) => None,
6692 };
6693 let sel_rows = select_rows.as_deref();
6694
6695 let total = match (values, sel_rows) {
6696 (Some(rows), _) => rows.len(),
6697 (_, Some(rows)) => rows.len(),
6698 _ => 0,
6699 };
6700
6701 if let Some(sel) = sel_rows {
6702 if !sel.is_empty() && sel[0].len() != insert_columns.len() {
6703 return Err(SqlError::InvalidValue(format!(
6704 "INSERT ... SELECT column count mismatch: expected {}, got {}",
6705 insert_columns.len(),
6706 sel[0].len()
6707 )));
6708 }
6709 }
6710
6711 for idx in 0..total {
6712 for v in bufs.row.iter_mut() {
6713 *v = Value::Null;
6714 }
6715
6716 if let Some(value_rows) = values {
6717 let value_row = &value_rows[idx];
6718 if value_row.len() != insert_columns.len() {
6719 return Err(SqlError::InvalidValue(format!(
6720 "expected {} values, got {}",
6721 insert_columns.len(),
6722 value_row.len()
6723 )));
6724 }
6725 for (i, expr) in value_row.iter().enumerate() {
6726 let val = if let Expr::Parameter(n) = expr {
6727 params
6728 .get(n - 1)
6729 .cloned()
6730 .ok_or_else(|| SqlError::Parse(format!("unbound parameter ${n}")))?
6731 } else {
6732 eval_const_expr(expr)?
6733 };
6734 let col_idx = bufs.col_indices[i];
6735 let col = &table_schema.columns[col_idx];
6736 let got_type = val.data_type();
6737 bufs.row[col_idx] = if val.is_null() {
6738 Value::Null
6739 } else {
6740 val.coerce_into(col.data_type)
6741 .ok_or_else(|| SqlError::TypeMismatch {
6742 expected: col.data_type.to_string(),
6743 got: got_type.to_string(),
6744 })?
6745 };
6746 }
6747 } else if let Some(sel) = sel_rows {
6748 let sel_row = &sel[idx];
6749 for (i, val) in sel_row.iter().enumerate() {
6750 let col_idx = bufs.col_indices[i];
6751 let col = &table_schema.columns[col_idx];
6752 let got_type = val.data_type();
6753 bufs.row[col_idx] = if val.is_null() {
6754 Value::Null
6755 } else {
6756 val.clone().coerce_into(col.data_type).ok_or_else(|| {
6757 SqlError::TypeMismatch {
6758 expected: col.data_type.to_string(),
6759 got: got_type.to_string(),
6760 }
6761 })?
6762 };
6763 }
6764 }
6765
6766 for &(pos, def_expr) in &defaults {
6768 let val = eval_const_expr(def_expr)?;
6769 let col = &table_schema.columns[pos];
6770 if val.is_null() {
6771 } else {
6773 let got_type = val.data_type();
6774 bufs.row[pos] =
6775 val.coerce_into(col.data_type)
6776 .ok_or_else(|| SqlError::TypeMismatch {
6777 expected: col.data_type.to_string(),
6778 got: got_type.to_string(),
6779 })?;
6780 }
6781 }
6782
6783 for col in &table_schema.columns {
6784 if !col.nullable && bufs.row[col.position as usize].is_null() {
6785 return Err(SqlError::NotNullViolation(col.name.clone()));
6786 }
6787 }
6788
6789 if let Some(ref col_map) = check_col_map {
6791 for col in &table_schema.columns {
6792 if let Some(ref check) = col.check_expr {
6793 let result = eval_expr(check, col_map, &bufs.row)?;
6794 if !is_truthy(&result) && !result.is_null() {
6795 let name = col.check_name.as_deref().unwrap_or(&col.name);
6796 return Err(SqlError::CheckViolation(name.to_string()));
6797 }
6798 }
6799 }
6800 for tc in &table_schema.check_constraints {
6801 let result = eval_expr(&tc.expr, col_map, &bufs.row)?;
6802 if !is_truthy(&result) && !result.is_null() {
6803 let name = tc.name.as_deref().unwrap_or(&tc.sql);
6804 return Err(SqlError::CheckViolation(name.to_string()));
6805 }
6806 }
6807 }
6808
6809 for fk in &table_schema.foreign_keys {
6811 let any_null = fk.columns.iter().any(|&ci| bufs.row[ci as usize].is_null());
6812 if any_null {
6813 continue;
6814 }
6815 let fk_vals: Vec<Value> = fk
6816 .columns
6817 .iter()
6818 .map(|&ci| bufs.row[ci as usize].clone())
6819 .collect();
6820 bufs.fk_key_buf.clear();
6821 encode_composite_key_into(&fk_vals, &mut bufs.fk_key_buf);
6822 let found = wtx
6823 .table_get(fk.foreign_table.as_bytes(), &bufs.fk_key_buf)
6824 .map_err(SqlError::Storage)?;
6825 if found.is_none() {
6826 let name = fk.name.as_deref().unwrap_or(&fk.foreign_table);
6827 return Err(SqlError::ForeignKeyViolation(name.to_string()));
6828 }
6829 }
6830
6831 for (j, &i) in pk_indices.iter().enumerate() {
6832 bufs.pk_values[j] = std::mem::replace(&mut bufs.row[i], Value::Null);
6833 }
6834 encode_composite_key_into(&bufs.pk_values, &mut bufs.key_buf);
6835
6836 for &slot in dropped {
6837 bufs.value_values[slot as usize] = Value::Null;
6838 }
6839 for (j, &i) in non_pk.iter().enumerate() {
6840 bufs.value_values[enc_pos[j] as usize] =
6841 std::mem::replace(&mut bufs.row[i], Value::Null);
6842 }
6843 encode_row_into(&bufs.value_values, &mut bufs.value_buf);
6844
6845 if bufs.key_buf.len() > citadel_core::MAX_KEY_SIZE {
6846 return Err(SqlError::KeyTooLarge {
6847 size: bufs.key_buf.len(),
6848 max: citadel_core::MAX_KEY_SIZE,
6849 });
6850 }
6851 if bufs.value_buf.len() > citadel_core::MAX_INLINE_VALUE_SIZE {
6852 return Err(SqlError::RowTooLarge {
6853 size: bufs.value_buf.len(),
6854 max: citadel_core::MAX_INLINE_VALUE_SIZE,
6855 });
6856 }
6857
6858 let is_new = wtx
6859 .table_insert(stmt.table.as_bytes(), &bufs.key_buf, &bufs.value_buf)
6860 .map_err(SqlError::Storage)?;
6861 if !is_new {
6862 return Err(SqlError::DuplicateKey);
6863 }
6864
6865 if !table_schema.indices.is_empty() {
6866 for (j, &i) in pk_indices.iter().enumerate() {
6867 bufs.row[i] = bufs.pk_values[j].clone();
6868 }
6869 for (j, &i) in non_pk.iter().enumerate() {
6870 bufs.row[i] =
6871 std::mem::replace(&mut bufs.value_values[enc_pos[j] as usize], Value::Null);
6872 }
6873 insert_index_entries(wtx, table_schema, &bufs.row, &bufs.pk_values)?;
6874 }
6875 count += 1;
6876 }
6877
6878 Ok(ExecutionResult::RowsAffected(count))
6879}
6880
6881fn exec_select_in_txn(
6882 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
6883 schema: &SchemaManager,
6884 stmt: &SelectStmt,
6885 ctes: &CteContext,
6886) -> Result<ExecutionResult> {
6887 let materialized;
6888 let stmt = if stmt_has_subquery(stmt) {
6889 materialized =
6890 materialize_stmt(stmt, &mut |sub| exec_subquery_write(wtx, schema, sub, ctes))?;
6891 &materialized
6892 } else {
6893 stmt
6894 };
6895
6896 if stmt.from.is_empty() {
6897 return exec_select_no_from(stmt);
6898 }
6899
6900 let lower_name = stmt.from.to_ascii_lowercase();
6901
6902 if let Some(cte_result) = ctes.get(&lower_name) {
6903 if stmt.joins.is_empty() {
6904 return exec_select_from_cte(cte_result, stmt, &mut |sub| {
6905 exec_subquery_write(wtx, schema, sub, ctes)
6906 });
6907 } else {
6908 return exec_select_join_with_ctes(stmt, ctes, &mut |name| {
6909 scan_table_write(wtx, schema, name)
6910 });
6911 }
6912 }
6913
6914 if !ctes.is_empty()
6915 && stmt
6916 .joins
6917 .iter()
6918 .any(|j| ctes.contains_key(&j.table.name.to_ascii_lowercase()))
6919 {
6920 return exec_select_join_with_ctes(stmt, ctes, &mut |name| {
6921 scan_table_write(wtx, schema, name)
6922 });
6923 }
6924
6925 if !stmt.joins.is_empty() {
6926 return exec_select_join_in_txn(wtx, schema, stmt);
6927 }
6928
6929 let lower_name = stmt.from.to_ascii_lowercase();
6930 let table_schema = schema
6931 .get(&lower_name)
6932 .ok_or_else(|| SqlError::TableNotFound(stmt.from.clone()))?;
6933
6934 if let Some(result) = try_count_star_shortcut(stmt, || {
6935 wtx.table_entry_count(lower_name.as_bytes())
6936 .map_err(SqlError::Storage)
6937 })? {
6938 return Ok(result);
6939 }
6940
6941 if let Some(plan) = StreamAggPlan::try_new(stmt, table_schema)? {
6942 let mut states: Vec<AggState> = plan.ops.iter().map(|(op, _)| AggState::new(op)).collect();
6943 let mut scan_err: Option<SqlError> = None;
6944 if stmt.where_clause.is_none() {
6945 wtx.table_scan_from(lower_name.as_bytes(), b"", |key, value| {
6946 Ok(plan.feed_row_raw(key, value, &mut states, &mut scan_err))
6947 })
6948 .map_err(SqlError::Storage)?;
6949 } else {
6950 let col_map = ColumnMap::new(&table_schema.columns);
6951 wtx.table_scan_from(lower_name.as_bytes(), b"", |key, value| {
6952 Ok(plan.feed_row(
6953 key,
6954 value,
6955 table_schema,
6956 &col_map,
6957 &stmt.where_clause,
6958 &mut states,
6959 &mut scan_err,
6960 ))
6961 })
6962 .map_err(SqlError::Storage)?;
6963 }
6964 if let Some(e) = scan_err {
6965 return Err(e);
6966 }
6967 return Ok(plan.finish(states));
6968 }
6969
6970 if let Some(plan) = StreamGroupByPlan::try_new(stmt, table_schema)? {
6971 let lower = lower_name.clone();
6972 return plan.execute_scan(|cb| {
6973 wtx.table_scan_from(lower.as_bytes(), b"", |key, value| Ok(cb(key, value)))
6974 });
6975 }
6976
6977 if let Some(plan) = TopKScanPlan::try_new(stmt, table_schema)? {
6978 let lower = lower_name.clone();
6979 return plan.execute_scan(table_schema, stmt, |cb| {
6980 wtx.table_scan_from(lower.as_bytes(), b"", |key, value| Ok(cb(key, value)))
6981 });
6982 }
6983
6984 let scan_limit = compute_scan_limit(stmt);
6985 let (rows, predicate_applied) =
6986 collect_rows_write(wtx, table_schema, &stmt.where_clause, scan_limit)?;
6987 process_select(&table_schema.columns, rows, stmt, predicate_applied)
6988}
6989
6990fn exec_update_in_txn(
6991 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
6992 schema: &SchemaManager,
6993 stmt: &UpdateStmt,
6994) -> Result<ExecutionResult> {
6995 let materialized;
6996 let stmt = if update_has_subquery(stmt) {
6997 materialized = materialize_update(stmt, &mut |sub| {
6998 exec_subquery_write(wtx, schema, sub, &HashMap::new())
6999 })?;
7000 &materialized
7001 } else {
7002 stmt
7003 };
7004
7005 let lower_name = stmt.table.to_ascii_lowercase();
7006 let table_schema = schema
7007 .get(&lower_name)
7008 .ok_or_else(|| SqlError::TableNotFound(stmt.table.clone()))?;
7009
7010 let col_map = ColumnMap::new(&table_schema.columns);
7011 let all_candidates = collect_keyed_rows_write(wtx, table_schema, &stmt.where_clause)?;
7012 let matching_rows: Vec<(Vec<u8>, Vec<Value>)> = all_candidates
7013 .into_iter()
7014 .filter(|(_, row)| match &stmt.where_clause {
7015 Some(where_expr) => match eval_expr(where_expr, &col_map, row) {
7016 Ok(val) => is_truthy(&val),
7017 Err(_) => false,
7018 },
7019 None => true,
7020 })
7021 .collect();
7022
7023 if matching_rows.is_empty() {
7024 return Ok(ExecutionResult::RowsAffected(0));
7025 }
7026
7027 struct UpdateChange {
7028 old_key: Vec<u8>,
7029 new_key: Vec<u8>,
7030 new_value: Vec<u8>,
7031 pk_changed: bool,
7032 old_row: Vec<Value>,
7033 new_row: Vec<Value>,
7034 }
7035
7036 let pk_indices = table_schema.pk_indices();
7037 let mut changes: Vec<UpdateChange> = Vec::new();
7038
7039 for (old_key, row) in &matching_rows {
7040 let mut new_row = row.clone();
7041 let mut pk_changed = false;
7042
7043 let mut evaluated: Vec<(usize, Value)> = Vec::with_capacity(stmt.assignments.len());
7045 for (col_name, expr) in &stmt.assignments {
7046 let col_idx = table_schema
7047 .column_index(col_name)
7048 .ok_or_else(|| SqlError::ColumnNotFound(col_name.clone()))?;
7049 let new_val = eval_expr(expr, &col_map, row)?;
7050 let col = &table_schema.columns[col_idx];
7051
7052 let got_type = new_val.data_type();
7053 let coerced = if new_val.is_null() {
7054 if !col.nullable {
7055 return Err(SqlError::NotNullViolation(col.name.clone()));
7056 }
7057 Value::Null
7058 } else {
7059 new_val
7060 .coerce_into(col.data_type)
7061 .ok_or_else(|| SqlError::TypeMismatch {
7062 expected: col.data_type.to_string(),
7063 got: got_type.to_string(),
7064 })?
7065 };
7066
7067 evaluated.push((col_idx, coerced));
7068 }
7069
7070 for (col_idx, coerced) in evaluated {
7071 if table_schema.primary_key_columns.contains(&(col_idx as u16)) {
7072 pk_changed = true;
7073 }
7074 new_row[col_idx] = coerced;
7075 }
7076
7077 if table_schema.has_checks() {
7079 for col in &table_schema.columns {
7080 if let Some(ref check) = col.check_expr {
7081 let result = eval_expr(check, &col_map, &new_row)?;
7082 if !is_truthy(&result) && !result.is_null() {
7083 let name = col.check_name.as_deref().unwrap_or(&col.name);
7084 return Err(SqlError::CheckViolation(name.to_string()));
7085 }
7086 }
7087 }
7088 for tc in &table_schema.check_constraints {
7089 let result = eval_expr(&tc.expr, &col_map, &new_row)?;
7090 if !is_truthy(&result) && !result.is_null() {
7091 let name = tc.name.as_deref().unwrap_or(&tc.sql);
7092 return Err(SqlError::CheckViolation(name.to_string()));
7093 }
7094 }
7095 }
7096
7097 let pk_values: Vec<Value> = pk_indices.iter().map(|&i| new_row[i].clone()).collect();
7098 let new_key = encode_composite_key(&pk_values);
7099
7100 let non_pk = table_schema.non_pk_indices();
7101 let enc_pos = table_schema.encoding_positions();
7102 let phys_count = table_schema.physical_non_pk_count();
7103 let mut value_values = vec![Value::Null; phys_count];
7104 for (j, &i) in non_pk.iter().enumerate() {
7105 value_values[enc_pos[j] as usize] = new_row[i].clone();
7106 }
7107 let new_value = encode_row(&value_values);
7108
7109 changes.push(UpdateChange {
7110 old_key: old_key.clone(),
7111 new_key,
7112 new_value,
7113 pk_changed,
7114 old_row: row.clone(),
7115 new_row,
7116 });
7117 }
7118
7119 {
7120 use std::collections::HashSet;
7121 let mut new_keys: HashSet<Vec<u8>> = HashSet::new();
7122 for c in &changes {
7123 if c.pk_changed && c.new_key != c.old_key && !new_keys.insert(c.new_key.clone()) {
7124 return Err(SqlError::DuplicateKey);
7125 }
7126 }
7127 }
7128
7129 if !table_schema.foreign_keys.is_empty() {
7131 for c in &changes {
7132 for fk in &table_schema.foreign_keys {
7133 let fk_changed = fk
7134 .columns
7135 .iter()
7136 .any(|&ci| c.old_row[ci as usize] != c.new_row[ci as usize]);
7137 if !fk_changed {
7138 continue;
7139 }
7140 let any_null = fk
7141 .columns
7142 .iter()
7143 .any(|&ci| c.new_row[ci as usize].is_null());
7144 if any_null {
7145 continue;
7146 }
7147 let fk_vals: Vec<Value> = fk
7148 .columns
7149 .iter()
7150 .map(|&ci| c.new_row[ci as usize].clone())
7151 .collect();
7152 let fk_key = encode_composite_key(&fk_vals);
7153 let found = wtx
7154 .table_get(fk.foreign_table.as_bytes(), &fk_key)
7155 .map_err(SqlError::Storage)?;
7156 if found.is_none() {
7157 let name = fk.name.as_deref().unwrap_or(&fk.foreign_table);
7158 return Err(SqlError::ForeignKeyViolation(name.to_string()));
7159 }
7160 }
7161 }
7162 }
7163
7164 let child_fks = schema.child_fks_for(&lower_name);
7166 if !child_fks.is_empty() {
7167 for c in &changes {
7168 if !c.pk_changed {
7169 continue;
7170 }
7171 let old_pk: Vec<Value> = pk_indices.iter().map(|&i| c.old_row[i].clone()).collect();
7172 let old_pk_key = encode_composite_key(&old_pk);
7173 for &(child_table, fk) in &child_fks {
7174 let child_schema = schema.get(child_table).unwrap();
7175 let fk_idx = child_schema
7176 .indices
7177 .iter()
7178 .find(|idx| idx.columns == fk.columns);
7179 if let Some(idx) = fk_idx {
7180 let idx_table = TableSchema::index_table_name(child_table, &idx.name);
7181 let mut has_child = false;
7182 wtx.table_scan_from(&idx_table, &old_pk_key, |key, _| {
7183 if key.starts_with(&old_pk_key) {
7184 has_child = true;
7185 Ok(false)
7186 } else {
7187 Ok(false)
7188 }
7189 })
7190 .map_err(SqlError::Storage)?;
7191 if has_child {
7192 return Err(SqlError::ForeignKeyViolation(format!(
7193 "cannot update PK in '{}': referenced by '{}'",
7194 lower_name, child_table
7195 )));
7196 }
7197 }
7198 }
7199 }
7200 }
7201
7202 for c in &changes {
7203 let old_pk: Vec<Value> = pk_indices.iter().map(|&i| c.old_row[i].clone()).collect();
7204
7205 for idx in &table_schema.indices {
7206 if index_columns_changed(idx, &c.old_row, &c.new_row) || c.pk_changed {
7207 let idx_table = TableSchema::index_table_name(&lower_name, &idx.name);
7208 let old_idx_key = encode_index_key(idx, &c.old_row, &old_pk);
7209 wtx.table_delete(&idx_table, &old_idx_key)
7210 .map_err(SqlError::Storage)?;
7211 }
7212 }
7213
7214 if c.pk_changed {
7215 wtx.table_delete(lower_name.as_bytes(), &c.old_key)
7216 .map_err(SqlError::Storage)?;
7217 }
7218 }
7219
7220 for c in &changes {
7221 let new_pk: Vec<Value> = pk_indices.iter().map(|&i| c.new_row[i].clone()).collect();
7222
7223 if c.pk_changed {
7224 let is_new = wtx
7225 .table_insert(lower_name.as_bytes(), &c.new_key, &c.new_value)
7226 .map_err(SqlError::Storage)?;
7227 if !is_new {
7228 return Err(SqlError::DuplicateKey);
7229 }
7230 } else {
7231 wtx.table_insert(lower_name.as_bytes(), &c.new_key, &c.new_value)
7232 .map_err(SqlError::Storage)?;
7233 }
7234
7235 for idx in &table_schema.indices {
7236 if index_columns_changed(idx, &c.old_row, &c.new_row) || c.pk_changed {
7237 let idx_table = TableSchema::index_table_name(&lower_name, &idx.name);
7238 let new_idx_key = encode_index_key(idx, &c.new_row, &new_pk);
7239 let new_idx_val = encode_index_value(idx, &c.new_row, &new_pk);
7240 let is_new = wtx
7241 .table_insert(&idx_table, &new_idx_key, &new_idx_val)
7242 .map_err(SqlError::Storage)?;
7243 if idx.unique && !is_new {
7244 let indexed_values: Vec<Value> = idx
7245 .columns
7246 .iter()
7247 .map(|&col_idx| c.new_row[col_idx as usize].clone())
7248 .collect();
7249 let any_null = indexed_values.iter().any(|v| v.is_null());
7250 if !any_null {
7251 return Err(SqlError::UniqueViolation(idx.name.clone()));
7252 }
7253 }
7254 }
7255 }
7256 }
7257
7258 let count = changes.len() as u64;
7259 Ok(ExecutionResult::RowsAffected(count))
7260}
7261
7262fn exec_delete_in_txn(
7263 wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
7264 schema: &SchemaManager,
7265 stmt: &DeleteStmt,
7266) -> Result<ExecutionResult> {
7267 let materialized;
7268 let stmt = if delete_has_subquery(stmt) {
7269 materialized = materialize_delete(stmt, &mut |sub| {
7270 exec_subquery_write(wtx, schema, sub, &HashMap::new())
7271 })?;
7272 &materialized
7273 } else {
7274 stmt
7275 };
7276
7277 let lower_name = stmt.table.to_ascii_lowercase();
7278 let table_schema = schema
7279 .get(&lower_name)
7280 .ok_or_else(|| SqlError::TableNotFound(stmt.table.clone()))?;
7281
7282 let col_map = ColumnMap::new(&table_schema.columns);
7283 let all_candidates = collect_keyed_rows_write(wtx, table_schema, &stmt.where_clause)?;
7284 let rows_to_delete: Vec<(Vec<u8>, Vec<Value>)> = all_candidates
7285 .into_iter()
7286 .filter(|(_, row)| match &stmt.where_clause {
7287 Some(where_expr) => match eval_expr(where_expr, &col_map, row) {
7288 Ok(val) => is_truthy(&val),
7289 Err(_) => false,
7290 },
7291 None => true,
7292 })
7293 .collect();
7294
7295 if rows_to_delete.is_empty() {
7296 return Ok(ExecutionResult::RowsAffected(0));
7297 }
7298
7299 let pk_indices = table_schema.pk_indices();
7300
7301 let child_fks = schema.child_fks_for(&lower_name);
7303 if !child_fks.is_empty() {
7304 for (_key, row) in &rows_to_delete {
7305 let pk_values: Vec<Value> = pk_indices.iter().map(|&i| row[i].clone()).collect();
7306 let pk_key = encode_composite_key(&pk_values);
7307 for &(child_table, fk) in &child_fks {
7308 let child_schema = schema.get(child_table).unwrap();
7309 let fk_idx = child_schema
7310 .indices
7311 .iter()
7312 .find(|idx| idx.columns == fk.columns);
7313 if let Some(idx) = fk_idx {
7314 let idx_table = TableSchema::index_table_name(child_table, &idx.name);
7315 let mut has_child = false;
7316 wtx.table_scan_from(&idx_table, &pk_key, |key, _| {
7317 if key.starts_with(&pk_key) {
7318 has_child = true;
7319 Ok(false)
7320 } else {
7321 Ok(false)
7322 }
7323 })
7324 .map_err(SqlError::Storage)?;
7325 if has_child {
7326 return Err(SqlError::ForeignKeyViolation(format!(
7327 "cannot delete from '{}': referenced by '{}'",
7328 lower_name, child_table
7329 )));
7330 }
7331 }
7332 }
7333 }
7334 }
7335
7336 for (key, row) in &rows_to_delete {
7337 let pk_values: Vec<Value> = pk_indices.iter().map(|&i| row[i].clone()).collect();
7338 delete_index_entries(wtx, table_schema, row, &pk_values)?;
7339 wtx.table_delete(lower_name.as_bytes(), key)
7340 .map_err(SqlError::Storage)?;
7341 }
7342 let count = rows_to_delete.len() as u64;
7343 Ok(ExecutionResult::RowsAffected(count))
7344}
7345
7346fn exec_aggregate(
7349 columns: &[ColumnDef],
7350 rows: &[Vec<Value>],
7351 stmt: &SelectStmt,
7352) -> Result<ExecutionResult> {
7353 let col_map = ColumnMap::new(columns);
7354 let groups: BTreeMap<Vec<Value>, Vec<&Vec<Value>>> = if stmt.group_by.is_empty() {
7355 let mut m = BTreeMap::new();
7356 m.insert(vec![], rows.iter().collect());
7357 m
7358 } else {
7359 let mut m: BTreeMap<Vec<Value>, Vec<&Vec<Value>>> = BTreeMap::new();
7360 for row in rows {
7361 let group_key: Vec<Value> = stmt
7362 .group_by
7363 .iter()
7364 .map(|expr| eval_expr(expr, &col_map, row))
7365 .collect::<Result<_>>()?;
7366 m.entry(group_key).or_default().push(row);
7367 }
7368 m
7369 };
7370
7371 let mut result_rows = Vec::new();
7372 let output_cols = build_output_columns(&stmt.columns, columns);
7373
7374 for group_rows in groups.values() {
7375 let mut result_row = Vec::new();
7376
7377 for sel_col in &stmt.columns {
7378 match sel_col {
7379 SelectColumn::AllColumns => {
7380 return Err(SqlError::Unsupported("SELECT * with GROUP BY".into()));
7381 }
7382 SelectColumn::Expr { expr, .. } => {
7383 let val = eval_aggregate_expr(expr, &col_map, group_rows)?;
7384 result_row.push(val);
7385 }
7386 }
7387 }
7388
7389 if let Some(ref having) = stmt.having {
7390 let passes = match eval_aggregate_expr(having, &col_map, group_rows) {
7391 Ok(val) => is_truthy(&val),
7392 Err(SqlError::ColumnNotFound(_)) => {
7393 let output_map = ColumnMap::new(&output_cols);
7394 match eval_expr(having, &output_map, &result_row) {
7395 Ok(val) => is_truthy(&val),
7396 Err(_) => false,
7397 }
7398 }
7399 Err(e) => return Err(e),
7400 };
7401 if !passes {
7402 continue;
7403 }
7404 }
7405
7406 result_rows.push(result_row);
7407 }
7408
7409 if stmt.distinct {
7410 let mut seen = std::collections::HashSet::new();
7411 result_rows.retain(|row| seen.insert(row.clone()));
7412 }
7413
7414 if !stmt.order_by.is_empty() {
7415 let output_cols = build_output_columns(&stmt.columns, columns);
7416 sort_rows(&mut result_rows, &stmt.order_by, &output_cols)?;
7417 }
7418
7419 if let Some(ref offset_expr) = stmt.offset {
7420 let offset = eval_const_int(offset_expr)?.max(0) as usize;
7421 if offset < result_rows.len() {
7422 result_rows = result_rows.split_off(offset);
7423 } else {
7424 result_rows.clear();
7425 }
7426 }
7427 if let Some(ref limit_expr) = stmt.limit {
7428 let limit = eval_const_int(limit_expr)?.max(0) as usize;
7429 result_rows.truncate(limit);
7430 }
7431
7432 let col_names = stmt
7433 .columns
7434 .iter()
7435 .map(|c| match c {
7436 SelectColumn::AllColumns => "*".into(),
7437 SelectColumn::Expr { alias: Some(a), .. } => a.clone(),
7438 SelectColumn::Expr { expr, .. } => expr_display_name(expr),
7439 })
7440 .collect();
7441
7442 Ok(ExecutionResult::Query(QueryResult {
7443 columns: col_names,
7444 rows: result_rows,
7445 }))
7446}
7447
7448fn eval_aggregate_expr(
7449 expr: &Expr,
7450 col_map: &ColumnMap,
7451 group_rows: &[&Vec<Value>],
7452) -> Result<Value> {
7453 match expr {
7454 Expr::CountStar => Ok(Value::Integer(group_rows.len() as i64)),
7455
7456 Expr::Function { name, args } if is_aggregate_function(name, args.len()) => {
7457 let func = name.to_ascii_uppercase();
7458 if args.len() != 1 {
7459 return Err(SqlError::Unsupported(format!(
7460 "{func} with {} args",
7461 args.len()
7462 )));
7463 }
7464 let arg = &args[0];
7465 let values: Vec<Value> = group_rows
7466 .iter()
7467 .map(|row| eval_expr(arg, col_map, row))
7468 .collect::<Result<_>>()?;
7469
7470 match func.as_str() {
7471 "COUNT" => {
7472 let count = values.iter().filter(|v| !v.is_null()).count();
7473 Ok(Value::Integer(count as i64))
7474 }
7475 "SUM" => {
7476 let mut int_sum: i64 = 0;
7477 let mut real_sum: f64 = 0.0;
7478 let mut has_real = false;
7479 let mut all_null = true;
7480 for v in &values {
7481 match v {
7482 Value::Integer(i) => {
7483 int_sum += i;
7484 all_null = false;
7485 }
7486 Value::Real(r) => {
7487 real_sum += r;
7488 has_real = true;
7489 all_null = false;
7490 }
7491 Value::Null => {}
7492 _ => {
7493 return Err(SqlError::TypeMismatch {
7494 expected: "numeric".into(),
7495 got: v.data_type().to_string(),
7496 })
7497 }
7498 }
7499 }
7500 if all_null {
7501 return Ok(Value::Null);
7502 }
7503 if has_real {
7504 Ok(Value::Real(real_sum + int_sum as f64))
7505 } else {
7506 Ok(Value::Integer(int_sum))
7507 }
7508 }
7509 "AVG" => {
7510 let mut sum: f64 = 0.0;
7511 let mut count: i64 = 0;
7512 for v in &values {
7513 match v {
7514 Value::Integer(i) => {
7515 sum += *i as f64;
7516 count += 1;
7517 }
7518 Value::Real(r) => {
7519 sum += r;
7520 count += 1;
7521 }
7522 Value::Null => {}
7523 _ => {
7524 return Err(SqlError::TypeMismatch {
7525 expected: "numeric".into(),
7526 got: v.data_type().to_string(),
7527 })
7528 }
7529 }
7530 }
7531 if count == 0 {
7532 Ok(Value::Null)
7533 } else {
7534 Ok(Value::Real(sum / count as f64))
7535 }
7536 }
7537 "MIN" => {
7538 let mut min: Option<&Value> = None;
7539 for v in &values {
7540 if v.is_null() {
7541 continue;
7542 }
7543 min = Some(match min {
7544 None => v,
7545 Some(m) => {
7546 if v < m {
7547 v
7548 } else {
7549 m
7550 }
7551 }
7552 });
7553 }
7554 Ok(min.cloned().unwrap_or(Value::Null))
7555 }
7556 "MAX" => {
7557 let mut max: Option<&Value> = None;
7558 for v in &values {
7559 if v.is_null() {
7560 continue;
7561 }
7562 max = Some(match max {
7563 None => v,
7564 Some(m) => {
7565 if v > m {
7566 v
7567 } else {
7568 m
7569 }
7570 }
7571 });
7572 }
7573 Ok(max.cloned().unwrap_or(Value::Null))
7574 }
7575 _ => Err(SqlError::Unsupported(format!("aggregate function: {func}"))),
7576 }
7577 }
7578
7579 Expr::Column(_) | Expr::QualifiedColumn { .. } => {
7580 if let Some(first) = group_rows.first() {
7581 eval_expr(expr, col_map, first)
7582 } else {
7583 Ok(Value::Null)
7584 }
7585 }
7586
7587 Expr::Literal(v) => Ok(v.clone()),
7588
7589 Expr::BinaryOp { left, op, right } => {
7590 let l = eval_aggregate_expr(left, col_map, group_rows)?;
7591 let r = eval_aggregate_expr(right, col_map, group_rows)?;
7592 eval_expr(
7593 &Expr::BinaryOp {
7594 left: Box::new(Expr::Literal(l)),
7595 op: *op,
7596 right: Box::new(Expr::Literal(r)),
7597 },
7598 col_map,
7599 &[],
7600 )
7601 }
7602
7603 Expr::UnaryOp { op, expr: e } => {
7604 let v = eval_aggregate_expr(e, col_map, group_rows)?;
7605 eval_expr(
7606 &Expr::UnaryOp {
7607 op: *op,
7608 expr: Box::new(Expr::Literal(v)),
7609 },
7610 col_map,
7611 &[],
7612 )
7613 }
7614
7615 Expr::IsNull(e) => {
7616 let v = eval_aggregate_expr(e, col_map, group_rows)?;
7617 Ok(Value::Boolean(v.is_null()))
7618 }
7619
7620 Expr::IsNotNull(e) => {
7621 let v = eval_aggregate_expr(e, col_map, group_rows)?;
7622 Ok(Value::Boolean(!v.is_null()))
7623 }
7624
7625 Expr::Cast { expr: e, data_type } => {
7626 let v = eval_aggregate_expr(e, col_map, group_rows)?;
7627 eval_expr(
7628 &Expr::Cast {
7629 expr: Box::new(Expr::Literal(v)),
7630 data_type: *data_type,
7631 },
7632 col_map,
7633 &[],
7634 )
7635 }
7636
7637 Expr::Case {
7638 operand,
7639 conditions,
7640 else_result,
7641 } => {
7642 let op_val = operand
7643 .as_ref()
7644 .map(|e| eval_aggregate_expr(e, col_map, group_rows))
7645 .transpose()?;
7646 if let Some(ov) = &op_val {
7647 for (cond, result) in conditions {
7648 let cv = eval_aggregate_expr(cond, col_map, group_rows)?;
7649 if !ov.is_null() && !cv.is_null() && *ov == cv {
7650 return eval_aggregate_expr(result, col_map, group_rows);
7651 }
7652 }
7653 } else {
7654 for (cond, result) in conditions {
7655 let cv = eval_aggregate_expr(cond, col_map, group_rows)?;
7656 if is_truthy(&cv) {
7657 return eval_aggregate_expr(result, col_map, group_rows);
7658 }
7659 }
7660 }
7661 match else_result {
7662 Some(e) => eval_aggregate_expr(e, col_map, group_rows),
7663 None => Ok(Value::Null),
7664 }
7665 }
7666
7667 Expr::Coalesce(args) => {
7668 for arg in args {
7669 let v = eval_aggregate_expr(arg, col_map, group_rows)?;
7670 if !v.is_null() {
7671 return Ok(v);
7672 }
7673 }
7674 Ok(Value::Null)
7675 }
7676
7677 Expr::Between {
7678 expr: e,
7679 low,
7680 high,
7681 negated,
7682 } => {
7683 let v = eval_aggregate_expr(e, col_map, group_rows)?;
7684 let lo = eval_aggregate_expr(low, col_map, group_rows)?;
7685 let hi = eval_aggregate_expr(high, col_map, group_rows)?;
7686 eval_expr(
7687 &Expr::Between {
7688 expr: Box::new(Expr::Literal(v)),
7689 low: Box::new(Expr::Literal(lo)),
7690 high: Box::new(Expr::Literal(hi)),
7691 negated: *negated,
7692 },
7693 col_map,
7694 &[],
7695 )
7696 }
7697
7698 Expr::Like {
7699 expr: e,
7700 pattern,
7701 escape,
7702 negated,
7703 } => {
7704 let v = eval_aggregate_expr(e, col_map, group_rows)?;
7705 let p = eval_aggregate_expr(pattern, col_map, group_rows)?;
7706 let esc = escape
7707 .as_ref()
7708 .map(|es| eval_aggregate_expr(es, col_map, group_rows))
7709 .transpose()?;
7710 let esc_box = esc.map(|v| Box::new(Expr::Literal(v)));
7711 eval_expr(
7712 &Expr::Like {
7713 expr: Box::new(Expr::Literal(v)),
7714 pattern: Box::new(Expr::Literal(p)),
7715 escape: esc_box,
7716 negated: *negated,
7717 },
7718 col_map,
7719 &[],
7720 )
7721 }
7722
7723 Expr::Function { name, args } => {
7724 let evaluated: Vec<Value> = args
7725 .iter()
7726 .map(|a| eval_aggregate_expr(a, col_map, group_rows))
7727 .collect::<Result<_>>()?;
7728 let literal_args: Vec<Expr> = evaluated.into_iter().map(Expr::Literal).collect();
7729 eval_expr(
7730 &Expr::Function {
7731 name: name.clone(),
7732 args: literal_args,
7733 },
7734 col_map,
7735 &[],
7736 )
7737 }
7738
7739 _ => Err(SqlError::Unsupported(format!(
7740 "expression in aggregate: {expr:?}"
7741 ))),
7742 }
7743}
7744
7745fn is_aggregate_function(name: &str, arg_count: usize) -> bool {
7746 let u = name.to_ascii_uppercase();
7747 matches!(u.as_str(), "COUNT" | "SUM" | "AVG")
7748 || (matches!(u.as_str(), "MIN" | "MAX") && arg_count == 1)
7749}
7750
7751fn is_aggregate_expr(expr: &Expr) -> bool {
7752 match expr {
7753 Expr::CountStar => true,
7754 Expr::Function { name, args } => {
7755 is_aggregate_function(name, args.len()) || args.iter().any(is_aggregate_expr)
7756 }
7757 Expr::BinaryOp { left, right, .. } => is_aggregate_expr(left) || is_aggregate_expr(right),
7758 Expr::UnaryOp { expr, .. }
7759 | Expr::IsNull(expr)
7760 | Expr::IsNotNull(expr)
7761 | Expr::Cast { expr, .. } => is_aggregate_expr(expr),
7762 Expr::Case {
7763 operand,
7764 conditions,
7765 else_result,
7766 } => {
7767 operand.as_ref().is_some_and(|e| is_aggregate_expr(e))
7768 || conditions
7769 .iter()
7770 .any(|(c, r)| is_aggregate_expr(c) || is_aggregate_expr(r))
7771 || else_result.as_ref().is_some_and(|e| is_aggregate_expr(e))
7772 }
7773 Expr::Coalesce(args) => args.iter().any(is_aggregate_expr),
7774 Expr::Between {
7775 expr, low, high, ..
7776 } => is_aggregate_expr(expr) || is_aggregate_expr(low) || is_aggregate_expr(high),
7777 Expr::Like {
7778 expr,
7779 pattern,
7780 escape,
7781 ..
7782 } => {
7783 is_aggregate_expr(expr)
7784 || is_aggregate_expr(pattern)
7785 || escape.as_ref().is_some_and(|e| is_aggregate_expr(e))
7786 }
7787 Expr::WindowFunction { .. } => false,
7788 _ => false,
7789 }
7790}
7791
7792use std::collections::VecDeque;
7795
7796fn has_window_function(expr: &Expr) -> bool {
7797 match expr {
7798 Expr::WindowFunction { .. } => true,
7799 Expr::BinaryOp { left, right, .. } => {
7800 has_window_function(left) || has_window_function(right)
7801 }
7802 Expr::UnaryOp { expr: e, .. }
7803 | Expr::IsNull(e)
7804 | Expr::IsNotNull(e)
7805 | Expr::Cast { expr: e, .. } => has_window_function(e),
7806 Expr::Function { args, .. } | Expr::Coalesce(args) => args.iter().any(has_window_function),
7807 Expr::Case {
7808 operand,
7809 conditions,
7810 else_result,
7811 } => {
7812 operand.as_ref().is_some_and(|e| has_window_function(e))
7813 || conditions
7814 .iter()
7815 .any(|(c, r)| has_window_function(c) || has_window_function(r))
7816 || else_result.as_ref().is_some_and(|e| has_window_function(e))
7817 }
7818 _ => false,
7819 }
7820}
7821
7822fn has_any_window_function(stmt: &SelectStmt) -> bool {
7823 stmt.columns.iter().any(|c| match c {
7824 SelectColumn::Expr { expr, .. } => has_window_function(expr),
7825 _ => false,
7826 })
7827}
7828
7829fn extract_window_fns(
7832 expr: &Expr,
7833 slot_counter: &mut usize,
7834 extracted: &mut Vec<(String, String, Vec<Expr>, WindowSpec)>,
7835) -> Expr {
7836 match expr {
7837 Expr::WindowFunction { name, args, spec } => {
7838 let slot_name = format!("__win_{}", *slot_counter);
7839 *slot_counter += 1;
7840 extracted.push((slot_name.clone(), name.clone(), args.clone(), spec.clone()));
7841 Expr::Column(slot_name)
7842 }
7843 Expr::BinaryOp { left, op, right } => Expr::BinaryOp {
7844 left: Box::new(extract_window_fns(left, slot_counter, extracted)),
7845 op: *op,
7846 right: Box::new(extract_window_fns(right, slot_counter, extracted)),
7847 },
7848 Expr::UnaryOp { op, expr: e } => Expr::UnaryOp {
7849 op: *op,
7850 expr: Box::new(extract_window_fns(e, slot_counter, extracted)),
7851 },
7852 Expr::IsNull(e) => Expr::IsNull(Box::new(extract_window_fns(e, slot_counter, extracted))),
7853 Expr::IsNotNull(e) => {
7854 Expr::IsNotNull(Box::new(extract_window_fns(e, slot_counter, extracted)))
7855 }
7856 Expr::Cast { expr: e, data_type } => Expr::Cast {
7857 expr: Box::new(extract_window_fns(e, slot_counter, extracted)),
7858 data_type: *data_type,
7859 },
7860 Expr::Function { name, args } => Expr::Function {
7861 name: name.clone(),
7862 args: args
7863 .iter()
7864 .map(|a| extract_window_fns(a, slot_counter, extracted))
7865 .collect(),
7866 },
7867 Expr::Coalesce(args) => Expr::Coalesce(
7868 args.iter()
7869 .map(|a| extract_window_fns(a, slot_counter, extracted))
7870 .collect(),
7871 ),
7872 Expr::Case {
7873 operand,
7874 conditions,
7875 else_result,
7876 } => Expr::Case {
7877 operand: operand
7878 .as_ref()
7879 .map(|e| Box::new(extract_window_fns(e, slot_counter, extracted))),
7880 conditions: conditions
7881 .iter()
7882 .map(|(c, r)| {
7883 (
7884 extract_window_fns(c, slot_counter, extracted),
7885 extract_window_fns(r, slot_counter, extracted),
7886 )
7887 })
7888 .collect(),
7889 else_result: else_result
7890 .as_ref()
7891 .map(|e| Box::new(extract_window_fns(e, slot_counter, extracted))),
7892 },
7893 other => other.clone(),
7894 }
7895}
7896
7897fn resolve_frame(spec: &WindowSpec) -> WindowFrame {
7899 if let Some(ref frame) = spec.frame {
7900 return frame.clone();
7901 }
7902 if spec.order_by.is_empty() {
7903 WindowFrame {
7904 units: WindowFrameUnits::Range,
7905 start: WindowFrameBound::UnboundedPreceding,
7906 end: WindowFrameBound::UnboundedFollowing,
7907 }
7908 } else {
7909 WindowFrame {
7910 units: WindowFrameUnits::Range,
7911 start: WindowFrameBound::UnboundedPreceding,
7912 end: WindowFrameBound::CurrentRow,
7913 }
7914 }
7915}
7916
7917fn rows_frame_indices(frame: &WindowFrame, i: usize, n: usize) -> Result<(usize, usize)> {
7919 let start = match &frame.start {
7920 WindowFrameBound::UnboundedPreceding => 0,
7921 WindowFrameBound::Preceding(e) => {
7922 let k = eval_const_int(e)? as usize;
7923 i.saturating_sub(k)
7924 }
7925 WindowFrameBound::CurrentRow => i,
7926 WindowFrameBound::Following(e) => {
7927 let k = eval_const_int(e)? as usize;
7928 (i + k).min(n - 1)
7929 }
7930 WindowFrameBound::UnboundedFollowing => n - 1,
7931 };
7932 let end = match &frame.end {
7933 WindowFrameBound::UnboundedPreceding => 0,
7934 WindowFrameBound::Preceding(e) => {
7935 let k = eval_const_int(e)? as usize;
7936 i.saturating_sub(k)
7937 }
7938 WindowFrameBound::CurrentRow => i,
7939 WindowFrameBound::Following(e) => {
7940 let k = eval_const_int(e)? as usize;
7941 (i + k).min(n - 1)
7942 }
7943 WindowFrameBound::UnboundedFollowing => n - 1,
7944 };
7945 Ok((start, end.min(n - 1)))
7946}
7947
7948fn find_peer_range(
7950 rows: &[Vec<Value>],
7951 order_by: &[OrderByItem],
7952 col_map: &ColumnMap,
7953 i: usize,
7954) -> (usize, usize) {
7955 let key: Vec<Value> = order_by
7956 .iter()
7957 .map(|o| eval_expr(&o.expr, col_map, &rows[i]).unwrap_or(Value::Null))
7958 .collect();
7959 let mut start = i;
7960 while start > 0 {
7961 let prev_key: Vec<Value> = order_by
7962 .iter()
7963 .map(|o| eval_expr(&o.expr, col_map, &rows[start - 1]).unwrap_or(Value::Null))
7964 .collect();
7965 if prev_key != key {
7966 break;
7967 }
7968 start -= 1;
7969 }
7970 let mut end = i;
7971 while end + 1 < rows.len() {
7972 let next_key: Vec<Value> = order_by
7973 .iter()
7974 .map(|o| eval_expr(&o.expr, col_map, &rows[end + 1]).unwrap_or(Value::Null))
7975 .collect();
7976 if next_key != key {
7977 break;
7978 }
7979 end += 1;
7980 }
7981 (start, end)
7982}
7983
7984fn frame_indices(
7986 frame: &WindowFrame,
7987 i: usize,
7988 n: usize,
7989 rows: &[Vec<Value>],
7990 order_by: &[OrderByItem],
7991 col_map: &ColumnMap,
7992) -> Result<(usize, usize)> {
7993 match frame.units {
7994 WindowFrameUnits::Rows => rows_frame_indices(frame, i, n),
7995 WindowFrameUnits::Range => {
7996 let start = match &frame.start {
7998 WindowFrameBound::UnboundedPreceding => 0,
7999 WindowFrameBound::CurrentRow => find_peer_range(rows, order_by, col_map, i).0,
8000 _ => return Err(SqlError::Unsupported("RANGE with numeric offset".into())),
8001 };
8002 let end = match &frame.end {
8003 WindowFrameBound::UnboundedFollowing => n - 1,
8004 WindowFrameBound::CurrentRow => find_peer_range(rows, order_by, col_map, i).1,
8005 _ => return Err(SqlError::Unsupported("RANGE with numeric offset".into())),
8006 };
8007 Ok((start, end))
8008 }
8009 WindowFrameUnits::Groups => Err(SqlError::Unsupported("GROUPS window frame".into())),
8010 }
8011}
8012
8013struct MonoDeque {
8015 deque: VecDeque<(usize, Value)>,
8016 is_min: bool,
8017}
8018
8019impl MonoDeque {
8020 fn new(is_min: bool) -> Self {
8021 Self {
8022 deque: VecDeque::new(),
8023 is_min,
8024 }
8025 }
8026
8027 fn push(&mut self, idx: usize, val: Value) {
8028 if val.is_null() {
8029 return;
8030 }
8031 while let Some(back) = self.deque.back() {
8032 let evict = if self.is_min {
8033 val <= back.1
8034 } else {
8035 val >= back.1
8036 };
8037 if evict {
8038 self.deque.pop_back();
8039 } else {
8040 break;
8041 }
8042 }
8043 self.deque.push_back((idx, val));
8044 }
8045
8046 fn pop_expired(&mut self, frame_start: usize) {
8047 while let Some(front) = self.deque.front() {
8048 if front.0 < frame_start {
8049 self.deque.pop_front();
8050 } else {
8051 break;
8052 }
8053 }
8054 }
8055
8056 fn current(&self) -> Value {
8057 self.deque
8058 .front()
8059 .map(|(_, v)| v.clone())
8060 .unwrap_or(Value::Null)
8061 }
8062}
8063
8064struct SlidingSum {
8066 int_sum: i64,
8067 real_sum: f64,
8068 has_real: bool,
8069 count: i64,
8070}
8071
8072impl SlidingSum {
8073 fn new() -> Self {
8074 Self {
8075 int_sum: 0,
8076 real_sum: 0.0,
8077 has_real: false,
8078 count: 0,
8079 }
8080 }
8081
8082 fn add(&mut self, val: &Value) {
8083 match val {
8084 Value::Integer(i) => {
8085 self.int_sum += i;
8086 self.count += 1;
8087 }
8088 Value::Real(r) => {
8089 self.real_sum += r;
8090 self.has_real = true;
8091 self.count += 1;
8092 }
8093 _ => {}
8094 }
8095 }
8096
8097 fn remove(&mut self, val: &Value) {
8098 match val {
8099 Value::Integer(i) => {
8100 self.int_sum -= i;
8101 self.count -= 1;
8102 }
8103 Value::Real(r) => {
8104 self.real_sum -= r;
8105 self.count -= 1;
8106 }
8107 _ => {}
8108 }
8109 }
8110
8111 fn result_sum(&self) -> Value {
8112 if self.count == 0 && !self.has_real {
8113 Value::Null
8114 } else if self.has_real {
8115 Value::Real(self.real_sum + self.int_sum as f64)
8116 } else {
8117 Value::Integer(self.int_sum)
8118 }
8119 }
8120
8121 fn result_count(&self) -> Value {
8122 Value::Integer(self.count)
8123 }
8124
8125 fn result_avg(&self) -> Value {
8126 if self.count == 0 {
8127 Value::Null
8128 } else {
8129 let total = if self.has_real {
8130 self.real_sum + self.int_sum as f64
8131 } else {
8132 self.int_sum as f64
8133 };
8134 Value::Real(total / self.count as f64)
8135 }
8136 }
8137}
8138
8139fn eval_window_select(
8140 columns: &[ColumnDef],
8141 mut rows: Vec<Vec<Value>>,
8142 stmt: &SelectStmt,
8143) -> Result<ExecutionResult> {
8144 if rows.is_empty() {
8145 let col_names = stmt
8146 .columns
8147 .iter()
8148 .map(|c| match c {
8149 SelectColumn::AllColumns => "*".into(),
8150 SelectColumn::Expr { alias: Some(a), .. } => a.clone(),
8151 SelectColumn::Expr { expr, .. } => expr_display_name(expr),
8152 })
8153 .collect();
8154 return Ok(ExecutionResult::Query(QueryResult {
8155 columns: col_names,
8156 rows: vec![],
8157 }));
8158 }
8159
8160 let mut slot_counter = 0usize;
8162 let mut all_extracted: Vec<(String, String, Vec<Expr>, WindowSpec)> = Vec::new();
8163 let mut rewritten_columns: Vec<SelectColumn> = Vec::new();
8164
8165 for col in &stmt.columns {
8166 match col {
8167 SelectColumn::AllColumns => rewritten_columns.push(SelectColumn::AllColumns),
8168 SelectColumn::Expr { expr, alias } => {
8169 let new_expr = extract_window_fns(expr, &mut slot_counter, &mut all_extracted);
8170 rewritten_columns.push(SelectColumn::Expr {
8171 expr: new_expr,
8172 alias: alias.clone(),
8173 });
8174 }
8175 }
8176 }
8177
8178 if all_extracted.is_empty() {
8179 return process_select(columns, rows, stmt, false);
8180 }
8181
8182 let col_map = ColumnMap::new(columns);
8184 let num_win = all_extracted.len();
8185 let mut arg_values: Vec<Vec<Vec<Value>>> = Vec::with_capacity(num_win);
8186 for (_, _, args, _) in &all_extracted {
8187 let mut per_row = Vec::with_capacity(rows.len());
8188 for row in &rows {
8189 let vals: Vec<Value> = args
8190 .iter()
8191 .map(|a| eval_expr(a, &col_map, row).unwrap_or(Value::Null))
8192 .collect();
8193 per_row.push(vals);
8194 }
8195 arg_values.push(per_row);
8196 }
8197
8198 let n = rows.len();
8202 let mut row_results: Vec<Vec<Value>> = (0..n).map(|_| vec![Value::Null; num_win]).collect();
8203
8204 for (win_idx, (_, fn_name, _, spec)) in all_extracted.iter().enumerate() {
8205 let mut sort_keys: Vec<OrderByItem> = Vec::new();
8207 for pb in &spec.partition_by {
8208 sort_keys.push(OrderByItem {
8209 expr: pb.clone(),
8210 descending: false,
8211 nulls_first: Some(true),
8212 });
8213 }
8214 sort_keys.extend(spec.order_by.clone());
8215
8216 let mut indices: Vec<usize> = (0..n).collect();
8218 if !sort_keys.is_empty() {
8219 let keys: Vec<Vec<Value>> = indices
8220 .iter()
8221 .map(|&i| {
8222 sort_keys
8223 .iter()
8224 .map(|o| eval_expr(&o.expr, &col_map, &rows[i]).unwrap_or(Value::Null))
8225 .collect()
8226 })
8227 .collect();
8228 indices.sort_by(|&a, &b| compare_sort_keys(&keys[a], &keys[b], &sort_keys));
8229 }
8230
8231 let part_count = spec.partition_by.len();
8233 let mut partitions: Vec<(usize, usize)> = Vec::new();
8234 let mut part_start = 0;
8235 for pos in 1..n {
8236 let mut same = true;
8237 if part_count > 0 {
8238 for p in 0..part_count {
8239 let prev = eval_expr(&spec.partition_by[p], &col_map, &rows[indices[pos - 1]])
8240 .unwrap_or(Value::Null);
8241 let cur = eval_expr(&spec.partition_by[p], &col_map, &rows[indices[pos]])
8242 .unwrap_or(Value::Null);
8243 if prev != cur {
8244 same = false;
8245 break;
8246 }
8247 }
8248 }
8249 if !same {
8250 partitions.push((part_start, pos));
8251 part_start = pos;
8252 }
8253 }
8254 partitions.push((part_start, n));
8255
8256 let frame = resolve_frame(spec);
8257 let upper_name = fn_name.to_ascii_uppercase();
8258
8259 for &(ps, pe) in &partitions {
8261 let part_len = pe - ps;
8262 let part_indices = &indices[ps..pe];
8263
8264 match upper_name.as_str() {
8267 "ROW_NUMBER" => {
8268 for (rank, &orig_idx) in part_indices.iter().enumerate() {
8269 row_results[orig_idx][win_idx] = Value::Integer(rank as i64 + 1);
8270 }
8271 }
8272 "RANK" => {
8273 if spec.order_by.is_empty() {
8274 return Err(SqlError::WindowFunctionRequiresOrderBy("RANK".into()));
8275 }
8276 let mut rank = 1i64;
8277 let mut prev_key: Option<Vec<Value>> = None;
8278 for (pos, &orig_idx) in part_indices.iter().enumerate() {
8279 let key: Vec<Value> = spec
8280 .order_by
8281 .iter()
8282 .map(|o| {
8283 eval_expr(&o.expr, &col_map, &rows[orig_idx]).unwrap_or(Value::Null)
8284 })
8285 .collect();
8286 if let Some(ref pk) = prev_key {
8287 if &key != pk {
8288 rank = pos as i64 + 1;
8289 }
8290 }
8291 row_results[orig_idx][win_idx] = Value::Integer(rank);
8292 prev_key = Some(key);
8293 }
8294 }
8295 "DENSE_RANK" => {
8296 if spec.order_by.is_empty() {
8297 return Err(SqlError::WindowFunctionRequiresOrderBy("DENSE_RANK".into()));
8298 }
8299 let mut rank = 1i64;
8300 let mut prev_key: Option<Vec<Value>> = None;
8301 for &orig_idx in part_indices {
8302 let key: Vec<Value> = spec
8303 .order_by
8304 .iter()
8305 .map(|o| {
8306 eval_expr(&o.expr, &col_map, &rows[orig_idx]).unwrap_or(Value::Null)
8307 })
8308 .collect();
8309 if let Some(ref pk) = prev_key {
8310 if &key != pk {
8311 rank += 1;
8312 }
8313 }
8314 row_results[orig_idx][win_idx] = Value::Integer(rank);
8315 prev_key = Some(key);
8316 }
8317 }
8318 "NTILE" => {
8319 let ntile_n = if arg_values[win_idx][0].is_empty() {
8320 return Err(SqlError::Parse("NTILE requires one argument".into()));
8321 } else {
8322 match &arg_values[win_idx][part_indices[0]][0] {
8323 Value::Integer(n) if *n > 0 => *n as usize,
8324 _ => {
8325 return Err(SqlError::InvalidValue(
8326 "NTILE argument must be a positive integer".into(),
8327 ))
8328 }
8329 }
8330 };
8331 let base = part_len / ntile_n;
8332 let remainder = part_len % ntile_n;
8333 let mut bucket = 1usize;
8334 let mut count_in_bucket = 0usize;
8335 let bucket_size = |b: usize| -> usize {
8336 if b <= remainder {
8337 base + 1
8338 } else {
8339 base
8340 }
8341 };
8342 for &orig_idx in part_indices {
8343 row_results[orig_idx][win_idx] = Value::Integer(bucket as i64);
8344 count_in_bucket += 1;
8345 if count_in_bucket >= bucket_size(bucket) && bucket < ntile_n {
8346 bucket += 1;
8347 count_in_bucket = 0;
8348 }
8349 }
8350 }
8351 "LAG" | "LEAD" => {
8352 let offset = if arg_values[win_idx][0].len() >= 2 {
8353 match &arg_values[win_idx][0][1] {
8354 Value::Integer(n) => *n as usize,
8355 _ => 1,
8356 }
8357 } else {
8358 1
8359 };
8360 let default_val = if arg_values[win_idx][0].len() >= 3 {
8361 arg_values[win_idx][0][2].clone()
8362 } else {
8363 Value::Null
8364 };
8365 let is_lag = upper_name == "LAG";
8366 for (pos, &orig_idx) in part_indices.iter().enumerate() {
8367 let target_pos = if is_lag {
8368 if pos >= offset {
8369 Some(pos - offset)
8370 } else {
8371 None
8372 }
8373 } else if pos + offset < part_len {
8374 Some(pos + offset)
8375 } else {
8376 None
8377 };
8378 let val = match target_pos {
8379 Some(tp) => arg_values[win_idx][part_indices[tp]][0].clone(),
8380 None => default_val.clone(),
8381 };
8382 row_results[orig_idx][win_idx] = val;
8383 }
8384 }
8385 "FIRST_VALUE" => {
8386 for (pos, &orig_idx) in part_indices.iter().enumerate() {
8387 let (fs, _) = frame_indices(
8388 &frame,
8389 pos,
8390 part_len,
8391 &part_indices
8392 .iter()
8393 .map(|&i| rows[i].clone())
8394 .collect::<Vec<_>>(),
8395 &spec.order_by,
8396 &col_map,
8397 )?;
8398 let source_idx = part_indices[fs];
8399 row_results[orig_idx][win_idx] = arg_values[win_idx][source_idx][0].clone();
8400 }
8401 }
8402 "LAST_VALUE" => {
8403 for (pos, &orig_idx) in part_indices.iter().enumerate() {
8404 let (_, fe) = frame_indices(
8405 &frame,
8406 pos,
8407 part_len,
8408 &part_indices
8409 .iter()
8410 .map(|&i| rows[i].clone())
8411 .collect::<Vec<_>>(),
8412 &spec.order_by,
8413 &col_map,
8414 )?;
8415 let source_idx = part_indices[fe];
8416 row_results[orig_idx][win_idx] = arg_values[win_idx][source_idx][0].clone();
8417 }
8418 }
8419 "SUM" | "COUNT" | "AVG" => {
8420 let is_count_star = upper_name == "COUNT" && arg_values[win_idx][0].is_empty();
8421 if matches!(frame.units, WindowFrameUnits::Rows)
8423 && matches!(
8424 frame.start,
8425 WindowFrameBound::UnboundedPreceding | WindowFrameBound::Preceding(_)
8426 )
8427 && matches!(
8428 frame.end,
8429 WindowFrameBound::CurrentRow | WindowFrameBound::Following(_)
8430 )
8431 {
8432 let mut acc = SlidingSum::new();
8434 let mut prev_start = 0usize;
8435 for (pos, &orig_idx) in part_indices.iter().enumerate() {
8436 let (fs, fe) = rows_frame_indices(&frame, pos, part_len)?;
8437 while prev_start < fs {
8439 if is_count_star {
8440 acc.count -= 1;
8441 } else {
8442 acc.remove(&arg_values[win_idx][part_indices[prev_start]][0]);
8443 }
8444 prev_start += 1;
8445 }
8446 let add_from = if pos == 0 {
8448 fs
8449 } else {
8450 let (_, prev_fe) = rows_frame_indices(&frame, pos - 1, part_len)?;
8451 prev_fe + 1
8452 };
8453 for add_pos in add_from..=fe {
8454 if is_count_star {
8455 acc.count += 1;
8456 } else {
8457 acc.add(&arg_values[win_idx][part_indices[add_pos]][0]);
8458 }
8459 }
8460 row_results[orig_idx][win_idx] = match upper_name.as_str() {
8461 "SUM" => acc.result_sum(),
8462 "COUNT" => acc.result_count(),
8463 "AVG" => acc.result_avg(),
8464 _ => unreachable!(),
8465 };
8466 }
8467 } else {
8468 for (pos, &orig_idx) in part_indices.iter().enumerate() {
8470 let part_rows: Vec<Vec<Value>> =
8471 part_indices.iter().map(|&i| rows[i].clone()).collect();
8472 let (fs, fe) = frame_indices(
8473 &frame,
8474 pos,
8475 part_len,
8476 &part_rows,
8477 &spec.order_by,
8478 &col_map,
8479 )?;
8480 let mut acc = SlidingSum::new();
8481 for fpos in fs..=fe {
8482 if is_count_star {
8483 acc.count += 1;
8484 } else {
8485 acc.add(&arg_values[win_idx][part_indices[fpos]][0]);
8486 }
8487 }
8488 row_results[orig_idx][win_idx] = match upper_name.as_str() {
8489 "SUM" => acc.result_sum(),
8490 "COUNT" => acc.result_count(),
8491 "AVG" => acc.result_avg(),
8492 _ => unreachable!(),
8493 };
8494 }
8495 }
8496 }
8497 "MIN" | "MAX" => {
8498 let is_min = upper_name == "MIN";
8499 if matches!(frame.units, WindowFrameUnits::Rows)
8500 && matches!(
8501 frame.start,
8502 WindowFrameBound::UnboundedPreceding | WindowFrameBound::Preceding(_)
8503 )
8504 && matches!(
8505 frame.end,
8506 WindowFrameBound::CurrentRow | WindowFrameBound::Following(_)
8507 )
8508 {
8509 let mut deque = MonoDeque::new(is_min);
8511 let mut prev_end: Option<usize> = None;
8512 for (pos, &orig_idx) in part_indices.iter().enumerate() {
8513 let (fs, fe) = rows_frame_indices(&frame, pos, part_len)?;
8514 let add_from = prev_end.map(|pe| pe + 1).unwrap_or(fs);
8516 for add_pos in add_from..=fe {
8517 deque.push(
8518 add_pos,
8519 arg_values[win_idx][part_indices[add_pos]][0].clone(),
8520 );
8521 }
8522 deque.pop_expired(fs);
8523 row_results[orig_idx][win_idx] = deque.current();
8524 prev_end = Some(fe);
8525 }
8526 } else {
8527 for (pos, &orig_idx) in part_indices.iter().enumerate() {
8529 let part_rows: Vec<Vec<Value>> =
8530 part_indices.iter().map(|&i| rows[i].clone()).collect();
8531 let (fs, fe) = frame_indices(
8532 &frame,
8533 pos,
8534 part_len,
8535 &part_rows,
8536 &spec.order_by,
8537 &col_map,
8538 )?;
8539 let mut result = Value::Null;
8540 for fpos in fs..=fe {
8541 let v = &arg_values[win_idx][part_indices[fpos]][0];
8542 if !v.is_null() {
8543 result = match result {
8544 Value::Null => v.clone(),
8545 ref cur => {
8546 if (is_min && v < cur) || (!is_min && v > cur) {
8547 v.clone()
8548 } else {
8549 cur.clone()
8550 }
8551 }
8552 };
8553 }
8554 }
8555 row_results[orig_idx][win_idx] = result;
8556 }
8557 }
8558 }
8559 other => {
8560 return Err(SqlError::Unsupported(format!("window function: {other}")));
8561 }
8562 }
8563 }
8564 }
8565
8566 let base_col_count = columns.len();
8568 let mut extended_columns: Vec<ColumnDef> = columns.to_vec();
8569 for (i, (slot_name, _, _, _)) in all_extracted.iter().enumerate() {
8570 extended_columns.push(ColumnDef {
8571 name: slot_name.clone(),
8572 data_type: DataType::Null,
8573 nullable: true,
8574 position: (base_col_count + i) as u16,
8575 default_expr: None,
8576 default_sql: None,
8577 check_expr: None,
8578 check_sql: None,
8579 check_name: None,
8580 });
8581 }
8582
8583 for (row_idx, row) in rows.iter_mut().enumerate() {
8584 row.extend_from_slice(&row_results[row_idx]);
8585 }
8586
8587 let rewritten_stmt = SelectStmt {
8589 columns: rewritten_columns,
8590 from: stmt.from.clone(),
8591 from_alias: stmt.from_alias.clone(),
8592 joins: stmt.joins.clone(),
8593 distinct: stmt.distinct,
8594 where_clause: None, order_by: stmt.order_by.clone(),
8596 limit: stmt.limit.clone(),
8597 offset: stmt.offset.clone(),
8598 group_by: vec![],
8599 having: None,
8600 };
8601
8602 process_select(&extended_columns, rows, &rewritten_stmt, true)
8603}
8604
8605struct PartialDecodeCtx {
8608 pk_positions: Vec<(usize, usize)>,
8609 nonpk_targets: Vec<usize>,
8610 nonpk_schema: Vec<usize>,
8611 num_cols: usize,
8612 num_pk_cols: usize,
8613 remaining_pk: Vec<(usize, usize)>,
8614 remaining_nonpk_targets: Vec<usize>,
8615 remaining_nonpk_schema: Vec<usize>,
8616 nonpk_defaults: Vec<(usize, usize, Value)>,
8617 remaining_defaults: Vec<(usize, usize, Value)>,
8618}
8619
8620impl PartialDecodeCtx {
8621 fn new(schema: &TableSchema, needed: &[usize]) -> Self {
8622 let non_pk = schema.non_pk_indices();
8623 let enc_pos = schema.encoding_positions();
8624 let mut pk_positions = Vec::new();
8625 let mut nonpk_targets = Vec::new();
8626 let mut nonpk_schema = Vec::new();
8627
8628 for &col in needed {
8629 if let Some(pk_pos) = schema
8630 .primary_key_columns
8631 .iter()
8632 .position(|&i| i as usize == col)
8633 {
8634 pk_positions.push((pk_pos, col));
8635 } else if let Some(nonpk_order) = non_pk.iter().position(|&i| i == col) {
8636 nonpk_targets.push(enc_pos[nonpk_order] as usize);
8637 nonpk_schema.push(col);
8638 }
8639 }
8640
8641 let needed_set: std::collections::HashSet<usize> = needed.iter().copied().collect();
8642 let mut remaining_pk = Vec::new();
8643 for (pk_pos, &pk_col) in schema.primary_key_columns.iter().enumerate() {
8644 if !needed_set.contains(&(pk_col as usize)) {
8645 remaining_pk.push((pk_pos, pk_col as usize));
8646 }
8647 }
8648 let mut remaining_nonpk_targets = Vec::new();
8649 let mut remaining_nonpk_schema = Vec::new();
8650 for (nonpk_order, &col) in non_pk.iter().enumerate() {
8651 if !needed_set.contains(&col) {
8652 remaining_nonpk_targets.push(enc_pos[nonpk_order] as usize);
8653 remaining_nonpk_schema.push(col);
8654 }
8655 }
8656
8657 let mut nonpk_defaults = Vec::new();
8658 for (&phys_pos, &schema_col) in nonpk_targets.iter().zip(nonpk_schema.iter()) {
8659 if let Some(ref expr) = schema.columns[schema_col].default_expr {
8660 if let Ok(val) = eval_const_expr(expr) {
8661 nonpk_defaults.push((phys_pos, schema_col, val));
8662 }
8663 }
8664 }
8665 let mut remaining_defaults = Vec::new();
8666 for (&phys_pos, &schema_col) in remaining_nonpk_targets
8667 .iter()
8668 .zip(remaining_nonpk_schema.iter())
8669 {
8670 if let Some(ref expr) = schema.columns[schema_col].default_expr {
8671 if let Ok(val) = eval_const_expr(expr) {
8672 remaining_defaults.push((phys_pos, schema_col, val));
8673 }
8674 }
8675 }
8676
8677 Self {
8678 pk_positions,
8679 nonpk_targets,
8680 nonpk_schema,
8681 num_cols: schema.columns.len(),
8682 num_pk_cols: schema.primary_key_columns.len(),
8683 remaining_pk,
8684 remaining_nonpk_targets,
8685 remaining_nonpk_schema,
8686 nonpk_defaults,
8687 remaining_defaults,
8688 }
8689 }
8690
8691 fn decode(&self, key: &[u8], value: &[u8]) -> Result<Vec<Value>> {
8692 let mut row = vec![Value::Null; self.num_cols];
8693
8694 if self.pk_positions.len() == 1 && self.num_pk_cols == 1 {
8695 let (_, schema_col) = self.pk_positions[0];
8696 let (v, _) = decode_key_value(key)?;
8697 row[schema_col] = v;
8698 } else if !self.pk_positions.is_empty() {
8699 let mut pk_values = decode_composite_key(key, self.num_pk_cols)?;
8700 for &(pk_pos, schema_col) in &self.pk_positions {
8701 row[schema_col] = std::mem::take(&mut pk_values[pk_pos]);
8702 }
8703 }
8704
8705 if !self.nonpk_targets.is_empty() {
8706 decode_columns_into(value, &self.nonpk_targets, &self.nonpk_schema, &mut row)?;
8707 }
8708
8709 if !self.nonpk_defaults.is_empty() {
8710 let stored = row_non_pk_count(value);
8711 for (nonpk_idx, schema_col, default) in &self.nonpk_defaults {
8712 if *nonpk_idx >= stored {
8713 row[*schema_col] = default.clone();
8714 }
8715 }
8716 }
8717
8718 Ok(row)
8719 }
8720
8721 fn complete(&self, mut row: Vec<Value>, key: &[u8], value: &[u8]) -> Result<Vec<Value>> {
8722 if !self.remaining_pk.is_empty() {
8723 let mut pk_values = decode_composite_key(key, self.num_pk_cols)?;
8724 for &(pk_pos, schema_col) in &self.remaining_pk {
8725 row[schema_col] = std::mem::take(&mut pk_values[pk_pos]);
8726 }
8727 }
8728 if !self.remaining_nonpk_targets.is_empty() {
8729 let mut values = decode_columns(value, &self.remaining_nonpk_targets)?;
8730 for (i, &schema_col) in self.remaining_nonpk_schema.iter().enumerate() {
8731 row[schema_col] = std::mem::take(&mut values[i]);
8732 }
8733 }
8734 if !self.remaining_defaults.is_empty() {
8735 let stored = row_non_pk_count(value);
8736 for (nonpk_idx, schema_col, default) in &self.remaining_defaults {
8737 if *nonpk_idx >= stored {
8738 row[*schema_col] = default.clone();
8739 }
8740 }
8741 }
8742 Ok(row)
8743 }
8744}
8745
8746fn decode_full_row(schema: &TableSchema, key: &[u8], value: &[u8]) -> Result<Vec<Value>> {
8747 let mut row = vec![Value::Null; schema.columns.len()];
8748 decode_pk_into(
8749 key,
8750 schema.primary_key_columns.len(),
8751 &mut row,
8752 schema.pk_indices(),
8753 )?;
8754 let mapping = schema.decode_col_mapping();
8755 let stored_count = row_non_pk_count(value);
8756 decode_row_into(value, &mut row, mapping)?;
8757 if stored_count < mapping.len() {
8760 for &logical_idx in mapping.iter().skip(stored_count) {
8761 if logical_idx != usize::MAX {
8762 if let Some(ref expr) = schema.columns[logical_idx].default_expr {
8763 row[logical_idx] = eval_const_expr(expr)?;
8764 }
8765 }
8766 }
8767 }
8768 Ok(row)
8769}
8770
8771fn eval_const_expr(expr: &Expr) -> Result<Value> {
8773 static EMPTY: std::sync::OnceLock<ColumnMap> = std::sync::OnceLock::new();
8774 let empty = EMPTY.get_or_init(|| ColumnMap::new(&[]));
8775 eval_expr(expr, empty, &[])
8776}
8777
8778fn eval_const_int(expr: &Expr) -> Result<i64> {
8779 match eval_const_expr(expr)? {
8780 Value::Integer(i) => Ok(i),
8781 other => Err(SqlError::TypeMismatch {
8782 expected: "INTEGER".into(),
8783 got: other.data_type().to_string(),
8784 }),
8785 }
8786}
8787
8788fn sort_rows(
8789 rows: &mut [Vec<Value>],
8790 order_by: &[OrderByItem],
8791 columns: &[ColumnDef],
8792) -> Result<()> {
8793 if rows.is_empty() {
8794 return Ok(());
8795 }
8796 let col_map = ColumnMap::new(columns);
8797 let mut indices: Vec<usize> = (0..rows.len()).collect();
8798
8799 if let Some(col_idx) = try_resolve_flat_sort_col(order_by, &col_map) {
8800 let desc = order_by[0].descending;
8801 let nulls_first = order_by[0].nulls_first.unwrap_or(!desc);
8802 indices.sort_by(|&a, &b| {
8803 compare_flat_key(&rows[a][col_idx], &rows[b][col_idx], desc, nulls_first)
8804 });
8805 } else {
8806 let keys = extract_sort_keys(rows, order_by, &col_map);
8807 indices.sort_by(|&a, &b| compare_sort_keys(&keys[a], &keys[b], order_by));
8808 }
8809
8810 let sorted: Vec<Vec<Value>> = indices
8811 .iter()
8812 .map(|&i| std::mem::take(&mut rows[i]))
8813 .collect();
8814 rows.iter_mut()
8815 .zip(sorted)
8816 .for_each(|(slot, row)| *slot = row);
8817 Ok(())
8818}
8819
8820fn topk_rows(
8821 rows: &mut [Vec<Value>],
8822 order_by: &[OrderByItem],
8823 columns: &[ColumnDef],
8824 k: usize,
8825) -> Result<()> {
8826 let col_map = ColumnMap::new(columns);
8827 let mut indices: Vec<usize> = (0..rows.len()).collect();
8828
8829 if let Some(col_idx) = try_resolve_flat_sort_col(order_by, &col_map) {
8830 let desc = order_by[0].descending;
8831 let nulls_first = order_by[0].nulls_first.unwrap_or(!desc);
8832 let cmp = |&a: &usize, &b: &usize| {
8833 compare_flat_key(&rows[a][col_idx], &rows[b][col_idx], desc, nulls_first)
8834 };
8835 indices.select_nth_unstable_by(k - 1, cmp);
8836 indices[..k].sort_by(cmp);
8837 } else {
8838 let keys = extract_sort_keys(rows, order_by, &col_map);
8839 let cmp = |&a: &usize, &b: &usize| compare_sort_keys(&keys[a], &keys[b], order_by);
8840 indices.select_nth_unstable_by(k - 1, cmp);
8841 indices[..k].sort_by(cmp);
8842 }
8843
8844 let sorted: Vec<Vec<Value>> = indices[..k]
8845 .iter()
8846 .map(|&i| std::mem::take(&mut rows[i]))
8847 .collect();
8848 rows[..k]
8849 .iter_mut()
8850 .zip(sorted)
8851 .for_each(|(slot, row)| *slot = row);
8852 Ok(())
8853}
8854
8855fn try_resolve_flat_sort_col(order_by: &[OrderByItem], col_map: &ColumnMap) -> Option<usize> {
8856 if order_by.len() != 1 {
8857 return None;
8858 }
8859 match &order_by[0].expr {
8860 Expr::Column(name) => col_map.resolve(&name.to_ascii_lowercase()).ok(),
8861 _ => None,
8862 }
8863}
8864
8865fn compare_flat_key(a: &Value, b: &Value, desc: bool, nulls_first: bool) -> std::cmp::Ordering {
8866 match (a.is_null(), b.is_null()) {
8867 (true, true) => std::cmp::Ordering::Equal,
8868 (true, false) => {
8869 if nulls_first {
8870 std::cmp::Ordering::Less
8871 } else {
8872 std::cmp::Ordering::Greater
8873 }
8874 }
8875 (false, true) => {
8876 if nulls_first {
8877 std::cmp::Ordering::Greater
8878 } else {
8879 std::cmp::Ordering::Less
8880 }
8881 }
8882 (false, false) => {
8883 let cmp = a.cmp(b);
8884 if desc {
8885 cmp.reverse()
8886 } else {
8887 cmp
8888 }
8889 }
8890 }
8891}
8892
8893fn extract_sort_keys(
8894 rows: &[Vec<Value>],
8895 order_by: &[OrderByItem],
8896 col_map: &ColumnMap,
8897) -> Vec<Vec<Value>> {
8898 rows.iter()
8899 .map(|row| {
8900 order_by
8901 .iter()
8902 .map(|item| eval_expr(&item.expr, col_map, row).unwrap_or(Value::Null))
8903 .collect()
8904 })
8905 .collect()
8906}
8907
8908fn compare_sort_keys(a: &[Value], b: &[Value], order_by: &[OrderByItem]) -> std::cmp::Ordering {
8909 for (i, item) in order_by.iter().enumerate() {
8910 let nulls_first = item.nulls_first.unwrap_or(!item.descending);
8911 let ord = match (a[i].is_null(), b[i].is_null()) {
8912 (true, true) => std::cmp::Ordering::Equal,
8913 (true, false) => {
8914 if nulls_first {
8915 std::cmp::Ordering::Less
8916 } else {
8917 std::cmp::Ordering::Greater
8918 }
8919 }
8920 (false, true) => {
8921 if nulls_first {
8922 std::cmp::Ordering::Greater
8923 } else {
8924 std::cmp::Ordering::Less
8925 }
8926 }
8927 (false, false) => {
8928 let cmp = a[i].cmp(&b[i]);
8929 if item.descending {
8930 cmp.reverse()
8931 } else {
8932 cmp
8933 }
8934 }
8935 };
8936 if ord != std::cmp::Ordering::Equal {
8937 return ord;
8938 }
8939 }
8940 std::cmp::Ordering::Equal
8941}
8942
8943fn try_build_index_map(
8944 select_cols: &[SelectColumn],
8945 columns: &[ColumnDef],
8946) -> Option<Vec<(String, usize)>> {
8947 let col_map = ColumnMap::new(columns);
8948 let mut map = Vec::new();
8949 let mut seen = std::collections::HashSet::new();
8950 for sel in select_cols {
8951 match sel {
8952 SelectColumn::AllColumns => {
8953 for col in columns {
8954 let idx = col.position as usize;
8955 if !seen.insert(idx) {
8956 return None;
8957 }
8958 map.push((col.name.clone(), idx));
8959 }
8960 }
8961 SelectColumn::Expr { expr, alias } => {
8962 let idx = match expr {
8963 Expr::Column(name) => col_map.resolve(name).ok()?,
8964 Expr::QualifiedColumn { table, column } => {
8965 col_map.resolve_qualified(table, column).ok()?
8966 }
8967 _ => return None,
8968 };
8969 if !seen.insert(idx) {
8970 return None;
8971 }
8972 let name = alias.clone().unwrap_or_else(|| expr_display_name(expr));
8973 map.push((name, idx));
8974 }
8975 }
8976 }
8977 Some(map)
8978}
8979
8980fn project_rows(
8981 columns: &[ColumnDef],
8982 select_cols: &[SelectColumn],
8983 mut rows: Vec<Vec<Value>>,
8984) -> Result<(Vec<String>, Vec<Vec<Value>>)> {
8985 if select_cols.len() == 1 && matches!(select_cols[0], SelectColumn::AllColumns) {
8987 let col_names = columns.iter().map(|c| c.name.clone()).collect();
8988 return Ok((col_names, rows));
8989 }
8990
8991 if let Some(map) = try_build_index_map(select_cols, columns) {
8993 let col_names: Vec<String> = map.iter().map(|(n, _)| n.clone()).collect();
8994 if map.len() == columns.len() && map.iter().enumerate().all(|(i, &(_, idx))| idx == i) {
8996 return Ok((col_names, rows));
8997 }
8998 let projected = rows
8999 .iter_mut()
9000 .map(|row| {
9001 map.iter()
9002 .map(|&(_, idx)| std::mem::take(&mut row[idx]))
9003 .collect()
9004 })
9005 .collect();
9006 return Ok((col_names, projected));
9007 }
9008
9009 let mut col_names = Vec::new();
9011 type Projector = Box<dyn Fn(&[Value]) -> Result<Value>>;
9012 let mut projectors: Vec<Projector> = Vec::new();
9013 let col_map = std::sync::Arc::new(ColumnMap::new(columns));
9014
9015 for sel_col in select_cols {
9016 match sel_col {
9017 SelectColumn::AllColumns => {
9018 for col in columns {
9019 let idx = col.position as usize;
9020 col_names.push(col.name.clone());
9021 projectors.push(Box::new(move |row: &[Value]| Ok(row[idx].clone())));
9022 }
9023 }
9024 SelectColumn::Expr { expr, alias } => {
9025 let name = alias.clone().unwrap_or_else(|| expr_display_name(expr));
9026 col_names.push(name);
9027 let expr = expr.clone();
9028 let map = col_map.clone();
9029 projectors.push(Box::new(move |row: &[Value]| eval_expr(&expr, &map, row)));
9030 }
9031 }
9032 }
9033
9034 let projected = rows
9035 .iter()
9036 .map(|row| {
9037 projectors
9038 .iter()
9039 .map(|p| p(row))
9040 .collect::<Result<Vec<_>>>()
9041 })
9042 .collect::<Result<Vec<_>>>()?;
9043
9044 Ok((col_names, projected))
9045}
9046
9047fn expr_display_name(expr: &Expr) -> String {
9048 match expr {
9049 Expr::Column(name) => name.clone(),
9050 Expr::QualifiedColumn { table, column } => format!("{table}.{column}"),
9051 Expr::Literal(v) => format!("{v}"),
9052 Expr::CountStar => "COUNT(*)".into(),
9053 Expr::Function { name, args } => {
9054 let arg_strs: Vec<String> = args.iter().map(expr_display_name).collect();
9055 format!("{name}({})", arg_strs.join(", "))
9056 }
9057 Expr::BinaryOp { left, op, right } => {
9058 format!(
9059 "{} {} {}",
9060 expr_display_name(left),
9061 op_symbol(op),
9062 expr_display_name(right)
9063 )
9064 }
9065 Expr::WindowFunction { name, args, .. } => {
9066 if args.is_empty() {
9067 format!("{name}()")
9068 } else {
9069 let arg_strs: Vec<String> = args.iter().map(expr_display_name).collect();
9070 format!("{name}({})", arg_strs.join(", "))
9071 }
9072 }
9073 _ => "?".into(),
9074 }
9075}
9076
9077fn op_symbol(op: &BinOp) -> &'static str {
9078 match op {
9079 BinOp::Add => "+",
9080 BinOp::Sub => "-",
9081 BinOp::Mul => "*",
9082 BinOp::Div => "/",
9083 BinOp::Mod => "%",
9084 BinOp::Eq => "=",
9085 BinOp::NotEq => "<>",
9086 BinOp::Lt => "<",
9087 BinOp::Gt => ">",
9088 BinOp::LtEq => "<=",
9089 BinOp::GtEq => ">=",
9090 BinOp::And => "AND",
9091 BinOp::Or => "OR",
9092 BinOp::Concat => "||",
9093 }
9094}
9095
9096fn build_output_columns(select_cols: &[SelectColumn], columns: &[ColumnDef]) -> Vec<ColumnDef> {
9097 let mut out = Vec::new();
9098 for (i, col) in select_cols.iter().enumerate() {
9099 let (name, data_type) = match col {
9100 SelectColumn::AllColumns => (format!("col{i}"), DataType::Null),
9101 SelectColumn::Expr {
9102 alias: Some(a),
9103 expr,
9104 } => (a.clone(), infer_expr_type(expr, columns)),
9105 SelectColumn::Expr { expr, .. } => {
9106 (expr_display_name(expr), infer_expr_type(expr, columns))
9107 }
9108 };
9109 out.push(ColumnDef {
9110 name,
9111 data_type,
9112 nullable: true,
9113 position: i as u16,
9114 default_expr: None,
9115 default_sql: None,
9116 check_expr: None,
9117 check_sql: None,
9118 check_name: None,
9119 });
9120 }
9121 out
9122}
9123
9124fn infer_expr_type(expr: &Expr, columns: &[ColumnDef]) -> DataType {
9125 match expr {
9126 Expr::Column(name) => columns
9127 .iter()
9128 .find(|c| c.name == *name)
9129 .map(|c| c.data_type)
9130 .unwrap_or(DataType::Null),
9131 Expr::QualifiedColumn { table, column } => {
9132 let qualified = format!("{table}.{column}");
9133 columns
9134 .iter()
9135 .find(|c| c.name == qualified)
9136 .map(|c| c.data_type)
9137 .unwrap_or(DataType::Null)
9138 }
9139 Expr::Literal(v) => v.data_type(),
9140 Expr::CountStar => DataType::Integer,
9141 Expr::Function { name, .. } => match name.to_ascii_uppercase().as_str() {
9142 "COUNT" => DataType::Integer,
9143 "AVG" => DataType::Real,
9144 "SUM" | "MIN" | "MAX" => DataType::Null,
9145 _ => DataType::Null,
9146 },
9147 _ => DataType::Null,
9148 }
9149}