use crate::DbValue;
use crate::QueryId;
use crate::graph_search::SearchControl;
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
#[cfg_attr(feature = "derive", derive(agdb::DbSerialize))]
#[cfg_attr(feature = "api", derive(agdb::TypeDefImpl))]
pub enum QueryConditionLogic {
And,
Or,
}
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
#[cfg_attr(feature = "derive", derive(agdb::DbSerialize))]
#[cfg_attr(feature = "api", derive(agdb::TypeDefImpl))]
pub enum QueryConditionModifier {
None,
Beyond,
Not,
NotBeyond,
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
#[cfg_attr(feature = "derive", derive(agdb::DbSerialize))]
#[cfg_attr(feature = "api", derive(agdb::TypeDefImpl))]
pub enum QueryConditionData {
Distance(CountComparison),
Edge,
EdgeCount(CountComparison),
EdgeCountFrom(CountComparison),
EdgeCountTo(CountComparison),
Ids(Vec<QueryId>),
KeyValue(KeyValueComparison),
Keys(Vec<DbValue>),
Node,
Where(Vec<QueryCondition>),
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
#[cfg_attr(feature = "derive", derive(agdb::DbSerialize))]
#[cfg_attr(feature = "api", derive(agdb::TypeDefImpl))]
pub struct QueryCondition {
pub logic: QueryConditionLogic,
pub modifier: QueryConditionModifier,
#[cfg_attr(feature = "openapi", schema(no_recursion))]
pub data: QueryConditionData,
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
#[cfg_attr(feature = "derive", derive(agdb::DbSerialize))]
#[cfg_attr(feature = "api", derive(agdb::TypeDefImpl))]
pub enum CountComparison {
Equal(u64),
GreaterThan(u64),
GreaterThanOrEqual(u64),
LessThan(u64),
LessThanOrEqual(u64),
NotEqual(u64),
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
#[cfg_attr(feature = "derive", derive(agdb::DbSerialize))]
#[cfg_attr(feature = "api", derive(agdb::TypeDefImpl))]
pub enum Comparison {
Equal(DbValue),
GreaterThan(DbValue),
GreaterThanOrEqual(DbValue),
LessThan(DbValue),
LessThanOrEqual(DbValue),
NotEqual(DbValue),
Contains(DbValue),
StartsWith(DbValue),
EndsWith(DbValue),
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
#[cfg_attr(feature = "derive", derive(agdb::DbSerialize))]
#[cfg_attr(feature = "api", derive(agdb::TypeDefImpl))]
pub struct KeyValueComparison {
pub key: DbValue,
pub value: Comparison,
}
impl CountComparison {
pub(crate) fn compare_distance(&self, right: u64) -> SearchControl {
match self {
CountComparison::Equal(left) => match right.cmp(left) {
std::cmp::Ordering::Less => SearchControl::Continue(false),
std::cmp::Ordering::Equal => SearchControl::Stop(true),
std::cmp::Ordering::Greater => SearchControl::Stop(false),
},
CountComparison::GreaterThan(left) => match right.cmp(left) {
std::cmp::Ordering::Less | std::cmp::Ordering::Equal => {
SearchControl::Continue(false)
}
std::cmp::Ordering::Greater => SearchControl::Continue(true),
},
CountComparison::GreaterThanOrEqual(left) => match right.cmp(left) {
std::cmp::Ordering::Less => SearchControl::Continue(false),
std::cmp::Ordering::Greater | std::cmp::Ordering::Equal => {
SearchControl::Continue(true)
}
},
CountComparison::LessThan(left) => match right.cmp(left) {
std::cmp::Ordering::Less => SearchControl::Continue(true),
std::cmp::Ordering::Greater | std::cmp::Ordering::Equal => {
SearchControl::Stop(false)
}
},
CountComparison::LessThanOrEqual(left) => match right.cmp(left) {
std::cmp::Ordering::Less | std::cmp::Ordering::Equal => {
SearchControl::Continue(true)
}
std::cmp::Ordering::Greater => SearchControl::Stop(false),
},
CountComparison::NotEqual(left) => match right.cmp(left) {
std::cmp::Ordering::Less | std::cmp::Ordering::Greater => {
SearchControl::Continue(true)
}
std::cmp::Ordering::Equal => SearchControl::Continue(false),
},
}
}
pub(crate) fn compare(&self, left: u64) -> bool {
match self {
CountComparison::Equal(right) => left == *right,
CountComparison::GreaterThan(right) => left > *right,
CountComparison::GreaterThanOrEqual(right) => left >= *right,
CountComparison::LessThan(right) => left < *right,
CountComparison::LessThanOrEqual(right) => left <= *right,
CountComparison::NotEqual(right) => left != *right,
}
}
}
impl Comparison {
pub(crate) fn compare(&self, left: &DbValue) -> bool {
match self {
Comparison::Equal(right) => left == right,
Comparison::GreaterThan(right) => left > right,
Comparison::GreaterThanOrEqual(right) => left >= right,
Comparison::LessThan(right) => left < right,
Comparison::LessThanOrEqual(right) => left <= right,
Comparison::NotEqual(right) => left != right,
Comparison::Contains(right) => match (left, right) {
(DbValue::String(left), DbValue::String(right)) => left.contains(right),
(DbValue::String(left), DbValue::VecString(right)) => {
right.iter().all(|x| left.contains(x))
}
(DbValue::VecI64(left), DbValue::I64(right)) => left.contains(right),
(DbValue::VecI64(left), DbValue::VecI64(right)) => {
right.iter().all(|x| left.contains(x))
}
(DbValue::VecU64(left), DbValue::U64(right)) => left.contains(right),
(DbValue::VecU64(left), DbValue::VecU64(right)) => {
right.iter().all(|x| left.contains(x))
}
(DbValue::VecF64(left), DbValue::F64(right)) => left.contains(right),
(DbValue::VecF64(left), DbValue::VecF64(right)) => {
right.iter().all(|x| left.contains(x))
}
(DbValue::VecString(left), DbValue::String(right)) => left.contains(right),
(DbValue::VecString(left), DbValue::VecString(right)) => {
right.iter().all(|x| left.contains(x))
}
_ => false,
},
Comparison::StartsWith(right) => match (left, right) {
(DbValue::String(left), DbValue::String(right)) => left.starts_with(right),
(DbValue::String(left), DbValue::VecString(right)) => {
left.starts_with(&right.concat())
}
(DbValue::VecI64(left), DbValue::I64(right)) => left.starts_with(&[*right]),
(DbValue::VecI64(left), DbValue::VecI64(right)) => left.starts_with(right),
(DbValue::VecU64(left), DbValue::U64(right)) => left.starts_with(&[*right]),
(DbValue::VecU64(left), DbValue::VecU64(right)) => left.starts_with(right),
(DbValue::VecF64(left), DbValue::F64(right)) => left.starts_with(&[*right]),
(DbValue::VecF64(left), DbValue::VecF64(right)) => left.starts_with(right),
(DbValue::VecString(left), DbValue::String(right)) => left.first() == Some(right),
(DbValue::VecString(left), DbValue::VecString(right)) => left.starts_with(right),
_ => false,
},
Comparison::EndsWith(right) => match (left, right) {
(DbValue::String(left), DbValue::String(right)) => left.ends_with(right),
(DbValue::String(left), DbValue::VecString(right)) => {
left.ends_with(&right.concat())
}
(DbValue::VecI64(left), DbValue::I64(right)) => left.ends_with(&[*right]),
(DbValue::VecI64(left), DbValue::VecI64(right)) => left.ends_with(right),
(DbValue::VecU64(left), DbValue::U64(right)) => left.ends_with(&[*right]),
(DbValue::VecU64(left), DbValue::VecU64(right)) => left.ends_with(right),
(DbValue::VecF64(left), DbValue::F64(right)) => left.ends_with(&[*right]),
(DbValue::VecF64(left), DbValue::VecF64(right)) => left.ends_with(right),
(DbValue::VecString(left), DbValue::String(right)) => left.last() == Some(right),
(DbValue::VecString(left), DbValue::VecString(right)) => left.ends_with(right),
_ => false,
},
}
}
pub(crate) fn value(&self) -> &DbValue {
match self {
Comparison::Equal(value)
| Comparison::GreaterThan(value)
| Comparison::GreaterThanOrEqual(value)
| Comparison::LessThan(value)
| Comparison::LessThanOrEqual(value)
| Comparison::NotEqual(value)
| Comparison::Contains(value)
| Comparison::StartsWith(value)
| Comparison::EndsWith(value) => value,
}
}
}
impl From<u64> for CountComparison {
fn from(value: u64) -> Self {
CountComparison::Equal(value)
}
}
impl<T: Into<DbValue>> From<T> for Comparison {
fn from(value: T) -> Self {
Comparison::Equal(value.into())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn derived_from_debug() {
let _ = format!(
"{:?}",
QueryCondition {
logic: QueryConditionLogic::And,
modifier: QueryConditionModifier::None,
data: QueryConditionData::Edge,
}
);
let _ = format!("{:?}", Comparison::Equal(DbValue::I64(0)));
let _ = format!("{:?}", CountComparison::Equal(0));
}
#[test]
#[allow(clippy::redundant_clone)]
fn derived_from_clone() {
let left = QueryCondition {
logic: QueryConditionLogic::And,
modifier: QueryConditionModifier::None,
data: QueryConditionData::Edge,
};
let right = left.clone();
assert_eq!(left, right);
let left = Comparison::Equal(DbValue::I64(0));
let right = left.clone();
assert_eq!(left, right);
let left = CountComparison::Equal(0);
let right = left.clone();
assert_eq!(left, right);
}
#[test]
fn derived_from_partial_eq() {
assert_eq!(
QueryCondition {
logic: QueryConditionLogic::And,
modifier: QueryConditionModifier::None,
data: QueryConditionData::Edge,
},
QueryCondition {
logic: QueryConditionLogic::And,
modifier: QueryConditionModifier::None,
data: QueryConditionData::Edge,
}
);
assert_eq!(
Comparison::Equal(DbValue::I64(0)),
Comparison::Equal(DbValue::I64(0))
);
assert_eq!(CountComparison::Equal(0), CountComparison::Equal(0));
}
#[test]
fn count_comparison() {
use CountComparison::Equal;
use CountComparison::GreaterThan;
use CountComparison::GreaterThanOrEqual;
use CountComparison::LessThan;
use CountComparison::LessThanOrEqual;
use CountComparison::NotEqual;
use SearchControl::Continue;
use SearchControl::Stop;
assert_eq!(Equal(2).compare_distance(3), Stop(false));
assert_eq!(Equal(2).compare_distance(2), Stop(true));
assert_eq!(Equal(2).compare_distance(1), Continue(false));
assert_eq!(NotEqual(2).compare_distance(3), Continue(true));
assert_eq!(NotEqual(2).compare_distance(2), Continue(false));
assert_eq!(NotEqual(2).compare_distance(1), Continue(true));
assert_eq!(GreaterThan(2).compare_distance(3), Continue(true));
assert_eq!(GreaterThan(2).compare_distance(2), Continue(false));
assert_eq!(GreaterThan(2).compare_distance(1), Continue(false));
assert_eq!(GreaterThanOrEqual(2).compare_distance(3), Continue(true));
assert_eq!(GreaterThanOrEqual(2).compare_distance(2), Continue(true));
assert_eq!(GreaterThanOrEqual(2).compare_distance(1), Continue(false));
assert_eq!(LessThan(2).compare_distance(3), Stop(false));
assert_eq!(LessThan(2).compare_distance(2), Stop(false));
assert_eq!(LessThan(2).compare_distance(1), Continue(true));
assert_eq!(LessThanOrEqual(2).compare_distance(3), Stop(false));
assert_eq!(LessThanOrEqual(2).compare_distance(2), Continue(true));
assert_eq!(LessThanOrEqual(2).compare_distance(1), Continue(true));
}
#[test]
fn contains() {
let condition = Comparison::Contains("abc".into());
assert!(condition.compare(&"0abc123".into()));
assert!(!condition.compare(&"0bc123".into()));
let condition = Comparison::Contains(vec!["ab".to_string(), "23".to_string()].into());
assert!(condition.compare(&"0abc123".into()));
assert!(!condition.compare(&"0abc12".into()));
assert!(Comparison::Contains(1.into()).compare(&vec![2, 1, 3].into()));
assert!(!Comparison::Contains(4.into()).compare(&vec![2, 1, 3].into()));
let condition = Comparison::Contains(vec![2, 3].into());
assert!(condition.compare(&vec![2, 3].into()));
assert!(!condition.compare(&vec![1, 3].into()));
let condition = Comparison::Contains(1_u64.into());
assert!(condition.compare(&vec![2_u64, 1_u64, 3_u64].into()));
assert!(!condition.compare(&vec![2_u64, 3_u64].into()));
let condition = Comparison::Contains(vec![2_u64, 3_u64].into());
assert!(condition.compare(&vec![2_u64, 1_u64, 3_u64].into()));
assert!(!condition.compare(&vec![1_u64, 3_u64].into()));
let condition = Comparison::Contains(1.1.into());
assert!(condition.compare(&vec![2.1, 1.1, 3.3].into()));
assert!(!condition.compare(&vec![2.1, 3.3].into()));
let condition = Comparison::Contains(vec![2.2, 3.3].into());
assert!(condition.compare(&vec![2.2, 1.1, 3.3].into()));
assert!(!condition.compare(&vec![1.1, 3.3].into()));
let condition = Comparison::Contains("abc".into());
assert!(condition.compare(&vec!["abc".to_string(), "123".to_string()].into()));
assert!(!condition.compare(&vec!["0".to_string(), "123".to_string()].into()));
let condition = Comparison::Contains(vec!["abc".to_string(), "123".to_string()].into());
assert!(condition.compare(&vec!["abc".to_string(), "123".to_string()].into()));
assert!(!condition.compare(&vec!["123".to_string()].into()));
assert!(!Comparison::Contains("abc".into()).compare(&1.into()));
}
#[test]
fn value() {
assert_eq!(Comparison::Equal(DbValue::I64(0)).value(), &DbValue::I64(0));
assert_eq!(
Comparison::GreaterThan(DbValue::I64(0)).value(),
&DbValue::I64(0)
);
assert_eq!(
Comparison::GreaterThanOrEqual(DbValue::I64(0)).value(),
&DbValue::I64(0)
);
assert_eq!(
Comparison::LessThan(DbValue::I64(0)).value(),
&DbValue::I64(0)
);
assert_eq!(
Comparison::LessThanOrEqual(DbValue::I64(0)).value(),
&DbValue::I64(0)
);
assert_eq!(
Comparison::NotEqual(DbValue::I64(0)).value(),
&DbValue::I64(0)
);
assert_eq!(
Comparison::Contains(DbValue::I64(0)).value(),
&DbValue::I64(0)
);
}
#[test]
fn starts_with() {
let condition = Comparison::StartsWith("a".into());
assert!(condition.compare(&"abc".into()));
assert!(!condition.compare(&"bca".into()));
let condition = Comparison::StartsWith(vec!["ab".to_string(), "23".to_string()].into());
assert!(condition.compare(&"ab23".into()));
assert!(!condition.compare(&"ab2".into()));
assert!(Comparison::StartsWith(1.into()).compare(&vec![1, 2, 3].into()));
assert!(!Comparison::StartsWith(1.into()).compare(&vec![2, 1, 3].into()));
let condition = Comparison::StartsWith(vec![2, 3].into());
assert!(condition.compare(&vec![2, 3].into()));
assert!(!condition.compare(&vec![1, 2, 3].into()));
let condition = Comparison::StartsWith(1_u64.into());
assert!(condition.compare(&vec![1_u64, 2_u64, 3_u64].into()));
assert!(!condition.compare(&vec![2_u64, 1_u64].into()));
let condition = Comparison::StartsWith(vec![2_u64, 3_u64].into());
assert!(condition.compare(&vec![2_u64, 3_u64, 1_u64].into()));
assert!(!condition.compare(&vec![1_u64, 2_u64, 3_u64].into()));
let condition = Comparison::StartsWith(1.1.into());
assert!(condition.compare(&vec![1.1, 2.1, 3.3].into()));
assert!(!condition.compare(&vec![2.1, 3.3, 1.1].into()));
let condition = Comparison::StartsWith(vec![2.2, 3.3].into());
assert!(condition.compare(&vec![2.2, 3.3, 3.3].into()));
assert!(!condition.compare(&vec![1.1, 3.3, 2.2].into()));
let condition = Comparison::StartsWith("abc".into());
assert!(condition.compare(&vec!["abc".to_string(), "123".to_string()].into()));
assert!(!condition.compare(&vec!["0".to_string(), "abc".to_string()].into()));
let condition = Comparison::StartsWith(vec!["abc".to_string(), "123".to_string()].into());
assert!(condition.compare(&vec!["abc".to_string(), "123".to_string()].into()));
assert!(!condition.compare(&vec!["123".to_string(), "abc".to_string()].into()));
assert!(!Comparison::StartsWith("abc".into()).compare(&1.into()));
}
#[test]
fn ends_with() {
let condition = Comparison::EndsWith("a".into());
assert!(condition.compare(&"bca".into()));
assert!(!condition.compare(&"abc".into()));
let condition = Comparison::EndsWith(vec!["ab".to_string(), "23".to_string()].into());
assert!(condition.compare(&"ffeeggab23".into()));
assert!(!condition.compare(&"ab23ff".into()));
assert!(Comparison::EndsWith(1.into()).compare(&vec![1, 2, 1].into()));
assert!(!Comparison::EndsWith(1.into()).compare(&vec![1, 1, 3].into()));
let condition = Comparison::EndsWith(vec![2, 3].into());
assert!(condition.compare(&vec![4, 5, 2, 3].into()));
assert!(!condition.compare(&vec![1, 2, 3, 4, 5].into()));
let condition = Comparison::EndsWith(1_u64.into());
assert!(condition.compare(&vec![1_u64, 2_u64, 1_u64].into()));
assert!(!condition.compare(&vec![2_u64, 1_u64, 3_u64].into()));
let condition = Comparison::EndsWith(vec![2_u64, 3_u64].into());
assert!(condition.compare(&vec![1_u64, 2_u64, 3_u64].into()));
assert!(!condition.compare(&vec![2_u64, 3_u64, 1_u64].into()));
let condition = Comparison::EndsWith(1.1.into());
assert!(condition.compare(&vec![2.1, 3.3, 1.1].into()));
assert!(!condition.compare(&vec![2.1, 3.3, 1.1, 4.4].into()));
let condition = Comparison::EndsWith(vec![2.2, 3.3].into());
assert!(condition.compare(&vec![3.3, 4.4, 2.2, 3.3].into()));
assert!(!condition.compare(&vec![1.1, 3.3, 2.2].into()));
let condition = Comparison::EndsWith("abc".into());
assert!(condition.compare(&vec!["123".to_string(), "abc".to_string()].into()));
assert!(!condition.compare(&vec!["0".to_string(), "abcdef".to_string()].into()));
let condition = Comparison::EndsWith(vec!["abc".to_string(), "123".to_string()].into());
assert!(condition.compare(&vec!["abc".to_string(), "123".to_string()].into()));
assert!(!condition.compare(&vec!["123".to_string(), "abc".to_string()].into()));
assert!(!Comparison::EndsWith("abc".into()).compare(&1.into()));
}
}