Skip to main content

fraiseql_wire/operators/
sql_gen.rs

1//! SQL generation from operators
2//!
3//! Converts operator enums to PostgreSQL WHERE clause SQL strings.
4//! Handles parameter binding, type casting, and operator-specific SQL generation
5//! for both JSONB and direct column sources.
6//!
7//! # Type Casting Strategy
8//!
9//! JSONB fields extracted with `->>` are always text. When comparing with non-string values,
10//! we apply explicit type casting:
11//!
12//! - String comparisons: No cast needed (text = text)
13//! - Numeric comparisons: Cast to integer or float (text::integer > $1)
14//! - Boolean comparisons: Cast to boolean (text::boolean = true)
15//! - Array comparisons: No special handling (uses array operators)
16//!
17//! Direct columns use native types from the database schema.
18
19use super::{Field, Value, WhereOperator};
20use crate::Result;
21use std::collections::HashMap;
22
23/// Infers the PostgreSQL type cast needed for a value
24///
25/// Returns the type cast suffix (e.g., "::integer", "::text") if needed
26fn infer_type_cast(value: &Value) -> &'static str {
27    match value {
28        Value::String(_) => "::text",
29        Value::Number(_) => "::numeric", // numeric handles both int and float
30        Value::Bool(_) => "::boolean",
31        Value::Null => "",          // no cast for NULL
32        Value::Array(_) => "",      // arrays handled by operators
33        Value::FloatArray(_) => "", // vector operators handle their own casting
34        Value::RawSql(_) => "",     // raw SQL is assumed correct
35    }
36}
37
38/// Generates SQL from a WHERE operator with parameter binding support
39///
40/// # Parameters
41///
42/// - `operator`: The WHERE operator to generate SQL for
43/// - `param_index`: Mutable reference to parameter counter (for $1, $2, etc.)
44/// - `params`: Mutable map to accumulate parameter values (for later binding)
45///
46/// # Returns
47///
48/// SQL string with parameter placeholders ($1, $2, etc.)
49///
50/// # Examples
51///
52/// ```ignore
53/// let mut param_index = 0;
54/// let mut params = HashMap::new();
55/// let op = WhereOperator::Eq(Field::JsonbField("name".to_string()), Value::String("John".to_string()));
56/// let sql = generate_where_operator_sql(&op, &mut param_index, &mut params)?;
57/// assert_eq!(sql, "(data->'name') = $1");
58/// assert_eq!(params[&1], Value::String("John".to_string()));
59/// ```
60pub fn generate_where_operator_sql(
61    operator: &WhereOperator,
62    param_index: &mut usize,
63    params: &mut HashMap<usize, Value>,
64) -> Result<String> {
65    operator.validate().map_err(crate::Error::InvalidSchema)?;
66
67    match operator {
68        // ============ Comparison Operators ============
69        // These operators work on both JSONB and direct columns.
70        // For JSONB text extraction, we apply type casting for proper comparison.
71        WhereOperator::Eq(field, value) => {
72            let field_sql = field.to_sql();
73            if value.is_null() {
74                Ok(format!("{} IS NULL", field_sql))
75            } else {
76                let param_num = *param_index + 1;
77                *param_index += 1;
78                params.insert(param_num, value.clone());
79                // JSONB fields need type cast for non-string comparisons
80                let cast = match field {
81                    Field::JsonbField(_) | Field::JsonbPath(_) => infer_type_cast(value),
82                    Field::DirectColumn(_) => "", // direct columns use native types
83                };
84                Ok(format!("{}{} = ${}", field_sql, cast, param_num))
85            }
86        }
87
88        WhereOperator::Neq(field, value) => {
89            let field_sql = field.to_sql();
90            if value.is_null() {
91                Ok(format!("{} IS NOT NULL", field_sql))
92            } else {
93                let param_num = *param_index + 1;
94                *param_index += 1;
95                params.insert(param_num, value.clone());
96                let cast = match field {
97                    Field::JsonbField(_) | Field::JsonbPath(_) => infer_type_cast(value),
98                    Field::DirectColumn(_) => "",
99                };
100                Ok(format!("{}{} != ${}", field_sql, cast, param_num))
101            }
102        }
103
104        WhereOperator::Gt(field, value) => {
105            let field_sql = field.to_sql();
106            let param_num = *param_index + 1;
107            *param_index += 1;
108            params.insert(param_num, value.clone());
109            let cast = match field {
110                Field::JsonbField(_) | Field::JsonbPath(_) => infer_type_cast(value),
111                Field::DirectColumn(_) => "",
112            };
113            Ok(format!("{}{} > ${}", field_sql, cast, param_num))
114        }
115
116        WhereOperator::Gte(field, value) => {
117            let field_sql = field.to_sql();
118            let param_num = *param_index + 1;
119            *param_index += 1;
120            params.insert(param_num, value.clone());
121            let cast = match field {
122                Field::JsonbField(_) | Field::JsonbPath(_) => infer_type_cast(value),
123                Field::DirectColumn(_) => "",
124            };
125            Ok(format!("{}{} >= ${}", field_sql, cast, param_num))
126        }
127
128        WhereOperator::Lt(field, value) => {
129            let field_sql = field.to_sql();
130            let param_num = *param_index + 1;
131            *param_index += 1;
132            params.insert(param_num, value.clone());
133            let cast = match field {
134                Field::JsonbField(_) | Field::JsonbPath(_) => infer_type_cast(value),
135                Field::DirectColumn(_) => "",
136            };
137            Ok(format!("{}{} < ${}", field_sql, cast, param_num))
138        }
139
140        WhereOperator::Lte(field, value) => {
141            let field_sql = field.to_sql();
142            let param_num = *param_index + 1;
143            *param_index += 1;
144            params.insert(param_num, value.clone());
145            let cast = match field {
146                Field::JsonbField(_) | Field::JsonbPath(_) => infer_type_cast(value),
147                Field::DirectColumn(_) => "",
148            };
149            Ok(format!("{}{} <= ${}", field_sql, cast, param_num))
150        }
151
152        // ============ Array Operators ============
153        WhereOperator::In(field, values) => {
154            let field_sql = field.to_sql();
155            let placeholders: Vec<String> = values
156                .iter()
157                .map(|v| {
158                    let param_num = *param_index + 1;
159                    *param_index += 1;
160                    params.insert(param_num, v.clone());
161                    format!("${}", param_num)
162                })
163                .collect();
164            Ok(format!("{} IN ({})", field_sql, placeholders.join(", ")))
165        }
166
167        WhereOperator::Nin(field, values) => {
168            let field_sql = field.to_sql();
169            let placeholders: Vec<String> = values
170                .iter()
171                .map(|v| {
172                    let param_num = *param_index + 1;
173                    *param_index += 1;
174                    params.insert(param_num, v.clone());
175                    format!("${}", param_num)
176                })
177                .collect();
178            Ok(format!(
179                "{} NOT IN ({})",
180                field_sql,
181                placeholders.join(", ")
182            ))
183        }
184
185        WhereOperator::Contains(field, substring) => {
186            let field_sql = field.to_sql();
187            let param_num = *param_index + 1;
188            *param_index += 1;
189            params.insert(param_num, Value::String(substring.clone()));
190            Ok(format!(
191                "{} LIKE '%' || ${}::text || '%'",
192                field_sql, param_num
193            ))
194        }
195
196        WhereOperator::ArrayContains(field, value) => {
197            let field_sql = field.to_sql();
198            let param_num = *param_index + 1;
199            *param_index += 1;
200            params.insert(param_num, value.clone());
201            Ok(format!("{} @> ARRAY[${}]", field_sql, param_num))
202        }
203
204        WhereOperator::ArrayContainedBy(field, value) => {
205            let field_sql = field.to_sql();
206            let param_num = *param_index + 1;
207            *param_index += 1;
208            params.insert(param_num, value.clone());
209            Ok(format!("{} <@ ARRAY[${}]", field_sql, param_num))
210        }
211
212        WhereOperator::ArrayOverlaps(field, values) => {
213            let field_sql = field.to_sql();
214            let placeholders: Vec<String> = values
215                .iter()
216                .map(|v| {
217                    let param_num = *param_index + 1;
218                    *param_index += 1;
219                    params.insert(param_num, v.clone());
220                    format!("${}", param_num)
221                })
222                .collect();
223            Ok(format!(
224                "{} && ARRAY[{}]",
225                field_sql,
226                placeholders.join(", ")
227            ))
228        }
229
230        // ============ Array Length Operators ============
231        WhereOperator::LenEq(field, len) => {
232            let field_sql = field.to_sql();
233            Ok(format!("array_length({}, 1) = {}", field_sql, len))
234        }
235
236        WhereOperator::LenGt(field, len) => {
237            let field_sql = field.to_sql();
238            Ok(format!("array_length({}, 1) > {}", field_sql, len))
239        }
240
241        WhereOperator::LenGte(field, len) => {
242            let field_sql = field.to_sql();
243            Ok(format!("array_length({}, 1) >= {}", field_sql, len))
244        }
245
246        WhereOperator::LenLt(field, len) => {
247            let field_sql = field.to_sql();
248            Ok(format!("array_length({}, 1) < {}", field_sql, len))
249        }
250
251        WhereOperator::LenLte(field, len) => {
252            let field_sql = field.to_sql();
253            Ok(format!("array_length({}, 1) <= {}", field_sql, len))
254        }
255
256        // ============ String Operators ============
257        WhereOperator::Icontains(field, substring) => {
258            let field_sql = field.to_sql();
259            let param_num = *param_index + 1;
260            *param_index += 1;
261            params.insert(param_num, Value::String(substring.clone()));
262            Ok(format!(
263                "{} ILIKE '%' || ${}::text || '%'",
264                field_sql, param_num
265            ))
266        }
267
268        WhereOperator::Startswith(field, prefix) => {
269            let field_sql = field.to_sql();
270            let param_num = *param_index + 1;
271            *param_index += 1;
272            params.insert(param_num, Value::String(format!("{}%", prefix)));
273            Ok(format!("{} LIKE ${}", field_sql, param_num))
274        }
275
276        WhereOperator::Endswith(field, suffix) => {
277            let field_sql = field.to_sql();
278            let param_num = *param_index + 1;
279            *param_index += 1;
280            params.insert(param_num, Value::String(format!("%{}", suffix)));
281            Ok(format!("{} LIKE ${}", field_sql, param_num))
282        }
283
284        WhereOperator::Like(field, pattern) => {
285            let field_sql = field.to_sql();
286            let param_num = *param_index + 1;
287            *param_index += 1;
288            params.insert(param_num, Value::String(pattern.clone()));
289            Ok(format!("{} LIKE ${}", field_sql, param_num))
290        }
291
292        WhereOperator::Ilike(field, pattern) => {
293            let field_sql = field.to_sql();
294            let param_num = *param_index + 1;
295            *param_index += 1;
296            params.insert(param_num, Value::String(pattern.clone()));
297            Ok(format!("{} ILIKE ${}", field_sql, param_num))
298        }
299
300        // ============ Null Operator ============
301        WhereOperator::IsNull(field, is_null) => {
302            let field_sql = field.to_sql();
303            if *is_null {
304                Ok(format!("{} IS NULL", field_sql))
305            } else {
306                Ok(format!("{} IS NOT NULL", field_sql))
307            }
308        }
309
310        // ============ Vector Distance Operators (pgvector) ============
311        WhereOperator::L2Distance {
312            field,
313            vector,
314            threshold,
315        } => {
316            let field_sql = field.to_sql();
317            let param_num = *param_index + 1;
318            *param_index += 1;
319            params.insert(param_num, Value::FloatArray(vector.clone()));
320            Ok(format!(
321                "l2_distance({}::vector, ${}::vector) < {}",
322                field_sql, param_num, threshold
323            ))
324        }
325
326        WhereOperator::CosineDistance {
327            field,
328            vector,
329            threshold,
330        } => {
331            let field_sql = field.to_sql();
332            let param_num = *param_index + 1;
333            *param_index += 1;
334            params.insert(param_num, Value::FloatArray(vector.clone()));
335            Ok(format!(
336                "cosine_distance({}::vector, ${}::vector) < {}",
337                field_sql, param_num, threshold
338            ))
339        }
340
341        WhereOperator::InnerProduct {
342            field,
343            vector,
344            threshold,
345        } => {
346            let field_sql = field.to_sql();
347            let param_num = *param_index + 1;
348            *param_index += 1;
349            params.insert(param_num, Value::FloatArray(vector.clone()));
350            Ok(format!(
351                "inner_product({}::vector, ${}::vector) > {}",
352                field_sql, param_num, threshold
353            ))
354        }
355
356        WhereOperator::JaccardDistance {
357            field,
358            set,
359            threshold,
360        } => {
361            let field_sql = field.to_sql();
362            let param_num = *param_index + 1;
363            *param_index += 1;
364            let value_array: Vec<Value> = set.iter().map(|s| Value::String(s.clone())).collect();
365            params.insert(param_num, Value::Array(value_array));
366            Ok(format!(
367                "jaccard_distance({}::text[], ${}::text[]) < {}",
368                field_sql, param_num, threshold
369            ))
370        }
371
372        // ============ Full-Text Search Operators ============
373        WhereOperator::Matches {
374            field,
375            query,
376            language,
377        } => {
378            let field_sql = field.to_sql();
379            let param_num = *param_index + 1;
380            *param_index += 1;
381            params.insert(param_num, Value::String(query.clone()));
382            let lang = language.as_deref().unwrap_or("english");
383            Ok(format!(
384                "{} @@ plainto_tsquery('{}', ${})",
385                field_sql, lang, param_num
386            ))
387        }
388
389        WhereOperator::PlainQuery { field, query } => {
390            let field_sql = field.to_sql();
391            let param_num = *param_index + 1;
392            *param_index += 1;
393            params.insert(param_num, Value::String(query.clone()));
394            Ok(format!(
395                "{} @@ plainto_tsquery(${})::tsvector",
396                field_sql, param_num
397            ))
398        }
399
400        WhereOperator::PhraseQuery {
401            field,
402            query,
403            language,
404        } => {
405            let field_sql = field.to_sql();
406            let param_num = *param_index + 1;
407            *param_index += 1;
408            params.insert(param_num, Value::String(query.clone()));
409            let lang = language.as_deref().unwrap_or("english");
410            Ok(format!(
411                "{} @@ phraseto_tsquery('{}', ${})",
412                field_sql, lang, param_num
413            ))
414        }
415
416        WhereOperator::WebsearchQuery {
417            field,
418            query,
419            language,
420        } => {
421            let field_sql = field.to_sql();
422            let param_num = *param_index + 1;
423            *param_index += 1;
424            params.insert(param_num, Value::String(query.clone()));
425            let lang = language.as_deref().unwrap_or("english");
426            Ok(format!(
427                "{} @@ websearch_to_tsquery('{}', ${})",
428                field_sql, lang, param_num
429            ))
430        }
431
432        // ============ Network/INET Operators ============
433        WhereOperator::IsIPv4(field) => {
434            let field_sql = field.to_sql();
435            Ok(format!("family({}::inet) = 4", field_sql))
436        }
437
438        WhereOperator::IsIPv6(field) => {
439            let field_sql = field.to_sql();
440            Ok(format!("family({}::inet) = 6", field_sql))
441        }
442
443        WhereOperator::IsPrivate(field) => {
444            let field_sql = field.to_sql();
445            // RFC1918 private ranges + link-local
446            Ok(format!(
447                "({}::inet << '10.0.0.0/8'::inet OR {}::inet << '172.16.0.0/12'::inet OR {}::inet << '192.168.0.0/16'::inet OR {}::inet << '169.254.0.0/16'::inet)",
448                field_sql, field_sql, field_sql, field_sql
449            ))
450        }
451
452        WhereOperator::IsLoopback(field) => {
453            let field_sql = field.to_sql();
454            Ok(format!(
455                "(family({}::inet) = 4 AND {}::inet << '127.0.0.0/8'::inet) OR (family({}::inet) = 6 AND {}::inet << '::1/128'::inet)",
456                field_sql, field_sql, field_sql, field_sql
457            ))
458        }
459
460        WhereOperator::InSubnet { field, subnet } => {
461            let field_sql = field.to_sql();
462            let param_num = *param_index + 1;
463            *param_index += 1;
464            params.insert(param_num, Value::String(subnet.clone()));
465            Ok(format!("{}::inet << ${}::inet", field_sql, param_num))
466        }
467
468        WhereOperator::ContainsSubnet { field, subnet } => {
469            let field_sql = field.to_sql();
470            let param_num = *param_index + 1;
471            *param_index += 1;
472            params.insert(param_num, Value::String(subnet.clone()));
473            Ok(format!("{}::inet >> ${}::inet", field_sql, param_num))
474        }
475
476        WhereOperator::ContainsIP { field, ip } => {
477            let field_sql = field.to_sql();
478            let param_num = *param_index + 1;
479            *param_index += 1;
480            params.insert(param_num, Value::String(ip.clone()));
481            Ok(format!("{}::inet >> ${}::inet", field_sql, param_num))
482        }
483
484        WhereOperator::IPRangeOverlap { field, range } => {
485            let field_sql = field.to_sql();
486            let param_num = *param_index + 1;
487            *param_index += 1;
488            params.insert(param_num, Value::String(range.clone()));
489            Ok(format!("{}::inet && ${}::inet", field_sql, param_num))
490        }
491    }
492}
493
494#[cfg(test)]
495mod tests {
496    use super::*;
497
498    #[test]
499    fn test_eq_operator_jsonb_string() {
500        let mut param_index = 0;
501        let mut params = HashMap::new();
502        let op = WhereOperator::Eq(
503            Field::JsonbField("name".to_string()),
504            Value::String("John".to_string()),
505        );
506        let sql = generate_where_operator_sql(&op, &mut param_index, &mut params).unwrap();
507        // JSONB string fields get ::text cast for proper text comparison
508        assert_eq!(sql, "(data->'name')::text = $1");
509        assert_eq!(param_index, 1);
510    }
511
512    #[test]
513    fn test_eq_operator_direct_column() {
514        let mut param_index = 0;
515        let mut params = HashMap::new();
516        let op = WhereOperator::Eq(
517            Field::DirectColumn("status".to_string()),
518            Value::String("active".to_string()),
519        );
520        let sql = generate_where_operator_sql(&op, &mut param_index, &mut params).unwrap();
521        // Direct columns don't need casting (use native types)
522        assert_eq!(sql, "status = $1");
523        assert_eq!(param_index, 1);
524    }
525
526    #[test]
527    fn test_len_eq_operator() {
528        let mut param_index = 0;
529        let mut params = HashMap::new();
530        let op = WhereOperator::LenEq(Field::JsonbField("tags".to_string()), 5);
531        let sql = generate_where_operator_sql(&op, &mut param_index, &mut params).unwrap();
532        assert_eq!(sql, "array_length((data->'tags'), 1) = 5");
533        assert_eq!(param_index, 0); // No parameters for length operators
534    }
535
536    #[test]
537    fn test_is_ipv4_operator() {
538        let mut param_index = 0;
539        let mut params = HashMap::new();
540        let op = WhereOperator::IsIPv4(Field::JsonbField("ip".to_string()));
541        let sql = generate_where_operator_sql(&op, &mut param_index, &mut params).unwrap();
542        assert_eq!(sql, "family((data->'ip')::inet) = 4");
543    }
544
545    #[test]
546    fn test_l2_distance_operator() {
547        let mut param_index = 0;
548        let mut params = HashMap::new();
549        let op = WhereOperator::L2Distance {
550            field: Field::JsonbField("embedding".to_string()),
551            vector: vec![0.1, 0.2, 0.3],
552            threshold: 0.5,
553        };
554        let sql = generate_where_operator_sql(&op, &mut param_index, &mut params).unwrap();
555        assert_eq!(
556            sql,
557            "l2_distance((data->'embedding')::vector, $1::vector) < 0.5"
558        );
559        assert_eq!(param_index, 1);
560    }
561
562    #[test]
563    fn test_in_operator() {
564        let mut param_index = 0;
565        let mut params = HashMap::new();
566        let op = WhereOperator::In(
567            Field::JsonbField("status".to_string()),
568            vec![
569                Value::String("active".to_string()),
570                Value::String("pending".to_string()),
571            ],
572        );
573        let sql = generate_where_operator_sql(&op, &mut param_index, &mut params).unwrap();
574        assert_eq!(sql, "(data->'status') IN ($1, $2)");
575        assert_eq!(param_index, 2);
576    }
577}