use anyhow::{anyhow, Result};
use kdl::KdlDocument;
use std::collections::{HashMap, HashSet};
use std::path::{Path, PathBuf};
use tracing::debug;
use crate::{
AgentConfig, Config, GlobalRateLimitConfig, Limits, ListenerConfig, NamespaceConfig,
ObservabilityConfig, RouteConfig, ServerConfig, UpstreamConfig, WafConfig,
};
use super::parsers::{
parse_agent, parse_limits, parse_listener, parse_namespace, parse_observability, parse_route,
parse_server, parse_upstream, parse_waf,
};
#[derive(Debug, Default)]
pub(super) struct PartialConfig {
pub source_file: PathBuf,
pub server: Option<ServerConfig>,
pub listeners: Vec<ListenerConfig>,
pub routes: Vec<RouteConfig>,
pub upstreams: HashMap<String, UpstreamConfig>,
pub agents: Vec<AgentConfig>,
pub waf: Option<WafConfig>,
pub limits: Option<Limits>,
pub observability: Option<ObservabilityConfig>,
pub namespaces: Vec<NamespaceConfig>,
pub includes: Vec<PathBuf>,
}
impl PartialConfig {
pub fn from_kdl(doc: KdlDocument, source: &Path) -> Result<Self> {
let mut config = Self {
source_file: source.to_path_buf(),
..Default::default()
};
for node in doc.nodes() {
match node.name().value() {
"include" => {
if let Some(entry) = node.entries().first() {
if let Some(path_str) = entry.value().as_string() {
config.includes.push(PathBuf::from(path_str));
debug!("Found include directive: {}", path_str);
}
}
}
"server" if config.server.is_none() => {
config.server = Some(parse_server(node)?);
}
"listener" => {
config.listeners.push(parse_listener(node)?);
}
"route" => {
config.routes.push(parse_route(node)?);
}
"upstream" => {
let (name, upstream) = parse_upstream(node)?;
config.upstreams.insert(name, upstream);
}
"agent" => {
config.agents.push(parse_agent(node)?);
}
"waf" if config.waf.is_none() => {
config.waf = Some(parse_waf(node)?);
}
"limits" if config.limits.is_none() => {
config.limits = Some(parse_limits(node)?);
}
"observability" if config.observability.is_none() => {
config.observability = Some(parse_observability(node)?);
}
"namespace" => {
config.namespaces.push(parse_namespace(node)?);
}
"metadata" => {
}
_ => {
debug!(
"Ignoring unknown configuration node: {}",
node.name().value()
);
}
}
}
Ok(config)
}
}
pub(super) struct ConfigBuilder {
server: Option<ServerConfig>,
listeners: Vec<ListenerConfig>,
routes: Vec<RouteConfig>,
upstreams: HashMap<String, UpstreamConfig>,
filters: HashMap<String, crate::FilterConfig>,
agents: Vec<AgentConfig>,
waf: Option<WafConfig>,
limits: Option<Limits>,
observability: Option<ObservabilityConfig>,
namespaces: Vec<NamespaceConfig>,
listener_ids: HashSet<String>,
route_ids: HashSet<String>,
#[allow(dead_code)]
filter_ids: HashSet<String>,
agent_ids: HashSet<String>,
namespace_ids: HashSet<String>,
}
impl ConfigBuilder {
pub fn new() -> Self {
Self {
server: None,
listeners: Vec::new(),
routes: Vec::new(),
upstreams: HashMap::new(),
filters: HashMap::new(),
agents: Vec::new(),
waf: None,
limits: None,
observability: None,
namespaces: Vec::new(),
listener_ids: HashSet::new(),
route_ids: HashSet::new(),
filter_ids: HashSet::new(),
agent_ids: HashSet::new(),
namespace_ids: HashSet::new(),
}
}
pub fn merge(&mut self, partial: PartialConfig) -> Result<()> {
for listener in partial.listeners {
if !self.listener_ids.insert(listener.id.clone()) {
return Err(anyhow!(
"Duplicate listener '{}' in {:?}",
listener.id,
partial.source_file
));
}
self.listeners.push(listener);
}
for route in partial.routes {
if !self.route_ids.insert(route.id.clone()) {
return Err(anyhow!(
"Duplicate route '{}' in {:?}",
route.id,
partial.source_file
));
}
self.routes.push(route);
}
for (name, upstream) in partial.upstreams {
if self.upstreams.contains_key(&name) {
tracing::warn!(
"Overriding upstream '{}' from {:?}",
name,
partial.source_file
);
}
self.upstreams.insert(name, upstream);
}
for agent in partial.agents {
if !self.agent_ids.insert(agent.id.clone()) {
return Err(anyhow!(
"Duplicate agent '{}' in {:?}",
agent.id,
partial.source_file
));
}
self.agents.push(agent);
}
for namespace in partial.namespaces {
if !self.namespace_ids.insert(namespace.id.clone()) {
return Err(anyhow!(
"Duplicate namespace '{}' in {:?}",
namespace.id,
partial.source_file
));
}
self.namespaces.push(namespace);
}
if partial.server.is_some() {
if self.server.is_some() {
tracing::warn!("Overriding server config from {:?}", partial.source_file);
}
self.server = partial.server;
}
if partial.waf.is_some() {
if self.waf.is_some() {
tracing::warn!("Overriding WAF config from {:?}", partial.source_file);
}
self.waf = partial.waf;
}
if partial.limits.is_some() {
if self.limits.is_some() {
tracing::warn!("Overriding limits config from {:?}", partial.source_file);
}
self.limits = partial.limits;
}
if partial.observability.is_some() {
if self.observability.is_some() {
tracing::warn!(
"Overriding observability config from {:?}",
partial.source_file
);
}
self.observability = partial.observability;
}
Ok(())
}
pub fn build(self) -> Result<Config> {
Ok(Config {
schema_version: crate::CURRENT_SCHEMA_VERSION.to_string(),
server: self
.server
.ok_or_else(|| anyhow!("Server configuration is required"))?,
listeners: self.listeners,
routes: self.routes,
upstreams: self.upstreams,
filters: self.filters,
agents: self.agents,
waf: self.waf,
namespaces: self.namespaces,
limits: self.limits.unwrap_or_default(),
observability: self.observability.unwrap_or_default(),
rate_limits: GlobalRateLimitConfig::default(),
cache: None,
default_upstream: None,
})
}
}