Skip to main content

mangle_vm/
lib.rs

1// Copyright 2025 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::Result;
16pub use mangle_common::Host;
17use wasmtime::{Engine, Linker, Module, Store};
18
19#[cfg(feature = "csv_storage")]
20pub mod csv_host;
21
22pub mod composite_host;
23
24pub struct Vm {
25    engine: Engine,
26}
27
28struct HostWrapper<H>(H);
29
30impl Vm {
31    pub fn new() -> Result<Self> {
32        let engine = Engine::default();
33        Ok(Self { engine })
34    }
35
36    pub fn execute<H: Host + Send + 'static>(&self, wasm: &[u8], host: H) -> Result<()> {
37        let module = Module::new(&self.engine, wasm)?;
38        let mut store = Store::new(&self.engine, HostWrapper(host));
39
40        let mut linker = Linker::new(&self.engine);
41
42        linker.func_wrap(
43            "env",
44            "scan_start",
45            |mut caller: wasmtime::Caller<'_, HostWrapper<H>>, rel_id: i32| -> i32 {
46                caller.data_mut().0.scan_start(rel_id)
47            },
48        )?;
49
50        linker.func_wrap(
51            "env",
52            "scan_delta_start",
53            |mut caller: wasmtime::Caller<'_, HostWrapper<H>>, rel_id: i32| -> i32 {
54                caller.data_mut().0.scan_delta_start(rel_id)
55            },
56        )?;
57
58        linker.func_wrap(
59            "env",
60            "scan_index_start",
61            |mut caller: wasmtime::Caller<'_, HostWrapper<H>>,
62             rel_id: i32,
63             col_idx: i32,
64             val: i64|
65             -> i32 { caller.data_mut().0.scan_index_start(rel_id, col_idx, val) },
66        )?;
67
68        linker.func_wrap(
69            "env",
70            "scan_aggregate_start",
71            |mut caller: wasmtime::Caller<'_, HostWrapper<H>>,
72             rel_id: i32,
73             ptr: i32,
74             len: i32|
75             -> i32 {
76                let mem = caller
77                    .get_export("memory")
78                    .expect("memory export not found")
79                    .into_memory()
80                    .expect("not a memory");
81
82                let data = mem.data(&caller);
83
84                // Read description from memory
85
86                // Safety: Bounds check is implicitly done by slice indexing, will panic if OOB.
87
88                // ptr is byte offset. len is number of i32s.
89
90                let start = ptr as usize;
91
92                let end = start + (len as usize) * 4;
93
94                let bytes = &data[start..end];
95
96                let mut desc = Vec::with_capacity(len as usize);
97
98                for chunk in bytes.chunks_exact(4) {
99                    let val = i32::from_le_bytes(chunk.try_into().unwrap());
100
101                    desc.push(val);
102                }
103
104                caller.data_mut().0.scan_aggregate_start(rel_id, desc)
105            },
106        )?;
107
108        linker.func_wrap(
109            "env",
110            "scan_next",
111            |mut caller: wasmtime::Caller<'_, HostWrapper<H>>, iter_id: i32| -> i32 {
112                caller.data_mut().0.scan_next(iter_id)
113            },
114        )?;
115
116        linker.func_wrap(
117            "env",
118            "get_col",
119            |mut caller: wasmtime::Caller<'_, HostWrapper<H>>, ptr: i32, idx: i32| -> i64 {
120                caller.data_mut().0.get_col(ptr, idx)
121            },
122        )?;
123
124        linker.func_wrap(
125            "env",
126            "insert",
127            |mut caller: wasmtime::Caller<'_, HostWrapper<H>>, rel_id: i32, val: i64| {
128                caller.data_mut().0.insert(rel_id, val);
129            },
130        )?;
131
132        linker.func_wrap(
133            "env",
134            "merge_deltas",
135            |mut caller: wasmtime::Caller<'_, HostWrapper<H>>| -> i32 {
136                caller.data_mut().0.merge_deltas()
137            },
138        )?;
139
140        linker.func_wrap(
141            "env",
142            "debuglog",
143            |mut caller: wasmtime::Caller<'_, HostWrapper<H>>, val: i64| {
144                caller.data_mut().0.debuglog(val);
145            },
146        )?;
147
148        let instance = linker.instantiate(&mut store, &module)?;
149        let run = instance.get_typed_func::<(), ()>(&mut store, "run")?;
150
151        run.call(&mut store, ())?;
152
153        Ok(())
154    }
155}
156
157// Minimal dummy host for tests that don't need storage
158pub struct DummyHost;
159impl Host for DummyHost {
160    fn scan_start(&mut self, _rel_id: i32) -> i32 {
161        0
162    }
163    fn scan_delta_start(&mut self, _rel_id: i32) -> i32 {
164        0
165    }
166    fn scan_index_start(&mut self, _rel_id: i32, _col_idx: i32, _val: i64) -> i32 {
167        0
168    }
169    fn scan_aggregate_start(&mut self, _rel_id: i32, _description: Vec<i32>) -> i32 {
170        0
171    }
172    fn scan_next(&mut self, _iter_id: i32) -> i32 {
173        0
174    }
175    fn get_col(&mut self, _ptr: i32, _idx: i32) -> i64 {
176        0
177    }
178    fn insert(&mut self, _rel_id: i32, _val: i64) {}
179    fn merge_deltas(&mut self) -> i32 {
180        0
181    }
182    fn debuglog(&mut self, _val: i64) {}
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use mangle_analysis::LoweringContext;
189    use mangle_ast as ast;
190    use mangle_codegen::{Codegen, WasmImportsBackend};
191    use std::collections::HashMap;
192
193    #[test]
194    fn test_e2e_execution() -> Result<()> {
195        let arena = ast::Arena::new_with_global_interner();
196        let foo = arena.predicate_sym("foo", Some(1));
197        let bar = arena.predicate_sym("bar", Some(1));
198        let x = arena.variable("X");
199
200        let clause = ast::Clause {
201            head: arena.atom(foo, &[x]),
202            premises: arena
203                .alloc_slice_copy(&[arena.alloc(ast::Term::Atom(arena.atom(bar, &[x])))]),
204            transform: &[],
205        };
206        let unit = ast::Unit {
207            decls: &[],
208            clauses: arena.alloc_slice_copy(&[&clause]),
209        };
210
211        let ctx = LoweringContext::new(&arena);
212        let mut ir = ctx.lower_unit(&unit);
213
214        let mut codegen = Codegen::new(&mut ir, WasmImportsBackend);
215        let wasm = codegen.generate();
216
217        let vm = Vm::new()?;
218        vm.execute(&wasm, DummyHost)?;
219
220        Ok(())
221    }
222
223    #[test]
224    fn test_e2e_function() -> Result<()> {
225        let arena = ast::Arena::new_with_global_interner();
226        let foo = arena.predicate_sym("foo", Some(1));
227        let plus = arena.function_sym("fn:plus", Some(2));
228
229        let c1 = arena.const_(ast::Const::Number(1));
230        let c2 = arena.const_(ast::Const::Number(2));
231
232        let head_arg = arena.apply_fn(plus, &[c1, c2]);
233        let clause = ast::Clause {
234            head: arena.atom(foo, &[head_arg]),
235            premises: &[],
236            transform: &[],
237        };
238
239        let unit = ast::Unit {
240            decls: &[],
241            clauses: arena.alloc_slice_copy(&[&clause]),
242        };
243
244        let ctx = LoweringContext::new(&arena);
245        let mut ir = ctx.lower_unit(&unit);
246
247        let mut codegen = Codegen::new(&mut ir, WasmImportsBackend);
248        let wasm = codegen.generate();
249
250        let vm = Vm::new()?;
251        vm.execute(&wasm, DummyHost)?;
252        Ok(())
253    }
254
255    // --- Real Implementation Test ---
256
257    struct MemHost {
258        // Map rel_id -> List of Tuples (Vec<i64>)
259        data: HashMap<i32, Vec<Vec<i64>>>,
260        // Iterator state: iter_id -> (rel_id, current_index)
261        iters: HashMap<i32, (i32, usize)>,
262        next_iter_id: i32,
263    }
264
265    impl MemHost {
266        fn new() -> Self {
267            Self {
268                data: HashMap::new(),
269                iters: HashMap::new(),
270                next_iter_id: 1,
271            }
272        }
273
274        fn hash_name(name: &str) -> i32 {
275            let mut hash: u32 = 5381;
276            for c in name.bytes() {
277                hash = ((hash << 5).wrapping_add(hash)).wrapping_add(c as u32);
278            }
279            hash as i32
280        }
281
282        fn add_fact(&mut self, rel: &str, args: Vec<i64>) {
283            let id = Self::hash_name(rel);
284            self.data.entry(id).or_default().push(args);
285        }
286
287        fn get_facts(&self, rel: &str) -> Vec<Vec<i64>> {
288            let id = Self::hash_name(rel);
289            self.data.get(&id).cloned().unwrap_or_default()
290        }
291    }
292
293    impl Host for MemHost {
294        fn scan_start(&mut self, rel_id: i32) -> i32 {
295            let id = self.next_iter_id;
296            self.next_iter_id += 1;
297            self.iters.insert(id, (rel_id, 0));
298            id
299        }
300
301        fn scan_delta_start(&mut self, rel_id: i32) -> i32 {
302            self.scan_start(rel_id)
303        }
304
305        fn scan_index_start(&mut self, _rel_id: i32, _col_idx: i32, _val: i64) -> i32 {
306            0 // TODO: Actual index implementation
307        }
308
309        fn scan_aggregate_start(&mut self, _rel_id: i32, _description: Vec<i32>) -> i32 {
310            0 // TODO: Actual aggregate implementation
311        }
312
313        fn scan_next(&mut self, iter_id: i32) -> i32 {
314            if let Some((rel_id, idx)) = self.iters.get_mut(&iter_id)
315                && let Some(tuples) = self.data.get(rel_id)
316                && *idx < tuples.len()
317            {
318                // Return (iter_id << 16) | (idx + 1)
319                let ptr = (iter_id << 16) | (*idx as i32 + 1);
320                *idx += 1;
321                return ptr;
322            }
323            0 // Null
324        }
325
326        fn get_col(&mut self, ptr: i32, col_idx: i32) -> i64 {
327            let iter_id = ptr >> 16;
328            let tuple_idx = (ptr & 0xFFFF) - 1;
329
330            if let Some((rel_id, _)) = self.iters.get(&iter_id)
331                && let Some(tuples) = self.data.get(rel_id)
332            {
333                return tuples[tuple_idx as usize][col_idx as usize];
334            }
335            0
336        }
337
338        fn insert(&mut self, rel_id: i32, val: i64) {
339            self.data.entry(rel_id).or_default().push(vec![val]);
340        }
341
342        fn merge_deltas(&mut self) -> i32 {
343            0
344        }
345        fn debuglog(&mut self, val: i64) {
346            eprintln!("WASM LOG: {}", val);
347        }
348    }
349
350    #[test]
351    fn test_e2e_mem_store() -> Result<()> {
352        let arena = ast::Arena::new_with_global_interner();
353        // p(X) :- q(X).
354        // q is extensional.
355        let p = arena.predicate_sym("p", Some(1));
356        let q = arena.predicate_sym("q", Some(1));
357        let x = arena.variable("X");
358
359        let clause = ast::Clause {
360            head: arena.atom(p, &[x]),
361            premises: arena.alloc_slice_copy(&[arena.alloc(ast::Term::Atom(arena.atom(q, &[x])))]),
362            transform: &[],
363        };
364
365        let unit = ast::Unit {
366            decls: &[],
367            clauses: arena.alloc_slice_copy(&[&clause]),
368        };
369
370        let ctx = LoweringContext::new(&arena);
371        let mut ir = ctx.lower_unit(&unit);
372
373        let mut codegen = Codegen::new(&mut ir, WasmImportsBackend);
374        let wasm = codegen.generate();
375
376        // Setup Host
377        let mut host = MemHost::new();
378        host.add_fact("q", vec![10]);
379        host.add_fact("q", vec![20]);
380
381        let vm = Vm::new()?;
382
383        use std::sync::{Arc, Mutex};
384
385        #[derive(Clone)]
386        struct SharedMemHost {
387            inner: Arc<Mutex<MemHost>>,
388        }
389
390        impl Host for SharedMemHost {
391            fn scan_start(&mut self, rel_id: i32) -> i32 {
392                self.inner.lock().unwrap().scan_start(rel_id)
393            }
394            fn scan_delta_start(&mut self, rel_id: i32) -> i32 {
395                self.inner.lock().unwrap().scan_delta_start(rel_id)
396            }
397            fn scan_index_start(&mut self, rel_id: i32, col_idx: i32, val: i64) -> i32 {
398                self.inner
399                    .lock()
400                    .unwrap()
401                    .scan_index_start(rel_id, col_idx, val)
402            }
403            fn scan_aggregate_start(&mut self, rel_id: i32, description: Vec<i32>) -> i32 {
404                self.inner
405                    .lock()
406                    .unwrap()
407                    .scan_aggregate_start(rel_id, description)
408            }
409            fn scan_next(&mut self, iter_id: i32) -> i32 {
410                self.inner.lock().unwrap().scan_next(iter_id)
411            }
412            fn get_col(&mut self, ptr: i32, idx: i32) -> i64 {
413                self.inner.lock().unwrap().get_col(ptr, idx)
414            }
415            fn insert(&mut self, rel_id: i32, val: i64) {
416                self.inner.lock().unwrap().insert(rel_id, val);
417            }
418            fn merge_deltas(&mut self) -> i32 {
419                self.inner.lock().unwrap().merge_deltas()
420            }
421            fn debuglog(&mut self, val: i64) {
422                self.inner.lock().unwrap().debuglog(val);
423            }
424        }
425
426        let shared_host = SharedMemHost {
427            inner: Arc::new(Mutex::new(host)),
428        };
429
430        vm.execute(&wasm, shared_host.clone())?; // Clone increments Arc ref
431
432        let final_host = shared_host.inner.lock().unwrap();
433        let results = final_host.get_facts("p");
434
435        // Check contains 10 and 20
436        assert!(results.iter().any(|t| t[0] == 10));
437        assert!(results.iter().any(|t| t[0] == 20));
438
439        Ok(())
440    }
441
442    #[cfg(feature = "csv_storage")]
443    #[test]
444    fn test_e2e_csv_host() -> Result<()> {
445        use crate::csv_host::CsvHost;
446        use std::io::Write;
447        use std::sync::{Arc, Mutex};
448        use tempfile::NamedTempFile;
449
450        // 1. Create a CSV file
451        let mut file = NamedTempFile::new()?;
452        writeln!(file, "10")?;
453        writeln!(file, "20")?;
454        let path = file.path().to_path_buf();
455
456        // 2. Setup CsvHost
457        let mut host = CsvHost::new();
458        host.add_file("q", path);
459
460        // 3. Compile Program: p(X) :- q(X).
461        let arena = ast::Arena::new_with_global_interner();
462        let p = arena.predicate_sym("p", Some(1));
463        let q = arena.predicate_sym("q", Some(1));
464        let x = arena.variable("X");
465
466        let clause = ast::Clause {
467            head: arena.atom(p, &[x]),
468            premises: arena.alloc_slice_copy(&[arena.alloc(ast::Term::Atom(arena.atom(q, &[x])))]),
469            transform: &[],
470        };
471        let unit = ast::Unit {
472            decls: &[],
473            clauses: arena.alloc_slice_copy(&[&clause]),
474        };
475
476        let ctx = LoweringContext::new(&arena);
477        let mut ir = ctx.lower_unit(&unit);
478
479        let mut codegen = Codegen::new(&mut ir, WasmImportsBackend);
480        let wasm = codegen.generate();
481
482        // 4. Execute
483        // Wrapper for shared host
484        #[derive(Clone)]
485        struct SharedCsvHost {
486            inner: Arc<Mutex<CsvHost>>,
487        }
488        impl Host for SharedCsvHost {
489            fn scan_start(&mut self, rel_id: i32) -> i32 {
490                self.inner.lock().unwrap().scan_start(rel_id)
491            }
492            fn scan_delta_start(&mut self, rel_id: i32) -> i32 {
493                self.inner.lock().unwrap().scan_delta_start(rel_id)
494            }
495            fn scan_index_start(&mut self, rel_id: i32, col_idx: i32, val: i64) -> i32 {
496                self.inner
497                    .lock()
498                    .unwrap()
499                    .scan_index_start(rel_id, col_idx, val)
500            }
501            fn scan_aggregate_start(&mut self, rel_id: i32, description: Vec<i32>) -> i32 {
502                self.inner
503                    .lock()
504                    .unwrap()
505                    .scan_aggregate_start(rel_id, description)
506            }
507            fn scan_next(&mut self, iter_id: i32) -> i32 {
508                self.inner.lock().unwrap().scan_next(iter_id)
509            }
510            fn get_col(&mut self, ptr: i32, idx: i32) -> i64 {
511                self.inner.lock().unwrap().get_col(ptr, idx)
512            }
513            fn insert(&mut self, rel_id: i32, val: i64) {
514                self.inner.lock().unwrap().insert(rel_id, val);
515            }
516            fn merge_deltas(&mut self) -> i32 {
517                self.inner.lock().unwrap().merge_deltas()
518            }
519            fn debuglog(&mut self, val: i64) {
520                self.inner.lock().unwrap().debuglog(val);
521            }
522        }
523
524        let shared_host = SharedCsvHost {
525            inner: Arc::new(Mutex::new(host)),
526        };
527        let vm = Vm::new()?;
528
529        vm.execute(&wasm, shared_host)?;
530
531        Ok(())
532    }
533}