1use std::str;
2
3use thiserror::Error;
4use wasmtime::{AsContext, Caller};
5
6use crate::{
7 KernelError,
8 drivers::Capability,
9 registry::{InstanceRegistry, RegistryError},
10};
11use selium_abi::{
12 DRIVER_ERROR_MESSAGE_CODE, DRIVER_RESULT_PENDING, RkyvEncode, WORD_SIZE, decode_rkyv,
13 driver_encode_error, driver_encode_ready, encode_driver_error_message, encode_rkyv,
14};
15pub use selium_abi::{GuestInt, GuestUint};
16
17pub type GuestResult<T, E = GuestError> = Result<T, E>;
18
19#[derive(Error, Debug)]
20pub enum GuestError {
21 #[error("invalid argument")]
23 InvalidArgument,
24 #[error("Invalid UTF-8 in guest input")]
25 InvalidUtf8,
26 #[error("Invalid guest memory slice")]
27 MemorySlice,
28 #[error("resource not found")]
29 NotFound,
30 #[error("permission denied")]
31 PermissionDenied,
32
33 #[error("The kernel encountered an error. Please report this to your administrator.")]
35 Kernel(#[from] KernelError),
36 #[error("The kernel Registry encountered an error. Please report this to your administrator.")]
37 Registry(#[from] RegistryError),
38 #[error("Stable identifier already exists")]
39 StableIdExists,
40 #[error("internal error: {0}")]
41 Subsystem(String),
42 #[error("This function would block")]
43 WouldBlock,
44}
45
46impl GuestError {
47 fn encode_for_guest(
48 self,
49 caller: &mut Caller<'_, InstanceRegistry>,
50 ptr: GuestInt,
51 len: GuestUint,
52 ) -> Result<GuestUint, KernelError> {
53 if matches!(self, GuestError::WouldBlock) {
54 return Ok(DRIVER_RESULT_PENDING);
55 }
56
57 let bytes = encode_driver_error_message(&self.to_string())
58 .map_err(|err| KernelError::Driver(err.to_string()))?;
59 write_encoded(caller, ptr, len, &bytes)?;
60 Ok(driver_encode_error(DRIVER_ERROR_MESSAGE_CODE))
61 }
62}
63
64pub fn write_poll_result(
65 caller: &mut Caller<'_, InstanceRegistry>,
66 ptr: GuestInt,
67 len: GuestUint,
68 result: GuestResult<Vec<u8>>,
69) -> Result<GuestUint, KernelError> {
70 match result {
71 Ok(bytes) => write_encoded(caller, ptr, len, &bytes),
72 Err(err) => err.encode_for_guest(caller, ptr, len),
73 }
74}
75
76pub fn write_rkyv_value<T>(
77 caller: &mut Caller<'_, InstanceRegistry>,
78 ptr: GuestInt,
79 len: GuestUint,
80 value: T,
81) -> Result<GuestUint, KernelError>
82where
83 T: RkyvEncode,
84{
85 let bytes = encode_value(&value)?;
86 write_encoded(caller, ptr, len, &bytes)
87}
88
89pub fn read_rkyv_value<T>(
90 caller: &mut Caller<'_, InstanceRegistry>,
91 ptr: GuestInt,
92 len: GuestUint,
93) -> Result<T, KernelError>
94where
95 T: rkyv::Archive + Sized,
96 for<'a> T::Archived: 'a
97 + rkyv::Deserialize<T, rkyv::api::high::HighDeserializer<rkyv::rancor::Error>>
98 + rkyv::bytecheck::CheckBytes<rkyv::api::high::HighValidator<'a, rkyv::rancor::Error>>,
99{
100 let bytes = read_guest_bytes(caller, ptr, len)?;
101 decode_value(&bytes)
102}
103
104fn encode_value<T>(value: &T) -> Result<Vec<u8>, KernelError>
105where
106 T: RkyvEncode,
107{
108 encode_rkyv(value).map_err(|err| KernelError::Driver(err.to_string()))
109}
110
111fn decode_value<T>(bytes: &[u8]) -> Result<T, KernelError>
112where
113 T: rkyv::Archive + Sized,
114 for<'a> T::Archived: 'a
115 + rkyv::Deserialize<T, rkyv::api::high::HighDeserializer<rkyv::rancor::Error>>
116 + rkyv::bytecheck::CheckBytes<rkyv::api::high::HighValidator<'a, rkyv::rancor::Error>>,
117{
118 decode_rkyv(bytes).map_err(|err| KernelError::Driver(err.to_string()))
119}
120
121fn read_guest_bytes(
122 caller: &mut Caller<'_, InstanceRegistry>,
123 ptr: GuestInt,
124 len: GuestUint,
125) -> Result<Vec<u8>, KernelError> {
126 let memory = caller
127 .get_export("memory")
128 .and_then(|export| export.into_memory())
129 .ok_or(KernelError::MemoryMissing)?;
130
131 let start = usize::try_from(ptr).map_err(KernelError::IntConvert)?;
132 let len = usize::try_from(len).map_err(KernelError::IntConvert)?;
133 let end = start.checked_add(len).ok_or(KernelError::MemoryCapacity)?;
134
135 let ctx = caller.as_context();
136 let data = memory
137 .data(&ctx)
138 .get(start..end)
139 .ok_or(KernelError::MemoryCapacity)?;
140 Ok(data.to_vec())
141}
142
143fn write_encoded(
144 caller: &mut Caller<'_, InstanceRegistry>,
145 ptr: GuestInt,
146 len: GuestUint,
147 bytes: &[u8],
148) -> Result<GuestUint, KernelError> {
149 let memory = caller
150 .get_export("memory")
151 .and_then(|export| export.into_memory())
152 .ok_or(KernelError::MemoryMissing)?;
153 let capacity = usize::try_from(len).map_err(KernelError::IntConvert)?;
154 if capacity < bytes.len() {
155 return Err(KernelError::MemoryCapacity);
156 }
157
158 let offset = usize::try_from(ptr).map_err(KernelError::IntConvert)?;
159 memory.write(caller, offset, bytes)?;
160
161 encode_ready_len(bytes.len())
162}
163
164pub fn read_u32(data: &[u8], index: usize) -> GuestResult<u32> {
165 let offset = index * WORD_SIZE;
166 let bytes = data
167 .get(offset..offset + WORD_SIZE)
168 .ok_or(GuestError::MemorySlice)?;
169 Ok(u32::from_le_bytes(
170 bytes.try_into().map_err(|_| GuestError::MemorySlice)?,
171 ))
172}
173
174pub fn read_i32(data: &[u8], index: usize) -> GuestResult<i32> {
175 let offset = index * WORD_SIZE;
176 let bytes = data
177 .get(offset..offset + WORD_SIZE)
178 .ok_or(GuestError::MemorySlice)?;
179 Ok(i32::from_le_bytes(
180 bytes.try_into().map_err(|_| GuestError::MemorySlice)?,
181 ))
182}
183
184pub fn read_utf8(
185 memory: &wasmtime::Memory,
186 ctx: &impl AsContext<Data = InstanceRegistry>,
187 ptr: GuestInt,
188 len: GuestUint,
189) -> GuestResult<String> {
190 let ptr = usize::try_from(ptr).map_err(|_| GuestError::InvalidArgument)?;
191 let len = usize::try_from(len).map_err(|_| GuestError::InvalidArgument)?;
192 let data = memory
193 .data(ctx)
194 .get(ptr..ptr + len)
195 .ok_or(GuestError::MemorySlice)?;
196 let s = str::from_utf8(data).map_err(|_| GuestError::InvalidUtf8)?;
197 Ok(s.to_string())
198}
199
200pub fn read_capabilities(
201 memory: &wasmtime::Memory,
202 ctx: &impl AsContext<Data = InstanceRegistry>,
203 ptr: GuestInt,
204 count: GuestUint,
205) -> GuestResult<Vec<Capability>> {
206 let ptr = usize::try_from(ptr).map_err(|_| GuestError::InvalidArgument)?;
207 let count = usize::try_from(count).map_err(|_| GuestError::InvalidArgument)?;
208 let data = memory
209 .data(ctx)
210 .get(ptr..ptr + count)
211 .ok_or(GuestError::MemorySlice)?;
212 let mut caps = Vec::with_capacity(count);
213 for byte in data {
214 caps.push(Capability::try_from(*byte).map_err(|_| GuestError::InvalidArgument)?);
215 }
216 Ok(caps)
217}
218pub fn encode_ready_len(len: usize) -> Result<GuestUint, KernelError> {
219 let guest_len = GuestUint::try_from(len).map_err(|_| KernelError::MemoryCapacity)?;
220 driver_encode_ready(guest_len).ok_or(KernelError::MemoryCapacity)
221}