cardinal_base/destinations/
container.rs

1use crate::context::CardinalContext;
2use crate::provider::Provider;
3use crate::router::CardinalRouter;
4use async_trait::async_trait;
5use cardinal_config::{Destination, Middleware, MiddlewareType};
6use cardinal_errors::CardinalError;
7use pingora::http::RequestHeader;
8use std::collections::BTreeMap;
9use std::sync::Arc;
10
11pub struct DestinationWrapper {
12    pub destination: Destination,
13    pub router: CardinalRouter,
14    pub has_routes: bool,
15    inbound_middleware: Vec<Middleware>,
16    outbound_middleware: Vec<Middleware>,
17}
18
19impl DestinationWrapper {
20    pub fn new(destination: Destination, router: Option<CardinalRouter>) -> Self {
21        let inbound_middleware = destination
22            .middleware
23            .iter()
24            .filter(|&e| e.r#type == MiddlewareType::Inbound)
25            .cloned()
26            .collect();
27        let outbound_middleware = destination
28            .middleware
29            .iter()
30            .filter(|&e| e.r#type == MiddlewareType::Outbound)
31            .cloned()
32            .collect();
33
34        Self {
35            has_routes: !destination.routes.is_empty(),
36            destination,
37            router: router.unwrap_or_default(),
38            inbound_middleware,
39            outbound_middleware,
40        }
41    }
42
43    pub fn get_inbound_middleware(&self) -> &Vec<Middleware> {
44        &self.inbound_middleware
45    }
46
47    pub fn get_outbound_middleware(&self) -> &Vec<Middleware> {
48        &self.outbound_middleware
49    }
50}
51
52pub struct DestinationContainer {
53    destinations: BTreeMap<String, Arc<DestinationWrapper>>,
54}
55
56impl DestinationContainer {
57    pub fn get_backend_for_request(
58        &self,
59        req: &RequestHeader,
60        force_parameter: bool,
61    ) -> Option<Arc<DestinationWrapper>> {
62        let candidate_id = if !force_parameter {
63            extract_subdomain(req)
64        } else {
65            first_path_segment(req)
66        };
67
68        self.destinations.get(&candidate_id?).cloned()
69    }
70}
71
72#[async_trait]
73impl Provider for DestinationContainer {
74    async fn provide(ctx: &CardinalContext) -> Result<Self, CardinalError> {
75        let destinations = ctx
76            .config
77            .destinations
78            .clone()
79            .into_iter()
80            .map(|(key, destination)| {
81                let router =
82                    destination
83                        .routes
84                        .iter()
85                        .fold(CardinalRouter::new(), |mut r, route| {
86                            let _ = r.add(route.method.as_str(), route.path.as_str());
87                            r
88                        });
89                (
90                    key,
91                    Arc::new(DestinationWrapper::new(destination, Some(router))),
92                )
93            })
94            .collect::<BTreeMap<_, _>>();
95
96        Ok(Self { destinations })
97    }
98}
99
100fn first_path_segment(req: &RequestHeader) -> Option<String> {
101    let path = req.uri.path();
102    path.strip_prefix('/')
103        .and_then(|p| p.split('/').next())
104        .filter(|s| !s.is_empty())
105        .map(|s| s.to_ascii_lowercase())
106}
107
108fn extract_subdomain(req: &RequestHeader) -> Option<String> {
109    let host = req.uri.host().map(|h| h.to_string()).or_else(|| {
110        req.headers
111            .get("host")
112            .and_then(|v| v.to_str().ok())
113            .map(|s| s.to_string())
114    })?;
115
116    let host_no_port = host.split(':').next()?.to_ascii_lowercase();
117
118    // Only treat as valid when there is a true subdomain: at least sub.domain.tld
119    let parts: Vec<&str> = host_no_port.split('.').collect();
120    if parts.len() < 3 {
121        return None;
122    }
123
124    let first = parts[0];
125    if first.is_empty() || first == "www" {
126        None
127    } else {
128        Some(first.to_string())
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use http::{Method, Uri};
136
137    fn req_with_path(pq: &str) -> RequestHeader {
138        RequestHeader::build(Method::GET, pq.as_bytes(), None).unwrap()
139    }
140
141    #[test]
142    fn first_segment_basic() {
143        let req = req_with_path("/api/users");
144        assert_eq!(first_path_segment(&req), Some("api".to_string()));
145    }
146
147    #[test]
148    fn first_segment_root_none() {
149        let req = req_with_path("/");
150        assert_eq!(first_path_segment(&req), None);
151    }
152
153    #[test]
154    fn first_segment_case_insensitive() {
155        let req = req_with_path("/API/v1");
156        assert_eq!(first_path_segment(&req), Some("api".to_string()));
157    }
158
159    #[test]
160    fn first_segment_trailing_slash() {
161        let req = req_with_path("/api/");
162        assert_eq!(first_path_segment(&req), Some("api".to_string()));
163    }
164
165    fn req_with_host_header(host: &str, path: &str) -> RequestHeader {
166        let mut req = req_with_path(path);
167        req.insert_header("host", host).unwrap();
168        req
169    }
170
171    #[test]
172    fn subdomain_from_host_header_basic() {
173        let req = req_with_host_header("api.mygateway.com", "/any");
174        assert_eq!(extract_subdomain(&req), Some("api".to_string()));
175    }
176
177    #[test]
178    fn subdomain_from_host_header_with_port() {
179        let req = req_with_host_header("api.mygateway.com:8080", "/any");
180        assert_eq!(extract_subdomain(&req), Some("api".to_string()));
181    }
182
183    #[test]
184    fn subdomain_www_is_ignored() {
185        let req = req_with_host_header("www.mygateway.com", "/any");
186        assert_eq!(extract_subdomain(&req), None);
187    }
188
189    #[test]
190    fn subdomain_requires_at_least_domain_and_tld() {
191        let req = req_with_host_header("localhost", "/any");
192        assert_eq!(extract_subdomain(&req), None);
193    }
194
195    #[test]
196    fn apex_domain_returns_none() {
197        let req = req_with_host_header("mygateway.com", "/any");
198        assert_eq!(extract_subdomain(&req), None);
199    }
200
201    #[test]
202    fn subdomain_from_uri_authority() {
203        let mut req = req_with_path("/any");
204        let uri: Uri = "http://API.Example.com/any".parse().unwrap();
205        req.set_uri(uri);
206        assert_eq!(extract_subdomain(&req), Some("api".to_string()));
207    }
208}