Skip to main content

fraiseql_wire/operators/
where_operator.rs

1//! WHERE clause operators
2//!
3//! Type-safe operator definitions for WHERE clause generation.
4//! Supports 25+ operators across 5 categories with both JSONB and direct column sources.
5
6use super::field::{Field, Value};
7
8/// WHERE clause operators
9///
10/// Supports type-safe, audit-friendly WHERE clause construction
11/// without raw SQL strings.
12///
13/// # Categories
14///
15/// - **Comparison**: Eq, Neq, Gt, Gte, Lt, Lte
16/// - **Array**: In, Nin, Contains, `ArrayContains`, `ArrayContainedBy`, `ArrayOverlaps`
17/// - **Array Length**: `LenEq`, `LenGt`, `LenGte`, `LenLt`, `LenLte`
18/// - **String**: Icontains, Startswith, Istartswith, Endswith, Iendswith, Like, Ilike
19/// - **Null**: `IsNull`
20/// - **Vector Distance**: `L2Distance`, `CosineDistance`, `InnerProduct`, `L1Distance`, `HammingDistance`, `JaccardDistance`
21/// - **Full-Text Search**: Matches, `PlainQuery`, `PhraseQuery`, `WebsearchQuery`
22/// - **Network**: `IsIPv4`, `IsIPv6`, `IsPrivate`, `IsPublic`, `IsLoopback`, `InSubnet`, `ContainsSubnet`, `ContainsIP`, `IPRangeOverlap`
23/// - **JSONB**: `StrictlyContains`
24/// - **`LTree`**: `AncestorOf`, `DescendantOf`, `MatchesLquery`, `MatchesLtxtquery`, `MatchesAnyLquery`,
25///   `DepthEq`, `DepthNeq`, `DepthGt`, `DepthGte`, `DepthLt`, `DepthLte`, Lca
26#[derive(Debug, Clone)]
27#[non_exhaustive]
28pub enum WhereOperator {
29    // ============ Comparison Operators ============
30    /// Equal: `field = value`
31    Eq(Field, Value),
32
33    /// Not equal: `field != value` or `field <> value`
34    Neq(Field, Value),
35
36    /// Greater than: `field > value`
37    Gt(Field, Value),
38
39    /// Greater than or equal: `field >= value`
40    Gte(Field, Value),
41
42    /// Less than: `field < value`
43    Lt(Field, Value),
44
45    /// Less than or equal: `field <= value`
46    Lte(Field, Value),
47
48    // ============ Array Operators ============
49    /// Array contains value: `field IN (...)`
50    In(Field, Vec<Value>),
51
52    /// Array does not contain value: `field NOT IN (...)`
53    Nin(Field, Vec<Value>),
54
55    /// String contains substring: `field LIKE '%substring%'`
56    ///
57    /// # Warning — LIKE wildcard injection
58    ///
59    /// The substring is embedded between `%` anchors. Characters `%` (any sequence) and
60    /// `_` (any single character) in the substring itself are **also** treated as LIKE
61    /// wildcards, causing broader matches than intended. Escape user-supplied values
62    /// (replace `%` → `\%` and `_` → `\_`) before constructing this variant.
63    Contains(Field, String),
64
65    /// Array contains element: PostgreSQL array operator `@>`
66    /// Generated SQL: `field @> array[value]`
67    ArrayContains(Field, Value),
68
69    /// Array is contained by: PostgreSQL array operator `<@`
70    /// Generated SQL: `field <@ array[value]`
71    ArrayContainedBy(Field, Value),
72
73    /// Arrays overlap: PostgreSQL array operator `&&`
74    /// Generated SQL: `field && array[value]`
75    ArrayOverlaps(Field, Vec<Value>),
76
77    // ============ Array Length Operators ============
78    /// Array length equals: `array_length(field, 1) = value`
79    LenEq(Field, usize),
80
81    /// Array length greater than: `array_length(field, 1) > value`
82    LenGt(Field, usize),
83
84    /// Array length greater than or equal: `array_length(field, 1) >= value`
85    LenGte(Field, usize),
86
87    /// Array length less than: `array_length(field, 1) < value`
88    LenLt(Field, usize),
89
90    /// Array length less than or equal: `array_length(field, 1) <= value`
91    LenLte(Field, usize),
92
93    // ============ String Operators ============
94    /// Case-insensitive contains: `field ILIKE '%substring%'`
95    ///
96    /// # Warning — LIKE wildcard injection
97    ///
98    /// Same escaping requirements as [`WhereOperator::Contains`]. ILIKE applies the same
99    /// wildcard semantics, so `%` and `_` in the substring expand unexpectedly.
100    Icontains(Field, String),
101
102    /// Starts with: `field LIKE 'prefix%'`
103    ///
104    /// # Warning — LIKE wildcard injection
105    ///
106    /// The prefix string is passed verbatim before the trailing `%`. Characters `%` and `_`
107    /// in the prefix are treated as wildcards; escape user-supplied values before use.
108    Startswith(Field, String),
109
110    /// Starts with (case-insensitive): `field ILIKE 'prefix%'`
111    ///
112    /// # Warning — LIKE wildcard injection
113    ///
114    /// Same escaping requirements as [`WhereOperator::Startswith`].
115    Istartswith(Field, String),
116
117    /// Ends with: `field LIKE '%suffix'`
118    ///
119    /// # Warning — LIKE wildcard injection
120    ///
121    /// The suffix string is passed verbatim after the leading `%`. Characters `%` and `_`
122    /// in the suffix are treated as wildcards; escape user-supplied values before use.
123    Endswith(Field, String),
124
125    /// Ends with (case-insensitive): `field ILIKE '%suffix'`
126    ///
127    /// # Warning — LIKE wildcard injection
128    ///
129    /// Same escaping requirements as [`WhereOperator::Endswith`].
130    Iendswith(Field, String),
131
132    /// LIKE pattern matching: `field LIKE pattern`
133    ///
134    /// # Warning — LIKE wildcard injection
135    ///
136    /// The pattern string is passed to the database as-is. Characters `%` (any
137    /// sequence) and `_` (any single character) are SQL LIKE wildcards. If the
138    /// pattern originates from user input, callers **must** escape these
139    /// characters (e.g. replace `%` → `\%` and `_` → `\_`) before constructing
140    /// this variant, and append the appropriate `ESCAPE '\'` clause.
141    Like(Field, String),
142
143    /// Case-insensitive LIKE: `field ILIKE pattern`
144    ///
145    /// # Warning — LIKE wildcard injection
146    ///
147    /// Same escaping requirements as [`WhereOperator::Like`]. The `%` and `_`
148    /// wildcards apply to ILIKE patterns as well.
149    Ilike(Field, String),
150
151    // ============ Null Operator ============
152    /// IS NULL: `field IS NULL` or `field IS NOT NULL`
153    ///
154    /// When the boolean is true, generates `IS NULL`
155    /// When false, generates `IS NOT NULL`
156    IsNull(Field, bool),
157
158    // ============ Vector Distance Operators (pgvector) ============
159    /// L2 (Euclidean) distance: `l2_distance(field, vector) < threshold`
160    ///
161    /// Requires pgvector extension.
162    L2Distance {
163        /// The vector field to compare against
164        field: Field,
165        /// The embedding vector for distance calculation
166        vector: Vec<f32>,
167        /// Distance threshold for comparison
168        threshold: f32,
169    },
170
171    /// Cosine distance: `cosine_distance(field, vector) < threshold`
172    ///
173    /// Requires pgvector extension.
174    CosineDistance {
175        /// The vector field to compare against
176        field: Field,
177        /// The embedding vector for distance calculation
178        vector: Vec<f32>,
179        /// Distance threshold for comparison
180        threshold: f32,
181    },
182
183    /// Inner product: `inner_product(field, vector) > threshold`
184    ///
185    /// Requires pgvector extension.
186    InnerProduct {
187        /// The vector field to compare against
188        field: Field,
189        /// The embedding vector for distance calculation
190        vector: Vec<f32>,
191        /// Distance threshold for comparison
192        threshold: f32,
193    },
194
195    /// L1 (Manhattan) distance: `l1_distance(field, vector) < threshold`
196    ///
197    /// Requires pgvector extension.
198    L1Distance {
199        /// The vector field to compare against
200        field: Field,
201        /// The embedding vector for distance calculation
202        vector: Vec<f32>,
203        /// Distance threshold for comparison
204        threshold: f32,
205    },
206
207    /// Hamming distance: `hamming_distance(field, vector) < threshold`
208    ///
209    /// Requires pgvector extension. Works with bit vectors.
210    HammingDistance {
211        /// The vector field to compare against
212        field: Field,
213        /// The embedding vector for distance calculation
214        vector: Vec<f32>,
215        /// Distance threshold for comparison
216        threshold: f32,
217    },
218
219    /// Jaccard distance: `jaccard_distance(field, set) < threshold`
220    ///
221    /// Works with text arrays, measures set overlap.
222    JaccardDistance {
223        /// The field to compare against
224        field: Field,
225        /// The set of values for comparison
226        set: Vec<String>,
227        /// Distance threshold for comparison
228        threshold: f32,
229    },
230
231    // ============ Full-Text Search Operators ============
232    /// Full-text search with language: `field @@ plainto_tsquery(language, query)`
233    ///
234    /// If language is None, defaults to 'english'
235    Matches {
236        /// The text field to search
237        field: Field,
238        /// The search query
239        query: String,
240        /// Optional language for text search (default: english)
241        language: Option<String>,
242    },
243
244    /// Plain text query: `field @@ plainto_tsquery(query)`
245    ///
246    /// Uses no language specification
247    PlainQuery {
248        /// The text field to search
249        field: Field,
250        /// The search query
251        query: String,
252    },
253
254    /// Phrase query with language: `field @@ phraseto_tsquery(language, query)`
255    ///
256    /// If language is None, defaults to 'english'
257    PhraseQuery {
258        /// The text field to search
259        field: Field,
260        /// The search query
261        query: String,
262        /// Optional language for text search (default: english)
263        language: Option<String>,
264    },
265
266    /// Web search query with language: `field @@ websearch_to_tsquery(language, query)`
267    ///
268    /// If language is None, defaults to 'english'
269    WebsearchQuery {
270        /// The text field to search
271        field: Field,
272        /// The search query
273        query: String,
274        /// Optional language for text search (default: english)
275        language: Option<String>,
276    },
277
278    // ============ Network/INET Operators ============
279    /// Check if IP is `IPv4`: `family(field) = 4`
280    IsIPv4(Field),
281
282    /// Check if IP is `IPv6`: `family(field) = 6`
283    IsIPv6(Field),
284
285    /// Check if IP is private (RFC1918): matches private ranges
286    IsPrivate(Field),
287
288    /// Check if IP is public (not private): opposite of `IsPrivate`
289    IsPublic(Field),
290
291    /// Check if IP is loopback: `IPv4` 127.0.0.0/8 or `IPv6` `::1/128`
292    IsLoopback(Field),
293
294    /// IP is in subnet: `field << subnet`
295    ///
296    /// The subnet should be in CIDR notation (e.g., "192.168.0.0/24")
297    InSubnet {
298        /// The IP field to check
299        field: Field,
300        /// The CIDR subnet (e.g., "192.168.0.0/24")
301        subnet: String,
302    },
303
304    /// Network contains subnet: `field >> subnet`
305    ///
306    /// The subnet should be in CIDR notation
307    ContainsSubnet {
308        /// The network field to check
309        field: Field,
310        /// The CIDR subnet to check for containment
311        subnet: String,
312    },
313
314    /// Network/range contains IP: `field >> ip`
315    ///
316    /// The IP should be a single address (e.g., "192.168.1.1")
317    ContainsIP {
318        /// The network field to check
319        field: Field,
320        /// The IP address to check for containment
321        ip: String,
322    },
323
324    /// IP ranges overlap: `field && range`
325    ///
326    /// The range should be in CIDR notation
327    IPRangeOverlap {
328        /// The IP range field to check
329        field: Field,
330        /// The IP range to check for overlap
331        range: String,
332    },
333
334    // ============ JSONB Operators ============
335    /// JSONB strictly contains: `field @> value`
336    ///
337    /// Checks if the JSONB field strictly contains the given value
338    StrictlyContains(Field, Value),
339
340    // ============ LTree Operators (Hierarchical) ============
341    /// Ancestor of: `field @> path`
342    ///
343    /// Checks if the ltree field is an ancestor of the given path
344    AncestorOf {
345        /// The ltree field to check
346        field: Field,
347        /// The path to check ancestry against
348        path: String,
349    },
350
351    /// Descendant of: `field <@ path`
352    ///
353    /// Checks if the ltree field is a descendant of the given path
354    DescendantOf {
355        /// The ltree field to check
356        field: Field,
357        /// The path to check descendancy against
358        path: String,
359    },
360
361    /// Matches lquery: `field ~ lquery`
362    ///
363    /// Checks if the ltree field matches the given lquery pattern
364    MatchesLquery {
365        /// The ltree field to check
366        field: Field,
367        /// The lquery pattern to match against
368        pattern: String,
369    },
370
371    /// Matches ltxtquery: `field @ ltxtquery`
372    ///
373    /// Checks if the ltree field matches the given ltxtquery pattern (Boolean query syntax)
374    MatchesLtxtquery {
375        /// The ltree field to check
376        field: Field,
377        /// The ltxtquery pattern to match against (e.g., "Science & !Deprecated")
378        query: String,
379    },
380
381    /// Matches any lquery: `field ? array[lqueries]`
382    ///
383    /// Checks if the ltree field matches any of the given lquery patterns
384    MatchesAnyLquery {
385        /// The ltree field to check
386        field: Field,
387        /// Array of lquery patterns to match against
388        patterns: Vec<String>,
389    },
390
391    /// `LTree` depth equals: `nlevel(field) = depth`
392    DepthEq {
393        /// The ltree field to check
394        field: Field,
395        /// The depth value to compare
396        depth: usize,
397    },
398
399    /// `LTree` depth not equals: `nlevel(field) != depth`
400    DepthNeq {
401        /// The ltree field to check
402        field: Field,
403        /// The depth value to compare
404        depth: usize,
405    },
406
407    /// `LTree` depth greater than: `nlevel(field) > depth`
408    DepthGt {
409        /// The ltree field to check
410        field: Field,
411        /// The depth value to compare
412        depth: usize,
413    },
414
415    /// `LTree` depth greater than or equal: `nlevel(field) >= depth`
416    DepthGte {
417        /// The ltree field to check
418        field: Field,
419        /// The depth value to compare
420        depth: usize,
421    },
422
423    /// `LTree` depth less than: `nlevel(field) < depth`
424    DepthLt {
425        /// The ltree field to check
426        field: Field,
427        /// The depth value to compare
428        depth: usize,
429    },
430
431    /// `LTree` depth less than or equal: `nlevel(field) <= depth`
432    DepthLte {
433        /// The ltree field to check
434        field: Field,
435        /// The depth value to compare
436        depth: usize,
437    },
438
439    /// Lowest common ancestor: `lca(field, paths)`
440    ///
441    /// Checks if the ltree field equals the lowest common ancestor of the given paths
442    Lca {
443        /// The ltree field to check
444        field: Field,
445        /// The paths to find LCA of
446        paths: Vec<String>,
447    },
448}
449
450impl WhereOperator {
451    /// Get a human-readable name for this operator
452    pub const fn name(&self) -> &'static str {
453        match self {
454            WhereOperator::Eq(_, _) => "Eq",
455            WhereOperator::Neq(_, _) => "Neq",
456            WhereOperator::Gt(_, _) => "Gt",
457            WhereOperator::Gte(_, _) => "Gte",
458            WhereOperator::Lt(_, _) => "Lt",
459            WhereOperator::Lte(_, _) => "Lte",
460            WhereOperator::In(_, _) => "In",
461            WhereOperator::Nin(_, _) => "Nin",
462            WhereOperator::Contains(_, _) => "Contains",
463            WhereOperator::ArrayContains(_, _) => "ArrayContains",
464            WhereOperator::ArrayContainedBy(_, _) => "ArrayContainedBy",
465            WhereOperator::ArrayOverlaps(_, _) => "ArrayOverlaps",
466            WhereOperator::LenEq(_, _) => "LenEq",
467            WhereOperator::LenGt(_, _) => "LenGt",
468            WhereOperator::LenGte(_, _) => "LenGte",
469            WhereOperator::LenLt(_, _) => "LenLt",
470            WhereOperator::LenLte(_, _) => "LenLte",
471            WhereOperator::Icontains(_, _) => "Icontains",
472            WhereOperator::Startswith(_, _) => "Startswith",
473            WhereOperator::Istartswith(_, _) => "Istartswith",
474            WhereOperator::Endswith(_, _) => "Endswith",
475            WhereOperator::Iendswith(_, _) => "Iendswith",
476            WhereOperator::Like(_, _) => "Like",
477            WhereOperator::Ilike(_, _) => "Ilike",
478            WhereOperator::IsNull(_, _) => "IsNull",
479            WhereOperator::L2Distance { .. } => "L2Distance",
480            WhereOperator::CosineDistance { .. } => "CosineDistance",
481            WhereOperator::InnerProduct { .. } => "InnerProduct",
482            WhereOperator::L1Distance { .. } => "L1Distance",
483            WhereOperator::HammingDistance { .. } => "HammingDistance",
484            WhereOperator::JaccardDistance { .. } => "JaccardDistance",
485            WhereOperator::Matches { .. } => "Matches",
486            WhereOperator::PlainQuery { .. } => "PlainQuery",
487            WhereOperator::PhraseQuery { .. } => "PhraseQuery",
488            WhereOperator::WebsearchQuery { .. } => "WebsearchQuery",
489            WhereOperator::IsIPv4(_) => "IsIPv4",
490            WhereOperator::IsIPv6(_) => "IsIPv6",
491            WhereOperator::IsPrivate(_) => "IsPrivate",
492            WhereOperator::IsPublic(_) => "IsPublic",
493            WhereOperator::IsLoopback(_) => "IsLoopback",
494            WhereOperator::InSubnet { .. } => "InSubnet",
495            WhereOperator::ContainsSubnet { .. } => "ContainsSubnet",
496            WhereOperator::ContainsIP { .. } => "ContainsIP",
497            WhereOperator::IPRangeOverlap { .. } => "IPRangeOverlap",
498            WhereOperator::StrictlyContains(_, _) => "StrictlyContains",
499            WhereOperator::AncestorOf { .. } => "AncestorOf",
500            WhereOperator::DescendantOf { .. } => "DescendantOf",
501            WhereOperator::MatchesLquery { .. } => "MatchesLquery",
502            WhereOperator::MatchesLtxtquery { .. } => "MatchesLtxtquery",
503            WhereOperator::MatchesAnyLquery { .. } => "MatchesAnyLquery",
504            WhereOperator::DepthEq { .. } => "DepthEq",
505            WhereOperator::DepthNeq { .. } => "DepthNeq",
506            WhereOperator::DepthGt { .. } => "DepthGt",
507            WhereOperator::DepthGte { .. } => "DepthGte",
508            WhereOperator::DepthLt { .. } => "DepthLt",
509            WhereOperator::DepthLte { .. } => "DepthLte",
510            WhereOperator::Lca { .. } => "Lca",
511        }
512    }
513
514    /// Validate operator for basic correctness
515    ///
516    /// # Errors
517    ///
518    /// Returns a descriptive error string if any field within the operator fails validation.
519    pub fn validate(&self) -> Result<(), String> {
520        match self {
521            WhereOperator::Eq(f, _)
522            | WhereOperator::Neq(f, _)
523            | WhereOperator::Gt(f, _)
524            | WhereOperator::Gte(f, _)
525            | WhereOperator::Lt(f, _)
526            | WhereOperator::Lte(f, _)
527            | WhereOperator::In(f, _)
528            | WhereOperator::Nin(f, _)
529            | WhereOperator::Contains(f, _)
530            | WhereOperator::ArrayContains(f, _)
531            | WhereOperator::ArrayContainedBy(f, _)
532            | WhereOperator::ArrayOverlaps(f, _)
533            | WhereOperator::LenEq(f, _)
534            | WhereOperator::LenGt(f, _)
535            | WhereOperator::LenGte(f, _)
536            | WhereOperator::LenLt(f, _)
537            | WhereOperator::LenLte(f, _)
538            | WhereOperator::Icontains(f, _)
539            | WhereOperator::Startswith(f, _)
540            | WhereOperator::Istartswith(f, _)
541            | WhereOperator::Endswith(f, _)
542            | WhereOperator::Iendswith(f, _)
543            | WhereOperator::Like(f, _)
544            | WhereOperator::Ilike(f, _)
545            | WhereOperator::IsNull(f, _)
546            | WhereOperator::StrictlyContains(f, _) => f.validate(),
547
548            WhereOperator::L2Distance {
549                field, threshold, ..
550            }
551            | WhereOperator::CosineDistance {
552                field, threshold, ..
553            }
554            | WhereOperator::InnerProduct {
555                field, threshold, ..
556            }
557            | WhereOperator::L1Distance {
558                field, threshold, ..
559            }
560            | WhereOperator::HammingDistance {
561                field, threshold, ..
562            }
563            | WhereOperator::JaccardDistance {
564                field, threshold, ..
565            } => {
566                if !threshold.is_finite() {
567                    return Err(format!(
568                        "Vector distance threshold must be a finite number, got {}",
569                        threshold
570                    ));
571                }
572                field.validate()
573            }
574
575            WhereOperator::Matches { field, .. }
576            | WhereOperator::PlainQuery { field, .. }
577            | WhereOperator::PhraseQuery { field, .. }
578            | WhereOperator::WebsearchQuery { field, .. }
579            | WhereOperator::IsIPv4(field)
580            | WhereOperator::IsIPv6(field)
581            | WhereOperator::IsPrivate(field)
582            | WhereOperator::IsPublic(field)
583            | WhereOperator::IsLoopback(field)
584            | WhereOperator::InSubnet { field, .. }
585            | WhereOperator::ContainsSubnet { field, .. }
586            | WhereOperator::ContainsIP { field, .. }
587            | WhereOperator::IPRangeOverlap { field, .. }
588            | WhereOperator::AncestorOf { field, .. }
589            | WhereOperator::DescendantOf { field, .. }
590            | WhereOperator::MatchesLquery { field, .. }
591            | WhereOperator::MatchesLtxtquery { field, .. }
592            | WhereOperator::MatchesAnyLquery { field, .. }
593            | WhereOperator::DepthEq { field, .. }
594            | WhereOperator::DepthNeq { field, .. }
595            | WhereOperator::DepthGt { field, .. }
596            | WhereOperator::DepthGte { field, .. }
597            | WhereOperator::DepthLt { field, .. }
598            | WhereOperator::DepthLte { field, .. }
599            | WhereOperator::Lca { field, .. } => field.validate(),
600        }
601    }
602}
603
604#[cfg(test)]
605mod tests {
606    use super::*;
607
608    #[test]
609    fn test_operator_names() {
610        let op = WhereOperator::Eq(Field::JsonbField("id".to_string()), Value::Number(1.0));
611        assert_eq!(op.name(), "Eq");
612
613        let op = WhereOperator::LenGt(Field::JsonbField("tags".to_string()), 5);
614        assert_eq!(op.name(), "LenGt");
615    }
616
617    #[test]
618    fn test_operator_validation() {
619        let op = WhereOperator::Eq(
620            Field::JsonbField("name".to_string()),
621            Value::String("John".to_string()),
622        );
623        assert!(op.validate().is_ok());
624
625        let op = WhereOperator::Eq(
626            Field::JsonbField("bad-name".to_string()),
627            Value::String("John".to_string()),
628        );
629        assert!(op.validate().is_err());
630    }
631
632    #[test]
633    fn test_vector_operator_creation() {
634        let op = WhereOperator::L2Distance {
635            field: Field::JsonbField("embedding".to_string()),
636            vector: vec![0.1, 0.2, 0.3],
637            threshold: 0.5,
638        };
639        assert_eq!(op.name(), "L2Distance");
640    }
641
642    #[test]
643    fn test_network_operator_creation() {
644        let op = WhereOperator::InSubnet {
645            field: Field::JsonbField("ip".to_string()),
646            subnet: "192.168.0.0/24".to_string(),
647        };
648        assert_eq!(op.name(), "InSubnet");
649    }
650}