use std::{path::Path, sync::Arc};
use thiserror::Error;
use crate::{
action::{ApiAction, ApiResponse},
controller::RuntimeApiController,
error::ApiError,
schemas::{
BalloonConfig, BootSourceConfig, ConfigFile, CpuConfig, DriveConfig, EntropyConfig,
HotplugMemoryConfig, InstanceAction, LoggerConfig, MachineConfig, MetricsConfig,
MmdsConfig, MmdsContents, NetworkInterfaceConfig, PmemConfig, SerialConfig, VsockConfig,
common::{MAX_DRIVES, MAX_NICS, MAX_PMEM},
},
};
#[derive(Debug, Error)]
pub enum ReplayError {
#[error("failed to read --config-file {path}: {source}")]
Io {
path: String,
#[source]
source: std::io::Error,
},
#[error("failed to parse --config-file {path}: {source}")]
Parse {
path: String,
#[source]
source: serde_json::Error,
},
#[error("--config-file is missing the required `boot-source` member")]
MissingBootSource,
#[error("--config-file collection {field} exceeds {max} entries")]
CollectionCap {
field: &'static str,
max: usize,
},
#[error("--config-file rejected: {0}")]
Validation(String),
#[error("--config-file action `{label}` failed: {source}")]
Dispatch {
label: &'static str,
#[source]
source: ApiError,
},
#[error("--config-file action `{label}` rejected by VMM (status {status}): {fault_message}")]
Fault {
label: &'static str,
status: u16,
fault_message: String,
},
}
pub async fn parse_config_file(path: impl AsRef<Path>) -> Result<ConfigFile, ReplayError> {
let path = path.as_ref();
let bytes = tokio::fs::read(path).await.map_err(|e| ReplayError::Io {
path: path.display().to_string(),
source: e,
})?;
let cfg: ConfigFile = serde_json::from_slice(&bytes).map_err(|e| ReplayError::Parse {
path: path.display().to_string(),
source: e,
})?;
Ok(cfg)
}
pub async fn replay_config(
controller: &Arc<RuntimeApiController>,
cfg: ConfigFile,
start_microvm: bool,
) -> Result<(), ReplayError> {
if cfg.boot_source.is_none() {
return Err(ReplayError::MissingBootSource);
}
if cfg.drives.len() > MAX_DRIVES {
return Err(ReplayError::CollectionCap {
field: "drives",
max: MAX_DRIVES,
});
}
if cfg.network_interfaces.len() > MAX_NICS {
return Err(ReplayError::CollectionCap {
field: "network-interfaces",
max: MAX_NICS,
});
}
if cfg.pmem.len() > MAX_PMEM {
return Err(ReplayError::CollectionCap {
field: "pmem",
max: MAX_PMEM,
});
}
if let Some(raw) = cfg.machine_config {
let validated = MachineConfig::try_from(raw).map_err(ReplayError::Validation)?;
dispatch(controller, ApiAction::PutMachineConfig(validated)).await?;
}
if let Some(raw) = cfg.cpu_config {
let validated = CpuConfig::try_from(raw).map_err(ReplayError::Validation)?;
dispatch(controller, ApiAction::PutCpuConfig(validated)).await?;
}
if let Some(raw) = cfg.boot_source {
let validated = BootSourceConfig::try_from(raw).map_err(ReplayError::Validation)?;
dispatch(controller, ApiAction::PutBootSource(validated)).await?;
}
for raw in cfg.drives {
let validated = DriveConfig::try_from(raw).map_err(ReplayError::Validation)?;
dispatch(controller, ApiAction::PutDrive(validated)).await?;
}
for raw in cfg.network_interfaces {
let validated = NetworkInterfaceConfig::try_from(raw).map_err(ReplayError::Validation)?;
dispatch(controller, ApiAction::PutNetwork(validated)).await?;
}
if let Some(raw) = cfg.vsock {
let validated = VsockConfig::try_from(raw).map_err(ReplayError::Validation)?;
dispatch(controller, ApiAction::PutVsock(validated)).await?;
}
if let Some(raw) = cfg.mmds_config {
let validated = MmdsConfig::try_from(raw).map_err(ReplayError::Validation)?;
dispatch(controller, ApiAction::PutMmdsConfig(validated)).await?;
}
if let Some(value) = cfg.mmds {
dispatch(controller, ApiAction::PutMmds(MmdsContents::new(value))).await?;
}
if let Some(raw) = cfg.balloon {
let validated = BalloonConfig::try_from(raw).map_err(ReplayError::Validation)?;
dispatch(controller, ApiAction::PutBalloon(validated)).await?;
}
if let Some(raw) = cfg.entropy {
let validated = EntropyConfig::try_from(raw).map_err(ReplayError::Validation)?;
dispatch(controller, ApiAction::PutEntropy(validated)).await?;
}
if let Some(raw) = cfg.serial {
let validated = SerialConfig::try_from(raw).map_err(ReplayError::Validation)?;
dispatch(controller, ApiAction::PutSerial(validated)).await?;
}
for raw in cfg.pmem {
let validated = PmemConfig::try_from(raw).map_err(ReplayError::Validation)?;
dispatch(controller, ApiAction::PutPmem(validated)).await?;
}
if let Some(raw) = cfg.hotplug_memory {
let validated = HotplugMemoryConfig::try_from(raw).map_err(ReplayError::Validation)?;
dispatch(controller, ApiAction::PutHotplugMemory(validated)).await?;
}
if let Some(raw) = cfg.logger {
let validated = LoggerConfig::try_from(raw).map_err(ReplayError::Validation)?;
dispatch(controller, ApiAction::PutLogger(validated)).await?;
}
if let Some(raw) = cfg.metrics {
let validated = MetricsConfig::try_from(raw).map_err(ReplayError::Validation)?;
dispatch(controller, ApiAction::PutMetrics(validated)).await?;
}
if start_microvm {
dispatch(controller, ApiAction::Action(InstanceAction::InstanceStart)).await?;
}
Ok(())
}
async fn dispatch(
controller: &Arc<RuntimeApiController>,
action: ApiAction,
) -> Result<(), ReplayError> {
let label = action.label();
let resp = controller
.dispatch(action)
.await
.map_err(|e| ReplayError::Dispatch { label, source: e })?;
match resp {
ApiResponse::NoContent | ApiResponse::Json(_) => Ok(()),
ApiResponse::Fault {
status,
fault_message,
} => Err(ReplayError::Fault {
label,
status,
fault_message,
}),
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::*;
use crate::controller::{ControllerSnapshot, TimeoutTable};
fn build_controller() -> (Arc<RuntimeApiController>, crate::controller::ActionReceiver) {
let snap = ControllerSnapshot::new("anonymous", "1.16.0", "1.16.0 (squib test)");
let (c, rx) = RuntimeApiController::new(snap, TimeoutTable::from_spec(), 64);
(Arc::new(c), rx)
}
fn drain_acker(
mut rx: crate::controller::ActionReceiver,
) -> tokio::task::JoinHandle<Vec<&'static str>> {
tokio::spawn(async move {
let mut labels = Vec::new();
while let Some((action, ack)) = rx.recv().await {
labels.push(action.label());
let _ = ack.send(ApiResponse::NoContent);
}
labels
})
}
#[tokio::test]
async fn test_should_reject_config_without_boot_source() {
let (c, _rx) = build_controller();
let cfg = ConfigFile::default();
let res = replay_config(&c, cfg, false).await;
assert!(matches!(res, Err(ReplayError::MissingBootSource)));
}
#[tokio::test]
async fn test_should_replay_minimal_config_in_order() {
let (c, rx) = build_controller();
let drain = drain_acker(rx);
let cfg: ConfigFile = serde_json::from_str(
r#"{
"boot-source": {"kernel_image_path":"/tmp/k"},
"machine-config": {"vcpu_count":1,"mem_size_mib":256}
}"#,
)
.unwrap();
replay_config(&c, cfg, false).await.unwrap();
drop(c);
let labels = tokio::time::timeout(Duration::from_millis(500), drain)
.await
.unwrap()
.unwrap();
assert_eq!(labels[0], "PUT /machine-config");
assert_eq!(labels[1], "PUT /boot-source");
}
#[tokio::test]
async fn test_should_dispatch_instance_start_when_requested() {
let (c, rx) = build_controller();
let drain = drain_acker(rx);
let cfg: ConfigFile = serde_json::from_str(
r#"{"boot-source":{"kernel_image_path":"/tmp/k"},
"machine-config":{"vcpu_count":1,"mem_size_mib":256}}"#,
)
.unwrap();
replay_config(&c, cfg, true).await.unwrap();
drop(c);
let labels = tokio::time::timeout(Duration::from_millis(500), drain)
.await
.unwrap()
.unwrap();
assert_eq!(*labels.last().unwrap(), "PUT /actions");
}
#[tokio::test]
async fn test_should_reject_drives_over_cap() {
let (c, _rx) = build_controller();
let mut cfg = ConfigFile {
boot_source: Some(crate::schemas::boot_source::RawBootSourceConfig {
kernel_image_path: "/tmp/k".into(),
initrd_path: None,
boot_args: None,
}),
..ConfigFile::default()
};
for i in 0..=MAX_DRIVES {
cfg.drives.push(crate::schemas::drive::RawDriveConfig {
drive_id: format!("d_{i}"),
path_on_host: "/tmp/x".into(),
is_root_device: false,
is_read_only: false,
cache_type: crate::schemas::drive::CacheType::Unsafe,
io_engine: crate::schemas::drive::IoEngine::Sync,
partuuid: None,
rate_limiter: None,
socket: None,
});
}
let res = replay_config(&c, cfg, false).await;
assert!(matches!(
res,
Err(ReplayError::CollectionCap {
field: "drives",
..
})
));
}
}