1use std::{ops::Deref, sync::Arc};
2
3use axum::{
4 extract::{FromRequest, FromRequestParts, Path, Query as AxumQuery},
5 http::{request::Parts, Extensions, HeaderMap},
6};
7use serde::de::DeserializeOwned;
8
9use crate::{HttpException, RequestContext, Validate};
10
11pub trait Pipe<Input>: Send + Sync + 'static {
35 type Output;
36
37 fn transform(value: Input, ctx: &RequestContext) -> Result<Self::Output, HttpException>;
38}
39
40pub trait RequestDecorator: Send + Sync + 'static {
69 type Output: Send + 'static;
70
71 fn extract(ctx: &RequestContext, parts: &Parts) -> Result<Self::Output, HttpException>;
72}
73
74#[derive(Debug, Clone)]
82pub struct RequestId(pub Arc<str>);
83
84impl RequestId {
85 pub fn new(value: impl Into<String>) -> Self {
86 Self(Arc::<str>::from(value.into()))
87 }
88
89 pub fn into_inner(self) -> String {
90 self.0.as_ref().to_string()
91 }
92
93 pub fn value(&self) -> &str {
94 self.0.as_ref()
95 }
96}
97
98impl Deref for RequestId {
99 type Target = str;
100
101 fn deref(&self) -> &Self::Target {
102 self.0.as_ref()
103 }
104}
105
106pub fn request_id_from_extensions(extensions: &Extensions) -> Option<String> {
107 extensions
108 .get::<RequestId>()
109 .map(|request_id| request_id.value().to_string())
110}
111
112impl<S> FromRequestParts<S> for RequestId
113where
114 S: Send + Sync,
115{
116 type Rejection = HttpException;
117
118 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
119 parts
120 .extensions
121 .get::<RequestId>()
122 .cloned()
123 .ok_or_else(|| HttpException::internal_server_error("Request id not available"))
124 }
125}
126
127#[derive(Debug, Clone, Copy)]
133pub struct Param<T>(pub T);
134
135impl<T> Deref for Param<T> {
136 type Target = T;
137
138 fn deref(&self) -> &Self::Target {
139 &self.0
140 }
141}
142
143pub struct Decorated<T>
144where
145 T: RequestDecorator,
146{
147 value: T::Output,
148 _marker: std::marker::PhantomData<T>,
149}
150
151impl<T> Deref for Decorated<T>
152where
153 T: RequestDecorator,
154{
155 type Target = T::Output;
156
157 fn deref(&self) -> &Self::Target {
158 &self.value
159 }
160}
161
162impl<T> Decorated<T>
163where
164 T: RequestDecorator,
165{
166 pub fn into_inner(self) -> T::Output {
167 self.value
168 }
169}
170
171impl<S, T> FromRequestParts<S> for Decorated<T>
172where
173 S: Send + Sync,
174 T: RequestDecorator,
175{
176 type Rejection = HttpException;
177
178 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
179 let ctx = RequestContext::from_parts(parts);
180 let value = T::extract(&ctx, parts)?;
181
182 Ok(Self {
183 value,
184 _marker: std::marker::PhantomData,
185 })
186 }
187}
188
189impl<T> Param<T> {
190 pub fn into_inner(self) -> T {
191 self.0
192 }
193
194 pub fn value(self) -> T {
195 self.0
196 }
197}
198
199impl<S, T> FromRequestParts<S> for Param<T>
200where
201 S: Send + Sync,
202 T: DeserializeOwned + Send + 'static,
203{
204 type Rejection = HttpException;
205
206 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
207 let request_id = request_id_from_extensions(&parts.extensions);
208 let Path(value) = Path::<T>::from_request_parts(parts, state)
209 .await
210 .map_err(|_| {
211 HttpException::bad_request("Invalid route parameter")
212 .with_optional_request_id(request_id)
213 })?;
214
215 Ok(Self(value))
216 }
217}
218
219pub struct PipedParam<T, P>
220where
221 P: Pipe<T>,
222{
223 value: P::Output,
224 _marker: std::marker::PhantomData<(T, P)>,
225}
226
227impl<T, P> Deref for PipedParam<T, P>
228where
229 P: Pipe<T>,
230{
231 type Target = P::Output;
232
233 fn deref(&self) -> &Self::Target {
234 &self.value
235 }
236}
237
238impl<T, P> PipedParam<T, P>
239where
240 P: Pipe<T>,
241{
242 pub fn into_inner(self) -> P::Output {
243 self.value
244 }
245}
246
247impl<S, T, P> FromRequestParts<S> for PipedParam<T, P>
248where
249 S: Send + Sync,
250 T: DeserializeOwned + Send + 'static,
251 P: Pipe<T>,
252{
253 type Rejection = HttpException;
254
255 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
256 let Path(value) = Path::<T>::from_request_parts(parts, state)
257 .await
258 .map_err(|_| {
259 HttpException::bad_request("Invalid route parameter")
260 .with_optional_request_id(request_id_from_extensions(&parts.extensions))
261 })?;
262 let ctx = RequestContext::from_parts(parts);
263 let value = P::transform(value, &ctx)?;
264
265 Ok(Self {
266 value,
267 _marker: std::marker::PhantomData,
268 })
269 }
270}
271
272#[derive(Debug, Clone)]
278pub struct Query<T>(pub T);
279
280impl<T> Deref for Query<T> {
281 type Target = T;
282
283 fn deref(&self) -> &Self::Target {
284 &self.0
285 }
286}
287
288impl<T> Query<T> {
289 pub fn into_inner(self) -> T {
290 self.0
291 }
292
293 pub fn value(self) -> T {
294 self.0
295 }
296}
297
298impl<S, T> FromRequestParts<S> for Query<T>
299where
300 S: Send + Sync,
301 T: DeserializeOwned + Send + 'static,
302{
303 type Rejection = HttpException;
304
305 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
306 let request_id = request_id_from_extensions(&parts.extensions);
307 let AxumQuery(value) = AxumQuery::<T>::from_request_parts(parts, state)
308 .await
309 .map_err(|_| {
310 HttpException::bad_request("Invalid query parameters")
311 .with_optional_request_id(request_id)
312 })?;
313
314 Ok(Self(value))
315 }
316}
317
318pub struct PipedQuery<T, P>
319where
320 P: Pipe<T>,
321{
322 value: P::Output,
323 _marker: std::marker::PhantomData<(T, P)>,
324}
325
326impl<T, P> Deref for PipedQuery<T, P>
327where
328 P: Pipe<T>,
329{
330 type Target = P::Output;
331
332 fn deref(&self) -> &Self::Target {
333 &self.value
334 }
335}
336
337impl<T, P> PipedQuery<T, P>
338where
339 P: Pipe<T>,
340{
341 pub fn into_inner(self) -> P::Output {
342 self.value
343 }
344}
345
346impl<S, T, P> FromRequestParts<S> for PipedQuery<T, P>
347where
348 S: Send + Sync,
349 T: DeserializeOwned + Send + 'static,
350 P: Pipe<T>,
351{
352 type Rejection = HttpException;
353
354 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
355 let AxumQuery(value) = AxumQuery::<T>::from_request_parts(parts, state)
356 .await
357 .map_err(|_| {
358 HttpException::bad_request("Invalid query parameters")
359 .with_optional_request_id(request_id_from_extensions(&parts.extensions))
360 })?;
361 let ctx = RequestContext::from_parts(parts);
362 let value = P::transform(value, &ctx)?;
363
364 Ok(Self {
365 value,
366 _marker: std::marker::PhantomData,
367 })
368 }
369}
370
371#[derive(Debug, Clone)]
373pub struct Headers(pub HeaderMap);
374
375impl Headers {
376 pub fn get(&self, name: &str) -> Option<&axum::http::HeaderValue> {
377 self.0.get(name)
378 }
379}
380
381impl Deref for Headers {
382 type Target = HeaderMap;
383
384 fn deref(&self) -> &Self::Target {
385 &self.0
386 }
387}
388
389impl<S> FromRequestParts<S> for Headers
390where
391 S: Send + Sync,
392{
393 type Rejection = HttpException;
394
395 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
396 Ok(Self(parts.headers.clone()))
397 }
398}
399
400#[derive(Debug, Clone, Default)]
404pub struct Cookies {
405 values: std::collections::BTreeMap<String, String>,
406}
407
408impl Cookies {
409 pub fn new<I, K, V>(pairs: I) -> Self
410 where
411 I: IntoIterator<Item = (K, V)>,
412 K: Into<String>,
413 V: Into<String>,
414 {
415 Self {
416 values: pairs
417 .into_iter()
418 .map(|(key, value)| (key.into(), value.into()))
419 .collect(),
420 }
421 }
422
423 pub fn get(&self, name: &str) -> Option<&str> {
424 self.values.get(name).map(String::as_str)
425 }
426}
427
428impl<S> FromRequestParts<S> for Cookies
429where
430 S: Send + Sync,
431{
432 type Rejection = HttpException;
433
434 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
435 let mut cookies = std::collections::BTreeMap::new();
436 if let Some(header) = parts.headers.get(axum::http::header::COOKIE) {
437 if let Ok(raw) = header.to_str() {
438 for pair in raw.split(';') {
439 let trimmed = pair.trim();
440 if let Some((name, value)) = trimmed.split_once('=') {
441 cookies.insert(name.trim().to_string(), value.trim().to_string());
442 }
443 }
444 }
445 }
446
447 Ok(Self { values: cookies })
448 }
449}
450
451pub struct Body<T>(pub T);
457
458impl<T> Deref for Body<T> {
459 type Target = T;
460
461 fn deref(&self) -> &Self::Target {
462 &self.0
463 }
464}
465
466impl<T> Body<T> {
467 pub fn into_inner(self) -> T {
468 self.0
469 }
470
471 pub fn value(self) -> T {
472 self.0
473 }
474}
475
476impl<S, T> FromRequest<S> for Body<T>
480where
481 S: Send + Sync,
482 T: DeserializeOwned + Send + 'static,
483{
484 type Rejection = HttpException;
485
486 async fn from_request(req: axum::extract::Request, state: &S) -> Result<Self, Self::Rejection> {
487 let request_id = request_id_from_extensions(req.extensions());
488 let axum::Json(value) = axum::Json::<T>::from_request(req, state)
489 .await
490 .map_err(|_| {
491 HttpException::bad_request("Invalid JSON body").with_optional_request_id(request_id)
492 })?;
493
494 Ok(Self(value))
495 }
496}
497
498pub struct PipedBody<T, P>
499where
500 P: Pipe<T>,
501{
502 value: P::Output,
503 _marker: std::marker::PhantomData<(T, P)>,
504}
505
506impl<T, P> Deref for PipedBody<T, P>
507where
508 P: Pipe<T>,
509{
510 type Target = P::Output;
511
512 fn deref(&self) -> &Self::Target {
513 &self.value
514 }
515}
516
517impl<T, P> PipedBody<T, P>
518where
519 P: Pipe<T>,
520{
521 pub fn into_inner(self) -> P::Output {
522 self.value
523 }
524}
525
526impl<S, T, P> FromRequest<S> for PipedBody<T, P>
527where
528 S: Send + Sync,
529 T: DeserializeOwned + Send + 'static,
530 P: Pipe<T>,
531{
532 type Rejection = HttpException;
533
534 async fn from_request(req: axum::extract::Request, state: &S) -> Result<Self, Self::Rejection> {
535 let ctx = RequestContext::from_request(&req);
536 let axum::Json(value) = axum::Json::<T>::from_request(req, state)
537 .await
538 .map_err(|_| {
539 HttpException::bad_request("Invalid JSON body")
540 .with_optional_request_id(ctx.request_id.clone())
541 })?;
542 let value = P::transform(value, &ctx)?;
543
544 Ok(Self {
545 value,
546 _marker: std::marker::PhantomData,
547 })
548 }
549}
550
551pub struct ValidatedBody<T>(pub T);
556
557impl<T> Deref for ValidatedBody<T> {
558 type Target = T;
559
560 fn deref(&self) -> &Self::Target {
561 &self.0
562 }
563}
564
565impl<T> ValidatedBody<T> {
566 pub fn into_inner(self) -> T {
567 self.0
568 }
569
570 pub fn value(self) -> T {
571 self.0
572 }
573}
574
575impl<S, T> FromRequest<S> for ValidatedBody<T>
576where
577 S: Send + Sync,
578 T: DeserializeOwned + Validate + Send + 'static,
579{
580 type Rejection = HttpException;
581
582 async fn from_request(req: axum::extract::Request, state: &S) -> Result<Self, Self::Rejection> {
583 let request_id = request_id_from_extensions(req.extensions());
584 let axum::Json(value) = axum::Json::<T>::from_request(req, state)
585 .await
586 .map_err(|_| {
587 HttpException::bad_request("Invalid JSON body")
588 .with_optional_request_id(request_id.clone())
589 })?;
590
591 value.validate().map_err(|errors| {
592 HttpException::bad_request_validation(errors).with_optional_request_id(request_id)
593 })?;
594
595 Ok(Self(value))
596 }
597}