1use std::path::Path;
2
3use crate::error::Result;
4
5use crate::codegen::GeneratedFile;
6
7const COMMENT: &str = "// Auto-generated by sqlx-gen. Do not edit.";
8const INNER_ATTR: &str = "#![allow(unused_attributes)]";
9
10pub fn write_files(
11 files: &[GeneratedFile],
12 output_dir: &Path,
13 single_file: bool,
14 dry_run: bool,
15) -> Result<()> {
16 if dry_run {
17 for f in files {
18 println!("{}", build_file_content(f));
19 println!();
20 }
21 return Ok(());
22 }
23
24 std::fs::create_dir_all(output_dir)?;
25
26 if single_file {
27 write_single_file(files, output_dir)?;
28 } else {
29 write_multi_files(files, output_dir)?;
30 }
31
32 Ok(())
33}
34
35fn build_file_content(f: &GeneratedFile) -> String {
36 let mut content = String::new();
37 content.push_str(COMMENT);
38 content.push('\n');
39 if let Some(origin) = &f.origin {
40 content.push_str(&format!("// {}\n", origin));
41 }
42 content.push('\n');
43 content.push_str(INNER_ATTR);
44 content.push_str("\n\n");
45 content.push_str(&f.code);
46 content
47}
48
49fn write_single_file(files: &[GeneratedFile], output_dir: &Path) -> Result<()> {
50 let mut content = format!("{}\n\n{}\n\n", COMMENT, INNER_ATTR);
51
52 for f in files {
53 if let Some(origin) = &f.origin {
54 content.push_str(&format!("// --- {} ---\n\n", origin));
55 }
56 content.push_str(&f.code);
57 content.push('\n');
58 }
59
60 let path = output_dir.join("models.rs");
61 std::fs::write(&path, &content)?;
62 log::info!("Wrote {}", path.display());
63
64 Ok(())
65}
66
67fn write_multi_files(files: &[GeneratedFile], output_dir: &Path) -> Result<()> {
68 let mut mod_entries = Vec::new();
69
70 for f in files {
71 let content = build_file_content(f);
72 let path = output_dir.join(&f.filename);
73 std::fs::write(&path, &content)?;
74 log::info!("Wrote {}", path.display());
75
76 let mod_name = f.filename.strip_suffix(".rs").unwrap_or(&f.filename);
77 mod_entries.push(mod_name.to_string());
78 }
79
80 let mut mod_content = format!("{}\n\n{}\n\n", COMMENT, INNER_ATTR);
82 for m in &mod_entries {
83 mod_content.push_str(&format!("pub mod {};\n", m));
84 }
85
86 let mod_path = output_dir.join("mod.rs");
87 std::fs::write(&mod_path, &mod_content)?;
88 log::info!("Wrote {}", mod_path.display());
89
90 Ok(())
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96 use crate::codegen::GeneratedFile;
97
98 fn make_file(filename: &str, code: &str, origin: Option<&str>) -> GeneratedFile {
99 GeneratedFile {
100 filename: filename.to_string(),
101 origin: origin.map(|s| s.to_string()),
102 code: code.to_string(),
103 }
104 }
105
106 #[test]
109 fn test_build_content_with_origin() {
110 let f = make_file("users.rs", "pub struct Users {}", Some("Table: public.users"));
111 let content = build_file_content(&f);
112 assert!(content.contains(COMMENT));
113 assert!(content.contains(INNER_ATTR));
114 assert!(content.contains("// Table: public.users"));
115 assert!(content.contains("pub struct Users {}"));
116 }
117
118 #[test]
119 fn test_build_content_without_origin() {
120 let f = make_file("types.rs", "pub enum Status {}", None);
121 let content = build_file_content(&f);
122 assert!(content.contains(COMMENT));
123 assert!(content.contains(INNER_ATTR));
124 assert!(!content.contains("// Table:"));
125 assert!(content.contains("pub enum Status {}"));
126 }
127
128 #[test]
129 fn test_build_content_header_value() {
130 let f = make_file("x.rs", "", None);
131 let content = build_file_content(&f);
132 assert!(content.starts_with("// Auto-generated by sqlx-gen. Do not edit."));
133 }
134
135 #[test]
136 fn test_build_content_preserves_code() {
137 let code = "use chrono::NaiveDateTime;\n\npub struct Foo {\n pub x: i32,\n}";
138 let f = make_file("foo.rs", code, None);
139 let content = build_file_content(&f);
140 assert!(content.contains(code));
141 }
142
143 #[test]
144 fn test_build_content_origin_format() {
145 let f = make_file("x.rs", "code", Some("Table: public.users"));
146 let content = build_file_content(&f);
147 assert!(content.contains("// Table: public.users\n"));
148 }
149
150 #[test]
151 fn test_build_content_empty_code() {
152 let f = make_file("x.rs", "", Some("Table: public.x"));
153 let content = build_file_content(&f);
154 assert!(content.contains(COMMENT));
155 assert!(content.contains(INNER_ATTR));
156 assert!(content.contains("// Table: public.x"));
157 }
158
159 #[test]
162 fn test_dry_run_returns_ok() {
163 let files = vec![make_file("users.rs", "code", Some("origin"))];
164 let dir = tempfile::tempdir().unwrap();
165 let result = write_files(&files, dir.path(), false, true);
166 assert!(result.is_ok());
167 }
168
169 #[test]
170 fn test_dry_run_no_files_created() {
171 let files = vec![make_file("users.rs", "code", Some("origin"))];
172 let dir = tempfile::tempdir().unwrap();
173 let sub = dir.path().join("output");
174 let _ = write_files(&files, &sub, false, true);
175 assert!(!sub.exists());
177 }
178
179 #[test]
180 fn test_dry_run_empty_files() {
181 let dir = tempfile::tempdir().unwrap();
182 let result = write_files(&[], dir.path(), false, true);
183 assert!(result.is_ok());
184 }
185
186 #[test]
189 fn test_multi_creates_files_and_mod() {
190 let files = vec![
191 make_file("users.rs", "pub struct Users {}", Some("Table: public.users")),
192 make_file("posts.rs", "pub struct Posts {}", Some("Table: public.posts")),
193 ];
194 let dir = tempfile::tempdir().unwrap();
195 write_files(&files, dir.path(), false, false).unwrap();
196
197 assert!(dir.path().join("users.rs").exists());
198 assert!(dir.path().join("posts.rs").exists());
199 assert!(dir.path().join("mod.rs").exists());
200 }
201
202 #[test]
203 fn test_multi_mod_rs_content() {
204 let files = vec![
205 make_file("users.rs", "code", Some("origin")),
206 make_file("types.rs", "code", None),
207 ];
208 let dir = tempfile::tempdir().unwrap();
209 write_files(&files, dir.path(), false, false).unwrap();
210
211 let mod_content = std::fs::read_to_string(dir.path().join("mod.rs")).unwrap();
212 assert!(mod_content.contains("pub mod users;"));
213 assert!(mod_content.contains("pub mod types;"));
214 }
215
216 #[test]
217 fn test_multi_file_has_header() {
218 let files = vec![make_file("users.rs", "code", Some("Table: public.users"))];
219 let dir = tempfile::tempdir().unwrap();
220 write_files(&files, dir.path(), false, false).unwrap();
221
222 let content = std::fs::read_to_string(dir.path().join("users.rs")).unwrap();
223 assert!(content.starts_with(COMMENT));
224 }
225
226 #[test]
227 fn test_multi_file_has_origin() {
228 let files = vec![make_file("users.rs", "code", Some("Table: public.users"))];
229 let dir = tempfile::tempdir().unwrap();
230 write_files(&files, dir.path(), false, false).unwrap();
231
232 let content = std::fs::read_to_string(dir.path().join("users.rs")).unwrap();
233 assert!(content.contains("// Table: public.users"));
234 }
235
236 #[test]
237 fn test_multi_creates_output_dir() {
238 let dir = tempfile::tempdir().unwrap();
239 let sub = dir.path().join("nested").join("output");
240 let files = vec![make_file("users.rs", "code", Some("origin"))];
241 write_files(&files, &sub, false, false).unwrap();
242 assert!(sub.join("users.rs").exists());
243 }
244
245 #[test]
246 fn test_multi_file_no_origin() {
247 let files = vec![make_file("types.rs", "code", None)];
248 let dir = tempfile::tempdir().unwrap();
249 write_files(&files, dir.path(), false, false).unwrap();
250
251 let content = std::fs::read_to_string(dir.path().join("types.rs")).unwrap();
252 assert!(content.contains(COMMENT));
253 assert!(content.contains(INNER_ATTR));
254 assert!(!content.contains("// Table:"));
256 }
257
258 #[test]
261 fn test_single_creates_models_rs() {
262 let files = vec![make_file("users.rs", "pub struct Users {}", Some("Table: public.users"))];
263 let dir = tempfile::tempdir().unwrap();
264 write_files(&files, dir.path(), true, false).unwrap();
265 assert!(dir.path().join("models.rs").exists());
266 }
267
268 #[test]
269 fn test_single_starts_with_header() {
270 let files = vec![make_file("users.rs", "code", Some("origin"))];
271 let dir = tempfile::tempdir().unwrap();
272 write_files(&files, dir.path(), true, false).unwrap();
273
274 let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
275 assert!(content.starts_with(COMMENT));
276 }
277
278 #[test]
279 fn test_single_has_section_separator() {
280 let files = vec![make_file("users.rs", "code", Some("Table: public.users"))];
281 let dir = tempfile::tempdir().unwrap();
282 write_files(&files, dir.path(), true, false).unwrap();
283
284 let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
285 assert!(content.contains("// --- Table: public.users ---"));
286 }
287
288 #[test]
289 fn test_single_concatenates_all_code() {
290 let files = vec![
291 make_file("users.rs", "struct Users;", Some("Table: public.users")),
292 make_file("posts.rs", "struct Posts;", Some("Table: public.posts")),
293 ];
294 let dir = tempfile::tempdir().unwrap();
295 write_files(&files, dir.path(), true, false).unwrap();
296
297 let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
298 assert!(content.contains("struct Users;"));
299 assert!(content.contains("struct Posts;"));
300 }
301
302 #[test]
303 fn test_single_no_origin_no_separator() {
304 let files = vec![make_file("types.rs", "code", None)];
305 let dir = tempfile::tempdir().unwrap();
306 write_files(&files, dir.path(), true, false).unwrap();
307
308 let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
309 assert!(!content.contains("// ---"));
310 }
311}