Skip to main content

qail_core/ast/
expr.rs

1use crate::ast::{AggregateFunc, Cage, Condition, ModKind, Value};
2use serde::{Deserialize, Serialize};
3
4/// Binary operators for expressions
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
6pub enum BinaryOp {
7    // Arithmetic
8    /// String concatenation `||`.
9    Concat,
10    /// Addition `+`.
11    Add,
12    /// Subtraction `-`.
13    Sub,
14    /// Multiplication `*`.
15    Mul,
16    /// Division `/`.
17    Div,
18    /// Modulo (%)
19    Rem,
20    // Logical
21    /// Logical AND.
22    And,
23    /// Logical OR.
24    Or,
25    // Comparison
26    /// Equals `=`.
27    Eq,
28    /// Not equals `<>`.
29    Ne,
30    /// Greater than `>`.
31    Gt,
32    /// Greater than or equal `>=`.
33    Gte,
34    /// Less than `<`.
35    Lt,
36    /// Less than or equal `<=`.
37    Lte,
38    // Null checks (unary but represented as binary with null right)
39    /// IS NULL.
40    IsNull,
41    /// IS NOT NULL.
42    IsNotNull,
43}
44
45impl std::fmt::Display for BinaryOp {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        match self {
48            BinaryOp::Concat => write!(f, "||"),
49            BinaryOp::Add => write!(f, "+"),
50            BinaryOp::Sub => write!(f, "-"),
51            BinaryOp::Mul => write!(f, "*"),
52            BinaryOp::Div => write!(f, "/"),
53            BinaryOp::Rem => write!(f, "%"),
54            BinaryOp::And => write!(f, "AND"),
55            BinaryOp::Or => write!(f, "OR"),
56            BinaryOp::Eq => write!(f, "="),
57            BinaryOp::Ne => write!(f, "<>"),
58            BinaryOp::Gt => write!(f, ">"),
59            BinaryOp::Gte => write!(f, ">="),
60            BinaryOp::Lt => write!(f, "<"),
61            BinaryOp::Lte => write!(f, "<="),
62            BinaryOp::IsNull => write!(f, "IS NULL"),
63            BinaryOp::IsNotNull => write!(f, "IS NOT NULL"),
64        }
65    }
66}
67
68/// An expression node in the AST.
69#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
70pub enum Expr {
71    /// All columns (*)
72    Star,
73    /// A named column or identifier.
74    Named(String),
75    /// An aliased expression (expr AS alias)
76    Aliased {
77        /// Expression name.
78        name: String,
79        /// Alias.
80        alias: String,
81    },
82    /// An aggregate function (COUNT(col)) with optional FILTER and DISTINCT
83    Aggregate {
84        /// Column to aggregate.
85        col: String,
86        /// Aggregate function.
87        func: AggregateFunc,
88        /// Whether DISTINCT is applied.
89        distinct: bool,
90        /// PostgreSQL FILTER (WHERE ...) clause for aggregates
91        filter: Option<Vec<Condition>>,
92        /// Optional alias.
93        alias: Option<String>,
94    },
95    /// Type cast expression (expr::type)
96    Cast {
97        /// Expression to cast.
98        expr: Box<Expr>,
99        /// Target SQL type.
100        target_type: String,
101        /// Optional alias.
102        alias: Option<String>,
103    },
104    /// Column definition (name, type, constraints).
105    Def {
106        /// Column name.
107        name: String,
108        /// SQL data type.
109        data_type: String,
110        /// Column constraints.
111        constraints: Vec<Constraint>,
112    },
113    /// ALTER TABLE modify (ADD/DROP column).
114    Mod {
115        /// Modification kind.
116        kind: ModKind,
117        /// Column expression.
118        col: Box<Expr>,
119    },
120    /// Window Function Definition
121    Window {
122        /// Window name/alias.
123        name: String,
124        /// Window function name.
125        func: String,
126        /// Function arguments as expressions (e.g., for SUM(amount), use Expr::Named("amount"))
127        params: Vec<Expr>,
128        /// PARTITION BY columns.
129        partition: Vec<String>,
130        /// ORDER BY clauses.
131        order: Vec<Cage>,
132        /// Frame specification.
133        frame: Option<WindowFrame>,
134    },
135    /// CASE WHEN expression
136    Case {
137        /// WHEN condition THEN expr pairs (Expr allows functions, values, identifiers)
138        when_clauses: Vec<(Condition, Box<Expr>)>,
139        /// ELSE expr (optional)
140        else_value: Option<Box<Expr>>,
141        /// Optional alias
142        alias: Option<String>,
143    },
144    /// JSON accessor (data->>'key' or data->'key' or chained data->'a'->0->>'b')
145    JsonAccess {
146        /// Base column name
147        column: String,
148        /// JSON path segments: (key, as_text)
149        /// as_text: true for ->> (extract as text), false for -> (extract as JSON)
150        /// For chained access like x->'a'->0->>'b', this is [("a", false), ("0", false), ("b", true)]
151        path_segments: Vec<(String, bool)>,
152        /// Optional alias
153        alias: Option<String>,
154    },
155    /// Function call expression (COALESCE, NULLIF, etc.)
156    FunctionCall {
157        /// Function name (coalesce, nullif, etc.)
158        name: String,
159        /// Arguments to the function (now supports nested expressions)
160        args: Vec<Expr>,
161        /// Optional alias
162        alias: Option<String>,
163    },
164    /// Special SQL function with keyword arguments (SUBSTRING, EXTRACT, TRIM, etc.)
165    /// e.g., SUBSTRING(expr FROM pos [FOR len]), EXTRACT(YEAR FROM date)
166    SpecialFunction {
167        /// Function name (SUBSTRING, EXTRACT, TRIM, etc.)
168        name: String,
169        /// Arguments as (optional_keyword, expr) pairs
170        /// e.g., [(None, col), (Some("FROM"), 2), (Some("FOR"), 5)]
171        args: Vec<(Option<String>, Box<Expr>)>,
172        /// Optional alias
173        alias: Option<String>,
174    },
175    /// Binary expression (left op right)
176    Binary {
177        /// Left operand.
178        left: Box<Expr>,
179        /// Binary operator.
180        op: BinaryOp,
181        /// Right operand.
182        right: Box<Expr>,
183        /// Optional alias.
184        alias: Option<String>,
185    },
186    /// Literal value (string, number) for use in expressions
187    /// e.g., '62', 0, 'active'
188    Literal(Value),
189    /// Array constructor: ARRAY[expr1, expr2, ...]
190    ArrayConstructor {
191        /// Array elements.
192        elements: Vec<Expr>,
193        /// Optional alias.
194        alias: Option<String>,
195    },
196    /// Row constructor: ROW(expr1, expr2, ...) or (expr1, expr2, ...)
197    RowConstructor {
198        /// Row elements.
199        elements: Vec<Expr>,
200        /// Optional alias.
201        alias: Option<String>,
202    },
203    /// Array/string subscript: `arr[index]`.
204    Subscript {
205        /// Base expression.
206        expr: Box<Expr>,
207        /// Index expression.
208        index: Box<Expr>,
209        /// Optional alias.
210        alias: Option<String>,
211    },
212    /// Collation: expr COLLATE "collation_name"
213    Collate {
214        /// Expression.
215        expr: Box<Expr>,
216        /// Collation name.
217        collation: String,
218        /// Optional alias.
219        alias: Option<String>,
220    },
221    /// Field selection from composite: (row).field
222    FieldAccess {
223        /// Composite expression.
224        expr: Box<Expr>,
225        /// Field name.
226        field: String,
227        /// Optional alias.
228        alias: Option<String>,
229    },
230    /// Scalar subquery: (SELECT ... LIMIT 1)
231    /// Used in COALESCE, comparisons, etc.
232    Subquery {
233        /// Inner query.
234        query: Box<super::Qail>,
235        /// Optional alias.
236        alias: Option<String>,
237    },
238    /// EXISTS subquery: EXISTS(SELECT ...)
239    Exists {
240        /// Inner query.
241        query: Box<super::Qail>,
242        /// Whether this is NOT EXISTS.
243        negated: bool,
244        /// Optional alias.
245        alias: Option<String>,
246    },
247    /// Raw SQL expression — escape hatch for expressions that cannot be
248    /// reverse-parsed into typed AST nodes (e.g. from pg_policies introspection).
249    /// Prefer typed variants wherever possible.
250    Raw(String),
251}
252
253impl std::fmt::Display for Expr {
254    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255        match self {
256            Expr::Star => write!(f, "*"),
257            Expr::Named(name) => write!(f, "{}", name),
258            Expr::Aliased { name, alias } => write!(f, "{} AS {}", name, alias),
259            Expr::Aggregate {
260                col,
261                func,
262                distinct,
263                filter,
264                alias,
265            } => {
266                if *distinct {
267                    write!(f, "{}(DISTINCT {})", func, col)?;
268                } else {
269                    write!(f, "{}({})", func, col)?;
270                }
271                if let Some(conditions) = filter {
272                    write!(
273                        f,
274                        " FILTER (WHERE {})",
275                        conditions
276                            .iter()
277                            .map(|c| c.to_string())
278                            .collect::<Vec<_>>()
279                            .join(" AND ")
280                    )?;
281                }
282                if let Some(a) = alias {
283                    write!(f, " AS {}", a)?;
284                }
285                Ok(())
286            }
287            Expr::Cast {
288                expr,
289                target_type,
290                alias,
291            } => {
292                write!(f, "{}::{}", expr, target_type)?;
293                if let Some(a) = alias {
294                    write!(f, " AS {}", a)?;
295                }
296                Ok(())
297            }
298            Expr::Def {
299                name,
300                data_type,
301                constraints,
302            } => {
303                write!(f, "{}:{}", name, data_type)?;
304                for c in constraints {
305                    write!(f, "^{}", c)?;
306                }
307                Ok(())
308            }
309            Expr::Mod { kind, col } => match kind {
310                ModKind::Add => write!(f, "+{}", col),
311                ModKind::Drop => write!(f, "-{}", col),
312            },
313            Expr::Window {
314                name,
315                func,
316                params,
317                partition,
318                order,
319                frame,
320            } => {
321                write!(f, "{}:{}(", name, func)?;
322                for (i, p) in params.iter().enumerate() {
323                    if i > 0 {
324                        write!(f, ", ")?;
325                    }
326                    write!(f, "{}", p)?;
327                }
328                write!(f, ")")?;
329
330                // Print partitions if any
331                if !partition.is_empty() {
332                    write!(f, "{{Part=")?;
333                    for (i, p) in partition.iter().enumerate() {
334                        if i > 0 {
335                            write!(f, ",")?;
336                        }
337                        write!(f, "{}", p)?;
338                    }
339                    if let Some(fr) = frame {
340                        write!(f, ", Frame={:?}", fr)?; // Debug format for now
341                    }
342                    write!(f, "}}")?;
343                } else if let Some(fr) = frame {
344                    write!(f, "{{Frame={:?}}}", fr)?;
345                }
346
347                // Print order cages
348                for _cage in order {
349                    // Order cages are sort cages - display format TBD
350                }
351                Ok(())
352            }
353            Expr::Case {
354                when_clauses,
355                else_value,
356                alias,
357            } => {
358                write!(f, "CASE")?;
359                for (cond, val) in when_clauses {
360                    write!(f, " WHEN {} THEN {}", cond.left, val)?;
361                }
362                if let Some(e) = else_value {
363                    write!(f, " ELSE {}", e)?;
364                }
365                write!(f, " END")?;
366                if let Some(a) = alias {
367                    write!(f, " AS {}", a)?;
368                }
369                Ok(())
370            }
371            Expr::JsonAccess {
372                column,
373                path_segments,
374                alias,
375            } => {
376                write!(f, "{}", column)?;
377                for (path, as_text) in path_segments {
378                    let op = if *as_text { "->>" } else { "->" };
379                    // Integer indices should NOT be quoted (array access)
380                    // String keys should be quoted (object access)
381                    if path.parse::<i64>().is_ok() {
382                        write!(f, "{}{}", op, path)?;
383                    } else {
384                        write!(f, "{}'{}'", op, path)?;
385                    }
386                }
387                if let Some(a) = alias {
388                    write!(f, " AS {}", a)?;
389                }
390                Ok(())
391            }
392            Expr::FunctionCall { name, args, alias } => {
393                let args_str: Vec<String> = args.iter().map(|a| a.to_string()).collect();
394                write!(f, "{}({})", name.to_uppercase(), args_str.join(", "))?;
395                if let Some(a) = alias {
396                    write!(f, " AS {}", a)?;
397                }
398                Ok(())
399            }
400            Expr::SpecialFunction { name, args, alias } => {
401                write!(f, "{}(", name.to_uppercase())?;
402                for (i, (keyword, expr)) in args.iter().enumerate() {
403                    if i > 0 {
404                        write!(f, " ")?;
405                    }
406                    if let Some(kw) = keyword {
407                        write!(f, "{} ", kw)?;
408                    }
409                    write!(f, "{}", expr)?;
410                }
411                write!(f, ")")?;
412                if let Some(a) = alias {
413                    write!(f, " AS {}", a)?;
414                }
415                Ok(())
416            }
417            Expr::Binary {
418                left,
419                op,
420                right,
421                alias,
422            } => {
423                write!(f, "({} {} {})", left, op, right)?;
424                if let Some(a) = alias {
425                    write!(f, " AS {}", a)?;
426                }
427                Ok(())
428            }
429            Expr::Literal(value) => write!(f, "{}", value),
430            Expr::ArrayConstructor { elements, alias } => {
431                write!(f, "ARRAY[")?;
432                for (i, elem) in elements.iter().enumerate() {
433                    if i > 0 {
434                        write!(f, ", ")?;
435                    }
436                    write!(f, "{}", elem)?;
437                }
438                write!(f, "]")?;
439                if let Some(a) = alias {
440                    write!(f, " AS {}", a)?;
441                }
442                Ok(())
443            }
444            Expr::RowConstructor { elements, alias } => {
445                write!(f, "ROW(")?;
446                for (i, elem) in elements.iter().enumerate() {
447                    if i > 0 {
448                        write!(f, ", ")?;
449                    }
450                    write!(f, "{}", elem)?;
451                }
452                write!(f, ")")?;
453                if let Some(a) = alias {
454                    write!(f, " AS {}", a)?;
455                }
456                Ok(())
457            }
458            Expr::Subscript { expr, index, alias } => {
459                write!(f, "{}[{}]", expr, index)?;
460                if let Some(a) = alias {
461                    write!(f, " AS {}", a)?;
462                }
463                Ok(())
464            }
465            Expr::Collate {
466                expr,
467                collation,
468                alias,
469            } => {
470                write!(f, "{} COLLATE \"{}\"", expr, collation)?;
471                if let Some(a) = alias {
472                    write!(f, " AS {}", a)?;
473                }
474                Ok(())
475            }
476            Expr::FieldAccess { expr, field, alias } => {
477                write!(f, "({}).{}", expr, field)?;
478                if let Some(a) = alias {
479                    write!(f, " AS {}", a)?;
480                }
481                Ok(())
482            }
483            Expr::Subquery { query, alias } => {
484                write!(f, "({})", query)?;
485                if let Some(a) = alias {
486                    write!(f, " AS {}", a)?;
487                }
488                Ok(())
489            }
490            Expr::Exists {
491                query,
492                negated,
493                alias,
494            } => {
495                if *negated {
496                    write!(f, "NOT ")?;
497                }
498                write!(f, "EXISTS ({})", query)?;
499                if let Some(a) = alias {
500                    write!(f, " AS {}", a)?;
501                }
502                Ok(())
503            }
504            Expr::Raw(sql) => write!(f, "{}", sql),
505        }
506    }
507}
508
509/// Column constraint.
510#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
511pub enum Constraint {
512    /// PRIMARY KEY.
513    PrimaryKey,
514    /// UNIQUE.
515    Unique,
516    /// NULL / nullable.
517    Nullable,
518    /// DEFAULT value.
519    Default(String),
520    /// CHECK constraint.
521    Check(Vec<String>),
522    /// COMMENT ON COLUMN.
523    Comment(String),
524    /// REFERENCES foreign key.
525    References(String),
526    /// GENERATED column.
527    Generated(ColumnGeneration),
528}
529
530/// Generated column type (STORED or VIRTUAL)
531#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
532pub enum ColumnGeneration {
533    /// GENERATED ALWAYS AS (expr) STORED - computed and stored
534    Stored(String),
535    /// GENERATED ALWAYS AS (expr) - computed at query time (default in Postgres 18+)
536    Virtual(String),
537}
538
539/// Window frame definition for window functions
540#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
541pub enum WindowFrame {
542    /// ROWS BETWEEN start AND end
543    Rows {
544        /// Frame start bound.
545        start: FrameBound,
546        /// Frame end bound.
547        end: FrameBound,
548    },
549    /// RANGE BETWEEN start AND end
550    Range {
551        /// Frame start bound.
552        start: FrameBound,
553        /// Frame end bound.
554        end: FrameBound,
555    },
556}
557
558/// Window frame boundary
559#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
560pub enum FrameBound {
561    /// UNBOUNDED PRECEDING.
562    UnboundedPreceding,
563    /// n PRECEDING.
564    Preceding(i32),
565    /// CURRENT ROW.
566    CurrentRow,
567    /// n FOLLOWING.
568    Following(i32),
569    /// UNBOUNDED FOLLOWING.
570    UnboundedFollowing,
571}
572
573impl std::fmt::Display for Constraint {
574    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
575        match self {
576            Constraint::PrimaryKey => write!(f, "pk"),
577            Constraint::Unique => write!(f, "uniq"),
578            Constraint::Nullable => write!(f, "?"),
579            Constraint::Default(val) => write!(f, "={}", val),
580            Constraint::Check(vals) => write!(f, "check({})", vals.join(",")),
581            Constraint::Comment(text) => write!(f, "comment(\"{}\")", text),
582            Constraint::References(target) => write!(f, "ref({})", target),
583            Constraint::Generated(generation) => match generation {
584                ColumnGeneration::Stored(expr) => write!(f, "gen({})", expr),
585                ColumnGeneration::Virtual(expr) => write!(f, "vgen({})", expr),
586            },
587        }
588    }
589}
590
591/// Index definition for CREATE INDEX
592#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
593pub struct IndexDef {
594    /// Index name
595    pub name: String,
596    /// Target table
597    pub table: String,
598    /// Columns to index (ordered)
599    pub columns: Vec<String>,
600    /// Whether the index is unique.
601    pub unique: bool,
602    /// Index type (e.g., "keyword", "integer", "float", "geo", "text")
603    #[serde(default)]
604    pub index_type: Option<String>,
605}
606
607/// Table-level constraints for composite keys
608#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
609pub enum TableConstraint {
610    /// Composite UNIQUE constraint.
611    Unique(Vec<String>),
612    /// Composite PRIMARY KEY.
613    PrimaryKey(Vec<String>),
614}
615
616// ==================== From Implementations for Ergonomic API ====================
617
618impl From<&str> for Expr {
619    /// Convert a string reference to a Named expression.
620    /// Enables: `.select(["id", "name"])` instead of `.select([col("id"), col("name")])`
621    fn from(s: &str) -> Self {
622        Expr::Named(s.to_string())
623    }
624}
625
626impl From<String> for Expr {
627    fn from(s: String) -> Self {
628        Expr::Named(s)
629    }
630}
631
632impl From<&String> for Expr {
633    fn from(s: &String) -> Self {
634        Expr::Named(s.clone())
635    }
636}
637
638// ==================== Function and Trigger Definitions ====================
639
640/// PostgreSQL function definition
641#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
642pub struct FunctionDef {
643    /// Function name.
644    pub name: String,
645    /// Return type (e.g., "trigger", "integer", "void").
646    pub returns: String,
647    /// Function body (PL/pgSQL code).
648    pub body: String,
649    /// Language (default: plpgsql).
650    pub language: Option<String>,
651}
652
653/// Trigger timing (BEFORE or AFTER)
654#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
655pub enum TriggerTiming {
656    /// BEFORE.
657    Before,
658    /// AFTER.
659    After,
660    /// INSTEAD OF.
661    InsteadOf,
662}
663
664/// Trigger event types
665#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
666pub enum TriggerEvent {
667    /// INSERT.
668    Insert,
669    /// UPDATE.
670    Update,
671    /// DELETE.
672    Delete,
673    /// TRUNCATE.
674    Truncate,
675}
676
677/// PostgreSQL trigger definition
678#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
679pub struct TriggerDef {
680    /// Trigger name.
681    pub name: String,
682    /// Target table.
683    pub table: String,
684    /// Timing (BEFORE, AFTER, INSTEAD OF).
685    pub timing: TriggerTiming,
686    /// Events that fire the trigger.
687    pub events: Vec<TriggerEvent>,
688    /// Whether the trigger fires FOR EACH ROW.
689    pub for_each_row: bool,
690    /// Function to execute.
691    pub execute_function: String,
692}