use std::{
collections::{BTreeMap, HashMap, HashSet},
env, fmt,
fs::{File, create_dir_all, metadata},
io::{ErrorKind, Read},
net::SocketAddr,
ops::Range,
path::PathBuf,
};
use crate::{
ObjectKind,
certificate::split_certificate_chain,
logging::AccessLogFormat,
proto::command::{
ActivateListener, AddBackend, AddCertificate, CertificateAndKey, Cluster,
CustomHttpAnswers, Header, HeaderPosition, HealthCheckConfig, HstsConfig,
HttpListenerConfig, HttpsListenerConfig, ListenerType, LoadBalancingAlgorithms,
LoadBalancingParams, LoadMetric, MetricDetail, MetricsConfiguration, PathRule,
ProtobufAccessLogFormat, ProxyProtocolConfig, RedirectPolicy, RedirectScheme, Request,
RequestHttpFrontend, RequestTcpFrontend, RulePosition, ServerConfig, ServerMetricsConfig,
SocketAddress, TcpListenerConfig, TlsVersion, WorkerRequest, request::RequestType,
},
};
pub const DEFAULT_CIPHER_LIST: [&str; 9] = [
"TLS13_AES_256_GCM_SHA384",
"TLS13_AES_128_GCM_SHA256",
"TLS13_CHACHA20_POLY1305_SHA256",
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
];
pub const DEFAULT_SIGNATURE_ALGORITHMS: [&str; 9] = [
"ECDSA+SHA256",
"ECDSA+SHA384",
"ECDSA+SHA512",
"RSA+SHA256",
"RSA+SHA384",
"RSA+SHA512",
"RSA-PSS+SHA256",
"RSA-PSS+SHA384",
"RSA-PSS+SHA512",
];
pub const DEFAULT_GROUPS_LIST: [&str; 4] = ["X25519MLKEM768", "x25519", "P-256", "P-384"];
pub const DEFAULT_ALPN_PROTOCOLS: [&str; 2] = ["h2", "http/1.1"];
pub const DEFAULT_FRONT_TIMEOUT: u32 = 60;
pub const DEFAULT_BACK_TIMEOUT: u32 = 30;
pub const DEFAULT_CONNECT_TIMEOUT: u32 = 3;
pub const DEFAULT_REQUEST_TIMEOUT: u32 = 10;
pub const DEFAULT_WORKER_TIMEOUT: u32 = 10;
pub const DEFAULT_STICKY_NAME: &str = "SOZUBALANCEID";
pub const DEFAULT_ZOMBIE_CHECK_INTERVAL: u32 = 1_800;
pub const DEFAULT_ACCEPT_QUEUE_TIMEOUT: u32 = 60;
pub const DEFAULT_HSTS_MAX_AGE: u32 = 31_536_000;
pub const DEFAULT_EVICT_ON_QUEUE_FULL: bool = false;
pub const DEFAULT_WORKER_COUNT: u16 = 2;
pub const DEFAULT_WORKER_AUTOMATIC_RESTART: bool = true;
pub const DEFAULT_AUTOMATIC_STATE_SAVE: bool = false;
pub const DEFAULT_MIN_BUFFERS: u64 = 1;
pub const DEFAULT_MAX_BUFFERS: u64 = 1_000;
pub const DEFAULT_BUFFER_SIZE: u64 = 16_393;
pub const H2_MIN_BUFFER_SIZE: u64 = 16_393;
pub const DEFAULT_MAX_CONNECTIONS: usize = 10_000;
pub const DEFAULT_COMMAND_BUFFER_SIZE: u64 = 1_000_000;
pub const DEFAULT_MAX_COMMAND_BUFFER_SIZE: u64 = 2_000_000;
pub const DEFAULT_DISABLE_CLUSTER_METRICS: bool = false;
pub const MAX_LOOP_ITERATIONS: usize = 100000;
pub const DEFAULT_SEND_TLS_13_TICKETS: u64 = 4;
pub const DEFAULT_LOG_TARGET: &str = "stdout";
pub const DEFAULT_MAX_CONNECTIONS_PER_IP: u64 = 0;
pub const DEFAULT_RETRY_AFTER: u32 = 60;
#[derive(Debug)]
pub enum IncompatibilityKind {
PublicAddress,
ProxyProtocol,
}
#[derive(Debug)]
pub enum MissingKind {
Field(String),
Protocol,
SavedState,
}
#[derive(thiserror::Error, Debug)]
pub enum ConfigError {
#[error("env path not found: {0}")]
Env(String),
#[error("Could not open file {path_to_open}: {io_error}")]
FileOpen {
path_to_open: String,
io_error: std::io::Error,
},
#[error("Could not read file {path_to_read}: {io_error}")]
FileRead {
path_to_read: String,
io_error: std::io::Error,
},
#[error(
"the field {kind:?} of {object:?} with id or address {id} is incompatible with the rest of the options"
)]
Incompatible {
kind: IncompatibilityKind,
object: ObjectKind,
id: String,
},
#[error("Invalid '{0}' field for a TCP frontend")]
InvalidFrontendConfig(String),
#[error("invalid path {0:?}")]
InvalidPath(PathBuf),
#[error("listening address {0:?} is already used in the configuration")]
ListenerAddressAlreadyInUse(SocketAddr),
#[error("missing {0:?}")]
Missing(MissingKind),
#[error("could not get parent directory for file {0}")]
NoFileParent(String),
#[error("Could not get the path of the saved state")]
SaveStatePath(String),
#[error("Can not determine path to sozu socket: {0}")]
SocketPathError(String),
#[error("toml decoding error: {0}")]
DeserializeToml(String),
#[error("Can not set this frontend on a {0:?} listener")]
WrongFrontendProtocol(ListenerProtocol),
#[error("Can not build a {expected:?} listener from a {found:?} config")]
WrongListenerProtocol {
expected: ListenerProtocol,
found: Option<ListenerProtocol>,
},
#[error("Invalid ALPN protocol '{0}'. Valid values: \"h2\", \"http/1.1\"")]
InvalidAlpnProtocol(String),
#[error(
"disable_http11 = true is incompatible with alpn_protocols containing \"http/1.1\" \
on listener {address}. The proxy would advertise http/1.1 then refuse every \
connection that negotiates it. Drop \"http/1.1\" from alpn_protocols or unset \
disable_http11."
)]
DisableHttp11WithHttp11Alpn { address: String },
#[error(
"buffer_size = {buffer_size} is below the H2 minimum of {minimum} but \
{listeners} HTTPS listener(s) advertise H2 ALPN. The H2 mux deadlocks \
on full-size frames with smaller buffers. Raise buffer_size to >= {minimum} \
or remove \"h2\" from those listeners' alpn_protocols."
)]
BufferSizeTooSmallForH2 {
buffer_size: u64,
minimum: u64,
listeners: usize,
},
#[error(
"invalid redirect policy '{0}'. Valid values: \"forward\", \"permanent\", \"unauthorized\""
)]
InvalidRedirectPolicy(String),
#[error(
"invalid redirect scheme '{0}'. Valid values: \"use-same\", \"use-http\", \"use-https\""
)]
InvalidRedirectScheme(String),
#[error(
"invalid header position '{position}' at headers[{index}]. Valid values: \"request\", \"response\", \"both\""
)]
InvalidHeaderPosition { index: usize, position: String },
#[error(
"invalid header bytes in {field} at headers[{index}]: control characters \
(NUL / CR / LF / other C0) are forbidden in header keys and values"
)]
InvalidHeaderBytes { index: usize, field: &'static str },
#[error("invalid HSTS config at {0}: `enabled` is required when an [hsts] block is present")]
HstsEnabledRequired(String),
#[error(
"invalid HSTS config at {0}: HSTS is only valid on HTTPS listeners and frontends \
(RFC 6797 §7.2 forbids the header over plaintext HTTP)"
)]
HstsOnPlainHttp(String),
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ListenerBuilder {
pub address: SocketAddr,
pub protocol: Option<ListenerProtocol>,
pub public_address: Option<SocketAddr>,
pub answer_301: Option<String>,
pub answer_400: Option<String>,
pub answer_401: Option<String>,
pub answer_404: Option<String>,
pub answer_408: Option<String>,
pub answer_413: Option<String>,
pub answer_421: Option<String>,
pub answer_502: Option<String>,
pub answer_503: Option<String>,
pub answer_504: Option<String>,
pub answer_507: Option<String>,
pub answer_429: Option<String>,
pub tls_versions: Option<Vec<TlsVersion>>,
pub cipher_list: Option<Vec<String>>,
pub cipher_suites: Option<Vec<String>>,
pub groups_list: Option<Vec<String>>,
pub expect_proxy: Option<bool>,
#[serde(default = "default_sticky_name")]
pub sticky_name: String,
pub certificate: Option<String>,
pub certificate_chain: Option<String>,
pub key: Option<String>,
pub front_timeout: Option<u32>,
pub back_timeout: Option<u32>,
pub connect_timeout: Option<u32>,
pub request_timeout: Option<u32>,
pub config: Option<Config>,
pub send_tls13_tickets: Option<u64>,
pub alpn_protocols: Option<Vec<String>>,
pub h2_max_rst_stream_per_window: Option<u32>,
pub h2_max_ping_per_window: Option<u32>,
pub h2_max_settings_per_window: Option<u32>,
pub h2_max_empty_data_per_window: Option<u32>,
pub h2_max_window_update_stream0_per_window: Option<u32>,
pub sozu_id_header: Option<String>,
pub h2_max_continuation_frames: Option<u32>,
pub h2_max_glitch_count: Option<u32>,
pub h2_initial_connection_window: Option<u32>,
pub h2_max_concurrent_streams: Option<u32>,
pub h2_stream_shrink_ratio: Option<u32>,
pub h2_max_rst_stream_lifetime: Option<u64>,
pub h2_max_rst_stream_abusive_lifetime: Option<u64>,
pub h2_max_rst_stream_emitted_lifetime: Option<u64>,
pub h2_max_header_list_size: Option<u32>,
pub h2_max_header_table_size: Option<u32>,
pub h2_stream_idle_timeout_seconds: Option<u32>,
pub h2_graceful_shutdown_deadline_seconds: Option<u32>,
pub strict_sni_binding: Option<bool>,
pub disable_http11: Option<bool>,
pub elide_x_real_ip: Option<bool>,
pub send_x_real_ip: Option<bool>,
pub answers: Option<BTreeMap<String, String>>,
pub hsts: Option<FileHstsConfig>,
}
pub fn default_sticky_name() -> String {
DEFAULT_STICKY_NAME.to_string()
}
impl ListenerBuilder {
pub fn new_http(address: SocketAddress) -> ListenerBuilder {
Self::new(address, ListenerProtocol::Http)
}
pub fn new_tcp(address: SocketAddress) -> ListenerBuilder {
Self::new(address, ListenerProtocol::Tcp)
}
pub fn new_https(address: SocketAddress) -> ListenerBuilder {
Self::new(address, ListenerProtocol::Https)
}
fn new(address: SocketAddress, protocol: ListenerProtocol) -> ListenerBuilder {
ListenerBuilder {
address: address.into(),
answer_301: None,
answer_401: None,
answer_400: None,
answer_404: None,
answer_408: None,
answer_413: None,
answer_421: None,
answer_502: None,
answer_503: None,
answer_504: None,
answer_507: None,
answer_429: None,
back_timeout: None,
certificate_chain: None,
certificate: None,
cipher_list: None,
cipher_suites: None,
groups_list: None,
config: None,
connect_timeout: None,
expect_proxy: None,
front_timeout: None,
key: None,
protocol: Some(protocol),
public_address: None,
request_timeout: None,
send_tls13_tickets: None,
sticky_name: DEFAULT_STICKY_NAME.to_string(),
tls_versions: None,
alpn_protocols: None,
h2_max_rst_stream_per_window: None,
h2_max_ping_per_window: None,
h2_max_settings_per_window: None,
h2_max_empty_data_per_window: None,
h2_max_window_update_stream0_per_window: None,
sozu_id_header: None,
h2_max_continuation_frames: None,
h2_max_glitch_count: None,
h2_initial_connection_window: None,
h2_max_concurrent_streams: None,
h2_stream_shrink_ratio: None,
h2_max_rst_stream_lifetime: None,
h2_max_rst_stream_abusive_lifetime: None,
h2_max_rst_stream_emitted_lifetime: None,
h2_max_header_list_size: None,
h2_max_header_table_size: None,
h2_stream_idle_timeout_seconds: None,
h2_graceful_shutdown_deadline_seconds: None,
strict_sni_binding: None,
disable_http11: None,
elide_x_real_ip: None,
send_x_real_ip: None,
answers: None,
hsts: None,
}
}
pub fn with_public_address(&mut self, public_address: Option<SocketAddr>) -> &mut Self {
if let Some(address) = public_address {
self.public_address = Some(address);
}
self
}
pub fn with_answer_404_path<S>(&mut self, answer_404_path: Option<S>) -> &mut Self
where
S: ToString,
{
if let Some(path) = answer_404_path {
self.answer_404 = Some(path.to_string());
}
self
}
pub fn with_answer_503_path<S>(&mut self, answer_503_path: Option<S>) -> &mut Self
where
S: ToString,
{
if let Some(path) = answer_503_path {
self.answer_503 = Some(path.to_string());
}
self
}
pub fn with_tls_versions(&mut self, tls_versions: Vec<TlsVersion>) -> &mut Self {
self.tls_versions = Some(tls_versions);
self
}
pub fn with_cipher_list(&mut self, cipher_list: Option<Vec<String>>) -> &mut Self {
self.cipher_list = cipher_list;
self
}
pub fn with_cipher_suites(&mut self, cipher_suites: Option<Vec<String>>) -> &mut Self {
self.cipher_suites = cipher_suites;
self
}
pub fn with_alpn_protocols(&mut self, alpn_protocols: Option<Vec<String>>) -> &mut Self {
self.alpn_protocols = alpn_protocols;
self
}
pub fn with_elide_x_real_ip(&mut self, elide_x_real_ip: bool) -> &mut Self {
self.elide_x_real_ip = Some(elide_x_real_ip);
self
}
pub fn with_send_x_real_ip(&mut self, send_x_real_ip: bool) -> &mut Self {
self.send_x_real_ip = Some(send_x_real_ip);
self
}
pub fn with_expect_proxy(&mut self, expect_proxy: bool) -> &mut Self {
self.expect_proxy = Some(expect_proxy);
self
}
pub fn with_sticky_name<S>(&mut self, sticky_name: Option<S>) -> &mut Self
where
S: ToString,
{
if let Some(name) = sticky_name {
self.sticky_name = name.to_string();
}
self
}
pub fn with_certificate<S>(&mut self, certificate: S) -> &mut Self
where
S: ToString,
{
self.certificate = Some(certificate.to_string());
self
}
pub fn with_certificate_chain(&mut self, certificate_chain: String) -> &mut Self {
self.certificate = Some(certificate_chain);
self
}
pub fn with_key<S>(&mut self, key: String) -> &mut Self
where
S: ToString,
{
self.key = Some(key);
self
}
pub fn with_front_timeout(&mut self, front_timeout: Option<u32>) -> &mut Self {
self.front_timeout = front_timeout;
self
}
pub fn with_back_timeout(&mut self, back_timeout: Option<u32>) -> &mut Self {
self.back_timeout = back_timeout;
self
}
pub fn with_connect_timeout(&mut self, connect_timeout: Option<u32>) -> &mut Self {
self.connect_timeout = connect_timeout;
self
}
pub fn with_request_timeout(&mut self, request_timeout: Option<u32>) -> &mut Self {
self.request_timeout = request_timeout;
self
}
pub fn with_answer<S, P>(&mut self, code: S, path: P) -> &mut Self
where
S: ToString,
P: ToString,
{
self.answers
.get_or_insert_with(BTreeMap::new)
.insert(code.to_string(), path.to_string());
self
}
pub fn with_answers(&mut self, answers: BTreeMap<String, String>) -> &mut Self {
self.answers = Some(answers);
self
}
fn get_http_answers(&self) -> Result<Option<CustomHttpAnswers>, ConfigError> {
let http_answers = CustomHttpAnswers {
answer_301: read_http_answer_file(&self.answer_301)?,
answer_400: read_http_answer_file(&self.answer_400)?,
answer_401: read_http_answer_file(&self.answer_401)?,
answer_404: read_http_answer_file(&self.answer_404)?,
answer_408: read_http_answer_file(&self.answer_408)?,
answer_413: read_http_answer_file(&self.answer_413)?,
answer_421: read_http_answer_file(&self.answer_421)?,
answer_502: read_http_answer_file(&self.answer_502)?,
answer_503: read_http_answer_file(&self.answer_503)?,
answer_504: read_http_answer_file(&self.answer_504)?,
answer_507: read_http_answer_file(&self.answer_507)?,
answer_429: read_http_answer_file(&self.answer_429)?,
};
Ok(Some(http_answers))
}
fn get_listener_answers(&self) -> Result<BTreeMap<String, String>, ConfigError> {
let mut out = BTreeMap::new();
macro_rules! merge_legacy {
($code:literal, $field:ident) => {
if let Some(body) = read_http_answer_file(&self.$field)? {
out.insert($code.to_owned(), body);
}
};
}
merge_legacy!("301", answer_301);
merge_legacy!("400", answer_400);
merge_legacy!("401", answer_401);
merge_legacy!("404", answer_404);
merge_legacy!("408", answer_408);
merge_legacy!("413", answer_413);
merge_legacy!("421", answer_421);
merge_legacy!("502", answer_502);
merge_legacy!("503", answer_503);
merge_legacy!("504", answer_504);
merge_legacy!("507", answer_507);
merge_legacy!("429", answer_429);
if let Some(map) = &self.answers {
let loaded = load_answers(map)?;
out.extend(loaded);
}
Ok(out)
}
fn assign_config_timeouts(&mut self, config: &Config) {
self.front_timeout = Some(self.front_timeout.unwrap_or(config.front_timeout));
self.back_timeout = Some(self.back_timeout.unwrap_or(config.back_timeout));
self.connect_timeout = Some(self.connect_timeout.unwrap_or(config.connect_timeout));
self.request_timeout = Some(self.request_timeout.unwrap_or(config.request_timeout));
}
pub fn to_http(&mut self, config: Option<&Config>) -> Result<HttpListenerConfig, ConfigError> {
if self.protocol != Some(ListenerProtocol::Http) {
return Err(ConfigError::WrongListenerProtocol {
expected: ListenerProtocol::Http,
found: self.protocol.to_owned(),
});
}
if self.hsts.is_some() {
return Err(ConfigError::HstsOnPlainHttp(format!(
"HTTP listener {}",
self.address
)));
}
if let Some(config) = config {
self.assign_config_timeouts(config);
}
let http_answers = self.get_http_answers()?;
let answers = self.get_listener_answers()?;
let configuration = HttpListenerConfig {
address: self.address.into(),
public_address: self.public_address.map(|a| a.into()),
expect_proxy: self.expect_proxy.unwrap_or(false),
sticky_name: self.sticky_name.clone(),
front_timeout: self.front_timeout.unwrap_or(DEFAULT_FRONT_TIMEOUT),
back_timeout: self.back_timeout.unwrap_or(DEFAULT_BACK_TIMEOUT),
connect_timeout: self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT),
request_timeout: self.request_timeout.unwrap_or(DEFAULT_REQUEST_TIMEOUT),
http_answers,
answers,
h2_max_rst_stream_per_window: self.h2_max_rst_stream_per_window,
h2_max_ping_per_window: self.h2_max_ping_per_window,
h2_max_settings_per_window: self.h2_max_settings_per_window,
h2_max_empty_data_per_window: self.h2_max_empty_data_per_window,
h2_max_window_update_stream0_per_window: self.h2_max_window_update_stream0_per_window,
h2_max_continuation_frames: self.h2_max_continuation_frames,
h2_max_glitch_count: self.h2_max_glitch_count,
h2_initial_connection_window: self.h2_initial_connection_window,
h2_max_concurrent_streams: self.h2_max_concurrent_streams,
h2_stream_shrink_ratio: self.h2_stream_shrink_ratio,
h2_max_rst_stream_lifetime: self.h2_max_rst_stream_lifetime,
h2_max_rst_stream_abusive_lifetime: self.h2_max_rst_stream_abusive_lifetime,
h2_max_rst_stream_emitted_lifetime: self.h2_max_rst_stream_emitted_lifetime,
h2_max_header_list_size: self.h2_max_header_list_size,
h2_max_header_table_size: self.h2_max_header_table_size,
h2_stream_idle_timeout_seconds: self.h2_stream_idle_timeout_seconds,
h2_graceful_shutdown_deadline_seconds: self.h2_graceful_shutdown_deadline_seconds,
sozu_id_header: self.sozu_id_header.clone(),
elide_x_real_ip: Some(self.elide_x_real_ip.unwrap_or(false)),
send_x_real_ip: Some(self.send_x_real_ip.unwrap_or(false)),
..Default::default()
};
Ok(configuration)
}
pub fn to_tls(&mut self, config: Option<&Config>) -> Result<HttpsListenerConfig, ConfigError> {
if self.protocol != Some(ListenerProtocol::Https) {
return Err(ConfigError::WrongListenerProtocol {
expected: ListenerProtocol::Https,
found: self.protocol.to_owned(),
});
}
let default_cipher_list = DEFAULT_CIPHER_LIST.into_iter().map(String::from).collect();
let cipher_list = self.cipher_list.clone().unwrap_or(default_cipher_list);
let cipher_suites = self
.cipher_suites
.clone()
.unwrap_or_else(|| DEFAULT_CIPHER_LIST.into_iter().map(String::from).collect());
let signature_algorithms: Vec<String> = DEFAULT_SIGNATURE_ALGORITHMS
.into_iter()
.map(String::from)
.collect();
let groups_list = self
.groups_list
.clone()
.unwrap_or_else(|| DEFAULT_GROUPS_LIST.into_iter().map(String::from).collect());
let alpn_protocols: Vec<String> = match &self.alpn_protocols {
Some(protos) if !protos.is_empty() => {
for proto in protos {
match proto.as_str() {
"h2" | "http/1.1" => {}
other => return Err(ConfigError::InvalidAlpnProtocol(other.to_owned())),
}
}
if self.disable_http11.unwrap_or(false) && protos.iter().any(|p| p == "http/1.1") {
return Err(ConfigError::DisableHttp11WithHttp11Alpn {
address: self.address.to_string(),
});
}
if !protos.iter().any(|p| p == "http/1.1") {
warn!(
"ALPN protocols do not include 'http/1.1'. Clients without H2 support will fail TLS negotiation."
);
}
let mut seen = std::collections::HashSet::new();
protos
.iter()
.filter(|p| seen.insert(p.as_str()))
.cloned()
.collect()
}
_ => {
if self.disable_http11.unwrap_or(false)
&& DEFAULT_ALPN_PROTOCOLS.contains(&"http/1.1")
{
return Err(ConfigError::DisableHttp11WithHttp11Alpn {
address: self.address.to_string(),
});
}
DEFAULT_ALPN_PROTOCOLS
.iter()
.map(|s| s.to_string())
.collect()
}
};
let versions = match self.tls_versions {
None => vec![TlsVersion::TlsV12 as i32, TlsVersion::TlsV13 as i32],
Some(ref v) => v.iter().map(|v| *v as i32).collect(),
};
let key = self.key.as_ref().and_then(|path| {
Config::load_file(path)
.map_err(|e| {
error!("cannot load key at path '{}': {:?}", path, e);
e
})
.ok()
});
let certificate = self.certificate.as_ref().and_then(|path| {
Config::load_file(path)
.map_err(|e| {
error!("cannot load certificate at path '{}': {:?}", path, e);
e
})
.ok()
});
let certificate_chain = self
.certificate_chain
.as_ref()
.and_then(|path| {
Config::load_file(path)
.map_err(|e| {
error!("cannot load certificate chain at path '{}': {:?}", path, e);
e
})
.ok()
})
.map(split_certificate_chain)
.unwrap_or_default();
let http_answers = self.get_http_answers()?;
let answers = self.get_listener_answers()?;
if let Some(config) = config {
self.assign_config_timeouts(config);
}
let https_listener_config = HttpsListenerConfig {
address: self.address.into(),
sticky_name: self.sticky_name.clone(),
public_address: self.public_address.map(|a| a.into()),
cipher_list,
versions,
expect_proxy: self.expect_proxy.unwrap_or(false),
key,
certificate,
certificate_chain,
front_timeout: self.front_timeout.unwrap_or(DEFAULT_FRONT_TIMEOUT),
back_timeout: self.back_timeout.unwrap_or(DEFAULT_BACK_TIMEOUT),
connect_timeout: self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT),
request_timeout: self.request_timeout.unwrap_or(DEFAULT_REQUEST_TIMEOUT),
cipher_suites,
signature_algorithms,
groups_list,
active: false,
send_tls13_tickets: self
.send_tls13_tickets
.unwrap_or(DEFAULT_SEND_TLS_13_TICKETS),
http_answers,
answers,
alpn_protocols,
h2_max_rst_stream_per_window: self.h2_max_rst_stream_per_window,
h2_max_ping_per_window: self.h2_max_ping_per_window,
h2_max_settings_per_window: self.h2_max_settings_per_window,
h2_max_empty_data_per_window: self.h2_max_empty_data_per_window,
h2_max_window_update_stream0_per_window: self.h2_max_window_update_stream0_per_window,
h2_max_continuation_frames: self.h2_max_continuation_frames,
h2_max_glitch_count: self.h2_max_glitch_count,
h2_initial_connection_window: self.h2_initial_connection_window,
h2_max_concurrent_streams: self.h2_max_concurrent_streams,
h2_stream_shrink_ratio: self.h2_stream_shrink_ratio,
h2_max_rst_stream_lifetime: self.h2_max_rst_stream_lifetime,
h2_max_rst_stream_abusive_lifetime: self.h2_max_rst_stream_abusive_lifetime,
h2_max_rst_stream_emitted_lifetime: self.h2_max_rst_stream_emitted_lifetime,
h2_max_header_list_size: self.h2_max_header_list_size,
h2_max_header_table_size: self.h2_max_header_table_size,
strict_sni_binding: self.strict_sni_binding,
disable_http11: self.disable_http11,
h2_stream_idle_timeout_seconds: self.h2_stream_idle_timeout_seconds,
h2_graceful_shutdown_deadline_seconds: self.h2_graceful_shutdown_deadline_seconds,
sozu_id_header: self.sozu_id_header.clone(),
elide_x_real_ip: Some(self.elide_x_real_ip.unwrap_or(false)),
send_x_real_ip: Some(self.send_x_real_ip.unwrap_or(false)),
hsts: match self.hsts.as_ref() {
Some(h) => Some(h.to_proto("listener")?),
None => None,
},
};
Ok(https_listener_config)
}
pub fn to_tcp(&mut self, config: Option<&Config>) -> Result<TcpListenerConfig, ConfigError> {
if self.protocol != Some(ListenerProtocol::Tcp) {
return Err(ConfigError::WrongListenerProtocol {
expected: ListenerProtocol::Tcp,
found: self.protocol.to_owned(),
});
}
if let Some(config) = config {
self.assign_config_timeouts(config);
}
Ok(TcpListenerConfig {
address: self.address.into(),
public_address: self.public_address.map(|a| a.into()),
expect_proxy: self.expect_proxy.unwrap_or(false),
front_timeout: self.front_timeout.unwrap_or(DEFAULT_FRONT_TIMEOUT),
back_timeout: self.back_timeout.unwrap_or(DEFAULT_BACK_TIMEOUT),
connect_timeout: self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT),
active: false,
})
}
}
fn read_http_answer_file(path: &Option<String>) -> Result<Option<String>, ConfigError> {
match path {
Some(path) => {
let mut content = String::new();
let mut file = File::open(path).map_err(|io_error| ConfigError::FileOpen {
path_to_open: path.to_owned(),
io_error,
})?;
file.read_to_string(&mut content)
.map_err(|io_error| ConfigError::FileRead {
path_to_read: path.to_owned(),
io_error,
})?;
Ok(Some(content))
}
None => Ok(None),
}
}
pub fn resolve_answer_source(value: &str) -> Result<String, ConfigError> {
if let Some(path) = value.strip_prefix("file://") {
let mut content = String::new();
let mut file = File::open(path).map_err(|io_error| ConfigError::FileOpen {
path_to_open: path.to_owned(),
io_error,
})?;
file.read_to_string(&mut content)
.map_err(|io_error| ConfigError::FileRead {
path_to_read: path.to_owned(),
io_error,
})?;
return Ok(content);
}
Ok(value.to_owned())
}
pub fn load_answers(
answers: &BTreeMap<String, String>,
) -> Result<BTreeMap<String, String>, ConfigError> {
let mut out = BTreeMap::new();
for (code, value) in answers {
if value.is_empty() {
continue;
}
out.insert(code.to_owned(), resolve_answer_source(value)?);
}
Ok(out)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MetricDetailLevel {
Process,
Frontend,
Cluster,
Backend,
}
impl Default for MetricDetailLevel {
fn default() -> Self {
Self::Cluster
}
}
impl From<MetricDetailLevel> for MetricDetail {
fn from(level: MetricDetailLevel) -> Self {
match level {
MetricDetailLevel::Process => MetricDetail::DetailProcess,
MetricDetailLevel::Frontend => MetricDetail::DetailFrontend,
MetricDetailLevel::Cluster => MetricDetail::DetailCluster,
MetricDetailLevel::Backend => MetricDetail::DetailBackend,
}
}
}
impl From<MetricDetail> for MetricDetailLevel {
fn from(detail: MetricDetail) -> Self {
match detail {
MetricDetail::DetailProcess => MetricDetailLevel::Process,
MetricDetail::DetailFrontend => MetricDetailLevel::Frontend,
MetricDetail::DetailCluster => MetricDetailLevel::Cluster,
MetricDetail::DetailBackend => MetricDetailLevel::Backend,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct MetricsConfig {
pub address: SocketAddr,
#[serde(default)]
pub tagged_metrics: bool,
#[serde(default)]
pub prefix: Option<String>,
#[serde(default)]
pub detail: MetricDetailLevel,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
#[serde(deny_unknown_fields)]
pub enum PathRuleType {
Prefix,
Regex,
Equals,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct FileClusterFrontendConfig {
pub address: SocketAddr,
pub hostname: Option<String>,
pub path: Option<String>,
pub path_type: Option<PathRuleType>,
pub method: Option<String>,
pub certificate: Option<String>,
pub key: Option<String>,
pub certificate_chain: Option<String>,
#[serde(default)]
pub tls_versions: Vec<TlsVersion>,
#[serde(default)]
pub position: RulePosition,
pub tags: Option<BTreeMap<String, String>>,
pub redirect: Option<String>,
pub redirect_scheme: Option<String>,
pub redirect_template: Option<String>,
pub rewrite_host: Option<String>,
pub rewrite_path: Option<String>,
pub rewrite_port: Option<u32>,
pub required_auth: Option<bool>,
pub headers: Option<Vec<HeaderEditConfig>>,
pub hsts: Option<FileHstsConfig>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct HeaderEditConfig {
pub position: String,
pub key: String,
pub value: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct FileHstsConfig {
pub enabled: Option<bool>,
pub max_age: Option<u32>,
pub include_subdomains: Option<bool>,
pub preload: Option<bool>,
pub force_replace_backend: Option<bool>,
}
impl FileHstsConfig {
pub fn to_proto(&self, scope: &str) -> Result<HstsConfig, ConfigError> {
let enabled = match self.enabled {
Some(v) => v,
None => return Err(ConfigError::HstsEnabledRequired(scope.to_owned())),
};
let max_age = match (enabled, self.max_age) {
(true, None) => Some(DEFAULT_HSTS_MAX_AGE),
(_, m) => m,
};
if let Some(value) = max_age
&& value > 0
&& value < 86_400
{
warn!(
"HSTS max_age = {}s on {} is below 1 day — this is almost certainly a \
misconfiguration. RFC 6797 §11.4 reserves max_age = 0 as the explicit kill \
switch.",
value, scope
);
}
let include_subdomains = self.include_subdomains;
let preload = self.preload;
if matches!(preload, Some(true)) {
let max_age_value = max_age.unwrap_or(0);
if max_age_value < DEFAULT_HSTS_MAX_AGE {
warn!(
"HSTS preload = true on {} with max_age = {}s; the Chrome HSTS preload \
list requires max_age >= {} (https://hstspreload.org/).",
scope, max_age_value, DEFAULT_HSTS_MAX_AGE
);
}
if include_subdomains != Some(true) {
warn!(
"HSTS preload = true on {} without include_subdomains = true; the Chrome \
HSTS preload list requires includeSubDomains \
(https://hstspreload.org/).",
scope
);
}
}
Ok(HstsConfig {
enabled: Some(enabled),
max_age,
include_subdomains,
preload,
force_replace_backend: self.force_replace_backend,
})
}
}
impl FileClusterFrontendConfig {
pub fn to_tcp_front(&self) -> Result<TcpFrontendConfig, ConfigError> {
if self.hostname.is_some() {
return Err(ConfigError::InvalidFrontendConfig("hostname".to_string()));
}
if self.path.is_some() {
return Err(ConfigError::InvalidFrontendConfig(
"path_prefix".to_string(),
));
}
if self.certificate.is_some() {
return Err(ConfigError::InvalidFrontendConfig(
"certificate".to_string(),
));
}
if self.hostname.is_some() {
return Err(ConfigError::InvalidFrontendConfig("hostname".to_string()));
}
if self.certificate_chain.is_some() {
return Err(ConfigError::InvalidFrontendConfig(
"certificate_chain".to_string(),
));
}
Ok(TcpFrontendConfig {
address: self.address,
tags: self.tags.clone(),
})
}
pub fn to_http_front(&self, _cluster_id: &str) -> Result<HttpFrontendConfig, ConfigError> {
let hostname = match &self.hostname {
Some(hostname) => hostname.to_owned(),
None => {
return Err(ConfigError::Missing(MissingKind::Field(
"hostname".to_string(),
)));
}
};
let key_opt = match self.key.as_ref() {
None => None,
Some(path) => {
let key = Config::load_file(path)?;
Some(key)
}
};
let certificate_opt = match self.certificate.as_ref() {
None => None,
Some(path) => {
let certificate = Config::load_file(path)?;
Some(certificate)
}
};
let certificate_chain = match self.certificate_chain.as_ref() {
None => None,
Some(path) => {
let certificate_chain = Config::load_file(path)?;
Some(split_certificate_chain(certificate_chain))
}
};
let path = match (self.path.as_ref(), self.path_type.as_ref()) {
(None, _) => PathRule::prefix("".to_string()),
(Some(s), Some(PathRuleType::Prefix)) => PathRule::prefix(s.to_string()),
(Some(s), Some(PathRuleType::Regex)) => PathRule::regex(s.to_string()),
(Some(s), Some(PathRuleType::Equals)) => PathRule::equals(s.to_string()),
(Some(s), None) => PathRule::prefix(s.clone()),
};
let redirect = match self.redirect.as_deref() {
Some(v) => Some(parse_redirect_policy(v)?),
None => None,
};
let redirect_scheme = match self.redirect_scheme.as_deref() {
Some(v) => Some(parse_redirect_scheme(v)?),
None => None,
};
let headers = match self.headers.as_ref() {
Some(entries) => {
let mut out = Vec::with_capacity(entries.len());
for (index, entry) in entries.iter().enumerate() {
out.push(parse_header_edit(index, entry)?);
}
out
}
None => Vec::new(),
};
let frontend_serves_https = key_opt.is_some() && certificate_opt.is_some();
let hsts = match self.hsts.as_ref() {
Some(h) => {
if !frontend_serves_https {
return Err(ConfigError::HstsOnPlainHttp(format!(
"frontend {_cluster_id}/{hostname}"
)));
}
Some(h.to_proto(&format!("frontend {_cluster_id}/{hostname}"))?)
}
None => None,
};
Ok(HttpFrontendConfig {
address: self.address,
hostname,
certificate: certificate_opt,
key: key_opt,
certificate_chain,
tls_versions: self.tls_versions.clone(),
position: self.position,
path,
method: self.method.clone(),
tags: self.tags.clone(),
redirect,
redirect_scheme,
redirect_template: self.redirect_template.clone(),
rewrite_host: self.rewrite_host.clone(),
rewrite_path: self.rewrite_path.clone(),
rewrite_port: self.rewrite_port,
required_auth: self.required_auth,
headers,
hsts,
})
}
}
pub(crate) fn parse_redirect_policy(value: &str) -> Result<RedirectPolicy, ConfigError> {
match value.to_ascii_lowercase().as_str() {
"forward" => Ok(RedirectPolicy::Forward),
"permanent" => Ok(RedirectPolicy::Permanent),
"unauthorized" => Ok(RedirectPolicy::Unauthorized),
_ => Err(ConfigError::InvalidRedirectPolicy(value.to_owned())),
}
}
pub(crate) fn parse_redirect_scheme(value: &str) -> Result<RedirectScheme, ConfigError> {
match value.to_ascii_lowercase().as_str() {
"use-same" | "use_same" => Ok(RedirectScheme::UseSame),
"use-http" | "use_http" => Ok(RedirectScheme::UseHttp),
"use-https" | "use_https" => Ok(RedirectScheme::UseHttps),
_ => Err(ConfigError::InvalidRedirectScheme(value.to_owned())),
}
}
pub(crate) fn parse_header_edit(
index: usize,
entry: &HeaderEditConfig,
) -> Result<Header, ConfigError> {
let position = match entry.position.to_ascii_lowercase().as_str() {
"request" => HeaderPosition::Request,
"response" => HeaderPosition::Response,
"both" => HeaderPosition::Both,
_ => {
return Err(ConfigError::InvalidHeaderPosition {
index,
position: entry.position.clone(),
});
}
};
if !header_name_is_valid_token(entry.key.as_bytes()) {
return Err(ConfigError::InvalidHeaderBytes {
index,
field: "key",
});
}
if header_value_contains_forbidden_controls(entry.value.as_bytes()) {
return Err(ConfigError::InvalidHeaderBytes {
index,
field: "value",
});
}
Ok(Header {
position: position as i32,
key: entry.key.clone(),
val: entry.value.clone(),
})
}
pub(crate) fn header_name_is_valid_token(bytes: &[u8]) -> bool {
if bytes.is_empty() {
return false;
}
bytes.iter().all(|&b| is_tchar(b))
}
fn is_tchar(b: u8) -> bool {
b.is_ascii_alphanumeric()
|| matches!(
b,
b'!' | b'#'
| b'$'
| b'%'
| b'&'
| b'\''
| b'*'
| b'+'
| b'-'
| b'.'
| b'^'
| b'_'
| b'`'
| b'|'
| b'~'
)
}
pub(crate) fn header_value_contains_forbidden_controls(bytes: &[u8]) -> bool {
bytes
.iter()
.any(|&b| matches!(b, 0x00..=0x08 | 0x0A..=0x1F | 0x7F))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(deny_unknown_fields, rename_all = "lowercase")]
pub enum ListenerProtocol {
Http,
Https,
Tcp,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(deny_unknown_fields, rename_all = "lowercase")]
pub enum FileClusterProtocolConfig {
Http,
Tcp,
}
fn default_health_check_interval() -> u32 {
10
}
fn default_health_check_timeout() -> u32 {
5
}
fn default_health_check_threshold() -> u32 {
3
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct FileHealthCheckConfig {
pub uri: String,
#[serde(default = "default_health_check_interval")]
pub interval: u32,
#[serde(default = "default_health_check_timeout")]
pub timeout: u32,
#[serde(default = "default_health_check_threshold")]
pub healthy_threshold: u32,
#[serde(default = "default_health_check_threshold")]
pub unhealthy_threshold: u32,
#[serde(default)]
pub expected_status: u32,
}
impl FileHealthCheckConfig {
pub fn to_proto(&self) -> HealthCheckConfig {
HealthCheckConfig {
uri: self.uri.to_owned(),
interval: self.interval,
timeout: self.timeout,
healthy_threshold: self.healthy_threshold,
unhealthy_threshold: self.unhealthy_threshold,
expected_status: self.expected_status,
}
}
}
pub fn validate_health_check_config(cfg: &HealthCheckConfig) -> Result<(), &'static str> {
if cfg.interval == 0 {
return Err("health check interval must be > 0");
}
if cfg.timeout == 0 {
return Err("health check timeout must be > 0");
}
if cfg.healthy_threshold == 0 {
return Err("health check healthy_threshold must be > 0");
}
if cfg.unhealthy_threshold == 0 {
return Err("health check unhealthy_threshold must be > 0");
}
if !cfg.uri.starts_with('/') {
return Err("health check URI must start with '/'");
}
if cfg
.uri
.bytes()
.any(|b| b == b'\r' || b == b'\n' || b == 0 || (b < 0x20 && b != b'\t'))
{
return Err("health check URI must not contain CR, LF, NUL, or other C0 control bytes");
}
Ok(())
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct FileClusterConfig {
pub frontends: Vec<FileClusterFrontendConfig>,
pub backends: Vec<BackendConfig>,
pub protocol: FileClusterProtocolConfig,
pub sticky_session: Option<bool>,
pub https_redirect: Option<bool>,
#[serde(default)]
pub send_proxy: Option<bool>,
#[serde(default)]
pub load_balancing: LoadBalancingAlgorithms,
pub answer_503: Option<String>,
#[serde(default)]
pub load_metric: Option<LoadMetric>,
pub http2: Option<bool>,
pub answers: Option<BTreeMap<String, String>>,
pub https_redirect_port: Option<u32>,
pub authorized_hashes: Option<Vec<String>>,
pub www_authenticate: Option<String>,
pub max_connections_per_ip: Option<u64>,
pub retry_after: Option<u32>,
#[serde(default)]
pub health_check: Option<FileHealthCheckConfig>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct BackendConfig {
pub address: SocketAddr,
pub weight: Option<u8>,
pub sticky_id: Option<String>,
pub backup: Option<bool>,
pub backend_id: Option<String>,
}
impl FileClusterConfig {
pub fn to_cluster_config(
self,
cluster_id: &str,
expect_proxy: &HashSet<SocketAddr>,
) -> Result<ClusterConfig, ConfigError> {
match self.protocol {
FileClusterProtocolConfig::Tcp => {
let mut has_expect_proxy = None;
let mut frontends = Vec::new();
for f in self.frontends {
if expect_proxy.contains(&f.address) {
match has_expect_proxy {
Some(true) => {}
Some(false) => {
return Err(ConfigError::Incompatible {
object: ObjectKind::Cluster,
id: cluster_id.to_owned(),
kind: IncompatibilityKind::ProxyProtocol,
});
}
None => has_expect_proxy = Some(true),
}
} else {
match has_expect_proxy {
Some(false) => {}
Some(true) => {
return Err(ConfigError::Incompatible {
object: ObjectKind::Cluster,
id: cluster_id.to_owned(),
kind: IncompatibilityKind::ProxyProtocol,
});
}
None => has_expect_proxy = Some(false),
}
}
let tcp_frontend = f.to_tcp_front()?;
frontends.push(tcp_frontend);
}
let send_proxy = self.send_proxy.unwrap_or(false);
let expect_proxy = has_expect_proxy.unwrap_or(false);
let proxy_protocol = match (send_proxy, expect_proxy) {
(true, true) => Some(ProxyProtocolConfig::RelayHeader),
(true, false) => Some(ProxyProtocolConfig::SendHeader),
(false, true) => Some(ProxyProtocolConfig::ExpectHeader),
_ => None,
};
let answers = match self.answers.as_ref() {
Some(map) => load_answers(map)?,
None => BTreeMap::new(),
};
Ok(ClusterConfig::Tcp(TcpClusterConfig {
cluster_id: cluster_id.to_string(),
frontends,
backends: self.backends,
proxy_protocol,
load_balancing: self.load_balancing,
load_metric: self.load_metric,
answers,
https_redirect_port: self.https_redirect_port,
authorized_hashes: self.authorized_hashes.unwrap_or_default(),
www_authenticate: self.www_authenticate,
max_connections_per_ip: self.max_connections_per_ip,
retry_after: self.retry_after,
health_check: self.health_check.as_ref().map(|hc| hc.to_proto()),
}))
}
FileClusterProtocolConfig::Http => {
let mut frontends = Vec::new();
for frontend in self.frontends {
let http_frontend = frontend.to_http_front(cluster_id)?;
frontends.push(http_frontend);
}
let answer_503 = self.answer_503.as_ref().and_then(|path| {
Config::load_file(path)
.map_err(|e| {
error!("cannot load 503 error page at path '{}': {:?}", path, e);
e
})
.ok()
});
let answers = match self.answers.as_ref() {
Some(map) => load_answers(map)?,
None => BTreeMap::new(),
};
Ok(ClusterConfig::Http(HttpClusterConfig {
cluster_id: cluster_id.to_string(),
frontends,
backends: self.backends,
sticky_session: self.sticky_session.unwrap_or(false),
https_redirect: self.https_redirect.unwrap_or(false),
load_balancing: self.load_balancing,
load_metric: self.load_metric,
answer_503,
http2: self.http2,
answers,
https_redirect_port: self.https_redirect_port,
authorized_hashes: self.authorized_hashes.unwrap_or_default(),
www_authenticate: self.www_authenticate,
max_connections_per_ip: self.max_connections_per_ip,
retry_after: self.retry_after,
health_check: self.health_check.as_ref().map(|hc| hc.to_proto()),
}))
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct HttpFrontendConfig {
pub address: SocketAddr,
pub hostname: String,
pub path: PathRule,
pub method: Option<String>,
pub certificate: Option<String>,
pub key: Option<String>,
pub certificate_chain: Option<Vec<String>>,
#[serde(default)]
pub tls_versions: Vec<TlsVersion>,
#[serde(default)]
pub position: RulePosition,
pub tags: Option<BTreeMap<String, String>>,
#[serde(default)]
pub redirect: Option<RedirectPolicy>,
#[serde(default)]
pub redirect_scheme: Option<RedirectScheme>,
#[serde(default)]
pub redirect_template: Option<String>,
#[serde(default)]
pub rewrite_host: Option<String>,
#[serde(default)]
pub rewrite_path: Option<String>,
#[serde(default)]
pub rewrite_port: Option<u32>,
#[serde(default)]
pub required_auth: Option<bool>,
#[serde(default)]
pub headers: Vec<Header>,
#[serde(default)]
pub hsts: Option<HstsConfig>,
}
impl HttpFrontendConfig {
pub fn generate_requests(&self, cluster_id: &str) -> Vec<Request> {
let mut v = Vec::new();
let tags = self.tags.clone().unwrap_or_default();
if self.key.is_some() && self.certificate.is_some() {
v.push(
RequestType::AddCertificate(AddCertificate {
address: self.address.into(),
certificate: CertificateAndKey {
key: self.key.clone().unwrap(),
certificate: self.certificate.clone().unwrap(),
certificate_chain: self.certificate_chain.clone().unwrap_or_default(),
versions: self.tls_versions.iter().map(|v| *v as i32).collect(),
names: vec![],
},
expired_at: None,
})
.into(),
);
v.push(
RequestType::AddHttpsFrontend(RequestHttpFrontend {
cluster_id: Some(cluster_id.to_string()),
address: self.address.into(),
hostname: self.hostname.clone(),
path: self.path.clone(),
method: self.method.clone(),
position: self.position.into(),
tags,
redirect: self.redirect.map(|r| r as i32),
required_auth: self.required_auth,
redirect_scheme: self.redirect_scheme.map(|s| s as i32),
redirect_template: self.redirect_template.clone(),
rewrite_host: self.rewrite_host.clone(),
rewrite_path: self.rewrite_path.clone(),
rewrite_port: self.rewrite_port,
headers: self.headers.clone(),
hsts: self.hsts,
})
.into(),
);
} else {
v.push(
RequestType::AddHttpFrontend(RequestHttpFrontend {
cluster_id: Some(cluster_id.to_string()),
address: self.address.into(),
hostname: self.hostname.clone(),
path: self.path.clone(),
method: self.method.clone(),
position: self.position.into(),
tags,
redirect: self.redirect.map(|r| r as i32),
required_auth: self.required_auth,
redirect_scheme: self.redirect_scheme.map(|s| s as i32),
redirect_template: self.redirect_template.clone(),
rewrite_host: self.rewrite_host.clone(),
rewrite_path: self.rewrite_path.clone(),
rewrite_port: self.rewrite_port,
headers: self.headers.clone(),
hsts: self.hsts,
})
.into(),
);
}
v
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct HttpClusterConfig {
pub cluster_id: String,
pub frontends: Vec<HttpFrontendConfig>,
pub backends: Vec<BackendConfig>,
pub sticky_session: bool,
pub https_redirect: bool,
pub load_balancing: LoadBalancingAlgorithms,
pub load_metric: Option<LoadMetric>,
pub answer_503: Option<String>,
pub http2: Option<bool>,
#[serde(default)]
pub answers: BTreeMap<String, String>,
#[serde(default)]
pub https_redirect_port: Option<u32>,
#[serde(default)]
pub authorized_hashes: Vec<String>,
#[serde(default)]
pub www_authenticate: Option<String>,
#[serde(default)]
pub max_connections_per_ip: Option<u64>,
#[serde(default)]
pub retry_after: Option<u32>,
#[serde(default)]
pub health_check: Option<HealthCheckConfig>,
}
impl HttpClusterConfig {
pub fn generate_requests(&self) -> Result<Vec<Request>, ConfigError> {
let mut v = vec![
RequestType::AddCluster(Cluster {
cluster_id: self.cluster_id.clone(),
sticky_session: self.sticky_session,
https_redirect: self.https_redirect,
proxy_protocol: None,
load_balancing: self.load_balancing as i32,
answer_503: self.answer_503.clone(),
load_metric: self.load_metric.map(|s| s as i32),
http2: self.http2,
answers: self.answers.clone(),
https_redirect_port: self.https_redirect_port,
authorized_hashes: self.authorized_hashes.clone(),
www_authenticate: self.www_authenticate.clone(),
max_connections_per_ip: self.max_connections_per_ip,
retry_after: self.retry_after,
health_check: self.health_check.clone(),
})
.into(),
];
for frontend in &self.frontends {
let mut orders = frontend.generate_requests(&self.cluster_id);
v.append(&mut orders);
}
for (backend_count, backend) in self.backends.iter().enumerate() {
let load_balancing_parameters = Some(LoadBalancingParams {
weight: backend.weight.unwrap_or(100) as i32,
});
v.push(
RequestType::AddBackend(AddBackend {
cluster_id: self.cluster_id.clone(),
backend_id: backend.backend_id.clone().unwrap_or_else(|| {
format!("{}-{}-{}", self.cluster_id, backend_count, backend.address)
}),
address: backend.address.into(),
load_balancing_parameters,
sticky_id: backend.sticky_id.clone(),
backup: backend.backup,
})
.into(),
);
}
Ok(v)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TcpFrontendConfig {
pub address: SocketAddr,
pub tags: Option<BTreeMap<String, String>>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TcpClusterConfig {
pub cluster_id: String,
pub frontends: Vec<TcpFrontendConfig>,
pub backends: Vec<BackendConfig>,
#[serde(default)]
pub proxy_protocol: Option<ProxyProtocolConfig>,
pub load_balancing: LoadBalancingAlgorithms,
pub load_metric: Option<LoadMetric>,
#[serde(default)]
pub answers: BTreeMap<String, String>,
#[serde(default)]
pub https_redirect_port: Option<u32>,
#[serde(default)]
pub authorized_hashes: Vec<String>,
#[serde(default)]
pub www_authenticate: Option<String>,
#[serde(default)]
pub max_connections_per_ip: Option<u64>,
#[serde(default)]
pub retry_after: Option<u32>,
#[serde(default)]
pub health_check: Option<HealthCheckConfig>,
}
impl TcpClusterConfig {
pub fn generate_requests(&self) -> Result<Vec<Request>, ConfigError> {
let mut v = vec![
RequestType::AddCluster(Cluster {
cluster_id: self.cluster_id.clone(),
sticky_session: false,
https_redirect: false,
proxy_protocol: self.proxy_protocol.map(|s| s as i32),
load_balancing: self.load_balancing as i32,
load_metric: self.load_metric.map(|s| s as i32),
answer_503: None,
http2: None,
answers: self.answers.clone(),
https_redirect_port: self.https_redirect_port,
authorized_hashes: self.authorized_hashes.clone(),
www_authenticate: self.www_authenticate.clone(),
max_connections_per_ip: self.max_connections_per_ip,
retry_after: self.retry_after,
health_check: self.health_check.clone(),
})
.into(),
];
for frontend in &self.frontends {
v.push(
RequestType::AddTcpFrontend(RequestTcpFrontend {
cluster_id: self.cluster_id.clone(),
address: frontend.address.into(),
tags: frontend.tags.clone().unwrap_or(BTreeMap::new()),
})
.into(),
);
}
for (backend_count, backend) in self.backends.iter().enumerate() {
let load_balancing_parameters = Some(LoadBalancingParams {
weight: backend.weight.unwrap_or(100) as i32,
});
v.push(
RequestType::AddBackend(AddBackend {
cluster_id: self.cluster_id.clone(),
backend_id: backend.backend_id.clone().unwrap_or_else(|| {
format!("{}-{}-{}", self.cluster_id, backend_count, backend.address)
}),
address: backend.address.into(),
load_balancing_parameters,
sticky_id: backend.sticky_id.clone(),
backup: backend.backup,
})
.into(),
);
}
Ok(v)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ClusterConfig {
Http(HttpClusterConfig),
Tcp(TcpClusterConfig),
}
impl ClusterConfig {
pub fn generate_requests(&self) -> Result<Vec<Request>, ConfigError> {
match *self {
ClusterConfig::Http(ref http) => http.generate_requests(),
ClusterConfig::Tcp(ref tcp) => tcp.generate_requests(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Default, Deserialize)]
pub struct FileConfig {
pub command_socket: Option<String>,
pub command_buffer_size: Option<u64>,
pub max_command_buffer_size: Option<u64>,
pub max_connections: Option<usize>,
pub min_buffers: Option<u64>,
pub max_buffers: Option<u64>,
pub buffer_size: Option<u64>,
#[serde(default)]
pub slab_entries_per_connection: Option<u64>,
#[serde(default)]
pub basic_auth_max_credential_bytes: Option<u64>,
#[serde(default)]
pub max_connections_per_ip: Option<u64>,
#[serde(default)]
pub retry_after: Option<u32>,
#[serde(default)]
pub splice_pipe_capacity_bytes: Option<u64>,
#[serde(default)]
pub command_allowed_uids: Option<Vec<u32>>,
pub saved_state: Option<String>,
#[serde(default)]
pub automatic_state_save: Option<bool>,
pub log_level: Option<String>,
pub log_target: Option<String>,
#[serde(default)]
pub log_colored: bool,
#[serde(default)]
pub audit_logs_target: Option<String>,
#[serde(default)]
pub audit_logs_json_target: Option<String>,
#[serde(default)]
pub access_logs_target: Option<String>,
#[serde(default)]
pub access_logs_format: Option<AccessLogFormat>,
#[serde(default)]
pub access_logs_colored: Option<bool>,
pub worker_count: Option<u16>,
pub worker_automatic_restart: Option<bool>,
pub metrics: Option<MetricsConfig>,
pub disable_cluster_metrics: Option<bool>,
pub listeners: Option<Vec<ListenerBuilder>>,
pub clusters: Option<HashMap<String, FileClusterConfig>>,
pub handle_process_affinity: Option<bool>,
pub ctl_command_timeout: Option<u64>,
pub pid_file_path: Option<String>,
pub activate_listeners: Option<bool>,
#[serde(default)]
pub front_timeout: Option<u32>,
#[serde(default)]
pub back_timeout: Option<u32>,
#[serde(default)]
pub connect_timeout: Option<u32>,
#[serde(default)]
pub zombie_check_interval: Option<u32>,
#[serde(default)]
pub accept_queue_timeout: Option<u32>,
#[serde(default)]
pub evict_on_queue_full: Option<bool>,
#[serde(default)]
pub request_timeout: Option<u32>,
#[serde(default)]
pub worker_timeout: Option<u32>,
}
impl FileConfig {
pub fn load_from_path(path: &str) -> Result<FileConfig, ConfigError> {
let data = Config::load_file(path)?;
let config: FileConfig = match toml::from_str(&data) {
Ok(config) => config,
Err(e) => {
display_toml_error(&data, &e);
return Err(ConfigError::DeserializeToml(e.to_string()));
}
};
let mut reserved_address: HashSet<SocketAddr> = HashSet::new();
if let Some(listeners) = config.listeners.as_ref() {
for listener in listeners.iter() {
if reserved_address.contains(&listener.address) {
return Err(ConfigError::ListenerAddressAlreadyInUse(listener.address));
}
reserved_address.insert(listener.address);
}
}
Ok(config)
}
}
pub struct ConfigBuilder {
file: FileConfig,
known_addresses: HashMap<SocketAddr, ListenerProtocol>,
expect_proxy_addresses: HashSet<SocketAddr>,
built: Config,
}
impl ConfigBuilder {
pub fn new<S>(file_config: FileConfig, config_path: S) -> Self
where
S: ToString,
{
let built = Config {
accept_queue_timeout: file_config
.accept_queue_timeout
.unwrap_or(DEFAULT_ACCEPT_QUEUE_TIMEOUT),
evict_on_queue_full: file_config
.evict_on_queue_full
.unwrap_or(DEFAULT_EVICT_ON_QUEUE_FULL),
activate_listeners: file_config.activate_listeners.unwrap_or(true),
automatic_state_save: file_config
.automatic_state_save
.unwrap_or(DEFAULT_AUTOMATIC_STATE_SAVE),
back_timeout: file_config.back_timeout.unwrap_or(DEFAULT_BACK_TIMEOUT),
buffer_size: file_config.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE),
command_buffer_size: file_config
.command_buffer_size
.unwrap_or(DEFAULT_COMMAND_BUFFER_SIZE),
config_path: config_path.to_string(),
connect_timeout: file_config
.connect_timeout
.unwrap_or(DEFAULT_CONNECT_TIMEOUT),
ctl_command_timeout: file_config.ctl_command_timeout.unwrap_or(1_000),
front_timeout: file_config.front_timeout.unwrap_or(DEFAULT_FRONT_TIMEOUT),
handle_process_affinity: file_config.handle_process_affinity.unwrap_or(false),
access_logs_target: file_config.access_logs_target.clone(),
audit_logs_target: file_config.audit_logs_target.clone(),
audit_logs_json_target: file_config.audit_logs_json_target.clone(),
access_logs_format: file_config.access_logs_format.clone(),
access_logs_colored: file_config.access_logs_colored,
log_level: file_config
.log_level
.clone()
.unwrap_or_else(|| String::from("info")),
log_target: file_config
.log_target
.clone()
.unwrap_or_else(|| String::from("stdout")),
log_colored: file_config.log_colored,
max_buffers: file_config.max_buffers.unwrap_or(DEFAULT_MAX_BUFFERS),
max_command_buffer_size: file_config
.max_command_buffer_size
.unwrap_or(DEFAULT_MAX_COMMAND_BUFFER_SIZE),
max_connections: file_config
.max_connections
.unwrap_or(DEFAULT_MAX_CONNECTIONS),
metrics: file_config.metrics.clone(),
disable_cluster_metrics: file_config
.disable_cluster_metrics
.unwrap_or(DEFAULT_DISABLE_CLUSTER_METRICS),
min_buffers: std::cmp::min(
file_config.min_buffers.unwrap_or(DEFAULT_MIN_BUFFERS),
file_config.max_buffers.unwrap_or(DEFAULT_MAX_BUFFERS),
),
pid_file_path: file_config.pid_file_path.clone(),
request_timeout: file_config
.request_timeout
.unwrap_or(DEFAULT_REQUEST_TIMEOUT),
saved_state: file_config.saved_state.clone(),
worker_automatic_restart: file_config
.worker_automatic_restart
.unwrap_or(DEFAULT_WORKER_AUTOMATIC_RESTART),
worker_count: file_config.worker_count.unwrap_or(DEFAULT_WORKER_COUNT),
zombie_check_interval: file_config
.zombie_check_interval
.unwrap_or(DEFAULT_ZOMBIE_CHECK_INTERVAL),
worker_timeout: file_config.worker_timeout.unwrap_or(DEFAULT_WORKER_TIMEOUT),
slab_entries_per_connection: file_config.slab_entries_per_connection.map(|n| {
n.clamp(
ServerConfig::MIN_SLAB_ENTRIES_PER_CONNECTION,
ServerConfig::MAX_SLAB_ENTRIES_PER_CONNECTION,
)
}),
command_allowed_uids: file_config.command_allowed_uids.clone(),
basic_auth_max_credential_bytes: file_config.basic_auth_max_credential_bytes,
max_connections_per_ip: file_config
.max_connections_per_ip
.unwrap_or(DEFAULT_MAX_CONNECTIONS_PER_IP),
retry_after: file_config.retry_after.unwrap_or(DEFAULT_RETRY_AFTER),
splice_pipe_capacity_bytes: file_config.splice_pipe_capacity_bytes,
..Default::default()
};
Self {
file: file_config,
known_addresses: HashMap::new(),
expect_proxy_addresses: HashSet::new(),
built,
}
}
fn push_tls_listener(&mut self, mut listener: ListenerBuilder) -> Result<(), ConfigError> {
let listener = listener.to_tls(Some(&self.built))?;
self.built.https_listeners.push(listener);
Ok(())
}
fn push_http_listener(&mut self, mut listener: ListenerBuilder) -> Result<(), ConfigError> {
let listener = listener.to_http(Some(&self.built))?;
self.built.http_listeners.push(listener);
Ok(())
}
fn push_tcp_listener(&mut self, mut listener: ListenerBuilder) -> Result<(), ConfigError> {
let listener = listener.to_tcp(Some(&self.built))?;
self.built.tcp_listeners.push(listener);
Ok(())
}
fn populate_listeners(&mut self, listeners: Vec<ListenerBuilder>) -> Result<(), ConfigError> {
for listener in listeners.iter() {
if self.known_addresses.contains_key(&listener.address) {
return Err(ConfigError::ListenerAddressAlreadyInUse(listener.address));
}
let protocol = listener
.protocol
.ok_or(ConfigError::Missing(MissingKind::Protocol))?;
self.known_addresses.insert(listener.address, protocol);
if listener.expect_proxy == Some(true) {
self.expect_proxy_addresses.insert(listener.address);
}
if listener.public_address.is_some() && listener.expect_proxy == Some(true) {
return Err(ConfigError::Incompatible {
object: ObjectKind::Listener,
id: listener.address.to_string(),
kind: IncompatibilityKind::PublicAddress,
});
}
match protocol {
ListenerProtocol::Https => self.push_tls_listener(listener.clone())?,
ListenerProtocol::Http => self.push_http_listener(listener.clone())?,
ListenerProtocol::Tcp => self.push_tcp_listener(listener.clone())?,
}
}
Ok(())
}
fn populate_clusters(
&mut self,
mut file_cluster_configs: HashMap<String, FileClusterConfig>,
) -> Result<(), ConfigError> {
for (id, file_cluster_config) in file_cluster_configs.drain() {
let mut cluster_config =
file_cluster_config.to_cluster_config(id.as_str(), &self.expect_proxy_addresses)?;
match cluster_config {
ClusterConfig::Http(ref mut http) => {
for frontend in http.frontends.iter_mut() {
match self.known_addresses.get(&frontend.address) {
Some(ListenerProtocol::Tcp) => {
return Err(ConfigError::WrongFrontendProtocol(
ListenerProtocol::Tcp,
));
}
Some(ListenerProtocol::Http) => {
if frontend.certificate.is_some() {
return Err(ConfigError::WrongFrontendProtocol(
ListenerProtocol::Http,
));
}
}
Some(ListenerProtocol::Https) => {
if frontend.certificate.is_none() {
if let Some(https_listener) =
self.built.https_listeners.iter().find(|listener| {
listener.address == frontend.address.into()
&& listener.certificate.is_some()
})
{
frontend
.certificate
.clone_from(&https_listener.certificate);
frontend.certificate_chain =
Some(https_listener.certificate_chain.clone());
frontend.key.clone_from(&https_listener.key);
}
if frontend.certificate.is_none() {
debug!("known addresses: {:?}", self.known_addresses);
debug!("frontend: {:?}", frontend);
return Err(ConfigError::WrongFrontendProtocol(
ListenerProtocol::Https,
));
}
}
}
None => {
let file_listener_protocol = if frontend.certificate.is_some() {
self.push_tls_listener(ListenerBuilder::new(
frontend.address.into(),
ListenerProtocol::Https,
))?;
ListenerProtocol::Https
} else {
self.push_http_listener(ListenerBuilder::new(
frontend.address.into(),
ListenerProtocol::Http,
))?;
ListenerProtocol::Http
};
self.known_addresses
.insert(frontend.address, file_listener_protocol);
}
}
}
}
ClusterConfig::Tcp(ref tcp) => {
for frontend in &tcp.frontends {
match self.known_addresses.get(&frontend.address) {
Some(ListenerProtocol::Http) | Some(ListenerProtocol::Https) => {
return Err(ConfigError::WrongFrontendProtocol(
ListenerProtocol::Http,
));
}
Some(ListenerProtocol::Tcp) => {}
None => {
self.push_tcp_listener(ListenerBuilder::new(
frontend.address.into(),
ListenerProtocol::Tcp,
))?;
self.known_addresses
.insert(frontend.address, ListenerProtocol::Tcp);
}
}
}
}
}
self.built.clusters.insert(id, cluster_config);
}
Ok(())
}
pub fn into_config(&mut self) -> Result<Config, ConfigError> {
if let Some(listeners) = &self.file.listeners {
self.populate_listeners(listeners.clone())?;
}
if let Some(file_cluster_configs) = &self.file.clusters {
self.populate_clusters(file_cluster_configs.clone())?;
}
let h2_listeners = self
.built
.https_listeners
.iter()
.filter(|l| l.alpn_protocols.iter().any(|p| p == "h2"))
.count();
if h2_listeners > 0 && self.built.buffer_size < H2_MIN_BUFFER_SIZE {
return Err(ConfigError::BufferSizeTooSmallForH2 {
buffer_size: self.built.buffer_size,
minimum: H2_MIN_BUFFER_SIZE,
listeners: h2_listeners,
});
}
if let Some(cap) = self.built.basic_auth_max_credential_bytes {
let third = self.built.buffer_size / 3;
if cap >= third {
warn!(
"basic_auth_max_credential_bytes = {} is >= buffer_size / 3 ({}); \
a hostile peer can pin ~33% of the per-frontend buffer per failed auth \
attempt. Consider lowering basic_auth_max_credential_bytes (typical \
credentials are <100 bytes) or raising buffer_size.",
cap, third
);
}
}
if self.built.evict_on_queue_full && self.built.max_connections < 100 {
let pct = 100usize.div_ceil(self.built.max_connections);
warn!(
"evict_on_queue_full enabled with max_connections = {}; the eviction batch \
clamps to 1, equivalent to ~{}% of capacity per cap event (the knob is \
documented as 1%). Confirm this is intended.",
self.built.max_connections, pct
);
}
let command_socket_path = self.file.command_socket.clone().unwrap_or({
let mut path = env::current_dir().map_err(|e| ConfigError::Env(e.to_string()))?;
path.push("sozu.sock");
let verified_path = path
.to_str()
.ok_or(ConfigError::InvalidPath(path.clone()))?;
verified_path.to_owned()
});
if let (None, Some(true)) = (&self.file.saved_state, &self.file.automatic_state_save) {
return Err(ConfigError::Missing(MissingKind::SavedState));
}
Ok(Config {
command_socket: command_socket_path,
..self.built.clone()
})
}
}
#[derive(Clone, PartialEq, Eq, Serialize, Default, Deserialize)]
pub struct Config {
pub config_path: String,
pub command_socket: String,
pub command_buffer_size: u64,
pub max_command_buffer_size: u64,
pub max_connections: usize,
pub min_buffers: u64,
pub max_buffers: u64,
pub buffer_size: u64,
pub saved_state: Option<String>,
#[serde(default)]
pub automatic_state_save: bool,
pub log_level: String,
pub log_target: String,
pub log_colored: bool,
#[serde(default)]
pub audit_logs_target: Option<String>,
#[serde(default)]
pub audit_logs_json_target: Option<String>,
#[serde(default)]
pub access_logs_target: Option<String>,
pub access_logs_format: Option<AccessLogFormat>,
pub access_logs_colored: Option<bool>,
pub worker_count: u16,
pub worker_automatic_restart: bool,
pub metrics: Option<MetricsConfig>,
#[serde(default = "default_disable_cluster_metrics")]
pub disable_cluster_metrics: bool,
pub http_listeners: Vec<HttpListenerConfig>,
pub https_listeners: Vec<HttpsListenerConfig>,
pub tcp_listeners: Vec<TcpListenerConfig>,
pub clusters: HashMap<String, ClusterConfig>,
pub handle_process_affinity: bool,
pub ctl_command_timeout: u64,
pub pid_file_path: Option<String>,
pub activate_listeners: bool,
#[serde(default = "default_front_timeout")]
pub front_timeout: u32,
#[serde(default = "default_back_timeout")]
pub back_timeout: u32,
#[serde(default = "default_connect_timeout")]
pub connect_timeout: u32,
#[serde(default = "default_zombie_check_interval")]
pub zombie_check_interval: u32,
#[serde(default = "default_accept_queue_timeout")]
pub accept_queue_timeout: u32,
#[serde(default = "default_evict_on_queue_full")]
pub evict_on_queue_full: bool,
#[serde(default = "default_request_timeout")]
pub request_timeout: u32,
#[serde(default = "default_worker_timeout")]
pub worker_timeout: u32,
#[serde(default)]
pub slab_entries_per_connection: Option<u64>,
#[serde(default)]
pub command_allowed_uids: Option<Vec<u32>>,
#[serde(default)]
pub basic_auth_max_credential_bytes: Option<u64>,
#[serde(default = "default_max_connections_per_ip")]
pub max_connections_per_ip: u64,
#[serde(default = "default_retry_after")]
pub retry_after: u32,
#[serde(default)]
pub splice_pipe_capacity_bytes: Option<u64>,
}
fn default_front_timeout() -> u32 {
DEFAULT_FRONT_TIMEOUT
}
fn default_back_timeout() -> u32 {
DEFAULT_BACK_TIMEOUT
}
fn default_connect_timeout() -> u32 {
DEFAULT_CONNECT_TIMEOUT
}
fn default_request_timeout() -> u32 {
DEFAULT_REQUEST_TIMEOUT
}
fn default_zombie_check_interval() -> u32 {
DEFAULT_ZOMBIE_CHECK_INTERVAL
}
fn default_accept_queue_timeout() -> u32 {
DEFAULT_ACCEPT_QUEUE_TIMEOUT
}
fn default_evict_on_queue_full() -> bool {
DEFAULT_EVICT_ON_QUEUE_FULL
}
fn default_disable_cluster_metrics() -> bool {
DEFAULT_DISABLE_CLUSTER_METRICS
}
fn default_worker_timeout() -> u32 {
DEFAULT_WORKER_TIMEOUT
}
fn default_max_connections_per_ip() -> u64 {
DEFAULT_MAX_CONNECTIONS_PER_IP
}
fn default_retry_after() -> u32 {
DEFAULT_RETRY_AFTER
}
impl Config {
pub fn load_from_path(path: &str) -> Result<Config, ConfigError> {
let file_config = FileConfig::load_from_path(path)?;
let mut config = ConfigBuilder::new(file_config, path).into_config()?;
config.saved_state = config.saved_state_path()?;
Ok(config)
}
pub fn generate_config_messages(&self) -> Result<Vec<WorkerRequest>, ConfigError> {
let mut v = Vec::new();
let mut count = 0u8;
for listener in &self.http_listeners {
v.push(WorkerRequest {
id: format!("CONFIG-{count}"),
content: RequestType::AddHttpListener(listener.clone()).into(),
});
count += 1;
}
for listener in &self.https_listeners {
v.push(WorkerRequest {
id: format!("CONFIG-{count}"),
content: RequestType::AddHttpsListener(listener.clone()).into(),
});
count += 1;
}
for listener in &self.tcp_listeners {
v.push(WorkerRequest {
id: format!("CONFIG-{count}"),
content: RequestType::AddTcpListener(*listener).into(),
});
count += 1;
}
for cluster in self.clusters.values() {
let mut orders = cluster.generate_requests()?;
for content in orders.drain(..) {
v.push(WorkerRequest {
id: format!("CONFIG-{count}"),
content,
});
count += 1;
}
}
if self.activate_listeners {
for listener in &self.http_listeners {
v.push(WorkerRequest {
id: format!("CONFIG-{count}"),
content: RequestType::ActivateListener(ActivateListener {
address: listener.address,
proxy: ListenerType::Http.into(),
from_scm: false,
})
.into(),
});
count += 1;
}
for listener in &self.https_listeners {
v.push(WorkerRequest {
id: format!("CONFIG-{count}"),
content: RequestType::ActivateListener(ActivateListener {
address: listener.address,
proxy: ListenerType::Https.into(),
from_scm: false,
})
.into(),
});
count += 1;
}
for listener in &self.tcp_listeners {
v.push(WorkerRequest {
id: format!("CONFIG-{count}"),
content: RequestType::ActivateListener(ActivateListener {
address: listener.address,
proxy: ListenerType::Tcp.into(),
from_scm: false,
})
.into(),
});
count += 1;
}
}
if self.disable_cluster_metrics {
v.push(WorkerRequest {
id: format!("CONFIG-{count}"),
content: RequestType::ConfigureMetrics(MetricsConfiguration::Disabled.into())
.into(),
});
}
Ok(v)
}
pub fn command_socket_path(&self) -> Result<String, ConfigError> {
let config_path_buf = PathBuf::from(self.config_path.clone());
let mut config_dir = config_path_buf
.parent()
.ok_or(ConfigError::NoFileParent(
config_path_buf.to_string_lossy().to_string(),
))?
.to_path_buf();
let socket_path = PathBuf::from(self.command_socket.clone());
let mut socket_parent_dir = match socket_path.parent() {
None => config_dir,
Some(path) => {
config_dir.push(path);
config_dir.canonicalize().map_err(|io_error| {
ConfigError::SocketPathError(format!(
"Could not canonicalize path {config_dir:?}: {io_error}"
))
})?
}
};
let socket_name = socket_path
.file_name()
.ok_or(ConfigError::SocketPathError(format!(
"could not get command socket file name from {socket_path:?}"
)))?;
socket_parent_dir.push(socket_name);
let command_socket_path = socket_parent_dir
.to_str()
.ok_or(ConfigError::SocketPathError(format!(
"Invalid socket path {socket_parent_dir:?}"
)))?
.to_string();
Ok(command_socket_path)
}
fn saved_state_path(&self) -> Result<Option<String>, ConfigError> {
let path = match self.saved_state.as_ref() {
Some(path) => path,
None => return Ok(None),
};
debug!("saved_stated path in the config: {}", path);
let config_path = PathBuf::from(self.config_path.clone());
debug!("Config path buffer: {:?}", config_path);
let config_dir = config_path
.parent()
.ok_or(ConfigError::SaveStatePath(format!(
"Could get parent directory of config file {config_path:?}"
)))?;
debug!("Config folder: {:?}", config_dir);
if !config_dir.exists() {
create_dir_all(config_dir).map_err(|io_error| {
ConfigError::SaveStatePath(format!(
"failed to create state parent directory '{config_dir:?}': {io_error}"
))
})?;
}
let mut saved_state_path_raw = config_dir.to_path_buf();
saved_state_path_raw.push(path);
debug!(
"Looking for saved state on the path {:?}",
saved_state_path_raw
);
match metadata(path) {
Err(err) if matches!(err.kind(), ErrorKind::NotFound) => {
info!("Create an empty state file at '{}'", path);
File::create(path).map_err(|io_error| {
ConfigError::SaveStatePath(format!(
"failed to create state file '{path:?}': {io_error}"
))
})?;
}
_ => {}
}
saved_state_path_raw.canonicalize().map_err(|io_error| {
ConfigError::SaveStatePath(format!(
"could not get saved state path from config file input {path:?}: {io_error}"
))
})?;
let stringified_path = saved_state_path_raw
.to_str()
.ok_or(ConfigError::SaveStatePath(format!(
"Invalid path {saved_state_path_raw:?}"
)))?
.to_string();
Ok(Some(stringified_path))
}
pub fn load_file(path: &str) -> Result<String, ConfigError> {
std::fs::read_to_string(path).map_err(|io_error| ConfigError::FileRead {
path_to_read: path.to_owned(),
io_error,
})
}
pub fn load_file_bytes(path: &str) -> Result<Vec<u8>, ConfigError> {
std::fs::read(path).map_err(|io_error| ConfigError::FileRead {
path_to_read: path.to_owned(),
io_error,
})
}
}
impl fmt::Debug for Config {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Config")
.field("config_path", &self.config_path)
.field("command_socket", &self.command_socket)
.field("command_buffer_size", &self.command_buffer_size)
.field("max_command_buffer_size", &self.max_command_buffer_size)
.field("max_connections", &self.max_connections)
.field("min_buffers", &self.min_buffers)
.field("max_buffers", &self.max_buffers)
.field("buffer_size", &self.buffer_size)
.field("saved_state", &self.saved_state)
.field("automatic_state_save", &self.automatic_state_save)
.field("log_level", &self.log_level)
.field("log_target", &self.log_target)
.field("access_logs_target", &self.access_logs_target)
.field("audit_logs_target", &self.audit_logs_target)
.field("audit_logs_json_target", &self.audit_logs_json_target)
.field("access_logs_format", &self.access_logs_format)
.field("worker_count", &self.worker_count)
.field("worker_automatic_restart", &self.worker_automatic_restart)
.field("metrics", &self.metrics)
.field("disable_cluster_metrics", &self.disable_cluster_metrics)
.field("handle_process_affinity", &self.handle_process_affinity)
.field("ctl_command_timeout", &self.ctl_command_timeout)
.field("pid_file_path", &self.pid_file_path)
.field("activate_listeners", &self.activate_listeners)
.field("front_timeout", &self.front_timeout)
.field("back_timeout", &self.back_timeout)
.field("connect_timeout", &self.connect_timeout)
.field("zombie_check_interval", &self.zombie_check_interval)
.field("accept_queue_timeout", &self.accept_queue_timeout)
.field("evict_on_queue_full", &self.evict_on_queue_full)
.field("request_timeout", &self.request_timeout)
.field("worker_timeout", &self.worker_timeout)
.finish()
}
}
fn display_toml_error(file: &str, error: &toml::de::Error) {
println!("error parsing the configuration file '{file}': {error}");
if let Some(Range { start, end }) = error.span() {
print!("error parsing the configuration file '{file}' at position: {start}, {end}");
}
}
impl ServerConfig {
pub const DEFAULT_SLAB_ENTRIES_PER_CONNECTION: u64 = 4;
pub const MIN_SLAB_ENTRIES_PER_CONNECTION: u64 = 2;
pub const MAX_SLAB_ENTRIES_PER_CONNECTION: u64 = 32;
pub fn effective_slab_entries_per_connection(&self) -> u64 {
match self.slab_entries_per_connection {
Some(0) | None => Self::DEFAULT_SLAB_ENTRIES_PER_CONNECTION,
Some(n) => n.clamp(
Self::MIN_SLAB_ENTRIES_PER_CONNECTION,
Self::MAX_SLAB_ENTRIES_PER_CONNECTION,
),
}
}
pub fn slab_capacity(&self) -> u64 {
10 + self.effective_slab_entries_per_connection() * self.max_connections
}
}
impl From<&Config> for ServerConfig {
fn from(config: &Config) -> Self {
let metrics = config.metrics.clone().map(|m| ServerMetricsConfig {
address: m.address.to_string(),
tagged_metrics: m.tagged_metrics,
prefix: m.prefix,
detail: Some(MetricDetail::from(m.detail) as i32),
});
Self {
max_connections: config.max_connections as u64,
front_timeout: config.front_timeout,
back_timeout: config.back_timeout,
connect_timeout: config.connect_timeout,
zombie_check_interval: config.zombie_check_interval,
accept_queue_timeout: config.accept_queue_timeout,
min_buffers: config.min_buffers,
max_buffers: config.max_buffers,
buffer_size: config.buffer_size,
log_level: config.log_level.clone(),
log_target: config.log_target.clone(),
access_logs_target: config.access_logs_target.clone(),
audit_logs_target: config.audit_logs_target.clone(),
audit_logs_json_target: config.audit_logs_json_target.clone(),
command_buffer_size: config.command_buffer_size,
max_command_buffer_size: config.max_command_buffer_size,
metrics,
access_log_format: ProtobufAccessLogFormat::from(&config.access_logs_format) as i32,
log_colored: config.log_colored,
slab_entries_per_connection: config.slab_entries_per_connection,
basic_auth_max_credential_bytes: config.basic_auth_max_credential_bytes,
evict_on_queue_full: Some(config.evict_on_queue_full),
max_connections_per_ip: Some(config.max_connections_per_ip),
retry_after: Some(config.retry_after),
splice_pipe_capacity_bytes: config.splice_pipe_capacity_bytes,
}
}
}
#[cfg(test)]
mod tests {
use toml::to_string;
use super::*;
#[test]
fn hsts_to_proto_enabled_substitutes_default_max_age() {
let cfg = FileHstsConfig {
enabled: Some(true),
max_age: None,
include_subdomains: None,
preload: None,
force_replace_backend: None,
};
let proto = cfg.to_proto("test").expect("should validate");
assert_eq!(proto.enabled, Some(true));
assert_eq!(proto.max_age, Some(DEFAULT_HSTS_MAX_AGE));
}
#[test]
fn hsts_to_proto_explicit_max_age_kept() {
let cfg = FileHstsConfig {
enabled: Some(true),
max_age: Some(63_072_000),
include_subdomains: Some(true),
preload: Some(true),
force_replace_backend: None,
};
let proto = cfg.to_proto("test").expect("should validate");
assert_eq!(proto.max_age, Some(63_072_000));
assert_eq!(proto.include_subdomains, Some(true));
assert_eq!(proto.preload, Some(true));
}
#[test]
fn hsts_to_proto_disabled_keeps_zero_intent() {
let cfg = FileHstsConfig {
enabled: Some(false),
max_age: None,
include_subdomains: None,
preload: None,
force_replace_backend: None,
};
let proto = cfg.to_proto("test").expect("should validate");
assert_eq!(proto.enabled, Some(false));
}
#[test]
fn hsts_to_proto_kill_switch_max_age_zero_allowed() {
let cfg = FileHstsConfig {
enabled: Some(true),
max_age: Some(0),
include_subdomains: None,
preload: None,
force_replace_backend: None,
};
let proto = cfg.to_proto("test").expect("kill-switch must validate");
assert_eq!(proto.max_age, Some(0));
}
#[test]
fn hsts_to_proto_missing_enabled_errors() {
let cfg = FileHstsConfig {
enabled: None,
max_age: Some(31_536_000),
include_subdomains: None,
preload: None,
force_replace_backend: None,
};
match cfg.to_proto("test").unwrap_err() {
ConfigError::HstsEnabledRequired(scope) => assert_eq!(scope, "test"),
other => panic!("expected HstsEnabledRequired, got {other:?}"),
}
}
#[test]
fn hsts_rejected_on_http_listener() {
let mut listener = ListenerBuilder::new(
SocketAddress::new_v4(127, 0, 0, 1, 8080),
ListenerProtocol::Http,
);
listener.hsts = Some(FileHstsConfig {
enabled: Some(true),
max_age: Some(31_536_000),
include_subdomains: None,
preload: None,
force_replace_backend: None,
});
match listener.to_http(None).unwrap_err() {
ConfigError::HstsOnPlainHttp(scope) => assert!(
scope.contains("HTTP listener"),
"expected scope to mention 'HTTP listener', got {scope:?}"
),
other => panic!("expected HstsOnPlainHttp, got {other:?}"),
}
}
#[test]
fn hsts_rejected_on_http_frontend() {
let frontend = FileClusterFrontendConfig {
address: "127.0.0.1:8080".parse().unwrap(),
hostname: Some("example.com".to_owned()),
path: None,
path_type: None,
method: None,
certificate: None,
key: None,
certificate_chain: None,
tls_versions: vec![],
position: RulePosition::Tree,
tags: None,
redirect: None,
redirect_scheme: None,
redirect_template: None,
rewrite_host: None,
rewrite_path: None,
rewrite_port: None,
required_auth: None,
headers: None,
hsts: Some(FileHstsConfig {
enabled: Some(true),
max_age: Some(31_536_000),
include_subdomains: None,
preload: None,
force_replace_backend: None,
}),
};
match frontend.to_http_front("api").unwrap_err() {
ConfigError::HstsOnPlainHttp(scope) => {
assert!(
scope.contains("api") && scope.contains("example.com"),
"expected scope to mention 'api' and 'example.com', got {scope:?}"
);
}
other => panic!("expected HstsOnPlainHttp, got {other:?}"),
}
}
#[test]
fn serialize() {
let http = ListenerBuilder::new(
SocketAddress::new_v4(127, 0, 0, 1, 8080),
ListenerProtocol::Http,
)
.with_answer_404_path(Some("404.html"))
.to_owned();
println!("http: {:?}", to_string(&http));
let https = ListenerBuilder::new(
SocketAddress::new_v4(127, 0, 0, 1, 8443),
ListenerProtocol::Https,
)
.with_answer_404_path(Some("404.html"))
.to_owned();
println!("https: {:?}", to_string(&https));
let listeners = vec![http, https];
let config = FileConfig {
command_socket: Some(String::from("./command_folder/sock")),
worker_count: Some(2),
worker_automatic_restart: Some(true),
max_connections: Some(500),
min_buffers: Some(1),
max_buffers: Some(500),
buffer_size: Some(16393),
metrics: Some(MetricsConfig {
address: "127.0.0.1:8125".parse().unwrap(),
tagged_metrics: false,
prefix: Some(String::from("sozu-metrics")),
detail: MetricDetailLevel::default(),
}),
listeners: Some(listeners),
..Default::default()
};
println!("config: {:?}", to_string(&config));
let encoded = to_string(&config).unwrap();
println!("conf:\n{encoded}");
}
#[test]
fn parse() {
let path = "assets/config.toml";
let config = Config::load_from_path(path).unwrap_or_else(|load_error| {
panic!("Cannot load config from path {path}: {load_error:?}")
});
println!("config: {config:#?}");
}
#[test]
fn multiple_listeners_preserve_per_address_expect_proxy() {
let toml_content = r#"
command_socket = "/tmp/sozu_test.sock"
worker_count = 1
[[listeners]]
protocol = "http"
address = "172.16.20.1:80"
expect_proxy = true
[[listeners]]
protocol = "http"
address = "10.22.0.1:80"
expect_proxy = false
[[listeners]]
protocol = "https"
address = "192.168.1.1:443"
expect_proxy = true
[[listeners]]
protocol = "https"
address = "192.168.2.1:443"
expect_proxy = false
"#;
let file_config: FileConfig =
toml::from_str(toml_content).expect("Could not parse TOML config");
let listeners = file_config.listeners.as_ref().expect("No listeners found");
assert_eq!(listeners.len(), 4);
let config = ConfigBuilder::new(file_config, "/tmp/test_config.toml")
.into_config()
.expect("Could not build config");
assert_eq!(config.http_listeners.len(), 2);
assert_eq!(config.https_listeners.len(), 2);
let http_proxy = config
.http_listeners
.iter()
.find(|l| SocketAddr::from(l.address) == "172.16.20.1:80".parse().unwrap())
.expect("Listener on 172.16.20.1:80 not found");
let http_direct = config
.http_listeners
.iter()
.find(|l| SocketAddr::from(l.address) == "10.22.0.1:80".parse().unwrap())
.expect("Listener on 10.22.0.1:80 not found");
assert!(http_proxy.expect_proxy);
assert!(!http_direct.expect_proxy);
let https_proxy = config
.https_listeners
.iter()
.find(|l| SocketAddr::from(l.address) == "192.168.1.1:443".parse().unwrap())
.expect("Listener on 192.168.1.1:443 not found");
let https_direct = config
.https_listeners
.iter()
.find(|l| SocketAddr::from(l.address) == "192.168.2.1:443".parse().unwrap())
.expect("Listener on 192.168.2.1:443 not found");
assert!(https_proxy.expect_proxy);
assert!(!https_direct.expect_proxy);
}
#[test]
fn multiple_listeners_generate_correct_worker_requests() {
let toml_content = r#"
command_socket = "/tmp/sozu_test.sock"
worker_count = 1
activate_listeners = true
[[listeners]]
protocol = "http"
address = "172.16.20.1:80"
expect_proxy = true
[[listeners]]
protocol = "http"
address = "10.22.0.1:80"
expect_proxy = false
"#;
let file_config: FileConfig =
toml::from_str(toml_content).expect("Could not parse TOML config");
let config = ConfigBuilder::new(file_config, "/tmp/test_config.toml")
.into_config()
.expect("Could not build config");
let messages = config
.generate_config_messages()
.expect("Could not generate config messages");
let add_listener_count = messages
.iter()
.filter(|m| {
matches!(
m.content.request_type,
Some(RequestType::AddHttpListener(_))
)
})
.count();
let activate_listener_count = messages
.iter()
.filter(|m| {
matches!(
m.content.request_type,
Some(RequestType::ActivateListener(ActivateListener {
proxy,
..
})) if proxy == ListenerType::Http as i32
)
})
.count();
assert_eq!(add_listener_count, 2);
assert_eq!(activate_listener_count, 2);
}
#[test]
fn duplicate_listener_address_rejected() {
let toml_content = r#"
command_socket = "/tmp/sozu_test.sock"
worker_count = 1
[[listeners]]
protocol = "http"
address = "0.0.0.0:80"
[[listeners]]
protocol = "http"
address = "0.0.0.0:80"
"#;
let file_config: FileConfig =
toml::from_str(toml_content).expect("Could not parse TOML config");
let result = ConfigBuilder::new(file_config, "/tmp/test_config.toml").into_config();
assert!(
result.is_err(),
"Should reject duplicate listener addresses"
);
}
#[test]
fn buffer_size_below_h2_minimum_rejected() {
let toml_content = r#"
command_socket = "/tmp/sozu_test.sock"
worker_count = 1
buffer_size = 8192
[[listeners]]
protocol = "https"
address = "127.0.0.1:8443"
"#;
let file_config: FileConfig =
toml::from_str(toml_content).expect("Could not parse TOML config");
let result = ConfigBuilder::new(file_config, "/tmp/test_config.toml").into_config();
match result {
Err(ConfigError::BufferSizeTooSmallForH2 {
buffer_size: 8192,
minimum: 16_393,
listeners: 1,
}) => {}
other => panic!("expected BufferSizeTooSmallForH2, got {other:?}"),
}
}
#[test]
fn buffer_size_below_h2_minimum_accepted_when_no_h2_listener() {
let toml_content = r#"
command_socket = "/tmp/sozu_test.sock"
worker_count = 1
buffer_size = 8192
[[listeners]]
protocol = "https"
address = "127.0.0.1:8443"
alpn_protocols = ["http/1.1"]
"#;
let file_config: FileConfig =
toml::from_str(toml_content).expect("Could not parse TOML config");
let result = ConfigBuilder::new(file_config, "/tmp/test_config.toml").into_config();
assert!(
result.is_ok(),
"non-H2 HTTPS listener with sub-16393 buffer should be accepted: {result:?}"
);
}
#[test]
fn buffer_size_at_h2_minimum_accepted() {
let toml_content = r#"
command_socket = "/tmp/sozu_test.sock"
worker_count = 1
buffer_size = 16393
[[listeners]]
protocol = "https"
address = "127.0.0.1:8443"
"#;
let file_config: FileConfig =
toml::from_str(toml_content).expect("Could not parse TOML config");
let result = ConfigBuilder::new(file_config, "/tmp/test_config.toml").into_config();
assert!(
result.is_ok(),
"buffer_size at the H2 minimum should be accepted: {result:?}"
);
}
#[test]
fn alpn_protocols_default() {
let mut builder = ListenerBuilder::new_https(SocketAddress::new_v4(127, 0, 0, 1, 8443));
let config = builder.to_tls(None).expect("to_tls should succeed");
assert_eq!(config.alpn_protocols, vec!["h2", "http/1.1"]);
}
#[test]
fn alpn_protocols_custom() {
let mut builder = ListenerBuilder::new_https(SocketAddress::new_v4(127, 0, 0, 1, 8443));
builder.with_alpn_protocols(Some(vec!["http/1.1".to_owned()]));
let config = builder.to_tls(None).expect("to_tls should succeed");
assert_eq!(config.alpn_protocols, vec!["http/1.1"]);
}
#[test]
fn alpn_protocols_invalid_rejected() {
let mut builder = ListenerBuilder::new_https(SocketAddress::new_v4(127, 0, 0, 1, 8443));
builder.with_alpn_protocols(Some(vec!["h3".to_owned()]));
let result = builder.to_tls(None);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("h3"),
"error should mention the invalid protocol: {err}"
);
}
#[test]
fn alpn_protocols_empty_uses_default() {
let mut builder = ListenerBuilder::new_https(SocketAddress::new_v4(127, 0, 0, 1, 8443));
builder.with_alpn_protocols(Some(vec![]));
let config = builder.to_tls(None).expect("to_tls should succeed");
assert_eq!(config.alpn_protocols, vec!["h2", "http/1.1"]);
}
#[test]
fn alpn_protocols_deduplicated() {
let mut builder = ListenerBuilder::new_https(SocketAddress::new_v4(127, 0, 0, 1, 8443));
builder.with_alpn_protocols(Some(vec![
"h2".to_owned(),
"h2".to_owned(),
"http/1.1".to_owned(),
]));
let config = builder.to_tls(None).expect("to_tls should succeed");
assert_eq!(config.alpn_protocols, vec!["h2", "http/1.1"]);
}
#[test]
fn alpn_protocols_order_preserved() {
let mut builder = ListenerBuilder::new_https(SocketAddress::new_v4(127, 0, 0, 1, 8443));
builder.with_alpn_protocols(Some(vec!["http/1.1".to_owned(), "h2".to_owned()]));
let config = builder.to_tls(None).expect("to_tls should succeed");
assert_eq!(config.alpn_protocols, vec!["http/1.1", "h2"]);
}
#[test]
fn parse_header_edit_rejects_crlf_in_value() {
let entry = HeaderEditConfig {
position: "request".to_owned(),
key: "X-Test".to_owned(),
value: "value\r\nEvil-Header: stolen".to_owned(),
};
let err = parse_header_edit(0, &entry).expect_err("CRLF in value must be rejected");
match err {
ConfigError::InvalidHeaderBytes { index, field } => {
assert_eq!(index, 0);
assert_eq!(field, "value");
}
other => panic!("expected InvalidHeaderBytes, got {other:?}"),
}
}
#[test]
fn parse_header_edit_rejects_lf_in_key() {
let entry = HeaderEditConfig {
position: "response".to_owned(),
key: "X-\nTest".to_owned(),
value: "ok".to_owned(),
};
let err = parse_header_edit(2, &entry).expect_err("LF in key must be rejected");
match err {
ConfigError::InvalidHeaderBytes { index, field } => {
assert_eq!(index, 2);
assert_eq!(field, "key");
}
other => panic!("expected InvalidHeaderBytes, got {other:?}"),
}
}
#[test]
fn parse_header_edit_rejects_nul() {
let entry = HeaderEditConfig {
position: "both".to_owned(),
key: "X-Test".to_owned(),
value: "with\0nul".to_owned(),
};
assert!(matches!(
parse_header_edit(0, &entry),
Err(ConfigError::InvalidHeaderBytes { .. })
));
}
#[test]
fn parse_header_edit_accepts_tab_in_value() {
let entry = HeaderEditConfig {
position: "request".to_owned(),
key: "X-Test".to_owned(),
value: "with\ttab".to_owned(),
};
let header = parse_header_edit(0, &entry).expect("tab in value must be accepted");
assert_eq!(header.val, "with\ttab");
}
#[test]
fn parse_header_edit_rejects_tab_in_key() {
let entry = HeaderEditConfig {
position: "request".to_owned(),
key: "Host\t".to_owned(),
value: "ok".to_owned(),
};
let err = parse_header_edit(0, &entry).expect_err("HTAB in key must be rejected");
match err {
ConfigError::InvalidHeaderBytes { field, .. } => assert_eq!(field, "key"),
other => panic!("expected InvalidHeaderBytes{{field=\"key\"}}, got {other:?}"),
}
}
#[test]
fn parse_header_edit_rejects_space_in_key() {
let entry = HeaderEditConfig {
position: "request".to_owned(),
key: "X Test".to_owned(),
value: "ok".to_owned(),
};
let err = parse_header_edit(0, &entry).expect_err("SP in key must be rejected");
assert!(matches!(err, ConfigError::InvalidHeaderBytes { .. }));
}
#[test]
fn parse_header_edit_rejects_empty_key() {
let entry = HeaderEditConfig {
position: "request".to_owned(),
key: String::new(),
value: "ok".to_owned(),
};
let err = parse_header_edit(0, &entry).expect_err("empty key must be rejected");
assert!(matches!(
err,
ConfigError::InvalidHeaderBytes { field: "key", .. }
));
}
#[test]
fn parse_header_edit_accepts_clean_value() {
let entry = HeaderEditConfig {
position: "request".to_owned(),
key: "X-Tenant".to_owned(),
value: "alpha".to_owned(),
};
let header = parse_header_edit(0, &entry).expect("clean value must be accepted");
assert_eq!(header.key, "X-Tenant");
assert_eq!(header.val, "alpha");
}
#[test]
fn resolve_answer_source_bare_string_is_literal() {
let body = resolve_answer_source("HTTP/1.1 503 Service Unavailable\r\n\r\nbusy")
.expect("bare-string source must resolve");
assert_eq!(body, "HTTP/1.1 503 Service Unavailable\r\n\r\nbusy");
}
#[test]
fn resolve_answer_source_empty_string_is_legitimate() {
let body = resolve_answer_source("").expect("empty source must resolve");
assert_eq!(body, "");
}
#[test]
fn resolve_answer_source_file_scheme_missing_file_errors() {
let err = resolve_answer_source("file:///nonexistent/sozu-test/never.http")
.expect_err("missing path must error");
assert!(matches!(err, ConfigError::FileOpen { .. }));
}
#[test]
fn resolve_answer_source_file_scheme_empty_path_errors() {
let err = resolve_answer_source("file://").expect_err("empty path must error");
assert!(matches!(err, ConfigError::FileOpen { .. }));
}
}