use anyhow::Result;
pub const AGENT_PORT: u32 = 1024;
#[cfg(target_os = "linux")]
mod linux {
use std::collections::HashMap;
use std::sync::Arc;
use anyhow::{Context, Result};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::{mpsc, RwLock};
use tokio_vsock::{VsockAddr, VsockListener, VMADDR_CID_ANY};
use super::AGENT_PORT;
use crate::container::{ContainerHandle, ContainerRuntime, ContainerState};
use crate::log_watcher::{watch_log_file, LogWatchOptions};
use crate::rpc::{
parse_request, read_message, write_response, ErrorResponse, RpcRequest, RpcResponse,
AGENT_VERSION,
};
use arcbox_protocol::agent::{
ContainerInfo, CreateContainerResponse, ExecOutput, ListContainersResponse, LogEntry,
LogsRequest, PingResponse, SystemInfo,
};
pub struct AgentState {
pub runtime: ContainerRuntime,
}
impl AgentState {
pub fn new() -> Self {
Self {
runtime: ContainerRuntime::new(),
}
}
}
impl Default for AgentState {
fn default() -> Self {
Self::new()
}
}
enum RequestResult {
Single(RpcResponse),
Stream(mpsc::Receiver<LogEntry>, mpsc::Sender<()>),
}
pub struct Agent {
state: Arc<RwLock<AgentState>>,
}
impl Agent {
pub fn new() -> Self {
Self {
state: Arc::new(RwLock::new(AgentState::new())),
}
}
pub async fn run(&self) -> Result<()> {
let addr = VsockAddr::new(VMADDR_CID_ANY, AGENT_PORT);
let mut listener =
VsockListener::bind(addr).context("failed to bind vsock listener")?;
tracing::info!("Agent listening on vsock port {}", AGENT_PORT);
loop {
match listener.accept().await {
Ok((stream, peer_addr)) => {
tracing::info!("Accepted connection from {:?}", peer_addr);
let state = Arc::clone(&self.state);
tokio::spawn(async move {
if let Err(e) = handle_connection(stream, state).await {
tracing::error!("Connection error: {}", e);
}
});
}
Err(e) => {
tracing::error!("Accept error: {}", e);
}
}
}
}
}
impl Default for Agent {
fn default() -> Self {
Self::new()
}
}
async fn handle_connection<S>(mut stream: S, state: Arc<RwLock<AgentState>>) -> Result<()>
where
S: AsyncRead + AsyncWrite + Unpin,
{
loop {
let (msg_type, payload) = match read_message(&mut stream).await {
Ok(msg) => msg,
Err(e) => {
if e.to_string().contains("failed to read message header") {
tracing::debug!("Client disconnected");
return Ok(());
}
return Err(e);
}
};
tracing::debug!("Received message type {:?}", msg_type);
let result = match parse_request(msg_type, &payload) {
Ok(request) => handle_request(request, &state).await,
Err(e) => {
tracing::warn!("Failed to parse request: {}", e);
RequestResult::Single(RpcResponse::Error(ErrorResponse::new(
400,
format!("invalid request: {}", e),
)))
}
};
match result {
RequestResult::Single(response) => {
write_response(&mut stream, &response).await?;
}
RequestResult::Stream(mut log_rx, cancel_tx) => {
tracing::debug!("Starting log stream");
loop {
tokio::select! {
entry = log_rx.recv() => {
match entry {
Some(log_entry) => {
let response = RpcResponse::LogEntry(log_entry);
if let Err(e) = write_response(&mut stream, &response).await {
tracing::debug!("Client disconnected during streaming: {}", e);
let _ = cancel_tx.send(()).await;
return Ok(());
}
}
None => {
tracing::debug!("Log stream ended");
break;
}
}
}
_ = tokio::time::sleep(tokio::time::Duration::from_secs(30)) => {
continue;
}
}
}
}
}
}
}
async fn handle_request(
request: RpcRequest,
state: &Arc<RwLock<AgentState>>,
) -> RequestResult {
match request {
RpcRequest::Ping(req) => RequestResult::Single(handle_ping(req)),
RpcRequest::GetSystemInfo => RequestResult::Single(handle_get_system_info().await),
RpcRequest::CreateContainer(req) => {
RequestResult::Single(handle_create_container(req, state).await)
}
RpcRequest::StartContainer(req) => {
RequestResult::Single(handle_start_container(&req.id, state).await)
}
RpcRequest::StopContainer(req) => {
RequestResult::Single(handle_stop_container(&req.id, req.timeout, state).await)
}
RpcRequest::RemoveContainer(req) => {
RequestResult::Single(handle_remove_container(&req.id, req.force, state).await)
}
RpcRequest::ListContainers(req) => {
RequestResult::Single(handle_list_containers(req.all, state).await)
}
RpcRequest::Exec(req) => RequestResult::Single(handle_exec(req, state).await),
RpcRequest::Logs(req) => handle_logs(req, state).await,
}
}
fn handle_ping(req: arcbox_protocol::agent::PingRequest) -> RpcResponse {
tracing::debug!("Ping request: {:?}", req.message);
RpcResponse::Ping(PingResponse {
message: if req.message.is_empty() {
"pong".to_string()
} else {
format!("pong: {}", req.message)
},
version: AGENT_VERSION.to_string(),
})
}
async fn handle_get_system_info() -> RpcResponse {
let info = collect_system_info();
RpcResponse::SystemInfo(info)
}
fn collect_system_info() -> SystemInfo {
let mut info = SystemInfo::default();
if let Ok(uname) = nix::sys::utsname::uname() {
info.kernel_version = uname.release().to_string_lossy().to_string();
info.os_name = uname.sysname().to_string_lossy().to_string();
info.os_version = uname.version().to_string_lossy().to_string();
info.arch = uname.machine().to_string_lossy().to_string();
info.hostname = uname.nodename().to_string_lossy().to_string();
}
if let Ok(meminfo) = std::fs::read_to_string("/proc/meminfo") {
for line in meminfo.lines() {
if line.starts_with("MemTotal:") {
if let Some(kb) = line.split_whitespace().nth(1) {
if let Ok(kb_val) = kb.parse::<u64>() {
info.total_memory = kb_val * 1024;
}
}
} else if line.starts_with("MemAvailable:") {
if let Some(kb) = line.split_whitespace().nth(1) {
if let Ok(kb_val) = kb.parse::<u64>() {
info.available_memory = kb_val * 1024;
}
}
}
}
}
info.cpu_count = std::thread::available_parallelism()
.map(|p| p.get() as u32)
.unwrap_or(1);
if let Ok(loadavg) = std::fs::read_to_string("/proc/loadavg") {
let parts: Vec<&str> = loadavg.split_whitespace().collect();
if parts.len() >= 3 {
if let Ok(load1) = parts[0].parse::<f64>() {
info.load_average.push(load1);
}
if let Ok(load5) = parts[1].parse::<f64>() {
info.load_average.push(load5);
}
if let Ok(load15) = parts[2].parse::<f64>() {
info.load_average.push(load15);
}
}
}
if let Ok(uptime) = std::fs::read_to_string("/proc/uptime") {
if let Some(secs) = uptime.split_whitespace().next() {
if let Ok(secs_val) = secs.parse::<f64>() {
info.uptime = secs_val as u64;
}
}
}
info
}
async fn handle_create_container(
req: arcbox_protocol::agent::CreateContainerRequest,
state: &Arc<RwLock<AgentState>>,
) -> RpcResponse {
tracing::info!("CreateContainer: name={}, image={}", req.name, req.image);
let container_id = uuid::Uuid::new_v4().to_string();
let env: Vec<(String, String)> = req.env.into_iter().collect();
let cmd = if !req.entrypoint.is_empty() {
let mut full_cmd = req.entrypoint;
full_cmd.extend(req.cmd);
full_cmd
} else if !req.cmd.is_empty() {
req.cmd
} else {
vec!["/bin/sh".to_string()]
};
let handle = ContainerHandle {
id: container_id.clone(),
name: req.name,
image: req.image,
command: cmd,
env,
working_dir: if req.working_dir.is_empty() {
"/".to_string()
} else {
req.working_dir
},
state: ContainerState::Created,
pid: None,
exit_code: None,
created_at: chrono::Utc::now(),
};
{
let mut state = state.write().await;
state.runtime.add_container(handle);
}
RpcResponse::CreateContainer(CreateContainerResponse { id: container_id })
}
async fn handle_start_container(id: &str, state: &Arc<RwLock<AgentState>>) -> RpcResponse {
tracing::info!("StartContainer: id={}", id);
let mut state = state.write().await;
match state.runtime.start_container(id).await {
Ok(()) => RpcResponse::Empty,
Err(e) => {
tracing::error!("Failed to start container {}: {}", id, e);
RpcResponse::Error(ErrorResponse::new(500, format!("failed to start: {}", e)))
}
}
}
async fn handle_stop_container(
id: &str,
timeout: u32,
state: &Arc<RwLock<AgentState>>,
) -> RpcResponse {
tracing::info!("StopContainer: id={}, timeout={}s", id, timeout);
let mut state = state.write().await;
match state.runtime.stop_container(id, timeout).await {
Ok(()) => RpcResponse::Empty,
Err(e) => {
tracing::error!("Failed to stop container {}: {}", id, e);
RpcResponse::Error(ErrorResponse::new(500, format!("failed to stop: {}", e)))
}
}
}
async fn handle_remove_container(
id: &str,
force: bool,
state: &Arc<RwLock<AgentState>>,
) -> RpcResponse {
tracing::info!("RemoveContainer: id={}, force={}", id, force);
let mut state = state.write().await;
match state.runtime.remove_container(id, force).await {
Ok(()) => RpcResponse::Empty,
Err(e) => {
tracing::error!("Failed to remove container {}: {}", id, e);
RpcResponse::Error(ErrorResponse::new(500, format!("failed to remove: {}", e)))
}
}
}
async fn handle_list_containers(all: bool, state: &Arc<RwLock<AgentState>>) -> RpcResponse {
tracing::debug!("ListContainers: all={}", all);
let state = state.read().await;
let containers: Vec<ContainerInfo> = state
.runtime
.list_containers(all)
.iter()
.map(|h| ContainerInfo {
id: h.id.clone(),
name: h.name.clone(),
image: h.image.clone(),
state: h.state.as_str().to_string(),
status: format_status(h),
created: h.created_at.timestamp(),
})
.collect();
RpcResponse::ListContainers(ListContainersResponse { containers })
}
fn format_status(handle: &ContainerHandle) -> String {
match handle.state {
ContainerState::Created => "Created".to_string(),
ContainerState::Running => {
if let Some(pid) = handle.pid {
format!("Running (PID: {})", pid)
} else {
"Running".to_string()
}
}
ContainerState::Stopped => {
if let Some(code) = handle.exit_code {
format!("Exited ({})", code)
} else {
"Stopped".to_string()
}
}
}
}
async fn handle_exec(
req: arcbox_protocol::agent::ExecRequest,
state: &Arc<RwLock<AgentState>>,
) -> RpcResponse {
tracing::info!(
"Exec: container_id={}, cmd={:?}",
req.container_id,
req.cmd
);
if req.cmd.is_empty() {
return RpcResponse::Error(ErrorResponse::new(400, "empty command"));
}
if req.container_id.is_empty() {
return execute_on_host(req).await;
}
let state = state.read().await;
let container = match state.runtime.get_container(&req.container_id) {
Some(c) => c,
None => {
return RpcResponse::Error(ErrorResponse::new(
404,
format!("container not found: {}", req.container_id),
));
}
};
if container.state != ContainerState::Running {
return RpcResponse::Error(ErrorResponse::new(
400,
format!("container is not running: {}", req.container_id),
));
}
execute_on_host(req).await
}
async fn execute_on_host(req: arcbox_protocol::agent::ExecRequest) -> RpcResponse {
let env: HashMap<String, String> = req.env.into_iter().collect();
let env_vec: Vec<(String, String)> = env.into_iter().collect();
let working_dir = if req.working_dir.is_empty() {
None
} else {
Some(req.working_dir.as_str())
};
match crate::exec::exec(&req.cmd, working_dir, &env_vec, None).await {
Ok(result) => {
let mut output = ExecOutput::default();
output.stream = "stdout".to_string();
output.data = result.stdout;
output.exit_code = result.exit_code;
output.done = true;
RpcResponse::ExecOutput(output)
}
Err(e) => RpcResponse::Error(ErrorResponse::new(500, format!("exec failed: {}", e))),
}
}
async fn handle_logs(req: LogsRequest, state: &Arc<RwLock<AgentState>>) -> RequestResult {
tracing::info!(
"Logs: container_id={}, follow={}, tail={}",
req.container_id,
req.follow,
req.tail
);
{
let state = state.read().await;
if state.runtime.get_container(&req.container_id).is_none() {
return RequestResult::Single(RpcResponse::Error(ErrorResponse::new(
404,
format!("container not found: {}", req.container_id),
)));
}
}
let log_path = format!("/var/log/containers/{}.log", req.container_id);
if req.follow {
return handle_logs_stream(req, log_path).await;
}
let log_data = match std::fs::read_to_string(&log_path) {
Ok(data) => data,
Err(e) => {
if e.kind() == std::io::ErrorKind::NotFound {
return RequestResult::Single(RpcResponse::LogEntry(LogEntry {
stream: "stdout".to_string(),
data: Vec::new(),
timestamp: chrono::Utc::now().timestamp_nanos_opt().unwrap_or(0),
}));
}
return RequestResult::Single(RpcResponse::Error(ErrorResponse::new(
500,
format!("failed to read logs: {}", e),
)));
}
};
let lines: Vec<&str> = log_data.lines().collect();
let output_lines = if req.tail > 0 {
let start = lines.len().saturating_sub(req.tail as usize);
&lines[start..]
} else {
&lines[..]
};
let output = output_lines.join("\n");
RequestResult::Single(RpcResponse::LogEntry(LogEntry {
stream: if req.stdout { "stdout" } else { "stderr" }.to_string(),
data: output.into_bytes(),
timestamp: chrono::Utc::now().timestamp_nanos_opt().unwrap_or(0),
}))
}
async fn handle_logs_stream(req: LogsRequest, log_path: String) -> RequestResult {
let options = LogWatchOptions {
stdout: req.stdout,
stderr: req.stderr,
timestamps: req.timestamps,
tail: req.tail,
since: req.since,
until: req.until,
};
let (cancel_tx, cancel_rx) = mpsc::channel::<()>(1);
match watch_log_file(&log_path, options, cancel_rx).await {
Ok(log_rx) => RequestResult::Stream(log_rx, cancel_tx),
Err(e) => {
tracing::error!("Failed to start log watcher: {}", e);
RequestResult::Single(RpcResponse::Error(ErrorResponse::new(
500,
format!("failed to start log stream: {}", e),
)))
}
}
}
}
#[cfg(not(target_os = "linux"))]
mod stub {
use anyhow::Result;
use super::AGENT_PORT;
pub struct Agent;
impl Agent {
pub fn new() -> Self {
Self
}
pub async fn run(&self) -> Result<()> {
tracing::warn!("Agent is running in stub mode (non-Linux platform)");
tracing::info!("Agent would listen on vsock port {}", AGENT_PORT);
loop {
tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
tracing::debug!("Agent stub heartbeat");
}
}
}
impl Default for Agent {
fn default() -> Self {
Self::new()
}
}
}
#[cfg(target_os = "linux")]
pub use linux::Agent;
#[cfg(not(target_os = "linux"))]
pub use stub::Agent;
pub async fn run() -> Result<()> {
let agent = Agent::new();
agent.run().await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_agent_creation() {
let _agent = Agent::new();
}
}