1use std::{
2 convert::TryInto,
3 future::Future,
4 io::{Read, Write},
5};
6
7use anyhow::{anyhow, Result};
8use lunatic_common_api::{get_memory, IntoTrap};
9use lunatic_networking_api::NetworkingCtx;
10use lunatic_process_api::ProcessCtx;
11use tokio::time::{timeout, Duration};
12use wasmtime::{Caller, Linker};
13
14use lunatic_process::{
15 message::{DataMessage, Message},
16 state::ProcessState,
17 Signal,
18};
19
20pub fn register<T: ProcessState + ProcessCtx<T> + NetworkingCtx + Send + 'static>(
22 linker: &mut Linker<T>,
23) -> Result<()> {
24 linker.func_wrap("lunatic::message", "create_data", create_data)?;
25 linker.func_wrap("lunatic::message", "write_data", write_data)?;
26 linker.func_wrap("lunatic::message", "read_data", read_data)?;
27 linker.func_wrap("lunatic::message", "seek_data", seek_data)?;
28 linker.func_wrap("lunatic::message", "get_tag", get_tag)?;
29 linker.func_wrap("lunatic::message", "get_process_id", get_process_id)?;
30 linker.func_wrap("lunatic::message", "data_size", data_size)?;
31 linker.func_wrap("lunatic::message", "push_module", push_module)?;
32 linker.func_wrap("lunatic::message", "take_module", take_module)?;
33 linker.func_wrap("lunatic::message", "push_tcp_stream", push_tcp_stream)?;
34 linker.func_wrap("lunatic::message", "take_tcp_stream", take_tcp_stream)?;
35 linker.func_wrap("lunatic::message", "push_tls_stream", push_tls_stream)?;
36 linker.func_wrap("lunatic::message", "take_tls_stream", take_tls_stream)?;
37 linker.func_wrap("lunatic::message", "send", send)?;
38 linker.func_wrap3_async(
39 "lunatic::message",
40 "send_receive_skip_search",
41 send_receive_skip_search,
42 )?;
43 linker.func_wrap3_async("lunatic::message", "receive", receive)?;
44 linker.func_wrap("lunatic::message", "push_udp_socket", push_udp_socket)?;
45 linker.func_wrap("lunatic::message", "take_udp_socket", take_udp_socket)?;
46
47 Ok(())
48}
49
50fn create_data<T: ProcessState + ProcessCtx<T>>(
127 mut caller: Caller<T>,
128 tag: i64,
129 buffer_capacity: u64,
130) {
131 let tag = match tag {
132 0 => None,
133 tag => Some(tag),
134 };
135 let message = DataMessage::new(tag, buffer_capacity as usize);
136 caller
137 .data_mut()
138 .message_scratch_area()
139 .replace(Message::Data(message));
140}
141
142fn write_data<T: ProcessState + ProcessCtx<T>>(
148 mut caller: Caller<T>,
149 data_ptr: u32,
150 data_len: u32,
151) -> Result<u32> {
152 let memory = get_memory(&mut caller)?;
153 let mut message = caller
154 .data_mut()
155 .message_scratch_area()
156 .take()
157 .or_trap("lunatic::message::write_data")?;
158 let buffer = memory
159 .data(&caller)
160 .get(data_ptr as usize..(data_ptr as usize + data_len as usize))
161 .or_trap("lunatic::message::write_data")?;
162 let bytes = match &mut message {
163 Message::Data(data) => data.write(buffer).or_trap("lunatic::message::write_data")?,
164 Message::LinkDied(_) => {
165 return Err(anyhow!("Unexpected `Message::LinkDied` in scratch area"))
166 }
167 Message::ProcessDied(_) => {
168 return Err(anyhow!("Unexpected `Message::ProcessDied` in scratch area"))
169 }
170 };
171 caller.data_mut().message_scratch_area().replace(message);
173
174 Ok(bytes as u32)
175}
176
177fn read_data<T: ProcessState + ProcessCtx<T>>(
183 mut caller: Caller<T>,
184 data_ptr: u32,
185 data_len: u32,
186) -> Result<u32> {
187 let memory = get_memory(&mut caller)?;
188 let mut message = caller
189 .data_mut()
190 .message_scratch_area()
191 .take()
192 .or_trap("lunatic::message::read_data")?;
193 let buffer = memory
194 .data_mut(&mut caller)
195 .get_mut(data_ptr as usize..(data_ptr as usize + data_len as usize))
196 .or_trap("lunatic::message::read_data")?;
197 let bytes = match &mut message {
198 Message::Data(data) => data.read(buffer).or_trap("lunatic::message::read_data")?,
199 Message::LinkDied(_) => {
200 return Err(anyhow!("Unexpected `Message::LinkDied` in scratch area"))
201 }
202 Message::ProcessDied(_) => {
203 return Err(anyhow!("Unexpected `Message::ProcessDied` in scratch area"))
204 }
205 };
206 caller.data_mut().message_scratch_area().replace(message);
208
209 Ok(bytes as u32)
210}
211
212fn seek_data<T: ProcessState + ProcessCtx<T>>(mut caller: Caller<T>, index: u64) -> Result<()> {
219 let mut message = caller
220 .data_mut()
221 .message_scratch_area()
222 .as_mut()
223 .or_trap("lunatic::message::seek_data")?;
224 match &mut message {
225 Message::Data(data) => data.seek(index as usize),
226 Message::LinkDied(_) => {
227 return Err(anyhow!("Unexpected `Message::LinkDied` in scratch area"))
228 }
229 Message::ProcessDied(_) => {
230 return Err(anyhow!("Unexpected `Message::ProcessDied` in scratch area"))
231 }
232 };
233 Ok(())
234}
235
236fn get_tag<T: ProcessState + ProcessCtx<T>>(mut caller: Caller<T>) -> Result<i64> {
241 let message = caller
242 .data_mut()
243 .message_scratch_area()
244 .as_ref()
245 .or_trap("lunatic::message::get_tag")?;
246 Ok(message.tag().unwrap_or(0))
247}
248
249fn get_process_id<T: ProcessState + ProcessCtx<T>>(mut caller: Caller<T>) -> Result<u64> {
254 let message = caller
255 .data_mut()
256 .message_scratch_area()
257 .as_ref()
258 .or_trap("lunatic::message::get_process_id")?;
259 Ok(message.process_id().unwrap_or(0))
260}
261
262fn data_size<T: ProcessState + ProcessCtx<T>>(mut caller: Caller<T>) -> Result<u64> {
267 let message = caller
268 .data_mut()
269 .message_scratch_area()
270 .as_ref()
271 .or_trap("lunatic::message::data_size")?;
272 let bytes = match message {
273 Message::Data(data) => data.size(),
274 Message::LinkDied(_) => {
275 return Err(anyhow!("Unexpected `Message::LinkDied` in scratch area"))
276 }
277 Message::ProcessDied(_) => {
278 return Err(anyhow!("Unexpected `Message::ProcessDied` in scratch area"))
279 }
280 };
281
282 Ok(bytes as u64)
283}
284
285fn push_module<T: ProcessState + ProcessCtx<T> + NetworkingCtx + 'static>(
292 mut caller: Caller<T>,
293 module_id: u64,
294) -> Result<u64> {
295 let module = caller
296 .data()
297 .module_resources()
298 .get(module_id)
299 .or_trap("lunatic::message::push_module")?
300 .clone();
301 let message = caller
302 .data_mut()
303 .message_scratch_area()
304 .as_mut()
305 .or_trap("lunatic::message::push_module")?;
306 let index = match message {
307 Message::Data(data) => data.add_resource(module) as u64,
308 Message::LinkDied(_) => {
309 return Err(anyhow!("Unexpected `Message::LinkDied` in scratch area"))
310 }
311 Message::ProcessDied(_) => {
312 return Err(anyhow!("Unexpected `Message::ProcessDied` in scratch area"))
313 }
314 };
315 Ok(index)
316}
317
318fn take_module<T: ProcessState + ProcessCtx<T> + NetworkingCtx + 'static>(
325 mut caller: Caller<T>,
326 index: u64,
327) -> Result<u64> {
328 let message = caller
329 .data_mut()
330 .message_scratch_area()
331 .as_mut()
332 .or_trap("lunatic::message::take_module")?;
333 let module = match message {
334 Message::Data(data) => data
335 .take_module(index as usize)
336 .or_trap("lunatic::message::take_module")?,
337 Message::LinkDied(_) => {
338 return Err(anyhow!("Unexpected `Message::LinkDied` in scratch area"))
339 }
340 Message::ProcessDied(_) => {
341 return Err(anyhow!("Unexpected `Message::ProcessDied` in scratch area"))
342 }
343 };
344 Ok(caller.data_mut().module_resources_mut().add(module))
345}
346
347fn push_tcp_stream<T: ProcessState + ProcessCtx<T> + NetworkingCtx>(
354 mut caller: Caller<T>,
355 stream_id: u64,
356) -> Result<u64> {
357 let stream = caller
358 .data_mut()
359 .tcp_stream_resources_mut()
360 .remove(stream_id)
361 .or_trap("lunatic::message::push_tcp_stream")?;
362 let message = caller
363 .data_mut()
364 .message_scratch_area()
365 .as_mut()
366 .or_trap("lunatic::message::push_tcp_stream")?;
367 let index = match message {
368 Message::Data(data) => data.add_resource(stream) as u64,
369 Message::LinkDied(_) => {
370 return Err(anyhow!("Unexpected `Message::LinkDied` in scratch area"))
371 }
372 Message::ProcessDied(_) => {
373 return Err(anyhow!("Unexpected `Message::ProcessDied` in scratch area"))
374 }
375 };
376 Ok(index)
377}
378
379fn take_tcp_stream<T: ProcessState + ProcessCtx<T> + NetworkingCtx>(
386 mut caller: Caller<T>,
387 index: u64,
388) -> Result<u64> {
389 let message = caller
390 .data_mut()
391 .message_scratch_area()
392 .as_mut()
393 .or_trap("lunatic::message::take_tcp_stream")?;
394 let tcp_stream = match message {
395 Message::Data(data) => data
396 .take_tcp_stream(index as usize)
397 .or_trap("lunatic::message::take_tcp_stream")?,
398 Message::LinkDied(_) => {
399 return Err(anyhow!("Unexpected `Message::LinkDied` in scratch area"))
400 }
401 Message::ProcessDied(_) => {
402 return Err(anyhow!("Unexpected `Message::ProcessDied` in scratch area"))
403 }
404 };
405 Ok(caller.data_mut().tcp_stream_resources_mut().add(tcp_stream))
406}
407
408fn push_tls_stream<T: ProcessState + ProcessCtx<T> + NetworkingCtx>(
417 mut caller: Caller<T>,
418 stream_id: u64,
419) -> Result<u64> {
420 let resources = caller.data_mut().tls_stream_resources_mut();
421 let stream = resources
422 .remove(stream_id)
423 .or_trap("lunatic::message::push_tls_stream")?;
424 let message = caller
425 .data_mut()
426 .message_scratch_area()
427 .as_mut()
428 .or_trap("lunatic::message::push_tls_stream")?;
429 let index = match message {
430 Message::Data(data) => data.add_resource(stream) as u64,
431 Message::LinkDied(_) => {
432 return Err(anyhow!("Unexpected `Message::LinkDied` in scratch area"))
433 }
434 Message::ProcessDied(_) => {
435 return Err(anyhow!("Unexpected `Message::ProcessDied` in scratch area"))
436 }
437 };
438 Ok(index)
439}
440
441fn take_tls_stream<T: ProcessState + ProcessCtx<T> + NetworkingCtx>(
448 mut caller: Caller<T>,
449 index: u64,
450) -> Result<u64> {
451 let message = caller
452 .data_mut()
453 .message_scratch_area()
454 .as_mut()
455 .or_trap("lunatic::message::take_tls_stream")?;
456 let tls_stream = match message {
457 Message::Data(data) => data
458 .take_tls_stream(index as usize)
459 .or_trap("lunatic::message::take_tls_stream")?,
460 Message::LinkDied(_) => {
461 return Err(anyhow!("Unexpected `Message::LinkDied` in scratch area"))
462 }
463 Message::ProcessDied(_) => {
464 return Err(anyhow!("Unexpected `Message::ProcessDied` in scratch area"))
465 }
466 };
467 Ok(caller.data_mut().tls_stream_resources_mut().add(tls_stream))
468}
469
470fn send<T: ProcessState + ProcessCtx<T>>(mut caller: Caller<T>, process_id: u64) -> Result<u32> {
478 let message = caller
479 .data_mut()
480 .message_scratch_area()
481 .take()
482 .or_trap("lunatic::message::send::no_message")?;
483
484 if let Some(process) = caller.data_mut().environment().get_process(process_id) {
485 process.send(Signal::Message(message));
486 }
487
488 Ok(0)
489}
490
491fn send_receive_skip_search<T: ProcessState + ProcessCtx<T> + Send>(
510 mut caller: Caller<T>,
511 process_id: u64,
512 wait_on_tag: i64,
513 timeout_duration: u64,
514) -> Box<dyn Future<Output = Result<u32>> + Send + '_> {
515 Box::new(async move {
516 let message = caller
517 .data_mut()
518 .message_scratch_area()
519 .take()
520 .or_trap("lunatic::message::send_receive_skip_search")?;
521
522 if let Some(process) = caller.data_mut().environment().get_process(process_id) {
523 process.send(Signal::Message(message));
524 }
525
526 let tags = [wait_on_tag];
527 let pop_skip_search_tag = caller.data_mut().mailbox().pop_skip_search(Some(&tags));
528 if let Ok(message) = match timeout_duration {
529 u64::MAX => Ok(pop_skip_search_tag.await),
531 t => timeout(Duration::from_millis(t), pop_skip_search_tag).await,
533 } {
534 caller.data_mut().message_scratch_area().replace(message);
536 Ok(0)
537 } else {
538 Ok(9027)
539 }
540 })
541}
542
543fn receive<T: ProcessState + ProcessCtx<T> + Send>(
565 mut caller: Caller<T>,
566 tag_ptr: u32,
567 tag_len: u32,
568 timeout_duration: u64,
569) -> Box<dyn Future<Output = Result<u32>> + Send + '_> {
570 Box::new(async move {
571 let tags = if tag_len > 0 {
572 let memory = get_memory(&mut caller)?;
573 let buffer = memory
574 .data(&caller)
575 .get(tag_ptr as usize..(tag_ptr + tag_len * 8) as usize)
576 .or_trap("lunatic::message::receive")?;
577
578 let tags: Vec<i64> = buffer
580 .chunks_exact(8)
581 .map(|chunk| i64::from_le_bytes(chunk.try_into().expect("works")))
582 .collect();
583 Some(tags)
584 } else {
585 None
586 };
587
588 let pop = caller.data_mut().mailbox().pop(tags.as_deref());
589 if let Ok(message) = match timeout_duration {
590 u64::MAX => Ok(pop.await),
592 t => timeout(Duration::from_millis(t), pop).await,
594 } {
595 let result = match message {
596 Message::Data(_) => 0,
597 Message::LinkDied(_) => 1,
598 Message::ProcessDied(_) => 2,
599 };
600 caller.data_mut().message_scratch_area().replace(message);
602 Ok(result)
603 } else {
604 Ok(9027)
605 }
606 })
607}
608
609fn push_udp_socket<T: ProcessState + ProcessCtx<T> + NetworkingCtx>(
616 mut caller: Caller<T>,
617 socket_id: u64,
618) -> Result<u64> {
619 let data = caller.data_mut();
620 let socket = data
621 .udp_resources_mut()
622 .remove(socket_id)
623 .or_trap("lunatic::message::push_udp_socket")?;
624 let message = data
625 .message_scratch_area()
626 .as_mut()
627 .or_trap("lunatic::message::push_udp_socket")?;
628 let index = match message {
629 Message::Data(data) => data.add_resource(socket) as u64,
630 Message::LinkDied(_) => {
631 return Err(anyhow!("Unexpected `Message::LinkDied` in scratch area"))
632 }
633 Message::ProcessDied(_) => {
634 return Err(anyhow!("Unexpected `Message::ProcessDied` in scratch area"))
635 }
636 };
637 Ok(index)
638}
639
640fn take_udp_socket<T: ProcessState + ProcessCtx<T> + NetworkingCtx>(
647 mut caller: Caller<T>,
648 index: u64,
649) -> Result<u64> {
650 let message = caller
651 .data_mut()
652 .message_scratch_area()
653 .as_mut()
654 .or_trap("lunatic::message::take_udp_socket")?;
655 let udp_socket = match message {
656 Message::Data(data) => data
657 .take_udp_socket(index as usize)
658 .or_trap("lunatic::message::take_udp_socket")?,
659 Message::LinkDied(_) => {
660 return Err(anyhow!("Unexpected `Message::LinkDied` in scratch area"))
661 }
662 Message::ProcessDied(_) => {
663 return Err(anyhow!("Unexpected `Message::ProcessDied` in scratch area"))
664 }
665 };
666 Ok(caller.data_mut().udp_resources_mut().add(udp_socket))
667}