Skip to main content

nestforge_core/
request.rs

1use std::{ops::Deref, sync::Arc};
2
3use axum::{
4    extract::{FromRequest, FromRequestParts, Path, Query as AxumQuery},
5    http::{request::Parts, Extensions, HeaderMap},
6};
7use serde::de::DeserializeOwned;
8
9use crate::{HttpException, RequestContext, Validate};
10
11/**
12 * Pipe Transform Trait
13 *
14 * A transformation pipe transforms input data (such as query parameters or request bodies)
15 * into a desired output type. Pipes can also perform validation as part of the transformation.
16 *
17 * # Type Parameters
18 * - `Input`: The input type to transform from
19 *
20 * # Associated Types
21 * - `Output`: The resulting type after transformation
22 *
23 * # Example
24 * ```rust
25 * struct ParseIntPipe;
26 * impl Pipe<String> for ParseIntPipe {
27 *     type Output = i32;
28 *     fn transform(value: String, ctx: &RequestContext) -> Result<Self::Output, HttpException> {
29 *         value.parse().map_err(|_| HttpException::bad_request("Invalid integer"))
30 *     }
31 * }
32 * ```
33 */
34pub trait Pipe<Input>: Send + Sync + 'static {
35    type Output;
36
37    fn transform(value: Input, ctx: &RequestContext) -> Result<Self::Output, HttpException>;
38}
39
40/**
41 * Request Decorator Trait
42 *
43 * A custom request decorator (extractor) defines custom logic for extracting
44 * data from an incoming HTTP request. This is the foundation for creating
45 * custom extractors beyond the built-in ones.
46 *
47 * # Type Parameters
48 * - `Self`: The decorator type implementing the trait
49 *
50 * # Associated Types
51 * - `Output`: The type that gets extracted from the request
52 *
53 * # Example
54 * ```rust
55 * struct UserAgent;
56 * impl RequestDecorator for UserAgent {
57 *     type Output = String;
58 *     fn extract(ctx: &RequestContext, parts: &Parts) -> Result<Self::Output, HttpException> {
59 *         parts.headers
60 *             .get("user-agent")
61 *             .and_then(|h| h.to_str().ok())
62 *             .map(|s| s.to_string())
63 *             .ok_or_else(|| HttpException::bad_request("Missing User-Agent"))
64 *     }
65 * }
66 * ```
67 */
68pub trait RequestDecorator: Send + Sync + 'static {
69    type Output: Send + 'static;
70
71    fn extract(ctx: &RequestContext, parts: &Parts) -> Result<Self::Output, HttpException>;
72}
73
74/**
75 * RequestId
76 *
77 * A unique identifier for the current request. Automatically generated by
78 * the framework and available in every request context. Useful for logging,
79 * tracing, and debugging.
80 */
81#[derive(Debug, Clone)]
82pub struct RequestId(pub Arc<str>);
83
84impl RequestId {
85    pub fn new(value: impl Into<String>) -> Self {
86        Self(Arc::<str>::from(value.into()))
87    }
88
89    pub fn into_inner(self) -> String {
90        self.0.as_ref().to_string()
91    }
92
93    pub fn value(&self) -> &str {
94        self.0.as_ref()
95    }
96}
97
98impl Deref for RequestId {
99    type Target = str;
100
101    fn deref(&self) -> &Self::Target {
102        self.0.as_ref()
103    }
104}
105
106pub fn request_id_from_extensions(extensions: &Extensions) -> Option<String> {
107    extensions
108        .get::<RequestId>()
109        .map(|request_id| request_id.value().to_string())
110}
111
112impl<S> FromRequestParts<S> for RequestId
113where
114    S: Send + Sync,
115{
116    type Rejection = HttpException;
117
118    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
119        parts
120            .extensions
121            .get::<RequestId>()
122            .cloned()
123            .ok_or_else(|| HttpException::internal_server_error("Request id not available"))
124    }
125}
126
127/// Wrapper for extracting route parameters.
128///
129/// Replaces `axum::extract::Path`.
130///
131/// Usage: `fn get_user(id: Param<u64>)`
132#[derive(Debug, Clone, Copy)]
133pub struct Param<T>(pub T);
134
135impl<T> Deref for Param<T> {
136    type Target = T;
137
138    fn deref(&self) -> &Self::Target {
139        &self.0
140    }
141}
142
143pub struct Decorated<T>
144where
145    T: RequestDecorator,
146{
147    value: T::Output,
148    _marker: std::marker::PhantomData<T>,
149}
150
151impl<T> Deref for Decorated<T>
152where
153    T: RequestDecorator,
154{
155    type Target = T::Output;
156
157    fn deref(&self) -> &Self::Target {
158        &self.value
159    }
160}
161
162impl<T> Decorated<T>
163where
164    T: RequestDecorator,
165{
166    pub fn into_inner(self) -> T::Output {
167        self.value
168    }
169}
170
171impl<S, T> FromRequestParts<S> for Decorated<T>
172where
173    S: Send + Sync,
174    T: RequestDecorator,
175{
176    type Rejection = HttpException;
177
178    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
179        let ctx = RequestContext::from_parts(parts);
180        let value = T::extract(&ctx, parts)?;
181
182        Ok(Self {
183            value,
184            _marker: std::marker::PhantomData,
185        })
186    }
187}
188
189impl<T> Param<T> {
190    pub fn into_inner(self) -> T {
191        self.0
192    }
193
194    pub fn value(self) -> T {
195        self.0
196    }
197}
198
199impl<S, T> FromRequestParts<S> for Param<T>
200where
201    S: Send + Sync,
202    T: DeserializeOwned + Send + 'static,
203{
204    type Rejection = HttpException;
205
206    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
207        let request_id = request_id_from_extensions(&parts.extensions);
208        let Path(value) = Path::<T>::from_request_parts(parts, state)
209            .await
210            .map_err(|_| {
211                HttpException::bad_request("Invalid route parameter")
212                    .with_optional_request_id(request_id)
213            })?;
214
215        Ok(Self(value))
216    }
217}
218
219pub struct PipedParam<T, P>
220where
221    P: Pipe<T>,
222{
223    value: P::Output,
224    _marker: std::marker::PhantomData<(T, P)>,
225}
226
227impl<T, P> Deref for PipedParam<T, P>
228where
229    P: Pipe<T>,
230{
231    type Target = P::Output;
232
233    fn deref(&self) -> &Self::Target {
234        &self.value
235    }
236}
237
238impl<T, P> PipedParam<T, P>
239where
240    P: Pipe<T>,
241{
242    pub fn into_inner(self) -> P::Output {
243        self.value
244    }
245}
246
247impl<S, T, P> FromRequestParts<S> for PipedParam<T, P>
248where
249    S: Send + Sync,
250    T: DeserializeOwned + Send + 'static,
251    P: Pipe<T>,
252{
253    type Rejection = HttpException;
254
255    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
256        let Path(value) = Path::<T>::from_request_parts(parts, state)
257            .await
258            .map_err(|_| {
259                HttpException::bad_request("Invalid route parameter")
260                    .with_optional_request_id(request_id_from_extensions(&parts.extensions))
261            })?;
262        let ctx = RequestContext::from_parts(parts);
263        let value = P::transform(value, &ctx)?;
264
265        Ok(Self {
266            value,
267            _marker: std::marker::PhantomData,
268        })
269    }
270}
271
272/// Wrapper for extracting query parameters.
273///
274/// Replaces `axum::extract::Query`.
275///
276/// Usage: `fn search(q: Query<SearchDto>)`
277#[derive(Debug, Clone)]
278pub struct Query<T>(pub T);
279
280impl<T> Deref for Query<T> {
281    type Target = T;
282
283    fn deref(&self) -> &Self::Target {
284        &self.0
285    }
286}
287
288impl<T> Query<T> {
289    pub fn into_inner(self) -> T {
290        self.0
291    }
292
293    pub fn value(self) -> T {
294        self.0
295    }
296}
297
298impl<S, T> FromRequestParts<S> for Query<T>
299where
300    S: Send + Sync,
301    T: DeserializeOwned + Send + 'static,
302{
303    type Rejection = HttpException;
304
305    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
306        let request_id = request_id_from_extensions(&parts.extensions);
307        let AxumQuery(value) = AxumQuery::<T>::from_request_parts(parts, state)
308            .await
309            .map_err(|_| {
310                HttpException::bad_request("Invalid query parameters")
311                    .with_optional_request_id(request_id)
312            })?;
313
314        Ok(Self(value))
315    }
316}
317
318pub struct PipedQuery<T, P>
319where
320    P: Pipe<T>,
321{
322    value: P::Output,
323    _marker: std::marker::PhantomData<(T, P)>,
324}
325
326impl<T, P> Deref for PipedQuery<T, P>
327where
328    P: Pipe<T>,
329{
330    type Target = P::Output;
331
332    fn deref(&self) -> &Self::Target {
333        &self.value
334    }
335}
336
337impl<T, P> PipedQuery<T, P>
338where
339    P: Pipe<T>,
340{
341    pub fn into_inner(self) -> P::Output {
342        self.value
343    }
344}
345
346impl<S, T, P> FromRequestParts<S> for PipedQuery<T, P>
347where
348    S: Send + Sync,
349    T: DeserializeOwned + Send + 'static,
350    P: Pipe<T>,
351{
352    type Rejection = HttpException;
353
354    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
355        let AxumQuery(value) = AxumQuery::<T>::from_request_parts(parts, state)
356            .await
357            .map_err(|_| {
358                HttpException::bad_request("Invalid query parameters")
359                    .with_optional_request_id(request_id_from_extensions(&parts.extensions))
360            })?;
361        let ctx = RequestContext::from_parts(parts);
362        let value = P::transform(value, &ctx)?;
363
364        Ok(Self {
365            value,
366            _marker: std::marker::PhantomData,
367        })
368    }
369}
370
371/// Extractor for HTTP headers.
372#[derive(Debug, Clone)]
373pub struct Headers(pub HeaderMap);
374
375impl Headers {
376    pub fn get(&self, name: &str) -> Option<&axum::http::HeaderValue> {
377        self.0.get(name)
378    }
379}
380
381impl Deref for Headers {
382    type Target = HeaderMap;
383
384    fn deref(&self) -> &Self::Target {
385        &self.0
386    }
387}
388
389impl<S> FromRequestParts<S> for Headers
390where
391    S: Send + Sync,
392{
393    type Rejection = HttpException;
394
395    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
396        Ok(Self(parts.headers.clone()))
397    }
398}
399
400/// Extractor for HTTP cookies.
401///
402/// Parses the `Cookie` header into a map.
403#[derive(Debug, Clone, Default)]
404pub struct Cookies {
405    values: std::collections::BTreeMap<String, String>,
406}
407
408impl Cookies {
409    pub fn new<I, K, V>(pairs: I) -> Self
410    where
411        I: IntoIterator<Item = (K, V)>,
412        K: Into<String>,
413        V: Into<String>,
414    {
415        Self {
416            values: pairs
417                .into_iter()
418                .map(|(key, value)| (key.into(), value.into()))
419                .collect(),
420        }
421    }
422
423    pub fn get(&self, name: &str) -> Option<&str> {
424        self.values.get(name).map(String::as_str)
425    }
426}
427
428impl<S> FromRequestParts<S> for Cookies
429where
430    S: Send + Sync,
431{
432    type Rejection = HttpException;
433
434    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
435        let mut cookies = std::collections::BTreeMap::new();
436        if let Some(header) = parts.headers.get(axum::http::header::COOKIE) {
437            if let Ok(raw) = header.to_str() {
438                for pair in raw.split(';') {
439                    let trimmed = pair.trim();
440                    if let Some((name, value)) = trimmed.split_once('=') {
441                        cookies.insert(name.trim().to_string(), value.trim().to_string());
442                    }
443                }
444            }
445        }
446
447        Ok(Self { values: cookies })
448    }
449}
450
451/// Wrapper for extracting the JSON request body.
452///
453/// Replaces `axum::extract::Json`.
454///
455/// Usage: `fn create_user(body: Body<CreateUserDto>)`
456pub struct Body<T>(pub T);
457
458impl<T> Deref for Body<T> {
459    type Target = T;
460
461    fn deref(&self) -> &Self::Target {
462        &self.0
463    }
464}
465
466impl<T> Body<T> {
467    pub fn into_inner(self) -> T {
468        self.0
469    }
470
471    pub fn value(self) -> T {
472        self.0
473    }
474}
475
476/*
477Extract Body<T> from JSON request body.
478*/
479impl<S, T> FromRequest<S> for Body<T>
480where
481    S: Send + Sync,
482    T: DeserializeOwned + Send + 'static,
483{
484    type Rejection = HttpException;
485
486    async fn from_request(req: axum::extract::Request, state: &S) -> Result<Self, Self::Rejection> {
487        let request_id = request_id_from_extensions(req.extensions());
488        let axum::Json(value) = axum::Json::<T>::from_request(req, state)
489            .await
490            .map_err(|_| {
491                HttpException::bad_request("Invalid JSON body").with_optional_request_id(request_id)
492            })?;
493
494        Ok(Self(value))
495    }
496}
497
498pub struct PipedBody<T, P>
499where
500    P: Pipe<T>,
501{
502    value: P::Output,
503    _marker: std::marker::PhantomData<(T, P)>,
504}
505
506impl<T, P> Deref for PipedBody<T, P>
507where
508    P: Pipe<T>,
509{
510    type Target = P::Output;
511
512    fn deref(&self) -> &Self::Target {
513        &self.value
514    }
515}
516
517impl<T, P> PipedBody<T, P>
518where
519    P: Pipe<T>,
520{
521    pub fn into_inner(self) -> P::Output {
522        self.value
523    }
524}
525
526impl<S, T, P> FromRequest<S> for PipedBody<T, P>
527where
528    S: Send + Sync,
529    T: DeserializeOwned + Send + 'static,
530    P: Pipe<T>,
531{
532    type Rejection = HttpException;
533
534    async fn from_request(req: axum::extract::Request, state: &S) -> Result<Self, Self::Rejection> {
535        let ctx = RequestContext::from_request(&req);
536        let axum::Json(value) = axum::Json::<T>::from_request(req, state)
537            .await
538            .map_err(|_| {
539                HttpException::bad_request("Invalid JSON body")
540                    .with_optional_request_id(ctx.request_id.clone())
541            })?;
542        let value = P::transform(value, &ctx)?;
543
544        Ok(Self {
545            value,
546            _marker: std::marker::PhantomData,
547        })
548    }
549}
550
551/// A wrapper that validates the request body.
552///
553/// Requires the inner type to implement `nestforge::Validate`.
554/// If validation fails, it returns `400 Bad Request` with validation details.
555pub struct ValidatedBody<T>(pub T);
556
557impl<T> Deref for ValidatedBody<T> {
558    type Target = T;
559
560    fn deref(&self) -> &Self::Target {
561        &self.0
562    }
563}
564
565impl<T> ValidatedBody<T> {
566    pub fn into_inner(self) -> T {
567        self.0
568    }
569
570    pub fn value(self) -> T {
571        self.0
572    }
573}
574
575impl<S, T> FromRequest<S> for ValidatedBody<T>
576where
577    S: Send + Sync,
578    T: DeserializeOwned + Validate + Send + 'static,
579{
580    type Rejection = HttpException;
581
582    async fn from_request(req: axum::extract::Request, state: &S) -> Result<Self, Self::Rejection> {
583        let request_id = request_id_from_extensions(req.extensions());
584        let axum::Json(value) = axum::Json::<T>::from_request(req, state)
585            .await
586            .map_err(|_| {
587                HttpException::bad_request("Invalid JSON body")
588                    .with_optional_request_id(request_id.clone())
589            })?;
590
591        value.validate().map_err(|errors| {
592            HttpException::bad_request_validation(errors).with_optional_request_id(request_id)
593        })?;
594
595        Ok(Self(value))
596    }
597}