cardinal_proxy/
lib.rs

1mod utils;
2
3use crate::utils::requests::{
4    compose_upstream_url, parse_origin, rewrite_request_path, set_upstream_host_headers,
5};
6use cardinal_base::context::CardinalContext;
7use cardinal_base::destinations::container::{DestinationContainer, DestinationWrapper};
8use cardinal_plugins::runner::{MiddlewareResult, PluginRunner};
9use pingora::http::ResponseHeader;
10use pingora::prelude::*;
11use pingora::protocols::Digest;
12use pingora::upstreams::peer::Peer;
13use std::sync::Arc;
14use tracing::{debug, error, info, warn};
15
16pub mod pingora {
17    pub use pingora::*;
18}
19
20pub trait CardinalContextProvider: Send + Sync {
21    fn resolve(&self, session: &Session) -> Option<Arc<CardinalContext>>;
22}
23
24#[derive(Clone)]
25pub struct StaticContextProvider {
26    context: Arc<CardinalContext>,
27}
28
29impl StaticContextProvider {
30    pub fn new(context: Arc<CardinalContext>) -> Self {
31        Self { context }
32    }
33}
34
35impl CardinalContextProvider for StaticContextProvider {
36    fn resolve(&self, _session: &Session) -> Option<Arc<CardinalContext>> {
37        Some(self.context.clone())
38    }
39}
40
41pub struct CardinalProxy {
42    provider: Arc<dyn CardinalContextProvider>,
43}
44
45impl CardinalProxy {
46    pub fn new(context: Arc<CardinalContext>) -> Self {
47        Self::builder(context).build()
48    }
49
50    pub fn with_provider(provider: Arc<dyn CardinalContextProvider>) -> Self {
51        Self { provider }
52    }
53
54    pub fn builder(context: Arc<CardinalContext>) -> CardinalProxyBuilder {
55        CardinalProxyBuilder::new(context)
56    }
57}
58
59pub struct CardinalProxyBuilder {
60    provider: Arc<dyn CardinalContextProvider>,
61}
62
63impl CardinalProxyBuilder {
64    pub fn new(context: Arc<CardinalContext>) -> Self {
65        Self {
66            provider: Arc::new(StaticContextProvider::new(context)),
67        }
68    }
69
70    pub fn from_context_provider(provider: Arc<dyn CardinalContextProvider>) -> Self {
71        Self { provider }
72    }
73
74    pub fn with_context_provider(mut self, provider: Arc<dyn CardinalContextProvider>) -> Self {
75        self.provider = provider;
76        self
77    }
78
79    pub fn build(self) -> CardinalProxy {
80        CardinalProxy::with_provider(self.provider)
81    }
82}
83
84#[async_trait::async_trait]
85impl ProxyHttp for CardinalProxy {
86    type CTX = Option<RequestContext>;
87
88    fn new_ctx(&self) -> Self::CTX {
89        None
90    }
91
92    async fn request_filter(&self, session: &mut Session, ctx: &mut Self::CTX) -> Result<bool> {
93        let path = session.req_header().uri.path().to_string();
94        info!(%path, "Request received");
95
96        let context = match self.provider.resolve(session) {
97            Some(ctx) => ctx,
98            None => {
99                warn!(%path, "No context found for request host");
100                let _ = session.respond_error(421).await;
101                return Ok(true);
102            }
103        };
104
105        let destination_container = context
106            .get::<DestinationContainer>()
107            .await
108            .map_err(|_| Error::new_str("Destination Container is not present"))?;
109
110        let force_path = context.config.server.force_path_parameter;
111        let backend =
112            match destination_container.get_backend_for_request(session.req_header(), force_path) {
113                Some(b) => b.clone(),
114                None => {
115                    warn!(%path, "No matching backend, returning 404");
116                    let _ = session.respond_error(404).await;
117                    return Ok(true);
118                }
119            };
120
121        let destination_name = backend.destination.name.clone();
122        let _ = set_upstream_host_headers(session, &backend);
123        info!(backend_id = %destination_name, "Routing to backend");
124
125        rewrite_request_path(session.req_header_mut(), &destination_name, force_path);
126
127        let mut request_state = RequestContext::new(context.clone());
128        request_state.backend = Some(backend.clone());
129
130        let run_filters = request_state
131            .runner
132            .run_request_filters(session, backend)
133            .await;
134
135        let res = match run_filters {
136            Ok(filter_result) => filter_result,
137            Err(err) => {
138                error!(%err, "Error running request filters");
139                let _ = session.respond_error(500).await;
140                return Ok(true);
141            }
142        };
143
144        *ctx = Some(request_state);
145
146        match res {
147            MiddlewareResult::Continue => Ok(false),
148            MiddlewareResult::Responded => Ok(true),
149        }
150    }
151
152    async fn upstream_peer(
153        &self,
154        _session: &mut Session,
155        ctx: &mut Self::CTX,
156    ) -> Result<Box<HttpPeer>> {
157        if let Some(state) = ctx.as_ref() {
158            if let Some(backend) = state.backend.as_ref() {
159                // Determine origin parts for TLS and SNI
160                let (host, port, is_tls) = parse_origin(&backend.destination.url)
161                    .map_err(|_| Error::new_str("Origin could not be parsed "))?;
162                let hostport = format!("{host}:{port}");
163
164                // Compose full upstream URL for logging with normalized scheme
165                let path_and_query = _session
166                    .req_header()
167                    .uri
168                    .path_and_query()
169                    .map(|pq| pq.as_str())
170                    .unwrap_or("/");
171                let upstream_url = compose_upstream_url(is_tls, &host, port, path_and_query);
172
173                info!(%upstream_url, backend_id = %backend.destination.name, is_tls, sni = %host, "Forwarding to upstream");
174                debug!(upstream_origin = %hostport, "Connecting to upstream origin");
175
176                let mut peer = HttpPeer::new(&hostport, is_tls, host);
177                if let Some(opts) = peer.get_mut_peer_options() {
178                    // Allow both HTTP/1.1 and HTTP/2 so plain HTTP backends keep working.
179                    opts.set_http_version(2, 1);
180                }
181                let peer = Box::new(peer);
182                Ok(peer)
183            } else {
184                Err(Error::new(ErrorType::InternalError))
185            }
186        } else {
187            Err(Error::new(ErrorType::InternalError))
188        }
189    }
190
191    async fn connected_to_upstream(
192        &self,
193        _session: &mut Session,
194        reused: bool,
195        peer: &HttpPeer,
196        #[cfg(unix)] _fd: std::os::unix::io::RawFd,
197        #[cfg(windows)] _sock: std::os::windows::io::RawSocket,
198        _digest: Option<&Digest>,
199        ctx: &mut Self::CTX,
200    ) -> Result<()> {
201        let backend_id = ctx
202            .as_ref()
203            .and_then(|state| state.backend.as_ref())
204            .map(|b| b.destination.name.as_str())
205            .unwrap_or("<unknown>");
206        info!(backend_id, reused, peer = %peer, "Connected to upstream");
207        Ok(())
208    }
209
210    async fn response_filter(
211        &self,
212        session: &mut Session,
213        upstream_response: &mut ResponseHeader,
214        ctx: &mut Self::CTX,
215    ) -> Result<()> {
216        if let Some(state) = ctx.as_mut() {
217            if let Some(backend) = state.backend.clone() {
218                state
219                    .runner
220                    .run_response_filters(session, backend, upstream_response)
221                    .await;
222            }
223
224            if !state.context.config.server.log_upstream_response {
225                return Ok(());
226            }
227
228            let status = upstream_response.status.as_u16();
229            let location = upstream_response
230                .headers
231                .get("location")
232                .and_then(|v| v.to_str().ok())
233                .map(|s| s.to_string());
234            let backend_id = state
235                .backend
236                .as_ref()
237                .map(|b| b.destination.name.as_str())
238                .unwrap_or("<unknown>");
239            if let Some(loc) = location {
240                info!(backend_id, status, location = %loc, "Upstream responded");
241            } else {
242                info!(backend_id, status, "Upstream responded");
243            }
244        }
245
246        Ok(())
247    }
248}
249
250pub struct RequestContext {
251    pub context: Arc<CardinalContext>,
252    pub backend: Option<Arc<DestinationWrapper>>,
253    pub runner: PluginRunner,
254}
255
256impl RequestContext {
257    fn new(context: Arc<CardinalContext>) -> Self {
258        let runner = PluginRunner::new(context.clone());
259        Self {
260            context,
261            backend: None,
262            runner,
263        }
264    }
265}