use anyhow::{Context, Result, bail};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::time::timeout;
#[cfg(unix)]
use tokio_vsock::{VsockAddr, VsockStream};
#[allow(dead_code)]
pub const AGENT_PORT: u32 = 52000;
#[allow(dead_code)]
pub const HOST_CID: u32 = 2;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RequestType {
Run,
Shell,
ShellInput,
ShellResize,
ShellClose,
Ping,
Shutdown,
WriteFile,
ReadFile,
RemoveFile,
Mkdir,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentRequest {
pub id: String,
#[serde(rename = "type")]
pub request_type: RequestType,
#[serde(skip_serializing_if = "Option::is_none")]
pub command: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cwd: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub env: Option<HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub path: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content_base64: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub recursive: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub session_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub rows: Option<u16>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cols: Option<u16>,
#[serde(skip_serializing_if = "Option::is_none")]
pub input_base64: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentResponse {
pub id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub exit_code: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stdout: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stderr: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content_base64: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub session_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_base64: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub shell_event: Option<ShellEvent>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ShellEvent {
Started,
Output,
Exited,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct RunResult {
pub exit_code: i32,
pub stdout: String,
pub stderr: String,
}
#[cfg(unix)]
pub struct VsockConnection {
stream: tokio::net::UnixStream,
timeout_secs: u64,
}
#[cfg(unix)]
impl VsockConnection {
pub async fn connect(uds_path: impl AsRef<std::path::Path>, port: u32) -> Result<Self> {
use tokio::net::UnixStream;
let mut stream = timeout(
Duration::from_secs(30),
UnixStream::connect(uds_path.as_ref()),
)
.await
.context("Connection timeout")?
.context("Failed to connect to Firecracker vsock socket")?;
let connect_cmd = format!("CONNECT {}\n", port);
stream
.write_all(connect_cmd.as_bytes())
.await
.context("Failed to send CONNECT")?;
stream.flush().await?;
let mut response_buf = [0u8; 32];
let n = timeout(Duration::from_secs(5), stream.read(&mut response_buf))
.await
.context("Timeout waiting for CONNECT response")?
.context("Failed to read CONNECT response")?;
let response_str = std::str::from_utf8(&response_buf[..n])
.context("Invalid CONNECT response")?
.trim();
if !response_str.starts_with("OK ") {
bail!("Firecracker vsock CONNECT failed: {}", response_str);
}
Ok(Self {
stream,
timeout_secs: 30,
})
}
pub async fn run_command(&mut self, command: &[String]) -> Result<RunResult> {
let request = AgentRequest {
id: uuid::Uuid::new_v4().to_string(),
request_type: RequestType::Run,
command: Some(command.to_vec()),
cwd: None,
env: None,
path: None,
content_base64: None,
recursive: None,
session_id: None,
rows: None,
cols: None,
input_base64: None,
};
let response = self.send_request(&request).await?;
if let Some(error) = response.error {
bail!("Guest agent error: {}", error);
}
Ok(RunResult {
exit_code: response.exit_code.unwrap_or(-1),
stdout: response.stdout.unwrap_or_default(),
stderr: response.stderr.unwrap_or_default(),
})
}
async fn send_request(&mut self, request: &AgentRequest) -> Result<AgentResponse> {
let request_bytes = serde_json::to_vec(request)?;
let len = request_bytes.len() as u32;
self.stream.write_all(&len.to_le_bytes()).await?;
self.stream.write_all(&request_bytes).await?;
self.stream.flush().await?;
let mut len_bytes = [0u8; 4];
timeout(
Duration::from_secs(self.timeout_secs),
self.stream.read_exact(&mut len_bytes),
)
.await
.context("Read timeout")?
.context("Failed to read response length")?;
let len = u32::from_le_bytes(len_bytes) as usize;
if len > 10 * 1024 * 1024 {
bail!("Response too large: {} bytes", len);
}
let mut response_bytes = vec![0u8; len];
timeout(
Duration::from_secs(self.timeout_secs),
self.stream.read_exact(&mut response_bytes),
)
.await
.context("Read timeout")?
.context("Failed to read response body")?;
let response: AgentResponse =
serde_json::from_slice(&response_bytes).context("Failed to parse response")?;
Ok(response)
}
#[allow(dead_code)]
pub async fn ping(&mut self) -> bool {
let request = AgentRequest {
id: "ping".to_string(),
request_type: RequestType::Ping,
command: None,
cwd: None,
env: None,
path: None,
content_base64: None,
recursive: None,
session_id: None,
rows: None,
cols: None,
input_base64: None,
};
self.send_request(&request).await.is_ok()
}
}
#[allow(dead_code)]
pub struct VsockClient {
cid: u32,
port: u32,
timeout_secs: u64,
uds_path: Option<std::path::PathBuf>,
}
#[allow(dead_code)]
impl VsockClient {
pub fn new(cid: u32) -> Self {
Self {
cid,
port: AGENT_PORT,
timeout_secs: 30,
uds_path: None,
}
}
pub fn for_firecracker(uds_path: impl Into<std::path::PathBuf>) -> Self {
Self {
cid: 0, port: AGENT_PORT,
timeout_secs: 30,
uds_path: Some(uds_path.into()),
}
}
#[allow(dead_code)]
pub fn with_port(mut self, port: u32) -> Self {
self.port = port;
self
}
#[allow(dead_code)]
pub fn with_timeout(mut self, secs: u64) -> Self {
self.timeout_secs = secs;
self
}
#[cfg(unix)]
pub async fn run_command(&self, command: &[String]) -> Result<RunResult> {
let request = AgentRequest {
id: uuid::Uuid::new_v4().to_string(),
request_type: RequestType::Run,
command: Some(command.to_vec()),
cwd: None,
env: None,
path: None,
content_base64: None,
recursive: None,
session_id: None,
rows: None,
cols: None,
input_base64: None,
};
let response = self.send_request(&request).await?;
if let Some(error) = response.error {
bail!("Guest agent error: {}", error);
}
Ok(RunResult {
exit_code: response.exit_code.unwrap_or(-1),
stdout: response.stdout.unwrap_or_default(),
stderr: response.stderr.unwrap_or_default(),
})
}
#[cfg(unix)]
#[allow(dead_code)]
pub async fn run_command_with_env(
&self,
command: &[String],
cwd: Option<&str>,
env: Option<HashMap<String, String>>,
) -> Result<RunResult> {
let request = AgentRequest {
id: uuid::Uuid::new_v4().to_string(),
request_type: RequestType::Run,
command: Some(command.to_vec()),
cwd: cwd.map(|s| s.to_string()),
env,
path: None,
content_base64: None,
recursive: None,
session_id: None,
rows: None,
cols: None,
input_base64: None,
};
let response = self.send_request(&request).await?;
if let Some(error) = response.error {
bail!("Guest agent error: {}", error);
}
Ok(RunResult {
exit_code: response.exit_code.unwrap_or(-1),
stdout: response.stdout.unwrap_or_default(),
stderr: response.stderr.unwrap_or_default(),
})
}
#[cfg(unix)]
#[allow(dead_code)]
pub async fn ping(&self) -> Result<bool> {
let request = AgentRequest {
id: uuid::Uuid::new_v4().to_string(),
request_type: RequestType::Ping,
command: None,
cwd: None,
env: None,
path: None,
content_base64: None,
recursive: None,
session_id: None,
rows: None,
cols: None,
input_base64: None,
};
match self.send_request(&request).await {
Ok(_) => Ok(true),
Err(_) => Ok(false),
}
}
#[cfg(unix)]
#[allow(dead_code)]
pub async fn shutdown(&self) -> Result<()> {
let request = AgentRequest {
id: uuid::Uuid::new_v4().to_string(),
request_type: RequestType::Shutdown,
command: None,
cwd: None,
env: None,
path: None,
content_base64: None,
recursive: None,
session_id: None,
rows: None,
cols: None,
input_base64: None,
};
let _ = self.send_request(&request).await;
Ok(())
}
#[cfg(unix)]
pub async fn write_file(&self, path: &str, content: &[u8]) -> Result<()> {
use base64::{Engine, engine::general_purpose::STANDARD};
let request = AgentRequest {
id: uuid::Uuid::new_v4().to_string(),
request_type: RequestType::WriteFile,
command: None,
cwd: None,
env: None,
path: Some(path.to_string()),
content_base64: Some(STANDARD.encode(content)),
recursive: None,
session_id: None,
rows: None,
cols: None,
input_base64: None,
};
let response = self.send_request(&request).await?;
if let Some(error) = response.error {
bail!("Failed to write file: {}", error);
}
Ok(())
}
#[cfg(unix)]
pub async fn read_file(&self, path: &str) -> Result<Vec<u8>> {
use base64::{Engine, engine::general_purpose::STANDARD};
let request = AgentRequest {
id: uuid::Uuid::new_v4().to_string(),
request_type: RequestType::ReadFile,
command: None,
cwd: None,
env: None,
path: Some(path.to_string()),
content_base64: None,
recursive: None,
session_id: None,
rows: None,
cols: None,
input_base64: None,
};
let response = self.send_request(&request).await?;
if let Some(error) = response.error {
bail!("Failed to read file: {}", error);
}
let content_base64 = response
.content_base64
.ok_or_else(|| anyhow::anyhow!("No content in response"))?;
let content = STANDARD
.decode(&content_base64)
.context("Failed to decode file content")?;
Ok(content)
}
#[cfg(unix)]
pub async fn remove_file(&self, path: &str) -> Result<()> {
let request = AgentRequest {
id: uuid::Uuid::new_v4().to_string(),
request_type: RequestType::RemoveFile,
command: None,
cwd: None,
env: None,
path: Some(path.to_string()),
content_base64: None,
recursive: None,
session_id: None,
rows: None,
cols: None,
input_base64: None,
};
let response = self.send_request(&request).await?;
if let Some(error) = response.error {
bail!("Failed to remove file: {}", error);
}
Ok(())
}
#[cfg(unix)]
pub async fn mkdir(&self, path: &str, recursive: bool) -> Result<()> {
let request = AgentRequest {
id: uuid::Uuid::new_v4().to_string(),
request_type: RequestType::Mkdir,
command: None,
cwd: None,
env: None,
path: Some(path.to_string()),
content_base64: None,
recursive: Some(recursive),
session_id: None,
rows: None,
cols: None,
input_base64: None,
};
let response = self.send_request(&request).await?;
if let Some(error) = response.error {
bail!("Failed to create directory: {}", error);
}
Ok(())
}
#[cfg(unix)]
#[allow(dead_code)]
pub async fn start_shell(
&self,
command: Option<Vec<String>>,
rows: u16,
cols: u16,
env: Option<HashMap<String, String>>,
) -> Result<String> {
let request = AgentRequest {
id: uuid::Uuid::new_v4().to_string(),
request_type: RequestType::Shell,
command,
cwd: None,
env,
path: None,
content_base64: None,
recursive: None,
session_id: None,
rows: Some(rows),
cols: Some(cols),
input_base64: None,
};
let response = self.send_request(&request).await?;
if let Some(error) = response.error {
bail!("Failed to start shell: {}", error);
}
response
.session_id
.ok_or_else(|| anyhow::anyhow!("No session ID in shell response"))
}
#[cfg(unix)]
#[allow(dead_code)]
pub async fn shell_input(&self, session_id: &str, data: &[u8]) -> Result<()> {
use base64::{Engine, engine::general_purpose::STANDARD};
let request = AgentRequest {
id: uuid::Uuid::new_v4().to_string(),
request_type: RequestType::ShellInput,
command: None,
cwd: None,
env: None,
path: None,
content_base64: None,
recursive: None,
session_id: Some(session_id.to_string()),
rows: None,
cols: None,
input_base64: Some(STANDARD.encode(data)),
};
let response = self.send_request(&request).await?;
if let Some(error) = response.error {
bail!("Failed to send shell input: {}", error);
}
Ok(())
}
#[cfg(unix)]
#[allow(dead_code)]
pub async fn shell_resize(&self, session_id: &str, rows: u16, cols: u16) -> Result<()> {
let request = AgentRequest {
id: uuid::Uuid::new_v4().to_string(),
request_type: RequestType::ShellResize,
command: None,
cwd: None,
env: None,
path: None,
content_base64: None,
recursive: None,
session_id: Some(session_id.to_string()),
rows: Some(rows),
cols: Some(cols),
input_base64: None,
};
let response = self.send_request(&request).await?;
if let Some(error) = response.error {
bail!("Failed to resize shell: {}", error);
}
Ok(())
}
#[cfg(unix)]
#[allow(dead_code)]
pub async fn shell_close(&self, session_id: &str) -> Result<i32> {
let request = AgentRequest {
id: uuid::Uuid::new_v4().to_string(),
request_type: RequestType::ShellClose,
command: None,
cwd: None,
env: None,
path: None,
content_base64: None,
recursive: None,
session_id: Some(session_id.to_string()),
rows: None,
cols: None,
input_base64: None,
};
let response = self.send_request(&request).await?;
if let Some(error) = response.error {
bail!("Failed to close shell: {}", error);
}
Ok(response.exit_code.unwrap_or(-1))
}
#[cfg(unix)]
async fn send_request(&self, request: &AgentRequest) -> Result<AgentResponse> {
if let Some(ref uds_path) = self.uds_path {
self.send_request_via_firecracker(request, uds_path).await
} else {
self.send_request_via_native_vsock(request).await
}
}
#[cfg(unix)]
async fn send_request_via_native_vsock(&self, request: &AgentRequest) -> Result<AgentResponse> {
let addr = VsockAddr::new(self.cid, self.port);
let mut stream = timeout(
Duration::from_secs(self.timeout_secs),
VsockStream::connect(addr),
)
.await
.context("Connection timeout")?
.context("Failed to connect to guest agent")?;
self.send_and_receive(&mut stream, request).await
}
#[cfg(unix)]
async fn send_request_via_firecracker(
&self,
request: &AgentRequest,
uds_path: &std::path::Path,
) -> Result<AgentResponse> {
use tokio::net::UnixStream;
let mut stream = timeout(
Duration::from_secs(self.timeout_secs),
UnixStream::connect(uds_path),
)
.await
.context("Connection timeout")?
.context("Failed to connect to Firecracker vsock socket")?;
let connect_cmd = format!("CONNECT {}\n", self.port);
stream
.write_all(connect_cmd.as_bytes())
.await
.context("Failed to send CONNECT")?;
stream.flush().await?;
let mut response_buf = [0u8; 32];
let n = timeout(Duration::from_secs(5), stream.read(&mut response_buf))
.await
.context("Timeout waiting for CONNECT response")?
.context("Failed to read CONNECT response")?;
let response_str = std::str::from_utf8(&response_buf[..n])
.context("Invalid CONNECT response")?
.trim();
if !response_str.starts_with("OK ") {
bail!("Firecracker vsock CONNECT failed: {}", response_str);
}
self.send_and_receive(&mut stream, request).await
}
#[cfg(unix)]
async fn send_and_receive<S>(
&self,
stream: &mut S,
request: &AgentRequest,
) -> Result<AgentResponse>
where
S: AsyncReadExt + AsyncWriteExt + Unpin,
{
let request_bytes = serde_json::to_vec(request)?;
let len = request_bytes.len() as u32;
stream.write_all(&len.to_le_bytes()).await?;
stream.write_all(&request_bytes).await?;
stream.flush().await?;
let mut len_bytes = [0u8; 4];
timeout(
Duration::from_secs(self.timeout_secs),
stream.read_exact(&mut len_bytes),
)
.await
.context("Read timeout")?
.context("Failed to read response length")?;
let len = u32::from_le_bytes(len_bytes) as usize;
if len > 10 * 1024 * 1024 {
bail!("Response too large: {} bytes", len);
}
let mut response_bytes = vec![0u8; len];
timeout(
Duration::from_secs(self.timeout_secs),
stream.read_exact(&mut response_bytes),
)
.await
.context("Read timeout")?
.context("Failed to read response body")?;
let response: AgentResponse =
serde_json::from_slice(&response_bytes).context("Failed to parse response")?;
Ok(response)
}
#[cfg(not(unix))]
pub async fn run_command(&self, _command: &[String]) -> Result<RunResult> {
bail!("Vsock is only supported on Unix platforms");
}
#[cfg(not(unix))]
#[allow(dead_code)]
pub async fn run_command_with_env(
&self,
_command: &[String],
_cwd: Option<&str>,
_env: Option<HashMap<String, String>>,
) -> Result<RunResult> {
bail!("Vsock is only supported on Unix platforms");
}
#[cfg(not(unix))]
#[allow(dead_code)]
pub async fn ping(&self) -> Result<bool> {
bail!("Vsock is only supported on Unix platforms");
}
#[cfg(not(unix))]
#[allow(dead_code)]
pub async fn shutdown(&self) -> Result<()> {
bail!("Vsock is only supported on Unix platforms");
}
}
#[cfg(unix)]
#[allow(dead_code)]
pub async fn wait_for_agent(cid: u32, timeout_secs: u64) -> Result<()> {
let client = VsockClient::new(cid).with_timeout(5);
let deadline = std::time::Instant::now() + Duration::from_secs(timeout_secs);
while std::time::Instant::now() < deadline {
if client.ping().await.unwrap_or(false) {
return Ok(());
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
bail!("Guest agent not available after {} seconds", timeout_secs);
}
#[cfg(not(unix))]
#[allow(dead_code)]
pub async fn wait_for_agent(_cid: u32, _timeout_secs: u64) -> Result<()> {
bail!("Vsock is only supported on Unix platforms");
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_request_serialize() {
let request = AgentRequest {
id: "test-123".to_string(),
request_type: RequestType::Run,
command: Some(vec!["ls".to_string(), "-la".to_string()]),
cwd: Some("/app".to_string()),
env: None,
path: None,
content_base64: None,
recursive: None,
session_id: None,
rows: None,
cols: None,
input_base64: None,
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("\"type\":\"run\""));
assert!(json.contains("\"command\":[\"ls\",\"-la\"]"));
assert!(json.contains("\"cwd\":\"/app\""));
}
#[test]
fn test_write_file_request_serialize() {
let request = AgentRequest {
id: "test-456".to_string(),
request_type: RequestType::WriteFile,
command: None,
cwd: None,
env: None,
path: Some("/tmp/test.txt".to_string()),
content_base64: Some("SGVsbG8gV29ybGQ=".to_string()), recursive: None,
session_id: None,
rows: None,
cols: None,
input_base64: None,
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("\"type\":\"write_file\""));
assert!(json.contains("\"path\":\"/tmp/test.txt\""));
assert!(json.contains("\"content_base64\":\"SGVsbG8gV29ybGQ=\""));
}
#[test]
fn test_response_deserialize() {
let json = r#"{
"id": "test-123",
"exit_code": 0,
"stdout": "hello world\n",
"stderr": ""
}"#;
let response: AgentResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.id, "test-123");
assert_eq!(response.exit_code, Some(0));
assert_eq!(response.stdout, Some("hello world\n".to_string()));
}
}