rocket_include_dir/
lib.rs1use std::{ops::Deref, path::PathBuf};
9
10use include_dir::File;
11use rocket::{
12 fs::Options,
13 http::{
14 ext::IntoOwned,
15 uri::{fmt::Path, Segments},
16 ContentType, Method, Status,
17 },
18 outcome::IntoOutcome,
19 response::{self, Redirect, Responder},
20 route::{Handler, Outcome},
21 Data, Request, Route,
22};
23
24pub use include_dir::include_dir;
25pub use include_dir::Dir;
26
27#[derive(Clone, Copy)]
45pub struct StaticFiles {
46 dir: &'static Dir<'static>,
47 options: Options,
48 rank: isize,
49}
50
51impl From<&'static Dir<'static>> for StaticFiles {
52 fn from(dir: &'static Dir<'static>) -> Self {
53 Self {
54 dir,
55 options: Options::default(),
56 rank: Self::DEFAULT_RANK,
57 }
58 }
59}
60
61impl StaticFiles {
62 const DEFAULT_RANK: isize = 10;
63
64 pub fn new(dir: &'static Dir<'static>, options: Options) -> Self {
69 Self {
70 dir,
71 options,
72 rank: Self::DEFAULT_RANK,
73 }
74 }
75
76 pub fn options(mut self, options: Options) -> Self {
78 self.options = options;
79 self
80 }
81
82 pub fn rank(mut self, rank: isize) -> Self {
84 self.rank = rank;
85 self
86 }
87}
88
89fn respond_with<'r>(
90 req: &'r Request<'_>,
91 path: PathBuf,
92 file: &'r File<'r>,
93) -> response::Result<'r> {
94 let mut response = file.contents().respond_to(req)?;
95 if let Some(ext) = path.extension() {
96 if let Some(ct) = ContentType::from_extension(&ext.to_string_lossy()) {
97 response.set_header(ct);
98 }
99 }
100
101 Ok(response)
102}
103
104#[rocket::async_trait]
105impl Handler for StaticFiles {
106 async fn handle<'r>(&self, req: &'r Request<'_>, data: Data<'r>) -> Outcome<'r> {
107 let options = self.options;
109 let allow_dotfiles = options.contains(Options::DotFiles);
111 let path = req
112 .segments::<Segments<'_, Path>>(0..)
113 .ok()
114 .and_then(|segments| segments.to_path_buf(allow_dotfiles).ok());
115
116 match path {
117 Some(p) => {
118 if let Some(path) = self.dir.get_dir(&p) {
119 if options.contains(Options::NormalizeDirs) && !req.uri().path().ends_with('/')
120 {
121 let normal = req
122 .uri()
123 .map_path(|p| format!("{}/", p))
124 .expect("adding a trailing slash to a known good path => valid path")
125 .into_owned();
126
127 return Redirect::permanent(normal)
128 .respond_to(req)
129 .or_forward((data, Status::InternalServerError));
130 }
131 if !options.contains(Options::Index) {
132 return Outcome::forward(data, Status::NotFound);
133 }
134 path.get_entry("index.html")
135 .and_then(|f| f.as_file())
136 .ok_or(Status::NotFound)
137 .and_then(|path| respond_with(req, p.join("index.html"), path))
138 .or_forward((data, Status::NotFound))
139 } else if let Some(path) = self.dir.get_file(&p) {
140 respond_with(req, p, path).or_forward((data, Status::NotFound))
141 } else {
142 Outcome::forward(data, Status::NotFound)
143 }
144 }
145 None => Outcome::forward(data, Status::NotFound),
146 }
147 }
148}
149
150impl Into<Route> for StaticFiles {
151 fn into(self) -> Route {
152 Route::ranked(self.rank, Method::Get, "/<path..>", self)
153 }
154}
155
156impl Into<Vec<Route>> for StaticFiles {
157 fn into(self) -> Vec<Route> {
158 vec![self.into()]
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use include_dir::include_dir;
165 use rocket::{build, local::blocking::Client, Build, Rocket};
166
167 use super::*;
168
169 fn launch() -> Rocket<Build> {
170 static PROJECT_DIR: Dir = include_dir!("static");
171 build().mount("/", StaticFiles::new(&PROJECT_DIR, Options::default()))
172 }
173
174 #[test]
175 fn it_works() {
176 std::env::set_current_dir("/tmp").expect("Requires /tmp directory");
178 let client = Client::tracked(launch()).expect("valid rocket instance");
179 let response = client.get("/test-doesnt-exist").dispatch();
180 assert_eq!(response.status(), Status::NotFound);
181 let response = client.get("/test.txt").dispatch();
182 assert_eq!(response.status(), Status::Ok);
183 }
184}