jsonrpsee_server/middleware/http/
proxy_get_request.rs1use crate::transport::http;
30use crate::{HttpBody, HttpRequest, HttpResponse};
31use futures_util::{FutureExt, TryFutureExt};
32use http_body_util::BodyExt;
33use hyper::body::Bytes;
34use hyper::header::{ACCEPT, CONTENT_TYPE};
35use hyper::http::HeaderValue;
36use hyper::{Method, StatusCode, Uri};
37use jsonrpsee_core::BoxError;
38use jsonrpsee_types::{ErrorCode, ErrorObject, Id, Request};
39use std::collections::HashMap;
40use std::future::Future;
41use std::pin::Pin;
42use std::str::FromStr;
43use std::sync::Arc;
44use std::task::{Context, Poll};
45use tower::{Layer, Service};
46
47#[derive(Debug, thiserror::Error)]
49pub enum ProxyGetRequestError {
50 #[error("ProxyGetRequestLayer path must be unique, got duplicated `{0}`")]
52 DuplicatedPath(String),
53 #[error("ProxyGetRequestLayer path must start with `/`, got `{0}`")]
55 InvalidPath(String),
56}
57
58#[derive(Debug, Clone)]
63pub struct ProxyGetRequestLayer {
64 methods: Arc<HashMap<String, String>>,
66}
67
68impl ProxyGetRequestLayer {
69 pub fn new<P, M>(pairs: impl IntoIterator<Item = (P, M)>) -> Result<Self, ProxyGetRequestError>
74 where
75 P: Into<String>,
76 M: Into<String>,
77 {
78 let mut methods = HashMap::new();
79
80 for (path, method) in pairs {
81 let path = path.into();
82 let method = method.into();
83
84 if !path.starts_with('/') {
85 return Err(ProxyGetRequestError::InvalidPath(path));
86 }
87
88 if let Some(path) = methods.insert(path, method) {
89 return Err(ProxyGetRequestError::DuplicatedPath(path));
90 }
91 }
92
93 Ok(Self { methods: Arc::new(methods) })
94 }
95}
96
97impl<S> Layer<S> for ProxyGetRequestLayer {
98 type Service = ProxyGetRequest<S>;
99
100 fn layer(&self, inner: S) -> Self::Service {
101 ProxyGetRequest { inner, methods: self.methods.clone() }
102 }
103}
104
105#[derive(Debug, Clone)]
118pub struct ProxyGetRequest<S> {
119 inner: S,
120 methods: Arc<HashMap<String, String>>,
122}
123
124impl<S, B> Service<HttpRequest<B>> for ProxyGetRequest<S>
125where
126 S: Service<HttpRequest, Response = HttpResponse>,
127 S::Response: 'static,
128 S::Error: Into<BoxError> + 'static,
129 S::Future: Send + 'static,
130 B: http_body::Body<Data = Bytes> + Send + 'static,
131 B::Data: Send,
132 B::Error: Into<BoxError>,
133{
134 type Response = S::Response;
135 type Error = BoxError;
136 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
137
138 #[inline]
139 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
140 self.inner.poll_ready(cx).map_err(Into::into)
141 }
142
143 fn call(&mut self, mut req: HttpRequest<B>) -> Self::Future {
144 let path = req.uri().path();
145 let method = self.methods.get(path);
146
147 match (method, req.method()) {
148 (Some(method), &Method::GET) => {
150 *req.method_mut() = Method::POST;
152 *req.uri_mut() = if let Some(query) = req.uri().query() {
154 Uri::from_str(&format!("/?{}", query)).expect("The query comes from a valid URI; qed")
155 } else {
156 Uri::from_static("/")
157 };
158 req.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
160 req.headers_mut().insert(ACCEPT, HeaderValue::from_static("application/json"));
161
162 let bytes =
164 serde_json::to_vec(&Request::borrowed(method, None, Id::Number(0))).expect("Valid request; qed");
165 let req = req.map(|_| HttpBody::from(bytes));
166
167 let fut = self.inner.call(req);
169
170 async move {
171 let res = fut.await.map_err(Into::into)?;
172
173 let (parts, body) = res.into_parts();
174 let mut body = http_body_util::BodyStream::new(body);
175 let mut bytes = Vec::new();
176
177 while let Some(frame) = body.frame().await {
178 let data = frame?.into_data().map_err(|e| format!("{e:?}"))?;
179 bytes.extend(data);
180 }
181
182 #[derive(serde::Deserialize)]
183 struct SuccessResponse<'a> {
184 #[serde(borrow)]
185 result: &'a serde_json::value::RawValue,
186 }
187
188 let mut response = if let Ok(payload) = serde_json::from_slice::<SuccessResponse>(&bytes) {
189 http::response::ok_response(payload.result.to_string())
190 } else {
191 internal_proxy_error(&bytes)
192 };
193
194 response.extensions_mut().extend(parts.extensions);
195
196 Ok(response)
197 }
198 .boxed()
199 }
200 _ => {
202 let req = req.map(HttpBody::new);
203 self.inner.call(req).map_err(Into::into).boxed()
204 }
205 }
206 }
207}
208
209fn internal_proxy_error(bytes: &[u8]) -> HttpResponse {
210 #[derive(serde::Deserialize)]
211 struct ErrorResponse<'a> {
212 #[serde(borrow)]
213 error: ErrorObject<'a>,
214 }
215
216 let error = serde_json::from_slice::<ErrorResponse>(bytes)
217 .map(|payload| payload.error)
218 .unwrap_or_else(|_| ErrorObject::from(ErrorCode::InternalError));
219
220 http::response::from_template(
221 StatusCode::INTERNAL_SERVER_ERROR,
222 serde_json::to_string(&error).expect("JSON serialization infallible; qed"),
223 "application/json; charset=utf-8",
224 )
225}