use fraiseql_error::{FraiseQLError, Result};
use serde_json::Value;
use crate::{WhereClause, WhereOperator};
const MAX_SQL_VALUE_BYTES: usize = 65_536;
#[doc(hidden)]
pub struct WhereSqlGenerator;
impl WhereSqlGenerator {
pub fn to_sql(clause: &WhereClause) -> Result<String> {
match clause {
WhereClause::Field {
path,
operator,
value,
} => Self::generate_field_predicate(path, operator, value),
WhereClause::And(clauses) => {
if clauses.is_empty() {
return Ok("TRUE".to_string());
}
let parts: Result<Vec<_>> = clauses.iter().map(Self::to_sql).collect();
Ok(format!("({})", parts?.join(" AND ")))
},
WhereClause::Or(clauses) => {
if clauses.is_empty() {
return Ok("FALSE".to_string());
}
let parts: Result<Vec<_>> = clauses.iter().map(Self::to_sql).collect();
Ok(format!("({})", parts?.join(" OR ")))
},
WhereClause::Not(clause) => {
let inner = Self::to_sql(clause)?;
Ok(format!("NOT ({})", inner))
},
WhereClause::NativeField {
column,
operator,
value,
..
} => {
let escaped_col = Self::escape_sql_string(column)?;
let col_expr = format!("\"{escaped_col}\"");
let sql_op = Self::operator_to_sql(operator)?;
let val_sql = Self::value_to_sql(value, operator)?;
Ok(format!("{col_expr} {sql_op} {val_sql}"))
},
}
}
fn generate_field_predicate(
path: &[String],
operator: &WhereOperator,
value: &Value,
) -> Result<String> {
let json_path = Self::build_json_path(path)?;
let sql = if operator == &WhereOperator::IsNull {
let is_null = value.as_bool().unwrap_or(true);
if is_null {
format!("{json_path} IS NULL")
} else {
format!("{json_path} IS NOT NULL")
}
} else {
let sql_op = Self::operator_to_sql(operator)?;
let sql_value = Self::value_to_sql(value, operator)?;
format!("{json_path} {sql_op} {sql_value}")
};
Ok(sql)
}
fn build_json_path(path: &[String]) -> Result<String> {
if path.is_empty() {
return Ok("data".to_string());
}
if path.len() == 1 {
let escaped = Self::escape_sql_string(&path[0])?;
Ok(format!("data->>'{}'", escaped))
} else {
let nested = &path[..path.len() - 1];
let last = &path[path.len() - 1];
let escaped_nested: Vec<String> =
nested.iter().map(|n| Self::escape_sql_string(n)).collect::<Result<Vec<_>>>()?;
let nested_path = escaped_nested.join(",");
let escaped_last = Self::escape_sql_string(last)?;
Ok(format!("data#>'{{{}}}'->>'{}'", nested_path, escaped_last))
}
}
fn operator_to_sql(operator: &WhereOperator) -> Result<&'static str> {
Ok(match operator {
WhereOperator::Eq => "=",
WhereOperator::Neq => "!=",
WhereOperator::Gt => ">",
WhereOperator::Gte => ">=",
WhereOperator::Lt => "<",
WhereOperator::Lte => "<=",
WhereOperator::In => "= ANY",
WhereOperator::Nin => "!= ALL",
WhereOperator::Contains => "LIKE",
WhereOperator::Icontains => "ILIKE",
WhereOperator::Startswith => "LIKE",
WhereOperator::Istartswith => "ILIKE",
WhereOperator::Endswith => "LIKE",
WhereOperator::Iendswith => "ILIKE",
WhereOperator::Like => "LIKE",
WhereOperator::Ilike => "ILIKE",
WhereOperator::Nlike => "NOT LIKE",
WhereOperator::Nilike => "NOT ILIKE",
WhereOperator::Regex => "~",
WhereOperator::Iregex => "~*",
WhereOperator::Nregex => "!~",
WhereOperator::Niregex => "!~*",
WhereOperator::ArrayContains => "@>",
WhereOperator::ArrayContainedBy => "<@",
WhereOperator::ArrayOverlaps => "&&",
WhereOperator::IsNull => {
return Err(FraiseQLError::Internal {
message: "IsNull should be handled separately".to_string(),
source: None,
});
},
WhereOperator::LenEq
| WhereOperator::LenGt
| WhereOperator::LenLt
| WhereOperator::LenGte
| WhereOperator::LenLte
| WhereOperator::LenNeq => {
return Err(FraiseQLError::Internal {
message: format!(
"Array length operators not yet supported in fraiseql-wire: {operator:?}"
),
source: None,
});
},
WhereOperator::L2Distance
| WhereOperator::CosineDistance
| WhereOperator::L1Distance
| WhereOperator::HammingDistance
| WhereOperator::InnerProduct
| WhereOperator::JaccardDistance => {
return Err(FraiseQLError::Internal {
message: format!(
"Vector operations not supported in fraiseql-wire: {operator:?}"
),
source: None,
});
},
WhereOperator::Matches
| WhereOperator::PlainQuery
| WhereOperator::PhraseQuery
| WhereOperator::WebsearchQuery => {
return Err(FraiseQLError::Internal {
message: format!(
"Full-text search operators not yet supported in fraiseql-wire: {operator:?}"
),
source: None,
});
},
WhereOperator::IsIPv4
| WhereOperator::IsIPv6
| WhereOperator::IsPrivate
| WhereOperator::IsPublic
| WhereOperator::IsLoopback
| WhereOperator::InSubnet
| WhereOperator::ContainsSubnet
| WhereOperator::ContainsIP
| WhereOperator::Overlaps
| WhereOperator::StrictlyContains
| WhereOperator::AncestorOf
| WhereOperator::DescendantOf
| WhereOperator::MatchesLquery
| WhereOperator::MatchesLtxtquery
| WhereOperator::MatchesAnyLquery
| WhereOperator::DepthEq
| WhereOperator::DepthNeq
| WhereOperator::DepthGt
| WhereOperator::DepthGte
| WhereOperator::DepthLt
| WhereOperator::DepthLte
| WhereOperator::Lca
| WhereOperator::Extended(_) => {
return Err(FraiseQLError::Internal {
message: format!(
"Advanced operators not yet supported in fraiseql-wire: {operator:?}"
),
source: None,
});
},
})
}
fn value_to_sql(value: &Value, operator: &WhereOperator) -> Result<String> {
match (value, operator) {
(Value::Null, _) => Ok("NULL".to_string()),
(Value::Bool(b), _) => Ok(b.to_string()),
(Value::Number(n), _) => Ok(n.to_string()),
(Value::String(s), WhereOperator::Contains | WhereOperator::Icontains) => {
Ok(format!("'%{}%'", Self::escape_sql_string(s)?))
},
(Value::String(s), WhereOperator::Startswith | WhereOperator::Istartswith) => {
Ok(format!("'{}%'", Self::escape_sql_string(s)?))
},
(Value::String(s), WhereOperator::Endswith | WhereOperator::Iendswith) => {
Ok(format!("'%{}'", Self::escape_sql_string(s)?))
},
(Value::String(s), _) => Ok(format!("'{}'", Self::escape_sql_string(s)?)),
(Value::Array(arr), WhereOperator::In | WhereOperator::Nin) => {
let values: Result<Vec<_>> =
arr.iter().map(|v| Self::value_to_sql(v, &WhereOperator::Eq)).collect();
Ok(format!("ARRAY[{}]", values?.join(", ")))
},
(
Value::Array(_),
WhereOperator::ArrayContains
| WhereOperator::ArrayContainedBy
| WhereOperator::ArrayOverlaps,
) => {
let json_str =
serde_json::to_string(value).map_err(|e| FraiseQLError::Internal {
message: format!("Failed to serialize JSON for array operator: {e}"),
source: None,
})?;
if json_str.len() > MAX_SQL_VALUE_BYTES {
return Err(FraiseQLError::Validation {
message: format!(
"JSONB value exceeds maximum allowed size for SQL embedding \
({} bytes, limit is {} bytes)",
json_str.len(),
MAX_SQL_VALUE_BYTES
),
path: None,
});
}
let escaped = json_str.replace('\'', "''");
Ok(format!("'{}'::jsonb", escaped))
},
_ => Err(FraiseQLError::Internal {
message: format!(
"Unsupported value type for operator: {value:?} with {operator:?}"
),
source: None,
}),
}
}
fn escape_sql_string(s: &str) -> Result<String> {
if s.len() > MAX_SQL_VALUE_BYTES {
return Err(FraiseQLError::Validation {
message: format!(
"String value exceeds maximum allowed size for SQL embedding \
({} bytes, limit is {} bytes)",
s.len(),
MAX_SQL_VALUE_BYTES
),
path: None,
});
}
Ok(s.replace('\'', "''"))
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)] mod tests {
use serde_json::json;
use super::*;
#[test]
fn test_simple_equality() {
let clause = WhereClause::Field {
path: vec!["status".to_string()],
operator: WhereOperator::Eq,
value: json!("active"),
};
let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
assert_eq!(sql, "data->>'status' = 'active'");
}
#[test]
fn test_nested_path() {
let clause = WhereClause::Field {
path: vec!["user".to_string(), "email".to_string()],
operator: WhereOperator::Eq,
value: json!("test@example.com"),
};
let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
assert_eq!(sql, "data#>'{user}'->>'email' = 'test@example.com'");
}
#[test]
fn test_icontains() {
let clause = WhereClause::Field {
path: vec!["name".to_string()],
operator: WhereOperator::Icontains,
value: json!("john"),
};
let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
assert_eq!(sql, "data->>'name' ILIKE '%john%'");
}
#[test]
fn test_startswith() {
let clause = WhereClause::Field {
path: vec!["email".to_string()],
operator: WhereOperator::Startswith,
value: json!("admin"),
};
let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
assert_eq!(sql, "data->>'email' LIKE 'admin%'");
}
#[test]
fn test_and_clause() {
let clause = WhereClause::And(vec![
WhereClause::Field {
path: vec!["status".to_string()],
operator: WhereOperator::Eq,
value: json!("active"),
},
WhereClause::Field {
path: vec!["age".to_string()],
operator: WhereOperator::Gte,
value: json!(18),
},
]);
let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
assert_eq!(sql, "(data->>'status' = 'active' AND data->>'age' >= 18)");
}
#[test]
fn test_or_clause() {
let clause = WhereClause::Or(vec![
WhereClause::Field {
path: vec!["type".to_string()],
operator: WhereOperator::Eq,
value: json!("admin"),
},
WhereClause::Field {
path: vec!["type".to_string()],
operator: WhereOperator::Eq,
value: json!("moderator"),
},
]);
let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
assert_eq!(sql, "(data->>'type' = 'admin' OR data->>'type' = 'moderator')");
}
#[test]
fn test_not_clause() {
let clause = WhereClause::Not(Box::new(WhereClause::Field {
path: vec!["deleted".to_string()],
operator: WhereOperator::Eq,
value: json!(true),
}));
let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
assert_eq!(sql, "NOT (data->>'deleted' = true)");
}
#[test]
fn test_is_null() {
let clause = WhereClause::Field {
path: vec!["deleted_at".to_string()],
operator: WhereOperator::IsNull,
value: json!(true),
};
let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
assert_eq!(sql, "data->>'deleted_at' IS NULL");
}
#[test]
fn test_is_not_null() {
let clause = WhereClause::Field {
path: vec!["updated_at".to_string()],
operator: WhereOperator::IsNull,
value: json!(false),
};
let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
assert_eq!(sql, "data->>'updated_at' IS NOT NULL");
}
#[test]
fn test_in_operator() {
let clause = WhereClause::Field {
path: vec!["status".to_string()],
operator: WhereOperator::In,
value: json!(["active", "pending", "approved"]),
};
let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
assert_eq!(sql, "data->>'status' = ANY ARRAY['active', 'pending', 'approved']");
}
#[test]
fn test_sql_injection_prevention() {
let clause = WhereClause::Field {
path: vec!["name".to_string()],
operator: WhereOperator::Eq,
value: json!("'; DROP TABLE users; --"),
};
let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
assert_eq!(sql, "data->>'name' = '''; DROP TABLE users; --'");
}
#[test]
fn test_numeric_comparison() {
let clause = WhereClause::Field {
path: vec!["price".to_string()],
operator: WhereOperator::Gt,
value: json!(99.99),
};
let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
assert_eq!(sql, "data->>'price' > 99.99");
}
#[test]
fn test_boolean_value() {
let clause = WhereClause::Field {
path: vec!["published".to_string()],
operator: WhereOperator::Eq,
value: json!(true),
};
let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
assert_eq!(sql, "data->>'published' = true");
}
#[test]
fn test_empty_and_clause() {
let clause = WhereClause::And(vec![]);
let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
assert_eq!(sql, "TRUE");
}
#[test]
fn test_empty_or_clause() {
let clause = WhereClause::Or(vec![]);
let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
assert_eq!(sql, "FALSE");
}
#[test]
fn test_complex_nested_condition() {
let clause = WhereClause::And(vec![
WhereClause::Field {
path: vec!["type".to_string()],
operator: WhereOperator::Eq,
value: json!("article"),
},
WhereClause::Or(vec![
WhereClause::Field {
path: vec!["status".to_string()],
operator: WhereOperator::Eq,
value: json!("published"),
},
WhereClause::And(vec![
WhereClause::Field {
path: vec!["status".to_string()],
operator: WhereOperator::Eq,
value: json!("draft"),
},
WhereClause::Field {
path: vec!["author".to_string(), "role".to_string()],
operator: WhereOperator::Eq,
value: json!("admin"),
},
]),
]),
]);
let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
assert_eq!(
sql,
"(data->>'type' = 'article' AND (data->>'status' = 'published' OR (data->>'status' = 'draft' AND data#>'{author}'->>'role' = 'admin')))"
);
}
#[test]
fn test_sql_injection_in_field_name_simple() {
let clause = WhereClause::Field {
path: vec!["name'; DROP TABLE users; --".to_string()],
operator: WhereOperator::Eq,
value: json!("value"),
};
let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
assert!(sql.contains("''")); assert!(sql.contains("data->>'"));
assert!(sql.contains("= 'value'")); }
#[test]
fn test_sql_injection_prevention_in_array_operator() {
let clause = WhereClause::Field {
path: vec!["tags".to_string()],
operator: WhereOperator::ArrayContains,
value: json!(["normal", "'; DROP TABLE users; --"]),
};
let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
assert!(sql.contains("::jsonb"), "Must produce valid JSONB cast");
assert!(
sql.contains("''"),
"Single quotes inside JSON values must be doubled for SQL safety"
);
}
#[test]
fn test_sql_injection_in_nested_field_name() {
let clause = WhereClause::Field {
path: vec![
"user".to_string(),
"role'; DROP TABLE users; --".to_string(),
],
operator: WhereOperator::Eq,
value: json!("admin"),
};
let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
assert!(sql.contains("''")); assert!(sql.contains("data#>'{")); }
#[test]
fn escape_sql_string_rejects_oversized_input() {
let large = "a".repeat(MAX_SQL_VALUE_BYTES + 1);
let result = WhereSqlGenerator::escape_sql_string(&large);
assert!(matches!(result, Err(FraiseQLError::Validation { .. })));
}
#[test]
fn escape_sql_string_accepts_exactly_max_bytes() {
let at_limit = "a".repeat(MAX_SQL_VALUE_BYTES);
WhereSqlGenerator::escape_sql_string(&at_limit).unwrap_or_else(|e| {
panic!("expected Ok for string at exactly MAX_SQL_VALUE_BYTES: {e}")
});
}
#[test]
fn escape_sql_string_escapes_single_quotes() {
let result = WhereSqlGenerator::escape_sql_string("it's").unwrap();
assert_eq!(result, "it''s");
}
#[test]
fn value_to_sql_rejects_oversized_string_value() {
let large = "a".repeat(MAX_SQL_VALUE_BYTES + 1);
let clause = WhereClause::Field {
path: vec!["name".to_string()],
operator: WhereOperator::Eq,
value: json!(large),
};
assert!(matches!(
WhereSqlGenerator::to_sql(&clause),
Err(FraiseQLError::Validation { .. })
));
}
#[test]
fn value_to_sql_rejects_oversized_jsonb_value() {
let large_element = "a".repeat(MAX_SQL_VALUE_BYTES);
let clause = WhereClause::Field {
path: vec!["tags".to_string()],
operator: WhereOperator::ArrayContains,
value: json!([large_element]),
};
assert!(matches!(
WhereSqlGenerator::to_sql(&clause),
Err(FraiseQLError::Validation { .. })
));
}
}