use std::fmt::{Display, Formatter};
use fnv::FnvHashSet;
use serde::{Deserialize, Serialize};
use crate::expr::accessor::{StructAccessor, StructAccessorRef};
use crate::expr::{
BinaryExpression, Bind, Predicate, PredicateOperator, SetExpression, UnaryExpression,
};
use crate::spec::{Datum, NestedField, NestedFieldRef, SchemaRef};
use crate::{Error, ErrorKind};
pub type Term = Reference;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Reference {
name: String,
}
impl Reference {
pub fn new(name: impl Into<String>) -> Self {
Self { name: name.into() }
}
pub fn name(&self) -> &str {
&self.name
}
}
impl Reference {
pub fn less_than(self, datum: Datum) -> Predicate {
Predicate::Binary(BinaryExpression::new(
PredicateOperator::LessThan,
self,
datum,
))
}
pub fn less_than_or_equal_to(self, datum: Datum) -> Predicate {
Predicate::Binary(BinaryExpression::new(
PredicateOperator::LessThanOrEq,
self,
datum,
))
}
pub fn greater_than(self, datum: Datum) -> Predicate {
Predicate::Binary(BinaryExpression::new(
PredicateOperator::GreaterThan,
self,
datum,
))
}
pub fn greater_than_or_equal_to(self, datum: Datum) -> Predicate {
Predicate::Binary(BinaryExpression::new(
PredicateOperator::GreaterThanOrEq,
self,
datum,
))
}
pub fn equal_to(self, datum: Datum) -> Predicate {
Predicate::Binary(BinaryExpression::new(PredicateOperator::Eq, self, datum))
}
pub fn not_equal_to(self, datum: Datum) -> Predicate {
Predicate::Binary(BinaryExpression::new(PredicateOperator::NotEq, self, datum))
}
pub fn starts_with(self, datum: Datum) -> Predicate {
Predicate::Binary(BinaryExpression::new(
PredicateOperator::StartsWith,
self,
datum,
))
}
pub fn not_starts_with(self, datum: Datum) -> Predicate {
Predicate::Binary(BinaryExpression::new(
PredicateOperator::NotStartsWith,
self,
datum,
))
}
pub fn is_nan(self) -> Predicate {
Predicate::Unary(UnaryExpression::new(PredicateOperator::IsNan, self))
}
pub fn is_not_nan(self) -> Predicate {
Predicate::Unary(UnaryExpression::new(PredicateOperator::NotNan, self))
}
pub fn is_null(self) -> Predicate {
Predicate::Unary(UnaryExpression::new(PredicateOperator::IsNull, self))
}
pub fn is_not_null(self) -> Predicate {
Predicate::Unary(UnaryExpression::new(PredicateOperator::NotNull, self))
}
pub fn is_in(self, literals: impl IntoIterator<Item = Datum>) -> Predicate {
Predicate::Set(SetExpression::new(
PredicateOperator::In,
self,
FnvHashSet::from_iter(literals),
))
}
pub fn is_not_in(self, literals: impl IntoIterator<Item = Datum>) -> Predicate {
Predicate::Set(SetExpression::new(
PredicateOperator::NotIn,
self,
FnvHashSet::from_iter(literals),
))
}
}
impl Display for Reference {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name)
}
}
impl Bind for Reference {
type Bound = BoundReference;
fn bind(&self, schema: SchemaRef, case_sensitive: bool) -> crate::Result<Self::Bound> {
let field = if case_sensitive {
schema.field_by_name(&self.name)
} else {
schema.field_by_name_case_insensitive(&self.name)
};
let field = field.ok_or_else(|| {
Error::new(
ErrorKind::DataInvalid,
format!("Field {} not found in schema", self.name),
)
})?;
let accessor = schema.accessor_by_field_id(field.id).ok_or_else(|| {
Error::new(
ErrorKind::DataInvalid,
format!("Accessor for Field {} not found", self.name),
)
})?;
Ok(BoundReference::new(
self.name.clone(),
field.clone(),
accessor.clone(),
))
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct BoundReference {
column_name: String,
field: NestedFieldRef,
accessor: StructAccessorRef,
}
impl BoundReference {
pub fn new(
name: impl Into<String>,
field: NestedFieldRef,
accessor: StructAccessorRef,
) -> Self {
Self {
column_name: name.into(),
field,
accessor,
}
}
pub fn field(&self) -> &NestedField {
&self.field
}
pub fn accessor(&self) -> &StructAccessor {
&self.accessor
}
}
impl Display for BoundReference {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.column_name)
}
}
pub type BoundTerm = BoundReference;
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::expr::accessor::StructAccessor;
use crate::expr::{Bind, BoundReference, Reference};
use crate::spec::{NestedField, PrimitiveType, Schema, SchemaRef, Type};
fn table_schema_simple() -> SchemaRef {
Arc::new(
Schema::builder()
.with_schema_id(1)
.with_identifier_field_ids(vec![2])
.with_fields(vec![
NestedField::optional(1, "foo", Type::Primitive(PrimitiveType::String)).into(),
NestedField::required(2, "bar", Type::Primitive(PrimitiveType::Int)).into(),
NestedField::optional(3, "baz", Type::Primitive(PrimitiveType::Boolean)).into(),
])
.build()
.unwrap(),
)
}
#[test]
fn test_bind_reference() {
let schema = table_schema_simple();
let reference = Reference::new("bar").bind(schema, true).unwrap();
let accessor_ref = Arc::new(StructAccessor::new(1, PrimitiveType::Int));
let expected_ref = BoundReference::new(
"bar",
NestedField::required(2, "bar", Type::Primitive(PrimitiveType::Int)).into(),
accessor_ref.clone(),
);
assert_eq!(expected_ref, reference);
}
#[test]
fn test_bind_reference_case_insensitive() {
let schema = table_schema_simple();
let reference = Reference::new("BAR").bind(schema, false).unwrap();
let accessor_ref = Arc::new(StructAccessor::new(1, PrimitiveType::Int));
let expected_ref = BoundReference::new(
"BAR",
NestedField::required(2, "bar", Type::Primitive(PrimitiveType::Int)).into(),
accessor_ref.clone(),
);
assert_eq!(expected_ref, reference);
}
#[test]
fn test_bind_reference_failure() {
let schema = table_schema_simple();
let result = Reference::new("bar_not_eix").bind(schema, true);
assert!(result.is_err());
}
#[test]
fn test_bind_reference_case_insensitive_failure() {
let schema = table_schema_simple();
let result = Reference::new("bar_non_exist").bind(schema, false);
assert!(result.is_err());
}
}