use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use validator::Validate;
use zentinel_common::budget::{CostAttributionConfig, TokenBudgetConfig};
use zentinel_common::types::{ByteSize, CircuitBreakerConfig, Priority, RetryPolicy};
use crate::filters::RateLimitKey;
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct RouteConfig {
pub id: String,
#[serde(default)]
pub priority: Priority,
pub matches: Vec<MatchCondition>,
pub upstream: Option<String>,
#[serde(default)]
pub service_type: ServiceType,
#[serde(default)]
pub policies: RoutePolicies,
#[serde(default)]
pub filters: Vec<String>,
#[serde(default, rename = "builtin-handler")]
pub builtin_handler: Option<BuiltinHandler>,
#[serde(default)]
pub waf_enabled: bool,
#[serde(default)]
pub circuit_breaker: Option<CircuitBreakerConfig>,
#[serde(default)]
pub retry_policy: Option<RetryPolicy>,
#[serde(default)]
pub static_files: Option<StaticFileConfig>,
#[serde(default)]
pub api_schema: Option<ApiSchemaConfig>,
#[serde(default)]
pub inference: Option<InferenceConfig>,
#[serde(default)]
pub error_pages: Option<ErrorPageConfig>,
#[serde(default)]
pub websocket: bool,
#[serde(default)]
pub websocket_inspection: bool,
#[serde(default)]
pub shadow: Option<ShadowConfig>,
#[serde(default)]
pub fallback: Option<FallbackConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MatchCondition {
PathPrefix(String),
Path(String),
PathRegex(String),
Host(String),
Header { name: String, value: Option<String> },
Method(Vec<String>),
QueryParam { name: String, value: Option<String> },
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum ServiceType {
#[default]
Web,
Api,
Static,
Builtin,
Inference,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum BuiltinHandler {
Status,
Health,
Metrics,
NotFound,
Config,
Upstreams,
CachePurge,
CacheStats,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct RoutePolicies {
#[serde(default)]
pub request_headers: HeaderModifications,
#[serde(default)]
pub response_headers: HeaderModifications,
pub timeout_secs: Option<u64>,
pub max_body_size: Option<ByteSize>,
pub rate_limit: Option<RateLimitPolicy>,
#[serde(default = "default_failure_mode")]
pub failure_mode: FailureMode,
#[serde(default)]
pub buffer_requests: bool,
#[serde(default)]
pub buffer_responses: bool,
#[serde(default)]
pub cache: Option<RouteCacheConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouteCacheConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_cache_ttl")]
pub default_ttl_secs: u64,
#[serde(default = "default_max_cache_size")]
pub max_size_bytes: usize,
#[serde(default)]
pub cache_private: bool,
#[serde(default = "default_stale_while_revalidate")]
pub stale_while_revalidate_secs: u64,
#[serde(default = "default_stale_if_error")]
pub stale_if_error_secs: u64,
#[serde(default = "default_cacheable_methods")]
pub cacheable_methods: Vec<String>,
#[serde(default = "default_cacheable_status_codes")]
pub cacheable_status_codes: Vec<u16>,
#[serde(default)]
pub vary_headers: Vec<String>,
#[serde(default)]
pub ignore_query_params: Vec<String>,
#[serde(default)]
pub exclude_extensions: Vec<String>,
#[serde(default)]
pub exclude_paths: Vec<String>,
}
impl Default for RouteCacheConfig {
fn default() -> Self {
Self {
enabled: false,
default_ttl_secs: default_cache_ttl(),
max_size_bytes: default_max_cache_size(),
cache_private: false,
stale_while_revalidate_secs: default_stale_while_revalidate(),
stale_if_error_secs: default_stale_if_error(),
cacheable_methods: default_cacheable_methods(),
cacheable_status_codes: default_cacheable_status_codes(),
vary_headers: Vec::new(),
ignore_query_params: Vec::new(),
exclude_extensions: Vec::new(),
exclude_paths: Vec::new(),
}
}
}
fn default_cache_ttl() -> u64 {
3600 }
fn default_max_cache_size() -> usize {
10 * 1024 * 1024 }
fn default_stale_while_revalidate() -> u64 {
60 }
fn default_stale_if_error() -> u64 {
300 }
fn default_cacheable_methods() -> Vec<String> {
vec!["GET".to_string(), "HEAD".to_string()]
}
fn default_cacheable_status_codes() -> Vec<u16> {
vec![200, 203, 204, 206, 300, 301, 308, 404, 410]
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheStorageConfig {
#[serde(default = "default_cache_enabled")]
pub enabled: bool,
#[serde(default)]
pub backend: CacheBackend,
#[serde(default = "default_cache_storage_size")]
pub max_size_bytes: usize,
#[serde(default)]
pub eviction_limit_bytes: Option<usize>,
#[serde(default = "default_cache_lock_timeout")]
pub lock_timeout_secs: u64,
#[serde(default)]
pub disk_path: Option<PathBuf>,
#[serde(default = "default_disk_shards")]
pub disk_shards: u32,
#[serde(default)]
pub disk_max_size_bytes: Option<usize>,
#[serde(default)]
pub status_header: bool,
#[serde(default = "default_status_header_name")]
pub status_header_name: String,
}
fn default_status_header_name() -> String {
"zentinel".to_string()
}
impl Default for CacheStorageConfig {
fn default() -> Self {
Self {
enabled: true,
backend: CacheBackend::Memory,
max_size_bytes: default_cache_storage_size(),
eviction_limit_bytes: None,
lock_timeout_secs: default_cache_lock_timeout(),
disk_path: None,
disk_shards: default_disk_shards(),
disk_max_size_bytes: None,
status_header: false,
status_header_name: default_status_header_name(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum CacheBackend {
#[default]
Memory,
Disk,
Hybrid,
}
fn default_cache_enabled() -> bool {
true
}
fn default_cache_storage_size() -> usize {
100 * 1024 * 1024 }
fn default_cache_lock_timeout() -> u64 {
10 }
fn default_disk_shards() -> u32 {
16
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct HeaderModifications {
#[serde(default)]
pub rename: HashMap<String, String>,
#[serde(default)]
pub set: HashMap<String, String>,
#[serde(default)]
pub add: HashMap<String, String>,
#[serde(default)]
pub remove: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitPolicy {
pub requests_per_second: u32,
pub burst: u32,
pub key: RateLimitKey,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum FailureMode {
Open, #[default]
Closed, }
pub(crate) fn default_failure_mode() -> FailureMode {
FailureMode::Closed
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StaticFileConfig {
pub root: PathBuf,
#[serde(default = "default_index_file")]
pub index: String,
#[serde(default)]
pub directory_listing: bool,
#[serde(default = "default_cache_control")]
pub cache_control: String,
#[serde(default = "default_true")]
pub compress: bool,
#[serde(default)]
pub mime_types: HashMap<String, String>,
pub fallback: Option<String>,
}
fn default_index_file() -> String {
"index.html".to_string()
}
fn default_cache_control() -> String {
"public, max-age=3600".to_string()
}
fn default_true() -> bool {
true
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApiSchemaConfig {
pub schema_file: Option<PathBuf>,
pub schema_content: Option<String>,
pub request_schema: Option<serde_json::Value>,
pub response_schema: Option<serde_json::Value>,
#[serde(default = "default_true")]
pub validate_requests: bool,
#[serde(default)]
pub validate_responses: bool,
#[serde(default)]
pub strict_mode: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorPageConfig {
#[serde(default)]
pub pages: HashMap<u16, ErrorPage>,
#[serde(default)]
pub default_format: ErrorFormat,
#[serde(default)]
pub include_stack_trace: bool,
pub template_dir: Option<PathBuf>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorPage {
pub format: ErrorFormat,
pub template: Option<PathBuf>,
pub message: Option<String>,
#[serde(default)]
pub headers: HashMap<String, String>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum ErrorFormat {
#[default]
Html,
Json,
Text,
Xml,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShadowConfig {
pub upstream: String,
#[serde(default = "default_shadow_percentage")]
pub percentage: f64,
pub sample_header: Option<(String, String)>,
#[serde(default = "default_shadow_timeout_ms")]
pub timeout_ms: u64,
#[serde(default)]
pub buffer_body: bool,
#[serde(default = "default_shadow_max_body_bytes")]
pub max_body_bytes: usize,
}
fn default_shadow_percentage() -> f64 {
100.0 }
fn default_shadow_timeout_ms() -> u64 {
5000 }
fn default_shadow_max_body_bytes() -> usize {
1048576 }
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct InferenceConfig {
#[serde(default)]
pub provider: InferenceProvider,
pub model_header: Option<String>,
pub rate_limit: Option<TokenRateLimit>,
pub budget: Option<TokenBudgetConfig>,
pub cost_attribution: Option<CostAttributionConfig>,
pub routing: Option<InferenceRouting>,
pub model_routing: Option<ModelRoutingConfig>,
pub guardrails: Option<GuardrailsConfig>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum InferenceProvider {
#[default]
Generic,
OpenAi,
Anthropic,
}
impl InferenceProvider {
pub fn as_str(&self) -> &'static str {
match self {
Self::Generic => "generic",
Self::OpenAi => "openai",
Self::Anthropic => "anthropic",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenRateLimit {
pub tokens_per_minute: u64,
pub requests_per_minute: Option<u64>,
#[serde(default = "default_burst_tokens")]
pub burst_tokens: u64,
#[serde(default)]
pub estimation_method: TokenEstimation,
}
fn default_burst_tokens() -> u64 {
10000
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum TokenEstimation {
#[default]
Chars,
Words,
Tiktoken,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InferenceRouting {
#[serde(default)]
pub strategy: InferenceRoutingStrategy,
pub queue_depth_header: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum InferenceRoutingStrategy {
#[default]
LeastTokensQueued,
RoundRobin,
LeastLatency,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ModelRoutingConfig {
#[serde(default)]
pub mappings: Vec<ModelUpstreamMapping>,
pub default_upstream: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelUpstreamMapping {
pub model_pattern: String,
pub upstream: String,
pub provider: Option<InferenceProvider>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct FallbackConfig {
#[serde(default)]
pub upstreams: Vec<FallbackUpstream>,
#[serde(default)]
pub triggers: FallbackTriggers,
#[serde(default = "default_max_fallback_attempts")]
pub max_attempts: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FallbackUpstream {
pub upstream: String,
#[serde(default)]
pub provider: InferenceProvider,
#[serde(default)]
pub model_mapping: HashMap<String, String>,
#[serde(default)]
pub skip_if_unhealthy: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FallbackTriggers {
#[serde(default = "default_true")]
pub on_health_failure: bool,
#[serde(default)]
pub on_budget_exhausted: bool,
#[serde(default)]
pub on_latency_threshold_ms: Option<u64>,
#[serde(default)]
pub on_error_codes: Vec<u16>,
#[serde(default = "default_true")]
pub on_connection_error: bool,
}
impl Default for FallbackTriggers {
fn default() -> Self {
Self {
on_health_failure: true,
on_budget_exhausted: false,
on_latency_threshold_ms: None,
on_error_codes: Vec::new(),
on_connection_error: true,
}
}
}
fn default_max_fallback_attempts() -> u32 {
3
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct GuardrailsConfig {
pub prompt_injection: Option<PromptInjectionConfig>,
pub pii_detection: Option<PiiDetectionConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptInjectionConfig {
#[serde(default)]
pub enabled: bool,
pub agent: String,
#[serde(default)]
pub action: GuardrailAction,
#[serde(default = "default_guardrail_block_status")]
pub block_status: u16,
pub block_message: Option<String>,
#[serde(default = "default_prompt_injection_timeout_ms")]
pub timeout_ms: u64,
#[serde(default)]
pub failure_mode: GuardrailFailureMode,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PiiDetectionConfig {
#[serde(default)]
pub enabled: bool,
pub agent: String,
#[serde(default)]
pub action: PiiAction,
#[serde(default)]
pub categories: Vec<String>,
#[serde(default = "default_pii_detection_timeout_ms")]
pub timeout_ms: u64,
#[serde(default)]
pub failure_mode: GuardrailFailureMode,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum GuardrailAction {
Block,
#[default]
Log,
Warn,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum PiiAction {
#[default]
Log,
Redact,
Block,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum GuardrailFailureMode {
#[default]
Open,
Closed,
}
fn default_guardrail_block_status() -> u16 {
400
}
fn default_prompt_injection_timeout_ms() -> u64 {
500
}
fn default_pii_detection_timeout_ms() -> u64 {
1000
}