Skip to main content

panproto_gat/
model.rs

1use std::fmt;
2use std::sync::Arc;
3
4use rustc_hash::FxHashMap;
5
6use crate::error::GatError;
7use crate::morphism::TheoryMorphism;
8
9/// A value in a model interpretation.
10///
11/// `ModelValue` represents the elements that sorts are interpreted as,
12/// and the values that operations produce and consume.
13#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
14#[non_exhaustive]
15pub enum ModelValue {
16    /// A string value.
17    Str(String),
18    /// A 64-bit integer value.
19    Int(i64),
20    /// A boolean value.
21    Bool(bool),
22    /// A list of values.
23    List(Vec<Self>),
24    /// A map of key-value pairs.
25    Map(FxHashMap<String, Self>),
26    /// A null / absent value.
27    Null,
28}
29
30/// An operation interpreter: a function from argument values to a result value.
31///
32/// Wrapped in `Arc` so that `Model` can be cloned and sent across threads.
33type OpInterp = Arc<dyn Fn(&[ModelValue]) -> Result<ModelValue, GatError> + Send + Sync>;
34
35/// A model (interpretation) of a theory in Set.
36///
37/// Maps each sort to a carrier set of values and each operation to a
38/// function on those values. Models are the semantic counterpart of
39/// theories: a theory describes structure abstractly, while a model
40/// provides a concrete instantiation.
41///
42/// `Model` does not derive `Serialize`/`Deserialize` because `op_interp`
43/// contains function pointers (`Arc<dyn Fn(...)>`) which cannot be serialized.
44pub struct Model {
45    /// The name of the theory this model interprets.
46    pub theory: String,
47    /// Sort interpretations: each sort name maps to its carrier set.
48    pub sort_interp: FxHashMap<String, Vec<ModelValue>>,
49    /// Operation interpretations: each operation name maps to a function.
50    pub op_interp: FxHashMap<String, OpInterp>,
51}
52
53impl fmt::Debug for Model {
54    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55        f.debug_struct("Model")
56            .field("theory", &self.theory)
57            .field("sort_interp", &self.sort_interp)
58            .field("op_interp_keys", &self.op_interp.keys().collect::<Vec<_>>())
59            .finish()
60    }
61}
62
63impl Model {
64    /// Create a new model for a given theory.
65    #[must_use]
66    pub fn new(theory: impl Into<String>) -> Self {
67        Self {
68            theory: theory.into(),
69            sort_interp: FxHashMap::default(),
70            op_interp: FxHashMap::default(),
71        }
72    }
73
74    /// Add a sort interpretation (carrier set).
75    pub fn add_sort(&mut self, name: impl Into<String>, values: Vec<ModelValue>) {
76        self.sort_interp.insert(name.into(), values);
77    }
78
79    /// Add an operation interpretation.
80    pub fn add_op<F>(&mut self, name: impl Into<String>, f: F)
81    where
82        F: Fn(&[ModelValue]) -> Result<ModelValue, GatError> + Send + Sync + 'static,
83    {
84        self.op_interp.insert(name.into(), Arc::new(f));
85    }
86
87    /// Evaluate an operation by name on the given arguments.
88    ///
89    /// # Errors
90    ///
91    /// Returns [`GatError::OpNotFound`] if the operation is not in this model,
92    /// or [`GatError::ModelError`] if the operation function itself fails.
93    pub fn eval(&self, op_name: &str, args: &[ModelValue]) -> Result<ModelValue, GatError> {
94        let f = self
95            .op_interp
96            .get(op_name)
97            .ok_or_else(|| GatError::OpNotFound(op_name.to_owned()))?;
98        f(args)
99    }
100}
101
102/// Migrate a model along a theory morphism.
103///
104/// Given a morphism from theory A to theory B and a model of B, produce
105/// a model of A by reindexing sort and operation interpretations via the
106/// morphism's mappings.
107///
108/// Sort interpretations are renamed: if the morphism maps sort `S` to `T`,
109/// then the new model's interpretation for `S` is taken from the original
110/// model's interpretation for `T`.
111///
112/// Operation interpretations are renamed analogously.
113///
114/// # Errors
115///
116/// Returns [`GatError::ModelError`] if a mapped sort or operation is missing
117/// from the source model.
118pub fn migrate_model(morphism: &TheoryMorphism, model: &Model) -> Result<Model, GatError> {
119    let mut new_model = Model::new(&model.theory);
120
121    // Reindex sort interpretations.
122    for (domain_sort, codomain_sort) in &morphism.sort_map {
123        let values = model
124            .sort_interp
125            .get(codomain_sort.as_ref())
126            .ok_or_else(|| {
127                GatError::ModelError(format!(
128                    "sort interpretation for '{codomain_sort}' not found in model"
129                ))
130            })?;
131        new_model
132            .sort_interp
133            .insert(domain_sort.to_string(), values.clone());
134    }
135
136    // Reindex operation interpretations.
137    for (domain_op, codomain_op) in &morphism.op_map {
138        let interp = model.op_interp.get(codomain_op.as_ref()).ok_or_else(|| {
139            GatError::ModelError(format!(
140                "operation interpretation for '{codomain_op}' not found in model"
141            ))
142        })?;
143        new_model
144            .op_interp
145            .insert(domain_op.to_string(), Arc::clone(interp));
146    }
147
148    Ok(new_model)
149}
150
151#[cfg(test)]
152#[allow(clippy::unwrap_used)]
153mod tests {
154    use std::sync::Arc;
155
156    use super::*;
157
158    fn int_val(v: i64) -> ModelValue {
159        ModelValue::Int(v)
160    }
161
162    #[test]
163    fn integer_monoid_model() {
164        let mut model = Model::new("Monoid");
165
166        // Carrier = {0, 1, 2, ..., 9}
167        let carrier: Vec<ModelValue> = (0..10).map(int_val).collect();
168        model.add_sort("Carrier", carrier);
169
170        // mul = addition
171        model.add_op("mul", |args: &[ModelValue]| match (&args[0], &args[1]) {
172            (ModelValue::Int(a), ModelValue::Int(b)) => Ok(ModelValue::Int(a + b)),
173            _ => Err(GatError::ModelError("expected Int arguments".to_owned())),
174        });
175
176        // unit = 0
177        model.add_op("unit", |_args: &[ModelValue]| Ok(ModelValue::Int(0)));
178
179        // Verify mul(3, 4) = 7.
180        let result = model.eval("mul", &[int_val(3), int_val(4)]).unwrap();
181        assert_eq!(result, int_val(7));
182
183        // Verify unit() = 0.
184        let result = model.eval("unit", &[]).unwrap();
185        assert_eq!(result, int_val(0));
186
187        // Verify left identity: mul(unit(), x) = x.
188        let zero = model.eval("unit", &[]).unwrap();
189        let result = model.eval("mul", &[zero, int_val(5)]).unwrap();
190        assert_eq!(result, int_val(5));
191
192        // Verify right identity: mul(x, unit()) = x.
193        let zero = model.eval("unit", &[]).unwrap();
194        let result = model.eval("mul", &[int_val(5), zero]).unwrap();
195        assert_eq!(result, int_val(5));
196
197        // Verify associativity: mul(a, mul(b, c)) = mul(mul(a, b), c).
198        let bc = model.eval("mul", &[int_val(2), int_val(3)]).unwrap();
199        let lhs = model.eval("mul", &[int_val(1), bc]).unwrap();
200        let ab = model.eval("mul", &[int_val(1), int_val(2)]).unwrap();
201        let rhs = model.eval("mul", &[ab, int_val(3)]).unwrap();
202        assert_eq!(lhs, rhs);
203    }
204
205    #[test]
206    fn migrate_model_renames_sorts_and_ops() {
207        let mut model = Model::new("M2");
208        model.add_sort("Carrier", vec![int_val(0), int_val(1)]);
209        model.add_op("times", |args: &[ModelValue]| match (&args[0], &args[1]) {
210            (ModelValue::Int(a), ModelValue::Int(b)) => Ok(ModelValue::Int(a * b)),
211            _ => Err(GatError::ModelError("expected Int".to_owned())),
212        });
213        model.add_op("one", |_: &[ModelValue]| Ok(ModelValue::Int(1)));
214
215        // Morphism: M1 -> M2, mapping mul->times, unit->one.
216        let sort_map =
217            std::collections::HashMap::from([(Arc::from("Carrier"), Arc::from("Carrier"))]);
218        let op_map = std::collections::HashMap::from([
219            (Arc::from("mul"), Arc::from("times")),
220            (Arc::from("unit"), Arc::from("one")),
221        ]);
222
223        let morphism = TheoryMorphism::new("rename", "M1", "M2", sort_map, op_map);
224        let migrated = migrate_model(&morphism, &model).unwrap();
225
226        // Migrated model should have "mul" and "unit" as keys.
227        assert!(migrated.sort_interp.contains_key("Carrier"));
228        assert!(migrated.op_interp.contains_key("mul"));
229        assert!(migrated.op_interp.contains_key("unit"));
230
231        // And the operations should still work.
232        let result = migrated.eval("mul", &[int_val(3), int_val(4)]).unwrap();
233        assert_eq!(result, int_val(12));
234
235        let result = migrated.eval("unit", &[]).unwrap();
236        assert_eq!(result, int_val(1));
237    }
238
239    #[test]
240    fn migrate_model_missing_sort_fails() {
241        let model = Model::new("Empty");
242
243        let sort_map = std::collections::HashMap::from([(Arc::from("S"), Arc::from("Missing"))]);
244
245        let morphism = TheoryMorphism::new(
246            "bad",
247            "X",
248            "Empty",
249            sort_map,
250            std::collections::HashMap::new(),
251        );
252        let result = migrate_model(&morphism, &model);
253        assert!(matches!(result, Err(GatError::ModelError(_))));
254    }
255
256    #[test]
257    fn eval_missing_op_fails() {
258        let model = Model::new("Empty");
259        let result = model.eval("nonexistent", &[]);
260        assert!(matches!(result, Err(GatError::OpNotFound(_))));
261    }
262
263    #[test]
264    fn model_value_serialization_roundtrip() {
265        let values = vec![
266            ModelValue::Str("hello".to_owned()),
267            ModelValue::Int(42),
268            ModelValue::Bool(true),
269            ModelValue::List(vec![ModelValue::Int(1), ModelValue::Int(2)]),
270            ModelValue::Map(FxHashMap::from_iter([(
271                "key".to_owned(),
272                ModelValue::Str("val".to_owned()),
273            )])),
274            ModelValue::Null,
275        ];
276
277        for val in &values {
278            let json = serde_json::to_string(val).unwrap();
279            let roundtripped: ModelValue = serde_json::from_str(&json).unwrap();
280            assert_eq!(val, &roundtripped);
281        }
282    }
283
284    #[test]
285    fn model_value_nested_roundtrip() {
286        let nested = ModelValue::Map(FxHashMap::from_iter([(
287            "list".to_owned(),
288            ModelValue::List(vec![
289                ModelValue::Int(1),
290                ModelValue::Map(FxHashMap::from_iter([(
291                    "inner".to_owned(),
292                    ModelValue::Bool(false),
293                )])),
294            ]),
295        )]));
296
297        let json = serde_json::to_string(&nested).unwrap();
298        let roundtripped: ModelValue = serde_json::from_str(&json).unwrap();
299        assert_eq!(nested, roundtripped);
300    }
301
302    #[test]
303    fn model_debug_format() {
304        let model = Model::new("Test");
305        let debug_str = format!("{model:?}");
306        assert!(debug_str.contains("Test"));
307    }
308}