axum_safe_path/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![doc = include_str!("../README.md")]
3#![allow(forbidden_lint_groups)]
4
5use std::{
6    error::Error,
7    fmt,
8    path::{self, Component, PathBuf},
9    str::FromStr,
10};
11
12use axum::{
13    extract::{FromRequestParts, Path, rejection::PathRejection},
14    http::{StatusCode, request::Parts},
15    response::{IntoResponse, Response},
16};
17
18const REJECTION_MESSAGE: &str = "Invalid path: possible traversal attack detected";
19
20/// A traversal-safe path extractor for Axum.
21///
22/// This extractor wraps `axum::extract::Path` and rejects requests
23/// containing path components like `..`, `/`, or `C:`, preventing
24/// directory traversal attacks.
25#[derive(Debug, Default, Clone)]
26#[cfg_attr(feature = "serde", derive(serde::Serialize))]
27pub struct SafePath(pub PathBuf);
28
29impl AsRef<path::Path> for SafePath {
30    fn as_ref(&self) -> &path::Path {
31        self.0.as_ref()
32    }
33}
34
35impl FromStr for SafePath {
36    type Err = SafePathRejection;
37
38    fn from_str(s: &str) -> Result<Self, Self::Err> {
39        if is_traversal_attack(s) {
40            Err(SafePathRejection::TraversalAttack)
41        } else {
42            Ok(Self(PathBuf::from(s)))
43        }
44    }
45}
46
47/// Rejection type for [`SafePath`].
48#[derive(Debug)]
49pub enum SafePathRejection {
50    /// Possible traversal attack detected
51    TraversalAttack,
52    /// The underlying [`Path`] extractor failed
53    PathExtraction(PathRejection),
54}
55
56impl fmt::Display for SafePathRejection {
57    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58        match self {
59            Self::TraversalAttack => f.write_str(REJECTION_MESSAGE),
60            Self::PathExtraction(err) => write!(f, "{err}"),
61        }
62    }
63}
64
65impl Error for SafePathRejection {
66    fn source(&self) -> Option<&(dyn Error + 'static)> {
67        match self {
68            Self::TraversalAttack => None,
69            Self::PathExtraction(err) => Some(err),
70        }
71    }
72}
73
74impl IntoResponse for SafePathRejection {
75    fn into_response(self) -> Response {
76        match self {
77            Self::TraversalAttack => (StatusCode::BAD_REQUEST, REJECTION_MESSAGE).into_response(),
78            Self::PathExtraction(inner) => inner.into_response(),
79        }
80    }
81}
82
83/// Checks if a path contains traversal-related components such as `..`, a root
84/// directory, or a drive prefix.
85fn is_traversal_attack(path: impl AsRef<path::Path>) -> bool {
86    path.as_ref().components().any(|component| {
87        matches!(
88            component,
89            Component::ParentDir | Component::Prefix(_) | Component::RootDir
90        )
91    })
92}
93
94impl<S> FromRequestParts<S> for SafePath
95where
96    S: Send + Sync,
97{
98    type Rejection = SafePathRejection;
99
100    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
101        let Path(path) = Path::from_request_parts(parts, state)
102            .await
103            .map_err(SafePathRejection::PathExtraction)?;
104
105        (!is_traversal_attack(&path))
106            .then_some(Self(path))
107            .ok_or(SafePathRejection::TraversalAttack)
108    }
109}
110
111#[cfg(feature = "serde")]
112impl<'de> serde::Deserialize<'de> for SafePath {
113    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
114    where
115        D: serde::de::Deserializer<'de>,
116    {
117        let path = PathBuf::deserialize(deserializer)?;
118
119        if is_traversal_attack(&path) {
120            Err(serde::de::Error::custom(REJECTION_MESSAGE))
121        } else {
122            Ok(Self(path))
123        }
124    }
125}
126
127#[cfg(test)]
128mod validation_tests {
129    use super::*;
130
131    #[test]
132    fn valid_paths() {
133        assert!(!is_traversal_attack(""));
134        assert!(!is_traversal_attack("."));
135        assert!(!is_traversal_attack("./foo/bar.txt"));
136        assert!(!is_traversal_attack("a/b/c/d"));
137        assert!(!is_traversal_attack("foo.txt"));
138        assert!(!is_traversal_attack("foo/./bar.txt"));
139        assert!(!is_traversal_attack("foo/bar.txt"));
140    }
141
142    #[test]
143    fn invalid_parent_dir() {
144        assert!(is_traversal_attack(".."));
145        assert!(is_traversal_attack("../foo.txt"));
146        assert!(is_traversal_attack("foo/../bar.txt"));
147        assert!(is_traversal_attack("foo/bar/.."));
148    }
149
150    #[test]
151    fn invalid_absolute_paths() {
152        assert!(is_traversal_attack("/etc/passwd"));
153        assert!(is_traversal_attack("/foo/bar.txt"));
154    }
155
156    #[test]
157    #[cfg(windows)]
158    fn invalid_windows_paths() {
159        assert!(is_traversal_attack("C:\\Users\\Admin"));
160        assert!(is_traversal_attack("\\Windows"));
161    }
162}
163
164#[cfg(all(test, feature = "serde"))]
165#[allow(clippy::unwrap_used)]
166mod serde_tests {
167    use super::*;
168
169    #[test]
170    fn roundtrip() {
171        let path = SafePath(PathBuf::from("foo/bar.txt"));
172        let serialized = serde_json::to_string(&path).unwrap();
173        assert_eq!(serialized, r#""foo/bar.txt""#);
174
175        let deserialized: SafePath = serde_json::from_str(&serialized).unwrap();
176        assert_eq!(deserialized.0, path.0);
177    }
178
179    #[test]
180    fn invalid_json() {
181        let invalid_json = r#""../secret.txt""#;
182        let result: Result<SafePath, _> = serde_json::from_str(invalid_json);
183        assert!(result.is_err());
184    }
185}
186
187#[cfg(test)]
188#[allow(clippy::unwrap_used)]
189mod path_integration_tests {
190    use axum::{Router, routing::get};
191    use axum_test::TestServer;
192
193    use super::*;
194
195    async fn handler(SafePath(path): SafePath) -> String {
196        format!("Path: {}", path.display())
197    }
198
199    #[tokio::test]
200    async fn successful_path() {
201        let app = Router::new().route("/path/{*path}", get(handler));
202        let server = TestServer::new(app).unwrap();
203
204        let res = server.get("/path/foo/bar.txt").await;
205        assert_eq!(res.status_code(), StatusCode::OK);
206        assert_eq!(res.text(), "Path: foo/bar.txt");
207    }
208
209    #[tokio::test]
210    async fn rejected_path() {
211        let app = Router::new().route("/path/{*path}", get(handler));
212        let server = TestServer::new(app).unwrap();
213
214        let res = server.get("/path//etc/passwd").await;
215        assert_eq!(res.status_code(), StatusCode::BAD_REQUEST);
216        assert_eq!(res.text(), REJECTION_MESSAGE);
217    }
218}
219
220#[cfg(all(test, feature = "json"))]
221#[allow(clippy::unwrap_used, forbidden_lint_groups)]
222mod json_integration_tests {
223    use axum::{Json, Router, routing::post};
224    use axum_test::TestServer;
225    use serde_json::json;
226
227    use super::*;
228
229    #[derive(serde::Deserialize)]
230    struct Payload {
231        path: SafePath,
232    }
233
234    async fn json_handler(Json(payload): Json<Payload>) -> String {
235        format!("Path: {}", payload.path.0.display())
236    }
237
238    #[tokio::test]
239    async fn successful_json_path() {
240        let app = Router::new().route("/", post(json_handler));
241        let server = TestServer::new(app).unwrap();
242
243        let res = server
244            .post("/")
245            .json(&json!({ "path": "foo/bar.txt" }))
246            .await;
247
248        assert_eq!(res.status_code(), StatusCode::OK);
249        assert_eq!(res.text(), "Path: foo/bar.txt");
250    }
251
252    #[tokio::test]
253    async fn rejected_json_path() {
254        let app = Router::new().route("/", post(json_handler));
255        let server = TestServer::new(app).unwrap();
256
257        let res = server
258            .post("/")
259            .json(&json!({ "path": "../secret.txt" }))
260            .await;
261
262        assert_eq!(res.status_code(), StatusCode::UNPROCESSABLE_ENTITY);
263        assert!(res.text().contains(REJECTION_MESSAGE));
264    }
265}
266
267#[cfg(all(test, feature = "form"))]
268#[allow(clippy::unwrap_used, forbidden_lint_groups)]
269mod form_integration_tests {
270    use axum::{Form, Router, routing::post};
271    use axum_test::TestServer;
272
273    use super::*;
274
275    #[derive(serde::Deserialize)]
276    struct Payload {
277        path: SafePath,
278    }
279
280    #[derive(serde::Serialize)]
281    struct TestPayload<'a> {
282        path: &'a str,
283    }
284
285    async fn form_handler(Form(payload): Form<Payload>) -> String {
286        format!("Path: {}", payload.path.0.display())
287    }
288
289    #[tokio::test]
290    async fn successful_form_path() {
291        let app = Router::new().route("/", post(form_handler));
292        let server = TestServer::new(app).unwrap();
293
294        let res = server
295            .post("/")
296            .form(&TestPayload {
297                path: "foo/bar.txt",
298            })
299            .await;
300
301        assert_eq!(res.status_code(), StatusCode::OK);
302        assert_eq!(res.text(), "Path: foo/bar.txt");
303    }
304
305    #[tokio::test]
306    async fn rejected_form_path() {
307        let app = Router::new().route("/", post(form_handler));
308        let server = TestServer::new(app).unwrap();
309
310        let res = server
311            .post("/")
312            .form(&TestPayload {
313                path: "../secret.txt",
314            })
315            .await;
316
317        assert_eq!(res.status_code(), StatusCode::UNPROCESSABLE_ENTITY);
318        assert!(res.text().contains(REJECTION_MESSAGE));
319    }
320}