use crate::{QError, QErrorKind, QSource};
use dibs_query_schema::{Meta, Span};
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ArgSpec {
VariableOrLiteral,
Variable,
Literal,
}
#[derive(Debug, Clone)]
pub enum FilterArg {
Variable(String),
Literal(String),
}
impl FilterArg {
pub fn parse(s: &str) -> Self {
if let Some(var_name) = s.strip_prefix('$') {
FilterArg::Variable(var_name.to_string())
} else {
FilterArg::Literal(s.to_string())
}
}
pub fn is_variable(&self) -> bool {
matches!(self, FilterArg::Variable(_))
}
pub fn is_literal(&self) -> bool {
matches!(self, FilterArg::Literal(_))
}
pub fn as_str(&self) -> &str {
match self {
FilterArg::Variable(s) | FilterArg::Literal(s) => s,
}
}
}
pub struct FunctionSpec {
pub name: &'static str,
pub args: &'static [ArgSpec],
}
impl FunctionSpec {
pub fn parse_args(
&self,
source: Arc<QSource>,
span: Span,
args: &[Meta<String>],
) -> Result<Vec<FilterArg>, QError> {
if args.len() != self.args.len() {
return Err(QError {
source,
span,
kind: QErrorKind::InvalidFilterArgCount {
filter: self.name.to_string(),
expected: self.args.len(),
actual: args.len(),
},
});
}
let mut parsed = Vec::with_capacity(args.len());
for (i, (arg_meta, spec)) in args.iter().zip(self.args.iter()).enumerate() {
let filter_arg = FilterArg::parse(arg_meta.as_str());
match spec {
ArgSpec::VariableOrLiteral => {
}
ArgSpec::Variable => {
if !filter_arg.is_variable() {
return Err(QError {
source: source.clone(),
span,
kind: QErrorKind::InvalidFilterArgType {
filter: self.name.to_string(),
reason: format!(
"argument {i} must be a variable reference (starting with $), got literal",
),
},
});
}
}
ArgSpec::Literal => {
if !filter_arg.is_literal() {
return Err(QError {
source: source.clone(),
span,
kind: QErrorKind::InvalidFilterArgType {
filter: self.name.to_string(),
reason: format!(
"argument {i} must be a literal value, got variable reference",
),
},
});
}
}
}
parsed.push(filter_arg);
}
Ok(parsed)
}
}
pub const EQ_SPEC: FunctionSpec = FunctionSpec {
name: "eq",
args: &[ArgSpec::VariableOrLiteral],
};
pub const NE_SPEC: FunctionSpec = FunctionSpec {
name: "ne",
args: &[ArgSpec::VariableOrLiteral],
};
pub const LT_SPEC: FunctionSpec = FunctionSpec {
name: "lt",
args: &[ArgSpec::VariableOrLiteral],
};
pub const LTE_SPEC: FunctionSpec = FunctionSpec {
name: "lte",
args: &[ArgSpec::VariableOrLiteral],
};
pub const GT_SPEC: FunctionSpec = FunctionSpec {
name: "gt",
args: &[ArgSpec::VariableOrLiteral],
};
pub const GTE_SPEC: FunctionSpec = FunctionSpec {
name: "gte",
args: &[ArgSpec::VariableOrLiteral],
};
pub const LIKE_SPEC: FunctionSpec = FunctionSpec {
name: "like",
args: &[ArgSpec::VariableOrLiteral],
};
pub const ILIKE_SPEC: FunctionSpec = FunctionSpec {
name: "ilike",
args: &[ArgSpec::VariableOrLiteral],
};
pub const IN_SPEC: FunctionSpec = FunctionSpec {
name: "in",
args: &[ArgSpec::VariableOrLiteral],
};
pub const JSON_GET_SPEC: FunctionSpec = FunctionSpec {
name: "json-get",
args: &[ArgSpec::VariableOrLiteral],
};
pub const JSON_GET_TEXT_SPEC: FunctionSpec = FunctionSpec {
name: "json-get-text",
args: &[ArgSpec::VariableOrLiteral],
};
pub const CONTAINS_SPEC: FunctionSpec = FunctionSpec {
name: "contains",
args: &[ArgSpec::VariableOrLiteral],
};
pub const KEY_EXISTS_SPEC: FunctionSpec = FunctionSpec {
name: "key-exists",
args: &[ArgSpec::VariableOrLiteral],
};
pub const EQ_BARE_SPEC: FunctionSpec = FunctionSpec {
name: "eq-bare",
args: &[ArgSpec::VariableOrLiteral],
};
use dibs_query_schema::FilterValue;
fn get_spec_and_args(
filter_value: &FilterValue,
) -> Option<(&'static FunctionSpec, &[Meta<String>])> {
match filter_value {
FilterValue::Null | FilterValue::NotNull => None,
FilterValue::Eq(args) => Some((&EQ_SPEC, args)),
FilterValue::Ne(args) => Some((&NE_SPEC, args)),
FilterValue::Lt(args) => Some((<_SPEC, args)),
FilterValue::Lte(args) => Some((<E_SPEC, args)),
FilterValue::Gt(args) => Some((>_SPEC, args)),
FilterValue::Gte(args) => Some((>E_SPEC, args)),
FilterValue::Like(args) => Some((&LIKE_SPEC, args)),
FilterValue::Ilike(args) => Some((&ILIKE_SPEC, args)),
FilterValue::In(args) => Some((&IN_SPEC, args)),
FilterValue::JsonGet(args) => Some((&JSON_GET_SPEC, args)),
FilterValue::JsonGetText(args) => Some((&JSON_GET_TEXT_SPEC, args)),
FilterValue::Contains(args) => Some((&CONTAINS_SPEC, args)),
FilterValue::KeyExists(args) => Some((&KEY_EXISTS_SPEC, args)),
FilterValue::EqBare(_) => None, }
}
pub fn validate_filter(
source: Arc<QSource>,
filter_span: Span,
filter_value: &FilterValue,
) -> Result<Option<Vec<FilterArg>>, QError> {
match get_spec_and_args(filter_value) {
Some((spec, args)) => {
let parsed = spec.parse_args(source, filter_span, args)?;
Ok(Some(parsed))
}
None => {
if let FilterValue::EqBare(Some(meta)) = filter_value {
let _ = FilterArg::parse(&meta.value);
}
Ok(None)
}
}
}
use dibs_query_schema::Where;
use dibs_sql::ColumnName;
pub fn validate_where(source: Arc<QSource>, where_clause: &Where) -> Result<(), QError> {
for (column_meta, filter_value) in &where_clause.filters {
validate_filter(source.clone(), column_meta.span, filter_value)?;
}
Ok(())
}
pub fn validate_relation_where(
source: Arc<QSource>,
where_clause: &indexmap::IndexMap<Meta<ColumnName>, FilterValue>,
) -> Result<(), QError> {
for (column_meta, filter_value) in where_clause {
validate_filter(source.clone(), column_meta.span, filter_value)?;
}
Ok(())
}
use dibs_query_schema::{Decl, FieldDef, QueryFile, SelectFields};
pub fn validate_query_file(source: Arc<QSource>, query_file: &QueryFile) -> Result<(), QError> {
for (_name_meta, decl) in &query_file.0 {
match decl {
Decl::Select(select) => {
if let Some(where_clause) = &select.where_clause {
validate_where(source.clone(), where_clause)?;
}
if let Some(fields) = &select.fields {
validate_select_fields(source.clone(), fields)?;
}
}
Decl::Update(update) => {
if let Some(where_clause) = &update.where_clause {
validate_where(source.clone(), where_clause)?;
}
}
Decl::Delete(delete) => {
if let Some(where_clause) = &delete.where_clause {
validate_where(source.clone(), where_clause)?;
}
}
Decl::Insert(_) | Decl::InsertMany(_) | Decl::Upsert(_) | Decl::UpsertMany(_) => {}
}
}
Ok(())
}
fn validate_select_fields(source: Arc<QSource>, fields: &SelectFields) -> Result<(), QError> {
for (_field_name, field_def) in &fields.fields {
if let Some(FieldDef::Rel(relation)) = field_def {
if let Some(where_clause) = &relation.where_clause {
validate_where(source.clone(), where_clause)?;
}
if let Some(nested_fields) = &relation.fields {
validate_select_fields(source.clone(), nested_fields)?;
}
}
}
Ok(())
}