1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![doc = include_str!("../README.md")]
3#![allow(forbidden_lint_groups)]
4
5use std::{
6 error::Error,
7 fmt,
8 path::{self, Component, PathBuf},
9 str::FromStr,
10};
11
12use axum::{
13 extract::{FromRequestParts, Path, rejection::PathRejection},
14 http::{StatusCode, request::Parts},
15 response::{IntoResponse, Response},
16};
17
18const REJECTION_MESSAGE: &str = "Invalid path: possible traversal attack detected";
19
20#[derive(Debug, Default, Clone)]
26#[cfg_attr(feature = "serde", derive(serde::Serialize))]
27pub struct SafePath(pub PathBuf);
28
29impl AsRef<path::Path> for SafePath {
30 fn as_ref(&self) -> &path::Path {
31 self.0.as_ref()
32 }
33}
34
35impl FromStr for SafePath {
36 type Err = SafePathRejection;
37
38 fn from_str(s: &str) -> Result<Self, Self::Err> {
39 if is_traversal_attack(s) {
40 Err(SafePathRejection::TraversalAttack)
41 } else {
42 Ok(Self(PathBuf::from(s)))
43 }
44 }
45}
46
47#[derive(Debug)]
49pub enum SafePathRejection {
50 TraversalAttack,
52 PathExtraction(PathRejection),
54}
55
56impl fmt::Display for SafePathRejection {
57 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58 match self {
59 Self::TraversalAttack => f.write_str(REJECTION_MESSAGE),
60 Self::PathExtraction(err) => write!(f, "{err}"),
61 }
62 }
63}
64
65impl Error for SafePathRejection {
66 fn source(&self) -> Option<&(dyn Error + 'static)> {
67 match self {
68 Self::TraversalAttack => None,
69 Self::PathExtraction(err) => Some(err),
70 }
71 }
72}
73
74impl IntoResponse for SafePathRejection {
75 fn into_response(self) -> Response {
76 match self {
77 Self::TraversalAttack => (StatusCode::BAD_REQUEST, REJECTION_MESSAGE).into_response(),
78 Self::PathExtraction(inner) => inner.into_response(),
79 }
80 }
81}
82
83fn is_traversal_attack(path: impl AsRef<path::Path>) -> bool {
86 path.as_ref().components().any(|component| {
87 matches!(
88 component,
89 Component::ParentDir | Component::Prefix(_) | Component::RootDir
90 )
91 })
92}
93
94impl<S> FromRequestParts<S> for SafePath
95where
96 S: Send + Sync,
97{
98 type Rejection = SafePathRejection;
99
100 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
101 let Path(path) = Path::from_request_parts(parts, state)
102 .await
103 .map_err(SafePathRejection::PathExtraction)?;
104
105 (!is_traversal_attack(&path))
106 .then_some(Self(path))
107 .ok_or(SafePathRejection::TraversalAttack)
108 }
109}
110
111#[cfg(feature = "serde")]
112impl<'de> serde::Deserialize<'de> for SafePath {
113 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
114 where
115 D: serde::de::Deserializer<'de>,
116 {
117 let path = PathBuf::deserialize(deserializer)?;
118
119 if is_traversal_attack(&path) {
120 Err(serde::de::Error::custom(REJECTION_MESSAGE))
121 } else {
122 Ok(Self(path))
123 }
124 }
125}
126
127#[cfg(test)]
128mod validation_tests {
129 use super::*;
130
131 #[test]
132 fn valid_paths() {
133 assert!(!is_traversal_attack(""));
134 assert!(!is_traversal_attack("."));
135 assert!(!is_traversal_attack("./foo/bar.txt"));
136 assert!(!is_traversal_attack("a/b/c/d"));
137 assert!(!is_traversal_attack("foo.txt"));
138 assert!(!is_traversal_attack("foo/./bar.txt"));
139 assert!(!is_traversal_attack("foo/bar.txt"));
140 }
141
142 #[test]
143 fn invalid_parent_dir() {
144 assert!(is_traversal_attack(".."));
145 assert!(is_traversal_attack("../foo.txt"));
146 assert!(is_traversal_attack("foo/../bar.txt"));
147 assert!(is_traversal_attack("foo/bar/.."));
148 }
149
150 #[test]
151 fn invalid_absolute_paths() {
152 assert!(is_traversal_attack("/etc/passwd"));
153 assert!(is_traversal_attack("/foo/bar.txt"));
154 }
155
156 #[test]
157 #[cfg(windows)]
158 fn invalid_windows_paths() {
159 assert!(is_traversal_attack("C:\\Users\\Admin"));
160 assert!(is_traversal_attack("\\Windows"));
161 }
162}
163
164#[cfg(all(test, feature = "serde"))]
165#[allow(clippy::unwrap_used)]
166mod serde_tests {
167 use super::*;
168
169 #[test]
170 fn roundtrip() {
171 let path = SafePath(PathBuf::from("foo/bar.txt"));
172 let serialized = serde_json::to_string(&path).unwrap();
173 assert_eq!(serialized, r#""foo/bar.txt""#);
174
175 let deserialized: SafePath = serde_json::from_str(&serialized).unwrap();
176 assert_eq!(deserialized.0, path.0);
177 }
178
179 #[test]
180 fn invalid_json() {
181 let invalid_json = r#""../secret.txt""#;
182 let result: Result<SafePath, _> = serde_json::from_str(invalid_json);
183 assert!(result.is_err());
184 }
185}
186
187#[cfg(test)]
188#[allow(clippy::unwrap_used)]
189mod path_integration_tests {
190 use axum::{Router, routing::get};
191 use axum_test::TestServer;
192
193 use super::*;
194
195 async fn handler(SafePath(path): SafePath) -> String {
196 format!("Path: {}", path.display())
197 }
198
199 #[tokio::test]
200 async fn successful_path() {
201 let app = Router::new().route("/path/{*path}", get(handler));
202 let server = TestServer::new(app).unwrap();
203
204 let res = server.get("/path/foo/bar.txt").await;
205 assert_eq!(res.status_code(), StatusCode::OK);
206 assert_eq!(res.text(), "Path: foo/bar.txt");
207 }
208
209 #[tokio::test]
210 async fn rejected_path() {
211 let app = Router::new().route("/path/{*path}", get(handler));
212 let server = TestServer::new(app).unwrap();
213
214 let res = server.get("/path//etc/passwd").await;
215 assert_eq!(res.status_code(), StatusCode::BAD_REQUEST);
216 assert_eq!(res.text(), REJECTION_MESSAGE);
217 }
218}
219
220#[cfg(all(test, feature = "json"))]
221#[allow(clippy::unwrap_used, forbidden_lint_groups)]
222mod json_integration_tests {
223 use axum::{Json, Router, routing::post};
224 use axum_test::TestServer;
225 use serde_json::json;
226
227 use super::*;
228
229 #[derive(serde::Deserialize)]
230 struct Payload {
231 path: SafePath,
232 }
233
234 async fn json_handler(Json(payload): Json<Payload>) -> String {
235 format!("Path: {}", payload.path.0.display())
236 }
237
238 #[tokio::test]
239 async fn successful_json_path() {
240 let app = Router::new().route("/", post(json_handler));
241 let server = TestServer::new(app).unwrap();
242
243 let res = server
244 .post("/")
245 .json(&json!({ "path": "foo/bar.txt" }))
246 .await;
247
248 assert_eq!(res.status_code(), StatusCode::OK);
249 assert_eq!(res.text(), "Path: foo/bar.txt");
250 }
251
252 #[tokio::test]
253 async fn rejected_json_path() {
254 let app = Router::new().route("/", post(json_handler));
255 let server = TestServer::new(app).unwrap();
256
257 let res = server
258 .post("/")
259 .json(&json!({ "path": "../secret.txt" }))
260 .await;
261
262 assert_eq!(res.status_code(), StatusCode::UNPROCESSABLE_ENTITY);
263 assert!(res.text().contains(REJECTION_MESSAGE));
264 }
265}
266
267#[cfg(all(test, feature = "form"))]
268#[allow(clippy::unwrap_used, forbidden_lint_groups)]
269mod form_integration_tests {
270 use axum::{Form, Router, routing::post};
271 use axum_test::TestServer;
272
273 use super::*;
274
275 #[derive(serde::Deserialize)]
276 struct Payload {
277 path: SafePath,
278 }
279
280 #[derive(serde::Serialize)]
281 struct TestPayload<'a> {
282 path: &'a str,
283 }
284
285 async fn form_handler(Form(payload): Form<Payload>) -> String {
286 format!("Path: {}", payload.path.0.display())
287 }
288
289 #[tokio::test]
290 async fn successful_form_path() {
291 let app = Router::new().route("/", post(form_handler));
292 let server = TestServer::new(app).unwrap();
293
294 let res = server
295 .post("/")
296 .form(&TestPayload {
297 path: "foo/bar.txt",
298 })
299 .await;
300
301 assert_eq!(res.status_code(), StatusCode::OK);
302 assert_eq!(res.text(), "Path: foo/bar.txt");
303 }
304
305 #[tokio::test]
306 async fn rejected_form_path() {
307 let app = Router::new().route("/", post(form_handler));
308 let server = TestServer::new(app).unwrap();
309
310 let res = server
311 .post("/")
312 .form(&TestPayload {
313 path: "../secret.txt",
314 })
315 .await;
316
317 assert_eq!(res.status_code(), StatusCode::UNPROCESSABLE_ENTITY);
318 assert!(res.text().contains(REJECTION_MESSAGE));
319 }
320}