1use std::fmt;
2use std::sync::Arc;
3
4use rustc_hash::FxHashMap;
5
6use crate::error::GatError;
7use crate::morphism::TheoryMorphism;
8
9#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
14#[non_exhaustive]
15pub enum ModelValue {
16 Str(String),
18 Int(i64),
20 Bool(bool),
22 List(Vec<Self>),
24 Map(FxHashMap<String, Self>),
26 Null,
28}
29
30type OpInterp = Arc<dyn Fn(&[ModelValue]) -> Result<ModelValue, GatError> + Send + Sync>;
34
35pub struct Model {
45 pub theory: String,
47 pub sort_interp: FxHashMap<String, Vec<ModelValue>>,
49 pub op_interp: FxHashMap<String, OpInterp>,
51}
52
53impl fmt::Debug for Model {
54 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55 f.debug_struct("Model")
56 .field("theory", &self.theory)
57 .field("sort_interp", &self.sort_interp)
58 .field("op_interp_keys", &self.op_interp.keys().collect::<Vec<_>>())
59 .finish()
60 }
61}
62
63impl Model {
64 #[must_use]
66 pub fn new(theory: impl Into<String>) -> Self {
67 Self {
68 theory: theory.into(),
69 sort_interp: FxHashMap::default(),
70 op_interp: FxHashMap::default(),
71 }
72 }
73
74 pub fn add_sort(&mut self, name: impl Into<String>, values: Vec<ModelValue>) {
76 self.sort_interp.insert(name.into(), values);
77 }
78
79 pub fn add_op<F>(&mut self, name: impl Into<String>, f: F)
81 where
82 F: Fn(&[ModelValue]) -> Result<ModelValue, GatError> + Send + Sync + 'static,
83 {
84 self.op_interp.insert(name.into(), Arc::new(f));
85 }
86
87 pub fn eval(&self, op_name: &str, args: &[ModelValue]) -> Result<ModelValue, GatError> {
94 let f = self
95 .op_interp
96 .get(op_name)
97 .ok_or_else(|| GatError::OpNotFound(op_name.to_owned()))?;
98 f(args)
99 }
100}
101
102pub fn migrate_model(morphism: &TheoryMorphism, model: &Model) -> Result<Model, GatError> {
119 let mut new_model = Model::new(&model.theory);
120
121 for (domain_sort, codomain_sort) in &morphism.sort_map {
123 let values = model
124 .sort_interp
125 .get(codomain_sort.as_ref())
126 .ok_or_else(|| {
127 GatError::ModelError(format!(
128 "sort interpretation for '{codomain_sort}' not found in model"
129 ))
130 })?;
131 new_model
132 .sort_interp
133 .insert(domain_sort.to_string(), values.clone());
134 }
135
136 for (domain_op, codomain_op) in &morphism.op_map {
138 let interp = model.op_interp.get(codomain_op.as_ref()).ok_or_else(|| {
139 GatError::ModelError(format!(
140 "operation interpretation for '{codomain_op}' not found in model"
141 ))
142 })?;
143 new_model
144 .op_interp
145 .insert(domain_op.to_string(), Arc::clone(interp));
146 }
147
148 Ok(new_model)
149}
150
151#[cfg(test)]
152#[allow(clippy::unwrap_used)]
153mod tests {
154 use std::sync::Arc;
155
156 use super::*;
157
158 fn int_val(v: i64) -> ModelValue {
159 ModelValue::Int(v)
160 }
161
162 #[test]
163 fn integer_monoid_model() {
164 let mut model = Model::new("Monoid");
165
166 let carrier: Vec<ModelValue> = (0..10).map(int_val).collect();
168 model.add_sort("Carrier", carrier);
169
170 model.add_op("mul", |args: &[ModelValue]| match (&args[0], &args[1]) {
172 (ModelValue::Int(a), ModelValue::Int(b)) => Ok(ModelValue::Int(a + b)),
173 _ => Err(GatError::ModelError("expected Int arguments".to_owned())),
174 });
175
176 model.add_op("unit", |_args: &[ModelValue]| Ok(ModelValue::Int(0)));
178
179 let result = model.eval("mul", &[int_val(3), int_val(4)]).unwrap();
181 assert_eq!(result, int_val(7));
182
183 let result = model.eval("unit", &[]).unwrap();
185 assert_eq!(result, int_val(0));
186
187 let zero = model.eval("unit", &[]).unwrap();
189 let result = model.eval("mul", &[zero, int_val(5)]).unwrap();
190 assert_eq!(result, int_val(5));
191
192 let zero = model.eval("unit", &[]).unwrap();
194 let result = model.eval("mul", &[int_val(5), zero]).unwrap();
195 assert_eq!(result, int_val(5));
196
197 let bc = model.eval("mul", &[int_val(2), int_val(3)]).unwrap();
199 let lhs = model.eval("mul", &[int_val(1), bc]).unwrap();
200 let ab = model.eval("mul", &[int_val(1), int_val(2)]).unwrap();
201 let rhs = model.eval("mul", &[ab, int_val(3)]).unwrap();
202 assert_eq!(lhs, rhs);
203 }
204
205 #[test]
206 fn migrate_model_renames_sorts_and_ops() {
207 let mut model = Model::new("M2");
208 model.add_sort("Carrier", vec![int_val(0), int_val(1)]);
209 model.add_op("times", |args: &[ModelValue]| match (&args[0], &args[1]) {
210 (ModelValue::Int(a), ModelValue::Int(b)) => Ok(ModelValue::Int(a * b)),
211 _ => Err(GatError::ModelError("expected Int".to_owned())),
212 });
213 model.add_op("one", |_: &[ModelValue]| Ok(ModelValue::Int(1)));
214
215 let sort_map =
217 std::collections::HashMap::from([(Arc::from("Carrier"), Arc::from("Carrier"))]);
218 let op_map = std::collections::HashMap::from([
219 (Arc::from("mul"), Arc::from("times")),
220 (Arc::from("unit"), Arc::from("one")),
221 ]);
222
223 let morphism = TheoryMorphism::new("rename", "M1", "M2", sort_map, op_map);
224 let migrated = migrate_model(&morphism, &model).unwrap();
225
226 assert!(migrated.sort_interp.contains_key("Carrier"));
228 assert!(migrated.op_interp.contains_key("mul"));
229 assert!(migrated.op_interp.contains_key("unit"));
230
231 let result = migrated.eval("mul", &[int_val(3), int_val(4)]).unwrap();
233 assert_eq!(result, int_val(12));
234
235 let result = migrated.eval("unit", &[]).unwrap();
236 assert_eq!(result, int_val(1));
237 }
238
239 #[test]
240 fn migrate_model_missing_sort_fails() {
241 let model = Model::new("Empty");
242
243 let sort_map = std::collections::HashMap::from([(Arc::from("S"), Arc::from("Missing"))]);
244
245 let morphism = TheoryMorphism::new(
246 "bad",
247 "X",
248 "Empty",
249 sort_map,
250 std::collections::HashMap::new(),
251 );
252 let result = migrate_model(&morphism, &model);
253 assert!(matches!(result, Err(GatError::ModelError(_))));
254 }
255
256 #[test]
257 fn eval_missing_op_fails() {
258 let model = Model::new("Empty");
259 let result = model.eval("nonexistent", &[]);
260 assert!(matches!(result, Err(GatError::OpNotFound(_))));
261 }
262
263 #[test]
264 fn model_value_serialization_roundtrip() {
265 let values = vec![
266 ModelValue::Str("hello".to_owned()),
267 ModelValue::Int(42),
268 ModelValue::Bool(true),
269 ModelValue::List(vec![ModelValue::Int(1), ModelValue::Int(2)]),
270 ModelValue::Map(FxHashMap::from_iter([(
271 "key".to_owned(),
272 ModelValue::Str("val".to_owned()),
273 )])),
274 ModelValue::Null,
275 ];
276
277 for val in &values {
278 let json = serde_json::to_string(val).unwrap();
279 let roundtripped: ModelValue = serde_json::from_str(&json).unwrap();
280 assert_eq!(val, &roundtripped);
281 }
282 }
283
284 #[test]
285 fn model_value_nested_roundtrip() {
286 let nested = ModelValue::Map(FxHashMap::from_iter([(
287 "list".to_owned(),
288 ModelValue::List(vec![
289 ModelValue::Int(1),
290 ModelValue::Map(FxHashMap::from_iter([(
291 "inner".to_owned(),
292 ModelValue::Bool(false),
293 )])),
294 ]),
295 )]));
296
297 let json = serde_json::to_string(&nested).unwrap();
298 let roundtripped: ModelValue = serde_json::from_str(&json).unwrap();
299 assert_eq!(nested, roundtripped);
300 }
301
302 #[test]
303 fn model_debug_format() {
304 let model = Model::new("Test");
305 let debug_str = format!("{model:?}");
306 assert!(debug_str.contains("Test"));
307 }
308}