1use std::{
2 fs,
3 io,
4 path::Path,
5};
6
7use sqlformat::{format, Dialect, FormatOptions, QueryParams};
8use sqlparser::dialect::PostgreSqlDialect;
9use sqlparser::parser::Parser as SqlParser;
10
11pub const IGNORE_STRING: &str = "--poppy-ignore";
12
13pub struct PythonSqlResult {
14 pub content: String,
15 pub queries: Vec<String>,
16}
17
18pub fn process_path(path: &Path) -> io::Result<()> {
19 if path.is_dir() {
20 traverse_dirs(path)
21 } else {
22 let filename = path
23 .file_name()
24 .and_then(|s| s.to_str())
25 .unwrap_or("")
26 .to_string();
27
28 if !is_supported_file(&filename) {
29 println!("unsupported file format");
30 return Ok(());
31 }
32
33 format_file(&filename, path)
34 }
35}
36
37pub fn traverse_dirs(dir: &Path) -> io::Result<()> {
38 if dir.is_dir() {
39 for entry in fs::read_dir(dir)? {
40 let entry = entry?;
41 let path = entry.path();
42
43 if path.is_dir() {
44 traverse_dirs(&path)?;
45 } else {
46 let filename = entry.file_name().to_str().unwrap_or("").to_string();
47
48 if !is_supported_file(&filename) {
49 continue;
50 }
51
52 format_file(&filename, &path)?;
53 }
54 }
55 }
56
57 Ok(())
58}
59
60pub fn format_file(filename: &str, path: &Path) -> io::Result<()> {
61 println!("{filename}");
62
63 if filename.ends_with(".sql") {
64 let contents = fs::read_to_string(path).unwrap_or_default();
65
66 if contents.contains(IGNORE_STRING) {
67 return Ok(());
68 }
69
70 let mut new_contents = format_sql(&contents);
71 new_contents.push('\n');
72
73 if new_contents != contents {
74 println!("Changes applied to: {filename}");
75 fs::write(path, new_contents)?;
76 }
77 }
78
79 if filename.ends_with(".py") {
80 let contents = fs::read_to_string(path).unwrap_or_default();
81 let result = find_sql_in_python_file(&contents, true);
82 let new_contents = result.content;
83
84 if new_contents != contents {
85 println!("Changes applied to: {filename}");
86 fs::write(path, new_contents)?;
87 }
88 }
89
90 Ok(())
91}
92
93pub fn find_sql_in_python_file(contents: &str, format_file_content: bool) -> PythonSqlResult {
94 let mut output = String::with_capacity(contents.len());
95 let mut queries = Vec::new();
96 let mut unprocessed_contents = contents;
97 let dialect = PostgreSqlDialect {};
98
99 while let Some(start) = unprocessed_contents.find(r#"""""#) {
100 let is_fstring =
101 start > 0 && matches!(unprocessed_contents.as_bytes()[start - 1], b'f' | b'F');
102
103 let (prefix, after_prefix) = unprocessed_contents.split_at(start);
104 output.push_str(prefix);
105
106 let indent: String = prefix
107 .lines()
108 .next_back()
109 .unwrap_or("")
110 .chars()
111 .take_while(|c| matches!(c, ' ' | '\t'))
112 .collect();
113
114 unprocessed_contents = &after_prefix[3..];
115
116 let Some(end_rel) = unprocessed_contents.find(r#"""""#) else {
117 output.push_str(r#"""""#);
118 output.push_str(unprocessed_contents);
119 return PythonSqlResult {
120 content: output,
121 queries,
122 };
123 };
124
125 let (raw_sql, after_sql) = unprocessed_contents.split_at(end_rel);
126
127 let is_valid_sql_query = !is_fstring
128 && !raw_sql.contains(IGNORE_STRING);
129
130 let do_format = format_file_content
131 && is_valid_sql_query
132 && SqlParser::parse_sql(&dialect, raw_sql).is_ok();
133
134 output.push_str(r#"""""#);
135
136 if is_valid_sql_query {
137 queries.push(raw_sql.to_string());
138 }
139
140 if do_format {
141 let formatted = format_sql(raw_sql);
142
143 output.push('\n');
144
145 for line in formatted.lines() {
146 output.push_str(&indent);
147 output.push_str(line);
148 output.push('\n');
149 }
150
151 output.push_str(&indent);
152 } else {
153 output.push_str(raw_sql);
154 }
155
156 output.push_str(r#"""""#);
157 unprocessed_contents = &after_sql[3..];
158 }
159
160 output.push_str(unprocessed_contents);
161
162 PythonSqlResult {
163 content: output,
164 queries,
165 }
166}
167
168pub fn format_sql(sql: &str) -> String {
169 format(
170 sql,
171 &QueryParams::None,
172 &FormatOptions {
173 indent: sqlformat::Indent::Spaces(4),
174 uppercase: Some(true),
175 joins_as_top_level: true,
176 dialect: Dialect::PostgreSql,
177 lines_between_queries: 2,
178 ..Default::default()
179 },
180 )
181}
182
183pub fn is_supported_file(filename: &str) -> bool {
184 filename.ends_with(".sql") || filename.ends_with(".py")
185}