1use std::{
14 path::{Path, PathBuf},
15 sync::Arc,
16};
17
18use axum::{
19 Router,
20 extract::{MatchedPath, Request},
21 http::{HeaderValue, Response as HttpResponse, header},
22 routing::{get, patch, put},
23 serve as axum_serve,
24};
25use tokio::net::UnixListener;
26use tower_http::{
27 limit::RequestBodyLimitLayer, set_header::SetResponseHeaderLayer, trace::TraceLayer,
28};
29use tracing::{Span, info, info_span};
30
31use crate::{
32 controller::RuntimeApiController,
33 handlers::{
34 delete_drive, delete_network, delete_pmem, fallback, get_balloon, get_balloon_statistics,
35 get_hotplug_memory, get_machine_config, get_mmds, get_root, get_version, get_vm_config,
36 patch_balloon, patch_balloon_hinting, patch_balloon_statistics, patch_drive,
37 patch_hotplug_memory, patch_machine_config, patch_mmds, patch_network, patch_pmem,
38 patch_vm, put_actions, put_balloon, put_boot_source, put_cpu_config, put_drive,
39 put_entropy, put_hotplug_memory, put_logger, put_machine_config, put_metrics, put_mmds,
40 put_mmds_config, put_network, put_pmem, put_serial, put_snapshot_create, put_snapshot_load,
41 put_vsock,
42 },
43};
44
45pub const FIRECRACKER_SERVER_HEADER: &str = "Firecracker API";
48
49pub const DEFAULT_MAX_PAYLOAD: usize = 51_200;
51
52pub const MIN_MAX_PAYLOAD: usize = 1_024;
55
56pub const MAX_MAX_PAYLOAD: usize = 1_048_576;
58
59#[derive(Debug, Clone)]
61pub struct ServeOptions {
62 pub socket_path: PathBuf,
64 pub max_payload_size: usize,
66}
67
68impl ServeOptions {
69 pub fn new(socket_path: impl Into<PathBuf>) -> Self {
71 Self {
72 socket_path: socket_path.into(),
73 max_payload_size: DEFAULT_MAX_PAYLOAD,
74 }
75 }
76
77 #[must_use]
80 pub fn with_max_payload_size(mut self, bytes: usize) -> Self {
81 self.max_payload_size = bytes.clamp(MIN_MAX_PAYLOAD, MAX_MAX_PAYLOAD);
82 self
83 }
84}
85
86pub fn router(controller: Arc<RuntimeApiController>, max_payload: usize) -> Router {
95 let server_header_value = HeaderValue::from_static(FIRECRACKER_SERVER_HEADER);
96 let server_layer = SetResponseHeaderLayer::overriding(header::SERVER, server_header_value);
97 let instance_id = controller.snapshot().instance_info.id.clone();
98 let trace_layer = TraceLayer::new_for_http()
99 .make_span_with(move |req: &Request<_>| {
100 let matched = req
103 .extensions()
104 .get::<MatchedPath>()
105 .map_or("<unmatched>", MatchedPath::as_str);
106 info_span!(
107 "squib_api_request",
108 instance_id = %instance_id,
109 method = %req.method(),
110 path = %matched,
111 )
112 })
113 .on_request(())
114 .on_response(|_resp: &HttpResponse<_>, _latency: std::time::Duration, _span: &Span| {});
115
116 Router::new()
117 .route("/", get(get_root))
119 .route("/version", get(get_version))
120 .route("/vm/config", get(get_vm_config))
121 .route("/vm", patch(patch_vm))
123 .route(
124 "/machine-config",
125 get(get_machine_config)
126 .put(put_machine_config)
127 .patch(patch_machine_config),
128 )
129 .route("/boot-source", put(put_boot_source))
130 .route(
131 "/drives/{id}",
132 put(put_drive).patch(patch_drive).delete(delete_drive),
133 )
134 .route(
135 "/network-interfaces/{id}",
136 put(put_network).patch(patch_network).delete(delete_network),
137 )
138 .route("/vsock", put(put_vsock))
139 .route("/mmds", get(get_mmds).put(put_mmds).patch(patch_mmds))
140 .route("/mmds/config", put(put_mmds_config))
141 .route(
142 "/balloon",
143 get(get_balloon).put(put_balloon).patch(patch_balloon),
144 )
145 .route(
146 "/balloon/statistics",
147 get(get_balloon_statistics).patch(patch_balloon_statistics),
148 )
149 .route("/balloon/hinting/{op}", patch(patch_balloon_hinting))
150 .route("/entropy", put(put_entropy))
151 .route("/serial", put(put_serial))
152 .route(
153 "/pmem/{id}",
154 put(put_pmem).patch(patch_pmem).delete(delete_pmem),
155 )
156 .route(
157 "/hotplug/memory",
158 get(get_hotplug_memory)
159 .put(put_hotplug_memory)
160 .patch(patch_hotplug_memory),
161 )
162 .route("/cpu-config", put(put_cpu_config))
163 .route("/actions", put(put_actions))
164 .route("/snapshot/create", put(put_snapshot_create))
165 .route("/snapshot/load", put(put_snapshot_load))
166 .route("/logger", put(put_logger))
167 .route("/metrics", put(put_metrics))
168 .fallback(fallback)
169 .with_state(controller)
170 .layer(server_layer)
171 .layer(RequestBodyLimitLayer::new(max_payload))
172 .layer(trace_layer)
173}
174
175pub async fn bind_listener(opts: &ServeOptions) -> std::io::Result<UnixListener> {
183 if opts.socket_path.exists() {
184 tokio::fs::remove_file(&opts.socket_path).await?;
185 }
186 UnixListener::bind(&opts.socket_path)
187}
188
189pub async fn serve_bound(
194 listener: UnixListener,
195 opts: ServeOptions,
196 controller: Arc<RuntimeApiController>,
197) -> std::io::Result<()> {
198 info!(
199 socket = %opts.socket_path.display(),
200 max_payload_size = opts.max_payload_size,
201 "squib-api listening",
202 );
203
204 let app = router(controller, opts.max_payload_size);
205 axum_serve(listener, app).await
206}
207
208pub async fn serve(
214 opts: ServeOptions,
215 controller: Arc<RuntimeApiController>,
216) -> std::io::Result<()> {
217 let listener = bind_listener(&opts).await?;
218 serve_bound(listener, opts, controller).await
219}
220
221pub async fn unlink_socket_if_exists(path: &Path) -> std::io::Result<()> {
225 match tokio::fs::remove_file(path).await {
226 Ok(()) => Ok(()),
227 Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()),
228 Err(err) => Err(err),
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use crate::controller::{ControllerSnapshot, TimeoutTable};
236
237 fn ctl() -> Arc<RuntimeApiController> {
238 let snap = ControllerSnapshot::new("anonymous", "1.16.0", "1.16.0 (squib 0.0.0-test)");
239 let (c, _rx) = RuntimeApiController::new(snap, TimeoutTable::from_spec(), 16);
240 Arc::new(c)
241 }
242
243 #[test]
244 fn test_should_build_router_against_controller() {
245 let _ = router(ctl(), DEFAULT_MAX_PAYLOAD);
246 }
247
248 #[test]
249 fn test_should_default_payload_limit_to_51200() {
250 let opts = ServeOptions::new("/tmp/squib.sock");
251 assert_eq!(opts.max_payload_size, DEFAULT_MAX_PAYLOAD);
252 }
253
254 #[test]
255 fn test_should_clamp_payload_limit_to_lower_bound() {
256 let opts = ServeOptions::new("/tmp/squib.sock").with_max_payload_size(0);
257 assert_eq!(opts.max_payload_size, MIN_MAX_PAYLOAD);
258 }
259
260 #[test]
261 fn test_should_clamp_payload_limit_to_upper_bound() {
262 let opts = ServeOptions::new("/tmp/squib.sock").with_max_payload_size(usize::MAX);
263 assert_eq!(opts.max_payload_size, MAX_MAX_PAYLOAD);
264 }
265}