use std::path::PathBuf;
use serde::Serialize;
#[cfg(target_os = "linux")]
use std::time::Duration;
#[cfg(target_os = "linux")]
use bytes::Bytes;
#[cfg(target_os = "linux")]
use http_body_util::Full;
#[cfg(target_os = "linux")]
use hyper::body::Incoming;
#[cfg(target_os = "linux")]
use hyper::client::conn::http1;
#[cfg(target_os = "linux")]
use hyper::{Method, Request, Response, StatusCode};
#[cfg(target_os = "linux")]
use hyper_util::rt::TokioIo;
#[cfg(target_os = "linux")]
use tokio::net::UnixStream;
#[cfg(target_os = "linux")]
use tokio::time::timeout;
#[cfg(target_os = "linux")]
use tracing::instrument;
#[cfg(target_os = "linux")]
use cellos_core::CellosError;
#[cfg(target_os = "linux")]
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
#[cfg(target_os = "linux")]
const REQUEST_TIMEOUT: Duration = Duration::from_secs(5);
#[cfg(target_os = "linux")]
const SOCKET_POLL_INTERVAL: Duration = Duration::from_millis(50);
#[derive(Clone, Debug)]
pub struct FirecrackerApiClient {
#[allow(dead_code)]
socket_path: PathBuf,
}
#[cfg(target_os = "linux")]
impl FirecrackerApiClient {
pub fn new(socket_path: impl Into<PathBuf>) -> Self {
Self {
socket_path: socket_path.into(),
}
}
#[instrument(skip(self), fields(socket = %self.socket_path.display()))]
pub async fn wait_for_ready(&self) -> Result<(), CellosError> {
let deadline = tokio::time::Instant::now() + CONNECT_TIMEOUT;
loop {
if self.socket_path.exists() {
if UnixStream::connect(&self.socket_path).await.is_ok() {
tracing::debug!("firecracker socket ready");
return Ok(());
}
}
if tokio::time::Instant::now() >= deadline {
return Err(CellosError::Host(format!(
"timed out waiting for Firecracker socket at {}",
self.socket_path.display()
)));
}
tokio::time::sleep(SOCKET_POLL_INTERVAL).await;
}
}
#[instrument(skip(self, body), fields(socket = %self.socket_path.display(), path = path))]
pub async fn put<T: Serialize>(&self, path: &str, body: &T) -> Result<StatusCode, CellosError> {
self.send_json(Method::PUT, path, body).await
}
#[instrument(skip(self, body), fields(socket = %self.socket_path.display(), path = path))]
pub async fn patch<T: Serialize>(
&self,
path: &str,
body: &T,
) -> Result<StatusCode, CellosError> {
self.send_json(Method::PATCH, path, body).await
}
async fn send_json<T: Serialize>(
&self,
method: Method,
path: &str,
body: &T,
) -> Result<StatusCode, CellosError> {
let body_bytes = serde_json::to_vec(body)
.map_err(|e| CellosError::Host(format!("serialize firecracker request: {e}")))?;
let req = Request::builder()
.method(method)
.uri(format!("http://localhost{path}"))
.header("Content-Type", "application/json")
.header("Accept", "application/json")
.header("Host", "localhost")
.body(Full::new(Bytes::from(body_bytes)))
.map_err(|e| CellosError::Host(format!("build firecracker request: {e}")))?;
let status = timeout(REQUEST_TIMEOUT, self.send_request(req))
.await
.map_err(|_| {
CellosError::Host(format!(
"firecracker API request to {path} timed out after {}s",
REQUEST_TIMEOUT.as_secs()
))
})??;
Ok(status)
}
async fn send_request(&self, req: Request<Full<Bytes>>) -> Result<StatusCode, CellosError> {
let stream = UnixStream::connect(&self.socket_path).await.map_err(|e| {
CellosError::Host(format!(
"connect to firecracker socket {}: {e}",
self.socket_path.display()
))
})?;
let io = TokioIo::new(stream);
let (mut sender, conn) = http1::handshake::<_, Full<Bytes>>(io)
.await
.map_err(|e| CellosError::Host(format!("firecracker HTTP handshake: {e}")))?;
tokio::spawn(async move {
if let Err(e) = conn.await {
tracing::debug!(error = %e, "firecracker connection task ended");
}
});
let resp: Response<Incoming> = sender
.send_request(req)
.await
.map_err(|e| CellosError::Host(format!("firecracker HTTP send: {e}")))?;
Ok(resp.status())
}
}
#[derive(Debug, Serialize)]
pub struct MachineConfig {
pub vcpu_count: u32,
pub mem_size_mib: u32,
#[serde(default)]
pub track_dirty_pages: bool,
}
#[derive(Debug, Serialize)]
pub struct BootSource {
pub kernel_image_path: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub boot_args: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct Drive {
pub drive_id: String,
pub path_on_host: String,
pub is_root_device: bool,
pub is_read_only: bool,
}
#[derive(Debug, Serialize)]
pub struct VsockDevice {
pub guest_cid: u32,
pub uds_path: String,
}
#[derive(Debug, Serialize)]
pub struct NetworkInterface {
pub iface_id: String,
pub guest_mac: String,
pub host_dev_name: String,
}
#[derive(Debug, Serialize)]
pub struct InstanceAction {
pub action_type: InstanceActionType,
}
#[derive(Debug, Serialize)]
pub enum InstanceActionType {
#[serde(rename = "InstanceStart")]
InstanceStart,
#[serde(rename = "SendCtrlAltDel")]
SendCtrlAltDel,
}
#[derive(Debug, Serialize)]
pub struct VmStatePatch {
pub state: VmState,
}
#[derive(Debug, Serialize)]
pub enum VmState {
Paused,
Resumed,
}
#[derive(Debug, Serialize)]
pub struct SnapshotCreate {
pub snapshot_type: SnapshotType,
pub snapshot_path: String,
pub mem_file_path: String,
}
#[derive(Debug, Serialize)]
pub enum SnapshotType {
Full,
Diff,
}
#[derive(Debug, Serialize)]
pub struct SnapshotLoad {
pub snapshot_path: String,
pub mem_backend: MemBackend,
#[serde(default)]
pub enable_diff_snapshots: bool,
#[serde(default)]
pub resume_vm: bool,
}
#[derive(Debug, Serialize)]
pub struct MemBackend {
pub backend_type: MemBackendType,
pub backend_path: String,
}
#[derive(Debug, Serialize)]
pub enum MemBackendType {
File,
Uffd,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn machine_config_serializes() {
let cfg = MachineConfig {
vcpu_count: 1,
mem_size_mib: 128,
track_dirty_pages: false,
};
let json = serde_json::to_string(&cfg).unwrap();
assert!(json.contains("\"vcpu_count\":1"));
assert!(json.contains("\"mem_size_mib\":128"));
}
#[test]
fn instance_start_action_serializes() {
let act = InstanceAction {
action_type: InstanceActionType::InstanceStart,
};
let json = serde_json::to_string(&act).unwrap();
assert!(json.contains("InstanceStart"));
}
#[test]
fn send_ctrl_alt_del_serializes() {
let act = InstanceAction {
action_type: InstanceActionType::SendCtrlAltDel,
};
let json = serde_json::to_string(&act).unwrap();
assert!(json.contains("SendCtrlAltDel"));
}
#[test]
fn boot_source_omits_optional_boot_args() {
let src = BootSource {
kernel_image_path: "/vmlinux".into(),
boot_args: None,
};
let json = serde_json::to_string(&src).unwrap();
assert!(!json.contains("boot_args"));
}
#[test]
fn vsock_device_serializes() {
let dev = VsockDevice {
guest_cid: 3,
uds_path: "/tmp/cellos-vsock.socket".into(),
};
let json = serde_json::to_string(&dev).unwrap();
assert!(json.contains("\"guest_cid\":3"));
assert!(json.contains("cellos-vsock.socket"));
}
#[test]
fn network_interface_serializes() {
let ni = NetworkInterface {
iface_id: "eth0".into(),
guest_mac: "AA:FC:00:00:00:01".into(),
host_dev_name: "cfc-abcd1234".into(),
};
let json = serde_json::to_string(&ni).unwrap();
assert!(json.contains("\"iface_id\":\"eth0\""));
assert!(json.contains("\"guest_mac\":\"AA:FC:00:00:00:01\""));
assert!(json.contains("\"host_dev_name\":\"cfc-abcd1234\""));
}
#[test]
fn vm_state_patch_paused_serializes() {
let p = VmStatePatch {
state: VmState::Paused,
};
let json = serde_json::to_string(&p).unwrap();
assert!(json.contains("\"state\":\"Paused\""), "got {json}");
}
#[test]
fn snapshot_create_serializes_full_with_paths() {
let s = SnapshotCreate {
snapshot_type: SnapshotType::Full,
snapshot_path: "/tmp/cellos-pool-0.snap".into(),
mem_file_path: "/tmp/cellos-pool-0.mem".into(),
};
let json = serde_json::to_string(&s).unwrap();
assert!(json.contains("\"snapshot_type\":\"Full\""), "got {json}");
assert!(json.contains("/tmp/cellos-pool-0.snap"));
assert!(json.contains("/tmp/cellos-pool-0.mem"));
}
#[test]
fn snapshot_load_serializes_with_file_backend_and_resume() {
let s = SnapshotLoad {
snapshot_path: "/tmp/cellos-pool-0.snap".into(),
mem_backend: MemBackend {
backend_type: MemBackendType::File,
backend_path: "/tmp/cellos-pool-0.mem".into(),
},
enable_diff_snapshots: false,
resume_vm: true,
};
let json = serde_json::to_string(&s).unwrap();
assert!(json.contains("\"backend_type\":\"File\""), "got {json}");
assert!(json.contains("\"resume_vm\":true"));
assert!(json.contains("/tmp/cellos-pool-0.snap"));
assert!(json.contains("/tmp/cellos-pool-0.mem"));
}
#[test]
fn boot_source_includes_boot_args_when_set() {
let src = BootSource {
kernel_image_path: "/vmlinux".into(),
boot_args: Some("console=ttyS0 reboot=k panic=1".into()),
};
let json = serde_json::to_string(&src).unwrap();
assert!(json.contains("boot_args"));
assert!(json.contains("console=ttyS0"));
}
}