httproxide_hyper_reverse_proxy/
lib.rs1#[macro_use]
112extern crate tracing;
113
114use hyper::body::HttpBody;
115use hyper::client::connect::Connect;
116use hyper::header::{HeaderMap, HeaderName, HeaderValue};
117use hyper::http::header::{InvalidHeaderValue, ToStrError};
118use hyper::http::uri::InvalidUri;
119use hyper::upgrade::OnUpgrade;
120use hyper::{Body, Client, Error, Request, Response, StatusCode};
121use lazy_static::lazy_static;
122use std::net::IpAddr;
123use thiserror::Error as ThisError;
124use tokio::io::copy_bidirectional;
125
126lazy_static! {
127 static ref TE_HEADER: HeaderName = HeaderName::from_static("te");
128 static ref CONNECTION_HEADER: HeaderName = HeaderName::from_static("connection");
129 static ref UPGRADE_HEADER: HeaderName = HeaderName::from_static("upgrade");
130 static ref TRAILER_HEADER: HeaderName = HeaderName::from_static("trailer");
131 static ref TRAILERS_HEADER: HeaderName = HeaderName::from_static("trailers");
132 static ref HOP_HEADERS: [HeaderName; 9] = [
134 CONNECTION_HEADER.clone(),
135 TE_HEADER.clone(),
136 TRAILER_HEADER.clone(),
137 HeaderName::from_static("keep-alive"),
138 HeaderName::from_static("proxy-connection"),
139 HeaderName::from_static("proxy-authenticate"),
140 HeaderName::from_static("proxy-authorization"),
141 HeaderName::from_static("transfer-encoding"),
142 HeaderName::from_static("upgrade"),
143 ];
144
145 static ref X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");
146}
147
148#[derive(Debug, ThisError)]
149pub enum ProxyError {
150 #[error("{0}")]
151 InvalidUri(#[from] InvalidUri),
152 #[error("{0}")]
153 HyperError(#[from] Error),
154 #[error("ForwardHeaderError")]
155 ForwardHeaderError,
156 #[error("UpgradeError: {0}")]
157 UpgradeError(String),
158}
159
160impl From<ToStrError> for ProxyError {
161 fn from(_err: ToStrError) -> ProxyError {
162 ProxyError::ForwardHeaderError
163 }
164}
165
166impl From<InvalidHeaderValue> for ProxyError {
167 fn from(_err: InvalidHeaderValue) -> ProxyError {
168 ProxyError::ForwardHeaderError
169 }
170}
171
172fn remove_hop_headers(headers: &mut HeaderMap) {
173 debug!("Removing hop headers");
174
175 for header in &*HOP_HEADERS {
176 headers.remove(header);
177 }
178}
179
180fn get_upgrade_type(headers: &HeaderMap) -> Option<String> {
181 #[allow(clippy::blocks_in_if_conditions)]
182 if headers
183 .get(&*CONNECTION_HEADER)
184 .map(|value| {
185 value
186 .to_str()
187 .unwrap()
188 .split(',')
189 .any(|e| e.trim() == *UPGRADE_HEADER)
190 })
191 .unwrap_or(false)
192 {
193 if let Some(upgrade_value) = headers.get(&*UPGRADE_HEADER) {
194 debug!(
195 "Found upgrade header with value: {}",
196 upgrade_value.to_str().unwrap().to_owned()
197 );
198
199 return Some(upgrade_value.to_str().unwrap().to_owned());
200 }
201 }
202
203 None
204}
205
206fn remove_connection_headers(headers: &mut HeaderMap) {
207 if headers.get(&*CONNECTION_HEADER).is_some() {
208 debug!("Removing connection headers");
209
210 let value = headers.get(&*CONNECTION_HEADER).cloned().unwrap();
211
212 for name in value.to_str().unwrap().split(',') {
213 if !name.trim().is_empty() {
214 headers.remove(name.trim());
215 }
216 }
217 }
218}
219
220fn create_proxied_response<B>(mut response: Response<B>) -> Response<B> {
221 info!("Creating proxied response");
222
223 remove_hop_headers(response.headers_mut());
224 remove_connection_headers(response.headers_mut());
225
226 response
227}
228
229fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> String {
230 debug!("Building forward uri");
231
232 let split_url = forward_url.split('?').collect::<Vec<&str>>();
233
234 let mut base_url: &str = split_url.get(0).unwrap_or(&"");
235 let forward_url_query: &str = split_url.get(1).unwrap_or(&"");
236
237 let path2 = req.uri().path();
238
239 if base_url.ends_with('/') {
240 let mut path1_chars = base_url.chars();
241 path1_chars.next_back();
242
243 base_url = path1_chars.as_str();
244 }
245
246 let total_length = base_url.len()
247 + path2.len()
248 + 1
249 + forward_url_query.len()
250 + req.uri().query().map(|e| e.len()).unwrap_or(0);
251
252 debug!("Creating url with capacity to {}", total_length);
253
254 let mut url = String::with_capacity(total_length);
255
256 url.push_str(base_url);
257 url.push_str(path2);
258
259 if !forward_url_query.is_empty() || req.uri().query().map(|e| !e.is_empty()).unwrap_or(false) {
260 debug!("Adding query parts to url");
261 url.push('?');
262 url.push_str(forward_url_query);
263
264 if forward_url_query.is_empty() {
265 debug!("Using request query");
266
267 url.push_str(req.uri().query().unwrap_or(""));
268 } else {
269 debug!("Merging request and forward_url query");
270
271 let request_query_items = req.uri().query().unwrap_or("").split('&').map(|el| {
272 let parts = el.split('=').collect::<Vec<&str>>();
273 (parts[0], if parts.len() > 1 { parts[1] } else { "" })
274 });
275
276 let forward_query_items = forward_url_query
277 .split('&')
278 .map(|el| {
279 let parts = el.split('=').collect::<Vec<&str>>();
280 parts[0]
281 })
282 .collect::<Vec<_>>();
283
284 for (key, value) in request_query_items {
285 if !forward_query_items.iter().any(|e| e == &key) {
286 url.push('&');
287 url.push_str(key);
288 url.push('=');
289 url.push_str(value);
290 }
291 }
292
293 if url.ends_with('&') {
294 let mut parts = url.chars();
295 parts.next_back();
296
297 url = parts.as_str().to_string();
298 }
299 }
300 }
301
302 debug!("Built forwarding url from request: {}", url);
303
304 url.parse().unwrap()
305}
306
307fn create_proxied_request<B>(
308 client_ip: IpAddr,
309 forward_url: &str,
310 mut request: Request<B>,
311 upgrade_type: Option<&String>,
312) -> Result<Request<B>, ProxyError> {
313 info!("Creating proxied request");
314
315 let contains_te_trailers_value = request
316 .headers()
317 .get(&*TE_HEADER)
318 .map(|value| {
319 value
320 .to_str()
321 .unwrap()
322 .split(',')
323 .any(|e| e.trim() == *TRAILERS_HEADER)
324 })
325 .unwrap_or(false);
326
327 let uri: hyper::Uri = forward_uri(forward_url, &request).parse()?;
328
329 debug!("Setting headers of proxied request");
330
331 *request.uri_mut() = uri;
332
333 remove_hop_headers(request.headers_mut());
334 remove_connection_headers(request.headers_mut());
335
336 if contains_te_trailers_value {
337 debug!("Setting up trailer headers");
338
339 request
340 .headers_mut()
341 .insert(&*TE_HEADER, HeaderValue::from_static("trailers"));
342 }
343
344 if let Some(value) = upgrade_type {
345 debug!("Repopulate upgrade headers");
346
347 request
348 .headers_mut()
349 .insert(&*UPGRADE_HEADER, value.parse().unwrap());
350 request
351 .headers_mut()
352 .insert(&*CONNECTION_HEADER, HeaderValue::from_static("UPGRADE"));
353 }
354
355 match request.headers_mut().entry(&*X_FORWARDED_FOR) {
357 hyper::header::Entry::Vacant(entry) => {
358 debug!("X-Fowraded-for header was vacant");
359 entry.insert(client_ip.to_string().parse()?);
360 }
361
362 hyper::header::Entry::Occupied(entry) => {
363 debug!("X-Fowraded-for header was occupied");
364 let client_ip_str = client_ip.to_string();
365 let mut addr =
366 String::with_capacity(entry.get().as_bytes().len() + 2 + client_ip_str.len());
367
368 addr.push_str(std::str::from_utf8(entry.get().as_bytes()).unwrap());
369 addr.push(',');
370 addr.push(' ');
371 addr.push_str(&client_ip_str);
372 }
373 }
374
375 debug!("Created proxied request");
376
377 Ok(request)
378}
379
380pub async fn call<'a, C, B>(
381 client_ip: IpAddr,
382 forward_uri: &str,
383 mut request: Request<B>,
384 client: &'a Client<C, B>,
385) -> Result<Response<Body>, ProxyError>
386where
387 C: Connect + Clone + Send + Sync + 'static,
388 B: HttpBody + Send + 'static,
389 B::Data: Send,
390 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
391{
392 info!(
393 "Received proxy call from {} to {}, client: {}",
394 request.uri().to_string(),
395 forward_uri,
396 client_ip
397 );
398
399 let request_upgrade_type = get_upgrade_type(request.headers());
400 let request_upgraded = request.extensions_mut().remove::<OnUpgrade>();
401
402 let proxied_request = create_proxied_request(
403 client_ip,
404 forward_uri,
405 request,
406 request_upgrade_type.as_ref(),
407 )?;
408 let mut response = client.request(proxied_request).await?;
409
410 if response.status() == StatusCode::SWITCHING_PROTOCOLS {
411 let response_upgrade_type = get_upgrade_type(response.headers());
412
413 if request_upgrade_type != response_upgrade_type {
414 return Err(ProxyError::UpgradeError(format!(
415 "backend tried to switch to protocol {:?} when {:?} was requested",
416 response_upgrade_type, request_upgrade_type
417 )));
418 };
419 let request_upgraded = match request_upgraded {
420 Some(v) => v,
421 None => {
422 return Err(ProxyError::UpgradeError(
423 "request does not have an upgrade extension".to_string(),
424 ))
425 }
426 };
427 let mut response_upgraded = match response.extensions_mut().remove::<OnUpgrade>() {
428 Some(v) => v.await?,
429 None => {
430 return Err(ProxyError::UpgradeError(
431 "response does not have an upgrade extension".to_string(),
432 ))
433 }
434 };
435
436 debug!("Responding to a connection upgrade response");
437 tokio::spawn(async move {
438 let mut request_upgraded = match request_upgraded.await {
439 Ok(v) => v,
440 Err(e) => {
441 warn!("failed to upgrade request: {}", e);
442 return;
443 }
444 };
445
446 if let Some(err) = copy_bidirectional(&mut response_upgraded, &mut request_upgraded)
447 .await
448 .err()
449 {
450 if err.kind() != std::io::ErrorKind::UnexpectedEof {
451 warn!("coping between upgraded connections failed: {}", err);
452 }
453 }
454 });
455
456 Ok(response)
457 } else {
458 let proxied_response = create_proxied_response(response);
459
460 debug!("Responding to call with response");
461 Ok(proxied_response)
462 }
463}
464
465pub struct ReverseProxy<C, B> {
466 client: Client<C, B>,
467}
468
469impl<C, B> ReverseProxy<C, B>
470where
471 C: Connect + Clone + Send + Sync + 'static,
472 B: HttpBody + Send + 'static,
473 B::Data: Send,
474 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
475{
476 pub fn new(client: Client<C, B>) -> Self {
477 Self { client }
478 }
479
480 pub async fn call(
481 &self,
482 client_ip: IpAddr,
483 forward_uri: &str,
484 request: Request<B>,
485 ) -> Result<Response<Body>, ProxyError> {
486 call::<C, B>(client_ip, forward_uri, request, &self.client).await
487 }
488}
489
490#[cfg(feature = "__bench")]
491pub mod benches {
492 pub fn hop_headers() -> &'static [crate::HeaderName] {
493 &*super::HOP_HEADERS
494 }
495
496 pub fn create_proxied_response<T>(response: crate::Response<T>) {
497 super::create_proxied_response(response);
498 }
499
500 pub fn forward_uri<B>(forward_url: &str, req: &crate::Request<B>) {
501 super::forward_uri(forward_url, req);
502 }
503
504 pub fn create_proxied_request<B>(
505 client_ip: crate::IpAddr,
506 forward_url: &str,
507 request: crate::Request<B>,
508 upgrade_type: Option<&String>,
509 ) {
510 super::create_proxied_request(client_ip, forward_url, request, upgrade_type).unwrap();
511 }
512}