use std::collections::HashMap;
use std::sync::Arc;
use rustc_hash::FxHashMap;
use crate::eq::{Term, alpha_equivalent, normalize};
use crate::error::GatError;
use crate::morphism::TheoryMorphism;
use crate::theory::Theory;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct NaturalTransformation {
pub name: Arc<str>,
pub source: Arc<str>,
pub target: Arc<str>,
pub components: HashMap<Arc<str>, Term>,
}
pub fn check_natural_transformation(
nt: &NaturalTransformation,
f: &TheoryMorphism,
g: &TheoryMorphism,
domain: &Theory,
codomain: &Theory,
) -> Result<(), GatError> {
if f.domain != g.domain || f.codomain != g.codomain {
return Err(GatError::NatTransDomainMismatch {
source_morphism: f.name.to_string(),
target_morphism: g.name.to_string(),
});
}
for sort in &domain.sorts {
if !nt.components.contains_key(&sort.name) {
return Err(GatError::MissingNatTransComponent(sort.name.to_string()));
}
}
for (sort_name, term) in &nt.components {
validate_term_ops(term, codomain).map_err(|detail| GatError::NatTransComponentInvalid {
sort: sort_name.to_string(),
detail,
})?;
}
for op in &domain.ops {
let output_sort = &op.output;
let alpha_output = nt
.components
.get(output_sort)
.ok_or_else(|| GatError::MissingNatTransComponent(output_sort.to_string()))?;
let f_op = f
.op_map
.get(&op.name)
.cloned()
.unwrap_or_else(|| Arc::clone(&op.name));
let g_op = g
.op_map
.get(&op.name)
.cloned()
.unwrap_or_else(|| Arc::clone(&op.name));
let var_names: Vec<Arc<str>> = if op.inputs.len() == 1 {
vec![Arc::from("x")]
} else {
(0..op.inputs.len())
.map(|i| Arc::from(format!("x{i}")))
.collect()
};
let f_op_applied = Term::app(
Arc::clone(&f_op),
var_names.iter().map(|v| Term::var(Arc::clone(v))).collect(),
);
let mut subst_lhs = FxHashMap::default();
subst_lhs.insert(Arc::from("x"), f_op_applied);
let lhs = alpha_output.substitute(&subst_lhs);
let mut rhs_args = Vec::with_capacity(op.inputs.len());
for (i, (_, input_sort)) in op.inputs.iter().enumerate() {
let alpha_input = nt
.components
.get(input_sort)
.ok_or_else(|| GatError::MissingNatTransComponent(input_sort.to_string()))?;
let mut subst_arg = FxHashMap::default();
subst_arg.insert(Arc::from("x"), Term::var(Arc::clone(&var_names[i])));
rhs_args.push(alpha_input.substitute(&subst_arg));
}
let rhs = Term::app(g_op, rhs_args);
let lhs_norm = normalize(&lhs, &codomain.directed_eqs, 1000);
let rhs_norm = normalize(&rhs, &codomain.directed_eqs, 1000);
if !alpha_equivalent(&lhs_norm, &rhs_norm) {
return Err(GatError::NaturalityViolation {
op: op.name.to_string(),
lhs: format!("{lhs_norm:?}"),
rhs: format!("{rhs_norm:?}"),
});
}
}
Ok(())
}
fn validate_term_ops(term: &Term, codomain: &Theory) -> Result<(), String> {
match term {
Term::Var(_) => Ok(()),
Term::App { op, args } => {
if !codomain.has_op(op) {
return Err(format!("operation {op} not found in codomain"));
}
for arg in args {
validate_term_ops(arg, codomain)?;
}
Ok(())
}
}
}
pub fn vertical_compose(
alpha: &NaturalTransformation,
beta: &NaturalTransformation,
domain: &Theory,
) -> Result<NaturalTransformation, GatError> {
if alpha.target != beta.source {
return Err(GatError::NatTransComposeMismatch {
alpha_target: alpha.target.to_string(),
beta_source: beta.source.to_string(),
});
}
let mut components = HashMap::new();
for sort in &domain.sorts {
let alpha_s = alpha
.components
.get(&sort.name)
.ok_or_else(|| GatError::MissingNatTransComponent(sort.name.to_string()))?;
let beta_s = beta
.components
.get(&sort.name)
.ok_or_else(|| GatError::MissingNatTransComponent(sort.name.to_string()))?;
let mut subst = FxHashMap::default();
subst.insert(Arc::from("x"), alpha_s.clone());
let composed = beta_s.substitute(&subst);
components.insert(Arc::clone(&sort.name), composed);
}
Ok(NaturalTransformation {
name: Arc::from(format!("{}.{}", beta.name, alpha.name)),
source: Arc::clone(&alpha.source),
target: Arc::clone(&beta.target),
components,
})
}
pub fn horizontal_compose(
alpha: &NaturalTransformation,
beta: &NaturalTransformation,
_f: &TheoryMorphism,
g: &TheoryMorphism,
h: &TheoryMorphism,
domain: &Theory,
) -> Result<NaturalTransformation, GatError> {
let mut components = HashMap::new();
for sort in &domain.sorts {
let g_s = g
.sort_map
.get(&sort.name)
.ok_or_else(|| GatError::MissingSortMapping(sort.name.to_string()))?;
let beta_gs = beta
.components
.get(g_s)
.ok_or_else(|| GatError::MissingNatTransComponent(g_s.to_string()))?;
let alpha_s = alpha
.components
.get(&sort.name)
.ok_or_else(|| GatError::MissingNatTransComponent(sort.name.to_string()))?;
let h_alpha_s = h.apply_to_term(alpha_s);
let mut subst = FxHashMap::default();
subst.insert(Arc::from("x"), h_alpha_s);
let composed = beta_gs.substitute(&subst);
components.insert(Arc::clone(&sort.name), composed);
}
Ok(NaturalTransformation {
name: Arc::from(format!("{}*{}", beta.name, alpha.name)),
source: Arc::from(format!("{}.{}", beta.source, alpha.source)),
target: Arc::from(format!("{}.{}", beta.target, alpha.target)),
components,
})
}
#[allow(clippy::too_many_arguments)]
pub fn check_interchange(
alpha: &NaturalTransformation,
beta: &NaturalTransformation,
alpha_prime: &NaturalTransformation,
beta_prime: &NaturalTransformation,
f: &TheoryMorphism,
g: &TheoryMorphism,
h: &TheoryMorphism,
f_prime: &TheoryMorphism,
g_prime: &TheoryMorphism,
_h_prime: &TheoryMorphism,
domain: &Theory,
middle: &Theory,
) -> Result<(), GatError> {
let beta_alpha = vertical_compose(alpha, beta, domain)?;
let beta_prime_alpha_prime = vertical_compose(alpha_prime, beta_prime, middle)?;
let lhs = horizontal_compose(&beta_alpha, &beta_prime_alpha_prime, f, h, f_prime, domain)?;
let alpha_star = horizontal_compose(alpha, alpha_prime, f, g, f_prime, domain)?;
let beta_star = horizontal_compose(beta, beta_prime, g, h, g_prime, domain)?;
let rhs = vertical_compose(&alpha_star, &beta_star, domain)?;
for sort in &domain.sorts {
let lhs_comp = lhs
.components
.get(&sort.name)
.ok_or_else(|| GatError::MissingNatTransComponent(sort.name.to_string()))?;
let rhs_comp = rhs
.components
.get(&sort.name)
.ok_or_else(|| GatError::MissingNatTransComponent(sort.name.to_string()))?;
if !alpha_equivalent(lhs_comp, rhs_comp) {
return Err(GatError::NaturalityViolation {
op: format!("interchange at sort {}", sort.name),
lhs: format!("{lhs_comp:?}"),
rhs: format!("{rhs_comp:?}"),
});
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::morphism::TheoryMorphism;
use crate::op::Operation;
use crate::sort::Sort;
use crate::theory::Theory;
fn two_sort_theory() -> Theory {
Theory::new(
"T",
vec![Sort::simple("A"), Sort::simple("B")],
vec![Operation::unary("f", "x", "A", "B")],
Vec::new(),
)
}
fn identity_morphism(theory: &Theory, name: &str) -> TheoryMorphism {
let sort_map: HashMap<Arc<str>, Arc<str>> = theory
.sorts
.iter()
.map(|s| (Arc::clone(&s.name), Arc::clone(&s.name)))
.collect();
let op_map: HashMap<Arc<str>, Arc<str>> = theory
.ops
.iter()
.map(|o| (Arc::clone(&o.name), Arc::clone(&o.name)))
.collect();
TheoryMorphism::new(name, &*theory.name, &*theory.name, sort_map, op_map)
}
fn identity_nat_trans(
theory: &Theory,
source: &str,
target: &str,
name: &str,
) -> NaturalTransformation {
let components: HashMap<Arc<str>, Term> = theory
.sorts
.iter()
.map(|s| (Arc::clone(&s.name), Term::var("x")))
.collect();
NaturalTransformation {
name: Arc::from(name),
source: Arc::from(source),
target: Arc::from(target),
components,
}
}
#[test]
fn identity_nat_trans_validates() {
let theory = two_sort_theory();
let morph = identity_morphism(&theory, "id");
let nt = identity_nat_trans(&theory, "id", "id", "id_nt");
let result = check_natural_transformation(&nt, &morph, &morph, &theory, &theory);
assert!(result.is_ok(), "expected Ok, got {result:?}");
}
#[test]
fn vertical_compose_identities_yields_identity() -> Result<(), Box<dyn std::error::Error>> {
let theory = two_sort_theory();
let alpha = identity_nat_trans(&theory, "id", "id", "alpha");
let beta = identity_nat_trans(&theory, "id", "id", "beta");
let composed = vertical_compose(&alpha, &beta, &theory)?;
for sort in &theory.sorts {
let comp = composed
.components
.get(&sort.name)
.ok_or("missing component for sort")?;
assert_eq!(comp, &Term::var("x"));
}
Ok(())
}
#[test]
fn naturality_violation_detected() {
let theory = two_sort_theory();
let _morph = identity_morphism(&theory, "id");
let mut components = HashMap::new();
components.insert(Arc::from("A"), Term::var("x"));
let theory_with_const = Theory::new(
"T",
vec![Sort::simple("A"), Sort::simple("B")],
vec![
Operation::unary("f", "x", "A", "B"),
Operation::nullary("bad", "B"),
],
Vec::new(),
);
let morph2 = identity_morphism(&theory_with_const, "id");
components.insert(Arc::from("B"), Term::constant("bad"));
let nt = NaturalTransformation {
name: Arc::from("bad_nt"),
source: Arc::from("id"),
target: Arc::from("id"),
components,
};
let result =
check_natural_transformation(&nt, &morph2, &morph2, &theory, &theory_with_const);
assert!(
matches!(result, Err(GatError::NaturalityViolation { .. })),
"expected NaturalityViolation, got {result:?}"
);
}
#[test]
fn missing_component_detected() {
let theory = two_sort_theory();
let morph = identity_morphism(&theory, "id");
let mut components = HashMap::new();
components.insert(Arc::from("A"), Term::var("x"));
let nt = NaturalTransformation {
name: Arc::from("partial"),
source: Arc::from("id"),
target: Arc::from("id"),
components,
};
let result = check_natural_transformation(&nt, &morph, &morph, &theory, &theory);
assert!(
matches!(result, Err(GatError::MissingNatTransComponent(_))),
"expected MissingNatTransComponent, got {result:?}"
);
}
#[test]
fn domain_mismatch_detected() {
let t1 = Theory::new("T1", vec![Sort::simple("A")], Vec::new(), Vec::new());
let _t2 = Theory::new("T2", vec![Sort::simple("B")], Vec::new(), Vec::new());
let f = TheoryMorphism::new(
"f",
"T1",
"T1",
HashMap::from([(Arc::from("A"), Arc::from("A"))]),
HashMap::new(),
);
let g = TheoryMorphism::new(
"g",
"T2",
"T2",
HashMap::from([(Arc::from("B"), Arc::from("B"))]),
HashMap::new(),
);
let nt = NaturalTransformation {
name: Arc::from("bad"),
source: Arc::from("f"),
target: Arc::from("g"),
components: HashMap::new(),
};
let result = check_natural_transformation(&nt, &f, &g, &t1, &t1);
assert!(
matches!(result, Err(GatError::NatTransDomainMismatch { .. })),
"expected NatTransDomainMismatch, got {result:?}"
);
}
#[test]
fn naturality_passes_after_normalization() {
use crate::eq::DirectedEquation;
let domain = Theory::new(
"D",
vec![Sort::simple("A")],
vec![Operation::unary("f", "x", "A", "A")],
Vec::new(),
);
let codomain = Theory::full(
"C",
Vec::new(),
vec![Sort::simple("A")],
vec![Operation::unary("h", "x", "A", "A")],
Vec::new(),
vec![DirectedEquation::new(
"idem",
Term::app("h", vec![Term::app("h", vec![Term::var("y")])]),
Term::app("h", vec![Term::var("y")]),
panproto_expr::Expr::Var("_".into()),
)],
Vec::new(),
);
let f_morph = TheoryMorphism::new(
"F",
"D",
"C",
HashMap::from([(Arc::from("A"), Arc::from("A"))]),
HashMap::from([(Arc::from("f"), Arc::from("h"))]),
);
let g_morph = TheoryMorphism::new(
"G",
"D",
"C",
HashMap::from([(Arc::from("A"), Arc::from("A"))]),
HashMap::from([(Arc::from("f"), Arc::from("h"))]),
);
let mut components = HashMap::new();
components.insert(
Arc::from("A"),
Term::app("h", vec![Term::app("h", vec![Term::var("x")])]),
);
let nt = NaturalTransformation {
name: Arc::from("alpha"),
source: Arc::from("F"),
target: Arc::from("G"),
components,
};
let result = check_natural_transformation(&nt, &f_morph, &g_morph, &domain, &codomain);
assert!(
result.is_ok(),
"naturality should pass after normalization: {result:?}"
);
}
#[test]
fn composition_mismatch_detected() {
let theory = two_sort_theory();
let alpha = identity_nat_trans(&theory, "f", "g", "alpha");
let beta = identity_nat_trans(&theory, "h", "k", "beta");
let result = vertical_compose(&alpha, &beta, &theory);
assert!(
matches!(result, Err(GatError::NatTransComposeMismatch { .. })),
"expected NatTransComposeMismatch, got {result:?}"
);
}
#[test]
fn interchange_law_identities() -> Result<(), Box<dyn std::error::Error>> {
let theory = two_sort_theory();
let id_morph = identity_morphism(&theory, "id");
let alpha = identity_nat_trans(&theory, "id", "id", "alpha");
let beta = identity_nat_trans(&theory, "id", "id", "beta");
let alpha_prime = identity_nat_trans(&theory, "id", "id", "alpha_prime");
let beta_prime = identity_nat_trans(&theory, "id", "id", "beta_prime");
check_interchange(
&alpha,
&beta,
&alpha_prime,
&beta_prime,
&id_morph,
&id_morph,
&id_morph,
&id_morph,
&id_morph,
&id_morph,
&theory,
&theory,
)?;
Ok(())
}
}