borderless_runtime/rt/
agent.rs

1use std::sync::Arc;
2use std::time::Instant;
3
4use ahash::HashMap;
5use borderless::__private::registers::*;
6use borderless::agents::Init;
7use borderless::common::{Introduction, Revocation, Symbols};
8use borderless::events::Events;
9use borderless::{events::CallAction, AgentId, BorderlessId};
10use borderless_kv_store::backend::lmdb::Lmdb;
11use borderless_kv_store::Db;
12use parking_lot::Mutex as SyncMutex;
13use tokio::sync::{mpsc, Mutex};
14use wasmtime::{Caller, Config, Engine, ExternType, FuncType, Linker, Module};
15
16use super::vm::{ActiveEntity, Commit};
17use super::{
18    code_store::CodeStore,
19    vm::{self, VmState},
20};
21use crate::log_shim::*;
22use crate::{
23    error::{ErrorKind, Result},
24    AGENT_SUB_DB, SUBSCRIPTION_REL_SUB_DB,
25};
26
27pub mod tasks;
28
29pub type SharedRuntime<S> = Arc<Mutex<Runtime<S>>>;
30
31pub struct Runtime<S = Lmdb>
32where
33    S: Db,
34{
35    linker: Linker<VmState<S>>,
36    engine: Engine,
37    agent_store: CodeStore<S>,
38    mutability_lock: MutLock,
39    executor: Option<Vec<u8>>,
40}
41
42impl<S: Db> Runtime<S> {
43    pub fn new(storage: &S, agent_store: CodeStore<S>, lock: MutLock) -> Result<Self> {
44        let start = Instant::now();
45        // Create agent sub-db (in case it does not exist)
46        let _ = storage.create_sub_db(AGENT_SUB_DB)?;
47        let _ = storage.create_sub_db(SUBSCRIPTION_REL_SUB_DB)?;
48
49        // Generate engine ( with async enabled )
50        let mut config = Config::new();
51        config.cranelift_opt_level(wasmtime::OptLevel::Speed);
52        config.async_support(true); // <- BIG difference
53        let engine = Engine::new(&config)?;
54
55        let mut linker: Linker<VmState<S>> = Linker::new(&engine);
56
57        // NOTE: We have to wrap the functions into a closure here, because they must be monomorphized
58        // (as a generic function cannot be made into a function pointer)
59        linker.func_wrap(
60            "env",
61            "print",
62            |caller: Caller<'_, VmState<S>>, ptr, len, level| vm::print(caller, ptr, len, level),
63        )?;
64        linker.func_wrap(
65            "env",
66            "read_register",
67            |caller: Caller<'_, VmState<S>>, register_id, ptr| {
68                vm::read_register(caller, register_id, ptr)
69            },
70        )?;
71        linker.func_wrap(
72            "env",
73            "register_len",
74            |caller: Caller<'_, VmState<S>>, register_id| vm::register_len(caller, register_id),
75        )?;
76        linker.func_wrap(
77            "env",
78            "write_register",
79            |caller: Caller<'_, VmState<S>>, register_id, wasm_ptr, wasm_ptr_len| {
80                vm::write_register(caller, register_id, wasm_ptr, wasm_ptr_len)
81            },
82        )?;
83        linker.func_wrap(
84            "env",
85            "storage_read",
86            |caller: Caller<'_, VmState<S>>, base_key, sub_key, register_id| {
87                vm::storage_read(caller, base_key, sub_key, register_id)
88            },
89        )?;
90        linker.func_wrap(
91            "env",
92            "storage_write",
93            |caller: Caller<'_, VmState<S>>, base_key, sub_key, value_ptr, value_len| {
94                vm::storage_write(caller, base_key, sub_key, value_ptr, value_len)
95            },
96        )?;
97        linker.func_wrap(
98            "env",
99            "storage_remove",
100            |caller: Caller<'_, VmState<S>>, base_key, sub_key| {
101                vm::storage_remove(caller, base_key, sub_key)
102            },
103        )?;
104        linker.func_wrap(
105            "env",
106            "storage_has_key",
107            |caller: Caller<'_, VmState<S>>, base_key, sub_key| {
108                vm::storage_has_key(caller, base_key, sub_key)
109            },
110        )?;
111        linker.func_wrap(
112            "env",
113            "storage_cursor",
114            |caller: Caller<'_, VmState<S>>, base_key| vm::storage_cursor(caller, base_key),
115        )?;
116
117        // NOTE: Those functions introduce side-effects;
118        // they should only be used by us or during development of a contract
119        linker.func_wrap("env", "storage_gen_sub_key", vm::storage_gen_sub_key)?;
120        linker.func_wrap("env", "tic", |caller: Caller<'_, VmState<S>>| {
121            vm::tic(caller)
122        })?;
123        linker.func_wrap("env", "toc", |caller: Caller<'_, VmState<S>>| {
124            vm::toc(caller)
125        })?;
126        linker.func_wrap("env", "rand", vm::rand)?;
127
128        // --- TODO: Playground for the new async api
129        linker.func_wrap_async(
130            "env",
131            "send_http_rq",
132            |caller: Caller<'_, VmState<S>>, (rq_head, rq_body, rs_head, rs_body, err)| {
133                Box::new(vm::async_abi::send_http_rq(
134                    caller, rq_head, rq_body, rs_head, rs_body, err,
135                ))
136            },
137        )?;
138        linker.func_wrap_async(
139            "env",
140            "send_ws_msg",
141            |caller: Caller<'_, VmState<S>>, (msg_ptr, msg_len)| {
142                Box::new(vm::async_abi::send_ws_msg(caller, msg_ptr, msg_len))
143            },
144        )?;
145
146        linker.func_wrap("env", "timestamp", |caller: Caller<'_, VmState<S>>| {
147            vm::timestamp(caller)
148        })?;
149
150        info!("Initialized runtime in: {:?}", start.elapsed());
151
152        Ok(Self {
153            linker,
154            engine,
155            agent_store,
156            mutability_lock: lock,
157            executor: None,
158        })
159    }
160
161    pub fn into_shared(self) -> Arc<Mutex<Self>> {
162        Arc::new(Mutex::new(self))
163    }
164
165    /// Returns a copy of the underlying db handle
166    pub fn get_db(&self) -> S {
167        self.agent_store.get_db()
168    }
169
170    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(%agent_id), err))]
171    pub fn instantiate_sw_agent(&mut self, agent_id: AgentId, module_bytes: &[u8]) -> Result<()> {
172        let module = Module::new(&self.engine, module_bytes)?;
173        check_module(&self.engine, &module)?;
174        self.agent_store.insert_swagent(agent_id, module)?;
175        Ok(())
176    }
177
178    /// Sanity check for introductions
179    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, err))]
180    pub async fn check_module_and_state(
181        &mut self,
182        module_bytes: Vec<u8>,
183        state: serde_json::Value,
184    ) -> Result<(bool, Vec<String>)> {
185        let module = Module::new(&self.engine, module_bytes)?;
186        check_module(&self.engine, &module)?;
187        let mut store = self.agent_store.create_store(&self.engine)?;
188        let instance = self.linker.instantiate(&mut store, &module)?;
189
190        // Prepare registers
191        store
192            .data_mut()
193            .set_register(REGISTER_INPUT, state.to_string().into_bytes());
194
195        // Get function
196        let func = instance.get_typed_func::<(), ()>(&mut store, "parse_state")?;
197
198        // Prepare execution
199        store.data_mut().prepare_exec(ActiveEntity::None)?;
200
201        // Call the actual function on the wasm side
202        let success = match func.call_async(&mut store, ()).await {
203            Ok(()) => true,
204            Err(_e) => false,
205        };
206        let log = store.data_mut().finish_exec(None)?;
207        Ok((success, log.into_iter().map(|l| l.msg).collect()))
208    }
209
210    /// Sets the currently active executor
211    ///
212    /// This buffers the [`BorderlessId`] of the executor, to later write it into the dedicated register,
213    /// so that the wasm side can query it.
214    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(%executor_id), err))]
215    pub fn set_executor(&mut self, executor_id: BorderlessId) -> Result<()> {
216        let bytes = executor_id.into_bytes().to_vec();
217        self.executor = Some(bytes);
218        Ok(())
219    }
220
221    /// Registers a new websocket client
222    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(agent_id = %aid), err))]
223    pub fn register_ws(&mut self, aid: AgentId) -> Result<mpsc::Receiver<Vec<u8>>> {
224        let (tx, rx) = mpsc::channel(4);
225        self.mutability_lock.insert_ws_sender(&aid, tx);
226        Ok(rx)
227    }
228
229    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(agent_id = %aid), err))]
230    pub async fn initialize(&mut self, aid: &AgentId) -> Result<Init> {
231        let (instance, mut store) = self
232            .agent_store
233            .get_agent(aid, &self.engine, &mut self.linker)
234            .await?
235            .ok_or_else(|| ErrorKind::MissingAgent { aid: *aid })?;
236
237        // Buffered registers
238        store
239            .data_mut()
240            .set_register(REGISTER_EXECUTOR, self.executor.clone().unwrap_or_default());
241
242        // Call the actual function on the wasm side
243        let func = instance.get_typed_func::<(), ()>(&mut store, "on_init")?;
244        store
245            .data_mut()
246            .prepare_exec(ActiveEntity::agent(*aid, false))?;
247
248        if let Err(e) = func.call_async(&mut store, ()).await {
249            warn!("initialize failed with error: {e}");
250        }
251        let output = store.data().get_register(REGISTER_OUTPUT);
252        store.data_mut().finish_exec(None)?;
253
254        // Return output events
255        let bytes = output.ok_or_else(|| ErrorKind::MissingRegisterValue("init-output"))?;
256        let init = Init::from_bytes(&bytes)?;
257
258        Ok(init)
259    }
260
261    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(agent_id = %aid), err))]
262    pub async fn process_ws_msg(&mut self, aid: &AgentId, msg: Vec<u8>) -> Result<Option<Events>> {
263        self.call_mut(aid, msg, "on_ws_msg", Commit::Other).await
264    }
265
266    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(agent_id = %aid), err))]
267    pub async fn on_ws_open(&mut self, aid: &AgentId) -> Result<Option<Events>> {
268        self.call_mut(aid, Vec::new(), "on_ws_open", Commit::Other)
269            .await
270    }
271
272    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(agent_id = %aid), err))]
273    pub async fn on_ws_error(&mut self, aid: &AgentId) -> Result<Option<Events>> {
274        self.call_mut(aid, Vec::new(), "on_ws_error", Commit::Other)
275            .await
276    }
277
278    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(agent_id = %aid), err))]
279    pub async fn on_ws_close(&mut self, aid: &AgentId) -> Result<Option<Events>> {
280        self.call_mut(aid, Vec::new(), "on_ws_close", Commit::Other)
281            .await
282    }
283
284    // TODO: If the initial state from the introduction cannot be parsed, the agent should *not* be saved !!
285    // Currently, this creates an agent, where decoding the state will constantly explode during runtime !!!
286    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(agent_id = %introduction.id), err))]
287    pub async fn process_introduction(&mut self, introduction: Introduction) -> Result<()> {
288        let aid = match introduction.id {
289            borderless::prelude::Id::Contract { .. } => return Err(ErrorKind::InvalidIdType.into()),
290            borderless::prelude::Id::Agent { agent_id } => agent_id,
291        };
292        // NOTE: The input for the introduction is not the introduction, but only the initial state!
293        // The introduction itself is commited by the VmState
294        let initial_state = introduction.initial_state.to_string().into_bytes();
295        let res = self
296            .call_mut(
297                &aid,
298                initial_state,
299                "process_introduction",
300                Commit::Introduction(introduction),
301            )
302            .await?;
303        assert!(res.is_none(), "introductions should not write events");
304        Ok(())
305    }
306
307    // TODO: Calling process revocation on an already revoked agent should generate an error
308    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(agent_id = %revocation.id), err))]
309    pub async fn process_revocation(&mut self, revocation: Revocation) -> Result<()> {
310        let aid = match revocation.id {
311            borderless::prelude::Id::Contract { .. } => return Err(ErrorKind::InvalidIdType.into()),
312            borderless::prelude::Id::Agent { agent_id } => agent_id,
313        };
314        // NOTE: The input for the introduction is not the introduction, but only the initial state!
315        // The introduction itself is commited by the VmState
316        let input = revocation.to_bytes()?;
317        let res = self
318            .call_mut(
319                &aid,
320                input,
321                "process_revocation",
322                Commit::Revocation(revocation),
323            )
324            .await?;
325        assert!(res.is_none(), "revocations should not write events");
326        Ok(())
327    }
328
329    // OK; Just to get some stuff going; I want to just simply call an action, and execute an http-request with it.
330    // That's more than enough to test stuff out.
331    // TODO: Logging ?
332    #[must_use = "You have to handle the output events of this function"]
333    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(agent_id = %aid), err))]
334    pub async fn process_action(
335        &mut self,
336        aid: &AgentId,
337        action: CallAction,
338    ) -> Result<Option<Events>> {
339        // Parse action
340        let input = action.to_bytes()?;
341        self.call_mut(aid, input, "process_action", Commit::Other)
342            .await
343    }
344
345    /// Helper function for mutable calls
346    async fn call_mut(
347        &mut self,
348        aid: &AgentId,
349        input: Vec<u8>,
350        method: &'static str,
351        commit: Commit,
352    ) -> Result<Option<Events>> {
353        let (instance, mut store) = self
354            .agent_store
355            .get_agent(aid, &self.engine, &mut self.linker)
356            .await?
357            .ok_or_else(|| ErrorKind::MissingAgent { aid: *aid })?;
358
359        let state = self.mutability_lock.get_lock_state(aid);
360        let _guard = state.lock.lock().await;
361
362        // Prepare registers
363        store.data_mut().set_register(REGISTER_INPUT, input);
364
365        // Buffered registers
366        store
367            .data_mut()
368            .set_register(REGISTER_EXECUTOR, self.executor.clone().unwrap_or_default());
369
370        // Inject ws-sender (if any)
371        if let Some(tx) = state.ws_sender {
372            store.data_mut().register_ws(tx)?;
373        }
374
375        // Call the actual function on the wasm side
376        let func = instance.get_typed_func::<(), ()>(&mut store, method)?;
377        store
378            .data_mut()
379            .prepare_exec(ActiveEntity::agent(*aid, true))?;
380
381        let commit = match func.call_async(&mut store, ()).await {
382            Ok(()) => Some(commit),
383            Err(e) => {
384                warn!("{method} failed with error: {e}");
385                None
386            }
387        };
388        let output = store.data().get_register(REGISTER_OUTPUT);
389        let _logs = store.data_mut().finish_exec(commit)?;
390
391        // Return output events
392        match output {
393            Some(bytes) => Ok(Some(Events::from_bytes(&bytes)?)),
394            None => Ok(None),
395        }
396    }
397
398    // --- NOTE: Maybe we should create a separate runtime for the HTTP handling ?
399
400    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(agent_id = %aid, %path), err))]
401    pub async fn http_get_state(&mut self, aid: &AgentId, path: String) -> Result<(u16, Vec<u8>)> {
402        // Get instance
403        let (instance, mut store) = self
404            .agent_store
405            .get_agent(aid, &self.engine, &mut self.linker)
406            .await?
407            .ok_or_else(|| ErrorKind::MissingAgent { aid: *aid })?;
408
409        // Prepare registers
410        store
411            .data_mut()
412            .set_register(REGISTER_INPUT_HTTP_PATH, path.into_bytes());
413
414        // Buffered registers
415        store
416            .data_mut()
417            .set_register(REGISTER_EXECUTOR, self.executor.clone().unwrap_or_default());
418
419        // Get function
420        let func = instance.get_typed_func::<(), ()>(&mut store, "http_get_state")?;
421
422        // Prepare execution
423        store
424            .data_mut()
425            .prepare_exec(ActiveEntity::agent(*aid, false))?;
426
427        // Call the function
428        if let Err(e) = func.call_async(&mut store, ()).await {
429            warn!("http_get_state failed with error: {e}");
430        }
431        let status = store.data().get_register(REGISTER_OUTPUT_HTTP_STATUS);
432        let result = store.data().get_register(REGISTER_OUTPUT_HTTP_RESULT);
433
434        // Finish the execution ( and commit nothing )
435        let _log = store.data_mut().finish_exec(None)?;
436
437        // Parse status
438        let status = status.ok_or_else(|| ErrorKind::MissingRegisterValue("http-status"))?;
439        let status_bytes = status
440            .try_into()
441            .map_err(|_| ErrorKind::InvalidRegisterValue {
442                register: "http-status",
443                expected_type: "u16",
444            })?;
445        let status = u16::from_be_bytes(status_bytes);
446
447        // Check result
448        let result = result.ok_or_else(|| ErrorKind::MissingRegisterValue("http-result"))?;
449
450        Ok((status, result))
451    }
452
453    // TODO: This will directly execute the action and return a list of events
454    //
455    // The question is, what should be returned via the web-api ?
456    /// Uses a POST request to parse and generate a [`CallAction`] object.
457    ///
458    /// The return type is a nested result. The outer result type should convert to a server error,
459    /// as it represents errors in the runtime itself.
460    /// The inner error type comes from the wasm code and contains the error status and message.
461    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(agent_id = %aid, %path, %writer), err))]
462    pub async fn http_post_action(
463        &mut self,
464        aid: &AgentId,
465        path: String,
466        payload: Vec<u8>,
467        writer: &BorderlessId, // TODO: I think the writer makes no sense here and is an artifact
468    ) -> Result<std::result::Result<(Events, CallAction), (u16, String)>> {
469        let (instance, mut store) = self
470            .agent_store
471            .get_agent(aid, &self.engine, &mut self.linker)
472            .await?
473            .ok_or_else(|| ErrorKind::MissingAgent { aid: *aid })?;
474
475        let state = self.mutability_lock.get_lock_state(aid);
476        let _guard = state.lock.lock().await;
477
478        // NOTE: We cannot convert the payload into a call-action on-spot, as we might call a nested route.
479        // To be precise - we *could* do it here, but I think it is cleaner to leave this logic up to the wasm module,
480        // as otherwise we may have to duplicate the logic here (and if it changes in the macro, we have to sync this with the code of the runtime etc.).
481        store
482            .data_mut()
483            .set_register(REGISTER_INPUT_HTTP_PATH, path.into_bytes());
484
485        store
486            .data_mut()
487            .set_register(REGISTER_INPUT_HTTP_PAYLOAD, payload);
488
489        store
490            .data_mut()
491            .set_register(REGISTER_WRITER, writer.into_bytes().into());
492
493        // Buffered registers
494        store
495            .data_mut()
496            .set_register(REGISTER_EXECUTOR, self.executor.clone().unwrap_or_default());
497
498        // Prepare mutable execution
499        store
500            .data_mut()
501            .prepare_exec(ActiveEntity::agent(*aid, true))?;
502
503        // Get function
504        let func = instance.get_typed_func::<(), ()>(&mut store, "http_post_action")?;
505
506        // Call the function
507        if let Err(e) = func.call_async(&mut store, ()).await {
508            warn!("http_get_state failed with error: {e}");
509        }
510        let status = store.data().get_register(REGISTER_OUTPUT_HTTP_STATUS);
511        let result = store.data().get_register(REGISTER_OUTPUT_HTTP_RESULT);
512        let output = store.data().get_register(REGISTER_OUTPUT);
513
514        // Finish the execution
515        // NOTE: This will clear all the registers !
516        let _log = store.data_mut().finish_exec(Some(Commit::Other))?;
517
518        // Parse status
519        let status = status.ok_or_else(|| ErrorKind::MissingRegisterValue("http-status"))?;
520        let status_bytes = status
521            .try_into()
522            .map_err(|_| ErrorKind::InvalidRegisterValue {
523                register: "http-status",
524                expected_type: "u16",
525            })?;
526        let status = u16::from_be_bytes(status_bytes);
527
528        // Check result
529        let result = result.ok_or_else(|| ErrorKind::MissingRegisterValue("http-result"))?;
530
531        if status == 200 {
532            let events = match output {
533                Some(b) => Events::from_bytes(&b)?,
534                None => Events::default(),
535            };
536            let action = CallAction::from_bytes(&result)?;
537            Ok(Ok((events, action)))
538        } else {
539            let error = String::from_utf8(result).map_err(|_| ErrorKind::InvalidRegisterValue {
540                register: "http-result",
541                expected_type: "string",
542            })?;
543            Ok(Err((status, error)))
544        }
545    }
546
547    /// Returns the symbols of the contract
548    pub async fn get_symbols(&mut self, aid: &AgentId) -> Result<Option<Symbols>> {
549        let (instance, mut store) = self
550            .agent_store
551            .get_agent(aid, &self.engine, &mut self.linker)
552            .await?
553            .ok_or_else(|| ErrorKind::MissingAgent { aid: *aid })?;
554
555        store.data_mut().prepare_exec(ActiveEntity::None)?;
556
557        // In case the contract does not export any symbols, just return 'None'
558        if let Err(e) = instance
559            .get_typed_func::<(), ()>(&mut store, "get_symbols")
560            .and_then(|func| func.call(&mut store, ()))
561        {
562            error!("get_symbols failed with error: {e}");
563        }
564        let output = store.data().get_register(REGISTER_OUTPUT);
565        store.data_mut().finish_exec(None)?;
566
567        let bytes = match output {
568            Some(b) => b,
569            None => return Ok(None),
570        };
571        let symbols = Symbols::from_bytes(&bytes)?;
572        Ok(Some(symbols))
573    }
574
575    pub fn available_agents(&self) -> Result<Vec<AgentId>> {
576        self.agent_store.available_swagents()
577    }
578}
579
580// NOTE: We Mis-Use the lock here to also carry persistent state - e.g. for the websocket
581#[derive(Default, Clone)]
582pub struct Lock {
583    lock: Arc<Mutex<()>>,
584    ws_sender: Option<mpsc::Sender<Vec<u8>>>,
585}
586
587/// Global mutability lock for all SW-Agents
588///
589/// Since we can only allow one mutable agent execution at a given time, we need a mechanism to ensure that.
590/// The `MutLock` ensures this on a per-agent basis. It holds `RwLock`s for all agents and provides threadsafe access.
591///
592/// The logic is similar but not identical to rusts ownership rules. While there can be only one read-write (mutable) execution,
593/// there can be multiple read-only (immutable) executions even if there is an ongoing read-write execution !
594/// The reason behind this is basically that read-only executions do not produce storage operations that would change the state in the database.
595/// In the `VmState`, all write operations are buffered until the execution is finished. If there would be two executions in parallel,
596/// we might end up commiting changes to a state, that has already changed under the hood - which is not what we want.
597/// However, if there is a writer thread, the readers do not care, and also the writer does not care about the readers.
598/// The readers will use the old state, until the new one is commited by the runtime.
599///
600/// Note: In contrast to [`borderless_runtime::rt::contract::MutLock`],
601/// this version uses asynchronous locks for the agents, and a synchronous lock only for the access of the hashmap.
602///
603/// Additionally, this double-functions as the provider of over-arching state per agent,
604/// e.g. the lock also contains the websocket sender, for agents that use websockets.
605#[derive(Clone, Default)]
606pub struct MutLock {
607    map: Arc<SyncMutex<HashMap<AgentId, Lock>>>,
608}
609
610impl MutLock {
611    /// Returns the `RwLock` for the given agent.
612    ///
613    /// If the agent-id is unknown, a new lock is created.
614    pub fn get_lock_state(&self, aid: &AgentId) -> Lock {
615        let mut map = self.map.lock();
616        let lock = map.entry(*aid).or_default();
617        lock.clone()
618    }
619
620    /// Inserts a new ws-sender for the agent
621    ///
622    /// Panics if the sender already contained a lock
623    pub fn insert_ws_sender(&self, aid: &AgentId, ws_sender: mpsc::Sender<Vec<u8>>) {
624        let mut map = self.map.lock();
625        let lock = map.entry(*aid).or_default();
626        assert!(
627            lock.ws_sender.is_none(),
628            "Cannot register websocket twice on the same agent"
629        );
630        lock.ws_sender = Some(ws_sender);
631    }
632}
633
634// NOTE: We could also check, if the websocket functions are exported,
635// and do a consistency check, if the module uses a websocket.
636// But maybe that's overkill.
637fn check_module(engine: &Engine, module: &Module) -> Result<()> {
638    let functions = [
639        "on_init",
640        "on_shutdown",
641        "process_action",
642        "process_introduction",
643        "process_revocation",
644        "http_get_state",
645        "http_post_action",
646        "parse_state",
647        "get_symbols",
648    ];
649    for func in functions {
650        let exp = module
651            .get_export(func)
652            .ok_or_else(|| ErrorKind::MissingExport { func })?;
653        if let ExternType::Func(func_type) = exp {
654            if !func_type.matches(&FuncType::new(engine, [], [])) {
655                return Err(ErrorKind::InvalidFuncType { func }.into());
656            }
657        } else {
658            return Err(ErrorKind::InvalidExport { func }.into());
659        }
660    }
661    Ok(())
662}
663
664#[cfg(test)]
665mod tests {
666    use super::*;
667
668    const ALL_EXPORTS: &str = r#"
669(module
670  ;; Declare the function `placeholder`
671  (func $placeholder)
672
673  ;; Export the functions so they can be called from outside the module
674  (export "on_init" (func $placeholder))
675  (export "on_shutdown" (func $placeholder))
676  (export "process_action" (func $placeholder))
677  (export "process_introduction" (func $placeholder))
678  (export "process_revocation" (func $placeholder))
679  (export "http_get_state" (func $placeholder))
680  (export "http_post_action" (func $placeholder))
681  (export "parse_state" (func $placeholder))
682  (export "get_symbols" (func $placeholder))
683)
684"#;
685    fn remove_line_with_pattern(original: &str, pattern: &str) -> String {
686        // Create a new Vec to hold the processed lines
687        let mut new_lines = Vec::new();
688
689        for line in original.lines() {
690            // Check if the line contains the pattern
691            if !line.contains(pattern) {
692                // Otherwise, push the original line
693                new_lines.push(line);
694            }
695        }
696
697        // Collect the lines back into a single string
698        new_lines.join("\n")
699    }
700
701    #[test]
702    fn missing_exports() {
703        let mut config = Config::new();
704        config.cranelift_opt_level(wasmtime::OptLevel::Speed);
705        config.async_support(false);
706        let engine = Engine::new(&config).unwrap();
707
708        // These are the functions, that must not be missing
709        let functions = [
710            "on_init",
711            "on_shutdown",
712            "process_action",
713            "process_introduction",
714            "process_revocation",
715            "http_get_state",
716            "http_post_action",
717            "parse_state",
718            "get_symbols",
719        ];
720        for func in functions {
721            let wat_missing = remove_line_with_pattern(ALL_EXPORTS, func);
722            let module = Module::new(&engine, &wat_missing);
723            assert!(module.is_ok());
724            let err = check_module(&engine, &module.unwrap());
725            assert!(err.is_err());
726        }
727        let module = Module::new(&engine, &ALL_EXPORTS);
728        assert!(module.is_ok());
729
730        let err = check_module(&engine, &module.unwrap());
731        assert!(err.is_ok());
732    }
733}