use std::fs;
use std::io::{Read as _, Write as _};
use std::net::{TcpListener, TcpStream};
use std::path::{Path, PathBuf};
use std::sync::{
Arc,
atomic::{AtomicBool, Ordering},
};
use std::thread::{self, JoinHandle};
use std::time::Duration;
use anyhow::{Context as _, Result, anyhow};
use percent_encoding::percent_decode_str;
pub(super) struct StaticServer {
stop: Arc<AtomicBool>,
handle: Option<JoinHandle<()>>,
port: u16,
}
impl StaticServer {
pub(super) fn start(root: PathBuf, port: u16) -> Result<Self> {
let listener =
TcpListener::bind(("127.0.0.1", port)).with_context(|| format!("bind port {port}"))?;
listener
.set_nonblocking(true)
.context("set static server nonblocking")?;
let port = listener
.local_addr()
.context("read static server addr")?
.port();
let stop = Arc::new(AtomicBool::new(false));
let thread_stop = Arc::clone(&stop);
let handle = thread::spawn(move || {
while !thread_stop.load(Ordering::Acquire) {
match listener.accept() {
Ok((stream, _)) => {
let _ = serve_static_request(stream, &root);
}
Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => {
thread::sleep(Duration::from_millis(10));
}
Err(_) => break,
}
}
});
Ok(Self {
stop,
handle: Some(handle),
port,
})
}
pub(super) fn base_url(&self) -> String {
format!("http://127.0.0.1:{}/", self.port)
}
pub(super) fn wait_forever(mut self) -> Result<()> {
if let Some(handle) = self.handle.take() {
handle
.join()
.map_err(|_| anyhow!("static server thread panicked"))?;
}
Ok(())
}
}
impl Drop for StaticServer {
fn drop(&mut self) {
self.stop.store(true, Ordering::Release);
let _ = TcpStream::connect(("127.0.0.1", self.port));
if let Some(handle) = self.handle.take() {
let _ = handle.join();
}
}
}
fn serve_static_request(mut stream: TcpStream, root: &Path) -> Result<()> {
let mut buf = [0; 2048];
let n = stream.read(&mut buf).context("read request")?;
let request = String::from_utf8_lossy(&buf[..n]);
let mut request_parts = request
.lines()
.next()
.unwrap_or_default()
.split_whitespace();
let method = request_parts.next().unwrap_or_default();
let request_path = request_parts.next().map_or("/", |path| path);
if method != "GET" && method != "HEAD" {
return write_http_response(
&mut stream,
"405 Method Not Allowed",
"text/plain",
b"method not allowed",
method == "HEAD",
);
}
let file = resolve_static_path(root, request_path);
if let Ok(bytes) = fs::read(&file) {
write_http_response(
&mut stream,
"200 OK",
content_type(&file),
&bytes,
method == "HEAD",
)?;
} else {
write_http_response(
&mut stream,
"404 Not Found",
"text/plain; charset=utf-8",
b"not found",
method == "HEAD",
)?;
}
Ok(())
}
pub(super) fn resolve_static_path(root: &Path, request_path: &str) -> PathBuf {
let request_path_without_query = request_path
.split_once('?')
.map_or(request_path, |(path, _query)| path);
let Some(clean) = percent_decode_path(request_path_without_query) else {
return root.join("__invalid__");
};
let clean = clean.trim_start_matches('/');
if clean.is_empty() {
return root.join("index.html");
}
if clean
.split('/')
.any(|component| component == "." || component == "..")
{
return root.join("__invalid__");
}
let path = root.join(clean);
if request_path_without_query.ends_with('/') {
path.join("index.html")
} else {
path
}
}
pub(super) fn content_type(path: &Path) -> &'static str {
match path.extension().and_then(|extension| extension.to_str()) {
Some("css") => "text/css; charset=utf-8",
Some("gif") => "image/gif",
Some("html") => "text/html; charset=utf-8",
Some("ico") => "image/x-icon",
Some("js") => "text/javascript; charset=utf-8",
Some("json") => "application/json; charset=utf-8",
Some("jpg") | Some("jpeg") => "image/jpeg",
Some("png") => "image/png",
Some("svg") => "image/svg+xml",
Some("txt") => "text/plain; charset=utf-8",
Some("webp") => "image/webp",
_ => "application/octet-stream",
}
}
fn write_http_response(
stream: &mut TcpStream,
status: &str,
content_type: &str,
body: &[u8],
head_only: bool,
) -> Result<()> {
let header = format!(
"HTTP/1.1 {status}\r\nContent-Type: {content_type}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
body.len()
);
stream.write_all(header.as_bytes())?;
if !head_only {
stream.write_all(body)?;
}
Ok(())
}
fn percent_decode_path(path: &str) -> Option<String> {
if !has_valid_percent_escapes(path) {
return None;
}
percent_decode_str(path)
.decode_utf8()
.ok()
.map(std::borrow::Cow::into_owned)
}
fn has_valid_percent_escapes(path: &str) -> bool {
let bytes = path.as_bytes();
let mut index = 0;
while index < bytes.len() {
if bytes[index] == b'%' {
let Some(hi) = bytes.get(index + 1) else {
return false;
};
let Some(lo) = bytes.get(index + 2) else {
return false;
};
if !hi.is_ascii_hexdigit() || !lo.is_ascii_hexdigit() {
return false;
}
index += 3;
} else {
index += 1;
}
}
true
}