use std::any::Any;
use chrono::{DateTime, Utc};
use rustc_hash::FxHashMap;
use super::{find_column_index, resolve_alias, Expression};
use crate::common::SmartString;
use crate::core::{DataType, Error, Operator, Result, Row, Schema, Value};
#[derive(Debug, Clone)]
pub enum ComparisonValue {
Null,
Integer(i64),
Float(f64),
Text(String),
Boolean(bool),
Timestamp(DateTime<Utc>),
}
impl ComparisonValue {
pub fn from_value(value: &Value) -> Self {
match value {
Value::Null(_) => ComparisonValue::Null,
Value::Integer(i) => ComparisonValue::Integer(*i),
Value::Float(f) => ComparisonValue::Float(*f),
Value::Text(s) => ComparisonValue::Text(s.to_string()),
Value::Boolean(b) => ComparisonValue::Boolean(*b),
Value::Timestamp(t) => ComparisonValue::Timestamp(*t),
Value::Extension(data) if data.first() == Some(&(DataType::Json as u8)) => {
ComparisonValue::Text(std::str::from_utf8(&data[1..]).unwrap_or("").to_string())
}
Value::Extension(data) if data.first() == Some(&(DataType::Vector as u8)) => {
ComparisonValue::Text(crate::core::value::format_vector_bytes(&data[1..]))
}
Value::Extension(data) => {
ComparisonValue::Text(String::from_utf8_lossy(&data[1..]).into_owned())
}
}
}
pub fn data_type(&self) -> DataType {
match self {
ComparisonValue::Null => DataType::Text, ComparisonValue::Integer(_) => DataType::Integer,
ComparisonValue::Float(_) => DataType::Float,
ComparisonValue::Text(_) => DataType::Text,
ComparisonValue::Boolean(_) => DataType::Boolean,
ComparisonValue::Timestamp(_) => DataType::Timestamp,
}
}
pub fn is_null(&self) -> bool {
matches!(self, ComparisonValue::Null)
}
pub fn to_value(&self) -> Value {
match self {
ComparisonValue::Null => Value::Null(DataType::Text),
ComparisonValue::Integer(i) => Value::Integer(*i),
ComparisonValue::Float(f) => Value::Float(*f),
ComparisonValue::Text(s) => Value::Text(SmartString::new(s)),
ComparisonValue::Boolean(b) => Value::Boolean(*b),
ComparisonValue::Timestamp(t) => Value::Timestamp(*t),
}
}
}
#[derive(Debug, Clone)]
pub struct ComparisonExpr {
column: String,
operator: Operator,
value: ComparisonValue,
original_value: Value,
col_index: Option<usize>,
aliases: FxHashMap<String, String>,
original_column: Option<String>,
}
impl ComparisonExpr {
pub fn new(column: impl Into<String>, operator: Operator, value: Value) -> Self {
Self {
column: column.into(),
operator,
value: ComparisonValue::from_value(&value),
original_value: value,
col_index: None,
aliases: FxHashMap::default(),
original_column: None,
}
}
pub fn eq(column: impl Into<String>, value: Value) -> Self {
Self::new(column, Operator::Eq, value)
}
pub fn ne(column: impl Into<String>, value: Value) -> Self {
Self::new(column, Operator::Ne, value)
}
pub fn gt(column: impl Into<String>, value: Value) -> Self {
Self::new(column, Operator::Gt, value)
}
pub fn gte(column: impl Into<String>, value: Value) -> Self {
Self::new(column, Operator::Gte, value)
}
pub fn lt(column: impl Into<String>, value: Value) -> Self {
Self::new(column, Operator::Lt, value)
}
pub fn lte(column: impl Into<String>, value: Value) -> Self {
Self::new(column, Operator::Lte, value)
}
pub fn column(&self) -> &str {
&self.column
}
pub fn operator(&self) -> Operator {
self.operator
}
pub fn value(&self) -> &ComparisonValue {
&self.value
}
pub fn integer_value(&self) -> Option<i64> {
match &self.value {
ComparisonValue::Integer(i) => Some(*i),
_ => None,
}
}
#[inline]
fn compare_integers(&self, col_val: i64, cmp_val: i64) -> bool {
match self.operator {
Operator::Eq => col_val == cmp_val,
Operator::Ne => col_val != cmp_val,
Operator::Gt => col_val > cmp_val,
Operator::Gte => col_val >= cmp_val,
Operator::Lt => col_val < cmp_val,
Operator::Lte => col_val <= cmp_val,
_ => false,
}
}
#[inline]
fn compare_floats(&self, col_val: f64, cmp_val: f64) -> bool {
match self.operator {
Operator::Eq => col_val == cmp_val,
Operator::Ne => col_val != cmp_val,
Operator::Gt => col_val > cmp_val,
Operator::Gte => col_val >= cmp_val,
Operator::Lt => col_val < cmp_val,
Operator::Lte => col_val <= cmp_val,
_ => false,
}
}
#[inline]
fn compare_strings(&self, col_val: &str, cmp_val: &str) -> bool {
match self.operator {
Operator::Eq => col_val == cmp_val,
Operator::Ne => col_val != cmp_val,
Operator::Gt => col_val > cmp_val,
Operator::Gte => col_val >= cmp_val,
Operator::Lt => col_val < cmp_val,
Operator::Lte => col_val <= cmp_val,
_ => false,
}
}
#[inline]
fn compare_booleans(&self, col_val: bool, cmp_val: bool) -> bool {
match self.operator {
Operator::Eq => col_val == cmp_val,
Operator::Ne => col_val != cmp_val,
_ => false, }
}
#[inline]
fn compare_timestamps(&self, col_val: DateTime<Utc>, cmp_val: DateTime<Utc>) -> bool {
match self.operator {
Operator::Eq => col_val == cmp_val,
Operator::Ne => col_val != cmp_val,
Operator::Gt => col_val > cmp_val,
Operator::Gte => col_val >= cmp_val,
Operator::Lt => col_val < cmp_val,
Operator::Lte => col_val <= cmp_val,
_ => false,
}
}
}
impl Expression for ComparisonExpr {
fn evaluate(&self, row: &Row) -> Result<bool> {
let col_idx = match self.col_index {
Some(idx) => idx,
None => return Ok(false),
};
if col_idx >= row.len() {
return Ok(false);
}
let col_value = &row[col_idx];
if col_value.is_null() {
return Ok(matches!(self.operator, Operator::IsNull));
}
match self.operator {
Operator::IsNull => return Ok(col_value.is_null()),
Operator::IsNotNull => return Ok(!col_value.is_null()),
_ => {}
}
if self.value.is_null() {
return Ok(false);
}
match (&self.value, col_value) {
(ComparisonValue::Integer(cmp_val), Value::Integer(col_val)) => {
Ok(self.compare_integers(*col_val, *cmp_val))
}
(ComparisonValue::Float(cmp_val), Value::Float(col_val)) => {
Ok(self.compare_floats(*col_val, *cmp_val))
}
(ComparisonValue::Text(cmp_val), Value::Text(col_val)) => {
Ok(self.compare_strings(col_val, cmp_val))
}
(ComparisonValue::Boolean(cmp_val), Value::Boolean(col_val)) => {
Ok(self.compare_booleans(*col_val, *cmp_val))
}
(ComparisonValue::Timestamp(cmp_val), Value::Timestamp(col_val)) => {
Ok(self.compare_timestamps(*col_val, *cmp_val))
}
(ComparisonValue::Integer(cmp_val), Value::Float(col_val)) => {
Ok(self.compare_floats(*col_val, *cmp_val as f64))
}
(ComparisonValue::Float(cmp_val), Value::Integer(col_val)) => {
Ok(self.compare_floats(*col_val as f64, *cmp_val))
}
_ => Err(Error::type_conversion(
format!("{:?}", col_value.data_type()),
format!("{:?}", self.value.data_type()),
)),
}
}
fn evaluate_fast(&self, row: &Row) -> bool {
let col_idx = match self.col_index {
Some(idx) if idx < row.len() => idx,
_ => return false,
};
let col_value = &row[col_idx];
if col_value.is_null() {
return matches!(self.operator, Operator::IsNull);
}
match self.operator {
Operator::IsNull => return col_value.is_null(),
Operator::IsNotNull => return !col_value.is_null(),
_ => {}
}
if self.value.is_null() {
return false;
}
match (&self.value, col_value) {
(ComparisonValue::Integer(cmp_val), Value::Integer(col_val)) => {
self.compare_integers(*col_val, *cmp_val)
}
(ComparisonValue::Float(cmp_val), Value::Float(col_val)) => {
self.compare_floats(*col_val, *cmp_val)
}
(ComparisonValue::Text(cmp_val), Value::Text(col_val)) => {
self.compare_strings(col_val, cmp_val)
}
(ComparisonValue::Boolean(cmp_val), Value::Boolean(col_val)) => {
self.compare_booleans(*col_val, *cmp_val)
}
(ComparisonValue::Timestamp(cmp_val), Value::Timestamp(col_val)) => {
self.compare_timestamps(*col_val, *cmp_val)
}
(ComparisonValue::Integer(cmp_val), Value::Float(col_val)) => {
self.compare_floats(*col_val, *cmp_val as f64)
}
(ComparisonValue::Float(cmp_val), Value::Integer(col_val)) => {
self.compare_floats(*col_val as f64, *cmp_val)
}
_ => false,
}
}
fn with_aliases(&self, aliases: &FxHashMap<String, String>) -> Box<dyn Expression> {
let resolved = resolve_alias(&self.column, aliases);
let mut expr = self.clone();
if resolved != self.column {
expr.original_column = Some(self.column.clone());
expr.column = resolved.to_string();
}
expr.aliases = aliases.clone();
expr.col_index = None; Box::new(expr)
}
fn prepare_for_schema(&mut self, schema: &Schema) {
if self.col_index.is_some() {
return; }
self.col_index = find_column_index(schema, &self.column);
}
fn is_prepared(&self) -> bool {
self.col_index.is_some()
}
fn get_column_name(&self) -> Option<&str> {
Some(&self.column)
}
fn collect_column_indices(&self, out: &mut Vec<usize>) -> bool {
if let Some(idx) = self.col_index {
out.push(idx);
true
} else {
false
}
}
fn can_use_index(&self) -> bool {
matches!(
self.operator,
Operator::Eq | Operator::Gt | Operator::Gte | Operator::Lt | Operator::Lte
)
}
fn get_comparison_info(&self) -> Option<(&str, Operator, &Value)> {
Some((&self.column, self.operator, &self.original_value))
}
fn is_conjunctive_simple(&self) -> bool {
true
}
fn clone_box(&self) -> Box<dyn Expression> {
Box::new(self.clone())
}
fn is_unknown_due_to_null(&self, row: &Row) -> bool {
let col_idx = match self.col_index {
Some(idx) if idx < row.len() => idx,
_ => return false,
};
let col_value = &row[col_idx];
if matches!(self.operator, Operator::IsNull | Operator::IsNotNull) {
return false;
}
if col_value.is_null() {
return true;
}
if self.value.is_null() {
return true;
}
false
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::SchemaBuilder;
fn test_schema() -> Schema {
SchemaBuilder::new("test")
.add_primary_key("id", DataType::Integer)
.add("name", DataType::Text)
.add("score", DataType::Float)
.add("active", DataType::Boolean)
.build()
}
fn test_row() -> Row {
Row::from_values(vec![
Value::integer(1),
Value::text("Alice"),
Value::float(95.5),
Value::boolean(true),
])
}
#[test]
fn test_integer_equality() {
let schema = test_schema();
let row = test_row();
let mut expr = ComparisonExpr::eq("id", Value::integer(1));
expr.prepare_for_schema(&schema);
assert!(expr.evaluate(&row).unwrap());
assert!(expr.evaluate_fast(&row));
let mut expr = ComparisonExpr::eq("id", Value::integer(2));
expr.prepare_for_schema(&schema);
assert!(!expr.evaluate(&row).unwrap());
assert!(!expr.evaluate_fast(&row));
}
#[test]
fn test_integer_comparison() {
let schema = test_schema();
let row = test_row();
let mut expr = ComparisonExpr::gt("id", Value::integer(0));
expr.prepare_for_schema(&schema);
assert!(expr.evaluate(&row).unwrap());
let mut expr = ComparisonExpr::gte("id", Value::integer(1));
expr.prepare_for_schema(&schema);
assert!(expr.evaluate(&row).unwrap());
let mut expr = ComparisonExpr::lt("id", Value::integer(2));
expr.prepare_for_schema(&schema);
assert!(expr.evaluate(&row).unwrap());
let mut expr = ComparisonExpr::lte("id", Value::integer(1));
expr.prepare_for_schema(&schema);
assert!(expr.evaluate(&row).unwrap());
let mut expr = ComparisonExpr::ne("id", Value::integer(2));
expr.prepare_for_schema(&schema);
assert!(expr.evaluate(&row).unwrap());
}
#[test]
fn test_string_equality() {
let schema = test_schema();
let row = test_row();
let mut expr = ComparisonExpr::eq("name", Value::text("Alice"));
expr.prepare_for_schema(&schema);
assert!(expr.evaluate(&row).unwrap());
let mut expr = ComparisonExpr::eq("name", Value::text("Bob"));
expr.prepare_for_schema(&schema);
assert!(!expr.evaluate(&row).unwrap());
}
#[test]
fn test_float_comparison() {
let schema = test_schema();
let row = test_row();
let mut expr = ComparisonExpr::gt("score", Value::float(90.0));
expr.prepare_for_schema(&schema);
assert!(expr.evaluate(&row).unwrap());
let mut expr = ComparisonExpr::lt("score", Value::float(100.0));
expr.prepare_for_schema(&schema);
assert!(expr.evaluate(&row).unwrap());
}
#[test]
fn test_boolean_comparison() {
let schema = test_schema();
let row = test_row();
let mut expr = ComparisonExpr::eq("active", Value::boolean(true));
expr.prepare_for_schema(&schema);
assert!(expr.evaluate(&row).unwrap());
let mut expr = ComparisonExpr::ne("active", Value::boolean(false));
expr.prepare_for_schema(&schema);
assert!(expr.evaluate(&row).unwrap());
}
#[test]
fn test_null_handling() {
let schema = SchemaBuilder::new("test")
.add_nullable("value", DataType::Integer)
.build();
let null_row = Row::from_values(vec![Value::null(DataType::Integer)]);
let mut expr = ComparisonExpr::eq("value", Value::integer(1));
expr.prepare_for_schema(&schema);
assert!(!expr.evaluate(&null_row).unwrap());
}
#[test]
fn test_cross_type_numeric() {
let schema = test_schema();
let row = test_row(); let mut expr = ComparisonExpr::eq("id", Value::float(1.0));
expr.prepare_for_schema(&schema);
assert!(expr.evaluate(&row).unwrap());
let mut expr = ComparisonExpr::gt("score", Value::integer(90));
expr.prepare_for_schema(&schema);
assert!(expr.evaluate(&row).unwrap());
}
#[test]
fn test_unprepared_expression() {
let row = test_row();
let expr = ComparisonExpr::eq("id", Value::integer(1));
assert!(!expr.evaluate(&row).unwrap());
assert!(!expr.evaluate_fast(&row));
}
#[test]
fn test_with_aliases() {
let schema = test_schema();
let row = test_row();
let mut aliases = FxHashMap::default();
aliases.insert("n".to_string(), "name".to_string());
let expr = ComparisonExpr::eq("n", Value::text("Alice"));
let mut aliased = expr.with_aliases(&aliases);
aliased.prepare_for_schema(&schema);
assert!(aliased.evaluate(&row).unwrap());
}
#[test]
fn test_can_use_index() {
let expr = ComparisonExpr::eq("id", Value::integer(1));
assert!(expr.can_use_index());
let expr = ComparisonExpr::gt("id", Value::integer(1));
assert!(expr.can_use_index());
}
}