1use crate::{Transform, TransformError, TransformerCategory};
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub struct SqlFormatter;
6
7impl Transform for SqlFormatter {
8 fn name(&self) -> &'static str {
9 "SQL Formatter"
10 }
11
12 fn id(&self) -> &'static str {
13 "sqlformatter"
14 }
15
16 fn description(&self) -> &'static str {
17 "Formats SQL queries with proper indentation and spacing"
18 }
19
20 fn category(&self) -> TransformerCategory {
21 TransformerCategory::Formatter
22 }
23
24 fn default_test_input(&self) -> &'static str {
25 "SELECT id, username, email FROM users WHERE status = 'active' AND created_at > '2023-01-01' ORDER BY created_at DESC LIMIT 10"
26 }
27
28 fn transform(&self, input: &str) -> Result<String, TransformError> {
29 if input.trim().is_empty() {
31 return Ok(String::new());
32 }
33
34 format_sql(input)
35 }
36}
37
38enum SqlTokenType {
39 Keyword,
40 Identifier,
41 String,
42 Number,
43 Operator,
44 Punctuation,
45 Whitespace,
46 Parenthesis,
47}
48
49const NEWLINE_KEYWORDS: [&str; 16] = [
51 "FROM",
52 "WHERE",
53 "LEFT JOIN",
54 "RIGHT JOIN",
55 "INNER JOIN",
56 "OUTER JOIN",
57 "FULL JOIN",
58 "CROSS JOIN",
59 "JOIN",
60 "GROUP BY",
61 "HAVING",
62 "ORDER BY",
63 "LIMIT",
64 "UNION",
65 "UNION ALL",
66 "INTERSECT",
67];
68
69const MAJOR_KEYWORDS: [&str; 7] = [
71 "SELECT", "INSERT", "UPDATE", "DELETE", "CREATE", "ALTER", "DROP",
72];
73
74fn format_sql(input: &str) -> Result<String, TransformError> {
76 let mut result = String::with_capacity(input.len() * 2);
77 let mut input_chars = input.chars().peekable();
78 let mut indent_level: usize = 0;
79 let mut at_beginning_of_line = true;
80 let mut previous_token_type = SqlTokenType::Whitespace;
81 let mut buffer = String::new();
82 let mut in_string = false;
83 let mut string_quote_char = '"';
84 let mut in_comment = false;
85 let mut in_multiline_comment = false;
86 let mut pending_whitespace = false;
87
88 while let Some(c) = input_chars.next() {
89 if (c == '\'' || c == '"') && !in_comment && !in_multiline_comment {
91 if !in_string {
92 in_string = true;
94 string_quote_char = c;
95
96 if !matches!(
98 previous_token_type,
99 SqlTokenType::Whitespace | SqlTokenType::Operator | SqlTokenType::Parenthesis
100 ) {
101 result.push(' ');
102 }
103
104 result.push(c);
105 } else if c == string_quote_char {
106 if input_chars.peek() == Some(&c) {
108 result.push(c);
110 input_chars.next(); result.push(c);
112 } else {
113 in_string = false;
115 result.push(c);
116 }
117 } else {
118 result.push(c);
120 }
121 previous_token_type = SqlTokenType::String;
122 continue;
123 }
124
125 if in_string {
127 result.push(c);
128 continue;
129 }
130
131 if c == '-' && input_chars.peek() == Some(&'-') && !in_multiline_comment {
133 in_comment = true;
134 if !at_beginning_of_line {
135 result.push(' ');
136 }
137 result.push(c);
138 continue;
139 }
140
141 if in_comment {
142 result.push(c);
143 if c == '\n' {
144 in_comment = false;
145 at_beginning_of_line = true;
146
147 result.push_str(&" ".repeat(indent_level));
149 }
150 continue;
151 }
152
153 if c == '/' && input_chars.peek() == Some(&'*') && !in_comment {
155 in_multiline_comment = true;
156 if !at_beginning_of_line {
157 result.push(' ');
158 }
159 result.push(c);
160 continue;
161 }
162
163 if in_multiline_comment {
164 result.push(c);
165 if c == '*' && input_chars.peek() == Some(&'/') {
166 input_chars.next(); result.push('/');
168 in_multiline_comment = false;
169 }
170 continue;
171 }
172
173 if c.is_whitespace() {
175 if at_beginning_of_line && c != '\n' {
176 continue;
178 }
179
180 if c == '\n' {
181 if !at_beginning_of_line {
183 result.push('\n');
184 at_beginning_of_line = true;
185
186 result.push_str(&" ".repeat(indent_level));
188 }
189 } else if !at_beginning_of_line {
190 pending_whitespace = true;
192 }
193
194 previous_token_type = SqlTokenType::Whitespace;
195 continue;
196 }
197
198 if c == '(' {
200 if pending_whitespace && !at_beginning_of_line {
201 result.push(' ');
202 }
203 pending_whitespace = false;
204
205 result.push(c);
206 indent_level += 1;
207
208 result.push('\n');
210
211 result.push_str(&" ".repeat(indent_level));
213
214 at_beginning_of_line = true;
216
217 previous_token_type = SqlTokenType::Parenthesis;
218 continue;
219 }
220
221 if c == ')' {
222 pending_whitespace = false;
223
224 if !at_beginning_of_line {
226 result.push('\n');
227 }
228
229 indent_level = indent_level.saturating_sub(1);
230
231 if at_beginning_of_line {
233 result.truncate(result.rfind('\n').map(|pos| pos + 1).unwrap_or(0));
235 }
236
237 result.push_str(&" ".repeat(indent_level));
238 result.push(c);
239
240 previous_token_type = SqlTokenType::Parenthesis;
241 at_beginning_of_line = false;
242 continue;
243 }
244
245 if c == ',' {
247 result.push(c);
248
249 result.push('\n');
251 at_beginning_of_line = true;
252
253 result.push_str(&" ".repeat(indent_level));
255
256 previous_token_type = SqlTokenType::Punctuation;
257 continue;
258 }
259
260 if "+-*/=%<>!|&".contains(c) {
261 if pending_whitespace {
262 result.push(' ');
263 }
264 pending_whitespace = false;
265
266 result.push(c);
267
268 if !matches!(input_chars.peek(), Some(&'=') | Some(&'>') | Some(&'<')) {
270 result.push(' ');
271 }
272
273 previous_token_type = SqlTokenType::Operator;
274 at_beginning_of_line = false;
275 continue;
276 }
277
278 if c.is_alphabetic() || c == '_' || c == '@' || c == '#' || c == '$' {
280 buffer.clear();
281 buffer.push(c);
282
283 while let Some(&next_c) = input_chars.peek() {
285 if next_c.is_alphanumeric()
286 || next_c == '_'
287 || next_c == '@'
288 || next_c == '#'
289 || next_c == '$'
290 {
291 buffer.push(next_c);
292 input_chars.next();
293 } else {
294 break;
295 }
296 }
297
298 let upper_buffer = buffer.to_uppercase();
300 let is_keyword = is_sql_keyword(&upper_buffer);
301
302 if is_keyword {
304 let needs_newline = NEWLINE_KEYWORDS.contains(&upper_buffer.as_str())
306 || (MAJOR_KEYWORDS.contains(&upper_buffer.as_str()) && !at_beginning_of_line);
307
308 if needs_newline && !at_beginning_of_line {
309 result.push('\n');
310
311 result.push_str(&" ".repeat(indent_level));
313 } else if pending_whitespace && !at_beginning_of_line {
314 result.push(' ');
315 }
316
317 pending_whitespace = false;
318
319 result.push_str(&upper_buffer);
321
322 result.push(' ');
324
325 previous_token_type = SqlTokenType::Keyword;
326 } else {
327 if pending_whitespace && !at_beginning_of_line {
329 result.push(' ');
330 }
331 pending_whitespace = false;
332
333 result.push_str(&buffer);
335
336 previous_token_type = SqlTokenType::Identifier;
337 }
338
339 at_beginning_of_line = false;
340 continue;
341 }
342
343 if c.is_numeric() || (c == '.' && input_chars.peek().is_some_and(|p| p.is_numeric())) {
345 if pending_whitespace && !at_beginning_of_line {
346 result.push(' ');
347 }
348 pending_whitespace = false;
349
350 result.push(c);
351
352 while let Some(&next_c) = input_chars.peek() {
354 if next_c.is_numeric() || next_c == '.' {
355 result.push(next_c);
356 input_chars.next();
357 } else {
358 break;
359 }
360 }
361
362 previous_token_type = SqlTokenType::Number;
363 at_beginning_of_line = false;
364 continue;
365 }
366
367 if pending_whitespace && !at_beginning_of_line {
369 result.push(' ');
370 }
371 pending_whitespace = false;
372
373 result.push(c);
374 at_beginning_of_line = false;
375
376 previous_token_type = SqlTokenType::Punctuation;
378 }
379
380 Ok(result)
381}
382
383fn is_sql_keyword(word: &str) -> bool {
385 const KEYWORDS: [&str; 59] = [
387 "SELECT",
388 "FROM",
389 "WHERE",
390 "INSERT",
391 "UPDATE",
392 "DELETE",
393 "DROP",
394 "CREATE",
395 "ALTER",
396 "TABLE",
397 "VIEW",
398 "INDEX",
399 "TRIGGER",
400 "PROCEDURE",
401 "FUNCTION",
402 "DATABASE",
403 "SCHEMA",
404 "GRANT",
405 "REVOKE",
406 "JOIN",
407 "INNER",
408 "OUTER",
409 "LEFT",
410 "RIGHT",
411 "FULL",
412 "CROSS",
413 "NATURAL",
414 "GROUP",
415 "ORDER",
416 "BY",
417 "HAVING",
418 "UNION",
419 "ALL",
420 "INTERSECT",
421 "EXCEPT",
422 "INTO",
423 "VALUES",
424 "SET",
425 "AS",
426 "ON",
427 "AND",
428 "OR",
429 "NOT",
430 "NULL",
431 "IS",
432 "IN",
433 "BETWEEN",
434 "LIKE",
435 "EXISTS",
436 "CASE",
437 "WHEN",
438 "THEN",
439 "ELSE",
440 "END",
441 "ASC",
442 "DESC",
443 "LIMIT",
444 "OFFSET",
445 "WITH",
446 ];
447
448 KEYWORDS.contains(&word)
449}
450
451#[cfg(test)]
452mod tests {
453 use super::*;
454
455 #[test]
456 fn test_sql_formatter_empty() {
457 let transformer = SqlFormatter;
458 assert_eq!(transformer.transform("").unwrap(), "");
459 assert_eq!(transformer.transform(" ").unwrap(), "");
460 }
461
462 #[test]
463 fn test_sql_formatter_simple_select() {
464 let transformer = SqlFormatter;
465 let input = "SELECT id, name, email FROM users WHERE active = true ORDER BY name";
466
467 let expected =
469 "SELECT id,\nname,\nemail\nFROM users\nWHERE active = true ORDER BY name";
470 assert_eq!(transformer.transform(input).unwrap(), expected);
471 }
472
473 #[test]
474 fn test_sql_formatter_joins() {
475 let transformer = SqlFormatter;
476 let input = "SELECT u.id, u.name, o.order_date FROM users u JOIN orders o ON u.id = o.user_id WHERE o.total > 100";
477
478 let expected = "SELECT u.id,\nu.name,\no.order_date\nFROM users u\nJOIN orders o ON u.id = o.user_id\nWHERE o.total > 100";
480 assert_eq!(transformer.transform(input).unwrap(), expected);
481 }
482
483 #[test]
484 fn test_sql_formatter_nested_queries() {
485 let transformer = SqlFormatter;
486 let input = "SELECT * FROM (SELECT id, COUNT(*) as count FROM orders GROUP BY id) AS subquery WHERE count > 5";
487
488 let expected = "SELECT * \nFROM (\n SELECT id,\n COUNT(\n * \n ) AS count\n FROM orders GROUP BY id\n) AS subquery\nWHERE count > 5";
490 assert_eq!(transformer.transform(input).unwrap(), expected);
491 }
492
493 #[test]
494 fn test_sql_formatter_string_literals() {
495 let transformer = SqlFormatter;
496 let input = "SELECT * FROM users WHERE name = 'John''s' AND department = \"Sales\"";
497
498 let expected =
500 "SELECT * \nFROM users\nWHERE name = 'John''s' AND department = \"Sales\"";
501 assert_eq!(transformer.transform(input).unwrap(), expected);
502 }
503}