Skip to main content

nestforge_core/
request.rs

1use std::{ops::Deref, sync::Arc};
2
3use axum::{
4    extract::{FromRequest, FromRequestParts, Path, Query as AxumQuery},
5    http::{request::Parts, Extensions, HeaderMap},
6};
7use serde::de::DeserializeOwned;
8
9use crate::{HttpException, RequestContext, Validate};
10
11pub trait Pipe<Input>: Send + Sync + 'static {
12    type Output;
13
14    fn transform(value: Input, ctx: &RequestContext) -> Result<Self::Output, HttpException>;
15}
16
17pub trait RequestDecorator: Send + Sync + 'static {
18    type Output: Send + 'static;
19
20    fn extract(ctx: &RequestContext, parts: &Parts) -> Result<Self::Output, HttpException>;
21}
22
23#[derive(Debug, Clone)]
24pub struct RequestId(pub Arc<str>);
25
26impl RequestId {
27    pub fn new(value: impl Into<String>) -> Self {
28        Self(Arc::<str>::from(value.into()))
29    }
30
31    pub fn into_inner(self) -> String {
32        self.0.as_ref().to_string()
33    }
34
35    pub fn value(&self) -> &str {
36        self.0.as_ref()
37    }
38}
39
40impl Deref for RequestId {
41    type Target = str;
42
43    fn deref(&self) -> &Self::Target {
44        self.0.as_ref()
45    }
46}
47
48pub fn request_id_from_extensions(extensions: &Extensions) -> Option<String> {
49    extensions
50        .get::<RequestId>()
51        .map(|request_id| request_id.value().to_string())
52}
53
54impl<S> FromRequestParts<S> for RequestId
55where
56    S: Send + Sync,
57{
58    type Rejection = HttpException;
59
60    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
61        parts
62            .extensions
63            .get::<RequestId>()
64            .cloned()
65            .ok_or_else(|| HttpException::internal_server_error("Request id not available"))
66    }
67}
68
69/*
70Param<T> = path param wrapper
71
72User writes:
73id: Param<u64>
74
75Instead of:
76Path(id): Path<u64>
77*/
78#[derive(Debug, Clone, Copy)]
79pub struct Param<T>(pub T);
80
81impl<T> Deref for Param<T> {
82    type Target = T;
83
84    fn deref(&self) -> &Self::Target {
85        &self.0
86    }
87}
88
89pub struct Decorated<T>
90where
91    T: RequestDecorator,
92{
93    value: T::Output,
94    _marker: std::marker::PhantomData<T>,
95}
96
97impl<T> Deref for Decorated<T>
98where
99    T: RequestDecorator,
100{
101    type Target = T::Output;
102
103    fn deref(&self) -> &Self::Target {
104        &self.value
105    }
106}
107
108impl<T> Decorated<T>
109where
110    T: RequestDecorator,
111{
112    pub fn into_inner(self) -> T::Output {
113        self.value
114    }
115}
116
117impl<S, T> FromRequestParts<S> for Decorated<T>
118where
119    S: Send + Sync,
120    T: RequestDecorator,
121{
122    type Rejection = HttpException;
123
124    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
125        let ctx = RequestContext::from_parts(parts);
126        let value = T::extract(&ctx, parts)?;
127
128        Ok(Self {
129            value,
130            _marker: std::marker::PhantomData,
131        })
132    }
133}
134
135impl<T> Param<T> {
136    pub fn into_inner(self) -> T {
137        self.0
138    }
139
140    pub fn value(self) -> T {
141        self.0
142    }
143}
144
145/*
146Extract Param<T> from route path params.
147*/
148impl<S, T> FromRequestParts<S> for Param<T>
149where
150    S: Send + Sync,
151    T: DeserializeOwned + Send + 'static,
152{
153    type Rejection = HttpException;
154
155    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
156        let request_id = request_id_from_extensions(&parts.extensions);
157        let Path(value) = Path::<T>::from_request_parts(parts, state)
158            .await
159            .map_err(|_| {
160                HttpException::bad_request("Invalid route parameter")
161                    .with_optional_request_id(request_id)
162            })?;
163
164        Ok(Self(value))
165    }
166}
167
168pub struct PipedParam<T, P>
169where
170    P: Pipe<T>,
171{
172    value: P::Output,
173    _marker: std::marker::PhantomData<(T, P)>,
174}
175
176impl<T, P> Deref for PipedParam<T, P>
177where
178    P: Pipe<T>,
179{
180    type Target = P::Output;
181
182    fn deref(&self) -> &Self::Target {
183        &self.value
184    }
185}
186
187impl<T, P> PipedParam<T, P>
188where
189    P: Pipe<T>,
190{
191    pub fn into_inner(self) -> P::Output {
192        self.value
193    }
194}
195
196impl<S, T, P> FromRequestParts<S> for PipedParam<T, P>
197where
198    S: Send + Sync,
199    T: DeserializeOwned + Send + 'static,
200    P: Pipe<T>,
201{
202    type Rejection = HttpException;
203
204    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
205        let Path(value) = Path::<T>::from_request_parts(parts, state)
206            .await
207            .map_err(|_| {
208                HttpException::bad_request("Invalid route parameter")
209                    .with_optional_request_id(request_id_from_extensions(&parts.extensions))
210            })?;
211        let ctx = RequestContext::from_parts(parts);
212        let value = P::transform(value, &ctx)?;
213
214        Ok(Self {
215            value,
216            _marker: std::marker::PhantomData,
217        })
218    }
219}
220
221#[derive(Debug, Clone)]
222pub struct Query<T>(pub T);
223
224impl<T> Deref for Query<T> {
225    type Target = T;
226
227    fn deref(&self) -> &Self::Target {
228        &self.0
229    }
230}
231
232impl<T> Query<T> {
233    pub fn into_inner(self) -> T {
234        self.0
235    }
236
237    pub fn value(self) -> T {
238        self.0
239    }
240}
241
242impl<S, T> FromRequestParts<S> for Query<T>
243where
244    S: Send + Sync,
245    T: DeserializeOwned + Send + 'static,
246{
247    type Rejection = HttpException;
248
249    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
250        let request_id = request_id_from_extensions(&parts.extensions);
251        let AxumQuery(value) = AxumQuery::<T>::from_request_parts(parts, state)
252            .await
253            .map_err(|_| {
254                HttpException::bad_request("Invalid query parameters")
255                    .with_optional_request_id(request_id)
256            })?;
257
258        Ok(Self(value))
259    }
260}
261
262pub struct PipedQuery<T, P>
263where
264    P: Pipe<T>,
265{
266    value: P::Output,
267    _marker: std::marker::PhantomData<(T, P)>,
268}
269
270impl<T, P> Deref for PipedQuery<T, P>
271where
272    P: Pipe<T>,
273{
274    type Target = P::Output;
275
276    fn deref(&self) -> &Self::Target {
277        &self.value
278    }
279}
280
281impl<T, P> PipedQuery<T, P>
282where
283    P: Pipe<T>,
284{
285    pub fn into_inner(self) -> P::Output {
286        self.value
287    }
288}
289
290impl<S, T, P> FromRequestParts<S> for PipedQuery<T, P>
291where
292    S: Send + Sync,
293    T: DeserializeOwned + Send + 'static,
294    P: Pipe<T>,
295{
296    type Rejection = HttpException;
297
298    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
299        let AxumQuery(value) = AxumQuery::<T>::from_request_parts(parts, state)
300            .await
301            .map_err(|_| {
302                HttpException::bad_request("Invalid query parameters")
303                    .with_optional_request_id(request_id_from_extensions(&parts.extensions))
304            })?;
305        let ctx = RequestContext::from_parts(parts);
306        let value = P::transform(value, &ctx)?;
307
308        Ok(Self {
309            value,
310            _marker: std::marker::PhantomData,
311        })
312    }
313}
314
315#[derive(Debug, Clone)]
316pub struct Headers(pub HeaderMap);
317
318impl Headers {
319    pub fn get(&self, name: &str) -> Option<&axum::http::HeaderValue> {
320        self.0.get(name)
321    }
322}
323
324impl Deref for Headers {
325    type Target = HeaderMap;
326
327    fn deref(&self) -> &Self::Target {
328        &self.0
329    }
330}
331
332impl<S> FromRequestParts<S> for Headers
333where
334    S: Send + Sync,
335{
336    type Rejection = HttpException;
337
338    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
339        Ok(Self(parts.headers.clone()))
340    }
341}
342
343#[derive(Debug, Clone, Default)]
344pub struct Cookies {
345    values: std::collections::BTreeMap<String, String>,
346}
347
348impl Cookies {
349    pub fn new<I, K, V>(pairs: I) -> Self
350    where
351        I: IntoIterator<Item = (K, V)>,
352        K: Into<String>,
353        V: Into<String>,
354    {
355        Self {
356            values: pairs
357                .into_iter()
358                .map(|(key, value)| (key.into(), value.into()))
359                .collect(),
360        }
361    }
362
363    pub fn get(&self, name: &str) -> Option<&str> {
364        self.values.get(name).map(String::as_str)
365    }
366}
367
368impl<S> FromRequestParts<S> for Cookies
369where
370    S: Send + Sync,
371{
372    type Rejection = HttpException;
373
374    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
375        let mut cookies = std::collections::BTreeMap::new();
376        if let Some(header) = parts.headers.get(axum::http::header::COOKIE) {
377            if let Ok(raw) = header.to_str() {
378                for pair in raw.split(';') {
379                    let trimmed = pair.trim();
380                    if let Some((name, value)) = trimmed.split_once('=') {
381                        cookies.insert(name.trim().to_string(), value.trim().to_string());
382                    }
383                }
384            }
385        }
386
387        Ok(Self { values: cookies })
388    }
389}
390
391/*
392Body<T> = JSON request body wrapper
393
394User writes:
395body: Body<CreateUserDto>
396
397Instead of:
398Json(dto): Json<CreateUserDto>
399*/
400pub struct Body<T>(pub T);
401
402impl<T> Deref for Body<T> {
403    type Target = T;
404
405    fn deref(&self) -> &Self::Target {
406        &self.0
407    }
408}
409
410impl<T> Body<T> {
411    pub fn into_inner(self) -> T {
412        self.0
413    }
414
415    pub fn value(self) -> T {
416        self.0
417    }
418}
419
420/*
421Extract Body<T> from JSON request body.
422*/
423impl<S, T> FromRequest<S> for Body<T>
424where
425    S: Send + Sync,
426    T: DeserializeOwned + Send + 'static,
427{
428    type Rejection = HttpException;
429
430    async fn from_request(req: axum::extract::Request, state: &S) -> Result<Self, Self::Rejection> {
431        let request_id = request_id_from_extensions(req.extensions());
432        let axum::Json(value) = axum::Json::<T>::from_request(req, state)
433            .await
434            .map_err(|_| {
435                HttpException::bad_request("Invalid JSON body").with_optional_request_id(request_id)
436            })?;
437
438        Ok(Self(value))
439    }
440}
441
442pub struct PipedBody<T, P>
443where
444    P: Pipe<T>,
445{
446    value: P::Output,
447    _marker: std::marker::PhantomData<(T, P)>,
448}
449
450impl<T, P> Deref for PipedBody<T, P>
451where
452    P: Pipe<T>,
453{
454    type Target = P::Output;
455
456    fn deref(&self) -> &Self::Target {
457        &self.value
458    }
459}
460
461impl<T, P> PipedBody<T, P>
462where
463    P: Pipe<T>,
464{
465    pub fn into_inner(self) -> P::Output {
466        self.value
467    }
468}
469
470impl<S, T, P> FromRequest<S> for PipedBody<T, P>
471where
472    S: Send + Sync,
473    T: DeserializeOwned + Send + 'static,
474    P: Pipe<T>,
475{
476    type Rejection = HttpException;
477
478    async fn from_request(req: axum::extract::Request, state: &S) -> Result<Self, Self::Rejection> {
479        let ctx = RequestContext::from_request(&req);
480        let axum::Json(value) = axum::Json::<T>::from_request(req, state)
481            .await
482            .map_err(|_| {
483                HttpException::bad_request("Invalid JSON body")
484                    .with_optional_request_id(ctx.request_id.clone())
485            })?;
486        let value = P::transform(value, &ctx)?;
487
488        Ok(Self {
489            value,
490            _marker: std::marker::PhantomData,
491        })
492    }
493}
494
495pub struct ValidatedBody<T>(pub T);
496
497impl<T> Deref for ValidatedBody<T> {
498    type Target = T;
499
500    fn deref(&self) -> &Self::Target {
501        &self.0
502    }
503}
504
505impl<T> ValidatedBody<T> {
506    pub fn into_inner(self) -> T {
507        self.0
508    }
509
510    pub fn value(self) -> T {
511        self.0
512    }
513}
514
515impl<S, T> FromRequest<S> for ValidatedBody<T>
516where
517    S: Send + Sync,
518    T: DeserializeOwned + Validate + Send + 'static,
519{
520    type Rejection = HttpException;
521
522    async fn from_request(req: axum::extract::Request, state: &S) -> Result<Self, Self::Rejection> {
523        let request_id = request_id_from_extensions(req.extensions());
524        let axum::Json(value) = axum::Json::<T>::from_request(req, state)
525            .await
526            .map_err(|_| {
527                HttpException::bad_request("Invalid JSON body")
528                    .with_optional_request_id(request_id.clone())
529            })?;
530
531        value.validate().map_err(|errors| {
532            HttpException::bad_request_validation(errors).with_optional_request_id(request_id)
533        })?;
534
535        Ok(Self(value))
536    }
537}