1use std::{
19 collections::{HashMap, HashSet},
20 ffi::CString,
21 fmt::Debug,
22 ops::Deref,
23 sync::Arc,
24};
25
26use anyhow::{Context, Result};
27use tokio::sync::{Mutex, OnceCell};
28use tracing::Instrument;
29use wasmtime::{AsContextMut, Caller, Linker, Memory, MemoryType, Module};
30
31use crate::{
32 builtins::traits::Builtin,
33 funcs::{self, Func},
34 types::{AbiVersion, Addr, BuiltinId, EntrypointId, Heap, NulStr, Value},
35 DefaultContext, EvaluationContext,
36};
37
38async fn alloc_str<V: Into<Vec<u8>>, T: Send>(
40 opa_malloc: &funcs::OpaMalloc,
41 mut store: impl AsContextMut<Data = T>,
42 memory: &Memory,
43 value: V,
44) -> Result<Heap> {
45 let value = CString::new(value)?;
46 let value = value.as_bytes_with_nul();
47 let heap = opa_malloc.call(&mut store, value.len()).await?;
48
49 memory.write(
50 &mut store,
51 heap.ptr
52 .try_into()
53 .context("opa_malloc returned an invalid pointer value")?,
54 value,
55 )?;
56
57 Ok(heap)
58}
59
60async fn load_json<V: serde::Serialize, T: Send>(
62 opa_malloc: &funcs::OpaMalloc,
63 opa_free: &funcs::OpaFree,
64 opa_json_parse: &funcs::OpaJsonParse,
65 mut store: impl AsContextMut<Data = T>,
66 memory: &Memory,
67 data: &V,
68) -> Result<Value> {
69 let json = serde_json::to_vec(data)?;
70 let json = alloc_str(opa_malloc, &mut store, memory, json).await?;
71 let data = opa_json_parse.call(&mut store, &json).await?;
72 opa_free.call(&mut store, json).await?;
73 Ok(data)
74}
75
76struct LoadedBuiltins<C> {
78 builtins: HashMap<i32, (String, Box<dyn Builtin<C>>)>,
80
81 context: Mutex<C>,
84}
85
86impl<C> std::fmt::Debug for LoadedBuiltins<C> {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 f.debug_struct("LoadedBuiltins")
89 .field("builtins", &())
90 .finish()
91 }
92}
93
94impl<C> LoadedBuiltins<C>
95where
96 C: EvaluationContext,
97{
98 fn from_map(map: HashMap<String, BuiltinId>, context: C) -> Result<Self> {
100 let res: Result<_> = map
101 .into_iter()
102 .map(|(k, v)| {
103 let builtin = crate::builtins::resolve(&k)?;
104 Ok((v.0, (k, builtin)))
105 })
106 .collect();
107 Ok(Self {
108 builtins: res?,
109 context: Mutex::new(context),
110 })
111 }
112
113 async fn builtin<T: Send, const N: usize>(
115 &self,
116 mut caller: Caller<'_, T>,
117 memory: &Memory,
118 builtin_id: i32,
119 args: [i32; N],
120 ) -> Result<i32, anyhow::Error> {
121 let (name, builtin) = self
122 .builtins
123 .get(&builtin_id)
124 .with_context(|| format!("unknown builtin id {builtin_id}"))?;
125
126 let span = tracing::info_span!("builtin", %name);
127 let _enter = span.enter();
128
129 let opa_json_dump = funcs::OpaJsonDump::from_caller(&mut caller)?;
130 let opa_json_parse = funcs::OpaJsonParse::from_caller(&mut caller)?;
131 let opa_malloc = funcs::OpaMalloc::from_caller(&mut caller)?;
132 let opa_free = funcs::OpaFree::from_caller(&mut caller)?;
133
134 let mut args_json = Vec::with_capacity(N);
136 for arg in args {
137 args_json.push(opa_json_dump.call(&mut caller, &Value(arg)).await?);
138 }
139
140 let mut mapped_args = Vec::with_capacity(N);
142 for arg_json in args_json {
143 let arg = arg_json.read(&caller, memory)?;
144 mapped_args.push(arg.to_bytes());
145 }
146
147 let mut ctx = self.context.lock().await;
148
149 let ret = (async move { builtin.call(&mut ctx, &mapped_args).await })
151 .instrument(tracing::info_span!("builtin.call"))
152 .await?;
153
154 let json = alloc_str(&opa_malloc, &mut caller, memory, ret).await?;
155 let data = opa_json_parse.call(&mut caller, &json).await?;
156 opa_free.call(&mut caller, json).await?;
157
158 Ok(data.0)
159 }
160
161 async fn evaluation_start(&self) {
164 self.context.lock().await.evaluation_start();
165 }
166}
167
168#[allow(clippy::missing_docs_in_private_items)]
171pub struct Runtime<C> {
172 version: AbiVersion,
173 memory: Memory,
174 entrypoints: HashMap<String, EntrypointId>,
175 loaded_builtins: Arc<OnceCell<LoadedBuiltins<C>>>,
176
177 eval_func: funcs::Eval,
178 opa_eval_ctx_new_func: funcs::OpaEvalCtxNew,
179 opa_eval_ctx_set_input_func: funcs::OpaEvalCtxSetInput,
180 opa_eval_ctx_set_data_func: funcs::OpaEvalCtxSetData,
181 opa_eval_ctx_set_entrypoint_func: funcs::OpaEvalCtxSetEntrypoint,
182 opa_eval_ctx_get_result_func: funcs::OpaEvalCtxGetResult,
183 opa_malloc_func: funcs::OpaMalloc,
184 opa_free_func: funcs::OpaFree,
185 opa_json_parse_func: funcs::OpaJsonParse,
186 opa_json_dump_func: funcs::OpaJsonDump,
187 opa_heap_ptr_set_func: funcs::OpaHeapPtrSet,
188 opa_heap_ptr_get_func: funcs::OpaHeapPtrGet,
189 opa_eval_func: Option<funcs::OpaEval>,
190}
191
192impl<C> Debug for Runtime<C> {
193 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 f.debug_struct("Runtime")
195 .field("version", &self.version)
196 .field("memory", &self.memory)
197 .field("entrypoints", &self.entrypoints)
198 .finish_non_exhaustive()
199 }
200}
201
202impl Runtime<DefaultContext> {
203 #[allow(clippy::too_many_lines)]
217 pub async fn new<T: Send + 'static>(
218 store: impl AsContextMut<Data = T>,
219 module: &Module,
220 ) -> Result<Self> {
221 let context = DefaultContext::default();
222 Self::new_with_evaluation_context(store, module, context).await
223 }
224}
225
226impl<C> Runtime<C> {
227 #[allow(clippy::too_many_lines)]
241 pub async fn new_with_evaluation_context<T: Send + 'static>(
242 mut store: impl AsContextMut<Data = T>,
243 module: &Module,
244 context: C,
245 ) -> Result<Self>
246 where
247 C: EvaluationContext,
248 {
249 let ty = MemoryType::new(2, None);
250 let memory = Memory::new_async(&mut store, ty).await?;
251
252 let eventually_builtins = Arc::new(OnceCell::<LoadedBuiltins<C>>::new());
254
255 let mut linker = Linker::new(store.as_context_mut().engine());
256 linker.define(&store, "env", "memory", memory)?;
257
258 linker.func_wrap(
259 "env",
260 "opa_abort",
261 move |caller: Caller<'_, _>, addr: i32| -> wasmtime::Result<()> {
262 let addr = NulStr(addr);
263 let msg = addr
264 .read(&caller, &memory)
265 .map_err(wasmtime::Error::from_anyhow)?;
266 let msg = msg.to_string_lossy().into_owned();
267 tracing::error!("opa_abort: {}", msg);
268 Err(wasmtime::Error::msg(msg))
269 },
270 )?;
271
272 linker.func_wrap(
273 "env",
274 "opa_println",
275 move |caller: Caller<'_, _>, addr: i32| -> wasmtime::Result<()> {
276 let addr = NulStr(addr);
277 let msg = addr
278 .read(&caller, &memory)
279 .map_err(wasmtime::Error::from_anyhow)?;
280 tracing::info!("opa_print: {}", msg.to_string_lossy());
281 Ok(())
282 },
283 )?;
284
285 {
286 let eventually_builtins = eventually_builtins.clone();
287 linker.func_wrap_async(
288 "env",
289 "opa_builtin0",
290 move |caller: Caller<'_, _>, (builtin_id, _ctx): (i32, i32)| {
291 let eventually_builtins = eventually_builtins.clone();
292 Box::new(async move {
293 eventually_builtins
294 .get()
295 .context("builtins where never initialized")
296 .map_err(wasmtime::Error::from_anyhow)?
297 .builtin(caller, &memory, builtin_id, [])
298 .await
299 .map_err(wasmtime::Error::from_anyhow)
300 })
301 },
302 )?;
303 }
304
305 {
306 let eventually_builtins = eventually_builtins.clone();
307 linker.func_wrap_async(
308 "env",
309 "opa_builtin1",
310 move |caller: Caller<'_, _>, (builtin_id, _ctx, param1): (i32, i32, i32)| {
311 let eventually_builtins = eventually_builtins.clone();
312 Box::new(async move {
313 eventually_builtins
314 .get()
315 .context("builtins where never initialized")
316 .map_err(wasmtime::Error::from_anyhow)?
317 .builtin(caller, &memory, builtin_id, [param1])
318 .await
319 .map_err(wasmtime::Error::from_anyhow)
320 })
321 },
322 )?;
323 }
324
325 {
326 let eventually_builtins = eventually_builtins.clone();
327 linker.func_wrap_async(
328 "env",
329 "opa_builtin2",
330 move |caller: Caller<'_, _>,
331 (builtin_id, _ctx, param1, param2): (i32, i32, i32, i32)| {
332 let eventually_builtins = eventually_builtins.clone();
333 Box::new(async move {
334 eventually_builtins
335 .get()
336 .context("builtins where never initialized")
337 .map_err(wasmtime::Error::from_anyhow)?
338 .builtin(caller, &memory, builtin_id, [param1, param2])
339 .await
340 .map_err(wasmtime::Error::from_anyhow)
341 })
342 },
343 )?;
344 }
345
346 {
347 let eventually_builtins = eventually_builtins.clone();
348 linker.func_wrap_async(
349 "env",
350 "opa_builtin3",
351 move |caller: Caller<'_, _>,
352 (builtin_id,
353 _ctx,
354 param1,
355 param2,
356 param3): (i32, i32, i32, i32, i32)| {
357 let eventually_builtins = eventually_builtins.clone();
358 Box::new(async move {
359 eventually_builtins
360 .get()
361 .context("builtins where never initialized")
362 .map_err(wasmtime::Error::from_anyhow)?
363 .builtin(caller, &memory, builtin_id, [param1, param2, param3])
364 .await
365 .map_err(wasmtime::Error::from_anyhow)
366 })
367 },
368 )?;
369 }
370
371 {
372 let eventually_builtins = eventually_builtins.clone();
373 linker.func_wrap_async(
374 "env",
375 "opa_builtin4",
376 move |caller: Caller<'_, _>,
377 (builtin_id, _ctx, param1, param2, param3, param4): (
378 i32,
379 i32,
380 i32,
381 i32,
382 i32,
383 i32,
384 )| {
385 let eventually_builtins = eventually_builtins.clone();
386 Box::new(async move {
387 eventually_builtins
388 .get()
389 .context("builtins where never initialized")
390 .map_err(wasmtime::Error::from_anyhow)?
391 .builtin(
392 caller,
393 &memory,
394 builtin_id,
395 [param1, param2, param3, param4],
396 )
397 .await
398 .map_err(wasmtime::Error::from_anyhow)
399 })
400 },
401 )?;
402 }
403
404 let instance = linker.instantiate_async(&mut store, module).await?;
405
406 let version = AbiVersion::from_instance(&mut store, &instance)?;
407 tracing::debug!(%version, "Module ABI version");
408
409 let opa_json_dump_func = funcs::OpaJsonDump::from_instance(&mut store, &instance)?;
410
411 let builtins = funcs::Builtins::from_instance(&mut store, &instance)?
413 .call(&mut store)
414 .await?;
415 let builtins = opa_json_dump_func
416 .decode(&mut store, &memory, &builtins)
417 .await?;
418 let builtins = LoadedBuiltins::from_map(builtins, context)?;
419 eventually_builtins.set(builtins)?;
420
421 let entrypoints = funcs::Entrypoints::from_instance(&mut store, &instance)?
423 .call(&mut store)
424 .await?;
425 let entrypoints = opa_json_dump_func
426 .decode(&mut store, &memory, &entrypoints)
427 .await?;
428
429 let opa_eval_func = version
430 .has_eval_fastpath()
431 .then(|| funcs::OpaEval::from_instance(&mut store, &instance))
432 .transpose()?;
433
434 Ok(Self {
435 version,
436 memory,
437 entrypoints,
438 loaded_builtins: eventually_builtins,
439
440 eval_func: funcs::Eval::from_instance(&mut store, &instance)?,
441 opa_eval_ctx_new_func: funcs::OpaEvalCtxNew::from_instance(&mut store, &instance)?,
442 opa_eval_ctx_set_input_func: funcs::OpaEvalCtxSetInput::from_instance(
443 &mut store, &instance,
444 )?,
445 opa_eval_ctx_set_data_func: funcs::OpaEvalCtxSetData::from_instance(
446 &mut store, &instance,
447 )?,
448 opa_eval_ctx_set_entrypoint_func: funcs::OpaEvalCtxSetEntrypoint::from_instance(
449 &mut store, &instance,
450 )?,
451 opa_eval_ctx_get_result_func: funcs::OpaEvalCtxGetResult::from_instance(
452 &mut store, &instance,
453 )?,
454 opa_malloc_func: funcs::OpaMalloc::from_instance(&mut store, &instance)?,
455 opa_free_func: funcs::OpaFree::from_instance(&mut store, &instance)?,
456 opa_json_parse_func: funcs::OpaJsonParse::from_instance(&mut store, &instance)?,
457 opa_json_dump_func,
458 opa_heap_ptr_set_func: funcs::OpaHeapPtrSet::from_instance(&mut store, &instance)?,
459 opa_heap_ptr_get_func: funcs::OpaHeapPtrGet::from_instance(&mut store, &instance)?,
460 opa_eval_func,
461 })
462 }
463
464 async fn load_json<V: serde::Serialize, T: Send>(
466 &self,
467 store: impl AsContextMut<Data = T>,
468 data: &V,
469 ) -> Result<Value> {
470 load_json(
471 &self.opa_malloc_func,
472 &self.opa_free_func,
473 &self.opa_json_parse_func,
474 store,
475 &self.memory,
476 data,
477 )
478 .await
479 }
480
481 pub async fn without_data<T: Send>(
487 self,
488 store: impl AsContextMut<Data = T>,
489 ) -> Result<Policy<C>> {
490 let data = serde_json::Value::Object(serde_json::Map::default());
491 self.with_data(store, &data).await
492 }
493
494 pub async fn with_data<V: serde::Serialize, T: Send>(
500 self,
501 mut store: impl AsContextMut<Data = T>,
502 data: &V,
503 ) -> Result<Policy<C>> {
504 let data = self.load_json(&mut store, data).await?;
505 let heap_ptr = self.opa_heap_ptr_get_func.call(&mut store).await?;
506 Ok(Policy {
507 runtime: self,
508 data,
509 heap_ptr,
510 })
511 }
512
513 #[must_use]
516 pub fn default_entrypoint(&self) -> Option<&str> {
517 self.entrypoints
518 .iter()
519 .find_map(|(k, v)| (v.0 == 0).then_some(k.as_str()))
520 }
521
522 #[must_use]
524 pub fn entrypoints(&self) -> HashSet<&str> {
525 self.entrypoints.keys().map(String::as_str).collect()
526 }
527
528 #[must_use]
530 pub fn abi_version(&self) -> AbiVersion {
531 self.version
532 }
533}
534
535#[derive(Debug)]
537pub struct Policy<C> {
538 runtime: Runtime<C>,
540
541 data: Value,
543
544 heap_ptr: Addr,
546}
547
548impl<C> Policy<C> {
549 pub async fn evaluate<V: serde::Serialize, R: for<'de> serde::Deserialize<'de>, T: Send>(
556 &self,
557 mut store: impl AsContextMut<Data = T>,
558 entrypoint: &str,
559 input: &V,
560 ) -> Result<R>
561 where
562 C: EvaluationContext,
563 {
564 let entrypoint = self
566 .runtime
567 .entrypoints
568 .get(entrypoint)
569 .with_context(|| format!("could not find entrypoint {entrypoint}"))?;
570
571 self.loaded_builtins
572 .get()
573 .context("builtins where never initialized")?
574 .evaluation_start()
575 .await;
576
577 if let Some(opa_eval) = &self.runtime.opa_eval_func {
579 let input = serde_json::to_vec(&input)?;
581 let input_heap = Heap {
582 ptr: self.heap_ptr.0,
583 len: input.len().try_into().context("input too long")?,
584 freed: true,
586 };
587
588 let current_pages = self.runtime.memory.size(&store);
590 let needed_pages = input_heap.pages();
591 if current_pages < needed_pages {
592 self.runtime
593 .memory
594 .grow_async(&mut store, needed_pages - current_pages)
595 .await?;
596 }
597
598 self.runtime.memory.write(
600 &mut store,
601 input_heap.ptr.try_into().context("invalid heap pointer")?,
602 &input[..],
603 )?;
604
605 let heap_ptr = Addr(input_heap.end());
606
607 let result = opa_eval
609 .call(&mut store, entrypoint, &self.data, &input_heap, &heap_ptr)
610 .await?;
611
612 let result = result.read(&store, &self.runtime.memory)?;
614 let result = serde_json::from_slice(result.to_bytes())?;
615 Ok(result)
616 } else {
617 self.runtime
619 .opa_heap_ptr_set_func
620 .call(&mut store, &self.heap_ptr)
621 .await?;
622
623 let input = self.runtime.load_json(&mut store, input).await?;
625
626 let ctx = self.runtime.opa_eval_ctx_new_func.call(&mut store).await?;
628
629 self.runtime
631 .opa_eval_ctx_set_data_func
632 .call(&mut store, &ctx, &self.data)
633 .await?;
634 self.runtime
636 .opa_eval_ctx_set_input_func
637 .call(&mut store, &ctx, &input)
638 .await?;
639
640 self.runtime
642 .opa_eval_ctx_set_entrypoint_func
643 .call(&mut store, &ctx, entrypoint)
644 .await?;
645
646 self.runtime.eval_func.call(&mut store, &ctx).await?;
648
649 let result = self
651 .runtime
652 .opa_eval_ctx_get_result_func
653 .call(&mut store, &ctx)
654 .await?;
655
656 let result = self
657 .runtime
658 .opa_json_dump_func
659 .decode(&mut store, &self.runtime.memory, &result)
660 .await?;
661
662 Ok(result)
663 }
664 }
665}
666
667impl<C> Deref for Policy<C> {
668 type Target = Runtime<C>;
669 fn deref(&self) -> &Self::Target {
670 &self.runtime
671 }
672}