1use async_trait::async_trait;
9use bytes::Bytes;
10use flate2::{Compression as FlateCompression, write::GzEncoder};
11use pingora::http::ResponseHeader;
12use pingora::prelude::*;
13use pingora::proxy::FailToProxy;
14use std::io::Write;
15use std::sync::Arc;
16use tokio::io::AsyncReadExt;
17use tracing::{debug, error, info};
18
19use crate::acme::ChallengeStore;
20use crate::config::Config;
21use crate::middleware::{
22 CompressionEncoding, CompressionState, MiddlewareContext, MiddlewareStack, StaticResponseBody,
23};
24use crate::router::Router;
25
26pub struct WarpDriveProxy {
32 config: Arc<Config>,
34 middleware: MiddlewareStack,
36 router: Option<Router>,
38 challenge_store: ChallengeStore,
40}
41
42impl WarpDriveProxy {
43 pub fn new(
45 config: Arc<Config>,
46 router: Option<Router>,
47 challenge_store: ChallengeStore,
48 ) -> Self {
49 info!("Initializing WarpDrive proxy handler");
50
51 if router.is_some() {
52 info!(" Mode: Advanced (TOML routing)");
53 } else {
54 info!(" Mode: Simple (env vars)");
55 info!(" Target: {}:{}", config.target_host, config.target_port);
56 }
57
58 info!(" Forward headers: {}", config.forward_headers);
59 info!(" X-Sendfile: {}", config.x_sendfile_enabled);
60 info!(" Compression: {}", config.gzip_compression_enabled);
61 info!(" HTTP/2 cleartext (h2c): {}", config.h2c_enabled);
62 info!(" Logging: {}", config.log_requests);
63
64 let middleware = MiddlewareStack::new(config.clone());
65
66 Self {
67 config,
68 middleware,
69 router,
70 challenge_store,
71 }
72 }
73}
74
75#[async_trait]
76impl ProxyHttp for WarpDriveProxy {
77 type CTX = MiddlewareContext;
78
79 fn new_ctx(&self) -> Self::CTX {
80 MiddlewareContext::default()
81 }
82
83 async fn upstream_peer(
89 &self,
90 session: &mut Session,
91 _ctx: &mut Self::CTX,
92 ) -> Result<Box<HttpPeer>> {
93 if let Some(router) = &self.router {
94 let upstream = router
96 .select_upstream(session)
97 .map_err(|e| Error::because(ErrorType::HTTPStatus(502), e.to_string(), e))?;
98 upstream
99 .get_peer()
100 .await
101 .map_err(|e| Error::because(ErrorType::HTTPStatus(502), e.to_string(), e))
102 } else {
103 debug!(
105 "Selecting upstream peer: {}:{}",
106 self.config.target_host, self.config.target_port
107 );
108
109 Ok(Box::new(HttpPeer::new(
110 (self.config.target_host.as_str(), self.config.target_port),
111 false, String::new(), )))
114 }
115 }
116
117 async fn request_filter(&self, session: &mut Session, ctx: &mut Self::CTX) -> Result<bool> {
122 let path = session.req_header().uri.path();
124 if let Some(token) = path.strip_prefix("/.well-known/acme-challenge/") {
125 if let Some(key_auth) = self.challenge_store.get(token).await {
126 debug!("Serving ACME challenge for token: {}", token);
127
128 let mut response = ResponseHeader::build(200, None)?;
130 response.insert_header("Content-Type", "text/plain")?;
131 response.insert_header("Content-Length", key_auth.len().to_string())?;
132
133 session
134 .write_response_header(Box::new(response), false)
135 .await?;
136 session
137 .write_response_body(Some(Bytes::from(key_auth)), true)
138 .await?;
139
140 return Ok(true); } else {
142 debug!("ACME challenge token not found: {}", token);
143 }
145 }
146
147 let version = session.req_header().version;
151 if version == http::Version::HTTP_11 || version == http::Version::HTTP_10 {
152 if session.req_header().headers.get("Host").is_none() {
154 if let Err(err) = session.respond_error(400).await {
155 error!(
156 "Failed to send 400 response for missing Host header: {}",
157 err
158 );
159 return Err(err);
160 }
161 return Ok(true); }
163 }
164 self.middleware.apply_request_filters(session, ctx).await?;
168
169 if let Some(static_response) = ctx.static_response.take() {
171 debug!("Serving static file response");
172
173 session
174 .write_response_header(Box::new(static_response.header), false)
175 .await?;
176
177 match static_response.body {
178 StaticResponseBody::InMemory(body) => {
179 session.write_response_body(Some(body), true).await?;
180 }
181 StaticResponseBody::Stream(path) => {
182 let mut file = match tokio::fs::File::open(&path).await {
183 Ok(file) => file,
184 Err(err) => {
185 error!(
186 "Failed to open static file for streaming {:?}: {}",
187 path, err
188 );
189 return Err(Error::explain(
190 ErrorType::HTTPStatus(500),
191 "Failed to open static file",
192 ));
193 }
194 };
195
196 let mut buffer = vec![0u8; 64 * 1024];
197 loop {
198 let bytes_read = file.read(&mut buffer).await.map_err(|err| {
199 error!("Failed to read static file chunk {:?}: {}", path, err);
200 Error::explain(ErrorType::HTTPStatus(500), "Failed to read static file")
201 })?;
202
203 if bytes_read == 0 {
204 break;
205 }
206
207 session
208 .write_response_body(
209 Some(Bytes::copy_from_slice(&buffer[..bytes_read])),
210 false,
211 )
212 .await?;
213 }
214
215 session.write_response_body(None, true).await?;
216 }
217 }
218
219 return Ok(true); }
221
222 Ok(false)
224 }
225
226 async fn upstream_request_filter(
230 &self,
231 session: &mut Session,
232 upstream_request: &mut RequestHeader,
233 _ctx: &mut Self::CTX,
234 ) -> Result<()> {
235 if let Some(host) = session.req_header().headers.get("Host") {
238 let host_str = host.to_str().map_err(|_| {
241 Error::explain(
242 ErrorType::HTTPStatus(400),
243 "Invalid Host header (non-ASCII characters)",
244 )
245 })?;
246 upstream_request.insert_header("Host", host_str)?;
247 }
248
249 if let Some(router) = &self.router {
251 if let Some(route) = router.find_matching_route(session) {
252 let original_path = session.req_header().uri.path();
253 let transformed_path = route.transform_path(original_path);
254
255 if transformed_path != original_path {
256 debug!(
257 "Transforming path: {} -> {}",
258 original_path, transformed_path
259 );
260
261 let mut parts = upstream_request.uri.clone().into_parts();
263 let path_and_query = if let Some(query) = session.req_header().uri.query() {
264 format!("{}?{}", transformed_path, query)
265 } else {
266 transformed_path.to_string()
267 };
268
269 parts.path_and_query = Some(path_and_query.parse().map_err(|_| {
270 Error::explain(
271 ErrorType::HTTPStatus(500),
272 "Failed to construct transformed URI",
273 )
274 })?);
275
276 upstream_request.set_uri(http::Uri::from_parts(parts).map_err(|_| {
277 Error::explain(
278 ErrorType::HTTPStatus(500),
279 "Failed to apply path transformation",
280 )
281 })?);
282 }
283 }
284 }
285
286 debug!(
287 "Forwarding request: {} {} to upstream",
288 upstream_request.method, upstream_request.uri
289 );
290
291 Ok(())
292 }
293
294 async fn response_filter(
298 &self,
299 session: &mut Session,
300 upstream_response: &mut ResponseHeader,
301 ctx: &mut Self::CTX,
302 ) -> Result<()> {
303 self.middleware
304 .apply_response_filters(session, upstream_response, ctx)
305 .await
306 }
307
308 fn fail_to_connect(
313 &self,
314 _session: &mut Session,
315 _peer: &HttpPeer,
316 _ctx: &mut Self::CTX,
317 e: Box<Error>,
318 ) -> Box<Error> {
319 error!("Failed to connect to upstream: {}", e);
320 e
322 }
323
324 async fn fail_to_proxy(
329 &self,
330 _session: &mut Session,
331 e: &Error,
332 _ctx: &mut Self::CTX,
333 ) -> FailToProxy {
334 error!("Failed to proxy request: {}", e);
335
336 let error_code = if e.etype() == &ErrorType::ReadError {
338 413
341 } else {
342 502
344 };
345
346 FailToProxy {
349 error_code,
350 can_reuse_downstream: false,
351 }
352 }
353
354 fn response_body_filter(
356 &self,
357 _session: &mut Session,
358 body: &mut Option<Bytes>,
359 end_of_stream: bool,
360 ctx: &mut Self::CTX,
361 ) -> Result<Option<std::time::Duration>> {
362 if ctx.sendfile.active {
364 if let Some(chunk) = body {
365 chunk.clear();
366 }
367
368 if !ctx.sendfile.served {
369 if let Some(file_body) = ctx.sendfile.body.take() {
370 *body = Some(file_body);
371 } else {
372 *body = None;
373 }
374 ctx.sendfile.served = true;
375 } else {
376 *body = None;
377 }
378
379 return Ok(None);
380 }
381
382 if ctx.streaming {
385 return Ok(None);
387 }
388
389 if let CompressionState::Pending { buffer, encoding } = &mut ctx.compression {
391 if let Some(chunk) = body {
392 buffer.extend_from_slice(&chunk[..]);
393 chunk.clear();
394 }
395
396 if end_of_stream {
397 let compressed = match encoding {
398 CompressionEncoding::Brotli => brotli_compress(buffer)?,
399 CompressionEncoding::Gzip => gzip_compress(buffer)?,
400 };
401 *body = Some(Bytes::from(compressed));
402 ctx.compression = CompressionState::Complete;
403 } else {
404 *body = None;
405 }
406
407 return Ok(None);
408 }
409
410 Ok(None)
411 }
412}
413
414#[cfg(test)]
417mod tests {
418 use super::*;
419
420 #[test]
421 fn test_proxy_creation() {
422 let config = Arc::new(Config::default());
423 let challenge_store = ChallengeStore::default();
424 let proxy = WarpDriveProxy::new(config.clone(), None, challenge_store);
425
426 assert_eq!(proxy.config.target_port, 3000);
427 }
428}
429
430fn gzip_compress(buffer: &[u8]) -> Result<Vec<u8>> {
431 let mut encoder = GzEncoder::new(
432 Vec::with_capacity(buffer.len() / 2 + 16),
433 FlateCompression::default(),
434 );
435 encoder.write_all(buffer).map_err(|_| {
436 Error::explain(
437 ErrorType::HTTPStatus(500),
438 "Failed to compress response body",
439 )
440 })?;
441 encoder.finish().map_err(|_| {
442 Error::explain(
443 ErrorType::HTTPStatus(500),
444 "Failed to finalize compressed body",
445 )
446 })
447}
448
449fn brotli_compress(buffer: &[u8]) -> Result<Vec<u8>> {
450 use brotli::enc::BrotliEncoderParams;
451
452 let mut output = Vec::with_capacity(buffer.len() / 2 + 16);
453 let params = BrotliEncoderParams {
454 quality: 6, ..Default::default()
456 };
457
458 brotli::BrotliCompress(&mut std::io::Cursor::new(buffer), &mut output, ¶ms).map_err(
459 |_| {
460 Error::explain(
461 ErrorType::HTTPStatus(500),
462 "Failed to compress response body with Brotli",
463 )
464 },
465 )?;
466
467 Ok(output)
468}