use std::fs;
use std::path::{Path, PathBuf};
use std::time::UNIX_EPOCH;
use crate::handler::Handler;
use crate::mime;
use crate::proto::{Method, Request, Response, StatusCode};
#[derive(Debug, Clone)]
pub struct StaticFiles {
root: PathBuf,
index: String,
}
impl StaticFiles {
pub fn new(root: impl Into<PathBuf>) -> StaticFiles {
StaticFiles {
root: root.into(),
index: "index.html".to_owned(),
}
}
pub fn index(mut self, name: impl Into<String>) -> StaticFiles {
self.index = name.into();
self
}
fn resolve(&self, req_path: &str) -> Option<PathBuf> {
let decoded = percent_decode(req_path);
let mut out = self.root.clone();
for seg in decoded.split('/') {
if seg.is_empty() || seg == "." {
continue;
}
if seg == ".." {
return None; }
if seg.contains('\0') || seg.contains('/') || seg.contains('\\') {
return None;
}
out.push(seg);
}
Some(out)
}
fn within_root(&self, path: &Path) -> bool {
match (fs::canonicalize(&self.root), fs::canonicalize(path)) {
(Ok(root), Ok(target)) => target.starts_with(root),
_ => false,
}
}
fn serve(&self, req: &Request) -> Response {
if !matches!(req.method(), Method::Get | Method::Head) {
return Response::status(StatusCode::METHOD_NOT_ALLOWED).header("Allow", "GET, HEAD");
}
let Some(mut path) = self.resolve(req.path()) else {
return Response::status(StatusCode::FORBIDDEN);
};
let meta = match fs::metadata(&path) {
Ok(m) => m,
Err(_) => return Response::status(StatusCode::NOT_FOUND),
};
if meta.is_dir() {
if !req.path().ends_with('/') {
let mut loc = req.path().to_owned();
loc.push('/');
if let Some(q) = req.query() {
loc.push('?');
loc.push_str(q);
}
return Response::redirect(StatusCode::MOVED_PERMANENTLY, loc);
}
path.push(&self.index);
}
let meta = match fs::metadata(&path) {
Ok(m) if m.is_file() => m,
_ => return Response::status(StatusCode::NOT_FOUND),
};
if !self.within_root(&path) {
return Response::status(StatusCode::FORBIDDEN);
}
let content_type = mime::from_path(path.to_string_lossy().as_ref());
let len = meta.len();
let mtime_secs = meta
.modified()
.ok()
.and_then(|t| t.duration_since(UNIX_EPOCH).ok())
.map(|d| d.as_secs());
let last_modified = mtime_secs.map(crate::proto::http_date);
let etag = format!("\"{:x}-{:x}\"", len, mtime_secs.unwrap_or(0));
if conditional_hit(req, &etag, last_modified.as_deref()) {
let mut resp = Response::new(StatusCode::NOT_MODIFIED).header("ETag", etag.clone());
if let Some(lm) = &last_modified {
resp = resp.header("Last-Modified", lm.clone());
}
return resp;
}
let bytes = match fs::read(&path) {
Ok(b) => b,
Err(_) => return Response::status(StatusCode::INTERNAL_SERVER_ERROR),
};
if let Some(range) = req.headers().get("range") {
if let Some((start, end)) = parse_single_range(range, bytes.len() as u64) {
let slice = bytes[start as usize..=end as usize].to_vec();
let mut resp = Response::new(StatusCode::PARTIAL_CONTENT)
.header("Content-Type", content_type)
.header("Accept-Ranges", "bytes")
.header(
"Content-Range",
format!("bytes {start}-{end}/{}", bytes.len()),
)
.header("ETag", etag);
if let Some(lm) = last_modified {
resp = resp.header("Last-Modified", lm);
}
return resp.body(slice);
} else if range.trim_start().starts_with("bytes=") {
return Response::new(StatusCode::RANGE_NOT_SATISFIABLE)
.header("Content-Range", format!("bytes */{}", bytes.len()));
}
}
let mut resp = Response::new(StatusCode::OK)
.header("Content-Type", content_type)
.header("Accept-Ranges", "bytes")
.header("ETag", etag);
if let Some(lm) = last_modified {
resp = resp.header("Last-Modified", lm);
}
resp.body(bytes)
}
}
impl Handler for StaticFiles {
fn handle(&self, req: &Request) -> Response {
self.serve(req)
}
}
fn conditional_hit(req: &Request, etag: &str, last_modified: Option<&str>) -> bool {
if let Some(inm) = req.headers().get("if-none-match") {
return inm == "*" || inm.split(',').any(|t| t.trim() == etag);
}
if let (Some(ims), Some(lm)) = (req.headers().get("if-modified-since"), last_modified) {
return ims == lm;
}
false
}
fn parse_single_range(value: &str, total: u64) -> Option<(u64, u64)> {
let spec = value.trim().strip_prefix("bytes=")?;
if spec.contains(',') || total == 0 {
return None;
}
let (a, b) = spec.split_once('-')?;
let (a, b) = (a.trim(), b.trim());
let (start, end) = match (a.is_empty(), b.is_empty()) {
(true, false) => {
let n: u64 = b.parse().ok()?;
if n == 0 {
return None;
}
let n = n.min(total);
(total - n, total - 1)
}
(false, true) => {
let start: u64 = a.parse().ok()?;
(start, total - 1)
}
(false, false) => {
let start: u64 = a.parse().ok()?;
let end: u64 = b.parse().ok()?;
(start, end.min(total - 1))
}
(true, true) => return None,
};
if start > end || start >= total {
return None;
}
Some((start, end))
}
fn percent_decode(s: &str) -> String {
let bytes = s.as_bytes();
let mut out: Vec<u8> = Vec::with_capacity(bytes.len());
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'%' && i + 2 < bytes.len() {
let hi = (bytes[i + 1] as char).to_digit(16);
let lo = (bytes[i + 2] as char).to_digit(16);
if let (Some(hi), Some(lo)) = (hi, lo) {
out.push((hi * 16 + lo) as u8);
i += 3;
continue;
}
}
out.push(bytes[i]);
i += 1;
}
String::from_utf8_lossy(&out).into_owned()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn range_parsing() {
assert_eq!(parse_single_range("bytes=0-4", 10), Some((0, 4)));
assert_eq!(parse_single_range("bytes=5-", 10), Some((5, 9)));
assert_eq!(parse_single_range("bytes=-3", 10), Some((7, 9)));
assert_eq!(parse_single_range("bytes=8-100", 10), Some((8, 9)));
assert_eq!(parse_single_range("bytes=0-4,6-7", 10), None);
assert_eq!(parse_single_range("bytes=20-30", 10), None);
}
#[test]
fn percent_decoding() {
assert_eq!(percent_decode("/a%20b"), "/a b");
assert_eq!(percent_decode("/%2e%2e"), "/..");
assert_eq!(percent_decode("/bad%2"), "/bad%2");
}
#[test]
fn traversal_rejected() {
let sf = StaticFiles::new("/srv/www");
assert!(sf.resolve("/../etc/passwd").is_none());
assert!(sf.resolve("/a/%2e%2e/b").is_none());
assert_eq!(
sf.resolve("/sub/file.txt"),
Some(PathBuf::from("/srv/www/sub/file.txt"))
);
}
}