use crate::config::models::ServiceProtocol;
use crate::server::service::ServiceManager;
use crate::server::upstream::UpstreamManager;
use arc_swap::ArcSwap;
use jokoway_transformer::{RequestTransformer, ResponseTransformer};
use pingora::http::RequestHeader;
use std::collections::HashMap;
use std::sync::Arc;
pub use crate::server::service::{ALL_PROTOCOLS, HTTP_PROTOCOLS, HTTPS_PROTOCOLS};
pub struct RouteMatch {
pub upstream_name: Arc<str>,
pub req_transformer: Option<Arc<dyn RequestTransformer>>,
pub res_transformer: Option<Arc<dyn ResponseTransformer>>,
pub max_retries: u32,
}
pub struct RouteIndex {
pub path_router: matchit::Router<Vec<usize>>,
pub fallback_indices: Vec<usize>,
pub all_indices: Vec<usize>,
}
pub struct Router {
service_manager: Arc<ServiceManager>,
upstream_manager: Arc<UpstreamManager>,
host_index: ArcSwap<HashMap<String, Arc<RouteIndex>>>,
catch_all_index: ArcSwap<Arc<RouteIndex>>,
protocols: Vec<ServiceProtocol>,
}
impl Router {
pub fn new(
service_manager: Arc<ServiceManager>,
upstream_manager: Arc<UpstreamManager>,
protocols: &[ServiceProtocol],
) -> Arc<Self> {
let (host_index, catch_all_index) = Self::build_indices(&service_manager, protocols);
let router = Arc::new(Router {
service_manager: service_manager.clone(),
upstream_manager,
host_index: ArcSwap::from_pointee(host_index),
catch_all_index: ArcSwap::from_pointee(catch_all_index),
protocols: protocols.to_vec(),
});
let router_weak = Arc::downgrade(&router);
service_manager.add_services_changed_callback(move || {
if let Some(r) = router_weak.upgrade() {
r.refresh_indices();
}
});
router
}
fn build_indices(
service_manager: &ServiceManager,
protocols: &[ServiceProtocol],
) -> (HashMap<String, Arc<RouteIndex>>, Arc<RouteIndex>) {
let service_indices = service_manager.get_indices_for_protocols(protocols);
let all_services = service_manager.get_all();
let mut host_fallback: HashMap<String, Vec<usize>> = HashMap::new();
let mut host_paths: HashMap<String, Vec<(String, usize)>> = HashMap::new();
let mut catch_all_fallback: Vec<usize> = Vec::new();
let mut catch_all_paths: Vec<(String, usize)> = Vec::new();
for &idx in &service_indices {
let service = &all_services[idx];
let mut service_has_wildcard = false;
let mut service_hosts = std::collections::HashSet::new();
let mut service_paths = Vec::new();
let mut has_none_path = false;
for route in &service.routes {
let (hosts, is_wildcard) = route.matcher.get_required_hosts();
if is_wildcard {
service_has_wildcard = true;
}
service_hosts.extend(hosts);
if let Some(paths) = route.matcher.get_matchit_routes() {
service_paths.extend(paths);
} else {
has_none_path = true;
}
}
let is_catch_all = service_has_wildcard || service.routes.is_empty();
if is_catch_all {
if has_none_path || service.routes.is_empty() {
catch_all_fallback.push(idx);
} else {
for p in &service_paths {
catch_all_paths.push((p.clone(), idx));
}
}
}
for host in service_hosts {
if has_none_path || service.routes.is_empty() {
host_fallback.entry(host).or_default().push(idx);
} else {
for p in &service_paths {
host_paths
.entry(host.clone())
.or_default()
.push((p.clone(), idx));
}
}
}
}
let mut host_index = HashMap::new();
let mut all_hosts = std::collections::HashSet::new();
all_hosts.extend(host_fallback.keys().cloned());
all_hosts.extend(host_paths.keys().cloned());
for host in all_hosts {
let fallback = host_fallback.remove(&host).unwrap_or_default();
let mut path_map: HashMap<String, Vec<usize>> = HashMap::new();
if let Some(paths) = host_paths.remove(&host) {
for (p, idx) in paths {
path_map.entry(p).or_default().push(idx);
}
}
host_index.insert(host, Self::build_route_index(path_map, fallback));
}
let mut catch_all_path_map: HashMap<String, Vec<usize>> = HashMap::new();
for (p, idx) in catch_all_paths {
catch_all_path_map.entry(p).or_default().push(idx);
}
let catch_all_index = Self::build_route_index(catch_all_path_map, catch_all_fallback);
(host_index, catch_all_index)
}
fn build_route_index(
path_map: HashMap<String, Vec<usize>>,
mut fallback: Vec<usize>,
) -> Arc<RouteIndex> {
let mut path_router = matchit::Router::new();
let mut all_indices = fallback.clone();
let mut final_path_map: HashMap<String, Vec<usize>> = HashMap::new();
for (path, indices) in &path_map {
let mut combined_indices = indices.clone();
for (other_path, other_indices) in &path_map {
if path == other_path {
continue;
}
if let Some(prefix) = other_path.strip_suffix("/{*rest}") {
let prefix_slash = format!("{}/", prefix);
if path.starts_with(&prefix_slash) {
combined_indices.extend(other_indices);
}
}
}
combined_indices.sort_unstable();
combined_indices.dedup();
final_path_map.insert(path.clone(), combined_indices);
}
for (mut path, indices) in final_path_map {
if !path.starts_with('/') {
path = format!("/{}", path);
}
all_indices.extend(&indices);
if let Err(e) = path_router.insert(path.clone(), indices.clone()) {
log::warn!(
"Failed to insert route '{}' into matchit: {}. Falling back to linear scan.",
path,
e
);
fallback.extend(indices);
}
}
fallback.sort_unstable();
fallback.dedup();
all_indices.sort_unstable();
all_indices.dedup();
Arc::new(RouteIndex {
path_router,
fallback_indices: fallback,
all_indices,
})
}
pub fn refresh_indices(&self) {
let (host_index, catch_all_index) =
Self::build_indices(&self.service_manager, &self.protocols);
self.host_index.store(Arc::new(host_index));
self.catch_all_index.store(Arc::new(catch_all_index));
log::debug!("Router indices refreshed");
}
pub fn match_request(
&self,
req_header: &RequestHeader,
client_protocol: ServiceProtocol,
) -> Option<RouteMatch> {
let all_services = self.service_manager.get_all();
let uri_host = req_header.uri.host();
let header_host = req_header.headers.get("Host").and_then(|v| v.to_str().ok());
if uri_host.is_some() || header_host.is_some() {
let host_index = self.host_index.load();
if let Some(host) = uri_host
&& let Some(route_index) = host_index.get(host)
&& let Some(m) = Self::find_in_route_index(
&all_services,
route_index,
req_header,
&client_protocol,
)
{
return Some(m);
}
if let Some(host) = header_host
&& Some(host) != uri_host
&& let Some(route_index) = host_index.get(host)
&& let Some(m) = Self::find_in_route_index(
&all_services,
route_index,
req_header,
&client_protocol,
)
{
return Some(m);
}
}
let catch_all_index = self.catch_all_index.load();
Self::find_in_route_index(
&all_services,
&catch_all_index,
req_header,
&client_protocol,
)
}
#[inline]
fn find_in_route_index(
all_services: &[crate::server::service::RuntimeService],
route_index: &RouteIndex,
req_header: &RequestHeader,
client_protocol: &ServiceProtocol,
) -> Option<RouteMatch> {
let path = req_header.uri.path();
let path_indices = if let Ok(matched) = route_index.path_router.at(path) {
matched.value.as_slice()
} else {
&[]
};
let fallback_indices = &route_index.fallback_indices;
let mut p_idx = 0;
let mut f_idx = 0;
while p_idx < path_indices.len() || f_idx < fallback_indices.len() {
let check_idx;
if p_idx < path_indices.len() && f_idx < fallback_indices.len() {
if path_indices[p_idx] < fallback_indices[f_idx] {
check_idx = path_indices[p_idx];
p_idx += 1;
} else if path_indices[p_idx] > fallback_indices[f_idx] {
check_idx = fallback_indices[f_idx];
f_idx += 1;
} else {
check_idx = path_indices[p_idx];
p_idx += 1;
f_idx += 1;
}
} else if p_idx < path_indices.len() {
check_idx = path_indices[p_idx];
p_idx += 1;
} else {
check_idx = fallback_indices[f_idx];
f_idx += 1;
}
if let Some(service) = all_services.get(check_idx) {
if !service.protocols.is_empty() && !service.protocols.contains(client_protocol) {
continue;
}
for route in &service.routes {
if route.matcher.matches(req_header) {
return Some(RouteMatch {
upstream_name: service.host.clone(),
req_transformer: route.req_transformer.clone(),
res_transformer: route.res_transformer.clone(),
max_retries: route.max_retries,
});
}
}
}
}
None
}
pub fn upstream_manager(&self) -> &Arc<UpstreamManager> {
&self.upstream_manager
}
pub fn service_manager(&self) -> &Arc<ServiceManager> {
&self.service_manager
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::models::{JokowayConfig, Route, Service, ServiceProtocol};
use crate::extensions::dns::DnsResolver;
use crate::server::context::{AppContext, Context};
fn create_test_config() -> JokowayConfig {
JokowayConfig {
services: vec![
Service {
name: "http_only".to_string(),
host: "http_backend".to_string(),
protocols: vec![ServiceProtocol::Http],
routes: vec![Route {
name: "http_route".to_string(),
rule: "Host(`example.com`)".to_string(),
priority: None,
..Default::default()
}],
..Default::default()
},
Service {
name: "https_only".to_string(),
host: "https_backend".to_string(),
protocols: vec![ServiceProtocol::Https],
routes: vec![Route {
name: "https_route".to_string(),
rule: "Host(`secure.example.com`)".to_string(),
priority: None,
..Default::default()
}],
..Default::default()
},
Service {
name: "dual_protocol".to_string(),
host: "dual_backend".to_string(),
protocols: vec![ServiceProtocol::Http, ServiceProtocol::Https],
routes: vec![Route {
name: "dual_route".to_string(),
rule: "Host(`dual.example.com`)".to_string(),
priority: None,
..Default::default()
}],
..Default::default()
},
Service {
name: "no_protocol".to_string(),
host: "default_backend".to_string(),
protocols: vec![],
routes: vec![Route {
name: "default_route".to_string(),
rule: "Host(`default.example.com`)".to_string(),
priority: None,
..Default::default()
}],
..Default::default()
},
]
.into_iter()
.map(Arc::new)
.collect(),
..Default::default()
}
}
#[test]
fn test_protocol_restriction() {
let config = create_test_config();
let config_arc = Arc::new(config.clone());
let service_manager = Arc::new(
ServiceManager::new(config_arc.clone()).expect("Failed to create ServiceManager"),
);
let app_ctx = AppContext::new();
app_ctx.insert(config.clone());
app_ctx.insert(DnsResolver::new(&config));
let (upstream_manager, _) =
UpstreamManager::new(&app_ctx).expect("Failed to create UpstreamManager");
let upstream_manager = Arc::new(upstream_manager);
let http_router = Router::new(
service_manager.clone(),
upstream_manager.clone(),
&HTTP_PROTOCOLS,
);
assert_eq!(count_unique_services(&http_router), 3);
let https_router = Router::new(
service_manager.clone(),
upstream_manager.clone(),
&HTTPS_PROTOCOLS,
);
assert_eq!(count_unique_services(&https_router), 3);
let all_router = Router::new(
service_manager.clone(),
upstream_manager.clone(),
&ALL_PROTOCOLS,
);
assert_eq!(count_unique_services(&all_router), 4);
}
fn count_unique_services(router: &Router) -> usize {
let mut unique_indices = std::collections::HashSet::new();
for route_index in router.host_index.load().values() {
for &idx in &route_index.all_indices {
unique_indices.insert(idx);
}
}
for &idx in &router.catch_all_index.load().all_indices {
unique_indices.insert(idx);
}
unique_indices.len()
}
#[test]
fn test_router_refresh_on_service_changes() {
let config = create_test_config();
let config_arc = Arc::new(config.clone());
let service_manager = Arc::new(
ServiceManager::new(config_arc.clone()).expect("Failed to create ServiceManager"),
);
let app_ctx = AppContext::new();
app_ctx.insert(config.clone());
app_ctx.insert(DnsResolver::new(&config));
let (upstream_manager, _) =
UpstreamManager::new(&app_ctx).expect("Failed to create UpstreamManager");
let upstream_manager = Arc::new(upstream_manager);
let http_router = Router::new(
service_manager.clone(),
upstream_manager.clone(),
&HTTP_PROTOCOLS,
);
assert_eq!(count_unique_services(&http_router), 3);
let new_service = Service {
name: "new_http_service".to_string(),
host: "new_backend".to_string(),
protocols: vec![ServiceProtocol::Http],
routes: vec![Route {
name: "new_route".to_string(),
rule: "Host(`new.example.com`)".to_string(),
priority: None,
..Default::default()
}],
..Default::default()
};
service_manager
.add_service(new_service)
.expect("Failed to add service");
assert_eq!(count_unique_services(&http_router), 4);
service_manager
.remove_service("http_only")
.expect("Failed to remove service");
assert_eq!(count_unique_services(&http_router), 3);
}
#[test]
fn test_router_match_request_scenarios() {
let services = vec![
Service {
name: "service_a".to_string(),
host: "backend_a".to_string(),
protocols: vec![ServiceProtocol::Http],
routes: vec![Route {
name: "route_a".to_string(),
rule: "Host(`a.com`)".to_string(),
priority: None,
..Default::default()
}],
..Default::default()
},
Service {
name: "service_b".to_string(),
host: "backend_b".to_string(),
protocols: vec![ServiceProtocol::Http],
routes: vec![Route {
name: "route_b".to_string(),
rule: "Host(`b.com`)".to_string(),
priority: None,
..Default::default()
}],
..Default::default()
},
Service {
name: "service_hybrid".to_string(),
host: "backend_hybrid".to_string(),
protocols: vec![ServiceProtocol::Http],
routes: vec![Route {
name: "route_hybrid".to_string(),
rule: "Host(`c.com`) || PathPrefix(`/c`)".to_string(),
priority: None,
..Default::default()
}],
..Default::default()
},
Service {
name: "service_wild".to_string(),
host: "backend_wild".to_string(),
protocols: vec![ServiceProtocol::Http],
routes: vec![Route {
name: "route_wild".to_string(),
rule: "PathPrefix(`/wild`)".to_string(),
priority: None,
..Default::default()
}],
..Default::default()
},
Service {
name: "service_complex_no_wild".to_string(),
host: "backend_complex_no_wild".to_string(),
protocols: vec![ServiceProtocol::Http],
routes: vec![Route {
name: "route_complex_no_wild".to_string(),
rule: "Host(`c.com`) || Host(`a.com`) && PathPrefix(`/c`)".to_string(),
priority: None,
..Default::default()
}],
..Default::default()
},
];
let config = JokowayConfig {
services: services.into_iter().map(Arc::new).collect(),
..Default::default()
};
let config_arc = Arc::new(config.clone());
let service_manager = Arc::new(
ServiceManager::new(config_arc.clone()).expect("Failed to create ServiceManager"),
);
let app_ctx = AppContext::new();
app_ctx.insert(config.clone());
app_ctx.insert(DnsResolver::new(&config));
let (upstream_manager, _) =
UpstreamManager::new(&app_ctx).expect("Failed to create UpstreamManager");
let upstream_manager = Arc::new(upstream_manager);
let router = Router::new(
service_manager.clone(),
upstream_manager.clone(),
&HTTP_PROTOCOLS,
);
let mut req_a = RequestHeader::build("GET", b"/", None).unwrap();
req_a.insert_header("Host", "a.com").unwrap();
let match_a = router
.match_request(&req_a, ServiceProtocol::Http)
.expect("Should match service_a");
assert_eq!(match_a.upstream_name.as_ref(), "backend_a");
let mut req_b = RequestHeader::build("GET", b"/foo", None).unwrap();
req_b.insert_header("Host", "b.com").unwrap();
let match_b = router
.match_request(&req_b, ServiceProtocol::Http)
.expect("Should match service_b");
assert_eq!(match_b.upstream_name.as_ref(), "backend_b");
let mut req_c1 = RequestHeader::build("GET", b"/anything", None).unwrap();
req_c1.insert_header("Host", "c.com").unwrap();
let match_c1 = router
.match_request(&req_c1, ServiceProtocol::Http)
.expect("Should match service_hybrid by host");
assert_eq!(match_c1.upstream_name.as_ref(), "backend_hybrid");
let mut req_c2 = RequestHeader::build("GET", b"/c/foo", None).unwrap();
req_c2.insert_header("Host", "other.com").unwrap();
let match_c2 = router
.match_request(&req_c2, ServiceProtocol::Http)
.expect("Should match service_hybrid by path (catch-all)");
assert_eq!(match_c2.upstream_name.as_ref(), "backend_hybrid");
let mut req_d = RequestHeader::build("GET", b"/wild/bar", None).unwrap();
req_d.insert_header("Host", "random.com").unwrap();
let match_d = router
.match_request(&req_d, ServiceProtocol::Http)
.expect("Should match service_wild");
assert_eq!(match_d.upstream_name.as_ref(), "backend_wild");
let mut req_e = RequestHeader::build("GET", b"/nomatch", None).unwrap();
req_e.insert_header("Host", "other.com").unwrap();
let match_e = router.match_request(&req_e, ServiceProtocol::Http);
assert!(match_e.is_none());
let mut req_f1 = RequestHeader::build("GET", b"/anything", None).unwrap();
req_f1.insert_header("Host", "c.com").unwrap();
let match_f1 = router
.match_request(&req_f1, ServiceProtocol::Http)
.expect("Should match complex rule via c.com");
assert_eq!(match_f1.upstream_name.as_ref(), "backend_hybrid");
let mut req_f2 = RequestHeader::build("GET", b"/c/foo", None).unwrap();
req_f2.insert_header("Host", "a.com").unwrap();
let match_f2 = router
.match_request(&req_f2, ServiceProtocol::Http)
.expect("Should match service_a");
assert_eq!(match_f2.upstream_name.as_ref(), "backend_a");
let mut req_f3 = RequestHeader::build("GET", b"/c/foo", None).unwrap();
req_f3.insert_header("Host", "other.com").unwrap();
let match_f3 = router
.match_request(&req_f3, ServiceProtocol::Http)
.expect("Should match service_hybrid (catch-all)");
assert_eq!(match_f3.upstream_name.as_ref(), "backend_hybrid");
assert_eq!(router.catch_all_index.load().all_indices.len(), 2);
}
#[test]
fn test_overlapping_route_regression() {
let services = vec![
Service {
name: "specific_post".to_string(),
host: "backend_1".to_string(),
protocols: vec![ServiceProtocol::Http],
routes: vec![Route {
name: "r1".to_string(),
rule: "Host(`api.com`) && PathPrefix(`/api/users`) && Method(`POST`)"
.to_string(),
priority: None,
..Default::default()
}],
..Default::default()
},
Service {
name: "general_api".to_string(),
host: "backend_2".to_string(),
protocols: vec![ServiceProtocol::Http],
routes: vec![Route {
name: "r2".to_string(),
rule: "Host(`api.com`) && PathPrefix(`/api`)".to_string(),
priority: None,
..Default::default()
}],
..Default::default()
},
];
let config = JokowayConfig {
services: services.into_iter().map(Arc::new).collect(),
..Default::default()
};
let app_ctx = AppContext::new();
app_ctx.insert(config.clone());
app_ctx.insert(DnsResolver::new(&config));
let (upstream_manager, _) = UpstreamManager::new(&app_ctx).unwrap();
let sm = Arc::new(ServiceManager::new(Arc::new(config)).unwrap());
let router = Router::new(sm, Arc::new(upstream_manager), &HTTP_PROTOCOLS);
let mut req_post = RequestHeader::build("POST", b"/api/users/123", None).unwrap();
req_post.insert_header("Host", "api.com").unwrap();
let m_post = router
.match_request(&req_post, ServiceProtocol::Http)
.expect("POST /api/users should match");
assert_eq!(m_post.upstream_name.as_ref(), "backend_1");
let mut req_get = RequestHeader::build("GET", b"/api/users/123", None).unwrap();
req_get.insert_header("Host", "api.com").unwrap();
let m_get = router
.match_request(&req_get, ServiceProtocol::Http)
.expect("GET /api/users should fall back to general_api");
assert_eq!(m_get.upstream_name.as_ref(), "backend_2");
}
#[test]
fn test_complex_matchit_fallback_and_isolation() {
let services = vec![
Service {
name: "api_health".to_string(),
host: "backend_health".to_string(),
protocols: vec![ServiceProtocol::Http],
routes: vec![Route {
name: "r1".to_string(),
rule: "Host(`api.com`) && Path(`/api/v1/health`)".to_string(),
priority: None,
..Default::default()
}],
..Default::default()
},
Service {
name: "api_v1_post".to_string(),
host: "backend_v1_post".to_string(),
protocols: vec![ServiceProtocol::Http],
routes: vec![Route {
name: "r2".to_string(),
rule: "Host(`api.com`) && PathPrefix(`/api/v1`) && Method(`POST`)".to_string(),
priority: None,
..Default::default()
}],
..Default::default()
},
Service {
name: "api_regex_fallback".to_string(),
host: "backend_regex".to_string(),
protocols: vec![ServiceProtocol::Http],
routes: vec![Route {
name: "r4".to_string(),
rule: "Host(`api.com`) && PathRegexp(`^/api/v[0-9]+/special$`)".to_string(),
priority: None,
..Default::default()
}],
..Default::default()
},
Service {
name: "api_general".to_string(),
host: "backend_general".to_string(),
protocols: vec![ServiceProtocol::Http],
routes: vec![Route {
name: "r3".to_string(),
rule: "Host(`api.com`) && PathPrefix(`/api`)".to_string(),
priority: None,
..Default::default()
}],
..Default::default()
},
Service {
name: "api_test_env".to_string(),
host: "backend_test_env".to_string(),
protocols: vec![ServiceProtocol::Http],
routes: vec![Route {
name: "r5".to_string(),
rule: "Host(`api.test.com`) && PathPrefix(`/api`)".to_string(),
priority: None,
..Default::default()
}],
..Default::default()
},
Service {
name: "catch_all_wildcard".to_string(),
host: "backend_wildcard".to_string(),
protocols: vec![ServiceProtocol::Http],
routes: vec![Route {
name: "r6".to_string(),
rule: "PathPrefix(`/wildcard`)".to_string(),
priority: None,
..Default::default()
}],
..Default::default()
},
];
let config = JokowayConfig {
services: services.into_iter().map(Arc::new).collect(),
..Default::default()
};
let app_ctx = AppContext::new();
app_ctx.insert(config.clone());
app_ctx.insert(DnsResolver::new(&config));
let (upstream_manager, _) = UpstreamManager::new(&app_ctx).unwrap();
let sm = Arc::new(ServiceManager::new(Arc::new(config)).unwrap());
let router = Router::new(sm, Arc::new(upstream_manager), &HTTP_PROTOCOLS);
let mut req = RequestHeader::build("GET", b"/api/v1/health", None).unwrap();
req.insert_header("Host", "api.com").unwrap();
let m = router
.match_request(&req, ServiceProtocol::Http)
.expect("Should match health");
assert_eq!(m.upstream_name.as_ref(), "backend_health");
let mut req = RequestHeader::build("POST", b"/api/v1/users", None).unwrap();
req.insert_header("Host", "api.com").unwrap();
let m = router
.match_request(&req, ServiceProtocol::Http)
.expect("Should match v1 post");
assert_eq!(m.upstream_name.as_ref(), "backend_v1_post");
let mut req = RequestHeader::build("GET", b"/api/v1/users", None).unwrap();
req.insert_header("Host", "api.com").unwrap();
let m = router
.match_request(&req, ServiceProtocol::Http)
.expect("Should fall back to general");
assert_eq!(m.upstream_name.as_ref(), "backend_general");
let mut req = RequestHeader::build("GET", b"/api/v2/special", None).unwrap();
req.insert_header("Host", "api.com").unwrap();
let m = router
.match_request(&req, ServiceProtocol::Http)
.expect("Should match regex");
assert_eq!(m.upstream_name.as_ref(), "backend_regex");
let mut req = RequestHeader::build("GET", b"/api/v1/health", None).unwrap();
req.insert_header("Host", "api.test.com").unwrap();
let m = router
.match_request(&req, ServiceProtocol::Http)
.expect("Should match test env, not health exact match");
assert_eq!(m.upstream_name.as_ref(), "backend_test_env");
let mut req = RequestHeader::build("GET", b"/wildcard/foo", None).unwrap();
req.insert_header("Host", "random.com").unwrap();
let m = router
.match_request(&req, ServiceProtocol::Http)
.expect("Should match wildcard");
assert_eq!(m.upstream_name.as_ref(), "backend_wildcard");
}
#[test]
fn test_dynamic_protocol_rejection() {
let services = vec![
Service {
name: "ws_only".to_string(),
host: "backend_ws".to_string(),
protocols: vec![ServiceProtocol::Ws],
routes: vec![Route {
name: "r1".to_string(),
rule: "Host(`ws.com`)".to_string(),
priority: Some(10),
..Default::default()
}],
..Default::default()
},
Service {
name: "http_fallback".to_string(),
host: "backend_http".to_string(),
protocols: vec![ServiceProtocol::Http],
routes: vec![Route {
name: "r2".to_string(),
rule: "Host(`ws.com`)".to_string(),
priority: Some(5),
..Default::default()
}],
..Default::default()
},
];
let config = JokowayConfig {
services: services.into_iter().map(Arc::new).collect(),
..Default::default()
};
let app_ctx = AppContext::new();
app_ctx.insert(config.clone());
app_ctx.insert(DnsResolver::new(&config));
let (upstream_manager, _) = UpstreamManager::new(&app_ctx).unwrap();
let sm = Arc::new(ServiceManager::new(Arc::new(config)).unwrap());
let router = Router::new(sm, Arc::new(upstream_manager), &ALL_PROTOCOLS);
let mut req = RequestHeader::build("GET", b"/", None).unwrap();
req.insert_header("Host", "ws.com").unwrap();
let match_http = router
.match_request(&req, ServiceProtocol::Http)
.expect("Should fallback to HTTP service");
assert_eq!(match_http.upstream_name.as_ref(), "backend_http");
let match_ws = router
.match_request(&req, ServiceProtocol::Ws)
.expect("Should match WS service");
assert_eq!(match_ws.upstream_name.as_ref(), "backend_ws");
let match_https = router.match_request(&req, ServiceProtocol::Https);
assert!(match_https.is_none());
}
}