lunatic_distributed_api/
lib.rs

1use std::{future::Future, sync::Arc, time::Duration};
2
3use anyhow::{anyhow, Result};
4use lunatic_common_api::{get_memory, write_to_guest_vec, IntoTrap};
5use lunatic_distributed::{
6    distributed::message::{ClientError, Spawn, Val},
7    DistributedCtx,
8};
9use lunatic_error_api::ErrorCtx;
10use lunatic_process::{
11    env::Environment,
12    message::{DataMessage, Message},
13};
14use lunatic_process_api::ProcessCtx;
15use rcgen::{Certificate, CertificateParams, CertificateSigningRequest, KeyPair};
16use tokio::time::timeout;
17use wasmtime::{Caller, Linker, ResourceLimiter};
18
19// Register the lunatic distributed APIs to the linker
20pub fn register<T, E>(linker: &mut Linker<T>) -> Result<()>
21where
22    T: DistributedCtx<E> + ProcessCtx<T> + Send + ResourceLimiter + ErrorCtx + 'static,
23    E: Environment + 'static,
24    for<'a> &'a T: Send,
25{
26    linker.func_wrap("lunatic::distributed", "nodes_count", nodes_count)?;
27    linker.func_wrap("lunatic::distributed", "get_nodes", get_nodes)?;
28    linker.func_wrap("lunatic::distributed", "node_id", node_id)?;
29    linker.func_wrap("lunatic::distributed", "module_id", module_id)?;
30    linker.func_wrap8_async("lunatic::distributed", "spawn", spawn)?;
31    linker.func_wrap2_async("lunatic::distributed", "send", send)?;
32    linker.func_wrap4_async(
33        "lunatic::distributed",
34        "send_receive_skip_search",
35        send_receive_skip_search,
36    )?;
37    linker.func_wrap5_async(
38        "lunatic::distributed",
39        "exec_lookup_nodes",
40        exec_lookup_nodes,
41    )?;
42    linker.func_wrap(
43        "lunatic::distributed",
44        "copy_lookup_nodes_results",
45        copy_lookup_nodes_results,
46    )?;
47    linker.func_wrap1_async("lunatic::distributed", "test_root_cert", test_root_cert)?;
48    linker.func_wrap5_async(
49        "lunatic::distributed",
50        "default_server_certificates",
51        default_server_certificates,
52    )?;
53    linker.func_wrap7_async("lunatic::distributed", "sign_node", sign_node)?;
54    Ok(())
55}
56
57// Returns the number of registered nodes
58fn nodes_count<T, E>(caller: Caller<T>) -> u32
59where
60    T: DistributedCtx<E>,
61    E: Environment,
62{
63    caller
64        .data()
65        .distributed()
66        .map(|d| d.control.node_count())
67        .unwrap_or(0) as u32
68}
69
70// Copy node ids into guest memory. Returns the number of nodes copied.
71//
72// Traps:
73// * If any memory outside the guest heap space is referenced.
74fn get_nodes<T, E>(mut caller: Caller<T>, nodes_ptr: u32, nodes_len: u32) -> Result<u32>
75where
76    T: DistributedCtx<E>,
77    E: Environment,
78{
79    let memory = get_memory(&mut caller)?;
80    let node_ids = caller
81        .data()
82        .distributed()
83        .map(|d| d.control.node_ids())
84        .unwrap_or_else(|_| vec![]);
85    let copy_nodes_len = node_ids.len().min(nodes_len as usize);
86    memory
87        .data_mut(&mut caller)
88        .get_mut(
89            nodes_ptr as usize..(nodes_ptr as usize + std::mem::size_of::<u64>() * copy_nodes_len),
90        )
91        .or_trap("lunatic::distributed::get_nodes::memory")?
92        .copy_from_slice(unsafe { node_ids[..copy_nodes_len].align_to::<u8>().1 });
93    Ok(copy_nodes_len as u32)
94}
95
96// Submits a lookup node query to the control server and waits for the results.
97//
98// Filtering is done based on tags which are `key=value` user defined node
99// metadata, see CLI flag `tag`.
100//
101// Traps:
102// * If the query is not a valid UTF-8 string
103// * if any memory outside the guest heap space is referenced
104fn exec_lookup_nodes<T, E>(
105    mut caller: Caller<T>,
106    query_ptr: u32,
107    query_len: u32,
108    query_id_ptr: u32,
109    nodes_len_ptr: u32,
110    error_ptr: u32,
111) -> Box<dyn Future<Output = Result<u32>> + Send + '_>
112where
113    T: DistributedCtx<E> + ErrorCtx + Send + 'static,
114    E: Environment + 'static,
115    for<'a> &'a T: Send,
116{
117    Box::new(async move {
118        let memory = get_memory(&mut caller)?;
119        let query_str = memory
120            .data(&caller)
121            .get(query_ptr as usize..(query_ptr + query_len) as usize)
122            .or_trap("lunatic::distributed::lookup_nodes::query_ptr")?;
123        let query = std::str::from_utf8(query_str)
124            .or_trap("lunatic::distributed::lookup_nodes::query_str_utf8")?;
125        let distributed = caller.data().distributed()?;
126        match distributed.control.lookup_nodes(query).await {
127            Ok((query_id, nodes_len)) => {
128                memory
129                    .write(&mut caller, query_id_ptr as usize, &query_id.to_le_bytes())
130                    .or_trap("lunatic::distributed::lookup_nodes::query_id")?;
131                memory
132                    .write(
133                        &mut caller,
134                        nodes_len_ptr as usize,
135                        &nodes_len.to_le_bytes(),
136                    )
137                    .or_trap("lunatic::distributed::lookup_nodes::nodes_len")?;
138                Ok(0)
139            }
140            Err(error) => {
141                let error_id = caller.data_mut().error_resources_mut().add(error);
142                memory
143                    .write(&mut caller, error_ptr as usize, &error_id.to_le_bytes())
144                    .or_trap("lunatic::distributed::lookup_nodes::error_ptr")?;
145                Ok(1)
146            }
147        }
148    })
149}
150
151// Copies node ids to guest memory from the lookup node query result, returns number of node ids copied.
152//
153// Traps:
154// * If any memory outside the guest heap space is referenced.
155fn copy_lookup_nodes_results<T, E>(
156    mut caller: Caller<T>,
157    query_id: u64,
158    nodes_ptr: u32,
159    nodes_len: u32,
160    error_ptr: u32,
161) -> Result<i32>
162where
163    T: DistributedCtx<E> + ErrorCtx,
164    E: Environment,
165{
166    let memory = get_memory(&mut caller)?;
167    if let Some(query_results) = caller
168        .data()
169        .distributed()
170        .map(|d| d.control.query_result(&query_id))?
171    {
172        let nodes = query_results.1;
173        let copy_nodes_len = nodes.len().min(nodes_len as usize);
174        let memory = get_memory(&mut caller)?;
175        memory
176            .data_mut(&mut caller)
177            .get_mut(
178                nodes_ptr as usize
179                    ..(nodes_ptr as usize + std::mem::size_of::<u64>() * copy_nodes_len),
180            )
181            .or_trap("lunatic::distributed::copy_lookup_nodes_results::memory")?
182            .copy_from_slice(unsafe { nodes[..copy_nodes_len].align_to::<u8>().1 });
183        Ok(copy_nodes_len as i32)
184    } else {
185        let error = anyhow!("Invalid query id");
186        let error_id = caller.data_mut().error_resources_mut().add(error);
187        memory
188            .write(&mut caller, error_ptr as usize, &error_id.to_le_bytes())
189            .or_trap("lunatic::distributed::copy_lookup_nodes_results::error_ptr")?;
190        Ok(-1)
191    }
192}
193
194fn test_root_cert<T, E>(
195    mut caller: Caller<T>,
196    len_ptr: u32,
197) -> Box<dyn Future<Output = Result<u32>> + Send + '_>
198where
199    T: DistributedCtx<E> + Send,
200    E: Environment,
201{
202    Box::new(async move {
203        let memory = get_memory(&mut caller)?;
204        let root_cert = lunatic_distributed::control::cert::test_root_cert()
205            .or_trap("lunatic::distributed::test_root_cert")?;
206
207        let cert_pem = root_cert
208            .serialize_pem()
209            .or_trap("lunatic::distributed::test_root_cert")?;
210        let key_pair_pem = root_cert.serialize_private_key_pem();
211
212        let data = bincode::serialize(&(cert_pem, key_pair_pem))
213            .or_trap("lunatic::distributed::test_root_cert")?;
214        let ptr = write_to_guest_vec(&mut caller, &memory, &data, len_ptr)
215            .await
216            .or_trap("lunatic::distributed::test_root_cert")?;
217
218        Ok(ptr)
219    })
220}
221
222fn default_server_certificates<T, E>(
223    mut caller: Caller<T>,
224    cert_pem_ptr: u32,
225    cert_pem_len: u32,
226    pk_pem_ptr: u32,
227    pk_pem_len: u32,
228    len_ptr: u32,
229) -> Box<dyn Future<Output = Result<u32>> + Send + '_>
230where
231    T: DistributedCtx<E> + Send,
232    E: Environment,
233{
234    Box::new(async move {
235        let memory = get_memory(&mut caller)?;
236
237        let cert_pem_bytes = memory
238            .data(&caller)
239            .get(cert_pem_ptr as usize..(cert_pem_ptr + cert_pem_len) as usize)
240            .or_trap("lunatic::distributed::spawn::default_server_certificates")?;
241        let cert_pem = std::str::from_utf8(cert_pem_bytes)
242            .or_trap("lunatic::distributed::default_server_certificates")?;
243
244        let pk_pem_bytes = memory
245            .data(&caller)
246            .get(pk_pem_ptr as usize..(pk_pem_ptr + pk_pem_len) as usize)
247            .or_trap("lunatic::distributed::default_server_certificates")?;
248        let pk_pem = std::str::from_utf8(pk_pem_bytes)
249            .or_trap("lunatic::distributed::default_server_certificates")?;
250
251        let key_pair = KeyPair::from_pem(pk_pem)
252            .or_trap("lunatic::distributed::default_server_certificates")?;
253        let cert_params = CertificateParams::from_ca_cert_pem(cert_pem, key_pair)
254            .or_trap("lunatic::distributed::default_server_certificates")?;
255
256        let root_cert = Certificate::from_params(cert_params)
257            .or_trap("lunatic::distributed::default_server_certificates")?;
258
259        let (ctrl_cert, ctrl_pk) =
260            lunatic_distributed::control::cert::default_server_certificates(&root_cert)?;
261
262        let data = bincode::serialize(&(ctrl_cert, ctrl_pk))
263            .or_trap("lunatic::distributed::default_server_certificates")?;
264        let ptr = write_to_guest_vec(&mut caller, &memory, &data, len_ptr)
265            .await
266            .or_trap("lunatic::distributed::default_server_certificates")?;
267
268        Ok(ptr)
269    })
270}
271
272#[allow(clippy::too_many_arguments)]
273fn sign_node<T, E>(
274    mut caller: Caller<T>,
275    cert_pem_ptr: u32,
276    cert_pem_len: u32,
277    pk_pem_ptr: u32,
278    pk_pem_len: u32,
279    csr_pem_ptr: u32,
280    csr_pem_len: u32,
281    len_ptr: u32,
282) -> Box<dyn Future<Output = Result<u32>> + Send + '_>
283where
284    T: DistributedCtx<E> + Send,
285    E: Environment,
286{
287    Box::new(async move {
288        let memory = get_memory(&mut caller)?;
289
290        let cert_pem_bytes = memory
291            .data(&caller)
292            .get(cert_pem_ptr as usize..(cert_pem_ptr + cert_pem_len) as usize)
293            .or_trap("lunatic::distributed::spawn::sign_node")?;
294        let cert_pem =
295            std::str::from_utf8(cert_pem_bytes).or_trap("lunatic::distributed::sign_node")?;
296
297        let pk_pem_bytes = memory
298            .data(&caller)
299            .get(pk_pem_ptr as usize..(pk_pem_ptr + pk_pem_len) as usize)
300            .or_trap("lunatic::distributed::sign_node")?;
301        let pk_pem =
302            std::str::from_utf8(pk_pem_bytes).or_trap("lunatic::distributed::sign_node")?;
303
304        let csr_pem_bytes = memory
305            .data(&caller)
306            .get(csr_pem_ptr as usize..(csr_pem_ptr + csr_pem_len) as usize)
307            .or_trap("lunatic::distributed::sign_node")?;
308        let csr_pem =
309            std::str::from_utf8(csr_pem_bytes).or_trap("lunatic::distributed::sign_node")?;
310
311        let key_pair = KeyPair::from_pem(pk_pem).or_trap("lunatic::distributed::sign_node")?;
312        let cert_params = CertificateParams::from_ca_cert_pem(cert_pem, key_pair)
313            .or_trap("lunatic::distributed::sign_node")?;
314
315        let ca_cert =
316            Certificate::from_params(cert_params).or_trap("lunatic::distributed::sign_node")?;
317
318        let Ok(cert_pem) = CertificateSigningRequest::from_pem(csr_pem)
319            .and_then(|sign_request| sign_request.serialize_pem_with_signer(&ca_cert)) else {
320                return Ok(0);
321            };
322
323        let data = bincode::serialize(&cert_pem).or_trap("lunatic::distributed::sign_node")?;
324        let ptr = write_to_guest_vec(&mut caller, &memory, &data, len_ptr)
325            .await
326            .or_trap("lunatic::distributed::sign_node")?;
327
328        Ok(ptr)
329    })
330}
331
332// Similar to a local spawn, it spawns a new process using the passed in function inside a module
333// as the entry point. The process is spawned on a node with id `node_id`.
334//
335// If `config_id` is 0, the same config is used as in the process calling this function.
336//
337// The function arguments are passed as an array with the following structure:
338// [0 byte = type ID; 1..17 bytes = value as u128, ...]
339// The type ID follows the WebAssembly binary convention:
340//  - 0x7F => i32
341//  - 0x7E => i64
342//  - 0x7B => v128
343// If any other value is used as type ID, this function will trap. If your type
344// would ordinarily occupy fewer than 16 bytes (e.g. in an i32 or i64), you MUST
345// first convert it to an i128.
346//
347// Returns:
348// * 0      on success - The ID of the newly created process is written to `id_ptr`
349// * 1      If node does not exist
350// * 2      If module does not exist
351// * 9027   If node connection error occurred
352//
353// Traps:
354// * If the function string is not a valid utf8 string.
355// * If the params array is in a wrong format.
356// * If any memory outside the guest heap space is referenced.
357#[allow(clippy::too_many_arguments)]
358fn spawn<T, E>(
359    mut caller: Caller<T>,
360    node_id: u64,
361    config_id: i64,
362    module_id: u64,
363    func_str_ptr: u32,
364    func_str_len: u32,
365    params_ptr: u32,
366    params_len: u32,
367    id_ptr: u32,
368) -> Box<dyn Future<Output = Result<u32>> + Send + '_>
369where
370    T: DistributedCtx<E> + ResourceLimiter + Send + ErrorCtx + 'static,
371    E: Environment,
372    for<'a> &'a T: Send,
373{
374    Box::new(async move {
375        if !caller.data().can_spawn() {
376            return Err(anyhow!(
377                "Process doesn't have permissions to spawn sub-processes"
378            ));
379        }
380        let memory = get_memory(&mut caller)?;
381        let func_str = memory
382            .data(&caller)
383            .get(func_str_ptr as usize..(func_str_ptr + func_str_len) as usize)
384            .or_trap("lunatic::distributed::spawn::func_str")?;
385
386        let function =
387            std::str::from_utf8(func_str).or_trap("lunatic::distributed::spawn::func_str_utf8")?;
388
389        let params = memory
390            .data(&caller)
391            .get(params_ptr as usize..(params_ptr + params_len) as usize)
392            .or_trap("lunatic::distributed::spawn::params")?;
393        let params = params
394            .chunks_exact(17)
395            .map(|chunk| {
396                let value = u128::from_le_bytes(chunk[1..].try_into()?);
397                let result = match chunk[0] {
398                    0x7F => Val::I32(value as i32),
399                    0x7E => Val::I64(value as i64),
400                    0x7B => Val::V128(value),
401                    _ => return Err(anyhow!("Unsupported type ID")),
402                };
403                Ok(result)
404            })
405            .collect::<Result<Vec<_>>>()?;
406
407        let state = caller.data();
408
409        let config = match config_id {
410            -1 => state.config().clone(),
411            config_id => Arc::new(
412                caller
413                    .data()
414                    .config_resources()
415                    .get(config_id as u64)
416                    .or_trap("lunatic::distributed::spawn: Config ID doesn't exist")?
417                    .clone(),
418            ),
419        };
420        let config: Vec<u8> =
421            rmp_serde::to_vec(config.as_ref()).map_err(|_| anyhow!("Error serializing config"))?;
422
423        log::debug!("Spawn on node {node_id}, mod {module_id}, fn {function}, params {params:?}");
424
425        let (process_or_error_id, ret) = match state
426            .distributed()?
427            .node_client
428            .spawn(
429                node_id,
430                Spawn {
431                    environment_id: state.environment_id(),
432                    function: function.to_string(),
433                    module_id,
434                    params,
435                    config,
436                },
437            )
438            .await
439        {
440            Ok(process_id) => (process_id, 0),
441            Err(error) => {
442                let (code, message): (u32, String) = match error {
443                    ClientError::Unexpected(cause) => Err(anyhow!(cause)),
444                    ClientError::NodeNotFound => Ok((1, "Node does not exist.".to_string())),
445                    ClientError::ModuleNotFound => Ok((2, "Module does not exist.".to_string())),
446                    ClientError::Connection(cause) => Ok((9027, cause)),
447                    _ => Err(anyhow!("unreachable")),
448                }?;
449                (
450                    caller
451                        .data_mut()
452                        .error_resources_mut()
453                        .add(anyhow!(message)),
454                    code,
455                )
456            }
457        };
458
459        memory
460            .write(
461                &mut caller,
462                id_ptr as usize,
463                &process_or_error_id.to_le_bytes(),
464            )
465            .or_trap("lunatic::distributed::spawn::write_id")?;
466
467        Ok(ret)
468    })
469}
470
471// Sends the message in scratch area to a process running on a node with id `node_id`.
472//
473// There are no guarantees that the message will be received.
474//
475// Returns:
476// * 0      If message sent
477// * 1      If process_id does not exist
478// * 2      If node_id does not exist
479// * 9027   If node connection error occurred
480//
481// Traps:
482// * If it's called before creating the next message.
483// * If the message contains resources
484fn send<T, E>(
485    mut caller: Caller<T>,
486    node_id: u64,
487    process_id: u64,
488) -> Box<dyn Future<Output = Result<u32>> + Send + '_>
489where
490    T: DistributedCtx<E> + ProcessCtx<T> + Send + ErrorCtx + 'static,
491    E: Environment,
492    for<'a> &'a T: Send,
493{
494    Box::new(async move {
495        let message = caller
496            .data_mut()
497            .message_scratch_area()
498            .take()
499            .or_trap("lunatic::distributed::send::no_message")?;
500
501        if let Message::Data(DataMessage {
502            tag,
503            buffer,
504            resources,
505            ..
506        }) = message
507        {
508            if !resources.is_empty() {
509                return Err(anyhow!("Cannot send resources to remote nodes."));
510            }
511
512            let state = caller.data();
513            match state
514                .distributed()?
515                .node_client
516                .message_process(node_id, state.environment_id(), process_id, tag, buffer)
517                .await
518            {
519                Ok(_) => Ok(0),
520                Err(error) => match error {
521                    ClientError::Unexpected(cause) => Err(anyhow!(cause)),
522                    ClientError::ProcessNotFound => Ok(1),
523                    ClientError::NodeNotFound => Ok(2),
524                    ClientError::Connection(_) => Ok(9027),
525                    _ => Err(anyhow!("unreachable")),
526                },
527            }
528        } else {
529            Err(anyhow!("Only Message::Data can be sent across nodes."))
530        }
531    })
532}
533
534// Sends the message to a process on a node with id `node_id` and waits for a reply,
535// but doesn't look through existing messages in the mailbox queue while waiting.
536// This is an optimization that only makes sense with tagged messages.
537// In a request/reply scenario we can tag the request message with an
538// unique tag and just wait on it specifically.
539//
540// This operation needs to be an atomic host function, if we jumped back into the guest we could
541// miss out on the incoming message before `receive` is called.
542//
543// If timeout is specified (value different from u64::MAX), the function will return on timeout
544// expiration with value 9027.
545//
546// Returns:
547// * 0    If message arrived.
548// * 1    If process_id does not exist
549// * 2    If node_id does not exist
550// * 9027 If call timed out.
551//
552// Traps:
553// * If it's called with wrong data in the scratch area.
554// * If the message contains resources
555fn send_receive_skip_search<T, E>(
556    mut caller: Caller<T>,
557    node_id: u64,
558    process_id: u64,
559    wait_on_tag: i64,
560    timeout_duration: u64,
561) -> Box<dyn Future<Output = Result<u32>> + Send + '_>
562where
563    T: DistributedCtx<E> + ProcessCtx<T> + Send + 'static,
564    E: Environment,
565    for<'a> &'a T: Send,
566{
567    Box::new(async move {
568        let message = caller
569            .data_mut()
570            .message_scratch_area()
571            .take()
572            .or_trap("lunatic::distributed::send_receive_skip_search")?;
573
574        if let Message::Data(DataMessage {
575            tag,
576            buffer,
577            resources,
578            ..
579        }) = message
580        {
581            if !resources.is_empty() {
582                return Err(anyhow!("Cannot send resources to remote nodes."));
583            }
584
585            let state = caller.data();
586            let code = match state
587                .distributed()?
588                .node_client
589                .message_process(node_id, state.environment_id(), process_id, tag, buffer)
590                .await
591            {
592                Ok(_) => Ok(0),
593                Err(error) => match error {
594                    ClientError::ProcessNotFound => Ok(1),
595                    ClientError::NodeNotFound => Ok(2),
596                    ClientError::Unexpected(cause) => Err(anyhow!(cause)),
597                    _ => Err(anyhow!("unreachable")),
598                },
599            }?;
600
601            if code != 0 {
602                return Ok(code);
603            }
604
605            let tags = [wait_on_tag];
606            let pop_skip_search = caller.data_mut().mailbox().pop_skip_search(Some(&tags));
607            if let Ok(message) = match timeout_duration {
608                // Without timeout
609                u64::MAX => Ok(pop_skip_search.await),
610                // With timeout
611                t => timeout(Duration::from_millis(t), pop_skip_search).await,
612            } {
613                // Put the message into the scratch area
614                caller.data_mut().message_scratch_area().replace(message);
615                Ok(0)
616            } else {
617                Ok(9027)
618            }
619        } else {
620            Err(anyhow!("Only Message::Data can be sent across nodes."))
621        }
622    })
623}
624
625// Returns the id of the node that the current process is running on
626fn node_id<T, E>(caller: Caller<T>) -> u64
627where
628    T: DistributedCtx<E>,
629    E: Environment,
630{
631    caller
632        .data()
633        .distributed()
634        .as_ref()
635        .map(|d| d.node_id())
636        .unwrap_or(0)
637}
638
639// Returns id of the module that the current process is spawned from
640fn module_id<T, E>(caller: Caller<T>) -> u64
641where
642    T: DistributedCtx<E>,
643    E: Environment,
644{
645    caller.data().module_id()
646}