use crate::hyperql::{CompilerError, CompilerResult};
use audb::model::Query;
use hyperQL::ast::{Expression, SelectStatement, Statement};
#[derive(Debug, Clone, PartialEq)]
pub enum QueryPattern {
PointQuery { table: String, id_param: String },
FilterQuery {
table: String,
filters: Vec<FilterCondition>,
},
ProjectionQuery {
table: String,
fields: Vec<String>,
filters: Vec<FilterCondition>,
},
OrderedQuery {
table: String,
filters: Vec<FilterCondition>,
order_by: Vec<OrderByClause>,
limit: Option<u64>,
offset: Option<u64>,
},
RelationshipQuery {
table: String,
traverse: TraverseInfo,
filters: Vec<FilterCondition>,
},
AggregationQuery {
table: String,
aggregates: Vec<AggregateFunction>,
filters: Vec<FilterCondition>,
group_by: Vec<String>,
},
ComplexQuery { ast: Box<SelectStatement> },
}
#[derive(Debug, Clone, PartialEq)]
pub struct FilterCondition {
pub field: String,
pub operator: FilterOperator,
pub value: FilterValue,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FilterOperator {
Equal,
NotEqual,
LessThan,
LessThanOrEqual,
GreaterThan,
GreaterThanOrEqual,
Like,
In,
}
#[derive(Debug, Clone, PartialEq)]
pub enum FilterValue {
Parameter(String),
Literal(LiteralValue),
}
#[derive(Debug, Clone, PartialEq)]
pub enum LiteralValue {
String(String),
Integer(i64),
Float(f64),
Boolean(bool),
Null,
}
#[derive(Debug, Clone, PartialEq)]
pub struct OrderByClause {
pub field: String,
pub descending: bool,
}
#[derive(Debug, Clone, PartialEq)]
pub struct TraverseInfo {
pub relationship: String,
pub target_table: String,
}
#[derive(Debug, Clone, PartialEq)]
pub enum AggregateFunction {
Count,
Sum(String),
Avg(String),
Min(String),
Max(String),
}
pub struct PatternMatcher;
impl PatternMatcher {
pub fn new() -> Self {
Self
}
pub fn analyze(&self, ast: &Statement, query: &Query) -> CompilerResult<QueryPattern> {
match ast {
Statement::Select(select) => self.analyze_select(select, query),
_ => Err(CompilerError::UnsupportedPattern(
"Only SELECT queries are currently supported".to_string(),
)),
}
}
fn analyze_select(
&self,
select: &SelectStatement,
query: &Query,
) -> CompilerResult<QueryPattern> {
let table = self.extract_table_name(select)?;
if let Some(pattern) = self.match_point_query(select, &table, query)? {
return Ok(pattern);
}
if self.has_aggregates(select) {
return self.match_aggregation_query(select, &table);
}
if select.traverse_clause.is_some() {
return self.match_relationship_query(select, &table);
}
if !select.order_by.is_empty() || select.limit.is_some() {
return self.match_ordered_query(select, &table);
}
if !self.is_select_star(select) {
return self.match_projection_query(select, &table);
}
self.match_filter_query(select, &table)
}
fn extract_table_name(&self, select: &SelectStatement) -> CompilerResult<String> {
match &select.from {
Some(from_clause) => match from_clause {
hyperQL::ast::FromClause::Table {
collection,
entity_type: _,
alias: _,
} => Ok(collection.clone()),
hyperQL::ast::FromClause::Subquery { .. } => {
Err(CompilerError::UnsupportedPattern(
"Subqueries are not yet supported".to_string(),
))
}
},
None => Err(CompilerError::UnsupportedPattern(
"SELECT without FROM clause not supported".to_string(),
)),
}
}
fn match_point_query(
&self,
select: &SelectStatement,
table: &str,
query: &Query,
) -> CompilerResult<Option<QueryPattern>> {
if !self.is_select_star(select) {
return Ok(None);
}
let where_clause = match &select.where_clause {
Some(expr) => expr,
None => return Ok(None),
};
if !select.order_by.is_empty()
|| select.limit.is_some()
|| select.offset.is_some()
|| !select.group_by.is_empty()
|| select.having.is_some()
|| !select.joins.is_empty()
|| select.traverse_clause.is_some()
{
return Ok(None);
}
if self.is_id_equality(where_clause) {
if let Some(param) = query.params.first() {
return Ok(Some(QueryPattern::PointQuery {
table: table.to_string(),
id_param: param.name.clone(),
}));
}
}
Ok(None)
}
fn is_id_equality(&self, expr: &Expression) -> bool {
match expr {
Expression::Binary { left, op, right: _ } => {
if !matches!(op, hyperQL::ast::BinaryOperator::Equal) {
return false;
}
matches!(
left.as_ref(),
Expression::Column(col_ref) if col_ref.name == "id"
)
}
_ => false,
}
}
fn match_filter_query(
&self,
select: &SelectStatement,
table: &str,
) -> CompilerResult<QueryPattern> {
let filters = match &select.where_clause {
Some(expr) => self.extract_filters(expr)?,
None => Vec::new(),
};
Ok(QueryPattern::FilterQuery {
table: table.to_string(),
filters,
})
}
fn match_projection_query(
&self,
select: &SelectStatement,
table: &str,
) -> CompilerResult<QueryPattern> {
let fields = self.extract_projection_fields(select)?;
let filters = match &select.where_clause {
Some(expr) => self.extract_filters(expr)?,
None => Vec::new(),
};
Ok(QueryPattern::ProjectionQuery {
table: table.to_string(),
fields,
filters,
})
}
fn match_ordered_query(
&self,
select: &SelectStatement,
table: &str,
) -> CompilerResult<QueryPattern> {
let filters = match &select.where_clause {
Some(expr) => self.extract_filters(expr)?,
None => Vec::new(),
};
let order_by = self.extract_order_by(&select.order_by)?;
Ok(QueryPattern::OrderedQuery {
table: table.to_string(),
filters,
order_by,
limit: select.limit,
offset: select.offset,
})
}
fn match_relationship_query(
&self,
select: &SelectStatement,
table: &str,
) -> CompilerResult<QueryPattern> {
let traverse = match &select.traverse_clause {
Some(clause) => self.extract_traverse_info(clause)?,
None => {
return Err(CompilerError::CodeGenError(
"Expected TRAVERSE clause".to_string(),
));
}
};
let filters = match &select.where_clause {
Some(expr) => self.extract_filters(expr)?,
None => Vec::new(),
};
Ok(QueryPattern::RelationshipQuery {
table: table.to_string(),
traverse,
filters,
})
}
fn match_aggregation_query(
&self,
select: &SelectStatement,
table: &str,
) -> CompilerResult<QueryPattern> {
let aggregates = self.extract_aggregates(select)?;
let filters = match &select.where_clause {
Some(expr) => self.extract_filters(expr)?,
None => Vec::new(),
};
let group_by = self.extract_group_by(&select.group_by)?;
Ok(QueryPattern::AggregationQuery {
table: table.to_string(),
aggregates,
filters,
group_by,
})
}
fn is_select_star(&self, select: &SelectStatement) -> bool {
select.select_list.len() == 1
&& matches!(select.select_list[0], hyperQL::ast::SelectItem::Wildcard)
}
fn has_aggregates(&self, select: &SelectStatement) -> bool {
for item in &select.select_list {
if let hyperQL::ast::SelectItem::Expression { expr, .. } = item {
if self.is_aggregate_expr(expr) {
return true;
}
}
}
false
}
fn is_aggregate_expr(&self, expr: &Expression) -> bool {
matches!(
expr,
Expression::Function { name, .. }
if matches!(name.to_uppercase().as_str(), "COUNT" | "SUM" | "AVG" | "MIN" | "MAX")
)
}
fn extract_filters(&self, expr: &Expression) -> CompilerResult<Vec<FilterCondition>> {
let mut filters = Vec::new();
self.extract_filters_recursive(expr, &mut filters)?;
Ok(filters)
}
fn extract_filters_recursive(
&self,
expr: &Expression,
filters: &mut Vec<FilterCondition>,
) -> CompilerResult<()> {
match expr {
Expression::Binary { left, op, right } => {
match op {
hyperQL::ast::BinaryOperator::And | hyperQL::ast::BinaryOperator::Or => {
self.extract_filters_recursive(left, filters)?;
self.extract_filters_recursive(right, filters)?;
}
_ => {
let filter = self.extract_single_filter(left, op, right)?;
filters.push(filter);
}
}
}
Expression::Unary { op, expr } => {
if matches!(op, hyperQL::ast::UnaryOperator::Not) {
self.extract_filters_recursive(expr, filters)?;
}
}
_ => {
}
}
Ok(())
}
fn extract_single_filter(
&self,
left: &Expression,
op: &hyperQL::ast::BinaryOperator,
right: &Expression,
) -> CompilerResult<FilterCondition> {
let field = match left {
Expression::Column(col_ref) => col_ref.name.clone(),
_ => {
return Err(CompilerError::InvalidFilter(
"Left side of comparison must be a column reference".to_string(),
));
}
};
let operator = self.convert_binary_operator(op)?;
let value = self.extract_filter_value(right)?;
Ok(FilterCondition {
field,
operator,
value,
})
}
fn convert_binary_operator(
&self,
op: &hyperQL::ast::BinaryOperator,
) -> CompilerResult<FilterOperator> {
match op {
hyperQL::ast::BinaryOperator::Equal => Ok(FilterOperator::Equal),
hyperQL::ast::BinaryOperator::NotEqual => Ok(FilterOperator::NotEqual),
hyperQL::ast::BinaryOperator::LessThan => Ok(FilterOperator::LessThan),
hyperQL::ast::BinaryOperator::LessThanOrEqual => Ok(FilterOperator::LessThanOrEqual),
hyperQL::ast::BinaryOperator::GreaterThan => Ok(FilterOperator::GreaterThan),
hyperQL::ast::BinaryOperator::GreaterThanOrEqual => {
Ok(FilterOperator::GreaterThanOrEqual)
}
hyperQL::ast::BinaryOperator::Like => Ok(FilterOperator::Like),
hyperQL::ast::BinaryOperator::In => Ok(FilterOperator::In),
_ => Err(CompilerError::InvalidFilter(format!(
"Unsupported operator in filter: {:?}",
op
))),
}
}
fn extract_filter_value(&self, expr: &Expression) -> CompilerResult<FilterValue> {
match expr {
Expression::Literal(lit) => {
let literal = self.convert_literal(lit);
Ok(FilterValue::Literal(literal))
}
Expression::Column(col_ref) => {
Ok(FilterValue::Parameter(col_ref.name.clone()))
}
_ => Err(CompilerError::InvalidFilter(
"Filter value must be a literal or parameter".to_string(),
)),
}
}
fn convert_literal(&self, lit: &hyperQL::ast::Literal) -> LiteralValue {
match lit {
hyperQL::ast::Literal::Null => LiteralValue::Null,
hyperQL::ast::Literal::Bool(b) => LiteralValue::Boolean(*b),
hyperQL::ast::Literal::Int(i) => LiteralValue::Integer(*i),
hyperQL::ast::Literal::Float(f) => LiteralValue::Float(*f),
hyperQL::ast::Literal::String(s) => LiteralValue::String(s.clone()),
hyperQL::ast::Literal::EntityId(_) => {
LiteralValue::String("entity_id".to_string())
}
}
}
fn extract_projection_fields(&self, select: &SelectStatement) -> CompilerResult<Vec<String>> {
let mut fields = Vec::new();
for item in &select.select_list {
match item {
hyperQL::ast::SelectItem::Wildcard => {
return Err(CompilerError::CodeGenError(
"Cannot mix wildcard with specific fields".to_string(),
));
}
hyperQL::ast::SelectItem::Expression { expr, alias } => {
let field_name = if let Some(name) = alias {
name.clone()
} else if let Expression::Column(col_ref) = expr {
col_ref.name.clone()
} else {
return Err(CompilerError::CodeGenError(
"Complex expressions in SELECT require aliases".to_string(),
));
};
fields.push(field_name);
}
}
}
Ok(fields)
}
fn extract_order_by(
&self,
order_items: &[hyperQL::ast::OrderByItem],
) -> CompilerResult<Vec<OrderByClause>> {
let mut clauses = Vec::new();
for item in order_items {
let field = if let Expression::Column(col_ref) = &item.expr {
col_ref.name.clone()
} else {
return Err(CompilerError::CodeGenError(
"Complex ORDER BY expressions not yet supported".to_string(),
));
};
let descending = matches!(item.direction, hyperQL::ast::OrderDirection::Desc);
clauses.push(OrderByClause { field, descending });
}
Ok(clauses)
}
fn extract_traverse_info(
&self,
clause: &hyperQL::ast::TraverseClause,
) -> CompilerResult<TraverseInfo> {
if clause.patterns.is_empty() {
return Err(CompilerError::UnsupportedPattern(
"TRAVERSE clause must have at least one pattern".to_string(),
));
}
let pattern = &clause.patterns[0];
let relationship = pattern.relationship.rel_type.clone().ok_or_else(|| {
CompilerError::UnsupportedPattern("TRAVERSE relationship must have a type".to_string())
})?;
let target_table = pattern.end_node.label.clone().ok_or_else(|| {
CompilerError::UnsupportedPattern("TRAVERSE target node must have a label".to_string())
})?;
Ok(TraverseInfo {
relationship,
target_table,
})
}
fn extract_aggregates(
&self,
select: &SelectStatement,
) -> CompilerResult<Vec<AggregateFunction>> {
let mut aggregates = Vec::new();
for item in &select.select_list {
if let hyperQL::ast::SelectItem::Expression { expr, .. } = item {
if let Some(agg) = self.extract_aggregate_function(expr)? {
aggregates.push(agg);
}
}
}
Ok(aggregates)
}
fn extract_aggregate_function(
&self,
expr: &Expression,
) -> CompilerResult<Option<AggregateFunction>> {
match expr {
Expression::Function { name, args } => match name.to_uppercase().as_str() {
"COUNT" => Ok(Some(AggregateFunction::Count)),
"SUM" => {
let field = self.extract_field_from_args(args)?;
Ok(Some(AggregateFunction::Sum(field)))
}
"AVG" => {
let field = self.extract_field_from_args(args)?;
Ok(Some(AggregateFunction::Avg(field)))
}
"MIN" => {
let field = self.extract_field_from_args(args)?;
Ok(Some(AggregateFunction::Min(field)))
}
"MAX" => {
let field = self.extract_field_from_args(args)?;
Ok(Some(AggregateFunction::Max(field)))
}
_ => Ok(None),
},
_ => Ok(None),
}
}
fn extract_field_from_args(&self, args: &[Expression]) -> CompilerResult<String> {
if args.is_empty() {
return Err(CompilerError::CodeGenError(
"Aggregate function requires arguments".to_string(),
));
}
match &args[0] {
Expression::Column(col_ref) => Ok(col_ref.name.clone()),
_ => Err(CompilerError::CodeGenError(
"Aggregate function argument must be a column reference".to_string(),
)),
}
}
fn extract_group_by(&self, group_exprs: &[Expression]) -> CompilerResult<Vec<String>> {
let mut fields = Vec::new();
for expr in group_exprs {
match expr {
Expression::Column(col_ref) => fields.push(col_ref.name.clone()),
_ => {
return Err(CompilerError::CodeGenError(
"Complex GROUP BY expressions not yet supported".to_string(),
));
}
}
}
Ok(fields)
}
}
impl Default for PatternMatcher {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pattern_matcher_creation() {
let matcher = PatternMatcher::new();
assert!(true); }
#[test]
fn test_is_select_star() {
let matcher = PatternMatcher::new();
let select = SelectStatement {
select_list: vec![hyperQL::ast::SelectItem::Wildcard],
from: None,
joins: Vec::new(),
traverse_clause: None,
where_clause: None,
group_by: Vec::new(),
having: None,
order_by: Vec::new(),
limit: None,
offset: None,
distinct: false,
};
assert!(matcher.is_select_star(&select));
}
#[test]
fn test_extract_single_filter() {
use hyperQL::ast::{BinaryOperator, ColumnRef, Expression, Literal};
let matcher = PatternMatcher::new();
let left = Expression::Column(ColumnRef {
table: None,
name: "age".to_string(),
});
let op = BinaryOperator::GreaterThan;
let right = Expression::Literal(Literal::Int(25));
let result = matcher.extract_single_filter(&left, &op, &right);
assert!(result.is_ok(), "Failed to extract filter: {:?}", result);
let filter = result.unwrap();
assert_eq!(filter.field, "age");
assert_eq!(filter.operator, FilterOperator::GreaterThan);
assert!(matches!(filter.value, FilterValue::Literal(_)));
}
#[test]
fn test_extract_filters_from_and_expression() {
use hyperQL::ast::{BinaryOperator, ColumnRef, Expression, Literal};
let matcher = PatternMatcher::new();
let age_filter = Expression::Binary {
left: Box::new(Expression::Column(ColumnRef {
table: None,
name: "age".to_string(),
})),
op: BinaryOperator::GreaterThan,
right: Box::new(Expression::Literal(Literal::Int(25))),
};
let active_filter = Expression::Binary {
left: Box::new(Expression::Column(ColumnRef {
table: None,
name: "active".to_string(),
})),
op: BinaryOperator::Equal,
right: Box::new(Expression::Literal(Literal::Bool(true))),
};
let combined = Expression::Binary {
left: Box::new(age_filter),
op: BinaryOperator::And,
right: Box::new(active_filter),
};
let result = matcher.extract_filters(&combined);
assert!(result.is_ok(), "Failed to extract filters: {:?}", result);
let filters = result.unwrap();
assert_eq!(filters.len(), 2, "Expected 2 filters");
assert_eq!(filters[0].field, "age");
assert_eq!(filters[1].field, "active");
}
#[test]
fn test_convert_binary_operator() {
let matcher = PatternMatcher::new();
let test_cases = vec![
(hyperQL::ast::BinaryOperator::Equal, FilterOperator::Equal),
(
hyperQL::ast::BinaryOperator::NotEqual,
FilterOperator::NotEqual,
),
(
hyperQL::ast::BinaryOperator::LessThan,
FilterOperator::LessThan,
),
(
hyperQL::ast::BinaryOperator::LessThanOrEqual,
FilterOperator::LessThanOrEqual,
),
(
hyperQL::ast::BinaryOperator::GreaterThan,
FilterOperator::GreaterThan,
),
(
hyperQL::ast::BinaryOperator::GreaterThanOrEqual,
FilterOperator::GreaterThanOrEqual,
),
(hyperQL::ast::BinaryOperator::Like, FilterOperator::Like),
(hyperQL::ast::BinaryOperator::In, FilterOperator::In),
];
for (input, expected) in test_cases {
let result = matcher.convert_binary_operator(&input);
assert!(result.is_ok(), "Failed to convert operator {:?}", input);
assert_eq!(result.unwrap(), expected);
}
}
#[test]
fn test_convert_literal() {
use hyperQL::ast::Literal;
let matcher = PatternMatcher::new();
let test_cases = vec![
(Literal::Null, LiteralValue::Null),
(Literal::Bool(true), LiteralValue::Boolean(true)),
(Literal::Int(42), LiteralValue::Integer(42)),
(Literal::Float(3.14), LiteralValue::Float(3.14)),
(
Literal::String("test".to_string()),
LiteralValue::String("test".to_string()),
),
];
for (input, expected) in test_cases {
let result = matcher.convert_literal(&input);
assert_eq!(result, expected);
}
}
#[test]
fn test_extract_filter_value_parameter() {
use hyperQL::ast::{ColumnRef, Expression};
let matcher = PatternMatcher::new();
let expr = Expression::Column(ColumnRef {
table: None,
name: "min_age".to_string(),
});
let result = matcher.extract_filter_value(&expr);
assert!(result.is_ok());
let value = result.unwrap();
assert!(matches!(value, FilterValue::Parameter(ref p) if p == "min_age"));
}
#[test]
fn test_extract_filter_value_literal() {
use hyperQL::ast::{Expression, Literal};
let matcher = PatternMatcher::new();
let expr = Expression::Literal(Literal::Int(100));
let result = matcher.extract_filter_value(&expr);
assert!(result.is_ok());
let value = result.unwrap();
assert!(matches!(
value,
FilterValue::Literal(LiteralValue::Integer(100))
));
}
}