use std::fmt;
use std::sync::Arc;
use rustc_hash::FxHashMap;
use crate::error::GatError;
use crate::morphism::TheoryMorphism;
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[non_exhaustive]
pub enum ModelValue {
Str(String),
Int(i64),
Bool(bool),
List(Vec<Self>),
Map(FxHashMap<String, Self>),
Null,
}
type OpInterp = Arc<dyn Fn(&[ModelValue]) -> Result<ModelValue, GatError> + Send + Sync>;
pub struct Model {
pub theory: String,
pub sort_interp: FxHashMap<String, Vec<ModelValue>>,
pub op_interp: FxHashMap<String, OpInterp>,
}
impl fmt::Debug for Model {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Model")
.field("theory", &self.theory)
.field("sort_interp", &self.sort_interp)
.field("op_interp_keys", &self.op_interp.keys().collect::<Vec<_>>())
.finish()
}
}
impl Model {
#[must_use]
pub fn new(theory: impl Into<String>) -> Self {
Self {
theory: theory.into(),
sort_interp: FxHashMap::default(),
op_interp: FxHashMap::default(),
}
}
pub fn add_sort(&mut self, name: impl Into<String>, values: Vec<ModelValue>) {
self.sort_interp.insert(name.into(), values);
}
pub fn add_op<F>(&mut self, name: impl Into<String>, f: F)
where
F: Fn(&[ModelValue]) -> Result<ModelValue, GatError> + Send + Sync + 'static,
{
self.op_interp.insert(name.into(), Arc::new(f));
}
pub fn eval(&self, op_name: &str, args: &[ModelValue]) -> Result<ModelValue, GatError> {
let f = self
.op_interp
.get(op_name)
.ok_or_else(|| GatError::OpNotFound(op_name.to_owned()))?;
f(args)
}
}
pub fn migrate_model(morphism: &TheoryMorphism, model: &Model) -> Result<Model, GatError> {
let mut new_model = Model::new(&model.theory);
for (domain_sort, codomain_sort) in &morphism.sort_map {
let values = model
.sort_interp
.get(codomain_sort.as_ref())
.ok_or_else(|| {
GatError::ModelError(format!(
"sort interpretation for '{codomain_sort}' not found in model"
))
})?;
new_model
.sort_interp
.insert(domain_sort.to_string(), values.clone());
}
for (domain_op, codomain_op) in &morphism.op_map {
let interp = model.op_interp.get(codomain_op.as_ref()).ok_or_else(|| {
GatError::ModelError(format!(
"operation interpretation for '{codomain_op}' not found in model"
))
})?;
new_model
.op_interp
.insert(domain_op.to_string(), Arc::clone(interp));
}
Ok(new_model)
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use std::sync::Arc;
use super::*;
fn int_val(v: i64) -> ModelValue {
ModelValue::Int(v)
}
#[test]
fn integer_monoid_model() {
let mut model = Model::new("Monoid");
let carrier: Vec<ModelValue> = (0..10).map(int_val).collect();
model.add_sort("Carrier", carrier);
model.add_op("mul", |args: &[ModelValue]| match (&args[0], &args[1]) {
(ModelValue::Int(a), ModelValue::Int(b)) => Ok(ModelValue::Int(a + b)),
_ => Err(GatError::ModelError("expected Int arguments".to_owned())),
});
model.add_op("unit", |_args: &[ModelValue]| Ok(ModelValue::Int(0)));
let result = model.eval("mul", &[int_val(3), int_val(4)]).unwrap();
assert_eq!(result, int_val(7));
let result = model.eval("unit", &[]).unwrap();
assert_eq!(result, int_val(0));
let zero = model.eval("unit", &[]).unwrap();
let result = model.eval("mul", &[zero, int_val(5)]).unwrap();
assert_eq!(result, int_val(5));
let zero = model.eval("unit", &[]).unwrap();
let result = model.eval("mul", &[int_val(5), zero]).unwrap();
assert_eq!(result, int_val(5));
let bc = model.eval("mul", &[int_val(2), int_val(3)]).unwrap();
let lhs = model.eval("mul", &[int_val(1), bc]).unwrap();
let ab = model.eval("mul", &[int_val(1), int_val(2)]).unwrap();
let rhs = model.eval("mul", &[ab, int_val(3)]).unwrap();
assert_eq!(lhs, rhs);
}
#[test]
fn migrate_model_renames_sorts_and_ops() {
let mut model = Model::new("M2");
model.add_sort("Carrier", vec![int_val(0), int_val(1)]);
model.add_op("times", |args: &[ModelValue]| match (&args[0], &args[1]) {
(ModelValue::Int(a), ModelValue::Int(b)) => Ok(ModelValue::Int(a * b)),
_ => Err(GatError::ModelError("expected Int".to_owned())),
});
model.add_op("one", |_: &[ModelValue]| Ok(ModelValue::Int(1)));
let sort_map =
std::collections::HashMap::from([(Arc::from("Carrier"), Arc::from("Carrier"))]);
let op_map = std::collections::HashMap::from([
(Arc::from("mul"), Arc::from("times")),
(Arc::from("unit"), Arc::from("one")),
]);
let morphism = TheoryMorphism::new("rename", "M1", "M2", sort_map, op_map);
let migrated = migrate_model(&morphism, &model).unwrap();
assert!(migrated.sort_interp.contains_key("Carrier"));
assert!(migrated.op_interp.contains_key("mul"));
assert!(migrated.op_interp.contains_key("unit"));
let result = migrated.eval("mul", &[int_val(3), int_val(4)]).unwrap();
assert_eq!(result, int_val(12));
let result = migrated.eval("unit", &[]).unwrap();
assert_eq!(result, int_val(1));
}
#[test]
fn migrate_model_missing_sort_fails() {
let model = Model::new("Empty");
let sort_map = std::collections::HashMap::from([(Arc::from("S"), Arc::from("Missing"))]);
let morphism = TheoryMorphism::new(
"bad",
"X",
"Empty",
sort_map,
std::collections::HashMap::new(),
);
let result = migrate_model(&morphism, &model);
assert!(matches!(result, Err(GatError::ModelError(_))));
}
#[test]
fn eval_missing_op_fails() {
let model = Model::new("Empty");
let result = model.eval("nonexistent", &[]);
assert!(matches!(result, Err(GatError::OpNotFound(_))));
}
#[test]
fn model_value_serialization_roundtrip() {
let values = vec![
ModelValue::Str("hello".to_owned()),
ModelValue::Int(42),
ModelValue::Bool(true),
ModelValue::List(vec![ModelValue::Int(1), ModelValue::Int(2)]),
ModelValue::Map(FxHashMap::from_iter([(
"key".to_owned(),
ModelValue::Str("val".to_owned()),
)])),
ModelValue::Null,
];
for val in &values {
let json = serde_json::to_string(val).unwrap();
let roundtripped: ModelValue = serde_json::from_str(&json).unwrap();
assert_eq!(val, &roundtripped);
}
}
#[test]
fn model_value_nested_roundtrip() {
let nested = ModelValue::Map(FxHashMap::from_iter([(
"list".to_owned(),
ModelValue::List(vec![
ModelValue::Int(1),
ModelValue::Map(FxHashMap::from_iter([(
"inner".to_owned(),
ModelValue::Bool(false),
)])),
]),
)]));
let json = serde_json::to_string(&nested).unwrap();
let roundtripped: ModelValue = serde_json::from_str(&json).unwrap();
assert_eq!(nested, roundtripped);
}
#[test]
fn model_debug_format() {
let model = Model::new("Test");
let debug_str = format!("{model:?}");
assert!(debug_str.contains("Test"));
}
}