use crate::catalog;
use crate::catalog::id::DbObjectId;
use crate::config::types::{ObjectExclude, ObjectInclude, Objects, TrackingTable};
use glob::Pattern;
pub struct ObjectFilter {
include: ObjectInclude,
exclude: ObjectExclude,
tracking_table: TrackingTable,
}
impl ObjectFilter {
pub fn new(config: &Objects, tracking_table: &TrackingTable) -> Self {
Self {
include: config.include.clone(),
exclude: config.exclude.clone(),
tracking_table: tracking_table.clone(),
}
}
pub fn from_config(config: &crate::config::types::Config) -> Self {
Self::new(&config.objects, &config.migration.tracking_table)
}
pub fn should_include_schema(&self, schema_name: &str) -> bool {
if self.matches_patterns(&self.exclude.schemas, schema_name) {
return false;
}
if !self.include.schemas.is_empty() {
return self.matches_patterns(&self.include.schemas, schema_name);
}
true
}
pub fn should_include_table(&self, schema_name: &str, table_name: &str) -> bool {
if self.is_pgmt_internal_table(schema_name, table_name) {
return false;
}
if !self.should_include_schema(schema_name) {
return false;
}
if self.matches_patterns(&self.exclude.tables, table_name) {
return false;
}
if !self.include.tables.is_empty() {
return self.matches_patterns(&self.include.tables, table_name);
}
true
}
pub fn is_pgmt_internal_table(&self, schema_name: &str, table_name: &str) -> bool {
if schema_name != self.tracking_table.schema {
return false;
}
let internal_tables = [
self.tracking_table.name.as_str(), &format!("{}_sections", self.tracking_table.name), ];
internal_tables.contains(&table_name)
}
pub fn filter_catalog(&self, mut catalog: catalog::Catalog) -> catalog::Catalog {
catalog
.schemas
.retain(|schema| self.should_include_schema(&schema.name));
catalog
.tables
.retain(|table| self.should_include_table(&table.schema, &table.name));
catalog
.views
.retain(|view| self.should_include_table(&view.schema, &view.name));
catalog
.functions
.retain(|function| self.should_include_schema(&function.schema));
catalog
.types
.retain(|custom_type| self.should_include_schema(&custom_type.schema));
catalog
.sequences
.retain(|sequence| self.should_include_schema(&sequence.schema));
catalog
.indexes
.retain(|index| self.should_include_table(&index.schema, &index.table_name));
catalog.constraints.retain(|constraint| {
self.should_include_table(&constraint.schema, &constraint.table_name)
});
catalog
.triggers
.retain(|trigger| self.should_include_table(&trigger.schema, &trigger.table_name));
catalog.grants.retain(|grant| {
match &grant.target.object {
DbObjectId::Table { schema, name } | DbObjectId::View { schema, name } => {
self.should_include_table(schema, name)
}
_ => self.should_include_schema(&grant.target.schema()),
}
});
catalog
.extensions
.retain(|ext| self.should_include_schema(&ext.schema));
let stale: Vec<_> = catalog
.forward_deps
.keys()
.filter(|id| !catalog.contains_id(id))
.cloned()
.collect();
for id in stale {
catalog.forward_deps.remove(&id);
}
catalog.rebuild_reverse_deps();
catalog
}
fn matches_patterns(&self, patterns: &[String], name: &str) -> bool {
if patterns.is_empty() {
return false;
}
patterns.iter().any(|pattern| {
Pattern::new(pattern)
.map(|p| p.matches(name))
.unwrap_or(false)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_objects() -> Objects {
Objects {
include: ObjectInclude {
schemas: vec!["public".to_string(), "app".to_string()],
tables: vec!["users".to_string(), "posts".to_string()],
},
exclude: ObjectExclude {
schemas: vec!["pg_*".to_string(), "information_schema".to_string()],
tables: vec!["temp_*".to_string()],
},
}
}
fn create_test_tracking_table() -> TrackingTable {
TrackingTable {
schema: "public".to_string(),
name: "pgmt_migrations".to_string(),
}
}
#[test]
fn test_schema_filtering() {
let filter = ObjectFilter::new(&create_test_objects(), &create_test_tracking_table());
assert!(filter.should_include_schema("public"));
assert!(filter.should_include_schema("app"));
assert!(!filter.should_include_schema("pg_catalog"));
assert!(!filter.should_include_schema("information_schema"));
assert!(!filter.should_include_schema("other"));
}
#[test]
fn test_table_filtering() {
let filter = ObjectFilter::new(&create_test_objects(), &create_test_tracking_table());
assert!(filter.should_include_table("public", "users"));
assert!(filter.should_include_table("app", "posts"));
assert!(!filter.should_include_table("public", "temp_data"));
assert!(!filter.should_include_table("public", "other_table"));
assert!(!filter.should_include_table("pg_catalog", "pg_tables"));
assert!(!filter.should_include_table("public", "pgmt_migrations"));
}
#[test]
fn test_pgmt_internal_tables() {
let filter = ObjectFilter::new(&create_test_objects(), &create_test_tracking_table());
assert!(filter.is_pgmt_internal_table("public", "pgmt_migrations"));
assert!(filter.is_pgmt_internal_table("public", "pgmt_migrations_sections"));
assert!(!filter.is_pgmt_internal_table("other", "pgmt_migrations"));
assert!(!filter.is_pgmt_internal_table("public", "users"));
}
#[test]
fn test_empty_include_patterns() {
let objects = Objects {
include: ObjectInclude {
schemas: vec![], tables: vec![],
},
exclude: ObjectExclude {
schemas: vec!["pg_*".to_string()],
tables: vec!["temp_*".to_string()],
},
};
let filter = ObjectFilter::new(&objects, &create_test_tracking_table());
assert!(filter.should_include_schema("public"));
assert!(filter.should_include_schema("app"));
assert!(!filter.should_include_schema("pg_catalog"));
}
#[test]
fn test_migration_table_handling() {
let tracking_table = TrackingTable {
schema: "internal".to_string(),
name: "migration_history".to_string(),
};
let objects = Objects {
include: ObjectInclude {
schemas: vec!["public".to_string()], tables: vec!["users".to_string()], },
exclude: ObjectExclude {
schemas: vec![],
tables: vec![],
},
};
let filter = ObjectFilter::new(&objects, &tracking_table);
assert!(!filter.should_include_table("internal", "migration_history"));
assert!(filter.is_pgmt_internal_table("internal", "migration_history"));
assert!(!filter.should_include_table("internal", "migration_history_sections"));
assert!(filter.is_pgmt_internal_table("internal", "migration_history_sections"));
assert!(!filter.should_include_table("internal", "other_table"));
assert!(filter.should_include_table("public", "users"));
assert!(!filter.should_include_table("public", "posts")); }
#[test]
fn test_grant_filtering() {
use crate::catalog::Catalog;
use crate::catalog::grant::{Grant, GranteeType};
use crate::catalog::target::AttrTarget;
let objects = Objects {
include: ObjectInclude {
schemas: vec![],
tables: vec![],
},
exclude: ObjectExclude {
schemas: vec!["excluded_schema".to_string()],
tables: vec!["excluded_table".to_string()],
},
};
let filter = ObjectFilter::new(&objects, &create_test_tracking_table());
let make_grant = |target: AttrTarget| Grant {
grantee: GranteeType::Public,
target,
privileges: vec!["EXECUTE".to_string()],
with_grant_option: false,
depends_on: vec![],
object_owner: "postgres".to_string(),
is_default_acl: false,
};
let mut catalog = Catalog::empty();
catalog.grants = vec![
make_grant(AttrTarget::object(DbObjectId::Function {
schema: "public".into(),
name: "my_func".into(),
arguments: "".into(),
})),
make_grant(AttrTarget::object(DbObjectId::Function {
schema: "excluded_schema".into(),
name: "notify_watchers".into(),
arguments: "".into(),
})),
make_grant(AttrTarget::object(DbObjectId::Table {
schema: "public".into(),
name: "excluded_table".into(),
})),
make_grant(AttrTarget::object(DbObjectId::Table {
schema: "public".into(),
name: "users".into(),
})),
make_grant(AttrTarget::object(DbObjectId::Schema {
name: "excluded_schema".into(),
})),
make_grant(AttrTarget::object(DbObjectId::Schema {
name: "public".into(),
})),
];
let filtered = filter.filter_catalog(catalog);
assert_eq!(filtered.grants.len(), 3);
let remaining_ids: Vec<String> = filtered.grants.iter().map(|g| g.id()).collect();
assert!(
remaining_ids
.iter()
.any(|id| id.contains("function:public.my_func"))
);
assert!(
remaining_ids
.iter()
.any(|id| id.contains("table:public.users"))
);
assert!(remaining_ids.iter().any(|id| id.contains("schema:public")));
assert!(
!remaining_ids
.iter()
.any(|id| id.contains("excluded_schema"))
);
assert!(!remaining_ids.iter().any(|id| id.contains("excluded_table")));
}
}
#[cfg(test)]
mod substrate_filter_tests {
use super::*;
use crate::catalog::Catalog;
use crate::catalog::extension::Extension;
use crate::catalog::id::DbObjectId;
use crate::catalog::schema::Schema;
#[test]
fn test_excluded_schema_drops_extension_and_dependency_entries() {
let mut catalog = Catalog::empty();
catalog.schemas = vec![
Schema {
name: "public".to_string(),
comment: None,
},
Schema {
name: "topology".to_string(),
comment: Some("PostGIS Topology schema".to_string()),
},
];
catalog.extensions = vec![
Extension {
name: "postgis".to_string(),
schema: "public".to_string(),
version: "3.5".to_string(),
relocatable: false,
comment: None,
depends_on: vec![],
},
Extension {
name: "postgis_topology".to_string(),
schema: "topology".to_string(),
version: "3.5".to_string(),
relocatable: false,
comment: None,
depends_on: vec![DbObjectId::Schema {
name: "topology".to_string(),
}],
},
];
catalog.forward_deps.insert(
DbObjectId::Extension {
name: "postgis_topology".to_string(),
},
vec![DbObjectId::Schema {
name: "topology".to_string(),
}],
);
catalog.forward_deps.insert(
DbObjectId::Extension {
name: "postgis".to_string(),
},
vec![],
);
catalog.rebuild_reverse_deps();
let objects = Objects {
include: ObjectInclude::default(),
exclude: ObjectExclude {
schemas: vec!["topology".to_string()],
tables: vec![],
},
};
let filter = ObjectFilter::new(&objects, &TrackingTable::default());
let filtered = filter.filter_catalog(catalog);
assert_eq!(
filtered.schemas.iter().map(|s| &s.name).collect::<Vec<_>>(),
vec!["public"],
"excluded schema must be dropped"
);
assert_eq!(
filtered
.extensions
.iter()
.map(|e| &e.name)
.collect::<Vec<_>>(),
vec!["postgis"],
"extensions installed in excluded schemas must be dropped"
);
assert!(
!filtered.forward_deps.contains_key(&DbObjectId::Extension {
name: "postgis_topology".to_string()
}),
"dependency maps must not retain filtered objects"
);
assert!(
!filtered.reverse_deps.contains_key(&DbObjectId::Schema {
name: "topology".to_string()
}),
"reverse deps must be rebuilt after filtering"
);
}
}