use std::sync::Arc;
use wasmtime::component::Resource;
use crate::engine::wasm::bindings::astrid::net::host::{
self as net, ErrorCode, TcpListener, TcpStream, UdpSocket, UnixListener,
};
use crate::engine::wasm::host::http::is_safe_ip;
use crate::engine::wasm::host::util;
use crate::engine::wasm::host_state::{HostState, NetStream, TcpStreamSlot};
mod handshake;
mod stream;
mod tcp_listener;
mod tcp_stream;
mod udp_socket;
mod unix_listener;
use stream::CONNECT_TIMEOUT;
pub(super) const MAX_ACTIVE_STREAMS: usize = 8;
pub(super) struct UnixListenerSlot;
#[allow(dead_code)]
pub(super) struct TcpListenerSlot;
#[allow(dead_code)]
pub(super) struct UdpSocketSlot;
pub(super) fn validate_host(host: &str) -> Result<(), ErrorCode> {
if host.is_empty() {
return Err(ErrorCode::AddressNotAvailable);
}
if host.len() > 255 {
return Err(ErrorCode::AddressNotAvailable);
}
if host.bytes().any(|b| b == 0) {
return Err(ErrorCode::AddressNotAvailable);
}
Ok(())
}
pub(super) fn map_io_err(err: std::io::Error) -> ErrorCode {
use std::io::ErrorKind;
match err.kind() {
ErrorKind::WouldBlock => ErrorCode::WouldBlock,
ErrorKind::ConnectionRefused => ErrorCode::ConnectionRefused,
ErrorKind::ConnectionReset | ErrorKind::ConnectionAborted | ErrorKind::BrokenPipe => {
ErrorCode::ConnectionReset
},
ErrorKind::TimedOut => ErrorCode::Timeout,
ErrorKind::AddrInUse => ErrorCode::AddressInUse,
ErrorKind::AddrNotAvailable => ErrorCode::AddressNotAvailable,
_ => ErrorCode::Unknown(err.to_string()),
}
}
pub(super) fn audit_net<T, E: std::fmt::Debug>(
state: &HostState,
op: &'static str,
bytes: u64,
result: &Result<T, E>,
) {
let capsule_id = state.capsule_id.as_str();
let principal = state.effective_principal();
match result {
Ok(_) => tracing::debug!(
target: "astrid.audit.net",
%capsule_id,
%principal,
fn = op,
bytes,
"audit",
),
Err(e) => tracing::debug!(
target: "astrid.audit.net",
%capsule_id,
%principal,
fn = op,
error = ?e,
"audit",
),
}
}
pub(super) fn net_stream(
table: &wasmtime::component::ResourceTable,
rep: u32,
) -> Result<NetStream, ErrorCode> {
table
.get::<NetStream>(&Resource::new_borrow(rep))
.cloned()
.map_err(|_| ErrorCode::InvalidHandle)
}
pub(super) fn with_tcp_slot_mut<F>(
table: &mut wasmtime::component::ResourceTable,
rep: u32,
op: F,
) -> Result<(), ErrorCode>
where
F: FnOnce(&mut TcpStreamSlot),
{
let s = table
.get_mut::<NetStream>(&Resource::new_borrow(rep))
.map_err(|_| ErrorCode::InvalidHandle)?;
match s {
NetStream::Tcp(slot) => {
op(slot);
Ok(())
},
NetStream::Unix(_) => Err(ErrorCode::NotTcp),
}
}
pub(super) fn with_tcp_stream<T, F>(state: &mut HostState, rep: u32, op: F) -> Result<T, ErrorCode>
where
F: FnOnce(&tokio::net::TcpStream) -> Result<T, ErrorCode>,
{
let stream = net_stream(&state.resource_table, rep)?;
let rt = state.runtime_handle.clone();
let sem = state.blocking_semaphore.clone();
let tok = state.cancel_token.clone();
match stream {
NetStream::Tcp(slot) => {
let result = util::bounded_block_on_cancellable(&rt, &sem, &tok, async move {
let s = slot.stream.lock().await;
op(&s)
});
result.unwrap_or(Err(ErrorCode::Closed))
},
NetStream::Unix(_) => Err(ErrorCode::NotTcp),
}
}
impl net::Host for HostState {
fn bind_unix(&mut self) -> Result<Resource<UnixListener>, ErrorCode> {
if let Some(ref gate) = self.security {
let capsule_id = self.capsule_id.as_str().to_owned();
let gate = gate.clone();
let handle = self.runtime_handle.clone();
let semaphore = self.blocking_semaphore.clone();
let check = util::bounded_block_on(&handle, &semaphore, async move {
gate.check_net_bind(&capsule_id).await
});
if check.is_err() {
return Err(ErrorCode::CapabilityDenied);
}
}
if self.cli_socket_listener.is_none() {
return Err(ErrorCode::Unknown(
"no pre-bound Unix listener configured".to_string(),
));
}
let res = self
.resource_table
.push(UnixListenerSlot)
.map_err(|e| ErrorCode::Unknown(format!("resource table: {e}")))?;
Ok(Resource::new_own(res.rep()))
}
fn bind_tcp(&mut self, _host: String, _port: u16) -> Result<Resource<TcpListener>, ErrorCode> {
Err(ErrorCode::CapabilityDenied)
}
fn connect_tcp(&mut self, host: String, port: u16) -> Result<Resource<TcpStream>, ErrorCode> {
validate_host(&host)?;
if let Some(ref gate) = self.security {
let capsule_id = self.capsule_id.as_str().to_owned();
let host_for_check = host.clone();
let gate = gate.clone();
let rt = self.runtime_handle.clone();
let semaphore = self.blocking_semaphore.clone();
let check = util::bounded_block_on(&rt, &semaphore, async move {
gate.check_net_connect(&capsule_id, &host_for_check, port)
.await
});
if check.is_err() {
return Err(ErrorCode::CapabilityDenied);
}
}
if self.net_stream_count >= MAX_ACTIVE_STREAMS {
return Err(ErrorCode::Quota);
}
let rt_handle = self.runtime_handle.clone();
let blocking_semaphore = self.blocking_semaphore.clone();
let cancel_token = self.cancel_token.clone();
let connect_result = util::bounded_block_on_cancellable(
&rt_handle,
&blocking_semaphore,
&cancel_token,
async {
tokio::time::timeout(CONNECT_TIMEOUT, async {
let addrs: Vec<std::net::SocketAddr> =
tokio::net::lookup_host((host.as_str(), port))
.await
.map_err(|_| ErrorCode::NameUnresolvable)?
.collect();
if addrs.is_empty() {
return Err(ErrorCode::NameUnresolvable);
}
for addr in &addrs {
if !is_safe_ip(addr.ip()) {
return Err(ErrorCode::AirlockRejected);
}
}
tokio::net::TcpStream::connect(&addrs[..])
.await
.map_err(map_io_err)
})
.await
.map_err(|_| ErrorCode::Timeout)
.and_then(|inner| inner)
},
);
let stream = match connect_result {
Some(Ok(s)) => s,
Some(Err(e)) => return Err(e),
None => return Err(ErrorCode::Closed),
};
if self.net_stream_count >= MAX_ACTIVE_STREAMS {
drop(stream);
return Err(ErrorCode::Quota);
}
let net_stream = NetStream::Tcp(TcpStreamSlot {
stream: Arc::new(tokio::sync::Mutex::new(stream)),
read_timeout: None,
write_timeout: None,
});
let res = self
.resource_table
.push(net_stream)
.map_err(|e| ErrorCode::Unknown(format!("resource table: {e}")))?;
self.net_stream_count += 1;
let result: Result<Resource<TcpStream>, ErrorCode> = Ok(Resource::new_own(res.rep()));
audit_net(self, "astrid:net/host.connect-tcp", 0, &result);
result
}
fn udp_bind(&mut self, _host: String, _port: u16) -> Result<Resource<UdpSocket>, ErrorCode> {
Err(ErrorCode::CapabilityDenied)
}
fn lookup_host(&mut self, host: String) -> Result<Vec<String>, ErrorCode> {
validate_host(&host)?;
if let Some(ref gate) = self.security {
let capsule_id = self.capsule_id.as_str().to_owned();
let host_for_check = host.clone();
let gate = gate.clone();
let rt = self.runtime_handle.clone();
let semaphore = self.blocking_semaphore.clone();
let check = util::bounded_block_on(&rt, &semaphore, async move {
gate.check_net_connect(&capsule_id, &host_for_check, 0)
.await
});
if check.is_err() {
return Err(ErrorCode::CapabilityDenied);
}
}
let rt = self.runtime_handle.clone();
let sem = self.blocking_semaphore.clone();
let host_owned = host.clone();
let resolved: Vec<std::net::SocketAddr> = util::bounded_block_on(&rt, &sem, async move {
tokio::net::lookup_host((host_owned.as_str(), 0))
.await
.map(|it| it.collect::<Vec<_>>())
})
.map_err(|_| ErrorCode::NameUnresolvable)?;
let mut out = Vec::new();
for addr in resolved {
if is_safe_ip(addr.ip()) {
out.push(addr.to_string());
}
}
Ok(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn max_active_streams_pinned() {
assert_eq!(MAX_ACTIVE_STREAMS, 8);
}
#[test]
fn validate_host_accepts_normal_names() {
assert!(validate_host("example.com").is_ok());
assert!(validate_host("fulcrum.unicity.network").is_ok());
assert!(validate_host("127.0.0.1").is_ok());
assert!(validate_host("::1").is_ok());
}
#[test]
fn validate_host_rejects_empty() {
assert!(validate_host("").is_err());
}
#[test]
fn validate_host_rejects_null_bytes() {
assert!(validate_host("evil\0.com").is_err());
}
#[test]
fn validate_host_rejects_overlength() {
let long = "a".repeat(256);
assert!(validate_host(&long).is_err());
}
#[test]
fn validate_host_accepts_max_length() {
let max = "a".repeat(255);
assert!(validate_host(&max).is_ok());
}
}