1use std::{convert::TryFrom, sync::Arc};
2
3use selium_abi::hostcalls::Hostcall;
4use selium_abi::{RkyvEncode, encode_rkyv};
5use tracing::{debug, trace};
6use wasmtime::{Caller, Linker};
7
8use crate::{
9 KernelError,
10 futures::FutureSharedState,
11 guest_data::{
12 GuestError, GuestInt, GuestResult, GuestUint, read_rkyv_value, write_poll_result,
13 },
14 registry::InstanceRegistry,
15};
16
17pub trait Contract {
21 type Input: RkyvEncode + Send;
22 type Output: RkyvEncode + Send;
23
24 fn to_future(
25 &self,
26 caller: &mut Caller<'_, InstanceRegistry>,
27 input: Self::Input,
28 ) -> impl Future<Output = GuestResult<Self::Output>> + Send + 'static;
29}
30
31pub struct Operation<Driver> {
33 driver: Driver,
34 module: &'static str,
35}
36
37pub trait LinkableOperation: Send + Sync {
39 fn link(&self, linker: &mut Linker<InstanceRegistry>) -> Result<(), KernelError>;
40}
41
42struct OperationLinker<Driver> {
43 operation: Arc<Operation<Driver>>,
44}
45
46impl<Driver> LinkableOperation for OperationLinker<Driver>
47where
48 Driver: Contract + Send + Sync + 'static,
49 for<'a> <Driver::Input as rkyv::Archive>::Archived: 'a
50 + rkyv::Deserialize<Driver::Input, rkyv::api::high::HighDeserializer<rkyv::rancor::Error>>
51 + rkyv::bytecheck::CheckBytes<rkyv::api::high::HighValidator<'a, rkyv::rancor::Error>>,
52 for<'a> <Driver::Output as rkyv::Archive>::Archived: 'a
53 + rkyv::Deserialize<Driver::Output, rkyv::api::high::HighDeserializer<rkyv::rancor::Error>>
54 + rkyv::bytecheck::CheckBytes<rkyv::api::high::HighValidator<'a, rkyv::rancor::Error>>,
55{
56 fn link(&self, linker: &mut Linker<InstanceRegistry>) -> Result<(), KernelError> {
57 self.operation.link(linker)
58 }
59}
60
61impl<Driver> Operation<Driver>
62where
63 Driver: Contract,
64 for<'a> <Driver::Input as rkyv::Archive>::Archived: 'a
65 + rkyv::Deserialize<Driver::Input, rkyv::api::high::HighDeserializer<rkyv::rancor::Error>>
66 + rkyv::bytecheck::CheckBytes<rkyv::api::high::HighValidator<'a, rkyv::rancor::Error>>,
67 for<'a> <Driver::Output as rkyv::Archive>::Archived: 'a
68 + rkyv::Deserialize<Driver::Output, rkyv::api::high::HighDeserializer<rkyv::rancor::Error>>
69 + rkyv::bytecheck::CheckBytes<rkyv::api::high::HighValidator<'a, rkyv::rancor::Error>>,
70{
71 pub fn new(driver: Driver, module: &'static str) -> Arc<Self> {
72 Arc::new(Self { driver, module })
73 }
74
75 pub fn from_hostcall(
77 driver: Driver,
78 hostcall: &'static Hostcall<Driver::Input, Driver::Output>,
79 ) -> Arc<Self> {
80 Self::new(driver, hostcall.name())
81 }
82}
83
84impl<Driver> Operation<Driver>
85where
86 Driver: Contract + Send + Sync + 'static,
87 for<'a> <Driver::Input as rkyv::Archive>::Archived: 'a
88 + rkyv::Deserialize<Driver::Input, rkyv::api::high::HighDeserializer<rkyv::rancor::Error>>
89 + rkyv::bytecheck::CheckBytes<rkyv::api::high::HighValidator<'a, rkyv::rancor::Error>>,
90 for<'a> <Driver::Output as rkyv::Archive>::Archived: 'a
91 + rkyv::Deserialize<Driver::Output, rkyv::api::high::HighDeserializer<rkyv::rancor::Error>>
92 + rkyv::bytecheck::CheckBytes<rkyv::api::high::HighValidator<'a, rkyv::rancor::Error>>,
93{
94 pub fn link(
95 self: &Arc<Self>,
96 linker: &mut Linker<InstanceRegistry>,
97 ) -> Result<(), KernelError> {
98 let this = self.clone();
99 linker.func_wrap(
100 self.module,
101 "create",
102 move |caller: Caller<'_, InstanceRegistry>, args_ptr: GuestInt, args_len: GuestUint| {
103 this.create(caller, args_ptr, args_len).map_err(Into::into)
104 },
105 )?;
106
107 let this = self.clone();
108 linker.func_wrap(
109 self.module,
110 "poll",
111 move |caller: Caller<'_, InstanceRegistry>,
112 state_id: GuestUint,
113 task_id: GuestUint,
114 result_ptr: GuestInt,
115 result_capacity: GuestUint| {
116 this.poll(caller, state_id, task_id, result_ptr, result_capacity)
117 .map_err(Into::into)
118 },
119 )?;
120
121 let this = self.clone();
122 linker.func_wrap(
123 self.module,
124 "drop",
125 move |caller: Caller<'_, InstanceRegistry>,
126 state_id: GuestUint,
127 result_ptr: GuestInt,
128 result_capacity: GuestUint| {
129 this.drop(caller, state_id, result_ptr, result_capacity)
130 .map_err(Into::into)
131 },
132 )?;
133
134 Ok(())
135 }
136
137 fn create(
138 self: &Arc<Self>,
139 mut caller: Caller<'_, InstanceRegistry>,
140 ptr: GuestInt,
141 len: GuestUint,
142 ) -> Result<GuestUint, KernelError> {
143 trace!("Creating future for {}", self.module);
144
145 let input = read_rkyv_value::<Driver::Input>(&mut caller, ptr, len)?;
146 let task = self.driver.to_future(&mut caller, input);
147 let state = FutureSharedState::new();
148 let shared = Arc::clone(&state);
149 tokio::spawn(async move {
150 let result = task.await.and_then(|out| {
151 encode_rkyv(&out)
152 .map_err(|err| GuestError::Kernel(KernelError::Driver(err.to_string())))
153 });
154 shared.resolve(result);
155 });
156
157 let handle = caller.data_mut().insert_future(Arc::clone(&state))?;
158
159 GuestUint::try_from(handle).map_err(KernelError::IntConvert)
160 }
161
162 fn poll(
163 self: &Arc<Self>,
164 mut caller: Caller<'_, InstanceRegistry>,
165 state_id: GuestUint,
166 task_id: GuestUint,
167 ptr: GuestInt,
168 capacity: GuestUint,
169 ) -> Result<GuestUint, KernelError> {
170 trace!("Polling future for {}", self.module);
171
172 let state_id = usize::try_from(state_id)?;
173 let task_id = usize::try_from(task_id)?;
174
175 if let Some(base) = mailbox_base(&mut caller) {
176 caller.data().refresh_mailbox(base);
177 }
178
179 let guest_result = {
180 let registry = caller.data_mut();
181 match registry.future_state(state_id) {
182 Some(state) => {
183 let waker = registry.waker(task_id).ok_or_else(|| {
184 KernelError::Driver("guest mailbox unavailable".to_string())
185 })?;
186 state.register_waker(waker);
187
188 match state.take_result() {
189 None => Err(GuestError::WouldBlock),
190 Some(output) => {
191 registry.remove_future(state_id);
192 output
193 }
194 }
195 }
196 None => Err(GuestError::NotFound),
197 }
198 };
199
200 let written = write_poll_result(
201 &mut caller,
202 ptr,
203 capacity,
204 guest_result.inspect_err(|e| {
205 if !matches!(e, GuestError::WouldBlock) {
206 debug!("Future failed with error: {e}");
207 }
208 }),
209 )?;
210 Ok(written as GuestUint)
211 }
212
213 fn drop(
214 self: &Arc<Self>,
215 mut caller: Caller<'_, InstanceRegistry>,
216 state_id: GuestUint,
217 ptr: GuestInt,
218 capacity: GuestUint,
219 ) -> Result<GuestUint, KernelError> {
220 trace!("Dropping future for {}", self.module);
221
222 let state_id = usize::try_from(state_id)?;
223
224 let guest_result = {
225 let registry = caller.data_mut();
226 if let Some(state) = registry.remove_future(state_id) {
227 state.abandon();
228 Ok(Vec::new())
229 } else {
230 Err(GuestError::NotFound)
231 }
232 };
233
234 let written = write_poll_result(&mut caller, ptr, capacity, guest_result)?;
235 Ok(written as GuestUint)
236 }
237}
238
239impl<Driver> Operation<Driver>
240where
241 Driver: Contract + Send + Sync + 'static,
242 for<'a> <Driver::Input as rkyv::Archive>::Archived: 'a
243 + rkyv::Deserialize<Driver::Input, rkyv::api::high::HighDeserializer<rkyv::rancor::Error>>
244 + rkyv::bytecheck::CheckBytes<rkyv::api::high::HighValidator<'a, rkyv::rancor::Error>>,
245 for<'a> <Driver::Output as rkyv::Archive>::Archived: 'a
246 + rkyv::Deserialize<Driver::Output, rkyv::api::high::HighDeserializer<rkyv::rancor::Error>>
247 + rkyv::bytecheck::CheckBytes<rkyv::api::high::HighValidator<'a, rkyv::rancor::Error>>,
248{
249 pub fn as_linkable(self: &Arc<Self>) -> Arc<dyn LinkableOperation> {
250 Arc::new(OperationLinker {
251 operation: Arc::clone(self),
252 })
253 }
254}
255
256fn mailbox_base(caller: &mut Caller<'_, InstanceRegistry>) -> Option<usize> {
257 caller
258 .get_export("memory")
259 .and_then(|export| export.into_memory())
260 .map(|memory| memory.data_ptr(&mut *caller) as usize)
261}