Skip to main content

hematite/sql/
connection.rs

1//! SQL connection boundary.
2//!
3//! A connection owns a catalog instance plus statement-level transaction behavior.
4//!
5//! ```text
6//! SQL text / prepared statement
7//!            |
8//!            v
9//!         parser
10//!            |
11//!            v
12//!    planner + executor
13//!            |
14//!            v
15//!         catalog
16//!            |
17//!            v
18//!      btree + pager
19//! ```
20//!
21//! This is where autocommit, explicit transactions, journal mode changes, and user-facing SQL
22//! errors are coordinated. The connection should not need to understand row encoding or page
23//! structure; it only sequences higher-level components.
24
25use crate::error::{HematiteError, Result};
26use crate::parser::ast::{
27    ColumnDefinition, Condition, CreateStatement, CreateViewStatement, Expression, InsertSource,
28    InsertStatement, SelectIntoStatement, SelectStatement, Statement, TableReference, TriggerEvent,
29    WhereClause,
30};
31use crate::parser::{Lexer, Parser, SqlTypeName};
32use crate::query::lowering::raise_literal_value;
33use crate::query::metadata as query_metadata;
34use crate::query::validation::{projected_column_names, source_column_names, validate_statement};
35use crate::query::{
36    Catalog, CatalogEngine, ExecutionContext, JournalMode, MutationEvent, QueryCatalogSnapshot,
37    QueryExecutor, QueryPlanner, QueryResult, Schema, Value,
38};
39use crate::sql::result::ExecutedStatement;
40use crate::sql::script::{split_script_tokens, ScriptIter};
41use std::collections::{HashMap, HashSet};
42use std::sync::{Arc, Mutex, MutexGuard};
43
44#[derive(Debug, Clone)]
45struct ConnectionTransaction {
46    snapshot: QueryCatalogSnapshot,
47    savepoints: Vec<SavepointState>,
48}
49
50#[derive(Debug, Clone)]
51struct SavepointState {
52    name: String,
53    snapshot: QueryCatalogSnapshot,
54}
55
56#[derive(Debug)]
57struct ImplicitMutation {
58    snapshot: Option<QueryCatalogSnapshot>,
59}
60
61impl ImplicitMutation {
62    fn begin(connection: &mut Connection) -> Result<Self> {
63        if connection.transaction.is_some() {
64            return Ok(Self { snapshot: None });
65        }
66
67        let mut catalog_guard = connection.lock_catalog()?;
68        let snapshot = catalog_guard.snapshot()?;
69        catalog_guard.begin_transaction()?;
70        Ok(Self {
71            snapshot: Some(snapshot),
72        })
73    }
74
75    fn rollback(mut self, connection: &mut Connection) -> Result<()> {
76        if let Some(snapshot) = self.snapshot.take() {
77            let mut catalog_guard = connection.lock_catalog()?;
78            let _ = catalog_guard.rollback_transaction();
79            catalog_guard.restore_snapshot(snapshot)?;
80        }
81        Ok(())
82    }
83
84    fn commit(mut self, connection: &mut Connection) -> Result<()> {
85        let Some(snapshot) = self.snapshot.take() else {
86            return Ok(());
87        };
88
89        let mut catalog_guard = connection.lock_catalog()?;
90        match catalog_guard.commit_transaction() {
91            Ok(()) => Ok(()),
92            Err(err) => {
93                let _ = catalog_guard.rollback_transaction();
94                catalog_guard.restore_snapshot(snapshot)?;
95                Err(err)
96            }
97        }
98    }
99}
100
101#[derive(Debug)]
102pub struct Connection {
103    catalog: Arc<Mutex<Catalog>>,
104    transaction: Option<ConnectionTransaction>,
105    trigger_depth: usize,
106}
107
108impl Connection {
109    const SELECT_INTO_ROWID_COLUMN: &'static str = "__hematite_select_into_rowid";
110
111    fn empty_result() -> QueryResult {
112        QueryResult {
113            affected_rows: 0,
114            columns: Vec::new(),
115            rows: Vec::new(),
116        }
117    }
118
119    fn mutation_result(affected_rows: usize) -> QueryResult {
120        QueryResult {
121            affected_rows,
122            columns: Vec::new(),
123            rows: Vec::new(),
124        }
125    }
126
127    fn select_into_synthetic_pk_name(column_names: &[String]) -> String {
128        let mut candidate = Self::SELECT_INTO_ROWID_COLUMN.to_string();
129        let used = column_names
130            .iter()
131            .map(|name| name.to_ascii_lowercase())
132            .collect::<HashSet<_>>();
133        let mut suffix = 2usize;
134        while used.contains(&candidate.to_ascii_lowercase()) {
135            candidate = format!("{}_{}", Self::SELECT_INTO_ROWID_COLUMN, suffix);
136            suffix += 1;
137        }
138        candidate
139    }
140
141    fn select_into_column_names(result: &QueryResult) -> Vec<String> {
142        let mut used = HashSet::new();
143        let mut names = Vec::with_capacity(result.columns.len());
144        for (index, name) in result.columns.iter().enumerate() {
145            let mut candidate = if name.trim().is_empty() {
146                format!("column{}", index + 1)
147            } else {
148                name.clone()
149            };
150            let base = candidate.clone();
151            let mut suffix = 2usize;
152            while used.contains(&candidate.to_ascii_lowercase())
153                || candidate.eq_ignore_ascii_case(Self::SELECT_INTO_ROWID_COLUMN)
154            {
155                candidate = format!("{base}_{suffix}");
156                suffix += 1;
157            }
158            used.insert(candidate.to_ascii_lowercase());
159            names.push(candidate);
160        }
161        names
162    }
163
164    fn infer_select_into_type(
165        column_name: &str,
166        values: &[Vec<Value>],
167        index: usize,
168    ) -> Result<SqlTypeName> {
169        #[derive(Debug, Clone, Copy, PartialEq, Eq)]
170        enum NumericKind {
171            Int,
172            Int64,
173            Int128,
174            UInt,
175            UInt64,
176            UInt128,
177            Float32,
178            Float,
179            Decimal,
180        }
181
182        #[derive(Debug, Clone)]
183        enum InferredKind {
184            Numeric(NumericKind),
185            String { saw_enum: bool, values: Vec<String> },
186            Boolean,
187            Blob,
188            Date,
189            Time,
190            DateTime,
191            TimeWithTimeZone,
192        }
193
194        impl InferredKind {
195            fn absorb(self, value: &Value, column_name: &str) -> Result<Self> {
196                use InferredKind::*;
197                use NumericKind::*;
198                match (self, value) {
199                    (kind, Value::Null) => Ok(kind),
200                    (_, Value::IntervalYearMonth(_)) | (_, Value::IntervalDaySecond(_)) => {
201                        Err(HematiteError::ParseError(format!(
202                            "SELECT INTO cannot infer a stored column type for interval-valued column '{}'",
203                            column_name
204                        )))
205                    }
206                    (Numeric(Int), Value::Integer(_)) => Ok(Numeric(Int)),
207                    (Numeric(Int), Value::BigInt(_))
208                    | (Numeric(Int64), Value::Integer(_))
209                    | (Numeric(Int64), Value::BigInt(_)) => Ok(Numeric(Int64)),
210                    (Numeric(Int), Value::Int128(_))
211                    | (Numeric(Int64), Value::Int128(_))
212                    | (Numeric(Int128), Value::Integer(_))
213                    | (Numeric(Int128), Value::BigInt(_))
214                    | (Numeric(Int128), Value::Int128(_)) => Ok(Numeric(Int128)),
215                    (Numeric(UInt), Value::UInteger(_)) => Ok(Numeric(UInt)),
216                    (Numeric(UInt), Value::UBigInt(_))
217                    | (Numeric(UInt64), Value::UInteger(_))
218                    | (Numeric(UInt64), Value::UBigInt(_)) => Ok(Numeric(UInt64)),
219                    (Numeric(UInt), Value::UInt128(_))
220                    | (Numeric(UInt64), Value::UInt128(_))
221                    | (Numeric(UInt128), Value::UInteger(_))
222                    | (Numeric(UInt128), Value::UBigInt(_))
223                    | (Numeric(UInt128), Value::UInt128(_)) => Ok(Numeric(UInt128)),
224                    (Numeric(Int), Value::UInteger(_))
225                    | (Numeric(Int64), Value::UInteger(_))
226                    | (Numeric(Int128), Value::UInteger(_))
227                    | (Numeric(UInt), Value::Integer(_))
228                    | (Numeric(UInt), Value::BigInt(_))
229                    | (Numeric(UInt), Value::Int128(_))
230                    | (Numeric(UInt64), Value::Integer(_))
231                    | (Numeric(UInt64), Value::BigInt(_))
232                    | (Numeric(UInt64), Value::Int128(_))
233                    | (Numeric(UInt128), Value::Integer(_))
234                    | (Numeric(UInt128), Value::BigInt(_))
235                    | (Numeric(UInt128), Value::Int128(_))
236                    | (Numeric(Int64), Value::UBigInt(_))
237                    | (Numeric(Int128), Value::UBigInt(_))
238                    | (Numeric(Int128), Value::UInt128(_))
239                    => Ok(Numeric(Decimal)),
240                    (Numeric(Int), Value::Float32(_))
241                    | (Numeric(Int64), Value::Float32(_))
242                    | (Numeric(Int128), Value::Float32(_))
243                    | (Numeric(UInt), Value::Float32(_))
244                    | (Numeric(UInt64), Value::Float32(_))
245                    | (Numeric(UInt128), Value::Float32(_))
246                    | (Numeric(Float32), Value::Integer(_))
247                    | (Numeric(Float32), Value::BigInt(_))
248                    | (Numeric(Float32), Value::Int128(_))
249                    | (Numeric(Float32), Value::UInteger(_))
250                    | (Numeric(Float32), Value::UBigInt(_))
251                    | (Numeric(Float32), Value::UInt128(_))
252                    | (Numeric(Float32), Value::Float32(_)) => Ok(Numeric(Float32)),
253                    (Numeric(Int), Value::Float(_))
254                    | (Numeric(Int64), Value::Float(_))
255                    | (Numeric(Int128), Value::Float(_))
256                    | (Numeric(UInt), Value::Float(_))
257                    | (Numeric(UInt64), Value::Float(_))
258                    | (Numeric(UInt128), Value::Float(_))
259                    | (Numeric(Float32), Value::Float(_))
260                    | (Numeric(Float), Value::Integer(_))
261                    | (Numeric(Float), Value::BigInt(_))
262                    | (Numeric(Float), Value::Int128(_))
263                    | (Numeric(Float), Value::UInteger(_))
264                    | (Numeric(Float), Value::UBigInt(_))
265                    | (Numeric(Float), Value::UInt128(_))
266                    | (Numeric(Float), Value::Float32(_))
267                    | (Numeric(Float), Value::Float(_)) => Ok(Numeric(Float)),
268                    (Numeric(Int), Value::Decimal(_))
269                    | (Numeric(Int64), Value::Decimal(_))
270                    | (Numeric(Int128), Value::Decimal(_))
271                    | (Numeric(UInt), Value::Decimal(_))
272                    | (Numeric(UInt64), Value::Decimal(_))
273                    | (Numeric(UInt128), Value::Decimal(_))
274                    | (Numeric(Float32), Value::Decimal(_))
275                    | (Numeric(Float), Value::Decimal(_))
276                    | (Numeric(Decimal), Value::Integer(_))
277                    | (Numeric(Decimal), Value::BigInt(_))
278                    | (Numeric(Decimal), Value::Int128(_))
279                    | (Numeric(Decimal), Value::UInteger(_))
280                    | (Numeric(Decimal), Value::UBigInt(_))
281                    | (Numeric(Decimal), Value::UInt128(_))
282                    | (Numeric(Decimal), Value::Float32(_))
283                    | (Numeric(Decimal), Value::Float(_))
284                    | (Numeric(Decimal), Value::Decimal(_)) => Ok(Numeric(Decimal)),
285                    (
286                        String {
287                            saw_enum,
288                            mut values,
289                        },
290                        Value::Text(text),
291                    ) => {
292                        if !values.iter().any(|candidate| candidate == text) {
293                            values.push(text.clone());
294                        }
295                        Ok(String { saw_enum, values })
296                    }
297                    (
298                        String {
299                            saw_enum: _,
300                            mut values,
301                        },
302                        Value::Enum(text),
303                    ) => {
304                        if !values.iter().any(|candidate| candidate == text) {
305                            values.push(text.clone());
306                        }
307                        Ok(String {
308                            saw_enum: true,
309                            values,
310                        })
311                    }
312                    (Blob, Value::Blob(_)) => Ok(Blob),
313                    (Blob, Value::Text(_)) => Ok(Blob),
314                    (Date, Value::Date(_)) => Ok(Date),
315                    (Time, Value::Time(_)) => Ok(Time),
316                    (DateTime, Value::DateTime(_)) => Ok(DateTime),
317                    (TimeWithTimeZone, Value::TimeWithTimeZone(_)) => Ok(TimeWithTimeZone),
318                    (Boolean, Value::Boolean(_)) => Ok(Boolean),
319                    (left, right) => Err(HematiteError::ParseError(format!(
320                        "SELECT INTO cannot infer a stable column type for '{}': {:?} cannot be combined with {:?}",
321                        column_name, left, right
322                    ))),
323                }
324            }
325
326            fn from_value(value: &Value, column_name: &str) -> Result<Option<Self>> {
327                use InferredKind::*;
328                use NumericKind::*;
329                let inferred = match value {
330                    Value::Null => return Ok(None),
331                    Value::Integer(_) => Numeric(Int),
332                    Value::BigInt(_) => Numeric(Int64),
333                    Value::Int128(_) => Numeric(Int128),
334                    Value::UInteger(_) => Numeric(UInt),
335                    Value::UBigInt(_) => Numeric(UInt64),
336                    Value::UInt128(_) => Numeric(UInt128),
337                    Value::Float32(_) => Numeric(Float32),
338                    Value::Float(_) => Numeric(Float),
339                    Value::Decimal(_) => Numeric(Decimal),
340                    Value::Text(text) => String {
341                        saw_enum: false,
342                        values: vec![text.clone()],
343                    },
344                    Value::Enum(text) => String {
345                        saw_enum: true,
346                        values: vec![text.clone()],
347                    },
348                    Value::Boolean(_) => Boolean,
349                    Value::Blob(_) => Blob,
350                    Value::Date(_) => Date,
351                    Value::Time(_) => Time,
352                    Value::DateTime(_) => DateTime,
353                    Value::TimeWithTimeZone(_) => TimeWithTimeZone,
354                    Value::IntervalYearMonth(_) | Value::IntervalDaySecond(_) => {
355                        return Err(HematiteError::ParseError(format!(
356                            "SELECT INTO cannot infer a stored column type for interval-valued column '{}'",
357                            column_name
358                        )))
359                    }
360                };
361                Ok(Some(inferred))
362            }
363
364            fn into_sql_type(self) -> SqlTypeName {
365                match self {
366                    InferredKind::Numeric(NumericKind::Int) => SqlTypeName::Int,
367                    InferredKind::Numeric(NumericKind::Int64) => SqlTypeName::Int64,
368                    InferredKind::Numeric(NumericKind::Int128) => SqlTypeName::Int128,
369                    InferredKind::Numeric(NumericKind::UInt) => SqlTypeName::UInt,
370                    InferredKind::Numeric(NumericKind::UInt64) => SqlTypeName::UInt64,
371                    InferredKind::Numeric(NumericKind::UInt128) => SqlTypeName::UInt128,
372                    InferredKind::Numeric(NumericKind::Float32) => SqlTypeName::Float32,
373                    InferredKind::Numeric(NumericKind::Float) => SqlTypeName::Float,
374                    InferredKind::Numeric(NumericKind::Decimal) => SqlTypeName::Decimal {
375                        precision: None,
376                        scale: None,
377                    },
378                    InferredKind::String {
379                        saw_enum: true,
380                        values,
381                    } => SqlTypeName::Enum(values),
382                    InferredKind::String { .. } => SqlTypeName::Text,
383                    InferredKind::Boolean => SqlTypeName::Boolean,
384                    InferredKind::Blob => SqlTypeName::Blob,
385                    InferredKind::Date => SqlTypeName::Date,
386                    InferredKind::Time => SqlTypeName::Time,
387                    InferredKind::DateTime => SqlTypeName::DateTime,
388                    InferredKind::TimeWithTimeZone => SqlTypeName::TimeWithTimeZone,
389                }
390            }
391        }
392
393        let mut inferred = None;
394        for row in values {
395            let Some(value) = row.get(index) else {
396                return Err(HematiteError::InternalError(format!(
397                    "SELECT INTO result row is missing projected column {}",
398                    index
399                )));
400            };
401
402            inferred = match (inferred, InferredKind::from_value(value, column_name)?) {
403                (None, None) => None,
404                (None, Some(kind)) => Some(kind),
405                (Some(kind), None) => Some(kind),
406                (Some(kind), Some(_)) => Some(kind.absorb(value, column_name)?),
407            };
408        }
409
410        Ok(inferred
411            .map(InferredKind::into_sql_type)
412            .unwrap_or(SqlTypeName::Text))
413    }
414
415    fn infer_select_into_columns(result: &QueryResult) -> Result<Vec<ColumnDefinition>> {
416        let column_names = Self::select_into_column_names(result);
417        column_names
418            .iter()
419            .enumerate()
420            .map(|(index, name)| {
421                Ok(ColumnDefinition {
422                    name: name.clone(),
423                    data_type: Self::infer_select_into_type(name, &result.rows, index)?,
424                    character_set: None,
425                    collation: None,
426                    nullable: true,
427                    primary_key: false,
428                    auto_increment: false,
429                    unique: false,
430                    default_value: None,
431                    check_constraint: None,
432                    references: None,
433                })
434            })
435            .collect()
436    }
437
438    fn lock_catalog(&self) -> Result<MutexGuard<'_, Catalog>> {
439        self.catalog.lock().map_err(|_| {
440            HematiteError::InternalError("SQL connection catalog mutex is poisoned".to_string())
441        })
442    }
443
444    pub fn new(database_path: &str) -> Result<Self> {
445        let catalog = Catalog::open_or_create(database_path)?;
446        Ok(Self {
447            catalog: Arc::new(Mutex::new(catalog)),
448            transaction: None,
449            trigger_depth: 0,
450        })
451    }
452
453    pub fn new_in_memory() -> Result<Self> {
454        let catalog = Catalog::open_in_memory()?;
455        Ok(Self {
456            catalog: Arc::new(Mutex::new(catalog)),
457            transaction: None,
458            trigger_depth: 0,
459        })
460    }
461
462    fn parse_statement(sql: &str) -> Result<crate::parser::ast::Statement> {
463        let mut lexer = Lexer::new(sql.to_string());
464        lexer.tokenize()?;
465
466        let mut parser = Parser::new(lexer.get_tokens().to_vec());
467        parser.parse()
468    }
469
470    fn parse_select_sql(sql: &str) -> Result<SelectStatement> {
471        match Self::parse_statement(&format!("{sql};"))? {
472            Statement::Select(select) => Ok(select),
473            other => Err(HematiteError::ParseError(format!(
474                "Expected stored view query to be SELECT, found {:?}",
475                other
476            ))),
477        }
478    }
479
480    fn expand_views_in_statement(statement: Statement, schema: &Schema) -> Result<Statement> {
481        match statement {
482            Statement::Explain(explain) => {
483                Ok(Statement::Explain(crate::parser::ast::ExplainStatement {
484                    statement: Box::new(Self::expand_views_in_statement(
485                        *explain.statement,
486                        schema,
487                    )?),
488                }))
489            }
490            Statement::Select(select) => Ok(Statement::Select(Self::expand_views_in_select(
491                select, schema,
492            )?)),
493            Statement::Insert(mut insert) => {
494                if let InsertSource::Select(select) = insert.source {
495                    insert.source = InsertSource::Select(Box::new(Self::expand_views_in_select(
496                        *select, schema,
497                    )?));
498                }
499                Ok(Statement::Insert(insert))
500            }
501            Statement::CreateView(mut create_view) => {
502                create_view.query = Self::expand_views_in_select(create_view.query, schema)?;
503                Ok(Statement::CreateView(create_view))
504            }
505            other => Ok(other),
506        }
507    }
508
509    fn expand_views_in_select(
510        mut select: SelectStatement,
511        schema: &Schema,
512    ) -> Result<SelectStatement> {
513        for cte in &mut select.with_clause {
514            cte.query = Box::new(Self::expand_views_in_select((*cte.query).clone(), schema)?);
515        }
516        let original_from = select.from.clone();
517        let select_context = select.clone();
518        select.from =
519            Self::expand_views_in_table_reference(original_from, &select_context, schema)?;
520        if let Some(where_clause) = &mut select.where_clause {
521            Self::expand_views_in_where_clause(where_clause, schema)?;
522        }
523        for expr in &mut select.group_by {
524            Self::expand_views_in_expression(expr, schema)?;
525        }
526        if let Some(having_clause) = &mut select.having_clause {
527            Self::expand_views_in_where_clause(having_clause, schema)?;
528        }
529        if let Some(set_operation) = &mut select.set_operation {
530            set_operation.right = Box::new(Self::expand_views_in_select(
531                (*set_operation.right).clone(),
532                schema,
533            )?);
534        }
535        for item in &mut select.columns {
536            if let crate::parser::ast::SelectItem::Expression(expr) = item {
537                Self::expand_views_in_expression(expr, schema)?;
538            }
539        }
540        Ok(select)
541    }
542
543    fn expand_views_in_table_reference(
544        from: TableReference,
545        select: &SelectStatement,
546        schema: &Schema,
547    ) -> Result<TableReference> {
548        match from {
549            TableReference::Table(table_name, alias) => {
550                if select.lookup_cte(&table_name).is_some()
551                    || schema.get_table_by_name(&table_name).is_some()
552                {
553                    Ok(TableReference::Table(table_name, alias))
554                } else if let Some(view) = schema.view(&table_name) {
555                    let subquery = Self::expand_views_in_select(
556                        Self::parse_select_sql(&view.query_sql)?,
557                        schema,
558                    )?;
559                    Ok(TableReference::Derived {
560                        subquery: Box::new(subquery),
561                        alias: alias.unwrap_or(table_name),
562                    })
563                } else {
564                    Ok(TableReference::Table(table_name, alias))
565                }
566            }
567            TableReference::Derived { subquery, alias } => Ok(TableReference::Derived {
568                subquery: Box::new(Self::expand_views_in_select(*subquery, schema)?),
569                alias,
570            }),
571            TableReference::CrossJoin(left, right) => Ok(TableReference::CrossJoin(
572                Box::new(Self::expand_views_in_table_reference(
573                    *left, select, schema,
574                )?),
575                Box::new(Self::expand_views_in_table_reference(
576                    *right, select, schema,
577                )?),
578            )),
579            TableReference::InnerJoin {
580                left,
581                right,
582                mut on,
583            } => {
584                Self::expand_views_in_condition(&mut on, schema)?;
585                Ok(TableReference::InnerJoin {
586                    left: Box::new(Self::expand_views_in_table_reference(
587                        *left, select, schema,
588                    )?),
589                    right: Box::new(Self::expand_views_in_table_reference(
590                        *right, select, schema,
591                    )?),
592                    on,
593                })
594            }
595            TableReference::LeftJoin {
596                left,
597                right,
598                mut on,
599            } => {
600                Self::expand_views_in_condition(&mut on, schema)?;
601                Ok(TableReference::LeftJoin {
602                    left: Box::new(Self::expand_views_in_table_reference(
603                        *left, select, schema,
604                    )?),
605                    right: Box::new(Self::expand_views_in_table_reference(
606                        *right, select, schema,
607                    )?),
608                    on,
609                })
610            }
611            TableReference::RightJoin {
612                left,
613                right,
614                mut on,
615            } => {
616                Self::expand_views_in_condition(&mut on, schema)?;
617                Ok(TableReference::RightJoin {
618                    left: Box::new(Self::expand_views_in_table_reference(
619                        *left, select, schema,
620                    )?),
621                    right: Box::new(Self::expand_views_in_table_reference(
622                        *right, select, schema,
623                    )?),
624                    on,
625                })
626            }
627            TableReference::FullOuterJoin {
628                left,
629                right,
630                mut on,
631            } => {
632                Self::expand_views_in_condition(&mut on, schema)?;
633                Ok(TableReference::FullOuterJoin {
634                    left: Box::new(Self::expand_views_in_table_reference(
635                        *left, select, schema,
636                    )?),
637                    right: Box::new(Self::expand_views_in_table_reference(
638                        *right, select, schema,
639                    )?),
640                    on,
641                })
642            }
643        }
644    }
645
646    fn expand_views_in_where_clause(where_clause: &mut WhereClause, schema: &Schema) -> Result<()> {
647        for condition in &mut where_clause.conditions {
648            Self::expand_views_in_condition(condition, schema)?;
649        }
650        Ok(())
651    }
652
653    fn expand_views_in_condition(condition: &mut Condition, schema: &Schema) -> Result<()> {
654        let mut expand = |subquery: &mut SelectStatement| -> Result<()> {
655            *subquery = Self::expand_views_in_select(subquery.clone(), schema)?;
656            Ok(())
657        };
658        Self::rewrite_nested_subqueries_in_condition(condition, &mut expand)
659    }
660
661    fn expand_views_in_expression(expr: &mut Expression, schema: &Schema) -> Result<()> {
662        let mut expand = |subquery: &mut SelectStatement| -> Result<()> {
663            *subquery = Self::expand_views_in_select(subquery.clone(), schema)?;
664            Ok(())
665        };
666        Self::rewrite_nested_subqueries_in_expression(expr, &mut expand)
667    }
668
669    fn normalize_statement(statement: Statement, schema: &Schema) -> Result<Statement> {
670        let mut statement = Self::expand_views_in_statement(statement, schema)?;
671        Self::rewrite_select_aliases_in_statement(&mut statement, schema)?;
672        Ok(statement)
673    }
674
675    fn rewrite_select_aliases_in_statement(
676        statement: &mut Statement,
677        schema: &Schema,
678    ) -> Result<()> {
679        match statement {
680            Statement::Explain(explain) => {
681                Self::rewrite_select_aliases_in_statement(&mut explain.statement, schema)
682            }
683            Statement::Select(select) => Self::rewrite_select_aliases_in_select(select, schema),
684            Statement::Insert(insert) => {
685                if let InsertSource::Select(select) = &mut insert.source {
686                    Self::rewrite_select_aliases_in_select(select, schema)?;
687                }
688                Ok(())
689            }
690            Statement::CreateView(create_view) => {
691                Self::rewrite_select_aliases_in_select(&mut create_view.query, schema)
692            }
693            _ => Ok(()),
694        }
695    }
696
697    fn rewrite_select_aliases_in_select(
698        select: &mut SelectStatement,
699        schema: &Schema,
700    ) -> Result<()> {
701        for cte in &mut select.with_clause {
702            if !cte.recursive {
703                Self::rewrite_select_aliases_in_select(&mut cte.query, schema)?;
704            }
705        }
706
707        Self::rewrite_select_aliases_in_table_reference(&mut select.from, schema)?;
708
709        for item in &mut select.columns {
710            match item {
711                crate::parser::ast::SelectItem::Expression(expr) => {
712                    Self::rewrite_nested_select_aliases_in_expression(expr, schema)?;
713                }
714                crate::parser::ast::SelectItem::Window { window, .. } => {
715                    for expr in &mut window.partition_by {
716                        Self::rewrite_nested_select_aliases_in_expression(expr, schema)?;
717                    }
718                }
719                crate::parser::ast::SelectItem::Wildcard
720                | crate::parser::ast::SelectItem::Column(_)
721                | crate::parser::ast::SelectItem::CountAll
722                | crate::parser::ast::SelectItem::Aggregate { .. } => {}
723            }
724        }
725
726        let alias_map = Self::where_alias_map(select);
727        let source_columns = source_column_names(select, schema)?
728            .into_iter()
729            .collect::<HashSet<_>>();
730
731        if let Some(where_clause) = &mut select.where_clause {
732            for condition in &mut where_clause.conditions {
733                Self::rewrite_where_aliases_in_condition(
734                    condition,
735                    &alias_map,
736                    &source_columns,
737                    &mut HashSet::new(),
738                )?;
739            }
740        }
741
742        for expr in &mut select.group_by {
743            Self::rewrite_nested_select_aliases_in_expression(expr, schema)?;
744        }
745
746        if let Some(having_clause) = &mut select.having_clause {
747            for condition in &mut having_clause.conditions {
748                Self::rewrite_nested_select_aliases_in_condition(condition, schema)?;
749            }
750        }
751
752        if let Some(set_operation) = &mut select.set_operation {
753            Self::rewrite_select_aliases_in_select(&mut set_operation.right, schema)?;
754        }
755
756        Ok(())
757    }
758
759    fn rewrite_select_aliases_in_table_reference(
760        from: &mut TableReference,
761        schema: &Schema,
762    ) -> Result<()> {
763        match from {
764            TableReference::Derived { subquery, .. } => {
765                Self::rewrite_select_aliases_in_select(subquery, schema)
766            }
767            TableReference::CrossJoin(left, right) => {
768                Self::rewrite_select_aliases_in_table_reference(left, schema)?;
769                Self::rewrite_select_aliases_in_table_reference(right, schema)
770            }
771            TableReference::InnerJoin { left, right, on }
772            | TableReference::LeftJoin { left, right, on }
773            | TableReference::RightJoin { left, right, on }
774            | TableReference::FullOuterJoin { left, right, on } => {
775                Self::rewrite_select_aliases_in_table_reference(left, schema)?;
776                Self::rewrite_select_aliases_in_table_reference(right, schema)?;
777                Self::rewrite_nested_select_aliases_in_condition(on, schema)
778            }
779            TableReference::Table(_, _) => Ok(()),
780        }
781    }
782
783    fn rewrite_nested_select_aliases_in_condition(
784        condition: &mut Condition,
785        schema: &Schema,
786    ) -> Result<()> {
787        let mut rewrite = |subquery: &mut SelectStatement| {
788            Self::rewrite_select_aliases_in_select(subquery, schema)
789        };
790        Self::rewrite_nested_subqueries_in_condition(condition, &mut rewrite)
791    }
792
793    fn rewrite_nested_select_aliases_in_expression(
794        expr: &mut Expression,
795        schema: &Schema,
796    ) -> Result<()> {
797        let mut rewrite = |subquery: &mut SelectStatement| {
798            Self::rewrite_select_aliases_in_select(subquery, schema)
799        };
800        Self::rewrite_nested_subqueries_in_expression(expr, &mut rewrite)
801    }
802
803    fn rewrite_nested_subqueries_in_condition<F>(
804        condition: &mut Condition,
805        on_subquery: &mut F,
806    ) -> Result<()>
807    where
808        F: FnMut(&mut SelectStatement) -> Result<()>,
809    {
810        match condition {
811            Condition::Comparison { left, right, .. } => {
812                Self::rewrite_nested_subqueries_in_expression(left, on_subquery)?;
813                Self::rewrite_nested_subqueries_in_expression(right, on_subquery)?;
814            }
815            Condition::InList { expr, values, .. } => {
816                Self::rewrite_nested_subqueries_in_expression(expr, on_subquery)?;
817                for value in values {
818                    Self::rewrite_nested_subqueries_in_expression(value, on_subquery)?;
819                }
820            }
821            Condition::InSubquery { expr, subquery, .. } => {
822                Self::rewrite_nested_subqueries_in_expression(expr, on_subquery)?;
823                on_subquery(subquery)?;
824            }
825            Condition::Between {
826                expr, lower, upper, ..
827            } => {
828                Self::rewrite_nested_subqueries_in_expression(expr, on_subquery)?;
829                Self::rewrite_nested_subqueries_in_expression(lower, on_subquery)?;
830                Self::rewrite_nested_subqueries_in_expression(upper, on_subquery)?;
831            }
832            Condition::Like { expr, pattern, .. } => {
833                Self::rewrite_nested_subqueries_in_expression(expr, on_subquery)?;
834                Self::rewrite_nested_subqueries_in_expression(pattern, on_subquery)?;
835            }
836            Condition::Exists { subquery, .. } => on_subquery(subquery)?,
837            Condition::NullCheck { expr, .. } => {
838                Self::rewrite_nested_subqueries_in_expression(expr, on_subquery)?;
839            }
840            Condition::Not(inner) => {
841                Self::rewrite_nested_subqueries_in_condition(inner, on_subquery)?
842            }
843            Condition::Logical { left, right, .. } => {
844                Self::rewrite_nested_subqueries_in_condition(left, on_subquery)?;
845                Self::rewrite_nested_subqueries_in_condition(right, on_subquery)?;
846            }
847        }
848        Ok(())
849    }
850
851    fn rewrite_nested_subqueries_in_expression<F>(
852        expr: &mut Expression,
853        on_subquery: &mut F,
854    ) -> Result<()>
855    where
856        F: FnMut(&mut SelectStatement) -> Result<()>,
857    {
858        match expr {
859            Expression::ScalarSubquery(subquery) => on_subquery(subquery),
860            Expression::Cast { expr, .. }
861            | Expression::UnaryMinus(expr)
862            | Expression::UnaryNot(expr)
863            | Expression::NullCheck { expr, .. } => {
864                Self::rewrite_nested_subqueries_in_expression(expr, on_subquery)
865            }
866            Expression::Case {
867                branches,
868                else_expr,
869            } => {
870                for branch in branches {
871                    Self::rewrite_nested_subqueries_in_expression(
872                        &mut branch.condition,
873                        on_subquery,
874                    )?;
875                    Self::rewrite_nested_subqueries_in_expression(&mut branch.result, on_subquery)?;
876                }
877                if let Some(else_expr) = else_expr {
878                    Self::rewrite_nested_subqueries_in_expression(else_expr, on_subquery)?;
879                }
880                Ok(())
881            }
882            Expression::ScalarFunctionCall { args, .. } => {
883                for arg in args {
884                    Self::rewrite_nested_subqueries_in_expression(arg, on_subquery)?;
885                }
886                Ok(())
887            }
888            Expression::Binary { left, right, .. }
889            | Expression::Comparison { left, right, .. }
890            | Expression::Logical { left, right, .. } => {
891                Self::rewrite_nested_subqueries_in_expression(left, on_subquery)?;
892                Self::rewrite_nested_subqueries_in_expression(right, on_subquery)
893            }
894            Expression::InList { expr, values, .. } => {
895                Self::rewrite_nested_subqueries_in_expression(expr, on_subquery)?;
896                for value in values {
897                    Self::rewrite_nested_subqueries_in_expression(value, on_subquery)?;
898                }
899                Ok(())
900            }
901            Expression::InSubquery { expr, subquery, .. } => {
902                Self::rewrite_nested_subqueries_in_expression(expr, on_subquery)?;
903                on_subquery(subquery)
904            }
905            Expression::Between {
906                expr, lower, upper, ..
907            } => {
908                Self::rewrite_nested_subqueries_in_expression(expr, on_subquery)?;
909                Self::rewrite_nested_subqueries_in_expression(lower, on_subquery)?;
910                Self::rewrite_nested_subqueries_in_expression(upper, on_subquery)
911            }
912            Expression::Like { expr, pattern, .. } => {
913                Self::rewrite_nested_subqueries_in_expression(expr, on_subquery)?;
914                Self::rewrite_nested_subqueries_in_expression(pattern, on_subquery)
915            }
916            Expression::Exists { subquery, .. } => on_subquery(subquery),
917            Expression::AggregateCall { .. }
918            | Expression::Column(_)
919            | Expression::Literal(_)
920            | Expression::IntervalLiteral { .. }
921            | Expression::Parameter(_) => Ok(()),
922        }
923    }
924
925    fn where_alias_map(select: &SelectStatement) -> HashMap<String, Expression> {
926        let mut aliases = HashMap::new();
927        for (index, alias) in select.column_aliases.iter().enumerate() {
928            let Some(alias) = alias.as_ref() else {
929                continue;
930            };
931            let Some(item) = select.columns.get(index) else {
932                continue;
933            };
934
935            let replacement = match item {
936                crate::parser::ast::SelectItem::Column(name) => Expression::Column(name.clone()),
937                crate::parser::ast::SelectItem::Expression(expr) => expr.clone(),
938                _ => continue,
939            };
940            aliases.insert(alias.clone(), replacement);
941        }
942        aliases
943    }
944
945    fn rewrite_where_aliases_in_condition(
946        condition: &mut Condition,
947        aliases: &HashMap<String, Expression>,
948        source_columns: &HashSet<String>,
949        active_aliases: &mut HashSet<String>,
950    ) -> Result<()> {
951        match condition {
952            Condition::Comparison { left, right, .. } => {
953                Self::rewrite_where_aliases_in_expression(
954                    left,
955                    aliases,
956                    source_columns,
957                    active_aliases,
958                )?;
959                Self::rewrite_where_aliases_in_expression(
960                    right,
961                    aliases,
962                    source_columns,
963                    active_aliases,
964                )?;
965            }
966            Condition::InList { expr, values, .. } => {
967                Self::rewrite_where_aliases_in_expression(
968                    expr,
969                    aliases,
970                    source_columns,
971                    active_aliases,
972                )?;
973                for value in values {
974                    Self::rewrite_where_aliases_in_expression(
975                        value,
976                        aliases,
977                        source_columns,
978                        active_aliases,
979                    )?;
980                }
981            }
982            Condition::InSubquery { expr, .. } => {
983                Self::rewrite_where_aliases_in_expression(
984                    expr,
985                    aliases,
986                    source_columns,
987                    active_aliases,
988                )?;
989            }
990            Condition::Between {
991                expr, lower, upper, ..
992            } => {
993                Self::rewrite_where_aliases_in_expression(
994                    expr,
995                    aliases,
996                    source_columns,
997                    active_aliases,
998                )?;
999                Self::rewrite_where_aliases_in_expression(
1000                    lower,
1001                    aliases,
1002                    source_columns,
1003                    active_aliases,
1004                )?;
1005                Self::rewrite_where_aliases_in_expression(
1006                    upper,
1007                    aliases,
1008                    source_columns,
1009                    active_aliases,
1010                )?;
1011            }
1012            Condition::Like { expr, pattern, .. } => {
1013                Self::rewrite_where_aliases_in_expression(
1014                    expr,
1015                    aliases,
1016                    source_columns,
1017                    active_aliases,
1018                )?;
1019                Self::rewrite_where_aliases_in_expression(
1020                    pattern,
1021                    aliases,
1022                    source_columns,
1023                    active_aliases,
1024                )?;
1025            }
1026            Condition::Exists { .. } => {}
1027            Condition::NullCheck { expr, .. } => {
1028                Self::rewrite_where_aliases_in_expression(
1029                    expr,
1030                    aliases,
1031                    source_columns,
1032                    active_aliases,
1033                )?;
1034            }
1035            Condition::Not(inner) => {
1036                Self::rewrite_where_aliases_in_condition(
1037                    inner,
1038                    aliases,
1039                    source_columns,
1040                    active_aliases,
1041                )?;
1042            }
1043            Condition::Logical { left, right, .. } => {
1044                Self::rewrite_where_aliases_in_condition(
1045                    left,
1046                    aliases,
1047                    source_columns,
1048                    active_aliases,
1049                )?;
1050                Self::rewrite_where_aliases_in_condition(
1051                    right,
1052                    aliases,
1053                    source_columns,
1054                    active_aliases,
1055                )?;
1056            }
1057        }
1058        Ok(())
1059    }
1060
1061    fn rewrite_where_aliases_in_expression(
1062        expr: &mut Expression,
1063        aliases: &HashMap<String, Expression>,
1064        source_columns: &HashSet<String>,
1065        active_aliases: &mut HashSet<String>,
1066    ) -> Result<()> {
1067        match expr {
1068            Expression::Column(name) => {
1069                if SelectStatement::split_column_reference(name).0.is_some()
1070                    || source_columns.contains(name)
1071                {
1072                    return Ok(());
1073                }
1074
1075                let Some(replacement) = aliases.get(name).cloned() else {
1076                    return Ok(());
1077                };
1078
1079                if !active_aliases.insert(name.clone()) {
1080                    return Err(HematiteError::ParseError(format!(
1081                        "Select alias '{}' is recursively defined",
1082                        name
1083                    )));
1084                }
1085
1086                let mut replacement = replacement;
1087                Self::rewrite_where_aliases_in_expression(
1088                    &mut replacement,
1089                    aliases,
1090                    source_columns,
1091                    active_aliases,
1092                )?;
1093                active_aliases.remove(name);
1094                *expr = replacement;
1095            }
1096            Expression::Cast { expr, .. }
1097            | Expression::UnaryMinus(expr)
1098            | Expression::UnaryNot(expr)
1099            | Expression::NullCheck { expr, .. } => {
1100                Self::rewrite_where_aliases_in_expression(
1101                    expr,
1102                    aliases,
1103                    source_columns,
1104                    active_aliases,
1105                )?;
1106            }
1107            Expression::Case {
1108                branches,
1109                else_expr,
1110            } => {
1111                for branch in branches {
1112                    Self::rewrite_where_aliases_in_expression(
1113                        &mut branch.condition,
1114                        aliases,
1115                        source_columns,
1116                        active_aliases,
1117                    )?;
1118                    Self::rewrite_where_aliases_in_expression(
1119                        &mut branch.result,
1120                        aliases,
1121                        source_columns,
1122                        active_aliases,
1123                    )?;
1124                }
1125                if let Some(else_expr) = else_expr {
1126                    Self::rewrite_where_aliases_in_expression(
1127                        else_expr,
1128                        aliases,
1129                        source_columns,
1130                        active_aliases,
1131                    )?;
1132                }
1133            }
1134            Expression::ScalarFunctionCall { args, .. } => {
1135                for arg in args {
1136                    Self::rewrite_where_aliases_in_expression(
1137                        arg,
1138                        aliases,
1139                        source_columns,
1140                        active_aliases,
1141                    )?;
1142                }
1143            }
1144            Expression::Binary { left, right, .. }
1145            | Expression::Comparison { left, right, .. }
1146            | Expression::Logical { left, right, .. } => {
1147                Self::rewrite_where_aliases_in_expression(
1148                    left,
1149                    aliases,
1150                    source_columns,
1151                    active_aliases,
1152                )?;
1153                Self::rewrite_where_aliases_in_expression(
1154                    right,
1155                    aliases,
1156                    source_columns,
1157                    active_aliases,
1158                )?;
1159            }
1160            Expression::InList { expr, values, .. } => {
1161                Self::rewrite_where_aliases_in_expression(
1162                    expr,
1163                    aliases,
1164                    source_columns,
1165                    active_aliases,
1166                )?;
1167                for value in values {
1168                    Self::rewrite_where_aliases_in_expression(
1169                        value,
1170                        aliases,
1171                        source_columns,
1172                        active_aliases,
1173                    )?;
1174                }
1175            }
1176            Expression::Between {
1177                expr, lower, upper, ..
1178            } => {
1179                Self::rewrite_where_aliases_in_expression(
1180                    expr,
1181                    aliases,
1182                    source_columns,
1183                    active_aliases,
1184                )?;
1185                Self::rewrite_where_aliases_in_expression(
1186                    lower,
1187                    aliases,
1188                    source_columns,
1189                    active_aliases,
1190                )?;
1191                Self::rewrite_where_aliases_in_expression(
1192                    upper,
1193                    aliases,
1194                    source_columns,
1195                    active_aliases,
1196                )?;
1197            }
1198            Expression::Like { expr, pattern, .. } => {
1199                Self::rewrite_where_aliases_in_expression(
1200                    expr,
1201                    aliases,
1202                    source_columns,
1203                    active_aliases,
1204                )?;
1205                Self::rewrite_where_aliases_in_expression(
1206                    pattern,
1207                    aliases,
1208                    source_columns,
1209                    active_aliases,
1210                )?;
1211            }
1212            Expression::AggregateCall { .. }
1213            | Expression::ScalarSubquery(_)
1214            | Expression::InSubquery { .. }
1215            | Expression::Exists { .. }
1216            | Expression::Literal(_)
1217            | Expression::IntervalLiteral { .. }
1218            | Expression::Parameter(_) => {}
1219        }
1220        Ok(())
1221    }
1222
1223    pub(crate) fn execute_statement(
1224        &mut self,
1225        statement: crate::parser::ast::Statement,
1226    ) -> Result<QueryResult> {
1227        match statement {
1228            crate::parser::ast::Statement::Begin => {
1229                self.begin_active_transaction()?;
1230                return Ok(Self::empty_result());
1231            }
1232            crate::parser::ast::Statement::Commit => {
1233                self.commit_active_transaction()?;
1234                return Ok(Self::empty_result());
1235            }
1236            crate::parser::ast::Statement::Rollback => {
1237                self.rollback_active_transaction()?;
1238                return Ok(Self::empty_result());
1239            }
1240            crate::parser::ast::Statement::Savepoint(name) => {
1241                self.create_savepoint(&name)?;
1242                return Ok(Self::empty_result());
1243            }
1244            crate::parser::ast::Statement::RollbackToSavepoint(name) => {
1245                self.rollback_to_savepoint(&name)?;
1246                return Ok(Self::empty_result());
1247            }
1248            crate::parser::ast::Statement::ReleaseSavepoint(name) => {
1249                self.release_savepoint(&name)?;
1250                return Ok(Self::empty_result());
1251            }
1252            crate::parser::ast::Statement::Explain(explain) => {
1253                return self.execute_explain_statement(*explain.statement);
1254            }
1255            crate::parser::ast::Statement::Describe(describe) => {
1256                return self.execute_describe_statement(&describe.table);
1257            }
1258            crate::parser::ast::Statement::ShowTables => {
1259                return self.execute_show_tables_statement();
1260            }
1261            crate::parser::ast::Statement::ShowViews => {
1262                return self.execute_show_views_statement();
1263            }
1264            crate::parser::ast::Statement::ShowIndexes(table_name) => {
1265                return self.execute_show_indexes_statement(table_name.as_deref());
1266            }
1267            crate::parser::ast::Statement::ShowTriggers(table_name) => {
1268                return self.execute_show_triggers_statement(table_name.as_deref());
1269            }
1270            crate::parser::ast::Statement::ShowCreateTable(table_name) => {
1271                return self.execute_show_create_table_statement(&table_name);
1272            }
1273            crate::parser::ast::Statement::ShowCreateView(view_name) => {
1274                return self.execute_show_create_view_statement(&view_name);
1275            }
1276            crate::parser::ast::Statement::SelectInto(select_into) => {
1277                return self.execute_select_into_statement(select_into);
1278            }
1279            crate::parser::ast::Statement::CreateView(create_view) => {
1280                return self.execute_create_view_statement(create_view);
1281            }
1282            crate::parser::ast::Statement::DropView(drop_view) => {
1283                return self.execute_drop_view_statement(&drop_view.view, drop_view.if_exists);
1284            }
1285            crate::parser::ast::Statement::CreateTrigger(create_trigger) => {
1286                return self.execute_create_trigger_statement(create_trigger);
1287            }
1288            crate::parser::ast::Statement::DropTrigger(drop_trigger) => {
1289                return self
1290                    .execute_drop_trigger_statement(&drop_trigger.trigger, drop_trigger.if_exists);
1291            }
1292            _ => {}
1293        }
1294
1295        if statement.is_read_only() {
1296            return self.execute_read_statement(statement);
1297        }
1298
1299        self.execute_mutating_statement(statement)
1300    }
1301
1302    fn execute_explain_statement(
1303        &mut self,
1304        statement: crate::parser::ast::Statement,
1305    ) -> Result<QueryResult> {
1306        let statement = match statement {
1307            Statement::SelectInto(select_into) => Statement::Select(select_into.query),
1308            other => other,
1309        };
1310        let (schema, table_row_counts) = self.read_planning_state()?;
1311        let statement = Self::expand_views_in_statement(statement, &schema)?;
1312        let planner = QueryPlanner::new(schema).with_table_row_counts(table_row_counts);
1313        let plan = planner.plan(statement)?;
1314        Ok(QueryResult {
1315            affected_rows: 0,
1316            columns: vec!["kind".to_string(), "detail".to_string()],
1317            rows: vec![
1318                vec![
1319                    Value::Text("node".to_string()),
1320                    Value::Text(format!("{:?}", plan.node)),
1321                ],
1322                vec![
1323                    Value::Text("estimated_cost".to_string()),
1324                    Value::Text(format!("{:.2}", plan.estimated_cost)),
1325                ],
1326            ],
1327        })
1328    }
1329
1330    fn execute_describe_statement(&mut self, table_name: &str) -> Result<QueryResult> {
1331        let catalog_guard = self.lock_catalog()?;
1332        query_metadata::describe_table(&catalog_guard, table_name)
1333    }
1334
1335    fn execute_show_tables_statement(&mut self) -> Result<QueryResult> {
1336        let catalog_guard = self.lock_catalog()?;
1337        query_metadata::show_tables(&catalog_guard)
1338    }
1339
1340    fn execute_show_views_statement(&mut self) -> Result<QueryResult> {
1341        let catalog_guard = self.lock_catalog()?;
1342        query_metadata::show_views(&catalog_guard)
1343    }
1344
1345    fn execute_show_indexes_statement(&mut self, table_name: Option<&str>) -> Result<QueryResult> {
1346        let catalog_guard = self.lock_catalog()?;
1347        query_metadata::show_indexes(&catalog_guard, table_name)
1348    }
1349
1350    fn execute_show_triggers_statement(&mut self, table_name: Option<&str>) -> Result<QueryResult> {
1351        let catalog_guard = self.lock_catalog()?;
1352        query_metadata::show_triggers(&catalog_guard, table_name)
1353    }
1354
1355    fn execute_show_create_table_statement(&mut self, table_name: &str) -> Result<QueryResult> {
1356        let catalog_guard = self.lock_catalog()?;
1357        query_metadata::show_create_table(&catalog_guard, table_name)
1358    }
1359
1360    fn execute_show_create_view_statement(&mut self, view_name: &str) -> Result<QueryResult> {
1361        let catalog_guard = self.lock_catalog()?;
1362        query_metadata::show_create_view(&catalog_guard, view_name)
1363    }
1364
1365    fn execute_select_into_statement(
1366        &mut self,
1367        statement: SelectIntoStatement,
1368    ) -> Result<QueryResult> {
1369        let (schema, _) = self.read_planning_state()?;
1370        if schema.get_table_by_name(&statement.table).is_some()
1371            || schema.view(&statement.table).is_some()
1372        {
1373            return Err(HematiteError::ParseError(format!(
1374                "Table '{}' already exists",
1375                statement.table
1376            )));
1377        }
1378
1379        let normalized_query =
1380            match Self::normalize_statement(Statement::Select(statement.query.clone()), &schema)? {
1381                Statement::Select(select) => select,
1382                _ => unreachable!("normalized SELECT INTO query should remain a select"),
1383            };
1384        validate_statement(&Statement::Select(normalized_query), &schema)?;
1385
1386        let query_result =
1387            self.execute_read_statement(Statement::Select(statement.query.clone()))?;
1388        let projected_columns = Self::infer_select_into_columns(&query_result)?;
1389        let insert_columns = projected_columns
1390            .iter()
1391            .map(|column| column.name.clone())
1392            .collect::<Vec<_>>();
1393        let synthetic_pk = Self::select_into_synthetic_pk_name(&insert_columns);
1394
1395        let mut create_columns = Vec::with_capacity(projected_columns.len() + 1);
1396        create_columns.push(ColumnDefinition {
1397            name: synthetic_pk,
1398            data_type: SqlTypeName::Int,
1399            character_set: None,
1400            collation: None,
1401            nullable: false,
1402            primary_key: true,
1403            auto_increment: true,
1404            unique: false,
1405            default_value: None,
1406            check_constraint: None,
1407            references: None,
1408        });
1409        create_columns.extend(projected_columns);
1410
1411        let mut implicit_mutation = Some(ImplicitMutation::begin(self)?);
1412        let result: Result<QueryResult> = (|| {
1413            self.execute_mutating_statement_in_scope(
1414                Statement::Create(CreateStatement {
1415                    table: statement.table.clone(),
1416                    columns: create_columns,
1417                    constraints: Vec::new(),
1418                    if_not_exists: false,
1419                }),
1420                false,
1421            )?;
1422
1423            let insert_result = self.execute_mutating_statement_in_scope(
1424                Statement::Insert(InsertStatement {
1425                    table: statement.table.clone(),
1426                    columns: insert_columns,
1427                    source: InsertSource::Select(Box::new(statement.query)),
1428                    on_duplicate: None,
1429                }),
1430                false,
1431            )?;
1432
1433            Ok(Self::mutation_result(insert_result.affected_rows))
1434        })();
1435
1436        match result {
1437            Ok(result) => {
1438                implicit_mutation
1439                    .take()
1440                    .expect("implicit mutation should be present")
1441                    .commit(self)?;
1442                Ok(result)
1443            }
1444            Err(err) => {
1445                implicit_mutation
1446                    .take()
1447                    .expect("implicit mutation should be present")
1448                    .rollback(self)?;
1449                Err(err)
1450            }
1451        }
1452    }
1453
1454    fn execute_create_view_statement(
1455        &mut self,
1456        statement: crate::parser::ast::CreateViewStatement,
1457    ) -> Result<QueryResult> {
1458        let mut implicit_mutation = Some(ImplicitMutation::begin(self)?);
1459        let result: Result<QueryResult> = (|| {
1460            let mut catalog_guard = self.lock_catalog()?;
1461            let schema = catalog_guard.clone_schema();
1462            let dependencies = statement.query.dependency_names();
1463            if dependencies
1464                .iter()
1465                .any(|dependency| dependency.eq_ignore_ascii_case(&statement.view))
1466            {
1467                return Err(HematiteError::ParseError(format!(
1468                    "View '{}' cannot depend on itself",
1469                    statement.view
1470                )));
1471            }
1472            let normalized_query = match Self::normalize_statement(
1473                Statement::Select(statement.query.clone()),
1474                &schema,
1475            )? {
1476                Statement::Select(select) => select,
1477                _ => unreachable!("normalized create view query should remain a select"),
1478            };
1479            validate_statement(
1480                &crate::parser::ast::Statement::CreateView(CreateViewStatement {
1481                    view: statement.view.clone(),
1482                    if_not_exists: statement.if_not_exists,
1483                    query: normalized_query.clone(),
1484                }),
1485                &schema,
1486            )?;
1487
1488            if statement.if_not_exists && catalog_guard.get_view(&statement.view)?.is_some() {
1489                Ok(Self::mutation_result(0))
1490            } else {
1491                let column_names = projected_column_names(&normalized_query, &schema)?;
1492
1493                catalog_guard.create_view(crate::catalog::View {
1494                    name: statement.view.clone(),
1495                    query_sql: statement.query.to_sql(),
1496                    column_names,
1497                    dependencies,
1498                })?;
1499                Ok(Self::mutation_result(0))
1500            }
1501        })();
1502
1503        match result {
1504            Ok(result) => {
1505                implicit_mutation
1506                    .take()
1507                    .expect("implicit mutation should be present")
1508                    .commit(self)?;
1509                Ok(result)
1510            }
1511            Err(err) => {
1512                implicit_mutation
1513                    .take()
1514                    .expect("implicit mutation should be present")
1515                    .rollback(self)?;
1516                Err(err)
1517            }
1518        }
1519    }
1520
1521    fn execute_drop_view_statement(
1522        &mut self,
1523        view_name: &str,
1524        if_exists: bool,
1525    ) -> Result<QueryResult> {
1526        let mut implicit_mutation = Some(ImplicitMutation::begin(self)?);
1527        let result: Result<QueryResult> = (|| {
1528            let mut catalog_guard = self.lock_catalog()?;
1529            if if_exists && catalog_guard.get_view(view_name)?.is_none() {
1530                Ok(Self::mutation_result(0))
1531            } else {
1532                catalog_guard.drop_view(view_name)?;
1533                Ok(Self::mutation_result(0))
1534            }
1535        })();
1536
1537        match result {
1538            Ok(result) => {
1539                implicit_mutation
1540                    .take()
1541                    .expect("implicit mutation should be present")
1542                    .commit(self)?;
1543                Ok(result)
1544            }
1545            Err(err) => {
1546                implicit_mutation
1547                    .take()
1548                    .expect("implicit mutation should be present")
1549                    .rollback(self)?;
1550                Err(err)
1551            }
1552        }
1553    }
1554
1555    fn execute_create_trigger_statement(
1556        &mut self,
1557        statement: crate::parser::ast::CreateTriggerStatement,
1558    ) -> Result<QueryResult> {
1559        let mut implicit_mutation = Some(ImplicitMutation::begin(self)?);
1560        let result: Result<QueryResult> = (|| {
1561            let mut catalog_guard = self.lock_catalog()?;
1562            let schema = catalog_guard.clone_schema();
1563            validate_statement(
1564                &crate::parser::ast::Statement::CreateTrigger(statement.clone()),
1565                &schema,
1566            )?;
1567
1568            catalog_guard.create_trigger(crate::catalog::Trigger {
1569                name: statement.trigger.clone(),
1570                table_name: statement.table.clone(),
1571                event: match statement.event {
1572                    TriggerEvent::Insert => crate::catalog::TriggerEvent::Insert,
1573                    TriggerEvent::Update => crate::catalog::TriggerEvent::Update,
1574                    TriggerEvent::Delete => crate::catalog::TriggerEvent::Delete,
1575                },
1576                body_sql: statement.body.to_sql(),
1577                old_alias: match statement.event {
1578                    TriggerEvent::Insert => None,
1579                    TriggerEvent::Update | TriggerEvent::Delete => Some("OLD".to_string()),
1580                },
1581                new_alias: match statement.event {
1582                    TriggerEvent::Delete => None,
1583                    TriggerEvent::Insert | TriggerEvent::Update => Some("NEW".to_string()),
1584                },
1585            })?;
1586            Ok(Self::mutation_result(0))
1587        })();
1588
1589        match result {
1590            Ok(result) => {
1591                implicit_mutation
1592                    .take()
1593                    .expect("implicit mutation should be present")
1594                    .commit(self)?;
1595                Ok(result)
1596            }
1597            Err(err) => {
1598                implicit_mutation
1599                    .take()
1600                    .expect("implicit mutation should be present")
1601                    .rollback(self)?;
1602                Err(err)
1603            }
1604        }
1605    }
1606
1607    fn execute_drop_trigger_statement(
1608        &mut self,
1609        trigger_name: &str,
1610        if_exists: bool,
1611    ) -> Result<QueryResult> {
1612        let mut implicit_mutation = Some(ImplicitMutation::begin(self)?);
1613        let result: Result<QueryResult> = (|| {
1614            let mut catalog_guard = self.lock_catalog()?;
1615            if if_exists && catalog_guard.get_trigger(trigger_name)?.is_none() {
1616                Ok(Self::mutation_result(0))
1617            } else {
1618                catalog_guard.drop_trigger(trigger_name)?;
1619                Ok(Self::mutation_result(0))
1620            }
1621        })();
1622
1623        match result {
1624            Ok(result) => {
1625                implicit_mutation
1626                    .take()
1627                    .expect("implicit mutation should be present")
1628                    .commit(self)?;
1629                Ok(result)
1630            }
1631            Err(err) => {
1632                implicit_mutation
1633                    .take()
1634                    .expect("implicit mutation should be present")
1635                    .rollback(self)?;
1636                Err(err)
1637            }
1638        }
1639    }
1640
1641    pub(crate) fn execute_statement_result(
1642        &mut self,
1643        statement: crate::parser::ast::Statement,
1644    ) -> Result<ExecutedStatement> {
1645        self.execute_statement(statement)
1646            .map(ExecutedStatement::from_query_result)
1647    }
1648
1649    fn execute_read_statement(
1650        &mut self,
1651        statement: crate::parser::ast::Statement,
1652    ) -> Result<QueryResult> {
1653        let (schema, mut executor) = self.plan_executor(statement)?;
1654
1655        let result = {
1656            let mut catalog_guard = self.lock_catalog()?;
1657            catalog_guard.with_read_engine(|engine| {
1658                let mut ctx = ExecutionContext::for_read(&schema, engine);
1659                executor.execute(&mut ctx)
1660            })?
1661        };
1662
1663        Ok(result)
1664    }
1665
1666    fn execute_mutating_statement(
1667        &mut self,
1668        statement: crate::parser::ast::Statement,
1669    ) -> Result<QueryResult> {
1670        self.execute_mutating_statement_in_scope(statement, true)
1671    }
1672
1673    fn execute_mutating_statement_in_scope(
1674        &mut self,
1675        statement: crate::parser::ast::Statement,
1676        use_implicit_mutation: bool,
1677    ) -> Result<QueryResult> {
1678        let persists_schema = statement.mutates_schema();
1679        let (schema, mut executor) = self.plan_executor(statement)?;
1680        let mut implicit_mutation = if use_implicit_mutation {
1681            Some(ImplicitMutation::begin(self)?)
1682        } else {
1683            None
1684        };
1685
1686        let execution_result = {
1687            let mut catalog_guard = self.lock_catalog()?;
1688            catalog_guard.with_engine(|engine| {
1689                let mut ctx = ExecutionContext::for_mutation(&schema, engine);
1690                let result = executor.execute(&mut ctx)?;
1691                Ok((result, ctx.catalog, ctx.mutation_events))
1692            })
1693        };
1694
1695        match execution_result {
1696            Ok((result, updated_schema, mutation_events)) => {
1697                if persists_schema {
1698                    let mut catalog_guard = self.lock_catalog()?;
1699                    if let Err(err) = catalog_guard.replace_schema(updated_schema) {
1700                        drop(catalog_guard);
1701                        if let Some(implicit_mutation) = implicit_mutation.take() {
1702                            implicit_mutation.rollback(self)?;
1703                        }
1704                        return Err(err);
1705                    }
1706                }
1707
1708                if let Err(err) = self.fire_triggers(mutation_events) {
1709                    if let Some(implicit_mutation) = implicit_mutation.take() {
1710                        implicit_mutation.rollback(self)?;
1711                    }
1712                    return Err(err);
1713                }
1714
1715                if let Some(implicit_mutation) = implicit_mutation.take() {
1716                    implicit_mutation.commit(self)?;
1717                }
1718
1719                Ok(result)
1720            }
1721            Err(err) => {
1722                if let Some(implicit_mutation) = implicit_mutation.take() {
1723                    implicit_mutation.rollback(self)?;
1724                }
1725                Err(err)
1726            }
1727        }
1728    }
1729
1730    fn plan_executor(
1731        &self,
1732        statement: crate::parser::ast::Statement,
1733    ) -> Result<(Schema, Box<dyn QueryExecutor>)> {
1734        let (schema, table_row_counts) = self.read_planning_state()?;
1735        let statement = Self::normalize_statement(statement, &schema)?;
1736        let planner = QueryPlanner::new(schema.clone()).with_table_row_counts(table_row_counts);
1737        let plan = planner.plan(statement)?;
1738        Ok((schema, plan.into_executor()))
1739    }
1740
1741    fn read_planning_state(&self) -> Result<(Schema, HashMap<String, usize>)> {
1742        let mut catalog_guard = self.lock_catalog()?;
1743        let schema = catalog_guard.clone_schema();
1744        let table_row_counts =
1745            catalog_guard.with_engine(|engine| Ok(Self::collect_table_row_counts(engine)))?;
1746        Ok((schema, table_row_counts))
1747    }
1748
1749    fn collect_table_row_counts(engine: &CatalogEngine) -> HashMap<String, usize> {
1750        engine
1751            .get_table_metadata()
1752            .iter()
1753            .map(|(name, metadata)| (name.clone(), metadata.row_count as usize))
1754            .collect()
1755    }
1756
1757    fn fire_triggers(&mut self, mutation_events: Vec<MutationEvent>) -> Result<()> {
1758        if mutation_events.is_empty() {
1759            return Ok(());
1760        }
1761
1762        if self.trigger_depth >= 32 {
1763            return Err(HematiteError::ParseError(
1764                "Trigger recursion limit exceeded".to_string(),
1765            ));
1766        }
1767
1768        self.trigger_depth += 1;
1769        let result = (|| {
1770            for event in mutation_events {
1771                let (table_name, event_kind, old_row, new_row) = match event {
1772                    MutationEvent::Insert {
1773                        table_name,
1774                        new_row,
1775                    } => (
1776                        table_name,
1777                        crate::catalog::TriggerEvent::Insert,
1778                        None,
1779                        Some(new_row),
1780                    ),
1781                    MutationEvent::Update {
1782                        table_name,
1783                        old_row,
1784                        new_row,
1785                    } => (
1786                        table_name,
1787                        crate::catalog::TriggerEvent::Update,
1788                        Some(old_row),
1789                        Some(new_row),
1790                    ),
1791                    MutationEvent::Delete {
1792                        table_name,
1793                        old_row,
1794                    } => (
1795                        table_name,
1796                        crate::catalog::TriggerEvent::Delete,
1797                        Some(old_row),
1798                        None,
1799                    ),
1800                };
1801
1802                let (table, triggers) = {
1803                    let catalog_guard = self.lock_catalog()?;
1804                    let table = catalog_guard
1805                        .get_table_by_name(&table_name)?
1806                        .ok_or_else(|| {
1807                            HematiteError::InternalError(format!(
1808                                "Table '{}' disappeared while firing triggers",
1809                                table_name
1810                            ))
1811                        })?;
1812                    let mut triggers = catalog_guard
1813                        .list_triggers()?
1814                        .into_iter()
1815                        .filter_map(|name| catalog_guard.get_trigger(&name).ok().flatten())
1816                        .filter(|trigger| {
1817                            trigger.table_name == table_name && trigger.event == event_kind
1818                        })
1819                        .collect::<Vec<_>>();
1820                    triggers.sort_by(|left, right| left.name.cmp(&right.name));
1821                    (table, triggers)
1822                };
1823
1824                for trigger in triggers {
1825                    let trigger_statement =
1826                        Self::parse_statement(&format!("{};", trigger.body_sql))?;
1827                    let trigger_statement = substitute_trigger_statement(
1828                        trigger_statement,
1829                        &table,
1830                        old_row.as_ref(),
1831                        new_row.as_ref(),
1832                    );
1833                    if trigger_statement.is_read_only() {
1834                        let _ = self.execute_read_statement(trigger_statement)?;
1835                    } else {
1836                        let _ =
1837                            self.execute_mutating_statement_in_scope(trigger_statement, false)?;
1838                    }
1839                }
1840            }
1841            Ok(())
1842        })();
1843        self.trigger_depth -= 1;
1844        result
1845    }
1846
1847    pub fn close(&mut self) -> Result<()> {
1848        if self.transaction.is_some() {
1849            return Err(HematiteError::InternalError(
1850                "Cannot close connection with an active transaction".to_string(),
1851            ));
1852        }
1853        let mut catalog_guard = self.lock_catalog()?;
1854        catalog_guard.flush()
1855    }
1856
1857    pub fn journal_mode(&self) -> Result<JournalMode> {
1858        let catalog_guard = self.lock_catalog()?;
1859        catalog_guard.journal_mode()
1860    }
1861
1862    pub fn set_journal_mode(&mut self, journal_mode: JournalMode) -> Result<()> {
1863        let mut catalog_guard = self.lock_catalog()?;
1864        catalog_guard.set_journal_mode(journal_mode)
1865    }
1866
1867    pub fn checkpoint_wal(&mut self) -> Result<()> {
1868        let mut catalog_guard = self.lock_catalog()?;
1869        catalog_guard.checkpoint_wal()
1870    }
1871
1872    pub fn execute(&mut self, sql: &str) -> Result<QueryResult> {
1873        self.execute_statement(Self::parse_statement(sql)?)
1874    }
1875
1876    pub fn execute_result(&mut self, sql: &str) -> Result<ExecutedStatement> {
1877        self.execute(sql).map(ExecutedStatement::from_query_result)
1878    }
1879
1880    pub fn iter_script<'a>(&'a mut self, sql: &str) -> Result<ScriptIter<'a>> {
1881        Ok(ScriptIter::new(self, split_script_tokens(sql)?))
1882    }
1883
1884    pub fn execute_batch(&mut self, sql: &str) -> Result<()> {
1885        for result in self.iter_script(sql)? {
1886            result?;
1887        }
1888        Ok(())
1889    }
1890
1891    pub fn execute_query(&mut self, sql: &str) -> Result<QueryResult> {
1892        self.execute(sql)
1893    }
1894
1895    pub fn prepare(&self, sql: &str) -> Result<PreparedStatement> {
1896        let statement = Self::parse_statement(sql)?;
1897        let parameter_count = statement.parameter_count();
1898
1899        Ok(PreparedStatement {
1900            statement,
1901            parameters: vec![None; parameter_count],
1902        })
1903    }
1904
1905    pub fn begin_transaction(&'_ mut self) -> Result<Transaction<'_>> {
1906        self.begin_active_transaction()?;
1907        Ok(Transaction {
1908            connection: self,
1909            completed: false,
1910        })
1911    }
1912
1913    fn begin_active_transaction(&mut self) -> Result<()> {
1914        if self.transaction.is_some() {
1915            return Err(HematiteError::InternalError(
1916                "Transaction is already active".to_string(),
1917            ));
1918        }
1919
1920        let mut catalog_guard = self.lock_catalog()?;
1921        let snapshot = catalog_guard.snapshot()?;
1922        catalog_guard.begin_transaction()?;
1923        drop(catalog_guard);
1924        self.transaction = Some(ConnectionTransaction {
1925            snapshot,
1926            savepoints: Vec::new(),
1927        });
1928        Ok(())
1929    }
1930
1931    #[cfg(test)]
1932    pub(crate) fn schema_snapshot(&self) -> Result<Schema> {
1933        let catalog_guard = self.lock_catalog()?;
1934        Ok(catalog_guard.clone_schema())
1935    }
1936
1937    fn active_transaction_mut(&mut self, action: &str) -> Result<&mut ConnectionTransaction> {
1938        self.transaction.as_mut().ok_or_else(|| {
1939            HematiteError::ParseError(format!("{} requires an active transaction", action))
1940        })
1941    }
1942
1943    fn create_savepoint(&mut self, name: &str) -> Result<()> {
1944        {
1945            let transaction = self.active_transaction_mut("SAVEPOINT")?;
1946            if transaction
1947                .savepoints
1948                .iter()
1949                .any(|savepoint| savepoint.name.eq_ignore_ascii_case(name))
1950            {
1951                return Err(HematiteError::ParseError(format!(
1952                    "Savepoint '{}' already exists",
1953                    name
1954                )));
1955            }
1956        }
1957
1958        let snapshot = {
1959            let catalog_guard = self.lock_catalog()?;
1960            catalog_guard.snapshot()
1961        }?;
1962
1963        let transaction = self.active_transaction_mut("SAVEPOINT")?;
1964        transaction.savepoints.push(SavepointState {
1965            name: name.to_string(),
1966            snapshot,
1967        });
1968        Ok(())
1969    }
1970
1971    fn rollback_to_savepoint(&mut self, name: &str) -> Result<()> {
1972        let position = {
1973            let transaction = self.active_transaction_mut("ROLLBACK TO SAVEPOINT")?;
1974            transaction
1975                .savepoints
1976                .iter()
1977                .position(|savepoint| savepoint.name.eq_ignore_ascii_case(name))
1978                .ok_or_else(|| {
1979                    HematiteError::ParseError(format!("Savepoint '{}' does not exist", name))
1980                })?
1981        };
1982
1983        let snapshot = {
1984            let transaction = self.active_transaction_mut("ROLLBACK TO SAVEPOINT")?;
1985            transaction.savepoints[position].snapshot.clone()
1986        };
1987
1988        {
1989            let mut catalog_guard = self.lock_catalog()?;
1990            catalog_guard.restore_snapshot(snapshot)?;
1991        }
1992
1993        let transaction = self.active_transaction_mut("ROLLBACK TO SAVEPOINT")?;
1994        transaction.savepoints.truncate(position + 1);
1995        Ok(())
1996    }
1997
1998    fn release_savepoint(&mut self, name: &str) -> Result<()> {
1999        let transaction = self.active_transaction_mut("RELEASE SAVEPOINT")?;
2000        let position = transaction
2001            .savepoints
2002            .iter()
2003            .position(|savepoint| savepoint.name.eq_ignore_ascii_case(name))
2004            .ok_or_else(|| {
2005                HematiteError::ParseError(format!("Savepoint '{}' does not exist", name))
2006            })?;
2007        transaction.savepoints.remove(position);
2008        Ok(())
2009    }
2010}
2011
2012fn substitute_trigger_statement(
2013    statement: Statement,
2014    table: &crate::catalog::Table,
2015    old_row: Option<&crate::catalog::StoredRow>,
2016    new_row: Option<&crate::catalog::StoredRow>,
2017) -> Statement {
2018    let mut bindings = HashMap::new();
2019    if let Some(old_row) = old_row {
2020        for (column, value) in table.columns.iter().zip(old_row.values.iter()) {
2021            bindings.insert(format!("OLD.{}", column.name), raise_literal_value(value));
2022        }
2023    }
2024    if let Some(new_row) = new_row {
2025        for (column, value) in table.columns.iter().zip(new_row.values.iter()) {
2026            bindings.insert(format!("NEW.{}", column.name), raise_literal_value(value));
2027        }
2028    }
2029
2030    substitute_statement_bindings(statement, &bindings)
2031}
2032
2033fn substitute_statement_bindings(
2034    statement: Statement,
2035    bindings: &HashMap<String, crate::parser::types::LiteralValue>,
2036) -> Statement {
2037    match statement {
2038        Statement::Select(select) => {
2039            Statement::Select(substitute_select_bindings(select, bindings))
2040        }
2041        Statement::Insert(insert) => Statement::Insert(crate::parser::ast::InsertStatement {
2042            table: insert.table,
2043            columns: insert.columns,
2044            source: match insert.source {
2045                InsertSource::Values(rows) => InsertSource::Values(
2046                    rows.into_iter()
2047                        .map(|row| {
2048                            row.into_iter()
2049                                .map(|expr| substitute_expression_bindings(expr, bindings))
2050                                .collect()
2051                        })
2052                        .collect(),
2053                ),
2054                InsertSource::Select(select) => {
2055                    InsertSource::Select(Box::new(substitute_select_bindings(*select, bindings)))
2056                }
2057            },
2058            on_duplicate: insert.on_duplicate.map(|assignments| {
2059                assignments
2060                    .into_iter()
2061                    .map(|assignment| crate::parser::ast::UpdateAssignment {
2062                        column: assignment.column,
2063                        value: substitute_expression_bindings(assignment.value, bindings),
2064                    })
2065                    .collect()
2066            }),
2067        }),
2068        Statement::Update(update) => Statement::Update(crate::parser::ast::UpdateStatement {
2069            table: update.table,
2070            target_binding: update.target_binding,
2071            source: update.source,
2072            assignments: update
2073                .assignments
2074                .into_iter()
2075                .map(|assignment| crate::parser::ast::UpdateAssignment {
2076                    column: assignment.column,
2077                    value: substitute_expression_bindings(assignment.value, bindings),
2078                })
2079                .collect(),
2080            where_clause: update
2081                .where_clause
2082                .map(|where_clause| substitute_where_clause_bindings(where_clause, bindings)),
2083        }),
2084        Statement::Delete(delete) => Statement::Delete(crate::parser::ast::DeleteStatement {
2085            table: delete.table,
2086            target_binding: delete.target_binding,
2087            source: delete.source,
2088            where_clause: delete
2089                .where_clause
2090                .map(|where_clause| substitute_where_clause_bindings(where_clause, bindings)),
2091        }),
2092        other => other,
2093    }
2094}
2095
2096fn substitute_select_bindings(
2097    select: SelectStatement,
2098    bindings: &HashMap<String, crate::parser::types::LiteralValue>,
2099) -> SelectStatement {
2100    SelectStatement {
2101        with_clause: select
2102            .with_clause
2103            .into_iter()
2104            .map(|cte| crate::parser::ast::CommonTableExpression {
2105                name: cte.name,
2106                recursive: cte.recursive,
2107                query: Box::new(substitute_select_bindings(*cte.query, bindings)),
2108            })
2109            .collect(),
2110        distinct: select.distinct,
2111        columns: select
2112            .columns
2113            .into_iter()
2114            .map(|item| match item {
2115                crate::parser::ast::SelectItem::Expression(expr) => {
2116                    crate::parser::ast::SelectItem::Expression(substitute_expression_bindings(
2117                        expr, bindings,
2118                    ))
2119                }
2120                crate::parser::ast::SelectItem::Column(name) => bindings
2121                    .get(&name)
2122                    .cloned()
2123                    .map(crate::parser::ast::Expression::Literal)
2124                    .map(crate::parser::ast::SelectItem::Expression)
2125                    .unwrap_or(crate::parser::ast::SelectItem::Column(name)),
2126                other => other,
2127            })
2128            .collect(),
2129        column_aliases: select.column_aliases,
2130        from: substitute_table_reference_bindings(select.from, bindings),
2131        where_clause: select
2132            .where_clause
2133            .map(|where_clause| substitute_where_clause_bindings(where_clause, bindings)),
2134        group_by: select
2135            .group_by
2136            .into_iter()
2137            .map(|expr| substitute_expression_bindings(expr, bindings))
2138            .collect(),
2139        having_clause: select
2140            .having_clause
2141            .map(|where_clause| substitute_where_clause_bindings(where_clause, bindings)),
2142        order_by: select.order_by,
2143        limit: select.limit,
2144        offset: select.offset,
2145        set_operation: select
2146            .set_operation
2147            .map(|set_operation| crate::parser::ast::SetOperation {
2148                operator: set_operation.operator,
2149                right: Box::new(substitute_select_bindings(*set_operation.right, bindings)),
2150            }),
2151    }
2152}
2153
2154fn substitute_table_reference_bindings(
2155    table_reference: TableReference,
2156    bindings: &HashMap<String, crate::parser::types::LiteralValue>,
2157) -> TableReference {
2158    match table_reference {
2159        TableReference::Table(name, alias) => TableReference::Table(name, alias),
2160        TableReference::Derived { subquery, alias } => TableReference::Derived {
2161            subquery: Box::new(substitute_select_bindings(*subquery, bindings)),
2162            alias,
2163        },
2164        TableReference::CrossJoin(left, right) => TableReference::CrossJoin(
2165            Box::new(substitute_table_reference_bindings(*left, bindings)),
2166            Box::new(substitute_table_reference_bindings(*right, bindings)),
2167        ),
2168        TableReference::InnerJoin { left, right, on } => TableReference::InnerJoin {
2169            left: Box::new(substitute_table_reference_bindings(*left, bindings)),
2170            right: Box::new(substitute_table_reference_bindings(*right, bindings)),
2171            on: substitute_condition_bindings(on, bindings),
2172        },
2173        TableReference::LeftJoin { left, right, on } => TableReference::LeftJoin {
2174            left: Box::new(substitute_table_reference_bindings(*left, bindings)),
2175            right: Box::new(substitute_table_reference_bindings(*right, bindings)),
2176            on: substitute_condition_bindings(on, bindings),
2177        },
2178        TableReference::RightJoin { left, right, on } => TableReference::RightJoin {
2179            left: Box::new(substitute_table_reference_bindings(*left, bindings)),
2180            right: Box::new(substitute_table_reference_bindings(*right, bindings)),
2181            on: substitute_condition_bindings(on, bindings),
2182        },
2183        TableReference::FullOuterJoin { left, right, on } => TableReference::FullOuterJoin {
2184            left: Box::new(substitute_table_reference_bindings(*left, bindings)),
2185            right: Box::new(substitute_table_reference_bindings(*right, bindings)),
2186            on: substitute_condition_bindings(on, bindings),
2187        },
2188    }
2189}
2190
2191fn substitute_where_clause_bindings(
2192    where_clause: WhereClause,
2193    bindings: &HashMap<String, crate::parser::types::LiteralValue>,
2194) -> WhereClause {
2195    WhereClause {
2196        conditions: where_clause
2197            .conditions
2198            .into_iter()
2199            .map(|condition| substitute_condition_bindings(condition, bindings))
2200            .collect(),
2201    }
2202}
2203
2204fn substitute_condition_bindings(
2205    condition: Condition,
2206    bindings: &HashMap<String, crate::parser::types::LiteralValue>,
2207) -> Condition {
2208    match condition {
2209        Condition::Comparison {
2210            left,
2211            operator,
2212            right,
2213        } => Condition::Comparison {
2214            left: substitute_expression_bindings(left, bindings),
2215            operator,
2216            right: substitute_expression_bindings(right, bindings),
2217        },
2218        Condition::InList {
2219            expr,
2220            values,
2221            is_not,
2222        } => Condition::InList {
2223            expr: substitute_expression_bindings(expr, bindings),
2224            values: values
2225                .into_iter()
2226                .map(|expr| substitute_expression_bindings(expr, bindings))
2227                .collect(),
2228            is_not,
2229        },
2230        Condition::InSubquery {
2231            expr,
2232            subquery,
2233            is_not,
2234        } => Condition::InSubquery {
2235            expr: substitute_expression_bindings(expr, bindings),
2236            subquery: Box::new(substitute_select_bindings(*subquery, bindings)),
2237            is_not,
2238        },
2239        Condition::Between {
2240            expr,
2241            lower,
2242            upper,
2243            is_not,
2244        } => Condition::Between {
2245            expr: substitute_expression_bindings(expr, bindings),
2246            lower: substitute_expression_bindings(lower, bindings),
2247            upper: substitute_expression_bindings(upper, bindings),
2248            is_not,
2249        },
2250        Condition::Like {
2251            expr,
2252            pattern,
2253            is_not,
2254        } => Condition::Like {
2255            expr: substitute_expression_bindings(expr, bindings),
2256            pattern: substitute_expression_bindings(pattern, bindings),
2257            is_not,
2258        },
2259        Condition::Exists { subquery, is_not } => Condition::Exists {
2260            subquery: Box::new(substitute_select_bindings(*subquery, bindings)),
2261            is_not,
2262        },
2263        Condition::NullCheck { expr, is_not } => Condition::NullCheck {
2264            expr: substitute_expression_bindings(expr, bindings),
2265            is_not,
2266        },
2267        Condition::Not(condition) => Condition::Not(Box::new(substitute_condition_bindings(
2268            *condition, bindings,
2269        ))),
2270        Condition::Logical {
2271            left,
2272            operator,
2273            right,
2274        } => Condition::Logical {
2275            left: Box::new(substitute_condition_bindings(*left, bindings)),
2276            operator,
2277            right: Box::new(substitute_condition_bindings(*right, bindings)),
2278        },
2279    }
2280}
2281
2282fn substitute_expression_bindings(
2283    expression: Expression,
2284    bindings: &HashMap<String, crate::parser::types::LiteralValue>,
2285) -> Expression {
2286    match expression {
2287        Expression::Column(name) => bindings
2288            .get(&name)
2289            .cloned()
2290            .map(Expression::Literal)
2291            .unwrap_or(Expression::Column(name)),
2292        Expression::Literal(_) | Expression::IntervalLiteral { .. } | Expression::Parameter(_) => {
2293            expression
2294        }
2295        Expression::ScalarSubquery(subquery) => {
2296            Expression::ScalarSubquery(Box::new(substitute_select_bindings(*subquery, bindings)))
2297        }
2298        Expression::Cast { expr, target_type } => Expression::Cast {
2299            expr: Box::new(substitute_expression_bindings(*expr, bindings)),
2300            target_type,
2301        },
2302        Expression::Case {
2303            branches,
2304            else_expr,
2305        } => Expression::Case {
2306            branches: branches
2307                .into_iter()
2308                .map(|branch| crate::parser::ast::CaseWhenClause {
2309                    condition: substitute_expression_bindings(branch.condition, bindings),
2310                    result: substitute_expression_bindings(branch.result, bindings),
2311                })
2312                .collect(),
2313            else_expr: else_expr
2314                .map(|expr| Box::new(substitute_expression_bindings(*expr, bindings))),
2315        },
2316        Expression::ScalarFunctionCall { function, args } => Expression::ScalarFunctionCall {
2317            function,
2318            args: args
2319                .into_iter()
2320                .map(|expr| substitute_expression_bindings(expr, bindings))
2321                .collect(),
2322        },
2323        Expression::AggregateCall { function, target } => {
2324            Expression::AggregateCall { function, target }
2325        }
2326        Expression::UnaryMinus(expr) => {
2327            Expression::UnaryMinus(Box::new(substitute_expression_bindings(*expr, bindings)))
2328        }
2329        Expression::UnaryNot(expr) => {
2330            Expression::UnaryNot(Box::new(substitute_expression_bindings(*expr, bindings)))
2331        }
2332        Expression::Binary {
2333            left,
2334            operator,
2335            right,
2336        } => Expression::Binary {
2337            left: Box::new(substitute_expression_bindings(*left, bindings)),
2338            operator,
2339            right: Box::new(substitute_expression_bindings(*right, bindings)),
2340        },
2341        Expression::Comparison {
2342            left,
2343            operator,
2344            right,
2345        } => Expression::Comparison {
2346            left: Box::new(substitute_expression_bindings(*left, bindings)),
2347            operator,
2348            right: Box::new(substitute_expression_bindings(*right, bindings)),
2349        },
2350        Expression::InList {
2351            expr,
2352            values,
2353            is_not,
2354        } => Expression::InList {
2355            expr: Box::new(substitute_expression_bindings(*expr, bindings)),
2356            values: values
2357                .into_iter()
2358                .map(|expr| substitute_expression_bindings(expr, bindings))
2359                .collect(),
2360            is_not,
2361        },
2362        Expression::InSubquery {
2363            expr,
2364            subquery,
2365            is_not,
2366        } => Expression::InSubquery {
2367            expr: Box::new(substitute_expression_bindings(*expr, bindings)),
2368            subquery: Box::new(substitute_select_bindings(*subquery, bindings)),
2369            is_not,
2370        },
2371        Expression::Between {
2372            expr,
2373            lower,
2374            upper,
2375            is_not,
2376        } => Expression::Between {
2377            expr: Box::new(substitute_expression_bindings(*expr, bindings)),
2378            lower: Box::new(substitute_expression_bindings(*lower, bindings)),
2379            upper: Box::new(substitute_expression_bindings(*upper, bindings)),
2380            is_not,
2381        },
2382        Expression::Like {
2383            expr,
2384            pattern,
2385            is_not,
2386        } => Expression::Like {
2387            expr: Box::new(substitute_expression_bindings(*expr, bindings)),
2388            pattern: Box::new(substitute_expression_bindings(*pattern, bindings)),
2389            is_not,
2390        },
2391        Expression::Exists { subquery, is_not } => Expression::Exists {
2392            subquery: Box::new(substitute_select_bindings(*subquery, bindings)),
2393            is_not,
2394        },
2395        Expression::NullCheck { expr, is_not } => Expression::NullCheck {
2396            expr: Box::new(substitute_expression_bindings(*expr, bindings)),
2397            is_not,
2398        },
2399        Expression::Logical {
2400            left,
2401            operator,
2402            right,
2403        } => Expression::Logical {
2404            left: Box::new(substitute_expression_bindings(*left, bindings)),
2405            operator,
2406            right: Box::new(substitute_expression_bindings(*right, bindings)),
2407        },
2408    }
2409}
2410
2411#[derive(Debug, Clone)]
2412pub struct PreparedStatement {
2413    statement: crate::parser::ast::Statement,
2414    parameters: Vec<Option<Value>>,
2415}
2416
2417impl PreparedStatement {
2418    pub fn bind(&mut self, index: usize, value: Value) -> Result<()> {
2419        if index == 0 || index > self.parameters.len() {
2420            return Err(HematiteError::ParseError(format!(
2421                "Parameter index {} is out of range",
2422                index
2423            )));
2424        }
2425
2426        self.parameters[index - 1] = Some(value);
2427        Ok(())
2428    }
2429
2430    pub fn bind_all(&mut self, values: Vec<Value>) -> Result<()> {
2431        if values.len() != self.parameters.len() {
2432            return Err(HematiteError::ParseError(format!(
2433                "Expected {} parameters, got {}",
2434                self.parameters.len(),
2435                values.len()
2436            )));
2437        }
2438
2439        self.parameters = values.into_iter().map(Some).collect();
2440        Ok(())
2441    }
2442
2443    pub fn clear_bindings(&mut self) {
2444        self.parameters.fill(None);
2445    }
2446
2447    pub fn parameter_count(&self) -> usize {
2448        self.parameters.len()
2449    }
2450
2451    pub fn execute(&mut self, connection: &mut Connection) -> Result<QueryResult> {
2452        let statement = self.bound_statement()?;
2453        connection.execute_statement(statement)
2454    }
2455
2456    pub fn query(&mut self, connection: &mut Connection) -> Result<QueryResult> {
2457        self.execute(connection)
2458    }
2459
2460    fn bound_statement(&self) -> Result<crate::parser::ast::Statement> {
2461        let bound_values = self
2462            .parameters
2463            .iter()
2464            .enumerate()
2465            .map(|(index, value)| {
2466                value.clone().ok_or_else(|| {
2467                    HematiteError::ParseError(format!("Parameter {} has not been bound", index + 1))
2468                })
2469            })
2470            .collect::<Result<Vec<_>>>()?;
2471        let bound_literals = bound_values
2472            .iter()
2473            .map(raise_literal_value)
2474            .collect::<Vec<_>>();
2475
2476        self.statement.bind_parameters(&bound_literals)
2477    }
2478}
2479
2480#[derive(Debug)]
2481pub struct Transaction<'a> {
2482    connection: &'a mut Connection,
2483    completed: bool,
2484}
2485
2486impl<'a> Transaction<'a> {
2487    pub fn execute(&mut self, sql: &str) -> Result<QueryResult> {
2488        self.connection.execute(sql)
2489    }
2490
2491    pub fn commit(&mut self) -> Result<()> {
2492        if self.completed {
2493            return Err(HematiteError::InternalError(
2494                "Transaction is already completed".to_string(),
2495            ));
2496        }
2497        self.connection.commit_active_transaction()?;
2498        self.completed = true;
2499        Ok(())
2500    }
2501
2502    pub fn rollback(&mut self) -> Result<()> {
2503        if self.completed {
2504            return Err(HematiteError::InternalError(
2505                "Transaction is already completed".to_string(),
2506            ));
2507        }
2508        self.connection.rollback_active_transaction()?;
2509        self.completed = true;
2510        Ok(())
2511    }
2512}
2513
2514impl<'a> Drop for Transaction<'a> {
2515    fn drop(&mut self) {
2516        if !self.completed {
2517            let _ = self.connection.rollback_active_transaction();
2518        }
2519    }
2520}
2521
2522#[derive(Debug, Clone)]
2523pub struct Database;
2524
2525impl Database {
2526    pub fn new() -> Self {
2527        Self
2528    }
2529
2530    pub fn open(database_path: &str) -> Result<Connection> {
2531        Connection::new(database_path)
2532    }
2533
2534    pub fn open_in_memory() -> Result<Connection> {
2535        Connection::new_in_memory()
2536    }
2537
2538    pub fn connect(&mut self, database_path: &str) -> Result<Connection> {
2539        Connection::new(database_path)
2540    }
2541}
2542
2543impl Default for Database {
2544    fn default() -> Self {
2545        Self::new()
2546    }
2547}
2548
2549impl Connection {
2550    fn take_active_transaction(&mut self, action: &str) -> Result<ConnectionTransaction> {
2551        self.transaction.take().ok_or_else(|| {
2552            HematiteError::InternalError(format!("No active transaction to {}", action))
2553        })
2554    }
2555
2556    fn commit_active_transaction(&mut self) -> Result<()> {
2557        let state = self.take_active_transaction("commit")?;
2558        let mut catalog_guard = self.lock_catalog()?;
2559        match catalog_guard.commit_transaction() {
2560            Ok(()) => Ok(()),
2561            Err(err) => {
2562                let _ = catalog_guard.rollback_transaction();
2563                catalog_guard.restore_snapshot(state.snapshot)?;
2564                Err(err)
2565            }
2566        }
2567    }
2568
2569    fn rollback_active_transaction(&mut self) -> Result<()> {
2570        let state = self.take_active_transaction("roll back")?;
2571        let mut catalog_guard = self.lock_catalog()?;
2572        catalog_guard.rollback_transaction()?;
2573        catalog_guard.restore_snapshot(state.snapshot)?;
2574        Ok(())
2575    }
2576}