use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use validator::Validate;
use zentinel_common::types::{HealthCheckType, LoadBalancingAlgorithm};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum SameSitePolicy {
#[default]
Lax,
Strict,
None,
}
impl std::fmt::Display for SameSitePolicy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SameSitePolicy::Lax => write!(f, "Lax"),
SameSitePolicy::Strict => write!(f, "Strict"),
SameSitePolicy::None => write!(f, "None"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StickySessionConfig {
pub cookie_name: String,
pub cookie_ttl_secs: u64,
#[serde(default = "default_cookie_path")]
pub cookie_path: String,
#[serde(default = "default_cookie_secure")]
pub cookie_secure: bool,
#[serde(default)]
pub cookie_same_site: SameSitePolicy,
#[serde(default = "default_sticky_fallback")]
pub fallback: LoadBalancingAlgorithm,
}
fn default_cookie_path() -> String {
"/".to_string()
}
fn default_cookie_secure() -> bool {
true
}
fn default_sticky_fallback() -> LoadBalancingAlgorithm {
LoadBalancingAlgorithm::RoundRobin
}
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct UpstreamConfig {
pub id: String,
#[validate(length(min = 1, message = "At least one target is required"))]
pub targets: Vec<UpstreamTarget>,
#[serde(default = "default_lb_algorithm")]
pub load_balancing: LoadBalancingAlgorithm,
pub sticky_session: Option<StickySessionConfig>,
pub health_check: Option<HealthCheck>,
#[serde(default)]
pub connection_pool: ConnectionPoolConfig,
#[serde(default)]
pub timeouts: UpstreamTimeouts,
pub tls: Option<UpstreamTlsConfig>,
#[serde(default)]
pub http_version: HttpVersionConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HttpVersionConfig {
#[serde(default = "default_min_http_version")]
pub min_version: u8,
#[serde(default = "default_max_http_version")]
pub max_version: u8,
#[serde(default)]
pub h2_ping_interval_secs: u64,
#[serde(default = "default_max_h2_streams")]
pub max_h2_streams: usize,
}
impl Default for HttpVersionConfig {
fn default() -> Self {
Self {
min_version: default_min_http_version(),
max_version: default_max_http_version(),
h2_ping_interval_secs: 0,
max_h2_streams: default_max_h2_streams(),
}
}
}
fn default_min_http_version() -> u8 {
1
}
fn default_max_http_version() -> u8 {
2 }
fn default_max_h2_streams() -> usize {
100
}
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct UpstreamTarget {
pub address: String,
#[serde(default = "default_weight")]
pub weight: u32,
pub max_requests: Option<u32>,
#[serde(default)]
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HealthCheck {
#[serde(rename = "type")]
pub check_type: HealthCheckType,
#[serde(default = "default_health_check_interval")]
pub interval_secs: u64,
#[serde(default = "default_health_check_timeout")]
pub timeout_secs: u64,
#[serde(default = "default_healthy_threshold")]
pub healthy_threshold: u32,
#[serde(default = "default_unhealthy_threshold")]
pub unhealthy_threshold: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConnectionPoolConfig {
#[serde(default = "default_max_connections_per_target")]
pub max_connections: usize,
#[serde(default = "default_max_idle_connections")]
pub max_idle: usize,
#[serde(default = "default_idle_timeout")]
pub idle_timeout_secs: u64,
pub max_lifetime_secs: Option<u64>,
}
impl Default for ConnectionPoolConfig {
fn default() -> Self {
Self {
max_connections: default_max_connections_per_target(),
max_idle: default_max_idle_connections(),
idle_timeout_secs: default_idle_timeout(),
max_lifetime_secs: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpstreamTimeouts {
#[serde(default = "default_connect_timeout")]
pub connect_secs: u64,
#[serde(default = "default_upstream_request_timeout")]
pub request_secs: u64,
#[serde(default = "default_read_timeout")]
pub read_secs: u64,
#[serde(default = "default_write_timeout")]
pub write_secs: u64,
}
impl Default for UpstreamTimeouts {
fn default() -> Self {
Self {
connect_secs: default_connect_timeout(),
request_secs: default_upstream_request_timeout(),
read_secs: default_read_timeout(),
write_secs: default_write_timeout(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpstreamTlsConfig {
pub sni: Option<String>,
#[serde(default)]
pub insecure_skip_verify: bool,
pub client_cert: Option<PathBuf>,
pub client_key: Option<PathBuf>,
pub ca_cert: Option<PathBuf>,
}
#[derive(Debug, Clone)]
pub struct UpstreamPeer {
pub address: String,
pub tls: bool,
pub host: String,
pub connect_timeout_secs: u64,
pub read_timeout_secs: u64,
pub write_timeout_secs: u64,
}
fn default_lb_algorithm() -> LoadBalancingAlgorithm {
LoadBalancingAlgorithm::RoundRobin
}
fn default_weight() -> u32 {
1
}
fn default_health_check_interval() -> u64 {
10
}
fn default_health_check_timeout() -> u64 {
5
}
fn default_healthy_threshold() -> u32 {
2
}
fn default_unhealthy_threshold() -> u32 {
3
}
fn default_max_connections_per_target() -> usize {
100
}
fn default_max_idle_connections() -> usize {
20
}
fn default_idle_timeout() -> u64 {
60
}
pub(crate) fn default_connect_timeout() -> u64 {
10
}
fn default_upstream_request_timeout() -> u64 {
60
}
pub(crate) fn default_read_timeout() -> u64 {
30
}
pub(crate) fn default_write_timeout() -> u64 {
30
}