1use std::io::Write;
2use std::path::Path;
3
4use crate::error::Result;
5
6use crate::codegen::GeneratedFile;
7
8const COMMENT: &str = "// Auto-generated by sqlx-gen. Do not edit.";
9const INNER_ATTR: &str = "#![allow(unused_attributes)]";
10
11pub(crate) fn write_atomic(path: &Path, content: &[u8]) -> Result<()> {
14 let parent = path.parent().ok_or_else(|| {
15 crate::error::Error::Config(format!(
16 "Cannot determine parent directory of {}",
17 path.display()
18 ))
19 })?;
20 let mut tmp = tempfile::NamedTempFile::new_in(parent)?;
21 tmp.write_all(content)?;
22 tmp.flush()?;
23 tmp.persist(path).map_err(|e| e.error)?;
24 Ok(())
25}
26
27pub fn write_files(
28 files: &[GeneratedFile],
29 output_dir: &Path,
30 single_file: bool,
31 dry_run: bool,
32) -> Result<()> {
33 for f in files {
34 validate_safe_filename(&f.filename)?;
35 }
36
37 if dry_run {
38 for f in files {
39 println!("{}", build_file_content(f));
40 println!();
41 }
42 return Ok(());
43 }
44
45 std::fs::create_dir_all(output_dir)?;
46
47 if single_file {
48 write_single_file(files, output_dir)?;
49 } else {
50 write_multi_files(files, output_dir)?;
51 }
52
53 Ok(())
54}
55
56fn validate_safe_filename(filename: &str) -> Result<()> {
61 let p = Path::new(filename);
62 if filename.is_empty()
63 || p.components().count() != 1
64 || p.is_absolute()
65 || filename.contains("..")
66 || filename.contains('/')
67 || filename.contains('\\')
68 || !filename.ends_with(".rs")
69 {
70 return Err(crate::error::Error::Config(format!(
71 "Refusing to write generated file with unsafe name: {:?}",
72 filename
73 )));
74 }
75 Ok(())
76}
77
78fn build_file_content(f: &GeneratedFile) -> String {
79 let mut content = String::new();
80 content.push_str(COMMENT);
81 content.push('\n');
82 if let Some(origin) = &f.origin {
83 content.push_str(&format!("// {}\n", origin));
84 }
85 content.push('\n');
86 content.push_str(INNER_ATTR);
87 content.push_str("\n\n");
88 content.push_str(&f.code);
89 content
90}
91
92fn write_single_file(files: &[GeneratedFile], output_dir: &Path) -> Result<()> {
93 let mut content = format!("{}\n\n{}\n\n", COMMENT, INNER_ATTR);
94
95 for f in files {
96 if let Some(origin) = &f.origin {
97 content.push_str(&format!("// --- {} ---\n\n", origin));
98 }
99 content.push_str(&f.code);
100 content.push('\n');
101 }
102
103 let path = output_dir.join("models.rs");
104 write_atomic(&path, content.as_bytes())?;
105 log::info!("Wrote {}", path.display());
106
107 Ok(())
108}
109
110fn write_multi_files(files: &[GeneratedFile], output_dir: &Path) -> Result<()> {
111 let mut mod_entries = Vec::new();
112
113 for f in files {
114 let content = build_file_content(f);
115 let path = output_dir.join(&f.filename);
116 write_atomic(&path, content.as_bytes())?;
117 log::info!("Wrote {}", path.display());
118
119 let mod_name = f.filename.strip_suffix(".rs").unwrap_or(&f.filename);
120 mod_entries.push(mod_name.to_string());
121 }
122
123 let mut mod_content = format!("{}\n\n{}\n\n", COMMENT, INNER_ATTR);
125 for m in &mod_entries {
126 mod_content.push_str(&format!("pub mod {};\n", m));
127 }
128
129 let mod_path = output_dir.join("mod.rs");
130 write_atomic(&mod_path, mod_content.as_bytes())?;
131 log::info!("Wrote {}", mod_path.display());
132
133 Ok(())
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use crate::codegen::GeneratedFile;
140
141 fn make_file(filename: &str, code: &str, origin: Option<&str>) -> GeneratedFile {
142 GeneratedFile {
143 filename: filename.to_string(),
144 origin: origin.map(|s| s.to_string()),
145 code: code.to_string(),
146 }
147 }
148
149 #[test]
152 fn test_build_content_with_origin() {
153 let f = make_file(
154 "users.rs",
155 "pub struct Users {}",
156 Some("Table: public.users"),
157 );
158 let content = build_file_content(&f);
159 assert!(content.contains(COMMENT));
160 assert!(content.contains(INNER_ATTR));
161 assert!(content.contains("// Table: public.users"));
162 assert!(content.contains("pub struct Users {}"));
163 }
164
165 #[test]
166 fn test_build_content_without_origin() {
167 let f = make_file("types.rs", "pub enum Status {}", None);
168 let content = build_file_content(&f);
169 assert!(content.contains(COMMENT));
170 assert!(content.contains(INNER_ATTR));
171 assert!(!content.contains("// Table:"));
172 assert!(content.contains("pub enum Status {}"));
173 }
174
175 #[test]
176 fn test_build_content_header_value() {
177 let f = make_file("x.rs", "", None);
178 let content = build_file_content(&f);
179 assert!(content.starts_with("// Auto-generated by sqlx-gen. Do not edit."));
180 }
181
182 #[test]
183 fn test_build_content_preserves_code() {
184 let code = "use chrono::NaiveDateTime;\n\npub struct Foo {\n pub x: i32,\n}";
185 let f = make_file("foo.rs", code, None);
186 let content = build_file_content(&f);
187 assert!(content.contains(code));
188 }
189
190 #[test]
191 fn test_build_content_origin_format() {
192 let f = make_file("x.rs", "code", Some("Table: public.users"));
193 let content = build_file_content(&f);
194 assert!(content.contains("// Table: public.users\n"));
195 }
196
197 #[test]
198 fn test_build_content_empty_code() {
199 let f = make_file("x.rs", "", Some("Table: public.x"));
200 let content = build_file_content(&f);
201 assert!(content.contains(COMMENT));
202 assert!(content.contains(INNER_ATTR));
203 assert!(content.contains("// Table: public.x"));
204 }
205
206 #[test]
209 fn test_dry_run_returns_ok() {
210 let files = vec![make_file("users.rs", "code", Some("origin"))];
211 let dir = tempfile::tempdir().unwrap();
212 let result = write_files(&files, dir.path(), false, true);
213 assert!(result.is_ok());
214 }
215
216 #[test]
217 fn test_dry_run_no_files_created() {
218 let files = vec![make_file("users.rs", "code", Some("origin"))];
219 let dir = tempfile::tempdir().unwrap();
220 let sub = dir.path().join("output");
221 let _ = write_files(&files, &sub, false, true);
222 assert!(!sub.exists());
224 }
225
226 #[test]
227 fn test_dry_run_empty_files() {
228 let dir = tempfile::tempdir().unwrap();
229 let result = write_files(&[], dir.path(), false, true);
230 assert!(result.is_ok());
231 }
232
233 #[test]
236 fn test_multi_creates_files_and_mod() {
237 let files = vec![
238 make_file(
239 "users.rs",
240 "pub struct Users {}",
241 Some("Table: public.users"),
242 ),
243 make_file(
244 "posts.rs",
245 "pub struct Posts {}",
246 Some("Table: public.posts"),
247 ),
248 ];
249 let dir = tempfile::tempdir().unwrap();
250 write_files(&files, dir.path(), false, false).unwrap();
251
252 assert!(dir.path().join("users.rs").exists());
253 assert!(dir.path().join("posts.rs").exists());
254 assert!(dir.path().join("mod.rs").exists());
255 }
256
257 #[test]
258 fn test_multi_mod_rs_content() {
259 let files = vec![
260 make_file("users.rs", "code", Some("origin")),
261 make_file("types.rs", "code", None),
262 ];
263 let dir = tempfile::tempdir().unwrap();
264 write_files(&files, dir.path(), false, false).unwrap();
265
266 let mod_content = std::fs::read_to_string(dir.path().join("mod.rs")).unwrap();
267 assert!(mod_content.contains("pub mod users;"));
268 assert!(mod_content.contains("pub mod types;"));
269 }
270
271 #[test]
272 fn test_multi_file_has_header() {
273 let files = vec![make_file("users.rs", "code", Some("Table: public.users"))];
274 let dir = tempfile::tempdir().unwrap();
275 write_files(&files, dir.path(), false, false).unwrap();
276
277 let content = std::fs::read_to_string(dir.path().join("users.rs")).unwrap();
278 assert!(content.starts_with(COMMENT));
279 }
280
281 #[test]
282 fn test_multi_file_has_origin() {
283 let files = vec![make_file("users.rs", "code", Some("Table: public.users"))];
284 let dir = tempfile::tempdir().unwrap();
285 write_files(&files, dir.path(), false, false).unwrap();
286
287 let content = std::fs::read_to_string(dir.path().join("users.rs")).unwrap();
288 assert!(content.contains("// Table: public.users"));
289 }
290
291 #[test]
292 fn test_multi_creates_output_dir() {
293 let dir = tempfile::tempdir().unwrap();
294 let sub = dir.path().join("nested").join("output");
295 let files = vec![make_file("users.rs", "code", Some("origin"))];
296 write_files(&files, &sub, false, false).unwrap();
297 assert!(sub.join("users.rs").exists());
298 }
299
300 #[test]
301 fn test_multi_file_no_origin() {
302 let files = vec![make_file("types.rs", "code", None)];
303 let dir = tempfile::tempdir().unwrap();
304 write_files(&files, dir.path(), false, false).unwrap();
305
306 let content = std::fs::read_to_string(dir.path().join("types.rs")).unwrap();
307 assert!(content.contains(COMMENT));
308 assert!(content.contains(INNER_ATTR));
309 assert!(!content.contains("// Table:"));
311 }
312
313 #[test]
316 fn test_single_creates_models_rs() {
317 let files = vec![make_file(
318 "users.rs",
319 "pub struct Users {}",
320 Some("Table: public.users"),
321 )];
322 let dir = tempfile::tempdir().unwrap();
323 write_files(&files, dir.path(), true, false).unwrap();
324 assert!(dir.path().join("models.rs").exists());
325 }
326
327 #[test]
328 fn test_single_starts_with_header() {
329 let files = vec![make_file("users.rs", "code", Some("origin"))];
330 let dir = tempfile::tempdir().unwrap();
331 write_files(&files, dir.path(), true, false).unwrap();
332
333 let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
334 assert!(content.starts_with(COMMENT));
335 }
336
337 #[test]
338 fn test_single_has_section_separator() {
339 let files = vec![make_file("users.rs", "code", Some("Table: public.users"))];
340 let dir = tempfile::tempdir().unwrap();
341 write_files(&files, dir.path(), true, false).unwrap();
342
343 let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
344 assert!(content.contains("// --- Table: public.users ---"));
345 }
346
347 #[test]
348 fn test_single_concatenates_all_code() {
349 let files = vec![
350 make_file("users.rs", "struct Users;", Some("Table: public.users")),
351 make_file("posts.rs", "struct Posts;", Some("Table: public.posts")),
352 ];
353 let dir = tempfile::tempdir().unwrap();
354 write_files(&files, dir.path(), true, false).unwrap();
355
356 let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
357 assert!(content.contains("struct Users;"));
358 assert!(content.contains("struct Posts;"));
359 }
360
361 #[test]
362 fn test_single_no_origin_no_separator() {
363 let files = vec![make_file("types.rs", "code", None)];
364 let dir = tempfile::tempdir().unwrap();
365 write_files(&files, dir.path(), true, false).unwrap();
366
367 let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
368 assert!(!content.contains("// ---"));
369 }
370
371 #[test]
374 fn test_atomic_creates_file_with_content() {
375 let dir = tempfile::tempdir().unwrap();
376 let path = dir.path().join("out.rs");
377 write_atomic(&path, b"hello").unwrap();
378 assert_eq!(std::fs::read_to_string(&path).unwrap(), "hello");
379 }
380
381 #[test]
382 fn test_atomic_overwrites_existing_file() {
383 let dir = tempfile::tempdir().unwrap();
384 let path = dir.path().join("out.rs");
385 std::fs::write(&path, "old").unwrap();
386 write_atomic(&path, b"new").unwrap();
387 assert_eq!(std::fs::read_to_string(&path).unwrap(), "new");
388 }
389
390 #[test]
391 fn test_atomic_leaves_no_temp_artifacts_on_success() {
392 let dir = tempfile::tempdir().unwrap();
393 let path = dir.path().join("out.rs");
394 write_atomic(&path, b"x").unwrap();
395 let entries: Vec<_> = std::fs::read_dir(dir.path())
396 .unwrap()
397 .map(|e| e.unwrap().file_name())
398 .collect();
399 assert_eq!(entries.len(), 1);
400 assert_eq!(entries[0].to_string_lossy(), "out.rs");
401 }
402
403 #[test]
406 fn test_rejects_dot_dot_in_filename() {
407 let files = vec![make_file("../escape.rs", "code", None)];
408 let dir = tempfile::tempdir().unwrap();
409 assert!(write_files(&files, dir.path(), false, false).is_err());
410 }
411
412 #[test]
413 fn test_rejects_absolute_path_filename() {
414 let files = vec![make_file("/etc/passwd", "code", None)];
415 let dir = tempfile::tempdir().unwrap();
416 assert!(write_files(&files, dir.path(), false, false).is_err());
417 }
418
419 #[test]
420 fn test_rejects_path_separator_in_filename() {
421 let files = vec![make_file("sub/dir/file.rs", "code", None)];
422 let dir = tempfile::tempdir().unwrap();
423 assert!(write_files(&files, dir.path(), false, false).is_err());
424 }
425
426 #[test]
427 fn test_rejects_non_rs_extension() {
428 let files = vec![make_file("evil.sh", "code", None)];
429 let dir = tempfile::tempdir().unwrap();
430 assert!(write_files(&files, dir.path(), false, false).is_err());
431 }
432
433 #[test]
434 fn test_rejects_empty_filename() {
435 let files = vec![make_file("", "code", None)];
436 let dir = tempfile::tempdir().unwrap();
437 assert!(write_files(&files, dir.path(), false, false).is_err());
438 }
439
440 #[test]
441 fn test_accepts_normal_rs_filename() {
442 let files = vec![make_file("users.rs", "code", None)];
443 let dir = tempfile::tempdir().unwrap();
444 assert!(write_files(&files, dir.path(), false, false).is_ok());
445 }
446}