1use std::{ops::Deref, sync::Arc};
2
3use axum::{
4 extract::{FromRequest, FromRequestParts, Path, Query as AxumQuery},
5 http::{request::Parts, Extensions, HeaderMap},
6};
7use serde::de::DeserializeOwned;
8
9use crate::{HttpException, RequestContext, Validate};
10
11pub trait Pipe<Input>: Send + Sync + 'static {
12 type Output;
13
14 fn transform(value: Input, ctx: &RequestContext) -> Result<Self::Output, HttpException>;
15}
16
17pub trait RequestDecorator: Send + Sync + 'static {
18 type Output: Send + 'static;
19
20 fn extract(ctx: &RequestContext, parts: &Parts) -> Result<Self::Output, HttpException>;
21}
22
23#[derive(Debug, Clone)]
24pub struct RequestId(pub Arc<str>);
25
26impl RequestId {
27 pub fn new(value: impl Into<String>) -> Self {
28 Self(Arc::<str>::from(value.into()))
29 }
30
31 pub fn into_inner(self) -> String {
32 self.0.as_ref().to_string()
33 }
34
35 pub fn value(&self) -> &str {
36 self.0.as_ref()
37 }
38}
39
40impl Deref for RequestId {
41 type Target = str;
42
43 fn deref(&self) -> &Self::Target {
44 self.0.as_ref()
45 }
46}
47
48pub fn request_id_from_extensions(extensions: &Extensions) -> Option<String> {
49 extensions
50 .get::<RequestId>()
51 .map(|request_id| request_id.value().to_string())
52}
53
54impl<S> FromRequestParts<S> for RequestId
55where
56 S: Send + Sync,
57{
58 type Rejection = HttpException;
59
60 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
61 parts
62 .extensions
63 .get::<RequestId>()
64 .cloned()
65 .ok_or_else(|| HttpException::internal_server_error("Request id not available"))
66 }
67}
68
69#[derive(Debug, Clone, Copy)]
79pub struct Param<T>(pub T);
80
81impl<T> Deref for Param<T> {
82 type Target = T;
83
84 fn deref(&self) -> &Self::Target {
85 &self.0
86 }
87}
88
89pub struct Decorated<T>
90where
91 T: RequestDecorator,
92{
93 value: T::Output,
94 _marker: std::marker::PhantomData<T>,
95}
96
97impl<T> Deref for Decorated<T>
98where
99 T: RequestDecorator,
100{
101 type Target = T::Output;
102
103 fn deref(&self) -> &Self::Target {
104 &self.value
105 }
106}
107
108impl<T> Decorated<T>
109where
110 T: RequestDecorator,
111{
112 pub fn into_inner(self) -> T::Output {
113 self.value
114 }
115}
116
117impl<S, T> FromRequestParts<S> for Decorated<T>
118where
119 S: Send + Sync,
120 T: RequestDecorator,
121{
122 type Rejection = HttpException;
123
124 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
125 let ctx = RequestContext::from_parts(parts);
126 let value = T::extract(&ctx, parts)?;
127
128 Ok(Self {
129 value,
130 _marker: std::marker::PhantomData,
131 })
132 }
133}
134
135impl<T> Param<T> {
136 pub fn into_inner(self) -> T {
137 self.0
138 }
139
140 pub fn value(self) -> T {
141 self.0
142 }
143}
144
145impl<S, T> FromRequestParts<S> for Param<T>
149where
150 S: Send + Sync,
151 T: DeserializeOwned + Send + 'static,
152{
153 type Rejection = HttpException;
154
155 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
156 let request_id = request_id_from_extensions(&parts.extensions);
157 let Path(value) = Path::<T>::from_request_parts(parts, state)
158 .await
159 .map_err(|_| {
160 HttpException::bad_request("Invalid route parameter")
161 .with_optional_request_id(request_id)
162 })?;
163
164 Ok(Self(value))
165 }
166}
167
168pub struct PipedParam<T, P>
169where
170 P: Pipe<T>,
171{
172 value: P::Output,
173 _marker: std::marker::PhantomData<(T, P)>,
174}
175
176impl<T, P> Deref for PipedParam<T, P>
177where
178 P: Pipe<T>,
179{
180 type Target = P::Output;
181
182 fn deref(&self) -> &Self::Target {
183 &self.value
184 }
185}
186
187impl<T, P> PipedParam<T, P>
188where
189 P: Pipe<T>,
190{
191 pub fn into_inner(self) -> P::Output {
192 self.value
193 }
194}
195
196impl<S, T, P> FromRequestParts<S> for PipedParam<T, P>
197where
198 S: Send + Sync,
199 T: DeserializeOwned + Send + 'static,
200 P: Pipe<T>,
201{
202 type Rejection = HttpException;
203
204 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
205 let Path(value) = Path::<T>::from_request_parts(parts, state)
206 .await
207 .map_err(|_| {
208 HttpException::bad_request("Invalid route parameter")
209 .with_optional_request_id(request_id_from_extensions(&parts.extensions))
210 })?;
211 let ctx = RequestContext::from_parts(parts);
212 let value = P::transform(value, &ctx)?;
213
214 Ok(Self {
215 value,
216 _marker: std::marker::PhantomData,
217 })
218 }
219}
220
221#[derive(Debug, Clone)]
222pub struct Query<T>(pub T);
223
224impl<T> Deref for Query<T> {
225 type Target = T;
226
227 fn deref(&self) -> &Self::Target {
228 &self.0
229 }
230}
231
232impl<T> Query<T> {
233 pub fn into_inner(self) -> T {
234 self.0
235 }
236
237 pub fn value(self) -> T {
238 self.0
239 }
240}
241
242impl<S, T> FromRequestParts<S> for Query<T>
243where
244 S: Send + Sync,
245 T: DeserializeOwned + Send + 'static,
246{
247 type Rejection = HttpException;
248
249 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
250 let request_id = request_id_from_extensions(&parts.extensions);
251 let AxumQuery(value) = AxumQuery::<T>::from_request_parts(parts, state)
252 .await
253 .map_err(|_| {
254 HttpException::bad_request("Invalid query parameters")
255 .with_optional_request_id(request_id)
256 })?;
257
258 Ok(Self(value))
259 }
260}
261
262pub struct PipedQuery<T, P>
263where
264 P: Pipe<T>,
265{
266 value: P::Output,
267 _marker: std::marker::PhantomData<(T, P)>,
268}
269
270impl<T, P> Deref for PipedQuery<T, P>
271where
272 P: Pipe<T>,
273{
274 type Target = P::Output;
275
276 fn deref(&self) -> &Self::Target {
277 &self.value
278 }
279}
280
281impl<T, P> PipedQuery<T, P>
282where
283 P: Pipe<T>,
284{
285 pub fn into_inner(self) -> P::Output {
286 self.value
287 }
288}
289
290impl<S, T, P> FromRequestParts<S> for PipedQuery<T, P>
291where
292 S: Send + Sync,
293 T: DeserializeOwned + Send + 'static,
294 P: Pipe<T>,
295{
296 type Rejection = HttpException;
297
298 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
299 let AxumQuery(value) = AxumQuery::<T>::from_request_parts(parts, state)
300 .await
301 .map_err(|_| {
302 HttpException::bad_request("Invalid query parameters")
303 .with_optional_request_id(request_id_from_extensions(&parts.extensions))
304 })?;
305 let ctx = RequestContext::from_parts(parts);
306 let value = P::transform(value, &ctx)?;
307
308 Ok(Self {
309 value,
310 _marker: std::marker::PhantomData,
311 })
312 }
313}
314
315#[derive(Debug, Clone)]
316pub struct Headers(pub HeaderMap);
317
318impl Headers {
319 pub fn get(&self, name: &str) -> Option<&axum::http::HeaderValue> {
320 self.0.get(name)
321 }
322}
323
324impl Deref for Headers {
325 type Target = HeaderMap;
326
327 fn deref(&self) -> &Self::Target {
328 &self.0
329 }
330}
331
332impl<S> FromRequestParts<S> for Headers
333where
334 S: Send + Sync,
335{
336 type Rejection = HttpException;
337
338 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
339 Ok(Self(parts.headers.clone()))
340 }
341}
342
343#[derive(Debug, Clone, Default)]
344pub struct Cookies {
345 values: std::collections::BTreeMap<String, String>,
346}
347
348impl Cookies {
349 pub fn new<I, K, V>(pairs: I) -> Self
350 where
351 I: IntoIterator<Item = (K, V)>,
352 K: Into<String>,
353 V: Into<String>,
354 {
355 Self {
356 values: pairs
357 .into_iter()
358 .map(|(key, value)| (key.into(), value.into()))
359 .collect(),
360 }
361 }
362
363 pub fn get(&self, name: &str) -> Option<&str> {
364 self.values.get(name).map(String::as_str)
365 }
366}
367
368impl<S> FromRequestParts<S> for Cookies
369where
370 S: Send + Sync,
371{
372 type Rejection = HttpException;
373
374 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
375 let mut cookies = std::collections::BTreeMap::new();
376 if let Some(header) = parts.headers.get(axum::http::header::COOKIE) {
377 if let Ok(raw) = header.to_str() {
378 for pair in raw.split(';') {
379 let trimmed = pair.trim();
380 if let Some((name, value)) = trimmed.split_once('=') {
381 cookies.insert(name.trim().to_string(), value.trim().to_string());
382 }
383 }
384 }
385 }
386
387 Ok(Self { values: cookies })
388 }
389}
390
391pub struct Body<T>(pub T);
401
402impl<T> Deref for Body<T> {
403 type Target = T;
404
405 fn deref(&self) -> &Self::Target {
406 &self.0
407 }
408}
409
410impl<T> Body<T> {
411 pub fn into_inner(self) -> T {
412 self.0
413 }
414
415 pub fn value(self) -> T {
416 self.0
417 }
418}
419
420impl<S, T> FromRequest<S> for Body<T>
424where
425 S: Send + Sync,
426 T: DeserializeOwned + Send + 'static,
427{
428 type Rejection = HttpException;
429
430 async fn from_request(req: axum::extract::Request, state: &S) -> Result<Self, Self::Rejection> {
431 let request_id = request_id_from_extensions(req.extensions());
432 let axum::Json(value) = axum::Json::<T>::from_request(req, state)
433 .await
434 .map_err(|_| {
435 HttpException::bad_request("Invalid JSON body").with_optional_request_id(request_id)
436 })?;
437
438 Ok(Self(value))
439 }
440}
441
442pub struct PipedBody<T, P>
443where
444 P: Pipe<T>,
445{
446 value: P::Output,
447 _marker: std::marker::PhantomData<(T, P)>,
448}
449
450impl<T, P> Deref for PipedBody<T, P>
451where
452 P: Pipe<T>,
453{
454 type Target = P::Output;
455
456 fn deref(&self) -> &Self::Target {
457 &self.value
458 }
459}
460
461impl<T, P> PipedBody<T, P>
462where
463 P: Pipe<T>,
464{
465 pub fn into_inner(self) -> P::Output {
466 self.value
467 }
468}
469
470impl<S, T, P> FromRequest<S> for PipedBody<T, P>
471where
472 S: Send + Sync,
473 T: DeserializeOwned + Send + 'static,
474 P: Pipe<T>,
475{
476 type Rejection = HttpException;
477
478 async fn from_request(req: axum::extract::Request, state: &S) -> Result<Self, Self::Rejection> {
479 let ctx = RequestContext::from_request(&req);
480 let axum::Json(value) = axum::Json::<T>::from_request(req, state)
481 .await
482 .map_err(|_| {
483 HttpException::bad_request("Invalid JSON body")
484 .with_optional_request_id(ctx.request_id.clone())
485 })?;
486 let value = P::transform(value, &ctx)?;
487
488 Ok(Self {
489 value,
490 _marker: std::marker::PhantomData,
491 })
492 }
493}
494
495pub struct ValidatedBody<T>(pub T);
496
497impl<T> Deref for ValidatedBody<T> {
498 type Target = T;
499
500 fn deref(&self) -> &Self::Target {
501 &self.0
502 }
503}
504
505impl<T> ValidatedBody<T> {
506 pub fn into_inner(self) -> T {
507 self.0
508 }
509
510 pub fn value(self) -> T {
511 self.0
512 }
513}
514
515impl<S, T> FromRequest<S> for ValidatedBody<T>
516where
517 S: Send + Sync,
518 T: DeserializeOwned + Validate + Send + 'static,
519{
520 type Rejection = HttpException;
521
522 async fn from_request(req: axum::extract::Request, state: &S) -> Result<Self, Self::Rejection> {
523 let request_id = request_id_from_extensions(req.extensions());
524 let axum::Json(value) = axum::Json::<T>::from_request(req, state)
525 .await
526 .map_err(|_| {
527 HttpException::bad_request("Invalid JSON body")
528 .with_optional_request_id(request_id.clone())
529 })?;
530
531 value.validate().map_err(|errors| {
532 HttpException::bad_request_validation(errors).with_optional_request_id(request_id)
533 })?;
534
535 Ok(Self(value))
536 }
537}