modus_lib/
unification.rs

1// Modus, a language for building container images
2// Copyright (C) 2022 University College London
3
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU Affero General Public License as
6// published by the Free Software Foundation, either version 3 of the
7// License, or (at your option) any later version.
8
9// This program is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12// GNU Affero General Public License for more details.
13
14// You should have received a copy of the GNU Affero General Public License
15// along with this program.  If not, see <https://www.gnu.org/licenses/>.
16
17use std::collections::HashMap;
18
19use crate::logic;
20use logic::{Clause, Ground, IRTerm, Literal, Predicate};
21
22pub type Substitution<T = IRTerm> = HashMap<T, T>;
23
24impl Ground for Substitution {
25    fn is_ground(&self) -> bool {
26        self.values().all(|v| v.is_ground())
27    }
28}
29
30pub trait Substitute<T> {
31    type Output;
32
33    fn substitute(&self, s: &Substitution<T>) -> Self::Output;
34}
35
36pub trait Rename<T> {
37    fn rename(&self) -> T;
38}
39
40pub trait RenameWithSubstitution<T> {
41    type Output;
42
43    fn rename_with_sub(&self) -> (Self::Output, Substitution<T>);
44}
45
46impl Substitute<IRTerm> for IRTerm {
47    type Output = IRTerm;
48    fn substitute(&self, s: &Substitution<IRTerm>) -> Self::Output {
49        // check hashmap first, in case intention was to replace entire
50        // the IRTerm::Array
51        if let Some(v) = s.get(self) {
52            v.clone()
53        } else if let IRTerm::Array(ts) = self {
54            IRTerm::Array(ts.iter().map(|t| t.substitute(s)).collect())
55        } else {
56            self.clone()
57        }
58    }
59}
60
61impl RenameWithSubstitution<IRTerm> for IRTerm {
62    type Output = IRTerm;
63    fn rename_with_sub(&self) -> (Self::Output, Substitution<IRTerm>) {
64        let s: Substitution<IRTerm> = self
65            .variables(true)
66            .iter()
67            .map(|r| {
68                [(r.clone(), r.rename())]
69                    .iter()
70                    .cloned()
71                    .collect::<Substitution<IRTerm>>()
72            })
73            .reduce(|mut l, r| {
74                l.extend(r);
75                l
76            })
77            .unwrap();
78        (self.substitute(&s), s)
79    }
80}
81
82impl Substitute<IRTerm> for Literal<IRTerm> {
83    type Output = Literal;
84    fn substitute(&self, s: &Substitution<IRTerm>) -> Self::Output {
85        Literal {
86            positive: self.positive,
87            position: self.position.clone(),
88            predicate: self.predicate.clone(),
89            args: self.args.iter().map(|t| t.substitute(s)).collect(),
90        }
91    }
92}
93
94impl RenameWithSubstitution<IRTerm> for Literal<IRTerm> {
95    type Output = Literal<IRTerm>;
96    fn rename_with_sub(&self) -> (Self::Output, Substitution<IRTerm>) {
97        let s: Substitution = self
98            .variables(true)
99            .iter()
100            .map(|r| {
101                [(r.clone(), r.rename())]
102                    .iter()
103                    .cloned()
104                    .collect::<Substitution>()
105            })
106            .reduce(|mut l, r| {
107                l.extend(r);
108                l
109            })
110            .unwrap();
111        (self.substitute(&s), s)
112    }
113}
114
115impl Substitute<IRTerm> for Vec<Literal<IRTerm>> {
116    type Output = Vec<Literal<IRTerm>>;
117    fn substitute(&self, s: &Substitution<IRTerm>) -> Self::Output {
118        self.iter().map(|l| l.substitute(s)).collect()
119    }
120}
121
122impl RenameWithSubstitution<IRTerm> for Vec<Literal<IRTerm>> {
123    type Output = Vec<Literal<IRTerm>>;
124    fn rename_with_sub(&self) -> (Self::Output, Substitution<IRTerm>) {
125        let s: Substitution<IRTerm> = self
126            .iter()
127            .flat_map(|e| e.variables(true))
128            .map(|r| {
129                [(r.clone(), r.rename())]
130                    .iter()
131                    .cloned()
132                    .collect::<Substitution<IRTerm>>()
133            })
134            .reduce(|mut l, r| {
135                l.extend(r);
136                l
137            })
138            .unwrap();
139        (self.substitute(&s), s)
140    }
141}
142
143impl Substitute<IRTerm> for Clause<IRTerm> {
144    type Output = Clause<IRTerm>;
145    fn substitute(&self, s: &Substitution<IRTerm>) -> Self::Output {
146        Clause {
147            head: self.head.substitute(s),
148            body: self.body.iter().map(|t| t.substitute(s)).collect(),
149        }
150    }
151}
152
153impl RenameWithSubstitution<IRTerm> for Clause<IRTerm> {
154    type Output = Clause<IRTerm>;
155    fn rename_with_sub(&self) -> (Self::Output, Substitution<IRTerm>) {
156        let s: Substitution<IRTerm> = self
157            .variables(true)
158            .iter()
159            .map(|r| {
160                [(r.clone(), r.rename())]
161                    .iter()
162                    .cloned()
163                    .collect::<Substitution<IRTerm>>()
164            })
165            .reduce(|mut l, r| {
166                l.extend(r);
167                l
168            })
169            .unwrap_or_default();
170        (self.substitute(&s), s)
171    }
172}
173
174pub fn compose_no_extend(
175    l: &Substitution<IRTerm>,
176    r: &Substitution<IRTerm>,
177) -> Substitution<IRTerm> {
178    let mut result = Substitution::<IRTerm>::new();
179    for (k, v) in l {
180        result.insert(k.clone(), v.substitute(r));
181    }
182    result
183}
184
185pub fn compose_extend(l: &Substitution<IRTerm>, r: &Substitution<IRTerm>) -> Substitution<IRTerm> {
186    let mut result = compose_no_extend(l, r);
187    result.extend(r.clone());
188    result
189}
190
191impl Literal<IRTerm> {
192    pub fn unify(&self, other: &Literal<IRTerm>) -> Option<Substitution<IRTerm>> {
193        fn unify_arglist(current: &[IRTerm], other: &[IRTerm]) -> Option<Substitution<IRTerm>> {
194            let mut s = Substitution::<IRTerm>::new();
195            for (self_term, other_term) in current.iter().zip(other) {
196                let self_term_subs = self_term.substitute(&s);
197                let other_term_subs = other_term.substitute(&s);
198                if self_term_subs != other_term_subs {
199                    match (self_term_subs.clone(), other_term_subs.clone()) {
200                        // cannot unify if they are both different constants
201                        (IRTerm::Constant(_), IRTerm::Constant(_)) => return None,
202
203                        (IRTerm::Array(terms1), IRTerm::Array(terms2)) => {
204                            if terms1.len() == terms2.len() {
205                                if let Some(new_sub) = unify_arglist(&terms1, &terms2) {
206                                    s = compose_extend(&s, &new_sub);
207                                } else {
208                                    return None;
209                                }
210                            } else {
211                                return None;
212                            }
213                        }
214                        (IRTerm::Array(_), IRTerm::Constant(_))
215                        | (IRTerm::Constant(_), IRTerm::Array(_)) => return None,
216                        (IRTerm::Array(ts), v) | (v, IRTerm::Array(ts)) => {
217                            let mut upd = Substitution::<IRTerm>::new();
218                            upd.insert(v.clone(), IRTerm::Array(ts));
219                            s = compose_extend(&s, &upd);
220                        }
221
222                        (IRTerm::Constant(_), v) => {
223                            let mut upd = Substitution::<IRTerm>::new();
224                            upd.insert(v.clone(), self_term_subs.clone());
225                            s = compose_extend(&s, &upd);
226                        }
227                        (v1, v2) => {
228                            let mut upd = Substitution::<IRTerm>::new();
229                            upd.insert(v1.clone(), v2.clone());
230                            s = compose_extend(&s, &upd);
231                        }
232                    }
233                }
234            }
235            Some(s)
236        }
237
238        if self.signature() != other.signature() {
239            return None;
240        }
241        unify_arglist(&self.args, &other.args)
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use serial_test::serial;
248
249    use super::*;
250
251    #[test]
252    fn simple_unifier() {
253        let l: logic::Literal = "a(X, \"c\")".parse().unwrap();
254        let m: logic::Literal = "a(\"d\", Y)".parse().unwrap();
255        let result = l.unify(&m);
256        assert!(result.is_some());
257        let mgu = result.unwrap();
258        assert_eq!(l.substitute(&mgu), m.substitute(&mgu))
259    }
260
261    #[test]
262    fn complex_unifier() {
263        let l: logic::Literal = "p(Y, Y, V, W)".parse().unwrap();
264        let m: logic::Literal = "p(X, Z, \"a\", U)".parse().unwrap();
265        let result = l.unify(&m);
266        assert!(result.is_some());
267        let mgu = result.unwrap();
268        assert!(l.substitute(&mgu).eq_ignoring_position(&m.substitute(&mgu)));
269        assert_eq!(
270            mgu.get(&logic::IRTerm::UserVariable("Y".into())),
271            Some(&logic::IRTerm::UserVariable("Z".into()))
272        );
273        assert_eq!(
274            mgu.get(&logic::IRTerm::UserVariable("X".into())),
275            Some(&logic::IRTerm::UserVariable("Z".into()))
276        );
277        assert_eq!(
278            mgu.get(&logic::IRTerm::UserVariable("V".into())),
279            Some(&logic::IRTerm::Constant("a".into()))
280        );
281        assert_eq!(
282            mgu.get(&logic::IRTerm::UserVariable("W".into())),
283            Some(&logic::IRTerm::UserVariable("U".into()))
284        );
285    }
286
287    #[test]
288    fn array_unifier() {
289        let l: logic::Literal = "p([X, \"b\"], X)".parse().unwrap();
290        let m: logic::Literal = "p([\"a\", Y], X)".parse().unwrap();
291        let result = l.unify(&m);
292        assert!(result.is_some());
293        let mgu = result.unwrap();
294        assert!(l.substitute(&mgu).eq_ignoring_position(&m.substitute(&mgu)));
295        assert_eq!(
296            mgu.get(&logic::IRTerm::UserVariable("X".into())),
297            Some(&logic::IRTerm::Constant("a".into()))
298        );
299        assert_eq!(
300            mgu.get(&logic::IRTerm::UserVariable("Y".into())),
301            Some(&logic::IRTerm::Constant("b".into()))
302        );
303    }
304
305    #[test]
306    fn compex_array_unifier() {
307        let l: logic::Literal = "p(A, \"b\")".parse().unwrap();
308        let m: logic::Literal = "p([\"a\", X], X)".parse().unwrap();
309        let result = l.unify(&m);
310        assert!(result.is_some());
311        let mgu = result.unwrap();
312        assert!(l.substitute(&mgu).eq_ignoring_position(&m.substitute(&mgu)));
313        assert_eq!(
314            mgu.get(&logic::IRTerm::UserVariable("X".into())),
315            Some(&logic::IRTerm::Constant("b".into()))
316        );
317        assert_eq!(
318            mgu.get(&logic::IRTerm::UserVariable("A".into())),
319            Some(&logic::IRTerm::Array(vec![
320                IRTerm::Constant("a".into()),
321                IRTerm::Constant("b".into())
322            ]))
323        );
324    }
325
326    #[test]
327    fn simple_non_unifiable() {
328        let l: logic::Literal = "a(X, \"b\")".parse().unwrap();
329        let m: logic::Literal = "a(Y)".parse().unwrap();
330        let result = l.unify(&m);
331        assert!(result.is_none());
332    }
333
334    #[test]
335    fn complex_non_unifiable() {
336        let l: logic::Literal = "q(X, \"a\", X, \"b\")".parse().unwrap();
337        let m: logic::Literal = "q(Y, \"a\", \"a\", Y)".parse().unwrap();
338        let result = l.unify(&m);
339        assert!(result.is_none());
340    }
341
342    #[test]
343    #[serial]
344    fn simple_renaming() {
345        let l: logic::Literal = "a(X, X, Y)".parse().unwrap();
346        let (m, _) = l.rename_with_sub();
347        assert!(l != m);
348        assert!(m.args[0] == m.args[1]);
349        assert!(m.args[0] != m.args[2]);
350    }
351
352    #[test]
353    #[serial]
354    fn complex_renaming() {
355        let l: logic::Literal = "a([ X, Y ], X)".parse().unwrap();
356        let (m, _) = l.rename_with_sub();
357        assert!(l != m);
358        if let IRTerm::Array(ts) = &m.args[0] {
359            assert_eq!(ts[0], m.args[1]);
360            assert_ne!(ts[1], m.args[1]);
361        } else {
362            panic!()
363        }
364    }
365}