Skip to main content

cloudillo_proxy/
handler.rs

1//! HTTP forwarding and WebSocket tunneling for reverse proxy
2
3use axum::http::{header, HeaderMap, HeaderName, HeaderValue, Uri};
4use hyper::body::Incoming;
5use hyper_util::{client::legacy::Client, rt::TokioExecutor};
6use std::sync::Arc;
7use std::time::Duration;
8
9use crate::prelude::*;
10use crate::ProxySiteEntry;
11
12/// Headers that should not be forwarded between client and backend (hop-by-hop)
13const HOP_BY_HOP_HEADERS: &[&str] = &[
14	"connection",
15	"keep-alive",
16	"proxy-authenticate",
17	"proxy-authorization",
18	"te",
19	"trailers",
20	"transfer-encoding",
21];
22
23/// Check if a header is a hop-by-hop header that should be stripped
24fn is_hop_by_hop(name: &HeaderName) -> bool {
25	HOP_BY_HOP_HEADERS.iter().any(|h| name.as_str().eq_ignore_ascii_case(h))
26}
27
28/// Check if a request is a WebSocket upgrade request
29fn is_websocket_upgrade(headers: &HeaderMap) -> bool {
30	headers
31		.get(header::UPGRADE)
32		.and_then(|v| v.to_str().ok())
33		.is_some_and(|v| v.eq_ignore_ascii_case("websocket"))
34}
35
36/// Build the backend URI from the proxy site entry and the original request URI
37fn build_backend_uri(entry: &ProxySiteEntry, original_uri: &Uri) -> ClResult<Uri> {
38	let mut backend = entry.backend_url.clone();
39	let combined_path = format!("{}{}", backend.path().trim_end_matches('/'), original_uri.path());
40	backend.set_path(&combined_path);
41	backend.set_query(original_uri.query());
42	debug!("Proxy backend URI: {} (combined_path={:?})", backend.as_str(), combined_path);
43	backend
44		.as_str()
45		.parse::<Uri>()
46		.map_err(|e| Error::Internal(format!("failed to build backend URI: {}", e)))
47}
48
49/// Copy non-hop-by-hop headers from source to destination
50fn copy_headers(src: &HeaderMap, dst: &mut HeaderMap, is_websocket: bool) {
51	for (name, value) in src {
52		// Skip hop-by-hop headers (but keep Upgrade for WebSocket)
53		if is_hop_by_hop(name) {
54			if is_websocket && name == header::UPGRADE {
55				dst.insert(name.clone(), value.clone());
56			}
57			continue;
58		}
59		dst.append(name.clone(), value.clone());
60	}
61}
62
63/// Handle a proxy request - main entry point for the proxy handler
64pub async fn handle_proxy_request(
65	entry: Arc<ProxySiteEntry>,
66	req: hyper::Request<Incoming>,
67	peer_addr: &str,
68) -> Result<hyper::Response<Incoming>, Error> {
69	let is_ws = is_websocket_upgrade(req.headers()) && entry.config.websocket.unwrap_or(true);
70
71	if is_ws {
72		return handle_websocket_proxy(entry, req, peer_addr).await;
73	}
74
75	let backend_uri = build_backend_uri(&entry, req.uri())?;
76
77	// Build the backend request
78	let mut backend_headers = HeaderMap::new();
79	copy_headers(req.headers(), &mut backend_headers, false);
80
81	// Host header handling
82	let preserve_host = entry.config.preserve_host.unwrap_or(true);
83	if preserve_host {
84		// Keep original Host header
85		if let Some(host) = req.headers().get(header::HOST) {
86			backend_headers.insert(header::HOST, host.clone());
87		}
88	} else if let Some(host) = entry.backend_url.host_str() {
89		// Rewrite to backend host
90		let host_val = if let Some(port) = entry.backend_url.port() {
91			format!("{}:{}", host, port)
92		} else {
93			host.to_string()
94		};
95		if let Ok(hv) = HeaderValue::from_str(&host_val) {
96			backend_headers.insert(header::HOST, hv);
97		}
98	}
99
100	// Add forwarding headers (always on for "basic" type)
101	let forward_headers = if entry.proxy_type.as_ref() == "basic" {
102		true
103	} else {
104		entry.config.forward_headers.unwrap_or(true)
105	};
106	if forward_headers {
107		if let Ok(hv) = HeaderValue::from_str(peer_addr) {
108			backend_headers.insert(HeaderName::from_static("x-forwarded-for"), hv.clone());
109			backend_headers.insert(HeaderName::from_static("x-real-ip"), hv);
110		}
111		backend_headers.insert(
112			HeaderName::from_static("x-forwarded-proto"),
113			HeaderValue::from_static("https"),
114		);
115		if let Ok(hv) = HeaderValue::from_str(&entry.domain) {
116			backend_headers.insert(HeaderName::from_static("x-forwarded-host"), hv);
117		}
118	}
119
120	// Add custom headers
121	if let Some(custom_headers) = &entry.config.custom_headers {
122		for (name, value) in custom_headers {
123			if let (Ok(hn), Ok(hv)) =
124				(HeaderName::from_bytes(name.as_bytes()), HeaderValue::from_str(value))
125			{
126				backend_headers.insert(hn, hv);
127			}
128		}
129	}
130
131	// Build the request
132	let method = req.method().clone();
133	let body = req.into_body();
134
135	let mut backend_req = hyper::Request::builder().method(method).uri(backend_uri);
136
137	if let Some(headers) = backend_req.headers_mut() {
138		*headers = backend_headers;
139	}
140
141	let backend_req = backend_req
142		.body(body)
143		.map_err(|e| Error::Internal(format!("failed to build backend request: {}", e)))?;
144
145	// Set up timeouts
146	let connect_timeout =
147		Duration::from_secs(u64::from(entry.config.connect_timeout_secs.unwrap_or(5)));
148	let read_timeout = Duration::from_secs(u64::from(entry.config.read_timeout_secs.unwrap_or(30)));
149
150	// Send the request to the backend
151	let scheme = entry.backend_url.scheme();
152	match send_backend_request(scheme, connect_timeout, read_timeout, backend_req).await {
153		Ok(mut backend_resp) => {
154			// Strip hop-by-hop headers from response
155			let headers_to_remove: Vec<HeaderName> = backend_resp
156				.headers()
157				.keys()
158				.filter(|name| is_hop_by_hop(name))
159				.cloned()
160				.collect();
161			for name in headers_to_remove {
162				backend_resp.headers_mut().remove(&name);
163			}
164			Ok(backend_resp)
165		}
166		Err(e @ Error::Timeout) => {
167			warn!("Proxy backend timeout for {}", entry.domain);
168			Err(e)
169		}
170		Err(e) => {
171			warn!("Proxy backend error for {}: {}", entry.domain, e);
172			Err(e)
173		}
174	}
175}
176
177/// Handle a WebSocket proxy request via upgrade tunneling
178async fn handle_websocket_proxy(
179	entry: Arc<ProxySiteEntry>,
180	req: hyper::Request<Incoming>,
181	peer_addr: &str,
182) -> Result<hyper::Response<Incoming>, Error> {
183	// For WebSocket upgrade, we use hyper's low-level connection handling
184	// to establish a bidirectional tunnel
185	let backend_uri = build_backend_uri(&entry, req.uri())?;
186
187	let mut backend_headers = HeaderMap::new();
188	// Copy all headers including WebSocket-specific ones
189	for (name, value) in req.headers() {
190		if is_hop_by_hop(name) && name != header::UPGRADE {
191			continue;
192		}
193		backend_headers.append(name.clone(), value.clone());
194	}
195
196	// Host header
197	let preserve_host = entry.config.preserve_host.unwrap_or(true);
198	if !preserve_host {
199		if let Some(host) = entry.backend_url.host_str() {
200			let host_val = if let Some(port) = entry.backend_url.port() {
201				format!("{}:{}", host, port)
202			} else {
203				host.to_string()
204			};
205			if let Ok(hv) = HeaderValue::from_str(&host_val) {
206				backend_headers.insert(header::HOST, hv);
207			}
208		}
209	}
210
211	// Add forwarding headers (always on for "basic" type)
212	let forward_headers = if entry.proxy_type.as_ref() == "basic" {
213		true
214	} else {
215		entry.config.forward_headers.unwrap_or(true)
216	};
217	if forward_headers {
218		if let Ok(hv) = HeaderValue::from_str(peer_addr) {
219			backend_headers.insert(HeaderName::from_static("x-forwarded-for"), hv);
220		}
221		backend_headers.insert(
222			HeaderName::from_static("x-forwarded-proto"),
223			HeaderValue::from_static("https"),
224		);
225	}
226
227	// Ensure Connection: Upgrade is present
228	backend_headers.insert(header::CONNECTION, HeaderValue::from_static("Upgrade"));
229
230	let method = req.method().clone();
231	let body = req.into_body();
232
233	let mut backend_req = hyper::Request::builder().method(method).uri(backend_uri);
234
235	if let Some(headers) = backend_req.headers_mut() {
236		*headers = backend_headers;
237	}
238
239	let backend_req = backend_req
240		.body(body)
241		.map_err(|e| Error::Internal(format!("failed to build ws backend request: {}", e)))?;
242
243	// Connect to backend
244	let connect_timeout =
245		Duration::from_secs(u64::from(entry.config.connect_timeout_secs.unwrap_or(5)));
246
247	let scheme = entry.backend_url.scheme();
248	match send_backend_request(scheme, connect_timeout, connect_timeout, backend_req).await {
249		Ok(backend_resp) => Ok(backend_resp),
250		Err(e @ Error::Timeout) => {
251			warn!("WebSocket proxy backend timeout for {}", entry.domain);
252			Err(e)
253		}
254		Err(e) => {
255			warn!("WebSocket proxy backend error for {}: {}", entry.domain, e);
256			Err(e)
257		}
258	}
259}
260
261/// Send a request to a backend, choosing HTTP or HTTPS connector based on scheme
262async fn send_backend_request(
263	scheme: &str,
264	connect_timeout: Duration,
265	timeout: Duration,
266	req: hyper::Request<Incoming>,
267) -> Result<hyper::Response<Incoming>, Error> {
268	let result = if scheme == "https" {
269		let https_connector = hyper_rustls::HttpsConnectorBuilder::new()
270			.with_native_roots()
271			.map_err(|_| Error::ConfigError("no native root CA certificates found".into()))?
272			.https_only()
273			.enable_http1()
274			.build();
275		let client: Client<_, Incoming> = Client::builder(TokioExecutor::new())
276			.pool_idle_timeout(connect_timeout)
277			.build(https_connector);
278		tokio::time::timeout(timeout, client.request(req)).await
279	} else {
280		let http_connector = hyper_util::client::legacy::connect::HttpConnector::new();
281		let client: Client<_, Incoming> = Client::builder(TokioExecutor::new())
282			.pool_idle_timeout(connect_timeout)
283			.build(http_connector);
284		tokio::time::timeout(timeout, client.request(req)).await
285	};
286	match result {
287		Ok(Ok(resp)) => Ok(resp),
288		Ok(Err(_)) => Err(Error::NetworkError("bad gateway".into())),
289		Err(_) => Err(Error::Timeout),
290	}
291}
292
293#[cfg(test)]
294mod tests {
295	use super::*;
296	use crate::ProxySiteConfig;
297
298	#[test]
299	fn test_is_hop_by_hop() {
300		assert!(is_hop_by_hop(&HeaderName::from_static("connection")));
301		assert!(is_hop_by_hop(&HeaderName::from_static("keep-alive")));
302		assert!(is_hop_by_hop(&HeaderName::from_static("transfer-encoding")));
303		assert!(!is_hop_by_hop(&HeaderName::from_static("content-type")));
304		assert!(!is_hop_by_hop(&HeaderName::from_static("host")));
305	}
306
307	#[test]
308	fn test_build_backend_uri() {
309		let entry = ProxySiteEntry {
310			site_id: 1,
311			domain: "test.example.com".into(),
312			proxy_type: "basic".into(),
313			backend_url: url::Url::parse("http://localhost:3000").unwrap(),
314			config: ProxySiteConfig::default(),
315		};
316		let uri = "/api/test?foo=bar".parse::<Uri>().unwrap();
317		let result = build_backend_uri(&entry, &uri).unwrap();
318		assert_eq!(result.to_string(), "http://localhost:3000/api/test?foo=bar");
319	}
320
321	#[test]
322	fn test_build_backend_uri_root_path() {
323		let entry = ProxySiteEntry {
324			site_id: 1,
325			domain: "test.example.com".into(),
326			proxy_type: "basic".into(),
327			backend_url: url::Url::parse("http://localhost:3000").unwrap(),
328			config: ProxySiteConfig::default(),
329		};
330		let uri = "/".parse::<Uri>().unwrap();
331		let result = build_backend_uri(&entry, &uri).unwrap();
332		assert_eq!(result.to_string(), "http://localhost:3000/");
333	}
334
335	#[test]
336	fn test_build_backend_uri_with_path_prefix() {
337		let entry = ProxySiteEntry {
338			site_id: 1,
339			domain: "test.example.com".into(),
340			proxy_type: "basic".into(),
341			backend_url: url::Url::parse("http://backend:3000/a/").unwrap(),
342			config: ProxySiteConfig::default(),
343		};
344
345		// Root request should preserve the base path
346		let uri = "/".parse::<Uri>().unwrap();
347		let result = build_backend_uri(&entry, &uri).unwrap();
348		assert_eq!(result.to_string(), "http://backend:3000/a/");
349
350		// Subpath request should join with base path
351		let uri = "/foo".parse::<Uri>().unwrap();
352		let result = build_backend_uri(&entry, &uri).unwrap();
353		assert_eq!(result.to_string(), "http://backend:3000/a/foo");
354
355		// Subpath with query should work too
356		let uri = "/api/test?key=val".parse::<Uri>().unwrap();
357		let result = build_backend_uri(&entry, &uri).unwrap();
358		assert_eq!(result.to_string(), "http://backend:3000/a/api/test?key=val");
359	}
360}
361
362// vim: ts=4