rama_http/layer/validate_request/
accept_header.rs

1use super::ValidateRequest;
2use crate::{
3    dep::mime::{Mime, MimeIter},
4    header, Request, Response, StatusCode,
5};
6use rama_core::Context;
7use std::{fmt, marker::PhantomData, sync::Arc};
8
9/// Type that performs validation of the Accept header.
10pub struct AcceptHeader<ResBody = crate::Body> {
11    header_value: Arc<Mime>,
12    _ty: PhantomData<fn() -> ResBody>,
13}
14
15impl<ResBody> AcceptHeader<ResBody> {
16    /// Create a new `AcceptHeader`.
17    ///
18    /// # Panics
19    ///
20    /// Panics if `header_value` is not in the form: `type/subtype`, such as `application/json`
21    pub(super) fn new(header_value: &str) -> Self
22    where
23        ResBody: Default,
24    {
25        Self {
26            header_value: Arc::new(
27                header_value
28                    .parse::<Mime>()
29                    .expect("value is not a valid header value"),
30            ),
31            _ty: PhantomData,
32        }
33    }
34}
35
36impl<ResBody> Clone for AcceptHeader<ResBody> {
37    fn clone(&self) -> Self {
38        Self {
39            header_value: self.header_value.clone(),
40            _ty: PhantomData,
41        }
42    }
43}
44
45impl<ResBody> fmt::Debug for AcceptHeader<ResBody> {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        f.debug_struct("AcceptHeader")
48            .field("header_value", &self.header_value)
49            .finish()
50    }
51}
52
53impl<S, B, ResBody> ValidateRequest<S, B> for AcceptHeader<ResBody>
54where
55    S: Clone + Send + Sync + 'static,
56    B: Send + Sync + 'static,
57    ResBody: Default + Send + 'static,
58{
59    type ResponseBody = ResBody;
60
61    async fn validate(
62        &self,
63        ctx: Context<S>,
64        req: Request<B>,
65    ) -> Result<(Context<S>, Request<B>), Response<Self::ResponseBody>> {
66        if !req.headers().contains_key(header::ACCEPT) {
67            return Ok((ctx, req));
68        }
69        if req
70            .headers()
71            .get_all(header::ACCEPT)
72            .into_iter()
73            .filter_map(|header| header.to_str().ok())
74            .any(|h| {
75                MimeIter::new(h)
76                    .map(|mim| {
77                        if let Ok(mim) = mim {
78                            let typ = self.header_value.type_();
79                            let subtype = self.header_value.subtype();
80                            match (mim.type_(), mim.subtype()) {
81                                (t, s) if t == typ && s == subtype => true,
82                                (t, mime::STAR) if t == typ => true,
83                                (mime::STAR, mime::STAR) => true,
84                                _ => false,
85                            }
86                        } else {
87                            false
88                        }
89                    })
90                    .reduce(|acc, mim| acc || mim)
91                    .unwrap_or(false)
92            })
93        {
94            return Ok((ctx, req));
95        }
96        let mut res = Response::new(ResBody::default());
97        *res.status_mut() = StatusCode::NOT_ACCEPTABLE;
98        Err(res)
99    }
100}