tla-checker 0.3.9

A TLA+ model checker written in Rust
Documentation
use std::borrow::Cow;
use std::collections::{BTreeMap, BTreeSet};

use crate::ast::{State, Value};

pub struct SymmetryConfig {
    symmetric_sets: Vec<BTreeSet<Value>>,
}

impl SymmetryConfig {
    pub fn new() -> Self {
        Self {
            symmetric_sets: Vec::new(),
        }
    }

    pub fn add_symmetric_set(&mut self, elements: BTreeSet<Value>) {
        if elements.len() > 1 {
            self.symmetric_sets.push(elements);
        }
    }

    pub fn is_empty(&self) -> bool {
        self.symmetric_sets.is_empty()
    }

    pub fn canonicalize<'a>(&self, state: &'a State) -> Cow<'a, State> {
        if self.symmetric_sets.is_empty() {
            return Cow::Borrowed(state);
        }

        let mut result = state.clone();
        for sym_set in &self.symmetric_sets {
            result = self.canonicalize_for_set(&result, sym_set);
        }
        Cow::Owned(result)
    }

    fn canonicalize_for_set(&self, state: &State, sym_set: &BTreeSet<Value>) -> State {
        let ordering = self.compute_element_ordering(state, sym_set);
        let canonical_elements: Vec<_> = sym_set.iter().cloned().collect();

        let mut mapping = BTreeMap::new();
        for (i, original) in ordering.iter().enumerate() {
            if i < canonical_elements.len() {
                mapping.insert(original.clone(), canonical_elements[i].clone());
            }
        }

        self.apply_mapping(state, &mapping)
    }

    fn compute_element_ordering(&self, state: &State, sym_set: &BTreeSet<Value>) -> Vec<Value> {
        let mut seen = BTreeSet::new();
        let mut ordering = Vec::new();

        for value in &state.values {
            self.collect_elements_in_order(value, sym_set, &mut seen, &mut ordering);
        }

        for elem in sym_set {
            if !seen.contains(elem) {
                ordering.push(elem.clone());
            }
        }

        ordering
    }

    fn collect_elements_in_order(
        &self,
        value: &Value,
        sym_set: &BTreeSet<Value>,
        seen: &mut BTreeSet<Value>,
        ordering: &mut Vec<Value>,
    ) {
        if sym_set.contains(value) && !seen.contains(value) {
            seen.insert(value.clone());
            ordering.push(value.clone());
            return;
        }

        match value {
            Value::Set(s) => {
                for elem in s {
                    self.collect_elements_in_order(elem, sym_set, seen, ordering);
                }
            }
            Value::Fn(f) => {
                let mut entries: Vec<_> = f.iter().collect();
                entries.sort_by(|a, b| b.1.cmp(a.1).then_with(|| a.0.cmp(b.0)));
                for (k, v) in entries {
                    self.collect_elements_in_order(v, sym_set, seen, ordering);
                    self.collect_elements_in_order(k, sym_set, seen, ordering);
                }
            }
            Value::Record(r) => {
                for v in r.values() {
                    self.collect_elements_in_order(v, sym_set, seen, ordering);
                }
            }
            Value::Tuple(t) => {
                for elem in t {
                    self.collect_elements_in_order(elem, sym_set, seen, ordering);
                }
            }
            Value::Bool(_) | Value::Int(_) | Value::Str(_) => {}
        }
    }

    fn apply_mapping(&self, state: &State, mapping: &BTreeMap<Value, Value>) -> State {
        let values = state
            .values
            .iter()
            .map(|value| self.apply_mapping_to_value(value, mapping))
            .collect();
        State { values }
    }

    fn apply_mapping_to_value(&self, value: &Value, mapping: &BTreeMap<Value, Value>) -> Value {
        if let Some(mapped) = mapping.get(value) {
            return mapped.clone();
        }

        match value {
            Value::Bool(_) | Value::Int(_) | Value::Str(_) => value.clone(),
            Value::Set(s) => {
                let mapped: BTreeSet<_> = s
                    .iter()
                    .map(|e| self.apply_mapping_to_value(e, mapping))
                    .collect();
                Value::Set(mapped)
            }
            Value::Fn(f) => {
                let mapped: BTreeMap<_, _> = f
                    .iter()
                    .map(|(k, v)| {
                        (
                            self.apply_mapping_to_value(k, mapping),
                            self.apply_mapping_to_value(v, mapping),
                        )
                    })
                    .collect();
                Value::Fn(mapped)
            }
            Value::Record(r) => {
                let mapped: BTreeMap<_, _> = r
                    .iter()
                    .map(|(k, v)| (k.clone(), self.apply_mapping_to_value(v, mapping)))
                    .collect();
                Value::Record(mapped)
            }
            Value::Tuple(t) => {
                let mapped: Vec<_> = t
                    .iter()
                    .map(|e| self.apply_mapping_to_value(e, mapping))
                    .collect();
                Value::Tuple(mapped)
            }
        }
    }
}

impl Default for SymmetryConfig {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use std::sync::Arc;

    use super::*;

    fn str_val(s: &str) -> Value {
        Value::Str(Arc::from(s))
    }

    #[test]
    fn empty_symmetry_returns_same_state() {
        let config = SymmetryConfig::new();
        let state = State {
            values: vec![Value::Int(42)],
        };

        let canonical = config.canonicalize(&state);
        assert_eq!(*canonical, state);
    }

    #[test]
    fn single_element_set_no_change() {
        let mut config = SymmetryConfig::new();
        config.add_symmetric_set(BTreeSet::from([str_val("a")]));

        let state = State {
            values: vec![str_val("a")],
        };

        let canonical = config.canonicalize(&state);
        assert_eq!(*canonical, state);
    }

    #[test]
    fn canonicalize_swaps_elements_based_on_first_occurrence() {
        let mut config = SymmetryConfig::new();
        config.add_symmetric_set(BTreeSet::from([
            str_val("p1"),
            str_val("p2"),
            str_val("p3"),
        ]));

        let mut votes = BTreeMap::new();
        votes.insert(str_val("p2"), Value::Int(1));
        votes.insert(str_val("p1"), Value::Int(0));
        votes.insert(str_val("p3"), Value::Int(0));
        let state = State {
            values: vec![Value::Fn(votes)],
        };

        let canonical = config.canonicalize(&state);

        if let Some(Value::Fn(cvotes)) = canonical.values.first() {
            assert_eq!(cvotes.get(&str_val("p1")), Some(&Value::Int(1)));
            assert_eq!(cvotes.get(&str_val("p2")), Some(&Value::Int(0)));
            assert_eq!(cvotes.get(&str_val("p3")), Some(&Value::Int(0)));
        } else {
            panic!("expected Fn value");
        }
    }

    #[test]
    fn canonicalize_produces_identical_states_for_symmetric_inputs() {
        let mut config = SymmetryConfig::new();
        config.add_symmetric_set(BTreeSet::from([str_val("a"), str_val("b")]));

        let state1 = State {
            values: vec![str_val("a")],
        };
        let state2 = State {
            values: vec![str_val("b")],
        };

        let c1 = config.canonicalize(&state1);
        let c2 = config.canonicalize(&state2);

        assert_eq!(c1, c2);
    }

    #[test]
    fn canonicalize_handles_nested_structures() {
        let mut config = SymmetryConfig::new();
        config.add_symmetric_set(BTreeSet::from([str_val("x"), str_val("y")]));

        let inner_set = BTreeSet::from([str_val("y")]);
        let state = State {
            values: vec![Value::Set(inner_set)],
        };

        let canonical = config.canonicalize(&state);

        if let Some(Value::Set(s)) = canonical.values.first() {
            assert!(s.contains(&str_val("x")));
            assert!(!s.contains(&str_val("y")));
        } else {
            panic!("expected Set value");
        }
    }
}