Skip to main content

sqlx_gen/
writer.rs

1use std::path::Path;
2
3use crate::error::Result;
4
5use crate::codegen::GeneratedFile;
6
7const COMMENT: &str = "// Auto-generated by sqlx-gen. Do not edit.";
8const INNER_ATTR: &str = "#![allow(unused_attributes)]";
9
10pub fn write_files(
11    files: &[GeneratedFile],
12    output_dir: &Path,
13    single_file: bool,
14    dry_run: bool,
15) -> Result<()> {
16    if dry_run {
17        for f in files {
18            println!("{}", build_file_content(f));
19            println!();
20        }
21        return Ok(());
22    }
23
24    std::fs::create_dir_all(output_dir)?;
25
26    if single_file {
27        write_single_file(files, output_dir)?;
28    } else {
29        write_multi_files(files, output_dir)?;
30    }
31
32    Ok(())
33}
34
35fn build_file_content(f: &GeneratedFile) -> String {
36    let mut content = String::new();
37    content.push_str(COMMENT);
38    content.push('\n');
39    if let Some(origin) = &f.origin {
40        content.push_str(&format!("// {}\n", origin));
41    }
42    content.push('\n');
43    content.push_str(INNER_ATTR);
44    content.push_str("\n\n");
45    content.push_str(&f.code);
46    content
47}
48
49fn write_single_file(files: &[GeneratedFile], output_dir: &Path) -> Result<()> {
50    let mut content = format!("{}\n\n{}\n\n", COMMENT, INNER_ATTR);
51
52    for f in files {
53        if let Some(origin) = &f.origin {
54            content.push_str(&format!("// --- {} ---\n\n", origin));
55        }
56        content.push_str(&f.code);
57        content.push('\n');
58    }
59
60    let path = output_dir.join("models.rs");
61    std::fs::write(&path, &content)?;
62    log::info!("Wrote {}", path.display());
63
64    Ok(())
65}
66
67fn write_multi_files(files: &[GeneratedFile], output_dir: &Path) -> Result<()> {
68    let mut mod_entries = Vec::new();
69
70    for f in files {
71        let content = build_file_content(f);
72        let path = output_dir.join(&f.filename);
73        std::fs::write(&path, &content)?;
74        log::info!("Wrote {}", path.display());
75
76        let mod_name = f.filename.strip_suffix(".rs").unwrap_or(&f.filename);
77        mod_entries.push(mod_name.to_string());
78    }
79
80    // Generate mod.rs
81    let mut mod_content = format!("{}\n\n{}\n\n", COMMENT, INNER_ATTR);
82    for m in &mod_entries {
83        mod_content.push_str(&format!("pub mod {};\n", m));
84    }
85
86    let mod_path = output_dir.join("mod.rs");
87    std::fs::write(&mod_path, &mod_content)?;
88    log::info!("Wrote {}", mod_path.display());
89
90    Ok(())
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    use crate::codegen::GeneratedFile;
97
98    fn make_file(filename: &str, code: &str, origin: Option<&str>) -> GeneratedFile {
99        GeneratedFile {
100            filename: filename.to_string(),
101            origin: origin.map(|s| s.to_string()),
102            code: code.to_string(),
103        }
104    }
105
106    // ========== build_file_content ==========
107
108    #[test]
109    fn test_build_content_with_origin() {
110        let f = make_file("users.rs", "pub struct Users {}", Some("Table: public.users"));
111        let content = build_file_content(&f);
112        assert!(content.contains(COMMENT));
113        assert!(content.contains(INNER_ATTR));
114        assert!(content.contains("// Table: public.users"));
115        assert!(content.contains("pub struct Users {}"));
116    }
117
118    #[test]
119    fn test_build_content_without_origin() {
120        let f = make_file("types.rs", "pub enum Status {}", None);
121        let content = build_file_content(&f);
122        assert!(content.contains(COMMENT));
123        assert!(content.contains(INNER_ATTR));
124        assert!(!content.contains("// Table:"));
125        assert!(content.contains("pub enum Status {}"));
126    }
127
128    #[test]
129    fn test_build_content_header_value() {
130        let f = make_file("x.rs", "", None);
131        let content = build_file_content(&f);
132        assert!(content.starts_with("// Auto-generated by sqlx-gen. Do not edit."));
133    }
134
135    #[test]
136    fn test_build_content_preserves_code() {
137        let code = "use chrono::NaiveDateTime;\n\npub struct Foo {\n    pub x: i32,\n}";
138        let f = make_file("foo.rs", code, None);
139        let content = build_file_content(&f);
140        assert!(content.contains(code));
141    }
142
143    #[test]
144    fn test_build_content_origin_format() {
145        let f = make_file("x.rs", "code", Some("Table: public.users"));
146        let content = build_file_content(&f);
147        assert!(content.contains("// Table: public.users\n"));
148    }
149
150    #[test]
151    fn test_build_content_empty_code() {
152        let f = make_file("x.rs", "", Some("Table: public.x"));
153        let content = build_file_content(&f);
154        assert!(content.contains(COMMENT));
155        assert!(content.contains(INNER_ATTR));
156        assert!(content.contains("// Table: public.x"));
157    }
158
159    // ========== write_files dry_run ==========
160
161    #[test]
162    fn test_dry_run_returns_ok() {
163        let files = vec![make_file("users.rs", "code", Some("origin"))];
164        let dir = tempfile::tempdir().unwrap();
165        let result = write_files(&files, dir.path(), false, true);
166        assert!(result.is_ok());
167    }
168
169    #[test]
170    fn test_dry_run_no_files_created() {
171        let files = vec![make_file("users.rs", "code", Some("origin"))];
172        let dir = tempfile::tempdir().unwrap();
173        let sub = dir.path().join("output");
174        let _ = write_files(&files, &sub, false, true);
175        // Output dir should NOT be created in dry_run mode
176        assert!(!sub.exists());
177    }
178
179    #[test]
180    fn test_dry_run_empty_files() {
181        let dir = tempfile::tempdir().unwrap();
182        let result = write_files(&[], dir.path(), false, true);
183        assert!(result.is_ok());
184    }
185
186    // ========== write_multi_files ==========
187
188    #[test]
189    fn test_multi_creates_files_and_mod() {
190        let files = vec![
191            make_file("users.rs", "pub struct Users {}", Some("Table: public.users")),
192            make_file("posts.rs", "pub struct Posts {}", Some("Table: public.posts")),
193        ];
194        let dir = tempfile::tempdir().unwrap();
195        write_files(&files, dir.path(), false, false).unwrap();
196
197        assert!(dir.path().join("users.rs").exists());
198        assert!(dir.path().join("posts.rs").exists());
199        assert!(dir.path().join("mod.rs").exists());
200    }
201
202    #[test]
203    fn test_multi_mod_rs_content() {
204        let files = vec![
205            make_file("users.rs", "code", Some("origin")),
206            make_file("types.rs", "code", None),
207        ];
208        let dir = tempfile::tempdir().unwrap();
209        write_files(&files, dir.path(), false, false).unwrap();
210
211        let mod_content = std::fs::read_to_string(dir.path().join("mod.rs")).unwrap();
212        assert!(mod_content.contains("pub mod users;"));
213        assert!(mod_content.contains("pub mod types;"));
214    }
215
216    #[test]
217    fn test_multi_file_has_header() {
218        let files = vec![make_file("users.rs", "code", Some("Table: public.users"))];
219        let dir = tempfile::tempdir().unwrap();
220        write_files(&files, dir.path(), false, false).unwrap();
221
222        let content = std::fs::read_to_string(dir.path().join("users.rs")).unwrap();
223        assert!(content.starts_with(COMMENT));
224    }
225
226    #[test]
227    fn test_multi_file_has_origin() {
228        let files = vec![make_file("users.rs", "code", Some("Table: public.users"))];
229        let dir = tempfile::tempdir().unwrap();
230        write_files(&files, dir.path(), false, false).unwrap();
231
232        let content = std::fs::read_to_string(dir.path().join("users.rs")).unwrap();
233        assert!(content.contains("// Table: public.users"));
234    }
235
236    #[test]
237    fn test_multi_creates_output_dir() {
238        let dir = tempfile::tempdir().unwrap();
239        let sub = dir.path().join("nested").join("output");
240        let files = vec![make_file("users.rs", "code", Some("origin"))];
241        write_files(&files, &sub, false, false).unwrap();
242        assert!(sub.join("users.rs").exists());
243    }
244
245    #[test]
246    fn test_multi_file_no_origin() {
247        let files = vec![make_file("types.rs", "code", None)];
248        let dir = tempfile::tempdir().unwrap();
249        write_files(&files, dir.path(), false, false).unwrap();
250
251        let content = std::fs::read_to_string(dir.path().join("types.rs")).unwrap();
252        assert!(content.contains(COMMENT));
253        assert!(content.contains(INNER_ATTR));
254        // No origin line
255        assert!(!content.contains("// Table:"));
256    }
257
258    // ========== write_single_file ==========
259
260    #[test]
261    fn test_single_creates_models_rs() {
262        let files = vec![make_file("users.rs", "pub struct Users {}", Some("Table: public.users"))];
263        let dir = tempfile::tempdir().unwrap();
264        write_files(&files, dir.path(), true, false).unwrap();
265        assert!(dir.path().join("models.rs").exists());
266    }
267
268    #[test]
269    fn test_single_starts_with_header() {
270        let files = vec![make_file("users.rs", "code", Some("origin"))];
271        let dir = tempfile::tempdir().unwrap();
272        write_files(&files, dir.path(), true, false).unwrap();
273
274        let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
275        assert!(content.starts_with(COMMENT));
276    }
277
278    #[test]
279    fn test_single_has_section_separator() {
280        let files = vec![make_file("users.rs", "code", Some("Table: public.users"))];
281        let dir = tempfile::tempdir().unwrap();
282        write_files(&files, dir.path(), true, false).unwrap();
283
284        let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
285        assert!(content.contains("// --- Table: public.users ---"));
286    }
287
288    #[test]
289    fn test_single_concatenates_all_code() {
290        let files = vec![
291            make_file("users.rs", "struct Users;", Some("Table: public.users")),
292            make_file("posts.rs", "struct Posts;", Some("Table: public.posts")),
293        ];
294        let dir = tempfile::tempdir().unwrap();
295        write_files(&files, dir.path(), true, false).unwrap();
296
297        let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
298        assert!(content.contains("struct Users;"));
299        assert!(content.contains("struct Posts;"));
300    }
301
302    #[test]
303    fn test_single_no_origin_no_separator() {
304        let files = vec![make_file("types.rs", "code", None)];
305        let dir = tempfile::tempdir().unwrap();
306        write_files(&files, dir.path(), true, false).unwrap();
307
308        let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
309        assert!(!content.contains("// ---"));
310    }
311}