1#![forbid(unsafe_code)]
4
5pub mod runtime;
6
7mod buf;
8mod chunk;
9mod date;
10mod error;
11
12pub use self::{chunk::ChunkReader, error::ServeError};
13
14use std::{
15 io::SeekFrom,
16 path::{Component, Path, PathBuf},
17};
18
19use http::{
20 header::{HeaderValue, ACCEPT_RANGES, CONTENT_LENGTH, CONTENT_RANGE, CONTENT_TYPE, LAST_MODIFIED, RANGE},
21 Method, Request, Response, StatusCode,
22};
23use mime_guess::mime;
24
25use self::{
26 buf::buf_write_header,
27 runtime::{AsyncFs, ChunkRead, Meta},
28};
29
30#[cfg(feature = "tokio")]
31#[derive(Clone)]
32pub struct ServeDir<FS: AsyncFs = runtime::TokioFs> {
33 chunk_size: usize,
34 base_path: PathBuf,
35 async_fs: FS,
36}
37
38#[cfg(not(feature = "tokio"))]
39#[derive(Clone)]
40pub struct ServeDir<FS: AsyncFs> {
41 chunk_size: usize,
42 base_path: PathBuf,
43 async_fs: FS,
44}
45
46#[cfg(feature = "default")]
47impl ServeDir<runtime::TokioFs> {
48 pub fn new(path: impl Into<PathBuf>) -> Self {
50 Self::with_fs(path, runtime::TokioFs)
51 }
52}
53
54#[cfg(feature = "tokio-uring")]
55impl ServeDir<runtime::TokioUringFs> {
56 pub fn new_tokio_uring(path: impl Into<PathBuf>) -> Self {
58 Self::with_fs(path, runtime::TokioUringFs)
59 }
60}
61
62impl<FS: AsyncFs> ServeDir<FS> {
63 pub fn with_fs(path: impl Into<PathBuf>, async_fs: FS) -> Self {
66 Self {
67 chunk_size: 4096,
68 base_path: path.into(),
69 async_fs,
70 }
71 }
72
73 pub fn chunk_size(&mut self, size: usize) -> &mut Self {
77 self.chunk_size = size;
78 self
79 }
80
81 pub async fn serve<Ext>(&self, req: &Request<Ext>) -> Result<Response<ChunkReader<FS::File>>, ServeError> {
94 if !matches!(*req.method(), Method::HEAD | Method::GET) {
95 return Err(ServeError::MethodNotAllowed);
96 }
97
98 let path = self.path_check(req.uri().path())?;
99
100 if path.is_dir() {
102 return Err(ServeError::InvalidPath);
103 }
104
105 let ct = mime_guess::from_path(&path)
106 .first_raw()
107 .unwrap_or_else(|| mime::APPLICATION_OCTET_STREAM.as_ref());
108
109 let mut file = self.async_fs.open(path).await?;
110
111 let modified = date::mod_date_check(req, &mut file)?;
112
113 let mut res = Response::new(());
114
115 let mut size = file.len();
116
117 if let Some(range) = req
118 .headers()
119 .get(RANGE)
120 .and_then(|h| h.to_str().ok())
121 .and_then(|range| http_range_header::parse_range_header(range).ok())
122 .map(|range| range.validate(size))
123 {
124 let (start, end) = range
125 .map_err(|_| ServeError::RangeNotSatisfied(size))?
126 .pop()
127 .expect("http_range_header produced empty range")
128 .into_inner();
129
130 file.seek(SeekFrom::Start(start)).await?;
131
132 *res.status_mut() = StatusCode::PARTIAL_CONTENT;
133 let val = buf_write_header!(0, "bytes {start}-{end}/{size}");
134 res.headers_mut().insert(CONTENT_RANGE, val);
135
136 size = end - start + 1;
137 }
138
139 res.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static(ct));
140 res.headers_mut().insert(CONTENT_LENGTH, HeaderValue::from(size));
141 res.headers_mut()
142 .insert(ACCEPT_RANGES, HeaderValue::from_static("bytes"));
143
144 if let Some(modified) = modified {
145 let val = date::date_to_header(modified);
146 res.headers_mut().insert(LAST_MODIFIED, val);
147 }
148
149 let stream = if matches!(*req.method(), Method::HEAD) {
150 ChunkReader::empty()
151 } else {
152 ChunkReader::reader(file, size, self.chunk_size)
153 };
154
155 Ok(res.map(|_| stream))
156 }
157}
158
159impl<FS: AsyncFs> ServeDir<FS> {
160 fn path_check(&self, path: &str) -> Result<PathBuf, ServeError> {
161 let path = path.trim_start_matches('/').as_bytes();
162
163 let path_decoded = percent_encoding::percent_decode(path)
164 .decode_utf8()
165 .map_err(|_| ServeError::InvalidPath)?;
166 let path_decoded = Path::new(&*path_decoded);
167
168 let mut path = self.base_path.clone();
169
170 for component in path_decoded.components() {
171 match component {
172 Component::Normal(comp) => {
173 if Path::new(&comp)
174 .components()
175 .any(|c| !matches!(c, Component::Normal(_)))
176 {
177 return Err(ServeError::InvalidPath);
178 }
179 path.push(comp)
180 }
181 Component::CurDir => {}
182 Component::Prefix(_) | Component::RootDir | Component::ParentDir => {
183 return Err(ServeError::InvalidPath)
184 }
185 }
186 }
187
188 Ok(path)
189 }
190}
191
192#[cfg(test)]
193mod test {
194 use core::future::poll_fn;
195
196 use futures_core::stream::Stream;
197
198 use super::*;
199
200 fn assert_send<F: Send>(_: &F) {}
201
202 #[tokio::test]
203 async fn tokio_fs_assert_send() {
204 let dir = ServeDir::new("sample");
205 let req = Request::builder().uri("/test.txt").body(()).unwrap();
206
207 let fut = dir.serve(&req);
208
209 assert_send(&fut);
210
211 let res = fut.await.unwrap();
212
213 assert_send(&res);
214 }
215
216 #[tokio::test]
217 async fn method() {
218 let dir = ServeDir::new("sample");
219 let req = Request::builder()
220 .method(Method::POST)
221 .uri("/test.txt")
222 .body(())
223 .unwrap();
224
225 let e = dir.serve(&req).await.err().unwrap();
226 assert!(matches!(e, ServeError::MethodNotAllowed));
227 }
228
229 #[tokio::test]
230 async fn head_method_body_check() {
231 let dir = ServeDir::new("sample");
232 let req = Request::builder()
233 .method(Method::HEAD)
234 .uri("/test.txt")
235 .body(())
236 .unwrap();
237
238 let res = dir.serve(&req).await.unwrap();
239
240 assert_eq!(
241 res.headers().get(CONTENT_LENGTH).unwrap(),
242 HeaderValue::from("hello, world!".len())
243 );
244
245 let mut stream = Box::pin(res.into_body());
246
247 assert_eq!(stream.size_hint(), (usize::MAX, Some(0)));
248
249 let body_chunk = poll_fn(|cx| stream.as_mut().poll_next(cx)).await;
250
251 assert!(body_chunk.is_none())
252 }
253
254 #[tokio::test]
255 async fn invalid_path() {
256 let dir = ServeDir::new("sample");
257 let req = Request::builder().uri("/../test.txt").body(()).unwrap();
258 assert!(matches!(dir.serve(&req).await.err(), Some(ServeError::InvalidPath)));
259 }
260
261 #[tokio::test]
262 async fn response_headers() {
263 let dir = ServeDir::new("sample");
264 let req = Request::builder().uri("/test.txt").body(()).unwrap();
265 let res = dir.serve(&req).await.unwrap();
266 assert_eq!(
267 res.headers().get(CONTENT_TYPE).unwrap(),
268 HeaderValue::from_static("text/plain")
269 );
270 assert_eq!(
271 res.headers().get(ACCEPT_RANGES).unwrap(),
272 HeaderValue::from_static("bytes")
273 );
274 assert_eq!(
275 res.headers().get(CONTENT_LENGTH).unwrap(),
276 HeaderValue::from("hello, world!".len())
277 );
278 }
279
280 #[tokio::test]
281 async fn body_size_hint() {
282 let dir = ServeDir::new("sample");
283 let req = Request::builder().uri("/test.txt").body(()).unwrap();
284 let res = dir.serve(&req).await.unwrap();
285 let (lower, Some(upper)) = res.body().size_hint() else {
286 panic!("ChunkReadStream does not have a size")
287 };
288 assert_eq!(lower, upper);
289 assert_eq!(lower, "hello, world!".len());
290 }
291
292 async fn _basic<FS: AsyncFs>(dir: ServeDir<FS>) {
293 let req = Request::builder().uri("/test.txt").body(()).unwrap();
294
295 let mut stream = Box::pin(dir.serve(&req).await.unwrap().into_body());
296
297 let (low, high) = stream.size_hint();
298
299 assert_eq!(low, high.unwrap());
300 assert_eq!(low, "hello, world!".len());
301
302 let mut res = String::new();
303
304 while let Some(Ok(bytes)) = poll_fn(|cx| stream.as_mut().poll_next(cx)).await {
305 res.push_str(std::str::from_utf8(bytes.as_ref()).unwrap());
306 }
307
308 assert_eq!(res, "hello, world!");
309 }
310
311 #[tokio::test]
312 async fn basic() {
313 _basic(ServeDir::new("sample")).await;
314 }
315
316 #[cfg(all(target_os = "linux", feature = "tokio-uring"))]
317 #[test]
318 fn basic_tokio_uring() {
319 tokio_uring::start(_basic(ServeDir::new_tokio_uring("sample")));
320 }
321
322 async fn test_range<FS: AsyncFs>(dir: ServeDir<FS>) {
323 let req = Request::builder()
324 .uri("/test.txt")
325 .header("range", "bytes=2-12")
326 .body(())
327 .unwrap();
328 let res = dir.serve(&req).await.unwrap();
329 assert_eq!(res.status(), StatusCode::PARTIAL_CONTENT);
330 assert_eq!(
331 res.headers().get(CONTENT_TYPE).unwrap(),
332 HeaderValue::from_static("text/plain")
333 );
334 assert_eq!(
335 res.headers().get(CONTENT_RANGE).unwrap(),
336 HeaderValue::from_static("bytes 2-12/13")
337 );
338 assert_eq!(
339 res.headers().get(CONTENT_LENGTH).unwrap(),
340 HeaderValue::from("llo, world!".len())
341 );
342
343 let mut stream = Box::pin(res.into_body());
344
345 let mut res = String::new();
346
347 while let Some(Ok(bytes)) = poll_fn(|cx| stream.as_mut().poll_next(cx)).await {
348 res.push_str(std::str::from_utf8(bytes.as_ref()).unwrap());
349 }
350
351 assert_eq!("llo, world!", res);
352 }
353
354 #[tokio::test]
355 async fn ranged() {
356 test_range(ServeDir::new("sample")).await;
357 }
358
359 #[cfg(all(target_os = "linux", feature = "tokio-uring"))]
360 #[test]
361 fn ranged_tokio_uring() {
362 tokio_uring::start(test_range(ServeDir::new_tokio_uring("sample")))
363 }
364}