Skip to main content

fraiseql_wire/operators/
where_operator.rs

1//! WHERE clause operators
2//!
3//! Type-safe operator definitions for WHERE clause generation.
4//! Supports 25+ operators across 5 categories with both JSONB and direct column sources.
5
6use super::field::{Field, Value};
7
8/// WHERE clause operators
9///
10/// Supports type-safe, audit-friendly WHERE clause construction
11/// without raw SQL strings.
12///
13/// # Categories
14///
15/// - **Comparison**: Eq, Neq, Gt, Gte, Lt, Lte
16/// - **Array**: In, Nin, Contains, ArrayContains, ArrayContainedBy, ArrayOverlaps
17/// - **Array Length**: LenEq, LenGt, LenGte, LenLt, LenLte
18/// - **String**: Icontains, Startswith, Istartswith, Endswith, Iendswith, Like, Ilike
19/// - **Null**: IsNull
20/// - **Vector Distance**: L2Distance, CosineDistance, InnerProduct, L1Distance, HammingDistance, JaccardDistance
21/// - **Full-Text Search**: Matches, PlainQuery, PhraseQuery, WebsearchQuery
22/// - **Network**: IsIPv4, IsIPv6, IsPrivate, IsPublic, IsLoopback, InSubnet, ContainsSubnet, ContainsIP, IPRangeOverlap
23/// - **JSONB**: StrictlyContains
24/// - **LTree**: AncestorOf, DescendantOf, MatchesLquery, MatchesLtxtquery, MatchesAnyLquery,
25///   DepthEq, DepthNeq, DepthGt, DepthGte, DepthLt, DepthLte, Lca
26#[derive(Debug, Clone)]
27pub enum WhereOperator {
28    // ============ Comparison Operators ============
29    /// Equal: `field = value`
30    Eq(Field, Value),
31
32    /// Not equal: `field != value` or `field <> value`
33    Neq(Field, Value),
34
35    /// Greater than: `field > value`
36    Gt(Field, Value),
37
38    /// Greater than or equal: `field >= value`
39    Gte(Field, Value),
40
41    /// Less than: `field < value`
42    Lt(Field, Value),
43
44    /// Less than or equal: `field <= value`
45    Lte(Field, Value),
46
47    // ============ Array Operators ============
48    /// Array contains value: `field IN (...)`
49    In(Field, Vec<Value>),
50
51    /// Array does not contain value: `field NOT IN (...)`
52    Nin(Field, Vec<Value>),
53
54    /// String contains substring: `field LIKE '%substring%'`
55    Contains(Field, String),
56
57    /// Array contains element: PostgreSQL array operator `@>`
58    /// Generated SQL: `field @> array[value]`
59    ArrayContains(Field, Value),
60
61    /// Array is contained by: PostgreSQL array operator `<@`
62    /// Generated SQL: `field <@ array[value]`
63    ArrayContainedBy(Field, Value),
64
65    /// Arrays overlap: PostgreSQL array operator `&&`
66    /// Generated SQL: `field && array[value]`
67    ArrayOverlaps(Field, Vec<Value>),
68
69    // ============ Array Length Operators ============
70    /// Array length equals: `array_length(field, 1) = value`
71    LenEq(Field, usize),
72
73    /// Array length greater than: `array_length(field, 1) > value`
74    LenGt(Field, usize),
75
76    /// Array length greater than or equal: `array_length(field, 1) >= value`
77    LenGte(Field, usize),
78
79    /// Array length less than: `array_length(field, 1) < value`
80    LenLt(Field, usize),
81
82    /// Array length less than or equal: `array_length(field, 1) <= value`
83    LenLte(Field, usize),
84
85    // ============ String Operators ============
86    /// Case-insensitive contains: `field ILIKE '%substring%'`
87    Icontains(Field, String),
88
89    /// Starts with: `field LIKE 'prefix%'`
90    Startswith(Field, String),
91
92    /// Starts with (case-insensitive): `field ILIKE 'prefix%'`
93    Istartswith(Field, String),
94
95    /// Ends with: `field LIKE '%suffix'`
96    Endswith(Field, String),
97
98    /// Ends with (case-insensitive): `field ILIKE '%suffix'`
99    Iendswith(Field, String),
100
101    /// LIKE pattern matching: `field LIKE pattern`
102    Like(Field, String),
103
104    /// Case-insensitive LIKE: `field ILIKE pattern`
105    Ilike(Field, String),
106
107    // ============ Null Operator ============
108    /// IS NULL: `field IS NULL` or `field IS NOT NULL`
109    ///
110    /// When the boolean is true, generates `IS NULL`
111    /// When false, generates `IS NOT NULL`
112    IsNull(Field, bool),
113
114    // ============ Vector Distance Operators (pgvector) ============
115    /// L2 (Euclidean) distance: `l2_distance(field, vector) < threshold`
116    ///
117    /// Requires pgvector extension.
118    L2Distance {
119        /// The vector field to compare against
120        field: Field,
121        /// The embedding vector for distance calculation
122        vector: Vec<f32>,
123        /// Distance threshold for comparison
124        threshold: f32,
125    },
126
127    /// Cosine distance: `cosine_distance(field, vector) < threshold`
128    ///
129    /// Requires pgvector extension.
130    CosineDistance {
131        /// The vector field to compare against
132        field: Field,
133        /// The embedding vector for distance calculation
134        vector: Vec<f32>,
135        /// Distance threshold for comparison
136        threshold: f32,
137    },
138
139    /// Inner product: `inner_product(field, vector) > threshold`
140    ///
141    /// Requires pgvector extension.
142    InnerProduct {
143        /// The vector field to compare against
144        field: Field,
145        /// The embedding vector for distance calculation
146        vector: Vec<f32>,
147        /// Distance threshold for comparison
148        threshold: f32,
149    },
150
151    /// L1 (Manhattan) distance: `l1_distance(field, vector) < threshold`
152    ///
153    /// Requires pgvector extension.
154    L1Distance {
155        /// The vector field to compare against
156        field: Field,
157        /// The embedding vector for distance calculation
158        vector: Vec<f32>,
159        /// Distance threshold for comparison
160        threshold: f32,
161    },
162
163    /// Hamming distance: `hamming_distance(field, vector) < threshold`
164    ///
165    /// Requires pgvector extension. Works with bit vectors.
166    HammingDistance {
167        /// The vector field to compare against
168        field: Field,
169        /// The embedding vector for distance calculation
170        vector: Vec<f32>,
171        /// Distance threshold for comparison
172        threshold: f32,
173    },
174
175    /// Jaccard distance: `jaccard_distance(field, set) < threshold`
176    ///
177    /// Works with text arrays, measures set overlap.
178    JaccardDistance {
179        /// The field to compare against
180        field: Field,
181        /// The set of values for comparison
182        set: Vec<String>,
183        /// Distance threshold for comparison
184        threshold: f32,
185    },
186
187    // ============ Full-Text Search Operators ============
188    /// Full-text search with language: `field @@ plainto_tsquery(language, query)`
189    ///
190    /// If language is None, defaults to 'english'
191    Matches {
192        /// The text field to search
193        field: Field,
194        /// The search query
195        query: String,
196        /// Optional language for text search (default: english)
197        language: Option<String>,
198    },
199
200    /// Plain text query: `field @@ plainto_tsquery(query)`
201    ///
202    /// Uses no language specification
203    PlainQuery {
204        /// The text field to search
205        field: Field,
206        /// The search query
207        query: String,
208    },
209
210    /// Phrase query with language: `field @@ phraseto_tsquery(language, query)`
211    ///
212    /// If language is None, defaults to 'english'
213    PhraseQuery {
214        /// The text field to search
215        field: Field,
216        /// The search query
217        query: String,
218        /// Optional language for text search (default: english)
219        language: Option<String>,
220    },
221
222    /// Web search query with language: `field @@ websearch_to_tsquery(language, query)`
223    ///
224    /// If language is None, defaults to 'english'
225    WebsearchQuery {
226        /// The text field to search
227        field: Field,
228        /// The search query
229        query: String,
230        /// Optional language for text search (default: english)
231        language: Option<String>,
232    },
233
234    // ============ Network/INET Operators ============
235    /// Check if IP is IPv4: `family(field) = 4`
236    IsIPv4(Field),
237
238    /// Check if IP is IPv6: `family(field) = 6`
239    IsIPv6(Field),
240
241    /// Check if IP is private (RFC1918): matches private ranges
242    IsPrivate(Field),
243
244    /// Check if IP is public (not private): opposite of IsPrivate
245    IsPublic(Field),
246
247    /// Check if IP is loopback: IPv4 127.0.0.0/8 or IPv6 ::1/128
248    IsLoopback(Field),
249
250    /// IP is in subnet: `field << subnet`
251    ///
252    /// The subnet should be in CIDR notation (e.g., "192.168.0.0/24")
253    InSubnet {
254        /// The IP field to check
255        field: Field,
256        /// The CIDR subnet (e.g., "192.168.0.0/24")
257        subnet: String,
258    },
259
260    /// Network contains subnet: `field >> subnet`
261    ///
262    /// The subnet should be in CIDR notation
263    ContainsSubnet {
264        /// The network field to check
265        field: Field,
266        /// The CIDR subnet to check for containment
267        subnet: String,
268    },
269
270    /// Network/range contains IP: `field >> ip`
271    ///
272    /// The IP should be a single address (e.g., "192.168.1.1")
273    ContainsIP {
274        /// The network field to check
275        field: Field,
276        /// The IP address to check for containment
277        ip: String,
278    },
279
280    /// IP ranges overlap: `field && range`
281    ///
282    /// The range should be in CIDR notation
283    IPRangeOverlap {
284        /// The IP range field to check
285        field: Field,
286        /// The IP range to check for overlap
287        range: String,
288    },
289
290    // ============ JSONB Operators ============
291    /// JSONB strictly contains: `field @> value`
292    ///
293    /// Checks if the JSONB field strictly contains the given value
294    StrictlyContains(Field, Value),
295
296    // ============ LTree Operators (Hierarchical) ============
297    /// Ancestor of: `field @> path`
298    ///
299    /// Checks if the ltree field is an ancestor of the given path
300    AncestorOf {
301        /// The ltree field to check
302        field: Field,
303        /// The path to check ancestry against
304        path: String,
305    },
306
307    /// Descendant of: `field <@ path`
308    ///
309    /// Checks if the ltree field is a descendant of the given path
310    DescendantOf {
311        /// The ltree field to check
312        field: Field,
313        /// The path to check descendancy against
314        path: String,
315    },
316
317    /// Matches lquery: `field ~ lquery`
318    ///
319    /// Checks if the ltree field matches the given lquery pattern
320    MatchesLquery {
321        /// The ltree field to check
322        field: Field,
323        /// The lquery pattern to match against
324        pattern: String,
325    },
326
327    /// Matches ltxtquery: `field @ ltxtquery`
328    ///
329    /// Checks if the ltree field matches the given ltxtquery pattern (Boolean query syntax)
330    MatchesLtxtquery {
331        /// The ltree field to check
332        field: Field,
333        /// The ltxtquery pattern to match against (e.g., "Science & !Deprecated")
334        query: String,
335    },
336
337    /// Matches any lquery: `field ? array[lqueries]`
338    ///
339    /// Checks if the ltree field matches any of the given lquery patterns
340    MatchesAnyLquery {
341        /// The ltree field to check
342        field: Field,
343        /// Array of lquery patterns to match against
344        patterns: Vec<String>,
345    },
346
347    /// LTree depth equals: `nlevel(field) = depth`
348    DepthEq {
349        /// The ltree field to check
350        field: Field,
351        /// The depth value to compare
352        depth: usize,
353    },
354
355    /// LTree depth not equals: `nlevel(field) != depth`
356    DepthNeq {
357        /// The ltree field to check
358        field: Field,
359        /// The depth value to compare
360        depth: usize,
361    },
362
363    /// LTree depth greater than: `nlevel(field) > depth`
364    DepthGt {
365        /// The ltree field to check
366        field: Field,
367        /// The depth value to compare
368        depth: usize,
369    },
370
371    /// LTree depth greater than or equal: `nlevel(field) >= depth`
372    DepthGte {
373        /// The ltree field to check
374        field: Field,
375        /// The depth value to compare
376        depth: usize,
377    },
378
379    /// LTree depth less than: `nlevel(field) < depth`
380    DepthLt {
381        /// The ltree field to check
382        field: Field,
383        /// The depth value to compare
384        depth: usize,
385    },
386
387    /// LTree depth less than or equal: `nlevel(field) <= depth`
388    DepthLte {
389        /// The ltree field to check
390        field: Field,
391        /// The depth value to compare
392        depth: usize,
393    },
394
395    /// Lowest common ancestor: `lca(field, paths)`
396    ///
397    /// Checks if the ltree field equals the lowest common ancestor of the given paths
398    Lca {
399        /// The ltree field to check
400        field: Field,
401        /// The paths to find LCA of
402        paths: Vec<String>,
403    },
404}
405
406impl WhereOperator {
407    /// Get a human-readable name for this operator
408    pub fn name(&self) -> &'static str {
409        match self {
410            WhereOperator::Eq(_, _) => "Eq",
411            WhereOperator::Neq(_, _) => "Neq",
412            WhereOperator::Gt(_, _) => "Gt",
413            WhereOperator::Gte(_, _) => "Gte",
414            WhereOperator::Lt(_, _) => "Lt",
415            WhereOperator::Lte(_, _) => "Lte",
416            WhereOperator::In(_, _) => "In",
417            WhereOperator::Nin(_, _) => "Nin",
418            WhereOperator::Contains(_, _) => "Contains",
419            WhereOperator::ArrayContains(_, _) => "ArrayContains",
420            WhereOperator::ArrayContainedBy(_, _) => "ArrayContainedBy",
421            WhereOperator::ArrayOverlaps(_, _) => "ArrayOverlaps",
422            WhereOperator::LenEq(_, _) => "LenEq",
423            WhereOperator::LenGt(_, _) => "LenGt",
424            WhereOperator::LenGte(_, _) => "LenGte",
425            WhereOperator::LenLt(_, _) => "LenLt",
426            WhereOperator::LenLte(_, _) => "LenLte",
427            WhereOperator::Icontains(_, _) => "Icontains",
428            WhereOperator::Startswith(_, _) => "Startswith",
429            WhereOperator::Istartswith(_, _) => "Istartswith",
430            WhereOperator::Endswith(_, _) => "Endswith",
431            WhereOperator::Iendswith(_, _) => "Iendswith",
432            WhereOperator::Like(_, _) => "Like",
433            WhereOperator::Ilike(_, _) => "Ilike",
434            WhereOperator::IsNull(_, _) => "IsNull",
435            WhereOperator::L2Distance { .. } => "L2Distance",
436            WhereOperator::CosineDistance { .. } => "CosineDistance",
437            WhereOperator::InnerProduct { .. } => "InnerProduct",
438            WhereOperator::L1Distance { .. } => "L1Distance",
439            WhereOperator::HammingDistance { .. } => "HammingDistance",
440            WhereOperator::JaccardDistance { .. } => "JaccardDistance",
441            WhereOperator::Matches { .. } => "Matches",
442            WhereOperator::PlainQuery { .. } => "PlainQuery",
443            WhereOperator::PhraseQuery { .. } => "PhraseQuery",
444            WhereOperator::WebsearchQuery { .. } => "WebsearchQuery",
445            WhereOperator::IsIPv4(_) => "IsIPv4",
446            WhereOperator::IsIPv6(_) => "IsIPv6",
447            WhereOperator::IsPrivate(_) => "IsPrivate",
448            WhereOperator::IsPublic(_) => "IsPublic",
449            WhereOperator::IsLoopback(_) => "IsLoopback",
450            WhereOperator::InSubnet { .. } => "InSubnet",
451            WhereOperator::ContainsSubnet { .. } => "ContainsSubnet",
452            WhereOperator::ContainsIP { .. } => "ContainsIP",
453            WhereOperator::IPRangeOverlap { .. } => "IPRangeOverlap",
454            WhereOperator::StrictlyContains(_, _) => "StrictlyContains",
455            WhereOperator::AncestorOf { .. } => "AncestorOf",
456            WhereOperator::DescendantOf { .. } => "DescendantOf",
457            WhereOperator::MatchesLquery { .. } => "MatchesLquery",
458            WhereOperator::MatchesLtxtquery { .. } => "MatchesLtxtquery",
459            WhereOperator::MatchesAnyLquery { .. } => "MatchesAnyLquery",
460            WhereOperator::DepthEq { .. } => "DepthEq",
461            WhereOperator::DepthNeq { .. } => "DepthNeq",
462            WhereOperator::DepthGt { .. } => "DepthGt",
463            WhereOperator::DepthGte { .. } => "DepthGte",
464            WhereOperator::DepthLt { .. } => "DepthLt",
465            WhereOperator::DepthLte { .. } => "DepthLte",
466            WhereOperator::Lca { .. } => "Lca",
467        }
468    }
469
470    /// Validate operator for basic correctness
471    pub fn validate(&self) -> Result<(), String> {
472        match self {
473            WhereOperator::Eq(f, _)
474            | WhereOperator::Neq(f, _)
475            | WhereOperator::Gt(f, _)
476            | WhereOperator::Gte(f, _)
477            | WhereOperator::Lt(f, _)
478            | WhereOperator::Lte(f, _)
479            | WhereOperator::In(f, _)
480            | WhereOperator::Nin(f, _)
481            | WhereOperator::Contains(f, _)
482            | WhereOperator::ArrayContains(f, _)
483            | WhereOperator::ArrayContainedBy(f, _)
484            | WhereOperator::ArrayOverlaps(f, _)
485            | WhereOperator::LenEq(f, _)
486            | WhereOperator::LenGt(f, _)
487            | WhereOperator::LenGte(f, _)
488            | WhereOperator::LenLt(f, _)
489            | WhereOperator::LenLte(f, _)
490            | WhereOperator::Icontains(f, _)
491            | WhereOperator::Startswith(f, _)
492            | WhereOperator::Istartswith(f, _)
493            | WhereOperator::Endswith(f, _)
494            | WhereOperator::Iendswith(f, _)
495            | WhereOperator::Like(f, _)
496            | WhereOperator::Ilike(f, _)
497            | WhereOperator::IsNull(f, _)
498            | WhereOperator::StrictlyContains(f, _) => f.validate(),
499
500            WhereOperator::L2Distance { field, .. }
501            | WhereOperator::CosineDistance { field, .. }
502            | WhereOperator::InnerProduct { field, .. }
503            | WhereOperator::L1Distance { field, .. }
504            | WhereOperator::HammingDistance { field, .. }
505            | WhereOperator::JaccardDistance { field, .. }
506            | WhereOperator::Matches { field, .. }
507            | WhereOperator::PlainQuery { field, .. }
508            | WhereOperator::PhraseQuery { field, .. }
509            | WhereOperator::WebsearchQuery { field, .. }
510            | WhereOperator::IsIPv4(field)
511            | WhereOperator::IsIPv6(field)
512            | WhereOperator::IsPrivate(field)
513            | WhereOperator::IsPublic(field)
514            | WhereOperator::IsLoopback(field)
515            | WhereOperator::InSubnet { field, .. }
516            | WhereOperator::ContainsSubnet { field, .. }
517            | WhereOperator::ContainsIP { field, .. }
518            | WhereOperator::IPRangeOverlap { field, .. }
519            | WhereOperator::AncestorOf { field, .. }
520            | WhereOperator::DescendantOf { field, .. }
521            | WhereOperator::MatchesLquery { field, .. }
522            | WhereOperator::MatchesLtxtquery { field, .. }
523            | WhereOperator::MatchesAnyLquery { field, .. }
524            | WhereOperator::DepthEq { field, .. }
525            | WhereOperator::DepthNeq { field, .. }
526            | WhereOperator::DepthGt { field, .. }
527            | WhereOperator::DepthGte { field, .. }
528            | WhereOperator::DepthLt { field, .. }
529            | WhereOperator::DepthLte { field, .. }
530            | WhereOperator::Lca { field, .. } => field.validate(),
531        }
532    }
533}
534
535#[cfg(test)]
536mod tests {
537    use super::*;
538
539    #[test]
540    fn test_operator_names() {
541        let op = WhereOperator::Eq(Field::JsonbField("id".to_string()), Value::Number(1.0));
542        assert_eq!(op.name(), "Eq");
543
544        let op = WhereOperator::LenGt(Field::JsonbField("tags".to_string()), 5);
545        assert_eq!(op.name(), "LenGt");
546    }
547
548    #[test]
549    fn test_operator_validation() {
550        let op = WhereOperator::Eq(
551            Field::JsonbField("name".to_string()),
552            Value::String("John".to_string()),
553        );
554        assert!(op.validate().is_ok());
555
556        let op = WhereOperator::Eq(
557            Field::JsonbField("bad-name".to_string()),
558            Value::String("John".to_string()),
559        );
560        assert!(op.validate().is_err());
561    }
562
563    #[test]
564    fn test_vector_operator_creation() {
565        let op = WhereOperator::L2Distance {
566            field: Field::JsonbField("embedding".to_string()),
567            vector: vec![0.1, 0.2, 0.3],
568            threshold: 0.5,
569        };
570        assert_eq!(op.name(), "L2Distance");
571    }
572
573    #[test]
574    fn test_network_operator_creation() {
575        let op = WhereOperator::InSubnet {
576            field: Field::JsonbField("ip".to_string()),
577            subnet: "192.168.0.0/24".to_string(),
578        };
579        assert_eq!(op.name(), "InSubnet");
580    }
581}