1use crate::error::{HematiteError, Result};
26use crate::parser::ast::{
27 ColumnDefinition, Condition, CreateStatement, CreateViewStatement, Expression, InsertSource,
28 InsertStatement, SelectIntoStatement, SelectStatement, Statement, TableReference, TriggerEvent,
29 WhereClause,
30};
31use crate::parser::{Lexer, Parser, SqlTypeName};
32use crate::query::lowering::raise_literal_value;
33use crate::query::metadata as query_metadata;
34use crate::query::validation::{projected_column_names, source_column_names, validate_statement};
35use crate::query::{
36 Catalog, CatalogEngine, ExecutionContext, JournalMode, MutationEvent, QueryCatalogSnapshot,
37 QueryExecutor, QueryPlanner, QueryResult, Schema, Value,
38};
39use crate::sql::result::ExecutedStatement;
40use crate::sql::script::{split_script_tokens, ScriptIter};
41use std::collections::{HashMap, HashSet};
42use std::sync::{Arc, Mutex, MutexGuard};
43
44#[derive(Debug, Clone)]
45struct ConnectionTransaction {
46 snapshot: QueryCatalogSnapshot,
47 savepoints: Vec<SavepointState>,
48}
49
50#[derive(Debug, Clone)]
51struct SavepointState {
52 name: String,
53 snapshot: QueryCatalogSnapshot,
54}
55
56#[derive(Debug)]
57struct ImplicitMutation {
58 snapshot: Option<QueryCatalogSnapshot>,
59}
60
61impl ImplicitMutation {
62 fn begin(connection: &mut Connection) -> Result<Self> {
63 if connection.transaction.is_some() {
64 return Ok(Self { snapshot: None });
65 }
66
67 let mut catalog_guard = connection.lock_catalog()?;
68 let snapshot = catalog_guard.snapshot()?;
69 catalog_guard.begin_transaction()?;
70 Ok(Self {
71 snapshot: Some(snapshot),
72 })
73 }
74
75 fn rollback(mut self, connection: &mut Connection) -> Result<()> {
76 if let Some(snapshot) = self.snapshot.take() {
77 let mut catalog_guard = connection.lock_catalog()?;
78 let _ = catalog_guard.rollback_transaction();
79 catalog_guard.restore_snapshot(snapshot)?;
80 }
81 Ok(())
82 }
83
84 fn commit(mut self, connection: &mut Connection) -> Result<()> {
85 let Some(snapshot) = self.snapshot.take() else {
86 return Ok(());
87 };
88
89 let mut catalog_guard = connection.lock_catalog()?;
90 match catalog_guard.commit_transaction() {
91 Ok(()) => Ok(()),
92 Err(err) => {
93 let _ = catalog_guard.rollback_transaction();
94 catalog_guard.restore_snapshot(snapshot)?;
95 Err(err)
96 }
97 }
98 }
99}
100
101#[derive(Debug)]
102pub struct Connection {
103 catalog: Arc<Mutex<Catalog>>,
104 transaction: Option<ConnectionTransaction>,
105 trigger_depth: usize,
106}
107
108impl Connection {
109 const SELECT_INTO_ROWID_COLUMN: &'static str = "__hematite_select_into_rowid";
110
111 fn empty_result() -> QueryResult {
112 QueryResult {
113 affected_rows: 0,
114 columns: Vec::new(),
115 rows: Vec::new(),
116 }
117 }
118
119 fn mutation_result(affected_rows: usize) -> QueryResult {
120 QueryResult {
121 affected_rows,
122 columns: Vec::new(),
123 rows: Vec::new(),
124 }
125 }
126
127 fn select_into_synthetic_pk_name(column_names: &[String]) -> String {
128 let mut candidate = Self::SELECT_INTO_ROWID_COLUMN.to_string();
129 let used = column_names
130 .iter()
131 .map(|name| name.to_ascii_lowercase())
132 .collect::<HashSet<_>>();
133 let mut suffix = 2usize;
134 while used.contains(&candidate.to_ascii_lowercase()) {
135 candidate = format!("{}_{}", Self::SELECT_INTO_ROWID_COLUMN, suffix);
136 suffix += 1;
137 }
138 candidate
139 }
140
141 fn select_into_column_names(result: &QueryResult) -> Vec<String> {
142 let mut used = HashSet::new();
143 let mut names = Vec::with_capacity(result.columns.len());
144 for (index, name) in result.columns.iter().enumerate() {
145 let mut candidate = if name.trim().is_empty() {
146 format!("column{}", index + 1)
147 } else {
148 name.clone()
149 };
150 let base = candidate.clone();
151 let mut suffix = 2usize;
152 while used.contains(&candidate.to_ascii_lowercase())
153 || candidate.eq_ignore_ascii_case(Self::SELECT_INTO_ROWID_COLUMN)
154 {
155 candidate = format!("{base}_{suffix}");
156 suffix += 1;
157 }
158 used.insert(candidate.to_ascii_lowercase());
159 names.push(candidate);
160 }
161 names
162 }
163
164 fn infer_select_into_type(
165 column_name: &str,
166 values: &[Vec<Value>],
167 index: usize,
168 ) -> Result<SqlTypeName> {
169 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
170 enum NumericKind {
171 Int,
172 Int64,
173 Int128,
174 UInt,
175 UInt64,
176 UInt128,
177 Float32,
178 Float,
179 Decimal,
180 }
181
182 #[derive(Debug, Clone)]
183 enum InferredKind {
184 Numeric(NumericKind),
185 String { saw_enum: bool, values: Vec<String> },
186 Boolean,
187 Blob,
188 Date,
189 Time,
190 DateTime,
191 TimeWithTimeZone,
192 }
193
194 impl InferredKind {
195 fn absorb(self, value: &Value, column_name: &str) -> Result<Self> {
196 use InferredKind::*;
197 use NumericKind::*;
198 match (self, value) {
199 (kind, Value::Null) => Ok(kind),
200 (_, Value::IntervalYearMonth(_)) | (_, Value::IntervalDaySecond(_)) => {
201 Err(HematiteError::ParseError(format!(
202 "SELECT INTO cannot infer a stored column type for interval-valued column '{}'",
203 column_name
204 )))
205 }
206 (Numeric(Int), Value::Integer(_)) => Ok(Numeric(Int)),
207 (Numeric(Int), Value::BigInt(_))
208 | (Numeric(Int64), Value::Integer(_))
209 | (Numeric(Int64), Value::BigInt(_)) => Ok(Numeric(Int64)),
210 (Numeric(Int), Value::Int128(_))
211 | (Numeric(Int64), Value::Int128(_))
212 | (Numeric(Int128), Value::Integer(_))
213 | (Numeric(Int128), Value::BigInt(_))
214 | (Numeric(Int128), Value::Int128(_)) => Ok(Numeric(Int128)),
215 (Numeric(UInt), Value::UInteger(_)) => Ok(Numeric(UInt)),
216 (Numeric(UInt), Value::UBigInt(_))
217 | (Numeric(UInt64), Value::UInteger(_))
218 | (Numeric(UInt64), Value::UBigInt(_)) => Ok(Numeric(UInt64)),
219 (Numeric(UInt), Value::UInt128(_))
220 | (Numeric(UInt64), Value::UInt128(_))
221 | (Numeric(UInt128), Value::UInteger(_))
222 | (Numeric(UInt128), Value::UBigInt(_))
223 | (Numeric(UInt128), Value::UInt128(_)) => Ok(Numeric(UInt128)),
224 (Numeric(Int), Value::UInteger(_))
225 | (Numeric(Int64), Value::UInteger(_))
226 | (Numeric(Int128), Value::UInteger(_))
227 | (Numeric(UInt), Value::Integer(_))
228 | (Numeric(UInt), Value::BigInt(_))
229 | (Numeric(UInt), Value::Int128(_))
230 | (Numeric(UInt64), Value::Integer(_))
231 | (Numeric(UInt64), Value::BigInt(_))
232 | (Numeric(UInt64), Value::Int128(_))
233 | (Numeric(UInt128), Value::Integer(_))
234 | (Numeric(UInt128), Value::BigInt(_))
235 | (Numeric(UInt128), Value::Int128(_))
236 | (Numeric(Int64), Value::UBigInt(_))
237 | (Numeric(Int128), Value::UBigInt(_))
238 | (Numeric(Int128), Value::UInt128(_))
239 => Ok(Numeric(Decimal)),
240 (Numeric(Int), Value::Float32(_))
241 | (Numeric(Int64), Value::Float32(_))
242 | (Numeric(Int128), Value::Float32(_))
243 | (Numeric(UInt), Value::Float32(_))
244 | (Numeric(UInt64), Value::Float32(_))
245 | (Numeric(UInt128), Value::Float32(_))
246 | (Numeric(Float32), Value::Integer(_))
247 | (Numeric(Float32), Value::BigInt(_))
248 | (Numeric(Float32), Value::Int128(_))
249 | (Numeric(Float32), Value::UInteger(_))
250 | (Numeric(Float32), Value::UBigInt(_))
251 | (Numeric(Float32), Value::UInt128(_))
252 | (Numeric(Float32), Value::Float32(_)) => Ok(Numeric(Float32)),
253 (Numeric(Int), Value::Float(_))
254 | (Numeric(Int64), Value::Float(_))
255 | (Numeric(Int128), Value::Float(_))
256 | (Numeric(UInt), Value::Float(_))
257 | (Numeric(UInt64), Value::Float(_))
258 | (Numeric(UInt128), Value::Float(_))
259 | (Numeric(Float32), Value::Float(_))
260 | (Numeric(Float), Value::Integer(_))
261 | (Numeric(Float), Value::BigInt(_))
262 | (Numeric(Float), Value::Int128(_))
263 | (Numeric(Float), Value::UInteger(_))
264 | (Numeric(Float), Value::UBigInt(_))
265 | (Numeric(Float), Value::UInt128(_))
266 | (Numeric(Float), Value::Float32(_))
267 | (Numeric(Float), Value::Float(_)) => Ok(Numeric(Float)),
268 (Numeric(Int), Value::Decimal(_))
269 | (Numeric(Int64), Value::Decimal(_))
270 | (Numeric(Int128), Value::Decimal(_))
271 | (Numeric(UInt), Value::Decimal(_))
272 | (Numeric(UInt64), Value::Decimal(_))
273 | (Numeric(UInt128), Value::Decimal(_))
274 | (Numeric(Float32), Value::Decimal(_))
275 | (Numeric(Float), Value::Decimal(_))
276 | (Numeric(Decimal), Value::Integer(_))
277 | (Numeric(Decimal), Value::BigInt(_))
278 | (Numeric(Decimal), Value::Int128(_))
279 | (Numeric(Decimal), Value::UInteger(_))
280 | (Numeric(Decimal), Value::UBigInt(_))
281 | (Numeric(Decimal), Value::UInt128(_))
282 | (Numeric(Decimal), Value::Float32(_))
283 | (Numeric(Decimal), Value::Float(_))
284 | (Numeric(Decimal), Value::Decimal(_)) => Ok(Numeric(Decimal)),
285 (
286 String {
287 saw_enum,
288 mut values,
289 },
290 Value::Text(text),
291 ) => {
292 if !values.iter().any(|candidate| candidate == text) {
293 values.push(text.clone());
294 }
295 Ok(String { saw_enum, values })
296 }
297 (
298 String {
299 saw_enum: _,
300 mut values,
301 },
302 Value::Enum(text),
303 ) => {
304 if !values.iter().any(|candidate| candidate == text) {
305 values.push(text.clone());
306 }
307 Ok(String {
308 saw_enum: true,
309 values,
310 })
311 }
312 (Blob, Value::Blob(_)) => Ok(Blob),
313 (Blob, Value::Text(_)) => Ok(Blob),
314 (Date, Value::Date(_)) => Ok(Date),
315 (Time, Value::Time(_)) => Ok(Time),
316 (DateTime, Value::DateTime(_)) => Ok(DateTime),
317 (TimeWithTimeZone, Value::TimeWithTimeZone(_)) => Ok(TimeWithTimeZone),
318 (Boolean, Value::Boolean(_)) => Ok(Boolean),
319 (left, right) => Err(HematiteError::ParseError(format!(
320 "SELECT INTO cannot infer a stable column type for '{}': {:?} cannot be combined with {:?}",
321 column_name, left, right
322 ))),
323 }
324 }
325
326 fn from_value(value: &Value, column_name: &str) -> Result<Option<Self>> {
327 use InferredKind::*;
328 use NumericKind::*;
329 let inferred = match value {
330 Value::Null => return Ok(None),
331 Value::Integer(_) => Numeric(Int),
332 Value::BigInt(_) => Numeric(Int64),
333 Value::Int128(_) => Numeric(Int128),
334 Value::UInteger(_) => Numeric(UInt),
335 Value::UBigInt(_) => Numeric(UInt64),
336 Value::UInt128(_) => Numeric(UInt128),
337 Value::Float32(_) => Numeric(Float32),
338 Value::Float(_) => Numeric(Float),
339 Value::Decimal(_) => Numeric(Decimal),
340 Value::Text(text) => String {
341 saw_enum: false,
342 values: vec![text.clone()],
343 },
344 Value::Enum(text) => String {
345 saw_enum: true,
346 values: vec![text.clone()],
347 },
348 Value::Boolean(_) => Boolean,
349 Value::Blob(_) => Blob,
350 Value::Date(_) => Date,
351 Value::Time(_) => Time,
352 Value::DateTime(_) => DateTime,
353 Value::TimeWithTimeZone(_) => TimeWithTimeZone,
354 Value::IntervalYearMonth(_) | Value::IntervalDaySecond(_) => {
355 return Err(HematiteError::ParseError(format!(
356 "SELECT INTO cannot infer a stored column type for interval-valued column '{}'",
357 column_name
358 )))
359 }
360 };
361 Ok(Some(inferred))
362 }
363
364 fn into_sql_type(self) -> SqlTypeName {
365 match self {
366 InferredKind::Numeric(NumericKind::Int) => SqlTypeName::Int,
367 InferredKind::Numeric(NumericKind::Int64) => SqlTypeName::Int64,
368 InferredKind::Numeric(NumericKind::Int128) => SqlTypeName::Int128,
369 InferredKind::Numeric(NumericKind::UInt) => SqlTypeName::UInt,
370 InferredKind::Numeric(NumericKind::UInt64) => SqlTypeName::UInt64,
371 InferredKind::Numeric(NumericKind::UInt128) => SqlTypeName::UInt128,
372 InferredKind::Numeric(NumericKind::Float32) => SqlTypeName::Float32,
373 InferredKind::Numeric(NumericKind::Float) => SqlTypeName::Float,
374 InferredKind::Numeric(NumericKind::Decimal) => SqlTypeName::Decimal {
375 precision: None,
376 scale: None,
377 },
378 InferredKind::String {
379 saw_enum: true,
380 values,
381 } => SqlTypeName::Enum(values),
382 InferredKind::String { .. } => SqlTypeName::Text,
383 InferredKind::Boolean => SqlTypeName::Boolean,
384 InferredKind::Blob => SqlTypeName::Blob,
385 InferredKind::Date => SqlTypeName::Date,
386 InferredKind::Time => SqlTypeName::Time,
387 InferredKind::DateTime => SqlTypeName::DateTime,
388 InferredKind::TimeWithTimeZone => SqlTypeName::TimeWithTimeZone,
389 }
390 }
391 }
392
393 let mut inferred = None;
394 for row in values {
395 let Some(value) = row.get(index) else {
396 return Err(HematiteError::InternalError(format!(
397 "SELECT INTO result row is missing projected column {}",
398 index
399 )));
400 };
401
402 inferred = match (inferred, InferredKind::from_value(value, column_name)?) {
403 (None, None) => None,
404 (None, Some(kind)) => Some(kind),
405 (Some(kind), None) => Some(kind),
406 (Some(kind), Some(_)) => Some(kind.absorb(value, column_name)?),
407 };
408 }
409
410 Ok(inferred
411 .map(InferredKind::into_sql_type)
412 .unwrap_or(SqlTypeName::Text))
413 }
414
415 fn infer_select_into_columns(result: &QueryResult) -> Result<Vec<ColumnDefinition>> {
416 let column_names = Self::select_into_column_names(result);
417 column_names
418 .iter()
419 .enumerate()
420 .map(|(index, name)| {
421 Ok(ColumnDefinition {
422 name: name.clone(),
423 data_type: Self::infer_select_into_type(name, &result.rows, index)?,
424 character_set: None,
425 collation: None,
426 nullable: true,
427 primary_key: false,
428 auto_increment: false,
429 unique: false,
430 default_value: None,
431 check_constraint: None,
432 references: None,
433 })
434 })
435 .collect()
436 }
437
438 fn lock_catalog(&self) -> Result<MutexGuard<'_, Catalog>> {
439 self.catalog.lock().map_err(|_| {
440 HematiteError::InternalError("SQL connection catalog mutex is poisoned".to_string())
441 })
442 }
443
444 pub fn new(database_path: &str) -> Result<Self> {
445 let catalog = Catalog::open_or_create(database_path)?;
446 Ok(Self {
447 catalog: Arc::new(Mutex::new(catalog)),
448 transaction: None,
449 trigger_depth: 0,
450 })
451 }
452
453 pub fn new_in_memory() -> Result<Self> {
454 let catalog = Catalog::open_in_memory()?;
455 Ok(Self {
456 catalog: Arc::new(Mutex::new(catalog)),
457 transaction: None,
458 trigger_depth: 0,
459 })
460 }
461
462 fn parse_statement(sql: &str) -> Result<crate::parser::ast::Statement> {
463 let mut lexer = Lexer::new(sql.to_string());
464 lexer.tokenize()?;
465
466 let mut parser = Parser::new(lexer.get_tokens().to_vec());
467 parser.parse()
468 }
469
470 fn parse_select_sql(sql: &str) -> Result<SelectStatement> {
471 match Self::parse_statement(&format!("{sql};"))? {
472 Statement::Select(select) => Ok(select),
473 other => Err(HematiteError::ParseError(format!(
474 "Expected stored view query to be SELECT, found {:?}",
475 other
476 ))),
477 }
478 }
479
480 fn expand_views_in_statement(statement: Statement, schema: &Schema) -> Result<Statement> {
481 match statement {
482 Statement::Explain(explain) => {
483 Ok(Statement::Explain(crate::parser::ast::ExplainStatement {
484 statement: Box::new(Self::expand_views_in_statement(
485 *explain.statement,
486 schema,
487 )?),
488 }))
489 }
490 Statement::Select(select) => Ok(Statement::Select(Self::expand_views_in_select(
491 select, schema,
492 )?)),
493 Statement::Insert(mut insert) => {
494 if let InsertSource::Select(select) = insert.source {
495 insert.source = InsertSource::Select(Box::new(Self::expand_views_in_select(
496 *select, schema,
497 )?));
498 }
499 Ok(Statement::Insert(insert))
500 }
501 Statement::CreateView(mut create_view) => {
502 create_view.query = Self::expand_views_in_select(create_view.query, schema)?;
503 Ok(Statement::CreateView(create_view))
504 }
505 other => Ok(other),
506 }
507 }
508
509 fn expand_views_in_select(
510 mut select: SelectStatement,
511 schema: &Schema,
512 ) -> Result<SelectStatement> {
513 for cte in &mut select.with_clause {
514 cte.query = Box::new(Self::expand_views_in_select((*cte.query).clone(), schema)?);
515 }
516 let original_from = select.from.clone();
517 let select_context = select.clone();
518 select.from =
519 Self::expand_views_in_table_reference(original_from, &select_context, schema)?;
520 if let Some(where_clause) = &mut select.where_clause {
521 Self::expand_views_in_where_clause(where_clause, schema)?;
522 }
523 for expr in &mut select.group_by {
524 Self::expand_views_in_expression(expr, schema)?;
525 }
526 if let Some(having_clause) = &mut select.having_clause {
527 Self::expand_views_in_where_clause(having_clause, schema)?;
528 }
529 if let Some(set_operation) = &mut select.set_operation {
530 set_operation.right = Box::new(Self::expand_views_in_select(
531 (*set_operation.right).clone(),
532 schema,
533 )?);
534 }
535 for item in &mut select.columns {
536 if let crate::parser::ast::SelectItem::Expression(expr) = item {
537 Self::expand_views_in_expression(expr, schema)?;
538 }
539 }
540 Ok(select)
541 }
542
543 fn expand_views_in_table_reference(
544 from: TableReference,
545 select: &SelectStatement,
546 schema: &Schema,
547 ) -> Result<TableReference> {
548 match from {
549 TableReference::Table(table_name, alias) => {
550 if select.lookup_cte(&table_name).is_some()
551 || schema.get_table_by_name(&table_name).is_some()
552 {
553 Ok(TableReference::Table(table_name, alias))
554 } else if let Some(view) = schema.view(&table_name) {
555 let subquery = Self::expand_views_in_select(
556 Self::parse_select_sql(&view.query_sql)?,
557 schema,
558 )?;
559 Ok(TableReference::Derived {
560 subquery: Box::new(subquery),
561 alias: alias.unwrap_or(table_name),
562 })
563 } else {
564 Ok(TableReference::Table(table_name, alias))
565 }
566 }
567 TableReference::Derived { subquery, alias } => Ok(TableReference::Derived {
568 subquery: Box::new(Self::expand_views_in_select(*subquery, schema)?),
569 alias,
570 }),
571 TableReference::CrossJoin(left, right) => Ok(TableReference::CrossJoin(
572 Box::new(Self::expand_views_in_table_reference(
573 *left, select, schema,
574 )?),
575 Box::new(Self::expand_views_in_table_reference(
576 *right, select, schema,
577 )?),
578 )),
579 TableReference::InnerJoin {
580 left,
581 right,
582 mut on,
583 } => {
584 Self::expand_views_in_condition(&mut on, schema)?;
585 Ok(TableReference::InnerJoin {
586 left: Box::new(Self::expand_views_in_table_reference(
587 *left, select, schema,
588 )?),
589 right: Box::new(Self::expand_views_in_table_reference(
590 *right, select, schema,
591 )?),
592 on,
593 })
594 }
595 TableReference::LeftJoin {
596 left,
597 right,
598 mut on,
599 } => {
600 Self::expand_views_in_condition(&mut on, schema)?;
601 Ok(TableReference::LeftJoin {
602 left: Box::new(Self::expand_views_in_table_reference(
603 *left, select, schema,
604 )?),
605 right: Box::new(Self::expand_views_in_table_reference(
606 *right, select, schema,
607 )?),
608 on,
609 })
610 }
611 TableReference::RightJoin {
612 left,
613 right,
614 mut on,
615 } => {
616 Self::expand_views_in_condition(&mut on, schema)?;
617 Ok(TableReference::RightJoin {
618 left: Box::new(Self::expand_views_in_table_reference(
619 *left, select, schema,
620 )?),
621 right: Box::new(Self::expand_views_in_table_reference(
622 *right, select, schema,
623 )?),
624 on,
625 })
626 }
627 TableReference::FullOuterJoin {
628 left,
629 right,
630 mut on,
631 } => {
632 Self::expand_views_in_condition(&mut on, schema)?;
633 Ok(TableReference::FullOuterJoin {
634 left: Box::new(Self::expand_views_in_table_reference(
635 *left, select, schema,
636 )?),
637 right: Box::new(Self::expand_views_in_table_reference(
638 *right, select, schema,
639 )?),
640 on,
641 })
642 }
643 }
644 }
645
646 fn expand_views_in_where_clause(where_clause: &mut WhereClause, schema: &Schema) -> Result<()> {
647 for condition in &mut where_clause.conditions {
648 Self::expand_views_in_condition(condition, schema)?;
649 }
650 Ok(())
651 }
652
653 fn expand_views_in_condition(condition: &mut Condition, schema: &Schema) -> Result<()> {
654 let mut expand = |subquery: &mut SelectStatement| -> Result<()> {
655 *subquery = Self::expand_views_in_select(subquery.clone(), schema)?;
656 Ok(())
657 };
658 Self::rewrite_nested_subqueries_in_condition(condition, &mut expand)
659 }
660
661 fn expand_views_in_expression(expr: &mut Expression, schema: &Schema) -> Result<()> {
662 let mut expand = |subquery: &mut SelectStatement| -> Result<()> {
663 *subquery = Self::expand_views_in_select(subquery.clone(), schema)?;
664 Ok(())
665 };
666 Self::rewrite_nested_subqueries_in_expression(expr, &mut expand)
667 }
668
669 fn normalize_statement(statement: Statement, schema: &Schema) -> Result<Statement> {
670 let mut statement = Self::expand_views_in_statement(statement, schema)?;
671 Self::rewrite_select_aliases_in_statement(&mut statement, schema)?;
672 Ok(statement)
673 }
674
675 fn rewrite_select_aliases_in_statement(
676 statement: &mut Statement,
677 schema: &Schema,
678 ) -> Result<()> {
679 match statement {
680 Statement::Explain(explain) => {
681 Self::rewrite_select_aliases_in_statement(&mut explain.statement, schema)
682 }
683 Statement::Select(select) => Self::rewrite_select_aliases_in_select(select, schema),
684 Statement::Insert(insert) => {
685 if let InsertSource::Select(select) = &mut insert.source {
686 Self::rewrite_select_aliases_in_select(select, schema)?;
687 }
688 Ok(())
689 }
690 Statement::CreateView(create_view) => {
691 Self::rewrite_select_aliases_in_select(&mut create_view.query, schema)
692 }
693 _ => Ok(()),
694 }
695 }
696
697 fn rewrite_select_aliases_in_select(
698 select: &mut SelectStatement,
699 schema: &Schema,
700 ) -> Result<()> {
701 for cte in &mut select.with_clause {
702 if !cte.recursive {
703 Self::rewrite_select_aliases_in_select(&mut cte.query, schema)?;
704 }
705 }
706
707 Self::rewrite_select_aliases_in_table_reference(&mut select.from, schema)?;
708
709 for item in &mut select.columns {
710 match item {
711 crate::parser::ast::SelectItem::Expression(expr) => {
712 Self::rewrite_nested_select_aliases_in_expression(expr, schema)?;
713 }
714 crate::parser::ast::SelectItem::Window { window, .. } => {
715 for expr in &mut window.partition_by {
716 Self::rewrite_nested_select_aliases_in_expression(expr, schema)?;
717 }
718 }
719 crate::parser::ast::SelectItem::Wildcard
720 | crate::parser::ast::SelectItem::Column(_)
721 | crate::parser::ast::SelectItem::CountAll
722 | crate::parser::ast::SelectItem::Aggregate { .. } => {}
723 }
724 }
725
726 let alias_map = Self::where_alias_map(select);
727 let source_columns = source_column_names(select, schema)?
728 .into_iter()
729 .collect::<HashSet<_>>();
730
731 if let Some(where_clause) = &mut select.where_clause {
732 for condition in &mut where_clause.conditions {
733 Self::rewrite_where_aliases_in_condition(
734 condition,
735 &alias_map,
736 &source_columns,
737 &mut HashSet::new(),
738 )?;
739 }
740 }
741
742 for expr in &mut select.group_by {
743 Self::rewrite_nested_select_aliases_in_expression(expr, schema)?;
744 }
745
746 if let Some(having_clause) = &mut select.having_clause {
747 for condition in &mut having_clause.conditions {
748 Self::rewrite_nested_select_aliases_in_condition(condition, schema)?;
749 }
750 }
751
752 if let Some(set_operation) = &mut select.set_operation {
753 Self::rewrite_select_aliases_in_select(&mut set_operation.right, schema)?;
754 }
755
756 Ok(())
757 }
758
759 fn rewrite_select_aliases_in_table_reference(
760 from: &mut TableReference,
761 schema: &Schema,
762 ) -> Result<()> {
763 match from {
764 TableReference::Derived { subquery, .. } => {
765 Self::rewrite_select_aliases_in_select(subquery, schema)
766 }
767 TableReference::CrossJoin(left, right) => {
768 Self::rewrite_select_aliases_in_table_reference(left, schema)?;
769 Self::rewrite_select_aliases_in_table_reference(right, schema)
770 }
771 TableReference::InnerJoin { left, right, on }
772 | TableReference::LeftJoin { left, right, on }
773 | TableReference::RightJoin { left, right, on }
774 | TableReference::FullOuterJoin { left, right, on } => {
775 Self::rewrite_select_aliases_in_table_reference(left, schema)?;
776 Self::rewrite_select_aliases_in_table_reference(right, schema)?;
777 Self::rewrite_nested_select_aliases_in_condition(on, schema)
778 }
779 TableReference::Table(_, _) => Ok(()),
780 }
781 }
782
783 fn rewrite_nested_select_aliases_in_condition(
784 condition: &mut Condition,
785 schema: &Schema,
786 ) -> Result<()> {
787 let mut rewrite = |subquery: &mut SelectStatement| {
788 Self::rewrite_select_aliases_in_select(subquery, schema)
789 };
790 Self::rewrite_nested_subqueries_in_condition(condition, &mut rewrite)
791 }
792
793 fn rewrite_nested_select_aliases_in_expression(
794 expr: &mut Expression,
795 schema: &Schema,
796 ) -> Result<()> {
797 let mut rewrite = |subquery: &mut SelectStatement| {
798 Self::rewrite_select_aliases_in_select(subquery, schema)
799 };
800 Self::rewrite_nested_subqueries_in_expression(expr, &mut rewrite)
801 }
802
803 fn rewrite_nested_subqueries_in_condition<F>(
804 condition: &mut Condition,
805 on_subquery: &mut F,
806 ) -> Result<()>
807 where
808 F: FnMut(&mut SelectStatement) -> Result<()>,
809 {
810 match condition {
811 Condition::Comparison { left, right, .. } => {
812 Self::rewrite_nested_subqueries_in_expression(left, on_subquery)?;
813 Self::rewrite_nested_subqueries_in_expression(right, on_subquery)?;
814 }
815 Condition::InList { expr, values, .. } => {
816 Self::rewrite_nested_subqueries_in_expression(expr, on_subquery)?;
817 for value in values {
818 Self::rewrite_nested_subqueries_in_expression(value, on_subquery)?;
819 }
820 }
821 Condition::InSubquery { expr, subquery, .. } => {
822 Self::rewrite_nested_subqueries_in_expression(expr, on_subquery)?;
823 on_subquery(subquery)?;
824 }
825 Condition::Between {
826 expr, lower, upper, ..
827 } => {
828 Self::rewrite_nested_subqueries_in_expression(expr, on_subquery)?;
829 Self::rewrite_nested_subqueries_in_expression(lower, on_subquery)?;
830 Self::rewrite_nested_subqueries_in_expression(upper, on_subquery)?;
831 }
832 Condition::Like { expr, pattern, .. } => {
833 Self::rewrite_nested_subqueries_in_expression(expr, on_subquery)?;
834 Self::rewrite_nested_subqueries_in_expression(pattern, on_subquery)?;
835 }
836 Condition::Exists { subquery, .. } => on_subquery(subquery)?,
837 Condition::NullCheck { expr, .. } => {
838 Self::rewrite_nested_subqueries_in_expression(expr, on_subquery)?;
839 }
840 Condition::Not(inner) => {
841 Self::rewrite_nested_subqueries_in_condition(inner, on_subquery)?
842 }
843 Condition::Logical { left, right, .. } => {
844 Self::rewrite_nested_subqueries_in_condition(left, on_subquery)?;
845 Self::rewrite_nested_subqueries_in_condition(right, on_subquery)?;
846 }
847 }
848 Ok(())
849 }
850
851 fn rewrite_nested_subqueries_in_expression<F>(
852 expr: &mut Expression,
853 on_subquery: &mut F,
854 ) -> Result<()>
855 where
856 F: FnMut(&mut SelectStatement) -> Result<()>,
857 {
858 match expr {
859 Expression::ScalarSubquery(subquery) => on_subquery(subquery),
860 Expression::Cast { expr, .. }
861 | Expression::UnaryMinus(expr)
862 | Expression::UnaryNot(expr)
863 | Expression::NullCheck { expr, .. } => {
864 Self::rewrite_nested_subqueries_in_expression(expr, on_subquery)
865 }
866 Expression::Case {
867 branches,
868 else_expr,
869 } => {
870 for branch in branches {
871 Self::rewrite_nested_subqueries_in_expression(
872 &mut branch.condition,
873 on_subquery,
874 )?;
875 Self::rewrite_nested_subqueries_in_expression(&mut branch.result, on_subquery)?;
876 }
877 if let Some(else_expr) = else_expr {
878 Self::rewrite_nested_subqueries_in_expression(else_expr, on_subquery)?;
879 }
880 Ok(())
881 }
882 Expression::ScalarFunctionCall { args, .. } => {
883 for arg in args {
884 Self::rewrite_nested_subqueries_in_expression(arg, on_subquery)?;
885 }
886 Ok(())
887 }
888 Expression::Binary { left, right, .. }
889 | Expression::Comparison { left, right, .. }
890 | Expression::Logical { left, right, .. } => {
891 Self::rewrite_nested_subqueries_in_expression(left, on_subquery)?;
892 Self::rewrite_nested_subqueries_in_expression(right, on_subquery)
893 }
894 Expression::InList { expr, values, .. } => {
895 Self::rewrite_nested_subqueries_in_expression(expr, on_subquery)?;
896 for value in values {
897 Self::rewrite_nested_subqueries_in_expression(value, on_subquery)?;
898 }
899 Ok(())
900 }
901 Expression::InSubquery { expr, subquery, .. } => {
902 Self::rewrite_nested_subqueries_in_expression(expr, on_subquery)?;
903 on_subquery(subquery)
904 }
905 Expression::Between {
906 expr, lower, upper, ..
907 } => {
908 Self::rewrite_nested_subqueries_in_expression(expr, on_subquery)?;
909 Self::rewrite_nested_subqueries_in_expression(lower, on_subquery)?;
910 Self::rewrite_nested_subqueries_in_expression(upper, on_subquery)
911 }
912 Expression::Like { expr, pattern, .. } => {
913 Self::rewrite_nested_subqueries_in_expression(expr, on_subquery)?;
914 Self::rewrite_nested_subqueries_in_expression(pattern, on_subquery)
915 }
916 Expression::Exists { subquery, .. } => on_subquery(subquery),
917 Expression::AggregateCall { .. }
918 | Expression::Column(_)
919 | Expression::Literal(_)
920 | Expression::IntervalLiteral { .. }
921 | Expression::Parameter(_) => Ok(()),
922 }
923 }
924
925 fn where_alias_map(select: &SelectStatement) -> HashMap<String, Expression> {
926 let mut aliases = HashMap::new();
927 for (index, alias) in select.column_aliases.iter().enumerate() {
928 let Some(alias) = alias.as_ref() else {
929 continue;
930 };
931 let Some(item) = select.columns.get(index) else {
932 continue;
933 };
934
935 let replacement = match item {
936 crate::parser::ast::SelectItem::Column(name) => Expression::Column(name.clone()),
937 crate::parser::ast::SelectItem::Expression(expr) => expr.clone(),
938 _ => continue,
939 };
940 aliases.insert(alias.clone(), replacement);
941 }
942 aliases
943 }
944
945 fn rewrite_where_aliases_in_condition(
946 condition: &mut Condition,
947 aliases: &HashMap<String, Expression>,
948 source_columns: &HashSet<String>,
949 active_aliases: &mut HashSet<String>,
950 ) -> Result<()> {
951 match condition {
952 Condition::Comparison { left, right, .. } => {
953 Self::rewrite_where_aliases_in_expression(
954 left,
955 aliases,
956 source_columns,
957 active_aliases,
958 )?;
959 Self::rewrite_where_aliases_in_expression(
960 right,
961 aliases,
962 source_columns,
963 active_aliases,
964 )?;
965 }
966 Condition::InList { expr, values, .. } => {
967 Self::rewrite_where_aliases_in_expression(
968 expr,
969 aliases,
970 source_columns,
971 active_aliases,
972 )?;
973 for value in values {
974 Self::rewrite_where_aliases_in_expression(
975 value,
976 aliases,
977 source_columns,
978 active_aliases,
979 )?;
980 }
981 }
982 Condition::InSubquery { expr, .. } => {
983 Self::rewrite_where_aliases_in_expression(
984 expr,
985 aliases,
986 source_columns,
987 active_aliases,
988 )?;
989 }
990 Condition::Between {
991 expr, lower, upper, ..
992 } => {
993 Self::rewrite_where_aliases_in_expression(
994 expr,
995 aliases,
996 source_columns,
997 active_aliases,
998 )?;
999 Self::rewrite_where_aliases_in_expression(
1000 lower,
1001 aliases,
1002 source_columns,
1003 active_aliases,
1004 )?;
1005 Self::rewrite_where_aliases_in_expression(
1006 upper,
1007 aliases,
1008 source_columns,
1009 active_aliases,
1010 )?;
1011 }
1012 Condition::Like { expr, pattern, .. } => {
1013 Self::rewrite_where_aliases_in_expression(
1014 expr,
1015 aliases,
1016 source_columns,
1017 active_aliases,
1018 )?;
1019 Self::rewrite_where_aliases_in_expression(
1020 pattern,
1021 aliases,
1022 source_columns,
1023 active_aliases,
1024 )?;
1025 }
1026 Condition::Exists { .. } => {}
1027 Condition::NullCheck { expr, .. } => {
1028 Self::rewrite_where_aliases_in_expression(
1029 expr,
1030 aliases,
1031 source_columns,
1032 active_aliases,
1033 )?;
1034 }
1035 Condition::Not(inner) => {
1036 Self::rewrite_where_aliases_in_condition(
1037 inner,
1038 aliases,
1039 source_columns,
1040 active_aliases,
1041 )?;
1042 }
1043 Condition::Logical { left, right, .. } => {
1044 Self::rewrite_where_aliases_in_condition(
1045 left,
1046 aliases,
1047 source_columns,
1048 active_aliases,
1049 )?;
1050 Self::rewrite_where_aliases_in_condition(
1051 right,
1052 aliases,
1053 source_columns,
1054 active_aliases,
1055 )?;
1056 }
1057 }
1058 Ok(())
1059 }
1060
1061 fn rewrite_where_aliases_in_expression(
1062 expr: &mut Expression,
1063 aliases: &HashMap<String, Expression>,
1064 source_columns: &HashSet<String>,
1065 active_aliases: &mut HashSet<String>,
1066 ) -> Result<()> {
1067 match expr {
1068 Expression::Column(name) => {
1069 if SelectStatement::split_column_reference(name).0.is_some()
1070 || source_columns.contains(name)
1071 {
1072 return Ok(());
1073 }
1074
1075 let Some(replacement) = aliases.get(name).cloned() else {
1076 return Ok(());
1077 };
1078
1079 if !active_aliases.insert(name.clone()) {
1080 return Err(HematiteError::ParseError(format!(
1081 "Select alias '{}' is recursively defined",
1082 name
1083 )));
1084 }
1085
1086 let mut replacement = replacement;
1087 Self::rewrite_where_aliases_in_expression(
1088 &mut replacement,
1089 aliases,
1090 source_columns,
1091 active_aliases,
1092 )?;
1093 active_aliases.remove(name);
1094 *expr = replacement;
1095 }
1096 Expression::Cast { expr, .. }
1097 | Expression::UnaryMinus(expr)
1098 | Expression::UnaryNot(expr)
1099 | Expression::NullCheck { expr, .. } => {
1100 Self::rewrite_where_aliases_in_expression(
1101 expr,
1102 aliases,
1103 source_columns,
1104 active_aliases,
1105 )?;
1106 }
1107 Expression::Case {
1108 branches,
1109 else_expr,
1110 } => {
1111 for branch in branches {
1112 Self::rewrite_where_aliases_in_expression(
1113 &mut branch.condition,
1114 aliases,
1115 source_columns,
1116 active_aliases,
1117 )?;
1118 Self::rewrite_where_aliases_in_expression(
1119 &mut branch.result,
1120 aliases,
1121 source_columns,
1122 active_aliases,
1123 )?;
1124 }
1125 if let Some(else_expr) = else_expr {
1126 Self::rewrite_where_aliases_in_expression(
1127 else_expr,
1128 aliases,
1129 source_columns,
1130 active_aliases,
1131 )?;
1132 }
1133 }
1134 Expression::ScalarFunctionCall { args, .. } => {
1135 for arg in args {
1136 Self::rewrite_where_aliases_in_expression(
1137 arg,
1138 aliases,
1139 source_columns,
1140 active_aliases,
1141 )?;
1142 }
1143 }
1144 Expression::Binary { left, right, .. }
1145 | Expression::Comparison { left, right, .. }
1146 | Expression::Logical { left, right, .. } => {
1147 Self::rewrite_where_aliases_in_expression(
1148 left,
1149 aliases,
1150 source_columns,
1151 active_aliases,
1152 )?;
1153 Self::rewrite_where_aliases_in_expression(
1154 right,
1155 aliases,
1156 source_columns,
1157 active_aliases,
1158 )?;
1159 }
1160 Expression::InList { expr, values, .. } => {
1161 Self::rewrite_where_aliases_in_expression(
1162 expr,
1163 aliases,
1164 source_columns,
1165 active_aliases,
1166 )?;
1167 for value in values {
1168 Self::rewrite_where_aliases_in_expression(
1169 value,
1170 aliases,
1171 source_columns,
1172 active_aliases,
1173 )?;
1174 }
1175 }
1176 Expression::Between {
1177 expr, lower, upper, ..
1178 } => {
1179 Self::rewrite_where_aliases_in_expression(
1180 expr,
1181 aliases,
1182 source_columns,
1183 active_aliases,
1184 )?;
1185 Self::rewrite_where_aliases_in_expression(
1186 lower,
1187 aliases,
1188 source_columns,
1189 active_aliases,
1190 )?;
1191 Self::rewrite_where_aliases_in_expression(
1192 upper,
1193 aliases,
1194 source_columns,
1195 active_aliases,
1196 )?;
1197 }
1198 Expression::Like { expr, pattern, .. } => {
1199 Self::rewrite_where_aliases_in_expression(
1200 expr,
1201 aliases,
1202 source_columns,
1203 active_aliases,
1204 )?;
1205 Self::rewrite_where_aliases_in_expression(
1206 pattern,
1207 aliases,
1208 source_columns,
1209 active_aliases,
1210 )?;
1211 }
1212 Expression::AggregateCall { .. }
1213 | Expression::ScalarSubquery(_)
1214 | Expression::InSubquery { .. }
1215 | Expression::Exists { .. }
1216 | Expression::Literal(_)
1217 | Expression::IntervalLiteral { .. }
1218 | Expression::Parameter(_) => {}
1219 }
1220 Ok(())
1221 }
1222
1223 pub(crate) fn execute_statement(
1224 &mut self,
1225 statement: crate::parser::ast::Statement,
1226 ) -> Result<QueryResult> {
1227 match statement {
1228 crate::parser::ast::Statement::Begin => {
1229 self.begin_active_transaction()?;
1230 return Ok(Self::empty_result());
1231 }
1232 crate::parser::ast::Statement::Commit => {
1233 self.commit_active_transaction()?;
1234 return Ok(Self::empty_result());
1235 }
1236 crate::parser::ast::Statement::Rollback => {
1237 self.rollback_active_transaction()?;
1238 return Ok(Self::empty_result());
1239 }
1240 crate::parser::ast::Statement::Savepoint(name) => {
1241 self.create_savepoint(&name)?;
1242 return Ok(Self::empty_result());
1243 }
1244 crate::parser::ast::Statement::RollbackToSavepoint(name) => {
1245 self.rollback_to_savepoint(&name)?;
1246 return Ok(Self::empty_result());
1247 }
1248 crate::parser::ast::Statement::ReleaseSavepoint(name) => {
1249 self.release_savepoint(&name)?;
1250 return Ok(Self::empty_result());
1251 }
1252 crate::parser::ast::Statement::Explain(explain) => {
1253 return self.execute_explain_statement(*explain.statement);
1254 }
1255 crate::parser::ast::Statement::Describe(describe) => {
1256 return self.execute_describe_statement(&describe.table);
1257 }
1258 crate::parser::ast::Statement::ShowTables => {
1259 return self.execute_show_tables_statement();
1260 }
1261 crate::parser::ast::Statement::ShowViews => {
1262 return self.execute_show_views_statement();
1263 }
1264 crate::parser::ast::Statement::ShowIndexes(table_name) => {
1265 return self.execute_show_indexes_statement(table_name.as_deref());
1266 }
1267 crate::parser::ast::Statement::ShowTriggers(table_name) => {
1268 return self.execute_show_triggers_statement(table_name.as_deref());
1269 }
1270 crate::parser::ast::Statement::ShowCreateTable(table_name) => {
1271 return self.execute_show_create_table_statement(&table_name);
1272 }
1273 crate::parser::ast::Statement::ShowCreateView(view_name) => {
1274 return self.execute_show_create_view_statement(&view_name);
1275 }
1276 crate::parser::ast::Statement::SelectInto(select_into) => {
1277 return self.execute_select_into_statement(select_into);
1278 }
1279 crate::parser::ast::Statement::CreateView(create_view) => {
1280 return self.execute_create_view_statement(create_view);
1281 }
1282 crate::parser::ast::Statement::DropView(drop_view) => {
1283 return self.execute_drop_view_statement(&drop_view.view, drop_view.if_exists);
1284 }
1285 crate::parser::ast::Statement::CreateTrigger(create_trigger) => {
1286 return self.execute_create_trigger_statement(create_trigger);
1287 }
1288 crate::parser::ast::Statement::DropTrigger(drop_trigger) => {
1289 return self
1290 .execute_drop_trigger_statement(&drop_trigger.trigger, drop_trigger.if_exists);
1291 }
1292 _ => {}
1293 }
1294
1295 if statement.is_read_only() {
1296 return self.execute_read_statement(statement);
1297 }
1298
1299 self.execute_mutating_statement(statement)
1300 }
1301
1302 fn execute_explain_statement(
1303 &mut self,
1304 statement: crate::parser::ast::Statement,
1305 ) -> Result<QueryResult> {
1306 let statement = match statement {
1307 Statement::SelectInto(select_into) => Statement::Select(select_into.query),
1308 other => other,
1309 };
1310 let (schema, table_row_counts) = self.read_planning_state()?;
1311 let statement = Self::expand_views_in_statement(statement, &schema)?;
1312 let planner = QueryPlanner::new(schema).with_table_row_counts(table_row_counts);
1313 let plan = planner.plan(statement)?;
1314 Ok(QueryResult {
1315 affected_rows: 0,
1316 columns: vec!["kind".to_string(), "detail".to_string()],
1317 rows: vec![
1318 vec![
1319 Value::Text("node".to_string()),
1320 Value::Text(format!("{:?}", plan.node)),
1321 ],
1322 vec![
1323 Value::Text("estimated_cost".to_string()),
1324 Value::Text(format!("{:.2}", plan.estimated_cost)),
1325 ],
1326 ],
1327 })
1328 }
1329
1330 fn execute_describe_statement(&mut self, table_name: &str) -> Result<QueryResult> {
1331 let catalog_guard = self.lock_catalog()?;
1332 query_metadata::describe_table(&catalog_guard, table_name)
1333 }
1334
1335 fn execute_show_tables_statement(&mut self) -> Result<QueryResult> {
1336 let catalog_guard = self.lock_catalog()?;
1337 query_metadata::show_tables(&catalog_guard)
1338 }
1339
1340 fn execute_show_views_statement(&mut self) -> Result<QueryResult> {
1341 let catalog_guard = self.lock_catalog()?;
1342 query_metadata::show_views(&catalog_guard)
1343 }
1344
1345 fn execute_show_indexes_statement(&mut self, table_name: Option<&str>) -> Result<QueryResult> {
1346 let catalog_guard = self.lock_catalog()?;
1347 query_metadata::show_indexes(&catalog_guard, table_name)
1348 }
1349
1350 fn execute_show_triggers_statement(&mut self, table_name: Option<&str>) -> Result<QueryResult> {
1351 let catalog_guard = self.lock_catalog()?;
1352 query_metadata::show_triggers(&catalog_guard, table_name)
1353 }
1354
1355 fn execute_show_create_table_statement(&mut self, table_name: &str) -> Result<QueryResult> {
1356 let catalog_guard = self.lock_catalog()?;
1357 query_metadata::show_create_table(&catalog_guard, table_name)
1358 }
1359
1360 fn execute_show_create_view_statement(&mut self, view_name: &str) -> Result<QueryResult> {
1361 let catalog_guard = self.lock_catalog()?;
1362 query_metadata::show_create_view(&catalog_guard, view_name)
1363 }
1364
1365 fn execute_select_into_statement(
1366 &mut self,
1367 statement: SelectIntoStatement,
1368 ) -> Result<QueryResult> {
1369 let (schema, _) = self.read_planning_state()?;
1370 if schema.get_table_by_name(&statement.table).is_some()
1371 || schema.view(&statement.table).is_some()
1372 {
1373 return Err(HematiteError::ParseError(format!(
1374 "Table '{}' already exists",
1375 statement.table
1376 )));
1377 }
1378
1379 let normalized_query =
1380 match Self::normalize_statement(Statement::Select(statement.query.clone()), &schema)? {
1381 Statement::Select(select) => select,
1382 _ => unreachable!("normalized SELECT INTO query should remain a select"),
1383 };
1384 validate_statement(&Statement::Select(normalized_query), &schema)?;
1385
1386 let query_result =
1387 self.execute_read_statement(Statement::Select(statement.query.clone()))?;
1388 let projected_columns = Self::infer_select_into_columns(&query_result)?;
1389 let insert_columns = projected_columns
1390 .iter()
1391 .map(|column| column.name.clone())
1392 .collect::<Vec<_>>();
1393 let synthetic_pk = Self::select_into_synthetic_pk_name(&insert_columns);
1394
1395 let mut create_columns = Vec::with_capacity(projected_columns.len() + 1);
1396 create_columns.push(ColumnDefinition {
1397 name: synthetic_pk,
1398 data_type: SqlTypeName::Int,
1399 character_set: None,
1400 collation: None,
1401 nullable: false,
1402 primary_key: true,
1403 auto_increment: true,
1404 unique: false,
1405 default_value: None,
1406 check_constraint: None,
1407 references: None,
1408 });
1409 create_columns.extend(projected_columns);
1410
1411 let mut implicit_mutation = Some(ImplicitMutation::begin(self)?);
1412 let result: Result<QueryResult> = (|| {
1413 self.execute_mutating_statement_in_scope(
1414 Statement::Create(CreateStatement {
1415 table: statement.table.clone(),
1416 columns: create_columns,
1417 constraints: Vec::new(),
1418 if_not_exists: false,
1419 }),
1420 false,
1421 )?;
1422
1423 let insert_result = self.execute_mutating_statement_in_scope(
1424 Statement::Insert(InsertStatement {
1425 table: statement.table.clone(),
1426 columns: insert_columns,
1427 source: InsertSource::Select(Box::new(statement.query)),
1428 on_duplicate: None,
1429 }),
1430 false,
1431 )?;
1432
1433 Ok(Self::mutation_result(insert_result.affected_rows))
1434 })();
1435
1436 match result {
1437 Ok(result) => {
1438 implicit_mutation
1439 .take()
1440 .expect("implicit mutation should be present")
1441 .commit(self)?;
1442 Ok(result)
1443 }
1444 Err(err) => {
1445 implicit_mutation
1446 .take()
1447 .expect("implicit mutation should be present")
1448 .rollback(self)?;
1449 Err(err)
1450 }
1451 }
1452 }
1453
1454 fn execute_create_view_statement(
1455 &mut self,
1456 statement: crate::parser::ast::CreateViewStatement,
1457 ) -> Result<QueryResult> {
1458 let mut implicit_mutation = Some(ImplicitMutation::begin(self)?);
1459 let result: Result<QueryResult> = (|| {
1460 let mut catalog_guard = self.lock_catalog()?;
1461 let schema = catalog_guard.clone_schema();
1462 let dependencies = statement.query.dependency_names();
1463 if dependencies
1464 .iter()
1465 .any(|dependency| dependency.eq_ignore_ascii_case(&statement.view))
1466 {
1467 return Err(HematiteError::ParseError(format!(
1468 "View '{}' cannot depend on itself",
1469 statement.view
1470 )));
1471 }
1472 let normalized_query = match Self::normalize_statement(
1473 Statement::Select(statement.query.clone()),
1474 &schema,
1475 )? {
1476 Statement::Select(select) => select,
1477 _ => unreachable!("normalized create view query should remain a select"),
1478 };
1479 validate_statement(
1480 &crate::parser::ast::Statement::CreateView(CreateViewStatement {
1481 view: statement.view.clone(),
1482 if_not_exists: statement.if_not_exists,
1483 query: normalized_query.clone(),
1484 }),
1485 &schema,
1486 )?;
1487
1488 if statement.if_not_exists && catalog_guard.get_view(&statement.view)?.is_some() {
1489 Ok(Self::mutation_result(0))
1490 } else {
1491 let column_names = projected_column_names(&normalized_query, &schema)?;
1492
1493 catalog_guard.create_view(crate::catalog::View {
1494 name: statement.view.clone(),
1495 query_sql: statement.query.to_sql(),
1496 column_names,
1497 dependencies,
1498 })?;
1499 Ok(Self::mutation_result(0))
1500 }
1501 })();
1502
1503 match result {
1504 Ok(result) => {
1505 implicit_mutation
1506 .take()
1507 .expect("implicit mutation should be present")
1508 .commit(self)?;
1509 Ok(result)
1510 }
1511 Err(err) => {
1512 implicit_mutation
1513 .take()
1514 .expect("implicit mutation should be present")
1515 .rollback(self)?;
1516 Err(err)
1517 }
1518 }
1519 }
1520
1521 fn execute_drop_view_statement(
1522 &mut self,
1523 view_name: &str,
1524 if_exists: bool,
1525 ) -> Result<QueryResult> {
1526 let mut implicit_mutation = Some(ImplicitMutation::begin(self)?);
1527 let result: Result<QueryResult> = (|| {
1528 let mut catalog_guard = self.lock_catalog()?;
1529 if if_exists && catalog_guard.get_view(view_name)?.is_none() {
1530 Ok(Self::mutation_result(0))
1531 } else {
1532 catalog_guard.drop_view(view_name)?;
1533 Ok(Self::mutation_result(0))
1534 }
1535 })();
1536
1537 match result {
1538 Ok(result) => {
1539 implicit_mutation
1540 .take()
1541 .expect("implicit mutation should be present")
1542 .commit(self)?;
1543 Ok(result)
1544 }
1545 Err(err) => {
1546 implicit_mutation
1547 .take()
1548 .expect("implicit mutation should be present")
1549 .rollback(self)?;
1550 Err(err)
1551 }
1552 }
1553 }
1554
1555 fn execute_create_trigger_statement(
1556 &mut self,
1557 statement: crate::parser::ast::CreateTriggerStatement,
1558 ) -> Result<QueryResult> {
1559 let mut implicit_mutation = Some(ImplicitMutation::begin(self)?);
1560 let result: Result<QueryResult> = (|| {
1561 let mut catalog_guard = self.lock_catalog()?;
1562 let schema = catalog_guard.clone_schema();
1563 validate_statement(
1564 &crate::parser::ast::Statement::CreateTrigger(statement.clone()),
1565 &schema,
1566 )?;
1567
1568 catalog_guard.create_trigger(crate::catalog::Trigger {
1569 name: statement.trigger.clone(),
1570 table_name: statement.table.clone(),
1571 event: match statement.event {
1572 TriggerEvent::Insert => crate::catalog::TriggerEvent::Insert,
1573 TriggerEvent::Update => crate::catalog::TriggerEvent::Update,
1574 TriggerEvent::Delete => crate::catalog::TriggerEvent::Delete,
1575 },
1576 body_sql: statement.body.to_sql(),
1577 old_alias: match statement.event {
1578 TriggerEvent::Insert => None,
1579 TriggerEvent::Update | TriggerEvent::Delete => Some("OLD".to_string()),
1580 },
1581 new_alias: match statement.event {
1582 TriggerEvent::Delete => None,
1583 TriggerEvent::Insert | TriggerEvent::Update => Some("NEW".to_string()),
1584 },
1585 })?;
1586 Ok(Self::mutation_result(0))
1587 })();
1588
1589 match result {
1590 Ok(result) => {
1591 implicit_mutation
1592 .take()
1593 .expect("implicit mutation should be present")
1594 .commit(self)?;
1595 Ok(result)
1596 }
1597 Err(err) => {
1598 implicit_mutation
1599 .take()
1600 .expect("implicit mutation should be present")
1601 .rollback(self)?;
1602 Err(err)
1603 }
1604 }
1605 }
1606
1607 fn execute_drop_trigger_statement(
1608 &mut self,
1609 trigger_name: &str,
1610 if_exists: bool,
1611 ) -> Result<QueryResult> {
1612 let mut implicit_mutation = Some(ImplicitMutation::begin(self)?);
1613 let result: Result<QueryResult> = (|| {
1614 let mut catalog_guard = self.lock_catalog()?;
1615 if if_exists && catalog_guard.get_trigger(trigger_name)?.is_none() {
1616 Ok(Self::mutation_result(0))
1617 } else {
1618 catalog_guard.drop_trigger(trigger_name)?;
1619 Ok(Self::mutation_result(0))
1620 }
1621 })();
1622
1623 match result {
1624 Ok(result) => {
1625 implicit_mutation
1626 .take()
1627 .expect("implicit mutation should be present")
1628 .commit(self)?;
1629 Ok(result)
1630 }
1631 Err(err) => {
1632 implicit_mutation
1633 .take()
1634 .expect("implicit mutation should be present")
1635 .rollback(self)?;
1636 Err(err)
1637 }
1638 }
1639 }
1640
1641 pub(crate) fn execute_statement_result(
1642 &mut self,
1643 statement: crate::parser::ast::Statement,
1644 ) -> Result<ExecutedStatement> {
1645 self.execute_statement(statement)
1646 .map(ExecutedStatement::from_query_result)
1647 }
1648
1649 fn execute_read_statement(
1650 &mut self,
1651 statement: crate::parser::ast::Statement,
1652 ) -> Result<QueryResult> {
1653 let (schema, mut executor) = self.plan_executor(statement)?;
1654
1655 let result = {
1656 let mut catalog_guard = self.lock_catalog()?;
1657 catalog_guard.with_read_engine(|engine| {
1658 let mut ctx = ExecutionContext::for_read(&schema, engine);
1659 executor.execute(&mut ctx)
1660 })?
1661 };
1662
1663 Ok(result)
1664 }
1665
1666 fn execute_mutating_statement(
1667 &mut self,
1668 statement: crate::parser::ast::Statement,
1669 ) -> Result<QueryResult> {
1670 self.execute_mutating_statement_in_scope(statement, true)
1671 }
1672
1673 fn execute_mutating_statement_in_scope(
1674 &mut self,
1675 statement: crate::parser::ast::Statement,
1676 use_implicit_mutation: bool,
1677 ) -> Result<QueryResult> {
1678 let persists_schema = statement.mutates_schema();
1679 let (schema, mut executor) = self.plan_executor(statement)?;
1680 let mut implicit_mutation = if use_implicit_mutation {
1681 Some(ImplicitMutation::begin(self)?)
1682 } else {
1683 None
1684 };
1685
1686 let execution_result = {
1687 let mut catalog_guard = self.lock_catalog()?;
1688 catalog_guard.with_engine(|engine| {
1689 let mut ctx = ExecutionContext::for_mutation(&schema, engine);
1690 let result = executor.execute(&mut ctx)?;
1691 Ok((result, ctx.catalog, ctx.mutation_events))
1692 })
1693 };
1694
1695 match execution_result {
1696 Ok((result, updated_schema, mutation_events)) => {
1697 if persists_schema {
1698 let mut catalog_guard = self.lock_catalog()?;
1699 if let Err(err) = catalog_guard.replace_schema(updated_schema) {
1700 drop(catalog_guard);
1701 if let Some(implicit_mutation) = implicit_mutation.take() {
1702 implicit_mutation.rollback(self)?;
1703 }
1704 return Err(err);
1705 }
1706 }
1707
1708 if let Err(err) = self.fire_triggers(mutation_events) {
1709 if let Some(implicit_mutation) = implicit_mutation.take() {
1710 implicit_mutation.rollback(self)?;
1711 }
1712 return Err(err);
1713 }
1714
1715 if let Some(implicit_mutation) = implicit_mutation.take() {
1716 implicit_mutation.commit(self)?;
1717 }
1718
1719 Ok(result)
1720 }
1721 Err(err) => {
1722 if let Some(implicit_mutation) = implicit_mutation.take() {
1723 implicit_mutation.rollback(self)?;
1724 }
1725 Err(err)
1726 }
1727 }
1728 }
1729
1730 fn plan_executor(
1731 &self,
1732 statement: crate::parser::ast::Statement,
1733 ) -> Result<(Schema, Box<dyn QueryExecutor>)> {
1734 let (schema, table_row_counts) = self.read_planning_state()?;
1735 let statement = Self::normalize_statement(statement, &schema)?;
1736 let planner = QueryPlanner::new(schema.clone()).with_table_row_counts(table_row_counts);
1737 let plan = planner.plan(statement)?;
1738 Ok((schema, plan.into_executor()))
1739 }
1740
1741 fn read_planning_state(&self) -> Result<(Schema, HashMap<String, usize>)> {
1742 let mut catalog_guard = self.lock_catalog()?;
1743 let schema = catalog_guard.clone_schema();
1744 let table_row_counts =
1745 catalog_guard.with_engine(|engine| Ok(Self::collect_table_row_counts(engine)))?;
1746 Ok((schema, table_row_counts))
1747 }
1748
1749 fn collect_table_row_counts(engine: &CatalogEngine) -> HashMap<String, usize> {
1750 engine
1751 .get_table_metadata()
1752 .iter()
1753 .map(|(name, metadata)| (name.clone(), metadata.row_count as usize))
1754 .collect()
1755 }
1756
1757 fn fire_triggers(&mut self, mutation_events: Vec<MutationEvent>) -> Result<()> {
1758 if mutation_events.is_empty() {
1759 return Ok(());
1760 }
1761
1762 if self.trigger_depth >= 32 {
1763 return Err(HematiteError::ParseError(
1764 "Trigger recursion limit exceeded".to_string(),
1765 ));
1766 }
1767
1768 self.trigger_depth += 1;
1769 let result = (|| {
1770 for event in mutation_events {
1771 let (table_name, event_kind, old_row, new_row) = match event {
1772 MutationEvent::Insert {
1773 table_name,
1774 new_row,
1775 } => (
1776 table_name,
1777 crate::catalog::TriggerEvent::Insert,
1778 None,
1779 Some(new_row),
1780 ),
1781 MutationEvent::Update {
1782 table_name,
1783 old_row,
1784 new_row,
1785 } => (
1786 table_name,
1787 crate::catalog::TriggerEvent::Update,
1788 Some(old_row),
1789 Some(new_row),
1790 ),
1791 MutationEvent::Delete {
1792 table_name,
1793 old_row,
1794 } => (
1795 table_name,
1796 crate::catalog::TriggerEvent::Delete,
1797 Some(old_row),
1798 None,
1799 ),
1800 };
1801
1802 let (table, triggers) = {
1803 let catalog_guard = self.lock_catalog()?;
1804 let table = catalog_guard
1805 .get_table_by_name(&table_name)?
1806 .ok_or_else(|| {
1807 HematiteError::InternalError(format!(
1808 "Table '{}' disappeared while firing triggers",
1809 table_name
1810 ))
1811 })?;
1812 let mut triggers = catalog_guard
1813 .list_triggers()?
1814 .into_iter()
1815 .filter_map(|name| catalog_guard.get_trigger(&name).ok().flatten())
1816 .filter(|trigger| {
1817 trigger.table_name == table_name && trigger.event == event_kind
1818 })
1819 .collect::<Vec<_>>();
1820 triggers.sort_by(|left, right| left.name.cmp(&right.name));
1821 (table, triggers)
1822 };
1823
1824 for trigger in triggers {
1825 let trigger_statement =
1826 Self::parse_statement(&format!("{};", trigger.body_sql))?;
1827 let trigger_statement = substitute_trigger_statement(
1828 trigger_statement,
1829 &table,
1830 old_row.as_ref(),
1831 new_row.as_ref(),
1832 );
1833 if trigger_statement.is_read_only() {
1834 let _ = self.execute_read_statement(trigger_statement)?;
1835 } else {
1836 let _ =
1837 self.execute_mutating_statement_in_scope(trigger_statement, false)?;
1838 }
1839 }
1840 }
1841 Ok(())
1842 })();
1843 self.trigger_depth -= 1;
1844 result
1845 }
1846
1847 pub fn close(&mut self) -> Result<()> {
1848 if self.transaction.is_some() {
1849 return Err(HematiteError::InternalError(
1850 "Cannot close connection with an active transaction".to_string(),
1851 ));
1852 }
1853 let mut catalog_guard = self.lock_catalog()?;
1854 catalog_guard.flush()
1855 }
1856
1857 pub fn journal_mode(&self) -> Result<JournalMode> {
1858 let catalog_guard = self.lock_catalog()?;
1859 catalog_guard.journal_mode()
1860 }
1861
1862 pub fn set_journal_mode(&mut self, journal_mode: JournalMode) -> Result<()> {
1863 let mut catalog_guard = self.lock_catalog()?;
1864 catalog_guard.set_journal_mode(journal_mode)
1865 }
1866
1867 pub fn checkpoint_wal(&mut self) -> Result<()> {
1868 let mut catalog_guard = self.lock_catalog()?;
1869 catalog_guard.checkpoint_wal()
1870 }
1871
1872 pub fn execute(&mut self, sql: &str) -> Result<QueryResult> {
1873 self.execute_statement(Self::parse_statement(sql)?)
1874 }
1875
1876 pub fn execute_result(&mut self, sql: &str) -> Result<ExecutedStatement> {
1877 self.execute(sql).map(ExecutedStatement::from_query_result)
1878 }
1879
1880 pub fn iter_script<'a>(&'a mut self, sql: &str) -> Result<ScriptIter<'a>> {
1881 Ok(ScriptIter::new(self, split_script_tokens(sql)?))
1882 }
1883
1884 pub fn execute_batch(&mut self, sql: &str) -> Result<()> {
1885 for result in self.iter_script(sql)? {
1886 result?;
1887 }
1888 Ok(())
1889 }
1890
1891 pub fn execute_query(&mut self, sql: &str) -> Result<QueryResult> {
1892 self.execute(sql)
1893 }
1894
1895 pub fn prepare(&self, sql: &str) -> Result<PreparedStatement> {
1896 let statement = Self::parse_statement(sql)?;
1897 let parameter_count = statement.parameter_count();
1898
1899 Ok(PreparedStatement {
1900 statement,
1901 parameters: vec![None; parameter_count],
1902 })
1903 }
1904
1905 pub fn begin_transaction(&'_ mut self) -> Result<Transaction<'_>> {
1906 self.begin_active_transaction()?;
1907 Ok(Transaction {
1908 connection: self,
1909 completed: false,
1910 })
1911 }
1912
1913 fn begin_active_transaction(&mut self) -> Result<()> {
1914 if self.transaction.is_some() {
1915 return Err(HematiteError::InternalError(
1916 "Transaction is already active".to_string(),
1917 ));
1918 }
1919
1920 let mut catalog_guard = self.lock_catalog()?;
1921 let snapshot = catalog_guard.snapshot()?;
1922 catalog_guard.begin_transaction()?;
1923 drop(catalog_guard);
1924 self.transaction = Some(ConnectionTransaction {
1925 snapshot,
1926 savepoints: Vec::new(),
1927 });
1928 Ok(())
1929 }
1930
1931 #[cfg(test)]
1932 pub(crate) fn schema_snapshot(&self) -> Result<Schema> {
1933 let catalog_guard = self.lock_catalog()?;
1934 Ok(catalog_guard.clone_schema())
1935 }
1936
1937 fn active_transaction_mut(&mut self, action: &str) -> Result<&mut ConnectionTransaction> {
1938 self.transaction.as_mut().ok_or_else(|| {
1939 HematiteError::ParseError(format!("{} requires an active transaction", action))
1940 })
1941 }
1942
1943 fn create_savepoint(&mut self, name: &str) -> Result<()> {
1944 {
1945 let transaction = self.active_transaction_mut("SAVEPOINT")?;
1946 if transaction
1947 .savepoints
1948 .iter()
1949 .any(|savepoint| savepoint.name.eq_ignore_ascii_case(name))
1950 {
1951 return Err(HematiteError::ParseError(format!(
1952 "Savepoint '{}' already exists",
1953 name
1954 )));
1955 }
1956 }
1957
1958 let snapshot = {
1959 let catalog_guard = self.lock_catalog()?;
1960 catalog_guard.snapshot()
1961 }?;
1962
1963 let transaction = self.active_transaction_mut("SAVEPOINT")?;
1964 transaction.savepoints.push(SavepointState {
1965 name: name.to_string(),
1966 snapshot,
1967 });
1968 Ok(())
1969 }
1970
1971 fn rollback_to_savepoint(&mut self, name: &str) -> Result<()> {
1972 let position = {
1973 let transaction = self.active_transaction_mut("ROLLBACK TO SAVEPOINT")?;
1974 transaction
1975 .savepoints
1976 .iter()
1977 .position(|savepoint| savepoint.name.eq_ignore_ascii_case(name))
1978 .ok_or_else(|| {
1979 HematiteError::ParseError(format!("Savepoint '{}' does not exist", name))
1980 })?
1981 };
1982
1983 let snapshot = {
1984 let transaction = self.active_transaction_mut("ROLLBACK TO SAVEPOINT")?;
1985 transaction.savepoints[position].snapshot.clone()
1986 };
1987
1988 {
1989 let mut catalog_guard = self.lock_catalog()?;
1990 catalog_guard.restore_snapshot(snapshot)?;
1991 }
1992
1993 let transaction = self.active_transaction_mut("ROLLBACK TO SAVEPOINT")?;
1994 transaction.savepoints.truncate(position + 1);
1995 Ok(())
1996 }
1997
1998 fn release_savepoint(&mut self, name: &str) -> Result<()> {
1999 let transaction = self.active_transaction_mut("RELEASE SAVEPOINT")?;
2000 let position = transaction
2001 .savepoints
2002 .iter()
2003 .position(|savepoint| savepoint.name.eq_ignore_ascii_case(name))
2004 .ok_or_else(|| {
2005 HematiteError::ParseError(format!("Savepoint '{}' does not exist", name))
2006 })?;
2007 transaction.savepoints.remove(position);
2008 Ok(())
2009 }
2010}
2011
2012fn substitute_trigger_statement(
2013 statement: Statement,
2014 table: &crate::catalog::Table,
2015 old_row: Option<&crate::catalog::StoredRow>,
2016 new_row: Option<&crate::catalog::StoredRow>,
2017) -> Statement {
2018 let mut bindings = HashMap::new();
2019 if let Some(old_row) = old_row {
2020 for (column, value) in table.columns.iter().zip(old_row.values.iter()) {
2021 bindings.insert(format!("OLD.{}", column.name), raise_literal_value(value));
2022 }
2023 }
2024 if let Some(new_row) = new_row {
2025 for (column, value) in table.columns.iter().zip(new_row.values.iter()) {
2026 bindings.insert(format!("NEW.{}", column.name), raise_literal_value(value));
2027 }
2028 }
2029
2030 substitute_statement_bindings(statement, &bindings)
2031}
2032
2033fn substitute_statement_bindings(
2034 statement: Statement,
2035 bindings: &HashMap<String, crate::parser::types::LiteralValue>,
2036) -> Statement {
2037 match statement {
2038 Statement::Select(select) => {
2039 Statement::Select(substitute_select_bindings(select, bindings))
2040 }
2041 Statement::Insert(insert) => Statement::Insert(crate::parser::ast::InsertStatement {
2042 table: insert.table,
2043 columns: insert.columns,
2044 source: match insert.source {
2045 InsertSource::Values(rows) => InsertSource::Values(
2046 rows.into_iter()
2047 .map(|row| {
2048 row.into_iter()
2049 .map(|expr| substitute_expression_bindings(expr, bindings))
2050 .collect()
2051 })
2052 .collect(),
2053 ),
2054 InsertSource::Select(select) => {
2055 InsertSource::Select(Box::new(substitute_select_bindings(*select, bindings)))
2056 }
2057 },
2058 on_duplicate: insert.on_duplicate.map(|assignments| {
2059 assignments
2060 .into_iter()
2061 .map(|assignment| crate::parser::ast::UpdateAssignment {
2062 column: assignment.column,
2063 value: substitute_expression_bindings(assignment.value, bindings),
2064 })
2065 .collect()
2066 }),
2067 }),
2068 Statement::Update(update) => Statement::Update(crate::parser::ast::UpdateStatement {
2069 table: update.table,
2070 target_binding: update.target_binding,
2071 source: update.source,
2072 assignments: update
2073 .assignments
2074 .into_iter()
2075 .map(|assignment| crate::parser::ast::UpdateAssignment {
2076 column: assignment.column,
2077 value: substitute_expression_bindings(assignment.value, bindings),
2078 })
2079 .collect(),
2080 where_clause: update
2081 .where_clause
2082 .map(|where_clause| substitute_where_clause_bindings(where_clause, bindings)),
2083 }),
2084 Statement::Delete(delete) => Statement::Delete(crate::parser::ast::DeleteStatement {
2085 table: delete.table,
2086 target_binding: delete.target_binding,
2087 source: delete.source,
2088 where_clause: delete
2089 .where_clause
2090 .map(|where_clause| substitute_where_clause_bindings(where_clause, bindings)),
2091 }),
2092 other => other,
2093 }
2094}
2095
2096fn substitute_select_bindings(
2097 select: SelectStatement,
2098 bindings: &HashMap<String, crate::parser::types::LiteralValue>,
2099) -> SelectStatement {
2100 SelectStatement {
2101 with_clause: select
2102 .with_clause
2103 .into_iter()
2104 .map(|cte| crate::parser::ast::CommonTableExpression {
2105 name: cte.name,
2106 recursive: cte.recursive,
2107 query: Box::new(substitute_select_bindings(*cte.query, bindings)),
2108 })
2109 .collect(),
2110 distinct: select.distinct,
2111 columns: select
2112 .columns
2113 .into_iter()
2114 .map(|item| match item {
2115 crate::parser::ast::SelectItem::Expression(expr) => {
2116 crate::parser::ast::SelectItem::Expression(substitute_expression_bindings(
2117 expr, bindings,
2118 ))
2119 }
2120 crate::parser::ast::SelectItem::Column(name) => bindings
2121 .get(&name)
2122 .cloned()
2123 .map(crate::parser::ast::Expression::Literal)
2124 .map(crate::parser::ast::SelectItem::Expression)
2125 .unwrap_or(crate::parser::ast::SelectItem::Column(name)),
2126 other => other,
2127 })
2128 .collect(),
2129 column_aliases: select.column_aliases,
2130 from: substitute_table_reference_bindings(select.from, bindings),
2131 where_clause: select
2132 .where_clause
2133 .map(|where_clause| substitute_where_clause_bindings(where_clause, bindings)),
2134 group_by: select
2135 .group_by
2136 .into_iter()
2137 .map(|expr| substitute_expression_bindings(expr, bindings))
2138 .collect(),
2139 having_clause: select
2140 .having_clause
2141 .map(|where_clause| substitute_where_clause_bindings(where_clause, bindings)),
2142 order_by: select.order_by,
2143 limit: select.limit,
2144 offset: select.offset,
2145 set_operation: select
2146 .set_operation
2147 .map(|set_operation| crate::parser::ast::SetOperation {
2148 operator: set_operation.operator,
2149 right: Box::new(substitute_select_bindings(*set_operation.right, bindings)),
2150 }),
2151 }
2152}
2153
2154fn substitute_table_reference_bindings(
2155 table_reference: TableReference,
2156 bindings: &HashMap<String, crate::parser::types::LiteralValue>,
2157) -> TableReference {
2158 match table_reference {
2159 TableReference::Table(name, alias) => TableReference::Table(name, alias),
2160 TableReference::Derived { subquery, alias } => TableReference::Derived {
2161 subquery: Box::new(substitute_select_bindings(*subquery, bindings)),
2162 alias,
2163 },
2164 TableReference::CrossJoin(left, right) => TableReference::CrossJoin(
2165 Box::new(substitute_table_reference_bindings(*left, bindings)),
2166 Box::new(substitute_table_reference_bindings(*right, bindings)),
2167 ),
2168 TableReference::InnerJoin { left, right, on } => TableReference::InnerJoin {
2169 left: Box::new(substitute_table_reference_bindings(*left, bindings)),
2170 right: Box::new(substitute_table_reference_bindings(*right, bindings)),
2171 on: substitute_condition_bindings(on, bindings),
2172 },
2173 TableReference::LeftJoin { left, right, on } => TableReference::LeftJoin {
2174 left: Box::new(substitute_table_reference_bindings(*left, bindings)),
2175 right: Box::new(substitute_table_reference_bindings(*right, bindings)),
2176 on: substitute_condition_bindings(on, bindings),
2177 },
2178 TableReference::RightJoin { left, right, on } => TableReference::RightJoin {
2179 left: Box::new(substitute_table_reference_bindings(*left, bindings)),
2180 right: Box::new(substitute_table_reference_bindings(*right, bindings)),
2181 on: substitute_condition_bindings(on, bindings),
2182 },
2183 TableReference::FullOuterJoin { left, right, on } => TableReference::FullOuterJoin {
2184 left: Box::new(substitute_table_reference_bindings(*left, bindings)),
2185 right: Box::new(substitute_table_reference_bindings(*right, bindings)),
2186 on: substitute_condition_bindings(on, bindings),
2187 },
2188 }
2189}
2190
2191fn substitute_where_clause_bindings(
2192 where_clause: WhereClause,
2193 bindings: &HashMap<String, crate::parser::types::LiteralValue>,
2194) -> WhereClause {
2195 WhereClause {
2196 conditions: where_clause
2197 .conditions
2198 .into_iter()
2199 .map(|condition| substitute_condition_bindings(condition, bindings))
2200 .collect(),
2201 }
2202}
2203
2204fn substitute_condition_bindings(
2205 condition: Condition,
2206 bindings: &HashMap<String, crate::parser::types::LiteralValue>,
2207) -> Condition {
2208 match condition {
2209 Condition::Comparison {
2210 left,
2211 operator,
2212 right,
2213 } => Condition::Comparison {
2214 left: substitute_expression_bindings(left, bindings),
2215 operator,
2216 right: substitute_expression_bindings(right, bindings),
2217 },
2218 Condition::InList {
2219 expr,
2220 values,
2221 is_not,
2222 } => Condition::InList {
2223 expr: substitute_expression_bindings(expr, bindings),
2224 values: values
2225 .into_iter()
2226 .map(|expr| substitute_expression_bindings(expr, bindings))
2227 .collect(),
2228 is_not,
2229 },
2230 Condition::InSubquery {
2231 expr,
2232 subquery,
2233 is_not,
2234 } => Condition::InSubquery {
2235 expr: substitute_expression_bindings(expr, bindings),
2236 subquery: Box::new(substitute_select_bindings(*subquery, bindings)),
2237 is_not,
2238 },
2239 Condition::Between {
2240 expr,
2241 lower,
2242 upper,
2243 is_not,
2244 } => Condition::Between {
2245 expr: substitute_expression_bindings(expr, bindings),
2246 lower: substitute_expression_bindings(lower, bindings),
2247 upper: substitute_expression_bindings(upper, bindings),
2248 is_not,
2249 },
2250 Condition::Like {
2251 expr,
2252 pattern,
2253 is_not,
2254 } => Condition::Like {
2255 expr: substitute_expression_bindings(expr, bindings),
2256 pattern: substitute_expression_bindings(pattern, bindings),
2257 is_not,
2258 },
2259 Condition::Exists { subquery, is_not } => Condition::Exists {
2260 subquery: Box::new(substitute_select_bindings(*subquery, bindings)),
2261 is_not,
2262 },
2263 Condition::NullCheck { expr, is_not } => Condition::NullCheck {
2264 expr: substitute_expression_bindings(expr, bindings),
2265 is_not,
2266 },
2267 Condition::Not(condition) => Condition::Not(Box::new(substitute_condition_bindings(
2268 *condition, bindings,
2269 ))),
2270 Condition::Logical {
2271 left,
2272 operator,
2273 right,
2274 } => Condition::Logical {
2275 left: Box::new(substitute_condition_bindings(*left, bindings)),
2276 operator,
2277 right: Box::new(substitute_condition_bindings(*right, bindings)),
2278 },
2279 }
2280}
2281
2282fn substitute_expression_bindings(
2283 expression: Expression,
2284 bindings: &HashMap<String, crate::parser::types::LiteralValue>,
2285) -> Expression {
2286 match expression {
2287 Expression::Column(name) => bindings
2288 .get(&name)
2289 .cloned()
2290 .map(Expression::Literal)
2291 .unwrap_or(Expression::Column(name)),
2292 Expression::Literal(_) | Expression::IntervalLiteral { .. } | Expression::Parameter(_) => {
2293 expression
2294 }
2295 Expression::ScalarSubquery(subquery) => {
2296 Expression::ScalarSubquery(Box::new(substitute_select_bindings(*subquery, bindings)))
2297 }
2298 Expression::Cast { expr, target_type } => Expression::Cast {
2299 expr: Box::new(substitute_expression_bindings(*expr, bindings)),
2300 target_type,
2301 },
2302 Expression::Case {
2303 branches,
2304 else_expr,
2305 } => Expression::Case {
2306 branches: branches
2307 .into_iter()
2308 .map(|branch| crate::parser::ast::CaseWhenClause {
2309 condition: substitute_expression_bindings(branch.condition, bindings),
2310 result: substitute_expression_bindings(branch.result, bindings),
2311 })
2312 .collect(),
2313 else_expr: else_expr
2314 .map(|expr| Box::new(substitute_expression_bindings(*expr, bindings))),
2315 },
2316 Expression::ScalarFunctionCall { function, args } => Expression::ScalarFunctionCall {
2317 function,
2318 args: args
2319 .into_iter()
2320 .map(|expr| substitute_expression_bindings(expr, bindings))
2321 .collect(),
2322 },
2323 Expression::AggregateCall { function, target } => {
2324 Expression::AggregateCall { function, target }
2325 }
2326 Expression::UnaryMinus(expr) => {
2327 Expression::UnaryMinus(Box::new(substitute_expression_bindings(*expr, bindings)))
2328 }
2329 Expression::UnaryNot(expr) => {
2330 Expression::UnaryNot(Box::new(substitute_expression_bindings(*expr, bindings)))
2331 }
2332 Expression::Binary {
2333 left,
2334 operator,
2335 right,
2336 } => Expression::Binary {
2337 left: Box::new(substitute_expression_bindings(*left, bindings)),
2338 operator,
2339 right: Box::new(substitute_expression_bindings(*right, bindings)),
2340 },
2341 Expression::Comparison {
2342 left,
2343 operator,
2344 right,
2345 } => Expression::Comparison {
2346 left: Box::new(substitute_expression_bindings(*left, bindings)),
2347 operator,
2348 right: Box::new(substitute_expression_bindings(*right, bindings)),
2349 },
2350 Expression::InList {
2351 expr,
2352 values,
2353 is_not,
2354 } => Expression::InList {
2355 expr: Box::new(substitute_expression_bindings(*expr, bindings)),
2356 values: values
2357 .into_iter()
2358 .map(|expr| substitute_expression_bindings(expr, bindings))
2359 .collect(),
2360 is_not,
2361 },
2362 Expression::InSubquery {
2363 expr,
2364 subquery,
2365 is_not,
2366 } => Expression::InSubquery {
2367 expr: Box::new(substitute_expression_bindings(*expr, bindings)),
2368 subquery: Box::new(substitute_select_bindings(*subquery, bindings)),
2369 is_not,
2370 },
2371 Expression::Between {
2372 expr,
2373 lower,
2374 upper,
2375 is_not,
2376 } => Expression::Between {
2377 expr: Box::new(substitute_expression_bindings(*expr, bindings)),
2378 lower: Box::new(substitute_expression_bindings(*lower, bindings)),
2379 upper: Box::new(substitute_expression_bindings(*upper, bindings)),
2380 is_not,
2381 },
2382 Expression::Like {
2383 expr,
2384 pattern,
2385 is_not,
2386 } => Expression::Like {
2387 expr: Box::new(substitute_expression_bindings(*expr, bindings)),
2388 pattern: Box::new(substitute_expression_bindings(*pattern, bindings)),
2389 is_not,
2390 },
2391 Expression::Exists { subquery, is_not } => Expression::Exists {
2392 subquery: Box::new(substitute_select_bindings(*subquery, bindings)),
2393 is_not,
2394 },
2395 Expression::NullCheck { expr, is_not } => Expression::NullCheck {
2396 expr: Box::new(substitute_expression_bindings(*expr, bindings)),
2397 is_not,
2398 },
2399 Expression::Logical {
2400 left,
2401 operator,
2402 right,
2403 } => Expression::Logical {
2404 left: Box::new(substitute_expression_bindings(*left, bindings)),
2405 operator,
2406 right: Box::new(substitute_expression_bindings(*right, bindings)),
2407 },
2408 }
2409}
2410
2411#[derive(Debug, Clone)]
2412pub struct PreparedStatement {
2413 statement: crate::parser::ast::Statement,
2414 parameters: Vec<Option<Value>>,
2415}
2416
2417impl PreparedStatement {
2418 pub fn bind(&mut self, index: usize, value: Value) -> Result<()> {
2419 if index == 0 || index > self.parameters.len() {
2420 return Err(HematiteError::ParseError(format!(
2421 "Parameter index {} is out of range",
2422 index
2423 )));
2424 }
2425
2426 self.parameters[index - 1] = Some(value);
2427 Ok(())
2428 }
2429
2430 pub fn bind_all(&mut self, values: Vec<Value>) -> Result<()> {
2431 if values.len() != self.parameters.len() {
2432 return Err(HematiteError::ParseError(format!(
2433 "Expected {} parameters, got {}",
2434 self.parameters.len(),
2435 values.len()
2436 )));
2437 }
2438
2439 self.parameters = values.into_iter().map(Some).collect();
2440 Ok(())
2441 }
2442
2443 pub fn clear_bindings(&mut self) {
2444 self.parameters.fill(None);
2445 }
2446
2447 pub fn parameter_count(&self) -> usize {
2448 self.parameters.len()
2449 }
2450
2451 pub fn execute(&mut self, connection: &mut Connection) -> Result<QueryResult> {
2452 let statement = self.bound_statement()?;
2453 connection.execute_statement(statement)
2454 }
2455
2456 pub fn query(&mut self, connection: &mut Connection) -> Result<QueryResult> {
2457 self.execute(connection)
2458 }
2459
2460 fn bound_statement(&self) -> Result<crate::parser::ast::Statement> {
2461 let bound_values = self
2462 .parameters
2463 .iter()
2464 .enumerate()
2465 .map(|(index, value)| {
2466 value.clone().ok_or_else(|| {
2467 HematiteError::ParseError(format!("Parameter {} has not been bound", index + 1))
2468 })
2469 })
2470 .collect::<Result<Vec<_>>>()?;
2471 let bound_literals = bound_values
2472 .iter()
2473 .map(raise_literal_value)
2474 .collect::<Vec<_>>();
2475
2476 self.statement.bind_parameters(&bound_literals)
2477 }
2478}
2479
2480#[derive(Debug)]
2481pub struct Transaction<'a> {
2482 connection: &'a mut Connection,
2483 completed: bool,
2484}
2485
2486impl<'a> Transaction<'a> {
2487 pub fn execute(&mut self, sql: &str) -> Result<QueryResult> {
2488 self.connection.execute(sql)
2489 }
2490
2491 pub fn commit(&mut self) -> Result<()> {
2492 if self.completed {
2493 return Err(HematiteError::InternalError(
2494 "Transaction is already completed".to_string(),
2495 ));
2496 }
2497 self.connection.commit_active_transaction()?;
2498 self.completed = true;
2499 Ok(())
2500 }
2501
2502 pub fn rollback(&mut self) -> Result<()> {
2503 if self.completed {
2504 return Err(HematiteError::InternalError(
2505 "Transaction is already completed".to_string(),
2506 ));
2507 }
2508 self.connection.rollback_active_transaction()?;
2509 self.completed = true;
2510 Ok(())
2511 }
2512}
2513
2514impl<'a> Drop for Transaction<'a> {
2515 fn drop(&mut self) {
2516 if !self.completed {
2517 let _ = self.connection.rollback_active_transaction();
2518 }
2519 }
2520}
2521
2522#[derive(Debug, Clone)]
2523pub struct Database;
2524
2525impl Database {
2526 pub fn new() -> Self {
2527 Self
2528 }
2529
2530 pub fn open(database_path: &str) -> Result<Connection> {
2531 Connection::new(database_path)
2532 }
2533
2534 pub fn open_in_memory() -> Result<Connection> {
2535 Connection::new_in_memory()
2536 }
2537
2538 pub fn connect(&mut self, database_path: &str) -> Result<Connection> {
2539 Connection::new(database_path)
2540 }
2541}
2542
2543impl Default for Database {
2544 fn default() -> Self {
2545 Self::new()
2546 }
2547}
2548
2549impl Connection {
2550 fn take_active_transaction(&mut self, action: &str) -> Result<ConnectionTransaction> {
2551 self.transaction.take().ok_or_else(|| {
2552 HematiteError::InternalError(format!("No active transaction to {}", action))
2553 })
2554 }
2555
2556 fn commit_active_transaction(&mut self) -> Result<()> {
2557 let state = self.take_active_transaction("commit")?;
2558 let mut catalog_guard = self.lock_catalog()?;
2559 match catalog_guard.commit_transaction() {
2560 Ok(()) => Ok(()),
2561 Err(err) => {
2562 let _ = catalog_guard.rollback_transaction();
2563 catalog_guard.restore_snapshot(state.snapshot)?;
2564 Err(err)
2565 }
2566 }
2567 }
2568
2569 fn rollback_active_transaction(&mut self) -> Result<()> {
2570 let state = self.take_active_transaction("roll back")?;
2571 let mut catalog_guard = self.lock_catalog()?;
2572 catalog_guard.rollback_transaction()?;
2573 catalog_guard.restore_snapshot(state.snapshot)?;
2574 Ok(())
2575 }
2576}