hyper_method_override_middleware/
lib.rs1use hyper::{service::Service, Method, Request};
17use std::borrow::Borrow;
18use std::task::{Context, Poll};
19use url::form_urlencoded;
20
21#[derive(Debug, Clone)]
22pub struct MethodOverrideMiddleware<T> {
23 inner_service: T,
24}
25
26impl<T> MethodOverrideMiddleware<T> {
27 pub fn new(inner_service: T) -> Self {
28 Self { inner_service }
29 }
30}
31
32impl<InnerService, Body> Service<Request<Body>> for MethodOverrideMiddleware<InnerService>
33where
34 InnerService: Service<Request<Body>>,
35{
36 type Response = InnerService::Response;
37 type Error = InnerService::Error;
38 type Future = InnerService::Future;
39
40 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
41 self.inner_service.poll_ready(cx)
42 }
43
44 fn call(&mut self, mut req: Request<Body>) -> Self::Future {
45 if let Some(new_method) = override_method(&req) {
46 *req.method_mut() = new_method;
47 }
48 self.inner_service.call(req)
49 }
50}
51
52fn override_method<Body>(req: &Request<Body>) -> Option<Method> {
53 if req.method() != &Method::POST {
54 return None;
55 }
56
57 form_urlencoded::parse(req.uri().query().unwrap_or("").as_bytes())
58 .find(|(param_name, _)| param_name == "_method")
59 .and_then(|(_, method)| match method.borrow() {
60 "DELETE" => Some(Method::DELETE),
61 "PATCH" => Some(Method::PATCH),
62 "PUT" => Some(Method::PUT),
63 _ => None,
64 })
65}
66
67#[cfg(test)]
68mod tests {
69 use super::*;
70 use hyper::service::{make_service_fn, service_fn};
71 use hyper::{Body, Request, Response, Server};
72 use std::convert::Infallible;
73
74 async fn handle(req: Request<Body>) -> Result<Response<Body>, Infallible> {
75 let body = format!("{:?}", req.method()).into();
76 Ok(Response::new(body))
77 }
78
79 async fn send(method: Method, url: &str) -> String {
80 reqwest::Client::new()
81 .execute(reqwest::Request::new(
82 method,
83 reqwest::Url::parse(url).unwrap(),
84 ))
85 .await
86 .unwrap()
87 .text()
88 .await
89 .unwrap()
90 }
91
92 #[tokio::test]
93 async fn override_test() {
94 let addr = ([127, 0, 0, 1], 1337).into();
95
96 tokio::spawn(Server::bind(&addr).serve(make_service_fn(|_| async {
97 let service = MethodOverrideMiddleware::new(service_fn(handle));
98 Ok::<_, hyper::Error>(service)
99 })));
100
101 assert_eq!(send(Method::GET, "http://127.0.0.1:1337").await, "GET");
103 assert_eq!(send(Method::PUT, "http://127.0.0.1:1337").await, "PUT");
104 assert_eq!(send(Method::POST, "http://127.0.0.1:1337").await, "POST");
105 assert_eq!(send(Method::PATCH, "http://127.0.0.1:1337").await, "PATCH");
106
107 assert_eq!(
109 send(Method::POST, "http://127.0.0.1:1337?a=1&b=2&_method=PATCH").await,
110 "PATCH"
111 );
112 assert_eq!(
113 send(Method::POST, "http://127.0.0.1:1337?a=1&b=2&_method=PUT").await,
114 "PUT"
115 );
116 assert_eq!(
117 send(Method::POST, "http://127.0.0.1:1337?a=1&b=2&_method=DELETE").await,
118 "DELETE"
119 );
120
121 assert_eq!(
123 send(Method::POST, "http://127.0.0.1:1337?a=1&b=2&_method=GET").await,
124 "POST"
125 );
126 assert_eq!(
127 send(Method::POST, "http://127.0.0.1:1337?_method=OPTIONS").await,
128 "POST"
129 );
130
131 assert_eq!(
133 send(Method::GET, "http://127.0.0.1:1337?_method=PATCH").await,
134 "GET"
135 );
136 assert_eq!(
137 send(Method::DELETE, "http://127.0.0.1:1337?_method=PUT").await,
138 "DELETE"
139 );
140 assert_eq!(
141 send(Method::PATCH, "http://127.0.0.1:1337?_method=DELETE").await,
142 "PATCH"
143 );
144 }
145}