use serde::{
de::{MapAccess, Visitor},
Deserialize, Deserializer, Serialize,
};
use std::{
collections::HashMap,
sync::{atomic::AtomicUsize, Arc},
};
use crate::traits::Handler;
use tracing::trace;
pub type Config = HashMap<String, Route>;
pub type PublisherConfig = HashMap<String, Endpoint>;
#[derive(Debug, Deserialize, Serialize, Clone)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct Route {
pub input: Endpoint,
#[serde(default = "default_output_endpoint")]
pub output: Endpoint,
#[serde(flatten, default)]
pub options: RouteOptions,
}
impl Default for Route {
fn default() -> Self {
Self {
input: Endpoint::null(),
output: Endpoint::null(),
options: RouteOptions::default(),
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct RouteOptions {
#[serde(default, skip_serializing_if = "String::is_empty")]
pub description: String,
#[serde(default = "default_concurrency")]
#[cfg_attr(feature = "schema", schemars(range(min = 1)))]
pub concurrency: usize,
#[serde(default = "default_batch_size")]
#[cfg_attr(feature = "schema", schemars(range(min = 1)))]
pub batch_size: usize,
#[serde(default = "default_commit_concurrency_limit")]
pub commit_concurrency_limit: usize,
}
impl Default for RouteOptions {
fn default() -> Self {
Self {
description: String::new(),
concurrency: default_concurrency(),
batch_size: default_batch_size(),
commit_concurrency_limit: default_commit_concurrency_limit(),
}
}
}
pub(crate) fn default_concurrency() -> usize {
1
}
pub(crate) fn default_batch_size() -> usize {
1
}
pub(crate) fn default_commit_concurrency_limit() -> usize {
4096
}
fn default_output_endpoint() -> Endpoint {
Endpoint::new(EndpointType::Null)
}
fn default_retry_attempts() -> usize {
3
}
fn default_initial_interval_ms() -> u64 {
100
}
fn default_max_interval_ms() -> u64 {
5000
}
fn default_multiplier() -> f64 {
2.0
}
fn default_clean_session() -> bool {
false
}
fn default_cookie_metadata_key() -> String {
"cookie".to_string()
}
fn default_set_cookie_metadata_key() -> String {
"set-cookie".to_string()
}
fn is_known_endpoint_name(name: &str) -> bool {
matches!(
name,
"aws"
| "kafka"
| "nats"
| "file"
| "static"
| "memory"
| "sled"
| "amqp"
| "mongodb"
| "mqtt"
| "http"
| "websocket"
| "ibmmq"
| "zeromq"
| "grpc"
| "fanout"
| "ref"
| "switch"
| "response"
| "reader"
| "null"
| "sqlx"
)
}
#[derive(Serialize, Clone, Default)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct Endpoint {
#[serde(default)]
pub middlewares: Vec<Middleware>,
#[serde(flatten)]
pub endpoint_type: EndpointType,
#[serde(skip_serializing)]
#[cfg_attr(feature = "schema", schemars(skip))]
pub handler: Option<Arc<dyn Handler>>,
}
impl std::fmt::Debug for Endpoint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Endpoint")
.field("middlewares", &self.middlewares)
.field("endpoint_type", &self.endpoint_type)
.field(
"handler",
&if self.handler.is_some() {
"Some(<Handler>)"
} else {
"None"
},
)
.finish()
}
}
impl<'de> Deserialize<'de> for Endpoint {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct EndpointVisitor;
impl<'de> Visitor<'de> for EndpointVisitor {
type Value = Endpoint;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a map representing an endpoint or null")
}
fn visit_unit<E>(self) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(Endpoint::new(EndpointType::Null))
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
let mut temp_map = serde_json::Map::new();
let mut middlewares_val = None;
while let Some((key, value)) = map.next_entry::<String, serde_json::Value>()? {
if key == "middlewares" {
middlewares_val = Some(value);
} else {
temp_map.insert(key, value);
}
}
let temp_val = serde_json::Value::Object(temp_map);
let endpoint_type: EndpointType = match serde_json::from_value(temp_val.clone()) {
Ok(et) => et,
Err(original_err) => {
if let serde_json::Value::Object(map) = &temp_val {
if map.len() == 1 {
let (name, config) = map.iter().next().unwrap();
if is_known_endpoint_name(name) {
return Err(serde::de::Error::custom(original_err));
}
trace!("Falling back to Custom endpoint for key: {}", name);
EndpointType::Custom {
name: name.clone(),
config: config.clone(),
}
} else if map.is_empty() {
EndpointType::Null
} else {
return Err(serde::de::Error::custom(
"Invalid endpoint configuration: multiple keys found or unknown endpoint type",
));
}
} else {
return Err(serde::de::Error::custom("Invalid endpoint configuration"));
}
}
};
let middlewares = match middlewares_val {
Some(val) => {
deserialize_middlewares_from_value(val).map_err(serde::de::Error::custom)?
}
None => Vec::new(),
};
Ok(Endpoint {
middlewares,
endpoint_type,
handler: None,
})
}
}
deserializer.deserialize_any(EndpointVisitor)
}
}
fn is_known_middleware_name(name: &str) -> bool {
matches!(
name,
"deduplication"
| "metrics"
| "dlq"
| "retry"
| "random_panic"
| "delay"
| "weak_join"
| "limiter"
| "buffer"
| "cookie_jar"
| "custom"
)
}
fn deserialize_middlewares_from_value(value: serde_json::Value) -> anyhow::Result<Vec<Middleware>> {
let arr = match value {
serde_json::Value::Array(arr) => arr,
serde_json::Value::Object(map) => {
let mut middlewares: Vec<_> = map
.into_iter()
.filter_map(|(key, value)| key.parse::<usize>().ok().map(|index| (index, value)))
.collect();
middlewares.sort_by_key(|(index, _)| *index);
middlewares.into_iter().map(|(_, value)| value).collect()
}
_ => return Err(anyhow::anyhow!("Expected an array or object")),
};
let mut middlewares = Vec::new();
for item in arr {
let known_name = if let serde_json::Value::Object(map) = &item {
if map.len() == 1 {
let (name, _) = map.iter().next().unwrap();
if is_known_middleware_name(name) {
Some(name.clone())
} else {
None
}
} else {
None
}
} else {
None
};
if let Some(name) = known_name {
match serde_json::from_value::<Middleware>(item.clone()) {
Ok(m) => middlewares.push(m),
Err(e) => {
return Err(anyhow::anyhow!(
"Failed to deserialize known middleware '{}': {}",
name,
e
))
}
}
} else if let Ok(m) = serde_json::from_value::<Middleware>(item.clone()) {
middlewares.push(m);
} else if let serde_json::Value::Object(map) = &item {
if map.len() == 1 {
let (name, config) = map.iter().next().unwrap();
middlewares.push(Middleware::Custom {
name: name.clone(),
config: config.clone(),
});
} else {
return Err(anyhow::anyhow!(
"Invalid middleware configuration: {:?}",
item
));
}
} else {
return Err(anyhow::anyhow!(
"Invalid middleware configuration: {:?}",
item
));
}
}
Ok(middlewares)
}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(rename_all = "lowercase")]
pub enum EndpointType {
Aws(AwsConfig),
Kafka(KafkaConfig),
Nats(NatsConfig),
File(FileConfig),
Static(String),
Ref(String),
Memory(MemoryConfig),
Sled(SledConfig),
Amqp(AmqpConfig),
MongoDb(MongoDbConfig),
Mqtt(MqttConfig),
Http(HttpConfig),
WebSocket(WebSocketConfig),
IbmMq(IbmMqConfig),
ZeroMq(ZeroMqConfig),
Grpc(GrpcConfig),
Sqlx(SqlxConfig),
Fanout(Vec<Endpoint>),
Switch(SwitchConfig),
Response(ResponseConfig),
Reader(Box<Endpoint>),
Custom {
name: String,
config: serde_json::Value,
},
#[default]
Null,
}
impl EndpointType {
pub fn name(&self) -> &'static str {
match self {
EndpointType::Aws(_) => "aws",
EndpointType::Kafka(_) => "kafka",
EndpointType::Nats(_) => "nats",
EndpointType::File(_) => "file",
EndpointType::Static(_) => "static",
EndpointType::Ref(_) => "ref",
EndpointType::Memory(_) => "memory",
EndpointType::Sled(_) => "sled",
EndpointType::Amqp(_) => "amqp",
EndpointType::MongoDb(_) => "mongodb",
EndpointType::Mqtt(_) => "mqtt",
EndpointType::Http(_) => "http",
EndpointType::WebSocket(_) => "websocket",
EndpointType::IbmMq(_) => "ibmmq",
EndpointType::ZeroMq(_) => "zeromq",
EndpointType::Grpc(_) => "grpc",
EndpointType::Sqlx(_) => "sqlx",
EndpointType::Fanout(_) => "fanout",
EndpointType::Switch(_) => "switch",
EndpointType::Response(_) => "response",
EndpointType::Reader(_) => "reader",
EndpointType::Custom { .. } => "custom",
EndpointType::Null => "null",
}
}
pub fn is_core(&self) -> bool {
matches!(
self,
EndpointType::File(_)
| EndpointType::Static(_)
| EndpointType::Ref(_)
| EndpointType::Memory(_)
| EndpointType::Fanout(_)
| EndpointType::Switch(_)
| EndpointType::Response(_)
| EndpointType::Reader(_)
| EndpointType::Custom { .. }
| EndpointType::Null
)
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(rename_all = "snake_case")]
pub enum Middleware {
Deduplication(DeduplicationMiddleware),
Metrics(MetricsMiddleware),
Dlq(Box<DeadLetterQueueMiddleware>),
Retry(RetryMiddleware),
RandomPanic(RandomPanicMiddleware),
Delay(DelayMiddleware),
WeakJoin(WeakJoinMiddleware),
Limiter(LimiterMiddleware),
Buffer(BufferMiddleware),
CookieJar(CookieJarMiddleware),
Custom {
name: String,
config: serde_json::Value,
},
}
#[derive(Debug, Deserialize, Serialize, Clone)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct DeduplicationMiddleware {
pub sled_path: String,
pub ttl_seconds: u64,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct MetricsMiddleware {}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct DeadLetterQueueMiddleware {
pub endpoint: Endpoint,
}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct RetryMiddleware {
#[serde(default = "default_retry_attempts")]
pub max_attempts: usize,
#[serde(default = "default_initial_interval_ms")]
pub initial_interval_ms: u64,
#[serde(default = "default_max_interval_ms")]
pub max_interval_ms: u64,
#[serde(default = "default_multiplier")]
pub multiplier: f64,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct DelayMiddleware {
pub delay_ms: u64,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct LimiterMiddleware {
pub messages_per_second: f64,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct BufferMiddleware {
pub max_messages: usize,
pub max_delay_ms: u64,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct CookieJarMiddleware {
#[serde(default)]
pub shared_scope: Option<String>,
#[serde(default = "default_cookie_metadata_key")]
pub cookie_metadata_key: String,
#[serde(default = "default_set_cookie_metadata_key")]
pub set_cookie_metadata_key: String,
#[serde(default)]
pub capture_metadata_keys: Vec<String>,
#[serde(default)]
pub export_metadata_prefix: Option<String>,
#[serde(default)]
pub inject_metadata: HashMap<String, String>,
}
impl Default for CookieJarMiddleware {
fn default() -> Self {
Self {
shared_scope: None,
cookie_metadata_key: default_cookie_metadata_key(),
set_cookie_metadata_key: default_set_cookie_metadata_key(),
capture_metadata_keys: Vec::new(),
export_metadata_prefix: None,
inject_metadata: HashMap::new(),
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct WeakJoinMiddleware {
pub group_by: String,
pub expected_count: usize,
pub timeout_ms: u64,
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(rename_all = "snake_case")]
pub enum FaultMode {
#[default]
Panic,
Disconnect,
Timeout,
JsonFormatError,
Nack,
}
impl std::fmt::Display for FaultMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FaultMode::Panic => write!(f, "panic"),
FaultMode::Disconnect => write!(f, "disconnect"),
FaultMode::Timeout => write!(f, "timeout"),
FaultMode::JsonFormatError => write!(f, "json_format_error"),
FaultMode::Nack => write!(f, "nack"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(deny_unknown_fields)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
pub struct RandomPanicMiddleware {
#[serde(default)]
pub mode: FaultMode,
#[cfg_attr(feature = "schema", schemars(range(min = 1)))]
#[serde(default)]
pub trigger_on_message: Option<usize>,
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(skip, default = "default_atomic_usize_arc")]
#[cfg_attr(feature = "schema", schemars(skip))]
pub message_count: Arc<AtomicUsize>,
}
fn default_true() -> bool {
true
}
fn default_atomic_usize_arc() -> Arc<AtomicUsize> {
Arc::new(AtomicUsize::new(0))
}
fn deserialize_null_as_false<'de, D>(deserializer: D) -> Result<bool, D::Error>
where
D: Deserializer<'de>,
{
let opt = Option::<bool>::deserialize(deserializer)?;
Ok(opt.unwrap_or(false))
}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct AwsConfig {
#[cfg_attr(feature = "schema", schemars(extend("format"="password")))]
pub queue_url: Option<String>,
pub topic_arn: Option<String>,
pub region: Option<String>,
#[cfg_attr(feature = "schema", schemars(extend("format"="password")))]
pub endpoint_url: Option<String>,
#[cfg_attr(feature = "schema", schemars(extend("format"="password")))]
pub access_key: Option<String>,
#[cfg_attr(feature = "schema", schemars(extend("format"="password")))]
pub secret_key: Option<String>,
#[cfg_attr(feature = "schema", schemars(extend("format"="password")))]
pub session_token: Option<String>,
#[cfg_attr(feature = "schema", schemars(range(min = 1, max = 10)))]
pub max_messages: Option<i32>,
#[cfg_attr(feature = "schema", schemars(range(min = 0, max = 20)))]
pub wait_time_seconds: Option<i32>,
#[serde(default)]
pub binary_payload_mode: bool,
}
impl AwsConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_queue_url(mut self, queue_url: impl Into<String>) -> Self {
self.queue_url = Some(queue_url.into());
self
}
pub fn with_topic_arn(mut self, topic_arn: impl Into<String>) -> Self {
self.topic_arn = Some(topic_arn.into());
self
}
pub fn with_region(mut self, region: impl Into<String>) -> Self {
self.region = Some(region.into());
self
}
pub fn with_endpoint_url(mut self, endpoint_url: impl Into<String>) -> Self {
self.endpoint_url = Some(endpoint_url.into());
self
}
pub fn with_credentials(
mut self,
access_key: impl Into<String>,
secret_key: impl Into<String>,
) -> Self {
self.access_key = Some(access_key.into());
self.secret_key = Some(secret_key.into());
self
}
}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct KafkaConfig {
#[serde(alias = "brokers")]
#[cfg_attr(feature = "schema", schemars(extend("format"="password")))]
pub url: String,
pub topic: Option<String>,
pub username: Option<String>,
#[cfg_attr(feature = "schema", schemars(extend("format"="password")))]
pub password: Option<String>,
#[serde(default)]
pub tls: TlsConfig,
pub group_id: Option<String>,
#[serde(default)]
pub delayed_ack: bool,
#[serde(default)]
pub producer_options: Option<Vec<(String, String)>>,
#[serde(default)]
pub consumer_options: Option<Vec<(String, String)>>,
}
impl KafkaConfig {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
..Default::default()
}
}
pub fn with_topic(mut self, topic: impl Into<String>) -> Self {
self.topic = Some(topic.into());
self
}
pub fn with_group_id(mut self, group_id: impl Into<String>) -> Self {
self.group_id = Some(group_id.into());
self
}
pub fn with_credentials(
mut self,
username: impl Into<String>,
password: impl Into<String>,
) -> Self {
self.username = Some(username.into());
self.password = Some(password.into());
self
}
pub fn with_producer_option(
mut self,
key: impl Into<String>,
value: impl Into<String>,
) -> Self {
let options = self.producer_options.get_or_insert_with(Vec::new);
options.push((key.into(), value.into()));
self
}
pub fn with_consumer_option(
mut self,
key: impl Into<String>,
value: impl Into<String>,
) -> Self {
let options = self.consumer_options.get_or_insert_with(Vec::new);
options.push((key.into(), value.into()));
self
}
}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct SledConfig {
pub path: String,
pub tree: Option<String>,
#[serde(default)]
pub read_from_start: bool,
#[serde(default)]
pub delete_after_read: bool,
}
impl SledConfig {
pub fn new(path: impl Into<String>) -> Self {
Self {
path: path.into(),
..Default::default()
}
}
pub fn with_tree(mut self, tree: impl Into<String>) -> Self {
self.tree = Some(tree.into());
self
}
pub fn with_read_from_start(mut self, read_from_start: bool) -> Self {
self.read_from_start = read_from_start;
self
}
}
#[derive(Debug, Deserialize, Serialize, Clone, Default, PartialEq, Eq)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(rename_all = "snake_case")]
pub enum FileFormat {
#[default]
Normal,
Json,
Text,
Raw,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
pub struct FileConfig {
pub path: String,
pub delimiter: Option<String>,
#[serde(flatten, default)]
pub mode: Option<FileConsumerMode>,
#[serde(default)]
pub format: FileFormat,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "mode", rename_all = "snake_case")]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
pub enum FileConsumerMode {
Consume {
#[serde(default)]
delete: bool,
},
Subscribe {
#[serde(default)]
delete: bool,
},
GroupSubscribe {
group_id: String,
#[serde(default)]
read_from_tail: bool,
},
}
impl Default for FileConsumerMode {
fn default() -> Self {
Self::Consume { delete: false }
}
}
impl FileConfig {
pub fn new(path: impl Into<String>) -> Self {
Self {
path: path.into(),
mode: Some(FileConsumerMode::default()),
delimiter: None,
format: FileFormat::default(),
}
}
pub fn with_mode(mut self, mode: FileConsumerMode) -> Self {
self.mode = Some(mode);
self
}
pub fn effective_mode(&self) -> FileConsumerMode {
self.mode.clone().unwrap_or_default()
}
}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct NatsConfig {
#[cfg_attr(feature = "schema", schemars(extend("format"="password")))]
pub url: String,
pub subject: Option<String>,
pub stream: Option<String>,
pub username: Option<String>,
#[cfg_attr(feature = "schema", schemars(extend("format"="password")))]
pub password: Option<String>,
#[serde(default)]
pub tls: TlsConfig,
#[cfg_attr(feature = "schema", schemars(extend("format"="password")))]
pub token: Option<String>,
#[serde(default)]
pub request_reply: bool,
pub request_timeout_ms: Option<u64>,
#[serde(default)]
pub delayed_ack: bool,
#[serde(default)]
pub no_jetstream: bool,
#[serde(default)]
pub subscriber_mode: bool,
pub stream_max_messages: Option<i64>,
pub deliver_policy: Option<NatsDeliverPolicy>,
pub stream_max_bytes: Option<i64>,
pub prefetch_count: Option<usize>,
}
#[derive(Debug, Deserialize, Serialize, Clone, Default, PartialEq, Eq)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(rename_all = "snake_case")]
pub enum NatsDeliverPolicy {
#[default]
All,
Last,
New,
LastPerSubject,
}
impl NatsConfig {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
..Default::default()
}
}
pub fn with_subject(mut self, subject: impl Into<String>) -> Self {
self.subject = Some(subject.into());
self
}
pub fn with_stream(mut self, stream: impl Into<String>) -> Self {
self.stream = Some(stream.into());
self
}
pub fn with_deliver_policy(mut self, policy: NatsDeliverPolicy) -> Self {
self.deliver_policy = Some(policy);
self
}
pub fn with_credentials(
mut self,
username: impl Into<String>,
password: impl Into<String>,
) -> Self {
self.username = Some(username.into());
self.password = Some(password.into());
self
}
}
#[derive(Debug, Serialize, Deserialize, Clone, Default)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct MemoryConfig {
pub topic: String,
pub capacity: Option<usize>,
#[serde(default)]
pub request_reply: bool,
pub request_timeout_ms: Option<u64>,
#[serde(default)]
pub subscribe_mode: bool,
#[serde(default)]
pub enable_nack: bool,
}
impl MemoryConfig {
pub fn new(topic: impl Into<String>, capacity: Option<usize>) -> Self {
Self {
topic: topic.into(),
capacity,
..Default::default()
}
}
pub fn with_subscribe(self, subscribe_mode: bool) -> Self {
Self {
subscribe_mode,
..self
}
}
pub fn with_request_reply(mut self, request_reply: bool) -> Self {
self.request_reply = request_reply;
self
}
}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct AmqpConfig {
#[cfg_attr(feature = "schema", schemars(extend("format"="password")))]
pub url: String,
pub queue: Option<String>,
#[serde(default)]
pub subscribe_mode: bool,
pub username: Option<String>,
#[cfg_attr(feature = "schema", schemars(extend("format"="password")))]
pub password: Option<String>,
#[serde(default)]
pub tls: TlsConfig,
pub exchange: Option<String>,
pub prefetch_count: Option<u16>,
#[serde(default)]
pub no_persistence: bool,
#[serde(default)]
pub no_declare_queue: bool,
#[serde(default)]
pub delayed_ack: bool,
}
impl AmqpConfig {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
..Default::default()
}
}
pub fn with_queue(mut self, queue: impl Into<String>) -> Self {
self.queue = Some(queue.into());
self
}
pub fn with_exchange(mut self, exchange: impl Into<String>) -> Self {
self.exchange = Some(exchange.into());
self
}
pub fn with_credentials(
mut self,
username: impl Into<String>,
password: impl Into<String>,
) -> Self {
self.username = Some(username.into());
self.password = Some(password.into());
self
}
}
#[derive(Debug, Deserialize, Serialize, Clone, Default, PartialEq, Eq)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(rename_all = "lowercase")]
pub enum MongoDbFormat {
#[default]
Normal,
Json,
Text,
Raw,
}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct MongoDbConfig {
#[cfg_attr(feature = "schema", schemars(extend("format"="password")))]
pub url: String,
pub collection: Option<String>,
pub username: Option<String>,
#[cfg_attr(feature = "schema", schemars(extend("format"="password")))]
pub password: Option<String>,
#[serde(default)]
pub tls: TlsConfig,
pub database: String,
pub polling_interval_ms: Option<u64>,
pub reply_polling_ms: Option<u64>,
#[serde(default)]
pub request_reply: bool,
#[serde(default)]
pub change_stream: bool,
pub request_timeout_ms: Option<u64>,
pub ttl_seconds: Option<u64>,
pub capped_size_bytes: Option<i64>,
#[serde(default)]
pub format: MongoDbFormat,
pub cursor_id: Option<String>,
pub receive_query: Option<String>,
pub meta_collection: Option<String>,
}
impl MongoDbConfig {
pub fn new(url: impl Into<String>, database: impl Into<String>) -> Self {
Self {
url: url.into(),
database: database.into(),
..Default::default()
}
}
pub fn with_collection(mut self, collection: impl Into<String>) -> Self {
self.collection = Some(collection.into());
self
}
pub fn with_credentials(
mut self,
username: impl Into<String>,
password: impl Into<String>,
) -> Self {
self.username = Some(username.into());
self.password = Some(password.into());
self
}
pub fn with_change_stream(mut self, change_stream: bool) -> Self {
self.change_stream = change_stream;
self
}
}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct MqttConfig {
#[cfg_attr(feature = "schema", schemars(extend("format"="password")))]
pub url: String,
pub topic: Option<String>,
pub username: Option<String>,
#[cfg_attr(feature = "schema", schemars(extend("format"="password")))]
pub password: Option<String>,
#[serde(default)]
pub tls: TlsConfig,
pub client_id: Option<String>,
pub queue_capacity: Option<usize>,
pub max_inflight: Option<u16>,
pub qos: Option<u8>,
#[serde(default = "default_clean_session")]
pub clean_session: bool,
pub keep_alive_seconds: Option<u64>,
#[serde(default)]
pub protocol: MqttProtocol,
pub session_expiry_interval: Option<u32>,
#[serde(default)]
pub delayed_ack: bool,
}
impl MqttConfig {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
..Default::default()
}
}
pub fn with_topic(mut self, topic: impl Into<String>) -> Self {
self.topic = Some(topic.into());
self
}
pub fn with_client_id(mut self, client_id: impl Into<String>) -> Self {
self.client_id = Some(client_id.into());
self
}
pub fn with_credentials(
mut self,
username: impl Into<String>,
password: impl Into<String>,
) -> Self {
self.username = Some(username.into());
self.password = Some(password.into());
self
}
}
#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq, Eq)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(rename_all = "lowercase")]
pub enum MqttProtocol {
#[default]
V5,
V3,
}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct ZeroMqConfig {
#[cfg_attr(feature = "schema", schemars(extend("format"="password")))]
pub url: String,
#[serde(default)]
pub socket_type: Option<ZeroMqSocketType>,
pub topic: Option<String>,
#[serde(default)]
pub bind: bool,
#[serde(default)]
pub internal_buffer_size: Option<usize>,
}
impl ZeroMqConfig {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
..Default::default()
}
}
pub fn with_socket_type(mut self, socket_type: ZeroMqSocketType) -> Self {
self.socket_type = Some(socket_type);
self
}
pub fn with_bind(mut self, bind: bool) -> Self {
self.bind = bind;
self
}
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(rename_all = "lowercase")]
pub enum ZeroMqSocketType {
Push,
Pull,
Pub,
Sub,
Req,
Rep,
}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct GrpcConfig {
#[cfg_attr(feature = "schema", schemars(extend("format"="password")))]
pub url: String,
pub topic: Option<String>,
pub timeout_ms: Option<u64>,
#[serde(default)]
pub tls: TlsConfig,
#[serde(default)]
pub server_mode: bool,
#[serde(default)]
pub initial_stream_window_size: Option<u32>,
#[serde(default)]
pub initial_connection_window_size: Option<u32>,
#[serde(default)]
pub concurrency_limit_per_connection: Option<usize>,
#[serde(default)]
pub http2_keepalive_interval_ms: Option<u64>,
#[serde(default)]
pub http2_keepalive_timeout_ms: Option<u64>,
#[serde(default)]
pub max_decoding_message_size: Option<usize>,
}
impl GrpcConfig {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
..Default::default()
}
}
pub fn with_topic(mut self, topic: impl Into<String>) -> Self {
self.topic = Some(topic.into());
self
}
pub fn with_server_mode(mut self, server_mode: bool) -> Self {
self.server_mode = server_mode;
self
}
}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct HttpConfig {
pub url: String,
pub path: Option<String>,
pub method: Option<String>,
#[serde(default)]
pub tls: TlsConfig,
pub workers: Option<usize>,
pub message_id_header: Option<String>,
pub request_timeout_ms: Option<u64>,
pub internal_buffer_size: Option<usize>,
#[serde(default)]
pub fire_and_forget: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub batch_concurrency: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tcp_keepalive_ms: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub pool_idle_timeout_ms: Option<u64>,
#[serde(default)]
pub compression_enabled: bool,
#[serde(default)]
pub compression_threshold_bytes: Option<usize>,
pub concurrency_limit: Option<usize>,
#[cfg_attr(feature = "schema", schemars(extend("format"="password")))]
#[serde(
default,
skip_serializing_if = "Option::is_none",
deserialize_with = "deserialize_basic_auth"
)]
pub basic_auth: Option<(String, String)>,
#[cfg_attr(feature = "schema", schemars(extend("format"="password")))]
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub custom_headers: HashMap<String, String>,
}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct WebSocketConfig {
#[cfg_attr(feature = "schema", schemars(extend("format"="password")))]
pub url: String,
pub path: Option<String>,
pub message_id_header: Option<String>,
pub internal_buffer_size: Option<usize>,
}
fn deserialize_basic_auth<'de, D>(deserializer: D) -> Result<Option<(String, String)>, D::Error>
where
D: Deserializer<'de>,
{
let val = serde_json::Value::deserialize(deserializer)?;
match val {
serde_json::Value::Null => Ok(None),
serde_json::Value::Array(arr) => {
if arr.len() != 2 {
return Err(serde::de::Error::custom("basic_auth must have 2 elements"));
}
let u = arr[0]
.as_str()
.ok_or_else(|| serde::de::Error::custom("basic_auth[0] must be string"))?
.to_string();
let p = arr[1]
.as_str()
.ok_or_else(|| serde::de::Error::custom("basic_auth[1] must be string"))?
.to_string();
Ok(Some((u, p)))
}
serde_json::Value::Object(map) => {
let u = map
.get("0")
.and_then(|v| v.as_str())
.ok_or_else(|| serde::de::Error::custom("basic_auth map missing '0'"))?
.to_string();
let p = map
.get("1")
.and_then(|v| v.as_str())
.ok_or_else(|| serde::de::Error::custom("basic_auth map missing '1'"))?
.to_string();
Ok(Some((u, p)))
}
_ => Err(serde::de::Error::custom("invalid type for basic_auth")),
}
}
impl HttpConfig {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
..Default::default()
}
}
pub fn with_workers(mut self, workers: usize) -> Self {
self.workers = Some(workers);
self
}
pub fn with_method(mut self, method: impl Into<String>) -> Self {
self.method = Some(method.into());
self
}
pub fn with_path(mut self, path: impl Into<String>) -> Self {
self.path = Some(path.into());
self
}
}
impl WebSocketConfig {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
..Default::default()
}
}
pub fn with_path(mut self, path: impl Into<String>) -> Self {
self.path = Some(path.into());
self
}
}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct IbmMqConfig {
#[cfg_attr(feature = "schema", schemars(extend("format"="password")))]
pub url: String,
pub queue: Option<String>,
pub topic: Option<String>,
pub queue_manager: String,
pub channel: String,
pub username: Option<String>,
#[cfg_attr(feature = "schema", schemars(extend("format"="password")))]
pub password: Option<String>,
pub cipher_spec: Option<String>,
#[serde(default)]
pub tls: TlsConfig,
#[serde(default = "default_max_message_size")]
pub max_message_size: usize,
#[serde(default = "default_wait_timeout_ms")]
pub wait_timeout_ms: i32,
#[serde(default)]
pub internal_buffer_size: Option<usize>,
#[serde(default)]
pub disable_status_inq: bool,
}
impl IbmMqConfig {
pub fn new(
url: impl Into<String>,
queue_manager: impl Into<String>,
channel: impl Into<String>,
) -> Self {
Self {
url: url.into(),
queue_manager: queue_manager.into(),
channel: channel.into(),
disable_status_inq: false,
..Default::default()
}
}
pub fn with_queue(mut self, queue: impl Into<String>) -> Self {
self.queue = Some(queue.into());
self
}
pub fn with_topic(mut self, topic: impl Into<String>) -> Self {
self.topic = Some(topic.into());
self
}
pub fn with_credentials(
mut self,
username: impl Into<String>,
password: impl Into<String>,
) -> Self {
self.username = Some(username.into());
self.password = Some(password.into());
self
}
}
fn default_max_message_size() -> usize {
4 * 1024 * 1024 }
fn default_wait_timeout_ms() -> i32 {
1000 }
#[derive(Debug, Deserialize, Serialize, Clone)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct SwitchConfig {
pub metadata_key: String,
pub cases: HashMap<String, Endpoint>,
pub default: Option<Box<Endpoint>>,
}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct ResponseConfig {
}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct SqlxConfig {
#[cfg_attr(feature = "schema", schemars(extend("format"="password")))]
pub url: String,
#[serde(default)]
pub username: Option<String>,
#[cfg_attr(feature = "schema", schemars(extend("format"="password")))]
#[serde(default)]
pub password: Option<String>,
pub table: String,
pub insert_query: Option<String>,
pub select_query: Option<String>,
#[serde(default)]
pub delete_after_read: bool,
#[serde(default)]
pub auto_create_table: bool,
pub polling_interval_ms: Option<u64>,
#[serde(default)]
pub tls: TlsConfig,
pub max_connections: Option<u32>,
pub min_connections: Option<u32>,
pub acquire_timeout_ms: Option<u64>,
pub idle_timeout_ms: Option<u64>,
pub max_lifetime_ms: Option<u64>,
}
#[derive(Debug, Deserialize, Serialize, Clone, Default, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct TlsConfig {
#[serde(default, deserialize_with = "deserialize_null_as_false")]
pub required: bool,
pub ca_file: Option<String>,
pub cert_file: Option<String>,
pub key_file: Option<String>,
#[cfg_attr(feature = "schema", schemars(extend("format"="password")))]
pub cert_password: Option<String>,
#[serde(default)]
pub accept_invalid_certs: bool,
}
impl TlsConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_ca_file(mut self, ca_file: impl Into<String>) -> Self {
self.ca_file = Some(ca_file.into());
self.required = true;
self
}
pub fn with_client_cert(
mut self,
cert_file: impl Into<String>,
key_file: impl Into<String>,
) -> Self {
self.cert_file = Some(cert_file.into());
self.key_file = Some(key_file.into());
self.required = true;
self
}
pub fn with_insecure(mut self, accept_invalid_certs: bool) -> Self {
self.accept_invalid_certs = accept_invalid_certs;
self
}
pub fn is_mtls_client_configured(&self) -> bool {
self.required && self.cert_file.is_some() && self.key_file.is_some()
}
pub fn is_tls_server_configured(&self) -> bool {
self.required && self.cert_file.is_some() && self.key_file.is_some()
}
pub fn is_tls_client_configured(&self) -> bool {
self.required
|| self.ca_file.is_some()
|| (self.cert_file.is_some() && self.key_file.is_some())
}
pub fn normalize_url(&self, url: &str) -> String {
if url
.get(..7)
.is_some_and(|prefix| prefix.eq_ignore_ascii_case("http://"))
|| url
.get(..8)
.is_some_and(|prefix| prefix.eq_ignore_ascii_case("https://"))
{
url.to_string()
} else {
let is_tls = self.required;
let scheme = if is_tls { "https" } else { "http" };
format!("{}://{}", scheme, url)
}
}
}
pub trait SecretExtractor {
fn extract_secrets(&mut self, prefix: &str, secrets: &mut HashMap<String, String>);
}
fn extract_sensitive_string_map_entries(
values: &mut HashMap<String, String>,
prefix: &str,
field_name: &str,
secrets: &mut HashMap<String, String>,
) {
let secret_keys = values
.keys()
.filter(|key| {
let key = key.to_ascii_lowercase();
key.contains("key") || key.contains("token") || key.contains("auth")
})
.cloned()
.collect::<Vec<_>>();
for key in secret_keys {
if let Some(value) = values.remove(&key) {
secrets.insert(
sanitize_secret_key(&format!("{}__{}__{}", prefix, field_name, key)),
value,
);
}
}
}
fn url_has_userinfo(url: &str) -> bool {
let Some(authority_start) = url.find("://").map(|idx| idx + 3) else {
return false;
};
let authority_end = url[authority_start..]
.find(['/', '?', '#'])
.map(|idx| authority_start + idx)
.unwrap_or(url.len());
url[authority_start..authority_end].contains('@')
}
fn sanitize_secret_key(key: &str) -> String {
key.chars()
.map(|ch| {
let ch = ch.to_ascii_uppercase();
if ch.is_ascii_alphanumeric() || ch == '_' {
ch
} else {
'_'
}
})
.collect()
}
fn extract_sensitive_url(
url: &mut String,
prefix: &str,
field_name: &str,
secrets: &mut HashMap<String, String>,
) {
if !url.is_empty() && url_has_userinfo(url) {
secrets.insert(
sanitize_secret_key(&format!("{}__{}", prefix, field_name)),
std::mem::take(url),
);
}
}
fn extract_sensitive_optional_url(
url: &mut Option<String>,
prefix: &str,
field_name: &str,
secrets: &mut HashMap<String, String>,
) {
if url.as_ref().is_some_and(|url| url_has_userinfo(url)) {
if let Some(url) = url.take() {
secrets.insert(
sanitize_secret_key(&format!("{}__{}", prefix, field_name)),
url,
);
}
}
}
impl SecretExtractor for Route {
fn extract_secrets(&mut self, prefix: &str, secrets: &mut HashMap<String, String>) {
self.input
.extract_secrets(&format!("{}__{}", prefix, "INPUT"), secrets);
self.output
.extract_secrets(&format!("{}__{}", prefix, "OUTPUT"), secrets);
}
}
impl SecretExtractor for Endpoint {
fn extract_secrets(&mut self, prefix: &str, secrets: &mut HashMap<String, String>) {
for (i, middleware) in self.middlewares.iter_mut().enumerate() {
middleware.extract_secrets(&format!("{}__{}__{}", prefix, "MIDDLEWARES", i), secrets);
}
self.endpoint_type.extract_secrets(prefix, secrets);
}
}
impl SecretExtractor for EndpointType {
fn extract_secrets(&mut self, prefix: &str, secrets: &mut HashMap<String, String>) {
match self {
EndpointType::Aws(cfg) => {
cfg.extract_secrets(&format!("{}__{}", prefix, "AWS"), secrets)
}
EndpointType::Kafka(cfg) => {
cfg.extract_secrets(&format!("{}__{}", prefix, "KAFKA"), secrets)
}
EndpointType::Nats(cfg) => {
cfg.extract_secrets(&format!("{}__{}", prefix, "NATS"), secrets)
}
EndpointType::Amqp(cfg) => {
cfg.extract_secrets(&format!("{}__{}", prefix, "AMQP"), secrets)
}
EndpointType::MongoDb(cfg) => {
cfg.extract_secrets(&format!("{}__{}", prefix, "MONGODB"), secrets)
}
EndpointType::Mqtt(cfg) => {
cfg.extract_secrets(&format!("{}__{}", prefix, "MQTT"), secrets)
}
EndpointType::Http(cfg) => {
cfg.extract_secrets(&format!("{}__{}", prefix, "HTTP"), secrets)
}
EndpointType::WebSocket(cfg) => {
cfg.extract_secrets(&format!("{}__{}", prefix, "WEBSOCKET"), secrets)
}
EndpointType::IbmMq(cfg) => {
cfg.extract_secrets(&format!("{}__{}", prefix, "IBMMQ"), secrets)
}
EndpointType::ZeroMq(cfg) => {
cfg.extract_secrets(&format!("{}__{}", prefix, "ZEROMQ"), secrets)
}
EndpointType::Sqlx(cfg) => {
cfg.extract_secrets(&format!("{}__{}", prefix, "SQLX"), secrets)
}
EndpointType::Grpc(cfg) => {
cfg.extract_secrets(&format!("{}__{}", prefix, "GRPC"), secrets)
}
EndpointType::Fanout(endpoints) => {
for (i, ep) in endpoints.iter_mut().enumerate() {
ep.extract_secrets(&format!("{}__{}__{}", prefix, "FANOUT", i), secrets);
}
}
EndpointType::Switch(cfg) => {
for (key, ep) in cfg.cases.iter_mut() {
ep.extract_secrets(
&format!(
"{}__{}__{}",
prefix,
"SWITCH__CASES",
sanitize_secret_key(key)
),
secrets,
);
}
if let Some(default) = &mut cfg.default {
default.extract_secrets(&format!("{}__{}", prefix, "SWITCH__DEFAULT"), secrets);
}
}
EndpointType::Reader(ep) => {
ep.extract_secrets(&format!("{}__{}", prefix, "READER"), secrets)
}
_ => {}
}
}
}
impl SecretExtractor for Middleware {
fn extract_secrets(&mut self, prefix: &str, secrets: &mut HashMap<String, String>) {
if let Middleware::Dlq(cfg) = self {
cfg.endpoint
.extract_secrets(&format!("{}__{}__{}", prefix, "DLQ", "ENDPOINT"), secrets);
}
}
}
impl SecretExtractor for AwsConfig {
fn extract_secrets(&mut self, prefix: &str, secrets: &mut HashMap<String, String>) {
if let Some(val) = self.access_key.take() {
secrets.insert(format!("{}__{}", prefix, "ACCESS_KEY"), val);
}
if let Some(val) = self.secret_key.take() {
secrets.insert(format!("{}__{}", prefix, "SECRET_KEY"), val);
}
if let Some(val) = self.session_token.take() {
secrets.insert(format!("{}__{}", prefix, "SESSION_TOKEN"), val);
}
extract_sensitive_optional_url(&mut self.queue_url, prefix, "QUEUE_URL", secrets);
extract_sensitive_optional_url(&mut self.endpoint_url, prefix, "ENDPOINT_URL", secrets);
}
}
impl SecretExtractor for KafkaConfig {
fn extract_secrets(&mut self, prefix: &str, secrets: &mut HashMap<String, String>) {
extract_sensitive_url(&mut self.url, prefix, "URL", secrets);
if let Some(val) = self.username.take() {
secrets.insert(format!("{}__{}", prefix, "USERNAME"), val);
}
if let Some(val) = self.password.take() {
secrets.insert(format!("{}__{}", prefix, "PASSWORD"), val);
}
self.tls
.extract_secrets(&format!("{}__{}", prefix, "TLS"), secrets);
}
}
impl SecretExtractor for NatsConfig {
fn extract_secrets(&mut self, prefix: &str, secrets: &mut HashMap<String, String>) {
extract_sensitive_url(&mut self.url, prefix, "URL", secrets);
if let Some(val) = self.username.take() {
secrets.insert(format!("{}__{}", prefix, "USERNAME"), val);
}
if let Some(val) = self.password.take() {
secrets.insert(format!("{}__{}", prefix, "PASSWORD"), val);
}
if let Some(val) = self.token.take() {
secrets.insert(format!("{}__{}", prefix, "TOKEN"), val);
}
self.tls
.extract_secrets(&format!("{}__{}", prefix, "TLS"), secrets);
}
}
impl SecretExtractor for AmqpConfig {
fn extract_secrets(&mut self, prefix: &str, secrets: &mut HashMap<String, String>) {
extract_sensitive_url(&mut self.url, prefix, "URL", secrets);
if let Some(val) = self.username.take() {
secrets.insert(format!("{}__{}", prefix, "USERNAME"), val);
}
if let Some(val) = self.password.take() {
secrets.insert(format!("{}__{}", prefix, "PASSWORD"), val);
}
self.tls
.extract_secrets(&format!("{}__{}", prefix, "TLS"), secrets);
}
}
impl SecretExtractor for MongoDbConfig {
fn extract_secrets(&mut self, prefix: &str, secrets: &mut HashMap<String, String>) {
extract_sensitive_url(&mut self.url, prefix, "URL", secrets);
if let Some(val) = self.username.take() {
secrets.insert(format!("{}__{}", prefix, "USERNAME"), val);
}
if let Some(val) = self.password.take() {
secrets.insert(format!("{}__{}", prefix, "PASSWORD"), val);
}
self.tls
.extract_secrets(&format!("{}__{}", prefix, "TLS"), secrets);
}
}
impl SecretExtractor for MqttConfig {
fn extract_secrets(&mut self, prefix: &str, secrets: &mut HashMap<String, String>) {
extract_sensitive_url(&mut self.url, prefix, "URL", secrets);
if let Some(val) = self.username.take() {
secrets.insert(format!("{}__{}", prefix, "USERNAME"), val);
}
if let Some(val) = self.password.take() {
secrets.insert(format!("{}__{}", prefix, "PASSWORD"), val);
}
self.tls
.extract_secrets(&format!("{}__{}", prefix, "TLS"), secrets);
}
}
impl SecretExtractor for HttpConfig {
fn extract_secrets(&mut self, prefix: &str, secrets: &mut HashMap<String, String>) {
extract_sensitive_url(&mut self.url, prefix, "URL", secrets);
if let Some((u, p)) = self.basic_auth.take() {
secrets.insert(format!("{}__{}__{}", prefix, "BASIC_AUTH", 0), u);
secrets.insert(format!("{}__{}__{}", prefix, "BASIC_AUTH", 1), p);
}
extract_sensitive_string_map_entries(
&mut self.custom_headers,
prefix,
"CUSTOM_HEADERS",
secrets,
);
self.tls
.extract_secrets(&format!("{}__{}", prefix, "TLS"), secrets);
}
}
impl SecretExtractor for WebSocketConfig {
fn extract_secrets(&mut self, prefix: &str, secrets: &mut HashMap<String, String>) {
extract_sensitive_url(&mut self.url, prefix, "URL", secrets);
}
}
impl SecretExtractor for IbmMqConfig {
fn extract_secrets(&mut self, prefix: &str, secrets: &mut HashMap<String, String>) {
extract_sensitive_url(&mut self.url, prefix, "URL", secrets);
if let Some(val) = self.username.take() {
secrets.insert(format!("{}__{}", prefix, "USERNAME"), val);
}
if let Some(val) = self.password.take() {
secrets.insert(format!("{}__{}", prefix, "PASSWORD"), val);
}
self.tls
.extract_secrets(&format!("{}__{}", prefix, "TLS"), secrets);
}
}
impl SecretExtractor for ZeroMqConfig {
fn extract_secrets(&mut self, prefix: &str, secrets: &mut HashMap<String, String>) {
extract_sensitive_url(&mut self.url, prefix, "URL", secrets);
}
}
impl SecretExtractor for SqlxConfig {
fn extract_secrets(&mut self, prefix: &str, secrets: &mut HashMap<String, String>) {
extract_sensitive_url(&mut self.url, prefix, "URL", secrets);
if let Some(val) = self.username.take() {
secrets.insert(format!("{}__{}", prefix, "USERNAME"), val);
}
if let Some(val) = self.password.take() {
secrets.insert(format!("{}__{}", prefix, "PASSWORD"), val);
}
self.tls
.extract_secrets(&format!("{}__{}", prefix, "TLS"), secrets);
}
}
impl SecretExtractor for GrpcConfig {
fn extract_secrets(&mut self, prefix: &str, secrets: &mut HashMap<String, String>) {
extract_sensitive_url(&mut self.url, prefix, "URL", secrets);
self.tls
.extract_secrets(&format!("{}__{}", prefix, "TLS"), secrets);
}
}
impl SecretExtractor for TlsConfig {
fn extract_secrets(&mut self, prefix: &str, secrets: &mut HashMap<String, String>) {
if let Some(val) = self.cert_password.take() {
secrets.insert(format!("{}__{}", prefix, "CERT_PASSWORD"), val);
}
}
}
pub fn extract_config_secrets(config: &mut Config) -> HashMap<String, String> {
let mut secrets = HashMap::new();
for (route_name, route) in config.iter_mut() {
let prefix = sanitize_secret_key(&format!("MQB__{}", route_name));
route.extract_secrets(&prefix, &mut secrets);
}
secrets
}
#[cfg(test)]
mod tests {
use super::*;
use config::{Config as ConfigBuilder, Environment};
const TEST_YAML: &str = r#"
kafka_to_nats:
concurrency: 10
input:
middlewares:
- deduplication:
sled_path: "/tmp/mq-bridge/dedup_db"
ttl_seconds: 3600
- metrics: {}
- retry:
max_attempts: 5
initial_interval_ms: 200
- random_panic:
mode: nack
- dlq:
endpoint:
nats:
subject: "dlq-subject"
url: "nats://localhost:4222"
kafka:
topic: "input-topic"
url: "localhost:9092"
group_id: "my-consumer-group"
tls:
required: true
ca_file: "/path_to_ca"
cert_file: "/path_to_cert"
key_file: "/path_to_key"
cert_password: "password"
accept_invalid_certs: true
output:
middlewares:
- metrics: {}
- dlq:
endpoint:
file:
path: "error.out"
nats:
subject: "output-subject"
url: "nats://localhost:4222"
"#;
fn assert_config_values(config: &Config) {
assert_eq!(config.len(), 1);
let route = config.get("kafka_to_nats").expect("Route should exist");
assert_eq!(route.options.concurrency, 10);
let input = &route.input;
assert_eq!(input.middlewares.len(), 5);
let mut has_dedup = false;
let mut has_metrics = false;
let mut has_dlq = false;
let mut has_retry = false;
let mut has_random_panic = false;
for middleware in &input.middlewares {
match middleware {
Middleware::Deduplication(dedup) => {
assert_eq!(dedup.sled_path, "/tmp/mq-bridge/dedup_db");
assert_eq!(dedup.ttl_seconds, 3600);
has_dedup = true;
}
Middleware::Metrics(_) => {
has_metrics = true;
}
Middleware::Custom { .. } => {}
Middleware::Dlq(dlq) => {
assert!(dlq.endpoint.middlewares.is_empty());
if let EndpointType::Nats(nats_cfg) = &dlq.endpoint.endpoint_type {
assert_eq!(nats_cfg.subject, Some("dlq-subject".to_string()));
assert_eq!(nats_cfg.url, "nats://localhost:4222");
}
has_dlq = true;
}
Middleware::Retry(retry) => {
assert_eq!(retry.max_attempts, 5);
assert_eq!(retry.initial_interval_ms, 200);
has_retry = true;
}
Middleware::RandomPanic(rp) => {
assert!(rp.mode == FaultMode::Nack);
has_random_panic = true;
}
Middleware::Delay(_) => {}
Middleware::WeakJoin(_) => {}
Middleware::Limiter(_) => {}
Middleware::Buffer(_) => {}
Middleware::CookieJar(_) => {}
}
}
if let EndpointType::Kafka(kafka) = &input.endpoint_type {
assert_eq!(kafka.topic, Some("input-topic".to_string()));
assert_eq!(kafka.url, "localhost:9092");
assert_eq!(kafka.group_id, Some("my-consumer-group".to_string()));
let tls = &kafka.tls;
assert!(tls.required);
assert_eq!(tls.ca_file.as_deref(), Some("/path_to_ca"));
assert!(tls.accept_invalid_certs);
} else {
panic!("Input endpoint should be Kafka");
}
assert!(has_dedup);
assert!(has_metrics);
assert!(has_dlq);
assert!(has_retry);
assert!(has_random_panic);
let output = &route.output;
assert_eq!(output.middlewares.len(), 2);
assert!(matches!(output.middlewares[0], Middleware::Metrics(_)));
if let EndpointType::Nats(nats) = &output.endpoint_type {
assert_eq!(nats.subject, Some("output-subject".to_string()));
assert_eq!(nats.url, "nats://localhost:4222");
} else {
panic!("Output endpoint should be NATS");
}
}
#[test]
fn test_deserialize_from_yaml() {
let result: Result<Config, _> = serde_yaml_ng::from_str(TEST_YAML);
println!("Deserialized from YAML: {:#?}", result);
let config = result.expect("Failed to deserialize TEST_YAML");
assert_config_values(&config);
}
#[test]
fn test_deserialize_from_env() {
unsafe {
std::env::set_var("MQB__KAFKA_TO_NATS__CONCURRENCY", "10");
std::env::set_var("MQB__KAFKA_TO_NATS__INPUT__KAFKA__TOPIC", "input-topic");
std::env::set_var("MQB__KAFKA_TO_NATS__INPUT__KAFKA__URL", "localhost:9092");
std::env::set_var(
"MQB__KAFKA_TO_NATS__INPUT__KAFKA__GROUP_ID",
"my-consumer-group",
);
std::env::set_var("MQB__KAFKA_TO_NATS__INPUT__KAFKA__TLS__REQUIRED", "true");
std::env::set_var(
"MQB__KAFKA_TO_NATS__INPUT__KAFKA__TLS__CA_FILE",
"/path_to_ca",
);
std::env::set_var(
"MQB__KAFKA_TO_NATS__INPUT__KAFKA__TLS__ACCEPT_INVALID_CERTS",
"true",
);
std::env::set_var(
"MQB__KAFKA_TO_NATS__OUTPUT__NATS__SUBJECT",
"output-subject",
);
std::env::set_var(
"MQB__KAFKA_TO_NATS__OUTPUT__NATS__URL",
"nats://localhost:4222",
);
std::env::set_var(
"MQB__KAFKA_TO_NATS__INPUT__MIDDLEWARES__0__DLQ__ENDPOINT__NATS__SUBJECT",
"dlq-subject",
);
std::env::set_var(
"MQB__KAFKA_TO_NATS__INPUT__MIDDLEWARES__0__DLQ__ENDPOINT__NATS__URL",
"nats://localhost:4222",
);
}
let builder = ConfigBuilder::builder()
.add_source(
Environment::with_prefix("MQB")
.separator("__")
.try_parsing(true),
);
let config: Config = builder
.build()
.expect("Failed to build config")
.try_deserialize()
.expect("Failed to deserialize config");
assert_eq!(config.get("kafka_to_nats").unwrap().options.concurrency, 10);
if let EndpointType::Kafka(k) = &config.get("kafka_to_nats").unwrap().input.endpoint_type {
assert_eq!(k.topic, Some("input-topic".to_string()));
assert!(k.tls.required);
} else {
panic!("Expected Kafka endpoint");
}
let input = &config.get("kafka_to_nats").unwrap().input;
assert_eq!(input.middlewares.len(), 1);
if let Middleware::Dlq(_) = &input.middlewares[0] {
} else {
panic!("Expected DLQ middleware");
}
}
#[test]
fn test_extract_secrets() {
let mut config = Config::new();
let mut route = Route::default();
let mut kafka_config = KafkaConfig::new("kafka://user:pass@localhost:9092");
kafka_config.username = Some("user".to_string());
kafka_config.password = Some("pass".to_string());
kafka_config.tls.cert_password = Some("certpass".to_string());
route.input = Endpoint {
endpoint_type: EndpointType::Kafka(kafka_config),
middlewares: vec![],
handler: None,
};
let mut http_config = HttpConfig::new("http://httpuser:httppass@localhost");
http_config.basic_auth = Some(("httpuser".to_string(), "httppass".to_string()));
http_config
.custom_headers
.insert("X-API-Key".to_string(), "http-api-key".to_string());
http_config.custom_headers.insert(
"X-Access-Token".to_string(),
"http-access-token".to_string(),
);
http_config.custom_headers.insert(
"X-Authentication".to_string(),
"http-authentication".to_string(),
);
http_config.custom_headers.insert(
"Authorization".to_string(),
"Bearer secret-token".to_string(),
);
http_config
.custom_headers
.insert("X-Trace-Id".to_string(), "trace-value".to_string());
route.output = Endpoint {
endpoint_type: EndpointType::Http(http_config),
middlewares: vec![],
handler: None,
};
config.insert("test_route".to_string(), route);
let secrets = extract_config_secrets(&mut config);
assert_eq!(
secrets
.get("MQB__TEST_ROUTE__INPUT__KAFKA__URL")
.map(|s| s.as_str()),
Some("kafka://user:pass@localhost:9092")
);
assert_eq!(
secrets
.get("MQB__TEST_ROUTE__INPUT__KAFKA__USERNAME")
.map(|s| s.as_str()),
Some("user")
);
assert_eq!(
secrets
.get("MQB__TEST_ROUTE__INPUT__KAFKA__PASSWORD")
.map(|s| s.as_str()),
Some("pass")
);
assert_eq!(
secrets
.get("MQB__TEST_ROUTE__INPUT__KAFKA__TLS__CERT_PASSWORD")
.map(|s| s.as_str()),
Some("certpass")
);
assert_eq!(
secrets
.get("MQB__TEST_ROUTE__OUTPUT__HTTP__URL")
.map(|s| s.as_str()),
Some("http://httpuser:httppass@localhost")
);
assert_eq!(
secrets
.get("MQB__TEST_ROUTE__OUTPUT__HTTP__BASIC_AUTH__0")
.map(|s| s.as_str()),
Some("httpuser")
);
assert_eq!(
secrets
.get("MQB__TEST_ROUTE__OUTPUT__HTTP__BASIC_AUTH__1")
.map(|s| s.as_str()),
Some("httppass")
);
assert_eq!(
secrets
.get("MQB__TEST_ROUTE__OUTPUT__HTTP__CUSTOM_HEADERS__X_API_KEY")
.map(|s| s.as_str()),
Some("http-api-key")
);
assert_eq!(
secrets
.get("MQB__TEST_ROUTE__OUTPUT__HTTP__CUSTOM_HEADERS__X_ACCESS_TOKEN")
.map(|s| s.as_str()),
Some("http-access-token")
);
assert_eq!(
secrets
.get("MQB__TEST_ROUTE__OUTPUT__HTTP__CUSTOM_HEADERS__X_AUTHENTICATION")
.map(|s| s.as_str()),
Some("http-authentication")
);
assert_eq!(
secrets
.get("MQB__TEST_ROUTE__OUTPUT__HTTP__CUSTOM_HEADERS__AUTHORIZATION")
.map(|s| s.as_str()),
Some("Bearer secret-token")
);
let route = config.get("test_route").unwrap();
if let EndpointType::Kafka(k) = &route.input.endpoint_type {
assert!(k.url.is_empty());
assert!(k.username.is_none());
assert!(k.password.is_none());
assert!(k.tls.cert_password.is_none());
}
if let EndpointType::Http(h) = &route.output.endpoint_type {
assert!(h.url.is_empty());
assert!(h.basic_auth.is_none());
assert!(!h.custom_headers.contains_key("X-API-Key"));
assert!(!h.custom_headers.contains_key("X-Access-Token"));
assert!(!h.custom_headers.contains_key("X-Authentication"));
assert!(!h.custom_headers.contains_key("Authorization"));
assert_eq!(
h.custom_headers.get("X-Trace-Id").map(|s| s.as_str()),
Some("trace-value")
);
}
}
#[test]
fn test_extract_sensitive_url_only_strips_authority_credentials() {
let mut config = Config::new();
let path_at_route = Route {
output: Endpoint {
endpoint_type: EndpointType::Http(HttpConfig::new(
"https://example.com/path/user@example.com?email=a@b.test",
)),
middlewares: vec![],
handler: None,
},
..Default::default()
};
config.insert("path_at_route".to_string(), path_at_route);
let credential_route = Route {
output: Endpoint {
endpoint_type: EndpointType::Http(HttpConfig::new(
"https://user:pass@example.com/path",
)),
middlewares: vec![],
handler: None,
},
..Default::default()
};
config.insert("credential_route".to_string(), credential_route);
let query_at_route = Route {
output: Endpoint {
endpoint_type: EndpointType::Http(HttpConfig::new(
"https://example.com?next=a@b.test",
)),
middlewares: vec![],
handler: None,
},
..Default::default()
};
config.insert("query_at_route".to_string(), query_at_route);
let fragment_at_route = Route {
output: Endpoint {
endpoint_type: EndpointType::Http(HttpConfig::new(
"https://example.com#user@example.com",
)),
middlewares: vec![],
handler: None,
},
..Default::default()
};
config.insert("fragment_at_route".to_string(), fragment_at_route);
let secrets = extract_config_secrets(&mut config);
if let EndpointType::Http(http) = &config.get("path_at_route").unwrap().output.endpoint_type
{
assert_eq!(
http.url,
"https://example.com/path/user@example.com?email=a@b.test"
);
}
if let EndpointType::Http(http) =
&config.get("query_at_route").unwrap().output.endpoint_type
{
assert_eq!(http.url, "https://example.com?next=a@b.test");
}
if let EndpointType::Http(http) = &config
.get("fragment_at_route")
.unwrap()
.output
.endpoint_type
{
assert_eq!(http.url, "https://example.com#user@example.com");
}
if let EndpointType::Http(http) =
&config.get("credential_route").unwrap().output.endpoint_type
{
assert!(http.url.is_empty());
}
assert_eq!(
secrets
.get("MQB__CREDENTIAL_ROUTE__OUTPUT__HTTP__URL")
.map(String::as_str),
Some("https://user:pass@example.com/path")
);
assert!(!secrets.contains_key("MQB__PATH_AT_ROUTE__OUTPUT__HTTP__URL"));
assert!(!secrets.contains_key("MQB__QUERY_AT_ROUTE__OUTPUT__HTTP__URL"));
assert!(!secrets.contains_key("MQB__FRAGMENT_AT_ROUTE__OUTPUT__HTTP__URL"));
}
#[test]
fn test_file_config_inference() {
let yaml = r#"
mode: group_subscribe
path: "/tmp/test"
group_id: "my_group"
"#;
let config: FileConfig = serde_yaml_ng::from_str(yaml).unwrap();
match config.mode {
Some(FileConsumerMode::GroupSubscribe { group_id, .. }) => {
assert_eq!(group_id, "my_group")
}
_ => panic!("Expected GroupSubscribe"),
}
let yaml_queue = r#"
mode: consume
path: "/tmp/test"
"#;
let config_queue: FileConfig = serde_yaml_ng::from_str(yaml_queue).unwrap();
match config_queue.mode {
Some(FileConsumerMode::Consume { delete }) => assert!(!delete),
_ => panic!("Expected Consume"),
}
}
}
#[cfg(all(test, feature = "schema"))]
mod schema_tests {
use super::*;
#[test]
fn generate_json_schema() {
let schema = schemars::schema_for!(Config);
let schema_json = serde_json::to_string_pretty(&schema).unwrap();
let mut path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("mq-bridge.schema.json");
std::fs::write(path, schema_json).expect("Failed to write schema file");
}
}