use std::collections::HashMap;
use std::sync::Arc;
use crate::eq::{Term, alpha_equivalent_equation};
use crate::error::GatError;
use crate::ident::{NameSite, SiteRename};
use crate::theory::Theory;
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct TheoryMorphism {
pub name: Arc<str>,
pub domain: Arc<str>,
pub codomain: Arc<str>,
pub sort_map: HashMap<Arc<str>, Arc<str>>,
pub op_map: HashMap<Arc<str>, Arc<str>>,
}
impl TheoryMorphism {
#[must_use]
pub fn new(
name: impl Into<Arc<str>>,
domain: impl Into<Arc<str>>,
codomain: impl Into<Arc<str>>,
sort_map: HashMap<Arc<str>, Arc<str>>,
op_map: HashMap<Arc<str>, Arc<str>>,
) -> Self {
Self {
name: name.into(),
domain: domain.into(),
codomain: codomain.into(),
sort_map,
op_map,
}
}
#[must_use]
pub fn apply_to_term(&self, term: &Term) -> Term {
term.rename_ops(&self.op_map)
}
#[must_use]
pub fn induce_schema_renames(&self) -> Vec<SiteRename> {
let mut renames = Vec::new();
for (old_sort, new_sort) in &self.sort_map {
if old_sort != new_sort {
renames.push(SiteRename::new(
NameSite::VertexKind,
Arc::clone(old_sort),
Arc::clone(new_sort),
));
}
}
for (old_op, new_op) in &self.op_map {
if old_op != new_op {
renames.push(SiteRename::new(
NameSite::EdgeKind,
Arc::clone(old_op),
Arc::clone(new_op),
));
}
}
renames
}
}
pub fn check_morphism(
m: &TheoryMorphism,
domain: &Theory,
codomain: &Theory,
) -> Result<(), GatError> {
for sort in &domain.sorts {
let target_name = m
.sort_map
.get(&sort.name)
.ok_or_else(|| GatError::MissingSortMapping(sort.name.to_string()))?;
let target_sort = codomain
.find_sort(target_name)
.ok_or_else(|| GatError::SortNotFound(target_name.to_string()))?;
if sort.arity() != target_sort.arity() {
return Err(GatError::SortArityMismatch {
sort: sort.name.to_string(),
expected: sort.arity(),
got: target_sort.arity(),
});
}
}
for op in &domain.ops {
let target_name = m
.op_map
.get(&op.name)
.ok_or_else(|| GatError::MissingOpMapping(op.name.to_string()))?;
let target_op = codomain
.find_op(target_name)
.ok_or_else(|| GatError::OpNotFound(target_name.to_string()))?;
if op.inputs.len() != target_op.inputs.len() {
return Err(GatError::OpTypeMismatch {
op: op.name.to_string(),
detail: format!(
"arity mismatch: domain has {} inputs, codomain has {}",
op.inputs.len(),
target_op.inputs.len()
),
});
}
for (i, (_, sort_name)) in op.inputs.iter().enumerate() {
let mapped_sort = m
.sort_map
.get(sort_name)
.ok_or_else(|| GatError::MissingSortMapping(sort_name.to_string()))?;
let (_, target_sort) = &target_op.inputs[i];
if mapped_sort != target_sort {
return Err(GatError::OpTypeMismatch {
op: op.name.to_string(),
detail: format!("input {i}: expected sort {mapped_sort}, got {target_sort}"),
});
}
}
let mapped_output = m
.sort_map
.get(&op.output)
.ok_or_else(|| GatError::MissingSortMapping(op.output.to_string()))?;
if mapped_output != &target_op.output {
return Err(GatError::OpTypeMismatch {
op: op.name.to_string(),
detail: format!(
"output: expected sort {mapped_output}, got {}",
target_op.output
),
});
}
}
for eq in &domain.eqs {
let mapped_lhs = m.apply_to_term(&eq.lhs);
let mapped_rhs = m.apply_to_term(&eq.rhs);
let preserved = codomain
.eqs
.iter()
.any(|ceq| alpha_equivalent_equation(&ceq.lhs, &ceq.rhs, &mapped_lhs, &mapped_rhs));
if !preserved {
return Err(GatError::EquationNotPreserved {
equation: eq.name.to_string(),
detail: "mapped equation not found in codomain".to_owned(),
});
}
}
for de in &domain.directed_eqs {
let mapped_lhs = m.apply_to_term(&de.lhs);
let mapped_rhs = m.apply_to_term(&de.rhs);
let preserved = codomain
.directed_eqs
.iter()
.any(|cde| alpha_equivalent_equation(&cde.lhs, &cde.rhs, &mapped_lhs, &mapped_rhs));
if !preserved {
return Err(GatError::DirectedEquationNotPreserved {
equation: de.name.to_string(),
detail: "mapped directed equation not found in codomain".to_owned(),
});
}
}
Ok(())
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use crate::eq::{Equation, Term};
use crate::error::GatError;
use crate::model::{Model, ModelValue, migrate_model};
use crate::op::Operation;
use crate::sort::Sort;
use crate::theory::Theory;
fn monoid_theory(name: &str, mul_name: &str, unit_name: &str) -> Theory {
let carrier = Sort::simple("Carrier");
let mul = Operation::new(
mul_name,
vec![
("a".into(), "Carrier".into()),
("b".into(), "Carrier".into()),
],
"Carrier",
);
let unit = Operation::nullary(unit_name, "Carrier");
let assoc = Equation::new(
"assoc",
Term::app(
mul_name,
vec![
Term::var("a"),
Term::app(mul_name, vec![Term::var("b"), Term::var("c")]),
],
),
Term::app(
mul_name,
vec![
Term::app(mul_name, vec![Term::var("a"), Term::var("b")]),
Term::var("c"),
],
),
);
let left_id = Equation::new(
"left_id",
Term::app(mul_name, vec![Term::constant(unit_name), Term::var("a")]),
Term::var("a"),
);
let right_id = Equation::new(
"right_id",
Term::app(mul_name, vec![Term::var("a"), Term::constant(unit_name)]),
Term::var("a"),
);
Theory::new(
name,
vec![carrier],
vec![mul, unit],
vec![assoc, left_id, right_id],
)
}
fn commutative_monoid_theory(name: &str, mul_name: &str, unit_name: &str) -> Theory {
let carrier = Sort::simple("Carrier");
let mul = Operation::new(
mul_name,
vec![
("a".into(), "Carrier".into()),
("b".into(), "Carrier".into()),
],
"Carrier",
);
let unit = Operation::nullary(unit_name, "Carrier");
let assoc = Equation::new(
"assoc",
Term::app(
mul_name,
vec![
Term::var("a"),
Term::app(mul_name, vec![Term::var("b"), Term::var("c")]),
],
),
Term::app(
mul_name,
vec![
Term::app(mul_name, vec![Term::var("a"), Term::var("b")]),
Term::var("c"),
],
),
);
let left_id = Equation::new(
"left_id",
Term::app(mul_name, vec![Term::constant(unit_name), Term::var("a")]),
Term::var("a"),
);
let right_id = Equation::new(
"right_id",
Term::app(mul_name, vec![Term::var("a"), Term::constant(unit_name)]),
Term::var("a"),
);
let commutativity = Equation::new(
"comm",
Term::app(mul_name, vec![Term::var("a"), Term::var("b")]),
Term::app(mul_name, vec![Term::var("b"), Term::var("a")]),
);
Theory::new(
name,
vec![carrier],
vec![mul, unit],
vec![assoc, left_id, right_id, commutativity],
)
}
#[test]
fn identity_morphism_is_valid() {
let t = monoid_theory("Monoid", "mul", "unit");
let sort_map = HashMap::from([(Arc::from("Carrier"), Arc::from("Carrier"))]);
let op_map = HashMap::from([
(Arc::from("mul"), Arc::from("mul")),
(Arc::from("unit"), Arc::from("unit")),
]);
let m = TheoryMorphism::new("id", "Monoid", "Monoid", sort_map, op_map);
assert!(check_morphism(&m, &t, &t).is_ok());
}
#[test]
fn renaming_morphism_is_valid() {
let domain = monoid_theory("M1", "mul", "unit");
let codomain = monoid_theory("M2", "times", "one");
let sort_map = HashMap::from([(Arc::from("Carrier"), Arc::from("Carrier"))]);
let op_map = HashMap::from([
(Arc::from("mul"), Arc::from("times")),
(Arc::from("unit"), Arc::from("one")),
]);
let m = TheoryMorphism::new("rename", "M1", "M2", sort_map, op_map);
assert!(check_morphism(&m, &domain, &codomain).is_ok());
}
#[test]
fn missing_sort_mapping_fails() {
let t = monoid_theory("M", "mul", "unit");
let sort_map = HashMap::new(); let op_map = HashMap::from([
(Arc::from("mul"), Arc::from("mul")),
(Arc::from("unit"), Arc::from("unit")),
]);
let m = TheoryMorphism::new("bad", "M", "M", sort_map, op_map);
let result = check_morphism(&m, &t, &t);
assert!(matches!(result, Err(GatError::MissingSortMapping(_))));
}
#[test]
fn missing_op_mapping_fails() {
let t = monoid_theory("M", "mul", "unit");
let sort_map = HashMap::from([(Arc::from("Carrier"), Arc::from("Carrier"))]);
let op_map = HashMap::from([(Arc::from("mul"), Arc::from("mul"))]);
let m = TheoryMorphism::new("bad", "M", "M", sort_map, op_map);
let result = check_morphism(&m, &t, &t);
assert!(matches!(result, Err(GatError::MissingOpMapping(_))));
}
#[test]
fn sort_arity_mismatch_fails() {
use crate::sort::SortParam;
let domain = Theory::new("D", vec![Sort::simple("S")], Vec::new(), Vec::new());
let codomain = Theory::new(
"C",
vec![Sort::dependent("T", vec![SortParam::new("x", "T")])],
Vec::new(),
Vec::new(),
);
let sort_map = HashMap::from([(Arc::from("S"), Arc::from("T"))]);
let m = TheoryMorphism::new("bad", "D", "C", sort_map, HashMap::new());
let result = check_morphism(&m, &domain, &codomain);
assert!(matches!(result, Err(GatError::SortArityMismatch { .. })));
}
#[test]
fn op_type_mismatch_fails() {
let domain = Theory::new(
"D",
vec![Sort::simple("A"), Sort::simple("B")],
vec![Operation::unary("f", "x", "A", "B")],
Vec::new(),
);
let codomain = Theory::new(
"C",
vec![Sort::simple("A"), Sort::simple("B")],
vec![Operation::unary("f", "x", "B", "A")],
Vec::new(),
);
let sort_map = HashMap::from([
(Arc::from("A"), Arc::from("A")),
(Arc::from("B"), Arc::from("B")),
]);
let op_map = HashMap::from([(Arc::from("f"), Arc::from("f"))]);
let m = TheoryMorphism::new("bad", "D", "C", sort_map, op_map);
let result = check_morphism(&m, &domain, &codomain);
assert!(matches!(result, Err(GatError::OpTypeMismatch { .. })));
}
#[test]
fn morphism_with_renamed_equation_vars() {
let domain = Theory::new(
"D",
vec![Sort::simple("S")],
vec![Operation::new(
"f",
vec![("a".into(), "S".into()), ("b".into(), "S".into())],
"S",
)],
vec![Equation::new(
"comm",
Term::app("f", vec![Term::var("a"), Term::var("b")]),
Term::app("f", vec![Term::var("b"), Term::var("a")]),
)],
);
let codomain = Theory::new(
"C",
vec![Sort::simple("S")],
vec![Operation::new(
"f",
vec![("x".into(), "S".into()), ("y".into(), "S".into())],
"S",
)],
vec![Equation::new(
"comm",
Term::app("f", vec![Term::var("x"), Term::var("y")]),
Term::app("f", vec![Term::var("y"), Term::var("x")]),
)],
);
let sort_map = HashMap::from([(Arc::from("S"), Arc::from("S"))]);
let op_map = HashMap::from([(Arc::from("f"), Arc::from("f"))]);
let m = TheoryMorphism::new("id", "D", "C", sort_map, op_map);
assert!(
check_morphism(&m, &domain, &codomain).is_ok(),
"morphism should be valid: equations are α-equivalent"
);
}
#[test]
fn morphism_equation_multiplicity_mismatch_fails() {
let domain = Theory::new(
"D",
vec![Sort::simple("S")],
vec![
Operation::new(
"f",
vec![("a".into(), "S".into()), ("b".into(), "S".into())],
"S",
),
Operation::unary("g", "x", "S", "S"),
],
vec![Equation::new(
"eq1",
Term::app("f", vec![Term::var("x"), Term::var("x")]),
Term::app("g", vec![Term::var("x")]),
)],
);
let codomain = Theory::new(
"C",
vec![Sort::simple("S")],
vec![
Operation::new(
"f",
vec![("a".into(), "S".into()), ("b".into(), "S".into())],
"S",
),
Operation::unary("g", "x", "S", "S"),
],
vec![Equation::new(
"eq1",
Term::app("f", vec![Term::var("a"), Term::var("b")]),
Term::app("g", vec![Term::var("a")]),
)],
);
let sort_map = HashMap::from([(Arc::from("S"), Arc::from("S"))]);
let op_map = HashMap::from([
(Arc::from("f"), Arc::from("f")),
(Arc::from("g"), Arc::from("g")),
]);
let m = TheoryMorphism::new("bad", "D", "C", sort_map, op_map);
assert!(
check_morphism(&m, &domain, &codomain).is_err(),
"morphism should fail: equations have different variable multiplicity"
);
}
#[test]
fn morphism_preserves_directed_eqs() {
use crate::eq::DirectedEquation;
let theory = Theory::full(
"T",
Vec::new(),
vec![Sort::simple("A")],
vec![Operation::unary("f", "x", "A", "A")],
Vec::new(),
vec![DirectedEquation::new(
"idem",
Term::app("f", vec![Term::app("f", vec![Term::var("x")])]),
Term::app("f", vec![Term::var("x")]),
panproto_expr::Expr::Var("_".into()),
)],
Vec::new(),
);
let sort_map = HashMap::from([(Arc::from("A"), Arc::from("A"))]);
let op_map = HashMap::from([(Arc::from("f"), Arc::from("f"))]);
let m = TheoryMorphism::new("id", "T", "T", sort_map, op_map);
assert!(check_morphism(&m, &theory, &theory).is_ok());
}
#[test]
fn morphism_renaming_preserves_directed_eqs() {
use crate::eq::DirectedEquation;
let domain = Theory::full(
"D",
Vec::new(),
vec![Sort::simple("A")],
vec![Operation::unary("f", "x", "A", "A")],
Vec::new(),
vec![DirectedEquation::new(
"rule",
Term::app("f", vec![Term::var("x")]),
Term::var("x"),
panproto_expr::Expr::Var("_".into()),
)],
Vec::new(),
);
let codomain = Theory::full(
"C",
Vec::new(),
vec![Sort::simple("B")],
vec![Operation::unary("g", "y", "B", "B")],
Vec::new(),
vec![DirectedEquation::new(
"rule",
Term::app("g", vec![Term::var("y")]),
Term::var("y"),
panproto_expr::Expr::Var("_".into()),
)],
Vec::new(),
);
let sort_map = HashMap::from([(Arc::from("A"), Arc::from("B"))]);
let op_map = HashMap::from([(Arc::from("f"), Arc::from("g"))]);
let m = TheoryMorphism::new("rename", "D", "C", sort_map, op_map);
assert!(check_morphism(&m, &domain, &codomain).is_ok());
}
#[test]
fn morphism_missing_directed_eq_fails() {
use crate::eq::DirectedEquation;
let domain = Theory::full(
"D",
Vec::new(),
vec![Sort::simple("A")],
vec![Operation::unary("f", "x", "A", "A")],
Vec::new(),
vec![DirectedEquation::new(
"rule",
Term::app("f", vec![Term::var("x")]),
Term::var("x"),
panproto_expr::Expr::Var("_".into()),
)],
Vec::new(),
);
let codomain = Theory::new(
"C",
vec![Sort::simple("A")],
vec![Operation::unary("f", "x", "A", "A")],
Vec::new(),
);
let sort_map = HashMap::from([(Arc::from("A"), Arc::from("A"))]);
let op_map = HashMap::from([(Arc::from("f"), Arc::from("f"))]);
let m = TheoryMorphism::new("bad", "D", "C", sort_map, op_map);
assert!(matches!(
check_morphism(&m, &domain, &codomain),
Err(GatError::DirectedEquationNotPreserved { .. })
));
}
#[test]
fn morphism_no_directed_eqs_still_valid() {
let t = monoid_theory("M", "mul", "unit");
let sort_map = HashMap::from([(Arc::from("Carrier"), Arc::from("Carrier"))]);
let op_map = HashMap::from([
(Arc::from("mul"), Arc::from("mul")),
(Arc::from("unit"), Arc::from("unit")),
]);
let m = TheoryMorphism::new("id", "M", "M", sort_map, op_map);
assert!(check_morphism(&m, &t, &t).is_ok());
}
#[test]
fn reverse_mul_morphism_commutative_monoid() {
let theory = commutative_monoid_theory("CMonoid", "mul", "unit");
let sort_map = HashMap::from([(Arc::from("Carrier"), Arc::from("Carrier"))]);
let op_map = HashMap::from([
(Arc::from("mul"), Arc::from("mul")),
(Arc::from("unit"), Arc::from("unit")),
]);
let m = TheoryMorphism::new("swap", "CMonoid", "CMonoid", sort_map, op_map);
assert!(check_morphism(&m, &theory, &theory).is_ok());
let mut model = Model::new("CMonoid");
model.add_sort("Carrier", (0..10).map(ModelValue::Int).collect());
model.add_op("mul", |args: &[ModelValue]| match (&args[0], &args[1]) {
(ModelValue::Int(a), ModelValue::Int(b)) => Ok(ModelValue::Int(a + b)),
_ => Err(GatError::ModelError("expected Int".to_owned())),
});
model.add_op("unit", |_: &[ModelValue]| Ok(ModelValue::Int(0)));
let migrated = migrate_model(&m, &model).unwrap();
let orig = model
.eval("mul", &[ModelValue::Int(3), ModelValue::Int(5)])
.unwrap();
let mig = migrated
.eval("mul", &[ModelValue::Int(3), ModelValue::Int(5)])
.unwrap();
assert_eq!(orig, mig);
let orig_swap = model
.eval("mul", &[ModelValue::Int(5), ModelValue::Int(3)])
.unwrap();
assert_eq!(orig, orig_swap);
let orig_unit = model.eval("unit", &[]).unwrap();
let mig_unit = migrated.eval("unit", &[]).unwrap();
assert_eq!(orig_unit, mig_unit);
}
}