1use std::sync::Arc;
2
3use futures_core::Stream;
4use futures_util::{FutureExt, StreamExt};
5use parking_lot::Mutex;
6use wasmrs::{
7 BoxFlux, BoxMono, Frame, Handlers, IncomingMono, IncomingStream, Metadata, OperationHandler, OutgoingMono,
8 OutgoingStream, Payload, RSocket, RawPayload, WasmSocket,
9};
10use wasmrs_frames::PayloadError;
11use wasmrs_runtime::{spawn, ConditionallySend, UnboundedReceiver};
12use wasmrs_rx::*;
13
14use crate::context::{EngineProvider, SharedContext};
15
16type Result<T> = std::result::Result<T, crate::errors::Error>;
17
18#[must_use]
19#[allow(missing_debug_implementations)]
20pub struct Host {
22 engine: Box<dyn EngineProvider + Send + Sync>,
23 mtu: usize,
24 handlers: Arc<Mutex<Handlers>>,
25}
26
27impl Host {
28 pub async fn new<E: EngineProvider + Send + Sync + 'static>(engine: E) -> Result<Self> {
30 let host = Host {
31 engine: Box::new(engine),
32 mtu: 256,
33 handlers: Default::default(),
34 };
35
36 Ok(host)
37 }
38
39 pub async fn new_context(&self, host_buffer_size: u32, guest_buffer_size: u32) -> Result<CallContext> {
41 let mut socket = WasmSocket::new(
42 HostServer {
43 handlers: self.handlers.clone(),
44 },
45 wasmrs::SocketSide::Host,
46 );
47 let rx = socket.take_rx().unwrap();
48 let socket = Arc::new(socket);
49
50 let context = self.engine.new_context(socket.clone()).await?;
51
52 context.init(host_buffer_size, guest_buffer_size).await?;
53
54 CallContext::new(self.mtu, socket, context, rx)
55 }
56
57 pub fn register_request_response(
59 &self,
60 ns: impl AsRef<str>,
61 op: impl AsRef<str>,
62 handler: OperationHandler<IncomingMono, OutgoingMono>,
63 ) -> usize {
64 self.handlers.lock().register_request_response(ns, op, handler)
65 }
66
67 pub fn register_request_stream(
69 &self,
70 ns: impl AsRef<str>,
71 op: impl AsRef<str>,
72 handler: OperationHandler<IncomingMono, OutgoingStream>,
73 ) -> usize {
74 self.handlers.lock().register_request_stream(ns, op, handler)
75 }
76
77 pub fn register_request_channel(
79 &self,
80 ns: impl AsRef<str>,
81 op: impl AsRef<str>,
82 handler: OperationHandler<IncomingStream, OutgoingStream>,
83 ) -> usize {
84 self.handlers.lock().register_request_channel(ns, op, handler)
85 }
86
87 pub fn register_fire_and_forget(
89 &self,
90 ns: impl AsRef<str>,
91 op: impl AsRef<str>,
92 handler: OperationHandler<IncomingMono, ()>,
93 ) -> usize {
94 self.handlers.lock().register_fire_and_forget(ns, op, handler)
95 }
96}
97
98fn spawn_writer(mut rx: UnboundedReceiver<Frame>, context: SharedContext) -> tokio::task::JoinHandle<()> {
99 spawn("host:spawn_writer", async move {
100 while let Some(frame) = rx.recv().await {
101 let _ = context.write_frame(frame).await;
102 }
103 })
104}
105
106#[allow(missing_debug_implementations)]
107#[derive(Clone)]
108pub struct HostServer {
110 handlers: Arc<Mutex<Handlers>>,
111}
112
113fn parse_payload(req: RawPayload) -> Payload {
114 if let Some(mut md_bytes) = req.metadata {
115 let md = Metadata::decode(&mut md_bytes).unwrap();
116 Payload::new(md, req.data.unwrap())
117 } else {
118 panic!("No metadata found in payload.");
119 }
120}
121
122impl RSocket for HostServer {
123 fn fire_and_forget(&self, req: RawPayload) -> BoxMono<(), PayloadError> {
124 let payload = parse_payload(req);
125 let handler = self
126 .handlers
127 .lock()
128 .get_fnf_handler(payload.metadata.index.unwrap())
129 .unwrap();
130 handler(futures_util::future::ready(Ok(payload)).boxed()).unwrap();
131 futures_util::future::ready(Ok(())).boxed()
132 }
133
134 fn request_response(&self, req: RawPayload) -> BoxMono<RawPayload, PayloadError> {
135 let payload = parse_payload(req);
136 let handler = self
137 .handlers
138 .lock()
139 .get_request_response_handler(payload.metadata.index.unwrap())
140 .unwrap();
141
142 handler(futures_util::future::ready(Ok(payload)).boxed()).unwrap()
143 }
144
145 fn request_stream(&self, req: RawPayload) -> BoxFlux<RawPayload, PayloadError> {
146 let payload = parse_payload(req);
147 let handler = self
148 .handlers
149 .lock()
150 .get_request_stream_handler(payload.metadata.index.unwrap())
151 .unwrap();
152 handler(futures_util::future::ready(Ok(payload)).boxed()).unwrap()
153 }
154
155 fn request_channel<
156 T: Stream<Item = std::result::Result<RawPayload, PayloadError>> + ConditionallySend + Unpin + 'static,
157 >(
158 &self,
159 mut reqs: T,
160 ) -> BoxFlux<RawPayload, PayloadError> {
161 let (out_tx, out_rx) = FluxChannel::<RawPayload, PayloadError>::new_parts();
162 let handlers = self.handlers.clone();
163 tokio::spawn(async move {
164 let (inner_tx, inner_rx) = FluxChannel::new_parts();
165 let first = match reqs.next().await {
166 None => {
167 let _ = out_tx.send_result(Err(PayloadError::application_error("No first payload.", None)));
168 return;
169 }
170 Some(Err(e)) => {
171 let _ = out_tx.send_result(Err(e));
172 return;
173 }
174 Some(Ok(p)) => p,
175 };
176
177 let payload = parse_payload(first);
178 let handler = handlers
179 .lock()
180 .get_request_channel_handler(payload.metadata.index.unwrap())
181 .unwrap();
182 let _ = inner_tx.send(payload);
183 let mut out = handler(inner_rx.boxed()).unwrap();
184 tokio::spawn(async move {
185 while let Some(p) = out.next().await {
186 let _ = out_tx.send_result(p);
187 }
188 out_tx.complete();
189 });
190 tokio::spawn(async move {
191 while let Some(p) = reqs.next().await {
192 let _ = inner_tx.send_result(p.map(parse_payload));
193 }
194 inner_tx.complete();
195 });
196 });
197 out_rx.boxed()
198 }
199}
200
201pub struct CallContext {
203 socket: Arc<WasmSocket<HostServer>>,
204 context: SharedContext,
205 writer: tokio::task::JoinHandle<()>,
206}
207
208impl std::fmt::Debug for CallContext {
209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 f.debug_struct("WasmRsCallContext")
211 .field("state", &self.socket)
212 .finish()
213 }
214}
215
216impl CallContext {
217 fn new(
218 _mtu: usize,
219 socket: Arc<WasmSocket<HostServer>>,
220 context: SharedContext,
221 rx: UnboundedReceiver<Frame>,
222 ) -> Result<Self> {
223 let writer = spawn_writer(rx, context.clone());
224
225 Ok(Self {
226 socket,
227 context,
228 writer,
229 })
230 }
231
232 pub fn get_import(&self, namespace: &str, operation: &str) -> Option<u32> {
234 self.context.get_import(namespace, operation)
235 }
236
237 pub fn get_export(&self, namespace: &str, operation: &str) -> Option<u32> {
239 self.context.get_export(namespace, operation)
240 }
241
242 #[must_use]
244 pub fn get_exports(&self) -> Vec<String> {
245 self.context.get_operation_list().get_exports()
246 }
247
248 pub fn dump_operations(&self) {
250 println!("{:#?}", self.context.get_operation_list());
251 }
252
253 pub fn is_alive(&self) -> bool {
255 !self.writer.is_finished()
256 }
257}
258
259impl RSocket for CallContext {
260 fn fire_and_forget(&self, payload: RawPayload) -> BoxMono<(), PayloadError> {
261 self.socket.fire_and_forget(payload)
262 }
263
264 fn request_response(&self, payload: RawPayload) -> BoxMono<RawPayload, PayloadError> {
265 self.socket.request_response(payload)
266 }
267
268 fn request_stream(&self, payload: RawPayload) -> BoxFlux<RawPayload, PayloadError> {
269 self.socket.request_stream(payload)
270 }
271
272 fn request_channel<
273 T: Stream<Item = std::result::Result<RawPayload, PayloadError>> + ConditionallySend + Unpin + 'static,
274 >(
275 &self,
276 stream: T,
277 ) -> BoxFlux<RawPayload, PayloadError> {
278 self.socket.request_channel(stream)
279 }
280}