use std::sync::Arc;
use rustc_hash::FxHashMap;
use crate::eq::Term;
#[derive(Debug, Clone, Eq, serde::Serialize)]
#[serde(untagged)]
pub enum SortExpr {
Name(Arc<str>),
App {
name: Arc<str>,
args: Vec<Term>,
},
}
impl SortExpr {
#[must_use]
pub fn app(name: impl Into<Arc<str>>, args: Vec<Term>) -> Self {
if args.is_empty() {
Self::Name(name.into())
} else {
Self::App {
name: name.into(),
args,
}
}
}
#[must_use]
pub fn normalize(self) -> Self {
if let Self::App { name, args } = &self {
if args.is_empty() {
return Self::Name(Arc::clone(name));
}
}
self
}
#[must_use]
pub const fn head(&self) -> &Arc<str> {
match self {
Self::Name(n) | Self::App { name: n, .. } => n,
}
}
#[must_use]
pub fn args(&self) -> &[Term] {
match self {
Self::Name(_) => &[],
Self::App { args, .. } => args,
}
}
#[must_use]
pub fn subst(&self, mapping: &FxHashMap<Arc<str>, Term>) -> Self {
match self {
Self::Name(n) => Self::Name(Arc::clone(n)),
Self::App { name, args } => Self::app(
Arc::clone(name),
args.iter().map(|t| t.substitute(mapping)).collect(),
),
}
}
#[must_use]
pub fn alpha_eq(&self, other: &Self) -> bool {
self.head() == other.head() && self.args() == other.args()
}
#[must_use]
pub fn alpha_eq_modulo_rewrites(
&self,
other: &Self,
rules: &[crate::eq::DirectedEquation],
step_limit: usize,
) -> bool {
if self.head() != other.head() {
return false;
}
if self.args().len() != other.args().len() {
return false;
}
let normalize_all = |args: &[Term]| -> Vec<Term> {
args.iter()
.map(|t| crate::eq::normalize(t, rules, step_limit))
.collect()
};
let left = normalize_all(self.args());
let right = normalize_all(other.args());
left == right
}
#[must_use]
pub fn rename_head(&self, sort_map: &std::collections::HashMap<Arc<str>, Arc<str>>) -> Self {
match self {
Self::Name(n) => Self::Name(sort_map.get(n).cloned().unwrap_or_else(|| Arc::clone(n))),
Self::App { name, args } => Self::app(
sort_map
.get(name)
.cloned()
.unwrap_or_else(|| Arc::clone(name)),
args.clone(),
),
}
}
#[must_use]
pub fn apply_maps(
&self,
sort_map: &std::collections::HashMap<Arc<str>, Arc<str>>,
op_map: &std::collections::HashMap<Arc<str>, Arc<str>>,
) -> Self {
match self {
Self::Name(n) => Self::Name(sort_map.get(n).cloned().unwrap_or_else(|| Arc::clone(n))),
Self::App { name, args } => Self::app(
sort_map
.get(name)
.cloned()
.unwrap_or_else(|| Arc::clone(name)),
args.iter().map(|t| t.rename_ops(op_map)).collect(),
),
}
}
}
#[must_use]
pub fn positional_param_rename<I, J>(
domain_params: I,
target_params: J,
) -> FxHashMap<Arc<str>, Term>
where
I: IntoIterator<Item = Arc<str>>,
J: IntoIterator<Item = Arc<str>>,
{
let mut rename = FxHashMap::default();
for (d, t) in domain_params.into_iter().zip(target_params) {
if d != t {
rename.insert(d, Term::Var(t));
}
}
rename
}
#[must_use]
pub fn signatures_equivalent_modulo_param_rename(
lhs_inputs: &[(Arc<str>, SortExpr, crate::op::Implicit)],
lhs_output: &SortExpr,
rhs_inputs: &[(Arc<str>, SortExpr, crate::op::Implicit)],
rhs_output: &SortExpr,
) -> bool {
if lhs_inputs.len() != rhs_inputs.len() {
return false;
}
let rename = positional_param_rename(
lhs_inputs.iter().map(|(n, _, _)| Arc::clone(n)),
rhs_inputs.iter().map(|(n, _, _)| Arc::clone(n)),
);
for ((_, lhs_sort, l_imp), (_, rhs_sort, r_imp)) in lhs_inputs.iter().zip(rhs_inputs.iter()) {
if l_imp != r_imp {
return false;
}
if !lhs_sort.subst(&rename).alpha_eq(rhs_sort) {
return false;
}
}
lhs_output.subst(&rename).alpha_eq(rhs_output)
}
#[must_use]
pub fn sort_params_equivalent_modulo_rename(lhs: &[SortParam], rhs: &[SortParam]) -> bool {
if lhs.len() != rhs.len() {
return false;
}
let rename = positional_param_rename(
lhs.iter().map(|p| Arc::clone(&p.name)),
rhs.iter().map(|p| Arc::clone(&p.name)),
);
lhs.iter()
.zip(rhs.iter())
.all(|(l, r)| l.sort.subst(&rename).alpha_eq(&r.sort))
}
impl PartialEq for SortExpr {
fn eq(&self, other: &Self) -> bool {
self.head() == other.head() && self.args() == other.args()
}
}
impl std::hash::Hash for SortExpr {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.head().hash(state);
self.args().hash(state);
}
}
impl<'de> serde::Deserialize<'de> for SortExpr {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(serde::Deserialize)]
#[serde(untagged)]
enum Raw {
Name(Arc<str>),
App { name: Arc<str>, args: Vec<Term> },
}
match Raw::deserialize(deserializer)? {
Raw::Name(n) => Ok(Self::Name(n)),
Raw::App { name, args } => Ok(Self::app(name, args)),
}
}
}
impl std::fmt::Display for SortExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Name(n) => f.write_str(n),
Self::App { name, args } => {
f.write_str(name)?;
f.write_str("(")?;
for (i, a) in args.iter().enumerate() {
if i > 0 {
f.write_str(", ")?;
}
write!(f, "{a}")?;
}
f.write_str(")")
}
}
}
}
impl std::fmt::Display for Term {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Var(n) => f.write_str(n),
Self::App { op, args } if args.is_empty() => write!(f, "{op}()"),
Self::App { op, args } => {
write!(f, "{op}(")?;
for (i, a) in args.iter().enumerate() {
if i > 0 {
f.write_str(", ")?;
}
write!(f, "{a}")?;
}
f.write_str(")")
}
Self::Hole { name } => match name {
Some(n) => write!(f, "?{n}"),
None => f.write_str("?"),
},
Self::Let { name, bound, body } => {
write!(f, "let {name} = {bound} in {body}")
}
Self::Case {
scrutinee,
branches,
} => {
write!(f, "case {scrutinee} of ")?;
for (i, b) in branches.iter().enumerate() {
if i > 0 {
f.write_str(" | ")?;
}
write!(f, "{}(", b.constructor)?;
for (j, binder) in b.binders.iter().enumerate() {
if j > 0 {
f.write_str(", ")?;
}
f.write_str(binder)?;
}
write!(f, ") => {}", b.body)?;
}
f.write_str(" end")
}
}
}
}
impl From<&str> for SortExpr {
fn from(s: &str) -> Self {
Self::Name(Arc::from(s))
}
}
impl From<String> for SortExpr {
fn from(s: String) -> Self {
Self::Name(Arc::from(s))
}
}
impl From<Arc<str>> for SortExpr {
fn from(s: Arc<str>) -> Self {
Self::Name(s)
}
}
impl From<&Arc<str>> for SortExpr {
fn from(s: &Arc<str>) -> Self {
Self::Name(Arc::clone(s))
}
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, Default, serde::Serialize, serde::Deserialize,
)]
#[non_exhaustive]
pub enum CoercionClass {
#[default]
Iso,
Retraction,
Projection,
Opaque,
}
impl CoercionClass {
#[must_use]
pub const fn compose(self, other: Self) -> Self {
match (self, other) {
(Self::Iso, x) | (x, Self::Iso) => x,
(Self::Opaque, _) | (_, Self::Opaque) => Self::Opaque,
(Self::Retraction, Self::Retraction) => Self::Retraction,
(Self::Projection, Self::Projection) => Self::Projection,
(Self::Retraction, Self::Projection) | (Self::Projection, Self::Retraction) => {
Self::Opaque
}
}
}
#[must_use]
pub const fn is_lossless(self) -> bool {
matches!(self, Self::Iso)
}
#[must_use]
pub const fn needs_complement_storage(self) -> bool {
matches!(self, Self::Retraction | Self::Opaque)
}
#[must_use]
pub const fn all() -> &'static [Self] {
const fn _exhaustiveness_witness(c: CoercionClass) {
match c {
CoercionClass::Iso
| CoercionClass::Retraction
| CoercionClass::Projection
| CoercionClass::Opaque => {}
}
}
const ALL: &[CoercionClass] = &[
CoercionClass::Iso,
CoercionClass::Retraction,
CoercionClass::Projection,
CoercionClass::Opaque,
];
ALL
}
}
impl PartialOrd for CoercionClass {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for CoercionClass {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
const fn rank(c: CoercionClass) -> u8 {
match c {
CoercionClass::Iso => 0,
CoercionClass::Retraction => 1,
CoercionClass::Projection => 2,
CoercionClass::Opaque => 3,
}
}
rank(*self).cmp(&rank(*other))
}
}
#[must_use]
pub const fn classify_builtin_coercion(
op: panproto_expr::BuiltinOp,
) -> Option<(ValueKind, ValueKind, CoercionClass)> {
use panproto_expr::BuiltinOp;
match op {
BuiltinOp::IntToFloat => {
Some((ValueKind::Int, ValueKind::Float, CoercionClass::Retraction))
}
BuiltinOp::FloatToInt => Some((ValueKind::Float, ValueKind::Int, CoercionClass::Opaque)),
BuiltinOp::IntToStr => Some((ValueKind::Int, ValueKind::Str, CoercionClass::Retraction)),
BuiltinOp::FloatToStr => Some((ValueKind::Float, ValueKind::Str, CoercionClass::Opaque)),
BuiltinOp::StrToInt => Some((ValueKind::Str, ValueKind::Int, CoercionClass::Opaque)),
BuiltinOp::StrToFloat => Some((ValueKind::Str, ValueKind::Float, CoercionClass::Opaque)),
_ => None,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub enum ValueKind {
Bool,
Int,
Float,
Str,
Bytes,
Token,
Null,
Any,
}
impl ValueKind {
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
Self::Bool => "boolean",
Self::Int => "integer",
Self::Float => "number",
Self::Str => "string",
Self::Bytes => "bytes",
Self::Token => "token",
Self::Null => "null",
Self::Any => "any",
}
}
#[must_use]
pub const fn all() -> &'static [Self] {
const fn _exhaustiveness_witness(k: ValueKind) {
match k {
ValueKind::Bool
| ValueKind::Int
| ValueKind::Float
| ValueKind::Str
| ValueKind::Bytes
| ValueKind::Token
| ValueKind::Null
| ValueKind::Any => {}
}
}
const ALL: &[ValueKind] = &[
ValueKind::Bool,
ValueKind::Int,
ValueKind::Float,
ValueKind::Str,
ValueKind::Bytes,
ValueKind::Token,
ValueKind::Null,
ValueKind::Any,
];
ALL
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub enum SortKind {
#[default]
Structural,
Val(ValueKind),
Coercion {
from: ValueKind,
to: ValueKind,
class: CoercionClass,
},
Merger(ValueKind),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub struct SortParam {
pub name: Arc<str>,
pub sort: SortExpr,
}
#[derive(Debug, Default, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub enum SortClosure {
#[default]
Open,
Closed(Vec<Arc<str>>),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub struct Sort {
pub name: Arc<str>,
pub params: Vec<SortParam>,
#[serde(default)]
pub kind: SortKind,
#[serde(default)]
pub closure: SortClosure,
}
impl Sort {
#[must_use]
pub fn simple(name: impl Into<Arc<str>>) -> Self {
Self {
name: name.into(),
params: Vec::new(),
kind: SortKind::default(),
closure: SortClosure::Open,
}
}
#[must_use]
pub fn dependent(name: impl Into<Arc<str>>, params: Vec<SortParam>) -> Self {
Self {
name: name.into(),
params,
kind: SortKind::default(),
closure: SortClosure::Open,
}
}
#[must_use]
pub fn with_kind(name: impl Into<Arc<str>>, kind: SortKind) -> Self {
Self {
name: name.into(),
params: Vec::new(),
kind,
closure: SortClosure::Open,
}
}
#[must_use]
pub fn closed<I, S>(name: impl Into<Arc<str>>, params: Vec<SortParam>, constructors: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<Arc<str>>,
{
Self {
name: name.into(),
params,
kind: SortKind::default(),
closure: SortClosure::Closed(constructors.into_iter().map(Into::into).collect()),
}
}
#[must_use]
pub fn default_vertex_kind(&self) -> Arc<str> {
match &self.kind {
SortKind::Val(vk) => Arc::from(vk.as_str()),
SortKind::Structural | SortKind::Coercion { .. } | SortKind::Merger(_) => {
Arc::clone(&self.name)
}
}
}
#[must_use]
pub fn is_simple(&self) -> bool {
self.params.is_empty()
}
#[must_use]
pub fn arity(&self) -> usize {
self.params.len()
}
}
impl SortParam {
#[must_use]
pub fn new(name: impl Into<Arc<str>>, sort: impl Into<SortExpr>) -> Self {
Self {
name: name.into(),
sort: sort.into(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn simple_sort() {
let s = Sort::simple("Vertex");
assert!(s.is_simple());
assert_eq!(s.arity(), 0);
assert_eq!(&*s.name, "Vertex");
}
#[test]
fn dependent_sort() {
let s = Sort::dependent(
"Hom",
vec![SortParam::new("a", "Ob"), SortParam::new("b", "Ob")],
);
assert!(!s.is_simple());
assert_eq!(s.arity(), 2);
}
#[test]
fn sort_expr_from_str() {
let e: SortExpr = "Ob".into();
assert_eq!(e, SortExpr::Name(Arc::from("Ob")));
assert_eq!(&**e.head(), "Ob");
}
#[test]
fn sort_expr_app_head() {
let e = SortExpr::App {
name: Arc::from("Hom"),
args: vec![Term::var("x"), Term::var("y")],
};
assert_eq!(&**e.head(), "Hom");
assert_eq!(e.args().len(), 2);
}
#[test]
fn sort_expr_alpha_eq_name_vs_empty_app() {
let a = SortExpr::Name(Arc::from("Ob"));
let b = SortExpr::App {
name: Arc::from("Ob"),
args: Vec::new(),
};
assert!(a.alpha_eq(&b));
assert!(b.alpha_eq(&a));
}
#[test]
fn sort_expr_alpha_eq_structural() {
let a = SortExpr::App {
name: Arc::from("Hom"),
args: vec![Term::var("x"), Term::var("y")],
};
let b = SortExpr::App {
name: Arc::from("Hom"),
args: vec![Term::var("x"), Term::var("y")],
};
let c = SortExpr::App {
name: Arc::from("Hom"),
args: vec![Term::var("y"), Term::var("x")],
};
assert!(a.alpha_eq(&b));
assert!(!a.alpha_eq(&c));
}
#[test]
fn sort_expr_subst() {
let e = SortExpr::App {
name: Arc::from("Hom"),
args: vec![Term::var("x"), Term::var("y")],
};
let mut mapping: FxHashMap<Arc<str>, Term> = FxHashMap::default();
mapping.insert(Arc::from("x"), Term::constant("a"));
mapping.insert(Arc::from("y"), Term::constant("b"));
let result = e.subst(&mapping);
assert_eq!(
result,
SortExpr::App {
name: Arc::from("Hom"),
args: vec![Term::constant("a"), Term::constant("b")],
}
);
}
#[test]
fn sort_expr_subst_name_unchanged() {
let e = SortExpr::Name(Arc::from("Ob"));
let mut mapping: FxHashMap<Arc<str>, Term> = FxHashMap::default();
mapping.insert(Arc::from("x"), Term::constant("a"));
assert_eq!(e.subst(&mapping), e);
}
#[test]
fn sort_expr_serde_name_is_bare_string() -> Result<(), Box<dyn std::error::Error>> {
let e = SortExpr::Name(Arc::from("Ob"));
let s = serde_json::to_string(&e)?;
assert_eq!(s, "\"Ob\"");
let back: SortExpr = serde_json::from_str(&s)?;
assert_eq!(back, e);
Ok(())
}
#[test]
fn sort_expr_serde_app_is_struct() -> Result<(), Box<dyn std::error::Error>> {
let e = SortExpr::App {
name: Arc::from("Hom"),
args: vec![Term::var("x"), Term::var("y")],
};
let s = serde_json::to_string(&e)?;
let back: SortExpr = serde_json::from_str(&s)?;
assert_eq!(back, e);
Ok(())
}
#[test]
fn sort_expr_display() {
let e = SortExpr::App {
name: Arc::from("Hom"),
args: vec![Term::var("x"), Term::var("y")],
};
assert_eq!(format!("{e}"), "Hom(x, y)");
let n = SortExpr::Name(Arc::from("Ob"));
assert_eq!(format!("{n}"), "Ob");
}
const ALL_CLASSES: [CoercionClass; 4] = [
CoercionClass::Iso,
CoercionClass::Retraction,
CoercionClass::Projection,
CoercionClass::Opaque,
];
#[test]
fn coercion_class_identity() {
for &x in &ALL_CLASSES {
assert_eq!(CoercionClass::Iso.compose(x), x, "Iso is left identity");
assert_eq!(x.compose(CoercionClass::Iso), x, "Iso is right identity");
}
}
#[test]
fn coercion_class_absorption() {
for &x in &ALL_CLASSES {
assert_eq!(
CoercionClass::Opaque.compose(x),
CoercionClass::Opaque,
"Opaque absorbs on left"
);
assert_eq!(
x.compose(CoercionClass::Opaque),
CoercionClass::Opaque,
"Opaque absorbs on right"
);
}
}
#[test]
fn coercion_class_associativity() {
for &a in &ALL_CLASSES {
for &b in &ALL_CLASSES {
for &c in &ALL_CLASSES {
assert_eq!(
a.compose(b).compose(c),
a.compose(b.compose(c)),
"associativity: ({a:?} . {b:?}) . {c:?} == {a:?} . ({b:?} . {c:?})"
);
}
}
}
}
#[test]
fn coercion_class_commutativity() {
for &a in &ALL_CLASSES {
for &b in &ALL_CLASSES {
assert_eq!(
a.compose(b),
b.compose(a),
"commutativity: {a:?} . {b:?} == {b:?} . {a:?}"
);
}
}
}
#[test]
fn coercion_class_ordering_consistent_with_compose() {
for &a in &ALL_CLASSES {
for &b in &ALL_CLASSES {
let composed = a.compose(b);
assert!(
composed >= a,
"compose({a:?}, {b:?}) = {composed:?} should be >= {a:?}"
);
assert!(
composed >= b,
"compose({a:?}, {b:?}) = {composed:?} should be >= {b:?}"
);
}
}
}
#[test]
fn classify_builtin_coercion_coverage() {
use panproto_expr::BuiltinOp;
let coercion_ops = [
BuiltinOp::IntToFloat,
BuiltinOp::FloatToInt,
BuiltinOp::IntToStr,
BuiltinOp::FloatToStr,
BuiltinOp::StrToInt,
BuiltinOp::StrToFloat,
];
for op in coercion_ops {
assert!(
classify_builtin_coercion(op).is_some(),
"{op:?} should be classified"
);
}
assert!(classify_builtin_coercion(BuiltinOp::Add).is_none());
assert!(classify_builtin_coercion(BuiltinOp::Concat).is_none());
}
#[test]
fn no_builtin_classified_as_iso() {
use panproto_expr::BuiltinOp;
let coercion_ops = [
BuiltinOp::IntToFloat,
BuiltinOp::FloatToInt,
BuiltinOp::IntToStr,
BuiltinOp::FloatToStr,
BuiltinOp::StrToInt,
BuiltinOp::StrToFloat,
];
for op in coercion_ops {
if let Some((_, _, class)) = classify_builtin_coercion(op) {
assert_ne!(
class,
CoercionClass::Iso,
"{op:?} should not be classified as Iso"
);
}
}
}
#[test]
fn needs_complement_storage_consistent_with_lattice() {
assert!(
!CoercionClass::Iso.needs_complement_storage(),
"Iso: lossless, no storage"
);
assert!(
CoercionClass::Retraction.needs_complement_storage(),
"Retraction: stores residual"
);
assert!(
!CoercionClass::Projection.needs_complement_storage(),
"Projection: derived value re-computed, no storage"
);
assert!(
CoercionClass::Opaque.needs_complement_storage(),
"Opaque: stores entire original"
);
}
#[test]
fn projection_compose_laws() {
assert_eq!(
CoercionClass::Projection.compose(CoercionClass::Projection),
CoercionClass::Projection,
"Projection . Projection = Projection (projections compose)"
);
assert_eq!(
CoercionClass::Retraction.compose(CoercionClass::Projection),
CoercionClass::Opaque,
"Retraction . Projection = Opaque (diamond lattice meet)"
);
assert_eq!(
CoercionClass::Projection.compose(CoercionClass::Retraction),
CoercionClass::Opaque,
"Projection . Retraction = Opaque (commutativity of meet)"
);
assert_eq!(
CoercionClass::Iso.compose(CoercionClass::Projection),
CoercionClass::Projection,
);
assert_eq!(
CoercionClass::Opaque.compose(CoercionClass::Projection),
CoercionClass::Opaque,
);
}
#[test]
fn empty_args_app_normalizes_to_name() {
let raw = SortExpr::App {
name: Arc::from("Ob"),
args: Vec::new(),
};
let n = raw.normalize();
assert!(matches!(n, SortExpr::Name(ref s) if &**s == "Ob"));
assert_eq!(n.clone().normalize(), n);
}
#[test]
fn smart_constructor_collapses_empty_args() {
let v = SortExpr::app("Ob", Vec::new());
assert!(matches!(v, SortExpr::Name(_)));
}
#[test]
fn smart_constructor_preserves_nonempty() {
let v = SortExpr::app("Hom", vec![Term::var("x"), Term::var("y")]);
assert!(matches!(v, SortExpr::App { .. }));
}
#[test]
fn eq_treats_name_and_empty_app_equal() {
let a = SortExpr::Name(Arc::from("Ob"));
let b = SortExpr::App {
name: Arc::from("Ob"),
args: Vec::new(),
};
assert_eq!(a, b);
}
#[test]
fn hash_agrees_with_eq_across_spellings() {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let a = SortExpr::Name(Arc::from("Ob"));
let b = SortExpr::App {
name: Arc::from("Ob"),
args: Vec::new(),
};
let hash = |v: &SortExpr| {
let mut h = DefaultHasher::new();
v.hash(&mut h);
h.finish()
};
assert_eq!(hash(&a), hash(&b));
}
#[test]
fn hashmap_lookup_crosses_spellings() {
let mut m: FxHashMap<SortExpr, usize> = FxHashMap::default();
m.insert(SortExpr::Name(Arc::from("Ob")), 1);
let key = SortExpr::App {
name: Arc::from("Ob"),
args: Vec::new(),
};
assert_eq!(m.get(&key).copied(), Some(1));
}
#[test]
fn subst_produces_normalized_output() {
let e = SortExpr::App {
name: Arc::from("S"),
args: vec![Term::var("x")],
};
let mut mapping: FxHashMap<Arc<str>, Term> = FxHashMap::default();
mapping.insert(Arc::from("x"), Term::constant("c"));
let r = e.subst(&mapping);
assert!(matches!(r, SortExpr::App { .. }));
}
#[test]
fn rename_head_normalizes_empty_app() {
let e = SortExpr::App {
name: Arc::from("Ob"),
args: Vec::new(),
};
let mut sm: std::collections::HashMap<Arc<str>, Arc<str>> =
std::collections::HashMap::new();
sm.insert(Arc::from("Ob"), Arc::from("Obj"));
let r = e.rename_head(&sm);
assert!(matches!(r, SortExpr::Name(ref n) if &**n == "Obj"));
}
#[test]
fn deserialize_empty_args_app_normalizes() -> Result<(), Box<dyn std::error::Error>> {
let json = r#"{"name":"Ob","args":[]}"#;
let v: SortExpr = serde_json::from_str(json)?;
assert!(matches!(v, SortExpr::Name(ref n) if &**n == "Ob"));
Ok(())
}
#[test]
fn positional_rename_identity_is_empty() {
let r = positional_param_rename(
[Arc::from("a"), Arc::from("b")],
[Arc::from("a"), Arc::from("b")],
);
assert!(r.is_empty(), "identity rename should be empty");
}
#[test]
fn positional_rename_maps_differing_names_only() {
let r = positional_param_rename(
[Arc::from("a"), Arc::from("y"), Arc::from("c")],
[Arc::from("x"), Arc::from("y"), Arc::from("z")],
);
assert_eq!(r.len(), 2);
assert_eq!(r.get(&Arc::from("a")), Some(&Term::var("x")));
assert_eq!(r.get(&Arc::from("c")), Some(&Term::var("z")));
assert!(!r.contains_key(&Arc::from("y")));
}
#[test]
fn signature_equivalence_accepts_alpha_variant() {
use crate::op::Implicit;
let lhs_inputs = vec![(Arc::from("a"), SortExpr::from("Ob"), Implicit::No)];
let lhs_output = SortExpr::App {
name: Arc::from("Hom"),
args: vec![Term::var("a"), Term::var("a")],
};
let rhs_inputs = vec![(Arc::from("x"), SortExpr::from("Ob"), Implicit::No)];
let rhs_output = SortExpr::App {
name: Arc::from("Hom"),
args: vec![Term::var("x"), Term::var("x")],
};
assert!(signatures_equivalent_modulo_param_rename(
&lhs_inputs,
&lhs_output,
&rhs_inputs,
&rhs_output,
));
}
#[test]
fn signature_equivalence_rejects_swap() {
use crate::op::Implicit;
let hom = |a: &str, b: &str| SortExpr::App {
name: Arc::from("Hom"),
args: vec![Term::var(a), Term::var(b)],
};
let lhs_inputs = vec![
(Arc::from("x"), SortExpr::from("Ob"), Implicit::No),
(Arc::from("y"), SortExpr::from("Ob"), Implicit::No),
];
let rhs_inputs = lhs_inputs.clone();
assert!(!signatures_equivalent_modulo_param_rename(
&lhs_inputs,
&hom("x", "y"),
&rhs_inputs,
&hom("y", "x"),
));
}
#[test]
fn signature_equivalence_rejects_arity_mismatch() {
use crate::op::Implicit;
let lhs_inputs = vec![(Arc::from("x"), SortExpr::from("Ob"), Implicit::No)];
let rhs_inputs: Vec<(Arc<str>, SortExpr, Implicit)> = Vec::new();
assert!(!signatures_equivalent_modulo_param_rename(
&lhs_inputs,
&SortExpr::from("Ob"),
&rhs_inputs,
&SortExpr::from("Ob"),
));
}
#[test]
fn sort_params_rename_alpha_equivalent() {
let lhs = vec![SortParam::new("a", "Ob"), SortParam::new("b", "Ob")];
let rhs = vec![SortParam::new("p", "Ob"), SortParam::new("q", "Ob")];
assert!(sort_params_equivalent_modulo_rename(&lhs, &rhs));
}
#[test]
fn sort_params_rename_detects_dependent_difference() {
let lhs = vec![
SortParam::new("Gamma", "Ctx"),
SortParam::new(
"A",
SortExpr::App {
name: Arc::from("Ty"),
args: vec![Term::var("Gamma")],
},
),
];
let rhs = vec![
SortParam::new("G", "Ctx"),
SortParam::new(
"A",
SortExpr::App {
name: Arc::from("Ty"),
args: vec![Term::var("G")],
},
),
];
assert!(sort_params_equivalent_modulo_rename(&lhs, &rhs));
}
#[test]
fn sort_params_rename_rejects_genuine_difference() {
let lhs = vec![
SortParam::new("Gamma", "Ctx"),
SortParam::new(
"A",
SortExpr::App {
name: Arc::from("Ty"),
args: vec![Term::var("Gamma")],
},
),
];
let rhs = vec![
SortParam::new("G", "Ctx"),
SortParam::new(
"A",
SortExpr::App {
name: Arc::from("Ty"),
args: vec![Term::var("A")],
},
),
];
assert!(!sort_params_equivalent_modulo_rename(&lhs, &rhs));
}
mod property {
use super::*;
use proptest::prelude::*;
fn arb_name() -> impl Strategy<Value = Arc<str>> {
prop::sample::select(&["S", "T", "Hom", "Tm", "Ob"][..]).prop_map(Arc::from)
}
fn arb_term(depth: usize) -> BoxedStrategy<Term> {
if depth == 0 {
prop::sample::select(&["x", "y", "z"][..])
.prop_map(|s| Term::var(Arc::from(s)))
.boxed()
} else {
let leaf = prop::sample::select(&["x", "y", "z"][..])
.prop_map(|s| Term::var(Arc::from(s)));
let app = (
prop::sample::select(&["f", "g"][..]).prop_map(Arc::from),
prop::collection::vec(arb_term(depth - 1), 0..=2),
)
.prop_map(|(op, args)| Term::App { op, args });
prop_oneof![leaf, app].boxed()
}
}
fn arb_sort_expr() -> BoxedStrategy<SortExpr> {
prop_oneof![
arb_name().prop_map(SortExpr::Name),
(arb_name(), prop::collection::vec(arb_term(1), 0..=3))
.prop_map(|(name, args)| SortExpr::app(name, args))
]
.boxed()
}
fn arb_subst() -> BoxedStrategy<FxHashMap<Arc<str>, Term>> {
prop::collection::vec(
(
prop::sample::select(&["x", "y", "z"][..]).prop_map(Arc::from),
arb_term(1),
),
0..=3,
)
.prop_map(|pairs| {
let mut m = FxHashMap::default();
for (k, v) in pairs {
m.insert(k, v);
}
m
})
.boxed()
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(256))]
#[test]
fn subst_empty_is_identity(e in arb_sort_expr()) {
let empty = FxHashMap::default();
prop_assert_eq!(e.subst(&empty), e);
}
#[test]
fn subst_preserves_head(e in arb_sort_expr(), sigma in arb_subst()) {
let after = e.subst(&sigma);
prop_assert_eq!(e.head(), after.head());
}
#[test]
fn normalization_is_idempotent(e in arb_sort_expr()) {
let n1 = e.normalize();
let n2 = n1.clone().normalize();
prop_assert_eq!(n1, n2);
}
#[test]
fn sig_equivalence_is_reflexive(
raw_inputs in prop::collection::vec(
(prop::sample::select(&["x", "y", "z"][..]).prop_map(Arc::from), arb_sort_expr()),
0..=3,
),
output in arb_sort_expr(),
) {
let inputs: Vec<(Arc<str>, SortExpr, crate::op::Implicit)> = raw_inputs
.into_iter()
.map(|(n, s)| (n, s, crate::op::Implicit::No))
.collect();
prop_assert!(signatures_equivalent_modulo_param_rename(
&inputs, &output, &inputs, &output,
));
}
#[test]
fn sig_equivalence_under_alpha_rename(
sort_name in arb_name(),
first in prop::sample::select(&["x", "y", "z"][..]).prop_map(Arc::from),
replacement in prop::sample::select(&["p", "q", "r"][..]).prop_map(Arc::from),
) {
let lhs_inputs: Vec<(Arc<str>, SortExpr, crate::op::Implicit)> = vec![(
Arc::clone(&first),
SortExpr::Name(Arc::clone(&sort_name)),
crate::op::Implicit::No,
)];
let lhs_output = SortExpr::App {
name: Arc::clone(&sort_name),
args: vec![Term::Var(Arc::clone(&first))],
};
let rhs_inputs: Vec<(Arc<str>, SortExpr, crate::op::Implicit)> = vec![(
Arc::clone(&replacement),
SortExpr::Name(Arc::clone(&sort_name)),
crate::op::Implicit::No,
)];
let rhs_output = SortExpr::App {
name: sort_name,
args: vec![Term::Var(replacement)],
};
prop_assert!(signatures_equivalent_modulo_param_rename(
&lhs_inputs, &lhs_output, &rhs_inputs, &rhs_output,
));
}
#[test]
fn name_and_empty_app_hash_equal(name in arb_name()) {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let a = SortExpr::Name(Arc::clone(&name));
let b = SortExpr::App { name, args: Vec::new() };
let mut ha = DefaultHasher::new();
a.hash(&mut ha);
let mut hb = DefaultHasher::new();
b.hash(&mut hb);
prop_assert_eq!(ha.finish(), hb.finish());
prop_assert_eq!(a, b);
}
}
}
#[test]
fn coercion_class_serde_wire_format_is_pascal_case() {
let cases = [
(CoercionClass::Iso, "\"Iso\""),
(CoercionClass::Retraction, "\"Retraction\""),
(CoercionClass::Projection, "\"Projection\""),
(CoercionClass::Opaque, "\"Opaque\""),
];
for (value, expected) in cases {
match serde_json::to_string(&value) {
Ok(s) => assert_eq!(s, expected, "unexpected wire format"),
Err(e) => panic!("serde failed to serialize a plain enum: {e}"),
}
}
}
}