Skip to main content

nautilus_core/
expr.rs

1//! Expression AST for building WHERE clauses and filters.
2
3use crate::select::Select;
4use crate::value::Value;
5
6/// Internal expression function marker rendered as pgvector `<->`.
7pub const VECTOR_L2_DISTANCE_FUNCTION: &str = "__nautilus_vector_l2_distance";
8/// Internal expression function marker rendered as pgvector `<#>`.
9pub const VECTOR_INNER_PRODUCT_FUNCTION: &str = "__nautilus_vector_inner_product";
10/// Internal expression function marker rendered as pgvector `<=>`.
11pub const VECTOR_COSINE_DISTANCE_FUNCTION: &str = "__nautilus_vector_cosine_distance";
12
13/// Relation filter operator used by generated relation helpers.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum RelationFilterOp {
16    /// At least one related row matches the child filter.
17    Some,
18    /// No related row matches the child filter.
19    None,
20    /// Every related row matches the child filter.
21    Every,
22}
23
24/// Metadata needed to render or serialize a relation predicate.
25#[derive(Debug, Clone, PartialEq)]
26pub struct RelationFilter {
27    /// Logical relation field name on the parent model.
28    pub field: String,
29    /// Database table name of the parent model.
30    pub parent_table: String,
31    /// Database table name of the related child model.
32    pub target_table: String,
33    /// Child-side foreign-key column name.
34    pub fk_db: String,
35    /// Parent-side referenced key column name.
36    pub pk_db: String,
37    /// Child filter to apply inside the relation predicate.
38    pub filter: Box<Expr>,
39}
40
41/// Binary operators for expressions.
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum BinaryOp {
44    /// Equality (`=`).
45    Eq,
46    /// Not equal (`!=`).
47    Ne,
48    /// Less than (`<`).
49    Lt,
50    /// Less than or equal (`<=`).
51    Le,
52    /// Greater than (`>`).
53    Gt,
54    /// Greater than or equal (`>=`).
55    Ge,
56    /// Logical AND.
57    And,
58    /// Logical OR.
59    Or,
60    /// LIKE pattern matching.
61    Like,
62    /// Array contains (`@>` in PostgreSQL).
63    ArrayContains,
64    /// Array is contained by (`<@` in PostgreSQL).
65    ArrayContainedBy,
66    /// Array overlaps (`&&` in PostgreSQL).
67    ArrayOverlaps,
68    /// IN list membership.
69    In,
70    /// NOT IN list membership.
71    NotIn,
72}
73
74/// A SQL fragment emitted verbatim into the query text, **not** bound as a
75/// parameter (e.g. literal key names in `json_build_object`).
76///
77/// Because the contained text bypasses parameter binding, it must never carry
78/// untrusted user input. This newtype makes that contract explicit: a value can
79/// only be created through [`LiteralSql::from_static`] (compile-time safe) or
80/// [`LiteralSql::trusted`] (a deliberately-named, greppable assertion that the
81/// caller vetted the string). The inner `String` is private, so an
82/// `Expr::Literal` can never be built directly from a bare runtime string.
83#[derive(Debug, Clone, PartialEq, Eq)]
84pub struct LiteralSql(String);
85
86impl LiteralSql {
87    /// Build from a compile-time string. Always safe: a `&'static str` baked into
88    /// the binary can never be untrusted user input.
89    #[must_use]
90    pub fn from_static(text: &'static str) -> Self {
91        Self(text.to_string())
92    }
93
94    /// Build from a runtime string the caller asserts is trusted (e.g. a schema
95    /// column name), never raw user input. Prefer [`Self::from_static`] when the
96    /// value is known at compile time.
97    #[must_use]
98    pub fn trusted(text: impl Into<String>) -> Self {
99        Self(text.into())
100    }
101
102    /// The underlying SQL text.
103    #[must_use]
104    pub fn as_str(&self) -> &str {
105        &self.0
106    }
107}
108
109/// Expression node for WHERE clauses and filters.
110#[derive(Debug, Clone, PartialEq)]
111pub enum Expr {
112    /// Column reference.
113    Column(String),
114    /// Parameter placeholder.
115    Param(Value),
116    /// Binary operation.
117    Binary {
118        /// Left operand.
119        left: Box<Expr>,
120        /// Operator.
121        op: BinaryOp,
122        /// Right operand.
123        right: Box<Expr>,
124    },
125    /// Logical NOT.
126    Not(Box<Expr>),
127    /// Function call (e.g., json_agg, COALESCE).
128    FunctionCall {
129        /// Function name.
130        name: String,
131        /// Function arguments.
132        args: Vec<Expr>,
133    },
134    /// SQL FILTER clause for aggregate functions (PostgreSQL).
135    Filter {
136        /// The aggregation expression.
137        expr: Box<Expr>,
138        /// The filter predicate.
139        predicate: Box<Expr>,
140    },
141    /// EXISTS subquery predicate — compiles to `EXISTS (SELECT ...)`.
142    Exists(Box<Select>),
143    /// NOT EXISTS subquery predicate — compiles to `NOT EXISTS (SELECT ...)`.
144    NotExists(Box<Select>),
145    /// Relation predicate (`some` / `none` / `every`) with explicit relation metadata.
146    Relation {
147        /// Which relation operator to apply.
148        op: RelationFilterOp,
149        /// Relation metadata and nested child filter.
150        relation: Box<RelationFilter>,
151    },
152    /// Scalar subquery — compiles to `(SELECT ...)`.
153    ///
154    /// The inner SELECT must return exactly one row and one column.  Used for
155    /// correlated aggregate sub‑queries (e.g. relation includes) that must not
156    /// produce a cartesian product when two or more relations are joined.
157    ScalarSubquery(Box<Select>),
158    /// IS NULL check — compiles to `expr IS NULL`.
159    IsNull(Box<Expr>),
160    /// IS NOT NULL check — compiles to `expr IS NOT NULL`.
161    IsNotNull(Box<Expr>),
162    /// A raw SQL string literal emitted verbatim (no parameter binding).
163    ///
164    /// Use this sparingly — only for values that must appear as SQL literals
165    /// rather than positional parameters (e.g. keys in `json_build_object`).
166    /// The [`LiteralSql`] newtype enforces that the text is trusted, never raw
167    /// user input.
168    Literal(LiteralSql),
169    /// An ordered list of expressions for use in IN / NOT IN clauses.
170    ///
171    /// Rendered as a comma-separated sequence; the surrounding parentheses are
172    /// added by the IN/NOT IN rendering path in each dialect.
173    List(Vec<Expr>),
174    /// CASE WHEN … THEN … ELSE NULL END.
175    CaseWhen {
176        /// The condition.
177        condition: Box<Expr>,
178        /// The THEN result.
179        then: Box<Expr>,
180    },
181    /// SQL wildcard `*` — used inside aggregate functions like `COUNT(*)`.
182    Star,
183}
184
185impl Expr {
186    /// Creates a column reference.
187    pub fn column(name: impl Into<String>) -> Self {
188        Expr::Column(name.into())
189    }
190
191    /// Creates a parameter placeholder.
192    pub fn param(value: impl Into<Value>) -> Self {
193        Expr::Param(value.into())
194    }
195
196    /// Creates an equality comparison (`=`).
197    #[must_use]
198    pub fn eq(self, other: Expr) -> Self {
199        Expr::Binary {
200            left: Box::new(self),
201            op: BinaryOp::Eq,
202            right: Box::new(other),
203        }
204    }
205
206    /// Creates a not-equal comparison (`!=`).
207    #[must_use]
208    pub fn ne(self, other: Expr) -> Self {
209        Expr::Binary {
210            left: Box::new(self),
211            op: BinaryOp::Ne,
212            right: Box::new(other),
213        }
214    }
215
216    /// Creates a less-than comparison (`<`).
217    #[must_use]
218    pub fn lt(self, other: Expr) -> Self {
219        Expr::Binary {
220            left: Box::new(self),
221            op: BinaryOp::Lt,
222            right: Box::new(other),
223        }
224    }
225
226    /// Creates a less-than-or-equal comparison (`<=`).
227    #[must_use]
228    pub fn le(self, other: Expr) -> Self {
229        Expr::Binary {
230            left: Box::new(self),
231            op: BinaryOp::Le,
232            right: Box::new(other),
233        }
234    }
235
236    /// Creates a greater-than comparison (`>`).
237    #[must_use]
238    pub fn gt(self, other: Expr) -> Self {
239        Expr::Binary {
240            left: Box::new(self),
241            op: BinaryOp::Gt,
242            right: Box::new(other),
243        }
244    }
245
246    /// Creates a greater-than-or-equal comparison (`>=`).
247    #[must_use]
248    pub fn ge(self, other: Expr) -> Self {
249        Expr::Binary {
250            left: Box::new(self),
251            op: BinaryOp::Ge,
252            right: Box::new(other),
253        }
254    }
255
256    /// Creates a logical AND.
257    #[must_use]
258    pub fn and(self, other: Expr) -> Self {
259        Expr::Binary {
260            left: Box::new(self),
261            op: BinaryOp::And,
262            right: Box::new(other),
263        }
264    }
265
266    /// Creates a logical OR.
267    #[must_use]
268    pub fn or(self, other: Expr) -> Self {
269        Expr::Binary {
270            left: Box::new(self),
271            op: BinaryOp::Or,
272            right: Box::new(other),
273        }
274    }
275
276    /// Creates a LIKE pattern match.
277    #[must_use]
278    pub fn like(self, pattern: Expr) -> Self {
279        Expr::Binary {
280            left: Box::new(self),
281            op: BinaryOp::Like,
282            right: Box::new(pattern),
283        }
284    }
285
286    /// Creates an IN list membership check.
287    #[must_use]
288    pub fn in_list(self, exprs: Vec<Expr>) -> Self {
289        Expr::Binary {
290            left: Box::new(self),
291            op: BinaryOp::In,
292            right: Box::new(Expr::List(exprs)),
293        }
294    }
295
296    /// Creates a NOT IN list membership check.
297    #[must_use]
298    pub fn not_in_list(self, exprs: Vec<Expr>) -> Self {
299        Expr::Binary {
300            left: Box::new(self),
301            op: BinaryOp::NotIn,
302            right: Box::new(Expr::List(exprs)),
303        }
304    }
305
306    /// Creates a function call expression.
307    pub fn function_call(name: impl Into<String>, args: Vec<Expr>) -> Self {
308        Expr::FunctionCall {
309            name: name.into(),
310            args,
311        }
312    }
313
314    /// Creates an internal vector-distance expression for pgvector ordering.
315    pub fn vector_distance(metric: crate::args::VectorMetric, left: Expr, right: Expr) -> Self {
316        let function = match metric {
317            crate::args::VectorMetric::L2 => VECTOR_L2_DISTANCE_FUNCTION,
318            crate::args::VectorMetric::InnerProduct => VECTOR_INNER_PRODUCT_FUNCTION,
319            crate::args::VectorMetric::Cosine => VECTOR_COSINE_DISTANCE_FUNCTION,
320        };
321        Expr::function_call(function, vec![left, right])
322    }
323
324    /// Creates a json_agg() aggregate function.
325    pub fn json_agg(expr: Expr) -> Self {
326        Expr::FunctionCall {
327            name: "json_agg".to_string(),
328            args: vec![expr],
329        }
330    }
331
332    /// Creates a json_build_object() function with key-value pairs.
333    ///
334    /// Keys are emitted as SQL string literals (not bound parameters) because
335    /// `json_build_object` requires literal key names in all supported dialects.
336    ///
337    /// # Safety
338    ///
339    /// Keys must be static/compile-time strings. Never pass untrusted input as
340    /// a key — use [`Expr::param`] for that and handle it in application logic.
341    pub fn json_build_object(pairs: Vec<(String, Expr)>) -> Self {
342        let args: Vec<Expr> = pairs
343            .into_iter()
344            .flat_map(|(key, value)| vec![Expr::Literal(LiteralSql::trusted(key)), value])
345            .collect();
346
347        Expr::FunctionCall {
348            name: "json_build_object".to_string(),
349            args,
350        }
351    }
352
353    /// Creates a COALESCE() function to return first non-NULL value.
354    pub fn coalesce(exprs: Vec<Expr>) -> Self {
355        Expr::FunctionCall {
356            name: "COALESCE".to_string(),
357            args: exprs,
358        }
359    }
360
361    /// Creates an IS NOT NULL check — compiles to `expr IS NOT NULL`.
362    #[must_use]
363    pub fn is_not_null(self) -> Self {
364        Expr::IsNotNull(Box::new(self))
365    }
366
367    /// Creates an IS NULL check — compiles to `expr IS NULL`.
368    #[must_use]
369    pub fn is_null(self) -> Self {
370        Expr::IsNull(Box::new(self))
371    }
372
373    /// Adds a FILTER clause to an aggregate expression (PostgreSQL).
374    #[must_use]
375    pub fn filter(self, predicate: Expr) -> Self {
376        Expr::Filter {
377            expr: Box::new(self),
378            predicate: Box::new(predicate),
379        }
380    }
381
382    /// Creates an EXISTS subquery predicate.
383    pub fn exists(subquery: Select) -> Self {
384        Expr::Exists(Box::new(subquery))
385    }
386
387    /// Creates a NOT EXISTS subquery predicate.
388    pub fn not_exists(subquery: Select) -> Self {
389        Expr::NotExists(Box::new(subquery))
390    }
391
392    /// Creates a relation `some` predicate.
393    pub fn relation_some(
394        field: impl Into<String>,
395        parent_table: impl Into<String>,
396        target_table: impl Into<String>,
397        fk_db: impl Into<String>,
398        pk_db: impl Into<String>,
399        filter: Expr,
400    ) -> Self {
401        Expr::Relation {
402            op: RelationFilterOp::Some,
403            relation: Box::new(RelationFilter {
404                field: field.into(),
405                parent_table: parent_table.into(),
406                target_table: target_table.into(),
407                fk_db: fk_db.into(),
408                pk_db: pk_db.into(),
409                filter: Box::new(filter),
410            }),
411        }
412    }
413
414    /// Creates a relation `none` predicate.
415    pub fn relation_none(
416        field: impl Into<String>,
417        parent_table: impl Into<String>,
418        target_table: impl Into<String>,
419        fk_db: impl Into<String>,
420        pk_db: impl Into<String>,
421        filter: Expr,
422    ) -> Self {
423        Expr::Relation {
424            op: RelationFilterOp::None,
425            relation: Box::new(RelationFilter {
426                field: field.into(),
427                parent_table: parent_table.into(),
428                target_table: target_table.into(),
429                fk_db: fk_db.into(),
430                pk_db: pk_db.into(),
431                filter: Box::new(filter),
432            }),
433        }
434    }
435
436    /// Creates a relation `every` predicate.
437    pub fn relation_every(
438        field: impl Into<String>,
439        parent_table: impl Into<String>,
440        target_table: impl Into<String>,
441        fk_db: impl Into<String>,
442        pk_db: impl Into<String>,
443        filter: Expr,
444    ) -> Self {
445        Expr::Relation {
446            op: RelationFilterOp::Every,
447            relation: Box::new(RelationFilter {
448                field: field.into(),
449                parent_table: parent_table.into(),
450                target_table: target_table.into(),
451                fk_db: fk_db.into(),
452                pk_db: pk_db.into(),
453                filter: Box::new(filter),
454            }),
455        }
456    }
457
458    /// Creates a scalar subquery expression `(SELECT ...)`.
459    ///
460    /// The inner SELECT must return exactly one column and at most one row.
461    pub fn scalar_subquery(subquery: Select) -> Self {
462        Expr::ScalarSubquery(Box::new(subquery))
463    }
464
465    /// Creates a `CASE WHEN condition THEN result ELSE NULL END` expression.
466    pub fn case_when(condition: Expr, then: Expr) -> Self {
467        Expr::CaseWhen {
468            condition: Box::new(condition),
469            then: Box::new(then),
470        }
471    }
472
473    /// Creates the SQL wildcard `*` (for use in `COUNT(*)`).
474    pub fn star() -> Self {
475        Expr::Star
476    }
477}
478
479/// Implements the `!` operator for expressions, producing a SQL `NOT` clause.
480impl std::ops::Not for Expr {
481    type Output = Self;
482
483    fn not(self) -> Self::Output {
484        Expr::Not(Box::new(self))
485    }
486}
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491
492    #[test]
493    fn test_column_expr() {
494        let expr = Expr::column("email");
495        match expr {
496            Expr::Column(name) => assert_eq!(name, "email"),
497            _ => panic!("Expected Column variant"),
498        }
499    }
500
501    #[test]
502    fn test_param_expr() {
503        let expr = Expr::param(42i64);
504        match expr {
505            Expr::Param(Value::I64(42)) => {}
506            _ => panic!("Expected Param with I64(42)"),
507        }
508    }
509
510    #[test]
511    fn test_binary_ops() {
512        let col = Expr::column("age");
513        let val = Expr::param(18i64);
514
515        let expr = col.ge(val);
516        match expr {
517            Expr::Binary { op, .. } => assert_eq!(op, BinaryOp::Ge),
518            _ => panic!("Expected Binary expression"),
519        }
520    }
521
522    #[test]
523    fn test_complex_expr() {
524        let expr = Expr::column("age")
525            .ge(Expr::param(18i64))
526            .and(Expr::column("email").like(Expr::param("%@gmail.com")));
527
528        match expr {
529            Expr::Binary { op, .. } => assert_eq!(op, BinaryOp::And),
530            _ => panic!("Expected Binary AND expression"),
531        }
532    }
533
534    #[test]
535    fn test_not_expr() {
536        let expr = !Expr::column("active").eq(Expr::param(true));
537        match expr {
538            Expr::Not(_) => {}
539            _ => panic!("Expected Not expression"),
540        }
541    }
542
543    #[test]
544    fn test_in_list() {
545        let expr = Expr::column("status").in_list(vec![
546            Expr::param(Value::String("active".to_string())),
547            Expr::param(Value::String("pending".to_string())),
548        ]);
549        match expr {
550            Expr::Binary { op, .. } => assert_eq!(op, BinaryOp::In),
551            _ => panic!("Expected Binary IN expression"),
552        }
553    }
554
555    #[test]
556    fn test_not_in_list() {
557        let expr = Expr::column("role").not_in_list(vec![
558            Expr::param(Value::String("admin".to_string())),
559            Expr::param(Value::String("superuser".to_string())),
560        ]);
561        match expr {
562            Expr::Binary { op, .. } => assert_eq!(op, BinaryOp::NotIn),
563            _ => panic!("Expected Binary NOT IN expression"),
564        }
565    }
566
567    #[test]
568    fn test_relation_predicate() {
569        let expr = Expr::relation_some(
570            "posts",
571            "users",
572            "posts",
573            "author_id",
574            "id",
575            Expr::column("posts__published").eq(Expr::param(true)),
576        );
577        match expr {
578            Expr::Relation { op, relation } => {
579                assert_eq!(op, RelationFilterOp::Some);
580                assert_eq!(relation.field, "posts");
581                assert_eq!(relation.parent_table, "users");
582                assert_eq!(relation.target_table, "posts");
583                assert_eq!(relation.fk_db, "author_id");
584                assert_eq!(relation.pk_db, "id");
585            }
586            _ => panic!("Expected relation predicate"),
587        }
588    }
589}