cardinal_proxy/
lib.rs

1pub mod context_provider;
2pub mod req;
3pub mod retry;
4mod utils;
5
6use crate::context_provider::CardinalContextProvider;
7use crate::req::ReqCtx;
8use crate::retry::RetryState;
9use crate::utils::requests::{
10    compose_upstream_url, execution_context_from_request, parse_origin, rewrite_request_path,
11    set_upstream_host_headers,
12};
13use bytes::Bytes;
14use cardinal_base::context::CardinalContext;
15use cardinal_base::destinations::container::DestinationContainer;
16use cardinal_plugins::plugin_executor::CardinalPluginExecutor;
17use cardinal_plugins::request_context::RequestContext;
18use cardinal_plugins::runner::MiddlewareResult;
19use pingora::http::ResponseHeader;
20use pingora::prelude::*;
21use pingora::protocols::Digest;
22use pingora::upstreams::peer::Peer;
23use std::sync::Arc;
24use std::time::Duration;
25use tracing::{debug, error, info, warn};
26
27pub mod pingora {
28    pub use pingora::*;
29}
30
31#[derive(Debug, Clone)]
32pub enum HealthCheckStatus {
33    None,
34    Ready,
35    Unavailable {
36        status_code: u16,
37        reason: Option<String>,
38    },
39}
40
41#[derive(Clone)]
42pub struct StaticContextProvider {
43    context: Arc<CardinalContext>,
44}
45
46impl StaticContextProvider {
47    pub fn new(context: Arc<CardinalContext>) -> Self {
48        Self { context }
49    }
50}
51
52impl CardinalContextProvider for StaticContextProvider {
53    fn resolve(&self, _session: &Session, _ctx: &mut ReqCtx) -> Option<Arc<CardinalContext>> {
54        Some(self.context.clone())
55    }
56}
57
58#[async_trait::async_trait]
59impl CardinalPluginExecutor for StaticContextProvider {}
60
61pub struct CardinalProxy {
62    provider: Arc<dyn CardinalContextProvider>,
63    plugin_executor: Arc<dyn CardinalPluginExecutor>,
64}
65
66impl CardinalProxy {
67    pub fn new(context: Arc<CardinalContext>) -> Self {
68        Self::builder(context).build()
69    }
70
71    pub fn with_provider(
72        provider: Arc<dyn CardinalContextProvider>,
73        plugin_executor: Arc<dyn CardinalPluginExecutor>,
74    ) -> Self {
75        Self {
76            provider,
77            plugin_executor,
78        }
79    }
80
81    pub fn builder(context: Arc<CardinalContext>) -> CardinalProxyBuilder {
82        CardinalProxyBuilder::new(context)
83    }
84}
85
86pub struct CardinalProxyBuilder {
87    provider: Arc<dyn CardinalContextProvider>,
88    plugin_executor: Arc<dyn CardinalPluginExecutor>,
89}
90
91impl CardinalProxyBuilder {
92    pub fn new(context: Arc<CardinalContext>) -> Self {
93        Self {
94            provider: Arc::new(StaticContextProvider::new(context.clone())),
95            plugin_executor: Arc::new(StaticContextProvider::new(context)),
96        }
97    }
98
99    pub fn from_context_provider(
100        provider: Arc<dyn CardinalContextProvider>,
101        plugin_executor: Arc<dyn CardinalPluginExecutor>,
102    ) -> Self {
103        Self {
104            provider,
105            plugin_executor,
106        }
107    }
108
109    pub fn with_context_provider(
110        mut self,
111        provider: Arc<dyn CardinalContextProvider>,
112        plugin_executor: Arc<dyn CardinalPluginExecutor>,
113    ) -> Self {
114        self.provider = provider;
115        self.plugin_executor = plugin_executor;
116        self
117    }
118
119    pub fn build(self) -> CardinalProxy {
120        CardinalProxy::with_provider(self.provider, self.plugin_executor)
121    }
122}
123
124#[async_trait::async_trait]
125impl ProxyHttp for CardinalProxy {
126    type CTX = ReqCtx;
127
128    fn new_ctx(&self) -> Self::CTX {
129        self.provider.ctx()
130    }
131
132    async fn early_request_filter(&self, _session: &mut Session, _ctx: &mut Self::CTX) -> Result<()>
133    where
134        Self::CTX: Send + Sync,
135    {
136        self.provider.early_request_filter(_session, _ctx).await
137    }
138
139    async fn logging(&self, _session: &mut Session, _e: Option<&Error>, ctx: &mut Self::CTX)
140    where
141        Self::CTX: Send + Sync,
142    {
143        self.provider.logging(_session, _e, ctx);
144    }
145
146    async fn request_body_filter(
147        &self,
148        _session: &mut Session,
149        _body: &mut Option<Bytes>,
150        _end_of_stream: bool,
151        _ctx: &mut Self::CTX,
152    ) -> Result<()>
153    where
154        Self::CTX: Send + Sync,
155    {
156        self.provider
157            .request_body_filter(_session, _body, _end_of_stream, _ctx)
158            .await
159    }
160
161    fn response_body_filter(
162        &self,
163        _session: &mut Session,
164        _body: &mut Option<Bytes>,
165        _end_of_stream: bool,
166        _ctx: &mut Self::CTX,
167    ) -> Result<Option<Duration>>
168    where
169        Self::CTX: Send + Sync,
170    {
171        self.provider
172            .response_body_filter(_session, _body, _end_of_stream, _ctx)
173    }
174
175    async fn request_filter(&self, session: &mut Session, ctx: &mut Self::CTX) -> Result<bool> {
176        let path = session.req_header().uri.path().to_string();
177        info!(%path, "Request received");
178
179        match self.provider.health_check(session) {
180            HealthCheckStatus::None => {}
181            HealthCheckStatus::Ready => {
182                debug!(%path, "Health check ready");
183                // Build a 200 OK header
184                let mut resp = ResponseHeader::build(200, None)?;
185                resp.insert_header("Content-Type", "text/plain")?;
186                resp.set_content_length("healthy\n".len())?;
187
188                // Send header + body to the client
189                session
190                    .write_response_header(Box::new(resp), /*end_of_stream*/ false)
191                    .await?;
192                session
193                    .write_response_body(Some(Bytes::from_static(b"healthy\n")), /*end*/ true)
194                    .await?;
195
196                // Returning Ok(true) means "handled", stop further processing.
197                return Ok(true);
198            }
199            HealthCheckStatus::Unavailable {
200                status_code,
201                reason,
202            } => {
203                if let Some(reason) = reason {
204                    warn!(%path, status = status_code, reason = %reason, "Health check failed");
205                } else {
206                    warn!(%path, status = status_code, "Health check failed");
207                }
208                let _ = session.respond_error(status_code).await;
209                return Ok(true);
210            }
211        }
212
213        let context = match self.provider.resolve(session, ctx) {
214            Some(ctx) => ctx,
215            None => {
216                warn!(%path, "No context found for request host");
217                let _ = session.respond_error(421).await;
218                return Ok(true);
219            }
220        };
221
222        let destination_container = context
223            .get::<DestinationContainer>()
224            .await
225            .map_err(|_| Error::new_str("Destination Container is not present"))?;
226
227        let force_path = context.config.server.force_path_parameter;
228        let backend =
229            match destination_container.get_backend_for_request(session.req_header(), force_path) {
230                Some(b) => b,
231                None => {
232                    warn!(%path, "No matching backend, returning 404");
233                    let _ = session.respond_error(404).await;
234                    return Ok(true);
235                }
236            };
237
238        let destination_name = backend.destination.name.clone();
239        let _ = set_upstream_host_headers(session, &backend);
240        info!(backend_id = %destination_name, "Routing to backend");
241
242        rewrite_request_path(session.req_header_mut(), &destination_name, force_path);
243
244        let mut request_state = RequestContext::new(
245            context.clone(),
246            backend,
247            execution_context_from_request(session),
248            self.plugin_executor.clone(),
249        );
250
251        self.provider
252            .request_context_initialized(session, &mut request_state)
253            .await;
254
255        let plugin_runner = request_state.plugin_runner.clone();
256
257        let run_filters = plugin_runner
258            .run_request_filters(session, &mut request_state)
259            .await;
260
261        let res = match run_filters {
262            Ok(filter_result) => filter_result,
263            Err(err) => {
264                error!(%err, "Error running request filters");
265                let _ = session.respond_error(500).await;
266                return Ok(true);
267            }
268        };
269
270        ctx.set_resolved_request(request_state);
271
272        match res {
273            MiddlewareResult::Continue(resp_headers) => {
274                ctx.ctx_base
275                    .resolved_request
276                    .as_mut()
277                    .unwrap()
278                    .response_headers = Some(resp_headers);
279
280                Ok(false)
281            }
282            MiddlewareResult::Responded => Ok(true),
283        }
284    }
285
286    fn fail_to_connect(
287        &self,
288        _session: &mut Session,
289        _peer: &HttpPeer,
290        ctx: &mut Self::CTX,
291        mut e: Box<Error>,
292    ) -> Box<Error> {
293        let backend_config = ctx.req_unsafe().backend.destination.retry.clone();
294        if let Some(mut retry_state) = ctx.retry_state.take() {
295            retry_state.register_attempt();
296            if retry_state.can_retry() {
297                e.set_retry(true);
298                ctx.retry_state = Some(retry_state);
299            } else {
300                ctx.retry_state = None;
301            }
302        } else if let Some(retry_config) = backend_config {
303            let mut retry_state = RetryState::from(retry_config);
304            retry_state.register_attempt();
305            if retry_state.can_retry() {
306                e.set_retry(true);
307                ctx.retry_state = Some(retry_state);
308            } else {
309                ctx.retry_state = None;
310            }
311        }
312
313        e
314    }
315
316    async fn upstream_peer(
317        &self,
318        _session: &mut Session,
319        ctx: &mut Self::CTX,
320    ) -> Result<Box<HttpPeer>> {
321        if let Some(retry_state) = ctx.retry_state.as_mut() {
322            if !retry_state.sleep_if_retry_allowed().await {
323                ctx.retry_state = None;
324                return Err(Error::new_str("Retry attempts exhausted"));
325            }
326        }
327
328        let backend = &ctx.req_unsafe().backend;
329        // Determine origin parts for TLS and SNI
330        let (host, port, is_tls) = parse_origin(&backend.destination.url)
331            .map_err(|_| Error::new_str("Origin could not be parsed "))?;
332        let hostport = format!("{host}:{port}");
333
334        // Compose full upstream URL for logging with normalized scheme
335        let path_and_query = _session
336            .req_header()
337            .uri
338            .path_and_query()
339            .map(|pq| pq.as_str())
340            .unwrap_or("/");
341        let upstream_url = compose_upstream_url(is_tls, &host, port, path_and_query);
342
343        info!(%upstream_url, backend_id = %&backend.destination.name, is_tls, sni = %host, "Forwarding to upstream");
344        debug!(upstream_origin = %hostport, "Connecting to upstream origin");
345
346        let mut peer = HttpPeer::new(&hostport, is_tls, host);
347        if let Some(opts) = peer.get_mut_peer_options() {
348            // Allow both HTTP/1.1 and HTTP/2 so plain HTTP backends keep working.
349            opts.set_http_version(2, 1);
350            if let Some(timeout) = &backend.destination.timeout {
351                opts.idle_timeout = timeout
352                    .idle
353                    .as_ref()
354                    .map(|idle| Duration::from_millis(*idle));
355                opts.write_timeout = timeout
356                    .write
357                    .as_ref()
358                    .map(|idle| Duration::from_millis(*idle));
359                opts.total_connection_timeout = timeout
360                    .connect
361                    .as_ref()
362                    .map(|idle| Duration::from_millis(*idle));
363                opts.read_timeout = timeout
364                    .read
365                    .as_ref()
366                    .map(|idle| Duration::from_millis(*idle));
367            }
368        }
369        let peer = Box::new(peer);
370        Ok(peer)
371    }
372
373    async fn connected_to_upstream(
374        &self,
375        _session: &mut Session,
376        reused: bool,
377        peer: &HttpPeer,
378        #[cfg(unix)] _fd: std::os::unix::io::RawFd,
379        #[cfg(windows)] _sock: std::os::windows::io::RawSocket,
380        _digest: Option<&Digest>,
381        ctx: &mut Self::CTX,
382    ) -> Result<()> {
383        ctx.retry_state = None;
384        let backend_id = ctx.req_unsafe().backend.destination.name.to_string();
385
386        info!(backend_id, reused, peer = %peer, "Connected to upstream");
387        Ok(())
388    }
389
390    async fn response_filter(
391        &self,
392        session: &mut Session,
393        upstream_response: &mut ResponseHeader,
394        ctx: &mut Self::CTX,
395    ) -> Result<()> {
396        if let Some(resp_headers) = ctx.req_unsafe_mut().response_headers.take() {
397            for (key, val) in resp_headers {
398                let _ = upstream_response.insert_header(key, val);
399            }
400        }
401
402        {
403            // Run response filters first
404            {
405                let runner = {
406                    let req = ctx.req_unsafe_mut();
407                    req.plugin_runner.clone()
408                };
409
410                runner
411                    .run_response_filters(
412                        session,
413                        {
414                            let req = ctx.req_unsafe_mut();
415                            req
416                        },
417                        upstream_response,
418                    )
419                    .await;
420            }
421
422            ctx.set("status", upstream_response.status.as_str());
423
424            // Safe to get another mutable reference now
425            let req = ctx.req_unsafe_mut();
426
427            if !req.cardinal_context.config.server.log_upstream_response {
428                return Ok(());
429            }
430
431            let status = upstream_response.status.as_u16();
432            let location = upstream_response
433                .headers
434                .get("location")
435                .and_then(|v| v.to_str().ok())
436                .map(str::to_string);
437            let backend_id = &req.backend.destination.name;
438
439            match location {
440                Some(loc) => info!(backend_id, status, location = %loc, "Upstream responded"),
441                None => info!(backend_id, status, "Upstream responded"),
442            }
443        }
444
445        Ok(())
446    }
447}