axum/extract/
nested_path.rs1use std::{
2 sync::Arc,
3 task::{Context, Poll},
4};
5
6use crate::extract::Request;
7use axum_core::extract::FromRequestParts;
8use http::request::Parts;
9use tower_layer::{layer_fn, Layer};
10use tower_service::Service;
11
12use super::rejection::NestedPathRejection;
13
14#[derive(Debug, Clone)]
40pub struct NestedPath(Arc<str>);
41
42impl NestedPath {
43 #[must_use]
45 pub fn as_str(&self) -> &str {
46 &self.0
47 }
48}
49
50impl<S> FromRequestParts<S> for NestedPath
51where
52 S: Send + Sync,
53{
54 type Rejection = NestedPathRejection;
55
56 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
57 match parts.extensions.get::<Self>() {
58 Some(nested_path) => Ok(nested_path.clone()),
59 None => Err(NestedPathRejection),
60 }
61 }
62}
63
64#[derive(Clone)]
65pub(crate) struct SetNestedPath<S> {
66 inner: S,
67 path: Arc<str>,
68}
69
70impl<S> SetNestedPath<S> {
71 pub(crate) fn layer(path: &str) -> impl Layer<S, Service = Self> + Clone {
72 let path = Arc::from(path);
73 layer_fn(move |inner| Self {
74 inner,
75 path: Arc::clone(&path),
76 })
77 }
78}
79
80impl<S, B> Service<Request<B>> for SetNestedPath<S>
81where
82 S: Service<Request<B>>,
83{
84 type Response = S::Response;
85 type Error = S::Error;
86 type Future = S::Future;
87
88 #[inline]
89 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
90 self.inner.poll_ready(cx)
91 }
92
93 fn call(&mut self, mut req: Request<B>) -> Self::Future {
94 if let Some(prev) = req.extensions_mut().get_mut::<NestedPath>() {
95 let new_path = if prev.as_str() == "/" {
96 Arc::clone(&self.path)
97 } else {
98 format!("{}{}", prev.as_str().trim_end_matches('/'), self.path).into()
99 };
100 prev.0 = new_path;
101 } else {
102 req.extensions_mut()
103 .insert(NestedPath(Arc::clone(&self.path)));
104 };
105
106 self.inner.call(req)
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use axum_core::response::Response;
113 use http::StatusCode;
114
115 use crate::{
116 extract::{NestedPath, Request},
117 middleware::{from_fn, Next},
118 routing::get,
119 test_helpers::*,
120 Router,
121 };
122
123 #[crate::test]
124 async fn one_level_of_nesting() {
125 let api = Router::new().route(
126 "/users",
127 get(|nested_path: NestedPath| {
128 assert_eq!(nested_path.as_str(), "/api");
129 async {}
130 }),
131 );
132
133 let app = Router::new().nest("/api", api);
134
135 let client = TestClient::new(app);
136
137 let res = client.get("/api/users").await;
138 assert_eq!(res.status(), StatusCode::OK);
139 }
140
141 #[crate::test]
142 async fn one_level_of_nesting_with_trailing_slash() {
143 let api = Router::new().route(
144 "/users",
145 get(|nested_path: NestedPath| {
146 assert_eq!(nested_path.as_str(), "/api/");
147 async {}
148 }),
149 );
150
151 let app = Router::new().nest("/api/", api);
152
153 let client = TestClient::new(app);
154
155 let res = client.get("/api/users").await;
156 assert_eq!(res.status(), StatusCode::OK);
157 }
158
159 #[crate::test]
160 async fn two_levels_of_nesting() {
161 let api = Router::new().route(
162 "/users",
163 get(|nested_path: NestedPath| {
164 assert_eq!(nested_path.as_str(), "/api/v2");
165 async {}
166 }),
167 );
168
169 let app = Router::new().nest("/api", Router::new().nest("/v2", api));
170
171 let client = TestClient::new(app);
172
173 let res = client.get("/api/v2/users").await;
174 assert_eq!(res.status(), StatusCode::OK);
175 }
176
177 #[crate::test]
178 async fn two_levels_of_nesting_with_trailing_slash() {
179 let api = Router::new().route(
180 "/users",
181 get(|nested_path: NestedPath| {
182 assert_eq!(nested_path.as_str(), "/api/v2");
183 async {}
184 }),
185 );
186
187 let app = Router::new().nest("/api/", Router::new().nest("/v2", api));
188
189 let client = TestClient::new(app);
190
191 let res = client.get("/api/v2/users").await;
192 assert_eq!(res.status(), StatusCode::OK);
193 }
194
195 #[crate::test]
196 async fn in_fallbacks() {
197 let api = Router::new().fallback(get(|nested_path: NestedPath| {
198 assert_eq!(nested_path.as_str(), "/api");
199 async {}
200 }));
201
202 let app = Router::new().nest("/api", api);
203
204 let client = TestClient::new(app);
205
206 let res = client.get("/api/doesnt-exist").await;
207 assert_eq!(res.status(), StatusCode::OK);
208 }
209
210 #[crate::test]
211 async fn in_middleware() {
212 async fn middleware(nested_path: NestedPath, req: Request, next: Next) -> Response {
213 assert_eq!(nested_path.as_str(), "/api");
214 next.run(req).await
215 }
216
217 let api = Router::new()
218 .route("/users", get(|| async {}))
219 .layer(from_fn(middleware));
220
221 let app = Router::new().nest("/api", api);
222
223 let client = TestClient::new(app);
224
225 let res = client.get("/api/users").await;
226 assert_eq!(res.status(), StatusCode::OK);
227 }
228}