use std::future::Future;
use std::net::SocketAddr;
use std::time::Duration;
use std::vec::IntoIter;
use anyhow::Result;
use tokio::time::timeout;
use wasmtime::{Caller, Linker};
use lunatic_common_api::{get_memory, IntoTrap};
use lunatic_error_api::ErrorCtx;
use crate::NetworkingCtx;
pub struct DnsIterator {
iter: IntoIter<SocketAddr>,
}
impl DnsIterator {
pub fn new(iter: IntoIter<SocketAddr>) -> Self {
Self { iter }
}
}
impl Iterator for DnsIterator {
type Item = SocketAddr;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}
pub fn register<T: NetworkingCtx + ErrorCtx + Send + 'static>(
linker: &mut Linker<T>,
) -> Result<()> {
linker.func_wrap4_async("lunatic::networking", "resolve", resolve)?;
linker.func_wrap(
"lunatic::networking",
"drop_dns_iterator",
drop_dns_iterator,
)?;
linker.func_wrap("lunatic::networking", "resolve_next", resolve_next)?;
Ok(())
}
fn resolve<T: NetworkingCtx + ErrorCtx + Send>(
mut caller: Caller<T>,
name_str_ptr: u32,
name_str_len: u32,
timeout_duration: u64,
id_u64_ptr: u32,
) -> Box<dyn Future<Output = Result<u32>> + Send + '_> {
Box::new(async move {
let memory = get_memory(&mut caller)?;
let (memory_slice, state) = memory.data_and_store_mut(&mut caller);
let buffer = memory_slice
.get(name_str_ptr as usize..(name_str_ptr + name_str_len) as usize)
.or_trap("lunatic::network::resolve")?;
let name = std::str::from_utf8(buffer)
.or_trap("lunatic::network::resolve::not_valid_utf8_string")?;
let lookup_host = tokio::net::lookup_host(name);
let (iter_or_error_id, result) = if let Ok(result) = match timeout_duration {
u64::MAX => Ok(lookup_host.await),
t => timeout(Duration::from_millis(t), lookup_host).await,
} {
match result {
Ok(sockets) => {
#[allow(clippy::needless_collect)]
let id = state.dns_resources_mut().add(DnsIterator::new(
sockets.collect::<Vec<SocketAddr>>().into_iter(),
));
(id, 0)
}
Err(error) => {
let error_id = state.error_resources_mut().add(error.into());
(error_id, 1)
}
}
} else {
(0, 9027)
};
let memory = get_memory(&mut caller)?;
memory
.write(
&mut caller,
id_u64_ptr as usize,
&iter_or_error_id.to_le_bytes(),
)
.or_trap("lunatic::networking::resolve")?;
Ok(result)
})
}
fn drop_dns_iterator<T: NetworkingCtx>(mut caller: Caller<T>, dns_iter_id: u64) -> Result<()> {
caller
.data_mut()
.dns_resources_mut()
.remove(dns_iter_id)
.or_trap("lunatic::networking::drop_dns_iterator")?;
Ok(())
}
fn resolve_next<T: NetworkingCtx>(
mut caller: Caller<T>,
dns_iter_id: u64,
addr_type_u32_ptr: u32,
addr_u8_ptr: u32,
port_u16_ptr: u32,
flow_info_u32_ptr: u32,
scope_id_u32_ptr: u32,
) -> Result<u32> {
let memory = get_memory(&mut caller)?;
let dns_iter = caller
.data_mut()
.dns_resources_mut()
.get_mut(dns_iter_id)
.or_trap("lunatic::networking::resolve_next")?;
match dns_iter.next() {
Some(socket_addr) => {
match socket_addr {
SocketAddr::V4(v4) => {
memory
.write(&mut caller, addr_type_u32_ptr as usize, &4u32.to_le_bytes())
.or_trap("lunatic::networking::resolve_next")?;
memory
.write(&mut caller, addr_u8_ptr as usize, &v4.ip().octets())
.or_trap("lunatic::networking::resolve_next")?;
memory
.write(&mut caller, port_u16_ptr as usize, &v4.port().to_le_bytes())
.or_trap("lunatic::networking::resolve_next")?;
}
SocketAddr::V6(v6) => {
memory
.write(&mut caller, addr_type_u32_ptr as usize, &6u32.to_le_bytes())
.or_trap("lunatic::networking::resolve_next")?;
memory
.write(&mut caller, addr_u8_ptr as usize, &v6.ip().octets())
.or_trap("lunatic::networking::resolve_next")?;
memory
.write(&mut caller, port_u16_ptr as usize, &v6.port().to_le_bytes())
.or_trap("lunatic::networking::resolve_next")?;
memory
.write(
&mut caller,
flow_info_u32_ptr as usize,
&v6.flowinfo().to_le_bytes(),
)
.or_trap("lunatic::networking::resolve_next")?;
memory
.write(
&mut caller,
scope_id_u32_ptr as usize,
&v6.scope_id().to_le_bytes(),
)
.or_trap("lunatic::networking::resolve_next")?;
}
}
Ok(0)
}
None => Ok(1),
}
}