use anyhow::Result;
use std::path::Path;
use std::time::{SystemTime, UNIX_EPOCH};
use super::InitArgs;
use super::import::{ImportSource, import_schema};
use super::project::{create_project_structure, generate_config_file};
use super::prompts::{gather_init_options_with_args, prompt_baseline_creation};
use crate::baseline::operations::{
BaselineCreationRequest, create_baseline, display_baseline_summary,
};
use crate::catalog::Catalog;
use crate::config::load_config;
use crate::constants::CONFIG_FILENAME;
use crate::db::connection::mask_url_password;
use crate::migration_tracking;
use crate::prompts::ShadowDatabaseInput;
#[derive(Debug)]
pub enum ExistingConfigResult {
NotFound,
Update(Box<crate::config::types::ConfigInput>),
Fresh,
Cancelled,
}
pub fn check_existing_config(
project_dir: &Path,
force_fresh: bool,
) -> Result<ExistingConfigResult> {
let config_path = project_dir.join(CONFIG_FILENAME);
if !config_path.exists() {
return Ok(ExistingConfigResult::NotFound);
}
if force_fresh {
println!(
"⚠️ Existing {} will be overwritten (--fresh flag)\n",
CONFIG_FILENAME
);
return Ok(ExistingConfigResult::Fresh);
}
let config_path_str = config_path.to_string_lossy();
let (existing_config, _) = load_config(&config_path_str)?;
let databases = existing_config.databases.as_ref();
let directories = existing_config.directories.as_ref();
println!("📋 Existing configuration found:\n");
if let Some(url) = databases.and_then(|d| d.dev_url.as_ref()) {
println!(" Database: {}", mask_url_password(url));
}
if let Some(schema_dir) = directories.and_then(|d| d.schema_dir.as_ref()) {
println!(" Schema dir: {}", schema_dir);
}
if let Some(migrations_dir) = directories.and_then(|d| d.migrations_dir.as_ref()) {
println!(" Migrations: {}", migrations_dir);
}
if let Some(baselines_dir) = directories.and_then(|d| d.baselines_dir.as_ref()) {
println!(" Baselines: {}", baselines_dir);
}
if let Some(pg_version) = databases
.and_then(|d| d.shadow.as_ref())
.and_then(|s| s.docker.as_ref())
.and_then(|d| d.version.as_ref())
{
println!(" Shadow PG: {}", pg_version);
}
println!();
let choices = vec![
"Update - modify existing configuration",
"Fresh - start over with new configuration",
"Cancel - keep current configuration",
];
let selection = dialoguer::Select::new()
.with_prompt("What would you like to do?")
.items(&choices)
.default(0)
.interact()?;
match selection {
0 => {
println!("\n✏️ Update mode: existing values will be shown as defaults\n");
Ok(ExistingConfigResult::Update(Box::new(existing_config)))
}
1 => {
println!("\n🔄 Fresh mode: creating new configuration\n");
Ok(ExistingConfigResult::Fresh)
}
_ => {
println!("\n❌ Keeping existing configuration");
Ok(ExistingConfigResult::Cancelled)
}
}
}
#[derive(Debug)]
pub struct InitOptions {
pub project_dir: std::path::PathBuf,
pub dev_database_url: String,
pub shadow_config: ShadowDatabaseInput,
pub shadow_pg_version: Option<String>,
pub detected_pg_version: Option<String>,
pub schema_dir: std::path::PathBuf,
pub migrations_dir: String,
pub baselines_dir: String,
pub import_source: Option<ImportSource>,
pub object_config: ObjectManagementConfig,
pub baseline_config: BaselineCreationConfig,
#[allow(dead_code)]
pub tracking_table: crate::config::types::TrackingTable,
pub roles_file: Option<String>,
pub objects: crate::config::types::Objects,
pub substrate_exclusions: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct ObjectManagementConfig {
pub comments: bool,
pub grants: bool,
pub triggers: bool,
pub extensions: bool,
}
impl Default for ObjectManagementConfig {
fn default() -> Self {
Self {
comments: true,
grants: true,
triggers: true,
extensions: true,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct BaselineCreationConfig {
pub create_baseline: Option<bool>,
pub description: Option<String>,
}
pub async fn cmd_init_with_args(args: &InitArgs) -> Result<()> {
println!("🚀 Welcome to pgmt! Let's set up your PostgreSQL migration project.\n");
let project_dir = std::env::current_dir()?;
let existing_config = check_existing_config(&project_dir, args.fresh)?;
let existing_input = match existing_config {
ExistingConfigResult::NotFound | ExistingConfigResult::Fresh => None,
ExistingConfigResult::Update(config) => Some(*config),
ExistingConfigResult::Cancelled => {
return Ok(());
}
};
let mut options = gather_init_options_with_args(args, existing_input.as_ref()).await?;
if !args.defaults {
let confirmed = super::prompts::prompt_project_confirmation(&options)?;
if !confirmed {
println!("❌ Project initialization cancelled by user.");
return Ok(());
}
}
println!("🏗️ Creating project structure...");
create_project_structure(&options)?;
println!("✅ Project directories created");
let catalog = if let Some(import_source) = options.import_source.clone() {
match import_catalog_from_source(&import_source, &options).await? {
Some((catalog, substrate_exclusions)) => {
if !substrate_exclusions.is_empty() {
options
.objects
.exclude
.schemas
.extend(substrate_exclusions.iter().cloned());
options.substrate_exclusions = substrate_exclusions;
}
let filter = crate::config::filter::ObjectFilter::new(
&options.objects,
&options.tracking_table,
);
Some(filter.filter_catalog(catalog))
}
None => None,
}
} else {
None
};
if let Some(ref cat) = catalog {
show_catalog_preview(cat);
if !args.defaults {
options.object_config =
super::prompts::prompt_object_management_config_with_context(cat)?;
}
} else if !args.defaults {
options.object_config = super::prompts::prompt_object_management_config()?;
}
let baseline_result = if let Some(ref cat) = catalog {
process_imported_catalog(cat, &options).await?
} else {
BaselineResult::NotRequested
};
println!("📝 Generating configuration file...");
generate_config_file(&options, existing_input.as_ref(), &options.project_dir)?;
println!("✅ pgmt.yaml created");
print_success_summary(&options, &baseline_result);
Ok(())
}
fn resolve_shadow_database(
shadow_config: &ShadowDatabaseInput,
shadow_pg_version: Option<&String>,
detected_pg_version: Option<&String>,
) -> crate::config::types::ShadowDatabase {
use crate::config::types::{ShadowDatabase, ShadowDockerConfig};
match shadow_config {
ShadowDatabaseInput::Auto => {
let version = shadow_pg_version.or(detected_pg_version);
if let Some(v) = version {
let major_version = crate::prompts::extract_major_version(v);
ShadowDatabase::Docker(ShadowDockerConfig {
version: Some(major_version),
..Default::default()
})
} else {
ShadowDatabase::Auto
}
}
ShadowDatabaseInput::Docker { image, platform } => {
ShadowDatabase::Docker(ShadowDockerConfig {
image: image.clone(),
platform: platform.clone(),
..Default::default()
})
}
ShadowDatabaseInput::Manual(url) => ShadowDatabase::Url {
url: url.clone(),
reset: crate::config::types::ShadowResetMode::default(),
},
}
}
async fn import_catalog_from_source(
import_source: &ImportSource,
options: &InitOptions,
) -> Result<Option<(Catalog, Vec<String>)>> {
use crate::config::types::{ShadowDatabase, ShadowResetMode};
println!("📥 Importing existing schema...");
println!(" Source: {}", import_source.description());
let shadow_database = resolve_shadow_database(
&options.shadow_config,
options.shadow_pg_version.as_ref(),
options.detected_pg_version.as_ref(),
);
let sql_source = matches!(
import_source,
ImportSource::SqlFile(_) | ImportSource::Directory(_)
);
if let ShadowDatabase::Url {
url,
reset: ShadowResetMode::Clean,
} = &shadow_database
&& sql_source
{
println!(
"⚠️ The shadow database at {} will be reset: every schema pgmt manages will be dropped before the import.",
crate::db::connection::mask_url_password(url)
);
let confirmed = dialoguer::Confirm::new()
.with_prompt(" Reset this database and continue?")
.default(false)
.interact()?;
if !confirmed {
return Err(anyhow::anyhow!(
"import cancelled — shadow database left untouched"
));
}
}
let (shadow_url, substrate_exclusions) = if sql_source {
let shadow_url = shadow_database.get_connection_string().await?;
let branch_backed = !matches!(
shadow_database,
ShadowDatabase::Url {
reset: ShadowResetMode::Clean,
..
}
);
let exclusions = if branch_backed {
let substrate = fetch_substrate_schemas(&shadow_url).await?;
if substrate.is_empty() {
Vec::new()
} else {
super::prompts::prompt_substrate_exclusions(&substrate)?
}
} else {
Vec::new()
};
(shadow_url, exclusions)
} else {
(String::new(), Vec::new())
};
let roles_path = options
.roles_file
.as_ref()
.map(|f| options.project_dir.join(f));
match import_schema(
import_source.clone(),
&shadow_url,
roles_path.as_deref(),
&options.objects,
)
.await
{
Ok(catalog) => {
println!("✅ Schema import completed");
Ok(Some((catalog, substrate_exclusions)))
}
Err(e) => {
eprintln!("\n⚠️ Schema import failed:\n{:#}", e);
eprintln!("\n🔧 What would you like to do?");
let recovery_options = vec![
"Skip import and continue with empty project",
"Exit setup (you can run 'pgmt init' again later)",
];
let choice = dialoguer::Select::new()
.with_prompt("Choose an option")
.items(&recovery_options)
.default(0)
.interact()?;
match choice {
0 => {
println!(
"⚠️ Skipping schema import. You can add schema files manually later."
);
println!(
" 💡 Tip: You can also try importing again with 'pgmt apply' after setup."
);
eprintln!(" Continuing with empty project setup...");
Ok(None)
}
1 => {
println!("❌ Setup cancelled. Run 'pgmt init' again when ready.");
std::process::exit(1);
}
_ => Ok(None),
}
}
}
}
pub async fn fetch_substrate_schemas(shadow_url: &str) -> Result<Vec<String>> {
let pool = crate::db::connection::connect_with_retry_quiet(shadow_url).await?;
let rows: Vec<(String,)> = sqlx::query_as(
"SELECT nspname FROM pg_namespace
WHERE nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'public')
AND nspname NOT LIKE 'pg_temp_%'
AND nspname NOT LIKE 'pg_toast_temp_%'
ORDER BY nspname",
)
.fetch_all(&pool)
.await?;
pool.close().await;
Ok(rows.into_iter().map(|(n,)| n).collect())
}
#[derive(Debug, Clone)]
pub enum BaselineResult {
NotRequested,
Created,
NeedsAttention { reason: String },
Failed(String),
}
async fn process_imported_catalog(
catalog: &Catalog,
options: &InitOptions,
) -> Result<BaselineResult> {
let total_objects = count_catalog_objects(catalog);
if total_objects == 0 {
println!("⚠️ No database objects found in the imported schema.");
println!(" Continuing with empty schema directory...");
return Ok(BaselineResult::NotRequested);
}
println!("\n📝 Generating schema files from your database...");
let file_count = match generate_schema_files(catalog, options).await {
Ok(count) => count,
Err(e) => {
eprintln!("❌ Schema file generation failed: {}", e);
return Ok(BaselineResult::Failed(e.to_string()));
}
};
println!("✅ Generated {} schema files", file_count);
println!("\n🔍 Validating schema files...");
let schema_dir = options.project_dir.join(&options.schema_dir);
let roles_path = options
.roles_file
.as_ref()
.map(|f| options.project_dir.join(f));
match validate_schema_files(
&schema_dir,
roles_path.as_deref(),
&options.shadow_config,
options.shadow_pg_version.as_ref(),
options.detected_pg_version.as_ref(),
)
.await
{
Ok(_) => {
println!("✅ Schema validation passed");
}
Err(e) => {
let error_str = format!("{:#}", e);
if error_str.contains("Circular dependency") {
println!("\n📌 Circular dependency detected in schema files");
if let Some(cycle_info) = extract_circular_dep_info(&error_str) {
println!(" {}", cycle_info);
}
println!();
println!(" This is common in complex databases with bidirectional foreign keys.");
println!(" To fix: move one foreign key to a separate file (e.g., constraints/)");
println!(" so the tables can be created before the constraint is added.");
return Ok(BaselineResult::NeedsAttention {
reason: "Circular dependency detected".to_string(),
});
}
println!("\n━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
println!("⚠️ SCHEMA VALIDATION FAILED");
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n");
println!("{}\n", error_str);
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n");
println!("Next steps:");
println!(" 1. Fix dependencies in schema files (add '-- require:' statements)");
println!(" 2. Test with: pgmt apply --dry-run");
println!(" 3. Repeat until validation passes");
println!(" 4. Create baseline: pgmt migrate baseline\n");
return Ok(BaselineResult::Failed(e.to_string()));
}
}
let database_state = analyze_database_state(catalog);
let should_create_baseline = match &options.baseline_config.create_baseline {
Some(true) => true, Some(false) => false, None => {
prompt_baseline_creation(&database_state)?
}
};
if should_create_baseline {
let version = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
match create_baseline_with_migration_sync(catalog, options, version).await {
Ok((_baseline_path, _baseline_content)) => Ok(BaselineResult::Created),
Err(e) => {
handle_baseline_failure(&e);
Ok(BaselineResult::Failed(e.to_string()))
}
}
} else {
Ok(BaselineResult::NotRequested)
}
}
fn count_catalog_objects(catalog: &Catalog) -> usize {
catalog.tables.len()
+ catalog.views.len()
+ catalog.functions.len()
+ catalog.types.len()
+ catalog.sequences.len()
+ catalog.indexes.len()
+ catalog.constraints.len()
+ catalog.triggers.len()
+ catalog.extensions.len()
+ catalog.grants.len()
}
#[derive(Debug)]
pub enum DatabaseState {
Empty,
Existing { object_count: usize },
}
fn analyze_database_state(catalog: &Catalog) -> DatabaseState {
let total_objects = count_catalog_objects(catalog);
if total_objects <= 1 {
DatabaseState::Empty
} else {
DatabaseState::Existing {
object_count: total_objects,
}
}
}
fn show_catalog_preview(catalog: &Catalog) {
let total_objects = count_catalog_objects(catalog);
println!("\n📊 Schema Import Preview:");
println!(" 📋 {} tables", catalog.tables.len());
println!(" 👁 {} views", catalog.views.len());
println!(" ⚙️ {} functions", catalog.functions.len());
println!(" 🏷️ {} custom types", catalog.types.len());
println!(" 🔢 {} sequences", catalog.sequences.len());
println!(" 📇 {} indexes", catalog.indexes.len());
println!(" 🔗 {} constraints", catalog.constraints.len());
println!(" ⚡ {} triggers", catalog.triggers.len());
println!(" 🧩 {} extensions", catalog.extensions.len());
println!(" 🔑 {} grants", catalog.grants.len());
println!(" ═══════════════════");
println!(" 📦 {} total objects", total_objects);
}
fn extract_circular_dep_info(error_str: &str) -> Option<String> {
if let Some(start) = error_str.find("Circular dependency detected:") {
let after_prefix = &error_str[start + "Circular dependency detected:".len()..];
let cycle = after_prefix
.lines()
.next()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty());
return cycle;
}
None
}
fn handle_baseline_failure(error: &anyhow::Error) {
println!("\n❌ Baseline creation failed: {}", error);
if error.to_string().contains("relation") && error.to_string().contains("does not exist") {
println!("\n🔍 This error often indicates missing function dependencies.");
println!(" Some functions may reference tables that haven't been loaded yet.");
println!(" This is a known limitation - see README for details.");
println!("\n💡 Common fixes:");
println!(" • Add '-- require: tables/table_name.sql' to function files");
println!(" • Check function bodies for table references");
println!(" • Ensure proper loading order in your schema files");
} else {
println!("\n🔍 Baseline creation encountered an error.");
println!(" This might be due to:");
println!(" • Missing dependencies between schema objects");
println!(" • Permission issues");
println!(" • Database connection problems");
}
println!("\n⚠️ Skipping baseline creation due to errors.");
println!("💡 After fixing the dependency issues, run: pgmt migrate baseline");
}
async fn create_baseline_with_migration_sync(
catalog: &Catalog,
options: &InitOptions,
version: u64,
) -> Result<(std::path::PathBuf, String)> {
println!("💾 Creating baseline from current database state...");
let request = BaselineCreationRequest {
catalog: catalog.clone(),
version,
description: options
.baseline_config
.description
.clone()
.unwrap_or_else(|| "baseline".to_string()),
baselines_dir: options.project_dir.join(&options.baselines_dir),
verbose: false, };
let result = create_baseline(request).await?;
println!(
"✅ Created baseline: {}",
result.path.file_name().unwrap().to_str().unwrap()
);
display_baseline_summary(&result);
println!("🔄 Marking baseline as applied in migration tracking...");
use sqlx::PgPool;
let dev_pool = PgPool::connect(&options.dev_database_url).await?;
let tracking_table = crate::config::types::TrackingTable {
schema: "public".to_string(),
name: "pgmt_migrations".to_string(),
};
let checksum = migration_tracking::calculate_checksum(&result.baseline_sql);
migration_tracking::record_baseline_as_applied(
&dev_pool,
&tracking_table,
version,
&options
.baseline_config
.description
.clone()
.unwrap_or_else(|| "baseline".to_string()),
&checksum,
)
.await?;
println!("✅ Baseline marked as applied in migration tracking");
println!("💡 Future migrations will only contain NEW changes");
Ok((result.path, result.baseline_sql))
}
pub fn print_success_summary(options: &InitOptions, baseline_result: &BaselineResult) {
match baseline_result {
BaselineResult::Created => {
println!("\n🎉 Project initialized successfully!");
println!("\n📝 Created:");
println!(" ✅ pgmt.yaml (configuration)");
println!(
" ✅ {} directory with modular files",
options.schema_dir.display()
);
println!(" ✅ migrations/ directory");
println!(" ✅ schema_baselines/ directory");
println!(" ✅ Initial baseline from existing database");
println!("\nNext steps:");
println!(" 🚀 Run 'pgmt migrate new \"description\"' to create new migrations");
println!(" 💡 Future migrations will only contain NEW changes");
}
BaselineResult::NeedsAttention { reason } => {
println!("\n🎉 Project initialized successfully!");
println!("\n📝 Created:");
println!(" ✅ pgmt.yaml (configuration)");
println!(
" ✅ {} directory with modular files",
options.schema_dir.display()
);
println!(" ✅ migrations/ directory");
println!(" ✅ schema_baselines/ directory");
println!("\n📌 {}", reason);
println!("\nNext steps:");
println!(
" 1. Move one foreign key from the cycle to a separate file (e.g., schema/constraints/)"
);
println!(" 2. Test with: pgmt apply --dry-run");
println!(" 3. Create baseline: pgmt migrate baseline");
println!(" 💻 Run 'pgmt apply' to sync your dev database");
println!(" 🚀 Run 'pgmt migrate new \"description\"' to create migrations");
}
BaselineResult::Failed(error) => {
if error.contains("relation") || error.contains("does not exist") {
println!("\n⚠️ Project initialized - schema validation failed\n");
println!("📝 Created:");
println!(" ✅ pgmt.yaml");
println!(
" ✅ {} (needs dependency fixes)",
options.schema_dir.display()
);
println!(" ✅ migrations/");
println!("\n🔧 Next steps:");
println!(" 1. Fix schema dependencies (see error above)");
println!(" 2. Test with: pgmt apply --dry-run");
println!(" 3. Repeat until validation passes");
println!(" 4. Create baseline: pgmt migrate baseline");
} else {
let was_explicit_request =
matches!(options.baseline_config.create_baseline, Some(true));
if was_explicit_request {
println!("\n⚠️ Project partially initialized - baseline creation failed!");
println!("\n📝 Created:");
println!(" ✅ pgmt.yaml (configuration)");
println!(
" ✅ {} directory with modular files",
options.schema_dir.display()
);
println!(" ✅ migrations/ directory");
println!(" ✅ schema_baselines/ directory");
println!(" ❌ Initial baseline creation failed: {}", error);
println!("\nNext steps:");
println!(" 🔧 Fix the baseline creation issue:");
println!(" • Check database connectivity and permissions");
println!(" • Review schema file dependencies");
println!(" • Consider running 'pgmt migrate baseline' manually");
println!(" 💻 Run 'pgmt apply' to sync your dev database");
println!(" 🚀 Run 'pgmt migrate new \"description\"' to create migrations");
} else {
println!("\n🎉 Project initialized successfully!");
println!("\n📝 Created:");
println!(" ✅ pgmt.yaml (configuration)");
println!(
" ✅ {} directory with modular files",
options.schema_dir.display()
);
println!(" ✅ migrations/ directory");
println!(" ✅ schema_baselines/ directory");
println!(" ⚠️ Baseline creation failed (see error above)");
println!("\nNext steps:");
println!(" 💡 Fix the issue and create baseline: pgmt migrate baseline");
println!(" 💻 Run 'pgmt apply' to sync your dev database");
println!(" 🚀 Run 'pgmt migrate new \"description\"' to create migrations");
}
}
}
BaselineResult::NotRequested => {
println!("\n🎉 Project initialized successfully!");
println!("\n📝 Created:");
println!(" ✅ pgmt.yaml (configuration)");
println!(
" ✅ {} directory with modular files",
options.schema_dir.display()
);
println!(" ✅ migrations/ directory");
println!(" ✅ schema_baselines/ directory");
println!("\nNext steps:");
println!(" 💻 Run 'pgmt apply' to sync your dev database");
println!(
" 📝 Add schema files to {} and customize as needed",
options.schema_dir.display()
);
println!(" 🚀 Run 'pgmt migrate new \"description\"' to create migrations");
}
}
println!(" 📚 Visit https://docs.pgmt.dev for more information");
}
async fn validate_schema_files(
schema_dir: &std::path::Path,
roles_file: Option<&std::path::Path>,
shadow_config: &ShadowDatabaseInput,
shadow_pg_version: Option<&String>,
detected_pg_version: Option<&String>,
) -> Result<()> {
validate_schema_files_impl(
schema_dir,
roles_file,
shadow_config,
shadow_pg_version,
detected_pg_version,
)
.await
}
async fn validate_schema_files_impl(
schema_dir: &std::path::Path,
roles_file: Option<&std::path::Path>,
shadow_config: &ShadowDatabaseInput,
shadow_pg_version: Option<&String>,
detected_pg_version: Option<&String>,
) -> Result<()> {
use crate::db::cleaner;
use crate::db::connection::connect_with_retry;
use crate::db::schema_processor::{SchemaProcessor, SchemaProcessorConfig};
let shadow_database =
resolve_shadow_database(shadow_config, shadow_pg_version, detected_pg_version);
let shadow_url = shadow_database.get_connection_string().await?;
let pool = connect_with_retry(&shadow_url).await?;
cleaner::clean_shadow_db(&pool, &crate::config::types::Objects::default()).await?;
if let Some(roles_path) = roles_file
&& roles_path.exists()
{
crate::schema_ops::apply_roles_file(&pool, roles_path).await?;
}
let config = SchemaProcessorConfig {
verbose: false, clean_before_apply: false, ..Default::default()
};
let processor = SchemaProcessor::new(pool.clone(), config);
processor.process_schema_directory(schema_dir).await?;
pool.close().await;
Ok(())
}
async fn generate_schema_files(catalog: &Catalog, options: &InitOptions) -> Result<usize> {
use crate::schema_generator::{SchemaGenerator, SchemaGeneratorConfig};
let schema_path = options.project_dir.join(&options.schema_dir);
std::fs::create_dir_all(&schema_path)?;
let config = SchemaGeneratorConfig {
include_comments: options.object_config.comments,
include_grants: options.object_config.grants,
include_triggers: options.object_config.triggers,
include_extensions: options.object_config.extensions,
};
let generator = SchemaGenerator::new(catalog.clone(), schema_path.clone(), config);
generator.generate_files()?;
let file_count = count_generated_files(&schema_path)?;
Ok(file_count)
}
fn count_generated_files(schema_dir: &std::path::PathBuf) -> Result<usize> {
let mut count = 0;
if schema_dir.exists() {
for entry in std::fs::read_dir(schema_dir)? {
let entry = entry?;
let path = entry.path();
if path.is_file() && path.extension().and_then(|s| s.to_str()) == Some("sql") {
count += 1;
} else if path.is_dir() {
count += count_generated_files(&path)?;
}
}
}
Ok(count)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_object_management_config_default() {
let config = ObjectManagementConfig::default();
assert!(config.comments);
assert!(config.grants);
assert!(config.triggers);
assert!(config.extensions);
}
#[test]
fn test_count_catalog_objects() {
let catalog = Catalog::empty();
assert_eq!(count_catalog_objects(&catalog), 0);
}
#[test]
fn test_count_generated_files() {
use std::env;
let temp_dir = env::temp_dir().join("pgmt_test_count_files");
let _ = std::fs::remove_dir_all(&temp_dir);
std::fs::create_dir_all(&temp_dir).unwrap();
std::fs::write(temp_dir.join("test1.sql"), "SELECT 1;").unwrap();
std::fs::write(temp_dir.join("test2.sql"), "SELECT 2;").unwrap();
std::fs::write(temp_dir.join("readme.txt"), "Not SQL").unwrap();
let count = count_generated_files(&temp_dir).unwrap();
assert_eq!(count, 2);
let _ = std::fs::remove_dir_all(&temp_dir);
}
#[test]
fn test_check_existing_config_not_found() {
use std::env;
let temp_dir = env::temp_dir().join("pgmt_test_no_config");
let _ = std::fs::remove_dir_all(&temp_dir);
std::fs::create_dir_all(&temp_dir).unwrap();
let result = check_existing_config(&temp_dir, false).unwrap();
assert!(matches!(result, ExistingConfigResult::NotFound));
let _ = std::fs::remove_dir_all(&temp_dir);
}
#[test]
fn test_check_existing_config_fresh_flag() {
use std::env;
let temp_dir = env::temp_dir().join("pgmt_test_fresh_flag");
let _ = std::fs::remove_dir_all(&temp_dir);
std::fs::create_dir_all(&temp_dir).unwrap();
let config_content = r#"
databases:
dev_url: postgres://localhost/test
"#;
std::fs::write(temp_dir.join("pgmt.yaml"), config_content).unwrap();
let result = check_existing_config(&temp_dir, true).unwrap();
assert!(matches!(result, ExistingConfigResult::Fresh));
let _ = std::fs::remove_dir_all(&temp_dir);
}
}