use std::collections::HashMap;
use gitql_ast::expression::Expression;
use gitql_ast::expression::ExpressionKind;
use gitql_ast::expression::StringExpression;
use gitql_ast::expression::StringValueType;
use gitql_ast::operator::PrefixUnaryOperator;
use gitql_ast::statement::TableSelection;
use gitql_core::environment::Environment;
use gitql_core::signature::Signature;
use gitql_core::types::DataType;
use crate::diagnostic::Diagnostic;
use crate::format_checker::is_valid_date_format;
use crate::format_checker::is_valid_datetime_format;
use crate::format_checker::is_valid_time_format;
use crate::tokenizer::Location;
pub enum TypeCheckResult {
Equals,
NotEqualAndCantImplicitCast,
Error(Box<Diagnostic>),
RightSideCasted(Box<dyn Expression>),
LeftSideCasted(Box<dyn Expression>),
}
const BOOLEANS_VALUES_LITERAL: [&str; 10] =
["t", "true", "y", "yes", "1", "f", "false", "n", "no", "0"];
pub enum ExprTypeCheckResult {
Equals,
NotEqualAndCantImplicitCast,
Error(Box<Diagnostic>),
ImplicitCasted(Box<dyn Expression>),
}
#[allow(clippy::borrowed_box)]
pub fn is_expression_type_equals(
scope: &Environment,
expr: &Box<dyn Expression>,
data_type: &DataType,
) -> ExprTypeCheckResult {
let expr_type = expr.expr_type(scope);
if expr_type == *data_type {
return ExprTypeCheckResult::Equals;
}
if expr.kind() != ExpressionKind::String || !expr_type.is_text() {
return ExprTypeCheckResult::NotEqualAndCantImplicitCast;
}
if data_type.is_time() || data_type.is_variant_with(&DataType::Time) {
let literal = expr.as_any().downcast_ref::<StringExpression>().unwrap();
let string_literal_value = &literal.value;
if !is_valid_time_format(string_literal_value) {
return ExprTypeCheckResult::Error(
Diagnostic::error(&format!(
"Can't compare Time and Text `{}` because it can't be implicitly casted to Time",
string_literal_value
)).add_help("A valid Time format must match `HH:MM:SS` or `HH:MM:SS.SSS`")
.add_help("You can use `MAKETIME(hour, minute, second)` function to create date value")
.as_boxed(),
);
}
return ExprTypeCheckResult::ImplicitCasted(Box::new(StringExpression {
value: string_literal_value.to_owned(),
value_type: StringValueType::Time,
}));
}
if data_type.is_date() || data_type.is_variant_with(&DataType::Date) {
let literal = expr.as_any().downcast_ref::<StringExpression>().unwrap();
let string_literal_value = &literal.value;
if !is_valid_date_format(string_literal_value) {
return ExprTypeCheckResult::Error(
Diagnostic::error(&format!(
"Can't compare Date and Text `{}` because it can't be implicitly casted to Date",
string_literal_value
)).add_help("A valid Date format must match `YYYY-MM-DD`")
.add_help("You can use `MAKEDATE(year, dayOfYear)` function to a create date value")
.as_boxed(),
);
}
return ExprTypeCheckResult::ImplicitCasted(Box::new(StringExpression {
value: string_literal_value.to_owned(),
value_type: StringValueType::Date,
}));
}
if data_type.is_datetime() || data_type.is_variant_with(&DataType::DateTime) {
let literal = expr.as_any().downcast_ref::<StringExpression>().unwrap();
let string_literal_value = &literal.value;
if !is_valid_datetime_format(string_literal_value) {
return ExprTypeCheckResult::Error(
Diagnostic::error(&format!(
"Can't compare DateTime and Text `{}` because it can't be implicitly casted to DateTime",
string_literal_value
)).add_help("A valid DateTime format must match one of the values `YYYY-MM-DD HH:MM:SS` or `YYYY-MM-DD HH:MM:SS.SSS`")
.as_boxed(),
);
}
return ExprTypeCheckResult::ImplicitCasted(Box::new(StringExpression {
value: string_literal_value.to_owned(),
value_type: StringValueType::DateTime,
}));
}
if data_type.is_bool() || data_type.is_variant_with(&DataType::Boolean) {
let literal = expr.as_any().downcast_ref::<StringExpression>().unwrap();
let string_literal_value = &literal.value;
if !BOOLEANS_VALUES_LITERAL.contains(&string_literal_value.as_str()) {
return ExprTypeCheckResult::Error(
Diagnostic::error(&format!(
"Can't compare Boolean and Text `{}` because it can't be implicitly casted to Boolean",
string_literal_value
)).add_help("A valid Boolean value must match `t, true, y, yes, 1, f, false, n, no, 0`")
.as_boxed(),
);
}
return ExprTypeCheckResult::ImplicitCasted(Box::new(StringExpression {
value: string_literal_value.to_owned(),
value_type: StringValueType::Boolean,
}));
}
ExprTypeCheckResult::NotEqualAndCantImplicitCast
}
#[allow(clippy::borrowed_box)]
pub fn are_types_equals(
scope: &Environment,
lhs: &Box<dyn Expression>,
rhs: &Box<dyn Expression>,
) -> TypeCheckResult {
let lhs_type = lhs.expr_type(scope);
let rhs_type = rhs.expr_type(scope);
if lhs_type == rhs_type {
return TypeCheckResult::Equals;
}
match is_expression_type_equals(scope, rhs, &lhs_type) {
ExprTypeCheckResult::ImplicitCasted(expr) => {
return TypeCheckResult::RightSideCasted(expr);
}
ExprTypeCheckResult::Error(diagnostic) => {
return TypeCheckResult::Error(diagnostic);
}
_ => {}
}
match is_expression_type_equals(scope, lhs, &rhs_type) {
ExprTypeCheckResult::ImplicitCasted(expr) => {
return TypeCheckResult::LeftSideCasted(expr);
}
ExprTypeCheckResult::Error(diagnostic) => {
return TypeCheckResult::Error(diagnostic);
}
_ => {}
}
TypeCheckResult::NotEqualAndCantImplicitCast
}
pub fn check_all_values_are_same_type(
env: &mut Environment,
arguments: &[Box<dyn Expression>],
) -> Option<DataType> {
let arguments_count = arguments.len();
if arguments_count == 0 {
return Some(DataType::Any);
}
let data_type = arguments[0].expr_type(env);
for argument in arguments.iter().take(arguments_count).skip(1) {
let expr_type = argument.expr_type(env);
if data_type != expr_type {
return None;
}
}
Some(data_type)
}
pub fn check_function_call_arguments(
env: &Environment,
arguments: &mut [Box<dyn Expression>],
parameters: &[DataType],
function_name: String,
location: Location,
) -> Result<(), Box<Diagnostic>> {
let parameters_count = parameters.len();
let arguments_count = arguments.len();
let mut has_varargs_parameter = false;
let mut optional_parameters_count = 0;
if parameters_count != 0 {
let last_parameter = parameters.last().unwrap();
has_varargs_parameter = last_parameter.is_varargs();
for parameter_type in parameters.iter().take(parameters_count) {
if parameter_type.is_optional() {
optional_parameters_count += 1;
}
}
}
let mut min_arguments_count = parameters_count - optional_parameters_count;
if has_varargs_parameter {
min_arguments_count -= 1;
}
if arguments_count < min_arguments_count {
return Err(Diagnostic::error(&format!(
"Function `{}` expects at least `{}` arguments but got `{}`",
function_name, min_arguments_count, arguments_count
))
.with_location(location)
.as_boxed());
}
if !has_varargs_parameter && arguments_count > parameters_count {
return Err(Diagnostic::error(&format!(
"Function `{}` expects `{}` arguments but got `{}`",
function_name, arguments_count, parameters_count
))
.with_location(location)
.as_boxed());
}
for index in 0..min_arguments_count {
let parameter_type = parameters.get(index).unwrap();
let argument = arguments.get(index).unwrap();
if argument.expr_type(env).is_undefined() {
return Err(Diagnostic::error(&format!(
"Function `{}` argument number {} has Undefined type",
function_name, index,
))
.add_help("Make sure you used a correct field name")
.add_help("Check column names for each table from docs website")
.with_location(location)
.as_boxed());
}
match is_expression_type_equals(env, argument, parameter_type) {
ExprTypeCheckResult::ImplicitCasted(new_expr) => {
arguments[index] = new_expr;
}
ExprTypeCheckResult::NotEqualAndCantImplicitCast => {
let argument_type = argument.expr_type(env);
return Err(Diagnostic::error(&format!(
"Function `{}` argument number {} with type `{}` don't match expected type `{}`",
function_name, index, argument_type, parameter_type
))
.with_location(location).as_boxed());
}
ExprTypeCheckResult::Error(error) => return Err(error),
ExprTypeCheckResult::Equals => {}
}
}
let last_optional_param_index = min_arguments_count + optional_parameters_count;
for index in min_arguments_count..last_optional_param_index {
if index >= arguments_count {
return Ok(());
}
let parameter_type = parameters.get(index).unwrap();
let argument = arguments.get(index).unwrap();
if argument.expr_type(env).is_undefined() {
return Err(Diagnostic::error(&format!(
"Function `{}` argument number {} has Undefined type",
function_name, index,
))
.add_help("Make sure you used a correct field name")
.add_help("Check column names for each table from docs website")
.with_location(location)
.as_boxed());
}
match is_expression_type_equals(env, argument, parameter_type) {
ExprTypeCheckResult::ImplicitCasted(new_expr) => {
arguments[index] = new_expr;
}
ExprTypeCheckResult::NotEqualAndCantImplicitCast => {
let argument_type = argument.expr_type(env);
return Err(Diagnostic::error(&format!(
"Function `{}` argument number {} with type `{}` don't match expected type `{}`",
function_name, index, argument_type, parameter_type
))
.with_location(location).as_boxed());
}
ExprTypeCheckResult::Error(error) => return Err(error),
ExprTypeCheckResult::Equals => {}
}
}
if has_varargs_parameter {
let varargs_type = parameters.last().unwrap();
for index in last_optional_param_index..arguments_count {
let argument = arguments.get(index).unwrap();
if argument.expr_type(env).is_undefined() {
return Err(Diagnostic::error(&format!(
"Function `{}` argument number {} has Undefined type",
function_name, index,
))
.add_help("Make sure you used a correct field name")
.add_help("Check column names for each table from docs website")
.with_location(location)
.as_boxed());
}
match is_expression_type_equals(env, argument, varargs_type) {
ExprTypeCheckResult::ImplicitCasted(new_expr) => {
arguments[index] = new_expr;
}
ExprTypeCheckResult::NotEqualAndCantImplicitCast => {
let argument_type = argument.expr_type(env);
return Err(Diagnostic::error(&format!(
"Function `{}` argument number {} with type `{}` don't match expected type `{}`",
function_name, index, argument_type, varargs_type
))
.with_location(location).as_boxed());
}
ExprTypeCheckResult::Error(error) => return Err(error),
ExprTypeCheckResult::Equals => {}
}
}
}
Ok(())
}
pub fn type_check_and_classify_selected_fields(
env: &mut Environment,
selected_tables: &Vec<String>,
selected_columns: &Vec<String>,
location: Location,
) -> Result<Vec<TableSelection>, Box<Diagnostic>> {
let mut table_selections: Vec<TableSelection> = vec![];
let mut table_index: HashMap<String, usize> = HashMap::new();
for (index, table) in selected_tables.iter().enumerate() {
table_selections.push(TableSelection {
table_name: table.to_string(),
columns_names: vec![],
});
table_index.insert(table.to_string(), index);
}
for selected_column in selected_columns {
let mut is_column_resolved = false;
for table in selected_tables {
let table_columns = env.schema.tables_fields_names.get(table.as_str()).unwrap();
if table_columns.contains(&selected_column.as_str()) {
is_column_resolved = true;
let table_selection_index = *table_index.get(table).unwrap();
let selection = &mut table_selections[table_selection_index];
selection.columns_names.push(selected_column.to_string());
continue;
}
}
if !is_column_resolved {
if let Some(data_type) = env.resolve_type(selected_column) {
if !data_type.is_undefined() {
if table_selections.is_empty() {
table_selections.push(TableSelection {
table_name: selected_tables
.first()
.unwrap_or(&"".to_string())
.to_string(),
columns_names: vec![selected_column.to_string()],
});
} else {
table_selections[0]
.columns_names
.push(selected_column.to_string());
}
continue;
}
}
return Err(Diagnostic::error(&format!(
"Column `{}` not exists in any of the selected tables",
selected_column
))
.add_help("Check the documentations to see available fields for each tables")
.with_location(location)
.as_boxed());
}
}
Ok(table_selections)
}
pub fn type_check_projection_symbols(
env: &mut Environment,
selected_tables: &[String],
projection_names: &[String],
projection_locations: &[Location],
) -> Result<(), Box<Diagnostic>> {
for (index, selected_column) in projection_names.iter().enumerate() {
let mut is_column_resolved = false;
for table in selected_tables {
let table_columns = env.schema.tables_fields_names.get(table.as_str()).unwrap();
if table_columns.contains(&selected_column.as_str()) {
is_column_resolved = true;
break;
}
}
if !is_column_resolved {
return Err(Diagnostic::error(&format!(
"Column `{}` not exists in any of the selected tables",
selected_column
))
.add_help("Check the documentations to see available fields for each tables")
.with_location(projection_locations[index])
.as_boxed());
}
}
Ok(())
}
#[allow(clippy::borrowed_box)]
pub fn type_check_prefix_unary(
env: &Environment,
right: &Box<dyn Expression>,
op: &PrefixUnaryOperator,
location: Location,
) -> ExprTypeCheckResult {
let right_type = right.expr_type(env);
let expected_type = prefix_unary_expected_type(op);
if *op == PrefixUnaryOperator::Bang {
return is_expression_type_equals(env, right, &expected_type);
}
if *op == PrefixUnaryOperator::Minus {
if !right_type.is_number() {
return ExprTypeCheckResult::Error(type_mismatch_error(
location,
expected_type,
right_type,
));
}
return ExprTypeCheckResult::Equals;
}
if *op == PrefixUnaryOperator::Not {
if !right_type.is_int() {
return ExprTypeCheckResult::Error(type_mismatch_error(
location,
expected_type,
right_type,
));
}
return ExprTypeCheckResult::Equals;
}
ExprTypeCheckResult::Equals
}
#[inline(always)]
pub fn prefix_unary_expected_type(op: &PrefixUnaryOperator) -> DataType {
match op {
PrefixUnaryOperator::Minus => DataType::Variant(vec![DataType::Integer, DataType::Float]),
PrefixUnaryOperator::Bang => DataType::Boolean,
PrefixUnaryOperator::Not => DataType::Integer,
}
}
pub fn resolve_call_expression_return_type(
env: &Environment,
signature: &Signature,
arguments: &Vec<Box<dyn Expression>>,
) -> DataType {
let mut return_type = signature.return_type.clone();
if let DataType::Dynamic(calculate_type) = return_type {
return_type = calculate_type(&signature.parameters);
if !arguments.is_empty() && return_type.is_variant() {
let mut arguments_types = Vec::with_capacity(arguments.len());
for argument in arguments {
arguments_types.push(argument.expr_type(env));
}
return_type = calculate_type(&arguments_types);
}
}
return_type
}
#[inline(always)]
pub fn type_mismatch_error(
location: Location,
expected: DataType,
actual: DataType,
) -> Box<Diagnostic> {
Diagnostic::error(&format!(
"Type mismatch expected `{}`, got `{}`",
expected, actual
))
.with_location(location)
.as_boxed()
}