use std::{future::Future, sync::Arc, time::Duration};
use anyhow::{anyhow, Result};
use lunatic_common_api::{get_memory, IntoTrap};
use lunatic_distributed::{
distributed::message::{ClientError, Spawn, Val},
DistributedCtx,
};
use lunatic_error_api::ErrorCtx;
use lunatic_process::{
env::Environment,
message::{DataMessage, Message},
};
use lunatic_process_api::ProcessCtx;
use tokio::time::timeout;
use wasmtime::{Caller, Linker, ResourceLimiter, Trap};
pub fn register<T, E>(linker: &mut Linker<T>) -> Result<()>
where
T: DistributedCtx<E> + ProcessCtx<T> + Send + ResourceLimiter + ErrorCtx + 'static,
E: Environment + 'static,
for<'a> &'a T: Send,
{
linker.func_wrap("lunatic::distributed", "nodes_count", nodes_count)?;
linker.func_wrap("lunatic::distributed", "get_nodes", get_nodes)?;
linker.func_wrap("lunatic::distributed", "node_id", node_id)?;
linker.func_wrap("lunatic::distributed", "module_id", module_id)?;
linker.func_wrap8_async("lunatic::distributed", "spawn", spawn)?;
linker.func_wrap2_async("lunatic::distributed", "send", send)?;
linker.func_wrap3_async(
"lunatic::distributed",
"send_receive_skip_search",
send_receive_skip_search,
)?;
linker.func_wrap5_async(
"lunatic::distributed",
"exec_lookup_nodes",
exec_lookup_nodes,
)?;
linker.func_wrap(
"lunatic::distributed",
"copy_lookup_nodes_results",
copy_lookup_nodes_results,
)?;
Ok(())
}
fn nodes_count<T, E>(caller: Caller<T>) -> u32
where
T: DistributedCtx<E>,
E: Environment,
{
caller
.data()
.distributed()
.map(|d| d.control.node_count())
.unwrap_or(0) as u32
}
fn get_nodes<T, E>(mut caller: Caller<T>, nodes_ptr: u32, nodes_len: u32) -> Result<u32, Trap>
where
T: DistributedCtx<E>,
E: Environment,
{
let memory = get_memory(&mut caller)?;
let node_ids = caller
.data()
.distributed()
.map(|d| d.control.node_ids())
.unwrap_or_else(|_| vec![]);
let copy_nodes_len = node_ids.len().min(nodes_len as usize);
memory
.data_mut(&mut caller)
.get_mut(
nodes_ptr as usize
..(nodes_ptr as usize + std::mem::size_of::<u64>() * copy_nodes_len as usize),
)
.or_trap("lunatic::distributed::get_nodes::memory")?
.copy_from_slice(unsafe { node_ids[..copy_nodes_len].align_to::<u8>().1 });
Ok(copy_nodes_len as u32)
}
fn exec_lookup_nodes<T, E>(
mut caller: Caller<T>,
query_ptr: u32,
query_len: u32,
query_id_ptr: u32,
nodes_len_ptr: u32,
error_ptr: u32,
) -> Box<dyn Future<Output = Result<u32, Trap>> + Send + '_>
where
T: DistributedCtx<E> + ErrorCtx + Send + 'static,
E: Environment + 'static,
for<'a> &'a T: Send,
{
Box::new(async move {
let memory = get_memory(&mut caller)?;
let query_str = memory
.data(&caller)
.get(query_ptr as usize..(query_ptr + query_len) as usize)
.or_trap("lunatic::distributed::lookup_nodes::query_ptr")?;
let query = std::str::from_utf8(query_str)
.or_trap("lunatic::distributed::lookup_nodes::query_str_utf8")?;
let distributed = caller.data().distributed()?;
match distributed.control.lookup_nodes(query).await {
Ok((query_id, nodes_len)) => {
memory
.write(&mut caller, query_id_ptr as usize, &query_id.to_le_bytes())
.or_trap("lunatic::distributed::lookup_nodes::query_id")?;
memory
.write(
&mut caller,
nodes_len_ptr as usize,
&nodes_len.to_le_bytes(),
)
.or_trap("lunatic::distributed::lookup_nodes::nodes_len")?;
Ok(0)
}
Err(error) => {
let error_id = caller.data_mut().error_resources_mut().add(error);
memory
.write(&mut caller, error_ptr as usize, &error_id.to_le_bytes())
.or_trap("lunatic::distributed::lookup_nodes::error_ptr")?;
Ok(1)
}
}
})
}
fn copy_lookup_nodes_results<T, E>(
mut caller: Caller<T>,
query_id: u64,
nodes_ptr: u32,
nodes_len: u32,
error_ptr: u32,
) -> Result<i32, Trap>
where
T: DistributedCtx<E> + ErrorCtx,
E: Environment,
{
let memory = get_memory(&mut caller)?;
if let Some(query_results) = caller
.data()
.distributed()
.map(|d| d.control.query_result(&query_id))?
{
let nodes = query_results.1;
let copy_nodes_len = nodes.len().min(nodes_len as usize);
let memory = get_memory(&mut caller)?;
memory
.data_mut(&mut caller)
.get_mut(
nodes_ptr as usize
..(nodes_ptr as usize + std::mem::size_of::<u64>() * copy_nodes_len as usize),
)
.or_trap("lunatic::distributed::copy_lookup_nodes_results::memory")?
.copy_from_slice(unsafe { nodes[..copy_nodes_len].align_to::<u8>().1 });
Ok(copy_nodes_len as i32)
} else {
let error = anyhow!("Invalid query id");
let error_id = caller.data_mut().error_resources_mut().add(error);
memory
.write(&mut caller, error_ptr as usize, &error_id.to_le_bytes())
.or_trap("lunatic::distributed::copy_lookup_nodes_results::error_ptr")?;
Ok(-1)
}
}
#[allow(clippy::too_many_arguments)]
fn spawn<T, E>(
mut caller: Caller<T>,
node_id: u64,
config_id: i64,
module_id: u64,
func_str_ptr: u32,
func_str_len: u32,
params_ptr: u32,
params_len: u32,
id_ptr: u32,
) -> Box<dyn Future<Output = Result<u32, Trap>> + Send + '_>
where
T: DistributedCtx<E> + ResourceLimiter + Send + ErrorCtx + 'static,
E: Environment,
for<'a> &'a T: Send,
{
Box::new(async move {
if !caller.data().can_spawn() {
return Err(anyhow!("Process doesn't have permissions to spawn sub-processes").into());
}
let memory = get_memory(&mut caller)?;
let func_str = memory
.data(&caller)
.get(func_str_ptr as usize..(func_str_ptr + func_str_len) as usize)
.or_trap("lunatic::distributed::spawn::func_str")?;
let function =
std::str::from_utf8(func_str).or_trap("lunatic::distributed::spawn::func_str_utf8")?;
let params = memory
.data(&caller)
.get(params_ptr as usize..(params_ptr + params_len) as usize)
.or_trap("lunatic::distributed::spawn::params")?;
let params = params
.chunks_exact(17)
.map(|chunk| {
let value = u128::from_le_bytes(chunk[1..].try_into()?);
let result = match chunk[0] {
0x7F => Val::I32(value as i32),
0x7E => Val::I64(value as i64),
0x7B => Val::V128(value),
_ => return Err(anyhow!("Unsupported type ID")),
};
Ok(result)
})
.collect::<Result<Vec<_>>>()?;
let state = caller.data();
let config = match config_id {
-1 => state.config().clone(),
config_id => Arc::new(
caller
.data()
.config_resources()
.get(config_id as u64)
.or_trap("lunatic::process::spawn: Config ID doesn't exist")?
.clone(),
),
};
let config: Vec<u8> =
bincode::serialize(config.as_ref()).map_err(|_| anyhow!("Error serializing config"))?;
log::debug!("Spawn on node {node_id}, mod {module_id}, fn {function}, params {params:?}");
let (process_or_error_id, ret) = match state
.distributed()?
.node_client
.spawn(
node_id,
Spawn {
environment_id: state.environment_id(),
function: function.to_string(),
module_id,
params,
config,
},
)
.await
{
Ok(process_id) => (process_id, 0),
Err(error) => {
let (code, message): (u32, String) = match error {
ClientError::Unexpected(cause) => Err(Trap::new(cause)),
ClientError::NodeNotFound => Ok((1, "Node does not exist.".to_string())),
ClientError::ModuleNotFound => Ok((2, "Module does not exist.".to_string())),
ClientError::Connection(cause) => Ok((9027, cause)),
_ => Err(Trap::new("unreachable")),
}?;
(
caller
.data_mut()
.error_resources_mut()
.add(anyhow!(message)),
code,
)
}
};
memory
.write(
&mut caller,
id_ptr as usize,
&process_or_error_id.to_le_bytes(),
)
.or_trap("lunatic::distributed::spawn::write_id")?;
Ok(ret)
})
}
fn send<T, E>(
mut caller: Caller<T>,
node_id: u64,
process_id: u64,
) -> Box<dyn Future<Output = Result<u32, Trap>> + Send + '_>
where
T: DistributedCtx<E> + ProcessCtx<T> + Send + ErrorCtx + 'static,
E: Environment,
for<'a> &'a T: Send,
{
Box::new(async move {
let message = caller
.data_mut()
.message_scratch_area()
.take()
.or_trap("lunatic::message::send::no_message")?;
if let Message::Data(DataMessage {
tag,
buffer,
resources,
..
}) = message
{
if !resources.is_empty() {
return Err(Trap::new("Cannot send resources to remote nodes."));
}
let state = caller.data();
match state
.distributed()?
.node_client
.message_process(node_id, state.environment_id(), process_id, tag, buffer)
.await
{
Ok(_) => Ok(0),
Err(error) => match error {
ClientError::Unexpected(cause) => Err(Trap::new(cause)),
ClientError::ProcessNotFound => Ok(1),
ClientError::NodeNotFound => Ok(2),
ClientError::Connection(_) => Ok(9027),
_ => Err(Trap::new("unreachable")),
},
}
} else {
Err(Trap::new("Only Message::Data can be sent across nodes."))
}
})
}
fn send_receive_skip_search<T, E>(
mut caller: Caller<T>,
node_id: u64,
process_id: u64,
timeout_duration: u64,
) -> Box<dyn Future<Output = Result<u32, Trap>> + Send + '_>
where
T: DistributedCtx<E> + ProcessCtx<T> + Send + 'static,
E: Environment,
for<'a> &'a T: Send,
{
Box::new(async move {
let message = caller
.data_mut()
.message_scratch_area()
.take()
.or_trap("lunatic::message::send::no_message")?;
let mut _tags = [0; 1];
let tags = if let Some(tag) = message.tag() {
_tags = [tag];
Some(&_tags[..])
} else {
None
};
if let Message::Data(DataMessage {
tag,
buffer,
resources,
..
}) = message
{
if !resources.is_empty() {
return Err(Trap::new("Cannot send resources to remote nodes."));
}
let state = caller.data();
let code = match state
.distributed()?
.node_client
.message_process(node_id, state.environment_id(), process_id, tag, buffer)
.await
{
Ok(_) => Ok(0),
Err(error) => match error {
ClientError::ProcessNotFound => Ok(1),
ClientError::NodeNotFound => Ok(2),
ClientError::Unexpected(cause) => Err(Trap::new(cause)),
_ => Err(Trap::new("unreachable")),
},
}?;
if code != 0 {
return Ok(code);
}
let pop_skip_search = caller.data_mut().mailbox().pop_skip_search(tags);
if let Ok(message) = match timeout_duration {
u64::MAX => Ok(pop_skip_search.await),
t => timeout(Duration::from_millis(t), pop_skip_search).await,
} {
caller.data_mut().message_scratch_area().replace(message);
Ok(0)
} else {
Ok(9027)
}
} else {
Err(Trap::new("Only Message::Data can be sent across nodes."))
}
})
}
fn node_id<T, E>(caller: Caller<T>) -> u64
where
T: DistributedCtx<E>,
E: Environment,
{
caller
.data()
.distributed()
.as_ref()
.map(|d| d.node_id())
.unwrap_or(0)
}
fn module_id<T, E>(caller: Caller<T>) -> u64
where
T: DistributedCtx<E>,
E: Environment,
{
caller.data().module_id()
}