Skip to main content

mangle_engine/
naive.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::Engine;
16use crate::Result;
17use crate::ast::{Arena, Atom, BaseTerm, Term};
18use anyhow::anyhow;
19use fxhash::FxHashMap;
20
21pub struct Naive {}
22
23impl<'e> Engine<'e> for Naive {
24    fn eval<'p>(
25        &'e self,
26        store: &'e impl mangle_common::FactStore<'e>,
27        program: &'p mangle_analysis::StratifiedProgram<'p>,
28    ) -> Result<()> {
29        // Initial facts.
30        for pred in program.intensional_preds() {
31            for rule in program.rules(pred) {
32                if rule.premises.is_empty() {
33                    store.add(program.arena(), rule.head)?;
34                }
35            }
36        }
37        loop {
38            let mut fact_added = false;
39            for pred in program.intensional_preds() {
40                for rule in program.rules(pred) {
41                    if rule.premises.is_empty() {
42                        continue; // initial facts added previously
43                    }
44                    let arena = Arena::new_with_global_interner();
45                    let subst: FxHashMap<u32, &BaseTerm> = FxHashMap::default();
46                    let mut all_ok = true;
47                    let subst = std::cell::RefCell::new(subst);
48                    for premise in rule.premises.iter() {
49                        let ok: bool = match premise {
50                            Term::Atom(query) => {
51                                // TODO: eval ApplyFn terms.
52                                let query = query.apply_subst(&arena, &subst.borrow());
53                                let found = std::cell::RefCell::new(false);
54                                let _ = store.get(query.sym, query.args, &|atom: &'_ Atom<'_>| {
55                                    let mut mismatch = false;
56                                    for (i, arg) in query.args.iter().enumerate() {
57                                        match arg {
58                                            BaseTerm::Variable(v) => {
59                                                let own_atom_ref = arena
60                                                    .copy_base_term(store.arena(), atom.args[i]);
61                                                subst.borrow_mut().insert(v.0, own_atom_ref);
62                                            }
63                                            c @ BaseTerm::Const(_) => {
64                                                if *c == atom.args[i] {
65                                                    continue;
66                                                }
67                                                mismatch = true;
68                                                break;
69                                            }
70                                            _ => {
71                                                return Err(anyhow!(format!(
72                                                    "Unsupported term: {arg}"
73                                                )));
74                                            }
75                                        }
76                                    }
77                                    *found.borrow_mut() = !mismatch;
78                                    Ok(())
79                                });
80                                !*found.borrow()
81                            }
82                            Term::NegAtom(query) => {
83                                let query = query.apply_subst(&arena, &subst.borrow());
84
85                                let found = std::cell::RefCell::new(false);
86                                let _ = store.get(query.sym, query.args, &|_| {
87                                    *found.borrow_mut() = true;
88                                    Ok(())
89                                });
90                                !*found.borrow()
91                            }
92                            Term::Eq(left, right) => {
93                                let left = left.apply_subst(&arena, &subst.borrow());
94                                let right = right.apply_subst(&arena, &subst.borrow());
95                                left == right
96                            }
97                            Term::Ineq(left, right) => {
98                                let left = left.apply_subst(&arena, &subst.borrow());
99                                let right = right.apply_subst(&arena, &subst.borrow());
100                                left != right
101                            }
102                            Term::TemporalAtom(query, _interval) => {
103                                // Treat like Atom for now; temporal filtering is TODO.
104                                let query = query.apply_subst(&arena, &subst.borrow());
105                                let found = std::cell::RefCell::new(false);
106                                let _ = store.get(query.sym, query.args, &|atom: &'_ Atom<'_>| {
107                                    let mut mismatch = false;
108                                    for (i, arg) in query.args.iter().enumerate() {
109                                        match arg {
110                                            BaseTerm::Variable(v) => {
111                                                let own_atom_ref = arena
112                                                    .copy_base_term(store.arena(), atom.args[i]);
113                                                subst.borrow_mut().insert(v.0, own_atom_ref);
114                                            }
115                                            c @ BaseTerm::Const(_) => {
116                                                if *c == atom.args[i] {
117                                                    continue;
118                                                }
119                                                mismatch = true;
120                                                break;
121                                            }
122                                            _ => {
123                                                return Err(anyhow!(format!(
124                                                    "Unsupported term: {arg}"
125                                                )));
126                                            }
127                                        }
128                                    }
129                                    *found.borrow_mut() = !mismatch;
130                                    Ok(())
131                                });
132                                !*found.borrow()
133                            }
134                        };
135                        if !ok {
136                            all_ok = false;
137                            break;
138                        }
139                    }
140                    if all_ok {
141                        let head = rule.head.apply_subst(&arena, &subst.borrow());
142                        fact_added = store.add(program.arena(), head)?;
143                    }
144                }
145            }
146            if !fact_added {
147                break;
148            }
149        }
150        Ok(())
151    }
152}
153
154#[cfg(test)]
155mod test {
156    use super::*;
157    use crate::ast;
158    use anyhow::Result;
159    use mangle_analysis::Program;
160    use mangle_common::{TableConfig, TableStoreImpl, TableStoreSchema};
161
162    #[test]
163    pub fn test_naive() -> Result<()> {
164        let arena = Arena::new_with_global_interner();
165        let edge = arena.predicate_sym("edge", Some(2));
166        let reachable = arena.predicate_sym("reachable", Some(2));
167        let mut schema: TableStoreSchema = FxHashMap::default();
168        schema.insert(edge, TableConfig::InMemory);
169        schema.insert(reachable, TableConfig::InMemory);
170        let store = TableStoreImpl::new(&arena, &schema);
171
172        use crate::factstore::FactStore;
173        store.add(
174            &arena,
175            arena.atom(
176                edge,
177                &[
178                    &ast::BaseTerm::Const(ast::Const::Number(10)),
179                    &ast::BaseTerm::Const(ast::Const::Number(20)),
180                ],
181            ),
182        )?;
183        store.add(
184            &arena,
185            arena.atom(
186                edge,
187                &[
188                    &ast::BaseTerm::Const(ast::Const::Number(20)),
189                    &ast::BaseTerm::Const(ast::Const::Number(30)),
190                ],
191            ),
192        )?;
193        store.add(
194            &arena,
195            arena.atom(
196                edge,
197                &[
198                    &ast::BaseTerm::Const(ast::Const::Number(30)),
199                    &ast::BaseTerm::Const(ast::Const::Number(40)),
200                ],
201            ),
202        )?;
203
204        let mut simple = Program::new(&arena);
205        // Manually set ext_preds since Program::new doesn't take them anymore?
206        // Wait, Program::new initializes ext_preds to empty. The struct definition shows it as public.
207        simple.ext_preds = vec![edge];
208
209        let head = arena.alloc(ast::Atom {
210            sym: reachable,
211            args: arena.alloc_slice_copy(&[arena.variable("X"), arena.variable("Y")]),
212        });
213        // Add a clause.
214        simple.add_clause(
215            &arena,
216            arena.alloc(ast::Clause {
217                head,
218                head_time: None,
219                premises: arena.alloc_slice_copy(&[arena.alloc(ast::Term::Atom(
220                    arena.atom(edge, &[arena.variable("X"), arena.variable("Y")]),
221                ))]),
222                transform: &[],
223            }),
224        );
225        simple.add_clause(
226            &arena,
227            arena.alloc(ast::Clause {
228                head: arena.atom(reachable, &[arena.variable("X"), arena.variable("Z")]),
229                head_time: None,
230                premises: arena.alloc_slice_copy(&[
231                    arena.alloc(ast::Term::Atom(
232                        arena.atom(edge, &[arena.variable("X"), arena.variable("Y")]),
233                    )),
234                    arena.alloc(ast::Term::Atom(
235                        arena.atom(reachable, &[arena.variable("Y"), arena.variable("X")]),
236                    )),
237                ]),
238                transform: &[],
239            }),
240        );
241
242        let stratified_program = simple.stratify().unwrap();
243
244        let engine = Naive {};
245        engine.eval(&store, &stratified_program)?;
246
247        use crate::factstore::ReadOnlyFactStore;
248        assert!(
249            store
250                .contains(
251                    &arena,
252                    arena.atom(
253                        edge,
254                        &[
255                            &ast::BaseTerm::Const(ast::Const::Number(30)),
256                            &ast::BaseTerm::Const(ast::Const::Number(40))
257                        ],
258                    )
259                )
260                .unwrap()
261        );
262        Ok(())
263    }
264}