use serde::Deserialize;
use serde::Serialize;
use std::net::SocketAddr;
use std::path::PathBuf;
#[derive(Clone, Copy, Debug, Default, Deserialize, PartialEq, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum CompressionConfig {
#[default]
None,
Gzip,
}
pub type RawTlsConfig = rustls::ServerConfig;
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
#[serde(
from = "DeserializedConfigDropshot",
into = "DeserializedConfigDropshot"
)]
pub struct ConfigDropshot {
pub bind_address: SocketAddr,
pub default_request_body_max_bytes: usize,
pub default_handler_task_mode: HandlerTaskMode,
pub log_headers: Vec<String>,
pub compression: CompressionConfig,
}
#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Serialize)]
#[serde(rename_all = "kebab-case")]
pub enum HandlerTaskMode {
CancelOnDisconnect,
Detached,
}
#[derive(Clone, Debug)]
pub enum ConfigTls {
AsFile {
cert_file: PathBuf,
key_file: PathBuf,
},
AsBytes { certs: Vec<u8>, key: Vec<u8> },
Dynamic(RawTlsConfig),
}
impl Default for ConfigDropshot {
fn default() -> Self {
ConfigDropshot {
bind_address: "127.0.0.1:0".parse().unwrap(),
default_request_body_max_bytes: 1024,
default_handler_task_mode: HandlerTaskMode::Detached,
log_headers: Default::default(),
compression: CompressionConfig::default(),
}
}
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
#[serde(default)]
struct DeserializedConfigDropshot {
bind_address: SocketAddr,
default_request_body_max_bytes: usize,
#[serde(
deserialize_with = "deserialize_invalid_request_body_max_bytes",
skip_serializing
)]
request_body_max_bytes: Option<InvalidConfig>,
default_handler_task_mode: HandlerTaskMode,
log_headers: Vec<String>,
compression: CompressionConfig,
}
impl From<DeserializedConfigDropshot> for ConfigDropshot {
fn from(v: DeserializedConfigDropshot) -> Self {
ConfigDropshot {
bind_address: v.bind_address,
default_request_body_max_bytes: v.default_request_body_max_bytes,
default_handler_task_mode: v.default_handler_task_mode,
log_headers: v.log_headers,
compression: v.compression,
}
}
}
impl From<ConfigDropshot> for DeserializedConfigDropshot {
fn from(v: ConfigDropshot) -> Self {
DeserializedConfigDropshot {
bind_address: v.bind_address,
default_request_body_max_bytes: v.default_request_body_max_bytes,
request_body_max_bytes: None,
default_handler_task_mode: v.default_handler_task_mode,
log_headers: v.log_headers,
compression: v.compression,
}
}
}
impl Default for DeserializedConfigDropshot {
fn default() -> Self {
ConfigDropshot::default().into()
}
}
#[derive(Clone, Debug, PartialEq)]
pub enum InvalidConfig {}
fn deserialize_invalid_request_body_max_bytes<'de, D>(
deserializer: D,
) -> Result<Option<InvalidConfig>, D::Error>
where
D: serde::Deserializer<'de>,
{
deserialize_invalid(
deserializer,
"request_body_max_bytes has been renamed to \
default_request_body_max_bytes",
)
}
fn deserialize_invalid<'de, D>(
deserializer: D,
msg: &'static str,
) -> Result<Option<InvalidConfig>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
struct V {
msg: &'static str,
}
impl<'de> serde::de::Visitor<'de> for V {
type Value = Option<InvalidConfig>;
fn expecting(
&self,
formatter: &mut std::fmt::Formatter,
) -> std::fmt::Result {
write!(formatter, "the field to be absent ({})", self.msg)
}
fn visit_some<D>(self, _: D) -> Result<Self::Value, D::Error>
where
D: serde::Deserializer<'de>,
{
Err(D::Error::custom(self.msg))
}
}
deserializer.deserialize_any(V { msg })
}