use anyhow::Result;
use std::collections::HashMap;
use std::path::PathBuf;
use tracing::trace;
use zentinel_common::types::{HealthCheckType, LoadBalancingAlgorithm};
use crate::upstreams::*;
use super::helpers::{get_first_arg_string, get_int_entry};
pub fn parse_upstreams(node: &kdl::KdlNode) -> Result<HashMap<String, UpstreamConfig>> {
trace!("Parsing upstreams configuration block");
let mut upstreams = HashMap::new();
if let Some(children) = node.children() {
for child in children.nodes() {
if child.name().value() == "upstream" {
let id = get_first_arg_string(child).ok_or_else(|| {
anyhow::anyhow!(
"Upstream requires an ID argument, e.g., upstream \"backend\" {{ ... }}"
)
})?;
trace!(upstream_id = %id, "Parsing upstream");
let mut targets = Vec::new();
if let Some(upstream_children) = child.children() {
for target_node in upstream_children.nodes() {
if target_node.name().value() == "target" {
if let Some(address) = get_first_arg_string(target_node) {
let weight = target_node
.entries()
.iter()
.find(|e| e.name().map(|n| n.value()) == Some("weight"))
.and_then(|e| e.value().as_integer())
.map(|v| v as u32)
.unwrap_or(1);
trace!(
upstream_id = %id,
address = %address,
weight = weight,
"Parsed target"
);
targets.push(UpstreamTarget {
address,
weight,
max_requests: None,
metadata: HashMap::new(),
});
}
}
}
}
if targets.is_empty() {
return Err(anyhow::anyhow!(
"Upstream '{}' requires at least one target, e.g., target \"127.0.0.1:8081\"",
id
));
}
let load_balancing_node = child.children().and_then(|c| {
c.nodes()
.iter()
.find(|n| n.name().value() == "load-balancing")
});
let load_balancing = load_balancing_node
.and_then(get_first_arg_string)
.map(|s| parse_load_balancing(&s))
.unwrap_or(LoadBalancingAlgorithm::RoundRobin);
let sticky_session = if load_balancing == LoadBalancingAlgorithm::Sticky {
load_balancing_node
.and_then(|n| n.children())
.map(parse_sticky_session_config)
} else {
None
};
if sticky_session.is_some() {
trace!(
upstream_id = %id,
cookie_name = ?sticky_session.as_ref().map(|s| &s.cookie_name),
"Parsed sticky session configuration"
);
}
let health_check = child
.children()
.and_then(|c| {
c.nodes()
.iter()
.find(|n| n.name().value() == "health-check")
})
.and_then(|n| parse_health_check(n).ok());
if health_check.is_some() {
trace!(
upstream_id = %id,
"Parsed health check configuration"
);
}
let http_version = child
.children()
.and_then(|c| {
c.nodes()
.iter()
.find(|n| n.name().value() == "http-version")
})
.map(parse_http_version)
.unwrap_or_default();
if http_version.max_version >= 2 {
trace!(
upstream_id = %id,
max_version = http_version.max_version,
"HTTP/2 enabled for upstream"
);
}
let connection_pool = child
.children()
.and_then(|c| {
c.nodes()
.iter()
.find(|n| n.name().value() == "connection-pool")
})
.map(parse_connection_pool)
.unwrap_or_default();
let timeouts = child
.children()
.and_then(|c| c.nodes().iter().find(|n| n.name().value() == "timeouts"))
.map(parse_upstream_timeouts)
.unwrap_or_default();
let tls = child
.children()
.and_then(|c| c.nodes().iter().find(|n| n.name().value() == "tls"))
.map(parse_upstream_tls);
if tls.is_some() {
trace!(
upstream_id = %id,
"Parsed TLS configuration"
);
}
trace!(
upstream_id = %id,
target_count = targets.len(),
load_balancing = ?load_balancing,
has_health_check = health_check.is_some(),
has_tls = tls.is_some(),
http_version = http_version.max_version,
max_connections = connection_pool.max_connections,
connect_timeout = timeouts.connect_secs,
"Parsed upstream"
);
upstreams.insert(
id.clone(),
UpstreamConfig {
id,
targets,
load_balancing,
sticky_session,
health_check,
connection_pool,
timeouts,
tls,
http_version,
},
);
}
}
}
trace!(
upstream_count = upstreams.len(),
"Finished parsing upstreams"
);
Ok(upstreams)
}
fn parse_load_balancing(s: &str) -> LoadBalancingAlgorithm {
match s.to_lowercase().as_str() {
"round_robin" | "roundrobin" => LoadBalancingAlgorithm::RoundRobin,
"least_connections" | "leastconnections" => LoadBalancingAlgorithm::LeastConnections,
"weighted" | "weighted_round_robin" | "weighted-round-robin" => {
LoadBalancingAlgorithm::Weighted
}
"ip_hash" | "iphash" => LoadBalancingAlgorithm::IpHash,
"random" => LoadBalancingAlgorithm::Random,
"consistent_hash" | "consistenthash" => LoadBalancingAlgorithm::ConsistentHash,
"power_of_two_choices" | "p2c" => LoadBalancingAlgorithm::PowerOfTwoChoices,
"adaptive" => LoadBalancingAlgorithm::Adaptive,
"least_tokens_queued" | "leasttokensqueued" | "least_tokens" => {
LoadBalancingAlgorithm::LeastTokensQueued
}
"maglev" => LoadBalancingAlgorithm::Maglev,
"locality_aware" | "localityaware" | "locality" => LoadBalancingAlgorithm::LocalityAware,
"peak_ewma" | "peakewma" | "ewma" => LoadBalancingAlgorithm::PeakEwma,
"deterministic_subset" | "subset" | "subsetting" => {
LoadBalancingAlgorithm::DeterministicSubset
}
"weighted_least_connections" | "weighted_least_conn" | "wlc" => {
LoadBalancingAlgorithm::WeightedLeastConnections
}
"sticky" | "sticky_session" | "stickysession" => LoadBalancingAlgorithm::Sticky,
_ => LoadBalancingAlgorithm::RoundRobin,
}
}
fn parse_sticky_session_config(children: &kdl::KdlDocument) -> StickySessionConfig {
let nodes = children.nodes();
let cookie_name =
find_string_entry(nodes, "cookie-name").unwrap_or_else(|| "SERVERID".to_string());
let cookie_ttl_secs = find_int_entry(nodes, "cookie-ttl")
.map(|v| v as u64)
.or_else(|| find_string_entry(nodes, "cookie-ttl").and_then(|s| parse_duration_string(&s)))
.unwrap_or(3600);
let cookie_path = find_string_entry(nodes, "cookie-path").unwrap_or_else(|| "/".to_string());
let cookie_secure = nodes
.iter()
.find(|n| n.name().value() == "cookie-secure")
.and_then(|n| n.entries().first())
.and_then(|e| e.value().as_bool())
.unwrap_or(true);
let cookie_same_site = find_string_entry(nodes, "cookie-same-site")
.map(|s| match s.to_lowercase().as_str() {
"strict" => SameSitePolicy::Strict,
"none" => SameSitePolicy::None,
_ => SameSitePolicy::Lax,
})
.unwrap_or_default();
let fallback = find_string_entry(nodes, "fallback")
.map(|s| parse_load_balancing(&s))
.unwrap_or(LoadBalancingAlgorithm::RoundRobin);
StickySessionConfig {
cookie_name,
cookie_ttl_secs,
cookie_path,
cookie_secure,
cookie_same_site,
fallback,
}
}
fn parse_duration_string(s: &str) -> Option<u64> {
let s = s.trim();
if s.is_empty() {
return None;
}
let (num_part, unit_part) = s
.chars()
.position(|c| c.is_alphabetic())
.map(|i| s.split_at(i))?;
let value: f64 = num_part.trim().parse().ok()?;
let multiplier = match unit_part.to_lowercase().as_str() {
"s" | "sec" | "secs" | "second" | "seconds" => 1,
"m" | "min" | "mins" | "minute" | "minutes" => 60,
"h" | "hr" | "hrs" | "hour" | "hours" => 3600,
"d" | "day" | "days" => 86400,
_ => return None,
};
Some((value * multiplier as f64) as u64)
}
fn parse_http_version(node: &kdl::KdlNode) -> HttpVersionConfig {
let get_child_int = |name: &str| -> Option<i128> {
node.children()
.and_then(|c| c.nodes().iter().find(|n| n.name().value() == name))
.and_then(|n| n.entries().first())
.and_then(|e| e.value().as_integer())
};
let min_version = get_child_int("min-version").map(|v| v as u8).unwrap_or(1);
let max_version = get_child_int("max-version").map(|v| v as u8).unwrap_or(2);
let h2_ping_interval_secs = get_child_int("h2-ping-interval")
.map(|v| v as u64)
.unwrap_or(0);
let max_h2_streams = get_child_int("max-h2-streams")
.map(|v| v as usize)
.unwrap_or(100);
HttpVersionConfig {
min_version,
max_version,
h2_ping_interval_secs,
max_h2_streams,
}
}
fn parse_connection_pool(node: &kdl::KdlNode) -> ConnectionPoolConfig {
let max_connections = get_int_entry(node, "max-connections")
.map(|v| v as usize)
.unwrap_or(100);
let max_idle = get_int_entry(node, "max-idle")
.map(|v| v as usize)
.unwrap_or(20);
let idle_timeout_secs = get_int_entry(node, "idle-timeout")
.map(|v| v as u64)
.unwrap_or(60);
let max_lifetime_secs = get_int_entry(node, "max-lifetime").map(|v| v as u64);
ConnectionPoolConfig {
max_connections,
max_idle,
idle_timeout_secs,
max_lifetime_secs,
}
}
fn parse_upstream_timeouts(node: &kdl::KdlNode) -> UpstreamTimeouts {
let connect_secs = get_int_entry(node, "connect")
.map(|v| v as u64)
.unwrap_or(10);
let request_secs = get_int_entry(node, "request")
.map(|v| v as u64)
.unwrap_or(60);
let read_secs = get_int_entry(node, "read").map(|v| v as u64).unwrap_or(30);
let write_secs = get_int_entry(node, "write").map(|v| v as u64).unwrap_or(30);
UpstreamTimeouts {
connect_secs,
request_secs,
read_secs,
write_secs,
}
}
fn parse_upstream_tls(node: &kdl::KdlNode) -> UpstreamTlsConfig {
let sni = find_string_entry_from_node(node, "sni");
let insecure_skip_verify =
find_bool_entry_from_node(node, "insecure-skip-verify").unwrap_or(false);
let client_cert = find_string_entry_from_node(node, "client-cert").map(PathBuf::from);
let client_key = find_string_entry_from_node(node, "client-key").map(PathBuf::from);
let ca_cert = find_string_entry_from_node(node, "ca-cert").map(PathBuf::from);
UpstreamTlsConfig {
sni,
insecure_skip_verify,
client_cert,
client_key,
ca_cert,
}
}
fn find_string_entry_from_node(node: &kdl::KdlNode, name: &str) -> Option<String> {
node.children()
.and_then(|c| c.nodes().iter().find(|n| n.name().value() == name))
.and_then(|n| n.entries().first())
.and_then(|e| e.value().as_string().map(|s| s.to_string()))
}
fn find_bool_entry_from_node(node: &kdl::KdlNode, name: &str) -> Option<bool> {
node.children()
.and_then(|c| c.nodes().iter().find(|n| n.name().value() == name))
.and_then(|n| n.entries().first())
.and_then(|e| e.value().as_bool())
}
fn parse_health_check(node: &kdl::KdlNode) -> Result<HealthCheck> {
let check_type = node
.children()
.and_then(|c| c.nodes().iter().find(|n| n.name().value() == "type"))
.map(|type_node| {
let type_name = get_first_arg_string(type_node).unwrap_or_else(|| "tcp".to_string());
match type_name.to_lowercase().as_str() {
"http" => {
let path = type_node
.children()
.and_then(|c| c.nodes().iter().find(|n| n.name().value() == "path"))
.and_then(get_first_arg_string)
.unwrap_or_else(|| "/health".to_string());
let expected_status = type_node
.children()
.and_then(|c| {
c.nodes()
.iter()
.find(|n| n.name().value() == "expected-status")
})
.and_then(get_first_arg_string)
.and_then(|s| s.parse().ok())
.unwrap_or(200);
let host = type_node
.children()
.and_then(|c| c.nodes().iter().find(|n| n.name().value() == "host"))
.and_then(get_first_arg_string);
HealthCheckType::Http {
path,
expected_status,
host,
}
}
"grpc" => {
let service = type_node
.children()
.and_then(|c| c.nodes().iter().find(|n| n.name().value() == "service"))
.and_then(get_first_arg_string)
.unwrap_or_else(|| "grpc.health.v1.Health".to_string());
HealthCheckType::Grpc { service }
}
"inference" => {
let endpoint = type_node
.children()
.and_then(|c| c.nodes().iter().find(|n| n.name().value() == "endpoint"))
.and_then(get_first_arg_string)
.unwrap_or_else(|| "/v1/models".to_string());
let expected_models = type_node
.children()
.and_then(|c| {
c.nodes()
.iter()
.find(|n| n.name().value() == "expected-models")
})
.map(|n| {
n.entries()
.iter()
.filter_map(|e| e.value().as_string().map(|s| s.to_string()))
.collect::<Vec<_>>()
})
.unwrap_or_default();
let readiness = type_node
.children()
.and_then(|c| c.nodes().iter().find(|n| n.name().value() == "readiness"))
.map(|n| Box::new(parse_inference_readiness(n)));
HealthCheckType::Inference {
endpoint,
expected_models,
readiness,
}
}
_ => HealthCheckType::Tcp,
}
})
.unwrap_or(HealthCheckType::Tcp);
let interval_secs = get_int_entry(node, "interval-secs").unwrap_or(10) as u64;
let timeout_secs = get_int_entry(node, "timeout-secs").unwrap_or(5) as u64;
let healthy_threshold = get_int_entry(node, "healthy-threshold").unwrap_or(2) as u32;
let unhealthy_threshold = get_int_entry(node, "unhealthy-threshold").unwrap_or(3) as u32;
Ok(HealthCheck {
check_type,
interval_secs,
timeout_secs,
healthy_threshold,
unhealthy_threshold,
})
}
fn parse_inference_readiness(node: &kdl::KdlNode) -> zentinel_common::InferenceReadinessConfig {
use zentinel_common::{
ColdModelAction, InferenceProbeConfig, InferenceReadinessConfig, ModelStatusConfig,
QueueDepthConfig, WarmthDetectionConfig,
};
let children = match node.children() {
Some(c) => c,
None => return InferenceReadinessConfig::default(),
};
let inference_probe = children
.nodes()
.iter()
.find(|n| n.name().value() == "inference-probe")
.and_then(|n| n.children())
.map(|c| {
let nodes = c.nodes();
InferenceProbeConfig {
endpoint: find_string_entry(nodes, "endpoint")
.unwrap_or_else(|| "/v1/completions".to_string()),
model: find_string_entry(nodes, "model").unwrap_or_default(),
prompt: find_string_entry(nodes, "prompt").unwrap_or_else(|| ".".to_string()),
max_tokens: find_int_entry(nodes, "max-tokens").unwrap_or(1) as u32,
timeout_secs: find_int_entry(nodes, "timeout-secs").unwrap_or(30) as u64,
max_latency_ms: find_int_entry(nodes, "max-latency-ms").map(|v| v as u64),
}
});
let model_status = children
.nodes()
.iter()
.find(|n| n.name().value() == "model-status")
.and_then(|n| n.children())
.map(|c| {
let nodes = c.nodes();
let models = nodes
.iter()
.find(|n| n.name().value() == "models")
.map(|n| {
n.entries()
.iter()
.filter_map(|e| e.value().as_string().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
ModelStatusConfig {
endpoint_pattern: find_string_entry(nodes, "endpoint-pattern")
.unwrap_or_else(|| "/v1/models/{model}/status".to_string()),
models,
expected_status: find_string_entry(nodes, "expected-status")
.unwrap_or_else(|| "ready".to_string()),
status_field: find_string_entry(nodes, "status-field")
.unwrap_or_else(|| "status".to_string()),
timeout_secs: find_int_entry(nodes, "timeout-secs").unwrap_or(5) as u64,
}
});
let queue_depth = children
.nodes()
.iter()
.find(|n| n.name().value() == "queue-depth")
.and_then(|n| n.children())
.map(|c| {
let nodes = c.nodes();
QueueDepthConfig {
header: find_string_entry(nodes, "header"),
body_field: find_string_entry(nodes, "body-field"),
endpoint: find_string_entry(nodes, "endpoint"),
degraded_threshold: find_int_entry(nodes, "degraded-threshold").unwrap_or(50)
as u64,
unhealthy_threshold: find_int_entry(nodes, "unhealthy-threshold").unwrap_or(200)
as u64,
timeout_secs: find_int_entry(nodes, "timeout-secs").unwrap_or(5) as u64,
}
});
let warmth_detection = children
.nodes()
.iter()
.find(|n| n.name().value() == "warmth-detection")
.and_then(|n| n.children())
.map(|c| {
let nodes = c.nodes();
let cold_action = find_string_entry(nodes, "cold-action")
.map(|s| match s.as_str() {
"log-only" | "log_only" => ColdModelAction::LogOnly,
"mark-degraded" | "mark_degraded" => ColdModelAction::MarkDegraded,
"mark-unhealthy" | "mark_unhealthy" => ColdModelAction::MarkUnhealthy,
_ => ColdModelAction::LogOnly,
})
.unwrap_or_default();
WarmthDetectionConfig {
sample_size: find_int_entry(nodes, "sample-size").unwrap_or(10) as u32,
cold_threshold_multiplier: find_float_entry(nodes, "cold-threshold-multiplier")
.unwrap_or(3.0),
idle_cold_timeout_secs: find_int_entry(nodes, "idle-cold-timeout-secs")
.unwrap_or(300) as u64,
cold_action,
}
});
InferenceReadinessConfig {
inference_probe,
model_status,
queue_depth,
warmth_detection,
}
}
fn find_string_entry(nodes: &[kdl::KdlNode], name: &str) -> Option<String> {
nodes
.iter()
.find(|n| n.name().value() == name)
.and_then(|n| n.entries().first())
.and_then(|e| e.value().as_string().map(|s| s.to_string()))
}
fn find_int_entry(nodes: &[kdl::KdlNode], name: &str) -> Option<i64> {
nodes
.iter()
.find(|n| n.name().value() == name)
.and_then(|n| n.entries().first())
.and_then(|e| e.value().as_integer().map(|v| v as i64))
}
fn find_float_entry(nodes: &[kdl::KdlNode], name: &str) -> Option<f64> {
nodes
.iter()
.find(|n| n.name().value() == name)
.and_then(|n| n.entries().first())
.and_then(|e| {
e.value()
.as_float()
.or_else(|| e.value().as_integer().map(|i| i as f64))
})
}
#[cfg(test)]
mod tests {
use super::*;
fn parse_kdl_upstreams(input: &str) -> Result<HashMap<String, UpstreamConfig>> {
let doc: kdl::KdlDocument = input.parse().unwrap();
let node = doc.nodes().first().unwrap();
parse_upstreams(node)
}
#[test]
fn test_parse_upstream_tls_full() {
let kdl = r#"
upstreams {
upstream "secure-backend" {
target "10.0.0.1:8443"
tls {
sni "backend.example.com"
insecure-skip-verify #false
client-cert "/path/to/client.crt"
client-key "/path/to/client.key"
ca-cert "/path/to/ca.crt"
}
}
}
"#;
let upstreams = parse_kdl_upstreams(kdl).unwrap();
let upstream = upstreams.get("secure-backend").unwrap();
assert!(upstream.tls.is_some());
let tls = upstream.tls.as_ref().unwrap();
assert_eq!(tls.sni, Some("backend.example.com".to_string()));
assert!(!tls.insecure_skip_verify);
assert_eq!(tls.client_cert, Some(PathBuf::from("/path/to/client.crt")));
assert_eq!(tls.client_key, Some(PathBuf::from("/path/to/client.key")));
assert_eq!(tls.ca_cert, Some(PathBuf::from("/path/to/ca.crt")));
}
#[test]
fn test_parse_upstream_tls_minimal() {
let kdl = r#"
upstreams {
upstream "simple-tls" {
target "10.0.0.1:443"
tls {
sni "api.example.com"
}
}
}
"#;
let upstreams = parse_kdl_upstreams(kdl).unwrap();
let upstream = upstreams.get("simple-tls").unwrap();
assert!(upstream.tls.is_some());
let tls = upstream.tls.as_ref().unwrap();
assert_eq!(tls.sni, Some("api.example.com".to_string()));
assert!(!tls.insecure_skip_verify); assert!(tls.client_cert.is_none());
assert!(tls.client_key.is_none());
assert!(tls.ca_cert.is_none());
}
#[test]
fn test_parse_upstream_tls_insecure() {
let kdl = r#"
upstreams {
upstream "dev-backend" {
target "localhost:8443"
tls {
insecure-skip-verify #true
}
}
}
"#;
let upstreams = parse_kdl_upstreams(kdl).unwrap();
let upstream = upstreams.get("dev-backend").unwrap();
assert!(upstream.tls.is_some());
let tls = upstream.tls.as_ref().unwrap();
assert!(tls.insecure_skip_verify);
assert!(tls.sni.is_none());
}
#[test]
fn test_parse_upstream_no_tls() {
let kdl = r#"
upstreams {
upstream "plain-http" {
target "10.0.0.1:8080"
}
}
"#;
let upstreams = parse_kdl_upstreams(kdl).unwrap();
let upstream = upstreams.get("plain-http").unwrap();
assert!(upstream.tls.is_none());
}
#[test]
fn test_parse_upstream_tls_mtls_only() {
let kdl = r#"
upstreams {
upstream "mtls-backend" {
target "10.0.0.1:443"
tls {
client-cert "/certs/client.pem"
client-key "/certs/client-key.pem"
}
}
}
"#;
let upstreams = parse_kdl_upstreams(kdl).unwrap();
let upstream = upstreams.get("mtls-backend").unwrap();
assert!(upstream.tls.is_some());
let tls = upstream.tls.as_ref().unwrap();
assert_eq!(tls.client_cert, Some(PathBuf::from("/certs/client.pem")));
assert_eq!(tls.client_key, Some(PathBuf::from("/certs/client-key.pem")));
assert!(tls.sni.is_none());
assert!(tls.ca_cert.is_none());
}
}