rem_constraint/
constraint.rs

1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::fmt::Display;
4
5use rem_utils::annotation::Annotated;
6
7/// Abstract encoding of a Local Constraint
8pub trait LocalConstraint: Any + Display + Clone {
9    /// static CHR rules for the constraint system
10    const CHR_RULES: &'static str;
11
12    /// parse a single constraint rule
13    fn parse(s: &str) -> nom::IResult<&str, Self>;
14
15    /// Collect CHR rules from a function definition
16    fn collect<'a>(fun: &Annotated<'a, &'a syn::ItemFn>) -> Vec<Self>;
17}
18
19trait LocalConstraintSystem {
20    fn analyze<'a>(&mut self, fun: &Annotated<'a, &'a syn::ItemFn>);
21    fn constraints(&self) -> Vec<Box<dyn Any>>;
22}
23
24struct ConstraintSystem<C: LocalConstraint> {
25    constraints: Vec<C>,
26}
27
28impl<C: LocalConstraint> Default for ConstraintSystem<C> {
29    fn default() -> Self {
30        ConstraintSystem {
31            constraints: vec![],
32        }
33    }
34}
35
36impl<C: LocalConstraint + 'static> LocalConstraintSystem for ConstraintSystem<C> {
37    fn analyze<'a>(&mut self, fun: &Annotated<'a, &'a syn::ItemFn>) {
38        self.constraints = C::collect(fun);
39        // println!("collected");
40        // for x in &self.constraints {
41        //     println!("collected constraints: {}", x);
42        // }
43        self.constraints = crate::chr::chr_solve(&self.constraints);
44    }
45
46    fn constraints(&self) -> Vec<Box<dyn Any>> {
47        let constraints: Vec<C> = self.constraints.clone();
48        let constraints: Vec<Box<dyn Any>> = constraints
49            .into_iter()
50            .map(|v| {
51                let boxed: Box<dyn Any> = Box::new(v);
52                boxed
53            })
54            .collect::<Vec<_>>();
55        constraints
56    }
57}
58
59pub struct ConstraintManager {
60    /// mapping of type ids to a name + constraint system
61    constraint_systems: HashMap<TypeId, (&'static str, Box<dyn LocalConstraintSystem>)>,
62}
63
64impl Default for ConstraintManager {
65    fn default() -> Self {
66        ConstraintManager {
67            constraint_systems: HashMap::new(),
68        }
69    }
70}
71
72impl Display for ConstraintManager {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        write!(f, "ConstraintManager(")?;
75        for (_, (name, _)) in self.constraint_systems.iter() {
76            write!(f, "{}, ", name)?;
77        }
78        write!(f, ")")?;
79        Ok(())
80    }
81}
82
83impl ConstraintManager {
84    pub fn add_constraint<C: LocalConstraint>(&mut self) {
85        let id = TypeId::of::<C>();
86        let lcs = ConstraintSystem::<C>::default();
87        let name = std::any::type_name::<C>();
88        self.constraint_systems.insert(id, (name, Box::new(lcs)));
89    }
90
91    pub fn get_constraints<C: LocalConstraint>(&self) -> Vec<C> {
92        let id = TypeId::of::<C>();
93        let constraint_system = self.constraint_systems.get(&id);
94
95        match constraint_system {
96            Some((_, lcs)) => lcs
97                .constraints()
98                .into_iter()
99                .map(|boxed| std::boxed::Box::into_inner(boxed.downcast::<C>().unwrap()))
100                .collect(),
101            None => vec![],
102        }
103    }
104
105    pub fn analyze<'a>(&mut self, fun: &Annotated<'a, &'a syn::ItemFn>) {
106        for (_k, (_, v)) in self.constraint_systems.iter_mut() {
107            v.analyze(fun)
108        }
109    }
110}