use astrid_core::session_token::{
HandshakeRequest, HandshakeResponse, PROTOCOL_VERSION, SessionToken,
};
use crate::engine::wasm::host::util;
use crate::engine::wasm::host_state::HostState;
use extism::{CurrentPlugin, Error, UserData, Val};
const MAX_ACTIVE_STREAMS: usize = 8;
pub(crate) fn astrid_net_bind_unix_impl(
_: &mut CurrentPlugin,
_: &[Val],
outputs: &mut [Val],
user_data: UserData<HostState>,
) -> Result<(), Error> {
let ud = user_data.get()?;
let state = ud
.lock()
.map_err(|e| Error::msg(format!("host state lock poisoned: {e}")))?;
if let Some(ref gate) = state.security {
let capsule_id = state.capsule_id.as_str().to_owned();
let gate = gate.clone();
let handle = state.runtime_handle.clone();
let semaphore = state.host_semaphore.clone();
util::bounded_block_on(&handle, &semaphore, async move {
gate.check_net_bind(&capsule_id).await
})
.map_err(|e| Error::msg(format!("security denied net_bind: {e}")))?;
}
outputs[0] = Val::I64(1);
Ok(())
}
pub(crate) fn astrid_net_accept_impl(
plugin: &mut CurrentPlugin,
_: &[Val],
outputs: &mut [Val],
user_data: UserData<HostState>,
) -> Result<(), Error> {
let ud = user_data.get()?;
let (listener_arc, rt_handle, cancel_token, session_token, host_semaphore) = {
let state = ud
.lock()
.map_err(|e| Error::msg(format!("host state lock poisoned: {e}")))?;
let stream_count = state.active_streams.len();
if stream_count >= MAX_ACTIVE_STREAMS {
tracing::warn!(
max = MAX_ACTIVE_STREAMS,
current = stream_count,
"accept: connection cap reached, rejecting"
);
return Err(Error::msg(format!(
"connection cap reached ({stream_count}/{MAX_ACTIVE_STREAMS})"
)));
}
let listener = state
.cli_socket_listener
.clone()
.ok_or_else(|| Error::msg("No CLI Socket Listener available in HostState"))?;
(
listener,
state.runtime_handle.clone(),
state.cancel_token.clone(),
state.session_token.clone(),
state.host_semaphore.clone(),
)
};
let stream = loop {
let accept_result =
util::bounded_block_on_cancellable(&rt_handle, &host_semaphore, &cancel_token, async {
let l = listener_arc.lock().await;
l.accept().await
});
let (stream, _addr) = match accept_result {
Some(result) => result?,
None => return Err(Error::msg("capsule unloading")),
};
#[cfg(unix)]
if let Err(reason) = verify_peer_credentials(&stream) {
tracing::warn!(
security_event = true,
reason = %reason,
"Rejected socket connection: peer credential check failed"
);
drop(stream);
continue;
}
let mut stream = stream;
if let Some(ref token) = session_token {
let handshake_result = util::bounded_block_on_cancellable(
&rt_handle,
&host_semaphore,
&cancel_token,
validate_handshake(&mut stream, token),
);
match handshake_result {
None => return Err(Error::msg("capsule unloading")),
Some(Ok(())) => break stream,
Some(Err(reason)) => {
tracing::warn!(
security_event = true,
reason = %reason,
"Rejected socket connection: handshake failed"
);
drop(stream);
continue;
},
}
} else {
break stream;
}
};
let mut state = ud
.lock()
.map_err(|e| Error::msg(format!("host state lock poisoned: {e}")))?;
let stream_count = state.active_streams.len();
if stream_count >= MAX_ACTIVE_STREAMS {
tracing::warn!(
max = MAX_ACTIVE_STREAMS,
current = stream_count,
"accept: connection cap reached post-handshake, dropping authenticated stream"
);
drop(stream);
return Err(Error::msg(format!(
"connection cap reached ({stream_count}/{MAX_ACTIVE_STREAMS})"
)));
}
let handle_id = state.next_stream_id;
state.next_stream_id = state
.next_stream_id
.checked_add(1)
.ok_or_else(|| Error::msg("stream handle ID space exhausted"))?;
debug_assert!(
!state.active_streams.contains_key(&handle_id),
"stream handle ID collision"
);
state.active_streams.insert(
handle_id,
std::sync::Arc::new(tokio::sync::Mutex::new(stream)),
);
let connected_msg = astrid_events::ipc::IpcMessage::new(
"client.v1.connected",
astrid_events::ipc::IpcPayload::Connect,
state.capsule_uuid,
);
let _ = state.event_bus.publish(astrid_events::AstridEvent::Ipc {
metadata: astrid_events::EventMetadata::new("net_accept"),
message: connected_msg,
});
let mem = plugin.memory_new(handle_id.to_string())?;
outputs[0] = plugin.memory_to_val(mem);
Ok(())
}
pub(crate) fn astrid_net_read_impl(
plugin: &mut CurrentPlugin,
inputs: &[Val],
outputs: &mut [Val],
user_data: UserData<HostState>,
) -> Result<(), Error> {
let handle_str = util::get_safe_string(plugin, &inputs[0], 1024)?;
let handle_id: u64 = handle_str
.parse()
.map_err(|_| Error::msg("Invalid stream handle"))?;
let ud = user_data.get()?;
let (stream_arc, rt_handle, cancel_token, host_semaphore) = {
let state = ud
.lock()
.map_err(|e| Error::msg(format!("host state lock poisoned: {e}")))?;
let stream = state
.active_streams
.get(&handle_id)
.ok_or_else(|| Error::msg("Stream handle not found"))?
.clone();
(
stream,
state.runtime_handle.clone(),
state.cancel_token.clone(),
state.host_semaphore.clone(),
)
};
use tokio::io::AsyncReadExt;
let result =
util::bounded_block_on_cancellable(&rt_handle, &host_semaphore, &cancel_token, async {
let mut stream = stream_arc.lock().await;
let mut len_buf = [0u8; 4];
match tokio::time::timeout(
std::time::Duration::from_millis(50),
stream.read_exact(&mut len_buf),
)
.await
{
Err(_) => return Ok(Vec::new()), Ok(Err(e)) => return Err(Error::msg(format!("socket read error: {e}"))),
Ok(Ok(_)) => {}, }
let len = u32::from_be_bytes(len_buf) as usize;
if len > 10 * 1024 * 1024 {
return Err(Error::msg("Payload too large (max 10MB)"));
}
let mut payload = vec![0u8; len];
let timeout_ms = 5000 + (len as u64 / 1024);
tokio::time::timeout(
std::time::Duration::from_millis(timeout_ms),
stream.read_exact(&mut payload),
)
.await
.map_err(|_| Error::msg("Payload read timed out"))?
.map_err(|e| Error::msg(format!("socket payload read error: {e}")))?;
Ok(payload)
});
let result = match result {
Some(r) => r,
None => Ok(Vec::new()),
};
let result = result?;
if result.is_empty() {
let mem = plugin.memory_new("")?;
outputs[0] = plugin.memory_to_val(mem);
} else {
let mem = plugin.memory_new(&result)?;
outputs[0] = plugin.memory_to_val(mem);
}
Ok(())
}
pub(crate) fn astrid_net_write_impl(
plugin: &mut CurrentPlugin,
inputs: &[Val],
_: &mut [Val],
user_data: UserData<HostState>,
) -> Result<(), Error> {
let handle_str = util::get_safe_string(plugin, &inputs[0], 1024)?;
let handle_id: u64 = handle_str
.parse()
.map_err(|_| Error::msg("Invalid stream handle"))?;
let data = util::get_safe_bytes(plugin, &inputs[1], 10 * 1024 * 1024)?;
let ud = user_data.get()?;
let (stream_arc, rt_handle, host_semaphore, cancel_token) = {
let state = ud
.lock()
.map_err(|e| Error::msg(format!("host state lock poisoned: {e}")))?;
let stream = state
.active_streams
.get(&handle_id)
.ok_or_else(|| Error::msg("Stream handle not found"))?
.clone();
(
stream,
state.runtime_handle.clone(),
state.host_semaphore.clone(),
state.cancel_token.clone(),
)
};
use tokio::io::AsyncWriteExt;
let result =
util::bounded_block_on_cancellable(&rt_handle, &host_semaphore, &cancel_token, async {
let mut stream = stream_arc.lock().await;
let len = u32::try_from(data.len())
.map_err(|_| std::io::Error::other("write payload too large for length prefix"))?;
stream.write_all(&len.to_be_bytes()).await?;
stream.write_all(&data).await?;
stream.flush().await?;
Ok::<(), std::io::Error>(())
});
match result {
Some(inner) => inner?,
None => return Err(Error::msg("capsule unloading")),
}
Ok(())
}
pub(crate) fn astrid_net_close_stream_impl(
plugin: &mut CurrentPlugin,
inputs: &[Val],
_: &mut [Val],
user_data: UserData<HostState>,
) -> Result<(), Error> {
let handle_str = util::get_safe_string(plugin, &inputs[0], 1024)?;
let handle_id: u64 = handle_str
.parse()
.map_err(|_| Error::msg("Invalid stream handle"))?;
let ud = user_data.get()?;
let mut state = ud
.lock()
.map_err(|e| Error::msg(format!("host state lock poisoned: {e}")))?;
if state.active_streams.remove(&handle_id).is_some() {
let msg = astrid_events::ipc::IpcMessage::new(
"client.v1.disconnect",
astrid_events::ipc::IpcPayload::Disconnect {
reason: Some("stream_closed".to_string()),
},
state.capsule_uuid,
);
let _ = state.event_bus.publish(astrid_events::AstridEvent::Ipc {
metadata: astrid_events::EventMetadata::new("net_close_stream"),
message: msg,
});
}
Ok(())
}
pub(crate) fn astrid_net_poll_accept_impl(
plugin: &mut CurrentPlugin,
_: &[Val],
outputs: &mut [Val],
user_data: UserData<HostState>,
) -> Result<(), Error> {
let ud = user_data.get()?;
let (listener_arc, rt_handle, cancel_token, session_token, host_semaphore, stream_count) = {
let state = ud
.lock()
.map_err(|e| Error::msg(format!("host state lock poisoned: {e}")))?;
let listener = state
.cli_socket_listener
.clone()
.ok_or_else(|| Error::msg("No CLI Socket Listener available in HostState"))?;
(
listener,
state.runtime_handle.clone(),
state.cancel_token.clone(),
state.session_token.clone(),
state.host_semaphore.clone(),
state.active_streams.len(),
)
};
if stream_count >= MAX_ACTIVE_STREAMS {
tracing::warn!(
max = MAX_ACTIVE_STREAMS,
current = stream_count,
"poll_accept: connection cap reached, rejecting"
);
let mem = plugin.memory_new("")?;
outputs[0] = plugin.memory_to_val(mem);
return Ok(());
}
let accept_result =
util::bounded_block_on_cancellable(&rt_handle, &host_semaphore, &cancel_token, async {
let l = listener_arc.lock().await;
tokio::time::timeout(std::time::Duration::from_millis(10), l.accept()).await
});
let (stream, _addr) = match accept_result {
None => {
let mem = plugin.memory_new("")?;
outputs[0] = plugin.memory_to_val(mem);
return Ok(());
},
Some(Err(_)) => {
let mem = plugin.memory_new("")?;
outputs[0] = plugin.memory_to_val(mem);
return Ok(());
},
Some(Ok(Err(e))) => return Err(Error::msg(format!("accept error: {e}"))),
Some(Ok(Ok(pair))) => pair,
};
#[cfg(unix)]
if let Err(reason) = verify_peer_credentials(&stream) {
tracing::warn!(
security_event = true,
reason = %reason,
"poll_accept: rejected connection (peer credential check failed)"
);
drop(stream);
let mem = plugin.memory_new("")?;
outputs[0] = plugin.memory_to_val(mem);
return Ok(());
}
let mut stream = stream;
if let Some(ref token) = session_token {
let handshake_result = util::bounded_block_on_cancellable(
&rt_handle,
&host_semaphore,
&cancel_token,
validate_handshake(&mut stream, token),
);
match handshake_result {
None => {
let mem = plugin.memory_new("")?;
outputs[0] = plugin.memory_to_val(mem);
return Ok(());
},
Some(Err(reason)) => {
tracing::warn!(
security_event = true,
reason = %reason,
"poll_accept: rejected connection (handshake failed)"
);
drop(stream);
let mem = plugin.memory_new("")?;
outputs[0] = plugin.memory_to_val(mem);
return Ok(());
},
Some(Ok(())) => {},
}
}
let mut state = ud
.lock()
.map_err(|e| Error::msg(format!("host state lock poisoned: {e}")))?;
if state.active_streams.len() >= MAX_ACTIVE_STREAMS {
drop(stream);
let mem = plugin.memory_new("")?;
outputs[0] = plugin.memory_to_val(mem);
return Ok(());
}
let handle_id = state.next_stream_id;
state.next_stream_id = state
.next_stream_id
.checked_add(1)
.ok_or_else(|| Error::msg("stream handle ID space exhausted"))?;
debug_assert!(
!state.active_streams.contains_key(&handle_id),
"stream handle ID collision"
);
state.active_streams.insert(
handle_id,
std::sync::Arc::new(tokio::sync::Mutex::new(stream)),
);
let connected_msg = astrid_events::ipc::IpcMessage::new(
"client.v1.connected",
astrid_events::ipc::IpcPayload::Connect,
state.capsule_uuid,
);
let _ = state.event_bus.publish(astrid_events::AstridEvent::Ipc {
metadata: astrid_events::EventMetadata::new("net_poll_accept"),
message: connected_msg,
});
let mem = plugin.memory_new(handle_id.to_string())?;
outputs[0] = plugin.memory_to_val(mem);
Ok(())
}
const HANDSHAKE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
const MAX_HANDSHAKE_SIZE: usize = 4096;
async fn validate_handshake(
stream: &mut tokio::net::UnixStream,
expected_token: &SessionToken,
) -> Result<(), String> {
use tokio::io::AsyncReadExt;
let mut len_buf = [0u8; 4];
tokio::time::timeout(HANDSHAKE_TIMEOUT, stream.read_exact(&mut len_buf))
.await
.map_err(|_| "handshake timed out (5s)".to_string())?
.map_err(|e| format!("handshake read error: {e}"))?;
let len = u32::from_be_bytes(len_buf) as usize;
if len > MAX_HANDSHAKE_SIZE {
return Err(format!("handshake too large: {len} bytes"));
}
let mut payload = vec![0u8; len];
tokio::time::timeout(HANDSHAKE_TIMEOUT, stream.read_exact(&mut payload))
.await
.map_err(|_| "handshake payload timed out".to_string())?
.map_err(|e| format!("handshake payload read error: {e}"))?;
let request: HandshakeRequest =
serde_json::from_slice(&payload).map_err(|e| format!("invalid handshake JSON: {e}"))?;
if request.protocol_version != PROTOCOL_VERSION {
let reason = format!(
"Protocol version mismatch (client={}, server={}). \
Restart the daemon with `astrid daemon restart`.",
request.protocol_version, PROTOCOL_VERSION,
);
if let Err(e) =
send_handshake_response_timed(stream, &HandshakeResponse::error(&reason)).await
{
tracing::warn!(error = %e, "Failed to send handshake error response for protocol mismatch");
}
return Err(reason);
}
let client_token = match SessionToken::from_hex(&request.token) {
Ok(t) => t,
Err(_) => {
if let Err(e) = send_handshake_response_timed(
stream,
&HandshakeResponse::error("authentication failed"),
)
.await
{
tracing::warn!(error = %e, "Failed to send handshake error response");
}
return Err("invalid session token".to_string());
},
};
if !expected_token.ct_eq(&client_token) {
if let Err(e) = send_handshake_response_timed(
stream,
&HandshakeResponse::error("authentication failed"),
)
.await
{
tracing::warn!(error = %e, "Failed to send handshake error response");
}
return Err("invalid session token".to_string());
}
send_handshake_response_timed(stream, &HandshakeResponse::ok())
.await
.map_err(|e| format!("failed to send handshake response: {e}"))?;
let safe_version: String = request.client_version.chars().take(64).collect();
tracing::info!(
client_version = %safe_version,
"Socket handshake succeeded"
);
Ok(())
}
async fn send_handshake_response_timed(
stream: &mut tokio::net::UnixStream,
response: &HandshakeResponse,
) -> Result<(), std::io::Error> {
tokio::time::timeout(HANDSHAKE_TIMEOUT, send_handshake_response(stream, response))
.await
.map_err(|_| std::io::Error::other("handshake response write timed out (5s)"))?
}
async fn send_handshake_response(
stream: &mut tokio::net::UnixStream,
response: &HandshakeResponse,
) -> Result<(), std::io::Error> {
use tokio::io::AsyncWriteExt;
let bytes = serde_json::to_vec(response)
.map_err(|e| std::io::Error::other(format!("serialize handshake response: {e}")))?;
let len = u32::try_from(bytes.len())
.map_err(|_| std::io::Error::other("handshake response too large"))?;
stream.write_all(&len.to_be_bytes()).await?;
stream.write_all(&bytes).await?;
stream.flush().await?;
Ok(())
}
#[cfg(unix)]
fn verify_peer_credentials(stream: &tokio::net::UnixStream) -> Result<(), String> {
match stream.peer_cred() {
Ok(cred) => {
let peer_uid = cred.uid();
let my_uid = nix::unistd::geteuid().as_raw();
if peer_uid != my_uid {
Err(format!(
"peer UID {peer_uid} does not match daemon UID {my_uid}"
))
} else {
Ok(())
}
},
Err(e) => Err(format!("failed to check peer credentials: {e}")),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn max_active_streams_pinned() {
assert_eq!(MAX_ACTIVE_STREAMS, 8);
}
}