use std::time::Duration;
use anyhow::Result;
use tokio::io::{AsyncRead, AsyncWrite};
use arcbox_protocol::agent::PingResponse;
use crate::rpc::{
AGENT_VERSION, ErrorResponse, RpcRequest, RpcResponse, parse_request, read_message,
write_response,
};
use super::disk::handle_disk_trim;
use super::kubernetes::{
handle_delete_kubernetes, handle_kubernetes_kubeconfig, handle_kubernetes_status,
handle_start_kubernetes, handle_stop_kubernetes,
};
use super::runtime::{handle_ensure_runtime, handle_runtime_status};
use super::sandbox::handle_sandbox_message;
use super::system_info::handle_get_system_info;
use super::vsock::is_peer_closed_error;
enum RequestResult {
Single(RpcResponse),
}
pub(super) async fn handle_connection<S>(mut stream: S) -> Result<()>
where
S: AsyncRead + AsyncWrite + Unpin,
{
loop {
let (msg_type, trace_id, payload) = match read_message(&mut stream).await {
Ok(msg) => msg,
Err(e) => {
if is_peer_closed_error(&e) {
tracing::debug!("Client disconnected");
return Ok(());
}
return Err(e);
}
};
tracing::info!(
trace_id = %trace_id,
"Received message type {:?}, payload_len={}",
msg_type,
payload.len()
);
if msg_type.is_sandbox_request() {
if let Err(e) = handle_sandbox_message(&mut stream, msg_type, &trace_id, &payload).await
{
tracing::warn!(trace_id = %trace_id, error = %e, "sandbox handler error");
}
continue;
}
let result = match parse_request(msg_type, &payload) {
Ok(request) => handle_request(request).await,
Err(e) => {
tracing::warn!(trace_id = %trace_id, "Failed to parse request: {}", e);
RequestResult::Single(RpcResponse::Error(ErrorResponse::new(
400,
format!("invalid request: {}", e),
)))
}
};
match result {
RequestResult::Single(response) => {
write_response(&mut stream, &response, &trace_id).await?;
}
}
}
}
async fn handle_request(request: RpcRequest) -> RequestResult {
match request {
RpcRequest::Ping(req) => RequestResult::Single(handle_ping(req)),
RpcRequest::GetSystemInfo => RequestResult::Single(handle_get_system_info().await),
RpcRequest::EnsureRuntime(req) => RequestResult::Single(handle_ensure_runtime(req).await),
RpcRequest::RuntimeStatus(req) => RequestResult::Single(handle_runtime_status(req).await),
RpcRequest::StartKubernetes(req) => {
RequestResult::Single(handle_start_kubernetes(req).await)
}
RpcRequest::StopKubernetes(req) => RequestResult::Single(handle_stop_kubernetes(req).await),
RpcRequest::DeleteKubernetes(req) => {
RequestResult::Single(handle_delete_kubernetes(req).await)
}
RpcRequest::KubernetesStatus(req) => {
RequestResult::Single(handle_kubernetes_status(req).await)
}
RpcRequest::KubernetesKubeconfig(req) => {
RequestResult::Single(handle_kubernetes_kubeconfig(req).await)
}
RpcRequest::Shutdown(req) => RequestResult::Single(handle_shutdown(req)),
RpcRequest::MmapReadFile(req) => RequestResult::Single(handle_mmap_read_file(req)),
RpcRequest::DiskTrim(_) => RequestResult::Single(handle_disk_trim().await),
}
}
fn handle_shutdown(req: arcbox_protocol::agent::ShutdownRequest) -> RpcResponse {
let grace = if req.timeout_seconds == 0 {
Duration::from_secs(u64::from(
arcbox_constants::timeouts::GUEST_SHUTDOWN_GRACE_SECS,
))
} else {
Duration::from_secs(u64::from(req.timeout_seconds))
};
tracing::info!(grace_secs = grace.as_secs(), "Shutdown requested by host");
std::thread::spawn(move || {
std::thread::sleep(Duration::from_millis(100));
crate::shutdown::poweroff(grace);
});
RpcResponse::Shutdown(arcbox_protocol::agent::ShutdownResponse { accepted: true })
}
pub(super) fn sync_clock_from_host(timestamp_secs: i64) -> bool {
if timestamp_secs <= 0 {
return false;
}
let ts = libc::timespec {
tv_sec: timestamp_secs,
tv_nsec: 0,
};
let ret = unsafe { libc::clock_settime(libc::CLOCK_REALTIME, &ts) };
if ret != 0 {
tracing::warn!(
timestamp_secs,
error = %std::io::Error::last_os_error(),
"failed to set clock from host"
);
return false;
}
true
}
fn handle_mmap_read_file(req: arcbox_protocol::agent::MmapReadFileRequest) -> RpcResponse {
use std::os::unix::io::AsRawFd;
const PAGE_SIZE: u64 = 4096;
if req.length == 0 {
return RpcResponse::MmapReadFile(arcbox_protocol::agent::MmapReadFileResponse {
data: Vec::new(),
bytes_read: 0,
});
}
if req.length > 64 * 1024 * 1024 {
return RpcResponse::Error(crate::rpc::ErrorResponse::new(
libc::EINVAL,
"MmapReadFile length exceeds 64 MiB cap",
));
}
let file = match std::fs::OpenOptions::new().read(true).open(&req.path) {
Ok(f) => f,
Err(e) => {
return RpcResponse::Error(crate::rpc::ErrorResponse::new(
e.raw_os_error().unwrap_or(libc::EIO),
format!("open {}: {e}", req.path),
));
}
};
let page_base = req.offset & !(PAGE_SIZE - 1);
let inner_off = (req.offset - page_base) as usize;
let total_len = inner_off + req.length as usize;
let mapped_len = total_len.div_ceil(PAGE_SIZE as usize) * PAGE_SIZE as usize;
let ptr = unsafe {
libc::mmap(
std::ptr::null_mut(),
mapped_len,
libc::PROT_READ,
libc::MAP_SHARED,
file.as_raw_fd(),
#[allow(clippy::cast_possible_wrap)]
{
page_base as libc::off_t
},
)
};
if ptr == libc::MAP_FAILED {
let e = std::io::Error::last_os_error();
return RpcResponse::Error(crate::rpc::ErrorResponse::new(
e.raw_os_error().unwrap_or(libc::EIO),
format!("mmap {}: {e}", req.path),
));
}
let data = unsafe {
let slice =
std::slice::from_raw_parts(ptr.cast::<u8>().add(inner_off), req.length as usize);
slice.to_vec()
};
unsafe {
libc::munmap(ptr, mapped_len);
}
let bytes_read = data.len() as u64;
RpcResponse::MmapReadFile(arcbox_protocol::agent::MmapReadFileResponse { data, bytes_read })
}
fn handle_ping(req: arcbox_protocol::agent::PingRequest) -> RpcResponse {
tracing::debug!("Ping request: {:?}", req.message);
sync_clock_from_host(req.timestamp_secs);
RpcResponse::Ping(PingResponse {
message: if req.message.is_empty() {
"pong".to_string()
} else {
format!("pong: {}", req.message)
},
version: AGENT_VERSION.to_string(),
})
}