use crate::expressions::{BinaryOp, Expression};
use crate::optimizer::simplify::Simplifier;
use thiserror::Error;
pub const DEFAULT_MAX_DISTANCE: i64 = 128;
#[derive(Debug, Error, Clone)]
pub enum NormalizeError {
#[error("Normalization distance {distance} exceeds max {max}")]
DistanceExceeded { distance: i64, max: i64 },
}
pub type NormalizeResult<T> = Result<T, NormalizeError>;
pub fn normalize(
expression: Expression,
dnf: bool,
max_distance: i64,
) -> NormalizeResult<Expression> {
let simplifier = Simplifier::new(None);
normalize_with_simplifier(expression, dnf, max_distance, &simplifier)
}
fn normalize_with_simplifier(
expression: Expression,
dnf: bool,
max_distance: i64,
simplifier: &Simplifier,
) -> NormalizeResult<Expression> {
if normalized(&expression, dnf) {
return Ok(expression);
}
let distance = normalization_distance(&expression, dnf, max_distance);
if distance > max_distance {
return Ok(expression);
}
apply_distributive_law(&expression, dnf, max_distance, simplifier)
}
pub fn normalized(expression: &Expression, dnf: bool) -> bool {
if dnf {
!has_and_with_or_descendant(expression)
} else {
!has_or_with_and_descendant(expression)
}
}
fn has_or_with_and_descendant(expression: &Expression) -> bool {
match expression {
Expression::Or(bin) => {
contains_and(&bin.left)
|| contains_and(&bin.right)
|| has_or_with_and_descendant(&bin.left)
|| has_or_with_and_descendant(&bin.right)
}
Expression::And(bin) => {
has_or_with_and_descendant(&bin.left) || has_or_with_and_descendant(&bin.right)
}
Expression::Paren(paren) => has_or_with_and_descendant(&paren.this),
_ => false,
}
}
fn has_and_with_or_descendant(expression: &Expression) -> bool {
match expression {
Expression::And(bin) => {
contains_or(&bin.left)
|| contains_or(&bin.right)
|| has_and_with_or_descendant(&bin.left)
|| has_and_with_or_descendant(&bin.right)
}
Expression::Or(bin) => {
has_and_with_or_descendant(&bin.left) || has_and_with_or_descendant(&bin.right)
}
Expression::Paren(paren) => has_and_with_or_descendant(&paren.this),
_ => false,
}
}
fn contains_and(expression: &Expression) -> bool {
match expression {
Expression::And(_) => true,
Expression::Or(bin) => contains_and(&bin.left) || contains_and(&bin.right),
Expression::Paren(paren) => contains_and(&paren.this),
_ => false,
}
}
fn contains_or(expression: &Expression) -> bool {
match expression {
Expression::Or(_) => true,
Expression::And(bin) => contains_or(&bin.left) || contains_or(&bin.right),
Expression::Paren(paren) => contains_or(&paren.this),
_ => false,
}
}
pub fn normalization_distance(expression: &Expression, dnf: bool, max_distance: i64) -> i64 {
let connector_count = count_connectors(expression);
let mut total: i64 = -(connector_count as i64 + 1);
for length in predicate_lengths(expression, dnf, max_distance, 0) {
total += length;
if total > max_distance {
return total;
}
}
total
}
fn predicate_lengths(
expression: &Expression,
dnf: bool,
max_distance: i64,
depth: i64,
) -> Vec<i64> {
if depth > max_distance {
return vec![depth];
}
let expr = unwrap_paren(expression);
match expr {
Expression::Or(bin) if !dnf => {
let left_lengths = predicate_lengths(&bin.left, dnf, max_distance, depth + 1);
let right_lengths = predicate_lengths(&bin.right, dnf, max_distance, depth + 1);
let mut result = Vec::new();
for a in &left_lengths {
for b in &right_lengths {
result.push(a + b);
}
}
result
}
Expression::And(bin) if dnf => {
let left_lengths = predicate_lengths(&bin.left, dnf, max_distance, depth + 1);
let right_lengths = predicate_lengths(&bin.right, dnf, max_distance, depth + 1);
let mut result = Vec::new();
for a in &left_lengths {
for b in &right_lengths {
result.push(a + b);
}
}
result
}
Expression::And(bin) | Expression::Or(bin) => {
let mut result = predicate_lengths(&bin.left, dnf, max_distance, depth + 1);
result.extend(predicate_lengths(&bin.right, dnf, max_distance, depth + 1));
result
}
_ => vec![1], }
}
fn apply_distributive_law(
expression: &Expression,
dnf: bool,
max_distance: i64,
simplifier: &Simplifier,
) -> NormalizeResult<Expression> {
if normalized(expression, dnf) {
return Ok(expression.clone());
}
let distance = normalization_distance(expression, dnf, max_distance);
if distance > max_distance {
return Err(NormalizeError::DistanceExceeded {
distance,
max: max_distance,
});
}
let result = if dnf {
distribute_dnf(expression, simplifier)
} else {
distribute_cnf(expression, simplifier)
};
if !normalized(&result, dnf) {
apply_distributive_law(&result, dnf, max_distance, simplifier)
} else {
Ok(result)
}
}
fn distribute_cnf(expression: &Expression, simplifier: &Simplifier) -> Expression {
match expression {
Expression::Or(bin) => {
let left = distribute_cnf(&bin.left, simplifier);
let right = distribute_cnf(&bin.right, simplifier);
if let Expression::And(and_bin) = &right {
let left_or_y = make_or(left.clone(), and_bin.left.clone());
let left_or_z = make_or(left, and_bin.right.clone());
return make_and(left_or_y, left_or_z);
}
if let Expression::And(and_bin) = &left {
let y_or_right = make_or(and_bin.left.clone(), right.clone());
let z_or_right = make_or(and_bin.right.clone(), right);
return make_and(y_or_right, z_or_right);
}
make_or(left, right)
}
Expression::And(bin) => {
let left = distribute_cnf(&bin.left, simplifier);
let right = distribute_cnf(&bin.right, simplifier);
make_and(left, right)
}
Expression::Paren(paren) => distribute_cnf(&paren.this, simplifier),
_ => expression.clone(),
}
}
fn distribute_dnf(expression: &Expression, simplifier: &Simplifier) -> Expression {
match expression {
Expression::And(bin) => {
let left = distribute_dnf(&bin.left, simplifier);
let right = distribute_dnf(&bin.right, simplifier);
if let Expression::Or(or_bin) = &right {
let left_and_y = make_and(left.clone(), or_bin.left.clone());
let left_and_z = make_and(left, or_bin.right.clone());
return make_or(left_and_y, left_and_z);
}
if let Expression::Or(or_bin) = &left {
let y_and_right = make_and(or_bin.left.clone(), right.clone());
let z_and_right = make_and(or_bin.right.clone(), right);
return make_or(y_and_right, z_and_right);
}
make_and(left, right)
}
Expression::Or(bin) => {
let left = distribute_dnf(&bin.left, simplifier);
let right = distribute_dnf(&bin.right, simplifier);
make_or(left, right)
}
Expression::Paren(paren) => distribute_dnf(&paren.this, simplifier),
_ => expression.clone(),
}
}
fn count_connectors(expression: &Expression) -> usize {
match expression {
Expression::And(bin) | Expression::Or(bin) => {
1 + count_connectors(&bin.left) + count_connectors(&bin.right)
}
Expression::Paren(paren) => count_connectors(&paren.this),
_ => 0,
}
}
fn unwrap_paren(expression: &Expression) -> &Expression {
match expression {
Expression::Paren(paren) => unwrap_paren(&paren.this),
_ => expression,
}
}
fn make_and(left: Expression, right: Expression) -> Expression {
Expression::And(Box::new(BinaryOp {
left,
right,
left_comments: vec![],
operator_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
}))
}
fn make_or(left: Expression, right: Expression) -> Expression {
Expression::Or(Box::new(BinaryOp {
left,
right,
left_comments: vec![],
operator_comments: vec![],
trailing_comments: vec![],
inferred_type: None,
}))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::Parser;
fn parse(sql: &str) -> Expression {
Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
}
fn parse_predicate(sql: &str) -> Expression {
let full = format!("SELECT 1 WHERE {}", sql);
let stmt = parse(&full);
if let Expression::Select(select) = stmt {
if let Some(where_clause) = select.where_clause {
return where_clause.this;
}
}
panic!("Failed to extract predicate from: {}", sql);
}
#[test]
fn test_normalized_cnf() {
let expr = parse_predicate("(a OR b) AND (c OR d)");
assert!(normalized(&expr, false)); }
#[test]
fn test_normalized_dnf() {
let expr = parse_predicate("(a AND b) OR (c AND d)");
assert!(normalized(&expr, true)); }
#[test]
fn test_not_normalized_cnf() {
let expr = parse_predicate("(a AND b) OR c");
assert!(!normalized(&expr, false)); }
#[test]
fn test_not_normalized_dnf() {
let expr = parse_predicate("(a OR b) AND c");
assert!(!normalized(&expr, true)); }
#[test]
fn test_simple_literal_is_normalized() {
let expr = parse_predicate("a = 1");
assert!(normalized(&expr, false)); assert!(normalized(&expr, true)); }
#[test]
fn test_normalization_distance_simple() {
let expr = parse_predicate("a = 1");
let distance = normalization_distance(&expr, false, 128);
assert!(distance <= 0);
}
#[test]
fn test_normalization_distance_complex() {
let expr = parse_predicate("(a AND b) OR (c AND d)");
let distance = normalization_distance(&expr, false, 128);
assert!(distance > 0);
}
#[test]
fn test_normalize_to_cnf() {
let expr = parse_predicate("(x AND y) OR z");
let result = normalize(expr, false, 128).unwrap();
assert!(normalized(&result, false));
}
#[test]
fn test_normalize_to_dnf() {
let expr = parse_predicate("(x OR y) AND z");
let result = normalize(expr, true, 128).unwrap();
assert!(normalized(&result, true));
}
#[test]
fn test_count_connectors() {
let expr = parse_predicate("a AND b AND c");
let count = count_connectors(&expr);
assert_eq!(count, 2); }
#[test]
fn test_predicate_lengths() {
let expr = parse_predicate("a = 1");
let lengths = predicate_lengths(&expr, false, 128, 0);
assert_eq!(lengths, vec![1]);
}
}