use crate::{DomainInfo, PredicateInfo, SymbolTable};
use anyhow::{bail, Result};
use std::collections::HashSet;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MergeStrategy {
KeepFirst,
KeepSecond,
FailOnConflict,
Union,
Intersection,
}
#[derive(Debug, Clone)]
pub struct MergeResult {
pub merged: SymbolTable,
pub report: MergeReport,
}
#[derive(Debug, Clone)]
pub struct MergeReport {
pub base_domains: Vec<String>,
pub incoming_domains: Vec<String>,
pub conflicting_domains: Vec<DomainConflict>,
pub base_predicates: Vec<String>,
pub incoming_predicates: Vec<String>,
pub conflicting_predicates: Vec<PredicateConflict>,
pub merged_variables: Vec<String>,
pub conflicting_variables: Vec<VariableConflict>,
pub strategy: MergeStrategy,
}
impl MergeReport {
pub fn new(strategy: MergeStrategy) -> Self {
Self {
base_domains: Vec::new(),
incoming_domains: Vec::new(),
conflicting_domains: Vec::new(),
base_predicates: Vec::new(),
incoming_predicates: Vec::new(),
conflicting_predicates: Vec::new(),
merged_variables: Vec::new(),
conflicting_variables: Vec::new(),
strategy,
}
}
pub fn has_conflicts(&self) -> bool {
!self.conflicting_domains.is_empty()
|| !self.conflicting_predicates.is_empty()
|| !self.conflicting_variables.is_empty()
}
pub fn conflict_count(&self) -> usize {
self.conflicting_domains.len()
+ self.conflicting_predicates.len()
+ self.conflicting_variables.len()
}
pub fn merged_count(&self) -> usize {
self.base_domains.len()
+ self.incoming_domains.len()
+ self.base_predicates.len()
+ self.incoming_predicates.len()
+ self.merged_variables.len()
}
}
#[derive(Debug, Clone)]
pub struct DomainConflict {
pub name: String,
pub base: DomainInfo,
pub incoming: DomainInfo,
pub resolution: MergeConflictResolution,
}
#[derive(Debug, Clone)]
pub struct PredicateConflict {
pub name: String,
pub base: PredicateInfo,
pub incoming: PredicateInfo,
pub resolution: MergeConflictResolution,
}
#[derive(Debug, Clone)]
pub struct VariableConflict {
pub name: String,
pub base_domain: String,
pub incoming_domain: String,
pub resolution: MergeConflictResolution,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MergeConflictResolution {
KeptBase,
KeptIncoming,
Failed,
Merged,
}
pub struct SchemaMerger {
strategy: MergeStrategy,
}
impl SchemaMerger {
pub fn new(strategy: MergeStrategy) -> Self {
Self { strategy }
}
pub fn merge(&self, base: &SymbolTable, incoming: &SymbolTable) -> Result<MergeResult> {
let mut merged = SymbolTable::new();
let mut report = MergeReport::new(self.strategy);
self.merge_domains(base, incoming, &mut merged, &mut report)?;
self.merge_predicates(base, incoming, &mut merged, &mut report)?;
self.merge_variables(base, incoming, &mut merged, &mut report)?;
Ok(MergeResult { merged, report })
}
fn merge_domains(
&self,
base: &SymbolTable,
incoming: &SymbolTable,
merged: &mut SymbolTable,
report: &mut MergeReport,
) -> Result<()> {
let base_keys: HashSet<&String> = base.domains.keys().collect();
let incoming_keys: HashSet<&String> = incoming.domains.keys().collect();
for key in base_keys.difference(&incoming_keys) {
let domain = base
.domains
.get(*key)
.expect("key from HashMap iteration is always present");
merged.add_domain(domain.clone())?;
report.base_domains.push(key.to_string());
}
for key in incoming_keys.difference(&base_keys) {
let domain = incoming
.domains
.get(*key)
.expect("key from HashMap iteration is always present");
merged.add_domain(domain.clone())?;
report.incoming_domains.push(key.to_string());
}
for key in base_keys.intersection(&incoming_keys) {
let base_domain = base
.domains
.get(*key)
.expect("key from HashMap iteration is always present");
let incoming_domain = incoming
.domains
.get(*key)
.expect("key from HashMap iteration is always present");
let (domain, resolution) =
self.resolve_domain_conflict(base_domain, incoming_domain)?;
merged.add_domain(domain)?;
if resolution != MergeConflictResolution::Merged {
report.conflicting_domains.push(DomainConflict {
name: key.to_string(),
base: base_domain.clone(),
incoming: incoming_domain.clone(),
resolution,
});
}
}
Ok(())
}
fn merge_predicates(
&self,
base: &SymbolTable,
incoming: &SymbolTable,
merged: &mut SymbolTable,
report: &mut MergeReport,
) -> Result<()> {
let base_keys: HashSet<&String> = base.predicates.keys().collect();
let incoming_keys: HashSet<&String> = incoming.predicates.keys().collect();
for key in base_keys.difference(&incoming_keys) {
let predicate = base
.predicates
.get(*key)
.expect("key from HashMap iteration is always present");
merged.add_predicate(predicate.clone())?;
report.base_predicates.push(key.to_string());
}
for key in incoming_keys.difference(&base_keys) {
let predicate = incoming
.predicates
.get(*key)
.expect("key from HashMap iteration is always present");
merged.add_predicate(predicate.clone())?;
report.incoming_predicates.push(key.to_string());
}
for key in base_keys.intersection(&incoming_keys) {
let base_pred = base
.predicates
.get(*key)
.expect("key from HashMap iteration is always present");
let incoming_pred = incoming
.predicates
.get(*key)
.expect("key from HashMap iteration is always present");
let (predicate, resolution) =
self.resolve_predicate_conflict(base_pred, incoming_pred)?;
merged.add_predicate(predicate)?;
if resolution != MergeConflictResolution::Merged {
report.conflicting_predicates.push(PredicateConflict {
name: key.to_string(),
base: base_pred.clone(),
incoming: incoming_pred.clone(),
resolution,
});
}
}
Ok(())
}
fn merge_variables(
&self,
base: &SymbolTable,
incoming: &SymbolTable,
merged: &mut SymbolTable,
report: &mut MergeReport,
) -> Result<()> {
let base_keys: HashSet<&String> = base.variables.keys().collect();
let incoming_keys: HashSet<&String> = incoming.variables.keys().collect();
for key in base_keys.difference(&incoming_keys) {
let domain = base
.variables
.get(*key)
.expect("key from HashMap iteration is always present");
merged.bind_variable(key.to_string(), domain.clone())?;
report.merged_variables.push(key.to_string());
}
for key in incoming_keys.difference(&base_keys) {
let domain = incoming
.variables
.get(*key)
.expect("key from HashMap iteration is always present");
merged.bind_variable(key.to_string(), domain.clone())?;
report.merged_variables.push(key.to_string());
}
for key in base_keys.intersection(&incoming_keys) {
let base_domain = base
.variables
.get(*key)
.expect("key from HashMap iteration is always present");
let incoming_domain = incoming
.variables
.get(*key)
.expect("key from HashMap iteration is always present");
let (domain, resolution) =
self.resolve_variable_conflict(base_domain, incoming_domain)?;
merged.bind_variable(key.to_string(), domain)?;
if resolution != MergeConflictResolution::Merged {
report.conflicting_variables.push(VariableConflict {
name: key.to_string(),
base_domain: base_domain.clone(),
incoming_domain: incoming_domain.clone(),
resolution,
});
}
}
Ok(())
}
fn resolve_domain_conflict(
&self,
base: &DomainInfo,
incoming: &DomainInfo,
) -> Result<(DomainInfo, MergeConflictResolution)> {
match self.strategy {
MergeStrategy::KeepFirst => Ok((base.clone(), MergeConflictResolution::KeptBase)),
MergeStrategy::KeepSecond => {
Ok((incoming.clone(), MergeConflictResolution::KeptIncoming))
}
MergeStrategy::FailOnConflict => {
bail!(
"Domain conflict for '{}': cardinality {} vs {}",
base.name,
base.cardinality,
incoming.cardinality
)
}
MergeStrategy::Union => {
if base.cardinality >= incoming.cardinality {
Ok((base.clone(), MergeConflictResolution::KeptBase))
} else {
Ok((incoming.clone(), MergeConflictResolution::KeptIncoming))
}
}
MergeStrategy::Intersection => {
if base.cardinality <= incoming.cardinality {
Ok((base.clone(), MergeConflictResolution::KeptBase))
} else {
Ok((incoming.clone(), MergeConflictResolution::KeptIncoming))
}
}
}
}
fn resolve_predicate_conflict(
&self,
base: &PredicateInfo,
incoming: &PredicateInfo,
) -> Result<(PredicateInfo, MergeConflictResolution)> {
let compatible = base.arg_domains == incoming.arg_domains;
match self.strategy {
MergeStrategy::KeepFirst => Ok((base.clone(), MergeConflictResolution::KeptBase)),
MergeStrategy::KeepSecond => {
Ok((incoming.clone(), MergeConflictResolution::KeptIncoming))
}
MergeStrategy::FailOnConflict => {
bail!(
"Predicate conflict for '{}': {:?} vs {:?}",
base.name,
base.arg_domains,
incoming.arg_domains
)
}
MergeStrategy::Union => {
if compatible {
Ok((base.clone(), MergeConflictResolution::Merged))
} else {
bail!(
"Incompatible predicate signatures for '{}': {:?} vs {:?}",
base.name,
base.arg_domains,
incoming.arg_domains
)
}
}
MergeStrategy::Intersection => {
if compatible {
Ok((base.clone(), MergeConflictResolution::Merged))
} else {
bail!(
"Incompatible predicate signatures for '{}': {:?} vs {:?}",
base.name,
base.arg_domains,
incoming.arg_domains
)
}
}
}
}
fn resolve_variable_conflict(
&self,
base_domain: &str,
incoming_domain: &str,
) -> Result<(String, MergeConflictResolution)> {
match self.strategy {
MergeStrategy::KeepFirst => {
Ok((base_domain.to_string(), MergeConflictResolution::KeptBase))
}
MergeStrategy::KeepSecond => Ok((
incoming_domain.to_string(),
MergeConflictResolution::KeptIncoming,
)),
MergeStrategy::FailOnConflict => {
bail!(
"Variable domain conflict: '{}' vs '{}'",
base_domain,
incoming_domain
)
}
MergeStrategy::Union | MergeStrategy::Intersection => {
if base_domain == incoming_domain {
Ok((base_domain.to_string(), MergeConflictResolution::Merged))
} else {
bail!(
"Incompatible variable domains: '{}' vs '{}'",
base_domain,
incoming_domain
)
}
}
}
}
}
impl Default for SchemaMerger {
fn default() -> Self {
Self::new(MergeStrategy::Union)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_base_table() -> SymbolTable {
let mut table = SymbolTable::new();
table
.add_domain(DomainInfo::new("Person", 100))
.expect("unwrap");
table
.add_predicate(PredicateInfo::new(
"knows",
vec!["Person".to_string(), "Person".to_string()],
))
.expect("unwrap");
table.bind_variable("x", "Person").expect("unwrap");
table
}
fn create_incoming_table() -> SymbolTable {
let mut table = SymbolTable::new();
table
.add_domain(DomainInfo::new("Person", 150))
.expect("unwrap"); table
.add_domain(DomainInfo::new("Organization", 50))
.expect("unwrap");
table
.add_predicate(PredicateInfo::new("age", vec!["Person".to_string()]))
.expect("unwrap");
table
}
#[test]
fn test_merge_union_no_conflicts() {
let base = create_base_table();
let incoming = create_incoming_table();
let merger = SchemaMerger::new(MergeStrategy::Union);
let result = merger.merge(&base, &incoming).expect("unwrap");
assert_eq!(result.merged.domains.len(), 2); assert_eq!(result.merged.predicates.len(), 2); assert!(result.report.has_conflicts()); }
#[test]
fn test_merge_with_domain_conflict() {
let mut base = SymbolTable::new();
base.add_domain(DomainInfo::new("Person", 100))
.expect("unwrap");
let mut incoming = SymbolTable::new();
incoming
.add_domain(DomainInfo::new("Person", 200))
.expect("unwrap");
let merger = SchemaMerger::new(MergeStrategy::KeepFirst);
let result = merger.merge(&base, &incoming).expect("unwrap");
assert_eq!(result.merged.domains.len(), 1);
assert_eq!(
result
.merged
.domains
.get("Person")
.expect("unwrap")
.cardinality,
100
);
assert!(result.report.has_conflicts());
}
#[test]
fn test_merge_keep_second() {
let mut base = SymbolTable::new();
base.add_domain(DomainInfo::new("Person", 100))
.expect("unwrap");
let mut incoming = SymbolTable::new();
incoming
.add_domain(DomainInfo::new("Person", 200))
.expect("unwrap");
let merger = SchemaMerger::new(MergeStrategy::KeepSecond);
let result = merger.merge(&base, &incoming).expect("unwrap");
assert_eq!(
result
.merged
.domains
.get("Person")
.expect("unwrap")
.cardinality,
200
);
}
#[test]
fn test_merge_fail_on_conflict() {
let mut base = SymbolTable::new();
base.add_domain(DomainInfo::new("Person", 100))
.expect("unwrap");
let mut incoming = SymbolTable::new();
incoming
.add_domain(DomainInfo::new("Person", 200))
.expect("unwrap");
let merger = SchemaMerger::new(MergeStrategy::FailOnConflict);
let result = merger.merge(&base, &incoming);
assert!(result.is_err());
}
#[test]
fn test_merge_report() {
let base = create_base_table();
let incoming = create_incoming_table();
let merger = SchemaMerger::new(MergeStrategy::Union);
let result = merger.merge(&base, &incoming).expect("unwrap");
let report = &result.report;
assert_eq!(report.base_domains.len(), 0);
assert_eq!(report.incoming_domains.len(), 1);
assert_eq!(report.merged_count(), 4);
assert_eq!(report.conflict_count(), 1); }
#[test]
fn test_predicate_conflict_compatible() {
let mut base = SymbolTable::new();
base.add_domain(DomainInfo::new("Person", 100))
.expect("unwrap");
base.add_predicate(PredicateInfo::new("age", vec!["Person".to_string()]))
.expect("unwrap");
let mut incoming = SymbolTable::new();
incoming
.add_domain(DomainInfo::new("Person", 100))
.expect("unwrap");
incoming
.add_predicate(PredicateInfo::new("age", vec!["Person".to_string()]))
.expect("unwrap");
let merger = SchemaMerger::new(MergeStrategy::Union);
let result = merger.merge(&base, &incoming).expect("unwrap");
assert_eq!(result.merged.predicates.len(), 1);
assert_eq!(result.report.conflicting_predicates.len(), 0);
}
#[test]
fn test_variable_conflict() {
let mut base = SymbolTable::new();
base.add_domain(DomainInfo::new("Person", 100))
.expect("unwrap");
base.add_domain(DomainInfo::new("Agent", 50))
.expect("unwrap");
base.bind_variable("x", "Person").expect("unwrap");
let mut incoming = SymbolTable::new();
incoming
.add_domain(DomainInfo::new("Person", 100))
.expect("unwrap");
incoming
.add_domain(DomainInfo::new("Agent", 50))
.expect("unwrap");
incoming.bind_variable("x", "Agent").expect("unwrap");
let merger = SchemaMerger::new(MergeStrategy::KeepFirst);
let result = merger.merge(&base, &incoming).expect("unwrap");
assert_eq!(result.merged.variables.get("x").expect("unwrap"), "Person");
assert_eq!(result.report.conflicting_variables.len(), 1);
}
}