use std::{collections::HashSet, sync::Arc};
use fraiseql_error::{FraiseQLError, Result};
use super::counter::ParamCounter;
use crate::{
dialect::SqlDialect,
where_clause::{WhereClause, WhereOperator},
};
pub(crate) fn escape_like_literal(s: &str) -> String {
s.replace('\\', "\\\\").replace('%', "\\%").replace('_', "\\_")
}
const MAX_REGEX_PATTERN_LEN: usize = 1_000;
fn validate_regex_pattern(pattern: &str) -> Result<()> {
if pattern.len() > MAX_REGEX_PATTERN_LEN {
return Err(FraiseQLError::Validation {
message: format!(
"Regex pattern exceeds maximum length of {MAX_REGEX_PATTERN_LEN} bytes"
),
path: None,
});
}
let bytes = pattern.as_bytes();
let mut depth: i32 = 0;
let mut group_has_quantifier = Vec::new();
for (i, &b) in bytes.iter().enumerate() {
if i > 0 && bytes[i - 1] == b'\\' {
continue;
}
match b {
b'(' => {
depth += 1;
group_has_quantifier.push(false);
},
b')' => {
let had_quantifier = group_has_quantifier.pop().unwrap_or(false);
depth -= 1;
if had_quantifier {
let next = bytes.get(i + 1).copied();
if matches!(next, Some(b'+' | b'*' | b'?' | b'{')) {
return Err(FraiseQLError::Validation {
message: "Regex pattern contains nested quantifiers (potential \
ReDoS). Simplify the pattern to avoid `(…+)+`, \
`(…*)*`, or similar constructs."
.to_string(),
path: None,
});
}
}
},
b'+' | b'*' | b'?' => {
if let Some(flag) = group_has_quantifier.last_mut() {
*flag = true;
}
},
b'{' if depth > 0 => {
if let Some(flag) = group_has_quantifier.last_mut() {
*flag = true;
}
},
_ => {},
}
}
Ok(())
}
pub struct GenericWhereGenerator<D: SqlDialect> {
dialect: D,
counter: ParamCounter,
indexed_columns: Option<Arc<HashSet<String>>>,
}
impl<D: SqlDialect> GenericWhereGenerator<D> {
pub const fn new(dialect: D) -> Self {
Self {
dialect,
counter: ParamCounter::new(),
indexed_columns: None,
}
}
#[must_use]
pub fn with_indexed_columns(mut self, cols: Arc<HashSet<String>>) -> Self {
self.indexed_columns = Some(cols);
self
}
pub fn generate(&self, clause: &WhereClause) -> Result<(String, Vec<serde_json::Value>)> {
self.generate_with_param_offset(clause, 0)
}
pub fn generate_with_hierarchy(
&self,
clause: &WhereClause,
hierarchy_ctx: &super::HierarchyContext,
) -> Result<(String, Vec<serde_json::Value>)> {
self.counter.reset_to(0);
let mut params = Vec::new();
let sql = self.visit_impl(clause, &mut params, Some(hierarchy_ctx))?;
Ok((sql, params))
}
pub fn generate_with_param_offset(
&self,
clause: &WhereClause,
offset: usize,
) -> Result<(String, Vec<serde_json::Value>)> {
self.counter.reset_to(offset);
let mut params = Vec::new();
let sql = self.visit(clause, &mut params)?;
Ok((sql, params))
}
fn visit(&self, clause: &WhereClause, params: &mut Vec<serde_json::Value>) -> Result<String> {
self.visit_impl(clause, params, None)
}
fn visit_impl(
&self,
clause: &WhereClause,
params: &mut Vec<serde_json::Value>,
hierarchy_ctx: Option<&super::HierarchyContext>,
) -> Result<String> {
match clause {
WhereClause::And(clauses) => {
if clauses.is_empty() {
return Ok(self.dialect.always_true().to_string());
}
let parts: Result<Vec<_>> =
clauses.iter().map(|c| self.visit_impl(c, params, hierarchy_ctx)).collect();
Ok(format!("({})", parts?.join(" AND ")))
},
WhereClause::Or(clauses) => {
if clauses.is_empty() {
return Ok(self.dialect.always_false().to_string());
}
let parts: Result<Vec<_>> =
clauses.iter().map(|c| self.visit_impl(c, params, hierarchy_ctx)).collect();
Ok(format!("({})", parts?.join(" OR ")))
},
WhereClause::Not(inner) => {
Ok(format!("NOT ({})", self.visit_impl(inner, params, hierarchy_ctx)?))
},
WhereClause::Field {
path,
operator,
value,
} => self.visit_field(path, operator, value, params, hierarchy_ctx),
WhereClause::NativeField {
column,
pg_cast,
operator,
value,
} => self.visit_native_field(column, pg_cast, operator, value, params),
}
}
fn visit_native_field(
&self,
column: &str,
pg_cast: &str,
operator: &WhereOperator,
value: &serde_json::Value,
params: &mut Vec<serde_json::Value>,
) -> Result<String> {
let col_expr = self.dialect.quote_identifier(column);
let p = self.push_param(params, value.clone());
let rhs = if pg_cast.is_empty() {
p
} else {
self.dialect.cast_native_param(&p, pg_cast)
};
match operator {
WhereOperator::Eq => Ok(format!("{col_expr} = {rhs}")),
WhereOperator::Neq => {
let neq = self.dialect.neq_operator();
Ok(format!("{col_expr} {neq} {rhs}"))
},
_ => Err(FraiseQLError::validation(format!(
"Operator {operator:?} is not supported for native column conditions"
))),
}
}
fn resolve_field_expr(&self, path: &[String]) -> String {
if let Some(indexed) = &self.indexed_columns {
let col_name = path.join("__");
if indexed.contains(&col_name) {
return self.dialect.quote_identifier(&col_name);
}
}
self.dialect.json_extract_scalar("data", path)
}
fn push_param(&self, params: &mut Vec<serde_json::Value>, v: serde_json::Value) -> String {
params.push(v);
self.dialect.placeholder(self.counter.next())
}
fn visit_field(
&self,
path: &[String],
operator: &WhereOperator,
value: &serde_json::Value,
params: &mut Vec<serde_json::Value>,
hierarchy_ctx: Option<&super::HierarchyContext>,
) -> Result<String> {
let field_expr = self.resolve_field_expr(path);
match operator {
WhereOperator::Eq => {
let p = self.push_param(params, value.clone());
if value.is_number() {
let cast = self.dialect.cast_to_numeric(&field_expr);
let rhs = self.dialect.cast_param_numeric(&p);
Ok(format!("{cast} = {rhs}"))
} else if value.is_boolean() {
let cast = self.dialect.cast_to_boolean(&field_expr);
Ok(format!("{cast} = {p}"))
} else {
Ok(format!("{field_expr} = {p}"))
}
},
WhereOperator::Neq => {
let p = self.push_param(params, value.clone());
let neq = self.dialect.neq_operator();
if value.is_number() {
let cast = self.dialect.cast_to_numeric(&field_expr);
let rhs = self.dialect.cast_param_numeric(&p);
Ok(format!("{cast} {neq} {rhs}"))
} else if value.is_boolean() {
let cast = self.dialect.cast_to_boolean(&field_expr);
Ok(format!("{cast} {neq} {p}"))
} else {
Ok(format!("{field_expr} {neq} {p}"))
}
},
WhereOperator::Gt | WhereOperator::Gte | WhereOperator::Lt | WhereOperator::Lte => {
let op = match operator {
WhereOperator::Gt => ">",
WhereOperator::Gte => ">=",
WhereOperator::Lt => "<",
_ => "<=",
};
let cast = self.dialect.cast_to_numeric(&field_expr);
let p = self.push_param(params, value.clone());
let rhs = self.dialect.cast_param_numeric(&p);
Ok(format!("{cast} {op} {rhs}"))
},
WhereOperator::In | WhereOperator::Nin => {
let arr = value.as_array().ok_or_else(|| {
FraiseQLError::validation("IN operator requires an array value".to_string())
})?;
if arr.is_empty() {
return Ok(if matches!(operator, WhereOperator::In) {
self.dialect.always_false().to_string()
} else {
self.dialect.always_true().to_string()
});
}
let placeholders: Vec<_> =
arr.iter().map(|v| self.push_param(params, v.clone())).collect();
let in_list = placeholders.join(", ");
let sql = format!("{field_expr} IN ({in_list})");
Ok(if matches!(operator, WhereOperator::Nin) {
format!("NOT ({sql})")
} else {
sql
})
},
WhereOperator::IsNull => {
let is_null = value.as_bool().unwrap_or(true);
let null_op = if is_null { "IS NULL" } else { "IS NOT NULL" };
Ok(format!("{field_expr} {null_op}"))
},
WhereOperator::Contains => {
let val_str = self.require_str(value, "Contains")?;
let escaped = escape_like_literal(val_str);
let p = self.push_param(params, serde_json::Value::String(escaped));
let pattern = self.dialect.concat_sql(&["'%'", &p, "'%'"]);
Ok(self.dialect.like_sql(&field_expr, &pattern))
},
WhereOperator::Icontains => {
let val_str = self.require_str(value, "Icontains")?;
let escaped = escape_like_literal(val_str);
let p = self.push_param(params, serde_json::Value::String(escaped));
let pattern = self.dialect.concat_sql(&["'%'", &p, "'%'"]);
Ok(self.dialect.ilike_sql(&field_expr, &pattern))
},
WhereOperator::Startswith => {
let val_str = self.require_str(value, "Startswith")?;
let escaped = escape_like_literal(val_str);
let p = self.push_param(params, serde_json::Value::String(escaped));
let pattern = self.dialect.concat_sql(&[&p, "'%'"]);
Ok(self.dialect.like_sql(&field_expr, &pattern))
},
WhereOperator::Istartswith => {
let val_str = self.require_str(value, "Istartswith")?;
let escaped = escape_like_literal(val_str);
let p = self.push_param(params, serde_json::Value::String(escaped));
let pattern = self.dialect.concat_sql(&[&p, "'%'"]);
Ok(self.dialect.ilike_sql(&field_expr, &pattern))
},
WhereOperator::Endswith => {
let val_str = self.require_str(value, "Endswith")?;
let escaped = escape_like_literal(val_str);
let p = self.push_param(params, serde_json::Value::String(escaped));
let pattern = self.dialect.concat_sql(&["'%'", &p]);
Ok(self.dialect.like_sql(&field_expr, &pattern))
},
WhereOperator::Iendswith => {
let val_str = self.require_str(value, "Iendswith")?;
let escaped = escape_like_literal(val_str);
let p = self.push_param(params, serde_json::Value::String(escaped));
let pattern = self.dialect.concat_sql(&["'%'", &p]);
Ok(self.dialect.ilike_sql(&field_expr, &pattern))
},
WhereOperator::Like => {
let p = self.push_param(params, value.clone());
Ok(self.dialect.like_sql(&field_expr, &p))
},
WhereOperator::Ilike => {
let p = self.push_param(params, value.clone());
Ok(self.dialect.ilike_sql(&field_expr, &p))
},
WhereOperator::Nlike => {
let p = self.push_param(params, value.clone());
Ok(format!("NOT ({})", self.dialect.like_sql(&field_expr, &p)))
},
WhereOperator::Nilike => {
let p = self.push_param(params, value.clone());
Ok(format!("NOT ({})", self.dialect.ilike_sql(&field_expr, &p)))
},
WhereOperator::Regex => {
if let Some(s) = value.as_str() {
validate_regex_pattern(s)?;
}
let p = self.push_param(params, value.clone());
self.dialect
.regex_sql(&field_expr, &p, false, false)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::Iregex => {
if let Some(s) = value.as_str() {
validate_regex_pattern(s)?;
}
let p = self.push_param(params, value.clone());
self.dialect
.regex_sql(&field_expr, &p, true, false)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::Nregex => {
if let Some(s) = value.as_str() {
validate_regex_pattern(s)?;
}
let p = self.push_param(params, value.clone());
self.dialect
.regex_sql(&field_expr, &p, false, true)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::Niregex => {
if let Some(s) = value.as_str() {
validate_regex_pattern(s)?;
}
let p = self.push_param(params, value.clone());
self.dialect
.regex_sql(&field_expr, &p, true, true)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::LenEq
| WhereOperator::LenNeq
| WhereOperator::LenGt
| WhereOperator::LenGte
| WhereOperator::LenLt
| WhereOperator::LenLte => {
let op = match operator {
WhereOperator::LenEq => "=",
WhereOperator::LenNeq => self.dialect.neq_operator(),
WhereOperator::LenGt => ">",
WhereOperator::LenGte => ">=",
WhereOperator::LenLt => "<",
_ => "<=",
};
let len_expr = self.dialect.json_array_length(&field_expr);
let p = self.push_param(params, value.clone());
Ok(format!("{len_expr} {op} {p}"))
},
WhereOperator::ArrayContains | WhereOperator::StrictlyContains => {
let p = self.push_param(params, value.clone());
self.dialect
.array_contains_sql(&field_expr, &p)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::ArrayContainedBy => {
let p = self.push_param(params, value.clone());
self.dialect
.array_contained_by_sql(&field_expr, &p)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::ArrayOverlaps => {
let p = self.push_param(params, value.clone());
self.dialect
.array_overlaps_sql(&field_expr, &p)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::Matches => {
let p = self.push_param(params, value.clone());
self.dialect
.fts_matches_sql(&field_expr, &p)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::PlainQuery => {
let p = self.push_param(params, value.clone());
self.dialect
.fts_plain_query_sql(&field_expr, &p)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::PhraseQuery => {
let p = self.push_param(params, value.clone());
self.dialect
.fts_phrase_query_sql(&field_expr, &p)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::WebsearchQuery => {
let p = self.push_param(params, value.clone());
self.dialect
.fts_websearch_query_sql(&field_expr, &p)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::CosineDistance => {
let p = self.push_param(params, value.clone());
self.dialect
.vector_distance_sql("<=>", &field_expr, &p)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::L2Distance => {
let p = self.push_param(params, value.clone());
self.dialect
.vector_distance_sql("<->", &field_expr, &p)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::L1Distance => {
let p = self.push_param(params, value.clone());
self.dialect
.vector_distance_sql("<+>", &field_expr, &p)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::HammingDistance => {
let p = self.push_param(params, value.clone());
self.dialect
.vector_distance_sql("<~>", &field_expr, &p)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::InnerProduct => {
let p = self.push_param(params, value.clone());
self.dialect
.vector_distance_sql("<#>", &field_expr, &p)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::JaccardDistance => {
let p = self.push_param(params, value.clone());
self.dialect
.jaccard_distance_sql(&field_expr, &p)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::IsIPv4 => self
.dialect
.inet_check_sql(&field_expr, "IsIPv4")
.map_err(|e| FraiseQLError::validation(e.to_string())),
WhereOperator::IsIPv6 => self
.dialect
.inet_check_sql(&field_expr, "IsIPv6")
.map_err(|e| FraiseQLError::validation(e.to_string())),
WhereOperator::IsPrivate => {
let negate = value.as_bool().is_some_and(|v| !v);
let check_name = if negate { "IsPublic" } else { "IsPrivate" };
self.dialect
.inet_check_sql(&field_expr, check_name)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::IsLoopback => {
let negate = value.as_bool().is_some_and(|v| !v);
let check_name = if negate {
"IsNotLoopback"
} else {
"IsLoopback"
};
self.dialect
.inet_check_sql(&field_expr, check_name)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::IsMulticast => {
let negate = value.as_bool().is_some_and(|v| !v);
let check_name = if negate {
"IsNotMulticast"
} else {
"IsMulticast"
};
self.dialect
.inet_check_sql(&field_expr, check_name)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::IsLinkLocal => {
let negate = value.as_bool().is_some_and(|v| !v);
let check_name = if negate {
"IsNotLinkLocal"
} else {
"IsLinkLocal"
};
self.dialect
.inet_check_sql(&field_expr, check_name)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::IsDocumentation => {
let negate = value.as_bool().is_some_and(|v| !v);
let check_name = if negate {
"IsNotDocumentation"
} else {
"IsDocumentation"
};
self.dialect
.inet_check_sql(&field_expr, check_name)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::IsCarrierGrade => {
let negate = value.as_bool().is_some_and(|v| !v);
let check_name = if negate {
"IsNotCarrierGrade"
} else {
"IsCarrierGrade"
};
self.dialect
.inet_check_sql(&field_expr, check_name)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::InSubnet => {
let p = self.push_param(params, value.clone());
self.dialect
.inet_binary_sql("<<", &field_expr, &p)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::ContainsSubnet | WhereOperator::ContainsIP => {
let p = self.push_param(params, value.clone());
self.dialect
.inet_binary_sql(">>", &field_expr, &p)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::Overlaps => {
let p = self.push_param(params, value.clone());
self.dialect
.inet_binary_sql("&&", &field_expr, &p)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::AncestorOf => {
let p = self.push_param(params, value.clone());
self.dialect
.ltree_binary_sql("@>", &field_expr, &p, "ltree")
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::DescendantOf => {
let p = self.push_param(params, value.clone());
self.dialect
.ltree_binary_sql("<@", &field_expr, &p, "ltree")
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::MatchesLquery => {
let p = self.push_param(params, value.clone());
self.dialect
.ltree_binary_sql("~", &field_expr, &p, "lquery")
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::MatchesLtxtquery => {
let p = self.push_param(params, value.clone());
self.dialect
.ltree_binary_sql("@", &field_expr, &p, "ltxtquery")
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::MatchesAnyLquery => {
let arr = value.as_array().ok_or_else(|| {
FraiseQLError::validation(
"matches_any_lquery operator requires an array value".to_string(),
)
})?;
if arr.is_empty() {
return Err(FraiseQLError::validation(
"matches_any_lquery requires at least one lquery".to_string(),
));
}
let placeholders: Vec<_> = arr
.iter()
.map(|v| format!("{}::lquery", self.push_param(params, v.clone())))
.collect();
self.dialect
.ltree_any_lquery_sql(&field_expr, &placeholders)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::DepthEq => {
let p = self.push_param(params, value.clone());
self.dialect
.ltree_depth_sql("=", &field_expr, &p)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::DepthNeq => {
let p = self.push_param(params, value.clone());
self.dialect
.ltree_depth_sql("!=", &field_expr, &p)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::DepthGt => {
let p = self.push_param(params, value.clone());
self.dialect
.ltree_depth_sql(">", &field_expr, &p)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::DepthGte => {
let p = self.push_param(params, value.clone());
self.dialect
.ltree_depth_sql(">=", &field_expr, &p)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::DepthLt => {
let p = self.push_param(params, value.clone());
self.dialect
.ltree_depth_sql("<", &field_expr, &p)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::DepthLte => {
let p = self.push_param(params, value.clone());
self.dialect
.ltree_depth_sql("<=", &field_expr, &p)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::Lca => {
let arr = value.as_array().ok_or_else(|| {
FraiseQLError::validation("lca operator requires an array value".to_string())
})?;
if arr.is_empty() {
return Err(FraiseQLError::validation(
"lca operator requires at least one path".to_string(),
));
}
let placeholders: Vec<_> = arr
.iter()
.map(|v| format!("{}::ltree", self.push_param(params, v.clone())))
.collect();
self.dialect
.ltree_lca_sql(&field_expr, &placeholders)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::DescendantOfId | WhereOperator::AncestorOfId => {
let ctx = hierarchy_ctx.ok_or_else(|| {
FraiseQLError::validation(
"descendantOfId/ancestorOfId requires HierarchyContext — \
configure [hierarchies] in fraiseql.toml"
.to_string(),
)
})?;
let pg_op = if matches!(operator, WhereOperator::DescendantOfId) {
"<@"
} else {
"@>"
};
let p = self.push_param(params, value.clone());
self.dialect
.ltree_id_subquery_sql(
pg_op,
&field_expr,
&ctx.table,
&ctx.path_column,
ctx.fk_column.as_deref(),
&p,
)
.map_err(|e| FraiseQLError::validation(e.to_string()))
},
WhereOperator::Extended(op) => {
self.dialect.generate_extended_sql(op, &field_expr, params)
},
#[allow(unreachable_patterns)]
_ => Err(FraiseQLError::Validation {
message: format!(
"Operator {operator:?} is not supported by the {} dialect",
self.dialect.name()
),
path: None,
}),
}
}
fn require_str<'a>(&self, value: &'a serde_json::Value, op: &'static str) -> Result<&'a str> {
value.as_str().ok_or_else(|| {
FraiseQLError::validation(format!("{op} operator requires a string value"))
})
}
}
impl<D: SqlDialect + Default> Default for GenericWhereGenerator<D> {
fn default() -> Self {
Self::new(D::default())
}
}
impl<D: SqlDialect> crate::filters::ExtendedOperatorHandler for GenericWhereGenerator<D> {
fn generate_extended_sql(
&self,
operator: &crate::filters::ExtendedOperator,
field_sql: &str,
params: &mut Vec<serde_json::Value>,
) -> Result<String> {
self.dialect.generate_extended_sql(operator, field_sql, params)
}
}
#[cfg(test)]
mod tests;