Skip to main content

opa_wasm/
policy.rs

1// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! The policy evaluation logic, which includes the [`Policy`] and [`Runtime`]
16//! structures.
17
18use std::{
19    collections::{HashMap, HashSet},
20    ffi::CString,
21    fmt::Debug,
22    ops::Deref,
23    sync::Arc,
24};
25
26use anyhow::{Context, Result};
27use tokio::sync::{Mutex, OnceCell};
28use tracing::Instrument;
29use wasmtime::{AsContextMut, Caller, Linker, Memory, MemoryType, Module};
30
31use crate::{
32    builtins::traits::Builtin,
33    funcs::{self, Func},
34    types::{AbiVersion, Addr, BuiltinId, EntrypointId, Heap, NulStr, Value},
35    DefaultContext, EvaluationContext,
36};
37
38/// Utility to allocate a string in the Wasm memory and return a pointer to it.
39async fn alloc_str<V: Into<Vec<u8>>, T: Send>(
40    opa_malloc: &funcs::OpaMalloc,
41    mut store: impl AsContextMut<Data = T>,
42    memory: &Memory,
43    value: V,
44) -> Result<Heap> {
45    let value = CString::new(value)?;
46    let value = value.as_bytes_with_nul();
47    let heap = opa_malloc.call(&mut store, value.len()).await?;
48
49    memory.write(
50        &mut store,
51        heap.ptr
52            .try_into()
53            .context("opa_malloc returned an invalid pointer value")?,
54        value,
55    )?;
56
57    Ok(heap)
58}
59
60/// Utility to load a JSON value into the WASM memory.
61async fn load_json<V: serde::Serialize, T: Send>(
62    opa_malloc: &funcs::OpaMalloc,
63    opa_free: &funcs::OpaFree,
64    opa_json_parse: &funcs::OpaJsonParse,
65    mut store: impl AsContextMut<Data = T>,
66    memory: &Memory,
67    data: &V,
68) -> Result<Value> {
69    let json = serde_json::to_vec(data)?;
70    let json = alloc_str(opa_malloc, &mut store, memory, json).await?;
71    let data = opa_json_parse.call(&mut store, &json).await?;
72    opa_free.call(&mut store, json).await?;
73    Ok(data)
74}
75
76/// A structure which holds the builtins referenced by the policy.
77struct LoadedBuiltins<C> {
78    /// A map of builtin IDs to the name and the builtin itself.
79    builtins: HashMap<i32, (String, Box<dyn Builtin<C>>)>,
80
81    /// The inner [`EvaluationContext`] which will be passed when calling
82    /// some builtins
83    context: Mutex<C>,
84}
85
86impl<C> std::fmt::Debug for LoadedBuiltins<C> {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        f.debug_struct("LoadedBuiltins")
89            .field("builtins", &())
90            .finish()
91    }
92}
93
94impl<C> LoadedBuiltins<C>
95where
96    C: EvaluationContext,
97{
98    /// Resolve the builtins from a map of builtin IDs to their names.
99    fn from_map(map: HashMap<String, BuiltinId>, context: C) -> Result<Self> {
100        let res: Result<_> = map
101            .into_iter()
102            .map(|(k, v)| {
103                let builtin = crate::builtins::resolve(&k)?;
104                Ok((v.0, (k, builtin)))
105            })
106            .collect();
107        Ok(Self {
108            builtins: res?,
109            context: Mutex::new(context),
110        })
111    }
112
113    /// Call the given builtin given its ID and arguments.
114    async fn builtin<T: Send, const N: usize>(
115        &self,
116        mut caller: Caller<'_, T>,
117        memory: &Memory,
118        builtin_id: i32,
119        args: [i32; N],
120    ) -> Result<i32, anyhow::Error> {
121        let (name, builtin) = self
122            .builtins
123            .get(&builtin_id)
124            .with_context(|| format!("unknown builtin id {builtin_id}"))?;
125
126        let span = tracing::info_span!("builtin", %name);
127        let _enter = span.enter();
128
129        let opa_json_dump = funcs::OpaJsonDump::from_caller(&mut caller)?;
130        let opa_json_parse = funcs::OpaJsonParse::from_caller(&mut caller)?;
131        let opa_malloc = funcs::OpaMalloc::from_caller(&mut caller)?;
132        let opa_free = funcs::OpaFree::from_caller(&mut caller)?;
133
134        // Call opa_json_dump on each argument
135        let mut args_json = Vec::with_capacity(N);
136        for arg in args {
137            args_json.push(opa_json_dump.call(&mut caller, &Value(arg)).await?);
138        }
139
140        // Extract the JSON value of each argument
141        let mut mapped_args = Vec::with_capacity(N);
142        for arg_json in args_json {
143            let arg = arg_json.read(&caller, memory)?;
144            mapped_args.push(arg.to_bytes());
145        }
146
147        let mut ctx = self.context.lock().await;
148
149        // Actually call the function
150        let ret = (async move { builtin.call(&mut ctx, &mapped_args).await })
151            .instrument(tracing::info_span!("builtin.call"))
152            .await?;
153
154        let json = alloc_str(&opa_malloc, &mut caller, memory, ret).await?;
155        let data = opa_json_parse.call(&mut caller, &json).await?;
156        opa_free.call(&mut caller, json).await?;
157
158        Ok(data.0)
159    }
160
161    /// Called when the policy evaluation starts, to reset the context and
162    /// record the evaluation starting time
163    async fn evaluation_start(&self) {
164        self.context.lock().await.evaluation_start();
165    }
166}
167
168/// An instance of a policy with builtins and entrypoints resolved, but with no
169/// data provided yet
170#[allow(clippy::missing_docs_in_private_items)]
171pub struct Runtime<C> {
172    version: AbiVersion,
173    memory: Memory,
174    entrypoints: HashMap<String, EntrypointId>,
175    loaded_builtins: Arc<OnceCell<LoadedBuiltins<C>>>,
176
177    eval_func: funcs::Eval,
178    opa_eval_ctx_new_func: funcs::OpaEvalCtxNew,
179    opa_eval_ctx_set_input_func: funcs::OpaEvalCtxSetInput,
180    opa_eval_ctx_set_data_func: funcs::OpaEvalCtxSetData,
181    opa_eval_ctx_set_entrypoint_func: funcs::OpaEvalCtxSetEntrypoint,
182    opa_eval_ctx_get_result_func: funcs::OpaEvalCtxGetResult,
183    opa_malloc_func: funcs::OpaMalloc,
184    opa_free_func: funcs::OpaFree,
185    opa_json_parse_func: funcs::OpaJsonParse,
186    opa_json_dump_func: funcs::OpaJsonDump,
187    opa_heap_ptr_set_func: funcs::OpaHeapPtrSet,
188    opa_heap_ptr_get_func: funcs::OpaHeapPtrGet,
189    opa_eval_func: Option<funcs::OpaEval>,
190}
191
192impl<C> Debug for Runtime<C> {
193    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194        f.debug_struct("Runtime")
195            .field("version", &self.version)
196            .field("memory", &self.memory)
197            .field("entrypoints", &self.entrypoints)
198            .finish_non_exhaustive()
199    }
200}
201
202impl Runtime<DefaultContext> {
203    /// Load a new WASM policy module into the given store, with the default
204    /// evaluation context.
205    ///
206    /// # Errors
207    ///
208    /// It will raise an error if one of the following condition is met:
209    ///
210    ///  - the provided [`wasmtime::Store`] isn't an async one
211    ///  - the [`wasmtime::Module`] was created with a different
212    ///    [`wasmtime::Engine`] than the [`wasmtime::Store`]
213    ///  - the WASM module is not a valid OPA WASM compiled policy, and lacks
214    ///    some of the exported functions
215    ///  - it failed to load the entrypoints or the builtins list
216    #[allow(clippy::too_many_lines)]
217    pub async fn new<T: Send + 'static>(
218        store: impl AsContextMut<Data = T>,
219        module: &Module,
220    ) -> Result<Self> {
221        let context = DefaultContext::default();
222        Self::new_with_evaluation_context(store, module, context).await
223    }
224}
225
226impl<C> Runtime<C> {
227    /// Load a new WASM policy module into the given store, with a given
228    /// evaluation context.
229    ///
230    /// # Errors
231    ///
232    /// It will raise an error if one of the following condition is met:
233    ///
234    ///  - the provided [`wasmtime::Store`] isn't an async one
235    ///  - the [`wasmtime::Module`] was created with a different
236    ///    [`wasmtime::Engine`] than the [`wasmtime::Store`]
237    ///  - the WASM module is not a valid OPA WASM compiled policy, and lacks
238    ///    some of the exported functions
239    ///  - it failed to load the entrypoints or the builtins list
240    #[allow(clippy::too_many_lines)]
241    pub async fn new_with_evaluation_context<T: Send + 'static>(
242        mut store: impl AsContextMut<Data = T>,
243        module: &Module,
244        context: C,
245    ) -> Result<Self>
246    where
247        C: EvaluationContext,
248    {
249        let ty = MemoryType::new(2, None);
250        let memory = Memory::new_async(&mut store, ty).await?;
251
252        // TODO: make the context configurable and reset it on evaluation
253        let eventually_builtins = Arc::new(OnceCell::<LoadedBuiltins<C>>::new());
254
255        let mut linker = Linker::new(store.as_context_mut().engine());
256        linker.define(&store, "env", "memory", memory)?;
257
258        linker.func_wrap(
259            "env",
260            "opa_abort",
261            move |caller: Caller<'_, _>, addr: i32| -> wasmtime::Result<()> {
262                let addr = NulStr(addr);
263                let msg = addr
264                    .read(&caller, &memory)
265                    .map_err(wasmtime::Error::from_anyhow)?;
266                let msg = msg.to_string_lossy().into_owned();
267                tracing::error!("opa_abort: {}", msg);
268                Err(wasmtime::Error::msg(msg))
269            },
270        )?;
271
272        linker.func_wrap(
273            "env",
274            "opa_println",
275            move |caller: Caller<'_, _>, addr: i32| -> wasmtime::Result<()> {
276                let addr = NulStr(addr);
277                let msg = addr
278                    .read(&caller, &memory)
279                    .map_err(wasmtime::Error::from_anyhow)?;
280                tracing::info!("opa_print: {}", msg.to_string_lossy());
281                Ok(())
282            },
283        )?;
284
285        {
286            let eventually_builtins = eventually_builtins.clone();
287            linker.func_wrap_async(
288                "env",
289                "opa_builtin0",
290                move |caller: Caller<'_, _>, (builtin_id, _ctx): (i32, i32)| {
291                    let eventually_builtins = eventually_builtins.clone();
292                    Box::new(async move {
293                        eventually_builtins
294                            .get()
295                            .context("builtins where never initialized")
296                            .map_err(wasmtime::Error::from_anyhow)?
297                            .builtin(caller, &memory, builtin_id, [])
298                            .await
299                            .map_err(wasmtime::Error::from_anyhow)
300                    })
301                },
302            )?;
303        }
304
305        {
306            let eventually_builtins = eventually_builtins.clone();
307            linker.func_wrap_async(
308                "env",
309                "opa_builtin1",
310                move |caller: Caller<'_, _>, (builtin_id, _ctx, param1): (i32, i32, i32)| {
311                    let eventually_builtins = eventually_builtins.clone();
312                    Box::new(async move {
313                        eventually_builtins
314                            .get()
315                            .context("builtins where never initialized")
316                            .map_err(wasmtime::Error::from_anyhow)?
317                            .builtin(caller, &memory, builtin_id, [param1])
318                            .await
319                            .map_err(wasmtime::Error::from_anyhow)
320                    })
321                },
322            )?;
323        }
324
325        {
326            let eventually_builtins = eventually_builtins.clone();
327            linker.func_wrap_async(
328                "env",
329                "opa_builtin2",
330                move |caller: Caller<'_, _>,
331                      (builtin_id, _ctx, param1, param2): (i32, i32, i32, i32)| {
332                    let eventually_builtins = eventually_builtins.clone();
333                    Box::new(async move {
334                        eventually_builtins
335                            .get()
336                            .context("builtins where never initialized")
337                            .map_err(wasmtime::Error::from_anyhow)?
338                            .builtin(caller, &memory, builtin_id, [param1, param2])
339                            .await
340                            .map_err(wasmtime::Error::from_anyhow)
341                    })
342                },
343            )?;
344        }
345
346        {
347            let eventually_builtins = eventually_builtins.clone();
348            linker.func_wrap_async(
349                "env",
350                "opa_builtin3",
351                move |caller: Caller<'_, _>,
352                      (builtin_id,
353                      _ctx,
354                      param1,
355                      param2,
356                      param3): (i32, i32, i32, i32, i32)| {
357                    let eventually_builtins = eventually_builtins.clone();
358                    Box::new(async move {
359                        eventually_builtins
360                            .get()
361                            .context("builtins where never initialized")
362                            .map_err(wasmtime::Error::from_anyhow)?
363                            .builtin(caller, &memory, builtin_id, [param1, param2, param3])
364                            .await
365                            .map_err(wasmtime::Error::from_anyhow)
366                    })
367                },
368            )?;
369        }
370
371        {
372            let eventually_builtins = eventually_builtins.clone();
373            linker.func_wrap_async(
374                "env",
375                "opa_builtin4",
376                move |caller: Caller<'_, _>,
377                      (builtin_id, _ctx, param1, param2, param3, param4): (
378                    i32,
379                    i32,
380                    i32,
381                    i32,
382                    i32,
383                    i32,
384                )| {
385                    let eventually_builtins = eventually_builtins.clone();
386                    Box::new(async move {
387                        eventually_builtins
388                            .get()
389                            .context("builtins where never initialized")
390                            .map_err(wasmtime::Error::from_anyhow)?
391                            .builtin(
392                                caller,
393                                &memory,
394                                builtin_id,
395                                [param1, param2, param3, param4],
396                            )
397                            .await
398                            .map_err(wasmtime::Error::from_anyhow)
399                    })
400                },
401            )?;
402        }
403
404        let instance = linker.instantiate_async(&mut store, module).await?;
405
406        let version = AbiVersion::from_instance(&mut store, &instance)?;
407        tracing::debug!(%version, "Module ABI version");
408
409        let opa_json_dump_func = funcs::OpaJsonDump::from_instance(&mut store, &instance)?;
410
411        // Load the builtins map
412        let builtins = funcs::Builtins::from_instance(&mut store, &instance)?
413            .call(&mut store)
414            .await?;
415        let builtins = opa_json_dump_func
416            .decode(&mut store, &memory, &builtins)
417            .await?;
418        let builtins = LoadedBuiltins::from_map(builtins, context)?;
419        eventually_builtins.set(builtins)?;
420
421        // Load the entrypoints map
422        let entrypoints = funcs::Entrypoints::from_instance(&mut store, &instance)?
423            .call(&mut store)
424            .await?;
425        let entrypoints = opa_json_dump_func
426            .decode(&mut store, &memory, &entrypoints)
427            .await?;
428
429        let opa_eval_func = version
430            .has_eval_fastpath()
431            .then(|| funcs::OpaEval::from_instance(&mut store, &instance))
432            .transpose()?;
433
434        Ok(Self {
435            version,
436            memory,
437            entrypoints,
438            loaded_builtins: eventually_builtins,
439
440            eval_func: funcs::Eval::from_instance(&mut store, &instance)?,
441            opa_eval_ctx_new_func: funcs::OpaEvalCtxNew::from_instance(&mut store, &instance)?,
442            opa_eval_ctx_set_input_func: funcs::OpaEvalCtxSetInput::from_instance(
443                &mut store, &instance,
444            )?,
445            opa_eval_ctx_set_data_func: funcs::OpaEvalCtxSetData::from_instance(
446                &mut store, &instance,
447            )?,
448            opa_eval_ctx_set_entrypoint_func: funcs::OpaEvalCtxSetEntrypoint::from_instance(
449                &mut store, &instance,
450            )?,
451            opa_eval_ctx_get_result_func: funcs::OpaEvalCtxGetResult::from_instance(
452                &mut store, &instance,
453            )?,
454            opa_malloc_func: funcs::OpaMalloc::from_instance(&mut store, &instance)?,
455            opa_free_func: funcs::OpaFree::from_instance(&mut store, &instance)?,
456            opa_json_parse_func: funcs::OpaJsonParse::from_instance(&mut store, &instance)?,
457            opa_json_dump_func,
458            opa_heap_ptr_set_func: funcs::OpaHeapPtrSet::from_instance(&mut store, &instance)?,
459            opa_heap_ptr_get_func: funcs::OpaHeapPtrGet::from_instance(&mut store, &instance)?,
460            opa_eval_func,
461        })
462    }
463
464    /// Load a JSON value into the WASM memory
465    async fn load_json<V: serde::Serialize, T: Send>(
466        &self,
467        store: impl AsContextMut<Data = T>,
468        data: &V,
469    ) -> Result<Value> {
470        load_json(
471            &self.opa_malloc_func,
472            &self.opa_free_func,
473            &self.opa_json_parse_func,
474            store,
475            &self.memory,
476            data,
477        )
478        .await
479    }
480
481    /// Instanciate the policy with an empty `data` object
482    ///
483    /// # Errors
484    ///
485    /// If it failed to load the empty data object in memory
486    pub async fn without_data<T: Send>(
487        self,
488        store: impl AsContextMut<Data = T>,
489    ) -> Result<Policy<C>> {
490        let data = serde_json::Value::Object(serde_json::Map::default());
491        self.with_data(store, &data).await
492    }
493
494    /// Instanciate the policy with the given `data` object
495    ///
496    /// # Errors
497    ///
498    /// If it failed to serialize and load the `data` object
499    pub async fn with_data<V: serde::Serialize, T: Send>(
500        self,
501        mut store: impl AsContextMut<Data = T>,
502        data: &V,
503    ) -> Result<Policy<C>> {
504        let data = self.load_json(&mut store, data).await?;
505        let heap_ptr = self.opa_heap_ptr_get_func.call(&mut store).await?;
506        Ok(Policy {
507            runtime: self,
508            data,
509            heap_ptr,
510        })
511    }
512
513    /// Get the default entrypoint of this module. May return [`None`] if no
514    /// entrypoint with ID 0 was found
515    #[must_use]
516    pub fn default_entrypoint(&self) -> Option<&str> {
517        self.entrypoints
518            .iter()
519            .find_map(|(k, v)| (v.0 == 0).then_some(k.as_str()))
520    }
521
522    /// Get the list of entrypoints found in this module.
523    #[must_use]
524    pub fn entrypoints(&self) -> HashSet<&str> {
525        self.entrypoints.keys().map(String::as_str).collect()
526    }
527
528    /// Get the ABI version detected for this module
529    #[must_use]
530    pub fn abi_version(&self) -> AbiVersion {
531        self.version
532    }
533}
534
535/// An instance of a policy, ready to be executed
536#[derive(Debug)]
537pub struct Policy<C> {
538    /// The runtime this policy instance belongs to
539    runtime: Runtime<C>,
540
541    /// The data object loaded for this policy
542    data: Value,
543
544    /// A pointer to the heap, used for efficient allocations
545    heap_ptr: Addr,
546}
547
548impl<C> Policy<C> {
549    /// Evaluate a policy with the given entrypoint and input.
550    ///
551    /// # Errors
552    ///
553    /// Returns an error if the policy evaluation failed, or if this policy did
554    /// not belong to the given store.
555    pub async fn evaluate<V: serde::Serialize, R: for<'de> serde::Deserialize<'de>, T: Send>(
556        &self,
557        mut store: impl AsContextMut<Data = T>,
558        entrypoint: &str,
559        input: &V,
560    ) -> Result<R>
561    where
562        C: EvaluationContext,
563    {
564        // Lookup the entrypoint
565        let entrypoint = self
566            .runtime
567            .entrypoints
568            .get(entrypoint)
569            .with_context(|| format!("could not find entrypoint {entrypoint}"))?;
570
571        self.loaded_builtins
572            .get()
573            .context("builtins where never initialized")?
574            .evaluation_start()
575            .await;
576
577        // Take the fast path if it is awailable
578        if let Some(opa_eval) = &self.runtime.opa_eval_func {
579            // Write the input
580            let input = serde_json::to_vec(&input)?;
581            let input_heap = Heap {
582                ptr: self.heap_ptr.0,
583                len: input.len().try_into().context("input too long")?,
584                // Not managed by a malloc
585                freed: true,
586            };
587
588            // Check if we need to grow the memory first
589            let current_pages = self.runtime.memory.size(&store);
590            let needed_pages = input_heap.pages();
591            if current_pages < needed_pages {
592                self.runtime
593                    .memory
594                    .grow_async(&mut store, needed_pages - current_pages)
595                    .await?;
596            }
597
598            // Write the JSON input to memory
599            self.runtime.memory.write(
600                &mut store,
601                input_heap.ptr.try_into().context("invalid heap pointer")?,
602                &input[..],
603            )?;
604
605            let heap_ptr = Addr(input_heap.end());
606
607            // Call the eval fast-path
608            let result = opa_eval
609                .call(&mut store, entrypoint, &self.data, &input_heap, &heap_ptr)
610                .await?;
611
612            // Read back the JSON-formatted result
613            let result = result.read(&store, &self.runtime.memory)?;
614            let result = serde_json::from_slice(result.to_bytes())?;
615            Ok(result)
616        } else {
617            // Reset the heap pointer
618            self.runtime
619                .opa_heap_ptr_set_func
620                .call(&mut store, &self.heap_ptr)
621                .await?;
622
623            // Load the input
624            let input = self.runtime.load_json(&mut store, input).await?;
625
626            // Create a new evaluation context
627            let ctx = self.runtime.opa_eval_ctx_new_func.call(&mut store).await?;
628
629            // Set the data location
630            self.runtime
631                .opa_eval_ctx_set_data_func
632                .call(&mut store, &ctx, &self.data)
633                .await?;
634            // Set the input location
635            self.runtime
636                .opa_eval_ctx_set_input_func
637                .call(&mut store, &ctx, &input)
638                .await?;
639
640            // Set the entrypoint
641            self.runtime
642                .opa_eval_ctx_set_entrypoint_func
643                .call(&mut store, &ctx, entrypoint)
644                .await?;
645
646            // Evaluate the policy
647            self.runtime.eval_func.call(&mut store, &ctx).await?;
648
649            // Get the results back
650            let result = self
651                .runtime
652                .opa_eval_ctx_get_result_func
653                .call(&mut store, &ctx)
654                .await?;
655
656            let result = self
657                .runtime
658                .opa_json_dump_func
659                .decode(&mut store, &self.runtime.memory, &result)
660                .await?;
661
662            Ok(result)
663        }
664    }
665}
666
667impl<C> Deref for Policy<C> {
668    type Target = Runtime<C>;
669    fn deref(&self) -> &Self::Target {
670        &self.runtime
671    }
672}