use std::path::Path;
use crate::error::Result;
use crate::codegen::GeneratedFile;
const COMMENT: &str = "// Auto-generated by sqlx-gen. Do not edit.";
const INNER_ATTR: &str = "#![allow(unused_attributes)]";
pub fn write_files(
files: &[GeneratedFile],
output_dir: &Path,
single_file: bool,
dry_run: bool,
) -> Result<()> {
if dry_run {
for f in files {
println!("{}", build_file_content(f));
println!();
}
return Ok(());
}
std::fs::create_dir_all(output_dir)?;
if single_file {
write_single_file(files, output_dir)?;
} else {
write_multi_files(files, output_dir)?;
}
Ok(())
}
fn build_file_content(f: &GeneratedFile) -> String {
let mut content = String::new();
content.push_str(COMMENT);
content.push('\n');
if let Some(origin) = &f.origin {
content.push_str(&format!("// {}\n", origin));
}
content.push('\n');
content.push_str(INNER_ATTR);
content.push_str("\n\n");
content.push_str(&f.code);
content
}
fn write_single_file(files: &[GeneratedFile], output_dir: &Path) -> Result<()> {
let mut content = format!("{}\n\n{}\n\n", COMMENT, INNER_ATTR);
for f in files {
if let Some(origin) = &f.origin {
content.push_str(&format!("// --- {} ---\n\n", origin));
}
content.push_str(&f.code);
content.push('\n');
}
let path = output_dir.join("models.rs");
std::fs::write(&path, &content)?;
log::info!("Wrote {}", path.display());
Ok(())
}
fn write_multi_files(files: &[GeneratedFile], output_dir: &Path) -> Result<()> {
let mut mod_entries = Vec::new();
for f in files {
let content = build_file_content(f);
let path = output_dir.join(&f.filename);
std::fs::write(&path, &content)?;
log::info!("Wrote {}", path.display());
let mod_name = f.filename.strip_suffix(".rs").unwrap_or(&f.filename);
mod_entries.push(mod_name.to_string());
}
let mut mod_content = format!("{}\n\n{}\n\n", COMMENT, INNER_ATTR);
for m in &mod_entries {
mod_content.push_str(&format!("pub mod {};\n", m));
}
let mod_path = output_dir.join("mod.rs");
std::fs::write(&mod_path, &mod_content)?;
log::info!("Wrote {}", mod_path.display());
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codegen::GeneratedFile;
fn make_file(filename: &str, code: &str, origin: Option<&str>) -> GeneratedFile {
GeneratedFile {
filename: filename.to_string(),
origin: origin.map(|s| s.to_string()),
code: code.to_string(),
}
}
#[test]
fn test_build_content_with_origin() {
let f = make_file("users.rs", "pub struct Users {}", Some("Table: public.users"));
let content = build_file_content(&f);
assert!(content.contains(COMMENT));
assert!(content.contains(INNER_ATTR));
assert!(content.contains("// Table: public.users"));
assert!(content.contains("pub struct Users {}"));
}
#[test]
fn test_build_content_without_origin() {
let f = make_file("types.rs", "pub enum Status {}", None);
let content = build_file_content(&f);
assert!(content.contains(COMMENT));
assert!(content.contains(INNER_ATTR));
assert!(!content.contains("// Table:"));
assert!(content.contains("pub enum Status {}"));
}
#[test]
fn test_build_content_header_value() {
let f = make_file("x.rs", "", None);
let content = build_file_content(&f);
assert!(content.starts_with("// Auto-generated by sqlx-gen. Do not edit."));
}
#[test]
fn test_build_content_preserves_code() {
let code = "use chrono::NaiveDateTime;\n\npub struct Foo {\n pub x: i32,\n}";
let f = make_file("foo.rs", code, None);
let content = build_file_content(&f);
assert!(content.contains(code));
}
#[test]
fn test_build_content_origin_format() {
let f = make_file("x.rs", "code", Some("Table: public.users"));
let content = build_file_content(&f);
assert!(content.contains("// Table: public.users\n"));
}
#[test]
fn test_build_content_empty_code() {
let f = make_file("x.rs", "", Some("Table: public.x"));
let content = build_file_content(&f);
assert!(content.contains(COMMENT));
assert!(content.contains(INNER_ATTR));
assert!(content.contains("// Table: public.x"));
}
#[test]
fn test_dry_run_returns_ok() {
let files = vec![make_file("users.rs", "code", Some("origin"))];
let dir = tempfile::tempdir().unwrap();
let result = write_files(&files, dir.path(), false, true);
assert!(result.is_ok());
}
#[test]
fn test_dry_run_no_files_created() {
let files = vec![make_file("users.rs", "code", Some("origin"))];
let dir = tempfile::tempdir().unwrap();
let sub = dir.path().join("output");
let _ = write_files(&files, &sub, false, true);
assert!(!sub.exists());
}
#[test]
fn test_dry_run_empty_files() {
let dir = tempfile::tempdir().unwrap();
let result = write_files(&[], dir.path(), false, true);
assert!(result.is_ok());
}
#[test]
fn test_multi_creates_files_and_mod() {
let files = vec![
make_file("users.rs", "pub struct Users {}", Some("Table: public.users")),
make_file("posts.rs", "pub struct Posts {}", Some("Table: public.posts")),
];
let dir = tempfile::tempdir().unwrap();
write_files(&files, dir.path(), false, false).unwrap();
assert!(dir.path().join("users.rs").exists());
assert!(dir.path().join("posts.rs").exists());
assert!(dir.path().join("mod.rs").exists());
}
#[test]
fn test_multi_mod_rs_content() {
let files = vec![
make_file("users.rs", "code", Some("origin")),
make_file("types.rs", "code", None),
];
let dir = tempfile::tempdir().unwrap();
write_files(&files, dir.path(), false, false).unwrap();
let mod_content = std::fs::read_to_string(dir.path().join("mod.rs")).unwrap();
assert!(mod_content.contains("pub mod users;"));
assert!(mod_content.contains("pub mod types;"));
}
#[test]
fn test_multi_file_has_header() {
let files = vec![make_file("users.rs", "code", Some("Table: public.users"))];
let dir = tempfile::tempdir().unwrap();
write_files(&files, dir.path(), false, false).unwrap();
let content = std::fs::read_to_string(dir.path().join("users.rs")).unwrap();
assert!(content.starts_with(COMMENT));
}
#[test]
fn test_multi_file_has_origin() {
let files = vec![make_file("users.rs", "code", Some("Table: public.users"))];
let dir = tempfile::tempdir().unwrap();
write_files(&files, dir.path(), false, false).unwrap();
let content = std::fs::read_to_string(dir.path().join("users.rs")).unwrap();
assert!(content.contains("// Table: public.users"));
}
#[test]
fn test_multi_creates_output_dir() {
let dir = tempfile::tempdir().unwrap();
let sub = dir.path().join("nested").join("output");
let files = vec![make_file("users.rs", "code", Some("origin"))];
write_files(&files, &sub, false, false).unwrap();
assert!(sub.join("users.rs").exists());
}
#[test]
fn test_multi_file_no_origin() {
let files = vec![make_file("types.rs", "code", None)];
let dir = tempfile::tempdir().unwrap();
write_files(&files, dir.path(), false, false).unwrap();
let content = std::fs::read_to_string(dir.path().join("types.rs")).unwrap();
assert!(content.contains(COMMENT));
assert!(content.contains(INNER_ATTR));
assert!(!content.contains("// Table:"));
}
#[test]
fn test_single_creates_models_rs() {
let files = vec![make_file("users.rs", "pub struct Users {}", Some("Table: public.users"))];
let dir = tempfile::tempdir().unwrap();
write_files(&files, dir.path(), true, false).unwrap();
assert!(dir.path().join("models.rs").exists());
}
#[test]
fn test_single_starts_with_header() {
let files = vec![make_file("users.rs", "code", Some("origin"))];
let dir = tempfile::tempdir().unwrap();
write_files(&files, dir.path(), true, false).unwrap();
let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
assert!(content.starts_with(COMMENT));
}
#[test]
fn test_single_has_section_separator() {
let files = vec![make_file("users.rs", "code", Some("Table: public.users"))];
let dir = tempfile::tempdir().unwrap();
write_files(&files, dir.path(), true, false).unwrap();
let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
assert!(content.contains("// --- Table: public.users ---"));
}
#[test]
fn test_single_concatenates_all_code() {
let files = vec![
make_file("users.rs", "struct Users;", Some("Table: public.users")),
make_file("posts.rs", "struct Posts;", Some("Table: public.posts")),
];
let dir = tempfile::tempdir().unwrap();
write_files(&files, dir.path(), true, false).unwrap();
let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
assert!(content.contains("struct Users;"));
assert!(content.contains("struct Posts;"));
}
#[test]
fn test_single_no_origin_no_separator() {
let files = vec![make_file("types.rs", "code", None)];
let dir = tempfile::tempdir().unwrap();
write_files(&files, dir.path(), true, false).unwrap();
let content = std::fs::read_to_string(dir.path().join("models.rs")).unwrap();
assert!(!content.contains("// ---"));
}
}