soph_server/request/
mod.rs

1use crate::{async_trait, traits::RequestTrait};
2
3#[cfg(feature = "request-validate")]
4pub use validate::Validate;
5
6#[cfg(feature = "request-validate")]
7mod validate;
8
9#[async_trait]
10impl RequestTrait for axum::extract::Request {
11    fn content_type(&self) -> mime::Mime {
12        if has_content_type(self.headers(), &mime::APPLICATION_WWW_FORM_URLENCODED) {
13            return mime::APPLICATION_WWW_FORM_URLENCODED;
14        }
15
16        if has_content_type(self.headers(), &mime::APPLICATION_JSON) {
17            return mime::APPLICATION_JSON;
18        }
19
20        if has_content_type(self.headers(), &mime::MULTIPART_FORM_DATA) {
21            return mime::MULTIPART_FORM_DATA;
22        }
23
24        mime::TEXT_PLAIN
25    }
26
27    fn header<K>(&self, key: K) -> Option<&str>
28    where
29        K: axum::http::header::AsHeaderName,
30    {
31        self.headers().get(key).and_then(|value| value.to_str().ok())
32    }
33
34    fn user_agent(&self) -> &str {
35        self.headers()
36            .get(axum::http::header::USER_AGENT)
37            .map_or("", |h| h.to_str().unwrap_or(""))
38    }
39
40    async fn path<T: serde::de::DeserializeOwned + Send + 'static>(
41        &mut self,
42    ) -> Result<T, axum::extract::rejection::PathRejection> {
43        use axum::{extract::Path, RequestExt};
44
45        let Path(path) = self.extract_parts::<Path<T>>().await?;
46
47        Ok(path)
48    }
49
50    #[cfg(feature = "request-query")]
51    async fn query<T: serde::de::DeserializeOwned + 'static>(
52        &mut self,
53    ) -> crate::ServerResult<T, axum::extract::rejection::QueryRejection> {
54        use axum::{extract::Query, RequestExt};
55
56        let Query(query) = self.extract_parts::<Query<T>>().await?;
57
58        Ok(query)
59    }
60
61    #[cfg(feature = "request-form")]
62    async fn form<T: serde::de::DeserializeOwned>(
63        self,
64    ) -> crate::ServerResult<T, axum::extract::rejection::FormRejection> {
65        use axum::{extract::FromRequest, Form};
66        let Form(payload) = Form::<T>::from_request(self, &()).await?;
67
68        Ok(payload)
69    }
70
71    #[cfg(feature = "request-json")]
72    async fn json<T: serde::de::DeserializeOwned>(
73        self,
74    ) -> crate::ServerResult<T, axum::extract::rejection::JsonRejection> {
75        use axum::{extract::FromRequest, Json};
76        let Json(payload) = Json::<T>::from_request(self, &()).await?;
77
78        Ok(payload)
79    }
80
81    #[cfg(feature = "request-multipart")]
82    async fn multipart<T: serde::de::DeserializeOwned>(self) -> crate::ServerResult<T> {
83        use axum::extract::{FromRequest, Multipart};
84        let mut multipart = Multipart::from_request(self, &()).await?;
85
86        let mut data = serde_json::Map::new();
87
88        while let Some(field) = multipart.next_field().await? {
89            if let Some(name) = field.name() {
90                if let Some(file_name) = field.file_name() {
91                    data.insert(name.to_owned(), file_name.into());
92                } else {
93                    data.insert(name.to_owned(), field.text().await?.into());
94                }
95            }
96        }
97
98        let payload = serde_json::from_value::<T>(data.into())?;
99
100        Ok(payload)
101    }
102
103    #[cfg(feature = "request-multipart")]
104    async fn file(self, key: &str) -> crate::ServerResult<Option<axum::body::Bytes>> {
105        use axum::extract::{FromRequest, Multipart};
106        let mut multipart = Multipart::from_request(self, &()).await?;
107
108        while let Some(field) = multipart.next_field().await? {
109            if let Some(name) = field.name() {
110                if field.file_name().is_some() && name == key {
111                    return Ok(Some(field.bytes().await?));
112                }
113            }
114        }
115
116        Ok(None)
117    }
118
119    #[cfg(feature = "request-validate")]
120    async fn validate<T: serde::de::DeserializeOwned + validator::Validate>(self) -> crate::ServerResult<T> {
121        let payload: T = if self.content_type() == mime::APPLICATION_JSON {
122            self.json::<T>().await?
123        } else if self.content_type() == mime::APPLICATION_WWW_FORM_URLENCODED {
124            self.form::<T>().await?
125        } else if self.content_type() == mime::MULTIPART_FORM_DATA {
126            self.multipart::<T>().await?
127        } else {
128            return Err(crate::error::Error::UnsupportedMediaType);
129        };
130
131        payload.validate()?;
132
133        Ok(payload)
134    }
135
136    #[cfg(feature = "request-id")]
137    fn id(&self) -> Option<&str> {
138        self.header(crate::config::X_REQUEST_ID)
139    }
140
141    #[cfg(feature = "request-auth")]
142    fn user(&self) -> Option<soph_auth::support::UserClaims> {
143        self.extensions().get::<soph_auth::support::UserClaims>().cloned()
144    }
145
146    #[cfg(feature = "request-auth")]
147    fn token(&self) -> crate::ServerResult<String> {
148        use crate::{error::Error, traits::ErrorTrait};
149        use soph_config::support::config;
150
151        let config = config().parse::<soph_auth::config::Auth>().map_err(Error::wrap)?;
152
153        match config
154            .jwt
155            .location
156            .as_ref()
157            .unwrap_or(&soph_auth::config::JwtLocation::Bearer)
158        {
159            soph_auth::config::JwtLocation::Query { name } => Ok(extract_token_from_query(name, self.uri())?),
160            soph_auth::config::JwtLocation::Cookie { name } => Ok(extract_token_from_cookie(
161                name,
162                &axum_extra::extract::cookie::CookieJar::from_headers(self.headers()),
163            )?),
164            soph_auth::config::JwtLocation::Bearer => Ok(extract_token_from_header(self.headers())
165                .map_err(|e| soph_auth::error::Error::Unauthorized(e.to_string()))?),
166        }
167    }
168}
169
170fn has_content_type(headers: &axum::http::header::HeaderMap, expected_content_type: &mime::Mime) -> bool {
171    let content_type = if let Some(content_type) = headers.get(axum::http::header::CONTENT_TYPE) {
172        content_type
173    } else {
174        return false;
175    };
176
177    let content_type = if let Ok(content_type) = content_type.to_str() {
178        content_type
179    } else {
180        return false;
181    };
182
183    content_type.starts_with(expected_content_type.as_ref())
184}
185
186#[cfg(feature = "request-auth")]
187fn extract_token_from_header(headers: &axum::http::header::HeaderMap) -> crate::ServerResult<String> {
188    use axum::http::header::AUTHORIZATION;
189    use soph_auth::error::Error;
190
191    Ok(headers
192        .get(AUTHORIZATION)
193        .ok_or(Error::Unauthorized(format!("header {} token not found", AUTHORIZATION)))?
194        .to_str()
195        .map_err(|err| Error::Unauthorized(err.to_string()))?
196        .strip_prefix("Bearer ")
197        .ok_or(Error::Unauthorized(format!("error strip {} value", AUTHORIZATION)))?
198        .to_string())
199}
200
201#[cfg(feature = "request-auth")]
202fn extract_token_from_cookie(name: &str, jar: &axum_extra::extract::cookie::CookieJar) -> crate::ServerResult<String> {
203    use soph_auth::error::Error;
204
205    Ok(jar
206        .get(name)
207        .ok_or(Error::Unauthorized("token is not found".to_string()))?
208        .to_string()
209        .strip_prefix(&format!("{name}="))
210        .ok_or(Error::Unauthorized("error strip value".to_string()))?
211        .to_string())
212}
213
214#[cfg(feature = "request-auth")]
215fn extract_token_from_query(name: &str, uri: &axum::http::Uri) -> crate::ServerResult<String> {
216    use axum::extract::Query;
217    use soph_auth::error::Error;
218    use std::collections::HashMap;
219
220    let parameters: Query<HashMap<String, String>> =
221        Query::try_from_uri(uri).map_err(|err| Error::Unauthorized(err.to_string()))?;
222
223    Ok(parameters
224        .get(name)
225        .cloned()
226        .ok_or(Error::Unauthorized(format!("`{name}` query parameter not found")))?)
227}