1use std::{
4 collections::Bound,
5 io::{Seek, SeekFrom},
6 path::{Path, PathBuf},
7 str::FromStr,
8 time::SystemTime,
9};
10use tokio::io::AsyncReadExt;
11use tokio_util::io::ReaderStream;
12
13use viz_core::{
14 headers::{
15 AcceptRanges, ContentLength, ContentRange, ContentType, ETag, HeaderMap, HeaderMapExt,
16 IfMatch, IfModifiedSince, IfNoneMatch, IfUnmodifiedSince, LastModified, Range,
17 },
18 Handler, IntoResponse, Method, Request, RequestExt, Response, ResponseExt, Result, StatusCode,
19};
20
21mod directory;
22mod error;
23
24use directory::Directory;
25pub use error::Error;
26
27#[derive(Clone, Debug)]
29pub struct File {
30 path: PathBuf,
31}
32
33impl File {
34 #[must_use]
40 pub fn new(path: impl Into<PathBuf>) -> Self {
41 let path = path.into();
42
43 assert!(path.exists(), "{} not found", path.to_string_lossy());
44
45 Self { path }
46 }
47}
48
49#[viz_core::async_trait]
50impl Handler<Request> for File {
51 type Output = Result<Response>;
52
53 async fn call(&self, req: Request) -> Self::Output {
54 serve(&self.path, req.headers())
55 }
56}
57
58#[derive(Clone, Debug)]
60pub struct Dir {
61 path: PathBuf,
62 listing: bool,
63 unlisted: Option<Vec<&'static str>>,
64}
65
66impl Dir {
67 #[must_use]
73 pub fn new(path: impl Into<PathBuf>) -> Self {
74 let path = path.into();
75
76 assert!(path.exists(), "{} not found", path.to_string_lossy());
77
78 Self {
79 path,
80 listing: false,
81 unlisted: None,
82 }
83 }
84
85 #[must_use]
87 pub const fn listing(mut self) -> Self {
88 self.listing = true;
89 self
90 }
91
92 #[must_use]
94 pub fn unlisted(mut self, unlisted: Vec<&'static str>) -> Self {
95 self.unlisted.replace(unlisted);
96 self
97 }
98}
99
100#[viz_core::async_trait]
101impl Handler<Request> for Dir {
102 type Output = Result<Response>;
103
104 async fn call(&self, req: Request) -> Self::Output {
105 if req.method() != Method::GET {
106 Err(Error::MethodNotAllowed)?;
107 }
108
109 let mut prev = false;
110 let mut path = self.path.clone();
111
112 if let Some(param) = req.route_info().params.first().map(|(_, v)| v) {
113 let p = percent_encoding::percent_decode_str(param)
114 .decode_utf8()
115 .map_err(|_| Error::InvalidPath)?;
116 sanitize_path(&mut path, &p)?;
117 prev = true;
118 }
119
120 if !path.exists() {
121 Err(StatusCode::NOT_FOUND.into_error())?;
122 }
123
124 if path.is_file() {
125 return serve(&path, req.headers());
126 }
127
128 let index = path.join("index.html");
129 if index.exists() {
130 return serve(&index, req.headers());
131 }
132
133 if self.listing {
134 return Directory::new(req.path(), prev, &path, self.unlisted.as_ref())
135 .ok_or_else(|| StatusCode::INTERNAL_SERVER_ERROR.into_error())
136 .map(IntoResponse::into_response);
137 }
138
139 Ok(StatusCode::NOT_FOUND.into_response())
140 }
141}
142
143fn sanitize_path<'a>(path: &'a mut PathBuf, p: &'a str) -> Result<()> {
144 for seg in p.split('/') {
145 if seg.starts_with("..") {
146 return Err(StatusCode::NOT_FOUND.into_error());
147 }
148 if seg.contains('\\') {
149 return Err(StatusCode::NOT_FOUND.into_error());
150 }
151 path.push(seg);
152 }
153 Ok(())
154}
155
156fn extract_etag(mtime: &SystemTime, size: u64) -> Option<ETag> {
157 ETag::from_str(&format!(
158 r#""{}-{}""#,
159 mtime
160 .duration_since(SystemTime::UNIX_EPOCH)
161 .ok()?
162 .as_millis(),
163 size
164 ))
165 .ok()
166}
167
168#[inline]
169fn serve(path: &Path, headers: &HeaderMap) -> Result<Response> {
170 let mut file = std::fs::File::open(path).map_err(Error::Io)?;
171 let metadata = file
172 .metadata()
173 .map_err(|_| StatusCode::NOT_FOUND.into_error())?;
174
175 let mut etag = None;
176 let mut last_modified = None;
177 let mut content_range = None;
178 let mut max = metadata.len();
179
180 if let Ok(modified) = metadata.modified() {
181 etag = extract_etag(&modified, max);
182
183 if matches!((headers.typed_get::<IfMatch>(), &etag), (Some(if_match), Some(etag)) if !if_match.precondition_passes(etag))
184 || matches!(headers.typed_get::<IfUnmodifiedSince>(), Some(if_unmodified_since) if !if_unmodified_since.precondition_passes(modified))
185 {
186 Err(Error::PreconditionFailed)?;
187 }
188
189 if matches!((headers.typed_get::<IfNoneMatch>(), &etag), (Some(if_no_match), Some(etag)) if !if_no_match.precondition_passes(etag))
190 || matches!(headers.typed_get::<IfModifiedSince>(), Some(if_modified_since) if !if_modified_since.is_modified(modified))
191 {
192 return Ok(StatusCode::NOT_MODIFIED.into_response());
193 }
194
195 last_modified.replace(LastModified::from(modified));
196 }
197
198 if let Some((start, end)) = headers
200 .typed_get::<Range>()
201 .and_then(|range| range.satisfiable_ranges(100).next())
202 {
203 let start = match start {
204 Bound::Included(n) => n,
205 Bound::Excluded(n) => n + 1,
206 Bound::Unbounded => 0,
207 };
208 let end = match end {
209 Bound::Included(n) => n + 1,
210 Bound::Excluded(n) => n,
211 Bound::Unbounded => max,
212 };
213
214 if end < start || end > max {
215 Err(Error::RangeUnsatisfied(max))?;
216 }
217
218 if start != 0 || end != max {
219 if let Ok(range) = ContentRange::bytes(start..end, max) {
220 max = end - start;
221 content_range.replace(range);
222 file.seek(SeekFrom::Start(start)).map_err(Error::Io)?;
223 }
224 }
225 }
226
227 let mut res = if content_range.is_some() {
228 Response::stream(ReaderStream::new(tokio::fs::File::from_std(file).take(max)))
230 } else {
231 Response::stream(ReaderStream::new(tokio::fs::File::from_std(file)))
232 };
233
234 let headers = res.headers_mut();
235
236 headers.typed_insert(AcceptRanges::bytes());
237 headers.typed_insert(ContentLength(max));
238 headers.typed_insert(ContentType::from(
239 mime_guess::from_path(path).first_or_octet_stream(),
240 ));
241
242 if let Some(etag) = etag {
243 headers.typed_insert(etag);
244 }
245
246 if let Some(last_modified) = last_modified {
247 headers.typed_insert(last_modified);
248 }
249
250 if let Some(content_range) = content_range {
251 headers.typed_insert(content_range);
252 *res.status_mut() = StatusCode::PARTIAL_CONTENT;
253 };
254
255 Ok(res)
256}
257
258#[cfg(test)]
259mod tests {
260 use super::{Dir, File};
261 use std::sync::Arc;
262 use viz_core::{
263 types::{Params, RouteInfo},
264 Handler, IntoResponse, Request, Result, StatusCode,
265 };
266
267 #[tokio::test]
268 async fn file() -> Result<()> {
269 let serve = File::new("src/serve.rs");
270
271 let mut req: Request = Request::default();
272 req.extensions_mut().insert(Arc::new(RouteInfo {
273 id: 2,
274 pattern: "/*".to_string(),
275 params: Into::<Params>::into(vec![("*1", "serve.rs")]),
276 }));
277 *req.uri_mut() = "/serve.rs".parse().unwrap();
278
279 let result = serve.call(req).await;
280
281 assert_eq!(result.unwrap().status(), StatusCode::OK);
282
283 let mut req: Request = Request::default();
284 req.extensions_mut().insert(Arc::new(RouteInfo {
285 id: 2,
286 pattern: "/*".to_string(),
287 params: Into::<Params>::into(vec![("*1", "serve")]),
288 }));
289 *req.uri_mut() = "/serve".parse().unwrap();
290
291 let result = serve.call(req).await;
292
293 assert_eq!(result.unwrap().status(), StatusCode::OK);
294
295 Ok(())
296 }
297
298 #[tokio::test]
299 async fn dir() -> Result<()> {
300 let serve = Dir::new("src/serve");
301
302 let mut req: Request = Request::default();
303 req.extensions_mut().insert(Arc::new(RouteInfo {
304 id: 2,
305 pattern: "/*".to_string(),
306 params: Into::<Params>::into(vec![("*1", "list.tpl")]),
307 }));
308 *req.uri_mut() = "/list.tpl".parse().unwrap();
309
310 let result = serve.call(req).await;
311
312 assert_eq!(result.unwrap().status(), StatusCode::OK);
313
314 let mut req: Request = Request::default();
315 req.extensions_mut().insert(Arc::new(RouteInfo {
316 id: 2,
317 pattern: "/*".to_string(),
318 params: Into::<Params>::into(vec![("*1", "list")]),
319 }));
320 *req.uri_mut() = "/list".parse().unwrap();
321
322 let result = serve.call(req).await.map_err(IntoResponse::into_response);
323
324 assert_eq!(result.unwrap_err().status(), StatusCode::NOT_FOUND);
325
326 Ok(())
327 }
328}