#![deny(missing_docs)]
use serde::{Deserialize, Serialize};
pub mod retry;
pub mod base64_bytes {
use base64::{engine::general_purpose::STANDARD, Engine};
use serde::{Deserialize, Deserializer, Serializer};
pub fn serialize<S: Serializer>(data: &[u8], serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str(&STANDARD.encode(data))
}
pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<Vec<u8>, D::Error> {
let s = String::deserialize(deserializer)?;
STANDARD.decode(&s).map_err(serde::de::Error::custom)
}
}
pub const PROTOCOL_VERSION: u32 = 1;
pub const MAX_FRAME_SIZE: u32 = 32 * 1024 * 1024;
pub const LAYER_CHUNK_SIZE: usize = 16 * 1024 * 1024;
pub const FILE_WRITE_SINGLE_SHOT_MAX: usize = 1024 * 1024;
pub const FILE_WRITE_CHUNK_SIZE: usize = FILE_WRITE_SINGLE_SHOT_MAX;
pub const FILE_TRANSFER_MAX_TOTAL: u64 = 4 * 1024 * 1024 * 1024;
pub mod ports {
pub const WORKLOAD_CONTROL: u32 = 5000;
pub const WORKLOAD_LOGS: u32 = 5001;
pub const AGENT_CONTROL: u32 = 6000;
pub const SSH_AGENT: u32 = 6001;
pub const DNS_FILTER: u32 = 6002;
}
pub mod cid {
pub const HOST: u32 = 2;
pub const GUEST: u32 = 3;
pub const ANY: u32 = u32::MAX;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method", rename_all = "snake_case")]
pub enum AgentRequest {
Ping,
Pull {
image: String,
oci_platform: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
auth: Option<RegistryAuth>,
},
Query {
image: String,
},
ListImages,
GarbageCollect {
dry_run: bool,
#[serde(default)]
purge_all: bool,
},
PrepareOverlay {
image: String,
workload_id: String,
},
CleanupOverlay {
workload_id: String,
},
FormatStorage,
StorageStatus,
NetworkTest {
url: String,
},
Shutdown,
ExportLayer {
image_digest: String,
layer_index: usize,
},
VmExec {
command: Vec<String>,
#[serde(default)]
env: Vec<(String, String)>,
workdir: Option<String>,
#[serde(default)]
timeout_ms: Option<u64>,
#[serde(default)]
interactive: bool,
#[serde(default)]
tty: bool,
#[serde(default)]
background: bool,
},
Run {
image: String,
command: Vec<String>,
#[serde(default)]
env: Vec<(String, String)>,
workdir: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
user: Option<String>,
#[serde(default)]
mounts: Vec<(String, String, bool)>,
#[serde(default)]
timeout_ms: Option<u64>,
#[serde(default)]
interactive: bool,
#[serde(default)]
tty: bool,
#[serde(default)]
detached: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
persistent_overlay_id: Option<String>,
#[serde(default)]
background: bool,
},
Stdin {
#[serde(with = "base64_bytes")]
data: Vec<u8>,
},
Resize {
cols: u16,
rows: u16,
},
FileWrite {
path: String,
#[serde(with = "base64_bytes")]
data: Vec<u8>,
#[serde(default)]
mode: Option<u32>,
},
FileWriteBegin {
path: String,
#[serde(default)]
mode: Option<u32>,
total_size: u64,
},
FileWriteChunk {
#[serde(with = "base64_bytes")]
data: Vec<u8>,
done: bool,
},
FileRead {
path: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "status", rename_all = "snake_case")]
pub enum AgentResponse {
Ok {
#[serde(default, skip_serializing_if = "Option::is_none")]
data: Option<serde_json::Value>,
},
Pong {
version: u32,
},
Progress {
message: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
percent: Option<u8>,
#[serde(default, skip_serializing_if = "Option::is_none")]
layer: Option<String>,
},
Error {
message: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
code: Option<String>,
},
Completed {
exit_code: i32,
#[serde(with = "base64_bytes")]
stdout: Vec<u8>,
#[serde(with = "base64_bytes")]
stderr: Vec<u8>,
},
Started,
Stdout {
#[serde(with = "base64_bytes")]
data: Vec<u8>,
},
Stderr {
#[serde(with = "base64_bytes")]
data: Vec<u8>,
},
Exited {
exit_code: i32,
},
DataChunk {
#[serde(with = "base64_bytes")]
data: Vec<u8>,
done: bool,
},
}
pub mod error_codes {
pub const INVALID_REQUEST: &str = "INVALID_REQUEST";
pub const NOT_FOUND: &str = "NOT_FOUND";
pub const INTERNAL_ERROR: &str = "INTERNAL_ERROR";
pub const PULL_FAILED: &str = "PULL_FAILED";
pub const QUERY_FAILED: &str = "QUERY_FAILED";
pub const RUN_FAILED: &str = "RUN_FAILED";
pub const EXEC_FAILED: &str = "EXEC_FAILED";
pub const SPAWN_FAILED: &str = "SPAWN_FAILED";
pub const MOUNT_FAILED: &str = "MOUNT_FAILED";
pub const FILE_IO_FAILED: &str = "FILE_IO_FAILED";
pub const OVERLAY_FAILED: &str = "OVERLAY_FAILED";
pub const CLEANUP_FAILED: &str = "CLEANUP_FAILED";
pub const FORMAT_FAILED: &str = "FORMAT_FAILED";
pub const STATUS_FAILED: &str = "STATUS_FAILED";
pub const LIST_FAILED: &str = "LIST_FAILED";
pub const GC_FAILED: &str = "GC_FAILED";
pub const CREATE_FAILED: &str = "CREATE_FAILED";
pub const START_FAILED: &str = "START_FAILED";
pub const STOP_FAILED: &str = "STOP_FAILED";
pub const DELETE_FAILED: &str = "DELETE_FAILED";
pub const EXPORT_FAILED: &str = "EXPORT_FAILED";
pub const SERIALIZATION_ERROR: &str = "SERIALIZATION_ERROR";
pub const MESSAGE_TOO_LARGE: &str = "MESSAGE_TOO_LARGE";
pub const WAIT_FAILED: &str = "WAIT_FAILED";
}
impl AgentResponse {
pub fn error(message: impl Into<String>, code: &str) -> Self {
AgentResponse::Error {
message: message.into(),
code: Some(code.to_string()),
}
}
pub fn from_err<E: std::fmt::Display>(err: E, code: &str) -> Self {
AgentResponse::Error {
message: err.to_string(),
code: Some(code.to_string()),
}
}
pub fn ok(data: Option<serde_json::Value>) -> Self {
AgentResponse::Ok { data }
}
pub fn ok_with_data<T: serde::Serialize>(data: T) -> Self {
match serde_json::to_value(data) {
Ok(value) => AgentResponse::Ok { data: Some(value) },
Err(e) => AgentResponse::error(
format!("failed to serialize response: {}", e),
error_codes::SERIALIZATION_ERROR,
),
}
}
pub fn from_result<T, E>(result: Result<T, E>, error_code: &str) -> Self
where
T: serde::Serialize,
E: std::fmt::Display,
{
match result {
Ok(data) => Self::ok_with_data(data),
Err(e) => Self::from_err(e, error_code),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageInfo {
pub reference: String,
pub digest: String,
pub size: u64,
pub created: Option<String>,
pub architecture: String,
pub os: String,
pub layer_count: usize,
pub layers: Vec<String>,
#[serde(default)]
pub entrypoint: Vec<String>,
#[serde(default)]
pub cmd: Vec<String>,
#[serde(default)]
pub env: Vec<String>,
#[serde(default)]
pub workdir: Option<String>,
#[serde(default)]
pub user: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OverlayInfo {
pub rootfs_path: String,
pub upper_path: String,
pub work_path: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StorageStatus {
pub ready: bool,
pub total_bytes: u64,
pub used_bytes: u64,
pub layer_count: usize,
pub image_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RegistryAuth {
pub username: String,
pub password: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum HostMessage {
Auth {
token: String,
protocol_version: u32,
},
Run {
request_id: u64,
command: Vec<String>,
env: Vec<(String, String)>,
workdir: Option<String>,
},
Exec {
request_id: u64,
command: Vec<String>,
tty: bool,
},
Signal {
request_id: u64,
signal: i32,
},
Stop {
timeout_ms: u64,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum GuestMessage {
AuthOk,
AuthFailed,
Ready,
Started {
request_id: u64,
},
Stdout {
request_id: u64,
#[serde(with = "base64_bytes")]
data: Vec<u8>,
truncated: bool,
},
Stderr {
request_id: u64,
#[serde(with = "base64_bytes")]
data: Vec<u8>,
truncated: bool,
},
Exit {
request_id: u64,
code: i32,
reason: String,
},
Error {
request_id: Option<u64>,
message: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Envelope<T> {
#[serde(skip_serializing_if = "Option::is_none", default)]
pub trace_id: Option<String>,
#[serde(flatten)]
pub body: T,
}
impl<T> Envelope<T> {
pub fn new(body: T) -> Self {
Self {
trace_id: None,
body,
}
}
pub fn with_trace_id(body: T, trace_id: Option<String>) -> Self {
Self { trace_id, body }
}
}
pub fn encode_message<T: Serialize>(msg: &T) -> Result<Vec<u8>, serde_json::Error> {
let json = serde_json::to_vec(msg)?;
let len = json.len() as u32;
let mut buf = Vec::with_capacity(4 + json.len());
buf.extend_from_slice(&len.to_be_bytes());
buf.extend_from_slice(&json);
Ok(buf)
}
pub fn decode_message<T: for<'de> Deserialize<'de>>(data: &[u8]) -> Result<T, DecodeError> {
if data.len() < 4 {
return Err(DecodeError::TooShort);
}
let len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
if len > MAX_FRAME_SIZE as usize {
return Err(DecodeError::TooLarge(len));
}
if data.len() < 4 + len {
return Err(DecodeError::Incomplete {
expected: len,
got: data.len() - 4,
});
}
serde_json::from_slice(&data[4..4 + len]).map_err(DecodeError::Json)
}
#[derive(Debug)]
pub enum DecodeError {
TooShort,
TooLarge(usize),
Incomplete {
expected: usize,
got: usize,
},
Json(serde_json::Error),
}
impl std::fmt::Display for DecodeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DecodeError::TooShort => write!(f, "data too short for length header"),
DecodeError::TooLarge(size) => write!(f, "frame too large: {} bytes", size),
DecodeError::Incomplete { expected, got } => {
write!(
f,
"incomplete frame: expected {} bytes, got {}",
expected, got
)
}
DecodeError::Json(e) => write!(f, "JSON decode error: {}", e),
}
}
}
impl std::error::Error for DecodeError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_decode_roundtrip() {
let req = AgentRequest::Pull {
image: "alpine:latest".to_string(),
oci_platform: Some("linux/arm64".to_string()),
auth: None,
};
let encoded = encode_message(&req).unwrap();
let decoded: AgentRequest = decode_message(&encoded).unwrap();
let AgentRequest::Pull {
image,
oci_platform,
auth,
} = decoded
else {
panic!("expected Pull variant, got {:?}", decoded);
};
assert_eq!(image, "alpine:latest");
assert_eq!(oci_platform, Some("linux/arm64".to_string()));
assert!(auth.is_none());
}
#[test]
fn test_encode_decode_with_auth() {
let req = AgentRequest::Pull {
image: "ghcr.io/owner/repo:latest".to_string(),
oci_platform: None,
auth: Some(RegistryAuth {
username: "testuser".to_string(),
password: "testpass".to_string(),
}),
};
let encoded = encode_message(&req).unwrap();
let decoded: AgentRequest = decode_message(&encoded).unwrap();
let AgentRequest::Pull {
image,
oci_platform,
auth,
} = decoded
else {
panic!("expected Pull variant, got {:?}", decoded);
};
assert_eq!(image, "ghcr.io/owner/repo:latest");
assert!(oci_platform.is_none());
let auth = auth.expect("auth should be Some");
assert_eq!(auth.username, "testuser");
assert_eq!(auth.password, "testpass");
}
#[test]
fn test_decode_too_short() {
let data = [0u8; 2];
let result: Result<AgentRequest, _> = decode_message(&data);
assert!(matches!(result, Err(DecodeError::TooShort)));
}
#[test]
fn test_decode_incomplete() {
let mut data = vec![0, 0, 0, 100]; data.extend_from_slice(b"{}"); let result: Result<AgentRequest, _> = decode_message(&data);
assert!(matches!(result, Err(DecodeError::Incomplete { .. })));
}
#[test]
fn test_agent_request_serialization() {
let req = AgentRequest::Ping;
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("ping"));
let req = AgentRequest::PrepareOverlay {
image: "ubuntu:22.04".to_string(),
workload_id: "wl-123".to_string(),
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("prepare_overlay"));
}
#[test]
fn test_agent_response_serialization() {
let resp = AgentResponse::Pong {
version: PROTOCOL_VERSION,
};
let json = serde_json::to_string(&resp).unwrap();
assert!(json.contains("pong"));
let resp = AgentResponse::Progress {
message: "Pulling layer 1/3".to_string(),
percent: Some(33),
layer: Some("sha256:abc123".to_string()),
};
let json = serde_json::to_string(&resp).unwrap();
assert!(json.contains("progress"));
}
#[test]
fn file_write_begin_roundtrips() {
let req = AgentRequest::FileWriteBegin {
path: "/tmp/target".into(),
mode: Some(0o600),
total_size: 123_456_789,
};
let bytes = encode_message(&req).unwrap();
let back: AgentRequest = decode_message(&bytes).unwrap();
match back {
AgentRequest::FileWriteBegin {
path,
mode,
total_size,
} => {
assert_eq!(path, "/tmp/target");
assert_eq!(mode, Some(0o600));
assert_eq!(total_size, 123_456_789);
}
_ => panic!("wrong variant"),
}
}
#[test]
fn file_write_chunk_roundtrips_binary_data() {
let payload: Vec<u8> = (0u8..=255).collect();
let req = AgentRequest::FileWriteChunk {
data: payload.clone(),
done: true,
};
let bytes = encode_message(&req).unwrap();
let back: AgentRequest = decode_message(&bytes).unwrap();
match back {
AgentRequest::FileWriteChunk { data, done } => {
assert_eq!(data, payload);
assert!(done);
}
_ => panic!("wrong variant"),
}
}
#[test]
fn file_write_size_constants_are_frame_safe() {
let chunk_bytes = FILE_WRITE_CHUNK_SIZE as u64;
let base64_bytes = chunk_bytes.div_ceil(3) * 4; let json_overhead = 256u64; let total = base64_bytes + json_overhead;
assert!(
total < MAX_FRAME_SIZE as u64,
"FILE_WRITE_CHUNK_SIZE of {} bytes would produce a frame \
of ~{} bytes which exceeds MAX_FRAME_SIZE of {}",
chunk_bytes,
total,
MAX_FRAME_SIZE
);
assert!(FILE_WRITE_SINGLE_SHOT_MAX <= FILE_WRITE_CHUNK_SIZE);
}
#[test]
fn test_ports_constants() {
assert_eq!(ports::WORKLOAD_CONTROL, 5000);
assert_eq!(ports::WORKLOAD_LOGS, 5001);
assert_eq!(ports::AGENT_CONTROL, 6000);
assert_eq!(ports::SSH_AGENT, 6001);
}
#[test]
fn test_cid_constants() {
assert_eq!(cid::HOST, 2);
assert_eq!(cid::GUEST, 3);
}
#[test]
fn test_envelope_serialization_with_trace_id() {
let req = AgentRequest::Ping;
let envelope = Envelope::with_trace_id(&req, Some("abc123".to_string()));
let json = serde_json::to_string(&envelope).unwrap();
assert!(json.contains("\"trace_id\":\"abc123\""));
assert!(json.contains("\"method\":\"ping\""));
let parsed: Envelope<AgentRequest> = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.trace_id.as_deref(), Some("abc123"));
assert!(matches!(parsed.body, AgentRequest::Ping));
}
#[test]
fn test_envelope_without_trace_id() {
let req = AgentRequest::Ping;
let envelope = Envelope::new(&req);
let json = serde_json::to_string(&envelope).unwrap();
assert!(!json.contains("trace_id"));
assert!(json.contains("\"method\":\"ping\""));
}
#[test]
fn test_envelope_backward_compat_bare_request() {
let bare_json = r#"{"method":"ping"}"#;
let envelope_result = serde_json::from_str::<Envelope<AgentRequest>>(bare_json);
let bare_result = serde_json::from_str::<AgentRequest>(bare_json);
assert!(
envelope_result.is_ok() || bare_result.is_ok(),
"Neither Envelope nor bare parse succeeded"
);
assert!(bare_result.is_ok());
assert!(matches!(bare_result.unwrap(), AgentRequest::Ping));
if let Ok(env) = envelope_result {
assert!(env.trace_id.is_none());
}
}
}