1use std::collections::HashMap;
14use std::net::SocketAddr;
15use std::path::PathBuf;
16use std::sync::Arc;
17
18use axum::Router;
19use axum::body::Body;
20use axum::extract::{Path, State};
21use axum::http::{HeaderMap, Response};
22use axum::routing::any;
23use tokio::sync::RwLock;
24use tower_http::services::ServeDir;
25
26type RouteHandlerFn = Arc<dyn Fn(&str, &HeaderMap) -> RouteResponse + Send + Sync>;
27
28pub struct RouteResponse {
30 pub status: u16,
31 pub content_type: String,
32 pub body: Vec<u8>,
33 pub headers: Vec<(String, String)>,
34}
35
36impl RouteResponse {
37 pub fn html(body: &str) -> Self {
39 Self {
40 status: 200,
41 content_type: "text/html".into(),
42 body: body.as_bytes().to_vec(),
43 headers: vec![],
44 }
45 }
46
47 pub fn json(body: &str) -> Self {
49 Self {
50 status: 200,
51 content_type: "application/json".into(),
52 body: body.as_bytes().to_vec(),
53 headers: vec![],
54 }
55 }
56
57 pub fn text(body: &str) -> Self {
59 Self {
60 status: 200,
61 content_type: "text/plain".into(),
62 body: body.as_bytes().to_vec(),
63 headers: vec![],
64 }
65 }
66
67 pub fn status(code: u16) -> Self {
69 Self {
70 status: code,
71 content_type: "text/plain".into(),
72 body: vec![],
73 headers: vec![],
74 }
75 }
76}
77
78#[derive(Debug, Clone)]
80pub struct RecordedRequest {
81 pub path: String,
82 pub method: String,
83 pub headers: HashMap<String, String>,
84 pub body: Vec<u8>,
85}
86
87struct ServerState {
88 routes: RwLock<HashMap<String, RouteHandlerFn>>,
89 requests: RwLock<Vec<RecordedRequest>>,
90 assets_dir: PathBuf,
91 spa: bool,
92}
93
94pub struct TestServer {
99 addr: SocketAddr,
100 state: Arc<ServerState>,
101 shutdown_tx: tokio::sync::oneshot::Sender<()>,
102 handle: tokio::task::JoinHandle<()>,
103}
104
105impl TestServer {
106 pub async fn start(assets_dir: impl Into<PathBuf>) -> ferridriver::error::Result<Self> {
112 Self::start_with_options(assets_dir.into(), 0, false).await
113 }
114
115 pub async fn start_spa(assets_dir: impl Into<PathBuf>) -> ferridriver::error::Result<Self> {
117 Self::start_with_options(assets_dir.into(), 0, true).await
118 }
119
120 pub async fn from_config(config: &crate::config::WebServerConfig) -> ferridriver::error::Result<Self> {
122 let dir = config.static_dir.as_deref().unwrap_or(".");
123 Self::start_with_options(PathBuf::from(dir), config.port, config.spa).await
124 }
125
126 async fn start_with_options(assets_dir: PathBuf, port: u16, spa: bool) -> ferridriver::error::Result<Self> {
127 let state = Arc::new(ServerState {
128 routes: RwLock::new(HashMap::new()),
129 requests: RwLock::new(Vec::new()),
130 assets_dir: assets_dir.clone(),
131 spa,
132 });
133
134 let state2 = state.clone();
135 let fallback = ServeDir::new(&assets_dir).append_index_html_on_directories(true);
136
137 let app = Router::new()
138 .route("/{*path}", any(handle_request))
139 .route("/", any(handle_request))
140 .with_state(state2)
141 .fallback_service(fallback);
142
143 let bind_addr = format!("127.0.0.1:{port}");
144 let listener = tokio::net::TcpListener::bind(&bind_addr).await?;
145 let addr = listener.local_addr()?;
146 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
147
148 let handle = tokio::spawn(async move {
149 axum::serve(listener, app)
150 .with_graceful_shutdown(async {
151 let _ = shutdown_rx.await;
152 })
153 .await
154 .ok();
155 });
156
157 Ok(Self {
158 addr,
159 state,
160 shutdown_tx,
161 handle,
162 })
163 }
164
165 #[must_use]
167 pub fn url(&self) -> String {
168 format!("http://{}", self.addr)
169 }
170
171 #[must_use]
173 pub fn prefix(&self) -> String {
174 self.url()
175 }
176
177 #[must_use]
179 pub fn empty_page(&self) -> String {
180 format!("{}/empty.html", self.url())
181 }
182
183 pub async fn set_route(&self, path: &str, handler: RouteHandlerFn) {
185 self.state.routes.write().await.insert(path.to_string(), handler);
186 }
187
188 pub async fn set_content(&self, path: &str, content_type: &str, body: &str) {
190 let ct = content_type.to_string();
191 let b = body.as_bytes().to_vec();
192 self
193 .set_route(
194 path,
195 Arc::new(move |_, _| RouteResponse {
196 status: 200,
197 content_type: ct.clone(),
198 body: b.clone(),
199 headers: vec![],
200 }),
201 )
202 .await;
203 }
204
205 pub async fn requests(&self) -> Vec<RecordedRequest> {
207 self.state.requests.read().await.clone()
208 }
209
210 pub async fn requests_for(&self, path: &str) -> Vec<RecordedRequest> {
212 self
213 .state
214 .requests
215 .read()
216 .await
217 .iter()
218 .filter(|r| r.path.starts_with(path))
219 .cloned()
220 .collect()
221 }
222
223 pub async fn clear_requests(&self) {
225 self.state.requests.write().await.clear();
226 }
227
228 pub async fn stop(self) {
230 let _ = self.shutdown_tx.send(());
231 let _ = self.handle.await;
232 }
233}
234
235async fn handle_request(
236 State(state): State<Arc<ServerState>>,
237 path: Option<Path<String>>,
238 headers: HeaderMap,
239 method: axum::http::Method,
240 body: axum::body::Bytes,
241) -> Response<Body> {
242 let request_path = format!("/{}", path.as_ref().map(|p| p.as_str()).unwrap_or(""));
243
244 let mut header_map = HashMap::new();
246 for (name, value) in &headers {
247 if let Ok(v) = value.to_str() {
248 header_map.insert(name.to_string(), v.to_string());
249 }
250 }
251 state.requests.write().await.push(RecordedRequest {
252 path: request_path.clone(),
253 method: method.to_string(),
254 headers: header_map,
255 body: body.to_vec(),
256 });
257
258 let routes = state.routes.read().await;
260 if let Some(handler) = routes.get(&request_path) {
261 let resp = handler(&request_path, &headers);
262 let mut builder = Response::builder().status(resp.status);
263 builder = builder.header("content-type", &resp.content_type);
264 builder = builder.header("access-control-allow-origin", "*");
265 for (k, v) in &resp.headers {
266 builder = builder.header(k.as_str(), v.as_str());
267 }
268 return builder.body(Body::from(resp.body)).unwrap_or_else(|_| {
269 Response::builder()
270 .status(500)
271 .body(Body::empty())
272 .expect("empty 500 response")
273 });
274 }
275 drop(routes);
276
277 let file_path = state.assets_dir.join(request_path.trim_start_matches('/'));
280 if file_path.exists() && file_path.is_file() {
281 let content_type = mime_guess::from_path(&file_path).first_or_octet_stream().to_string();
282 match tokio::fs::read(&file_path).await {
283 Ok(contents) => Response::builder()
284 .status(200)
285 .header("content-type", content_type)
286 .header("access-control-allow-origin", "*")
287 .body(Body::from(contents))
288 .expect("static file response"),
289 Err(_) => Response::builder()
290 .status(500)
291 .body(Body::empty())
292 .expect("empty 500 response"),
293 }
294 } else if state.spa {
295 let index = state.assets_dir.join("index.html");
297 if index.exists() {
298 match tokio::fs::read(&index).await {
299 Ok(contents) => Response::builder()
300 .status(200)
301 .header("content-type", "text/html")
302 .header("access-control-allow-origin", "*")
303 .body(Body::from(contents))
304 .expect("SPA index.html response"),
305 Err(_) => Response::builder()
306 .status(500)
307 .body(Body::empty())
308 .expect("empty 500 response"),
309 }
310 } else {
311 Response::builder()
312 .status(404)
313 .header("content-type", "text/plain")
314 .body(Body::from("Not Found (SPA: no index.html)"))
315 .expect("404 response")
316 }
317 } else {
318 Response::builder()
319 .status(404)
320 .header("content-type", "text/plain")
321 .body(Body::from("Not Found"))
322 .expect("404 response")
323 }
324}
325
326pub struct WebServerManager {
331 servers: Vec<RunningServer>,
332}
333
334enum RunningServer {
335 Static(Box<StaticEntry>),
336 Command(Box<CommandEntry>),
337}
338
339struct StaticEntry {
340 server: TestServer,
341 name: String,
342}
343
344struct CommandEntry {
345 child: tokio::process::Child,
346 url: String,
347 name: String,
348 graceful: Option<crate::config::GracefulShutdown>,
349}
350
351impl WebServerManager {
352 pub async fn start(configs: &[crate::config::WebServerConfig]) -> ferridriver::error::Result<Self> {
359 let mut servers = Vec::with_capacity(configs.len());
360 for config in configs {
361 let display_name = config.name.clone().unwrap_or_else(|| "WebServer".to_string());
362 if let Some(ref dir) = config.static_dir {
363 let server = TestServer::start_with_options(PathBuf::from(dir), config.port, config.spa).await?;
364 tracing::info!(name = %display_name, "[{display_name}] Static server ready at {} (serving {})", server.url(), dir);
365 servers.push(RunningServer::Static(Box::new(StaticEntry {
366 server,
367 name: display_name,
368 })));
369 } else if let Some(ref command) = config.command {
370 let url = config.url.as_deref().ok_or_else(|| {
371 ferridriver::FerriError::invalid_argument(
372 "webServer.url",
373 format!("webServer command requires 'url' to wait for: {command}"),
374 )
375 })?;
376
377 if config.reuse_existing_server && http_probe(url, config.ignore_https_errors).await {
381 tracing::info!(name = %display_name, "[{display_name}] Reusing existing server at {url}");
382 servers.push(RunningServer::Command(Box::new(CommandEntry {
387 child: tokio::process::Command::new("true").spawn()?,
388 url: url.to_string(),
389 name: display_name,
390 graceful: config.graceful_shutdown.clone(),
391 })));
392 continue;
393 }
394
395 let cwd = config.cwd.as_deref().unwrap_or(".");
396 let child = spawn_command(command, cwd, &config.env)?;
397 wait_for_url(url, config.timeout, config.ignore_https_errors, &display_name).await?;
398 tracing::info!(name = %display_name, "[{display_name}] Dev server ready at {url} (command: {command})");
399 servers.push(RunningServer::Command(Box::new(CommandEntry {
400 child,
401 url: url.to_string(),
402 name: display_name,
403 graceful: config.graceful_shutdown.clone(),
404 })));
405 } else {
406 return Err(ferridriver::FerriError::invalid_argument(
407 "webServer",
408 "webServer config must have either 'command' or 'staticDir'",
409 ));
410 }
411 }
412 Ok(Self { servers })
413 }
414
415 #[must_use]
417 pub fn first_url(&self) -> Option<String> {
418 self.servers.first().map(|s| match s {
419 RunningServer::Static(entry) => entry.server.url(),
420 RunningServer::Command(entry) => entry.url.clone(),
421 })
422 }
423
424 pub fn test_server(&self) -> Option<&TestServer> {
426 self.servers.first().and_then(|s| match s {
427 RunningServer::Static(entry) => Some(&entry.server),
428 RunningServer::Command(_) => None,
429 })
430 }
431
432 pub async fn stop(self) {
438 for server in self.servers {
439 match server {
440 RunningServer::Static(entry) => {
441 let StaticEntry { server, name } = *entry;
442 tracing::info!(name = %name, "[{name}] Stopping static server");
443 server.stop().await;
444 },
445 RunningServer::Command(entry) => {
446 let CommandEntry {
447 mut child,
448 name,
449 graceful,
450 ..
451 } = *entry;
452 stop_child(&mut child, &name, graceful.as_ref()).await;
453 },
454 }
455 }
456 }
457}
458
459async fn stop_child(child: &mut tokio::process::Child, name: &str, graceful: Option<&crate::config::GracefulShutdown>) {
460 let Some(g) = graceful else {
461 tracing::info!(name = %name, "[{name}] Hard-killing child process");
462 let _ = child.kill().await;
463 return;
464 };
465
466 let Some(pid) = child.id() else {
467 let _ = child.wait().await;
469 return;
470 };
471
472 let signum = parse_signal(&g.signal);
473 tracing::info!(
474 name = %name,
475 "[{name}] Sending {} (graceful_shutdown), waiting up to {}ms before SIGKILL",
476 g.signal,
477 g.timeout
478 );
479 #[cfg(unix)]
480 send_signal(pid, signum);
481 #[cfg(not(unix))]
482 {
483 let _ = (pid, signum);
484 let _ = child.kill().await;
485 return;
486 }
487
488 let timeout = std::time::Duration::from_millis(g.timeout);
489 if tokio::time::timeout(timeout, child.wait()).await.is_ok() {
490 tracing::info!(name = %name, "[{name}] Process exited gracefully");
491 } else {
492 tracing::warn!(
493 name = %name,
494 "[{name}] Process did not exit within {}ms — escalating to SIGKILL",
495 g.timeout
496 );
497 let _ = child.kill().await;
498 }
499}
500
501fn parse_signal(name: &str) -> libc::c_int {
502 match name.trim().to_ascii_uppercase().as_str() {
503 "SIGINT" => libc::SIGINT,
504 "SIGKILL" => libc::SIGKILL,
505 _ => libc::SIGTERM,
506 }
507}
508
509#[cfg(unix)]
510#[allow(unsafe_code)]
511fn send_signal(pid: u32, signum: libc::c_int) {
512 #[allow(clippy::cast_possible_wrap)]
515 let pid_i = pid as libc::pid_t;
516 unsafe {
521 libc::kill(pid_i, signum);
522 }
523}
524
525fn spawn_command(
526 command: &str,
527 cwd: &str,
528 env: &std::collections::BTreeMap<String, String>,
529) -> ferridriver::error::Result<tokio::process::Child> {
530 let mut cmd = if cfg!(target_os = "windows") {
531 let mut c = tokio::process::Command::new("cmd");
532 c.args(["/C", command]);
533 c
534 } else {
535 let mut c = tokio::process::Command::new("sh");
541 c.args(["-c", &format!("exec {command}")]);
542 c
543 };
544 cmd.current_dir(cwd);
545 for (k, v) in env {
546 cmd.env(k, v);
547 }
548 cmd
549 .stdin(std::process::Stdio::null())
550 .stdout(std::process::Stdio::piped())
551 .stderr(std::process::Stdio::piped());
552 cmd
553 .spawn()
554 .map_err(|e| ferridriver::FerriError::backend(format!("spawn '{command}': {e}")))
555}
556
557#[must_use]
562pub fn build_probe_client(ignore_https_errors: bool) -> reqwest::Client {
563 reqwest::Client::builder()
564 .danger_accept_invalid_certs(ignore_https_errors)
565 .timeout(std::time::Duration::from_secs(5))
566 .build()
567 .unwrap_or_else(|_| reqwest::Client::new())
568}
569
570pub async fn http_probe(url: &str, ignore_https_errors: bool) -> bool {
574 let client = build_probe_client(ignore_https_errors);
575 match probe_status(&client, url).await {
576 Some(s) if (200..404).contains(&s) => true,
577 Some(404) => {
578 let index_url = if url.ends_with('/') {
580 format!("{url}index.html")
581 } else {
582 format!("{url}/index.html")
583 };
584 matches!(probe_status(&client, &index_url).await, Some(s) if (200..404).contains(&s))
585 },
586 _ => false,
587 }
588}
589
590async fn probe_status(client: &reqwest::Client, url: &str) -> Option<u16> {
591 match client.get(url).send().await {
592 Ok(resp) => Some(resp.status().as_u16()),
593 Err(_) => None,
594 }
595}
596
597async fn wait_for_url(
599 url: &str,
600 timeout_ms: u64,
601 ignore_https_errors: bool,
602 name: &str,
603) -> ferridriver::error::Result<()> {
604 let deadline = tokio::time::Instant::now() + std::time::Duration::from_millis(timeout_ms);
605
606 let mut delays = [100u64, 250, 500].iter().copied();
608
609 loop {
610 if tokio::time::Instant::now() >= deadline {
611 return Err(ferridriver::FerriError::timeout(
612 format!("[{name}] webServer {url}"),
613 timeout_ms,
614 ));
615 }
616 if http_probe(url, ignore_https_errors).await {
617 return Ok(());
618 }
619 let delay = delays.next().unwrap_or(1000);
620 tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
621 }
622}