otter_sql/
vm.rs

1//! The executor virtual machine, its registers and errors.
2//!
3//! See [`VirtualMachine`] and [`Register`].
4use hashbrown::HashMap;
5use permutation::permutation;
6use sqlparser::ast::DataType;
7use std::error::Error;
8use std::fmt::Display;
9
10use sqlparser::parser::ParserError;
11
12use crate::codegen::{codegen_ast, CodegenError};
13use crate::column::Column;
14use crate::expr::eval::ExprExecError;
15use crate::expr::Expr;
16use crate::ic::{Instruction, IntermediateCode};
17use crate::identifier::{ColumnRef, TableRef};
18use crate::parser::parse;
19use crate::schema::Schema;
20use crate::table::{Row, RowShared, Table};
21use crate::value::Value;
22use crate::{BoundedString, Database};
23
24const DEFAULT_DATABASE_NAME: &str = "default";
25
26/// An index that can be used to access a specific register.
27#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
28pub struct RegisterIndex(usize);
29
30impl RegisterIndex {
31    /// Get the next index in the sequence.
32    pub fn next_index(&self) -> RegisterIndex {
33        RegisterIndex(self.0 + 1)
34    }
35}
36
37impl Display for RegisterIndex {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        write!(f, "%{}", self.0)
40    }
41}
42
43/// An index that can be used as a reference to a table.
44#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
45pub struct TableIndex(usize);
46
47impl TableIndex {
48    /// Get the next index in the sequence.
49    pub fn next_index(&self) -> Self {
50        TableIndex(self.0 + 1)
51    }
52}
53
54/// Executor of an SQL query.
55pub struct VirtualMachine {
56    database: Database,
57    registers: HashMap<RegisterIndex, Register>,
58    tables: HashMap<TableIndex, Table>,
59    last_table_index: TableIndex,
60}
61
62impl VirtualMachine {
63    pub fn new(name: BoundedString) -> Self {
64        Self {
65            database: Database::new(name),
66            registers: Default::default(),
67            tables: Default::default(),
68            last_table_index: Default::default(),
69        }
70    }
71
72    /// Inserts a value for the register at the given index.
73    fn insert_register(&mut self, index: RegisterIndex, reg: Register) {
74        self.registers.insert(index.clone(), reg);
75    }
76
77    /// Gets the value for the register at the given index.
78    fn get_register(&mut self, index: &RegisterIndex) -> Option<&Register> {
79        self.registers.get(index)
80    }
81
82    /// Creates a new table with a temp name and returns its index.
83    fn new_temp_table(&mut self) -> TableIndex {
84        let index = self.last_table_index.next_index();
85        self.tables.insert(index, Table::new_temp(index.0));
86        self.last_table_index = index;
87        index
88    }
89
90    /// Creates a new empty table from another table (with the same schema)
91    fn new_table_from(&mut self, table: &TableIndex) -> TableIndex {
92        let table = self.tables.get(table).unwrap();
93        let index = self.last_table_index.next_index();
94        self.tables.insert(index, Table::new_from(table));
95        self.last_table_index = index;
96        index
97    }
98
99    /// Get a reference to an existing table at the given index.
100    fn table(&self, index: &TableIndex) -> Option<&Table> {
101        self.tables.get(index)
102    }
103
104    /// Drop an existing table from the VM.
105    ///
106    /// Note: does NOT remove the table from the schema (if it was added to a schema).
107    // TODO: ensure that IC gen calls this when a temp table is created.
108    fn drop_table(&mut self, index: &TableIndex) {
109        self.tables.remove(index);
110    }
111
112    /// Executes the given SQL.
113    pub fn execute(&mut self, code: &str) -> Result<Option<Table>, ExecutionError> {
114        let ast = parse(code)?;
115        let mut ret = None;
116        for stmt in ast {
117            let ic = codegen_ast(&stmt)?;
118            ret = self.execute_ic(&ic)?;
119        }
120        Ok(ret)
121    }
122
123    /// Executes the given intermediate code.
124    pub fn execute_ic(&mut self, ic: &IntermediateCode) -> Result<Option<Table>, RuntimeError> {
125        let mut ret = None;
126        for instr in &ic.instrs {
127            ret = self.execute_instr(instr)?;
128        }
129        Ok(ret)
130    }
131
132    /// Executes the given instruction.
133    fn execute_instr(&mut self, instr: &Instruction) -> Result<Option<Table>, RuntimeError> {
134        let _ = &self.database;
135        match instr {
136            Instruction::Value { index, value } => {
137                self.registers
138                    .insert(*index, Register::Value(value.clone()));
139            }
140            Instruction::Expr { index, expr } => {
141                self.registers.insert(*index, Register::Expr(expr.clone()));
142            }
143            Instruction::Source { index, name } => match name {
144                TableRef {
145                    schema_name: None,
146                    table_name: _,
147                } => {
148                    let table_index = self.find_table(self.database.default_schema(), name)?;
149                    self.registers
150                        .insert(*index, Register::TableRef(table_index));
151                }
152                TableRef {
153                    schema_name: Some(schema_name),
154                    table_name: _,
155                } => {
156                    let schema = if let Some(schema) = self.database.schema_by_name(schema_name) {
157                        schema
158                    } else {
159                        return Err(RuntimeError::SchemaNotFound(*schema_name));
160                    };
161
162                    let table_index = self.find_table(schema, name)?;
163                    self.registers
164                        .insert(*index, Register::TableRef(table_index));
165                }
166            },
167            Instruction::Empty { index } => {
168                let table_index = self.new_temp_table();
169                self.registers
170                    .insert(*index, Register::TableRef(table_index));
171            }
172            Instruction::NonExistent { index } => {
173                self.registers.insert(*index, Register::NonExistentTable);
174            }
175            Instruction::Return { index } => match self.registers.remove(index) {
176                None => return Err(RuntimeError::EmptyRegister(*index)),
177                Some(Register::TableRef(t)) => return Ok(Some(self.tables[&t].clone())),
178                Some(Register::Value(v)) => {
179                    let mut table = Table::new_temp(self.last_table_index.next_index().0);
180                    self.last_table_index = self.last_table_index.next_index();
181                    table.add_column(Column::new("?column?".into(), v.data_type(), vec![], false));
182                    table.new_row(vec![v]);
183                    return Ok(Some(table));
184                }
185                Some(register) => return Err(RuntimeError::CannotReturn(register.clone())),
186            },
187            Instruction::Filter { index, expr } => match self.registers.get(index) {
188                None => return Err(RuntimeError::EmptyRegister(*index)),
189                Some(Register::TableRef(table_index)) => {
190                    let table_index = *table_index;
191                    // TODO: should be safe to unwrap, but make it an error anyway?
192                    let table = self.tables.get(&table_index).unwrap();
193                    let filtered_data = table
194                        .raw_data
195                        .iter()
196                        .filter_map(|row| {
197                            match Expr::execute(expr, table, RowShared::from_raw(row, &table)) {
198                                Ok(val) => match val {
199                                    Value::Bool(b) => {
200                                        if b {
201                                            Some(Ok(row.clone()))
202                                        } else {
203                                            None
204                                        }
205                                    }
206                                    _ => Some(Err(RuntimeError::FilterWithNonBoolean(
207                                        expr.clone(),
208                                        val.clone(),
209                                    ))),
210                                },
211                                Err(e) => Some(Err(e.into())),
212                            }
213                        })
214                        .collect::<Result<_, _>>()?;
215                    let new_table_index = self.new_table_from(&table_index);
216                    self.tables.get_mut(&new_table_index).unwrap().raw_data = filtered_data;
217                    self.insert_register(*index, Register::TableRef(new_table_index));
218                }
219                Some(reg) => return Err(RuntimeError::RegisterNotATable("filter", reg.clone())),
220            },
221            Instruction::Project {
222                input,
223                output,
224                expr,
225                alias,
226            } => match (self.registers.get(input), self.registers.get(output)) {
227                (None, _) => return Err(RuntimeError::EmptyRegister(*input)),
228                (_, None) => return Err(RuntimeError::EmptyRegister(*output)),
229                (Some(Register::NonExistentTable), Some(Register::TableRef(out_table_index))) => {
230                    let out_table = self.tables.get_mut(out_table_index).unwrap();
231                    // we assume out table is empty at this point, so use it like an input table
232                    // because why not.
233                    let val =
234                        Expr::execute(expr, &out_table, out_table.sentinel_row()?.to_shared())?;
235                    let data_type = val.data_type();
236                    out_table.new_row(vec![val]);
237
238                    // TODO: provide a unique name here
239                    let new_col = Column::new(
240                        alias.unwrap_or("PLACEHOLDER".into()),
241                        data_type,
242                        vec![],
243                        false,
244                    );
245
246                    out_table.add_column(new_col);
247                }
248                (
249                    Some(Register::TableRef(inp_table_index)),
250                    Some(Register::TableRef(out_table_index)),
251                ) => {
252                    let [inp_table, out_table] = self
253                        .tables
254                        .get_many_mut([inp_table_index, out_table_index])
255                        .unwrap();
256
257                    if !out_table.is_empty()
258                        && (inp_table.raw_data.len() != out_table.raw_data.len())
259                    {
260                        return Err(RuntimeError::ProjectTableSizeMismatch {
261                            inp_table_name: inp_table.name().to_owned(),
262                            inp_table_len: inp_table.raw_data.len(),
263                            out_table_name: out_table.name().to_owned(),
264                            out_table_len: out_table.raw_data.len(),
265                        });
266                    }
267
268                    if let Expr::Wildcard = expr {
269                        // TODO: this could be optimized.
270                        for col in inp_table.columns() {
271                            out_table.add_column(col.clone());
272                            out_table.add_column_data(
273                                col.name(),
274                                inp_table.get_column_data(col.name())?,
275                            )?;
276                        }
277                    } else {
278                        if inp_table.raw_data.len() == out_table.raw_data.len() {
279                            for (inp_row, out_row) in
280                                inp_table.raw_data.iter().zip(out_table.raw_data.iter_mut())
281                            {
282                                let val = Expr::execute(
283                                    expr,
284                                    inp_table,
285                                    RowShared::from_raw(&inp_row, &inp_table),
286                                )?;
287                                out_row.raw_data.push(val);
288                            }
289                        } else {
290                            for inp_row in inp_table.raw_data.iter() {
291                                let val = Expr::execute(
292                                    expr,
293                                    inp_table,
294                                    RowShared::from_raw(&inp_row, &inp_table),
295                                )?;
296                                out_table.new_row(vec![val]);
297                            }
298                        }
299
300                        let data_type = if !out_table.raw_data.is_empty() {
301                            let newly_added =
302                                out_table.raw_data.first().unwrap().raw_data.last().unwrap();
303                            newly_added.data_type()
304                        } else {
305                            let sentinel = inp_table.sentinel_row()?;
306                            let output_val = Expr::execute(expr, inp_table, sentinel.to_shared())?;
307                            output_val.data_type()
308                        };
309
310                        // TODO: provide a unique name here
311                        let new_col = Column::new(
312                            alias.unwrap_or("PLACEHOLDER".into()),
313                            data_type,
314                            vec![],
315                            false,
316                        );
317
318                        out_table.add_column(new_col);
319                    }
320                }
321                (Some(reg), Some(Register::TableRef(_))) => {
322                    return Err(RuntimeError::RegisterNotATable("project", reg.clone()))
323                }
324                (Some(Register::TableRef(_)), Some(reg)) => {
325                    return Err(RuntimeError::RegisterNotATable("project", reg.clone()))
326                }
327                (Some(reg), Some(_)) => {
328                    return Err(RuntimeError::RegisterNotATable("project", reg.clone()))
329                }
330            },
331            Instruction::GroupBy { index: _, expr: _ } => todo!("group by is not implemented yet"),
332            Instruction::Order {
333                index,
334                expr,
335                ascending,
336            } => {
337                let table_index = match self.registers.get(index) {
338                    None => return Err(RuntimeError::EmptyRegister(*index)),
339                    Some(Register::TableRef(table_index)) => table_index,
340                    Some(register) => {
341                        return Err(RuntimeError::RegisterNotATable(
342                            "order by",
343                            register.clone(),
344                        ))
345                    }
346                };
347                let table = self.tables.get_mut(table_index).unwrap();
348
349                let expr_values = table
350                    .raw_data
351                    .iter()
352                    .map(|row| Expr::execute(expr, table, RowShared::from_raw(&row, &table)))
353                    .collect::<Result<Vec<_>, _>>()?;
354                let mut perm = permutation::sort(expr_values);
355                perm.apply_slice_in_place(&mut table.raw_data);
356
357                if !ascending {
358                    table.raw_data.reverse();
359                }
360            }
361            Instruction::Limit { index, limit } => {
362                let table_index = match self.registers.get(index) {
363                    None => return Err(RuntimeError::EmptyRegister(*index)),
364                    Some(Register::TableRef(table_index)) => table_index,
365                    Some(register) => {
366                        return Err(RuntimeError::RegisterNotATable("limit", register.clone()))
367                    }
368                };
369                let table = self.tables.get_mut(table_index).unwrap();
370
371                table.raw_data.truncate(*limit as usize);
372            }
373            Instruction::NewSchema {
374                schema_name,
375                exists_ok,
376            } => {
377                let name = schema_name.0;
378                if let None = self.database.schema_by_name(&name) {
379                    self.database.add_schema(Schema::new(name));
380                } else if !*exists_ok {
381                    return Err(RuntimeError::SchemaExists(name));
382                }
383            }
384            Instruction::ColumnDef {
385                index,
386                name,
387                data_type,
388            } => {
389                self.registers.insert(
390                    *index,
391                    Register::Column(Column::new(*name, data_type.clone(), vec![], false)),
392                );
393            }
394            Instruction::AddColumnOption { index, option } => {
395                let column = match self.registers.get_mut(index) {
396                    Some(Register::Column(column)) => column,
397                    Some(register) => {
398                        return Err(RuntimeError::RegisterNotAColumn(
399                            "add column option",
400                            register.clone(),
401                        ))
402                    }
403                    None => return Err(RuntimeError::EmptyRegister(*index)),
404                };
405                column.add_column_option(option.clone());
406            }
407            Instruction::AddColumn {
408                table_reg_index,
409                col_index,
410            } => {
411                let table_index = match self.registers.get(table_reg_index) {
412                    None => return Err(RuntimeError::EmptyRegister(*table_reg_index)),
413                    Some(Register::TableRef(table_index)) => table_index,
414                    Some(register) => {
415                        return Err(RuntimeError::RegisterNotATable(
416                            "add column",
417                            register.clone(),
418                        ))
419                    }
420                };
421                let table = self.tables.get_mut(table_index).unwrap();
422
423                let column = match self.registers.get(col_index) {
424                    Some(Register::Column(column)) => column,
425                    Some(register) => {
426                        return Err(RuntimeError::RegisterNotAColumn(
427                            "add column",
428                            register.clone(),
429                        ))
430                    }
431                    None => return Err(RuntimeError::EmptyRegister(*col_index)),
432                };
433
434                table.add_column(column.clone());
435            }
436            Instruction::NewTable {
437                index,
438                name,
439                exists_ok,
440            } => {
441                let table_index = *match self.registers.get(index) {
442                    None => return Err(RuntimeError::EmptyRegister(*index)),
443                    Some(Register::TableRef(table_index)) => table_index,
444                    Some(register) => {
445                        return Err(RuntimeError::RegisterNotATable(
446                            "new table",
447                            register.clone(),
448                        ))
449                    }
450                };
451
452                let table = self.tables.get_mut(&table_index).unwrap();
453                table.rename(name.table_name);
454
455                let schema = self.find_schema(name.schema_name)?;
456
457                match self.find_table(schema, name) {
458                    Ok(_) => {
459                        if !exists_ok {
460                            return Err(RuntimeError::TableExists(*name));
461                        }
462                    }
463                    Err(RuntimeError::TableNotFound(_)) => {
464                        self.find_schema_mut(name.schema_name)?
465                            .add_table(table_index);
466                    }
467                    Err(e) => return Err(e),
468                }
469            }
470            Instruction::DropTable { index: _ } => todo!("drop table is not implemented yet"),
471            Instruction::RemoveColumn {
472                index: _,
473                col_name: _,
474            } => todo!("remove column is not implemented yet"),
475            Instruction::RenameColumn {
476                index: _,
477                old_name: _,
478                new_name: _,
479            } => todo!("rename column is not implemented yet"),
480            Instruction::InsertDef {
481                table_reg_index,
482                index,
483            } => {
484                let table_index = *match self.registers.get(table_reg_index) {
485                    None => return Err(RuntimeError::EmptyRegister(*table_reg_index)),
486                    Some(Register::TableRef(table_index)) => table_index,
487                    Some(register) => {
488                        return Err(RuntimeError::RegisterNotATable(
489                            "insert def",
490                            register.clone(),
491                        ))
492                    }
493                };
494
495                self.registers
496                    .insert(*index, Register::InsertDef(InsertDef::new(table_index)));
497            }
498            Instruction::ColumnInsertDef {
499                insert_index,
500                col_name,
501            } => {
502                let insert = match self.registers.get_mut(insert_index) {
503                    Some(Register::InsertDef(insert)) => insert,
504                    Some(register) => {
505                        return Err(RuntimeError::RegisterNotAInsert(
506                            "column insert def",
507                            register.clone(),
508                        ))
509                    }
510                    None => return Err(RuntimeError::EmptyRegister(*insert_index)),
511                };
512
513                let table = self.tables.get(&insert.table).unwrap();
514
515                let col_info = table.get_column(col_name)?;
516
517                insert.columns.push((col_info.0, col_info.1.to_owned()));
518            }
519            Instruction::RowDef {
520                insert_index,
521                row_index: row_reg_index,
522            } => {
523                let insert = match self.registers.get_mut(insert_index) {
524                    Some(Register::InsertDef(insert)) => insert,
525                    Some(register) => {
526                        return Err(RuntimeError::RegisterNotAInsert(
527                            "row def",
528                            register.clone(),
529                        ))
530                    }
531                    None => return Err(RuntimeError::EmptyRegister(*insert_index)),
532                };
533
534                insert.rows.push(vec![]);
535                let row_index = insert.rows.len() - 1;
536
537                self.registers.insert(
538                    *row_reg_index,
539                    Register::InsertRow(InsertRow {
540                        def: *insert_index,
541                        row_index,
542                    }),
543                );
544            }
545            Instruction::AddValue {
546                row_index: row_reg_index,
547                expr,
548            } => {
549                let &InsertRow {
550                    def: insert_reg_index,
551                    row_index,
552                } = match self.registers.get(row_reg_index) {
553                    Some(Register::InsertRow(insert_row)) => insert_row,
554                    Some(register) => {
555                        return Err(RuntimeError::RegisterNotAInsertRow(
556                            "add value",
557                            register.clone(),
558                        ))
559                    }
560                    None => return Err(RuntimeError::EmptyRegister(*row_reg_index)),
561                };
562
563                let insert = match self.registers.get_mut(&insert_reg_index) {
564                    Some(Register::InsertDef(insert)) => insert,
565                    Some(register) => {
566                        return Err(RuntimeError::RegisterNotAInsert(
567                            "row def",
568                            register.clone(),
569                        ))
570                    }
571                    None => return Err(RuntimeError::EmptyRegister(insert_reg_index)),
572                };
573
574                let table = self.tables.get(&insert.table).unwrap();
575
576                let value = Expr::execute(expr, table, table.sentinel_row()?.to_shared())?;
577
578                if insert.rows[row_index].len() + 1 > table.num_columns() {
579                    return Err(RuntimeError::TooManyValuesToInsert(
580                        *table.name(),
581                        insert.rows[row_index].len() + 1,
582                        table.num_columns(),
583                    ));
584                }
585
586                insert.rows[row_index].push(value);
587            }
588            Instruction::Insert {
589                index: insert_index,
590            } => {
591                let insert = match self.registers.remove(insert_index) {
592                    Some(Register::InsertDef(insert)) => insert,
593                    Some(register) => {
594                        return Err(RuntimeError::RegisterNotAInsert("insert", register.clone()))
595                    }
596                    None => return Err(RuntimeError::EmptyRegister(*insert_index)),
597                };
598
599                let table = self.tables.get_mut(&insert.table).unwrap();
600
601                if !insert.columns.is_empty() && (insert.columns.len() != table.num_columns()) {
602                    return Err(RuntimeError::Unsupported(concat!(
603                        "Default values are not supported yet. ",
604                        "Some columns were missing in INSERT."
605                    )));
606                }
607
608                for row in insert.rows {
609                    if table.num_columns() != row.len() {
610                        return Err(RuntimeError::NotEnoughValuesToInsert(
611                            *table.name(),
612                            row.len(),
613                            table.num_columns(),
614                        ));
615                    }
616                    table.new_row(row);
617                }
618            }
619            Instruction::Update {
620                index: _,
621                col: _,
622                expr: _,
623            } => todo!("update is not implemented yet"),
624            Instruction::Union {
625                input1: _,
626                input2: _,
627                output: _,
628            } => todo!("union is not implemented yet"),
629            Instruction::CrossJoin {
630                input1: _,
631                input2: _,
632                output: _,
633            } => todo!("joins are not implemented yet"),
634            Instruction::NaturalJoin {
635                input1: _,
636                input2: _,
637                output: _,
638            } => todo!("joins are not implemented yet"),
639        }
640        Ok(None)
641    }
642
643    /// Find [`TableIndex`] given the schema and its name.
644    fn find_table(&self, schema: &Schema, table: &TableRef) -> Result<TableIndex, RuntimeError> {
645        if let Some(table_index) = schema
646            .tables()
647            .iter()
648            .find(|table_index| self.tables[table_index].name() == &table.table_name)
649        {
650            Ok(*table_index)
651        } else {
652            Err(RuntimeError::TableNotFound(table.clone()))
653        }
654    }
655
656    /// A reference to the given schema, or default schema if it's `None`.
657    fn find_schema(&self, name: Option<BoundedString>) -> Result<&Schema, RuntimeError> {
658        if let Some(schema_name) = name {
659            match self.database.schema_by_name(&schema_name) {
660                Some(schema) => Ok(schema),
661                None => return Err(RuntimeError::SchemaNotFound(schema_name)),
662            }
663        } else {
664            Ok(self.database.default_schema())
665        }
666    }
667
668    /// A mutable reference to the given schema, or default schema if it's `None`.
669    fn find_schema_mut(
670        &mut self,
671        name: Option<BoundedString>,
672    ) -> Result<&mut Schema, RuntimeError> {
673        if let Some(schema_name) = name {
674            match self.database.schema_by_name_mut(&schema_name) {
675                Some(schema) => Ok(schema),
676                None => return Err(RuntimeError::SchemaNotFound(schema_name)),
677            }
678        } else {
679            Ok(self.database.default_schema_mut())
680        }
681    }
682}
683
684impl Default for VirtualMachine {
685    fn default() -> Self {
686        Self::new(DEFAULT_DATABASE_NAME.into())
687    }
688}
689
690#[derive(Debug, Clone, PartialEq)]
691/// A register in the executor VM.
692pub enum Register {
693    /// A reference to a table.
694    TableRef(TableIndex),
695    /// A reference to a non-existent table.
696    NonExistentTable,
697    /// A grouped table.
698    GroupedTable {
699        grouped_col: Column,
700        other_cols: Vec<Column>,
701        /// The group, a mapping of grouped col value -> rows in that group.
702        data: Vec<(Value, Vec<Row>)>,
703    },
704    /// A table definition.
705    TableDef(TableDef),
706    /// A column definition
707    Column(Column),
708    /// An insert statement
709    InsertDef(InsertDef),
710    /// A row to insert
711    InsertRow(InsertRow),
712    /// A value
713    Value(Value),
714    /// An expression
715    Expr(Expr),
716    // TODO: an error value?
717}
718
719#[derive(Debug, Clone, PartialEq)]
720/// An abstract definition of a create table statement.
721pub struct TableDef {
722    pub name: BoundedString,
723    pub columns: Vec<Column>,
724}
725
726#[derive(Debug, Clone, PartialEq)]
727/// An abstract definition of an insert statement.
728pub struct InsertDef {
729    /// The view to insert into
730    pub table: TableIndex,
731    /// The columns to insert into.
732    ///
733    /// Empty means all columns.
734    pub columns: Vec<(usize, Column)>,
735    /// The values to insert.
736    pub rows: Vec<Vec<Value>>,
737}
738
739impl InsertDef {
740    pub fn new(table: TableIndex) -> Self {
741        Self {
742            table,
743            columns: Vec::new(),
744            rows: Vec::new(),
745        }
746    }
747}
748
749#[derive(Debug, Clone, PartialEq)]
750/// A row of values to insert.
751pub struct InsertRow {
752    /// The insert definition which this belongs to
753    pub def: RegisterIndex,
754    /// Which row of the insert definition this refers to
755    pub row_index: usize,
756}
757
758#[derive(Debug)]
759pub enum ExecutionError {
760    ParseError(ParserError),
761    CodegenError(CodegenError),
762    RuntimeError(RuntimeError),
763}
764
765impl From<ParserError> for ExecutionError {
766    fn from(err: ParserError) -> Self {
767        ExecutionError::ParseError(err)
768    }
769}
770
771impl From<CodegenError> for ExecutionError {
772    fn from(err: CodegenError) -> Self {
773        ExecutionError::CodegenError(err)
774    }
775}
776
777impl From<RuntimeError> for ExecutionError {
778    fn from(err: RuntimeError) -> Self {
779        ExecutionError::RuntimeError(err)
780    }
781}
782
783impl Display for ExecutionError {
784    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
785        match self {
786            Self::ParseError(e) => write!(f, "{}", e),
787            Self::CodegenError(e) => write!(f, "{}", e),
788            Self::RuntimeError(e) => write!(f, "{}", e),
789        }
790    }
791}
792
793impl Error for ExecutionError {}
794
795/// All possible errors handled during execution.
796///
797/// This includes constraint violations, errors in expression evaluation, unsupported features as
798/// well as internal errors that are explicitly caught.
799#[derive(Debug, PartialEq)]
800pub enum RuntimeError {
801    ColumnNotFound(ColumnRef),
802    TableNotFound(TableRef),
803    TableExists(TableRef),
804    SchemaNotFound(BoundedString),
805    SchemaExists(BoundedString),
806    EmptyRegister(RegisterIndex),
807    RegisterNotATable(&'static str, Register),
808    RegisterNotAColumn(&'static str, Register),
809    RegisterNotAInsert(&'static str, Register),
810    RegisterNotAInsertRow(&'static str, Register),
811    CannotReturn(Register),
812    FilterWithNonBoolean(Expr, Value),
813    ProjectOnNonEmptyTable(BoundedString),
814    ProjectTableSizeMismatch {
815        inp_table_name: BoundedString,
816        inp_table_len: usize,
817        out_table_name: BoundedString,
818        out_table_len: usize,
819    },
820    TableNewColumnSizeMismatch {
821        table_name: BoundedString,
822        table_len: usize,
823        col_name: BoundedString,
824        col_len: usize,
825    },
826    UnsupportedType(DataType),
827    ExprExecError(ExprExecError),
828    TooManyValuesToInsert(BoundedString, usize, usize),
829    NotEnoughValuesToInsert(BoundedString, usize, usize),
830    Unsupported(&'static str),
831}
832
833impl From<ExprExecError> for RuntimeError {
834    fn from(e: ExprExecError) -> Self {
835        Self::ExprExecError(e)
836    }
837}
838
839impl Display for RuntimeError {
840    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
841        match self {
842            Self::ColumnNotFound(c) => write!(f, "Column not found: '{}'", c),
843            Self::TableNotFound(t) => write!(f, "Table not found: '{}'", t),
844            Self::TableExists(s) => write!(f, "Table already exists: '{}'", s),
845            Self::SchemaNotFound(s) => write!(f, "Schema not found: '{}'", s),
846            Self::SchemaExists(s) => write!(f, "Schema already exists: '{}'", s),
847            Self::EmptyRegister(r) => write!(
848                f,
849                "Register is not initialized: '{}' (critical error. Please file an issue.)",
850                r
851            ),
852            Self::RegisterNotATable(operation, reg) => write!(
853                f,
854                "Register is not a table. Cannot perform '{}' on '{:?}'",
855                operation, reg
856            ),
857            Self::RegisterNotAColumn(operation, reg) => write!(
858                f,
859                "Register is not a column. Cannot perform '{}' on '{:?}'",
860                operation, reg
861            ),
862            Self::RegisterNotAInsert(operation, reg) => write!(
863                f,
864                "Register is not an insert def. Cannot perform '{}' on '{:?}'",
865                operation, reg
866            ),
867            Self::RegisterNotAInsertRow(operation, reg) => write!(
868                f,
869                "Register is not an insert row. Cannot perform '{}' on '{:?}'",
870                operation, reg
871            ),
872            Self::CannotReturn(r) => write!(
873                f,
874                "Register value cannot be returned: '{:?}' \
875                 (critical error. Please file an issue)",
876                r
877            ),
878            Self::FilterWithNonBoolean(e, v) => write!(
879                f,
880                "WHERE clause used with a non-boolean value. \
881                 Expression: '{}' evaluated to value: '{}'",
882                e, v
883            ),
884            Self::ProjectOnNonEmptyTable(table_name) => write!(
885                f,
886                "Projecting on a non-empty table is not supported. \
887                 Tried projecting onto table: '{}'",
888                table_name
889            ),
890            Self::ProjectTableSizeMismatch {
891                inp_table_name,
892                inp_table_len,
893                out_table_name,
894                out_table_len,
895            } => write!(
896                f,
897                "Projection input and output table had different number of rows. \
898                 Input: '{}' with length {}, Output: '{}' with length {}",
899                inp_table_name, inp_table_len, out_table_name, out_table_len
900            ),
901            Self::TableNewColumnSizeMismatch {
902                table_name,
903                table_len,
904                col_name,
905                col_len,
906            } => write!(
907                f,
908                "New column data size does not match table size. \
909                 Table: '{}' with length {}, New column: '{}' with length {}",
910                table_name, table_len, col_name, col_len,
911            ),
912            Self::UnsupportedType(d) => write!(f, "Unsupported type: {}", d),
913            Self::ExprExecError(e) => write!(f, "{}", e),
914            Self::TooManyValuesToInsert(table_name, got_num, expected_num) => write!(
915                f,
916                concat!(
917                    "Too many values to insert into table '{}'. ",
918                    "Got at least {} values while the table has {} columns."
919                ),
920                table_name, got_num, expected_num
921            ),
922            Self::NotEnoughValuesToInsert(table_name, got_num, expected_num) => write!(
923                f,
924                concat!(
925                    "Not enough values to insert into table '{}'. ",
926                    "Got at {} values while {} columns were expected."
927                ),
928                table_name, got_num, expected_num
929            ),
930            Self::Unsupported(err) => write!(f, "{}", err,),
931        }
932    }
933}
934
935#[cfg(test)]
936mod tests {
937    use sqlparser::ast::{ColumnOption, ColumnOptionDef, DataType};
938
939    use crate::{
940        codegen::codegen_ast,
941        column::Column,
942        expr::{eval::ExprExecError, BinOp, Expr},
943        identifier::{ColumnRef, TableRef},
944        parser::parse,
945        table::{Row, Table},
946        value::Value,
947    };
948
949    use super::{RuntimeError, VirtualMachine};
950
951    #[test]
952    fn create_vm() {
953        let _ = VirtualMachine::default();
954    }
955
956    fn check_single_statement(
957        query: &str,
958        vm: &mut VirtualMachine,
959    ) -> Result<Option<Table>, RuntimeError> {
960        let parsed = parse(query).unwrap();
961        assert_eq!(parsed.len(), 1);
962
963        let statement = &parsed[0];
964        let ic = codegen_ast(&statement).unwrap();
965
966        println!("ic: {ic:#?}");
967
968        vm.execute_ic(&ic)
969    }
970
971    #[test]
972    fn create_schema() {
973        let mut vm = VirtualMachine::default();
974
975        let _res = check_single_statement("CREATE SCHEMA abc", &mut vm).unwrap();
976        let schema = vm.database.schema_by_name(&"abc".into()).unwrap();
977        assert_eq!(schema.name(), "abc");
978        assert_eq!(schema.tables(), &vec![]);
979
980        let res = check_single_statement("CREATE SCHEMA abc", &mut vm).unwrap_err();
981        assert_eq!(res, RuntimeError::SchemaExists("abc".into()));
982
983        let _res = check_single_statement("CREATE SCHEMA IF NOT EXISTS abc", &mut vm).unwrap();
984        let schema = vm.database.schema_by_name(&"abc".into()).unwrap();
985        assert_eq!(schema.name(), "abc");
986        assert_eq!(schema.tables(), &vec![]);
987    }
988
989    #[test]
990    fn create_table() {
991        let mut vm = VirtualMachine::default();
992        let _res = check_single_statement(
993            "CREATE TABLE
994             IF NOT EXISTS table1
995             (
996                 col1 INTEGER PRIMARY KEY NOT NULL,
997                 col2 STRING NOT NULL,
998                 col3 INTEGER UNIQUE
999             )",
1000            &mut vm,
1001        )
1002        .unwrap();
1003
1004        let table_index = vm
1005            .find_table(
1006                vm.database.default_schema(),
1007                &TableRef {
1008                    schema_name: None,
1009                    table_name: "table1".into(),
1010                },
1011            )
1012            .unwrap();
1013
1014        let table = vm.table(&table_index).unwrap();
1015        assert_eq!(
1016            table.columns().cloned().collect::<Vec<_>>(),
1017            vec![
1018                Column::new(
1019                    "col1".into(),
1020                    DataType::Int(None),
1021                    vec![
1022                        ColumnOptionDef {
1023                            name: None,
1024                            option: ColumnOption::Unique { is_primary: true },
1025                        },
1026                        ColumnOptionDef {
1027                            name: None,
1028                            option: ColumnOption::NotNull
1029                        }
1030                    ],
1031                    false
1032                ),
1033                Column::new(
1034                    "col2".into(),
1035                    DataType::String,
1036                    vec![ColumnOptionDef {
1037                        name: None,
1038                        option: ColumnOption::NotNull
1039                    }],
1040                    false
1041                ),
1042                Column::new(
1043                    "col3".into(),
1044                    DataType::Int(None),
1045                    vec![ColumnOptionDef {
1046                        name: None,
1047                        option: ColumnOption::Unique { is_primary: false },
1048                    }],
1049                    false
1050                ),
1051            ]
1052        )
1053    }
1054
1055    #[test]
1056    fn insert_values() {
1057        let mut vm = VirtualMachine::default();
1058
1059        check_single_statement(
1060            "
1061            CREATE TABLE table1
1062            (
1063                col1 INTEGER PRIMARY KEY NOT NULL,
1064                col2 STRING NOT NULL
1065            )
1066            ",
1067            &mut vm,
1068        )
1069        .unwrap();
1070
1071        let _res = check_single_statement(
1072            "
1073            INSERT INTO table1 VALUES
1074                (2, 'bar'),
1075                (3, 'aaa')
1076            ",
1077            &mut vm,
1078        )
1079        .unwrap();
1080
1081        let table_index = vm
1082            .find_table(
1083                vm.database.default_schema(),
1084                &TableRef {
1085                    schema_name: None,
1086                    table_name: "table1".into(),
1087                },
1088            )
1089            .unwrap();
1090
1091        let table = vm.table(&table_index).unwrap();
1092
1093        assert_eq!(
1094            table.all_data(),
1095            vec![
1096                Row::new(vec![Value::Int64(2), Value::String("bar".to_owned())]),
1097                Row::new(vec![Value::Int64(3), Value::String("aaa".to_owned())])
1098            ]
1099        );
1100
1101        let res = check_single_statement(
1102            "
1103            INSERT INTO table1 VALUES
1104                (2, 'bar', 1.9)
1105            ",
1106            &mut vm,
1107        )
1108        .unwrap_err();
1109
1110        assert_eq!(
1111            res,
1112            RuntimeError::TooManyValuesToInsert("table1".into(), 3, 2)
1113        );
1114
1115        let res = check_single_statement(
1116            "
1117            INSERT INTO table1 VALUES
1118                ('bar')
1119            ",
1120            &mut vm,
1121        );
1122
1123        assert_eq!(
1124            res.unwrap_err(),
1125            RuntimeError::NotEnoughValuesToInsert("table1".into(), 1, 2)
1126        );
1127
1128        let _res = check_single_statement(
1129            "
1130            INSERT INTO table1 (col1, col2) VALUES
1131                (4, 'car'),
1132                (5, 'yak')
1133            ",
1134            &mut vm,
1135        )
1136        .unwrap();
1137
1138        let table_index = vm
1139            .find_table(
1140                vm.database.default_schema(),
1141                &TableRef {
1142                    schema_name: None,
1143                    table_name: "table1".into(),
1144                },
1145            )
1146            .unwrap();
1147
1148        let table = vm.table(&table_index).unwrap();
1149
1150        assert_eq!(
1151            table.all_data(),
1152            vec![
1153                Row::new(vec![Value::Int64(2), Value::String("bar".to_owned())]),
1154                Row::new(vec![Value::Int64(3), Value::String("aaa".to_owned())]),
1155                Row::new(vec![Value::Int64(4), Value::String("car".to_owned())]),
1156                Row::new(vec![Value::Int64(5), Value::String("yak".to_owned())]),
1157            ]
1158        );
1159
1160        let res = check_single_statement(
1161            "
1162            INSERT INTO table1 (col2) VALUES
1163                ('bar')
1164            ",
1165            &mut vm,
1166        );
1167        matches!(res.unwrap_err(), RuntimeError::Unsupported(_));
1168    }
1169
1170    #[test]
1171    fn select() {
1172        let mut vm = VirtualMachine::default();
1173
1174        let res = check_single_statement("SELECT 1", &mut vm)
1175            .unwrap()
1176            .unwrap();
1177
1178        assert_eq!(
1179            res.columns().collect::<Vec<_>>(),
1180            vec![&Column::new(
1181                "PLACEHOLDER".into(),
1182                DataType::Int(None),
1183                vec![],
1184                false
1185            ),]
1186        );
1187
1188        assert_eq!(res.all_data(), vec![Row::new(vec![Value::Int64(1)])]);
1189        assert_eq!(
1190            check_single_statement("SELECT 10 * 20 + 5", &mut vm)
1191                .unwrap()
1192                .unwrap()
1193                .all_data(),
1194            vec![Row::new(vec![Value::Int64(205)])]
1195        );
1196        assert_eq!(
1197            check_single_statement("SELECT 'a'", &mut vm)
1198                .unwrap()
1199                .unwrap()
1200                .all_data(),
1201            vec![Row::new(vec![Value::String("a".to_owned())])]
1202        );
1203
1204        check_single_statement(
1205            "
1206            CREATE TABLE table1
1207            (
1208                col1 INTEGER PRIMARY KEY NOT NULL,
1209                col2 STRING NOT NULL
1210            )
1211            ",
1212            &mut vm,
1213        )
1214        .unwrap();
1215
1216        check_single_statement(
1217            "
1218            INSERT INTO table1 VALUES
1219                (2, 'bar'),
1220                (3, 'aaa')
1221            ",
1222            &mut vm,
1223        )
1224        .unwrap();
1225
1226        let res = check_single_statement("SELECT * FROM table1", &mut vm)
1227            .unwrap()
1228            .unwrap();
1229
1230        assert_eq!(
1231            res.columns().cloned().collect::<Vec<_>>(),
1232            vec![
1233                Column::new(
1234                    "col1".into(),
1235                    DataType::Int(None),
1236                    vec![
1237                        ColumnOptionDef {
1238                            name: None,
1239                            option: ColumnOption::Unique { is_primary: true },
1240                        },
1241                        ColumnOptionDef {
1242                            name: None,
1243                            option: ColumnOption::NotNull
1244                        }
1245                    ],
1246                    false
1247                ),
1248                Column::new(
1249                    "col2".into(),
1250                    DataType::String,
1251                    vec![ColumnOptionDef {
1252                        name: None,
1253                        option: ColumnOption::NotNull
1254                    }],
1255                    false
1256                ),
1257            ]
1258        );
1259
1260        assert_eq!(
1261            res.all_data(),
1262            vec![
1263                Row::new(vec![Value::Int64(2), Value::String("bar".to_owned())]),
1264                Row::new(vec![Value::Int64(3), Value::String("aaa".to_owned())])
1265            ]
1266        );
1267
1268        let res = check_single_statement("SELECT * FROM table1 WHERE col1 = 2", &mut vm)
1269            .unwrap()
1270            .unwrap();
1271        assert_eq!(
1272            res.all_data(),
1273            vec![Row::new(vec![
1274                Value::Int64(2),
1275                Value::String("bar".to_owned())
1276            ])]
1277        );
1278
1279        let res = check_single_statement("SELECT * FROM table1 WHERE col1 = 1", &mut vm)
1280            .unwrap()
1281            .unwrap();
1282        assert_eq!(res.all_data(), vec![]);
1283
1284        let res = check_single_statement("SELECT * FROM table1 WHERE col1 = 2", &mut vm)
1285            .unwrap()
1286            .unwrap();
1287        assert_eq!(
1288            res.all_data(),
1289            vec![Row::new(vec![
1290                Value::Int64(2),
1291                Value::String("bar".to_owned())
1292            ])]
1293        );
1294
1295        let res =
1296            check_single_statement("SELECT * FROM table1 WHERE col1 = 2 or col1 = 3", &mut vm)
1297                .unwrap()
1298                .unwrap();
1299        assert_eq!(
1300            res.all_data(),
1301            vec![
1302                Row::new(vec![Value::Int64(2), Value::String("bar".to_owned())]),
1303                Row::new(vec![Value::Int64(3), Value::String("aaa".to_owned())])
1304            ]
1305        );
1306
1307        let res = check_single_statement("SELECT col1 FROM table1", &mut vm)
1308            .unwrap()
1309            .unwrap();
1310        assert_eq!(
1311            res.all_data(),
1312            vec![
1313                Row::new(vec![Value::Int64(2)]),
1314                Row::new(vec![Value::Int64(3)])
1315            ]
1316        );
1317
1318        let res = check_single_statement("SELECT col1, col2 FROM table1", &mut vm)
1319            .unwrap()
1320            .unwrap();
1321        assert_eq!(
1322            res.all_data(),
1323            vec![
1324                Row::new(vec![Value::Int64(2), Value::String("bar".to_owned())]),
1325                Row::new(vec![Value::Int64(3), Value::String("aaa".to_owned())])
1326            ]
1327        );
1328
1329        let res = check_single_statement("SELECT * FROM table1 ORDER BY col1", &mut vm)
1330            .unwrap()
1331            .unwrap();
1332        assert_eq!(
1333            res.all_data(),
1334            vec![
1335                Row::new(vec![Value::Int64(2), Value::String("bar".to_owned())]),
1336                Row::new(vec![Value::Int64(3), Value::String("aaa".to_owned())])
1337            ]
1338        );
1339
1340        let res = check_single_statement("SELECT * FROM table1 ORDER BY col2", &mut vm)
1341            .unwrap()
1342            .unwrap();
1343        assert_eq!(
1344            res.all_data(),
1345            vec![
1346                Row::new(vec![Value::Int64(3), Value::String("aaa".to_owned())]),
1347                Row::new(vec![Value::Int64(2), Value::String("bar".to_owned())]),
1348            ]
1349        );
1350
1351        let res = check_single_statement("SELECT * FROM table1 ORDER BY col1 DESC", &mut vm)
1352            .unwrap()
1353            .unwrap();
1354        assert_eq!(
1355            res.all_data(),
1356            vec![
1357                Row::new(vec![Value::Int64(3), Value::String("aaa".to_owned())]),
1358                Row::new(vec![Value::Int64(2), Value::String("bar".to_owned())]),
1359            ]
1360        );
1361
1362        let res = check_single_statement("SELECT * FROM table1 ORDER BY col2 DESC", &mut vm)
1363            .unwrap()
1364            .unwrap();
1365        assert_eq!(
1366            res.all_data(),
1367            vec![
1368                Row::new(vec![Value::Int64(2), Value::String("bar".to_owned())]),
1369                Row::new(vec![Value::Int64(3), Value::String("aaa".to_owned())]),
1370            ]
1371        );
1372
1373        let res = check_single_statement("SELECT * FROM table1 ORDER BY col1 LIMIT 1", &mut vm)
1374            .unwrap()
1375            .unwrap();
1376        assert_eq!(
1377            res.all_data(),
1378            vec![Row::new(vec![
1379                Value::Int64(2),
1380                Value::String("bar".to_owned())
1381            ]),]
1382        );
1383
1384        let res = check_single_statement("SELECT col3 FROM table1", &mut vm);
1385        assert_eq!(
1386            res.unwrap_err(),
1387            RuntimeError::ExprExecError(ExprExecError::NoSuchColumn("col3".into()))
1388        );
1389
1390        let res = check_single_statement("SELECT col1 FROM table2", &mut vm);
1391        assert_eq!(
1392            res.unwrap_err(),
1393            RuntimeError::TableNotFound(TableRef {
1394                schema_name: None,
1395                table_name: "table2".into()
1396            })
1397        );
1398
1399        let res = check_single_statement("SELECT col1 FROM table1 ORDER BY col3", &mut vm);
1400        assert_eq!(
1401            res.unwrap_err(),
1402            RuntimeError::ExprExecError(ExprExecError::NoSuchColumn("col3".into()))
1403        );
1404
1405        let res = check_single_statement("SELECT col1 FROM table1 WHERE col3 = 1", &mut vm);
1406        assert_eq!(
1407            res.unwrap_err(),
1408            RuntimeError::ExprExecError(ExprExecError::NoSuchColumn("col3".into()))
1409        );
1410
1411        let res = check_single_statement("SELECT col1 FROM table1 WHERE col1 + 1", &mut vm);
1412        assert_eq!(
1413            res.unwrap_err(),
1414            RuntimeError::FilterWithNonBoolean(
1415                Expr::Binary {
1416                    left: Box::new(Expr::ColumnRef(ColumnRef {
1417                        schema_name: None,
1418                        table_name: None,
1419                        col_name: "col1".into()
1420                    })),
1421                    op: BinOp::Plus,
1422                    right: Box::new(Expr::Value(Value::Int64(1)))
1423                },
1424                Value::Int64(3)
1425            )
1426        );
1427    }
1428}