use std::io::Write;
use std::path::Path;
use crate::error::Result;
use crate::codegen::GeneratedFile;
const COMMENT: &str = "// Auto-generated by sqlx-gen. Do not edit.";
const INNER_ATTR: &str = "#![allow(unused_attributes)]";
pub(crate) fn write_atomic(path: &Path, content: &[u8]) -> Result<()> {
let parent = path.parent().ok_or_else(|| {
crate::error::Error::Config(format!(
"Cannot determine parent directory of {}",
path.display()
))
})?;
let mut tmp = tempfile::NamedTempFile::new_in(parent)?;
tmp.write_all(content)?;
tmp.flush()?;
tmp.persist(path).map_err(|e| e.error)?;
Ok(())
}
pub fn write_files(
files: &[GeneratedFile],
output_dir: &Path,
single_file: bool,
dry_run: bool,
) -> Result<()> {
for f in files {
validate_safe_filename(&f.filename)?;
}
if dry_run {
for f in files {
println!("{}", build_file_content(f));
println!();
}
return Ok(());
}
std::fs::create_dir_all(output_dir)?;
if single_file {
write_single_file(files, output_dir)?;
} else {
write_multi_files(files, output_dir)?;
}
Ok(())
}
fn validate_safe_filename(filename: &str) -> Result<()> {
let p = Path::new(filename);
if filename.is_empty()
|| p.components().count() != 1
|| p.is_absolute()
|| filename.contains("..")
|| filename.contains('/')
|| filename.contains('\\')
|| !filename.ends_with(".rs")
{
return Err(crate::error::Error::Config(format!(
"Refusing to write generated file with unsafe name: {:?}",
filename
)));
}
Ok(())
}
fn build_file_content(f: &GeneratedFile) -> String {
let mut content = String::new();
content.push_str(COMMENT);
content.push('\n');
if let Some(origin) = &f.origin {
content.push_str(&format!("// {}\n", origin));
}
content.push('\n');
content.push_str(INNER_ATTR);
content.push_str("\n\n");
content.push_str(&f.code);
content
}
fn write_single_file(files: &[GeneratedFile], output_dir: &Path) -> Result<()> {
let mut content = format!("{}\n\n{}\n\n", COMMENT, INNER_ATTR);
for f in files {
if let Some(origin) = &f.origin {
content.push_str(&format!("// --- {} ---\n\n", origin));
}
content.push_str(&f.code);
content.push('\n');
}
let path = output_dir.join("models.rs");
write_atomic(&path, content.as_bytes())?;
log::info!("Wrote {}", path.display());
Ok(())
}
fn write_multi_files(files: &[GeneratedFile], output_dir: &Path) -> Result<()> {
let mut mod_entries = Vec::new();
for f in files {
let content = build_file_content(f);
let path = output_dir.join(&f.filename);
write_atomic(&path, content.as_bytes())?;
log::info!("Wrote {}", path.display());
let mod_name = f.filename.strip_suffix(".rs").unwrap_or(&f.filename);
mod_entries.push(mod_name.to_string());
}
let mut mod_content = format!("{}\n\n{}\n\n", COMMENT, INNER_ATTR);
for m in &mod_entries {
mod_content.push_str(&format!("pub mod {};\n", m));
}
let mod_path = output_dir.join("mod.rs");
write_atomic(&mod_path, mod_content.as_bytes())?;
log::info!("Wrote {}", mod_path.display());
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codegen::GeneratedFile;
fn make_file(filename: &str, code: &str, origin: Option<&str>) -> GeneratedFile {
GeneratedFile {
filename: filename.to_string(),
origin: origin.map(|s| s.to_string()),
code: code.to_string(),
}
}
#[test]
fn test_build_content_with_origin() {
let f = make_file(
"users.rs",
"pub struct Users {}",
Some("Table: public.users"),
);
let content = build_file_content(&f);
assert!(content.contains(COMMENT));
assert!(content.contains(INNER_ATTR));
assert!(content.contains("// Table: public.users"));
assert!(content.contains("pub struct Users {}"));
}
#[test]
fn test_build_content_without_origin() {
let f = make_file("types.rs", "pub enum Status {}", None);
let content = build_file_content(&f);
assert!(content.contains(COMMENT));
assert!(content.contains(INNER_ATTR));
assert!(!content.contains("// Table:"));
assert!(content.contains("pub enum Status {}"));
}
#[test]
fn test_build_content_header_value() {
let f = make_file("x.rs", "", None);
let content = build_file_content(&f);
assert!(content.starts_with("// Auto-generated by sqlx-gen. Do not edit."));
}
#[test]
fn test_build_content_preserves_code() {
let code = "use chrono::NaiveDateTime;\n\npub struct Foo {\n pub x: i32,\n}";
let f = make_file("foo.rs", code, None);
let content = build_file_content(&f);
assert!(content.contains(code));
}
#[test]
fn test_build_content_origin_format() {
let f = make_file("x.rs", "code", Some("Table: public.users"));
let content = build_file_content(&f);
assert!(content.contains("// Table: public.users\n"));
}
#[test]
fn test_build_content_empty_code() {
let f = make_file("x.rs", "", Some("Table: public.x"));
let content = build_file_content(&f);
assert!(content.contains(COMMENT));
assert!(content.contains(INNER_ATTR));
assert!(content.contains("// Table: public.x"));
}
#[test]
fn test_dry_run_returns_ok() {
let files = vec![make_file("users.rs", "code", Some("origin"))];
let dir = tempfile::tempdir().unwrap();
let result = write_files(&files, dir.path(), false, true);
assert!(result.is_ok());
}
#[test]
fn test_dry_run_no_files_created() {
let files = vec![make_file("users.rs", "code", Some("origin"))];
let dir = tempfile::tempdir().unwrap();
let sub = dir.path().join("output");
let _ = write_files(&files, &sub, false, true);
assert!(!sub.exists());
}
#[test]
fn test_dry_run_empty_files() {
let dir = tempfile::tempdir().unwrap();
let result = write_files(&[], dir.path(), false, true);
assert!(result.is_ok());
}
#[test]
fn test_multi_creates_files_and_mod() {
let files = vec![
make_file(
"users.rs",
"pub struct Users {}",
Some("Table: public.users"),
),
make_file(
"posts.rs",
"pub struct Posts {}",
Some("Table: public.posts"),
),
];
let dir = tempfile::tempdir().unwrap();
write_files(&files, dir.path(), false, false).unwrap();
assert!(dir.path().join("users.rs").exists());
assert!(dir.path().join("posts.rs").exists());
assert!(dir.path().join("mod.rs").exists());
}
#[test]
fn test_multi_mod_rs_content() {
let files = vec![
make_file("users.rs", "code", Some("origin")),
make_file("types.rs", "code", None),
];
let dir = tempfile::tempdir().unwrap();
write_files(&files, dir.path(), false, false).unwrap();
let mod_content = std::fs::read_to_string(dir.path().join("mod.rs")).unwrap();
assert!(mod_content.contains("pub mod users;"));
assert!(mod_content.contains("pub mod types;"));
}
#[test]
fn test_multi_file_has_header() {
let files = vec![make_file("users.rs", "code", Some("Table: public.users"))];
let dir = tempfile::tempdir().unwrap();
write_files(&files, dir.path(), false, false).unwrap();
let content = std::fs::read_to_string(dir.path().join("users.rs")).unwrap();
assert!(content.starts_with(COMMENT));
}
#[test]
fn test_multi_file_has_origin() {
let files = vec![make_file("users.rs", "code", Some("Table: public.users"))];
let dir = tempfile::tempdir().unwrap();
write_files(&files, dir.path(), false, false).unwrap();
let content = std::fs::read_to_string(dir.path().join("users.rs")).unwrap();
assert!(content.contains("// Table: public.users"));
}
#[test]
fn test_multi_creates_output_dir() {
let dir = tempfile::tempdir().unwrap();
let sub = dir.path().join("nested").join("output");
let files = vec![make_file("users.rs", "code", Some("origin"))];
write_files(&files, &sub, false, false).unwrap();
assert!(sub.join("users.rs").exists());
}
#[test]
fn test_multi_file_no_origin() {
let files = vec![make_file("types.rs", "code", None)];
let dir = tempfile::tempdir().unwrap();
write_files(&files, dir.path(), false, false).unwrap();
let content = std::fs::read_to_string(dir.path().join("types.rs")).unwrap();
assert!(content.contains(COMMENT));
assert!(content.contains(INNER_ATTR));
assert!(!content.contains("// Table:"));
}
#[test]
fn test_single_creates_models_rs() {
let files = vec![make_file(
"users.rs",
"pub struct Users {}",
Some("Table: public.users"),
)];
let dir = tempfile::tempdir().unwrap();
write_files(&files, dir.path(), true, false).unwrap();
assert!(dir.path().join("models.rs").exists());
}
#[test]
fn test_single_starts_with_header() {
let files = vec![make_file("users.rs", "code", Some("origin"))];
let dir = tempfile::tempdir().unwrap();
write_files(&files, dir.path(), true, false).unwrap();
let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
assert!(content.starts_with(COMMENT));
}
#[test]
fn test_single_has_section_separator() {
let files = vec![make_file("users.rs", "code", Some("Table: public.users"))];
let dir = tempfile::tempdir().unwrap();
write_files(&files, dir.path(), true, false).unwrap();
let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
assert!(content.contains("// --- Table: public.users ---"));
}
#[test]
fn test_single_concatenates_all_code() {
let files = vec![
make_file("users.rs", "struct Users;", Some("Table: public.users")),
make_file("posts.rs", "struct Posts;", Some("Table: public.posts")),
];
let dir = tempfile::tempdir().unwrap();
write_files(&files, dir.path(), true, false).unwrap();
let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
assert!(content.contains("struct Users;"));
assert!(content.contains("struct Posts;"));
}
#[test]
fn test_single_no_origin_no_separator() {
let files = vec![make_file("types.rs", "code", None)];
let dir = tempfile::tempdir().unwrap();
write_files(&files, dir.path(), true, false).unwrap();
let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
assert!(!content.contains("// ---"));
}
#[test]
fn test_atomic_creates_file_with_content() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("out.rs");
write_atomic(&path, b"hello").unwrap();
assert_eq!(std::fs::read_to_string(&path).unwrap(), "hello");
}
#[test]
fn test_atomic_overwrites_existing_file() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("out.rs");
std::fs::write(&path, "old").unwrap();
write_atomic(&path, b"new").unwrap();
assert_eq!(std::fs::read_to_string(&path).unwrap(), "new");
}
#[test]
fn test_atomic_leaves_no_temp_artifacts_on_success() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("out.rs");
write_atomic(&path, b"x").unwrap();
let entries: Vec<_> = std::fs::read_dir(dir.path())
.unwrap()
.map(|e| e.unwrap().file_name())
.collect();
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].to_string_lossy(), "out.rs");
}
#[test]
fn test_rejects_dot_dot_in_filename() {
let files = vec![make_file("../escape.rs", "code", None)];
let dir = tempfile::tempdir().unwrap();
assert!(write_files(&files, dir.path(), false, false).is_err());
}
#[test]
fn test_rejects_absolute_path_filename() {
let files = vec![make_file("/etc/passwd", "code", None)];
let dir = tempfile::tempdir().unwrap();
assert!(write_files(&files, dir.path(), false, false).is_err());
}
#[test]
fn test_rejects_path_separator_in_filename() {
let files = vec![make_file("sub/dir/file.rs", "code", None)];
let dir = tempfile::tempdir().unwrap();
assert!(write_files(&files, dir.path(), false, false).is_err());
}
#[test]
fn test_rejects_non_rs_extension() {
let files = vec![make_file("evil.sh", "code", None)];
let dir = tempfile::tempdir().unwrap();
assert!(write_files(&files, dir.path(), false, false).is_err());
}
#[test]
fn test_rejects_empty_filename() {
let files = vec![make_file("", "code", None)];
let dir = tempfile::tempdir().unwrap();
assert!(write_files(&files, dir.path(), false, false).is_err());
}
#[test]
fn test_accepts_normal_rs_filename() {
let files = vec![make_file("users.rs", "code", None)];
let dir = tempfile::tempdir().unwrap();
assert!(write_files(&files, dir.path(), false, false).is_ok());
}
}