cardinal_proxy/
lib.rs

1mod utils;
2
3use crate::utils::requests::{
4    compose_upstream_url, execution_context_from_request, parse_origin, rewrite_request_path,
5    set_upstream_host_headers,
6};
7use bytes::Bytes;
8use cardinal_base::context::CardinalContext;
9use cardinal_base::destinations::container::DestinationContainer;
10use cardinal_plugins::request_context::{RequestContext, RequestContextBase};
11use cardinal_plugins::runner::MiddlewareResult;
12use pingora::http::ResponseHeader;
13use pingora::prelude::*;
14use pingora::protocols::Digest;
15use pingora::upstreams::peer::Peer;
16use std::sync::Arc;
17use tracing::{debug, error, info, warn};
18
19pub mod pingora {
20    pub use pingora::*;
21}
22
23#[derive(Debug, Clone)]
24pub enum HealthCheckStatus {
25    None,
26    Ready,
27    Unavailable {
28        status_code: u16,
29        reason: Option<String>,
30    },
31}
32
33pub trait CardinalContextProvider: Send + Sync {
34    fn resolve(&self, session: &Session) -> Option<Arc<CardinalContext>>;
35    fn health_check(&self, _session: &Session) -> HealthCheckStatus {
36        HealthCheckStatus::None
37    }
38
39    fn logging(&self, _session: &mut Session, _e: Option<&Error>, _ctx: &mut RequestContextBase) {}
40}
41
42#[derive(Clone)]
43pub struct StaticContextProvider {
44    context: Arc<CardinalContext>,
45}
46
47impl StaticContextProvider {
48    pub fn new(context: Arc<CardinalContext>) -> Self {
49        Self { context }
50    }
51}
52
53impl CardinalContextProvider for StaticContextProvider {
54    fn resolve(&self, _session: &Session) -> Option<Arc<CardinalContext>> {
55        Some(self.context.clone())
56    }
57}
58
59pub struct CardinalProxy {
60    provider: Arc<dyn CardinalContextProvider>,
61}
62
63impl CardinalProxy {
64    pub fn new(context: Arc<CardinalContext>) -> Self {
65        Self::builder(context).build()
66    }
67
68    pub fn with_provider(provider: Arc<dyn CardinalContextProvider>) -> Self {
69        Self { provider }
70    }
71
72    pub fn builder(context: Arc<CardinalContext>) -> CardinalProxyBuilder {
73        CardinalProxyBuilder::new(context)
74    }
75}
76
77pub struct CardinalProxyBuilder {
78    provider: Arc<dyn CardinalContextProvider>,
79}
80
81impl CardinalProxyBuilder {
82    pub fn new(context: Arc<CardinalContext>) -> Self {
83        Self {
84            provider: Arc::new(StaticContextProvider::new(context)),
85        }
86    }
87
88    pub fn from_context_provider(provider: Arc<dyn CardinalContextProvider>) -> Self {
89        Self { provider }
90    }
91
92    pub fn with_context_provider(mut self, provider: Arc<dyn CardinalContextProvider>) -> Self {
93        self.provider = provider;
94        self
95    }
96
97    pub fn build(self) -> CardinalProxy {
98        CardinalProxy::with_provider(self.provider)
99    }
100}
101
102#[async_trait::async_trait]
103impl ProxyHttp for CardinalProxy {
104    type CTX = RequestContextBase;
105
106    fn new_ctx(&self) -> Self::CTX {
107        RequestContextBase::default()
108    }
109
110    async fn logging(&self, _session: &mut Session, _e: Option<&Error>, ctx: &mut Self::CTX)
111    where
112        Self::CTX: Send + Sync,
113    {
114        self.provider.logging(_session, _e, ctx);
115    }
116
117    async fn request_filter(&self, session: &mut Session, ctx: &mut Self::CTX) -> Result<bool> {
118        let path = session.req_header().uri.path().to_string();
119        info!(%path, "Request received");
120
121        match self.provider.health_check(session) {
122            HealthCheckStatus::None => {}
123            HealthCheckStatus::Ready => {
124                debug!(%path, "Health check ready");
125                // Build a 200 OK header
126                let mut resp = ResponseHeader::build(200, None)?;
127                resp.insert_header("Content-Type", "text/plain")?;
128                resp.set_content_length("healthy\n".len())?;
129
130                // Send header + body to the client
131                session
132                    .write_response_header(Box::new(resp), /*end_of_stream*/ false)
133                    .await?;
134                session
135                    .write_response_body(Some(Bytes::from_static(b"healthy\n")), /*end*/ true)
136                    .await?;
137
138                // Returning Ok(true) means "handled", stop further processing.
139                return Ok(true);
140            }
141            HealthCheckStatus::Unavailable {
142                status_code,
143                reason,
144            } => {
145                if let Some(reason) = reason {
146                    warn!(%path, status = status_code, reason = %reason, "Health check failed");
147                } else {
148                    warn!(%path, status = status_code, "Health check failed");
149                }
150                let _ = session.respond_error(status_code).await;
151                return Ok(true);
152            }
153        }
154
155        let context = match self.provider.resolve(session) {
156            Some(ctx) => ctx,
157            None => {
158                warn!(%path, "No context found for request host");
159                let _ = session.respond_error(421).await;
160                return Ok(true);
161            }
162        };
163
164        let destination_container = context
165            .get::<DestinationContainer>()
166            .await
167            .map_err(|_| Error::new_str("Destination Container is not present"))?;
168
169        let force_path = context.config.server.force_path_parameter;
170        let backend =
171            match destination_container.get_backend_for_request(session.req_header(), force_path) {
172                Some(b) => b,
173                None => {
174                    warn!(%path, "No matching backend, returning 404");
175                    let _ = session.respond_error(404).await;
176                    return Ok(true);
177                }
178            };
179
180        let destination_name = backend.destination.name.clone();
181        let _ = set_upstream_host_headers(session, &backend);
182        info!(backend_id = %destination_name, "Routing to backend");
183
184        rewrite_request_path(session.req_header_mut(), &destination_name, force_path);
185
186        let mut request_state = RequestContext::new(
187            context.clone(),
188            backend,
189            execution_context_from_request(session),
190        );
191
192        let plugin_runner = request_state.plugin_runner.clone();
193
194        let run_filters = plugin_runner
195            .run_request_filters(session, &mut request_state)
196            .await;
197
198        let res = match run_filters {
199            Ok(filter_result) => filter_result,
200            Err(err) => {
201                error!(%err, "Error running request filters");
202                let _ = session.respond_error(500).await;
203                return Ok(true);
204            }
205        };
206
207        ctx.set_resolved_request(request_state);
208
209        match res {
210            MiddlewareResult::Continue(resp_headers) => {
211                ctx.resolved_request.as_mut().unwrap().response_headers = Some(resp_headers);
212
213                Ok(false)
214            }
215            MiddlewareResult::Responded => Ok(true),
216        }
217    }
218
219    async fn upstream_peer(
220        &self,
221        _session: &mut Session,
222        ctx: &mut Self::CTX,
223    ) -> Result<Box<HttpPeer>> {
224        // Determine origin parts for TLS and SNI
225        let (host, port, is_tls) = parse_origin(&ctx.req_unsafe().backend.destination.url)
226            .map_err(|_| Error::new_str("Origin could not be parsed "))?;
227        let hostport = format!("{host}:{port}");
228
229        // Compose full upstream URL for logging with normalized scheme
230        let path_and_query = _session
231            .req_header()
232            .uri
233            .path_and_query()
234            .map(|pq| pq.as_str())
235            .unwrap_or("/");
236        let upstream_url = compose_upstream_url(is_tls, &host, port, path_and_query);
237
238        info!(%upstream_url, backend_id = %ctx.req_unsafe().backend.destination.name, is_tls, sni = %host, "Forwarding to upstream");
239        debug!(upstream_origin = %hostport, "Connecting to upstream origin");
240
241        let mut peer = HttpPeer::new(&hostport, is_tls, host);
242        if let Some(opts) = peer.get_mut_peer_options() {
243            // Allow both HTTP/1.1 and HTTP/2 so plain HTTP backends keep working.
244            opts.set_http_version(2, 1);
245        }
246        let peer = Box::new(peer);
247        Ok(peer)
248    }
249
250    async fn connected_to_upstream(
251        &self,
252        _session: &mut Session,
253        reused: bool,
254        peer: &HttpPeer,
255        #[cfg(unix)] _fd: std::os::unix::io::RawFd,
256        #[cfg(windows)] _sock: std::os::windows::io::RawSocket,
257        _digest: Option<&Digest>,
258        ctx: &mut Self::CTX,
259    ) -> Result<()> {
260        let backend_id = ctx.req_unsafe().backend.destination.name.to_string();
261
262        info!(backend_id, reused, peer = %peer, "Connected to upstream");
263        Ok(())
264    }
265
266    async fn response_filter(
267        &self,
268        session: &mut Session,
269        upstream_response: &mut ResponseHeader,
270        ctx: &mut Self::CTX,
271    ) -> Result<()> {
272        if let Some(resp_headers) = ctx.req_unsafe_mut().response_headers.take() {
273            for (key, val) in resp_headers {
274                let _ = upstream_response.insert_header(key, val);
275            }
276        }
277
278        let req = ctx.req_unsafe_mut();
279
280        let runner = req.plugin_runner.clone();
281
282        runner
283            .run_response_filters(session, req, upstream_response)
284            .await;
285
286        if !req.cardinal_context.config.server.log_upstream_response {
287            return Ok(());
288        }
289
290        let status = upstream_response.status.as_u16();
291        let location = upstream_response
292            .headers
293            .get("location")
294            .and_then(|v| v.to_str().ok())
295            .map(|s| s.to_string());
296        let backend_id = &req.backend.destination.name;
297        if let Some(loc) = location {
298            info!(backend_id, status, location = %loc, "Upstream responded");
299        } else {
300            info!(backend_id, status, "Upstream responded");
301        }
302
303        Ok(())
304    }
305}