1use crate::ast::{
2 Action, Cage, CageKind, Condition, Expr, Join, LogicalOp, Operator, Qail, SortOrder, Value,
3};
4use std::fmt::{Result, Write};
5
6#[cfg(test)]
7mod tests;
8
9pub struct Formatter {
10 indent_level: usize,
11 buffer: String,
12}
13
14impl Default for Formatter {
15 fn default() -> Self {
16 Self::new()
17 }
18}
19
20impl Formatter {
21 pub fn new() -> Self {
22 Self {
23 indent_level: 0,
24 buffer: String::new(),
25 }
26 }
27
28 pub fn format(mut self, cmd: &Qail) -> std::result::Result<String, std::fmt::Error> {
29 self.visit_cmd(cmd)?;
30 Ok(self.buffer)
31 }
32
33 fn indent(&mut self) -> Result {
34 for _ in 0..self.indent_level {
35 write!(self.buffer, " ")?;
36 }
37 Ok(())
38 }
39
40 fn visit_cmd(&mut self, cmd: &Qail) -> Result {
41 for cte in &cmd.ctes {
42 write!(self.buffer, "with {} = ", cte.name)?;
43 self.indent_level += 1;
44 writeln!(self.buffer)?;
45 self.indent()?;
46 self.visit_cmd(&cte.base_query)?;
47
48 if cte.recursive
49 && let Some(ref recursive_query) = cte.recursive_query
50 {
51 writeln!(self.buffer)?;
52 self.indent()?;
53 writeln!(self.buffer, "union all")?;
54 self.indent()?;
55 self.visit_cmd(recursive_query)?;
56 }
57
58 self.indent_level -= 1;
59 writeln!(self.buffer)?;
60 }
61
62 match cmd.action {
64 Action::Get => write!(self.buffer, "get {}", cmd.table)?,
65 Action::Set => write!(self.buffer, "set {}", cmd.table)?,
66 Action::Del => write!(self.buffer, "del {}", cmd.table)?,
67 Action::Add => write!(self.buffer, "add {}", cmd.table)?,
68 _ => write!(self.buffer, "{} {}", cmd.action, cmd.table)?, }
70 writeln!(self.buffer)?;
71
72 if !cmd.columns.is_empty() {
86 if !(cmd.columns.len() == 1 && matches!(cmd.columns[0], Expr::Star)) {
90 self.indent()?;
91 writeln!(self.buffer, "fields")?;
92 self.indent_level += 1;
93 for (i, col) in cmd.columns.iter().enumerate() {
94 self.indent()?;
95 self.format_column(col)?;
96 if i < cmd.columns.len() - 1 {
97 writeln!(self.buffer, ",")?;
98 } else {
99 writeln!(self.buffer)?;
100 }
101 }
102 self.indent_level -= 1;
103 }
104 }
105
106 for join in &cmd.joins {
108 self.indent()?;
109 self.format_join(join)?;
110 writeln!(self.buffer)?;
111 }
112
113 let filters: Vec<&Cage> = cmd
115 .cages
116 .iter()
117 .filter(|c| matches!(c.kind, CageKind::Filter))
118 .collect();
119 if !filters.is_empty() {
120 self.indent()?;
123 write!(self.buffer, "where ")?;
124 for (i, cage) in filters.iter().enumerate() {
125 if i > 0 {
126 write!(self.buffer, " and ")?; }
128 self.format_conditions(&cage.conditions, cage.logical_op)?;
129 }
130 writeln!(self.buffer)?;
131 }
132
133 let sorts: Vec<&Cage> = cmd
135 .cages
136 .iter()
137 .filter(|c| matches!(c.kind, CageKind::Sort(_)))
138 .collect();
139 if !sorts.is_empty() {
140 self.indent()?;
141 writeln!(self.buffer, "order by")?;
142 self.indent_level += 1;
143 for (i, cage) in sorts.iter().enumerate() {
144 if let CageKind::Sort(order) = cage.kind {
145 for (j, cond) in cage.conditions.iter().enumerate() {
146 self.indent()?;
147 write!(self.buffer, "{}", cond.left)?;
148 self.format_sort_order(order)?;
149 if i < sorts.len() - 1 || j < cage.conditions.len() - 1 {
150 writeln!(self.buffer, ",")?;
151 } else {
152 writeln!(self.buffer)?;
153 }
154 }
155 }
156 }
157 self.indent_level -= 1;
158 }
159
160 for cage in &cmd.cages {
161 match cage.kind {
162 CageKind::Limit(n) => {
163 self.indent()?;
164 writeln!(self.buffer, "limit {}", n)?;
165 }
166 CageKind::Offset(n) => {
167 self.indent()?;
168 writeln!(self.buffer, "offset {}", n)?;
169 }
170 _ => {}
171 }
172 }
173
174 Ok(())
176 }
177
178 fn format_column(&mut self, col: &Expr) -> Result {
179 match col {
180 Expr::Star => write!(self.buffer, "*")?,
181 Expr::Named(name) => write!(self.buffer, "{}", name)?,
182 Expr::Aliased { name, alias } => write!(self.buffer, "{} as {}", name, alias)?,
183 Expr::Aggregate {
184 col,
185 func,
186 distinct,
187 filter,
188 alias,
189 } => {
190 let func_name = match func {
191 crate::ast::AggregateFunc::Count => "count",
192 crate::ast::AggregateFunc::Sum => "sum",
193 crate::ast::AggregateFunc::Avg => "avg",
194 crate::ast::AggregateFunc::Min => "min",
195 crate::ast::AggregateFunc::Max => "max",
196 crate::ast::AggregateFunc::ArrayAgg => "array_agg",
197 crate::ast::AggregateFunc::StringAgg => "string_agg",
198 crate::ast::AggregateFunc::JsonAgg => "json_agg",
199 crate::ast::AggregateFunc::JsonbAgg => "jsonb_agg",
200 crate::ast::AggregateFunc::BoolAnd => "bool_and",
201 crate::ast::AggregateFunc::BoolOr => "bool_or",
202 };
203 if *distinct {
204 write!(self.buffer, "{}(distinct {})", func_name, col)?;
205 } else {
206 write!(self.buffer, "{}({})", func_name, col)?;
207 }
208 if let Some(conditions) = filter {
209 write!(
210 self.buffer,
211 " filter (where {})",
212 conditions
213 .iter()
214 .map(|c| c.to_string())
215 .collect::<Vec<_>>()
216 .join(" and ")
217 )?;
218 }
219 if let Some(a) = alias {
220 write!(self.buffer, " as {}", a)?;
221 }
222 }
223 Expr::FunctionCall { name, args, alias } => {
224 let args_str: Vec<String> = args.iter().map(|a| a.to_string()).collect();
225 write!(self.buffer, "{}({})", name, args_str.join(", "))?;
226 if let Some(a) = alias {
227 write!(self.buffer, " as {}", a)?;
228 }
229 }
230 Expr::Window { name, func, params, partition, .. } => {
231 let params_str: Vec<String> = params.iter().map(|p| p.to_string()).collect();
233 write!(self.buffer, "{}({})", func, params_str.join(", "))?;
234 if !partition.is_empty() {
235 write!(self.buffer, " over (partition by {})", partition.join(", "))?;
236 }
237 write!(self.buffer, " as {}", name)?;
238 }
239 Expr::Case { when_clauses, else_value, alias } => {
240 write!(self.buffer, "case")?;
241 for (cond, val) in when_clauses {
242 write!(self.buffer, " when {} then {}", cond.left, val)?;
243 }
244 if let Some(e) = else_value {
245 write!(self.buffer, " else {}", e)?;
246 }
247 write!(self.buffer, " end")?;
248 if let Some(a) = alias {
249 write!(self.buffer, " as {}", a)?;
250 }
251 }
252 Expr::JsonAccess { column, path_segments, alias } => {
253 write!(self.buffer, "{}", column)?;
254 for (path, as_text) in path_segments {
255 let op = if *as_text { "->>" } else { "->" };
256 if path.parse::<i64>().is_ok() {
257 write!(self.buffer, "{}{}", op, path)?;
258 } else {
259 write!(self.buffer, "{}'{}'", op, path)?;
260 }
261 }
262 if let Some(a) = alias {
263 write!(self.buffer, " as {}", a)?;
264 }
265 }
266 Expr::Cast { expr, target_type, alias } => {
267 write!(self.buffer, "{}::{}", expr, target_type)?;
268 if let Some(a) = alias {
269 write!(self.buffer, " as {}", a)?;
270 }
271 }
272 Expr::Binary { left, op, right, alias } => {
273 write!(self.buffer, "({} {} {})", left, op, right)?;
274 if let Some(a) = alias {
275 write!(self.buffer, " as {}", a)?;
276 }
277 }
278 Expr::SpecialFunction { name, args, alias } => {
279 write!(self.buffer, "{}(", name)?;
280 for (i, (keyword, expr)) in args.iter().enumerate() {
281 if i > 0 { write!(self.buffer, " ")?; }
282 if let Some(kw) = keyword {
283 write!(self.buffer, "{} ", kw)?;
284 }
285 write!(self.buffer, "{}", expr)?;
286 }
287 write!(self.buffer, ")")?;
288 if let Some(a) = alias {
289 write!(self.buffer, " as {}", a)?;
290 }
291 }
292 Expr::Literal(val) => self.format_value(val)?,
293 Expr::Def { name, data_type, constraints } => {
294 write!(self.buffer, "{}:{}", name, data_type)?;
295 for c in constraints {
296 write!(self.buffer, "^{}", c)?;
297 }
298 }
299 Expr::Mod { kind, col } => {
300 let prefix = match kind { crate::ast::ModKind::Add => "+", crate::ast::ModKind::Drop => "-" };
301 write!(self.buffer, "{}{}", prefix, col)?;
302 }
303 }
304 Ok(())
305 }
306
307 fn format_join(&mut self, join: &Join) -> Result {
308 match join.kind {
309 crate::ast::JoinKind::Inner => write!(self.buffer, "join {}", join.table)?,
310 crate::ast::JoinKind::Left => write!(self.buffer, "left join {}", join.table)?,
311 crate::ast::JoinKind::Right => write!(self.buffer, "right join {}", join.table)?,
312 crate::ast::JoinKind::Full => write!(self.buffer, "full join {}", join.table)?,
313 crate::ast::JoinKind::Cross => write!(self.buffer, "cross join {}", join.table)?,
314 crate::ast::JoinKind::Lateral => write!(self.buffer, "lateral join {}", join.table)?,
315 }
316
317 if let Some(conditions) = &join.on
318 && !conditions.is_empty()
319 {
320 writeln!(self.buffer)?;
321 self.indent_level += 1;
322 self.indent()?;
323 write!(self.buffer, "on ")?;
324 self.format_conditions(conditions, LogicalOp::And)?;
325 self.indent_level -= 1;
326 }
327 Ok(())
328 }
329
330 fn format_conditions(&mut self, conditions: &[Condition], logical_op: LogicalOp) -> Result {
331 for (i, cond) in conditions.iter().enumerate() {
332 if i > 0 {
333 match logical_op {
334 LogicalOp::And => write!(self.buffer, " and ")?,
335 LogicalOp::Or => write!(self.buffer, " or ")?,
336 }
337 }
338
339 write!(self.buffer, "{}", cond.left)?;
340
341 match cond.op {
342 Operator::Eq => write!(self.buffer, " = ")?,
343 Operator::Ne => write!(self.buffer, " != ")?,
344 Operator::Gt => write!(self.buffer, " > ")?,
345 Operator::Gte => write!(self.buffer, " >= ")?,
346 Operator::Lt => write!(self.buffer, " < ")?,
347 Operator::Lte => write!(self.buffer, " <= ")?,
348 Operator::Fuzzy => write!(self.buffer, " ~ ")?, Operator::In => write!(self.buffer, " in ")?,
350 Operator::NotIn => write!(self.buffer, " not in ")?,
351 Operator::IsNull => write!(self.buffer, " is null")?,
352 Operator::IsNotNull => write!(self.buffer, " is not null")?,
353 Operator::Contains => write!(self.buffer, " @> ")?,
354 Operator::KeyExists => write!(self.buffer, " ? ")?,
355 _ => write!(self.buffer, " {:?} ", cond.op)?,
356 }
357
358 if !matches!(cond.op, Operator::IsNull | Operator::IsNotNull) {
360 self.format_value(&cond.value)?;
361 }
362 }
363 Ok(())
364 }
365
366 fn format_value(&mut self, val: &Value) -> Result {
367 match val {
368 Value::Null => write!(self.buffer, "null")?,
369 Value::Bool(b) => write!(self.buffer, "{}", b)?,
370 Value::Int(n) => write!(self.buffer, "{}", n)?,
371 Value::Float(n) => write!(self.buffer, "{}", n)?,
372 Value::Param(n) => write!(self.buffer, "${}", n)?,
373 Value::Function(f) => write!(self.buffer, "{}", f)?,
374 Value::Column(c) => write!(self.buffer, "{}", c)?,
375 Value::String(s) => write!(self.buffer, "'{}'", s)?, Value::Array(arr) => {
380 write!(self.buffer, "[")?;
381 for (i, v) in arr.iter().enumerate() {
382 if i > 0 {
383 write!(self.buffer, ", ")?;
384 }
385 self.format_value(v)?;
386 }
387 write!(self.buffer, "]")?;
388 }
389 Value::NamedParam(name) => write!(self.buffer, ":{}", name)?,
390 Value::Uuid(u) => write!(self.buffer, "'{}'", u)?,
391 Value::NullUuid => write!(self.buffer, "null")?,
392 Value::Interval { amount, unit } => write!(self.buffer, "interval '{} {}'", amount, unit)?,
393 Value::Timestamp(ts) => write!(self.buffer, "'{}'", ts)?,
394 Value::Bytes(bytes) => {
395 write!(self.buffer, "'\\x")?;
396 for byte in bytes { write!(self.buffer, "{:02x}", byte)?; }
397 write!(self.buffer, "'")?;
398 }
399 Value::Subquery(cmd) => {
400 write!(self.buffer, "(")?;
401 self.visit_cmd(cmd)?;
402 write!(self.buffer, ")")?;
403 }
404 Value::Expr(expr) => write!(self.buffer, "{}", expr)?,
405 }
406 Ok(())
407 }
408
409 fn format_sort_order(&mut self, order: SortOrder) -> Result {
410 match order {
411 SortOrder::Asc => {}
412 SortOrder::Desc => write!(self.buffer, " desc")?,
413 SortOrder::AscNullsFirst => write!(self.buffer, " nulls first")?,
414 SortOrder::AscNullsLast => write!(self.buffer, " nulls last")?,
415 SortOrder::DescNullsFirst => write!(self.buffer, " desc nulls first")?,
416 SortOrder::DescNullsLast => write!(self.buffer, " desc nulls last")?,
417 }
418 Ok(())
419 }
420}