cardinal_proxy/
lib.rs

1mod utils;
2
3use crate::utils::requests::{
4    compose_upstream_url, execution_context_from_request, parse_origin, rewrite_request_path,
5    set_upstream_host_headers,
6};
7use bytes::Bytes;
8use cardinal_base::context::CardinalContext;
9use cardinal_base::destinations::container::DestinationContainer;
10use cardinal_plugins::request_context::{RequestContext, RequestContextBase};
11use cardinal_plugins::runner::MiddlewareResult;
12use pingora::http::ResponseHeader;
13use pingora::prelude::*;
14use pingora::protocols::Digest;
15use pingora::upstreams::peer::Peer;
16use std::sync::Arc;
17use tracing::{debug, error, info, warn};
18
19pub mod pingora {
20    pub use pingora::*;
21}
22
23#[derive(Debug, Clone)]
24pub enum HealthCheckStatus {
25    None,
26    Ready,
27    Unavailable {
28        status_code: u16,
29        reason: Option<String>,
30    },
31}
32
33pub trait CardinalContextProvider: Send + Sync {
34    fn ctx(&self) -> RequestContextBase {
35        RequestContextBase::default()
36    }
37
38    fn resolve(
39        &self,
40        session: &Session,
41        ctx: &mut RequestContextBase,
42    ) -> Option<Arc<CardinalContext>>;
43    fn health_check(&self, _session: &Session) -> HealthCheckStatus {
44        HealthCheckStatus::None
45    }
46
47    fn logging(&self, _session: &mut Session, _e: Option<&Error>, _ctx: &mut RequestContextBase) {}
48}
49
50#[derive(Clone)]
51pub struct StaticContextProvider {
52    context: Arc<CardinalContext>,
53}
54
55impl StaticContextProvider {
56    pub fn new(context: Arc<CardinalContext>) -> Self {
57        Self { context }
58    }
59}
60
61impl CardinalContextProvider for StaticContextProvider {
62    fn resolve(
63        &self,
64        _session: &Session,
65        _ctx: &mut RequestContextBase,
66    ) -> Option<Arc<CardinalContext>> {
67        Some(self.context.clone())
68    }
69}
70
71pub struct CardinalProxy {
72    provider: Arc<dyn CardinalContextProvider>,
73}
74
75impl CardinalProxy {
76    pub fn new(context: Arc<CardinalContext>) -> Self {
77        Self::builder(context).build()
78    }
79
80    pub fn with_provider(provider: Arc<dyn CardinalContextProvider>) -> Self {
81        Self { provider }
82    }
83
84    pub fn builder(context: Arc<CardinalContext>) -> CardinalProxyBuilder {
85        CardinalProxyBuilder::new(context)
86    }
87}
88
89pub struct CardinalProxyBuilder {
90    provider: Arc<dyn CardinalContextProvider>,
91}
92
93impl CardinalProxyBuilder {
94    pub fn new(context: Arc<CardinalContext>) -> Self {
95        Self {
96            provider: Arc::new(StaticContextProvider::new(context)),
97        }
98    }
99
100    pub fn from_context_provider(provider: Arc<dyn CardinalContextProvider>) -> Self {
101        Self { provider }
102    }
103
104    pub fn with_context_provider(mut self, provider: Arc<dyn CardinalContextProvider>) -> Self {
105        self.provider = provider;
106        self
107    }
108
109    pub fn build(self) -> CardinalProxy {
110        CardinalProxy::with_provider(self.provider)
111    }
112}
113
114#[async_trait::async_trait]
115impl ProxyHttp for CardinalProxy {
116    type CTX = RequestContextBase;
117
118    fn new_ctx(&self) -> Self::CTX {
119        self.provider.ctx()
120    }
121
122    async fn logging(&self, _session: &mut Session, _e: Option<&Error>, ctx: &mut Self::CTX)
123    where
124        Self::CTX: Send + Sync,
125    {
126        self.provider.logging(_session, _e, ctx);
127    }
128
129    async fn request_filter(&self, session: &mut Session, ctx: &mut Self::CTX) -> Result<bool> {
130        let path = session.req_header().uri.path().to_string();
131        info!(%path, "Request received");
132
133        match self.provider.health_check(session) {
134            HealthCheckStatus::None => {}
135            HealthCheckStatus::Ready => {
136                debug!(%path, "Health check ready");
137                // Build a 200 OK header
138                let mut resp = ResponseHeader::build(200, None)?;
139                resp.insert_header("Content-Type", "text/plain")?;
140                resp.set_content_length("healthy\n".len())?;
141
142                // Send header + body to the client
143                session
144                    .write_response_header(Box::new(resp), /*end_of_stream*/ false)
145                    .await?;
146                session
147                    .write_response_body(Some(Bytes::from_static(b"healthy\n")), /*end*/ true)
148                    .await?;
149
150                // Returning Ok(true) means "handled", stop further processing.
151                return Ok(true);
152            }
153            HealthCheckStatus::Unavailable {
154                status_code,
155                reason,
156            } => {
157                if let Some(reason) = reason {
158                    warn!(%path, status = status_code, reason = %reason, "Health check failed");
159                } else {
160                    warn!(%path, status = status_code, "Health check failed");
161                }
162                let _ = session.respond_error(status_code).await;
163                return Ok(true);
164            }
165        }
166
167        let context = match self.provider.resolve(session, ctx) {
168            Some(ctx) => ctx,
169            None => {
170                warn!(%path, "No context found for request host");
171                let _ = session.respond_error(421).await;
172                return Ok(true);
173            }
174        };
175
176        let destination_container = context
177            .get::<DestinationContainer>()
178            .await
179            .map_err(|_| Error::new_str("Destination Container is not present"))?;
180
181        let force_path = context.config.server.force_path_parameter;
182        let backend =
183            match destination_container.get_backend_for_request(session.req_header(), force_path) {
184                Some(b) => b,
185                None => {
186                    warn!(%path, "No matching backend, returning 404");
187                    let _ = session.respond_error(404).await;
188                    return Ok(true);
189                }
190            };
191
192        let destination_name = backend.destination.name.clone();
193        let _ = set_upstream_host_headers(session, &backend);
194        info!(backend_id = %destination_name, "Routing to backend");
195
196        rewrite_request_path(session.req_header_mut(), &destination_name, force_path);
197
198        let mut request_state = RequestContext::new(
199            context.clone(),
200            backend,
201            execution_context_from_request(session),
202        );
203
204        let plugin_runner = request_state.plugin_runner.clone();
205
206        let run_filters = plugin_runner
207            .run_request_filters(session, &mut request_state)
208            .await;
209
210        let res = match run_filters {
211            Ok(filter_result) => filter_result,
212            Err(err) => {
213                error!(%err, "Error running request filters");
214                let _ = session.respond_error(500).await;
215                return Ok(true);
216            }
217        };
218
219        ctx.set_resolved_request(request_state);
220
221        match res {
222            MiddlewareResult::Continue(resp_headers) => {
223                ctx.resolved_request.as_mut().unwrap().response_headers = Some(resp_headers);
224
225                Ok(false)
226            }
227            MiddlewareResult::Responded => Ok(true),
228        }
229    }
230
231    async fn upstream_peer(
232        &self,
233        _session: &mut Session,
234        ctx: &mut Self::CTX,
235    ) -> Result<Box<HttpPeer>> {
236        // Determine origin parts for TLS and SNI
237        let (host, port, is_tls) = parse_origin(&ctx.req_unsafe().backend.destination.url)
238            .map_err(|_| Error::new_str("Origin could not be parsed "))?;
239        let hostport = format!("{host}:{port}");
240
241        // Compose full upstream URL for logging with normalized scheme
242        let path_and_query = _session
243            .req_header()
244            .uri
245            .path_and_query()
246            .map(|pq| pq.as_str())
247            .unwrap_or("/");
248        let upstream_url = compose_upstream_url(is_tls, &host, port, path_and_query);
249
250        info!(%upstream_url, backend_id = %ctx.req_unsafe().backend.destination.name, is_tls, sni = %host, "Forwarding to upstream");
251        debug!(upstream_origin = %hostport, "Connecting to upstream origin");
252
253        let mut peer = HttpPeer::new(&hostport, is_tls, host);
254        if let Some(opts) = peer.get_mut_peer_options() {
255            // Allow both HTTP/1.1 and HTTP/2 so plain HTTP backends keep working.
256            opts.set_http_version(2, 1);
257        }
258        let peer = Box::new(peer);
259        Ok(peer)
260    }
261
262    async fn connected_to_upstream(
263        &self,
264        _session: &mut Session,
265        reused: bool,
266        peer: &HttpPeer,
267        #[cfg(unix)] _fd: std::os::unix::io::RawFd,
268        #[cfg(windows)] _sock: std::os::windows::io::RawSocket,
269        _digest: Option<&Digest>,
270        ctx: &mut Self::CTX,
271    ) -> Result<()> {
272        let backend_id = ctx.req_unsafe().backend.destination.name.to_string();
273
274        info!(backend_id, reused, peer = %peer, "Connected to upstream");
275        Ok(())
276    }
277
278    async fn response_filter(
279        &self,
280        session: &mut Session,
281        upstream_response: &mut ResponseHeader,
282        ctx: &mut Self::CTX,
283    ) -> Result<()> {
284        if let Some(resp_headers) = ctx.req_unsafe_mut().response_headers.take() {
285            for (key, val) in resp_headers {
286                let _ = upstream_response.insert_header(key, val);
287            }
288        }
289
290        {
291            let req = ctx.req_unsafe_mut();
292
293            let runner = req.plugin_runner.clone();
294
295            runner
296                .run_response_filters(session, req, upstream_response)
297                .await;
298
299            if !req.cardinal_context.config.server.log_upstream_response {
300                return Ok(());
301            }
302
303            let status = upstream_response.status.as_u16();
304            let location = upstream_response
305                .headers
306                .get("location")
307                .and_then(|v| v.to_str().ok())
308                .map(|s| s.to_string());
309            let backend_id = &req.backend.destination.name;
310            if let Some(loc) = location {
311                info!(backend_id, status, location = %loc, "Upstream responded");
312            } else {
313                info!(backend_id, status, "Upstream responded");
314            }
315        }
316
317        ctx.set("status", upstream_response.status.as_str());
318
319        Ok(())
320    }
321}