1use anyhow::Result;
16pub use mangle_common::{Host, HostVal};
17use wasmtime::{Engine, ExternRef, Linker, Module, Rooted, Store};
18
19#[cfg(feature = "csv_storage")]
20pub mod csv_host;
21
22pub mod composite_host;
23
24pub struct Vm {
25 engine: Engine,
26}
27
28struct HostWrapper<H> {
29 host: H,
30 strings: Vec<String>,
31 names: Vec<String>,
32}
33
34fn extract_hv<T>(val: &Option<Rooted<ExternRef>>, caller: &wasmtime::Caller<'_, T>) -> HostVal {
36 let n = val
37 .as_ref()
38 .and_then(|r| r.data(caller).ok())
39 .flatten()
40 .and_then(|d| d.downcast_ref::<u32>().copied())
41 .unwrap_or(0);
42 HostVal(n)
43}
44
45fn make_ref<H>(
47 caller: &mut wasmtime::Caller<'_, HostWrapper<H>>,
48 hv: HostVal,
49) -> Result<Option<Rooted<ExternRef>>> {
50 let r = ExternRef::new(caller, hv.0)?;
51 Ok(Some(r))
52}
53
54impl Vm {
55 pub fn new() -> Result<Self> {
56 let engine = Engine::default();
57 Ok(Self { engine })
58 }
59
60 pub fn execute<H: Host + Send + Sync + 'static>(
61 &self,
62 wasm: &[u8],
63 host: H,
64 strings: Vec<String>,
65 names: Vec<String>,
66 ) -> Result<()> {
67 let module = Module::new(&self.engine, wasm)?;
68 let mut store = Store::new(
69 &self.engine,
70 HostWrapper {
71 host,
72 strings,
73 names,
74 },
75 );
76
77 let mut linker = Linker::new(&self.engine);
78
79 linker.func_wrap(
81 "env",
82 "scan_start",
83 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>, rel_id: i32| -> i32 {
84 caller.data_mut().host.scan_start(rel_id)
85 },
86 )?;
87
88 linker.func_wrap(
90 "env",
91 "scan_next",
92 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>, iter_id: i32| -> i32 {
93 caller.data_mut().host.scan_next(iter_id)
94 },
95 )?;
96
97 linker.func_wrap(
99 "env",
100 "get_col",
101 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>,
102 ptr: i32,
103 idx: i32|
104 -> Result<Option<Rooted<ExternRef>>> {
105 let hv = caller.data_mut().host.get_col(ptr, idx);
106 make_ref(&mut caller, hv)
107 },
108 )?;
109
110 linker.func_wrap(
112 "env",
113 "insert_begin",
114 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>, rel_id: i32| {
115 caller.data_mut().host.insert_begin(rel_id);
116 },
117 )?;
118
119 linker.func_wrap(
121 "env",
122 "insert_push",
123 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>,
124 val: Option<Rooted<ExternRef>>| {
125 let hv = extract_hv(&val, &caller);
126 caller.data_mut().host.insert_push(hv);
127 },
128 )?;
129
130 linker.func_wrap(
132 "env",
133 "insert_end",
134 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>| {
135 caller.data_mut().host.insert_end();
136 },
137 )?;
138
139 linker.func_wrap(
141 "env",
142 "scan_delta_start",
143 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>, rel_id: i32| -> i32 {
144 caller.data_mut().host.scan_delta_start(rel_id)
145 },
146 )?;
147
148 linker.func_wrap(
150 "env",
151 "merge_deltas",
152 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>| -> i32 {
153 caller.data_mut().host.merge_deltas()
154 },
155 )?;
156
157 linker.func_wrap(
159 "env",
160 "debuglog",
161 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>,
162 val: Option<Rooted<ExternRef>>| {
163 let hv = extract_hv(&val, &caller);
164 caller.data_mut().host.debuglog(hv);
165 },
166 )?;
167
168 linker.func_wrap(
170 "env",
171 "scan_index_start",
172 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>,
173 rel_id: i32,
174 col_idx: i32,
175 val: Option<Rooted<ExternRef>>|
176 -> i32 {
177 let hv = extract_hv(&val, &caller);
178 caller
179 .data_mut()
180 .host
181 .scan_index_start(rel_id, col_idx, hv)
182 },
183 )?;
184
185 linker.func_wrap(
187 "env",
188 "scan_aggregate_start",
189 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>,
190 rel_id: i32,
191 ptr: i32,
192 len: i32|
193 -> i32 {
194 let mem = caller
195 .get_export("memory")
196 .expect("memory export not found")
197 .into_memory()
198 .expect("not a memory");
199 let data = mem.data(&caller);
200 let start = ptr as usize;
201 let end = start + (len as usize) * 4;
202 let bytes = &data[start..end];
203 let mut desc = Vec::with_capacity(len as usize);
204 for chunk in bytes.chunks_exact(4) {
205 desc.push(i32::from_le_bytes(chunk.try_into().unwrap()));
206 }
207 caller
208 .data_mut()
209 .host
210 .scan_aggregate_start(rel_id, desc)
211 },
212 )?;
213
214 linker.func_wrap(
216 "env",
217 "const_number",
218 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>,
219 n: i64|
220 -> Result<Option<Rooted<ExternRef>>> {
221 let hv = caller.data_mut().host.const_number(n);
222 make_ref(&mut caller, hv)
223 },
224 )?;
225
226 linker.func_wrap(
228 "env",
229 "const_float",
230 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>,
231 bits: i64|
232 -> Result<Option<Rooted<ExternRef>>> {
233 let hv = caller.data_mut().host.const_float(bits);
234 make_ref(&mut caller, hv)
235 },
236 )?;
237
238 linker.func_wrap(
240 "env",
241 "const_string",
242 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>,
243 id: i32|
244 -> Result<Option<Rooted<ExternRef>>> {
245 let hv = caller.data_mut().host.const_string(id);
246 make_ref(&mut caller, hv)
247 },
248 )?;
249
250 linker.func_wrap(
252 "env",
253 "const_name",
254 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>,
255 id: i32|
256 -> Result<Option<Rooted<ExternRef>>> {
257 let hv = caller.data_mut().host.const_name(id);
258 make_ref(&mut caller, hv)
259 },
260 )?;
261
262 linker.func_wrap(
264 "env",
265 "const_time",
266 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>,
267 nanos: i64|
268 -> Result<Option<Rooted<ExternRef>>> {
269 let hv = caller.data_mut().host.const_time(nanos);
270 make_ref(&mut caller, hv)
271 },
272 )?;
273
274 linker.func_wrap(
276 "env",
277 "const_duration",
278 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>,
279 nanos: i64|
280 -> Result<Option<Rooted<ExternRef>>> {
281 let hv = caller.data_mut().host.const_duration(nanos);
282 make_ref(&mut caller, hv)
283 },
284 )?;
285
286 macro_rules! binop {
288 ($name:expr, $method:ident) => {
289 linker.func_wrap(
290 "env",
291 $name,
292 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>,
293 a: Option<Rooted<ExternRef>>,
294 b: Option<Rooted<ExternRef>>|
295 -> Result<Option<Rooted<ExternRef>>> {
296 let a_hv = extract_hv(&a, &caller);
297 let b_hv = extract_hv(&b, &caller);
298 let result = caller.data_mut().host.$method(a_hv, b_hv);
299 make_ref(&mut caller, result)
300 },
301 )?;
302 };
303 }
304 binop!("val_add", val_add);
305 binop!("val_sub", val_sub);
306 binop!("val_mul", val_mul);
307 binop!("val_div", val_div);
308
309 linker.func_wrap(
311 "env",
312 "val_sqrt",
313 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>,
314 a: Option<Rooted<ExternRef>>|
315 -> Result<Option<Rooted<ExternRef>>> {
316 let a_hv = extract_hv(&a, &caller);
317 let result = caller.data_mut().host.val_sqrt(a_hv);
318 make_ref(&mut caller, result)
319 },
320 )?;
321
322 macro_rules! cmpop {
324 ($name:expr, $method:ident) => {
325 linker.func_wrap(
326 "env",
327 $name,
328 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>,
329 a: Option<Rooted<ExternRef>>,
330 b: Option<Rooted<ExternRef>>|
331 -> i32 {
332 let a_hv = extract_hv(&a, &caller);
333 let b_hv = extract_hv(&b, &caller);
334 caller.data_mut().host.$method(a_hv, b_hv)
335 },
336 )?;
337 };
338 }
339 cmpop!("val_eq", val_eq);
340 cmpop!("val_neq", val_neq);
341 cmpop!("val_lt", val_lt);
342 cmpop!("val_le", val_le);
343 cmpop!("val_gt", val_gt);
344 cmpop!("val_ge", val_ge);
345
346 binop!("str_concat", str_concat);
348
349 linker.func_wrap(
351 "env",
352 "str_replace",
353 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>,
354 s: Option<Rooted<ExternRef>>,
355 old: Option<Rooted<ExternRef>>,
356 new: Option<Rooted<ExternRef>>,
357 count: Option<Rooted<ExternRef>>|
358 -> Result<Option<Rooted<ExternRef>>> {
359 let s_hv = extract_hv(&s, &caller);
360 let old_hv = extract_hv(&old, &caller);
361 let new_hv = extract_hv(&new, &caller);
362 let count_hv = extract_hv(&count, &caller);
363 let result = caller.data_mut().host.str_replace(s_hv, old_hv, new_hv, count_hv);
364 make_ref(&mut caller, result)
365 },
366 )?;
367
368 linker.func_wrap(
370 "env",
371 "val_to_string",
372 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>,
373 a: Option<Rooted<ExternRef>>|
374 -> Result<Option<Rooted<ExternRef>>> {
375 let a_hv = extract_hv(&a, &caller);
376 let result = caller.data_mut().host.val_to_string(a_hv);
377 make_ref(&mut caller, result)
378 },
379 )?;
380
381 linker.func_wrap(
383 "env",
384 "compound_begin",
385 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>, kind: i32| {
386 caller.data_mut().host.compound_begin(kind);
387 },
388 )?;
389
390 linker.func_wrap(
392 "env",
393 "compound_push",
394 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>,
395 val: Option<Rooted<ExternRef>>| {
396 let hv = extract_hv(&val, &caller);
397 caller.data_mut().host.compound_push(hv);
398 },
399 )?;
400
401 linker.func_wrap(
403 "env",
404 "compound_end",
405 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>|
406 -> Result<Option<Rooted<ExternRef>>> {
407 let result = caller.data_mut().host.compound_end();
408 make_ref(&mut caller, result)
409 },
410 )?;
411
412 binop!("compound_get", compound_get);
414
415 linker.func_wrap(
417 "env",
418 "compound_len",
419 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>,
420 a: Option<Rooted<ExternRef>>|
421 -> Result<Option<Rooted<ExternRef>>> {
422 let a_hv = extract_hv(&a, &caller);
423 let result = caller.data_mut().host.compound_len(a_hv);
424 make_ref(&mut caller, result)
425 },
426 )?;
427
428 linker.func_wrap(
430 "env",
431 "pair_first",
432 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>,
433 a: Option<Rooted<ExternRef>>|
434 -> Result<Option<Rooted<ExternRef>>> {
435 let a_hv = extract_hv(&a, &caller);
436 let result = caller.data_mut().host.pair_first(a_hv);
437 make_ref(&mut caller, result)
438 },
439 )?;
440
441 linker.func_wrap(
443 "env",
444 "pair_second",
445 |mut caller: wasmtime::Caller<'_, HostWrapper<H>>,
446 a: Option<Rooted<ExternRef>>|
447 -> Result<Option<Rooted<ExternRef>>> {
448 let a_hv = extract_hv(&a, &caller);
449 let result = caller.data_mut().host.pair_second(a_hv);
450 make_ref(&mut caller, result)
451 },
452 )?;
453
454 let instance = linker.instantiate(&mut store, &module)?;
455 let run = instance.get_typed_func::<(), ()>(&mut store, "run")?;
456 run.call(&mut store, ())?;
457
458 Ok(())
459 }
460}
461
462pub struct DummyHost;
464impl Host for DummyHost {
465 fn scan_start(&mut self, _rel_id: i32) -> i32 { 0 }
466 fn scan_delta_start(&mut self, _rel_id: i32) -> i32 { 0 }
467 fn scan_next(&mut self, _iter_id: i32) -> i32 { 0 }
468 fn merge_deltas(&mut self) -> i32 { 0 }
469 fn scan_aggregate_start(&mut self, _rel_id: i32, _desc: Vec<i32>) -> i32 { 0 }
470 fn scan_index_start(&mut self, _rel_id: i32, _col_idx: i32, _val: HostVal) -> i32 { 0 }
471 fn get_col(&mut self, _ptr: i32, _idx: i32) -> HostVal { HostVal(0) }
472 fn insert_begin(&mut self, _rel_id: i32) {}
473 fn insert_push(&mut self, _val: HostVal) {}
474 fn insert_end(&mut self) {}
475 fn const_number(&mut self, _n: i64) -> HostVal { HostVal(0) }
476 fn const_float(&mut self, _bits: i64) -> HostVal { HostVal(0) }
477 fn const_string(&mut self, _id: i32) -> HostVal { HostVal(0) }
478 fn const_name(&mut self, _id: i32) -> HostVal { HostVal(0) }
479 fn const_time(&mut self, _nanos: i64) -> HostVal { HostVal(0) }
480 fn const_duration(&mut self, _nanos: i64) -> HostVal { HostVal(0) }
481 fn val_add(&mut self, _a: HostVal, _b: HostVal) -> HostVal { HostVal(0) }
482 fn val_sub(&mut self, _a: HostVal, _b: HostVal) -> HostVal { HostVal(0) }
483 fn val_mul(&mut self, _a: HostVal, _b: HostVal) -> HostVal { HostVal(0) }
484 fn val_div(&mut self, _a: HostVal, _b: HostVal) -> HostVal { HostVal(0) }
485 fn val_sqrt(&mut self, _a: HostVal) -> HostVal { HostVal(0) }
486 fn val_eq(&mut self, _a: HostVal, _b: HostVal) -> i32 { 0 }
487 fn val_neq(&mut self, _a: HostVal, _b: HostVal) -> i32 { 0 }
488 fn val_lt(&mut self, _a: HostVal, _b: HostVal) -> i32 { 0 }
489 fn val_le(&mut self, _a: HostVal, _b: HostVal) -> i32 { 0 }
490 fn val_gt(&mut self, _a: HostVal, _b: HostVal) -> i32 { 0 }
491 fn val_ge(&mut self, _a: HostVal, _b: HostVal) -> i32 { 0 }
492 fn str_concat(&mut self, _a: HostVal, _b: HostVal) -> HostVal { HostVal(0) }
493 fn str_replace(&mut self, _s: HostVal, _old: HostVal, _new: HostVal, _count: HostVal) -> HostVal { HostVal(0) }
494 fn val_to_string(&mut self, _val: HostVal) -> HostVal { HostVal(0) }
495 fn compound_begin(&mut self, _kind: i32) {}
496 fn compound_push(&mut self, _val: HostVal) {}
497 fn compound_end(&mut self) -> HostVal { HostVal(0) }
498 fn compound_get(&mut self, _compound: HostVal, _key: HostVal) -> HostVal { HostVal(0) }
499 fn compound_len(&mut self, _compound: HostVal) -> HostVal { HostVal(0) }
500 fn pair_first(&mut self, _compound: HostVal) -> HostVal { HostVal(0) }
501 fn pair_second(&mut self, _compound: HostVal) -> HostVal { HostVal(0) }
502 fn debuglog(&mut self, _val: HostVal) {}
503}
504
505#[cfg(test)]
506mod tests {
507 use super::*;
508 use fxhash::FxHashSet;
509 use mangle_analysis::{rewrite_unit, LoweringContext, Program};
510 use mangle_ast as ast;
511 use mangle_codegen::{Codegen, WasmImportsBackend};
512 use mangle_parse::Parser;
513 use std::collections::HashMap;
514
515 #[test]
516 fn test_e2e_execution() -> Result<()> {
517 let arena = ast::Arena::new_with_global_interner();
518 let foo = arena.predicate_sym("foo", Some(1));
519 let bar = arena.predicate_sym("bar", Some(1));
520 let x = arena.variable("X");
521
522 let clause = ast::Clause {
523 head: arena.atom(foo, &[x]),
524 head_time: None,
525 premises: arena
526 .alloc_slice_copy(&[arena.alloc(ast::Term::Atom(arena.atom(bar, &[x])))]),
527 transform: &[],
528 };
529 let unit = ast::Unit {
530 decls: &[],
531 clauses: arena.alloc_slice_copy(&[&clause]),
532 };
533
534 let ctx = LoweringContext::new(&arena);
535 let mut ir = ctx.lower_unit(&unit);
536
537 let mut codegen = Codegen::new(&mut ir, WasmImportsBackend);
538 let compiled = codegen.generate();
539
540 let vm = Vm::new()?;
541 vm.execute(&compiled.wasm, DummyHost, compiled.strings, compiled.names)?;
542
543 Ok(())
544 }
545
546 #[test]
547 fn test_e2e_function() -> Result<()> {
548 let arena = ast::Arena::new_with_global_interner();
549 let foo = arena.predicate_sym("foo", Some(1));
550 let plus = arena.function_sym("fn:plus", Some(2));
551
552 let c1 = arena.const_(ast::Const::Number(1));
553 let c2 = arena.const_(ast::Const::Number(2));
554
555 let head_arg = arena.apply_fn(plus, &[c1, c2]);
556 let clause = ast::Clause {
557 head: arena.atom(foo, &[head_arg]),
558 head_time: None,
559 premises: &[],
560 transform: &[],
561 };
562
563 let unit = ast::Unit {
564 decls: &[],
565 clauses: arena.alloc_slice_copy(&[&clause]),
566 };
567
568 let ctx = LoweringContext::new(&arena);
569 let mut ir = ctx.lower_unit(&unit);
570
571 let mut codegen = Codegen::new(&mut ir, WasmImportsBackend);
572 let compiled = codegen.generate();
573
574 let vm = Vm::new()?;
575 vm.execute(&compiled.wasm, DummyHost, compiled.strings, compiled.names)?;
576 Ok(())
577 }
578
579 #[derive(Debug, Clone, PartialEq)]
582 enum Val {
583 Number(i64),
584 Float(f64),
585 String(String),
586 Time(i64),
587 Duration(i64),
588 Compound(i32, Vec<HostVal>),
590 }
591
592 struct MemHost {
593 values: Vec<Val>,
595 data: HashMap<i32, Vec<Vec<HostVal>>>,
597 iters: HashMap<i32, (i32, usize)>,
598 next_iter_id: i32,
599 pending_rel: i32,
601 pending_tuple: Vec<HostVal>,
602 compound_kind: i32,
604 compound_elems: Vec<HostVal>,
605 strings: Vec<String>,
607 names: Vec<String>,
608 }
609
610 impl MemHost {
611 fn new(strings: Vec<String>, names: Vec<String>) -> Self {
612 Self {
613 values: Vec::new(),
614 data: HashMap::new(),
615 iters: HashMap::new(),
616 next_iter_id: 1,
617 pending_rel: 0,
618 pending_tuple: Vec::new(),
619 compound_kind: 0,
620 compound_elems: Vec::new(),
621 strings,
622 names,
623 }
624 }
625
626 fn hash_name(name: &str) -> i32 {
627 let mut hash: u32 = 5381;
628 for c in name.bytes() {
629 hash = ((hash << 5).wrapping_add(hash)).wrapping_add(c as u32);
630 }
631 hash as i32
632 }
633
634 fn alloc(&mut self, v: Val) -> HostVal {
635 let idx = self.values.len() as u32;
636 self.values.push(v);
637 HostVal(idx)
638 }
639
640 fn get_val(&self, hv: HostVal) -> &Val {
641 &self.values[hv.0 as usize]
642 }
643
644 fn val_to_str(&self, hv: HostVal) -> String {
645 match self.get_val(hv) {
646 Val::Number(n) => n.to_string(),
647 Val::Float(f) => f.to_string(),
648 Val::String(s) => s.clone(),
649 Val::Time(t) => format!("time({})", t),
650 Val::Duration(d) => format!("duration({})", d),
651 Val::Compound(kind, _) => format!("compound(kind={})", kind),
652 }
653 }
654
655 fn add_number_fact(&mut self, rel: &str, args: &[i64]) {
656 let id = Self::hash_name(rel);
657 let hvs: Vec<HostVal> = args.iter().map(|n| self.alloc(Val::Number(*n))).collect();
658 self.data.entry(id).or_default().push(hvs);
659 }
660
661 fn add_string_fact(&mut self, rel: &str, args: &[&str]) {
662 let id = Self::hash_name(rel);
663 let hvs: Vec<HostVal> = args.iter().map(|s| self.alloc(Val::String(s.to_string()))).collect();
664 self.data.entry(id).or_default().push(hvs);
665 }
666
667 fn get_number_facts(&self, rel: &str) -> Vec<Vec<i64>> {
668 let id = Self::hash_name(rel);
669 self.data
670 .get(&id)
671 .map(|tuples| {
672 tuples
673 .iter()
674 .map(|t| {
675 t.iter()
676 .map(|hv| match self.get_val(*hv) {
677 Val::Number(n) => *n,
678 _ => 0,
679 })
680 .collect()
681 })
682 .collect()
683 })
684 .unwrap_or_default()
685 }
686
687 fn get_string_facts(&self, rel: &str) -> Vec<Vec<String>> {
688 let id = Self::hash_name(rel);
689 self.data
690 .get(&id)
691 .map(|tuples| {
692 tuples
693 .iter()
694 .map(|t| {
695 t.iter()
696 .map(|hv| match self.get_val(*hv) {
697 Val::String(s) => s.clone(),
698 other => format!("{:?}", other),
699 })
700 .collect()
701 })
702 .collect()
703 })
704 .unwrap_or_default()
705 }
706
707 fn get_val_facts(&self, rel: &str) -> Vec<Vec<Val>> {
708 let id = Self::hash_name(rel);
709 self.data
710 .get(&id)
711 .map(|tuples| {
712 tuples
713 .iter()
714 .map(|t| t.iter().map(|hv| self.get_val(*hv).clone()).collect())
715 .collect()
716 })
717 .unwrap_or_default()
718 }
719 }
720
721 impl Host for MemHost {
722 fn scan_start(&mut self, rel_id: i32) -> i32 {
723 let id = self.next_iter_id;
724 self.next_iter_id += 1;
725 self.iters.insert(id, (rel_id, 0));
726 id
727 }
728 fn scan_delta_start(&mut self, rel_id: i32) -> i32 {
729 self.scan_start(rel_id)
730 }
731 fn scan_next(&mut self, iter_id: i32) -> i32 {
732 if let Some((rel_id, idx)) = self.iters.get_mut(&iter_id)
733 && let Some(tuples) = self.data.get(rel_id)
734 && *idx < tuples.len()
735 {
736 let ptr = (iter_id << 16) | (*idx as i32 + 1);
737 *idx += 1;
738 return ptr;
739 }
740 0
741 }
742 fn merge_deltas(&mut self) -> i32 { 0 }
743 fn scan_aggregate_start(&mut self, _rel_id: i32, _desc: Vec<i32>) -> i32 { 0 }
744 fn scan_index_start(&mut self, _rel_id: i32, _col_idx: i32, _val: HostVal) -> i32 { 0 }
745
746 fn get_col(&mut self, ptr: i32, col_idx: i32) -> HostVal {
747 let iter_id = ptr >> 16;
748 let tuple_idx = (ptr & 0xFFFF) - 1;
749 if let Some((rel_id, _)) = self.iters.get(&iter_id)
750 && let Some(tuples) = self.data.get(rel_id)
751 {
752 return tuples[tuple_idx as usize][col_idx as usize];
753 }
754 HostVal(0)
755 }
756
757 fn insert_begin(&mut self, rel_id: i32) {
758 self.pending_rel = rel_id;
759 self.pending_tuple.clear();
760 }
761 fn insert_push(&mut self, val: HostVal) {
762 self.pending_tuple.push(val);
763 }
764 fn insert_end(&mut self) {
765 let tuple = std::mem::take(&mut self.pending_tuple);
766 self.data.entry(self.pending_rel).or_default().push(tuple);
767 }
768
769 fn const_number(&mut self, n: i64) -> HostVal { self.alloc(Val::Number(n)) }
770 fn const_float(&mut self, bits: i64) -> HostVal { self.alloc(Val::Float(f64::from_bits(bits as u64))) }
771 fn const_string(&mut self, id: i32) -> HostVal {
772 let s = self.strings.get((id - 1) as usize).cloned().unwrap_or_default();
773 self.alloc(Val::String(s))
774 }
775 fn const_name(&mut self, id: i32) -> HostVal {
776 let s = self.names.get((id - 1) as usize).cloned().unwrap_or_default();
777 self.alloc(Val::String(s))
778 }
779 fn const_time(&mut self, nanos: i64) -> HostVal { self.alloc(Val::Time(nanos)) }
780 fn const_duration(&mut self, nanos: i64) -> HostVal { self.alloc(Val::Duration(nanos)) }
781
782 fn val_add(&mut self, a: HostVal, b: HostVal) -> HostVal {
783 let result = match (self.get_val(a), self.get_val(b)) {
784 (Val::Number(a), Val::Number(b)) => Val::Number(a + b),
785 (Val::Float(a), Val::Float(b)) => Val::Float(a + b),
786 (Val::Number(a), Val::Float(b)) => Val::Float(*a as f64 + b),
787 (Val::Float(a), Val::Number(b)) => Val::Float(a + *b as f64),
788 _ => Val::Number(0),
789 };
790 self.alloc(result)
791 }
792 fn val_sub(&mut self, a: HostVal, b: HostVal) -> HostVal {
793 let result = match (self.get_val(a), self.get_val(b)) {
794 (Val::Number(a), Val::Number(b)) => Val::Number(a - b),
795 (Val::Float(a), Val::Float(b)) => Val::Float(a - b),
796 (Val::Number(a), Val::Float(b)) => Val::Float(*a as f64 - b),
797 (Val::Float(a), Val::Number(b)) => Val::Float(a - *b as f64),
798 _ => Val::Number(0),
799 };
800 self.alloc(result)
801 }
802 fn val_mul(&mut self, a: HostVal, b: HostVal) -> HostVal {
803 let result = match (self.get_val(a), self.get_val(b)) {
804 (Val::Number(a), Val::Number(b)) => Val::Number(a * b),
805 (Val::Float(a), Val::Float(b)) => Val::Float(a * b),
806 (Val::Number(a), Val::Float(b)) => Val::Float(*a as f64 * b),
807 (Val::Float(a), Val::Number(b)) => Val::Float(a * *b as f64),
808 _ => Val::Number(0),
809 };
810 self.alloc(result)
811 }
812 fn val_div(&mut self, a: HostVal, b: HostVal) -> HostVal {
813 let result = match (self.get_val(a), self.get_val(b)) {
814 (Val::Number(a), Val::Number(b)) if *b != 0 => Val::Number(a / b),
815 (Val::Float(a), Val::Float(b)) => Val::Float(a / b),
816 (Val::Number(a), Val::Float(b)) => Val::Float(*a as f64 / b),
817 (Val::Float(a), Val::Number(b)) => Val::Float(a / *b as f64),
818 _ => Val::Number(0),
819 };
820 self.alloc(result)
821 }
822 fn val_sqrt(&mut self, a: HostVal) -> HostVal {
823 let result = match self.get_val(a) {
824 Val::Float(f) => Val::Float(f.sqrt()),
825 Val::Number(n) => Val::Float((*n as f64).sqrt()),
826 _ => Val::Float(0.0),
827 };
828 self.alloc(result)
829 }
830 fn val_eq(&mut self, a: HostVal, b: HostVal) -> i32 {
831 (self.get_val(a) == self.get_val(b)) as i32
832 }
833 fn val_neq(&mut self, a: HostVal, b: HostVal) -> i32 {
834 (self.get_val(a) != self.get_val(b)) as i32
835 }
836 fn val_lt(&mut self, a: HostVal, b: HostVal) -> i32 {
837 match (self.get_val(a), self.get_val(b)) {
838 (Val::Number(a), Val::Number(b)) => (a < b) as i32,
839 (Val::Float(a), Val::Float(b)) => (a < b) as i32,
840 _ => 0,
841 }
842 }
843 fn val_le(&mut self, a: HostVal, b: HostVal) -> i32 {
844 match (self.get_val(a), self.get_val(b)) {
845 (Val::Number(a), Val::Number(b)) => (a <= b) as i32,
846 (Val::Float(a), Val::Float(b)) => (a <= b) as i32,
847 _ => 0,
848 }
849 }
850 fn val_gt(&mut self, a: HostVal, b: HostVal) -> i32 {
851 match (self.get_val(a), self.get_val(b)) {
852 (Val::Number(a), Val::Number(b)) => (a > b) as i32,
853 (Val::Float(a), Val::Float(b)) => (a > b) as i32,
854 _ => 0,
855 }
856 }
857 fn val_ge(&mut self, a: HostVal, b: HostVal) -> i32 {
858 match (self.get_val(a), self.get_val(b)) {
859 (Val::Number(a), Val::Number(b)) => (a >= b) as i32,
860 (Val::Float(a), Val::Float(b)) => (a >= b) as i32,
861 _ => 0,
862 }
863 }
864 fn str_concat(&mut self, a: HostVal, b: HostVal) -> HostVal {
865 let sa = self.val_to_str(a);
866 let sb = self.val_to_str(b);
867 self.alloc(Val::String(format!("{}{}", sa, sb)))
868 }
869 fn str_replace(&mut self, s: HostVal, old: HostVal, new: HostVal, count: HostVal) -> HostVal {
870 let s_str = self.val_to_str(s);
871 let old_str = self.val_to_str(old);
872 let new_str = self.val_to_str(new);
873 let count_val = match self.get_val(count) {
874 Val::Number(n) => *n,
875 _ => -1,
876 };
877 let result = if count_val < 0 {
878 s_str.replace(&old_str, &new_str)
879 } else {
880 s_str.replacen(&old_str, &new_str, count_val as usize)
881 };
882 self.alloc(Val::String(result))
883 }
884 fn val_to_string(&mut self, val: HostVal) -> HostVal {
885 let s = self.val_to_str(val);
886 self.alloc(Val::String(s))
887 }
888 fn compound_begin(&mut self, kind: i32) {
889 self.compound_kind = kind;
890 self.compound_elems.clear();
891 }
892 fn compound_push(&mut self, val: HostVal) {
893 self.compound_elems.push(val);
894 }
895 fn compound_end(&mut self) -> HostVal {
896 let elems = std::mem::take(&mut self.compound_elems);
897 self.alloc(Val::Compound(self.compound_kind, elems))
898 }
899 fn compound_get(&mut self, compound: HostVal, key: HostVal) -> HostVal {
900 if let Val::Compound(kind, elems) = self.get_val(compound).clone() {
901 match kind {
902 0 => {
903 if let Val::Number(idx) = self.get_val(key) {
905 return elems.get(*idx as usize).copied().unwrap_or(HostVal(0));
906 }
907 }
908 2 | 3 => {
909 for i in (0..elems.len()).step_by(2) {
911 if i + 1 < elems.len() && self.get_val(elems[i]) == self.get_val(key) {
912 return elems[i + 1];
913 }
914 }
915 }
916 _ => {}
917 }
918 }
919 HostVal(0)
920 }
921 fn compound_len(&mut self, compound: HostVal) -> HostVal {
922 if let Val::Compound(kind, elems) = self.get_val(compound).clone() {
923 let len = match kind {
924 0 | 1 => elems.len() as i64, 2 | 3 => (elems.len() / 2) as i64, _ => 0,
927 };
928 return self.alloc(Val::Number(len));
929 }
930 self.alloc(Val::Number(0))
931 }
932 fn pair_first(&mut self, compound: HostVal) -> HostVal {
933 if let Val::Compound(_, elems) = self.get_val(compound).clone() {
934 return elems.first().copied().unwrap_or(HostVal(0));
935 }
936 HostVal(0)
937 }
938 fn pair_second(&mut self, compound: HostVal) -> HostVal {
939 if let Val::Compound(_, elems) = self.get_val(compound).clone() {
940 return elems.get(1).copied().unwrap_or(HostVal(0));
941 }
942 HostVal(0)
943 }
944 fn debuglog(&mut self, val: HostVal) {
945 eprintln!("WASM LOG: {:?}", self.get_val(val));
946 }
947 }
948
949 use std::sync::{Arc, Mutex};
952
953 #[derive(Clone)]
954 struct SharedMemHost {
955 inner: Arc<Mutex<MemHost>>,
956 }
957
958 macro_rules! delegate_host {
959 () => {
960 fn scan_start(&mut self, rel_id: i32) -> i32 { self.inner.lock().unwrap().scan_start(rel_id) }
961 fn scan_delta_start(&mut self, rel_id: i32) -> i32 { self.inner.lock().unwrap().scan_delta_start(rel_id) }
962 fn scan_next(&mut self, iter_id: i32) -> i32 { self.inner.lock().unwrap().scan_next(iter_id) }
963 fn merge_deltas(&mut self) -> i32 { self.inner.lock().unwrap().merge_deltas() }
964 fn scan_aggregate_start(&mut self, rel_id: i32, desc: Vec<i32>) -> i32 { self.inner.lock().unwrap().scan_aggregate_start(rel_id, desc) }
965 fn scan_index_start(&mut self, rel_id: i32, col_idx: i32, val: HostVal) -> i32 { self.inner.lock().unwrap().scan_index_start(rel_id, col_idx, val) }
966 fn get_col(&mut self, ptr: i32, idx: i32) -> HostVal { self.inner.lock().unwrap().get_col(ptr, idx) }
967 fn insert_begin(&mut self, rel_id: i32) { self.inner.lock().unwrap().insert_begin(rel_id) }
968 fn insert_push(&mut self, val: HostVal) { self.inner.lock().unwrap().insert_push(val) }
969 fn insert_end(&mut self) { self.inner.lock().unwrap().insert_end() }
970 fn const_number(&mut self, n: i64) -> HostVal { self.inner.lock().unwrap().const_number(n) }
971 fn const_float(&mut self, bits: i64) -> HostVal { self.inner.lock().unwrap().const_float(bits) }
972 fn const_string(&mut self, id: i32) -> HostVal { self.inner.lock().unwrap().const_string(id) }
973 fn const_name(&mut self, id: i32) -> HostVal { self.inner.lock().unwrap().const_name(id) }
974 fn const_time(&mut self, nanos: i64) -> HostVal { self.inner.lock().unwrap().const_time(nanos) }
975 fn const_duration(&mut self, nanos: i64) -> HostVal { self.inner.lock().unwrap().const_duration(nanos) }
976 fn val_add(&mut self, a: HostVal, b: HostVal) -> HostVal { self.inner.lock().unwrap().val_add(a, b) }
977 fn val_sub(&mut self, a: HostVal, b: HostVal) -> HostVal { self.inner.lock().unwrap().val_sub(a, b) }
978 fn val_mul(&mut self, a: HostVal, b: HostVal) -> HostVal { self.inner.lock().unwrap().val_mul(a, b) }
979 fn val_div(&mut self, a: HostVal, b: HostVal) -> HostVal { self.inner.lock().unwrap().val_div(a, b) }
980 fn val_sqrt(&mut self, a: HostVal) -> HostVal { self.inner.lock().unwrap().val_sqrt(a) }
981 fn val_eq(&mut self, a: HostVal, b: HostVal) -> i32 { self.inner.lock().unwrap().val_eq(a, b) }
982 fn val_neq(&mut self, a: HostVal, b: HostVal) -> i32 { self.inner.lock().unwrap().val_neq(a, b) }
983 fn val_lt(&mut self, a: HostVal, b: HostVal) -> i32 { self.inner.lock().unwrap().val_lt(a, b) }
984 fn val_le(&mut self, a: HostVal, b: HostVal) -> i32 { self.inner.lock().unwrap().val_le(a, b) }
985 fn val_gt(&mut self, a: HostVal, b: HostVal) -> i32 { self.inner.lock().unwrap().val_gt(a, b) }
986 fn val_ge(&mut self, a: HostVal, b: HostVal) -> i32 { self.inner.lock().unwrap().val_ge(a, b) }
987 fn str_concat(&mut self, a: HostVal, b: HostVal) -> HostVal { self.inner.lock().unwrap().str_concat(a, b) }
988 fn str_replace(&mut self, s: HostVal, old: HostVal, new: HostVal, count: HostVal) -> HostVal { self.inner.lock().unwrap().str_replace(s, old, new, count) }
989 fn val_to_string(&mut self, val: HostVal) -> HostVal { self.inner.lock().unwrap().val_to_string(val) }
990 fn compound_begin(&mut self, kind: i32) { self.inner.lock().unwrap().compound_begin(kind) }
991 fn compound_push(&mut self, val: HostVal) { self.inner.lock().unwrap().compound_push(val) }
992 fn compound_end(&mut self) -> HostVal { self.inner.lock().unwrap().compound_end() }
993 fn compound_get(&mut self, compound: HostVal, key: HostVal) -> HostVal { self.inner.lock().unwrap().compound_get(compound, key) }
994 fn compound_len(&mut self, compound: HostVal) -> HostVal { self.inner.lock().unwrap().compound_len(compound) }
995 fn pair_first(&mut self, compound: HostVal) -> HostVal { self.inner.lock().unwrap().pair_first(compound) }
996 fn pair_second(&mut self, compound: HostVal) -> HostVal { self.inner.lock().unwrap().pair_second(compound) }
997 fn debuglog(&mut self, val: HostVal) { self.inner.lock().unwrap().debuglog(val) }
998 };
999 }
1000
1001 impl Host for SharedMemHost {
1002 delegate_host!();
1003 }
1004
1005 fn run_host_wasm(compiled: &mangle_codegen::CompiledModule, host: MemHost) -> Result<MemHost> {
1007 let shared_host = SharedMemHost {
1008 inner: Arc::new(Mutex::new(host)),
1009 };
1010 let vm = Vm::new()?;
1011 vm.execute(
1012 &compiled.wasm,
1013 shared_host.clone(),
1014 compiled.strings.clone(),
1015 compiled.names.clone(),
1016 )?;
1017 let host = Arc::try_unwrap(shared_host.inner)
1018 .map_err(|_| anyhow::anyhow!("Arc still shared"))?
1019 .into_inner()
1020 .unwrap();
1021 Ok(host)
1022 }
1023
1024 fn run_wasm_program(source: &str) -> Result<MemHost> {
1026 let arena = ast::Arena::new_with_global_interner();
1027 let mut parser = Parser::new(&arena, source.as_bytes(), arena.alloc_str("test"));
1028 parser.next_token().map_err(|e| anyhow::anyhow!(e))?;
1029 let unit = parser.parse_unit()?;
1030 let unit = rewrite_unit(&arena, &unit);
1031
1032 let mut program = Program::new(&arena);
1033 let mut all_preds = FxHashSet::default();
1034 let mut idb_preds = FxHashSet::default();
1035 for clause in unit.clauses {
1036 program.add_clause(&arena, clause);
1037 idb_preds.insert(clause.head.sym);
1038 all_preds.insert(clause.head.sym);
1039 for premise in clause.premises {
1040 match premise {
1041 ast::Term::Atom(atom) => { all_preds.insert(atom.sym); }
1042 ast::Term::NegAtom(atom) => { all_preds.insert(atom.sym); }
1043 ast::Term::TemporalAtom(atom, _) => { all_preds.insert(atom.sym); }
1044 _ => {}
1045 }
1046 }
1047 }
1048 for pred in all_preds {
1049 if !idb_preds.contains(&pred) {
1050 program.ext_preds.push(pred);
1051 }
1052 }
1053 let stratified = program.stratify().map_err(|e| anyhow::anyhow!(e))?;
1054
1055 let ctx = LoweringContext::new(&arena);
1056 let mut ir = ctx.lower_unit(&unit);
1057
1058 let mut codegen = Codegen::new_with_stratified(&mut ir, &stratified, WasmImportsBackend);
1059 let compiled = codegen.generate();
1060 let host = MemHost::new(compiled.strings.clone(), compiled.names.clone());
1061 run_host_wasm(&compiled, host)
1062 }
1063
1064 #[test]
1065 fn test_e2e_mem_store() -> Result<()> {
1066 let arena = ast::Arena::new_with_global_interner();
1067 let p = arena.predicate_sym("p", Some(1));
1068 let q = arena.predicate_sym("q", Some(1));
1069 let x = arena.variable("X");
1070
1071 let clause = ast::Clause {
1072 head: arena.atom(p, &[x]),
1073 head_time: None,
1074 premises: arena.alloc_slice_copy(&[arena.alloc(ast::Term::Atom(arena.atom(q, &[x])))]),
1075 transform: &[],
1076 };
1077 let unit = ast::Unit {
1078 decls: &[],
1079 clauses: arena.alloc_slice_copy(&[&clause]),
1080 };
1081
1082 let ctx = LoweringContext::new(&arena);
1083 let mut ir = ctx.lower_unit(&unit);
1084
1085 let mut codegen = Codegen::new(&mut ir, WasmImportsBackend);
1086 let compiled = codegen.generate();
1087
1088 let mut host = MemHost::new(compiled.strings.clone(), compiled.names.clone());
1089 host.add_number_fact("q", &[10]);
1090 host.add_number_fact("q", &[20]);
1091
1092 let host = run_host_wasm(&compiled, host)?;
1093 let results = host.get_number_facts("p");
1094
1095 assert!(results.iter().any(|t| t[0] == 10), "expected 10 in results: {:?}", results);
1096 assert!(results.iter().any(|t| t[0] == 20), "expected 20 in results: {:?}", results);
1097
1098 Ok(())
1099 }
1100
1101 #[test]
1104 fn test_wasm_string_constant() -> Result<()> {
1105 let host = run_wasm_program(r#"
1106 p("hello").
1107 q(X) :- p(X).
1108 "#)?;
1109 let results = host.get_string_facts("q");
1110 assert_eq!(results.len(), 1);
1111 assert_eq!(results[0][0], "hello");
1112 Ok(())
1113 }
1114
1115 #[test]
1116 fn test_wasm_string_equality() -> Result<()> {
1117 let host = run_wasm_program(r#"
1118 p("hello"). p("world").
1119 q(X) :- p(X), X = "hello".
1120 "#)?;
1121 let results = host.get_string_facts("q");
1122 assert_eq!(results.len(), 1);
1123 assert_eq!(results[0][0], "hello");
1124 Ok(())
1125 }
1126
1127 #[test]
1128 fn test_wasm_string_concat() -> Result<()> {
1129 let host = run_wasm_program(r#"
1130 p("hello", "world").
1131 q(R) :- p(A, B) |> let R = fn:string:concat(A, " ", B).
1132 "#)?;
1133 let results = host.get_string_facts("q");
1134 assert_eq!(results.len(), 1);
1135 assert_eq!(results[0][0], "hello world");
1136 Ok(())
1137 }
1138
1139 #[test]
1140 fn test_wasm_string_replace() -> Result<()> {
1141 let host = run_wasm_program(r#"
1142 p("foo-bar-baz").
1143 q(R) :- p(S) |> let R = fn:string:replace(S, "-", "_", -1).
1144 "#)?;
1145 let results = host.get_string_facts("q");
1146 assert_eq!(results.len(), 1);
1147 assert_eq!(results[0][0], "foo_bar_baz");
1148 Ok(())
1149 }
1150
1151 #[test]
1152 fn test_wasm_number_to_string() -> Result<()> {
1153 let host = run_wasm_program(r#"
1154 p(42).
1155 q(R) :- p(X) |> let R = fn:number:to_string(X).
1156 "#)?;
1157 let results = host.get_string_facts("q");
1158 assert_eq!(results.len(), 1);
1159 assert_eq!(results[0][0], "42");
1160 Ok(())
1161 }
1162
1163 #[test]
1166 fn test_wasm_list_construction() -> Result<()> {
1167 let host = run_wasm_program(r#"
1168 p(1, 2, 3).
1169 q(L) :- p(A, B, C) |> let L = fn:list(A, B, C).
1170 "#)?;
1171 let results = host.get_val_facts("q");
1172 assert_eq!(results.len(), 1);
1173 if let Val::Compound(kind, elems) = &results[0][0] {
1174 assert_eq!(*kind, 0); assert_eq!(elems.len(), 3);
1176 } else {
1177 panic!("Expected compound, got {:?}", results[0][0]);
1178 }
1179 Ok(())
1180 }
1181
1182 #[test]
1183 fn test_wasm_pair_construction_and_access() -> Result<()> {
1184 let host = run_wasm_program(r#"
1185 p(10, 20).
1186 mid(P) :- p(A, B) |> let P = fn:pair(A, B).
1187 q(F) :- mid(P) |> let F = fn:pair:first(P).
1188 r(S) :- mid(P) |> let S = fn:pair:second(P).
1189 "#)?;
1190 let q_results = host.get_number_facts("q");
1191 let r_results = host.get_number_facts("r");
1192 assert_eq!(q_results.len(), 1);
1193 assert_eq!(q_results[0], vec![10]);
1194 assert_eq!(r_results.len(), 1);
1195 assert_eq!(r_results[0], vec![20]);
1196 Ok(())
1197 }
1198
1199 #[test]
1200 fn test_wasm_list_construction_and_len() -> Result<()> {
1201 let host = run_wasm_program(r#"
1202 p(10, 20, 30).
1203 mid(L) :- p(A, B, C) |> let L = fn:list(A, B, C).
1204 q(N) :- mid(L) |> let N = fn:len(L).
1205 "#)?;
1206 let results = host.get_number_facts("q");
1207 assert_eq!(results.len(), 1);
1208 assert_eq!(results[0], vec![3]);
1209 Ok(())
1210 }
1211
1212 #[test]
1213 fn test_wasm_list_get() -> Result<()> {
1214 let host = run_wasm_program(r#"
1215 p(10, 20, 30).
1216 mid(L) :- p(A, B, C) |> let L = fn:list(A, B, C).
1217 q(E) :- mid(L) |> let E = fn:list:get(L, 1).
1218 "#)?;
1219 let results = host.get_number_facts("q");
1220 assert_eq!(results.len(), 1);
1221 assert_eq!(results[0], vec![20]);
1222 Ok(())
1223 }
1224
1225 #[test]
1226 fn test_wasm_struct_construction_and_get() -> Result<()> {
1227 let host = run_wasm_program(r#"
1228 p("alice", 30).
1229 mid(S) :- p(Name, Age) |> let S = fn:struct(/name, Name, /age, Age).
1230 q(V) :- mid(S) |> let V = fn:struct:get(S, /name).
1231 "#)?;
1232 let results = host.get_string_facts("q");
1233 assert_eq!(results.len(), 1);
1234 assert_eq!(results[0][0], "alice");
1235 Ok(())
1236 }
1237}