use alloc::format;
use alloc::vec::Vec;
use spg_sql::ast::{BinOp, ColumnName, Expr, Literal, UnOp};
use spg_storage::{Row, Value};
use super::{
EvalContext, EvalError, apply_binary, apply_unary, column_collation, composite_eq, eval_expr,
like_match_inner, literal_to_value,
};
pub(crate) enum Step {
Column(usize),
Lit(Value),
Binary(BinOp),
BinaryCi(BinOp),
Unary(UnOp),
IsNull {
negated: bool,
},
InSet {
set: crate::memoize::InListSet,
has_null: bool,
negated: bool,
fallback: Expr,
},
Like {
pattern: alloc::vec::Vec<char>,
negated: bool,
case_insensitive: bool,
},
Function {
name_lower: alloc::string::String,
n_args: usize,
},
ColumnLength {
pos: usize,
},
ColumnOctetLength {
pos: usize,
},
Cast {
target: spg_sql::ast::CastTarget,
},
Subtree(Expr),
}
pub(crate) struct CompiledExpr {
steps: Vec<Step>,
}
impl CompiledExpr {
pub(crate) fn as_single_column_length(&self) -> Option<usize> {
if self.steps.len() == 1
&& let Step::ColumnLength { pos } = &self.steps[0]
{
Some(*pos)
} else {
None
}
}
}
fn compile_column_pos(c: &ColumnName, ctx: &EvalContext<'_>) -> Option<usize> {
if let Some(q) = &c.qualifier {
if let Some(pos) = ctx
.columns
.iter()
.position(|s| composite_eq(&s.name, q, &c.name))
{
return Some(pos);
}
let prefix_exists = ctx.columns.iter().any(|s| {
s.name.starts_with(q.as_str()) && s.name.as_bytes().get(q.len()) == Some(&b'.')
});
if prefix_exists {
return None;
}
match ctx.table_alias {
Some(a) if a == q => {}
_ => return None,
}
}
if let Some(pos) = ctx.columns.iter().position(|s| s.name == c.name) {
return Some(pos);
}
let mut matches = ctx.columns.iter().enumerate().filter(|(_, s)| {
s.name.len() > c.name.len()
&& s.name.ends_with(c.name.as_str())
&& s.name.as_bytes()[s.name.len() - c.name.len() - 1] == b'.'
});
let first = matches.next();
if matches.next().is_some() {
return None; }
first.map(|(i, _)| i)
}
fn compile_into(e: &Expr, ctx: &EvalContext<'_>, steps: &mut Vec<Step>) {
match e {
Expr::Literal(l) => steps.push(Step::Lit(literal_to_value(l))),
Expr::Column(c) => match compile_column_pos(c, ctx) {
Some(pos) => steps.push(Step::Column(pos)),
None => steps.push(Step::Subtree(e.clone())),
},
Expr::Binary { lhs, op, rhs } => {
compile_into(lhs, ctx, steps);
compile_into(rhs, ctx, steps);
let cmp = matches!(
op,
BinOp::Eq | BinOp::NotEq | BinOp::Lt | BinOp::LtEq | BinOp::Gt | BinOp::GtEq
);
let ci = cmp
&& (matches!(
column_collation(lhs, ctx),
Some(spg_storage::Collation::CaseInsensitive)
) || matches!(
column_collation(rhs, ctx),
Some(spg_storage::Collation::CaseInsensitive)
));
steps.push(if ci {
Step::BinaryCi(*op)
} else {
Step::Binary(*op)
});
}
Expr::Unary { op, expr } => {
compile_into(expr, ctx, steps);
steps.push(Step::Unary(*op));
}
Expr::IsNull { expr, negated } => {
compile_into(expr, ctx, steps);
steps.push(Step::IsNull { negated: *negated });
}
Expr::InList {
expr,
list,
negated,
} => {
match crate::build_in_list_set(list) {
Some(entry) if fully_compilable(expr) => {
compile_into(expr, ctx, steps);
steps.push(Step::InSet {
set: entry.set,
has_null: entry.has_null,
negated: *negated,
fallback: e.clone(),
});
}
_ => steps.push(Step::Subtree(e.clone())),
}
}
Expr::Like {
expr,
pattern,
negated,
case_insensitive,
} => match literal_text_pattern(pattern) {
Some(pat) if fully_compilable(expr) => {
if !pat.is_empty() && pat.chars().all(|c| c == '%') {
compile_into(expr, ctx, steps);
steps.push(Step::IsNull { negated: !*negated });
return;
}
compile_into(expr, ctx, steps);
let chars: alloc::vec::Vec<char> = if *case_insensitive {
pat.to_lowercase().chars().collect()
} else {
pat.chars().collect()
};
steps.push(Step::Like {
pattern: chars,
negated: *negated,
case_insensitive: *case_insensitive,
});
}
_ => steps.push(Step::Subtree(e.clone())),
},
Expr::FunctionCall { name, args } if is_pure_scalar_function(name) => {
let lower = name.to_ascii_lowercase();
if args.len() == 1 {
if let Expr::Column(c) = &args[0]
&& let Some(pos) = compile_column_pos(c, ctx)
{
match lower.as_str() {
"length" | "char_length" | "character_length" => {
steps.push(Step::ColumnLength { pos });
return;
}
"octet_length" => {
steps.push(Step::ColumnOctetLength { pos });
return;
}
_ => {}
}
}
}
for a in args {
compile_into(a, ctx, steps);
}
steps.push(Step::Function {
name_lower: lower,
n_args: args.len(),
});
}
Expr::Cast { expr, target } => {
compile_into(expr, ctx, steps);
steps.push(Step::Cast { target: *target });
}
other => steps.push(Step::Subtree(other.clone())),
}
}
fn literal_text_pattern(pattern: &Expr) -> Option<&str> {
match pattern {
Expr::Literal(Literal::String(s)) => Some(s.as_str()),
_ => None,
}
}
pub(crate) fn fully_compilable(e: &Expr) -> bool {
match e {
Expr::Literal(_) | Expr::Column(_) => true,
Expr::Binary { lhs, rhs, .. } => fully_compilable(lhs) && fully_compilable(rhs),
Expr::Unary { expr, .. } | Expr::IsNull { expr, .. } => fully_compilable(expr),
Expr::InList { expr, list, .. } => {
fully_compilable(expr) && crate::build_in_list_set(list).is_some()
}
Expr::Like { expr, pattern, .. } => {
fully_compilable(expr) && literal_text_pattern(pattern).is_some()
}
Expr::FunctionCall { name, args } => {
is_pure_scalar_function(name) && args.iter().all(fully_compilable)
}
Expr::Cast { expr, .. } => fully_compilable(expr),
_ => false,
}
}
fn is_pure_scalar_function(name: &str) -> bool {
matches!(
name.to_ascii_lowercase().as_str(),
"length"
| "char_length"
| "character_length"
| "octet_length"
| "upper"
| "lower"
| "trim"
| "ltrim"
| "rtrim"
| "btrim"
| "left"
| "right"
| "substring"
| "substr"
| "replace"
| "position"
| "strpos"
| "concat"
| "concat_ws"
| "reverse"
| "repeat"
| "lpad"
| "rpad"
| "split_part"
| "coalesce"
| "nullif"
| "greatest"
| "least"
| "ifnull"
| "isnull"
| "nvl"
| "abs"
| "ceil"
| "ceiling"
| "floor"
| "round"
| "trunc"
| "sqrt"
| "power"
| "pow"
| "mod"
| "sign"
| "log"
| "log10"
| "exp"
| "ln"
| "cast"
)
}
pub(crate) fn compile_expr(e: &Expr, ctx: &EvalContext<'_>) -> CompiledExpr {
let mut steps = Vec::new();
compile_into(e, ctx, &mut steps);
CompiledExpr { steps }
}
pub(crate) fn eval_compiled(
c: &CompiledExpr,
row: &Row,
ctx: &EvalContext<'_>,
stack: &mut Vec<Value>,
) -> Result<Value, EvalError> {
eval_compiled_ref(c, &crate::join::RowRef::Owned(row), ctx, stack)
}
pub(crate) fn eval_compiled_ref(
c: &CompiledExpr,
row: &crate::join::RowRef<'_>,
ctx: &EvalContext<'_>,
stack: &mut Vec<Value>,
) -> Result<Value, EvalError> {
stack.clear();
for step in &c.steps {
match step {
Step::Column(pos) => {
stack.push(row.get(*pos).cloned().unwrap_or(Value::Null));
}
Step::Lit(v) => stack.push(v.clone()),
Step::Binary(op) => {
let r = stack.pop().unwrap_or(Value::Null);
let l = stack.pop().unwrap_or(Value::Null);
stack.push(apply_binary(*op, l, r)?);
}
Step::BinaryCi(op) => {
let fold = |v: Value| match v {
Value::Text(s) => Value::Text(s.to_ascii_lowercase()),
other => other,
};
let r = fold(stack.pop().unwrap_or(Value::Null));
let l = fold(stack.pop().unwrap_or(Value::Null));
stack.push(apply_binary(*op, l, r)?);
}
Step::Unary(op) => {
let v = stack.pop().unwrap_or(Value::Null);
stack.push(apply_unary(*op, v)?);
}
Step::IsNull { negated } => {
let v = stack.pop().unwrap_or(Value::Null);
let is_null = matches!(v, Value::Null);
stack.push(Value::Bool(if *negated { !is_null } else { is_null }));
}
Step::InSet {
set,
has_null,
negated,
fallback,
} => {
let needle = stack.pop().unwrap_or(Value::Null);
let contained = match (&needle, set) {
(Value::Null, _) => {
stack.push(Value::Null);
continue;
}
(Value::SmallInt(n), crate::memoize::InListSet::Int(s)) => {
s.contains(&i64::from(*n))
}
(Value::Int(n), crate::memoize::InListSet::Int(s)) => {
s.contains(&i64::from(*n))
}
(Value::BigInt(n), crate::memoize::InListSet::Int(s)) => s.contains(n),
(Value::Text(t), crate::memoize::InListSet::Text(s)) => s.contains(t.as_str()),
_ => {
stack.push(eval_expr(fallback, &row.as_row(), ctx)?);
continue;
}
};
let inner = if contained {
Value::Bool(true)
} else if *has_null {
Value::Null
} else {
Value::Bool(false)
};
stack.push(match (negated, inner) {
(true, Value::Bool(b)) => Value::Bool(!b),
(_, v) => v,
});
}
Step::Like {
pattern,
negated,
case_insensitive,
} => {
let v = stack.pop().unwrap_or(Value::Null);
match v {
Value::Null => stack.push(Value::Null),
Value::Text(t) => {
let text: Vec<char> = if *case_insensitive {
t.to_lowercase().chars().collect()
} else {
t.chars().collect()
};
let m = like_match_inner(&text, 0, pattern, 0);
stack.push(Value::Bool(if *negated { !m } else { m }));
}
other => {
return Err(EvalError::TypeMismatch {
detail: format!(
"LIKE requires text operands, got {:?}",
other.data_type()
),
});
}
}
}
Step::ColumnLength { pos } => {
let v = row.get(*pos).unwrap_or(&Value::Null);
let pushed = match v {
Value::Null => Value::Null,
Value::Text(s) => {
let n = if s.is_ascii() {
i32::try_from(s.len()).unwrap_or(i32::MAX)
} else {
i32::try_from(s.chars().count()).unwrap_or(i32::MAX)
};
Value::Int(n)
}
Value::Bytes(b) => Value::Int(i32::try_from(b.len()).unwrap_or(i32::MAX)),
other => {
return Err(EvalError::TypeMismatch {
detail: format!(
"length() needs text or bytea, got {:?}",
other.data_type()
),
});
}
};
stack.push(pushed);
}
Step::ColumnOctetLength { pos } => {
let v = row.get(*pos).unwrap_or(&Value::Null);
let pushed = match v {
Value::Null => Value::Null,
Value::Text(s) => Value::Int(i32::try_from(s.len()).unwrap_or(i32::MAX)),
Value::Bytes(b) => Value::Int(i32::try_from(b.len()).unwrap_or(i32::MAX)),
other => {
return Err(EvalError::TypeMismatch {
detail: format!(
"octet_length() needs text or bytea, got {:?}",
other.data_type()
),
});
}
};
stack.push(pushed);
}
Step::Function { name_lower, n_args } => {
let start = stack.len().saturating_sub(*n_args);
let result =
super::functions::apply_function_lower(name_lower, &stack[start..], ctx)?;
stack.truncate(start);
stack.push(result);
}
Step::Cast { target } => {
let v = stack.pop().unwrap_or(Value::Null);
stack.push(super::cast::cast_value(v, *target)?);
}
Step::Subtree(e) => stack.push(eval_expr(e, &row.as_row(), ctx)?),
}
}
Ok(stack.pop().unwrap_or(Value::Null))
}