axum/extract/
nested_path.rs

1use std::{
2    sync::Arc,
3    task::{Context, Poll},
4};
5
6use crate::extract::Request;
7use axum_core::extract::FromRequestParts;
8use http::request::Parts;
9use tower_layer::{layer_fn, Layer};
10use tower_service::Service;
11
12use super::rejection::NestedPathRejection;
13
14/// Access the path the matched the route is nested at.
15///
16/// This can for example be used when doing redirects.
17///
18/// # Example
19///
20/// ```
21/// use axum::{
22///     Router,
23///     extract::NestedPath,
24///     routing::get,
25/// };
26///
27/// let api = Router::new().route(
28///     "/users",
29///     get(|path: NestedPath| async move {
30///         // `path` will be "/api" because that's what this
31///         // router is nested at when we build `app`
32///         let path = path.as_str();
33///     })
34/// );
35///
36/// let app = Router::new().nest("/api", api);
37/// # let _: Router = app;
38/// ```
39#[derive(Debug, Clone)]
40pub struct NestedPath(Arc<str>);
41
42impl NestedPath {
43    /// Returns a `str` representation of the path.
44    #[must_use]
45    pub fn as_str(&self) -> &str {
46        &self.0
47    }
48}
49
50impl<S> FromRequestParts<S> for NestedPath
51where
52    S: Send + Sync,
53{
54    type Rejection = NestedPathRejection;
55
56    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
57        match parts.extensions.get::<Self>() {
58            Some(nested_path) => Ok(nested_path.clone()),
59            None => Err(NestedPathRejection),
60        }
61    }
62}
63
64#[derive(Clone)]
65pub(crate) struct SetNestedPath<S> {
66    inner: S,
67    path: Arc<str>,
68}
69
70impl<S> SetNestedPath<S> {
71    pub(crate) fn layer(path: &str) -> impl Layer<S, Service = Self> + Clone {
72        let path = Arc::from(path);
73        layer_fn(move |inner| Self {
74            inner,
75            path: Arc::clone(&path),
76        })
77    }
78}
79
80impl<S, B> Service<Request<B>> for SetNestedPath<S>
81where
82    S: Service<Request<B>>,
83{
84    type Response = S::Response;
85    type Error = S::Error;
86    type Future = S::Future;
87
88    #[inline]
89    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
90        self.inner.poll_ready(cx)
91    }
92
93    fn call(&mut self, mut req: Request<B>) -> Self::Future {
94        if let Some(prev) = req.extensions_mut().get_mut::<NestedPath>() {
95            let new_path = if prev.as_str() == "/" {
96                Arc::clone(&self.path)
97            } else {
98                format!("{}{}", prev.as_str().trim_end_matches('/'), self.path).into()
99            };
100            prev.0 = new_path;
101        } else {
102            req.extensions_mut()
103                .insert(NestedPath(Arc::clone(&self.path)));
104        };
105
106        self.inner.call(req)
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use axum_core::response::Response;
113    use http::StatusCode;
114
115    use crate::{
116        extract::{NestedPath, Request},
117        middleware::{from_fn, Next},
118        routing::get,
119        test_helpers::*,
120        Router,
121    };
122
123    #[crate::test]
124    async fn one_level_of_nesting() {
125        let api = Router::new().route(
126            "/users",
127            get(|nested_path: NestedPath| {
128                assert_eq!(nested_path.as_str(), "/api");
129                async {}
130            }),
131        );
132
133        let app = Router::new().nest("/api", api);
134
135        let client = TestClient::new(app);
136
137        let res = client.get("/api/users").await;
138        assert_eq!(res.status(), StatusCode::OK);
139    }
140
141    #[crate::test]
142    async fn one_level_of_nesting_with_trailing_slash() {
143        let api = Router::new().route(
144            "/users",
145            get(|nested_path: NestedPath| {
146                assert_eq!(nested_path.as_str(), "/api/");
147                async {}
148            }),
149        );
150
151        let app = Router::new().nest("/api/", api);
152
153        let client = TestClient::new(app);
154
155        let res = client.get("/api/users").await;
156        assert_eq!(res.status(), StatusCode::OK);
157    }
158
159    #[crate::test]
160    async fn two_levels_of_nesting() {
161        let api = Router::new().route(
162            "/users",
163            get(|nested_path: NestedPath| {
164                assert_eq!(nested_path.as_str(), "/api/v2");
165                async {}
166            }),
167        );
168
169        let app = Router::new().nest("/api", Router::new().nest("/v2", api));
170
171        let client = TestClient::new(app);
172
173        let res = client.get("/api/v2/users").await;
174        assert_eq!(res.status(), StatusCode::OK);
175    }
176
177    #[crate::test]
178    async fn two_levels_of_nesting_with_trailing_slash() {
179        let api = Router::new().route(
180            "/users",
181            get(|nested_path: NestedPath| {
182                assert_eq!(nested_path.as_str(), "/api/v2");
183                async {}
184            }),
185        );
186
187        let app = Router::new().nest("/api/", Router::new().nest("/v2", api));
188
189        let client = TestClient::new(app);
190
191        let res = client.get("/api/v2/users").await;
192        assert_eq!(res.status(), StatusCode::OK);
193    }
194
195    #[crate::test]
196    async fn in_fallbacks() {
197        let api = Router::new().fallback(get(|nested_path: NestedPath| {
198            assert_eq!(nested_path.as_str(), "/api");
199            async {}
200        }));
201
202        let app = Router::new().nest("/api", api);
203
204        let client = TestClient::new(app);
205
206        let res = client.get("/api/doesnt-exist").await;
207        assert_eq!(res.status(), StatusCode::OK);
208    }
209
210    #[crate::test]
211    async fn in_middleware() {
212        async fn middleware(nested_path: NestedPath, req: Request, next: Next) -> Response {
213            assert_eq!(nested_path.as_str(), "/api");
214            next.run(req).await
215        }
216
217        let api = Router::new()
218            .route("/users", get(|| async {}))
219            .layer(from_fn(middleware));
220
221        let app = Router::new().nest("/api", api);
222
223        let client = TestClient::new(app);
224
225        let res = client.get("/api/users").await;
226        assert_eq!(res.status(), StatusCode::OK);
227    }
228}