use fraiseql_error::{FraiseQLError, Result};
use serde::{Deserialize, Serialize};
use crate::utils::to_snake_case;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum WhereClause {
Field {
path: Vec<String>,
operator: WhereOperator,
value: serde_json::Value,
},
And(Vec<WhereClause>),
Or(Vec<WhereClause>),
Not(Box<WhereClause>),
NativeField {
column: String,
pg_cast: String,
operator: WhereOperator,
value: serde_json::Value,
},
}
impl WhereClause {
#[must_use]
pub const fn is_empty(&self) -> bool {
match self {
Self::And(clauses) | Self::Or(clauses) => clauses.is_empty(),
Self::Not(_) | Self::Field { .. } | Self::NativeField { .. } => false,
}
}
#[must_use]
pub fn native_column_names(&self) -> Vec<&str> {
let mut names = Vec::new();
self.collect_native_column_names(&mut names);
names
}
fn collect_native_column_names<'a>(&'a self, out: &mut Vec<&'a str>) {
match self {
Self::And(clauses) | Self::Or(clauses) => {
for c in clauses {
c.collect_native_column_names(out);
}
},
Self::Not(inner) => inner.collect_native_column_names(out),
Self::NativeField { column, .. } => out.push(column),
Self::Field { .. } => {},
}
}
pub fn from_graphql_json(value: &serde_json::Value) -> Result<Self> {
Self::parse_where_object(value, &[])
}
fn parse_where_object(value: &serde_json::Value, path_prefix: &[String]) -> Result<Self> {
let Some(obj) = value.as_object() else {
return Err(FraiseQLError::Validation {
message: "where clause must be a JSON object".to_string(),
path: None,
});
};
let mut conditions = Vec::new();
for (key, val) in obj {
match key.as_str() {
"_and" => {
let arr = val.as_array().ok_or_else(|| FraiseQLError::Validation {
message: "_and must be an array".to_string(),
path: None,
})?;
let sub: Result<Vec<Self>> =
arr.iter().map(|v| Self::parse_where_object(v, path_prefix)).collect();
conditions.push(Self::And(sub?));
},
"_or" => {
let arr = val.as_array().ok_or_else(|| FraiseQLError::Validation {
message: "_or must be an array".to_string(),
path: None,
})?;
let sub: Result<Vec<Self>> =
arr.iter().map(|v| Self::parse_where_object(v, path_prefix)).collect();
conditions.push(Self::Or(sub?));
},
"_not" => {
let sub = Self::parse_where_object(val, path_prefix)?;
conditions.push(Self::Not(Box::new(sub)));
},
field_name => {
let ops = val.as_object().ok_or_else(|| FraiseQLError::Validation {
message: format!(
"where field '{field_name}' must be an object of {{operator: value}}"
),
path: None,
})?;
let mut field_path = path_prefix.to_vec();
field_path.push(to_snake_case(field_name));
for (op_str, op_val) in ops {
match WhereOperator::from_str(op_str) {
Ok(operator) => {
conditions.push(Self::Field {
path: field_path.clone(),
operator,
value: op_val.clone(),
});
},
Err(_) if op_val.is_object() => {
let nested_json = serde_json::json!({ op_str: op_val });
let nested = Self::parse_where_object(&nested_json, &field_path)?;
conditions.push(nested);
},
Err(e) => return Err(e),
}
}
},
}
}
if conditions.len() == 1 {
Ok(conditions.into_iter().next().expect("checked len == 1"))
} else {
Ok(Self::And(conditions))
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum WhereOperator {
Eq,
Neq,
Gt,
Gte,
Lt,
Lte,
In,
Nin,
Contains,
Icontains,
Startswith,
Istartswith,
Endswith,
Iendswith,
Like,
Ilike,
Nlike,
Nilike,
Regex,
Iregex,
Nregex,
Niregex,
IsNull,
ArrayContains,
ArrayContainedBy,
ArrayOverlaps,
LenEq,
LenGt,
LenLt,
LenGte,
LenLte,
LenNeq,
CosineDistance,
L2Distance,
L1Distance,
HammingDistance,
InnerProduct,
JaccardDistance,
Matches,
PlainQuery,
PhraseQuery,
WebsearchQuery,
IsIPv4,
IsIPv6,
IsPrivate,
IsLoopback,
IsMulticast,
IsLinkLocal,
IsDocumentation,
IsCarrierGrade,
InSubnet,
ContainsSubnet,
ContainsIP,
Overlaps,
StrictlyContains,
AncestorOf,
DescendantOf,
MatchesLquery,
MatchesLtxtquery,
MatchesAnyLquery,
DepthEq,
DepthNeq,
DepthGt,
DepthGte,
DepthLt,
DepthLte,
Lca,
DescendantOfId,
AncestorOfId,
#[serde(skip)]
Extended(crate::filters::ExtendedOperator),
}
impl WhereOperator {
#[allow(clippy::should_implement_trait)] pub fn from_str(s: &str) -> Result<Self> {
if let Some(op) = Self::match_exact(s) {
return Ok(op);
}
if !s.contains('_') && s.chars().any(char::is_uppercase) {
let snake = crate::utils::to_snake_case(s);
if let Some(op) = Self::match_exact(&snake) {
return Ok(op);
}
}
Err(FraiseQLError::validation(format!("Unknown WHERE operator: {s}")))
}
fn match_exact(s: &str) -> Option<Self> {
match s {
"eq" => Some(Self::Eq),
"neq" => Some(Self::Neq),
"gt" => Some(Self::Gt),
"gte" => Some(Self::Gte),
"lt" => Some(Self::Lt),
"lte" => Some(Self::Lte),
"in" => Some(Self::In),
"nin" | "notin" => Some(Self::Nin),
"contains" => Some(Self::Contains),
"icontains" => Some(Self::Icontains),
"startswith" => Some(Self::Startswith),
"istartswith" => Some(Self::Istartswith),
"endswith" => Some(Self::Endswith),
"iendswith" => Some(Self::Iendswith),
"like" => Some(Self::Like),
"ilike" => Some(Self::Ilike),
"nlike" => Some(Self::Nlike),
"nilike" => Some(Self::Nilike),
"regex" => Some(Self::Regex),
"iregex" | "imatches" => Some(Self::Iregex),
"nregex" | "not_matches" => Some(Self::Nregex),
"niregex" => Some(Self::Niregex),
"isnull" => Some(Self::IsNull),
"array_contains" => Some(Self::ArrayContains),
"array_contained_by" => Some(Self::ArrayContainedBy),
"array_overlaps" => Some(Self::ArrayOverlaps),
"len_eq" => Some(Self::LenEq),
"len_gt" => Some(Self::LenGt),
"len_lt" => Some(Self::LenLt),
"len_gte" => Some(Self::LenGte),
"len_lte" => Some(Self::LenLte),
"len_neq" => Some(Self::LenNeq),
"cosine_distance" => Some(Self::CosineDistance),
"l2_distance" => Some(Self::L2Distance),
"l1_distance" => Some(Self::L1Distance),
"hamming_distance" => Some(Self::HammingDistance),
"inner_product" => Some(Self::InnerProduct),
"jaccard_distance" => Some(Self::JaccardDistance),
"matches" => Some(Self::Matches),
"plain_query" => Some(Self::PlainQuery),
"phrase_query" => Some(Self::PhraseQuery),
"websearch_query" => Some(Self::WebsearchQuery),
"is_ipv4" => Some(Self::IsIPv4),
"is_ipv6" => Some(Self::IsIPv6),
"is_private" => Some(Self::IsPrivate),
"is_loopback" => Some(Self::IsLoopback),
"is_multicast" => Some(Self::IsMulticast),
"is_link_local" => Some(Self::IsLinkLocal),
"is_documentation" => Some(Self::IsDocumentation),
"is_carrier_grade" => Some(Self::IsCarrierGrade),
"in_subnet" | "inrange" => Some(Self::InSubnet),
"contains_subnet" => Some(Self::ContainsSubnet),
"contains_ip" => Some(Self::ContainsIP),
"overlaps" => Some(Self::Overlaps),
"strictly_contains" => Some(Self::StrictlyContains),
"ancestor_of" => Some(Self::AncestorOf),
"descendant_of" => Some(Self::DescendantOf),
"matches_lquery" => Some(Self::MatchesLquery),
"matches_ltxtquery" => Some(Self::MatchesLtxtquery),
"matches_any_lquery" => Some(Self::MatchesAnyLquery),
"depth_eq" => Some(Self::DepthEq),
"depth_neq" => Some(Self::DepthNeq),
"depth_gt" => Some(Self::DepthGt),
"depth_gte" => Some(Self::DepthGte),
"depth_lt" => Some(Self::DepthLt),
"depth_lte" => Some(Self::DepthLte),
"lca" => Some(Self::Lca),
"descendant_of_id" => Some(Self::DescendantOfId),
"ancestor_of_id" => Some(Self::AncestorOfId),
_ => None,
}
}
#[must_use]
pub const fn expects_array(&self) -> bool {
matches!(self, Self::In | Self::Nin)
}
#[must_use]
pub const fn is_case_insensitive(&self) -> bool {
matches!(
self,
Self::Icontains
| Self::Istartswith
| Self::Iendswith
| Self::Ilike
| Self::Nilike
| Self::Iregex
| Self::Niregex
)
}
#[must_use]
pub const fn is_string_operator(&self) -> bool {
matches!(
self,
Self::Contains
| Self::Icontains
| Self::Startswith
| Self::Istartswith
| Self::Endswith
| Self::Iendswith
| Self::Like
| Self::Ilike
| Self::Nlike
| Self::Nilike
| Self::Regex
| Self::Iregex
| Self::Nregex
| Self::Niregex
)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum HavingClause {
Aggregate {
aggregate: String,
operator: WhereOperator,
value: serde_json::Value,
},
And(Vec<HavingClause>),
Or(Vec<HavingClause>),
Not(Box<HavingClause>),
}
impl HavingClause {
#[must_use]
pub const fn is_empty(&self) -> bool {
match self {
Self::And(clauses) | Self::Or(clauses) => clauses.is_empty(),
Self::Not(_) | Self::Aggregate { .. } => false,
}
}
}
#[cfg(test)]
mod tests;