use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::vec::Vec;
use crate::catalog::{
ColumnDesc, ColumnId, RootCatalog, TableRefId, DEFAULT_DATABASE_NAME, DEFAULT_SCHEMA_NAME,
};
use crate::parser::{Ident, ObjectName, Statement};
use crate::types::{DataTypeKind, DataValue};
mod expr_visitor;
mod expression;
pub(crate) mod statement;
mod table_ref;
pub use self::expr_visitor::*;
pub use self::expression::*;
pub use self::statement::*;
pub use self::table_ref::*;
#[derive(Debug, PartialEq, Clone)]
pub enum BoundStatement {
CreateTable(BoundCreateTable),
Drop(BoundDrop),
Insert(BoundInsert),
Copy(BoundCopy),
Select(Box<BoundSelect>),
Explain(Box<BoundStatement>),
Delete(Box<BoundDelete>),
}
#[derive(thiserror::Error, Debug, PartialEq, Eq)]
pub enum BindError {
#[error("invalid database {0}")]
InvalidDatabase(String),
#[error("invalid schema {0}")]
InvalidSchema(String),
#[error("invalid table {0}")]
InvalidTable(String),
#[error("invalid column {0}")]
InvalidColumn(String),
#[error("duplicated table {0}")]
DuplicatedTable(String),
#[error("duplicated column {0}")]
DuplicatedColumn(String),
#[error("invalid expression: {0}")]
InvalidExpression(String),
#[error("not nullable column: {0}")]
NotNullableColumn(String),
#[error("binary operator types mismatch: {0} != {1}")]
BinaryOpTypeMismatch(String, String),
#[error("ambiguous column")]
AmbiguousColumn,
#[error("invalid table name: {0:?}")]
InvalidTableName(Vec<Ident>),
#[error("SQL not supported")]
NotSupportedTSQL,
#[error("invalid SQL")]
InvalidSQL,
#[error("cannot cast {0:?} to {1:?}")]
CastError(DataValue, DataTypeKind),
#[error("{0}")]
BindFunctionError(String),
}
#[derive(Debug, Default)]
struct BinderContext {
regular_tables: HashMap<String, TableRefId>,
column_names: HashMap<String, HashSet<String>>,
column_ids: HashMap<String, Vec<ColumnId>>,
column_descs: HashMap<String, Vec<ColumnDesc>>,
aliases: Vec<String>,
aliases_expressions: Vec<BoundExpr>,
}
pub struct Binder {
catalog: Arc<RootCatalog>,
context: BinderContext,
upper_contexts: Vec<BinderContext>,
base_table_refs: Vec<String>,
}
impl Binder {
pub fn new(catalog: Arc<RootCatalog>) -> Self {
Binder {
catalog,
upper_contexts: Vec::new(),
context: BinderContext::default(),
base_table_refs: Vec::new(),
}
}
fn push_context(&mut self) {
let new_context = std::mem::take(&mut self.context);
self.upper_contexts.push(new_context);
}
fn pop_context(&mut self) {
let old_context = self.upper_contexts.pop();
self.context = old_context.unwrap();
}
pub fn bind(&mut self, stmt: &Statement) -> Result<BoundStatement, BindError> {
match stmt {
Statement::CreateTable { .. } => {
Ok(BoundStatement::CreateTable(self.bind_create_table(stmt)?))
}
Statement::Drop { .. } => Ok(BoundStatement::Drop(self.bind_drop(stmt)?)),
Statement::Insert { .. } => Ok(BoundStatement::Insert(self.bind_insert(stmt)?)),
Statement::Delete { .. } => Ok(BoundStatement::Delete(self.bind_delete(stmt)?)),
Statement::Copy { .. } => Ok(BoundStatement::Copy(self.bind_copy(stmt)?)),
Statement::Query(query) => Ok(BoundStatement::Select(self.bind_select(query)?)),
Statement::Explain { statement, .. } => {
Ok(BoundStatement::Explain((self.bind(statement)?).into()))
}
Statement::ShowVariable { .. }
| Statement::ShowCreate { .. }
| Statement::ShowColumns { .. } => Err(BindError::NotSupportedTSQL),
_ => Err(BindError::InvalidSQL),
}
}
}
fn split_name(name: &ObjectName) -> Result<(&str, &str, &str), BindError> {
Ok(match name.0.as_slice() {
[table] => (DEFAULT_DATABASE_NAME, DEFAULT_SCHEMA_NAME, &table.value),
[schema, table] => (DEFAULT_DATABASE_NAME, &schema.value, &table.value),
[db, schema, table] => (&db.value, &schema.value, &table.value),
_ => return Err(BindError::InvalidTableName(name.0.clone())),
})
}
fn lower_case_name(name: &ObjectName) -> ObjectName {
ObjectName(
name.0
.iter()
.map(|ident| Ident::new(ident.value.to_lowercase()))
.collect::<Vec<_>>(),
)
}