Skip to main content

qail_core/fmt/
mod.rs

1use crate::ast::{
2    Action, Cage, CageKind, Condition, Expr, Join, LogicalOp, Operator, Qail, SortOrder, Value,
3};
4use std::fmt::{Result, Write};
5
6#[cfg(test)]
7mod tests;
8
9/// Pretty-printer for QAIL AST nodes.
10///
11/// Renders a `Qail` command back into human-readable QAIL syntax
12/// with indentation for readability.
13pub struct Formatter {
14    /// Current indentation depth.
15    indent_level: usize,
16    /// Output buffer.
17    buffer: String,
18}
19
20impl Default for Formatter {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26impl Formatter {
27    /// Create a new formatter with default settings.
28    pub fn new() -> Self {
29        Self {
30            indent_level: 0,
31            buffer: String::new(),
32        }
33    }
34
35    /// Format a query into a readable QAIL string.
36    pub fn format(mut self, cmd: &Qail) -> std::result::Result<String, std::fmt::Error> {
37        self.visit_cmd(cmd)?;
38        Ok(self.buffer)
39    }
40
41    fn indent(&mut self) -> Result {
42        for _ in 0..self.indent_level {
43            write!(self.buffer, "  ")?;
44        }
45        Ok(())
46    }
47
48    fn visit_cmd(&mut self, cmd: &Qail) -> Result {
49        for cte in &cmd.ctes {
50            write!(self.buffer, "with {} = ", cte.name)?;
51            self.indent_level += 1;
52            writeln!(self.buffer)?;
53            self.indent()?;
54            self.visit_cmd(&cte.base_query)?;
55
56            if cte.recursive
57                && let Some(ref recursive_query) = cte.recursive_query
58            {
59                writeln!(self.buffer)?;
60                self.indent()?;
61                writeln!(self.buffer, "union all")?;
62                self.indent()?;
63                self.visit_cmd(recursive_query)?;
64            }
65
66            self.indent_level -= 1;
67            writeln!(self.buffer)?;
68        }
69
70        // Action and Table
71        match cmd.action {
72            Action::Get => write!(self.buffer, "get {}", cmd.table)?,
73            Action::Set => write!(self.buffer, "set {}", cmd.table)?,
74            Action::Del => write!(self.buffer, "del {}", cmd.table)?,
75            Action::Add => write!(self.buffer, "add {}", cmd.table)?,
76            _ => write!(self.buffer, "{} {}", cmd.action, cmd.table)?, // Fallback for others
77        }
78        writeln!(self.buffer)?;
79
80        // self.indent_level += 1; // Removed: Clauses should act at same level as command
81
82        // Cages: Group By (if any "by" equivalent exists? No, "by" is usually implicit in AST or explicit in group_by_mode?)
83        // The proposal example shows "by phone_number".
84        // In AST `cmd.rs`, there isn't a direct "Group By" list, usually inferred or group_by_mode.
85        // Wait, where is `by phone_number` stored in AST?
86        // Checking `ast/cmd.rs`: `group_by_mode: GroupByMode`.
87        // Usually group by is inferred from aggregates or explicit.
88        // If the AST doesn't have explicit group by columns, we might need to derive it or it's in `cages`?
89        // Let's check `cages.rs` again. `CageKind` has `Filter`, `Sort`, `Limit`... no `GroupBy`.
90        // Maybe it's implied by non-aggregated columns in a `Get` with aggregates?
91        // For now, I will skip "by" unless I find it in AST.
92
93        if !cmd.columns.is_empty() {
94            // But proposal says "Canonical".
95            // "get table" implies "get table fields *" usually?
96            // If manual explicit columns:
97            if !(cmd.columns.len() == 1 && matches!(cmd.columns[0], Expr::Star)) {
98                self.indent()?;
99                writeln!(self.buffer, "fields")?;
100                self.indent_level += 1;
101                for (i, col) in cmd.columns.iter().enumerate() {
102                    self.indent()?;
103                    self.format_column(col)?;
104                    if i < cmd.columns.len() - 1 {
105                        writeln!(self.buffer, ",")?;
106                    } else {
107                        writeln!(self.buffer)?;
108                    }
109                }
110                self.indent_level -= 1;
111            }
112        }
113
114        // Joins
115        for join in &cmd.joins {
116            self.indent()?;
117            self.format_join(join)?;
118            writeln!(self.buffer)?;
119        }
120
121        // Where (Filter Cages)
122        let filters: Vec<&Cage> = cmd
123            .cages
124            .iter()
125            .filter(|c| matches!(c.kind, CageKind::Filter))
126            .collect();
127        if !filters.is_empty() {
128            // We need to merge them or print them?
129            // Proposal says: "where rn = 1"
130            self.indent()?;
131            write!(self.buffer, "where ")?;
132            for (i, cage) in filters.iter().enumerate() {
133                if i > 0 {
134                    write!(self.buffer, " and ")?; // Assuming AND between cages for now
135                }
136                self.format_conditions(&cage.conditions, cage.logical_op)?;
137            }
138            writeln!(self.buffer)?;
139        }
140
141        // Order By (Sort Cages)
142        let sorts: Vec<&Cage> = cmd
143            .cages
144            .iter()
145            .filter(|c| matches!(c.kind, CageKind::Sort(_)))
146            .collect();
147        if !sorts.is_empty() {
148            self.indent()?;
149            writeln!(self.buffer, "order by")?;
150            self.indent_level += 1;
151            for (i, cage) in sorts.iter().enumerate() {
152                if let CageKind::Sort(order) = cage.kind {
153                    for (j, cond) in cage.conditions.iter().enumerate() {
154                        self.indent()?;
155                        write!(self.buffer, "{}", cond.left)?;
156                        self.format_sort_order(order)?;
157                        if i < sorts.len() - 1 || j < cage.conditions.len() - 1 {
158                            writeln!(self.buffer, ",")?;
159                        } else {
160                            writeln!(self.buffer)?;
161                        }
162                    }
163                }
164            }
165            self.indent_level -= 1;
166        }
167
168        for cage in &cmd.cages {
169            match cage.kind {
170                CageKind::Limit(n) => {
171                    self.indent()?;
172                    writeln!(self.buffer, "limit {}", n)?;
173                }
174                CageKind::Offset(n) => {
175                    self.indent()?;
176                    writeln!(self.buffer, "offset {}", n)?;
177                }
178                _ => {}
179            }
180        }
181
182        // self.indent_level -= 1; // Removed matching decrement
183        Ok(())
184    }
185
186    fn format_column(&mut self, col: &Expr) -> Result {
187        match col {
188            Expr::Star => write!(self.buffer, "*")?,
189            Expr::Named(name) => write!(self.buffer, "{}", name)?,
190            Expr::Aliased { name, alias } => write!(self.buffer, "{} as {}", name, alias)?,
191            Expr::Aggregate {
192                col,
193                func,
194                distinct,
195                filter,
196                alias,
197            } => {
198                let func_name = match func {
199                    crate::ast::AggregateFunc::Count => "count",
200                    crate::ast::AggregateFunc::Sum => "sum",
201                    crate::ast::AggregateFunc::Avg => "avg",
202                    crate::ast::AggregateFunc::Min => "min",
203                    crate::ast::AggregateFunc::Max => "max",
204                    crate::ast::AggregateFunc::ArrayAgg => "array_agg",
205                    crate::ast::AggregateFunc::StringAgg => "string_agg",
206                    crate::ast::AggregateFunc::JsonAgg => "json_agg",
207                    crate::ast::AggregateFunc::JsonbAgg => "jsonb_agg",
208                    crate::ast::AggregateFunc::BoolAnd => "bool_and",
209                    crate::ast::AggregateFunc::BoolOr => "bool_or",
210                };
211                if *distinct {
212                    write!(self.buffer, "{}(distinct {})", func_name, col)?;
213                } else {
214                    write!(self.buffer, "{}({})", func_name, col)?;
215                }
216                if let Some(conditions) = filter {
217                    write!(
218                        self.buffer,
219                        " filter (where {})",
220                        conditions
221                            .iter()
222                            .map(|c| c.to_string())
223                            .collect::<Vec<_>>()
224                            .join(" and ")
225                    )?;
226                }
227                if let Some(a) = alias {
228                    write!(self.buffer, " as {}", a)?;
229                }
230            }
231            Expr::FunctionCall { name, args, alias } => {
232                let args_str: Vec<String> = args.iter().map(|a| a.to_string()).collect();
233                write!(self.buffer, "{}({})", name, args_str.join(", "))?;
234                if let Some(a) = alias {
235                    write!(self.buffer, " as {}", a)?;
236                }
237            }
238            Expr::Window {
239                name,
240                func,
241                params,
242                partition,
243                ..
244            } => {
245                // Use Window function format: func(params) OVER (PARTITION BY ...)
246                let params_str: Vec<String> = params.iter().map(|p| p.to_string()).collect();
247                write!(self.buffer, "{}({})", func, params_str.join(", "))?;
248                if !partition.is_empty() {
249                    write!(self.buffer, " over (partition by {})", partition.join(", "))?;
250                }
251                write!(self.buffer, " as {}", name)?;
252            }
253            Expr::Case {
254                when_clauses,
255                else_value,
256                alias,
257            } => {
258                write!(self.buffer, "case")?;
259                for (cond, val) in when_clauses {
260                    write!(self.buffer, " when {} then {}", cond.left, val)?;
261                }
262                if let Some(e) = else_value {
263                    write!(self.buffer, " else {}", e)?;
264                }
265                write!(self.buffer, " end")?;
266                if let Some(a) = alias {
267                    write!(self.buffer, " as {}", a)?;
268                }
269            }
270            Expr::JsonAccess {
271                column,
272                path_segments,
273                alias,
274            } => {
275                write!(self.buffer, "{}", column)?;
276                for (path, as_text) in path_segments {
277                    let op = if *as_text { "->>" } else { "->" };
278                    if path.parse::<i64>().is_ok() {
279                        write!(self.buffer, "{}{}", op, path)?;
280                    } else {
281                        write!(self.buffer, "{}'{}'", op, path)?;
282                    }
283                }
284                if let Some(a) = alias {
285                    write!(self.buffer, " as {}", a)?;
286                }
287            }
288            Expr::Cast {
289                expr,
290                target_type,
291                alias,
292            } => {
293                write!(self.buffer, "{}::{}", expr, target_type)?;
294                if let Some(a) = alias {
295                    write!(self.buffer, " as {}", a)?;
296                }
297            }
298            Expr::Binary {
299                left,
300                op,
301                right,
302                alias,
303            } => {
304                write!(self.buffer, "({} {} {})", left, op, right)?;
305                if let Some(a) = alias {
306                    write!(self.buffer, " as {}", a)?;
307                }
308            }
309            Expr::SpecialFunction { name, args, alias } => {
310                write!(self.buffer, "{}(", name)?;
311                for (i, (keyword, expr)) in args.iter().enumerate() {
312                    if i > 0 {
313                        write!(self.buffer, " ")?;
314                    }
315                    if let Some(kw) = keyword {
316                        write!(self.buffer, "{} ", kw)?;
317                    }
318                    write!(self.buffer, "{}", expr)?;
319                }
320                write!(self.buffer, ")")?;
321                if let Some(a) = alias {
322                    write!(self.buffer, " as {}", a)?;
323                }
324            }
325            Expr::Literal(val) => self.format_value(val)?,
326            Expr::Def {
327                name,
328                data_type,
329                constraints,
330            } => {
331                write!(self.buffer, "{}:{}", name, data_type)?;
332                for c in constraints {
333                    write!(self.buffer, "^{}", c)?;
334                }
335            }
336            Expr::Mod { kind, col } => {
337                let prefix = match kind {
338                    crate::ast::ModKind::Add => "+",
339                    crate::ast::ModKind::Drop => "-",
340                };
341                write!(self.buffer, "{}{}", prefix, col)?;
342            }
343            Expr::ArrayConstructor { elements, alias } => {
344                write!(self.buffer, "ARRAY[")?;
345                for (i, elem) in elements.iter().enumerate() {
346                    if i > 0 {
347                        write!(self.buffer, ", ")?;
348                    }
349                    self.format_column(elem)?;
350                }
351                write!(self.buffer, "]")?;
352                if let Some(a) = alias {
353                    write!(self.buffer, " as {}", a)?;
354                }
355            }
356            Expr::RowConstructor { elements, alias } => {
357                write!(self.buffer, "ROW(")?;
358                for (i, elem) in elements.iter().enumerate() {
359                    if i > 0 {
360                        write!(self.buffer, ", ")?;
361                    }
362                    self.format_column(elem)?;
363                }
364                write!(self.buffer, ")")?;
365                if let Some(a) = alias {
366                    write!(self.buffer, " as {}", a)?;
367                }
368            }
369            Expr::Subscript { expr, index, alias } => {
370                self.format_column(expr)?;
371                write!(self.buffer, "[")?;
372                self.format_column(index)?;
373                write!(self.buffer, "]")?;
374                if let Some(a) = alias {
375                    write!(self.buffer, " as {}", a)?;
376                }
377            }
378            Expr::Collate {
379                expr,
380                collation,
381                alias,
382            } => {
383                self.format_column(expr)?;
384                write!(self.buffer, " COLLATE \"{}\"", collation)?;
385                if let Some(a) = alias {
386                    write!(self.buffer, " as {}", a)?;
387                }
388            }
389            Expr::FieldAccess { expr, field, alias } => {
390                write!(self.buffer, "(")?;
391                self.format_column(expr)?;
392                write!(self.buffer, ").{}", field)?;
393                if let Some(a) = alias {
394                    write!(self.buffer, " as {}", a)?;
395                }
396            }
397            Expr::Subquery { query, alias } => {
398                write!(self.buffer, "(")?;
399                self.visit_cmd(query)?;
400                write!(self.buffer, ")")?;
401                if let Some(a) = alias {
402                    write!(self.buffer, " as {}", a)?;
403                }
404            }
405            Expr::Exists {
406                query,
407                negated,
408                alias,
409            } => {
410                if *negated {
411                    write!(self.buffer, "not ")?;
412                }
413                write!(self.buffer, "exists (")?;
414                self.visit_cmd(query)?;
415                write!(self.buffer, ")")?;
416                if let Some(a) = alias {
417                    write!(self.buffer, " as {}", a)?;
418                }
419            }
420        }
421        Ok(())
422    }
423
424    fn format_join(&mut self, join: &Join) -> Result {
425        match join.kind {
426            crate::ast::JoinKind::Inner => write!(self.buffer, "join {}", join.table)?,
427            crate::ast::JoinKind::Left => write!(self.buffer, "left join {}", join.table)?,
428            crate::ast::JoinKind::Right => write!(self.buffer, "right join {}", join.table)?,
429            crate::ast::JoinKind::Full => write!(self.buffer, "full join {}", join.table)?,
430            crate::ast::JoinKind::Cross => write!(self.buffer, "cross join {}", join.table)?,
431            crate::ast::JoinKind::Lateral => write!(self.buffer, "lateral join {}", join.table)?,
432        }
433
434        if let Some(conditions) = &join.on
435            && !conditions.is_empty()
436        {
437            writeln!(self.buffer)?;
438            self.indent_level += 1;
439            self.indent()?;
440            write!(self.buffer, "on ")?;
441            self.format_conditions(conditions, LogicalOp::And)?;
442            self.indent_level -= 1;
443        }
444        Ok(())
445    }
446
447    fn format_conditions(&mut self, conditions: &[Condition], logical_op: LogicalOp) -> Result {
448        for (i, cond) in conditions.iter().enumerate() {
449            if i > 0 {
450                match logical_op {
451                    LogicalOp::And => write!(self.buffer, " and ")?,
452                    LogicalOp::Or => write!(self.buffer, " or ")?,
453                }
454            }
455
456            write!(self.buffer, "{}", cond.left)?;
457
458            match cond.op {
459                Operator::Eq => write!(self.buffer, " = ")?,
460                Operator::Ne => write!(self.buffer, " != ")?,
461                Operator::Gt => write!(self.buffer, " > ")?,
462                Operator::Gte => write!(self.buffer, " >= ")?,
463                Operator::Lt => write!(self.buffer, " < ")?,
464                Operator::Lte => write!(self.buffer, " <= ")?,
465                Operator::Fuzzy => write!(self.buffer, " ~ ")?, // ILIKE
466                Operator::In => write!(self.buffer, " in ")?,
467                Operator::NotIn => write!(self.buffer, " not in ")?,
468                Operator::IsNull => write!(self.buffer, " is null")?,
469                Operator::IsNotNull => write!(self.buffer, " is not null")?,
470                Operator::Contains => write!(self.buffer, " @> ")?,
471                Operator::KeyExists => write!(self.buffer, " ? ")?,
472                _ => write!(self.buffer, " {:?} ", cond.op)?,
473            }
474
475            // Some operators like IsNull don't need a value printed
476            if !matches!(cond.op, Operator::IsNull | Operator::IsNotNull) {
477                self.format_value(&cond.value)?;
478            }
479        }
480        Ok(())
481    }
482
483    fn format_value(&mut self, val: &Value) -> Result {
484        match val {
485            Value::Null => write!(self.buffer, "null")?,
486            Value::Bool(b) => write!(self.buffer, "{}", b)?,
487            Value::Int(n) => write!(self.buffer, "{}", n)?,
488            Value::Float(n) => write!(self.buffer, "{}", n)?,
489            Value::Param(n) => write!(self.buffer, "${}", n)?,
490            Value::Function(f) => write!(self.buffer, "{}", f)?,
491            Value::Column(c) => write!(self.buffer, "{}", c)?,
492            Value::String(s) => write!(self.buffer, "'{}'", s)?, // Simple quoting, might need escaping
493            // Value::Date and Value::Interval are not in AST, likely Strings
494            // Value::Date(d) => write!(self.buffer, "'{}'", d)?,
495            // Value::Interval(i) => write!(self.buffer, "interval '{}'", i)?,
496            Value::Array(arr) => {
497                write!(self.buffer, "[")?;
498                for (i, v) in arr.iter().enumerate() {
499                    if i > 0 {
500                        write!(self.buffer, ", ")?;
501                    }
502                    self.format_value(v)?;
503                }
504                write!(self.buffer, "]")?;
505            }
506            Value::NamedParam(name) => write!(self.buffer, ":{}", name)?,
507            Value::Uuid(u) => write!(self.buffer, "'{}'", u)?,
508            Value::NullUuid => write!(self.buffer, "null")?,
509            Value::Interval { amount, unit } => {
510                write!(self.buffer, "interval '{} {}'", amount, unit)?
511            }
512            Value::Timestamp(ts) => write!(self.buffer, "'{}'", ts)?,
513            Value::Bytes(bytes) => {
514                write!(self.buffer, "'\\x")?;
515                for byte in bytes {
516                    write!(self.buffer, "{:02x}", byte)?;
517                }
518                write!(self.buffer, "'")?;
519            }
520            Value::Subquery(cmd) => {
521                write!(self.buffer, "(")?;
522                self.visit_cmd(cmd)?;
523                write!(self.buffer, ")")?;
524            }
525            Value::Expr(expr) => write!(self.buffer, "{}", expr)?,
526            Value::Vector(v) => {
527                write!(self.buffer, "[")?;
528                for (i, val) in v.iter().enumerate() {
529                    if i > 0 {
530                        write!(self.buffer, ", ")?;
531                    }
532                    write!(self.buffer, "{}", val)?;
533                }
534                write!(self.buffer, "]")?;
535            }
536            Value::Json(json) => write!(self.buffer, "'{}'::jsonb", json.replace('\'', "''"))?,
537        }
538        Ok(())
539    }
540
541    fn format_sort_order(&mut self, order: SortOrder) -> Result {
542        match order {
543            SortOrder::Asc => {}
544            SortOrder::Desc => write!(self.buffer, " desc")?,
545            SortOrder::AscNullsFirst => write!(self.buffer, " nulls first")?,
546            SortOrder::AscNullsLast => write!(self.buffer, " nulls last")?,
547            SortOrder::DescNullsFirst => write!(self.buffer, " desc nulls first")?,
548            SortOrder::DescNullsLast => write!(self.buffer, " desc nulls last")?,
549        }
550        Ok(())
551    }
552}