1use std::{
2 collections::HashMap,
3 net::{IpAddr, Ipv4Addr, SocketAddr},
4};
5
6use async_trait::async_trait;
7pub use axum::http::request::Parts;
8pub use axum::http::Request;
9use axum::{
10 body::Bytes,
11 extract::{
12 connect_info::ConnectInfo,
13 rejection::{ExtensionRejection, JsonRejection, PathRejection, QueryRejection},
14 FromRef, FromRequest, FromRequestParts, Path, Query,
15 },
16 Json,
17};
18use futures::StreamExt;
19use hyper::{HeaderMap, Method, Uri, Version};
20use serde::de::DeserializeOwned;
21use serde_json::Value;
22
23use crate::Body;
24
25type HashMapRequest = HashMap<String, String>;
26
27#[derive(Debug)]
28pub struct Context<InnerState = ()> {
29 params_map: HashMapRequest,
30 query_map: HashMapRequest,
31 connect_info: Result<ConnectInfo<SocketAddr>, ExtensionRejection>,
32 bytes: Bytes,
33 inner_state: InnerState,
34 headers: HeaderMap,
35 method: Method,
36 uri: Uri,
37 version: Version,
38}
39
40impl<InnerState> Context<InnerState> {
44 pub fn headers(&self) -> &HeaderMap {
45 &self.headers
46 }
47 pub fn method(&self) -> &Method {
48 &self.method
49 }
50 pub fn version(&self) -> &Version {
51 &self.version
52 }
53 pub fn uri(&self) -> &Uri {
54 &self.uri
55 }
56
57 pub fn addr(&self) -> SocketAddr {
58 match self.connect_info {
59 Ok(ConnectInfo(addr)) => addr,
60 Err(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 8080),
61 }
62 }
63
64 pub fn ip(&self) -> IpAddr {
65 self.addr().ip()
66 }
67
68 pub fn body(&self) -> String {
69 String::from_utf8(self.bytes.to_vec()).expect("")
70 }
71 pub fn bytes(&self) -> &Bytes {
72 &self.bytes
73 }
74 pub fn state(&self) -> &InnerState {
75 &self.inner_state
76 }
77
78 pub async fn parse_params<T: DeserializeOwned>(&self) -> Result<Json<T>, JsonRejection> {
79 let value = match serde_json::to_string(&self.params_map) {
80 Ok(data) => data,
81 Err(_) => String::new(),
82 };
83 let request = Request::builder()
84 .header("Content-Type", "application/json")
85 .body(Body::from(value));
86
87 let request = match request {
88 Ok(value) => value,
89 Err(_) => Request::default(),
90 };
91
92 Json::from_request(request, &()).await
93 }
94 pub fn all_params(&self) -> &HashMapRequest {
95 &self.params_map
96 }
97 pub fn params(&self, key: &'static str) -> String {
98 match self.params_map.get(key) {
99 Some(value) => value.clone(),
100 None => String::new(),
101 }
102 }
103 pub async fn parse_query<T: DeserializeOwned>(&self) -> Result<Json<T>, JsonRejection> {
104 let value = match serde_json::to_string(&self.query_map) {
105 Ok(data) => data,
106 Err(_) => String::new(),
107 };
108 let request = Request::builder()
109 .header("Content-Type", "application/json")
110 .body(Body::from(value));
111
112 let request = match request {
113 Ok(value) => value,
114 Err(_) => Request::default(),
115 };
116
117 Json::from_request(request, &()).await
118 }
119 pub fn query(&self, key: &'static str) -> String {
120 match self.query_map.get(key) {
121 Some(value) => value.clone(),
122 None => String::new(),
123 }
124 }
125 pub fn all_query(&self) -> &HashMapRequest {
126 &self.query_map
127 }
128
129 pub async fn payload<T: DeserializeOwned + Default>(&self) -> Result<Json<T>, JsonRejection> {
130 let request = Request::builder()
132 .header("Content-Type", "application/json")
133 .body(Body::from(self.bytes.clone()));
134
135 let request = match request {
136 Ok(value) => value,
137 Err(_) => Request::default(),
138 };
139
140 Json::from_request(request, &()).await
141 }
142
143 pub fn json(&self, payload: Value) -> Json<Value> {
144 Json(payload)
145 }
146
147 pub fn send(value: &str) -> &str {
148 value
149 }
150}
151
152#[async_trait]
153impl<OuterState, InnerState> FromRequest<OuterState, Body> for Context<InnerState>
154where
155 OuterState: Send + Sync + 'static,
156 InnerState: FromRef<OuterState> + Send + Sync,
157{
158 type Rejection = JsonRejection;
159
160 async fn from_request(
161 req: axum::http::Request<Body>,
162 state: &OuterState,
163 ) -> Result<Self, Self::Rejection> {
164 let inner_state = InnerState::from_ref(state);
165 let headers = req.headers().clone();
166 let method = req.method().clone();
167 let uri = req.uri().clone();
168 let version = req.version();
169 let (parts, body) = &mut req.into_parts();
170 let mut params_map = HashMap::new();
171 let mut query_map = HashMap::new();
172 let result_params: Result<Path<HashMapRequest>, PathRejection> =
173 Path::from_request_parts(parts, &()).await;
174
175 let connect_info: Result<ConnectInfo<SocketAddr>, ExtensionRejection> =
176 ConnectInfo::from_request_parts(parts, state).await;
177
178 if let Ok(params) = result_params {
179 match params {
180 Path(parse_params) => {
181 params_map = parse_params;
182 }
183 }
184 }
185
186 let result_query: Result<Query<HashMapRequest>, QueryRejection> =
187 Query::from_request_parts(parts, &()).await;
188 if let Ok(params) = result_query {
189 match params {
190 Query(parse_params) => {
191 query_map = parse_params;
192 }
193 }
194 }
195
196 let mut bytes = Bytes::new();
197 let n = body.map(|x| {
198 if let Ok(value) = x {
199 bytes = value
200 }
201 });
202 n.collect::<Vec<_>>().await;
204 Ok(Context {
205 version,
206 connect_info,
207 headers,
208 method,
209 uri,
210 bytes,
211 inner_state,
212 params_map,
213 query_map,
214 })
215 }
216}
217
218#[derive(Debug)]
219pub struct ContextPart<InnerState = ()> {
220 params_map: HashMapRequest,
221 connect_info: Result<ConnectInfo<SocketAddr>, ExtensionRejection>,
222 query_map: HashMapRequest,
223 inner_state: InnerState,
224 headers: HeaderMap,
225 method: Method,
226 uri: Uri,
227 version: Version,
228}
229
230impl<InnerState> ContextPart<InnerState> {
234 pub fn headers(&self) -> &HeaderMap {
235 &self.headers
236 }
237 pub fn method(&self) -> &Method {
238 &self.method
239 }
240 pub fn version(&self) -> &Version {
241 &self.version
242 }
243 pub fn uri(&self) -> &Uri {
244 &self.uri
245 }
246 pub fn state(&self) -> &InnerState {
247 &self.inner_state
248 }
249
250 pub fn addr(&self) -> SocketAddr {
251 match self.connect_info {
252 Ok(ConnectInfo(addr)) => addr,
253 Err(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 8080),
254 }
255 }
256
257 pub fn ip(&self) -> IpAddr {
258 self.addr().ip()
259 }
260
261 pub async fn parse_params<T: DeserializeOwned>(&self) -> Result<Json<T>, JsonRejection> {
262 let value = match serde_json::to_string(&self.params_map) {
263 Ok(data) => data,
264 Err(_) => String::new(),
265 };
266 let request = Request::builder()
267 .header("Content-Type", "application/json")
268 .body(Body::from(value));
269
270 let request = match request {
271 Ok(value) => value,
272 Err(_) => Request::default(),
273 };
274
275 Json::from_request(request, &()).await
276 }
277 pub fn all_params(&self) -> &HashMapRequest {
278 &self.params_map
279 }
280 pub fn params(&self, key: &'static str) -> String {
281 match self.params_map.get(key) {
282 Some(value) => value.clone(),
283 None => String::new(),
284 }
285 }
286 pub async fn parse_query<T: DeserializeOwned>(&self) -> Result<Json<T>, JsonRejection> {
287 let value = match serde_json::to_string(&self.query_map) {
288 Ok(data) => data,
289 Err(_) => String::new(),
290 };
291 let request = Request::builder()
292 .header("Content-Type", "application/json")
293 .body(Body::from(value));
294
295 let request = match request {
296 Ok(value) => value,
297 Err(_) => Request::default(),
298 };
299
300 Json::from_request(request, &()).await
301 }
302 pub fn query(&self, key: &'static str) -> String {
303 match self.query_map.get(key) {
304 Some(value) => value.clone(),
305 None => String::new(),
306 }
307 }
308 pub fn all_query(&self) -> &HashMapRequest {
309 &self.query_map
310 }
311
312 pub fn json(&self, payload: Value) -> Json<Value> {
313 Json(payload)
314 }
315
316 pub fn send(value: &str) -> &str {
317 value
318 }
319}
320
321#[async_trait]
322impl<OuterState, InnerState> FromRequestParts<OuterState> for ContextPart<InnerState>
323where
324 OuterState: Send + Sync + 'static,
325 InnerState: FromRef<OuterState> + Send + Sync,
326{
327 type Rejection = JsonRejection;
328
329 async fn from_request_parts(
330 parts: &mut Parts,
331 state: &OuterState,
332 ) -> Result<Self, Self::Rejection> {
333 let inner_state = InnerState::from_ref(state);
334 let headers = parts.headers.clone();
335 let method = parts.method.clone();
336 let uri = parts.uri.clone();
337 let version = parts.version;
338 let mut params_map = HashMap::new();
339 let mut query_map = HashMap::new();
340 let result_params: Result<Path<HashMapRequest>, PathRejection> =
341 Path::from_request_parts(parts, &()).await;
342
343 let connect_info: Result<ConnectInfo<SocketAddr>, ExtensionRejection> =
344 ConnectInfo::from_request_parts(parts, state).await;
345
346 if let Ok(params) = result_params {
347 match params {
348 Path(parse_params) => {
349 params_map = parse_params;
350 }
351 }
352 }
353
354 let result_query: Result<Query<HashMapRequest>, QueryRejection> =
355 Query::from_request_parts(parts, &()).await;
356 if let Ok(params) = result_query {
357 match params {
358 Query(parse_params) => {
359 query_map = parse_params;
360 }
361 }
362 }
363
364 Ok(ContextPart {
365 version,
366 connect_info,
367 headers,
368 method,
369 uri,
370 inner_state,
371 params_map,
372 query_map,
373 })
374 }
375}