Skip to main content

poppy_sql/
lib.rs

1use std::{
2    fs,
3    io,
4    path::Path,
5};
6
7use sqlformat::{format, Dialect, FormatOptions, QueryParams};
8use sqlparser::dialect::PostgreSqlDialect;
9use sqlparser::parser::Parser as SqlParser;
10
11pub const IGNORE_STRING: &str = "--poppy-ignore";
12
13pub struct PythonSqlResult {
14    pub content: String,
15    pub queries: Vec<String>,
16}
17
18pub fn process_path(path: &Path) -> io::Result<()> {
19    if path.is_dir() {
20        traverse_dirs(path)
21    } else {
22        let filename = path
23            .file_name()
24            .and_then(|s| s.to_str())
25            .unwrap_or("")
26            .to_string();
27
28        if !is_supported_file(&filename) {
29            println!("unsupported file format");
30            return Ok(());
31        }
32
33        format_file(&filename, path)
34    }
35}
36
37pub fn traverse_dirs(dir: &Path) -> io::Result<()> {
38    if dir.is_dir() {
39        for entry in fs::read_dir(dir)? {
40            let entry = entry?;
41            let path = entry.path();
42
43            if path.is_dir() {
44                traverse_dirs(&path)?;
45            } else {
46                let filename = entry.file_name().to_str().unwrap_or("").to_string();
47
48                if !is_supported_file(&filename) {
49                    continue;
50                }
51
52                format_file(&filename, &path)?;
53            }
54        }
55    }
56
57    Ok(())
58}
59
60pub fn format_file(filename: &str, path: &Path) -> io::Result<()> {
61    println!("{filename}");
62
63    if filename.ends_with(".sql") {
64        let contents = fs::read_to_string(path).unwrap_or_default();
65
66        if contents.contains(IGNORE_STRING) {
67            return Ok(());
68        }
69
70        let mut new_contents = format_sql(&contents);
71        new_contents.push('\n');
72
73        if new_contents != contents {
74            println!("Changes applied to: {filename}");
75            fs::write(path, new_contents)?;
76        }
77    }
78
79    if filename.ends_with(".py") {
80        let contents = fs::read_to_string(path).unwrap_or_default();
81        let result = find_sql_in_python_file(&contents, true);
82        let new_contents = result.content;
83
84        if new_contents != contents {
85            println!("Changes applied to: {filename}");
86            fs::write(path, new_contents)?;
87        }
88    }
89
90    Ok(())
91}
92
93pub fn find_sql_in_python_file(contents: &str, format_file_content: bool) -> PythonSqlResult {
94    let mut output = String::with_capacity(contents.len());
95    let mut queries = Vec::new();
96    let mut unprocessed_contents = contents;
97    let dialect = PostgreSqlDialect {};
98
99    while let Some(start) = unprocessed_contents.find(r#"""""#) {
100        let is_fstring =
101            start > 0 && matches!(unprocessed_contents.as_bytes()[start - 1], b'f' | b'F');
102
103        let (prefix, after_prefix) = unprocessed_contents.split_at(start);
104        output.push_str(prefix);
105
106        let indent: String = prefix
107            .lines()
108            .next_back()
109            .unwrap_or("")
110            .chars()
111            .take_while(|c| matches!(c, ' ' | '\t'))
112            .collect();
113
114        unprocessed_contents = &after_prefix[3..];
115
116        let Some(end_rel) = unprocessed_contents.find(r#"""""#) else {
117            output.push_str(r#"""""#);
118            output.push_str(unprocessed_contents);
119            return PythonSqlResult {
120                content: output,
121                queries,
122            };
123        };
124
125        let (raw_sql, after_sql) = unprocessed_contents.split_at(end_rel);
126
127        let is_valid_sql_query = !is_fstring
128            && !raw_sql.contains(IGNORE_STRING);
129
130        let do_format = format_file_content
131            && is_valid_sql_query
132            && SqlParser::parse_sql(&dialect, raw_sql).is_ok();
133
134        output.push_str(r#"""""#);
135
136        if is_valid_sql_query {
137            queries.push(raw_sql.to_string());
138        }
139
140        if do_format {
141            let formatted = format_sql(raw_sql);
142
143            output.push('\n');
144
145            for line in formatted.lines() {
146                output.push_str(&indent);
147                output.push_str(line);
148                output.push('\n');
149            }
150
151            output.push_str(&indent);
152        } else {
153            output.push_str(raw_sql);
154        }
155
156        output.push_str(r#"""""#);
157        unprocessed_contents = &after_sql[3..];
158    }
159
160    output.push_str(unprocessed_contents);
161
162    PythonSqlResult {
163        content: output,
164        queries,
165    }
166}
167
168pub fn format_sql(sql: &str) -> String {
169    format(
170        sql,
171        &QueryParams::None,
172        &FormatOptions {
173            indent: sqlformat::Indent::Spaces(4),
174            uppercase: Some(true),
175            joins_as_top_level: true,
176            dialect: Dialect::PostgreSql,
177            lines_between_queries: 2,
178            ..Default::default()
179        },
180    )
181}
182
183pub fn is_supported_file(filename: &str) -> bool {
184    filename.ends_with(".sql") || filename.ends_with(".py")
185}