use anyhow::Result;
use std::collections::{BTreeMap, BTreeSet};
use std::fs;
use std::path::PathBuf;
use crate::catalog::Catalog;
use crate::catalog::id::DbObjectId;
use crate::diff::operations::{MigrationStep, SqlRenderer};
use crate::diff::plan;
#[derive(Debug, Clone)]
pub struct SchemaGeneratorConfig {
pub include_comments: bool,
pub include_grants: bool,
pub include_triggers: bool,
pub include_extensions: bool,
}
impl Default for SchemaGeneratorConfig {
fn default() -> Self {
Self {
include_comments: true,
include_grants: true,
include_triggers: true,
include_extensions: true,
}
}
}
#[derive(Debug, Clone)]
struct FileContent {
path: PathBuf,
dependencies: Vec<String>,
sql_statements: Vec<String>,
}
pub struct SchemaGenerator {
catalog: Catalog,
output_dir: PathBuf,
config: SchemaGeneratorConfig,
}
impl SchemaGenerator {
pub fn new(catalog: Catalog, output_dir: PathBuf, config: SchemaGeneratorConfig) -> Self {
Self {
catalog,
output_dir,
config,
}
}
fn has_multiple_schemas(&self) -> bool {
self.catalog.schemas.iter().any(|s| s.name != "public")
}
fn schema_path_prefix(&self, schema: &str) -> String {
if self.has_multiple_schemas() {
format!("{}/", schema)
} else {
String::new()
}
}
pub fn generate_files(&self) -> Result<()> {
self.create_directory_structure()?;
let empty_catalog = Catalog::empty();
let ordered_steps = plan(&empty_catalog, &self.catalog)?;
let filtered_steps = self.filter_steps_by_config(ordered_steps);
let organized_files = self.organize_steps_into_files(filtered_steps)?;
self.write_organized_files(organized_files)?;
Ok(())
}
fn create_directory_structure(&self) -> Result<()> {
fs::create_dir_all(&self.output_dir)?;
if self.has_multiple_schemas() {
for schema in &self.catalog.schemas {
let schema_dir = self.output_dir.join(&schema.name);
fs::create_dir_all(schema_dir.join("tables"))?;
fs::create_dir_all(schema_dir.join("views"))?;
fs::create_dir_all(schema_dir.join("functions"))?;
fs::create_dir_all(schema_dir.join("types"))?;
fs::create_dir_all(schema_dir.join("aggregates"))?;
fs::create_dir_all(schema_dir.join("sequences"))?;
}
} else {
fs::create_dir_all(self.output_dir.join("tables"))?;
fs::create_dir_all(self.output_dir.join("views"))?;
fs::create_dir_all(self.output_dir.join("functions"))?;
fs::create_dir_all(self.output_dir.join("types"))?;
fs::create_dir_all(self.output_dir.join("aggregates"))?;
fs::create_dir_all(self.output_dir.join("sequences"))?;
}
Ok(())
}
fn filter_steps_by_config(&self, steps: Vec<MigrationStep>) -> Vec<MigrationStep> {
steps
.into_iter()
.filter(|step| match step {
MigrationStep::Grant(_) => self.config.include_grants,
MigrationStep::Trigger(_) => self.config.include_triggers,
MigrationStep::Extension(_) => self.config.include_extensions,
_ => {
if let DbObjectId::Comment { .. } = step.id() {
self.config.include_comments
} else {
true
}
}
})
.collect()
}
fn organize_steps_into_files(
&self,
steps: Vec<MigrationStep>,
) -> Result<BTreeMap<String, FileContent>> {
let mut steps_by_file: BTreeMap<String, Vec<MigrationStep>> = BTreeMap::new();
let mut object_to_file: BTreeMap<DbObjectId, String> = BTreeMap::new();
for step in steps {
let file_key = self.determine_file_for_step(&step);
let object_id = step.id();
object_to_file.insert(object_id, file_key.clone());
steps_by_file.entry(file_key).or_default().push(step);
}
let mut files: BTreeMap<String, FileContent> = BTreeMap::new();
for (file_key, file_steps) in steps_by_file {
let file_content = self.create_file_content(file_key, file_steps, &object_to_file)?;
files.insert(
file_content.path.to_string_lossy().to_string(),
file_content,
);
}
Ok(files)
}
fn determine_file_for_step(&self, step: &MigrationStep) -> String {
match step {
MigrationStep::Schema(_) => "schemas.sql".to_string(),
MigrationStep::Extension(_) => "extensions.sql".to_string(),
MigrationStep::Type(op) => {
let (schema, name) = self.extract_type_info_from_operation(op);
let prefix = self.schema_path_prefix(&schema);
format!("{}types/{}.sql", prefix, name)
}
MigrationStep::Domain(op) => {
let (schema, name) = self.extract_domain_info_from_operation(op);
let prefix = self.schema_path_prefix(&schema);
format!("{}domains/{}.sql", prefix, name)
}
MigrationStep::Table(op) => {
let (schema, name) = self.extract_table_info_from_operation(op);
let prefix = self.schema_path_prefix(&schema);
format!("{}tables/{}.sql", prefix, name)
}
MigrationStep::View(op) => {
let (schema, name) = self.extract_view_info_from_operation(op);
let prefix = self.schema_path_prefix(&schema);
format!("{}views/{}.sql", prefix, name)
}
MigrationStep::Function(op) => {
let (schema, name) = self.extract_function_info_from_operation(op);
let prefix = self.schema_path_prefix(&schema);
format!("{}functions/{}.sql", prefix, name)
}
MigrationStep::Aggregate(op) => {
let (schema, name) = self.extract_aggregate_info_from_operation(op);
let prefix = self.schema_path_prefix(&schema);
format!("{}aggregates/{}.sql", prefix, name)
}
MigrationStep::Operator(op) => {
let schema = op.db_object_id().schema().unwrap_or("public").to_string();
let prefix = self.schema_path_prefix(&schema);
format!("{}operators.sql", prefix)
}
MigrationStep::Cast(_) => {
"casts.sql".to_string()
}
MigrationStep::Sequence(op) => {
let (schema, name) = self.extract_sequence_info_from_operation(op);
if let Some((table_schema, table_name)) =
self.find_owning_table_for_sequence(&schema, &name)
{
let prefix = self.schema_path_prefix(&table_schema);
format!("{}tables/{}.sql", prefix, table_name)
} else {
let prefix = self.schema_path_prefix(&schema);
format!("{}sequences/{}.sql", prefix, name)
}
}
MigrationStep::Index(op) => {
let (schema, table_name) = self.extract_table_info_from_index_operation(op);
let prefix = self.schema_path_prefix(&schema);
format!("{}tables/{}.sql", prefix, table_name)
}
MigrationStep::Constraint(op) => {
let (schema, table_name) = self.extract_table_info_from_constraint_operation(op);
let prefix = self.schema_path_prefix(&schema);
format!("{}tables/{}.sql", prefix, table_name)
}
MigrationStep::Trigger(op) => {
let (schema, table_name) = self.extract_table_info_from_trigger_operation(op);
let prefix = self.schema_path_prefix(&schema);
format!("{}tables/{}.sql", prefix, table_name)
}
MigrationStep::Policy(op) => {
let (schema, table_name) = self.extract_table_info_from_policy_operation(op);
let prefix = self.schema_path_prefix(&schema);
format!("{}tables/{}.sql", prefix, table_name)
}
MigrationStep::Comment(op) => self.determine_file_for_object_id(&op.target().object),
MigrationStep::Grant(op) => match self.extract_grant_target(op) {
GrantTarget::Table { schema, name } => {
let prefix = self.schema_path_prefix(&schema);
format!("{}tables/{}.sql", prefix, name)
}
GrantTarget::View { schema, name } => {
let prefix = self.schema_path_prefix(&schema);
format!("{}views/{}.sql", prefix, name)
}
GrantTarget::Function { schema, name } => {
let prefix = self.schema_path_prefix(&schema);
format!("{}functions/{}.sql", prefix, name)
}
GrantTarget::Procedure { schema, name } => {
let prefix = self.schema_path_prefix(&schema);
format!("{}functions/{}.sql", prefix, name)
}
GrantTarget::Aggregate { schema, name } => {
let prefix = self.schema_path_prefix(&schema);
format!("{}aggregates/{}.sql", prefix, name)
}
GrantTarget::Schema => "schemas.sql".to_string(),
GrantTarget::Type { schema } => {
let prefix = self.schema_path_prefix(&schema);
format!("{}types.sql", prefix)
}
GrantTarget::Domain { schema } => {
let prefix = self.schema_path_prefix(&schema);
format!("{}domains.sql", prefix)
}
GrantTarget::Sequence { schema, name } => {
if let Some((table_schema, table_name)) =
self.find_owning_table_for_sequence(&schema, &name)
{
let prefix = self.schema_path_prefix(&table_schema);
format!("{}tables/{}.sql", prefix, table_name)
} else {
let prefix = self.schema_path_prefix(&schema);
format!("{}sequences/{}.sql", prefix, name)
}
}
},
}
}
fn determine_file_for_object_id(&self, id: &DbObjectId) -> String {
match id {
DbObjectId::Schema { .. } => "schemas.sql".to_string(),
DbObjectId::Extension { .. } => "extensions.sql".to_string(),
DbObjectId::Type { schema, name } => {
format!("{}types/{}.sql", self.schema_path_prefix(schema), name)
}
DbObjectId::Domain { schema, name } => {
format!("{}domains/{}.sql", self.schema_path_prefix(schema), name)
}
DbObjectId::Table { schema, name } => {
format!("{}tables/{}.sql", self.schema_path_prefix(schema), name)
}
DbObjectId::View { schema, name } => {
format!("{}views/{}.sql", self.schema_path_prefix(schema), name)
}
DbObjectId::Function { schema, name, .. }
| DbObjectId::Procedure { schema, name, .. } => {
format!("{}functions/{}.sql", self.schema_path_prefix(schema), name)
}
DbObjectId::Aggregate { schema, name, .. } => {
format!("{}aggregates/{}.sql", self.schema_path_prefix(schema), name)
}
DbObjectId::Operator { schema, .. } => {
format!("{}operators.sql", self.schema_path_prefix(schema))
}
DbObjectId::Cast { .. } => "casts.sql".to_string(),
DbObjectId::Sequence { schema, name } => {
match self.find_owning_table_for_sequence(schema, name) {
Some((table_schema, table_name)) => {
format!(
"{}tables/{}.sql",
self.schema_path_prefix(&table_schema),
table_name
)
}
None => format!("{}sequences/{}.sql", self.schema_path_prefix(schema), name),
}
}
DbObjectId::Constraint { schema, table, .. }
| DbObjectId::Trigger { schema, table, .. }
| DbObjectId::Policy { schema, table, .. }
| DbObjectId::Column { schema, table, .. } => {
format!("{}tables/{}.sql", self.schema_path_prefix(schema), table)
}
DbObjectId::Index { schema, name } => {
let (table_schema, table_name) = self
.catalog
.indexes
.iter()
.find(|i| i.schema == *schema && i.name == *name)
.map(|i| (i.table_schema.clone(), i.table_name.clone()))
.unwrap_or_else(|| (schema.clone(), "unknown".to_string()));
format!(
"{}tables/{}.sql",
self.schema_path_prefix(&table_schema),
table_name
)
}
DbObjectId::Grant { .. } | DbObjectId::Comment { .. } => {
unreachable!("a comment/grant id is not a routable object: {id:?}")
}
}
}
fn create_file_content(
&self,
file_key: String,
steps: Vec<MigrationStep>,
object_to_file: &BTreeMap<DbObjectId, String>,
) -> Result<FileContent> {
let file_path = self.output_dir.join(&file_key);
let dependencies = self.calculate_file_dependencies(&steps, object_to_file);
let mut sql_statements = Vec::new();
for step in steps {
let rendered_sqls = step.to_sql();
for rendered_sql in rendered_sqls {
sql_statements.push(rendered_sql.sql);
}
}
Ok(FileContent {
path: file_path,
dependencies,
sql_statements,
})
}
fn calculate_file_dependencies(
&self,
steps: &[MigrationStep],
object_to_file: &BTreeMap<DbObjectId, String>,
) -> Vec<String> {
let mut dependencies = BTreeSet::new();
let current_file_path = if let Some(first_step) = steps.first() {
self.determine_file_for_step(first_step)
} else {
return vec![];
};
for step in steps {
let step_deps = self.get_step_dependencies(step);
for dep in step_deps {
if let Some(file_path) = object_to_file.get(&dep)
&& *file_path != current_file_path
{
dependencies.insert(file_path.clone());
}
}
}
dependencies.into_iter().collect()
}
fn get_step_dependencies(&self, step: &MigrationStep) -> Vec<DbObjectId> {
let step_id = step.id();
self.catalog
.forward_deps
.get(&step_id)
.cloned()
.unwrap_or_default()
}
fn write_organized_files(&self, files: BTreeMap<String, FileContent>) -> Result<()> {
for (_, file_content) in files {
let mut content = String::new();
if !file_content.dependencies.is_empty() {
for dep in &file_content.dependencies {
content.push_str(&format!("-- require: {}\n", dep));
}
content.push('\n');
}
for (i, sql) in file_content.sql_statements.iter().enumerate() {
if i > 0 {
content.push('\n');
}
content.push_str(sql);
if !sql.ends_with(';') {
content.push(';');
}
content.push('\n');
}
if !content.trim().is_empty() {
if let Some(parent) = file_content.path.parent() {
fs::create_dir_all(parent)?;
}
fs::write(&file_content.path, content)?;
}
}
Ok(())
}
fn extract_table_info_from_operation(
&self,
op: &crate::diff::operations::TableOperation,
) -> (String, String) {
use crate::diff::operations::TableOperation;
match op {
TableOperation::Create { schema, name, .. } => (schema.clone(), name.clone()),
TableOperation::Drop { schema, name } => (schema.clone(), name.clone()),
TableOperation::Alter { schema, name, .. } => (schema.clone(), name.clone()),
}
}
fn extract_view_info_from_operation(
&self,
op: &crate::diff::operations::ViewOperation,
) -> (String, String) {
use crate::diff::operations::ViewOperation;
match op {
ViewOperation::Create { schema, name, .. } => (schema.clone(), name.clone()),
ViewOperation::Drop { schema, name } => (schema.clone(), name.clone()),
ViewOperation::Replace { schema, name, .. } => (schema.clone(), name.clone()),
ViewOperation::SetOption { schema, name, .. } => (schema.clone(), name.clone()),
}
}
fn extract_function_info_from_operation(
&self,
op: &crate::diff::operations::FunctionOperation,
) -> (String, String) {
use crate::diff::operations::FunctionOperation;
match op {
FunctionOperation::Create { schema, name, .. } => (schema.clone(), name.clone()),
FunctionOperation::Drop { schema, name, .. } => (schema.clone(), name.clone()),
FunctionOperation::Replace { schema, name, .. } => (schema.clone(), name.clone()),
}
}
fn extract_aggregate_info_from_operation(
&self,
op: &crate::diff::operations::AggregateOperation,
) -> (String, String) {
use crate::diff::operations::AggregateOperation;
match op {
AggregateOperation::Create { aggregate, .. } => {
(aggregate.schema.clone(), aggregate.name.clone())
}
AggregateOperation::Drop { identifier, .. } => {
(identifier.schema.clone(), identifier.name.clone())
}
AggregateOperation::Replace { new_aggregate, .. } => {
(new_aggregate.schema.clone(), new_aggregate.name.clone())
}
}
}
fn extract_sequence_info_from_operation(
&self,
op: &crate::diff::operations::SequenceOperation,
) -> (String, String) {
use crate::diff::operations::SequenceOperation;
match op {
SequenceOperation::Create { schema, name, .. } => (schema.clone(), name.clone()),
SequenceOperation::Drop { schema, name } => (schema.clone(), name.clone()),
SequenceOperation::AlterOwnership { schema, name, .. } => {
(schema.clone(), name.clone())
}
}
}
fn extract_type_info_from_operation(
&self,
op: &crate::diff::operations::TypeOperation,
) -> (String, String) {
use crate::diff::operations::TypeOperation;
match op {
TypeOperation::Create { schema, name, .. } => (schema.clone(), name.clone()),
TypeOperation::Drop { schema, name } => (schema.clone(), name.clone()),
TypeOperation::Alter { schema, name, .. } => (schema.clone(), name.clone()),
}
}
fn extract_domain_info_from_operation(
&self,
op: &crate::diff::operations::DomainOperation,
) -> (String, String) {
use crate::diff::operations::DomainOperation;
match op {
DomainOperation::Create { schema, name, .. }
| DomainOperation::Drop { schema, name }
| DomainOperation::AlterSetNotNull { schema, name }
| DomainOperation::AlterDropNotNull { schema, name }
| DomainOperation::AlterSetDefault { schema, name, .. }
| DomainOperation::AlterDropDefault { schema, name }
| DomainOperation::AddConstraint { schema, name, .. }
| DomainOperation::DropConstraint { schema, name, .. } => {
(schema.clone(), name.clone())
}
}
}
fn extract_table_info_from_index_operation(
&self,
op: &crate::diff::operations::IndexOperation,
) -> (String, String) {
use crate::diff::operations::IndexOperation;
match op {
IndexOperation::Create(index) => (index.table_schema.clone(), index.table_name.clone()),
IndexOperation::Drop { schema, name, .. } => {
for index in &self.catalog.indexes {
if index.schema == *schema && index.name == *name {
return (index.table_schema.clone(), index.table_name.clone());
}
}
(schema.clone(), "unknown".to_string())
}
IndexOperation::Cluster {
table_schema,
table_name,
..
} => (table_schema.clone(), table_name.clone()),
IndexOperation::SetWithoutCluster { schema, name, .. } => {
for index in &self.catalog.indexes {
if index.schema == *schema && index.name == *name {
return (index.table_schema.clone(), index.table_name.clone());
}
}
(schema.clone(), name.clone())
}
IndexOperation::Reindex { schema, name, .. } => {
for index in &self.catalog.indexes {
if index.schema == *schema && index.name == *name {
return (index.table_schema.clone(), index.table_name.clone());
}
}
(schema.clone(), "unknown".to_string())
}
}
}
fn extract_table_info_from_constraint_operation(
&self,
op: &crate::diff::operations::ConstraintOperation,
) -> (String, String) {
use crate::diff::operations::ConstraintOperation;
match op {
ConstraintOperation::Create(constraint) => {
(constraint.schema.clone(), constraint.table_name.clone())
}
ConstraintOperation::Drop(constraint_id) => (
constraint_id.schema.clone(),
constraint_id.table_name.clone(),
),
}
}
fn extract_table_info_from_trigger_operation(
&self,
op: &crate::diff::operations::TriggerOperation,
) -> (String, String) {
use crate::diff::operations::TriggerOperation;
match op {
TriggerOperation::Create { trigger } => {
(trigger.schema.clone(), trigger.table_name.clone())
}
TriggerOperation::Drop { identifier } => {
(identifier.schema.clone(), identifier.table.clone())
}
TriggerOperation::Replace { new_trigger, .. } => {
(new_trigger.schema.clone(), new_trigger.table_name.clone())
}
}
}
fn extract_table_info_from_policy_operation(
&self,
op: &crate::diff::operations::PolicyOperation,
) -> (String, String) {
use crate::diff::operations::PolicyOperation;
match op {
PolicyOperation::Create { policy } => {
(policy.schema.clone(), policy.table_name.clone())
}
PolicyOperation::Drop { identifier } => {
(identifier.schema.clone(), identifier.table.clone())
}
PolicyOperation::Alter { identifier, .. } => {
(identifier.schema.clone(), identifier.table.clone())
}
PolicyOperation::Replace { new_policy, .. } => {
(new_policy.schema.clone(), new_policy.table_name.clone())
}
}
}
fn extract_grant_target(&self, op: &crate::diff::operations::GrantOperation) -> GrantTarget {
use crate::catalog::id::DbObjectId;
use crate::diff::operations::GrantOperation;
let target = match op {
GrantOperation::Grant { grant } => &grant.target,
GrantOperation::Revoke { grant } => &grant.target,
GrantOperation::GrantColumns(cg) | GrantOperation::RevokeColumns(cg) => {
let (schema, name) = cg.relation_schema_and_name();
return GrantTarget::Table { schema, name };
}
};
if target.column_name().is_some() {
let (schema, name) = target.schema_and_name();
return GrantTarget::Table { schema, name };
}
match &target.object {
DbObjectId::Table { schema, name } => GrantTarget::Table {
schema: schema.clone(),
name: name.clone(),
},
DbObjectId::View { schema, name } => GrantTarget::View {
schema: schema.clone(),
name: name.clone(),
},
DbObjectId::Function { schema, name, .. } => GrantTarget::Function {
schema: schema.clone(),
name: name.clone(),
},
DbObjectId::Procedure { schema, name, .. } => GrantTarget::Procedure {
schema: schema.clone(),
name: name.clone(),
},
DbObjectId::Aggregate { schema, name, .. } => GrantTarget::Aggregate {
schema: schema.clone(),
name: name.clone(),
},
DbObjectId::Schema { .. } => GrantTarget::Schema,
DbObjectId::Type { schema, .. } => GrantTarget::Type {
schema: schema.clone(),
},
DbObjectId::Domain { schema, .. } => GrantTarget::Domain {
schema: schema.clone(),
},
DbObjectId::Sequence { schema, name } => GrantTarget::Sequence {
schema: schema.clone(),
name: name.clone(),
},
DbObjectId::Index { .. }
| DbObjectId::Constraint { .. }
| DbObjectId::Trigger { .. }
| DbObjectId::Policy { .. }
| DbObjectId::Extension { .. }
| DbObjectId::Operator { .. }
| DbObjectId::Cast { .. }
| DbObjectId::Grant { .. }
| DbObjectId::Comment { .. }
| DbObjectId::Column { .. } => {
unreachable!("not a grantable object kind: {:?}", target.object)
}
}
}
fn find_owning_table_for_sequence(
&self,
seq_schema: &str,
seq_name: &str,
) -> Option<(String, String)> {
for sequence in &self.catalog.sequences {
if sequence.schema == seq_schema && sequence.name == seq_name {
if let Some(ref owned_by) = sequence.owned_by {
let parts: Vec<&str> = owned_by.split('.').collect();
if parts.len() >= 3 {
return Some((parts[0].to_string(), parts[1].to_string()));
}
}
break;
}
}
None
}
}
#[derive(Debug, Clone)]
enum GrantTarget {
Table { schema: String, name: String },
View { schema: String, name: String },
Function { schema: String, name: String },
Procedure { schema: String, name: String },
Aggregate { schema: String, name: String },
Schema,
Type { schema: String },
Domain { schema: String },
Sequence { schema: String, name: String },
}