cardinal_proxy/
lib.rs

1mod utils;
2
3use crate::utils::requests::{
4    compose_upstream_url, parse_origin, rewrite_request_path, set_upstream_host_headers,
5};
6use cardinal_base::context::CardinalContext;
7use cardinal_base::destinations::container::{DestinationContainer, DestinationWrapper};
8use cardinal_plugins::runner::{MiddlewareResult, PluginRunner};
9use pingora::http::ResponseHeader;
10use pingora::prelude::*;
11use pingora::protocols::Digest;
12use pingora::upstreams::peer::Peer;
13use std::collections::HashMap;
14use std::sync::Arc;
15use tracing::{debug, error, info, warn};
16
17pub mod pingora {
18    pub use pingora::*;
19}
20
21pub trait CardinalContextProvider: Send + Sync {
22    fn resolve(&self, session: &Session) -> Option<Arc<CardinalContext>>;
23}
24
25#[derive(Clone)]
26pub struct StaticContextProvider {
27    context: Arc<CardinalContext>,
28}
29
30impl StaticContextProvider {
31    pub fn new(context: Arc<CardinalContext>) -> Self {
32        Self { context }
33    }
34}
35
36impl CardinalContextProvider for StaticContextProvider {
37    fn resolve(&self, _session: &Session) -> Option<Arc<CardinalContext>> {
38        Some(self.context.clone())
39    }
40}
41
42pub struct CardinalProxy {
43    provider: Arc<dyn CardinalContextProvider>,
44}
45
46impl CardinalProxy {
47    pub fn new(context: Arc<CardinalContext>) -> Self {
48        Self::builder(context).build()
49    }
50
51    pub fn with_provider(provider: Arc<dyn CardinalContextProvider>) -> Self {
52        Self { provider }
53    }
54
55    pub fn builder(context: Arc<CardinalContext>) -> CardinalProxyBuilder {
56        CardinalProxyBuilder::new(context)
57    }
58}
59
60pub struct CardinalProxyBuilder {
61    provider: Arc<dyn CardinalContextProvider>,
62}
63
64impl CardinalProxyBuilder {
65    pub fn new(context: Arc<CardinalContext>) -> Self {
66        Self {
67            provider: Arc::new(StaticContextProvider::new(context)),
68        }
69    }
70
71    pub fn from_context_provider(provider: Arc<dyn CardinalContextProvider>) -> Self {
72        Self { provider }
73    }
74
75    pub fn with_context_provider(mut self, provider: Arc<dyn CardinalContextProvider>) -> Self {
76        self.provider = provider;
77        self
78    }
79
80    pub fn build(self) -> CardinalProxy {
81        CardinalProxy::with_provider(self.provider)
82    }
83}
84
85#[async_trait::async_trait]
86impl ProxyHttp for CardinalProxy {
87    type CTX = Option<RequestContext>;
88
89    fn new_ctx(&self) -> Self::CTX {
90        None
91    }
92
93    async fn request_filter(&self, session: &mut Session, ctx: &mut Self::CTX) -> Result<bool> {
94        let path = session.req_header().uri.path().to_string();
95        info!(%path, "Request received");
96
97        let context = match self.provider.resolve(session) {
98            Some(ctx) => ctx,
99            None => {
100                warn!(%path, "No context found for request host");
101                let _ = session.respond_error(421).await;
102                return Ok(true);
103            }
104        };
105
106        let destination_container = context
107            .get::<DestinationContainer>()
108            .await
109            .map_err(|_| Error::new_str("Destination Container is not present"))?;
110
111        let force_path = context.config.server.force_path_parameter;
112        let backend =
113            match destination_container.get_backend_for_request(session.req_header(), force_path) {
114                Some(b) => b.clone(),
115                None => {
116                    warn!(%path, "No matching backend, returning 404");
117                    let _ = session.respond_error(404).await;
118                    return Ok(true);
119                }
120            };
121
122        let destination_name = backend.destination.name.clone();
123        let _ = set_upstream_host_headers(session, &backend);
124        info!(backend_id = %destination_name, "Routing to backend");
125
126        rewrite_request_path(session.req_header_mut(), &destination_name, force_path);
127
128        let mut request_state = RequestContext::new(context.clone());
129        request_state.backend = Some(backend.clone());
130
131        let run_filters = request_state
132            .runner
133            .run_request_filters(session, backend)
134            .await;
135
136        let res = match run_filters {
137            Ok(filter_result) => filter_result,
138            Err(err) => {
139                error!(%err, "Error running request filters");
140                let _ = session.respond_error(500).await;
141                return Ok(true);
142            }
143        };
144
145        *ctx = Some(request_state);
146
147        match res {
148            MiddlewareResult::Continue(resp_headers) => {
149                ctx.as_mut().unwrap().response_headers = Some(resp_headers);
150
151                Ok(false)
152            }
153            MiddlewareResult::Responded => Ok(true),
154        }
155    }
156
157    async fn upstream_peer(
158        &self,
159        _session: &mut Session,
160        ctx: &mut Self::CTX,
161    ) -> Result<Box<HttpPeer>> {
162        if let Some(state) = ctx.as_ref() {
163            if let Some(backend) = state.backend.as_ref() {
164                // Determine origin parts for TLS and SNI
165                let (host, port, is_tls) = parse_origin(&backend.destination.url)
166                    .map_err(|_| Error::new_str("Origin could not be parsed "))?;
167                let hostport = format!("{host}:{port}");
168
169                // Compose full upstream URL for logging with normalized scheme
170                let path_and_query = _session
171                    .req_header()
172                    .uri
173                    .path_and_query()
174                    .map(|pq| pq.as_str())
175                    .unwrap_or("/");
176                let upstream_url = compose_upstream_url(is_tls, &host, port, path_and_query);
177
178                info!(%upstream_url, backend_id = %backend.destination.name, is_tls, sni = %host, "Forwarding to upstream");
179                debug!(upstream_origin = %hostport, "Connecting to upstream origin");
180
181                let mut peer = HttpPeer::new(&hostport, is_tls, host);
182                if let Some(opts) = peer.get_mut_peer_options() {
183                    // Allow both HTTP/1.1 and HTTP/2 so plain HTTP backends keep working.
184                    opts.set_http_version(2, 1);
185                }
186                let peer = Box::new(peer);
187                Ok(peer)
188            } else {
189                Err(Error::new(ErrorType::InternalError))
190            }
191        } else {
192            Err(Error::new(ErrorType::InternalError))
193        }
194    }
195
196    async fn connected_to_upstream(
197        &self,
198        _session: &mut Session,
199        reused: bool,
200        peer: &HttpPeer,
201        #[cfg(unix)] _fd: std::os::unix::io::RawFd,
202        #[cfg(windows)] _sock: std::os::windows::io::RawSocket,
203        _digest: Option<&Digest>,
204        ctx: &mut Self::CTX,
205    ) -> Result<()> {
206        let backend_id = ctx
207            .as_ref()
208            .and_then(|state| state.backend.as_ref())
209            .map(|b| b.destination.name.as_str())
210            .unwrap_or("<unknown>");
211        info!(backend_id, reused, peer = %peer, "Connected to upstream");
212        Ok(())
213    }
214
215    async fn response_filter(
216        &self,
217        session: &mut Session,
218        upstream_response: &mut ResponseHeader,
219        ctx: &mut Self::CTX,
220    ) -> Result<()> {
221        if let Some(state) = ctx.as_mut() {
222            if let Some(resp_headers) = state.response_headers.take() {
223                for (key, val) in resp_headers {
224                    let _ = upstream_response.insert_header(key, val);
225                }
226            }
227
228            if let Some(backend) = state.backend.clone() {
229                state
230                    .runner
231                    .run_response_filters(session, backend, upstream_response)
232                    .await;
233            }
234
235            if !state.context.config.server.log_upstream_response {
236                return Ok(());
237            }
238
239            let status = upstream_response.status.as_u16();
240            let location = upstream_response
241                .headers
242                .get("location")
243                .and_then(|v| v.to_str().ok())
244                .map(|s| s.to_string());
245            let backend_id = state
246                .backend
247                .as_ref()
248                .map(|b| b.destination.name.as_str())
249                .unwrap_or("<unknown>");
250            if let Some(loc) = location {
251                info!(backend_id, status, location = %loc, "Upstream responded");
252            } else {
253                info!(backend_id, status, "Upstream responded");
254            }
255        }
256
257        Ok(())
258    }
259}
260
261pub struct RequestContext {
262    pub context: Arc<CardinalContext>,
263    pub backend: Option<Arc<DestinationWrapper>>,
264    pub runner: PluginRunner,
265    pub response_headers: Option<HashMap<String, String>>,
266}
267
268impl RequestContext {
269    fn new(context: Arc<CardinalContext>) -> Self {
270        let runner = PluginRunner::new(context.clone());
271        Self {
272            context,
273            backend: None,
274            runner,
275            response_headers: None,
276        }
277    }
278}