1use crate::ast::{
2 Action, Cage, CageKind, Condition, Expr, Join, LogicalOp, Operator, QailCmd, SortOrder, Value,
3};
4use std::fmt::{Result, Write};
5
6#[cfg(test)]
7mod tests;
8
9pub struct Formatter {
10 indent_level: usize,
11 buffer: String,
12}
13
14impl Default for Formatter {
15 fn default() -> Self {
16 Self::new()
17 }
18}
19
20impl Formatter {
21 pub fn new() -> Self {
22 Self {
23 indent_level: 0,
24 buffer: String::new(),
25 }
26 }
27
28 pub fn format(mut self, cmd: &QailCmd) -> std::result::Result<String, std::fmt::Error> {
29 self.visit_cmd(cmd)?;
30 Ok(self.buffer)
31 }
32
33 fn indent(&mut self) -> Result {
34 for _ in 0..self.indent_level {
35 write!(self.buffer, " ")?;
36 }
37 Ok(())
38 }
39
40 fn visit_cmd(&mut self, cmd: &QailCmd) -> Result {
41 for cte in &cmd.ctes {
43 write!(self.buffer, "with {} = ", cte.name)?;
44 self.indent_level += 1;
45 writeln!(self.buffer)?;
46 self.indent()?;
47 self.visit_cmd(&cte.base_query)?;
48
49 if cte.recursive
51 && let Some(ref recursive_query) = cte.recursive_query
52 {
53 writeln!(self.buffer)?;
54 self.indent()?;
55 writeln!(self.buffer, "union all")?;
56 self.indent()?;
57 self.visit_cmd(recursive_query)?;
58 }
59
60 self.indent_level -= 1;
61 writeln!(self.buffer)?;
62 }
63
64 match cmd.action {
66 Action::Get => write!(self.buffer, "get {}", cmd.table)?,
67 Action::Set => write!(self.buffer, "set {}", cmd.table)?,
68 Action::Del => write!(self.buffer, "del {}", cmd.table)?,
69 Action::Add => write!(self.buffer, "add {}", cmd.table)?,
70 _ => write!(self.buffer, "{} {}", cmd.action, cmd.table)?, }
72 writeln!(self.buffer)?;
73
74 if !cmd.columns.is_empty() {
89 if !(cmd.columns.len() == 1 && matches!(cmd.columns[0], Expr::Star)) {
94 self.indent()?;
95 writeln!(self.buffer, "fields")?;
96 self.indent_level += 1;
97 for (i, col) in cmd.columns.iter().enumerate() {
98 self.indent()?;
99 self.format_column(col)?;
100 if i < cmd.columns.len() - 1 {
101 writeln!(self.buffer, ",")?;
102 } else {
103 writeln!(self.buffer)?;
104 }
105 }
106 self.indent_level -= 1;
107 }
108 }
109
110 for join in &cmd.joins {
112 self.indent()?;
113 self.format_join(join)?;
114 writeln!(self.buffer)?;
115 }
116
117 let filters: Vec<&Cage> = cmd
119 .cages
120 .iter()
121 .filter(|c| matches!(c.kind, CageKind::Filter))
122 .collect();
123 if !filters.is_empty() {
124 self.indent()?;
127 write!(self.buffer, "where ")?;
128 for (i, cage) in filters.iter().enumerate() {
129 if i > 0 {
130 write!(self.buffer, " and ")?; }
132 self.format_conditions(&cage.conditions, cage.logical_op)?;
133 }
134 writeln!(self.buffer)?;
135 }
136
137 let sorts: Vec<&Cage> = cmd
139 .cages
140 .iter()
141 .filter(|c| matches!(c.kind, CageKind::Sort(_)))
142 .collect();
143 if !sorts.is_empty() {
144 self.indent()?;
145 writeln!(self.buffer, "order by")?;
146 self.indent_level += 1;
147 for (i, cage) in sorts.iter().enumerate() {
148 if let CageKind::Sort(order) = cage.kind {
149 for (j, cond) in cage.conditions.iter().enumerate() {
150 self.indent()?;
151 write!(self.buffer, "{}", cond.left)?;
152 self.format_sort_order(order)?;
153 if i < sorts.len() - 1 || j < cage.conditions.len() - 1 {
154 writeln!(self.buffer, ",")?;
155 } else {
156 writeln!(self.buffer)?;
157 }
158 }
159 }
160 }
161 self.indent_level -= 1;
162 }
163
164 for cage in &cmd.cages {
166 match cage.kind {
167 CageKind::Limit(n) => {
168 self.indent()?;
169 writeln!(self.buffer, "limit {}", n)?;
170 }
171 CageKind::Offset(n) => {
172 self.indent()?;
173 writeln!(self.buffer, "offset {}", n)?;
174 }
175 _ => {}
176 }
177 }
178
179 Ok(())
181 }
182
183 fn format_column(&mut self, col: &Expr) -> Result {
184 match col {
185 Expr::Star => write!(self.buffer, "*")?,
186 Expr::Named(name) => write!(self.buffer, "{}", name)?,
187 Expr::Aliased { name, alias } => write!(self.buffer, "{} as {}", name, alias)?,
188 Expr::Aggregate {
189 col,
190 func,
191 distinct,
192 filter,
193 alias,
194 } => {
195 let func_name = match func {
196 crate::ast::AggregateFunc::Count => "count",
197 crate::ast::AggregateFunc::Sum => "sum",
198 crate::ast::AggregateFunc::Avg => "avg",
199 crate::ast::AggregateFunc::Min => "min",
200 crate::ast::AggregateFunc::Max => "max",
201 crate::ast::AggregateFunc::ArrayAgg => "array_agg",
202 crate::ast::AggregateFunc::StringAgg => "string_agg",
203 crate::ast::AggregateFunc::JsonAgg => "json_agg",
204 crate::ast::AggregateFunc::JsonbAgg => "jsonb_agg",
205 crate::ast::AggregateFunc::BoolAnd => "bool_and",
206 crate::ast::AggregateFunc::BoolOr => "bool_or",
207 };
208 if *distinct {
209 write!(self.buffer, "{}(distinct {})", func_name, col)?;
210 } else {
211 write!(self.buffer, "{}({})", func_name, col)?;
212 }
213 if let Some(conditions) = filter {
214 write!(
215 self.buffer,
216 " filter (where {})",
217 conditions
218 .iter()
219 .map(|c| c.to_string())
220 .collect::<Vec<_>>()
221 .join(" and ")
222 )?;
223 }
224 if let Some(a) = alias {
225 write!(self.buffer, " as {}", a)?;
226 }
227 }
228 Expr::FunctionCall { name, args, alias } => {
229 let args_str: Vec<String> = args.iter().map(|a| a.to_string()).collect();
230 write!(self.buffer, "{}({})", name, args_str.join(", "))?;
231 if let Some(a) = alias {
232 write!(self.buffer, " as {}", a)?;
233 }
234 }
235 _ => write!(self.buffer, "/* TODO: {:?} */", col)?,
237 }
238 Ok(())
239 }
240
241 fn format_join(&mut self, join: &Join) -> Result {
242 match join.kind {
243 crate::ast::JoinKind::Inner => write!(self.buffer, "join {}", join.table)?,
244 crate::ast::JoinKind::Left => write!(self.buffer, "left join {}", join.table)?,
245 crate::ast::JoinKind::Right => write!(self.buffer, "right join {}", join.table)?,
246 crate::ast::JoinKind::Full => write!(self.buffer, "full join {}", join.table)?,
247 crate::ast::JoinKind::Cross => write!(self.buffer, "cross join {}", join.table)?,
248 crate::ast::JoinKind::Lateral => write!(self.buffer, "lateral join {}", join.table)?,
249 }
250
251 if let Some(conditions) = &join.on
252 && !conditions.is_empty()
253 {
254 writeln!(self.buffer)?;
255 self.indent_level += 1;
256 self.indent()?;
257 write!(self.buffer, "on ")?;
258 self.format_conditions(conditions, LogicalOp::And)?;
259 self.indent_level -= 1;
260 }
261 Ok(())
262 }
263
264 fn format_conditions(&mut self, conditions: &[Condition], logical_op: LogicalOp) -> Result {
265 for (i, cond) in conditions.iter().enumerate() {
266 if i > 0 {
267 match logical_op {
268 LogicalOp::And => write!(self.buffer, " and ")?,
269 LogicalOp::Or => write!(self.buffer, " or ")?,
270 }
271 }
272
273 write!(self.buffer, "{}", cond.left)?;
274
275 match cond.op {
276 Operator::Eq => write!(self.buffer, " = ")?,
277 Operator::Ne => write!(self.buffer, " != ")?,
278 Operator::Gt => write!(self.buffer, " > ")?,
279 Operator::Gte => write!(self.buffer, " >= ")?,
280 Operator::Lt => write!(self.buffer, " < ")?,
281 Operator::Lte => write!(self.buffer, " <= ")?,
282 Operator::Fuzzy => write!(self.buffer, " ~ ")?, Operator::In => write!(self.buffer, " in ")?,
284 Operator::NotIn => write!(self.buffer, " not in ")?,
285 Operator::IsNull => write!(self.buffer, " is null")?,
286 Operator::IsNotNull => write!(self.buffer, " is not null")?,
287 Operator::Contains => write!(self.buffer, " @> ")?,
288 Operator::KeyExists => write!(self.buffer, " ? ")?,
289 _ => write!(self.buffer, " {:?} ", cond.op)?,
290 }
291
292 if !matches!(cond.op, Operator::IsNull | Operator::IsNotNull) {
294 self.format_value(&cond.value)?;
295 }
296 }
297 Ok(())
298 }
299
300 fn format_value(&mut self, val: &Value) -> Result {
301 match val {
302 Value::Null => write!(self.buffer, "null")?,
303 Value::Bool(b) => write!(self.buffer, "{}", b)?,
304 Value::Int(n) => write!(self.buffer, "{}", n)?,
305 Value::Float(n) => write!(self.buffer, "{}", n)?,
306 Value::Param(n) => write!(self.buffer, "${}", n)?,
307 Value::Function(f) => write!(self.buffer, "{}", f)?,
308 Value::Column(c) => write!(self.buffer, "{}", c)?,
309 Value::String(s) => write!(self.buffer, "'{}'", s)?, Value::Array(arr) => {
314 write!(self.buffer, "[")?;
315 for (i, v) in arr.iter().enumerate() {
316 if i > 0 {
317 write!(self.buffer, ", ")?;
318 }
319 self.format_value(v)?;
320 }
321 write!(self.buffer, "]")?;
322 }
323 _ => write!(self.buffer, "{:?}", val)?,
325 }
326 Ok(())
327 }
328
329 fn format_sort_order(&mut self, order: SortOrder) -> Result {
330 match order {
331 SortOrder::Asc => {}
332 SortOrder::Desc => write!(self.buffer, " desc")?,
333 SortOrder::AscNullsFirst => write!(self.buffer, " nulls first")?,
334 SortOrder::AscNullsLast => write!(self.buffer, " nulls last")?,
335 SortOrder::DescNullsFirst => write!(self.buffer, " desc nulls first")?,
336 SortOrder::DescNullsLast => write!(self.buffer, " desc nulls last")?,
337 }
338 Ok(())
339 }
340}