use std::any::Any;
use std::fmt::{self, Debug};
use rustc_hash::FxHashMap;
use std::sync::Arc;
use crate::core::{Operator, Result, Row, Schema, Value};
use crate::functions::ScalarFunction;
use super::{find_column_index, resolve_alias, Expression};
pub struct FunctionExpr {
function: Arc<dyn ScalarFunction>,
arguments: Vec<FunctionArg>,
operator: Operator,
compare_value: Value,
arg_indices: Vec<Option<usize>>,
prepared: bool,
}
thread_local! {
static ARG_BUFFER: std::cell::RefCell<Vec<Value>> = std::cell::RefCell::new(Vec::with_capacity(4));
}
#[derive(Clone)]
pub enum FunctionArg {
Column(String),
Literal(Value),
Function {
function: Arc<dyn ScalarFunction>,
arguments: Vec<FunctionArg>,
},
}
impl Debug for FunctionArg {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
FunctionArg::Column(col) => f.debug_tuple("Column").field(col).finish(),
FunctionArg::Literal(val) => f.debug_tuple("Literal").field(val).finish(),
FunctionArg::Function {
function,
arguments,
} => f
.debug_struct("Function")
.field("name", &function.name())
.field("arguments", arguments)
.finish(),
}
}
}
impl FunctionExpr {
pub fn new(
function: Arc<dyn ScalarFunction>,
arguments: Vec<FunctionArg>,
operator: Operator,
compare_value: Value,
) -> Self {
let arg_count = arguments.len();
Self {
function,
arguments,
operator,
compare_value,
arg_indices: vec![None; arg_count],
prepared: false,
}
}
pub fn eq(
function: Arc<dyn ScalarFunction>,
arguments: Vec<FunctionArg>,
compare_value: Value,
) -> Self {
Self::new(function, arguments, Operator::Eq, compare_value)
}
pub fn boolean(function: Arc<dyn ScalarFunction>, arguments: Vec<FunctionArg>) -> Self {
Self::new(function, arguments, Operator::Eq, Value::Boolean(true))
}
pub fn function_name(&self) -> &str {
self.function.name()
}
pub fn get_arguments(&self) -> &[FunctionArg] {
&self.arguments
}
pub fn get_operator(&self) -> Operator {
self.operator
}
pub fn get_compare_value(&self) -> &Value {
&self.compare_value
}
#[allow(clippy::only_used_in_recursion)]
fn evaluate_arg(
&self,
arg: &FunctionArg,
arg_index: Option<usize>,
row: &Row,
) -> Result<Value> {
match arg {
FunctionArg::Column(col_name) => {
if let Some(idx) = arg_index {
Ok(row.get(idx).cloned().unwrap_or_else(Value::null_unknown))
} else {
Err(crate::core::Error::ColumnNotFound(col_name.to_string()))
}
}
FunctionArg::Literal(value) => Ok(value.clone()),
FunctionArg::Function {
function,
arguments,
} => {
let args: Result<Vec<Value>> = arguments
.iter()
.map(|a| self.evaluate_arg(a, None, row))
.collect();
function.evaluate(&args?)
}
}
}
fn compare(&self, result: &Value, target: &Value) -> bool {
match self.operator {
Operator::Eq => result == target,
Operator::Ne => result != target,
Operator::Lt => result < target,
Operator::Lte => result <= target,
Operator::Gt => result > target,
Operator::Gte => result >= target,
_ => false,
}
}
}
impl Debug for FunctionExpr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FunctionExpr")
.field("function", &self.function.name())
.field("arguments", &self.arguments)
.field("operator", &self.operator)
.field("compare_value", &self.compare_value)
.field("prepared", &self.prepared)
.finish()
}
}
impl Expression for FunctionExpr {
fn evaluate(&self, row: &Row) -> Result<bool> {
if self.arguments.len() == 1 {
if let FunctionArg::Column(_) = &self.arguments[0] {
if let Some(idx) = self.arg_indices.first().copied().flatten() {
let value = row.get(idx).cloned().unwrap_or_else(Value::null_unknown);
let result = self.function.evaluate(std::slice::from_ref(&value))?;
return Ok(self.compare(&result, &self.compare_value));
}
}
}
ARG_BUFFER.with(|buf_cell| {
let mut arg_values = buf_cell.borrow_mut();
arg_values.clear();
for (i, arg) in self.arguments.iter().enumerate() {
let value =
self.evaluate_arg(arg, self.arg_indices.get(i).copied().flatten(), row)?;
arg_values.push(value);
}
let result = self.function.evaluate(&arg_values)?;
Ok(self.compare(&result, &self.compare_value))
})
}
fn evaluate_fast(&self, row: &Row) -> bool {
self.evaluate(row).unwrap_or(false)
}
fn with_aliases(&self, aliases: &FxHashMap<String, String>) -> Box<dyn Expression> {
let new_arguments: Vec<FunctionArg> = self
.arguments
.iter()
.map(|arg| match arg {
FunctionArg::Column(col) => {
FunctionArg::Column(resolve_alias(col, aliases).to_string())
}
FunctionArg::Literal(v) => FunctionArg::Literal(v.clone()),
FunctionArg::Function {
function,
arguments,
} => {
FunctionArg::Function {
function: Arc::clone(function),
arguments: arguments
.iter()
.map(|a| match a {
FunctionArg::Column(c) => {
FunctionArg::Column(resolve_alias(c, aliases).to_string())
}
other => other.clone(),
})
.collect(),
}
}
})
.collect();
Box::new(FunctionExpr {
function: Arc::clone(&self.function),
arguments: new_arguments,
operator: self.operator,
compare_value: self.compare_value.clone(),
arg_indices: vec![None; self.arguments.len()],
prepared: false,
})
}
fn prepare_for_schema(&mut self, schema: &Schema) {
self.arg_indices = self
.arguments
.iter()
.map(|arg| match arg {
FunctionArg::Column(col) => find_column_index(schema, col),
_ => None,
})
.collect();
self.prepared = true;
}
fn is_prepared(&self) -> bool {
self.prepared
}
fn can_use_index(&self) -> bool {
false
}
fn clone_box(&self) -> Box<dyn Expression> {
Box::new(FunctionExpr {
function: Arc::clone(&self.function),
arguments: self.arguments.clone(),
operator: self.operator,
compare_value: self.compare_value.clone(),
arg_indices: self.arg_indices.clone(),
prepared: self.prepared,
})
}
fn as_any(&self) -> &dyn Any {
self
}
}
pub struct EvalExpr {
eval_fn: Box<dyn Fn(&Row) -> bool + Send + Sync>,
}
impl EvalExpr {
pub fn new<F>(f: F) -> Self
where
F: Fn(&Row) -> bool + Send + Sync + 'static,
{
Self {
eval_fn: Box::new(f),
}
}
}
impl Debug for EvalExpr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("EvalExpr").finish()
}
}
impl Expression for EvalExpr {
fn evaluate(&self, row: &Row) -> Result<bool> {
Ok((self.eval_fn)(row))
}
fn evaluate_fast(&self, row: &Row) -> bool {
(self.eval_fn)(row)
}
fn with_aliases(&self, _aliases: &FxHashMap<String, String>) -> Box<dyn Expression> {
panic!("EvalExpr does not support alias resolution")
}
fn prepare_for_schema(&mut self, _schema: &Schema) {
}
fn is_prepared(&self) -> bool {
true }
fn clone_box(&self) -> Box<dyn Expression> {
panic!("EvalExpr cannot be cloned")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{DataType, SchemaBuilder};
use crate::functions::{
FunctionDataType, FunctionInfo, FunctionSignature, FunctionType, UpperFunction,
};
fn test_schema() -> Schema {
SchemaBuilder::new("test")
.add_primary_key("id", DataType::Integer)
.add("name", DataType::Text)
.add("age", DataType::Integer)
.build()
}
struct TestLengthFn;
impl ScalarFunction for TestLengthFn {
fn name(&self) -> &str {
"LENGTH"
}
fn info(&self) -> FunctionInfo {
FunctionInfo::new(
"LENGTH",
FunctionType::Scalar,
"Returns the length of a string",
FunctionSignature::new(
FunctionDataType::Integer,
vec![FunctionDataType::String],
1,
1,
),
)
}
fn evaluate(&self, args: &[Value]) -> Result<Value> {
match args.first() {
Some(Value::Text(s)) => Ok(Value::Integer(s.len() as i64)),
Some(Value::Null(_)) => Ok(Value::null_unknown()),
_ => Ok(Value::Integer(0)),
}
}
fn clone_box(&self) -> Box<dyn ScalarFunction> {
Box::new(TestLengthFn)
}
}
#[test]
fn test_function_expr_upper() {
let schema = test_schema();
let upper_fn = Arc::new(UpperFunction);
let mut expr = FunctionExpr::eq(
upper_fn,
vec![FunctionArg::Column("name".to_string())],
Value::text("ALICE"),
);
expr.prepare_for_schema(&schema);
let row1 = Row::from_values(vec![
Value::Integer(1),
Value::text("alice"),
Value::Integer(30),
]);
assert!(expr.evaluate(&row1).unwrap());
let row2 = Row::from_values(vec![
Value::Integer(2),
Value::text("Alice"),
Value::Integer(25),
]);
assert!(expr.evaluate(&row2).unwrap());
let row3 = Row::from_values(vec![
Value::Integer(3),
Value::text("bob"),
Value::Integer(35),
]);
assert!(!expr.evaluate(&row3).unwrap());
}
#[test]
fn test_function_expr_with_literal() {
let schema = test_schema();
let upper_fn = Arc::new(UpperFunction);
let mut expr = FunctionExpr::eq(
upper_fn,
vec![FunctionArg::Literal(Value::text("hello"))],
Value::text("HELLO"),
);
expr.prepare_for_schema(&schema);
let row = Row::from_values(vec![
Value::Integer(1),
Value::text("anything"),
Value::Integer(30),
]);
assert!(expr.evaluate(&row).unwrap());
}
#[test]
fn test_function_expr_operators() {
let schema = test_schema();
let length_fn = Arc::new(TestLengthFn);
let mut expr = FunctionExpr::new(
length_fn,
vec![FunctionArg::Column("name".to_string())],
Operator::Gt,
Value::Integer(3),
);
expr.prepare_for_schema(&schema);
let row1 = Row::from_values(vec![
Value::Integer(1),
Value::text("alice"),
Value::Integer(30),
]);
assert!(expr.evaluate(&row1).unwrap());
let row2 = Row::from_values(vec![
Value::Integer(2),
Value::text("bob"),
Value::Integer(25),
]);
assert!(!expr.evaluate(&row2).unwrap());
}
#[test]
fn test_eval_expr() {
let expr = EvalExpr::new(|row| {
match row.get(0) {
Some(Value::Integer(n)) => *n > 5,
_ => false,
}
});
let row1 = Row::from_values(vec![Value::Integer(10)]);
assert!(expr.evaluate(&row1).unwrap());
let row2 = Row::from_values(vec![Value::Integer(3)]);
assert!(!expr.evaluate(&row2).unwrap());
}
#[test]
fn test_function_expr_clone() {
let upper_fn = Arc::new(UpperFunction);
let expr = FunctionExpr::eq(
upper_fn,
vec![FunctionArg::Column("name".to_string())],
Value::text("ALICE"),
);
let cloned = expr.clone_box();
assert!(format!("{:?}", cloned).contains("FunctionExpr"));
}
}