use std::{
path::{Path, PathBuf},
sync::Arc,
};
use axum::{
Router,
extract::{MatchedPath, Request},
http::{HeaderValue, Response as HttpResponse, header},
routing::{get, patch, put},
serve as axum_serve,
};
use tokio::net::UnixListener;
use tower_http::{
limit::RequestBodyLimitLayer, set_header::SetResponseHeaderLayer, trace::TraceLayer,
};
use tracing::{Span, info, info_span};
use crate::{
controller::RuntimeApiController,
handlers::{
delete_drive, delete_network, delete_pmem, fallback, get_balloon, get_balloon_statistics,
get_hotplug_memory, get_machine_config, get_mmds, get_root, get_version, get_vm_config,
patch_balloon, patch_balloon_hinting, patch_balloon_statistics, patch_drive,
patch_hotplug_memory, patch_machine_config, patch_mmds, patch_network, patch_pmem,
patch_vm, put_actions, put_balloon, put_boot_source, put_cpu_config, put_drive,
put_entropy, put_hotplug_memory, put_logger, put_machine_config, put_metrics, put_mmds,
put_mmds_config, put_network, put_pmem, put_serial, put_snapshot_create, put_snapshot_load,
put_vsock,
},
};
pub const FIRECRACKER_SERVER_HEADER: &str = "Firecracker API";
pub const DEFAULT_MAX_PAYLOAD: usize = 51_200;
pub const MIN_MAX_PAYLOAD: usize = 1_024;
pub const MAX_MAX_PAYLOAD: usize = 1_048_576;
#[derive(Debug, Clone)]
pub struct ServeOptions {
pub socket_path: PathBuf,
pub max_payload_size: usize,
}
impl ServeOptions {
pub fn new(socket_path: impl Into<PathBuf>) -> Self {
Self {
socket_path: socket_path.into(),
max_payload_size: DEFAULT_MAX_PAYLOAD,
}
}
#[must_use]
pub fn with_max_payload_size(mut self, bytes: usize) -> Self {
self.max_payload_size = bytes.clamp(MIN_MAX_PAYLOAD, MAX_MAX_PAYLOAD);
self
}
}
pub fn router(controller: Arc<RuntimeApiController>, max_payload: usize) -> Router {
let server_header_value = HeaderValue::from_static(FIRECRACKER_SERVER_HEADER);
let server_layer = SetResponseHeaderLayer::overriding(header::SERVER, server_header_value);
let instance_id = controller.snapshot().instance_info.id.clone();
let trace_layer = TraceLayer::new_for_http()
.make_span_with(move |req: &Request<_>| {
let matched = req
.extensions()
.get::<MatchedPath>()
.map_or("<unmatched>", MatchedPath::as_str);
info_span!(
"squib_api_request",
instance_id = %instance_id,
method = %req.method(),
path = %matched,
)
})
.on_request(())
.on_response(|_resp: &HttpResponse<_>, _latency: std::time::Duration, _span: &Span| {});
Router::new()
.route("/", get(get_root))
.route("/version", get(get_version))
.route("/vm/config", get(get_vm_config))
.route("/vm", patch(patch_vm))
.route(
"/machine-config",
get(get_machine_config)
.put(put_machine_config)
.patch(patch_machine_config),
)
.route("/boot-source", put(put_boot_source))
.route(
"/drives/{id}",
put(put_drive).patch(patch_drive).delete(delete_drive),
)
.route(
"/network-interfaces/{id}",
put(put_network).patch(patch_network).delete(delete_network),
)
.route("/vsock", put(put_vsock))
.route("/mmds", get(get_mmds).put(put_mmds).patch(patch_mmds))
.route("/mmds/config", put(put_mmds_config))
.route(
"/balloon",
get(get_balloon).put(put_balloon).patch(patch_balloon),
)
.route(
"/balloon/statistics",
get(get_balloon_statistics).patch(patch_balloon_statistics),
)
.route("/balloon/hinting/{op}", patch(patch_balloon_hinting))
.route("/entropy", put(put_entropy))
.route("/serial", put(put_serial))
.route(
"/pmem/{id}",
put(put_pmem).patch(patch_pmem).delete(delete_pmem),
)
.route(
"/hotplug/memory",
get(get_hotplug_memory)
.put(put_hotplug_memory)
.patch(patch_hotplug_memory),
)
.route("/cpu-config", put(put_cpu_config))
.route("/actions", put(put_actions))
.route("/snapshot/create", put(put_snapshot_create))
.route("/snapshot/load", put(put_snapshot_load))
.route("/logger", put(put_logger))
.route("/metrics", put(put_metrics))
.fallback(fallback)
.with_state(controller)
.layer(server_layer)
.layer(RequestBodyLimitLayer::new(max_payload))
.layer(trace_layer)
}
pub async fn bind_listener(opts: &ServeOptions) -> std::io::Result<UnixListener> {
if opts.socket_path.exists() {
tokio::fs::remove_file(&opts.socket_path).await?;
}
UnixListener::bind(&opts.socket_path)
}
pub async fn serve_bound(
listener: UnixListener,
opts: ServeOptions,
controller: Arc<RuntimeApiController>,
) -> std::io::Result<()> {
info!(
socket = %opts.socket_path.display(),
max_payload_size = opts.max_payload_size,
"squib-api listening",
);
let app = router(controller, opts.max_payload_size);
axum_serve(listener, app).await
}
pub async fn serve(
opts: ServeOptions,
controller: Arc<RuntimeApiController>,
) -> std::io::Result<()> {
let listener = bind_listener(&opts).await?;
serve_bound(listener, opts, controller).await
}
pub async fn unlink_socket_if_exists(path: &Path) -> std::io::Result<()> {
match tokio::fs::remove_file(path).await {
Ok(()) => Ok(()),
Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()),
Err(err) => Err(err),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::controller::{ControllerSnapshot, TimeoutTable};
fn ctl() -> Arc<RuntimeApiController> {
let snap = ControllerSnapshot::new("anonymous", "1.16.0", "1.16.0 (squib 0.0.0-test)");
let (c, _rx) = RuntimeApiController::new(snap, TimeoutTable::from_spec(), 16);
Arc::new(c)
}
#[test]
fn test_should_build_router_against_controller() {
let _ = router(ctl(), DEFAULT_MAX_PAYLOAD);
}
#[test]
fn test_should_default_payload_limit_to_51200() {
let opts = ServeOptions::new("/tmp/squib.sock");
assert_eq!(opts.max_payload_size, DEFAULT_MAX_PAYLOAD);
}
#[test]
fn test_should_clamp_payload_limit_to_lower_bound() {
let opts = ServeOptions::new("/tmp/squib.sock").with_max_payload_size(0);
assert_eq!(opts.max_payload_size, MIN_MAX_PAYLOAD);
}
#[test]
fn test_should_clamp_payload_limit_to_upper_bound() {
let opts = ServeOptions::new("/tmp/squib.sock").with_max_payload_size(usize::MAX);
assert_eq!(opts.max_payload_size, MAX_MAX_PAYLOAD);
}
}