use std::{cell::RefCell, collections::HashMap, rc::Rc, time::Duration};
use mio::{Interest, Token, net::TcpStream};
use sozu_command::{
logging::ansi_palette,
proto::command::{ListenerType, RedirectPolicy, RedirectScheme},
};
#[cfg(debug_assertions)]
use super::DebugEvent;
use super::{BackendStatus, Connection, Context, GlobalStreamId, Position, StreamState};
use crate::{
BackendConnectionError, L7ListenerHandler, L7Proxy, ListenerHandler, ProxySession, Readiness,
RetrieveClusterError,
backends::{Backend, BackendError},
protocol::http::editor::{HeaderEditMode, HeaderEditSnapshot, HttpContext},
router::{HeaderEdit, RouteResult},
server::CONN_RETRIES,
socket::SessionTcpStream,
timer::TimeoutContainer,
};
use crate::metrics::names;
macro_rules! log_module_context {
() => {{
let (open, reset, _, _, _) = ansi_palette();
format!("{open}MUX-ROUTER{reset}\t >>>", open = open, reset = reset)
}};
($http_context:expr) => {{
let (open, reset, grey, gray, white) = ansi_palette();
let http_ctx: &HttpContext = &$http_context;
let ctx = http_ctx.log_context();
format!(
"{gray}{ctx}{reset}\t{open}MUX-ROUTER{reset}\t{grey}Session{reset}({gray}frontend{reset}={white}{frontend:?}{reset}, {gray}method{reset}={white}{method:?}{reset}, {gray}authority{reset}={white}{authority:?}{reset})\t >>>",
open = open,
reset = reset,
grey = grey,
gray = gray,
white = white,
ctx = ctx,
frontend = http_ctx.session_address,
method = http_ctx.method,
authority = http_ctx.authority,
)
}};
}
#[derive(Debug)]
pub struct Router {
pub backends: HashMap<Token, Connection<SessionTcpStream>>,
pub configured_backend_timeout: Duration,
pub configured_connect_timeout: Duration,
pub(super) fallback_readiness: Readiness,
}
impl Router {
pub fn new(configured_backend_timeout: Duration, configured_connect_timeout: Duration) -> Self {
Self {
backends: HashMap::new(),
configured_backend_timeout,
configured_connect_timeout,
fallback_readiness: Readiness::new(),
}
}
pub(super) fn connect<L: ListenerHandler + L7ListenerHandler>(
&mut self,
stream_id: GlobalStreamId,
context: &mut Context<L>,
session: Rc<RefCell<dyn ProxySession>>,
proxy: Rc<RefCell<dyn L7Proxy>>,
frontend_token: Token,
) -> Result<(), BackendConnectionError> {
let stream = &mut context.streams[stream_id];
if !matches!(stream.state, StreamState::Link) {
error!(
"{} stream {} expected to be in Link state, got {:?}",
log_module_context!(stream.context),
stream_id,
stream.state
);
return Err(BackendConnectionError::MaxSessionsMemory);
}
#[cfg(debug_assertions)]
context
.debug
.push(DebugEvent::Str(stream.context.get_route()));
if stream.attempts >= CONN_RETRIES {
incr!(
"backend.connect.retries_exhausted",
stream.context.cluster_id.as_deref(),
stream.context.backend_id.as_deref()
);
return Err(BackendConnectionError::MaxConnectionRetries(
stream.context.cluster_id.clone(),
));
}
stream.attempts += 1;
let (front_ref, stream_context_ref) = {
let stream_split = &mut *stream;
(&mut stream_split.front, &mut stream_split.context)
};
let cluster_id = self
.route_from_request(stream_context_ref, front_ref, &context.listener, &proxy)
.map_err(BackendConnectionError::RetrieveClusterError)?;
let stream_context = &mut stream.context;
stream_context.cluster_id = Some(cluster_id.to_owned());
let (
frontend_should_stick,
frontend_should_redirect_https,
h2,
cluster_max_connections_per_ip,
cluster_retry_after,
) = proxy
.borrow()
.clusters()
.get(&cluster_id)
.map(|cluster| {
(
cluster.sticky_session,
cluster.https_redirect,
cluster.http2.unwrap_or(false),
cluster.max_connections_per_ip,
cluster.retry_after,
)
})
.unwrap_or((false, false, false, None, None));
if frontend_should_redirect_https && matches!(proxy.borrow().kind(), ListenerType::Http) {
return Err(BackendConnectionError::RetrieveClusterError(
RetrieveClusterError::HttpsRedirect,
));
}
let session_ip = stream_context.session_address.map(|sa| sa.ip());
if let Some(ip) = session_ip {
let sessions_rc = proxy.borrow().sessions();
let at_limit = sessions_rc.borrow().cluster_ip_at_limit(
frontend_token,
&cluster_id,
&ip,
cluster_max_connections_per_ip,
);
if at_limit {
let retry_after = sessions_rc
.borrow()
.effective_retry_after(cluster_retry_after);
stream_context.retry_after_seconds = Some(retry_after).filter(|v| *v > 0);
return Err(BackendConnectionError::TooManyConnectionsPerIp {
cluster_id: cluster_id.to_owned(),
});
}
sessions_rc
.borrow_mut()
.track_cluster_ip(frontend_token, cluster_id.clone(), ip);
}
let mut reuse_token = None;
let mut best_h2_stream_count = usize::MAX;
for (token, backend) in &self.backends {
match (h2, backend.position()) {
(_, Position::Server) => {
error!(
"{} Backend connection unexpectedly behaves like a server",
log_module_context!(stream_context)
);
continue;
}
(_, Position::Client(_, _, BackendStatus::Disconnecting)) => {}
(true, Position::Client(other_cluster_id, _, BackendStatus::Connected)) => {
if *other_cluster_id == cluster_id && !backend.is_draining() {
let Connection::H2(h2c) = backend else {
continue;
};
let stream_count = h2c.streams.len();
if stream_count
>= h2c.peer_settings.settings_max_concurrent_streams as usize
{
continue;
}
if stream_count < best_h2_stream_count {
best_h2_stream_count = stream_count;
reuse_token = Some(*token);
}
}
}
(true, Position::Client(other_cluster_id, _, BackendStatus::Connecting(_))) => {
if *other_cluster_id == cluster_id
&& best_h2_stream_count == usize::MAX
&& matches!(backend, Connection::H2(_))
{
reuse_token = Some(*token)
}
}
(true, Position::Client(other_cluster_id, _, BackendStatus::KeepAlive)) => {
if *other_cluster_id == cluster_id && matches!(backend, Connection::H2(_)) {
error!(
"{} ConnectionH2 unexpectedly behaves like H1 with KeepAlive",
log_module_context!(stream_context)
);
}
}
(false, Position::Client(old_cluster_id, _, BackendStatus::KeepAlive)) => {
if *old_cluster_id == cluster_id {
reuse_token = Some(*token);
break;
}
}
(false, Position::Client(_, _, BackendStatus::Connected))
| (false, Position::Client(_, _, BackendStatus::Connecting(_))) => {}
}
}
trace!(
"{} connect: (stick={}, h2={}) -> (reuse={:?})",
log_module_context!(stream_context),
frontend_should_stick,
h2,
reuse_token
);
if let Some(token) = reuse_token {
incr!(names::backend::POOL_HIT);
trace!(
"{} reused backend: {:#?}",
log_module_context!(stream_context),
self.backends.get(&token)
);
let Some(backend_conn) = self.backends.get_mut(&token) else {
error!(
"{} reused backend token {:?} missing from backends map",
log_module_context!(stream_context),
token
);
return Err(BackendConnectionError::MaxSessionsMemory);
};
if !backend_conn.start_stream(stream_id, context) {
error!(
"{} Backend rejected stream start (max concurrent streams reached)",
log_module_context!(context.http_context(stream_id))
);
return Err(BackendConnectionError::MaxSessionsMemory);
}
if let Some(backend_conn) = self.backends.get(&token) {
if let Position::Client(_, backend_ref, _) = backend_conn.position() {
let backend = backend_ref.borrow();
let stream = &mut context.streams[stream_id];
stream.context.backend_id = Some(backend.backend_id.to_owned());
stream.context.backend_address = Some(backend.address);
stream.metrics.backend_id = Some(backend.backend_id.to_owned());
stream.metrics.backend_start();
stream.metrics.backend_connected();
}
}
context.link_stream(stream_id, token);
return Ok(());
}
incr!(names::backend::POOL_MISS);
let token = {
let (socket, backend) = self.backend_from_request(
&cluster_id,
frontend_should_stick,
stream_context,
proxy.clone(),
&context.listener,
)?;
if let Err(e) = socket.set_nodelay(true) {
error!(
"{} error setting nodelay on back socket({:?}): {:?}",
log_module_context!(context.http_context(stream_id)),
socket,
e
);
}
let backend_peer = Some(backend.borrow().address);
let socket = SessionTcpStream::new(socket, context.session_ulid, backend_peer);
let timeout_container = TimeoutContainer::new_empty(self.configured_connect_timeout);
let flood_config = context.listener.borrow().get_h2_flood_config();
let connection_config = context.listener.borrow().get_h2_connection_config();
let stream_idle_timeout = context.listener.borrow().get_h2_stream_idle_timeout();
let graceful_shutdown_deadline = context
.listener
.borrow()
.get_h2_graceful_shutdown_deadline();
let backend_id_for_gauge = backend.borrow().backend_id.to_owned();
let mut connection = if h2 {
match Connection::new_h2_client(
context.session_ulid,
socket,
cluster_id.to_owned(),
backend,
context.pool.clone(),
timeout_container,
flood_config,
connection_config,
stream_idle_timeout,
graceful_shutdown_deadline,
) {
Some(connection) => connection,
None => return Err(BackendConnectionError::MaxBuffers),
}
} else {
Connection::new_h1_client(
context.session_ulid,
socket,
cluster_id.to_owned(),
backend,
timeout_container,
)
};
if !connection.start_stream(stream_id, context) {
error!(
"{} Backend rejected stream start (max concurrent streams reached)",
log_module_context!(context.http_context(stream_id))
);
return Err(BackendConnectionError::MaxSessionsMemory);
}
let stream = &mut context.streams[stream_id];
stream.metrics.backend_start();
stream.metrics.backend_id = stream.context.backend_id.to_owned();
gauge_add!(names::backend::CONNECTIONS, 1);
gauge_add!(names::backend::POOL_SIZE, 1);
gauge_add!(
names::backend::CONNECTIONS_PER_BACKEND,
1,
Some(&cluster_id),
Some(&backend_id_for_gauge)
);
let token = proxy.borrow().add_session(session);
{
let socket_ref = connection.socket_mut();
if let Err(e) = proxy.borrow().register_socket(
socket_ref,
token,
Interest::READABLE | Interest::WRITABLE,
) {
error!(
"{} error registering back socket: {:?} — rolling back",
log_module_context!(context.http_context(stream_id)),
e
);
gauge_add!(names::backend::CONNECTIONS, -1);
gauge_add!(names::backend::POOL_SIZE, -1);
gauge_add!(
names::backend::CONNECTIONS_PER_BACKEND,
-1,
Some(&cluster_id),
Some(&backend_id_for_gauge)
);
proxy.borrow().remove_session(token);
return Err(BackendConnectionError::MaxSessionsMemory);
}
}
connection.timeout_container().set(token);
self.backends.insert(token, connection);
token
};
context.link_stream(stream_id, token);
Ok(())
}
fn route_from_request<L: ListenerHandler + L7ListenerHandler>(
&mut self,
context: &mut HttpContext,
front: &mut super::GenericHttpStream,
listener: &Rc<RefCell<L>>,
proxy: &Rc<RefCell<dyn L7Proxy>>,
) -> Result<String, RetrieveClusterError> {
let (host, uri, method) = match context.extract_route() {
Ok(tuple) => tuple,
Err(cluster_error) => {
error!(
"{} Malformed request in connect (should be caught at parsing) {:?}: {}",
log_module_context!(context),
context,
cluster_error
);
return Err(cluster_error);
}
};
let captured_authority = host.to_owned();
if let Some(sni) = context
.tls_server_name
.as_deref()
.filter(|_| context.strict_sni_binding)
{
let matched: Option<&str> = match context.tls_cert_names.as_deref() {
Some(cert_names) => authority_matched_cert_name(host, cert_names),
None => {
if authority_matches_sni(host, sni) {
Some(sni)
} else {
None
}
}
};
match matched {
Some(matched_name) => {
if !authority_matches_sni(host, sni) && context.tls_alpn == Some("h2") {
incr!(names::h2::COALESCING_ACCEPTED);
debug!(
"{} accepted coalesced authority {:?} (SNI {:?}, matched SAN {:?})",
log_module_context!(context),
host,
sni,
matched_name,
);
}
}
None => {
incr!(names::http::SNI_AUTHORITY_MISMATCH);
warn!(
"{} rejecting request: TLS cert SANs do not cover :authority {:?} (SNI {:?})",
log_module_context!(context),
host,
sni,
);
return Err(RetrieveClusterError::SniAuthorityMismatch {
sni: sni.to_owned(),
authority: host.to_owned(),
});
}
}
}
let route_result = listener.borrow().frontend_from_request(host, uri, method);
let route = match route_result {
Ok(route) => route,
Err(frontend_error) => {
trace!("{} {}", log_module_context!(context), frontend_error);
return Err(RetrieveClusterError::RetrieveFrontend(frontend_error));
}
};
context.original_authority = Some(captured_authority);
let RouteResult {
cluster_id,
redirect,
redirect_scheme,
redirect_template,
rewritten_host,
rewritten_path,
rewritten_port,
headers_request,
headers_response,
required_auth: frontend_required_auth,
..
} = route;
if matches!(context.protocol, crate::Protocol::HTTPS) {
snapshot_response_edits(&mut context.headers_response, &headers_response, |e| {
matches!(e.mode, HeaderEditMode::SetIfAbsent | HeaderEditMode::Set)
});
}
let (legacy_https_redirect, https_redirect_port, authorized_hashes, www_authenticate) =
match cluster_id.as_deref() {
Some(id) => proxy
.borrow()
.clusters()
.get(id)
.map(|c| {
(
c.https_redirect,
c.https_redirect_port,
c.authorized_hashes.clone(),
c.www_authenticate.clone(),
)
})
.unwrap_or((false, None, Vec::new(), None)),
None => (false, None, Vec::new(), None),
};
let redirect_status = match redirect {
RedirectPolicy::Permanent => Some(301u16),
RedirectPolicy::Found => Some(302u16),
RedirectPolicy::PermanentRedirect => Some(308u16),
RedirectPolicy::Forward | RedirectPolicy::Unauthorized => None,
};
if let Some(status_code) = redirect_status {
let scheme = resolve_redirect_scheme(redirect_scheme, context);
let port = rewritten_port.map(|p| p as u32).or(https_redirect_port);
context.redirect_location = Some(build_redirect_location(
scheme,
context,
port,
rewritten_host.as_deref(),
rewritten_path.as_deref(),
));
context.frontend_redirect_template = redirect_template;
context.redirect_status = Some(status_code);
return Err(RetrieveClusterError::HttpsRedirect);
}
if matches!(redirect, RedirectPolicy::Unauthorized) || cluster_id.is_none() {
context.www_authenticate = www_authenticate.clone();
trace!("{} RouteResult::deny", log_module_context!(context));
return Err(RetrieveClusterError::UnauthorizedRoute);
}
let Some(cluster_id) = cluster_id else {
unreachable!("cluster_id was checked Some above")
};
if legacy_https_redirect && matches!(proxy.borrow().kind(), ListenerType::Http) {
let port = https_redirect_port;
context.redirect_location =
Some(build_redirect_location("https", context, port, None, None));
}
if frontend_required_auth
&& !crate::protocol::mux::auth::check_basic(front, &authorized_hashes)
{
context.www_authenticate = www_authenticate.clone();
trace!(
"{} basic-auth check failed; emitting 401",
log_module_context!(context)
);
return Err(RetrieveClusterError::UnauthorizedRoute);
}
apply_request_rewrites_and_headers(
front,
context,
rewritten_host.as_deref(),
rewritten_path.as_deref(),
&headers_request,
);
snapshot_response_edits(&mut context.headers_response, &headers_response, |_| true);
Ok(cluster_id)
}
pub fn backend_from_request<L: ListenerHandler + L7ListenerHandler>(
&mut self,
cluster_id: &str,
frontend_should_stick: bool,
context: &mut HttpContext,
proxy: Rc<RefCell<dyn L7Proxy>>,
listener: &Rc<RefCell<L>>,
) -> Result<(TcpStream, Rc<RefCell<Backend>>), BackendConnectionError> {
let (backend, conn) = self
.get_backend_for_sticky_session(
cluster_id,
frontend_should_stick,
context.sticky_session_found.as_deref(),
proxy,
)
.map_err(|backend_error| {
trace!("{} {}", log_module_context!(context), backend_error);
BackendConnectionError::Backend(backend_error)
})?;
if frontend_should_stick {
context.sticky_name = listener.borrow().get_sticky_name().to_string();
context.sticky_session = Some(
backend
.borrow()
.sticky_id
.clone()
.unwrap_or_else(|| backend.borrow().backend_id.to_owned()),
);
}
context.backend_id = Some(backend.borrow().backend_id.to_owned());
context.backend_address = Some(backend.borrow().address);
Ok((conn, backend))
}
fn get_backend_for_sticky_session(
&self,
cluster_id: &str,
frontend_should_stick: bool,
sticky_session: Option<&str>,
proxy: Rc<RefCell<dyn L7Proxy>>,
) -> Result<(Rc<RefCell<Backend>>, TcpStream), BackendError> {
match (frontend_should_stick, sticky_session) {
(true, Some(sticky_session)) => proxy
.borrow()
.backends()
.borrow_mut()
.backend_from_sticky_session(cluster_id, sticky_session),
_ => proxy
.borrow()
.backends()
.borrow_mut()
.backend_from_cluster_id(cluster_id),
}
}
}
fn apply_request_rewrites_and_headers(
kawa: &mut super::GenericHttpStream,
context: &mut HttpContext,
rewritten_host: Option<&str>,
rewritten_path: Option<&str>,
headers_request: &[HeaderEdit],
) {
use kawa::{Block, Pair, Store};
if rewritten_host.is_none() && rewritten_path.is_none() && headers_request.is_empty() {
return;
}
let original_authority: Option<String> = if rewritten_host.is_some() {
context.original_authority.clone()
} else {
None
};
if rewritten_host.is_some() || rewritten_path.is_some() {
if let kawa::StatusLine::Request {
authority,
path,
uri,
..
} = &mut kawa.detached.status_line
{
if let Some(new_host) = rewritten_host {
*authority = Store::from_string(new_host.to_owned());
}
if let Some(new_path) = rewritten_path {
*path = Store::from_string(new_path.to_owned());
*uri = Store::from_string(new_path.to_owned());
}
}
}
let host_lower = b"host";
let xfh_lower = b"x-forwarded-host";
let rewriting_host = rewritten_host.is_some();
let mut keys_to_drop: Vec<Vec<u8>> = Vec::with_capacity(headers_request.len() + 2);
let mut to_insert: Vec<Block> = Vec::with_capacity(headers_request.len() + 2);
let mut operator_overrides_host = false;
let mut operator_overrides_xfh = false;
for edit in headers_request {
let key_is_host = edit.key.eq_ignore_ascii_case(host_lower);
let key_is_xfh = edit.key.eq_ignore_ascii_case(xfh_lower);
operator_overrides_host |= key_is_host;
operator_overrides_xfh |= key_is_xfh;
if edit.val.is_empty() {
keys_to_drop.push(edit.key.iter().map(u8::to_ascii_lowercase).collect());
} else {
to_insert.push(Block::Header(Pair {
key: Store::from_slice(&edit.key),
val: Store::from_slice(&edit.val),
}));
}
}
if rewriting_host || operator_overrides_host {
keys_to_drop.push(host_lower.to_vec());
}
if rewriting_host || operator_overrides_xfh {
keys_to_drop.push(xfh_lower.to_vec());
}
let buf_ptr = kawa.storage.buffer();
if !keys_to_drop.is_empty() {
let buf = buf_ptr;
kawa.blocks.retain(|block| {
if let Block::Header(Pair { key, val: _ }) = block {
if matches!(key, Store::Empty) {
return true;
}
let key_bytes = key.data(buf);
let key_lower: Vec<u8> = key_bytes.iter().map(u8::to_ascii_lowercase).collect();
!keys_to_drop
.iter()
.any(|k| k.as_slice() == key_lower.as_slice())
} else {
true
}
});
}
let end_header_idx = super::shared::end_of_headers_index(kawa);
if rewriting_host {
let mut synth: Vec<Block> = Vec::with_capacity(2);
if let Some(new_host) = rewritten_host {
synth.push(Block::Header(Pair {
key: Store::Static(b"Host"),
val: Store::from_string(new_host.to_owned()),
}));
}
if let Some(orig) = original_authority.as_deref() {
synth.push(Block::Header(Pair {
key: Store::Static(b"X-Forwarded-Host"),
val: Store::from_string(orig.to_owned()),
}));
}
synth.append(&mut to_insert);
to_insert = synth;
}
if !to_insert.is_empty() {
let insert_at = end_header_idx.unwrap_or(kawa.blocks.len());
for (offset, block) in to_insert.into_iter().enumerate() {
kawa.blocks.insert(insert_at + offset, block);
}
}
}
fn snapshot_response_edits<F>(target: &mut Vec<HeaderEditSnapshot>, src: &[HeaderEdit], filter: F)
where
F: Fn(&HeaderEdit) -> bool,
{
target.clear();
for edit in src.iter().filter(|e| filter(e)) {
target.push(HeaderEditSnapshot {
key: edit.key.to_vec(),
val: edit.val.to_vec(),
mode: edit.mode,
});
}
}
fn resolve_redirect_scheme(scheme: RedirectScheme, context: &HttpContext) -> &'static str {
match scheme {
RedirectScheme::UseHttps => "https",
RedirectScheme::UseHttp => "http",
RedirectScheme::UseSame => {
if context.tls_server_name.is_some() {
"https"
} else {
"http"
}
}
}
}
fn build_redirect_location(
scheme: &str,
context: &HttpContext,
port: Option<u32>,
host_override: Option<&str>,
path_override: Option<&str>,
) -> String {
let authority = host_override
.or(context.authority.as_deref())
.unwrap_or_default();
let path = path_override.or(context.path.as_deref()).unwrap_or("/");
let host_only = match authority.rsplit_once(':') {
Some((host, port_part))
if !port_part.is_empty() && port_part.bytes().all(|b| b.is_ascii_digit()) =>
{
host
}
_ => authority,
};
let port_suffix = match port {
Some(80) if scheme == "http" => String::new(),
Some(443) if scheme == "https" => String::new(),
Some(p) => format!(":{p}"),
None => String::new(),
};
format!("{scheme}://{host_only}{port_suffix}{path}")
}
pub(crate) fn authority_matches_sni(authority: &str, sni_lowercased: &str) -> bool {
let host = strip_authority_port(authority);
if host.len() != sni_lowercased.len() {
return false;
}
host.as_bytes()
.iter()
.zip(sni_lowercased.as_bytes())
.all(|(a, b)| a.to_ascii_lowercase() == *b)
}
fn strip_authority_port(authority: &str) -> &str {
match authority.rsplit_once(':') {
Some((h, port)) if !port.is_empty() && port.bytes().all(|b| b.is_ascii_digit()) => h,
_ => authority,
}
}
pub(crate) fn authority_matched_cert_name<'a>(
authority: &str,
names: &'a [String],
) -> Option<&'a str> {
let mut host = strip_authority_port(authority);
if let Some(trimmed) = host.strip_suffix('.') {
host = trimmed;
}
if host.is_empty() {
return None;
}
for entry in names {
if let Some(suffix) = entry.strip_prefix("*.") {
if suffix.contains('*') {
continue;
}
let Some((leftmost, rest)) = host.split_once('.') else {
continue;
};
if leftmost.is_empty() {
continue;
}
if rest.eq_ignore_ascii_case(suffix) {
return Some(entry);
}
continue;
}
if entry.contains('*') {
continue;
}
if host.eq_ignore_ascii_case(entry) {
return Some(entry);
}
}
None
}
#[cfg(test)]
mod tests {
use super::authority_matches_sni;
#[test]
fn match_exact() {
assert!(authority_matches_sni("example.com", "example.com"));
}
#[test]
fn match_different_case() {
assert!(authority_matches_sni("Example.COM", "example.com"));
}
#[test]
fn match_authority_with_port() {
assert!(authority_matches_sni("example.com:8443", "example.com"));
}
#[test]
fn reject_different_host() {
assert!(!authority_matches_sni(
"tenant-b.example.com",
"tenant-a.example.com"
));
}
#[test]
fn reject_substring_attack() {
assert!(!authority_matches_sni("example.co", "example.com"));
assert!(!authority_matches_sni("example.commons", "example.com"));
}
#[test]
fn reject_wildcard_not_expanded() {
assert!(!authority_matches_sni("foo.example.com", "*.example.com"));
}
#[test]
fn ipv6_bracketed_literal_with_port() {
assert!(authority_matches_sni("[::1]:8443", "[::1]"));
}
#[test]
fn ipv6_bracketed_without_port() {
assert!(authority_matches_sni("[::1]", "[::1]"));
}
}
#[cfg(test)]
mod authority_matched_cert_name_tests {
use super::authority_matched_cert_name;
#[test]
fn cert_name_match_exact_single_san() {
let names = vec!["example.com".to_owned()];
assert_eq!(
authority_matched_cert_name("example.com", &names),
Some("example.com"),
);
}
#[test]
fn cert_name_match_wildcard_left_most() {
let names = vec!["*.cleverapps.io".to_owned()];
assert_eq!(
authority_matched_cert_name("staging-3.cleverapps.io", &names),
Some("*.cleverapps.io"),
);
}
#[test]
fn cert_name_reject_wildcard_apex() {
let names = vec!["*.example.com".to_owned()];
assert_eq!(authority_matched_cert_name("example.com", &names), None);
}
#[test]
fn cert_name_reject_wildcard_two_labels() {
let names = vec!["*.example.com".to_owned()];
assert_eq!(authority_matched_cert_name("a.b.example.com", &names), None,);
}
#[test]
fn cert_name_reject_wildcard_not_left_most() {
let names = vec!["foo.*.example.com".to_owned()];
assert_eq!(
authority_matched_cert_name("foo.bar.example.com", &names),
None,
);
}
#[test]
fn cert_name_match_case_insensitive() {
let names = vec!["EXAMPLE.com".to_owned()];
assert!(authority_matched_cert_name("Example.COM", &names).is_some());
}
#[test]
fn cert_name_match_with_port() {
let names = vec!["example.com".to_owned()];
assert!(authority_matched_cert_name("example.com:8443", &names).is_some());
}
#[test]
fn cert_name_match_absolute_form_trailing_dot() {
let names = vec!["example.com".to_owned()];
assert!(authority_matched_cert_name("example.com.", &names).is_some());
assert!(authority_matched_cert_name("example.com.:8443", &names).is_some());
let wildcard = vec!["*.example.com".to_owned()];
assert!(authority_matched_cert_name("foo.example.com.", &wildcard).is_some());
}
#[test]
fn cert_name_match_idn_a_label() {
let names = vec!["xn--bcher-kva.example.com".to_owned()];
assert!(authority_matched_cert_name("xn--bcher-kva.example.com", &names).is_some());
}
#[test]
fn cert_name_reject_empty_names() {
assert_eq!(authority_matched_cert_name("example.com", &[]), None);
}
#[test]
fn cert_name_match_multi_san_one_hit() {
let names = vec!["foo.com".to_owned(), "*.example.org".to_owned()];
assert_eq!(
authority_matched_cert_name("bar.example.org", &names),
Some("*.example.org"),
);
}
#[test]
fn cert_name_reject_substring_attack() {
let names = vec!["*.example.com".to_owned()];
assert_eq!(authority_matched_cert_name("example.commons", &names), None,);
}
#[test]
fn cert_name_ipv6_bracketed_literal_with_port() {
let names = vec!["[::1]".to_owned()];
assert!(authority_matched_cert_name("[::1]:8443", &names).is_some());
}
}