Skip to main content

nestforge_core/
request.rs

1use std::{ops::Deref, sync::Arc};
2
3use axum::{
4    extract::{FromRequest, FromRequestParts, Path, Query as AxumQuery},
5    http::{request::Parts, Extensions, HeaderMap},
6};
7use serde::de::DeserializeOwned;
8
9use crate::{HttpException, RequestContext, Validate};
10
11pub trait Pipe<Input>: Send + Sync + 'static {
12    type Output;
13
14    fn transform(value: Input, ctx: &RequestContext) -> Result<Self::Output, HttpException>;
15}
16
17pub trait RequestDecorator: Send + Sync + 'static {
18    type Output: Send + 'static;
19
20    fn extract(ctx: &RequestContext, parts: &Parts) -> Result<Self::Output, HttpException>;
21}
22
23#[derive(Debug, Clone)]
24pub struct RequestId(pub Arc<str>);
25
26impl RequestId {
27    pub fn new(value: impl Into<String>) -> Self {
28        Self(Arc::<str>::from(value.into()))
29    }
30
31    pub fn into_inner(self) -> String {
32        self.0.as_ref().to_string()
33    }
34
35    pub fn value(&self) -> &str {
36        self.0.as_ref()
37    }
38}
39
40impl Deref for RequestId {
41    type Target = str;
42
43    fn deref(&self) -> &Self::Target {
44        self.0.as_ref()
45    }
46}
47
48pub fn request_id_from_extensions(extensions: &Extensions) -> Option<String> {
49    extensions.get::<RequestId>().map(|request_id| request_id.value().to_string())
50}
51
52impl<S> FromRequestParts<S> for RequestId
53where
54    S: Send + Sync,
55{
56    type Rejection = HttpException;
57
58    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
59        parts
60            .extensions
61            .get::<RequestId>()
62            .cloned()
63            .ok_or_else(|| HttpException::internal_server_error("Request id not available"))
64    }
65}
66
67/*
68Param<T> = path param wrapper
69
70User writes:
71id: Param<u64>
72
73Instead of:
74Path(id): Path<u64>
75*/
76#[derive(Debug, Clone, Copy)]
77pub struct Param<T>(pub T);
78
79impl<T> Deref for Param<T> {
80    type Target = T;
81
82    fn deref(&self) -> &Self::Target {
83        &self.0
84    }
85}
86
87pub struct Decorated<T>
88where
89    T: RequestDecorator,
90{
91    value: T::Output,
92    _marker: std::marker::PhantomData<T>,
93}
94
95impl<T> Deref for Decorated<T>
96where
97    T: RequestDecorator,
98{
99    type Target = T::Output;
100
101    fn deref(&self) -> &Self::Target {
102        &self.value
103    }
104}
105
106impl<T> Decorated<T>
107where
108    T: RequestDecorator,
109{
110    pub fn into_inner(self) -> T::Output {
111        self.value
112    }
113}
114
115impl<S, T> FromRequestParts<S> for Decorated<T>
116where
117    S: Send + Sync,
118    T: RequestDecorator,
119{
120    type Rejection = HttpException;
121
122    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
123        let ctx = RequestContext::from_parts(parts);
124        let value = T::extract(&ctx, parts)?;
125
126        Ok(Self {
127            value,
128            _marker: std::marker::PhantomData,
129        })
130    }
131}
132
133impl<T> Param<T> {
134    pub fn into_inner(self) -> T {
135        self.0
136    }
137
138    pub fn value(self) -> T {
139        self.0
140    }
141}
142
143/*
144Extract Param<T> from route path params.
145*/
146impl<S, T> FromRequestParts<S> for Param<T>
147where
148    S: Send + Sync,
149    T: DeserializeOwned + Send + 'static,
150{
151    type Rejection = HttpException;
152
153    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
154        let request_id = request_id_from_extensions(&parts.extensions);
155        let Path(value) = Path::<T>::from_request_parts(parts, state)
156            .await
157            .map_err(|_| {
158                HttpException::bad_request("Invalid route parameter")
159                    .with_optional_request_id(request_id)
160            })?;
161
162        Ok(Self(value))
163    }
164}
165
166pub struct PipedParam<T, P>
167where
168    P: Pipe<T>,
169{
170    value: P::Output,
171    _marker: std::marker::PhantomData<(T, P)>,
172}
173
174impl<T, P> Deref for PipedParam<T, P>
175where
176    P: Pipe<T>,
177{
178    type Target = P::Output;
179
180    fn deref(&self) -> &Self::Target {
181        &self.value
182    }
183}
184
185impl<T, P> PipedParam<T, P>
186where
187    P: Pipe<T>,
188{
189    pub fn into_inner(self) -> P::Output {
190        self.value
191    }
192}
193
194impl<S, T, P> FromRequestParts<S> for PipedParam<T, P>
195where
196    S: Send + Sync,
197    T: DeserializeOwned + Send + 'static,
198    P: Pipe<T>,
199{
200    type Rejection = HttpException;
201
202    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
203        let Path(value) = Path::<T>::from_request_parts(parts, state)
204            .await
205            .map_err(|_| {
206                HttpException::bad_request("Invalid route parameter")
207                    .with_optional_request_id(request_id_from_extensions(&parts.extensions))
208            })?;
209        let ctx = RequestContext::from_parts(parts);
210        let value = P::transform(value, &ctx)?;
211
212        Ok(Self {
213            value,
214            _marker: std::marker::PhantomData,
215        })
216    }
217}
218
219#[derive(Debug, Clone)]
220pub struct Query<T>(pub T);
221
222impl<T> Deref for Query<T> {
223    type Target = T;
224
225    fn deref(&self) -> &Self::Target {
226        &self.0
227    }
228}
229
230impl<T> Query<T> {
231    pub fn into_inner(self) -> T {
232        self.0
233    }
234
235    pub fn value(self) -> T {
236        self.0
237    }
238}
239
240impl<S, T> FromRequestParts<S> for Query<T>
241where
242    S: Send + Sync,
243    T: DeserializeOwned + Send + 'static,
244{
245    type Rejection = HttpException;
246
247    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
248        let request_id = request_id_from_extensions(&parts.extensions);
249        let AxumQuery(value) = AxumQuery::<T>::from_request_parts(parts, state)
250            .await
251            .map_err(|_| {
252                HttpException::bad_request("Invalid query parameters")
253                    .with_optional_request_id(request_id)
254            })?;
255
256        Ok(Self(value))
257    }
258}
259
260pub struct PipedQuery<T, P>
261where
262    P: Pipe<T>,
263{
264    value: P::Output,
265    _marker: std::marker::PhantomData<(T, P)>,
266}
267
268impl<T, P> Deref for PipedQuery<T, P>
269where
270    P: Pipe<T>,
271{
272    type Target = P::Output;
273
274    fn deref(&self) -> &Self::Target {
275        &self.value
276    }
277}
278
279impl<T, P> PipedQuery<T, P>
280where
281    P: Pipe<T>,
282{
283    pub fn into_inner(self) -> P::Output {
284        self.value
285    }
286}
287
288impl<S, T, P> FromRequestParts<S> for PipedQuery<T, P>
289where
290    S: Send + Sync,
291    T: DeserializeOwned + Send + 'static,
292    P: Pipe<T>,
293{
294    type Rejection = HttpException;
295
296    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
297        let AxumQuery(value) = AxumQuery::<T>::from_request_parts(parts, state)
298            .await
299            .map_err(|_| {
300                HttpException::bad_request("Invalid query parameters")
301                    .with_optional_request_id(request_id_from_extensions(&parts.extensions))
302            })?;
303        let ctx = RequestContext::from_parts(parts);
304        let value = P::transform(value, &ctx)?;
305
306        Ok(Self {
307            value,
308            _marker: std::marker::PhantomData,
309        })
310    }
311}
312
313#[derive(Debug, Clone)]
314pub struct Headers(pub HeaderMap);
315
316impl Headers {
317    pub fn get(&self, name: &str) -> Option<&axum::http::HeaderValue> {
318        self.0.get(name)
319    }
320}
321
322impl Deref for Headers {
323    type Target = HeaderMap;
324
325    fn deref(&self) -> &Self::Target {
326        &self.0
327    }
328}
329
330impl<S> FromRequestParts<S> for Headers
331where
332    S: Send + Sync,
333{
334    type Rejection = HttpException;
335
336    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
337        Ok(Self(parts.headers.clone()))
338    }
339}
340
341#[derive(Debug, Clone, Default)]
342pub struct Cookies {
343    values: std::collections::BTreeMap<String, String>,
344}
345
346impl Cookies {
347    pub fn new<I, K, V>(pairs: I) -> Self
348    where
349        I: IntoIterator<Item = (K, V)>,
350        K: Into<String>,
351        V: Into<String>,
352    {
353        Self {
354            values: pairs
355                .into_iter()
356                .map(|(key, value)| (key.into(), value.into()))
357                .collect(),
358        }
359    }
360
361    pub fn get(&self, name: &str) -> Option<&str> {
362        self.values.get(name).map(String::as_str)
363    }
364}
365
366impl<S> FromRequestParts<S> for Cookies
367where
368    S: Send + Sync,
369{
370    type Rejection = HttpException;
371
372    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
373        let mut cookies = std::collections::BTreeMap::new();
374        if let Some(header) = parts.headers.get(axum::http::header::COOKIE) {
375            if let Ok(raw) = header.to_str() {
376                for pair in raw.split(';') {
377                    let trimmed = pair.trim();
378                    if let Some((name, value)) = trimmed.split_once('=') {
379                        cookies.insert(name.trim().to_string(), value.trim().to_string());
380                    }
381                }
382            }
383        }
384
385        Ok(Self { values: cookies })
386    }
387}
388
389/*
390Body<T> = JSON request body wrapper
391
392User writes:
393body: Body<CreateUserDto>
394
395Instead of:
396Json(dto): Json<CreateUserDto>
397*/
398pub struct Body<T>(pub T);
399
400impl<T> Deref for Body<T> {
401    type Target = T;
402
403    fn deref(&self) -> &Self::Target {
404        &self.0
405    }
406}
407
408impl<T> Body<T> {
409    pub fn into_inner(self) -> T {
410        self.0
411    }
412
413    pub fn value(self) -> T {
414        self.0
415    }
416}
417
418/*
419Extract Body<T> from JSON request body.
420*/
421impl<S, T> FromRequest<S> for Body<T>
422where
423    S: Send + Sync,
424    T: DeserializeOwned + Send + 'static,
425{
426    type Rejection = HttpException;
427
428    async fn from_request(req: axum::extract::Request, state: &S) -> Result<Self, Self::Rejection> {
429        let request_id = request_id_from_extensions(req.extensions());
430        let axum::Json(value) = axum::Json::<T>::from_request(req, state)
431            .await
432            .map_err(|_| {
433                HttpException::bad_request("Invalid JSON body").with_optional_request_id(request_id)
434            })?;
435
436        Ok(Self(value))
437    }
438}
439
440pub struct PipedBody<T, P>
441where
442    P: Pipe<T>,
443{
444    value: P::Output,
445    _marker: std::marker::PhantomData<(T, P)>,
446}
447
448impl<T, P> Deref for PipedBody<T, P>
449where
450    P: Pipe<T>,
451{
452    type Target = P::Output;
453
454    fn deref(&self) -> &Self::Target {
455        &self.value
456    }
457}
458
459impl<T, P> PipedBody<T, P>
460where
461    P: Pipe<T>,
462{
463    pub fn into_inner(self) -> P::Output {
464        self.value
465    }
466}
467
468impl<S, T, P> FromRequest<S> for PipedBody<T, P>
469where
470    S: Send + Sync,
471    T: DeserializeOwned + Send + 'static,
472    P: Pipe<T>,
473{
474    type Rejection = HttpException;
475
476    async fn from_request(
477        req: axum::extract::Request,
478        state: &S,
479    ) -> Result<Self, Self::Rejection> {
480        let ctx = RequestContext::from_request(&req);
481        let axum::Json(value) = axum::Json::<T>::from_request(req, state)
482            .await
483            .map_err(|_| {
484                HttpException::bad_request("Invalid JSON body")
485                    .with_optional_request_id(ctx.request_id.clone())
486            })?;
487        let value = P::transform(value, &ctx)?;
488
489        Ok(Self {
490            value,
491            _marker: std::marker::PhantomData,
492        })
493    }
494}
495
496pub struct ValidatedBody<T>(pub T);
497
498impl<T> Deref for ValidatedBody<T> {
499    type Target = T;
500
501    fn deref(&self) -> &Self::Target {
502        &self.0
503    }
504}
505
506impl<T> ValidatedBody<T> {
507    pub fn into_inner(self) -> T {
508        self.0
509    }
510
511    pub fn value(self) -> T {
512        self.0
513    }
514}
515
516impl<S, T> FromRequest<S> for ValidatedBody<T>
517where
518    S: Send + Sync,
519    T: DeserializeOwned + Validate + Send + 'static,
520{
521    type Rejection = HttpException;
522
523    async fn from_request(req: axum::extract::Request, state: &S) -> Result<Self, Self::Rejection> {
524        let request_id = request_id_from_extensions(req.extensions());
525        let axum::Json(value) = axum::Json::<T>::from_request(req, state)
526            .await
527            .map_err(|_| {
528                HttpException::bad_request("Invalid JSON body")
529                    .with_optional_request_id(request_id.clone())
530            })?;
531
532        value
533            .validate()
534            .map_err(|errors| {
535                HttpException::bad_request_validation(errors)
536                    .with_optional_request_id(request_id)
537            })?;
538
539        Ok(Self(value))
540    }
541}