Skip to main content

selium_kernel/
operation.rs

1use std::{convert::TryFrom, sync::Arc};
2
3use selium_abi::hostcalls::Hostcall;
4use selium_abi::{RkyvEncode, encode_rkyv};
5use tracing::{debug, trace};
6use wasmtime::{Caller, Linker};
7
8use crate::{
9    KernelError,
10    futures::FutureSharedState,
11    guest_data::{
12        GuestError, GuestInt, GuestResult, GuestUint, read_rkyv_value, write_poll_result,
13    },
14    registry::InstanceRegistry,
15};
16
17/// `Contract` is used by kernel drivers to define a consistent method for guest execution.
18/// This allows [`Operation`]s to expose the driver contract to the guest without having
19/// to know its internal structure.
20pub trait Contract {
21    type Input: RkyvEncode + Send;
22    type Output: RkyvEncode + Send;
23
24    fn to_future(
25        &self,
26        caller: &mut Caller<'_, InstanceRegistry>,
27        input: Self::Input,
28    ) -> impl Future<Output = GuestResult<Self::Output>> + Send + 'static;
29}
30
31/// An asynchronous system task that a guest can execute in a non-blocking fashion.
32pub struct Operation<Driver> {
33    driver: Driver,
34    module: &'static str,
35}
36
37/// Trait object for operations that can be linked into a Wasmtime linker.
38pub trait LinkableOperation: Send + Sync {
39    fn link(&self, linker: &mut Linker<InstanceRegistry>) -> Result<(), KernelError>;
40}
41
42struct OperationLinker<Driver> {
43    operation: Arc<Operation<Driver>>,
44}
45
46impl<Driver> LinkableOperation for OperationLinker<Driver>
47where
48    Driver: Contract + Send + Sync + 'static,
49    for<'a> <Driver::Input as rkyv::Archive>::Archived: 'a
50        + rkyv::Deserialize<Driver::Input, rkyv::api::high::HighDeserializer<rkyv::rancor::Error>>
51        + rkyv::bytecheck::CheckBytes<rkyv::api::high::HighValidator<'a, rkyv::rancor::Error>>,
52    for<'a> <Driver::Output as rkyv::Archive>::Archived: 'a
53        + rkyv::Deserialize<Driver::Output, rkyv::api::high::HighDeserializer<rkyv::rancor::Error>>
54        + rkyv::bytecheck::CheckBytes<rkyv::api::high::HighValidator<'a, rkyv::rancor::Error>>,
55{
56    fn link(&self, linker: &mut Linker<InstanceRegistry>) -> Result<(), KernelError> {
57        self.operation.link(linker)
58    }
59}
60
61impl<Driver> Operation<Driver>
62where
63    Driver: Contract,
64    for<'a> <Driver::Input as rkyv::Archive>::Archived: 'a
65        + rkyv::Deserialize<Driver::Input, rkyv::api::high::HighDeserializer<rkyv::rancor::Error>>
66        + rkyv::bytecheck::CheckBytes<rkyv::api::high::HighValidator<'a, rkyv::rancor::Error>>,
67    for<'a> <Driver::Output as rkyv::Archive>::Archived: 'a
68        + rkyv::Deserialize<Driver::Output, rkyv::api::high::HighDeserializer<rkyv::rancor::Error>>
69        + rkyv::bytecheck::CheckBytes<rkyv::api::high::HighValidator<'a, rkyv::rancor::Error>>,
70{
71    pub fn new(driver: Driver, module: &'static str) -> Arc<Self> {
72        Arc::new(Self { driver, module })
73    }
74
75    /// Create an operation from a canonical hostcall descriptor.
76    pub fn from_hostcall(
77        driver: Driver,
78        hostcall: &'static Hostcall<Driver::Input, Driver::Output>,
79    ) -> Arc<Self> {
80        Self::new(driver, hostcall.name())
81    }
82}
83
84impl<Driver> Operation<Driver>
85where
86    Driver: Contract + Send + Sync + 'static,
87    for<'a> <Driver::Input as rkyv::Archive>::Archived: 'a
88        + rkyv::Deserialize<Driver::Input, rkyv::api::high::HighDeserializer<rkyv::rancor::Error>>
89        + rkyv::bytecheck::CheckBytes<rkyv::api::high::HighValidator<'a, rkyv::rancor::Error>>,
90    for<'a> <Driver::Output as rkyv::Archive>::Archived: 'a
91        + rkyv::Deserialize<Driver::Output, rkyv::api::high::HighDeserializer<rkyv::rancor::Error>>
92        + rkyv::bytecheck::CheckBytes<rkyv::api::high::HighValidator<'a, rkyv::rancor::Error>>,
93{
94    pub fn link(
95        self: &Arc<Self>,
96        linker: &mut Linker<InstanceRegistry>,
97    ) -> Result<(), KernelError> {
98        let this = self.clone();
99        linker.func_wrap(
100            self.module,
101            "create",
102            move |caller: Caller<'_, InstanceRegistry>, args_ptr: GuestInt, args_len: GuestUint| {
103                this.create(caller, args_ptr, args_len).map_err(Into::into)
104            },
105        )?;
106
107        let this = self.clone();
108        linker.func_wrap(
109            self.module,
110            "poll",
111            move |caller: Caller<'_, InstanceRegistry>,
112                  state_id: GuestUint,
113                  task_id: GuestUint,
114                  result_ptr: GuestInt,
115                  result_capacity: GuestUint| {
116                this.poll(caller, state_id, task_id, result_ptr, result_capacity)
117                    .map_err(Into::into)
118            },
119        )?;
120
121        let this = self.clone();
122        linker.func_wrap(
123            self.module,
124            "drop",
125            move |caller: Caller<'_, InstanceRegistry>,
126                  state_id: GuestUint,
127                  result_ptr: GuestInt,
128                  result_capacity: GuestUint| {
129                this.drop(caller, state_id, result_ptr, result_capacity)
130                    .map_err(Into::into)
131            },
132        )?;
133
134        Ok(())
135    }
136
137    fn create(
138        self: &Arc<Self>,
139        mut caller: Caller<'_, InstanceRegistry>,
140        ptr: GuestInt,
141        len: GuestUint,
142    ) -> Result<GuestUint, KernelError> {
143        trace!("Creating future for {}", self.module);
144
145        let input = read_rkyv_value::<Driver::Input>(&mut caller, ptr, len)?;
146        let task = self.driver.to_future(&mut caller, input);
147        let state = FutureSharedState::new();
148        let shared = Arc::clone(&state);
149        tokio::spawn(async move {
150            let result = task.await.and_then(|out| {
151                encode_rkyv(&out)
152                    .map_err(|err| GuestError::Kernel(KernelError::Driver(err.to_string())))
153            });
154            shared.resolve(result);
155        });
156
157        let handle = caller.data_mut().insert_future(Arc::clone(&state))?;
158
159        GuestUint::try_from(handle).map_err(KernelError::IntConvert)
160    }
161
162    fn poll(
163        self: &Arc<Self>,
164        mut caller: Caller<'_, InstanceRegistry>,
165        state_id: GuestUint,
166        task_id: GuestUint,
167        ptr: GuestInt,
168        capacity: GuestUint,
169    ) -> Result<GuestUint, KernelError> {
170        trace!("Polling future for {}", self.module);
171
172        let state_id = usize::try_from(state_id)?;
173        let task_id = usize::try_from(task_id)?;
174
175        if let Some(base) = mailbox_base(&mut caller) {
176            caller.data().refresh_mailbox(base);
177        }
178
179        let guest_result = {
180            let registry = caller.data_mut();
181            match registry.future_state(state_id) {
182                Some(state) => {
183                    let waker = registry.waker(task_id).ok_or_else(|| {
184                        KernelError::Driver("guest mailbox unavailable".to_string())
185                    })?;
186                    state.register_waker(waker);
187
188                    match state.take_result() {
189                        None => Err(GuestError::WouldBlock),
190                        Some(output) => {
191                            registry.remove_future(state_id);
192                            output
193                        }
194                    }
195                }
196                None => Err(GuestError::NotFound),
197            }
198        };
199
200        let written = write_poll_result(
201            &mut caller,
202            ptr,
203            capacity,
204            guest_result.inspect_err(|e| {
205                if !matches!(e, GuestError::WouldBlock) {
206                    debug!("Future failed with error: {e}");
207                }
208            }),
209        )?;
210        Ok(written as GuestUint)
211    }
212
213    fn drop(
214        self: &Arc<Self>,
215        mut caller: Caller<'_, InstanceRegistry>,
216        state_id: GuestUint,
217        ptr: GuestInt,
218        capacity: GuestUint,
219    ) -> Result<GuestUint, KernelError> {
220        trace!("Dropping future for {}", self.module);
221
222        let state_id = usize::try_from(state_id)?;
223
224        let guest_result = {
225            let registry = caller.data_mut();
226            if let Some(state) = registry.remove_future(state_id) {
227                state.abandon();
228                Ok(Vec::new())
229            } else {
230                Err(GuestError::NotFound)
231            }
232        };
233
234        let written = write_poll_result(&mut caller, ptr, capacity, guest_result)?;
235        Ok(written as GuestUint)
236    }
237}
238
239impl<Driver> Operation<Driver>
240where
241    Driver: Contract + Send + Sync + 'static,
242    for<'a> <Driver::Input as rkyv::Archive>::Archived: 'a
243        + rkyv::Deserialize<Driver::Input, rkyv::api::high::HighDeserializer<rkyv::rancor::Error>>
244        + rkyv::bytecheck::CheckBytes<rkyv::api::high::HighValidator<'a, rkyv::rancor::Error>>,
245    for<'a> <Driver::Output as rkyv::Archive>::Archived: 'a
246        + rkyv::Deserialize<Driver::Output, rkyv::api::high::HighDeserializer<rkyv::rancor::Error>>
247        + rkyv::bytecheck::CheckBytes<rkyv::api::high::HighValidator<'a, rkyv::rancor::Error>>,
248{
249    pub fn as_linkable(self: &Arc<Self>) -> Arc<dyn LinkableOperation> {
250        Arc::new(OperationLinker {
251            operation: Arc::clone(self),
252        })
253    }
254}
255
256fn mailbox_base(caller: &mut Caller<'_, InstanceRegistry>) -> Option<usize> {
257    caller
258        .get_export("memory")
259        .and_then(|export| export.into_memory())
260        .map(|memory| memory.data_ptr(&mut *caller) as usize)
261}