Skip to main content

fraiseql_wire/operators/
where_operator.rs

1//! WHERE clause operators
2//!
3//! Type-safe operator definitions for WHERE clause generation.
4//! Supports 25+ operators across 5 categories with both JSONB and direct column sources.
5
6use super::field::{Field, Value};
7
8/// WHERE clause operators
9///
10/// Supports type-safe, audit-friendly WHERE clause construction
11/// without raw SQL strings.
12///
13/// # Categories
14///
15/// - **Comparison**: Eq, Neq, Gt, Gte, Lt, Lte
16/// - **Array**: In, Nin, Contains, ArrayContains, ArrayContainedBy, ArrayOverlaps
17/// - **Array Length**: LenEq, LenGt, LenGte, LenLt, LenLte
18/// - **String**: Icontains, Startswith, Endswith, Like, Ilike
19/// - **Null**: IsNull
20/// - **Vector Distance**: L2Distance, CosineDistance, InnerProduct, JaccardDistance
21/// - **Full-Text Search**: Matches, PlainQuery, PhraseQuery, WebsearchQuery
22/// - **Network**: IsIPv4, IsIPv6, IsPrivate, IsLoopback, InSubnet, ContainsSubnet, ContainsIP, IPRangeOverlap
23#[derive(Debug, Clone)]
24pub enum WhereOperator {
25    // ============ Comparison Operators ============
26    /// Equal: `field = value`
27    Eq(Field, Value),
28
29    /// Not equal: `field != value` or `field <> value`
30    Neq(Field, Value),
31
32    /// Greater than: `field > value`
33    Gt(Field, Value),
34
35    /// Greater than or equal: `field >= value`
36    Gte(Field, Value),
37
38    /// Less than: `field < value`
39    Lt(Field, Value),
40
41    /// Less than or equal: `field <= value`
42    Lte(Field, Value),
43
44    // ============ Array Operators ============
45    /// Array contains value: `field IN (...)`
46    In(Field, Vec<Value>),
47
48    /// Array does not contain value: `field NOT IN (...)`
49    Nin(Field, Vec<Value>),
50
51    /// String contains substring: `field LIKE '%substring%'`
52    Contains(Field, String),
53
54    /// Array contains element: PostgreSQL array operator `@>`
55    /// Generated SQL: `field @> array[value]`
56    ArrayContains(Field, Value),
57
58    /// Array is contained by: PostgreSQL array operator `<@`
59    /// Generated SQL: `field <@ array[value]`
60    ArrayContainedBy(Field, Value),
61
62    /// Arrays overlap: PostgreSQL array operator `&&`
63    /// Generated SQL: `field && array[value]`
64    ArrayOverlaps(Field, Vec<Value>),
65
66    // ============ Array Length Operators ============
67    /// Array length equals: `array_length(field, 1) = value`
68    LenEq(Field, usize),
69
70    /// Array length greater than: `array_length(field, 1) > value`
71    LenGt(Field, usize),
72
73    /// Array length greater than or equal: `array_length(field, 1) >= value`
74    LenGte(Field, usize),
75
76    /// Array length less than: `array_length(field, 1) < value`
77    LenLt(Field, usize),
78
79    /// Array length less than or equal: `array_length(field, 1) <= value`
80    LenLte(Field, usize),
81
82    // ============ String Operators ============
83    /// Case-insensitive contains: `field ILIKE '%substring%'`
84    Icontains(Field, String),
85
86    /// Starts with: `field LIKE 'prefix%'`
87    Startswith(Field, String),
88
89    /// Ends with: `field LIKE '%suffix'`
90    Endswith(Field, String),
91
92    /// LIKE pattern matching: `field LIKE pattern`
93    Like(Field, String),
94
95    /// Case-insensitive LIKE: `field ILIKE pattern`
96    Ilike(Field, String),
97
98    // ============ Null Operator ============
99    /// IS NULL: `field IS NULL` or `field IS NOT NULL`
100    ///
101    /// When the boolean is true, generates `IS NULL`
102    /// When false, generates `IS NOT NULL`
103    IsNull(Field, bool),
104
105    // ============ Vector Distance Operators (pgvector) ============
106    /// L2 (Euclidean) distance: `l2_distance(field, vector) < threshold`
107    ///
108    /// Requires pgvector extension.
109    L2Distance {
110        /// The vector field to compare against
111        field: Field,
112        /// The embedding vector for distance calculation
113        vector: Vec<f32>,
114        /// Distance threshold for comparison
115        threshold: f32,
116    },
117
118    /// Cosine distance: `cosine_distance(field, vector) < threshold`
119    ///
120    /// Requires pgvector extension.
121    CosineDistance {
122        /// The vector field to compare against
123        field: Field,
124        /// The embedding vector for distance calculation
125        vector: Vec<f32>,
126        /// Distance threshold for comparison
127        threshold: f32,
128    },
129
130    /// Inner product: `inner_product(field, vector) > threshold`
131    ///
132    /// Requires pgvector extension.
133    InnerProduct {
134        /// The vector field to compare against
135        field: Field,
136        /// The embedding vector for distance calculation
137        vector: Vec<f32>,
138        /// Distance threshold for comparison
139        threshold: f32,
140    },
141
142    /// Jaccard distance: `jaccard_distance(field, set) < threshold`
143    ///
144    /// Works with text arrays, measures set overlap.
145    JaccardDistance {
146        /// The field to compare against
147        field: Field,
148        /// The set of values for comparison
149        set: Vec<String>,
150        /// Distance threshold for comparison
151        threshold: f32,
152    },
153
154    // ============ Full-Text Search Operators ============
155    /// Full-text search with language: `field @@ plainto_tsquery(language, query)`
156    ///
157    /// If language is None, defaults to 'english'
158    Matches {
159        /// The text field to search
160        field: Field,
161        /// The search query
162        query: String,
163        /// Optional language for text search (default: english)
164        language: Option<String>,
165    },
166
167    /// Plain text query: `field @@ plainto_tsquery(query)`
168    ///
169    /// Uses no language specification
170    PlainQuery {
171        /// The text field to search
172        field: Field,
173        /// The search query
174        query: String,
175    },
176
177    /// Phrase query with language: `field @@ phraseto_tsquery(language, query)`
178    ///
179    /// If language is None, defaults to 'english'
180    PhraseQuery {
181        /// The text field to search
182        field: Field,
183        /// The search query
184        query: String,
185        /// Optional language for text search (default: english)
186        language: Option<String>,
187    },
188
189    /// Web search query with language: `field @@ websearch_to_tsquery(language, query)`
190    ///
191    /// If language is None, defaults to 'english'
192    WebsearchQuery {
193        /// The text field to search
194        field: Field,
195        /// The search query
196        query: String,
197        /// Optional language for text search (default: english)
198        language: Option<String>,
199    },
200
201    // ============ Network/INET Operators ============
202    /// Check if IP is IPv4: `family(field) = 4`
203    IsIPv4(Field),
204
205    /// Check if IP is IPv6: `family(field) = 6`
206    IsIPv6(Field),
207
208    /// Check if IP is private (RFC1918): matches private ranges
209    IsPrivate(Field),
210
211    /// Check if IP is loopback: IPv4 127.0.0.0/8 or IPv6 ::1/128
212    IsLoopback(Field),
213
214    /// IP is in subnet: `field << subnet`
215    ///
216    /// The subnet should be in CIDR notation (e.g., "192.168.0.0/24")
217    InSubnet {
218        /// The IP field to check
219        field: Field,
220        /// The CIDR subnet (e.g., "192.168.0.0/24")
221        subnet: String,
222    },
223
224    /// Network contains subnet: `field >> subnet`
225    ///
226    /// The subnet should be in CIDR notation
227    ContainsSubnet {
228        /// The network field to check
229        field: Field,
230        /// The CIDR subnet to check for containment
231        subnet: String,
232    },
233
234    /// Network/range contains IP: `field >> ip`
235    ///
236    /// The IP should be a single address (e.g., "192.168.1.1")
237    ContainsIP {
238        /// The network field to check
239        field: Field,
240        /// The IP address to check for containment
241        ip: String,
242    },
243
244    /// IP ranges overlap: `field && range`
245    ///
246    /// The range should be in CIDR notation
247    IPRangeOverlap {
248        /// The IP range field to check
249        field: Field,
250        /// The IP range to check for overlap
251        range: String,
252    },
253}
254
255impl WhereOperator {
256    /// Get a human-readable name for this operator
257    pub fn name(&self) -> &'static str {
258        match self {
259            WhereOperator::Eq(_, _) => "Eq",
260            WhereOperator::Neq(_, _) => "Neq",
261            WhereOperator::Gt(_, _) => "Gt",
262            WhereOperator::Gte(_, _) => "Gte",
263            WhereOperator::Lt(_, _) => "Lt",
264            WhereOperator::Lte(_, _) => "Lte",
265            WhereOperator::In(_, _) => "In",
266            WhereOperator::Nin(_, _) => "Nin",
267            WhereOperator::Contains(_, _) => "Contains",
268            WhereOperator::ArrayContains(_, _) => "ArrayContains",
269            WhereOperator::ArrayContainedBy(_, _) => "ArrayContainedBy",
270            WhereOperator::ArrayOverlaps(_, _) => "ArrayOverlaps",
271            WhereOperator::LenEq(_, _) => "LenEq",
272            WhereOperator::LenGt(_, _) => "LenGt",
273            WhereOperator::LenGte(_, _) => "LenGte",
274            WhereOperator::LenLt(_, _) => "LenLt",
275            WhereOperator::LenLte(_, _) => "LenLte",
276            WhereOperator::Icontains(_, _) => "Icontains",
277            WhereOperator::Startswith(_, _) => "Startswith",
278            WhereOperator::Endswith(_, _) => "Endswith",
279            WhereOperator::Like(_, _) => "Like",
280            WhereOperator::Ilike(_, _) => "Ilike",
281            WhereOperator::IsNull(_, _) => "IsNull",
282            WhereOperator::L2Distance { .. } => "L2Distance",
283            WhereOperator::CosineDistance { .. } => "CosineDistance",
284            WhereOperator::InnerProduct { .. } => "InnerProduct",
285            WhereOperator::JaccardDistance { .. } => "JaccardDistance",
286            WhereOperator::Matches { .. } => "Matches",
287            WhereOperator::PlainQuery { .. } => "PlainQuery",
288            WhereOperator::PhraseQuery { .. } => "PhraseQuery",
289            WhereOperator::WebsearchQuery { .. } => "WebsearchQuery",
290            WhereOperator::IsIPv4(_) => "IsIPv4",
291            WhereOperator::IsIPv6(_) => "IsIPv6",
292            WhereOperator::IsPrivate(_) => "IsPrivate",
293            WhereOperator::IsLoopback(_) => "IsLoopback",
294            WhereOperator::InSubnet { .. } => "InSubnet",
295            WhereOperator::ContainsSubnet { .. } => "ContainsSubnet",
296            WhereOperator::ContainsIP { .. } => "ContainsIP",
297            WhereOperator::IPRangeOverlap { .. } => "IPRangeOverlap",
298        }
299    }
300
301    /// Validate operator for basic correctness
302    pub fn validate(&self) -> Result<(), String> {
303        match self {
304            WhereOperator::Eq(f, _)
305            | WhereOperator::Neq(f, _)
306            | WhereOperator::Gt(f, _)
307            | WhereOperator::Gte(f, _)
308            | WhereOperator::Lt(f, _)
309            | WhereOperator::Lte(f, _)
310            | WhereOperator::In(f, _)
311            | WhereOperator::Nin(f, _)
312            | WhereOperator::Contains(f, _)
313            | WhereOperator::ArrayContains(f, _)
314            | WhereOperator::ArrayContainedBy(f, _)
315            | WhereOperator::ArrayOverlaps(f, _)
316            | WhereOperator::LenEq(f, _)
317            | WhereOperator::LenGt(f, _)
318            | WhereOperator::LenGte(f, _)
319            | WhereOperator::LenLt(f, _)
320            | WhereOperator::LenLte(f, _)
321            | WhereOperator::Icontains(f, _)
322            | WhereOperator::Startswith(f, _)
323            | WhereOperator::Endswith(f, _)
324            | WhereOperator::Like(f, _)
325            | WhereOperator::Ilike(f, _)
326            | WhereOperator::IsNull(f, _) => f.validate(),
327
328            WhereOperator::L2Distance { field, .. }
329            | WhereOperator::CosineDistance { field, .. }
330            | WhereOperator::InnerProduct { field, .. }
331            | WhereOperator::JaccardDistance { field, .. }
332            | WhereOperator::Matches { field, .. }
333            | WhereOperator::PlainQuery { field, .. }
334            | WhereOperator::PhraseQuery { field, .. }
335            | WhereOperator::WebsearchQuery { field, .. }
336            | WhereOperator::IsIPv4(field)
337            | WhereOperator::IsIPv6(field)
338            | WhereOperator::IsPrivate(field)
339            | WhereOperator::IsLoopback(field)
340            | WhereOperator::InSubnet { field, .. }
341            | WhereOperator::ContainsSubnet { field, .. }
342            | WhereOperator::ContainsIP { field, .. }
343            | WhereOperator::IPRangeOverlap { field, .. } => field.validate(),
344        }
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    #[test]
353    fn test_operator_names() {
354        let op = WhereOperator::Eq(Field::JsonbField("id".to_string()), Value::Number(1.0));
355        assert_eq!(op.name(), "Eq");
356
357        let op = WhereOperator::LenGt(Field::JsonbField("tags".to_string()), 5);
358        assert_eq!(op.name(), "LenGt");
359    }
360
361    #[test]
362    fn test_operator_validation() {
363        let op = WhereOperator::Eq(
364            Field::JsonbField("name".to_string()),
365            Value::String("John".to_string()),
366        );
367        assert!(op.validate().is_ok());
368
369        let op = WhereOperator::Eq(
370            Field::JsonbField("bad-name".to_string()),
371            Value::String("John".to_string()),
372        );
373        assert!(op.validate().is_err());
374    }
375
376    #[test]
377    fn test_vector_operator_creation() {
378        let op = WhereOperator::L2Distance {
379            field: Field::JsonbField("embedding".to_string()),
380            vector: vec![0.1, 0.2, 0.3],
381            threshold: 0.5,
382        };
383        assert_eq!(op.name(), "L2Distance");
384    }
385
386    #[test]
387    fn test_network_operator_creation() {
388        let op = WhereOperator::InSubnet {
389            field: Field::JsonbField("ip".to_string()),
390            subnet: "192.168.0.0/24".to_string(),
391        };
392        assert_eq!(op.name(), "InSubnet");
393    }
394}