Skip to main content

opa_wasm/
policy.rs

1// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! The policy evaluation logic, which includes the [`Policy`] and [`Runtime`]
16//! structures.
17
18use std::{
19    collections::{HashMap, HashSet},
20    ffi::CString,
21    fmt::Debug,
22    ops::Deref,
23    sync::Arc,
24};
25
26use anyhow::{Context, Result};
27use tokio::sync::{Mutex, OnceCell};
28use tracing::Instrument;
29use wasmtime::{AsContextMut, Caller, Linker, Memory, Module};
30
31use crate::{
32    builtins::traits::Builtin,
33    funcs::{self, Func},
34    types::{AbiVersion, Addr, BuiltinId, EntrypointId, Heap, NulStr, Value},
35    DefaultContext, EvaluationContext,
36};
37
38/// Utility to allocate a string in the Wasm memory and return a pointer to it.
39async fn alloc_str<V: Into<Vec<u8>>, T: Send>(
40    opa_malloc: &funcs::OpaMalloc,
41    mut store: impl AsContextMut<Data = T>,
42    memory: &Memory,
43    value: V,
44) -> Result<Heap> {
45    let value = CString::new(value)?;
46    let value = value.as_bytes_with_nul();
47    let heap = opa_malloc.call(&mut store, value.len()).await?;
48
49    memory.write(
50        &mut store,
51        heap.ptr
52            .try_into()
53            .context("opa_malloc returned an invalid pointer value")?,
54        value,
55    )?;
56
57    Ok(heap)
58}
59
60/// Utility to load a JSON value into the WASM memory.
61async fn load_json<V: serde::Serialize, T: Send>(
62    opa_malloc: &funcs::OpaMalloc,
63    opa_free: &funcs::OpaFree,
64    opa_json_parse: &funcs::OpaJsonParse,
65    mut store: impl AsContextMut<Data = T>,
66    memory: &Memory,
67    data: &V,
68) -> Result<Value> {
69    let json = serde_json::to_vec(data)?;
70    let json = alloc_str(opa_malloc, &mut store, memory, json).await?;
71    let data = opa_json_parse.call(&mut store, &json).await?;
72    opa_free.call(&mut store, json).await?;
73    Ok(data)
74}
75
76/// A structure which holds the builtins referenced by the policy.
77struct LoadedBuiltins<C> {
78    /// A map of builtin IDs to the name and the builtin itself.
79    builtins: HashMap<i32, (String, Box<dyn Builtin<C>>)>,
80
81    /// The inner [`EvaluationContext`] which will be passed when calling
82    /// some builtins
83    context: Mutex<C>,
84}
85
86impl<C> std::fmt::Debug for LoadedBuiltins<C> {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        f.debug_struct("LoadedBuiltins")
89            .field("builtins", &())
90            .finish()
91    }
92}
93
94impl<C> LoadedBuiltins<C>
95where
96    C: EvaluationContext,
97{
98    /// Resolve the builtins from a map of builtin IDs to their names.
99    fn from_map(map: HashMap<String, BuiltinId>, context: C) -> Result<Self> {
100        let res: Result<_> = map
101            .into_iter()
102            .map(|(k, v)| {
103                let builtin = crate::builtins::resolve(&k)?;
104                Ok((v.0, (k, builtin)))
105            })
106            .collect();
107        Ok(Self {
108            builtins: res?,
109            context: Mutex::new(context),
110        })
111    }
112
113    /// Call the given builtin given its ID and arguments.
114    async fn builtin<T: Send, const N: usize>(
115        &self,
116        mut caller: Caller<'_, T>,
117        memory: &Memory,
118        builtin_id: i32,
119        args: [i32; N],
120    ) -> Result<i32, anyhow::Error> {
121        let (name, builtin) = self
122            .builtins
123            .get(&builtin_id)
124            .with_context(|| format!("unknown builtin id {builtin_id}"))?;
125
126        let span = tracing::info_span!("builtin", %name);
127        let _enter = span.enter();
128
129        let opa_json_dump = funcs::OpaJsonDump::from_caller(&mut caller)?;
130        let opa_json_parse = funcs::OpaJsonParse::from_caller(&mut caller)?;
131        let opa_malloc = funcs::OpaMalloc::from_caller(&mut caller)?;
132        let opa_free = funcs::OpaFree::from_caller(&mut caller)?;
133
134        // Call opa_json_dump on each argument
135        let mut args_json = Vec::with_capacity(N);
136        for arg in args {
137            args_json.push(opa_json_dump.call(&mut caller, &Value(arg)).await?);
138        }
139
140        // Extract the JSON value of each argument
141        let mut mapped_args = Vec::with_capacity(N);
142        for arg_json in args_json {
143            let arg = arg_json.read(&caller, memory)?;
144            mapped_args.push(arg.to_bytes());
145        }
146
147        let mut ctx = self.context.lock().await;
148
149        // Actually call the function
150        let ret = (async move { builtin.call(&mut ctx, &mapped_args).await })
151            .instrument(tracing::info_span!("builtin.call"))
152            .await?;
153
154        let json = alloc_str(&opa_malloc, &mut caller, memory, ret).await?;
155        let data = opa_json_parse.call(&mut caller, &json).await?;
156        opa_free.call(&mut caller, json).await?;
157
158        Ok(data.0)
159    }
160
161    /// Called when the policy evaluation starts, to reset the context and
162    /// record the evaluation starting time
163    async fn evaluation_start(&self) {
164        self.context.lock().await.evaluation_start();
165    }
166}
167
168/// An instance of a policy with builtins and entrypoints resolved, but with no
169/// data provided yet
170#[allow(clippy::missing_docs_in_private_items)]
171pub struct Runtime<C> {
172    version: AbiVersion,
173    memory: Memory,
174    entrypoints: HashMap<String, EntrypointId>,
175    loaded_builtins: Arc<OnceCell<LoadedBuiltins<C>>>,
176
177    eval_func: funcs::Eval,
178    opa_eval_ctx_new_func: funcs::OpaEvalCtxNew,
179    opa_eval_ctx_set_input_func: funcs::OpaEvalCtxSetInput,
180    opa_eval_ctx_set_data_func: funcs::OpaEvalCtxSetData,
181    opa_eval_ctx_set_entrypoint_func: funcs::OpaEvalCtxSetEntrypoint,
182    opa_eval_ctx_get_result_func: funcs::OpaEvalCtxGetResult,
183    opa_malloc_func: funcs::OpaMalloc,
184    opa_free_func: funcs::OpaFree,
185    opa_json_parse_func: funcs::OpaJsonParse,
186    opa_json_dump_func: funcs::OpaJsonDump,
187    opa_heap_ptr_set_func: funcs::OpaHeapPtrSet,
188    opa_heap_ptr_get_func: funcs::OpaHeapPtrGet,
189    opa_eval_func: Option<funcs::OpaEval>,
190}
191
192impl<C> Debug for Runtime<C> {
193    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194        f.debug_struct("Runtime")
195            .field("version", &self.version)
196            .field("memory", &self.memory)
197            .field("entrypoints", &self.entrypoints)
198            .finish_non_exhaustive()
199    }
200}
201
202impl Runtime<DefaultContext> {
203    /// Load a new WASM policy module into the given store, with the default
204    /// evaluation context.
205    ///
206    /// # Errors
207    ///
208    /// It will raise an error if one of the following condition is met:
209    ///
210    ///  - the provided [`wasmtime::Store`] isn't an async one
211    ///  - the [`wasmtime::Module`] was created with a different
212    ///    [`wasmtime::Engine`] than the [`wasmtime::Store`]
213    ///  - the WASM module is not a valid OPA WASM compiled policy, and lacks
214    ///    some of the exported functions
215    ///  - it failed to load the entrypoints or the builtins list
216    #[allow(clippy::too_many_lines)]
217    pub async fn new<T: Send + 'static>(
218        store: impl AsContextMut<Data = T>,
219        module: &Module,
220    ) -> Result<Self> {
221        let context = DefaultContext::default();
222        Self::new_with_evaluation_context(store, module, context).await
223    }
224}
225
226impl<C> Runtime<C> {
227    /// Load a new WASM policy module into the given store, with a given
228    /// evaluation context.
229    ///
230    /// # Errors
231    ///
232    /// It will raise an error if one of the following condition is met:
233    ///
234    ///  - the provided [`wasmtime::Store`] isn't an async one
235    ///  - the [`wasmtime::Module`] was created with a different
236    ///    [`wasmtime::Engine`] than the [`wasmtime::Store`]
237    ///  - the WASM module is not a valid OPA WASM compiled policy, and lacks
238    ///    some of the exported functions
239    ///  - it failed to load the entrypoints or the builtins list
240    #[allow(clippy::too_many_lines)]
241    pub async fn new_with_evaluation_context<T: Send + 'static>(
242        mut store: impl AsContextMut<Data = T>,
243        module: &Module,
244        context: C,
245    ) -> Result<Self>
246    where
247        C: EvaluationContext,
248    {
249        // Read the memory type from the module imports, and create a memory
250        // instance that matches whatever OPA's compiler declared (min, max,
251        // 32/64-bit, shared).
252        let ty = module
253            .imports()
254            .find(|i| i.module() == "env" && i.name() == "memory")
255            .and_then(|i| i.ty().memory().cloned())
256            .context("missing or invalid 'env.memory' import")?;
257        let memory = Memory::new_async(&mut store, ty).await?;
258
259        // TODO: make the context configurable and reset it on evaluation
260        let eventually_builtins = Arc::new(OnceCell::<LoadedBuiltins<C>>::new());
261
262        let mut linker = Linker::new(store.as_context_mut().engine());
263        linker.define(&store, "env", "memory", memory)?;
264
265        linker.func_wrap(
266            "env",
267            "opa_abort",
268            move |caller: Caller<'_, _>, addr: i32| -> wasmtime::Result<()> {
269                let addr = NulStr(addr);
270                let msg = addr
271                    .read(&caller, &memory)
272                    .map_err(wasmtime::Error::from_anyhow)?;
273                let msg = msg.to_string_lossy().into_owned();
274                tracing::error!("opa_abort: {}", msg);
275                Err(wasmtime::Error::msg(msg))
276            },
277        )?;
278
279        linker.func_wrap(
280            "env",
281            "opa_println",
282            move |caller: Caller<'_, _>, addr: i32| -> wasmtime::Result<()> {
283                let addr = NulStr(addr);
284                let msg = addr
285                    .read(&caller, &memory)
286                    .map_err(wasmtime::Error::from_anyhow)?;
287                tracing::info!("opa_print: {}", msg.to_string_lossy());
288                Ok(())
289            },
290        )?;
291
292        {
293            let eventually_builtins = eventually_builtins.clone();
294            linker.func_wrap_async(
295                "env",
296                "opa_builtin0",
297                move |caller: Caller<'_, _>, (builtin_id, _ctx): (i32, i32)| {
298                    let eventually_builtins = eventually_builtins.clone();
299                    Box::new(async move {
300                        eventually_builtins
301                            .get()
302                            .context("builtins where never initialized")
303                            .map_err(wasmtime::Error::from_anyhow)?
304                            .builtin(caller, &memory, builtin_id, [])
305                            .await
306                            .map_err(wasmtime::Error::from_anyhow)
307                    })
308                },
309            )?;
310        }
311
312        {
313            let eventually_builtins = eventually_builtins.clone();
314            linker.func_wrap_async(
315                "env",
316                "opa_builtin1",
317                move |caller: Caller<'_, _>, (builtin_id, _ctx, param1): (i32, i32, i32)| {
318                    let eventually_builtins = eventually_builtins.clone();
319                    Box::new(async move {
320                        eventually_builtins
321                            .get()
322                            .context("builtins where never initialized")
323                            .map_err(wasmtime::Error::from_anyhow)?
324                            .builtin(caller, &memory, builtin_id, [param1])
325                            .await
326                            .map_err(wasmtime::Error::from_anyhow)
327                    })
328                },
329            )?;
330        }
331
332        {
333            let eventually_builtins = eventually_builtins.clone();
334            linker.func_wrap_async(
335                "env",
336                "opa_builtin2",
337                move |caller: Caller<'_, _>,
338                      (builtin_id, _ctx, param1, param2): (i32, i32, i32, i32)| {
339                    let eventually_builtins = eventually_builtins.clone();
340                    Box::new(async move {
341                        eventually_builtins
342                            .get()
343                            .context("builtins where never initialized")
344                            .map_err(wasmtime::Error::from_anyhow)?
345                            .builtin(caller, &memory, builtin_id, [param1, param2])
346                            .await
347                            .map_err(wasmtime::Error::from_anyhow)
348                    })
349                },
350            )?;
351        }
352
353        {
354            let eventually_builtins = eventually_builtins.clone();
355            linker.func_wrap_async(
356                "env",
357                "opa_builtin3",
358                move |caller: Caller<'_, _>,
359                      (builtin_id,
360                      _ctx,
361                      param1,
362                      param2,
363                      param3): (i32, i32, i32, i32, i32)| {
364                    let eventually_builtins = eventually_builtins.clone();
365                    Box::new(async move {
366                        eventually_builtins
367                            .get()
368                            .context("builtins where never initialized")
369                            .map_err(wasmtime::Error::from_anyhow)?
370                            .builtin(caller, &memory, builtin_id, [param1, param2, param3])
371                            .await
372                            .map_err(wasmtime::Error::from_anyhow)
373                    })
374                },
375            )?;
376        }
377
378        {
379            let eventually_builtins = eventually_builtins.clone();
380            linker.func_wrap_async(
381                "env",
382                "opa_builtin4",
383                move |caller: Caller<'_, _>,
384                      (builtin_id, _ctx, param1, param2, param3, param4): (
385                    i32,
386                    i32,
387                    i32,
388                    i32,
389                    i32,
390                    i32,
391                )| {
392                    let eventually_builtins = eventually_builtins.clone();
393                    Box::new(async move {
394                        eventually_builtins
395                            .get()
396                            .context("builtins where never initialized")
397                            .map_err(wasmtime::Error::from_anyhow)?
398                            .builtin(
399                                caller,
400                                &memory,
401                                builtin_id,
402                                [param1, param2, param3, param4],
403                            )
404                            .await
405                            .map_err(wasmtime::Error::from_anyhow)
406                    })
407                },
408            )?;
409        }
410
411        let instance = linker.instantiate_async(&mut store, module).await?;
412
413        let version = AbiVersion::from_instance(&mut store, &instance)?;
414        tracing::debug!(%version, "Module ABI version");
415
416        let opa_json_dump_func = funcs::OpaJsonDump::from_instance(&mut store, &instance)?;
417
418        // Load the builtins map
419        let builtins = funcs::Builtins::from_instance(&mut store, &instance)?
420            .call(&mut store)
421            .await?;
422        let builtins = opa_json_dump_func
423            .decode(&mut store, &memory, &builtins)
424            .await?;
425        let builtins = LoadedBuiltins::from_map(builtins, context)?;
426        eventually_builtins.set(builtins)?;
427
428        // Load the entrypoints map
429        let entrypoints = funcs::Entrypoints::from_instance(&mut store, &instance)?
430            .call(&mut store)
431            .await?;
432        let entrypoints = opa_json_dump_func
433            .decode(&mut store, &memory, &entrypoints)
434            .await?;
435
436        let opa_eval_func = version
437            .has_eval_fastpath()
438            .then(|| funcs::OpaEval::from_instance(&mut store, &instance))
439            .transpose()?;
440
441        Ok(Self {
442            version,
443            memory,
444            entrypoints,
445            loaded_builtins: eventually_builtins,
446
447            eval_func: funcs::Eval::from_instance(&mut store, &instance)?,
448            opa_eval_ctx_new_func: funcs::OpaEvalCtxNew::from_instance(&mut store, &instance)?,
449            opa_eval_ctx_set_input_func: funcs::OpaEvalCtxSetInput::from_instance(
450                &mut store, &instance,
451            )?,
452            opa_eval_ctx_set_data_func: funcs::OpaEvalCtxSetData::from_instance(
453                &mut store, &instance,
454            )?,
455            opa_eval_ctx_set_entrypoint_func: funcs::OpaEvalCtxSetEntrypoint::from_instance(
456                &mut store, &instance,
457            )?,
458            opa_eval_ctx_get_result_func: funcs::OpaEvalCtxGetResult::from_instance(
459                &mut store, &instance,
460            )?,
461            opa_malloc_func: funcs::OpaMalloc::from_instance(&mut store, &instance)?,
462            opa_free_func: funcs::OpaFree::from_instance(&mut store, &instance)?,
463            opa_json_parse_func: funcs::OpaJsonParse::from_instance(&mut store, &instance)?,
464            opa_json_dump_func,
465            opa_heap_ptr_set_func: funcs::OpaHeapPtrSet::from_instance(&mut store, &instance)?,
466            opa_heap_ptr_get_func: funcs::OpaHeapPtrGet::from_instance(&mut store, &instance)?,
467            opa_eval_func,
468        })
469    }
470
471    /// Load a JSON value into the WASM memory
472    async fn load_json<V: serde::Serialize, T: Send>(
473        &self,
474        store: impl AsContextMut<Data = T>,
475        data: &V,
476    ) -> Result<Value> {
477        load_json(
478            &self.opa_malloc_func,
479            &self.opa_free_func,
480            &self.opa_json_parse_func,
481            store,
482            &self.memory,
483            data,
484        )
485        .await
486    }
487
488    /// Instanciate the policy with an empty `data` object
489    ///
490    /// # Errors
491    ///
492    /// If it failed to load the empty data object in memory
493    pub async fn without_data<T: Send>(
494        self,
495        store: impl AsContextMut<Data = T>,
496    ) -> Result<Policy<C>> {
497        let data = serde_json::Value::Object(serde_json::Map::default());
498        self.with_data(store, &data).await
499    }
500
501    /// Instanciate the policy with the given `data` object
502    ///
503    /// # Errors
504    ///
505    /// If it failed to serialize and load the `data` object
506    pub async fn with_data<V: serde::Serialize, T: Send>(
507        self,
508        mut store: impl AsContextMut<Data = T>,
509        data: &V,
510    ) -> Result<Policy<C>> {
511        let data = self.load_json(&mut store, data).await?;
512        let heap_ptr = self.opa_heap_ptr_get_func.call(&mut store).await?;
513        Ok(Policy {
514            runtime: self,
515            data,
516            heap_ptr,
517        })
518    }
519
520    /// Get the default entrypoint of this module. May return [`None`] if no
521    /// entrypoint with ID 0 was found
522    #[must_use]
523    pub fn default_entrypoint(&self) -> Option<&str> {
524        self.entrypoints
525            .iter()
526            .find_map(|(k, v)| (v.0 == 0).then_some(k.as_str()))
527    }
528
529    /// Get the list of entrypoints found in this module.
530    #[must_use]
531    pub fn entrypoints(&self) -> HashSet<&str> {
532        self.entrypoints.keys().map(String::as_str).collect()
533    }
534
535    /// Get the ABI version detected for this module
536    #[must_use]
537    pub fn abi_version(&self) -> AbiVersion {
538        self.version
539    }
540}
541
542/// An instance of a policy, ready to be executed
543#[derive(Debug)]
544pub struct Policy<C> {
545    /// The runtime this policy instance belongs to
546    runtime: Runtime<C>,
547
548    /// The data object loaded for this policy
549    data: Value,
550
551    /// A pointer to the heap, used for efficient allocations
552    heap_ptr: Addr,
553}
554
555impl<C> Policy<C> {
556    /// Evaluate a policy with the given entrypoint and input.
557    ///
558    /// # Errors
559    ///
560    /// Returns an error if the policy evaluation failed, or if this policy did
561    /// not belong to the given store.
562    pub async fn evaluate<V: serde::Serialize, R: for<'de> serde::Deserialize<'de>, T: Send>(
563        &self,
564        mut store: impl AsContextMut<Data = T>,
565        entrypoint: &str,
566        input: &V,
567    ) -> Result<R>
568    where
569        C: EvaluationContext,
570    {
571        // Lookup the entrypoint
572        let entrypoint = self
573            .runtime
574            .entrypoints
575            .get(entrypoint)
576            .with_context(|| format!("could not find entrypoint {entrypoint}"))?;
577
578        self.loaded_builtins
579            .get()
580            .context("builtins where never initialized")?
581            .evaluation_start()
582            .await;
583
584        // Take the fast path if it is awailable
585        if let Some(opa_eval) = &self.runtime.opa_eval_func {
586            // Write the input
587            let input = serde_json::to_vec(&input)?;
588            let input_heap = Heap {
589                ptr: self.heap_ptr.0,
590                len: input.len().try_into().context("input too long")?,
591                // Not managed by a malloc
592                freed: true,
593            };
594
595            // Check if we need to grow the memory first
596            let current_pages = self.runtime.memory.size(&store);
597            let needed_pages = input_heap.pages();
598            if current_pages < needed_pages {
599                self.runtime
600                    .memory
601                    .grow_async(&mut store, needed_pages - current_pages)
602                    .await?;
603            }
604
605            // Write the JSON input to memory
606            self.runtime.memory.write(
607                &mut store,
608                input_heap.ptr.try_into().context("invalid heap pointer")?,
609                &input[..],
610            )?;
611
612            let heap_ptr = Addr(input_heap.end());
613
614            // Call the eval fast-path
615            let result = opa_eval
616                .call(&mut store, entrypoint, &self.data, &input_heap, &heap_ptr)
617                .await?;
618
619            // Read back the JSON-formatted result
620            let result = result.read(&store, &self.runtime.memory)?;
621            let result = serde_json::from_slice(result.to_bytes())?;
622            Ok(result)
623        } else {
624            // Reset the heap pointer
625            self.runtime
626                .opa_heap_ptr_set_func
627                .call(&mut store, &self.heap_ptr)
628                .await?;
629
630            // Load the input
631            let input = self.runtime.load_json(&mut store, input).await?;
632
633            // Create a new evaluation context
634            let ctx = self.runtime.opa_eval_ctx_new_func.call(&mut store).await?;
635
636            // Set the data location
637            self.runtime
638                .opa_eval_ctx_set_data_func
639                .call(&mut store, &ctx, &self.data)
640                .await?;
641            // Set the input location
642            self.runtime
643                .opa_eval_ctx_set_input_func
644                .call(&mut store, &ctx, &input)
645                .await?;
646
647            // Set the entrypoint
648            self.runtime
649                .opa_eval_ctx_set_entrypoint_func
650                .call(&mut store, &ctx, entrypoint)
651                .await?;
652
653            // Evaluate the policy
654            self.runtime.eval_func.call(&mut store, &ctx).await?;
655
656            // Get the results back
657            let result = self
658                .runtime
659                .opa_eval_ctx_get_result_func
660                .call(&mut store, &ctx)
661                .await?;
662
663            let result = self
664                .runtime
665                .opa_json_dump_func
666                .decode(&mut store, &self.runtime.memory, &result)
667                .await?;
668
669            Ok(result)
670        }
671    }
672}
673
674impl<C> Deref for Policy<C> {
675    type Target = Runtime<C>;
676    fn deref(&self) -> &Self::Target {
677        &self.runtime
678    }
679}