use csw_core::{CategorySpec, DiagonalSpec, TerminalSpec};
use crate::{
EquivalenceRule, ReductionRule, StructuralRules, TermConstructor, TermConstructorKind,
TypeConstructor, TypeConstructorKind, TypeSystem, TypingRule,
};
pub struct Deriver;
impl Deriver {
pub fn derive(spec: &CategorySpec) -> TypeSystem {
let mut ts = TypeSystem {
name: spec.name.clone(),
structural: StructuralRules {
weakening: matches!(spec.structure.terminal_morphism, TerminalSpec::Universal),
contraction: matches!(spec.structure.diagonal, DiagonalSpec::Universal),
exchange: spec.structure.symmetry,
},
..Default::default()
};
Self::derive_variable(&mut ts);
for obj in &spec.base_objects {
ts.type_constructors.push(TypeConstructor {
name: obj.name.clone(),
symbol: obj.name.clone(),
arity: 0,
kind: TypeConstructorKind::Base,
});
}
if spec.structure.terminal {
Self::derive_terminal(&mut ts);
}
if spec.structure.initial {
Self::derive_initial(&mut ts);
}
if spec.structure.products {
Self::derive_products(&mut ts, spec);
}
if spec.structure.coproducts {
Self::derive_coproducts(&mut ts);
}
if spec.structure.exponentials {
Self::derive_exponentials(&mut ts);
}
if spec.structure.tensor.is_some() {
Self::derive_tensor(&mut ts, spec);
}
if spec.structure.linear_hom {
Self::derive_linear_hom(&mut ts);
}
ts
}
fn derive_variable(ts: &mut TypeSystem) {
ts.term_constructors.push(TermConstructor {
name: "var".to_string(),
symbol: "x".to_string(),
kind: TermConstructorKind::Variable,
});
ts.typing_rules.push(TypingRule {
name: "var".to_string(),
premises: vec![],
conclusion: "Γ, x:A ⊢ x : A".to_string(),
side_conditions: vec![],
});
}
fn derive_terminal(ts: &mut TypeSystem) {
ts.type_constructors.push(TypeConstructor {
name: "Unit".to_string(),
symbol: "1".to_string(),
arity: 0,
kind: TypeConstructorKind::Unit,
});
ts.term_constructors.push(TermConstructor {
name: "unit".to_string(),
symbol: "()".to_string(),
kind: TermConstructorKind::UnitIntro,
});
ts.typing_rules.push(TypingRule {
name: "unit-intro".to_string(),
premises: vec![],
conclusion: "Γ ⊢ () : 1".to_string(),
side_conditions: vec![],
});
ts.equivalence_rules.push(EquivalenceRule {
name: "unit-eta".to_string(),
lhs: "e".to_string(),
rhs: "()".to_string(),
condition: Some("e : 1".to_string()),
});
}
fn derive_initial(ts: &mut TypeSystem) {
ts.type_constructors.push(TypeConstructor {
name: "Empty".to_string(),
symbol: "0".to_string(),
arity: 0,
kind: TypeConstructorKind::Empty,
});
ts.term_constructors.push(TermConstructor {
name: "absurd".to_string(),
symbol: "absurd".to_string(),
kind: TermConstructorKind::Absurd,
});
ts.typing_rules.push(TypingRule {
name: "absurd-elim".to_string(),
premises: vec!["Γ ⊢ e : 0".to_string()],
conclusion: "Γ ⊢ absurd e : A".to_string(),
side_conditions: vec![],
});
}
fn derive_products(ts: &mut TypeSystem, spec: &CategorySpec) {
ts.type_constructors.push(TypeConstructor {
name: "Product".to_string(),
symbol: "×".to_string(),
arity: 2,
kind: TypeConstructorKind::Product,
});
ts.term_constructors.push(TermConstructor {
name: "pair".to_string(),
symbol: "(_, _)".to_string(),
kind: TermConstructorKind::PairIntro,
});
ts.term_constructors.push(TermConstructor {
name: "fst".to_string(),
symbol: "fst".to_string(),
kind: TermConstructorKind::PairElimFst,
});
ts.term_constructors.push(TermConstructor {
name: "snd".to_string(),
symbol: "snd".to_string(),
kind: TermConstructorKind::PairElimSnd,
});
ts.typing_rules.push(TypingRule {
name: "pair-intro".to_string(),
premises: vec!["Γ ⊢ a : A".to_string(), "Γ ⊢ b : B".to_string()],
conclusion: "Γ ⊢ (a, b) : A × B".to_string(),
side_conditions: vec![],
});
ts.typing_rules.push(TypingRule {
name: "fst-elim".to_string(),
premises: vec!["Γ ⊢ p : A × B".to_string()],
conclusion: "Γ ⊢ fst p : A".to_string(),
side_conditions: vec![],
});
ts.typing_rules.push(TypingRule {
name: "snd-elim".to_string(),
premises: vec!["Γ ⊢ p : A × B".to_string()],
conclusion: "Γ ⊢ snd p : B".to_string(),
side_conditions: vec![],
});
ts.reduction_rules.push(ReductionRule {
name: "fst-beta".to_string(),
lhs: "fst (a, b)".to_string(),
rhs: "a".to_string(),
});
ts.reduction_rules.push(ReductionRule {
name: "snd-beta".to_string(),
lhs: "snd (a, b)".to_string(),
rhs: "b".to_string(),
});
if matches!(spec.structure.diagonal, DiagonalSpec::Universal) {
ts.equivalence_rules.push(EquivalenceRule {
name: "pair-eta".to_string(),
lhs: "p".to_string(),
rhs: "(fst p, snd p)".to_string(),
condition: Some("p : A × B".to_string()),
});
}
}
fn derive_coproducts(ts: &mut TypeSystem) {
ts.type_constructors.push(TypeConstructor {
name: "Coproduct".to_string(),
symbol: "+".to_string(),
arity: 2,
kind: TypeConstructorKind::Coproduct,
});
ts.term_constructors.push(TermConstructor {
name: "inl".to_string(),
symbol: "inl".to_string(),
kind: TermConstructorKind::InjLeft,
});
ts.term_constructors.push(TermConstructor {
name: "inr".to_string(),
symbol: "inr".to_string(),
kind: TermConstructorKind::InjRight,
});
ts.term_constructors.push(TermConstructor {
name: "case".to_string(),
symbol: "case _ of inl x ⇒ _ | inr y ⇒ _".to_string(),
kind: TermConstructorKind::Case,
});
ts.typing_rules.push(TypingRule {
name: "inl-intro".to_string(),
premises: vec!["Γ ⊢ a : A".to_string()],
conclusion: "Γ ⊢ inl a : A + B".to_string(),
side_conditions: vec![],
});
ts.typing_rules.push(TypingRule {
name: "inr-intro".to_string(),
premises: vec!["Γ ⊢ b : B".to_string()],
conclusion: "Γ ⊢ inr b : A + B".to_string(),
side_conditions: vec![],
});
ts.typing_rules.push(TypingRule {
name: "case-elim".to_string(),
premises: vec![
"Γ ⊢ e : A + B".to_string(),
"Γ, x:A ⊢ e₁ : C".to_string(),
"Γ, y:B ⊢ e₂ : C".to_string(),
],
conclusion: "Γ ⊢ case e of inl x ⇒ e₁ | inr y ⇒ e₂ : C".to_string(),
side_conditions: vec![],
});
ts.reduction_rules.push(ReductionRule {
name: "case-inl-beta".to_string(),
lhs: "case (inl a) of inl x ⇒ e₁ | inr y ⇒ e₂".to_string(),
rhs: "e₁[a/x]".to_string(),
});
ts.reduction_rules.push(ReductionRule {
name: "case-inr-beta".to_string(),
lhs: "case (inr b) of inl x ⇒ e₁ | inr y ⇒ e₂".to_string(),
rhs: "e₂[b/y]".to_string(),
});
ts.equivalence_rules.push(EquivalenceRule {
name: "sum-eta".to_string(),
lhs: "e".to_string(),
rhs: "case e of inl x ⇒ inl x | inr y ⇒ inr y".to_string(),
condition: Some("e : A + B".to_string()),
});
}
fn derive_exponentials(ts: &mut TypeSystem) {
ts.type_constructors.push(TypeConstructor {
name: "Arrow".to_string(),
symbol: "→".to_string(),
arity: 2,
kind: TypeConstructorKind::Exponential,
});
ts.term_constructors.push(TermConstructor {
name: "abs".to_string(),
symbol: "λ_._ ".to_string(),
kind: TermConstructorKind::Abstraction,
});
ts.term_constructors.push(TermConstructor {
name: "app".to_string(),
symbol: "_ _".to_string(),
kind: TermConstructorKind::Application,
});
ts.typing_rules.push(TypingRule {
name: "abs-intro".to_string(),
premises: vec!["Γ, x:A ⊢ e : B".to_string()],
conclusion: "Γ ⊢ λx.e : A → B".to_string(),
side_conditions: vec![],
});
ts.typing_rules.push(TypingRule {
name: "app-elim".to_string(),
premises: vec!["Γ ⊢ f : A → B".to_string(), "Γ ⊢ a : A".to_string()],
conclusion: "Γ ⊢ f a : B".to_string(),
side_conditions: vec![],
});
ts.reduction_rules.push(ReductionRule {
name: "beta".to_string(),
lhs: "(λx.e) a".to_string(),
rhs: "e[a/x]".to_string(),
});
ts.equivalence_rules.push(EquivalenceRule {
name: "eta".to_string(),
lhs: "f".to_string(),
rhs: "λx.f x".to_string(),
condition: Some("f : A → B, x fresh".to_string()),
});
}
fn derive_tensor(ts: &mut TypeSystem, spec: &CategorySpec) {
let tensor_spec = spec.structure.tensor.as_ref().unwrap();
ts.type_constructors.push(TypeConstructor {
name: "Tensor".to_string(),
symbol: tensor_spec.symbol.clone(),
arity: 2,
kind: TypeConstructorKind::Tensor,
});
ts.type_constructors.push(TypeConstructor {
name: "TensorUnit".to_string(),
symbol: tensor_spec.unit_symbol.clone(),
arity: 0,
kind: TypeConstructorKind::Unit,
});
ts.term_constructors.push(TermConstructor {
name: "tensor-pair".to_string(),
symbol: format!("(_ {} _)", tensor_spec.symbol),
kind: TermConstructorKind::PairIntro,
});
ts.term_constructors.push(TermConstructor {
name: "let-tensor".to_string(),
symbol: "let (x ⊗ y) = _ in _".to_string(),
kind: TermConstructorKind::LetPair,
});
ts.typing_rules.push(TypingRule {
name: "tensor-intro".to_string(),
premises: vec!["Γ₁ ⊢ a : A".to_string(), "Γ₂ ⊢ b : B".to_string()],
conclusion: format!("Γ₁, Γ₂ ⊢ (a {} b) : A {} B", tensor_spec.symbol, tensor_spec.symbol),
side_conditions: vec!["Γ₁ ∩ Γ₂ = ∅".to_string()],
});
ts.typing_rules.push(TypingRule {
name: "tensor-elim".to_string(),
premises: vec![
format!("Γ₁ ⊢ p : A {} B", tensor_spec.symbol),
"Γ₂, x:A, y:B ⊢ e : C".to_string(),
],
conclusion: "Γ₁, Γ₂ ⊢ let (x ⊗ y) = p in e : C".to_string(),
side_conditions: vec![],
});
ts.reduction_rules.push(ReductionRule {
name: "tensor-beta".to_string(),
lhs: "let (x ⊗ y) = (a ⊗ b) in e".to_string(),
rhs: "e[a/x, b/y]".to_string(),
});
}
fn derive_linear_hom(ts: &mut TypeSystem) {
ts.type_constructors.push(TypeConstructor {
name: "Lollipop".to_string(),
symbol: "⊸".to_string(),
arity: 2,
kind: TypeConstructorKind::LinearHom,
});
ts.term_constructors.push(TermConstructor {
name: "linear-abs".to_string(),
symbol: "λ°_._ ".to_string(),
kind: TermConstructorKind::Abstraction,
});
ts.typing_rules.push(TypingRule {
name: "linear-abs-intro".to_string(),
premises: vec!["Γ, x:A ⊢ e : B".to_string()],
conclusion: "Γ ⊢ λx.e : A ⊸ B".to_string(),
side_conditions: vec!["x appears exactly once in e".to_string()],
});
ts.typing_rules.push(TypingRule {
name: "linear-app-elim".to_string(),
premises: vec!["Γ₁ ⊢ f : A ⊸ B".to_string(), "Γ₂ ⊢ a : A".to_string()],
conclusion: "Γ₁, Γ₂ ⊢ f a : B".to_string(),
side_conditions: vec!["Γ₁ ∩ Γ₂ = ∅".to_string()],
});
ts.reduction_rules.push(ReductionRule {
name: "linear-beta".to_string(),
lhs: "(λx.e) a".to_string(),
rhs: "e[a/x]".to_string(),
});
ts.equivalence_rules.push(EquivalenceRule {
name: "linear-eta".to_string(),
lhs: "f".to_string(),
rhs: "λx.f x".to_string(),
condition: Some("f : A ⊸ B, x fresh".to_string()),
});
}
}
#[cfg(test)]
mod tests {
use super::*;
use csw_core::CategoryBuilder;
#[test]
fn test_derive_stlc() {
let ccc = CategoryBuilder::new("STLC")
.with_base("Int")
.with_terminal()
.with_products()
.with_exponentials()
.cartesian()
.build()
.unwrap();
let ts = Deriver::derive(&ccc);
assert_eq!(ts.name, "STLC");
assert!(ts.structural.weakening);
assert!(ts.structural.contraction);
assert!(ts.structural.exchange);
let type_names: Vec<_> = ts.type_constructors.iter().map(|t| &t.name).collect();
assert!(type_names.contains(&&"Int".to_string()));
assert!(type_names.contains(&&"Unit".to_string()));
assert!(type_names.contains(&&"Product".to_string()));
assert!(type_names.contains(&&"Arrow".to_string()));
}
#[test]
fn test_derive_linear() {
let smcc = CategoryBuilder::new("Linear")
.with_base("Int")
.with_tensor()
.with_linear_hom()
.with_symmetry()
.linear()
.build()
.unwrap();
let ts = Deriver::derive(&smcc);
assert_eq!(ts.name, "Linear");
assert!(!ts.structural.weakening);
assert!(!ts.structural.contraction);
assert!(ts.structural.exchange);
let type_names: Vec<_> = ts.type_constructors.iter().map(|t| &t.name).collect();
assert!(type_names.contains(&&"Tensor".to_string()));
assert!(type_names.contains(&&"Lollipop".to_string()));
}
#[test]
fn test_print_type_system() {
let ccc = CategoryBuilder::new("STLC")
.with_terminal()
.with_products()
.with_exponentials()
.cartesian()
.build()
.unwrap();
let ts = Deriver::derive(&ccc);
let output = ts.to_string();
assert!(output.contains("STLC Type System"));
assert!(output.contains("Weakening: ✓"));
}
}