warpdrive_proxy/proxy/
handler.rs

1//! HTTP proxy handler implementation using Pingora
2//!
3//! This module provides the core ProxyHttp implementation for WarpDrive, handling:
4//! - Upstream peer selection (forwarding to target server)
5//! - Request/response filtering via middleware chain
6//! - Error handling (502 Bad Gateway responses)
7
8use async_trait::async_trait;
9use bytes::Bytes;
10use flate2::{Compression as FlateCompression, write::GzEncoder};
11use pingora::http::ResponseHeader;
12use pingora::prelude::*;
13use pingora::proxy::FailToProxy;
14use std::io::Write;
15use std::sync::Arc;
16use tokio::io::AsyncReadExt;
17use tracing::{debug, error, info};
18
19use crate::acme::ChallengeStore;
20use crate::config::Config;
21use crate::middleware::{
22    CompressionEncoding, CompressionState, MiddlewareContext, MiddlewareStack, StaticResponseBody,
23};
24use crate::router::Router;
25
26/// WarpDrive HTTP proxy handler
27///
28/// Implements Pingora's ProxyHttp trait to forward HTTP requests to an upstream server.
29/// Uses a middleware stack for request/response processing.
30/// Supports multi-upstream routing via optional Router.
31pub struct WarpDriveProxy {
32    /// Shared configuration
33    config: Arc<Config>,
34    /// Middleware stack for request/response processing
35    middleware: MiddlewareStack,
36    /// Optional router for multi-upstream routing (None = simple mode)
37    router: Option<Router>,
38    /// ACME challenge store for HTTP-01 validation
39    challenge_store: ChallengeStore,
40}
41
42impl WarpDriveProxy {
43    /// Create a new proxy handler with the given configuration
44    pub fn new(
45        config: Arc<Config>,
46        router: Option<Router>,
47        challenge_store: ChallengeStore,
48    ) -> Self {
49        info!("Initializing WarpDrive proxy handler");
50
51        if router.is_some() {
52            info!("  Mode: Advanced (TOML routing)");
53        } else {
54            info!("  Mode: Simple (env vars)");
55            info!("  Target: {}:{}", config.target_host, config.target_port);
56        }
57
58        info!("  Forward headers: {}", config.forward_headers);
59        info!("  X-Sendfile: {}", config.x_sendfile_enabled);
60        info!("  Compression: {}", config.gzip_compression_enabled);
61        info!("  HTTP/2 cleartext (h2c): {}", config.h2c_enabled);
62        info!("  Logging: {}", config.log_requests);
63
64        let middleware = MiddlewareStack::new(config.clone());
65
66        Self {
67            config,
68            middleware,
69            router,
70            challenge_store,
71        }
72    }
73}
74
75#[async_trait]
76impl ProxyHttp for WarpDriveProxy {
77    type CTX = MiddlewareContext;
78
79    fn new_ctx(&self) -> Self::CTX {
80        MiddlewareContext::default()
81    }
82
83    /// Select the upstream peer (required method)
84    ///
85    /// This is called for each request to determine which upstream server to connect to.
86    /// - Advanced mode (TOML): Routes based on path/host/method and selects from LoadBalancer
87    /// - Simple mode (env): Always forwards to configured target_host:target_port
88    async fn upstream_peer(
89        &self,
90        session: &mut Session,
91        _ctx: &mut Self::CTX,
92    ) -> Result<Box<HttpPeer>> {
93        if let Some(router) = &self.router {
94            // Advanced mode: Use router to select upstream
95            let upstream = router
96                .select_upstream(session)
97                .map_err(|e| Error::because(ErrorType::HTTPStatus(502), e.to_string(), e))?;
98            upstream
99                .get_peer()
100                .await
101                .map_err(|e| Error::because(ErrorType::HTTPStatus(502), e.to_string(), e))
102        } else {
103            // Simple mode: Use configured target
104            debug!(
105                "Selecting upstream peer: {}:{}",
106                self.config.target_host, self.config.target_port
107            );
108
109            Ok(Box::new(HttpPeer::new(
110                (self.config.target_host.as_str(), self.config.target_port),
111                false,         // HTTP (not HTTPS)
112                String::new(), // No SNI for HTTP
113            )))
114        }
115    }
116
117    /// Apply request filters before sending to upstream
118    ///
119    /// This executes the middleware chain on the incoming request.
120    /// Returns false to indicate request should not be bypassed.
121    async fn request_filter(&self, session: &mut Session, ctx: &mut Self::CTX) -> Result<bool> {
122        // Handle ACME HTTP-01 challenges
123        let path = session.req_header().uri.path();
124        if let Some(token) = path.strip_prefix("/.well-known/acme-challenge/") {
125            if let Some(key_auth) = self.challenge_store.get(token).await {
126                debug!("Serving ACME challenge for token: {}", token);
127
128                // Respond with key authorization
129                let mut response = ResponseHeader::build(200, None)?;
130                response.insert_header("Content-Type", "text/plain")?;
131                response.insert_header("Content-Length", key_auth.len().to_string())?;
132
133                session
134                    .write_response_header(Box::new(response), false)
135                    .await?;
136                session
137                    .write_response_body(Some(Bytes::from(key_auth)), true)
138                    .await?;
139
140                return Ok(true); // Bypass - we handled the request
141            } else {
142                debug!("ACME challenge token not found: {}", token);
143                // Let it pass through - might be handled by upstream
144            }
145        }
146
147        // Validate Host/:authority header based on HTTP version
148        // HTTP/1.1: Requires "Host" header
149        // HTTP/2+: Uses ":authority" pseudo-header (no "Host" header)
150        let version = session.req_header().version;
151        if version == http::Version::HTTP_11 || version == http::Version::HTTP_10 {
152            // HTTP/1.x requires Host header
153            if session.req_header().headers.get("Host").is_none() {
154                if let Err(err) = session.respond_error(400).await {
155                    error!(
156                        "Failed to send 400 response for missing Host header: {}",
157                        err
158                    );
159                    return Err(err);
160                }
161                return Ok(true); // Bypass request - we already responded
162            }
163        }
164        // HTTP/2+ uses :authority pseudo-header - handled by Pingora internally
165
166        // Apply middleware request filters
167        self.middleware.apply_request_filters(session, ctx).await?;
168
169        // Check if static files middleware set a response
170        if let Some(static_response) = ctx.static_response.take() {
171            debug!("Serving static file response");
172
173            session
174                .write_response_header(Box::new(static_response.header), false)
175                .await?;
176
177            match static_response.body {
178                StaticResponseBody::InMemory(body) => {
179                    session.write_response_body(Some(body), true).await?;
180                }
181                StaticResponseBody::Stream(path) => {
182                    let mut file = match tokio::fs::File::open(&path).await {
183                        Ok(file) => file,
184                        Err(err) => {
185                            error!(
186                                "Failed to open static file for streaming {:?}: {}",
187                                path, err
188                            );
189                            return Err(Error::explain(
190                                ErrorType::HTTPStatus(500),
191                                "Failed to open static file",
192                            ));
193                        }
194                    };
195
196                    let mut buffer = vec![0u8; 64 * 1024];
197                    loop {
198                        let bytes_read = file.read(&mut buffer).await.map_err(|err| {
199                            error!("Failed to read static file chunk {:?}: {}", path, err);
200                            Error::explain(ErrorType::HTTPStatus(500), "Failed to read static file")
201                        })?;
202
203                        if bytes_read == 0 {
204                            break;
205                        }
206
207                        session
208                            .write_response_body(
209                                Some(Bytes::copy_from_slice(&buffer[..bytes_read])),
210                                false,
211                            )
212                            .await?;
213                    }
214
215                    session.write_response_body(None, true).await?;
216                }
217            }
218
219            return Ok(true); // Bypass - we handled the request
220        }
221
222        // Don't bypass request
223        Ok(false)
224    }
225
226    /// Modify request before sending to upstream (optional)
227    ///
228    /// This preserves the original Host header and applies path transformation if configured.
229    async fn upstream_request_filter(
230        &self,
231        session: &mut Session,
232        upstream_request: &mut RequestHeader,
233        _ctx: &mut Self::CTX,
234    ) -> Result<()> {
235        // Preserve the original Host header for upstream request
236        // Note: Host header validation happens earlier in request_filter()
237        if let Some(host) = session.req_header().headers.get("Host") {
238            // Convert header value to string
239            // Invalid headers are handled by returning 400 in request_filter
240            let host_str = host.to_str().map_err(|_| {
241                Error::explain(
242                    ErrorType::HTTPStatus(400),
243                    "Invalid Host header (non-ASCII characters)",
244                )
245            })?;
246            upstream_request.insert_header("Host", host_str)?;
247        }
248
249        // Apply path transformation if using router
250        if let Some(router) = &self.router {
251            if let Some(route) = router.find_matching_route(session) {
252                let original_path = session.req_header().uri.path();
253                let transformed_path = route.transform_path(original_path);
254
255                if transformed_path != original_path {
256                    debug!(
257                        "Transforming path: {} -> {}",
258                        original_path, transformed_path
259                    );
260
261                    // Build new URI with transformed path
262                    let mut parts = upstream_request.uri.clone().into_parts();
263                    let path_and_query = if let Some(query) = session.req_header().uri.query() {
264                        format!("{}?{}", transformed_path, query)
265                    } else {
266                        transformed_path.to_string()
267                    };
268
269                    parts.path_and_query = Some(path_and_query.parse().map_err(|_| {
270                        Error::explain(
271                            ErrorType::HTTPStatus(500),
272                            "Failed to construct transformed URI",
273                        )
274                    })?);
275
276                    upstream_request.set_uri(http::Uri::from_parts(parts).map_err(|_| {
277                        Error::explain(
278                            ErrorType::HTTPStatus(500),
279                            "Failed to apply path transformation",
280                        )
281                    })?);
282                }
283            }
284        }
285
286        debug!(
287            "Forwarding request: {} {} to upstream",
288            upstream_request.method, upstream_request.uri
289        );
290
291        Ok(())
292    }
293
294    /// Apply response filters after receiving from upstream
295    ///
296    /// This executes the middleware chain on the upstream response.
297    async fn response_filter(
298        &self,
299        session: &mut Session,
300        upstream_response: &mut ResponseHeader,
301        ctx: &mut Self::CTX,
302    ) -> Result<()> {
303        self.middleware
304            .apply_response_filters(session, upstream_response, ctx)
305            .await
306    }
307
308    /// Handle upstream connection errors (optional)
309    ///
310    /// This is called when we fail to connect to the upstream server.
311    /// We log the error and return it (Pingora will send 502 Bad Gateway).
312    fn fail_to_connect(
313        &self,
314        _session: &mut Session,
315        _peer: &HttpPeer,
316        _ctx: &mut Self::CTX,
317        e: Box<Error>,
318    ) -> Box<Error> {
319        error!("Failed to connect to upstream: {}", e);
320        // TODO: In future, serve custom bad gateway page from config.bad_gateway_page
321        e
322    }
323
324    /// Handle upstream errors during request/response (optional)
325    ///
326    /// This is called when the proxy operation fails after connection.
327    /// We can customize the error response and connection reuse behavior.
328    async fn fail_to_proxy(
329        &self,
330        _session: &mut Session,
331        e: &Error,
332        _ctx: &mut Self::CTX,
333    ) -> FailToProxy {
334        error!("Failed to proxy request: {}", e);
335
336        // Determine error code based on error type
337        let error_code = if e.etype() == &ErrorType::ReadError {
338            // Could be request entity too large
339            // TODO: Better error type detection
340            413
341        } else {
342            // Return 502 Bad Gateway for other errors
343            502
344        };
345
346        // Don't reuse downstream connection on proxy errors
347        // to avoid sending partial/corrupted responses
348        FailToProxy {
349            error_code,
350            can_reuse_downstream: false,
351        }
352    }
353
354    /// Process response body chunks for compression/sendfile support
355    fn response_body_filter(
356        &self,
357        _session: &mut Session,
358        body: &mut Option<Bytes>,
359        end_of_stream: bool,
360        ctx: &mut Self::CTX,
361    ) -> Result<Option<std::time::Duration>> {
362        // X-Sendfile has highest priority: replace upstream body with file contents
363        if ctx.sendfile.active {
364            if let Some(chunk) = body {
365                chunk.clear();
366            }
367
368            if !ctx.sendfile.served {
369                if let Some(file_body) = ctx.sendfile.body.take() {
370                    *body = Some(file_body);
371                } else {
372                    *body = None;
373                }
374                ctx.sendfile.served = true;
375            } else {
376                *body = None;
377            }
378
379            return Ok(None);
380        }
381
382        // Streaming responses: pass through immediately without buffering
383        // This enables real-time SSE, WebSockets, and other streaming protocols
384        if ctx.streaming {
385            // Body passes through unmodified - no buffering
386            return Ok(None);
387        }
388
389        // Compression: buffer upstream body until end-of-stream then emit compressed bytes
390        if let CompressionState::Pending { buffer, encoding } = &mut ctx.compression {
391            if let Some(chunk) = body {
392                buffer.extend_from_slice(&chunk[..]);
393                chunk.clear();
394            }
395
396            if end_of_stream {
397                let compressed = match encoding {
398                    CompressionEncoding::Brotli => brotli_compress(buffer)?,
399                    CompressionEncoding::Gzip => gzip_compress(buffer)?,
400                };
401                *body = Some(Bytes::from(compressed));
402                ctx.compression = CompressionState::Complete;
403            } else {
404                *body = None;
405            }
406
407            return Ok(None);
408        }
409
410        Ok(None)
411    }
412}
413
414// Note: X-Forwarded-* header logic is now handled by HeadersMiddleware
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419
420    #[test]
421    fn test_proxy_creation() {
422        let config = Arc::new(Config::default());
423        let challenge_store = ChallengeStore::default();
424        let proxy = WarpDriveProxy::new(config.clone(), None, challenge_store);
425
426        assert_eq!(proxy.config.target_port, 3000);
427    }
428}
429
430fn gzip_compress(buffer: &[u8]) -> Result<Vec<u8>> {
431    let mut encoder = GzEncoder::new(
432        Vec::with_capacity(buffer.len() / 2 + 16),
433        FlateCompression::default(),
434    );
435    encoder.write_all(buffer).map_err(|_| {
436        Error::explain(
437            ErrorType::HTTPStatus(500),
438            "Failed to compress response body",
439        )
440    })?;
441    encoder.finish().map_err(|_| {
442        Error::explain(
443            ErrorType::HTTPStatus(500),
444            "Failed to finalize compressed body",
445        )
446    })
447}
448
449fn brotli_compress(buffer: &[u8]) -> Result<Vec<u8>> {
450    use brotli::enc::BrotliEncoderParams;
451
452    let mut output = Vec::with_capacity(buffer.len() / 2 + 16);
453    let params = BrotliEncoderParams {
454        quality: 6, // Brotli quality 6 (balance of speed vs size, similar to gzip default)
455        ..Default::default()
456    };
457
458    brotli::BrotliCompress(&mut std::io::Cursor::new(buffer), &mut output, &params).map_err(
459        |_| {
460            Error::explain(
461                ErrorType::HTTPStatus(500),
462                "Failed to compress response body with Brotli",
463            )
464        },
465    )?;
466
467    Ok(output)
468}