1use std::ops::{Deref, DerefMut};
2
3use anyhow::Result;
4use async_trait::async_trait;
5use surrealism_types::controller::AsyncMemoryController;
6use surrealism_types::err::PrefixError;
7use surrealism_types::serialize::SerializableRange;
8use surrealism_types::transfer::AsyncTransfer;
9use wasmtime::{Caller, Linker};
10
11use crate::config::SurrealismConfig;
12use crate::controller::StoreData;
13use crate::kv::KVStore;
14
15macro_rules! host_try_or_return {
16 ($error:expr,$expr:expr) => {
17 match $expr {
18 Ok(val) => val,
19 Err(e) => {
20 eprintln!("{}: {}", $error, e);
21 return -1;
22 }
23 }
24 };
25}
26
27macro_rules! force_u32 {
29 ($ty:ty) => {
30 u32
31 };
32}
33
34#[macro_export]
38macro_rules! register_host_function {
39 ($linker:expr, $name:expr, |mut $controller:ident : $controller_ty:ty, $arg:ident : $arg_ty:ty| -> Result<$ret:ty> $body:tt) => {{
41 $linker
42 .func_wrap_async(
43 "env",
44 $name,
45 |caller: Caller<'_, StoreData>, ($arg,): (u32,)| {
46 Box::new(async move {
47 eprintln!("🔵 Host function called: {}", $name);
48 let mut $controller: $controller_ty = HostController::from(caller);
49 let $arg = host_try_or_return!("Failed to receive argument", <$arg_ty>::receive($arg.into(), &mut $controller).await);
50
51 eprintln!("🟡 Executing async body for: {}", $name);
52 let result = $body;
53 eprintln!("🟢 Async body completed for: {}", $name);
54
55 (*host_try_or_return!("Transfer error", result.transfer(&mut $controller).await)) as i32
56 })
57 }
58 )
59 .prefix_err(|| "failed to register host function")?
60 }};
61 ($linker:expr, $name:expr, |mut $controller:ident : $controller_ty:ty, $($arg:ident : $arg_ty:ty),+| -> Result<$ret:ty> $body:tt) => {{
63 $linker
64 .func_wrap_async(
65 "env",
66 $name,
67 |caller: Caller<'_, StoreData>, ($($arg),+): ($(force_u32!($arg_ty)),+)| {
68 Box::new(async move {
69 eprintln!("🔵 Host function called: {}", $name);
70 let mut $controller: $controller_ty = HostController::from(caller);
71 $(let $arg = host_try_or_return!("Failed to receive argument", <$arg_ty>::receive($arg.into(), &mut $controller).await);)+
72
73 eprintln!("🟡 Executing async body for: {}", $name);
74 let result = $body;
75 eprintln!("🟢 Async body completed for: {}", $name);
76
77 (*host_try_or_return!("Transfer error", result.transfer(&mut $controller).await)) as i32
78 })
79 }
80 )
81 .prefix_err(|| "failed to register host function")?
82 }};
83 ($linker:expr, $name:expr, |$controller:ident : $controller_ty:ty, $($arg:ident : $arg_ty:ty),+| -> Result<$ret:ty> $body:tt) => {{
85 $linker
86 .func_wrap_async(
87 "env",
88 $name,
89 |caller: Caller<'_, StoreData>, ($($arg),+): ($(force_u32!($arg_ty)),+)| {
90 Box::new(async move {
91 eprintln!("🔵 Host function called: {}", $name);
92 let mut $controller: $controller_ty = HostController::from(caller);
93 $(let $arg = host_try_or_return!("Failed to receive argument", <$arg_ty>::receive($arg.into(), &mut $controller).await);)+
94
95 eprintln!("🟡 Executing async body for: {}", $name);
96 let result = $body;
97 eprintln!("🟢 Async body completed for: {}", $name);
98
99 (*host_try_or_return!("Transfer error", result.transfer(&mut $controller).await)) as i32
100 })
101 }
102 )
103 .prefix_err(|| "failed to register host function")?
104 }};
105}
106
107macro_rules! map_ok {
108 ($expr:expr => |$x:ident| $body:expr) => {
109 match $expr {
110 Ok($x) => $body,
111 Err(e) => Err(e),
112 }
113 };
114}
115
116#[async_trait]
119pub trait InvocationContext: Send + Sync {
120 async fn sql(
121 &mut self,
122 config: &SurrealismConfig,
123 query: String,
124 vars: surrealdb_types::Object,
125 ) -> Result<surrealdb_types::Value>;
126 async fn run(
127 &mut self,
128 config: &SurrealismConfig,
129 fnc: String,
130 version: Option<String>,
131 args: Vec<surrealdb_types::Value>,
132 ) -> Result<surrealdb_types::Value>;
133
134 fn kv(&mut self) -> Result<&dyn KVStore>;
135
136 fn stdout(&mut self, output: &str) -> Result<()> {
138 print!("{}", output);
140 Ok(())
141 }
142
143 fn stderr(&mut self, output: &str) -> Result<()> {
145 eprint!("{}", output);
147 Ok(())
148 }
149}
150
151pub trait Host: InvocationContext {}
153
154pub fn implement_host_functions(linker: &mut Linker<StoreData>) -> Result<()> {
155 #[rustfmt::skip]
157 register_host_function!(linker, "__sr_sql", |mut controller: HostController, sql: String, vars: Vec<(String, surrealdb_types::Value)>| -> Result<surrealdb_types::Value> {
158 let vars = surrealdb_types::Object::from_iter(vars.into_iter());
159 let config = controller.config().clone();
160 controller.context_mut().sql(&config, sql, vars).await
161 });
162
163 #[rustfmt::skip]
165 register_host_function!(linker, "__sr_run", |mut controller: HostController, fnc: String, version: Option<String>, args: Vec<surrealdb_types::Value>| -> Result<surrealdb_types::Value> {
166 let config = controller.config().clone();
167 controller.context_mut().run(&config, fnc, version, args).await
168 });
169
170 #[rustfmt::skip]
172 register_host_function!(linker, "__sr_kv_get", |mut controller: HostController, key: String| -> Result<Option<surrealdb_types::Value>> {
173 map_ok!(controller.context_mut().kv() => |kv| kv.get(key).await)
174 });
175
176 #[rustfmt::skip]
177 register_host_function!(linker, "__sr_kv_set", |mut controller: HostController, key: String, value: surrealdb_types::Value| -> Result<()> {
178 map_ok!(controller.context_mut().kv() => |kv| kv.set(key, value).await)
179 });
180
181 #[rustfmt::skip]
182 register_host_function!(linker, "__sr_kv_del", |mut controller: HostController, key: String| -> Result<()> {
183 map_ok!(controller.context_mut().kv() => |kv| kv.del(key).await)
184 });
185
186 #[rustfmt::skip]
187 register_host_function!(linker, "__sr_kv_exists", |mut controller: HostController, key: String| -> Result<bool> {
188 map_ok!(controller.context_mut().kv() => |kv| kv.exists(key).await)
189 });
190
191 #[rustfmt::skip]
192 register_host_function!(linker, "__sr_kv_del_rng", |mut controller: HostController, range: SerializableRange<String>| -> Result<()> {
193 map_ok!(controller.context_mut().kv() => |kv| kv.del_rng(range.beg, range.end).await)
194 });
195
196 #[rustfmt::skip]
197 register_host_function!(linker, "__sr_kv_get_batch", |mut controller: HostController, keys: Vec<String>| -> Result<Vec<Option<surrealdb_types::Value>>> {
198 map_ok!(controller.context_mut().kv() => |kv| kv.get_batch(keys).await)
199 });
200
201 #[rustfmt::skip]
202 register_host_function!(linker, "__sr_kv_set_batch", |mut controller: HostController, entries: Vec<(String, surrealdb_types::Value)>| -> Result<()> {
203 map_ok!(controller.context_mut().kv() => |kv| kv.set_batch(entries).await)
204 });
205
206 #[rustfmt::skip]
207 register_host_function!(linker, "__sr_kv_del_batch", |mut controller: HostController, keys: Vec<String>| -> Result<()> {
208 map_ok!(controller.context_mut().kv() => |kv| kv.del_batch(keys).await)
209 });
210
211 #[rustfmt::skip]
212 register_host_function!(linker, "__sr_kv_keys", |mut controller: HostController, range: SerializableRange<String>| -> Result<Vec<String>> {
213 map_ok!(controller.context_mut().kv() => |kv| kv.keys(range.beg, range.end).await)
214 });
215
216 #[rustfmt::skip]
217 register_host_function!(linker, "__sr_kv_values", |mut controller: HostController, range: SerializableRange<String>| -> Result<Vec<surrealdb_types::Value>> {
218 map_ok!(controller.context_mut().kv() => |kv| kv.values(range.beg, range.end).await)
219 });
220
221 #[rustfmt::skip]
222 register_host_function!(linker, "__sr_kv_entries", |mut controller: HostController, range: SerializableRange<String>| -> Result<Vec<(String, surrealdb_types::Value)>> {
223 map_ok!(controller.context_mut().kv() => |kv| kv.entries(range.beg, range.end).await)
224 });
225
226 #[rustfmt::skip]
227 register_host_function!(linker, "__sr_kv_count", |mut controller: HostController, range: SerializableRange<String>| -> Result<u64> {
228 map_ok!(controller.context_mut().kv() => |kv| kv.count(range.beg, range.end).await)
229 });
230
231 Ok(())
232}
233
234struct HostController<'a>(Caller<'a, StoreData>);
235
236impl<'a> HostController<'a> {
237 pub fn context_mut(&mut self) -> &mut dyn InvocationContext {
239 &mut *self.0.data_mut().context
240 }
241
242 pub fn config(&self) -> &SurrealismConfig {
243 &self.0.data().config
244 }
245}
246
247impl<'a> From<Caller<'a, StoreData>> for HostController<'a> {
248 fn from(caller: Caller<'a, StoreData>) -> Self {
249 Self(caller)
250 }
251}
252
253impl<'a> Deref for HostController<'a> {
254 type Target = Caller<'a, StoreData>;
255 fn deref(&self) -> &Self::Target {
256 &self.0
257 }
258}
259
260impl<'a> DerefMut for HostController<'a> {
261 fn deref_mut(&mut self) -> &mut Self::Target {
262 &mut self.0
263 }
264}
265
266#[async_trait]
267impl<'a> AsyncMemoryController for HostController<'a> {
268 async fn alloc(&mut self, len: u32) -> Result<u32> {
269 let alloc_func = self
270 .get_export("__sr_alloc")
271 .ok_or_else(|| anyhow::anyhow!("Export __sr_alloc not found"))?
272 .into_func()
273 .ok_or_else(|| anyhow::anyhow!("Export __sr_alloc is not a function"))?;
274 let result =
275 alloc_func.typed::<(u32,), u32>(&mut self.0)?.call_async(&mut self.0, (len,)).await?;
276 if result == 0 {
277 anyhow::bail!("Memory allocation failed");
278 }
279 Ok(result)
280 }
281
282 async fn free(&mut self, ptr: u32, len: u32) -> Result<()> {
283 let free_func = self
284 .get_export("__sr_free")
285 .ok_or_else(|| anyhow::anyhow!("Export __sr_free not found"))?
286 .into_func()
287 .ok_or_else(|| anyhow::anyhow!("Export __sr_free is not a function"))?;
288 let result = free_func
289 .typed::<(u32, u32), u32>(&mut self.0)?
290 .call_async(&mut self.0, (ptr, len))
291 .await?;
292 if result == 0 {
293 anyhow::bail!("Memory deallocation failed");
294 }
295 Ok(())
296 }
297
298 fn mut_mem(&mut self, ptr: u32, len: u32) -> Result<&mut [u8]> {
299 let memory = self
300 .get_export("memory")
301 .ok_or_else(|| anyhow::anyhow!("Export memory not found"))?
302 .into_memory()
303 .ok_or_else(|| anyhow::anyhow!("Export memory is not a memory"))?;
304 let mem = memory.data_mut(&mut self.0);
305 if (ptr as usize) + (len as usize) > mem.len() {
306 anyhow::bail!(
307 "[ERROR] Out of bounds: ptr + len = {} > mem.len() = {}",
308 (ptr as usize) + (len as usize),
309 mem.len()
310 );
311 }
312 Ok(&mut mem[(ptr as usize)..(ptr as usize) + (len as usize)])
313 }
314}