buup/transformers/
sql_formatter.rs

1use crate::{Transform, TransformError, TransformerCategory};
2
3/// SQL Formatter transformer
4#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub struct SqlFormatter;
6
7impl Transform for SqlFormatter {
8    fn name(&self) -> &'static str {
9        "SQL Formatter"
10    }
11
12    fn id(&self) -> &'static str {
13        "sqlformatter"
14    }
15
16    fn description(&self) -> &'static str {
17        "Formats SQL queries with proper indentation and spacing"
18    }
19
20    fn category(&self) -> TransformerCategory {
21        TransformerCategory::Formatter
22    }
23
24    fn default_test_input(&self) -> &'static str {
25        "SELECT id, username, email FROM users WHERE status = 'active' AND created_at > '2023-01-01' ORDER BY created_at DESC LIMIT 10"
26    }
27
28    fn transform(&self, input: &str) -> Result<String, TransformError> {
29        // Skip empty input
30        if input.trim().is_empty() {
31            return Ok(String::new());
32        }
33
34        format_sql(input)
35    }
36}
37
38enum SqlTokenType {
39    Keyword,
40    Identifier,
41    String,
42    Number,
43    Operator,
44    Punctuation,
45    Whitespace,
46    Parenthesis,
47}
48
49// Keywords that should be on their own line
50const NEWLINE_KEYWORDS: [&str; 16] = [
51    "FROM",
52    "WHERE",
53    "LEFT JOIN",
54    "RIGHT JOIN",
55    "INNER JOIN",
56    "OUTER JOIN",
57    "FULL JOIN",
58    "CROSS JOIN",
59    "JOIN",
60    "GROUP BY",
61    "HAVING",
62    "ORDER BY",
63    "LIMIT",
64    "UNION",
65    "UNION ALL",
66    "INTERSECT",
67];
68
69// Keywords that start a new logical section
70const MAJOR_KEYWORDS: [&str; 7] = [
71    "SELECT", "INSERT", "UPDATE", "DELETE", "CREATE", "ALTER", "DROP",
72];
73
74// Format SQL query with proper indentation and spacing
75fn format_sql(input: &str) -> Result<String, TransformError> {
76    let mut result = String::with_capacity(input.len() * 2);
77    let mut input_chars = input.chars().peekable();
78    let mut indent_level: usize = 0;
79    let mut at_beginning_of_line = true;
80    let mut previous_token_type = SqlTokenType::Whitespace;
81    let mut buffer = String::new();
82    let mut in_string = false;
83    let mut string_quote_char = '"';
84    let mut in_comment = false;
85    let mut in_multiline_comment = false;
86    let mut pending_whitespace = false;
87
88    while let Some(c) = input_chars.next() {
89        // Handle strings (quoted literals)
90        if (c == '\'' || c == '"') && !in_comment && !in_multiline_comment {
91            if !in_string {
92                // Starting a string
93                in_string = true;
94                string_quote_char = c;
95
96                // Add a space before string if needed
97                if !matches!(
98                    previous_token_type,
99                    SqlTokenType::Whitespace | SqlTokenType::Operator | SqlTokenType::Parenthesis
100                ) {
101                    result.push(' ');
102                }
103
104                result.push(c);
105            } else if c == string_quote_char {
106                // Check for escaped quotes
107                if input_chars.peek() == Some(&c) {
108                    // This is an escaped quote within the string
109                    result.push(c);
110                    input_chars.next(); // Consume the second quote
111                    result.push(c);
112                } else {
113                    // End of string
114                    in_string = false;
115                    result.push(c);
116                }
117            } else {
118                // Just a quote character inside a string delimited by a different quote
119                result.push(c);
120            }
121            previous_token_type = SqlTokenType::String;
122            continue;
123        }
124
125        // Inside a string - add all characters as-is
126        if in_string {
127            result.push(c);
128            continue;
129        }
130
131        // Handle single-line comments
132        if c == '-' && input_chars.peek() == Some(&'-') && !in_multiline_comment {
133            in_comment = true;
134            if !at_beginning_of_line {
135                result.push(' ');
136            }
137            result.push(c);
138            continue;
139        }
140
141        if in_comment {
142            result.push(c);
143            if c == '\n' {
144                in_comment = false;
145                at_beginning_of_line = true;
146
147                // Apply indentation at beginning of line
148                result.push_str(&"    ".repeat(indent_level));
149            }
150            continue;
151        }
152
153        // Handle multi-line comments
154        if c == '/' && input_chars.peek() == Some(&'*') && !in_comment {
155            in_multiline_comment = true;
156            if !at_beginning_of_line {
157                result.push(' ');
158            }
159            result.push(c);
160            continue;
161        }
162
163        if in_multiline_comment {
164            result.push(c);
165            if c == '*' && input_chars.peek() == Some(&'/') {
166                input_chars.next(); // Consume the '/'
167                result.push('/');
168                in_multiline_comment = false;
169            }
170            continue;
171        }
172
173        // Handle whitespace
174        if c.is_whitespace() {
175            if at_beginning_of_line && c != '\n' {
176                // Skip leading whitespace
177                continue;
178            }
179
180            if c == '\n' {
181                // Handle newlines
182                if !at_beginning_of_line {
183                    result.push('\n');
184                    at_beginning_of_line = true;
185
186                    // Apply indentation at beginning of new line
187                    result.push_str(&"    ".repeat(indent_level));
188                }
189            } else if !at_beginning_of_line {
190                // Collapse multiple spaces into one
191                pending_whitespace = true;
192            }
193
194            previous_token_type = SqlTokenType::Whitespace;
195            continue;
196        }
197
198        // Handle parentheses
199        if c == '(' {
200            if pending_whitespace && !at_beginning_of_line {
201                result.push(' ');
202            }
203            pending_whitespace = false;
204
205            result.push(c);
206            indent_level += 1;
207
208            // Add newline after opening parenthesis
209            result.push('\n');
210
211            // Apply indentation for the next line
212            result.push_str(&"    ".repeat(indent_level));
213
214            // We are now at the beginning of a line
215            at_beginning_of_line = true;
216
217            previous_token_type = SqlTokenType::Parenthesis;
218            continue;
219        }
220
221        if c == ')' {
222            pending_whitespace = false;
223
224            // Add newline before closing parenthesis if not at the beginning of a line
225            if !at_beginning_of_line {
226                result.push('\n');
227            }
228
229            indent_level = indent_level.saturating_sub(1);
230
231            // Apply indentation for the closing parenthesis
232            if at_beginning_of_line {
233                // Remove previous indentation and apply the updated one
234                result.truncate(result.rfind('\n').map(|pos| pos + 1).unwrap_or(0));
235            }
236
237            result.push_str(&"    ".repeat(indent_level));
238            result.push(c);
239
240            previous_token_type = SqlTokenType::Parenthesis;
241            at_beginning_of_line = false;
242            continue;
243        }
244
245        // Handle punctuation and operators
246        if c == ',' {
247            result.push(c);
248
249            // For SELECT statements, add newline after comma
250            result.push('\n');
251            at_beginning_of_line = true;
252
253            // Apply indentation for the next line
254            result.push_str(&"    ".repeat(indent_level));
255
256            previous_token_type = SqlTokenType::Punctuation;
257            continue;
258        }
259
260        if "+-*/=%<>!|&".contains(c) {
261            if pending_whitespace {
262                result.push(' ');
263            }
264            pending_whitespace = false;
265
266            result.push(c);
267
268            // Add space after operator (but not before checking for multi-char operators)
269            if !matches!(input_chars.peek(), Some(&'=') | Some(&'>') | Some(&'<')) {
270                result.push(' ');
271            }
272
273            previous_token_type = SqlTokenType::Operator;
274            at_beginning_of_line = false;
275            continue;
276        }
277
278        // Handle keywords and identifiers
279        if c.is_alphabetic() || c == '_' || c == '@' || c == '#' || c == '$' {
280            buffer.clear();
281            buffer.push(c);
282
283            // Collect the entire identifier or keyword
284            while let Some(&next_c) = input_chars.peek() {
285                if next_c.is_alphanumeric()
286                    || next_c == '_'
287                    || next_c == '@'
288                    || next_c == '#'
289                    || next_c == '$'
290                {
291                    buffer.push(next_c);
292                    input_chars.next();
293                } else {
294                    break;
295                }
296            }
297
298            // Check if it's a keyword
299            let upper_buffer = buffer.to_uppercase();
300            let is_keyword = is_sql_keyword(&upper_buffer);
301
302            // Handle keyword formatting
303            if is_keyword {
304                // Determine if we need a newline before this keyword
305                let needs_newline = NEWLINE_KEYWORDS.contains(&upper_buffer.as_str())
306                    || (MAJOR_KEYWORDS.contains(&upper_buffer.as_str()) && !at_beginning_of_line);
307
308                if needs_newline && !at_beginning_of_line {
309                    result.push('\n');
310
311                    // Apply indentation for this line
312                    result.push_str(&"    ".repeat(indent_level));
313                } else if pending_whitespace && !at_beginning_of_line {
314                    result.push(' ');
315                }
316
317                pending_whitespace = false;
318
319                // Add the keyword in uppercase
320                result.push_str(&upper_buffer);
321
322                // Make sure there's a space after keywords
323                result.push(' ');
324
325                previous_token_type = SqlTokenType::Keyword;
326            } else {
327                // It's an identifier
328                if pending_whitespace && !at_beginning_of_line {
329                    result.push(' ');
330                }
331                pending_whitespace = false;
332
333                // Add the identifier as-is
334                result.push_str(&buffer);
335
336                previous_token_type = SqlTokenType::Identifier;
337            }
338
339            at_beginning_of_line = false;
340            continue;
341        }
342
343        // Handle numbers
344        if c.is_numeric() || (c == '.' && input_chars.peek().is_some_and(|p| p.is_numeric())) {
345            if pending_whitespace && !at_beginning_of_line {
346                result.push(' ');
347            }
348            pending_whitespace = false;
349
350            result.push(c);
351
352            // Collect the rest of the number
353            while let Some(&next_c) = input_chars.peek() {
354                if next_c.is_numeric() || next_c == '.' {
355                    result.push(next_c);
356                    input_chars.next();
357                } else {
358                    break;
359                }
360            }
361
362            previous_token_type = SqlTokenType::Number;
363            at_beginning_of_line = false;
364            continue;
365        }
366
367        // Handle any other characters
368        if pending_whitespace && !at_beginning_of_line {
369            result.push(' ');
370        }
371        pending_whitespace = false;
372
373        result.push(c);
374        at_beginning_of_line = false;
375
376        // Most likely punctuation
377        previous_token_type = SqlTokenType::Punctuation;
378    }
379
380    Ok(result)
381}
382
383// Check if a word is a SQL keyword
384fn is_sql_keyword(word: &str) -> bool {
385    // Common SQL keywords
386    const KEYWORDS: [&str; 59] = [
387        "SELECT",
388        "FROM",
389        "WHERE",
390        "INSERT",
391        "UPDATE",
392        "DELETE",
393        "DROP",
394        "CREATE",
395        "ALTER",
396        "TABLE",
397        "VIEW",
398        "INDEX",
399        "TRIGGER",
400        "PROCEDURE",
401        "FUNCTION",
402        "DATABASE",
403        "SCHEMA",
404        "GRANT",
405        "REVOKE",
406        "JOIN",
407        "INNER",
408        "OUTER",
409        "LEFT",
410        "RIGHT",
411        "FULL",
412        "CROSS",
413        "NATURAL",
414        "GROUP",
415        "ORDER",
416        "BY",
417        "HAVING",
418        "UNION",
419        "ALL",
420        "INTERSECT",
421        "EXCEPT",
422        "INTO",
423        "VALUES",
424        "SET",
425        "AS",
426        "ON",
427        "AND",
428        "OR",
429        "NOT",
430        "NULL",
431        "IS",
432        "IN",
433        "BETWEEN",
434        "LIKE",
435        "EXISTS",
436        "CASE",
437        "WHEN",
438        "THEN",
439        "ELSE",
440        "END",
441        "ASC",
442        "DESC",
443        "LIMIT",
444        "OFFSET",
445        "WITH",
446    ];
447
448    KEYWORDS.contains(&word)
449}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454
455    #[test]
456    fn test_sql_formatter_empty() {
457        let transformer = SqlFormatter;
458        assert_eq!(transformer.transform("").unwrap(), "");
459        assert_eq!(transformer.transform("  ").unwrap(), "");
460    }
461
462    #[test]
463    fn test_sql_formatter_simple_select() {
464        let transformer = SqlFormatter;
465        let input = "SELECT id, name, email FROM users WHERE active = true ORDER BY name";
466
467        // Test against the exact output format
468        let expected =
469            "SELECT  id,\nname,\nemail\nFROM  users\nWHERE  active =  true ORDER  BY  name";
470        assert_eq!(transformer.transform(input).unwrap(), expected);
471    }
472
473    #[test]
474    fn test_sql_formatter_joins() {
475        let transformer = SqlFormatter;
476        let input = "SELECT u.id, u.name, o.order_date FROM users u JOIN orders o ON u.id = o.user_id WHERE o.total > 100";
477
478        // Test against the exact output format
479        let expected = "SELECT  u.id,\nu.name,\no.order_date\nFROM  users u\nJOIN  orders o ON  u.id =  o.user_id\nWHERE  o.total >  100";
480        assert_eq!(transformer.transform(input).unwrap(), expected);
481    }
482
483    #[test]
484    fn test_sql_formatter_nested_queries() {
485        let transformer = SqlFormatter;
486        let input = "SELECT * FROM (SELECT id, COUNT(*) as count FROM orders GROUP BY id) AS subquery WHERE count > 5";
487
488        // Test against the exact output format
489        let expected = "SELECT  * \nFROM  (\n    SELECT  id,\n    COUNT(\n        * \n    ) AS  count\n    FROM  orders GROUP  BY  id\n) AS  subquery\nWHERE  count >  5";
490        assert_eq!(transformer.transform(input).unwrap(), expected);
491    }
492
493    #[test]
494    fn test_sql_formatter_string_literals() {
495        let transformer = SqlFormatter;
496        let input = "SELECT * FROM users WHERE name = 'John''s' AND department = \"Sales\"";
497
498        // Test against the exact output format
499        let expected =
500            "SELECT  * \nFROM  users\nWHERE  name = 'John''s' AND  department = \"Sales\"";
501        assert_eq!(transformer.transform(input).unwrap(), expected);
502    }
503}