1pub mod pattern_trie;
2
3use std::{
4 fmt::{self, Debug, Write},
5 rc::Rc,
6 str::from_utf8,
7 time::Instant,
8};
9
10use regex::bytes::Regex;
11use sozu_command::{
12 logging::CachedTags,
13 proto::command::{
14 HeaderPosition, HstsConfig, PathRule as CommandPathRule, PathRuleKind, RedirectPolicy,
15 RedirectScheme, RulePosition,
16 },
17 response::HttpFrontend,
18 state::ClusterId,
19};
20
21use crate::metrics::names;
22use crate::{
23 protocol::{http::editor::HeaderEditMode, http::parser::Method},
24 router::pattern_trie::{TrieMatches, TrieNode, TrieSubMatch},
25 sozu_command::logging::ansi_palette,
26};
27
28macro_rules! log_module_context {
35 () => {{
36 let (open, reset, _, _, _) = ansi_palette();
37 format!("{open}ROUTER{reset}\t >>>", open = open, reset = reset)
38 }};
39}
40
41#[derive(thiserror::Error, Debug, PartialEq)]
42pub enum RouterError {
43 #[error("Could not parse rule from frontend path {0:?}")]
44 InvalidPathRule(String),
45 #[error("parsing hostname {hostname} failed")]
46 InvalidDomain { hostname: String },
47 #[error("Could not parse host rewrite {0:?}")]
48 InvalidHostRewrite(String),
49 #[error("Could not parse path rewrite {0:?}")]
50 InvalidPathRewrite(String),
51 #[error("Could not add route {0}")]
52 AddRoute(String),
53 #[error("Could not remove route {0}")]
54 RemoveRoute(String),
55 #[error("no route for {method} {host} {path}")]
56 RouteNotFound {
57 host: String,
58 path: String,
59 method: Method,
60 },
61}
62
63pub struct Router {
64 pre: Vec<(DomainRule, PathRule, MethodRule, Route)>,
65 pub tree: TrieNode<Vec<(PathRule, MethodRule, Route)>>,
66 post: Vec<(DomainRule, PathRule, MethodRule, Route)>,
67}
68
69impl Default for Router {
70 fn default() -> Self {
71 Self::new()
72 }
73}
74
75impl Router {
76 pub fn new() -> Router {
77 Router {
78 pre: Vec::new(),
79 tree: TrieNode::root(),
80 post: Vec::new(),
81 }
82 }
83
84 pub fn lookup(
93 &self,
94 hostname: &str,
95 path: &str,
96 method: &Method,
97 ) -> Result<RouteResult, RouterError> {
98 let hostname_b = hostname.as_bytes();
99 let path_b = path.as_bytes();
100 for (domain_rule, path_rule, method_rule, route) in &self.pre {
101 if domain_rule.matches(hostname_b)
102 && path_rule.matches(path_b) != PathRuleResult::None
103 && method_rule.matches(method) != MethodRuleResult::None
104 {
105 return Ok(RouteResult::new_no_trie(
106 hostname_b,
107 domain_rule,
108 path_b,
109 path_rule,
110 route,
111 ));
112 }
113 }
114
115 let trie_path: TrieMatches<'_, '_> = Vec::with_capacity(16);
116 if let Some(((_, path_rules), trie_matches)) =
117 self.tree.lookup_with_path(hostname_b, true, trie_path)
118 {
119 let mut prefix_length = 0;
120 let mut matched: Option<(&PathRule, &Route)> = None;
121
122 for (rule, method_rule, route) in path_rules {
123 match rule.matches(path_b) {
124 PathRuleResult::Regex | PathRuleResult::Equals => {
125 match method_rule.matches(method) {
126 MethodRuleResult::Equals => {
127 return Ok(RouteResult::new_with_trie(
128 hostname_b,
129 trie_matches,
130 path_b,
131 rule,
132 route,
133 ));
134 }
135 MethodRuleResult::All => {
136 prefix_length = path_b.len();
137 matched = Some((rule, route));
138 }
139 MethodRuleResult::None => {}
140 }
141 }
142 PathRuleResult::Prefix(size) => {
143 if size >= prefix_length {
144 match method_rule.matches(method) {
145 MethodRuleResult::Equals => {
147 debug_assert!(
151 size >= prefix_length,
152 "longest-prefix selection must never shrink the match length",
153 );
154 prefix_length = size;
155 matched = Some((rule, route));
156 }
157 MethodRuleResult::All => {
158 debug_assert!(
159 size >= prefix_length,
160 "longest-prefix selection must never shrink the match length",
161 );
162 prefix_length = size;
163 matched = Some((rule, route));
164 }
165 MethodRuleResult::None => {}
166 }
167 }
168 }
169 PathRuleResult::None => {}
170 }
171 }
172
173 if let Some((path_rule, route)) = matched {
174 return Ok(RouteResult::new_with_trie(
175 hostname_b,
176 trie_matches,
177 path_b,
178 path_rule,
179 route,
180 ));
181 }
182 }
183
184 for (domain_rule, path_rule, method_rule, route) in self.post.iter() {
185 if domain_rule.matches(hostname_b)
186 && path_rule.matches(path_b) != PathRuleResult::None
187 && method_rule.matches(method) != MethodRuleResult::None
188 {
189 return Ok(RouteResult::new_no_trie(
190 hostname_b,
191 domain_rule,
192 path_b,
193 path_rule,
194 route,
195 ));
196 }
197 }
198
199 Err(RouterError::RouteNotFound {
200 host: hostname.to_owned(),
201 path: path.to_owned(),
202 method: method.to_owned(),
203 })
204 }
205
206 pub fn add_http_front(&mut self, front: &HttpFrontend) -> Result<(), RouterError> {
213 self.add_http_front_with_hsts_origin(front, HstsOrigin::Explicit)
214 }
215
216 pub fn add_http_front_with_hsts_origin(
223 &mut self,
224 front: &HttpFrontend,
225 hsts_origin: HstsOrigin,
226 ) -> Result<(), RouterError> {
227 let path_rule = PathRule::from_config(front.path.clone())
228 .ok_or(RouterError::InvalidPathRule(front.path.to_string()))?;
229
230 let method_rule = MethodRule::new(front.method.clone());
231
232 let has_policy = front.redirect.is_some()
237 || front.redirect_scheme.is_some()
238 || front.redirect_template.is_some()
239 || front.rewrite_host.is_some()
240 || front.rewrite_path.is_some()
241 || front.rewrite_port.is_some()
242 || front.required_auth.unwrap_or(false)
243 || !front.headers.is_empty()
244 || front.hsts.is_some();
245
246 let domain =
247 front
248 .hostname
249 .parse::<DomainRule>()
250 .map_err(|_| RouterError::InvalidDomain {
251 hostname: front.hostname.clone(),
252 })?;
253
254 let route = if has_policy {
255 let redirect = front
256 .redirect
257 .and_then(|r| RedirectPolicy::try_from(r).ok())
258 .unwrap_or(RedirectPolicy::Forward);
259 let redirect_scheme = front
260 .redirect_scheme
261 .and_then(|s| RedirectScheme::try_from(s).ok())
262 .unwrap_or(RedirectScheme::UseSame);
263 let frontend = Frontend::new(
264 &domain,
265 &path_rule,
266 front,
267 redirect,
268 redirect_scheme,
269 front.redirect_template.clone(),
270 front.rewrite_host.clone(),
271 front.rewrite_path.clone(),
272 front.rewrite_port.and_then(|p| u16::try_from(p).ok()),
273 &front.headers,
274 front.required_auth.unwrap_or(false),
275 hsts_origin,
276 )?;
277 Route::Frontend(Rc::new(frontend))
278 } else {
279 match &front.cluster_id {
280 Some(cluster_id) => Route::ClusterId(cluster_id.clone()),
281 None => Route::Deny,
282 }
283 };
284
285 let success = match front.position {
286 RulePosition::Pre => self.add_pre_rule(&domain, &path_rule, &method_rule, &route),
287 RulePosition::Post => self.add_post_rule(&domain, &path_rule, &method_rule, &route),
288 RulePosition::Tree => {
289 self.add_tree_rule(front.hostname.as_bytes(), &path_rule, &method_rule, &route)
290 }
291 };
292 if !success {
293 return Err(RouterError::AddRoute(format!("{front:?}")));
294 }
295 Ok(())
296 }
297
298 pub fn remove_http_front(&mut self, front: &HttpFrontend) -> Result<(), RouterError> {
299 let path_rule = PathRule::from_config(front.path.clone())
300 .ok_or(RouterError::InvalidPathRule(front.path.to_string()))?;
301
302 let method_rule = MethodRule::new(front.method.clone());
303
304 let remove_success = match front.position {
305 RulePosition::Pre => {
306 let domain = front.hostname.parse::<DomainRule>().map_err(|_| {
307 RouterError::InvalidDomain {
308 hostname: front.hostname.clone(),
309 }
310 })?;
311
312 self.remove_pre_rule(&domain, &path_rule, &method_rule)
313 }
314 RulePosition::Post => {
315 let domain = front.hostname.parse::<DomainRule>().map_err(|_| {
316 RouterError::InvalidDomain {
317 hostname: front.hostname.clone(),
318 }
319 })?;
320
321 self.remove_post_rule(&domain, &path_rule, &method_rule)
322 }
323 RulePosition::Tree => {
324 self.remove_tree_rule(front.hostname.as_bytes(), &path_rule, &method_rule)
325 }
326 };
327 if !remove_success {
328 return Err(RouterError::RemoveRoute(format!("{front:?}")));
329 }
330 Ok(())
331 }
332
333 pub fn add_tree_rule(
334 &mut self,
335 hostname: &[u8],
336 path: &PathRule,
337 method: &MethodRule,
338 cluster: &Route,
339 ) -> bool {
340 let hostname = match from_utf8(hostname) {
341 Err(_) => return false,
342 Ok(h) => h,
343 };
344
345 match ::idna::domain_to_ascii(hostname) {
346 Ok(hostname) => {
347 let mut empty = true;
349 if let Some((_, paths)) = self.tree.domain_lookup_mut(hostname.as_bytes(), false) {
350 empty = false;
351 let before = paths.len();
352 if !paths.iter().any(|(p, m, _)| p == path && m == method) {
353 paths.push((path.to_owned(), method.to_owned(), cluster.to_owned()));
354 debug_assert_eq!(
357 paths.len(),
358 before + 1,
359 "appending a tree rule must grow the leaf's rule list by exactly one",
360 );
361 debug_assert!(
362 paths.iter().any(|(p, m, _)| p == path && m == method),
363 "the freshly appended (path, method) rule must be present after insert",
364 );
365 return true;
366 }
367 }
368
369 if empty {
370 let inserted_host = hostname.clone().into_bytes();
375 self.tree.domain_insert(
376 hostname.into_bytes(),
377 vec![(path.to_owned(), method.to_owned(), cluster.to_owned())],
378 );
379 debug_assert!(
388 self.tree
389 .domain_lookup_mut(&inserted_host, false)
390 .is_some_and(|(_, paths)| paths
391 .iter()
392 .any(|(p, m, _)| p == path && m == method)),
393 "a freshly inserted tree domain must resolve to its inserted rule",
394 );
395 return true;
396 }
397
398 false
399 }
400 Err(_) => false,
401 }
402 }
403
404 pub fn remove_tree_rule(
405 &mut self,
406 hostname: &[u8],
407 path: &PathRule,
408 method: &MethodRule,
409 ) -> bool {
411 let hostname = match from_utf8(hostname) {
412 Err(_) => return false,
413 Ok(h) => h,
414 };
415
416 match ::idna::domain_to_ascii(hostname) {
417 Ok(hostname) => {
418 let should_delete = {
419 let paths_opt = self.tree.domain_lookup_mut(hostname.as_bytes(), false);
420
421 if let Some((_, paths)) = paths_opt {
422 paths.retain(|(p, m, _)| p != path || m != method);
423 debug_assert!(
426 !paths.iter().any(|(p, m, _)| p == path && m == method),
427 "remove must evict every matching (path, method) rule from the leaf",
428 );
429 }
430
431 paths_opt
432 .as_ref()
433 .map(|(_, paths)| paths.is_empty())
434 .unwrap_or(false)
435 };
436
437 if should_delete {
438 let removed_host = hostname.clone().into_bytes();
439 self.tree.domain_remove(&hostname.into_bytes());
440 debug_assert!(
448 self.tree.domain_lookup_mut(&removed_host, false).is_none(),
449 "a domain whose last rule was removed must be unreachable",
450 );
451 }
452
453 true
454 }
455 Err(_) => false,
456 }
457 }
458
459 pub fn refresh_inheriting_hsts(&mut self, new_hsts: Option<&HstsConfig>) -> usize {
503 let mut refreshed = 0usize;
504 let new_edit = build_listener_hsts_edit(new_hsts);
512 let new_edit_ref = new_edit.as_ref();
513 let promote_lightweight = new_edit_ref.is_some();
514 let mut visit = |route: &mut Route| match route {
515 Route::Frontend(rc) => {
516 if rc.inherits_listener_hsts {
517 let new_frontend = rebuild_with_listener_hsts(rc, new_edit_ref);
518 *rc = Rc::new(new_frontend);
519 refreshed += 1;
520 }
521 }
522 Route::ClusterId(id) => {
523 if promote_lightweight {
524 let promoted = rebuild_with_listener_hsts(
525 &Frontend::minimal_forward(id.clone()),
526 new_edit_ref,
527 );
528 *route = Route::Frontend(Rc::new(promoted));
529 refreshed += 1;
530 }
531 }
532 Route::Deny => {
533 if promote_lightweight {
534 let promoted =
535 rebuild_with_listener_hsts(&Frontend::minimal_deny(), new_edit_ref);
536 *route = Route::Frontend(Rc::new(promoted));
537 refreshed += 1;
538 }
539 }
540 };
541
542 for (_, _, _, route) in self.pre.iter_mut() {
543 visit(route);
544 }
545 self.tree.for_each_value_mut(&mut |paths| {
546 for (_, _, route) in paths.iter_mut() {
547 visit(route);
548 }
549 });
550 for (_, _, _, route) in self.post.iter_mut() {
551 visit(route);
552 }
553 refreshed
554 }
555
556 pub fn add_pre_rule(
557 &mut self,
558 domain: &DomainRule,
559 path: &PathRule,
560 method: &MethodRule,
561 cluster_id: &Route,
562 ) -> bool {
563 let before = self.pre.len();
564 if !self
565 .pre
566 .iter()
567 .any(|(d, p, m, _)| d == domain && p == path && m == method)
568 {
569 self.pre.push((
570 domain.to_owned(),
571 path.to_owned(),
572 method.to_owned(),
573 cluster_id.to_owned(),
574 ));
575 debug_assert_eq!(
579 self.pre.len(),
580 before + 1,
581 "adding a unique pre-rule must push exactly one entry",
582 );
583 debug_assert!(
584 self.pre
585 .iter()
586 .any(|(d, p, m, _)| d == domain && p == path && m == method),
587 "the freshly added pre-rule must be present",
588 );
589 true
590 } else {
591 debug_assert_eq!(
592 self.pre.len(),
593 before,
594 "a duplicate pre-rule must not change the list length",
595 );
596 false
597 }
598 }
599
600 pub fn add_post_rule(
601 &mut self,
602 domain: &DomainRule,
603 path: &PathRule,
604 method: &MethodRule,
605 cluster_id: &Route,
606 ) -> bool {
607 let before = self.post.len();
608 if !self
609 .post
610 .iter()
611 .any(|(d, p, m, _)| d == domain && p == path && m == method)
612 {
613 self.post.push((
614 domain.to_owned(),
615 path.to_owned(),
616 method.to_owned(),
617 cluster_id.to_owned(),
618 ));
619 debug_assert_eq!(
620 self.post.len(),
621 before + 1,
622 "adding a unique post-rule must push exactly one entry",
623 );
624 debug_assert!(
625 self.post
626 .iter()
627 .any(|(d, p, m, _)| d == domain && p == path && m == method),
628 "the freshly added post-rule must be present",
629 );
630 true
631 } else {
632 debug_assert_eq!(
633 self.post.len(),
634 before,
635 "a duplicate post-rule must not change the list length",
636 );
637 false
638 }
639 }
640
641 pub fn remove_pre_rule(
642 &mut self,
643 domain: &DomainRule,
644 path: &PathRule,
645 method: &MethodRule,
646 ) -> bool {
647 let before = self.pre.len();
648 match self
649 .pre
650 .iter()
651 .position(|(d, p, m, _)| d == domain && p == path && m == method)
652 {
653 None => {
654 debug_assert_eq!(
655 self.pre.len(),
656 before,
657 "a no-op pre-rule removal must not change the list length",
658 );
659 false
660 }
661 Some(index) => {
662 debug_assert!(index < self.pre.len(), "found index must be in bounds");
663 self.pre.remove(index);
664 debug_assert_eq!(
666 self.pre.len() + 1,
667 before,
668 "removing a pre-rule must drop exactly one entry",
669 );
670 debug_assert!(
671 !self
672 .pre
673 .iter()
674 .any(|(d, p, m, _)| d == domain && p == path && m == method),
675 "the removed pre-rule must no longer be present",
676 );
677 true
678 }
679 }
680 }
681
682 pub fn remove_post_rule(
683 &mut self,
684 domain: &DomainRule,
685 path: &PathRule,
686 method: &MethodRule,
687 ) -> bool {
688 let before = self.post.len();
689 match self
690 .post
691 .iter()
692 .position(|(d, p, m, _)| d == domain && p == path && m == method)
693 {
694 None => {
695 debug_assert_eq!(
696 self.post.len(),
697 before,
698 "a no-op post-rule removal must not change the list length",
699 );
700 false
701 }
702 Some(index) => {
703 debug_assert!(index < self.post.len(), "found index must be in bounds");
704 self.post.remove(index);
705 debug_assert_eq!(
706 self.post.len() + 1,
707 before,
708 "removing a post-rule must drop exactly one entry",
709 );
710 debug_assert!(
711 !self
712 .post
713 .iter()
714 .any(|(d, p, m, _)| d == domain && p == path && m == method),
715 "the removed post-rule must no longer be present",
716 );
717 true
718 }
719 }
720 }
721
722 pub fn has_hostname(&self, hostname: &str) -> bool {
727 let hostname_b = hostname.as_bytes();
728
729 for (domain_rule, _, _, _) in &self.pre {
731 if domain_rule.matches(hostname_b) {
732 return true;
733 }
734 }
735
736 if let Ok(ascii_hostname) = ::idna::domain_to_ascii(hostname) {
738 if self
739 .tree
740 .domain_lookup(ascii_hostname.as_bytes(), false)
741 .is_some()
742 {
743 return true;
744 }
745 }
746
747 for (domain_rule, _, _, _) in &self.post {
749 if domain_rule.matches(hostname_b) {
750 return true;
751 }
752 }
753
754 false
755 }
756}
757
758#[derive(Clone, Debug)]
759pub enum DomainRule {
760 Any,
761 Exact(String),
762 Wildcard(String),
768 Regex(Regex),
769}
770
771fn convert_regex_domain_rule(hostname: &str) -> Option<String> {
772 let mut result = String::from("\\A");
778
779 let s = hostname.as_bytes();
780 let mut index = 0;
781 loop {
782 if s[index] == b'/' {
783 let mut found = false;
784 for i in index + 1..s.len() {
785 if s[i] == b'/' {
786 match std::str::from_utf8(&s[index + 1..i]) {
787 Ok(r) => result.push_str(r),
788 Err(_) => return None,
789 }
790 index = i + 1;
791 found = true;
792 break;
793 }
794 }
795
796 if !found {
797 return None;
798 }
799 } else {
800 let start = index;
801 for i in start..s.len() + 1 {
802 index = i;
803 if i < s.len() && s[i] == b'.' {
804 match std::str::from_utf8(&s[start..i]) {
805 Ok(r) => result.push_str(r),
806 Err(_) => return None,
807 }
808 break;
809 }
810 }
811 if index == s.len() {
812 match std::str::from_utf8(&s[start..]) {
813 Ok(r) => result.push_str(r),
814 Err(_) => return None,
815 }
816 }
817 }
818
819 if index == s.len() {
820 result.push_str("\\z");
821 return Some(result);
822 } else if s[index] == b'.' {
823 result.push_str("\\.");
824 index += 1;
825 } else {
826 return None;
827 }
828 }
829}
830
831impl DomainRule {
832 pub fn matches(&self, hostname: &[u8]) -> bool {
833 match self {
834 DomainRule::Any => true,
835 DomainRule::Wildcard(s) => {
836 debug_assert_eq!(
840 s.as_bytes().first(),
841 Some(&b'*'),
842 "a Wildcard rule must retain its leading '*'",
843 );
844 let suffix = &s.as_bytes()[1..];
845 let matched = hostname
846 .strip_suffix(suffix)
847 .is_some_and(|prefix| !prefix.is_empty() && !prefix.contains(&b'.'));
848 debug_assert!(
852 !matched || hostname.len() > suffix.len(),
853 "a wildcard match requires a non-empty leftmost label before the suffix",
854 );
855 matched
856 }
857 DomainRule::Exact(s) => s.as_bytes() == hostname,
858 DomainRule::Regex(r) => {
859 let start = Instant::now();
860 let is_a_match = r.is_match(hostname);
861 let now = Instant::now();
862 time!(
863 names::event_loop::REGEX_MATCHING_TIME,
864 (now - start).as_millis()
865 );
866 is_a_match
867 }
868 }
869 }
870}
871
872impl std::cmp::PartialEq for DomainRule {
873 fn eq(&self, other: &Self) -> bool {
874 match (self, other) {
875 (DomainRule::Any, DomainRule::Any) => true,
876 (DomainRule::Wildcard(s1), DomainRule::Wildcard(s2)) => s1 == s2,
877 (DomainRule::Exact(s1), DomainRule::Exact(s2)) => s1 == s2,
878 (DomainRule::Regex(r1), DomainRule::Regex(r2)) => r1.as_str() == r2.as_str(),
879 _ => false,
880 }
881 }
882}
883
884impl std::str::FromStr for DomainRule {
885 type Err = ();
886
887 fn from_str(s: &str) -> Result<Self, Self::Err> {
888 Ok(if s == "*" {
889 DomainRule::Any
890 } else if s.contains('/') {
891 match convert_regex_domain_rule(s) {
892 Some(s) => match regex::bytes::Regex::new(&s) {
893 Ok(r) => DomainRule::Regex(r),
894 Err(_) => return Err(()),
895 },
896 None => return Err(()),
897 }
898 } else if s.contains('*') {
899 if s.starts_with('*') {
900 match ::idna::domain_to_ascii(s) {
901 Ok(r) => DomainRule::Wildcard(r),
902 Err(_) => return Err(()),
903 }
904 } else {
905 return Err(());
906 }
907 } else {
908 match ::idna::domain_to_ascii(s) {
909 Ok(r) => DomainRule::Exact(r),
910 Err(_) => return Err(()),
911 }
912 })
913 }
914}
915
916#[derive(Clone, Debug)]
917pub enum PathRule {
918 Prefix(String),
919 Regex(Regex),
920 Equals(String),
921}
922
923#[derive(PartialEq, Eq)]
924pub enum PathRuleResult {
925 Regex,
926 Prefix(usize),
927 Equals,
928 None,
929}
930
931impl PathRule {
932 pub fn matches(&self, path: &[u8]) -> PathRuleResult {
933 match self {
934 PathRule::Prefix(prefix) => {
935 if path.starts_with(prefix.as_bytes()) {
936 debug_assert!(
940 prefix.len() <= path.len(),
941 "a matching prefix cannot be longer than the path it matched",
942 );
943 PathRuleResult::Prefix(prefix.len())
944 } else {
945 PathRuleResult::None
946 }
947 }
948 PathRule::Regex(regex) => {
949 let start = Instant::now();
950 let is_a_match = regex.is_match(path);
951 let now = Instant::now();
952 time!(
953 names::event_loop::REGEX_MATCHING_TIME,
954 (now - start).as_millis()
955 );
956
957 if is_a_match {
958 PathRuleResult::Regex
959 } else {
960 PathRuleResult::None
961 }
962 }
963 PathRule::Equals(pattern) => {
964 if path == pattern.as_bytes() {
965 PathRuleResult::Equals
966 } else {
967 PathRuleResult::None
968 }
969 }
970 }
971 }
972
973 pub fn from_config(rule: CommandPathRule) -> Option<Self> {
974 match PathRuleKind::try_from(rule.kind) {
975 Ok(PathRuleKind::Prefix) => Some(PathRule::Prefix(rule.value)),
976 Ok(PathRuleKind::Regex) => Regex::new(&rule.value).ok().map(PathRule::Regex),
977 Ok(PathRuleKind::Equals) => Some(PathRule::Equals(rule.value)),
978 Err(_) => None,
979 }
980 }
981}
982
983impl std::cmp::PartialEq for PathRule {
984 fn eq(&self, other: &Self) -> bool {
985 match (self, other) {
986 (PathRule::Prefix(s1), PathRule::Prefix(s2)) => s1 == s2,
987 (PathRule::Regex(r1), PathRule::Regex(r2)) => r1.as_str() == r2.as_str(),
988 _ => false,
989 }
990 }
991}
992
993#[derive(Clone, Debug, PartialEq, Eq)]
994pub struct MethodRule {
995 pub inner: Option<Method>,
996}
997
998#[derive(PartialEq, Eq)]
999pub enum MethodRuleResult {
1000 All,
1001 Equals,
1002 None,
1003}
1004
1005impl MethodRule {
1006 pub fn new(method: Option<String>) -> Self {
1007 MethodRule {
1008 inner: method.map(|s| Method::new(s.as_bytes())),
1009 }
1010 }
1011
1012 pub fn matches(&self, method: &Method) -> MethodRuleResult {
1013 match self.inner {
1014 None => MethodRuleResult::All,
1015 Some(ref m) => {
1016 if method == m {
1017 MethodRuleResult::Equals
1018 } else {
1019 MethodRuleResult::None
1020 }
1021 }
1022 }
1023 }
1024}
1025
1026#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
1048pub enum Route {
1049 Deny,
1051 ClusterId(ClusterId),
1053 Frontend(Rc<Frontend>),
1058}
1059
1060fn build_listener_hsts_edit(new_hsts: Option<&HstsConfig>) -> Option<HeaderEdit> {
1087 let cfg = new_hsts?;
1088 if !matches!(cfg.enabled, Some(true)) {
1089 return None;
1090 }
1091 let rendered = render_hsts(cfg)?;
1092 let mode = if matches!(cfg.force_replace_backend, Some(true)) {
1093 HeaderEditMode::Set
1094 } else {
1095 HeaderEditMode::SetIfAbsent
1096 };
1097 Some(HeaderEdit {
1098 key: Rc::from(&b"strict-transport-security"[..]),
1099 val: rendered.into_bytes().into(),
1100 mode,
1101 })
1102}
1103
1104fn rebuild_with_listener_hsts(frontend: &Frontend, new_edit: Option<&HeaderEdit>) -> Frontend {
1120 let mut headers_response: Vec<HeaderEdit> = frontend
1122 .headers_response
1123 .iter()
1124 .filter(|edit| !edit.key.eq_ignore_ascii_case(b"strict-transport-security"))
1125 .cloned()
1126 .collect();
1127
1128 if let Some(edit) = new_edit {
1131 headers_response.push(edit.clone());
1132 }
1133
1134 Frontend {
1135 headers_response: headers_response.into(),
1136 ..frontend.clone()
1138 }
1139}
1140
1141pub fn render_hsts(cfg: &HstsConfig) -> Option<String> {
1154 let max_age = cfg.max_age?;
1155 let mut s = format!("max-age={max_age}");
1156 if matches!(cfg.include_subdomains, Some(true)) {
1157 s.push_str("; includeSubDomains");
1158 }
1159 if matches!(cfg.preload, Some(true)) {
1160 s.push_str("; preload");
1161 }
1162 Some(s)
1163}
1164
1165#[derive(Clone, PartialEq, Eq)]
1177pub struct HeaderEdit {
1178 pub key: Rc<[u8]>,
1179 pub val: Rc<[u8]>,
1180 pub mode: HeaderEditMode,
1181}
1182
1183impl Debug for HeaderEdit {
1184 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1185 f.write_fmt(format_args!(
1186 "({:?}, {:?}, {:?})",
1187 String::from_utf8_lossy(&self.key),
1188 String::from_utf8_lossy(&self.val),
1189 self.mode,
1190 ))
1191 }
1192}
1193
1194#[derive(Debug, Clone, PartialEq, Eq)]
1202enum RewritePart {
1203 String(String),
1204 Host(usize),
1205 Path(usize),
1206}
1207
1208#[derive(Debug, Clone, PartialEq, Eq)]
1223pub struct RewriteParts(Vec<RewritePart>);
1224
1225impl RewriteParts {
1226 pub fn parse(
1238 template: &str,
1239 host_cap_cap: usize,
1240 path_cap_cap: usize,
1241 used_index_host: &mut usize,
1242 used_index_path: &mut usize,
1243 ) -> Option<Self> {
1244 let mut result = Vec::new();
1245 let mut i = 0;
1246 let pattern = template.as_bytes();
1247 while i < pattern.len() {
1248 if pattern[i] == b'$' {
1249 let is_host = if pattern[i..].starts_with(b"$HOST[") {
1250 i += 6;
1251 true
1252 } else if pattern[i..].starts_with(b"$PATH[") {
1253 i += 6;
1254 false
1255 } else {
1256 return None;
1257 };
1258 let mut index = 0usize;
1259 let digits_start = i;
1260 while i < pattern.len() && pattern[i].is_ascii_digit() {
1261 index = index
1262 .checked_mul(10)?
1263 .checked_add((pattern[i] - b'0') as usize)?;
1264 i += 1;
1265 }
1266 if i == digits_start {
1267 return None;
1269 }
1270 if i >= pattern.len() || pattern[i] != b']' {
1271 return None;
1272 }
1273 if is_host {
1274 if index >= host_cap_cap {
1275 return None;
1276 }
1277 if index >= *used_index_host {
1278 *used_index_host = index + 1;
1279 }
1280 result.push(RewritePart::Host(index));
1281 } else {
1282 if index >= path_cap_cap {
1283 return None;
1284 }
1285 if index >= *used_index_path {
1286 *used_index_path = index + 1;
1287 }
1288 result.push(RewritePart::Path(index));
1289 }
1290 i += 1; } else {
1292 let start = i;
1293 while i < pattern.len() && pattern[i] != b'$' {
1294 i += 1;
1295 }
1296 result.push(RewritePart::String(template[start..i].to_owned()));
1301 }
1302 }
1303 debug_assert!(
1306 result.iter().all(|part| match part {
1307 RewritePart::Host(idx) => *idx < host_cap_cap,
1308 RewritePart::Path(idx) => *idx < path_cap_cap,
1309 RewritePart::String(_) => true,
1310 }),
1311 "a parsed rewrite template must only reference captures within the rule's caps",
1312 );
1313 debug_assert!(
1314 *used_index_host <= host_cap_cap && *used_index_path <= path_cap_cap,
1315 "the highest referenced capture index cannot exceed the cap",
1316 );
1317 Some(Self(result))
1318 }
1319
1320 pub fn run(&self, host_captures: &[&str], path_captures: &[&str]) -> String {
1325 let mut cap = 0usize;
1326 for part in &self.0 {
1327 cap += match part {
1328 RewritePart::String(s) => s.len(),
1329 RewritePart::Host(i) => host_captures.get(*i).map(|s| s.len()).unwrap_or(0),
1330 RewritePart::Path(i) => path_captures.get(*i).map(|s| s.len()).unwrap_or(0),
1331 };
1332 }
1333 let mut result = String::with_capacity(cap);
1334 for part in &self.0 {
1335 let _ = match part {
1337 RewritePart::String(s) => result.write_str(s),
1338 RewritePart::Host(i) => result.write_str(host_captures.get(*i).unwrap_or(&"")),
1339 RewritePart::Path(i) => result.write_str(path_captures.get(*i).unwrap_or(&"")),
1340 };
1341 }
1342 debug_assert_eq!(
1346 result.len(),
1347 cap,
1348 "rewrite output length must equal the pre-computed one-pass capacity",
1349 );
1350 result
1351 }
1352}
1353
1354#[derive(Debug, Clone)]
1367pub struct Frontend {
1368 pub cluster_id: Option<ClusterId>,
1369 pub redirect: RedirectPolicy,
1370 pub redirect_scheme: RedirectScheme,
1371 pub redirect_template: Option<String>,
1372 pub capture_cap_host: usize,
1376 pub capture_cap_path: usize,
1380 pub rewrite_host: Option<RewriteParts>,
1381 pub rewrite_path: Option<RewriteParts>,
1382 pub rewrite_port: Option<u16>,
1383 pub headers_request: Rc<[HeaderEdit]>,
1384 pub headers_response: Rc<[HeaderEdit]>,
1385 pub required_auth: bool,
1386 pub tags: Option<Rc<CachedTags>>,
1387 pub inherits_listener_hsts: bool,
1396}
1397
1398#[derive(Copy, Clone, Debug, PartialEq, Eq)]
1406pub enum HstsOrigin {
1407 Explicit,
1412 InheritedFromListenerDefault,
1417}
1418
1419impl PartialEq for Frontend {
1420 fn eq(&self, other: &Self) -> bool {
1421 self.cluster_id == other.cluster_id
1425 && self.redirect == other.redirect
1426 && self.redirect_scheme == other.redirect_scheme
1427 && self.redirect_template == other.redirect_template
1428 && self.rewrite_host == other.rewrite_host
1429 && self.rewrite_path == other.rewrite_path
1430 && self.rewrite_port == other.rewrite_port
1431 && self.headers_request == other.headers_request
1432 && self.headers_response == other.headers_response
1433 && self.required_auth == other.required_auth
1434 }
1435}
1436
1437impl Eq for Frontend {}
1438
1439impl std::hash::Hash for Frontend {
1440 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
1441 self.cluster_id.hash(state);
1442 (self.redirect as i32).hash(state);
1445 (self.redirect_scheme as i32).hash(state);
1446 self.redirect_template.hash(state);
1447 self.required_auth.hash(state);
1448 }
1449}
1450
1451impl PartialOrd for Frontend {
1452 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
1453 Some(self.cmp(other))
1454 }
1455}
1456
1457impl Ord for Frontend {
1458 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
1459 self.cluster_id
1460 .cmp(&other.cluster_id)
1461 .then_with(|| (self.redirect as i32).cmp(&(other.redirect as i32)))
1462 .then_with(|| (self.redirect_scheme as i32).cmp(&(other.redirect_scheme as i32)))
1463 .then_with(|| self.redirect_template.cmp(&other.redirect_template))
1464 .then_with(|| self.required_auth.cmp(&other.required_auth))
1465 }
1466}
1467
1468impl Frontend {
1469 #[allow(clippy::too_many_arguments)]
1490 pub fn new(
1491 domain_rule: &DomainRule,
1492 path_rule: &PathRule,
1493 front: &HttpFrontend,
1494 redirect: RedirectPolicy,
1495 redirect_scheme: RedirectScheme,
1496 redirect_template: Option<String>,
1497 rewrite_host: Option<String>,
1498 rewrite_path: Option<String>,
1499 rewrite_port: Option<u16>,
1500 headers: &[sozu_command::proto::command::Header],
1501 required_auth: bool,
1502 hsts_origin: HstsOrigin,
1503 ) -> Result<Self, RouterError> {
1504 let hsts = front.hsts.as_ref();
1511 let inherits_listener_hsts =
1512 matches!(hsts_origin, HstsOrigin::InheritedFromListenerDefault) && hsts.is_some();
1513 let cluster_id = front.cluster_id.clone();
1514 let tags = front
1515 .tags
1516 .clone()
1517 .map(|tags| Rc::new(CachedTags::new(tags)));
1518
1519 let redirect_template = redirect_template.filter(|s| !s.is_empty());
1524 let rewrite_host = rewrite_host.filter(|s| !s.is_empty());
1525 let rewrite_path = rewrite_path.filter(|s| !s.is_empty());
1526
1527 let deny = match (&cluster_id, redirect) {
1528 (_, RedirectPolicy::Unauthorized) => true,
1529 (None, RedirectPolicy::Forward) => {
1530 warn!(
1531 "{} Frontend[domain: {:?}, path: {:?}]: forward on clusterless frontends are unauthorized",
1532 log_module_context!(),
1533 domain_rule,
1534 path_rule,
1535 );
1536 true
1537 }
1538 _ => false,
1539 };
1540 if deny {
1541 let mut deny_headers_response: Vec<HeaderEdit> = Vec::new();
1549 if let Some(cfg) = hsts
1550 && matches!(cfg.enabled, Some(true))
1551 && let Some(rendered) = render_hsts(cfg)
1552 {
1553 let mode = if matches!(cfg.force_replace_backend, Some(true)) {
1554 HeaderEditMode::Set
1555 } else {
1556 HeaderEditMode::SetIfAbsent
1557 };
1558 deny_headers_response.push(HeaderEdit {
1559 key: Rc::from(&b"strict-transport-security"[..]),
1560 val: rendered.into_bytes().into(),
1561 mode,
1562 });
1563 crate::incr!(names::http::HSTS_FRONTEND_ADDED);
1564 }
1565
1566 return Ok(Self {
1567 cluster_id,
1568 redirect: RedirectPolicy::Unauthorized,
1569 redirect_scheme,
1570 redirect_template: None,
1571 capture_cap_host: 0,
1572 capture_cap_path: 0,
1573 rewrite_host: None,
1574 rewrite_path: None,
1575 rewrite_port: None,
1576 headers_request: Rc::new([]),
1577 headers_response: deny_headers_response.into(),
1578 required_auth,
1579 tags,
1580 inherits_listener_hsts,
1581 });
1582 }
1583
1584 let mut capture_cap_host = match domain_rule {
1589 DomainRule::Any => 1,
1590 DomainRule::Exact(_) => 1,
1591 DomainRule::Wildcard(_) => 2,
1592 DomainRule::Regex(regex) => regex.captures_len(),
1593 };
1594 let mut capture_cap_path = match path_rule {
1595 PathRule::Equals(_) => 1,
1596 PathRule::Prefix(_) => 2,
1597 PathRule::Regex(regex) => regex.captures_len(),
1598 };
1599 let mut used_capture_host = 0usize;
1600 let mut used_capture_path = 0usize;
1601 let rewrite_host_parts = if let Some(p) = rewrite_host {
1602 Some(
1603 RewriteParts::parse(
1604 &p,
1605 capture_cap_host,
1606 capture_cap_path,
1607 &mut used_capture_host,
1608 &mut used_capture_path,
1609 )
1610 .ok_or(RouterError::InvalidHostRewrite(p))?,
1611 )
1612 } else {
1613 None
1614 };
1615 let rewrite_path_parts = if let Some(p) = rewrite_path {
1616 Some(
1617 RewriteParts::parse(
1618 &p,
1619 capture_cap_host,
1620 capture_cap_path,
1621 &mut used_capture_host,
1622 &mut used_capture_path,
1623 )
1624 .ok_or(RouterError::InvalidPathRewrite(p))?,
1625 )
1626 } else {
1627 None
1628 };
1629 if used_capture_host == 0 {
1632 capture_cap_host = 0;
1633 }
1634 if used_capture_path == 0 {
1635 capture_cap_path = 0;
1636 }
1637
1638 let mut headers_request = Vec::new();
1639 let mut headers_response = Vec::new();
1640 for header in headers {
1641 let edit = HeaderEdit {
1642 key: header.key.as_bytes().into(),
1643 val: header.val.as_bytes().into(),
1644 mode: HeaderEditMode::Append,
1645 };
1646 match header.position() {
1647 HeaderPosition::Request => headers_request.push(edit),
1648 HeaderPosition::Response => headers_response.push(edit),
1649 HeaderPosition::Both => {
1650 headers_request.push(edit.clone());
1651 headers_response.push(edit);
1652 }
1653 HeaderPosition::Unspecified => {
1659 warn!(
1660 "{} dropping Header {{ key: {:?}, val: {:?} }} with HEADER_POSITION_UNSPECIFIED",
1661 log_module_context!(),
1662 header.key,
1663 header.val,
1664 );
1665 }
1666 }
1667 }
1668
1669 if let Some(cfg) = hsts
1677 && matches!(cfg.enabled, Some(true))
1678 {
1679 if let Some(rendered) = render_hsts(cfg) {
1680 let mode = if matches!(cfg.force_replace_backend, Some(true)) {
1686 HeaderEditMode::Set
1687 } else {
1688 HeaderEditMode::SetIfAbsent
1689 };
1690 headers_response.push(HeaderEdit {
1691 key: Rc::from(&b"strict-transport-security"[..]),
1692 val: rendered.into_bytes().into(),
1693 mode,
1694 });
1695 crate::incr!(names::http::HSTS_FRONTEND_ADDED);
1696 } else {
1697 warn!(
1705 "{} HSTS enabled = true on frontend {:?} but render_hsts \
1706 returned None (max_age missing). Frontend will not emit \
1707 Strict-Transport-Security; the config layer that built \
1708 this HstsConfig must substitute DEFAULT_HSTS_MAX_AGE.",
1709 log_module_context!(),
1710 cluster_id,
1711 );
1712 crate::incr!(names::http::HSTS_UNRENDERED);
1713 }
1714 }
1715
1716 Ok(Frontend {
1717 cluster_id,
1718 redirect,
1719 redirect_scheme,
1720 redirect_template,
1721 capture_cap_host,
1722 capture_cap_path,
1723 rewrite_host: rewrite_host_parts,
1724 rewrite_path: rewrite_path_parts,
1725 rewrite_port,
1726 headers_request: headers_request.into(),
1727 headers_response: headers_response.into(),
1728 required_auth,
1729 tags,
1730 inherits_listener_hsts,
1731 })
1732 }
1733
1734 pub(crate) fn minimal_forward(cluster_id: ClusterId) -> Self {
1745 Self {
1746 cluster_id: Some(cluster_id),
1747 redirect: RedirectPolicy::Forward,
1748 redirect_scheme: RedirectScheme::UseSame,
1749 redirect_template: None,
1750 capture_cap_host: 0,
1751 capture_cap_path: 0,
1752 rewrite_host: None,
1753 rewrite_path: None,
1754 rewrite_port: None,
1755 headers_request: Rc::new([]),
1756 headers_response: Rc::new([]),
1757 required_auth: false,
1758 tags: None,
1759 inherits_listener_hsts: true,
1760 }
1761 }
1762
1763 pub(crate) fn minimal_deny() -> Self {
1774 Self {
1775 cluster_id: None,
1776 redirect: RedirectPolicy::Unauthorized,
1777 redirect_scheme: RedirectScheme::UseSame,
1778 redirect_template: None,
1779 capture_cap_host: 0,
1780 capture_cap_path: 0,
1781 rewrite_host: None,
1782 rewrite_path: None,
1783 rewrite_port: None,
1784 headers_request: Rc::new([]),
1785 headers_response: Rc::new([]),
1786 required_auth: false,
1787 tags: None,
1788 inherits_listener_hsts: true,
1789 }
1790 }
1791}
1792
1793#[derive(Debug, Clone, PartialEq)]
1808pub struct RouteResult {
1809 pub cluster_id: Option<ClusterId>,
1810 pub redirect: RedirectPolicy,
1811 pub redirect_scheme: RedirectScheme,
1812 pub redirect_template: Option<String>,
1813 pub rewritten_host: Option<String>,
1814 pub rewritten_path: Option<String>,
1815 pub rewritten_port: Option<u16>,
1816 pub headers_request: Rc<[HeaderEdit]>,
1817 pub headers_response: Rc<[HeaderEdit]>,
1818 pub required_auth: bool,
1819 pub tags: Option<Rc<CachedTags>>,
1820}
1821
1822impl RouteResult {
1823 pub fn deny(cluster_id: Option<ClusterId>) -> Self {
1825 Self {
1826 cluster_id,
1827 redirect: RedirectPolicy::Unauthorized,
1828 redirect_scheme: RedirectScheme::UseSame,
1829 redirect_template: None,
1830 rewritten_host: None,
1831 rewritten_path: None,
1832 rewritten_port: None,
1833 headers_request: Rc::new([]),
1834 headers_response: Rc::new([]),
1835 required_auth: false,
1836 tags: None,
1837 }
1838 }
1839
1840 pub fn forward(cluster_id: ClusterId) -> Self {
1843 Self {
1844 cluster_id: Some(cluster_id),
1845 redirect: RedirectPolicy::Forward,
1846 redirect_scheme: RedirectScheme::UseSame,
1847 redirect_template: None,
1848 rewritten_host: None,
1849 rewritten_path: None,
1850 rewritten_port: None,
1851 headers_request: Rc::new([]),
1852 headers_response: Rc::new([]),
1853 required_auth: false,
1854 tags: None,
1855 }
1856 }
1857
1858 fn from_frontend(
1861 frontend: &Frontend,
1862 captures_host: Vec<&str>,
1863 path: &[u8],
1864 path_rule: &PathRule,
1865 ) -> Self {
1866 if frontend.redirect == RedirectPolicy::Unauthorized {
1874 return Self {
1875 cluster_id: frontend.cluster_id.clone(),
1876 redirect: RedirectPolicy::Unauthorized,
1877 redirect_scheme: frontend.redirect_scheme,
1878 redirect_template: frontend.redirect_template.clone(),
1879 rewritten_host: None,
1880 rewritten_path: None,
1881 rewritten_port: None,
1882 headers_request: Rc::new([]),
1883 headers_response: frontend.headers_response.clone(),
1884 required_auth: frontend.required_auth,
1885 tags: frontend.tags.clone(),
1886 };
1887 }
1888
1889 let mut captures_path: Vec<&str> = Vec::with_capacity(frontend.capture_cap_path);
1890 if frontend.capture_cap_path > 0 {
1891 captures_path.push(from_utf8(path).unwrap_or_default());
1892 match path_rule {
1893 PathRule::Prefix(prefix) => {
1894 let tail_start = prefix.len().min(path.len());
1895 captures_path.push(from_utf8(&path[tail_start..]).unwrap_or_default());
1896 }
1897 PathRule::Regex(regex) => {
1898 if let Some(caps) = regex.captures(path) {
1899 captures_path.extend(caps.iter().skip(1).map(|c| {
1900 c.map(|m| from_utf8(m.as_bytes()).unwrap_or_default())
1901 .unwrap_or("")
1902 }));
1903 }
1904 }
1905 PathRule::Equals(_) => {}
1906 }
1907 }
1908
1909 Self {
1910 cluster_id: frontend.cluster_id.clone(),
1911 redirect: frontend.redirect,
1912 redirect_scheme: frontend.redirect_scheme,
1913 redirect_template: frontend.redirect_template.clone(),
1914 rewritten_host: frontend
1915 .rewrite_host
1916 .as_ref()
1917 .map(|rewrite| rewrite.run(&captures_host, &captures_path)),
1918 rewritten_path: frontend
1919 .rewrite_path
1920 .as_ref()
1921 .map(|rewrite| rewrite.run(&captures_host, &captures_path)),
1922 rewritten_port: frontend.rewrite_port,
1923 headers_request: frontend.headers_request.clone(),
1924 headers_response: frontend.headers_response.clone(),
1925 required_auth: frontend.required_auth,
1926 tags: frontend.tags.clone(),
1927 }
1928 }
1929
1930 fn new_no_trie<'a>(
1935 domain: &'a [u8],
1936 domain_rule: &DomainRule,
1937 path: &'a [u8],
1938 path_rule: &PathRule,
1939 route: &Route,
1940 ) -> Self {
1941 let frontend = match route {
1942 Route::Frontend(f) => f.clone(),
1943 Route::ClusterId(id) => return Self::forward(id.clone()),
1944 Route::Deny => return Self::deny(None),
1945 };
1946 let mut captures_host: Vec<&str> = Vec::with_capacity(frontend.capture_cap_host);
1947 if frontend.capture_cap_host > 0 {
1948 captures_host.push(from_utf8(domain).unwrap_or_default());
1949 match domain_rule {
1950 DomainRule::Wildcard(suffix) => {
1951 let head_end = domain.len().saturating_sub(suffix.len().saturating_sub(1));
1952 captures_host.push(from_utf8(&domain[..head_end]).unwrap_or_default());
1953 }
1954 DomainRule::Regex(regex) => {
1955 if let Some(caps) = regex.captures(domain) {
1956 captures_host.extend(caps.iter().skip(1).map(|c| {
1957 c.map(|m| from_utf8(m.as_bytes()).unwrap_or_default())
1958 .unwrap_or("")
1959 }));
1960 }
1961 }
1962 DomainRule::Any | DomainRule::Exact(_) => {}
1963 }
1964 }
1965 Self::from_frontend(&frontend, captures_host, path, path_rule)
1966 }
1967
1968 fn new_with_trie<'a, 'b>(
1973 domain: &'a [u8],
1974 domain_submatches: TrieMatches<'a, 'b>,
1975 path: &'a [u8],
1976 path_rule: &PathRule,
1977 route: &Route,
1978 ) -> Self {
1979 let frontend = match route {
1980 Route::Frontend(f) => f.clone(),
1981 Route::ClusterId(id) => return Self::forward(id.clone()),
1982 Route::Deny => return Self::deny(None),
1983 };
1984 let mut captures_host: Vec<&str> = Vec::with_capacity(frontend.capture_cap_host);
1985 if frontend.capture_cap_host > 0 {
1986 captures_host.push(from_utf8(domain).unwrap_or_default());
1987 for submatch in &domain_submatches {
1988 match submatch {
1989 TrieSubMatch::Wildcard(part) => {
1990 captures_host.push(from_utf8(part).unwrap_or_default());
1991 }
1992 TrieSubMatch::Regexp(part, regex) => {
1993 if let Some(caps) = regex.captures(part) {
1994 captures_host.extend(caps.iter().skip(1).map(|c| {
1995 c.map(|m| from_utf8(m.as_bytes()).unwrap_or_default())
1996 .unwrap_or("")
1997 }));
1998 }
1999 }
2000 }
2001 }
2002 }
2003 Self::from_frontend(&frontend, captures_host, path, path_rule)
2004 }
2005}
2006
2007#[cfg(test)]
2008mod tests {
2009 use super::*;
2010
2011 #[test]
2012 fn render_hsts_max_age_only() {
2013 let cfg = HstsConfig {
2014 enabled: Some(true),
2015 max_age: Some(31_536_000),
2016 include_subdomains: None,
2017 preload: None,
2018 force_replace_backend: None,
2019 };
2020 assert_eq!(render_hsts(&cfg), Some("max-age=31536000".to_owned()));
2021 }
2022
2023 #[test]
2024 fn render_hsts_with_include_subdomains() {
2025 let cfg = HstsConfig {
2026 enabled: Some(true),
2027 max_age: Some(31_536_000),
2028 include_subdomains: Some(true),
2029 preload: None,
2030 force_replace_backend: None,
2031 };
2032 assert_eq!(
2033 render_hsts(&cfg),
2034 Some("max-age=31536000; includeSubDomains".to_owned())
2035 );
2036 }
2037
2038 #[test]
2039 fn render_hsts_with_preload_only() {
2040 let cfg = HstsConfig {
2041 enabled: Some(true),
2042 max_age: Some(63_072_000),
2043 include_subdomains: None,
2044 preload: Some(true),
2045 force_replace_backend: None,
2046 };
2047 assert_eq!(
2048 render_hsts(&cfg),
2049 Some("max-age=63072000; preload".to_owned())
2050 );
2051 }
2052
2053 #[test]
2054 fn render_hsts_full() {
2055 let cfg = HstsConfig {
2056 enabled: Some(true),
2057 max_age: Some(31_536_000),
2058 include_subdomains: Some(true),
2059 preload: Some(true),
2060 force_replace_backend: None,
2061 };
2062 assert_eq!(
2063 render_hsts(&cfg),
2064 Some("max-age=31536000; includeSubDomains; preload".to_owned())
2065 );
2066 }
2067
2068 #[test]
2069 fn render_hsts_kill_switch_max_age_zero() {
2070 let cfg = HstsConfig {
2071 enabled: Some(true),
2072 max_age: Some(0),
2073 include_subdomains: Some(true),
2074 preload: None,
2075 force_replace_backend: None,
2076 };
2077 assert_eq!(
2081 render_hsts(&cfg),
2082 Some("max-age=0; includeSubDomains".to_owned())
2083 );
2084 }
2085
2086 #[test]
2087 fn render_hsts_omitted_when_max_age_missing() {
2088 let cfg = HstsConfig {
2089 enabled: Some(true),
2090 max_age: None,
2091 include_subdomains: Some(true),
2092 preload: None,
2093 force_replace_backend: None,
2094 };
2095 assert_eq!(render_hsts(&cfg), None);
2099 }
2100
2101 #[test]
2102 fn rebuild_with_listener_hsts_replaces_existing_entry() {
2103 let frontend = Frontend {
2107 cluster_id: Some("api".to_owned()),
2108 redirect: RedirectPolicy::Forward,
2109 redirect_scheme: RedirectScheme::UseSame,
2110 redirect_template: None,
2111 capture_cap_host: 0,
2112 capture_cap_path: 0,
2113 rewrite_host: None,
2114 rewrite_path: None,
2115 rewrite_port: None,
2116 headers_request: Rc::new([]),
2117 headers_response: Rc::from(vec![
2118 HeaderEdit {
2119 key: Rc::from(&b"x-cache"[..]),
2120 val: Rc::from(&b"hit"[..]),
2121 mode: HeaderEditMode::Append,
2122 },
2123 HeaderEdit {
2124 key: Rc::from(&b"strict-transport-security"[..]),
2125 val: Rc::from(&b"max-age=31536000"[..]),
2126 mode: HeaderEditMode::SetIfAbsent,
2127 },
2128 ]),
2129 required_auth: false,
2130 tags: None,
2131 inherits_listener_hsts: true,
2132 };
2133 let new_hsts = HstsConfig {
2134 enabled: Some(true),
2135 max_age: Some(63_072_000),
2136 include_subdomains: Some(true),
2137 preload: None,
2138 force_replace_backend: None,
2139 };
2140 let new_edit = build_listener_hsts_edit(Some(&new_hsts));
2141 let rebuilt = rebuild_with_listener_hsts(&frontend, new_edit.as_ref());
2142
2143 let response: Vec<_> = rebuilt.headers_response.iter().collect();
2144 assert_eq!(response.len(), 2, "x-cache + new STS, no leftover STS");
2145 assert_eq!(&*response[0].key, b"x-cache");
2146 assert_eq!(&*response[1].key, b"strict-transport-security");
2147 assert_eq!(
2148 &*response[1].val,
2149 b"max-age=63072000; includeSubDomains".as_slice()
2150 );
2151 assert!(rebuilt.inherits_listener_hsts);
2152 }
2153
2154 #[test]
2155 fn rebuild_with_listener_hsts_strips_when_none() {
2156 let frontend = Frontend {
2159 cluster_id: Some("api".to_owned()),
2160 redirect: RedirectPolicy::Forward,
2161 redirect_scheme: RedirectScheme::UseSame,
2162 redirect_template: None,
2163 capture_cap_host: 0,
2164 capture_cap_path: 0,
2165 rewrite_host: None,
2166 rewrite_path: None,
2167 rewrite_port: None,
2168 headers_request: Rc::new([]),
2169 headers_response: Rc::from(vec![
2170 HeaderEdit {
2171 key: Rc::from(&b"x-cache"[..]),
2172 val: Rc::from(&b"hit"[..]),
2173 mode: HeaderEditMode::Append,
2174 },
2175 HeaderEdit {
2176 key: Rc::from(&b"strict-transport-security"[..]),
2177 val: Rc::from(&b"max-age=31536000"[..]),
2178 mode: HeaderEditMode::SetIfAbsent,
2179 },
2180 ]),
2181 required_auth: false,
2182 tags: None,
2183 inherits_listener_hsts: true,
2184 };
2185 let new_edit = build_listener_hsts_edit(None);
2186 let rebuilt = rebuild_with_listener_hsts(&frontend, new_edit.as_ref());
2187 let response: Vec<_> = rebuilt.headers_response.iter().collect();
2188 assert_eq!(response.len(), 1);
2189 assert_eq!(&*response[0].key, b"x-cache");
2190 }
2191
2192 #[test]
2193 fn rebuild_with_listener_hsts_disabled_strips() {
2194 let frontend = Frontend {
2197 cluster_id: Some("api".to_owned()),
2198 redirect: RedirectPolicy::Forward,
2199 redirect_scheme: RedirectScheme::UseSame,
2200 redirect_template: None,
2201 capture_cap_host: 0,
2202 capture_cap_path: 0,
2203 rewrite_host: None,
2204 rewrite_path: None,
2205 rewrite_port: None,
2206 headers_request: Rc::new([]),
2207 headers_response: Rc::from(vec![HeaderEdit {
2208 key: Rc::from(&b"strict-transport-security"[..]),
2209 val: Rc::from(&b"max-age=31536000"[..]),
2210 mode: HeaderEditMode::SetIfAbsent,
2211 }]),
2212 required_auth: false,
2213 tags: None,
2214 inherits_listener_hsts: true,
2215 };
2216 let new_hsts = HstsConfig {
2217 enabled: Some(false),
2218 max_age: None,
2219 include_subdomains: None,
2220 preload: None,
2221 force_replace_backend: None,
2222 };
2223 let new_edit = build_listener_hsts_edit(Some(&new_hsts));
2224 let rebuilt = rebuild_with_listener_hsts(&frontend, new_edit.as_ref());
2225 assert_eq!(rebuilt.headers_response.len(), 0);
2226 }
2227
2228 #[test]
2229 fn refresh_inheriting_hsts_skips_explicit_overrides() {
2230 use crate::router::pattern_trie::TrieNode;
2234 let mut router = Router {
2235 pre: Vec::new(),
2236 tree: TrieNode::root(),
2237 post: Vec::new(),
2238 };
2239 let inheriting = Frontend {
2240 cluster_id: Some("api".to_owned()),
2241 redirect: RedirectPolicy::Forward,
2242 redirect_scheme: RedirectScheme::UseSame,
2243 redirect_template: None,
2244 capture_cap_host: 0,
2245 capture_cap_path: 0,
2246 rewrite_host: None,
2247 rewrite_path: None,
2248 rewrite_port: None,
2249 headers_request: Rc::new([]),
2250 headers_response: Rc::from(vec![HeaderEdit {
2251 key: Rc::from(&b"strict-transport-security"[..]),
2252 val: Rc::from(&b"max-age=31536000"[..]),
2253 mode: HeaderEditMode::SetIfAbsent,
2254 }]),
2255 required_auth: false,
2256 tags: None,
2257 inherits_listener_hsts: true,
2258 };
2259 let explicit = Frontend {
2260 cluster_id: Some("legacy".to_owned()),
2261 redirect: RedirectPolicy::Forward,
2262 redirect_scheme: RedirectScheme::UseSame,
2263 redirect_template: None,
2264 capture_cap_host: 0,
2265 capture_cap_path: 0,
2266 rewrite_host: None,
2267 rewrite_path: None,
2268 rewrite_port: None,
2269 headers_request: Rc::new([]),
2270 headers_response: Rc::from(vec![HeaderEdit {
2271 key: Rc::from(&b"strict-transport-security"[..]),
2272 val: Rc::from(&b"max-age=300"[..]),
2273 mode: HeaderEditMode::SetIfAbsent,
2274 }]),
2275 required_auth: false,
2276 tags: None,
2277 inherits_listener_hsts: false,
2278 };
2279 router.pre.push((
2280 DomainRule::Any,
2281 PathRule::Prefix("/api".to_owned()),
2282 MethodRule::new(None),
2283 Route::Frontend(Rc::new(inheriting)),
2284 ));
2285 router.post.push((
2286 DomainRule::Any,
2287 PathRule::Prefix("/legacy".to_owned()),
2288 MethodRule::new(None),
2289 Route::Frontend(Rc::new(explicit)),
2290 ));
2291
2292 let new_hsts = HstsConfig {
2293 enabled: Some(true),
2294 max_age: Some(63_072_000),
2295 include_subdomains: Some(true),
2296 preload: None,
2297 force_replace_backend: None,
2298 };
2299 let count = router.refresh_inheriting_hsts(Some(&new_hsts));
2300 assert_eq!(count, 1, "only the inheriting frontend should refresh");
2301
2302 if let Route::Frontend(rc) = &router.pre[0].3 {
2303 let response: Vec<_> = rc.headers_response.iter().collect();
2304 assert_eq!(
2305 &*response.last().unwrap().val,
2306 b"max-age=63072000; includeSubDomains".as_slice(),
2307 "inheriting frontend's STS must reflect the new listener default"
2308 );
2309 } else {
2310 panic!("pre[0] should be Route::Frontend");
2311 }
2312 if let Route::Frontend(rc) = &router.post[0].3 {
2313 let response: Vec<_> = rc.headers_response.iter().collect();
2314 assert_eq!(
2315 &*response.last().unwrap().val,
2316 b"max-age=300".as_slice(),
2317 "explicit override must be preserved unchanged"
2318 );
2319 } else {
2320 panic!("post[0] should be Route::Frontend");
2321 }
2322 }
2323
2324 #[test]
2325 fn refresh_inheriting_hsts_promotes_clusterid_on_enable() {
2326 use crate::router::pattern_trie::TrieNode;
2336 let mut router = Router {
2337 pre: Vec::new(),
2338 tree: TrieNode::root(),
2339 post: vec![(
2340 DomainRule::Any,
2341 PathRule::Prefix("/".to_owned()),
2342 MethodRule::new(None),
2343 Route::ClusterId("api".to_owned()),
2344 )],
2345 };
2346
2347 let new_hsts = HstsConfig {
2348 enabled: Some(true),
2349 max_age: Some(31_536_000),
2350 include_subdomains: Some(true),
2351 preload: None,
2352 force_replace_backend: None,
2353 };
2354 let count = router.refresh_inheriting_hsts(Some(&new_hsts));
2355 assert_eq!(count, 1, "the ClusterId entry must be promoted + counted");
2356
2357 let Route::Frontend(rc) = &router.post[0].3 else {
2358 panic!("post[0] should now be Route::Frontend, not the original Route::ClusterId");
2359 };
2360 assert_eq!(rc.cluster_id.as_deref(), Some("api"));
2361 assert_eq!(
2362 rc.redirect,
2363 RedirectPolicy::Forward,
2364 "promoted entry must keep Forward semantics so lookup yields the same backend"
2365 );
2366 assert!(
2367 rc.inherits_listener_hsts,
2368 "promoted entry must mark itself inheriting so the next patch refreshes it"
2369 );
2370 let response: Vec<_> = rc.headers_response.iter().collect();
2371 assert_eq!(
2372 response.len(),
2373 1,
2374 "promoted entry carries exactly one STS edit, no operator headers"
2375 );
2376 assert_eq!(&*response[0].key, b"strict-transport-security");
2377 assert_eq!(
2378 &*response[0].val,
2379 b"max-age=31536000; includeSubDomains".as_slice()
2380 );
2381 }
2382
2383 #[test]
2384 fn refresh_inheriting_hsts_promotes_deny_on_enable() {
2385 use crate::router::pattern_trie::TrieNode;
2391 let mut router = Router {
2392 pre: Vec::new(),
2393 tree: TrieNode::root(),
2394 post: vec![(
2395 DomainRule::Any,
2396 PathRule::Prefix("/forbidden".to_owned()),
2397 MethodRule::new(None),
2398 Route::Deny,
2399 )],
2400 };
2401
2402 let new_hsts = HstsConfig {
2403 enabled: Some(true),
2404 max_age: Some(31_536_000),
2405 include_subdomains: None,
2406 preload: None,
2407 force_replace_backend: None,
2408 };
2409 let count = router.refresh_inheriting_hsts(Some(&new_hsts));
2410 assert_eq!(count, 1);
2411
2412 let Route::Frontend(rc) = &router.post[0].3 else {
2413 panic!("post[0] should now be Route::Frontend, not the original Route::Deny");
2414 };
2415 assert_eq!(rc.cluster_id, None, "promoted Deny stays clusterless");
2416 assert_eq!(
2417 rc.redirect,
2418 RedirectPolicy::Unauthorized,
2419 "promoted Deny must keep Unauthorized so lookup yields a 401"
2420 );
2421 assert!(rc.inherits_listener_hsts);
2422 let response: Vec<_> = rc.headers_response.iter().collect();
2423 assert_eq!(response.len(), 1);
2424 assert_eq!(&*response[0].key, b"strict-transport-security");
2425 assert_eq!(&*response[0].val, b"max-age=31536000".as_slice());
2426 }
2427
2428 #[test]
2429 fn refresh_inheriting_hsts_skips_lightweight_on_disable() {
2430 use crate::router::pattern_trie::TrieNode;
2440 let make_router = || Router {
2441 pre: vec![(
2442 DomainRule::Any,
2443 PathRule::Prefix("/".to_owned()),
2444 MethodRule::new(None),
2445 Route::ClusterId("api".to_owned()),
2446 )],
2447 tree: TrieNode::root(),
2448 post: vec![(
2449 DomainRule::Any,
2450 PathRule::Prefix("/forbidden".to_owned()),
2451 MethodRule::new(None),
2452 Route::Deny,
2453 )],
2454 };
2455
2456 for (label, hsts) in [
2457 ("none", None),
2458 (
2459 "disabled",
2460 Some(HstsConfig {
2461 enabled: Some(false),
2462 max_age: None,
2463 include_subdomains: None,
2464 preload: None,
2465 force_replace_backend: None,
2466 }),
2467 ),
2468 (
2469 "enabled-without-max-age",
2470 Some(HstsConfig {
2471 enabled: Some(true),
2472 max_age: None,
2473 include_subdomains: None,
2474 preload: None,
2475 force_replace_backend: None,
2476 }),
2477 ),
2478 ] {
2479 let mut router = make_router();
2480 let count = router.refresh_inheriting_hsts(hsts.as_ref());
2481 assert_eq!(count, 0, "no promotion expected for {label}");
2482 assert!(
2483 matches!(router.pre[0].3, Route::ClusterId(_)),
2484 "{label}: ClusterId must stay lightweight"
2485 );
2486 assert!(
2487 matches!(router.post[0].3, Route::Deny),
2488 "{label}: Deny must stay lightweight"
2489 );
2490 }
2491 }
2492
2493 #[test]
2494 fn refresh_inheriting_hsts_promoted_entry_refreshes_on_subsequent_patches() {
2495 use crate::router::pattern_trie::TrieNode;
2500 let mut router = Router {
2501 pre: Vec::new(),
2502 tree: TrieNode::root(),
2503 post: vec![(
2504 DomainRule::Any,
2505 PathRule::Prefix("/".to_owned()),
2506 MethodRule::new(None),
2507 Route::ClusterId("api".to_owned()),
2508 )],
2509 };
2510
2511 let first_patch = HstsConfig {
2512 enabled: Some(true),
2513 max_age: Some(31_536_000),
2514 include_subdomains: None,
2515 preload: None,
2516 force_replace_backend: None,
2517 };
2518 assert_eq!(router.refresh_inheriting_hsts(Some(&first_patch)), 1);
2519
2520 let second_patch = HstsConfig {
2521 enabled: Some(true),
2522 max_age: Some(63_072_000),
2523 include_subdomains: Some(true),
2524 preload: None,
2525 force_replace_backend: None,
2526 };
2527 assert_eq!(
2528 router.refresh_inheriting_hsts(Some(&second_patch)),
2529 1,
2530 "the previously promoted entry must be re-counted via the path-1 branch"
2531 );
2532
2533 let Route::Frontend(rc) = &router.post[0].3 else {
2534 panic!("post[0] should still be Route::Frontend after the second patch");
2535 };
2536 let response: Vec<_> = rc.headers_response.iter().collect();
2537 assert_eq!(
2538 response.len(),
2539 1,
2540 "second patch must REPLACE the existing STS edit, not append a duplicate"
2541 );
2542 assert_eq!(
2543 &*response[0].val,
2544 b"max-age=63072000; includeSubDomains".as_slice()
2545 );
2546 }
2547
2548 #[test]
2549 fn refresh_inheriting_hsts_promoted_entry_loses_hsts_on_disable_patch() {
2550 use crate::router::pattern_trie::TrieNode;
2557 let mut router = Router {
2558 pre: vec![(
2559 DomainRule::Any,
2560 PathRule::Prefix("/".to_owned()),
2561 MethodRule::new(None),
2562 Route::ClusterId("api".to_owned()),
2563 )],
2564 tree: TrieNode::root(),
2565 post: Vec::new(),
2566 };
2567
2568 let enable = HstsConfig {
2569 enabled: Some(true),
2570 max_age: Some(31_536_000),
2571 include_subdomains: None,
2572 preload: None,
2573 force_replace_backend: None,
2574 };
2575 assert_eq!(router.refresh_inheriting_hsts(Some(&enable)), 1);
2576
2577 let disable = HstsConfig {
2578 enabled: Some(false),
2579 max_age: None,
2580 include_subdomains: None,
2581 preload: None,
2582 force_replace_backend: None,
2583 };
2584 assert_eq!(
2585 router.refresh_inheriting_hsts(Some(&disable)),
2586 1,
2587 "the promoted entry must still be touched on disable to strip its STS edit"
2588 );
2589
2590 let Route::Frontend(rc) = &router.pre[0].3 else {
2591 panic!("pre[0] should still be Route::Frontend (no demotion)");
2592 };
2593 assert_eq!(rc.cluster_id.as_deref(), Some("api"));
2594 assert_eq!(
2595 rc.headers_response.len(),
2596 0,
2597 "disable patch must strip the STS edit from the promoted entry"
2598 );
2599 }
2600
2601 #[test]
2602 fn refresh_inheriting_hsts_promotes_clusterid_in_trie_on_enable() {
2603 use crate::router::pattern_trie::TrieNode;
2610 let mut router = Router {
2611 pre: Vec::new(),
2612 tree: TrieNode::root(),
2613 post: Vec::new(),
2614 };
2615 let path_rule = PathRule::Prefix("/".to_owned());
2616 let method_rule = MethodRule::new(None);
2617 assert!(router.add_tree_rule(
2618 b"example.com",
2619 &path_rule,
2620 &method_rule,
2621 &Route::ClusterId("api".to_owned()),
2622 ));
2623
2624 let new_hsts = HstsConfig {
2625 enabled: Some(true),
2626 max_age: Some(31_536_000),
2627 include_subdomains: Some(true),
2628 preload: None,
2629 force_replace_backend: None,
2630 };
2631 let count = router.refresh_inheriting_hsts(Some(&new_hsts));
2632 assert_eq!(
2633 count, 1,
2634 "trie-resident ClusterId must be promoted + counted"
2635 );
2636
2637 let (_, paths) = router
2638 .tree
2639 .domain_lookup_mut(b"example.com", false)
2640 .expect("trie leaf still present after refresh");
2641 assert_eq!(paths.len(), 1);
2642 let Route::Frontend(rc) = &paths[0].2 else {
2643 panic!("trie leaf should now be Route::Frontend, not Route::ClusterId");
2644 };
2645 assert_eq!(rc.cluster_id.as_deref(), Some("api"));
2646 assert_eq!(rc.redirect, RedirectPolicy::Forward);
2647 assert!(rc.inherits_listener_hsts);
2648 let response: Vec<_> = rc.headers_response.iter().collect();
2649 assert_eq!(response.len(), 1);
2650 assert_eq!(&*response[0].key, b"strict-transport-security");
2651 assert_eq!(
2652 &*response[0].val,
2653 b"max-age=31536000; includeSubDomains".as_slice()
2654 );
2655 }
2656
2657 #[test]
2658 fn convert_regex() {
2659 assert_eq!(
2662 convert_regex_domain_rule("www.example.com")
2663 .unwrap()
2664 .as_str(),
2665 "\\Awww\\.example\\.com\\z"
2666 );
2667 assert_eq!(
2668 convert_regex_domain_rule("*.example.com").unwrap().as_str(),
2669 "\\A*\\.example\\.com\\z"
2670 );
2671 assert_eq!(
2672 convert_regex_domain_rule("test.*.example.com")
2673 .unwrap()
2674 .as_str(),
2675 "\\Atest\\.*\\.example\\.com\\z"
2676 );
2677 assert_eq!(
2678 convert_regex_domain_rule("css./cdn[a-z0-9]+/.example.com")
2679 .unwrap()
2680 .as_str(),
2681 "\\Acss\\.cdn[a-z0-9]+\\.example\\.com\\z"
2682 );
2683
2684 assert_eq!(
2685 convert_regex_domain_rule("css./cdn[a-z0-9]+.example.com"),
2686 None
2687 );
2688 assert_eq!(
2689 convert_regex_domain_rule("css./cdn[a-z0-9]+/a.example.com"),
2690 None
2691 );
2692 }
2693
2694 #[test]
2699 fn regex_domain_rule_rejects_suffix_and_prefix() {
2700 let rule: DomainRule = "/example\\.com/".parse().unwrap();
2701 assert!(rule.matches(b"example.com"));
2702 assert!(!rule.matches(b"attacker.example.com"));
2703 assert!(!rule.matches(b"example.com.evil.org"));
2704 assert!(!rule.matches(b"prefixexample.com"));
2705 assert!(!rule.matches(b"example.commercial"));
2706 }
2707
2708 #[test]
2714 fn regex_domain_rule_multi_segment_segments_are_isolated() {
2715 let pattern = convert_regex_domain_rule("/seg1/.foo./seg2/.com")
2716 .expect("multi-segment regex hostname must compile");
2717 assert_eq!(pattern.as_str(), "\\Aseg1\\.foo\\.seg2\\.com\\z");
2718 }
2719
2720 #[test]
2721 fn parse_domain_rule() {
2722 assert_eq!("*".parse::<DomainRule>().unwrap(), DomainRule::Any);
2723 assert_eq!(
2724 "www.example.com".parse::<DomainRule>().unwrap(),
2725 DomainRule::Exact("www.example.com".to_string())
2726 );
2727 assert_eq!(
2728 "*.example.com".parse::<DomainRule>().unwrap(),
2729 DomainRule::Wildcard("*.example.com".to_string())
2730 );
2731 assert_eq!("test.*.example.com".parse::<DomainRule>(), Err(()));
2732 assert_eq!(
2733 "/cdn[0-9]+/.example.com".parse::<DomainRule>().unwrap(),
2734 DomainRule::Regex(Regex::new("\\Acdn[0-9]+\\.example\\.com\\z").unwrap())
2735 );
2736 }
2737
2738 #[test]
2739 fn match_domain_rule() {
2740 assert!(DomainRule::Any.matches("www.example.com".as_bytes()));
2741 assert!(
2742 DomainRule::Exact("www.example.com".to_string()).matches("www.example.com".as_bytes())
2743 );
2744 assert!(
2745 DomainRule::Wildcard("*.example.com".to_string()).matches("www.example.com".as_bytes())
2746 );
2747 assert!(
2748 !DomainRule::Wildcard("*.example.com".to_string())
2749 .matches("test.www.example.com".as_bytes())
2750 );
2751 assert!(
2752 "/cdn[0-9]+/.example.com"
2753 .parse::<DomainRule>()
2754 .unwrap()
2755 .matches("cdn1.example.com".as_bytes())
2756 );
2757 assert!(
2758 !"/cdn[0-9]+/.example.com"
2759 .parse::<DomainRule>()
2760 .unwrap()
2761 .matches("www.example.com".as_bytes())
2762 );
2763 assert!(
2764 !"/cdn[0-9]+/.example.com"
2765 .parse::<DomainRule>()
2766 .unwrap()
2767 .matches("cdn10.exampleAcom".as_bytes())
2768 );
2769 }
2770
2771 #[test]
2772 fn match_domain_rule_wildcard_short_hostname_does_not_panic() {
2773 let rule = DomainRule::Wildcard("*.foo.example.com".to_string());
2774
2775 assert!(!rule.matches(b""));
2777
2778 assert!(!rule.matches(b"a.b"));
2780 assert!(!rule.matches(b"x"));
2781
2782 assert!(!rule.matches(b".foo.example.com"));
2785
2786 assert!(!rule.matches(b"y.x.foo.example.com"));
2788
2789 assert!(rule.matches(b"x.foo.example.com"));
2792 }
2793
2794 #[test]
2795 fn router_lookup_wildcard_pre_rule_short_hostname_does_not_panic() {
2796 let mut router = Router::new();
2797
2798 assert!(router.add_pre_rule(
2801 &"*.foo.example.com".parse::<DomainRule>().unwrap(),
2802 &PathRule::Prefix("/".to_string()),
2803 &MethodRule::new(Some("GET".to_string())),
2804 &Route::ClusterId("wildcard".to_string()),
2805 ));
2806
2807 let method = Method::new(&b"GET"[..]);
2808
2809 assert!(router.lookup("", "/", &method).is_err());
2811 assert!(router.lookup("x", "/", &method).is_err());
2812 assert!(router.lookup("a.b", "/", &method).is_err());
2813
2814 assert!(router.lookup(".foo.example.com", "/", &method).is_err());
2816
2817 assert_eq!(
2819 router.lookup("x.foo.example.com", "/", &method),
2820 Ok(RouteResult::forward("wildcard".to_string()))
2821 );
2822 }
2823
2824 #[test]
2825 fn match_path_rule() {
2826 assert!(PathRule::Prefix("".to_string()).matches("/".as_bytes()) != PathRuleResult::None);
2827 assert!(
2828 PathRule::Prefix("".to_string()).matches("/hello".as_bytes()) != PathRuleResult::None
2829 );
2830 assert!(
2831 PathRule::Prefix("/hello".to_string()).matches("/hello".as_bytes())
2832 != PathRuleResult::None
2833 );
2834 assert!(
2835 PathRule::Prefix("/hello".to_string()).matches("/hello/world".as_bytes())
2836 != PathRuleResult::None
2837 );
2838 assert!(
2839 PathRule::Prefix("/hello".to_string()).matches("/".as_bytes()) == PathRuleResult::None
2840 );
2841 }
2842
2843 #[test]
2851 fn multiple_children_on_a_wildcard() {
2852 let mut router = Router::new();
2853
2854 assert!(router.add_tree_rule(
2855 b"*.sozu.io",
2856 &PathRule::Prefix("".to_string()),
2857 &MethodRule::new(Some("GET".to_string())),
2858 &Route::ClusterId("base".to_string())
2859 ));
2860 println!("{:#?}", router.tree);
2861 assert_eq!(
2862 router.lookup("www.sozu.io", "/api", &Method::Get),
2863 Ok(RouteResult::forward("base".to_string()))
2864 );
2865 assert!(router.add_tree_rule(
2866 b"*.sozu.io",
2867 &PathRule::Prefix("/api".to_string()),
2868 &MethodRule::new(Some("GET".to_string())),
2869 &Route::ClusterId("api".to_string())
2870 ));
2871 println!("{:#?}", router.tree);
2872 assert_eq!(
2873 router.lookup("www.sozu.io", "/ap", &Method::Get),
2874 Ok(RouteResult::forward("base".to_string()))
2875 );
2876 assert_eq!(
2877 router.lookup("www.sozu.io", "/api", &Method::Get),
2878 Ok(RouteResult::forward("api".to_string()))
2879 );
2880 }
2881
2882 #[test]
2890 fn multiple_children_including_one_with_wildcard() {
2891 let mut router = Router::new();
2892
2893 assert!(router.add_tree_rule(
2894 b"*.sozu.io",
2895 &PathRule::Prefix("".to_string()),
2896 &MethodRule::new(Some("GET".to_string())),
2897 &Route::ClusterId("base".to_string())
2898 ));
2899 println!("{:#?}", router.tree);
2900 assert_eq!(
2901 router.lookup("www.sozu.io", "/api", &Method::Get),
2902 Ok(RouteResult::forward("base".to_string()))
2903 );
2904 assert!(router.add_tree_rule(
2905 b"api.sozu.io",
2906 &PathRule::Prefix("".to_string()),
2907 &MethodRule::new(Some("GET".to_string())),
2908 &Route::ClusterId("api".to_string())
2909 ));
2910 println!("{:#?}", router.tree);
2911 assert_eq!(
2912 router.lookup("www.sozu.io", "/api", &Method::Get),
2913 Ok(RouteResult::forward("base".to_string()))
2914 );
2915 assert_eq!(
2916 router.lookup("api.sozu.io", "/api", &Method::Get),
2917 Ok(RouteResult::forward("api".to_string()))
2918 );
2919 }
2920
2921 #[test]
2922 fn router_insert_remove_through_regex() {
2923 let mut router = Router::new();
2924
2925 assert!(router.add_tree_rule(
2926 b"www./.*/.io",
2927 &PathRule::Prefix("".to_string()),
2928 &MethodRule::new(Some("GET".to_string())),
2929 &Route::ClusterId("base".to_string())
2930 ));
2931 println!("{:#?}", router.tree);
2932 assert!(router.add_tree_rule(
2933 b"www.doc./.*/.io",
2934 &PathRule::Prefix("".to_string()),
2935 &MethodRule::new(Some("GET".to_string())),
2936 &Route::ClusterId("doc".to_string())
2937 ));
2938 println!("{:#?}", router.tree);
2939 assert_eq!(
2940 router.lookup("www.sozu.io", "/", &Method::Get),
2941 Ok(RouteResult::forward("base".to_string()))
2942 );
2943 assert_eq!(
2944 router.lookup("www.doc.sozu.io", "/", &Method::Get),
2945 Ok(RouteResult::forward("doc".to_string()))
2946 );
2947 assert!(router.remove_tree_rule(
2948 b"www./.*/.io",
2949 &PathRule::Prefix("".to_string()),
2950 &MethodRule::new(Some("GET".to_string()))
2951 ));
2952 println!("{:#?}", router.tree);
2953 assert!(router.lookup("www.sozu.io", "/", &Method::Get).is_err());
2954 assert_eq!(
2955 router.lookup("www.doc.sozu.io", "/", &Method::Get),
2956 Ok(RouteResult::forward("doc".to_string()))
2957 );
2958 }
2959
2960 #[test]
2961 fn match_router() {
2962 let mut router = Router::new();
2963
2964 assert!(router.add_pre_rule(
2965 &"*".parse::<DomainRule>().unwrap(),
2966 &PathRule::Prefix("/.well-known/acme-challenge".to_string()),
2967 &MethodRule::new(Some("GET".to_string())),
2968 &Route::ClusterId("acme".to_string())
2969 ));
2970 assert!(router.add_tree_rule(
2971 "www.example.com".as_bytes(),
2972 &PathRule::Prefix("/".to_string()),
2973 &MethodRule::new(Some("GET".to_string())),
2974 &Route::ClusterId("example".to_string())
2975 ));
2976 assert!(router.add_tree_rule(
2977 "*.test.example.com".as_bytes(),
2978 &PathRule::Regex(Regex::new("/hello[A-Z]+/").unwrap()),
2979 &MethodRule::new(Some("GET".to_string())),
2980 &Route::ClusterId("examplewildcard".to_string())
2981 ));
2982 assert!(router.add_tree_rule(
2983 "/test[0-9]/.example.com".as_bytes(),
2984 &PathRule::Prefix("/".to_string()),
2985 &MethodRule::new(Some("GET".to_string())),
2986 &Route::ClusterId("exampleregex".to_string())
2987 ));
2988
2989 assert_eq!(
2990 router.lookup("www.example.com", "/helloA", &Method::new(&b"GET"[..])),
2991 Ok(RouteResult::forward("example".to_string()))
2992 );
2993 assert_eq!(
2994 router.lookup(
2995 "www.example.com",
2996 "/.well-known/acme-challenge",
2997 &Method::new(&b"GET"[..])
2998 ),
2999 Ok(RouteResult::forward("acme".to_string()))
3000 );
3001 assert!(
3002 router
3003 .lookup("www.test.example.com", "/", &Method::new(&b"GET"[..]))
3004 .is_err()
3005 );
3006 assert_eq!(
3007 router.lookup(
3008 "www.test.example.com",
3009 "/helloAB/",
3010 &Method::new(&b"GET"[..])
3011 ),
3012 Ok(RouteResult::forward("examplewildcard".to_string()))
3013 );
3014 assert_eq!(
3015 router.lookup("test1.example.com", "/helloAB/", &Method::new(&b"GET"[..])),
3016 Ok(RouteResult::forward("exampleregex".to_string()))
3017 );
3018 }
3019
3020 #[test]
3021 fn has_hostname_checks_tree_pre_and_post() {
3022 let mut router = Router::new();
3023
3024 assert!(!router.has_hostname("www.example.com"));
3026
3027 assert!(router.add_tree_rule(
3029 b"www.example.com",
3030 &PathRule::Prefix("/".to_string()),
3031 &MethodRule::new(Some("GET".to_string())),
3032 &Route::ClusterId("cluster1".to_string())
3033 ));
3034 assert!(router.has_hostname("www.example.com"));
3035 assert!(!router.has_hostname("api.example.com"));
3036
3037 assert!(router.remove_tree_rule(
3039 b"www.example.com",
3040 &PathRule::Prefix("/".to_string()),
3041 &MethodRule::new(Some("GET".to_string()))
3042 ));
3043 assert!(!router.has_hostname("www.example.com"));
3044
3045 assert!(router.add_pre_rule(
3047 &DomainRule::Exact("api.example.com".to_string()),
3048 &PathRule::Prefix("/".to_string()),
3049 &MethodRule::new(None),
3050 &Route::ClusterId("cluster2".to_string())
3051 ));
3052 assert!(router.has_hostname("api.example.com"));
3053 assert!(!router.has_hostname("www.example.com"));
3054
3055 assert!(router.add_post_rule(
3057 &DomainRule::Exact("cdn.example.com".to_string()),
3058 &PathRule::Prefix("/".to_string()),
3059 &MethodRule::new(None),
3060 &Route::ClusterId("cluster3".to_string())
3061 ));
3062 assert!(router.has_hostname("cdn.example.com"));
3063
3064 assert!(router.remove_pre_rule(
3066 &DomainRule::Exact("api.example.com".to_string()),
3067 &PathRule::Prefix("/".to_string()),
3068 &MethodRule::new(None),
3069 ));
3070 assert!(!router.has_hostname("api.example.com"));
3071 assert!(router.has_hostname("cdn.example.com"));
3072 }
3073
3074 #[test]
3075 fn has_hostname_false_after_last_route_removed() {
3076 let mut router = Router::new();
3077
3078 assert!(router.add_tree_rule(
3080 b"www.example.com",
3081 &PathRule::Prefix("/".to_string()),
3082 &MethodRule::new(Some("GET".to_string())),
3083 &Route::ClusterId("cluster1".to_string())
3084 ));
3085 assert!(router.add_tree_rule(
3086 b"www.example.com",
3087 &PathRule::Prefix("/api".to_string()),
3088 &MethodRule::new(Some("GET".to_string())),
3089 &Route::ClusterId("cluster2".to_string())
3090 ));
3091 assert!(router.has_hostname("www.example.com"));
3092
3093 assert!(router.remove_tree_rule(
3095 b"www.example.com",
3096 &PathRule::Prefix("/".to_string()),
3097 &MethodRule::new(Some("GET".to_string()))
3098 ));
3099 assert!(router.has_hostname("www.example.com"));
3100
3101 assert!(router.remove_tree_rule(
3103 b"www.example.com",
3104 &PathRule::Prefix("/api".to_string()),
3105 &MethodRule::new(Some("GET".to_string()))
3106 ));
3107 assert!(!router.has_hostname("www.example.com"));
3108 }
3109}