1use std::{ops::Deref, sync::Arc};
2
3use axum::{
4 extract::{FromRequest, FromRequestParts, Path, Query as AxumQuery},
5 http::{request::Parts, Extensions, HeaderMap},
6};
7use serde::de::DeserializeOwned;
8
9use crate::{HttpException, RequestContext, Validate};
10
11pub trait Pipe<Input>: Send + Sync + 'static {
12 type Output;
13
14 fn transform(value: Input, ctx: &RequestContext) -> Result<Self::Output, HttpException>;
15}
16
17pub trait RequestDecorator: Send + Sync + 'static {
18 type Output: Send + 'static;
19
20 fn extract(ctx: &RequestContext, parts: &Parts) -> Result<Self::Output, HttpException>;
21}
22
23#[derive(Debug, Clone)]
24pub struct RequestId(pub Arc<str>);
25
26impl RequestId {
27 pub fn new(value: impl Into<String>) -> Self {
28 Self(Arc::<str>::from(value.into()))
29 }
30
31 pub fn into_inner(self) -> String {
32 self.0.as_ref().to_string()
33 }
34
35 pub fn value(&self) -> &str {
36 self.0.as_ref()
37 }
38}
39
40impl Deref for RequestId {
41 type Target = str;
42
43 fn deref(&self) -> &Self::Target {
44 self.0.as_ref()
45 }
46}
47
48pub fn request_id_from_extensions(extensions: &Extensions) -> Option<String> {
49 extensions.get::<RequestId>().map(|request_id| request_id.value().to_string())
50}
51
52impl<S> FromRequestParts<S> for RequestId
53where
54 S: Send + Sync,
55{
56 type Rejection = HttpException;
57
58 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
59 parts
60 .extensions
61 .get::<RequestId>()
62 .cloned()
63 .ok_or_else(|| HttpException::internal_server_error("Request id not available"))
64 }
65}
66
67#[derive(Debug, Clone, Copy)]
77pub struct Param<T>(pub T);
78
79impl<T> Deref for Param<T> {
80 type Target = T;
81
82 fn deref(&self) -> &Self::Target {
83 &self.0
84 }
85}
86
87pub struct Decorated<T>
88where
89 T: RequestDecorator,
90{
91 value: T::Output,
92 _marker: std::marker::PhantomData<T>,
93}
94
95impl<T> Deref for Decorated<T>
96where
97 T: RequestDecorator,
98{
99 type Target = T::Output;
100
101 fn deref(&self) -> &Self::Target {
102 &self.value
103 }
104}
105
106impl<T> Decorated<T>
107where
108 T: RequestDecorator,
109{
110 pub fn into_inner(self) -> T::Output {
111 self.value
112 }
113}
114
115impl<S, T> FromRequestParts<S> for Decorated<T>
116where
117 S: Send + Sync,
118 T: RequestDecorator,
119{
120 type Rejection = HttpException;
121
122 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
123 let ctx = RequestContext::from_parts(parts);
124 let value = T::extract(&ctx, parts)?;
125
126 Ok(Self {
127 value,
128 _marker: std::marker::PhantomData,
129 })
130 }
131}
132
133impl<T> Param<T> {
134 pub fn into_inner(self) -> T {
135 self.0
136 }
137
138 pub fn value(self) -> T {
139 self.0
140 }
141}
142
143impl<S, T> FromRequestParts<S> for Param<T>
147where
148 S: Send + Sync,
149 T: DeserializeOwned + Send + 'static,
150{
151 type Rejection = HttpException;
152
153 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
154 let request_id = request_id_from_extensions(&parts.extensions);
155 let Path(value) = Path::<T>::from_request_parts(parts, state)
156 .await
157 .map_err(|_| {
158 HttpException::bad_request("Invalid route parameter")
159 .with_optional_request_id(request_id)
160 })?;
161
162 Ok(Self(value))
163 }
164}
165
166pub struct PipedParam<T, P>
167where
168 P: Pipe<T>,
169{
170 value: P::Output,
171 _marker: std::marker::PhantomData<(T, P)>,
172}
173
174impl<T, P> Deref for PipedParam<T, P>
175where
176 P: Pipe<T>,
177{
178 type Target = P::Output;
179
180 fn deref(&self) -> &Self::Target {
181 &self.value
182 }
183}
184
185impl<T, P> PipedParam<T, P>
186where
187 P: Pipe<T>,
188{
189 pub fn into_inner(self) -> P::Output {
190 self.value
191 }
192}
193
194impl<S, T, P> FromRequestParts<S> for PipedParam<T, P>
195where
196 S: Send + Sync,
197 T: DeserializeOwned + Send + 'static,
198 P: Pipe<T>,
199{
200 type Rejection = HttpException;
201
202 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
203 let Path(value) = Path::<T>::from_request_parts(parts, state)
204 .await
205 .map_err(|_| {
206 HttpException::bad_request("Invalid route parameter")
207 .with_optional_request_id(request_id_from_extensions(&parts.extensions))
208 })?;
209 let ctx = RequestContext::from_parts(parts);
210 let value = P::transform(value, &ctx)?;
211
212 Ok(Self {
213 value,
214 _marker: std::marker::PhantomData,
215 })
216 }
217}
218
219#[derive(Debug, Clone)]
220pub struct Query<T>(pub T);
221
222impl<T> Deref for Query<T> {
223 type Target = T;
224
225 fn deref(&self) -> &Self::Target {
226 &self.0
227 }
228}
229
230impl<T> Query<T> {
231 pub fn into_inner(self) -> T {
232 self.0
233 }
234
235 pub fn value(self) -> T {
236 self.0
237 }
238}
239
240impl<S, T> FromRequestParts<S> for Query<T>
241where
242 S: Send + Sync,
243 T: DeserializeOwned + Send + 'static,
244{
245 type Rejection = HttpException;
246
247 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
248 let request_id = request_id_from_extensions(&parts.extensions);
249 let AxumQuery(value) = AxumQuery::<T>::from_request_parts(parts, state)
250 .await
251 .map_err(|_| {
252 HttpException::bad_request("Invalid query parameters")
253 .with_optional_request_id(request_id)
254 })?;
255
256 Ok(Self(value))
257 }
258}
259
260pub struct PipedQuery<T, P>
261where
262 P: Pipe<T>,
263{
264 value: P::Output,
265 _marker: std::marker::PhantomData<(T, P)>,
266}
267
268impl<T, P> Deref for PipedQuery<T, P>
269where
270 P: Pipe<T>,
271{
272 type Target = P::Output;
273
274 fn deref(&self) -> &Self::Target {
275 &self.value
276 }
277}
278
279impl<T, P> PipedQuery<T, P>
280where
281 P: Pipe<T>,
282{
283 pub fn into_inner(self) -> P::Output {
284 self.value
285 }
286}
287
288impl<S, T, P> FromRequestParts<S> for PipedQuery<T, P>
289where
290 S: Send + Sync,
291 T: DeserializeOwned + Send + 'static,
292 P: Pipe<T>,
293{
294 type Rejection = HttpException;
295
296 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
297 let AxumQuery(value) = AxumQuery::<T>::from_request_parts(parts, state)
298 .await
299 .map_err(|_| {
300 HttpException::bad_request("Invalid query parameters")
301 .with_optional_request_id(request_id_from_extensions(&parts.extensions))
302 })?;
303 let ctx = RequestContext::from_parts(parts);
304 let value = P::transform(value, &ctx)?;
305
306 Ok(Self {
307 value,
308 _marker: std::marker::PhantomData,
309 })
310 }
311}
312
313#[derive(Debug, Clone)]
314pub struct Headers(pub HeaderMap);
315
316impl Headers {
317 pub fn get(&self, name: &str) -> Option<&axum::http::HeaderValue> {
318 self.0.get(name)
319 }
320}
321
322impl Deref for Headers {
323 type Target = HeaderMap;
324
325 fn deref(&self) -> &Self::Target {
326 &self.0
327 }
328}
329
330impl<S> FromRequestParts<S> for Headers
331where
332 S: Send + Sync,
333{
334 type Rejection = HttpException;
335
336 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
337 Ok(Self(parts.headers.clone()))
338 }
339}
340
341#[derive(Debug, Clone, Default)]
342pub struct Cookies {
343 values: std::collections::BTreeMap<String, String>,
344}
345
346impl Cookies {
347 pub fn new<I, K, V>(pairs: I) -> Self
348 where
349 I: IntoIterator<Item = (K, V)>,
350 K: Into<String>,
351 V: Into<String>,
352 {
353 Self {
354 values: pairs
355 .into_iter()
356 .map(|(key, value)| (key.into(), value.into()))
357 .collect(),
358 }
359 }
360
361 pub fn get(&self, name: &str) -> Option<&str> {
362 self.values.get(name).map(String::as_str)
363 }
364}
365
366impl<S> FromRequestParts<S> for Cookies
367where
368 S: Send + Sync,
369{
370 type Rejection = HttpException;
371
372 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
373 let mut cookies = std::collections::BTreeMap::new();
374 if let Some(header) = parts.headers.get(axum::http::header::COOKIE) {
375 if let Ok(raw) = header.to_str() {
376 for pair in raw.split(';') {
377 let trimmed = pair.trim();
378 if let Some((name, value)) = trimmed.split_once('=') {
379 cookies.insert(name.trim().to_string(), value.trim().to_string());
380 }
381 }
382 }
383 }
384
385 Ok(Self { values: cookies })
386 }
387}
388
389pub struct Body<T>(pub T);
399
400impl<T> Deref for Body<T> {
401 type Target = T;
402
403 fn deref(&self) -> &Self::Target {
404 &self.0
405 }
406}
407
408impl<T> Body<T> {
409 pub fn into_inner(self) -> T {
410 self.0
411 }
412
413 pub fn value(self) -> T {
414 self.0
415 }
416}
417
418impl<S, T> FromRequest<S> for Body<T>
422where
423 S: Send + Sync,
424 T: DeserializeOwned + Send + 'static,
425{
426 type Rejection = HttpException;
427
428 async fn from_request(req: axum::extract::Request, state: &S) -> Result<Self, Self::Rejection> {
429 let request_id = request_id_from_extensions(req.extensions());
430 let axum::Json(value) = axum::Json::<T>::from_request(req, state)
431 .await
432 .map_err(|_| {
433 HttpException::bad_request("Invalid JSON body").with_optional_request_id(request_id)
434 })?;
435
436 Ok(Self(value))
437 }
438}
439
440pub struct PipedBody<T, P>
441where
442 P: Pipe<T>,
443{
444 value: P::Output,
445 _marker: std::marker::PhantomData<(T, P)>,
446}
447
448impl<T, P> Deref for PipedBody<T, P>
449where
450 P: Pipe<T>,
451{
452 type Target = P::Output;
453
454 fn deref(&self) -> &Self::Target {
455 &self.value
456 }
457}
458
459impl<T, P> PipedBody<T, P>
460where
461 P: Pipe<T>,
462{
463 pub fn into_inner(self) -> P::Output {
464 self.value
465 }
466}
467
468impl<S, T, P> FromRequest<S> for PipedBody<T, P>
469where
470 S: Send + Sync,
471 T: DeserializeOwned + Send + 'static,
472 P: Pipe<T>,
473{
474 type Rejection = HttpException;
475
476 async fn from_request(
477 req: axum::extract::Request,
478 state: &S,
479 ) -> Result<Self, Self::Rejection> {
480 let ctx = RequestContext::from_request(&req);
481 let axum::Json(value) = axum::Json::<T>::from_request(req, state)
482 .await
483 .map_err(|_| {
484 HttpException::bad_request("Invalid JSON body")
485 .with_optional_request_id(ctx.request_id.clone())
486 })?;
487 let value = P::transform(value, &ctx)?;
488
489 Ok(Self {
490 value,
491 _marker: std::marker::PhantomData,
492 })
493 }
494}
495
496pub struct ValidatedBody<T>(pub T);
497
498impl<T> Deref for ValidatedBody<T> {
499 type Target = T;
500
501 fn deref(&self) -> &Self::Target {
502 &self.0
503 }
504}
505
506impl<T> ValidatedBody<T> {
507 pub fn into_inner(self) -> T {
508 self.0
509 }
510
511 pub fn value(self) -> T {
512 self.0
513 }
514}
515
516impl<S, T> FromRequest<S> for ValidatedBody<T>
517where
518 S: Send + Sync,
519 T: DeserializeOwned + Validate + Send + 'static,
520{
521 type Rejection = HttpException;
522
523 async fn from_request(req: axum::extract::Request, state: &S) -> Result<Self, Self::Rejection> {
524 let request_id = request_id_from_extensions(req.extensions());
525 let axum::Json(value) = axum::Json::<T>::from_request(req, state)
526 .await
527 .map_err(|_| {
528 HttpException::bad_request("Invalid JSON body")
529 .with_optional_request_id(request_id.clone())
530 })?;
531
532 value
533 .validate()
534 .map_err(|errors| {
535 HttpException::bad_request_validation(errors)
536 .with_optional_request_id(request_id)
537 })?;
538
539 Ok(Self(value))
540 }
541}