1use fraiseql_error::{FraiseQLError, Result};
7use serde_json::Value;
8
9use crate::{WhereClause, WhereOperator};
10
11const MAX_SQL_VALUE_BYTES: usize = 65_536;
17
18#[doc(hidden)]
31pub struct WhereSqlGenerator;
32
33impl WhereSqlGenerator {
34 pub fn to_sql(clause: &WhereClause) -> Result<String> {
58 match clause {
59 WhereClause::Field {
60 path,
61 operator,
62 value,
63 } => Self::generate_field_predicate(path, operator, value),
64 WhereClause::And(clauses) => {
65 if clauses.is_empty() {
66 return Ok("TRUE".to_string());
67 }
68 let parts: Result<Vec<_>> = clauses.iter().map(Self::to_sql).collect();
69 Ok(format!("({})", parts?.join(" AND ")))
70 },
71 WhereClause::Or(clauses) => {
72 if clauses.is_empty() {
73 return Ok("FALSE".to_string());
74 }
75 let parts: Result<Vec<_>> = clauses.iter().map(Self::to_sql).collect();
76 Ok(format!("({})", parts?.join(" OR ")))
77 },
78 WhereClause::Not(clause) => {
79 let inner = Self::to_sql(clause)?;
80 Ok(format!("NOT ({})", inner))
81 },
82 WhereClause::NativeField {
83 column,
84 operator,
85 value,
86 ..
87 } => {
88 let escaped_col = Self::escape_sql_string(column)?;
91 let col_expr = format!("\"{escaped_col}\"");
92 let sql_op = Self::operator_to_sql(operator)?;
93 let val_sql = Self::value_to_sql(value, operator)?;
94 Ok(format!("{col_expr} {sql_op} {val_sql}"))
95 },
96 }
97 }
98
99 fn generate_field_predicate(
100 path: &[String],
101 operator: &WhereOperator,
102 value: &Value,
103 ) -> Result<String> {
104 let json_path = Self::build_json_path(path)?;
105 let sql = if operator == &WhereOperator::IsNull {
106 let is_null = value.as_bool().unwrap_or(true);
107 if is_null {
108 format!("{json_path} IS NULL")
109 } else {
110 format!("{json_path} IS NOT NULL")
111 }
112 } else {
113 let sql_op = Self::operator_to_sql(operator)?;
114 let sql_value = Self::value_to_sql(value, operator)?;
115 format!("{json_path} {sql_op} {sql_value}")
116 };
117 Ok(sql)
118 }
119
120 fn build_json_path(path: &[String]) -> Result<String> {
121 if path.is_empty() {
122 return Ok("data".to_string());
123 }
124
125 if path.len() == 1 {
126 let escaped = Self::escape_sql_string(&path[0])?;
129 Ok(format!("data->>'{}'", escaped))
130 } else {
131 let nested = &path[..path.len() - 1];
134 let last = &path[path.len() - 1];
135
136 let escaped_nested: Vec<String> =
138 nested.iter().map(|n| Self::escape_sql_string(n)).collect::<Result<Vec<_>>>()?;
139 let nested_path = escaped_nested.join(",");
140 let escaped_last = Self::escape_sql_string(last)?;
141 Ok(format!("data#>'{{{}}}'->>'{}'", nested_path, escaped_last))
142 }
143 }
144
145 fn operator_to_sql(operator: &WhereOperator) -> Result<&'static str> {
146 Ok(match operator {
147 WhereOperator::Eq => "=",
149 WhereOperator::Neq => "!=",
150 WhereOperator::Gt => ">",
151 WhereOperator::Gte => ">=",
152 WhereOperator::Lt => "<",
153 WhereOperator::Lte => "<=",
154
155 WhereOperator::In => "= ANY",
157 WhereOperator::Nin => "!= ALL",
158
159 WhereOperator::Contains => "LIKE",
161 WhereOperator::Icontains => "ILIKE",
162 WhereOperator::Startswith => "LIKE",
163 WhereOperator::Istartswith => "ILIKE",
164 WhereOperator::Endswith => "LIKE",
165 WhereOperator::Iendswith => "ILIKE",
166 WhereOperator::Like => "LIKE",
167 WhereOperator::Ilike => "ILIKE",
168 WhereOperator::Nlike => "NOT LIKE",
169 WhereOperator::Nilike => "NOT ILIKE",
170 WhereOperator::Regex => "~",
171 WhereOperator::Iregex => "~*",
172 WhereOperator::Nregex => "!~",
173 WhereOperator::Niregex => "!~*",
174
175 WhereOperator::ArrayContains => "@>",
177 WhereOperator::ArrayContainedBy => "<@",
178 WhereOperator::ArrayOverlaps => "&&",
179
180 WhereOperator::IsNull => {
182 return Err(FraiseQLError::Internal {
183 message: "IsNull should be handled separately".to_string(),
184 source: None,
185 });
186 },
187 WhereOperator::LenEq
188 | WhereOperator::LenGt
189 | WhereOperator::LenLt
190 | WhereOperator::LenGte
191 | WhereOperator::LenLte
192 | WhereOperator::LenNeq => {
193 return Err(FraiseQLError::Internal {
194 message: format!(
195 "Array length operators not yet supported in fraiseql-wire: {operator:?}"
196 ),
197 source: None,
198 });
199 },
200
201 WhereOperator::L2Distance
203 | WhereOperator::CosineDistance
204 | WhereOperator::L1Distance
205 | WhereOperator::HammingDistance
206 | WhereOperator::InnerProduct
207 | WhereOperator::JaccardDistance => {
208 return Err(FraiseQLError::Internal {
209 message: format!(
210 "Vector operations not supported in fraiseql-wire: {operator:?}"
211 ),
212 source: None,
213 });
214 },
215
216 WhereOperator::Matches
218 | WhereOperator::PlainQuery
219 | WhereOperator::PhraseQuery
220 | WhereOperator::WebsearchQuery => {
221 return Err(FraiseQLError::Internal {
222 message: format!(
223 "Full-text search operators not yet supported in fraiseql-wire: {operator:?}"
224 ),
225 source: None,
226 });
227 },
228
229 WhereOperator::IsIPv4
231 | WhereOperator::IsIPv6
232 | WhereOperator::IsPrivate
233 | WhereOperator::IsPublic
234 | WhereOperator::IsLoopback
235 | WhereOperator::InSubnet
236 | WhereOperator::ContainsSubnet
237 | WhereOperator::ContainsIP
238 | WhereOperator::Overlaps
239 | WhereOperator::StrictlyContains
240 | WhereOperator::AncestorOf
241 | WhereOperator::DescendantOf
242 | WhereOperator::MatchesLquery
243 | WhereOperator::MatchesLtxtquery
244 | WhereOperator::MatchesAnyLquery
245 | WhereOperator::DepthEq
246 | WhereOperator::DepthNeq
247 | WhereOperator::DepthGt
248 | WhereOperator::DepthGte
249 | WhereOperator::DepthLt
250 | WhereOperator::DepthLte
251 | WhereOperator::Lca
252 | WhereOperator::Extended(_) => {
253 return Err(FraiseQLError::Internal {
254 message: format!(
255 "Advanced operators not yet supported in fraiseql-wire: {operator:?}"
256 ),
257 source: None,
258 });
259 },
260 })
261 }
262
263 fn value_to_sql(value: &Value, operator: &WhereOperator) -> Result<String> {
264 match (value, operator) {
265 (Value::Null, _) => Ok("NULL".to_string()),
266 (Value::Bool(b), _) => Ok(b.to_string()),
267 (Value::Number(n), _) => Ok(n.to_string()),
268
269 (Value::String(s), WhereOperator::Contains | WhereOperator::Icontains) => {
271 Ok(format!("'%{}%'", Self::escape_sql_string(s)?))
272 },
273 (Value::String(s), WhereOperator::Startswith | WhereOperator::Istartswith) => {
274 Ok(format!("'{}%'", Self::escape_sql_string(s)?))
275 },
276 (Value::String(s), WhereOperator::Endswith | WhereOperator::Iendswith) => {
277 Ok(format!("'%{}'", Self::escape_sql_string(s)?))
278 },
279
280 (Value::String(s), _) => Ok(format!("'{}'", Self::escape_sql_string(s)?)),
282
283 (Value::Array(arr), WhereOperator::In | WhereOperator::Nin) => {
285 let values: Result<Vec<_>> =
286 arr.iter().map(|v| Self::value_to_sql(v, &WhereOperator::Eq)).collect();
287 Ok(format!("ARRAY[{}]", values?.join(", ")))
288 },
289
290 (
292 Value::Array(_),
293 WhereOperator::ArrayContains
294 | WhereOperator::ArrayContainedBy
295 | WhereOperator::ArrayOverlaps,
296 ) => {
297 let json_str =
301 serde_json::to_string(value).map_err(|e| FraiseQLError::Internal {
302 message: format!("Failed to serialize JSON for array operator: {e}"),
303 source: None,
304 })?;
305 if json_str.len() > MAX_SQL_VALUE_BYTES {
306 return Err(FraiseQLError::Validation {
307 message: format!(
308 "JSONB value exceeds maximum allowed size for SQL embedding \
309 ({} bytes, limit is {} bytes)",
310 json_str.len(),
311 MAX_SQL_VALUE_BYTES
312 ),
313 path: None,
314 });
315 }
316 let escaped = json_str.replace('\'', "''");
317 Ok(format!("'{}'::jsonb", escaped))
318 },
319
320 _ => Err(FraiseQLError::Internal {
321 message: format!(
322 "Unsupported value type for operator: {value:?} with {operator:?}"
323 ),
324 source: None,
325 }),
326 }
327 }
328
329 fn escape_sql_string(s: &str) -> Result<String> {
330 if s.len() > MAX_SQL_VALUE_BYTES {
331 return Err(FraiseQLError::Validation {
332 message: format!(
333 "String value exceeds maximum allowed size for SQL embedding \
334 ({} bytes, limit is {} bytes)",
335 s.len(),
336 MAX_SQL_VALUE_BYTES
337 ),
338 path: None,
339 });
340 }
341 Ok(s.replace('\'', "''"))
342 }
343}
344
345#[cfg(test)]
346#[allow(clippy::unwrap_used)] mod tests {
348 use serde_json::json;
349
350 use super::*;
351
352 #[test]
353 fn test_simple_equality() {
354 let clause = WhereClause::Field {
355 path: vec!["status".to_string()],
356 operator: WhereOperator::Eq,
357 value: json!("active"),
358 };
359
360 let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
361 assert_eq!(sql, "data->>'status' = 'active'");
362 }
363
364 #[test]
365 fn test_nested_path() {
366 let clause = WhereClause::Field {
367 path: vec!["user".to_string(), "email".to_string()],
368 operator: WhereOperator::Eq,
369 value: json!("test@example.com"),
370 };
371
372 let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
373 assert_eq!(sql, "data#>'{user}'->>'email' = 'test@example.com'");
374 }
375
376 #[test]
377 fn test_icontains() {
378 let clause = WhereClause::Field {
379 path: vec!["name".to_string()],
380 operator: WhereOperator::Icontains,
381 value: json!("john"),
382 };
383
384 let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
385 assert_eq!(sql, "data->>'name' ILIKE '%john%'");
386 }
387
388 #[test]
389 fn test_startswith() {
390 let clause = WhereClause::Field {
391 path: vec!["email".to_string()],
392 operator: WhereOperator::Startswith,
393 value: json!("admin"),
394 };
395
396 let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
397 assert_eq!(sql, "data->>'email' LIKE 'admin%'");
398 }
399
400 #[test]
401 fn test_and_clause() {
402 let clause = WhereClause::And(vec![
403 WhereClause::Field {
404 path: vec!["status".to_string()],
405 operator: WhereOperator::Eq,
406 value: json!("active"),
407 },
408 WhereClause::Field {
409 path: vec!["age".to_string()],
410 operator: WhereOperator::Gte,
411 value: json!(18),
412 },
413 ]);
414
415 let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
416 assert_eq!(sql, "(data->>'status' = 'active' AND data->>'age' >= 18)");
417 }
418
419 #[test]
420 fn test_or_clause() {
421 let clause = WhereClause::Or(vec![
422 WhereClause::Field {
423 path: vec!["type".to_string()],
424 operator: WhereOperator::Eq,
425 value: json!("admin"),
426 },
427 WhereClause::Field {
428 path: vec!["type".to_string()],
429 operator: WhereOperator::Eq,
430 value: json!("moderator"),
431 },
432 ]);
433
434 let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
435 assert_eq!(sql, "(data->>'type' = 'admin' OR data->>'type' = 'moderator')");
436 }
437
438 #[test]
439 fn test_not_clause() {
440 let clause = WhereClause::Not(Box::new(WhereClause::Field {
441 path: vec!["deleted".to_string()],
442 operator: WhereOperator::Eq,
443 value: json!(true),
444 }));
445
446 let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
447 assert_eq!(sql, "NOT (data->>'deleted' = true)");
448 }
449
450 #[test]
451 fn test_is_null() {
452 let clause = WhereClause::Field {
453 path: vec!["deleted_at".to_string()],
454 operator: WhereOperator::IsNull,
455 value: json!(true),
456 };
457
458 let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
459 assert_eq!(sql, "data->>'deleted_at' IS NULL");
460 }
461
462 #[test]
463 fn test_is_not_null() {
464 let clause = WhereClause::Field {
465 path: vec!["updated_at".to_string()],
466 operator: WhereOperator::IsNull,
467 value: json!(false),
468 };
469
470 let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
471 assert_eq!(sql, "data->>'updated_at' IS NOT NULL");
472 }
473
474 #[test]
475 fn test_in_operator() {
476 let clause = WhereClause::Field {
477 path: vec!["status".to_string()],
478 operator: WhereOperator::In,
479 value: json!(["active", "pending", "approved"]),
480 };
481
482 let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
483 assert_eq!(sql, "data->>'status' = ANY ARRAY['active', 'pending', 'approved']");
484 }
485
486 #[test]
487 fn test_sql_injection_prevention() {
488 let clause = WhereClause::Field {
489 path: vec!["name".to_string()],
490 operator: WhereOperator::Eq,
491 value: json!("'; DROP TABLE users; --"),
492 };
493
494 let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
495 assert_eq!(sql, "data->>'name' = '''; DROP TABLE users; --'");
496 }
498
499 #[test]
500 fn test_numeric_comparison() {
501 let clause = WhereClause::Field {
502 path: vec!["price".to_string()],
503 operator: WhereOperator::Gt,
504 value: json!(99.99),
505 };
506
507 let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
508 assert_eq!(sql, "data->>'price' > 99.99");
509 }
510
511 #[test]
512 fn test_boolean_value() {
513 let clause = WhereClause::Field {
514 path: vec!["published".to_string()],
515 operator: WhereOperator::Eq,
516 value: json!(true),
517 };
518
519 let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
520 assert_eq!(sql, "data->>'published' = true");
521 }
522
523 #[test]
524 fn test_empty_and_clause() {
525 let clause = WhereClause::And(vec![]);
526 let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
527 assert_eq!(sql, "TRUE");
528 }
529
530 #[test]
531 fn test_empty_or_clause() {
532 let clause = WhereClause::Or(vec![]);
533 let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
534 assert_eq!(sql, "FALSE");
535 }
536
537 #[test]
538 fn test_complex_nested_condition() {
539 let clause = WhereClause::And(vec![
540 WhereClause::Field {
541 path: vec!["type".to_string()],
542 operator: WhereOperator::Eq,
543 value: json!("article"),
544 },
545 WhereClause::Or(vec![
546 WhereClause::Field {
547 path: vec!["status".to_string()],
548 operator: WhereOperator::Eq,
549 value: json!("published"),
550 },
551 WhereClause::And(vec![
552 WhereClause::Field {
553 path: vec!["status".to_string()],
554 operator: WhereOperator::Eq,
555 value: json!("draft"),
556 },
557 WhereClause::Field {
558 path: vec!["author".to_string(), "role".to_string()],
559 operator: WhereOperator::Eq,
560 value: json!("admin"),
561 },
562 ]),
563 ]),
564 ]);
565
566 let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
567 assert_eq!(
568 sql,
569 "(data->>'type' = 'article' AND (data->>'status' = 'published' OR (data->>'status' = 'draft' AND data#>'{author}'->>'role' = 'admin')))"
570 );
571 }
572
573 #[test]
574 fn test_sql_injection_in_field_name_simple() {
575 let clause = WhereClause::Field {
577 path: vec!["name'; DROP TABLE users; --".to_string()],
578 operator: WhereOperator::Eq,
579 value: json!("value"),
580 };
581
582 let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
583 assert!(sql.contains("''")); assert!(sql.contains("data->>'"));
590 assert!(sql.contains("= 'value'")); }
592
593 #[test]
594 fn test_sql_injection_prevention_in_array_operator() {
595 let clause = WhereClause::Field {
597 path: vec!["tags".to_string()],
598 operator: WhereOperator::ArrayContains,
599 value: json!(["normal", "'; DROP TABLE users; --"]),
600 };
601
602 let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
603 assert!(sql.contains("::jsonb"), "Must produce valid JSONB cast");
606 assert!(
611 sql.contains("''"),
612 "Single quotes inside JSON values must be doubled for SQL safety"
613 );
614 }
615
616 #[test]
617 fn test_sql_injection_in_nested_field_name() {
618 let clause = WhereClause::Field {
620 path: vec![
621 "user".to_string(),
622 "role'; DROP TABLE users; --".to_string(),
623 ],
624 operator: WhereOperator::Eq,
625 value: json!("admin"),
626 };
627
628 let sql = WhereSqlGenerator::to_sql(&clause).unwrap();
629 assert!(sql.contains("''")); assert!(sql.contains("data#>'{")); }
633
634 #[test]
635 fn escape_sql_string_rejects_oversized_input() {
636 let large = "a".repeat(MAX_SQL_VALUE_BYTES + 1);
637 let result = WhereSqlGenerator::escape_sql_string(&large);
638 assert!(matches!(result, Err(FraiseQLError::Validation { .. })));
639 }
640
641 #[test]
642 fn escape_sql_string_accepts_exactly_max_bytes() {
643 let at_limit = "a".repeat(MAX_SQL_VALUE_BYTES);
644 WhereSqlGenerator::escape_sql_string(&at_limit).unwrap_or_else(|e| {
645 panic!("expected Ok for string at exactly MAX_SQL_VALUE_BYTES: {e}")
646 });
647 }
648
649 #[test]
650 fn escape_sql_string_escapes_single_quotes() {
651 let result = WhereSqlGenerator::escape_sql_string("it's").unwrap();
652 assert_eq!(result, "it''s");
653 }
654
655 #[test]
656 fn value_to_sql_rejects_oversized_string_value() {
657 let large = "a".repeat(MAX_SQL_VALUE_BYTES + 1);
658 let clause = WhereClause::Field {
659 path: vec!["name".to_string()],
660 operator: WhereOperator::Eq,
661 value: json!(large),
662 };
663 assert!(matches!(
664 WhereSqlGenerator::to_sql(&clause),
665 Err(FraiseQLError::Validation { .. })
666 ));
667 }
668
669 #[test]
670 fn value_to_sql_rejects_oversized_jsonb_value() {
671 let large_element = "a".repeat(MAX_SQL_VALUE_BYTES);
673 let clause = WhereClause::Field {
674 path: vec!["tags".to_string()],
675 operator: WhereOperator::ArrayContains,
676 value: json!([large_element]),
677 };
678 assert!(matches!(
679 WhereSqlGenerator::to_sql(&clause),
680 Err(FraiseQLError::Validation { .. })
681 ));
682 }
683}