1use std::collections::HashMap;
18
19use crate::logic;
20use logic::{Clause, Ground, IRTerm, Literal, Predicate};
21
22pub type Substitution<T = IRTerm> = HashMap<T, T>;
23
24impl Ground for Substitution {
25 fn is_ground(&self) -> bool {
26 self.values().all(|v| v.is_ground())
27 }
28}
29
30pub trait Substitute<T> {
31 type Output;
32
33 fn substitute(&self, s: &Substitution<T>) -> Self::Output;
34}
35
36pub trait Rename<T> {
37 fn rename(&self) -> T;
38}
39
40pub trait RenameWithSubstitution<T> {
41 type Output;
42
43 fn rename_with_sub(&self) -> (Self::Output, Substitution<T>);
44}
45
46impl Substitute<IRTerm> for IRTerm {
47 type Output = IRTerm;
48 fn substitute(&self, s: &Substitution<IRTerm>) -> Self::Output {
49 if let Some(v) = s.get(self) {
52 v.clone()
53 } else if let IRTerm::Array(ts) = self {
54 IRTerm::Array(ts.iter().map(|t| t.substitute(s)).collect())
55 } else {
56 self.clone()
57 }
58 }
59}
60
61impl RenameWithSubstitution<IRTerm> for IRTerm {
62 type Output = IRTerm;
63 fn rename_with_sub(&self) -> (Self::Output, Substitution<IRTerm>) {
64 let s: Substitution<IRTerm> = self
65 .variables(true)
66 .iter()
67 .map(|r| {
68 [(r.clone(), r.rename())]
69 .iter()
70 .cloned()
71 .collect::<Substitution<IRTerm>>()
72 })
73 .reduce(|mut l, r| {
74 l.extend(r);
75 l
76 })
77 .unwrap();
78 (self.substitute(&s), s)
79 }
80}
81
82impl Substitute<IRTerm> for Literal<IRTerm> {
83 type Output = Literal;
84 fn substitute(&self, s: &Substitution<IRTerm>) -> Self::Output {
85 Literal {
86 positive: self.positive,
87 position: self.position.clone(),
88 predicate: self.predicate.clone(),
89 args: self.args.iter().map(|t| t.substitute(s)).collect(),
90 }
91 }
92}
93
94impl RenameWithSubstitution<IRTerm> for Literal<IRTerm> {
95 type Output = Literal<IRTerm>;
96 fn rename_with_sub(&self) -> (Self::Output, Substitution<IRTerm>) {
97 let s: Substitution = self
98 .variables(true)
99 .iter()
100 .map(|r| {
101 [(r.clone(), r.rename())]
102 .iter()
103 .cloned()
104 .collect::<Substitution>()
105 })
106 .reduce(|mut l, r| {
107 l.extend(r);
108 l
109 })
110 .unwrap();
111 (self.substitute(&s), s)
112 }
113}
114
115impl Substitute<IRTerm> for Vec<Literal<IRTerm>> {
116 type Output = Vec<Literal<IRTerm>>;
117 fn substitute(&self, s: &Substitution<IRTerm>) -> Self::Output {
118 self.iter().map(|l| l.substitute(s)).collect()
119 }
120}
121
122impl RenameWithSubstitution<IRTerm> for Vec<Literal<IRTerm>> {
123 type Output = Vec<Literal<IRTerm>>;
124 fn rename_with_sub(&self) -> (Self::Output, Substitution<IRTerm>) {
125 let s: Substitution<IRTerm> = self
126 .iter()
127 .flat_map(|e| e.variables(true))
128 .map(|r| {
129 [(r.clone(), r.rename())]
130 .iter()
131 .cloned()
132 .collect::<Substitution<IRTerm>>()
133 })
134 .reduce(|mut l, r| {
135 l.extend(r);
136 l
137 })
138 .unwrap();
139 (self.substitute(&s), s)
140 }
141}
142
143impl Substitute<IRTerm> for Clause<IRTerm> {
144 type Output = Clause<IRTerm>;
145 fn substitute(&self, s: &Substitution<IRTerm>) -> Self::Output {
146 Clause {
147 head: self.head.substitute(s),
148 body: self.body.iter().map(|t| t.substitute(s)).collect(),
149 }
150 }
151}
152
153impl RenameWithSubstitution<IRTerm> for Clause<IRTerm> {
154 type Output = Clause<IRTerm>;
155 fn rename_with_sub(&self) -> (Self::Output, Substitution<IRTerm>) {
156 let s: Substitution<IRTerm> = self
157 .variables(true)
158 .iter()
159 .map(|r| {
160 [(r.clone(), r.rename())]
161 .iter()
162 .cloned()
163 .collect::<Substitution<IRTerm>>()
164 })
165 .reduce(|mut l, r| {
166 l.extend(r);
167 l
168 })
169 .unwrap_or_default();
170 (self.substitute(&s), s)
171 }
172}
173
174pub fn compose_no_extend(
175 l: &Substitution<IRTerm>,
176 r: &Substitution<IRTerm>,
177) -> Substitution<IRTerm> {
178 let mut result = Substitution::<IRTerm>::new();
179 for (k, v) in l {
180 result.insert(k.clone(), v.substitute(r));
181 }
182 result
183}
184
185pub fn compose_extend(l: &Substitution<IRTerm>, r: &Substitution<IRTerm>) -> Substitution<IRTerm> {
186 let mut result = compose_no_extend(l, r);
187 result.extend(r.clone());
188 result
189}
190
191impl Literal<IRTerm> {
192 pub fn unify(&self, other: &Literal<IRTerm>) -> Option<Substitution<IRTerm>> {
193 fn unify_arglist(current: &[IRTerm], other: &[IRTerm]) -> Option<Substitution<IRTerm>> {
194 let mut s = Substitution::<IRTerm>::new();
195 for (self_term, other_term) in current.iter().zip(other) {
196 let self_term_subs = self_term.substitute(&s);
197 let other_term_subs = other_term.substitute(&s);
198 if self_term_subs != other_term_subs {
199 match (self_term_subs.clone(), other_term_subs.clone()) {
200 (IRTerm::Constant(_), IRTerm::Constant(_)) => return None,
202
203 (IRTerm::Array(terms1), IRTerm::Array(terms2)) => {
204 if terms1.len() == terms2.len() {
205 if let Some(new_sub) = unify_arglist(&terms1, &terms2) {
206 s = compose_extend(&s, &new_sub);
207 } else {
208 return None;
209 }
210 } else {
211 return None;
212 }
213 }
214 (IRTerm::Array(_), IRTerm::Constant(_))
215 | (IRTerm::Constant(_), IRTerm::Array(_)) => return None,
216 (IRTerm::Array(ts), v) | (v, IRTerm::Array(ts)) => {
217 let mut upd = Substitution::<IRTerm>::new();
218 upd.insert(v.clone(), IRTerm::Array(ts));
219 s = compose_extend(&s, &upd);
220 }
221
222 (IRTerm::Constant(_), v) => {
223 let mut upd = Substitution::<IRTerm>::new();
224 upd.insert(v.clone(), self_term_subs.clone());
225 s = compose_extend(&s, &upd);
226 }
227 (v1, v2) => {
228 let mut upd = Substitution::<IRTerm>::new();
229 upd.insert(v1.clone(), v2.clone());
230 s = compose_extend(&s, &upd);
231 }
232 }
233 }
234 }
235 Some(s)
236 }
237
238 if self.signature() != other.signature() {
239 return None;
240 }
241 unify_arglist(&self.args, &other.args)
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use serial_test::serial;
248
249 use super::*;
250
251 #[test]
252 fn simple_unifier() {
253 let l: logic::Literal = "a(X, \"c\")".parse().unwrap();
254 let m: logic::Literal = "a(\"d\", Y)".parse().unwrap();
255 let result = l.unify(&m);
256 assert!(result.is_some());
257 let mgu = result.unwrap();
258 assert_eq!(l.substitute(&mgu), m.substitute(&mgu))
259 }
260
261 #[test]
262 fn complex_unifier() {
263 let l: logic::Literal = "p(Y, Y, V, W)".parse().unwrap();
264 let m: logic::Literal = "p(X, Z, \"a\", U)".parse().unwrap();
265 let result = l.unify(&m);
266 assert!(result.is_some());
267 let mgu = result.unwrap();
268 assert!(l.substitute(&mgu).eq_ignoring_position(&m.substitute(&mgu)));
269 assert_eq!(
270 mgu.get(&logic::IRTerm::UserVariable("Y".into())),
271 Some(&logic::IRTerm::UserVariable("Z".into()))
272 );
273 assert_eq!(
274 mgu.get(&logic::IRTerm::UserVariable("X".into())),
275 Some(&logic::IRTerm::UserVariable("Z".into()))
276 );
277 assert_eq!(
278 mgu.get(&logic::IRTerm::UserVariable("V".into())),
279 Some(&logic::IRTerm::Constant("a".into()))
280 );
281 assert_eq!(
282 mgu.get(&logic::IRTerm::UserVariable("W".into())),
283 Some(&logic::IRTerm::UserVariable("U".into()))
284 );
285 }
286
287 #[test]
288 fn array_unifier() {
289 let l: logic::Literal = "p([X, \"b\"], X)".parse().unwrap();
290 let m: logic::Literal = "p([\"a\", Y], X)".parse().unwrap();
291 let result = l.unify(&m);
292 assert!(result.is_some());
293 let mgu = result.unwrap();
294 assert!(l.substitute(&mgu).eq_ignoring_position(&m.substitute(&mgu)));
295 assert_eq!(
296 mgu.get(&logic::IRTerm::UserVariable("X".into())),
297 Some(&logic::IRTerm::Constant("a".into()))
298 );
299 assert_eq!(
300 mgu.get(&logic::IRTerm::UserVariable("Y".into())),
301 Some(&logic::IRTerm::Constant("b".into()))
302 );
303 }
304
305 #[test]
306 fn compex_array_unifier() {
307 let l: logic::Literal = "p(A, \"b\")".parse().unwrap();
308 let m: logic::Literal = "p([\"a\", X], X)".parse().unwrap();
309 let result = l.unify(&m);
310 assert!(result.is_some());
311 let mgu = result.unwrap();
312 assert!(l.substitute(&mgu).eq_ignoring_position(&m.substitute(&mgu)));
313 assert_eq!(
314 mgu.get(&logic::IRTerm::UserVariable("X".into())),
315 Some(&logic::IRTerm::Constant("b".into()))
316 );
317 assert_eq!(
318 mgu.get(&logic::IRTerm::UserVariable("A".into())),
319 Some(&logic::IRTerm::Array(vec![
320 IRTerm::Constant("a".into()),
321 IRTerm::Constant("b".into())
322 ]))
323 );
324 }
325
326 #[test]
327 fn simple_non_unifiable() {
328 let l: logic::Literal = "a(X, \"b\")".parse().unwrap();
329 let m: logic::Literal = "a(Y)".parse().unwrap();
330 let result = l.unify(&m);
331 assert!(result.is_none());
332 }
333
334 #[test]
335 fn complex_non_unifiable() {
336 let l: logic::Literal = "q(X, \"a\", X, \"b\")".parse().unwrap();
337 let m: logic::Literal = "q(Y, \"a\", \"a\", Y)".parse().unwrap();
338 let result = l.unify(&m);
339 assert!(result.is_none());
340 }
341
342 #[test]
343 #[serial]
344 fn simple_renaming() {
345 let l: logic::Literal = "a(X, X, Y)".parse().unwrap();
346 let (m, _) = l.rename_with_sub();
347 assert!(l != m);
348 assert!(m.args[0] == m.args[1]);
349 assert!(m.args[0] != m.args[2]);
350 }
351
352 #[test]
353 #[serial]
354 fn complex_renaming() {
355 let l: logic::Literal = "a([ X, Y ], X)".parse().unwrap();
356 let (m, _) = l.rename_with_sub();
357 assert!(l != m);
358 if let IRTerm::Array(ts) = &m.args[0] {
359 assert_eq!(ts[0], m.args[1]);
360 assert_ne!(ts[1], m.args[1]);
361 } else {
362 panic!()
363 }
364 }
365}