use serde::{Deserialize, Serialize};
#[cfg(feature = "tls")]
use std::sync::Arc;
#[cfg(feature = "tls")]
#[derive(Clone, Debug)]
pub struct RustlsClientConfig(pub Arc<rustls::ClientConfig>);
#[cfg(feature = "tls")]
impl PartialEq for RustlsClientConfig {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.0, &other.0)
}
}
#[cfg(feature = "tls")]
impl Eq for RustlsClientConfig {}
#[cfg(feature = "tls")]
impl Serialize for RustlsClientConfig {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str("<RustlsClientConfig>")
}
}
#[cfg(feature = "tls")]
impl<'de> Deserialize<'de> for RustlsClientConfig {
fn deserialize<D: serde::Deserializer<'de>>(_deserializer: D) -> Result<Self, D::Error> {
Err(serde::de::Error::custom(
"RustlsClientConfig cannot be deserialized from text; inject it at runtime",
))
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "mode", rename_all = "snake_case")]
pub enum TransportConfig {
Tls {
ca_cert_path: Option<String>,
client_cert_path: Option<String>,
client_key_path: Option<String>,
#[serde(default)]
allow_invalid_certificates: bool,
#[serde(default)]
allow_invalid_hostnames: bool,
},
Plaintext,
#[cfg(feature = "tls")]
RustlsConfig {
config: RustlsClientConfig,
},
}
impl Default for TransportConfig {
fn default() -> Self {
Self::tls()
}
}
impl TransportConfig {
pub fn tls() -> Self {
Self::Tls {
ca_cert_path: None,
client_cert_path: None,
client_key_path: None,
allow_invalid_certificates: false,
allow_invalid_hostnames: false,
}
}
pub fn tls_with_ca_cert_path(ca_cert_path: Option<String>) -> Self {
Self::Tls {
ca_cert_path,
client_cert_path: None,
client_key_path: None,
allow_invalid_certificates: false,
allow_invalid_hostnames: false,
}
}
pub fn mtls(
ca_cert_path: Option<String>,
client_cert_path: String,
client_key_path: String,
) -> Self {
Self::Tls {
ca_cert_path,
client_cert_path: Some(client_cert_path),
client_key_path: Some(client_key_path),
allow_invalid_certificates: false,
allow_invalid_hostnames: false,
}
}
pub fn tls_insecure_skip_verify() -> Self {
Self::Tls {
ca_cert_path: None,
client_cert_path: None,
client_key_path: None,
allow_invalid_certificates: true,
allow_invalid_hostnames: true,
}
}
pub const fn plaintext() -> Self {
Self::Plaintext
}
#[cfg(feature = "tls")]
pub fn rustls_config(config: Arc<rustls::ClientConfig>) -> Self {
Self::RustlsConfig {
config: RustlsClientConfig(config),
}
}
pub fn is_tls(&self) -> bool {
#[cfg(feature = "tls")]
if matches!(self, Self::RustlsConfig { .. }) {
return true;
}
matches!(self, Self::Tls { .. })
}
pub fn is_mtls(&self) -> bool {
matches!(
self,
Self::Tls {
client_cert_path: Some(_),
client_key_path: Some(_),
..
}
)
}
pub fn ca_cert_path(&self) -> Option<&str> {
match self {
Self::Tls {
ca_cert_path: Some(path),
..
} => Some(path.as_str()),
_ => None,
}
}
pub fn allow_invalid_certificates(&self) -> bool {
match self {
Self::Tls {
allow_invalid_certificates,
..
} => *allow_invalid_certificates,
_ => false,
}
}
pub fn allow_invalid_hostnames(&self) -> bool {
match self {
Self::Tls {
allow_invalid_hostnames,
..
} => *allow_invalid_hostnames,
_ => false,
}
}
pub fn client_cert_path(&self) -> Option<&str> {
match self {
Self::Tls {
client_cert_path: Some(path),
..
} => Some(path.as_str()),
_ => None,
}
}
pub fn client_key_path(&self) -> Option<&str> {
match self {
Self::Tls {
client_key_path: Some(path),
..
} => Some(path.as_str()),
_ => None,
}
}
pub fn warn_if_insecure(&self, source_label: &str) {
if self.allow_invalid_certificates() {
tracing::warn!(
target: "rustcdc::transport::security",
source = source_label,
flag = "allow_invalid_certificates",
"TLS certificate verification is disabled — do not use in production"
);
}
if self.allow_invalid_hostnames() {
tracing::warn!(
target: "rustcdc::transport::security",
source = source_label,
flag = "allow_invalid_hostnames",
"TLS hostname verification is disabled — do not use in production"
);
}
}
}
#[cfg(test)]
mod tests {
use super::TransportConfig;
#[test]
fn tls_defaults_to_strict_verification() {
let transport = TransportConfig::tls();
assert!(transport.is_tls());
assert!(!transport.allow_invalid_certificates());
assert!(!transport.allow_invalid_hostnames());
}
#[test]
fn tls_insecure_skip_verify_sets_insecure_flags() {
let transport = TransportConfig::tls_insecure_skip_verify();
assert!(transport.is_tls());
assert!(transport.allow_invalid_certificates());
assert!(transport.allow_invalid_hostnames());
}
}