1use anyhow::Result;
16pub use mangle_common::Host;
17use wasmtime::{Engine, Linker, Module, Store};
18
19#[cfg(feature = "csv_storage")]
20pub mod csv_host;
21
22pub mod composite_host;
23
24pub struct Vm {
25 engine: Engine,
26}
27
28struct HostWrapper<H>(H);
29
30impl Vm {
31 pub fn new() -> Result<Self> {
32 let engine = Engine::default();
33 Ok(Self { engine })
34 }
35
36 pub fn execute<H: Host + Send + 'static>(&self, wasm: &[u8], host: H) -> Result<()> {
37 let module = Module::new(&self.engine, wasm)?;
38 let mut store = Store::new(&self.engine, HostWrapper(host));
39
40 let mut linker = Linker::new(&self.engine);
41
42 linker.func_wrap(
43 "env",
44 "scan_start",
45 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>, rel_id: i32| -> i32 {
46 caller.data_mut().0.scan_start(rel_id)
47 },
48 )?;
49
50 linker.func_wrap(
51 "env",
52 "scan_delta_start",
53 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>, rel_id: i32| -> i32 {
54 caller.data_mut().0.scan_delta_start(rel_id)
55 },
56 )?;
57
58 linker.func_wrap(
59 "env",
60 "scan_index_start",
61 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>,
62 rel_id: i32,
63 col_idx: i32,
64 val: i64|
65 -> i32 { caller.data_mut().0.scan_index_start(rel_id, col_idx, val) },
66 )?;
67
68 linker.func_wrap(
69 "env",
70 "scan_aggregate_start",
71 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>,
72 rel_id: i32,
73 ptr: i32,
74 len: i32|
75 -> i32 {
76 let mem = caller
77 .get_export("memory")
78 .expect("memory export not found")
79 .into_memory()
80 .expect("not a memory");
81
82 let data = mem.data(&caller);
83
84 let start = ptr as usize;
91
92 let end = start + (len as usize) * 4;
93
94 let bytes = &data[start..end];
95
96 let mut desc = Vec::with_capacity(len as usize);
97
98 for chunk in bytes.chunks_exact(4) {
99 let val = i32::from_le_bytes(chunk.try_into().unwrap());
100
101 desc.push(val);
102 }
103
104 caller.data_mut().0.scan_aggregate_start(rel_id, desc)
105 },
106 )?;
107
108 linker.func_wrap(
109 "env",
110 "scan_next",
111 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>, iter_id: i32| -> i32 {
112 caller.data_mut().0.scan_next(iter_id)
113 },
114 )?;
115
116 linker.func_wrap(
117 "env",
118 "get_col",
119 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>, ptr: i32, idx: i32| -> i64 {
120 caller.data_mut().0.get_col(ptr, idx)
121 },
122 )?;
123
124 linker.func_wrap(
125 "env",
126 "insert",
127 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>, rel_id: i32, val: i64| {
128 caller.data_mut().0.insert(rel_id, val);
129 },
130 )?;
131
132 linker.func_wrap(
133 "env",
134 "merge_deltas",
135 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>| -> i32 {
136 caller.data_mut().0.merge_deltas()
137 },
138 )?;
139
140 linker.func_wrap(
141 "env",
142 "debuglog",
143 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>, val: i64| {
144 caller.data_mut().0.debuglog(val);
145 },
146 )?;
147
148 let instance = linker.instantiate(&mut store, &module)?;
149 let run = instance.get_typed_func::<(), ()>(&mut store, "run")?;
150
151 run.call(&mut store, ())?;
152
153 Ok(())
154 }
155}
156
157pub struct DummyHost;
159impl Host for DummyHost {
160 fn scan_start(&mut self, _rel_id: i32) -> i32 {
161 0
162 }
163 fn scan_delta_start(&mut self, _rel_id: i32) -> i32 {
164 0
165 }
166 fn scan_index_start(&mut self, _rel_id: i32, _col_idx: i32, _val: i64) -> i32 {
167 0
168 }
169 fn scan_aggregate_start(&mut self, _rel_id: i32, _description: Vec<i32>) -> i32 {
170 0
171 }
172 fn scan_next(&mut self, _iter_id: i32) -> i32 {
173 0
174 }
175 fn get_col(&mut self, _ptr: i32, _idx: i32) -> i64 {
176 0
177 }
178 fn insert(&mut self, _rel_id: i32, _val: i64) {}
179 fn merge_deltas(&mut self) -> i32 {
180 0
181 }
182 fn debuglog(&mut self, _val: i64) {}
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use mangle_analysis::LoweringContext;
189 use mangle_ast as ast;
190 use mangle_codegen::{Codegen, WasmImportsBackend};
191 use std::collections::HashMap;
192
193 #[test]
194 fn test_e2e_execution() -> Result<()> {
195 let arena = ast::Arena::new_with_global_interner();
196 let foo = arena.predicate_sym("foo", Some(1));
197 let bar = arena.predicate_sym("bar", Some(1));
198 let x = arena.variable("X");
199
200 let clause = ast::Clause {
201 head: arena.atom(foo, &[x]),
202 premises: arena
203 .alloc_slice_copy(&[arena.alloc(ast::Term::Atom(arena.atom(bar, &[x])))]),
204 transform: &[],
205 };
206 let unit = ast::Unit {
207 decls: &[],
208 clauses: arena.alloc_slice_copy(&[&clause]),
209 };
210
211 let ctx = LoweringContext::new(&arena);
212 let mut ir = ctx.lower_unit(&unit);
213
214 let mut codegen = Codegen::new(&mut ir, WasmImportsBackend);
215 let wasm = codegen.generate();
216
217 let vm = Vm::new()?;
218 vm.execute(&wasm, DummyHost)?;
219
220 Ok(())
221 }
222
223 #[test]
224 fn test_e2e_function() -> Result<()> {
225 let arena = ast::Arena::new_with_global_interner();
226 let foo = arena.predicate_sym("foo", Some(1));
227 let plus = arena.function_sym("fn:plus", Some(2));
228
229 let c1 = arena.const_(ast::Const::Number(1));
230 let c2 = arena.const_(ast::Const::Number(2));
231
232 let head_arg = arena.apply_fn(plus, &[c1, c2]);
233 let clause = ast::Clause {
234 head: arena.atom(foo, &[head_arg]),
235 premises: &[],
236 transform: &[],
237 };
238
239 let unit = ast::Unit {
240 decls: &[],
241 clauses: arena.alloc_slice_copy(&[&clause]),
242 };
243
244 let ctx = LoweringContext::new(&arena);
245 let mut ir = ctx.lower_unit(&unit);
246
247 let mut codegen = Codegen::new(&mut ir, WasmImportsBackend);
248 let wasm = codegen.generate();
249
250 let vm = Vm::new()?;
251 vm.execute(&wasm, DummyHost)?;
252 Ok(())
253 }
254
255 struct MemHost {
258 data: HashMap<i32, Vec<Vec<i64>>>,
260 iters: HashMap<i32, (i32, usize)>,
262 next_iter_id: i32,
263 }
264
265 impl MemHost {
266 fn new() -> Self {
267 Self {
268 data: HashMap::new(),
269 iters: HashMap::new(),
270 next_iter_id: 1,
271 }
272 }
273
274 fn hash_name(name: &str) -> i32 {
275 let mut hash: u32 = 5381;
276 for c in name.bytes() {
277 hash = ((hash << 5).wrapping_add(hash)).wrapping_add(c as u32);
278 }
279 hash as i32
280 }
281
282 fn add_fact(&mut self, rel: &str, args: Vec<i64>) {
283 let id = Self::hash_name(rel);
284 self.data.entry(id).or_default().push(args);
285 }
286
287 fn get_facts(&self, rel: &str) -> Vec<Vec<i64>> {
288 let id = Self::hash_name(rel);
289 self.data.get(&id).cloned().unwrap_or_default()
290 }
291 }
292
293 impl Host for MemHost {
294 fn scan_start(&mut self, rel_id: i32) -> i32 {
295 let id = self.next_iter_id;
296 self.next_iter_id += 1;
297 self.iters.insert(id, (rel_id, 0));
298 id
299 }
300
301 fn scan_delta_start(&mut self, rel_id: i32) -> i32 {
302 self.scan_start(rel_id)
303 }
304
305 fn scan_index_start(&mut self, _rel_id: i32, _col_idx: i32, _val: i64) -> i32 {
306 0 }
308
309 fn scan_aggregate_start(&mut self, _rel_id: i32, _description: Vec<i32>) -> i32 {
310 0 }
312
313 fn scan_next(&mut self, iter_id: i32) -> i32 {
314 if let Some((rel_id, idx)) = self.iters.get_mut(&iter_id)
315 && let Some(tuples) = self.data.get(rel_id)
316 && *idx < tuples.len()
317 {
318 let ptr = (iter_id << 16) | (*idx as i32 + 1);
320 *idx += 1;
321 return ptr;
322 }
323 0 }
325
326 fn get_col(&mut self, ptr: i32, col_idx: i32) -> i64 {
327 let iter_id = ptr >> 16;
328 let tuple_idx = (ptr & 0xFFFF) - 1;
329
330 if let Some((rel_id, _)) = self.iters.get(&iter_id)
331 && let Some(tuples) = self.data.get(rel_id)
332 {
333 return tuples[tuple_idx as usize][col_idx as usize];
334 }
335 0
336 }
337
338 fn insert(&mut self, rel_id: i32, val: i64) {
339 self.data.entry(rel_id).or_default().push(vec![val]);
340 }
341
342 fn merge_deltas(&mut self) -> i32 {
343 0
344 }
345 fn debuglog(&mut self, val: i64) {
346 eprintln!("WASM LOG: {}", val);
347 }
348 }
349
350 #[test]
351 fn test_e2e_mem_store() -> Result<()> {
352 let arena = ast::Arena::new_with_global_interner();
353 let p = arena.predicate_sym("p", Some(1));
356 let q = arena.predicate_sym("q", Some(1));
357 let x = arena.variable("X");
358
359 let clause = ast::Clause {
360 head: arena.atom(p, &[x]),
361 premises: arena.alloc_slice_copy(&[arena.alloc(ast::Term::Atom(arena.atom(q, &[x])))]),
362 transform: &[],
363 };
364
365 let unit = ast::Unit {
366 decls: &[],
367 clauses: arena.alloc_slice_copy(&[&clause]),
368 };
369
370 let ctx = LoweringContext::new(&arena);
371 let mut ir = ctx.lower_unit(&unit);
372
373 let mut codegen = Codegen::new(&mut ir, WasmImportsBackend);
374 let wasm = codegen.generate();
375
376 let mut host = MemHost::new();
378 host.add_fact("q", vec![10]);
379 host.add_fact("q", vec![20]);
380
381 let vm = Vm::new()?;
382
383 use std::sync::{Arc, Mutex};
384
385 #[derive(Clone)]
386 struct SharedMemHost {
387 inner: Arc<Mutex<MemHost>>,
388 }
389
390 impl Host for SharedMemHost {
391 fn scan_start(&mut self, rel_id: i32) -> i32 {
392 self.inner.lock().unwrap().scan_start(rel_id)
393 }
394 fn scan_delta_start(&mut self, rel_id: i32) -> i32 {
395 self.inner.lock().unwrap().scan_delta_start(rel_id)
396 }
397 fn scan_index_start(&mut self, rel_id: i32, col_idx: i32, val: i64) -> i32 {
398 self.inner
399 .lock()
400 .unwrap()
401 .scan_index_start(rel_id, col_idx, val)
402 }
403 fn scan_aggregate_start(&mut self, rel_id: i32, description: Vec<i32>) -> i32 {
404 self.inner
405 .lock()
406 .unwrap()
407 .scan_aggregate_start(rel_id, description)
408 }
409 fn scan_next(&mut self, iter_id: i32) -> i32 {
410 self.inner.lock().unwrap().scan_next(iter_id)
411 }
412 fn get_col(&mut self, ptr: i32, idx: i32) -> i64 {
413 self.inner.lock().unwrap().get_col(ptr, idx)
414 }
415 fn insert(&mut self, rel_id: i32, val: i64) {
416 self.inner.lock().unwrap().insert(rel_id, val);
417 }
418 fn merge_deltas(&mut self) -> i32 {
419 self.inner.lock().unwrap().merge_deltas()
420 }
421 fn debuglog(&mut self, val: i64) {
422 self.inner.lock().unwrap().debuglog(val);
423 }
424 }
425
426 let shared_host = SharedMemHost {
427 inner: Arc::new(Mutex::new(host)),
428 };
429
430 vm.execute(&wasm, shared_host.clone())?; let final_host = shared_host.inner.lock().unwrap();
433 let results = final_host.get_facts("p");
434
435 assert!(results.iter().any(|t| t[0] == 10));
437 assert!(results.iter().any(|t| t[0] == 20));
438
439 Ok(())
440 }
441
442 #[cfg(feature = "csv_storage")]
443 #[test]
444 fn test_e2e_csv_host() -> Result<()> {
445 use crate::csv_host::CsvHost;
446 use std::io::Write;
447 use std::sync::{Arc, Mutex};
448 use tempfile::NamedTempFile;
449
450 let mut file = NamedTempFile::new()?;
452 writeln!(file, "10")?;
453 writeln!(file, "20")?;
454 let path = file.path().to_path_buf();
455
456 let mut host = CsvHost::new();
458 host.add_file("q", path);
459
460 let arena = ast::Arena::new_with_global_interner();
462 let p = arena.predicate_sym("p", Some(1));
463 let q = arena.predicate_sym("q", Some(1));
464 let x = arena.variable("X");
465
466 let clause = ast::Clause {
467 head: arena.atom(p, &[x]),
468 premises: arena.alloc_slice_copy(&[arena.alloc(ast::Term::Atom(arena.atom(q, &[x])))]),
469 transform: &[],
470 };
471 let unit = ast::Unit {
472 decls: &[],
473 clauses: arena.alloc_slice_copy(&[&clause]),
474 };
475
476 let ctx = LoweringContext::new(&arena);
477 let mut ir = ctx.lower_unit(&unit);
478
479 let mut codegen = Codegen::new(&mut ir, WasmImportsBackend);
480 let wasm = codegen.generate();
481
482 #[derive(Clone)]
485 struct SharedCsvHost {
486 inner: Arc<Mutex<CsvHost>>,
487 }
488 impl Host for SharedCsvHost {
489 fn scan_start(&mut self, rel_id: i32) -> i32 {
490 self.inner.lock().unwrap().scan_start(rel_id)
491 }
492 fn scan_delta_start(&mut self, rel_id: i32) -> i32 {
493 self.inner.lock().unwrap().scan_delta_start(rel_id)
494 }
495 fn scan_index_start(&mut self, rel_id: i32, col_idx: i32, val: i64) -> i32 {
496 self.inner
497 .lock()
498 .unwrap()
499 .scan_index_start(rel_id, col_idx, val)
500 }
501 fn scan_aggregate_start(&mut self, rel_id: i32, description: Vec<i32>) -> i32 {
502 self.inner
503 .lock()
504 .unwrap()
505 .scan_aggregate_start(rel_id, description)
506 }
507 fn scan_next(&mut self, iter_id: i32) -> i32 {
508 self.inner.lock().unwrap().scan_next(iter_id)
509 }
510 fn get_col(&mut self, ptr: i32, idx: i32) -> i64 {
511 self.inner.lock().unwrap().get_col(ptr, idx)
512 }
513 fn insert(&mut self, rel_id: i32, val: i64) {
514 self.inner.lock().unwrap().insert(rel_id, val);
515 }
516 fn merge_deltas(&mut self) -> i32 {
517 self.inner.lock().unwrap().merge_deltas()
518 }
519 fn debuglog(&mut self, val: i64) {
520 self.inner.lock().unwrap().debuglog(val);
521 }
522 }
523
524 let shared_host = SharedCsvHost {
525 inner: Arc::new(Mutex::new(host)),
526 };
527 let vm = Vm::new()?;
528
529 vm.execute(&wasm, shared_host)?;
530
531 Ok(())
532 }
533}