opa_wasm/
policy.rs

1// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! The policy evaluation logic, which includes the [`Policy`] and [`Runtime`]
16//! structures.
17
18use std::{
19    collections::{HashMap, HashSet},
20    ffi::CString,
21    fmt::Debug,
22    ops::Deref,
23    sync::Arc,
24};
25
26use anyhow::{Context, Result};
27use tokio::sync::{Mutex, OnceCell};
28use tracing::Instrument;
29use wasmtime::{AsContextMut, Caller, Linker, Memory, MemoryType, Module};
30
31use crate::{
32    builtins::traits::Builtin,
33    funcs::{self, Func},
34    types::{AbiVersion, Addr, BuiltinId, EntrypointId, Heap, NulStr, Value},
35    DefaultContext, EvaluationContext,
36};
37
38/// Utility to allocate a string in the Wasm memory and return a pointer to it.
39async fn alloc_str<V: Into<Vec<u8>>, T: Send>(
40    opa_malloc: &funcs::OpaMalloc,
41    mut store: impl AsContextMut<Data = T>,
42    memory: &Memory,
43    value: V,
44) -> Result<Heap> {
45    let value = CString::new(value)?;
46    let value = value.as_bytes_with_nul();
47    let heap = opa_malloc.call(&mut store, value.len()).await?;
48
49    memory.write(
50        &mut store,
51        heap.ptr
52            .try_into()
53            .context("opa_malloc returned an invalid pointer value")?,
54        value,
55    )?;
56
57    Ok(heap)
58}
59
60/// Utility to load a JSON value into the WASM memory.
61async fn load_json<V: serde::Serialize, T: Send>(
62    opa_malloc: &funcs::OpaMalloc,
63    opa_free: &funcs::OpaFree,
64    opa_json_parse: &funcs::OpaJsonParse,
65    mut store: impl AsContextMut<Data = T>,
66    memory: &Memory,
67    data: &V,
68) -> Result<Value> {
69    let json = serde_json::to_vec(data)?;
70    let json = alloc_str(opa_malloc, &mut store, memory, json).await?;
71    let data = opa_json_parse.call(&mut store, &json).await?;
72    opa_free.call(&mut store, json).await?;
73    Ok(data)
74}
75
76/// A structure which holds the builtins referenced by the policy.
77struct LoadedBuiltins<C> {
78    /// A map of builtin IDs to the name and the builtin itself.
79    builtins: HashMap<i32, (String, Box<dyn Builtin<C>>)>,
80
81    /// The inner [`EvaluationContext`] which will be passed when calling
82    /// some builtins
83    context: Mutex<C>,
84}
85
86impl<C> std::fmt::Debug for LoadedBuiltins<C> {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        f.debug_struct("LoadedBuiltins")
89            .field("builtins", &())
90            .finish()
91    }
92}
93
94impl<C> LoadedBuiltins<C>
95where
96    C: EvaluationContext,
97{
98    /// Resolve the builtins from a map of builtin IDs to their names.
99    fn from_map(map: HashMap<String, BuiltinId>, context: C) -> Result<Self> {
100        let res: Result<_> = map
101            .into_iter()
102            .map(|(k, v)| {
103                let builtin = crate::builtins::resolve(&k)?;
104                Ok((v.0, (k, builtin)))
105            })
106            .collect();
107        Ok(Self {
108            builtins: res?,
109            context: Mutex::new(context),
110        })
111    }
112
113    /// Call the given builtin given its ID and arguments.
114    async fn builtin<T: Send, const N: usize>(
115        &self,
116        mut caller: Caller<'_, T>,
117        memory: &Memory,
118        builtin_id: i32,
119        args: [i32; N],
120    ) -> Result<i32, anyhow::Error> {
121        let (name, builtin) = self
122            .builtins
123            .get(&builtin_id)
124            .with_context(|| format!("unknown builtin id {builtin_id}"))?;
125
126        let span = tracing::info_span!("builtin", %name);
127        let _enter = span.enter();
128
129        let opa_json_dump = funcs::OpaJsonDump::from_caller(&mut caller)?;
130        let opa_json_parse = funcs::OpaJsonParse::from_caller(&mut caller)?;
131        let opa_malloc = funcs::OpaMalloc::from_caller(&mut caller)?;
132        let opa_free = funcs::OpaFree::from_caller(&mut caller)?;
133
134        // Call opa_json_dump on each argument
135        let mut args_json = Vec::with_capacity(N);
136        for arg in args {
137            args_json.push(opa_json_dump.call(&mut caller, &Value(arg)).await?);
138        }
139
140        // Extract the JSON value of each argument
141        let mut mapped_args = Vec::with_capacity(N);
142        for arg_json in args_json {
143            let arg = arg_json.read(&caller, memory)?;
144            mapped_args.push(arg.to_bytes());
145        }
146
147        let mut ctx = self.context.lock().await;
148
149        // Actually call the function
150        let ret = (async move { builtin.call(&mut ctx, &mapped_args).await })
151            .instrument(tracing::info_span!("builtin.call"))
152            .await?;
153
154        let json = alloc_str(&opa_malloc, &mut caller, memory, ret).await?;
155        let data = opa_json_parse.call(&mut caller, &json).await?;
156        opa_free.call(&mut caller, json).await?;
157
158        Ok(data.0)
159    }
160
161    /// Called when the policy evaluation starts, to reset the context and
162    /// record the evaluation starting time
163    async fn evaluation_start(&self) {
164        self.context.lock().await.evaluation_start();
165    }
166}
167
168/// An instance of a policy with builtins and entrypoints resolved, but with no
169/// data provided yet
170#[allow(clippy::missing_docs_in_private_items)]
171pub struct Runtime<C> {
172    version: AbiVersion,
173    memory: Memory,
174    entrypoints: HashMap<String, EntrypointId>,
175    loaded_builtins: Arc<OnceCell<LoadedBuiltins<C>>>,
176
177    eval_func: funcs::Eval,
178    opa_eval_ctx_new_func: funcs::OpaEvalCtxNew,
179    opa_eval_ctx_set_input_func: funcs::OpaEvalCtxSetInput,
180    opa_eval_ctx_set_data_func: funcs::OpaEvalCtxSetData,
181    opa_eval_ctx_set_entrypoint_func: funcs::OpaEvalCtxSetEntrypoint,
182    opa_eval_ctx_get_result_func: funcs::OpaEvalCtxGetResult,
183    opa_malloc_func: funcs::OpaMalloc,
184    opa_free_func: funcs::OpaFree,
185    opa_json_parse_func: funcs::OpaJsonParse,
186    opa_json_dump_func: funcs::OpaJsonDump,
187    opa_heap_ptr_set_func: funcs::OpaHeapPtrSet,
188    opa_heap_ptr_get_func: funcs::OpaHeapPtrGet,
189    opa_eval_func: Option<funcs::OpaEval>,
190}
191
192impl<C> Debug for Runtime<C> {
193    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194        f.debug_struct("Runtime")
195            .field("version", &self.version)
196            .field("memory", &self.memory)
197            .field("entrypoints", &self.entrypoints)
198            .finish_non_exhaustive()
199    }
200}
201
202impl Runtime<DefaultContext> {
203    /// Load a new WASM policy module into the given store, with the default
204    /// evaluation context.
205    ///
206    /// # Errors
207    ///
208    /// It will raise an error if one of the following condition is met:
209    ///
210    ///  - the provided [`wasmtime::Store`] isn't an async one
211    ///  - the [`wasmtime::Module`] was created with a different
212    ///    [`wasmtime::Engine`] than the [`wasmtime::Store`]
213    ///  - the WASM module is not a valid OPA WASM compiled policy, and lacks
214    ///    some of the exported functions
215    ///  - it failed to load the entrypoints or the builtins list
216    #[allow(clippy::too_many_lines)]
217    pub async fn new<T: Send + 'static>(
218        store: impl AsContextMut<Data = T>,
219        module: &Module,
220    ) -> Result<Self> {
221        let context = DefaultContext::default();
222        Self::new_with_evaluation_context(store, module, context).await
223    }
224}
225
226impl<C> Runtime<C> {
227    /// Load a new WASM policy module into the given store, with a given
228    /// evaluation context.
229    ///
230    /// # Errors
231    ///
232    /// It will raise an error if one of the following condition is met:
233    ///
234    ///  - the provided [`wasmtime::Store`] isn't an async one
235    ///  - the [`wasmtime::Module`] was created with a different
236    ///    [`wasmtime::Engine`] than the [`wasmtime::Store`]
237    ///  - the WASM module is not a valid OPA WASM compiled policy, and lacks
238    ///    some of the exported functions
239    ///  - it failed to load the entrypoints or the builtins list
240    #[allow(clippy::too_many_lines)]
241    pub async fn new_with_evaluation_context<T: Send + 'static>(
242        mut store: impl AsContextMut<Data = T>,
243        module: &Module,
244        context: C,
245    ) -> Result<Self>
246    where
247        C: EvaluationContext,
248    {
249        let ty = MemoryType::new(2, None);
250        let memory = Memory::new_async(&mut store, ty).await?;
251
252        // TODO: make the context configurable and reset it on evaluation
253        let eventually_builtins = Arc::new(OnceCell::<LoadedBuiltins<C>>::new());
254
255        let mut linker = Linker::new(store.as_context_mut().engine());
256        linker.define(&store, "env", "memory", memory)?;
257
258        linker.func_wrap(
259            "env",
260            "opa_abort",
261            move |caller: Caller<'_, _>, addr: i32| -> Result<(), anyhow::Error> {
262                let addr = NulStr(addr);
263                let msg = addr.read(&caller, &memory)?;
264                let msg = msg.to_string_lossy().into_owned();
265                tracing::error!("opa_abort: {}", msg);
266                anyhow::bail!(msg)
267            },
268        )?;
269
270        linker.func_wrap(
271            "env",
272            "opa_println",
273            move |caller: Caller<'_, _>, addr: i32| {
274                let addr = NulStr(addr);
275                let msg = addr.read(&caller, &memory)?;
276                tracing::info!("opa_print: {}", msg.to_string_lossy());
277                Ok(())
278            },
279        )?;
280
281        {
282            let eventually_builtins = eventually_builtins.clone();
283            linker.func_wrap_async(
284                "env",
285                "opa_builtin0",
286                move |caller: Caller<'_, _>, (builtin_id, _ctx): (i32, i32)| {
287                    let eventually_builtins = eventually_builtins.clone();
288                    Box::new(async move {
289                        eventually_builtins
290                            .get()
291                            .context("builtins where never initialized")?
292                            .builtin(caller, &memory, builtin_id, [])
293                            .await
294                    })
295                },
296            )?;
297        }
298
299        {
300            let eventually_builtins = eventually_builtins.clone();
301            linker.func_wrap_async(
302                "env",
303                "opa_builtin1",
304                move |caller: Caller<'_, _>, (builtin_id, _ctx, param1): (i32, i32, i32)| {
305                    let eventually_builtins = eventually_builtins.clone();
306                    Box::new(async move {
307                        eventually_builtins
308                            .get()
309                            .context("builtins where never initialized")?
310                            .builtin(caller, &memory, builtin_id, [param1])
311                            .await
312                    })
313                },
314            )?;
315        }
316
317        {
318            let eventually_builtins = eventually_builtins.clone();
319            linker.func_wrap_async(
320                "env",
321                "opa_builtin2",
322                move |caller: Caller<'_, _>,
323                      (builtin_id, _ctx, param1, param2): (i32, i32, i32, i32)| {
324                    let eventually_builtins = eventually_builtins.clone();
325                    Box::new(async move {
326                        eventually_builtins
327                            .get()
328                            .context("builtins where never initialized")?
329                            .builtin(caller, &memory, builtin_id, [param1, param2])
330                            .await
331                    })
332                },
333            )?;
334        }
335
336        {
337            let eventually_builtins = eventually_builtins.clone();
338            linker.func_wrap_async(
339                "env",
340                "opa_builtin3",
341                move |caller: Caller<'_, _>,
342                      (builtin_id,
343                      _ctx,
344                      param1,
345                      param2,
346                      param3): (i32, i32, i32, i32, i32)| {
347                    let eventually_builtins = eventually_builtins.clone();
348                    Box::new(async move {
349                        eventually_builtins
350                            .get()
351                            .context("builtins where never initialized")?
352                            .builtin(caller, &memory, builtin_id, [param1, param2, param3])
353                            .await
354                    })
355                },
356            )?;
357        }
358
359        {
360            let eventually_builtins = eventually_builtins.clone();
361            linker.func_wrap_async(
362                "env",
363                "opa_builtin4",
364                move |caller: Caller<'_, _>,
365                      (builtin_id, _ctx, param1, param2, param3, param4): (
366                    i32,
367                    i32,
368                    i32,
369                    i32,
370                    i32,
371                    i32,
372                )| {
373                    let eventually_builtins = eventually_builtins.clone();
374                    Box::new(async move {
375                        eventually_builtins
376                            .get()
377                            .context("builtins where never initialized")?
378                            .builtin(
379                                caller,
380                                &memory,
381                                builtin_id,
382                                [param1, param2, param3, param4],
383                            )
384                            .await
385                    })
386                },
387            )?;
388        }
389
390        let instance = linker.instantiate_async(&mut store, module).await?;
391
392        let version = AbiVersion::from_instance(&mut store, &instance)?;
393        tracing::debug!(%version, "Module ABI version");
394
395        let opa_json_dump_func = funcs::OpaJsonDump::from_instance(&mut store, &instance)?;
396
397        // Load the builtins map
398        let builtins = funcs::Builtins::from_instance(&mut store, &instance)?
399            .call(&mut store)
400            .await?;
401        let builtins = opa_json_dump_func
402            .decode(&mut store, &memory, &builtins)
403            .await?;
404        let builtins = LoadedBuiltins::from_map(builtins, context)?;
405        eventually_builtins.set(builtins)?;
406
407        // Load the entrypoints map
408        let entrypoints = funcs::Entrypoints::from_instance(&mut store, &instance)?
409            .call(&mut store)
410            .await?;
411        let entrypoints = opa_json_dump_func
412            .decode(&mut store, &memory, &entrypoints)
413            .await?;
414
415        let opa_eval_func = version
416            .has_eval_fastpath()
417            .then(|| funcs::OpaEval::from_instance(&mut store, &instance))
418            .transpose()?;
419
420        Ok(Self {
421            version,
422            memory,
423            entrypoints,
424            loaded_builtins: eventually_builtins,
425
426            eval_func: funcs::Eval::from_instance(&mut store, &instance)?,
427            opa_eval_ctx_new_func: funcs::OpaEvalCtxNew::from_instance(&mut store, &instance)?,
428            opa_eval_ctx_set_input_func: funcs::OpaEvalCtxSetInput::from_instance(
429                &mut store, &instance,
430            )?,
431            opa_eval_ctx_set_data_func: funcs::OpaEvalCtxSetData::from_instance(
432                &mut store, &instance,
433            )?,
434            opa_eval_ctx_set_entrypoint_func: funcs::OpaEvalCtxSetEntrypoint::from_instance(
435                &mut store, &instance,
436            )?,
437            opa_eval_ctx_get_result_func: funcs::OpaEvalCtxGetResult::from_instance(
438                &mut store, &instance,
439            )?,
440            opa_malloc_func: funcs::OpaMalloc::from_instance(&mut store, &instance)?,
441            opa_free_func: funcs::OpaFree::from_instance(&mut store, &instance)?,
442            opa_json_parse_func: funcs::OpaJsonParse::from_instance(&mut store, &instance)?,
443            opa_json_dump_func,
444            opa_heap_ptr_set_func: funcs::OpaHeapPtrSet::from_instance(&mut store, &instance)?,
445            opa_heap_ptr_get_func: funcs::OpaHeapPtrGet::from_instance(&mut store, &instance)?,
446            opa_eval_func,
447        })
448    }
449
450    /// Load a JSON value into the WASM memory
451    async fn load_json<V: serde::Serialize, T: Send>(
452        &self,
453        store: impl AsContextMut<Data = T>,
454        data: &V,
455    ) -> Result<Value> {
456        load_json(
457            &self.opa_malloc_func,
458            &self.opa_free_func,
459            &self.opa_json_parse_func,
460            store,
461            &self.memory,
462            data,
463        )
464        .await
465    }
466
467    /// Instanciate the policy with an empty `data` object
468    ///
469    /// # Errors
470    ///
471    /// If it failed to load the empty data object in memory
472    pub async fn without_data<T: Send>(
473        self,
474        store: impl AsContextMut<Data = T>,
475    ) -> Result<Policy<C>> {
476        let data = serde_json::Value::Object(serde_json::Map::default());
477        self.with_data(store, &data).await
478    }
479
480    /// Instanciate the policy with the given `data` object
481    ///
482    /// # Errors
483    ///
484    /// If it failed to serialize and load the `data` object
485    pub async fn with_data<V: serde::Serialize, T: Send>(
486        self,
487        mut store: impl AsContextMut<Data = T>,
488        data: &V,
489    ) -> Result<Policy<C>> {
490        let data = self.load_json(&mut store, data).await?;
491        let heap_ptr = self.opa_heap_ptr_get_func.call(&mut store).await?;
492        Ok(Policy {
493            runtime: self,
494            data,
495            heap_ptr,
496        })
497    }
498
499    /// Get the default entrypoint of this module. May return [`None`] if no
500    /// entrypoint with ID 0 was found
501    #[must_use]
502    pub fn default_entrypoint(&self) -> Option<&str> {
503        self.entrypoints
504            .iter()
505            .find_map(|(k, v)| (v.0 == 0).then_some(k.as_str()))
506    }
507
508    /// Get the list of entrypoints found in this module.
509    #[must_use]
510    pub fn entrypoints(&self) -> HashSet<&str> {
511        self.entrypoints.keys().map(String::as_str).collect()
512    }
513
514    /// Get the ABI version detected for this module
515    #[must_use]
516    pub fn abi_version(&self) -> AbiVersion {
517        self.version
518    }
519}
520
521/// An instance of a policy, ready to be executed
522#[derive(Debug)]
523pub struct Policy<C> {
524    /// The runtime this policy instance belongs to
525    runtime: Runtime<C>,
526
527    /// The data object loaded for this policy
528    data: Value,
529
530    /// A pointer to the heap, used for efficient allocations
531    heap_ptr: Addr,
532}
533
534impl<C> Policy<C> {
535    /// Evaluate a policy with the given entrypoint and input.
536    ///
537    /// # Errors
538    ///
539    /// Returns an error if the policy evaluation failed, or if this policy did
540    /// not belong to the given store.
541    pub async fn evaluate<V: serde::Serialize, R: for<'de> serde::Deserialize<'de>, T: Send>(
542        &self,
543        mut store: impl AsContextMut<Data = T>,
544        entrypoint: &str,
545        input: &V,
546    ) -> Result<R>
547    where
548        C: EvaluationContext,
549    {
550        // Lookup the entrypoint
551        let entrypoint = self
552            .runtime
553            .entrypoints
554            .get(entrypoint)
555            .with_context(|| format!("could not find entrypoint {entrypoint}"))?;
556
557        self.loaded_builtins
558            .get()
559            .context("builtins where never initialized")?
560            .evaluation_start()
561            .await;
562
563        // Take the fast path if it is awailable
564        if let Some(opa_eval) = &self.runtime.opa_eval_func {
565            // Write the input
566            let input = serde_json::to_vec(&input)?;
567            let input_heap = Heap {
568                ptr: self.heap_ptr.0,
569                len: input.len().try_into().context("input too long")?,
570                // Not managed by a malloc
571                freed: true,
572            };
573
574            // Check if we need to grow the memory first
575            let current_pages = self.runtime.memory.size(&store);
576            let needed_pages = input_heap.pages();
577            if current_pages < needed_pages {
578                self.runtime
579                    .memory
580                    .grow_async(&mut store, needed_pages - current_pages)
581                    .await?;
582            }
583
584            // Write the JSON input to memory
585            self.runtime.memory.write(
586                &mut store,
587                input_heap.ptr.try_into().context("invalid heap pointer")?,
588                &input[..],
589            )?;
590
591            let heap_ptr = Addr(input_heap.end());
592
593            // Call the eval fast-path
594            let result = opa_eval
595                .call(&mut store, entrypoint, &self.data, &input_heap, &heap_ptr)
596                .await?;
597
598            // Read back the JSON-formatted result
599            let result = result.read(&store, &self.runtime.memory)?;
600            let result = serde_json::from_slice(result.to_bytes())?;
601            Ok(result)
602        } else {
603            // Reset the heap pointer
604            self.runtime
605                .opa_heap_ptr_set_func
606                .call(&mut store, &self.heap_ptr)
607                .await?;
608
609            // Load the input
610            let input = self.runtime.load_json(&mut store, input).await?;
611
612            // Create a new evaluation context
613            let ctx = self.runtime.opa_eval_ctx_new_func.call(&mut store).await?;
614
615            // Set the data location
616            self.runtime
617                .opa_eval_ctx_set_data_func
618                .call(&mut store, &ctx, &self.data)
619                .await?;
620            // Set the input location
621            self.runtime
622                .opa_eval_ctx_set_input_func
623                .call(&mut store, &ctx, &input)
624                .await?;
625
626            // Set the entrypoint
627            self.runtime
628                .opa_eval_ctx_set_entrypoint_func
629                .call(&mut store, &ctx, entrypoint)
630                .await?;
631
632            // Evaluate the policy
633            self.runtime.eval_func.call(&mut store, &ctx).await?;
634
635            // Get the results back
636            let result = self
637                .runtime
638                .opa_eval_ctx_get_result_func
639                .call(&mut store, &ctx)
640                .await?;
641
642            let result = self
643                .runtime
644                .opa_json_dump_func
645                .decode(&mut store, &self.runtime.memory, &result)
646                .await?;
647
648            Ok(result)
649        }
650    }
651}
652
653impl<C> Deref for Policy<C> {
654    type Target = Runtime<C>;
655    fn deref(&self) -> &Self::Target {
656        &self.runtime
657    }
658}