use super::intermediate_ast::{OrderBy, SetExpression, Slice, TableExpression};
use crate::{sql::SelectStatementParser, Identifier, ParseError, ParseResult, ResourceId};
use serde::{Deserialize, Serialize};
use std::{fmt, ops::Deref, str::FromStr};
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone)]
pub struct SelectStatement {
pub expr: Box<SetExpression>,
pub order_by: Vec<OrderBy>,
pub slice: Option<Slice>,
}
impl fmt::Debug for SelectStatement {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"SelectStatement \n[{:#?},\n{:#?},\n{:#?}\n]",
self.expr, self.order_by, self.slice
)
}
}
impl SelectStatement {
pub fn get_table_references(&self, default_schema: Identifier) -> Vec<ResourceId> {
let set_expression: &SetExpression = &(self.expr);
match set_expression {
SetExpression::Query {
result_exprs: _,
from,
where_expr: _,
group_by: _,
} => convert_table_expr_to_resource_id_vector(&from[..], default_schema),
}
}
}
impl FromStr for SelectStatement {
type Err = crate::ParseError;
fn from_str(query: &str) -> ParseResult<Self> {
SelectStatementParser::new()
.parse(query)
.map_err(|e| ParseError::QueryParseError(e.to_string()))
}
}
fn convert_table_expr_to_resource_id_vector(
table_expressions: &[Box<TableExpression>],
default_schema: Identifier,
) -> Vec<ResourceId> {
let mut tables = Vec::new();
for table_expression in table_expressions.iter() {
let table_ref: &TableExpression = table_expression.deref();
match table_ref {
TableExpression::Named { table, schema } => {
let schema = schema
.as_ref()
.map(|schema| schema.as_str())
.unwrap_or_else(|| default_schema.name());
tables.push(ResourceId::try_new(schema, table.as_str()).unwrap());
}
}
}
tables
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sql::SelectStatementParser;
#[test]
fn we_can_get_the_correct_table_references_using_a_default_schema() {
let parsed_query_ast = SelectStatementParser::new()
.parse("SELECT A FROM TAB WHERE C = 3")
.unwrap();
let default_schema = Identifier::try_new("ETH").unwrap();
let ref_tables = parsed_query_ast.get_table_references(default_schema);
assert_eq!(ref_tables, [ResourceId::try_new("eth", "tab").unwrap()]);
}
#[test]
fn we_can_get_the_correct_table_references_in_case_the_default_schema_equals_the_original_schema(
) {
let parsed_query_ast = SelectStatementParser::new()
.parse("SELECT A FROM SCHEMA.TAB WHERE C = 3")
.unwrap();
let default_schema = Identifier::try_new("SCHEMA").unwrap();
let ref_tables = parsed_query_ast.get_table_references(default_schema);
assert_eq!(ref_tables, [ResourceId::try_new("schema", "tab").unwrap()]);
}
#[test]
fn we_can_get_the_correct_table_references_in_case_the_default_schema_differs_from_the_original_schema(
) {
let parsed_query_ast = SelectStatementParser::new()
.parse("SELECT A FROM SCHEMA.TAB WHERE C = 3")
.unwrap();
let default_schema = Identifier::try_new(" ETH ").unwrap();
let ref_tables = parsed_query_ast.get_table_references(default_schema);
assert_eq!(ref_tables, [ResourceId::try_new("schema", "tab").unwrap()]);
}
}