#![deny(
clippy::unwrap_used,
clippy::expect_used,
clippy::indexing_slicing,
clippy::panic
)]
use std::{sync::Arc, time::Duration};
use arc_swap::ArcSwap;
use parking_lot::Mutex;
use squib_core::LifecyclePhase;
use tokio::sync::{mpsc, oneshot};
use tracing::error;
use crate::{
action::{ActionClass, ApiAction, ApiResponse},
error::ApiError,
schemas::{InstanceAction, InstanceInfo, MAX_DRIVES, MAX_NICS, MAX_PMEM, VmState},
};
#[derive(Debug, Clone, Copy)]
pub struct TimeoutTable {
pub pre_boot_config: Duration,
pub instance_start: Duration,
pub snapshot_create: Duration,
pub snapshot_load: Duration,
pub vm_state_change: Duration,
pub balloon_resize: Duration,
pub other: Duration,
}
impl TimeoutTable {
#[must_use]
pub const fn from_spec() -> Self {
Self {
pre_boot_config: Duration::from_secs(5),
instance_start: Duration::from_secs(30),
snapshot_create: Duration::from_mins(5),
snapshot_load: Duration::from_mins(5),
vm_state_change: Duration::from_secs(5),
balloon_resize: Duration::from_secs(30),
other: Duration::from_secs(5),
}
}
#[must_use]
pub const fn for_class(&self, class: ActionClass) -> Duration {
match class {
ActionClass::PreBootConfig => self.pre_boot_config,
ActionClass::InstanceStart => self.instance_start,
ActionClass::SnapshotCreate => self.snapshot_create,
ActionClass::SnapshotLoad => self.snapshot_load,
ActionClass::VmStateChange => self.vm_state_change,
ActionClass::BalloonResize => self.balloon_resize,
ActionClass::Other => self.other,
}
}
}
impl Default for TimeoutTable {
fn default() -> Self {
Self::from_spec()
}
}
#[derive(Debug, Clone)]
pub struct ControllerSnapshot {
pub instance_info: InstanceInfo,
pub firecracker_version: String,
pub vm_config: Arc<serde_json::Value>,
pub phase: LifecyclePhase,
}
impl ControllerSnapshot {
pub fn new(
instance_id: impl Into<String>,
firecracker_version: impl Into<String>,
vmm_version: impl Into<String>,
) -> Self {
let firecracker_version = firecracker_version.into();
Self {
instance_info: InstanceInfo {
id: instance_id.into(),
state: VmState::NotStarted,
vmm_version: vmm_version.into(),
app_name: "Firecracker".into(),
},
firecracker_version,
vm_config: Arc::new(serde_json::json!({})),
phase: LifecyclePhase::Uninitialized,
}
}
}
pub type ActionSender = mpsc::Sender<(ApiAction, oneshot::Sender<ApiResponse>)>;
pub type ActionReceiver = mpsc::Receiver<(ApiAction, oneshot::Sender<ApiResponse>)>;
#[derive(Debug)]
pub struct LimitsState {
pub host_ram_mib: u64,
pub mem_size_mib: Option<u64>,
pub running_drives: u32,
pub running_nics: u32,
pub running_pmem: u32,
}
impl LimitsState {
#[must_use]
pub fn from_host_ram_mib(host_ram_mib: u64) -> Self {
Self {
host_ram_mib,
mem_size_mib: None,
running_drives: 0,
running_nics: 0,
running_pmem: 0,
}
}
}
impl Default for LimitsState {
fn default() -> Self {
Self::from_host_ram_mib(1024 * 1024)
}
}
#[derive(Debug)]
pub struct RuntimeApiController {
snapshot: ArcSwap<ControllerSnapshot>,
vmm_tx: ActionSender,
timeouts: TimeoutTable,
limits: Mutex<LimitsState>,
}
impl RuntimeApiController {
#[must_use]
pub fn new(
snapshot: ControllerSnapshot,
timeouts: TimeoutTable,
capacity: usize,
) -> (Self, ActionReceiver) {
Self::new_with_limits(snapshot, timeouts, capacity, LimitsState::default())
}
#[must_use]
pub fn new_with_limits(
snapshot: ControllerSnapshot,
timeouts: TimeoutTable,
capacity: usize,
limits: LimitsState,
) -> (Self, ActionReceiver) {
let (tx, rx) = mpsc::channel(capacity);
let controller = Self {
snapshot: ArcSwap::from(Arc::new(snapshot)),
vmm_tx: tx,
timeouts,
limits: Mutex::new(limits),
};
(controller, rx)
}
#[must_use]
pub fn limits_snapshot(&self) -> LimitsSnapshot {
let g = self.limits.lock();
LimitsSnapshot {
host_ram_mib: g.host_ram_mib,
mem_size_mib: g.mem_size_mib,
running_drives: g.running_drives,
running_nics: g.running_nics,
running_pmem: g.running_pmem,
}
}
fn validate_cross_field(&self, action: &ApiAction) -> Result<(), ApiError> {
let g = self.limits.lock();
match action {
ApiAction::PutMachineConfig(cfg) => {
let req = cfg.mem_size_mib.get();
if req > g.host_ram_mib {
return Err(ApiError::BadRequest(format!(
"mem_size_mib={req} exceeds host RAM cap of {host} MiB",
host = g.host_ram_mib,
)));
}
}
ApiAction::PutBalloon(b) => {
cross_check_balloon(b.amount_mib, g.mem_size_mib)?;
}
ApiAction::PatchBalloon(u) => {
cross_check_balloon(u.amount_mib, g.mem_size_mib)?;
}
ApiAction::PutDrive(_) if u64::from(g.running_drives) >= MAX_DRIVES_AS_U64 => {
return Err(ApiError::BadRequest(format!(
"drives: per-class cap {MAX_DRIVES} exceeded"
)));
}
ApiAction::PutNetwork(_) if u64::from(g.running_nics) >= MAX_NICS_AS_U64 => {
return Err(ApiError::BadRequest(format!(
"network_interfaces: per-class cap {MAX_NICS} exceeded"
)));
}
ApiAction::PutPmem(_) if u64::from(g.running_pmem) >= MAX_PMEM_AS_U64 => {
return Err(ApiError::BadRequest(format!(
"pmem: per-class cap {MAX_PMEM} exceeded"
)));
}
_ => {}
}
Ok(())
}
}
#[derive(Debug, Clone, Copy)]
enum ActionCounterKick {
None,
SetMemSize(u64),
AddDrive,
AddNic,
AddPmem,
}
impl ActionCounterKick {
fn for_action(action: &ApiAction) -> Self {
match action {
ApiAction::PutMachineConfig(cfg) => Self::SetMemSize(cfg.mem_size_mib.get()),
ApiAction::PutDrive(_) => Self::AddDrive,
ApiAction::PutNetwork(_) => Self::AddNic,
ApiAction::PutPmem(_) => Self::AddPmem,
_ => Self::None,
}
}
fn apply(self, ctl: &RuntimeApiController) {
if matches!(self, Self::None) {
return;
}
let mut g = ctl.limits.lock();
match self {
Self::SetMemSize(v) => g.mem_size_mib = Some(v),
Self::AddDrive => g.running_drives = g.running_drives.saturating_add(1),
Self::AddNic => g.running_nics = g.running_nics.saturating_add(1),
Self::AddPmem => g.running_pmem = g.running_pmem.saturating_add(1),
Self::None => {}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct LimitsSnapshot {
pub host_ram_mib: u64,
pub mem_size_mib: Option<u64>,
pub running_drives: u32,
pub running_nics: u32,
pub running_pmem: u32,
}
const MAX_DRIVES_AS_U64: u64 = MAX_DRIVES as u64;
const MAX_NICS_AS_U64: u64 = MAX_NICS as u64;
const MAX_PMEM_AS_U64: u64 = MAX_PMEM as u64;
fn cross_check_balloon(amount_mib: u64, mem_size_mib: Option<u64>) -> Result<(), ApiError> {
let Some(mem) = mem_size_mib else {
return Ok(());
};
let max_balloon = mem.saturating_sub(32);
if amount_mib > max_balloon {
return Err(ApiError::BadRequest(format!(
"balloon amount_mib={amount_mib} exceeds max ({max_balloon} = mem_size_mib {mem} - 32)"
)));
}
Ok(())
}
impl RuntimeApiController {
#[must_use]
pub fn snapshot(&self) -> Arc<ControllerSnapshot> {
self.snapshot.load_full()
}
pub fn store_snapshot(&self, snap: ControllerSnapshot) {
self.snapshot.store(Arc::new(snap));
}
pub fn validate_phase(&self, action: &ApiAction) -> Result<(), ApiError> {
let phase = self.snapshot.load().phase;
if let ApiAction::Action(InstanceAction::SendCtrlAltDel) = action {
return Err(ApiError::BadRequest(
"Invalid action: SendCtrlAltDel is x86-only and not supported on aarch64".into(),
));
}
if matches!(action, ApiAction::Shutdown) {
return Ok(());
}
if phase.is_pre_boot() && !action.is_pre_boot() {
return Err(ApiError::BadRequest(
"The requested operation is not allowed before the microVM has booted".into(),
));
}
if phase.is_post_boot() && !action.is_post_boot() {
return Err(ApiError::BadRequest(
"The requested operation is not supported after the microVM has booted".into(),
));
}
if matches!(phase, LifecyclePhase::Starting) {
return Err(ApiError::BadRequest(
"The requested operation cannot be served during boot orchestration".into(),
));
}
if matches!(phase, LifecyclePhase::Shutdown) {
return Err(ApiError::Internal("VMM is shut down".into()));
}
Ok(())
}
pub async fn dispatch(&self, action: ApiAction) -> Result<ApiResponse, ApiError> {
self.validate_phase(&action)?;
self.validate_cross_field(&action)?;
let class = action.class();
let timeout = self.timeouts.for_class(class);
let label = action.label();
let counter_kick = ActionCounterKick::for_action(&action);
let (resp_tx, resp_rx) = oneshot::channel();
self.vmm_tx
.send((action, resp_tx))
.await
.map_err(|_| ApiError::Internal("VMM event loop is gone".into()))?;
match tokio::time::timeout(timeout, resp_rx).await {
Ok(Ok(resp)) => {
if matches!(resp, ApiResponse::NoContent | ApiResponse::Json { .. }) {
counter_kick.apply(self);
}
Ok(resp)
}
Ok(Err(_)) => Err(ApiError::Internal("VMM event loop is gone".into())),
Err(_) => {
error!(
action = label,
timeout_secs = timeout.as_secs(),
"VMM action timed out; the action remains pending at the VMM",
);
Err(ApiError::Timeout(class.label()))
}
}
}
#[must_use]
pub fn timeouts(&self) -> TimeoutTable {
self.timeouts
}
#[must_use]
pub fn action_sender(&self) -> ActionSender {
self.vmm_tx.clone()
}
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::indexing_slicing,
clippy::panic
)]
mod tests {
use squib_core::LifecyclePhase;
use super::*;
use crate::schemas::{BootSourceConfig, EntropyConfig, VmStateChange};
fn ctl(phase: LifecyclePhase) -> (RuntimeApiController, ActionReceiver) {
let mut snap = ControllerSnapshot::new("anonymous", "1.16.0", "1.16.0 (squib 0.1.0)");
snap.phase = phase;
snap.instance_info.state = phase.wire_state().into();
RuntimeApiController::new(snap, TimeoutTable::from_spec(), 16)
}
fn boot_source() -> BootSourceConfig {
BootSourceConfig::try_from(crate::schemas::boot_source::RawBootSourceConfig {
kernel_image_path: "/tmp/k".into(),
initrd_path: None,
boot_args: None,
})
.unwrap()
}
#[test]
fn test_should_admit_pre_boot_action_in_uninitialized() {
let (c, _rx) = ctl(LifecyclePhase::Uninitialized);
let action = ApiAction::PutBootSource(boot_source());
c.validate_phase(&action).unwrap();
}
#[test]
fn test_should_reject_post_boot_action_in_uninitialized() {
let (c, _rx) = ctl(LifecyclePhase::Uninitialized);
let action = ApiAction::PatchVm(VmStateChange::Paused);
let err = c.validate_phase(&action).unwrap_err();
assert!(matches!(err, ApiError::BadRequest(_)));
}
#[test]
fn test_should_reject_pre_boot_action_in_running() {
let (c, _rx) = ctl(LifecyclePhase::Running);
let action = ApiAction::PutEntropy(EntropyConfig::default());
let err = c.validate_phase(&action).unwrap_err();
assert!(matches!(err, ApiError::BadRequest(_)));
}
#[test]
fn test_should_admit_pause_in_running() {
let (c, _rx) = ctl(LifecyclePhase::Running);
let action = ApiAction::PatchVm(VmStateChange::Paused);
c.validate_phase(&action).unwrap();
}
#[test]
fn test_should_reject_send_ctrl_alt_del_with_upstream_message() {
let (c, _rx) = ctl(LifecyclePhase::Running);
let action = ApiAction::Action(InstanceAction::SendCtrlAltDel);
let err = c.validate_phase(&action).unwrap_err();
assert!(matches!(err, ApiError::BadRequest(_)));
assert!(err.fault_message().contains("SendCtrlAltDel"));
}
#[test]
fn test_should_reject_anything_in_shutdown() {
let (c, _rx) = ctl(LifecyclePhase::Shutdown);
let action = ApiAction::PutEntropy(EntropyConfig::default());
let err = c.validate_phase(&action).unwrap_err();
assert!(matches!(err, ApiError::Internal(_)));
}
#[test]
fn test_should_reject_during_starting_phase() {
let (c, _rx) = ctl(LifecyclePhase::Starting);
let action = ApiAction::PutEntropy(EntropyConfig::default());
assert!(c.validate_phase(&action).is_err());
}
#[tokio::test]
async fn test_should_surface_504_on_action_timeout() {
let mut snap = ControllerSnapshot::new("anonymous", "1.16.0", "1.16.0 (squib test)");
snap.phase = LifecyclePhase::Uninitialized;
snap.instance_info.state = VmState::NotStarted;
let mut t = TimeoutTable::from_spec();
t.pre_boot_config = Duration::from_millis(50);
let (c, _rx) = RuntimeApiController::new(snap, t, 16);
let action = ApiAction::PutBootSource(boot_source());
let res = c.dispatch(action).await;
assert!(matches!(res, Err(ApiError::Timeout(_))));
}
#[tokio::test]
async fn test_should_dispatch_to_vmm_and_return_no_content() {
let (c, mut rx) = ctl(LifecyclePhase::Uninitialized);
let action = ApiAction::PutBootSource(boot_source());
tokio::spawn(async move {
if let Some((_action, ack)) = rx.recv().await {
let _ = ack.send(ApiResponse::NoContent);
}
});
let resp = c.dispatch(action).await.unwrap();
assert!(matches!(resp, ApiResponse::NoContent));
}
#[tokio::test]
async fn test_should_surface_500_when_event_loop_drops_response() {
let (c, rx) = ctl(LifecyclePhase::Uninitialized);
let action = ApiAction::PutBootSource(boot_source());
tokio::spawn(async move {
let mut rx = rx;
if let Some((_action, ack)) = rx.recv().await {
drop(ack);
}
});
let res = c.dispatch(action).await;
assert!(matches!(res, Err(ApiError::Internal(_))));
}
fn machine_cfg(mem_mib: u64) -> crate::schemas::MachineConfig {
crate::schemas::MachineConfig::try_from(crate::schemas::machine_config::RawMachineConfig {
vcpu_count: 1,
mem_size_mib: mem_mib,
smt: false,
track_dirty_pages: false,
cpu_template: None,
huge_pages: None,
})
.unwrap()
}
fn drive_cfg(id: &str) -> crate::schemas::DriveConfig {
crate::schemas::DriveConfig::try_from(crate::schemas::drive::RawDriveConfig {
drive_id: id.into(),
path_on_host: format!("/tmp/{id}.img"),
is_root_device: false,
is_read_only: true,
cache_type: crate::schemas::drive::CacheType::Unsafe,
io_engine: crate::schemas::drive::IoEngine::default(),
partuuid: None,
rate_limiter: None,
socket: None,
})
.unwrap()
}
fn balloon_cfg(amount_mib: u64) -> crate::schemas::BalloonConfig {
crate::schemas::BalloonConfig::try_from(crate::schemas::balloon::RawBalloonConfig {
amount_mib,
deflate_on_oom: false,
stats_polling_interval_s: 0,
free_page_hinting: false,
free_page_reporting: false,
})
.unwrap()
}
#[tokio::test]
async fn test_should_reject_machine_config_above_host_ram_cap() {
let mut snap = ControllerSnapshot::new("test", "1.16.0", "1.16.0 (squib test)");
snap.phase = LifecyclePhase::Uninitialized;
let limits = LimitsState::from_host_ram_mib(256);
let (c, _rx) =
RuntimeApiController::new_with_limits(snap, TimeoutTable::from_spec(), 16, limits);
let action = ApiAction::PutMachineConfig(machine_cfg(1024));
let err = c.dispatch(action).await.unwrap_err();
assert!(matches!(err, ApiError::BadRequest(_)));
assert!(err.fault_message().contains("host RAM cap"));
}
#[tokio::test]
async fn test_should_reject_balloon_above_mem_minus_32() {
let (c, mut rx) = ctl(LifecyclePhase::Uninitialized);
tokio::spawn(async move {
while let Some((_action, ack)) = rx.recv().await {
let _ = ack.send(ApiResponse::NoContent);
}
});
c.dispatch(ApiAction::PutMachineConfig(machine_cfg(256)))
.await
.unwrap();
assert_eq!(c.limits_snapshot().mem_size_mib, Some(256));
let err = c
.dispatch(ApiAction::PutBalloon(balloon_cfg(256)))
.await
.unwrap_err();
assert!(matches!(err, ApiError::BadRequest(_)));
assert!(err.fault_message().contains("exceeds max"));
}
#[tokio::test]
async fn test_should_defer_balloon_cap_check_when_mem_size_not_yet_set() {
let (c, mut rx) = ctl(LifecyclePhase::Uninitialized);
tokio::spawn(async move {
while let Some((_action, ack)) = rx.recv().await {
let _ = ack.send(ApiResponse::NoContent);
}
});
let resp = c
.dispatch(ApiAction::PutBalloon(balloon_cfg(64)))
.await
.unwrap();
assert!(matches!(resp, ApiResponse::NoContent));
}
#[tokio::test]
async fn test_should_enforce_drives_class_cap_via_running_count() {
let (c, mut rx) = ctl(LifecyclePhase::Uninitialized);
tokio::spawn(async move {
while let Some((_action, ack)) = rx.recv().await {
let _ = ack.send(ApiResponse::NoContent);
}
});
for i in 0..8 {
c.dispatch(ApiAction::PutDrive(drive_cfg(&format!("d{i}"))))
.await
.unwrap();
}
assert_eq!(c.limits_snapshot().running_drives, 8);
let err = c
.dispatch(ApiAction::PutDrive(drive_cfg("d9")))
.await
.unwrap_err();
assert!(matches!(err, ApiError::BadRequest(_)));
assert!(err.fault_message().contains("drives"));
}
#[tokio::test]
async fn test_should_not_bump_running_count_on_vmm_fault() {
let (c, mut rx) = ctl(LifecyclePhase::Uninitialized);
tokio::spawn(async move {
while let Some((_action, ack)) = rx.recv().await {
let _ = ack.send(ApiResponse::Fault {
status: 400,
fault_message: "stub VMM rejected this".into(),
});
}
});
let _ = c
.dispatch(ApiAction::PutDrive(drive_cfg("d0")))
.await
.unwrap();
assert_eq!(c.limits_snapshot().running_drives, 0);
}
#[test]
fn test_should_apply_default_timeouts_per_spec() {
let t = TimeoutTable::from_spec();
assert_eq!(
t.for_class(ActionClass::PreBootConfig),
Duration::from_secs(5)
);
assert_eq!(
t.for_class(ActionClass::InstanceStart),
Duration::from_secs(30)
);
assert_eq!(
t.for_class(ActionClass::SnapshotCreate),
Duration::from_mins(5)
);
assert_eq!(
t.for_class(ActionClass::SnapshotLoad),
Duration::from_mins(5)
);
}
}