#![allow(clippy::cast_possible_truncation)]
use crate::PolicyDecision;
use anyhow::anyhow;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{
borrow::Cow, collections::HashMap, io::copy, mem::ManuallyDrop, string::String, sync::Arc,
};
use wasmtime::{Caller, Engine, Instance, Linker, Memory, MemoryType, Module, Store};
type StrHandler = Box<dyn Fn(&str) + Send + Sync>;
#[derive(Default)]
pub struct OpaBuilder {
abort_cb: Option<StrHandler>,
println_cb: Option<StrHandler>,
buffer_max_mem_pages: Option<u32>,
engine: Engine,
}
impl OpaBuilder {
#[must_use]
pub fn on_abort<F>(mut self, f: F) -> Self
where
F: Fn(&str) + Send + Sync + 'static,
{
self.abort_cb = Some(Box::new(f));
self
}
#[must_use]
pub fn on_println<F>(mut self, f: F) -> Self
where
F: Fn(&str) + Send + Sync + 'static,
{
self.println_cb = Some(Box::new(f));
self
}
#[must_use]
pub fn max_memory_pages(mut self, opt: impl Into<Option<u32>>) -> Self {
self.buffer_max_mem_pages = opt.into();
self
}
#[must_use]
pub fn with_engine(mut self, engine: Engine) -> Self {
self.engine = engine;
self
}
#[cfg(feature = "bundle")]
pub fn build_from_bundle(self, bundle: &crate::bundle::Bundle) -> Result<Opa, anyhow::Error> {
#[cfg(feature = "wasmtime-aot")]
{
match &bundle.wasmtime_bytes {
Some(b) => {
let module = unsafe { Module::deserialize(&self.engine, b)? };
return self.build_module(module);
}
None => {}
}
}
#[cfg(feature = "wasmtime-cranelift")]
{
return self.build(
&bundle
.wasm_policies
.first()
.as_ref()
.ok_or_else(|| anyhow::anyhow!("the bundle must at least one WASM module"))?
.bytes,
);
}
#[allow(unreachable_code)]
Err(anyhow!("the bundle contains no precompiled WASM binary."))
}
#[cfg(feature = "wasmtime-cranelift")]
pub fn build(self, wasm_bytes: impl AsRef<[u8]>) -> Result<Opa, anyhow::Error> {
let m = Module::from_binary(&self.engine, wasm_bytes.as_ref())?;
self.build_module(m)
}
#[allow(clippy::needless_pass_by_value)]
fn build_module(self, module: Module) -> Result<Opa, anyhow::Error> {
let engine = self.engine;
let mut linker = Linker::<()>::new(&engine);
let mut store = Store::new(&engine, ());
let env_buffer = Memory::new(&mut store, MemoryType::new(2, self.buffer_max_mem_pages))?;
let on_abort = Arc::<Box<dyn Fn(&str) + Send + Sync>>::from(
self.abort_cb.unwrap_or_else(|| Box::new(default_opa_abort)),
);
let on_abort1 = on_abort.clone();
let on_println = self
.println_cb
.unwrap_or_else(|| Box::new(default_opa_println));
linker.define("env", "memory", env_buffer)?;
linker.func_wrap(
"env",
"opa_abort",
move |caller: Caller<'_, ()>, addr: u32| {
let addr = addr as usize;
let mem = env_buffer.data(&caller);
let s = null_terminated_str(&mem[addr..]).unwrap_or("invalid string in memory");
on_abort1(s);
},
)?;
linker.func_wrap(
"env",
"opa_println",
move |caller: Caller<'_, ()>, addr: u32| {
let addr = addr as usize;
let mem = env_buffer.data(&caller);
match null_terminated_str(&mem[addr..]) {
Some(s) => on_println(s),
None => on_abort("invalid string in memory"),
}
},
)?;
linker.func_wrap("env", "opa_builtin0", move |_id: u32, _ctx: u32| 0_u32)?;
linker.func_wrap(
"env",
"opa_builtin1",
move |_id: u32, _ctx: u32, _1: u32| 0_u32,
)?;
linker.func_wrap(
"env",
"opa_builtin2",
move |_id: u32, _ctx: u32, _1: u32, _2: u32| 0_u32,
)?;
linker.func_wrap(
"env",
"opa_builtin3",
move |_id: u32, _ctx: u32, _1: u32, _2: u32, _3: u32| 0_u32,
)?;
linker.func_wrap(
"env",
"opa_builtin4",
move |_id: u32, _ctx: u32, _1: u32, _2: u32, _3: u32, _4: u32| 0_u32,
)?;
let instance = linker.instantiate(&mut store, &module)?;
env_buffer.data(&mut store);
let mut opa = Opa {
store,
instance,
env_buffer,
entrypoints: HashMap::default(),
data_heap_ptr: Addr(0),
data_addr: None,
input_heap_ptr: Addr(0),
minor_version: 0,
};
opa.init()?;
Ok(opa)
}
}
#[derive(Debug)]
pub struct Opa {
store: Store<()>,
instance: Instance,
env_buffer: Memory,
minor_version: usize,
entrypoints: HashMap<String, u32>,
data_heap_ptr: Addr,
data_addr: Option<Addr>,
input_heap_ptr: Addr,
}
impl Opa {
#[allow(clippy::new_ret_no_self)]
#[must_use]
pub fn new() -> OpaBuilder {
OpaBuilder::default()
}
pub fn entrypoints(&self) -> impl Iterator<Item = &str> {
self.entrypoints.keys().map(String::as_str)
}
pub fn set_data(&mut self, data: &impl Serialize) -> Result<(), anyhow::Error> {
self.set_heap_ptr(self.data_heap_ptr)?;
self.data_addr = Some(self.write_json(data)?);
self.input_heap_ptr = self.heap_ptr()?;
Ok(())
}
pub fn eval<I, O>(&mut self, entrypoint: &str, input: &I) -> Result<O, anyhow::Error>
where
I: Serialize,
O: DeserializeOwned,
{
if self.minor_version >= 2 {
return self.eval_once(entrypoint, input);
}
let mut ctx = EvalContext::create(self, input)?;
let res = ctx.eval(entrypoint)?;
ctx.destroy()?;
Ok(res)
}
pub fn eval_context<'c>(
&'c mut self,
input: &impl Serialize,
) -> Result<EvalContext<'c>, anyhow::Error> {
EvalContext::create(self, input)
}
pub fn decide<P: PolicyDecision>(
&mut self,
input: &P::Input,
) -> Result<P::Output, anyhow::Error> {
self.eval(P::POLICY_PATH, input)
}
}
impl Opa {
fn init(&mut self) -> Result<(), anyhow::Error> {
self.data_heap_ptr = self.heap_ptr()?;
self.input_heap_ptr = self.data_heap_ptr;
let opa_entrypoints = self
.instance
.get_typed_func::<(), u32, _>(&mut self.store, "entrypoints")?;
let ep_addr = opa_entrypoints.call(&mut self.store, ())?;
self.entrypoints = self.json_at(ep_addr.into())?;
self.minor_version = self
.instance
.get_global(&mut self.store, "opa_wasm_abi_minor_version")
.and_then(|global| global.get(&mut self.store).i32())
.and_then(|int| int.try_into().ok())
.unwrap_or(0);
Ok(())
}
fn bytes_at(&self, addr: Addr) -> Option<&[u8]> {
let data = self.env_buffer.data(&self.store);
null_terminated_slice(&data[addr.into()..])
}
#[allow(dead_code)]
fn str_at(&self, addr: usize) -> Option<&str> {
let data = self.env_buffer.data(&self.store);
null_terminated_str(&data[addr..])
}
fn json_at<T: DeserializeOwned>(&mut self, addr: Addr) -> Result<T, anyhow::Error> {
let opa_json_dump = self
.instance
.get_typed_func::<(u32,), u32, _>(&mut self.store, "opa_json_dump")?;
let json_addr: Addr = opa_json_dump.call(&mut self.store, (addr.into(),))?.into();
let json_result = serde_json::from_slice::<T>(self.bytes_at(json_addr).unwrap());
Ok(json_result?)
}
fn write_json(&mut self, value: &impl Serialize) -> Result<Addr, anyhow::Error> {
let opa_json_parse = self
.instance
.get_typed_func::<(u32, u32), u32, _>(&mut self.store, "opa_json_parse")?;
let json = serde_json::to_vec(value)?;
let json_size = json.len();
let json_bytes_addr = self.write_bytes(json)?;
let parsed_json_addr =
opa_json_parse.call(&mut self.store, (json_bytes_addr.into(), json_size as _))?;
Ok(parsed_json_addr.into())
}
fn write_bytes(&mut self, bytes: impl AsRef<[u8]>) -> Result<Addr, anyhow::Error> {
let bytes = bytes.as_ref();
let (addr, mut data) = self.alloc(bytes.len())?;
copy(&mut &*bytes, &mut data)?;
Ok(addr)
}
fn alloc(&mut self, len: usize) -> Result<(Addr, &mut [u8]), anyhow::Error> {
let opa_malloc = self
.instance
.get_typed_func::<(u32,), u32, _>(&mut self.store, "opa_malloc")?;
let addr = opa_malloc.call(&mut self.store, (len as _,))?;
let data =
&mut self.env_buffer.data_mut(&mut self.store)[addr as usize..addr as usize + len];
Ok((addr.into(), data))
}
#[allow(dead_code)]
fn free(&mut self, addr: Addr) -> Result<(), anyhow::Error> {
let opa_free = self
.instance
.get_typed_func::<(u32,), (), _>(&mut self.store, "opa_free")?;
opa_free.call(&mut self.store, (addr.into(),))?;
Ok(())
}
fn eval_once<I, O>(&mut self, entrypoint: &str, input: &I) -> Result<O, anyhow::Error>
where
I: Serialize,
O: DeserializeOwned,
{
let opa_eval = self.instance.get_typed_func::<(
u32, // reserved
u32, // entrypoint_id
u32, // parsed data addr
u32, // input json str addr
u32, // input length
u32, // heap_ptr
u32, // format
), u32, _>(&mut self.store, "opa_eval")?;
let data_addr = self.data_addr.ok_or_else(|| {
anyhow!("no data provided, `set_data` must be called at least once first")
})?;
let input_bytes = serde_json::to_vec(input)?;
let input_idx = self.input_heap_ptr.0 as usize;
let input_len = u32::try_from(input_bytes.len())
.map_err(|err| anyhow::anyhow!("input data is too large: {err}"))?;
if self.env_buffer.data_size(&mut self.store) < input_idx + input_bytes.len() {
let delta = round_up(
input_idx + input_bytes.len() - self.env_buffer.data_size(&mut self.store),
);
self.env_buffer.grow(&mut self.store, delta as _)?;
}
let data = self.env_buffer.data_mut(&mut self.store);
copy(&mut &*input_bytes, &mut &mut data[input_idx..])?;
let entrypoint = self.entrypoint_id(entrypoint)?;
let out_addr = opa_eval.call(
&mut self.store,
(
0,
entrypoint,
data_addr.into(),
self.input_heap_ptr.0,
input_len,
u32::from(self.input_heap_ptr) + input_len,
0,
),
)?;
let mut out: Vec<OpaOutput<O>> = serde_json::from_slice(
self.bytes_at(Addr(out_addr))
.ok_or_else(|| anyhow::anyhow!("invalid output returned from evaluation"))?,
)?;
self.set_heap_ptr(self.input_heap_ptr)?;
out.pop()
.map(|v| v.result)
.ok_or_else(|| anyhow::anyhow!("evaluation yielded no result (it's most likely a bug)"))
}
fn heap_ptr(&mut self) -> Result<Addr, anyhow::Error> {
let opa_heap_ptr_get = self
.instance
.get_typed_func::<(), u32, _>(&mut self.store, "opa_heap_ptr_get")?;
Ok(Addr(opa_heap_ptr_get.call(&mut self.store, ())?))
}
fn set_heap_ptr(&mut self, addr: Addr) -> Result<(), anyhow::Error> {
let opa_heap_ptr_set = self
.instance
.get_typed_func::<(u32,), (), _>(&mut self.store, "opa_heap_ptr_set")?;
opa_heap_ptr_set.call(&mut self.store, (addr.into(),))?;
Ok(())
}
fn entrypoint_id(&mut self, entrypoint: &str) -> Result<u32, anyhow::Error> {
let entrypoint = if entrypoint.contains('.') {
Cow::Owned(entrypoint.replace('.', "/"))
} else {
Cow::Borrowed(entrypoint)
};
self.entrypoints
.get(entrypoint.as_ref())
.copied()
.ok_or_else(|| anyhow!("invalid entrypoint `{}`", &entrypoint))
}
}
pub struct EvalContext<'c> {
opa: &'c mut Opa,
ctx_addr: Addr,
}
impl<'c> EvalContext<'c> {
fn create(opa: &'c mut Opa, input: &impl Serialize) -> Result<Self, anyhow::Error> {
let opa_eval_ctx_new = opa
.instance
.get_typed_func::<(), u32, _>(&mut opa.store, "opa_eval_ctx_new")?;
let opa_eval_ctx_set_input = opa
.instance
.get_typed_func::<(u32, u32), (), _>(&mut opa.store, "opa_eval_ctx_set_input")?;
let opa_eval_ctx_set_data = opa
.instance
.get_typed_func::<(u32, u32), (), _>(&mut opa.store, "opa_eval_ctx_set_data")?;
opa.set_heap_ptr(opa.input_heap_ptr)?;
let data_addr = opa.data_addr.ok_or_else(|| {
anyhow!("no data provided, `set_data` must be called at least once first")
})?;
let input_addr = opa.write_json(input)?;
let ctx_addr = opa_eval_ctx_new.call(&mut opa.store, ())?;
opa_eval_ctx_set_data.call(&mut opa.store, (ctx_addr, data_addr.into()))?;
opa_eval_ctx_set_input.call(&mut opa.store, (ctx_addr, input_addr.into()))?;
Ok(EvalContext {
opa,
ctx_addr: ctx_addr.into(),
})
}
pub fn eval<O>(&mut self, entrypoint: &str) -> Result<O, anyhow::Error>
where
O: DeserializeOwned,
{
let opa_eval_ctx_set_entrypoint = self.opa.instance.get_typed_func::<(u32, u32), (), _>(
&mut self.opa.store,
"opa_eval_ctx_set_entrypoint",
)?;
let opa_eval_ctx_get_result = self
.opa
.instance
.get_typed_func::<(u32,), u32, _>(&mut self.opa.store, "opa_eval_ctx_get_result")?;
let opa_eval = self
.opa
.instance
.get_typed_func::<(u32,), u32, _>(&mut self.opa.store, "eval")?;
let entrypoint_id = self.opa.entrypoint_id(entrypoint)?;
opa_eval_ctx_set_entrypoint
.call(&mut self.opa.store, (self.ctx_addr.into(), entrypoint_id))?;
let start_heap = self.opa.heap_ptr()?;
opa_eval.call(&mut self.opa.store, (self.ctx_addr.into(),))?;
let result_addr =
opa_eval_ctx_get_result.call(&mut self.opa.store, (self.ctx_addr.into(),))?;
let result: Result<Vec<OpaOutput<O>>, _> = self.opa.json_at(result_addr.into());
self.opa.set_heap_ptr(start_heap)?;
result.and_then(|mut out| {
out.pop()
.map(|r| r.result)
.ok_or_else(|| anyhow!("the query produced no results"))
})
}
pub fn destroy(mut self) -> Result<(), anyhow::Error> {
self.destroy_mut()?;
let _ = ManuallyDrop::new(self);
Ok(())
}
fn destroy_mut(&mut self) -> Result<(), anyhow::Error> {
self.opa.set_heap_ptr(self.opa.input_heap_ptr)?;
Ok(())
}
}
impl Drop for EvalContext<'_> {
fn drop(&mut self) {
if let Err(err) = self.destroy_mut() {
#[allow(clippy::manual_assert)]
if !std::thread::panicking() {
panic!("{err:?}");
}
}
}
}
#[derive(Debug, Clone, Copy)]
#[repr(transparent)]
struct Addr(u32);
impl From<u32> for Addr {
fn from(v: u32) -> Self {
Self(v)
}
}
impl From<usize> for Addr {
fn from(v: usize) -> Self {
Self(v as _)
}
}
impl From<Addr> for u32 {
fn from(v: Addr) -> Self {
v.0
}
}
impl From<Addr> for usize {
fn from(v: Addr) -> Self {
v.0 as _
}
}
fn null_terminated_slice(slice: &[u8]) -> Option<&[u8]> {
slice.iter().position(|b| *b == 0).map(|end| &slice[0..end])
}
fn null_terminated_str(slice: &[u8]) -> Option<&str> {
slice
.iter()
.position(|b| *b == 0)
.and_then(|end| std::str::from_utf8(&slice[0..end]).ok())
}
fn default_opa_abort(error: &str) {
panic!("OPA abort was called: {error}");
}
fn default_opa_println(value: &str) {
println!("{value}");
}
#[derive(Deserialize)]
struct OpaOutput<R> {
result: R,
}
fn round_up(bytes: usize) -> usize {
const PAGE_SIZE: usize = 64 * 1024;
bytes / PAGE_SIZE + usize::from(bytes % PAGE_SIZE != 0)
}