1use std::{future::Future, sync::Arc, time::Duration};
2
3use anyhow::{anyhow, Result};
4use lunatic_common_api::{get_memory, write_to_guest_vec, IntoTrap};
5use lunatic_distributed::{
6 distributed::message::{ClientError, Spawn, Val},
7 DistributedCtx,
8};
9use lunatic_error_api::ErrorCtx;
10use lunatic_process::{
11 env::Environment,
12 message::{DataMessage, Message},
13};
14use lunatic_process_api::ProcessCtx;
15use rcgen::{Certificate, CertificateParams, CertificateSigningRequest, KeyPair};
16use tokio::time::timeout;
17use wasmtime::{Caller, Linker, ResourceLimiter};
18
19pub fn register<T, E>(linker: &mut Linker<T>) -> Result<()>
21where
22 T: DistributedCtx<E> + ProcessCtx<T> + Send + ResourceLimiter + ErrorCtx + 'static,
23 E: Environment + 'static,
24 for<'a> &'a T: Send,
25{
26 linker.func_wrap("lunatic::distributed", "nodes_count", nodes_count)?;
27 linker.func_wrap("lunatic::distributed", "get_nodes", get_nodes)?;
28 linker.func_wrap("lunatic::distributed", "node_id", node_id)?;
29 linker.func_wrap("lunatic::distributed", "module_id", module_id)?;
30 linker.func_wrap8_async("lunatic::distributed", "spawn", spawn)?;
31 linker.func_wrap2_async("lunatic::distributed", "send", send)?;
32 linker.func_wrap4_async(
33 "lunatic::distributed",
34 "send_receive_skip_search",
35 send_receive_skip_search,
36 )?;
37 linker.func_wrap5_async(
38 "lunatic::distributed",
39 "exec_lookup_nodes",
40 exec_lookup_nodes,
41 )?;
42 linker.func_wrap(
43 "lunatic::distributed",
44 "copy_lookup_nodes_results",
45 copy_lookup_nodes_results,
46 )?;
47 linker.func_wrap1_async("lunatic::distributed", "test_root_cert", test_root_cert)?;
48 linker.func_wrap5_async(
49 "lunatic::distributed",
50 "default_server_certificates",
51 default_server_certificates,
52 )?;
53 linker.func_wrap7_async("lunatic::distributed", "sign_node", sign_node)?;
54 Ok(())
55}
56
57fn nodes_count<T, E>(caller: Caller<T>) -> u32
59where
60 T: DistributedCtx<E>,
61 E: Environment,
62{
63 caller
64 .data()
65 .distributed()
66 .map(|d| d.control.node_count())
67 .unwrap_or(0) as u32
68}
69
70fn get_nodes<T, E>(mut caller: Caller<T>, nodes_ptr: u32, nodes_len: u32) -> Result<u32>
75where
76 T: DistributedCtx<E>,
77 E: Environment,
78{
79 let memory = get_memory(&mut caller)?;
80 let node_ids = caller
81 .data()
82 .distributed()
83 .map(|d| d.control.node_ids())
84 .unwrap_or_else(|_| vec![]);
85 let copy_nodes_len = node_ids.len().min(nodes_len as usize);
86 memory
87 .data_mut(&mut caller)
88 .get_mut(
89 nodes_ptr as usize..(nodes_ptr as usize + std::mem::size_of::<u64>() * copy_nodes_len),
90 )
91 .or_trap("lunatic::distributed::get_nodes::memory")?
92 .copy_from_slice(unsafe { node_ids[..copy_nodes_len].align_to::<u8>().1 });
93 Ok(copy_nodes_len as u32)
94}
95
96fn exec_lookup_nodes<T, E>(
105 mut caller: Caller<T>,
106 query_ptr: u32,
107 query_len: u32,
108 query_id_ptr: u32,
109 nodes_len_ptr: u32,
110 error_ptr: u32,
111) -> Box<dyn Future<Output = Result<u32>> + Send + '_>
112where
113 T: DistributedCtx<E> + ErrorCtx + Send + 'static,
114 E: Environment + 'static,
115 for<'a> &'a T: Send,
116{
117 Box::new(async move {
118 let memory = get_memory(&mut caller)?;
119 let query_str = memory
120 .data(&caller)
121 .get(query_ptr as usize..(query_ptr + query_len) as usize)
122 .or_trap("lunatic::distributed::lookup_nodes::query_ptr")?;
123 let query = std::str::from_utf8(query_str)
124 .or_trap("lunatic::distributed::lookup_nodes::query_str_utf8")?;
125 let distributed = caller.data().distributed()?;
126 match distributed.control.lookup_nodes(query).await {
127 Ok((query_id, nodes_len)) => {
128 memory
129 .write(&mut caller, query_id_ptr as usize, &query_id.to_le_bytes())
130 .or_trap("lunatic::distributed::lookup_nodes::query_id")?;
131 memory
132 .write(
133 &mut caller,
134 nodes_len_ptr as usize,
135 &nodes_len.to_le_bytes(),
136 )
137 .or_trap("lunatic::distributed::lookup_nodes::nodes_len")?;
138 Ok(0)
139 }
140 Err(error) => {
141 let error_id = caller.data_mut().error_resources_mut().add(error);
142 memory
143 .write(&mut caller, error_ptr as usize, &error_id.to_le_bytes())
144 .or_trap("lunatic::distributed::lookup_nodes::error_ptr")?;
145 Ok(1)
146 }
147 }
148 })
149}
150
151fn copy_lookup_nodes_results<T, E>(
156 mut caller: Caller<T>,
157 query_id: u64,
158 nodes_ptr: u32,
159 nodes_len: u32,
160 error_ptr: u32,
161) -> Result<i32>
162where
163 T: DistributedCtx<E> + ErrorCtx,
164 E: Environment,
165{
166 let memory = get_memory(&mut caller)?;
167 if let Some(query_results) = caller
168 .data()
169 .distributed()
170 .map(|d| d.control.query_result(&query_id))?
171 {
172 let nodes = query_results.1;
173 let copy_nodes_len = nodes.len().min(nodes_len as usize);
174 let memory = get_memory(&mut caller)?;
175 memory
176 .data_mut(&mut caller)
177 .get_mut(
178 nodes_ptr as usize
179 ..(nodes_ptr as usize + std::mem::size_of::<u64>() * copy_nodes_len),
180 )
181 .or_trap("lunatic::distributed::copy_lookup_nodes_results::memory")?
182 .copy_from_slice(unsafe { nodes[..copy_nodes_len].align_to::<u8>().1 });
183 Ok(copy_nodes_len as i32)
184 } else {
185 let error = anyhow!("Invalid query id");
186 let error_id = caller.data_mut().error_resources_mut().add(error);
187 memory
188 .write(&mut caller, error_ptr as usize, &error_id.to_le_bytes())
189 .or_trap("lunatic::distributed::copy_lookup_nodes_results::error_ptr")?;
190 Ok(-1)
191 }
192}
193
194fn test_root_cert<T, E>(
195 mut caller: Caller<T>,
196 len_ptr: u32,
197) -> Box<dyn Future<Output = Result<u32>> + Send + '_>
198where
199 T: DistributedCtx<E> + Send,
200 E: Environment,
201{
202 Box::new(async move {
203 let memory = get_memory(&mut caller)?;
204 let root_cert = lunatic_distributed::control::cert::test_root_cert()
205 .or_trap("lunatic::distributed::test_root_cert")?;
206
207 let cert_pem = root_cert
208 .serialize_pem()
209 .or_trap("lunatic::distributed::test_root_cert")?;
210 let key_pair_pem = root_cert.serialize_private_key_pem();
211
212 let data = bincode::serialize(&(cert_pem, key_pair_pem))
213 .or_trap("lunatic::distributed::test_root_cert")?;
214 let ptr = write_to_guest_vec(&mut caller, &memory, &data, len_ptr)
215 .await
216 .or_trap("lunatic::distributed::test_root_cert")?;
217
218 Ok(ptr)
219 })
220}
221
222fn default_server_certificates<T, E>(
223 mut caller: Caller<T>,
224 cert_pem_ptr: u32,
225 cert_pem_len: u32,
226 pk_pem_ptr: u32,
227 pk_pem_len: u32,
228 len_ptr: u32,
229) -> Box<dyn Future<Output = Result<u32>> + Send + '_>
230where
231 T: DistributedCtx<E> + Send,
232 E: Environment,
233{
234 Box::new(async move {
235 let memory = get_memory(&mut caller)?;
236
237 let cert_pem_bytes = memory
238 .data(&caller)
239 .get(cert_pem_ptr as usize..(cert_pem_ptr + cert_pem_len) as usize)
240 .or_trap("lunatic::distributed::spawn::default_server_certificates")?;
241 let cert_pem = std::str::from_utf8(cert_pem_bytes)
242 .or_trap("lunatic::distributed::default_server_certificates")?;
243
244 let pk_pem_bytes = memory
245 .data(&caller)
246 .get(pk_pem_ptr as usize..(pk_pem_ptr + pk_pem_len) as usize)
247 .or_trap("lunatic::distributed::default_server_certificates")?;
248 let pk_pem = std::str::from_utf8(pk_pem_bytes)
249 .or_trap("lunatic::distributed::default_server_certificates")?;
250
251 let key_pair = KeyPair::from_pem(pk_pem)
252 .or_trap("lunatic::distributed::default_server_certificates")?;
253 let cert_params = CertificateParams::from_ca_cert_pem(cert_pem, key_pair)
254 .or_trap("lunatic::distributed::default_server_certificates")?;
255
256 let root_cert = Certificate::from_params(cert_params)
257 .or_trap("lunatic::distributed::default_server_certificates")?;
258
259 let (ctrl_cert, ctrl_pk) =
260 lunatic_distributed::control::cert::default_server_certificates(&root_cert)?;
261
262 let data = bincode::serialize(&(ctrl_cert, ctrl_pk))
263 .or_trap("lunatic::distributed::default_server_certificates")?;
264 let ptr = write_to_guest_vec(&mut caller, &memory, &data, len_ptr)
265 .await
266 .or_trap("lunatic::distributed::default_server_certificates")?;
267
268 Ok(ptr)
269 })
270}
271
272#[allow(clippy::too_many_arguments)]
273fn sign_node<T, E>(
274 mut caller: Caller<T>,
275 cert_pem_ptr: u32,
276 cert_pem_len: u32,
277 pk_pem_ptr: u32,
278 pk_pem_len: u32,
279 csr_pem_ptr: u32,
280 csr_pem_len: u32,
281 len_ptr: u32,
282) -> Box<dyn Future<Output = Result<u32>> + Send + '_>
283where
284 T: DistributedCtx<E> + Send,
285 E: Environment,
286{
287 Box::new(async move {
288 let memory = get_memory(&mut caller)?;
289
290 let cert_pem_bytes = memory
291 .data(&caller)
292 .get(cert_pem_ptr as usize..(cert_pem_ptr + cert_pem_len) as usize)
293 .or_trap("lunatic::distributed::spawn::sign_node")?;
294 let cert_pem =
295 std::str::from_utf8(cert_pem_bytes).or_trap("lunatic::distributed::sign_node")?;
296
297 let pk_pem_bytes = memory
298 .data(&caller)
299 .get(pk_pem_ptr as usize..(pk_pem_ptr + pk_pem_len) as usize)
300 .or_trap("lunatic::distributed::sign_node")?;
301 let pk_pem =
302 std::str::from_utf8(pk_pem_bytes).or_trap("lunatic::distributed::sign_node")?;
303
304 let csr_pem_bytes = memory
305 .data(&caller)
306 .get(csr_pem_ptr as usize..(csr_pem_ptr + csr_pem_len) as usize)
307 .or_trap("lunatic::distributed::sign_node")?;
308 let csr_pem =
309 std::str::from_utf8(csr_pem_bytes).or_trap("lunatic::distributed::sign_node")?;
310
311 let key_pair = KeyPair::from_pem(pk_pem).or_trap("lunatic::distributed::sign_node")?;
312 let cert_params = CertificateParams::from_ca_cert_pem(cert_pem, key_pair)
313 .or_trap("lunatic::distributed::sign_node")?;
314
315 let ca_cert =
316 Certificate::from_params(cert_params).or_trap("lunatic::distributed::sign_node")?;
317
318 let Ok(cert_pem) = CertificateSigningRequest::from_pem(csr_pem)
319 .and_then(|sign_request| sign_request.serialize_pem_with_signer(&ca_cert)) else {
320 return Ok(0);
321 };
322
323 let data = bincode::serialize(&cert_pem).or_trap("lunatic::distributed::sign_node")?;
324 let ptr = write_to_guest_vec(&mut caller, &memory, &data, len_ptr)
325 .await
326 .or_trap("lunatic::distributed::sign_node")?;
327
328 Ok(ptr)
329 })
330}
331
332#[allow(clippy::too_many_arguments)]
358fn spawn<T, E>(
359 mut caller: Caller<T>,
360 node_id: u64,
361 config_id: i64,
362 module_id: u64,
363 func_str_ptr: u32,
364 func_str_len: u32,
365 params_ptr: u32,
366 params_len: u32,
367 id_ptr: u32,
368) -> Box<dyn Future<Output = Result<u32>> + Send + '_>
369where
370 T: DistributedCtx<E> + ResourceLimiter + Send + ErrorCtx + 'static,
371 E: Environment,
372 for<'a> &'a T: Send,
373{
374 Box::new(async move {
375 if !caller.data().can_spawn() {
376 return Err(anyhow!(
377 "Process doesn't have permissions to spawn sub-processes"
378 ));
379 }
380 let memory = get_memory(&mut caller)?;
381 let func_str = memory
382 .data(&caller)
383 .get(func_str_ptr as usize..(func_str_ptr + func_str_len) as usize)
384 .or_trap("lunatic::distributed::spawn::func_str")?;
385
386 let function =
387 std::str::from_utf8(func_str).or_trap("lunatic::distributed::spawn::func_str_utf8")?;
388
389 let params = memory
390 .data(&caller)
391 .get(params_ptr as usize..(params_ptr + params_len) as usize)
392 .or_trap("lunatic::distributed::spawn::params")?;
393 let params = params
394 .chunks_exact(17)
395 .map(|chunk| {
396 let value = u128::from_le_bytes(chunk[1..].try_into()?);
397 let result = match chunk[0] {
398 0x7F => Val::I32(value as i32),
399 0x7E => Val::I64(value as i64),
400 0x7B => Val::V128(value),
401 _ => return Err(anyhow!("Unsupported type ID")),
402 };
403 Ok(result)
404 })
405 .collect::<Result<Vec<_>>>()?;
406
407 let state = caller.data();
408
409 let config = match config_id {
410 -1 => state.config().clone(),
411 config_id => Arc::new(
412 caller
413 .data()
414 .config_resources()
415 .get(config_id as u64)
416 .or_trap("lunatic::distributed::spawn: Config ID doesn't exist")?
417 .clone(),
418 ),
419 };
420 let config: Vec<u8> =
421 rmp_serde::to_vec(config.as_ref()).map_err(|_| anyhow!("Error serializing config"))?;
422
423 log::debug!("Spawn on node {node_id}, mod {module_id}, fn {function}, params {params:?}");
424
425 let (process_or_error_id, ret) = match state
426 .distributed()?
427 .node_client
428 .spawn(
429 node_id,
430 Spawn {
431 environment_id: state.environment_id(),
432 function: function.to_string(),
433 module_id,
434 params,
435 config,
436 },
437 )
438 .await
439 {
440 Ok(process_id) => (process_id, 0),
441 Err(error) => {
442 let (code, message): (u32, String) = match error {
443 ClientError::Unexpected(cause) => Err(anyhow!(cause)),
444 ClientError::NodeNotFound => Ok((1, "Node does not exist.".to_string())),
445 ClientError::ModuleNotFound => Ok((2, "Module does not exist.".to_string())),
446 ClientError::Connection(cause) => Ok((9027, cause)),
447 _ => Err(anyhow!("unreachable")),
448 }?;
449 (
450 caller
451 .data_mut()
452 .error_resources_mut()
453 .add(anyhow!(message)),
454 code,
455 )
456 }
457 };
458
459 memory
460 .write(
461 &mut caller,
462 id_ptr as usize,
463 &process_or_error_id.to_le_bytes(),
464 )
465 .or_trap("lunatic::distributed::spawn::write_id")?;
466
467 Ok(ret)
468 })
469}
470
471fn send<T, E>(
485 mut caller: Caller<T>,
486 node_id: u64,
487 process_id: u64,
488) -> Box<dyn Future<Output = Result<u32>> + Send + '_>
489where
490 T: DistributedCtx<E> + ProcessCtx<T> + Send + ErrorCtx + 'static,
491 E: Environment,
492 for<'a> &'a T: Send,
493{
494 Box::new(async move {
495 let message = caller
496 .data_mut()
497 .message_scratch_area()
498 .take()
499 .or_trap("lunatic::distributed::send::no_message")?;
500
501 if let Message::Data(DataMessage {
502 tag,
503 buffer,
504 resources,
505 ..
506 }) = message
507 {
508 if !resources.is_empty() {
509 return Err(anyhow!("Cannot send resources to remote nodes."));
510 }
511
512 let state = caller.data();
513 match state
514 .distributed()?
515 .node_client
516 .message_process(node_id, state.environment_id(), process_id, tag, buffer)
517 .await
518 {
519 Ok(_) => Ok(0),
520 Err(error) => match error {
521 ClientError::Unexpected(cause) => Err(anyhow!(cause)),
522 ClientError::ProcessNotFound => Ok(1),
523 ClientError::NodeNotFound => Ok(2),
524 ClientError::Connection(_) => Ok(9027),
525 _ => Err(anyhow!("unreachable")),
526 },
527 }
528 } else {
529 Err(anyhow!("Only Message::Data can be sent across nodes."))
530 }
531 })
532}
533
534fn send_receive_skip_search<T, E>(
556 mut caller: Caller<T>,
557 node_id: u64,
558 process_id: u64,
559 wait_on_tag: i64,
560 timeout_duration: u64,
561) -> Box<dyn Future<Output = Result<u32>> + Send + '_>
562where
563 T: DistributedCtx<E> + ProcessCtx<T> + Send + 'static,
564 E: Environment,
565 for<'a> &'a T: Send,
566{
567 Box::new(async move {
568 let message = caller
569 .data_mut()
570 .message_scratch_area()
571 .take()
572 .or_trap("lunatic::distributed::send_receive_skip_search")?;
573
574 if let Message::Data(DataMessage {
575 tag,
576 buffer,
577 resources,
578 ..
579 }) = message
580 {
581 if !resources.is_empty() {
582 return Err(anyhow!("Cannot send resources to remote nodes."));
583 }
584
585 let state = caller.data();
586 let code = match state
587 .distributed()?
588 .node_client
589 .message_process(node_id, state.environment_id(), process_id, tag, buffer)
590 .await
591 {
592 Ok(_) => Ok(0),
593 Err(error) => match error {
594 ClientError::ProcessNotFound => Ok(1),
595 ClientError::NodeNotFound => Ok(2),
596 ClientError::Unexpected(cause) => Err(anyhow!(cause)),
597 _ => Err(anyhow!("unreachable")),
598 },
599 }?;
600
601 if code != 0 {
602 return Ok(code);
603 }
604
605 let tags = [wait_on_tag];
606 let pop_skip_search = caller.data_mut().mailbox().pop_skip_search(Some(&tags));
607 if let Ok(message) = match timeout_duration {
608 u64::MAX => Ok(pop_skip_search.await),
610 t => timeout(Duration::from_millis(t), pop_skip_search).await,
612 } {
613 caller.data_mut().message_scratch_area().replace(message);
615 Ok(0)
616 } else {
617 Ok(9027)
618 }
619 } else {
620 Err(anyhow!("Only Message::Data can be sent across nodes."))
621 }
622 })
623}
624
625fn node_id<T, E>(caller: Caller<T>) -> u64
627where
628 T: DistributedCtx<E>,
629 E: Environment,
630{
631 caller
632 .data()
633 .distributed()
634 .as_ref()
635 .map(|d| d.node_id())
636 .unwrap_or(0)
637}
638
639fn module_id<T, E>(caller: Caller<T>) -> u64
641where
642 T: DistributedCtx<E>,
643 E: Environment,
644{
645 caller.data().module_id()
646}