borderless_runtime/rt/
agent.rs

1use std::sync::Arc;
2use std::time::Instant;
3
4use ahash::HashMap;
5use borderless::__private::registers::*;
6use borderless::agents::Init;
7use borderless::common::{Introduction, Revocation, Symbols};
8use borderless::events::Events;
9use borderless::{events::CallAction, AgentId, BorderlessId};
10use borderless_kv_store::backend::lmdb::Lmdb;
11use borderless_kv_store::Db;
12use http::StatusCode;
13use parking_lot::Mutex as SyncMutex;
14use tokio::sync::{mpsc, Mutex};
15use wasmtime::{Caller, Config, Engine, ExternType, FuncType, Linker, Module};
16
17use super::vm::{ActiveEntity, Commit};
18use super::{
19    code_store::CodeStore,
20    vm::{self, VmState},
21};
22use crate::db::controller::Controller;
23use crate::log_shim::*;
24use crate::{
25    error::{ErrorKind, Result},
26    AGENT_SUB_DB, SUBSCRIPTION_REL_SUB_DB,
27};
28
29pub mod tasks;
30
31pub type SharedRuntime<S> = Arc<Mutex<Runtime<S>>>;
32
33pub struct Runtime<S = Lmdb>
34where
35    S: Db,
36{
37    linker: Linker<VmState<S>>,
38    engine: Engine,
39    agent_store: CodeStore<S>,
40    mutability_lock: MutLock,
41    executor: Option<Vec<u8>>,
42}
43
44impl<S: Db> Runtime<S> {
45    pub fn new(storage: &S, agent_store: CodeStore<S>, lock: MutLock) -> Result<Self> {
46        let start = Instant::now();
47        // Create agent sub-db (in case it does not exist)
48        let _ = storage.create_sub_db(AGENT_SUB_DB)?;
49        let _ = storage.create_sub_db(SUBSCRIPTION_REL_SUB_DB)?;
50
51        // Generate engine ( with async enabled )
52        let mut config = Config::new();
53        config.cranelift_opt_level(wasmtime::OptLevel::Speed);
54        config.async_support(true); // <- BIG difference
55        let engine = Engine::new(&config)?;
56
57        let mut linker: Linker<VmState<S>> = Linker::new(&engine);
58
59        // NOTE: We have to wrap the functions into a closure here, because they must be monomorphized
60        // (as a generic function cannot be made into a function pointer)
61        linker.func_wrap(
62            "env",
63            "print",
64            |caller: Caller<'_, VmState<S>>, ptr, len, level| vm::print(caller, ptr, len, level),
65        )?;
66        linker.func_wrap(
67            "env",
68            "read_register",
69            |caller: Caller<'_, VmState<S>>, register_id, ptr| {
70                vm::read_register(caller, register_id, ptr)
71            },
72        )?;
73        linker.func_wrap(
74            "env",
75            "register_len",
76            |caller: Caller<'_, VmState<S>>, register_id| vm::register_len(caller, register_id),
77        )?;
78        linker.func_wrap(
79            "env",
80            "write_register",
81            |caller: Caller<'_, VmState<S>>, register_id, wasm_ptr, wasm_ptr_len| {
82                vm::write_register(caller, register_id, wasm_ptr, wasm_ptr_len)
83            },
84        )?;
85        linker.func_wrap(
86            "env",
87            "storage_read",
88            |caller: Caller<'_, VmState<S>>, base_key, sub_key, register_id| {
89                vm::storage_read(caller, base_key, sub_key, register_id)
90            },
91        )?;
92        linker.func_wrap(
93            "env",
94            "storage_write",
95            |caller: Caller<'_, VmState<S>>, base_key, sub_key, value_ptr, value_len| {
96                vm::storage_write(caller, base_key, sub_key, value_ptr, value_len)
97            },
98        )?;
99        linker.func_wrap(
100            "env",
101            "storage_remove",
102            |caller: Caller<'_, VmState<S>>, base_key, sub_key| {
103                vm::storage_remove(caller, base_key, sub_key)
104            },
105        )?;
106        linker.func_wrap(
107            "env",
108            "storage_has_key",
109            |caller: Caller<'_, VmState<S>>, base_key, sub_key| {
110                vm::storage_has_key(caller, base_key, sub_key)
111            },
112        )?;
113        linker.func_wrap(
114            "env",
115            "storage_cursor",
116            |caller: Caller<'_, VmState<S>>, base_key| vm::storage_cursor(caller, base_key),
117        )?;
118
119        // NOTE: Those functions introduce side-effects;
120        // they should only be used by us or during development of a contract
121        linker.func_wrap("env", "storage_gen_sub_key", vm::storage_gen_sub_key)?;
122        linker.func_wrap("env", "tic", |caller: Caller<'_, VmState<S>>| {
123            vm::tic(caller)
124        })?;
125        linker.func_wrap("env", "toc", |caller: Caller<'_, VmState<S>>| {
126            vm::toc(caller)
127        })?;
128        linker.func_wrap("env", "rand", vm::rand)?;
129
130        // --- TODO: Playground for the new async api
131        linker.func_wrap_async(
132            "env",
133            "send_http_rq",
134            |caller: Caller<'_, VmState<S>>, (rq_head, rq_body, rs_head, rs_body, err)| {
135                Box::new(vm::async_abi::send_http_rq(
136                    caller, rq_head, rq_body, rs_head, rs_body, err,
137                ))
138            },
139        )?;
140        linker.func_wrap_async(
141            "env",
142            "send_ws_msg",
143            |caller: Caller<'_, VmState<S>>, (msg_ptr, msg_len)| {
144                Box::new(vm::async_abi::send_ws_msg(caller, msg_ptr, msg_len))
145            },
146        )?;
147
148        linker.func_wrap("env", "timestamp", |caller: Caller<'_, VmState<S>>| {
149            vm::timestamp(caller)
150        })?;
151
152        info!("Initialized runtime in: {:?}", start.elapsed());
153
154        Ok(Self {
155            linker,
156            engine,
157            agent_store,
158            mutability_lock: lock,
159            executor: None,
160        })
161    }
162
163    pub fn into_shared(self) -> Arc<Mutex<Self>> {
164        Arc::new(Mutex::new(self))
165    }
166
167    /// Returns a copy of the underlying db handle
168    pub fn get_db(&self) -> S {
169        self.agent_store.get_db()
170    }
171
172    /// Check whether a sw-agent exists
173    pub fn agent_exists(&self, aid: &AgentId) -> Result<bool> {
174        let db = self.get_db();
175        let controller = Controller::new(&db);
176        controller.agent_exists(aid)
177    }
178
179    /// Check whether a sw-agent is revoked
180    pub fn agent_revoked(&self, aid: &AgentId) -> Result<bool> {
181        let db = self.get_db();
182        let controller = Controller::new(&db);
183        controller.agent_revoked(aid)
184    }
185
186    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(%agent_id), err))]
187    pub fn instantiate_sw_agent(&mut self, agent_id: AgentId, module_bytes: &[u8]) -> Result<()> {
188        let module = Module::new(&self.engine, module_bytes)?;
189        check_module(&self.engine, &module)?;
190        self.agent_store.insert_swagent(agent_id, module)?;
191        Ok(())
192    }
193
194    /// Sanity check for introductions
195    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, err))]
196    pub async fn check_module_and_state(
197        &mut self,
198        module_bytes: Vec<u8>,
199        state: serde_json::Value,
200    ) -> Result<(bool, Vec<String>)> {
201        let module = Module::new(&self.engine, module_bytes)?;
202        check_module(&self.engine, &module)?;
203        let mut store = self.agent_store.create_store(&self.engine)?;
204        let instance = self.linker.instantiate(&mut store, &module)?;
205
206        // Prepare registers
207        store
208            .data_mut()
209            .set_register(REGISTER_INPUT, state.to_string().into_bytes());
210
211        // Get function
212        let func = instance.get_typed_func::<(), ()>(&mut store, "parse_state")?;
213
214        // Prepare execution
215        store.data_mut().prepare_exec(ActiveEntity::None)?;
216
217        // Call the actual function on the wasm side
218        let success = match func.call_async(&mut store, ()).await {
219            Ok(()) => true,
220            Err(_e) => false,
221        };
222        let log = store.data_mut().finish_exec(None)?;
223        Ok((success, log.into_iter().map(|l| l.msg).collect()))
224    }
225
226    /// Sets the currently active executor
227    ///
228    /// This buffers the [`BorderlessId`] of the executor, to later write it into the dedicated register,
229    /// so that the wasm side can query it.
230    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(%executor_id), err))]
231    pub fn set_executor(&mut self, executor_id: BorderlessId) -> Result<()> {
232        let bytes = executor_id.into_bytes().to_vec();
233        self.executor = Some(bytes);
234        Ok(())
235    }
236
237    /// Registers a new websocket client
238    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(agent_id = %aid), err))]
239    pub fn register_ws(&mut self, aid: AgentId) -> Result<mpsc::Receiver<Vec<u8>>> {
240        let (tx, rx) = mpsc::channel(4);
241        self.mutability_lock.insert_ws_sender(&aid, tx);
242        Ok(rx)
243    }
244
245    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(agent_id = %aid), err))]
246    pub async fn initialize(&mut self, aid: &AgentId) -> Result<Init> {
247        let (instance, mut store) = self
248            .agent_store
249            .get_agent(aid, &self.engine, &mut self.linker)
250            .await?
251            .ok_or_else(|| ErrorKind::MissingAgent { aid: *aid })?;
252
253        // Buffered registers
254        store
255            .data_mut()
256            .set_register(REGISTER_EXECUTOR, self.executor.clone().unwrap_or_default());
257
258        // Call the actual function on the wasm side
259        let func = instance.get_typed_func::<(), ()>(&mut store, "on_init")?;
260        store
261            .data_mut()
262            .prepare_exec(ActiveEntity::agent(*aid, false))?;
263
264        if let Err(e) = func.call_async(&mut store, ()).await {
265            warn!("initialize failed with error: {e}");
266        }
267        let output = store.data().get_register(REGISTER_OUTPUT);
268        store.data_mut().finish_exec(None)?;
269
270        // Return output events
271        let bytes = output.ok_or_else(|| ErrorKind::MissingRegisterValue("init-output"))?;
272        let init = Init::from_bytes(&bytes)?;
273
274        Ok(init)
275    }
276
277    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(agent_id = %aid), err))]
278    pub async fn process_ws_msg(&mut self, aid: &AgentId, msg: Vec<u8>) -> Result<Option<Events>> {
279        self.call_mut(aid, msg, "on_ws_msg", Commit::Other).await
280    }
281
282    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(agent_id = %aid), err))]
283    pub async fn on_ws_open(&mut self, aid: &AgentId) -> Result<Option<Events>> {
284        self.call_mut(aid, Vec::new(), "on_ws_open", Commit::Other)
285            .await
286    }
287
288    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(agent_id = %aid), err))]
289    pub async fn on_ws_error(&mut self, aid: &AgentId) -> Result<Option<Events>> {
290        self.call_mut(aid, Vec::new(), "on_ws_error", Commit::Other)
291            .await
292    }
293
294    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(agent_id = %aid), err))]
295    pub async fn on_ws_close(&mut self, aid: &AgentId) -> Result<Option<Events>> {
296        self.call_mut(aid, Vec::new(), "on_ws_close", Commit::Other)
297            .await
298    }
299
300    // TODO: If the initial state from the introduction cannot be parsed, the agent should *not* be saved !!
301    // Currently, this creates an agent, where decoding the state will constantly explode during runtime !!!
302    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(agent_id = %introduction.id), err))]
303    pub async fn process_introduction(&mut self, introduction: Introduction) -> Result<()> {
304        let aid = match introduction.id {
305            borderless::prelude::Id::Contract { .. } => return Err(ErrorKind::InvalidIdType.into()),
306            borderless::prelude::Id::Agent { agent_id } => agent_id,
307        };
308        // NOTE: The input for the introduction is not the introduction, but only the initial state!
309        // The introduction itself is commited by the VmState
310        let initial_state = introduction.initial_state.to_string().into_bytes();
311        let res = self
312            .call_mut(
313                &aid,
314                initial_state,
315                "process_introduction",
316                Commit::Introduction(introduction),
317            )
318            .await?;
319        assert!(res.is_none(), "introductions should not write events");
320        Ok(())
321    }
322
323    // TODO: Calling process revocation on an already revoked agent should generate an error
324    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(agent_id = %revocation.id), err))]
325    pub async fn process_revocation(&mut self, revocation: Revocation) -> Result<()> {
326        let aid = match revocation.id {
327            borderless::prelude::Id::Contract { .. } => return Err(ErrorKind::InvalidIdType.into()),
328            borderless::prelude::Id::Agent { agent_id } => agent_id,
329        };
330        let input = revocation.to_bytes()?;
331        let res = self
332            .call_mut(
333                &aid,
334                input,
335                "process_revocation",
336                Commit::Revocation(revocation),
337            )
338            .await?;
339        assert!(res.is_none(), "revocations should not write events");
340        Ok(())
341    }
342
343    // OK; Just to get some stuff going; I want to just simply call an action, and execute an http-request with it.
344    // That's more than enough to test stuff out.
345    // TODO: Logging ?
346    #[must_use = "You have to handle the output events of this function"]
347    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(agent_id = %aid), err))]
348    pub async fn process_action(
349        &mut self,
350        aid: &AgentId,
351        action: CallAction,
352    ) -> Result<Option<Events>> {
353        // Parse action
354        let input = action.to_bytes()?;
355        self.call_mut(aid, input, "process_action", Commit::Other)
356            .await
357    }
358
359    /// Helper function for mutable calls
360    async fn call_mut(
361        &mut self,
362        aid: &AgentId,
363        input: Vec<u8>,
364        method: &'static str,
365        commit: Commit,
366    ) -> Result<Option<Events>> {
367        let (instance, mut store) = self
368            .agent_store
369            .get_agent(aid, &self.engine, &mut self.linker)
370            .await?
371            .ok_or_else(|| ErrorKind::MissingAgent { aid: *aid })?;
372
373        if self.agent_revoked(&aid)? {
374            return Err(ErrorKind::RevokedAgent { aid: *aid }.into());
375        }
376
377        let state = self.mutability_lock.get_lock_state(aid);
378        let _guard = state.lock.lock().await;
379
380        // Prepare registers
381        store.data_mut().set_register(REGISTER_INPUT, input);
382
383        // Buffered registers
384        store
385            .data_mut()
386            .set_register(REGISTER_EXECUTOR, self.executor.clone().unwrap_or_default());
387
388        // Inject ws-sender (if any)
389        if let Some(tx) = state.ws_sender {
390            store.data_mut().register_ws(tx)?;
391        }
392
393        // Call the actual function on the wasm side
394        let func = instance.get_typed_func::<(), ()>(&mut store, method)?;
395        store
396            .data_mut()
397            .prepare_exec(ActiveEntity::agent(*aid, true))?;
398
399        let commit = match func.call_async(&mut store, ()).await {
400            Ok(()) => Some(commit),
401            Err(e) => {
402                warn!("{method} failed with error: {e}");
403                None
404            }
405        };
406        let output = store.data().get_register(REGISTER_OUTPUT);
407        let _logs = store.data_mut().finish_exec(commit)?;
408
409        // Return output events
410        match output {
411            Some(bytes) => Ok(Some(Events::from_bytes(&bytes)?)),
412            None => Ok(None),
413        }
414    }
415
416    // --- NOTE: Maybe we should create a separate runtime for the HTTP handling ?
417
418    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(agent_id = %aid, %path), err))]
419    pub async fn http_get_state(&mut self, aid: &AgentId, path: String) -> Result<(u16, Vec<u8>)> {
420        // Get instance
421        let (instance, mut store) = self
422            .agent_store
423            .get_agent(aid, &self.engine, &mut self.linker)
424            .await?
425            .ok_or_else(|| ErrorKind::MissingAgent { aid: *aid })?;
426
427        // Prepare registers
428        store
429            .data_mut()
430            .set_register(REGISTER_INPUT_HTTP_PATH, path.into_bytes());
431
432        // Buffered registers
433        store
434            .data_mut()
435            .set_register(REGISTER_EXECUTOR, self.executor.clone().unwrap_or_default());
436
437        // Get function
438        let func = instance.get_typed_func::<(), ()>(&mut store, "http_get_state")?;
439
440        // Prepare execution
441        store
442            .data_mut()
443            .prepare_exec(ActiveEntity::agent(*aid, false))?;
444
445        // Call the function
446        if let Err(e) = func.call_async(&mut store, ()).await {
447            warn!("http_get_state failed with error: {e}");
448        }
449        let status = store.data().get_register(REGISTER_OUTPUT_HTTP_STATUS);
450        let result = store.data().get_register(REGISTER_OUTPUT_HTTP_RESULT);
451
452        // Finish the execution ( and commit nothing )
453        let _log = store.data_mut().finish_exec(None)?;
454
455        // Parse status
456        let status = status.ok_or_else(|| ErrorKind::MissingRegisterValue("http-status"))?;
457        let status_bytes = status
458            .try_into()
459            .map_err(|_| ErrorKind::InvalidRegisterValue {
460                register: "http-status",
461                expected_type: "u16",
462            })?;
463        let status = u16::from_be_bytes(status_bytes);
464
465        // Check result
466        let result = result.ok_or_else(|| ErrorKind::MissingRegisterValue("http-result"))?;
467
468        Ok((status, result))
469    }
470
471    // TODO: This will directly execute the action and return a list of events
472    //
473    // The question is, what should be returned via the web-api ?
474    /// Uses a POST request to parse and generate a [`CallAction`] object.
475    ///
476    /// The return type is a nested result. The outer result type should convert to a server error,
477    /// as it represents errors in the runtime itself.
478    /// The inner error type comes from the wasm code and contains the error status and message.
479    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(agent_id = %aid, %path, %writer), err))]
480    pub async fn http_post_action(
481        &mut self,
482        aid: &AgentId,
483        path: String,
484        payload: Vec<u8>,
485        writer: &BorderlessId, // TODO: I think the writer makes no sense here and is an artifact
486    ) -> Result<std::result::Result<(Events, CallAction), (u16, String)>> {
487        // Check whether agent exists
488        let Some((instance, mut store)) = self
489            .agent_store
490            .get_agent(aid, &self.engine, &mut self.linker)
491            .await?
492        else {
493            return Ok(Err((
494                StatusCode::BAD_REQUEST.as_u16(),
495                ErrorKind::MissingAgent { aid: *aid }.to_string(),
496            )));
497        };
498        // Check whether agent is revoked
499        if self.agent_revoked(&aid)? {
500            return Ok(Err((
501                StatusCode::BAD_REQUEST.as_u16(),
502                ErrorKind::RevokedAgent { aid: *aid }.to_string(),
503            )));
504        }
505
506        let state = self.mutability_lock.get_lock_state(aid);
507        let _guard = state.lock.lock().await;
508
509        // NOTE: We cannot convert the payload into a call-action on-spot, as we might call a nested route.
510        // To be precise - we *could* do it here, but I think it is cleaner to leave this logic up to the wasm module,
511        // as otherwise we may have to duplicate the logic here (and if it changes in the macro, we have to sync this with the code of the runtime etc.).
512        store
513            .data_mut()
514            .set_register(REGISTER_INPUT_HTTP_PATH, path.into_bytes());
515
516        store
517            .data_mut()
518            .set_register(REGISTER_INPUT_HTTP_PAYLOAD, payload);
519
520        store
521            .data_mut()
522            .set_register(REGISTER_WRITER, writer.into_bytes().into());
523
524        // Buffered registers
525        store
526            .data_mut()
527            .set_register(REGISTER_EXECUTOR, self.executor.clone().unwrap_or_default());
528
529        // Prepare mutable execution
530        store
531            .data_mut()
532            .prepare_exec(ActiveEntity::agent(*aid, true))?;
533
534        // Get function
535        let func = instance.get_typed_func::<(), ()>(&mut store, "http_post_action")?;
536
537        // Call the function
538        if let Err(e) = func.call_async(&mut store, ()).await {
539            warn!("http_get_state failed with error: {e}");
540        }
541        let status = store.data().get_register(REGISTER_OUTPUT_HTTP_STATUS);
542        let result = store.data().get_register(REGISTER_OUTPUT_HTTP_RESULT);
543        let output = store.data().get_register(REGISTER_OUTPUT);
544
545        // Finish the execution
546        // NOTE: This will clear all the registers !
547        let _log = store.data_mut().finish_exec(Some(Commit::Other))?;
548
549        // Parse status
550        let status = status.ok_or_else(|| ErrorKind::MissingRegisterValue("http-status"))?;
551        let status_bytes = status
552            .try_into()
553            .map_err(|_| ErrorKind::InvalidRegisterValue {
554                register: "http-status",
555                expected_type: "u16",
556            })?;
557        let status = u16::from_be_bytes(status_bytes);
558
559        // Check result
560        let result = result.ok_or_else(|| ErrorKind::MissingRegisterValue("http-result"))?;
561
562        if status == 200 {
563            let events = match output {
564                Some(b) => Events::from_bytes(&b)?,
565                None => Events::default(),
566            };
567            let action = CallAction::from_bytes(&result)?;
568            Ok(Ok((events, action)))
569        } else {
570            let error = String::from_utf8(result).map_err(|_| ErrorKind::InvalidRegisterValue {
571                register: "http-result",
572                expected_type: "string",
573            })?;
574            Ok(Err((status, error)))
575        }
576    }
577
578    /// Returns the symbols of the contract
579    pub async fn get_symbols(&mut self, aid: &AgentId) -> Result<Option<Symbols>> {
580        let (instance, mut store) = self
581            .agent_store
582            .get_agent(aid, &self.engine, &mut self.linker)
583            .await?
584            .ok_or_else(|| ErrorKind::MissingAgent { aid: *aid })?;
585
586        store.data_mut().prepare_exec(ActiveEntity::None)?;
587
588        // In case the contract does not export any symbols, just return 'None'
589        if let Err(e) = instance
590            .get_typed_func::<(), ()>(&mut store, "get_symbols")
591            .and_then(|func| func.call(&mut store, ()))
592        {
593            error!("get_symbols failed with error: {e}");
594        }
595        let output = store.data().get_register(REGISTER_OUTPUT);
596        store.data_mut().finish_exec(None)?;
597
598        let bytes = match output {
599            Some(b) => b,
600            None => return Ok(None),
601        };
602        let symbols = Symbols::from_bytes(&bytes)?;
603        Ok(Some(symbols))
604    }
605
606    pub fn available_agents(&self) -> Result<Vec<AgentId>> {
607        self.agent_store.available_swagents()
608    }
609}
610
611// NOTE: We Mis-Use the lock here to also carry persistent state - e.g. for the websocket
612#[derive(Default, Clone)]
613pub struct Lock {
614    lock: Arc<Mutex<()>>,
615    ws_sender: Option<mpsc::Sender<Vec<u8>>>,
616}
617
618/// Global mutability lock for all SW-Agents
619///
620/// Since we can only allow one mutable agent execution at a given time, we need a mechanism to ensure that.
621/// The `MutLock` ensures this on a per-agent basis. It holds `RwLock`s for all agents and provides threadsafe access.
622///
623/// The logic is similar but not identical to rusts ownership rules. While there can be only one read-write (mutable) execution,
624/// there can be multiple read-only (immutable) executions even if there is an ongoing read-write execution !
625/// The reason behind this is basically that read-only executions do not produce storage operations that would change the state in the database.
626/// In the `VmState`, all write operations are buffered until the execution is finished. If there would be two executions in parallel,
627/// we might end up commiting changes to a state, that has already changed under the hood - which is not what we want.
628/// However, if there is a writer thread, the readers do not care, and also the writer does not care about the readers.
629/// The readers will use the old state, until the new one is commited by the runtime.
630///
631/// Note: In contrast to [`borderless_runtime::rt::contract::MutLock`],
632/// this version uses asynchronous locks for the agents, and a synchronous lock only for the access of the hashmap.
633///
634/// Additionally, this double-functions as the provider of over-arching state per agent,
635/// e.g. the lock also contains the websocket sender, for agents that use websockets.
636#[derive(Clone, Default)]
637pub struct MutLock {
638    map: Arc<SyncMutex<HashMap<AgentId, Lock>>>,
639}
640
641impl MutLock {
642    /// Returns the `RwLock` for the given agent.
643    ///
644    /// If the agent-id is unknown, a new lock is created.
645    pub fn get_lock_state(&self, aid: &AgentId) -> Lock {
646        let mut map = self.map.lock();
647        let lock = map.entry(*aid).or_default();
648        lock.clone()
649    }
650
651    /// Inserts a new ws-sender for the agent
652    ///
653    /// Panics if the sender already contained a lock
654    pub fn insert_ws_sender(&self, aid: &AgentId, ws_sender: mpsc::Sender<Vec<u8>>) {
655        let mut map = self.map.lock();
656        let lock = map.entry(*aid).or_default();
657        assert!(
658            lock.ws_sender.is_none(),
659            "Cannot register websocket twice on the same agent"
660        );
661        lock.ws_sender = Some(ws_sender);
662    }
663}
664
665// NOTE: We could also check, if the websocket functions are exported,
666// and do a consistency check, if the module uses a websocket.
667// But maybe that's overkill.
668fn check_module(engine: &Engine, module: &Module) -> Result<()> {
669    let functions = [
670        "on_init",
671        "on_shutdown",
672        "process_action",
673        "process_introduction",
674        "process_revocation",
675        "http_get_state",
676        "http_post_action",
677        "parse_state",
678        "get_symbols",
679    ];
680    for func in functions {
681        let exp = module
682            .get_export(func)
683            .ok_or_else(|| ErrorKind::MissingExport { func })?;
684        if let ExternType::Func(func_type) = exp {
685            if !func_type.matches(&FuncType::new(engine, [], [])) {
686                return Err(ErrorKind::InvalidFuncType { func }.into());
687            }
688        } else {
689            return Err(ErrorKind::InvalidExport { func }.into());
690        }
691    }
692    Ok(())
693}
694
695#[cfg(test)]
696mod tests {
697    use super::*;
698
699    const ALL_EXPORTS: &str = r#"
700(module
701  ;; Declare the function `placeholder`
702  (func $placeholder)
703
704  ;; Export the functions so they can be called from outside the module
705  (export "on_init" (func $placeholder))
706  (export "on_shutdown" (func $placeholder))
707  (export "process_action" (func $placeholder))
708  (export "process_introduction" (func $placeholder))
709  (export "process_revocation" (func $placeholder))
710  (export "http_get_state" (func $placeholder))
711  (export "http_post_action" (func $placeholder))
712  (export "parse_state" (func $placeholder))
713  (export "get_symbols" (func $placeholder))
714)
715"#;
716    fn remove_line_with_pattern(original: &str, pattern: &str) -> String {
717        // Create a new Vec to hold the processed lines
718        let mut new_lines = Vec::new();
719
720        for line in original.lines() {
721            // Check if the line contains the pattern
722            if !line.contains(pattern) {
723                // Otherwise, push the original line
724                new_lines.push(line);
725            }
726        }
727
728        // Collect the lines back into a single string
729        new_lines.join("\n")
730    }
731
732    #[test]
733    fn missing_exports() {
734        let mut config = Config::new();
735        config.cranelift_opt_level(wasmtime::OptLevel::Speed);
736        config.async_support(false);
737        let engine = Engine::new(&config).unwrap();
738
739        // These are the functions, that must not be missing
740        let functions = [
741            "on_init",
742            "on_shutdown",
743            "process_action",
744            "process_introduction",
745            "process_revocation",
746            "http_get_state",
747            "http_post_action",
748            "parse_state",
749            "get_symbols",
750        ];
751        for func in functions {
752            let wat_missing = remove_line_with_pattern(ALL_EXPORTS, func);
753            let module = Module::new(&engine, &wat_missing);
754            assert!(module.is_ok());
755            let err = check_module(&engine, &module.unwrap());
756            assert!(err.is_err());
757        }
758        let module = Module::new(&engine, &ALL_EXPORTS);
759        assert!(module.is_ok());
760
761        let err = check_module(&engine, &module.unwrap());
762        assert!(err.is_ok());
763    }
764}