1pub mod runtime;
4
5mod buf;
6mod chunk;
7mod date;
8mod error;
9
10pub use self::{chunk::ChunkReader, error::ServeError};
11
12use std::{
13 io::SeekFrom,
14 path::{Component, Path, PathBuf},
15};
16
17use http::{
18 header::{HeaderValue, ACCEPT_RANGES, CONTENT_LENGTH, CONTENT_RANGE, CONTENT_TYPE, LAST_MODIFIED, RANGE},
19 Method, Request, Response, StatusCode,
20};
21use mime_guess::mime;
22
23use self::{
24 buf::buf_write_header,
25 runtime::{AsyncFs, ChunkRead, Meta},
26};
27
28#[cfg(feature = "tokio")]
29#[derive(Clone)]
30pub struct ServeDir<FS: AsyncFs = runtime::TokioFs> {
31 chunk_size: usize,
32 base_path: PathBuf,
33 async_fs: FS,
34}
35
36#[cfg(not(feature = "tokio"))]
37#[derive(Clone)]
38pub struct ServeDir<FS: AsyncFs> {
39 chunk_size: usize,
40 base_path: PathBuf,
41 async_fs: FS,
42}
43
44#[cfg(feature = "default")]
45impl ServeDir<runtime::TokioFs> {
46 pub fn new(path: impl Into<PathBuf>) -> Self {
48 Self::with_fs(path, runtime::TokioFs)
49 }
50}
51
52#[cfg(feature = "tokio-uring")]
53impl ServeDir<runtime::TokioUringFs> {
54 pub fn new_tokio_uring(path: impl Into<PathBuf>) -> Self {
56 Self::with_fs(path, runtime::TokioUringFs)
57 }
58}
59
60impl<FS: AsyncFs> ServeDir<FS> {
61 pub fn with_fs(path: impl Into<PathBuf>, async_fs: FS) -> Self {
64 Self {
65 chunk_size: 4096,
66 base_path: path.into(),
67 async_fs,
68 }
69 }
70
71 pub fn chunk_size(&mut self, size: usize) -> &mut Self {
75 self.chunk_size = size;
76 self
77 }
78
79 pub async fn serve<Ext>(&self, req: &Request<Ext>) -> Result<Response<ChunkReader<FS::File>>, ServeError> {
92 if !matches!(*req.method(), Method::HEAD | Method::GET) {
93 return Err(ServeError::MethodNotAllowed);
94 }
95
96 let path = self.path_check(req.uri().path())?;
97
98 if path.is_dir() {
100 return Err(ServeError::InvalidPath);
101 }
102
103 let ct = mime_guess::from_path(&path)
104 .first_raw()
105 .unwrap_or_else(|| mime::APPLICATION_OCTET_STREAM.as_ref());
106
107 let mut file = self.async_fs.open(path).await?;
108
109 let modified = date::mod_date_check(req, &mut file)?;
110
111 let mut res = Response::new(());
112
113 let mut size = file.len();
114
115 if let Some(range) = req
116 .headers()
117 .get(RANGE)
118 .and_then(|h| h.to_str().ok())
119 .and_then(|range| http_range_header::parse_range_header(range).ok())
120 .map(|range| range.validate(size))
121 {
122 let (start, end) = range
123 .map_err(|_| ServeError::RangeNotSatisfied(size))?
124 .pop()
125 .expect("http_range_header produced empty range")
126 .into_inner();
127
128 file.seek(SeekFrom::Start(start)).await?;
129
130 *res.status_mut() = StatusCode::PARTIAL_CONTENT;
131 let val = buf_write_header!(0, "bytes {start}-{end}/{size}");
132 res.headers_mut().insert(CONTENT_RANGE, val);
133
134 size = end - start + 1;
135 }
136
137 res.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static(ct));
138 res.headers_mut().insert(CONTENT_LENGTH, HeaderValue::from(size));
139 res.headers_mut()
140 .insert(ACCEPT_RANGES, HeaderValue::from_static("bytes"));
141
142 if let Some(modified) = modified {
143 let val = date::date_to_header(modified);
144 res.headers_mut().insert(LAST_MODIFIED, val);
145 }
146
147 let stream = if matches!(*req.method(), Method::HEAD) {
148 ChunkReader::empty()
149 } else {
150 ChunkReader::reader(file, size, self.chunk_size)
151 };
152
153 Ok(res.map(|_| stream))
154 }
155}
156
157impl<FS: AsyncFs> ServeDir<FS> {
158 fn path_check(&self, path: &str) -> Result<PathBuf, ServeError> {
159 let path = path.trim_start_matches('/').as_bytes();
160
161 let path_decoded = percent_encoding::percent_decode(path)
162 .decode_utf8()
163 .map_err(|_| ServeError::InvalidPath)?;
164 let path_decoded = Path::new(&*path_decoded);
165
166 let mut path = self.base_path.clone();
167
168 for component in path_decoded.components() {
169 match component {
170 Component::Normal(comp) => {
171 if Path::new(&comp)
172 .components()
173 .any(|c| !matches!(c, Component::Normal(_)))
174 {
175 return Err(ServeError::InvalidPath);
176 }
177 path.push(comp)
178 }
179 Component::CurDir => {}
180 Component::Prefix(_) | Component::RootDir | Component::ParentDir => {
181 return Err(ServeError::InvalidPath)
182 }
183 }
184 }
185
186 Ok(path)
187 }
188}
189
190#[cfg(test)]
191mod test {
192 use core::future::poll_fn;
193
194 use futures_core::stream::Stream;
195
196 use super::*;
197
198 fn assert_send<F: Send>(_: &F) {}
199
200 #[tokio::test]
201 async fn tokio_fs_assert_send() {
202 let dir = ServeDir::new("sample");
203 let req = Request::builder().uri("/test.txt").body(()).unwrap();
204
205 let fut = dir.serve(&req);
206
207 assert_send(&fut);
208
209 let res = fut.await.unwrap();
210
211 assert_send(&res);
212 }
213
214 #[tokio::test]
215 async fn method() {
216 let dir = ServeDir::new("sample");
217 let req = Request::builder()
218 .method(Method::POST)
219 .uri("/test.txt")
220 .body(())
221 .unwrap();
222
223 let e = dir.serve(&req).await.err().unwrap();
224 assert!(matches!(e, ServeError::MethodNotAllowed));
225 }
226
227 #[tokio::test]
228 async fn head_method_body_check() {
229 let dir = ServeDir::new("sample");
230 let req = Request::builder()
231 .method(Method::HEAD)
232 .uri("/test.txt")
233 .body(())
234 .unwrap();
235
236 let res = dir.serve(&req).await.unwrap();
237
238 assert_eq!(
239 res.headers().get(CONTENT_LENGTH).unwrap(),
240 HeaderValue::from("hello, world!".len())
241 );
242
243 let mut stream = Box::pin(res.into_body());
244
245 assert_eq!(stream.size_hint(), (usize::MAX, Some(0)));
246
247 let body_chunk = poll_fn(|cx| stream.as_mut().poll_next(cx)).await;
248
249 assert!(body_chunk.is_none())
250 }
251
252 #[tokio::test]
253 async fn invalid_path() {
254 let dir = ServeDir::new("sample");
255 let req = Request::builder().uri("/../test.txt").body(()).unwrap();
256 assert!(matches!(dir.serve(&req).await.err(), Some(ServeError::InvalidPath)));
257 }
258
259 #[tokio::test]
260 async fn response_headers() {
261 let dir = ServeDir::new("sample");
262 let req = Request::builder().uri("/test.txt").body(()).unwrap();
263 let res = dir.serve(&req).await.unwrap();
264 assert_eq!(
265 res.headers().get(CONTENT_TYPE).unwrap(),
266 HeaderValue::from_static("text/plain")
267 );
268 assert_eq!(
269 res.headers().get(ACCEPT_RANGES).unwrap(),
270 HeaderValue::from_static("bytes")
271 );
272 assert_eq!(
273 res.headers().get(CONTENT_LENGTH).unwrap(),
274 HeaderValue::from("hello, world!".len())
275 );
276 }
277
278 #[tokio::test]
279 async fn body_size_hint() {
280 let dir = ServeDir::new("sample");
281 let req = Request::builder().uri("/test.txt").body(()).unwrap();
282 let res = dir.serve(&req).await.unwrap();
283 let (lower, Some(upper)) = res.body().size_hint() else {
284 panic!("ChunkReadStream does not have a size")
285 };
286 assert_eq!(lower, upper);
287 assert_eq!(lower, "hello, world!".len());
288 }
289
290 async fn _basic<FS: AsyncFs>(dir: ServeDir<FS>) {
291 let req = Request::builder().uri("/test.txt").body(()).unwrap();
292
293 let mut stream = Box::pin(dir.serve(&req).await.unwrap().into_body());
294
295 let (low, high) = stream.size_hint();
296
297 assert_eq!(low, high.unwrap());
298 assert_eq!(low, "hello, world!".len());
299
300 let mut res = String::new();
301
302 while let Some(Ok(bytes)) = poll_fn(|cx| stream.as_mut().poll_next(cx)).await {
303 res.push_str(std::str::from_utf8(bytes.as_ref()).unwrap());
304 }
305
306 assert_eq!(res, "hello, world!");
307 }
308
309 #[tokio::test]
310 async fn basic() {
311 _basic(ServeDir::new("sample")).await;
312 }
313
314 #[cfg(all(target_os = "linux", feature = "tokio-uring"))]
315 #[test]
316 fn basic_tokio_uring() {
317 tokio_uring::start(_basic(ServeDir::new_tokio_uring("sample")));
318 }
319
320 async fn test_range<FS: AsyncFs>(dir: ServeDir<FS>) {
321 let req = Request::builder()
322 .uri("/test.txt")
323 .header("range", "bytes=2-12")
324 .body(())
325 .unwrap();
326 let res = dir.serve(&req).await.unwrap();
327 assert_eq!(res.status(), StatusCode::PARTIAL_CONTENT);
328 assert_eq!(
329 res.headers().get(CONTENT_TYPE).unwrap(),
330 HeaderValue::from_static("text/plain")
331 );
332 assert_eq!(
333 res.headers().get(CONTENT_RANGE).unwrap(),
334 HeaderValue::from_static("bytes 2-12/13")
335 );
336 assert_eq!(
337 res.headers().get(CONTENT_LENGTH).unwrap(),
338 HeaderValue::from("llo, world!".len())
339 );
340
341 let mut stream = Box::pin(res.into_body());
342
343 let mut res = String::new();
344
345 while let Some(Ok(bytes)) = poll_fn(|cx| stream.as_mut().poll_next(cx)).await {
346 res.push_str(std::str::from_utf8(bytes.as_ref()).unwrap());
347 }
348
349 assert_eq!("llo, world!", res);
350 }
351
352 #[tokio::test]
353 async fn ranged() {
354 test_range(ServeDir::new("sample")).await;
355 }
356
357 #[cfg(all(target_os = "linux", feature = "tokio-uring"))]
358 #[test]
359 fn ranged_tokio_uring() {
360 tokio_uring::start(test_range(ServeDir::new_tokio_uring("sample")))
361 }
362}