Skip to main content

citadel_sql/
executor.rs

1//! SQL executor: DDL and DML operations.
2
3use std::collections::{BTreeMap, HashMap};
4
5use citadel::Database;
6
7use crate::encoding::{
8    decode_column_raw, decode_columns, decode_columns_into, decode_composite_key, decode_key_value,
9    decode_pk_integer, decode_pk_into, decode_row_into, encode_composite_key,
10    encode_composite_key_into, encode_row, encode_row_into, RawColumn,
11};
12use crate::error::{Result, SqlError};
13use crate::eval::{eval_expr, is_truthy, referenced_columns, ColumnMap};
14use crate::parser::*;
15use crate::planner::{self, ScanPlan};
16use crate::schema::SchemaManager;
17use crate::types::*;
18
19// ── Index helpers ────────────────────────────────────────────────────
20
21fn encode_index_key(idx: &IndexDef, row: &[Value], pk_values: &[Value]) -> Vec<u8> {
22    let indexed_values: Vec<Value> = idx
23        .columns
24        .iter()
25        .map(|&col_idx| row[col_idx as usize].clone())
26        .collect();
27
28    if idx.unique {
29        let any_null = indexed_values.iter().any(|v| v.is_null());
30        if !any_null {
31            return encode_composite_key(&indexed_values);
32        }
33    }
34
35    let mut all_values = indexed_values;
36    all_values.extend_from_slice(pk_values);
37    encode_composite_key(&all_values)
38}
39
40fn encode_index_value(idx: &IndexDef, row: &[Value], pk_values: &[Value]) -> Vec<u8> {
41    if idx.unique {
42        let indexed_values: Vec<Value> = idx
43            .columns
44            .iter()
45            .map(|&col_idx| row[col_idx as usize].clone())
46            .collect();
47        let any_null = indexed_values.iter().any(|v| v.is_null());
48        if !any_null {
49            return encode_composite_key(pk_values);
50        }
51    }
52    vec![]
53}
54
55fn insert_index_entries(
56    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
57    table_schema: &TableSchema,
58    row: &[Value],
59    pk_values: &[Value],
60) -> Result<()> {
61    for idx in &table_schema.indices {
62        let idx_table = TableSchema::index_table_name(&table_schema.name, &idx.name);
63        let key = encode_index_key(idx, row, pk_values);
64        let value = encode_index_value(idx, row, pk_values);
65
66        let is_new = wtx
67            .table_insert(&idx_table, &key, &value)
68            .map_err(SqlError::Storage)?;
69
70        if idx.unique && !is_new {
71            let indexed_values: Vec<Value> = idx
72                .columns
73                .iter()
74                .map(|&col_idx| row[col_idx as usize].clone())
75                .collect();
76            let any_null = indexed_values.iter().any(|v| v.is_null());
77            if !any_null {
78                return Err(SqlError::UniqueViolation(idx.name.clone()));
79            }
80        }
81    }
82    Ok(())
83}
84
85fn delete_index_entries(
86    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
87    table_schema: &TableSchema,
88    row: &[Value],
89    pk_values: &[Value],
90) -> Result<()> {
91    for idx in &table_schema.indices {
92        let idx_table = TableSchema::index_table_name(&table_schema.name, &idx.name);
93        let key = encode_index_key(idx, row, pk_values);
94        wtx.table_delete(&idx_table, &key)
95            .map_err(SqlError::Storage)?;
96    }
97    Ok(())
98}
99
100fn index_columns_changed(idx: &IndexDef, old_row: &[Value], new_row: &[Value]) -> bool {
101    idx.columns
102        .iter()
103        .any(|&col_idx| old_row[col_idx as usize] != new_row[col_idx as usize])
104}
105
106/// Execute a parsed SQL statement in auto-commit mode.
107pub fn execute(
108    db: &Database,
109    schema: &mut SchemaManager,
110    stmt: &Statement,
111    params: &[Value],
112) -> Result<ExecutionResult> {
113    match stmt {
114        Statement::CreateTable(ct) => exec_create_table(db, schema, ct),
115        Statement::DropTable(dt) => exec_drop_table(db, schema, dt),
116        Statement::CreateIndex(ci) => exec_create_index(db, schema, ci),
117        Statement::DropIndex(di) => exec_drop_index(db, schema, di),
118        Statement::Insert(ins) => exec_insert(db, schema, ins, params),
119        Statement::Select(sel) => exec_select(db, schema, sel),
120        Statement::Update(upd) => exec_update(db, schema, upd),
121        Statement::Delete(del) => exec_delete(db, schema, del),
122        Statement::Explain(inner) => explain(schema, inner),
123        Statement::Begin | Statement::Commit | Statement::Rollback => Err(SqlError::Unsupported(
124            "transaction control in auto-commit mode".into(),
125        )),
126    }
127}
128
129/// Execute a parsed SQL statement within an existing write transaction.
130pub fn execute_in_txn(
131    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
132    schema: &mut SchemaManager,
133    stmt: &Statement,
134    params: &[Value],
135) -> Result<ExecutionResult> {
136    match stmt {
137        Statement::CreateTable(ct) => exec_create_table_in_txn(wtx, schema, ct),
138        Statement::DropTable(dt) => exec_drop_table_in_txn(wtx, schema, dt),
139        Statement::CreateIndex(ci) => exec_create_index_in_txn(wtx, schema, ci),
140        Statement::DropIndex(di) => exec_drop_index_in_txn(wtx, schema, di),
141        Statement::Insert(ins) => {
142            let mut bufs = InsertBufs::new();
143            exec_insert_in_txn(wtx, schema, ins, params, &mut bufs)
144        }
145        Statement::Select(sel) => exec_select_in_txn(wtx, schema, sel),
146        Statement::Update(upd) => exec_update_in_txn(wtx, schema, upd),
147        Statement::Delete(del) => exec_delete_in_txn(wtx, schema, del),
148        Statement::Explain(inner) => explain(schema, inner),
149        Statement::Begin | Statement::Commit | Statement::Rollback => {
150            Err(SqlError::Unsupported("nested transaction control".into()))
151        }
152    }
153}
154
155// ── EXPLAIN ──────────────────────────────────────────────────────────
156
157pub fn explain(schema: &SchemaManager, stmt: &Statement) -> Result<ExecutionResult> {
158    let lines = match stmt {
159        Statement::Select(sel) => explain_select(schema, sel)?,
160        Statement::Insert(ins) => {
161            vec![format!("INSERT INTO {}", ins.table.to_ascii_lowercase())]
162        }
163        Statement::Update(upd) => explain_dml(schema, &upd.table, &upd.where_clause, "UPDATE")?,
164        Statement::Delete(del) => {
165            explain_dml(schema, &del.table, &del.where_clause, "DELETE FROM")?
166        }
167        Statement::Explain(_) => {
168            return Err(SqlError::Unsupported("EXPLAIN EXPLAIN".into()));
169        }
170        _ => {
171            return Err(SqlError::Unsupported(
172                "EXPLAIN for this statement type".into(),
173            ));
174        }
175    };
176
177    let rows = lines
178        .into_iter()
179        .map(|line| vec![Value::Text(line.into())])
180        .collect();
181    Ok(ExecutionResult::Query(QueryResult {
182        columns: vec!["plan".into()],
183        rows,
184    }))
185}
186
187fn explain_dml(
188    schema: &SchemaManager,
189    table: &str,
190    where_clause: &Option<Expr>,
191    verb: &str,
192) -> Result<Vec<String>> {
193    let lower = table.to_ascii_lowercase();
194    let table_schema = schema
195        .get(&lower)
196        .ok_or_else(|| SqlError::TableNotFound(table.to_string()))?;
197    let plan = planner::plan_select(table_schema, where_clause);
198    let scan_line = format_scan_line(&lower, &None, &plan, table_schema);
199    Ok(vec![format!("{verb} {}", scan_line)])
200}
201
202fn explain_select(schema: &SchemaManager, stmt: &SelectStmt) -> Result<Vec<String>> {
203    let mut lines = Vec::new();
204
205    if stmt.from.is_empty() {
206        lines.push("CONSTANT ROW".into());
207        return Ok(lines);
208    }
209
210    let lower_from = stmt.from.to_ascii_lowercase();
211    let from_schema = schema
212        .get(&lower_from)
213        .ok_or_else(|| SqlError::TableNotFound(stmt.from.clone()))?;
214
215    if stmt.joins.is_empty() {
216        let plan = planner::plan_select(from_schema, &stmt.where_clause);
217        lines.push(format_scan_line(
218            &lower_from,
219            &stmt.from_alias,
220            &plan,
221            from_schema,
222        ));
223    } else {
224        let from_plan = planner::plan_select(from_schema, &None);
225        lines.push(format_scan_line(
226            &lower_from,
227            &stmt.from_alias,
228            &from_plan,
229            from_schema,
230        ));
231
232        for join in &stmt.joins {
233            let inner_lower = join.table.name.to_ascii_lowercase();
234            let inner_schema = schema
235                .get(&inner_lower)
236                .ok_or_else(|| SqlError::TableNotFound(join.table.name.clone()))?;
237            let inner_plan = planner::plan_select(inner_schema, &None);
238            lines.push(format_scan_line(
239                &inner_lower,
240                &join.table.alias,
241                &inner_plan,
242                inner_schema,
243            ));
244        }
245
246        let join_type_str = match stmt.joins.last().map(|j| j.join_type) {
247            Some(JoinType::Left) => "LEFT JOIN",
248            Some(JoinType::Right) => "RIGHT JOIN",
249            Some(JoinType::Cross) => "CROSS JOIN",
250            _ => "NESTED LOOP",
251        };
252        lines.push(join_type_str.into());
253    }
254
255    if stmt.where_clause.is_some() && stmt.joins.is_empty() {
256        let plan = planner::plan_select(from_schema, &stmt.where_clause);
257        if matches!(plan, ScanPlan::SeqScan) {
258            lines.push("FILTER".into());
259        }
260    }
261
262    if let Some(ref w) = stmt.where_clause {
263        if !stmt.joins.is_empty() && has_subquery(w) {
264            lines.push("SUBQUERY".into());
265        }
266    }
267
268    explain_subqueries(stmt, &mut lines);
269
270    if !stmt.group_by.is_empty() {
271        lines.push("GROUP BY".into());
272    }
273
274    if stmt.distinct {
275        lines.push("DISTINCT".into());
276    }
277
278    if !stmt.order_by.is_empty() {
279        lines.push("SORT".into());
280    }
281
282    if let Some(ref offset_expr) = stmt.offset {
283        if let Ok(n) = eval_const_int(offset_expr) {
284            lines.push(format!("OFFSET {n}"));
285        } else {
286            lines.push("OFFSET".into());
287        }
288    }
289
290    if let Some(ref limit_expr) = stmt.limit {
291        if let Ok(n) = eval_const_int(limit_expr) {
292            lines.push(format!("LIMIT {n}"));
293        } else {
294            lines.push("LIMIT".into());
295        }
296    }
297
298    Ok(lines)
299}
300
301fn explain_subqueries(stmt: &SelectStmt, lines: &mut Vec<String>) {
302    let mut count = 0;
303    if let Some(ref w) = stmt.where_clause {
304        count += count_subqueries(w);
305    }
306    if let Some(ref h) = stmt.having {
307        count += count_subqueries(h);
308    }
309    for col in &stmt.columns {
310        if let SelectColumn::Expr { expr, .. } = col {
311            count += count_subqueries(expr);
312        }
313    }
314    for _ in 0..count {
315        lines.push("SUBQUERY".into());
316    }
317}
318
319fn count_subqueries(expr: &Expr) -> usize {
320    match expr {
321        Expr::InSubquery { expr: e, .. } => 1 + count_subqueries(e),
322        Expr::ScalarSubquery(_) => 1,
323        Expr::Exists { .. } => 1,
324        Expr::BinaryOp { left, right, .. } => count_subqueries(left) + count_subqueries(right),
325        Expr::UnaryOp { expr: e, .. } => count_subqueries(e),
326        Expr::IsNull(e) | Expr::IsNotNull(e) => count_subqueries(e),
327        Expr::Function { args, .. } => args.iter().map(count_subqueries).sum(),
328        Expr::Between {
329            expr: e, low, high, ..
330        } => count_subqueries(e) + count_subqueries(low) + count_subqueries(high),
331        Expr::Like {
332            expr: e, pattern, ..
333        } => count_subqueries(e) + count_subqueries(pattern),
334        Expr::Case {
335            operand,
336            conditions,
337            else_result,
338        } => {
339            let mut n = 0;
340            if let Some(op) = operand {
341                n += count_subqueries(op);
342            }
343            for (c, r) in conditions {
344                n += count_subqueries(c) + count_subqueries(r);
345            }
346            if let Some(el) = else_result {
347                n += count_subqueries(el);
348            }
349            n
350        }
351        Expr::Coalesce(args) => args.iter().map(count_subqueries).sum(),
352        Expr::Cast { expr: e, .. } => count_subqueries(e),
353        Expr::InList { expr: e, list, .. } => {
354            count_subqueries(e) + list.iter().map(count_subqueries).sum::<usize>()
355        }
356        _ => 0,
357    }
358}
359
360fn format_scan_line(
361    table_name: &str,
362    alias: &Option<String>,
363    plan: &ScanPlan,
364    table_schema: &TableSchema,
365) -> String {
366    let alias_part = match alias {
367        Some(a) if !a.eq_ignore_ascii_case(table_name) => {
368            format!(" AS {}", a.to_ascii_lowercase())
369        }
370        _ => String::new(),
371    };
372
373    let desc = planner::describe_plan(plan, table_schema);
374
375    if desc.is_empty() {
376        format!("SCAN TABLE {table_name}{alias_part}")
377    } else {
378        format!("SEARCH TABLE {table_name}{alias_part} {desc}")
379    }
380}
381
382// ── DDL ─────────────────────────────────────────────────────────────
383
384fn exec_create_table(
385    db: &Database,
386    schema: &mut SchemaManager,
387    stmt: &CreateTableStmt,
388) -> Result<ExecutionResult> {
389    let lower_name = stmt.name.to_ascii_lowercase();
390
391    if schema.contains(&lower_name) {
392        if stmt.if_not_exists {
393            return Ok(ExecutionResult::Ok);
394        }
395        return Err(SqlError::TableAlreadyExists(stmt.name.clone()));
396    }
397
398    if stmt.primary_key.is_empty() {
399        return Err(SqlError::PrimaryKeyRequired);
400    }
401
402    let mut seen = std::collections::HashSet::new();
403    for col in &stmt.columns {
404        let lower = col.name.to_ascii_lowercase();
405        if !seen.insert(lower.clone()) {
406            return Err(SqlError::DuplicateColumn(col.name.clone()));
407        }
408    }
409
410    let columns: Vec<ColumnDef> = stmt
411        .columns
412        .iter()
413        .enumerate()
414        .map(|(i, c)| ColumnDef {
415            name: c.name.to_ascii_lowercase(),
416            data_type: c.data_type,
417            nullable: c.nullable,
418            position: i as u16,
419        })
420        .collect();
421
422    let primary_key_columns: Vec<u16> = stmt
423        .primary_key
424        .iter()
425        .map(|pk_name| {
426            let lower = pk_name.to_ascii_lowercase();
427            columns
428                .iter()
429                .position(|c| c.name == lower)
430                .map(|i| i as u16)
431                .ok_or_else(|| SqlError::ColumnNotFound(pk_name.clone()))
432        })
433        .collect::<Result<_>>()?;
434
435    let table_schema = TableSchema::new(lower_name.clone(), columns, primary_key_columns, vec![]);
436
437    let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
438    SchemaManager::ensure_schema_table(&mut wtx)?;
439    wtx.create_table(lower_name.as_bytes())
440        .map_err(SqlError::Storage)?;
441    SchemaManager::save_schema(&mut wtx, &table_schema)?;
442    wtx.commit().map_err(SqlError::Storage)?;
443
444    schema.register(table_schema);
445    Ok(ExecutionResult::Ok)
446}
447
448fn exec_drop_table(
449    db: &Database,
450    schema: &mut SchemaManager,
451    stmt: &DropTableStmt,
452) -> Result<ExecutionResult> {
453    let lower_name = stmt.name.to_ascii_lowercase();
454
455    if !schema.contains(&lower_name) {
456        if stmt.if_exists {
457            return Ok(ExecutionResult::Ok);
458        }
459        return Err(SqlError::TableNotFound(stmt.name.clone()));
460    }
461
462    let table_schema = schema.get(&lower_name).unwrap();
463    let idx_tables: Vec<Vec<u8>> = table_schema
464        .indices
465        .iter()
466        .map(|idx| TableSchema::index_table_name(&lower_name, &idx.name))
467        .collect();
468
469    let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
470    for idx_table in &idx_tables {
471        wtx.drop_table(idx_table).map_err(SqlError::Storage)?;
472    }
473    wtx.drop_table(lower_name.as_bytes())
474        .map_err(SqlError::Storage)?;
475    SchemaManager::delete_schema(&mut wtx, &lower_name)?;
476    wtx.commit().map_err(SqlError::Storage)?;
477
478    schema.remove(&lower_name);
479    Ok(ExecutionResult::Ok)
480}
481
482fn exec_create_table_in_txn(
483    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
484    schema: &mut SchemaManager,
485    stmt: &CreateTableStmt,
486) -> Result<ExecutionResult> {
487    let lower_name = stmt.name.to_ascii_lowercase();
488
489    if schema.contains(&lower_name) {
490        if stmt.if_not_exists {
491            return Ok(ExecutionResult::Ok);
492        }
493        return Err(SqlError::TableAlreadyExists(stmt.name.clone()));
494    }
495
496    if stmt.primary_key.is_empty() {
497        return Err(SqlError::PrimaryKeyRequired);
498    }
499
500    let mut seen = std::collections::HashSet::new();
501    for col in &stmt.columns {
502        let lower = col.name.to_ascii_lowercase();
503        if !seen.insert(lower.clone()) {
504            return Err(SqlError::DuplicateColumn(col.name.clone()));
505        }
506    }
507
508    let columns: Vec<ColumnDef> = stmt
509        .columns
510        .iter()
511        .enumerate()
512        .map(|(i, c)| ColumnDef {
513            name: c.name.to_ascii_lowercase(),
514            data_type: c.data_type,
515            nullable: c.nullable,
516            position: i as u16,
517        })
518        .collect();
519
520    let primary_key_columns: Vec<u16> = stmt
521        .primary_key
522        .iter()
523        .map(|pk_name| {
524            let lower = pk_name.to_ascii_lowercase();
525            columns
526                .iter()
527                .position(|c| c.name == lower)
528                .map(|i| i as u16)
529                .ok_or_else(|| SqlError::ColumnNotFound(pk_name.clone()))
530        })
531        .collect::<Result<_>>()?;
532
533    let table_schema = TableSchema::new(lower_name.clone(), columns, primary_key_columns, vec![]);
534
535    SchemaManager::ensure_schema_table(wtx)?;
536    wtx.create_table(lower_name.as_bytes())
537        .map_err(SqlError::Storage)?;
538    SchemaManager::save_schema(wtx, &table_schema)?;
539
540    schema.register(table_schema);
541    Ok(ExecutionResult::Ok)
542}
543
544fn exec_drop_table_in_txn(
545    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
546    schema: &mut SchemaManager,
547    stmt: &DropTableStmt,
548) -> Result<ExecutionResult> {
549    let lower_name = stmt.name.to_ascii_lowercase();
550
551    if !schema.contains(&lower_name) {
552        if stmt.if_exists {
553            return Ok(ExecutionResult::Ok);
554        }
555        return Err(SqlError::TableNotFound(stmt.name.clone()));
556    }
557
558    let table_schema = schema.get(&lower_name).unwrap();
559    let idx_tables: Vec<Vec<u8>> = table_schema
560        .indices
561        .iter()
562        .map(|idx| TableSchema::index_table_name(&lower_name, &idx.name))
563        .collect();
564
565    for idx_table in &idx_tables {
566        wtx.drop_table(idx_table).map_err(SqlError::Storage)?;
567    }
568    wtx.drop_table(lower_name.as_bytes())
569        .map_err(SqlError::Storage)?;
570    SchemaManager::delete_schema(wtx, &lower_name)?;
571
572    schema.remove(&lower_name);
573    Ok(ExecutionResult::Ok)
574}
575
576fn exec_create_index(
577    db: &Database,
578    schema: &mut SchemaManager,
579    stmt: &CreateIndexStmt,
580) -> Result<ExecutionResult> {
581    let lower_table = stmt.table_name.to_ascii_lowercase();
582    let lower_idx = stmt.index_name.to_ascii_lowercase();
583
584    let table_schema = schema
585        .get(&lower_table)
586        .ok_or_else(|| SqlError::TableNotFound(stmt.table_name.clone()))?;
587
588    if table_schema.index_by_name(&lower_idx).is_some() {
589        if stmt.if_not_exists {
590            return Ok(ExecutionResult::Ok);
591        }
592        return Err(SqlError::IndexAlreadyExists(stmt.index_name.clone()));
593    }
594
595    let col_indices: Vec<u16> = stmt
596        .columns
597        .iter()
598        .map(|col_name| {
599            let lower = col_name.to_ascii_lowercase();
600            table_schema
601                .column_index(&lower)
602                .map(|i| i as u16)
603                .ok_or_else(|| SqlError::ColumnNotFound(col_name.clone()))
604        })
605        .collect::<Result<_>>()?;
606
607    let idx_def = IndexDef {
608        name: lower_idx.clone(),
609        columns: col_indices,
610        unique: stmt.unique,
611    };
612
613    let idx_table = TableSchema::index_table_name(&lower_table, &lower_idx);
614
615    let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
616    SchemaManager::ensure_schema_table(&mut wtx)?;
617    wtx.create_table(&idx_table).map_err(SqlError::Storage)?;
618
619    // Populate index from existing rows
620    let pk_indices = table_schema.pk_indices();
621    let mut rows: Vec<Vec<Value>> = Vec::new();
622    {
623        let mut scan_err: Option<SqlError> = None;
624        wtx.table_for_each(lower_table.as_bytes(), |key, value| {
625            match decode_full_row(table_schema, key, value) {
626                Ok(row) => rows.push(row),
627                Err(e) => scan_err = Some(e),
628            }
629            Ok(())
630        })
631        .map_err(SqlError::Storage)?;
632        if let Some(e) = scan_err {
633            return Err(e);
634        }
635    }
636
637    for row in &rows {
638        let pk_values: Vec<Value> = pk_indices.iter().map(|&i| row[i].clone()).collect();
639        let key = encode_index_key(&idx_def, row, &pk_values);
640        let value = encode_index_value(&idx_def, row, &pk_values);
641        let is_new = wtx
642            .table_insert(&idx_table, &key, &value)
643            .map_err(SqlError::Storage)?;
644        if idx_def.unique && !is_new {
645            let indexed_values: Vec<Value> = idx_def
646                .columns
647                .iter()
648                .map(|&col_idx| row[col_idx as usize].clone())
649                .collect();
650            let any_null = indexed_values.iter().any(|v| v.is_null());
651            if !any_null {
652                return Err(SqlError::UniqueViolation(stmt.index_name.clone()));
653            }
654        }
655    }
656
657    let mut updated_schema = table_schema.clone();
658    updated_schema.indices.push(idx_def);
659    SchemaManager::save_schema(&mut wtx, &updated_schema)?;
660    wtx.commit().map_err(SqlError::Storage)?;
661
662    schema.register(updated_schema);
663    Ok(ExecutionResult::Ok)
664}
665
666fn exec_drop_index(
667    db: &Database,
668    schema: &mut SchemaManager,
669    stmt: &DropIndexStmt,
670) -> Result<ExecutionResult> {
671    let lower_idx = stmt.index_name.to_ascii_lowercase();
672
673    let (table_name, _idx_pos) = match find_index_in_schemas(schema, &lower_idx) {
674        Some(found) => found,
675        None => {
676            if stmt.if_exists {
677                return Ok(ExecutionResult::Ok);
678            }
679            return Err(SqlError::IndexNotFound(stmt.index_name.clone()));
680        }
681    };
682
683    let idx_table = TableSchema::index_table_name(&table_name, &lower_idx);
684
685    let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
686    wtx.drop_table(&idx_table).map_err(SqlError::Storage)?;
687
688    let table_schema = schema.get(&table_name).unwrap();
689    let mut updated_schema = table_schema.clone();
690    updated_schema.indices.retain(|i| i.name != lower_idx);
691    SchemaManager::save_schema(&mut wtx, &updated_schema)?;
692    wtx.commit().map_err(SqlError::Storage)?;
693
694    schema.register(updated_schema);
695    Ok(ExecutionResult::Ok)
696}
697
698fn exec_create_index_in_txn(
699    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
700    schema: &mut SchemaManager,
701    stmt: &CreateIndexStmt,
702) -> Result<ExecutionResult> {
703    let lower_table = stmt.table_name.to_ascii_lowercase();
704    let lower_idx = stmt.index_name.to_ascii_lowercase();
705
706    let table_schema = schema
707        .get(&lower_table)
708        .ok_or_else(|| SqlError::TableNotFound(stmt.table_name.clone()))?;
709
710    if table_schema.index_by_name(&lower_idx).is_some() {
711        if stmt.if_not_exists {
712            return Ok(ExecutionResult::Ok);
713        }
714        return Err(SqlError::IndexAlreadyExists(stmt.index_name.clone()));
715    }
716
717    let col_indices: Vec<u16> = stmt
718        .columns
719        .iter()
720        .map(|col_name| {
721            let lower = col_name.to_ascii_lowercase();
722            table_schema
723                .column_index(&lower)
724                .map(|i| i as u16)
725                .ok_or_else(|| SqlError::ColumnNotFound(col_name.clone()))
726        })
727        .collect::<Result<_>>()?;
728
729    let idx_def = IndexDef {
730        name: lower_idx.clone(),
731        columns: col_indices,
732        unique: stmt.unique,
733    };
734
735    let idx_table = TableSchema::index_table_name(&lower_table, &lower_idx);
736
737    SchemaManager::ensure_schema_table(wtx)?;
738    wtx.create_table(&idx_table).map_err(SqlError::Storage)?;
739
740    let pk_indices = table_schema.pk_indices();
741    let mut rows: Vec<Vec<Value>> = Vec::new();
742    {
743        let mut scan_err: Option<SqlError> = None;
744        wtx.table_for_each(lower_table.as_bytes(), |key, value| {
745            match decode_full_row(table_schema, key, value) {
746                Ok(row) => rows.push(row),
747                Err(e) => scan_err = Some(e),
748            }
749            Ok(())
750        })
751        .map_err(SqlError::Storage)?;
752        if let Some(e) = scan_err {
753            return Err(e);
754        }
755    }
756
757    for row in &rows {
758        let pk_values: Vec<Value> = pk_indices.iter().map(|&i| row[i].clone()).collect();
759        let key = encode_index_key(&idx_def, row, &pk_values);
760        let value = encode_index_value(&idx_def, row, &pk_values);
761        let is_new = wtx
762            .table_insert(&idx_table, &key, &value)
763            .map_err(SqlError::Storage)?;
764        if idx_def.unique && !is_new {
765            let indexed_values: Vec<Value> = idx_def
766                .columns
767                .iter()
768                .map(|&col_idx| row[col_idx as usize].clone())
769                .collect();
770            let any_null = indexed_values.iter().any(|v| v.is_null());
771            if !any_null {
772                return Err(SqlError::UniqueViolation(stmt.index_name.clone()));
773            }
774        }
775    }
776
777    let mut updated_schema = table_schema.clone();
778    updated_schema.indices.push(idx_def);
779    SchemaManager::save_schema(wtx, &updated_schema)?;
780
781    schema.register(updated_schema);
782    Ok(ExecutionResult::Ok)
783}
784
785fn exec_drop_index_in_txn(
786    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
787    schema: &mut SchemaManager,
788    stmt: &DropIndexStmt,
789) -> Result<ExecutionResult> {
790    let lower_idx = stmt.index_name.to_ascii_lowercase();
791
792    let (table_name, _idx_pos) = match find_index_in_schemas(schema, &lower_idx) {
793        Some(found) => found,
794        None => {
795            if stmt.if_exists {
796                return Ok(ExecutionResult::Ok);
797            }
798            return Err(SqlError::IndexNotFound(stmt.index_name.clone()));
799        }
800    };
801
802    let idx_table = TableSchema::index_table_name(&table_name, &lower_idx);
803    wtx.drop_table(&idx_table).map_err(SqlError::Storage)?;
804
805    let table_schema = schema.get(&table_name).unwrap();
806    let mut updated_schema = table_schema.clone();
807    updated_schema.indices.retain(|i| i.name != lower_idx);
808    SchemaManager::save_schema(wtx, &updated_schema)?;
809
810    schema.register(updated_schema);
811    Ok(ExecutionResult::Ok)
812}
813
814fn find_index_in_schemas(schema: &SchemaManager, index_name: &str) -> Option<(String, usize)> {
815    for table_name in schema.table_names() {
816        if let Some(ts) = schema.get(table_name) {
817            if let Some(pos) = ts.indices.iter().position(|i| i.name == index_name) {
818                return Some((table_name.to_string(), pos));
819            }
820        }
821    }
822    None
823}
824
825// ── Index scan helpers ───────────────────────────────────────────────
826
827fn extract_pk_key(
828    idx_key: &[u8],
829    idx_value: &[u8],
830    is_unique: bool,
831    num_index_cols: usize,
832    num_pk_cols: usize,
833) -> Result<Vec<u8>> {
834    if is_unique && !idx_value.is_empty() {
835        Ok(idx_value.to_vec())
836    } else {
837        let total_cols = num_index_cols + num_pk_cols;
838        let all_values = decode_composite_key(idx_key, total_cols)?;
839        let pk_values = &all_values[num_index_cols..];
840        Ok(encode_composite_key(pk_values))
841    }
842}
843
844fn check_range_conditions(
845    idx_key: &[u8],
846    num_prefix_cols: usize,
847    range_conds: &[(BinOp, Value)],
848    num_index_cols: usize,
849) -> Result<RangeCheck> {
850    if range_conds.is_empty() {
851        return Ok(RangeCheck::Match);
852    }
853
854    let num_to_decode = num_prefix_cols + 1;
855    if num_to_decode > num_index_cols {
856        return Ok(RangeCheck::Match);
857    }
858
859    // Decode just enough columns to check the range column
860    let mut pos = 0;
861    for _ in 0..num_prefix_cols {
862        let (_, n) = decode_key_value(&idx_key[pos..])?;
863        pos += n;
864    }
865    let (range_val, _) = decode_key_value(&idx_key[pos..])?;
866
867    let mut exceeds_upper = false;
868    let mut below_lower = false;
869
870    for (op, val) in range_conds {
871        match op {
872            BinOp::Lt => {
873                if range_val >= *val {
874                    exceeds_upper = true;
875                }
876            }
877            BinOp::LtEq => {
878                if range_val > *val {
879                    exceeds_upper = true;
880                }
881            }
882            BinOp::Gt => {
883                if range_val <= *val {
884                    below_lower = true;
885                }
886            }
887            BinOp::GtEq => {
888                if range_val < *val {
889                    below_lower = true;
890                }
891            }
892            _ => {}
893        }
894    }
895
896    if exceeds_upper {
897        Ok(RangeCheck::ExceedsUpper)
898    } else if below_lower {
899        Ok(RangeCheck::BelowLower)
900    } else {
901        Ok(RangeCheck::Match)
902    }
903}
904
905enum RangeCheck {
906    Match,
907    BelowLower,
908    ExceedsUpper,
909}
910
911/// Collect rows via ReadTxn using the scan plan.
912fn collect_rows_read(
913    db: &Database,
914    table_schema: &TableSchema,
915    where_clause: &Option<Expr>,
916    limit: Option<usize>,
917) -> Result<(Vec<Vec<Value>>, bool)> {
918    let plan = planner::plan_select(table_schema, where_clause);
919    let lower_name = &table_schema.name;
920    let columns = &table_schema.columns;
921
922    match plan {
923        ScanPlan::SeqScan => {
924            let simple_pred = where_clause
925                .as_ref()
926                .and_then(|expr| try_simple_predicate(expr, table_schema));
927
928            if let Some(ref pred) = simple_pred {
929                let mut rtx = db.begin_read();
930                let entry_count =
931                    rtx.table_entry_count(lower_name.as_bytes()).unwrap_or(0) as usize;
932                let mut rows = Vec::with_capacity(entry_count / 4);
933                let mut scan_err: Option<SqlError> = None;
934                rtx.table_scan_raw(lower_name.as_bytes(), |key, value| {
935                    match pred.matches_raw(key, value) {
936                        Ok(true) => match decode_full_row(table_schema, key, value) {
937                            Ok(row) => rows.push(row),
938                            Err(e) => {
939                                scan_err = Some(e);
940                                return false;
941                            }
942                        },
943                        Ok(false) => {}
944                        Err(e) => {
945                            scan_err = Some(e);
946                            return false;
947                        }
948                    }
949                    scan_err.is_none() && limit.map_or(true, |n| rows.len() < n)
950                })
951                .map_err(SqlError::Storage)?;
952                if let Some(e) = scan_err {
953                    return Err(e);
954                }
955                return Ok((rows, true));
956            }
957
958            let mut rtx = db.begin_read();
959            let entry_count = rtx.table_entry_count(lower_name.as_bytes()).unwrap_or(0) as usize;
960            let capacity = if where_clause.is_some() {
961                entry_count / 4
962            } else {
963                entry_count
964            };
965            let mut rows = Vec::with_capacity(capacity);
966            let mut scan_err: Option<SqlError> = None;
967
968            let col_map = ColumnMap::new(columns);
969            let partial_ctx = where_clause.as_ref().and_then(|expr| {
970                let needed = referenced_columns(expr, columns);
971                if needed.len() < columns.len() {
972                    Some(PartialDecodeCtx::new(table_schema, &needed))
973                } else {
974                    None
975                }
976            });
977
978            rtx.table_scan_from(lower_name.as_bytes(), b"", |key, value| {
979                match (&where_clause, &partial_ctx) {
980                    (Some(expr), Some(ctx)) => match ctx.decode(key, value) {
981                        Ok(partial) => match eval_expr(expr, &col_map, &partial) {
982                            Ok(val) if is_truthy(&val) => match ctx.complete(partial, key, value) {
983                                Ok(row) => rows.push(row),
984                                Err(e) => scan_err = Some(e),
985                            },
986                            Err(e) => scan_err = Some(e),
987                            _ => {}
988                        },
989                        Err(e) => scan_err = Some(e),
990                    },
991                    (Some(expr), None) => match decode_full_row(table_schema, key, value) {
992                        Ok(row) => match eval_expr(expr, &col_map, &row) {
993                            Ok(val) if is_truthy(&val) => rows.push(row),
994                            Err(e) => scan_err = Some(e),
995                            _ => {}
996                        },
997                        Err(e) => scan_err = Some(e),
998                    },
999                    _ => match decode_full_row(table_schema, key, value) {
1000                        Ok(row) => rows.push(row),
1001                        Err(e) => scan_err = Some(e),
1002                    },
1003                }
1004                let keep_going = scan_err.is_none() && limit.map_or(true, |n| rows.len() < n);
1005                Ok(keep_going)
1006            })
1007            .map_err(SqlError::Storage)?;
1008            if let Some(e) = scan_err {
1009                return Err(e);
1010            }
1011            Ok((rows, where_clause.is_some()))
1012        }
1013
1014        ScanPlan::PkLookup { pk_values } => {
1015            let key = encode_composite_key(&pk_values);
1016            let mut rtx = db.begin_read();
1017            match rtx
1018                .table_get(lower_name.as_bytes(), &key)
1019                .map_err(SqlError::Storage)?
1020            {
1021                Some(value) => {
1022                    let row = decode_full_row(table_schema, &key, &value)?;
1023                    if let Some(ref expr) = where_clause {
1024                        let col_map = ColumnMap::new(columns);
1025                        match eval_expr(expr, &col_map, &row) {
1026                            Ok(val) if is_truthy(&val) => Ok((vec![row], true)),
1027                            _ => Ok((vec![], true)),
1028                        }
1029                    } else {
1030                        Ok((vec![row], false))
1031                    }
1032                }
1033                None => Ok((vec![], true)),
1034            }
1035        }
1036
1037        ScanPlan::IndexScan {
1038            idx_table,
1039            prefix,
1040            num_prefix_cols,
1041            range_conds,
1042            is_unique,
1043            index_columns,
1044            ..
1045        } => {
1046            let num_pk_cols = table_schema.primary_key_columns.len();
1047            let num_index_cols = index_columns.len();
1048            let mut pk_keys: Vec<Vec<u8>> = Vec::new();
1049
1050            {
1051                let mut rtx = db.begin_read();
1052                let mut scan_err: Option<SqlError> = None;
1053                rtx.table_scan_from(&idx_table, &prefix, |key, value| {
1054                    if !key.starts_with(&prefix) {
1055                        return Ok(false);
1056                    }
1057                    match check_range_conditions(key, num_prefix_cols, &range_conds, num_index_cols)
1058                    {
1059                        Ok(RangeCheck::ExceedsUpper) => return Ok(false),
1060                        Ok(RangeCheck::BelowLower) => return Ok(true),
1061                        Ok(RangeCheck::Match) => {}
1062                        Err(e) => {
1063                            scan_err = Some(e);
1064                            return Ok(false);
1065                        }
1066                    }
1067                    match extract_pk_key(key, value, is_unique, num_index_cols, num_pk_cols) {
1068                        Ok(pk) => pk_keys.push(pk),
1069                        Err(e) => {
1070                            scan_err = Some(e);
1071                            return Ok(false);
1072                        }
1073                    }
1074                    Ok(true)
1075                })
1076                .map_err(SqlError::Storage)?;
1077                if let Some(e) = scan_err {
1078                    return Err(e);
1079                }
1080            }
1081
1082            let mut rows = Vec::new();
1083            let mut rtx = db.begin_read();
1084            let col_map = ColumnMap::new(columns);
1085            for pk_key in &pk_keys {
1086                if let Some(value) = rtx
1087                    .table_get(lower_name.as_bytes(), pk_key)
1088                    .map_err(SqlError::Storage)?
1089                {
1090                    let row = decode_full_row(table_schema, pk_key, &value)?;
1091                    if let Some(ref expr) = where_clause {
1092                        match eval_expr(expr, &col_map, &row) {
1093                            Ok(val) if is_truthy(&val) => rows.push(row),
1094                            _ => {}
1095                        }
1096                    } else {
1097                        rows.push(row);
1098                    }
1099                }
1100            }
1101            Ok((rows, where_clause.is_some()))
1102        }
1103    }
1104}
1105
1106/// Collect rows via WriteTxn using the scan plan.
1107fn collect_rows_write(
1108    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
1109    table_schema: &TableSchema,
1110    where_clause: &Option<Expr>,
1111    limit: Option<usize>,
1112) -> Result<(Vec<Vec<Value>>, bool)> {
1113    let plan = planner::plan_select(table_schema, where_clause);
1114    let lower_name = &table_schema.name;
1115    let columns = &table_schema.columns;
1116
1117    match plan {
1118        ScanPlan::SeqScan => {
1119            let simple_pred = where_clause
1120                .as_ref()
1121                .and_then(|expr| try_simple_predicate(expr, table_schema));
1122
1123            if let Some(ref pred) = simple_pred {
1124                let mut rows = Vec::new();
1125                let mut scan_err: Option<SqlError> = None;
1126                wtx.table_scan_from(lower_name.as_bytes(), b"", |key, value| {
1127                    match pred.matches_raw(key, value) {
1128                        Ok(true) => match decode_full_row(table_schema, key, value) {
1129                            Ok(row) => rows.push(row),
1130                            Err(e) => scan_err = Some(e),
1131                        },
1132                        Ok(false) => {}
1133                        Err(e) => scan_err = Some(e),
1134                    }
1135                    let keep_going = scan_err.is_none() && limit.map_or(true, |n| rows.len() < n);
1136                    Ok(keep_going)
1137                })
1138                .map_err(SqlError::Storage)?;
1139                if let Some(e) = scan_err {
1140                    return Err(e);
1141                }
1142                return Ok((rows, true));
1143            }
1144
1145            let mut rows = Vec::new();
1146            let mut scan_err: Option<SqlError> = None;
1147
1148            let col_map = ColumnMap::new(columns);
1149            let partial_ctx = where_clause.as_ref().and_then(|expr| {
1150                let needed = referenced_columns(expr, columns);
1151                if needed.len() < columns.len() {
1152                    Some(PartialDecodeCtx::new(table_schema, &needed))
1153                } else {
1154                    None
1155                }
1156            });
1157
1158            wtx.table_scan_from(lower_name.as_bytes(), b"", |key, value| {
1159                match (&where_clause, &partial_ctx) {
1160                    (Some(expr), Some(ctx)) => match ctx.decode(key, value) {
1161                        Ok(partial) => match eval_expr(expr, &col_map, &partial) {
1162                            Ok(val) if is_truthy(&val) => match ctx.complete(partial, key, value) {
1163                                Ok(row) => rows.push(row),
1164                                Err(e) => scan_err = Some(e),
1165                            },
1166                            Err(e) => scan_err = Some(e),
1167                            _ => {}
1168                        },
1169                        Err(e) => scan_err = Some(e),
1170                    },
1171                    (Some(expr), None) => match decode_full_row(table_schema, key, value) {
1172                        Ok(row) => match eval_expr(expr, &col_map, &row) {
1173                            Ok(val) if is_truthy(&val) => rows.push(row),
1174                            Err(e) => scan_err = Some(e),
1175                            _ => {}
1176                        },
1177                        Err(e) => scan_err = Some(e),
1178                    },
1179                    _ => match decode_full_row(table_schema, key, value) {
1180                        Ok(row) => rows.push(row),
1181                        Err(e) => scan_err = Some(e),
1182                    },
1183                }
1184                let keep_going = scan_err.is_none() && limit.map_or(true, |n| rows.len() < n);
1185                Ok(keep_going)
1186            })
1187            .map_err(SqlError::Storage)?;
1188            if let Some(e) = scan_err {
1189                return Err(e);
1190            }
1191            Ok((rows, where_clause.is_some()))
1192        }
1193
1194        ScanPlan::PkLookup { pk_values } => {
1195            let key = encode_composite_key(&pk_values);
1196            match wtx
1197                .table_get(lower_name.as_bytes(), &key)
1198                .map_err(SqlError::Storage)?
1199            {
1200                Some(value) => {
1201                    let row = decode_full_row(table_schema, &key, &value)?;
1202                    if let Some(ref expr) = where_clause {
1203                        let col_map = ColumnMap::new(columns);
1204                        match eval_expr(expr, &col_map, &row) {
1205                            Ok(val) if is_truthy(&val) => Ok((vec![row], true)),
1206                            _ => Ok((vec![], true)),
1207                        }
1208                    } else {
1209                        Ok((vec![row], false))
1210                    }
1211                }
1212                None => Ok((vec![], true)),
1213            }
1214        }
1215
1216        ScanPlan::IndexScan {
1217            idx_table,
1218            prefix,
1219            num_prefix_cols,
1220            range_conds,
1221            is_unique,
1222            index_columns,
1223            ..
1224        } => {
1225            let num_pk_cols = table_schema.primary_key_columns.len();
1226            let num_index_cols = index_columns.len();
1227            let mut pk_keys: Vec<Vec<u8>> = Vec::new();
1228
1229            {
1230                let mut scan_err: Option<SqlError> = None;
1231                wtx.table_scan_from(&idx_table, &prefix, |key, value| {
1232                    if !key.starts_with(&prefix) {
1233                        return Ok(false);
1234                    }
1235                    match check_range_conditions(key, num_prefix_cols, &range_conds, num_index_cols)
1236                    {
1237                        Ok(RangeCheck::ExceedsUpper) => return Ok(false),
1238                        Ok(RangeCheck::BelowLower) => return Ok(true),
1239                        Ok(RangeCheck::Match) => {}
1240                        Err(e) => {
1241                            scan_err = Some(e);
1242                            return Ok(false);
1243                        }
1244                    }
1245                    match extract_pk_key(key, value, is_unique, num_index_cols, num_pk_cols) {
1246                        Ok(pk) => pk_keys.push(pk),
1247                        Err(e) => {
1248                            scan_err = Some(e);
1249                            return Ok(false);
1250                        }
1251                    }
1252                    Ok(true)
1253                })
1254                .map_err(SqlError::Storage)?;
1255                if let Some(e) = scan_err {
1256                    return Err(e);
1257                }
1258            }
1259
1260            let mut rows = Vec::new();
1261            let col_map = ColumnMap::new(columns);
1262            for pk_key in &pk_keys {
1263                if let Some(value) = wtx
1264                    .table_get(lower_name.as_bytes(), pk_key)
1265                    .map_err(SqlError::Storage)?
1266                {
1267                    let row = decode_full_row(table_schema, pk_key, &value)?;
1268                    if let Some(ref expr) = where_clause {
1269                        match eval_expr(expr, &col_map, &row) {
1270                            Ok(val) if is_truthy(&val) => rows.push(row),
1271                            _ => {}
1272                        }
1273                    } else {
1274                        rows.push(row);
1275                    }
1276                }
1277            }
1278            Ok((rows, where_clause.is_some()))
1279        }
1280    }
1281}
1282
1283/// Collect (encoded_key, full_row) pairs via ReadTxn using the scan plan.
1284/// Used by DELETE and UPDATE which need the encoded PK key.
1285fn collect_keyed_rows_read(
1286    db: &Database,
1287    table_schema: &TableSchema,
1288    where_clause: &Option<Expr>,
1289) -> Result<Vec<(Vec<u8>, Vec<Value>)>> {
1290    let plan = planner::plan_select(table_schema, where_clause);
1291    let lower_name = &table_schema.name;
1292
1293    match plan {
1294        ScanPlan::SeqScan => {
1295            let mut rows = Vec::new();
1296            let mut rtx = db.begin_read();
1297            let mut scan_err: Option<SqlError> = None;
1298            rtx.table_for_each(lower_name.as_bytes(), |key, value| {
1299                match decode_full_row(table_schema, key, value) {
1300                    Ok(row) => rows.push((key.to_vec(), row)),
1301                    Err(e) => scan_err = Some(e),
1302                }
1303                Ok(())
1304            })
1305            .map_err(SqlError::Storage)?;
1306            if let Some(e) = scan_err {
1307                return Err(e);
1308            }
1309            Ok(rows)
1310        }
1311
1312        ScanPlan::PkLookup { pk_values } => {
1313            let key = encode_composite_key(&pk_values);
1314            let mut rtx = db.begin_read();
1315            match rtx
1316                .table_get(lower_name.as_bytes(), &key)
1317                .map_err(SqlError::Storage)?
1318            {
1319                Some(value) => {
1320                    let row = decode_full_row(table_schema, &key, &value)?;
1321                    Ok(vec![(key, row)])
1322                }
1323                None => Ok(vec![]),
1324            }
1325        }
1326
1327        ScanPlan::IndexScan {
1328            idx_table,
1329            prefix,
1330            num_prefix_cols,
1331            range_conds,
1332            is_unique,
1333            index_columns,
1334            ..
1335        } => {
1336            let num_pk_cols = table_schema.primary_key_columns.len();
1337            let num_index_cols = index_columns.len();
1338            let mut pk_keys: Vec<Vec<u8>> = Vec::new();
1339
1340            {
1341                let mut rtx = db.begin_read();
1342                let mut scan_err: Option<SqlError> = None;
1343                rtx.table_scan_from(&idx_table, &prefix, |key, value| {
1344                    if !key.starts_with(&prefix) {
1345                        return Ok(false);
1346                    }
1347                    match check_range_conditions(key, num_prefix_cols, &range_conds, num_index_cols)
1348                    {
1349                        Ok(RangeCheck::ExceedsUpper) => return Ok(false),
1350                        Ok(RangeCheck::BelowLower) => return Ok(true),
1351                        Ok(RangeCheck::Match) => {}
1352                        Err(e) => {
1353                            scan_err = Some(e);
1354                            return Ok(false);
1355                        }
1356                    }
1357                    match extract_pk_key(key, value, is_unique, num_index_cols, num_pk_cols) {
1358                        Ok(pk) => pk_keys.push(pk),
1359                        Err(e) => {
1360                            scan_err = Some(e);
1361                            return Ok(false);
1362                        }
1363                    }
1364                    Ok(true)
1365                })
1366                .map_err(SqlError::Storage)?;
1367                if let Some(e) = scan_err {
1368                    return Err(e);
1369                }
1370            }
1371
1372            let mut rows = Vec::new();
1373            let mut rtx = db.begin_read();
1374            for pk_key in &pk_keys {
1375                if let Some(value) = rtx
1376                    .table_get(lower_name.as_bytes(), pk_key)
1377                    .map_err(SqlError::Storage)?
1378                {
1379                    rows.push((
1380                        pk_key.clone(),
1381                        decode_full_row(table_schema, pk_key, &value)?,
1382                    ));
1383                }
1384            }
1385            Ok(rows)
1386        }
1387    }
1388}
1389
1390/// Collect (encoded_key, full_row) pairs via WriteTxn using the scan plan.
1391fn collect_keyed_rows_write(
1392    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
1393    table_schema: &TableSchema,
1394    where_clause: &Option<Expr>,
1395) -> Result<Vec<(Vec<u8>, Vec<Value>)>> {
1396    let plan = planner::plan_select(table_schema, where_clause);
1397    let lower_name = &table_schema.name;
1398
1399    match plan {
1400        ScanPlan::SeqScan => {
1401            let mut rows = Vec::new();
1402            let mut scan_err: Option<SqlError> = None;
1403            wtx.table_for_each(lower_name.as_bytes(), |key, value| {
1404                match decode_full_row(table_schema, key, value) {
1405                    Ok(row) => rows.push((key.to_vec(), row)),
1406                    Err(e) => scan_err = Some(e),
1407                }
1408                Ok(())
1409            })
1410            .map_err(SqlError::Storage)?;
1411            if let Some(e) = scan_err {
1412                return Err(e);
1413            }
1414            Ok(rows)
1415        }
1416
1417        ScanPlan::PkLookup { pk_values } => {
1418            let key = encode_composite_key(&pk_values);
1419            match wtx
1420                .table_get(lower_name.as_bytes(), &key)
1421                .map_err(SqlError::Storage)?
1422            {
1423                Some(value) => {
1424                    let row = decode_full_row(table_schema, &key, &value)?;
1425                    Ok(vec![(key, row)])
1426                }
1427                None => Ok(vec![]),
1428            }
1429        }
1430
1431        ScanPlan::IndexScan {
1432            idx_table,
1433            prefix,
1434            num_prefix_cols,
1435            range_conds,
1436            is_unique,
1437            index_columns,
1438            ..
1439        } => {
1440            let num_pk_cols = table_schema.primary_key_columns.len();
1441            let num_index_cols = index_columns.len();
1442            let mut pk_keys: Vec<Vec<u8>> = Vec::new();
1443
1444            {
1445                let mut scan_err: Option<SqlError> = None;
1446                wtx.table_scan_from(&idx_table, &prefix, |key, value| {
1447                    if !key.starts_with(&prefix) {
1448                        return Ok(false);
1449                    }
1450                    match check_range_conditions(key, num_prefix_cols, &range_conds, num_index_cols)
1451                    {
1452                        Ok(RangeCheck::ExceedsUpper) => return Ok(false),
1453                        Ok(RangeCheck::BelowLower) => return Ok(true),
1454                        Ok(RangeCheck::Match) => {}
1455                        Err(e) => {
1456                            scan_err = Some(e);
1457                            return Ok(false);
1458                        }
1459                    }
1460                    match extract_pk_key(key, value, is_unique, num_index_cols, num_pk_cols) {
1461                        Ok(pk) => pk_keys.push(pk),
1462                        Err(e) => {
1463                            scan_err = Some(e);
1464                            return Ok(false);
1465                        }
1466                    }
1467                    Ok(true)
1468                })
1469                .map_err(SqlError::Storage)?;
1470                if let Some(e) = scan_err {
1471                    return Err(e);
1472                }
1473            }
1474
1475            let mut rows = Vec::new();
1476            for pk_key in &pk_keys {
1477                if let Some(value) = wtx
1478                    .table_get(lower_name.as_bytes(), pk_key)
1479                    .map_err(SqlError::Storage)?
1480                {
1481                    rows.push((
1482                        pk_key.clone(),
1483                        decode_full_row(table_schema, pk_key, &value)?,
1484                    ));
1485                }
1486            }
1487            Ok(rows)
1488        }
1489    }
1490}
1491
1492// ── DML ─────────────────────────────────────────────────────────────
1493
1494fn exec_insert(
1495    db: &Database,
1496    schema: &SchemaManager,
1497    stmt: &InsertStmt,
1498    params: &[Value],
1499) -> Result<ExecutionResult> {
1500    let materialized;
1501    let stmt = if insert_has_subquery(stmt) {
1502        materialized = materialize_insert(stmt, &mut |sub| exec_subquery_read(db, schema, sub))?;
1503        &materialized
1504    } else {
1505        stmt
1506    };
1507
1508    let lower_name = stmt.table.to_ascii_lowercase();
1509    let table_schema = schema
1510        .get(&lower_name)
1511        .ok_or_else(|| SqlError::TableNotFound(stmt.table.clone()))?;
1512
1513    let insert_columns = if stmt.columns.is_empty() {
1514        table_schema
1515            .columns
1516            .iter()
1517            .map(|c| c.name.clone())
1518            .collect::<Vec<_>>()
1519    } else {
1520        stmt.columns
1521            .iter()
1522            .map(|c| c.to_ascii_lowercase())
1523            .collect()
1524    };
1525
1526    let col_indices: Vec<usize> = insert_columns
1527        .iter()
1528        .map(|name| {
1529            table_schema
1530                .column_index(name)
1531                .ok_or_else(|| SqlError::ColumnNotFound(name.clone()))
1532        })
1533        .collect::<Result<_>>()?;
1534
1535    let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
1536    let mut count: u64 = 0;
1537
1538    let pk_indices = table_schema.pk_indices();
1539    let non_pk = table_schema.non_pk_indices();
1540    let mut row = vec![Value::Null; table_schema.columns.len()];
1541    let mut pk_values: Vec<Value> = vec![Value::Null; pk_indices.len()];
1542    let mut value_values: Vec<Value> = vec![Value::Null; non_pk.len()];
1543    let mut key_buf: Vec<u8> = Vec::with_capacity(64);
1544    let mut value_buf: Vec<u8> = Vec::with_capacity(256);
1545
1546    for value_row in &stmt.values {
1547        if value_row.len() != insert_columns.len() {
1548            return Err(SqlError::InvalidValue(format!(
1549                "expected {} values, got {}",
1550                insert_columns.len(),
1551                value_row.len()
1552            )));
1553        }
1554
1555        for v in row.iter_mut() {
1556            *v = Value::Null;
1557        }
1558
1559        for (i, expr) in value_row.iter().enumerate() {
1560            let val = if let Expr::Parameter(n) = expr {
1561                params
1562                    .get(n - 1)
1563                    .cloned()
1564                    .ok_or_else(|| SqlError::Parse(format!("unbound parameter ${n}")))?
1565            } else {
1566                eval_const_expr(expr)?
1567            };
1568            let col_idx = col_indices[i];
1569            let col = &table_schema.columns[col_idx];
1570
1571            let got_type = val.data_type();
1572            row[col_idx] = if val.is_null() {
1573                Value::Null
1574            } else {
1575                val.coerce_into(col.data_type)
1576                    .ok_or_else(|| SqlError::TypeMismatch {
1577                        expected: col.data_type.to_string(),
1578                        got: got_type.to_string(),
1579                    })?
1580            };
1581        }
1582
1583        for col in &table_schema.columns {
1584            if !col.nullable && row[col.position as usize].is_null() {
1585                return Err(SqlError::NotNullViolation(col.name.clone()));
1586            }
1587        }
1588
1589        for (j, &i) in pk_indices.iter().enumerate() {
1590            pk_values[j] = std::mem::replace(&mut row[i], Value::Null);
1591        }
1592        encode_composite_key_into(&pk_values, &mut key_buf);
1593
1594        for (j, &i) in non_pk.iter().enumerate() {
1595            value_values[j] = std::mem::replace(&mut row[i], Value::Null);
1596        }
1597        encode_row_into(&value_values, &mut value_buf);
1598
1599        if key_buf.len() > citadel_core::MAX_KEY_SIZE {
1600            return Err(SqlError::KeyTooLarge {
1601                size: key_buf.len(),
1602                max: citadel_core::MAX_KEY_SIZE,
1603            });
1604        }
1605        if value_buf.len() > citadel_core::MAX_INLINE_VALUE_SIZE {
1606            return Err(SqlError::RowTooLarge {
1607                size: value_buf.len(),
1608                max: citadel_core::MAX_INLINE_VALUE_SIZE,
1609            });
1610        }
1611
1612        let is_new = wtx
1613            .table_insert(stmt.table.as_bytes(), &key_buf, &value_buf)
1614            .map_err(SqlError::Storage)?;
1615        if !is_new {
1616            return Err(SqlError::DuplicateKey);
1617        }
1618
1619        if !table_schema.indices.is_empty() {
1620            for (j, &i) in pk_indices.iter().enumerate() {
1621                row[i] = pk_values[j].clone();
1622            }
1623            for (j, &i) in non_pk.iter().enumerate() {
1624                row[i] = std::mem::replace(&mut value_values[j], Value::Null);
1625            }
1626            insert_index_entries(&mut wtx, table_schema, &row, &pk_values)?;
1627        }
1628        count += 1;
1629    }
1630
1631    wtx.commit().map_err(SqlError::Storage)?;
1632    Ok(ExecutionResult::RowsAffected(count))
1633}
1634
1635fn has_subquery(expr: &Expr) -> bool {
1636    match expr {
1637        Expr::InSubquery { .. } | Expr::Exists { .. } | Expr::ScalarSubquery(_) => true,
1638        Expr::BinaryOp { left, right, .. } => has_subquery(left) || has_subquery(right),
1639        Expr::UnaryOp { expr, .. } => has_subquery(expr),
1640        Expr::IsNull(e) | Expr::IsNotNull(e) => has_subquery(e),
1641        Expr::InList { expr, list, .. } => has_subquery(expr) || list.iter().any(has_subquery),
1642        Expr::InSet { expr, .. } => has_subquery(expr),
1643        Expr::Between {
1644            expr, low, high, ..
1645        } => has_subquery(expr) || has_subquery(low) || has_subquery(high),
1646        Expr::Like {
1647            expr,
1648            pattern,
1649            escape,
1650            ..
1651        } => {
1652            has_subquery(expr)
1653                || has_subquery(pattern)
1654                || escape.as_ref().is_some_and(|e| has_subquery(e))
1655        }
1656        Expr::Case {
1657            operand,
1658            conditions,
1659            else_result,
1660        } => {
1661            operand.as_ref().is_some_and(|e| has_subquery(e))
1662                || conditions
1663                    .iter()
1664                    .any(|(c, r)| has_subquery(c) || has_subquery(r))
1665                || else_result.as_ref().is_some_and(|e| has_subquery(e))
1666        }
1667        Expr::Coalesce(args) => args.iter().any(has_subquery),
1668        Expr::Cast { expr, .. } => has_subquery(expr),
1669        Expr::Function { args, .. } => args.iter().any(has_subquery),
1670        _ => false,
1671    }
1672}
1673
1674fn stmt_has_subquery(stmt: &SelectStmt) -> bool {
1675    if let Some(ref w) = stmt.where_clause {
1676        if has_subquery(w) {
1677            return true;
1678        }
1679    }
1680    if let Some(ref h) = stmt.having {
1681        if has_subquery(h) {
1682            return true;
1683        }
1684    }
1685    for col in &stmt.columns {
1686        if let SelectColumn::Expr { expr, .. } = col {
1687            if has_subquery(expr) {
1688                return true;
1689            }
1690        }
1691    }
1692    for ob in &stmt.order_by {
1693        if has_subquery(&ob.expr) {
1694            return true;
1695        }
1696    }
1697    for join in &stmt.joins {
1698        if let Some(ref on_expr) = join.on_clause {
1699            if has_subquery(on_expr) {
1700                return true;
1701            }
1702        }
1703    }
1704    false
1705}
1706
1707fn materialize_expr(
1708    expr: &Expr,
1709    exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
1710) -> Result<Expr> {
1711    match expr {
1712        Expr::InSubquery {
1713            expr: e,
1714            subquery,
1715            negated,
1716        } => {
1717            let inner = materialize_expr(e, exec_sub)?;
1718            let qr = exec_sub(subquery)?;
1719            if !qr.columns.is_empty() && qr.columns.len() != 1 {
1720                return Err(SqlError::SubqueryMultipleColumns);
1721            }
1722            let mut values = std::collections::HashSet::new();
1723            let mut has_null = false;
1724            for row in &qr.rows {
1725                if row[0].is_null() {
1726                    has_null = true;
1727                } else {
1728                    values.insert(row[0].clone());
1729                }
1730            }
1731            Ok(Expr::InSet {
1732                expr: Box::new(inner),
1733                values,
1734                has_null,
1735                negated: *negated,
1736            })
1737        }
1738        Expr::ScalarSubquery(subquery) => {
1739            let qr = exec_sub(subquery)?;
1740            if qr.rows.len() > 1 {
1741                return Err(SqlError::SubqueryMultipleRows);
1742            }
1743            let val = if qr.rows.is_empty() {
1744                Value::Null
1745            } else {
1746                qr.rows[0][0].clone()
1747            };
1748            Ok(Expr::Literal(val))
1749        }
1750        Expr::Exists { subquery, negated } => {
1751            let qr = exec_sub(subquery)?;
1752            let exists = !qr.rows.is_empty();
1753            let result = if *negated { !exists } else { exists };
1754            Ok(Expr::Literal(Value::Boolean(result)))
1755        }
1756        Expr::InList {
1757            expr: e,
1758            list,
1759            negated,
1760        } => {
1761            let inner = materialize_expr(e, exec_sub)?;
1762            let items = list
1763                .iter()
1764                .map(|item| materialize_expr(item, exec_sub))
1765                .collect::<Result<Vec<_>>>()?;
1766            Ok(Expr::InList {
1767                expr: Box::new(inner),
1768                list: items,
1769                negated: *negated,
1770            })
1771        }
1772        Expr::BinaryOp { left, op, right } => Ok(Expr::BinaryOp {
1773            left: Box::new(materialize_expr(left, exec_sub)?),
1774            op: *op,
1775            right: Box::new(materialize_expr(right, exec_sub)?),
1776        }),
1777        Expr::UnaryOp { op, expr: e } => Ok(Expr::UnaryOp {
1778            op: *op,
1779            expr: Box::new(materialize_expr(e, exec_sub)?),
1780        }),
1781        Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(materialize_expr(e, exec_sub)?))),
1782        Expr::IsNotNull(e) => Ok(Expr::IsNotNull(Box::new(materialize_expr(e, exec_sub)?))),
1783        Expr::InSet {
1784            expr: e,
1785            values,
1786            has_null,
1787            negated,
1788        } => Ok(Expr::InSet {
1789            expr: Box::new(materialize_expr(e, exec_sub)?),
1790            values: values.clone(),
1791            has_null: *has_null,
1792            negated: *negated,
1793        }),
1794        Expr::Between {
1795            expr: e,
1796            low,
1797            high,
1798            negated,
1799        } => Ok(Expr::Between {
1800            expr: Box::new(materialize_expr(e, exec_sub)?),
1801            low: Box::new(materialize_expr(low, exec_sub)?),
1802            high: Box::new(materialize_expr(high, exec_sub)?),
1803            negated: *negated,
1804        }),
1805        Expr::Like {
1806            expr: e,
1807            pattern,
1808            escape,
1809            negated,
1810        } => {
1811            let esc = escape
1812                .as_ref()
1813                .map(|es| materialize_expr(es, exec_sub).map(Box::new))
1814                .transpose()?;
1815            Ok(Expr::Like {
1816                expr: Box::new(materialize_expr(e, exec_sub)?),
1817                pattern: Box::new(materialize_expr(pattern, exec_sub)?),
1818                escape: esc,
1819                negated: *negated,
1820            })
1821        }
1822        Expr::Case {
1823            operand,
1824            conditions,
1825            else_result,
1826        } => {
1827            let op = operand
1828                .as_ref()
1829                .map(|e| materialize_expr(e, exec_sub).map(Box::new))
1830                .transpose()?;
1831            let conds = conditions
1832                .iter()
1833                .map(|(c, r)| {
1834                    Ok((
1835                        materialize_expr(c, exec_sub)?,
1836                        materialize_expr(r, exec_sub)?,
1837                    ))
1838                })
1839                .collect::<Result<Vec<_>>>()?;
1840            let else_r = else_result
1841                .as_ref()
1842                .map(|e| materialize_expr(e, exec_sub).map(Box::new))
1843                .transpose()?;
1844            Ok(Expr::Case {
1845                operand: op,
1846                conditions: conds,
1847                else_result: else_r,
1848            })
1849        }
1850        Expr::Coalesce(args) => {
1851            let materialized = args
1852                .iter()
1853                .map(|a| materialize_expr(a, exec_sub))
1854                .collect::<Result<Vec<_>>>()?;
1855            Ok(Expr::Coalesce(materialized))
1856        }
1857        Expr::Cast { expr: e, data_type } => Ok(Expr::Cast {
1858            expr: Box::new(materialize_expr(e, exec_sub)?),
1859            data_type: *data_type,
1860        }),
1861        Expr::Function { name, args } => {
1862            let materialized = args
1863                .iter()
1864                .map(|a| materialize_expr(a, exec_sub))
1865                .collect::<Result<Vec<_>>>()?;
1866            Ok(Expr::Function {
1867                name: name.clone(),
1868                args: materialized,
1869            })
1870        }
1871        other => Ok(other.clone()),
1872    }
1873}
1874
1875fn materialize_stmt(
1876    stmt: &SelectStmt,
1877    exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
1878) -> Result<SelectStmt> {
1879    let where_clause = stmt
1880        .where_clause
1881        .as_ref()
1882        .map(|e| materialize_expr(e, exec_sub))
1883        .transpose()?;
1884    let having = stmt
1885        .having
1886        .as_ref()
1887        .map(|e| materialize_expr(e, exec_sub))
1888        .transpose()?;
1889    let columns = stmt
1890        .columns
1891        .iter()
1892        .map(|c| match c {
1893            SelectColumn::AllColumns => Ok(SelectColumn::AllColumns),
1894            SelectColumn::Expr { expr, alias } => Ok(SelectColumn::Expr {
1895                expr: materialize_expr(expr, exec_sub)?,
1896                alias: alias.clone(),
1897            }),
1898        })
1899        .collect::<Result<Vec<_>>>()?;
1900    let order_by = stmt
1901        .order_by
1902        .iter()
1903        .map(|ob| {
1904            Ok(OrderByItem {
1905                expr: materialize_expr(&ob.expr, exec_sub)?,
1906                descending: ob.descending,
1907                nulls_first: ob.nulls_first,
1908            })
1909        })
1910        .collect::<Result<Vec<_>>>()?;
1911    let joins = stmt
1912        .joins
1913        .iter()
1914        .map(|j| {
1915            let on_clause = j
1916                .on_clause
1917                .as_ref()
1918                .map(|e| materialize_expr(e, exec_sub))
1919                .transpose()?;
1920            Ok(JoinClause {
1921                join_type: j.join_type,
1922                table: j.table.clone(),
1923                on_clause,
1924            })
1925        })
1926        .collect::<Result<Vec<_>>>()?;
1927    let group_by = stmt
1928        .group_by
1929        .iter()
1930        .map(|e| materialize_expr(e, exec_sub))
1931        .collect::<Result<Vec<_>>>()?;
1932    Ok(SelectStmt {
1933        columns,
1934        from: stmt.from.clone(),
1935        from_alias: stmt.from_alias.clone(),
1936        joins,
1937        distinct: stmt.distinct,
1938        where_clause,
1939        order_by,
1940        limit: stmt.limit.clone(),
1941        offset: stmt.offset.clone(),
1942        group_by,
1943        having,
1944    })
1945}
1946
1947fn exec_subquery_read(
1948    db: &Database,
1949    schema: &SchemaManager,
1950    stmt: &SelectStmt,
1951) -> Result<QueryResult> {
1952    match exec_select(db, schema, stmt)? {
1953        ExecutionResult::Query(qr) => Ok(qr),
1954        _ => Ok(QueryResult {
1955            columns: vec![],
1956            rows: vec![],
1957        }),
1958    }
1959}
1960
1961fn exec_subquery_write(
1962    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
1963    schema: &SchemaManager,
1964    stmt: &SelectStmt,
1965) -> Result<QueryResult> {
1966    match exec_select_in_txn(wtx, schema, stmt)? {
1967        ExecutionResult::Query(qr) => Ok(qr),
1968        _ => Ok(QueryResult {
1969            columns: vec![],
1970            rows: vec![],
1971        }),
1972    }
1973}
1974
1975fn update_has_subquery(stmt: &UpdateStmt) -> bool {
1976    stmt.where_clause.as_ref().is_some_and(has_subquery)
1977        || stmt.assignments.iter().any(|(_, e)| has_subquery(e))
1978}
1979
1980fn materialize_update(
1981    stmt: &UpdateStmt,
1982    exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
1983) -> Result<UpdateStmt> {
1984    let where_clause = stmt
1985        .where_clause
1986        .as_ref()
1987        .map(|e| materialize_expr(e, exec_sub))
1988        .transpose()?;
1989    let assignments = stmt
1990        .assignments
1991        .iter()
1992        .map(|(name, expr)| Ok((name.clone(), materialize_expr(expr, exec_sub)?)))
1993        .collect::<Result<Vec<_>>>()?;
1994    Ok(UpdateStmt {
1995        table: stmt.table.clone(),
1996        assignments,
1997        where_clause,
1998    })
1999}
2000
2001fn delete_has_subquery(stmt: &DeleteStmt) -> bool {
2002    stmt.where_clause.as_ref().is_some_and(has_subquery)
2003}
2004
2005fn materialize_delete(
2006    stmt: &DeleteStmt,
2007    exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
2008) -> Result<DeleteStmt> {
2009    let where_clause = stmt
2010        .where_clause
2011        .as_ref()
2012        .map(|e| materialize_expr(e, exec_sub))
2013        .transpose()?;
2014    Ok(DeleteStmt {
2015        table: stmt.table.clone(),
2016        where_clause,
2017    })
2018}
2019
2020fn insert_has_subquery(stmt: &InsertStmt) -> bool {
2021    stmt.values.iter().any(|row| row.iter().any(has_subquery))
2022}
2023
2024fn materialize_insert(
2025    stmt: &InsertStmt,
2026    exec_sub: &mut dyn FnMut(&SelectStmt) -> Result<QueryResult>,
2027) -> Result<InsertStmt> {
2028    let values = stmt
2029        .values
2030        .iter()
2031        .map(|row| {
2032            row.iter()
2033                .map(|e| materialize_expr(e, exec_sub))
2034                .collect::<Result<Vec<_>>>()
2035        })
2036        .collect::<Result<Vec<_>>>()?;
2037    Ok(InsertStmt {
2038        table: stmt.table.clone(),
2039        columns: stmt.columns.clone(),
2040        values,
2041    })
2042}
2043
2044fn exec_select(
2045    db: &Database,
2046    schema: &SchemaManager,
2047    stmt: &SelectStmt,
2048) -> Result<ExecutionResult> {
2049    let materialized;
2050    let stmt = if stmt_has_subquery(stmt) {
2051        materialized = materialize_stmt(stmt, &mut |sub| exec_subquery_read(db, schema, sub))?;
2052        &materialized
2053    } else {
2054        stmt
2055    };
2056
2057    if stmt.from.is_empty() {
2058        return exec_select_no_from(stmt);
2059    }
2060
2061    let lower_name = stmt.from.to_ascii_lowercase();
2062    let table_schema = schema
2063        .get(&lower_name)
2064        .ok_or_else(|| SqlError::TableNotFound(stmt.from.clone()))?;
2065
2066    if !stmt.joins.is_empty() {
2067        return exec_select_join(db, schema, stmt);
2068    }
2069
2070    if let Some(result) = try_count_star_shortcut(stmt, || {
2071        let mut rtx = db.begin_read();
2072        rtx.table_entry_count(lower_name.as_bytes())
2073            .map_err(SqlError::Storage)
2074    })? {
2075        return Ok(result);
2076    }
2077
2078    if let Some(plan) = StreamAggPlan::try_new(stmt, table_schema)? {
2079        let mut states: Vec<AggState> = plan.ops.iter().map(|(op, _)| AggState::new(op)).collect();
2080        let mut scan_err: Option<SqlError> = None;
2081        let mut rtx = db.begin_read();
2082        if stmt.where_clause.is_none() {
2083            rtx.table_scan_raw(lower_name.as_bytes(), |key, value| {
2084                plan.feed_row_raw(key, value, &mut states, &mut scan_err)
2085            })
2086            .map_err(SqlError::Storage)?;
2087        } else {
2088            let col_map = ColumnMap::new(&table_schema.columns);
2089            rtx.table_scan_raw(lower_name.as_bytes(), |key, value| {
2090                plan.feed_row(
2091                    key,
2092                    value,
2093                    table_schema,
2094                    &col_map,
2095                    &stmt.where_clause,
2096                    &mut states,
2097                    &mut scan_err,
2098                )
2099            })
2100            .map_err(SqlError::Storage)?;
2101        }
2102        if let Some(e) = scan_err {
2103            return Err(e);
2104        }
2105        return Ok(plan.finish(states));
2106    }
2107
2108    if let Some(plan) = StreamGroupByPlan::try_new(stmt, table_schema)? {
2109        let lower = lower_name.clone();
2110        let mut rtx = db.begin_read();
2111        return plan
2112            .execute_scan(|cb| rtx.table_scan_raw(lower.as_bytes(), |key, value| cb(key, value)));
2113    }
2114
2115    if let Some(plan) = TopKScanPlan::try_new(stmt, table_schema)? {
2116        let lower = lower_name.clone();
2117        let mut rtx = db.begin_read();
2118        return plan.execute_scan(table_schema, stmt, |cb| {
2119            rtx.table_scan_raw(lower.as_bytes(), |key, value| cb(key, value))
2120        });
2121    }
2122
2123    let scan_limit = compute_scan_limit(stmt);
2124    let (rows, predicate_applied) =
2125        collect_rows_read(db, table_schema, &stmt.where_clause, scan_limit)?;
2126    process_select(&table_schema.columns, rows, stmt, predicate_applied)
2127}
2128
2129fn compute_scan_limit(stmt: &SelectStmt) -> Option<usize> {
2130    if !stmt.order_by.is_empty()
2131        || !stmt.group_by.is_empty()
2132        || stmt.distinct
2133        || stmt.having.is_some()
2134    {
2135        return None;
2136    }
2137    let has_aggregates = stmt.columns.iter().any(|c| match c {
2138        SelectColumn::Expr { expr, .. } => is_aggregate_expr(expr),
2139        _ => false,
2140    });
2141    if has_aggregates {
2142        return None;
2143    }
2144    let limit = stmt.limit.as_ref()?;
2145    let limit_val = eval_const_int(limit).ok()?.max(0) as usize;
2146    let offset_val = stmt
2147        .offset
2148        .as_ref()
2149        .and_then(|e| eval_const_int(e).ok())
2150        .unwrap_or(0)
2151        .max(0) as usize;
2152    Some(limit_val.saturating_add(offset_val))
2153}
2154
2155fn try_count_star_shortcut(
2156    stmt: &SelectStmt,
2157    get_count: impl FnOnce() -> Result<u64>,
2158) -> Result<Option<ExecutionResult>> {
2159    if stmt.columns.len() != 1
2160        || stmt.where_clause.is_some()
2161        || !stmt.group_by.is_empty()
2162        || stmt.having.is_some()
2163    {
2164        return Ok(None);
2165    }
2166    let col = match &stmt.columns[0] {
2167        SelectColumn::Expr { expr, alias } => (expr, alias),
2168        _ => return Ok(None),
2169    };
2170    if !matches!(col.0, Expr::CountStar) {
2171        return Ok(None);
2172    }
2173    let count = get_count()? as i64;
2174    let col_name = col.1.as_deref().unwrap_or("COUNT(*)").to_string();
2175    Ok(Some(ExecutionResult::Query(QueryResult {
2176        columns: vec![col_name],
2177        rows: vec![vec![Value::Integer(count)]],
2178    })))
2179}
2180
2181enum StreamAgg {
2182    CountStar,
2183    Count(usize),
2184    Sum(usize),
2185    Avg(usize),
2186    Min(usize),
2187    Max(usize),
2188}
2189
2190enum RawAggTarget {
2191    CountStar,
2192    Pk(usize),
2193    NonPk(usize),
2194}
2195
2196enum AggState {
2197    CountStar(i64),
2198    Count(i64),
2199    Sum {
2200        int_sum: i64,
2201        real_sum: f64,
2202        has_real: bool,
2203        all_null: bool,
2204    },
2205    Avg {
2206        sum: f64,
2207        count: i64,
2208    },
2209    Min(Option<Value>),
2210    Max(Option<Value>),
2211}
2212
2213impl AggState {
2214    fn new(op: &StreamAgg) -> Self {
2215        match op {
2216            StreamAgg::CountStar => AggState::CountStar(0),
2217            StreamAgg::Count(_) => AggState::Count(0),
2218            StreamAgg::Sum(_) => AggState::Sum {
2219                int_sum: 0,
2220                real_sum: 0.0,
2221                has_real: false,
2222                all_null: true,
2223            },
2224            StreamAgg::Avg(_) => AggState::Avg { sum: 0.0, count: 0 },
2225            StreamAgg::Min(_) => AggState::Min(None),
2226            StreamAgg::Max(_) => AggState::Max(None),
2227        }
2228    }
2229
2230    fn feed_val(&mut self, val: &Value) -> Result<()> {
2231        match self {
2232            AggState::CountStar(c) => {
2233                *c += 1;
2234            }
2235            AggState::Count(c) => {
2236                if !val.is_null() {
2237                    *c += 1;
2238                }
2239            }
2240            AggState::Sum {
2241                int_sum,
2242                real_sum,
2243                has_real,
2244                all_null,
2245            } => match val {
2246                Value::Integer(i) => {
2247                    *int_sum += i;
2248                    *all_null = false;
2249                }
2250                Value::Real(r) => {
2251                    *real_sum += r;
2252                    *has_real = true;
2253                    *all_null = false;
2254                }
2255                Value::Null => {}
2256                _ => {
2257                    return Err(SqlError::TypeMismatch {
2258                        expected: "numeric".into(),
2259                        got: val.data_type().to_string(),
2260                    })
2261                }
2262            },
2263            AggState::Avg { sum, count } => match val {
2264                Value::Integer(i) => {
2265                    *sum += *i as f64;
2266                    *count += 1;
2267                }
2268                Value::Real(r) => {
2269                    *sum += r;
2270                    *count += 1;
2271                }
2272                Value::Null => {}
2273                _ => {
2274                    return Err(SqlError::TypeMismatch {
2275                        expected: "numeric".into(),
2276                        got: val.data_type().to_string(),
2277                    })
2278                }
2279            },
2280            AggState::Min(cur) => {
2281                if !val.is_null() {
2282                    *cur = Some(match cur.take() {
2283                        None => val.clone(),
2284                        Some(m) => {
2285                            if val < &m {
2286                                val.clone()
2287                            } else {
2288                                m
2289                            }
2290                        }
2291                    });
2292                }
2293            }
2294            AggState::Max(cur) => {
2295                if !val.is_null() {
2296                    *cur = Some(match cur.take() {
2297                        None => val.clone(),
2298                        Some(m) => {
2299                            if val > &m {
2300                                val.clone()
2301                            } else {
2302                                m
2303                            }
2304                        }
2305                    });
2306                }
2307            }
2308        }
2309        Ok(())
2310    }
2311
2312    fn feed_raw(&mut self, raw: &RawColumn) -> Result<()> {
2313        match self {
2314            AggState::CountStar(c) => {
2315                *c += 1;
2316            }
2317            AggState::Count(c) => {
2318                if !matches!(raw, RawColumn::Null) {
2319                    *c += 1;
2320                }
2321            }
2322            AggState::Sum {
2323                int_sum,
2324                real_sum,
2325                has_real,
2326                all_null,
2327            } => match raw {
2328                RawColumn::Integer(i) => {
2329                    *int_sum += i;
2330                    *all_null = false;
2331                }
2332                RawColumn::Real(r) => {
2333                    *real_sum += r;
2334                    *has_real = true;
2335                    *all_null = false;
2336                }
2337                RawColumn::Null => {}
2338                _ => {
2339                    return Err(SqlError::TypeMismatch {
2340                        expected: "numeric".into(),
2341                        got: "non-numeric".into(),
2342                    })
2343                }
2344            },
2345            AggState::Avg { sum, count } => match raw {
2346                RawColumn::Integer(i) => {
2347                    *sum += *i as f64;
2348                    *count += 1;
2349                }
2350                RawColumn::Real(r) => {
2351                    *sum += r;
2352                    *count += 1;
2353                }
2354                RawColumn::Null => {}
2355                _ => {
2356                    return Err(SqlError::TypeMismatch {
2357                        expected: "numeric".into(),
2358                        got: "non-numeric".into(),
2359                    })
2360                }
2361            },
2362            AggState::Min(cur) => {
2363                if !matches!(raw, RawColumn::Null) {
2364                    let val = raw.to_value();
2365                    *cur = Some(match cur.take() {
2366                        None => val,
2367                        Some(m) => {
2368                            if val < m {
2369                                val
2370                            } else {
2371                                m
2372                            }
2373                        }
2374                    });
2375                }
2376            }
2377            AggState::Max(cur) => {
2378                if !matches!(raw, RawColumn::Null) {
2379                    let val = raw.to_value();
2380                    *cur = Some(match cur.take() {
2381                        None => val,
2382                        Some(m) => {
2383                            if val > m {
2384                                val
2385                            } else {
2386                                m
2387                            }
2388                        }
2389                    });
2390                }
2391            }
2392        }
2393        Ok(())
2394    }
2395
2396    fn finish(self) -> Value {
2397        match self {
2398            AggState::CountStar(c) | AggState::Count(c) => Value::Integer(c),
2399            AggState::Sum {
2400                int_sum,
2401                real_sum,
2402                has_real,
2403                all_null,
2404            } => {
2405                if all_null {
2406                    Value::Null
2407                } else if has_real {
2408                    Value::Real(real_sum + int_sum as f64)
2409                } else {
2410                    Value::Integer(int_sum)
2411                }
2412            }
2413            AggState::Avg { sum, count } => {
2414                if count == 0 {
2415                    Value::Null
2416                } else {
2417                    Value::Real(sum / count as f64)
2418                }
2419            }
2420            AggState::Min(v) | AggState::Max(v) => v.unwrap_or(Value::Null),
2421        }
2422    }
2423}
2424
2425struct StreamAggPlan {
2426    ops: Vec<(StreamAgg, String)>,
2427    partial_ctx: Option<PartialDecodeCtx>,
2428    raw_targets: Vec<RawAggTarget>,
2429    num_pk_cols: usize,
2430}
2431
2432impl StreamAggPlan {
2433    fn try_new(stmt: &SelectStmt, table_schema: &TableSchema) -> Result<Option<Self>> {
2434        if !stmt.group_by.is_empty() || stmt.having.is_some() || !stmt.joins.is_empty() {
2435            return Ok(None);
2436        }
2437
2438        let col_map = ColumnMap::new(&table_schema.columns);
2439        let mut ops: Vec<(StreamAgg, String)> = Vec::new();
2440        for sel_col in &stmt.columns {
2441            let (expr, alias) = match sel_col {
2442                SelectColumn::Expr { expr, alias } => (expr, alias),
2443                _ => return Ok(None),
2444            };
2445            let name = alias
2446                .as_deref()
2447                .unwrap_or(&expr_display_name(expr))
2448                .to_string();
2449            match expr {
2450                Expr::CountStar => ops.push((StreamAgg::CountStar, name)),
2451                Expr::Function {
2452                    name: func_name,
2453                    args,
2454                } if args.len() == 1 => {
2455                    let func = func_name.to_ascii_uppercase();
2456                    let col_idx = match resolve_simple_col(&args[0], &col_map) {
2457                        Some(idx) => idx,
2458                        None => return Ok(None),
2459                    };
2460                    match func.as_str() {
2461                        "COUNT" => ops.push((StreamAgg::Count(col_idx), name)),
2462                        "SUM" => ops.push((StreamAgg::Sum(col_idx), name)),
2463                        "AVG" => ops.push((StreamAgg::Avg(col_idx), name)),
2464                        "MIN" => ops.push((StreamAgg::Min(col_idx), name)),
2465                        "MAX" => ops.push((StreamAgg::Max(col_idx), name)),
2466                        _ => return Ok(None),
2467                    }
2468                }
2469                _ => return Ok(None),
2470            }
2471        }
2472
2473        let mut needed: Vec<usize> = ops
2474            .iter()
2475            .filter_map(|(op, _)| match op {
2476                StreamAgg::CountStar => None,
2477                StreamAgg::Count(i)
2478                | StreamAgg::Sum(i)
2479                | StreamAgg::Avg(i)
2480                | StreamAgg::Min(i)
2481                | StreamAgg::Max(i) => Some(*i),
2482            })
2483            .collect();
2484        if let Some(ref where_expr) = stmt.where_clause {
2485            needed.extend(referenced_columns(where_expr, &table_schema.columns));
2486        }
2487        needed.sort_unstable();
2488        needed.dedup();
2489
2490        let partial_ctx = if needed.len() < table_schema.columns.len() {
2491            Some(PartialDecodeCtx::new(table_schema, &needed))
2492        } else {
2493            None
2494        };
2495
2496        let non_pk = table_schema.non_pk_indices();
2497        let raw_targets: Vec<RawAggTarget> = ops
2498            .iter()
2499            .map(|(op, _)| match op {
2500                StreamAgg::CountStar => RawAggTarget::CountStar,
2501                StreamAgg::Count(idx)
2502                | StreamAgg::Sum(idx)
2503                | StreamAgg::Avg(idx)
2504                | StreamAgg::Min(idx)
2505                | StreamAgg::Max(idx) => {
2506                    if let Some(pk_pos) = table_schema
2507                        .primary_key_columns
2508                        .iter()
2509                        .position(|&i| i as usize == *idx)
2510                    {
2511                        RawAggTarget::Pk(pk_pos)
2512                    } else {
2513                        let nonpk_idx = non_pk.iter().position(|&i| i == *idx).unwrap();
2514                        RawAggTarget::NonPk(nonpk_idx)
2515                    }
2516                }
2517            })
2518            .collect();
2519
2520        let num_pk_cols = table_schema.primary_key_columns.len();
2521
2522        Ok(Some(Self {
2523            ops,
2524            partial_ctx,
2525            raw_targets,
2526            num_pk_cols,
2527        }))
2528    }
2529
2530    #[allow(clippy::too_many_arguments)]
2531    fn feed_row(
2532        &self,
2533        key: &[u8],
2534        value: &[u8],
2535        table_schema: &TableSchema,
2536        col_map: &ColumnMap,
2537        where_clause: &Option<Expr>,
2538        states: &mut [AggState],
2539        scan_err: &mut Option<SqlError>,
2540    ) -> bool {
2541        let row = match &self.partial_ctx {
2542            Some(ctx) => match ctx.decode(key, value) {
2543                Ok(r) => r,
2544                Err(e) => {
2545                    *scan_err = Some(e);
2546                    return false;
2547                }
2548            },
2549            None => match decode_full_row(table_schema, key, value) {
2550                Ok(r) => r,
2551                Err(e) => {
2552                    *scan_err = Some(e);
2553                    return false;
2554                }
2555            },
2556        };
2557
2558        if let Some(expr) = where_clause {
2559            match eval_expr(expr, col_map, &row) {
2560                Ok(val) if !is_truthy(&val) => return true,
2561                Err(e) => {
2562                    *scan_err = Some(e);
2563                    return false;
2564                }
2565                _ => {}
2566            }
2567        }
2568
2569        for (i, (op, _)) in self.ops.iter().enumerate() {
2570            let val = match op {
2571                StreamAgg::CountStar => &Value::Null,
2572                StreamAgg::Count(idx)
2573                | StreamAgg::Sum(idx)
2574                | StreamAgg::Avg(idx)
2575                | StreamAgg::Min(idx)
2576                | StreamAgg::Max(idx) => &row[*idx],
2577            };
2578            if let Err(e) = states[i].feed_val(val) {
2579                *scan_err = Some(e);
2580                return false;
2581            }
2582        }
2583        true
2584    }
2585
2586    fn feed_row_raw(
2587        &self,
2588        key: &[u8],
2589        value: &[u8],
2590        states: &mut [AggState],
2591        scan_err: &mut Option<SqlError>,
2592    ) -> bool {
2593        for (i, target) in self.raw_targets.iter().enumerate() {
2594            let raw = match target {
2595                RawAggTarget::CountStar => {
2596                    if let Err(e) = states[i].feed_raw(&RawColumn::Null) {
2597                        *scan_err = Some(e);
2598                        return false;
2599                    }
2600                    continue;
2601                }
2602                RawAggTarget::Pk(pk_pos) => {
2603                    if self.num_pk_cols == 1 && *pk_pos == 0 {
2604                        match decode_pk_integer(key) {
2605                            Ok(v) => RawColumn::Integer(v),
2606                            Err(e) => {
2607                                *scan_err = Some(e);
2608                                return false;
2609                            }
2610                        }
2611                    } else {
2612                        match decode_composite_key(key, self.num_pk_cols) {
2613                            Ok(pk) => RawColumn::Integer(match &pk[*pk_pos] {
2614                                Value::Integer(i) => *i,
2615                                _ => {
2616                                    *scan_err =
2617                                        Some(SqlError::InvalidValue("PK not integer".into()));
2618                                    return false;
2619                                }
2620                            }),
2621                            Err(e) => {
2622                                *scan_err = Some(e);
2623                                return false;
2624                            }
2625                        }
2626                    }
2627                }
2628                RawAggTarget::NonPk(idx) => match decode_column_raw(value, *idx) {
2629                    Ok(v) => v,
2630                    Err(e) => {
2631                        *scan_err = Some(e);
2632                        return false;
2633                    }
2634                },
2635            };
2636            if let Err(e) = states[i].feed_raw(&raw) {
2637                *scan_err = Some(e);
2638                return false;
2639            }
2640        }
2641        true
2642    }
2643
2644    fn finish(self, states: Vec<AggState>) -> ExecutionResult {
2645        let col_names: Vec<String> = self.ops.iter().map(|(_, name)| name.clone()).collect();
2646        let result_row: Vec<Value> = states.into_iter().map(|s| s.finish()).collect();
2647        ExecutionResult::Query(QueryResult {
2648            columns: col_names,
2649            rows: vec![result_row],
2650        })
2651    }
2652}
2653
2654fn resolve_simple_col(expr: &Expr, col_map: &ColumnMap) -> Option<usize> {
2655    match expr {
2656        Expr::Column(name) => col_map.resolve(name).ok(),
2657        Expr::QualifiedColumn { table, column } => col_map.resolve_qualified(table, column).ok(),
2658        _ => None,
2659    }
2660}
2661
2662enum GroupByOutputCol {
2663    GroupKey,
2664    Agg(usize),
2665}
2666
2667struct StreamGroupByPlan {
2668    group_target: RawAggTarget,
2669    num_pk_cols: usize,
2670    agg_ops: Vec<StreamAgg>,
2671    raw_targets: Vec<RawAggTarget>,
2672    output: Vec<(GroupByOutputCol, String)>,
2673}
2674
2675impl StreamGroupByPlan {
2676    fn try_new(stmt: &SelectStmt, schema: &TableSchema) -> Result<Option<Self>> {
2677        if stmt.group_by.len() != 1
2678            || stmt.having.is_some()
2679            || !stmt.joins.is_empty()
2680            || stmt.where_clause.is_some()
2681            || !stmt.order_by.is_empty()
2682            || stmt.limit.is_some()
2683        {
2684            return Ok(None);
2685        }
2686
2687        let col_map = ColumnMap::new(&schema.columns);
2688
2689        let group_col_idx = match &stmt.group_by[0] {
2690            Expr::Column(name) => col_map.resolve(name).ok(),
2691            _ => None,
2692        };
2693        let group_col_idx = match group_col_idx {
2694            Some(idx) => idx,
2695            None => return Ok(None),
2696        };
2697
2698        if schema.columns[group_col_idx].data_type != DataType::Integer {
2699            return Ok(None);
2700        }
2701
2702        let non_pk = schema.non_pk_indices();
2703        let group_target = if let Some(pk_pos) = schema
2704            .primary_key_columns
2705            .iter()
2706            .position(|&i| i as usize == group_col_idx)
2707        {
2708            RawAggTarget::Pk(pk_pos)
2709        } else {
2710            let nonpk_idx = non_pk.iter().position(|&i| i == group_col_idx).unwrap();
2711            RawAggTarget::NonPk(nonpk_idx)
2712        };
2713
2714        let mut agg_ops = Vec::new();
2715        let mut raw_targets = Vec::new();
2716        let mut output = Vec::new();
2717
2718        for sel_col in &stmt.columns {
2719            let (expr, alias) = match sel_col {
2720                SelectColumn::Expr { expr, alias } => (expr, alias),
2721                _ => return Ok(None),
2722            };
2723            let name = alias
2724                .as_deref()
2725                .unwrap_or(&expr_display_name(expr))
2726                .to_string();
2727
2728            if let Some(idx) = resolve_simple_col(expr, &col_map) {
2729                if idx == group_col_idx {
2730                    output.push((GroupByOutputCol::GroupKey, name));
2731                    continue;
2732                }
2733            }
2734
2735            match expr {
2736                Expr::CountStar => {
2737                    let agg_idx = agg_ops.len();
2738                    agg_ops.push(StreamAgg::CountStar);
2739                    raw_targets.push(RawAggTarget::CountStar);
2740                    output.push((GroupByOutputCol::Agg(agg_idx), name));
2741                }
2742                Expr::Function {
2743                    name: func_name,
2744                    args,
2745                } if args.len() == 1 => {
2746                    let func = func_name.to_ascii_uppercase();
2747                    let col_idx = match resolve_simple_col(&args[0], &col_map) {
2748                        Some(idx) => idx,
2749                        None => return Ok(None),
2750                    };
2751                    let target = if let Some(pk_pos) = schema
2752                        .primary_key_columns
2753                        .iter()
2754                        .position(|&i| i as usize == col_idx)
2755                    {
2756                        RawAggTarget::Pk(pk_pos)
2757                    } else {
2758                        let nonpk_idx = non_pk.iter().position(|&i| i == col_idx).unwrap();
2759                        RawAggTarget::NonPk(nonpk_idx)
2760                    };
2761                    let agg_idx = agg_ops.len();
2762                    match func.as_str() {
2763                        "COUNT" => agg_ops.push(StreamAgg::Count(col_idx)),
2764                        "SUM" => agg_ops.push(StreamAgg::Sum(col_idx)),
2765                        "AVG" => agg_ops.push(StreamAgg::Avg(col_idx)),
2766                        "MIN" => agg_ops.push(StreamAgg::Min(col_idx)),
2767                        "MAX" => agg_ops.push(StreamAgg::Max(col_idx)),
2768                        _ => return Ok(None),
2769                    }
2770                    raw_targets.push(target);
2771                    output.push((GroupByOutputCol::Agg(agg_idx), name));
2772                }
2773                _ => return Ok(None),
2774            }
2775        }
2776
2777        Ok(Some(Self {
2778            group_target,
2779            num_pk_cols: schema.primary_key_columns.len(),
2780            agg_ops,
2781            raw_targets,
2782            output,
2783        }))
2784    }
2785
2786    fn execute_scan(
2787        &self,
2788        scan: impl FnOnce(
2789            &mut dyn FnMut(&[u8], &[u8]) -> bool,
2790        ) -> std::result::Result<(), citadel::Error>,
2791    ) -> Result<ExecutionResult> {
2792        let mut groups: HashMap<i64, Vec<AggState>> = HashMap::new();
2793        let mut null_group: Option<Vec<AggState>> = None;
2794        let mut scan_err: Option<SqlError> = None;
2795
2796        scan(&mut |key, value| {
2797            let group_key: Option<i64> = match &self.group_target {
2798                RawAggTarget::Pk(pk_pos) => {
2799                    if self.num_pk_cols == 1 && *pk_pos == 0 {
2800                        match decode_pk_integer(key) {
2801                            Ok(v) => Some(v),
2802                            Err(e) => {
2803                                scan_err = Some(e);
2804                                return false;
2805                            }
2806                        }
2807                    } else {
2808                        match decode_composite_key(key, self.num_pk_cols) {
2809                            Ok(pk) => match &pk[*pk_pos] {
2810                                Value::Integer(i) => Some(*i),
2811                                Value::Null => None,
2812                                _ => {
2813                                    scan_err = Some(SqlError::InvalidValue(
2814                                        "GROUP BY key not integer".into(),
2815                                    ));
2816                                    return false;
2817                                }
2818                            },
2819                            Err(e) => {
2820                                scan_err = Some(e);
2821                                return false;
2822                            }
2823                        }
2824                    }
2825                }
2826                RawAggTarget::NonPk(idx) => match decode_column_raw(value, *idx) {
2827                    Ok(RawColumn::Integer(i)) => Some(i),
2828                    Ok(RawColumn::Null) => None,
2829                    Ok(_) => {
2830                        scan_err = Some(SqlError::InvalidValue("GROUP BY key not integer".into()));
2831                        return false;
2832                    }
2833                    Err(e) => {
2834                        scan_err = Some(e);
2835                        return false;
2836                    }
2837                },
2838                RawAggTarget::CountStar => unreachable!(),
2839            };
2840
2841            let states = match group_key {
2842                Some(k) => groups
2843                    .entry(k)
2844                    .or_insert_with(|| self.agg_ops.iter().map(AggState::new).collect()),
2845                None => null_group
2846                    .get_or_insert_with(|| self.agg_ops.iter().map(AggState::new).collect()),
2847            };
2848
2849            for (i, target) in self.raw_targets.iter().enumerate() {
2850                let raw = match target {
2851                    RawAggTarget::CountStar => {
2852                        if let Err(e) = states[i].feed_raw(&RawColumn::Null) {
2853                            scan_err = Some(e);
2854                            return false;
2855                        }
2856                        continue;
2857                    }
2858                    RawAggTarget::Pk(pk_pos) => {
2859                        if self.num_pk_cols == 1 && *pk_pos == 0 {
2860                            match decode_pk_integer(key) {
2861                                Ok(v) => RawColumn::Integer(v),
2862                                Err(e) => {
2863                                    scan_err = Some(e);
2864                                    return false;
2865                                }
2866                            }
2867                        } else {
2868                            match decode_composite_key(key, self.num_pk_cols) {
2869                                Ok(pk) => match &pk[*pk_pos] {
2870                                    Value::Integer(i) => RawColumn::Integer(*i),
2871                                    _ => {
2872                                        scan_err = Some(SqlError::InvalidValue(
2873                                            "agg column not integer".into(),
2874                                        ));
2875                                        return false;
2876                                    }
2877                                },
2878                                Err(e) => {
2879                                    scan_err = Some(e);
2880                                    return false;
2881                                }
2882                            }
2883                        }
2884                    }
2885                    RawAggTarget::NonPk(idx) => match decode_column_raw(value, *idx) {
2886                        Ok(v) => v,
2887                        Err(e) => {
2888                            scan_err = Some(e);
2889                            return false;
2890                        }
2891                    },
2892                };
2893                if let Err(e) = states[i].feed_raw(&raw) {
2894                    scan_err = Some(e);
2895                    return false;
2896                }
2897            }
2898            true
2899        })
2900        .map_err(SqlError::Storage)?;
2901
2902        if let Some(e) = scan_err {
2903            return Err(e);
2904        }
2905
2906        let col_names: Vec<String> = self.output.iter().map(|(_, name)| name.clone()).collect();
2907        let null_extra = if null_group.is_some() { 1 } else { 0 };
2908        let mut result_rows: Vec<Vec<Value>> = Vec::with_capacity(groups.len() + null_extra);
2909        if let Some(states) = null_group {
2910            let mut row = Vec::with_capacity(self.output.len());
2911            let finished: Vec<Value> = states.into_iter().map(|s| s.finish()).collect();
2912            for (col, _) in &self.output {
2913                match col {
2914                    GroupByOutputCol::GroupKey => row.push(Value::Null),
2915                    GroupByOutputCol::Agg(idx) => row.push(finished[*idx].clone()),
2916                }
2917            }
2918            result_rows.push(row);
2919        }
2920        for (group_key, states) in groups {
2921            let mut row = Vec::with_capacity(self.output.len());
2922            let finished: Vec<Value> = states.into_iter().map(|s| s.finish()).collect();
2923            for (col, _) in &self.output {
2924                match col {
2925                    GroupByOutputCol::GroupKey => row.push(Value::Integer(group_key)),
2926                    GroupByOutputCol::Agg(idx) => row.push(finished[*idx].clone()),
2927                }
2928            }
2929            result_rows.push(row);
2930        }
2931
2932        Ok(ExecutionResult::Query(QueryResult {
2933            columns: col_names,
2934            rows: result_rows,
2935        }))
2936    }
2937}
2938
2939struct TopKScanPlan {
2940    sort_target: RawAggTarget,
2941    num_pk_cols: usize,
2942    descending: bool,
2943    nulls_first: bool,
2944    keep: usize,
2945}
2946
2947impl TopKScanPlan {
2948    fn try_new(stmt: &SelectStmt, schema: &TableSchema) -> Result<Option<Self>> {
2949        if stmt.order_by.len() != 1
2950            || stmt.limit.is_none()
2951            || stmt.where_clause.is_some()
2952            || !stmt.group_by.is_empty()
2953            || stmt.having.is_some()
2954            || !stmt.joins.is_empty()
2955            || stmt.distinct
2956        {
2957            return Ok(None);
2958        }
2959
2960        let has_aggregates = stmt.columns.iter().any(|c| match c {
2961            SelectColumn::Expr { expr, .. } => is_aggregate_expr(expr),
2962            _ => false,
2963        });
2964        if has_aggregates {
2965            return Ok(None);
2966        }
2967
2968        let ob = &stmt.order_by[0];
2969        let col_map = ColumnMap::new(&schema.columns);
2970        let col_idx = match resolve_simple_col(&ob.expr, &col_map) {
2971            Some(idx) => idx,
2972            None => return Ok(None),
2973        };
2974
2975        let non_pk = schema.non_pk_indices();
2976        let sort_target = if let Some(pk_pos) = schema
2977            .primary_key_columns
2978            .iter()
2979            .position(|&i| i as usize == col_idx)
2980        {
2981            RawAggTarget::Pk(pk_pos)
2982        } else {
2983            let nonpk_idx = non_pk.iter().position(|&i| i == col_idx).unwrap();
2984            RawAggTarget::NonPk(nonpk_idx)
2985        };
2986
2987        let limit = eval_const_int(stmt.limit.as_ref().unwrap())?.max(0) as usize;
2988        let offset = stmt
2989            .offset
2990            .as_ref()
2991            .map(eval_const_int)
2992            .transpose()?
2993            .unwrap_or(0)
2994            .max(0) as usize;
2995        let keep = limit.saturating_add(offset);
2996        if keep == 0 {
2997            return Ok(None);
2998        }
2999
3000        Ok(Some(Self {
3001            sort_target,
3002            num_pk_cols: schema.primary_key_columns.len(),
3003            descending: ob.descending,
3004            nulls_first: ob.nulls_first.unwrap_or(!ob.descending),
3005            keep,
3006        }))
3007    }
3008
3009    fn execute_scan(
3010        &self,
3011        schema: &TableSchema,
3012        stmt: &SelectStmt,
3013        scan: impl FnOnce(
3014            &mut dyn FnMut(&[u8], &[u8]) -> bool,
3015        ) -> std::result::Result<(), citadel::Error>,
3016    ) -> Result<ExecutionResult> {
3017        use std::cmp::Ordering;
3018        use std::collections::BinaryHeap;
3019
3020        struct Candidate {
3021            sort_key: Value,
3022            raw_key: Vec<u8>,
3023            raw_value: Vec<u8>,
3024        }
3025
3026        struct CandWrapper {
3027            c: Candidate,
3028            descending: bool,
3029            nulls_first: bool,
3030        }
3031
3032        impl PartialEq for CandWrapper {
3033            fn eq(&self, other: &Self) -> bool {
3034                self.cmp(other) == Ordering::Equal
3035            }
3036        }
3037        impl Eq for CandWrapper {}
3038
3039        impl PartialOrd for CandWrapper {
3040            fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
3041                Some(self.cmp(other))
3042            }
3043        }
3044
3045        // Max-heap: worst candidate on top for eviction.
3046        impl Ord for CandWrapper {
3047            fn cmp(&self, other: &Self) -> Ordering {
3048                let ord = match (self.c.sort_key.is_null(), other.c.sort_key.is_null()) {
3049                    (true, true) => Ordering::Equal,
3050                    (true, false) => {
3051                        if self.nulls_first {
3052                            Ordering::Less
3053                        } else {
3054                            Ordering::Greater
3055                        }
3056                    }
3057                    (false, true) => {
3058                        if self.nulls_first {
3059                            Ordering::Greater
3060                        } else {
3061                            Ordering::Less
3062                        }
3063                    }
3064                    (false, false) => self.c.sort_key.cmp(&other.c.sort_key),
3065                };
3066                if self.descending {
3067                    ord.reverse()
3068                } else {
3069                    ord
3070                }
3071            }
3072        }
3073
3074        let k = self.keep;
3075        let mut heap: BinaryHeap<CandWrapper> = BinaryHeap::with_capacity(k + 1);
3076        let mut scan_err: Option<SqlError> = None;
3077
3078        scan(&mut |key, value| {
3079            let sort_key: Value = match &self.sort_target {
3080                RawAggTarget::Pk(pk_pos) => {
3081                    if self.num_pk_cols == 1 && *pk_pos == 0 {
3082                        match decode_pk_integer(key) {
3083                            Ok(v) => Value::Integer(v),
3084                            Err(e) => {
3085                                scan_err = Some(e);
3086                                return false;
3087                            }
3088                        }
3089                    } else {
3090                        match decode_composite_key(key, self.num_pk_cols) {
3091                            Ok(mut pk) => std::mem::replace(&mut pk[*pk_pos], Value::Null),
3092                            Err(e) => {
3093                                scan_err = Some(e);
3094                                return false;
3095                            }
3096                        }
3097                    }
3098                }
3099                RawAggTarget::NonPk(idx) => match decode_column_raw(value, *idx) {
3100                    Ok(raw) => raw.to_value(),
3101                    Err(e) => {
3102                        scan_err = Some(e);
3103                        return false;
3104                    }
3105                },
3106                RawAggTarget::CountStar => unreachable!(),
3107            };
3108
3109            // Heap full and can't beat worst — skip
3110            if heap.len() >= k {
3111                if let Some(top) = heap.peek() {
3112                    let ord = match (sort_key.is_null(), top.c.sort_key.is_null()) {
3113                        (true, true) => Ordering::Equal,
3114                        (true, false) => {
3115                            if self.nulls_first {
3116                                Ordering::Less
3117                            } else {
3118                                Ordering::Greater
3119                            }
3120                        }
3121                        (false, true) => {
3122                            if self.nulls_first {
3123                                Ordering::Greater
3124                            } else {
3125                                Ordering::Less
3126                            }
3127                        }
3128                        (false, false) => sort_key.cmp(&top.c.sort_key),
3129                    };
3130                    let cmp = if self.descending { ord.reverse() } else { ord };
3131                    if cmp != Ordering::Less {
3132                        return true;
3133                    }
3134                }
3135            }
3136
3137            let cand = CandWrapper {
3138                c: Candidate {
3139                    sort_key,
3140                    raw_key: key.to_vec(),
3141                    raw_value: value.to_vec(),
3142                },
3143                descending: self.descending,
3144                nulls_first: self.nulls_first,
3145            };
3146
3147            if heap.len() < k {
3148                heap.push(cand);
3149            } else if let Some(mut top) = heap.peek_mut() {
3150                *top = cand;
3151            }
3152
3153            true
3154        })
3155        .map_err(SqlError::Storage)?;
3156
3157        if let Some(e) = scan_err {
3158            return Err(e);
3159        }
3160
3161        let mut winners: Vec<CandWrapper> = heap.into_vec();
3162        winners.sort();
3163
3164        let mut rows: Vec<Vec<Value>> = Vec::with_capacity(winners.len());
3165        for w in &winners {
3166            rows.push(decode_full_row(schema, &w.c.raw_key, &w.c.raw_value)?);
3167        }
3168
3169        if let Some(ref offset_expr) = stmt.offset {
3170            let offset = eval_const_int(offset_expr)?.max(0) as usize;
3171            if offset < rows.len() {
3172                rows = rows.split_off(offset);
3173            } else {
3174                rows.clear();
3175            }
3176        }
3177        if let Some(ref limit_expr) = stmt.limit {
3178            let limit = eval_const_int(limit_expr)?.max(0) as usize;
3179            rows.truncate(limit);
3180        }
3181
3182        let (col_names, projected) = project_rows(&schema.columns, &stmt.columns, rows)?;
3183        Ok(ExecutionResult::Query(QueryResult {
3184            columns: col_names,
3185            rows: projected,
3186        }))
3187    }
3188}
3189
3190struct SimplePredicate {
3191    is_pk: bool,
3192    pk_pos: usize,
3193    nonpk_idx: usize,
3194    op: BinOp,
3195    literal: Value,
3196    num_pk_cols: usize,
3197    precomputed_int: Option<i64>,
3198}
3199
3200impl SimplePredicate {
3201    fn matches_raw(&self, key: &[u8], value: &[u8]) -> Result<bool> {
3202        if let Some(target) = self.precomputed_int {
3203            return Ok(self.match_nonpk_int_inline(value, target));
3204        }
3205        let raw = if self.is_pk {
3206            if self.num_pk_cols == 1 {
3207                RawColumn::Integer(decode_pk_integer(key)?)
3208            } else {
3209                let pk = decode_composite_key(key, self.num_pk_cols)?;
3210                match &pk[self.pk_pos] {
3211                    Value::Integer(i) => RawColumn::Integer(*i),
3212                    Value::Real(r) => RawColumn::Real(*r),
3213                    Value::Boolean(b) => RawColumn::Boolean(*b),
3214                    _ => {
3215                        return Ok(raw_matches_op_value(
3216                            &pk[self.pk_pos],
3217                            self.op,
3218                            &self.literal,
3219                        ))
3220                    }
3221                }
3222            }
3223        } else {
3224            decode_column_raw(value, self.nonpk_idx)?
3225        };
3226        Ok(raw_matches_op(&raw, self.op, &self.literal))
3227    }
3228
3229    #[inline(always)]
3230    fn match_nonpk_int_inline(&self, data: &[u8], target: i64) -> bool {
3231        let col_count = u16::from_le_bytes(data[0..2].try_into().unwrap()) as usize;
3232        let bm_bytes = col_count.div_ceil(8);
3233
3234        // NULL → false (SQL NULL semantics)
3235        if data[2 + self.nonpk_idx / 8] & (1 << (self.nonpk_idx % 8)) != 0 {
3236            return false;
3237        }
3238
3239        let mut pos = 2 + bm_bytes;
3240
3241        // Skip preceding non-null columns by reading their length
3242        for col in 0..self.nonpk_idx {
3243            if data[2 + col / 8] & (1 << (col % 8)) == 0 {
3244                let len = u32::from_le_bytes(data[pos + 1..pos + 5].try_into().unwrap()) as usize;
3245                pos += 5 + len;
3246            }
3247        }
3248
3249        // Read i64 directly: skip type_tag(1) + len(4), read 8 bytes
3250        let v = i64::from_le_bytes(data[pos + 5..pos + 13].try_into().unwrap());
3251
3252        match self.op {
3253            BinOp::Eq => v == target,
3254            BinOp::NotEq => v != target,
3255            BinOp::Lt => v < target,
3256            BinOp::Gt => v > target,
3257            BinOp::LtEq => v <= target,
3258            BinOp::GtEq => v >= target,
3259            _ => false,
3260        }
3261    }
3262}
3263
3264fn try_simple_predicate(expr: &Expr, schema: &TableSchema) -> Option<SimplePredicate> {
3265    let (col_name, op, literal) = match expr {
3266        Expr::BinaryOp { left, op, right } => match (left.as_ref(), right.as_ref()) {
3267            (Expr::Column(name), Expr::Literal(lit)) => (name.as_str(), *op, lit),
3268            (Expr::Literal(lit), Expr::Column(name)) => (name.as_str(), flip_cmp_op(*op)?, lit),
3269            _ => return None,
3270        },
3271        _ => return None,
3272    };
3273
3274    if !matches!(
3275        op,
3276        BinOp::Eq | BinOp::NotEq | BinOp::Lt | BinOp::Gt | BinOp::LtEq | BinOp::GtEq
3277    ) {
3278        return None;
3279    }
3280
3281    let col_idx = schema.column_index(col_name)?;
3282    let non_pk = schema.non_pk_indices();
3283
3284    if let Some(pk_pos) = schema
3285        .primary_key_columns
3286        .iter()
3287        .position(|&i| i as usize == col_idx)
3288    {
3289        Some(SimplePredicate {
3290            is_pk: true,
3291            pk_pos,
3292            nonpk_idx: 0,
3293            op,
3294            literal: literal.clone(),
3295            num_pk_cols: schema.primary_key_columns.len(),
3296            precomputed_int: None,
3297        })
3298    } else {
3299        let nonpk_idx = non_pk.iter().position(|&i| i == col_idx)?;
3300        let precomputed_int = match literal {
3301            Value::Integer(i) => Some(*i),
3302            _ => None,
3303        };
3304        Some(SimplePredicate {
3305            is_pk: false,
3306            pk_pos: 0,
3307            nonpk_idx,
3308            op,
3309            literal: literal.clone(),
3310            num_pk_cols: schema.primary_key_columns.len(),
3311            precomputed_int,
3312        })
3313    }
3314}
3315
3316fn flip_cmp_op(op: BinOp) -> Option<BinOp> {
3317    match op {
3318        BinOp::Eq => Some(BinOp::Eq),
3319        BinOp::NotEq => Some(BinOp::NotEq),
3320        BinOp::Lt => Some(BinOp::Gt),
3321        BinOp::Gt => Some(BinOp::Lt),
3322        BinOp::LtEq => Some(BinOp::GtEq),
3323        BinOp::GtEq => Some(BinOp::LtEq),
3324        _ => None,
3325    }
3326}
3327
3328fn raw_matches_op(raw: &RawColumn, op: BinOp, literal: &Value) -> bool {
3329    // SQL NULL semantics: any comparison involving NULL yields NULL (falsy)
3330    if matches!(raw, RawColumn::Null) || literal.is_null() {
3331        return false;
3332    }
3333    match op {
3334        BinOp::Eq => raw.eq_value(literal),
3335        BinOp::NotEq => !raw.eq_value(literal),
3336        BinOp::Lt => raw.cmp_value(literal) == Some(std::cmp::Ordering::Less),
3337        BinOp::Gt => raw.cmp_value(literal) == Some(std::cmp::Ordering::Greater),
3338        BinOp::LtEq => raw
3339            .cmp_value(literal)
3340            .is_some_and(|o| o != std::cmp::Ordering::Greater),
3341        BinOp::GtEq => raw
3342            .cmp_value(literal)
3343            .is_some_and(|o| o != std::cmp::Ordering::Less),
3344        _ => false,
3345    }
3346}
3347
3348fn raw_matches_op_value(val: &Value, op: BinOp, literal: &Value) -> bool {
3349    match op {
3350        BinOp::Eq => val == literal,
3351        BinOp::NotEq => val != literal && !val.is_null(),
3352        BinOp::Lt => val < literal,
3353        BinOp::Gt => val > literal,
3354        BinOp::LtEq => val <= literal,
3355        BinOp::GtEq => val >= literal,
3356        _ => false,
3357    }
3358}
3359
3360fn exec_select_no_from(stmt: &SelectStmt) -> Result<ExecutionResult> {
3361    let empty_cols: Vec<ColumnDef> = vec![];
3362    let empty_row: Vec<Value> = vec![];
3363    let (col_names, projected) = project_rows(&empty_cols, &stmt.columns, vec![empty_row])?;
3364    Ok(ExecutionResult::Query(QueryResult {
3365        columns: col_names,
3366        rows: projected,
3367    }))
3368}
3369
3370fn process_select(
3371    columns: &[ColumnDef],
3372    mut rows: Vec<Vec<Value>>,
3373    stmt: &SelectStmt,
3374    predicate_applied: bool,
3375) -> Result<ExecutionResult> {
3376    if !predicate_applied {
3377        if let Some(ref where_expr) = stmt.where_clause {
3378            let col_map = ColumnMap::new(columns);
3379            rows.retain(|row| match eval_expr(where_expr, &col_map, row) {
3380                Ok(val) => is_truthy(&val),
3381                Err(_) => false,
3382            });
3383        }
3384    }
3385
3386    let has_aggregates = stmt.columns.iter().any(|c| match c {
3387        SelectColumn::Expr { expr, .. } => is_aggregate_expr(expr),
3388        _ => false,
3389    });
3390
3391    if has_aggregates || !stmt.group_by.is_empty() {
3392        return exec_aggregate(columns, &rows, stmt);
3393    }
3394
3395    if stmt.distinct {
3396        let (col_names, mut projected) = project_rows(columns, &stmt.columns, rows)?;
3397
3398        let mut seen = std::collections::HashSet::new();
3399        projected.retain(|row| seen.insert(row.clone()));
3400
3401        if !stmt.order_by.is_empty() {
3402            let output_cols = build_output_columns(&stmt.columns, columns);
3403            sort_rows(&mut projected, &stmt.order_by, &output_cols)?;
3404        }
3405
3406        if let Some(ref offset_expr) = stmt.offset {
3407            let offset = eval_const_int(offset_expr)?.max(0) as usize;
3408            if offset < projected.len() {
3409                projected = projected.split_off(offset);
3410            } else {
3411                projected.clear();
3412            }
3413        }
3414
3415        if let Some(ref limit_expr) = stmt.limit {
3416            let limit = eval_const_int(limit_expr)?.max(0) as usize;
3417            projected.truncate(limit);
3418        }
3419
3420        return Ok(ExecutionResult::Query(QueryResult {
3421            columns: col_names,
3422            rows: projected,
3423        }));
3424    }
3425
3426    if !stmt.order_by.is_empty() {
3427        if let Some(ref limit_expr) = stmt.limit {
3428            let limit = eval_const_int(limit_expr)?.max(0) as usize;
3429            let offset = match stmt.offset {
3430                Some(ref e) => eval_const_int(e)?.max(0) as usize,
3431                None => 0,
3432            };
3433            let keep = limit.saturating_add(offset);
3434            if keep == 0 {
3435                rows.clear();
3436            } else if keep < rows.len() {
3437                topk_rows(&mut rows, &stmt.order_by, columns, keep)?;
3438                rows.truncate(keep);
3439            } else {
3440                sort_rows(&mut rows, &stmt.order_by, columns)?;
3441            }
3442        } else {
3443            sort_rows(&mut rows, &stmt.order_by, columns)?;
3444        }
3445    }
3446
3447    if let Some(ref offset_expr) = stmt.offset {
3448        let offset = eval_const_int(offset_expr)?.max(0) as usize;
3449        if offset < rows.len() {
3450            rows = rows.split_off(offset);
3451        } else {
3452            rows.clear();
3453        }
3454    }
3455
3456    if let Some(ref limit_expr) = stmt.limit {
3457        let limit = eval_const_int(limit_expr)?.max(0) as usize;
3458        rows.truncate(limit);
3459    }
3460
3461    let (col_names, projected) = project_rows(columns, &stmt.columns, rows)?;
3462
3463    Ok(ExecutionResult::Query(QueryResult {
3464        columns: col_names,
3465        rows: projected,
3466    }))
3467}
3468
3469fn resolve_table_name<'a>(schema: &'a SchemaManager, name: &str) -> Result<&'a TableSchema> {
3470    schema
3471        .get(name)
3472        .ok_or_else(|| SqlError::TableNotFound(name.to_string()))
3473}
3474
3475fn build_joined_columns(tables: &[(String, &TableSchema)]) -> Vec<ColumnDef> {
3476    let mut result = Vec::new();
3477    let mut pos: u16 = 0;
3478
3479    for (alias, schema) in tables {
3480        for col in &schema.columns {
3481            result.push(ColumnDef {
3482                name: format!("{}.{}", alias.to_ascii_lowercase(), col.name),
3483                data_type: col.data_type,
3484                nullable: col.nullable,
3485                position: pos,
3486            });
3487            pos += 1;
3488        }
3489    }
3490
3491    result
3492}
3493
3494fn extract_equi_join_keys(
3495    on_expr: &Expr,
3496    combined_cols: &[ColumnDef],
3497    outer_col_count: usize,
3498) -> Vec<(usize, usize)> {
3499    let mut pairs = Vec::new();
3500
3501    fn flatten<'a>(e: &'a Expr, out: &mut Vec<&'a Expr>) {
3502        match e {
3503            Expr::BinaryOp {
3504                left,
3505                op: BinOp::And,
3506                right,
3507            } => {
3508                flatten(left, out);
3509                flatten(right, out);
3510            }
3511            _ => out.push(e),
3512        }
3513    }
3514    let mut conjuncts = Vec::new();
3515    flatten(on_expr, &mut conjuncts);
3516
3517    for expr in conjuncts {
3518        if let Expr::BinaryOp {
3519            left,
3520            op: BinOp::Eq,
3521            right,
3522        } = expr
3523        {
3524            if let (Some(l_idx), Some(r_idx)) = (
3525                resolve_col_idx(left, combined_cols),
3526                resolve_col_idx(right, combined_cols),
3527            ) {
3528                if l_idx < outer_col_count && r_idx >= outer_col_count {
3529                    pairs.push((l_idx, r_idx - outer_col_count));
3530                } else if r_idx < outer_col_count && l_idx >= outer_col_count {
3531                    pairs.push((r_idx, l_idx - outer_col_count));
3532                }
3533            }
3534        }
3535    }
3536
3537    pairs
3538}
3539
3540fn resolve_col_idx(expr: &Expr, columns: &[ColumnDef]) -> Option<usize> {
3541    match expr {
3542        Expr::Column(name) => {
3543            let matches: Vec<usize> = columns
3544                .iter()
3545                .enumerate()
3546                .filter(|(_, c)| {
3547                    c.name == *name
3548                        || (c.name.len() > name.len()
3549                            && c.name.as_bytes()[c.name.len() - name.len() - 1] == b'.'
3550                            && c.name.ends_with(name.as_str()))
3551                })
3552                .map(|(i, _)| i)
3553                .collect();
3554            if matches.len() == 1 {
3555                Some(matches[0])
3556            } else {
3557                None
3558            }
3559        }
3560        Expr::QualifiedColumn { table, column } => {
3561            let qualified = format!("{table}.{column}");
3562            columns.iter().position(|c| c.name == qualified)
3563        }
3564        _ => None,
3565    }
3566}
3567
3568fn hash_key(row: &[Value], col_indices: &[usize]) -> Vec<Value> {
3569    col_indices.iter().map(|&i| row[i].clone()).collect()
3570}
3571
3572fn count_conjuncts(expr: &Expr) -> usize {
3573    match expr {
3574        Expr::BinaryOp {
3575            op: BinOp::And,
3576            left,
3577            right,
3578        } => count_conjuncts(left) + count_conjuncts(right),
3579        _ => 1,
3580    }
3581}
3582
3583fn combine_row(outer: &[Value], inner: &[Value], cap: usize) -> Vec<Value> {
3584    let mut combined = Vec::with_capacity(cap);
3585    combined.extend(outer.iter().cloned());
3586    combined.extend(inner.iter().cloned());
3587    combined
3588}
3589
3590struct CombineProjection {
3591    slots: Vec<(usize, bool)>,
3592}
3593
3594fn combine_row_projected(outer: &[Value], inner: &[Value], proj: &CombineProjection) -> Vec<Value> {
3595    proj.slots
3596        .iter()
3597        .map(|&(idx, is_inner)| {
3598            if is_inner {
3599                inner[idx].clone()
3600            } else {
3601                outer[idx].clone()
3602            }
3603        })
3604        .collect()
3605}
3606
3607fn build_combine_projection(
3608    needed_combined: &[usize],
3609    outer_col_count: usize,
3610) -> CombineProjection {
3611    CombineProjection {
3612        slots: needed_combined
3613            .iter()
3614            .map(|&ci| {
3615                if ci < outer_col_count {
3616                    (ci, false)
3617                } else {
3618                    (ci - outer_col_count, true)
3619                }
3620            })
3621            .collect(),
3622    }
3623}
3624
3625fn build_projected_columns(full_cols: &[ColumnDef], needed_combined: &[usize]) -> Vec<ColumnDef> {
3626    needed_combined
3627        .iter()
3628        .enumerate()
3629        .map(|(new_pos, &old_pos)| {
3630            let orig = &full_cols[old_pos];
3631            ColumnDef {
3632                name: orig.name.clone(),
3633                data_type: orig.data_type,
3634                nullable: orig.nullable,
3635                position: new_pos as u16,
3636            }
3637        })
3638        .collect()
3639}
3640
3641#[allow(clippy::too_many_arguments)]
3642fn try_integer_join(
3643    outer_rows: Vec<Vec<Value>>,
3644    inner_rows: &[Vec<Value>],
3645    join_type: &JoinType,
3646    outer_key_col: usize,
3647    inner_key_col: usize,
3648    outer_col_count: usize,
3649    inner_col_count: usize,
3650    outer_is_sorted: bool,
3651    projection: Option<&CombineProjection>,
3652) -> std::result::Result<Vec<Vec<Value>>, Vec<Vec<Value>>> {
3653    let cap = projection.map_or(outer_col_count + inner_col_count, |p| p.slots.len());
3654
3655    if outer_is_sorted && matches!(join_type, JoinType::Inner | JoinType::Cross) {
3656        let mut sorted_inner: Vec<(i64, usize)> = Vec::with_capacity(inner_rows.len());
3657        let mut needs_sort = false;
3658        let mut prev = i64::MIN;
3659        for (i, r) in inner_rows.iter().enumerate() {
3660            match r[inner_key_col] {
3661                Value::Integer(k) => {
3662                    if k < prev {
3663                        needs_sort = true;
3664                    }
3665                    prev = k;
3666                    sorted_inner.push((k, i));
3667                }
3668                Value::Null => {}
3669                _ => return Err(outer_rows),
3670            }
3671        }
3672        if needs_sort {
3673            sorted_inner.sort_unstable_by_key(|&(k, _)| k);
3674        }
3675
3676        let mut result = Vec::with_capacity(outer_rows.len());
3677        let mut j = 0;
3678        for mut outer in outer_rows {
3679            let ok = match outer[outer_key_col] {
3680                Value::Integer(i) => i,
3681                _ => continue,
3682            };
3683            while j < sorted_inner.len() && sorted_inner[j].0 < ok {
3684                j += 1;
3685            }
3686            let mut kk = j;
3687            while kk < sorted_inner.len() && sorted_inner[kk].0 == ok {
3688                let is_last = kk + 1 >= sorted_inner.len() || sorted_inner[kk + 1].0 != ok;
3689                let inner = &inner_rows[sorted_inner[kk].1];
3690                if let Some(proj) = projection {
3691                    if is_last {
3692                        result.push(
3693                            proj.slots
3694                                .iter()
3695                                .map(|&(idx, is_inner)| {
3696                                    if is_inner {
3697                                        inner[idx].clone()
3698                                    } else {
3699                                        std::mem::take(&mut outer[idx])
3700                                    }
3701                                })
3702                                .collect(),
3703                        );
3704                    } else {
3705                        result.push(combine_row_projected(&outer, inner, proj));
3706                    }
3707                } else if is_last {
3708                    outer.extend(inner.iter().cloned());
3709                    result.push(outer);
3710                    break;
3711                } else {
3712                    result.push(combine_row(&outer, inner, cap));
3713                }
3714                kk += 1;
3715            }
3716        }
3717        return Ok(result);
3718    }
3719
3720    let mut inner_map: HashMap<i64, Vec<usize>> = HashMap::with_capacity(inner_rows.len());
3721    for (idx, inner) in inner_rows.iter().enumerate() {
3722        match &inner[inner_key_col] {
3723            Value::Integer(k) => inner_map.entry(*k).or_default().push(idx),
3724            Value::Null => {}
3725            _ => return Err(outer_rows),
3726        }
3727    }
3728
3729    let mut result = Vec::with_capacity(inner_rows.len());
3730
3731    match join_type {
3732        JoinType::Inner | JoinType::Cross => {
3733            for mut outer in outer_rows {
3734                if let Value::Integer(k) = outer[outer_key_col] {
3735                    if let Some(indices) = inner_map.get(&k) {
3736                        if let Some(proj) = projection {
3737                            for &idx in indices {
3738                                result.push(combine_row_projected(&outer, &inner_rows[idx], proj));
3739                            }
3740                        } else {
3741                            for &idx in &indices[..indices.len() - 1] {
3742                                result.push(combine_row(&outer, &inner_rows[idx], cap));
3743                            }
3744                            let last_idx = *indices.last().unwrap();
3745                            outer.extend(inner_rows[last_idx].iter().cloned());
3746                            result.push(outer);
3747                        }
3748                    }
3749                }
3750            }
3751        }
3752        JoinType::Left => {
3753            for mut outer in outer_rows {
3754                if let Value::Integer(k) = outer[outer_key_col] {
3755                    if let Some(indices) = inner_map.get(&k) {
3756                        if let Some(proj) = projection {
3757                            for &idx in indices {
3758                                result.push(combine_row_projected(&outer, &inner_rows[idx], proj));
3759                            }
3760                        } else {
3761                            for &idx in &indices[..indices.len() - 1] {
3762                                result.push(combine_row(&outer, &inner_rows[idx], cap));
3763                            }
3764                            let last_idx = *indices.last().unwrap();
3765                            outer.extend(inner_rows[last_idx].iter().cloned());
3766                            result.push(outer);
3767                        }
3768                        continue;
3769                    }
3770                }
3771                if let Some(proj) = projection {
3772                    let null_inner = vec![Value::Null; inner_col_count];
3773                    result.push(combine_row_projected(&outer, &null_inner, proj));
3774                } else {
3775                    outer.resize(cap, Value::Null);
3776                    result.push(outer);
3777                }
3778            }
3779        }
3780        JoinType::Right => {
3781            let mut inner_matched = vec![false; inner_rows.len()];
3782            for mut outer in outer_rows {
3783                if let Value::Integer(k) = outer[outer_key_col] {
3784                    if let Some(indices) = inner_map.get(&k) {
3785                        if let Some(proj) = projection {
3786                            for &idx in indices {
3787                                result.push(combine_row_projected(&outer, &inner_rows[idx], proj));
3788                                inner_matched[idx] = true;
3789                            }
3790                        } else {
3791                            for &idx in &indices[..indices.len() - 1] {
3792                                result.push(combine_row(&outer, &inner_rows[idx], cap));
3793                                inner_matched[idx] = true;
3794                            }
3795                            let last_idx = *indices.last().unwrap();
3796                            inner_matched[last_idx] = true;
3797                            outer.extend(inner_rows[last_idx].iter().cloned());
3798                            result.push(outer);
3799                        }
3800                    }
3801                }
3802            }
3803            for (j, inner) in inner_rows.iter().enumerate() {
3804                if !inner_matched[j] {
3805                    if let Some(proj) = projection {
3806                        let null_outer = vec![Value::Null; outer_col_count];
3807                        result.push(combine_row_projected(&null_outer, inner, proj));
3808                    } else {
3809                        let mut padded = Vec::with_capacity(cap);
3810                        padded.resize(outer_col_count, Value::Null);
3811                        padded.extend(inner.iter().cloned());
3812                        result.push(padded);
3813                    }
3814                }
3815            }
3816        }
3817    }
3818
3819    Ok(result)
3820}
3821
3822#[allow(clippy::too_many_arguments)]
3823fn exec_join_step(
3824    mut outer_rows: Vec<Vec<Value>>,
3825    inner_rows: &[Vec<Value>],
3826    join: &JoinClause,
3827    combined_cols: &[ColumnDef],
3828    outer_col_count: usize,
3829    inner_col_count: usize,
3830    outer_pk_col: Option<usize>,
3831    projection: Option<&CombineProjection>,
3832) -> Vec<Vec<Value>> {
3833    let equi_pairs = join
3834        .on_clause
3835        .as_ref()
3836        .map(|on| extract_equi_join_keys(on, combined_cols, outer_col_count))
3837        .unwrap_or_default();
3838
3839    let is_pure_equi = join.on_clause.as_ref().map_or(true, |on| {
3840        !equi_pairs.is_empty() && count_conjuncts(on) == equi_pairs.len()
3841    });
3842
3843    let effective_proj = if is_pure_equi { projection } else { None };
3844
3845    if equi_pairs.len() == 1 && is_pure_equi {
3846        let (outer_key_col, inner_key_col) = equi_pairs[0];
3847        let outer_is_sorted = outer_pk_col == Some(outer_key_col);
3848        match try_integer_join(
3849            outer_rows,
3850            inner_rows,
3851            &join.join_type,
3852            outer_key_col,
3853            inner_key_col,
3854            outer_col_count,
3855            inner_col_count,
3856            outer_is_sorted,
3857            effective_proj,
3858        ) {
3859            Ok(result) => return result,
3860            Err(rows) => outer_rows = rows,
3861        }
3862    }
3863
3864    let outer_key_cols: Vec<usize> = equi_pairs.iter().map(|&(o, _)| o).collect();
3865    let inner_key_cols: Vec<usize> = equi_pairs.iter().map(|&(_, i)| i).collect();
3866
3867    let mut inner_map: HashMap<Vec<Value>, Vec<usize>> = HashMap::new();
3868    for (idx, inner) in inner_rows.iter().enumerate() {
3869        inner_map
3870            .entry(hash_key(inner, &inner_key_cols))
3871            .or_default()
3872            .push(idx);
3873    }
3874
3875    let cap = effective_proj.map_or(outer_col_count + inner_col_count, |p| p.slots.len());
3876    let mut result = Vec::new();
3877
3878    if is_pure_equi {
3879        match join.join_type {
3880            JoinType::Inner | JoinType::Cross => {
3881                for mut outer in outer_rows {
3882                    let key = hash_key(&outer, &outer_key_cols);
3883                    if let Some(indices) = inner_map.get(&key) {
3884                        if let Some(proj) = effective_proj {
3885                            for &idx in indices {
3886                                result.push(combine_row_projected(&outer, &inner_rows[idx], proj));
3887                            }
3888                        } else {
3889                            for &idx in &indices[..indices.len() - 1] {
3890                                result.push(combine_row(&outer, &inner_rows[idx], cap));
3891                            }
3892                            let last_idx = *indices.last().unwrap();
3893                            outer.extend(inner_rows[last_idx].iter().cloned());
3894                            result.push(outer);
3895                        }
3896                    }
3897                }
3898            }
3899            JoinType::Left => {
3900                for mut outer in outer_rows {
3901                    let key = hash_key(&outer, &outer_key_cols);
3902                    if let Some(indices) = inner_map.get(&key) {
3903                        if let Some(proj) = effective_proj {
3904                            for &idx in indices {
3905                                result.push(combine_row_projected(&outer, &inner_rows[idx], proj));
3906                            }
3907                        } else {
3908                            for &idx in &indices[..indices.len() - 1] {
3909                                result.push(combine_row(&outer, &inner_rows[idx], cap));
3910                            }
3911                            let last_idx = *indices.last().unwrap();
3912                            outer.extend(inner_rows[last_idx].iter().cloned());
3913                            result.push(outer);
3914                        }
3915                    } else if let Some(proj) = effective_proj {
3916                        let null_inner = vec![Value::Null; inner_col_count];
3917                        result.push(combine_row_projected(&outer, &null_inner, proj));
3918                    } else {
3919                        outer.resize(cap, Value::Null);
3920                        result.push(outer);
3921                    }
3922                }
3923            }
3924            JoinType::Right => {
3925                let mut inner_matched = vec![false; inner_rows.len()];
3926                for mut outer in outer_rows {
3927                    let key = hash_key(&outer, &outer_key_cols);
3928                    if let Some(indices) = inner_map.get(&key) {
3929                        if let Some(proj) = effective_proj {
3930                            for &idx in indices {
3931                                result.push(combine_row_projected(&outer, &inner_rows[idx], proj));
3932                                inner_matched[idx] = true;
3933                            }
3934                        } else {
3935                            for &idx in &indices[..indices.len() - 1] {
3936                                result.push(combine_row(&outer, &inner_rows[idx], cap));
3937                                inner_matched[idx] = true;
3938                            }
3939                            let last_idx = *indices.last().unwrap();
3940                            inner_matched[last_idx] = true;
3941                            outer.extend(inner_rows[last_idx].iter().cloned());
3942                            result.push(outer);
3943                        }
3944                    }
3945                }
3946                for (j, inner) in inner_rows.iter().enumerate() {
3947                    if !inner_matched[j] {
3948                        if let Some(proj) = effective_proj {
3949                            let null_outer = vec![Value::Null; outer_col_count];
3950                            result.push(combine_row_projected(&null_outer, inner, proj));
3951                        } else {
3952                            let mut padded = Vec::with_capacity(cap);
3953                            padded.resize(outer_col_count, Value::Null);
3954                            padded.extend(inner.iter().cloned());
3955                            result.push(padded);
3956                        }
3957                    }
3958                }
3959            }
3960        }
3961    } else {
3962        let combined_map = ColumnMap::new(combined_cols);
3963        let on_matches = |combined: &[Value]| -> bool {
3964            match join.on_clause {
3965                Some(ref on_expr) => eval_expr(on_expr, &combined_map, combined)
3966                    .map(|v| is_truthy(&v))
3967                    .unwrap_or(false),
3968                None => true,
3969            }
3970        };
3971
3972        match join.join_type {
3973            JoinType::Inner | JoinType::Cross => {
3974                for outer in &outer_rows {
3975                    let key = hash_key(outer, &outer_key_cols);
3976                    if let Some(indices) = inner_map.get(&key) {
3977                        for &idx in indices {
3978                            let combined = combine_row(outer, &inner_rows[idx], cap);
3979                            if on_matches(&combined) {
3980                                result.push(combined);
3981                            }
3982                        }
3983                    }
3984                }
3985            }
3986            JoinType::Left => {
3987                for outer in &outer_rows {
3988                    let key = hash_key(outer, &outer_key_cols);
3989                    let mut matched = false;
3990                    if let Some(indices) = inner_map.get(&key) {
3991                        for &idx in indices {
3992                            let combined = combine_row(outer, &inner_rows[idx], cap);
3993                            if on_matches(&combined) {
3994                                result.push(combined);
3995                                matched = true;
3996                            }
3997                        }
3998                    }
3999                    if !matched {
4000                        let mut padded = Vec::with_capacity(cap);
4001                        padded.extend(outer.iter().cloned());
4002                        padded.resize(cap, Value::Null);
4003                        result.push(padded);
4004                    }
4005                }
4006            }
4007            JoinType::Right => {
4008                let mut inner_matched = vec![false; inner_rows.len()];
4009                for outer in &outer_rows {
4010                    let key = hash_key(outer, &outer_key_cols);
4011                    if let Some(indices) = inner_map.get(&key) {
4012                        for &idx in indices {
4013                            let combined = combine_row(outer, &inner_rows[idx], cap);
4014                            if on_matches(&combined) {
4015                                result.push(combined);
4016                                inner_matched[idx] = true;
4017                            }
4018                        }
4019                    }
4020                }
4021                for (j, inner) in inner_rows.iter().enumerate() {
4022                    if !inner_matched[j] {
4023                        let mut padded = Vec::with_capacity(cap);
4024                        padded.resize(outer_col_count, Value::Null);
4025                        padded.extend(inner.iter().cloned());
4026                        result.push(padded);
4027                    }
4028                }
4029            }
4030        }
4031    }
4032
4033    result
4034}
4035
4036fn table_alias_or_name(name: &str, alias: &Option<String>) -> String {
4037    match alias {
4038        Some(a) => a.to_ascii_lowercase(),
4039        None => name.to_ascii_lowercase(),
4040    }
4041}
4042
4043fn collect_all_rows_raw(
4044    rtx: &mut citadel_txn::read_txn::ReadTxn<'_>,
4045    table_schema: &TableSchema,
4046) -> Result<Vec<Vec<Value>>> {
4047    let lower_name = &table_schema.name;
4048    let entry_count = rtx.table_entry_count(lower_name.as_bytes()).unwrap_or(0) as usize;
4049    let mut rows = Vec::with_capacity(entry_count);
4050    let mut scan_err: Option<SqlError> = None;
4051    rtx.table_scan_raw(lower_name.as_bytes(), |key, value| {
4052        match decode_full_row(table_schema, key, value) {
4053            Ok(row) => rows.push(row),
4054            Err(e) => {
4055                scan_err = Some(e);
4056                return false;
4057            }
4058        }
4059        true
4060    })
4061    .map_err(SqlError::Storage)?;
4062    if let Some(e) = scan_err {
4063        return Err(e);
4064    }
4065    Ok(rows)
4066}
4067
4068fn collect_all_rows_write(
4069    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
4070    table_schema: &TableSchema,
4071) -> Result<Vec<Vec<Value>>> {
4072    collect_rows_write(wtx, table_schema, &None, None).map(|(rows, _)| rows)
4073}
4074
4075fn has_ambiguous_bare_ref(expr: &Expr, columns: &[ColumnDef]) -> bool {
4076    match expr {
4077        Expr::Column(name) => {
4078            let lower = name.to_ascii_lowercase();
4079            columns
4080                .iter()
4081                .filter(|c| c.name == lower || c.name.ends_with(&format!(".{lower}")))
4082                .count()
4083                > 1
4084        }
4085        Expr::BinaryOp { left, right, .. } => {
4086            has_ambiguous_bare_ref(left, columns) || has_ambiguous_bare_ref(right, columns)
4087        }
4088        Expr::UnaryOp { expr: inner, .. } | Expr::IsNull(inner) | Expr::IsNotNull(inner) => {
4089            has_ambiguous_bare_ref(inner, columns)
4090        }
4091        Expr::Function { args, .. } | Expr::Coalesce(args) => {
4092            args.iter().any(|a| has_ambiguous_bare_ref(a, columns))
4093        }
4094        Expr::Between {
4095            expr: e, low, high, ..
4096        } => {
4097            has_ambiguous_bare_ref(e, columns)
4098                || has_ambiguous_bare_ref(low, columns)
4099                || has_ambiguous_bare_ref(high, columns)
4100        }
4101        Expr::InList { expr: e, list, .. } => {
4102            has_ambiguous_bare_ref(e, columns)
4103                || list.iter().any(|a| has_ambiguous_bare_ref(a, columns))
4104        }
4105        Expr::Like {
4106            expr: e,
4107            pattern,
4108            escape,
4109            ..
4110        } => {
4111            has_ambiguous_bare_ref(e, columns)
4112                || has_ambiguous_bare_ref(pattern, columns)
4113                || escape
4114                    .as_ref()
4115                    .is_some_and(|esc| has_ambiguous_bare_ref(esc, columns))
4116        }
4117        Expr::Cast { expr: inner, .. } => has_ambiguous_bare_ref(inner, columns),
4118        Expr::Case {
4119            operand,
4120            conditions,
4121            else_result,
4122        } => {
4123            operand
4124                .as_ref()
4125                .is_some_and(|o| has_ambiguous_bare_ref(o, columns))
4126                || conditions.iter().any(|(w, t)| {
4127                    has_ambiguous_bare_ref(w, columns) || has_ambiguous_bare_ref(t, columns)
4128                })
4129                || else_result
4130                    .as_ref()
4131                    .is_some_and(|e| has_ambiguous_bare_ref(e, columns))
4132        }
4133        _ => false,
4134    }
4135}
4136
4137struct JoinColumnPlan {
4138    per_table: Vec<Vec<usize>>,
4139    output_combined: Vec<usize>,
4140}
4141
4142fn compute_join_needed_columns(
4143    stmt: &SelectStmt,
4144    tables: &[(String, &TableSchema)],
4145) -> Option<JoinColumnPlan> {
4146    for sel in &stmt.columns {
4147        if matches!(sel, SelectColumn::AllColumns) {
4148            return None;
4149        }
4150    }
4151
4152    let combined_cols = build_joined_columns(tables);
4153
4154    for sel in &stmt.columns {
4155        if let SelectColumn::Expr { expr, .. } = sel {
4156            if has_ambiguous_bare_ref(expr, &combined_cols) {
4157                return None;
4158            }
4159        }
4160    }
4161
4162    let mut output_combined: Vec<usize> = Vec::new();
4163    for sel in &stmt.columns {
4164        if let SelectColumn::Expr { expr, .. } = sel {
4165            output_combined.extend(referenced_columns(expr, &combined_cols));
4166        }
4167    }
4168    if let Some(w) = &stmt.where_clause {
4169        output_combined.extend(referenced_columns(w, &combined_cols));
4170    }
4171    for ob in &stmt.order_by {
4172        output_combined.extend(referenced_columns(&ob.expr, &combined_cols));
4173    }
4174    for gb in &stmt.group_by {
4175        output_combined.extend(referenced_columns(gb, &combined_cols));
4176    }
4177    if let Some(h) = &stmt.having {
4178        output_combined.extend(referenced_columns(h, &combined_cols));
4179    }
4180    output_combined.sort_unstable();
4181    output_combined.dedup();
4182
4183    let mut needed_combined = output_combined.clone();
4184    for join in &stmt.joins {
4185        if let Some(on_expr) = &join.on_clause {
4186            needed_combined.extend(referenced_columns(on_expr, &combined_cols));
4187        }
4188    }
4189    needed_combined.sort_unstable();
4190    needed_combined.dedup();
4191
4192    let mut offsets = Vec::with_capacity(tables.len() + 1);
4193    offsets.push(0usize);
4194    for (_, s) in tables {
4195        offsets.push(offsets.last().unwrap() + s.columns.len());
4196    }
4197
4198    let mut per_table: Vec<Vec<usize>> = tables.iter().map(|_| Vec::new()).collect();
4199    for &ci in &needed_combined {
4200        for (t, _) in tables.iter().enumerate() {
4201            let start = offsets[t];
4202            let end = offsets[t + 1];
4203            if ci >= start && ci < end {
4204                per_table[t].push(ci - start);
4205                break;
4206            }
4207        }
4208    }
4209
4210    Some(JoinColumnPlan {
4211        per_table,
4212        output_combined,
4213    })
4214}
4215
4216fn collect_rows_partial(
4217    rtx: &mut citadel_txn::read_txn::ReadTxn<'_>,
4218    table_schema: &TableSchema,
4219    needed: &[usize],
4220) -> Result<Vec<Vec<Value>>> {
4221    if needed.is_empty() || needed.len() == table_schema.columns.len() {
4222        return collect_all_rows_raw(rtx, table_schema);
4223    }
4224    let ctx = PartialDecodeCtx::new(table_schema, needed);
4225    let lower_name = &table_schema.name;
4226    let entry_count = rtx.table_entry_count(lower_name.as_bytes()).unwrap_or(0) as usize;
4227    let mut rows = Vec::with_capacity(entry_count);
4228    let mut scan_err: Option<SqlError> = None;
4229    rtx.table_scan_raw(lower_name.as_bytes(), |key, value| {
4230        match ctx.decode(key, value) {
4231            Ok(row) => rows.push(row),
4232            Err(e) => {
4233                scan_err = Some(e);
4234                return false;
4235            }
4236        }
4237        true
4238    })
4239    .map_err(SqlError::Storage)?;
4240    if let Some(e) = scan_err {
4241        return Err(e);
4242    }
4243    Ok(rows)
4244}
4245
4246fn collect_rows_partial_write(
4247    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
4248    table_schema: &TableSchema,
4249    needed: &[usize],
4250) -> Result<Vec<Vec<Value>>> {
4251    if needed.is_empty() || needed.len() == table_schema.columns.len() {
4252        return collect_all_rows_write(wtx, table_schema);
4253    }
4254    let ctx = PartialDecodeCtx::new(table_schema, needed);
4255    let lower_name = &table_schema.name;
4256    let entry_count = wtx.table_entry_count(lower_name.as_bytes()).unwrap_or(0) as usize;
4257    let mut rows = Vec::with_capacity(entry_count);
4258    let mut scan_err: Option<SqlError> = None;
4259    wtx.table_scan_from(lower_name.as_bytes(), b"", |key, value| {
4260        match ctx.decode(key, value) {
4261            Ok(row) => rows.push(row),
4262            Err(e) => {
4263                scan_err = Some(e);
4264                return Ok(false);
4265            }
4266        }
4267        Ok(true)
4268    })
4269    .map_err(SqlError::Storage)?;
4270    if let Some(e) = scan_err {
4271        return Err(e);
4272    }
4273    Ok(rows)
4274}
4275
4276fn exec_select_join(
4277    db: &Database,
4278    schema: &SchemaManager,
4279    stmt: &SelectStmt,
4280) -> Result<ExecutionResult> {
4281    let from_schema = resolve_table_name(schema, &stmt.from)?;
4282    let from_alias = table_alias_or_name(&stmt.from, &stmt.from_alias);
4283
4284    let mut all_tables: Vec<(String, &TableSchema)> = vec![(from_alias.clone(), from_schema)];
4285    for join in &stmt.joins {
4286        let inner_schema = resolve_table_name(schema, &join.table.name)?;
4287        let inner_alias = table_alias_or_name(&join.table.name, &join.table.alias);
4288        all_tables.push((inner_alias, inner_schema));
4289    }
4290    let (needed_per_table, output_combined) = match compute_join_needed_columns(stmt, &all_tables) {
4291        Some(plan) => (Some(plan.per_table), Some(plan.output_combined)),
4292        None => (None, None),
4293    };
4294
4295    let mut rtx = db.begin_read();
4296    let mut outer_rows = match &needed_per_table {
4297        Some(n) if !n.is_empty() => collect_rows_partial(&mut rtx, from_schema, &n[0])?,
4298        _ => collect_all_rows_raw(&mut rtx, from_schema)?,
4299    };
4300
4301    let mut tables: Vec<(String, &TableSchema)> = vec![(from_alias.clone(), from_schema)];
4302    let mut cur_outer_pk_col: Option<usize> = if from_schema.primary_key_columns.len() == 1 {
4303        Some(from_schema.primary_key_columns[0] as usize)
4304    } else {
4305        None
4306    };
4307
4308    let num_joins = stmt.joins.len();
4309    let mut last_combined_cols: Option<Vec<ColumnDef>> = None;
4310    for (ji, join) in stmt.joins.iter().enumerate() {
4311        let inner_schema = resolve_table_name(schema, &join.table.name)?;
4312        let inner_alias = table_alias_or_name(&join.table.name, &join.table.alias);
4313        let inner_rows = match &needed_per_table {
4314            Some(n) if ji + 1 < n.len() => {
4315                collect_rows_partial(&mut rtx, inner_schema, &n[ji + 1])?
4316            }
4317            _ => collect_all_rows_raw(&mut rtx, inner_schema)?,
4318        };
4319
4320        let mut preview_tables = tables.clone();
4321        preview_tables.push((inner_alias.clone(), inner_schema));
4322        let combined_cols = build_joined_columns(&preview_tables);
4323
4324        let outer_col_count = if outer_rows.is_empty() {
4325            tables.iter().map(|(_, s)| s.columns.len()).sum()
4326        } else {
4327            outer_rows[0].len()
4328        };
4329        let inner_col_count = inner_schema.columns.len();
4330
4331        let is_last = ji == num_joins - 1;
4332        let proj = if is_last {
4333            output_combined
4334                .as_ref()
4335                .map(|oc| build_combine_projection(oc, outer_col_count))
4336        } else {
4337            None
4338        };
4339
4340        outer_rows = exec_join_step(
4341            outer_rows,
4342            &inner_rows,
4343            join,
4344            &combined_cols,
4345            outer_col_count,
4346            inner_col_count,
4347            cur_outer_pk_col,
4348            proj.as_ref(),
4349        );
4350        last_combined_cols = Some(combined_cols);
4351        tables.push((inner_alias, inner_schema));
4352        cur_outer_pk_col = None;
4353    }
4354    drop(rtx);
4355
4356    let joined_cols = last_combined_cols.unwrap_or_else(|| build_joined_columns(&tables));
4357    if let Some(ref oc) = output_combined {
4358        let actual_width = outer_rows.first().map_or(0, |r| r.len());
4359        if actual_width == oc.len() {
4360            let projected_cols = build_projected_columns(&joined_cols, oc);
4361            return process_select(&projected_cols, outer_rows, stmt, false);
4362        }
4363    }
4364    process_select(&joined_cols, outer_rows, stmt, false)
4365}
4366
4367fn exec_select_join_in_txn(
4368    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
4369    schema: &SchemaManager,
4370    stmt: &SelectStmt,
4371) -> Result<ExecutionResult> {
4372    let from_schema = resolve_table_name(schema, &stmt.from)?;
4373    let from_alias = table_alias_or_name(&stmt.from, &stmt.from_alias);
4374
4375    let mut all_tables: Vec<(String, &TableSchema)> = vec![(from_alias.clone(), from_schema)];
4376    for join in &stmt.joins {
4377        let inner_schema = resolve_table_name(schema, &join.table.name)?;
4378        let inner_alias = table_alias_or_name(&join.table.name, &join.table.alias);
4379        all_tables.push((inner_alias, inner_schema));
4380    }
4381    let (needed_per_table, output_combined) = match compute_join_needed_columns(stmt, &all_tables) {
4382        Some(plan) => (Some(plan.per_table), Some(plan.output_combined)),
4383        None => (None, None),
4384    };
4385
4386    let mut outer_rows = match &needed_per_table {
4387        Some(n) if !n.is_empty() => collect_rows_partial_write(wtx, from_schema, &n[0])?,
4388        _ => collect_all_rows_write(wtx, from_schema)?,
4389    };
4390
4391    let mut tables: Vec<(String, &TableSchema)> = vec![(from_alias.clone(), from_schema)];
4392    let mut cur_outer_pk_col: Option<usize> = if from_schema.primary_key_columns.len() == 1 {
4393        Some(from_schema.primary_key_columns[0] as usize)
4394    } else {
4395        None
4396    };
4397
4398    let num_joins = stmt.joins.len();
4399    for (ji, join) in stmt.joins.iter().enumerate() {
4400        let inner_schema = resolve_table_name(schema, &join.table.name)?;
4401        let inner_alias = table_alias_or_name(&join.table.name, &join.table.alias);
4402        let inner_rows = match &needed_per_table {
4403            Some(n) if ji + 1 < n.len() => {
4404                collect_rows_partial_write(wtx, inner_schema, &n[ji + 1])?
4405            }
4406            _ => collect_all_rows_write(wtx, inner_schema)?,
4407        };
4408
4409        let mut preview_tables = tables.clone();
4410        preview_tables.push((inner_alias.clone(), inner_schema));
4411        let combined_cols = build_joined_columns(&preview_tables);
4412
4413        let outer_col_count = if outer_rows.is_empty() {
4414            tables.iter().map(|(_, s)| s.columns.len()).sum()
4415        } else {
4416            outer_rows[0].len()
4417        };
4418        let inner_col_count = inner_schema.columns.len();
4419
4420        let is_last = ji == num_joins - 1;
4421        let proj = if is_last {
4422            output_combined
4423                .as_ref()
4424                .map(|oc| build_combine_projection(oc, outer_col_count))
4425        } else {
4426            None
4427        };
4428
4429        outer_rows = exec_join_step(
4430            outer_rows,
4431            &inner_rows,
4432            join,
4433            &combined_cols,
4434            outer_col_count,
4435            inner_col_count,
4436            cur_outer_pk_col,
4437            proj.as_ref(),
4438        );
4439        tables.push((inner_alias, inner_schema));
4440        cur_outer_pk_col = None;
4441    }
4442
4443    let joined_cols = build_joined_columns(&tables);
4444    if let Some(ref oc) = output_combined {
4445        let actual_width = outer_rows.first().map_or(0, |r| r.len());
4446        if actual_width == oc.len() {
4447            let projected_cols = build_projected_columns(&joined_cols, oc);
4448            return process_select(&projected_cols, outer_rows, stmt, false);
4449        }
4450    }
4451    process_select(&joined_cols, outer_rows, stmt, false)
4452}
4453
4454fn exec_update(
4455    db: &Database,
4456    schema: &SchemaManager,
4457    stmt: &UpdateStmt,
4458) -> Result<ExecutionResult> {
4459    let materialized;
4460    let stmt = if update_has_subquery(stmt) {
4461        materialized = materialize_update(stmt, &mut |sub| exec_subquery_read(db, schema, sub))?;
4462        &materialized
4463    } else {
4464        stmt
4465    };
4466
4467    let lower_name = stmt.table.to_ascii_lowercase();
4468    let table_schema = schema
4469        .get(&lower_name)
4470        .ok_or_else(|| SqlError::TableNotFound(stmt.table.clone()))?;
4471
4472    let col_map = ColumnMap::new(&table_schema.columns);
4473    let all_candidates = collect_keyed_rows_read(db, table_schema, &stmt.where_clause)?;
4474    let matching_rows: Vec<(Vec<u8>, Vec<Value>)> = all_candidates
4475        .into_iter()
4476        .filter(|(_, row)| match &stmt.where_clause {
4477            Some(where_expr) => match eval_expr(where_expr, &col_map, row) {
4478                Ok(val) => is_truthy(&val),
4479                Err(_) => false,
4480            },
4481            None => true,
4482        })
4483        .collect();
4484
4485    if matching_rows.is_empty() {
4486        return Ok(ExecutionResult::RowsAffected(0));
4487    }
4488
4489    struct UpdateChange {
4490        old_key: Vec<u8>,
4491        new_key: Vec<u8>,
4492        new_value: Vec<u8>,
4493        pk_changed: bool,
4494        old_row: Vec<Value>,
4495        new_row: Vec<Value>,
4496    }
4497
4498    let pk_indices = table_schema.pk_indices();
4499    let mut changes: Vec<UpdateChange> = Vec::new();
4500
4501    for (old_key, row) in &matching_rows {
4502        let mut new_row = row.clone();
4503        let mut pk_changed = false;
4504
4505        // Evaluate all SET expressions against the original row (SQL standard).
4506        let mut evaluated: Vec<(usize, Value)> = Vec::with_capacity(stmt.assignments.len());
4507        for (col_name, expr) in &stmt.assignments {
4508            let col_idx = table_schema
4509                .column_index(col_name)
4510                .ok_or_else(|| SqlError::ColumnNotFound(col_name.clone()))?;
4511            let new_val = eval_expr(expr, &col_map, row)?;
4512            let col = &table_schema.columns[col_idx];
4513
4514            let got_type = new_val.data_type();
4515            let coerced = if new_val.is_null() {
4516                if !col.nullable {
4517                    return Err(SqlError::NotNullViolation(col.name.clone()));
4518                }
4519                Value::Null
4520            } else {
4521                new_val
4522                    .coerce_into(col.data_type)
4523                    .ok_or_else(|| SqlError::TypeMismatch {
4524                        expected: col.data_type.to_string(),
4525                        got: got_type.to_string(),
4526                    })?
4527            };
4528
4529            evaluated.push((col_idx, coerced));
4530        }
4531
4532        for (col_idx, coerced) in evaluated {
4533            if table_schema.primary_key_columns.contains(&(col_idx as u16)) {
4534                pk_changed = true;
4535            }
4536            new_row[col_idx] = coerced;
4537        }
4538
4539        let pk_values: Vec<Value> = pk_indices.iter().map(|&i| new_row[i].clone()).collect();
4540        let new_key = encode_composite_key(&pk_values);
4541
4542        let non_pk = table_schema.non_pk_indices();
4543        let value_values: Vec<Value> = non_pk.iter().map(|&i| new_row[i].clone()).collect();
4544        let new_value = encode_row(&value_values);
4545
4546        changes.push(UpdateChange {
4547            old_key: old_key.clone(),
4548            new_key,
4549            new_value,
4550            pk_changed,
4551            old_row: row.clone(),
4552            new_row,
4553        });
4554    }
4555
4556    {
4557        use std::collections::HashSet;
4558        let mut new_keys: HashSet<Vec<u8>> = HashSet::new();
4559        for c in &changes {
4560            if c.pk_changed && c.new_key != c.old_key && !new_keys.insert(c.new_key.clone()) {
4561                return Err(SqlError::DuplicateKey);
4562            }
4563        }
4564    }
4565
4566    let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
4567
4568    for c in &changes {
4569        let old_pk: Vec<Value> = pk_indices.iter().map(|&i| c.old_row[i].clone()).collect();
4570
4571        for idx in &table_schema.indices {
4572            if index_columns_changed(idx, &c.old_row, &c.new_row) || c.pk_changed {
4573                let idx_table = TableSchema::index_table_name(&lower_name, &idx.name);
4574                let old_idx_key = encode_index_key(idx, &c.old_row, &old_pk);
4575                wtx.table_delete(&idx_table, &old_idx_key)
4576                    .map_err(SqlError::Storage)?;
4577            }
4578        }
4579
4580        if c.pk_changed {
4581            wtx.table_delete(lower_name.as_bytes(), &c.old_key)
4582                .map_err(SqlError::Storage)?;
4583        }
4584    }
4585
4586    for c in &changes {
4587        let new_pk: Vec<Value> = pk_indices.iter().map(|&i| c.new_row[i].clone()).collect();
4588
4589        if c.pk_changed {
4590            let is_new = wtx
4591                .table_insert(lower_name.as_bytes(), &c.new_key, &c.new_value)
4592                .map_err(SqlError::Storage)?;
4593            if !is_new {
4594                return Err(SqlError::DuplicateKey);
4595            }
4596        } else {
4597            wtx.table_insert(lower_name.as_bytes(), &c.new_key, &c.new_value)
4598                .map_err(SqlError::Storage)?;
4599        }
4600
4601        for idx in &table_schema.indices {
4602            if index_columns_changed(idx, &c.old_row, &c.new_row) || c.pk_changed {
4603                let idx_table = TableSchema::index_table_name(&lower_name, &idx.name);
4604                let new_idx_key = encode_index_key(idx, &c.new_row, &new_pk);
4605                let new_idx_val = encode_index_value(idx, &c.new_row, &new_pk);
4606                let is_new = wtx
4607                    .table_insert(&idx_table, &new_idx_key, &new_idx_val)
4608                    .map_err(SqlError::Storage)?;
4609                if idx.unique && !is_new {
4610                    let indexed_values: Vec<Value> = idx
4611                        .columns
4612                        .iter()
4613                        .map(|&col_idx| c.new_row[col_idx as usize].clone())
4614                        .collect();
4615                    let any_null = indexed_values.iter().any(|v| v.is_null());
4616                    if !any_null {
4617                        return Err(SqlError::UniqueViolation(idx.name.clone()));
4618                    }
4619                }
4620            }
4621        }
4622    }
4623
4624    let count = changes.len() as u64;
4625    wtx.commit().map_err(SqlError::Storage)?;
4626    Ok(ExecutionResult::RowsAffected(count))
4627}
4628
4629fn exec_delete(
4630    db: &Database,
4631    schema: &SchemaManager,
4632    stmt: &DeleteStmt,
4633) -> Result<ExecutionResult> {
4634    let materialized;
4635    let stmt = if delete_has_subquery(stmt) {
4636        materialized = materialize_delete(stmt, &mut |sub| exec_subquery_read(db, schema, sub))?;
4637        &materialized
4638    } else {
4639        stmt
4640    };
4641
4642    let lower_name = stmt.table.to_ascii_lowercase();
4643    let table_schema = schema
4644        .get(&lower_name)
4645        .ok_or_else(|| SqlError::TableNotFound(stmt.table.clone()))?;
4646
4647    let col_map = ColumnMap::new(&table_schema.columns);
4648    let all_candidates = collect_keyed_rows_read(db, table_schema, &stmt.where_clause)?;
4649    let rows_to_delete: Vec<(Vec<u8>, Vec<Value>)> = all_candidates
4650        .into_iter()
4651        .filter(|(_, row)| match &stmt.where_clause {
4652            Some(where_expr) => match eval_expr(where_expr, &col_map, row) {
4653                Ok(val) => is_truthy(&val),
4654                Err(_) => false,
4655            },
4656            None => true,
4657        })
4658        .collect();
4659
4660    if rows_to_delete.is_empty() {
4661        return Ok(ExecutionResult::RowsAffected(0));
4662    }
4663
4664    let pk_indices = table_schema.pk_indices();
4665    let mut wtx = db.begin_write().map_err(SqlError::Storage)?;
4666    for (key, row) in &rows_to_delete {
4667        let pk_values: Vec<Value> = pk_indices.iter().map(|&i| row[i].clone()).collect();
4668        delete_index_entries(&mut wtx, table_schema, row, &pk_values)?;
4669        wtx.table_delete(lower_name.as_bytes(), key)
4670            .map_err(SqlError::Storage)?;
4671    }
4672    let count = rows_to_delete.len() as u64;
4673    wtx.commit().map_err(SqlError::Storage)?;
4674    Ok(ExecutionResult::RowsAffected(count))
4675}
4676
4677#[derive(Default)]
4678pub struct InsertBufs {
4679    row: Vec<Value>,
4680    pk_values: Vec<Value>,
4681    value_values: Vec<Value>,
4682    key_buf: Vec<u8>,
4683    value_buf: Vec<u8>,
4684    col_indices: Vec<usize>,
4685}
4686
4687impl InsertBufs {
4688    pub fn new() -> Self {
4689        Self {
4690            row: Vec::new(),
4691            pk_values: Vec::new(),
4692            value_values: Vec::new(),
4693            key_buf: Vec::with_capacity(64),
4694            value_buf: Vec::with_capacity(256),
4695            col_indices: Vec::new(),
4696        }
4697    }
4698}
4699
4700pub fn exec_insert_in_txn(
4701    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
4702    schema: &SchemaManager,
4703    stmt: &InsertStmt,
4704    params: &[Value],
4705    bufs: &mut InsertBufs,
4706) -> Result<ExecutionResult> {
4707    let materialized;
4708    let stmt = if insert_has_subquery(stmt) {
4709        materialized = materialize_insert(stmt, &mut |sub| exec_subquery_write(wtx, schema, sub))?;
4710        &materialized
4711    } else {
4712        stmt
4713    };
4714
4715    let table_schema = schema
4716        .get(&stmt.table)
4717        .ok_or_else(|| SqlError::TableNotFound(stmt.table.clone()))?;
4718
4719    let default_columns;
4720    let insert_columns: &[String] = if stmt.columns.is_empty() {
4721        default_columns = table_schema
4722            .columns
4723            .iter()
4724            .map(|c| c.name.clone())
4725            .collect::<Vec<_>>();
4726        &default_columns
4727    } else {
4728        &stmt.columns
4729    };
4730
4731    bufs.col_indices.clear();
4732    for name in insert_columns {
4733        bufs.col_indices.push(
4734            table_schema
4735                .column_index(name)
4736                .ok_or_else(|| SqlError::ColumnNotFound(name.clone()))?,
4737        );
4738    }
4739
4740    let pk_indices = table_schema.pk_indices();
4741    let non_pk = table_schema.non_pk_indices();
4742
4743    bufs.row.resize(table_schema.columns.len(), Value::Null);
4744    bufs.pk_values.resize(pk_indices.len(), Value::Null);
4745    bufs.value_values.resize(non_pk.len(), Value::Null);
4746
4747    let mut count: u64 = 0;
4748
4749    for value_row in &stmt.values {
4750        if value_row.len() != insert_columns.len() {
4751            return Err(SqlError::InvalidValue(format!(
4752                "expected {} values, got {}",
4753                insert_columns.len(),
4754                value_row.len()
4755            )));
4756        }
4757
4758        for v in bufs.row.iter_mut() {
4759            *v = Value::Null;
4760        }
4761
4762        for (i, expr) in value_row.iter().enumerate() {
4763            let val = if let Expr::Parameter(n) = expr {
4764                params
4765                    .get(n - 1)
4766                    .cloned()
4767                    .ok_or_else(|| SqlError::Parse(format!("unbound parameter ${n}")))?
4768            } else {
4769                eval_const_expr(expr)?
4770            };
4771            let col_idx = bufs.col_indices[i];
4772            let col = &table_schema.columns[col_idx];
4773
4774            let got_type = val.data_type();
4775            bufs.row[col_idx] = if val.is_null() {
4776                Value::Null
4777            } else {
4778                val.coerce_into(col.data_type)
4779                    .ok_or_else(|| SqlError::TypeMismatch {
4780                        expected: col.data_type.to_string(),
4781                        got: got_type.to_string(),
4782                    })?
4783            };
4784        }
4785
4786        for col in &table_schema.columns {
4787            if !col.nullable && bufs.row[col.position as usize].is_null() {
4788                return Err(SqlError::NotNullViolation(col.name.clone()));
4789            }
4790        }
4791
4792        for (j, &i) in pk_indices.iter().enumerate() {
4793            bufs.pk_values[j] = std::mem::replace(&mut bufs.row[i], Value::Null);
4794        }
4795        encode_composite_key_into(&bufs.pk_values, &mut bufs.key_buf);
4796
4797        for (j, &i) in non_pk.iter().enumerate() {
4798            bufs.value_values[j] = std::mem::replace(&mut bufs.row[i], Value::Null);
4799        }
4800        encode_row_into(&bufs.value_values, &mut bufs.value_buf);
4801
4802        if bufs.key_buf.len() > citadel_core::MAX_KEY_SIZE {
4803            return Err(SqlError::KeyTooLarge {
4804                size: bufs.key_buf.len(),
4805                max: citadel_core::MAX_KEY_SIZE,
4806            });
4807        }
4808        if bufs.value_buf.len() > citadel_core::MAX_INLINE_VALUE_SIZE {
4809            return Err(SqlError::RowTooLarge {
4810                size: bufs.value_buf.len(),
4811                max: citadel_core::MAX_INLINE_VALUE_SIZE,
4812            });
4813        }
4814
4815        let is_new = wtx
4816            .table_insert(stmt.table.as_bytes(), &bufs.key_buf, &bufs.value_buf)
4817            .map_err(SqlError::Storage)?;
4818        if !is_new {
4819            return Err(SqlError::DuplicateKey);
4820        }
4821
4822        if !table_schema.indices.is_empty() {
4823            for (j, &i) in pk_indices.iter().enumerate() {
4824                bufs.row[i] = bufs.pk_values[j].clone();
4825            }
4826            for (j, &i) in non_pk.iter().enumerate() {
4827                bufs.row[i] = std::mem::replace(&mut bufs.value_values[j], Value::Null);
4828            }
4829            insert_index_entries(wtx, table_schema, &bufs.row, &bufs.pk_values)?;
4830        }
4831        count += 1;
4832    }
4833
4834    Ok(ExecutionResult::RowsAffected(count))
4835}
4836
4837fn exec_select_in_txn(
4838    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
4839    schema: &SchemaManager,
4840    stmt: &SelectStmt,
4841) -> Result<ExecutionResult> {
4842    let materialized;
4843    let stmt = if stmt_has_subquery(stmt) {
4844        materialized = materialize_stmt(stmt, &mut |sub| exec_subquery_write(wtx, schema, sub))?;
4845        &materialized
4846    } else {
4847        stmt
4848    };
4849
4850    if stmt.from.is_empty() {
4851        return exec_select_no_from(stmt);
4852    }
4853
4854    if !stmt.joins.is_empty() {
4855        return exec_select_join_in_txn(wtx, schema, stmt);
4856    }
4857
4858    let lower_name = stmt.from.to_ascii_lowercase();
4859    let table_schema = schema
4860        .get(&lower_name)
4861        .ok_or_else(|| SqlError::TableNotFound(stmt.from.clone()))?;
4862
4863    if let Some(result) = try_count_star_shortcut(stmt, || {
4864        wtx.table_entry_count(lower_name.as_bytes())
4865            .map_err(SqlError::Storage)
4866    })? {
4867        return Ok(result);
4868    }
4869
4870    if let Some(plan) = StreamAggPlan::try_new(stmt, table_schema)? {
4871        let mut states: Vec<AggState> = plan.ops.iter().map(|(op, _)| AggState::new(op)).collect();
4872        let mut scan_err: Option<SqlError> = None;
4873        if stmt.where_clause.is_none() {
4874            wtx.table_scan_from(lower_name.as_bytes(), b"", |key, value| {
4875                Ok(plan.feed_row_raw(key, value, &mut states, &mut scan_err))
4876            })
4877            .map_err(SqlError::Storage)?;
4878        } else {
4879            let col_map = ColumnMap::new(&table_schema.columns);
4880            wtx.table_scan_from(lower_name.as_bytes(), b"", |key, value| {
4881                Ok(plan.feed_row(
4882                    key,
4883                    value,
4884                    table_schema,
4885                    &col_map,
4886                    &stmt.where_clause,
4887                    &mut states,
4888                    &mut scan_err,
4889                ))
4890            })
4891            .map_err(SqlError::Storage)?;
4892        }
4893        if let Some(e) = scan_err {
4894            return Err(e);
4895        }
4896        return Ok(plan.finish(states));
4897    }
4898
4899    if let Some(plan) = StreamGroupByPlan::try_new(stmt, table_schema)? {
4900        let lower = lower_name.clone();
4901        return plan.execute_scan(|cb| {
4902            wtx.table_scan_from(lower.as_bytes(), b"", |key, value| Ok(cb(key, value)))
4903        });
4904    }
4905
4906    if let Some(plan) = TopKScanPlan::try_new(stmt, table_schema)? {
4907        let lower = lower_name.clone();
4908        return plan.execute_scan(table_schema, stmt, |cb| {
4909            wtx.table_scan_from(lower.as_bytes(), b"", |key, value| Ok(cb(key, value)))
4910        });
4911    }
4912
4913    let scan_limit = compute_scan_limit(stmt);
4914    let (rows, predicate_applied) =
4915        collect_rows_write(wtx, table_schema, &stmt.where_clause, scan_limit)?;
4916    process_select(&table_schema.columns, rows, stmt, predicate_applied)
4917}
4918
4919fn exec_update_in_txn(
4920    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
4921    schema: &SchemaManager,
4922    stmt: &UpdateStmt,
4923) -> Result<ExecutionResult> {
4924    let materialized;
4925    let stmt = if update_has_subquery(stmt) {
4926        materialized = materialize_update(stmt, &mut |sub| exec_subquery_write(wtx, schema, sub))?;
4927        &materialized
4928    } else {
4929        stmt
4930    };
4931
4932    let lower_name = stmt.table.to_ascii_lowercase();
4933    let table_schema = schema
4934        .get(&lower_name)
4935        .ok_or_else(|| SqlError::TableNotFound(stmt.table.clone()))?;
4936
4937    let col_map = ColumnMap::new(&table_schema.columns);
4938    let all_candidates = collect_keyed_rows_write(wtx, table_schema, &stmt.where_clause)?;
4939    let matching_rows: Vec<(Vec<u8>, Vec<Value>)> = all_candidates
4940        .into_iter()
4941        .filter(|(_, row)| match &stmt.where_clause {
4942            Some(where_expr) => match eval_expr(where_expr, &col_map, row) {
4943                Ok(val) => is_truthy(&val),
4944                Err(_) => false,
4945            },
4946            None => true,
4947        })
4948        .collect();
4949
4950    if matching_rows.is_empty() {
4951        return Ok(ExecutionResult::RowsAffected(0));
4952    }
4953
4954    struct UpdateChange {
4955        old_key: Vec<u8>,
4956        new_key: Vec<u8>,
4957        new_value: Vec<u8>,
4958        pk_changed: bool,
4959        old_row: Vec<Value>,
4960        new_row: Vec<Value>,
4961    }
4962
4963    let pk_indices = table_schema.pk_indices();
4964    let mut changes: Vec<UpdateChange> = Vec::new();
4965
4966    for (old_key, row) in &matching_rows {
4967        let mut new_row = row.clone();
4968        let mut pk_changed = false;
4969
4970        // Evaluate all SET expressions against the original row (SQL standard).
4971        let mut evaluated: Vec<(usize, Value)> = Vec::with_capacity(stmt.assignments.len());
4972        for (col_name, expr) in &stmt.assignments {
4973            let col_idx = table_schema
4974                .column_index(col_name)
4975                .ok_or_else(|| SqlError::ColumnNotFound(col_name.clone()))?;
4976            let new_val = eval_expr(expr, &col_map, row)?;
4977            let col = &table_schema.columns[col_idx];
4978
4979            let got_type = new_val.data_type();
4980            let coerced = if new_val.is_null() {
4981                if !col.nullable {
4982                    return Err(SqlError::NotNullViolation(col.name.clone()));
4983                }
4984                Value::Null
4985            } else {
4986                new_val
4987                    .coerce_into(col.data_type)
4988                    .ok_or_else(|| SqlError::TypeMismatch {
4989                        expected: col.data_type.to_string(),
4990                        got: got_type.to_string(),
4991                    })?
4992            };
4993
4994            evaluated.push((col_idx, coerced));
4995        }
4996
4997        for (col_idx, coerced) in evaluated {
4998            if table_schema.primary_key_columns.contains(&(col_idx as u16)) {
4999                pk_changed = true;
5000            }
5001            new_row[col_idx] = coerced;
5002        }
5003
5004        let pk_values: Vec<Value> = pk_indices.iter().map(|&i| new_row[i].clone()).collect();
5005        let new_key = encode_composite_key(&pk_values);
5006
5007        let non_pk = table_schema.non_pk_indices();
5008        let value_values: Vec<Value> = non_pk.iter().map(|&i| new_row[i].clone()).collect();
5009        let new_value = encode_row(&value_values);
5010
5011        changes.push(UpdateChange {
5012            old_key: old_key.clone(),
5013            new_key,
5014            new_value,
5015            pk_changed,
5016            old_row: row.clone(),
5017            new_row,
5018        });
5019    }
5020
5021    {
5022        use std::collections::HashSet;
5023        let mut new_keys: HashSet<Vec<u8>> = HashSet::new();
5024        for c in &changes {
5025            if c.pk_changed && c.new_key != c.old_key && !new_keys.insert(c.new_key.clone()) {
5026                return Err(SqlError::DuplicateKey);
5027            }
5028        }
5029    }
5030
5031    for c in &changes {
5032        let old_pk: Vec<Value> = pk_indices.iter().map(|&i| c.old_row[i].clone()).collect();
5033
5034        for idx in &table_schema.indices {
5035            if index_columns_changed(idx, &c.old_row, &c.new_row) || c.pk_changed {
5036                let idx_table = TableSchema::index_table_name(&lower_name, &idx.name);
5037                let old_idx_key = encode_index_key(idx, &c.old_row, &old_pk);
5038                wtx.table_delete(&idx_table, &old_idx_key)
5039                    .map_err(SqlError::Storage)?;
5040            }
5041        }
5042
5043        if c.pk_changed {
5044            wtx.table_delete(lower_name.as_bytes(), &c.old_key)
5045                .map_err(SqlError::Storage)?;
5046        }
5047    }
5048
5049    for c in &changes {
5050        let new_pk: Vec<Value> = pk_indices.iter().map(|&i| c.new_row[i].clone()).collect();
5051
5052        if c.pk_changed {
5053            let is_new = wtx
5054                .table_insert(lower_name.as_bytes(), &c.new_key, &c.new_value)
5055                .map_err(SqlError::Storage)?;
5056            if !is_new {
5057                return Err(SqlError::DuplicateKey);
5058            }
5059        } else {
5060            wtx.table_insert(lower_name.as_bytes(), &c.new_key, &c.new_value)
5061                .map_err(SqlError::Storage)?;
5062        }
5063
5064        for idx in &table_schema.indices {
5065            if index_columns_changed(idx, &c.old_row, &c.new_row) || c.pk_changed {
5066                let idx_table = TableSchema::index_table_name(&lower_name, &idx.name);
5067                let new_idx_key = encode_index_key(idx, &c.new_row, &new_pk);
5068                let new_idx_val = encode_index_value(idx, &c.new_row, &new_pk);
5069                let is_new = wtx
5070                    .table_insert(&idx_table, &new_idx_key, &new_idx_val)
5071                    .map_err(SqlError::Storage)?;
5072                if idx.unique && !is_new {
5073                    let indexed_values: Vec<Value> = idx
5074                        .columns
5075                        .iter()
5076                        .map(|&col_idx| c.new_row[col_idx as usize].clone())
5077                        .collect();
5078                    let any_null = indexed_values.iter().any(|v| v.is_null());
5079                    if !any_null {
5080                        return Err(SqlError::UniqueViolation(idx.name.clone()));
5081                    }
5082                }
5083            }
5084        }
5085    }
5086
5087    let count = changes.len() as u64;
5088    Ok(ExecutionResult::RowsAffected(count))
5089}
5090
5091fn exec_delete_in_txn(
5092    wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
5093    schema: &SchemaManager,
5094    stmt: &DeleteStmt,
5095) -> Result<ExecutionResult> {
5096    let materialized;
5097    let stmt = if delete_has_subquery(stmt) {
5098        materialized = materialize_delete(stmt, &mut |sub| exec_subquery_write(wtx, schema, sub))?;
5099        &materialized
5100    } else {
5101        stmt
5102    };
5103
5104    let lower_name = stmt.table.to_ascii_lowercase();
5105    let table_schema = schema
5106        .get(&lower_name)
5107        .ok_or_else(|| SqlError::TableNotFound(stmt.table.clone()))?;
5108
5109    let col_map = ColumnMap::new(&table_schema.columns);
5110    let all_candidates = collect_keyed_rows_write(wtx, table_schema, &stmt.where_clause)?;
5111    let rows_to_delete: Vec<(Vec<u8>, Vec<Value>)> = all_candidates
5112        .into_iter()
5113        .filter(|(_, row)| match &stmt.where_clause {
5114            Some(where_expr) => match eval_expr(where_expr, &col_map, row) {
5115                Ok(val) => is_truthy(&val),
5116                Err(_) => false,
5117            },
5118            None => true,
5119        })
5120        .collect();
5121
5122    if rows_to_delete.is_empty() {
5123        return Ok(ExecutionResult::RowsAffected(0));
5124    }
5125
5126    let pk_indices = table_schema.pk_indices();
5127    for (key, row) in &rows_to_delete {
5128        let pk_values: Vec<Value> = pk_indices.iter().map(|&i| row[i].clone()).collect();
5129        delete_index_entries(wtx, table_schema, row, &pk_values)?;
5130        wtx.table_delete(lower_name.as_bytes(), key)
5131            .map_err(SqlError::Storage)?;
5132    }
5133    let count = rows_to_delete.len() as u64;
5134    Ok(ExecutionResult::RowsAffected(count))
5135}
5136
5137// ── Aggregation ─────────────────────────────────────────────────────
5138
5139fn exec_aggregate(
5140    columns: &[ColumnDef],
5141    rows: &[Vec<Value>],
5142    stmt: &SelectStmt,
5143) -> Result<ExecutionResult> {
5144    let col_map = ColumnMap::new(columns);
5145    let groups: BTreeMap<Vec<Value>, Vec<&Vec<Value>>> = if stmt.group_by.is_empty() {
5146        let mut m = BTreeMap::new();
5147        m.insert(vec![], rows.iter().collect());
5148        m
5149    } else {
5150        let mut m: BTreeMap<Vec<Value>, Vec<&Vec<Value>>> = BTreeMap::new();
5151        for row in rows {
5152            let group_key: Vec<Value> = stmt
5153                .group_by
5154                .iter()
5155                .map(|expr| eval_expr(expr, &col_map, row))
5156                .collect::<Result<_>>()?;
5157            m.entry(group_key).or_default().push(row);
5158        }
5159        m
5160    };
5161
5162    let mut result_rows = Vec::new();
5163    let output_cols = build_output_columns(&stmt.columns, columns);
5164
5165    for group_rows in groups.values() {
5166        let mut result_row = Vec::new();
5167
5168        for sel_col in &stmt.columns {
5169            match sel_col {
5170                SelectColumn::AllColumns => {
5171                    return Err(SqlError::Unsupported("SELECT * with GROUP BY".into()));
5172                }
5173                SelectColumn::Expr { expr, .. } => {
5174                    let val = eval_aggregate_expr(expr, &col_map, group_rows)?;
5175                    result_row.push(val);
5176                }
5177            }
5178        }
5179
5180        if let Some(ref having) = stmt.having {
5181            let passes = match eval_aggregate_expr(having, &col_map, group_rows) {
5182                Ok(val) => is_truthy(&val),
5183                Err(SqlError::ColumnNotFound(_)) => {
5184                    let output_map = ColumnMap::new(&output_cols);
5185                    match eval_expr(having, &output_map, &result_row) {
5186                        Ok(val) => is_truthy(&val),
5187                        Err(_) => false,
5188                    }
5189                }
5190                Err(e) => return Err(e),
5191            };
5192            if !passes {
5193                continue;
5194            }
5195        }
5196
5197        result_rows.push(result_row);
5198    }
5199
5200    if stmt.distinct {
5201        let mut seen = std::collections::HashSet::new();
5202        result_rows.retain(|row| seen.insert(row.clone()));
5203    }
5204
5205    if !stmt.order_by.is_empty() {
5206        let output_cols = build_output_columns(&stmt.columns, columns);
5207        sort_rows(&mut result_rows, &stmt.order_by, &output_cols)?;
5208    }
5209
5210    if let Some(ref offset_expr) = stmt.offset {
5211        let offset = eval_const_int(offset_expr)?.max(0) as usize;
5212        if offset < result_rows.len() {
5213            result_rows = result_rows.split_off(offset);
5214        } else {
5215            result_rows.clear();
5216        }
5217    }
5218    if let Some(ref limit_expr) = stmt.limit {
5219        let limit = eval_const_int(limit_expr)?.max(0) as usize;
5220        result_rows.truncate(limit);
5221    }
5222
5223    let col_names = stmt
5224        .columns
5225        .iter()
5226        .map(|c| match c {
5227            SelectColumn::AllColumns => "*".into(),
5228            SelectColumn::Expr { alias: Some(a), .. } => a.clone(),
5229            SelectColumn::Expr { expr, .. } => expr_display_name(expr),
5230        })
5231        .collect();
5232
5233    Ok(ExecutionResult::Query(QueryResult {
5234        columns: col_names,
5235        rows: result_rows,
5236    }))
5237}
5238
5239fn eval_aggregate_expr(
5240    expr: &Expr,
5241    col_map: &ColumnMap,
5242    group_rows: &[&Vec<Value>],
5243) -> Result<Value> {
5244    match expr {
5245        Expr::CountStar => Ok(Value::Integer(group_rows.len() as i64)),
5246
5247        Expr::Function { name, args } if is_aggregate_function(name, args.len()) => {
5248            let func = name.to_ascii_uppercase();
5249            if args.len() != 1 {
5250                return Err(SqlError::Unsupported(format!(
5251                    "{func} with {} args",
5252                    args.len()
5253                )));
5254            }
5255            let arg = &args[0];
5256            let values: Vec<Value> = group_rows
5257                .iter()
5258                .map(|row| eval_expr(arg, col_map, row))
5259                .collect::<Result<_>>()?;
5260
5261            match func.as_str() {
5262                "COUNT" => {
5263                    let count = values.iter().filter(|v| !v.is_null()).count();
5264                    Ok(Value::Integer(count as i64))
5265                }
5266                "SUM" => {
5267                    let mut int_sum: i64 = 0;
5268                    let mut real_sum: f64 = 0.0;
5269                    let mut has_real = false;
5270                    let mut all_null = true;
5271                    for v in &values {
5272                        match v {
5273                            Value::Integer(i) => {
5274                                int_sum += i;
5275                                all_null = false;
5276                            }
5277                            Value::Real(r) => {
5278                                real_sum += r;
5279                                has_real = true;
5280                                all_null = false;
5281                            }
5282                            Value::Null => {}
5283                            _ => {
5284                                return Err(SqlError::TypeMismatch {
5285                                    expected: "numeric".into(),
5286                                    got: v.data_type().to_string(),
5287                                })
5288                            }
5289                        }
5290                    }
5291                    if all_null {
5292                        return Ok(Value::Null);
5293                    }
5294                    if has_real {
5295                        Ok(Value::Real(real_sum + int_sum as f64))
5296                    } else {
5297                        Ok(Value::Integer(int_sum))
5298                    }
5299                }
5300                "AVG" => {
5301                    let mut sum: f64 = 0.0;
5302                    let mut count: i64 = 0;
5303                    for v in &values {
5304                        match v {
5305                            Value::Integer(i) => {
5306                                sum += *i as f64;
5307                                count += 1;
5308                            }
5309                            Value::Real(r) => {
5310                                sum += r;
5311                                count += 1;
5312                            }
5313                            Value::Null => {}
5314                            _ => {
5315                                return Err(SqlError::TypeMismatch {
5316                                    expected: "numeric".into(),
5317                                    got: v.data_type().to_string(),
5318                                })
5319                            }
5320                        }
5321                    }
5322                    if count == 0 {
5323                        Ok(Value::Null)
5324                    } else {
5325                        Ok(Value::Real(sum / count as f64))
5326                    }
5327                }
5328                "MIN" => {
5329                    let mut min: Option<&Value> = None;
5330                    for v in &values {
5331                        if v.is_null() {
5332                            continue;
5333                        }
5334                        min = Some(match min {
5335                            None => v,
5336                            Some(m) => {
5337                                if v < m {
5338                                    v
5339                                } else {
5340                                    m
5341                                }
5342                            }
5343                        });
5344                    }
5345                    Ok(min.cloned().unwrap_or(Value::Null))
5346                }
5347                "MAX" => {
5348                    let mut max: Option<&Value> = None;
5349                    for v in &values {
5350                        if v.is_null() {
5351                            continue;
5352                        }
5353                        max = Some(match max {
5354                            None => v,
5355                            Some(m) => {
5356                                if v > m {
5357                                    v
5358                                } else {
5359                                    m
5360                                }
5361                            }
5362                        });
5363                    }
5364                    Ok(max.cloned().unwrap_or(Value::Null))
5365                }
5366                _ => Err(SqlError::Unsupported(format!("aggregate function: {func}"))),
5367            }
5368        }
5369
5370        Expr::Column(_) | Expr::QualifiedColumn { .. } => {
5371            if let Some(first) = group_rows.first() {
5372                eval_expr(expr, col_map, first)
5373            } else {
5374                Ok(Value::Null)
5375            }
5376        }
5377
5378        Expr::Literal(v) => Ok(v.clone()),
5379
5380        Expr::BinaryOp { left, op, right } => {
5381            let l = eval_aggregate_expr(left, col_map, group_rows)?;
5382            let r = eval_aggregate_expr(right, col_map, group_rows)?;
5383            eval_expr(
5384                &Expr::BinaryOp {
5385                    left: Box::new(Expr::Literal(l)),
5386                    op: *op,
5387                    right: Box::new(Expr::Literal(r)),
5388                },
5389                col_map,
5390                &[],
5391            )
5392        }
5393
5394        Expr::UnaryOp { op, expr: e } => {
5395            let v = eval_aggregate_expr(e, col_map, group_rows)?;
5396            eval_expr(
5397                &Expr::UnaryOp {
5398                    op: *op,
5399                    expr: Box::new(Expr::Literal(v)),
5400                },
5401                col_map,
5402                &[],
5403            )
5404        }
5405
5406        Expr::IsNull(e) => {
5407            let v = eval_aggregate_expr(e, col_map, group_rows)?;
5408            Ok(Value::Boolean(v.is_null()))
5409        }
5410
5411        Expr::IsNotNull(e) => {
5412            let v = eval_aggregate_expr(e, col_map, group_rows)?;
5413            Ok(Value::Boolean(!v.is_null()))
5414        }
5415
5416        Expr::Cast { expr: e, data_type } => {
5417            let v = eval_aggregate_expr(e, col_map, group_rows)?;
5418            eval_expr(
5419                &Expr::Cast {
5420                    expr: Box::new(Expr::Literal(v)),
5421                    data_type: *data_type,
5422                },
5423                col_map,
5424                &[],
5425            )
5426        }
5427
5428        Expr::Case {
5429            operand,
5430            conditions,
5431            else_result,
5432        } => {
5433            let op_val = operand
5434                .as_ref()
5435                .map(|e| eval_aggregate_expr(e, col_map, group_rows))
5436                .transpose()?;
5437            if let Some(ov) = &op_val {
5438                for (cond, result) in conditions {
5439                    let cv = eval_aggregate_expr(cond, col_map, group_rows)?;
5440                    if !ov.is_null() && !cv.is_null() && *ov == cv {
5441                        return eval_aggregate_expr(result, col_map, group_rows);
5442                    }
5443                }
5444            } else {
5445                for (cond, result) in conditions {
5446                    let cv = eval_aggregate_expr(cond, col_map, group_rows)?;
5447                    if is_truthy(&cv) {
5448                        return eval_aggregate_expr(result, col_map, group_rows);
5449                    }
5450                }
5451            }
5452            match else_result {
5453                Some(e) => eval_aggregate_expr(e, col_map, group_rows),
5454                None => Ok(Value::Null),
5455            }
5456        }
5457
5458        Expr::Coalesce(args) => {
5459            for arg in args {
5460                let v = eval_aggregate_expr(arg, col_map, group_rows)?;
5461                if !v.is_null() {
5462                    return Ok(v);
5463                }
5464            }
5465            Ok(Value::Null)
5466        }
5467
5468        Expr::Between {
5469            expr: e,
5470            low,
5471            high,
5472            negated,
5473        } => {
5474            let v = eval_aggregate_expr(e, col_map, group_rows)?;
5475            let lo = eval_aggregate_expr(low, col_map, group_rows)?;
5476            let hi = eval_aggregate_expr(high, col_map, group_rows)?;
5477            eval_expr(
5478                &Expr::Between {
5479                    expr: Box::new(Expr::Literal(v)),
5480                    low: Box::new(Expr::Literal(lo)),
5481                    high: Box::new(Expr::Literal(hi)),
5482                    negated: *negated,
5483                },
5484                col_map,
5485                &[],
5486            )
5487        }
5488
5489        Expr::Like {
5490            expr: e,
5491            pattern,
5492            escape,
5493            negated,
5494        } => {
5495            let v = eval_aggregate_expr(e, col_map, group_rows)?;
5496            let p = eval_aggregate_expr(pattern, col_map, group_rows)?;
5497            let esc = escape
5498                .as_ref()
5499                .map(|es| eval_aggregate_expr(es, col_map, group_rows))
5500                .transpose()?;
5501            let esc_box = esc.map(|v| Box::new(Expr::Literal(v)));
5502            eval_expr(
5503                &Expr::Like {
5504                    expr: Box::new(Expr::Literal(v)),
5505                    pattern: Box::new(Expr::Literal(p)),
5506                    escape: esc_box,
5507                    negated: *negated,
5508                },
5509                col_map,
5510                &[],
5511            )
5512        }
5513
5514        Expr::Function { name, args } => {
5515            let evaluated: Vec<Value> = args
5516                .iter()
5517                .map(|a| eval_aggregate_expr(a, col_map, group_rows))
5518                .collect::<Result<_>>()?;
5519            let literal_args: Vec<Expr> = evaluated.into_iter().map(Expr::Literal).collect();
5520            eval_expr(
5521                &Expr::Function {
5522                    name: name.clone(),
5523                    args: literal_args,
5524                },
5525                col_map,
5526                &[],
5527            )
5528        }
5529
5530        _ => Err(SqlError::Unsupported(format!(
5531            "expression in aggregate: {expr:?}"
5532        ))),
5533    }
5534}
5535
5536fn is_aggregate_function(name: &str, arg_count: usize) -> bool {
5537    let u = name.to_ascii_uppercase();
5538    matches!(u.as_str(), "COUNT" | "SUM" | "AVG")
5539        || (matches!(u.as_str(), "MIN" | "MAX") && arg_count == 1)
5540}
5541
5542fn is_aggregate_expr(expr: &Expr) -> bool {
5543    match expr {
5544        Expr::CountStar => true,
5545        Expr::Function { name, args } => {
5546            is_aggregate_function(name, args.len()) || args.iter().any(is_aggregate_expr)
5547        }
5548        Expr::BinaryOp { left, right, .. } => is_aggregate_expr(left) || is_aggregate_expr(right),
5549        Expr::UnaryOp { expr, .. }
5550        | Expr::IsNull(expr)
5551        | Expr::IsNotNull(expr)
5552        | Expr::Cast { expr, .. } => is_aggregate_expr(expr),
5553        Expr::Case {
5554            operand,
5555            conditions,
5556            else_result,
5557        } => {
5558            operand.as_ref().is_some_and(|e| is_aggregate_expr(e))
5559                || conditions
5560                    .iter()
5561                    .any(|(c, r)| is_aggregate_expr(c) || is_aggregate_expr(r))
5562                || else_result.as_ref().is_some_and(|e| is_aggregate_expr(e))
5563        }
5564        Expr::Coalesce(args) => args.iter().any(is_aggregate_expr),
5565        Expr::Between {
5566            expr, low, high, ..
5567        } => is_aggregate_expr(expr) || is_aggregate_expr(low) || is_aggregate_expr(high),
5568        Expr::Like {
5569            expr,
5570            pattern,
5571            escape,
5572            ..
5573        } => {
5574            is_aggregate_expr(expr)
5575                || is_aggregate_expr(pattern)
5576                || escape.as_ref().is_some_and(|e| is_aggregate_expr(e))
5577        }
5578        _ => false,
5579    }
5580}
5581
5582// ── Helpers ─────────────────────────────────────────────────────────
5583
5584struct PartialDecodeCtx {
5585    pk_positions: Vec<(usize, usize)>,
5586    nonpk_targets: Vec<usize>,
5587    nonpk_schema: Vec<usize>,
5588    num_cols: usize,
5589    num_pk_cols: usize,
5590    remaining_pk: Vec<(usize, usize)>,
5591    remaining_nonpk_targets: Vec<usize>,
5592    remaining_nonpk_schema: Vec<usize>,
5593}
5594
5595impl PartialDecodeCtx {
5596    fn new(schema: &TableSchema, needed: &[usize]) -> Self {
5597        let non_pk = schema.non_pk_indices();
5598        let mut pk_positions = Vec::new();
5599        let mut nonpk_targets = Vec::new();
5600        let mut nonpk_schema = Vec::new();
5601
5602        for &col in needed {
5603            if let Some(pk_pos) = schema
5604                .primary_key_columns
5605                .iter()
5606                .position(|&i| i as usize == col)
5607            {
5608                pk_positions.push((pk_pos, col));
5609            } else if let Some(nonpk_idx) = non_pk.iter().position(|&i| i == col) {
5610                nonpk_targets.push(nonpk_idx);
5611                nonpk_schema.push(col);
5612            }
5613        }
5614
5615        let needed_set: std::collections::HashSet<usize> = needed.iter().copied().collect();
5616        let mut remaining_pk = Vec::new();
5617        for (pk_pos, &pk_col) in schema.primary_key_columns.iter().enumerate() {
5618            if !needed_set.contains(&(pk_col as usize)) {
5619                remaining_pk.push((pk_pos, pk_col as usize));
5620            }
5621        }
5622        let mut remaining_nonpk_targets = Vec::new();
5623        let mut remaining_nonpk_schema = Vec::new();
5624        for (nonpk_idx, &col) in non_pk.iter().enumerate() {
5625            if !needed_set.contains(&col) {
5626                remaining_nonpk_targets.push(nonpk_idx);
5627                remaining_nonpk_schema.push(col);
5628            }
5629        }
5630
5631        Self {
5632            pk_positions,
5633            nonpk_targets,
5634            nonpk_schema,
5635            num_cols: schema.columns.len(),
5636            num_pk_cols: schema.primary_key_columns.len(),
5637            remaining_pk,
5638            remaining_nonpk_targets,
5639            remaining_nonpk_schema,
5640        }
5641    }
5642
5643    fn decode(&self, key: &[u8], value: &[u8]) -> Result<Vec<Value>> {
5644        let mut row = vec![Value::Null; self.num_cols];
5645
5646        if self.pk_positions.len() == 1 && self.num_pk_cols == 1 {
5647            let (_, schema_col) = self.pk_positions[0];
5648            let (v, _) = decode_key_value(key)?;
5649            row[schema_col] = v;
5650        } else if !self.pk_positions.is_empty() {
5651            let mut pk_values = decode_composite_key(key, self.num_pk_cols)?;
5652            for &(pk_pos, schema_col) in &self.pk_positions {
5653                row[schema_col] = std::mem::take(&mut pk_values[pk_pos]);
5654            }
5655        }
5656
5657        if !self.nonpk_targets.is_empty() {
5658            decode_columns_into(value, &self.nonpk_targets, &self.nonpk_schema, &mut row)?;
5659        }
5660
5661        Ok(row)
5662    }
5663
5664    fn complete(&self, mut row: Vec<Value>, key: &[u8], value: &[u8]) -> Result<Vec<Value>> {
5665        if !self.remaining_pk.is_empty() {
5666            let mut pk_values = decode_composite_key(key, self.num_pk_cols)?;
5667            for &(pk_pos, schema_col) in &self.remaining_pk {
5668                row[schema_col] = std::mem::take(&mut pk_values[pk_pos]);
5669            }
5670        }
5671        if !self.remaining_nonpk_targets.is_empty() {
5672            let mut values = decode_columns(value, &self.remaining_nonpk_targets)?;
5673            for (i, &schema_col) in self.remaining_nonpk_schema.iter().enumerate() {
5674                row[schema_col] = std::mem::take(&mut values[i]);
5675            }
5676        }
5677        Ok(row)
5678    }
5679}
5680
5681fn decode_full_row(schema: &TableSchema, key: &[u8], value: &[u8]) -> Result<Vec<Value>> {
5682    let mut row = vec![Value::Null; schema.columns.len()];
5683    decode_pk_into(
5684        key,
5685        schema.primary_key_columns.len(),
5686        &mut row,
5687        schema.pk_indices(),
5688    )?;
5689    decode_row_into(value, &mut row, schema.non_pk_indices())?;
5690    Ok(row)
5691}
5692
5693/// Evaluate a constant expression (no column references).
5694fn eval_const_expr(expr: &Expr) -> Result<Value> {
5695    static EMPTY: std::sync::OnceLock<ColumnMap> = std::sync::OnceLock::new();
5696    let empty = EMPTY.get_or_init(|| ColumnMap::new(&[]));
5697    eval_expr(expr, empty, &[])
5698}
5699
5700fn eval_const_int(expr: &Expr) -> Result<i64> {
5701    match eval_const_expr(expr)? {
5702        Value::Integer(i) => Ok(i),
5703        other => Err(SqlError::TypeMismatch {
5704            expected: "INTEGER".into(),
5705            got: other.data_type().to_string(),
5706        }),
5707    }
5708}
5709
5710fn sort_rows(
5711    rows: &mut [Vec<Value>],
5712    order_by: &[OrderByItem],
5713    columns: &[ColumnDef],
5714) -> Result<()> {
5715    if rows.is_empty() {
5716        return Ok(());
5717    }
5718    let col_map = ColumnMap::new(columns);
5719    let mut indices: Vec<usize> = (0..rows.len()).collect();
5720
5721    if let Some(col_idx) = try_resolve_flat_sort_col(order_by, &col_map) {
5722        let desc = order_by[0].descending;
5723        let nulls_first = order_by[0].nulls_first.unwrap_or(!desc);
5724        indices.sort_by(|&a, &b| {
5725            compare_flat_key(&rows[a][col_idx], &rows[b][col_idx], desc, nulls_first)
5726        });
5727    } else {
5728        let keys = extract_sort_keys(rows, order_by, &col_map);
5729        indices.sort_by(|&a, &b| compare_sort_keys(&keys[a], &keys[b], order_by));
5730    }
5731
5732    let sorted: Vec<Vec<Value>> = indices
5733        .iter()
5734        .map(|&i| std::mem::take(&mut rows[i]))
5735        .collect();
5736    rows.iter_mut()
5737        .zip(sorted)
5738        .for_each(|(slot, row)| *slot = row);
5739    Ok(())
5740}
5741
5742fn topk_rows(
5743    rows: &mut [Vec<Value>],
5744    order_by: &[OrderByItem],
5745    columns: &[ColumnDef],
5746    k: usize,
5747) -> Result<()> {
5748    let col_map = ColumnMap::new(columns);
5749    let mut indices: Vec<usize> = (0..rows.len()).collect();
5750
5751    if let Some(col_idx) = try_resolve_flat_sort_col(order_by, &col_map) {
5752        let desc = order_by[0].descending;
5753        let nulls_first = order_by[0].nulls_first.unwrap_or(!desc);
5754        let cmp = |&a: &usize, &b: &usize| {
5755            compare_flat_key(&rows[a][col_idx], &rows[b][col_idx], desc, nulls_first)
5756        };
5757        indices.select_nth_unstable_by(k - 1, cmp);
5758        indices[..k].sort_by(cmp);
5759    } else {
5760        let keys = extract_sort_keys(rows, order_by, &col_map);
5761        let cmp = |&a: &usize, &b: &usize| compare_sort_keys(&keys[a], &keys[b], order_by);
5762        indices.select_nth_unstable_by(k - 1, cmp);
5763        indices[..k].sort_by(cmp);
5764    }
5765
5766    let sorted: Vec<Vec<Value>> = indices[..k]
5767        .iter()
5768        .map(|&i| std::mem::take(&mut rows[i]))
5769        .collect();
5770    rows[..k]
5771        .iter_mut()
5772        .zip(sorted)
5773        .for_each(|(slot, row)| *slot = row);
5774    Ok(())
5775}
5776
5777fn try_resolve_flat_sort_col(order_by: &[OrderByItem], col_map: &ColumnMap) -> Option<usize> {
5778    if order_by.len() != 1 {
5779        return None;
5780    }
5781    match &order_by[0].expr {
5782        Expr::Column(name) => col_map.resolve(&name.to_ascii_lowercase()).ok(),
5783        _ => None,
5784    }
5785}
5786
5787fn compare_flat_key(a: &Value, b: &Value, desc: bool, nulls_first: bool) -> std::cmp::Ordering {
5788    match (a.is_null(), b.is_null()) {
5789        (true, true) => std::cmp::Ordering::Equal,
5790        (true, false) => {
5791            if nulls_first {
5792                std::cmp::Ordering::Less
5793            } else {
5794                std::cmp::Ordering::Greater
5795            }
5796        }
5797        (false, true) => {
5798            if nulls_first {
5799                std::cmp::Ordering::Greater
5800            } else {
5801                std::cmp::Ordering::Less
5802            }
5803        }
5804        (false, false) => {
5805            let cmp = a.cmp(b);
5806            if desc {
5807                cmp.reverse()
5808            } else {
5809                cmp
5810            }
5811        }
5812    }
5813}
5814
5815fn extract_sort_keys(
5816    rows: &[Vec<Value>],
5817    order_by: &[OrderByItem],
5818    col_map: &ColumnMap,
5819) -> Vec<Vec<Value>> {
5820    rows.iter()
5821        .map(|row| {
5822            order_by
5823                .iter()
5824                .map(|item| eval_expr(&item.expr, col_map, row).unwrap_or(Value::Null))
5825                .collect()
5826        })
5827        .collect()
5828}
5829
5830fn compare_sort_keys(a: &[Value], b: &[Value], order_by: &[OrderByItem]) -> std::cmp::Ordering {
5831    for (i, item) in order_by.iter().enumerate() {
5832        let nulls_first = item.nulls_first.unwrap_or(!item.descending);
5833        let ord = match (a[i].is_null(), b[i].is_null()) {
5834            (true, true) => std::cmp::Ordering::Equal,
5835            (true, false) => {
5836                if nulls_first {
5837                    std::cmp::Ordering::Less
5838                } else {
5839                    std::cmp::Ordering::Greater
5840                }
5841            }
5842            (false, true) => {
5843                if nulls_first {
5844                    std::cmp::Ordering::Greater
5845                } else {
5846                    std::cmp::Ordering::Less
5847                }
5848            }
5849            (false, false) => {
5850                let cmp = a[i].cmp(&b[i]);
5851                if item.descending {
5852                    cmp.reverse()
5853                } else {
5854                    cmp
5855                }
5856            }
5857        };
5858        if ord != std::cmp::Ordering::Equal {
5859            return ord;
5860        }
5861    }
5862    std::cmp::Ordering::Equal
5863}
5864
5865fn try_build_index_map(
5866    select_cols: &[SelectColumn],
5867    columns: &[ColumnDef],
5868) -> Option<Vec<(String, usize)>> {
5869    let col_map = ColumnMap::new(columns);
5870    let mut map = Vec::new();
5871    let mut seen = std::collections::HashSet::new();
5872    for sel in select_cols {
5873        match sel {
5874            SelectColumn::AllColumns => {
5875                for col in columns {
5876                    let idx = col.position as usize;
5877                    if !seen.insert(idx) {
5878                        return None;
5879                    }
5880                    map.push((col.name.clone(), idx));
5881                }
5882            }
5883            SelectColumn::Expr { expr, alias } => {
5884                let idx = match expr {
5885                    Expr::Column(name) => col_map.resolve(name).ok()?,
5886                    Expr::QualifiedColumn { table, column } => {
5887                        col_map.resolve_qualified(table, column).ok()?
5888                    }
5889                    _ => return None,
5890                };
5891                if !seen.insert(idx) {
5892                    return None;
5893                }
5894                let name = alias.clone().unwrap_or_else(|| expr_display_name(expr));
5895                map.push((name, idx));
5896            }
5897        }
5898    }
5899    Some(map)
5900}
5901
5902fn project_rows(
5903    columns: &[ColumnDef],
5904    select_cols: &[SelectColumn],
5905    mut rows: Vec<Vec<Value>>,
5906) -> Result<(Vec<String>, Vec<Vec<Value>>)> {
5907    // Fast path: SELECT * — zero clones
5908    if select_cols.len() == 1 && matches!(select_cols[0], SelectColumn::AllColumns) {
5909        let col_names = columns.iter().map(|c| c.name.clone()).collect();
5910        return Ok((col_names, rows));
5911    }
5912
5913    // Fast path: all simple column refs — use mem::take, zero clones
5914    if let Some(map) = try_build_index_map(select_cols, columns) {
5915        let col_names: Vec<String> = map.iter().map(|(n, _)| n.clone()).collect();
5916        // Identity: columns already in the right order — return as-is
5917        if map.len() == columns.len() && map.iter().enumerate().all(|(i, &(_, idx))| idx == i) {
5918            return Ok((col_names, rows));
5919        }
5920        let projected = rows
5921            .iter_mut()
5922            .map(|row| {
5923                map.iter()
5924                    .map(|&(_, idx)| std::mem::take(&mut row[idx]))
5925                    .collect()
5926            })
5927            .collect();
5928        return Ok((col_names, projected));
5929    }
5930
5931    // Fallback: expression evaluation (requires cloning)
5932    let mut col_names = Vec::new();
5933    type Projector = Box<dyn Fn(&[Value]) -> Result<Value>>;
5934    let mut projectors: Vec<Projector> = Vec::new();
5935    let col_map = std::sync::Arc::new(ColumnMap::new(columns));
5936
5937    for sel_col in select_cols {
5938        match sel_col {
5939            SelectColumn::AllColumns => {
5940                for col in columns {
5941                    let idx = col.position as usize;
5942                    col_names.push(col.name.clone());
5943                    projectors.push(Box::new(move |row: &[Value]| Ok(row[idx].clone())));
5944                }
5945            }
5946            SelectColumn::Expr { expr, alias } => {
5947                let name = alias.clone().unwrap_or_else(|| expr_display_name(expr));
5948                col_names.push(name);
5949                let expr = expr.clone();
5950                let map = col_map.clone();
5951                projectors.push(Box::new(move |row: &[Value]| eval_expr(&expr, &map, row)));
5952            }
5953        }
5954    }
5955
5956    let projected = rows
5957        .iter()
5958        .map(|row| {
5959            projectors
5960                .iter()
5961                .map(|p| p(row))
5962                .collect::<Result<Vec<_>>>()
5963        })
5964        .collect::<Result<Vec<_>>>()?;
5965
5966    Ok((col_names, projected))
5967}
5968
5969fn expr_display_name(expr: &Expr) -> String {
5970    match expr {
5971        Expr::Column(name) => name.clone(),
5972        Expr::QualifiedColumn { table, column } => format!("{table}.{column}"),
5973        Expr::Literal(v) => format!("{v}"),
5974        Expr::CountStar => "COUNT(*)".into(),
5975        Expr::Function { name, args } => {
5976            let arg_strs: Vec<String> = args.iter().map(expr_display_name).collect();
5977            format!("{name}({})", arg_strs.join(", "))
5978        }
5979        Expr::BinaryOp { left, op, right } => {
5980            format!(
5981                "{} {} {}",
5982                expr_display_name(left),
5983                op_symbol(op),
5984                expr_display_name(right)
5985            )
5986        }
5987        _ => "?".into(),
5988    }
5989}
5990
5991fn op_symbol(op: &BinOp) -> &'static str {
5992    match op {
5993        BinOp::Add => "+",
5994        BinOp::Sub => "-",
5995        BinOp::Mul => "*",
5996        BinOp::Div => "/",
5997        BinOp::Mod => "%",
5998        BinOp::Eq => "=",
5999        BinOp::NotEq => "<>",
6000        BinOp::Lt => "<",
6001        BinOp::Gt => ">",
6002        BinOp::LtEq => "<=",
6003        BinOp::GtEq => ">=",
6004        BinOp::And => "AND",
6005        BinOp::Or => "OR",
6006        BinOp::Concat => "||",
6007    }
6008}
6009
6010fn build_output_columns(select_cols: &[SelectColumn], columns: &[ColumnDef]) -> Vec<ColumnDef> {
6011    let mut out = Vec::new();
6012    for (i, col) in select_cols.iter().enumerate() {
6013        let (name, data_type) = match col {
6014            SelectColumn::AllColumns => (format!("col{i}"), DataType::Null),
6015            SelectColumn::Expr {
6016                alias: Some(a),
6017                expr,
6018            } => (a.clone(), infer_expr_type(expr, columns)),
6019            SelectColumn::Expr { expr, .. } => {
6020                (expr_display_name(expr), infer_expr_type(expr, columns))
6021            }
6022        };
6023        out.push(ColumnDef {
6024            name,
6025            data_type,
6026            nullable: true,
6027            position: i as u16,
6028        });
6029    }
6030    out
6031}
6032
6033fn infer_expr_type(expr: &Expr, columns: &[ColumnDef]) -> DataType {
6034    match expr {
6035        Expr::Column(name) => columns
6036            .iter()
6037            .find(|c| c.name == *name)
6038            .map(|c| c.data_type)
6039            .unwrap_or(DataType::Null),
6040        Expr::QualifiedColumn { table, column } => {
6041            let qualified = format!("{table}.{column}");
6042            columns
6043                .iter()
6044                .find(|c| c.name == qualified)
6045                .map(|c| c.data_type)
6046                .unwrap_or(DataType::Null)
6047        }
6048        Expr::Literal(v) => v.data_type(),
6049        Expr::CountStar => DataType::Integer,
6050        Expr::Function { name, .. } => match name.to_ascii_uppercase().as_str() {
6051            "COUNT" => DataType::Integer,
6052            "AVG" => DataType::Real,
6053            "SUM" | "MIN" | "MAX" => DataType::Null,
6054            _ => DataType::Null,
6055        },
6056        _ => DataType::Null,
6057    }
6058}